├── README.md ├── config ├── lolv1.yml ├── lolv1_test.json ├── lolv1_train.json ├── lolv2_real.yml ├── lolv2_real_test.json ├── lolv2_real_train.json ├── lolv2_syn.yml ├── lolv2_syn_test.json ├── lolv2_syn_train.json └── test_unpaired.json ├── core ├── __pycache__ │ ├── logger.cpython-38.pyc │ └── metrics.cpython-38.pyc ├── logger.py └── metrics.py ├── data ├── LoL_dataset.py ├── __init__.py ├── __pycache__ │ ├── LoL_dataset.cpython-38.pyc │ └── __init__.cpython-38.pyc ├── single_image_dataset.py └── util.py ├── dataset └── LOLv1 │ └── readme.txt ├── model ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── base_model.cpython-38.pyc │ ├── model.cpython-38.pyc │ └── networks.cpython-38.pyc ├── base_model.py ├── ddpm_modules │ ├── __pycache__ │ │ ├── diffusion.cpython-38.pyc │ │ └── unet.cpython-38.pyc │ ├── diffusion.py │ └── unet.py ├── model.py └── networks.py ├── options ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ └── options.cpython-38.pyc └── options.py ├── requirements.txt ├── test.py ├── test.sh ├── test_unpaired.py ├── train.py ├── train_lol1.sh ├── train_lol2_real.sh ├── train_lol2_syn.sh └── utils ├── __init__.py ├── __pycache__ ├── __init__.cpython-38.pyc ├── ema.cpython-38.pyc └── util.cpython-38.pyc ├── ema.py ├── niqe.py ├── niqe_image_params.mat └── util.py /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 |
4 | 5 | # **[CVPR2025]** Efficient Diffusion as Low Light Enhancer 6 | 7 |

8 | 9 | 10 |

11 | 12 |
13 | 14 | ## :fire: News 15 | 16 | - [2025/03/04] We have released the training code and inference code! 🚀🚀 17 | - [2025/02/27] ReDDiT has been accepted to CVPR 2025! 🤗🤗 18 | 19 | ## :memo: TODO 20 | 21 | - [x] Training code 22 | - [x] Inference code 23 | - [x] CVPR Camera-ready Version 24 | - [x] Project page 25 | - [ ] Journal Version & Teacher Model 26 | 27 | ## :hammer: Get Started 28 | 29 | ### :mag: Dependencies and Installation 30 | 31 | - Python 3.8 32 | - Pytorch 1.11 33 | 34 | 1. Create Conda Environment 35 | 36 | ``` 37 | conda create --name ReDDiT python=3.8 38 | conda activate ReDDiT 39 | ``` 40 | 41 | 2. Install PyTorch 42 | 43 | ``` 44 | pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 torchaudio==0.11.0 --extra-index-url https://download.pytorch.org/whl/cu113 45 | ``` 46 | 47 | 3. Install Dependencies 48 | 49 | ``` 50 | cd ReDDiT 51 | pip install -r requirements.txt 52 | ``` 53 | 54 | ### :page_with_curl: Data Preparation 55 | 56 | You can refer to the following links to download the datasets. 57 | 58 | - [LOLv1](https://daooshee.github.io/BMVC2018website/) 59 | - [LOLv2](https://github.com/flyywh/CVPR-2020-Semi-Low-Light) 60 | 61 | Then, put them in the following folder: 62 | 63 |
dataset (click to expand) 64 | 65 | ``` 66 | ├── dataset 67 | ├── LOLv1 68 | ├── our485 69 | ├──low 70 | ├──high 71 | ├── eval15 72 | ├──low 73 | ├──high 74 | ├── dataset 75 | ├── LOLv2 76 | ├── Real_captured 77 | ├── Train 78 | ├── Test 79 | ├── Synthetic 80 | ├── Train 81 | ├── Test 82 | ``` 83 | 84 |
85 | 86 | ### :blue_book: Testing 87 | 88 | Note: Following LLFlow and KinD, we have also adjusted the brightness of the output image produced by the network, based on the average value of Ground Truth (GT). ``It should be noted that this adjustment process does not influence the texture details generated; it is merely a straightforward method to regulate the overall illumination.`` Moreover, it can be easily adjusted according to user preferences in practical applications. 89 | 90 | You can also refer to the following links to download the checkpoints from [Google Drive](https://drive.google.com/file/d/13_XM8nFxJc2IfUotC2_lJo9ATt0rcIyg/view?usp=sharing) or [百度网盘 (Baidu Netdisk)](https://pan.baidu.com/s/1J7MP33Ws5kE673F-8zc2RA?pwd=nj8b) and put it in the following folder: 91 | 92 | ``` 93 | ├── checkpoints 94 | ├── lolv1_8step_gen.pth 95 | ├── lolv1_4step_gen.pth 96 | ├── lolv1_2step_gen.pth 97 | ...... 98 | ``` 99 | To test the model using the ``sh test.sh`` command and modify the `n_timestep` and `time_scale` parameters for different step models. Here's a general outline of the steps: 100 | ``` 101 | "val": { 102 | "schedule": "linear", 103 | "n_timestep": 8, 104 | "linear_start": 1e-4, 105 | "linear_end": 2e-2, 106 | "time_scale": 64 107 | } 108 | ``` 109 | 110 | ``` 111 | "val": { 112 | "schedule": "linear", 113 | "n_timestep": 4, 114 | "linear_start": 1e-4, 115 | "linear_end": 2e-2, 116 | "time_scale": 128 117 | } 118 | ``` 119 | 120 | ``` 121 | "val": { 122 | "schedule": "linear", 123 | "n_timestep": 2, 124 | "linear_start": 1e-4, 125 | "linear_end": 2e-2, 126 | "time_scale": 256 127 | } 128 | ``` 129 | ### :blue_book: Testing on unpaired data 130 | 131 | ``` 132 | python test_unpaired.py --config config/test_unpaired.json --input unpaired_image_folder 133 | ``` 134 | 135 | You can use any one of these three pre-trained models, and employ different sampling steps to obtain visual-pleasing results by modifying these terms in the 'test_unpaired.json'. 136 | 137 | 138 | 139 | ### :rocket: Training 140 | 141 | ``` 142 | bash train.sh 143 | ``` 144 | 145 | 146 | ## :black_nib: Citation 147 | 148 | If you find our repo useful for your research, please consider citing our paper: 149 | 150 | ```bibtex 151 | @InProceedings{Lan_2025_CVPR, 152 | author = {Lan, Guanzhou and Ma, Qianli and Yang, Yuqi and Wang, Zhigang and Wang, Dong and Li, Xuelong and Zhao, Bin}, 153 | title = {Efficient Diffusion as Low Light Enhancer}, 154 | booktitle = {Proceedings of the Computer Vision and Pattern Recognition Conference (CVPR)}, 155 | month = {June}, 156 | year = {2025}, 157 | pages = {21277-21286} 158 | } 159 | ``` 160 | 161 | 162 | ## :heart: Acknowledgement 163 | 164 | Our code is built upon [SR3](https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement). Thanks to the contributors for their great work. 165 | -------------------------------------------------------------------------------- /config/lolv1.yml: -------------------------------------------------------------------------------- 1 | 2 | dataset: LOLv1 3 | 4 | #### datasets 5 | datasets: 6 | train: 7 | dist: False 8 | root: ./dataset/LOLv1 9 | use_shuffle: true 10 | n_workers: 8 11 | batch_size: 16 12 | use_flip: true 13 | use_crop: true 14 | patch_size: 96 15 | 16 | val: 17 | dist: False 18 | root: ./dataset/LOLv1 19 | n_workers: 1 20 | use_crop: true 21 | batch_size: 1 22 | 23 | 24 | 25 | -------------------------------------------------------------------------------- /config/lolv1_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "lolv1_test_8", 3 | "phase": "test", 4 | "distill": false, 5 | "gpu_ids": [ 6 | 0 7 | ], 8 | "path": { 9 | "log": "logs", 10 | "tb_logger": "tb_logger", 11 | "results": "results", 12 | "checkpoint": "checkpoint", 13 | "resume_state": "./checkpoint/lolv1_4step_gen.pth" 14 | }, 15 | "model": { 16 | "which_model_G": "ddpm", 17 | "finetune_norm": false, 18 | "unet": { 19 | "in_channel": 6, 20 | "out_channel": 3, 21 | "inner_channel": 64, 22 | "channel_multiplier": [ 23 | 1, 24 | 1, 25 | 2, 26 | 2, 27 | 4 28 | ], 29 | "attn_res": [ 30 | 16 31 | ], 32 | "res_blocks": 2, 33 | "dropout": 0 34 | }, 35 | "beta_schedule": { 36 | "train": { 37 | "schedule": "linear", 38 | "n_timestep": 4, 39 | "linear_start": 1e-4, 40 | "linear_end": 2e-2, 41 | "time_scale": 128 42 | 43 | 44 | }, 45 | "val": { 46 | "schedule": "linear", 47 | "n_timestep": 4, 48 | "linear_start": 1e-4, 49 | "linear_end": 2e-2, 50 | "time_scale": 128 51 | 52 | 53 | } 54 | }, 55 | "diffusion": { 56 | "image_size": 128, 57 | "channels": 6, 58 | "conditional": true, 59 | "w_gt": 0.1, 60 | "w_snr": 0.5, 61 | "w_str": 0.1, 62 | "w_lpips": 0.2 63 | } 64 | }, 65 | "train": { 66 | "n_iter": 1000000, 67 | "val_freq": 1e4, 68 | "save_checkpoint_freq": 5e4, 69 | "print_freq": 200, 70 | "optimizer": { 71 | "type": "adam", 72 | "lr": 1e-4, 73 | "lr_policy":"linear", 74 | "lr_decay_iters":3000, 75 | "n_lr_iters": 2000 76 | }, 77 | "ema_scheduler": { 78 | "step_start_ema": 5000, 79 | "update_ema_every": 1, 80 | "ema_decay": 0.9999 81 | } 82 | }, 83 | "wandb": { 84 | "project": "llie_ddpm" 85 | } 86 | } -------------------------------------------------------------------------------- /config/lolv1_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "lolv1_train", 3 | "phase": "train", 4 | "distill": true, 5 | "gpu_ids": [ 6 | 0 7 | ], 8 | "path": { 9 | "log": "logs", 10 | "tb_logger": "tb_logger", 11 | "results": "results", 12 | "checkpoint": "checkpoint", 13 | "resume_state": "./checkpoint/lolv1_4step_gen.pth" 14 | }, 15 | "model": { 16 | "which_model_G": "ddpm", 17 | "finetune_norm": false, 18 | "unet": { 19 | "in_channel": 6, 20 | "out_channel": 3, 21 | "inner_channel": 64, 22 | "channel_multiplier": [ 23 | 1, 24 | 1, 25 | 2, 26 | 2, 27 | 4 28 | ], 29 | "attn_res": [ 30 | 16 31 | ], 32 | "res_blocks": 2, 33 | "dropout": 0 34 | }, 35 | "beta_schedule": { 36 | "train": { 37 | "schedule": "linear", 38 | "n_timestep": 513, 39 | "linear_start": 1e-4, 40 | "linear_end": 2e-2, 41 | "time_scale": 1, 42 | "reflow": false 43 | }, 44 | "val": { 45 | "schedule": "linear", 46 | "n_timestep": 513, 47 | "linear_start": 1e-4, 48 | "linear_end": 2e-2 49 | 50 | } 51 | }, 52 | "diffusion": { 53 | "image_size": 128, 54 | "channels": 6, 55 | "conditional": true, 56 | "w_gt": 0.1, 57 | "w_snr": 0.5, 58 | "w_str": 0.1, 59 | "w_lpips": 0.2 60 | } 61 | }, 62 | "train": { 63 | "n_iter": 5000, 64 | "val_freq": 100, 65 | "save_checkpoint_freq": 100, 66 | "print_freq": 200, 67 | "optimizer": { 68 | "type": "adam", 69 | "lr": 1e-4, 70 | "lr_policy":"linear", 71 | "lr_decay_iters":3000, 72 | "n_lr_iters": 2000 73 | }, 74 | "ema_scheduler": { 75 | "step_start_ema": 5000, 76 | "update_ema_every": 1, 77 | "ema_decay": 0.9999 78 | } 79 | }, 80 | "wandb": { 81 | "project": "llie_ddpm" 82 | } 83 | } 84 | -------------------------------------------------------------------------------- /config/lolv2_real.yml: -------------------------------------------------------------------------------- 1 | 2 | dataset: LOLv2 3 | 4 | #### datasets 5 | datasets: 6 | train: 7 | dist: False 8 | root: ./dataset/LOL-v2 9 | use_shuffle: true 10 | n_workers: 8 11 | batch_size: 16 12 | use_flip: true 13 | use_crop: true 14 | patch_size: 96 15 | sub_data: Real_captured 16 | 17 | val: 18 | dist: False 19 | root: ./dataset/LOL-v2 20 | n_workers: 1 21 | batch_size: 1 22 | sub_data: Real_captured 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /config/lolv2_real_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "lolv2_test_real", 3 | "phase": "test", 4 | "gpu_ids": [ 5 | 0 6 | ], 7 | 8 | "path": { 9 | "log": "logs", 10 | "tb_logger": "tb_logger", 11 | "results": "results", 12 | "checkpoint": "checkpoint", 13 | "resume_state": "./checkpoint/lolv2_real_4step_gen.pth" 14 | }, 15 | "freq_aware": false, 16 | "freq_awareUNet": { 17 | "b1": 1.6, 18 | "b2": 1.6, 19 | "s1": 0.9, 20 | "s2": 0.9 21 | }, 22 | "model": { 23 | "which_model_G": "ddpm", 24 | "finetune_norm": false, 25 | "unet": { 26 | "in_channel": 6, 27 | "out_channel": 3, 28 | "inner_channel": 64, 29 | "channel_multiplier": [ 30 | 1, 31 | 1, 32 | 2, 33 | 2, 34 | 4 35 | ], 36 | "attn_res": [ 37 | 16 38 | ], 39 | "res_blocks": 2, 40 | "dropout": 0 41 | }, 42 | "beta_schedule": { 43 | "train": { 44 | "schedule": "linear", 45 | "n_timestep": 4, 46 | "linear_start": 1e-4, 47 | "linear_end": 2e-2, 48 | "time_scale": 128 49 | }, 50 | "val": { 51 | "schedule": "linear", 52 | "n_timestep": 4, 53 | "linear_start": 1e-4, 54 | "linear_end": 2e-2, 55 | "time_scale": 128 56 | } 57 | }, 58 | "diffusion": { 59 | "image_size": 128, 60 | "channels": 6, 61 | "conditional": true 62 | } 63 | }, 64 | "train": { 65 | "n_iter": 1000000, 66 | "val_freq": 1e4, 67 | "save_checkpoint_freq": 5e4, 68 | "print_freq": 200, 69 | "optimizer": { 70 | "type": "adam", 71 | "lr": 1e-4 72 | }, 73 | "ema_scheduler": { 74 | "step_start_ema": 5000, 75 | "update_ema_every": 1, 76 | "ema_decay": 0.9999 77 | } 78 | }, 79 | "wandb": { 80 | "project": "llie_ddpm" 81 | } 82 | } -------------------------------------------------------------------------------- /config/lolv2_real_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "lolv2_train_real", 3 | "phase": "train", 4 | "distill": true, 5 | "CD":false, 6 | "gpu_ids": [ 7 | 0 8 | ], 9 | 10 | "path": { 11 | "log": "logs", 12 | "tb_logger": "tb_logger", 13 | "results": "results", 14 | "checkpoint": "checkpoints", 15 | "resume_state": "./checkpoint/lolv2_real_4step_gen.pth" 16 | }, 17 | "model": { 18 | "which_model_G": "ddpm", 19 | "finetune_norm": false, 20 | "unet": { 21 | "in_channel": 6, 22 | "out_channel": 3, 23 | "inner_channel": 64, 24 | "channel_multiplier": [ 25 | 1, 26 | 1, 27 | 2, 28 | 2, 29 | 4 30 | ], 31 | "attn_res": [ 32 | 16 33 | ], 34 | "res_blocks": 2, 35 | "dropout": 0 36 | }, 37 | "beta_schedule": { 38 | "train": { 39 | "schedule": "linear", 40 | "n_timestep": 513, 41 | "linear_start": 1e-4, 42 | "linear_end": 2e-2, 43 | "reflow": false, 44 | "time_scale": 1 45 | }, 46 | "val": { 47 | "schedule": "linear", 48 | 49 | "n_timestep": 513, 50 | "linear_start": 1e-4, 51 | "linear_end": 2e-2, 52 | "time_scale": 1 53 | } 54 | }, 55 | "diffusion": { 56 | "image_size": 128, 57 | "channels": 6, 58 | "conditional": true, 59 | "w_gt": 0.1, 60 | "w_snr": 0.5, 61 | "w_str": 0.1, 62 | "w_lpips": 0.2 63 | 64 | } 65 | }, 66 | "train": { 67 | "n_iter": 5000, 68 | "val_freq": 100, 69 | "save_checkpoint_freq": 100, 70 | "print_freq": 100, 71 | "optimizer": { 72 | "type": "adam", 73 | "lr": 1e-4, 74 | "lr_policy":"linear", 75 | "lr_decay_iters":3000, 76 | "n_lr_iters": 2000 77 | }, 78 | "ema_scheduler": { 79 | "step_start_ema": 5000, 80 | "update_ema_every": 1, 81 | "ema_decay": 0.9999 82 | } 83 | }, 84 | "wandb": { 85 | "project": "llie_ddpm" 86 | } 87 | } -------------------------------------------------------------------------------- /config/lolv2_syn.yml: -------------------------------------------------------------------------------- 1 | 2 | dataset: LOLv2 3 | 4 | #### datasets 5 | datasets: 6 | train: 7 | dist: False 8 | root: ./dataset/LOL-v2 9 | use_shuffle: true 10 | n_workers: 8 11 | batch_size: 16 12 | use_flip: true 13 | use_crop: true 14 | patch_size: 96 15 | sub_data: Synthetic 16 | 17 | val: 18 | dist: False 19 | root: ./dataset/LOL-v2 20 | n_workers: 1 21 | batch_size: 1 22 | sub_data: Synthetic 23 | 24 | 25 | 26 | -------------------------------------------------------------------------------- /config/lolv2_syn_test.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "lolv2_test_syn", 3 | "phase": "test", 4 | "gpu_ids": [ 5 | 0 6 | ], 7 | "path": { 8 | "log": "logs", 9 | "tb_logger": "tb_logger", 10 | "results": "results", 11 | "checkpoint": "checkpoint", 12 | "resume_state": "./checkpoint/lolv2_syn_4step_gen.pth" 13 | }, 14 | "model": { 15 | "which_model_G": "ddpm", 16 | "finetune_norm": false, 17 | "unet": { 18 | "in_channel": 6, 19 | "out_channel": 3, 20 | "inner_channel": 64, 21 | "channel_multiplier": [ 22 | 1, 23 | 1, 24 | 2, 25 | 2, 26 | 4 27 | ], 28 | "attn_res": [ 29 | 16 30 | ], 31 | "res_blocks": 2, 32 | "dropout": 0 33 | }, 34 | "beta_schedule": { 35 | "train": { 36 | "schedule": "linear", 37 | "n_timestep": 4, 38 | "linear_start": 1e-4, 39 | "linear_end": 2e-2, 40 | "time_scale": 128 41 | }, 42 | "val": { 43 | "schedule": "linear", 44 | "n_timestep": 4, 45 | "linear_start": 1e-4, 46 | "linear_end": 2e-2, 47 | "time_scale": 128 48 | } 49 | }, 50 | "diffusion": { 51 | "image_size": 128, 52 | "channels": 6, 53 | "conditional": true 54 | } 55 | }, 56 | "train": { 57 | "n_iter": 1000000, 58 | "val_freq": 1e4, 59 | "save_checkpoint_freq": 5e4, 60 | "print_freq": 200, 61 | "optimizer": { 62 | "type": "adam", 63 | "lr": 1e-4 64 | }, 65 | "ema_scheduler": { 66 | "step_start_ema": 5000, 67 | "update_ema_every": 1, 68 | "ema_decay": 0.9999 69 | } 70 | }, 71 | "wandb": { 72 | "project": "llie_ddpm" 73 | } 74 | } -------------------------------------------------------------------------------- /config/lolv2_syn_train.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "lolv2_train_syn", 3 | "phase": "train", 4 | "distill": true, 5 | "gpu_ids": [ 6 | 0 7 | ], 8 | "path": { 9 | "log": "logs", 10 | "tb_logger": "tb_logger", 11 | "results": "results", 12 | "checkpoint": "checkpoint", 13 | "resume_state": "./checkpoint/lolv2_syn_4step_gen.pth" 14 | }, 15 | "model": { 16 | "which_model_G": "ddpm", 17 | "finetune_norm": false, 18 | "unet": { 19 | "in_channel": 6, 20 | "out_channel": 3, 21 | "inner_channel": 64, 22 | "channel_multiplier": [ 23 | 1, 24 | 1, 25 | 2, 26 | 2, 27 | 4 28 | ], 29 | "attn_res": [ 30 | 16 31 | ], 32 | "res_blocks": 2, 33 | "dropout": 0 34 | }, 35 | "beta_schedule": { 36 | "train": { 37 | "schedule": "linear", 38 | "n_timestep": 513, 39 | "linear_start": 1e-4, 40 | "linear_end": 2e-2, 41 | "reflow": false, 42 | "time_scale": 1 43 | }, 44 | "val": { 45 | "schedule": "linear", 46 | "n_timestep": 513, 47 | "linear_start": 1e-4, 48 | "linear_end": 2e-2 49 | } 50 | }, 51 | "diffusion": { 52 | "image_size": 128, 53 | "channels": 6, 54 | "conditional": true, 55 | "w_gt": 0.1, 56 | "w_snr": 0.5, 57 | "w_str": 0.1, 58 | "w_lpips": 0.2 59 | } 60 | }, 61 | "train": { 62 | "n_iter": 5000 , 63 | "val_freq": 100, 64 | "save_checkpoint_freq": 100, 65 | "print_freq": 100, 66 | "optimizer": { 67 | "type": "adam", 68 | "lr": 1e-4, 69 | "lr_policy":"linear", 70 | "lr_decay_iters":3000, 71 | "n_lr_iters": 2000 72 | }, 73 | "ema_scheduler": { 74 | "step_start_ema": 5000, 75 | "update_ema_every": 1, 76 | "ema_decay": 0.9999 77 | } 78 | }, 79 | "wandb": { 80 | "project": "llie_ddpm" 81 | } 82 | } -------------------------------------------------------------------------------- /config/test_unpaired.json: -------------------------------------------------------------------------------- 1 | { 2 | "name": "test_unpaired", 3 | "phase": "test", 4 | "gpu_ids": [ 5 | 0 6 | ], 7 | "path": { 8 | "log": "logs", 9 | "tb_logger": "tb_logger", 10 | "results": "results", 11 | "checkpoint": "checkpoint", 12 | "resume_state": "experiments/lolv2_train_syn/lolv2_train_syn_w_snr:0.2_w_str:0.0_w_gt:1.0_w_lpips:0.6_240410_125713/checkpoint/num_step_8/psnr30.1459_ssim0.9424_lpips0.0284_I2700_E47_gen_ema.pth" 13 | }, 14 | "freq_aware": false, 15 | "freq_awareUNet": { 16 | "b1": 1.6, 17 | "b2": 1.6, 18 | "s1": 0.9, 19 | "s2": 0.9 20 | }, 21 | "model": { 22 | "which_model_G": "ddpm", 23 | "finetune_norm": false, 24 | "unet": { 25 | "in_channel": 6, 26 | "out_channel": 3, 27 | "inner_channel": 64, 28 | "channel_multiplier": [ 29 | 1, 30 | 1, 31 | 2, 32 | 2, 33 | 4 34 | ], 35 | "attn_res": [ 36 | 16 37 | ], 38 | "res_blocks": 2, 39 | "dropout": 0 40 | }, 41 | "beta_schedule": { 42 | "train": { 43 | "schedule": "linear", 44 | "n_timestep": 4, 45 | "linear_start": 1e-4, 46 | "linear_end": 2e-2, 47 | "time_scale": 128 48 | }, 49 | "val": { 50 | "schedule": "linear", 51 | "n_timestep": 4, 52 | "linear_start": 1e-4, 53 | "linear_end": 2e-2, 54 | "time_scale": 128 55 | } 56 | }, 57 | "diffusion": { 58 | "image_size": 128, 59 | "channels": 6, 60 | "conditional": true 61 | } 62 | }, 63 | "train": { 64 | "n_iter": 1000000, 65 | "val_freq": 1e4, 66 | "save_checkpoint_freq": 5e4, 67 | "print_freq": 200, 68 | "optimizer": { 69 | "type": "adam", 70 | "lr": 1e-4 71 | }, 72 | "ema_scheduler": { 73 | "step_start_ema": 5000, 74 | "update_ema_every": 1, 75 | "ema_decay": 0.9999 76 | } 77 | }, 78 | "wandb": { 79 | "project": "llie_ddpm" 80 | } 81 | } -------------------------------------------------------------------------------- /core/__pycache__/logger.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/core/__pycache__/logger.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/metrics.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/core/__pycache__/metrics.cpython-38.pyc -------------------------------------------------------------------------------- /core/logger.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | from collections import OrderedDict 5 | import json 6 | from datetime import datetime 7 | 8 | 9 | def mkdirs(paths): 10 | if isinstance(paths, str): 11 | os.makedirs(paths, exist_ok=True) 12 | else: 13 | for path in paths: 14 | os.makedirs(path, exist_ok=True) 15 | 16 | 17 | def get_timestamp(): 18 | return datetime.now().strftime('%y%m%d_%H%M%S') 19 | 20 | 21 | def parse(args): 22 | phase = args.phase 23 | opt_path = args.config 24 | gpu_ids = args.gpu_ids 25 | # remove comments starting with '//' 26 | json_str = '' 27 | with open(opt_path, 'r') as f: 28 | for line in f: 29 | line = line.split('//')[0] + '\n' 30 | json_str += line 31 | opt = json.loads(json_str, object_pairs_hook=OrderedDict) 32 | 33 | # set log directory 34 | if args.debug: 35 | opt['name'] = 'debug_{}'.format(opt['name']) 36 | if args.brutal_search: 37 | experiments_root = os.path.join( 38 | 'experiments', opt['name'], '{}_noise_start:{}_noise_end:{}_{}'.format(opt['name'], args.noise_start, args.noise_end, get_timestamp())) 39 | else: 40 | if opt['phase'] == 'train': 41 | experiments_root = os.path.join( 42 | 'experiments', opt['name'], '{}_w_snr:{}_w_str:{}_w_gt:{}_w_lpips:{}_{}'.format(opt['name'], args.w_snr, args.w_str, args.w_gt, args.w_lpips, get_timestamp())) 43 | elif opt['phase'] == 'test': 44 | experiments_root = os.path.join( 45 | 'experiments', opt['name'], '{}_numstep:{}_w_snr:{}_w_gt:{}_w_lpips:{}_{}'.format(opt['name'],opt["model"]['beta_schedule']['val']['n_timestep'], args.w_snr, args.w_gt, args.w_lpips, get_timestamp())) 46 | opt['path']['experiments_root'] = experiments_root 47 | for key, path in opt['path'].items(): 48 | if 'resume' not in key and 'experiments' not in key: 49 | opt['path'][key] = os.path.join(experiments_root, path) 50 | mkdirs(opt['path'][key]) 51 | 52 | # change dataset length limit 53 | opt['phase'] = phase 54 | 55 | # export CUDA_VISIBLE_DEVICES 56 | if gpu_ids is not None: 57 | opt['gpu_ids'] = [int(id) for id in gpu_ids.split(',')] 58 | gpu_list = gpu_ids 59 | else: 60 | gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 61 | 62 | if len(gpu_list) > 1: 63 | opt['distributed'] = True 64 | else: 65 | opt['distributed'] = False 66 | 67 | # debug 68 | if 'debug' in opt['name']: 69 | opt['train']['val_freq'] = 2 70 | opt['train']['print_freq'] = 2 71 | opt['train']['save_checkpoint_freq'] = 3 72 | opt['datasets']['train']['batch_size'] = 2 73 | opt['model']['beta_schedule']['train']['n_timestep'] = 10 74 | opt['model']['beta_schedule']['val']['n_timestep'] = 10 75 | opt['datasets']['train']['data_len'] = 6 76 | opt['datasets']['val']['data_len'] = 3 77 | 78 | # validation in train phase 79 | # if phase == 'train': 80 | # opt['datasets']['val']['data_len'] = 3 81 | 82 | # W&B Logging 83 | try: 84 | log_wandb_ckpt = args.log_wandb_ckpt 85 | opt['log_wandb_ckpt'] = log_wandb_ckpt 86 | except: 87 | pass 88 | try: 89 | log_eval = args.log_eval 90 | opt['log_eval'] = log_eval 91 | except: 92 | pass 93 | try: 94 | log_infer = args.log_infer 95 | opt['log_infer'] = log_infer 96 | except: 97 | pass 98 | 99 | return opt 100 | 101 | 102 | class NoneDict(dict): 103 | def __missing__(self, key): 104 | return None 105 | 106 | 107 | # convert to NoneDict, which return None for missing key. 108 | def dict_to_nonedict(opt): 109 | if isinstance(opt, dict): 110 | new_opt = dict() 111 | for key, sub_opt in opt.items(): 112 | new_opt[key] = dict_to_nonedict(sub_opt) 113 | return NoneDict(**new_opt) 114 | elif isinstance(opt, list): 115 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 116 | else: 117 | return opt 118 | 119 | 120 | def dict2str(opt, indent_l=1): 121 | '''dict to string for logger''' 122 | msg = '' 123 | for k, v in opt.items(): 124 | if isinstance(v, dict): 125 | msg += ' ' * (indent_l * 2) + k + ':[\n' 126 | msg += dict2str(v, indent_l + 1) 127 | msg += ' ' * (indent_l * 2) + ']\n' 128 | else: 129 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 130 | return msg 131 | 132 | 133 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False): 134 | '''set up logger''' 135 | l = logging.getLogger(logger_name) 136 | formatter = logging.Formatter( 137 | '%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', datefmt='%y-%m-%d %H:%M:%S') 138 | log_file = os.path.join(root, '{}.log'.format(phase)) 139 | fh = logging.FileHandler(log_file, mode='w') 140 | fh.setFormatter(formatter) 141 | l.setLevel(level) 142 | l.addHandler(fh) 143 | if screen: 144 | sh = logging.StreamHandler() 145 | sh.setFormatter(formatter) 146 | l.addHandler(sh) 147 | -------------------------------------------------------------------------------- /core/metrics.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import numpy as np 4 | import cv2 5 | from torchvision.utils import make_grid 6 | 7 | 8 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)): 9 | ''' 10 | Converts a torch Tensor into an image Numpy array 11 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 12 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 13 | ''' 14 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 15 | tensor = (tensor - min_max[0]) / \ 16 | (min_max[1] - min_max[0]) # to range [0,1] 17 | n_dim = tensor.dim() 18 | if n_dim == 4: 19 | n_img = len(tensor) 20 | # img_np = make_grid(tensor, nrow=int( 21 | # math.sqrt(n_img)), normalize=False).numpy() 22 | # img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 23 | 24 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 25 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 26 | elif n_dim == 3: 27 | img_np = tensor.numpy() 28 | # img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 29 | 30 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 31 | elif n_dim == 2: 32 | img_np = tensor.numpy() 33 | else: 34 | raise TypeError( 35 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 36 | if out_type == np.uint8: 37 | img_np = np.clip((img_np * 255.0).round(), 0, 255) 38 | # img_np = (img_np * 255.0).round() 39 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 40 | return img_np.astype(out_type) 41 | 42 | def tensor2img2(tensor, out_type=np.uint8, min_max=(-1, 1)): 43 | ''' 44 | Converts a torch Tensor into an image Numpy array 45 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 46 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 47 | ''' 48 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 49 | 50 | n_dim = tensor.dim() 51 | if n_dim == 4: 52 | n_img = len(tensor) 53 | # img_np = make_grid(tensor, nrow=int( 54 | # math.sqrt(n_img)), normalize=False).numpy() 55 | # img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 56 | 57 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 58 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 59 | elif n_dim == 3: 60 | img_np = tensor.numpy() 61 | # img_np = np.transpose(img_np, (1, 2, 0)) # HWC, RGB 62 | 63 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 64 | elif n_dim == 2: 65 | img_np = tensor.numpy() 66 | else: 67 | raise TypeError( 68 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 69 | if out_type == np.uint8: 70 | img_np = np.clip((img_np * 255.0).round(), 0, 255) 71 | # img_np = (img_np * 255.0).round() 72 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 73 | return img_np.astype(out_type) 74 | 75 | def save_img(img, img_path, mode='RGB'): 76 | cv2.imwrite(img_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 77 | # cv2.imwrite(img_path, img) 78 | 79 | 80 | def calculate_psnr(img1, img2): 81 | # img1 and img2 have range [0, 255] 82 | img1 = img1.astype(np.float64) 83 | img2 = img2.astype(np.float64) 84 | mse = np.mean((img1 - img2)**2) 85 | if mse == 0: 86 | return float('inf') 87 | return 20 * math.log10(255.0 / math.sqrt(mse)) 88 | 89 | 90 | def ssim(img1, img2): 91 | C1 = (0.01 * 255)**2 92 | C2 = (0.03 * 255)**2 93 | 94 | img1 = img1.astype(np.float64) 95 | img2 = img2.astype(np.float64) 96 | kernel = cv2.getGaussianKernel(11, 1.5) 97 | window = np.outer(kernel, kernel.transpose()) 98 | 99 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 100 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 101 | mu1_sq = mu1**2 102 | mu2_sq = mu2**2 103 | mu1_mu2 = mu1 * mu2 104 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 105 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 106 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 107 | 108 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 109 | (sigma1_sq + sigma2_sq + C2)) 110 | return ssim_map.mean() 111 | 112 | 113 | def calculate_ssim(img1, img2): 114 | '''calculate SSIM 115 | the same outputs as MATLAB's 116 | img1, img2: [0, 255] 117 | ''' 118 | if not img1.shape == img2.shape: 119 | raise ValueError('Input images must have the same dimensions.') 120 | if img1.ndim == 2: 121 | return ssim(img1, img2) 122 | elif img1.ndim == 3: 123 | if img1.shape[2] == 3: 124 | ssims = [] 125 | for i in range(3): 126 | ssims.append(ssim(img1, img2)) 127 | return np.array(ssims).mean() 128 | elif img1.shape[2] == 1: 129 | return ssim(np.squeeze(img1), np.squeeze(img2)) 130 | else: 131 | raise ValueError('Wrong input image dimensions.') 132 | -------------------------------------------------------------------------------- /data/LoL_dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch.utils.data as data 3 | import numpy as np 4 | import torch 5 | import cv2 6 | from torchvision.transforms import ToTensor 7 | import torchvision.transforms as T 8 | import torchvision 9 | 10 | 11 | class LOLv1_Dataset(data.Dataset): 12 | def __init__(self, opt, train, all_opt): 13 | self.root = opt["root"] 14 | self.opt = opt 15 | self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False 16 | self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False 17 | self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False 18 | self.crop_size = opt.get("patch_size", None) 19 | if train: 20 | self.split = 'train' 21 | self.root = os.path.join(self.root, 'our485') 22 | else: 23 | self.split = 'val' 24 | self.root = os.path.join(self.root, 'eval15') 25 | self.pairs = self.load_pairs(self.root) 26 | self.to_tensor = ToTensor() 27 | 28 | def __len__(self): 29 | return len(self.pairs) 30 | 31 | def load_pairs(self, folder_path): 32 | 33 | low_list = os.listdir(os.path.join(folder_path, 'low')) 34 | low_list = filter(lambda x: 'png' in x, low_list) 35 | 36 | pairs = [] 37 | for idx, f_name in enumerate(low_list): 38 | 39 | if self.split == 'val': 40 | pairs.append( 41 | [cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'low', f_name)), cv2.COLOR_BGR2RGB), 42 | cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'high', f_name)), cv2.COLOR_BGR2RGB), 43 | f_name.split('.')[0]]) 44 | else: 45 | pairs.append( 46 | [cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'low', f_name)), cv2.COLOR_BGR2RGB), 47 | cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'high', f_name)), cv2.COLOR_BGR2RGB), 48 | f_name.split('.')[0]]) 49 | return pairs 50 | 51 | def get_max(self,input): 52 | T,_=torch.max(input,dim=0) 53 | T=T+0.1 54 | input[0,:,:] = input[0,:,:]/ T 55 | input[1,:,:] = input[1,:,:]/ T 56 | input[2,:,:]= input[2,:,:] / T 57 | return input 58 | 59 | def __getitem__(self, item): 60 | lr, hr, f_name = self.pairs[item] 61 | 62 | 63 | if self.use_crop and self.split != 'val': 64 | hr, lr = random_crop(hr, lr, self.crop_size) 65 | elif self.split == 'val': 66 | lr = cv2.copyMakeBorder(lr, 8,8,4,4,cv2.BORDER_REFLECT) 67 | 68 | if self.use_flip: 69 | hr, lr = random_flip(hr, lr) 70 | 71 | if self.use_rot: 72 | hr, lr = random_rotation(hr, lr) 73 | 74 | hr = self.to_tensor(hr) 75 | lr = self.to_tensor(lr) 76 | # lr = self.get_max(lr) 77 | 78 | [lr, hr] = transform_augment( 79 | [lr, hr], split=self.split, min_max=(-1, 1)) 80 | 81 | return {'LQ': lr, 'GT': hr, 'LQ_path': f_name, 'GT_path': f_name} 82 | 83 | class LOLv2_Dataset(data.Dataset): 84 | def __init__(self, opt, train, all_opt): 85 | self.root = opt["root"] 86 | self.opt = opt 87 | self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False 88 | self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False 89 | self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False 90 | self.crop_size = opt.get("patch_size", None) 91 | self.sub_data = opt.get("sub_data", None) 92 | self.pairs = [] 93 | self.train = train 94 | if train: 95 | self.split = 'train' 96 | root = os.path.join(self.root, self.sub_data, 'Train') 97 | else: 98 | self.split = 'val' 99 | root = os.path.join(self.root, self.sub_data, 'Test') 100 | self.pairs.extend(self.load_pairs(root)) 101 | self.to_tensor = ToTensor() 102 | self.gamma_aug = opt['gamma_aug'] if 'gamma_aug' in opt.keys() else False 103 | 104 | def __len__(self): 105 | return len(self.pairs) 106 | 107 | def get_max(self,input): 108 | T,_=torch.max(input,dim=0) 109 | T=T+0.1 110 | input[0,:,:] = input[0,:,:]/ T 111 | input[1,:,:] = input[1,:,:]/ T 112 | input[2,:,:]= input[2,:,:] / T 113 | return input 114 | 115 | def load_pairs(self, folder_path): 116 | 117 | low_list = os.listdir(os.path.join(folder_path, 'Low' if self.train else 'Low')) 118 | low_list = sorted(list(filter(lambda x: 'png' in x, low_list))) 119 | high_list = os.listdir(os.path.join(folder_path, 'Normal' if self.train else 'Normal')) 120 | high_list = sorted(list(filter(lambda x: 'png' in x, high_list))) 121 | pairs = [] 122 | 123 | for idx in range(len(low_list)): 124 | f_name_low = low_list[idx] 125 | f_name_high = high_list[idx] 126 | pairs.append( 127 | [cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'Low' if self.train else 'Low', f_name_low)), 128 | cv2.COLOR_BGR2RGB), 129 | cv2.cvtColor(cv2.imread(os.path.join(folder_path, 'Normal' if self.train else 'Normal', f_name_high)), 130 | cv2.COLOR_BGR2RGB), 131 | f_name_high.split('.')[0]]) 132 | return pairs 133 | 134 | def __getitem__(self, item): 135 | 136 | lr, hr, f_name = self.pairs[item] 137 | 138 | if self.use_crop and self.split != 'val': 139 | hr, lr = random_crop(hr, lr, self.crop_size) 140 | elif self.sub_data == 'Real_captured' and self.split == 'val': # for Real_captured 141 | lr = cv2.copyMakeBorder(lr, 8,8,4,4,cv2.BORDER_REFLECT) 142 | 143 | if self.use_flip: 144 | hr, lr = random_flip(hr, lr) 145 | 146 | if self.use_rot: 147 | hr, lr = random_rotation(hr, lr) 148 | 149 | 150 | hr = self.to_tensor(hr) 151 | lr = self.to_tensor(lr) 152 | # lr = self.get_max(lr) 153 | 154 | 155 | [lr, hr] = transform_augment( 156 | [lr, hr], split=self.split, min_max=(-1, 1)) 157 | 158 | return {'LQ': lr, 'GT': hr, 'LQ_path': f_name, 'GT_path': f_name} 159 | 160 | 161 | def random_flip(img, seg): 162 | random_choice = np.random.choice([True, False]) 163 | img = img if random_choice else np.flip(img, 1).copy() 164 | seg = seg if random_choice else np.flip(seg, 1).copy() 165 | 166 | return img, seg 167 | 168 | 169 | def gamma_aug(img, gamma=0): 170 | max_val = img.max() 171 | img_after_norm = img / max_val 172 | img_after_norm = np.power(img_after_norm, gamma) 173 | return img_after_norm * max_val 174 | 175 | 176 | def random_rotation(img, seg): 177 | random_choice = np.random.choice([0, 1, 3]) 178 | img = np.rot90(img, random_choice, axes=(0, 1)).copy() 179 | seg = np.rot90(seg, random_choice, axes=(0, 1)).copy() 180 | 181 | return img, seg 182 | 183 | 184 | def random_crop(hr, lr, size_hr): 185 | size_lr = size_hr 186 | 187 | size_lr_x = lr.shape[0] 188 | size_lr_y = lr.shape[1] 189 | 190 | start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0 191 | start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0 192 | 193 | # LR Patch 194 | lr_patch = lr[start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr, :] 195 | 196 | # HR Patch 197 | start_x_hr = start_x_lr 198 | start_y_hr = start_y_lr 199 | hr_patch = hr[start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr, :] 200 | 201 | # HisEq Patch 202 | his_eq_patch = None 203 | return hr_patch, lr_patch, 204 | 205 | 206 | # implementation by torchvision, detail in https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement/issues/14 207 | totensor = torchvision.transforms.ToTensor() 208 | hflip = torchvision.transforms.RandomHorizontalFlip() 209 | def transform_augment(imgs, split='val', min_max=(0, 1)): 210 | # imgs = [totensor(img) for img in img_list] 211 | if split == 'train': 212 | imgs = torch.stack(imgs, 0) 213 | # imgs = hflip(imgs) 214 | imgs = torch.unbind(imgs, dim=0) 215 | ret_img = [img * (min_max[1] - min_max[0]) + min_max[0] for img in imgs] 216 | return ret_img 217 | 218 | -------------------------------------------------------------------------------- /data/__init__.py: -------------------------------------------------------------------------------- 1 | '''create dataset and dataloader''' 2 | import logging 3 | import torch 4 | import torch.utils.data 5 | 6 | 7 | 8 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 9 | phase = dataset_opt['phase'] 10 | if phase == 'train': 11 | if dataset_opt['dist']: 12 | world_size = torch.distributed.get_world_size() 13 | num_workers = dataset_opt['n_workers'] 14 | assert dataset_opt['batch_size'] % world_size == 0 15 | batch_size = dataset_opt['batch_size'] // world_size 16 | shuffle = False 17 | sampler=torch.utils.data.distributed.DistributedSampler(dataset) 18 | else: 19 | num_workers = dataset_opt['n_workers'] # * len(opt['gpu_ids']) 20 | batch_size = dataset_opt['batch_size'] 21 | shuffle = True 22 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 23 | num_workers=num_workers, sampler=sampler, drop_last=True, 24 | pin_memory=True) 25 | else: 26 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, 27 | pin_memory=True) 28 | 29 | 30 | # def create_dataloader(train, dataset, dataset_opt, opt=None, sampler=None): 31 | # # gpu_ids = opt.get('gpu_ids', None) 32 | # gpu_ids = [] 33 | # gpu_ids = gpu_ids if gpu_ids else [] 34 | # num_workers = dataset_opt['n_workers'] * (len(gpu_ids)+1) 35 | # batch_size = dataset_opt['batch_size'] 36 | # shuffle = True 37 | # if train: 38 | # return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 39 | # num_workers=num_workers, sampler=sampler, drop_last=True, 40 | # pin_memory=False) 41 | # else: 42 | # return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=False, 43 | # num_workers=num_workers, sampler=sampler, drop_last=False, 44 | # pin_memory=False) 45 | 46 | -------------------------------------------------------------------------------- /data/__pycache__/LoL_dataset.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/data/__pycache__/LoL_dataset.cpython-38.pyc -------------------------------------------------------------------------------- /data/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/data/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /data/single_image_dataset.py: -------------------------------------------------------------------------------- 1 | from os import path as osp 2 | from torch.utils import data as data 3 | from torchvision.transforms.functional import normalize 4 | 5 | from basicsr.data.data_util import paths_from_lmdb 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor, scandir 7 | 8 | 9 | class SingleImageDataset(data.Dataset): 10 | """Read only lq images in the test phase. 11 | 12 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc). 13 | 14 | There are two modes: 15 | 1. 'meta_info_file': Use meta information file to generate paths. 16 | 2. 'folder': Scan folders to generate paths. 17 | 18 | Args: 19 | opt (dict): Config for train datasets. It contains the following keys: 20 | dataroot_lq (str): Data root path for lq. 21 | meta_info_file (str): Path for meta information file. 22 | io_backend (dict): IO backend type and other kwarg. 23 | """ 24 | 25 | def __init__(self, opt): 26 | super(SingleImageDataset, self).__init__() 27 | self.opt = opt 28 | # file client (io backend) 29 | self.file_client = None 30 | self.io_backend_opt = opt['io_backend'] 31 | self.mean = opt['mean'] if 'mean' in opt else None 32 | self.std = opt['std'] if 'std' in opt else None 33 | self.lq_folder = opt['dataroot_lq'] 34 | 35 | if self.io_backend_opt['type'] == 'lmdb': 36 | self.io_backend_opt['db_paths'] = [self.lq_folder] 37 | self.io_backend_opt['client_keys'] = ['lq'] 38 | self.paths = paths_from_lmdb(self.lq_folder) 39 | elif 'meta_info_file' in self.opt: 40 | with open(self.opt['meta_info_file'], 'r') as fin: 41 | self.paths = [ 42 | osp.join(self.lq_folder, 43 | line.split(' ')[0]) for line in fin 44 | ] 45 | else: 46 | self.paths = sorted(list(scandir(self.lq_folder, full_path=True))) 47 | 48 | def __getitem__(self, index): 49 | if self.file_client is None: 50 | self.file_client = FileClient( 51 | self.io_backend_opt.pop('type'), **self.io_backend_opt) 52 | 53 | # load lq image 54 | lq_path = self.paths[index] 55 | img_bytes = self.file_client.get(lq_path, 'lq') 56 | img_lq = imfrombytes(img_bytes, float32=True) 57 | 58 | # TODO: color space transform 59 | # BGR to RGB, HWC to CHW, numpy to tensor 60 | img_lq = img2tensor(img_lq, bgr2rgb=True, float32=True) 61 | # normalize 62 | if self.mean is not None or self.std is not None: 63 | normalize(img_lq, self.mean, self.std, inplace=True) 64 | return {'lq': img_lq, 'lq_path': lq_path} 65 | 66 | def __len__(self): 67 | return len(self.paths) 68 | -------------------------------------------------------------------------------- /data/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import pickle 4 | import random 5 | import numpy as np 6 | import glob 7 | import torch 8 | import cv2 9 | 10 | #################### 11 | # Files & IO 12 | #################### 13 | 14 | ###################### get image path list ###################### 15 | IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP'] 16 | 17 | 18 | def flip(x, dim): 19 | indices = [slice(None)] * x.dim() 20 | indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, 21 | dtype=torch.long, device=x.device) 22 | return x[tuple(indices)] 23 | 24 | 25 | def is_image_file(filename): 26 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 27 | 28 | 29 | def _get_paths_from_images(path): 30 | """get image path list from image folder""" 31 | assert os.path.isdir(path), '{:s} is not a valid directory'.format(path) 32 | images = [] 33 | for dirpath, _, fnames in sorted(os.walk(path)): 34 | for fname in sorted(fnames): 35 | if is_image_file(fname): 36 | img_path = os.path.join(dirpath, fname) 37 | images.append(img_path) 38 | assert images, '{:s} has no valid image file'.format(path) 39 | return images 40 | 41 | 42 | def _get_paths_from_lmdb(dataroot): 43 | """get image path list from lmdb meta info""" 44 | meta_info = pickle.load(open(os.path.join(dataroot, 'meta_info.pkl'), 'rb')) 45 | paths = meta_info['keys'] 46 | sizes = meta_info['resolution'] 47 | if len(sizes) == 1: 48 | sizes = sizes * len(paths) 49 | return paths, sizes 50 | 51 | 52 | def get_image_paths(data_type, dataroot): 53 | """get image path list 54 | support lmdb or image files""" 55 | paths, sizes = None, None 56 | if dataroot is not None: 57 | if data_type == 'lmdb': 58 | paths, sizes = _get_paths_from_lmdb(dataroot) 59 | elif data_type == 'img': 60 | paths = sorted(_get_paths_from_images(dataroot)) 61 | else: 62 | raise NotImplementedError('data_type [{:s}] is not recognized.'.format(data_type)) 63 | return paths, sizes 64 | 65 | 66 | def glob_file_list(root): 67 | return sorted(glob.glob(os.path.join(root, '*'))) 68 | 69 | 70 | ###################### read images ###################### 71 | def _read_img_lmdb(env, key, size): 72 | """read image from lmdb with key (w/ and w/o fixed size) 73 | size: (C, H, W) tuple""" 74 | with env.begin(write=False) as txn: 75 | buf = txn.get(key.encode('ascii')) 76 | img_flat = np.frombuffer(buf, dtype=np.uint8) 77 | C, H, W = size 78 | img = img_flat.reshape(H, W, C) 79 | return img 80 | 81 | 82 | def read_img(env, path, size=None): 83 | """read image by cv2 or from lmdb 84 | return: Numpy float32, HWC, BGR, [0,1]""" 85 | if env is None: # img 86 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 87 | if img is None: 88 | print(path) 89 | if size is not None: 90 | img = cv2.resize(img, (size[0], size[1])) 91 | else: 92 | img = _read_img_lmdb(env, path, size) 93 | 94 | img = img.astype(np.float32) / 255. 95 | if img.ndim == 2: 96 | img = np.expand_dims(img, axis=2) 97 | # some images have 4 channels 98 | if img.shape[2] > 3: 99 | img = img[:, :, :3] 100 | return img 101 | 102 | 103 | def read_img2(env, path, size=None): 104 | """read image by cv2 or from lmdb 105 | return: Numpy float32, HWC, BGR, [0,1]""" 106 | if env is None: # img 107 | img = np.load(path) 108 | if img is None: 109 | print(path) 110 | if size is not None: 111 | img = cv2.resize(img, (size[0], size[1])) 112 | # img = cv2.resize(img, size) 113 | else: 114 | img = _read_img_lmdb(env, path, size) 115 | img = get_max(img) 116 | img = img.astype(np.float32) / 255. 117 | if img.ndim == 2: 118 | img = np.expand_dims(img, axis=2) 119 | # some images have 4 channels 120 | if img.shape[2] > 3: 121 | img = img[:, :, :3] 122 | return img 123 | 124 | 125 | def read_img_seq(path, size=None): 126 | """Read a sequence of images from a given folder path 127 | Args: 128 | path (list/str): list of image paths/image folder path 129 | 130 | Returns: 131 | imgs (Tensor): size (T, C, H, W), RGB, [0, 1] 132 | """ 133 | # print(path) 134 | if type(path) is list: 135 | img_path_l = path 136 | else: 137 | img_path_l = sorted(glob.glob(os.path.join(path, '*'))) 138 | 139 | img_l = [read_img(None, v, size) for v in img_path_l] 140 | # stack to Torch tensor 141 | imgs = np.stack(img_l, axis=0) 142 | try: 143 | imgs = imgs[:, :, :, [2, 1, 0]] 144 | except Exception: 145 | import ipdb; ipdb.set_trace() 146 | imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() 147 | return imgs 148 | 149 | 150 | def read_img_seq2(path, size=None): 151 | """Read a sequence of images from a given folder path 152 | Args: 153 | path (list/str): list of image paths/image folder path 154 | 155 | Returns: 156 | imgs (Tensor): size (T, C, H, W), RGB, [0, 1] 157 | """ 158 | # print(path) 159 | if type(path) is list: 160 | img_path_l = path 161 | else: 162 | img_path_l = sorted(glob.glob(os.path.join(path, '*'))) 163 | 164 | img_l = [read_img2(None, v, size) for v in img_path_l] 165 | # stack to Torch tensor 166 | imgs = np.stack(img_l, axis=0) 167 | try: 168 | imgs = imgs[:, :, :, [2, 1, 0]] 169 | except Exception: 170 | import ipdb; ipdb.set_trace() 171 | imgs = torch.from_numpy(np.ascontiguousarray(np.transpose(imgs, (0, 3, 1, 2)))).float() 172 | return imgs 173 | 174 | def get_max(x): 175 | T =np.max(x,axis=0) 176 | T=T+0.1 177 | x[0,:,:] = x[0,:,:]/ T 178 | x[1,:,:] = x[1,:,:]/ T 179 | x[2,:,:]= x[2,:,:] / T 180 | return x 181 | 182 | 183 | def index_generation(crt_i, max_n, N, padding='reflection'): 184 | """Generate an index list for reading N frames from a sequence of images 185 | Args: 186 | crt_i (int): current center index 187 | max_n (int): max number of the sequence of images (calculated from 1) 188 | N (int): reading N frames 189 | padding (str): padding mode, one of replicate | reflection | new_info | circle 190 | Example: crt_i = 0, N = 5 191 | replicate: [0, 0, 0, 1, 2] 192 | reflection: [2, 1, 0, 1, 2] 193 | new_info: [4, 3, 0, 1, 2] 194 | circle: [3, 4, 0, 1, 2] 195 | 196 | Returns: 197 | return_l (list [int]): a list of indexes 198 | """ 199 | max_n = max_n - 1 200 | n_pad = N // 2 201 | return_l = [] 202 | 203 | for i in range(crt_i - n_pad, crt_i + n_pad + 1): 204 | if i < 0: 205 | if padding == 'replicate': 206 | add_idx = 0 207 | elif padding == 'reflection': 208 | add_idx = -i 209 | elif padding == 'new_info': 210 | add_idx = (crt_i + n_pad) + (-i) 211 | elif padding == 'circle': 212 | add_idx = N + i 213 | else: 214 | raise ValueError('Wrong padding mode') 215 | elif i > max_n: 216 | if padding == 'replicate': 217 | add_idx = max_n 218 | elif padding == 'reflection': 219 | add_idx = max_n * 2 - i 220 | elif padding == 'new_info': 221 | add_idx = (crt_i - n_pad) - (i - max_n) 222 | elif padding == 'circle': 223 | add_idx = i - N 224 | else: 225 | raise ValueError('Wrong padding mode') 226 | else: 227 | add_idx = i 228 | return_l.append(add_idx) 229 | return return_l 230 | 231 | 232 | #################### 233 | # image processing 234 | # process on numpy image 235 | #################### 236 | 237 | 238 | def augment(img_list, hflip=True, rot=True): 239 | """horizontal flip OR rotate (0, 90, 180, 270 degrees)""" 240 | hflip = hflip and random.random() < 0.5 241 | vflip = rot and random.random() < 0.5 242 | rot90 = rot and random.random() < 0.5 243 | 244 | def _augment(img): 245 | if hflip: 246 | img = img[:, ::-1, :] 247 | if vflip: 248 | img = img[::-1, :, :] 249 | if rot90: 250 | # import pdb; pdb.set_trace() 251 | img = img.transpose(1, 0, 2) 252 | return img 253 | 254 | return [_augment(img) for img in img_list] 255 | 256 | 257 | 258 | def augment_torch(img_list, hflip=True, rot=True): 259 | """horizontal flip OR rotate (0, 90, 180, 270 degrees)""" 260 | hflip = hflip and random.random() < 0.5 261 | vflip = rot and random.random() < 0.5 262 | # rot90 = rot and random.random() < 0.5 263 | 264 | def _augment(img): 265 | if hflip: 266 | img = flip(img, 2) 267 | if vflip: 268 | img = flip(img, 1) 269 | # if rot90: 270 | # # import pdb; pdb.set_trace() 271 | # img = img.transpose(1, 0, 2) 272 | return img 273 | 274 | return [_augment(img) for img in img_list] 275 | 276 | 277 | def augment_flow(img_list, flow_list, hflip=True, rot=True): 278 | """horizontal flip OR rotate (0, 90, 180, 270 degrees) with flows""" 279 | hflip = hflip and random.random() < 0.5 280 | vflip = rot and random.random() < 0.5 281 | rot90 = rot and random.random() < 0.5 282 | 283 | def _augment(img): 284 | if hflip: 285 | img = img[:, ::-1, :] 286 | if vflip: 287 | img = img[::-1, :, :] 288 | if rot90: 289 | img = img.transpose(1, 0, 2) 290 | return img 291 | 292 | def _augment_flow(flow): 293 | if hflip: 294 | flow = flow[:, ::-1, :] 295 | flow[:, :, 0] *= -1 296 | if vflip: 297 | flow = flow[::-1, :, :] 298 | flow[:, :, 1] *= -1 299 | if rot90: 300 | flow = flow.transpose(1, 0, 2) 301 | flow = flow[:, :, [1, 0]] 302 | return flow 303 | 304 | rlt_img_list = [_augment(img) for img in img_list] 305 | rlt_flow_list = [_augment_flow(flow) for flow in flow_list] 306 | 307 | return rlt_img_list, rlt_flow_list 308 | 309 | 310 | def channel_convert(in_c, tar_type, img_list): 311 | """conversion among BGR, gray and y""" 312 | if in_c == 3 and tar_type == 'gray': # BGR to gray 313 | gray_list = [cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) for img in img_list] 314 | return [np.expand_dims(img, axis=2) for img in gray_list] 315 | elif in_c == 3 and tar_type == 'y': # BGR to y 316 | y_list = [bgr2ycbcr(img, only_y=True) for img in img_list] 317 | return [np.expand_dims(img, axis=2) for img in y_list] 318 | elif in_c == 1 and tar_type == 'RGB': # gray/y to BGR 319 | return [cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) for img in img_list] 320 | else: 321 | return img_list 322 | 323 | 324 | def rgb2ycbcr(img, only_y=True): 325 | """same as matlab rgb2ycbcr 326 | only_y: only return Y channel 327 | Input: 328 | uint8, [0, 255] 329 | float, [0, 1] 330 | """ 331 | in_img_type = img.dtype 332 | img.astype(np.float32) 333 | if in_img_type != np.uint8: 334 | img *= 255. 335 | # convert 336 | if only_y: 337 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 338 | else: 339 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 340 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] 341 | if in_img_type == np.uint8: 342 | rlt = rlt.round() 343 | else: 344 | rlt /= 255. 345 | return rlt.astype(in_img_type) 346 | 347 | 348 | def bgr2ycbcr(img, only_y=True): 349 | """bgr version of rgb2ycbcr 350 | only_y: only return Y channel 351 | Input: 352 | uint8, [0, 255] 353 | float, [0, 1] 354 | """ 355 | in_img_type = img.dtype 356 | img.astype(np.float32) 357 | if in_img_type != np.uint8: 358 | img *= 255. 359 | # convert 360 | if only_y: 361 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 362 | else: 363 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 364 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 365 | if in_img_type == np.uint8: 366 | rlt = rlt.round() 367 | else: 368 | rlt /= 255. 369 | return rlt.astype(in_img_type) 370 | 371 | 372 | def ycbcr2rgb(img): 373 | """same as matlab ycbcr2rgb 374 | Input: 375 | uint8, [0, 255] 376 | float, [0, 1] 377 | """ 378 | in_img_type = img.dtype 379 | img.astype(np.float32) 380 | if in_img_type != np.uint8: 381 | img *= 255. 382 | # convert 383 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], 384 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] 385 | if in_img_type == np.uint8: 386 | rlt = rlt.round() 387 | else: 388 | rlt /= 255. 389 | return rlt.astype(in_img_type) 390 | 391 | 392 | def modcrop(img_in, scale): 393 | """img_in: Numpy, HWC or HW""" 394 | img = np.copy(img_in) 395 | if img.ndim == 2: 396 | H, W = img.shape 397 | H_r, W_r = H % scale, W % scale 398 | img = img[:H - H_r, :W - W_r] 399 | elif img.ndim == 3: 400 | H, W, C = img.shape 401 | H_r, W_r = H % scale, W % scale 402 | img = img[:H - H_r, :W - W_r, :] 403 | else: 404 | raise ValueError('Wrong img ndim: [{:d}].'.format(img.ndim)) 405 | return img 406 | 407 | -------------------------------------------------------------------------------- /dataset/LOLv1/readme.txt: -------------------------------------------------------------------------------- 1 | You can refer to the corresponding link to download the [LOL](https://daooshee.github.io/BMVC2018website/) dataset and put it in this folder. 2 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger('base') 3 | 4 | 5 | def create_model(opt): 6 | from .model import DDPM as M 7 | from .model import DDPM_PD as M_PD 8 | # print(opt['distill']) 9 | # import pdb; pdb.set_trace() 10 | if opt['distill']: 11 | m=M_PD(opt) 12 | else: 13 | m = M(opt) 14 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 15 | return m 16 | -------------------------------------------------------------------------------- /model/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/base_model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/__pycache__/base_model.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/model.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/__pycache__/model.cpython-38.pyc -------------------------------------------------------------------------------- /model/__pycache__/networks.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/__pycache__/networks.cpython-38.pyc -------------------------------------------------------------------------------- /model/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | 5 | 6 | class BaseModel(): 7 | def __init__(self, opt): 8 | self.opt = opt 9 | self.device = torch.device( 10 | 'cuda' if opt['gpu_ids'] is not None else 'cpu') 11 | self.begin_step = 0 12 | self.begin_epoch = 0 13 | 14 | def feed_data(self, data): 15 | pass 16 | 17 | def optimize_parameters(self): 18 | pass 19 | 20 | def get_current_visuals(self): 21 | pass 22 | 23 | def get_current_losses(self): 24 | pass 25 | 26 | def print_network(self): 27 | pass 28 | 29 | def set_device(self, x): 30 | if isinstance(x, dict): 31 | for key, item in x.items(): 32 | if item is not None: 33 | x[key] = item.to(self.device) 34 | elif isinstance(x, list): 35 | for item in x: 36 | if item is not None: 37 | item = item.to(self.device) 38 | else: 39 | x = x.to(self.device) 40 | return x 41 | 42 | def get_network_description(self, network): 43 | '''Get the string and total parameters of the network''' 44 | if isinstance(network, nn.DataParallel): 45 | network = network.module 46 | s = str(network) 47 | n = sum(map(lambda x: x.numel(), network.parameters())) 48 | return s, n 49 | -------------------------------------------------------------------------------- /model/ddpm_modules/__pycache__/diffusion.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/ddpm_modules/__pycache__/diffusion.cpython-38.pyc -------------------------------------------------------------------------------- /model/ddpm_modules/__pycache__/unet.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/model/ddpm_modules/__pycache__/unet.cpython-38.pyc -------------------------------------------------------------------------------- /model/ddpm_modules/diffusion.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import device, nn, einsum 4 | import torch.nn.functional as F 5 | from inspect import isfunction 6 | from functools import partial 7 | import numpy as np 8 | from tqdm import tqdm 9 | import cv2 10 | import torchvision.transforms as T 11 | 12 | from sklearn.cluster import AgglomerativeClustering 13 | from sklearn.cluster import MeanShift 14 | from sklearn.cluster import DBSCAN 15 | from sklearn.cluster import SpectralClustering 16 | import lpips 17 | from torchvision.utils import save_image 18 | from torch.optim.swa_utils import AveragedModel 19 | 20 | 21 | 22 | transform = T.Lambda(lambda t: (t + 1) / 2) 23 | 24 | def extract(v, t, x_shape): 25 | 26 | try: 27 | out = torch.gather(v, index=t, dim=0).float() 28 | except: 29 | # import pdb; pdb.set_trace() 30 | print(t) 31 | # import pdb; pdb.set_trace() 32 | print(print(v.shape)) 33 | # import pdb; pdb.set_trace() 34 | return out.view([t.shape[0]] + [1] * (len(x_shape) - 1)) 35 | 36 | 37 | 38 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac): 39 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 40 | warmup_time = int(n_timestep * warmup_frac) 41 | betas[:warmup_time] = np.linspace( 42 | linear_start, linear_end, warmup_time, dtype=np.float64) 43 | return betas 44 | 45 | 46 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 47 | if schedule == 'quad': 48 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, 49 | n_timestep, dtype=np.float64) ** 2 50 | elif schedule == 'linear': 51 | betas = np.linspace(linear_start, linear_end, 52 | n_timestep, dtype=np.float64) 53 | elif schedule == 'warmup10': 54 | betas = _warmup_beta(linear_start, linear_end, 55 | n_timestep, 0.1) 56 | elif schedule == 'warmup50': 57 | betas = _warmup_beta(linear_start, linear_end, 58 | n_timestep, 0.5) 59 | elif schedule == 'const': 60 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 61 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 62 | betas = 1. / np.linspace(n_timestep, 63 | 1, n_timestep, dtype=np.float64) 64 | elif schedule == "cosine": 65 | timesteps = ( 66 | torch.arange(n_timestep + 1, dtype=torch.float64) / 67 | n_timestep + cosine_s 68 | ) 69 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 70 | alphas = torch.cos(alphas).pow(2) 71 | alphas = alphas / alphas[0] 72 | betas = 1 - alphas[1:] / alphas[:-1] 73 | betas = betas.clamp(max=0.999) 74 | else: 75 | raise NotImplementedError(schedule) 76 | return betas 77 | 78 | 79 | # gaussian diffusion trainer class 80 | 81 | def exists(x): 82 | return x is not None 83 | 84 | 85 | def default(val, d): 86 | if exists(val): 87 | return val 88 | return d() if isfunction(d) else d 89 | 90 | 91 | class GaussianDiffusion(nn.Module): 92 | def __init__( 93 | self, 94 | denoise_fn, 95 | image_size, 96 | num_timesteps, 97 | time_scale, 98 | w_str, 99 | w_gt, 100 | w_snr, 101 | w_lpips, 102 | channels=3, 103 | loss_type='l1', 104 | conditional=True, 105 | schedule_opt=None 106 | ): 107 | super().__init__() 108 | self.channels = channels 109 | self.image_size = image_size 110 | self.denoise_fn = denoise_fn 111 | self.loss_type = loss_type 112 | self.conditional = conditional 113 | self.num_timesteps = num_timesteps 114 | device = torch.device("cuda") 115 | 116 | self.w_str = w_str 117 | self.w_gt = w_gt 118 | self.w_snr = w_snr 119 | self.w_lpips = w_lpips 120 | # self.lpips = lpips.LPIPS(net='vgg').cuda() 121 | # print(self.num_timesteps) 122 | # import pdb; pdb.set_trace() 123 | self.time_scale = time_scale 124 | self.CD = False 125 | if schedule_opt is not None: 126 | self.set_new_noise_schedule(schedule_opt, device) 127 | 128 | def set_loss(self, device): 129 | if self.loss_type == 'l1': 130 | self.loss_func = nn.L1Loss(reduction='sum').to(device) 131 | elif self.loss_type == 'l2': 132 | self.loss_func = nn.MSELoss(reduction='sum').to(device) 133 | else: 134 | raise NotImplementedError() 135 | 136 | def set_new_noise_schedule(self, schedule_opt, device): 137 | to_torch = partial(torch.tensor, dtype=torch.float32, device=device) 138 | 139 | betas = make_beta_schedule( 140 | schedule=schedule_opt['schedule'], 141 | n_timestep= self.num_timesteps* self.time_scale + 1, 142 | linear_start=schedule_opt['linear_start'], 143 | linear_end=schedule_opt['linear_end']) 144 | betas = betas.detach().cpu().numpy() if isinstance( 145 | betas, torch.Tensor) else betas 146 | alphas = 1. - betas 147 | alphas_cumprod = np.cumprod(alphas, axis=0) 148 | alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) 149 | self.sqrt_alphas_cumprod_prev = np.sqrt( 150 | np.append(1., alphas_cumprod)) 151 | 152 | timesteps, = betas.shape 153 | self.register_buffer('betas', to_torch(betas)) 154 | self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) 155 | self.register_buffer('alphas_cumprod_prev', 156 | to_torch(alphas_cumprod_prev)) 157 | 158 | # calculations for diffusion q(x_t | x_{t-1}) and others 159 | self.register_buffer('sqrt_alphas_cumprod', 160 | to_torch(np.sqrt(alphas_cumprod))) 161 | self.register_buffer('sqrt_one_minus_alphas_cumprod', 162 | to_torch(np.sqrt(1. - alphas_cumprod))) 163 | self.register_buffer('log_one_minus_alphas_cumprod', 164 | to_torch(np.log(1. - alphas_cumprod))) 165 | self.register_buffer('sqrt_recip_alphas_cumprod', 166 | to_torch(np.sqrt(1. / alphas_cumprod))) 167 | self.register_buffer('sqrt_recipm1_alphas_cumprod', 168 | to_torch(np.sqrt(1. / alphas_cumprod - 1))) 169 | 170 | # calculations for posterior q(x_{t-1} | x_t, x_0) 171 | posterior_variance = betas * \ 172 | (1. - alphas_cumprod_prev) / (1. - alphas_cumprod) 173 | # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) 174 | self.register_buffer('posterior_variance', 175 | to_torch(posterior_variance)) 176 | # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain 177 | self.register_buffer('posterior_log_variance_clipped', to_torch( 178 | np.log(np.maximum(posterior_variance, 1e-20)))) 179 | self.register_buffer('posterior_mean_coef1', to_torch( 180 | betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))) 181 | self.register_buffer('posterior_mean_coef2', to_torch( 182 | (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod))) 183 | 184 | 185 | def predict_start_from_noise(self, x_t, t, noise): 186 | return self.sqrt_recip_alphas_cumprod[t] * x_t - \ 187 | self.sqrt_recipm1_alphas_cumprod[t] * noise 188 | 189 | def predict_eps_from_x(self, x_t, x_0, t): 190 | 191 | eps = (x_t -self.sqrt_alphas_cumprod[t] * x_0) / self.sqrt_one_minus_alphas_cumprod[t] 192 | return eps 193 | 194 | def predict_eps(self, x_t, x_0, continuous_sqrt_alpha_cumprod): 195 | 196 | eps = (1. / (1 - continuous_sqrt_alpha_cumprod **2).sqrt()) * x_t - \ 197 | (1. / (1 - continuous_sqrt_alpha_cumprod**2) -1).sqrt() * x_0 198 | 199 | return eps 200 | 201 | def predict_start(self, x_t, continuous_sqrt_alpha_cumprod, noise): 202 | 203 | return (1. / continuous_sqrt_alpha_cumprod) * x_t - \ 204 | (1. / continuous_sqrt_alpha_cumprod**2 - 1).sqrt() * noise 205 | 206 | def predict_t_minus1(self, x, t, continuous_sqrt_alpha_cumprod, noise, clip_denoised=True): 207 | 208 | x_recon = self.predict_start(x, 209 | continuous_sqrt_alpha_cumprod=continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), 210 | noise=noise) 211 | 212 | if clip_denoised: 213 | x_recon.clamp_(-1., 1.) 214 | 215 | model_mean, model_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t) 216 | 217 | noise_z = torch.randn_like(x) if t > 0 else torch.zeros_like(x) 218 | 219 | return model_mean + noise_z * (0.5 * model_log_variance).exp() 220 | 221 | def q_posterior(self, x_start, x_t, t): 222 | posterior_mean = self.posterior_mean_coef1[t] * \ 223 | x_start + self.posterior_mean_coef2[t] * x_t 224 | posterior_log_variance_clipped = self.posterior_log_variance_clipped[t] 225 | return posterior_mean, posterior_log_variance_clipped 226 | 227 | def p_mean_variance(self, x, t, clip_denoised: bool, condition_x=None): 228 | batch_size = x.shape[0] 229 | noise_level = torch.FloatTensor( 230 | [self.sqrt_alphas_cumprod_prev[t+self.time_scale]]).repeat(batch_size, 1).to(x.device) 231 | 232 | eps = self.denoise_fn(torch.cat([condition_x, x], dim=1), noise_level)[0] 233 | # print(t) 234 | 235 | x_recon = self.predict_start_from_noise(x, t=t*self.time_scale, noise=eps) 236 | 237 | 238 | if clip_denoised: 239 | x_recon.clamp_(-1., 1.) 240 | 241 | model_mean, posterior_log_variance = self.q_posterior( 242 | x_start=x_recon, x_t=x, t=t) 243 | return model_mean, posterior_log_variance, eps 244 | 245 | @torch.no_grad() 246 | def p_sample(self, x, t, clip_denoised=True, condition_x=None): 247 | model_mean, model_log_variance, eps = self.p_mean_variance( 248 | x=x, t=t, clip_denoised=clip_denoised, condition_x=condition_x) 249 | noise = torch.randn_like(x) if t > 0 else torch.zeros_like(x) 250 | return model_mean + noise * (0.5 * model_log_variance).exp() 251 | 252 | @torch.no_grad() 253 | def p_sample_loop(self, x_in, continous=False): 254 | device = self.betas.device 255 | sample_inter = (1 | (self.num_timesteps//10)) 256 | if not self.conditional: 257 | shape = x_in 258 | img = torch.randn(shape, device=device) 259 | ret_img = img 260 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 261 | img = self.p_sample(img, i) 262 | if i % sample_inter == 0: 263 | ret_img = torch.cat([ret_img, img], dim=0) 264 | else: 265 | x = x_in 266 | shape = x.shape 267 | img = torch.randn(shape, device=device) 268 | ret_img = x 269 | # print(self.time_scale) 270 | # import pdb; pdb.set_trace() 271 | for i in tqdm(reversed(range(0, self.num_timesteps)), desc='sampling loop time step', total=self.num_timesteps): 272 | img = self.p_sample(img, i, condition_x=x) 273 | if i % sample_inter == 0: 274 | ret_img = torch.cat([ret_img, img], dim=0) 275 | if continous: 276 | return ret_img 277 | else: 278 | return ret_img[-1] 279 | 280 | @torch.no_grad() 281 | def sample(self, batch_size=1, continous=False): 282 | image_size = self.image_size 283 | channels = self.channels 284 | return self.p_sample_loop((batch_size, channels, image_size, image_size), continous) 285 | 286 | @torch.no_grad() 287 | def super_resolution(self, x_in, continous=False, stride=1): 288 | return self.ddim(x_in, continous, stride=stride) 289 | 290 | def q_sample(self, x_start, continuous_sqrt_alpha_cumprod, noise=None): 291 | 292 | return ( 293 | continuous_sqrt_alpha_cumprod * x_start + 294 | (1 - continuous_sqrt_alpha_cumprod**2).sqrt() * noise 295 | ) 296 | 297 | def ddim(self, x_in, continous=False, snr_aware=False, stride=1, clip_denoised=True): 298 | x = x_in 299 | condition_x = x_in 300 | x_t = torch.randn(x.shape, device=x.device) 301 | 302 | batch_size = x_in.shape[0] 303 | 304 | for time_step in reversed(range(stride, self.num_timesteps + 1, stride)): 305 | 306 | 307 | t = x_t.new_ones([x_t.shape[0], ], dtype=torch.long) * time_step 308 | s = x_t.new_ones([x_t.shape[0], ], dtype=torch.long) * (time_step - stride) 309 | noise_level = torch.FloatTensor([self.sqrt_alphas_cumprod_prev[t* self.time_scale]]).repeat(batch_size, 1).to(x.device) 310 | eps =self.denoise_fn(torch.cat([condition_x, x_t], dim=1), noise_level)[0] 311 | x_0 = self.predict_start_from_noise(x_t, t * self.time_scale, eps) 312 | if clip_denoised: 313 | x_0 = torch.clip(x_0, -1., 1.) 314 | eps = self.predict_eps_from_x(x_t, x_0, t * self.time_scale) 315 | 316 | x_t = self.sqrt_alphas_cumprod[s * self.time_scale] * x_0 + self.sqrt_one_minus_alphas_cumprod[s * self.time_scale] * eps 317 | 318 | return torch.clip(x_t, -1, 1) 319 | 320 | def SNR_map(self, x_0): 321 | blur_transform = T.GaussianBlur(kernel_size=15, sigma=3) 322 | blur_x_0 = blur_transform(x_0) 323 | gray_blur_x_0 = blur_x_0[:, 0:1, :, :] * 0.299 + blur_x_0[:, 1:2, :, :] * 0.587 + blur_x_0[:, 2:3, :, :] * 0.114 324 | gray_x_0 = x_0[:, 0:1, :, :] * 0.299 + x_0[:, 1:2, :, :] * 0.587 + x_0[:, 2:3, :, :] * 0.114 325 | noise = torch.abs(gray_blur_x_0 - gray_x_0) 326 | 327 | return noise 328 | 329 | 330 | 331 | def loss(self, x_in, student, noise=None, lpips_func=None): 332 | x_0 = x_in['GT'] 333 | [b, c, h, w] = x_0.shape 334 | 335 | t = 2 * np.random.randint(1, student.num_timesteps + 1) 336 | 337 | continuous_sqrt_alpha_cumprod = torch.FloatTensor( 338 | np.random.uniform( 339 | self.sqrt_alphas_cumprod_prev[(t-1)*self.time_scale], 340 | self.sqrt_alphas_cumprod_prev[t*self.time_scale], 341 | size=b 342 | ) 343 | ).to(x_0.device) 344 | 345 | continuous_sqrt_alpha_cumprod_t_mins_1 = torch.FloatTensor( 346 | np.random.uniform( 347 | self.sqrt_alphas_cumprod_prev[(t-2)*self.time_scale], 348 | self.sqrt_alphas_cumprod_prev[(t-1)*self.time_scale], 349 | size=b 350 | ) 351 | ).to(x_0.device) 352 | continuous_sqrt_alpha_cumprod_t_mins_2 = torch.FloatTensor( 353 | np.random.uniform( 354 | self.sqrt_alphas_cumprod_prev[(t-3)*self.time_scale], 355 | self.sqrt_alphas_cumprod_prev[(t-2)*self.time_scale], 356 | size=b 357 | ) 358 | ).to(x_0.device) 359 | 360 | continuous_sqrt_alpha_cumprod = continuous_sqrt_alpha_cumprod.view(b, -1) 361 | continuous_sqrt_alpha_cumprod_t_mins_1 = continuous_sqrt_alpha_cumprod_t_mins_1.view(b, -1) 362 | continuous_sqrt_alpha_cumprod_t_mins_2 = continuous_sqrt_alpha_cumprod_t_mins_2.view(b, -1) 363 | 364 | noise = default(noise, lambda: torch.randn_like(x_0)) 365 | t = torch.tensor([t], dtype=torch.int64).to(x_0.device) 366 | bs = x_0.size(0) 367 | 368 | with torch.no_grad(): 369 | z_t = self.q_sample(x_0, continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), noise) 370 | eps_rec, _ = self.denoise_fn(torch.cat([x_in['LQ'], z_t], dim=1), continuous_sqrt_alpha_cumprod) 371 | x_0_rec = self.predict_start(z_t, continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), eps_rec) 372 | z_t_minus_1 = self.q_sample(x_0_rec, continuous_sqrt_alpha_cumprod_t_mins_1.view(-1, 1, 1, 1), eps_rec) 373 | eps_rec_rec, _ = self.denoise_fn(torch.cat([x_in['LQ'], z_t_minus_1], dim=1), continuous_sqrt_alpha_cumprod_t_mins_1) 374 | x_0_rec_rec = self.predict_start(z_t_minus_1, continuous_sqrt_alpha_cumprod_t_mins_1.view(-1, 1, 1, 1), eps_rec_rec) 375 | z_t_minus_2 = self.q_sample(x_0_rec_rec, continuous_sqrt_alpha_cumprod_t_mins_2.view(-1, 1, 1, 1), eps_rec_rec) 376 | frac = (1 - continuous_sqrt_alpha_cumprod_t_mins_2.view(-1, 1, 1, 1)**2).sqrt() / (1- continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1)**2).sqrt() 377 | if self.w_snr != 0: 378 | y = x_in['LQ'] 379 | T,_=torch.max(y,dim=1, keepdim=True) 380 | T=T+0.1 381 | y = y / T 382 | iso_noise = self.SNR_map(y) 383 | y = y - iso_noise 384 | refine_x_0 = y 385 | z_t_minus_2_refine = self.q_sample(refine_x_0, continuous_sqrt_alpha_cumprod_t_mins_2.view(-1, 1, 1, 1), eps_rec_rec) 386 | z_t_minus_2 = z_t_minus_2 + self.w_snr *(z_t_minus_2_refine - z_t_minus_2) 387 | 388 | x_target = (z_t_minus_2 - frac * z_t) / ( continuous_sqrt_alpha_cumprod_t_mins_2.view(-1, 1, 1, 1) - frac * continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1 )) 389 | eps_target = self.predict_eps(z_t, x_target, continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1)) 390 | 391 | eps_predicted, _ = student.denoise_fn(torch.cat([x_in['LQ'], z_t], dim=1), continuous_sqrt_alpha_cumprod) 392 | x_0_predicted = self.predict_start(z_t, continuous_sqrt_alpha_cumprod.view(-1, 1, 1, 1), eps_predicted) 393 | loss_x_0 = torch.mean(F.mse_loss(x_0_predicted, x_target, reduction='none').reshape(bs, -1), dim=-1) 394 | loss_eps = torch.mean(F.mse_loss(eps_predicted, eps_target, reduction='none').reshape(bs, -1), dim=-1) 395 | 396 | loss_stru = torch.zeros_like(loss_x_0) # 0. 397 | if self.w_gt != 0: 398 | loss_output_x0 = torch.mean(F.mse_loss(x_0, x_0_predicted, reduction='none').reshape(bs, -1), dim=-1) 399 | loss_output_eps = torch.mean(F.mse_loss(noise, eps_predicted, reduction='none').reshape(bs, -1), dim=-1) 400 | 401 | else: 402 | loss_output_x0 = torch.zeros_like(loss_x_0) # 0. 403 | loss_output_eps = torch.zeros_like(loss_eps) # 0. 404 | 405 | if self.w_lpips != 0: 406 | loss_lpips = torch.mean(lpips_func(x_0, x_0_predicted)) 407 | else: 408 | loss_lpips = torch.zeros_like(loss_x_0) # 0. 409 | 410 | return torch.mean(torch.maximum(loss_x_0, loss_eps)) + \ 411 | self.w_gt * torch.mean(torch.maximum(loss_output_x0, loss_output_eps)) + \ 412 | self.w_lpips*torch.mean(loss_lpips) + \ 413 | self.w_str*torch.mean(loss_stru) 414 | 415 | 416 | 417 | 418 | 419 | def forward(self, x, s_model=None, *args, **kwargs): 420 | return self.loss(x, s_model, *args, **kwargs) 421 | 422 | 423 | -------------------------------------------------------------------------------- /model/ddpm_modules/unet.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | import torch.fft as fft 4 | from torch import nn 5 | import torch.nn.functional as F 6 | from inspect import isfunction 7 | import cv2 8 | import torchvision.transforms as T 9 | import numpy as np 10 | def exists(x): 11 | return x is not None 12 | 13 | # PositionalEncoding Source: https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py 14 | class PositionalEncoding(nn.Module): 15 | def __init__(self, dim): 16 | super().__init__() 17 | self.dim = dim 18 | 19 | def forward(self, noise_level): 20 | count = self.dim // 2 21 | step = torch.arange(count, dtype=noise_level.dtype, 22 | device=noise_level.device) / count 23 | encoding = noise_level.unsqueeze( 24 | 1) * torch.exp(-math.log(1e4) * step.unsqueeze(0)) 25 | encoding = torch.cat( 26 | [torch.sin(encoding), torch.cos(encoding)], dim=-1) 27 | return encoding 28 | 29 | class FeatureWiseAffine(nn.Module): 30 | def __init__(self, in_channels, out_channels, use_affine_level=False): 31 | super(FeatureWiseAffine, self).__init__() 32 | self.use_affine_level = use_affine_level 33 | self.noise_func = nn.Sequential( 34 | nn.Linear(in_channels, out_channels*(1+self.use_affine_level)) 35 | ) 36 | 37 | def forward(self, x, noise_embed): 38 | batch = x.shape[0] 39 | if self.use_affine_level: 40 | gamma, beta = self.noise_func(noise_embed).view( 41 | batch, -1, 1, 1).chunk(2, dim=1) 42 | x = (1 + gamma) * x + beta 43 | else: 44 | x = x + self.noise_func(noise_embed).view(batch, -1, 1, 1) 45 | return x 46 | 47 | def default(val, d): 48 | if exists(val): 49 | return val 50 | return d() if isfunction(d) else d 51 | 52 | # model 53 | class Swish(nn.Module): 54 | def forward(self, x): 55 | return x * torch.sigmoid(x) 56 | 57 | 58 | class Upsample(nn.Module): 59 | def __init__(self, dim): 60 | super().__init__() 61 | self.up = nn.Upsample(scale_factor=2, mode="nearest") 62 | self.conv = nn.Conv2d(dim, dim, 3, padding=1) 63 | 64 | def forward(self, x): 65 | return self.conv(self.up(x)) 66 | 67 | 68 | class Downsample(nn.Module): 69 | def __init__(self, dim): 70 | super().__init__() 71 | self.conv = nn.Conv2d(dim, dim, 3, 2, 1) 72 | 73 | def forward(self, x): 74 | return self.conv(x) 75 | 76 | 77 | # building block modules 78 | class Block(nn.Module): 79 | def __init__(self, dim, dim_out, groups=32, dropout=0): 80 | super().__init__() 81 | self.block = nn.Sequential( 82 | nn.GroupNorm(groups, dim), 83 | Swish(), 84 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(), 85 | nn.Conv2d(dim, dim_out, 3, padding=1) 86 | ) 87 | 88 | def forward(self, x): 89 | return self.block(x) 90 | 91 | 92 | class ResnetBlock(nn.Module): 93 | def __init__(self, dim, dim_out, time_emb_dim=None, dropout=0, norm_groups=32): 94 | super().__init__() 95 | self.mlp = nn.Sequential( 96 | Swish(), 97 | nn.Linear(time_emb_dim, dim_out) 98 | ) if exists(time_emb_dim) else None 99 | self.noise_func = FeatureWiseAffine( 100 | time_emb_dim, dim_out, use_affine_level=False) 101 | 102 | self.block1 = Block(dim, dim_out, groups=norm_groups) 103 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) 104 | self.res_conv = nn.Conv2d( 105 | dim, dim_out, 1) if dim != dim_out else nn.Identity() 106 | 107 | def forward(self, x, time_emb): 108 | h = self.block1(x) 109 | h = self.noise_func(h, time_emb) 110 | h = self.block2(h) 111 | return h + self.res_conv(x) 112 | 113 | 114 | class SelfAttention(nn.Module): 115 | def __init__(self, in_channel, n_head=1, norm_groups=32): 116 | super().__init__() 117 | 118 | self.n_head = n_head 119 | 120 | self.norm = nn.GroupNorm(norm_groups, in_channel) 121 | self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False) 122 | self.out = nn.Conv2d(in_channel, in_channel, 1) 123 | 124 | def forward(self, input): 125 | batch, channel, height, width = input.shape 126 | n_head = self.n_head 127 | head_dim = channel // n_head 128 | 129 | norm = self.norm(input) 130 | qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width) 131 | query, key, value = qkv.chunk(3, dim=2) # bhdyx 132 | 133 | attn = torch.einsum( 134 | "bnchw, bncyx -> bnhwyx", query, key 135 | ).contiguous() / math.sqrt(channel) 136 | attn = attn.view(batch, n_head, height, width, -1) 137 | attn = torch.softmax(attn, -1) 138 | attn = attn.view(batch, n_head, height, width, height, width) 139 | 140 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous() 141 | out = self.out(out.view(batch, channel, height, width)) 142 | 143 | return out + input 144 | 145 | 146 | class ResnetBlocWithAttn(nn.Module): 147 | def __init__(self, dim, dim_out, *, time_emb_dim=None, norm_groups=32, dropout=0, with_attn=False): 148 | super().__init__() 149 | self.with_attn = with_attn 150 | self.res_block = ResnetBlock( 151 | dim, dim_out, time_emb_dim, norm_groups=norm_groups, dropout=dropout) 152 | if with_attn: 153 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups) 154 | 155 | def forward(self, x, time_emb): 156 | x = self.res_block(x, time_emb) 157 | if(self.with_attn): 158 | x = self.attn(x) 159 | return x 160 | 161 | 162 | class UNet(nn.Module): 163 | def __init__( 164 | self, 165 | in_channel=6, 166 | out_channel=3, 167 | inner_channel=32, 168 | norm_groups=32, 169 | channel_mults=(1, 1, 2, 2, 4), 170 | attn_res=(8), 171 | res_blocks=3, 172 | dropout=0, 173 | with_noise_level_emb=True, 174 | image_size=128 175 | ): 176 | super().__init__() 177 | if with_noise_level_emb: 178 | noise_level_channel = inner_channel 179 | self.noise_level_mlp = nn.Sequential( 180 | PositionalEncoding(inner_channel), 181 | nn.Linear(inner_channel, inner_channel * 4), 182 | Swish(), 183 | nn.Linear(inner_channel * 4, inner_channel) 184 | ) 185 | else: 186 | noise_level_channel = None 187 | self.noise_level_mlp = None 188 | 189 | 190 | num_mults = len(channel_mults) 191 | pre_channel = inner_channel 192 | feat_channels = [pre_channel] 193 | now_res = image_size 194 | downs = [nn.Conv2d(in_channel, inner_channel, 195 | kernel_size=3, padding=1)] 196 | for ind in range(num_mults): 197 | is_last = (ind == num_mults - 1) 198 | use_attn = (now_res in attn_res) 199 | channel_mult = inner_channel * channel_mults[ind] 200 | for _ in range(0, res_blocks): 201 | downs.append(ResnetBlocWithAttn( 202 | pre_channel, channel_mult, time_emb_dim=noise_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn)) 203 | feat_channels.append(channel_mult) 204 | pre_channel = channel_mult 205 | if not is_last: 206 | downs.append(Downsample(pre_channel)) 207 | feat_channels.append(pre_channel) 208 | now_res = now_res//2 209 | self.downs = nn.ModuleList(downs) 210 | 211 | self.mid = nn.ModuleList([ 212 | ResnetBlocWithAttn(pre_channel, pre_channel, time_emb_dim=noise_level_channel, norm_groups=norm_groups, 213 | dropout=dropout, with_attn=True), 214 | ResnetBlocWithAttn(pre_channel, pre_channel, time_emb_dim=noise_level_channel, norm_groups=norm_groups, 215 | dropout=dropout, with_attn=False) 216 | ]) 217 | 218 | ups = [] 219 | for ind in reversed(range(num_mults)): 220 | is_last = (ind < 1) 221 | use_attn = (now_res in attn_res) 222 | channel_mult = inner_channel * channel_mults[ind] 223 | for _ in range(0, res_blocks+1): 224 | ups.append(ResnetBlocWithAttn( 225 | pre_channel+feat_channels.pop(), channel_mult, time_emb_dim=noise_level_channel, dropout=dropout, norm_groups=norm_groups, with_attn=use_attn)) 226 | pre_channel = channel_mult 227 | if not is_last: 228 | ups.append(Upsample(pre_channel)) 229 | now_res = now_res*2 230 | 231 | self.ups = nn.ModuleList(ups) 232 | 233 | self.final_conv = Block(pre_channel, default(out_channel, in_channel), groups=norm_groups) 234 | 235 | self.var_conv = nn.Sequential(*[ 236 | nn.Conv2d(pre_channel, pre_channel, 3, padding=(3//2), bias=True), 237 | nn.ELU(), 238 | nn.Conv2d(pre_channel, pre_channel, 3, padding=(3//2), bias=True), 239 | nn.ELU(), 240 | nn.Conv2d(pre_channel, 3, 3, padding=(3//2), bias=True), 241 | nn.ELU() 242 | ]) 243 | # self.swish = Swish() 244 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 245 | return nn.Conv2d( 246 | in_channels, out_channels, kernel_size, 247 | padding=(kernel_size//2), bias=bias) 248 | 249 | def forward(self, x, noise): 250 | noise_level = self.noise_level_mlp(noise) if exists(self.noise_level_mlp) else None 251 | feats = [] 252 | for layer in self.downs: 253 | if isinstance(layer, ResnetBlocWithAttn): 254 | x = layer(x, noise_level) 255 | else: 256 | x = layer(x) 257 | feats.append(x) 258 | 259 | for layer in self.mid: 260 | if isinstance(layer, ResnetBlocWithAttn): 261 | x = layer(x, noise_level) 262 | else: 263 | x = layer(x) 264 | 265 | for layer in self.ups: 266 | if isinstance(layer, ResnetBlocWithAttn): 267 | x = layer(torch.cat((x, feats.pop()), dim=1), noise_level) 268 | else: 269 | x = layer(x) 270 | return self.final_conv(x), self.var_conv(x) 271 | 272 | 273 | 274 | # FreeU 275 | def Fourier_filter(x, threshold, scale): 276 | # FFT 277 | x_freq = fft.fftn(x, dim=(-2, -1)) 278 | x_freq = fft.fftshift(x_freq, dim=(-2, -1)) 279 | 280 | B, C, H, W = x_freq.shape 281 | mask = torch.ones((B, C, H, W)).cuda() 282 | 283 | crow, ccol = H // 2, W //2 284 | mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale 285 | x_freq = x_freq * mask 286 | 287 | # IFFT 288 | x_freq = fft.ifftshift(x_freq, dim=(-2, -1)) 289 | x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real 290 | 291 | return x_filtered 292 | 293 | def SNR_filter(x_in, threshold, scale): 294 | 295 | blur_transform = T.GaussianBlur(kernel_size=15, sigma=3) 296 | blur_x_in = blur_transform(x_in) 297 | gray_blur_x_in = blur_x_in[:, 0:1, :, :] * 0.299 + blur_x_in[:, 1:2, :, :] * 0.587 + blur_x_in[:, 2:3, :, :] * 0.114 298 | gray_x_in = x_in[:, 0:1, :, :] * 0.299 + x_in[:, 1:2, :, :] * 0.587 + x_in[:, 2:3, :, :] * 0.114 299 | noise = torch.abs(gray_blur_x_in - gray_x_in) 300 | mask = torch.div(gray_x_in, noise + 0.0001) 301 | 302 | batch_size = mask.shape[0] 303 | height = mask.shape[2] 304 | width = mask.shape[3] 305 | mask_max = torch.max(mask.view(batch_size, -1), dim=1)[0] 306 | mask_max = mask_max.view(batch_size, 1, 1, 1) 307 | mask_max = mask_max.repeat(1, 1, height, width) 308 | mask = mask * 1.0 / (mask_max+0.0001) 309 | mask = torch.clamp(mask, min=0, max=1.0) 310 | mask = mask.float() 311 | 312 | return mask 313 | 314 | 315 | class Free_UNet(UNet): 316 | """ 317 | :param b1: backbone factor of the first stage block of decoder. 318 | :param b2: backbone factor of the second stage block of decoder. 319 | :param s1: skip factor of the first stage block of decoder. 320 | :param s2: skip factor of the second stage block of decoder. 321 | """ 322 | 323 | def __init__( 324 | self, 325 | b1=1.3, 326 | b2=1.4, 327 | s1=0.9, 328 | s2=0.2, 329 | *args, 330 | **kwargs 331 | ): 332 | super().__init__(*args, **kwargs) 333 | self.b1 = b1 334 | self.b2 = b2 335 | self.s1 = s1 336 | self.s2 = s2 337 | 338 | def forward(self, h, noise): 339 | # what we need is only x and noise 340 | """ 341 | Apply the model to an input batch. 342 | :param x: an [N x C x ...] Tensor of inputs. 343 | :param timesteps: a 1-D batch of timesteps. 344 | :param context: conditioning plugged in via crossattn 345 | :param y: an [N] Tensor of labels, if class-conditional. 346 | :return: an [N x C x ...] Tensor of outputs. 347 | """ 348 | hs = [] 349 | noise_level = self.noise_level_mlp(noise) if exists(self.noise_level_mlp) else None 350 | 351 | 352 | 353 | for layer in self.downs: 354 | if isinstance(layer, ResnetBlocWithAttn): 355 | h = layer(h, noise_level) 356 | else: 357 | h = layer(h) 358 | hs.append(h) 359 | 360 | for layer in self.mid: 361 | if isinstance(layer, ResnetBlocWithAttn): 362 | h = layer(h, noise_level) 363 | else: 364 | h = layer(h) 365 | 366 | for layer in self.ups: 367 | # --------------- FreeU code ----------------------- 368 | # Only operate on the first two stages 369 | if h.shape[1] == 256: 370 | hs_ = hs.pop() 371 | hidden_mean = h.mean(1).unsqueeze(1) 372 | B = hidden_mean.shape[0] 373 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 374 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) 375 | hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) 376 | 377 | h[:,:128] = h[:,:128] * ((self.b1 - 1 ) * hidden_mean + 1) 378 | hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1) 379 | hs.append(hs_) 380 | if h.shape[1] == 128: 381 | hs_ = hs.pop() 382 | hidden_mean = h.mean(1).unsqueeze(1) 383 | B = hidden_mean.shape[0] 384 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 385 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) 386 | hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) 387 | 388 | h[:,:64] = h[:,:64] * ((self.b2 - 1 ) * hidden_mean + 1) 389 | hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2) 390 | hs.append(hs_) 391 | # --------------------------------------------------------- 392 | # exit(print(h.shape, hs_.shape)) # torch.Size([1, 256, 26, 38]) torch.Size([1, 256, 26, 38]) 393 | # print(h.shape, hs_.shape) 394 | # h = torch.cat((h, hs_), dim=1) 395 | 396 | if isinstance(layer, ResnetBlocWithAttn): 397 | h = layer(torch.cat((h, hs.pop()), dim=1), noise_level) 398 | else: 399 | h = layer(h) 400 | return self.final_conv(h), self.var_conv(h) 401 | 402 | class LAUNet(UNet): 403 | """ 404 | :param b1: backbone factor of the first stage block of decoder. 405 | :param b2: backbone factor of the second stage block of decoder. 406 | :param s1: skip factor of the first stage block of decoder. 407 | :param s2: skip factor of the second stage block of decoder. 408 | """ 409 | 410 | def __init__( 411 | self, 412 | b1=1.3, 413 | b2=1.4, 414 | s1=0.9, 415 | s2=0.2, 416 | *args, 417 | **kwargs 418 | ): 419 | super().__init__(*args, **kwargs) 420 | self.b1 = b1 421 | self.b2 = b2 422 | self.s1 = s1 423 | self.s2 = s2 424 | 425 | def forward(self, h, noise): 426 | # what we need is only x and noise 427 | """ 428 | Apply the model to an input batch. 429 | :param x: an [N x C x ...] Tensor of inputs. 430 | :param timesteps: a 1-D batch of timesteps. 431 | :param context: conditioning plugged in via crossattn 432 | :param y: an [N] Tensor of labels, if class-conditional. 433 | :return: an [N x C x ...] Tensor of outputs. 434 | """ 435 | hs = [] 436 | noise_level = self.noise_level_mlp(noise) if exists(self.noise_level_mlp) else None 437 | 438 | 439 | 440 | for layer in self.downs: 441 | if isinstance(layer, ResnetBlocWithAttn): 442 | h = layer(h, noise_level) 443 | else: 444 | h = layer(h) 445 | hs.append(h) 446 | 447 | for layer in self.mid: 448 | if isinstance(layer, ResnetBlocWithAttn): 449 | h = layer(h, noise_level) 450 | else: 451 | h = layer(h) 452 | 453 | for layer in self.ups: 454 | # --------------- FreeU code ----------------------- 455 | # Only operate on the first two stages 456 | if h.shape[1] == 256: 457 | hs_ = hs.pop() 458 | hidden_mean = h.mean(1).unsqueeze(1) 459 | B = hidden_mean.shape[0] 460 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 461 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) 462 | hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) 463 | 464 | h[:,:128] = h[:,:128] * ((self.b1 - 1 ) * hidden_mean + 1) 465 | hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1) 466 | hs.append(hs_) 467 | if h.shape[1] == 128: 468 | hs_ = hs.pop() 469 | hidden_mean = h.mean(1).unsqueeze(1) 470 | B = hidden_mean.shape[0] 471 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 472 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) 473 | hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) 474 | 475 | h[:,:64] = h[:,:64] * ((self.b2 - 1 ) * hidden_mean + 1) 476 | hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2) 477 | hs.append(hs_) 478 | # --------------------------------------------------------- 479 | # exit(print(h.shape, hs_.shape)) # torch.Size([1, 256, 26, 38]) torch.Size([1, 256, 26, 38]) 480 | # print(h.shape, hs_.shape) 481 | # h = torch.cat((h, hs_), dim=1) 482 | 483 | if isinstance(layer, ResnetBlocWithAttn): 484 | h = layer(torch.cat((h, hs.pop()), dim=1), noise_level) 485 | else: 486 | h = layer(h) 487 | return self.final_conv(h), self.var_conv(h) 488 | 489 | class LAUNet(UNet): 490 | """ 491 | :param b1: backbone factor of the first stage block of decoder. 492 | :param b2: backbone factor of the second stage block of decoder. 493 | :param s1: skip factor of the first stage block of decoder. 494 | :param s2: skip factor of the second stage block of decoder. 495 | """ 496 | 497 | def __init__( 498 | self, 499 | b1=1.3, 500 | b2=1.4, 501 | s1=0.9, 502 | s2=0.2, 503 | *args, 504 | **kwargs 505 | ): 506 | super().__init__(*args, **kwargs) 507 | self.b1 = b1 508 | self.b2 = b2 509 | self.s1 = s1 510 | self.s2 = s2 511 | 512 | def forward(self, h, noise): 513 | # what we need is only x and noise 514 | """ 515 | Apply the model to an input batch. 516 | :param x: an [N x C x ...] Tensor of inputs. 517 | :param timesteps: a 1-D batch of timesteps. 518 | :param context: conditioning plugged in via crossattn 519 | :param y: an [N] Tensor of labels, if class-conditional. 520 | :return: an [N x C x ...] Tensor of outputs. 521 | """ 522 | hs = [] 523 | noise_level = self.noise_level_mlp(noise) if exists(self.noise_level_mlp) else None 524 | 525 | 526 | 527 | for layer in self.downs: 528 | # print(h.shape) 529 | if isinstance(layer, ResnetBlocWithAttn): 530 | h = layer(h, noise_level) 531 | else: 532 | h = layer(h) 533 | hs.append(h) 534 | # print("\n") 535 | for layer in self.mid: 536 | # print(h.shape) 537 | if isinstance(layer, ResnetBlocWithAttn): 538 | h = layer(h, noise_level) 539 | else: 540 | h = layer(h) 541 | # print("\n") 542 | for layer in self.ups: 543 | # print(h.shape) 544 | # --------------- FreeU code ----------------------- 545 | # Only operate on the first two stages 546 | if h.shape[1] == 256: 547 | hs_ = hs.pop() 548 | hidden_mean = h.mean(1).unsqueeze(1) 549 | B = hidden_mean.shape[0] 550 | hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 551 | hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) 552 | hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) 553 | 554 | h[:,:64] = h[:,:64] * ((self.b1 - 1 ) * hidden_mean + 1) 555 | hs_ = Fourier_filter(hs_, threshold=1, scale=self.s1) 556 | hs.append(hs_) 557 | # if h.shape[1] == 128 and h.shape[2] == 104: 558 | # hs_ = hs.pop() 559 | # hidden_mean = h.mean(1).unsqueeze(1) 560 | # B = hidden_mean.shape[0] 561 | # hidden_max, _ = torch.max(hidden_mean.view(B, -1), dim=-1, keepdim=True) 562 | # hidden_min, _ = torch.min(hidden_mean.view(B, -1), dim=-1, keepdim=True) 563 | # hidden_mean = (hidden_mean - hidden_min.unsqueeze(2).unsqueeze(3)) / (hidden_max - hidden_min).unsqueeze(2).unsqueeze(3) 564 | 565 | # h[:,:64] = h[:,:64] * ((self.b2 - 1 ) * hidden_mean + 1) 566 | # hs_ = Fourier_filter(hs_, threshold=1, scale=self.s2) 567 | # hs.append(hs_) 568 | # --------------------------------------------------------- 569 | # exit(print(h.shape, hs_.shape)) # torch.Size([1, 256, 26, 38]) torch.Size([1, 256, 26, 38]) 570 | # print(h.shape, hs_.shape) 571 | # h = torch.cat((h, hs_), dim=1) 572 | if isinstance(layer, ResnetBlocWithAttn): 573 | h = layer(torch.cat((h, hs.pop()), dim=1), noise_level) 574 | else: 575 | h = layer(h) 576 | 577 | # exit(-1) 578 | return self.final_conv(h), self.var_conv(h) -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | import os 6 | import model.networks as networks 7 | from .base_model import BaseModel 8 | from torch.nn.parallel import DistributedDataParallel as DDP 9 | from utils.ema import EMA 10 | from torch.optim import lr_scheduler 11 | import lpips 12 | 13 | logger = logging.getLogger('base') 14 | skip_para = [] 15 | 16 | skip_para = ['betas', 'alphas_cumprod', 'alphas_cumprod_prev', 'sqrt_alphas_cumprod', 17 | 'sqrt_one_minus_alphas_cumprod', 'log_one_minus_alphas_cumprod', 'sqrt_recip_alphas_cumprod', 18 | 'sqrt_recipm1_alphas_cumprod', 'posterior_variance', 'posterior_log_variance_clipped', 19 | 'posterior_mean_coef1', 'posterior_mean_coef2',] 20 | 21 | def get_scheduler(optimizer, opt): 22 | if opt['train']["optimizer"]['lr_policy'] == 'linear': 23 | def lambda_rule(iteration): 24 | lr_l = 1.0 - max(0, iteration-opt['train']["optimizer"]["n_lr_iters"]) / float(opt['train']["optimizer"]["lr_decay_iters"] + 1) 25 | return lr_l 26 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule) 27 | elif opt['train']["optimizer"]['lr_policy'] == 'step': 28 | scheduler = lr_scheduler.StepLR(optimizer, step_size=opt['train']["optimizer"]["lr_decay_iters"], gamma=0.8) 29 | elif opt['train']["optimizer"]['lr_policy'] == 'plateau': 30 | scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5) 31 | elif opt['train']["optimizer"]['lr_policy'] == 'cosine': 32 | scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=opt.n_epochs, eta_min=0) 33 | else: 34 | return NotImplementedError('learning rate policy [%s] is not implemented', opt['train']["optimizer"]['lr_policy']) 35 | return scheduler 36 | 37 | 38 | class DDPM(BaseModel): 39 | def __init__(self, opt): 40 | super(DDPM, self).__init__(opt) 41 | 42 | if opt['dist']: 43 | self.local_rank = torch.distributed.get_rank() 44 | torch.cuda.set_device(self.local_rank) 45 | device = torch.device("cuda", self.local_rank) 46 | # define network and load pretrained models 47 | self.netG = self.set_device(networks.define_G(opt, student=False)) 48 | if opt['dist']: 49 | self.netG.to(device) 50 | 51 | # self.netG.to(device) 52 | 53 | self.schedule_phase = None 54 | self.opt = opt 55 | 56 | # set loss and load resume state 57 | self.set_loss() 58 | 59 | if self.opt['phase'] == 'train': 60 | self.netG.train() 61 | # find the parameters to optimize 62 | if opt['model']['finetune_norm']: 63 | optim_params = [] 64 | for k, v in self.netG.named_parameters(): 65 | v.requires_grad = False 66 | if k.find('transformer') >= 0: 67 | v.requires_grad = True 68 | v.data.zero_() 69 | optim_params.append(v) 70 | logger.info( 71 | 'Params [{:s}] initialized to 0 and will optimize.'.format(k)) 72 | else: 73 | optim_params = list(self.netG.parameters()) 74 | 75 | self.optG = torch.optim.Adam( 76 | optim_params, lr=opt['train']["optimizer"]["lr"]) 77 | self.log_dict = OrderedDict() 78 | 79 | if self.opt['phase'] == 'test': 80 | self.netG.load_state_dict(torch.load(self.opt['path']['resume_state']), strict=True) 81 | 82 | else: 83 | self.load_network() 84 | if opt['dist']: 85 | self.netG = DDP(self.netG, device_ids=[self.local_rank], output_device=self.local_rank,find_unused_parameters=True) 86 | 87 | 88 | def feed_data(self, data): 89 | 90 | dic = {} 91 | 92 | if self.opt['dist']: 93 | dic = {} 94 | dic['LQ'] = data['LQ'].to(self.local_rank) 95 | dic['GT'] = data['GT'].to(self.local_rank) 96 | self.data = dic 97 | else: 98 | dic['LQ'] = data['LQ'] 99 | dic['GT'] = data['GT'] 100 | 101 | self.data = self.set_device(dic) 102 | 103 | 104 | def test(self, continous=False): 105 | self.netG.eval() 106 | with torch.no_grad(): 107 | if isinstance(self.netG, nn.DataParallel): 108 | self.SR = self.netG.module.super_resolution( 109 | self.data['LQ'], continous) 110 | 111 | else: 112 | if self.opt['dist']: 113 | self.SR = self.netG.module.super_resolution(self.data['LQ'], continous) 114 | else: 115 | self.SR = self.netG.super_resolution(self.data['LQ'], continous) 116 | 117 | self.netG.train() 118 | 119 | def sample(self, batch_size=1, continous=False): 120 | self.netG.eval() 121 | with torch.no_grad(): 122 | if isinstance(self.netG, nn.DataParallel): 123 | self.SR = self.netG.module.sample(batch_size, continous) 124 | else: 125 | self.SR = self.netG.sample(batch_size, continous) 126 | self.netG.train() 127 | 128 | def set_loss(self): 129 | if isinstance(self.netG, nn.DataParallel): 130 | self.netG.module.set_loss(self.device) 131 | else: 132 | self.netG.set_loss(self.device) 133 | 134 | def set_new_noise_schedule(self, schedule_opt, schedule_phase='train'): 135 | 136 | if self.opt['dist']: 137 | 138 | device = torch.device("cuda", self.local_rank) 139 | if self.schedule_phase is None or self.schedule_phase != schedule_phase: 140 | self.schedule_phase = schedule_phase 141 | if isinstance(self.netG, nn.DataParallel): 142 | self.netG.module.set_new_noise_schedule( 143 | schedule_opt, self.device) 144 | else: 145 | self.netG.set_new_noise_schedule(schedule_opt) 146 | 147 | else: 148 | self.schedule_phase = schedule_phase 149 | if isinstance(self.netG, nn.DataParallel): 150 | self.netG.module.set_new_noise_schedule( 151 | schedule_opt, self.device) 152 | else: 153 | # self.netG.set_new_noise_schedule(schedule_opt, self.device) 154 | self.netG.set_new_noise_schedule(schedule_opt, self.device) 155 | 156 | 157 | def get_current_log(self): 158 | return self.log_dict 159 | 160 | def get_current_visuals(self, need_LR=True, sample=False): 161 | out_dict = OrderedDict() 162 | if sample: 163 | out_dict['SAM'] = self.SR.detach().float().cpu() 164 | else: 165 | out_dict['HQ'] = self.SR.detach().float().cpu() 166 | out_dict['INF'] = self.data['LQ'].detach().float().cpu() 167 | out_dict['GT'] = self.data['GT'].detach()[0].float().cpu() 168 | if need_LR and 'LR' in self.data: 169 | out_dict['LQ'] = self.data['LQ'].detach().float().cpu() 170 | else: 171 | out_dict['LQ'] = out_dict['INF'] 172 | return out_dict 173 | 174 | def print_network(self): 175 | s, n = self.get_network_description(self.netG) 176 | if isinstance(self.netG, nn.DataParallel): 177 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 178 | self.netG.module.__class__.__name__) 179 | else: 180 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 181 | 182 | logger.info(s) 183 | 184 | logger.info( 185 | 'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 186 | 187 | def save_network(self, distill_step, epoch, iter_step): 188 | gen_path = os.path.join( 189 | self.opt['path']['checkpoint'], 'num_step_{}', 'I{}_E{}_gen.pth'.format(distill_step, iter_step, epoch)) 190 | opt_path = os.path.join( 191 | self.opt['path']['checkpoint'], 'num_step_{}', 'I{}_E{}_opt.pth'.format(distill_step, iter_step, epoch)) 192 | 193 | # gen 194 | network = self.netG 195 | if isinstance(self.netG, nn.DataParallel): 196 | network = network.module 197 | state_dict = network.state_dict() 198 | for key, param in state_dict.items(): 199 | state_dict[key] = param.cpu() 200 | torch.save(state_dict, gen_path) 201 | 202 | 203 | 204 | logger.info( 205 | 'Saved model in [{:s}] ...'.format(gen_path)) 206 | 207 | def load_network(self): 208 | load_path = self.opt['path']['resume_state'] 209 | if load_path is not None: 210 | logger.info( 211 | 'Loading pretrained model for G [{:s}] ...'.format(load_path)) 212 | gen_path = '{}'.format(load_path) 213 | 214 | # gen 215 | networks = [self.netG, self.netG] 216 | for network in networks: 217 | if isinstance(network, nn.DataParallel): 218 | network = network.module 219 | 220 | # network = nn.DataParallel(network).cuda() 221 | ckpt = torch.load(gen_path) 222 | current_state_dict = network.state_dict() 223 | for name, param in ckpt.items(): 224 | if name in skip_para: 225 | continue 226 | # print(name) 227 | # import pdb; pdb.set_trace() 228 | else: 229 | current_state_dict[name] = param 230 | 231 | network.load_state_dict(current_state_dict, strict=False) 232 | if self.opt['phase'] == 'train': 233 | self.begin_step = 0 234 | self.begin_epoch = 0 235 | 236 | 237 | 238 | 239 | class DDPM_PD(BaseModel): 240 | def __init__(self, opt): 241 | super(DDPM_PD, self).__init__(opt) 242 | 243 | if opt['dist']: 244 | self.local_rank = torch.distributed.get_rank() 245 | torch.cuda.set_device(self.local_rank) 246 | device = torch.device("cuda", self.local_rank) 247 | # define network and load pretrained models 248 | self.netG_t = self.set_device(networks.define_G(opt, student=False)) 249 | if opt['CD'] : 250 | self.netG_s = self.set_device(networks.define_G(opt, student=False)) 251 | else: 252 | self.netG_s = self.set_device(networks.define_G(opt, student=True)) 253 | if opt['dist']: 254 | self.netG_t.to(device) 255 | self.netG_s.to(device) 256 | 257 | # self.netG.to(device) 258 | 259 | 260 | self.schedule_phase = None 261 | self.opt = opt 262 | 263 | # set loss and load resume state 264 | 265 | self.set_loss() 266 | self.lpips = lpips.LPIPS(net='vgg').cuda() 267 | 268 | # self.set_new_noise_schedule(opt['model']['beta_schedule']['train'], schedule_phase='train') 269 | 270 | if self.opt['phase'] == 'train': 271 | self.netG_s.train() 272 | # find the parameters to optimize 273 | 274 | if opt['model']['finetune_norm']: 275 | optim_params = [] 276 | for k, v in self.netG_s.named_parameters(): 277 | v.requires_grad = False 278 | if k.find('transformer') >= 0: 279 | v.requires_grad = True 280 | v.data.zero_() 281 | optim_params.append(v) 282 | logger.info( 283 | 'Params [{:s}] initialized to 0 and will optimize.'.format(k)) 284 | else: 285 | optim_params = list(self.netG_s.parameters()) 286 | 287 | self.optG = torch.optim.Adam(optim_params, lr=opt['train']["optimizer"]["lr"]) 288 | self.scheduler = get_scheduler(self.optG, opt) 289 | self.log_dict = OrderedDict() 290 | 291 | 292 | 293 | 294 | if self.opt['phase'] == 'test': 295 | self.netG_s.load_state_dict(torch.load(self.opt['path']['resume_state']), strict=True) 296 | else: 297 | self.load_network() 298 | # self.netG_t.load_state_dict(torch.load(self.opt['path']['resume_state']), strict=True) 299 | if opt['dist']: 300 | self.netG_s = DDP(self.netG_s, device_ids=[self.local_rank], output_device=self.local_rank,find_unused_parameters=True) 301 | self.netG_t = DDP(self.netG_t, device_ids=[self.local_rank], output_device=self.local_rank,find_unused_parameters=True) 302 | for p in self.netG_t.parameters(): 303 | p.requires_grad_(False) 304 | self.netG_t.eval() 305 | self.netG_t.CD = opt['CD'] 306 | self.netG_s.CD = opt['CD'] 307 | 308 | # self.print_network() 309 | 310 | # define ema 311 | self.ema_decay = opt['train']["ema_scheduler"]["ema_decay"] 312 | if self.opt['dist']: 313 | self.ema_student = EMA( 314 | self.netG_s.module, 315 | decay = self.ema_decay, # exponential moving average factor 316 | ) 317 | else: 318 | self.ema_student = EMA( 319 | self.netG_s, 320 | decay = self.ema_decay, # exponential moving average factor 321 | ) 322 | 323 | self.ema_student.register() 324 | 325 | def feed_data(self, data): 326 | 327 | dic = {} 328 | 329 | if self.opt['dist']: 330 | dic = {} 331 | dic['LQ'] = data['LQ'].to(self.local_rank) 332 | dic['GT'] = data['GT'].to(self.local_rank) 333 | self.data = dic 334 | else: 335 | dic['LQ'] = data['LQ'] 336 | dic['GT'] = data['GT'] 337 | 338 | self.data = self.set_device(dic) 339 | 340 | def optimize_parameters(self): 341 | 342 | self.optG.zero_grad() 343 | if self.opt['dist']: 344 | l_pd = self.netG_t(self.data, self.netG_s.module, lpips_func=self.lpips) 345 | else: 346 | l_pd = self.netG_t(self.data, self.netG_s, lpips_func=self.lpips) 347 | # print("to be debug") 348 | # import pdb; pdb.set_trace() 349 | # 350 | 351 | 352 | loss = l_pd 353 | # print(l_pd) 354 | # import pdb; pdb.set_trace() 355 | loss.backward() 356 | torch.nn.utils.clip_grad_norm_(self.netG_s.parameters(), 1) 357 | self.optG.step() 358 | self.scheduler.step() 359 | # print( self.optG.param_groups[0]['lr']) 360 | self.ema_student.update() 361 | # set log 362 | self.log_dict['total_loss'] = loss.item() 363 | 364 | 365 | 366 | 367 | def test(self, continous=False, stride=1): 368 | self.ema_student.apply_shadow() # apply shadow weights here 369 | self.netG_s.eval() 370 | with torch.no_grad(): 371 | if isinstance(self.netG_s, nn.DataParallel): 372 | self.SR = self.netG_s.module.super_resolution( 373 | self.data['LQ'], continous, stride) 374 | 375 | else: 376 | if self.opt['dist']: 377 | self.SR = self.netG_s.module.super_resolution(self.data['LQ'], continous, stride) 378 | else: 379 | self.SR = self.netG_s.super_resolution(self.data['LQ'], continous, stride) 380 | self.ema_student.restore()# restore shadow weights here 381 | 382 | self.netG_s.train() 383 | 384 | def sample(self, batch_size=1, continous=False): 385 | self.ema_student.apply_shadow() # apply shadow weights here 386 | self.netG_s.eval() 387 | with torch.no_grad(): 388 | if isinstance(self.netG_s, nn.DataParallel): 389 | self.SR = self.netG_s.module.sample(batch_size, continous) 390 | else: 391 | self.SR = self.netG_s.sample(batch_size, continous) 392 | self.ema_student.restore()# restore shadow weights here 393 | self.netG_s.train() 394 | 395 | def set_loss(self): 396 | if isinstance(self.netG_s, nn.DataParallel): 397 | self.netG_s.module.set_loss(self.device) 398 | else: 399 | self.netG_s.set_loss(self.device) 400 | 401 | 402 | def get_current_log(self): 403 | return self.log_dict 404 | 405 | def get_current_visuals(self, need_LR=True, sample=False): 406 | out_dict = OrderedDict() 407 | if sample: 408 | out_dict['SAM'] = self.SR.detach().float().cpu() 409 | else: 410 | out_dict['HQ'] = self.SR.detach().float().cpu() 411 | out_dict['INF'] = self.data['LQ'].detach().float().cpu() 412 | out_dict['GT'] = self.data['GT'].detach()[0].float().cpu() 413 | if need_LR and 'LR' in self.data: 414 | out_dict['LQ'] = self.data['LQ'].detach().float().cpu() 415 | else: 416 | out_dict['LQ'] = out_dict['INF'] 417 | return out_dict 418 | 419 | def print_network(self): 420 | s, n = self.get_network_description(self.netG_s) 421 | if isinstance(self.netG_s, nn.DataParallel): 422 | net_struc_str = '{} - {}'.format(self.netG_s.__class__.__name__, 423 | self.netG_s.module.__class__.__name__) 424 | else: 425 | net_struc_str = '{}'.format(self.netG_s.__class__.__name__) 426 | 427 | logger.info(s) 428 | 429 | logger.info( 430 | 'Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 431 | 432 | def save_network(self, distill_step, epoch, iter_step, psnr, ssim, lpips): 433 | save_root = os.path.join(self.opt['path']['checkpoint'], 'num_step_{}'.format(distill_step)) 434 | os.makedirs(save_root, exist_ok=True) 435 | # gen_path = os.path.join(save_root, 'P{:.4e}_S{:.4e}_I{}_E{}_gen.pth'.format(psnr, ssim, iter_step, epoch)) 436 | # opt_path = os.path.join(save_root, 'P{:.4e}_S{:.4e}_I{}_E{}_opt.pth'.format(psnr, ssim, iter_step, epoch)) 437 | ema_path = os.path.join(save_root, 'psnr{:.4f}_ssim{:.4f}_lpips{:.4f}_I{}_E{}_gen_ema.pth'.format(psnr, ssim, lpips, iter_step, epoch)) 438 | 439 | # gen 440 | # network = self.netG_s 441 | # if isinstance(self.netG_s, nn.DataParallel): 442 | # network = network.module 443 | # state_dict = network.state_dict() 444 | # for key, param in state_dict.items(): 445 | # state_dict[key] = param.cpu() 446 | # torch.save(state_dict, gen_path) 447 | 448 | # opt 449 | 450 | 451 | 452 | # ema 453 | self.ema_student.apply_shadow() 454 | network = self.ema_student.model 455 | if isinstance(self.ema_student.model, nn.DataParallel): 456 | network = network.module 457 | ema_ckpt = network.state_dict() 458 | for key, param in ema_ckpt.items(): 459 | ema_ckpt[key] = param.cpu() 460 | torch.save(ema_ckpt, ema_path) 461 | self.ema_student.restore() 462 | # logger.info( 463 | # 'Saved model in [{:s}] ...'.format(gen_path)) 464 | logger.info( 465 | 'Saved model in [{:s}] ...'.format(ema_path)) 466 | return ema_path # gen_path 467 | 468 | def load_network(self): 469 | load_path = self.opt['path']['resume_state'] 470 | if load_path is not None: 471 | logger.info( 472 | 'Loading pretrained model for G [{:s}] ...'.format(load_path)) 473 | gen_path = '{}'.format(load_path) 474 | 475 | # gen 476 | networks = [self.netG_t, self.netG_s] 477 | for network in networks: 478 | if isinstance(network, nn.DataParallel): 479 | network = network.module 480 | ckpt = torch.load(gen_path) 481 | 482 | current_state_dict = network.state_dict() 483 | for name, param in ckpt.items(): 484 | if name in skip_para: 485 | continue 486 | 487 | else: 488 | current_state_dict[name] = param 489 | 490 | network.load_state_dict(current_state_dict, strict=False) 491 | 492 | if self.opt['phase'] == 'train': 493 | 494 | self.begin_step = 0 495 | self.begin_epoch = 0 496 | -------------------------------------------------------------------------------- /model/networks.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import logging 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn import init 6 | from torch.nn import modules 7 | logger = logging.getLogger('base') 8 | #################### 9 | # initialize 10 | #################### 11 | 12 | 13 | def weights_init_normal(m, std=0.02): 14 | classname = m.__class__.__name__ 15 | if classname.find('Conv') != -1: 16 | init.normal_(m.weight.data, 0.0, std) 17 | if m.bias is not None: 18 | m.bias.data.zero_() 19 | elif classname.find('Linear') != -1: 20 | init.normal_(m.weight.data, 0.0, std) 21 | if m.bias is not None: 22 | m.bias.data.zero_() 23 | elif classname.find('BatchNorm2d') != -1: 24 | init.normal_(m.weight.data, 1.0, std) # BN also uses norm 25 | init.constant_(m.bias.data, 0.0) 26 | 27 | 28 | def weights_init_kaiming(m, scale=1): 29 | classname = m.__class__.__name__ 30 | if classname.find('Conv2d') != -1: 31 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 32 | m.weight.data *= scale 33 | if m.bias is not None: 34 | m.bias.data.zero_() 35 | elif classname.find('Linear') != -1: 36 | init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') 37 | m.weight.data *= scale 38 | if m.bias is not None: 39 | m.bias.data.zero_() 40 | elif classname.find('BatchNorm2d') != -1: 41 | init.constant_(m.weight.data, 1.0) 42 | init.constant_(m.bias.data, 0.0) 43 | 44 | 45 | def weights_init_orthogonal(m): 46 | classname = m.__class__.__name__ 47 | if classname.find('Conv') != -1: 48 | init.orthogonal_(m.weight.data, gain=1) 49 | if m.bias is not None: 50 | m.bias.data.zero_() 51 | elif classname.find('Linear') != -1: 52 | init.orthogonal_(m.weight.data, gain=1) 53 | if m.bias is not None: 54 | m.bias.data.zero_() 55 | elif classname.find('BatchNorm2d') != -1: 56 | init.constant_(m.weight.data, 1.0) 57 | init.constant_(m.bias.data, 0.0) 58 | 59 | 60 | def init_weights(net, init_type='kaiming', scale=1, std=0.02): 61 | # scale for 'kaiming', std for 'normal'. 62 | logger.info('Initialization method [{:s}]'.format(init_type)) 63 | if init_type == 'normal': 64 | weights_init_normal_ = functools.partial(weights_init_normal, std=std) 65 | net.apply(weights_init_normal_) 66 | elif init_type == 'kaiming': 67 | weights_init_kaiming_ = functools.partial( 68 | weights_init_kaiming, scale=scale) 69 | net.apply(weights_init_kaiming_) 70 | elif init_type == 'orthogonal': 71 | net.apply(weights_init_orthogonal) 72 | else: 73 | raise NotImplementedError( 74 | 'initialization method [{:s}] not implemented'.format(init_type)) 75 | 76 | 77 | #################### 78 | # define network 79 | #################### 80 | 81 | 82 | # Generator 83 | def define_G(opt, student=False): 84 | model_opt = opt['model'] 85 | print(model_opt['which_model_G']) 86 | if model_opt['which_model_G'] == 'ddpm': 87 | from .ddpm_modules import diffusion, unet 88 | if ('norm_groups' not in model_opt['unet']) or model_opt['unet']['norm_groups'] is None: 89 | model_opt['unet']['norm_groups']=32 90 | 91 | 92 | model = unet.UNet( 93 | in_channel=model_opt['unet']['in_channel'], 94 | out_channel=model_opt['unet']['out_channel'], 95 | norm_groups=model_opt['unet']['norm_groups'], 96 | inner_channel=model_opt['unet']['inner_channel'], 97 | channel_mults=model_opt['unet']['channel_multiplier'], 98 | attn_res=model_opt['unet']['attn_res'], 99 | res_blocks=model_opt['unet']['res_blocks'], 100 | dropout=model_opt['unet']['dropout'], 101 | image_size=model_opt['diffusion']['image_size'] 102 | ) 103 | ''' 104 | if opt['freq_aware']: 105 | model = unet.Free_UNet( 106 | in_channel=model_opt['unet']['in_channel'], 107 | out_channel=model_opt['unet']['out_channel'], 108 | norm_groups=model_opt['unet']['norm_groups'], 109 | inner_channel=model_opt['unet']['inner_channel'], 110 | channel_mults=model_opt['unet']['channel_multiplier'], 111 | attn_res=model_opt['unet']['attn_res'], 112 | res_blocks=model_opt['unet']['res_blocks'], 113 | dropout=model_opt['unet']['dropout'], 114 | image_size=model_opt['diffusion']['image_size'], 115 | b1=opt['freq_awareUNet']['b1'], 116 | b2=opt['freq_awareUNet']['b2'], 117 | s1=opt['freq_awareUNet']['s1'], 118 | s2=opt['freq_awareUNet']['s2'] 119 | ) 120 | ''' 121 | 122 | # print(model_opt['beta_schedule']['train']['n_timestep']) 123 | if student: 124 | netG = diffusion.GaussianDiffusion( 125 | model, 126 | image_size=model_opt['diffusion']['image_size'], 127 | num_timesteps=model_opt['beta_schedule']['train']['n_timestep'] // 2, 128 | time_scale=model_opt['beta_schedule']['train']['time_scale'] * 2, 129 | channels=model_opt['diffusion']['channels'], 130 | w_gt= model_opt['diffusion']['w_gt'], 131 | w_snr= model_opt['diffusion']['w_snr'], 132 | w_str= model_opt['diffusion']['w_str'], 133 | w_lpips= model_opt['diffusion']['w_lpips'], 134 | loss_type='l1', 135 | conditional=model_opt['diffusion']['conditional'], 136 | schedule_opt=model_opt['beta_schedule']['train']) 137 | else: 138 | 139 | netG = diffusion.GaussianDiffusion( 140 | model, 141 | image_size=model_opt['diffusion']['image_size'], 142 | num_timesteps=model_opt['beta_schedule']['train']['n_timestep'] , 143 | time_scale=model_opt['beta_schedule']['train']['time_scale'], 144 | channels=model_opt['diffusion']['channels'], 145 | w_gt= model_opt['diffusion']['w_gt'], 146 | w_snr= model_opt['diffusion']['w_snr'], 147 | w_str= model_opt['diffusion']['w_str'], 148 | w_lpips= model_opt['diffusion']['w_lpips'], 149 | loss_type='l1', 150 | conditional=model_opt['diffusion']['conditional'], 151 | schedule_opt=model_opt['beta_schedule']['train']) 152 | 153 | if opt['phase'] == 'train': 154 | # init_weights(netG, init_type='kaiming', scale=0.1) 155 | init_weights(netG, init_type='orthogonal') 156 | if opt['gpu_ids'] and opt['distributed']: 157 | assert torch.cuda.is_available() 158 | netG = nn.DataParallel(netG) 159 | return netG 160 | 161 | 162 | # Generator 163 | def define_GGG(opt): 164 | model_opt = opt['model'] 165 | print(model_opt['which_model_G']) 166 | if model_opt['which_model_G'] == 'ddpm': 167 | from .ddpm_modules import diffusion, unet 168 | if ('norm_groups' not in model_opt['unet']) or model_opt['unet']['norm_groups'] is None: 169 | model_opt['unet']['norm_groups']=32 170 | model = unet.UNet( 171 | in_channel=model_opt['unet']['in_channel'], 172 | out_channel=model_opt['unet']['out_channel'], 173 | norm_groups=model_opt['unet']['norm_groups'], 174 | inner_channel=model_opt['unet']['inner_channel'], 175 | channel_mults=model_opt['unet']['channel_multiplier'], 176 | attn_res=model_opt['unet']['attn_res'], 177 | res_blocks=model_opt['unet']['res_blocks'], 178 | dropout=model_opt['unet']['dropout'], 179 | image_size=model_opt['diffusion']['image_size'] 180 | ) 181 | netGVar = diffusion.GaussianDiffusion( 182 | model, 183 | image_size=model_opt['diffusion']['image_size'], 184 | channels=model_opt['diffusion']['channels'], 185 | loss_type='l1', 186 | conditional=model_opt['diffusion']['conditional'], 187 | schedule_opt=model_opt['beta_schedule']['train'] 188 | ) 189 | if opt['phase'] == 'train': 190 | # init_weights(netG, init_type='kaiming', scale=0.1) 191 | init_weights(netGVar, init_type='orthogonal') 192 | if opt['gpu_ids'] and opt['distributed']: 193 | assert torch.cuda.is_available() 194 | netGVar = nn.DataParallel(netGVar) 195 | return netGVar -------------------------------------------------------------------------------- /options/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /options/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/options/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /options/__pycache__/options.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/options/__pycache__/options.cpython-38.pyc -------------------------------------------------------------------------------- /options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import yaml 5 | from utils.util import OrderedYaml 6 | 7 | Loader, Dumper = OrderedYaml() 8 | 9 | 10 | def parse(opt_path, is_train=True): 11 | with open(opt_path, mode='r') as f: 12 | opt = yaml.load(f, Loader=Loader) 13 | # export CUDA_VISIBLE_DEVICES 14 | gpu_list = ','.join(str(x) for x in opt.get('gpu_ids', [])) 15 | # os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 16 | # print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 17 | opt['is_train'] = is_train 18 | 19 | # datasets 20 | for phase, dataset in opt['datasets'].items(): 21 | phase = phase.split('_')[0] 22 | dataset['phase'] = phase 23 | 24 | is_lmdb = False 25 | if dataset.get('dataroot_GT', None) is not None: 26 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) 27 | if dataset['dataroot_GT'].endswith('lmdb'): 28 | is_lmdb = True 29 | if dataset.get('dataroot_LQ', None) is not None: 30 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) 31 | if dataset['dataroot_LQ'].endswith('lmdb'): 32 | is_lmdb = True 33 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img' 34 | 35 | # relative learning rate 36 | if 'train' in opt: 37 | niter = opt['train']['niter'] 38 | if 'T_period_rel' in opt['train']: 39 | opt['train']['T_period'] = [int(x * niter) for x in opt['train']['T_period_rel']] 40 | if 'restarts_rel' in opt['train']: 41 | opt['train']['restarts'] = [int(x * niter) for x in opt['train']['restarts_rel']] 42 | if 'lr_steps_rel' in opt['train']: 43 | opt['train']['lr_steps'] = [int(x * niter) for x in opt['train']['lr_steps_rel']] 44 | if 'lr_steps_inverse_rel' in opt['train']: 45 | opt['train']['lr_steps_inverse'] = [int(x * niter) for x in opt['train']['lr_steps_inverse_rel']] 46 | print(opt['train']) 47 | 48 | return opt 49 | 50 | 51 | def dict2str(opt, indent_l=1): 52 | '''dict to string for logger''' 53 | msg = '' 54 | for k, v in opt.items(): 55 | if isinstance(v, dict): 56 | msg += ' ' * (indent_l * 2) + k + ':[\n' 57 | msg += dict2str(v, indent_l + 1) 58 | msg += ' ' * (indent_l * 2) + ']\n' 59 | else: 60 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 61 | return msg 62 | 63 | 64 | class NoneDict(dict): 65 | def __missing__(self, key): 66 | return None 67 | 68 | 69 | # convert to NoneDict, which return None for missing key. 70 | def dict_to_nonedict(opt): 71 | if isinstance(opt, dict): 72 | new_opt = dict() 73 | for key, sub_opt in opt.items(): 74 | new_opt[key] = dict_to_nonedict(sub_opt) 75 | return NoneDict(**new_opt) 76 | elif isinstance(opt, list): 77 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 78 | else: 79 | return opt 80 | 81 | 82 | def check_resume(opt, resume_iter): 83 | '''Check resume states and pretrain_model paths''' 84 | logger = logging.getLogger('base') 85 | if opt['path']['resume_state']: 86 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( 87 | 'pretrain_model_D', None) is not None: 88 | logger.warning('pretrain_model path will be ignored when resuming training.') 89 | 90 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], 91 | '{}_G.pth'.format(resume_iter)) 92 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 93 | if 'gan' in opt['model']: 94 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], 95 | '{}_D.pth'.format(resume_iter)) 96 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) 97 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python==4.5.2.54 2 | PyYAML==6.0 3 | natsort==8.1.0 4 | scikit-image==0.18.1 5 | lpips==0.1.4 6 | kmeans_pytorch 7 | scikit-learn==1.0 8 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import basename 3 | import math 4 | import argparse 5 | import random 6 | import logging 7 | import cv2 8 | import sys 9 | import numpy as np 10 | import torch 11 | import torch.distributed as dist 12 | import torch.multiprocessing as mp 13 | import torch.nn as nn 14 | 15 | import options.options as option 16 | from utils import util 17 | from data import create_dataloader 18 | from data.LoL_dataset import LOLv1_Dataset, LOLv2_Dataset 19 | import torchvision.transforms as T 20 | import lpips 21 | import model as Model 22 | import core.logger as Logger 23 | import core.metrics as Metrics 24 | from torchvision import transforms 25 | 26 | 27 | transform = transforms.Lambda(lambda t: (t * 2) - 1) 28 | 29 | def main(): 30 | 31 | parser = argparse.ArgumentParser() 32 | parser = argparse.ArgumentParser() 33 | parser.add_argument('--dataset', type=str, help='Path to option YMAL file.', 34 | default='./config/LOLv1.yml') # 35 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', 36 | help='job launcher') 37 | parser.add_argument('--local_rank', type=int, default=0) 38 | parser.add_argument('--tfboard', action='store_true') 39 | 40 | 41 | parser.add_argument('-c', '--config', type=str, default='config/lolv1_test.json', 42 | help='JSON file for configuration') 43 | parser.add_argument('-p', '--phase', type=str, choices=['train', 'val'], 44 | help='Run either train(training) or val(generation)', default='train') 45 | parser.add_argument('-gpu', '--gpu_ids', type=str, default="0") 46 | parser.add_argument('-debug', '-d', action='store_true') 47 | parser.add_argument('-log_eval', action='store_true') 48 | 49 | # search noise schedule 50 | parser.add_argument('--brutal_search', action='store_true') 51 | parser.add_argument('--noise_start', type=float, default=9e-4) 52 | parser.add_argument('--noise_end', type=float, default=8.5e-1) 53 | parser.add_argument('--n_timestep', type=int, default=16) 54 | 55 | parser.add_argument('--w_str', type=float, default=0.) 56 | parser.add_argument('--w_snr', type=float, default=0.) 57 | parser.add_argument('--w_gt', type=float, default=0.1) 58 | parser.add_argument('--w_lpips', type=float, default=0.1) 59 | 60 | parser.add_argument('--stride', type=int, default=1) 61 | 62 | 63 | # for freq_aware 64 | parser.add_argument('--freq_aware', action='store_true') 65 | parser.add_argument('--b1', type=float, default=1.6) 66 | parser.add_argument('--b2', type=float, default=1.6) 67 | parser.add_argument('--s1', type=float, default=0.9) 68 | parser.add_argument('--s2', type=float, default=0.9) 69 | 70 | args = parser.parse_args() 71 | opt = Logger.parse(args) 72 | opt = Logger.dict_to_nonedict(opt) 73 | opt_dataset = option.parse(args.dataset, is_train=True) 74 | 75 | 76 | 77 | os.environ['CUDA_VISIBLE_DEVICES'] = '0' 78 | 79 | opt['phase'] = 'test' 80 | opt['distill'] = False 81 | opt['uncertainty_train'] = False 82 | 83 | #### distributed training settings 84 | opt['dist'] = False 85 | rank = -1 86 | print('Disabled distributed training.') 87 | 88 | #### mkdir and loggers 89 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) 90 | # config loggers. Before it, the log will not work 91 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, 92 | screen=True, tofile=True) 93 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, 94 | screen=True, tofile=True) 95 | logger = logging.getLogger('base') 96 | logger.info(option.dict2str(opt)) 97 | 98 | # convert to NoneDict, which returns None for missing keys 99 | opt = option.dict_to_nonedict(opt) 100 | 101 | #### seed 102 | seed = opt['seed'] 103 | if seed is None: 104 | seed = random.randint(1, 10000) 105 | if rank <= 0: 106 | logger.info('Seed: {}'.format(seed)) 107 | util.set_random_seed(seed) 108 | 109 | torch.backends.cudnn.benchmark = True 110 | # torch.backends.cudnn.deterministic = True 111 | 112 | #### create train and val dataloader 113 | if opt_dataset['dataset'] == 'LOLv1': 114 | dataset_cls = LOLv1_Dataset 115 | elif opt_dataset['dataset'] == 'LOLv2': 116 | dataset_cls = LOLv2_Dataset 117 | 118 | else: 119 | raise NotImplementedError() 120 | 121 | for phase, dataset_opt in opt_dataset['datasets'].items(): 122 | if phase == 'val': 123 | val_set = dataset_cls(opt=dataset_opt, train=False, all_opt=opt_dataset) 124 | val_loader = create_dataloader(val_set, dataset_opt, opt_dataset, None) 125 | 126 | # opt["model"]['beta_schedule']["train"]["time_scale"] = 1 127 | 128 | opt["model"]["diffusion"]["w_snr"] = args.w_snr 129 | opt["model"]["diffusion"]["w_str"] = args.w_str 130 | opt["model"]["diffusion"]["w_gt"] = args.w_gt 131 | 132 | # model 133 | diffusion = Model.create_model(opt) 134 | logger.info('Initial Model Finished') 135 | 136 | 137 | 138 | loss_fn_vgg = lpips.LPIPS(net='vgg').cuda() 139 | result_path = '{}'.format(opt['path']['results']) 140 | result_path_gt = result_path+'/gt/' 141 | result_path_out = result_path+'/output/' 142 | result_path_input = result_path+'/input/' 143 | os.makedirs(result_path_gt, exist_ok=True) 144 | os.makedirs(result_path_out, exist_ok=True) 145 | os.makedirs(result_path_input, exist_ok=True) 146 | 147 | #diffusion.set_new_noise_schedule( 148 | #opt['model']['beta_schedule']['val'], schedule_phase='val') 149 | 150 | 151 | logger_val = logging.getLogger('val') # validation logger 152 | 153 | avg_psnr = 0.0 154 | avg_ssim = 0.0 155 | avg_lpips = 0.0 156 | idx = 0 157 | lpipss = [] 158 | 159 | for val_data in val_loader: 160 | 161 | idx += 1 162 | diffusion.feed_data(val_data) 163 | diffusion.test(continous=False) 164 | 165 | visuals = diffusion.get_current_visuals() 166 | 167 | normal_img = Metrics.tensor2img(visuals['HQ']) 168 | if normal_img.shape[0] != normal_img.shape[1]: # lolv1 and lolv2-real 169 | normal_img = normal_img[8:408, 4:604,:] 170 | gt_img = Metrics.tensor2img(visuals['GT']) 171 | ll_img = Metrics.tensor2img(visuals['LQ']) 172 | 173 | img_mode = 'single' 174 | if img_mode == 'single': 175 | util.save_img( 176 | gt_img, '{}/{}_gt.png'.format(result_path_gt, idx)) 177 | util.save_img( 178 | ll_img, '{}/{}_lq.png'.format(result_path_input, idx)) 179 | # util.save_img( 180 | # normal_img, '{}/{}_normal_noadjust.png'.format(result_path, idx)) 181 | else: 182 | util.save_img( 183 | gt_img, '{}/{}_gt.png'.format(result_path, idx)) 184 | util.save_img( 185 | normal_img, '{}/{}_{}_normal_process.png'.format(result_path, idx)) 186 | # for i in range(visuals['HQ'].shape[0]): 187 | # util.save_img(Metrics.tensor2img(visuals['HQ'][i]), '{}/{}_{}_normal.png'.format(result_path, idx, i)) 188 | # util.save_img( 189 | # Metrics.tensor2img(visuals['HQ'][-1]), '{}/{}_normal.png'.format(result_path, idx)) 190 | normal_img = Metrics.tensor2img(visuals['HQ'][-1]) 191 | 192 | # Similar to LLFlow, we follow a similar way of 'Kind' to finetune the overall brightness 193 | # as illustrated in Line 73 (https://github.com/zhangyhuaee/KinD/blob/master/evaluate_LOLdataset.py). 194 | gt_img = gt_img / 255. 195 | normal_img = normal_img / 255. 196 | mean_gray_out = cv2.cvtColor(normal_img.astype(np.float32), cv2.COLOR_BGR2GRAY).mean() 197 | mean_gray_gt = cv2.cvtColor(gt_img.astype(np.float32), cv2.COLOR_BGR2GRAY).mean() 198 | normal_img_adjust = np.clip(normal_img * (mean_gray_gt / mean_gray_out), 0, 1) 199 | 200 | normal_img = (normal_img_adjust * 255).astype(np.uint8) 201 | gt_img = (gt_img * 255).astype(np.uint8) 202 | 203 | psnr = util.calculate_psnr(normal_img, gt_img) 204 | ssim = util.calculate_ssim(normal_img, gt_img) 205 | 206 | normal_img_tensor = torch.tensor(normal_img.astype(np.float32)) 207 | gt_img_tensor = torch.tensor(gt_img.astype(np.float32)) 208 | normal_img_tensor = normal_img_tensor.permute(2, 0, 1).cuda() 209 | gt_img_tensor = gt_img_tensor.permute(2, 0, 1).cuda() 210 | lpips_scores = loss_fn_vgg(normal_img_tensor, gt_img_tensor).item() 211 | 212 | util.save_img(normal_img, '{}/{}_normal.png'.format(result_path_out, idx)) 213 | 214 | # lpips 215 | 216 | # lpips_ = loss_fn_vgg(visuals['HQ'], visuals['GT']) 217 | # lpipss.append(lpips_scores.numpy()) 218 | 219 | logger_val.info('### {} cPSNR: {:.4e} cSSIM: {:.4e} cLPIPS: {:.4e}'.format(idx, psnr, ssim, lpips_scores)) 220 | avg_ssim += ssim 221 | avg_psnr += psnr 222 | avg_lpips += lpips_scores 223 | 224 | avg_psnr = avg_psnr / idx 225 | avg_ssim = avg_ssim / idx 226 | avg_lpips = avg_lpips / idx 227 | 228 | # log 229 | logger_val.info('# Validation # avgPSNR: {:.4e} avgSSIM: {:.4e} avgLPIPS: {:.4e}'.format(avg_psnr, avg_ssim, avg_lpips)) 230 | # logger_val.info(f"n_timestep: {args.n_timestep}, noise_start: {args.noise_start}, noise_end: {args.noise_end}") 231 | 232 | 233 | 234 | if __name__ == '__main__': 235 | main() 236 | -------------------------------------------------------------------------------- /test.sh: -------------------------------------------------------------------------------- 1 | python test.py --dataset ./config/lolv2_real.yml --config config/lolv2_real_test.json --w_str 0.9 --w_snr 0.2 --w_gt 0.2 2 | 3 | 4 | -------------------------------------------------------------------------------- /test_unpaired.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import basename 3 | import math 4 | import argparse 5 | import random 6 | import logging 7 | import cv2 8 | import torch 9 | import torchvision.transforms.functional as TF 10 | from PIL import Image 11 | import options.options as option 12 | from utils import util 13 | import torchvision.transforms as T 14 | import model as Model 15 | import core.logger as Logger 16 | import core.metrics as Metrics 17 | import natsort 18 | from torchvision import transforms 19 | from utils.niqe import niqe 20 | 21 | transform = transforms.Lambda(lambda t: (t * 2) - 1) 22 | 23 | def main(): 24 | #### options 25 | parser = argparse.ArgumentParser() 26 | parser = argparse.ArgumentParser() 27 | parser.add_argument('--dataset', type=str, help='Path to option YMAL file.', 28 | default='./config/dataset.yml') # 29 | parser.add_argument('--input', type=str, help='testing the unpaired image', 30 | default='images/unpaired/') 31 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', 32 | help='job launcher') 33 | parser.add_argument('--local_rank', type=int, default=0) 34 | parser.add_argument('--tfboard', action='store_true') 35 | parser.add_argument('-c', '--config', type=str, default='config/test_unpaired.json', 36 | help='JSON file for configuration') 37 | parser.add_argument('-p', '--phase', type=str, choices=['train', 'val'], 38 | help='Run either train(training) or val(generation)', default='train') 39 | parser.add_argument('-gpu', '--gpu_ids', type=str, default="0") 40 | parser.add_argument('-debug', '-d', action='store_true') 41 | parser.add_argument('-enable_wandb', action='store_true') 42 | parser.add_argument('-log_wandb_ckpt', action='store_true') 43 | parser.add_argument('-log_eval', action='store_true') 44 | 45 | parser.add_argument('--n_timestep', type=int, default=8) 46 | parser.add_argument('--w_str', type=float, default=0.) 47 | parser.add_argument('--w_snr', type=float, default=0.) 48 | parser.add_argument('--w_gt', type=float, default=0.) 49 | parser.add_argument('--w_lpips', type=float, default=0.) 50 | 51 | 52 | parser.add_argument('--brutal_search', action='store_true') 53 | 54 | # parse configs 55 | args = parser.parse_args() 56 | opt = Logger.parse(args) 57 | # Convert to NoneDict, which return None for missing key. 58 | opt = Logger.dict_to_nonedict(opt) 59 | 60 | # os.environ['CUDA_VISIBLE_DEVICES'] = '0' 61 | 62 | opt['phase'] = 'test' 63 | 64 | #### distributed training settings 65 | opt['dist'] = False 66 | rank = -1 67 | print('Disabled distributed training.') 68 | 69 | #### mkdir and loggers 70 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) 71 | # config loggers. Before it, the log will not work 72 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, 73 | screen=True, tofile=True) 74 | logger = logging.getLogger('base') 75 | logger.info(option.dict2str(opt)) 76 | 77 | 78 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True, tofile=True) 79 | logger = logging.getLogger('base') 80 | 81 | # convert to NoneDict, which returns None for missing keys 82 | opt = option.dict_to_nonedict(opt) 83 | 84 | #### random seed 85 | seed = opt['train']['manual_seed'] 86 | if seed is None: 87 | seed = random.randint(1, 10000) 88 | if rank <= 0: 89 | logger.info('Random seed: {}'.format(seed)) 90 | util.set_random_seed(seed) 91 | 92 | torch.backends.cudnn.benchmark = True 93 | # torch.backends.cudnn.deterministic = True 94 | 95 | 96 | # model 97 | diffusion = Model.create_model(opt) 98 | logger.info('Initial Model Finished') 99 | 100 | result_path = '{}'.format(opt['path']['results']) 101 | os.makedirs(result_path, exist_ok=True) 102 | 103 | # diffusion.set_new_noise_schedule( 104 | # opt['model']['beta_schedule']['val'], schedule_phase='val') 105 | 106 | InputPath = args.input 107 | Image_names = natsort.natsorted(os.listdir(InputPath), alg=natsort.ns.PATH) 108 | 109 | ave_niqe = 0. 110 | 111 | for i in range(len(Image_names)): 112 | 113 | path = InputPath + Image_names[i] 114 | raw_img = Image.open(path).convert('RGB') 115 | img_w = raw_img.size[0] 116 | img_h = raw_img.size[1] 117 | raw_img = transforms.Resize((img_h // 16 * 16, img_w // 16 * 16))(raw_img) 118 | 119 | raw_img = transform(TF.to_tensor(raw_img)).unsqueeze(0).cuda() 120 | 121 | val_data = {} 122 | val_data['LQ'] = raw_img 123 | val_data['GT'] = raw_img 124 | diffusion.feed_data(val_data) 125 | diffusion.test(continous=False) 126 | 127 | visuals = diffusion.get_current_visuals() 128 | 129 | normal_img = Metrics.tensor2img(visuals['HQ']) 130 | # normal_img = cv2.resize(normal_img, (img_w, img_h)) 131 | ll_img = Metrics.tensor2img(visuals['LQ']) 132 | niqe_scores = niqe(normal_img) 133 | ave_niqe = ave_niqe + niqe_scores 134 | llie_img_mode = 'single' 135 | if llie_img_mode == 'single': 136 | # util.save_img(ll_img, '{}/{}_input.png'.format(result_path, idx)) 137 | util.save_img( 138 | normal_img, '{}/{}_normal.png'.format(result_path, i+1)) 139 | else: 140 | util.save_img( 141 | normal_img, '{}/{}_{}_normal_process.png'.format(result_path, i)) 142 | util.save_img( 143 | Metrics.tensor2img(visuals['HQ'][-1]), '{}/{}_normal.png'.format(result_path, i)) 144 | normal_img = Metrics.tensor2img(visuals['HQ'][-1]) 145 | logger.info('CNIQE: {} on {}'.format(niqe_scores, Image_names[i])) 146 | ave_niqe = ave_niqe / len(Image_names) 147 | logger.info('NIQE: {} on {}'.format(ave_niqe, InputPath)) 148 | 149 | 150 | 151 | if __name__ == '__main__': 152 | main() 153 | print("finish!") 154 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | from os.path import basename 3 | import math 4 | import argparse 5 | import logging 6 | import cv2 7 | import sys 8 | import numpy as np 9 | import torch 10 | import torch.distributed as dist 11 | import torch.multiprocessing as mp 12 | import torch.nn as nn 13 | from torchvision import transforms 14 | import options.options as option 15 | from utils import util 16 | from data import create_dataloader 17 | import data as Data 18 | from data.LoL_dataset import LOLv1_Dataset, LOLv2_Dataset 19 | from data.SDSD_image_dataset import Dataset_SDSDImage 20 | from data.SID import ImageDataset2 21 | import torchvision.transforms as T 22 | import model as Model 23 | import core.logger as Logger 24 | import core.metrics as Metrics 25 | import random 26 | import lpips 27 | 28 | import pdb 29 | 30 | 31 | 32 | 33 | def init_dist(backend='nccl', **kwargs): 34 | """initialization for distributed training""" 35 | if mp.get_start_method(allow_none=True) != 'spawn': 36 | mp.set_start_method('spawn') 37 | rank = int(os.environ['RANK']) 38 | num_gpus = torch.cuda.device_count() 39 | torch.cuda.set_device(rank % num_gpus) 40 | dist.init_process_group(backend=backend, **kwargs) 41 | 42 | 43 | def main(): 44 | 45 | parser = argparse.ArgumentParser() 46 | parser.add_argument('--dataset', type=str, help='Path to option YMAL file.', 47 | default='./config/lolv2_real.yml') # 48 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', 49 | help='job launcher') 50 | parser.add_argument('--local_rank', type=int, default=0) 51 | 52 | parser.add_argument('-c', '--config', type=str, default='config/lolv1_train.json', 53 | help='JSON file for configuration') 54 | parser.add_argument('-p', '--phase', type=str, choices=['train', 'val'], 55 | help='Run either train(training) or val(generation)', default='train') 56 | parser.add_argument('-gpu', '--gpu_ids', type=str, default=None) 57 | parser.add_argument('-debug', '-d', action='store_true') 58 | parser.add_argument('-log_eval', action='store_true') 59 | parser.add_argument('-uncertainty', action='store_true') 60 | 61 | # for ablation 62 | parser.add_argument('--ablation', action='store_true') 63 | parser.add_argument('--w_str', type=float, default=0.2) 64 | parser.add_argument('--w_snr', type=float, default=0.9) 65 | parser.add_argument('--w_gt', type=float, default=0.2) 66 | parser.add_argument('--w_lpips', type=float, default=0.2) 67 | 68 | parser.add_argument('--progressive', action='store_true') 69 | parser.add_argument('--CD', action='store_true') 70 | 71 | 72 | 73 | parser.add_argument('--brutal_search', action='store_true') 74 | 75 | # ema config 76 | parser.add_argument('--ema_decay', type=float, default=0.999) 77 | 78 | # parse configs 79 | args = parser.parse_args() 80 | opt = Logger.parse(args) 81 | # Convert to NoneDict, which return None for missing key. 82 | opt = Logger.dict_to_nonedict(opt) 83 | opt_dataset = option.parse(args.dataset, is_train=True) 84 | 85 | if args.ablation: 86 | opt["model"]["diffusion"]["w_snr"] = args.w_snr 87 | opt["model"]["diffusion"]["w_str"] = args.w_str 88 | opt["model"]["diffusion"]["w_gt"] = args.w_gt 89 | opt["model"]["diffusion"]["w_lpips"] = args.w_lpips 90 | if args.CD: 91 | opt["CD"] = True 92 | if args.launcher == 'none': # disabled distributed training 93 | opt['dist'] = False 94 | rank = -1 95 | print('Disabled distributed training.') 96 | else: 97 | opt['dist'] = True 98 | init_dist() 99 | torch.distributed.init_process_group( 100 | 'nccl', 101 | init_method='env://' 102 | ) 103 | world_size = torch.distributed.get_world_size() 104 | rank = torch.distributed.get_rank() 105 | device = torch.device("cuda", rank) 106 | 107 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) 108 | 109 | # config loggers. Before it, the log will not work 110 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, 111 | screen=True, tofile=True) 112 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, 113 | screen=True, tofile=True) 114 | logger = logging.getLogger('base') 115 | # import pdb; pdb.set_trace() 116 | logger.info(option.dict2str(opt)) 117 | 118 | # import pdb; pdb.set_trace() 119 | 120 | # tensorboard logger 121 | if opt.get('use_tb_logger', False) and 'debug' not in opt['name']: 122 | version = float(torch.__version__[0:3]) 123 | if version >= 1.1: # PyTorch 1.1 124 | # from torch.utils.tensorboard import SummaryWriter 125 | if sys.platform != 'win32': 126 | from tensorboardX import SummaryWriter 127 | else: 128 | from tensorboardX import SummaryWriter 129 | # from torch.utils.tensorboard import SummaryWriter 130 | else: 131 | logger.info( 132 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) 133 | from tensorboard import SummaryWriter 134 | conf_name = basename(args.opt).replace(".yml", "") 135 | exp_dir = opt['path']['experiments_root'] 136 | log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train') 137 | log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid') 138 | tb_logger_train = SummaryWriter(log_dir=log_dir_train) 139 | tb_logger_valid = SummaryWriter(log_dir=log_dir_valid) 140 | else: 141 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) 142 | logger = logging.getLogger('base') 143 | 144 | # convert to NoneDict, which returns None for missing keys 145 | opt = option.dict_to_nonedict(opt) 146 | 147 | #### random seed 148 | seed = opt['seed'] 149 | if seed is None: 150 | seed = random.randint(1, 10000) 151 | if rank <= 0: 152 | logger.info('Random seed: {}'.format(seed)) 153 | util.set_random_seed(seed) 154 | 155 | torch.backends.cudnn.benchmark = True 156 | torch.backends.cudnn.deterministic = False 157 | torch.backends.cudnn.allow_tf32 = True 158 | 159 | #### create train and val dataloader 160 | if opt_dataset['dataset'] == 'LOLv1': 161 | dataset_cls = LOLv1_Dataset 162 | PD_steps = [16, 8, 4, 2, 1] 163 | temp_time_scale = [1, 2, 4, 8, 16] 164 | time_scale = [i * 32 for i in temp_time_scale] 165 | elif opt_dataset['dataset'] == 'LOLv2': 166 | dataset_cls = LOLv2_Dataset 167 | PD_steps = [16, 8, 4, 2, 1] 168 | temp_time_scale = [1, 2, 4, 8, 16] 169 | time_scale = [i * 32 for i in temp_time_scale] 170 | elif opt_dataset['dataset'] == 'SDSD_indoor': 171 | dataset_cls = Dataset_SDSDImage 172 | PD_steps = [16, 8, 4, 2, 1] 173 | temp_time_scale = [1, 2, 4, 8, 16] 174 | time_scale = [i * 32 for i in temp_time_scale] 175 | 176 | elif opt_dataset['dataset'] == 'SID': 177 | dataset_cls = ImageDataset2 178 | PD_steps = [16, 8, 4, 2, 1] 179 | temp_time_scale = [1, 2, 4, 8, 16] 180 | time_scale = [i * 32 for i in temp_time_scale] 181 | else: 182 | raise NotImplementedError() 183 | 184 | for phase, dataset_opt in opt_dataset['datasets'].items(): 185 | if phase == 'train': 186 | train_set = dataset_cls(opt=dataset_opt, train=True, all_opt=opt_dataset) 187 | train_loader = create_dataloader(train_set, dataset_opt, opt_dataset, None) 188 | elif phase == 'val': 189 | val_set = dataset_cls(opt=dataset_opt, train=False, all_opt=opt_dataset) 190 | val_loader = create_dataloader(val_set, dataset_opt, opt_dataset, None) 191 | 192 | # model 193 | resume_state = opt["path"]["resume_state"] 194 | lpips_func = lpips.LPIPS(net='vgg').cuda() 195 | 196 | for i in range(len(PD_steps)): 197 | opt["model"]['beta_schedule']["train"]["n_timestep"] = PD_steps[i] 198 | opt["model"]['beta_schedule']["val"]["n_timestep"] = PD_steps[i+1] 199 | 200 | opt["path"]["resume_state"] = resume_state 201 | opt["model"]['beta_schedule']["train"]["time_scale"] = time_scale[i] 202 | logger.info('Distillation from {:d} to {:d}'.format(opt["model"]['beta_schedule']["train"]["n_timestep"], opt["model"]['beta_schedule']["val"]["n_timestep"])) 203 | logger.info(f"w_snr: {opt['model']['diffusion']['w_snr']}, w_str: {opt['model']['diffusion']['w_str']}") 204 | 205 | diffusion = Model.create_model(opt) 206 | 207 | logger.info('Initial Model Finished') 208 | # Train 209 | current_step = diffusion.begin_step 210 | current_epoch = diffusion.begin_epoch 211 | n_iter = opt['train']['n_iter'] # * iter_scale[i] 212 | # training 213 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(current_epoch, current_step)) 214 | avg_psnr = 0 215 | best_psnr = 0 216 | best_ssim = 0 217 | best_lpips = 0 218 | 219 | # pdb.set_trace() 220 | while current_step < n_iter: 221 | 222 | current_epoch += 1 223 | for _, train_data in enumerate(train_loader): 224 | 225 | current_step += 1 226 | if current_step > n_iter: 227 | break 228 | 229 | diffusion.feed_data(train_data) 230 | diffusion.optimize_parameters() 231 | # log 232 | if current_step % opt['train']['print_freq'] == 0 and rank <= 0: 233 | logs = diffusion.get_current_log() 234 | message = ' '.format( 235 | current_epoch, current_step) 236 | for k, v in logs.items(): 237 | message += '{:s}: {:.4e} '.format(k, v) 238 | logger.info(message) 239 | 240 | # validation 241 | if current_step % opt['train']['val_freq'] == 0 and rank <= 0: 242 | 243 | avg_psnr = 0.0 244 | avg_ssim = 0.0 245 | avg_lpips = 0.0 246 | idx = 0 247 | 248 | result_path = '{}/{}'.format(opt['path']['results'], PD_steps[i+1], current_epoch) 249 | result_path_gt = result_path+'/gt/' 250 | result_path_out = result_path+'/output/' 251 | result_path_input = result_path+'/input/' 252 | 253 | os.makedirs(result_path_gt, exist_ok=True) 254 | os.makedirs(result_path_out, exist_ok=True) 255 | os.makedirs(result_path_input, exist_ok=True) 256 | 257 | 258 | for val_data in val_loader: 259 | 260 | idx += 1 261 | diffusion.feed_data(val_data) 262 | diffusion.test(continous=False) 263 | 264 | visuals = diffusion.get_current_visuals() 265 | 266 | 267 | if opt_dataset['dataset'] == 'LOLv1' or opt_dataset['dataset'] == 'LOLv2' : 268 | normal_img = Metrics.tensor2img(visuals['HQ']) 269 | if normal_img.shape[0] != normal_img.shape[1]: # lolv1 and lolv2-real 270 | normal_img = normal_img[8:408, 4:604,:] 271 | gt_img = Metrics.tensor2img(visuals['GT']) 272 | ll_img = Metrics.tensor2img(visuals['LQ']) 273 | else: 274 | normal_img = Metrics.tensor2img2(visuals['HQ']) 275 | gt_img = Metrics.tensor2img2(visuals['GT']) 276 | ll_img = Metrics.tensor2img2(visuals['LQ']) 277 | 278 | img_mode = 'single' 279 | ''' 280 | if img_mode == 'single': 281 | util.save_img( 282 | gt_img, '{}/{}_gt.png'.format(result_path_gt, idx)) 283 | util.save_img( 284 | ll_img, '{}/{}_in.png'.format(result_path_input, idx)) 285 | # util.save_img( 286 | # normal_img, '{}/{}_normal.png'.format(result_path_out, idx)) 287 | else: 288 | util.save_img( 289 | gt_img, '{}/{}_{}_gt.png'.format(result_path, current_step, idx)) 290 | util.save_img( 291 | normal_img, '{}/{}_{}_normal_process.png'.format(result_path, current_step, idx)) 292 | util.save_img( 293 | Metrics.tensor2img(visuals['HQ'][-1]), '{}/{}_{}_normal.png'.format(result_path, current_step, idx)) 294 | normal_img = Metrics.tensor2img(visuals['HQ'][-1]) 295 | ''' 296 | 297 | 298 | # Similar to LLFlow, 299 | # we follow a similar way of 'Kind' to finetune the overall brightness as illustrated 300 | # in Line 73 (https://github.com/zhangyhuaee/KinD/blob/master/evaluate_LOLdataset.py). 301 | if opt_dataset['dataset'] == 'LOLv1' or opt_dataset['dataset'] == 'LOLv2': 302 | gt_img = gt_img / 255. 303 | normal_img = normal_img / 255. 304 | 305 | mean_gray_out = cv2.cvtColor(normal_img.astype(np.float32), cv2.COLOR_BGR2GRAY).mean() 306 | mean_gray_gt = cv2.cvtColor(gt_img.astype(np.float32), cv2.COLOR_BGR2GRAY).mean() 307 | normal_img_adjust = np.clip(normal_img * (mean_gray_gt / mean_gray_out), 0, 1) 308 | 309 | normal_img = (normal_img_adjust * 255).astype(np.uint8) 310 | gt_img = (gt_img * 255).astype(np.uint8) 311 | 312 | psnr = util.calculate_psnr(normal_img, gt_img) 313 | ssim = util.calculate_ssim(normal_img, gt_img) 314 | 315 | normal_img_tensor = torch.tensor(normal_img.astype(np.float32)) 316 | gt_img_tensor = torch.tensor(gt_img.astype(np.float32)) 317 | normal_img_tensor = normal_img_tensor.permute(2, 0, 1).cuda() 318 | gt_img_tensor = gt_img_tensor.permute(2, 0, 1).cuda() 319 | lpips_scores = lpips_func(normal_img_tensor, gt_img_tensor).item() 320 | 321 | util.save_img(normal_img, '{}/{}_normal.png'.format(result_path_out, idx)) 322 | 323 | logger.info('cPSNR: {:.4e} cSSIM: {:.4e} cLPIPS: {:.4e}'.format(psnr, ssim, lpips_scores)) 324 | 325 | avg_ssim += ssim 326 | avg_psnr += psnr 327 | avg_lpips += lpips_scores 328 | # break 329 | 330 | avg_psnr = avg_psnr / idx 331 | avg_ssim = avg_ssim / idx 332 | avg_lpips = avg_lpips / idx 333 | 334 | if avg_psnr > best_psnr: 335 | best_psnr = avg_psnr 336 | best_ssim = avg_ssim 337 | best_lpips = avg_lpips 338 | if current_step % opt['train']['save_checkpoint_freq'] == 0 and rank <= 0: 339 | logger.info('Saving models and training states.') 340 | gen_path = diffusion.save_network(PD_steps[i+1], current_epoch, current_step, best_psnr, best_ssim, best_lpips) 341 | if args.progressive: 342 | resume_state = gen_path 343 | # logger.info('# Validation Avg scores at timesteps {:3d} # PSNR: {:.4e} SSIM: {:.4e} LPIPS: {:.4e}'.format(PD_steps[i+1], avg_psnr, avg_ssim, avg_lpips)) 344 | logger_val = logging.getLogger('val') 345 | logger_val.info('# Avg scores # psnr: {:.4e} SSIM: {:.4e} LPIPS: {:.4e}'.format(PD_steps[i+1], 346 | current_epoch, current_step, avg_psnr, avg_ssim, avg_lpips)) 347 | logger_val.info('# Best scores # psnr: {:.4e} SSIM: {:.4e} LPIPS: {:.4e}'.format(PD_steps[i+1], best_psnr, best_ssim, best_lpips)) 348 | if opt["model"]['beta_schedule']["val"]["n_timestep"] == 2: 349 | break 350 | 351 | 352 | 353 | if __name__ == '__main__': 354 | 355 | main() 356 | -------------------------------------------------------------------------------- /train_lol1.sh: -------------------------------------------------------------------------------- 1 | 2 | python train.py --config ./config/lolv2_real_train.json --dataset ./config/lolv2_real.yml --w_str 0.0 --w_snr 0.8 --w_gt 1.0 --w_lpips 0.6 --ablation 3 | -------------------------------------------------------------------------------- /train_lol2_real.sh: -------------------------------------------------------------------------------- 1 | 2 | python train.py --config ./config/lolv2_real_train.json --dataset ./config/lolv2_real.yml --w_str 0.0 --w_snr 0.4 --w_gt 0.0 --w_lpips 0.6 --ablation -------------------------------------------------------------------------------- /train_lol2_syn.sh: -------------------------------------------------------------------------------- 1 | 2 | 3 | python train.py --config ./config/lolv2_syn_train.json --dataset ./config/lolv2_syn.yml --w_str 0.0 --w_snr 0.4 --w_gt 0.0 --w_lpips 0.6 --ablation & 4 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/__init__.py -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/ema.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/__pycache__/ema.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/util.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/__pycache__/util.cpython-38.pyc -------------------------------------------------------------------------------- /utils/ema.py: -------------------------------------------------------------------------------- 1 | class EMA(): 2 | def __init__(self, model, decay): 3 | self.model = model 4 | self.decay = decay 5 | self.shadow = {} 6 | self.backup = {} 7 | 8 | def register(self): 9 | for name, param in self.model.named_parameters(): 10 | if param.requires_grad: 11 | self.shadow[name] = param.data.clone() 12 | 13 | def update(self): 14 | for name, param in self.model.named_parameters(): 15 | if param.requires_grad: 16 | assert name in self.shadow 17 | new_average = (1.0 - self.decay) * param.data + self.decay * self.shadow[name] 18 | self.shadow[name] = new_average.clone() 19 | 20 | def apply_shadow(self): 21 | for name, param in self.model.named_parameters(): 22 | if param.requires_grad: 23 | assert name in self.shadow 24 | self.backup[name] = param.data 25 | param.data = self.shadow[name] 26 | 27 | def restore(self): 28 | for name, param in self.model.named_parameters(): 29 | if param.requires_grad: 30 | assert name in self.backup 31 | param.data = self.backup[name] 32 | self.backup = {} -------------------------------------------------------------------------------- /utils/niqe.py: -------------------------------------------------------------------------------- 1 | import math 2 | from os.path import dirname, join 3 | 4 | import cv2 5 | import numpy as np 6 | import scipy 7 | import scipy.io 8 | import scipy.misc 9 | import scipy.ndimage 10 | import scipy.special 11 | from PIL import Image 12 | 13 | gamma_range = np.arange(0.2, 10, 0.001) 14 | a = scipy.special.gamma(2.0/gamma_range) 15 | a *= a 16 | b = scipy.special.gamma(1.0/gamma_range) 17 | c = scipy.special.gamma(3.0/gamma_range) 18 | prec_gammas = a/(b*c) 19 | 20 | 21 | def aggd_features(imdata): 22 | # flatten imdata 23 | imdata.shape = (len(imdata.flat),) 24 | imdata2 = imdata*imdata 25 | left_data = imdata2[imdata < 0] 26 | right_data = imdata2[imdata >= 0] 27 | left_mean_sqrt = 0 28 | right_mean_sqrt = 0 29 | if len(left_data) > 0: 30 | left_mean_sqrt = np.sqrt(np.average(left_data)) 31 | if len(right_data) > 0: 32 | right_mean_sqrt = np.sqrt(np.average(right_data)) 33 | 34 | if right_mean_sqrt != 0: 35 | gamma_hat = left_mean_sqrt/right_mean_sqrt 36 | else: 37 | gamma_hat = np.inf 38 | # solve r-hat norm 39 | 40 | imdata2_mean = np.mean(imdata2) 41 | if imdata2_mean != 0: 42 | r_hat = (np.average(np.abs(imdata))**2) / (np.average(imdata2)) 43 | else: 44 | r_hat = np.inf 45 | rhat_norm = r_hat * (((math.pow(gamma_hat, 3) + 1) * 46 | (gamma_hat + 1)) / math.pow(math.pow(gamma_hat, 2) + 1, 2)) 47 | 48 | # solve alpha by guessing values that minimize ro 49 | pos = np.argmin((prec_gammas - rhat_norm)**2) 50 | alpha = gamma_range[pos] 51 | 52 | gam1 = scipy.special.gamma(1.0/alpha) 53 | gam2 = scipy.special.gamma(2.0/alpha) 54 | gam3 = scipy.special.gamma(3.0/alpha) 55 | 56 | aggdratio = np.sqrt(gam1) / np.sqrt(gam3) 57 | bl = aggdratio * left_mean_sqrt 58 | br = aggdratio * right_mean_sqrt 59 | 60 | # mean parameter 61 | N = (br - bl)*(gam2 / gam1) # *aggdratio 62 | return (alpha, N, bl, br, left_mean_sqrt, right_mean_sqrt) 63 | 64 | 65 | def ggd_features(imdata): 66 | nr_gam = 1/prec_gammas 67 | sigma_sq = np.var(imdata) 68 | E = np.mean(np.abs(imdata)) 69 | rho = sigma_sq/E**2 70 | pos = np.argmin(np.abs(nr_gam - rho)) 71 | return gamma_range[pos], sigma_sq 72 | 73 | 74 | def paired_product(new_im): 75 | shift1 = np.roll(new_im.copy(), 1, axis=1) 76 | shift2 = np.roll(new_im.copy(), 1, axis=0) 77 | shift3 = np.roll(np.roll(new_im.copy(), 1, axis=0), 1, axis=1) 78 | shift4 = np.roll(np.roll(new_im.copy(), 1, axis=0), -1, axis=1) 79 | 80 | H_img = shift1 * new_im 81 | V_img = shift2 * new_im 82 | D1_img = shift3 * new_im 83 | D2_img = shift4 * new_im 84 | 85 | return (H_img, V_img, D1_img, D2_img) 86 | 87 | 88 | def gen_gauss_window(lw, sigma): 89 | sd = np.float32(sigma) 90 | lw = int(lw) 91 | weights = [0.0] * (2 * lw + 1) 92 | weights[lw] = 1.0 93 | sum = 1.0 94 | sd *= sd 95 | for ii in range(1, lw + 1): 96 | tmp = np.exp(-0.5 * np.float32(ii * ii) / sd) 97 | weights[lw + ii] = tmp 98 | weights[lw - ii] = tmp 99 | sum += 2.0 * tmp 100 | for ii in range(2 * lw + 1): 101 | weights[ii] /= sum 102 | return weights 103 | 104 | 105 | def compute_image_mscn_transform(image, C=1, avg_window=None, extend_mode='constant'): 106 | if avg_window is None: 107 | avg_window = gen_gauss_window(3, 7.0/6.0) 108 | assert len(np.shape(image)) == 2 109 | h, w = np.shape(image) 110 | mu_image = np.zeros((h, w), dtype=np.float32) 111 | var_image = np.zeros((h, w), dtype=np.float32) 112 | image = np.array(image).astype('float32') 113 | scipy.ndimage.correlate1d(image, avg_window, 0, mu_image, mode=extend_mode) 114 | scipy.ndimage.correlate1d(mu_image, avg_window, 1, 115 | mu_image, mode=extend_mode) 116 | scipy.ndimage.correlate1d(image**2, avg_window, 0, 117 | var_image, mode=extend_mode) 118 | scipy.ndimage.correlate1d(var_image, avg_window, 119 | 1, var_image, mode=extend_mode) 120 | var_image = np.sqrt(np.abs(var_image - mu_image**2)) 121 | return (image - mu_image)/(var_image + C), var_image, mu_image 122 | 123 | 124 | def _niqe_extract_subband_feats(mscncoefs): 125 | # alpha_m, = extract_ggd_features(mscncoefs) 126 | alpha_m, N, bl, br, lsq, rsq = aggd_features(mscncoefs.copy()) 127 | pps1, pps2, pps3, pps4 = paired_product(mscncoefs) 128 | alpha1, N1, bl1, br1, lsq1, rsq1 = aggd_features(pps1) 129 | alpha2, N2, bl2, br2, lsq2, rsq2 = aggd_features(pps2) 130 | alpha3, N3, bl3, br3, lsq3, rsq3 = aggd_features(pps3) 131 | alpha4, N4, bl4, br4, lsq4, rsq4 = aggd_features(pps4) 132 | return np.array([alpha_m, (bl+br)/2.0, 133 | alpha1, N1, bl1, br1, # (V) 134 | alpha2, N2, bl2, br2, # (H) 135 | alpha3, N3, bl3, bl3, # (D1) 136 | alpha4, N4, bl4, bl4, # (D2) 137 | ]) 138 | 139 | 140 | def get_patches_train_features(img, patch_size, stride=8): 141 | return _get_patches_generic(img, patch_size, 1, stride) 142 | 143 | 144 | def get_patches_test_features(img, patch_size, stride=8): 145 | return _get_patches_generic(img, patch_size, 0, stride) 146 | 147 | 148 | def extract_on_patches(img, patch_size): 149 | h, w = img.shape 150 | patch_size = np.int32(patch_size) 151 | patches = [] 152 | for j in range(0, h-patch_size+1, patch_size): 153 | for i in range(0, w-patch_size+1, patch_size): 154 | patch = img[j:j+patch_size, i:i+patch_size] 155 | patches.append(patch) 156 | 157 | patches = np.array(patches) 158 | 159 | patch_features = [] 160 | for p in patches: 161 | patch_features.append(_niqe_extract_subband_feats(p)) 162 | patch_features = np.array(patch_features) 163 | 164 | return patch_features 165 | 166 | 167 | def _get_patches_generic(img, patch_size, is_train, stride): 168 | h, w = np.shape(img) 169 | if h < patch_size or w < patch_size: 170 | print("Input image is too small") 171 | exit(0) 172 | 173 | # ensure that the patch divides evenly into img 174 | hoffset = (h % patch_size) 175 | woffset = (w % patch_size) 176 | 177 | if hoffset > 0: 178 | img = img[:-hoffset, :] 179 | if woffset > 0: 180 | img = img[:, :-woffset] 181 | 182 | img = img.astype(np.float32) 183 | # img2 = scipy.misc.imresize(img, 0.5, interp='bicubic', mode='F') 184 | img2 = cv2.resize(img, (0, 0), fx=0.5, fy=0.5) 185 | 186 | mscn1, var, mu = compute_image_mscn_transform(img) 187 | mscn1 = mscn1.astype(np.float32) 188 | 189 | mscn2, _, _ = compute_image_mscn_transform(img2) 190 | mscn2 = mscn2.astype(np.float32) 191 | 192 | feats_lvl1 = extract_on_patches(mscn1, patch_size) 193 | feats_lvl2 = extract_on_patches(mscn2, patch_size/2) 194 | 195 | feats = np.hstack((feats_lvl1, feats_lvl2)) # feats_lvl3)) 196 | 197 | return feats 198 | 199 | 200 | def niqe(inputImgData): 201 | 202 | patch_size = 8 203 | module_path = dirname(__file__) 204 | 205 | # TODO: memoize 206 | params = scipy.io.loadmat( 207 | join(module_path, 'niqe_image_params.mat')) 208 | pop_mu = np.ravel(params["pop_mu"]) 209 | pop_cov = params["pop_cov"] 210 | 211 | if inputImgData.ndim == 3: 212 | inputImgData = cv2.cvtColor(inputImgData, cv2.COLOR_BGR2GRAY) 213 | M, N = inputImgData.shape 214 | 215 | # assert C == 1, "niqe called with videos containing %d channels. Please supply only the luminance channel" % (C,) 216 | assert M > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" 217 | assert N > (patch_size*2+1), "niqe called with small frame size, requires > 192x192 resolution video using current training parameters" 218 | 219 | feats = get_patches_test_features(inputImgData, patch_size) 220 | sample_mu = np.mean(feats, axis=0) 221 | sample_cov = np.cov(feats.T) 222 | 223 | X = sample_mu - pop_mu 224 | covmat = ((pop_cov+sample_cov)/2.0) 225 | pinvmat = scipy.linalg.pinv(covmat) 226 | niqe_score = np.sqrt(np.dot(np.dot(X, pinvmat), X)) 227 | 228 | return niqe_score 229 | -------------------------------------------------------------------------------- /utils/niqe_image_params.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lgz-0713/ReDDiT/ba046491b7135850ddf74900ad2ff691d57b8331/utils/niqe_image_params.mat -------------------------------------------------------------------------------- /utils/util.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import math 4 | from datetime import datetime 5 | import random 6 | import logging 7 | from collections import OrderedDict 8 | 9 | import natsort 10 | import numpy as np 11 | import cv2 12 | import torch 13 | from torchvision.utils import make_grid 14 | from shutil import get_terminal_size 15 | import torch 16 | import torch.nn.functional as F 17 | from torch.autograd import Variable 18 | import numpy as np 19 | from math import exp 20 | 21 | from skimage.metrics import structural_similarity as SSIM 22 | 23 | 24 | def gaussian(window_size, sigma): 25 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 26 | return gauss / gauss.sum() 27 | 28 | 29 | def create_window(window_size, channel): 30 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 31 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 32 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 33 | return window 34 | 35 | 36 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 37 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 38 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 39 | 40 | mu1_sq = mu1.pow(2) 41 | mu2_sq = mu2.pow(2) 42 | mu1_mu2 = mu1 * mu2 43 | 44 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 45 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 46 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 47 | 48 | C1 = 0.01 ** 2 49 | C2 = 0.03 ** 2 50 | 51 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 52 | 53 | if size_average: 54 | return ssim_map.mean() 55 | else: 56 | return ssim_map.mean(1).mean(1).mean(1) 57 | 58 | 59 | # class SSIM(torch.nn.Module): 60 | # def __init__(self, window_size=11, size_average=True): 61 | # super(SSIM, self).__init__() 62 | # self.window_size = window_size 63 | # self.size_average = size_average 64 | # self.channel = 1 65 | # self.window = create_window(window_size, self.channel) 66 | 67 | # def forward(self, img1, img2): 68 | # (_, channel, _, _) = img1.size() 69 | 70 | # if channel == self.channel and self.window.data.type() == img1.data.type(): 71 | # window = self.window 72 | # else: 73 | # window = create_window(self.window_size, channel) 74 | 75 | # if img1.is_cuda: 76 | # window = window.cuda(img1.get_device()) 77 | # window = window.type_as(img1) 78 | 79 | # self.window = window 80 | # self.channel = channel 81 | 82 | # return _ssim(img1, img2, window, self.window_size, channel, self.size_average) 83 | 84 | 85 | def ssim(img1, img2, window_size=11, size_average=True): 86 | (_, channel, _, _) = img1.size() 87 | window = create_window(window_size, channel) 88 | 89 | if img1.is_cuda: 90 | window = window.cuda(img1.get_device()) 91 | window = window.type_as(img1) 92 | 93 | return _ssim(img1.mean(dim=0, keepdims=True), img2.mean(dim=0, keepdims=True), window, window_size, channel, size_average) 94 | 95 | 96 | import yaml 97 | 98 | try: 99 | from yaml import CLoader as Loader, CDumper as Dumper 100 | except ImportError: 101 | from yaml import Loader, Dumper 102 | 103 | 104 | def OrderedYaml(): 105 | '''yaml orderedDict support''' 106 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 107 | 108 | def dict_representer(dumper, data): 109 | return dumper.represent_dict(data.items()) 110 | 111 | def dict_constructor(loader, node): 112 | return OrderedDict(loader.construct_pairs(node)) 113 | 114 | Dumper.add_representer(OrderedDict, dict_representer) 115 | Loader.add_constructor(_mapping_tag, dict_constructor) 116 | return Loader, Dumper 117 | 118 | 119 | #################### 120 | # miscellaneous 121 | #################### 122 | 123 | 124 | def get_timestamp(): 125 | return datetime.now().strftime('%y%m%d-%H%M%S') 126 | 127 | 128 | def mkdir(path): 129 | if not os.path.exists(path): 130 | os.makedirs(path) 131 | 132 | 133 | def mkdirs(paths): 134 | if isinstance(paths, str): 135 | mkdir(paths) 136 | else: 137 | for path in paths: 138 | mkdir(path) 139 | 140 | 141 | def mkdir_and_rename(path): 142 | if os.path.exists(path): 143 | new_name = path + '_archived_' + get_timestamp() 144 | print('Path already exists. Rename it to [{:s}]'.format(new_name)) 145 | logger = logging.getLogger('base') 146 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) 147 | os.rename(path, new_name) 148 | os.makedirs(path) 149 | 150 | 151 | def set_random_seed(seed): 152 | random.seed(seed) 153 | np.random.seed(seed) 154 | torch.manual_seed(seed) 155 | torch.cuda.manual_seed_all(seed) 156 | 157 | 158 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): 159 | '''set up logger''' 160 | lg = logging.getLogger(logger_name) 161 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 162 | datefmt='%y-%m-%d %H:%M:%S') 163 | lg.setLevel(level) 164 | if tofile: 165 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) 166 | fh = logging.FileHandler(log_file, mode='w') 167 | fh.setFormatter(formatter) 168 | lg.addHandler(fh) 169 | if screen: 170 | sh = logging.StreamHandler() 171 | sh.setFormatter(formatter) 172 | lg.addHandler(sh) 173 | 174 | 175 | #################### 176 | # image convert 177 | #################### 178 | 179 | 180 | def tensor2img(tensor, out_type=np.uint8, min_max=(-1, 1)): 181 | ''' 182 | Converts a torch Tensor into an image Numpy array 183 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 184 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 185 | ''' 186 | if hasattr(tensor, 'detach'): 187 | tensor = tensor.detach() 188 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 189 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 190 | n_dim = tensor.dim() 191 | if n_dim == 4: 192 | n_img = len(tensor) 193 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 194 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 195 | elif n_dim == 3: 196 | img_np = tensor.numpy() 197 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 198 | elif n_dim == 2: 199 | img_np = tensor.numpy() 200 | else: 201 | raise TypeError( 202 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 203 | if out_type == np.uint8: 204 | img_np = np.clip((img_np * 255.0).round(), 0, 255) 205 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 206 | return tensor 207 | # return img_np.astype(out_type) 208 | 209 | 210 | def save_img(img, img_path, mode='RGB'): 211 | cv2.imwrite(img_path, img) 212 | 213 | 214 | #################### 215 | # metric 216 | #################### 217 | 218 | 219 | def calculate_psnr(img1, img2): 220 | # img1 and img2 have range [0, 255] 221 | img1 = img1.astype(np.float64) 222 | img2 = img2.astype(np.float64) 223 | mse = np.mean((img1 - img2) ** 2) 224 | if mse == 0: 225 | return float('inf') 226 | return 20 * math.log10(255.0 / math.sqrt(mse)) 227 | 228 | 229 | def calculate_ssim(imgA, imgB, gray_scale=False): 230 | 231 | if gray_scale: 232 | score, diff = SSIM(cv2.cvtColor(imgA, cv2.COLOR_RGB2GRAY), cv2.cvtColor(imgB, cv2.COLOR_RGB2GRAY), full=True, multichannel=True) 233 | else: 234 | score, diff = SSIM(imgA, imgB, full=True, multichannel=True) 235 | return score 236 | 237 | 238 | def calculate_lpips(imgA, imgB, gray_scale=False): 239 | 240 | if gray_scale: 241 | score, diff = SSIM(cv2.cvtColor(imgA, cv2.COLOR_RGB2GRAY), cv2.cvtColor(imgB, cv2.COLOR_RGB2GRAY), full=True, multichannel=True) 242 | else: 243 | score, diff = SSIM(imgA, imgB, full=True, multichannel=True) 244 | return score 245 | 246 | def get_resume_paths(opt): 247 | resume_state_path = None 248 | resume_model_path = None 249 | ts = opt_get(opt, ['path', 'training_state']) 250 | if opt.get('path', {}).get('resume_state', None) == "auto" and ts is not None: 251 | wildcard = os.path.join(ts, "*") 252 | paths = natsort.natsorted(glob.glob(wildcard)) 253 | if len(paths) > 0: 254 | resume_state_path = paths[-1] 255 | resume_model_path = resume_state_path.replace('training_state', 'models').replace('.state', '_G.pth') 256 | else: 257 | resume_state_path = opt.get('path', {}).get('resume_state') 258 | return resume_state_path, resume_model_path 259 | 260 | 261 | def opt_get(opt, keys, default=None): 262 | if opt is None: 263 | return default 264 | ret = opt 265 | for k in keys: 266 | ret = ret.get(k, None) 267 | if ret is None: 268 | return default 269 | return ret 270 | --------------------------------------------------------------------------------