├── .gitignore
├── LICENSE
├── README.md
├── imgs
└── title.webp
├── requirements.txt
└── src
├── __init__.py
├── config
├── aug
│ ├── none.yaml
│ └── resize512.yaml
├── config.yaml
├── ds
│ ├── test.yaml
│ ├── train.yaml
│ └── valid.yaml
└── runtime
│ ├── bilateral_upsample_net.default.yaml
│ ├── hist_unet.default.yaml
│ ├── lcdpnet.default.yaml
│ └── lcdpnet.release.yaml
├── data
├── __init__.py
├── augmentation.py
└── img_dataset.py
├── env.yaml
├── globalenv.py
├── model
├── __init__.py
├── arch
│ ├── __init__.py
│ ├── drconv.py
│ ├── hist.py
│ ├── nonlocal_block_embedded_gaussian.py
│ └── unet_based
│ │ ├── __init__.py
│ │ └── hist_unet.py
├── basemodel.py
├── basic_loss.py
├── bilateralupsamplenet.py
├── lcdpnet.py
└── single_net_basemodel.py
├── test.py
├── train.py
└── utils
├── __init__.py
└── util.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 | imgs/title.png
9 | .DS_Store
10 |
11 | # Distribution / packaging
12 | .Python
13 | build/
14 | develop-eggs/
15 | dist/
16 | downloads/
17 | eggs/
18 | .eggs/
19 | lib/
20 | lib64/
21 | parts/
22 | sdist/
23 | var/
24 | wheels/
25 | pip-wheel-metadata/
26 | share/python-wheels/
27 | *.egg-info/
28 | .installed.cfg
29 | *.egg
30 | MANIFEST
31 |
32 | # PyInstaller
33 | # Usually these files are written by a python script from a template
34 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
35 | *.manifest
36 | *.spec
37 |
38 | # Installer logs
39 | pip-log.txt
40 | pip-delete-this-directory.txt
41 |
42 | # Unit test / coverage reports
43 | htmlcov/
44 | .tox/
45 | .nox/
46 | .coverage
47 | .coverage.*
48 | .cache
49 | nosetests.xml
50 | coverage.xml
51 | *.cover
52 | *.py,cover
53 | .hypothesis/
54 | .pytest_cache/
55 |
56 | # Translations
57 | *.mo
58 | *.pot
59 |
60 | # Django stuff:
61 | *.log
62 | local_settings.py
63 | db.sqlite3
64 | db.sqlite3-journal
65 |
66 | # Flask stuff:
67 | instance/
68 | .webassets-cache
69 |
70 | # Scrapy stuff:
71 | .scrapy
72 |
73 | # Sphinx documentation
74 | docs/_build/
75 |
76 | # PyBuilder
77 | target/
78 |
79 | # Jupyter Notebook
80 | .ipynb_checkpoints
81 |
82 | # IPython
83 | profile_default/
84 | ipython_config.py
85 |
86 | # pyenv
87 | .python-version
88 |
89 | # pipenv
90 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
91 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
92 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
93 | # install all needed dependencies.
94 | #Pipfile.lock
95 |
96 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow
97 | __pypackages__/
98 |
99 | # Celery stuff
100 | celerybeat-schedule
101 | celerybeat.pid
102 |
103 | # SageMath parsed files
104 | *.sage.py
105 |
106 | # Environments
107 | .env
108 | .venv
109 | env/
110 | venv/
111 | ENV/
112 | env.bak/
113 | venv.bak/
114 |
115 | # Spyder project settings
116 | .spyderproject
117 | .spyproject
118 |
119 | # Rope project settings
120 | .ropeproject
121 |
122 | # mkdocs documentation
123 | /site
124 |
125 | # mypy
126 | .mypy_cache/
127 | .dmypy.json
128 | dmypy.json
129 |
130 | # Pyre type checker
131 | .pyre/
132 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2022 hywang99
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 | [`🌐 Website`](https://whyy.site/paper/lcdp) · [`📃 Paper`](https://www.cs.cityu.edu.hk/~rynson/papers/eccv22b.pdf) · [`🗃️ Dataset`](https://drive.google.com/drive/folders/10Reaq-N0DiZiFpSrZ8j5g3g0EJes4JiS?usp=sharing)
12 |
13 |
14 |
15 | **Abstract:** Existing image enhancement methods are typically designed to address either the over- or under-exposure problem in the input image. When the illumination of the input image contains both over- and under-exposure problems, these existing methods may not work well. We observe from the image statistics that the local color distributions (LCDs) of an image suffering from both problems tend to vary across different regions of the image, depending on the local illuminations. Based on this observation, we propose in this paper to exploit these LCDs as a prior for locating and enhancing the two types of regions (i.e., over-/under-exposed regions). First, we leverage the LCDs to represent these regions, and propose a novel local color distribution embedded (LCDE) module to formulate LCDs in multi-scales to model the correlations across different regions. Second, we propose a dual-illumination learning mechanism to enhance the two types of regions. Third, we construct a new dataset to facilitate the learning process, by following the camera image signal processing (ISP) pipeline to render standard RGB images with both under-/over-exposures from raw data. Extensive experiments demonstrate that the proposed method outperforms existing state-of-the-art methods quantitatively and qualitatively.
16 |
17 | ## 📻 News
18 |
19 | - 2023.7.21: if you have an interest in low-light enhancement and NeRF, please check out my latest ICCV2023 work, [LLNeRF](https://github.com/onpix/LLNeRF) ! 🔥🔥🔥
20 | - 2023.7.21: Update README
21 | - 2023.2.7: Merge `tar.gz` files of our dataset to a single `7z` file.
22 | - 2023.2.8: Update packages version in `requirements.txt`.
23 | - 2023.2.8: Upload `env.yaml`.
24 |
25 | ## 🔥 Our Model
26 |
27 | 
28 |
29 |
30 | ## ⚙️ Setup
31 |
32 | 1. Clone `git clone https://github.com/onpix/LCDPNet.git`
33 | 2. Go to directory `cd LCDPNet`
34 | 3. Install required packages `pip install -r requirements.txt`
35 |
36 | We also provide `env.yaml` for quickly installing packages. Note that you may need to modify the env name to prevent overwriting your existing enviroment, or modify cudatoolkit and cudnn version in `env.yaml` to match your local cuda version.
37 |
38 | ## ⌨️ How to run
39 |
40 | To train our model:
41 |
42 | 1. Prepare data: Modify `src/config/ds/train.yaml` and `src/config/ds/valid.yaml`.
43 | 2. Modify configs in `src/config`. Note that we use `hydra` for config management.
44 | 3. Run: `python src/train.py name= num_epoch=200 log_every=2000 valid_every=20`
45 |
46 | To test our model:
47 |
48 | 1. Prepare data: Modify `src/config/ds/test.yaml`
49 | 2. Run: `python src/test.py checkpoint_path=`
50 |
51 | ## 📂 Dataset & Pretrained Model
52 |
53 | The LCDP Dataset is here: [[Google drive]](https://drive.google.com/drive/folders/10Reaq-N0DiZiFpSrZ8j5g3g0EJes4JiS?usp=sharing). Please unzip `lcdp_dataset.7z`. The training and test images are:
54 |
55 | | | Train | Test |
56 | | ----- | ------------- | ------------------ |
57 | | Input | `input/*.png` | `test-input/*.png` |
58 | | GT | `gt/*.png` | `test-gt/*.png` |
59 |
60 | We provide the two pretrained models: `pretrained_models/trained_on_ours.ckpt` and `pretrained_models/trained_on_MSEC.ckpt` for researchers to reproduce the results in Table 1 and Table 2 in our paper. Note that we train `pretrained_models/trained_on_MSEC.ckpt` on the Expert C subset of the MSEC dataset with both over and under-exposed images.
61 |
62 | | Filename | Training data | Testing data | Test PSNR | Test SSIM |
63 | | -------------------- | ------------------------------------------------------------ | ---------------------------- | --------- | --------- |
64 | | trained_on_ours.ckpt | Ours | Our testing data | 23.239 | 0.842 |
65 | | trained_on_MSEC.ckpt | [MSEC](https://github.com/mahmoudnafifi/Exposure_Correction) | MSEC testing data (Expert C) | 22.295 | 0.855 |
66 |
67 | Our model is lightweight. Experiments show that increasing model size will further improve the quality of the results. To train a bigger model, increase the values in `runtime.bilateral_upsample_net.hist_unet.channel_nums`.
68 |
69 | ## 🔗 Cite This Paper
70 |
71 | If you find our work or code helpful, or your research benefits from this repo, please cite our paper:
72 |
73 | ```bibtex
74 | @inproceedings{wang2022lcdp,
75 | title = {Local Color Distributions Prior for Image Enhancement},
76 | author = {Haoyuan Wang, Ke Xu, and Rynson W.H. Lau},
77 | booktitle = {Proceedings of the European Conference on Computer Vision (ECCV)},
78 | year = {2022}
79 | }
80 | ```
--------------------------------------------------------------------------------
/imgs/title.webp:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/imgs/title.webp
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | ipdb == 0.13.9
2 | hydra-core == 1.1.1
3 | pytorch_lightning == 1.7.6
4 | scikit-image == 0.19.2
5 | kornia == 0.6.7
6 | wandb == 0.13.8
7 | opencv-python == 4.5.5.64
8 | matplotlib == 3.5.1
9 | torchvision == 0.13.1
10 |
--------------------------------------------------------------------------------
/src/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/src/__init__.py
--------------------------------------------------------------------------------
/src/config/aug/none.yaml:
--------------------------------------------------------------------------------
1 | # @package aug
2 |
3 |
4 | # crop size:
5 | # - false for not cropping;
6 | # - int n for cropping square [n. n];
7 | # - [h, w] for cropping specific size/
8 | crop: false
9 |
10 |
11 | # downsample factor.
12 | # - value > 1 for downsample;
13 | # - 0.x for upsample;
14 | # - [h, w] for specific size;
15 | # - false for not resizing.
16 | downsample: false
17 |
18 | h-flip: false
19 | v-flip: false
20 |
21 | lightness_adjust: false
22 | contrast_adjust: false
23 |
24 | resize_divisible_n: false # align input image size to the number divisible by N to avoid some size-mismatch error.
--------------------------------------------------------------------------------
/src/config/aug/resize512.yaml:
--------------------------------------------------------------------------------
1 | # @package aug
2 |
3 | # crop size:
4 | # - false for not cropping;
5 | # - int n for cropping square [n. n];
6 | # - [h, w] for cropping specific sizes
7 | crop: false
8 |
9 |
10 | # downsample factor.
11 | # - value > 1 for downsample;
12 | # - 0.x for upsample;
13 | # - [h, w] for specific size;
14 | # - false for not resizing.
15 | downsample: [ 512, 512 ]
16 |
17 | h-flip: true
18 | v-flip: true
--------------------------------------------------------------------------------
/src/config/config.yaml:
--------------------------------------------------------------------------------
1 | project: default_proj
2 | name: default_name # name of experiment
3 | comment: false
4 | debug: false
5 | val_debug_step_nums: 2 # sanity check num
6 | gpu: -1 # number of gpus to use (int) or which GPUs to use (list or str)
7 | backend: ddp # gpu accelerator. value: ddp or none
8 | runtime_precision: 16
9 | amp_backend: native # native or apex
10 | amp_level: O1
11 | dataloader_num_worker: 5
12 | mode: train
13 | logger: tb
14 |
15 | # frequently changed configs:
16 | num_epoch: 1000
17 | valid_every: 10 # validate every N EPOCHS
18 | savemodel_every: 4 # run ModelCheckpoint every N EPOCHS
19 | log_every: 100 # log your message, curve or images every N STEPS
20 | batchsize: 16
21 | valid_batchsize: 1
22 | lr: 1e-4
23 | checkpoint_path: null
24 |
25 | checkpoint_monitor: loss
26 | resume_training: true
27 | monitor_mode: min
28 | early_stop: false
29 | valid_ratio: 0.1
30 |
31 | flags: { }
32 |
33 | defaults:
34 | - aug: resize512
35 | - ds@train_ds: train
36 | - ds@test_ds: test
37 | - ds@valid_ds: valid
38 | - runtime: lcdpnet.release
39 |
40 | hydra:
41 | run:
42 | dir: ./
43 |
--------------------------------------------------------------------------------
/src/config/ds/test.yaml:
--------------------------------------------------------------------------------
1 | # @package ds
2 | class: img_dataset
3 | name: lcdp_data.test
4 | input:
5 | - your_dataset_path/test-input/*
6 | GT:
7 | - your_dataset_path/test-gt/*
8 |
--------------------------------------------------------------------------------
/src/config/ds/train.yaml:
--------------------------------------------------------------------------------
1 | # @package ds
2 | class: img_dataset
3 | name: lcdp_data.train
4 | input:
5 | - your_dataset_path/input/*
6 | GT:
7 | - your_dataset_path/gt/*
8 |
--------------------------------------------------------------------------------
/src/config/ds/valid.yaml:
--------------------------------------------------------------------------------
1 | # @package ds
2 | class: img_dataset
3 | name: lcdp_data.valid
4 | input:
5 | - your_dataset_path/valid-input/*
6 | GT:
7 | - your_dataset_path/valid-gt/*
8 |
--------------------------------------------------------------------------------
/src/config/runtime/bilateral_upsample_net.default.yaml:
--------------------------------------------------------------------------------
1 | # @package runtime
2 |
3 | modelname: bilateral_upsample_net
4 | predict_illumination: false
5 | loss: # when self_supervised is true, this option will be ignored.
6 | mse: 1.0
7 | cos: 0.1
8 | ltv: 0.1 # only matters when predict_illumination is true.
9 |
10 | luma_bins: 8
11 | channel_multiplier: 1
12 | spatial_bin: 16
13 | batch_norm: true
14 | low_resolution: 256
15 | coeffs_type: matrix # selected from [matrix, gamma, retinex]
16 |
17 | conv_type: conv
18 |
19 | # choose from: [ori, hist-unet]
20 | backbone: ori
21 |
22 | # type: false or int.
23 | # if N, do : illu_map **= N to adjust the brightness of the output.
24 | illu_map_power: false
25 |
26 | # only work when using hist-unet
27 | defaults:
28 | - hist_unet.default@hist_unet
29 |
--------------------------------------------------------------------------------
/src/config/runtime/hist_unet.default.yaml:
--------------------------------------------------------------------------------
1 | n_bins: 8
2 | hist_as_guide: false
3 | channel_nums: false
4 | encoder_use_hist: false
5 | guide_feature_from_hist: false
6 | region_num: 8
7 | use_gray_hist: false
8 | conv_type: drconv
9 | down_ratio: 2
10 | hist_conv_trainable: false
11 | drconv_position: [ 1,1 ]
--------------------------------------------------------------------------------
/src/config/runtime/lcdpnet.default.yaml:
--------------------------------------------------------------------------------
1 | # @package runtime
2 |
3 | modelname: lcdpnet
4 | use_wavelet: false
5 | use_attn_map: false
6 | use_non_local: false
7 |
8 | how_to_fuse: cnn-weights
9 |
10 | backbone: unet
11 |
12 | # choose from : [conv, drconv]
13 | conv_type: conv
14 |
15 | # the output of the backbone is the illu_net.
16 | # If false, the output of the backbone is directly to darken & brighten input.
17 | backbone_out_illu: true
18 |
19 | illumap_channel: 3 # 1 or 3
20 |
21 | # 2 branches share the same weights to predict the illu map.
22 | share_weights: true
23 |
24 | # only work when using hist-unet
25 | n_bins: 8
26 | hist_as_guide: false
27 |
28 | loss:
29 | ltv: 0 # ltv applied on the FINAL OUTPUT.
30 | cos: 0.5 # cos similarity loss
31 | weighted_loss: 0 # weighted loss instead of l1loss
32 | tvloss1: 0.01 # tvloss applied on the illumination
33 | tvloss2: 0.01 # tvloss applied on the inverse illumination
34 | tvloss1_new: 0
35 | tvloss2_new: 0
36 |
37 | l1_loss: 1.0 # default pixel-wise l1 loss
38 | ssim_loss: 0
39 | psnr_loss: 0
40 | illumap_loss: 0 # constraint illumap1 + illumap2 -> 1
41 | hist_loss: 0
42 | inter_hist_loss: 0
43 | vgg_loss: 0
44 |
45 | defaults:
46 | - bilateral_upsample_net.default@bilateral_upsample_net
47 | - hist_unet.default@hist_unet
--------------------------------------------------------------------------------
/src/config/runtime/lcdpnet.release.yaml:
--------------------------------------------------------------------------------
1 | # @package runtime
2 |
3 | defaults:
4 | - lcdpnet.default
5 |
6 | loss:
7 | ltv: 0 # ltv applied on the FINAL OUTPUT.
8 | cos: 0 # cos similarity loss
9 | cos2: 0.5 # cos with sigmoid
10 | tvloss1: 0 # tvloss applied on the illumination
11 | tvloss2: 0 # tvloss applied on the inverse illumination
12 | l1_loss: 1.0 # default pixel-wise l1 loss
13 | tvloss1_new: 0.01
14 | tvloss2_new: 0.01
15 |
16 | backbone: bilateral_upsample_net
17 |
18 | bilateral_upsample_net:
19 | backbone: hist-unet
20 | hist_unet:
21 | guide_feature_from_hist: true
22 | region_num: 2
23 | drconv_position: [ 0,1 ]
24 | channel_nums: [ 8,16,32,64,128 ]
--------------------------------------------------------------------------------
/src/data/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/src/data/__init__.py
--------------------------------------------------------------------------------
/src/data/augmentation.py:
--------------------------------------------------------------------------------
1 | # import cv2
2 | from collections.abc import Iterable
3 |
4 | import numpy as np
5 | import torchvision.transforms.functional as F
6 | from torchvision import transforms
7 |
8 | from globalenv import *
9 |
10 |
11 | class Downsample:
12 | def __init__(self, downsample_factor=None):
13 | self.downsample_factor = downsample_factor
14 |
15 | if isinstance(self.downsample_factor, Iterable):
16 | # should be [h, w]
17 | assert len(downsample_factor) == 2
18 |
19 | def __call__(self, img):
20 | '''
21 | img: passed by the previous transforms. PIL iamge or np.ndarray
22 | '''
23 | origin_h = img.size[1]
24 | origin_w = img.size[0]
25 | if isinstance(self.downsample_factor, Iterable):
26 | # pass [h,w]
27 | if -1 in self.downsample_factor:
28 | # automatic calculate the output size:
29 | h_scale = origin_h / self.downsample_factor[0]
30 | w_scale = origin_w / self.downsample_factor[1]
31 |
32 | # choose the correct one
33 | scale = max(w_scale, h_scale)
34 | new_size = [
35 | int(origin_h / scale), # H
36 | int(origin_w / scale) # W
37 | ]
38 | else:
39 | new_size = self.downsample_factor # [H, W]
40 |
41 | elif type(self.downsample_factor + 0.1) == float:
42 | # pass a number as scale factor
43 | # PIL.Image, cv2.resize and torchvision.transforms.Resize all accepts [W, H]
44 | new_size = [
45 | int(img.size[1] / self.downsample_factor), # H
46 | int(img.size[0] / self.downsample_factor) # W
47 | ]
48 | else:
49 | raise RuntimeError(f'ERR: Wrong config aug.downsample: {self.downsample_factor}')
50 |
51 | img = img.resize(new_size[::-1]) # reverse passed [h, w] to [w, h]
52 | return img
53 |
54 | def __repr__(self):
55 | return self.__class__.__name__ + f'({self.downsample_factor})'
56 |
57 |
58 | def get_value(d, k):
59 | if k in d and d[k]:
60 | return d[k]
61 | else:
62 | return False
63 |
64 |
65 | def parseAugmentation(opt):
66 | '''
67 | return: pytorch composed transform
68 | '''
69 | aug_config = opt[AUGMENTATION]
70 | aug_list = [transforms.ToPILImage(), ]
71 |
72 | # the order is fixed:
73 | augmentaionFactory = {
74 | DOWNSAMPLE: Downsample(aug_config[DOWNSAMPLE])
75 | if get_value(aug_config, DOWNSAMPLE) else None,
76 | CROP: transforms.RandomCrop(aug_config[CROP])
77 | if get_value(aug_config, CROP) else None,
78 | HORIZON_FLIP: transforms.RandomHorizontalFlip(),
79 | VERTICAL_FLIP: transforms.RandomVerticalFlip(),
80 | }
81 |
82 | for k, v in augmentaionFactory.items():
83 | if get_value(aug_config, k):
84 | aug_list.append(v)
85 |
86 | aug_list.append(transforms.ToTensor())
87 | print('Dataset augmentation:')
88 | print(aug_list)
89 | return transforms.Compose(aug_list)
90 |
91 |
92 | if __name__ == '__main__':
93 | pass
94 |
--------------------------------------------------------------------------------
/src/data/img_dataset.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import random
3 | from collections.abc import Iterable
4 | from glob import glob
5 |
6 | import cv2
7 | import ipdb
8 | import numpy as np
9 | import torch
10 | import torchvision
11 | from pytorch_lightning import LightningDataModule
12 |
13 | from globalenv import *
14 | from .augmentation import parseAugmentation
15 |
16 |
17 | def parse_item_list_txt(txtpath):
18 | '''
19 | Parse txt file containing all file paths. Each line one file.
20 | '''
21 | txt = Path(txtpath)
22 | basedir = txt.parent
23 | content = txt.open().read().splitlines()
24 |
25 | sample = content[0]
26 | if sample.split('/')[0] in str(basedir):
27 | raise NotImplementedError('Not implemented: file path in txt and basedir have comman str.')
28 | else:
29 | assert (basedir / sample).exists()
30 | return [str(basedir / x) for x in content]
31 |
32 |
33 | def load_from_glob_list(globs):
34 | if type(globs) == str:
35 | if globs.endswith('.txt'):
36 | print(f'Parse txt file: {globs}')
37 | return parse_item_list_txt(globs)
38 | else:
39 | return sorted(glob(globs))
40 |
41 | elif isinstance(globs, Iterable):
42 | # if iamges are glob lists, sort EACH list ** individually **
43 | res = []
44 | for g in globs:
45 | assert not g.endswith('.txt'), 'TXT file should not in glob list.'
46 | res.extend(sorted(glob(g)))
47 | return res
48 | else:
49 | ipdb.set_trace()
50 | raise TypeError(
51 | f'ERR: `ds.GT` or `ds.input` has wrong type: expect `str` or `list` but get {type(globs)}')
52 |
53 |
54 | def augment_one_img(img, seed, transform=None):
55 | img = img.astype(np.uint8)
56 | random.seed(seed)
57 | torch.manual_seed(seed)
58 | if transform:
59 | img = transform(img)
60 | return img
61 |
62 |
63 | class ImagesDataset(torch.utils.data.Dataset):
64 |
65 | def __init__(self, opt, ds_type=TRAIN_DATA, transform=None, batchsize=None):
66 | """Initialisation for the Dataset object
67 | transform: PyTorch image transformations to apply to the images
68 | """
69 | self.transform = transform
70 | self.opt = opt
71 |
72 | gt_globs = opt[ds_type][GT]
73 | input_globs = opt[ds_type][INPUT]
74 | self.have_gt = True if gt_globs else False
75 |
76 | print(f'{ds_type} - GT Directory path: [yellow]{gt_globs}[/yellow]')
77 | print(f'{ds_type} - Input Directory path: [yellow]{input_globs}[/yellow]')
78 |
79 | # load input images:
80 | self.input_list = load_from_glob_list(input_globs)
81 |
82 | # load GT images:
83 | if self.have_gt:
84 | self.gt_list = load_from_glob_list(gt_globs)
85 | try:
86 | assert len(self.input_list) == len(self.gt_list)
87 | except:
88 | ipdb.set_trace()
89 | raise AssertionError(
90 | f'In [{ds_type}]: len(input_images) ({len(self.input_list)}) != len(gt_images) ({len(self.gt_list)})! ')
91 |
92 | # import ipdb; ipdb.set_trace()
93 | print(
94 | f'{ds_type} Dataset length: {self.__len__()}, batch num: {self.__len__() // batchsize}')
95 |
96 | if self.__len__() == 0:
97 | print(f'Error occured! Your ds is: TYPE={ds_type}, config:')
98 | print(opt[ds_type])
99 | raise RuntimeError(f'[ Err ] Dataset input nums is 0!')
100 |
101 | def __len__(self):
102 | return (len(self.input_list))
103 |
104 | def debug_save_item(self, input, gt):
105 | # home = os.environ['HOME']
106 | util.saveTensorAsImg(input, 'i.png')
107 | util.saveTensorAsImg(gt, 'o.png')
108 |
109 | def __getitem__(self, idx):
110 | """Returns a pair of images with the given identifier. This is lazy loading
111 | of data into memory. Only those image pairs needed for the current batch
112 | are loaded.
113 |
114 | :param idx: image pair identifier
115 | :returns: dictionary containing input and output images and their identifier
116 | :rtype: dictionary
117 |
118 | """
119 | res_item = {INPUT_FPATH: self.input_list[idx]}
120 |
121 | # different seed for different item, but same for GT and INPUT in one item:
122 | # the "seed of seed" is fixed for reproducing
123 | # random.seed(GLOBAL_SEED)
124 | seed = random.randint(0, 100000)
125 | input_img = cv2.imread(self.input_list[idx])[:, :, [2, 1, 0]]
126 | if self.have_gt and self.gt_list[idx].endswith('.hdr'):
127 | input_img = torch.Tensor(input_img / 255).permute(2, 0, 1)
128 | else:
129 | input_img = augment_one_img(input_img, seed, transform=self.transform)
130 | res_item[INPUT] = input_img
131 |
132 | if self.have_gt:
133 | res_item[GT_FPATH] = self.gt_list[idx]
134 |
135 | if res_item[GT_FPATH].endswith('.hdr'):
136 | # gt may be HDR
137 | # do not augment HDR image.
138 | gt_img = cv2.imread(self.gt_list[idx], flags=cv2.IMREAD_ANYDEPTH)[:, :, [2, 1, 0]]
139 | gt_img = torch.Tensor(np.log10(gt_img + 1)).permute(2, 0, 1)
140 | else:
141 | gt_img = cv2.imread(self.gt_list[idx])[:, :, [2, 1, 0]]
142 | gt_img = augment_one_img(gt_img, seed, transform=self.transform)
143 |
144 | res_item[GT] = gt_img
145 | assert res_item[GT].shape == res_item[INPUT].shape
146 |
147 | return res_item
148 |
149 |
150 | class DataModule(LightningDataModule):
151 | def __init__(self, opt, apply_test_transform=False, apply_valid_transform=False):
152 | super().__init__()
153 | self.opt = opt
154 | self.transform = parseAugmentation(opt)
155 | # self.transform = None
156 | if apply_test_transform:
157 | self.test_transform = self.transform
158 | else:
159 | self.test_transform = torchvision.transforms.ToTensor()
160 |
161 | if apply_valid_transform:
162 | self.valid_transform = self.transform
163 | else:
164 | self.valid_transform = torchvision.transforms.ToTensor()
165 |
166 | self.training_dataset = None
167 | self.valid_dataset = None
168 | self.test_dataset = None
169 |
170 | def prepare_data(self):
171 | # download, split, etc...
172 | # only called on 1 GPU/TPU in distributed
173 | ...
174 |
175 | def setup(self, stage):
176 | opt = self.opt
177 | # from train.py
178 | if stage == "fit":
179 | # if opt[TRAIN_DATA]:
180 | assert opt[TRAIN_DATA][INPUT]
181 | self.training_dataset = ImagesDataset(opt, ds_type=TRAIN_DATA, transform=self.transform, batchsize=opt.batchsize)
182 |
183 | # valid data provided:
184 | if opt[VALID_DATA] and opt[VALID_DATA][INPUT]:
185 | self.valid_dataset = ImagesDataset(opt, ds_type=VALID_DATA, transform=self.valid_transform, batchsize=opt.valid_batchsize)
186 |
187 | # no valid data, splt from training data:
188 | elif opt[VALID_RATIO]:
189 | print(f'Split valid dataset from training data. Ratio: {opt[VALID_RATIO]}')
190 | valid_size = int(opt[VALID_RATIO] * len(self.training_dataset))
191 | train_size = len(self.training_dataset) - valid_size
192 | torch.manual_seed(233)
193 | self.training_dataset, self.valid_dataset = torch.utils.data.random_split(self.training_dataset, [
194 | train_size, valid_size
195 | ])
196 | print(
197 | f'Update - training data: {len(self.training_dataset)}; valid data: {len(self.valid_dataset)}')
198 |
199 | # testing phase
200 | # if stage == 'test':
201 | if opt[TEST_DATA] and opt[TEST_DATA][INPUT]:
202 | self.test_dataset = ImagesDataset(opt, ds_type=TEST_DATA, transform=self.test_transform, batchsize=1)
203 |
204 | def train_dataloader(self):
205 | if self.training_dataset:
206 | trainloader = torch.utils.data.DataLoader(
207 | self.training_dataset,
208 | batch_size=self.opt[BATCHSIZE],
209 | num_workers=self.opt[DATALOADER_N],
210 | shuffle=True,
211 | drop_last=True,
212 | pin_memory=True
213 | )
214 | return trainloader
215 |
216 | def val_dataloader(self):
217 | if self.valid_dataset:
218 | return torch.utils.data.DataLoader(
219 | self.valid_dataset,
220 | batch_size=self.opt[VALID_BATCHSIZE],
221 | shuffle=False,
222 | num_workers=self.opt[DATALOADER_N]
223 | )
224 |
225 | def test_dataloader(self):
226 | if self.test_dataset:
227 | return torch.utils.data.DataLoader(
228 | self.test_dataset,
229 | batch_size=1,
230 | shuffle=False,
231 | num_workers=self.opt[DATALOADER_N],
232 | pin_memory=True
233 | )
234 |
235 | def teardown(self, stage):
236 | # clean up after fit or test
237 | # called on every process in DDP
238 | ...
239 |
240 |
241 | if __name__ == '__main__':
242 | ...
243 |
--------------------------------------------------------------------------------
/src/env.yaml:
--------------------------------------------------------------------------------
1 | name: base
2 | channels:
3 | - pytorch
4 | - bottler
5 | - conda-forge
6 | - defaults
7 | dependencies:
8 | - _anaconda_depends=2020.07=py38_0
9 | - _ipyw_jlab_nb_ext_conf=0.1.0=py38_0
10 | - _libgcc_mutex=0.1=main
11 | - alabaster=0.7.12=pyhd3eb1b0_0
12 | - anaconda=custom=py38_1
13 | - anaconda-client=1.7.2=py38_0
14 | - anaconda-navigator=2.0.3=py38_0
15 | - anaconda-project=0.9.1=pyhd3eb1b0_1
16 | - appdirs=1.4.4=py_0
17 | - argh=0.26.2=py38_0
18 | - argon2-cffi=20.1.0=py38h27cfd23_1
19 | - asn1crypto=1.4.0=py_0
20 | - astroid=2.5=py38h06a4308_1
21 | - astropy=4.2.1=py38h27cfd23_1
22 | - async_generator=1.10=pyhd3eb1b0_0
23 | - atomicwrites=1.4.0=py_0
24 | - attrs=20.3.0=pyhd3eb1b0_0
25 | - autopep8=1.5.6=pyhd3eb1b0_0
26 | - babel=2.9.0=pyhd3eb1b0_0
27 | - backcall=0.2.0=pyhd3eb1b0_0
28 | - backports=1.0=pyhd3eb1b0_2
29 | - backports.functools_lru_cache=1.6.4=pyhd3eb1b0_0
30 | - backports.shutil_get_terminal_size=1.0.0=pyhd3eb1b0_3
31 | - backports.tempfile=1.0=pyhd3eb1b0_1
32 | - backports.weakref=1.0.post1=py_1
33 | - beautifulsoup4=4.9.3=pyha847dfd_0
34 | - bitarray=2.1.0=py38h27cfd23_1
35 | - bkcharts=0.2=py38_0
36 | - black=19.10b0=py_0
37 | - blas=1.0=mkl
38 | - bleach=3.3.0=pyhd3eb1b0_0
39 | - blosc=1.21.0=h8c45485_0
40 | - bokeh=2.3.2=py38h06a4308_0
41 | - boto=2.49.0=py38_0
42 | - bottleneck=1.3.2=py38heb32a55_1
43 | - brotlipy=0.7.0=py38h27cfd23_1003
44 | - bzip2=1.0.8=h7b6447c_0
45 | - c-ares=1.17.1=h27cfd23_0
46 | - ca-certificates=2022.4.26=h06a4308_0
47 | - cairo=1.16.0=hf32fb01_1
48 | - certifi=2021.10.8=py38h06a4308_2
49 | - cffi=1.14.5=py38h261ae71_0
50 | - chardet=4.0.0=py38h06a4308_1003
51 | - cloudpickle=1.6.0=py_0
52 | - clyent=1.2.2=py38_1
53 | - colorama=0.4.4=pyhd3eb1b0_0
54 | - conda=4.12.0=py38h06a4308_0
55 | - conda-build=3.21.4=py38h06a4308_0
56 | - conda-content-trust=0.1.1=pyhd3eb1b0_0
57 | - conda-env=2.6.0=1
58 | - conda-package-handling=1.7.3=py38h27cfd23_1
59 | - conda-repo-cli=1.0.4=pyhd3eb1b0_0
60 | - conda-token=0.3.0=pyhd3eb1b0_0
61 | - conda-verify=3.4.2=py_1
62 | - contextlib2=0.6.0.post1=py_0
63 | - cryptography=3.4.7=py38hd23ed53_0
64 | - cudatoolkit=10.2.89=hfd86e86_1
65 | - cudatoolkit-dev=10.1.243=h516909a_3
66 | - curl=7.71.1=hbc83047_1
67 | - cython=0.29.23=py38h2531618_0
68 | - cytoolz=0.11.0=py38h7b6447c_0
69 | - dask=2021.4.0=pyhd3eb1b0_0
70 | - dask-core=2021.4.0=pyhd3eb1b0_0
71 | - dbus=1.13.18=hb2f20db_0
72 | - decorator=5.0.6=pyhd3eb1b0_0
73 | - defusedxml=0.7.1=pyhd3eb1b0_0
74 | - diff-match-patch=20200713=py_0
75 | - distributed=2021.4.1=py38h06a4308_0
76 | - docutils=0.17.1=py38h06a4308_1
77 | - entrypoints=0.3=py38_0
78 | - et_xmlfile=1.0.1=py_1001
79 | - expat=2.3.0=h2531618_2
80 | - fastcache=1.1.0=py38h7b6447c_0
81 | - ffmpeg=4.2.2=h20bf706_0
82 | - flake8=3.9.0=pyhd3eb1b0_0
83 | - flask=1.1.2=pyhd3eb1b0_0
84 | - fontconfig=2.13.1=h6c09931_0
85 | - freetype=2.10.4=h5ab3b9f_0
86 | - fribidi=1.0.10=h7b6447c_0
87 | - fsspec=2021.8.1=pyhd3eb1b0_0
88 | - future=0.18.2=py38_1
89 | - get_terminal_size=1.0.0=haa9412d_0
90 | - gevent=21.1.2=py38h27cfd23_1
91 | - glib=2.68.1=h36276a3_0
92 | - glob2=0.7=pyhd3eb1b0_0
93 | - gmp=6.2.1=h2531618_2
94 | - gmpy2=2.0.8=py38hd5f6e3b_3
95 | - gnutls=3.6.15=he1e5248_0
96 | - graphite2=1.3.14=h23475e2_0
97 | - greenlet=1.0.0=py38h2531618_2
98 | - gst-plugins-base=1.14.0=h8213a91_2
99 | - gstreamer=1.14.0=h28cd5cc_2
100 | - h5py=2.10.0=py38h7918eee_0
101 | - harfbuzz=2.8.0=h6f93f22_0
102 | - hdf5=1.10.4=hb1b8bf9_0
103 | - heapdict=1.0.1=py_0
104 | - html5lib=1.1=py_0
105 | - icu=58.2=he6710b0_3
106 | - idna=2.10=pyhd3eb1b0_0
107 | - imagesize=1.2.0=pyhd3eb1b0_0
108 | - importlib-metadata=3.10.0=py38h06a4308_0
109 | - importlib_metadata=3.10.0=hd3eb1b0_0
110 | - iniconfig=1.1.1=pyhd3eb1b0_0
111 | - intel-openmp=2021.2.0=h06a4308_610
112 | - intervaltree=3.1.0=py_0
113 | - ipykernel=5.3.4=py38h5ca1d4c_0
114 | - ipython=7.22.0=py38hb070fc8_0
115 | - ipython_genutils=0.2.0=pyhd3eb1b0_1
116 | - ipywidgets=7.6.3=pyhd3eb1b0_1
117 | - isort=5.8.0=pyhd3eb1b0_0
118 | - itsdangerous=1.1.0=pyhd3eb1b0_0
119 | - jbig=2.1=hdba287a_0
120 | - jdcal=1.4.1=py_0
121 | - jedi=0.17.2=py38h06a4308_1
122 | - jeepney=0.6.0=pyhd3eb1b0_0
123 | - joblib=1.0.1=pyhd3eb1b0_0
124 | - jpeg=9b=h024ee3a_2
125 | - json5=0.9.5=py_0
126 | - jsonschema=3.2.0=py_2
127 | - jupyter=1.0.0=py38_7
128 | - jupyter-packaging=0.7.12=pyhd3eb1b0_0
129 | - jupyter_client=6.1.12=pyhd3eb1b0_0
130 | - jupyter_console=6.4.0=pyhd3eb1b0_0
131 | - jupyter_core=4.7.1=py38h06a4308_0
132 | - jupyterlab_pygments=0.1.2=py_0
133 | - jupyterlab_widgets=1.0.0=pyhd3eb1b0_1
134 | - keyring=22.3.0=py38h06a4308_0
135 | - krb5=1.18.2=h173b8e3_0
136 | - lame=3.100=h7b6447c_0
137 | - lazy-object-proxy=1.6.0=py38h27cfd23_0
138 | - lcms2=2.12=h3be6417_0
139 | - ld_impl_linux-64=2.33.1=h53a641e_7
140 | - libarchive=3.4.2=h62408e4_0
141 | - libcurl=7.71.1=h20c2e04_1
142 | - libedit=3.1.20210216=h27cfd23_1
143 | - libev=4.33=h7b6447c_0
144 | - libffi=3.3=he6710b0_2
145 | - libgcc-ng=9.1.0=hdf63c60_0
146 | - libgfortran-ng=7.3.0=hdf63c60_0
147 | - libiconv=1.15=h63c8f33_5
148 | - libidn2=2.3.2=h7f8727e_0
149 | - liblief=0.10.1=he6710b0_0
150 | - libllvm10=10.0.1=hbcb73fb_5
151 | - libllvm9=9.0.1=h4a3c616_1
152 | - libopus=1.3.1=h7b6447c_0
153 | - libpng=1.6.37=hbc83047_0
154 | - libsodium=1.0.18=h7b6447c_0
155 | - libspatialindex=1.9.3=h2531618_0
156 | - libssh2=1.9.0=h1ba5d50_1
157 | - libstdcxx-ng=9.1.0=hdf63c60_0
158 | - libtasn1=4.16.0=h27cfd23_0
159 | - libtiff=4.2.0=h85742a9_0
160 | - libtool=2.4.6=h7b6447c_1005
161 | - libunistring=0.9.10=h27cfd23_0
162 | - libuuid=1.0.3=h1bed415_2
163 | - libuv=1.40.0=h7b6447c_0
164 | - libvpx=1.7.0=h439df22_0
165 | - libwebp-base=1.2.0=h27cfd23_0
166 | - libxcb=1.14=h7b6447c_0
167 | - libxml2=2.9.10=hb55368b_3
168 | - libxslt=1.1.34=hc22bd24_0
169 | - llvmlite=0.36.0=py38h612dafd_4
170 | - locket=0.2.1=py38h06a4308_1
171 | - lxml=4.6.3=py38h9120a33_0
172 | - lz4-c=1.9.3=h2531618_0
173 | - lzo=2.10=h7b6447c_2
174 | - matplotlib-base=3.3.4=py38h62a2d02_0
175 | - mccabe=0.6.1=py38_1
176 | - mkl=2021.4.0=h06a4308_640
177 | - mkl-service=2.3.0=py38h27cfd23_1
178 | - mkl_fft=1.3.0=py38h42c9631_2
179 | - mkl_random=1.2.1=py38ha9443f7_2
180 | - mock=4.0.3=pyhd3eb1b0_0
181 | - more-itertools=8.7.0=pyhd3eb1b0_0
182 | - mpc=1.1.0=h10f8cd9_1
183 | - mpfr=4.0.2=hb69a4c5_1
184 | - mpmath=1.2.1=py38h06a4308_0
185 | - msgpack-python=1.0.2=py38hff7bd54_1
186 | - multipledispatch=0.6.0=py38_0
187 | - mypy_extensions=0.4.3=py38_0
188 | - navigator-updater=0.2.1=py38_0
189 | - nbclient=0.5.3=pyhd3eb1b0_0
190 | - ncurses=6.2=he6710b0_1
191 | - nest-asyncio=1.5.1=pyhd3eb1b0_0
192 | - nettle=3.7.3=hbbd107a_1
193 | - nltk=3.6.1=pyhd3eb1b0_0
194 | - nose=1.3.7=pyhd3eb1b0_1006
195 | - notebook=6.3.0=py38h06a4308_0
196 | - numba=0.53.1=py38ha9443f7_0
197 | - numexpr=2.7.3=py38h22e1b3c_1
198 | - numpy-base=1.21.5=py38hf524024_2
199 | - numpydoc=1.1.0=pyhd3eb1b0_1
200 | - nvidiacub=1.10.0=0
201 | - olefile=0.46=py_0
202 | - openh264=2.1.0=hd408876_0
203 | - openpyxl=3.0.7=pyhd3eb1b0_0
204 | - openssl=1.1.1n=h7f8727e_0
205 | - pandas=1.2.4=py38h2531618_0
206 | - pandoc=2.12=h06a4308_0
207 | - pandocfilters=1.4.3=py38h06a4308_1
208 | - pango=1.45.3=hd140c19_0
209 | - parso=0.7.0=py_0
210 | - partd=1.2.0=pyhd3eb1b0_0
211 | - patchelf=0.12=h2531618_1
212 | - path=15.1.2=py38h06a4308_0
213 | - path.py=12.5.0=0
214 | - pathlib2=2.3.5=py38h06a4308_2
215 | - pathspec=0.7.0=py_0
216 | - pathtools=0.1.2=pyhd3eb1b0_1
217 | - patsy=0.5.1=py38_0
218 | - pcre=8.44=he6710b0_0
219 | - pep8=1.7.1=py38_0
220 | - pexpect=4.8.0=pyhd3eb1b0_3
221 | - pickleshare=0.7.5=pyhd3eb1b0_1003
222 | - pixman=0.40.0=h7b6447c_0
223 | - pkginfo=1.7.0=py38h06a4308_0
224 | - pluggy=0.13.1=py38h06a4308_0
225 | - ply=3.11=py38_0
226 | - prometheus_client=0.10.1=pyhd3eb1b0_0
227 | - prompt-toolkit=3.0.17=pyh06a4308_0
228 | - prompt_toolkit=3.0.17=hd3eb1b0_0
229 | - psutil=5.8.0=py38h27cfd23_1
230 | - ptyprocess=0.7.0=pyhd3eb1b0_2
231 | - py=1.10.0=pyhd3eb1b0_0
232 | - py-lief=0.10.1=py38h403a769_0
233 | - pycodestyle=2.6.0=pyhd3eb1b0_0
234 | - pycosat=0.6.3=py38h7b6447c_1
235 | - pycparser=2.20=py_2
236 | - pycurl=7.43.0.6=py38h1ba5d50_0
237 | - pydocstyle=6.0.0=pyhd3eb1b0_0
238 | - pyerfa=1.7.3=py38h27cfd23_0
239 | - pyflakes=2.2.0=pyhd3eb1b0_0
240 | - pygments=2.8.1=pyhd3eb1b0_0
241 | - pylint=2.7.4=py38h06a4308_1
242 | - pyls-black=0.4.6=hd3eb1b0_0
243 | - pyls-spyder=0.3.2=pyhd3eb1b0_0
244 | - pyodbc=4.0.30=py38he6710b0_0
245 | - pyopenssl=20.0.1=pyhd3eb1b0_1
246 | - pyqt=5.9.2=py38h05f1152_4
247 | - pyrsistent=0.17.3=py38h7b6447c_0
248 | - pysocks=1.7.1=py38h06a4308_0
249 | - pytables=3.6.1=py38h9fd0a39_0
250 | - pytest=6.2.3=py38h06a4308_2
251 | - python=3.8.8=hdb3f193_5
252 | - python-jsonrpc-server=0.4.0=py_0
253 | - python-language-server=0.36.2=pyhd3eb1b0_0
254 | - python-libarchive-c=2.9=pyhd3eb1b0_1
255 | - python_abi=3.8=2_cp38
256 | - pytz=2021.1=pyhd3eb1b0_0
257 | - pyxdg=0.27=pyhd3eb1b0_0
258 | - pyzmq=20.0.0=py38h2531618_1
259 | - qdarkstyle=2.8.1=py_0
260 | - qt=5.9.7=h5867ecd_1
261 | - qtawesome=1.0.2=pyhd3eb1b0_0
262 | - qtconsole=5.0.3=pyhd3eb1b0_0
263 | - qtpy=1.9.0=py_0
264 | - readline=8.1=h27cfd23_0
265 | - regex=2021.4.4=py38h27cfd23_0
266 | - requests=2.25.1=pyhd3eb1b0_0
267 | - ripgrep=12.1.1=0
268 | - rope=0.18.0=py_0
269 | - rtree=0.9.7=py38h06a4308_1
270 | - ruamel_yaml=0.15.100=py38h27cfd23_0
271 | - scikit-learn=0.24.1=py38ha9443f7_0
272 | - seaborn=0.11.1=pyhd3eb1b0_0
273 | - secretstorage=3.3.1=py38h06a4308_0
274 | - setuptools=52.0.0=py38h06a4308_0
275 | - simplegeneric=0.8.1=py38_2
276 | - singledispatch=3.6.1=pyhd3eb1b0_1001
277 | - sip=4.19.13=py38he6710b0_0
278 | - snappy=1.1.8=he6710b0_0
279 | - sniffio=1.2.0=py38h06a4308_1
280 | - snowballstemmer=2.1.0=pyhd3eb1b0_0
281 | - sortedcollections=2.1.0=pyhd3eb1b0_0
282 | - sortedcontainers=2.3.0=pyhd3eb1b0_0
283 | - soupsieve=2.2.1=pyhd3eb1b0_0
284 | - sphinx=4.0.1=pyhd3eb1b0_0
285 | - sphinxcontrib=1.0=py38_1
286 | - sphinxcontrib-applehelp=1.0.2=pyhd3eb1b0_0
287 | - sphinxcontrib-devhelp=1.0.2=pyhd3eb1b0_0
288 | - sphinxcontrib-htmlhelp=1.0.3=pyhd3eb1b0_0
289 | - sphinxcontrib-jsmath=1.0.1=pyhd3eb1b0_0
290 | - sphinxcontrib-qthelp=1.0.3=pyhd3eb1b0_0
291 | - sphinxcontrib-serializinghtml=1.1.4=pyhd3eb1b0_0
292 | - sphinxcontrib-websupport=1.2.4=py_0
293 | - spyder=4.2.5=py38h06a4308_0
294 | - spyder-kernels=1.10.2=py38h06a4308_0
295 | - sqlalchemy=1.4.15=py38h27cfd23_0
296 | - sqlite=3.35.4=hdfb4753_0
297 | - statsmodels=0.12.2=py38h27cfd23_0
298 | - sympy=1.8=py38h06a4308_0
299 | - tbb=2020.3=hfd86e86_0
300 | - tblib=1.7.0=py_0
301 | - terminado=0.9.4=py38h06a4308_0
302 | - testpath=0.4.4=pyhd3eb1b0_0
303 | - textdistance=4.2.1=pyhd3eb1b0_0
304 | - threadpoolctl=2.1.0=pyh5ca1d4c_0
305 | - three-merge=0.1.1=pyhd3eb1b0_0
306 | - tk=8.6.10=hbc83047_0
307 | - toml=0.10.2=pyhd3eb1b0_0
308 | - toolz=0.11.1=pyhd3eb1b0_0
309 | - torchaudio=0.9.1=py38
310 | - tornado=6.1=py38h27cfd23_0
311 | - typed-ast=1.4.2=py38h27cfd23_1
312 | - ujson=4.0.2=py38h2531618_0
313 | - unicodecsv=0.14.1=py38_0
314 | - unixodbc=2.3.9=h7b6447c_0
315 | - urllib3=1.26.4=pyhd3eb1b0_0
316 | - watchdog=1.0.2=py38h06a4308_1
317 | - wcwidth=0.2.5=py_0
318 | - webencodings=0.5.1=py38_1
319 | - werkzeug=1.0.1=pyhd3eb1b0_0
320 | - wheel=0.36.2=pyhd3eb1b0_0
321 | - widgetsnbextension=3.5.1=py38_0
322 | - wrapt=1.12.1=py38h7b6447c_1
323 | - wurlitzer=2.1.0=py38h06a4308_0
324 | - x264=1!157.20191217=h7b6447c_0
325 | - xlrd=2.0.1=pyhd3eb1b0_0
326 | - xlsxwriter=1.3.8=pyhd3eb1b0_0
327 | - xlwt=1.3.0=py38_0
328 | - xmltodict=0.12.0=py_0
329 | - xz=5.2.5=h7b6447c_0
330 | - yaml=0.2.5=h7b6447c_0
331 | - yapf=0.31.0=pyhd3eb1b0_0
332 | - zeromq=4.3.4=h2531618_0
333 | - zict=2.0.0=pyhd3eb1b0_0
334 | - zipp=3.4.1=pyhd3eb1b0_0
335 | - zlib=1.2.11=h7b6447c_3
336 | - zope=1.0=py38_1
337 | - zope.event=4.5.0=py38_0
338 | - zope.interface=5.3.0=py38h27cfd23_0
339 | - zstd=1.4.5=h9ceee32_0
340 | - pip:
341 | - absl-py==0.13.0
342 | - addict==2.4.0
343 | - aiohttp==3.7.4.post0
344 | - aiosignal==1.2.0
345 | - albumentations==1.1.0
346 | - altair==4.2.0
347 | - antlr4-python3-runtime==4.9.3
348 | - anyio==3.6.1
349 | - anykeystore==0.2
350 | - apex==0.1
351 | - astor==0.8.1
352 | - async-timeout==3.0.1
353 | - audioread==2.1.9
354 | - augly==0.1.9
355 | - av==10.0.0
356 | - backports-zoneinfo==0.2.1
357 | - base58==2.1.1
358 | - bcrypt==4.0.1
359 | - blinker==1.4
360 | - blis==0.7.5
361 | - cachetools==4.2.2
362 | - catalogue==2.0.6
363 | - chex==0.1.4
364 | - clean-fid==0.1.29
365 | - click==8.1.3
366 | - clip==1.0
367 | - commonmark==0.9.1
368 | - configargparse==1.5.3
369 | - configparser==5.0.2
370 | - cryptacular==1.6.2
371 | - cycler==0.11.0
372 | - cymem==2.0.6
373 | - deepdiff==5.7.0
374 | - diffusers==0.3.0
375 | - distlib==0.3.4
376 | - dm-pix==0.3.3
377 | - dm-tree==0.1.7
378 | - docker-pycreds==0.4.0
379 | - dominate==2.6.0
380 | - easydict==1.9
381 | - editdistance==0.6.0
382 | - einops==0.4.1
383 | - etils==0.7.1
384 | - exifread==3.0.0
385 | - facexlib==0.2.5
386 | - fairscale==0.4.9
387 | - fairseq==0.9.0
388 | - fastai==2.5.3
389 | - fastapi==0.85.0
390 | - fastcore==1.3.27
391 | - fastdownload==0.0.5
392 | - fastjsonschema==2.16.1
393 | - fastprogress==1.0.0
394 | - ffmpy==0.3.0
395 | - filelock==3.7.1
396 | - filterpy==1.4.5
397 | - fire==0.4.0
398 | - flax==0.6.1
399 | - font-roboto==0.0.1
400 | - fonts==0.0.3
401 | - fonttools==4.31.2
402 | - frozenlist==1.3.0
403 | - ftfy==6.1.1
404 | - fvcore==0.1.5.post20220305
405 | - gdown==4.5.1
406 | - gfpgan==1.3.8
407 | - gin-config==0.5.0
408 | - gitdb==4.0.7
409 | - gitpython==3.1.18
410 | - google-auth==1.35.0
411 | - google-auth-oauthlib==0.4.6
412 | - gradio==3.4.1
413 | - grpcio==1.40.0
414 | - h11==0.12.0
415 | - htcondor==9.2.0
416 | - httpcore==0.15.0
417 | - httpimport==0.7.2
418 | - httpx==0.23.0
419 | - huggingface-hub==0.10.1
420 | - hupper==1.10.3
421 | - hydra-core==1.1.1
422 | - imageio==2.16.1
423 | - imageio-ffmpeg==0.4.5
424 | - imgaug==0.4.0
425 | - importlib-resources==5.2.2
426 | - install==1.3.5
427 | - iopath==0.1.9
428 | - ipdb==0.13.9
429 | - jax==0.3.23
430 | - jaxlib==0.3.22
431 | - jinja2==3.1.2
432 | - jsonmerge==1.8.0
433 | - jsonpatch==1.32
434 | - jsonpointer==2.2
435 | - jupyter-server==1.18.1
436 | - jupyter-server-mathjax==0.2.6
437 | - jupyterlab==3.4.5
438 | - jupyterlab-git==0.38.0
439 | - jupyterlab-server==2.15.1
440 | - kiwisolver==1.4.0
441 | - kornia==0.6.7
442 | - kornia-moons==0.2.0
443 | - lark==1.1.2
444 | - librosa==0.8.1
445 | - linkify-it-py==1.0.3
446 | - lmdb==1.2.1
447 | - lpips==0.1.4
448 | - markdown==3.3.4
449 | - markdown-it-py==2.1.0
450 | - markupsafe==2.1.1
451 | - matplotlib==3.5.1
452 | - mdit-py-plugins==0.3.1
453 | - mdurl==0.1.2
454 | - mediapy==1.1.0
455 | - mistune==2.0.4
456 | - mrcfile==1.3.0
457 | - multidict==5.1.0
458 | - murmurhash==1.0.6
459 | - natsort==8.1.0
460 | - nbclassic==0.4.3
461 | - nbconvert==7.0.0
462 | - nbdime==3.1.1
463 | - nbformat==5.4.0
464 | - networkx==2.7.1
465 | - ninja==1.10.2.3
466 | - nlpaug==1.1.3
467 | - notebook-shim==0.1.0
468 | - numpy==1.23.3
469 | - oauthlib==3.1.1
470 | - omegaconf==2.2.3
471 | - opencv-python==4.5.5.64
472 | - opencv-python-headless==4.5.4.58
473 | - opt-einsum==3.3.0
474 | - optax==0.1.3
475 | - ordered-set==4.0.2
476 | - orjson==3.8.0
477 | - packaging==21.3
478 | - paramiko==2.11.0
479 | - pastedeploy==2.1.1
480 | - patchify==0.2.3
481 | - pathlib==1.0.1
482 | - pathy==0.6.1
483 | - pbkdf2==1.3
484 | - piexif==1.1.3
485 | - pillow==9.0.1
486 | - pip==22.3.1
487 | - plaster==1.0
488 | - plaster-pastedeploy==0.7
489 | - platformdirs==2.5.2
490 | - plotly==5.3.1
491 | - plyfile==0.7.4
492 | - pooch==1.5.2
493 | - portalocker==2.3.2
494 | - preshed==3.0.6
495 | - promise==2.3
496 | - protobuf==3.17.3
497 | - pyarrow==7.0.0
498 | - pyasn1==0.4.8
499 | - pyasn1-modules==0.2.8
500 | - pycryptodome==3.15.0
501 | - pydantic==1.8.2
502 | - pydeck==0.7.1
503 | - pydeprecate==0.3.1
504 | - pydub==0.25.1
505 | - pyexiftool==0.5.4
506 | - pympler==1.0.1
507 | - pynacl==1.5.0
508 | - pyparsing==3.0.7
509 | - pyramid==2.0
510 | - pyramid-mailer==0.15.1
511 | - python-dateutil==2.8.2
512 | - python-magic==0.4.24
513 | - python-multipart==0.0.4
514 | - python3-openid==3.2.0
515 | - pytorch-fid==0.2.1
516 | - pytorch-lightning==1.7.6
517 | - pytz-deprecation-shim==0.1.0.post0
518 | - pywavelets==1.3.0
519 | - pyyaml==6.0
520 | - qudida==0.0.4
521 | - rawpy==0.17.3
522 | - ray==1.13.0
523 | - realesrgan==0.3.0
524 | - repoze-sendmail==4.4.1
525 | - requests-oauthlib==1.3.0
526 | - resampy==0.2.2
527 | - resize-right==0.0.2
528 | - rfc3986==1.5.0
529 | - rich==11.2.0
530 | - rich-cli==1.8.0
531 | - rich-rst==1.1.7
532 | - rsa==4.7.2
533 | - sacrebleu==2.0.0
534 | - scikit-image==0.19.2
535 | - scikit-video==1.1.11
536 | - scipy==1.8.0
537 | - semver==2.13.0
538 | - send2trash==1.8.0
539 | - sentry-sdk==1.3.1
540 | - setproctitle==1.2.2
541 | - shapely==1.7.1
542 | - shortuuid==1.0.1
543 | - simdkalman==1.0.2
544 | - six==1.16.0
545 | - smart-open==5.2.1
546 | - smmap==4.0.0
547 | - soundfile==0.10.3.post1
548 | - spacy==3.1.4
549 | - spacy-legacy==3.0.8
550 | - srsly==2.4.2
551 | - starlette==0.20.4
552 | - streamlit==1.3.1
553 | - subprocess32==3.5.4
554 | - svox2-csrc==0.0.1.dev0+sphtexcub.lincolor.fast
555 | - tabulate==0.8.9
556 | - tb-nightly==2.11.0a20221012
557 | - tenacity==8.0.1
558 | - tensorboard==2.10.1
559 | - tensorboard-data-server==0.6.1
560 | - tensorboard-plugin-wit==1.8.0
561 | - tensorboardx==2.4.1
562 | - termcolor==1.1.0
563 | - test-tube==0.7.5
564 | - textual==0.1.18
565 | - thinc==8.0.12
566 | - thop==0.0.31-2005241907
567 | - tifffile==2022.3.16
568 | - timm==0.6.7
569 | - tinycss2==1.1.1
570 | - tinycudann==1.6
571 | - tl2==0.0.8
572 | - tokenizers==0.12.1
573 | - torch==1.9.0
574 | - torch-ema==0.3
575 | - torch-scatter==2.0.9
576 | - torchdiffeq==0.2.3
577 | - torchfile==0.1.0
578 | - torchmetrics==0.9.2
579 | - torchsummary==1.5.1
580 | - torchvision==0.10.0
581 | - tqdm==4.64.0
582 | - traitlets==5.3.0
583 | - transaction==3.0.1
584 | - transformers==4.19.2
585 | - translationstring==1.4
586 | - trimesh==3.11.2
587 | - tsmoothie==1.0.4
588 | - typeguard==2.13.3
589 | - typer==0.4.0
590 | - typing-extensions==4.3.0
591 | - tzdata==2021.5
592 | - tzlocal==4.1
593 | - uc-micro-py==1.0.1
594 | - uvicorn==0.18.3
595 | - validators==0.18.2
596 | - velruse==1.1.1
597 | - venusian==3.0.0
598 | - virtualenv==20.15.1
599 | - visdom==0.1.8.9
600 | - vren==1.0
601 | - wand==0.6.7
602 | - wandb==0.13.8
603 | - warmup-scheduler==0.3
604 | - wasabi==0.8.2
605 | - webob==1.8.7
606 | - websocket-client==1.2.3
607 | - websockets==10.3
608 | - wtforms==3.0.1
609 | - wtforms-recaptcha==0.3.2
610 | - yacs==0.1.8
611 | - yarl==1.6.3
612 | - yaspin==2.1.0
613 | - zope-deprecation==4.4.0
614 | - zope-sqlalchemy==1.6
615 | prefix: /home/grads/hywang26/anaconda3
616 |
--------------------------------------------------------------------------------
/src/globalenv.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from pathlib import Path
3 |
4 | SRC_PATH = Path(__file__).absolute().parent
5 | ROOT_PATH = SRC_PATH.parent
6 | sys.path.append(str(SRC_PATH))
7 |
8 | LOGGER_BUFFER_LOCK = False
9 | SPLIT = '————————————————————————————————————————————————————'
10 |
11 | GLOBAL_SEED = 233
12 | TEST_RESULT_DIRNAME = 'test_result'
13 | TRAIN_LOG_DIRNAME = 'log'
14 | CONFIG_DIR = 'config'
15 | CONFIG_FILEPATH = 'config/config.yaml'
16 | LMDB_DIRPATH = ROOT_PATH / 'lmdb'
17 | METRICS_LOG_DIRPATH = ROOT_PATH / 'metrics_log'
18 | OPT_FILENAME = 'CONFIG.yaml'
19 | LOG_FILENAME = 'run.log'
20 | LOG_TIME_FORMAT = '%Y-%m-%d_%H:%M:%S'
21 | INPUT = 'input'
22 | OUTPUT = 'output'
23 | GT = 'GT'
24 | STRING_FALSE = 'False'
25 | SKIP_FLAG = 'q'
26 | DEFAULTS = 'defaults'
27 | HYDRA = 'hydra'
28 |
29 | INPUT_FPATH = 'input_fpath'
30 | GT_FPATH = 'gt_fpath'
31 |
32 | DEBUG = 'debug'
33 | BACKEND = 'backend'
34 | CHECKPOINT_PATH = 'checkpoint_path'
35 | LOG_DIRPATH = 'log_dirpath'
36 | IMG_DIRPATH = 'img_dirpath'
37 | DATALOADER_N = 'dataloader_num_worker'
38 | VAL_DEBUG_STEP_NUMS = 'val_debug_step_nums'
39 | VALID_EVERY = 'valid_every'
40 | LOG_EVERY = 'log_every'
41 | AUGMENTATION = 'aug'
42 | RUNTIME_PRECISION = 'runtime_precision'
43 | NUM_EPOCH = 'num_epoch'
44 | NAME = 'name'
45 | LOSS = 'loss'
46 | TRAIN_DATA = 'train_ds'
47 | VALID_DATA = 'valid_ds'
48 | TEST_DATA = 'test_ds'
49 | GPU = 'gpu'
50 | RUNTIME = 'runtime'
51 | CLASS = 'class'
52 | MODELNAME = 'modelname'
53 | BATCHSIZE = 'batchsize'
54 | VALID_BATCHSIZE = 'valid_batchsize'
55 | LR = 'lr'
56 | CHECKPOINT_MONITOR = 'checkpoint_monitor'
57 | MONITOR_MODE = 'monitor_mode'
58 | COMMENT = 'comment'
59 | EARLY_STOP = 'early_stop'
60 | AMP_BACKEND = 'amp_backend'
61 | AMP_LEVEL = 'amp_level'
62 | VALID_RATIO = 'valid_ratio'
63 |
64 | LTV_LOSS = 'ltv'
65 | COS_LOSS = 'cos'
66 | SSIM_LOSS = 'ssim_loss'
67 | L1_LOSS = 'l1_loss'
68 | COLOR_LOSS = 'l_color'
69 | SPATIAL_LOSS = 'l_spa'
70 | EXPOSURE_LOSS = 'l_exp'
71 | WEIGHTED_LOSS = 'weighted_loss'
72 | PSNR_LOSS = 'psnr_loss'
73 | HIST_LOSS = 'hist_loss'
74 | INTER_HIST_LOSS = 'inter_hist_loss'
75 | VGG_LOSS = 'vgg_loss'
76 |
77 | PSNR = 'psnr'
78 | SSIM = 'ssim'
79 |
80 | VERTICAL_FLIP = 'v-flip'
81 | HORIZON_FLIP = 'h-flip'
82 | DOWNSAMPLE = 'downsample'
83 | RESIZE_DIVISIBLE_N = 'resize_divisible_n'
84 | CROP = 'crop'
85 | LIGHTNESS_ADJUST = 'lightness_adjust'
86 | CONTRAST_ADJUST = 'contrast_adjust'
87 |
88 | BUNET = 'bilateral_upsample_net'
89 | UNET = 'unet'
90 | HIST_UNET = 'hist_unet'
91 | PREDICT_ILLUMINATION = 'predict_illumination'
92 | FILTERS = 'filters'
93 |
94 | MODE = 'mode'
95 | COLOR_SPACE = 'color_space'
96 | BETA1 = 'beta1'
97 | BETA2 = 'beta2'
98 | LAMBDA_SMOOTH = 'lambda_smooth'
99 | LAMBDA_MONOTONICITY = 'lambda_monotonicity'
100 | MSE = 'mse'
101 | L2_LOSS = 'l2_loss'
102 | TV_CONS = 'tv_cons'
103 | MN_CONS = 'mv_cons'
104 | WEIGHTS_NORM = 'wnorm'
105 | TEST_PTH = 'test_pth'
106 |
107 | LUMA_BINS = 'luma_bins'
108 | CHANNEL_MULTIPLIER = 'channel_multiplier'
109 | SPATIAL_BIN = 'spatial_bin'
110 | BATCH_NORM = 'batch_norm'
111 | NET_INPUT_SIZE = 'net_input_size'
112 | LOW_RESOLUTION = 'low_resolution'
113 | ONNX_EXPORTING_MODE = 'onnx_exporting_mode'
114 | SELF_SUPERVISED = 'self_supervised'
115 | COEFFS_TYPE = 'coeffs_type'
116 | ILLU_MAP_POWER = 'illu_map_power'
117 | GAMMA = 'gamma'
118 | MATRIX = 'matrix'
119 | GUIDEMAP = 'guidemap'
120 | USE_HSV = 'use_hsv'
121 |
122 | USE_WAVELET = 'use_wavelet'
123 | NON_LOCAL = 'use_non_local'
124 | USE_ATTN_MAP = 'use_attn_map'
125 | ILLUMAP_CHANNEL = 'illumap_channel'
126 | HOW_TO_FUSE = 'how_to_fuse'
127 | SHARE_WEIGHTS = 'share_weights'
128 | BACKBONE = 'backbone'
129 | ARCH = 'arch'
130 | N_BINS = 'n_bins'
131 | BACKBONE_OUT_ILLU = 'backbone_out_illu'
132 | CONV_TYPE = 'conv_type'
133 | HIST_AS_GUIDE_ = 'hist_as_guide'
134 | ENCODER_USE_HIST = 'encoder_use_hist'
135 | GUIDE_FEATURE_FROM_HIST = 'guide_feature_from_hist'
136 | NC = 'channel_nums'
137 |
138 | ILLU_MAP = 'illu_map'
139 | INVERSE_ILLU_MAP = 'inverse_illu_map'
140 | BRIGHTEN_INPUT = 'brighten_input'
141 | DARKEN_INPUT = 'darken_input'
142 |
143 | TRAIN = 'train'
144 | TEST = 'test'
145 | VALID = 'valid'
146 | ONNX = 'onnx'
147 | CONDOR = 'condor'
148 | IMAGES = 'images'
149 |
--------------------------------------------------------------------------------
/src/model/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/src/model/__init__.py
--------------------------------------------------------------------------------
/src/model/arch/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/src/model/arch/__init__.py
--------------------------------------------------------------------------------
/src/model/arch/drconv.py:
--------------------------------------------------------------------------------
1 | import ipdb
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torch.autograd import Function
6 |
7 |
8 | class asign_index(torch.autograd.Function):
9 | @staticmethod
10 | def forward(ctx, kernel, guide_feature):
11 | ctx.save_for_backward(kernel, guide_feature)
12 | guide_mask = torch.zeros_like(guide_feature).scatter_(1, guide_feature.argmax(dim=1, keepdim=True),
13 | 1).unsqueeze(2) # B x 3 x 1 x 25 x 25
14 | return torch.sum(kernel * guide_mask, dim=1)
15 |
16 | @staticmethod
17 | def backward(ctx, grad_output):
18 | kernel, guide_feature = ctx.saved_tensors
19 | guide_mask = torch.zeros_like(guide_feature).scatter_(1, guide_feature.argmax(dim=1, keepdim=True),
20 | 1).unsqueeze(2) # B x 3 x 1 x 25 x 25
21 | grad_kernel = grad_output.clone().unsqueeze(1) * guide_mask # B x 3 x 256 x 25 x 25
22 | grad_guide = grad_output.clone().unsqueeze(1) * kernel # B x 3 x 256 x 25 x 25
23 | grad_guide = grad_guide.sum(dim=2) # B x 3 x 25 x 25
24 | softmax = F.softmax(guide_feature, 1) # B x 3 x 25 x 25
25 | grad_guide = softmax * (grad_guide - (softmax * grad_guide).sum(dim=1, keepdim=True)) # B x 3 x 25 x 25
26 | return grad_kernel, grad_guide
27 |
28 |
29 | def xcorr_slow(x, kernel, kwargs):
30 | """for loop to calculate cross correlation
31 | """
32 | batch = x.size()[0]
33 | out = []
34 | for i in range(batch):
35 | px = x[i]
36 | pk = kernel[i]
37 | px = px.view(1, px.size()[0], px.size()[1], px.size()[2])
38 | pk = pk.view(-1, px.size()[1], pk.size()[1], pk.size()[2])
39 | po = F.conv2d(px, pk, **kwargs)
40 | out.append(po)
41 | out = torch.cat(out, 0)
42 | return out
43 |
44 |
45 | def xcorr_fast(x, kernel, kwargs):
46 | """group conv2d to calculate cross correlation
47 | """
48 | batch = kernel.size()[0]
49 | pk = kernel.view(-1, x.size()[1], kernel.size()[2], kernel.size()[3])
50 | px = x.view(1, -1, x.size()[2], x.size()[3])
51 | po = F.conv2d(px, pk, **kwargs, groups=batch)
52 | po = po.view(batch, -1, po.size()[2], po.size()[3])
53 | return po
54 |
55 |
56 | class Corr(Function):
57 | @staticmethod
58 | def symbolic(g, x, kernel, groups):
59 | return g.op("Corr", x, kernel, groups_i=groups)
60 |
61 | @staticmethod
62 | def forward(self, x, kernel, groups, kwargs):
63 | """group conv2d to calculate cross correlation
64 | """
65 | batch = x.size(0)
66 | channel = x.size(1)
67 | x = x.view(1, -1, x.size(2), x.size(3))
68 | kernel = kernel.view(-1, channel // groups, kernel.size(2), kernel.size(3))
69 | out = F.conv2d(x, kernel, **kwargs, groups=groups * batch)
70 | out = out.view(batch, -1, out.size(2), out.size(3))
71 | return out
72 |
73 |
74 | class Correlation(nn.Module):
75 | use_slow = True
76 |
77 | def __init__(self, use_slow=None):
78 | super(Correlation, self).__init__()
79 | if use_slow is not None:
80 | self.use_slow = use_slow
81 | else:
82 | self.use_slow = Correlation.use_slow
83 |
84 | def extra_repr(self):
85 | if self.use_slow: return "xcorr_slow"
86 | return "xcorr_fast"
87 |
88 | def forward(self, x, kernel, **kwargs):
89 | if self.training:
90 | if self.use_slow:
91 | return xcorr_slow(x, kernel, kwargs)
92 | else:
93 | return xcorr_fast(x, kernel, kwargs)
94 | else:
95 | return Corr.apply(x, kernel, 1, kwargs)
96 |
97 |
98 | class DRConv2d(nn.Module):
99 | def __init__(self, in_channels, out_channels, kernel_size, region_num=8, guide_input_channel=False, **kwargs):
100 | super(DRConv2d, self).__init__()
101 | self.region_num = region_num
102 | self.guide_input_channel = guide_input_channel
103 |
104 | self.conv_kernel = nn.Sequential(
105 | nn.AdaptiveAvgPool2d((kernel_size, kernel_size)),
106 | nn.Conv2d(in_channels, region_num * region_num, kernel_size=1),
107 | nn.Sigmoid(),
108 | nn.Conv2d(region_num * region_num, region_num * in_channels * out_channels, kernel_size=1,
109 | groups=region_num)
110 | )
111 | if guide_input_channel:
112 | # get guide feature from a user input tensor.
113 | self.conv_guide = nn.Conv2d(guide_input_channel, region_num, kernel_size=kernel_size, **kwargs)
114 | else:
115 | self.conv_guide = nn.Conv2d(in_channels, region_num, kernel_size=kernel_size, **kwargs)
116 |
117 | self.corr = Correlation(use_slow=False)
118 | self.kwargs = kwargs
119 | self.asign_index = asign_index.apply
120 |
121 | def forward(self, input, guide_input=None):
122 | kernel = self.conv_kernel(input)
123 | # kernel = kernel.view(kernel.size(0), -1, kernel.size(2), kernel.size(3)) # B x (r*in*out) x W X H
124 | output = self.corr(input, kernel, **self.kwargs) # B x (r*out) x W x H
125 | output = output.view(output.size(0), self.region_num, -1, output.size(2), output.size(3)) # B x r x out x W x H
126 | if self.guide_input_channel:
127 | guide_feature = self.conv_guide(guide_input)
128 | else:
129 | guide_feature = self.conv_guide(input)
130 | self.guide_feature = guide_feature
131 | # self.guide_feature = torch.zeros_like(guide_feature).scatter_(1, guide_feature.argmax(dim=1, keepdim=True), 1).unsqueeze(2) # B x 3 x 1 x 25 x 25
132 | output = self.asign_index(output, guide_feature)
133 | return output
134 |
135 |
136 | class HistDRConv2d(DRConv2d):
137 | def forward(self, input, histmap):
138 | """
139 | use histmap as guide feature directly.
140 | histmap.shape: [bs, n_bins, h, w]
141 | """
142 | histmap.requires_grad_(False)
143 |
144 | kernel = self.conv_kernel(input)
145 | output = self.corr(input, kernel, **self.kwargs) # B x (r*out) x W x H
146 | output = output.view(output.size(0), self.region_num, -1, output.size(2), output.size(3)) # B x r x out x W x H
147 | output = self.asign_index(output, histmap)
148 | return output
149 |
150 |
151 | if __name__ == '__main__':
152 | B = 16
153 | in_channels = 256
154 | out_channels = 512
155 | size = 89
156 | conv = DRConv2d(in_channels, out_channels, kernel_size=3, region_num=8).cuda()
157 | conv.train()
158 | input = torch.ones(B, in_channels, size, size).cuda()
159 | output = conv(input)
160 | print(input.shape, output.shape)
161 |
162 | # flops, params
163 | from thop import profile
164 |
165 |
166 | class Conv2d(nn.Module):
167 | def __init__(self):
168 | super(Conv2d, self).__init__()
169 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3)
170 |
171 | def forward(self, input):
172 | return self.conv(input)
173 |
174 |
175 | ipdb.set_trace()
176 | conv2 = Conv2d().cuda()
177 | conv2.train()
178 | print(input.shape, conv2(input).shape)
179 | flops2, params2 = profile(conv2, inputs=(input,))
180 | flops, params = profile(conv, inputs=(input,))
181 |
182 | print('[ * ] DRconv FLOPs = ' + str(flops / 1000 ** 3) + 'G')
183 | print('[ * ] DRconv Params Num = ' + str(params / 1000 ** 2) + 'M')
184 |
185 | print('[ * ] Conv FLOPs = ' + str(flops2 / 1000 ** 3) + 'G')
186 | print('[ * ] Conv Params Num = ' + str(params2 / 1000 ** 2) + 'M')
187 |
--------------------------------------------------------------------------------
/src/model/arch/hist.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | def get_gray(img):
5 | r = img[:, 0, ...]
6 | g = img[:, 1, ...]
7 | b = img[:, 2, ...]
8 | return (0.299 * r + 0.587 * g + 0.114 * b).unsqueeze(1)
9 |
10 |
11 | def pack_tensor(x, n_bins):
12 | # pack tensor: transform gt_hist.shape [n_bins, bs, c, h, w] -> [bs*c, b_bins, h, w]
13 | # merge dim 1 (bs) and dim 2 (channel).
14 | return x.reshape(n_bins, -1, *x.shape[-2:]).permute(1, 0, 2, 3)
15 |
16 |
17 | def get_hist(img, n_bins, grayscale=False):
18 | """
19 | Given a img (shape: bs, c, h, w),
20 | return the SOFT histogram map (shape: n_bins, bs, c, h, w)
21 | or (shape: n_bins, bs, h, w) when grayscale=True.
22 | """
23 | if grayscale:
24 | img = get_gray(img)
25 | return torch.stack([
26 | torch.nn.functional.relu(1 - torch.abs(img - (2 * b - 1) / float(2 * n_bins)) * float(n_bins))
27 | for b in range(1, n_bins + 1)
28 | ])
29 |
30 |
31 | def get_hist_conv(n_bins, kernel_size=2, train=False):
32 | """
33 | Return a conv kernel.
34 | The kernel is used to apply on the histogram map, shrinking the scale of the hist-map.
35 | """
36 | conv = torch.nn.Conv2d(n_bins, n_bins, kernel_size, kernel_size, bias=False, groups=1)
37 | conv.weight.data.zero_()
38 | for i in range(conv.weight.shape[1]):
39 | alpha = kernel_size ** 2
40 | # alpha = 1
41 | conv.weight.data[i, i, ...] = torch.ones(kernel_size, kernel_size) / alpha
42 | if not train:
43 | conv.requires_grad_(False)
44 | return conv
45 |
--------------------------------------------------------------------------------
/src/model/arch/nonlocal_block_embedded_gaussian.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 | from torch.nn import functional as F
4 | import ipdb
5 |
6 |
7 | class _NonLocalBlockND(nn.Module):
8 | def __init__(self, in_channels, inter_channels=None, dimension=3, sub_sample='pool', bn_layer=True):
9 | """
10 | :param in_channels:
11 | :param inter_channels:
12 | :param dimension:
13 | :param sub_sample: 'pool' or 'bilinear' or False
14 | :param bn_layer:
15 | """
16 |
17 | super(_NonLocalBlockND, self).__init__()
18 |
19 | assert dimension in [1, 2, 3]
20 |
21 | self.dimension = dimension
22 | self.sub_sample = sub_sample
23 |
24 | self.in_channels = in_channels
25 | self.inter_channels = inter_channels
26 |
27 | if self.inter_channels is None:
28 | self.inter_channels = in_channels // 2
29 | if self.inter_channels == 0:
30 | self.inter_channels = 1
31 |
32 | if dimension == 3:
33 | conv_nd = nn.Conv3d
34 | max_pool_layer = nn.MaxPool3d(kernel_size=(1, 2, 2))
35 | bn = nn.BatchNorm3d
36 | elif dimension == 2:
37 | conv_nd = nn.Conv2d
38 | if sub_sample == 'pool':
39 | max_pool_layer = nn.MaxPool2d(kernel_size=(2, 2))
40 | elif sub_sample == 'bilinear':
41 | max_pool_layer = nn.UpsamplingBilinear2d([16, 16])
42 | else:
43 | raise NotImplementedError(f'[ ERR ] Unknown down sample method: {sub_sample}')
44 | bn = nn.BatchNorm2d
45 | else:
46 | conv_nd = nn.Conv1d
47 | max_pool_layer = nn.MaxPool1d(kernel_size=(2))
48 | bn = nn.BatchNorm1d
49 |
50 | self.g = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
51 | kernel_size=1, stride=1, padding=0)
52 |
53 | if bn_layer:
54 | self.W = nn.Sequential(
55 | conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
56 | kernel_size=1, stride=1, padding=0),
57 | bn(self.in_channels)
58 | )
59 | nn.init.constant_(self.W[1].weight, 0)
60 | nn.init.constant_(self.W[1].bias, 0)
61 | else:
62 | self.W = conv_nd(in_channels=self.inter_channels, out_channels=self.in_channels,
63 | kernel_size=1, stride=1, padding=0)
64 | nn.init.constant_(self.W.weight, 0)
65 | nn.init.constant_(self.W.bias, 0)
66 |
67 | self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
68 | kernel_size=1, stride=1, padding=0)
69 | self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
70 | kernel_size=1, stride=1, padding=0)
71 |
72 | if sub_sample:
73 | self.g = nn.Sequential(self.g, max_pool_layer)
74 | self.phi = nn.Sequential(self.phi, max_pool_layer)
75 |
76 | def forward(self, x, return_nl_map=False):
77 | """
78 | :param x: (b, c, t, h, w)
79 | :param return_nl_map: if True return z, nl_map, else only return z.
80 | :return:
81 | """
82 |
83 | batch_size = x.size(0)
84 |
85 | g_x = self.g(x).view(batch_size, self.inter_channels, -1)
86 | g_x = g_x.permute(0, 2, 1)
87 |
88 | theta_x = self.theta(x).view(batch_size, self.inter_channels, -1)
89 | theta_x = theta_x.permute(0, 2, 1)
90 | phi_x = self.phi(x).view(batch_size, self.inter_channels, -1)
91 | f = torch.matmul(theta_x, phi_x)
92 | f_div_C = F.softmax(f, dim=-1)
93 |
94 | y = torch.matmul(f_div_C, g_x)
95 | y = y.permute(0, 2, 1).contiguous()
96 | y = y.view(batch_size, self.inter_channels, *x.size()[2:])
97 | W_y = self.W(y)
98 | z = W_y + x
99 |
100 | if return_nl_map:
101 | return z, f_div_C
102 | return z
103 |
104 |
105 | class NONLocalBlock1D(_NonLocalBlockND):
106 | def __init__(self, in_channels, inter_channels=None, sub_sample='pool', bn_layer=True):
107 | super(NONLocalBlock1D, self).__init__(in_channels,
108 | inter_channels=inter_channels,
109 | dimension=1, sub_sample=sub_sample,
110 | bn_layer=bn_layer)
111 |
112 |
113 | class NONLocalBlock2D(_NonLocalBlockND):
114 | def __init__(self, in_channels, inter_channels=None, sub_sample='pool', bn_layer=True):
115 | super(NONLocalBlock2D, self).__init__(in_channels,
116 | inter_channels=inter_channels,
117 | dimension=2, sub_sample=sub_sample,
118 | bn_layer=bn_layer, )
119 |
120 |
121 | class NONLocalBlock3D(_NonLocalBlockND):
122 | def __init__(self, in_channels, inter_channels=None, sub_sample='pool', bn_layer=True):
123 | super(NONLocalBlock3D, self).__init__(in_channels,
124 | inter_channels=inter_channels,
125 | dimension=3, sub_sample=sub_sample,
126 | bn_layer=bn_layer, )
127 |
128 |
129 | if __name__ == '__main__':
130 | import torch
131 |
132 | # for (sub_sample_, bn_layer_) in [(True, True), (False, False), (True, False), (False, True)]:
133 | # img = torch.zeros(2, 3, 20)
134 | # net = NONLocalBlock1D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
135 | # out = net(img)
136 | # print(out.size())
137 |
138 | # img = torch.zeros(2, 3, 20, 20)
139 | # net = NONLocalBlock2D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
140 | # out = net(img)
141 | # print(out.size())
142 |
143 | # img = torch.randn(2, 3, 8, 20, 20)
144 | # net = NONLocalBlock3D(3, sub_sample=sub_sample_, bn_layer=bn_layer_)
145 | # out = net(img)
146 | # print(out.size())
147 |
148 | img = torch.zeros(4, 16, 20, 20).cuda()
149 | net = NONLocalBlock2D(16, sub_sample=True, bn_layer=False).cuda()
150 | out = net(img)
151 | ipdb.set_trace()
152 |
--------------------------------------------------------------------------------
/src/model/arch/unet_based/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/src/model/arch/unet_based/__init__.py
--------------------------------------------------------------------------------
/src/model/arch/unet_based/hist_unet.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import ipdb
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from model.arch.drconv import DRConv2d, HistDRConv2d
8 | from model.arch.hist import get_hist, get_hist_conv, pack_tensor
9 |
10 |
11 | class DoubleConv(nn.Module):
12 | def __init__(self, in_channels, out_channels, mid_channels=None):
13 | super().__init__()
14 | if not mid_channels:
15 | mid_channels = out_channels
16 | self.double_conv = nn.Sequential(
17 | nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1),
18 | nn.BatchNorm2d(mid_channels),
19 | nn.ReLU(inplace=True),
20 | nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1),
21 | nn.BatchNorm2d(out_channels),
22 | nn.ReLU(inplace=True)
23 | )
24 |
25 | def forward(self, x):
26 | return self.double_conv(x)
27 |
28 |
29 | class DRDoubleConv(nn.Module):
30 | def __init__(self, in_channels, out_channels, mid_channels=None, **kargs):
31 | super().__init__()
32 | if not mid_channels:
33 | mid_channels = out_channels
34 |
35 | self.double_conv = nn.Sequential(
36 | DRConv2d(in_channels, mid_channels, kernel_size=3, region_num=REGION_NUM_, padding=1, **kargs),
37 | nn.BatchNorm2d(mid_channels),
38 | nn.ReLU(inplace=True),
39 | DRConv2d(mid_channels, out_channels, kernel_size=3, region_num=REGION_NUM_, padding=1, **kargs),
40 | nn.BatchNorm2d(out_channels),
41 | nn.ReLU(inplace=True)
42 | )
43 | assert len(DRCONV_POSITION_) == 2
44 | assert DRCONV_POSITION_[0] or DRCONV_POSITION_[1]
45 | if DRCONV_POSITION_[0] == 0:
46 | print('[ WARN ] Use Conv in DRDoubleConv[0] instead of DRconv.')
47 | self.double_conv[0] = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1)
48 | if DRCONV_POSITION_[1] == 0:
49 | print('[ WARN ] Use Conv in DRDoubleConv[3] instead of DRconv.')
50 | self.double_conv[3] = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1)
51 |
52 | def forward(self, x):
53 | res = self.double_conv(x)
54 | self.guide_features = []
55 | if DRCONV_POSITION_[0]:
56 | self.guide_features.append(self.double_conv[0].guide_feature)
57 | if DRCONV_POSITION_[1]:
58 | self.guide_features.append(self.double_conv[3].guide_feature)
59 | return res
60 |
61 |
62 | class HistDRDoubleConv(nn.Module):
63 | def __init__(self, in_channels, out_channels, mid_channels=None):
64 | super().__init__()
65 | if not mid_channels:
66 | mid_channels = out_channels
67 | self.conv1 = HistDRConv2d(in_channels, mid_channels, kernel_size=3, region_num=REGION_NUM_, padding=1)
68 | self.inter1 = nn.Sequential(
69 | nn.BatchNorm2d(mid_channels),
70 | nn.ReLU(inplace=True)
71 | )
72 | self.conv2 = HistDRConv2d(mid_channels, out_channels, kernel_size=3, region_num=REGION_NUM_, padding=1)
73 | self.inter2 = nn.Sequential(
74 | nn.BatchNorm2d(out_channels),
75 | nn.ReLU(inplace=True)
76 | )
77 |
78 | def forward(self, x, histmap):
79 | y = self.conv1(x, histmap)
80 | y = self.inter1(y)
81 | y = self.conv2(y, histmap)
82 | return self.inter2(y)
83 |
84 |
85 | class HistGuidedDRDoubleConv(nn.Module):
86 | def __init__(self, in_channels, out_channels, mid_channels=None, **kargs):
87 | super().__init__()
88 | assert len(DRCONV_POSITION_) == 2
89 | assert DRCONV_POSITION_[0] or DRCONV_POSITION_[1]
90 |
91 | if not mid_channels:
92 | mid_channels = out_channels
93 | if DRCONV_POSITION_[0]:
94 | self.conv1 = DRConv2d(in_channels, mid_channels, kernel_size=3, region_num=REGION_NUM_, padding=1, **kargs)
95 | else:
96 | print('[ WARN ] Use Conv in HistGuidedDRDoubleConv[0] instead of DRconv.')
97 | self.conv1 = nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1)
98 |
99 | self.inter1 = nn.Sequential(
100 | nn.BatchNorm2d(mid_channels),
101 | nn.ReLU(inplace=True)
102 | )
103 | if DRCONV_POSITION_[1]:
104 | self.conv2 = DRConv2d(mid_channels, out_channels, kernel_size=3, region_num=REGION_NUM_, padding=1, **kargs)
105 | else:
106 | print('[ WARN ] Use Conv in HistGuidedDRDoubleConv[0] instead of DRconv.')
107 | self.conv2 = nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1)
108 |
109 | self.inter2 = nn.Sequential(
110 | nn.BatchNorm2d(out_channels),
111 | nn.ReLU(inplace=True)
112 | )
113 |
114 | def forward(self, x, histmap):
115 | if DRCONV_POSITION_[0]:
116 | y = self.conv1(x, histmap)
117 | else:
118 | y = self.conv1(x)
119 | y = self.inter1(y)
120 |
121 | if DRCONV_POSITION_[1]:
122 | y = self.conv2(y, histmap)
123 | else:
124 | y = self.conv2(y)
125 |
126 | # self.guide_features = [self.conv1.guide_feature, self.conv2.guide_feature]
127 | self.guide_features = []
128 | if DRCONV_POSITION_[0]:
129 | self.guide_features.append(self.conv1.guide_feature)
130 | if DRCONV_POSITION_[1]:
131 | self.guide_features.append(self.conv2.guide_feature)
132 |
133 | return self.inter2(y)
134 |
135 |
136 | class Up(nn.Module):
137 | def __init__(self, in_channels, out_channels, bilinear=True, **kargs):
138 | super().__init__()
139 | self.up = nn.Upsample(scale_factor=DOWN_RATIO_, mode='bilinear', align_corners=True)
140 | if CONV_TYPE_ == 'drconv':
141 | if HIST_AS_GUIDE_:
142 | self.conv = HistDRDoubleConv(in_channels, out_channels, in_channels // 2)
143 | elif GUIDE_FEATURE_FROM_HIST_:
144 | self.conv = HistGuidedDRDoubleConv(in_channels, out_channels, in_channels // 2, **kargs)
145 | else:
146 | self.conv = DRDoubleConv(in_channels, out_channels, in_channels // 2)
147 | # elif CONV_TYPE_ == 'dconv':
148 | # self.conv = HistDyDoubleConv(in_channels, out_channels, in_channels // 2)
149 |
150 | def forward(self, x1, x2, histmap):
151 | """
152 | histmap: shape [bs, c * n_bins, h, w]
153 | """
154 | x1 = self.up(x1)
155 |
156 | # input is CHW
157 | diffY = x2.size()[2] - x1.size()[2]
158 | diffX = x2.size()[3] - x1.size()[3]
159 |
160 | x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
161 | diffY // 2, diffY - diffY // 2])
162 |
163 | if HIST_AS_GUIDE_ or GUIDE_FEATURE_FROM_HIST_ or CONV_TYPE_ == 'dconv':
164 | x = torch.cat([x2, x1], dim=1)
165 | res = self.conv(x, histmap)
166 | else:
167 | x = torch.cat([x2, x1, histmap], dim=1)
168 | res = self.conv(x)
169 | self.guide_features = self.conv.guide_features
170 | return res
171 |
172 |
173 | class Down(nn.Module):
174 | def __init__(self, in_channels, out_channels, use_hist=False):
175 | super().__init__()
176 | self.use_hist = use_hist
177 | if not use_hist:
178 | self.maxpool_conv = nn.Sequential(
179 | nn.MaxPool2d(DOWN_RATIO_),
180 | DoubleConv(in_channels, out_channels)
181 | )
182 | else:
183 | if HIST_AS_GUIDE_:
184 | # self.maxpool_conv = nn.Sequential(
185 | # nn.MaxPool2d(2),
186 | # HistDRDoubleConv(in_channels, out_channels, in_channels // 2)
187 | # )
188 | raise NotImplementedError()
189 | elif GUIDE_FEATURE_FROM_HIST_:
190 | self.maxpool = nn.MaxPool2d(DOWN_RATIO_)
191 | self.conv = HistGuidedDRDoubleConv(in_channels, out_channels, in_channels // 2)
192 | else:
193 | self.maxpool_conv = nn.Sequential(
194 | nn.MaxPool2d(DOWN_RATIO_),
195 | DRDoubleConv(in_channels, out_channels, in_channels // 2)
196 | )
197 |
198 | def forward(self, x, histmap=None):
199 | if GUIDE_FEATURE_FROM_HIST_ and self.use_hist:
200 | x = self.maxpool(x)
201 | return self.conv(x, histmap)
202 | elif self.use_hist:
203 | return self.maxpool_conv(torch.cat([x, histmap], axis=1))
204 | else:
205 | return self.maxpool_conv(x)
206 |
207 |
208 | class OutConv(nn.Module):
209 | def __init__(self, in_channels, out_channels):
210 | super(OutConv, self).__init__()
211 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
212 |
213 | def forward(self, x):
214 | return self.conv(x)
215 |
216 |
217 | class HistUNet(nn.Module):
218 | def __init__(self,
219 | in_channels=3,
220 | out_channels=3,
221 | bilinear=True,
222 | n_bins=8,
223 | hist_as_guide=False,
224 | channel_nums=None,
225 | hist_conv_trainable=False,
226 | encoder_use_hist=False,
227 | guide_feature_from_hist=False,
228 | region_num=8,
229 | use_gray_hist=False,
230 | conv_type='drconv',
231 | down_ratio=1,
232 | drconv_position=[1, 1],
233 | ):
234 | super().__init__()
235 | C_NUMS = [16, 32, 64, 128, 256]
236 | if channel_nums:
237 | C_NUMS = channel_nums
238 | self.maxpool = nn.MaxPool2d(2)
239 | self.n_bins = n_bins
240 | self.encoder_use_hist = encoder_use_hist
241 | self.use_gray_hist = use_gray_hist
242 | self.hist_conv_trainable = hist_conv_trainable
243 |
244 | global HIST_AS_GUIDE_, GUIDE_FEATURE_FROM_HIST_, REGION_NUM_, CONV_TYPE_, DOWN_RATIO_, DRCONV_POSITION_
245 | HIST_AS_GUIDE_ = hist_as_guide
246 | GUIDE_FEATURE_FROM_HIST_ = guide_feature_from_hist
247 | REGION_NUM_ = region_num
248 | CONV_TYPE_ = conv_type
249 | DOWN_RATIO_ = down_ratio
250 | DRCONV_POSITION_ = drconv_position
251 |
252 | if hist_conv_trainable:
253 | self.hist_conv1 = get_hist_conv(n_bins * in_channels, down_ratio, train=True)
254 | self.hist_conv2 = get_hist_conv(n_bins * in_channels, down_ratio, train=True)
255 | self.hist_conv3 = get_hist_conv(n_bins * in_channels, down_ratio, train=True)
256 | else:
257 | self.hist_conv = get_hist_conv(n_bins, down_ratio)
258 |
259 | factor = 2 if bilinear else 1
260 | self.inc = DoubleConv(in_channels, C_NUMS[0])
261 | if hist_as_guide or guide_feature_from_hist or conv_type == 'dconv':
262 | extra_c_num = 0
263 | elif use_gray_hist:
264 | extra_c_num = n_bins
265 | else:
266 | extra_c_num = n_bins * in_channels
267 |
268 | if guide_feature_from_hist:
269 | kargs = {
270 | 'guide_input_channel': n_bins if use_gray_hist else n_bins * in_channels
271 | }
272 | else:
273 | kargs = {}
274 |
275 | if encoder_use_hist:
276 | encoder_extra_c_num = extra_c_num
277 | else:
278 | encoder_extra_c_num = 0
279 |
280 | self.down1 = Down(C_NUMS[0] + encoder_extra_c_num, C_NUMS[1], use_hist=encoder_use_hist)
281 | self.down2 = Down(C_NUMS[1] + encoder_extra_c_num, C_NUMS[2], use_hist=encoder_use_hist)
282 | self.down3 = Down(C_NUMS[2] + encoder_extra_c_num, C_NUMS[3], use_hist=encoder_use_hist)
283 | self.down4 = Down(C_NUMS[3] + encoder_extra_c_num, C_NUMS[4] // factor, use_hist=encoder_use_hist)
284 |
285 | self.up1 = Up(C_NUMS[4] + extra_c_num, C_NUMS[3] // factor, bilinear, **kargs)
286 | self.up2 = Up(C_NUMS[3] + extra_c_num, C_NUMS[2] // factor, bilinear, **kargs)
287 | self.up3 = Up(C_NUMS[2] + extra_c_num, C_NUMS[1] // factor, bilinear, **kargs)
288 | self.up4 = Up(C_NUMS[1] + extra_c_num, C_NUMS[0], bilinear, **kargs)
289 | self.outc = OutConv(C_NUMS[0], out_channels)
290 |
291 | def forward(self, x):
292 | # ipdb.set_trace()
293 | # get histograms
294 | # (`get_hist` return shape: n_bins, bs, c, h, w).
295 | if HIST_AS_GUIDE_ or self.use_gray_hist:
296 | histmap = get_hist(x, self.n_bins, grayscale=True)
297 | else:
298 | histmap = get_hist(x, self.n_bins)
299 |
300 | bs = x.shape[0]
301 | histmap = pack_tensor(histmap, self.n_bins).detach() # out: [bs * c, n_bins, h, w]
302 | if not self.hist_conv_trainable:
303 | hist_down2 = self.hist_conv(histmap)
304 | hist_down4 = self.hist_conv(hist_down2)
305 | hist_down8 = self.hist_conv(hist_down4)
306 |
307 | # [bs * c, b_bins, h, w] -> [bs, c*b_bins, h, w]
308 | for item in [histmap, hist_down2, hist_down4, hist_down8]:
309 | item.data = item.reshape(bs, -1, *item.shape[-2:])
310 | else:
311 | histmap = histmap.reshape(bs, -1, *histmap.shape[-2:])
312 | hist_down2 = self.hist_conv1(histmap)
313 | hist_down4 = self.hist_conv2(hist_down2)
314 | hist_down8 = self.hist_conv3(hist_down4) # [bs, n_bins * c, h/n, w/n]
315 |
316 | # forward
317 | encoder_hists = [None, ] * 4
318 | if self.encoder_use_hist:
319 | encoder_hists = [histmap, hist_down2, hist_down4, hist_down8]
320 |
321 | x1 = self.inc(x)
322 | x2 = self.down1(x1, encoder_hists[0]) # x2: 16
323 | x3 = self.down2(x2, encoder_hists[1]) # x3: 24
324 | x4 = self.down3(x3, encoder_hists[2]) # x4: 32
325 | x5 = self.down4(x4, encoder_hists[3]) # x5: 32
326 |
327 | # always apply hist in decoder:
328 | # ipdb.set_trace()
329 | x = self.up1(x5, x4, hist_down8) # [x5, x4]: 32 + 32
330 | x = self.up2(x, x3, hist_down4) # [x4, x3]:
331 | x = self.up3(x, x2, hist_down2)
332 | x = self.up4(x, x1, histmap)
333 |
334 | self.guide_features = [layer.guide_features for layer in [
335 | self.up1,
336 | self.up2,
337 | self.up3,
338 | self.up4,
339 | ]]
340 |
341 | logits = self.outc(x)
342 | return logits
343 |
344 |
345 | if __name__ == '__main__':
346 | model = HistUNet()
347 | x = torch.rand(4, 3, 512, 512)
348 | model(x)
349 | ipdb.set_trace()
350 |
--------------------------------------------------------------------------------
/src/model/basemodel.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import pathlib
4 | from collections.abc import Iterable
5 |
6 | import pytorch_lightning as pl
7 | import torchvision
8 | import wandb
9 |
10 | import utils.util as util
11 |
12 | try:
13 | from thop import profile
14 | except Exception as e:
15 | print('ERR: import thop failed, skip. error msg:')
16 | print(e)
17 |
18 | from globalenv import *
19 |
20 | global LOGGER_BUFFER_LOCK
21 |
22 |
23 | class BaseModel(pl.core.LightningModule):
24 | def __init__(self, opt, running_modes):
25 | '''
26 | logger_img_group_names: images group names in wandb logger. recommand: ['train', 'valid']
27 | '''
28 |
29 | super().__init__()
30 | self.save_hyperparameters(dict(opt))
31 | print('Running initialization for BaseModel')
32 |
33 | if IMG_DIRPATH in opt:
34 | # in training mode.
35 | # if in test mode, configLogging is not called.
36 | if TRAIN in running_modes:
37 | self.train_img_dirpath = osp.join(opt[IMG_DIRPATH], TRAIN)
38 | util.mkdir(self.train_img_dirpath)
39 | if VALID in running_modes and (len(opt[VALID_DATA].keys()) > 1 or opt[VALID_RATIO]):
40 | self.valid_img_dirpath = osp.join(opt[IMG_DIRPATH], VALID)
41 | util.mkdir(self.valid_img_dirpath)
42 |
43 | self.opt = opt
44 | self.learning_rate = self.opt[LR]
45 |
46 | self.MODEL_WATCHED = False # for wandb watching model
47 | self.global_valid_step = 0
48 | self.iogt = {} # a dict, saving input, output and gt batch
49 |
50 | assert isinstance(running_modes, Iterable)
51 | self.logger_image_buffer = {k: [] for k in running_modes}
52 |
53 | def show_flops_and_param_num(self, inputs):
54 | # inputs: arguments of `forward()`
55 | try:
56 | flops, params = profile(self, inputs=inputs)
57 | print('[ * ] FLOPs = ' + str(flops / 1000 ** 3) + 'G')
58 | print('[ * ] Params Num = ' + str(params / 1000 ** 2) + 'M')
59 | except Exception as e:
60 | print(f'Err occured while calculating flops: {str(e)}')
61 |
62 | # def get_progress_bar_dict(self):
63 | # items = super().get_progress_bar_dict()
64 | # items.pop("v_num", None)
65 | # # items.pop("loss", None)
66 | # return items
67 |
68 | def build_test_res_dir(self):
69 | assert self.opt[CHECKPOINT_PATH]
70 | modelpath = pathlib.Path(self.opt[CHECKPOINT_PATH])
71 |
72 | # only `test_ds` is supported when testing.
73 | ds_type = TEST_DATA
74 | runtime_dirname = f'{self.opt.runtime.modelname}_{modelpath.parent.name}_{modelpath.name}@{self.opt.test_ds.name}'
75 | dirpath = modelpath.parent / TEST_RESULT_DIRNAME
76 |
77 | if (dirpath / runtime_dirname).exists():
78 | if len(os.listdir(dirpath / runtime_dirname)) == 0:
79 | # an existing but empty dir
80 | pass
81 | else:
82 | try:
83 | input_str = input(
84 | f'[ WARN ] Result directory "{runtime_dirname}" exists. Press ENTER to overwrite or input suffix '
85 | f'to create a new one:\n> New name: {runtime_dirname}.')
86 | except Exception as e:
87 | print(
88 | f'[ WARN ] Excepion {e} occured, ignore input and set `input_str` empty.')
89 | input_str = ''
90 | if input_str == '':
91 | print(
92 | f"[ WARN ] Overwrite result_dir: {runtime_dirname}")
93 | pass
94 | else:
95 | runtime_dirname += '.' + input_str
96 | # fname += '.new'
97 |
98 | dirpath /= runtime_dirname
99 | util.mkdir(dirpath)
100 | print('TEST - Result save path:')
101 | print(str(dirpath))
102 |
103 | util.save_opt(dirpath, self.opt)
104 | return str(dirpath)
105 |
106 | @staticmethod
107 | def save_img_batch(batch, dirpath, fname, save_num=1):
108 | util.mkdir(dirpath)
109 | imgpath = osp.join(dirpath, fname)
110 |
111 | # If you want to visiual a single image, call .unsqueeze(0)
112 | assert len(batch.shape) == 4
113 | torchvision.utils.save_image(batch[:save_num], imgpath)
114 |
115 | def calc_and_log_losses(self, loss_lambda_map):
116 | logged_losses = {}
117 | loss = 0
118 | for loss_name, loss_weight in self.opt[RUNTIME][LOSS].items():
119 | if loss_weight:
120 | current = loss_lambda_map[loss_name]()
121 | if current != None:
122 | current *= loss_weight
123 | logged_losses[loss_name] = current
124 | loss += current
125 |
126 | logged_losses[LOSS] = loss
127 | self.log_dict(logged_losses)
128 | return loss
129 |
130 | def log_images_dict(self, mode, input_fname, img_batch_dict, gt_fname=None):
131 | """
132 | log input, output and gt images to local disk and remote wandb logger.
133 | mode: TRAIN or VALID
134 | """
135 | if self.opt[DEBUG]:
136 | return
137 |
138 | global LOGGER_BUFFER_LOCK
139 | if LOGGER_BUFFER_LOCK and self.opt.logger == 'wandb':
140 | # buffer is used by other GPU-thread.
141 | # print('Buffer locked!')
142 | return
143 |
144 | assert mode in [TRAIN, VALID]
145 | if mode == VALID:
146 | local_dirpath = self.valid_img_dirpath
147 | step = self.global_valid_step
148 | if self.global_valid_step == 0:
149 | print(
150 | 'WARN: Found global_valid_step=0. Maybe you foget to increase `self.global_valid_step` in `self.validation_step`?')
151 | # log_step = step # to avoid valid log step = train log step
152 | elif mode == TRAIN:
153 | local_dirpath = self.train_img_dirpath
154 | step = self.global_step
155 | # log_step = None
156 |
157 | if step % self.opt[LOG_EVERY] == 0:
158 | suffiix = f'_epoch{self.current_epoch}_step{step}.png'
159 | input_fname = osp.basename(input_fname) + suffiix
160 |
161 | if gt_fname:
162 | gt_fname = osp.basename(gt_fname) + suffiix
163 |
164 | # ****** public buffer opration ******
165 | LOGGER_BUFFER_LOCK = True
166 | for name, batch in img_batch_dict.items():
167 | if batch is None or batch is False:
168 | # image is None or False, skip.
169 | continue
170 |
171 | # save local image:
172 | fname = input_fname
173 | if name == GT and gt_fname:
174 | fname = gt_fname
175 | self.save_img_batch(
176 | batch,
177 | # e.g. ../train_log/train/output
178 | osp.join(local_dirpath, name),
179 | fname)
180 |
181 | # save remote image:
182 | if self.opt.logger == 'wandb':
183 | self.add_img_to_buffer(mode, batch, mode, name, fname)
184 | else:
185 | # tb logger
186 | self.logger.experiment.add_image(f'{mode}/{name}', batch[0], step)
187 |
188 | if self.opt.logger == 'wandb':
189 | self.commit_logger_buffer(mode)
190 |
191 | # self.buffer_img_step += 1
192 | LOGGER_BUFFER_LOCK = False
193 | # ****** public buffer opration ******
194 |
195 | def add_img_to_buffer(self, group_name, batch, *caption):
196 | if len(batch.shape) == 3:
197 | # when input is not a batch:
198 | batch = batch.unsqueeze(0)
199 |
200 | self.logger_image_buffer[group_name].append(
201 | wandb.Image(batch[0], caption='-'.join(caption))
202 | )
203 |
204 | def commit_logger_buffer(self, groupname, **kwargs):
205 | assert self.logger
206 | self.logger.experiment.log({
207 | groupname: self.logger_image_buffer[groupname]
208 | }, **kwargs)
209 |
210 | # clear buffer after each commit for the next commit
211 | self.logger_image_buffer[groupname].clear()
212 |
--------------------------------------------------------------------------------
/src/model/basic_loss.py:
--------------------------------------------------------------------------------
1 | import kornia as kn
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 | from torchvision import models, transforms
6 |
7 | from .arch.hist import get_hist, get_hist_conv
8 |
9 |
10 | class HistogramLoss(nn.Module):
11 | def __init__(self, n_bins=8, downscale=16):
12 | super().__init__()
13 | self.n_bins = n_bins
14 | self.hist_conv = get_hist_conv(n_bins, downscale)
15 |
16 | # pack tensor: transform gt_hist.shape [n_bins, bs, c, h, w] -> [bs*c, b_bins, h, w]
17 | # merge dim 1 (bs) and dim 2 (channel).
18 | self.pack_tensor = lambda x: x.reshape(self.n_bins, -1, *x.shape[-2:]).permute(1, 0, 2, 3)
19 |
20 | def forward(self, output, gt):
21 | gt_hist = get_hist(gt, self.n_bins)
22 | output_hist = get_hist(output, self.n_bins)
23 |
24 | shrink_hist_gt = self.hist_conv(self.pack_tensor(gt_hist))
25 | shrink_hist_output = self.hist_conv(self.pack_tensor(output_hist))
26 |
27 | return F.mse_loss(shrink_hist_gt, shrink_hist_output)
28 |
29 |
30 | class IntermediateHistogramLoss(HistogramLoss):
31 | def __init__(self, n_bins=8, downscale=16):
32 | super().__init__(n_bins, downscale)
33 | self.exposure_threshold = 0.5
34 |
35 | def forward(self, img, gt, brighten, darken):
36 | """
37 | input brighten and darken img, get errors between:
38 | - brighten img & darken region in GT
39 | - darken img & brighten region in GT
40 | """
41 | bs, c, _, _ = gt.shape
42 | gt_hist = get_hist(gt, self.n_bins)
43 | shrink_hist_gt = self.hist_conv(self.pack_tensor(gt_hist))
44 |
45 | down_size = shrink_hist_gt.shape[-2:]
46 | shrink_hist_gt = shrink_hist_gt.reshape(bs, c, self.n_bins, *down_size)
47 | down_x = F.interpolate(img, size=down_size)
48 |
49 | # get mask from the input:
50 | over_ixs = down_x > self.exposure_threshold
51 | under_ixs = down_x <= self.exposure_threshold
52 | over_mask = down_x.clone()
53 | over_mask[under_ixs] = 0
54 | over_mask[over_ixs] = 1
55 | over_mask.unsqueeze_(2)
56 | under_mask = down_x.clone()
57 | under_mask[under_ixs] = 1
58 | under_mask[over_ixs] = 0
59 | under_mask.unsqueeze_(2)
60 |
61 | shrink_darken_hist = self.hist_conv(self.pack_tensor(get_hist(darken, self.n_bins))).reshape(bs, c, self.n_bins,
62 | *down_size)
63 | shrink_brighten_hist = self.hist_conv(self.pack_tensor(get_hist(brighten, self.n_bins))).reshape(bs, c,
64 | self.n_bins,
65 | *down_size)
66 |
67 | # [ 046 ] use ssim loss
68 | return 0.5 * kn.losses.ssim_loss((shrink_hist_gt * over_mask).view(-1, c, *down_size),
69 | (shrink_darken_hist * over_mask).view(-1, c, *down_size),
70 | window_size=5) + 0.5 * kn.losses.ssim_loss(
71 | (shrink_hist_gt * under_mask).view(-1, c, *down_size),
72 | (shrink_brighten_hist * under_mask).view(-1, c, *down_size), window_size=5)
73 |
74 | # [ 042 ] use l2 loss
75 | # return 0.5 * F.mse_loss(shrink_hist_gt * over_mask, shrink_darken_hist * over_mask) + 0.5 * F.mse_loss(shrink_hist_gt * under_mask, shrink_brighten_hist * under_mask)
76 |
77 |
78 | class WeightedL1Loss(nn.Module):
79 | def __init__(self):
80 | super().__init__()
81 |
82 | def forward(self, input, output, gt):
83 | bias = 0.1
84 | weights = (torch.abs(input - 0.5) + bias) / 0.5
85 | weights = weights.mean(axis=1).unsqueeze(1).repeat(1, 3, 1, 1)
86 | loss = torch.mean(torch.abs(output - gt) * weights.detach())
87 | return loss
88 |
89 |
90 | class LTVloss(nn.Module):
91 | def __init__(self, alpha=1.2, beta=1.5, eps=1e-4):
92 | super(LTVloss, self).__init__()
93 | self.alpha = alpha
94 | self.beta = beta
95 | self.eps = eps
96 |
97 | def forward(self, origin, illumination, weight):
98 | '''
99 | origin: one batch of input data. shape [batchsize, 3, h, w]
100 | illumination: one batch of predicted illumination data. if predicted_illumination
101 | is False, then use the output (predicted result) of the network.
102 | '''
103 |
104 | # # re-normalize origin to 0 ~ 1
105 | # origin = (input_ - input_.min().item()) / (input_.max().item() - input_.min().item())
106 |
107 | I = origin[:, 0:1, :, :] * 0.299 + origin[:, 1:2, :, :] * \
108 | 0.587 + origin[:, 2:3, :, :] * 0.114
109 | L = torch.log(I + self.eps)
110 | dx = L[:, :, :-1, :-1] - L[:, :, :-1, 1:]
111 | dy = L[:, :, :-1, :-1] - L[:, :, 1:, :-1]
112 |
113 | dx = self.beta / (torch.pow(torch.abs(dx), self.alpha) + self.eps)
114 | dy = self.beta / (torch.pow(torch.abs(dy), self.alpha) + self.eps)
115 |
116 | x_loss = dx * \
117 | ((illumination[:, :, :-1, :-1] - illumination[:, :, :-1, 1:]) ** 2)
118 | y_loss = dy * \
119 | ((illumination[:, :, :-1, :-1] - illumination[:, :, 1:, :-1]) ** 2)
120 | tvloss = torch.mean(x_loss + y_loss) / 2.0
121 |
122 | return tvloss * weight
123 |
124 |
125 | class L_TV(nn.Module):
126 | def __init__(self, TVLoss_weight=1):
127 | super(L_TV, self).__init__()
128 | self.TVLoss_weight = TVLoss_weight
129 |
130 | def forward(self, x):
131 | batch_size = x.size()[0]
132 | h_x = x.size()[2]
133 | w_x = x.size()[3]
134 | count_h = (x.size()[2] - 1) * x.size()[3]
135 | count_w = x.size()[2] * (x.size()[3] - 1)
136 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
137 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
138 | return self.TVLoss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
139 |
--------------------------------------------------------------------------------
/src/model/bilateralupsamplenet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from globalenv import *
7 | from .arch.drconv import DRConv2d
8 | from model.arch.unet_based.hist_unet import HistUNet
9 | from .basic_loss import LTVloss
10 | from .single_net_basemodel import SingleNetBaseModel
11 |
12 |
13 | class LitModel(SingleNetBaseModel):
14 | def __init__(self, opt):
15 | super().__init__(opt, BilateralUpsampleNet(opt[RUNTIME]), [TRAIN, VALID])
16 | low_res = opt[RUNTIME][LOW_RESOLUTION]
17 |
18 | self.down_sampler = lambda x: F.interpolate(x, size=(low_res, low_res), mode='bicubic', align_corners=False)
19 | self.use_illu = opt[RUNTIME][PREDICT_ILLUMINATION]
20 |
21 | self.mse = torch.nn.MSELoss()
22 | self.ltv = LTVloss()
23 | self.cos = torch.nn.CosineSimilarity(1, 1e-8)
24 |
25 | self.net.train()
26 |
27 | def training_step(self, batch, batch_idx):
28 | input_batch, gt_batch, output_batch = super().training_step_forward(batch, batch_idx)
29 | loss_lambda_map = {
30 | MSE: lambda: self.mse(output_batch, gt_batch),
31 | COS_LOSS: lambda: (1 - self.cos(output_batch, gt_batch).mean()) * 0.5,
32 | LTV_LOSS: lambda: self.ltv(input_batch, self.net.illu_map, 1) if self.use_illu else None,
33 | }
34 |
35 | # logging:
36 | loss = self.calc_and_log_losses(loss_lambda_map)
37 | self.log_training_iogt_img(batch, extra_img_dict={
38 | PREDICT_ILLUMINATION: self.net.illu_map,
39 | GUIDEMAP: self.net.guidemap
40 | })
41 | return loss
42 |
43 | def validation_step(self, batch, batch_idx):
44 | super().validation_step(batch, batch_idx)
45 |
46 | def test_step(self, batch, batch_ix):
47 | super().test_step(batch, batch_ix)
48 |
49 | def forward(self, x):
50 | low_res_x = self.down_sampler(x)
51 | return self.net(low_res_x, x)
52 |
53 |
54 | class ConvBlock(nn.Module):
55 | def __init__(self, inc, outc, kernel_size=3, padding=1, stride=1, use_bias=True, activation=nn.ReLU,
56 | batch_norm=False):
57 | super(ConvBlock, self).__init__()
58 | conv_type = OPT['conv_type']
59 | if conv_type == 'conv':
60 | self.conv = nn.Conv2d(int(inc), int(outc), kernel_size, padding=padding, stride=stride, bias=use_bias)
61 | elif conv_type.startswith('drconv'):
62 | region_num = int(conv_type.replace('drconv', ''))
63 | self.conv = DRConv2d(int(inc), int(outc), kernel_size, region_num=region_num, padding=padding,
64 | stride=stride)
65 | print(f'[ WARN ] Using DRconv2d(n_region={region_num}) instead of Conv2d in BilateralUpsampleNet.')
66 | else:
67 | raise NotImplementedError()
68 |
69 | self.activation = activation() if activation else None
70 | self.bn = nn.BatchNorm2d(outc) if batch_norm else None
71 |
72 | def forward(self, x):
73 | x = self.conv(x)
74 | if self.bn:
75 | x = self.bn(x)
76 | if self.activation:
77 | x = self.activation(x)
78 | return x
79 |
80 |
81 | class FC(nn.Module):
82 | def __init__(self, inc, outc, activation=nn.ReLU, batch_norm=False):
83 | super(FC, self).__init__()
84 | self.fc = nn.Linear(int(inc), int(outc), bias=(not batch_norm))
85 | self.activation = activation() if activation else None
86 | self.bn = nn.BatchNorm1d(outc) if batch_norm else None
87 |
88 | def forward(self, x):
89 | x = self.fc(x)
90 | if self.bn:
91 | x = self.bn(x)
92 | if self.activation:
93 | x = self.activation(x)
94 | return x
95 |
96 |
97 | class SliceNode(nn.Module):
98 | def __init__(self, opt):
99 | super(SliceNode, self).__init__()
100 | self.opt = opt
101 |
102 | def forward(self, bilateral_grid, guidemap):
103 | # bilateral_grid shape: Nx12x8x16x16
104 | device = bilateral_grid.get_device()
105 | N, _, H, W = guidemap.shape
106 | hg, wg = torch.meshgrid([torch.arange(0, H), torch.arange(0, W)]) # [0,511] HxW
107 | if device >= 0:
108 | hg = hg.to(device)
109 | wg = wg.to(device)
110 |
111 | hg = hg.float().repeat(N, 1, 1).unsqueeze(3) / (H - 1) * 2 - 1 # norm to [-1,1] NxHxWx1
112 | wg = wg.float().repeat(N, 1, 1).unsqueeze(3) / (W - 1) * 2 - 1 # norm to [-1,1] NxHxWx1
113 | guidemap = guidemap * 2 - 1
114 | guidemap = guidemap.permute(0, 2, 3, 1).contiguous()
115 | guidemap_guide = torch.cat([wg, hg, guidemap], dim=3).unsqueeze(1)
116 |
117 | # guidemap shape: [N, 1 (D), H, W]
118 | # bilateral_grid shape: [N, 12 (c), 8 (d), 16 (h), 16 (w)], which is considered as a 3D space: [8, 16, 16]
119 | # guidemap_guide shape: [N, 1 (D), H, W, 3], which is considered as a 3D space: [1, H, W]
120 | # coeff shape: [N, 12 (c), 1 (D), H, W]
121 |
122 | # in F.grid_sample, gird is guidemap_guide, input is bilateral_grid
123 | # guidemap_guide[N, D, H, W] is a 3-vector . but:
124 | # x -> W, y -> H, z -> D in bilater_grid
125 | # What does it really do:
126 | # [ 1 ] For pixel in guidemap_guide[D, H, W], get , and:
127 | # [ 2 ] Normalize from [-1, 1] to [0, w - 1], [0, h - 1], [0, d - 1], respectively.
128 | # [ 3 ] Locate pixel in bilateral_grid at position [N, :, z, y, x].
129 | # [ 4 ] Interplate using the neighbor values as the output affine matrix.
130 |
131 | # Force them have the same type for fp16 training :
132 | guidemap_guide = guidemap_guide.type_as(bilateral_grid)
133 | # bilateral_grid = bilateral_grid.type_as(guidemap_guide)
134 | coeff = F.grid_sample(bilateral_grid, guidemap_guide, 'bilinear', align_corners=True)
135 | return coeff.squeeze(2)
136 |
137 |
138 | class ApplyCoeffs(nn.Module):
139 | def __init__(self):
140 | super(ApplyCoeffs, self).__init__()
141 |
142 | def forward(self, coeff, full_res_input):
143 | '''
144 | coeff shape: [bs, 12, h, w]
145 | input shape: [bs, 3, h, w]
146 | Affine:
147 | r = a11*r + a12*g + a13*b + a14
148 | g = a21*r + a22*g + a23*b + a24
149 | ...
150 | '''
151 | R = torch.sum(full_res_input * coeff[:, 0:3, :, :], dim=1, keepdim=True) + coeff[:, 3:4, :, :]
152 | G = torch.sum(full_res_input * coeff[:, 4:7, :, :], dim=1, keepdim=True) + coeff[:, 7:8, :, :]
153 | B = torch.sum(full_res_input * coeff[:, 8:11, :, :], dim=1, keepdim=True) + coeff[:, 11:12, :, :]
154 |
155 | return torch.cat([R, G, B], dim=1)
156 |
157 |
158 | class ApplyCoeffsGamma(nn.Module):
159 | def __init__(self):
160 | super(ApplyCoeffsGamma, self).__init__()
161 | print('[ WARN ] Use alter methods indtead of affine matrix.')
162 |
163 | def forward(self, x_r, x):
164 | '''
165 | coeff shape: [bs, 12, h, w]
166 | apply zeroDCE curve.
167 | '''
168 |
169 | # [ 008 ] single iteration alpha map:
170 | # coeff channel num: 3
171 | # return x + x_r * (torch.pow(x, 2) - x)
172 |
173 | # [ 009 ] 8 iteratoins:
174 | # coeff channel num: 24
175 | r1, r2, r3, r4, r5, r6, r7, r8 = torch.split(x_r, 3, dim=1)
176 | x = x + r1 * (torch.pow(x, 2) - x)
177 | x = x + r2 * (torch.pow(x, 2) - x)
178 | x = x + r3 * (torch.pow(x, 2) - x)
179 | enhance_image_1 = x + r4 * (torch.pow(x, 2) - x)
180 | x = enhance_image_1 + r5 * (torch.pow(enhance_image_1, 2) - enhance_image_1)
181 | x = x + r6 * (torch.pow(x, 2) - x)
182 | x = x + r7 * (torch.pow(x, 2) - x)
183 | enhance_image = x + r8 * (torch.pow(x, 2) - x)
184 | r = torch.cat([r1, r2, r3, r4, r5, r6, r7, r8], 1)
185 | return enhance_image
186 |
187 | # [ 014 ] use illu map:
188 | # coeff channel num: 3
189 | # return x / (torch.where(x_r < x, x, x_r) + 1e-7)
190 |
191 | # [ 015 ] use HSV and only affine V channel:
192 | # coeff channel num: 3
193 | # V = torch.sum(x * x_r, dim=1, keepdim=True) + x_r
194 | # return torch.cat([x[:, 0:2, ...], V], dim=1)
195 |
196 |
197 | class ApplyCoeffsRetinex(nn.Module):
198 | def __init__(self):
199 | super().__init__()
200 | print('[ WARN ] Use alter methods indtead of affine matrix.')
201 |
202 | def forward(self, x_r, x):
203 | '''
204 | coeff shape: [bs, 12, h, w]
205 | apply division of illumap.
206 | '''
207 |
208 | # [ 014 ] use illu map:
209 | # coeff channel num: 3
210 | return x / (torch.where(x_r < x, x, x_r) + 1e-7)
211 |
212 |
213 | class GuideNet(nn.Module):
214 | def __init__(self, params=None, out_channel=1):
215 | super(GuideNet, self).__init__()
216 | self.params = params
217 | self.conv1 = ConvBlock(3, 16, kernel_size=1, padding=0, batch_norm=True)
218 | self.conv2 = ConvBlock(16, out_channel, kernel_size=1, padding=0, activation=nn.Sigmoid) # nn.Tanh
219 |
220 | def forward(self, x):
221 | return self.conv2(self.conv1(x)) # .squeeze(1)
222 |
223 |
224 | class LowResNet(nn.Module):
225 | def __init__(self, coeff_dim=12, opt=None):
226 | super(LowResNet, self).__init__()
227 | self.params = opt
228 | self.coeff_dim = coeff_dim
229 |
230 | lb = opt[LUMA_BINS]
231 | cm = opt[CHANNEL_MULTIPLIER]
232 | sb = opt[SPATIAL_BIN]
233 | bn = opt[BATCH_NORM]
234 | nsize = opt[LOW_RESOLUTION]
235 |
236 | self.relu = nn.ReLU()
237 |
238 | # splat features
239 | n_layers_splat = int(np.log2(nsize / sb))
240 | self.splat_features = nn.ModuleList()
241 | prev_ch = 3
242 | for i in range(n_layers_splat):
243 | use_bn = bn if i > 0 else False
244 | self.splat_features.append(ConvBlock(prev_ch, cm * (2 ** i) * lb, 3, stride=2, batch_norm=use_bn))
245 | prev_ch = splat_ch = cm * (2 ** i) * lb
246 |
247 | # global features
248 | n_layers_global = int(np.log2(sb / 4))
249 | # print(n_layers_global)
250 | self.global_features_conv = nn.ModuleList()
251 | self.global_features_fc = nn.ModuleList()
252 | for i in range(n_layers_global):
253 | self.global_features_conv.append(ConvBlock(prev_ch, cm * 8 * lb, 3, stride=2, batch_norm=bn))
254 | prev_ch = cm * 8 * lb
255 |
256 | n_total = n_layers_splat + n_layers_global
257 | prev_ch = prev_ch * (nsize / 2 ** n_total) ** 2
258 | self.global_features_fc.append(FC(prev_ch, 32 * cm * lb, batch_norm=bn))
259 | self.global_features_fc.append(FC(32 * cm * lb, 16 * cm * lb, batch_norm=bn))
260 | self.global_features_fc.append(FC(16 * cm * lb, 8 * cm * lb, activation=None, batch_norm=bn))
261 |
262 | # local features
263 | self.local_features = nn.ModuleList()
264 | self.local_features.append(ConvBlock(splat_ch, 8 * cm * lb, 3, batch_norm=bn))
265 | self.local_features.append(ConvBlock(8 * cm * lb, 8 * cm * lb, 3, activation=None, use_bias=False))
266 |
267 | # predicton
268 | self.conv_out = ConvBlock(8 * cm * lb, lb * coeff_dim, 1, padding=0, activation=None)
269 |
270 | def forward(self, lowres_input):
271 | params = self.params
272 | bs = lowres_input.shape[0]
273 | lb = params[LUMA_BINS]
274 | cm = params[CHANNEL_MULTIPLIER]
275 | sb = params[SPATIAL_BIN]
276 |
277 | x = lowres_input
278 | for layer in self.splat_features:
279 | x = layer(x)
280 | splat_features = x
281 |
282 | for layer in self.global_features_conv:
283 | x = layer(x)
284 | x = x.view(bs, -1)
285 | for layer in self.global_features_fc:
286 | x = layer(x)
287 | global_features = x
288 |
289 | x = splat_features
290 | for layer in self.local_features:
291 | x = layer(x)
292 | local_features = x
293 |
294 | # shape: bs x 64 x 16 x 16
295 | fusion_grid = local_features
296 |
297 | # shape: bs x 64 x 1 x 1
298 | fusion_global = global_features.view(bs, 8 * cm * lb, 1, 1)
299 | fusion = self.relu(fusion_grid + fusion_global)
300 |
301 | x = self.conv_out(fusion)
302 |
303 | # reshape channel dimension -> bilateral grid dimensions:
304 | # [bs, 96, 16, 16] -> [bs, 12, 8, 16, 16]
305 | y = torch.stack(torch.split(x, self.coeff_dim, 1), 2)
306 | return y
307 |
308 |
309 | class LowResHistUNet(HistUNet):
310 | def __init__(self, coeff_dim=12, opt=None):
311 | super(LowResHistUNet, self).__init__(
312 | in_channels=3,
313 | out_channels=coeff_dim * opt[LUMA_BINS],
314 | bilinear=True,
315 | **opt[HIST_UNET]
316 | )
317 | self.coeff_dim = coeff_dim
318 | print('[[ WARN ]] Using HistUNet in BilateralUpsampleNet as backbone')
319 |
320 | def forward(self, x):
321 | y = super(LowResHistUNet, self).forward(x)
322 | y = torch.stack(torch.split(y, self.coeff_dim, 1), 2)
323 | return y
324 |
325 |
326 | class BilateralUpsampleNet(nn.Module):
327 | def __init__(self, opt):
328 | super(BilateralUpsampleNet, self).__init__()
329 | self.opt = opt
330 | global OPT
331 | OPT = opt
332 | self.guide = GuideNet(params=opt)
333 | self.slice = SliceNode(opt)
334 | self.build_coeffs_network(opt)
335 |
336 | def build_coeffs_network(self, opt):
337 | # Choose backbone:
338 | if opt[BACKBONE] == 'ori':
339 | Backbone = LowResNet
340 | elif opt[BACKBONE] == 'hist-unet':
341 | Backbone = LowResHistUNet
342 | else:
343 | raise NotImplementedError()
344 |
345 | # How to apply coeffs:
346 | # ───────────────────────────────────────────────────────────────────
347 | if opt[COEFFS_TYPE] == MATRIX:
348 | self.coeffs = Backbone(opt=opt)
349 | self.apply_coeffs = ApplyCoeffs()
350 |
351 | elif opt[COEFFS_TYPE] == GAMMA:
352 | print('[[ WARN ]] HDRPointwiseNN use COEFFS_TYPE: GAMMA.')
353 |
354 | # [ 008 ] change affine matrix -> other methods (alpha map, illu map)
355 | self.coeffs = Backbone(opt=opt, coeff_dim=24)
356 | self.apply_coeffs = ApplyCoeffsGamma()
357 |
358 | elif opt[COEFFS_TYPE] == 'retinex':
359 | print('[[ WARN ]] HDRPointwiseNN use COEFFS_TYPE: retinex.')
360 | self.coeffs = Backbone(opt=opt, coeff_dim=3)
361 | self.apply_coeffs = ApplyCoeffsRetinex()
362 |
363 | else:
364 | raise NotImplementedError(f'[ ERR ] coeff type {opt[COEFFS_TYPE]} unkown.')
365 | # ─────────────────────────────────────────────────────────────────────────────
366 |
367 | def forward(self, lowres, fullres):
368 | bilateral_grid = self.coeffs(lowres)
369 | try:
370 | self.guide_features = self.coeffs.guide_features
371 | except:
372 | ...
373 | guide = self.guide(fullres)
374 | self.guidemap = guide
375 |
376 | slice_coeffs = self.slice(bilateral_grid, guide)
377 | out = self.apply_coeffs(slice_coeffs, fullres)
378 |
379 | # use illu map:
380 | self.slice_coeffs = slice_coeffs
381 | # if self.opt[PREDICT_ILLUMINATION]:
382 | #
383 | # power = self.opt[ILLU_MAP_POWER]
384 | # if power:
385 | # assert type(power + 0.1) == float
386 | # out = out.pow(power)
387 | #
388 | # out = out.clamp(fullres, torch.ones_like(out))
389 | # # out = torch.where(out < fullres, fullres, out)
390 | # self.illu_map = out
391 | # out = fullres / (out + 1e-7)
392 | # else:
393 | self.illu_map = None
394 |
395 | if self.opt[PREDICT_ILLUMINATION]:
396 | return fullres / (out.clamp(fullres, torch.ones_like(out)) + 1e-7)
397 | else:
398 | return out
399 |
--------------------------------------------------------------------------------
/src/model/lcdpnet.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 |
3 | import kornia as kn
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 | import torchvision.utils
8 | import utils.util as util
9 |
10 | from globalenv import *
11 | from .arch.nonlocal_block_embedded_gaussian import NONLocalBlock2D
12 | from .basic_loss import L_TV, WeightedL1Loss, HistogramLoss, IntermediateHistogramLoss, LTVloss
13 | # from .hdrunet import tanh_L1Loss
14 | from .single_net_basemodel import SingleNetBaseModel
15 |
16 |
17 | class tanh_L1Loss(nn.Module):
18 | def __init__(self):
19 | super(tanh_L1Loss, self).__init__()
20 |
21 | def forward(self, x, y):
22 | loss = torch.mean(torch.abs(torch.tanh(x) - torch.tanh(y)))
23 | return loss
24 |
25 |
26 | class LitModel(SingleNetBaseModel):
27 | def __init__(self, opt):
28 | super().__init__(opt, DeepWBNet(opt[RUNTIME]), [TRAIN, VALID])
29 | # self.pixel_loss = torch.nn.MSELoss()
30 |
31 | # [ 008-u1 ] use tanh L1 loss
32 | self.pixel_loss = tanh_L1Loss()
33 |
34 | # [ 019 ] use log L1 loss
35 | # self.pixel_loss = LogL2Loss()
36 |
37 | # [ 028 ] use weighted L1 loss.
38 | self.weighted_loss = WeightedL1Loss()
39 | self.tvloss = L_TV()
40 | self.ltv2 = LTVloss()
41 | self.cos = torch.nn.CosineSimilarity(1, 1e-8)
42 | self.histloss = HistogramLoss()
43 | # self.vggloss = VGGLoss(shift=2)
44 | # self.vggloss.train()
45 | self.inter_histloss = IntermediateHistogramLoss()
46 |
47 | def training_step(self, batch, batch_idx):
48 | input_batch, gt_batch, output_batch = super().training_step_forward(batch, batch_idx)
49 | # print('[*] Now running:', batch[INPUT].shape, batch[GT].shape, output_batch.shape, batch[INPUT_FPATH], batch[GT_FPATH])
50 |
51 | loss_lambda_map = {
52 | L1_LOSS: lambda: self.pixel_loss(output_batch, gt_batch),
53 | COS_LOSS: lambda: (1 - self.cos(output_batch, gt_batch).mean()) * 0.5,
54 | COS_LOSS + '2': lambda: 1 - F.sigmoid(self.cos(output_batch, gt_batch).mean()),
55 | LTV_LOSS: lambda: self.tvloss(output_batch),
56 | 'tvloss1': lambda: self.tvloss(self.net.res[ILLU_MAP]),
57 | 'tvloss2': lambda: self.tvloss(self.net.res[INVERSE_ILLU_MAP]),
58 |
59 | 'tvloss1_new': lambda: self.ltv2(input_batch, self.net.res[ILLU_MAP], 1),
60 | 'tvloss2_new': lambda: self.ltv2(1 - input_batch, self.net.res[INVERSE_ILLU_MAP], 1),
61 | 'illumap_loss': lambda: F.mse_loss(self.net.res[ILLU_MAP], 1 - self.net.res[INVERSE_ILLU_MAP]),
62 | WEIGHTED_LOSS: lambda: self.weighted_loss(input_batch.detach(), output_batch, gt_batch),
63 | SSIM_LOSS: lambda: kn.losses.ssim_loss(output_batch, gt_batch, window_size=5),
64 | PSNR_LOSS: lambda: kn.losses.psnr_loss(output_batch, gt_batch, max_val=1.0),
65 | HIST_LOSS: lambda: self.histloss(output_batch, gt_batch),
66 | INTER_HIST_LOSS: lambda: self.inter_histloss(
67 | input_batch, gt_batch, self.net.res[BRIGHTEN_INPUT], self.net.res[DARKEN_INPUT]),
68 | VGG_LOSS: lambda: self.vggloss(input_batch, gt_batch),
69 | }
70 | loss = self.calc_and_log_losses(loss_lambda_map)
71 |
72 | # logging images:
73 | self.log_training_iogt_img(batch)
74 | return loss
75 |
76 | def validation_step(self, batch, batch_idx):
77 | super().validation_step(batch, batch_idx)
78 |
79 | def test_step(self, batch, batch_ix):
80 | super().test_step(batch, batch_ix)
81 |
82 | # save intermidiate results
83 | for k, v in self.net.res.items():
84 | dirpath = Path(self.opt[IMG_DIRPATH]) / k
85 | fname = osp.basename(batch[INPUT_FPATH][0])
86 | if 'illu' in k:
87 | util.mkdir(dirpath)
88 | torchvision.utils.save_image(v[0].unsqueeze(1), dirpath / fname)
89 | elif k == 'guide_features':
90 | # v.shape: [bs, region_num, h, w]
91 | util.mkdir(dirpath)
92 | max_size = v[-1][-1].shape[-2:]
93 | final = []
94 | for level_guide in v:
95 | gs = [F.interpolate(g, max_size) for g in level_guide]
96 | final.extend(gs)
97 | # import ipdb
98 | # ipdb.set_trace()
99 | region_num = final[0].shape[1]
100 | final = torch.stack(final).argmax(axis=2).float() / region_num
101 | # ipdb.set_trace()
102 | torchvision.utils.save_image(final, dirpath / fname)
103 | else:
104 | self.save_img_batch(v, dirpath, fname)
105 |
106 |
107 | class DeepWBNet(nn.Module):
108 | def build_illu_net(self):
109 | # if self.opt[BACKBONE] == 'unet':
110 | # if self.opt[USE_ATTN_MAP]:
111 | # return UNet(
112 | # self.opt,
113 | # in_channels=4,
114 | # out_channels=1,
115 | # wavelet=self.opt[USE_WAVELET],
116 | # non_local=self.opt[NON_LOCAL]
117 | # )
118 | # else:
119 | # return UNet(self.opt, out_channels=self.opt[ILLUMAP_CHANNEL], wavelet=self.opt[USE_WAVELET])
120 |
121 | from .bilateralupsamplenet import BilateralUpsampleNet
122 | return BilateralUpsampleNet(self.opt[BUNET])
123 | #
124 | # elif self.opt[BACKBONE] == 'ynet':
125 | # from .arch.ynet import YNet
126 | # return YNet()
127 | #
128 | # elif self.opt[BACKBONE] == 'hdrunet':
129 | # from .hdrunet import HDRUNet
130 | # return HDRUNet()
131 | #
132 | # elif self.opt[BACKBONE] == 'hist-unet':
133 | # from model.arch.unet_based.hist_unet import HistUNet
134 | # return HistUNet(**self.opt[HIST_UNET])
135 | #
136 | # else:
137 | # raise NotImplementedError(f'[[ ERR ]] Unknown backbone arch: {self.opt[BACKBONE]}')
138 |
139 | def backbone_forward(self, net, x):
140 | if self.opt[BACKBONE] in ['unet', 'hdrunet', 'hist-unet']:
141 | return net(x)
142 |
143 | elif self.opt[BACKBONE] == 'ynet':
144 | return net.forward_2input(x, 1 - x)
145 |
146 | elif self.opt[BACKBONE] == BUNET:
147 | low_x = self.down_sampler(x)
148 | res = net(low_x, x)
149 | try:
150 | self.res.update({'guide_features': net.guide_features})
151 | except:
152 | ...
153 | # print('[yellow]No guide feature found in BilateralUpsampleNet[/yellow]')
154 | return res
155 |
156 | def __init__(self, opt=None):
157 | super(DeepWBNet, self).__init__()
158 | self.opt = opt
159 | self.down_sampler = lambda x: F.interpolate(x, size=(256, 256), mode='bicubic', align_corners=False)
160 | self.illu_net = self.build_illu_net()
161 |
162 | # [ 021 ] use 2 illu nets (do not share weights).
163 | if not opt[SHARE_WEIGHTS]:
164 | self.illu_net2 = self.build_illu_net()
165 |
166 | # self.guide_net = GuideNN(out_channel=3)
167 | if opt[HOW_TO_FUSE] in ['cnn-weights', 'cnn-direct', 'cnn-softmax3']:
168 | # self.out_net = UNet(in_channels=9, wavelet=opt[USE_WAVELET])
169 |
170 | # [ 008-u1 ] use a simple network
171 | nf = 32
172 | self.out_net = nn.Sequential(
173 | nn.Conv2d(9, nf, 3, 1, 1),
174 | nn.ReLU(inplace=True),
175 | nn.Conv2d(nf, nf, 3, 1, 1),
176 | nn.ReLU(inplace=True),
177 | NONLocalBlock2D(nf, sub_sample='bilinear', bn_layer=False),
178 | nn.Conv2d(nf, nf, 1),
179 | nn.ReLU(inplace=True),
180 | nn.Conv2d(nf, 3, 1),
181 | NONLocalBlock2D(3, sub_sample='bilinear', bn_layer=False),
182 | )
183 |
184 | elif opt[HOW_TO_FUSE] in ['cnn-color']:
185 | # self.out_net = UNet(in_channels=3, wavelet=opt[USE_WAVELET])
186 | ...
187 |
188 | if not self.opt[BACKBONE_OUT_ILLU]:
189 | print('[[ WARN ]] Use output of backbone as brighten & darken directly.')
190 | self.res = {}
191 |
192 | def decomp(self, x1, illu_map):
193 | return x1 / (torch.where(illu_map < x1, x1, illu_map.float()) + 1e-7)
194 |
195 | def one_iter(self, x, attn_map, inverse_attn_map):
196 | # used only when USE_ATTN_MAP
197 | x1 = torch.cat((x, attn_map), 1)
198 | inverse_x1 = torch.cat((1 - x, inverse_attn_map), 1)
199 |
200 | illu_map = self.illu_net(x1, attn_map)
201 | inverse_illu_map = self.illu_net(inverse_x1)
202 | return illu_map, inverse_illu_map
203 |
204 | def forward(self, x):
205 | # ──────────────────────────────────────────────────────────
206 | # [ <008 ] use guideNN
207 | # x1 = self.guide_net(x).clamp(0, 1)
208 |
209 | # [ 008 ] use original input
210 | x1 = x
211 | inverse_x1 = 1 - x1
212 |
213 | if self.opt[USE_ATTN_MAP]:
214 | # [ 015 ] use attn map iteration to get illu map
215 | r, g, b = x[:, 0] + 1, x[:, 1] + 1, x[:, 2] + 1
216 |
217 | # init attn map as illumination channel of original input img:
218 | attn_map = (1. - (0.299 * r + 0.587 * g + 0.114 * b) / 2.).unsqueeze(1)
219 | inverse_attn_map = 1 - attn_map
220 | for _ in range(3):
221 | inverse_attn_map, attn_map = self.one_iter(x, attn_map, inverse_attn_map)
222 | illu_map, inverse_illu_map = inverse_attn_map, attn_map
223 |
224 | elif self.opt[BACKBONE] == 'ynet':
225 | # [ 024 ] one encoder, 2 decoders.
226 | illu_map, inverse_illu_map = self.backbone_forward(self.illu_net, x1)
227 |
228 | else:
229 | illu_map = self.backbone_forward(self.illu_net, x1)
230 | if self.opt[SHARE_WEIGHTS]:
231 | inverse_illu_map = self.backbone_forward(self.illu_net, inverse_x1)
232 | else:
233 | # [ 021 ] use 2 illu nets
234 | inverse_illu_map = self.backbone_forward(self.illu_net2, inverse_x1)
235 | # ──────────────────────────────────────────────────────────
236 |
237 | if self.opt[BACKBONE_OUT_ILLU]:
238 | brighten_x1 = self.decomp(x1, illu_map)
239 | inverse_x2 = self.decomp(inverse_x1, inverse_illu_map)
240 | else:
241 | brighten_x1 = illu_map
242 | inverse_x2 = inverse_illu_map
243 | darken_x1 = 1 - inverse_x2
244 | # ──────────────────────────────────────────────────────────
245 |
246 | self.res.update({
247 | ILLU_MAP: illu_map,
248 | INVERSE_ILLU_MAP: inverse_illu_map,
249 | BRIGHTEN_INPUT: brighten_x1,
250 | DARKEN_INPUT: darken_x1,
251 | })
252 |
253 | # fusion:
254 | # ──────────────────────────────────────────────────────────
255 | if self.opt[HOW_TO_FUSE] == 'cnn-weights':
256 | # [ 009 ] only fuse 2 output image
257 | # fused_x = torch.cat([brighten_x1, darken_x1], dim=1)
258 |
259 | fused_x = torch.cat([x, brighten_x1, darken_x1], dim=1)
260 |
261 | # [ 007 ] get weight-map from UNet, then get output from weight-map
262 | weight_map = self.out_net(fused_x) # <- 3 channels, [ N, 3, H, W ]
263 | w1 = weight_map[:, 0, ...].unsqueeze(1)
264 | w2 = weight_map[:, 1, ...].unsqueeze(1)
265 | w3 = weight_map[:, 2, ...].unsqueeze(1)
266 | out = x * w1 + brighten_x1 * w2 + darken_x1 * w3
267 |
268 | # [ 009 ] only fuse 2 output image
269 | # out = brighten_x1 * w1 + darken_x1 * w2
270 | # ────────────────────────────────────────────────────────────
271 |
272 | elif self.opt[HOW_TO_FUSE] == 'cnn-softmax3':
273 | fused_x = torch.cat([x, brighten_x1, darken_x1], dim=1)
274 | weight_map = F.softmax(self.out_net(fused_x), dim=1) # <- 3 channels, [ N, 3, H, W ]
275 | w1 = weight_map[:, 0, ...].unsqueeze(1)
276 | w2 = weight_map[:, 1, ...].unsqueeze(1)
277 | w3 = weight_map[:, 2, ...].unsqueeze(1)
278 | out = x * w1 + brighten_x1 * w2 + darken_x1 * w3
279 |
280 | # [ 006 ] get output directly from UNet
281 | elif self.opt[HOW_TO_FUSE] == 'cnn-direct':
282 | fused_x = torch.cat([x, brighten_x1, darken_x1], dim=1)
283 | out = self.out_net(fused_x)
284 |
285 | # [ 016 ] average 2 outputs.
286 | elif self.opt[HOW_TO_FUSE] == 'avg':
287 | out = 0.5 * brighten_x1 + 0.5 * darken_x1
288 |
289 | # [ 017 ] global color ajust
290 | elif self.opt[HOW_TO_FUSE] == 'cnn-color':
291 | out = 0.5 * brighten_x1 + 0.5 * darken_x1
292 |
293 | # elif self.opt[HOW_TO_FUSE] == 'cnn-residual':
294 | # out = x +
295 |
296 | else:
297 | raise NotImplementedError(f'Unknown fusion method: {self.opt[HOW_TO_FUSE]}')
298 | return out
299 |
--------------------------------------------------------------------------------
/src/model/single_net_basemodel.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import os.path as osp
4 |
5 | import cv2
6 | import ipdb
7 | import torch
8 | import torch.nn.functional as F
9 | import torch.optim as optim
10 | import torchmetrics
11 |
12 | import utils.util as util
13 | from globalenv import *
14 | from .basemodel import BaseModel
15 |
16 |
17 | # dict_merge = lambda a, b: a.update(b) or a
18 |
19 |
20 | class SingleNetBaseModel(BaseModel):
21 | # for models with only one self.net
22 | def __init__(self, opt, net, running_modes, valid_ssim=False, print_arch=True):
23 | super().__init__(opt, running_modes)
24 | self.net = net
25 | self.net.train()
26 |
27 | # config for SingleNetBaseModel
28 | if print_arch:
29 | print(str(net))
30 | self.valid_ssim = valid_ssim # weather to compute ssim in validation
31 | self.tonemapper = cv2.createTonemapReinhard(2.2)
32 |
33 | self.psnr_func = torchmetrics.PeakSignalNoiseRatio(data_range=1)
34 | self.ssim_func = torchmetrics.StructuralSimilarityIndexMeasure(data_range=1)
35 |
36 | def configure_optimizers(self):
37 | # self.parameters in LitModel is the same as nn.Module.
38 | # once you add nn.xxxx as a member in __init__, self.parameters will include it.
39 | optimizer = optim.Adam(self.net.parameters(), lr=self.learning_rate)
40 | # optimizer = optim.Adam(self.net.parameters(), lr=self.opt[LR])
41 |
42 | schedular = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
43 | return [optimizer], [schedular]
44 |
45 | def forward(self, x):
46 | return self.net(x)
47 |
48 | def training_step_forward(self, batch, batch_idx):
49 | if not self.MODEL_WATCHED and not self.opt[DEBUG] and self.opt.logger == 'wandb':
50 | self.logger.experiment.watch(
51 | self.net, log_freq=self.opt[LOG_EVERY] * 2, log_graph=True)
52 | self.MODEL_WATCHED = True
53 | # self.show_flops_and_param_num([batch[INPUT]])
54 |
55 | input_batch, gt_batch = batch[INPUT], batch[GT]
56 | output_batch = self(input_batch)
57 | self.iogt = {
58 | INPUT: input_batch,
59 | OUTPUT: output_batch,
60 | GT: gt_batch,
61 | }
62 | return input_batch, gt_batch, output_batch
63 |
64 | def validation_step(self, batch, batch_idx):
65 | input_batch, gt_batch = batch[INPUT], batch[GT]
66 | output_batch = self(input_batch)
67 |
68 | # log psnr
69 | output_ = util.cuda_tensor_to_ndarray(output_batch)
70 | y_ = util.cuda_tensor_to_ndarray(gt_batch)
71 | try:
72 | psnr = util.ImageProcessing.compute_psnr(output_, y_, 1.0)
73 | except:
74 | ipdb.set_trace()
75 | self.log(PSNR, psnr)
76 |
77 | # log SSIM (optional)
78 | if self.valid_ssim:
79 | ssim = util.ImageProcessing.compute_ssim(output_batch, gt_batch)
80 | self.log(SSIM, ssim)
81 |
82 | # log images
83 | if self.global_valid_step % self.opt.log_every == 0:
84 | self.log_images_dict(
85 | VALID,
86 | osp.basename(batch[INPUT_FPATH][0]),
87 | {
88 | INPUT: input_batch,
89 | OUTPUT: output_batch,
90 | GT: gt_batch,
91 | },
92 | gt_fname=osp.basename(batch[GT_FPATH][0])
93 | )
94 | self.global_valid_step += 1
95 | return output_batch
96 |
97 | def log_training_iogt_img(self, batch, extra_img_dict=None):
98 | """
99 | Only used in training_step
100 | """
101 | if extra_img_dict:
102 | img_dict = {**self.iogt, **extra_img_dict}
103 | else:
104 | img_dict = self.iogt
105 |
106 | if self.global_step % self.opt.log_every == 0:
107 | self.log_images_dict(
108 | TRAIN,
109 | osp.basename(batch[INPUT_FPATH][0]),
110 | img_dict,
111 | gt_fname=osp.basename(batch[GT_FPATH][0])
112 | )
113 |
114 | @staticmethod
115 | def logdomain2hdr(ldr_batch):
116 | return 10 ** ldr_batch - 1
117 |
118 | def on_test_start(self):
119 | self.total_psnr = 0
120 | self.total_ssim = 0
121 | self.global_test_step = 0
122 |
123 | def on_test_end(self):
124 | print(
125 | f'Test step: {self.global_test_step}, Manual PSNR: {self.total_psnr / self.global_test_step}, Manual SSIM: {self.total_ssim / self.global_test_step}')
126 |
127 | def test_step(self, batch, batch_ix):
128 | """
129 | save test result and calculate PSNR and SSIM for `self.net` (when have GT)
130 | """
131 | # test without GT image:
132 | self.global_test_step += 1
133 | input_batch = batch[INPUT]
134 | assert input_batch.shape[0] == 1
135 | output_batch = self(input_batch)
136 | save_num = 1
137 | # visualized_batch = torch.cat([input_batch, output_batch])
138 | # save_num = 2
139 |
140 | # test with GT:
141 | # if GT in batch:
142 | # gt_batch = batch[GT]
143 | # if output_batch.shape != batch[GT].shape:
144 | # print(
145 | # f'[[ WARN ]] output.shape is {output_batch.shape} but GT.shape is {batch[GT].shape}. Resize to get PSNR.')
146 | # gt_batch = F.interpolate(batch[GT], output_batch.shape[2:])
147 | #
148 | # visualized_batch = torch.cat([visualized_batch, gt_batch])
149 | # save_num = 3
150 | #
151 | # # calculate metrics:
152 | # # psnr = float(self.psnr_func(output_batch, gt_batch).cpu().numpy())
153 | # # ssim = float(self.ssim_func(output_batch, gt_batch).cpu().numpy())
154 | # # ipdb.set_trace()
155 | #
156 | # # output_ = util.cuda_tensor_to_ndarray(output_batch)
157 | # # y_ = util.cuda_tensor_to_ndarray(gt_batch)
158 | # # psnr = util.ImageProcessing.compute_psnr(output_, y_, 1.0)
159 | # # ssim = util.ImageProcessing.compute_ssim(output_, y_)
160 | # # self.log_dict({
161 | # # 'test-' + PSNR: psnr,
162 | # # 'test-' + SSIM: ssim
163 | # # }, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
164 | # # self.total_psnr += psnr
165 | # # self.total_ssim += ssim
166 | # # print(
167 | # # f'{batch[INPUT_FPATH][0].split("/")[-1]}: psnr: {psnr:.4f}, ssim: {ssim:.4f}, avgpsnr: {self.total_psnr / self.global_test_step:.4f}, avgssim: {self.total_ssim / self.global_test_step:.4f}')
168 | #
169 | # # if batch[GT_FPATH][0].endswith('.hdr'):
170 | # # # output and GT -> HDR domain -> tonemap back to LDR
171 | # # output_vis = util.cuda_tensor_to_ndarray(
172 | # # self.logdomain2hdr(visualized_batch[1]).permute(1, 2, 0))
173 | # # gt_vis = util.cuda_tensor_to_ndarray(
174 | # # self.logdomain2hdr(visualized_batch[2]).permute(1, 2, 0))
175 | # # output_ldr = self.tonemapper.process(output_vis)
176 | # # gt_ldr = self.tonemapper.process(gt_vis)
177 | # # visualized_batch[1] = torch.tensor(
178 | # # output_ldr).permute(2, 0, 1).cuda()
179 | # # visualized_batch[2] = torch.tensor(
180 | # # gt_ldr).permute(2, 0, 1).cuda()
181 | # #
182 | # # hdr_outpath = Path(self.opt[IMG_DIRPATH]) / 'hdr_output'
183 | # # util.mkdir(hdr_outpath)
184 | # # cv2.imwrite(
185 | # # str(hdr_outpath / (osp.basename(batch[INPUT_FPATH][0]) + '.hdr')), output_vis)
186 |
187 | # save images
188 | self.save_img_batch(
189 | output_batch,
190 | self.opt[IMG_DIRPATH],
191 | osp.basename(batch[INPUT_FPATH][0]),
192 | save_num=save_num
193 | )
194 |
--------------------------------------------------------------------------------
/src/test.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import time
3 |
4 | import hydra
5 | import pytorch_lightning as pl
6 | from omegaconf import open_dict
7 | from pytorch_lightning import Trainer
8 |
9 | from globalenv import *
10 | from utils.util import parse_config
11 |
12 | pl.seed_everything(GLOBAL_SEED)
13 |
14 |
15 | @hydra.main(config_path='config', config_name="config")
16 | def main(opt):
17 | opt = parse_config(opt, TEST)
18 | print('Running config:', opt)
19 | from model.lcdpnet import LitModel as ModelClass
20 | ckpt = opt[CHECKPOINT_PATH]
21 | assert ckpt
22 | model = ModelClass.load_from_checkpoint(ckpt, opt=opt)
23 | # model.opt = opt
24 | with open_dict(opt):
25 | model.opt[IMG_DIRPATH] = model.build_test_res_dir()
26 | opt.mode = 'test'
27 | print(f'Loading model from: {ckpt}')
28 |
29 | from data.img_dataset import DataModule
30 | datamodule = DataModule(opt)
31 |
32 | trainer = Trainer(
33 | gpus=opt[GPU],
34 | strategy=opt[BACKEND],
35 | precision=opt[RUNTIME_PRECISION])
36 |
37 | beg = time.time()
38 | trainer.test(model, datamodule)
39 | print(f'[ TIMER ] Total time usage: {time.time() - beg}')
40 | print('[ PATH ] The results are in :')
41 | print(model.opt[IMG_DIRPATH])
42 |
43 |
44 | if __name__ == "__main__":
45 | main()
46 |
--------------------------------------------------------------------------------
/src/train.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import os
3 | import platform
4 | import time
5 | import warnings
6 |
7 | import hydra
8 | import pytorch_lightning as pl
9 | import torch
10 | from omegaconf import open_dict
11 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
12 | from pytorch_lightning.loggers import WandbLogger, TensorBoardLogger
13 |
14 | import utils.util as util
15 | from globalenv import *
16 | from utils.util import parse_config, init_logging
17 | from data.img_dataset import DataModule
18 |
19 | pl.seed_everything(GLOBAL_SEED)
20 |
21 |
22 | @hydra.main(config_path='config', config_name="config")
23 | def main(config):
24 | try:
25 | print('GPU status info:')
26 | os.system('nvidia-smi')
27 | except:
28 | ...
29 |
30 | # print(config)
31 | opt = parse_config(config, TRAIN)
32 | if opt.name == DEBUG:
33 | opt.debug = True
34 |
35 | if opt.debug:
36 | # ipdb.set_trace()
37 | mylogger = None
38 | if opt.checkpoint_path:
39 | continue_epoch = torch.load(
40 | opt.checkpoint_path, map_location=torch.device('cpu'))['global_step']
41 | debug_config = {
42 | DATALOADER_N: 0,
43 | NAME: DEBUG,
44 | LOG_EVERY: 1,
45 | VALID_EVERY: 1,
46 | NUM_EPOCH: 2 if not opt.checkpoint_path else continue_epoch + 2
47 | }
48 | opt.update(debug_config)
49 | debug_str = '[red]>>>> [[ WARN ]] You are in debug mode, update configs. <<<<[/red]'
50 | print(f'{debug_str}\n{debug_config}\n{debug_str}')
51 |
52 | else:
53 | # rename the exp
54 | spl = '_' if platform.system() == 'Windows' else ':'
55 | opt.name = f'{opt.runtime.modelname}{spl}{opt.name}@{opt.train_ds.name}'
56 |
57 | # trainer logger. init early to record all console output.
58 | mylogger = TensorBoardLogger(
59 | name=opt.name,
60 | save_dir=ROOT_PATH / 'tb_logs',
61 | )
62 |
63 | # init logging
64 | print('Running config:', opt)
65 | # opt[LOG_DIRPATH], opt.img_dirpath = init_logging(TRAIN, opt)
66 | with open_dict(opt):
67 | opt.log_dirpath, opt.img_dirpath = init_logging(TRAIN, opt)
68 |
69 | # load data
70 | # DataModuleClass = parse_ds_class(opt[TRAIN_DATA][CLASS])
71 | datamodule = DataModule(opt)
72 |
73 | # callbacks:
74 | callbacks = []
75 | if opt[EARLY_STOP]:
76 | print(
77 | f'Apply EarlyStopping when `{opt.checkpoint_monitor}` is {opt.monitor_mode}')
78 | callbacks.append(EarlyStopping(
79 | opt.checkpoint_monitor, mode=opt.monitor_mode))
80 |
81 | # callbacks:
82 | checkpoint_callback = ModelCheckpoint(
83 | dirpath=opt[LOG_DIRPATH],
84 | save_last=True,
85 | save_top_k=5,
86 | mode=opt.monitor_mode,
87 | monitor=opt.checkpoint_monitor,
88 | save_on_train_epoch_end=True,
89 | every_n_epochs=opt.savemodel_every
90 | )
91 | callbacks.append(checkpoint_callback)
92 |
93 | if opt[AMP_BACKEND] != 'native':
94 | print(
95 | f'WARN: Running in APEX, mode: {opt[AMP_BACKEND]}-{opt[AMP_LEVEL]}')
96 | else:
97 | opt[AMP_LEVEL] = None
98 |
99 | # init trainer:
100 | trainer = pl.Trainer(
101 | gpus=opt[GPU],
102 | max_epochs=opt[NUM_EPOCH],
103 | logger=mylogger,
104 | callbacks=callbacks,
105 | check_val_every_n_epoch=opt[VALID_EVERY],
106 | num_sanity_val_steps=opt[VAL_DEBUG_STEP_NUMS],
107 | strategy=opt[BACKEND],
108 | precision=opt[RUNTIME_PRECISION],
109 | amp_backend=opt[AMP_BACKEND],
110 | amp_level=opt[AMP_LEVEL],
111 | **opt.flags
112 | )
113 | print('Trainer initailized.')
114 |
115 | # training loop
116 | from model.lcdpnet import LitModel as ModelClass
117 | if opt.checkpoint_path and not opt.resume_training:
118 | print('Load ckpt and train from step 0...')
119 | model = ModelClass.load_from_checkpoint(opt.checkpoint_path, opt=opt)
120 | trainer.fit(model, datamodule)
121 | else:
122 | model = ModelClass(opt)
123 | print(f'Continue training: {opt.checkpoint_path}')
124 | trainer.fit(model, datamodule, ckpt_path=opt.checkpoint_path)
125 |
126 |
127 | if __name__ == "__main__":
128 | main()
129 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/onpix/LCDPNet/4faa0d98e8ff45f53a3569dd005a74353995b335/src/utils/__init__.py
--------------------------------------------------------------------------------
/src/utils/util.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | import logging
3 | import os
4 | import os.path as osp
5 | import smtplib
6 | from email.mime.multipart import MIMEMultipart
7 | from email.mime.text import MIMEText
8 |
9 | import cv2
10 | import ipdb
11 | import numpy as np
12 | import omegaconf
13 | import torch
14 | import yaml
15 | from PIL import Image
16 | from matplotlib.image import imread
17 | from skimage.metrics import structural_similarity as calc_ssim
18 | from torch.autograd import Variable
19 |
20 | from globalenv import *
21 |
22 |
23 | # from model import parse_model_class
24 |
25 |
26 | def update_global_opt(global_opt, valued_opt):
27 | for k, v in valued_opt.items():
28 | global_opt[k] = v
29 |
30 |
31 | def mkdir(dirpath):
32 | if not osp.exists(dirpath):
33 | print(f'Creating directory: "{dirpath}"')
34 | try:
35 | os.makedirs(dirpath)
36 | except:
37 | ipdb.set_trace()
38 | return
39 | # print(f'Directory {dirpath} already exists, skip creating.')
40 |
41 |
42 | def cuda_tensor_to_ndarray(cuda_tensor):
43 | return cuda_tensor.clone().detach().cpu().numpy()
44 |
45 |
46 | def tensor2pil(img):
47 | img = img.squeeze() # * 0.5 + 0.5
48 | return Image.fromarray(img.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy())
49 |
50 |
51 | #
52 | # def calculate_psnr(img1, img2):
53 | # # img1 and img2 have range [0, 255]
54 | # # shape: [H, W, C]
55 | # img1 = img1.astype(np.float64)
56 | # img2 = img2.astype(np.float64)
57 | # mse = np.mean((img1 - img2) ** 2)
58 | # if mse == 0:
59 | # return float('inf')
60 | # return 20 * np.log10(255.0 / np.sqrt(mse))
61 | #
62 | #
63 | # def ssim_com(img1, img2):
64 | # C1 = (0.01 * 255) ** 2
65 | # C2 = (0.03 * 255) ** 2
66 | #
67 | # img1 = img1.astype(np.float64)
68 | # img2 = img2.astype(np.float64)
69 | # kernel = cv2.getGaussianKernel(11, 1.5)
70 | # window = np.outer(kernel, kernel.transpose())
71 | #
72 | # mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
73 | # mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
74 | # mu1_sq = mu1 ** 2
75 | # mu2_sq = mu2 ** 2
76 | # mu1_mu2 = mu1 * mu2
77 | # sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
78 | # sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
79 | # sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
80 | #
81 | # ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
82 | # (sigma1_sq + sigma2_sq + C2))
83 | # return ssim_map.mean()
84 | #
85 | #
86 | # def calculate_ssim(img1, img2):
87 | # """
88 | # calculate SSIM
89 | # the same outputs as MATLAB's
90 | # img1, img2: [0, 255] numpy array. shape: [H, W, C]
91 | # """
92 | # if not img1.shape == img2.shape:
93 | # raise ValueError('Input images must have the same dimensions.')
94 | # if img1.ndim == 2:
95 | # return ssim_com(img1, img2)
96 | # elif img1.ndim == 3:
97 | # if img1.shape[2] == 3:
98 | # ssims = []
99 | # for i in range(3):
100 | # ssims.append(ssim_com(img1, img2))
101 | # return np.array(ssims).mean()
102 | # elif img1.shape[2] == 1:
103 | # return ssim_com(np.squeeze(img1), np.squeeze(img2))
104 | # else:
105 | # raise ValueError('Wrong input image dimensions.')
106 |
107 |
108 | def save_opt(dirpath, opt):
109 | save_opt_fpath = dirpath / OPT_FILENAME
110 |
111 | with save_opt_fpath.open('w', encoding="utf-8") as f:
112 | yaml.dump(omega2dict(opt), f, default_flow_style=False)
113 |
114 |
115 | def init_logging(mode, opt):
116 | # mode: only for training
117 | assert mode == TRAIN
118 |
119 | log_dirpath = ROOT_PATH / TRAIN_LOG_DIRNAME / opt[RUNTIME][MODELNAME] / opt[NAME]
120 | # + datetime.datetime.now().strftime(LOG_TIME_FORMAT)
121 | img_dirpath = log_dirpath / IMAGES
122 |
123 | mkdir(log_dirpath)
124 | mkdir(img_dirpath)
125 | save_opt(log_dirpath, opt)
126 |
127 | # pl_logger = logging.getLogger("lightning")
128 | # pl_logger.propagate = False
129 |
130 | return str(log_dirpath), str(img_dirpath)
131 |
132 |
133 | def saveTensorAsImg(output, path, downsample_factor=False):
134 | # save image of A BATCH or a SINGLE IMAGE.
135 | # input dtype: must be Tensor
136 | output = output.squeeze(0)
137 |
138 | if len(output.shape) == 4:
139 | # a batch
140 | res = []
141 | for i in range(len(output)):
142 | res.append(saveTensorAsImg(output[i], f'{path}-{i}.png', downsample_factor))
143 | return res
144 |
145 | # a single image
146 | assert len(output.shape) == 3
147 | outImg = cuda_tensor_to_ndarray(
148 | output.permute(1, 2, 0)
149 | ) * 255.0
150 | outImg = outImg[:, :, [2, 1, 0]].astype(np.uint8)
151 |
152 | if downsample_factor:
153 | assert type(downsample_factor + 0.1) == float
154 | h = outImg.shape[0]
155 | w = outImg.shape[1]
156 | outImg = cv2.resize(outImg, (int(w / downsample_factor), int(h / downsample_factor))).astype(np.uint8)
157 |
158 | cv2.imwrite(path, outImg)
159 | return outImg
160 |
161 |
162 | def parse_config(opt, mode):
163 | def checkField(opt, name, raise_msg):
164 | try:
165 | assert name in opt
166 | except:
167 | raise RuntimeError(raise_msg)
168 |
169 | # check necessary argments for ALL MODELS and ALL MODES:
170 | # for x in GENERAL_NECESSARY_ARGUMENTS:
171 | # checkField(opt, x, ARGUMENTS_MISSING_ERRS[x])
172 |
173 | # check necessary argments for all models for EACH MODE:
174 | # if mode == TRAIN:
175 | # necessaryFields = TRAIN_NECESSARY_ARGUMENTS
176 | # elif mode in [TEST, VALID]:
177 | # necessaryFields = TEST_NECESSARY_ARGUMENTS
178 | # else:
179 | # raise NotImplementedError('[ ERR ] In function [checkConfig]: unknown mode', mode)
180 | # for x in necessaryFields:
181 | # checkField(opt, x, ARGUMENTS_MISSING_ERRS[x])
182 |
183 | # make sure the model is implemented:
184 | modelname = opt[RUNTIME][MODELNAME]
185 | # assert parse_model_class(modelname)
186 |
187 | # check fields in runtime config is the same as template.
188 | # use `modelname.default.yaml` as template.
189 | runtime_config_dir = SRC_PATH.absolute() / CONFIG_DIR / RUNTIME
190 | template_yml_path = runtime_config_dir / f'{modelname}.default.yaml'
191 | print(f'Check runtime config: use "{template_yml_path}" as template.')
192 | assert template_yml_path.exists()
193 | # for x in load_yml(str(template_yml_path)):
194 | # checkField(opt[RUNTIME], x, f'[ ERR ] Runtime config missing argument: {x}')
195 |
196 | # if type(opt) == omegaconf.DictConfig:
197 | # return omegaconf.OmegaConf.to_container(opt)
198 |
199 | pl_logger = logging.getLogger("lightning")
200 | pl_logger.propagate = False
201 | return opt
202 |
203 |
204 | def omega2dict(opt):
205 | if type(opt) == omegaconf.DictConfig:
206 | return omegaconf.OmegaConf.to_container(opt)
207 | else:
208 | return opt
209 |
210 |
211 | # (Discarded)
212 | def load_yml(ymlpath):
213 | '''
214 | input config file path (yml file), return config dict.
215 | '''
216 | print(f'* Reading config from: {ymlpath}')
217 |
218 | if ymlpath.startswith('http'):
219 | import requests
220 | ymlContent = requests.get(ymlpath).content
221 | else:
222 | ymlContent = open(ymlpath, 'r').read()
223 |
224 | yml = yaml.load(ymlContent, Loader=yaml.FullLoader)
225 | return yml
226 |
227 |
228 | class ImageProcessing(object):
229 |
230 | @staticmethod
231 | def rgb_to_lab(img, is_training=True):
232 | """ PyTorch implementation of RGB to LAB conversion: https://docs.opencv.org/3.3.0/de/d25/imgproc_color_conversions.html
233 | Based roughly on a similar implementation here: https://github.com/affinelayer/pix2pix-tensorflow/blob/master/pix2pix.py
234 | :param img:
235 | :returns:
236 | :rtype:
237 |
238 | """
239 | img = img.permute(2, 1, 0)
240 | shape = img.shape
241 | img = img.contiguous()
242 | img = img.view(-1, 3)
243 |
244 | img = (img / 12.92) * img.le(0.04045).float() + (((torch.clamp(img,
245 | min=0.0001) + 0.055) / 1.055) ** 2.4) * img.gt(
246 | 0.04045).float()
247 |
248 | rgb_to_xyz = Variable(torch.FloatTensor([ # X Y Z
249 | [0.412453, 0.212671,
250 | 0.019334], # R
251 | [0.357580, 0.715160,
252 | 0.119193], # G
253 | [0.180423, 0.072169,
254 | 0.950227], # B
255 | ]), requires_grad=False).type_as(img)
256 |
257 | img = torch.matmul(img, rgb_to_xyz)
258 | img = torch.mul(img, Variable(torch.FloatTensor(
259 | [1 / 0.950456, 1.0, 1 / 1.088754]), requires_grad=False).type_as(img))
260 |
261 | epsilon = 6 / 29
262 |
263 | img = ((img / (3.0 * epsilon ** 2) + 4.0 / 29.0) * img.le(epsilon ** 3).float()) + \
264 | (torch.clamp(img, min=0.0001) ** (1.0 / 3.0) * img.gt(epsilon ** 3).float())
265 |
266 | fxfyfz_to_lab = Variable(torch.FloatTensor([[0.0, 500.0, 0.0], # fx
267 | [116.0, -500.0, 200.0], # fy
268 | [0.0, 0.0, -200.0], # fz
269 | ]), requires_grad=False).type_as(img)
270 |
271 | img = torch.matmul(img, fxfyfz_to_lab) + Variable(
272 | torch.FloatTensor([-16.0, 0.0, 0.0]), requires_grad=False).type_as(img)
273 |
274 | img = img.view(shape)
275 | img = img.permute(2, 1, 0)
276 |
277 | '''
278 | L_chan: black and white with input range [0, 100]
279 | a_chan/b_chan: color channels with input range ~[-110, 110], not exact
280 | [0, 100] => [0, 1], ~[-110, 110] => [0, 1]
281 | '''
282 | img[0, :, :] = img[0, :, :] / 100
283 | img[1, :, :] = (img[1, :, :] / 110 + 1) / 2
284 | img[2, :, :] = (img[2, :, :] / 110 + 1) / 2
285 |
286 | img[(img != img).detach()] = 0
287 |
288 | img = img.contiguous()
289 |
290 | return img
291 |
292 | @staticmethod
293 | def swapimdims_3HW_HW3(img):
294 | """Move the image channels to the first dimension of the numpy
295 | multi-dimensional array
296 |
297 | :param img: numpy nd array representing the image
298 | :returns: numpy nd array with permuted axes
299 | :rtype: numpy nd array
300 |
301 | """
302 | if img.ndim == 3:
303 | return np.swapaxes(np.swapaxes(img, 1, 2), 0, 2)
304 | elif img.ndim == 4:
305 | return np.swapaxes(np.swapaxes(img, 2, 3), 1, 3)
306 |
307 | @staticmethod
308 | def swapimdims_HW3_3HW(img):
309 | """Move the image channels to the last dimensiion of the numpy
310 | multi-dimensional array
311 |
312 | :param img: numpy nd array representing the image
313 | :returns: numpy nd array with permuted axes
314 | :rtype: numpy nd array
315 |
316 | """
317 | if img.ndim == 3:
318 | return np.swapaxes(np.swapaxes(img, 0, 2), 1, 2)
319 | elif img.ndim == 4:
320 | return np.swapaxes(np.swapaxes(img, 1, 3), 2, 3)
321 |
322 | @staticmethod
323 | def load_image(img_filepath, normaliser):
324 | """Loads an image from file as a numpy multi-dimensional array
325 |
326 | :param img_filepath: filepath to the image
327 | :returns: image as a multi-dimensional numpy array
328 | :rtype: multi-dimensional numpy array
329 |
330 | """
331 | img = ImageProcessing.normalise_image(
332 | imread(img_filepath), normaliser) # NB: imread normalises to 0-1
333 | return img
334 |
335 | @staticmethod
336 | def normalise_image(img, normaliser):
337 | """Normalises image data to be a float between 0 and 1
338 |
339 | :param img: Image as a numpy multi-dimensional image array
340 | :returns: Normalised image as a numpy multi-dimensional image array
341 | :rtype: Numpy array
342 |
343 | """
344 | img = img.astype('float32') / normaliser
345 | return img
346 |
347 | @staticmethod
348 | def compute_mse(original, result):
349 | """Computes the mean squared error between to RGB images represented as multi-dimensional numpy arrays.
350 |
351 | :param original: input RGB image as a numpy array
352 | :param result: target RGB image as a numpy array
353 | :returns: the mean squared error between the input and target images
354 | :rtype: float
355 |
356 | """
357 | return ((original - result) ** 2).mean()
358 |
359 | @staticmethod
360 | def compute_psnr(image_batchA, image_batchB, max_intensity):
361 | """Computes the average PSNR for a batch of input and output images
362 | could be used during training / validation
363 |
364 | :param image_batchA: numpy nd-array representing the image batch A of shape Bx3xWxH
365 | :param image_batchB: numpy nd-array representing the image batch A of shape Bx3xWxH
366 | :param max_intensity: maximum intensity possible in the image (e.g. 255)
367 | :returns: average PSNR for the batch of images
368 | :rtype: float
369 |
370 | """
371 | num_images = image_batchA.shape[0]
372 | psnr_val = 0.0
373 |
374 | for i in range(0, num_images):
375 | imageA = image_batchA[i, 0:3, :, :]
376 | imageB = image_batchB[i, 0:3, :, :]
377 | imageB = np.maximum(0, np.minimum(imageB, max_intensity))
378 | psnr_val += 10 * \
379 | np.log10(max_intensity ** 2 /
380 | ImageProcessing.compute_mse(imageA, imageB))
381 |
382 | return psnr_val / num_images
383 |
384 | @staticmethod
385 | def compute_ssim(image_batchA, image_batchB):
386 | """Computes the SSIM for a batch of input and output images
387 |
388 | :param image_batchA: numpy nd-array representing the image batch A of shape Bx3xWxH
389 | :param image_batchB: numpy nd-array representing the image batch A of shape Bx3xWxH
390 | :param max_intensity: maximum intensity possible in the image (e.g. 255)
391 | :returns: average PSNR for the batch of images
392 | :rtype: float
393 |
394 | """
395 | num_images = image_batchA.shape[0]
396 | ssim_val = 0.0
397 |
398 | for i in range(0, num_images):
399 | imageA = ImageProcessing.swapimdims_3HW_HW3(
400 | image_batchA[i, 0:3, :, :])
401 | imageB = ImageProcessing.swapimdims_3HW_HW3(
402 | image_batchB[i, 0:3, :, :])
403 | ssim_val += calc_ssim(imageA, imageB, data_range=imageA.max() - imageA.min(), multichannel=True,
404 | gaussian_weights=True, win_size=11)
405 |
406 | return ssim_val / num_images
407 |
--------------------------------------------------------------------------------