├── README.md
├── config
├── lolv1.yml
├── lolv1_test.json
├── lolv1_train.json
├── lolv2_real.yml
├── lolv2_real_test.json
├── lolv2_real_train.json
├── lolv2_syn.yml
├── lolv2_syn_test.json
├── lolv2_syn_train.json
└── test_unpaired.json
├── core
├── __pycache__
│ ├── logger.cpython-38.pyc
│ └── metrics.cpython-38.pyc
├── logger.py
└── metrics.py
├── data
├── LoL_dataset.py
├── __init__.py
├── __pycache__
│ ├── LoL_dataset.cpython-38.pyc
│ └── __init__.cpython-38.pyc
├── single_image_dataset.py
└── util.py
├── dataset
└── LOLv1
│ └── readme.txt
├── model
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ ├── base_model.cpython-38.pyc
│ ├── model.cpython-38.pyc
│ └── networks.cpython-38.pyc
├── base_model.py
├── ddpm_modules
│ ├── __pycache__
│ │ ├── diffusion.cpython-38.pyc
│ │ └── unet.cpython-38.pyc
│ ├── diffusion.py
│ └── unet.py
├── model.py
└── networks.py
├── options
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-38.pyc
│ └── options.cpython-38.pyc
└── options.py
├── requirements.txt
├── test.py
├── test.sh
├── test_unpaired.py
├── train.py
├── train_lol1.sh
├── train_lol2_real.sh
├── train_lol2_syn.sh
└── utils
├── __init__.py
├── __pycache__
├── __init__.cpython-38.pyc
├── ema.cpython-38.pyc
└── util.cpython-38.pyc
├── ema.py
├── niqe.py
├── niqe_image_params.mat
└── util.py
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # **[CVPR2025]** Efficient Diffusion as Low Light Enhancer
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
14 | ## :fire: News
15 |
16 | - [2025/03/04] We have released the training code and inference code! 🚀🚀
17 | - [2025/02/27] ReDDiT has been accepted to CVPR 2025! 🤗🤗
18 |
19 | ## :memo: TODO
20 |
21 | - [x] Training code
22 | - [x] Inference code
23 | - [x] CVPR Camera-ready Version
24 | - [x] Project page
25 | - [ ] Journal Version & Teacher Model
26 |
27 | ## :hammer: Get Started
28 |
29 | ### :mag: Dependencies and Installation
30 |
31 | - Python 3.8
32 | - Pytorch 1.11
33 |
34 | 1. Create Conda Environment
35 |
36 | ```
37 | conda create --name ReDDiT python=3.8
38 | conda activate ReDDiT
39 | ```
40 |
41 | 2. Install PyTorch
42 |
43 | ```
44 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113
45 | ```
46 |
47 | 3. Install Dependencies
48 |
49 | ```
50 | cd ReDDiT
51 | pip install -r requirements.txt
52 | ```
53 |
54 | ### :page_with_curl: Data Preparation
55 |
56 | You can refer to the following links to download the datasets.
57 |
58 | - [LOLv1](https://daooshee.github.io/BMVC2018website/)
59 | - [LOLv2](https://github.com/flyywh/CVPR-2020-Semi-Low-Light)
60 |
61 | Then, put them in the following folder:
62 |
63 | dataset (click to expand)
64 |
65 | ```
66 | ├── dataset
67 | ├── LOLv1
68 | ├── our485
69 | ├──low
70 | ├──high
71 | ├── eval15
72 | ├──low
73 | ├──high
74 | ├── dataset
75 | ├── LOLv2
76 | ├── Real_captured
77 | ├── Train
78 | ├── Test
79 | ├── Synthetic
80 | ├── Train
81 | ├── Test
82 | ```
83 |
84 |
85 |
86 | ### :blue_book: Testing
87 |
88 | Note: Following LLFlow and KinD, we have also adjusted the brightness of the output image produced by the network, based on the average value of Ground Truth (GT). ``It should be noted that this adjustment process does not influence the texture details generated; it is merely a straightforward method to regulate the overall illumination.`` Moreover, it can be easily adjusted according to user preferences in practical applications.
89 |
90 | You can also refer to the following links to download the checkpoints from [Google Drive](https://drive.google.com/file/d/13_XM8nFxJc2IfUotC2_lJo9ATt0rcIyg/view?usp=sharing) or [百度网盘 (Baidu Netdisk)](https://pan.baidu.com/s/1J7MP33Ws5kE673F-8zc2RA?pwd=nj8b) and put it in the following folder:
91 |
92 | ```
93 | ├── checkpoints
94 | ├── lolv1_8step_gen.pth
95 | ├── lolv1_4step_gen.pth
96 | ├── lolv1_2step_gen.pth
97 | ......
98 | ```
99 | To test the model using the ``sh test.sh`` command and modify the `n_timestep` and `time_scale` parameters for different step models. Here's a general outline of the steps:
100 | ```
101 | "val": {
102 | "schedule": "linear",
103 | "n_timestep": 8,
104 | "linear_start": 1e-4,
105 | "linear_end": 2e-2,
106 | "time_scale": 64
107 | }
108 | ```
109 |
110 | ```
111 | "val": {
112 | "schedule": "linear",
113 | "n_timestep": 4,
114 | "linear_start": 1e-4,
115 | "linear_end": 2e-2,
116 | "time_scale": 128
117 | }
118 | ```
119 |
120 | ```
121 | "val": {
122 | "schedule": "linear",
123 | "n_timestep": 2,
124 | "linear_start": 1e-4,
125 | "linear_end": 2e-2,
126 | "time_scale": 256
127 | }
128 | ```
129 | ### :blue_book: Testing on unpaired data
130 |
131 | ```
132 | python test_unpaired.py --config config/test_unpaired.json --input unpaired_image_folder
133 | ```
134 |
135 | You can use any one of these three pre-trained models, and employ different sampling steps to obtain visual-pleasing results by modifying these terms in the 'test_unpaired.json'.
136 |
137 |
138 |
139 | ### :rocket: Training
140 |
141 | ```
142 | bash train.sh
143 | ```
144 |
145 |
146 | ## :black_nib: Citation
147 |
148 | If you find our repo useful for your research, please consider citing our paper:
149 |
150 | ```bibtex
151 | @InProceedings{Lan_2025_CVPR,
152 | author = {Lan, Guanzhou and Ma, Qianli and Yang, Yuqi and Wang, Zhigang and Wang, Dong and Li, Xuelong and Zhao, Bin},
153 | title = {Efficient Diffusion as Low Light Enhancer},
154 | booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)},
155 | month = {June},
156 | year = {2025},
157 | pages = {21277-21286}
158 | }
159 | ```
160 |
161 |
162 | ## :heart: Acknowledgement
163 |
164 | Our code is built upon [SR3](https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement). Thanks to the contributors for their great work.
165 |
--------------------------------------------------------------------------------
/config/lolv1.yml:
--------------------------------------------------------------------------------
1 |
2 | dataset: LOLv1
3 |
4 | #### datasets
5 | datasets:
6 | train:
7 | dist: False
8 | root: ./dataset/LOLv1
9 | use_shuffle: true
10 | n_workers: 8
11 | batch_size: 16
12 | use_flip: true
13 | use_crop: true
14 | patch_size: 96
15 |
16 | val:
17 | dist: False
18 | root: ./dataset/LOLv1
19 | n_workers: 1
20 | use_crop: true
21 | batch_size: 1
22 |
23 |
24 |
25 |
--------------------------------------------------------------------------------
/config/lolv1_test.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "lolv1_test_8",
3 | "phase": "test",
4 | "distill": false,
5 | "gpu_ids": [
6 | 0
7 | ],
8 | "path": {
9 | "log": "logs",
10 | "tb_logger": "tb_logger",
11 | "results": "results",
12 | "checkpoint": "checkpoint",
13 | "resume_state": "./checkpoint/lolv1_4step_gen.pth"
14 | },
15 | "model": {
16 | "which_model_G": "ddpm",
17 | "finetune_norm": false,
18 | "unet": {
19 | "in_channel": 6,
20 | "out_channel": 3,
21 | "inner_channel": 64,
22 | "channel_multiplier": [
23 | 1,
24 | 1,
25 | 2,
26 | 2,
27 | 4
28 | ],
29 | "attn_res": [
30 | 16
31 | ],
32 | "res_blocks": 2,
33 | "dropout": 0
34 | },
35 | "beta_schedule": {
36 | "train": {
37 | "schedule": "linear",
38 | "n_timestep": 4,
39 | "linear_start": 1e-4,
40 | "linear_end": 2e-2,
41 | "time_scale": 128
42 |
43 |
44 | },
45 | "val": {
46 | "schedule": "linear",
47 | "n_timestep": 4,
48 | "linear_start": 1e-4,
49 | "linear_end": 2e-2,
50 | "time_scale": 128
51 |
52 |
53 | }
54 | },
55 | "diffusion": {
56 | "image_size": 128,
57 | "channels": 6,
58 | "conditional": true,
59 | "w_gt": 0.1,
60 | "w_snr": 0.5,
61 | "w_str": 0.1,
62 | "w_lpips": 0.2
63 | }
64 | },
65 | "train": {
66 | "n_iter": 1000000,
67 | "val_freq": 1e4,
68 | "save_checkpoint_freq": 5e4,
69 | "print_freq": 200,
70 | "optimizer": {
71 | "type": "adam",
72 | "lr": 1e-4,
73 | "lr_policy":"linear",
74 | "lr_decay_iters":3000,
75 | "n_lr_iters": 2000
76 | },
77 | "ema_scheduler": {
78 | "step_start_ema": 5000,
79 | "update_ema_every": 1,
80 | "ema_decay": 0.9999
81 | }
82 | },
83 | "wandb": {
84 | "project": "llie_ddpm"
85 | }
86 | }
--------------------------------------------------------------------------------
/config/lolv1_train.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "lolv1_train",
3 | "phase": "train",
4 | "distill": true,
5 | "gpu_ids": [
6 | 0
7 | ],
8 | "path": {
9 | "log": "logs",
10 | "tb_logger": "tb_logger",
11 | "results": "results",
12 | "checkpoint": "checkpoint",
13 | "resume_state": "./checkpoint/lolv1_4step_gen.pth"
14 | },
15 | "model": {
16 | "which_model_G": "ddpm",
17 | "finetune_norm": false,
18 | "unet": {
19 | "in_channel": 6,
20 | "out_channel": 3,
21 | "inner_channel": 64,
22 | "channel_multiplier": [
23 | 1,
24 | 1,
25 | 2,
26 | 2,
27 | 4
28 | ],
29 | "attn_res": [
30 | 16
31 | ],
32 | "res_blocks": 2,
33 | "dropout": 0
34 | },
35 | "beta_schedule": {
36 | "train": {
37 | "schedule": "linear",
38 | "n_timestep": 513,
39 | "linear_start": 1e-4,
40 | "linear_end": 2e-2,
41 | "time_scale": 1,
42 | "reflow": false
43 | },
44 | "val": {
45 | "schedule": "linear",
46 | "n_timestep": 513,
47 | "linear_start": 1e-4,
48 | "linear_end": 2e-2
49 |
50 | }
51 | },
52 | "diffusion": {
53 | "image_size": 128,
54 | "channels": 6,
55 | "conditional": true,
56 | "w_gt": 0.1,
57 | "w_snr": 0.5,
58 | "w_str": 0.1,
59 | "w_lpips": 0.2
60 | }
61 | },
62 | "train": {
63 | "n_iter": 5000,
64 | "val_freq": 100,
65 | "save_checkpoint_freq": 100,
66 | "print_freq": 200,
67 | "optimizer": {
68 | "type": "adam",
69 | "lr": 1e-4,
70 | "lr_policy":"linear",
71 | "lr_decay_iters":3000,
72 | "n_lr_iters": 2000
73 | },
74 | "ema_scheduler": {
75 | "step_start_ema": 5000,
76 | "update_ema_every": 1,
77 | "ema_decay": 0.9999
78 | }
79 | },
80 | "wandb": {
81 | "project": "llie_ddpm"
82 | }
83 | }
84 |
--------------------------------------------------------------------------------
/config/lolv2_real.yml:
--------------------------------------------------------------------------------
1 |
2 | dataset: LOLv2
3 |
4 | #### datasets
5 | datasets:
6 | train:
7 | dist: False
8 | root: ./dataset/LOL-v2
9 | use_shuffle: true
10 | n_workers: 8
11 | batch_size: 16
12 | use_flip: true
13 | use_crop: true
14 | patch_size: 96
15 | sub_data: Real_captured
16 |
17 | val:
18 | dist: False
19 | root: ./dataset/LOL-v2
20 | n_workers: 1
21 | batch_size: 1
22 | sub_data: Real_captured
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/config/lolv2_real_test.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "lolv2_test_real",
3 | "phase": "test",
4 | "gpu_ids": [
5 | 0
6 | ],
7 |
8 | "path": {
9 | "log": "logs",
10 | "tb_logger": "tb_logger",
11 | "results": "results",
12 | "checkpoint": "checkpoint",
13 | "resume_state": "./checkpoint/lolv2_real_4step_gen.pth"
14 | },
15 | "freq_aware": false,
16 | "freq_awareUNet": {
17 | "b1": 1.6,
18 | "b2": 1.6,
19 | "s1": 0.9,
20 | "s2": 0.9
21 | },
22 | "model": {
23 | "which_model_G": "ddpm",
24 | "finetune_norm": false,
25 | "unet": {
26 | "in_channel": 6,
27 | "out_channel": 3,
28 | "inner_channel": 64,
29 | "channel_multiplier": [
30 | 1,
31 | 1,
32 | 2,
33 | 2,
34 | 4
35 | ],
36 | "attn_res": [
37 | 16
38 | ],
39 | "res_blocks": 2,
40 | "dropout": 0
41 | },
42 | "beta_schedule": {
43 | "train": {
44 | "schedule": "linear",
45 | "n_timestep": 4,
46 | "linear_start": 1e-4,
47 | "linear_end": 2e-2,
48 | "time_scale": 128
49 | },
50 | "val": {
51 | "schedule": "linear",
52 | "n_timestep": 4,
53 | "linear_start": 1e-4,
54 | "linear_end": 2e-2,
55 | "time_scale": 128
56 | }
57 | },
58 | "diffusion": {
59 | "image_size": 128,
60 | "channels": 6,
61 | "conditional": true
62 | }
63 | },
64 | "train": {
65 | "n_iter": 1000000,
66 | "val_freq": 1e4,
67 | "save_checkpoint_freq": 5e4,
68 | "print_freq": 200,
69 | "optimizer": {
70 | "type": "adam",
71 | "lr": 1e-4
72 | },
73 | "ema_scheduler": {
74 | "step_start_ema": 5000,
75 | "update_ema_every": 1,
76 | "ema_decay": 0.9999
77 | }
78 | },
79 | "wandb": {
80 | "project": "llie_ddpm"
81 | }
82 | }
--------------------------------------------------------------------------------
/config/lolv2_real_train.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "lolv2_train_real",
3 | "phase": "train",
4 | "distill": true,
5 | "CD":false,
6 | "gpu_ids": [
7 | 0
8 | ],
9 |
10 | "path": {
11 | "log": "logs",
12 | "tb_logger": "tb_logger",
13 | "results": "results",
14 | "checkpoint": "checkpoints",
15 | "resume_state": "./checkpoint/lolv2_real_4step_gen.pth"
16 | },
17 | "model": {
18 | "which_model_G": "ddpm",
19 | "finetune_norm": false,
20 | "unet": {
21 | "in_channel": 6,
22 | "out_channel": 3,
23 | "inner_channel": 64,
24 | "channel_multiplier": [
25 | 1,
26 | 1,
27 | 2,
28 | 2,
29 | 4
30 | ],
31 | "attn_res": [
32 | 16
33 | ],
34 | "res_blocks": 2,
35 | "dropout": 0
36 | },
37 | "beta_schedule": {
38 | "train": {
39 | "schedule": "linear",
40 | "n_timestep": 513,
41 | "linear_start": 1e-4,
42 | "linear_end": 2e-2,
43 | "reflow": false,
44 | "time_scale": 1
45 | },
46 | "val": {
47 | "schedule": "linear",
48 |
49 | "n_timestep": 513,
50 | "linear_start": 1e-4,
51 | "linear_end": 2e-2,
52 | "time_scale": 1
53 | }
54 | },
55 | "diffusion": {
56 | "image_size": 128,
57 | "channels": 6,
58 | "conditional": true,
59 | "w_gt": 0.1,
60 | "w_snr": 0.5,
61 | "w_str": 0.1,
62 | "w_lpips": 0.2
63 |
64 | }
65 | },
66 | "train": {
67 | "n_iter": 5000,
68 | "val_freq": 100,
69 | "save_checkpoint_freq": 100,
70 | "print_freq": 100,
71 | "optimizer": {
72 | "type": "adam",
73 | "lr": 1e-4,
74 | "lr_policy":"linear",
75 | "lr_decay_iters":3000,
76 | "n_lr_iters": 2000
77 | },
78 | "ema_scheduler": {
79 | "step_start_ema": 5000,
80 | "update_ema_every": 1,
81 | "ema_decay": 0.9999
82 | }
83 | },
84 | "wandb": {
85 | "project": "llie_ddpm"
86 | }
87 | }
--------------------------------------------------------------------------------
/config/lolv2_syn.yml:
--------------------------------------------------------------------------------
1 |
2 | dataset: LOLv2
3 |
4 | #### datasets
5 | datasets:
6 | train:
7 | dist: False
8 | root: ./dataset/LOL-v2
9 | use_shuffle: true
10 | n_workers: 8
11 | batch_size: 16
12 | use_flip: true
13 | use_crop: true
14 | patch_size: 96
15 | sub_data: Synthetic
16 |
17 | val:
18 | dist: False
19 | root: ./dataset/LOL-v2
20 | n_workers: 1
21 | batch_size: 1
22 | sub_data: Synthetic
23 |
24 |
25 |
26 |
--------------------------------------------------------------------------------
/config/lolv2_syn_test.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "lolv2_test_syn",
3 | "phase": "test",
4 | "gpu_ids": [
5 | 0
6 | ],
7 | "path": {
8 | "log": "logs",
9 | "tb_logger": "tb_logger",
10 | "results": "results",
11 | "checkpoint": "checkpoint",
12 | "resume_state": "./checkpoint/lolv2_syn_4step_gen.pth"
13 | },
14 | "model": {
15 | "which_model_G": "ddpm",
16 | "finetune_norm": false,
17 | "unet": {
18 | "in_channel": 6,
19 | "out_channel": 3,
20 | "inner_channel": 64,
21 | "channel_multiplier": [
22 | 1,
23 | 1,
24 | 2,
25 | 2,
26 | 4
27 | ],
28 | "attn_res": [
29 | 16
30 | ],
31 | "res_blocks": 2,
32 | "dropout": 0
33 | },
34 | "beta_schedule": {
35 | "train": {
36 | "schedule": "linear",
37 | "n_timestep": 4,
38 | "linear_start": 1e-4,
39 | "linear_end": 2e-2,
40 | "time_scale": 128
41 | },
42 | "val": {
43 | "schedule": "linear",
44 | "n_timestep": 4,
45 | "linear_start": 1e-4,
46 | "linear_end": 2e-2,
47 | "time_scale": 128
48 | }
49 | },
50 | "diffusion": {
51 | "image_size": 128,
52 | "channels": 6,
53 | "conditional": true
54 | }
55 | },
56 | "train": {
57 | "n_iter": 1000000,
58 | "val_freq": 1e4,
59 | "save_checkpoint_freq": 5e4,
60 | "print_freq": 200,
61 | "optimizer": {
62 | "type": "adam",
63 | "lr": 1e-4
64 | },
65 | "ema_scheduler": {
66 | "step_start_ema": 5000,
67 | "update_ema_every": 1,
68 | "ema_decay": 0.9999
69 | }
70 | },
71 | "wandb": {
72 | "project": "llie_ddpm"
73 | }
74 | }
--------------------------------------------------------------------------------
/config/lolv2_syn_train.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "lolv2_train_syn",
3 | "phase": "train",
4 | "distill": true,
5 | "gpu_ids": [
6 | 0
7 | ],
8 | "path": {
9 | "log": "logs",
10 | "tb_logger": "tb_logger",
11 | "results": "results",
12 | "checkpoint": "checkpoint",
13 | "resume_state": "./checkpoint/lolv2_syn_4step_gen.pth"
14 | },
15 | "model": {
16 | "which_model_G": "ddpm",
17 | "finetune_norm": false,
18 | "unet": {
19 | "in_channel": 6,
20 | "out_channel": 3,
21 | "inner_channel": 64,
22 | "channel_multiplier": [
23 | 1,
24 | 1,
25 | 2,
26 | 2,
27 | 4
28 | ],
29 | "attn_res": [
30 | 16
31 | ],
32 | "res_blocks": 2,
33 | "dropout": 0
34 | },
35 | "beta_schedule": {
36 | "train": {
37 | "schedule": "linear",
38 | "n_timestep": 513,
39 | "linear_start": 1e-4,
40 | "linear_end": 2e-2,
41 | "reflow": false,
42 | "time_scale": 1
43 | },
44 | "val": {
45 | "schedule": "linear",
46 | "n_timestep": 513,
47 | "linear_start": 1e-4,
48 | "linear_end": 2e-2
49 | }
50 | },
51 | "diffusion": {
52 | "image_size": 128,
53 | "channels": 6,
54 | "conditional": true,
55 | "w_gt": 0.1,
56 | "w_snr": 0.5,
57 | "w_str": 0.1,
58 | "w_lpips": 0.2
59 | }
60 | },
61 | "train": {
62 | "n_iter": 5000 ,
63 | "val_freq": 100,
64 | "save_checkpoint_freq": 100,
65 | "print_freq": 100,
66 | "optimizer": {
67 | "type": "adam",
68 | "lr": 1e-4,
69 | "lr_policy":"linear",
70 | "lr_decay_iters":3000,
71 | "n_lr_iters": 2000
72 | },
73 | "ema_scheduler": {
74 | "step_start_ema": 5000,
75 | "update_ema_every": 1,
76 | "ema_decay": 0.9999
77 | }
78 | },
79 | "wandb": {
80 | "project": "llie_ddpm"
81 | }
82 | }
--------------------------------------------------------------------------------
/config/test_unpaired.json:
--------------------------------------------------------------------------------
1 | {
2 | "name": "test_unpaired",
3 | "phase": "test",
4 | "gpu_ids": [
5 | 0
6 | ],
7 | "path": {
8 | "log": "logs",
9 | "tb_logger": "tb_logger",
10 | "results": "results",
11 | "checkpoint": "checkpoint",
12 | "resume_state": "experiments/lolv2_train_syn/lolv2_train_syn_w_snr:0.2_w_str:0.0_w_gt:1.0_w_lpips:0.6_240410_125713/checkpoint/num_step_8/psnr30.1459_ssim0.9424_lpips0.0284_I2700_E47_gen_ema.pth"
13 | },
14 | "freq_aware": false,
15 | "freq_awareUNet": {
16 | "b1": 1.6,
17 | "b2": 1.6,
18 | "s1": 0.9,
19 | "s2": 0.9
20 | },
21 | "model": {
22 | "which_model_G": "ddpm",
23 | "finetune_norm": false,
24 | "unet": {
25 | "in_channel": 6,
26 | "out_channel": 3,
27 | "inner_channel": 64,
28 | "channel_multiplier": [
29 | 1,
30 | 1,
31 | 2,
32 | 2,
33 | 4
34 | ],
35 | "attn_res": [
36 | 16
37 | ],
38 | "res_blocks": 2,
39 | "dropout": 0
40 | },
41 | "beta_schedule": {
42 | "train": {
43 | "schedule": "linear",
44 | "n_timestep": 4,
45 | "linear_start": 1e-4,
46 | "linear_end": 2e-2,
47 | "time_scale": 128
48 | },
49 | "val": {
50 | "schedule": "linear",
51 | "n_timestep": 4,
52 | "linear_start": 1e-4,
53 | "linear_end": 2e-2,
54 | "time_scale": 128
55 | }
56 | },
57 | "diffusion": {
58 | "image_size": 128,
59 | "channels": 6,
60 | "conditional": true
61 | }
62 | },
63 | "train": {
64 | "n_iter": 1000000,
65 | "val_freq": 1e4,
66 | "save_checkpoint_freq": 5e4,
67 | "print_freq": 200,
68 | "optimizer": {
69 | "type": "adam",
70 | "lr": 1e-4
71 | },
72 | "ema_scheduler": {
73 | "step_start_ema": 5000,
74 | "update_ema_every": 1,
75 | "ema_decay": 0.9999
76 | }
77 | },
78 | "wandb": {
79 | "project": "llie_ddpm"
80 | }
81 | }
--------------------------------------------------------------------------------
/core/__pycache__/logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/core/__pycache__/logger.cpython-38.pyc
--------------------------------------------------------------------------------
/core/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/core/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/core/logger.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import logging
4 | from collections import OrderedDict
5 | import json
6 | from datetime import datetime
7 |
8 |
9 | def mkdirs(paths):
10 | if isinstance(paths, str):
11 | os.makedirs(paths, exist_ok=True)
12 | else:
13 | for path in paths:
14 | os.makedirs(path, exist_ok=True)
15 |
16 |
17 | def get_timestamp():
18 | return datetime.now().strftime('%y%m%d_%H%M%S')
19 |
20 |
21 | def parse(args):
22 | phase = args.phase
23 | opt_path = args.config
24 | gpu_ids = args.gpu_ids
25 | # remove comments starting with '//'
26 | json_str = ''
27 | with open(opt_path, 'r') as f:
28 | for line in f:
29 | line = line.split('//')[0] + '\n'
30 | json_str += line
31 | opt = json.loads(json_str, object_pairs_hook=OrderedDict)
32 |
33 | # set log directory
34 | if args.debug:
35 | opt['name'] = 'debug_{}'.format(opt['name'])
36 | if args.brutal_search:
37 | experiments_root = os.path.join(
38 | 'experiments', opt['name'], '{}_noise_start:{}_noise_end:{}_{}'.format(opt['name'], args.noise_start, args.noise_end, get_timestamp()))
39 | else:
40 | if opt['phase'] == 'train':
41 | experiments_root = os.path.join(
42 | 'experiments', opt['name'], '{}_w_snr:{}_w_str:{}_w_gt:{}_w_lpips:{}_{}'.format(opt['name'], args.w_snr, args.w_str, args.w_gt, args.w_lpips, get_timestamp()))
43 | elif opt['phase'] == 'test':
44 | experiments_root = os.path.join(
45 | 'experiments', opt['name'], '{}_numstep:{}_w_snr:{}_w_gt:{}_w_lpips:{}_{}'.format(opt['name'],opt["model"]['beta_schedule']['val']['n_timestep'], args.w_snr, args.w_gt, args.w_lpips, get_timestamp()))
46 | opt['path']['experiments_root'] = experiments_root
47 | for key, path in opt['path'].items():
48 | if 'resume' not in key and 'experiments' not in key:
49 | opt['path'][key] = os.path.join(experiments_root, path)
50 | mkdirs(opt['path'][key])
51 |
52 | # change dataset length limit
53 | opt['phase'] = phase
54 |
55 | # export CUDA_VISIBLE_DEVICES
56 | if gpu_ids is not None:
57 | opt['gpu_ids'] = [int(id) for id in gpu_ids.split(',')]
58 | gpu_list = gpu_ids
59 | else:
60 | gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
61 |
62 | if len(gpu_list) > 1:
63 | opt['distributed'] = True
64 | else:
65 | opt['distributed'] = False
66 |
67 | # debug
68 | if 'debug' in opt['name']:
69 | opt['train']['val_freq'] = 2
70 | opt['train']['print_freq'] = 2
71 | opt['train']['save_checkpoint_freq'] = 3
72 | opt['datasets']['train']['batch_size'] = 2
73 | opt['model']['beta_schedule']['train']['n_timestep'] = 10
74 | opt['model']['beta_schedule']['val']['n_timestep'] = 10
75 | opt['datasets']['train']['data_len'] = 6
76 | opt['datasets']['val']['data_len'] = 3
77 |
78 | # validation in train phase
79 | # if phase == 'train':
80 | # opt['datasets']['val']['data_len'] = 3
81 |
82 | # W&B Logging
83 | try:
84 | log_wandb_ckpt = args.log_wandb_ckpt
85 | opt['log_wandb_ckpt'] = log_wandb_ckpt
86 | except:
87 | pass
88 | try:
89 | log_eval = args.log_eval
90 | opt['log_eval'] = log_eval
91 | except:
92 | pass
93 | try:
94 | log_infer = args.log_infer
95 | opt['log_infer'] = log_infer
96 | except:
97 | pass
98 |
99 | return opt
100 |
101 |
102 | class NoneDict(dict):
103 | def __missing__(self, key):
104 | return None
105 |
106 |
107 | # convert to NoneDict, which return None for missing key.
108 | def dict_to_nonedict(opt):
109 | if isinstance(opt, dict):
110 | new_opt = dict()
111 | for key, sub_opt in opt.items():
112 | new_opt[key] = dict_to_nonedict(sub_opt)
113 | return NoneDict(**new_opt)
114 | elif isinstance(opt, list):
115 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
116 | else:
117 | return opt
118 |
119 |
120 | def dict2str(opt, indent_l=1):
121 | '''dict to string for logger'''
122 | msg = ''
123 | for k, v in opt.items():
124 | if isinstance(v, dict):
125 | msg += ' ' * (indent_l * 2) + k + ':[\n'
126 | msg += dict2str(v, indent_l + 1)
127 | msg += ' ' * (indent_l * 2) + ']\n'
128 | else:
129 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
130 | return msg
131 |
132 |
133 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False):
134 | '''set up logger'''
135 | l = logging.getLogger(logger_name)
136 | formatter = logging.Formatter(
137 | '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S')
138 | log_file = os.path.join(root, '{}.log'.format(phase))
139 | fh = logging.FileHandler(log_file, mode='w')
140 | fh.setFormatter(formatter)
141 | l.setLevel(level)
142 | l.addHandler(fh)
143 | if screen:
144 | sh = logging.StreamHandler()
145 | sh.setFormatter(formatter)
146 | l.addHandler(sh)
147 |
--------------------------------------------------------------------------------
/core/metrics.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import numpy as np
4 | import cv2
5 | from torchvision.utils import make_grid
6 |
7 |
8 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)):
9 | '''
10 | Converts a torch Tensor into an image Numpy array
11 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
12 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
13 | '''
14 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
15 | tensor = (tensor - min_max[0]) / \
16 | (min_max[1] - min_max[0]) # to range [0,1]
17 | n_dim = tensor.dim()
18 | if n_dim == 4:
19 | n_img = len(tensor)
20 | # img_np = make_grid(tensor, nrow=int(
21 | # math.sqrt(n_img)), normalize=False).numpy()
22 | # img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
23 |
24 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
25 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
26 | elif n_dim == 3:
27 | img_np = tensor.numpy()
28 | # img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
29 |
30 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
31 | elif n_dim == 2:
32 | img_np = tensor.numpy()
33 | else:
34 | raise TypeError(
35 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
36 | if out_type == np.uint8:
37 | img_np = np.clip((img_np * 255.0).round(), 0, 255)
38 | # img_np = (img_np * 255.0).round()
39 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
40 | return img_np.astype(out_type)
41 |
42 | def tensor2img2(tensor, out_type=np.uint8, min_max=(-1, 1)):
43 | '''
44 | Converts a torch Tensor into an image Numpy array
45 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
46 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
47 | '''
48 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
49 |
50 | n_dim = tensor.dim()
51 | if n_dim == 4:
52 | n_img = len(tensor)
53 | # img_np = make_grid(tensor, nrow=int(
54 | # math.sqrt(n_img)), normalize=False).numpy()
55 | # img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
56 |
57 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
58 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
59 | elif n_dim == 3:
60 | img_np = tensor.numpy()
61 | # img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB
62 |
63 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
64 | elif n_dim == 2:
65 | img_np = tensor.numpy()
66 | else:
67 | raise TypeError(
68 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
69 | if out_type == np.uint8:
70 | img_np = np.clip((img_np * 255.0).round(), 0, 255)
71 | # img_np = (img_np * 255.0).round()
72 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
73 | return img_np.astype(out_type)
74 |
75 | def save_img(img, img_path, mode='RGB'):
76 | cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
77 | # cv2.imwrite(img_path, img)
78 |
79 |
80 | def calculate_psnr(img1, img2):
81 | # img1 and img2 have range [0, 255]
82 | img1 = img1.astype(np.float64)
83 | img2 = img2.astype(np.float64)
84 | mse = np.mean((img1 - img2)**2)
85 | if mse == 0:
86 | return float('inf')
87 | return 20 * math.log10(255.0 / math.sqrt(mse))
88 |
89 |
90 | def ssim(img1, img2):
91 | C1 = (0.01 * 255)**2
92 | C2 = (0.03 * 255)**2
93 |
94 | img1 = img1.astype(np.float64)
95 | img2 = img2.astype(np.float64)
96 | kernel = cv2.getGaussianKernel(11, 1.5)
97 | window = np.outer(kernel, kernel.transpose())
98 |
99 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
100 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
101 | mu1_sq = mu1**2
102 | mu2_sq = mu2**2
103 | mu1_mu2 = mu1 * mu2
104 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
105 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
106 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
107 |
108 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
109 | (sigma1_sq + sigma2_sq + C2))
110 | return ssim_map.mean()
111 |
112 |
113 | def calculate_ssim(img1, img2):
114 | '''calculate SSIM
115 | the same outputs as MATLAB's
116 | img1, img2: [0, 255]
117 | '''
118 | if not img1.shape == img2.shape:
119 | raise ValueError('Input images must have the same dimensions.')
120 | if img1.ndim == 2:
121 | return ssim(img1, img2)
122 | elif img1.ndim == 3:
123 | if img1.shape[2] == 3:
124 | ssims = []
125 | for i in range(3):
126 | ssims.append(ssim(img1, img2))
127 | return np.array(ssims).mean()
128 | elif img1.shape[2] == 1:
129 | return ssim(np.squeeze(img1), np.squeeze(img2))
130 | else:
131 | raise ValueError('Wrong input image dimensions.')
132 |
--------------------------------------------------------------------------------
/data/LoL_dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch.utils.data as data
3 | import numpy as np
4 | import torch
5 | import cv2
6 | from torchvision.transforms import ToTensor
7 | import torchvision.transforms as T
8 | import torchvision
9 |
10 |
11 | class LOLv1_Dataset(data.Dataset):
12 | def __init__(self, opt, train, all_opt):
13 | self.root = opt["root"]
14 | self.opt = opt
15 | self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False
16 | self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False
17 | self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False
18 | self.crop_size = opt.get("patch_size", None)
19 | if train:
20 | self.split = 'train'
21 | self.root = os.path.join(self.root, 'our485')
22 | else:
23 | self.split = 'val'
24 | self.root = os.path.join(self.root, 'eval15')
25 | self.pairs = self.load_pairs(self.root)
26 | self.to_tensor = ToTensor()
27 |
28 | def __len__(self):
29 | return len(self.pairs)
30 |
31 | def load_pairs(self, folder_path):
32 |
33 | low_list = os.listdir(os.path.join(folder_path, 'low'))
34 | low_list = filter(lambda x: 'png' in x, low_list)
35 |
36 | pairs = []
37 | for idx, f_name in enumerate(low_list):
38 |
39 | if self.split == 'val':
40 | pairs.append(
41 | [cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'low', f_name)), cv2.COLOR_BGR2RGB),
42 | cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'high', f_name)), cv2.COLOR_BGR2RGB),
43 | f_name.split('.')[0]])
44 | else:
45 | pairs.append(
46 | [cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'low', f_name)), cv2.COLOR_BGR2RGB),
47 | cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'high', f_name)), cv2.COLOR_BGR2RGB),
48 | f_name.split('.')[0]])
49 | return pairs
50 |
51 | def get_max(self,input):
52 | T,_=torch.max(input,dim=0)
53 | T=T+0.1
54 | input[0,:,:] = input[0,:,:]/ T
55 | input[1,:,:] = input[1,:,:]/ T
56 | input[2,:,:]= input[2,:,:] / T
57 | return input
58 |
59 | def __getitem__(self, item):
60 | lr, hr, f_name = self.pairs[item]
61 |
62 |
63 | if self.use_crop and self.split != 'val':
64 | hr, lr = random_crop(hr, lr, self.crop_size)
65 | elif self.split == 'val':
66 | lr = cv2.copyMakeBorder(lr, 8,8,4,4,cv2.BORDER_REFLECT)
67 |
68 | if self.use_flip:
69 | hr, lr = random_flip(hr, lr)
70 |
71 | if self.use_rot:
72 | hr, lr = random_rotation(hr, lr)
73 |
74 | hr = self.to_tensor(hr)
75 | lr = self.to_tensor(lr)
76 | # lr = self.get_max(lr)
77 |
78 | [lr, hr] = transform_augment(
79 | [lr, hr], split=self.split, min_max=(-1, 1))
80 |
81 | return {'LQ': lr, 'GT': hr, 'LQ_path': f_name, 'GT_path': f_name}
82 |
83 | class LOLv2_Dataset(data.Dataset):
84 | def __init__(self, opt, train, all_opt):
85 | self.root = opt["root"]
86 | self.opt = opt
87 | self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False
88 | self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False
89 | self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False
90 | self.crop_size = opt.get("patch_size", None)
91 | self.sub_data = opt.get("sub_data", None)
92 | self.pairs = []
93 | self.train = train
94 | if train:
95 | self.split = 'train'
96 | root = os.path.join(self.root, self.sub_data, 'Train')
97 | else:
98 | self.split = 'val'
99 | root = os.path.join(self.root, self.sub_data, 'Test')
100 | self.pairs.extend(self.load_pairs(root))
101 | self.to_tensor = ToTensor()
102 | self.gamma_aug = opt['gamma_aug'] if 'gamma_aug' in opt.keys() else False
103 |
104 | def __len__(self):
105 | return len(self.pairs)
106 |
107 | def get_max(self,input):
108 | T,_=torch.max(input,dim=0)
109 | T=T+0.1
110 | input[0,:,:] = input[0,:,:]/ T
111 | input[1,:,:] = input[1,:,:]/ T
112 | input[2,:,:]= input[2,:,:] / T
113 | return input
114 |
115 | def load_pairs(self, folder_path):
116 |
117 | low_list = os.listdir(os.path.join(folder_path, 'Low' if self.train else 'Low'))
118 | low_list = sorted(list(filter(lambda x: 'png' in x, low_list)))
119 | high_list = os.listdir(os.path.join(folder_path, 'Normal' if self.train else 'Normal'))
120 | high_list = sorted(list(filter(lambda x: 'png' in x, high_list)))
121 | pairs = []
122 |
123 | for idx in range(len(low_list)):
124 | f_name_low = low_list[idx]
125 | f_name_high = high_list[idx]
126 | pairs.append(
127 | [cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'Low' if self.train else 'Low', f_name_low)),
128 | cv2.COLOR_BGR2RGB),
129 | cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'Normal' if self.train else 'Normal', f_name_high)),
130 | cv2.COLOR_BGR2RGB),
131 | f_name_high.split('.')[0]])
132 | return pairs
133 |
134 | def __getitem__(self, item):
135 |
136 | lr, hr, f_name = self.pairs[item]
137 |
138 | if self.use_crop and self.split != 'val':
139 | hr, lr = random_crop(hr, lr, self.crop_size)
140 | elif self.sub_data == 'Real_captured' and self.split == 'val': # for Real_captured
141 | lr = cv2.copyMakeBorder(lr, 8,8,4,4,cv2.BORDER_REFLECT)
142 |
143 | if self.use_flip:
144 | hr, lr = random_flip(hr, lr)
145 |
146 | if self.use_rot:
147 | hr, lr = random_rotation(hr, lr)
148 |
149 |
150 | hr = self.to_tensor(hr)
151 | lr = self.to_tensor(lr)
152 | # lr = self.get_max(lr)
153 |
154 |
155 | [lr, hr] = transform_augment(
156 | [lr, hr], split=self.split, min_max=(-1, 1))
157 |
158 | return {'LQ': lr, 'GT': hr, 'LQ_path': f_name, 'GT_path': f_name}
159 |
160 |
161 | def random_flip(img, seg):
162 | random_choice = np.random.choice([True, False])
163 | img = img if random_choice else np.flip(img, 1).copy()
164 | seg = seg if random_choice else np.flip(seg, 1).copy()
165 |
166 | return img, seg
167 |
168 |
169 | def gamma_aug(img, gamma=0):
170 | max_val = img.max()
171 | img_after_norm = img / max_val
172 | img_after_norm = np.power(img_after_norm, gamma)
173 | return img_after_norm * max_val
174 |
175 |
176 | def random_rotation(img, seg):
177 | random_choice = np.random.choice([0, 1, 3])
178 | img = np.rot90(img, random_choice, axes=(0, 1)).copy()
179 | seg = np.rot90(seg, random_choice, axes=(0, 1)).copy()
180 |
181 | return img, seg
182 |
183 |
184 | def random_crop(hr, lr, size_hr):
185 | size_lr = size_hr
186 |
187 | size_lr_x = lr.shape[0]
188 | size_lr_y = lr.shape[1]
189 |
190 | start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0
191 | start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0
192 |
193 | # LR Patch
194 | lr_patch = lr[start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr, :]
195 |
196 | # HR Patch
197 | start_x_hr = start_x_lr
198 | start_y_hr = start_y_lr
199 | hr_patch = hr[start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr, :]
200 |
201 | # HisEq Patch
202 | his_eq_patch = None
203 | return hr_patch, lr_patch,
204 |
205 |
206 | # implementation by torchvision, detail in https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement/issues/14
207 | totensor = torchvision.transforms.ToTensor()
208 | hflip = torchvision.transforms.RandomHorizontalFlip()
209 | def transform_augment(imgs, split='val', min_max=(0, 1)):
210 | # imgs = [totensor(img) for img in img_list]
211 | if split == 'train':
212 | imgs = torch.stack(imgs, 0)
213 | # imgs = hflip(imgs)
214 | imgs = torch.unbind(imgs, dim=0)
215 | ret_img = [img * (min_max[1] - min_max[0]) + min_max[0] for img in imgs]
216 | return ret_img
217 |
218 |
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
1 | '''create dataset and dataloader'''
2 | import logging
3 | import torch
4 | import torch.utils.data
5 |
6 |
7 |
8 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
9 | phase = dataset_opt['phase']
10 | if phase == 'train':
11 | if dataset_opt['dist']:
12 | world_size = torch.distributed.get_world_size()
13 | num_workers = dataset_opt['n_workers']
14 | assert dataset_opt['batch_size'] % world_size == 0
15 | batch_size = dataset_opt['batch_size'] // world_size
16 | shuffle = False
17 | sampler=torch.utils.data.distributed.DistributedSampler(dataset)
18 | else:
19 | num_workers = dataset_opt['n_workers'] # * len(opt['gpu_ids'])
20 | batch_size = dataset_opt['batch_size']
21 | shuffle = True
22 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
23 | num_workers=num_workers, sampler=sampler, drop_last=True,
24 | pin_memory=True)
25 | else:
26 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
27 | pin_memory=True)
28 |
29 |
30 | # def create_dataloader(train, dataset, dataset_opt, opt=None, sampler=None):
31 | # # gpu_ids = opt.get('gpu_ids', None)
32 | # gpu_ids = []
33 | # gpu_ids = gpu_ids if gpu_ids else []
34 | # num_workers = dataset_opt['n_workers'] * (len(gpu_ids)+1)
35 | # batch_size = dataset_opt['batch_size']
36 | # shuffle = True
37 | # if train:
38 | # return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
39 | # num_workers=num_workers, sampler=sampler, drop_last=True,
40 | # pin_memory=False)
41 | # else:
42 | # return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False,
43 | # num_workers=num_workers, sampler=sampler, drop_last=False,
44 | # pin_memory=False)
45 |
46 |
--------------------------------------------------------------------------------
/data/__pycache__/LoL_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/data/__pycache__/LoL_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/data/single_image_dataset.py:
--------------------------------------------------------------------------------
1 | from os import path as osp
2 | from torch.utils import data as data
3 | from torchvision.transforms.functional import normalize
4 |
5 | from basicsr.data.data_util import paths_from_lmdb
6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir
7 |
8 |
9 | class SingleImageDataset(data.Dataset):
10 | """Read only lq images in the test phase.
11 |
12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
13 |
14 | There are two modes:
15 | 1. 'meta_info_file': Use meta information file to generate paths.
16 | 2. 'folder': Scan folders to generate paths.
17 |
18 | Args:
19 | opt (dict): Config for train datasets. It contains the following keys:
20 | dataroot_lq (str): Data root path for lq.
21 | meta_info_file (str): Path for meta information file.
22 | io_backend (dict): IO backend type and other kwarg.
23 | """
24 |
25 | def __init__(self, opt):
26 | super(SingleImageDataset, self).__init__()
27 | self.opt = opt
28 | # file client (io backend)
29 | self.file_client = None
30 | self.io_backend_opt = opt['io_backend']
31 | self.mean = opt['mean'] if 'mean' in opt else None
32 | self.std = opt['std'] if 'std' in opt else None
33 | self.lq_folder = opt['dataroot_lq']
34 |
35 | if self.io_backend_opt['type'] == 'lmdb':
36 | self.io_backend_opt['db_paths'] = [self.lq_folder]
37 | self.io_backend_opt['client_keys'] = ['lq']
38 | self.paths = paths_from_lmdb(self.lq_folder)
39 | elif 'meta_info_file' in self.opt:
40 | with open(self.opt['meta_info_file'], 'r') as fin:
41 | self.paths = [
42 | osp.join(self.lq_folder,
43 | line.split(' ')[0]) for line in fin
44 | ]
45 | else:
46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
47 |
48 | def __getitem__(self, index):
49 | if self.file_client is None:
50 | self.file_client = FileClient(
51 | self.io_backend_opt.pop('type'), **self.io_backend_opt)
52 |
53 | # load lq image
54 | lq_path = self.paths[index]
55 | img_bytes = self.file_client.get(lq_path, 'lq')
56 | img_lq = imfrombytes(img_bytes, float32=True)
57 |
58 | # TODO: color space transform
59 | # BGR to RGB, HWC to CHW, numpy to tensor
60 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
61 | # normalize
62 | if self.mean is not None or self.std is not None:
63 | normalize(img_lq, self.mean, self.std, inplace=True)
64 | return {'lq': img_lq, 'lq_path': lq_path}
65 |
66 | def __len__(self):
67 | return len(self.paths)
68 |
--------------------------------------------------------------------------------
/data/util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import pickle
4 | import random
5 | import numpy as np
6 | import glob
7 | import torch
8 | import cv2
9 |
10 | ####################
11 | # Files & IO
12 | ####################
13 |
14 | ###################### get image path list ######################
15 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
16 |
17 |
18 | def flip(x, dim):
19 | indices = [slice(None)] * x.dim()
20 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
21 | dtype=torch.long, device=x.device)
22 | return x[tuple(indices)]
23 |
24 |
25 | def is_image_file(filename):
26 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
27 |
28 |
29 | def _get_paths_from_images(path):
30 | """get image path list from image folder"""
31 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
32 | images = []
33 | for dirpath, _, fnames in sorted(os.walk(path)):
34 | for fname in sorted(fnames):
35 | if is_image_file(fname):
36 | img_path = os.path.join(dirpath, fname)
37 | images.append(img_path)
38 | assert images, '{:s} has no valid image file'.format(path)
39 | return images
40 |
41 |
42 | def _get_paths_from_lmdb(dataroot):
43 | """get image path list from lmdb meta info"""
44 | meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
45 | paths = meta_info['keys']
46 | sizes = meta_info['resolution']
47 | if len(sizes) == 1:
48 | sizes = sizes * len(paths)
49 | return paths, sizes
50 |
51 |
52 | def get_image_paths(data_type, dataroot):
53 | """get image path list
54 | support lmdb or image files"""
55 | paths, sizes = None, None
56 | if dataroot is not None:
57 | if data_type == 'lmdb':
58 | paths, sizes = _get_paths_from_lmdb(dataroot)
59 | elif data_type == 'img':
60 | paths = sorted(_get_paths_from_images(dataroot))
61 | else:
62 | raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
63 | return paths, sizes
64 |
65 |
66 | def glob_file_list(root):
67 | return sorted(glob.glob(os.path.join(root, '*')))
68 |
69 |
70 | ###################### read images ######################
71 | def _read_img_lmdb(env, key, size):
72 | """read image from lmdb with key (w/ and w/o fixed size)
73 | size: (C, H, W) tuple"""
74 | with env.begin(write=False) as txn:
75 | buf = txn.get(key.encode('ascii'))
76 | img_flat = np.frombuffer(buf, dtype=np.uint8)
77 | C, H, W = size
78 | img = img_flat.reshape(H, W, C)
79 | return img
80 |
81 |
82 | def read_img(env, path, size=None):
83 | """read image by cv2 or from lmdb
84 | return: Numpy float32, HWC, BGR, [0,1]"""
85 | if env is None: # img
86 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
87 | if img is None:
88 | print(path)
89 | if size is not None:
90 | img = cv2.resize(img, (size[0], size[1]))
91 | else:
92 | img = _read_img_lmdb(env, path, size)
93 |
94 | img = img.astype(np.float32) / 255.
95 | if img.ndim == 2:
96 | img = np.expand_dims(img, axis=2)
97 | # some images have 4 channels
98 | if img.shape[2] > 3:
99 | img = img[:, :, :3]
100 | return img
101 |
102 |
103 | def read_img2(env, path, size=None):
104 | """read image by cv2 or from lmdb
105 | return: Numpy float32, HWC, BGR, [0,1]"""
106 | if env is None: # img
107 | img = np.load(path)
108 | if img is None:
109 | print(path)
110 | if size is not None:
111 | img = cv2.resize(img, (size[0], size[1]))
112 | # img = cv2.resize(img, size)
113 | else:
114 | img = _read_img_lmdb(env, path, size)
115 | img = get_max(img)
116 | img = img.astype(np.float32) / 255.
117 | if img.ndim == 2:
118 | img = np.expand_dims(img, axis=2)
119 | # some images have 4 channels
120 | if img.shape[2] > 3:
121 | img = img[:, :, :3]
122 | return img
123 |
124 |
125 | def read_img_seq(path, size=None):
126 | """Read a sequence of images from a given folder path
127 | Args:
128 | path (list/str): list of image paths/image folder path
129 |
130 | Returns:
131 | imgs (Tensor): size (T, C, H, W), RGB, [0, 1]
132 | """
133 | # print(path)
134 | if type(path) is list:
135 | img_path_l = path
136 | else:
137 | img_path_l = sorted(glob.glob(os.path.join(path, '*')))
138 |
139 | img_l = [read_img(None, v, size) for v in img_path_l]
140 | # stack to Torch tensor
141 | imgs = np.stack(img_l, axis=0)
142 | try:
143 | imgs = imgs[:, :, :, [2, 1, 0]]
144 | except Exception:
145 | import ipdb; ipdb.set_trace()
146 | imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
147 | return imgs
148 |
149 |
150 | def read_img_seq2(path, size=None):
151 | """Read a sequence of images from a given folder path
152 | Args:
153 | path (list/str): list of image paths/image folder path
154 |
155 | Returns:
156 | imgs (Tensor): size (T, C, H, W), RGB, [0, 1]
157 | """
158 | # print(path)
159 | if type(path) is list:
160 | img_path_l = path
161 | else:
162 | img_path_l = sorted(glob.glob(os.path.join(path, '*')))
163 |
164 | img_l = [read_img2(None, v, size) for v in img_path_l]
165 | # stack to Torch tensor
166 | imgs = np.stack(img_l, axis=0)
167 | try:
168 | imgs = imgs[:, :, :, [2, 1, 0]]
169 | except Exception:
170 | import ipdb; ipdb.set_trace()
171 | imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
172 | return imgs
173 |
174 | def get_max(x):
175 | T =np.max(x,axis=0)
176 | T=T+0.1
177 | x[0,:,:] = x[0,:,:]/ T
178 | x[1,:,:] = x[1,:,:]/ T
179 | x[2,:,:]= x[2,:,:] / T
180 | return x
181 |
182 |
183 | def index_generation(crt_i, max_n, N, padding='reflection'):
184 | """Generate an index list for reading N frames from a sequence of images
185 | Args:
186 | crt_i (int): current center index
187 | max_n (int): max number of the sequence of images (calculated from 1)
188 | N (int): reading N frames
189 | padding (str): padding mode, one of replicate | reflection | new_info | circle
190 | Example: crt_i = 0, N = 5
191 | replicate: [0, 0, 0, 1, 2]
192 | reflection: [2, 1, 0, 1, 2]
193 | new_info: [4, 3, 0, 1, 2]
194 | circle: [3, 4, 0, 1, 2]
195 |
196 | Returns:
197 | return_l (list [int]): a list of indexes
198 | """
199 | max_n = max_n - 1
200 | n_pad = N // 2
201 | return_l = []
202 |
203 | for i in range(crt_i - n_pad, crt_i + n_pad + 1):
204 | if i < 0:
205 | if padding == 'replicate':
206 | add_idx = 0
207 | elif padding == 'reflection':
208 | add_idx = -i
209 | elif padding == 'new_info':
210 | add_idx = (crt_i + n_pad) + (-i)
211 | elif padding == 'circle':
212 | add_idx = N + i
213 | else:
214 | raise ValueError('Wrong padding mode')
215 | elif i > max_n:
216 | if padding == 'replicate':
217 | add_idx = max_n
218 | elif padding == 'reflection':
219 | add_idx = max_n * 2 - i
220 | elif padding == 'new_info':
221 | add_idx = (crt_i - n_pad) - (i - max_n)
222 | elif padding == 'circle':
223 | add_idx = i - N
224 | else:
225 | raise ValueError('Wrong padding mode')
226 | else:
227 | add_idx = i
228 | return_l.append(add_idx)
229 | return return_l
230 |
231 |
232 | ####################
233 | # image processing
234 | # process on numpy image
235 | ####################
236 |
237 |
238 | def augment(img_list, hflip=True, rot=True):
239 | """horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
240 | hflip = hflip and random.random() < 0.5
241 | vflip = rot and random.random() < 0.5
242 | rot90 = rot and random.random() < 0.5
243 |
244 | def _augment(img):
245 | if hflip:
246 | img = img[:, ::-1, :]
247 | if vflip:
248 | img = img[::-1, :, :]
249 | if rot90:
250 | # import pdb; pdb.set_trace()
251 | img = img.transpose(1, 0, 2)
252 | return img
253 |
254 | return [_augment(img) for img in img_list]
255 |
256 |
257 |
258 | def augment_torch(img_list, hflip=True, rot=True):
259 | """horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
260 | hflip = hflip and random.random() < 0.5
261 | vflip = rot and random.random() < 0.5
262 | # rot90 = rot and random.random() < 0.5
263 |
264 | def _augment(img):
265 | if hflip:
266 | img = flip(img, 2)
267 | if vflip:
268 | img = flip(img, 1)
269 | # if rot90:
270 | # # import pdb; pdb.set_trace()
271 | # img = img.transpose(1, 0, 2)
272 | return img
273 |
274 | return [_augment(img) for img in img_list]
275 |
276 |
277 | def augment_flow(img_list, flow_list, hflip=True, rot=True):
278 | """horizontal flip OR rotate (0, 90, 180, 270 degrees) with flows"""
279 | hflip = hflip and random.random() < 0.5
280 | vflip = rot and random.random() < 0.5
281 | rot90 = rot and random.random() < 0.5
282 |
283 | def _augment(img):
284 | if hflip:
285 | img = img[:, ::-1, :]
286 | if vflip:
287 | img = img[::-1, :, :]
288 | if rot90:
289 | img = img.transpose(1, 0, 2)
290 | return img
291 |
292 | def _augment_flow(flow):
293 | if hflip:
294 | flow = flow[:, ::-1, :]
295 | flow[:, :, 0] *= -1
296 | if vflip:
297 | flow = flow[::-1, :, :]
298 | flow[:, :, 1] *= -1
299 | if rot90:
300 | flow = flow.transpose(1, 0, 2)
301 | flow = flow[:, :, [1, 0]]
302 | return flow
303 |
304 | rlt_img_list = [_augment(img) for img in img_list]
305 | rlt_flow_list = [_augment_flow(flow) for flow in flow_list]
306 |
307 | return rlt_img_list, rlt_flow_list
308 |
309 |
310 | def channel_convert(in_c, tar_type, img_list):
311 | """conversion among BGR, gray and y"""
312 | if in_c == 3 and tar_type == 'gray': # BGR to gray
313 | gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
314 | return [np.expand_dims(img, axis=2) for img in gray_list]
315 | elif in_c == 3 and tar_type == 'y': # BGR to y
316 | y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
317 | return [np.expand_dims(img, axis=2) for img in y_list]
318 | elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR
319 | return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
320 | else:
321 | return img_list
322 |
323 |
324 | def rgb2ycbcr(img, only_y=True):
325 | """same as matlab rgb2ycbcr
326 | only_y: only return Y channel
327 | Input:
328 | uint8, [0, 255]
329 | float, [0, 1]
330 | """
331 | in_img_type = img.dtype
332 | img.astype(np.float32)
333 | if in_img_type != np.uint8:
334 | img *= 255.
335 | # convert
336 | if only_y:
337 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
338 | else:
339 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
340 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
341 | if in_img_type == np.uint8:
342 | rlt = rlt.round()
343 | else:
344 | rlt /= 255.
345 | return rlt.astype(in_img_type)
346 |
347 |
348 | def bgr2ycbcr(img, only_y=True):
349 | """bgr version of rgb2ycbcr
350 | only_y: only return Y channel
351 | Input:
352 | uint8, [0, 255]
353 | float, [0, 1]
354 | """
355 | in_img_type = img.dtype
356 | img.astype(np.float32)
357 | if in_img_type != np.uint8:
358 | img *= 255.
359 | # convert
360 | if only_y:
361 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
362 | else:
363 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
364 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
365 | if in_img_type == np.uint8:
366 | rlt = rlt.round()
367 | else:
368 | rlt /= 255.
369 | return rlt.astype(in_img_type)
370 |
371 |
372 | def ycbcr2rgb(img):
373 | """same as matlab ycbcr2rgb
374 | Input:
375 | uint8, [0, 255]
376 | float, [0, 1]
377 | """
378 | in_img_type = img.dtype
379 | img.astype(np.float32)
380 | if in_img_type != np.uint8:
381 | img *= 255.
382 | # convert
383 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
384 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
385 | if in_img_type == np.uint8:
386 | rlt = rlt.round()
387 | else:
388 | rlt /= 255.
389 | return rlt.astype(in_img_type)
390 |
391 |
392 | def modcrop(img_in, scale):
393 | """img_in: Numpy, HWC or HW"""
394 | img = np.copy(img_in)
395 | if img.ndim == 2:
396 | H, W = img.shape
397 | H_r, W_r = H % scale, W % scale
398 | img = img[:H - H_r, :W - W_r]
399 | elif img.ndim == 3:
400 | H, W, C = img.shape
401 | H_r, W_r = H % scale, W % scale
402 | img = img[:H - H_r, :W - W_r, :]
403 | else:
404 | raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
405 | return img
406 |
407 |
--------------------------------------------------------------------------------
/dataset/LOLv1/readme.txt:
--------------------------------------------------------------------------------
1 | You can refer to the corresponding link to download the [LOL](https://daooshee.github.io/BMVC2018website/) dataset and put it in this folder.
2 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | logger = logging.getLogger('base')
3 |
4 |
5 | def create_model(opt):
6 | from .model import DDPM as M
7 | from .model import DDPM_PD as M_PD
8 | # print(opt['distill'])
9 | # import pdb; pdb.set_trace()
10 | if opt['distill']:
11 | m=M_PD(opt)
12 | else:
13 | m = M(opt)
14 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
15 | return m
16 |
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/base_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/__pycache__/base_model.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/networks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/__pycache__/networks.cpython-38.pyc
--------------------------------------------------------------------------------
/model/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 |
5 |
6 | class BaseModel():
7 | def __init__(self, opt):
8 | self.opt = opt
9 | self.device = torch.device(
10 | 'cuda' if opt['gpu_ids'] is not None else 'cpu')
11 | self.begin_step = 0
12 | self.begin_epoch = 0
13 |
14 | def feed_data(self, data):
15 | pass
16 |
17 | def optimize_parameters(self):
18 | pass
19 |
20 | def get_current_visuals(self):
21 | pass
22 |
23 | def get_current_losses(self):
24 | pass
25 |
26 | def print_network(self):
27 | pass
28 |
29 | def set_device(self, x):
30 | if isinstance(x, dict):
31 | for key, item in x.items():
32 | if item is not None:
33 | x[key] = item.to(self.device)
34 | elif isinstance(x, list):
35 | for item in x:
36 | if item is not None:
37 | item = item.to(self.device)
38 | else:
39 | x = x.to(self.device)
40 | return x
41 |
42 | def get_network_description(self, network):
43 | '''Get the string and total parameters of the network'''
44 | if isinstance(network, nn.DataParallel):
45 | network = network.module
46 | s = str(network)
47 | n = sum(map(lambda x: x.numel(), network.parameters()))
48 | return s, n
49 |
--------------------------------------------------------------------------------
/model/ddpm_modules/__pycache__/diffusion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/ddpm_modules/__pycache__/diffusion.cpython-38.pyc
--------------------------------------------------------------------------------
/model/ddpm_modules/__pycache__/unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/ddpm_modules/__pycache__/unet.cpython-38.pyc
--------------------------------------------------------------------------------
/model/ddpm_modules/diffusion.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import device, nn, einsum
4 | import torch.nn.functional as F
5 | from inspect import isfunction
6 | from functools import partial
7 | import numpy as np
8 | from tqdm import tqdm
9 | import cv2
10 | import torchvision.transforms as T
11 |
12 | from sklearn.cluster import AgglomerativeClustering
13 | from sklearn.cluster import MeanShift
14 | from sklearn.cluster import DBSCAN
15 | from sklearn.cluster import SpectralClustering
16 | import lpips
17 | from torchvision.utils import save_image
18 | from torch.optim.swa_utils import AveragedModel
19 |
20 |
21 |
22 | transform = T.Lambda(lambda t: (t + 1) / 2)
23 |
24 | def extract(v, t, x_shape):
25 |
26 | try:
27 | out = torch.gather(v, index=t, dim=0).float()
28 | except:
29 | # import pdb; pdb.set_trace()
30 | print(t)
31 | # import pdb; pdb.set_trace()
32 | print(print(v.shape))
33 | # import pdb; pdb.set_trace()
34 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
35 |
36 |
37 |
38 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac):
39 | betas = linear_end * np.ones(n_timestep, dtype=np.float64)
40 | warmup_time = int(n_timestep * warmup_frac)
41 | betas[:warmup_time] = np.linspace(
42 | linear_start, linear_end, warmup_time, dtype=np.float64)
43 | return betas
44 |
45 |
46 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
47 | if schedule == 'quad':
48 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5,
49 | n_timestep, dtype=np.float64) ** 2
50 | elif schedule == 'linear':
51 | betas = np.linspace(linear_start, linear_end,
52 | n_timestep, dtype=np.float64)
53 | elif schedule == 'warmup10':
54 | betas = _warmup_beta(linear_start, linear_end,
55 | n_timestep, 0.1)
56 | elif schedule == 'warmup50':
57 | betas = _warmup_beta(linear_start, linear_end,
58 | n_timestep, 0.5)
59 | elif schedule == 'const':
60 | betas = linear_end * np.ones(n_timestep, dtype=np.float64)
61 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
62 | betas = 1. / np.linspace(n_timestep,
63 | 1, n_timestep, dtype=np.float64)
64 | elif schedule == "cosine":
65 | timesteps = (
66 | torch.arange(n_timestep + 1, dtype=torch.float64) /
67 | n_timestep + cosine_s
68 | )
69 | alphas = timesteps / (1 + cosine_s) * math.pi / 2
70 | alphas = torch.cos(alphas).pow(2)
71 | alphas = alphas / alphas[0]
72 | betas = 1 - alphas[1:] / alphas[:-1]
73 | betas = betas.clamp(max=0.999)
74 | else:
75 | raise NotImplementedError(schedule)
76 | return betas
77 |
78 |
79 | # gaussian diffusion trainer class
80 |
81 | def exists(x):
82 | return x is not None
83 |
84 |
85 | def default(val, d):
86 | if exists(val):
87 | return val
88 | return d() if isfunction(d) else d
89 |
90 |
91 | class GaussianDiffusion(nn.Module):
92 | def __init__(
93 | self,
94 | denoise_fn,
95 | image_size,
96 | num_timesteps,
97 | time_scale,
98 | w_str,
99 | w_gt,
100 | w_snr,
101 | w_lpips,
102 | channels=3,
103 | loss_type='l1',
104 | conditional=True,
105 | schedule_opt=None
106 | ):
107 | super().__init__()
108 | self.channels = channels
109 | self.image_size = image_size
110 | self.denoise_fn = denoise_fn
111 | self.loss_type = loss_type
112 | self.conditional = conditional
113 | self.num_timesteps = num_timesteps
114 | device = torch.device("cuda")
115 |
116 | self.w_str = w_str
117 | self.w_gt = w_gt
118 | self.w_snr = w_snr
119 | self.w_lpips = w_lpips
120 | # self.lpips = lpips.LPIPS(net='vgg').cuda()
121 | # print(self.num_timesteps)
122 | # import pdb; pdb.set_trace()
123 | self.time_scale = time_scale
124 | self.CD = False
125 | if schedule_opt is not None:
126 | self.set_new_noise_schedule(schedule_opt, device)
127 |
128 | def set_loss(self, device):
129 | if self.loss_type == 'l1':
130 | self.loss_func = nn.L1Loss(reduction='sum').to(device)
131 | elif self.loss_type == 'l2':
132 | self.loss_func = nn.MSELoss(reduction='sum').to(device)
133 | else:
134 | raise NotImplementedError()
135 |
136 | def set_new_noise_schedule(self, schedule_opt, device):
137 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
138 |
139 | betas = make_beta_schedule(
140 | schedule=schedule_opt['schedule'],
141 | n_timestep= self.num_timesteps* self.time_scale + 1,
142 | linear_start=schedule_opt['linear_start'],
143 | linear_end=schedule_opt['linear_end'])
144 | betas = betas.detach().cpu().numpy() if isinstance(
145 | betas, torch.Tensor) else betas
146 | alphas = 1. - betas
147 | alphas_cumprod = np.cumprod(alphas, axis=0)
148 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
149 | self.sqrt_alphas_cumprod_prev = np.sqrt(
150 | np.append(1., alphas_cumprod))
151 |
152 | timesteps, = betas.shape
153 | self.register_buffer('betas', to_torch(betas))
154 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
155 | self.register_buffer('alphas_cumprod_prev',
156 | to_torch(alphas_cumprod_prev))
157 |
158 | # calculations for diffusion q(x_t | x_{t-1}) and others
159 | self.register_buffer('sqrt_alphas_cumprod',
160 | to_torch(np.sqrt(alphas_cumprod)))
161 | self.register_buffer('sqrt_one_minus_alphas_cumprod',
162 | to_torch(np.sqrt(1. - alphas_cumprod)))
163 | self.register_buffer('log_one_minus_alphas_cumprod',
164 | to_torch(np.log(1. - alphas_cumprod)))
165 | self.register_buffer('sqrt_recip_alphas_cumprod',
166 | to_torch(np.sqrt(1. / alphas_cumprod)))
167 | self.register_buffer('sqrt_recipm1_alphas_cumprod',
168 | to_torch(np.sqrt(1. / alphas_cumprod - 1)))
169 |
170 | # calculations for posterior q(x_{t-1} | x_t, x_0)
171 | posterior_variance = betas * \
172 | (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
173 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
174 | self.register_buffer('posterior_variance',
175 | to_torch(posterior_variance))
176 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
177 | self.register_buffer('posterior_log_variance_clipped', to_torch(
178 | np.log(np.maximum(posterior_variance, 1e-20))))
179 | self.register_buffer('posterior_mean_coef1', to_torch(
180 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
181 | self.register_buffer('posterior_mean_coef2', to_torch(
182 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
183 |
184 |
185 | def predict_start_from_noise(self, x_t, t, noise):
186 | return self.sqrt_recip_alphas_cumprod[t] * x_t - \
187 | self.sqrt_recipm1_alphas_cumprod[t] * noise
188 |
189 | def predict_eps_from_x(self, x_t, x_0, t):
190 |
191 | eps = (x_t -self.sqrt_alphas_cumprod[t] * x_0) / self.sqrt_one_minus_alphas_cumprod[t]
192 | return eps
193 |
194 | def predict_eps(self, x_t, x_0, continuous_sqrt_alpha_cumprod):
195 |
196 | eps = (1. / (1 - continuous_sqrt_alpha_cumprod **2).sqrt()) * x_t - \
197 | (1. / (1 - continuous_sqrt_alpha_cumprod**2) -1).sqrt() * x_0
198 |
199 | return eps
200 |
201 | def predict_start(self, x_t, continuous_sqrt_alpha_cumprod, noise):
202 |
203 | return (1. / continuous_sqrt_alpha_cumprod) * x_t - \
204 | (1. / continuous_sqrt_alpha_cumprod**2 - 1).sqrt() * noise
205 |
206 | def predict_t_minus1(self, x, t, continuous_sqrt_alpha_cumprod, noise, clip_denoised=True):
207 |
208 | x_recon = self.predict_start(x,
209 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1),
210 | noise=noise)
211 |
212 | if clip_denoised:
213 | x_recon.clamp_(-1., 1.)
214 |
215 | model_mean, model_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
216 |
217 | noise_z = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
218 |
219 | return model_mean + noise_z * (0.5 * model_log_variance).exp()
220 |
221 | def q_posterior(self, x_start, x_t, t):
222 | posterior_mean = self.posterior_mean_coef1[t] * \
223 | x_start + self.posterior_mean_coef2[t] * x_t
224 | posterior_log_variance_clipped = self.posterior_log_variance_clipped[t]
225 | return posterior_mean, posterior_log_variance_clipped
226 |
227 | def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None):
228 | batch_size = x.shape[0]
229 | noise_level = torch.FloatTensor(
230 | [self.sqrt_alphas_cumprod_prev[t+self.time_scale]]).repeat(batch_size, 1).to(x.device)
231 |
232 | eps = self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level)[0]
233 | # print(t)
234 |
235 | x_recon = self.predict_start_from_noise(x, t=t*self.time_scale, noise=eps)
236 |
237 |
238 | if clip_denoised:
239 | x_recon.clamp_(-1., 1.)
240 |
241 | model_mean, posterior_log_variance = self.q_posterior(
242 | x_start=x_recon, x_t=x, t=t)
243 | return model_mean, posterior_log_variance, eps
244 |
245 | @torch.no_grad()
246 | def p_sample(self, x, t, clip_denoised=True, condition_x=None):
247 | model_mean, model_log_variance, eps = self.p_mean_variance(
248 | x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x)
249 | noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
250 | return model_mean + noise * (0.5 * model_log_variance).exp()
251 |
252 | @torch.no_grad()
253 | def p_sample_loop(self, x_in, continous=False):
254 | device = self.betas.device
255 | sample_inter = (1 | (self.num_timesteps//10))
256 | if not self.conditional:
257 | shape = x_in
258 | img = torch.randn(shape, device=device)
259 | ret_img = img
260 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
261 | img = self.p_sample(img, i)
262 | if i % sample_inter == 0:
263 | ret_img = torch.cat([ret_img, img], dim=0)
264 | else:
265 | x = x_in
266 | shape = x.shape
267 | img = torch.randn(shape, device=device)
268 | ret_img = x
269 | # print(self.time_scale)
270 | # import pdb; pdb.set_trace()
271 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
272 | img = self.p_sample(img, i, condition_x=x)
273 | if i % sample_inter == 0:
274 | ret_img = torch.cat([ret_img, img], dim=0)
275 | if continous:
276 | return ret_img
277 | else:
278 | return ret_img[-1]
279 |
280 | @torch.no_grad()
281 | def sample(self, batch_size=1, continous=False):
282 | image_size = self.image_size
283 | channels = self.channels
284 | return self.p_sample_loop((batch_size, channels, image_size, image_size), continous)
285 |
286 | @torch.no_grad()
287 | def super_resolution(self, x_in, continous=False, stride=1):
288 | return self.ddim(x_in, continous, stride=stride)
289 |
290 | def q_sample(self, x_start, continuous_sqrt_alpha_cumprod, noise=None):
291 |
292 | return (
293 | continuous_sqrt_alpha_cumprod * x_start +
294 | (1 - continuous_sqrt_alpha_cumprod**2).sqrt() * noise
295 | )
296 |
297 | def ddim(self, x_in, continous=False, snr_aware=False, stride=1, clip_denoised=True):
298 | x = x_in
299 | condition_x = x_in
300 | x_t = torch.randn(x.shape, device=x.device)
301 |
302 | batch_size = x_in.shape[0]
303 |
304 | for time_step in reversed(range(stride, self.num_timesteps + 1, stride)):
305 |
306 |
307 | t = x_t.new_ones([x_t.shape[0], ], dtype=torch.long) * time_step
308 | s = x_t.new_ones([x_t.shape[0], ], dtype=torch.long) * (time_step - stride)
309 | noise_level = torch.FloatTensor([self.sqrt_alphas_cumprod_prev[t* self.time_scale]]).repeat(batch_size, 1).to(x.device)
310 | eps =self.denoise_fn(torch.cat([condition_x, x_t], dim=1), noise_level)[0]
311 | x_0 = self.predict_start_from_noise(x_t, t * self.time_scale, eps)
312 | if clip_denoised:
313 | x_0 = torch.clip(x_0, -1., 1.)
314 | eps = self.predict_eps_from_x(x_t, x_0, t * self.time_scale)
315 |
316 | x_t = self.sqrt_alphas_cumprod[s * self.time_scale] * x_0 + self.sqrt_one_minus_alphas_cumprod[s * self.time_scale] * eps
317 |
318 | return torch.clip(x_t, -1, 1)
319 |
320 | def SNR_map(self, x_0):
321 | blur_transform = T.GaussianBlur(kernel_size=15, sigma=3)
322 | blur_x_0 = blur_transform(x_0)
323 | gray_blur_x_0 = blur_x_0[:, 0:1, :, :] * 0.299 + blur_x_0[:, 1:2, :, :] * 0.587 + blur_x_0[:, 2:3, :, :] * 0.114
324 | gray_x_0 = x_0[:, 0:1, :, :] * 0.299 + x_0[:, 1:2, :, :] * 0.587 + x_0[:, 2:3, :, :] * 0.114
325 | noise = torch.abs(gray_blur_x_0 - gray_x_0)
326 |
327 | return noise
328 |
329 |
330 |
331 | def loss(self, x_in, student, noise=None, lpips_func=None):
332 | x_0 = x_in['GT']
333 | [b, c, h, w] = x_0.shape
334 |
335 | t = 2 * np.random.randint(1, student.num_timesteps + 1)
336 |
337 | continuous_sqrt_alpha_cumprod = torch.FloatTensor(
338 | np.random.uniform(
339 | self.sqrt_alphas_cumprod_prev[(t-1)*self.time_scale],
340 | self.sqrt_alphas_cumprod_prev[t*self.time_scale],
341 | size=b
342 | )
343 | ).to(x_0.device)
344 |
345 | continuous_sqrt_alpha_cumprod_t_mins_1 = torch.FloatTensor(
346 | np.random.uniform(
347 | self.sqrt_alphas_cumprod_prev[(t-2)*self.time_scale],
348 | self.sqrt_alphas_cumprod_prev[(t-1)*self.time_scale],
349 | size=b
350 | )
351 | ).to(x_0.device)
352 | continuous_sqrt_alpha_cumprod_t_mins_2 = torch.FloatTensor(
353 | np.random.uniform(
354 | self.sqrt_alphas_cumprod_prev[(t-3)*self.time_scale],
355 | self.sqrt_alphas_cumprod_prev[(t-2)*self.time_scale],
356 | size=b
357 | )
358 | ).to(x_0.device)
359 |
360 | continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view(b, -1)
361 | continuous_sqrt_alpha_cumprod_t_mins_1 = continuous_sqrt_alpha_cumprod_t_mins_1.view(b, -1)
362 | continuous_sqrt_alpha_cumprod_t_mins_2 = continuous_sqrt_alpha_cumprod_t_mins_2.view(b, -1)
363 |
364 | noise = default(noise, lambda: torch.randn_like(x_0))
365 | t = torch.tensor([t], dtype=torch.int64).to(x_0.device)
366 | bs = x_0.size(0)
367 |
368 | with torch.no_grad():
369 | z_t = self.q_sample(x_0, continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), noise)
370 | eps_rec, _ = self.denoise_fn(torch.cat([x_in['LQ'], z_t], dim=1), continuous_sqrt_alpha_cumprod)
371 | x_0_rec = self.predict_start(z_t, continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), eps_rec)
372 | z_t_minus_1 = self.q_sample(x_0_rec, continuous_sqrt_alpha_cumprod_t_mins_1.view(-1, 1, 1, 1), eps_rec)
373 | eps_rec_rec, _ = self.denoise_fn(torch.cat([x_in['LQ'], z_t_minus_1], dim=1), continuous_sqrt_alpha_cumprod_t_mins_1)
374 | x_0_rec_rec = self.predict_start(z_t_minus_1, continuous_sqrt_alpha_cumprod_t_mins_1.view(-1, 1, 1, 1), eps_rec_rec)
375 | z_t_minus_2 = self.q_sample(x_0_rec_rec, continuous_sqrt_alpha_cumprod_t_mins_2.view(-1, 1, 1, 1), eps_rec_rec)
376 | frac = (1 - continuous_sqrt_alpha_cumprod_t_mins_2.view(-1, 1, 1, 1)**2).sqrt() / (1- continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1)**2).sqrt()
377 | if self.w_snr != 0:
378 | y = x_in['LQ']
379 | T,_=torch.max(y,dim=1, keepdim=True)
380 | T=T+0.1
381 | y = y / T
382 | iso_noise = self.SNR_map(y)
383 | y = y - iso_noise
384 | refine_x_0 = y
385 | z_t_minus_2_refine = self.q_sample(refine_x_0, continuous_sqrt_alpha_cumprod_t_mins_2.view(-1, 1, 1, 1), eps_rec_rec)
386 | z_t_minus_2 = z_t_minus_2 + self.w_snr *(z_t_minus_2_refine - z_t_minus_2)
387 |
388 | x_target = (z_t_minus_2 - frac * z_t) / ( continuous_sqrt_alpha_cumprod_t_mins_2.view(-1, 1, 1, 1) - frac * continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1 ))
389 | eps_target = self.predict_eps(z_t, x_target, continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1))
390 |
391 | eps_predicted, _ = student.denoise_fn(torch.cat([x_in['LQ'], z_t], dim=1), continuous_sqrt_alpha_cumprod)
392 | x_0_predicted = self.predict_start(z_t, continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), eps_predicted)
393 | loss_x_0 = torch.mean(F.mse_loss(x_0_predicted, x_target, reduction='none').reshape(bs, -1), dim=-1)
394 | loss_eps = torch.mean(F.mse_loss(eps_predicted, eps_target, reduction='none').reshape(bs, -1), dim=-1)
395 |
396 | loss_stru = torch.zeros_like(loss_x_0) # 0.
397 | if self.w_gt != 0:
398 | loss_output_x0 = torch.mean(F.mse_loss(x_0, x_0_predicted, reduction='none').reshape(bs, -1), dim=-1)
399 | loss_output_eps = torch.mean(F.mse_loss(noise, eps_predicted, reduction='none').reshape(bs, -1), dim=-1)
400 |
401 | else:
402 | loss_output_x0 = torch.zeros_like(loss_x_0) # 0.
403 | loss_output_eps = torch.zeros_like(loss_eps) # 0.
404 |
405 | if self.w_lpips != 0:
406 | loss_lpips = torch.mean(lpips_func(x_0, x_0_predicted))
407 | else:
408 | loss_lpips = torch.zeros_like(loss_x_0) # 0.
409 |
410 | return torch.mean(torch.maximum(loss_x_0, loss_eps)) + \
411 | self.w_gt * torch.mean(torch.maximum(loss_output_x0, loss_output_eps)) + \
412 | self.w_lpips*torch.mean(loss_lpips) + \
413 | self.w_str*torch.mean(loss_stru)
414 |
415 |
416 |
417 |
418 |
419 | def forward(self, x, s_model=None, *args, **kwargs):
420 | return self.loss(x, s_model, *args, **kwargs)
421 |
422 |
423 |
--------------------------------------------------------------------------------
/model/ddpm_modules/unet.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | import torch.fft as fft
4 | from torch import nn
5 | import torch.nn.functional as F
6 | from inspect import isfunction
7 | import cv2
8 | import torchvision.transforms as T
9 | import numpy as np
10 | def exists(x):
11 | return x is not None
12 |
13 | # PositionalEncoding Source: https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py
14 | class PositionalEncoding(nn.Module):
15 | def __init__(self, dim):
16 | super().__init__()
17 | self.dim = dim
18 |
19 | def forward(self, noise_level):
20 | count = self.dim // 2
21 | step = torch.arange(count, dtype=noise_level.dtype,
22 | device=noise_level.device) / count
23 | encoding = noise_level.unsqueeze(
24 | 1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
25 | encoding = torch.cat(
26 | [torch.sin(encoding), torch.cos(encoding)], dim=-1)
27 | return encoding
28 |
29 | class FeatureWiseAffine(nn.Module):
30 | def __init__(self, in_channels, out_channels, use_affine_level=False):
31 | super(FeatureWiseAffine, self).__init__()
32 | self.use_affine_level = use_affine_level
33 | self.noise_func = nn.Sequential(
34 | nn.Linear(in_channels, out_channels*(1+self.use_affine_level))
35 | )
36 |
37 | def forward(self, x, noise_embed):
38 | batch = x.shape[0]
39 | if self.use_affine_level:
40 | gamma, beta = self.noise_func(noise_embed).view(
41 | batch, -1, 1, 1).chunk(2, dim=1)
42 | x = (1 + gamma) * x + beta
43 | else:
44 | x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
45 | return x
46 |
47 | def default(val, d):
48 | if exists(val):
49 | return val
50 | return d() if isfunction(d) else d
51 |
52 | # model
53 | class Swish(nn.Module):
54 | def forward(self, x):
55 | return x * torch.sigmoid(x)
56 |
57 |
58 | class Upsample(nn.Module):
59 | def __init__(self, dim):
60 | super().__init__()
61 | self.up = nn.Upsample(scale_factor=2, mode="nearest")
62 | self.conv = nn.Conv2d(dim, dim, 3, padding=1)
63 |
64 | def forward(self, x):
65 | return self.conv(self.up(x))
66 |
67 |
68 | class Downsample(nn.Module):
69 | def __init__(self, dim):
70 | super().__init__()
71 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
72 |
73 | def forward(self, x):
74 | return self.conv(x)
75 |
76 |
77 | # building block modules
78 | class Block(nn.Module):
79 | def __init__(self, dim, dim_out, groups=32, dropout=0):
80 | super().__init__()
81 | self.block = nn.Sequential(
82 | nn.GroupNorm(groups, dim),
83 | Swish(),
84 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
85 | nn.Conv2d(dim, dim_out, 3, padding=1)
86 | )
87 |
88 | def forward(self, x):
89 | return self.block(x)
90 |
91 |
92 | class ResnetBlock(nn.Module):
93 | def __init__(self, dim, dim_out, time_emb_dim=None, dropout=0, norm_groups=32):
94 | super().__init__()
95 | self.mlp = nn.Sequential(
96 | Swish(),
97 | nn.Linear(time_emb_dim, dim_out)
98 | ) if exists(time_emb_dim) else None
99 | self.noise_func = FeatureWiseAffine(
100 | time_emb_dim, dim_out, use_affine_level=False)
101 |
102 | self.block1 = Block(dim, dim_out, groups=norm_groups)
103 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
104 | self.res_conv = nn.Conv2d(
105 | dim, dim_out, 1) if dim != dim_out else nn.Identity()
106 |
107 | def forward(self, x, time_emb):
108 | h = self.block1(x)
109 | h = self.noise_func(h, time_emb)
110 | h = self.block2(h)
111 | return h + self.res_conv(x)
112 |
113 |
114 | class SelfAttention(nn.Module):
115 | def __init__(self, in_channel, n_head=1, norm_groups=32):
116 | super().__init__()
117 |
118 | self.n_head = n_head
119 |
120 | self.norm = nn.GroupNorm(norm_groups, in_channel)
121 | self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
122 | self.out = nn.Conv2d(in_channel, in_channel, 1)
123 |
124 | def forward(self, input):
125 | batch, channel, height, width = input.shape
126 | n_head = self.n_head
127 | head_dim = channel // n_head
128 |
129 | norm = self.norm(input)
130 | qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
131 | query, key, value = qkv.chunk(3, dim=2) # bhdyx
132 |
133 | attn = torch.einsum(
134 | "bnchw, bncyx -> bnhwyx", query, key
135 | ).contiguous() / math.sqrt(channel)
136 | attn = attn.view(batch, n_head, height, width, -1)
137 | attn = torch.softmax(attn, -1)
138 | attn = attn.view(batch, n_head, height, width, height, width)
139 |
140 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
141 | out = self.out(out.view(batch, channel, height, width))
142 |
143 | return out + input
144 |
145 |
146 | class ResnetBlocWithAttn(nn.Module):
147 | def __init__(self, dim, dim_out, *, time_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
148 | super().__init__()
149 | self.with_attn = with_attn
150 | self.res_block = ResnetBlock(
151 | dim, dim_out, time_emb_dim, norm_groups=norm_groups, dropout=dropout)
152 | if with_attn:
153 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
154 |
155 | def forward(self, x, time_emb):
156 | x = self.res_block(x, time_emb)
157 | if(self.with_attn):
158 | x = self.attn(x)
159 | return x
160 |
161 |
162 | class UNet(nn.Module):
163 | def __init__(
164 | self,
165 | in_channel=6,
166 | out_channel=3,
167 | inner_channel=32,
168 | norm_groups=32,
169 | channel_mults=(1, 1, 2, 2, 4),
170 | attn_res=(8),
171 | res_blocks=3,
172 | dropout=0,
173 | with_noise_level_emb=True,
174 | image_size=128
175 | ):
176 | super().__init__()
177 | if with_noise_level_emb:
178 | noise_level_channel = inner_channel
179 | self.noise_level_mlp = nn.Sequential(
180 | PositionalEncoding(inner_channel),
181 | nn.Linear(inner_channel, inner_channel * 4),
182 | Swish(),
183 | nn.Linear(inner_channel * 4, inner_channel)
184 | )
185 | else:
186 | noise_level_channel = None
187 | self.noise_level_mlp = None
188 |
189 |
190 | num_mults = len(channel_mults)
191 | pre_channel = inner_channel
192 | feat_channels = [pre_channel]
193 | now_res = image_size
194 | downs = [nn.Conv2d(in_channel, inner_channel,
195 | kernel_size=3, padding=1)]
196 | for ind in range(num_mults):
197 | is_last = (ind == num_mults - 1)
198 | use_attn = (now_res in attn_res)
199 | channel_mult = inner_channel * channel_mults[ind]
200 | for _ in range(0, res_blocks):
201 | downs.append(ResnetBlocWithAttn(
202 | pre_channel, channel_mult, time_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn))
203 | feat_channels.append(channel_mult)
204 | pre_channel = channel_mult
205 | if not is_last:
206 | downs.append(Downsample(pre_channel))
207 | feat_channels.append(pre_channel)
208 | now_res = now_res//2
209 | self.downs = nn.ModuleList(downs)
210 |
211 | self.mid = nn.ModuleList([
212 | ResnetBlocWithAttn(pre_channel, pre_channel, time_emb_dim=noise_level_channel, norm_groups=norm_groups,
213 | dropout=dropout, with_attn=True),
214 | ResnetBlocWithAttn(pre_channel, pre_channel, time_emb_dim=noise_level_channel, norm_groups=norm_groups,
215 | dropout=dropout, with_attn=False)
216 | ])
217 |
218 | ups = []
219 | for ind in reversed(range(num_mults)):
220 | is_last = (ind < 1)
221 | use_attn = (now_res in attn_res)
222 | channel_mult = inner_channel * channel_mults[ind]
223 | for _ in range(0, res_blocks+1):
224 | ups.append(ResnetBlocWithAttn(
225 | pre_channel+feat_channels.pop(), channel_mult, time_emb_dim=noise_level_channel, dropout=dropout, norm_groups=norm_groups, with_attn=use_attn))
226 | pre_channel = channel_mult
227 | if not is_last:
228 | ups.append(Upsample(pre_channel))
229 | now_res = now_res*2
230 |
231 | self.ups = nn.ModuleList(ups)
232 |
233 | self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)
234 |
235 | self.var_conv = nn.Sequential(*[
236 | nn.Conv2d(pre_channel, pre_channel, 3, padding=(3//2), bias=True),
237 | nn.ELU(),
238 | nn.Conv2d(pre_channel, pre_channel, 3, padding=(3//2), bias=True),
239 | nn.ELU(),
240 | nn.Conv2d(pre_channel, 3, 3, padding=(3//2), bias=True),
241 | nn.ELU()
242 | ])
243 | # self.swish = Swish()
244 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
245 | return nn.Conv2d(
246 | in_channels, out_channels, kernel_size,
247 | padding=(kernel_size//2), bias=bias)
248 |
249 | def forward(self, x, noise):
250 | noise_level = self.noise_level_mlp(noise) if exists(self.noise_level_mlp) else None
251 | feats = []
252 | for layer in self.downs:
253 | if isinstance(layer, ResnetBlocWithAttn):
254 | x = layer(x, noise_level)
255 | else:
256 | x = layer(x)
257 | feats.append(x)
258 |
259 | for layer in self.mid:
260 | if isinstance(layer, ResnetBlocWithAttn):
261 | x = layer(x, noise_level)
262 | else:
263 | x = layer(x)
264 |
265 | for layer in self.ups:
266 | if isinstance(layer, ResnetBlocWithAttn):
267 | x = layer(torch.cat((x, feats.pop()), dim=1), noise_level)
268 | else:
269 | x = layer(x)
270 | return self.final_conv(x), self.var_conv(x)
271 |
272 |
273 |
274 | # FreeU
275 | def Fourier_filter(x, threshold, scale):
276 | # FFT
277 | x_freq = fft.fftn(x, dim=(-2, -1))
278 | x_freq = fft.fftshift(x_freq, dim=(-2, -1))
279 |
280 | B, C, H, W = x_freq.shape
281 | mask = torch.ones((B, C, H, W)).cuda()
282 |
283 | crow, ccol = H // 2, W //2
284 | mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
285 | x_freq = x_freq * mask
286 |
287 | # IFFT
288 | x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
289 | x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
290 |
291 | return x_filtered
292 |
293 | def SNR_filter(x_in, threshold, scale):
294 |
295 | blur_transform = T.GaussianBlur(kernel_size=15, sigma=3)
296 | blur_x_in = blur_transform(x_in)
297 | gray_blur_x_in = blur_x_in[:, 0:1, :, :] * 0.299 + blur_x_in[:, 1:2, :, :] * 0.587 + blur_x_in[:, 2:3, :, :] * 0.114
298 | gray_x_in = x_in[:, 0:1, :, :] * 0.299 + x_in[:, 1:2, :, :] * 0.587 + x_in[:, 2:3, :, :] * 0.114
299 | noise = torch.abs(gray_blur_x_in - gray_x_in)
300 | mask = torch.div(gray_x_in, noise + 0.0001)
301 |
302 | batch_size = mask.shape[0]
303 | height = mask.shape[2]
304 | width = mask.shape[3]
305 | mask_max = torch.max(mask.view(batch_size, -1), dim=1)[0]
306 | mask_max = mask_max.view(batch_size, 1, 1, 1)
307 | mask_max = mask_max.repeat(1, 1, height, width)
308 | mask = mask * 1.0 / (mask_max+0.0001)
309 | mask = torch.clamp(mask, min=0, max=1.0)
310 | mask = mask.float()
311 |
312 | return mask
313 |
314 |
315 | class Free_UNet(UNet):
316 | """
317 | :param b1: backbone factor of the first stage block of decoder.
318 | :param b2: backbone factor of the second stage block of decoder.
319 | :param s1: skip factor of the first stage block of decoder.
320 | :param s2: skip factor of the second stage block of decoder.
321 | """
322 |
323 | def __init__(
324 | self,
325 | b1=1.3,
326 | b2=1.4,
327 | s1=0.9,
328 | s2=0.2,
329 | *args,
330 | **kwargs
331 | ):
332 | super().__init__(*args, **kwargs)
333 | self.b1 = b1
334 | self.b2 = b2
335 | self.s1 = s1
336 | self.s2 = s2
337 |
338 | def forward(self, h, noise):
339 | # what we need is only x and noise
340 | """
341 | Apply the model to an input batch.
342 | :param x: an [N x C x ...] Tensor of inputs.
343 | :param timesteps: a 1-D batch of timesteps.
344 | :param context: conditioning plugged in via crossattn
345 | :param y: an [N] Tensor of labels, if class-conditional.
346 | :return: an [N x C x ...] Tensor of outputs.
347 | """
348 | hs = []
349 | noise_level = self.noise_level_mlp(noise) if exists(self.noise_level_mlp) else None
350 |
351 |
352 |
353 | for layer in self.downs:
354 | if isinstance(layer, ResnetBlocWithAttn):
355 | h = layer(h, noise_level)
356 | else:
357 | h = layer(h)
358 | hs.append(h)
359 |
360 | for layer in self.mid:
361 | if isinstance(layer, ResnetBlocWithAttn):
362 | h = layer(h, noise_level)
363 | else:
364 | h = layer(h)
365 |
366 | for layer in self.ups:
367 | # --------------- FreeU code -----------------------
368 | # Only operate on the first two stages
369 | if h.shape[1] == 256:
370 | hs_ = hs.pop()
371 | hidden_mean = h.mean(1).unsqueeze(1)
372 | B = hidden_mean.shape[0]
373 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
374 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
375 | hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
376 |
377 | h[:,:128] = h[:,:128] * ((self.b1 - 1 ) * hidden_mean + 1)
378 | hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1)
379 | hs.append(hs_)
380 | if h.shape[1] == 128:
381 | hs_ = hs.pop()
382 | hidden_mean = h.mean(1).unsqueeze(1)
383 | B = hidden_mean.shape[0]
384 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
385 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
386 | hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
387 |
388 | h[:,:64] = h[:,:64] * ((self.b2 - 1 ) * hidden_mean + 1)
389 | hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2)
390 | hs.append(hs_)
391 | # ---------------------------------------------------------
392 | # exit(print(h.shape, hs_.shape)) # torch.Size([1, 256, 26, 38]) torch.Size([1, 256, 26, 38])
393 | # print(h.shape, hs_.shape)
394 | # h = torch.cat((h, hs_), dim=1)
395 |
396 | if isinstance(layer, ResnetBlocWithAttn):
397 | h = layer(torch.cat((h, hs.pop()), dim=1), noise_level)
398 | else:
399 | h = layer(h)
400 | return self.final_conv(h), self.var_conv(h)
401 |
402 | class LAUNet(UNet):
403 | """
404 | :param b1: backbone factor of the first stage block of decoder.
405 | :param b2: backbone factor of the second stage block of decoder.
406 | :param s1: skip factor of the first stage block of decoder.
407 | :param s2: skip factor of the second stage block of decoder.
408 | """
409 |
410 | def __init__(
411 | self,
412 | b1=1.3,
413 | b2=1.4,
414 | s1=0.9,
415 | s2=0.2,
416 | *args,
417 | **kwargs
418 | ):
419 | super().__init__(*args, **kwargs)
420 | self.b1 = b1
421 | self.b2 = b2
422 | self.s1 = s1
423 | self.s2 = s2
424 |
425 | def forward(self, h, noise):
426 | # what we need is only x and noise
427 | """
428 | Apply the model to an input batch.
429 | :param x: an [N x C x ...] Tensor of inputs.
430 | :param timesteps: a 1-D batch of timesteps.
431 | :param context: conditioning plugged in via crossattn
432 | :param y: an [N] Tensor of labels, if class-conditional.
433 | :return: an [N x C x ...] Tensor of outputs.
434 | """
435 | hs = []
436 | noise_level = self.noise_level_mlp(noise) if exists(self.noise_level_mlp) else None
437 |
438 |
439 |
440 | for layer in self.downs:
441 | if isinstance(layer, ResnetBlocWithAttn):
442 | h = layer(h, noise_level)
443 | else:
444 | h = layer(h)
445 | hs.append(h)
446 |
447 | for layer in self.mid:
448 | if isinstance(layer, ResnetBlocWithAttn):
449 | h = layer(h, noise_level)
450 | else:
451 | h = layer(h)
452 |
453 | for layer in self.ups:
454 | # --------------- FreeU code -----------------------
455 | # Only operate on the first two stages
456 | if h.shape[1] == 256:
457 | hs_ = hs.pop()
458 | hidden_mean = h.mean(1).unsqueeze(1)
459 | B = hidden_mean.shape[0]
460 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
461 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
462 | hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
463 |
464 | h[:,:128] = h[:,:128] * ((self.b1 - 1 ) * hidden_mean + 1)
465 | hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1)
466 | hs.append(hs_)
467 | if h.shape[1] == 128:
468 | hs_ = hs.pop()
469 | hidden_mean = h.mean(1).unsqueeze(1)
470 | B = hidden_mean.shape[0]
471 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
472 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
473 | hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
474 |
475 | h[:,:64] = h[:,:64] * ((self.b2 - 1 ) * hidden_mean + 1)
476 | hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2)
477 | hs.append(hs_)
478 | # ---------------------------------------------------------
479 | # exit(print(h.shape, hs_.shape)) # torch.Size([1, 256, 26, 38]) torch.Size([1, 256, 26, 38])
480 | # print(h.shape, hs_.shape)
481 | # h = torch.cat((h, hs_), dim=1)
482 |
483 | if isinstance(layer, ResnetBlocWithAttn):
484 | h = layer(torch.cat((h, hs.pop()), dim=1), noise_level)
485 | else:
486 | h = layer(h)
487 | return self.final_conv(h), self.var_conv(h)
488 |
489 | class LAUNet(UNet):
490 | """
491 | :param b1: backbone factor of the first stage block of decoder.
492 | :param b2: backbone factor of the second stage block of decoder.
493 | :param s1: skip factor of the first stage block of decoder.
494 | :param s2: skip factor of the second stage block of decoder.
495 | """
496 |
497 | def __init__(
498 | self,
499 | b1=1.3,
500 | b2=1.4,
501 | s1=0.9,
502 | s2=0.2,
503 | *args,
504 | **kwargs
505 | ):
506 | super().__init__(*args, **kwargs)
507 | self.b1 = b1
508 | self.b2 = b2
509 | self.s1 = s1
510 | self.s2 = s2
511 |
512 | def forward(self, h, noise):
513 | # what we need is only x and noise
514 | """
515 | Apply the model to an input batch.
516 | :param x: an [N x C x ...] Tensor of inputs.
517 | :param timesteps: a 1-D batch of timesteps.
518 | :param context: conditioning plugged in via crossattn
519 | :param y: an [N] Tensor of labels, if class-conditional.
520 | :return: an [N x C x ...] Tensor of outputs.
521 | """
522 | hs = []
523 | noise_level = self.noise_level_mlp(noise) if exists(self.noise_level_mlp) else None
524 |
525 |
526 |
527 | for layer in self.downs:
528 | # print(h.shape)
529 | if isinstance(layer, ResnetBlocWithAttn):
530 | h = layer(h, noise_level)
531 | else:
532 | h = layer(h)
533 | hs.append(h)
534 | # print("\n")
535 | for layer in self.mid:
536 | # print(h.shape)
537 | if isinstance(layer, ResnetBlocWithAttn):
538 | h = layer(h, noise_level)
539 | else:
540 | h = layer(h)
541 | # print("\n")
542 | for layer in self.ups:
543 | # print(h.shape)
544 | # --------------- FreeU code -----------------------
545 | # Only operate on the first two stages
546 | if h.shape[1] == 256:
547 | hs_ = hs.pop()
548 | hidden_mean = h.mean(1).unsqueeze(1)
549 | B = hidden_mean.shape[0]
550 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
551 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
552 | hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
553 |
554 | h[:,:64] = h[:,:64] * ((self.b1 - 1 ) * hidden_mean + 1)
555 | hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1)
556 | hs.append(hs_)
557 | # if h.shape[1] == 128 and h.shape[2] == 104:
558 | # hs_ = hs.pop()
559 | # hidden_mean = h.mean(1).unsqueeze(1)
560 | # B = hidden_mean.shape[0]
561 | # hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True)
562 | # hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
563 | # hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
564 |
565 | # h[:,:64] = h[:,:64] * ((self.b2 - 1 ) * hidden_mean + 1)
566 | # hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2)
567 | # hs.append(hs_)
568 | # ---------------------------------------------------------
569 | # exit(print(h.shape, hs_.shape)) # torch.Size([1, 256, 26, 38]) torch.Size([1, 256, 26, 38])
570 | # print(h.shape, hs_.shape)
571 | # h = torch.cat((h, hs_), dim=1)
572 | if isinstance(layer, ResnetBlocWithAttn):
573 | h = layer(torch.cat((h, hs.pop()), dim=1), noise_level)
574 | else:
575 | h = layer(h)
576 |
577 | # exit(-1)
578 | return self.final_conv(h), self.var_conv(h)
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
1 | import logging
2 | from collections import OrderedDict
3 | import torch
4 | import torch.nn as nn
5 | import os
6 | import model.networks as networks
7 | from .base_model import BaseModel
8 | from torch.nn.parallel import DistributedDataParallel as DDP
9 | from utils.ema import EMA
10 | from torch.optim import lr_scheduler
11 | import lpips
12 |
13 | logger = logging.getLogger('base')
14 | skip_para = []
15 |
16 | skip_para = ['betas', 'alphas_cumprod', 'alphas_cumprod_prev', 'sqrt_alphas_cumprod',
17 | 'sqrt_one_minus_alphas_cumprod', 'log_one_minus_alphas_cumprod', 'sqrt_recip_alphas_cumprod',
18 | 'sqrt_recipm1_alphas_cumprod', 'posterior_variance', 'posterior_log_variance_clipped',
19 | 'posterior_mean_coef1', 'posterior_mean_coef2',]
20 |
21 | def get_scheduler(optimizer, opt):
22 | if opt['train']["optimizer"]['lr_policy'] == 'linear':
23 | def lambda_rule(iteration):
24 | lr_l = 1.0 - max(0, iteration-opt['train']["optimizer"]["n_lr_iters"]) / float(opt['train']["optimizer"]["lr_decay_iters"] + 1)
25 | return lr_l
26 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
27 | elif opt['train']["optimizer"]['lr_policy'] == 'step':
28 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt['train']["optimizer"]["lr_decay_iters"], gamma=0.8)
29 | elif opt['train']["optimizer"]['lr_policy'] == 'plateau':
30 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
31 | elif opt['train']["optimizer"]['lr_policy'] == 'cosine':
32 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
33 | else:
34 | return NotImplementedError('learning rate policy [%s] is not implemented', opt['train']["optimizer"]['lr_policy'])
35 | return scheduler
36 |
37 |
38 | class DDPM(BaseModel):
39 | def __init__(self, opt):
40 | super(DDPM, self).__init__(opt)
41 |
42 | if opt['dist']:
43 | self.local_rank = torch.distributed.get_rank()
44 | torch.cuda.set_device(self.local_rank)
45 | device = torch.device("cuda", self.local_rank)
46 | # define network and load pretrained models
47 | self.netG = self.set_device(networks.define_G(opt, student=False))
48 | if opt['dist']:
49 | self.netG.to(device)
50 |
51 | # self.netG.to(device)
52 |
53 | self.schedule_phase = None
54 | self.opt = opt
55 |
56 | # set loss and load resume state
57 | self.set_loss()
58 |
59 | if self.opt['phase'] == 'train':
60 | self.netG.train()
61 | # find the parameters to optimize
62 | if opt['model']['finetune_norm']:
63 | optim_params = []
64 | for k, v in self.netG.named_parameters():
65 | v.requires_grad = False
66 | if k.find('transformer') >= 0:
67 | v.requires_grad = True
68 | v.data.zero_()
69 | optim_params.append(v)
70 | logger.info(
71 | 'Params [{:s}] initialized to 0 and will optimize.'.format(k))
72 | else:
73 | optim_params = list(self.netG.parameters())
74 |
75 | self.optG = torch.optim.Adam(
76 | optim_params, lr=opt['train']["optimizer"]["lr"])
77 | self.log_dict = OrderedDict()
78 |
79 | if self.opt['phase'] == 'test':
80 | self.netG.load_state_dict(torch.load(self.opt['path']['resume_state']), strict=True)
81 |
82 | else:
83 | self.load_network()
84 | if opt['dist']:
85 | self.netG = DDP(self.netG, device_ids=[self.local_rank], output_device=self.local_rank,find_unused_parameters=True)
86 |
87 |
88 | def feed_data(self, data):
89 |
90 | dic = {}
91 |
92 | if self.opt['dist']:
93 | dic = {}
94 | dic['LQ'] = data['LQ'].to(self.local_rank)
95 | dic['GT'] = data['GT'].to(self.local_rank)
96 | self.data = dic
97 | else:
98 | dic['LQ'] = data['LQ']
99 | dic['GT'] = data['GT']
100 |
101 | self.data = self.set_device(dic)
102 |
103 |
104 | def test(self, continous=False):
105 | self.netG.eval()
106 | with torch.no_grad():
107 | if isinstance(self.netG, nn.DataParallel):
108 | self.SR = self.netG.module.super_resolution(
109 | self.data['LQ'], continous)
110 |
111 | else:
112 | if self.opt['dist']:
113 | self.SR = self.netG.module.super_resolution(self.data['LQ'], continous)
114 | else:
115 | self.SR = self.netG.super_resolution(self.data['LQ'], continous)
116 |
117 | self.netG.train()
118 |
119 | def sample(self, batch_size=1, continous=False):
120 | self.netG.eval()
121 | with torch.no_grad():
122 | if isinstance(self.netG, nn.DataParallel):
123 | self.SR = self.netG.module.sample(batch_size, continous)
124 | else:
125 | self.SR = self.netG.sample(batch_size, continous)
126 | self.netG.train()
127 |
128 | def set_loss(self):
129 | if isinstance(self.netG, nn.DataParallel):
130 | self.netG.module.set_loss(self.device)
131 | else:
132 | self.netG.set_loss(self.device)
133 |
134 | def set_new_noise_schedule(self, schedule_opt, schedule_phase='train'):
135 |
136 | if self.opt['dist']:
137 |
138 | device = torch.device("cuda", self.local_rank)
139 | if self.schedule_phase is None or self.schedule_phase != schedule_phase:
140 | self.schedule_phase = schedule_phase
141 | if isinstance(self.netG, nn.DataParallel):
142 | self.netG.module.set_new_noise_schedule(
143 | schedule_opt, self.device)
144 | else:
145 | self.netG.set_new_noise_schedule(schedule_opt)
146 |
147 | else:
148 | self.schedule_phase = schedule_phase
149 | if isinstance(self.netG, nn.DataParallel):
150 | self.netG.module.set_new_noise_schedule(
151 | schedule_opt, self.device)
152 | else:
153 | # self.netG.set_new_noise_schedule(schedule_opt, self.device)
154 | self.netG.set_new_noise_schedule(schedule_opt, self.device)
155 |
156 |
157 | def get_current_log(self):
158 | return self.log_dict
159 |
160 | def get_current_visuals(self, need_LR=True, sample=False):
161 | out_dict = OrderedDict()
162 | if sample:
163 | out_dict['SAM'] = self.SR.detach().float().cpu()
164 | else:
165 | out_dict['HQ'] = self.SR.detach().float().cpu()
166 | out_dict['INF'] = self.data['LQ'].detach().float().cpu()
167 | out_dict['GT'] = self.data['GT'].detach()[0].float().cpu()
168 | if need_LR and 'LR' in self.data:
169 | out_dict['LQ'] = self.data['LQ'].detach().float().cpu()
170 | else:
171 | out_dict['LQ'] = out_dict['INF']
172 | return out_dict
173 |
174 | def print_network(self):
175 | s, n = self.get_network_description(self.netG)
176 | if isinstance(self.netG, nn.DataParallel):
177 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
178 | self.netG.module.__class__.__name__)
179 | else:
180 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
181 |
182 | logger.info(s)
183 |
184 | logger.info(
185 | 'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
186 |
187 | def save_network(self, distill_step, epoch, iter_step):
188 | gen_path = os.path.join(
189 | self.opt['path']['checkpoint'], 'num_step_{}', 'I{}_E{}_gen.pth'.format(distill_step, iter_step, epoch))
190 | opt_path = os.path.join(
191 | self.opt['path']['checkpoint'], 'num_step_{}', 'I{}_E{}_opt.pth'.format(distill_step, iter_step, epoch))
192 |
193 | # gen
194 | network = self.netG
195 | if isinstance(self.netG, nn.DataParallel):
196 | network = network.module
197 | state_dict = network.state_dict()
198 | for key, param in state_dict.items():
199 | state_dict[key] = param.cpu()
200 | torch.save(state_dict, gen_path)
201 |
202 |
203 |
204 | logger.info(
205 | 'Saved model in [{:s}] ...'.format(gen_path))
206 |
207 | def load_network(self):
208 | load_path = self.opt['path']['resume_state']
209 | if load_path is not None:
210 | logger.info(
211 | 'Loading pretrained model for G [{:s}] ...'.format(load_path))
212 | gen_path = '{}'.format(load_path)
213 |
214 | # gen
215 | networks = [self.netG, self.netG]
216 | for network in networks:
217 | if isinstance(network, nn.DataParallel):
218 | network = network.module
219 |
220 | # network = nn.DataParallel(network).cuda()
221 | ckpt = torch.load(gen_path)
222 | current_state_dict = network.state_dict()
223 | for name, param in ckpt.items():
224 | if name in skip_para:
225 | continue
226 | # print(name)
227 | # import pdb; pdb.set_trace()
228 | else:
229 | current_state_dict[name] = param
230 |
231 | network.load_state_dict(current_state_dict, strict=False)
232 | if self.opt['phase'] == 'train':
233 | self.begin_step = 0
234 | self.begin_epoch = 0
235 |
236 |
237 |
238 |
239 | class DDPM_PD(BaseModel):
240 | def __init__(self, opt):
241 | super(DDPM_PD, self).__init__(opt)
242 |
243 | if opt['dist']:
244 | self.local_rank = torch.distributed.get_rank()
245 | torch.cuda.set_device(self.local_rank)
246 | device = torch.device("cuda", self.local_rank)
247 | # define network and load pretrained models
248 | self.netG_t = self.set_device(networks.define_G(opt, student=False))
249 | if opt['CD'] :
250 | self.netG_s = self.set_device(networks.define_G(opt, student=False))
251 | else:
252 | self.netG_s = self.set_device(networks.define_G(opt, student=True))
253 | if opt['dist']:
254 | self.netG_t.to(device)
255 | self.netG_s.to(device)
256 |
257 | # self.netG.to(device)
258 |
259 |
260 | self.schedule_phase = None
261 | self.opt = opt
262 |
263 | # set loss and load resume state
264 |
265 | self.set_loss()
266 | self.lpips = lpips.LPIPS(net='vgg').cuda()
267 |
268 | # self.set_new_noise_schedule(opt['model']['beta_schedule']['train'], schedule_phase='train')
269 |
270 | if self.opt['phase'] == 'train':
271 | self.netG_s.train()
272 | # find the parameters to optimize
273 |
274 | if opt['model']['finetune_norm']:
275 | optim_params = []
276 | for k, v in self.netG_s.named_parameters():
277 | v.requires_grad = False
278 | if k.find('transformer') >= 0:
279 | v.requires_grad = True
280 | v.data.zero_()
281 | optim_params.append(v)
282 | logger.info(
283 | 'Params [{:s}] initialized to 0 and will optimize.'.format(k))
284 | else:
285 | optim_params = list(self.netG_s.parameters())
286 |
287 | self.optG = torch.optim.Adam(optim_params, lr=opt['train']["optimizer"]["lr"])
288 | self.scheduler = get_scheduler(self.optG, opt)
289 | self.log_dict = OrderedDict()
290 |
291 |
292 |
293 |
294 | if self.opt['phase'] == 'test':
295 | self.netG_s.load_state_dict(torch.load(self.opt['path']['resume_state']), strict=True)
296 | else:
297 | self.load_network()
298 | # self.netG_t.load_state_dict(torch.load(self.opt['path']['resume_state']), strict=True)
299 | if opt['dist']:
300 | self.netG_s = DDP(self.netG_s, device_ids=[self.local_rank], output_device=self.local_rank,find_unused_parameters=True)
301 | self.netG_t = DDP(self.netG_t, device_ids=[self.local_rank], output_device=self.local_rank,find_unused_parameters=True)
302 | for p in self.netG_t.parameters():
303 | p.requires_grad_(False)
304 | self.netG_t.eval()
305 | self.netG_t.CD = opt['CD']
306 | self.netG_s.CD = opt['CD']
307 |
308 | # self.print_network()
309 |
310 | # define ema
311 | self.ema_decay = opt['train']["ema_scheduler"]["ema_decay"]
312 | if self.opt['dist']:
313 | self.ema_student = EMA(
314 | self.netG_s.module,
315 | decay = self.ema_decay, # exponential moving average factor
316 | )
317 | else:
318 | self.ema_student = EMA(
319 | self.netG_s,
320 | decay = self.ema_decay, # exponential moving average factor
321 | )
322 |
323 | self.ema_student.register()
324 |
325 | def feed_data(self, data):
326 |
327 | dic = {}
328 |
329 | if self.opt['dist']:
330 | dic = {}
331 | dic['LQ'] = data['LQ'].to(self.local_rank)
332 | dic['GT'] = data['GT'].to(self.local_rank)
333 | self.data = dic
334 | else:
335 | dic['LQ'] = data['LQ']
336 | dic['GT'] = data['GT']
337 |
338 | self.data = self.set_device(dic)
339 |
340 | def optimize_parameters(self):
341 |
342 | self.optG.zero_grad()
343 | if self.opt['dist']:
344 | l_pd = self.netG_t(self.data, self.netG_s.module, lpips_func=self.lpips)
345 | else:
346 | l_pd = self.netG_t(self.data, self.netG_s, lpips_func=self.lpips)
347 | # print("to be debug")
348 | # import pdb; pdb.set_trace()
349 | #
350 |
351 |
352 | loss = l_pd
353 | # print(l_pd)
354 | # import pdb; pdb.set_trace()
355 | loss.backward()
356 | torch.nn.utils.clip_grad_norm_(self.netG_s.parameters(), 1)
357 | self.optG.step()
358 | self.scheduler.step()
359 | # print( self.optG.param_groups[0]['lr'])
360 | self.ema_student.update()
361 | # set log
362 | self.log_dict['total_loss'] = loss.item()
363 |
364 |
365 |
366 |
367 | def test(self, continous=False, stride=1):
368 | self.ema_student.apply_shadow() # apply shadow weights here
369 | self.netG_s.eval()
370 | with torch.no_grad():
371 | if isinstance(self.netG_s, nn.DataParallel):
372 | self.SR = self.netG_s.module.super_resolution(
373 | self.data['LQ'], continous, stride)
374 |
375 | else:
376 | if self.opt['dist']:
377 | self.SR = self.netG_s.module.super_resolution(self.data['LQ'], continous, stride)
378 | else:
379 | self.SR = self.netG_s.super_resolution(self.data['LQ'], continous, stride)
380 | self.ema_student.restore()# restore shadow weights here
381 |
382 | self.netG_s.train()
383 |
384 | def sample(self, batch_size=1, continous=False):
385 | self.ema_student.apply_shadow() # apply shadow weights here
386 | self.netG_s.eval()
387 | with torch.no_grad():
388 | if isinstance(self.netG_s, nn.DataParallel):
389 | self.SR = self.netG_s.module.sample(batch_size, continous)
390 | else:
391 | self.SR = self.netG_s.sample(batch_size, continous)
392 | self.ema_student.restore()# restore shadow weights here
393 | self.netG_s.train()
394 |
395 | def set_loss(self):
396 | if isinstance(self.netG_s, nn.DataParallel):
397 | self.netG_s.module.set_loss(self.device)
398 | else:
399 | self.netG_s.set_loss(self.device)
400 |
401 |
402 | def get_current_log(self):
403 | return self.log_dict
404 |
405 | def get_current_visuals(self, need_LR=True, sample=False):
406 | out_dict = OrderedDict()
407 | if sample:
408 | out_dict['SAM'] = self.SR.detach().float().cpu()
409 | else:
410 | out_dict['HQ'] = self.SR.detach().float().cpu()
411 | out_dict['INF'] = self.data['LQ'].detach().float().cpu()
412 | out_dict['GT'] = self.data['GT'].detach()[0].float().cpu()
413 | if need_LR and 'LR' in self.data:
414 | out_dict['LQ'] = self.data['LQ'].detach().float().cpu()
415 | else:
416 | out_dict['LQ'] = out_dict['INF']
417 | return out_dict
418 |
419 | def print_network(self):
420 | s, n = self.get_network_description(self.netG_s)
421 | if isinstance(self.netG_s, nn.DataParallel):
422 | net_struc_str = '{} - {}'.format(self.netG_s.__class__.__name__,
423 | self.netG_s.module.__class__.__name__)
424 | else:
425 | net_struc_str = '{}'.format(self.netG_s.__class__.__name__)
426 |
427 | logger.info(s)
428 |
429 | logger.info(
430 | 'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
431 |
432 | def save_network(self, distill_step, epoch, iter_step, psnr, ssim, lpips):
433 | save_root = os.path.join(self.opt['path']['checkpoint'], 'num_step_{}'.format(distill_step))
434 | os.makedirs(save_root, exist_ok=True)
435 | # gen_path = os.path.join(save_root, 'P{:.4e}_S{:.4e}_I{}_E{}_gen.pth'.format(psnr, ssim, iter_step, epoch))
436 | # opt_path = os.path.join(save_root, 'P{:.4e}_S{:.4e}_I{}_E{}_opt.pth'.format(psnr, ssim, iter_step, epoch))
437 | ema_path = os.path.join(save_root, 'psnr{:.4f}_ssim{:.4f}_lpips{:.4f}_I{}_E{}_gen_ema.pth'.format(psnr, ssim, lpips, iter_step, epoch))
438 |
439 | # gen
440 | # network = self.netG_s
441 | # if isinstance(self.netG_s, nn.DataParallel):
442 | # network = network.module
443 | # state_dict = network.state_dict()
444 | # for key, param in state_dict.items():
445 | # state_dict[key] = param.cpu()
446 | # torch.save(state_dict, gen_path)
447 |
448 | # opt
449 |
450 |
451 |
452 | # ema
453 | self.ema_student.apply_shadow()
454 | network = self.ema_student.model
455 | if isinstance(self.ema_student.model, nn.DataParallel):
456 | network = network.module
457 | ema_ckpt = network.state_dict()
458 | for key, param in ema_ckpt.items():
459 | ema_ckpt[key] = param.cpu()
460 | torch.save(ema_ckpt, ema_path)
461 | self.ema_student.restore()
462 | # logger.info(
463 | # 'Saved model in [{:s}] ...'.format(gen_path))
464 | logger.info(
465 | 'Saved model in [{:s}] ...'.format(ema_path))
466 | return ema_path # gen_path
467 |
468 | def load_network(self):
469 | load_path = self.opt['path']['resume_state']
470 | if load_path is not None:
471 | logger.info(
472 | 'Loading pretrained model for G [{:s}] ...'.format(load_path))
473 | gen_path = '{}'.format(load_path)
474 |
475 | # gen
476 | networks = [self.netG_t, self.netG_s]
477 | for network in networks:
478 | if isinstance(network, nn.DataParallel):
479 | network = network.module
480 | ckpt = torch.load(gen_path)
481 |
482 | current_state_dict = network.state_dict()
483 | for name, param in ckpt.items():
484 | if name in skip_para:
485 | continue
486 |
487 | else:
488 | current_state_dict[name] = param
489 |
490 | network.load_state_dict(current_state_dict, strict=False)
491 |
492 | if self.opt['phase'] == 'train':
493 |
494 | self.begin_step = 0
495 | self.begin_epoch = 0
496 |
--------------------------------------------------------------------------------
/model/networks.py:
--------------------------------------------------------------------------------
1 | import functools
2 | import logging
3 | import torch
4 | import torch.nn as nn
5 | from torch.nn import init
6 | from torch.nn import modules
7 | logger = logging.getLogger('base')
8 | ####################
9 | # initialize
10 | ####################
11 |
12 |
13 | def weights_init_normal(m, std=0.02):
14 | classname = m.__class__.__name__
15 | if classname.find('Conv') != -1:
16 | init.normal_(m.weight.data, 0.0, std)
17 | if m.bias is not None:
18 | m.bias.data.zero_()
19 | elif classname.find('Linear') != -1:
20 | init.normal_(m.weight.data, 0.0, std)
21 | if m.bias is not None:
22 | m.bias.data.zero_()
23 | elif classname.find('BatchNorm2d') != -1:
24 | init.normal_(m.weight.data, 1.0, std) # BN also uses norm
25 | init.constant_(m.bias.data, 0.0)
26 |
27 |
28 | def weights_init_kaiming(m, scale=1):
29 | classname = m.__class__.__name__
30 | if classname.find('Conv2d') != -1:
31 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
32 | m.weight.data *= scale
33 | if m.bias is not None:
34 | m.bias.data.zero_()
35 | elif classname.find('Linear') != -1:
36 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
37 | m.weight.data *= scale
38 | if m.bias is not None:
39 | m.bias.data.zero_()
40 | elif classname.find('BatchNorm2d') != -1:
41 | init.constant_(m.weight.data, 1.0)
42 | init.constant_(m.bias.data, 0.0)
43 |
44 |
45 | def weights_init_orthogonal(m):
46 | classname = m.__class__.__name__
47 | if classname.find('Conv') != -1:
48 | init.orthogonal_(m.weight.data, gain=1)
49 | if m.bias is not None:
50 | m.bias.data.zero_()
51 | elif classname.find('Linear') != -1:
52 | init.orthogonal_(m.weight.data, gain=1)
53 | if m.bias is not None:
54 | m.bias.data.zero_()
55 | elif classname.find('BatchNorm2d') != -1:
56 | init.constant_(m.weight.data, 1.0)
57 | init.constant_(m.bias.data, 0.0)
58 |
59 |
60 | def init_weights(net, init_type='kaiming', scale=1, std=0.02):
61 | # scale for 'kaiming', std for 'normal'.
62 | logger.info('Initialization method [{:s}]'.format(init_type))
63 | if init_type == 'normal':
64 | weights_init_normal_ = functools.partial(weights_init_normal, std=std)
65 | net.apply(weights_init_normal_)
66 | elif init_type == 'kaiming':
67 | weights_init_kaiming_ = functools.partial(
68 | weights_init_kaiming, scale=scale)
69 | net.apply(weights_init_kaiming_)
70 | elif init_type == 'orthogonal':
71 | net.apply(weights_init_orthogonal)
72 | else:
73 | raise NotImplementedError(
74 | 'initialization method [{:s}] not implemented'.format(init_type))
75 |
76 |
77 | ####################
78 | # define network
79 | ####################
80 |
81 |
82 | # Generator
83 | def define_G(opt, student=False):
84 | model_opt = opt['model']
85 | print(model_opt['which_model_G'])
86 | if model_opt['which_model_G'] == 'ddpm':
87 | from .ddpm_modules import diffusion, unet
88 | if ('norm_groups' not in model_opt['unet']) or model_opt['unet']['norm_groups'] is None:
89 | model_opt['unet']['norm_groups']=32
90 |
91 |
92 | model = unet.UNet(
93 | in_channel=model_opt['unet']['in_channel'],
94 | out_channel=model_opt['unet']['out_channel'],
95 | norm_groups=model_opt['unet']['norm_groups'],
96 | inner_channel=model_opt['unet']['inner_channel'],
97 | channel_mults=model_opt['unet']['channel_multiplier'],
98 | attn_res=model_opt['unet']['attn_res'],
99 | res_blocks=model_opt['unet']['res_blocks'],
100 | dropout=model_opt['unet']['dropout'],
101 | image_size=model_opt['diffusion']['image_size']
102 | )
103 | '''
104 | if opt['freq_aware']:
105 | model = unet.Free_UNet(
106 | in_channel=model_opt['unet']['in_channel'],
107 | out_channel=model_opt['unet']['out_channel'],
108 | norm_groups=model_opt['unet']['norm_groups'],
109 | inner_channel=model_opt['unet']['inner_channel'],
110 | channel_mults=model_opt['unet']['channel_multiplier'],
111 | attn_res=model_opt['unet']['attn_res'],
112 | res_blocks=model_opt['unet']['res_blocks'],
113 | dropout=model_opt['unet']['dropout'],
114 | image_size=model_opt['diffusion']['image_size'],
115 | b1=opt['freq_awareUNet']['b1'],
116 | b2=opt['freq_awareUNet']['b2'],
117 | s1=opt['freq_awareUNet']['s1'],
118 | s2=opt['freq_awareUNet']['s2']
119 | )
120 | '''
121 |
122 | # print(model_opt['beta_schedule']['train']['n_timestep'])
123 | if student:
124 | netG = diffusion.GaussianDiffusion(
125 | model,
126 | image_size=model_opt['diffusion']['image_size'],
127 | num_timesteps=model_opt['beta_schedule']['train']['n_timestep'] // 2,
128 | time_scale=model_opt['beta_schedule']['train']['time_scale'] * 2,
129 | channels=model_opt['diffusion']['channels'],
130 | w_gt= model_opt['diffusion']['w_gt'],
131 | w_snr= model_opt['diffusion']['w_snr'],
132 | w_str= model_opt['diffusion']['w_str'],
133 | w_lpips= model_opt['diffusion']['w_lpips'],
134 | loss_type='l1',
135 | conditional=model_opt['diffusion']['conditional'],
136 | schedule_opt=model_opt['beta_schedule']['train'])
137 | else:
138 |
139 | netG = diffusion.GaussianDiffusion(
140 | model,
141 | image_size=model_opt['diffusion']['image_size'],
142 | num_timesteps=model_opt['beta_schedule']['train']['n_timestep'] ,
143 | time_scale=model_opt['beta_schedule']['train']['time_scale'],
144 | channels=model_opt['diffusion']['channels'],
145 | w_gt= model_opt['diffusion']['w_gt'],
146 | w_snr= model_opt['diffusion']['w_snr'],
147 | w_str= model_opt['diffusion']['w_str'],
148 | w_lpips= model_opt['diffusion']['w_lpips'],
149 | loss_type='l1',
150 | conditional=model_opt['diffusion']['conditional'],
151 | schedule_opt=model_opt['beta_schedule']['train'])
152 |
153 | if opt['phase'] == 'train':
154 | # init_weights(netG, init_type='kaiming', scale=0.1)
155 | init_weights(netG, init_type='orthogonal')
156 | if opt['gpu_ids'] and opt['distributed']:
157 | assert torch.cuda.is_available()
158 | netG = nn.DataParallel(netG)
159 | return netG
160 |
161 |
162 | # Generator
163 | def define_GGG(opt):
164 | model_opt = opt['model']
165 | print(model_opt['which_model_G'])
166 | if model_opt['which_model_G'] == 'ddpm':
167 | from .ddpm_modules import diffusion, unet
168 | if ('norm_groups' not in model_opt['unet']) or model_opt['unet']['norm_groups'] is None:
169 | model_opt['unet']['norm_groups']=32
170 | model = unet.UNet(
171 | in_channel=model_opt['unet']['in_channel'],
172 | out_channel=model_opt['unet']['out_channel'],
173 | norm_groups=model_opt['unet']['norm_groups'],
174 | inner_channel=model_opt['unet']['inner_channel'],
175 | channel_mults=model_opt['unet']['channel_multiplier'],
176 | attn_res=model_opt['unet']['attn_res'],
177 | res_blocks=model_opt['unet']['res_blocks'],
178 | dropout=model_opt['unet']['dropout'],
179 | image_size=model_opt['diffusion']['image_size']
180 | )
181 | netGVar = diffusion.GaussianDiffusion(
182 | model,
183 | image_size=model_opt['diffusion']['image_size'],
184 | channels=model_opt['diffusion']['channels'],
185 | loss_type='l1',
186 | conditional=model_opt['diffusion']['conditional'],
187 | schedule_opt=model_opt['beta_schedule']['train']
188 | )
189 | if opt['phase'] == 'train':
190 | # init_weights(netG, init_type='kaiming', scale=0.1)
191 | init_weights(netGVar, init_type='orthogonal')
192 | if opt['gpu_ids'] and opt['distributed']:
193 | assert torch.cuda.is_available()
194 | netGVar = nn.DataParallel(netGVar)
195 | return netGVar
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/options/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/options/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/options/__pycache__/options.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/options/__pycache__/options.cpython-38.pyc
--------------------------------------------------------------------------------
/options/options.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path as osp
3 | import logging
4 | import yaml
5 | from utils.util import OrderedYaml
6 |
7 | Loader, Dumper = OrderedYaml()
8 |
9 |
10 | def parse(opt_path, is_train=True):
11 | with open(opt_path, mode='r') as f:
12 | opt = yaml.load(f, Loader=Loader)
13 | # export CUDA_VISIBLE_DEVICES
14 | gpu_list = ','.join(str(x) for x in opt.get('gpu_ids', []))
15 | # os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
16 | # print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
17 | opt['is_train'] = is_train
18 |
19 | # datasets
20 | for phase, dataset in opt['datasets'].items():
21 | phase = phase.split('_')[0]
22 | dataset['phase'] = phase
23 |
24 | is_lmdb = False
25 | if dataset.get('dataroot_GT', None) is not None:
26 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
27 | if dataset['dataroot_GT'].endswith('lmdb'):
28 | is_lmdb = True
29 | if dataset.get('dataroot_LQ', None) is not None:
30 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
31 | if dataset['dataroot_LQ'].endswith('lmdb'):
32 | is_lmdb = True
33 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
34 |
35 | # relative learning rate
36 | if 'train' in opt:
37 | niter = opt['train']['niter']
38 | if 'T_period_rel' in opt['train']:
39 | opt['train']['T_period'] = [int(x * niter) for x in opt['train']['T_period_rel']]
40 | if 'restarts_rel' in opt['train']:
41 | opt['train']['restarts'] = [int(x * niter) for x in opt['train']['restarts_rel']]
42 | if 'lr_steps_rel' in opt['train']:
43 | opt['train']['lr_steps'] = [int(x * niter) for x in opt['train']['lr_steps_rel']]
44 | if 'lr_steps_inverse_rel' in opt['train']:
45 | opt['train']['lr_steps_inverse'] = [int(x * niter) for x in opt['train']['lr_steps_inverse_rel']]
46 | print(opt['train'])
47 |
48 | return opt
49 |
50 |
51 | def dict2str(opt, indent_l=1):
52 | '''dict to string for logger'''
53 | msg = ''
54 | for k, v in opt.items():
55 | if isinstance(v, dict):
56 | msg += ' ' * (indent_l * 2) + k + ':[\n'
57 | msg += dict2str(v, indent_l + 1)
58 | msg += ' ' * (indent_l * 2) + ']\n'
59 | else:
60 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
61 | return msg
62 |
63 |
64 | class NoneDict(dict):
65 | def __missing__(self, key):
66 | return None
67 |
68 |
69 | # convert to NoneDict, which return None for missing key.
70 | def dict_to_nonedict(opt):
71 | if isinstance(opt, dict):
72 | new_opt = dict()
73 | for key, sub_opt in opt.items():
74 | new_opt[key] = dict_to_nonedict(sub_opt)
75 | return NoneDict(**new_opt)
76 | elif isinstance(opt, list):
77 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
78 | else:
79 | return opt
80 |
81 |
82 | def check_resume(opt, resume_iter):
83 | '''Check resume states and pretrain_model paths'''
84 | logger = logging.getLogger('base')
85 | if opt['path']['resume_state']:
86 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
87 | 'pretrain_model_D', None) is not None:
88 | logger.warning('pretrain_model path will be ignored when resuming training.')
89 |
90 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
91 | '{}_G.pth'.format(resume_iter))
92 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
93 | if 'gan' in opt['model']:
94 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
95 | '{}_D.pth'.format(resume_iter))
96 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
97 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python==4.5.2.54
2 | PyYAML==6.0
3 | natsort==8.1.0
4 | scikit-image==0.18.1
5 | lpips==0.1.4
6 | kmeans_pytorch
7 | scikit-learn==1.0
8 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import basename
3 | import math
4 | import argparse
5 | import random
6 | import logging
7 | import cv2
8 | import sys
9 | import numpy as np
10 | import torch
11 | import torch.distributed as dist
12 | import torch.multiprocessing as mp
13 | import torch.nn as nn
14 |
15 | import options.options as option
16 | from utils import util
17 | from data import create_dataloader
18 | from data.LoL_dataset import LOLv1_Dataset, LOLv2_Dataset
19 | import torchvision.transforms as T
20 | import lpips
21 | import model as Model
22 | import core.logger as Logger
23 | import core.metrics as Metrics
24 | from torchvision import transforms
25 |
26 |
27 | transform = transforms.Lambda(lambda t: (t * 2) - 1)
28 |
29 | def main():
30 |
31 | parser = argparse.ArgumentParser()
32 | parser = argparse.ArgumentParser()
33 | parser.add_argument('--dataset', type=str, help='Path to option YMAL file.',
34 | default='./config/LOLv1.yml') #
35 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
36 | help='job launcher')
37 | parser.add_argument('--local_rank', type=int, default=0)
38 | parser.add_argument('--tfboard', action='store_true')
39 |
40 |
41 | parser.add_argument('-c', '--config', type=str, default='config/lolv1_test.json',
42 | help='JSON file for configuration')
43 | parser.add_argument('-p', '--phase', type=str, choices=['train', 'val'],
44 | help='Run either train(training) or val(generation)', default='train')
45 | parser.add_argument('-gpu', '--gpu_ids', type=str, default="0")
46 | parser.add_argument('-debug', '-d', action='store_true')
47 | parser.add_argument('-log_eval', action='store_true')
48 |
49 | # search noise schedule
50 | parser.add_argument('--brutal_search', action='store_true')
51 | parser.add_argument('--noise_start', type=float, default=9e-4)
52 | parser.add_argument('--noise_end', type=float, default=8.5e-1)
53 | parser.add_argument('--n_timestep', type=int, default=16)
54 |
55 | parser.add_argument('--w_str', type=float, default=0.)
56 | parser.add_argument('--w_snr', type=float, default=0.)
57 | parser.add_argument('--w_gt', type=float, default=0.1)
58 | parser.add_argument('--w_lpips', type=float, default=0.1)
59 |
60 | parser.add_argument('--stride', type=int, default=1)
61 |
62 |
63 | # for freq_aware
64 | parser.add_argument('--freq_aware', action='store_true')
65 | parser.add_argument('--b1', type=float, default=1.6)
66 | parser.add_argument('--b2', type=float, default=1.6)
67 | parser.add_argument('--s1', type=float, default=0.9)
68 | parser.add_argument('--s2', type=float, default=0.9)
69 |
70 | args = parser.parse_args()
71 | opt = Logger.parse(args)
72 | opt = Logger.dict_to_nonedict(opt)
73 | opt_dataset = option.parse(args.dataset, is_train=True)
74 |
75 |
76 |
77 | os.environ['CUDA_VISIBLE_DEVICES'] = '0'
78 |
79 | opt['phase'] = 'test'
80 | opt['distill'] = False
81 | opt['uncertainty_train'] = False
82 |
83 | #### distributed training settings
84 | opt['dist'] = False
85 | rank = -1
86 | print('Disabled distributed training.')
87 |
88 | #### mkdir and loggers
89 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
90 | # config loggers. Before it, the log will not work
91 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
92 | screen=True, tofile=True)
93 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
94 | screen=True, tofile=True)
95 | logger = logging.getLogger('base')
96 | logger.info(option.dict2str(opt))
97 |
98 | # convert to NoneDict, which returns None for missing keys
99 | opt = option.dict_to_nonedict(opt)
100 |
101 | #### seed
102 | seed = opt['seed']
103 | if seed is None:
104 | seed = random.randint(1, 10000)
105 | if rank <= 0:
106 | logger.info('Seed: {}'.format(seed))
107 | util.set_random_seed(seed)
108 |
109 | torch.backends.cudnn.benchmark = True
110 | # torch.backends.cudnn.deterministic = True
111 |
112 | #### create train and val dataloader
113 | if opt_dataset['dataset'] == 'LOLv1':
114 | dataset_cls = LOLv1_Dataset
115 | elif opt_dataset['dataset'] == 'LOLv2':
116 | dataset_cls = LOLv2_Dataset
117 |
118 | else:
119 | raise NotImplementedError()
120 |
121 | for phase, dataset_opt in opt_dataset['datasets'].items():
122 | if phase == 'val':
123 | val_set = dataset_cls(opt=dataset_opt, train=False, all_opt=opt_dataset)
124 | val_loader = create_dataloader(val_set, dataset_opt, opt_dataset, None)
125 |
126 | # opt["model"]['beta_schedule']["train"]["time_scale"] = 1
127 |
128 | opt["model"]["diffusion"]["w_snr"] = args.w_snr
129 | opt["model"]["diffusion"]["w_str"] = args.w_str
130 | opt["model"]["diffusion"]["w_gt"] = args.w_gt
131 |
132 | # model
133 | diffusion = Model.create_model(opt)
134 | logger.info('Initial Model Finished')
135 |
136 |
137 |
138 | loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
139 | result_path = '{}'.format(opt['path']['results'])
140 | result_path_gt = result_path+'/gt/'
141 | result_path_out = result_path+'/output/'
142 | result_path_input = result_path+'/input/'
143 | os.makedirs(result_path_gt, exist_ok=True)
144 | os.makedirs(result_path_out, exist_ok=True)
145 | os.makedirs(result_path_input, exist_ok=True)
146 |
147 | #diffusion.set_new_noise_schedule(
148 | #opt['model']['beta_schedule']['val'], schedule_phase='val')
149 |
150 |
151 | logger_val = logging.getLogger('val') # validation logger
152 |
153 | avg_psnr = 0.0
154 | avg_ssim = 0.0
155 | avg_lpips = 0.0
156 | idx = 0
157 | lpipss = []
158 |
159 | for val_data in val_loader:
160 |
161 | idx += 1
162 | diffusion.feed_data(val_data)
163 | diffusion.test(continous=False)
164 |
165 | visuals = diffusion.get_current_visuals()
166 |
167 | normal_img = Metrics.tensor2img(visuals['HQ'])
168 | if normal_img.shape[0] != normal_img.shape[1]: # lolv1 and lolv2-real
169 | normal_img = normal_img[8:408, 4:604,:]
170 | gt_img = Metrics.tensor2img(visuals['GT'])
171 | ll_img = Metrics.tensor2img(visuals['LQ'])
172 |
173 | img_mode = 'single'
174 | if img_mode == 'single':
175 | util.save_img(
176 | gt_img, '{}/{}_gt.png'.format(result_path_gt, idx))
177 | util.save_img(
178 | ll_img, '{}/{}_lq.png'.format(result_path_input, idx))
179 | # util.save_img(
180 | # normal_img, '{}/{}_normal_noadjust.png'.format(result_path, idx))
181 | else:
182 | util.save_img(
183 | gt_img, '{}/{}_gt.png'.format(result_path, idx))
184 | util.save_img(
185 | normal_img, '{}/{}_{}_normal_process.png'.format(result_path, idx))
186 | # for i in range(visuals['HQ'].shape[0]):
187 | # util.save_img(Metrics.tensor2img(visuals['HQ'][i]), '{}/{}_{}_normal.png'.format(result_path, idx, i))
188 | # util.save_img(
189 | # Metrics.tensor2img(visuals['HQ'][-1]), '{}/{}_normal.png'.format(result_path, idx))
190 | normal_img = Metrics.tensor2img(visuals['HQ'][-1])
191 |
192 | # Similar to LLFlow, we follow a similar way of 'Kind' to finetune the overall brightness
193 | # as illustrated in Line 73 (https://github.com/zhangyhuaee/KinD/blob/master/evaluate_LOLdataset.py).
194 | gt_img = gt_img / 255.
195 | normal_img = normal_img / 255.
196 | mean_gray_out = cv2.cvtColor(normal_img.astype(np.float32), cv2.COLOR_BGR2GRAY).mean()
197 | mean_gray_gt = cv2.cvtColor(gt_img.astype(np.float32), cv2.COLOR_BGR2GRAY).mean()
198 | normal_img_adjust = np.clip(normal_img * (mean_gray_gt / mean_gray_out), 0, 1)
199 |
200 | normal_img = (normal_img_adjust * 255).astype(np.uint8)
201 | gt_img = (gt_img * 255).astype(np.uint8)
202 |
203 | psnr = util.calculate_psnr(normal_img, gt_img)
204 | ssim = util.calculate_ssim(normal_img, gt_img)
205 |
206 | normal_img_tensor = torch.tensor(normal_img.astype(np.float32))
207 | gt_img_tensor = torch.tensor(gt_img.astype(np.float32))
208 | normal_img_tensor = normal_img_tensor.permute(2, 0, 1).cuda()
209 | gt_img_tensor = gt_img_tensor.permute(2, 0, 1).cuda()
210 | lpips_scores = loss_fn_vgg(normal_img_tensor, gt_img_tensor).item()
211 |
212 | util.save_img(normal_img, '{}/{}_normal.png'.format(result_path_out, idx))
213 |
214 | # lpips
215 |
216 | # lpips_ = loss_fn_vgg(visuals['HQ'], visuals['GT'])
217 | # lpipss.append(lpips_scores.numpy())
218 |
219 | logger_val.info('### {} cPSNR: {:.4e} cSSIM: {:.4e} cLPIPS: {:.4e}'.format(idx, psnr, ssim, lpips_scores))
220 | avg_ssim += ssim
221 | avg_psnr += psnr
222 | avg_lpips += lpips_scores
223 |
224 | avg_psnr = avg_psnr / idx
225 | avg_ssim = avg_ssim / idx
226 | avg_lpips = avg_lpips / idx
227 |
228 | # log
229 | logger_val.info('# Validation # avgPSNR: {:.4e} avgSSIM: {:.4e} avgLPIPS: {:.4e}'.format(avg_psnr, avg_ssim, avg_lpips))
230 | # logger_val.info(f"n_timestep: {args.n_timestep}, noise_start: {args.noise_start}, noise_end: {args.noise_end}")
231 |
232 |
233 |
234 | if __name__ == '__main__':
235 | main()
236 |
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | python test.py --dataset ./config/lolv2_real.yml --config config/lolv2_real_test.json --w_str 0.9 --w_snr 0.2 --w_gt 0.2
2 |
3 |
4 |
--------------------------------------------------------------------------------
/test_unpaired.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import basename
3 | import math
4 | import argparse
5 | import random
6 | import logging
7 | import cv2
8 | import torch
9 | import torchvision.transforms.functional as TF
10 | from PIL import Image
11 | import options.options as option
12 | from utils import util
13 | import torchvision.transforms as T
14 | import model as Model
15 | import core.logger as Logger
16 | import core.metrics as Metrics
17 | import natsort
18 | from torchvision import transforms
19 | from utils.niqe import niqe
20 |
21 | transform = transforms.Lambda(lambda t: (t * 2) - 1)
22 |
23 | def main():
24 | #### options
25 | parser = argparse.ArgumentParser()
26 | parser = argparse.ArgumentParser()
27 | parser.add_argument('--dataset', type=str, help='Path to option YMAL file.',
28 | default='./config/dataset.yml') #
29 | parser.add_argument('--input', type=str, help='testing the unpaired image',
30 | default='images/unpaired/')
31 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
32 | help='job launcher')
33 | parser.add_argument('--local_rank', type=int, default=0)
34 | parser.add_argument('--tfboard', action='store_true')
35 | parser.add_argument('-c', '--config', type=str, default='config/test_unpaired.json',
36 | help='JSON file for configuration')
37 | parser.add_argument('-p', '--phase', type=str, choices=['train', 'val'],
38 | help='Run either train(training) or val(generation)', default='train')
39 | parser.add_argument('-gpu', '--gpu_ids', type=str, default="0")
40 | parser.add_argument('-debug', '-d', action='store_true')
41 | parser.add_argument('-enable_wandb', action='store_true')
42 | parser.add_argument('-log_wandb_ckpt', action='store_true')
43 | parser.add_argument('-log_eval', action='store_true')
44 |
45 | parser.add_argument('--n_timestep', type=int, default=8)
46 | parser.add_argument('--w_str', type=float, default=0.)
47 | parser.add_argument('--w_snr', type=float, default=0.)
48 | parser.add_argument('--w_gt', type=float, default=0.)
49 | parser.add_argument('--w_lpips', type=float, default=0.)
50 |
51 |
52 | parser.add_argument('--brutal_search', action='store_true')
53 |
54 | # parse configs
55 | args = parser.parse_args()
56 | opt = Logger.parse(args)
57 | # Convert to NoneDict, which return None for missing key.
58 | opt = Logger.dict_to_nonedict(opt)
59 |
60 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
61 |
62 | opt['phase'] = 'test'
63 |
64 | #### distributed training settings
65 | opt['dist'] = False
66 | rank = -1
67 | print('Disabled distributed training.')
68 |
69 | #### mkdir and loggers
70 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
71 | # config loggers. Before it, the log will not work
72 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
73 | screen=True, tofile=True)
74 | logger = logging.getLogger('base')
75 | logger.info(option.dict2str(opt))
76 |
77 |
78 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True, tofile=True)
79 | logger = logging.getLogger('base')
80 |
81 | # convert to NoneDict, which returns None for missing keys
82 | opt = option.dict_to_nonedict(opt)
83 |
84 | #### random seed
85 | seed = opt['train']['manual_seed']
86 | if seed is None:
87 | seed = random.randint(1, 10000)
88 | if rank <= 0:
89 | logger.info('Random seed: {}'.format(seed))
90 | util.set_random_seed(seed)
91 |
92 | torch.backends.cudnn.benchmark = True
93 | # torch.backends.cudnn.deterministic = True
94 |
95 |
96 | # model
97 | diffusion = Model.create_model(opt)
98 | logger.info('Initial Model Finished')
99 |
100 | result_path = '{}'.format(opt['path']['results'])
101 | os.makedirs(result_path, exist_ok=True)
102 |
103 | # diffusion.set_new_noise_schedule(
104 | # opt['model']['beta_schedule']['val'], schedule_phase='val')
105 |
106 | InputPath = args.input
107 | Image_names = natsort.natsorted(os.listdir(InputPath), alg=natsort.ns.PATH)
108 |
109 | ave_niqe = 0.
110 |
111 | for i in range(len(Image_names)):
112 |
113 | path = InputPath + Image_names[i]
114 | raw_img = Image.open(path).convert('RGB')
115 | img_w = raw_img.size[0]
116 | img_h = raw_img.size[1]
117 | raw_img = transforms.Resize((img_h // 16 * 16, img_w // 16 * 16))(raw_img)
118 |
119 | raw_img = transform(TF.to_tensor(raw_img)).unsqueeze(0).cuda()
120 |
121 | val_data = {}
122 | val_data['LQ'] = raw_img
123 | val_data['GT'] = raw_img
124 | diffusion.feed_data(val_data)
125 | diffusion.test(continous=False)
126 |
127 | visuals = diffusion.get_current_visuals()
128 |
129 | normal_img = Metrics.tensor2img(visuals['HQ'])
130 | # normal_img = cv2.resize(normal_img, (img_w, img_h))
131 | ll_img = Metrics.tensor2img(visuals['LQ'])
132 | niqe_scores = niqe(normal_img)
133 | ave_niqe = ave_niqe + niqe_scores
134 | llie_img_mode = 'single'
135 | if llie_img_mode == 'single':
136 | # util.save_img(ll_img, '{}/{}_input.png'.format(result_path, idx))
137 | util.save_img(
138 | normal_img, '{}/{}_normal.png'.format(result_path, i+1))
139 | else:
140 | util.save_img(
141 | normal_img, '{}/{}_{}_normal_process.png'.format(result_path, i))
142 | util.save_img(
143 | Metrics.tensor2img(visuals['HQ'][-1]), '{}/{}_normal.png'.format(result_path, i))
144 | normal_img = Metrics.tensor2img(visuals['HQ'][-1])
145 | logger.info('CNIQE: {} on {}'.format(niqe_scores, Image_names[i]))
146 | ave_niqe = ave_niqe / len(Image_names)
147 | logger.info('NIQE: {} on {}'.format(ave_niqe, InputPath))
148 |
149 |
150 |
151 | if __name__ == '__main__':
152 | main()
153 | print("finish!")
154 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | from os.path import basename
3 | import math
4 | import argparse
5 | import logging
6 | import cv2
7 | import sys
8 | import numpy as np
9 | import torch
10 | import torch.distributed as dist
11 | import torch.multiprocessing as mp
12 | import torch.nn as nn
13 | from torchvision import transforms
14 | import options.options as option
15 | from utils import util
16 | from data import create_dataloader
17 | import data as Data
18 | from data.LoL_dataset import LOLv1_Dataset, LOLv2_Dataset
19 | from data.SDSD_image_dataset import Dataset_SDSDImage
20 | from data.SID import ImageDataset2
21 | import torchvision.transforms as T
22 | import model as Model
23 | import core.logger as Logger
24 | import core.metrics as Metrics
25 | import random
26 | import lpips
27 |
28 | import pdb
29 |
30 |
31 |
32 |
33 | def init_dist(backend='nccl', **kwargs):
34 | """initialization for distributed training"""
35 | if mp.get_start_method(allow_none=True) != 'spawn':
36 | mp.set_start_method('spawn')
37 | rank = int(os.environ['RANK'])
38 | num_gpus = torch.cuda.device_count()
39 | torch.cuda.set_device(rank % num_gpus)
40 | dist.init_process_group(backend=backend, **kwargs)
41 |
42 |
43 | def main():
44 |
45 | parser = argparse.ArgumentParser()
46 | parser.add_argument('--dataset', type=str, help='Path to option YMAL file.',
47 | default='./config/lolv2_real.yml') #
48 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
49 | help='job launcher')
50 | parser.add_argument('--local_rank', type=int, default=0)
51 |
52 | parser.add_argument('-c', '--config', type=str, default='config/lolv1_train.json',
53 | help='JSON file for configuration')
54 | parser.add_argument('-p', '--phase', type=str, choices=['train', 'val'],
55 | help='Run either train(training) or val(generation)', default='train')
56 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
57 | parser.add_argument('-debug', '-d', action='store_true')
58 | parser.add_argument('-log_eval', action='store_true')
59 | parser.add_argument('-uncertainty', action='store_true')
60 |
61 | # for ablation
62 | parser.add_argument('--ablation', action='store_true')
63 | parser.add_argument('--w_str', type=float, default=0.2)
64 | parser.add_argument('--w_snr', type=float, default=0.9)
65 | parser.add_argument('--w_gt', type=float, default=0.2)
66 | parser.add_argument('--w_lpips', type=float, default=0.2)
67 |
68 | parser.add_argument('--progressive', action='store_true')
69 | parser.add_argument('--CD', action='store_true')
70 |
71 |
72 |
73 | parser.add_argument('--brutal_search', action='store_true')
74 |
75 | # ema config
76 | parser.add_argument('--ema_decay', type=float, default=0.999)
77 |
78 | # parse configs
79 | args = parser.parse_args()
80 | opt = Logger.parse(args)
81 | # Convert to NoneDict, which return None for missing key.
82 | opt = Logger.dict_to_nonedict(opt)
83 | opt_dataset = option.parse(args.dataset, is_train=True)
84 |
85 | if args.ablation:
86 | opt["model"]["diffusion"]["w_snr"] = args.w_snr
87 | opt["model"]["diffusion"]["w_str"] = args.w_str
88 | opt["model"]["diffusion"]["w_gt"] = args.w_gt
89 | opt["model"]["diffusion"]["w_lpips"] = args.w_lpips
90 | if args.CD:
91 | opt["CD"] = True
92 | if args.launcher == 'none': # disabled distributed training
93 | opt['dist'] = False
94 | rank = -1
95 | print('Disabled distributed training.')
96 | else:
97 | opt['dist'] = True
98 | init_dist()
99 | torch.distributed.init_process_group(
100 | 'nccl',
101 | init_method='env://'
102 | )
103 | world_size = torch.distributed.get_world_size()
104 | rank = torch.distributed.get_rank()
105 | device = torch.device("cuda", rank)
106 |
107 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
108 |
109 | # config loggers. Before it, the log will not work
110 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
111 | screen=True, tofile=True)
112 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
113 | screen=True, tofile=True)
114 | logger = logging.getLogger('base')
115 | # import pdb; pdb.set_trace()
116 | logger.info(option.dict2str(opt))
117 |
118 | # import pdb; pdb.set_trace()
119 |
120 | # tensorboard logger
121 | if opt.get('use_tb_logger', False) and 'debug' not in opt['name']:
122 | version = float(torch.__version__[0:3])
123 | if version >= 1.1: # PyTorch 1.1
124 | # from torch.utils.tensorboard import SummaryWriter
125 | if sys.platform != 'win32':
126 | from tensorboardX import SummaryWriter
127 | else:
128 | from tensorboardX import SummaryWriter
129 | # from torch.utils.tensorboard import SummaryWriter
130 | else:
131 | logger.info(
132 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
133 | from tensorboard import SummaryWriter
134 | conf_name = basename(args.opt).replace(".yml", "")
135 | exp_dir = opt['path']['experiments_root']
136 | log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train')
137 | log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid')
138 | tb_logger_train = SummaryWriter(log_dir=log_dir_train)
139 | tb_logger_valid = SummaryWriter(log_dir=log_dir_valid)
140 | else:
141 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
142 | logger = logging.getLogger('base')
143 |
144 | # convert to NoneDict, which returns None for missing keys
145 | opt = option.dict_to_nonedict(opt)
146 |
147 | #### random seed
148 | seed = opt['seed']
149 | if seed is None:
150 | seed = random.randint(1, 10000)
151 | if rank <= 0:
152 | logger.info('Random seed: {}'.format(seed))
153 | util.set_random_seed(seed)
154 |
155 | torch.backends.cudnn.benchmark = True
156 | torch.backends.cudnn.deterministic = False
157 | torch.backends.cudnn.allow_tf32 = True
158 |
159 | #### create train and val dataloader
160 | if opt_dataset['dataset'] == 'LOLv1':
161 | dataset_cls = LOLv1_Dataset
162 | PD_steps = [16, 8, 4, 2, 1]
163 | temp_time_scale = [1, 2, 4, 8, 16]
164 | time_scale = [i * 32 for i in temp_time_scale]
165 | elif opt_dataset['dataset'] == 'LOLv2':
166 | dataset_cls = LOLv2_Dataset
167 | PD_steps = [16, 8, 4, 2, 1]
168 | temp_time_scale = [1, 2, 4, 8, 16]
169 | time_scale = [i * 32 for i in temp_time_scale]
170 | elif opt_dataset['dataset'] == 'SDSD_indoor':
171 | dataset_cls = Dataset_SDSDImage
172 | PD_steps = [16, 8, 4, 2, 1]
173 | temp_time_scale = [1, 2, 4, 8, 16]
174 | time_scale = [i * 32 for i in temp_time_scale]
175 |
176 | elif opt_dataset['dataset'] == 'SID':
177 | dataset_cls = ImageDataset2
178 | PD_steps = [16, 8, 4, 2, 1]
179 | temp_time_scale = [1, 2, 4, 8, 16]
180 | time_scale = [i * 32 for i in temp_time_scale]
181 | else:
182 | raise NotImplementedError()
183 |
184 | for phase, dataset_opt in opt_dataset['datasets'].items():
185 | if phase == 'train':
186 | train_set = dataset_cls(opt=dataset_opt, train=True, all_opt=opt_dataset)
187 | train_loader = create_dataloader(train_set, dataset_opt, opt_dataset, None)
188 | elif phase == 'val':
189 | val_set = dataset_cls(opt=dataset_opt, train=False, all_opt=opt_dataset)
190 | val_loader = create_dataloader(val_set, dataset_opt, opt_dataset, None)
191 |
192 | # model
193 | resume_state = opt["path"]["resume_state"]
194 | lpips_func = lpips.LPIPS(net='vgg').cuda()
195 |
196 | for i in range(len(PD_steps)):
197 | opt["model"]['beta_schedule']["train"]["n_timestep"] = PD_steps[i]
198 | opt["model"]['beta_schedule']["val"]["n_timestep"] = PD_steps[i+1]
199 |
200 | opt["path"]["resume_state"] = resume_state
201 | opt["model"]['beta_schedule']["train"]["time_scale"] = time_scale[i]
202 | logger.info('Distillation from {:d} to {:d}'.format(opt["model"]['beta_schedule']["train"]["n_timestep"], opt["model"]['beta_schedule']["val"]["n_timestep"]))
203 | logger.info(f"w_snr: {opt['model']['diffusion']['w_snr']}, w_str: {opt['model']['diffusion']['w_str']}")
204 |
205 | diffusion = Model.create_model(opt)
206 |
207 | logger.info('Initial Model Finished')
208 | # Train
209 | current_step = diffusion.begin_step
210 | current_epoch = diffusion.begin_epoch
211 | n_iter = opt['train']['n_iter'] # * iter_scale[i]
212 | # training
213 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(current_epoch, current_step))
214 | avg_psnr = 0
215 | best_psnr = 0
216 | best_ssim = 0
217 | best_lpips = 0
218 |
219 | # pdb.set_trace()
220 | while current_step < n_iter:
221 |
222 | current_epoch += 1
223 | for _, train_data in enumerate(train_loader):
224 |
225 | current_step += 1
226 | if current_step > n_iter:
227 | break
228 |
229 | diffusion.feed_data(train_data)
230 | diffusion.optimize_parameters()
231 | # log
232 | if current_step % opt['train']['print_freq'] == 0 and rank <= 0:
233 | logs = diffusion.get_current_log()
234 | message = ' '.format(
235 | current_epoch, current_step)
236 | for k, v in logs.items():
237 | message += '{:s}: {:.4e} '.format(k, v)
238 | logger.info(message)
239 |
240 | # validation
241 | if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
242 |
243 | avg_psnr = 0.0
244 | avg_ssim = 0.0
245 | avg_lpips = 0.0
246 | idx = 0
247 |
248 | result_path = '{}/{}'.format(opt['path']['results'], PD_steps[i+1], current_epoch)
249 | result_path_gt = result_path+'/gt/'
250 | result_path_out = result_path+'/output/'
251 | result_path_input = result_path+'/input/'
252 |
253 | os.makedirs(result_path_gt, exist_ok=True)
254 | os.makedirs(result_path_out, exist_ok=True)
255 | os.makedirs(result_path_input, exist_ok=True)
256 |
257 |
258 | for val_data in val_loader:
259 |
260 | idx += 1
261 | diffusion.feed_data(val_data)
262 | diffusion.test(continous=False)
263 |
264 | visuals = diffusion.get_current_visuals()
265 |
266 |
267 | if opt_dataset['dataset'] == 'LOLv1' or opt_dataset['dataset'] == 'LOLv2' :
268 | normal_img = Metrics.tensor2img(visuals['HQ'])
269 | if normal_img.shape[0] != normal_img.shape[1]: # lolv1 and lolv2-real
270 | normal_img = normal_img[8:408, 4:604,:]
271 | gt_img = Metrics.tensor2img(visuals['GT'])
272 | ll_img = Metrics.tensor2img(visuals['LQ'])
273 | else:
274 | normal_img = Metrics.tensor2img2(visuals['HQ'])
275 | gt_img = Metrics.tensor2img2(visuals['GT'])
276 | ll_img = Metrics.tensor2img2(visuals['LQ'])
277 |
278 | img_mode = 'single'
279 | '''
280 | if img_mode == 'single':
281 | util.save_img(
282 | gt_img, '{}/{}_gt.png'.format(result_path_gt, idx))
283 | util.save_img(
284 | ll_img, '{}/{}_in.png'.format(result_path_input, idx))
285 | # util.save_img(
286 | # normal_img, '{}/{}_normal.png'.format(result_path_out, idx))
287 | else:
288 | util.save_img(
289 | gt_img, '{}/{}_{}_gt.png'.format(result_path, current_step, idx))
290 | util.save_img(
291 | normal_img, '{}/{}_{}_normal_process.png'.format(result_path, current_step, idx))
292 | util.save_img(
293 | Metrics.tensor2img(visuals['HQ'][-1]), '{}/{}_{}_normal.png'.format(result_path, current_step, idx))
294 | normal_img = Metrics.tensor2img(visuals['HQ'][-1])
295 | '''
296 |
297 |
298 | # Similar to LLFlow,
299 | # we follow a similar way of 'Kind' to finetune the overall brightness as illustrated
300 | # in Line 73 (https://github.com/zhangyhuaee/KinD/blob/master/evaluate_LOLdataset.py).
301 | if opt_dataset['dataset'] == 'LOLv1' or opt_dataset['dataset'] == 'LOLv2':
302 | gt_img = gt_img / 255.
303 | normal_img = normal_img / 255.
304 |
305 | mean_gray_out = cv2.cvtColor(normal_img.astype(np.float32), cv2.COLOR_BGR2GRAY).mean()
306 | mean_gray_gt = cv2.cvtColor(gt_img.astype(np.float32), cv2.COLOR_BGR2GRAY).mean()
307 | normal_img_adjust = np.clip(normal_img * (mean_gray_gt / mean_gray_out), 0, 1)
308 |
309 | normal_img = (normal_img_adjust * 255).astype(np.uint8)
310 | gt_img = (gt_img * 255).astype(np.uint8)
311 |
312 | psnr = util.calculate_psnr(normal_img, gt_img)
313 | ssim = util.calculate_ssim(normal_img, gt_img)
314 |
315 | normal_img_tensor = torch.tensor(normal_img.astype(np.float32))
316 | gt_img_tensor = torch.tensor(gt_img.astype(np.float32))
317 | normal_img_tensor = normal_img_tensor.permute(2, 0, 1).cuda()
318 | gt_img_tensor = gt_img_tensor.permute(2, 0, 1).cuda()
319 | lpips_scores = lpips_func(normal_img_tensor, gt_img_tensor).item()
320 |
321 | util.save_img(normal_img, '{}/{}_normal.png'.format(result_path_out, idx))
322 |
323 | logger.info('cPSNR: {:.4e} cSSIM: {:.4e} cLPIPS: {:.4e}'.format(psnr, ssim, lpips_scores))
324 |
325 | avg_ssim += ssim
326 | avg_psnr += psnr
327 | avg_lpips += lpips_scores
328 | # break
329 |
330 | avg_psnr = avg_psnr / idx
331 | avg_ssim = avg_ssim / idx
332 | avg_lpips = avg_lpips / idx
333 |
334 | if avg_psnr > best_psnr:
335 | best_psnr = avg_psnr
336 | best_ssim = avg_ssim
337 | best_lpips = avg_lpips
338 | if current_step % opt['train']['save_checkpoint_freq'] == 0 and rank <= 0:
339 | logger.info('Saving models and training states.')
340 | gen_path = diffusion.save_network(PD_steps[i+1], current_epoch, current_step, best_psnr, best_ssim, best_lpips)
341 | if args.progressive:
342 | resume_state = gen_path
343 | # logger.info('# Validation Avg scores at timesteps {:3d} # PSNR: {:.4e} SSIM: {:.4e} LPIPS: {:.4e}'.format(PD_steps[i+1], avg_psnr, avg_ssim, avg_lpips))
344 | logger_val = logging.getLogger('val')
345 | logger_val.info('# Avg scores # psnr: {:.4e} SSIM: {:.4e} LPIPS: {:.4e}'.format(PD_steps[i+1],
346 | current_epoch, current_step, avg_psnr, avg_ssim, avg_lpips))
347 | logger_val.info('# Best scores # psnr: {:.4e} SSIM: {:.4e} LPIPS: {:.4e}'.format(PD_steps[i+1], best_psnr, best_ssim, best_lpips))
348 | if opt["model"]['beta_schedule']["val"]["n_timestep"] == 2:
349 | break
350 |
351 |
352 |
353 | if __name__ == '__main__':
354 |
355 | main()
356 |
--------------------------------------------------------------------------------
/train_lol1.sh:
--------------------------------------------------------------------------------
1 |
2 | python train.py --config ./config/lolv2_real_train.json --dataset ./config/lolv2_real.yml --w_str 0.0 --w_snr 0.8 --w_gt 1.0 --w_lpips 0.6 --ablation
3 |
--------------------------------------------------------------------------------
/train_lol2_real.sh:
--------------------------------------------------------------------------------
1 |
2 | python train.py --config ./config/lolv2_real_train.json --dataset ./config/lolv2_real.yml --w_str 0.0 --w_snr 0.4 --w_gt 0.0 --w_lpips 0.6 --ablation
--------------------------------------------------------------------------------
/train_lol2_syn.sh:
--------------------------------------------------------------------------------
1 |
2 |
3 | python train.py --config ./config/lolv2_syn_train.json --dataset ./config/lolv2_syn.yml --w_str 0.0 --w_snr 0.4 --w_gt 0.0 --w_lpips 0.6 --ablation &
4 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/ema.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/__pycache__/ema.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/ema.py:
--------------------------------------------------------------------------------
1 | class EMA():
2 | def __init__(self, model, decay):
3 | self.model = model
4 | self.decay = decay
5 | self.shadow = {}
6 | self.backup = {}
7 |
8 | def register(self):
9 | for name, param in self.model.named_parameters():
10 | if param.requires_grad:
11 | self.shadow[name] = param.data.clone()
12 |
13 | def update(self):
14 | for name, param in self.model.named_parameters():
15 | if param.requires_grad:
16 | assert name in self.shadow
17 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
18 | self.shadow[name] = new_average.clone()
19 |
20 | def apply_shadow(self):
21 | for name, param in self.model.named_parameters():
22 | if param.requires_grad:
23 | assert name in self.shadow
24 | self.backup[name] = param.data
25 | param.data = self.shadow[name]
26 |
27 | def restore(self):
28 | for name, param in self.model.named_parameters():
29 | if param.requires_grad:
30 | assert name in self.backup
31 | param.data = self.backup[name]
32 | self.backup = {}
--------------------------------------------------------------------------------
/utils/niqe.py:
--------------------------------------------------------------------------------
1 | import math
2 | from os.path import dirname, join
3 |
4 | import cv2
5 | import numpy as np
6 | import scipy
7 | import scipy.io
8 | import scipy.misc
9 | import scipy.ndimage
10 | import scipy.special
11 | from PIL import Image
12 |
13 | gamma_range = np.arange(0.2, 10, 0.001)
14 | a = scipy.special.gamma(2.0/gamma_range)
15 | a *= a
16 | b = scipy.special.gamma(1.0/gamma_range)
17 | c = scipy.special.gamma(3.0/gamma_range)
18 | prec_gammas = a/(b*c)
19 |
20 |
21 | def aggd_features(imdata):
22 | # flatten imdata
23 | imdata.shape = (len(imdata.flat),)
24 | imdata2 = imdata*imdata
25 | left_data = imdata2[imdata < 0]
26 | right_data = imdata2[imdata >= 0]
27 | left_mean_sqrt = 0
28 | right_mean_sqrt = 0
29 | if len(left_data) > 0:
30 | left_mean_sqrt = np.sqrt(np.average(left_data))
31 | if len(right_data) > 0:
32 | right_mean_sqrt = np.sqrt(np.average(right_data))
33 |
34 | if right_mean_sqrt != 0:
35 | gamma_hat = left_mean_sqrt/right_mean_sqrt
36 | else:
37 | gamma_hat = np.inf
38 | # solve r-hat norm
39 |
40 | imdata2_mean = np.mean(imdata2)
41 | if imdata2_mean != 0:
42 | r_hat = (np.average(np.abs(imdata))**2) / (np.average(imdata2))
43 | else:
44 | r_hat = np.inf
45 | rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1) *
46 | (gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2))
47 |
48 | # solve alpha by guessing values that minimize ro
49 | pos = np.argmin((prec_gammas - rhat_norm)**2)
50 | alpha = gamma_range[pos]
51 |
52 | gam1 = scipy.special.gamma(1.0/alpha)
53 | gam2 = scipy.special.gamma(2.0/alpha)
54 | gam3 = scipy.special.gamma(3.0/alpha)
55 |
56 | aggdratio = np.sqrt(gam1) / np.sqrt(gam3)
57 | bl = aggdratio * left_mean_sqrt
58 | br = aggdratio * right_mean_sqrt
59 |
60 | # mean parameter
61 | N = (br - bl)*(gam2 / gam1) # *aggdratio
62 | return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt)
63 |
64 |
65 | def ggd_features(imdata):
66 | nr_gam = 1/prec_gammas
67 | sigma_sq = np.var(imdata)
68 | E = np.mean(np.abs(imdata))
69 | rho = sigma_sq/E**2
70 | pos = np.argmin(np.abs(nr_gam - rho))
71 | return gamma_range[pos], sigma_sq
72 |
73 |
74 | def paired_product(new_im):
75 | shift1 = np.roll(new_im.copy(), 1, axis=1)
76 | shift2 = np.roll(new_im.copy(), 1, axis=0)
77 | shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1)
78 | shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1)
79 |
80 | H_img = shift1 * new_im
81 | V_img = shift2 * new_im
82 | D1_img = shift3 * new_im
83 | D2_img = shift4 * new_im
84 |
85 | return (H_img, V_img, D1_img, D2_img)
86 |
87 |
88 | def gen_gauss_window(lw, sigma):
89 | sd = np.float32(sigma)
90 | lw = int(lw)
91 | weights = [0.0] * (2 * lw + 1)
92 | weights[lw] = 1.0
93 | sum = 1.0
94 | sd *= sd
95 | for ii in range(1, lw + 1):
96 | tmp = np.exp(-0.5 * np.float32(ii * ii) / sd)
97 | weights[lw + ii] = tmp
98 | weights[lw - ii] = tmp
99 | sum += 2.0 * tmp
100 | for ii in range(2 * lw + 1):
101 | weights[ii] /= sum
102 | return weights
103 |
104 |
105 | def compute_image_mscn_transform(image, C=1, avg_window=None, extend_mode='constant'):
106 | if avg_window is None:
107 | avg_window = gen_gauss_window(3, 7.0/6.0)
108 | assert len(np.shape(image)) == 2
109 | h, w = np.shape(image)
110 | mu_image = np.zeros((h, w), dtype=np.float32)
111 | var_image = np.zeros((h, w), dtype=np.float32)
112 | image = np.array(image).astype('float32')
113 | scipy.ndimage.correlate1d(image, avg_window, 0, mu_image, mode=extend_mode)
114 | scipy.ndimage.correlate1d(mu_image, avg_window, 1,
115 | mu_image, mode=extend_mode)
116 | scipy.ndimage.correlate1d(image**2, avg_window, 0,
117 | var_image, mode=extend_mode)
118 | scipy.ndimage.correlate1d(var_image, avg_window,
119 | 1, var_image, mode=extend_mode)
120 | var_image = np.sqrt(np.abs(var_image - mu_image**2))
121 | return (image - mu_image)/(var_image + C), var_image, mu_image
122 |
123 |
124 | def _niqe_extract_subband_feats(mscncoefs):
125 | # alpha_m, = extract_ggd_features(mscncoefs)
126 | alpha_m, N, bl, br, lsq, rsq = aggd_features(mscncoefs.copy())
127 | pps1, pps2, pps3, pps4 = paired_product(mscncoefs)
128 | alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1)
129 | alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2)
130 | alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3)
131 | alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4)
132 | return np.array([alpha_m, (bl+br)/2.0,
133 | alpha1, N1, bl1, br1, # (V)
134 | alpha2, N2, bl2, br2, # (H)
135 | alpha3, N3, bl3, bl3, # (D1)
136 | alpha4, N4, bl4, bl4, # (D2)
137 | ])
138 |
139 |
140 | def get_patches_train_features(img, patch_size, stride=8):
141 | return _get_patches_generic(img, patch_size, 1, stride)
142 |
143 |
144 | def get_patches_test_features(img, patch_size, stride=8):
145 | return _get_patches_generic(img, patch_size, 0, stride)
146 |
147 |
148 | def extract_on_patches(img, patch_size):
149 | h, w = img.shape
150 | patch_size = np.int32(patch_size)
151 | patches = []
152 | for j in range(0, h-patch_size+1, patch_size):
153 | for i in range(0, w-patch_size+1, patch_size):
154 | patch = img[j:j+patch_size, i:i+patch_size]
155 | patches.append(patch)
156 |
157 | patches = np.array(patches)
158 |
159 | patch_features = []
160 | for p in patches:
161 | patch_features.append(_niqe_extract_subband_feats(p))
162 | patch_features = np.array(patch_features)
163 |
164 | return patch_features
165 |
166 |
167 | def _get_patches_generic(img, patch_size, is_train, stride):
168 | h, w = np.shape(img)
169 | if h < patch_size or w < patch_size:
170 | print("Input image is too small")
171 | exit(0)
172 |
173 | # ensure that the patch divides evenly into img
174 | hoffset = (h % patch_size)
175 | woffset = (w % patch_size)
176 |
177 | if hoffset > 0:
178 | img = img[:-hoffset, :]
179 | if woffset > 0:
180 | img = img[:, :-woffset]
181 |
182 | img = img.astype(np.float32)
183 | # img2 = scipy.misc.imresize(img, 0.5, interp='bicubic', mode='F')
184 | img2 = cv2.resize(img, (0, 0), fx=0.5, fy=0.5)
185 |
186 | mscn1, var, mu = compute_image_mscn_transform(img)
187 | mscn1 = mscn1.astype(np.float32)
188 |
189 | mscn2, _, _ = compute_image_mscn_transform(img2)
190 | mscn2 = mscn2.astype(np.float32)
191 |
192 | feats_lvl1 = extract_on_patches(mscn1, patch_size)
193 | feats_lvl2 = extract_on_patches(mscn2, patch_size/2)
194 |
195 | feats = np.hstack((feats_lvl1, feats_lvl2)) # feats_lvl3))
196 |
197 | return feats
198 |
199 |
200 | def niqe(inputImgData):
201 |
202 | patch_size = 8
203 | module_path = dirname(__file__)
204 |
205 | # TODO: memoize
206 | params = scipy.io.loadmat(
207 | join(module_path, 'niqe_image_params.mat'))
208 | pop_mu = np.ravel(params["pop_mu"])
209 | pop_cov = params["pop_cov"]
210 |
211 | if inputImgData.ndim == 3:
212 | inputImgData = cv2.cvtColor(inputImgData, cv2.COLOR_BGR2GRAY)
213 | M, N = inputImgData.shape
214 |
215 | # assert C == 1, "niqe called with videos containing %d channels. Please supply only the luminance channel" % (C,)
216 | assert M > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters"
217 | assert N > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters"
218 |
219 | feats = get_patches_test_features(inputImgData, patch_size)
220 | sample_mu = np.mean(feats, axis=0)
221 | sample_cov = np.cov(feats.T)
222 |
223 | X = sample_mu - pop_mu
224 | covmat = ((pop_cov+sample_cov)/2.0)
225 | pinvmat = scipy.linalg.pinv(covmat)
226 | niqe_score = np.sqrt(np.dot(np.dot(X, pinvmat), X))
227 |
228 | return niqe_score
229 |
--------------------------------------------------------------------------------
/utils/niqe_image_params.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/niqe_image_params.mat
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import math
4 | from datetime import datetime
5 | import random
6 | import logging
7 | from collections import OrderedDict
8 |
9 | import natsort
10 | import numpy as np
11 | import cv2
12 | import torch
13 | from torchvision.utils import make_grid
14 | from shutil import get_terminal_size
15 | import torch
16 | import torch.nn.functional as F
17 | from torch.autograd import Variable
18 | import numpy as np
19 | from math import exp
20 |
21 | from skimage.metrics import structural_similarity as SSIM
22 |
23 |
24 | def gaussian(window_size, sigma):
25 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
26 | return gauss / gauss.sum()
27 |
28 |
29 | def create_window(window_size, channel):
30 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
31 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
32 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
33 | return window
34 |
35 |
36 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
37 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
38 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
39 |
40 | mu1_sq = mu1.pow(2)
41 | mu2_sq = mu2.pow(2)
42 | mu1_mu2 = mu1 * mu2
43 |
44 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
45 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
46 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
47 |
48 | C1 = 0.01 ** 2
49 | C2 = 0.03 ** 2
50 |
51 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
52 |
53 | if size_average:
54 | return ssim_map.mean()
55 | else:
56 | return ssim_map.mean(1).mean(1).mean(1)
57 |
58 |
59 | # class SSIM(torch.nn.Module):
60 | # def __init__(self, window_size=11, size_average=True):
61 | # super(SSIM, self).__init__()
62 | # self.window_size = window_size
63 | # self.size_average = size_average
64 | # self.channel = 1
65 | # self.window = create_window(window_size, self.channel)
66 |
67 | # def forward(self, img1, img2):
68 | # (_, channel, _, _) = img1.size()
69 |
70 | # if channel == self.channel and self.window.data.type() == img1.data.type():
71 | # window = self.window
72 | # else:
73 | # window = create_window(self.window_size, channel)
74 |
75 | # if img1.is_cuda:
76 | # window = window.cuda(img1.get_device())
77 | # window = window.type_as(img1)
78 |
79 | # self.window = window
80 | # self.channel = channel
81 |
82 | # return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
83 |
84 |
85 | def ssim(img1, img2, window_size=11, size_average=True):
86 | (_, channel, _, _) = img1.size()
87 | window = create_window(window_size, channel)
88 |
89 | if img1.is_cuda:
90 | window = window.cuda(img1.get_device())
91 | window = window.type_as(img1)
92 |
93 | return _ssim(img1.mean(dim=0, keepdims=True), img2.mean(dim=0, keepdims=True), window, window_size, channel, size_average)
94 |
95 |
96 | import yaml
97 |
98 | try:
99 | from yaml import CLoader as Loader, CDumper as Dumper
100 | except ImportError:
101 | from yaml import Loader, Dumper
102 |
103 |
104 | def OrderedYaml():
105 | '''yaml orderedDict support'''
106 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
107 |
108 | def dict_representer(dumper, data):
109 | return dumper.represent_dict(data.items())
110 |
111 | def dict_constructor(loader, node):
112 | return OrderedDict(loader.construct_pairs(node))
113 |
114 | Dumper.add_representer(OrderedDict, dict_representer)
115 | Loader.add_constructor(_mapping_tag, dict_constructor)
116 | return Loader, Dumper
117 |
118 |
119 | ####################
120 | # miscellaneous
121 | ####################
122 |
123 |
124 | def get_timestamp():
125 | return datetime.now().strftime('%y%m%d-%H%M%S')
126 |
127 |
128 | def mkdir(path):
129 | if not os.path.exists(path):
130 | os.makedirs(path)
131 |
132 |
133 | def mkdirs(paths):
134 | if isinstance(paths, str):
135 | mkdir(paths)
136 | else:
137 | for path in paths:
138 | mkdir(path)
139 |
140 |
141 | def mkdir_and_rename(path):
142 | if os.path.exists(path):
143 | new_name = path + '_archived_' + get_timestamp()
144 | print('Path already exists. Rename it to [{:s}]'.format(new_name))
145 | logger = logging.getLogger('base')
146 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
147 | os.rename(path, new_name)
148 | os.makedirs(path)
149 |
150 |
151 | def set_random_seed(seed):
152 | random.seed(seed)
153 | np.random.seed(seed)
154 | torch.manual_seed(seed)
155 | torch.cuda.manual_seed_all(seed)
156 |
157 |
158 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
159 | '''set up logger'''
160 | lg = logging.getLogger(logger_name)
161 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
162 | datefmt='%y-%m-%d %H:%M:%S')
163 | lg.setLevel(level)
164 | if tofile:
165 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
166 | fh = logging.FileHandler(log_file, mode='w')
167 | fh.setFormatter(formatter)
168 | lg.addHandler(fh)
169 | if screen:
170 | sh = logging.StreamHandler()
171 | sh.setFormatter(formatter)
172 | lg.addHandler(sh)
173 |
174 |
175 | ####################
176 | # image convert
177 | ####################
178 |
179 |
180 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)):
181 | '''
182 | Converts a torch Tensor into an image Numpy array
183 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
184 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
185 | '''
186 | if hasattr(tensor, 'detach'):
187 | tensor = tensor.detach()
188 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
189 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
190 | n_dim = tensor.dim()
191 | if n_dim == 4:
192 | n_img = len(tensor)
193 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
194 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
195 | elif n_dim == 3:
196 | img_np = tensor.numpy()
197 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
198 | elif n_dim == 2:
199 | img_np = tensor.numpy()
200 | else:
201 | raise TypeError(
202 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
203 | if out_type == np.uint8:
204 | img_np = np.clip((img_np * 255.0).round(), 0, 255)
205 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
206 | return tensor
207 | # return img_np.astype(out_type)
208 |
209 |
210 | def save_img(img, img_path, mode='RGB'):
211 | cv2.imwrite(img_path, img)
212 |
213 |
214 | ####################
215 | # metric
216 | ####################
217 |
218 |
219 | def calculate_psnr(img1, img2):
220 | # img1 and img2 have range [0, 255]
221 | img1 = img1.astype(np.float64)
222 | img2 = img2.astype(np.float64)
223 | mse = np.mean((img1 - img2) ** 2)
224 | if mse == 0:
225 | return float('inf')
226 | return 20 * math.log10(255.0 / math.sqrt(mse))
227 |
228 |
229 | def calculate_ssim(imgA, imgB, gray_scale=False):
230 |
231 | if gray_scale:
232 | score, diff = SSIM(cv2.cvtColor(imgA, cv2.COLOR_RGB2GRAY), cv2.cvtColor(imgB, cv2.COLOR_RGB2GRAY), full=True, multichannel=True)
233 | else:
234 | score, diff = SSIM(imgA, imgB, full=True, multichannel=True)
235 | return score
236 |
237 |
238 | def calculate_lpips(imgA, imgB, gray_scale=False):
239 |
240 | if gray_scale:
241 | score, diff = SSIM(cv2.cvtColor(imgA, cv2.COLOR_RGB2GRAY), cv2.cvtColor(imgB, cv2.COLOR_RGB2GRAY), full=True, multichannel=True)
242 | else:
243 | score, diff = SSIM(imgA, imgB, full=True, multichannel=True)
244 | return score
245 |
246 | def get_resume_paths(opt):
247 | resume_state_path = None
248 | resume_model_path = None
249 | ts = opt_get(opt, ['path', 'training_state'])
250 | if opt.get('path', {}).get('resume_state', None) == "auto" and ts is not None:
251 | wildcard = os.path.join(ts, "*")
252 | paths = natsort.natsorted(glob.glob(wildcard))
253 | if len(paths) > 0:
254 | resume_state_path = paths[-1]
255 | resume_model_path = resume_state_path.replace('training_state', 'models').replace('.state', '_G.pth')
256 | else:
257 | resume_state_path = opt.get('path', {}).get('resume_state')
258 | return resume_state_path, resume_model_path
259 |
260 |
261 | def opt_get(opt, keys, default=None):
262 | if opt is None:
263 | return default
264 | ret = opt
265 | for k in keys:
266 | ret = ret.get(k, None)
267 | if ret is None:
268 | return default
269 | return ret
270 |
--------------------------------------------------------------------------------