├── README.md
├── config
    ├── lolv1.yml
    ├── lolv1_test.json
    ├── lolv1_train.json
    ├── lolv2_real.yml
    ├── lolv2_real_test.json
    ├── lolv2_real_train.json
    ├── lolv2_syn.yml
    ├── lolv2_syn_test.json
    ├── lolv2_syn_train.json
    └── test_unpaired.json
├── core
    ├── __pycache__
    │   ├── logger.cpython-38.pyc
    │   └── metrics.cpython-38.pyc
    ├── logger.py
    └── metrics.py
├── data
    ├── LoL_dataset.py
    ├── __init__.py
    ├── __pycache__
    │   ├── LoL_dataset.cpython-38.pyc
    │   └── __init__.cpython-38.pyc
    ├── single_image_dataset.py
    └── util.py
├── dataset
    └── LOLv1
    │   └── readme.txt
├── model
    ├── __init__.py
    ├── __pycache__
    │   ├── __init__.cpython-38.pyc
    │   ├── base_model.cpython-38.pyc
    │   ├── model.cpython-38.pyc
    │   └── networks.cpython-38.pyc
    ├── base_model.py
    ├── ddpm_modules
    │   ├── __pycache__
    │   │   ├── diffusion.cpython-38.pyc
    │   │   └── unet.cpython-38.pyc
    │   ├── diffusion.py
    │   └── unet.py
    ├── model.py
    └── networks.py
├── options
    ├── __init__.py
    ├── __pycache__
    │   ├── __init__.cpython-38.pyc
    │   └── options.cpython-38.pyc
    └── options.py
├── requirements.txt
├── test.py
├── test.sh
├── test_unpaired.py
├── train.py
├── train_lol1.sh
├── train_lol2_real.sh
├── train_lol2_syn.sh
└── utils
    ├── __init__.py
    ├── __pycache__
        ├── __init__.cpython-38.pyc
        ├── ema.cpython-38.pyc
        └── util.cpython-38.pyc
    ├── ema.py
    ├── niqe.py
    ├── niqe_image_params.mat
    └── util.py
/README.md:
--------------------------------------------------------------------------------
  1 | 
  2 | 
  3 | 
  4 |   
  5 | # **[CVPR2025]** Efficient Diffusion as Low Light Enhancer
  6 | 
  7 | 
  8 |  9 |
  9 |  10 |
 10 | 
 11 | 
 12 | 
 dataset (click to expand)
 64 | 
 65 | ```
 66 | ├── dataset
 67 |     ├── LOLv1
 68 |         ├── our485
 69 |             ├──low
 70 |             ├──high
 71 | 	├── eval15
 72 |             ├──low
 73 |             ├──high
 74 | ├── dataset
 75 |    ├── LOLv2
 76 |        ├── Real_captured
 77 |            ├── Train
 78 | 	   ├── Test
 79 |        ├── Synthetic
 80 |            ├── Train
 81 | 	   ├── Test
 82 | ```
 83 | 
 84 |  
 85 | 
 86 | ### :blue_book: Testing
 87 | 
 88 | Note: Following LLFlow and KinD, we have also adjusted the brightness of the output image produced by the network, based on the average value of Ground Truth (GT). ``It should be noted that this adjustment process does not influence the texture details generated; it is merely a straightforward method to regulate the overall illumination.`` Moreover, it can be easily adjusted according to user preferences in practical applications.
 89 | 
 90 | You can also refer to the following links to download the checkpoints from [Google Drive](https://drive.google.com/file/d/13_XM8nFxJc2IfUotC2_lJo9ATt0rcIyg/view?usp=sharing) or [百度网盘 (Baidu Netdisk)](https://pan.baidu.com/s/1J7MP33Ws5kE673F-8zc2RA?pwd=nj8b) and put it in the following folder:
 91 | 
 92 | ```
 93 | ├── checkpoints
 94 |     ├── lolv1_8step_gen.pth
 95 |     ├── lolv1_4step_gen.pth
 96 |     ├── lolv1_2step_gen.pth
 97 |     ......
 98 | ```
 99 | To test the model using the ``sh test.sh`` command and modify the `n_timestep` and `time_scale` parameters for different step models. Here's a general outline of the steps:
100 | ```
101 | "val": {
102 |     "schedule": "linear",
103 |                 "n_timestep": 8,
104 |                 "linear_start": 1e-4,
105 |                 "linear_end": 2e-2,
106 |                 "time_scale": 64
107 | }
108 | ```
109 | 
110 | ```
111 | "val": {
112 |     "schedule": "linear",
113 |                 "n_timestep": 4,
114 |                 "linear_start": 1e-4,
115 |                 "linear_end": 2e-2,
116 |                 "time_scale": 128
117 | }
118 | ```
119 | 
120 | ```
121 | "val": {
122 |     "schedule": "linear",
123 |                 "n_timestep": 2,
124 |                 "linear_start": 1e-4,
125 |                 "linear_end": 2e-2,
126 |                 "time_scale": 256
127 | }
128 | ```
129 | ### :blue_book: Testing on unpaired data
130 | 
131 | ```
132 | python test_unpaired.py  --config config/test_unpaired.json --input unpaired_image_folder
133 | ```
134 | 
135 | You can use any one of these three pre-trained models, and employ different sampling steps to obtain visual-pleasing results by modifying these terms in the 'test_unpaired.json'.
136 | 
137 | 
138 | 
139 | ### :rocket: Training
140 | 
141 | ```
142 | bash train.sh
143 | ```
144 | 
145 | 
146 | ## :black_nib: Citation
147 | 
148 |    If you find our repo useful for your research, please consider citing our paper:
149 | 
150 |    ```bibtex
151 |     @InProceedings{Lan_2025_CVPR,
152 |         author    = {Lan, Guanzhou and Ma, Qianli and Yang, Yuqi and Wang, Zhigang and Wang, Dong and Li, Xuelong and Zhao, Bin},
153 |         title     = {Efficient Diffusion as Low Light Enhancer},
154 |         booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)},
155 |         month     = {June},
156 |         year      = {2025},
157 |         pages     = {21277-21286}
158 |     }
159 |    ```
160 | 
161 | 
162 | ## :heart: Acknowledgement
163 | 
164 | Our code is built upon [SR3](https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement). Thanks to the contributors for their great work.
165 | 
--------------------------------------------------------------------------------
/config/lolv1.yml:
--------------------------------------------------------------------------------
 1 | 
 2 | dataset: LOLv1
 3 | 
 4 | #### datasets
 5 | datasets:
 6 |   train:
 7 |     dist: False
 8 |     root: ./dataset/LOLv1
 9 |     use_shuffle: true
10 |     n_workers: 8  
11 |     batch_size: 16
12 |     use_flip: true
13 |     use_crop: true
14 |     patch_size: 96
15 | 
16 |   val:
17 |     dist: False
18 |     root: ./dataset/LOLv1
19 |     n_workers: 1
20 |     use_crop: true
21 |     batch_size: 1 
22 | 
23 | 
24 | 
25 | 
--------------------------------------------------------------------------------
/config/lolv1_test.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "name": "lolv1_test_8",
 3 |     "phase": "test",
 4 |     "distill": false,
 5 |     "gpu_ids": [
 6 |         0
 7 |     ],
 8 |     "path": {
 9 |         "log": "logs",
10 |         "tb_logger": "tb_logger",
11 |         "results": "results",
12 |         "checkpoint": "checkpoint", 
13 |         "resume_state": "./checkpoint/lolv1_4step_gen.pth"
14 |     },
15 |     "model": {
16 |         "which_model_G": "ddpm",
17 |         "finetune_norm": false,
18 |         "unet": {
19 |             "in_channel": 6,
20 |             "out_channel": 3,
21 |             "inner_channel": 64,
22 |             "channel_multiplier": [
23 |                 1,
24 |                 1,
25 |                 2,
26 |                 2,
27 |                 4
28 |             ],
29 |             "attn_res": [
30 |                 16
31 |             ],
32 |             "res_blocks": 2,
33 |             "dropout": 0
34 |         },
35 |         "beta_schedule": {
36 |             "train": {
37 |                 "schedule": "linear",
38 |                 "n_timestep": 4,
39 |                 "linear_start": 1e-4,
40 |                 "linear_end": 2e-2,
41 |                 "time_scale": 128
42 | 
43 |         
44 |             },
45 |             "val": {
46 |                 "schedule": "linear",
47 |                 "n_timestep": 4,
48 |                 "linear_start": 1e-4,
49 |                 "linear_end": 2e-2,
50 |                 "time_scale": 128
51 |                 
52 |            
53 |             }
54 |         },
55 |         "diffusion": {
56 |             "image_size": 128,
57 |             "channels": 6, 
58 |             "conditional": true,
59 |             "w_gt": 0.1,
60 |             "w_snr": 0.5,
61 |             "w_str": 0.1,
62 |             "w_lpips": 0.2
63 |         }
64 |     },
65 |     "train": {
66 |         "n_iter": 1000000,
67 |         "val_freq": 1e4,
68 |         "save_checkpoint_freq": 5e4,
69 |         "print_freq": 200,
70 |         "optimizer": {
71 |             "type": "adam",
72 |             "lr": 1e-4,
73 |             "lr_policy":"linear",
74 |             "lr_decay_iters":3000,
75 |             "n_lr_iters": 2000
76 |         },
77 |         "ema_scheduler": {
78 |             "step_start_ema": 5000,
79 |             "update_ema_every": 1,
80 |             "ema_decay": 0.9999
81 |         }
82 |     },
83 |     "wandb": {
84 |         "project": "llie_ddpm"
85 |     }
86 | }
--------------------------------------------------------------------------------
/config/lolv1_train.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "name": "lolv1_train",
 3 |     "phase": "train",
 4 |     "distill": true,
 5 |     "gpu_ids": [
 6 |         0
 7 |     ],
 8 |     "path": {
 9 |         "log": "logs",
10 |         "tb_logger": "tb_logger",
11 |         "results": "results",
12 |         "checkpoint": "checkpoint",
13 |         "resume_state": "./checkpoint/lolv1_4step_gen.pth"
14 |     },
15 |     "model": {
16 |         "which_model_G": "ddpm", 
17 |         "finetune_norm": false,
18 |         "unet": {
19 |             "in_channel": 6,
20 |             "out_channel": 3,
21 |             "inner_channel": 64,
22 |             "channel_multiplier": [
23 |                 1,
24 |                 1,
25 |                 2,
26 |                 2,
27 |                 4
28 |             ],
29 |             "attn_res": [
30 |                 16
31 |             ],
32 |             "res_blocks": 2,
33 |             "dropout": 0
34 |         },
35 |         "beta_schedule": {
36 |             "train": {
37 |                 "schedule": "linear",
38 |                 "n_timestep": 513,
39 |                 "linear_start": 1e-4,
40 |                 "linear_end": 2e-2,
41 |                 "time_scale": 1,
42 |                 "reflow": false
43 |             },
44 |             "val": {
45 |                 "schedule": "linear",
46 |                 "n_timestep": 513,
47 |                 "linear_start": 1e-4,
48 |                 "linear_end": 2e-2
49 | 
50 |             }
51 |         },
52 |         "diffusion": {
53 |             "image_size": 128,
54 |             "channels": 6, 
55 |             "conditional": true, 
56 |             "w_gt": 0.1,
57 |             "w_snr": 0.5,
58 |             "w_str": 0.1,
59 |             "w_lpips": 0.2
60 |         }
61 |     },
62 |     "train": {
63 |         "n_iter": 5000,
64 |         "val_freq": 100,
65 |         "save_checkpoint_freq": 100,
66 |         "print_freq": 200,
67 |         "optimizer": {
68 |             "type": "adam",
69 |             "lr": 1e-4,
70 |             "lr_policy":"linear",
71 |             "lr_decay_iters":3000,
72 |             "n_lr_iters": 2000
73 |         },
74 |         "ema_scheduler": {
75 |             "step_start_ema": 5000,
76 |             "update_ema_every": 1,
77 |             "ema_decay": 0.9999
78 |         }
79 |     },
80 |     "wandb": {
81 |         "project": "llie_ddpm"
82 |     }
83 | }
84 | 
--------------------------------------------------------------------------------
/config/lolv2_real.yml:
--------------------------------------------------------------------------------
 1 | 
 2 | dataset: LOLv2
 3 | 
 4 | #### datasets
 5 | datasets:
 6 |   train:
 7 |     dist: False
 8 |     root: ./dataset/LOL-v2
 9 |     use_shuffle: true
10 |     n_workers: 8  
11 |     batch_size: 16
12 |     use_flip: true
13 |     use_crop: true
14 |     patch_size: 96
15 |     sub_data: Real_captured
16 | 
17 |   val:
18 |     dist: False
19 |     root: ./dataset/LOL-v2
20 |     n_workers: 1
21 |     batch_size: 1 
22 |     sub_data: Real_captured
23 | 
24 | 
25 | 
26 | 
--------------------------------------------------------------------------------
/config/lolv2_real_test.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "name": "lolv2_test_real",
 3 |     "phase": "test",
 4 |     "gpu_ids": [
 5 |         0
 6 |     ],
 7 | 
 8 |     "path": {
 9 |         "log": "logs",
10 |         "tb_logger": "tb_logger",
11 |         "results": "results",
12 |         "checkpoint": "checkpoint", 
13 |         "resume_state": "./checkpoint/lolv2_real_4step_gen.pth"
14 |     },
15 |     "freq_aware": false,
16 |     "freq_awareUNet": {
17 |         "b1": 1.6,
18 |         "b2": 1.6,
19 |         "s1": 0.9,
20 |         "s2": 0.9
21 |     },
22 |     "model": {
23 |         "which_model_G": "ddpm",
24 |         "finetune_norm": false,
25 |         "unet": {
26 |             "in_channel": 6,
27 |             "out_channel": 3,
28 |             "inner_channel": 64,
29 |             "channel_multiplier": [
30 |                 1,
31 |                 1,
32 |                 2,
33 |                 2,
34 |                 4
35 |             ],
36 |             "attn_res": [
37 |                 16
38 |             ],
39 |             "res_blocks": 2,
40 |             "dropout": 0
41 |         },
42 |         "beta_schedule": {
43 |             "train": {
44 |                 "schedule": "linear",
45 |                 "n_timestep": 4,
46 |                 "linear_start": 1e-4,
47 |                 "linear_end": 2e-2,
48 |                 "time_scale": 128
49 |             },
50 |             "val": {
51 |                 "schedule": "linear",
52 |                 "n_timestep": 4,
53 |                 "linear_start": 1e-4,
54 |                 "linear_end": 2e-2,
55 |                 "time_scale": 128
56 |             }
57 |         },
58 |         "diffusion": {
59 |             "image_size": 128,
60 |             "channels": 6, 
61 |             "conditional": true
62 |         }
63 |     },
64 |     "train": {
65 |         "n_iter": 1000000,
66 |         "val_freq": 1e4,
67 |         "save_checkpoint_freq": 5e4,
68 |         "print_freq": 200,
69 |         "optimizer": {
70 |             "type": "adam",
71 |             "lr": 1e-4
72 |         },
73 |         "ema_scheduler": {
74 |             "step_start_ema": 5000,
75 |             "update_ema_every": 1,
76 |             "ema_decay": 0.9999
77 |         }
78 |     },
79 |     "wandb": {
80 |         "project": "llie_ddpm"
81 |     }
82 | }
--------------------------------------------------------------------------------
/config/lolv2_real_train.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "name": "lolv2_train_real",
 3 |     "phase": "train",
 4 |     "distill": true,
 5 |     "CD":false,
 6 |     "gpu_ids": [
 7 |         0
 8 |     ],
 9 | 
10 |     "path": {
11 |         "log": "logs",
12 |         "tb_logger": "tb_logger",
13 |         "results": "results",
14 |         "checkpoint": "checkpoints", 
15 |         "resume_state": "./checkpoint/lolv2_real_4step_gen.pth"
16 |     },
17 |     "model": {
18 |         "which_model_G": "ddpm",
19 |         "finetune_norm": false,
20 |         "unet": {
21 |             "in_channel": 6,
22 |             "out_channel": 3,
23 |             "inner_channel": 64,
24 |             "channel_multiplier": [
25 |                 1,
26 |                 1,
27 |                 2,
28 |                 2,
29 |                 4
30 |             ],
31 |             "attn_res": [
32 |                 16
33 |             ],
34 |             "res_blocks": 2,
35 |             "dropout": 0
36 |         },
37 |         "beta_schedule": {
38 |             "train": {
39 |                 "schedule": "linear",
40 |                 "n_timestep": 513,
41 |                 "linear_start": 1e-4,
42 |                 "linear_end": 2e-2,
43 |                 "reflow": false,
44 |                 "time_scale": 1
45 |             },
46 |             "val": {
47 |                 "schedule": "linear",
48 | 
49 |                 "n_timestep": 513,
50 |                 "linear_start": 1e-4,
51 |                 "linear_end": 2e-2,
52 |                 "time_scale": 1
53 |             }
54 |         },
55 |         "diffusion": {
56 |             "image_size": 128,
57 |             "channels": 6, 
58 |             "conditional": true,
59 |             "w_gt": 0.1,
60 |             "w_snr": 0.5,
61 |             "w_str": 0.1,
62 |             "w_lpips": 0.2
63 |             
64 |         }
65 |     },
66 |     "train": {
67 |         "n_iter": 5000,
68 |         "val_freq": 100,
69 |         "save_checkpoint_freq": 100,
70 |         "print_freq": 100,
71 |         "optimizer": {
72 |             "type": "adam",
73 |             "lr": 1e-4,
74 |             "lr_policy":"linear",
75 |             "lr_decay_iters":3000,
76 |             "n_lr_iters": 2000
77 |         },
78 |         "ema_scheduler": {
79 |             "step_start_ema": 5000,
80 |             "update_ema_every": 1,
81 |             "ema_decay": 0.9999
82 |         }
83 |     },
84 |     "wandb": {
85 |         "project": "llie_ddpm"
86 |     }
87 | }
--------------------------------------------------------------------------------
/config/lolv2_syn.yml:
--------------------------------------------------------------------------------
 1 | 
 2 | dataset: LOLv2
 3 | 
 4 | #### datasets
 5 | datasets:
 6 |   train:
 7 |     dist: False
 8 |     root: ./dataset/LOL-v2
 9 |     use_shuffle: true
10 |     n_workers: 8  
11 |     batch_size: 16
12 |     use_flip: true
13 |     use_crop: true
14 |     patch_size: 96
15 |     sub_data: Synthetic
16 | 
17 |   val:
18 |     dist: False
19 |     root: ./dataset/LOL-v2
20 |     n_workers: 1
21 |     batch_size: 1 
22 |     sub_data: Synthetic
23 | 
24 | 
25 | 
26 | 
--------------------------------------------------------------------------------
/config/lolv2_syn_test.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "name": "lolv2_test_syn",
 3 |     "phase": "test",
 4 |     "gpu_ids": [
 5 |         0
 6 |     ],
 7 |     "path": {
 8 |         "log": "logs",
 9 |         "tb_logger": "tb_logger",
10 |         "results": "results",
11 |         "checkpoint": "checkpoint", 
12 |         "resume_state": "./checkpoint/lolv2_syn_4step_gen.pth"
13 |     },
14 |     "model": {
15 |         "which_model_G": "ddpm",
16 |         "finetune_norm": false,
17 |         "unet": {
18 |             "in_channel": 6,
19 |             "out_channel": 3,
20 |             "inner_channel": 64,
21 |             "channel_multiplier": [
22 |                 1,
23 |                 1,
24 |                 2,
25 |                 2,
26 |                 4
27 |             ],
28 |             "attn_res": [
29 |                 16
30 |             ],
31 |             "res_blocks": 2,
32 |             "dropout": 0
33 |         },
34 |         "beta_schedule": {
35 |             "train": {
36 |                 "schedule": "linear",
37 |                 "n_timestep": 4,
38 |                 "linear_start": 1e-4,
39 |                 "linear_end": 2e-2,
40 |                 "time_scale": 128
41 |             },
42 |             "val": {
43 |                 "schedule": "linear",
44 |                 "n_timestep": 4,
45 |                 "linear_start": 1e-4,
46 |                 "linear_end": 2e-2,
47 |                 "time_scale": 128
48 |             }
49 |         },
50 |         "diffusion": {
51 |             "image_size": 128,
52 |             "channels": 6, 
53 |             "conditional": true
54 |         }
55 |     },
56 |     "train": {
57 |         "n_iter": 1000000,
58 |         "val_freq": 1e4,
59 |         "save_checkpoint_freq": 5e4,
60 |         "print_freq": 200,
61 |         "optimizer": {
62 |             "type": "adam",
63 |             "lr": 1e-4
64 |         },
65 |         "ema_scheduler": {
66 |             "step_start_ema": 5000,
67 |             "update_ema_every": 1,
68 |             "ema_decay": 0.9999
69 |         }
70 |     },
71 |     "wandb": {
72 |         "project": "llie_ddpm"
73 |     }
74 | }
--------------------------------------------------------------------------------
/config/lolv2_syn_train.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "name": "lolv2_train_syn",
 3 |     "phase": "train",
 4 |     "distill": true,
 5 |     "gpu_ids": [
 6 |         0
 7 |     ],
 8 |     "path": {
 9 |         "log": "logs",
10 |         "tb_logger": "tb_logger",
11 |         "results": "results",
12 |         "checkpoint": "checkpoint", 
13 |         "resume_state": "./checkpoint/lolv2_syn_4step_gen.pth"
14 |     },
15 |     "model": {
16 |         "which_model_G": "ddpm",
17 |         "finetune_norm": false,
18 |         "unet": {
19 |             "in_channel": 6,
20 |             "out_channel": 3,
21 |             "inner_channel": 64,
22 |             "channel_multiplier": [
23 |                 1,
24 |                 1,
25 |                 2,
26 |                 2,
27 |                 4
28 |             ],
29 |             "attn_res": [
30 |                 16
31 |             ],
32 |             "res_blocks": 2,
33 |             "dropout": 0
34 |         },
35 |         "beta_schedule": {
36 |             "train": {
37 |                 "schedule": "linear",
38 |                 "n_timestep": 513,
39 |                 "linear_start": 1e-4,
40 |                 "linear_end": 2e-2,
41 |                 "reflow": false,
42 |                 "time_scale": 1
43 |             },
44 |             "val": {
45 |                 "schedule": "linear",
46 |                 "n_timestep": 513,
47 |                 "linear_start": 1e-4,
48 |                 "linear_end": 2e-2
49 |             }
50 |         },
51 |         "diffusion": {
52 |             "image_size": 128,
53 |             "channels": 6, 
54 |             "conditional": true,
55 |             "w_gt": 0.1,
56 |             "w_snr": 0.5,
57 |             "w_str": 0.1,
58 |             "w_lpips": 0.2
59 |         }
60 |     },
61 |     "train": {
62 |         "n_iter": 5000  ,
63 |         "val_freq": 100,
64 |         "save_checkpoint_freq": 100,
65 |         "print_freq": 100,
66 |         "optimizer": {
67 |             "type": "adam",
68 |             "lr": 1e-4,
69 |             "lr_policy":"linear",
70 |             "lr_decay_iters":3000,
71 |             "n_lr_iters": 2000
72 |         },
73 |         "ema_scheduler": {
74 |             "step_start_ema": 5000,
75 |             "update_ema_every": 1,
76 |             "ema_decay": 0.9999
77 |         }
78 |     },
79 |     "wandb": {
80 |         "project": "llie_ddpm"
81 |     }
82 | }
--------------------------------------------------------------------------------
/config/test_unpaired.json:
--------------------------------------------------------------------------------
 1 | {
 2 |     "name": "test_unpaired",
 3 |     "phase": "test",
 4 |     "gpu_ids": [
 5 |         0
 6 |     ],
 7 |     "path": {
 8 |         "log": "logs",
 9 |         "tb_logger": "tb_logger",
10 |         "results": "results",
11 |         "checkpoint": "checkpoint", 
12 |         "resume_state": "experiments/lolv2_train_syn/lolv2_train_syn_w_snr:0.2_w_str:0.0_w_gt:1.0_w_lpips:0.6_240410_125713/checkpoint/num_step_8/psnr30.1459_ssim0.9424_lpips0.0284_I2700_E47_gen_ema.pth"
13 |     },
14 |     "freq_aware": false,
15 |     "freq_awareUNet": {
16 |         "b1": 1.6,
17 |         "b2": 1.6,
18 |         "s1": 0.9,
19 |         "s2": 0.9
20 |     },
21 |     "model": {
22 |         "which_model_G": "ddpm",
23 |         "finetune_norm": false,
24 |         "unet": {
25 |             "in_channel": 6,
26 |             "out_channel": 3,
27 |             "inner_channel": 64,
28 |             "channel_multiplier": [
29 |                 1,
30 |                 1,
31 |                 2,
32 |                 2,
33 |                 4
34 |             ],
35 |             "attn_res": [
36 |                 16
37 |             ],
38 |             "res_blocks": 2,
39 |             "dropout": 0
40 |         },
41 |         "beta_schedule": {
42 |             "train": {
43 |                 "schedule": "linear",
44 |                 "n_timestep": 4,
45 |                 "linear_start": 1e-4,
46 |                 "linear_end": 2e-2,
47 |                 "time_scale": 128
48 |             },
49 |             "val": {
50 |                 "schedule": "linear",
51 |                 "n_timestep": 4,
52 |                 "linear_start": 1e-4,
53 |                 "linear_end": 2e-2,
54 |                 "time_scale": 128
55 |             }
56 |         },
57 |         "diffusion": {
58 |             "image_size": 128,
59 |             "channels": 6, 
60 |             "conditional": true
61 |         }
62 |     },
63 |     "train": {
64 |         "n_iter": 1000000,
65 |         "val_freq": 1e4,
66 |         "save_checkpoint_freq": 5e4,
67 |         "print_freq": 200,
68 |         "optimizer": {
69 |             "type": "adam",
70 |             "lr": 1e-4
71 |         },
72 |         "ema_scheduler": {
73 |             "step_start_ema": 5000,
74 |             "update_ema_every": 1,
75 |             "ema_decay": 0.9999
76 |         }
77 |     },
78 |     "wandb": {
79 |         "project": "llie_ddpm"
80 |     }
81 | }
--------------------------------------------------------------------------------
/core/__pycache__/logger.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/core/__pycache__/logger.cpython-38.pyc
--------------------------------------------------------------------------------
/core/__pycache__/metrics.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/core/__pycache__/metrics.cpython-38.pyc
--------------------------------------------------------------------------------
/core/logger.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import os.path as osp
  3 | import logging
  4 | from collections import OrderedDict
  5 | import json
  6 | from datetime import datetime
  7 | 
  8 | 
  9 | def mkdirs(paths):
 10 |     if isinstance(paths, str):
 11 |         os.makedirs(paths, exist_ok=True)
 12 |     else:
 13 |         for path in paths:
 14 |             os.makedirs(path, exist_ok=True)
 15 | 
 16 | 
 17 | def get_timestamp():
 18 |     return datetime.now().strftime('%y%m%d_%H%M%S')
 19 | 
 20 | 
 21 | def parse(args):
 22 |     phase = args.phase
 23 |     opt_path = args.config
 24 |     gpu_ids = args.gpu_ids
 25 |     # remove comments starting with '//'
 26 |     json_str = ''
 27 |     with open(opt_path, 'r') as f:
 28 |         for line in f:
 29 |             line = line.split('//')[0] + '\n'
 30 |             json_str += line
 31 |     opt = json.loads(json_str, object_pairs_hook=OrderedDict)
 32 | 
 33 |     # set log directory
 34 |     if args.debug:
 35 |         opt['name'] = 'debug_{}'.format(opt['name'])
 36 |     if args.brutal_search:
 37 |         experiments_root = os.path.join(
 38 |             'experiments', opt['name'], '{}_noise_start:{}_noise_end:{}_{}'.format(opt['name'], args.noise_start, args.noise_end, get_timestamp()))
 39 |     else:
 40 |         if opt['phase'] == 'train':
 41 |             experiments_root = os.path.join(
 42 |                 'experiments', opt['name'], '{}_w_snr:{}_w_str:{}_w_gt:{}_w_lpips:{}_{}'.format(opt['name'], args.w_snr, args.w_str, args.w_gt, args.w_lpips, get_timestamp()))
 43 |         elif opt['phase'] == 'test':
 44 |             experiments_root = os.path.join(
 45 |                 'experiments', opt['name'], '{}_numstep:{}_w_snr:{}_w_gt:{}_w_lpips:{}_{}'.format(opt['name'],opt["model"]['beta_schedule']['val']['n_timestep'], args.w_snr, args.w_gt, args.w_lpips, get_timestamp()))
 46 |     opt['path']['experiments_root'] = experiments_root
 47 |     for key, path in opt['path'].items():
 48 |         if 'resume' not in key and 'experiments' not in key:
 49 |             opt['path'][key] = os.path.join(experiments_root, path)
 50 |             mkdirs(opt['path'][key])
 51 | 
 52 |     # change dataset length limit
 53 |     opt['phase'] = phase
 54 | 
 55 |     # export CUDA_VISIBLE_DEVICES
 56 |     if gpu_ids is not None:
 57 |         opt['gpu_ids'] = [int(id) for id in gpu_ids.split(',')]
 58 |         gpu_list = gpu_ids
 59 |     else:
 60 |         gpu_list = ','.join(str(x) for x in opt['gpu_ids'])
 61 |     
 62 |     if len(gpu_list) > 1:
 63 |         opt['distributed'] = True
 64 |     else:
 65 |         opt['distributed'] = False
 66 | 
 67 |     # debug
 68 |     if 'debug' in opt['name']:
 69 |         opt['train']['val_freq'] = 2
 70 |         opt['train']['print_freq'] = 2
 71 |         opt['train']['save_checkpoint_freq'] = 3
 72 |         opt['datasets']['train']['batch_size'] = 2
 73 |         opt['model']['beta_schedule']['train']['n_timestep'] = 10
 74 |         opt['model']['beta_schedule']['val']['n_timestep'] = 10
 75 |         opt['datasets']['train']['data_len'] = 6
 76 |         opt['datasets']['val']['data_len'] = 3
 77 | 
 78 |     # validation in train phase
 79 |     # if phase == 'train':
 80 |     #     opt['datasets']['val']['data_len'] = 3
 81 | 
 82 |     # W&B Logging
 83 |     try:
 84 |         log_wandb_ckpt = args.log_wandb_ckpt
 85 |         opt['log_wandb_ckpt'] = log_wandb_ckpt
 86 |     except:
 87 |         pass
 88 |     try:
 89 |         log_eval = args.log_eval
 90 |         opt['log_eval'] = log_eval
 91 |     except:
 92 |         pass
 93 |     try:
 94 |         log_infer = args.log_infer
 95 |         opt['log_infer'] = log_infer
 96 |     except:
 97 |         pass
 98 |     
 99 |     return opt
100 | 
101 | 
102 | class NoneDict(dict):
103 |     def __missing__(self, key):
104 |         return None
105 | 
106 | 
107 | # convert to NoneDict, which return None for missing key.
108 | def dict_to_nonedict(opt):
109 |     if isinstance(opt, dict):
110 |         new_opt = dict()
111 |         for key, sub_opt in opt.items():
112 |             new_opt[key] = dict_to_nonedict(sub_opt)
113 |         return NoneDict(**new_opt)
114 |     elif isinstance(opt, list):
115 |         return [dict_to_nonedict(sub_opt) for sub_opt in opt]
116 |     else:
117 |         return opt
118 | 
119 | 
120 | def dict2str(opt, indent_l=1):
121 |     '''dict to string for logger'''
122 |     msg = ''
123 |     for k, v in opt.items():
124 |         if isinstance(v, dict):
125 |             msg += ' ' * (indent_l * 2) + k + ':[\n'
126 |             msg += dict2str(v, indent_l + 1)
127 |             msg += ' ' * (indent_l * 2) + ']\n'
128 |         else:
129 |             msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
130 |     return msg
131 | 
132 | 
133 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False):
134 |     '''set up logger'''
135 |     l = logging.getLogger(logger_name)
136 |     formatter = logging.Formatter(
137 |         '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S')
138 |     log_file = os.path.join(root, '{}.log'.format(phase))
139 |     fh = logging.FileHandler(log_file, mode='w')
140 |     fh.setFormatter(formatter)
141 |     l.setLevel(level)
142 |     l.addHandler(fh)
143 |     if screen:
144 |         sh = logging.StreamHandler()
145 |         sh.setFormatter(formatter)
146 |         l.addHandler(sh)
147 | 
--------------------------------------------------------------------------------
/core/metrics.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import math
  3 | import numpy as np
  4 | import cv2
  5 | from torchvision.utils import make_grid
  6 | 
  7 | 
  8 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)):
  9 |     '''
 10 |     Converts a torch Tensor into an image Numpy array
 11 |     Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
 12 |     Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
 13 |     '''
 14 |     tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # clamp
 15 |     tensor = (tensor - min_max[0]) / \
 16 |         (min_max[1] - min_max[0])  # to range [0,1]
 17 |     n_dim = tensor.dim()
 18 |     if n_dim == 4:
 19 |         n_img = len(tensor)
 20 |         # img_np = make_grid(tensor, nrow=int(
 21 |         #     math.sqrt(n_img)), normalize=False).numpy()
 22 |         # img_np = np.transpose(img_np, (1, 2, 0))  # HWC, RGB
 23 | 
 24 |         img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
 25 |         img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
 26 |     elif n_dim == 3:
 27 |         img_np = tensor.numpy()
 28 |         # img_np = np.transpose(img_np, (1, 2, 0))  # HWC, RGB
 29 | 
 30 |         img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
 31 |     elif n_dim == 2:
 32 |         img_np = tensor.numpy()
 33 |     else:
 34 |         raise TypeError(
 35 |             'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
 36 |     if out_type == np.uint8:
 37 |         img_np = np.clip((img_np * 255.0).round(), 0, 255)
 38 |         # img_np = (img_np * 255.0).round()
 39 |         # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
 40 |     return img_np.astype(out_type)
 41 | 
 42 | def tensor2img2(tensor, out_type=np.uint8, min_max=(-1, 1)):
 43 |     '''
 44 |     Converts a torch Tensor into an image Numpy array
 45 |     Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
 46 |     Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
 47 |     '''
 48 |     tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # clamp
 49 | 
 50 |     n_dim = tensor.dim()
 51 |     if n_dim == 4:
 52 |         n_img = len(tensor)
 53 |         # img_np = make_grid(tensor, nrow=int(
 54 |         #     math.sqrt(n_img)), normalize=False).numpy()
 55 |         # img_np = np.transpose(img_np, (1, 2, 0))  # HWC, RGB
 56 | 
 57 |         img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
 58 |         img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
 59 |     elif n_dim == 3:
 60 |         img_np = tensor.numpy()
 61 |         # img_np = np.transpose(img_np, (1, 2, 0))  # HWC, RGB
 62 | 
 63 |         img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
 64 |     elif n_dim == 2:
 65 |         img_np = tensor.numpy()
 66 |     else:
 67 |         raise TypeError(
 68 |             'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
 69 |     if out_type == np.uint8:
 70 |         img_np = np.clip((img_np * 255.0).round(), 0, 255)
 71 |         # img_np = (img_np * 255.0).round()
 72 |         # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
 73 |     return img_np.astype(out_type)
 74 | 
 75 | def save_img(img, img_path, mode='RGB'):
 76 |     cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
 77 |     # cv2.imwrite(img_path, img)
 78 | 
 79 | 
 80 | def calculate_psnr(img1, img2):
 81 |     # img1 and img2 have range [0, 255]
 82 |     img1 = img1.astype(np.float64)
 83 |     img2 = img2.astype(np.float64)
 84 |     mse = np.mean((img1 - img2)**2)
 85 |     if mse == 0:
 86 |         return float('inf')
 87 |     return 20 * math.log10(255.0 / math.sqrt(mse))
 88 | 
 89 | 
 90 | def ssim(img1, img2):
 91 |     C1 = (0.01 * 255)**2
 92 |     C2 = (0.03 * 255)**2
 93 | 
 94 |     img1 = img1.astype(np.float64)
 95 |     img2 = img2.astype(np.float64)
 96 |     kernel = cv2.getGaussianKernel(11, 1.5)
 97 |     window = np.outer(kernel, kernel.transpose())
 98 | 
 99 |     mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]  # valid
100 |     mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
101 |     mu1_sq = mu1**2
102 |     mu2_sq = mu2**2
103 |     mu1_mu2 = mu1 * mu2
104 |     sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
105 |     sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
106 |     sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
107 | 
108 |     ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
109 |                                                             (sigma1_sq + sigma2_sq + C2))
110 |     return ssim_map.mean()
111 | 
112 | 
113 | def calculate_ssim(img1, img2):
114 |     '''calculate SSIM
115 |     the same outputs as MATLAB's
116 |     img1, img2: [0, 255]
117 |     '''
118 |     if not img1.shape == img2.shape:
119 |         raise ValueError('Input images must have the same dimensions.')
120 |     if img1.ndim == 2:
121 |         return ssim(img1, img2)
122 |     elif img1.ndim == 3:
123 |         if img1.shape[2] == 3:
124 |             ssims = []
125 |             for i in range(3):
126 |                 ssims.append(ssim(img1, img2))
127 |             return np.array(ssims).mean()
128 |         elif img1.shape[2] == 1:
129 |             return ssim(np.squeeze(img1), np.squeeze(img2))
130 |     else:
131 |         raise ValueError('Wrong input image dimensions.')
132 | 
--------------------------------------------------------------------------------
/data/LoL_dataset.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import torch.utils.data as data
  3 | import numpy as np
  4 | import torch
  5 | import cv2
  6 | from torchvision.transforms import ToTensor
  7 | import torchvision.transforms as T
  8 | import torchvision
  9 | 
 10 | 
 11 | class LOLv1_Dataset(data.Dataset):
 12 |     def __init__(self, opt, train, all_opt):
 13 |         self.root = opt["root"]
 14 |         self.opt = opt
 15 |         self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False
 16 |         self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False
 17 |         self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False
 18 |         self.crop_size = opt.get("patch_size", None)
 19 |         if train:
 20 |             self.split = 'train'
 21 |             self.root = os.path.join(self.root, 'our485')
 22 |         else:
 23 |             self.split = 'val'
 24 |             self.root = os.path.join(self.root, 'eval15')
 25 |         self.pairs = self.load_pairs(self.root)
 26 |         self.to_tensor = ToTensor()
 27 | 
 28 |     def __len__(self):
 29 |         return len(self.pairs)
 30 | 
 31 |     def load_pairs(self, folder_path):
 32 | 
 33 |         low_list = os.listdir(os.path.join(folder_path, 'low'))
 34 |         low_list = filter(lambda x: 'png' in x, low_list)
 35 | 
 36 |         pairs = []
 37 |         for idx, f_name in enumerate(low_list):
 38 |             
 39 |             if self.split == 'val':
 40 |                 pairs.append(
 41 |                     [cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'low', f_name)), cv2.COLOR_BGR2RGB),  
 42 |                      cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'high', f_name)), cv2.COLOR_BGR2RGB),
 43 |                     f_name.split('.')[0]])
 44 |             else:
 45 |                 pairs.append(
 46 |                     [cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'low', f_name)), cv2.COLOR_BGR2RGB),  
 47 |                      cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'high', f_name)), cv2.COLOR_BGR2RGB),
 48 |                     f_name.split('.')[0]])
 49 |         return pairs
 50 |     
 51 |     def get_max(self,input):
 52 |         T,_=torch.max(input,dim=0)
 53 |         T=T+0.1
 54 |         input[0,:,:] = input[0,:,:]/ T
 55 |         input[1,:,:] = input[1,:,:]/ T
 56 |         input[2,:,:]= input[2,:,:] / T
 57 |         return input
 58 |     
 59 |     def __getitem__(self, item):
 60 |         lr, hr, f_name = self.pairs[item]
 61 | 
 62 | 
 63 |         if self.use_crop and self.split != 'val':
 64 |             hr, lr = random_crop(hr, lr, self.crop_size)
 65 |         elif self.split == 'val':
 66 |             lr = cv2.copyMakeBorder(lr, 8,8,4,4,cv2.BORDER_REFLECT)
 67 | 
 68 |         if self.use_flip:
 69 |             hr, lr = random_flip(hr, lr)
 70 | 
 71 |         if self.use_rot:
 72 |             hr, lr = random_rotation(hr, lr)
 73 | 
 74 |         hr = self.to_tensor(hr)
 75 |         lr = self.to_tensor(lr)
 76 |         # lr = self.get_max(lr)
 77 | 
 78 |         [lr, hr] = transform_augment(
 79 |                 [lr, hr], split=self.split, min_max=(-1, 1))
 80 | 
 81 |         return {'LQ': lr, 'GT': hr, 'LQ_path': f_name, 'GT_path': f_name}
 82 | 
 83 | class LOLv2_Dataset(data.Dataset):
 84 |     def __init__(self, opt, train, all_opt):
 85 |         self.root = opt["root"]
 86 |         self.opt = opt
 87 |         self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False
 88 |         self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False
 89 |         self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False
 90 |         self.crop_size = opt.get("patch_size", None)
 91 |         self.sub_data = opt.get("sub_data", None)
 92 |         self.pairs = []
 93 |         self.train = train
 94 |         if train:
 95 |             self.split = 'train'
 96 |             root = os.path.join(self.root, self.sub_data, 'Train')
 97 |         else:
 98 |             self.split = 'val'
 99 |             root = os.path.join(self.root, self.sub_data, 'Test')
100 |         self.pairs.extend(self.load_pairs(root))
101 |         self.to_tensor = ToTensor()
102 |         self.gamma_aug = opt['gamma_aug'] if 'gamma_aug' in opt.keys() else False
103 | 
104 |     def __len__(self):
105 |         return len(self.pairs)
106 |     
107 |     def get_max(self,input):
108 |         T,_=torch.max(input,dim=0)
109 |         T=T+0.1
110 |         input[0,:,:] = input[0,:,:]/ T
111 |         input[1,:,:] = input[1,:,:]/ T
112 |         input[2,:,:]= input[2,:,:] / T
113 |         return input
114 | 
115 |     def load_pairs(self, folder_path):
116 | 
117 |         low_list = os.listdir(os.path.join(folder_path, 'Low' if self.train else 'Low'))
118 |         low_list = sorted(list(filter(lambda x: 'png' in x, low_list)))
119 |         high_list = os.listdir(os.path.join(folder_path, 'Normal' if self.train else 'Normal'))
120 |         high_list = sorted(list(filter(lambda x: 'png' in x, high_list)))
121 |         pairs = []
122 | 
123 |         for idx in range(len(low_list)):
124 |             f_name_low = low_list[idx]
125 |             f_name_high = high_list[idx]
126 |             pairs.append(
127 |                 [cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'Low' if self.train else 'Low', f_name_low)),
128 |                                 cv2.COLOR_BGR2RGB),  
129 |                     cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'Normal' if self.train else 'Normal', f_name_high)),
130 |                                 cv2.COLOR_BGR2RGB), 
131 |                     f_name_high.split('.')[0]])
132 |         return pairs
133 | 
134 |     def __getitem__(self, item):
135 |         
136 |         lr, hr, f_name = self.pairs[item]
137 | 
138 |         if self.use_crop and self.split != 'val':
139 |             hr, lr = random_crop(hr, lr, self.crop_size)
140 |         elif self.sub_data == 'Real_captured' and self.split == 'val': # for Real_captured
141 |             lr = cv2.copyMakeBorder(lr, 8,8,4,4,cv2.BORDER_REFLECT)
142 | 
143 |         if self.use_flip:
144 |             hr, lr = random_flip(hr, lr)
145 | 
146 |         if self.use_rot:
147 |             hr, lr = random_rotation(hr, lr)
148 | 
149 | 
150 |         hr = self.to_tensor(hr)
151 |         lr = self.to_tensor(lr)
152 |         # lr = self.get_max(lr)
153 | 
154 |         
155 |         [lr, hr] = transform_augment(
156 |                 [lr, hr], split=self.split, min_max=(-1, 1))
157 | 
158 |         return {'LQ': lr, 'GT': hr, 'LQ_path': f_name, 'GT_path': f_name}
159 | 
160 | 
161 | def random_flip(img, seg):
162 |     random_choice = np.random.choice([True, False])
163 |     img = img if random_choice else np.flip(img, 1).copy()
164 |     seg = seg if random_choice else np.flip(seg, 1).copy()
165 | 
166 |     return img, seg
167 | 
168 | 
169 | def gamma_aug(img, gamma=0):
170 |     max_val = img.max()
171 |     img_after_norm = img / max_val
172 |     img_after_norm = np.power(img_after_norm, gamma)
173 |     return img_after_norm * max_val
174 | 
175 | 
176 | def random_rotation(img, seg):
177 |     random_choice = np.random.choice([0, 1, 3])
178 |     img = np.rot90(img, random_choice, axes=(0, 1)).copy()
179 |     seg = np.rot90(seg, random_choice, axes=(0, 1)).copy()
180 |     
181 |     return img, seg
182 | 
183 | 
184 | def random_crop(hr, lr, size_hr):
185 |     size_lr = size_hr
186 | 
187 |     size_lr_x = lr.shape[0]
188 |     size_lr_y = lr.shape[1]
189 | 
190 |     start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0
191 |     start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0
192 | 
193 |     # LR Patch
194 |     lr_patch = lr[start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr, :]
195 | 
196 |     # HR Patch
197 |     start_x_hr = start_x_lr
198 |     start_y_hr = start_y_lr
199 |     hr_patch = hr[start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr, :]
200 | 
201 |     # HisEq Patch
202 |     his_eq_patch = None
203 |     return hr_patch, lr_patch, 
204 | 
205 | 
206 | # implementation by torchvision, detail in https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement/issues/14
207 | totensor = torchvision.transforms.ToTensor()
208 | hflip = torchvision.transforms.RandomHorizontalFlip()
209 | def transform_augment(imgs, split='val', min_max=(0, 1)):    
210 |     # imgs = [totensor(img) for img in img_list]
211 |     if split == 'train':
212 |         imgs = torch.stack(imgs, 0)
213 |         # imgs = hflip(imgs)
214 |         imgs = torch.unbind(imgs, dim=0)
215 |     ret_img = [img * (min_max[1] - min_max[0]) + min_max[0] for img in imgs]
216 |     return ret_img
217 | 
218 | 
--------------------------------------------------------------------------------
/data/__init__.py:
--------------------------------------------------------------------------------
 1 | '''create dataset and dataloader'''
 2 | import logging
 3 | import torch
 4 | import torch.utils.data
 5 | 
 6 | 
 7 | 
 8 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
 9 |     phase = dataset_opt['phase']
10 |     if phase == 'train':
11 |         if dataset_opt['dist']:
12 |             world_size = torch.distributed.get_world_size()
13 |             num_workers = dataset_opt['n_workers']
14 |             assert dataset_opt['batch_size'] % world_size == 0
15 |             batch_size = dataset_opt['batch_size'] // world_size
16 |             shuffle = False
17 |             sampler=torch.utils.data.distributed.DistributedSampler(dataset)
18 |         else:
19 |             num_workers = dataset_opt['n_workers'] # * len(opt['gpu_ids'])
20 |             batch_size = dataset_opt['batch_size']
21 |             shuffle = True
22 |         return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
23 |                                            num_workers=num_workers, sampler=sampler, drop_last=True,
24 |                                            pin_memory=True)
25 |     else:
26 |         return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
27 |                                            pin_memory=True)
28 | 
29 | 
30 | # def create_dataloader(train, dataset, dataset_opt, opt=None, sampler=None):
31 | #     # gpu_ids = opt.get('gpu_ids', None)
32 | #     gpu_ids = []
33 | #     gpu_ids = gpu_ids if gpu_ids else []
34 | #     num_workers = dataset_opt['n_workers'] * (len(gpu_ids)+1)
35 | #     batch_size = dataset_opt['batch_size']
36 | #     shuffle = True
37 | #     if train:
38 | #         return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
39 | #                                            num_workers=num_workers, sampler=sampler, drop_last=True,
40 | #                                            pin_memory=False)
41 | #     else:
42 | #         return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False,
43 | #                                            num_workers=num_workers, sampler=sampler, drop_last=False,
44 | #                                            pin_memory=False)
45 | 
46 | 
--------------------------------------------------------------------------------
/data/__pycache__/LoL_dataset.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/data/__pycache__/LoL_dataset.cpython-38.pyc
--------------------------------------------------------------------------------
/data/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/data/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/data/single_image_dataset.py:
--------------------------------------------------------------------------------
 1 | from os import path as osp
 2 | from torch.utils import data as data
 3 | from torchvision.transforms.functional import normalize
 4 | 
 5 | from basicsr.data.data_util import paths_from_lmdb
 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir
 7 | 
 8 | 
 9 | class SingleImageDataset(data.Dataset):
10 |     """Read only lq images in the test phase.
11 | 
12 |     Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc).
13 | 
14 |     There are two modes:
15 |     1. 'meta_info_file': Use meta information file to generate paths.
16 |     2. 'folder': Scan folders to generate paths.
17 | 
18 |     Args:
19 |         opt (dict): Config for train datasets. It contains the following keys:
20 |             dataroot_lq (str): Data root path for lq.
21 |             meta_info_file (str): Path for meta information file.
22 |             io_backend (dict): IO backend type and other kwarg.
23 |     """
24 | 
25 |     def __init__(self, opt):
26 |         super(SingleImageDataset, self).__init__()
27 |         self.opt = opt
28 |         # file client (io backend)
29 |         self.file_client = None
30 |         self.io_backend_opt = opt['io_backend']
31 |         self.mean = opt['mean'] if 'mean' in opt else None
32 |         self.std = opt['std'] if 'std' in opt else None
33 |         self.lq_folder = opt['dataroot_lq']
34 | 
35 |         if self.io_backend_opt['type'] == 'lmdb':
36 |             self.io_backend_opt['db_paths'] = [self.lq_folder]
37 |             self.io_backend_opt['client_keys'] = ['lq']
38 |             self.paths = paths_from_lmdb(self.lq_folder)
39 |         elif 'meta_info_file' in self.opt:
40 |             with open(self.opt['meta_info_file'], 'r') as fin:
41 |                 self.paths = [
42 |                     osp.join(self.lq_folder,
43 |                              line.split(' ')[0]) for line in fin
44 |                 ]
45 |         else:
46 |             self.paths = sorted(list(scandir(self.lq_folder, full_path=True)))
47 | 
48 |     def __getitem__(self, index):
49 |         if self.file_client is None:
50 |             self.file_client = FileClient(
51 |                 self.io_backend_opt.pop('type'), **self.io_backend_opt)
52 | 
53 |         # load lq image
54 |         lq_path = self.paths[index]
55 |         img_bytes = self.file_client.get(lq_path, 'lq')
56 |         img_lq = imfrombytes(img_bytes, float32=True)
57 | 
58 |         # TODO: color space transform
59 |         # BGR to RGB, HWC to CHW, numpy to tensor
60 |         img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True)
61 |         # normalize
62 |         if self.mean is not None or self.std is not None:
63 |             normalize(img_lq, self.mean, self.std, inplace=True)
64 |         return {'lq': img_lq, 'lq_path': lq_path}
65 | 
66 |     def __len__(self):
67 |         return len(self.paths)
68 | 
--------------------------------------------------------------------------------
/data/util.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | import math
  3 | import pickle
  4 | import random
  5 | import numpy as np
  6 | import glob
  7 | import torch
  8 | import cv2
  9 | 
 10 | ####################
 11 | # Files & IO
 12 | ####################
 13 | 
 14 | ###################### get image path list ######################
 15 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP']
 16 | 
 17 | 
 18 | def flip(x, dim):
 19 |     indices = [slice(None)] * x.dim()
 20 |     indices[dim] = torch.arange(x.size(dim) - 1, -1, -1,
 21 |                                 dtype=torch.long, device=x.device)
 22 |     return x[tuple(indices)]
 23 | 
 24 | 
 25 | def is_image_file(filename):
 26 |     return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
 27 | 
 28 | 
 29 | def _get_paths_from_images(path):
 30 |     """get image path list from image folder"""
 31 |     assert os.path.isdir(path), '{:s} is not a valid directory'.format(path)
 32 |     images = []
 33 |     for dirpath, _, fnames in sorted(os.walk(path)):
 34 |         for fname in sorted(fnames):
 35 |             if is_image_file(fname):
 36 |                 img_path = os.path.join(dirpath, fname)
 37 |                 images.append(img_path)
 38 |     assert images, '{:s} has no valid image file'.format(path)
 39 |     return images
 40 | 
 41 | 
 42 | def _get_paths_from_lmdb(dataroot):
 43 |     """get image path list from lmdb meta info"""
 44 |     meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb'))
 45 |     paths = meta_info['keys']
 46 |     sizes = meta_info['resolution']
 47 |     if len(sizes) == 1:
 48 |         sizes = sizes * len(paths)
 49 |     return paths, sizes
 50 | 
 51 | 
 52 | def get_image_paths(data_type, dataroot):
 53 |     """get image path list
 54 |     support lmdb or image files"""
 55 |     paths, sizes = None, None
 56 |     if dataroot is not None:
 57 |         if data_type == 'lmdb':
 58 |             paths, sizes = _get_paths_from_lmdb(dataroot)
 59 |         elif data_type == 'img':
 60 |             paths = sorted(_get_paths_from_images(dataroot))
 61 |         else:
 62 |             raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type))
 63 |     return paths, sizes
 64 | 
 65 | 
 66 | def glob_file_list(root):
 67 |     return sorted(glob.glob(os.path.join(root, '*')))
 68 | 
 69 | 
 70 | ###################### read images ######################
 71 | def _read_img_lmdb(env, key, size):
 72 |     """read image from lmdb with key (w/ and w/o fixed size)
 73 |     size: (C, H, W) tuple"""
 74 |     with env.begin(write=False) as txn:
 75 |         buf = txn.get(key.encode('ascii'))
 76 |     img_flat = np.frombuffer(buf, dtype=np.uint8)
 77 |     C, H, W = size
 78 |     img = img_flat.reshape(H, W, C)
 79 |     return img
 80 | 
 81 | 
 82 | def read_img(env, path, size=None):
 83 |     """read image by cv2 or from lmdb
 84 |     return: Numpy float32, HWC, BGR, [0,1]"""
 85 |     if env is None:  # img
 86 |         img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
 87 |         if img is None:
 88 |             print(path)
 89 |         if size is not None:
 90 |             img = cv2.resize(img, (size[0], size[1]))
 91 |     else:
 92 |         img = _read_img_lmdb(env, path, size)
 93 | 
 94 |     img = img.astype(np.float32) / 255.
 95 |     if img.ndim == 2:
 96 |         img = np.expand_dims(img, axis=2)
 97 |     # some images have 4 channels
 98 |     if img.shape[2] > 3:
 99 |         img = img[:, :, :3]
100 |     return img
101 | 
102 | 
103 | def read_img2(env, path, size=None):
104 |     """read image by cv2 or from lmdb
105 |     return: Numpy float32, HWC, BGR, [0,1]"""
106 |     if env is None:  # img
107 |         img = np.load(path)
108 |         if img is None:
109 |             print(path)
110 |         if size is not None:
111 |             img = cv2.resize(img, (size[0], size[1]))
112 |             # img = cv2.resize(img, size)
113 |     else:
114 |         img = _read_img_lmdb(env, path, size)
115 |     img = get_max(img)
116 |     img = img.astype(np.float32) / 255.
117 |     if img.ndim == 2:
118 |         img = np.expand_dims(img, axis=2)
119 |     # some images have 4 channels
120 |     if img.shape[2] > 3:
121 |         img = img[:, :, :3]
122 |     return img
123 | 
124 | 
125 | def read_img_seq(path, size=None):
126 |     """Read a sequence of images from a given folder path
127 |     Args:
128 |         path (list/str): list of image paths/image folder path
129 | 
130 |     Returns:
131 |         imgs (Tensor): size (T, C, H, W), RGB, [0, 1]
132 |     """
133 |     # print(path)
134 |     if type(path) is list:
135 |         img_path_l = path
136 |     else:
137 |         img_path_l = sorted(glob.glob(os.path.join(path, '*')))
138 | 
139 |     img_l = [read_img(None, v, size) for v in img_path_l]
140 |     # stack to Torch tensor
141 |     imgs = np.stack(img_l, axis=0)
142 |     try:
143 |         imgs = imgs[:, :, :, [2, 1, 0]]
144 |     except Exception:
145 |         import ipdb; ipdb.set_trace()
146 |     imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
147 |     return imgs
148 | 
149 | 
150 | def read_img_seq2(path, size=None):
151 |     """Read a sequence of images from a given folder path
152 |     Args:
153 |         path (list/str): list of image paths/image folder path
154 | 
155 |     Returns:
156 |         imgs (Tensor): size (T, C, H, W), RGB, [0, 1]
157 |     """
158 |     # print(path)
159 |     if type(path) is list:
160 |         img_path_l = path
161 |     else:
162 |         img_path_l = sorted(glob.glob(os.path.join(path, '*')))
163 | 
164 |     img_l = [read_img2(None, v, size) for v in img_path_l]
165 |     # stack to Torch tensor
166 |     imgs = np.stack(img_l, axis=0)
167 |     try:
168 |         imgs = imgs[:, :, :, [2, 1, 0]]
169 |     except Exception:
170 |         import ipdb; ipdb.set_trace()
171 |     imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float()
172 |     return imgs
173 | 
174 | def get_max(x):
175 |     T =np.max(x,axis=0)
176 |     T=T+0.1
177 |     x[0,:,:] = x[0,:,:]/ T
178 |     x[1,:,:] = x[1,:,:]/ T
179 |     x[2,:,:]= x[2,:,:] / T
180 |     return x
181 | 
182 | 
183 | def index_generation(crt_i, max_n, N, padding='reflection'):
184 |     """Generate an index list for reading N frames from a sequence of images
185 |     Args:
186 |         crt_i (int): current center index
187 |         max_n (int): max number of the sequence of images (calculated from 1)
188 |         N (int): reading N frames
189 |         padding (str): padding mode, one of replicate | reflection | new_info | circle
190 |             Example: crt_i = 0, N = 5
191 |             replicate: [0, 0, 0, 1, 2]
192 |             reflection: [2, 1, 0, 1, 2]
193 |             new_info: [4, 3, 0, 1, 2]
194 |             circle: [3, 4, 0, 1, 2]
195 | 
196 |     Returns:
197 |         return_l (list [int]): a list of indexes
198 |     """
199 |     max_n = max_n - 1
200 |     n_pad = N // 2
201 |     return_l = []
202 | 
203 |     for i in range(crt_i - n_pad, crt_i + n_pad + 1):
204 |         if i < 0:
205 |             if padding == 'replicate':
206 |                 add_idx = 0
207 |             elif padding == 'reflection':
208 |                 add_idx = -i
209 |             elif padding == 'new_info':
210 |                 add_idx = (crt_i + n_pad) + (-i)
211 |             elif padding == 'circle':
212 |                 add_idx = N + i
213 |             else:
214 |                 raise ValueError('Wrong padding mode')
215 |         elif i > max_n:
216 |             if padding == 'replicate':
217 |                 add_idx = max_n
218 |             elif padding == 'reflection':
219 |                 add_idx = max_n * 2 - i
220 |             elif padding == 'new_info':
221 |                 add_idx = (crt_i - n_pad) - (i - max_n)
222 |             elif padding == 'circle':
223 |                 add_idx = i - N
224 |             else:
225 |                 raise ValueError('Wrong padding mode')
226 |         else:
227 |             add_idx = i
228 |         return_l.append(add_idx)
229 |     return return_l
230 | 
231 | 
232 | ####################
233 | # image processing
234 | # process on numpy image
235 | ####################
236 | 
237 | 
238 | def augment(img_list, hflip=True, rot=True):
239 |     """horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
240 |     hflip = hflip and random.random() < 0.5
241 |     vflip = rot and random.random() < 0.5
242 |     rot90 = rot and random.random() < 0.5
243 | 
244 |     def _augment(img):
245 |         if hflip:
246 |             img = img[:, ::-1, :]
247 |         if vflip:
248 |             img = img[::-1, :, :]
249 |         if rot90:
250 |             # import pdb; pdb.set_trace()
251 |             img = img.transpose(1, 0, 2)
252 |         return img
253 | 
254 |     return [_augment(img) for img in img_list]
255 | 
256 | 
257 | 
258 | def augment_torch(img_list, hflip=True, rot=True):
259 |     """horizontal flip OR rotate (0, 90, 180, 270 degrees)"""
260 |     hflip = hflip and random.random() < 0.5
261 |     vflip = rot and random.random() < 0.5
262 |     # rot90 = rot and random.random() < 0.5
263 | 
264 |     def _augment(img):
265 |         if hflip:
266 |             img = flip(img, 2)
267 |         if vflip:
268 |             img = flip(img, 1)
269 |         # if rot90:
270 |         #     # import pdb; pdb.set_trace()
271 |         #     img = img.transpose(1, 0, 2)
272 |         return img
273 | 
274 |     return [_augment(img) for img in img_list]
275 | 
276 | 
277 | def augment_flow(img_list, flow_list, hflip=True, rot=True):
278 |     """horizontal flip OR rotate (0, 90, 180, 270 degrees) with flows"""
279 |     hflip = hflip and random.random() < 0.5
280 |     vflip = rot and random.random() < 0.5
281 |     rot90 = rot and random.random() < 0.5
282 | 
283 |     def _augment(img):
284 |         if hflip:
285 |             img = img[:, ::-1, :]
286 |         if vflip:
287 |             img = img[::-1, :, :]
288 |         if rot90:
289 |             img = img.transpose(1, 0, 2)
290 |         return img
291 | 
292 |     def _augment_flow(flow):
293 |         if hflip:
294 |             flow = flow[:, ::-1, :]
295 |             flow[:, :, 0] *= -1
296 |         if vflip:
297 |             flow = flow[::-1, :, :]
298 |             flow[:, :, 1] *= -1
299 |         if rot90:
300 |             flow = flow.transpose(1, 0, 2)
301 |             flow = flow[:, :, [1, 0]]
302 |         return flow
303 | 
304 |     rlt_img_list = [_augment(img) for img in img_list]
305 |     rlt_flow_list = [_augment_flow(flow) for flow in flow_list]
306 | 
307 |     return rlt_img_list, rlt_flow_list
308 | 
309 | 
310 | def channel_convert(in_c, tar_type, img_list):
311 |     """conversion among BGR, gray and y"""
312 |     if in_c == 3 and tar_type == 'gray':  # BGR to gray
313 |         gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list]
314 |         return [np.expand_dims(img, axis=2) for img in gray_list]
315 |     elif in_c == 3 and tar_type == 'y':  # BGR to y
316 |         y_list = [bgr2ycbcr(img, only_y=True) for img in img_list]
317 |         return [np.expand_dims(img, axis=2) for img in y_list]
318 |     elif in_c == 1 and tar_type == 'RGB':  # gray/y to BGR
319 |         return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list]
320 |     else:
321 |         return img_list
322 | 
323 | 
324 | def rgb2ycbcr(img, only_y=True):
325 |     """same as matlab rgb2ycbcr
326 |     only_y: only return Y channel
327 |     Input:
328 |         uint8, [0, 255]
329 |         float, [0, 1]
330 |     """
331 |     in_img_type = img.dtype
332 |     img.astype(np.float32)
333 |     if in_img_type != np.uint8:
334 |         img *= 255.
335 |     # convert
336 |     if only_y:
337 |         rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
338 |     else:
339 |         rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
340 |                               [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
341 |     if in_img_type == np.uint8:
342 |         rlt = rlt.round()
343 |     else:
344 |         rlt /= 255.
345 |     return rlt.astype(in_img_type)
346 | 
347 | 
348 | def bgr2ycbcr(img, only_y=True):
349 |     """bgr version of rgb2ycbcr
350 |     only_y: only return Y channel
351 |     Input:
352 |         uint8, [0, 255]
353 |         float, [0, 1]
354 |     """
355 |     in_img_type = img.dtype
356 |     img.astype(np.float32)
357 |     if in_img_type != np.uint8:
358 |         img *= 255.
359 |     # convert
360 |     if only_y:
361 |         rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
362 |     else:
363 |         rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
364 |                               [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
365 |     if in_img_type == np.uint8:
366 |         rlt = rlt.round()
367 |     else:
368 |         rlt /= 255.
369 |     return rlt.astype(in_img_type)
370 | 
371 | 
372 | def ycbcr2rgb(img):
373 |     """same as matlab ycbcr2rgb
374 |     Input:
375 |         uint8, [0, 255]
376 |         float, [0, 1]
377 |     """
378 |     in_img_type = img.dtype
379 |     img.astype(np.float32)
380 |     if in_img_type != np.uint8:
381 |         img *= 255.
382 |     # convert
383 |     rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
384 |                           [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
385 |     if in_img_type == np.uint8:
386 |         rlt = rlt.round()
387 |     else:
388 |         rlt /= 255.
389 |     return rlt.astype(in_img_type)
390 | 
391 | 
392 | def modcrop(img_in, scale):
393 |     """img_in: Numpy, HWC or HW"""
394 |     img = np.copy(img_in)
395 |     if img.ndim == 2:
396 |         H, W = img.shape
397 |         H_r, W_r = H % scale, W % scale
398 |         img = img[:H - H_r, :W - W_r]
399 |     elif img.ndim == 3:
400 |         H, W, C = img.shape
401 |         H_r, W_r = H % scale, W % scale
402 |         img = img[:H - H_r, :W - W_r, :]
403 |     else:
404 |         raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim))
405 |     return img
406 | 
407 | 
--------------------------------------------------------------------------------
/dataset/LOLv1/readme.txt:
--------------------------------------------------------------------------------
1 | You can refer to the corresponding link to download the [LOL](https://daooshee.github.io/BMVC2018website/) dataset and put it in this folder.
2 | 
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
 1 | import logging
 2 | logger = logging.getLogger('base')
 3 | 
 4 | 
 5 | def create_model(opt):
 6 |     from .model import DDPM as M
 7 |     from .model import DDPM_PD as M_PD
 8 |     # print(opt['distill'])
 9 |     # import pdb; pdb.set_trace()
10 |     if opt['distill']:
11 |         m=M_PD(opt)
12 |     else:
13 |         m = M(opt)
14 |     logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
15 |     return m
16 | 
--------------------------------------------------------------------------------
/model/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/base_model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/__pycache__/base_model.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/model.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/__pycache__/model.cpython-38.pyc
--------------------------------------------------------------------------------
/model/__pycache__/networks.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/__pycache__/networks.cpython-38.pyc
--------------------------------------------------------------------------------
/model/base_model.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import torch
 3 | import torch.nn as nn
 4 | 
 5 | 
 6 | class BaseModel():
 7 |     def __init__(self, opt):
 8 |         self.opt = opt
 9 |         self.device = torch.device(
10 |             'cuda' if opt['gpu_ids'] is not None else 'cpu')
11 |         self.begin_step = 0
12 |         self.begin_epoch = 0
13 | 
14 |     def feed_data(self, data):
15 |         pass
16 | 
17 |     def optimize_parameters(self):
18 |         pass
19 | 
20 |     def get_current_visuals(self):
21 |         pass
22 | 
23 |     def get_current_losses(self):
24 |         pass
25 | 
26 |     def print_network(self):
27 |         pass
28 | 
29 |     def set_device(self, x):
30 |         if isinstance(x, dict):
31 |             for key, item in x.items():
32 |                 if item is not None:
33 |                     x[key] = item.to(self.device)
34 |         elif isinstance(x, list):
35 |             for item in x:
36 |                 if item is not None:
37 |                     item = item.to(self.device)
38 |         else:
39 |             x = x.to(self.device)
40 |         return x
41 | 
42 |     def get_network_description(self, network):
43 |         '''Get the string and total parameters of the network'''
44 |         if isinstance(network, nn.DataParallel):
45 |             network = network.module
46 |         s = str(network)
47 |         n = sum(map(lambda x: x.numel(), network.parameters()))
48 |         return s, n
49 | 
--------------------------------------------------------------------------------
/model/ddpm_modules/__pycache__/diffusion.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/ddpm_modules/__pycache__/diffusion.cpython-38.pyc
--------------------------------------------------------------------------------
/model/ddpm_modules/__pycache__/unet.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/ddpm_modules/__pycache__/unet.cpython-38.pyc
--------------------------------------------------------------------------------
/model/ddpm_modules/diffusion.py:
--------------------------------------------------------------------------------
  1 | import math
  2 | import torch
  3 | from torch import device, nn, einsum
  4 | import torch.nn.functional as F
  5 | from inspect import isfunction
  6 | from functools import partial
  7 | import numpy as np
  8 | from tqdm import tqdm
  9 | import cv2
 10 | import torchvision.transforms as T
 11 | 
 12 | from sklearn.cluster import AgglomerativeClustering
 13 | from sklearn.cluster import MeanShift
 14 | from sklearn.cluster import DBSCAN
 15 | from sklearn.cluster import SpectralClustering
 16 | import lpips
 17 | from torchvision.utils import save_image
 18 | from torch.optim.swa_utils import AveragedModel
 19 | 
 20 |  
 21 | 
 22 | transform = T.Lambda(lambda t: (t + 1) / 2)
 23 | 
 24 | def extract(v, t, x_shape):
 25 |    
 26 |     try:
 27 |         out = torch.gather(v, index=t, dim=0).float()
 28 |     except:
 29 |         # import pdb; pdb.set_trace()
 30 |         print(t)
 31 |         # import pdb; pdb.set_trace()
 32 |         print(print(v.shape))
 33 |         # import pdb; pdb.set_trace()
 34 |     return out.view([t.shape[0]] + [1] * (len(x_shape) - 1))
 35 | 
 36 | 
 37 | 
 38 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac):
 39 |     betas = linear_end * np.ones(n_timestep, dtype=np.float64)
 40 |     warmup_time = int(n_timestep * warmup_frac)
 41 |     betas[:warmup_time] = np.linspace(
 42 |         linear_start, linear_end, warmup_time, dtype=np.float64)
 43 |     return betas
 44 | 
 45 | 
 46 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
 47 |     if schedule == 'quad':
 48 |         betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5,
 49 |                             n_timestep, dtype=np.float64) ** 2
 50 |     elif schedule == 'linear':
 51 |         betas = np.linspace(linear_start, linear_end,
 52 |                             n_timestep, dtype=np.float64)
 53 |     elif schedule == 'warmup10':
 54 |         betas = _warmup_beta(linear_start, linear_end,
 55 |                              n_timestep, 0.1)
 56 |     elif schedule == 'warmup50':
 57 |         betas = _warmup_beta(linear_start, linear_end,
 58 |                              n_timestep, 0.5)
 59 |     elif schedule == 'const':
 60 |         betas = linear_end * np.ones(n_timestep, dtype=np.float64)
 61 |     elif schedule == 'jsd':  # 1/T, 1/(T-1), 1/(T-2), ..., 1
 62 |         betas = 1. / np.linspace(n_timestep,
 63 |                                  1, n_timestep, dtype=np.float64)
 64 |     elif schedule == "cosine":
 65 |         timesteps = (
 66 |             torch.arange(n_timestep + 1, dtype=torch.float64) /
 67 |             n_timestep + cosine_s
 68 |         )
 69 |         alphas = timesteps / (1 + cosine_s) * math.pi / 2
 70 |         alphas = torch.cos(alphas).pow(2)
 71 |         alphas = alphas / alphas[0]
 72 |         betas = 1 - alphas[1:] / alphas[:-1]
 73 |         betas = betas.clamp(max=0.999)
 74 |     else:
 75 |         raise NotImplementedError(schedule)
 76 |     return betas
 77 | 
 78 | 
 79 | # gaussian diffusion trainer class
 80 | 
 81 | def exists(x):
 82 |     return x is not None
 83 | 
 84 | 
 85 | def default(val, d):
 86 |     if exists(val):
 87 |         return val
 88 |     return d() if isfunction(d) else d
 89 | 
 90 | 
 91 | class GaussianDiffusion(nn.Module):
 92 |     def __init__(
 93 |         self,
 94 |         denoise_fn,
 95 |         image_size,
 96 |         num_timesteps,
 97 |         time_scale,
 98 |         w_str,
 99 |         w_gt,
100 |         w_snr,
101 |         w_lpips,
102 |         channels=3,
103 |         loss_type='l1',
104 |         conditional=True,
105 |         schedule_opt=None
106 |     ):
107 |         super().__init__()
108 |         self.channels = channels
109 |         self.image_size = image_size
110 |         self.denoise_fn = denoise_fn
111 |         self.loss_type = loss_type
112 |         self.conditional = conditional
113 |         self.num_timesteps = num_timesteps
114 |         device = torch.device("cuda")
115 | 
116 |         self.w_str = w_str 
117 |         self.w_gt = w_gt
118 |         self.w_snr = w_snr
119 |         self.w_lpips = w_lpips
120 |         # self.lpips = lpips.LPIPS(net='vgg').cuda()
121 |         # print(self.num_timesteps)
122 |         # import pdb; pdb.set_trace()
123 |         self.time_scale = time_scale
124 |         self.CD = False
125 |         if schedule_opt is not None:
126 |             self.set_new_noise_schedule(schedule_opt, device)
127 | 
128 |     def set_loss(self, device):
129 |         if self.loss_type == 'l1':
130 |             self.loss_func = nn.L1Loss(reduction='sum').to(device)
131 |         elif self.loss_type == 'l2':
132 |             self.loss_func = nn.MSELoss(reduction='sum').to(device)
133 |         else:
134 |             raise NotImplementedError()
135 | 
136 |     def set_new_noise_schedule(self, schedule_opt, device):
137 |         to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
138 | 
139 |         betas = make_beta_schedule(
140 |             schedule=schedule_opt['schedule'],
141 |             n_timestep= self.num_timesteps* self.time_scale + 1,
142 |             linear_start=schedule_opt['linear_start'],
143 |             linear_end=schedule_opt['linear_end'])
144 |         betas = betas.detach().cpu().numpy() if isinstance(
145 |             betas, torch.Tensor) else betas
146 |         alphas = 1. - betas
147 |         alphas_cumprod = np.cumprod(alphas, axis=0)
148 |         alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
149 |         self.sqrt_alphas_cumprod_prev = np.sqrt(
150 |             np.append(1., alphas_cumprod))
151 |         
152 |         timesteps, = betas.shape
153 |         self.register_buffer('betas', to_torch(betas))
154 |         self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
155 |         self.register_buffer('alphas_cumprod_prev',
156 |                              to_torch(alphas_cumprod_prev))
157 | 
158 |         # calculations for diffusion q(x_t | x_{t-1}) and others
159 |         self.register_buffer('sqrt_alphas_cumprod',
160 |                              to_torch(np.sqrt(alphas_cumprod)))
161 |         self.register_buffer('sqrt_one_minus_alphas_cumprod',
162 |                              to_torch(np.sqrt(1. - alphas_cumprod)))
163 |         self.register_buffer('log_one_minus_alphas_cumprod',
164 |                              to_torch(np.log(1. - alphas_cumprod)))
165 |         self.register_buffer('sqrt_recip_alphas_cumprod',
166 |                              to_torch(np.sqrt(1. / alphas_cumprod)))
167 |         self.register_buffer('sqrt_recipm1_alphas_cumprod',
168 |                              to_torch(np.sqrt(1. / alphas_cumprod - 1)))
169 | 
170 |         # calculations for posterior q(x_{t-1} | x_t, x_0)
171 |         posterior_variance = betas * \
172 |             (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
173 |         # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t)
174 |         self.register_buffer('posterior_variance',
175 |                              to_torch(posterior_variance))
176 |         # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
177 |         self.register_buffer('posterior_log_variance_clipped', to_torch(
178 |             np.log(np.maximum(posterior_variance, 1e-20))))
179 |         self.register_buffer('posterior_mean_coef1', to_torch(
180 |             betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)))
181 |         self.register_buffer('posterior_mean_coef2', to_torch(
182 |             (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)))
183 |         
184 | 
185 |     def predict_start_from_noise(self, x_t, t, noise):
186 |         return self.sqrt_recip_alphas_cumprod[t] * x_t - \
187 |             self.sqrt_recipm1_alphas_cumprod[t] * noise
188 |     
189 |     def predict_eps_from_x(self, x_t, x_0, t):
190 | 
191 |         eps = (x_t -self.sqrt_alphas_cumprod[t] * x_0) / self.sqrt_one_minus_alphas_cumprod[t]
192 |         return eps
193 |     
194 |     def predict_eps(self, x_t, x_0, continuous_sqrt_alpha_cumprod):
195 |         
196 |         eps = (1. / (1 - continuous_sqrt_alpha_cumprod **2).sqrt()) * x_t - \
197 |             (1. / (1 - continuous_sqrt_alpha_cumprod**2) -1).sqrt() * x_0
198 | 
199 |         return eps
200 |     
201 |     def predict_start(self, x_t, continuous_sqrt_alpha_cumprod, noise):
202 | 
203 |         return (1. / continuous_sqrt_alpha_cumprod) * x_t - \
204 |             (1. / continuous_sqrt_alpha_cumprod**2 - 1).sqrt() * noise
205 | 
206 |     def predict_t_minus1(self, x, t, continuous_sqrt_alpha_cumprod, noise, clip_denoised=True):
207 | 
208 |         x_recon = self.predict_start(x, 
209 |                     continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), 
210 |                     noise=noise)
211 | 
212 |         if clip_denoised:
213 |             x_recon.clamp_(-1., 1.)
214 | 
215 |         model_mean, model_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
216 | 
217 |         noise_z = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
218 | 
219 |         return model_mean + noise_z * (0.5 * model_log_variance).exp() 
220 | 
221 |     def q_posterior(self, x_start, x_t, t):
222 |         posterior_mean = self.posterior_mean_coef1[t] * \
223 |             x_start + self.posterior_mean_coef2[t] * x_t
224 |         posterior_log_variance_clipped = self.posterior_log_variance_clipped[t]
225 |         return posterior_mean, posterior_log_variance_clipped
226 | 
227 |     def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None):
228 |         batch_size = x.shape[0]
229 |         noise_level = torch.FloatTensor(
230 |             [self.sqrt_alphas_cumprod_prev[t+self.time_scale]]).repeat(batch_size, 1).to(x.device)
231 |         
232 |         eps = self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level)[0]
233 |         # print(t)
234 | 
235 |         x_recon = self.predict_start_from_noise(x, t=t*self.time_scale, noise=eps)
236 | 
237 | 
238 |         if clip_denoised:
239 |             x_recon.clamp_(-1., 1.)
240 | 
241 |         model_mean, posterior_log_variance = self.q_posterior(
242 |             x_start=x_recon, x_t=x, t=t)
243 |         return model_mean, posterior_log_variance, eps
244 | 
245 |     @torch.no_grad()
246 |     def p_sample(self, x, t, clip_denoised=True, condition_x=None):
247 |         model_mean, model_log_variance, eps = self.p_mean_variance(
248 |             x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x)
249 |         noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x)
250 |         return model_mean + noise * (0.5 * model_log_variance).exp()
251 | 
252 |     @torch.no_grad()
253 |     def p_sample_loop(self, x_in, continous=False):
254 |         device = self.betas.device
255 |         sample_inter = (1 | (self.num_timesteps//10))
256 |         if not self.conditional:
257 |             shape = x_in
258 |             img = torch.randn(shape, device=device)
259 |             ret_img = img
260 |             for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
261 |                 img = self.p_sample(img, i)
262 |                 if i % sample_inter == 0:
263 |                     ret_img = torch.cat([ret_img, img], dim=0)
264 |         else:
265 |             x = x_in
266 |             shape = x.shape
267 |             img = torch.randn(shape, device=device)
268 |             ret_img = x
269 |             # print(self.time_scale)
270 |             # import pdb; pdb.set_trace()
271 |             for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps):
272 |                 img = self.p_sample(img, i, condition_x=x)
273 |                 if i % sample_inter == 0:
274 |                     ret_img = torch.cat([ret_img, img], dim=0)
275 |         if continous:
276 |             return ret_img
277 |         else:
278 |             return ret_img[-1]
279 | 
280 |     @torch.no_grad()
281 |     def sample(self, batch_size=1, continous=False):
282 |         image_size = self.image_size
283 |         channels = self.channels
284 |         return self.p_sample_loop((batch_size, channels, image_size, image_size), continous)
285 | 
286 |     @torch.no_grad()
287 |     def super_resolution(self, x_in, continous=False, stride=1):
288 |         return self.ddim(x_in, continous, stride=stride)
289 | 
290 |     def q_sample(self, x_start, continuous_sqrt_alpha_cumprod, noise=None):
291 | 
292 |         return (
293 |             continuous_sqrt_alpha_cumprod * x_start +
294 |             (1 - continuous_sqrt_alpha_cumprod**2).sqrt() * noise
295 |         )
296 | 
297 |     def ddim(self, x_in, continous=False, snr_aware=False, stride=1, clip_denoised=True): 
298 |         x = x_in
299 |         condition_x = x_in 
300 |         x_t = torch.randn(x.shape, device=x.device)
301 | 
302 |         batch_size = x_in.shape[0]
303 |   
304 |         for time_step in reversed(range(stride, self.num_timesteps + 1, stride)):
305 | 
306 |             
307 |             t = x_t.new_ones([x_t.shape[0], ], dtype=torch.long) * time_step
308 |             s = x_t.new_ones([x_t.shape[0], ], dtype=torch.long) * (time_step - stride)
309 |             noise_level = torch.FloatTensor([self.sqrt_alphas_cumprod_prev[t* self.time_scale]]).repeat(batch_size, 1).to(x.device)
310 |             eps =self.denoise_fn(torch.cat([condition_x, x_t], dim=1), noise_level)[0]
311 |             x_0 = self.predict_start_from_noise(x_t, t * self.time_scale, eps)
312 |             if clip_denoised:
313 |                 x_0 = torch.clip(x_0, -1., 1.)
314 |                 eps = self.predict_eps_from_x(x_t, x_0, t * self.time_scale)
315 | 
316 |             x_t = self.sqrt_alphas_cumprod[s * self.time_scale] * x_0 + self.sqrt_one_minus_alphas_cumprod[s * self.time_scale] * eps
317 | 
318 |         return torch.clip(x_t, -1, 1)
319 | 
320 |     def SNR_map(self, x_0):
321 |         blur_transform = T.GaussianBlur(kernel_size=15, sigma=3)
322 |         blur_x_0 = blur_transform(x_0)
323 |         gray_blur_x_0 = blur_x_0[:, 0:1, :, :] * 0.299 + blur_x_0[:, 1:2, :, :] * 0.587 + blur_x_0[:, 2:3, :, :] * 0.114
324 |         gray_x_0 = x_0[:, 0:1, :, :] * 0.299 + x_0[:, 1:2, :, :] * 0.587 + x_0[:, 2:3, :, :] * 0.114
325 |         noise = torch.abs(gray_blur_x_0 - gray_x_0)
326 | 
327 |         return noise
328 |     
329 | 
330 |         
331 |     def loss(self, x_in, student, noise=None, lpips_func=None):
332 |         x_0 = x_in['GT']
333 |         [b, c, h, w] = x_0.shape
334 |       
335 |         t = 2 * np.random.randint(1, student.num_timesteps + 1) 
336 | 
337 |         continuous_sqrt_alpha_cumprod = torch.FloatTensor(
338 |             np.random.uniform(
339 |                 self.sqrt_alphas_cumprod_prev[(t-1)*self.time_scale],
340 |                 self.sqrt_alphas_cumprod_prev[t*self.time_scale],
341 |                 size=b
342 |             )
343 |         ).to(x_0.device)
344 | 
345 |         continuous_sqrt_alpha_cumprod_t_mins_1 = torch.FloatTensor(
346 |             np.random.uniform(
347 |                 self.sqrt_alphas_cumprod_prev[(t-2)*self.time_scale],
348 |                 self.sqrt_alphas_cumprod_prev[(t-1)*self.time_scale],
349 |                 size=b
350 |             )
351 |         ).to(x_0.device)
352 |         continuous_sqrt_alpha_cumprod_t_mins_2 = torch.FloatTensor(
353 |             np.random.uniform(
354 |                 self.sqrt_alphas_cumprod_prev[(t-3)*self.time_scale],
355 |                 self.sqrt_alphas_cumprod_prev[(t-2)*self.time_scale],
356 |                 size=b
357 |             )
358 |         ).to(x_0.device)
359 |        
360 |         continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view(b, -1)
361 |         continuous_sqrt_alpha_cumprod_t_mins_1 = continuous_sqrt_alpha_cumprod_t_mins_1.view(b, -1)
362 |         continuous_sqrt_alpha_cumprod_t_mins_2 = continuous_sqrt_alpha_cumprod_t_mins_2.view(b, -1)
363 |         
364 |         noise = default(noise, lambda: torch.randn_like(x_0))
365 |         t = torch.tensor([t], dtype=torch.int64).to(x_0.device)
366 |         bs = x_0.size(0)
367 | 
368 |         with torch.no_grad():   
369 |             z_t = self.q_sample(x_0, continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), noise)
370 |             eps_rec, _ = self.denoise_fn(torch.cat([x_in['LQ'], z_t], dim=1), continuous_sqrt_alpha_cumprod) 
371 |             x_0_rec = self.predict_start(z_t, continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), eps_rec)
372 |             z_t_minus_1 = self.q_sample(x_0_rec, continuous_sqrt_alpha_cumprod_t_mins_1.view(-1, 1, 1, 1), eps_rec)
373 |             eps_rec_rec, _ = self.denoise_fn(torch.cat([x_in['LQ'], z_t_minus_1], dim=1), continuous_sqrt_alpha_cumprod_t_mins_1)
374 |             x_0_rec_rec = self.predict_start(z_t_minus_1, continuous_sqrt_alpha_cumprod_t_mins_1.view(-1, 1, 1, 1), eps_rec_rec)
375 |             z_t_minus_2 = self.q_sample(x_0_rec_rec, continuous_sqrt_alpha_cumprod_t_mins_2.view(-1, 1, 1, 1), eps_rec_rec)
376 |             frac = (1 - continuous_sqrt_alpha_cumprod_t_mins_2.view(-1, 1, 1, 1)**2).sqrt() / (1- continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1)**2).sqrt()         
377 |             if self.w_snr != 0:
378 |                 y = x_in['LQ']
379 |                 T,_=torch.max(y,dim=1, keepdim=True)
380 |                 T=T+0.1
381 |                 y = y / T 
382 |                 iso_noise = self.SNR_map(y)
383 |                 y = y - iso_noise      
384 |                 refine_x_0 =  y                  
385 |                 z_t_minus_2_refine = self.q_sample(refine_x_0, continuous_sqrt_alpha_cumprod_t_mins_2.view(-1, 1, 1, 1), eps_rec_rec)
386 |                 z_t_minus_2 =  z_t_minus_2 + self.w_snr *(z_t_minus_2_refine - z_t_minus_2)
387 |             
388 |             x_target = (z_t_minus_2 - frac * z_t) / ( continuous_sqrt_alpha_cumprod_t_mins_2.view(-1, 1, 1, 1) - frac * continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1 ))
389 |             eps_target = self.predict_eps(z_t, x_target, continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1)) 
390 | 
391 |         eps_predicted, _ = student.denoise_fn(torch.cat([x_in['LQ'], z_t], dim=1), continuous_sqrt_alpha_cumprod)
392 |         x_0_predicted = self.predict_start(z_t, continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), eps_predicted)
393 |         loss_x_0 = torch.mean(F.mse_loss(x_0_predicted, x_target, reduction='none').reshape(bs, -1), dim=-1)
394 |         loss_eps = torch.mean(F.mse_loss(eps_predicted, eps_target, reduction='none').reshape(bs, -1), dim=-1)
395 | 
396 |         loss_stru = torch.zeros_like(loss_x_0) # 0.
397 |         if self.w_gt != 0:
398 |             loss_output_x0 = torch.mean(F.mse_loss(x_0, x_0_predicted, reduction='none').reshape(bs, -1), dim=-1)  
399 |             loss_output_eps = torch.mean(F.mse_loss(noise, eps_predicted, reduction='none').reshape(bs, -1), dim=-1)  
400 | 
401 |         else:
402 |             loss_output_x0 = torch.zeros_like(loss_x_0) # 0.
403 |             loss_output_eps = torch.zeros_like(loss_eps) # 0.
404 | 
405 |         if self.w_lpips != 0:
406 |             loss_lpips = torch.mean(lpips_func(x_0, x_0_predicted)) 
407 |         else:
408 |             loss_lpips = torch.zeros_like(loss_x_0) # 0.
409 | 
410 |         return torch.mean(torch.maximum(loss_x_0, loss_eps)) + \
411 |               self.w_gt * torch.mean(torch.maximum(loss_output_x0, loss_output_eps)) + \
412 |               self.w_lpips*torch.mean(loss_lpips) + \
413 |               self.w_str*torch.mean(loss_stru)
414 | 
415 |     
416 |     
417 | 
418 | 
419 |     def forward(self, x, s_model=None,  *args, **kwargs):
420 |             return self.loss(x, s_model, *args, **kwargs)
421 | 
422 | 
423 | 
--------------------------------------------------------------------------------
/model/ddpm_modules/unet.py:
--------------------------------------------------------------------------------
  1 | import math
  2 | import torch
  3 | import torch.fft as fft
  4 | from torch import nn
  5 | import torch.nn.functional as F
  6 | from inspect import isfunction
  7 | import cv2
  8 | import torchvision.transforms as T
  9 | import numpy as np
 10 | def exists(x):
 11 |     return x is not None
 12 | 
 13 | # PositionalEncoding Source: https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py
 14 | class PositionalEncoding(nn.Module):
 15 |     def __init__(self, dim):
 16 |         super().__init__()
 17 |         self.dim = dim
 18 | 
 19 |     def forward(self, noise_level):
 20 |         count = self.dim // 2
 21 |         step = torch.arange(count, dtype=noise_level.dtype,
 22 |                             device=noise_level.device) / count
 23 |         encoding = noise_level.unsqueeze(
 24 |             1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
 25 |         encoding = torch.cat(
 26 |             [torch.sin(encoding), torch.cos(encoding)], dim=-1)
 27 |         return encoding
 28 | 
 29 | class FeatureWiseAffine(nn.Module):
 30 |     def __init__(self, in_channels, out_channels, use_affine_level=False):
 31 |         super(FeatureWiseAffine, self).__init__()
 32 |         self.use_affine_level = use_affine_level
 33 |         self.noise_func = nn.Sequential(
 34 |             nn.Linear(in_channels, out_channels*(1+self.use_affine_level))
 35 |         )
 36 | 
 37 |     def forward(self, x, noise_embed):
 38 |         batch = x.shape[0]
 39 |         if self.use_affine_level:
 40 |             gamma, beta = self.noise_func(noise_embed).view(
 41 |                 batch, -1, 1, 1).chunk(2, dim=1)
 42 |             x = (1 + gamma) * x + beta
 43 |         else:
 44 |             x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1)
 45 |         return x
 46 | 
 47 | def default(val, d):
 48 |     if exists(val):
 49 |         return val
 50 |     return d() if isfunction(d) else d
 51 | 
 52 | # model
 53 | class Swish(nn.Module):
 54 |     def forward(self, x):
 55 |         return x * torch.sigmoid(x)
 56 | 
 57 | 
 58 | class Upsample(nn.Module):
 59 |     def __init__(self, dim):
 60 |         super().__init__()
 61 |         self.up = nn.Upsample(scale_factor=2, mode="nearest")
 62 |         self.conv = nn.Conv2d(dim, dim, 3, padding=1)
 63 | 
 64 |     def forward(self, x):
 65 |         return self.conv(self.up(x))
 66 | 
 67 | 
 68 | class Downsample(nn.Module):
 69 |     def __init__(self, dim):
 70 |         super().__init__()
 71 |         self.conv = nn.Conv2d(dim, dim, 3, 2, 1)
 72 | 
 73 |     def forward(self, x):
 74 |         return self.conv(x)
 75 | 
 76 | 
 77 | # building block modules
 78 | class Block(nn.Module):
 79 |     def __init__(self, dim, dim_out, groups=32, dropout=0):
 80 |         super().__init__()
 81 |         self.block = nn.Sequential(
 82 |             nn.GroupNorm(groups, dim),
 83 |             Swish(),
 84 |             nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
 85 |             nn.Conv2d(dim, dim_out, 3, padding=1)
 86 |         )
 87 | 
 88 |     def forward(self, x):
 89 |         return self.block(x)
 90 | 
 91 | 
 92 | class ResnetBlock(nn.Module):
 93 |     def __init__(self, dim, dim_out, time_emb_dim=None, dropout=0, norm_groups=32):
 94 |         super().__init__()
 95 |         self.mlp = nn.Sequential(
 96 |             Swish(),
 97 |             nn.Linear(time_emb_dim, dim_out)
 98 |         ) if exists(time_emb_dim) else None
 99 |         self.noise_func = FeatureWiseAffine(
100 |             time_emb_dim, dim_out, use_affine_level=False)
101 | 
102 |         self.block1 = Block(dim, dim_out, groups=norm_groups)
103 |         self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
104 |         self.res_conv = nn.Conv2d(
105 |             dim, dim_out, 1) if dim != dim_out else nn.Identity()
106 | 
107 |     def forward(self, x, time_emb):
108 |         h = self.block1(x)
109 |         h = self.noise_func(h, time_emb)
110 |         h = self.block2(h)
111 |         return h + self.res_conv(x)
112 | 
113 | 
114 | class SelfAttention(nn.Module):
115 |     def __init__(self, in_channel, n_head=1, norm_groups=32):
116 |         super().__init__()
117 | 
118 |         self.n_head = n_head
119 | 
120 |         self.norm = nn.GroupNorm(norm_groups, in_channel)
121 |         self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
122 |         self.out = nn.Conv2d(in_channel, in_channel, 1)
123 | 
124 |     def forward(self, input):
125 |         batch, channel, height, width = input.shape
126 |         n_head = self.n_head
127 |         head_dim = channel // n_head
128 | 
129 |         norm = self.norm(input)
130 |         qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
131 |         query, key, value = qkv.chunk(3, dim=2)  # bhdyx
132 | 
133 |         attn = torch.einsum(
134 |             "bnchw, bncyx -> bnhwyx", query, key
135 |         ).contiguous() / math.sqrt(channel)
136 |         attn = attn.view(batch, n_head, height, width, -1)
137 |         attn = torch.softmax(attn, -1)
138 |         attn = attn.view(batch, n_head, height, width, height, width)
139 | 
140 |         out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
141 |         out = self.out(out.view(batch, channel, height, width))
142 | 
143 |         return out + input
144 | 
145 | 
146 | class ResnetBlocWithAttn(nn.Module):
147 |     def __init__(self, dim, dim_out, *, time_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
148 |         super().__init__()
149 |         self.with_attn = with_attn
150 |         self.res_block = ResnetBlock(
151 |             dim, dim_out, time_emb_dim, norm_groups=norm_groups, dropout=dropout)
152 |         if with_attn:
153 |             self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
154 | 
155 |     def forward(self, x, time_emb):
156 |         x = self.res_block(x, time_emb)
157 |         if(self.with_attn):
158 |             x = self.attn(x)
159 |         return x
160 | 
161 | 
162 | class UNet(nn.Module):
163 |     def __init__(
164 |         self,
165 |         in_channel=6,
166 |         out_channel=3,
167 |         inner_channel=32,
168 |         norm_groups=32,
169 |         channel_mults=(1, 1, 2, 2, 4),
170 |         attn_res=(8),
171 |         res_blocks=3,
172 |         dropout=0,
173 |         with_noise_level_emb=True,
174 |         image_size=128
175 |     ):
176 |         super().__init__()
177 |         if with_noise_level_emb:
178 |             noise_level_channel = inner_channel
179 |             self.noise_level_mlp = nn.Sequential(
180 |                 PositionalEncoding(inner_channel),
181 |                 nn.Linear(inner_channel, inner_channel * 4),
182 |                 Swish(),
183 |                 nn.Linear(inner_channel * 4, inner_channel)
184 |             )
185 |         else:
186 |             noise_level_channel = None
187 |             self.noise_level_mlp = None
188 | 
189 | 
190 |         num_mults = len(channel_mults)
191 |         pre_channel = inner_channel
192 |         feat_channels = [pre_channel]
193 |         now_res = image_size
194 |         downs = [nn.Conv2d(in_channel, inner_channel,
195 |                            kernel_size=3, padding=1)]
196 |         for ind in range(num_mults):
197 |             is_last = (ind == num_mults - 1)
198 |             use_attn = (now_res in attn_res)
199 |             channel_mult = inner_channel * channel_mults[ind]
200 |             for _ in range(0, res_blocks):
201 |                 downs.append(ResnetBlocWithAttn(
202 |                     pre_channel, channel_mult, time_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn))
203 |                 feat_channels.append(channel_mult)
204 |                 pre_channel = channel_mult
205 |             if not is_last:
206 |                 downs.append(Downsample(pre_channel))
207 |                 feat_channels.append(pre_channel)
208 |                 now_res = now_res//2
209 |         self.downs = nn.ModuleList(downs)
210 | 
211 |         self.mid = nn.ModuleList([
212 |             ResnetBlocWithAttn(pre_channel, pre_channel, time_emb_dim=noise_level_channel, norm_groups=norm_groups,
213 |                                dropout=dropout, with_attn=True),
214 |             ResnetBlocWithAttn(pre_channel, pre_channel, time_emb_dim=noise_level_channel, norm_groups=norm_groups, 
215 |                                 dropout=dropout, with_attn=False)
216 |         ])
217 | 
218 |         ups = []
219 |         for ind in reversed(range(num_mults)):
220 |             is_last = (ind < 1)
221 |             use_attn = (now_res in attn_res)
222 |             channel_mult = inner_channel * channel_mults[ind]
223 |             for _ in range(0, res_blocks+1):
224 |                 ups.append(ResnetBlocWithAttn(
225 |                     pre_channel+feat_channels.pop(), channel_mult, time_emb_dim=noise_level_channel, dropout=dropout, norm_groups=norm_groups, with_attn=use_attn))
226 |                 pre_channel = channel_mult
227 |             if not is_last:
228 |                 ups.append(Upsample(pre_channel))
229 |                 now_res = now_res*2
230 | 
231 |         self.ups = nn.ModuleList(ups)
232 | 
233 |         self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups)
234 | 
235 |         self.var_conv = nn.Sequential(*[
236 |             nn.Conv2d(pre_channel, pre_channel, 3, padding=(3//2), bias=True), 
237 |             nn.ELU(), 
238 |             nn.Conv2d(pre_channel, pre_channel, 3, padding=(3//2), bias=True),
239 |             nn.ELU(), 
240 |             nn.Conv2d(pre_channel, 3, 3, padding=(3//2), bias=True),
241 |             nn.ELU()
242 |         ])
243 |         # self.swish = Swish()
244 |     def default_conv(in_channels, out_channels, kernel_size, bias=True):
245 |         return nn.Conv2d(
246 |             in_channels, out_channels, kernel_size,
247 |             padding=(kernel_size//2), bias=bias)
248 | 
249 |     def forward(self, x, noise):
250 |         noise_level = self.noise_level_mlp(noise) if exists(self.noise_level_mlp) else None
251 |         feats = []
252 |         for layer in self.downs:
253 |             if isinstance(layer, ResnetBlocWithAttn):
254 |                 x = layer(x, noise_level)
255 |             else:
256 |                 x = layer(x)
257 |             feats.append(x)
258 | 
259 |         for layer in self.mid:
260 |             if isinstance(layer, ResnetBlocWithAttn):
261 |                 x = layer(x, noise_level)
262 |             else:
263 |                 x = layer(x)
264 | 
265 |         for layer in self.ups:
266 |             if isinstance(layer, ResnetBlocWithAttn):
267 |                 x = layer(torch.cat((x, feats.pop()), dim=1), noise_level)
268 |             else:
269 |                 x = layer(x)
270 |         return self.final_conv(x), self.var_conv(x)
271 | 
272 | 
273 | 
274 | # FreeU
275 | def Fourier_filter(x, threshold, scale):
276 |     # FFT
277 |     x_freq = fft.fftn(x, dim=(-2, -1))
278 |     x_freq = fft.fftshift(x_freq, dim=(-2, -1))
279 |     
280 |     B, C, H, W = x_freq.shape
281 |     mask = torch.ones((B, C, H, W)).cuda() 
282 | 
283 |     crow, ccol = H // 2, W //2
284 |     mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
285 |     x_freq = x_freq * mask
286 | 
287 |     # IFFT
288 |     x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
289 |     x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
290 |     
291 |     return x_filtered
292 | 
293 | def SNR_filter(x_in, threshold, scale):
294 | 
295 |     blur_transform = T.GaussianBlur(kernel_size=15, sigma=3)
296 |     blur_x_in = blur_transform(x_in)
297 |     gray_blur_x_in = blur_x_in[:, 0:1, :, :] * 0.299 + blur_x_in[:, 1:2, :, :] * 0.587 + blur_x_in[:, 2:3, :, :] * 0.114
298 |     gray_x_in = x_in[:, 0:1, :, :] * 0.299 + x_in[:, 1:2, :, :] * 0.587 + x_in[:, 2:3, :, :] * 0.114
299 |     noise = torch.abs(gray_blur_x_in - gray_x_in)
300 |     mask = torch.div(gray_x_in, noise + 0.0001)
301 | 
302 |     batch_size = mask.shape[0]
303 |     height = mask.shape[2]
304 |     width = mask.shape[3]
305 |     mask_max = torch.max(mask.view(batch_size, -1), dim=1)[0]
306 |     mask_max = mask_max.view(batch_size, 1, 1, 1)
307 |     mask_max = mask_max.repeat(1, 1, height, width)
308 |     mask = mask * 1.0 / (mask_max+0.0001)
309 |     mask = torch.clamp(mask, min=0, max=1.0)
310 |     mask = mask.float()
311 | 
312 |     return mask
313 | 
314 |     
315 | class Free_UNet(UNet):
316 |     """
317 |     :param b1: backbone factor of the first stage block of decoder.
318 |     :param b2: backbone factor of the second stage block of decoder.
319 |     :param s1: skip factor of the first stage block of decoder.
320 |     :param s2: skip factor of the second stage block of decoder.
321 |     """
322 | 
323 |     def __init__(
324 |         self,
325 |         b1=1.3,
326 |         b2=1.4,
327 |         s1=0.9,
328 |         s2=0.2,
329 |         *args,
330 |         **kwargs
331 |     ):
332 |         super().__init__(*args, **kwargs)
333 |         self.b1 = b1 
334 |         self.b2 = b2
335 |         self.s1 = s1
336 |         self.s2 = s2
337 | 
338 |     def forward(self, h, noise):
339 |         # what we need is only x and noise
340 |         """
341 |         Apply the model to an input batch.
342 |         :param x: an [N x C x ...] Tensor of inputs.
343 |         :param timesteps: a 1-D batch of timesteps.
344 |         :param context: conditioning plugged in via crossattn
345 |         :param y: an [N] Tensor of labels, if class-conditional.
346 |         :return: an [N x C x ...] Tensor of outputs.
347 |         """
348 |         hs = []
349 |         noise_level = self.noise_level_mlp(noise) if exists(self.noise_level_mlp) else None
350 | 
351 | 
352 | 
353 |         for layer in self.downs:
354 |             if isinstance(layer, ResnetBlocWithAttn):
355 |                 h = layer(h, noise_level)
356 |             else:
357 |                 h = layer(h)
358 |             hs.append(h)
359 |         
360 |         for layer in self.mid:
361 |             if isinstance(layer, ResnetBlocWithAttn):
362 |                 h = layer(h, noise_level)
363 |             else:
364 |                 h = layer(h)
365 | 
366 |         for layer in self.ups:
367 |             # --------------- FreeU code -----------------------
368 |             # Only operate on the first two stages
369 |             if h.shape[1] == 256:
370 |                 hs_ = hs.pop()
371 |                 hidden_mean = h.mean(1).unsqueeze(1)
372 |                 B = hidden_mean.shape[0]
373 |                 hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 
374 |                 hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
375 |                 hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
376 | 
377 |                 h[:,:128] = h[:,:128] * ((self.b1 - 1 ) * hidden_mean + 1)
378 |                 hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1)
379 |                 hs.append(hs_)
380 |             if h.shape[1] == 128:
381 |                 hs_ = hs.pop()
382 |                 hidden_mean = h.mean(1).unsqueeze(1)
383 |                 B = hidden_mean.shape[0]
384 |                 hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 
385 |                 hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
386 |                 hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
387 | 
388 |                 h[:,:64] = h[:,:64] * ((self.b2 - 1 ) * hidden_mean + 1)
389 |                 hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2)
390 |                 hs.append(hs_)
391 |             # ---------------------------------------------------------
392 |             # exit(print(h.shape, hs_.shape)) # torch.Size([1, 256, 26, 38]) torch.Size([1, 256, 26, 38])
393 |             # print(h.shape, hs_.shape)
394 |             # h = torch.cat((h, hs_), dim=1)
395 | 
396 |             if isinstance(layer, ResnetBlocWithAttn):
397 |                 h = layer(torch.cat((h, hs.pop()), dim=1), noise_level)
398 |             else:
399 |                 h = layer(h)
400 |         return self.final_conv(h), self.var_conv(h)
401 | 
402 | class LAUNet(UNet):
403 |     """
404 |     :param b1: backbone factor of the first stage block of decoder.
405 |     :param b2: backbone factor of the second stage block of decoder.
406 |     :param s1: skip factor of the first stage block of decoder.
407 |     :param s2: skip factor of the second stage block of decoder.
408 |     """
409 | 
410 |     def __init__(
411 |         self,
412 |         b1=1.3,
413 |         b2=1.4,
414 |         s1=0.9,
415 |         s2=0.2,
416 |         *args,
417 |         **kwargs
418 |     ):
419 |         super().__init__(*args, **kwargs)
420 |         self.b1 = b1 
421 |         self.b2 = b2
422 |         self.s1 = s1
423 |         self.s2 = s2
424 | 
425 |     def forward(self, h, noise):
426 |         # what we need is only x and noise
427 |         """
428 |         Apply the model to an input batch.
429 |         :param x: an [N x C x ...] Tensor of inputs.
430 |         :param timesteps: a 1-D batch of timesteps.
431 |         :param context: conditioning plugged in via crossattn
432 |         :param y: an [N] Tensor of labels, if class-conditional.
433 |         :return: an [N x C x ...] Tensor of outputs.
434 |         """
435 |         hs = []
436 |         noise_level = self.noise_level_mlp(noise) if exists(self.noise_level_mlp) else None
437 | 
438 | 
439 | 
440 |         for layer in self.downs:
441 |             if isinstance(layer, ResnetBlocWithAttn):
442 |                 h = layer(h, noise_level)
443 |             else:
444 |                 h = layer(h)
445 |             hs.append(h)
446 |         
447 |         for layer in self.mid:
448 |             if isinstance(layer, ResnetBlocWithAttn):
449 |                 h = layer(h, noise_level)
450 |             else:
451 |                 h = layer(h)
452 | 
453 |         for layer in self.ups:
454 |             # --------------- FreeU code -----------------------
455 |             # Only operate on the first two stages
456 |             if h.shape[1] == 256:
457 |                 hs_ = hs.pop()
458 |                 hidden_mean = h.mean(1).unsqueeze(1)
459 |                 B = hidden_mean.shape[0]
460 |                 hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 
461 |                 hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
462 |                 hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
463 | 
464 |                 h[:,:128] = h[:,:128] * ((self.b1 - 1 ) * hidden_mean + 1)
465 |                 hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1)
466 |                 hs.append(hs_)
467 |             if h.shape[1] == 128:
468 |                 hs_ = hs.pop()
469 |                 hidden_mean = h.mean(1).unsqueeze(1)
470 |                 B = hidden_mean.shape[0]
471 |                 hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 
472 |                 hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
473 |                 hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
474 | 
475 |                 h[:,:64] = h[:,:64] * ((self.b2 - 1 ) * hidden_mean + 1)
476 |                 hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2)
477 |                 hs.append(hs_)
478 |             # ---------------------------------------------------------
479 |             # exit(print(h.shape, hs_.shape)) # torch.Size([1, 256, 26, 38]) torch.Size([1, 256, 26, 38])
480 |             # print(h.shape, hs_.shape)
481 |             # h = torch.cat((h, hs_), dim=1)
482 | 
483 |             if isinstance(layer, ResnetBlocWithAttn):
484 |                 h = layer(torch.cat((h, hs.pop()), dim=1), noise_level)
485 |             else:
486 |                 h = layer(h)
487 |         return self.final_conv(h), self.var_conv(h)
488 | 
489 | class LAUNet(UNet):
490 |     """
491 |     :param b1: backbone factor of the first stage block of decoder.
492 |     :param b2: backbone factor of the second stage block of decoder.
493 |     :param s1: skip factor of the first stage block of decoder.
494 |     :param s2: skip factor of the second stage block of decoder.
495 |     """
496 | 
497 |     def __init__(
498 |         self,
499 |         b1=1.3,
500 |         b2=1.4,
501 |         s1=0.9,
502 |         s2=0.2,
503 |         *args,
504 |         **kwargs
505 |     ):
506 |         super().__init__(*args, **kwargs)
507 |         self.b1 = b1 
508 |         self.b2 = b2
509 |         self.s1 = s1
510 |         self.s2 = s2
511 | 
512 |     def forward(self, h, noise):
513 |         # what we need is only x and noise
514 |         """
515 |         Apply the model to an input batch.
516 |         :param x: an [N x C x ...] Tensor of inputs.
517 |         :param timesteps: a 1-D batch of timesteps.
518 |         :param context: conditioning plugged in via crossattn
519 |         :param y: an [N] Tensor of labels, if class-conditional.
520 |         :return: an [N x C x ...] Tensor of outputs.
521 |         """
522 |         hs = []
523 |         noise_level = self.noise_level_mlp(noise) if exists(self.noise_level_mlp) else None
524 | 
525 | 
526 | 
527 |         for layer in self.downs:
528 |             # print(h.shape)
529 |             if isinstance(layer, ResnetBlocWithAttn):
530 |                 h = layer(h, noise_level)
531 |             else:
532 |                 h = layer(h)
533 |             hs.append(h)
534 |         # print("\n")
535 |         for layer in self.mid:
536 |             # print(h.shape)
537 |             if isinstance(layer, ResnetBlocWithAttn):
538 |                 h = layer(h, noise_level)
539 |             else:
540 |                 h = layer(h)
541 |         # print("\n")
542 |         for layer in self.ups:
543 |             # print(h.shape)
544 |             # --------------- FreeU code -----------------------
545 |             # Only operate on the first two stages
546 |             if h.shape[1] == 256:
547 |                 hs_ = hs.pop()
548 |                 hidden_mean = h.mean(1).unsqueeze(1)
549 |                 B = hidden_mean.shape[0]
550 |                 hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 
551 |                 hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
552 |                 hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
553 | 
554 |                 h[:,:64] = h[:,:64] * ((self.b1 - 1 ) * hidden_mean + 1)
555 |                 hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1)
556 |                 hs.append(hs_)
557 |             # if h.shape[1] == 128 and h.shape[2] == 104:
558 |             #     hs_ = hs.pop()
559 |             #     hidden_mean = h.mean(1).unsqueeze(1)
560 |             #     B = hidden_mean.shape[0]
561 |             #     hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 
562 |             #     hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True)
563 |             #     hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3)
564 | 
565 |             #     h[:,:64] = h[:,:64] * ((self.b2 - 1 ) * hidden_mean + 1)
566 |             #     hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2)
567 |             #     hs.append(hs_)
568 |             # ---------------------------------------------------------
569 |             # exit(print(h.shape, hs_.shape)) # torch.Size([1, 256, 26, 38]) torch.Size([1, 256, 26, 38])
570 |             # print(h.shape, hs_.shape)
571 |             # h = torch.cat((h, hs_), dim=1)
572 |             if isinstance(layer, ResnetBlocWithAttn):
573 |                 h = layer(torch.cat((h, hs.pop()), dim=1), noise_level)
574 |             else:
575 |                 h = layer(h)
576 |         
577 |         # exit(-1)
578 |         return self.final_conv(h), self.var_conv(h)
--------------------------------------------------------------------------------
/model/model.py:
--------------------------------------------------------------------------------
  1 | import logging
  2 | from collections import OrderedDict
  3 | import torch
  4 | import torch.nn as nn
  5 | import os
  6 | import model.networks as networks
  7 | from .base_model import BaseModel
  8 | from torch.nn.parallel import DistributedDataParallel as DDP
  9 | from utils.ema import EMA
 10 | from torch.optim import lr_scheduler
 11 | import lpips
 12 | 
 13 | logger = logging.getLogger('base')
 14 | skip_para = []
 15 | 
 16 | skip_para = ['betas', 'alphas_cumprod', 'alphas_cumprod_prev', 'sqrt_alphas_cumprod', 
 17 |              'sqrt_one_minus_alphas_cumprod', 'log_one_minus_alphas_cumprod', 'sqrt_recip_alphas_cumprod',
 18 |                'sqrt_recipm1_alphas_cumprod', 'posterior_variance', 'posterior_log_variance_clipped', 
 19 |                'posterior_mean_coef1', 'posterior_mean_coef2',]
 20 | 
 21 | def get_scheduler(optimizer, opt):
 22 |     if opt['train']["optimizer"]['lr_policy'] == 'linear':
 23 |         def lambda_rule(iteration):
 24 |             lr_l = 1.0 - max(0, iteration-opt['train']["optimizer"]["n_lr_iters"]) / float(opt['train']["optimizer"]["lr_decay_iters"] + 1)
 25 |             return lr_l
 26 |         scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
 27 |     elif opt['train']["optimizer"]['lr_policy'] == 'step':
 28 |         scheduler = lr_scheduler.StepLR(optimizer, step_size=opt['train']["optimizer"]["lr_decay_iters"], gamma=0.8)
 29 |     elif opt['train']["optimizer"]['lr_policy'] == 'plateau':
 30 |         scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
 31 |     elif opt['train']["optimizer"]['lr_policy'] == 'cosine':
 32 |         scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0)
 33 |     else:
 34 |         return NotImplementedError('learning rate policy [%s] is not implemented', opt['train']["optimizer"]['lr_policy'])
 35 |     return scheduler
 36 | 
 37 | 
 38 | class DDPM(BaseModel):
 39 |     def __init__(self, opt):
 40 |         super(DDPM, self).__init__(opt)
 41 | 
 42 |         if opt['dist']:
 43 |             self.local_rank = torch.distributed.get_rank()
 44 |             torch.cuda.set_device(self.local_rank)
 45 |             device = torch.device("cuda", self.local_rank)
 46 |         # define network and load pretrained models
 47 |         self.netG = self.set_device(networks.define_G(opt, student=False))
 48 |         if opt['dist']:
 49 |             self.netG.to(device)
 50 |        
 51 |         # self.netG.to(device)
 52 | 
 53 |         self.schedule_phase = None
 54 |         self.opt = opt
 55 | 
 56 |         # set loss and load resume state
 57 |         self.set_loss()
 58 | 
 59 |         if self.opt['phase'] == 'train':
 60 |             self.netG.train()
 61 |             # find the parameters to optimize
 62 |             if opt['model']['finetune_norm']:
 63 |                 optim_params = []
 64 |                 for k, v in self.netG.named_parameters():
 65 |                     v.requires_grad = False
 66 |                     if k.find('transformer') >= 0:
 67 |                         v.requires_grad = True
 68 |                         v.data.zero_()
 69 |                         optim_params.append(v)
 70 |                         logger.info(
 71 |                             'Params [{:s}] initialized to 0 and will optimize.'.format(k))
 72 |             else:
 73 |                 optim_params = list(self.netG.parameters())
 74 | 
 75 |             self.optG = torch.optim.Adam(
 76 |                 optim_params, lr=opt['train']["optimizer"]["lr"])
 77 |             self.log_dict = OrderedDict()
 78 | 
 79 |         if self.opt['phase'] == 'test':
 80 |             self.netG.load_state_dict(torch.load(self.opt['path']['resume_state']), strict=True)
 81 | 
 82 |         else:
 83 |             self.load_network()
 84 |             if opt['dist']:
 85 |                 self.netG = DDP(self.netG, device_ids=[self.local_rank], output_device=self.local_rank,find_unused_parameters=True)
 86 | 
 87 | 
 88 |     def feed_data(self, data):
 89 | 
 90 |         dic = {}
 91 | 
 92 |         if self.opt['dist']:
 93 |             dic = {}
 94 |             dic['LQ'] = data['LQ'].to(self.local_rank)
 95 |             dic['GT'] = data['GT'].to(self.local_rank)
 96 |             self.data = dic
 97 |         else:
 98 |             dic['LQ'] = data['LQ']
 99 |             dic['GT'] = data['GT']
100 | 
101 |             self.data = self.set_device(dic)
102 | 
103 | 
104 |     def test(self, continous=False):
105 |         self.netG.eval()
106 |         with torch.no_grad():
107 |             if isinstance(self.netG, nn.DataParallel):
108 |                 self.SR = self.netG.module.super_resolution(
109 |                     self.data['LQ'], continous)
110 |                 
111 |             else:
112 |                 if self.opt['dist']:
113 |                     self.SR = self.netG.module.super_resolution(self.data['LQ'], continous)
114 |                 else:
115 |                     self.SR = self.netG.super_resolution(self.data['LQ'], continous)
116 | 
117 |         self.netG.train()
118 | 
119 |     def sample(self, batch_size=1, continous=False):
120 |         self.netG.eval()
121 |         with torch.no_grad():
122 |             if isinstance(self.netG, nn.DataParallel):
123 |                 self.SR = self.netG.module.sample(batch_size, continous)
124 |             else:
125 |                 self.SR = self.netG.sample(batch_size, continous)
126 |         self.netG.train()
127 | 
128 |     def set_loss(self):
129 |         if isinstance(self.netG, nn.DataParallel):
130 |             self.netG.module.set_loss(self.device)
131 |         else:
132 |             self.netG.set_loss(self.device)
133 | 
134 |     def set_new_noise_schedule(self, schedule_opt, schedule_phase='train'):
135 | 
136 |         if self.opt['dist']:
137 | 
138 |             device = torch.device("cuda", self.local_rank)
139 |             if self.schedule_phase is None or self.schedule_phase != schedule_phase:
140 |                 self.schedule_phase = schedule_phase
141 |                 if isinstance(self.netG, nn.DataParallel):
142 |                     self.netG.module.set_new_noise_schedule(
143 |                         schedule_opt, self.device)
144 |                 else:
145 |                     self.netG.set_new_noise_schedule(schedule_opt)
146 | 
147 |         else:
148 |             self.schedule_phase = schedule_phase
149 |             if isinstance(self.netG, nn.DataParallel):
150 |                 self.netG.module.set_new_noise_schedule(
151 |                     schedule_opt, self.device)
152 |             else:
153 |                 # self.netG.set_new_noise_schedule(schedule_opt, self.device)
154 |                 self.netG.set_new_noise_schedule(schedule_opt, self.device)
155 | 
156 | 
157 |     def get_current_log(self):
158 |         return self.log_dict
159 | 
160 |     def get_current_visuals(self, need_LR=True, sample=False):
161 |         out_dict = OrderedDict()
162 |         if sample:
163 |             out_dict['SAM'] = self.SR.detach().float().cpu()
164 |         else:
165 |             out_dict['HQ'] = self.SR.detach().float().cpu()
166 |             out_dict['INF'] = self.data['LQ'].detach().float().cpu()
167 |             out_dict['GT'] = self.data['GT'].detach()[0].float().cpu()
168 |             if need_LR and 'LR' in self.data:
169 |                 out_dict['LQ'] = self.data['LQ'].detach().float().cpu()
170 |             else:
171 |                 out_dict['LQ'] = out_dict['INF']
172 |         return out_dict
173 | 
174 |     def print_network(self):
175 |         s, n = self.get_network_description(self.netG)
176 |         if isinstance(self.netG, nn.DataParallel):
177 |             net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
178 |                                              self.netG.module.__class__.__name__)
179 |         else:
180 |             net_struc_str = '{}'.format(self.netG.__class__.__name__)
181 | 
182 |         logger.info(s)
183 | 
184 |         logger.info(
185 |             'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
186 | 
187 |     def save_network(self, distill_step, epoch, iter_step):
188 |         gen_path = os.path.join(
189 |             self.opt['path']['checkpoint'], 'num_step_{}', 'I{}_E{}_gen.pth'.format(distill_step, iter_step, epoch))
190 |         opt_path = os.path.join(
191 |             self.opt['path']['checkpoint'], 'num_step_{}', 'I{}_E{}_opt.pth'.format(distill_step, iter_step, epoch))
192 |         
193 |         # gen
194 |         network = self.netG
195 |         if isinstance(self.netG, nn.DataParallel):
196 |             network = network.module
197 |         state_dict = network.state_dict()
198 |         for key, param in state_dict.items():
199 |             state_dict[key] = param.cpu()
200 |         torch.save(state_dict, gen_path)
201 | 
202 | 
203 | 
204 |         logger.info(
205 |             'Saved model in [{:s}] ...'.format(gen_path))
206 | 
207 |     def load_network(self):
208 |         load_path = self.opt['path']['resume_state']
209 |         if load_path is not None:
210 |             logger.info(
211 |                 'Loading pretrained model for G [{:s}] ...'.format(load_path))
212 |             gen_path = '{}'.format(load_path)
213 |             
214 |             # gen
215 |             networks = [self.netG, self.netG]
216 |             for network in networks:
217 |                 if isinstance(network, nn.DataParallel):
218 |                     network = network.module
219 | 
220 |                 # network = nn.DataParallel(network).cuda()
221 |                 ckpt = torch.load(gen_path)
222 |                 current_state_dict = network.state_dict()
223 |                 for name, param in ckpt.items():
224 |                     if name in skip_para:
225 |                         continue
226 |                         # print(name)
227 |                         # import pdb; pdb.set_trace()
228 |                     else:
229 |                         current_state_dict[name] = param
230 | 
231 |                 network.load_state_dict(current_state_dict, strict=False)
232 |             if self.opt['phase'] == 'train':
233 |                 self.begin_step = 0
234 |                 self.begin_epoch = 0
235 | 
236 | 
237 | 
238 | 
239 | class DDPM_PD(BaseModel):
240 |     def __init__(self, opt):
241 |         super(DDPM_PD, self).__init__(opt)
242 | 
243 |         if opt['dist']:
244 |             self.local_rank = torch.distributed.get_rank()
245 |             torch.cuda.set_device(self.local_rank)
246 |             device = torch.device("cuda", self.local_rank)
247 |         # define network and load pretrained models
248 |         self.netG_t = self.set_device(networks.define_G(opt, student=False))
249 |         if  opt['CD'] :
250 |             self.netG_s = self.set_device(networks.define_G(opt, student=False))
251 |         else:
252 |             self.netG_s = self.set_device(networks.define_G(opt, student=True))
253 |         if opt['dist']:
254 |             self.netG_t.to(device)
255 |             self.netG_s.to(device)
256 |        
257 |         # self.netG.to(device)
258 |        
259 | 
260 |         self.schedule_phase = None
261 |         self.opt = opt
262 | 
263 |         # set loss and load resume state
264 | 
265 |         self.set_loss()
266 |         self.lpips = lpips.LPIPS(net='vgg').cuda()
267 | 
268 |         # self.set_new_noise_schedule(opt['model']['beta_schedule']['train'], schedule_phase='train')
269 | 
270 |         if self.opt['phase'] == 'train':
271 |             self.netG_s.train()
272 |             # find the parameters to optimize
273 | 
274 |             if opt['model']['finetune_norm']:
275 |                 optim_params = []
276 |                 for k, v in self.netG_s.named_parameters():
277 |                     v.requires_grad = False
278 |                     if k.find('transformer') >= 0:
279 |                         v.requires_grad = True
280 |                         v.data.zero_()
281 |                         optim_params.append(v)
282 |                         logger.info(
283 |                             'Params [{:s}] initialized to 0 and will optimize.'.format(k))
284 |             else:
285 |                 optim_params = list(self.netG_s.parameters())
286 | 
287 |             self.optG = torch.optim.Adam(optim_params, lr=opt['train']["optimizer"]["lr"])
288 |             self.scheduler = get_scheduler(self.optG, opt)
289 |             self.log_dict = OrderedDict()
290 | 
291 | 
292 |        
293 | 
294 |         if self.opt['phase'] == 'test':
295 |             self.netG_s.load_state_dict(torch.load(self.opt['path']['resume_state']), strict=True)
296 |         else:
297 |             self.load_network()
298 |             # self.netG_t.load_state_dict(torch.load(self.opt['path']['resume_state']), strict=True)
299 |             if opt['dist']:
300 |                 self.netG_s = DDP(self.netG_s, device_ids=[self.local_rank], output_device=self.local_rank,find_unused_parameters=True)
301 |                 self.netG_t = DDP(self.netG_t, device_ids=[self.local_rank], output_device=self.local_rank,find_unused_parameters=True)
302 |         for p in self.netG_t.parameters():
303 |             p.requires_grad_(False)
304 |         self.netG_t.eval()
305 |         self.netG_t.CD = opt['CD'] 
306 |         self.netG_s.CD = opt['CD'] 
307 | 
308 |         # self.print_network()
309 | 
310 |         # define ema
311 |         self.ema_decay = opt['train']["ema_scheduler"]["ema_decay"]
312 |         if self.opt['dist']:
313 |             self.ema_student = EMA(
314 |                 self.netG_s.module,
315 |                 decay = self.ema_decay,              # exponential moving average factor
316 |             )
317 |         else:
318 |             self.ema_student = EMA(
319 |                 self.netG_s,
320 |                 decay = self.ema_decay,              # exponential moving average factor
321 |             )
322 |         
323 |         self.ema_student.register()
324 | 
325 |     def feed_data(self, data):
326 | 
327 |         dic = {}
328 | 
329 |         if self.opt['dist']:
330 |             dic = {}
331 |             dic['LQ'] = data['LQ'].to(self.local_rank)
332 |             dic['GT'] = data['GT'].to(self.local_rank)
333 |             self.data = dic
334 |         else:
335 |             dic['LQ'] = data['LQ']
336 |             dic['GT'] = data['GT']
337 | 
338 |             self.data = self.set_device(dic)
339 | 
340 |     def optimize_parameters(self):
341 | 
342 |         self.optG.zero_grad()
343 |         if self.opt['dist']:
344 |             l_pd = self.netG_t(self.data, self.netG_s.module, lpips_func=self.lpips)
345 |         else:
346 |             l_pd = self.netG_t(self.data, self.netG_s, lpips_func=self.lpips)
347 |         # print("to be debug")
348 |         # import pdb; pdb.set_trace()
349 |         #
350 | 
351 |        
352 |         loss = l_pd
353 |         # print(l_pd)
354 |         # import pdb; pdb.set_trace()
355 |         loss.backward()
356 |         torch.nn.utils.clip_grad_norm_(self.netG_s.parameters(), 1)
357 |         self.optG.step()
358 |         self.scheduler.step()
359 |         # print( self.optG.param_groups[0]['lr'])
360 |         self.ema_student.update()
361 |         # set log
362 |         self.log_dict['total_loss'] = loss.item()
363 | 
364 |         
365 | 
366 | 
367 |     def test(self, continous=False, stride=1):
368 |         self.ema_student.apply_shadow() # apply shadow weights here
369 |         self.netG_s.eval()
370 |         with torch.no_grad():
371 |             if isinstance(self.netG_s, nn.DataParallel):
372 |                 self.SR = self.netG_s.module.super_resolution(
373 |                     self.data['LQ'], continous, stride)
374 |                 
375 |             else:
376 |                 if self.opt['dist']:
377 |                     self.SR = self.netG_s.module.super_resolution(self.data['LQ'], continous, stride)
378 |                 else:
379 |                     self.SR = self.netG_s.super_resolution(self.data['LQ'], continous, stride)
380 |         self.ema_student.restore()# restore shadow weights here
381 | 
382 |         self.netG_s.train()
383 | 
384 |     def sample(self, batch_size=1, continous=False):
385 |         self.ema_student.apply_shadow() # apply shadow weights here
386 |         self.netG_s.eval()
387 |         with torch.no_grad():
388 |             if isinstance(self.netG_s, nn.DataParallel):
389 |                 self.SR = self.netG_s.module.sample(batch_size, continous)
390 |             else:
391 |                 self.SR = self.netG_s.sample(batch_size, continous)
392 |         self.ema_student.restore()# restore shadow weights here
393 |         self.netG_s.train()
394 | 
395 |     def set_loss(self):
396 |         if isinstance(self.netG_s, nn.DataParallel):
397 |             self.netG_s.module.set_loss(self.device)
398 |         else:
399 |             self.netG_s.set_loss(self.device)
400 |     
401 | 
402 |     def get_current_log(self):
403 |         return self.log_dict
404 | 
405 |     def get_current_visuals(self, need_LR=True, sample=False):
406 |         out_dict = OrderedDict()
407 |         if sample:
408 |             out_dict['SAM'] = self.SR.detach().float().cpu()
409 |         else:
410 |             out_dict['HQ'] = self.SR.detach().float().cpu()
411 |             out_dict['INF'] = self.data['LQ'].detach().float().cpu()
412 |             out_dict['GT'] = self.data['GT'].detach()[0].float().cpu()
413 |             if need_LR and 'LR' in self.data:
414 |                 out_dict['LQ'] = self.data['LQ'].detach().float().cpu()
415 |             else:
416 |                 out_dict['LQ'] = out_dict['INF']
417 |         return out_dict
418 | 
419 |     def print_network(self):
420 |         s, n = self.get_network_description(self.netG_s)
421 |         if isinstance(self.netG_s, nn.DataParallel):
422 |             net_struc_str = '{} - {}'.format(self.netG_s.__class__.__name__,
423 |                                              self.netG_s.module.__class__.__name__)
424 |         else:
425 |             net_struc_str = '{}'.format(self.netG_s.__class__.__name__)
426 | 
427 |         logger.info(s)
428 | 
429 |         logger.info(
430 |             'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
431 | 
432 |     def save_network(self, distill_step, epoch, iter_step, psnr, ssim, lpips):
433 |         save_root = os.path.join(self.opt['path']['checkpoint'], 'num_step_{}'.format(distill_step))
434 |         os.makedirs(save_root, exist_ok=True)
435 |         # gen_path = os.path.join(save_root, 'P{:.4e}_S{:.4e}_I{}_E{}_gen.pth'.format(psnr, ssim, iter_step, epoch))
436 |         # opt_path = os.path.join(save_root, 'P{:.4e}_S{:.4e}_I{}_E{}_opt.pth'.format(psnr, ssim, iter_step, epoch))
437 |         ema_path = os.path.join(save_root, 'psnr{:.4f}_ssim{:.4f}_lpips{:.4f}_I{}_E{}_gen_ema.pth'.format(psnr, ssim, lpips, iter_step, epoch))
438 |         
439 |         # gen
440 |         # network = self.netG_s
441 |         # if isinstance(self.netG_s, nn.DataParallel):
442 |         #     network = network.module
443 |         # state_dict = network.state_dict()
444 |         # for key, param in state_dict.items():
445 |         #     state_dict[key] = param.cpu()
446 |         # torch.save(state_dict, gen_path)
447 | 
448 |         # opt
449 |  
450 | 
451 | 
452 |         # ema
453 |         self.ema_student.apply_shadow()
454 |         network = self.ema_student.model
455 |         if isinstance(self.ema_student.model, nn.DataParallel):
456 |             network = network.module
457 |         ema_ckpt = network.state_dict()
458 |         for key, param in ema_ckpt.items():
459 |             ema_ckpt[key] = param.cpu()
460 |         torch.save(ema_ckpt, ema_path)
461 |         self.ema_student.restore()
462 |         # logger.info(
463 |         #     'Saved model in [{:s}] ...'.format(gen_path))
464 |         logger.info(
465 |             'Saved model in [{:s}] ...'.format(ema_path))
466 |         return ema_path # gen_path
467 | 
468 |     def load_network(self):
469 |         load_path = self.opt['path']['resume_state']
470 |         if load_path is not None:
471 |             logger.info(
472 |                 'Loading pretrained model for G [{:s}] ...'.format(load_path))
473 |             gen_path = '{}'.format(load_path)
474 |             
475 |             # gen
476 |             networks = [self.netG_t, self.netG_s]
477 |             for network in networks:
478 |                 if isinstance(network, nn.DataParallel):
479 |                     network = network.module
480 |                 ckpt = torch.load(gen_path)
481 | 
482 |                 current_state_dict = network.state_dict()
483 |                 for name, param in ckpt.items():
484 |                     if name in skip_para:
485 |                         continue
486 | 
487 |                     else:
488 |                         current_state_dict[name] = param
489 | 
490 |                 network.load_state_dict(current_state_dict, strict=False)
491 | 
492 |             if self.opt['phase'] == 'train':
493 | 
494 |                 self.begin_step = 0
495 |                 self.begin_epoch = 0
496 | 
--------------------------------------------------------------------------------
/model/networks.py:
--------------------------------------------------------------------------------
  1 | import functools
  2 | import logging
  3 | import torch
  4 | import torch.nn as nn
  5 | from torch.nn import init
  6 | from torch.nn import modules
  7 | logger = logging.getLogger('base')
  8 | ####################
  9 | # initialize
 10 | ####################
 11 | 
 12 | 
 13 | def weights_init_normal(m, std=0.02):
 14 |     classname = m.__class__.__name__
 15 |     if classname.find('Conv') != -1:
 16 |         init.normal_(m.weight.data, 0.0, std)
 17 |         if m.bias is not None:
 18 |             m.bias.data.zero_()
 19 |     elif classname.find('Linear') != -1:
 20 |         init.normal_(m.weight.data, 0.0, std)
 21 |         if m.bias is not None:
 22 |             m.bias.data.zero_()
 23 |     elif classname.find('BatchNorm2d') != -1:
 24 |         init.normal_(m.weight.data, 1.0, std)  # BN also uses norm
 25 |         init.constant_(m.bias.data, 0.0)
 26 | 
 27 | 
 28 | def weights_init_kaiming(m, scale=1):
 29 |     classname = m.__class__.__name__
 30 |     if classname.find('Conv2d') != -1:
 31 |         init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
 32 |         m.weight.data *= scale
 33 |         if m.bias is not None:
 34 |             m.bias.data.zero_()
 35 |     elif classname.find('Linear') != -1:
 36 |         init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
 37 |         m.weight.data *= scale
 38 |         if m.bias is not None:
 39 |             m.bias.data.zero_()
 40 |     elif classname.find('BatchNorm2d') != -1:
 41 |         init.constant_(m.weight.data, 1.0)
 42 |         init.constant_(m.bias.data, 0.0)
 43 | 
 44 | 
 45 | def weights_init_orthogonal(m):
 46 |     classname = m.__class__.__name__
 47 |     if classname.find('Conv') != -1:
 48 |         init.orthogonal_(m.weight.data, gain=1)
 49 |         if m.bias is not None:
 50 |             m.bias.data.zero_()
 51 |     elif classname.find('Linear') != -1:
 52 |         init.orthogonal_(m.weight.data, gain=1)
 53 |         if m.bias is not None:
 54 |             m.bias.data.zero_()
 55 |     elif classname.find('BatchNorm2d') != -1:
 56 |         init.constant_(m.weight.data, 1.0)
 57 |         init.constant_(m.bias.data, 0.0)
 58 | 
 59 | 
 60 | def init_weights(net, init_type='kaiming', scale=1, std=0.02):
 61 |     # scale for 'kaiming', std for 'normal'.
 62 |     logger.info('Initialization method [{:s}]'.format(init_type))
 63 |     if init_type == 'normal':
 64 |         weights_init_normal_ = functools.partial(weights_init_normal, std=std)
 65 |         net.apply(weights_init_normal_)
 66 |     elif init_type == 'kaiming':
 67 |         weights_init_kaiming_ = functools.partial(
 68 |             weights_init_kaiming, scale=scale)
 69 |         net.apply(weights_init_kaiming_)
 70 |     elif init_type == 'orthogonal':
 71 |         net.apply(weights_init_orthogonal)
 72 |     else:
 73 |         raise NotImplementedError(
 74 |             'initialization method [{:s}] not implemented'.format(init_type))
 75 | 
 76 | 
 77 | ####################
 78 | # define network
 79 | ####################
 80 | 
 81 | 
 82 | # Generator
 83 | def define_G(opt, student=False):
 84 |     model_opt = opt['model']
 85 |     print(model_opt['which_model_G'])
 86 |     if model_opt['which_model_G'] == 'ddpm':
 87 |         from .ddpm_modules import diffusion, unet
 88 |     if ('norm_groups' not in model_opt['unet']) or model_opt['unet']['norm_groups'] is None:
 89 |         model_opt['unet']['norm_groups']=32
 90 | 
 91 | 
 92 |     model = unet.UNet(
 93 |         in_channel=model_opt['unet']['in_channel'],
 94 |         out_channel=model_opt['unet']['out_channel'],
 95 |         norm_groups=model_opt['unet']['norm_groups'],
 96 |         inner_channel=model_opt['unet']['inner_channel'],
 97 |         channel_mults=model_opt['unet']['channel_multiplier'],
 98 |         attn_res=model_opt['unet']['attn_res'],
 99 |         res_blocks=model_opt['unet']['res_blocks'],
100 |         dropout=model_opt['unet']['dropout'],
101 |         image_size=model_opt['diffusion']['image_size']
102 |     )
103 |     '''
104 |     if opt['freq_aware']:
105 |         model = unet.Free_UNet(
106 |             in_channel=model_opt['unet']['in_channel'],
107 |             out_channel=model_opt['unet']['out_channel'],
108 |             norm_groups=model_opt['unet']['norm_groups'],
109 |             inner_channel=model_opt['unet']['inner_channel'],
110 |             channel_mults=model_opt['unet']['channel_multiplier'],
111 |             attn_res=model_opt['unet']['attn_res'],
112 |             res_blocks=model_opt['unet']['res_blocks'],
113 |             dropout=model_opt['unet']['dropout'],
114 |             image_size=model_opt['diffusion']['image_size'],
115 |             b1=opt['freq_awareUNet']['b1'],
116 |             b2=opt['freq_awareUNet']['b2'],
117 |             s1=opt['freq_awareUNet']['s1'],
118 |             s2=opt['freq_awareUNet']['s2']
119 |         )
120 |     '''
121 | 
122 |     # print(model_opt['beta_schedule']['train']['n_timestep'])
123 |     if student: 
124 |         netG = diffusion.GaussianDiffusion(
125 |             model,
126 |             image_size=model_opt['diffusion']['image_size'],
127 |             num_timesteps=model_opt['beta_schedule']['train']['n_timestep'] // 2,
128 |             time_scale=model_opt['beta_schedule']['train']['time_scale'] * 2,
129 |             channels=model_opt['diffusion']['channels'],
130 |             w_gt= model_opt['diffusion']['w_gt'],
131 |             w_snr= model_opt['diffusion']['w_snr'],
132 |             w_str= model_opt['diffusion']['w_str'],
133 |             w_lpips= model_opt['diffusion']['w_lpips'],
134 |             loss_type='l1',   
135 |             conditional=model_opt['diffusion']['conditional'],
136 |             schedule_opt=model_opt['beta_schedule']['train'])
137 |     else:
138 |   
139 |         netG = diffusion.GaussianDiffusion(
140 |             model,
141 |             image_size=model_opt['diffusion']['image_size'],
142 |             num_timesteps=model_opt['beta_schedule']['train']['n_timestep'] ,
143 |             time_scale=model_opt['beta_schedule']['train']['time_scale'],
144 |             channels=model_opt['diffusion']['channels'],
145 |             w_gt= model_opt['diffusion']['w_gt'],
146 |             w_snr= model_opt['diffusion']['w_snr'],
147 |             w_str= model_opt['diffusion']['w_str'],
148 |             w_lpips= model_opt['diffusion']['w_lpips'],
149 |             loss_type='l1',   
150 |             conditional=model_opt['diffusion']['conditional'],
151 |             schedule_opt=model_opt['beta_schedule']['train'])
152 | 
153 |     if opt['phase'] == 'train':
154 |         # init_weights(netG, init_type='kaiming', scale=0.1)
155 |         init_weights(netG, init_type='orthogonal')
156 |     if opt['gpu_ids'] and opt['distributed']:
157 |         assert torch.cuda.is_available()
158 |         netG = nn.DataParallel(netG)
159 |     return netG
160 | 
161 |     
162 | # Generator
163 | def define_GGG(opt):
164 |     model_opt = opt['model']
165 |     print(model_opt['which_model_G'])
166 |     if model_opt['which_model_G'] == 'ddpm':
167 |         from .ddpm_modules import diffusion, unet
168 |     if ('norm_groups' not in model_opt['unet']) or model_opt['unet']['norm_groups'] is None:
169 |         model_opt['unet']['norm_groups']=32
170 |     model = unet.UNet(
171 |         in_channel=model_opt['unet']['in_channel'],
172 |         out_channel=model_opt['unet']['out_channel'],
173 |         norm_groups=model_opt['unet']['norm_groups'],
174 |         inner_channel=model_opt['unet']['inner_channel'],
175 |         channel_mults=model_opt['unet']['channel_multiplier'],
176 |         attn_res=model_opt['unet']['attn_res'],
177 |         res_blocks=model_opt['unet']['res_blocks'],
178 |         dropout=model_opt['unet']['dropout'],
179 |         image_size=model_opt['diffusion']['image_size']
180 |     )
181 |     netGVar = diffusion.GaussianDiffusion(
182 |         model,
183 |         image_size=model_opt['diffusion']['image_size'],
184 |         channels=model_opt['diffusion']['channels'],
185 |         loss_type='l1',  
186 |         conditional=model_opt['diffusion']['conditional'],
187 |         schedule_opt=model_opt['beta_schedule']['train']
188 |     )
189 |     if opt['phase'] == 'train':
190 |         # init_weights(netG, init_type='kaiming', scale=0.1)
191 |         init_weights(netGVar, init_type='orthogonal')
192 |     if opt['gpu_ids'] and opt['distributed']:
193 |         assert torch.cuda.is_available()
194 |         netGVar = nn.DataParallel(netGVar)
195 |     return netGVar
--------------------------------------------------------------------------------
/options/__init__.py:
--------------------------------------------------------------------------------
1 | 
2 | 
--------------------------------------------------------------------------------
/options/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/options/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/options/__pycache__/options.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/options/__pycache__/options.cpython-38.pyc
--------------------------------------------------------------------------------
/options/options.py:
--------------------------------------------------------------------------------
 1 | import os
 2 | import os.path as osp
 3 | import logging
 4 | import yaml
 5 | from utils.util import OrderedYaml
 6 | 
 7 | Loader, Dumper = OrderedYaml()
 8 | 
 9 | 
10 | def parse(opt_path, is_train=True):
11 |     with open(opt_path, mode='r') as f:
12 |         opt = yaml.load(f, Loader=Loader)
13 |     # export CUDA_VISIBLE_DEVICES
14 |     gpu_list = ','.join(str(x) for x in opt.get('gpu_ids', []))
15 |     # os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
16 |     # print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
17 |     opt['is_train'] = is_train
18 | 
19 |     # datasets
20 |     for phase, dataset in opt['datasets'].items():
21 |         phase = phase.split('_')[0]
22 |         dataset['phase'] = phase
23 | 
24 |         is_lmdb = False
25 |         if dataset.get('dataroot_GT', None) is not None:
26 |             dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
27 |             if dataset['dataroot_GT'].endswith('lmdb'):
28 |                 is_lmdb = True
29 |         if dataset.get('dataroot_LQ', None) is not None:
30 |             dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
31 |             if dataset['dataroot_LQ'].endswith('lmdb'):
32 |                 is_lmdb = True
33 |         dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
34 | 
35 |     # relative learning rate
36 |     if 'train' in opt:
37 |         niter = opt['train']['niter']
38 |         if 'T_period_rel' in opt['train']:
39 |             opt['train']['T_period'] = [int(x * niter) for x in opt['train']['T_period_rel']]
40 |         if 'restarts_rel' in opt['train']:
41 |             opt['train']['restarts'] = [int(x * niter) for x in opt['train']['restarts_rel']]
42 |         if 'lr_steps_rel' in opt['train']:
43 |             opt['train']['lr_steps'] = [int(x * niter) for x in opt['train']['lr_steps_rel']]
44 |         if 'lr_steps_inverse_rel' in opt['train']:
45 |             opt['train']['lr_steps_inverse'] = [int(x * niter) for x in opt['train']['lr_steps_inverse_rel']]
46 |         print(opt['train'])
47 | 
48 |     return opt
49 | 
50 | 
51 | def dict2str(opt, indent_l=1):
52 |     '''dict to string for logger'''
53 |     msg = ''
54 |     for k, v in opt.items():
55 |         if isinstance(v, dict):
56 |             msg += ' ' * (indent_l * 2) + k + ':[\n'
57 |             msg += dict2str(v, indent_l + 1)
58 |             msg += ' ' * (indent_l * 2) + ']\n'
59 |         else:
60 |             msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
61 |     return msg
62 | 
63 | 
64 | class NoneDict(dict):
65 |     def __missing__(self, key):
66 |         return None
67 | 
68 | 
69 | # convert to NoneDict, which return None for missing key.
70 | def dict_to_nonedict(opt):
71 |     if isinstance(opt, dict):
72 |         new_opt = dict()
73 |         for key, sub_opt in opt.items():
74 |             new_opt[key] = dict_to_nonedict(sub_opt)
75 |         return NoneDict(**new_opt)
76 |     elif isinstance(opt, list):
77 |         return [dict_to_nonedict(sub_opt) for sub_opt in opt]
78 |     else:
79 |         return opt
80 | 
81 | 
82 | def check_resume(opt, resume_iter):
83 |     '''Check resume states and pretrain_model paths'''
84 |     logger = logging.getLogger('base')
85 |     if opt['path']['resume_state']:
86 |         if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
87 |                 'pretrain_model_D', None) is not None:
88 |             logger.warning('pretrain_model path will be ignored when resuming training.')
89 | 
90 |         opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
91 |                                                    '{}_G.pth'.format(resume_iter))
92 |         logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
93 |         if 'gan' in opt['model']:
94 |             opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
95 |                                                        '{}_D.pth'.format(resume_iter))
96 |             logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
97 | 
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python==4.5.2.54
2 | PyYAML==6.0
3 | natsort==8.1.0
4 | scikit-image==0.18.1
5 | lpips==0.1.4
6 | kmeans_pytorch
7 | scikit-learn==1.0
8 | 
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | from os.path import basename
  3 | import math
  4 | import argparse
  5 | import random
  6 | import logging
  7 | import cv2
  8 | import sys
  9 | import numpy as np
 10 | import torch
 11 | import torch.distributed as dist
 12 | import torch.multiprocessing as mp
 13 | import torch.nn as nn
 14 | 
 15 | import options.options as option
 16 | from utils import util
 17 | from data import create_dataloader
 18 | from data.LoL_dataset import LOLv1_Dataset, LOLv2_Dataset
 19 | import torchvision.transforms as T
 20 | import lpips
 21 | import model as Model
 22 | import core.logger as Logger
 23 | import core.metrics as Metrics
 24 | from torchvision import transforms
 25 | 
 26 | 
 27 | transform = transforms.Lambda(lambda t: (t * 2) - 1)
 28 | 
 29 | def main():
 30 | 
 31 |     parser = argparse.ArgumentParser()
 32 |     parser = argparse.ArgumentParser()
 33 |     parser.add_argument('--dataset', type=str, help='Path to option YMAL file.',
 34 |                             default='./config/LOLv1.yml') # 
 35 |     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
 36 |                         help='job launcher')
 37 |     parser.add_argument('--local_rank', type=int, default=0)
 38 |     parser.add_argument('--tfboard', action='store_true')
 39 | 
 40 | 
 41 |     parser.add_argument('-c', '--config', type=str, default='config/lolv1_test.json',
 42 |                         help='JSON file for configuration')
 43 |     parser.add_argument('-p', '--phase', type=str, choices=['train', 'val'],
 44 |                         help='Run either train(training) or val(generation)', default='train')
 45 |     parser.add_argument('-gpu', '--gpu_ids', type=str, default="0")
 46 |     parser.add_argument('-debug', '-d', action='store_true')
 47 |     parser.add_argument('-log_eval', action='store_true')
 48 | 
 49 |     # search noise schedule
 50 |     parser.add_argument('--brutal_search', action='store_true')
 51 |     parser.add_argument('--noise_start', type=float, default=9e-4)
 52 |     parser.add_argument('--noise_end', type=float, default=8.5e-1)
 53 |     parser.add_argument('--n_timestep', type=int, default=16)
 54 | 
 55 |     parser.add_argument('--w_str', type=float, default=0.)
 56 |     parser.add_argument('--w_snr', type=float, default=0.)
 57 |     parser.add_argument('--w_gt', type=float, default=0.1)
 58 |     parser.add_argument('--w_lpips', type=float, default=0.1)
 59 | 
 60 |     parser.add_argument('--stride', type=int, default=1)
 61 | 
 62 | 
 63 |     # for freq_aware
 64 |     parser.add_argument('--freq_aware', action='store_true')
 65 |     parser.add_argument('--b1', type=float, default=1.6)
 66 |     parser.add_argument('--b2', type=float, default=1.6)
 67 |     parser.add_argument('--s1', type=float, default=0.9)
 68 |     parser.add_argument('--s2', type=float, default=0.9)
 69 | 
 70 |     args = parser.parse_args()
 71 |     opt = Logger.parse(args)
 72 |     opt = Logger.dict_to_nonedict(opt)
 73 |     opt_dataset = option.parse(args.dataset, is_train=True)
 74 | 
 75 | 
 76 | 
 77 |     os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 78 | 
 79 |     opt['phase'] = 'test'
 80 |     opt['distill'] = False
 81 |     opt['uncertainty_train'] = False
 82 | 
 83 |     #### distributed training settings
 84 |     opt['dist'] = False
 85 |     rank = -1
 86 |     print('Disabled distributed training.')
 87 | 
 88 |     #### mkdir and loggers
 89 |     if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
 90 |         # config loggers. Before it, the log will not work
 91 |         util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
 92 |                           screen=True, tofile=True)
 93 |         util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
 94 |                           screen=True, tofile=True)
 95 |         logger = logging.getLogger('base')
 96 |         logger.info(option.dict2str(opt))
 97 | 
 98 |     # convert to NoneDict, which returns None for missing keys
 99 |     opt = option.dict_to_nonedict(opt)
100 | 
101 |     #### seed
102 |     seed = opt['seed']
103 |     if seed is None:
104 |         seed = random.randint(1, 10000)
105 |     if rank <= 0:
106 |         logger.info('Seed: {}'.format(seed))
107 |     util.set_random_seed(seed)
108 | 
109 |     torch.backends.cudnn.benchmark = True
110 |     # torch.backends.cudnn.deterministic = True
111 | 
112 |     #### create train and val dataloader
113 |     if opt_dataset['dataset'] == 'LOLv1':
114 |         dataset_cls = LOLv1_Dataset
115 |     elif opt_dataset['dataset'] == 'LOLv2':
116 |         dataset_cls = LOLv2_Dataset
117 | 
118 |     else:
119 |         raise NotImplementedError()
120 | 
121 |     for phase, dataset_opt in opt_dataset['datasets'].items():
122 |         if phase == 'val':
123 |             val_set = dataset_cls(opt=dataset_opt, train=False, all_opt=opt_dataset)
124 |             val_loader = create_dataloader(val_set, dataset_opt, opt_dataset, None)
125 | 
126 |     # opt["model"]['beta_schedule']["train"]["time_scale"] = 1
127 | 
128 |     opt["model"]["diffusion"]["w_snr"] = args.w_snr
129 |     opt["model"]["diffusion"]["w_str"] = args.w_str
130 |     opt["model"]["diffusion"]["w_gt"] = args.w_gt
131 | 
132 |     # model
133 |     diffusion = Model.create_model(opt)
134 |     logger.info('Initial Model Finished')
135 | 
136 | 
137 | 
138 |     loss_fn_vgg = lpips.LPIPS(net='vgg').cuda()
139 |     result_path = '{}'.format(opt['path']['results'])
140 |     result_path_gt = result_path+'/gt/'
141 |     result_path_out = result_path+'/output/'
142 |     result_path_input = result_path+'/input/'
143 |     os.makedirs(result_path_gt, exist_ok=True)
144 |     os.makedirs(result_path_out, exist_ok=True)
145 |     os.makedirs(result_path_input, exist_ok=True)
146 | 
147 |     #diffusion.set_new_noise_schedule(
148 |         #opt['model']['beta_schedule']['val'], schedule_phase='val')
149 |     
150 |     
151 |     logger_val = logging.getLogger('val')  # validation logger
152 | 
153 |     avg_psnr = 0.0
154 |     avg_ssim = 0.0
155 |     avg_lpips = 0.0
156 |     idx = 0
157 |     lpipss = []
158 |     
159 |     for val_data in val_loader:
160 | 
161 |         idx += 1
162 |         diffusion.feed_data(val_data)
163 |         diffusion.test(continous=False)
164 | 
165 |         visuals = diffusion.get_current_visuals()
166 |         
167 |         normal_img = Metrics.tensor2img(visuals['HQ'])
168 |         if normal_img.shape[0] != normal_img.shape[1]: # lolv1 and lolv2-real
169 |             normal_img = normal_img[8:408, 4:604,:]
170 |         gt_img = Metrics.tensor2img(visuals['GT'])
171 |         ll_img = Metrics.tensor2img(visuals['LQ'])
172 | 
173 |         img_mode = 'single'
174 |         if img_mode == 'single':
175 |             util.save_img(
176 |                 gt_img, '{}/{}_gt.png'.format(result_path_gt, idx))
177 |             util.save_img(
178 |                 ll_img, '{}/{}_lq.png'.format(result_path_input, idx))
179 |             # util.save_img(
180 |             #     normal_img, '{}/{}_normal_noadjust.png'.format(result_path, idx))
181 |         else:
182 |             util.save_img(
183 |                 gt_img, '{}/{}_gt.png'.format(result_path, idx))
184 |             util.save_img(
185 |                 normal_img, '{}/{}_{}_normal_process.png'.format(result_path, idx))
186 |             # for i in range(visuals['HQ'].shape[0]):
187 |             #     util.save_img(Metrics.tensor2img(visuals['HQ'][i]), '{}/{}_{}_normal.png'.format(result_path, idx, i))
188 |             # util.save_img(
189 |             #     Metrics.tensor2img(visuals['HQ'][-1]), '{}/{}_normal.png'.format(result_path, idx))
190 |             normal_img = Metrics.tensor2img(visuals['HQ'][-1])
191 | 
192 |         # Similar to LLFlow, we follow a similar way of 'Kind' to finetune the overall brightness 
193 |         # as illustrated in Line 73 (https://github.com/zhangyhuaee/KinD/blob/master/evaluate_LOLdataset.py).
194 |         gt_img = gt_img / 255.
195 |         normal_img = normal_img / 255.
196 |         mean_gray_out = cv2.cvtColor(normal_img.astype(np.float32), cv2.COLOR_BGR2GRAY).mean()
197 |         mean_gray_gt = cv2.cvtColor(gt_img.astype(np.float32), cv2.COLOR_BGR2GRAY).mean()
198 |         normal_img_adjust = np.clip(normal_img * (mean_gray_gt / mean_gray_out), 0, 1)
199 | 
200 |         normal_img = (normal_img_adjust * 255).astype(np.uint8)
201 |         gt_img = (gt_img * 255).astype(np.uint8)
202 | 
203 |         psnr = util.calculate_psnr(normal_img, gt_img)
204 |         ssim = util.calculate_ssim(normal_img, gt_img)
205 | 
206 |         normal_img_tensor = torch.tensor(normal_img.astype(np.float32))
207 |         gt_img_tensor = torch.tensor(gt_img.astype(np.float32))
208 |         normal_img_tensor = normal_img_tensor.permute(2, 0, 1).cuda()
209 |         gt_img_tensor = gt_img_tensor.permute(2, 0, 1).cuda()
210 |         lpips_scores = loss_fn_vgg(normal_img_tensor, gt_img_tensor).item()
211 |         
212 |         util.save_img(normal_img, '{}/{}_normal.png'.format(result_path_out, idx))
213 | 
214 |         # lpips
215 |         
216 |         # lpips_ = loss_fn_vgg(visuals['HQ'], visuals['GT'])
217 |         # lpipss.append(lpips_scores.numpy())
218 | 
219 |         logger_val.info('### {} cPSNR: {:.4e} cSSIM: {:.4e} cLPIPS: {:.4e}'.format(idx, psnr, ssim, lpips_scores))
220 |         avg_ssim += ssim
221 |         avg_psnr += psnr
222 |         avg_lpips += lpips_scores
223 | 
224 |     avg_psnr = avg_psnr / idx
225 |     avg_ssim = avg_ssim / idx
226 |     avg_lpips = avg_lpips / idx
227 | 
228 |     # log
229 |     logger_val.info('# Validation # avgPSNR: {:.4e} avgSSIM: {:.4e} avgLPIPS: {:.4e}'.format(avg_psnr, avg_ssim, avg_lpips))
230 |     # logger_val.info(f"n_timestep: {args.n_timestep}, noise_start: {args.noise_start}, noise_end: {args.noise_end}")
231 | 
232 | 
233 | 
234 | if __name__ == '__main__':
235 |     main()
236 | 
--------------------------------------------------------------------------------
/test.sh:
--------------------------------------------------------------------------------
1 | python test.py --dataset ./config/lolv2_real.yml --config config/lolv2_real_test.json --w_str 0.9 --w_snr 0.2 --w_gt 0.2 
2 | 
3 | 
4 | 
--------------------------------------------------------------------------------
/test_unpaired.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | from os.path import basename
  3 | import math
  4 | import argparse
  5 | import random
  6 | import logging
  7 | import cv2
  8 | import torch
  9 | import torchvision.transforms.functional as TF
 10 | from PIL import Image
 11 | import options.options as option
 12 | from utils import util
 13 | import torchvision.transforms as T
 14 | import model as Model
 15 | import core.logger as Logger
 16 | import core.metrics as Metrics
 17 | import natsort
 18 | from torchvision import transforms
 19 | from utils.niqe import niqe
 20 | 
 21 | transform = transforms.Lambda(lambda t: (t * 2) - 1)
 22 | 
 23 | def main():
 24 |     #### options
 25 |     parser = argparse.ArgumentParser()
 26 |     parser = argparse.ArgumentParser()
 27 |     parser.add_argument('--dataset', type=str, help='Path to option YMAL file.',
 28 |                             default='./config/dataset.yml') # 
 29 |     parser.add_argument('--input', type=str, help='testing the unpaired image',
 30 |                             default='images/unpaired/')
 31 |     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
 32 |                         help='job launcher')
 33 |     parser.add_argument('--local_rank', type=int, default=0)
 34 |     parser.add_argument('--tfboard', action='store_true')
 35 |     parser.add_argument('-c', '--config', type=str, default='config/test_unpaired.json',
 36 |                         help='JSON file for configuration')
 37 |     parser.add_argument('-p', '--phase', type=str, choices=['train', 'val'],
 38 |                         help='Run either train(training) or val(generation)', default='train')
 39 |     parser.add_argument('-gpu', '--gpu_ids', type=str, default="0")
 40 |     parser.add_argument('-debug', '-d', action='store_true')
 41 |     parser.add_argument('-enable_wandb', action='store_true')
 42 |     parser.add_argument('-log_wandb_ckpt', action='store_true')
 43 |     parser.add_argument('-log_eval', action='store_true')
 44 | 
 45 |     parser.add_argument('--n_timestep', type=int, default=8)
 46 |     parser.add_argument('--w_str', type=float, default=0.)
 47 |     parser.add_argument('--w_snr', type=float, default=0.)
 48 |     parser.add_argument('--w_gt', type=float, default=0.)
 49 |     parser.add_argument('--w_lpips', type=float, default=0.)
 50 | 
 51 | 
 52 |     parser.add_argument('--brutal_search', action='store_true')
 53 | 
 54 |     # parse configs
 55 |     args = parser.parse_args()
 56 |     opt = Logger.parse(args)
 57 |     # Convert to NoneDict, which return None for missing key.
 58 |     opt = Logger.dict_to_nonedict(opt)
 59 | 
 60 |     # os.environ['CUDA_VISIBLE_DEVICES'] = '0'
 61 | 
 62 |     opt['phase'] = 'test'
 63 | 
 64 |     #### distributed training settings
 65 |     opt['dist'] = False
 66 |     rank = -1
 67 |     print('Disabled distributed training.')
 68 | 
 69 |     #### mkdir and loggers
 70 |     if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
 71 |         # config loggers. Before it, the log will not work
 72 |         util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
 73 |                           screen=True, tofile=True)
 74 |         logger = logging.getLogger('base')
 75 |         logger.info(option.dict2str(opt))
 76 | 
 77 | 
 78 |     util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True,  tofile=True)
 79 |     logger = logging.getLogger('base')
 80 | 
 81 |     # convert to NoneDict, which returns None for missing keys
 82 |     opt = option.dict_to_nonedict(opt)
 83 | 
 84 |     #### random seed
 85 |     seed = opt['train']['manual_seed']
 86 |     if seed is None:
 87 |         seed = random.randint(1, 10000)
 88 |     if rank <= 0:
 89 |         logger.info('Random seed: {}'.format(seed))
 90 |     util.set_random_seed(seed)
 91 | 
 92 |     torch.backends.cudnn.benchmark = True
 93 |     # torch.backends.cudnn.deterministic = True
 94 | 
 95 | 
 96 |     # model
 97 |     diffusion = Model.create_model(opt)
 98 |     logger.info('Initial Model Finished')
 99 | 
100 |     result_path = '{}'.format(opt['path']['results'])
101 |     os.makedirs(result_path, exist_ok=True)
102 | 
103 |    #  diffusion.set_new_noise_schedule(
104 |     #     opt['model']['beta_schedule']['val'], schedule_phase='val')
105 | 
106 |     InputPath = args.input
107 |     Image_names = natsort.natsorted(os.listdir(InputPath), alg=natsort.ns.PATH)
108 | 
109 |     ave_niqe = 0.
110 | 
111 |     for i in range(len(Image_names)):
112 | 
113 |         path = InputPath + Image_names[i]
114 |         raw_img = Image.open(path).convert('RGB')
115 |         img_w = raw_img.size[0]
116 |         img_h = raw_img.size[1]
117 |         raw_img = transforms.Resize((img_h // 16 * 16, img_w // 16 * 16))(raw_img)
118 | 
119 |         raw_img = transform(TF.to_tensor(raw_img)).unsqueeze(0).cuda()
120 | 
121 |         val_data = {}
122 |         val_data['LQ'] = raw_img
123 |         val_data['GT'] = raw_img
124 |         diffusion.feed_data(val_data)
125 |         diffusion.test(continous=False)
126 | 
127 |         visuals = diffusion.get_current_visuals()
128 |         
129 |         normal_img = Metrics.tensor2img(visuals['HQ']) 
130 |         # normal_img = cv2.resize(normal_img, (img_w, img_h))
131 |         ll_img = Metrics.tensor2img(visuals['LQ']) 
132 |         niqe_scores =  niqe(normal_img)
133 |         ave_niqe = ave_niqe + niqe_scores
134 |         llie_img_mode = 'single'
135 |         if llie_img_mode == 'single':
136 |             # util.save_img(ll_img, '{}/{}_input.png'.format(result_path, idx))
137 |             util.save_img(
138 |                 normal_img, '{}/{}_normal.png'.format(result_path, i+1))
139 |         else:
140 |             util.save_img(
141 |                 normal_img, '{}/{}_{}_normal_process.png'.format(result_path, i))
142 |             util.save_img(
143 |                 Metrics.tensor2img(visuals['HQ'][-1]), '{}/{}_normal.png'.format(result_path, i))
144 |             normal_img = Metrics.tensor2img(visuals['HQ'][-1])
145 |         logger.info('CNIQE: {} on {}'.format(niqe_scores, Image_names[i]))
146 |     ave_niqe = ave_niqe / len(Image_names)
147 |     logger.info('NIQE: {} on {}'.format(ave_niqe, InputPath))
148 | 
149 | 
150 | 
151 | if __name__ == '__main__':
152 |     main()
153 |     print("finish!")
154 | 
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
  1 | import os
  2 | from os.path import basename
  3 | import math
  4 | import argparse
  5 | import logging
  6 | import cv2
  7 | import sys
  8 | import numpy as np
  9 | import torch
 10 | import torch.distributed as dist
 11 | import torch.multiprocessing as mp
 12 | import torch.nn as nn
 13 | from torchvision import transforms
 14 | import options.options as option
 15 | from utils import util
 16 | from data import create_dataloader
 17 | import data as Data
 18 | from data.LoL_dataset import LOLv1_Dataset, LOLv2_Dataset
 19 | from data.SDSD_image_dataset import Dataset_SDSDImage
 20 | from data.SID import ImageDataset2
 21 | import torchvision.transforms as T
 22 | import model as Model
 23 | import core.logger as Logger
 24 | import core.metrics as Metrics
 25 | import random
 26 | import lpips
 27 | 
 28 | import pdb
 29 | 
 30 | 
 31 | 
 32 | 
 33 | def init_dist(backend='nccl', **kwargs):
 34 |     """initialization for distributed training"""
 35 |     if mp.get_start_method(allow_none=True) != 'spawn':
 36 |         mp.set_start_method('spawn')
 37 |     rank = int(os.environ['RANK'])
 38 |     num_gpus = torch.cuda.device_count()
 39 |     torch.cuda.set_device(rank % num_gpus)
 40 |     dist.init_process_group(backend=backend, **kwargs)
 41 | 
 42 | 
 43 | def main():
 44 | 
 45 |     parser = argparse.ArgumentParser()
 46 |     parser.add_argument('--dataset', type=str, help='Path to option YMAL file.',
 47 |                             default='./config/lolv2_real.yml') # 
 48 |     parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
 49 |                         help='job launcher')
 50 |     parser.add_argument('--local_rank', type=int, default=0)
 51 | 
 52 |     parser.add_argument('-c', '--config', type=str, default='config/lolv1_train.json',
 53 |                         help='JSON file for configuration')
 54 |     parser.add_argument('-p', '--phase', type=str, choices=['train', 'val'],
 55 |                         help='Run either train(training) or val(generation)', default='train')
 56 |     parser.add_argument('-gpu', '--gpu_ids', type=str, default=None)
 57 |     parser.add_argument('-debug', '-d', action='store_true')
 58 |     parser.add_argument('-log_eval', action='store_true')
 59 |     parser.add_argument('-uncertainty', action='store_true')
 60 | 
 61 |     # for ablation
 62 |     parser.add_argument('--ablation', action='store_true')
 63 |     parser.add_argument('--w_str', type=float, default=0.2)
 64 |     parser.add_argument('--w_snr', type=float, default=0.9)
 65 |     parser.add_argument('--w_gt', type=float, default=0.2)
 66 |     parser.add_argument('--w_lpips', type=float, default=0.2)
 67 | 
 68 |     parser.add_argument('--progressive', action='store_true')
 69 |     parser.add_argument('--CD', action='store_true')
 70 | 
 71 | 
 72 | 
 73 |     parser.add_argument('--brutal_search', action='store_true')
 74 | 
 75 |     # ema config
 76 |     parser.add_argument('--ema_decay', type=float, default=0.999)
 77 | 
 78 |     # parse configs
 79 |     args = parser.parse_args()
 80 |     opt = Logger.parse(args)
 81 |     # Convert to NoneDict, which return None for missing key.
 82 |     opt = Logger.dict_to_nonedict(opt)
 83 |     opt_dataset = option.parse(args.dataset, is_train=True)
 84 | 
 85 |     if args.ablation:
 86 |         opt["model"]["diffusion"]["w_snr"] = args.w_snr
 87 |         opt["model"]["diffusion"]["w_str"] = args.w_str
 88 |         opt["model"]["diffusion"]["w_gt"] = args.w_gt
 89 |         opt["model"]["diffusion"]["w_lpips"] = args.w_lpips
 90 |     if args.CD:
 91 |         opt["CD"] = True
 92 |     if args.launcher == 'none':  # disabled distributed training
 93 |         opt['dist'] = False
 94 |         rank = -1
 95 |         print('Disabled distributed training.')
 96 |     else:
 97 |         opt['dist'] = True
 98 |         init_dist()
 99 |         torch.distributed.init_process_group(
100 |             'nccl',
101 |             init_method='env://'
102 |         ) 
103 |         world_size = torch.distributed.get_world_size()
104 |         rank = torch.distributed.get_rank()
105 |         device = torch.device("cuda", rank)
106 | 
107 |     if rank <= 0:  # normal training (rank -1) OR distributed training (rank 0)
108 | 
109 |         # config loggers. Before it, the log will not work
110 |         util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
111 |                           screen=True, tofile=True)
112 |         util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
113 |                           screen=True, tofile=True)
114 |         logger = logging.getLogger('base')
115 |         # import pdb; pdb.set_trace()
116 |         logger.info(option.dict2str(opt))
117 | 
118 |         # import pdb; pdb.set_trace()
119 | 
120 |         # tensorboard logger
121 |         if opt.get('use_tb_logger', False) and 'debug' not in opt['name']:
122 |             version = float(torch.__version__[0:3])
123 |             if version >= 1.1:  # PyTorch 1.1
124 |                 # from torch.utils.tensorboard import SummaryWriter
125 |                 if sys.platform != 'win32':
126 |                     from tensorboardX import SummaryWriter
127 |                 else:
128 |                     from tensorboardX import SummaryWriter
129 |                     # from torch.utils.tensorboard import SummaryWriter
130 |             else:
131 |                 logger.info(
132 |                     'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
133 |                 from tensorboard import SummaryWriter
134 |             conf_name = basename(args.opt).replace(".yml", "")
135 |             exp_dir = opt['path']['experiments_root']
136 |             log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train')
137 |             log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid')
138 |             tb_logger_train = SummaryWriter(log_dir=log_dir_train)
139 |             tb_logger_valid = SummaryWriter(log_dir=log_dir_valid)
140 |     else:
141 |         util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
142 |         logger = logging.getLogger('base')
143 | 
144 |     # convert to NoneDict, which returns None for missing keys
145 |     opt = option.dict_to_nonedict(opt)
146 | 
147 |     #### random seed
148 |     seed = opt['seed']
149 |     if seed is None:
150 |         seed = random.randint(1, 10000)
151 |     if rank <= 0:
152 |         logger.info('Random seed: {}'.format(seed))
153 |     util.set_random_seed(seed)
154 | 
155 |     torch.backends.cudnn.benchmark = True
156 |     torch.backends.cudnn.deterministic = False
157 |     torch.backends.cudnn.allow_tf32 = True
158 | 
159 |     #### create train and val dataloader
160 |     if opt_dataset['dataset'] == 'LOLv1':
161 |         dataset_cls = LOLv1_Dataset
162 |         PD_steps = [16, 8, 4, 2, 1]
163 |         temp_time_scale = [1, 2, 4, 8, 16]
164 |         time_scale = [i * 32 for i in temp_time_scale]
165 |     elif opt_dataset['dataset'] == 'LOLv2':
166 |         dataset_cls = LOLv2_Dataset      
167 |         PD_steps = [16, 8, 4, 2, 1]
168 |         temp_time_scale = [1, 2, 4, 8, 16]
169 |         time_scale = [i * 32 for i in temp_time_scale]
170 |     elif opt_dataset['dataset'] == 'SDSD_indoor':
171 |         dataset_cls = Dataset_SDSDImage
172 |         PD_steps = [16, 8, 4, 2, 1]
173 |         temp_time_scale = [1, 2, 4, 8, 16]
174 |         time_scale = [i * 32 for i in temp_time_scale]
175 | 
176 |     elif opt_dataset['dataset'] == 'SID':
177 |         dataset_cls = ImageDataset2
178 |         PD_steps = [16, 8, 4, 2, 1]
179 |         temp_time_scale = [1, 2, 4, 8, 16]
180 |         time_scale = [i * 32 for i in temp_time_scale]
181 |     else:
182 |         raise NotImplementedError()
183 | 
184 |     for phase, dataset_opt in opt_dataset['datasets'].items():
185 |         if phase == 'train':
186 |             train_set = dataset_cls(opt=dataset_opt, train=True, all_opt=opt_dataset)
187 |             train_loader = create_dataloader(train_set, dataset_opt, opt_dataset, None)
188 |         elif phase == 'val':
189 |             val_set = dataset_cls(opt=dataset_opt, train=False, all_opt=opt_dataset)
190 |             val_loader = create_dataloader(val_set, dataset_opt, opt_dataset, None)
191 | 
192 |     # model
193 |     resume_state =  opt["path"]["resume_state"]
194 |     lpips_func = lpips.LPIPS(net='vgg').cuda()
195 |             
196 |     for i in range(len(PD_steps)):        
197 |         opt["model"]['beta_schedule']["train"]["n_timestep"] = PD_steps[i]
198 |         opt["model"]['beta_schedule']["val"]["n_timestep"] = PD_steps[i+1]
199 |         
200 |         opt["path"]["resume_state"] = resume_state
201 |         opt["model"]['beta_schedule']["train"]["time_scale"] = time_scale[i]
202 |         logger.info('Distillation from {:d} to {:d}'.format(opt["model"]['beta_schedule']["train"]["n_timestep"],  opt["model"]['beta_schedule']["val"]["n_timestep"]))
203 |         logger.info(f"w_snr: {opt['model']['diffusion']['w_snr']}, w_str: {opt['model']['diffusion']['w_str']}")
204 | 
205 |         diffusion = Model.create_model(opt)
206 |         
207 |         logger.info('Initial Model Finished')
208 |         # Train
209 |         current_step = diffusion.begin_step
210 |         current_epoch = diffusion.begin_epoch
211 |         n_iter = opt['train']['n_iter'] # * iter_scale[i]
212 |         # training
213 |         logger.info('Start training from epoch: {:d}, iter: {:d}'.format(current_epoch, current_step))
214 |         avg_psnr = 0
215 |         best_psnr = 0
216 |         best_ssim = 0
217 |         best_lpips = 0
218 | 
219 |         # pdb.set_trace()
220 |         while current_step < n_iter:
221 | 
222 |             current_epoch += 1
223 |             for _, train_data in enumerate(train_loader):
224 | 
225 |                 current_step += 1
226 |                 if current_step > n_iter:
227 |                     break
228 |                
229 |                 diffusion.feed_data(train_data)
230 |                 diffusion.optimize_parameters()
231 |                 # log
232 |                 if current_step % opt['train']['print_freq'] == 0 and rank <= 0:
233 |                     logs = diffusion.get_current_log()
234 |                     message = ' '.format(
235 |                         current_epoch, current_step)
236 |                     for k, v in logs.items():
237 |                         message += '{:s}: {:.4e} '.format(k, v)
238 |                     logger.info(message)
239 | 
240 |                 # validation
241 |                 if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
242 | 
243 |                     avg_psnr = 0.0
244 |                     avg_ssim = 0.0
245 |                     avg_lpips = 0.0
246 |                     idx = 0
247 |     
248 |                     result_path = '{}/{}'.format(opt['path']['results'], PD_steps[i+1], current_epoch)
249 |                     result_path_gt = result_path+'/gt/'
250 |                     result_path_out = result_path+'/output/'
251 |                     result_path_input = result_path+'/input/'
252 | 
253 |                     os.makedirs(result_path_gt, exist_ok=True)
254 |                     os.makedirs(result_path_out, exist_ok=True)
255 |                     os.makedirs(result_path_input, exist_ok=True)
256 | 
257 |       
258 |                     for val_data in val_loader:
259 | 
260 |                         idx += 1
261 |                         diffusion.feed_data(val_data)
262 |                         diffusion.test(continous=False)
263 | 
264 |                         visuals = diffusion.get_current_visuals()
265 |                         
266 |                         
267 |                         if opt_dataset['dataset'] == 'LOLv1' or  opt_dataset['dataset'] == 'LOLv2' :
268 |                             normal_img = Metrics.tensor2img(visuals['HQ'])
269 |                             if normal_img.shape[0] != normal_img.shape[1]: # lolv1 and lolv2-real
270 |                                 normal_img = normal_img[8:408, 4:604,:]
271 |                             gt_img = Metrics.tensor2img(visuals['GT'])
272 |                             ll_img = Metrics.tensor2img(visuals['LQ'])
273 |                         else:
274 |                             normal_img = Metrics.tensor2img2(visuals['HQ'])
275 |                             gt_img = Metrics.tensor2img2(visuals['GT'])
276 |                             ll_img = Metrics.tensor2img2(visuals['LQ'])
277 | 
278 |                         img_mode = 'single'
279 |                         '''
280 |                         if img_mode == 'single':
281 |                             util.save_img(
282 |                                 gt_img, '{}/{}_gt.png'.format(result_path_gt, idx))
283 |                             util.save_img(
284 |                                 ll_img, '{}/{}_in.png'.format(result_path_input, idx))
285 |                             # util.save_img(
286 |                             #     normal_img, '{}/{}_normal.png'.format(result_path_out, idx))
287 |                         else:
288 |                             util.save_img(
289 |                                 gt_img, '{}/{}_{}_gt.png'.format(result_path, current_step, idx))
290 |                             util.save_img(
291 |                                 normal_img, '{}/{}_{}_normal_process.png'.format(result_path, current_step, idx))
292 |                             util.save_img(
293 |                                 Metrics.tensor2img(visuals['HQ'][-1]), '{}/{}_{}_normal.png'.format(result_path, current_step, idx))
294 |                             normal_img = Metrics.tensor2img(visuals['HQ'][-1])
295 |                         '''
296 |     
297 | 
298 |                         # Similar to LLFlow, 
299 |                         # we follow a similar way of 'Kind' to finetune the overall brightness as illustrated 
300 |                         # in Line 73 (https://github.com/zhangyhuaee/KinD/blob/master/evaluate_LOLdataset.py).
301 |                         if opt_dataset['dataset'] == 'LOLv1' or opt_dataset['dataset'] == 'LOLv2':
302 |                             gt_img = gt_img / 255.
303 |                             normal_img = normal_img / 255.
304 | 
305 |                             mean_gray_out = cv2.cvtColor(normal_img.astype(np.float32), cv2.COLOR_BGR2GRAY).mean()
306 |                             mean_gray_gt = cv2.cvtColor(gt_img.astype(np.float32), cv2.COLOR_BGR2GRAY).mean()
307 |                             normal_img_adjust = np.clip(normal_img * (mean_gray_gt / mean_gray_out), 0, 1)
308 | 
309 |                             normal_img = (normal_img_adjust * 255).astype(np.uint8)
310 |                             gt_img = (gt_img * 255).astype(np.uint8)
311 | 
312 |                         psnr = util.calculate_psnr(normal_img, gt_img)
313 |                         ssim = util.calculate_ssim(normal_img, gt_img)
314 |                         
315 |                         normal_img_tensor = torch.tensor(normal_img.astype(np.float32))
316 |                         gt_img_tensor = torch.tensor(gt_img.astype(np.float32))
317 |                         normal_img_tensor = normal_img_tensor.permute(2, 0, 1).cuda()
318 |                         gt_img_tensor = gt_img_tensor.permute(2, 0, 1).cuda()
319 |                         lpips_scores = lpips_func(normal_img_tensor, gt_img_tensor).item()
320 |                         
321 |                         util.save_img(normal_img, '{}/{}_normal.png'.format(result_path_out, idx))
322 | 
323 |                         logger.info('cPSNR: {:.4e} cSSIM: {:.4e} cLPIPS: {:.4e}'.format(psnr, ssim, lpips_scores))
324 | 
325 |                         avg_ssim += ssim
326 |                         avg_psnr += psnr
327 |                         avg_lpips += lpips_scores
328 |                         # break
329 | 
330 |                     avg_psnr = avg_psnr / idx
331 |                     avg_ssim = avg_ssim / idx
332 |                     avg_lpips = avg_lpips / idx
333 | 
334 |                     if avg_psnr > best_psnr:
335 |                         best_psnr = avg_psnr
336 |                         best_ssim = avg_ssim
337 |                         best_lpips = avg_lpips
338 |                         if current_step % opt['train']['save_checkpoint_freq'] == 0 and rank <= 0:
339 |                             logger.info('Saving models and training states.')
340 |                             gen_path = diffusion.save_network(PD_steps[i+1], current_epoch, current_step, best_psnr, best_ssim, best_lpips)
341 |                             if args.progressive:
342 |                                 resume_state = gen_path
343 |                     # logger.info('# Validation Avg scores at timesteps {:3d} # PSNR: {:.4e} SSIM: {:.4e} LPIPS: {:.4e}'.format(PD_steps[i+1], avg_psnr, avg_ssim, avg_lpips))
344 |                     logger_val = logging.getLogger('val')  
345 |                     logger_val.info('# Avg scores #  psnr: {:.4e} SSIM: {:.4e} LPIPS: {:.4e}'.format(PD_steps[i+1],
346 |                         current_epoch, current_step, avg_psnr, avg_ssim, avg_lpips))
347 |         logger_val.info('# Best scores #  psnr: {:.4e} SSIM: {:.4e} LPIPS: {:.4e}'.format(PD_steps[i+1], best_psnr, best_ssim, best_lpips)) 
348 |         if opt["model"]['beta_schedule']["val"]["n_timestep"] == 2:
349 |             break
350 |                    
351 | 
352 |  
353 | if __name__ == '__main__':
354 |     
355 |     main()
356 | 
--------------------------------------------------------------------------------
/train_lol1.sh:
--------------------------------------------------------------------------------
1 | 
2 | python train.py --config ./config/lolv2_real_train.json --dataset ./config/lolv2_real.yml --w_str 0.0 --w_snr 0.8 --w_gt 1.0 --w_lpips 0.6 --ablation 
3 | 
--------------------------------------------------------------------------------
/train_lol2_real.sh:
--------------------------------------------------------------------------------
1 | 
2 | python train.py --config ./config/lolv2_real_train.json --dataset ./config/lolv2_real.yml --w_str 0.0 --w_snr 0.4 --w_gt 0.0 --w_lpips 0.6 --ablation 
--------------------------------------------------------------------------------
/train_lol2_syn.sh:
--------------------------------------------------------------------------------
1 | 
2 | 
3 | python train.py --config ./config/lolv2_syn_train.json --dataset ./config/lolv2_syn.yml --w_str 0.0 --w_snr 0.4 --w_gt 0.0 --w_lpips 0.6 --ablation  &
4 | 
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/__init__.py
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/ema.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/__pycache__/ema.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/util.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/__pycache__/util.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/ema.py:
--------------------------------------------------------------------------------
 1 | class EMA():
 2 |     def __init__(self, model, decay):
 3 |         self.model = model
 4 |         self.decay = decay
 5 |         self.shadow = {}
 6 |         self.backup = {}
 7 | 
 8 |     def register(self):
 9 |         for name, param in self.model.named_parameters():
10 |             if param.requires_grad:
11 |                 self.shadow[name] = param.data.clone()
12 | 
13 |     def update(self):
14 |         for name, param in self.model.named_parameters():
15 |             if param.requires_grad:
16 |                 assert name in self.shadow
17 |                 new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name]
18 |                 self.shadow[name] = new_average.clone()
19 | 
20 |     def apply_shadow(self):
21 |         for name, param in self.model.named_parameters():
22 |             if param.requires_grad:
23 |                 assert name in self.shadow
24 |                 self.backup[name] = param.data
25 |                 param.data = self.shadow[name]
26 | 
27 |     def restore(self):
28 |         for name, param in self.model.named_parameters():
29 |             if param.requires_grad:
30 |                 assert name in self.backup
31 |                 param.data = self.backup[name]
32 |         self.backup = {}
--------------------------------------------------------------------------------
/utils/niqe.py:
--------------------------------------------------------------------------------
  1 | import math
  2 | from os.path import dirname, join
  3 | 
  4 | import cv2
  5 | import numpy as np
  6 | import scipy
  7 | import scipy.io
  8 | import scipy.misc
  9 | import scipy.ndimage
 10 | import scipy.special
 11 | from PIL import Image
 12 | 
 13 | gamma_range = np.arange(0.2, 10, 0.001)
 14 | a = scipy.special.gamma(2.0/gamma_range)
 15 | a *= a
 16 | b = scipy.special.gamma(1.0/gamma_range)
 17 | c = scipy.special.gamma(3.0/gamma_range)
 18 | prec_gammas = a/(b*c)
 19 | 
 20 | 
 21 | def aggd_features(imdata):
 22 |     # flatten imdata
 23 |     imdata.shape = (len(imdata.flat),)
 24 |     imdata2 = imdata*imdata
 25 |     left_data = imdata2[imdata < 0]
 26 |     right_data = imdata2[imdata >= 0]
 27 |     left_mean_sqrt = 0
 28 |     right_mean_sqrt = 0
 29 |     if len(left_data) > 0:
 30 |         left_mean_sqrt = np.sqrt(np.average(left_data))
 31 |     if len(right_data) > 0:
 32 |         right_mean_sqrt = np.sqrt(np.average(right_data))
 33 | 
 34 |     if right_mean_sqrt != 0:
 35 |         gamma_hat = left_mean_sqrt/right_mean_sqrt
 36 |     else:
 37 |         gamma_hat = np.inf
 38 |     # solve r-hat norm
 39 | 
 40 |     imdata2_mean = np.mean(imdata2)
 41 |     if imdata2_mean != 0:
 42 |         r_hat = (np.average(np.abs(imdata))**2) / (np.average(imdata2))
 43 |     else:
 44 |         r_hat = np.inf
 45 |     rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1) *
 46 |                           (gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2))
 47 | 
 48 |     # solve alpha by guessing values that minimize ro
 49 |     pos = np.argmin((prec_gammas - rhat_norm)**2)
 50 |     alpha = gamma_range[pos]
 51 | 
 52 |     gam1 = scipy.special.gamma(1.0/alpha)
 53 |     gam2 = scipy.special.gamma(2.0/alpha)
 54 |     gam3 = scipy.special.gamma(3.0/alpha)
 55 | 
 56 |     aggdratio = np.sqrt(gam1) / np.sqrt(gam3)
 57 |     bl = aggdratio * left_mean_sqrt
 58 |     br = aggdratio * right_mean_sqrt
 59 | 
 60 |     # mean parameter
 61 |     N = (br - bl)*(gam2 / gam1)  # *aggdratio
 62 |     return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt)
 63 | 
 64 | 
 65 | def ggd_features(imdata):
 66 |     nr_gam = 1/prec_gammas
 67 |     sigma_sq = np.var(imdata)
 68 |     E = np.mean(np.abs(imdata))
 69 |     rho = sigma_sq/E**2
 70 |     pos = np.argmin(np.abs(nr_gam - rho))
 71 |     return gamma_range[pos], sigma_sq
 72 | 
 73 | 
 74 | def paired_product(new_im):
 75 |     shift1 = np.roll(new_im.copy(), 1, axis=1)
 76 |     shift2 = np.roll(new_im.copy(), 1, axis=0)
 77 |     shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1)
 78 |     shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1)
 79 | 
 80 |     H_img = shift1 * new_im
 81 |     V_img = shift2 * new_im
 82 |     D1_img = shift3 * new_im
 83 |     D2_img = shift4 * new_im
 84 | 
 85 |     return (H_img, V_img, D1_img, D2_img)
 86 | 
 87 | 
 88 | def gen_gauss_window(lw, sigma):
 89 |     sd = np.float32(sigma)
 90 |     lw = int(lw)
 91 |     weights = [0.0] * (2 * lw + 1)
 92 |     weights[lw] = 1.0
 93 |     sum = 1.0
 94 |     sd *= sd
 95 |     for ii in range(1, lw + 1):
 96 |         tmp = np.exp(-0.5 * np.float32(ii * ii) / sd)
 97 |         weights[lw + ii] = tmp
 98 |         weights[lw - ii] = tmp
 99 |         sum += 2.0 * tmp
100 |     for ii in range(2 * lw + 1):
101 |         weights[ii] /= sum
102 |     return weights
103 | 
104 | 
105 | def compute_image_mscn_transform(image, C=1, avg_window=None, extend_mode='constant'):
106 |     if avg_window is None:
107 |         avg_window = gen_gauss_window(3, 7.0/6.0)
108 |     assert len(np.shape(image)) == 2
109 |     h, w = np.shape(image)
110 |     mu_image = np.zeros((h, w), dtype=np.float32)
111 |     var_image = np.zeros((h, w), dtype=np.float32)
112 |     image = np.array(image).astype('float32')
113 |     scipy.ndimage.correlate1d(image, avg_window, 0, mu_image, mode=extend_mode)
114 |     scipy.ndimage.correlate1d(mu_image, avg_window, 1,
115 |                               mu_image, mode=extend_mode)
116 |     scipy.ndimage.correlate1d(image**2, avg_window, 0,
117 |                               var_image, mode=extend_mode)
118 |     scipy.ndimage.correlate1d(var_image, avg_window,
119 |                               1, var_image, mode=extend_mode)
120 |     var_image = np.sqrt(np.abs(var_image - mu_image**2))
121 |     return (image - mu_image)/(var_image + C), var_image, mu_image
122 | 
123 | 
124 | def _niqe_extract_subband_feats(mscncoefs):
125 |     # alpha_m,  = extract_ggd_features(mscncoefs)
126 |     alpha_m, N, bl, br, lsq, rsq = aggd_features(mscncoefs.copy())
127 |     pps1, pps2, pps3, pps4 = paired_product(mscncoefs)
128 |     alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1)
129 |     alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2)
130 |     alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3)
131 |     alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4)
132 |     return np.array([alpha_m, (bl+br)/2.0,
133 |                      alpha1, N1, bl1, br1,  # (V)
134 |                      alpha2, N2, bl2, br2,  # (H)
135 |                      alpha3, N3, bl3, bl3,  # (D1)
136 |                      alpha4, N4, bl4, bl4,  # (D2)
137 |                      ])
138 | 
139 | 
140 | def get_patches_train_features(img, patch_size, stride=8):
141 |     return _get_patches_generic(img, patch_size, 1, stride)
142 | 
143 | 
144 | def get_patches_test_features(img, patch_size, stride=8):
145 |     return _get_patches_generic(img, patch_size, 0, stride)
146 | 
147 | 
148 | def extract_on_patches(img, patch_size):
149 |     h, w = img.shape
150 |     patch_size = np.int32(patch_size)
151 |     patches = []
152 |     for j in range(0, h-patch_size+1, patch_size):
153 |         for i in range(0, w-patch_size+1, patch_size):
154 |             patch = img[j:j+patch_size, i:i+patch_size]
155 |             patches.append(patch)
156 | 
157 |     patches = np.array(patches)
158 | 
159 |     patch_features = []
160 |     for p in patches:
161 |         patch_features.append(_niqe_extract_subband_feats(p))
162 |     patch_features = np.array(patch_features)
163 | 
164 |     return patch_features
165 | 
166 | 
167 | def _get_patches_generic(img, patch_size, is_train, stride):
168 |     h, w = np.shape(img)
169 |     if h < patch_size or w < patch_size:
170 |         print("Input image is too small")
171 |         exit(0)
172 | 
173 |     # ensure that the patch divides evenly into img
174 |     hoffset = (h % patch_size)
175 |     woffset = (w % patch_size)
176 | 
177 |     if hoffset > 0:
178 |         img = img[:-hoffset, :]
179 |     if woffset > 0:
180 |         img = img[:, :-woffset]
181 | 
182 |     img = img.astype(np.float32)
183 |     # img2 = scipy.misc.imresize(img, 0.5, interp='bicubic', mode='F')
184 |     img2 = cv2.resize(img, (0, 0), fx=0.5, fy=0.5)
185 | 
186 |     mscn1, var, mu = compute_image_mscn_transform(img)
187 |     mscn1 = mscn1.astype(np.float32)
188 | 
189 |     mscn2, _, _ = compute_image_mscn_transform(img2)
190 |     mscn2 = mscn2.astype(np.float32)
191 | 
192 |     feats_lvl1 = extract_on_patches(mscn1, patch_size)
193 |     feats_lvl2 = extract_on_patches(mscn2, patch_size/2)
194 | 
195 |     feats = np.hstack((feats_lvl1, feats_lvl2))  # feats_lvl3))
196 | 
197 |     return feats
198 | 
199 | 
200 | def niqe(inputImgData):
201 | 
202 |     patch_size = 8
203 |     module_path = dirname(__file__)
204 | 
205 |     # TODO: memoize
206 |     params = scipy.io.loadmat(
207 |         join(module_path, 'niqe_image_params.mat'))
208 |     pop_mu = np.ravel(params["pop_mu"])
209 |     pop_cov = params["pop_cov"]
210 | 
211 |     if inputImgData.ndim == 3:
212 |         inputImgData = cv2.cvtColor(inputImgData, cv2.COLOR_BGR2GRAY)
213 |     M, N = inputImgData.shape
214 | 
215 |     # assert C == 1, "niqe called with videos containing %d channels. Please supply only the luminance channel" % (C,)
216 |     assert M > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters"
217 |     assert N > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters"
218 | 
219 |     feats = get_patches_test_features(inputImgData, patch_size)
220 |     sample_mu = np.mean(feats, axis=0)
221 |     sample_cov = np.cov(feats.T)
222 | 
223 |     X = sample_mu - pop_mu
224 |     covmat = ((pop_cov+sample_cov)/2.0)
225 |     pinvmat = scipy.linalg.pinv(covmat)
226 |     niqe_score = np.sqrt(np.dot(np.dot(X, pinvmat), X))
227 | 
228 |     return niqe_score
229 | 
--------------------------------------------------------------------------------
/utils/niqe_image_params.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/niqe_image_params.mat
--------------------------------------------------------------------------------
/utils/util.py:
--------------------------------------------------------------------------------
  1 | import glob
  2 | import os
  3 | import math
  4 | from datetime import datetime
  5 | import random
  6 | import logging
  7 | from collections import OrderedDict
  8 | 
  9 | import natsort
 10 | import numpy as np
 11 | import cv2
 12 | import torch
 13 | from torchvision.utils import make_grid
 14 | from shutil import get_terminal_size
 15 | import torch
 16 | import torch.nn.functional as F
 17 | from torch.autograd import Variable
 18 | import numpy as np
 19 | from math import exp
 20 | 
 21 | from skimage.metrics import structural_similarity as SSIM
 22 | 
 23 | 
 24 | def gaussian(window_size, sigma):
 25 |     gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
 26 |     return gauss / gauss.sum()
 27 | 
 28 | 
 29 | def create_window(window_size, channel):
 30 |     _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
 31 |     _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
 32 |     window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
 33 |     return window
 34 | 
 35 | 
 36 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
 37 |     mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
 38 |     mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
 39 | 
 40 |     mu1_sq = mu1.pow(2)
 41 |     mu2_sq = mu2.pow(2)
 42 |     mu1_mu2 = mu1 * mu2
 43 | 
 44 |     sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
 45 |     sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
 46 |     sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
 47 | 
 48 |     C1 = 0.01 ** 2
 49 |     C2 = 0.03 ** 2
 50 | 
 51 |     ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
 52 | 
 53 |     if size_average:
 54 |         return ssim_map.mean()
 55 |     else:
 56 |         return ssim_map.mean(1).mean(1).mean(1)
 57 | 
 58 | 
 59 | # class SSIM(torch.nn.Module):
 60 | #     def __init__(self, window_size=11, size_average=True):
 61 | #         super(SSIM, self).__init__()
 62 | #         self.window_size = window_size
 63 | #         self.size_average = size_average
 64 | #         self.channel = 1
 65 | #         self.window = create_window(window_size, self.channel)
 66 | 
 67 | #     def forward(self, img1, img2):
 68 | #         (_, channel, _, _) = img1.size()
 69 | 
 70 | #         if channel == self.channel and self.window.data.type() == img1.data.type():
 71 | #             window = self.window
 72 | #         else:
 73 | #             window = create_window(self.window_size, channel)
 74 | 
 75 | #             if img1.is_cuda:
 76 | #                 window = window.cuda(img1.get_device())
 77 | #             window = window.type_as(img1)
 78 | 
 79 | #             self.window = window
 80 | #             self.channel = channel
 81 | 
 82 | #         return _ssim(img1, img2, window, self.window_size, channel, self.size_average)
 83 | 
 84 | 
 85 | def ssim(img1, img2, window_size=11, size_average=True):
 86 |     (_, channel, _, _) = img1.size()
 87 |     window = create_window(window_size, channel)
 88 | 
 89 |     if img1.is_cuda:
 90 |         window = window.cuda(img1.get_device())
 91 |     window = window.type_as(img1)
 92 | 
 93 |     return _ssim(img1.mean(dim=0, keepdims=True), img2.mean(dim=0, keepdims=True), window, window_size, channel, size_average)
 94 | 
 95 | 
 96 | import yaml
 97 | 
 98 | try:
 99 |     from yaml import CLoader as Loader, CDumper as Dumper
100 | except ImportError:
101 |     from yaml import Loader, Dumper
102 | 
103 | 
104 | def OrderedYaml():
105 |     '''yaml orderedDict support'''
106 |     _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
107 | 
108 |     def dict_representer(dumper, data):
109 |         return dumper.represent_dict(data.items())
110 | 
111 |     def dict_constructor(loader, node):
112 |         return OrderedDict(loader.construct_pairs(node))
113 | 
114 |     Dumper.add_representer(OrderedDict, dict_representer)
115 |     Loader.add_constructor(_mapping_tag, dict_constructor)
116 |     return Loader, Dumper
117 | 
118 | 
119 | ####################
120 | # miscellaneous
121 | ####################
122 | 
123 | 
124 | def get_timestamp():
125 |     return datetime.now().strftime('%y%m%d-%H%M%S')
126 | 
127 | 
128 | def mkdir(path):
129 |     if not os.path.exists(path):
130 |         os.makedirs(path)
131 | 
132 | 
133 | def mkdirs(paths):
134 |     if isinstance(paths, str):
135 |         mkdir(paths)
136 |     else:
137 |         for path in paths:
138 |             mkdir(path)
139 | 
140 | 
141 | def mkdir_and_rename(path):
142 |     if os.path.exists(path):
143 |         new_name = path + '_archived_' + get_timestamp()
144 |         print('Path already exists. Rename it to [{:s}]'.format(new_name))
145 |         logger = logging.getLogger('base')
146 |         logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
147 |         os.rename(path, new_name)
148 |     os.makedirs(path)
149 | 
150 | 
151 | def set_random_seed(seed):
152 |     random.seed(seed)
153 |     np.random.seed(seed)
154 |     torch.manual_seed(seed)
155 |     torch.cuda.manual_seed_all(seed)
156 | 
157 | 
158 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
159 |     '''set up logger'''
160 |     lg = logging.getLogger(logger_name)
161 |     formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
162 |                                   datefmt='%y-%m-%d %H:%M:%S')
163 |     lg.setLevel(level)
164 |     if tofile:
165 |         log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
166 |         fh = logging.FileHandler(log_file, mode='w')
167 |         fh.setFormatter(formatter)
168 |         lg.addHandler(fh)
169 |     if screen:
170 |         sh = logging.StreamHandler()
171 |         sh.setFormatter(formatter)
172 |         lg.addHandler(sh)
173 | 
174 | 
175 | ####################
176 | # image convert
177 | ####################
178 | 
179 | 
180 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)):
181 |     '''
182 |     Converts a torch Tensor into an image Numpy array
183 |     Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
184 |     Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
185 |     '''
186 |     if hasattr(tensor, 'detach'):
187 |         tensor = tensor.detach()
188 |     tensor = tensor.squeeze().float().cpu().clamp_(*min_max)  # clamp
189 |     tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])  # to range [0,1]
190 |     n_dim = tensor.dim()
191 |     if n_dim == 4:
192 |         n_img = len(tensor)
193 |         img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
194 |         img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
195 |     elif n_dim == 3:
196 |         img_np = tensor.numpy()
197 |         img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))  # HWC, BGR
198 |     elif n_dim == 2:
199 |         img_np = tensor.numpy()
200 |     else:
201 |         raise TypeError(
202 |             'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
203 |     if out_type == np.uint8:
204 |         img_np = np.clip((img_np * 255.0).round(), 0, 255)
205 |         # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
206 |     return tensor
207 |     # return img_np.astype(out_type)
208 | 
209 | 
210 | def save_img(img, img_path, mode='RGB'):
211 |     cv2.imwrite(img_path, img)
212 | 
213 | 
214 | ####################
215 | # metric
216 | ####################
217 | 
218 | 
219 | def calculate_psnr(img1, img2):
220 |     # img1 and img2 have range [0, 255]
221 |     img1 = img1.astype(np.float64)
222 |     img2 = img2.astype(np.float64)
223 |     mse = np.mean((img1 - img2) ** 2)
224 |     if mse == 0:
225 |         return float('inf')
226 |     return 20 * math.log10(255.0 / math.sqrt(mse))
227 | 
228 | 
229 | def calculate_ssim(imgA, imgB, gray_scale=False):
230 | 
231 |     if gray_scale:
232 |         score, diff = SSIM(cv2.cvtColor(imgA, cv2.COLOR_RGB2GRAY), cv2.cvtColor(imgB, cv2.COLOR_RGB2GRAY), full=True, multichannel=True)
233 |     else:
234 |         score, diff = SSIM(imgA, imgB, full=True, multichannel=True)
235 |     return score
236 | 
237 | 
238 | def calculate_lpips(imgA, imgB, gray_scale=False):
239 | 
240 |     if gray_scale:
241 |         score, diff = SSIM(cv2.cvtColor(imgA, cv2.COLOR_RGB2GRAY), cv2.cvtColor(imgB, cv2.COLOR_RGB2GRAY), full=True, multichannel=True)
242 |     else:
243 |         score, diff = SSIM(imgA, imgB, full=True, multichannel=True)
244 |     return score
245 | 
246 | def get_resume_paths(opt):
247 |     resume_state_path = None
248 |     resume_model_path = None
249 |     ts = opt_get(opt, ['path', 'training_state'])
250 |     if opt.get('path', {}).get('resume_state', None) == "auto" and ts is not None:
251 |         wildcard = os.path.join(ts, "*")
252 |         paths = natsort.natsorted(glob.glob(wildcard))
253 |         if len(paths) > 0:
254 |             resume_state_path = paths[-1]
255 |             resume_model_path = resume_state_path.replace('training_state', 'models').replace('.state', '_G.pth')
256 |     else:
257 |         resume_state_path = opt.get('path', {}).get('resume_state')
258 |     return resume_state_path, resume_model_path
259 | 
260 | 
261 | def opt_get(opt, keys, default=None):
262 |     if opt is None:
263 |         return default
264 |     ret = opt
265 |     for k in keys:
266 |         ret = ret.get(k, None)
267 |         if ret is None:
268 |             return default
269 |     return ret
270 | 
--------------------------------------------------------------------------------