├── .gitignore ├── LICENSE ├── README.md ├── datasets └── README.md ├── diffglv ├── __init__.py ├── archs │ ├── __init__.py │ └── unet_BI_DiffSR_arch.py ├── data │ ├── __init__.py │ ├── data_util.py │ └── paired_image_dataset.py ├── losses │ ├── __init__.py │ └── at_loss.py ├── metrics │ ├── __init__.py │ └── lpips.py ├── models │ ├── BI_DiffSR_model.py │ └── __init__.py └── utils │ ├── GPU_memory.py │ ├── __init__.py │ ├── base_model.py │ ├── beta_schedule.py │ ├── extract_subimages.py │ ├── logger.py │ ├── lr_scheduler.py │ ├── make_ds.py │ ├── options.py │ └── transforms.py ├── experiments ├── README.md └── pretrained_models │ └── README.md ├── figs ├── BI-DiffSR.png ├── F1.png ├── F2-1.png ├── F2-2.png ├── F3-1.png ├── F3-2.png ├── T1.png ├── compare │ ├── ComS_img_023_BBCU_x4.png │ ├── ComS_img_023_BI-DiffSR_x4.png │ ├── ComS_img_023_Bicubic_x4.png │ ├── ComS_img_023_HR_x4.png │ ├── ComS_img_023_SR3_x4.png │ ├── ComS_img_033_BBCU_x4.png │ ├── ComS_img_033_BI-DiffSR_x4.png │ ├── ComS_img_033_Bicubic_x4.png │ ├── ComS_img_033_HR_x4.png │ └── ComS_img_033_SR3_x4.png └── logo.png ├── options ├── test │ ├── test_BI_DiffSR_x2.yml │ └── test_BI_DiffSR_x4.yml └── train │ ├── train_BI_DiffSR_x2.yml │ └── train_BI_DiffSR_x4.yml ├── requirements.txt ├── results └── README.md ├── test.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ignored folders 2 | datasets/* 3 | experiments/* 4 | results/* 5 | tb_logger/* 6 | wandb/* 7 | tmp/* 8 | slurm/* 9 | scripts/metrics/* 10 | 11 | options/euler/* 12 | 13 | *.DS_Store 14 | .idea 15 | 16 | # ignored files 17 | version.py 18 | 19 | # ignored files with suffix 20 | *.html 21 | *.png 22 | *.jpeg 23 | *.jpg 24 | *.gif 25 | *.pth 26 | *.zip 27 | *.npy 28 | *.pdf 29 | 30 | !figs/*.png 31 | 32 | # slurm 33 | *.err 34 | *.out 35 | 36 | # template 37 | 38 | # Byte-compiled / optimized / DLL files 39 | __pycache__/ 40 | *.py[cod] 41 | *$py.class 42 | 43 | # C extensions 44 | *.so 45 | 46 | # Distribution / packaging 47 | .Python 48 | build/ 49 | develop-eggs/ 50 | dist/ 51 | downloads/ 52 | eggs/ 53 | .eggs/ 54 | lib/ 55 | lib64/ 56 | parts/ 57 | sdist/ 58 | var/ 59 | wheels/ 60 | *.egg-info/ 61 | .installed.cfg 62 | *.egg 63 | MANIFEST 64 | 65 | # PyInstaller 66 | # Usually these files are written by a python script from a template 67 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 68 | *.manifest 69 | *.spec 70 | 71 | # Installer logs 72 | pip-log.txt 73 | pip-delete-this-directory.txt 74 | 75 | # Unit test / coverage reports 76 | htmlcov/ 77 | .tox/ 78 | .coverage 79 | .coverage.* 80 | .cache 81 | nosetests.xml 82 | coverage.xml 83 | *.cover 84 | .hypothesis/ 85 | .pytest_cache/ 86 | 87 | # Translations 88 | *.mo 89 | *.pot 90 | 91 | # Django stuff: 92 | *.log 93 | local_settings.py 94 | db.sqlite3 95 | 96 | # Flask stuff: 97 | instance/ 98 | .webassets-cache 99 | 100 | # Scrapy stuff: 101 | .scrapy 102 | 103 | # Sphinx documentation 104 | docs/_build/ 105 | 106 | # PyBuilder 107 | target/ 108 | 109 | # Jupyter Notebook 110 | .ipynb_checkpoints 111 | 112 | # pyenv 113 | .python-version 114 | 115 | # celery beat schedule file 116 | celerybeat-schedule 117 | 118 | # SageMath parsed files 119 | *.sage.py 120 | 121 | # Environments 122 | .env 123 | .venv 124 | env/ 125 | venv/ 126 | ENV/ 127 | env.bak/ 128 | venv.bak/ 129 | 130 | # Spyder project settings 131 | .spyderproject 132 | .spyproject 133 | 134 | # Rope project settings 135 | .ropeproject 136 | 137 | # mkdocs documentation 138 | /site 139 | 140 | # mypy 141 | .mypy_cache/ 142 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2024 BI-DiffSR Authors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |
2 |

3 |
4 | 5 | # Binarized Diffusion Model for Image Super-Resolution 6 | 7 | [Zheng Chen](https://zhengchen1999.github.io/), [Haotong Qin](https://htqin.github.io/), [Yong Guo](https://www.guoyongcs.com/), [Xiongfei Su](https://ieeexplore.ieee.org/author/37086348852), [Xin Yuan](https://en.westlake.edu.cn/faculty/xin-yuan.html), [Linghe Kong](https://www.cs.sjtu.edu.cn/~linghe.kong/), and [Yulun Zhang](http://yulunzhang.com/), "Binarized Diffusion Model for Image Super-Resolution", NeurIPS, 2024 8 | 9 | [[project](https://zhengchen1999.github.io/BI-DiffSR/)] [[arXiv](https://arxiv.org/abs/2406.05723)] [[supplementary material](https://github.com/zhengchen1999/BI-DiffSR/releases/download/v1/Supplementary_Material.pdf)] [[visual results](https://drive.google.com/drive/folders/1-Mfy8XHG55Bc19gAXqNaNitO0GEx7O1r?usp=drive_link)] [[pretrained models](https://drive.google.com/drive/folders/1hoHAG2yoLltloQ0SYv-QLxwk9Y8ZnTnH?usp=drive_link)] 10 | 11 | 12 | 13 | #### 🔥🔥🔥 News 14 | 15 | - **2024-10-23:** [Project Page](https://zhengchen1999.github.io/BI-DiffSR/) is accessible. 📃📃📃 16 | - **2024-10-14:** Code and pre-trained models are released. ⭐️⭐️⭐️ 17 | - **2024-09-26:** BI-DiffSR is accepted at NeurIPS 2024. 🎉🎉🎉 18 | - **2024-06-09:** This repo is released. 19 | 20 | --- 21 | 22 | > **Abstract:** Advanced diffusion models (DMs) perform impressively in image super-resolution (SR), but the high memory and computational costs hinder their deployment. Binarization, an ultra-compression algorithm, offers the potential for effectively accelerating DMs. Nonetheless, due to the model structure and the multi-step iterative attribute of DMs, existing binarization methods result in significant performance degradation. In this paper, we introduce a novel binarized diffusion model, BI-DiffSR, for image SR. First, for the model structure, we design a UNet architecture optimized for binarization. We propose the consistent-pixel-downsample (CP-Down) and consistent-pixel-upsample (CP-Up) to maintain dimension consistent and facilitate the full-precision information transfer. Meanwhile, we design the channel-shuffle-fusion (CS-Fusion) to enhance feature fusion in skip connection. Second, for the activation difference across timestep, we design the timestep-aware redistribution (TaR) and activation function (TaA). The TaR and TaA dynamically adjust the distribution of activations based on different timesteps, improving the flexibility and representation alability of the binarized module. Comprehensive experiments demonstrate that our BI-DiffSR outperforms existing binarization methods. 23 | 24 | ![](figs/BI-DiffSR.png) 25 | 26 | --- 27 | 28 | | HR | LR | [SR3 (FP)](https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement) | [BBCU](https://github.com/Zj-BinXia/BBCU) | BI-DiffSR (ours) | 29 | | :-------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :---------------------------------------------------------: | :----------------------------------------------------------: | 30 | | | | | | | 31 | | | | | | | 32 | 33 | ## TODO 34 | 35 | * [x] Release code and pretrained models 36 | 37 | ## Dependencies 38 | 39 | - Python 3.9 40 | - PyTorch 1.13.1+cu117 41 | 42 | ```bash 43 | # Clone the github repo and go to the default directory 'BI-DiffSR'. 44 | git clone https://github.com/zhengchen1999/BI-DiffSR.git 45 | conda create -n bi_diffsr python=3.9 46 | conda activate bi_diffsr 47 | pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html 48 | git clone https://github.com/huggingface/diffusers.git 49 | cd diffusers 50 | pip install -e ".[torch]" 51 | ``` 52 | 53 | ## Contents 54 | 55 | 1. [Datasets](#datasets) 56 | 1. [Models](#models) 57 | 1. [Training](#training) 58 | 1. [Testing](#testing) 59 | 1. [Results](#results) 60 | 1. [Citation](#citation) 61 | 1. [Acknowledgements](#acknowledgements) 62 | 63 | ## Datasets 64 | 65 | Used training and testing sets can be downloaded as follows: 66 | 67 | | Training Set | Testing Set | Visual Results | 68 | | :----------------------------------------------------------- | :----------------------------------------------------------: | :----------------------------------------------------------: | 69 | | [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) (800 training images, 100 validation images) + [Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) (2650 images) [complete training dataset DF2K: [Google Drive](https://drive.google.com/file/d/1TubDkirxl4qAWelfOnpwaSKoj3KLAIG4/view?usp=share_link) / [Baidu Disk](https://pan.baidu.com/s/1KIcPNz3qDsGSM0uDKl4DRw?pwd=74yc)] | Set5 + Set14 + BSD100 + Urban100 + Manga109 [complete testing dataset: [Google Drive](https://drive.google.com/file/d/1yMbItvFKVaCT93yPWmlP3883XtJ-wSee/view?usp=sharing) / [Baidu Disk](https://pan.baidu.com/s/1Tf8WT14vhlA49TO2lz3Y1Q?pwd=8xen)] | [Google Drive](https://drive.google.com/drive/folders/1ZMaZyCer44ZX6tdcDmjIrc_hSsKoMKg2?usp=drive_link) / [Baidu Disk](https://pan.baidu.com/s/1LO-INqy40F5T_coAJsl5qw?pwd=dqnv#list/path=%2F) | 70 | 71 | Download training and testing datasets and put them into the corresponding folders of `datasets/`. 72 | 73 | ## Models 74 | 75 | | Method | Params (M) | FLOPs (G) | PSNR (dB) | LPIPS | Model Zoo | Visual Results | 76 | | :-------- | :--------: | :-------: | :-------: | :----: | :----------------------------------------------------------: | :----------------------------------------------------------: | 77 | | BI-DiffSR | 4.58 | 36.67 | 24.11 | 0.1823 | [Google Drive](https://drive.google.com/drive/folders/1hoHAG2yoLltloQ0SYv-QLxwk9Y8ZnTnH?usp=sharing) | [Google Drive](https://drive.google.com/drive/folders/1-Mfy8XHG55Bc19gAXqNaNitO0GEx7O1r?usp=sharing) | 78 | 79 | The performance is reported on Urban100 (×4). Output size of FLOPs is 3×256×256. 80 | 81 | ## Training 82 | 83 | - The ×2 task requires **4*8 GB** VRAM, and the ×4 task requires **4*20 GB** VRAM. 84 | 85 | - Download [training](https://drive.google.com/file/d/1TubDkirxl4qAWelfOnpwaSKoj3KLAIG4/view?usp=share_link) (DF2K, already processed) and [testing](https://drive.google.com/file/d/1yMbItvFKVaCT93yPWmlP3883XtJ-wSee/view?usp=sharing) (Set5, BSD100, Urban100, Manga109, already processed) datasets, place them in `datasets/`. 86 | 87 | - Run the following scripts. The training configuration is in `options/train/`. 88 | 89 | ```shell 90 | # BI-DiffSR, input=64x64, 4 GPUs 91 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train/train_BI_DiffSR_x2.yml --launcher pytorch 92 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train/train_BI_DiffSR_x4.yml --launcher pytorch 93 | ``` 94 | 95 | - The training experiment is in `experiments/`. 96 | 97 | ## Testing 98 | 99 | - Download the pre-trained [models](https://drive.google.com/drive/folders/1hoHAG2yoLltloQ0SYv-QLxwk9Y8ZnTnH?usp=sharing) and place them in `experiments/pretrained_models/`. 100 | 101 | We provide pre-trained models for image SR (×2, ×4). 102 | 103 | - Download [testing](https://drive.google.com/file/d/1yMbItvFKVaCT93yPWmlP3883XtJ-wSee/view?usp=sharing) (Set5, BSD100, Urban100, Manga109) datasets, place them in `datasets/`. 104 | 105 | - Run the following scripts. The testing configuration is in `options/test/`. 106 | 107 | ```shell 108 | # BI-DiffSR, reproduces results in Table 2 of the main paper 109 | python test.py -opt options/test/test_BI_DiffSR_x2.yml 110 | python test.py -opt options/test/test_BI_DiffSR_x4.yml 111 | ``` 112 | 113 | Due to the randomness of diffusion model ([diffusers](https://huggingface.co/docs/diffusers)), results may slightly vary. 114 | 115 | - The output is in `results/`. 116 | 117 | ## Results 118 | 119 | We achieve state-of-the-art performance. Detailed results can be found in the paper. 120 | 121 |
122 | Quantitative Comparisons (click to expand) 123 | 124 | - Results in Table 2 (main paper) 125 | 126 |

127 | 128 |

129 |
130 | 131 | 132 | 133 |
134 | Visual Comparisons (click to expand) 135 | 136 | 137 | - Results in Figure 8 (main paper) 138 | 139 |

140 | 141 |

142 | 143 | 144 | 145 | - Results in Figure 5 (supplemental material) 146 | 147 |

148 | 149 | 150 |

151 | 152 | 153 | 154 | 155 | - Results in Figure 6 (supplemental material) 156 | 157 |

158 | 159 | 160 |

161 | 162 |
163 | 164 | 165 | 166 | ## Citation 167 | 168 | If you find the code helpful in your research or work, please cite the following paper(s). 169 | 170 | ``` 171 | @inproceedings{chen2024binarized, 172 | title={Binarized Diffusion Model for Image Super-Resolution}, 173 | author={Chen, Zheng and Qin, Haotong and Guo, Yong and Su, Xiongfei and Yuan, Xin and Kong, Linghe and Zhang, Yulun}, 174 | booktitle={NeurIPS}, 175 | year={2024} 176 | } 177 | ``` 178 | 179 | 180 | 181 | ## Acknowledgements 182 | 183 | This code is built on [BasicSR](https://github.com/XPixelGroup/BasicSR), [Image-Super-Resolution-via-Iterative-Refinement](https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement). 184 | -------------------------------------------------------------------------------- /datasets/README.md: -------------------------------------------------------------------------------- 1 | For training and testing, the directory structure is as follows: 2 | 3 | ```shell 4 | |-- datasets 5 | # train 6 | |-- DF2K 7 | |-- HR 8 | |-- LR_bicubic 9 | |-- X2 10 | |-- X3 11 | |-- X4 12 | # test 13 | |-- benchmark 14 | |-- Set5 15 | |-- HR 16 | |-- LR_bicubic 17 | |-- X2 18 | |-- X3 19 | |-- X4 20 | |-- B100 21 | |-- HR 22 | |-- LR_bicubic 23 | |-- X2 24 | |-- X3 25 | |-- X4 26 | |-- Urban100 27 | |-- HR 28 | |-- LR_bicubic 29 | |-- X2 30 | |-- X3 31 | |-- X4 32 | |-- Manga109 33 | |-- HR 34 | |-- LR_bicubic 35 | |-- X2 36 | |-- X3 37 | |-- X4 38 | ``` 39 | 40 | You can download the complete datasets we have collected. 41 | -------------------------------------------------------------------------------- /diffglv/__init__.py: -------------------------------------------------------------------------------- 1 | from .archs import * 2 | from .data import * 3 | from .losses import * 4 | from .metrics import * 5 | from .models import * 6 | 7 | -------------------------------------------------------------------------------- /diffglv/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import ARCH_REGISTRY 7 | 8 | __all__ = ['build_network'] 9 | 10 | # automatically scan and import arch modules for registry 11 | # scan all the files under the 'archs' folder and collect files ending with 12 | # '_arch.py' 13 | arch_folder = osp.dirname(osp.abspath(__file__)) 14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')] 15 | # import all the arch modules 16 | _arch_modules = [importlib.import_module(f'diffglv.archs.{file_name}') for file_name in arch_filenames] 17 | 18 | 19 | def build_network(opt): 20 | opt = deepcopy(opt) 21 | network_type = opt.pop('type') 22 | net = ARCH_REGISTRY.get(network_type)(**opt) 23 | logger = get_root_logger() 24 | logger.info(f'Network [{net.__class__.__name__}] is created.') 25 | return net 26 | -------------------------------------------------------------------------------- /diffglv/archs/unet_BI_DiffSR_arch.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | import torch.nn.functional as F 5 | from inspect import isfunction 6 | 7 | from basicsr.utils.registry import ARCH_REGISTRY 8 | 9 | 10 | 11 | def exists(x): 12 | return x is not None 13 | 14 | def default(val, d): 15 | if exists(val): 16 | return val 17 | return d() if isfunction(d) else d 18 | 19 | # --------------------------------------------- BI Basic Units: START ----------------------------------------------------------------- 20 | class RPReLU(nn.Module): 21 | def __init__(self, inplanes): 22 | super(RPReLU, self).__init__() 23 | self.pr_bias0 = LearnableBias(inplanes) 24 | self.pr_prelu = nn.PReLU(inplanes) 25 | self.pr_bias1 = LearnableBias(inplanes) 26 | 27 | def forward(self, x): 28 | x = self.pr_bias1(self.pr_prelu(self.pr_bias0(x))) 29 | return x 30 | 31 | class BinaryActivation(nn.Module): 32 | def __init__(self): 33 | super(BinaryActivation, self).__init__() 34 | 35 | def forward(self, x): 36 | out_forward = torch.sign(x) 37 | mask1 = x < -1 38 | mask2 = x < 0 39 | mask3 = x < 1 40 | out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32)) 41 | out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32)) 42 | out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32)) 43 | out = out_forward.detach() - out3.detach() + out3 44 | 45 | return out 46 | 47 | class LearnableBias(nn.Module): 48 | def __init__(self, out_chn): 49 | super(LearnableBias, self).__init__() 50 | self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True) 51 | 52 | def forward(self, x): 53 | out = x + self.bias.expand_as(x) 54 | return out 55 | 56 | class HardBinaryConv(nn.Conv2d): 57 | def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1,groups=1,bias=True): 58 | super(HardBinaryConv, self).__init__( 59 | in_chn, 60 | out_chn, 61 | kernel_size, 62 | stride=stride, 63 | padding=padding, 64 | groups=groups, 65 | bias=bias 66 | ) 67 | 68 | def forward(self, x): 69 | real_weights = self.weight 70 | scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True) 71 | scaling_factor = scaling_factor.detach() 72 | binary_weights_no_grad = scaling_factor * torch.sign(real_weights) 73 | cliped_weights = torch.clamp(real_weights, -1.0, 1.0) 74 | binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights 75 | y = F.conv2d(x, binary_weights,self.bias, stride=self.stride, padding=self.padding) 76 | return y 77 | 78 | class BIConv(nn.Module): 79 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False, groups=1, dynamic_group=5): 80 | super(BIConv, self).__init__() 81 | self.TaR = nn.ModuleDict({ 82 | f'dynamic_move_{i}': LearnableBias(in_channels) for i in range(dynamic_group) 83 | }) 84 | 85 | self.binary_activation = BinaryActivation() 86 | self.binary_conv = HardBinaryConv(in_channels, 87 | out_channels, 88 | kernel_size, 89 | padding=(kernel_size//2), 90 | bias=bias) 91 | self.TaA = nn.ModuleDict({ 92 | f'dynamic_relu_{i}': RPReLU(out_channels) for i in range(dynamic_group) 93 | }) 94 | 95 | def forward(self, x, t): 96 | out = self.TaR[f'dynamic_move_{t}'](x) 97 | out = self.binary_activation(out) 98 | out = self.binary_conv(out) 99 | out = self.TaA[f'dynamic_relu_{t}'](out) 100 | out = out + x 101 | return out 102 | # --------------------------------------------- BI Basic Units: END ----------------------------------------------------------------- 103 | 104 | # --------------------------------------------- FP Module: START -------------------------------------------------------------------- 105 | # PositionalEncoding Source: https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py 106 | class PositionalEncoding(nn.Module): 107 | def __init__(self, dim): 108 | super().__init__() 109 | self.dim = dim 110 | 111 | def forward(self, timestep_level): 112 | count = self.dim // 2 113 | step = torch.arange(count, dtype=timestep_level.dtype, 114 | device=timestep_level.device) / count 115 | encoding = timestep_level.unsqueeze( 116 | 1) * torch.exp(-math.log(1e4) * step.unsqueeze(0)) 117 | encoding = torch.cat( 118 | [torch.sin(encoding), torch.cos(encoding)], dim=-1) 119 | return encoding 120 | 121 | class Swish(nn.Module): 122 | def forward(self, x): 123 | return x * torch.sigmoid(x) 124 | 125 | class Block(nn.Module): 126 | def __init__(self, dim, dim_out, groups=32, dropout=0): 127 | super().__init__() 128 | self.block = nn.Sequential( 129 | nn.GroupNorm(groups, dim), 130 | Swish(), 131 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(), 132 | ) 133 | 134 | assert dim == dim_out, f"Error: input ({dim}) and output ({dim_out}) channel dimension." 135 | 136 | if dim == dim_out: 137 | self.conv = nn.Conv2d(dim, dim_out, 3, padding=1) 138 | 139 | def forward(self, x): 140 | return self.conv(self.block(x)) 141 | 142 | class Block_F(nn.Module): 143 | def __init__(self, dim, dim_out, groups=32, dropout=0): 144 | super().__init__() 145 | self.block = nn.Sequential( 146 | nn.GroupNorm(groups, dim), 147 | Swish(), 148 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(), 149 | nn.Conv2d(dim, dim_out, 3, padding=1) 150 | ) 151 | 152 | def forward(self, x): 153 | return self.block(x) 154 | 155 | class SelfAttention(nn.Module): 156 | def __init__(self, in_channel, n_head=1, norm_groups=32): 157 | super().__init__() 158 | 159 | self.n_head = n_head 160 | 161 | self.norm = nn.GroupNorm(norm_groups, in_channel) 162 | self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False) 163 | self.out = nn.Conv2d(in_channel, in_channel, 1) 164 | 165 | def forward(self, input): 166 | batch, channel, height, width = input.shape 167 | n_head = self.n_head 168 | head_dim = channel // n_head 169 | 170 | norm = self.norm(input) 171 | qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width) 172 | query, key, value = qkv.chunk(3, dim=2) # bhdyx 173 | 174 | attn = torch.einsum( 175 | "bnchw, bncyx -> bnhwyx", query, key 176 | ).contiguous() / math.sqrt(channel) 177 | attn = attn.view(batch, n_head, height, width, -1) 178 | attn = torch.softmax(attn, -1) 179 | attn = attn.view(batch, n_head, height, width, height, width) 180 | 181 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous() 182 | out = self.out(out.view(batch, channel, height, width)) 183 | 184 | return out + input 185 | 186 | class CP_Up_FP(nn.Module): 187 | def __init__(self, dim): 188 | super().__init__() 189 | self.biconv1 = nn.Conv2d(dim, dim, 3, 1, 1) 190 | self.biconv2 = nn.Conv2d(dim, dim, 3, 1, 1) 191 | self.up = nn.PixelShuffle(2) 192 | 193 | def forward(self, x): 194 | ''' 195 | input: b,c,h,w 196 | output: b,c/2,h*2,w*2 197 | ''' 198 | out1 = self.biconv1(x) 199 | out2 = self.biconv2(x) 200 | out = torch.cat([out1, out2], dim=1) 201 | out = self.up(out) 202 | return out 203 | 204 | class CP_Down_FP(nn.Module): 205 | def __init__(self, dim): 206 | super().__init__() 207 | self.biconv1 = nn.Conv2d(dim//2, dim//2, 3, padding=1) 208 | self.biconv2 = nn.Conv2d(dim//2, dim//2, 3, padding=1) 209 | self.down = nn.PixelUnshuffle(2) 210 | 211 | def forward(self, x): 212 | ''' 213 | input: b,c,h,w 214 | output: b,2c,h/2,w/2 215 | ''' 216 | b,c,h,w = x.shape 217 | out1 = self.biconv1(x[:,:c//2,:,:]) 218 | out2 = self.biconv2(x[:,c//2:,:,:]) 219 | out = out1 + out2 220 | out = self.down(out) 221 | return out 222 | 223 | class CS_Fusion_FP(nn.Module): 224 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False, groups=1): 225 | super(CS_Fusion_FP, self).__init__() 226 | 227 | assert in_channels // 2 == out_channels, f"Error: input ({in_channels}) and output ({out_channels}) channel dimension." 228 | 229 | self.biconv_1 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, groups=groups) 230 | self.biconv_2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, groups=groups) 231 | 232 | def forward(self, x): 233 | ''' 234 | x: b,c,h,w 235 | out: b,c/2,h,w 236 | ''' 237 | b,c,h,w = x.shape 238 | in_1 = x[:,:c//2,:,:] 239 | in_2 = x[:,c//2:,:,:] 240 | 241 | fu_1 = torch.cat((in_1[:, 1::2, :, :], in_2[:, 0::2, :, :]), dim=1) 242 | fu_2 = torch.cat((in_1[:, 0::2, :, :], in_2[:, 1::2, :, :]), dim=1) 243 | 244 | out_1 = self.biconv_1(fu_1) 245 | out_2 = self.biconv_2(fu_2) 246 | 247 | out = out_1 + out_2 248 | return out 249 | 250 | class ResnetBlock(nn.Module): 251 | def __init__(self, dim, dim_out, timestep_level_emb_dim=None, dropout=0, norm_groups=32): 252 | super().__init__() 253 | self.timestep_func = nn.Sequential( 254 | Swish(), 255 | nn.Linear(timestep_level_emb_dim, dim_out) 256 | ) 257 | 258 | assert dim == dim_out, f"Error: input ({dim}) and output ({dim_out}) channel dimension." 259 | 260 | self.block1 = Block(dim, dim_out, groups=norm_groups) 261 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout) 262 | if dim == dim_out: 263 | self.res_conv = nn.Identity() 264 | 265 | def forward(self, x, time_emb): 266 | b, c, h, w = x.shape 267 | h = self.block1(x) 268 | t_emb = self.timestep_func(time_emb).type(h.dtype) 269 | h = h + t_emb[:, :, None, None] 270 | h = self.block2(h) 271 | return h + self.res_conv(x) 272 | 273 | class ResnetBlocWithAttn(nn.Module): 274 | def __init__(self, dim, dim_out, *, timestep_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False): 275 | super().__init__() 276 | self.with_attn = with_attn 277 | self.res_block = ResnetBlock( 278 | dim, dim_out, timestep_level_emb_dim, norm_groups=norm_groups, dropout=dropout) 279 | if with_attn: 280 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups) 281 | 282 | def forward(self, x, time_emb): 283 | x = self.res_block(x, time_emb) 284 | if(self.with_attn): 285 | x = self.attn(x) 286 | return x 287 | # -------------------------------------------------- FP Module: END ---------------------------------------------------------------------- 288 | 289 | # -------------------------------------------------- BI Module: START -------------------------------------------------------------------- 290 | class CP_Up(nn.Module): 291 | def __init__(self, dim, dynamic_group=5): 292 | super().__init__() 293 | self.biconv1 = BIConv(dim, dim, 3, 1, 1, dynamic_group=dynamic_group) 294 | self.biconv2 = BIConv(dim, dim, 3, 1, 1, dynamic_group=dynamic_group) 295 | self.up = nn.PixelShuffle(2) 296 | 297 | def forward(self, x, t): 298 | ''' 299 | input: b,c,h,w 300 | output: b,c/2,h*2,w*2 301 | ''' 302 | out1 = self.biconv1(x, t) 303 | out2 = self.biconv2(x, t) 304 | out = torch.cat([out1, out2], dim=1) 305 | out = self.up(out) 306 | return out 307 | 308 | class CP_Down(nn.Module): 309 | def __init__(self, dim, dynamic_group=5): 310 | super().__init__() 311 | self.biconv1 = BIConv(dim//2, dim//2, 3, padding=1, dynamic_group=dynamic_group) 312 | self.biconv2 = BIConv(dim//2, dim//2, 3, padding=1, dynamic_group=dynamic_group) 313 | self.down = nn.PixelUnshuffle(2) 314 | 315 | def forward(self, x, t): 316 | ''' 317 | input: b,c,h,w 318 | output: b,2c,h/2,w/2 319 | ''' 320 | b,c,h,w = x.shape 321 | out1 = self.biconv1(x[:,:c//2,:,:], t) 322 | out2 = self.biconv2(x[:,c//2:,:,:], t) 323 | out = out1 + out2 324 | out = self.down(out) 325 | return out 326 | 327 | class CS_Fusion(nn.Module): 328 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False, groups=1, dynamic_group=5): 329 | super(CS_Fusion, self).__init__() 330 | 331 | assert in_channels // 2 == out_channels, f"Error: input ({in_channels}) and output ({out_channels}) channel dimension." 332 | 333 | self.biconv_1 = BIConv(out_channels, out_channels, kernel_size, stride, padding, bias, groups, dynamic_group=dynamic_group) 334 | self.biconv_2 = BIConv(out_channels, out_channels, kernel_size, stride, padding, bias, groups, dynamic_group=dynamic_group) 335 | 336 | def forward(self, x, t): 337 | ''' 338 | x: b,c,h,w 339 | out: b,c/2,h,w 340 | ''' 341 | b,c,h,w = x.shape 342 | in_1 = x[:,:c//2,:,:] 343 | in_2 = x[:,c//2:,:,:] 344 | 345 | fu_1 = torch.cat((in_1[:, 1::2, :, :], in_2[:, 0::2, :, :]), dim=1) 346 | fu_2 = torch.cat((in_1[:, 0::2, :, :], in_2[:, 1::2, :, :]), dim=1) 347 | 348 | out_1 = self.biconv_1(fu_1, t) 349 | out_2 = self.biconv_2(fu_2, t) 350 | 351 | out = out_1 + out_2 352 | return out 353 | 354 | class BI_Block(nn.Module): 355 | def __init__(self, dim, dim_out, groups=32, dropout=0, dynamic_group=5): 356 | super().__init__() 357 | self.block = nn.Sequential( 358 | nn.GroupNorm(groups, dim), 359 | Swish(), 360 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(), 361 | ) 362 | 363 | assert dim == dim_out, f"Error: input ({dim}) and output ({dim_out}) channel dimension." 364 | 365 | if dim == dim_out: 366 | self.conv = BIConv(dim, dim_out, 3, padding=1, dynamic_group=dynamic_group) 367 | 368 | def forward(self, x, t): 369 | return self.conv(self.block(x), t) 370 | 371 | class BI_ResnetBlock(nn.Module): 372 | def __init__(self, dim, dim_out, timestep_level_emb_dim=None, dropout=0, norm_groups=32, dynamic_group=5): 373 | super().__init__() 374 | self.timestep_func = nn.Sequential( 375 | Swish(), 376 | nn.Linear(timestep_level_emb_dim, dim_out) 377 | ) 378 | 379 | assert dim == dim_out, f"Error: input ({dim}) and output ({dim_out}) channel dimension." 380 | 381 | self.block1 = BI_Block(dim, dim_out, groups=norm_groups, dynamic_group=dynamic_group) 382 | self.block2 = BI_Block(dim_out, dim_out, groups=norm_groups, dropout=dropout, dynamic_group=dynamic_group) 383 | if dim == dim_out: 384 | self.res_conv = nn.Identity() 385 | 386 | def forward(self, x, time_emb, t): 387 | b, c, h, w = x.shape 388 | h = self.block1(x, t) 389 | t_emb = self.timestep_func(time_emb).type(h.dtype) 390 | h = h + t_emb[:, :, None, None] 391 | h = self.block2(h, t) 392 | return h + self.res_conv(x) 393 | 394 | class BI_ResnetBlocWithAttn(nn.Module): 395 | def __init__(self, dim, dim_out, *, timestep_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False, dynamic_group=5): 396 | super().__init__() 397 | self.with_attn = with_attn 398 | self.res_block = BI_ResnetBlock( 399 | dim, dim_out, timestep_level_emb_dim, norm_groups=norm_groups, dropout=dropout, dynamic_group=dynamic_group) 400 | if with_attn: 401 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups) 402 | 403 | def forward(self, x, time_emb, t): 404 | x = self.res_block(x, time_emb, t) 405 | if(self.with_attn): 406 | x = self.attn(x) 407 | return x 408 | # -------------------------------------------------- BI Module: END -------------------------------------------------------------------- 409 | 410 | # ----------------------------------------------- BI-DiffSR UNet: START ---------------------------------------------------------------- 411 | @ARCH_REGISTRY.register() 412 | class BIDiffSRUNet(nn.Module): 413 | def __init__( 414 | self, 415 | in_channel=6, 416 | out_channel=3, 417 | inner_channel=32, 418 | norm_groups=32, 419 | channel_mults=(1, 2, 4, 8, 8), 420 | attn_res=(8), 421 | res_blocks=3, 422 | dropout=0, 423 | image_size=128, 424 | fp_res=(0), 425 | total_step=2000, 426 | dynamic_group=5 427 | ): 428 | super(BIDiffSRUNet, self).__init__() 429 | 430 | self.in_channel = in_channel 431 | self.total_step = total_step 432 | self.dynamic_group = dynamic_group 433 | 434 | timestep_level_channel = inner_channel 435 | self.timestep_level_mlp = nn.Sequential( 436 | PositionalEncoding(inner_channel), 437 | nn.Linear(inner_channel, inner_channel * 4), 438 | Swish(), 439 | nn.Linear(inner_channel * 4, inner_channel) 440 | ) 441 | 442 | num_mults = len(channel_mults) 443 | pre_channel = inner_channel 444 | feat_channels = [pre_channel] 445 | now_res = image_size 446 | downs = [nn.Conv2d(in_channel, inner_channel, 447 | kernel_size=3, padding=1)] 448 | for ind in range(num_mults): 449 | is_last = (ind == num_mults - 1) 450 | use_attn = (now_res in attn_res) 451 | channel_mult = inner_channel * channel_mults[ind] 452 | for _ in range(0, res_blocks): 453 | downs.append(BI_ResnetBlocWithAttn( 454 | pre_channel, channel_mult, timestep_level_emb_dim=timestep_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, dynamic_group=dynamic_group)) 455 | feat_channels.append(channel_mult) 456 | pre_channel = channel_mult 457 | if not is_last: 458 | downs.append(CP_Down(pre_channel, dynamic_group=dynamic_group)) 459 | now_res = now_res//2 460 | pre_channel = pre_channel*2 461 | feat_channels.append(pre_channel) 462 | self.downs = nn.ModuleList(downs) 463 | 464 | self.mid = nn.ModuleList([ 465 | BI_ResnetBlocWithAttn(pre_channel, pre_channel, timestep_level_emb_dim=timestep_level_channel, norm_groups=norm_groups, 466 | dropout=dropout, with_attn=True, dynamic_group=dynamic_group), 467 | BI_ResnetBlocWithAttn(pre_channel, pre_channel, timestep_level_emb_dim=timestep_level_channel, norm_groups=norm_groups, 468 | dropout=dropout, with_attn=False, dynamic_group=dynamic_group) 469 | ]) 470 | 471 | ups = [] 472 | for ind in reversed(range(num_mults)): 473 | is_last = (ind < 1) 474 | use_attn = (now_res in attn_res) 475 | use_fp= (now_res in fp_res) 476 | channel_mult = inner_channel * channel_mults[ind] 477 | for _ in range(0, res_blocks+1): 478 | if use_fp: 479 | ups.append(CS_Fusion_FP(pre_channel+feat_channels.pop(), channel_mult, kernel_size=1, stride=1, padding=0, bias=False, groups=1)) 480 | ups.append(ResnetBlocWithAttn( 481 | channel_mult, channel_mult, timestep_level_emb_dim=timestep_level_channel, norm_groups=norm_groups, 482 | dropout=dropout, with_attn=use_attn)) 483 | else: 484 | ups.append(CS_Fusion(pre_channel+feat_channels.pop(), channel_mult, kernel_size=1, stride=1, padding=0, bias=False, groups=1, dynamic_group=dynamic_group)) 485 | ups.append(BI_ResnetBlocWithAttn( 486 | channel_mult, channel_mult, timestep_level_emb_dim=timestep_level_channel, norm_groups=norm_groups, 487 | dropout=dropout, with_attn=use_attn, dynamic_group=dynamic_group)) 488 | pre_channel = channel_mult 489 | if not is_last: 490 | ups.append(CP_Up(pre_channel, dynamic_group=dynamic_group)) 491 | now_res = now_res*2 492 | pre_channel = pre_channel//2 493 | 494 | self.ups = nn.ModuleList(ups) 495 | 496 | self.final_conv = Block_F(pre_channel, default(out_channel, in_channel), groups=norm_groups) 497 | 498 | def forward(self, x, c, time): 499 | index_dynamic = int(time[0][0] * self.dynamic_group / self.total_step) 500 | index_dynamic = max(0, min(index_dynamic, self.dynamic_group - 1)) 501 | 502 | time = time.squeeze(1) # consistent with the original code 503 | 504 | if self.in_channel != 3: 505 | x = torch.cat([c, x], dim=1) 506 | t = self.timestep_level_mlp(time) if exists( 507 | self.timestep_level_mlp) else None 508 | 509 | feats = [] 510 | for layer in self.downs: 511 | if isinstance(layer, BI_ResnetBlocWithAttn): 512 | x = layer(x, t, index_dynamic) 513 | elif isinstance(layer, ResnetBlocWithAttn): 514 | x = layer(x, t) 515 | elif isinstance(layer, CP_Down): 516 | x = layer(x, index_dynamic) 517 | else: 518 | x = layer(x) 519 | feats.append(x) 520 | 521 | for layer in self.mid: 522 | if isinstance(layer, BI_ResnetBlocWithAttn): 523 | x = layer(x, t, index_dynamic) 524 | else: 525 | x = layer(x) 526 | 527 | for layer in self.ups: 528 | if isinstance(layer, CS_Fusion): 529 | x = layer(torch.cat((x, feats.pop()), dim=1), index_dynamic) 530 | elif isinstance(layer, CS_Fusion_FP): 531 | x = layer(torch.cat((x, feats.pop()), dim=1)) 532 | elif isinstance(layer, BI_ResnetBlocWithAttn): 533 | x = layer(x, t, index_dynamic) 534 | elif isinstance(layer, ResnetBlocWithAttn): 535 | x = layer(x, t) 536 | elif isinstance(layer, CP_Up): 537 | x = layer(x, index_dynamic) 538 | else: 539 | x = layer(x) 540 | 541 | return self.final_conv(x) 542 | 543 | 544 | if __name__ == '__main__': 545 | model = BIDiffSRUNet( 546 | in_channel = 6, 547 | out_channel = 3, 548 | inner_channel = 64, 549 | norm_groups = 16, 550 | channel_mults = [1, 2, 4, 8], 551 | attn_res = [], 552 | res_blocks = 2, 553 | dropout = 0.2, 554 | image_size = 256, 555 | fp_res= [256, 128], 556 | dynamic_group=5 557 | ) 558 | print(model) 559 | 560 | x = torch.randn((2, 3, 128, 128)) 561 | c = torch.randn((2, 3, 128, 128)) 562 | timesteps = torch.randint(0, 10, (2,)).long().unsqueeze(1) 563 | x = model(x, c, timesteps) 564 | print(x.shape) 565 | print(sum(map(lambda x: x.numel(), model.parameters()))) -------------------------------------------------------------------------------- /diffglv/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from copy import deepcopy 7 | from functools import partial 8 | from os import path as osp 9 | 10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 11 | from basicsr.utils import get_root_logger, scandir 12 | from basicsr.utils.dist_util import get_dist_info 13 | from basicsr.utils.registry import DATASET_REGISTRY 14 | 15 | __all__ = ['build_dataset', 'build_dataloader', 'build_dataset_generate'] 16 | 17 | # automatically scan and import dataset modules for registry 18 | # scan all the files under the data folder with '_dataset' in file names 19 | data_folder = osp.dirname(osp.abspath(__file__)) 20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')] 21 | # import all the dataset modules 22 | _dataset_modules = [importlib.import_module(f'diffglv.data.{file_name}') for file_name in dataset_filenames] 23 | 24 | 25 | def build_dataset_generate(dataset_opt, opt): 26 | """Build dataset from options. 27 | 28 | Args: 29 | dataset_opt (dict): Configuration for dataset. It must contain: 30 | name (str): Dataset name. 31 | type (str): Dataset type. 32 | """ 33 | dataset_opt = deepcopy(dataset_opt) 34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt, opt) 35 | logger = get_root_logger() 36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') 37 | return dataset 38 | 39 | 40 | def build_dataset(dataset_opt): 41 | """Build dataset from options. 42 | 43 | Args: 44 | dataset_opt (dict): Configuration for dataset. It must contain: 45 | name (str): Dataset name. 46 | type (str): Dataset type. 47 | """ 48 | dataset_opt = deepcopy(dataset_opt) 49 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt) 50 | logger = get_root_logger() 51 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.') 52 | return dataset 53 | 54 | 55 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None): 56 | """Build dataloader. 57 | 58 | Args: 59 | dataset (torch.utils.data.Dataset): Dataset. 60 | dataset_opt (dict): Dataset options. It contains the following keys: 61 | phase (str): 'train' or 'val'. 62 | num_worker_per_gpu (int): Number of workers for each GPU. 63 | batch_size_per_gpu (int): Training batch size for each GPU. 64 | num_gpu (int): Number of GPUs. Used only in the train phase. 65 | Default: 1. 66 | dist (bool): Whether in distributed training. Used only in the train 67 | phase. Default: False. 68 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 69 | seed (int | None): Seed. Default: None 70 | """ 71 | phase = dataset_opt['phase'] 72 | rank, _ = get_dist_info() 73 | if phase == 'train': 74 | if dist: # distributed training 75 | batch_size = dataset_opt['batch_size_per_gpu'] 76 | num_workers = dataset_opt['num_worker_per_gpu'] 77 | else: # non-distributed training 78 | multiplier = 1 if num_gpu == 0 else num_gpu 79 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 80 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 81 | dataloader_args = dict( 82 | dataset=dataset, 83 | batch_size=batch_size, 84 | shuffle=False, 85 | num_workers=num_workers, 86 | sampler=sampler, 87 | drop_last=True) 88 | if sampler is None: 89 | dataloader_args['shuffle'] = True 90 | dataloader_args['worker_init_fn'] = partial( 91 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None 92 | elif phase in ['val', 'test']: # validation 93 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 94 | else: 95 | raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.") 96 | 97 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 98 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False) 99 | 100 | prefetch_mode = dataset_opt.get('prefetch_mode') 101 | if prefetch_mode == 'cpu': # CPUPrefetcher 102 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 103 | logger = get_root_logger() 104 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}') 105 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args) 106 | else: 107 | # prefetch_mode=None: Normal dataloader 108 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 109 | return torch.utils.data.DataLoader(**dataloader_args) 110 | 111 | 112 | def worker_init_fn(worker_id, num_workers, rank, seed): 113 | # Set the worker seed to num_workers * rank + worker_id + seed 114 | worker_seed = num_workers * rank + worker_id + seed 115 | np.random.seed(worker_seed) 116 | random.seed(worker_seed) 117 | -------------------------------------------------------------------------------- /diffglv/data/data_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import torch 4 | 5 | 6 | def paired_center_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): 7 | """Paired random crop. Support Numpy array and Tensor inputs. 8 | 9 | It crops lists of lq and gt images with corresponding locations. 10 | 11 | Args: 12 | img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images 13 | should have the same shape. If the input is an ndarray, it will 14 | be transformed to a list containing itself. 15 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 16 | should have the same shape. If the input is an ndarray, it will 17 | be transformed to a list containing itself. 18 | gt_patch_size (int): GT patch size. 19 | scale (int): Scale factor. 20 | gt_path (str): Path to ground-truth. Default: None. 21 | 22 | Returns: 23 | list[ndarray] | ndarray: GT images and LQ images. If returned results 24 | only have one element, just return ndarray. 25 | """ 26 | 27 | if not isinstance(img_gts, list): 28 | img_gts = [img_gts] 29 | if not isinstance(img_lqs, list): 30 | img_lqs = [img_lqs] 31 | 32 | # determine input type: Numpy array or Tensor 33 | input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' 34 | 35 | if input_type == 'Tensor': 36 | h_lq, w_lq = img_lqs[0].size()[-2:] 37 | h_gt, w_gt = img_gts[0].size()[-2:] 38 | else: 39 | h_lq, w_lq = img_lqs[0].shape[0:2] 40 | h_gt, w_gt = img_gts[0].shape[0:2] 41 | lq_patch_size = gt_patch_size // scale 42 | 43 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 44 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 45 | f'multiplication of LQ ({h_lq}, {w_lq}).') 46 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 47 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 48 | f'({lq_patch_size}, {lq_patch_size}). ' 49 | f'Please remove {gt_path}.') 50 | 51 | top = (h_lq - lq_patch_size) // 2 52 | left =(w_lq - lq_patch_size) // 2 53 | 54 | # crop lq patch 55 | if input_type == 'Tensor': 56 | img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] 57 | else: 58 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] 59 | 60 | # crop corresponding gt patch 61 | top_gt, left_gt = int(top * scale), int(left * scale) 62 | if input_type == 'Tensor': 63 | img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] 64 | else: 65 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] 66 | if len(img_gts) == 1: 67 | img_gts = img_gts[0] 68 | if len(img_lqs) == 1: 69 | img_lqs = img_lqs[0] 70 | return img_gts, img_lqs 71 | 72 | 73 | def random_crop(img_gts, gt_patch_size, gt_path=None): 74 | if not isinstance(img_gts, list): 75 | img_gts = [img_gts] 76 | 77 | 78 | # determine input type: Numpy array or Tensor 79 | input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' 80 | 81 | if input_type == 'Tensor': 82 | h_gt, w_gt = img_gts[0].size()[-2:] 83 | else: 84 | h_gt, w_gt = img_gts[0].shape[0:2] 85 | 86 | if h_gt < gt_patch_size or w_gt < gt_patch_size: 87 | raise ValueError(f' ({h_gt}, {w_gt}) is smaller than patch size ' 88 | f'({gt_patch_size}, {gt_patch_size}). ' 89 | f'Please remove {gt_path}.') 90 | 91 | top = random.randint(0, h_gt - gt_patch_size) 92 | left = random.randint(0, w_gt - gt_patch_size) 93 | 94 | 95 | if input_type == 'Tensor': 96 | img_gts = [v[:, :, top:top + gt_patch_size, left:left + gt_patch_size] for v in img_gts] 97 | else: 98 | img_gts = [v[top:top + gt_patch_size, left:left + gt_patch_size, ...] for v in img_gts] 99 | if len(img_gts) == 1: 100 | img_gts = img_gts[0] 101 | 102 | return img_gts -------------------------------------------------------------------------------- /diffglv/data/paired_image_dataset.py: -------------------------------------------------------------------------------- 1 | from torch.utils import data as data 2 | from torchvision.transforms.functional import normalize 3 | 4 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file 5 | from basicsr.data.transforms import augment 6 | from basicsr.utils import FileClient, imfrombytes, img2tensor 7 | # from basicsr.utils.matlab_functions import rgb2ycbcr 8 | from basicsr.utils.registry import DATASET_REGISTRY 9 | from diffglv.data.data_util import paired_center_crop 10 | from diffglv.utils.transforms import paired_random_crop 11 | 12 | import numpy as np 13 | import cv2 14 | import random 15 | 16 | @DATASET_REGISTRY.register() 17 | class MultiPairedImageDataset(data.Dataset): 18 | """Paired image dataset for image restoration. 19 | 20 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs. 21 | 22 | There are three modes: 23 | 1. 'lmdb': Use lmdb files. 24 | If opt['io_backend'] == lmdb. 25 | 2. 'meta_info_file': Use meta information file to generate paths. 26 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None. 27 | 3. 'folder': Scan folders to generate paths. 28 | The rest. 29 | 30 | Args: 31 | opt (dict): Config for train datasets. It contains the following keys: 32 | dataroot_gt (str): Data root path for gt. 33 | dataroot_lq (str): Data root path for lq. 34 | meta_info_file (str): Path for meta information file. 35 | io_backend (dict): IO backend type and other kwarg. 36 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension. 37 | Default: '{}'. 38 | gt_size (int): Cropped patched size for gt patches. 39 | use_hflip (bool): Use horizontal flips. 40 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation). 41 | 42 | scale (bool): Scale, which will be added automatically. 43 | phase (str): 'train' or 'val'. 44 | """ 45 | 46 | def __init__(self, opt): 47 | super(MultiPairedImageDataset, self).__init__() 48 | self.opt = opt 49 | # file client (io backend) 50 | self.file_client = None 51 | self.io_backend_opt = opt['io_backend'] 52 | self.mean = opt['mean'] if 'mean' in opt else None 53 | self.std = opt['std'] if 'std' in opt else None 54 | self.task = opt['task'] if 'task' in opt else None 55 | self.noise = opt['noise'] if 'noise' in opt else 0 56 | 57 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq'] 58 | 59 | if 'filename_tmpl' in opt: 60 | self.filename_tmpl = opt['filename_tmpl'] 61 | else: 62 | self.filename_tmpl = '{}' 63 | 64 | if self.io_backend_opt['type'] == 'lmdb': 65 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder] 66 | self.io_backend_opt['client_keys'] = ['lq', 'gt'] 67 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt']) 68 | elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None: 69 | self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'], 70 | self.opt['meta_info_file'], self.filename_tmpl) 71 | else: 72 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl) 73 | 74 | def __getitem__(self, index): 75 | if self.file_client is None: 76 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt) 77 | 78 | scale = self.opt['scale'] 79 | 80 | # Load gt and lq images. Dimension order: HWC; channel order: BGR; 81 | 82 | if self.task == 'CAR': 83 | # image range: [0, 255], int., H W 1 84 | 85 | gt_path = self.paths[index]['gt_path'] 86 | img_bytes = self.file_client.get(gt_path, 'gt') 87 | img_gt = imfrombytes(img_bytes, flag='grayscale', float32=False) 88 | lq_path = self.paths[index]['lq_path'] 89 | img_bytes = self.file_client.get(lq_path, 'lq') 90 | img_lq = imfrombytes(img_bytes, flag='grayscale', float32=False) 91 | img_gt = np.expand_dims(img_gt, axis=2).astype(np.float32) / 255. 92 | img_lq = np.expand_dims(img_lq, axis=2).astype(np.float32) / 255. 93 | 94 | # gt_path = self.paths[index]['gt_path'] 95 | # img_bytes = self.file_client.get(gt_path, 'gt') 96 | # img_gt = imfrombytes(img_bytes, float32=False) 97 | # lq_path = self.paths[index]['lq_path'] 98 | # img_bytes = self.file_client.get(lq_path, 'lq') 99 | # img_lq = imfrombytes(img_bytes, float32=False) 100 | # img_gt = img_gt[:,:,0,None] 101 | # img_lq = img_lq[:,:,0,None] 102 | 103 | elif self.task == 'Color Denoising': 104 | gt_path = self.paths[index]['gt_path'] 105 | lq_path = gt_path 106 | img_bytes = self.file_client.get(gt_path, 'gt') 107 | img_gt = imfrombytes(img_bytes, float32=True) 108 | if self.opt['phase'] != 'train': 109 | np.random.seed(seed=0) 110 | img_lq = img_gt + np.random.normal(0, self.noise/255., img_gt.shape) 111 | 112 | elif self.task == 'SR': 113 | # image range: [0, 1], float32., H W 3 114 | gt_path = self.paths[index]['gt_path'] 115 | img_bytes = self.file_client.get(gt_path, 'gt') 116 | img_gt = imfrombytes(img_bytes, float32=True) 117 | lq_path = self.paths[index]['lq_path'] 118 | img_bytes = self.file_client.get(lq_path, 'lq') 119 | img_lq = imfrombytes(img_bytes, float32=True) 120 | # bicubic 121 | img_lq = cv2.resize(img_lq, (img_lq.shape[1]*scale, img_lq.shape[0]*scale), interpolation=cv2.INTER_CUBIC) 122 | 123 | else: 124 | # image range: [0, 1], float32., H W 3 125 | gt_path = self.paths[index]['gt_path'] 126 | img_bytes = self.file_client.get(gt_path, 'gt') 127 | img_gt = imfrombytes(img_bytes, float32=True) 128 | lq_path = self.paths[index]['lq_path'] 129 | img_bytes = self.file_client.get(lq_path, 'lq') 130 | img_lq = imfrombytes(img_bytes, float32=True) 131 | 132 | scale = 1 133 | # augmentation for training 134 | if self.opt['phase'] == 'train': 135 | gt_size = self.opt['gt_size'] 136 | # random crop 137 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path) 138 | # flip, rotation 139 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot']) 140 | 141 | # color space transform 142 | if 'color' in self.opt and self.opt['color'] == 'y': 143 | print('Wrong: TODO') 144 | exit() 145 | # img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None] 146 | # img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None] 147 | else: 148 | # for val set 149 | if 'gt_size' in self.opt: 150 | gt_size = self.opt['gt_size'] 151 | img_gt, img_lq = paired_center_crop(img_gt, img_lq, gt_size, scale, gt_path) 152 | 153 | # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets 154 | # TODO: It is better to update the datasets, rather than force to crop 155 | if self.opt['phase'] != 'train': 156 | img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :] 157 | 158 | # BGR to RGB, HWC to CHW, numpy to tensor 159 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True) 160 | # normalize 161 | if self.mean is not None or self.std is not None: 162 | normalize(img_lq, self.mean, self.std, inplace=True) 163 | normalize(img_gt, self.mean, self.std, inplace=True) 164 | 165 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path, 'task': self.task} 166 | 167 | def __len__(self): 168 | return len(self.paths) 169 | -------------------------------------------------------------------------------- /diffglv/losses/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import LOSS_REGISTRY 7 | from basicsr.losses.gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty 8 | 9 | __all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize'] 10 | 11 | # automatically scan and import loss modules for registry 12 | # scan all the files under the 'losses' folder and collect files ending with '_loss.py' 13 | loss_folder = osp.dirname(osp.abspath(__file__)) 14 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')] 15 | # import all the loss modules 16 | _model_modules = [importlib.import_module(f'diffglv.losses.{file_name}') for file_name in loss_filenames] 17 | 18 | 19 | def build_loss(opt): 20 | """Build loss from options. 21 | 22 | Args: 23 | opt (dict): Configuration. It must contain: 24 | type (str): Model type. 25 | """ 26 | opt = deepcopy(opt) 27 | loss_type = opt.pop('type') 28 | loss = LOSS_REGISTRY.get(loss_type)(**opt) 29 | logger = get_root_logger() 30 | logger.info(f'Loss [{loss.__class__.__name__}] is created.') 31 | return loss 32 | -------------------------------------------------------------------------------- /diffglv/losses/at_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | from basicsr.utils.registry import LOSS_REGISTRY 5 | 6 | 7 | @LOSS_REGISTRY.register() 8 | class ATLoss(nn.Module): 9 | 10 | def __init__(self, loss_weight=1.0, reduction='mean'): 11 | super(ATLoss, self).__init__() 12 | 13 | 14 | def forward(self, pred, target): 15 | """ 16 | Args: 17 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 18 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 19 | """ 20 | Attention_pred = F.normalize(pred.pow(2).mean(1).view(pred.size(0), -1)) # (N, H*W) 21 | Attention_target = F.normalize(target.pow(2).mean(1).view(target.size(0), -1)) # (N, H*W) 22 | 23 | return nn.MSELoss()(Attention_pred, Attention_target) 24 | -------------------------------------------------------------------------------- /diffglv/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import scandir 6 | from basicsr.utils.registry import METRIC_REGISTRY 7 | from basicsr.metrics.niqe import calculate_niqe 8 | from basicsr.metrics.psnr_ssim import calculate_psnr, calculate_ssim 9 | 10 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe'] 11 | 12 | loss_folder = osp.dirname(osp.abspath(__file__)) 13 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_metric.py')] 14 | _model_modules = [importlib.import_module(f'diffglv.metrics.{file_name}') for file_name in loss_filenames] 15 | 16 | 17 | def calculate_metric(data, opt): 18 | """Calculate metric from data and options. 19 | 20 | Args: 21 | opt (dict): Configuration. It must contain: 22 | type (str): Model type. 23 | """ 24 | opt = deepcopy(opt) 25 | metric_type = opt.pop('type') 26 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt) 27 | return metric 28 | -------------------------------------------------------------------------------- /diffglv/metrics/lpips.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import lpips 3 | from torch import nn 4 | 5 | def disabled_train(self: nn.Module) -> nn.Module: 6 | """Overwrite model.train with this function to make sure train/eval mode 7 | does not change anymore.""" 8 | return self 9 | 10 | def frozen_module(module: nn.Module) -> None: 11 | module.eval() 12 | module.train = disabled_train 13 | for p in module.parameters(): 14 | p.requires_grad = False 15 | 16 | class LPIPS: 17 | def __init__(self, net: str) -> None: 18 | self.model = lpips.LPIPS(net=net) 19 | frozen_module(self.model) 20 | 21 | @torch.no_grad() 22 | def __call__(self, img1: torch.Tensor, img2: torch.Tensor, normalize: bool, boundarypixels=0) -> torch.Tensor: 23 | """ 24 | Compute LPIPS. 25 | 26 | Args: 27 | img1 (torch.Tensor): The first image (NCHW, RGB, [-1, 1]). Specify `normalize` if input 28 | image is range in [0, 1]. 29 | img2 (torch.Tensor): The second image (NCHW, RGB, [-1, 1]). Specify `normalize` if input 30 | image is range in [0, 1]. 31 | normalize (bool): If specified, the input images will be normalized from [0, 1] to [-1, 1]. 32 | 33 | Returns: 34 | lpips_values (torch.Tensor): The lpips scores of this batch. 35 | """ 36 | 37 | b, c, h, w = img1.shape 38 | img1 = img1[:, :, :h-h%boundarypixels, :w-w%boundarypixels] 39 | # img1 = img1[:,:, boundarypixels:-boundarypixels,boundarypixels:-boundarypixels] 40 | b, c, h, w = img2.shape 41 | img2 = img2[:, :, :h-h%boundarypixels, :w-w%boundarypixels] 42 | # img2 = img2[:,:, boundarypixels:-boundarypixels,boundarypixels:-boundarypixels] 43 | 44 | return self.model(img1, img2, normalize=normalize) 45 | 46 | def to(self, device: str) -> "LPIPS": 47 | self.model.to(device) 48 | return self 49 | -------------------------------------------------------------------------------- /diffglv/models/BI_DiffSR_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from collections import OrderedDict 3 | from os import path as osp 4 | from tqdm import tqdm 5 | 6 | from basicsr.archs import build_network 7 | from basicsr.losses import build_loss 8 | from basicsr.metrics import calculate_metric 9 | from basicsr.utils import get_root_logger, imwrite, tensor2img, img2tensor 10 | from basicsr.utils.registry import MODEL_REGISTRY 11 | from diffglv.utils.base_model import BaseModel 12 | from torch.nn import functional as F 13 | import numpy as np 14 | from diffusers import DDPMScheduler, DDIMScheduler 15 | 16 | from diffglv.metrics.lpips import LPIPS 17 | 18 | @MODEL_REGISTRY.register() 19 | class BIDiffSRModel(BaseModel): 20 | """DiffIR model for stage two.""" 21 | 22 | def __init__(self, opt): 23 | super(BIDiffSRModel, self).__init__(opt) 24 | 25 | # define network 26 | self.net_g = build_network(opt['network_g']) 27 | self.net_g = self.model_to_device(self.net_g) 28 | self.print_network(self.net_g) 29 | 30 | # load pretrained models 31 | load_path = self.opt['path'].get('pretrain_network_g', None) 32 | if load_path is not None: 33 | param_key = self.opt['path'].get('param_key_g', 'params') 34 | self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key) 35 | 36 | # diffusion 37 | self.set_new_noise_schedule(self.opt['beta_schedule'], self.device) 38 | 39 | # lipis 40 | self.lpips_opt = self.opt['val']['metrics'].get('lpips', None) 41 | if self.lpips_opt != None: 42 | self.lpips_metric = LPIPS(net="alex").to(self.device) 43 | 44 | if self.is_train: 45 | self.init_training_settings() 46 | 47 | def init_training_settings(self): 48 | self.net_g.train() 49 | train_opt = self.opt['train'] 50 | 51 | self.ema_decay = train_opt.get('ema_decay', 0) 52 | if self.ema_decay > 0: 53 | logger = get_root_logger() 54 | logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}') 55 | # define network net_g with Exponential Moving Average (EMA) 56 | # net_g_ema is used only for testing on one GPU and saving 57 | # There is no need to wrap with DistributedDataParallel 58 | self.net_g_ema = build_network(self.opt['network_g']).to(self.device) 59 | # load pretrained model 60 | load_path = self.opt['path'].get('pretrain_network_g', None) 61 | if load_path is not None: 62 | self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema') 63 | else: 64 | self.model_ema(0) # copy net_g weight 65 | self.net_g_ema.eval() 66 | 67 | # define losses 68 | if train_opt.get('pixel_opt'): 69 | self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device) 70 | else: 71 | self.cri_pix = None 72 | 73 | if self.cri_pix is None: 74 | raise ValueError('pixel loss is None.') 75 | 76 | # set up optimizers and schedulers 77 | self.setup_optimizers() 78 | self.setup_schedulers() 79 | 80 | def setup_optimizers(self): 81 | train_opt = self.opt['train'] 82 | optim_params = [] 83 | for k, v in self.net_g.named_parameters(): 84 | if v.requires_grad: 85 | optim_params.append(v) 86 | else: 87 | logger = get_root_logger() 88 | logger.warning(f'Network G: Params {k} will not be optimized.') 89 | 90 | optim_type = train_opt['optim_g'].pop('type') 91 | self.optimizer = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g']) 92 | self.optimizers.append(self.optimizer) 93 | 94 | def set_new_noise_schedule(self, schedule_opt, device): 95 | scheduler_opt = self.opt['beta_schedule'] 96 | scheduler_type = scheduler_opt.get('scheduler_type', None) 97 | _prediction_type = scheduler_opt.get('prediction_type', None) 98 | if scheduler_type == 'DDPM': 99 | self.noise_scheduler = DDPMScheduler(num_train_timesteps=schedule_opt['n_timestep'], 100 | beta_start=schedule_opt['linear_start'], 101 | beta_end=schedule_opt['linear_end'], 102 | beta_schedule=schedule_opt['schedule']) 103 | elif scheduler_type == 'DDIM': 104 | self.noise_scheduler = DDIMScheduler(num_train_timesteps=schedule_opt['n_timestep'], 105 | beta_start=schedule_opt['linear_start'], 106 | beta_end=schedule_opt['linear_end'], 107 | beta_schedule=schedule_opt['schedule']) 108 | else: 109 | raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') 110 | 111 | if _prediction_type is not None: 112 | # set prediction_type of scheduler if defined 113 | self.noise_scheduler.register_to_config(prediction_type=_prediction_type) 114 | 115 | def feed_data(self, data): 116 | self.lq = data['lq'].to(self.device) 117 | if 'gt' in data: 118 | self.gt = data['gt'].to(self.device) 119 | 120 | def optimize_parameters(self, current_iter, noise=None): 121 | self.optimizer.zero_grad() 122 | 123 | noise = torch.randn_like(self.gt).to(self.device) 124 | bsz = self.gt.shape[0] 125 | # Sample a random timestep for each image 126 | random_timestep = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (1,), device=self.device) 127 | timesteps = random_timestep.repeat(bsz).long() 128 | 129 | # Add noise to the latents according to the noise magnitude at each timestep 130 | noisy_image = self.noise_scheduler.add_noise(self.gt, noise, timesteps) 131 | 132 | # Get the target for loss depending on the prediction type 133 | if self.noise_scheduler.config.prediction_type == "epsilon": 134 | target = noise 135 | elif self.noise_scheduler.config.prediction_type == "v_prediction": 136 | target = self.noise_scheduler.get_velocity(self.gt, noise, timesteps) 137 | elif self.noise_scheduler.config.prediction_type == "sample": 138 | target = self.gt 139 | else: 140 | raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}") 141 | 142 | # Predict the noise residual and compute loss 143 | _timesteps = timesteps.unsqueeze(1).to(self.device) 144 | noise_pred = self.net_g(noisy_image, self.lq, _timesteps) 145 | l_total = 0 146 | loss_dict = OrderedDict() 147 | 148 | if self.cri_pix: 149 | l_pix = self.cri_pix(noise_pred, target) 150 | l_total += l_pix 151 | loss_dict['l_pix'] = l_pix 152 | 153 | l_total.backward() 154 | self.optimizer.step() 155 | 156 | self.log_dict = self.reduce_loss_dict(loss_dict) 157 | 158 | if self.ema_decay > 0: 159 | self.model_ema(decay=self.ema_decay) 160 | 161 | def test(self): 162 | scale = 1 163 | window_size = 8 164 | mod_pad_h, mod_pad_w = 0, 0 165 | _, _, h, w = self.lq.size() 166 | if h % window_size != 0: 167 | mod_pad_h = window_size - h % window_size 168 | if w % window_size != 0: 169 | mod_pad_w = window_size - w % window_size 170 | img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect') 171 | 172 | if hasattr(self, 'net_g_ema'): 173 | print("TODO") 174 | else: 175 | self.net_g.eval() 176 | 177 | is_guidance = self.opt['beta_schedule'].get('is_guidance', False) 178 | 179 | if not is_guidance: 180 | # original conditional 181 | latents = torch.randn_like(img).to(self.device) 182 | 183 | self.noise_scheduler.set_timesteps(self.opt['beta_schedule']['num_inference_steps']) 184 | 185 | for t in self.noise_scheduler.timesteps: 186 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes. 187 | latent_model_input = latents 188 | lq_image = img 189 | _t = t.unsqueeze(0).unsqueeze(1).to(self.device) 190 | 191 | latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t) 192 | 193 | # predict the noise residual 194 | with torch.no_grad(): 195 | noise_pred = self.net_g(latent_model_input, lq_image, _t) 196 | 197 | # compute the previous noisy sample x_t -> x_t-1 198 | latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample 199 | else: 200 | # classifier-free guidance 201 | print("TODO") 202 | 203 | self.output = latents 204 | _, _, h, w = self.output.size() 205 | self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale] 206 | 207 | self.net_g.train() 208 | 209 | def dist_validation(self, dataloader, current_iter, tb_logger, save_img): 210 | if self.opt['rank'] == 0: 211 | self.nondist_validation(dataloader, current_iter, tb_logger, save_img) 212 | 213 | def nondist_validation(self, dataloader, current_iter, tb_logger, save_img): 214 | dataset_name = dataloader.dataset.opt['name'] 215 | with_metrics = self.opt['val'].get('metrics') is not None 216 | use_pbar = self.opt['val'].get('pbar', False) 217 | 218 | if with_metrics: 219 | if not hasattr(self, 'metric_results'): # only execute in the first run 220 | self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()} 221 | # initialize the best metric results for each dataset_name (supporting multiple validation datasets) 222 | self._initialize_best_metric_results(dataset_name) 223 | # zero self.metric_results 224 | if with_metrics: 225 | self.metric_results = {metric: 0 for metric in self.metric_results} 226 | 227 | metric_data = dict() 228 | if use_pbar: 229 | pbar = tqdm(total=len(dataloader), unit='image') 230 | 231 | for idx, val_data in enumerate(dataloader): 232 | img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0] 233 | self.feed_data(val_data) 234 | self.test() 235 | 236 | visuals = self.get_current_visuals() 237 | sr_img = tensor2img([visuals['result']]) 238 | metric_data['img'] = sr_img 239 | if 'gt' in visuals: 240 | gt_img = tensor2img([visuals['gt']]) 241 | metric_data['img2'] = gt_img 242 | del self.gt 243 | 244 | # tentative for out of GPU memory 245 | del self.lq 246 | del self.output 247 | torch.cuda.empty_cache() 248 | 249 | if save_img: 250 | if self.opt['is_train']: 251 | save_img_path = osp.join(self.opt['path']['visualization'], img_name, 252 | f'{img_name}_{current_iter}.png') 253 | else: 254 | if self.opt['val']['suffix']: 255 | save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, 256 | f'{img_name}_{self.opt["val"]["suffix"]}.png') 257 | else: 258 | save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, 259 | f'{img_name}_{self.opt["name"]}.png') 260 | 261 | imwrite(sr_img, save_img_path) 262 | 263 | if with_metrics: 264 | # calculate metrics 265 | for name, opt_ in self.opt['val']['metrics'].items(): 266 | if name == 'lpips': continue 267 | self.metric_results[name] += calculate_metric(metric_data, opt_) 268 | if self.lpips_opt != None: 269 | sr_img = (img2tensor(metric_data['img']) / 255.0).unsqueeze(0).to(self.device) 270 | hq_img = (img2tensor(metric_data['img2']) / 255.0).unsqueeze(0).to(self.device) 271 | self.metric_results['lpips'] += self.lpips_metric(sr_img, hq_img, normalize=True, boundarypixels=self.lpips_opt['crop_border']).item() 272 | if use_pbar: 273 | pbar.update(1) 274 | pbar.set_description(f'Test {img_name}') 275 | if use_pbar: 276 | pbar.close() 277 | 278 | if with_metrics: 279 | for metric in self.metric_results.keys(): 280 | self.metric_results[metric] /= (idx + 1) 281 | # update the best metric result 282 | self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter) 283 | 284 | self._log_validation_metric_values(current_iter, dataset_name, tb_logger) 285 | 286 | def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger): 287 | log_str = f'Validation {dataset_name}\n' 288 | for metric, value in self.metric_results.items(): 289 | log_str += f'\t # {metric}: {value:.4f}' 290 | if hasattr(self, 'best_metric_results'): 291 | log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ ' 292 | f'{self.best_metric_results[dataset_name][metric]["iter"]} iter') 293 | log_str += '\n' 294 | 295 | logger = get_root_logger() 296 | logger.info(log_str) 297 | if tb_logger: 298 | for metric, value in self.metric_results.items(): 299 | tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter) 300 | 301 | def get_current_visuals(self): 302 | out_dict = OrderedDict() 303 | out_dict['lq'] = self.lq.detach().cpu() 304 | out_dict['result'] = self.output.detach().cpu() 305 | if hasattr(self, 'gt'): 306 | out_dict['gt'] = self.gt.detach().cpu() 307 | return out_dict 308 | 309 | def save(self, epoch, current_iter): 310 | if hasattr(self, 'net_g_ema'): 311 | self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema']) 312 | else: 313 | self.save_network(self.net_g, 'net_g', current_iter) 314 | self.save_training_state(epoch, current_iter) 315 | -------------------------------------------------------------------------------- /diffglv/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from copy import deepcopy 3 | from os import path as osp 4 | 5 | from basicsr.utils import get_root_logger, scandir 6 | from basicsr.utils.registry import MODEL_REGISTRY 7 | 8 | __all__ = ['build_model'] 9 | 10 | # automatically scan and import model modules for registry 11 | # scan all the files under the 'models' folder and collect files ending with '_model.py' 12 | model_folder = osp.dirname(osp.abspath(__file__)) 13 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')] 14 | # import all the model modules 15 | _model_modules = [importlib.import_module(f'diffglv.models.{file_name}') for file_name in model_filenames] 16 | 17 | 18 | def build_model(opt): 19 | """Build model from options. 20 | 21 | Args: 22 | opt (dict): Configuration. It must contain: 23 | model_type (str): Model type. 24 | """ 25 | opt = deepcopy(opt) 26 | model = MODEL_REGISTRY.get(opt['model_type'])(opt) 27 | logger = get_root_logger() 28 | logger.info(f'Model [{model.__class__.__name__}] is created.') 29 | return model 30 | -------------------------------------------------------------------------------- /diffglv/utils/GPU_memory.py: -------------------------------------------------------------------------------- 1 | import pynvml 2 | pynvml.nvmlInit() 3 | gpuDeviceCount = pynvml.nvmlDeviceGetCount() 4 | UNIT = 1024 * 1024 5 | for i in range(gpuDeviceCount): 6 | handle = pynvml.nvmlDeviceGetHandleByIndex(i)#获取GPU i的handle,后续通过handle来处理 7 | 8 | memoryInfo = pynvml.nvmlDeviceGetMemoryInfo(handle)#通过handle获取GPU i的信息 9 | 10 | m_total = memoryInfo.total/UNIT 11 | m_used = memoryInfo.used/UNIT 12 | print('[%s][%s/%s]' % (i, m_used, m_total)) 13 | pynvml.nvmlShutdown() -------------------------------------------------------------------------------- /diffglv/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/diffglv/utils/__init__.py -------------------------------------------------------------------------------- /diffglv/utils/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | from collections import OrderedDict 5 | from copy import deepcopy 6 | from torch.nn.parallel import DataParallel, DistributedDataParallel 7 | 8 | from diffglv.utils import lr_scheduler as lr_scheduler 9 | from basicsr.utils import get_root_logger 10 | from basicsr.utils.dist_util import master_only 11 | import torch.nn as nn 12 | 13 | class BaseModel(nn.Module): 14 | """Base model.""" 15 | 16 | def __init__(self, opt): 17 | super().__init__() 18 | self.opt = opt 19 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 20 | self.is_train = opt['is_train'] 21 | self.schedulers = [] 22 | self.optimizers = [] 23 | 24 | def feed_data(self, data): 25 | pass 26 | 27 | def optimize_parameters(self): 28 | pass 29 | 30 | def get_current_visuals(self): 31 | pass 32 | 33 | def save(self, epoch, current_iter): 34 | """Save networks and training state.""" 35 | pass 36 | 37 | def validation(self, dataloader, current_iter, tb_logger, save_img=False): 38 | """Validation function. 39 | 40 | Args: 41 | dataloader (torch.utils.data.DataLoader): Validation dataloader. 42 | current_iter (int): Current iteration. 43 | tb_logger (tensorboard logger): Tensorboard logger. 44 | save_img (bool): Whether to save images. Default: False. 45 | """ 46 | if self.opt['dist']: 47 | self.dist_validation(dataloader, current_iter, tb_logger, save_img) 48 | else: 49 | self.nondist_validation(dataloader, current_iter, tb_logger, save_img) 50 | 51 | def _initialize_best_metric_results(self, dataset_name): 52 | """Initialize the best metric results dict for recording the best metric value and iteration.""" 53 | if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results: 54 | return 55 | elif not hasattr(self, 'best_metric_results'): 56 | self.best_metric_results = dict() 57 | 58 | # add a dataset record 59 | record = dict() 60 | for metric, content in self.opt['val']['metrics'].items(): 61 | better = content.get('better', 'higher') 62 | init_val = float('-inf') if better == 'higher' else float('inf') 63 | record[metric] = dict(better=better, val=init_val, iter=-1) 64 | self.best_metric_results[dataset_name] = record 65 | 66 | def _update_best_metric_result(self, dataset_name, metric, val, current_iter): 67 | if self.best_metric_results[dataset_name][metric]['better'] == 'higher': 68 | if val >= self.best_metric_results[dataset_name][metric]['val']: 69 | self.best_metric_results[dataset_name][metric]['val'] = val 70 | self.best_metric_results[dataset_name][metric]['iter'] = current_iter 71 | else: 72 | if val <= self.best_metric_results[dataset_name][metric]['val']: 73 | self.best_metric_results[dataset_name][metric]['val'] = val 74 | self.best_metric_results[dataset_name][metric]['iter'] = current_iter 75 | 76 | def model_ema(self, decay=0.999): 77 | net_g = self.get_bare_model(self.net_g) 78 | 79 | net_g_params = dict(net_g.named_parameters()) 80 | net_g_ema_params = dict(self.net_g_ema.named_parameters()) 81 | 82 | for k in net_g_ema_params.keys(): 83 | net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay) 84 | 85 | def get_current_log(self): 86 | return self.log_dict 87 | 88 | def model_to_device(self, net): 89 | """Model to device. It also warps models with DistributedDataParallel 90 | or DataParallel. 91 | 92 | Args: 93 | net (nn.Module) 94 | """ 95 | net = net.to(self.device) 96 | if self.opt['dist']: 97 | find_unused_parameters = self.opt.get('find_unused_parameters', False) 98 | net = DistributedDataParallel( 99 | net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters) 100 | elif self.opt['num_gpu'] > 1: 101 | net = DataParallel(net) 102 | return net 103 | 104 | def get_optimizer(self, optim_type, params, lr, **kwargs): 105 | if optim_type == 'Adam': 106 | optimizer = torch.optim.Adam(params, lr, **kwargs) 107 | else: 108 | raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.') 109 | return optimizer 110 | 111 | def setup_schedulers(self): 112 | """Set up schedulers.""" 113 | train_opt = self.opt['train'] 114 | scheduler_type = train_opt['scheduler'].pop('type') 115 | if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']: 116 | for optimizer in self.optimizers: 117 | self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler'])) 118 | elif scheduler_type == 'CosineAnnealingRestartLR': 119 | for optimizer in self.optimizers: 120 | self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler'])) 121 | elif scheduler_type == 'CosineAnnealingRestartCyclicLR': 122 | for optimizer in self.optimizers: 123 | self.schedulers.append(lr_scheduler.CosineAnnealingRestartCyclicLR(optimizer, **train_opt['scheduler'])) 124 | else: 125 | raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.') 126 | 127 | def get_bare_model(self, net): 128 | """Get bare model, especially under wrapping with 129 | DistributedDataParallel or DataParallel. 130 | """ 131 | if isinstance(net, (DataParallel, DistributedDataParallel)): 132 | net = net.module 133 | return net 134 | 135 | @master_only 136 | def print_network(self, net): 137 | """Print the str and parameter number of a network. 138 | 139 | Args: 140 | net (nn.Module) 141 | """ 142 | if isinstance(net, (DataParallel, DistributedDataParallel)): 143 | net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}' 144 | else: 145 | net_cls_str = f'{net.__class__.__name__}' 146 | 147 | net = self.get_bare_model(net) 148 | net_str = str(net) 149 | net_params = sum(map(lambda x: x.numel(), net.parameters())) 150 | 151 | logger = get_root_logger() 152 | logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}') 153 | logger.info(net_str) 154 | 155 | def _set_lr(self, lr_groups_l): 156 | """Set learning rate for warmup. 157 | 158 | Args: 159 | lr_groups_l (list): List for lr_groups, each for an optimizer. 160 | """ 161 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 162 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 163 | param_group['lr'] = lr 164 | 165 | def _get_init_lr(self): 166 | """Get the initial lr, which is set by the scheduler. 167 | """ 168 | init_lr_groups_l = [] 169 | for optimizer in self.optimizers: 170 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) 171 | return init_lr_groups_l 172 | 173 | def update_learning_rate(self, current_iter, warmup_iter=-1): 174 | """Update learning rate. 175 | 176 | Args: 177 | current_iter (int): Current iteration. 178 | warmup_iter (int): Warmup iter numbers. -1 for no warmup. 179 | Default: -1. 180 | """ 181 | if current_iter > 1: 182 | for scheduler in self.schedulers: 183 | scheduler.step() 184 | # set up warm-up learning rate 185 | if current_iter < warmup_iter: 186 | # get initial lr for each group 187 | init_lr_g_l = self._get_init_lr() 188 | # modify warming-up learning rates 189 | # currently only support linearly warm up 190 | warm_up_lr_l = [] 191 | for init_lr_g in init_lr_g_l: 192 | warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g]) 193 | # set learning rate 194 | self._set_lr(warm_up_lr_l) 195 | 196 | def get_current_learning_rate(self): 197 | return [param_group['lr'] for param_group in self.optimizers[0].param_groups] 198 | 199 | @master_only 200 | def save_network(self, net, net_label, current_iter, param_key='params'): 201 | """Save networks. 202 | 203 | Args: 204 | net (nn.Module | list[nn.Module]): Network(s) to be saved. 205 | net_label (str): Network label. 206 | current_iter (int): Current iter number. 207 | param_key (str | list[str]): The parameter key(s) to save network. 208 | Default: 'params'. 209 | """ 210 | if current_iter == -1: 211 | current_iter = 'latest' 212 | save_filename = f'{net_label}_{current_iter}.pth' 213 | save_path = os.path.join(self.opt['path']['models'], save_filename) 214 | 215 | net = net if isinstance(net, list) else [net] 216 | param_key = param_key if isinstance(param_key, list) else [param_key] 217 | assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.' 218 | 219 | save_dict = {} 220 | for net_, param_key_ in zip(net, param_key): 221 | net_ = self.get_bare_model(net_) 222 | state_dict = net_.state_dict() 223 | for key, param in state_dict.items(): 224 | if key.startswith('module.'): # remove unnecessary 'module.' 225 | key = key[7:] 226 | state_dict[key] = param.cpu() 227 | save_dict[param_key_] = state_dict 228 | 229 | # avoid occasional writing errors 230 | retry = 3 231 | while retry > 0: 232 | try: 233 | torch.save(save_dict, save_path) 234 | except Exception as e: 235 | logger = get_root_logger() 236 | logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}') 237 | time.sleep(1) 238 | else: 239 | break 240 | finally: 241 | retry -= 1 242 | if retry == 0: 243 | logger.warning(f'Still cannot save {save_path}. Just ignore it.') 244 | # raise IOError(f'Cannot save {save_path}.') 245 | 246 | def _print_different_keys_loading(self, crt_net, load_net, strict=True): 247 | """Print keys with different name or different size when loading models. 248 | 249 | 1. Print keys with different names. 250 | 2. If strict=False, print the same key but with different tensor size. 251 | It also ignore these keys with different sizes (not load). 252 | 253 | Args: 254 | crt_net (torch model): Current network. 255 | load_net (dict): Loaded network. 256 | strict (bool): Whether strictly loaded. Default: True. 257 | """ 258 | crt_net = self.get_bare_model(crt_net) 259 | crt_net = crt_net.state_dict() 260 | crt_net_keys = set(crt_net.keys()) 261 | load_net_keys = set(load_net.keys()) 262 | 263 | logger = get_root_logger() 264 | if crt_net_keys != load_net_keys: 265 | logger.warning('Current net - loaded net:') 266 | for v in sorted(list(crt_net_keys - load_net_keys)): 267 | logger.warning(f' {v}') 268 | logger.warning('Loaded net - current net:') 269 | for v in sorted(list(load_net_keys - crt_net_keys)): 270 | logger.warning(f' {v}') 271 | 272 | # check the size for the same keys 273 | if not strict: 274 | common_keys = crt_net_keys & load_net_keys 275 | for k in common_keys: 276 | if crt_net[k].size() != load_net[k].size(): 277 | logger.warning(f'Size different, ignore [{k}]: crt_net: ' 278 | f'{crt_net[k].shape}; load_net: {load_net[k].shape}') 279 | load_net[k + '.ignore'] = load_net.pop(k) 280 | 281 | def load_network(self, net, load_path, strict=True, param_key='params'): 282 | """Load network. 283 | 284 | Args: 285 | load_path (str): The path of networks to be loaded. 286 | net (nn.Module): Network. 287 | strict (bool): Whether strictly loaded. 288 | param_key (str): The parameter key of loaded network. If set to 289 | None, use the root 'path'. 290 | Default: 'params'. 291 | """ 292 | logger = get_root_logger() 293 | net = self.get_bare_model(net) 294 | load_net = torch.load(load_path, map_location=lambda storage, loc: storage) 295 | if param_key is not None: 296 | if param_key not in load_net and 'params' in load_net: 297 | param_key = 'params' 298 | logger.info('Loading: params_ema does not exist, use params.') 299 | load_net = load_net[param_key] 300 | logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].') 301 | # remove unnecessary 'module.' 302 | for k, v in deepcopy(load_net).items(): 303 | if k.startswith('module.'): 304 | load_net[k[7:]] = v 305 | load_net.pop(k) 306 | self._print_different_keys_loading(net, load_net, strict) 307 | net.load_state_dict(load_net, strict=strict) 308 | 309 | @master_only 310 | def save_training_state(self, epoch, current_iter): 311 | """Save training states during training, which will be used for 312 | resuming. 313 | 314 | Args: 315 | epoch (int): Current epoch. 316 | current_iter (int): Current iteration. 317 | """ 318 | if current_iter != -1: 319 | state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []} 320 | for o in self.optimizers: 321 | state['optimizers'].append(o.state_dict()) 322 | for s in self.schedulers: 323 | state['schedulers'].append(s.state_dict()) 324 | save_filename = f'{current_iter}.state' 325 | save_path = os.path.join(self.opt['path']['training_states'], save_filename) 326 | 327 | # avoid occasional writing errors 328 | retry = 3 329 | while retry > 0: 330 | try: 331 | torch.save(state, save_path) 332 | except Exception as e: 333 | logger = get_root_logger() 334 | logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}') 335 | time.sleep(1) 336 | else: 337 | break 338 | finally: 339 | retry -= 1 340 | if retry == 0: 341 | logger.warning(f'Still cannot save {save_path}. Just ignore it.') 342 | # raise IOError(f'Cannot save {save_path}.') 343 | 344 | def resume_training(self, resume_state): 345 | """Reload the optimizers and schedulers for resumed training. 346 | 347 | Args: 348 | resume_state (dict): Resume state. 349 | """ 350 | resume_optimizers = resume_state['optimizers'] 351 | resume_schedulers = resume_state['schedulers'] 352 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' 353 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' 354 | for i, o in enumerate(resume_optimizers): 355 | self.optimizers[i].load_state_dict(o) 356 | for i, s in enumerate(resume_schedulers): 357 | self.schedulers[i].load_state_dict(s) 358 | 359 | def reduce_loss_dict(self, loss_dict): 360 | """reduce loss dict. 361 | 362 | In distributed training, it averages the losses among different GPUs . 363 | 364 | Args: 365 | loss_dict (OrderedDict): Loss dict. 366 | """ 367 | with torch.no_grad(): 368 | if self.opt['dist']: 369 | keys = [] 370 | losses = [] 371 | for name, value in loss_dict.items(): 372 | keys.append(name) 373 | losses.append(value) 374 | losses = torch.stack(losses, 0) 375 | torch.distributed.reduce(losses, dst=0) 376 | if self.opt['rank'] == 0: 377 | losses /= self.opt['world_size'] 378 | loss_dict = {key: loss for key, loss in zip(keys, losses)} 379 | 380 | log_dict = OrderedDict() 381 | for name, value in loss_dict.items(): 382 | log_dict[name] = value.mean().item() 383 | 384 | return log_dict -------------------------------------------------------------------------------- /diffglv/utils/beta_schedule.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import device, nn, einsum 4 | import torch.nn.functional as F 5 | from inspect import isfunction 6 | from functools import partial 7 | import numpy as np 8 | from tqdm import tqdm 9 | 10 | 11 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac): 12 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 13 | warmup_time = int(n_timestep * warmup_frac) 14 | betas[:warmup_time] = np.linspace( 15 | linear_start, linear_end, warmup_time, dtype=np.float64) 16 | return betas 17 | 18 | 19 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3): 20 | linear_start = float(linear_start) 21 | linear_end = float(linear_end) 22 | if schedule == 'quad': 23 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, 24 | n_timestep, dtype=np.float64) ** 2 25 | elif schedule == 'linear': 26 | betas = np.linspace(linear_start, linear_end, 27 | n_timestep, dtype=np.float64) 28 | elif schedule == 'warmup10': 29 | betas = _warmup_beta(linear_start, linear_end, 30 | n_timestep, 0.1) 31 | elif schedule == 'warmup50': 32 | betas = _warmup_beta(linear_start, linear_end, 33 | n_timestep, 0.5) 34 | elif schedule == 'const': 35 | betas = linear_end * np.ones(n_timestep, dtype=np.float64) 36 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1 37 | betas = 1. / np.linspace(n_timestep, 38 | 1, n_timestep, dtype=np.float64) 39 | elif schedule == "cosine": 40 | timesteps = ( 41 | torch.arange(n_timestep + 1, dtype=torch.float64) / 42 | n_timestep + cosine_s 43 | ) 44 | alphas = timesteps / (1 + cosine_s) * math.pi / 2 45 | alphas = torch.cos(alphas).pow(2) 46 | alphas = alphas / alphas[0] 47 | betas = 1 - alphas[1:] / alphas[:-1] 48 | betas = betas.clamp(max=0.999) 49 | else: 50 | raise NotImplementedError(schedule) 51 | return betas 52 | 53 | def exists(x): 54 | return x is not None 55 | 56 | 57 | def default(val, d): 58 | if exists(val): 59 | return val 60 | return d() if isfunction(d) else d 61 | -------------------------------------------------------------------------------- /diffglv/utils/extract_subimages.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import sys 5 | from multiprocessing import Pool 6 | from os import path as osp 7 | from tqdm import tqdm 8 | 9 | from basicsr.utils import scandir 10 | 11 | 12 | def main(): 13 | """A multi-thread tool to crop large images to sub-images for faster IO. 14 | It is used for DIV2K dataset. 15 | Args: 16 | opt (dict): Configuration dict. It contains: 17 | n_thread (int): Thread number. 18 | compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and 19 | longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2. 20 | input_folder (str): Path to the input folder. 21 | save_folder (str): Path to save folder. 22 | crop_size (int): Crop size. 23 | step (int): Step for overlapped sliding window. 24 | thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped. 25 | Usage: 26 | For each folder, run this script. 27 | Typically, there are four folders to be processed for DIV2K dataset. 28 | * DIV2K_train_HR 29 | * DIV2K_train_LR_bicubic/X2 30 | * DIV2K_train_LR_bicubic/X3 31 | * DIV2K_train_LR_bicubic/X4 32 | After process, each sub_folder should have the same number of subimages. 33 | Remember to modify opt configurations according to your settings. 34 | """ 35 | 36 | opt = {} 37 | opt['n_thread'] = 20 38 | opt['compression_level'] = 3 39 | 40 | # opt['input_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_HR_X8' 41 | # opt['save_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_HR_X8_sub' 42 | # opt['crop_size'] = 512 43 | # opt['step'] = 256 44 | # opt['thresh_size'] = 0 45 | # extract_subimages(opt) 46 | # 47 | # opt['input_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_HR' 48 | # opt['save_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_HR_sub' 49 | # opt['crop_size'] = 512 50 | # opt['step'] = 256 51 | # opt['thresh_size'] = 0 52 | # extract_subimages(opt) 53 | # 54 | # opt['input_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_LR_X2_rs' 55 | # opt['save_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_LR_X2_rs_sub' 56 | # opt['crop_size'] = 512 57 | # opt['step'] = 256 58 | # opt['thresh_size'] = 0 59 | # extract_subimages(opt) 60 | # 61 | # opt['input_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_LR_X4_rs' 62 | # opt['save_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_LR_X4_rs_sub' 63 | # opt['crop_size'] = 512 64 | # opt['step'] = 256 65 | # opt['thresh_size'] = 0 66 | # extract_subimages(opt) 67 | 68 | opt['input_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_LR_X8_rs' 69 | opt['save_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_LR_X8_rs_sub' 70 | opt['crop_size'] = 512 71 | opt['step'] = 256 72 | opt['thresh_size'] = 0 73 | extract_subimages(opt) 74 | 75 | 76 | 77 | def extract_subimages(opt): 78 | """Crop images to subimages. 79 | Args: 80 | opt (dict): Configuration dict. It contains: 81 | input_folder (str): Path to the input folder. 82 | save_folder (str): Path to save folder. 83 | n_thread (int): Thread number. 84 | """ 85 | input_folder = opt['input_folder'] 86 | save_folder = opt['save_folder'] 87 | if not osp.exists(save_folder): 88 | os.makedirs(save_folder) 89 | print(f'mkdir {save_folder} ...') 90 | else: 91 | print(f'Folder {save_folder} already exists. Exit.') 92 | sys.exit(1) 93 | 94 | img_list = list(scandir(input_folder, full_path=True)) 95 | 96 | pbar = tqdm(total=len(img_list), unit='image', desc='Extract') 97 | pool = Pool(opt['n_thread']) 98 | for path in img_list: 99 | pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1)) 100 | pool.close() 101 | pool.join() 102 | pbar.close() 103 | print('All processes done.') 104 | 105 | 106 | def worker(path, opt): 107 | """Worker for each process. 108 | Args: 109 | path (str): Image path. 110 | opt (dict): Configuration dict. It contains: 111 | crop_size (int): Crop size. 112 | step (int): Step for overlapped sliding window. 113 | thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped. 114 | save_folder (str): Path to save folder. 115 | compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. 116 | Returns: 117 | process_info (str): Process information displayed in progress bar. 118 | """ 119 | crop_size = opt['crop_size'] 120 | step = opt['step'] 121 | thresh_size = opt['thresh_size'] 122 | img_name, extension = osp.splitext(osp.basename(path)) 123 | 124 | # remove the x2, x3, x4 and x8 in the filename for DIV2K 125 | img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '') 126 | 127 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 128 | 129 | h, w = img.shape[0:2] 130 | h_space = np.arange(0, h - crop_size + 1, step) 131 | if h - (h_space[-1] + crop_size) > thresh_size: 132 | h_space = np.append(h_space, h - crop_size) 133 | w_space = np.arange(0, w - crop_size + 1, step) 134 | if w - (w_space[-1] + crop_size) > thresh_size: 135 | w_space = np.append(w_space, w - crop_size) 136 | 137 | index = 0 138 | for x in h_space: 139 | for y in w_space: 140 | index += 1 141 | cropped_img = img[x:x + crop_size, y:y + crop_size, ...] 142 | cropped_img = np.ascontiguousarray(cropped_img) 143 | cv2.imwrite( 144 | osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img, 145 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']]) 146 | process_info = f'Processing {img_name} ...' 147 | return process_info 148 | 149 | 150 | if __name__ == '__main__': 151 | main() -------------------------------------------------------------------------------- /diffglv/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from basicsr.utils import get_root_logger 6 | from basicsr.utils.dist_util import master_only 7 | 8 | class MessageLogger(): 9 | """Message logger for printing. 10 | 11 | Args: 12 | opt (dict): Config. It contains the following keys: 13 | name (str): Exp name. 14 | logger (dict): Contains 'print_freq' (str) for logger interval. 15 | train (dict): Contains 'total_iter' (int) for total iters. 16 | use_tb_logger (bool): Use tensorboard logger. 17 | start_iter (int): Start iter. Default: 1. 18 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 19 | """ 20 | 21 | def __init__(self, opt, start_iter=1, tb_logger=None): 22 | self.exp_name = opt['name'] 23 | self.interval = opt['logger']['print_freq'] 24 | self.start_iter = start_iter 25 | self.max_iters = opt['train']['total_iter'] 26 | self.use_tb_logger = opt['logger']['use_tb_logger'] 27 | self.tb_logger = tb_logger 28 | self.start_time = time.time() 29 | self.logger = get_root_logger() 30 | 31 | def reset_start_time(self): 32 | self.start_time = time.time() 33 | 34 | @master_only 35 | def __call__(self, log_vars): 36 | """Format logging message. 37 | 38 | Args: 39 | log_vars (dict): It contains the following keys: 40 | epoch (int): Epoch number. 41 | iter (int): Current iter. 42 | lrs (list): List for learning rates. 43 | 44 | time (float): Iter time. 45 | data_time (float): Data time for each iter. 46 | """ 47 | # epoch, iter, learning rates 48 | epoch = log_vars.pop('epoch') 49 | current_iter = log_vars.pop('iter') 50 | lrs = log_vars.pop('lrs') 51 | 52 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(') 53 | for v in lrs: 54 | message += f'{v:.3e},' 55 | message += ')] ' 56 | 57 | if 'task' in log_vars.keys(): 58 | message += '[' 59 | task = log_vars.pop('task') 60 | message += f'task: {task}, ' 61 | dataset = log_vars.pop('dataset') 62 | message += f'dataset: {dataset}' 63 | message += '] ' 64 | 65 | # time and estimated time 66 | if 'time' in log_vars.keys(): 67 | iter_time = log_vars.pop('time') 68 | data_time = log_vars.pop('data_time') 69 | 70 | total_time = time.time() - self.start_time 71 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 72 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 73 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 74 | message += f'[eta: {eta_str}, ' 75 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 76 | 77 | # other items, especially losses 78 | for k, v in log_vars.items(): 79 | message += f'{k}: {v:.4e} ' 80 | # tensorboard logger 81 | if self.use_tb_logger and 'debug' not in self.exp_name: 82 | if k.startswith('l_'): 83 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 84 | else: 85 | self.tb_logger.add_scalar(k, v, current_iter) 86 | self.logger.info(message) -------------------------------------------------------------------------------- /diffglv/utils/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | import torch 5 | 6 | 7 | class MultiStepRestartLR(_LRScheduler): 8 | """ MultiStep with restarts learning rate scheme. 9 | 10 | Args: 11 | optimizer (torch.nn.optimizer): Torch optimizer. 12 | milestones (list): Iterations that will decrease learning rate. 13 | gamma (float): Decrease ratio. Default: 0.1. 14 | restarts (list): Restart iterations. Default: [0]. 15 | restart_weights (list): Restart weights at each restart iteration. 16 | Default: [1]. 17 | last_epoch (int): Used in _LRScheduler. Default: -1. 18 | """ 19 | 20 | def __init__(self, 21 | optimizer, 22 | milestones, 23 | gamma=0.1, 24 | restarts=(0, ), 25 | restart_weights=(1, ), 26 | last_epoch=-1): 27 | self.milestones = Counter(milestones) 28 | self.gamma = gamma 29 | self.restarts = restarts 30 | self.restart_weights = restart_weights 31 | assert len(self.restarts) == len( 32 | self.restart_weights), 'restarts and their weights do not match.' 33 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 34 | 35 | def get_lr(self): 36 | if self.last_epoch in self.restarts: 37 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 38 | return [ 39 | group['initial_lr'] * weight 40 | for group in self.optimizer.param_groups 41 | ] 42 | if self.last_epoch not in self.milestones: 43 | return [group['lr'] for group in self.optimizer.param_groups] 44 | return [ 45 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 46 | for group in self.optimizer.param_groups 47 | ] 48 | 49 | class LinearLR(_LRScheduler): 50 | """ 51 | 52 | Args: 53 | optimizer (torch.nn.optimizer): Torch optimizer. 54 | milestones (list): Iterations that will decrease learning rate. 55 | gamma (float): Decrease ratio. Default: 0.1. 56 | last_epoch (int): Used in _LRScheduler. Default: -1. 57 | """ 58 | 59 | def __init__(self, 60 | optimizer, 61 | total_iter, 62 | last_epoch=-1): 63 | self.total_iter = total_iter 64 | super(LinearLR, self).__init__(optimizer, last_epoch) 65 | 66 | def get_lr(self): 67 | process = self.last_epoch / self.total_iter 68 | weight = (1 - process) 69 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups]) 70 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 71 | 72 | class VibrateLR(_LRScheduler): 73 | """ 74 | 75 | Args: 76 | optimizer (torch.nn.optimizer): Torch optimizer. 77 | milestones (list): Iterations that will decrease learning rate. 78 | gamma (float): Decrease ratio. Default: 0.1. 79 | last_epoch (int): Used in _LRScheduler. Default: -1. 80 | """ 81 | 82 | def __init__(self, 83 | optimizer, 84 | total_iter, 85 | last_epoch=-1): 86 | self.total_iter = total_iter 87 | super(VibrateLR, self).__init__(optimizer, last_epoch) 88 | 89 | def get_lr(self): 90 | process = self.last_epoch / self.total_iter 91 | 92 | f = 0.1 93 | if process < 3 / 8: 94 | f = 1 - process * 8 / 3 95 | elif process < 5 / 8: 96 | f = 0.2 97 | 98 | T = self.total_iter // 80 99 | Th = T // 2 100 | 101 | t = self.last_epoch % T 102 | 103 | f2 = t / Th 104 | if t >= Th: 105 | f2 = 2 - f2 106 | 107 | weight = f * f2 108 | 109 | if self.last_epoch < Th: 110 | weight = max(0.1, weight) 111 | 112 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2)) 113 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 114 | 115 | def get_position_from_periods(iteration, cumulative_period): 116 | """Get the position from a period list. 117 | 118 | It will return the index of the right-closest number in the period list. 119 | For example, the cumulative_period = [100, 200, 300, 400], 120 | if iteration == 50, return 0; 121 | if iteration == 210, return 2; 122 | if iteration == 300, return 2. 123 | 124 | Args: 125 | iteration (int): Current iteration. 126 | cumulative_period (list[int]): Cumulative period list. 127 | 128 | Returns: 129 | int: The position of the right-closest number in the period list. 130 | """ 131 | for i, period in enumerate(cumulative_period): 132 | if iteration <= period: 133 | return i 134 | 135 | 136 | class CosineAnnealingRestartLR(_LRScheduler): 137 | """ Cosine annealing with restarts learning rate scheme. 138 | 139 | An example of config: 140 | periods = [10, 10, 10, 10] 141 | restart_weights = [1, 0.5, 0.5, 0.5] 142 | eta_min=1e-7 143 | 144 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 145 | scheduler will restart with the weights in restart_weights. 146 | 147 | Args: 148 | optimizer (torch.nn.optimizer): Torch optimizer. 149 | periods (list): Period for each cosine anneling cycle. 150 | restart_weights (list): Restart weights at each restart iteration. 151 | Default: [1]. 152 | eta_min (float): The mimimum lr. Default: 0. 153 | last_epoch (int): Used in _LRScheduler. Default: -1. 154 | """ 155 | 156 | def __init__(self, 157 | optimizer, 158 | periods, 159 | restart_weights=(1, ), 160 | eta_min=0, 161 | last_epoch=-1): 162 | self.periods = periods 163 | self.restart_weights = restart_weights 164 | self.eta_min = eta_min 165 | assert (len(self.periods) == len(self.restart_weights) 166 | ), 'periods and restart_weights should have the same length.' 167 | self.cumulative_period = [ 168 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 169 | ] 170 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 171 | 172 | def get_lr(self): 173 | idx = get_position_from_periods(self.last_epoch, 174 | self.cumulative_period) 175 | current_weight = self.restart_weights[idx] 176 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 177 | current_period = self.periods[idx] 178 | 179 | return [ 180 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 181 | (1 + math.cos(math.pi * ( 182 | (self.last_epoch - nearest_restart) / current_period))) 183 | for base_lr in self.base_lrs 184 | ] 185 | 186 | class CosineAnnealingRestartCyclicLR(_LRScheduler): 187 | """ Cosine annealing with restarts learning rate scheme. 188 | An example of config: 189 | periods = [10, 10, 10, 10] 190 | restart_weights = [1, 0.5, 0.5, 0.5] 191 | eta_min=1e-7 192 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 193 | scheduler will restart with the weights in restart_weights. 194 | Args: 195 | optimizer (torch.nn.optimizer): Torch optimizer. 196 | periods (list): Period for each cosine anneling cycle. 197 | restart_weights (list): Restart weights at each restart iteration. 198 | Default: [1]. 199 | eta_min (float): The mimimum lr. Default: 0. 200 | last_epoch (int): Used in _LRScheduler. Default: -1. 201 | """ 202 | 203 | def __init__(self, 204 | optimizer, 205 | periods, 206 | restart_weights=(1, ), 207 | eta_mins=(0, ), 208 | last_epoch=-1): 209 | self.periods = periods 210 | self.restart_weights = restart_weights 211 | self.eta_mins = eta_mins 212 | assert (len(self.periods) == len(self.restart_weights) 213 | ), 'periods and restart_weights should have the same length.' 214 | self.cumulative_period = [ 215 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 216 | ] 217 | super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch) 218 | 219 | def get_lr(self): 220 | idx = get_position_from_periods(self.last_epoch, 221 | self.cumulative_period) 222 | current_weight = self.restart_weights[idx] 223 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 224 | current_period = self.periods[idx] 225 | eta_min = self.eta_mins[idx] 226 | 227 | return [ 228 | eta_min + current_weight * 0.5 * (base_lr - eta_min) * 229 | (1 + math.cos(math.pi * ( 230 | (self.last_epoch - nearest_restart) / current_period))) 231 | for base_lr in self.base_lrs 232 | ] 233 | -------------------------------------------------------------------------------- /diffglv/utils/make_ds.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from PIL import Image 3 | import os 4 | import sys 5 | from multiprocessing import Pool 6 | from os import path as osp 7 | from tqdm import tqdm 8 | from basicsr.utils import scandir 9 | 10 | 11 | def make_downsampling(input_folder, save_folder, scale, rescaling=False, downsample_type='bicubic', 12 | n_thread=20, wash_only=False): 13 | """Crop images to subimages. 14 | Args: 15 | opt (dict): Configuration dict. It contains: 16 | input_folder (str): Path to the input folder. 17 | save_folder (str): Path to save folder. 18 | n_thread (int): Thread number. 19 | """ 20 | opt = {} 21 | opt['scale'] = scale 22 | opt['rescaling'] = rescaling 23 | opt['save_folder'] = save_folder 24 | opt['downsample_type'] = downsample_type 25 | opt['wash_only'] = wash_only 26 | 27 | if not osp.exists(save_folder): 28 | os.makedirs(save_folder) 29 | print(f'mkdir {save_folder} ...') 30 | else: 31 | print(f'Folder {save_folder} already exists. Exit.') 32 | sys.exit(1) 33 | 34 | img_list = list(scandir(input_folder, full_path=True)) 35 | 36 | pbar = tqdm(total=len(img_list), unit='image', desc='Extract') 37 | pool = Pool(n_thread) 38 | for path in img_list: 39 | pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1)) 40 | pool.close() 41 | pool.join() 42 | pbar.close() 43 | print('All processes done.') 44 | 45 | 46 | def worker(path, opt): 47 | """Worker for each process. 48 | Args: 49 | path (str): Image path. 50 | opt (dict): Configuration dict. It contains: 51 | crop_size (int): Crop size. 52 | step (int): Step for overlapped sliding window. 53 | thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped. 54 | save_folder (str): Path to save folder. 55 | compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. 56 | Returns: 57 | process_info (str): Process information displayed in progress bar. 58 | """ 59 | scale = opt['scale'] 60 | save_folder = opt['save_folder'] 61 | downsample_type = opt['downsample_type'] 62 | 63 | hr = Image.open(path) 64 | w, h = hr.size 65 | if w % scale + h % scale: 66 | print('\n\nHR needs data washing\n') 67 | hr = hr.crop([0, 0, w //scale * scale, h //scale *scale]) 68 | 69 | ds_func = Image.Resampling.BICUBIC if downsample_type == 'bicubic' else Image.Resampling.BILINEAR 70 | 71 | if not opt['wash_only']: 72 | lr = hr.resize((w//scale, h//scale), ds_func) 73 | if opt['rescaling']: 74 | lr = lr.resize((w//scale*scale, h//scale*scale), ds_func) 75 | else: 76 | lr = hr 77 | 78 | lr.save(osp.join(save_folder, osp.split(path)[-1])) 79 | process_info = f'Processing {osp.split(path)[-1]} ...' 80 | return process_info 81 | 82 | 83 | if __name__ == '__main__': 84 | parser = argparse.ArgumentParser() 85 | parser.add_argument('--src', type=str) 86 | parser.add_argument('--dst', type=str) 87 | parser.add_argument('--scale', type=int) 88 | parser.add_argument('--rescaling', '-rs', action='store_true') 89 | parser.add_argument('--n_worker', type=int, default=20) 90 | parser.add_argument('--ds_func', type=str, default='bicubic') 91 | parser.add_argument('--wash_only', '-wo', action='store_true') 92 | 93 | args = parser.parse_args() 94 | 95 | make_downsampling(args.src, args.dst, args.scale, args.rescaling, args.ds_func, args.n_worker, args.wash_only) -------------------------------------------------------------------------------- /diffglv/utils/options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import random 3 | import torch 4 | import yaml 5 | from collections import OrderedDict 6 | from os import path as osp 7 | 8 | from basicsr.utils import set_random_seed 9 | from basicsr.utils.dist_util import get_dist_info, init_dist, master_only 10 | 11 | 12 | def ordered_yaml(): 13 | """Support OrderedDict for yaml. 14 | 15 | Returns: 16 | yaml Loader and Dumper. 17 | """ 18 | try: 19 | from yaml import CDumper as Dumper 20 | from yaml import CLoader as Loader 21 | except ImportError: 22 | from yaml import Dumper, Loader 23 | 24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 25 | 26 | def dict_representer(dumper, data): 27 | return dumper.represent_dict(data.items()) 28 | 29 | def dict_constructor(loader, node): 30 | return OrderedDict(loader.construct_pairs(node)) 31 | 32 | Dumper.add_representer(OrderedDict, dict_representer) 33 | Loader.add_constructor(_mapping_tag, dict_constructor) 34 | return Loader, Dumper 35 | 36 | 37 | def dict2str(opt, indent_level=1): 38 | """dict to string for printing options. 39 | 40 | Args: 41 | opt (dict): Option dict. 42 | indent_level (int): Indent level. Default: 1. 43 | 44 | Return: 45 | (str): Option string for printing. 46 | """ 47 | msg = '\n' 48 | for k, v in opt.items(): 49 | if isinstance(v, dict): 50 | msg += ' ' * (indent_level * 2) + k + ':[' 51 | msg += dict2str(v, indent_level + 1) 52 | msg += ' ' * (indent_level * 2) + ']\n' 53 | else: 54 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 55 | return msg 56 | 57 | 58 | def _postprocess_yml_value(value): 59 | # None 60 | if value == '~' or value.lower() == 'none': 61 | return None 62 | # bool 63 | if value.lower() == 'true': 64 | return True 65 | elif value.lower() == 'false': 66 | return False 67 | # !!float number 68 | if value.startswith('!!float'): 69 | return float(value.replace('!!float', '')) 70 | # number 71 | if value.isdigit(): 72 | return int(value) 73 | elif value.replace('.', '', 1).isdigit() and value.count('.') < 2: 74 | return float(value) 75 | # list 76 | if value.startswith('['): 77 | return eval(value) 78 | # str 79 | return value 80 | 81 | 82 | def parse_options(root_path, is_train=True): 83 | parser = argparse.ArgumentParser() 84 | parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.') 85 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher') 86 | parser.add_argument('--auto_resume', action='store_true') 87 | parser.add_argument('--debug', action='store_true') 88 | parser.add_argument('--local_rank', type=int, default=0) 89 | parser.add_argument( 90 | '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999') 91 | args = parser.parse_args() 92 | 93 | # parse yml to dict 94 | with open(args.opt, mode='r') as f: 95 | opt = yaml.load(f, Loader=ordered_yaml()[0]) 96 | 97 | # distributed settings 98 | if args.launcher == 'none': 99 | opt['dist'] = False 100 | print('Disable distributed.', flush=True) 101 | else: 102 | opt['dist'] = True 103 | if args.launcher == 'slurm' and 'dist_params' in opt: 104 | init_dist(args.launcher, **opt['dist_params']) 105 | else: 106 | init_dist(args.launcher) 107 | opt['rank'], opt['world_size'] = get_dist_info() 108 | 109 | # random seed 110 | seed = opt.get('manual_seed') 111 | if seed is None: 112 | seed = random.randint(1, 10000) 113 | opt['manual_seed'] = seed 114 | set_random_seed(seed + opt['rank']) 115 | 116 | # force to update yml options 117 | if args.force_yml is not None: 118 | for entry in args.force_yml: 119 | # now do not support creating new keys 120 | keys, value = entry.split('=') 121 | keys, value = keys.strip(), value.strip() 122 | value = _postprocess_yml_value(value) 123 | eval_str = 'opt' 124 | for key in keys.split(':'): 125 | eval_str += f'["{key}"]' 126 | eval_str += '=value' 127 | # using exec function 128 | exec(eval_str) 129 | 130 | opt['auto_resume'] = args.auto_resume 131 | opt['is_train'] = is_train 132 | 133 | # debug setting 134 | if args.debug and not opt['name'].startswith('debug'): 135 | opt['name'] = 'debug_' + opt['name'] 136 | 137 | if opt['num_gpu'] == 'auto': 138 | opt['num_gpu'] = torch.cuda.device_count() 139 | 140 | # datasets 141 | for phase, dataset in opt['datasets'].items(): 142 | # for multiple datasets, e.g., val_1, val_2; test_1, test_2 143 | phase = phase.split('_')[0] 144 | dataset['phase'] = phase 145 | if 'scale' in opt: 146 | dataset['scale'] = opt['scale'] 147 | if dataset.get('dataroot_gt') is not None: 148 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 149 | if dataset.get('dataroot_lq') is not None: 150 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 151 | 152 | # paths 153 | for key, val in opt['path'].items(): 154 | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key): 155 | opt['path'][key] = osp.expanduser(val) 156 | 157 | if is_train: 158 | experiments_root = osp.join(root_path, 'experiments', opt['name']) 159 | opt['path']['experiments_root'] = experiments_root 160 | opt['path']['models'] = osp.join(experiments_root, 'models') 161 | opt['path']['training_states'] = osp.join(experiments_root, 'training_states') 162 | opt['path']['log'] = experiments_root 163 | opt['path']['visualization'] = osp.join(experiments_root, 'visualization') 164 | 165 | # change some options for debug mode 166 | if 'debug' in opt['name']: 167 | if 'val' in opt: 168 | opt['val']['val_freq'] = 8 169 | opt['logger']['print_freq'] = 1 170 | opt['logger']['save_checkpoint_freq'] = 8 171 | else: # test 172 | results_root = osp.join(root_path, 'results', opt['name']) 173 | opt['path']['results_root'] = results_root 174 | opt['path']['log'] = results_root 175 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 176 | 177 | return opt, args 178 | 179 | 180 | @master_only 181 | def copy_opt_file(opt_file, experiments_root): 182 | # copy the yml file to the experiment root 183 | import sys 184 | import time 185 | from shutil import copyfile 186 | cmd = ' '.join(sys.argv) 187 | filename = osp.join(experiments_root, osp.basename(opt_file)) 188 | copyfile(opt_file, filename) 189 | 190 | with open(filename, 'r+') as f: 191 | lines = f.readlines() 192 | lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n') 193 | f.seek(0) 194 | f.writelines(lines) 195 | -------------------------------------------------------------------------------- /diffglv/utils/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import numpy as np 4 | import torch 5 | 6 | 7 | def mod_crop(img, scale): 8 | """Mod crop images, used during testing. 9 | 10 | Args: 11 | img (ndarray): Input image. 12 | scale (int): Scale factor. 13 | 14 | Returns: 15 | ndarray: Result image. 16 | """ 17 | img = img.copy() 18 | if img.ndim in (2, 3): 19 | h, w = img.shape[0], img.shape[1] 20 | h_remainder, w_remainder = h % scale, w % scale 21 | img = img[:h - h_remainder, :w - w_remainder, ...] 22 | else: 23 | raise ValueError(f'Wrong img ndim: {img.ndim}.') 24 | return img 25 | 26 | 27 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None): 28 | """Paired random crop. Support Numpy array and Tensor inputs. 29 | 30 | It crops lists of lq and gt images with corresponding locations. 31 | 32 | Args: 33 | img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images 34 | should have the same shape. If the input is an ndarray, it will 35 | be transformed to a list containing itself. 36 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 37 | should have the same shape. If the input is an ndarray, it will 38 | be transformed to a list containing itself. 39 | gt_patch_size (int): GT patch size. 40 | scale (int): Scale factor. 41 | gt_path (str): Path to ground-truth. Default: None. 42 | 43 | Returns: 44 | list[ndarray] | ndarray: GT images and LQ images. If returned results 45 | only have one element, just return ndarray. 46 | """ 47 | 48 | if not isinstance(img_gts, list): 49 | img_gts = [img_gts] 50 | if not isinstance(img_lqs, list): 51 | img_lqs = [img_lqs] 52 | 53 | # determine input type: Numpy array or Tensor 54 | input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy' 55 | 56 | if input_type == 'Tensor': 57 | h_lq, w_lq = img_lqs[0].size()[-2:] 58 | h_gt, w_gt = img_gts[0].size()[-2:] 59 | else: 60 | h_lq, w_lq = img_lqs[0].shape[0:2] 61 | h_gt, w_gt = img_gts[0].shape[0:2] 62 | lq_patch_size = gt_patch_size // scale 63 | 64 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 65 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 66 | f'multiplication of LQ ({h_lq}, {w_lq}).') 67 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 68 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 69 | f'({lq_patch_size}, {lq_patch_size}). ' 70 | f'Please remove {gt_path}.') 71 | 72 | # randomly choose top and left coordinates for lq patch 73 | top = random.randint(0, h_lq - lq_patch_size) 74 | left = random.randint(0, w_lq - lq_patch_size) 75 | 76 | # crop lq patch 77 | if input_type == 'Tensor': 78 | img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs] 79 | else: 80 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs] 81 | 82 | # crop corresponding gt patch 83 | top_gt, left_gt = int(top * scale), int(left * scale) 84 | if input_type == 'Tensor': 85 | img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts] 86 | else: 87 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts] 88 | if len(img_gts) == 1: 89 | img_gts = img_gts[0] 90 | if len(img_lqs) == 1: 91 | img_lqs = img_lqs[0] 92 | return img_gts, img_lqs 93 | 94 | 95 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): 96 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 97 | 98 | We use vertical flip and transpose for rotation implementation. 99 | All the images in the list use the same augmentation. 100 | 101 | Args: 102 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 103 | is an ndarray, it will be transformed to a list. 104 | hflip (bool): Horizontal flip. Default: True. 105 | rotation (bool): Ratotation. Default: True. 106 | flows (list[ndarray]: Flows to be augmented. If the input is an 107 | ndarray, it will be transformed to a list. 108 | Dimension is (h, w, 2). Default: None. 109 | return_status (bool): Return the status of flip and rotation. 110 | Default: False. 111 | 112 | Returns: 113 | list[ndarray] | ndarray: Augmented images and flows. If returned 114 | results only have one element, just return ndarray. 115 | 116 | """ 117 | hflip = hflip and random.random() < 0.5 118 | vflip = rotation and random.random() < 0.5 119 | rot90 = rotation and random.random() < 0.5 120 | 121 | def _augment(img): 122 | if hflip: # horizontal 123 | cv2.flip(img, 1, img) 124 | if vflip: # vertical 125 | cv2.flip(img, 0, img) 126 | if rot90: 127 | img = img.transpose(1, 0, 2) 128 | return img 129 | 130 | def _augment_flow(flow): 131 | if hflip: # horizontal 132 | cv2.flip(flow, 1, flow) 133 | flow[:, :, 0] *= -1 134 | if vflip: # vertical 135 | cv2.flip(flow, 0, flow) 136 | flow[:, :, 1] *= -1 137 | if rot90: 138 | flow = flow.transpose(1, 0, 2) 139 | flow = flow[:, :, [1, 0]] 140 | return flow 141 | 142 | if not isinstance(imgs, list): 143 | imgs = [imgs] 144 | imgs = [_augment(img) for img in imgs] 145 | if len(imgs) == 1: 146 | imgs = imgs[0] 147 | 148 | if flows is not None: 149 | if not isinstance(flows, list): 150 | flows = [flows] 151 | flows = [_augment_flow(flow) for flow in flows] 152 | if len(flows) == 1: 153 | flows = flows[0] 154 | return imgs, flows 155 | else: 156 | if return_status: 157 | return imgs, (hflip, vflip, rot90) 158 | else: 159 | return imgs 160 | 161 | 162 | def img_rotate(img, angle, center=None, scale=1.0): 163 | """Rotate image. 164 | 165 | Args: 166 | img (ndarray): Image to be rotated. 167 | angle (float): Rotation angle in degrees. Positive values mean 168 | counter-clockwise rotation. 169 | center (tuple[int]): Rotation center. If the center is None, 170 | initialize it as the center of the image. Default: None. 171 | scale (float): Isotropic scale factor. Default: 1.0. 172 | """ 173 | (h, w) = img.shape[:2] 174 | 175 | if center is None: 176 | center = (w // 2, h // 2) 177 | 178 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 179 | rotated_img = cv2.warpAffine(img, matrix, (w, h)) 180 | return rotated_img 181 | 182 | 183 | def data_augmentation(image, mode): 184 | """ 185 | Performs data augmentation of the input image 186 | Input: 187 | image: a cv2 (OpenCV) image 188 | mode: int. Choice of transformation to apply to the image 189 | 0 - no transformation 190 | 1 - flip up and down 191 | 2 - rotate counterwise 90 degree 192 | 3 - rotate 90 degree and flip up and down 193 | 4 - rotate 180 degree 194 | 5 - rotate 180 degree and flip 195 | 6 - rotate 270 degree 196 | 7 - rotate 270 degree and flip 197 | """ 198 | if mode == 0: 199 | # original 200 | out = image 201 | elif mode == 1: 202 | # flip up and down 203 | out = np.flipud(image) 204 | elif mode == 2: 205 | # rotate counterwise 90 degree 206 | out = np.rot90(image) 207 | elif mode == 3: 208 | # rotate 90 degree and flip up and down 209 | out = np.rot90(image) 210 | out = np.flipud(out) 211 | elif mode == 4: 212 | # rotate 180 degree 213 | out = np.rot90(image, k=2) 214 | elif mode == 5: 215 | # rotate 180 degree and flip 216 | out = np.rot90(image, k=2) 217 | out = np.flipud(out) 218 | elif mode == 6: 219 | # rotate 270 degree 220 | out = np.rot90(image, k=3) 221 | elif mode == 7: 222 | # rotate 270 degree and flip 223 | out = np.rot90(image, k=3) 224 | out = np.flipud(out) 225 | else: 226 | raise Exception('Invalid choice of image transformation') 227 | 228 | return out 229 | 230 | def random_augmentation(*args): 231 | out = [] 232 | flag_aug = random.randint(0,7) 233 | for data in args: 234 | out.append(data_augmentation(data, flag_aug).copy()) 235 | return out 236 | -------------------------------------------------------------------------------- /experiments/README.md: -------------------------------------------------------------------------------- 1 | Place pretrained models in `pretrained_models`. 2 | -------------------------------------------------------------------------------- /experiments/pretrained_models/README.md: -------------------------------------------------------------------------------- 1 | Place pretrained models here. -------------------------------------------------------------------------------- /figs/BI-DiffSR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/BI-DiffSR.png -------------------------------------------------------------------------------- /figs/F1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/F1.png -------------------------------------------------------------------------------- /figs/F2-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/F2-1.png -------------------------------------------------------------------------------- /figs/F2-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/F2-2.png -------------------------------------------------------------------------------- /figs/F3-1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/F3-1.png -------------------------------------------------------------------------------- /figs/F3-2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/F3-2.png -------------------------------------------------------------------------------- /figs/T1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/T1.png -------------------------------------------------------------------------------- /figs/compare/ComS_img_023_BBCU_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_023_BBCU_x4.png -------------------------------------------------------------------------------- /figs/compare/ComS_img_023_BI-DiffSR_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_023_BI-DiffSR_x4.png -------------------------------------------------------------------------------- /figs/compare/ComS_img_023_Bicubic_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_023_Bicubic_x4.png -------------------------------------------------------------------------------- /figs/compare/ComS_img_023_HR_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_023_HR_x4.png -------------------------------------------------------------------------------- /figs/compare/ComS_img_023_SR3_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_023_SR3_x4.png -------------------------------------------------------------------------------- /figs/compare/ComS_img_033_BBCU_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_033_BBCU_x4.png -------------------------------------------------------------------------------- /figs/compare/ComS_img_033_BI-DiffSR_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_033_BI-DiffSR_x4.png -------------------------------------------------------------------------------- /figs/compare/ComS_img_033_Bicubic_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_033_Bicubic_x4.png -------------------------------------------------------------------------------- /figs/compare/ComS_img_033_HR_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_033_HR_x4.png -------------------------------------------------------------------------------- /figs/compare/ComS_img_033_SR3_x4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_033_SR3_x4.png -------------------------------------------------------------------------------- /figs/logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/logo.png -------------------------------------------------------------------------------- /options/test/test_BI_DiffSR_x2.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: test_BI_DiffSR_DDIM_S50_x2 3 | model_type: BIDiffSRModel 4 | scale: 2 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | test_1: 11 | task: SR 12 | name: Set5 13 | type: MultiPairedImageDataset 14 | dataroot_gt: datasets/benchmark/Set5/HR 15 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2 16 | filename_tmpl: '{}x2' 17 | io_backend: 18 | type: disk 19 | 20 | test_2: 21 | task: SR 22 | name: B100 23 | type: MultiPairedImageDataset 24 | dataroot_gt: datasets/benchmark/B100/HR 25 | dataroot_lq: datasets/benchmark/B100/LR_bicubic/X2 26 | filename_tmpl: '{}x2' 27 | io_backend: 28 | type: disk 29 | 30 | test_3: 31 | task: SR 32 | name: Urban100 33 | type: MultiPairedImageDataset 34 | dataroot_gt: datasets/benchmark/Urban100/HR 35 | dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X2 36 | filename_tmpl: '{}x2' 37 | io_backend: 38 | type: disk 39 | 40 | test_4: 41 | task: SR 42 | name: Manga109 43 | type: MultiPairedImageDataset 44 | dataroot_gt: datasets/benchmark/Manga109/HR 45 | dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X2 46 | filename_tmpl: '{}_LRBI_x2' 47 | io_backend: 48 | type: disk 49 | 50 | # network structures 51 | network_g: 52 | type: BIDiffSRUNet 53 | in_channel: 6 54 | out_channel: 3 55 | inner_channel: 64 56 | norm_groups: 16 57 | channel_mults: [1, 2, 4, 8] 58 | attn_res: [] 59 | res_blocks: 2 60 | dropout: 0.2 61 | image_size: 256 62 | fp_res: [256, 128] 63 | total_step: 2000 64 | dynamic_group: 5 # K 65 | 66 | # schedule 67 | beta_schedule: 68 | scheduler_type: DDIM 69 | schedule: linear 70 | n_timestep: 2000 71 | linear_start: !!float 1e-6 72 | linear_end: !!float 1e-2 73 | prediction_type: epsilon 74 | num_inference_steps: 50 75 | guidance_scale: 7.5 76 | is_guidance: False 77 | 78 | # path 79 | path: 80 | pretrain_network_g: experiments/pretrained_models/BI_DiffSR_x2.pth 81 | strict_load_g: true 82 | resume_state: params 83 | 84 | # validation settings 85 | val: 86 | save_img: true 87 | suffix: 'test' # add suffix to saved images, if None, use exp name 88 | 89 | metrics: 90 | psnr: # metric name, can be arbitrary 91 | type: calculate_psnr 92 | crop_border: 2 93 | test_y_channel: true 94 | 95 | ssim: # metric name, can be arbitrary 96 | type: calculate_ssim 97 | crop_border: 2 98 | test_y_channel: true 99 | 100 | lpips: # metric name, can be arbitrary 101 | type: calculate_lpips 102 | crop_border: 2 103 | better: lower 104 | -------------------------------------------------------------------------------- /options/test/test_BI_DiffSR_x4.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: test_BI_DiffSR_DDIM_S50_x4 3 | model_type: BIDiffSRModel 4 | scale: 4 5 | num_gpu: 1 6 | manual_seed: 10 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | test_1: 11 | task: SR 12 | name: Set5 13 | type: MultiPairedImageDataset 14 | dataroot_gt: datasets/benchmark/Set5/HR 15 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 16 | filename_tmpl: '{}x4' 17 | io_backend: 18 | type: disk 19 | 20 | test_2: 21 | task: SR 22 | name: B100 23 | type: MultiPairedImageDataset 24 | dataroot_gt: datasets/benchmark/B100/HR 25 | dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4 26 | filename_tmpl: '{}x4' 27 | io_backend: 28 | type: disk 29 | 30 | test_3: 31 | task: SR 32 | name: Urban100 33 | type: MultiPairedImageDataset 34 | dataroot_gt: datasets/benchmark/Urban100/HR 35 | dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4 36 | filename_tmpl: '{}x4' 37 | io_backend: 38 | type: disk 39 | 40 | test_4: 41 | task: SR 42 | name: Manga109 43 | type: MultiPairedImageDataset 44 | dataroot_gt: datasets/benchmark/Manga109/HR 45 | dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4 46 | filename_tmpl: '{}_LRBI_x4' 47 | io_backend: 48 | type: disk 49 | 50 | # network structures 51 | network_g: 52 | type: BIDiffSRUNet 53 | in_channel: 6 54 | out_channel: 3 55 | inner_channel: 64 56 | norm_groups: 16 57 | channel_mults: [1, 2, 4, 8] 58 | attn_res: [] 59 | res_blocks: 2 60 | dropout: 0.2 61 | image_size: 256 62 | fp_res: [256, 128] 63 | total_step: 2000 64 | dynamic_group: 5 # K 65 | 66 | # schedule 67 | beta_schedule: 68 | scheduler_type: DDIM 69 | schedule: linear 70 | n_timestep: 2000 71 | linear_start: !!float 1e-6 72 | linear_end: !!float 1e-2 73 | prediction_type: epsilon 74 | num_inference_steps: 50 75 | guidance_scale: 7.5 76 | is_guidance: False 77 | 78 | # path 79 | path: 80 | pretrain_network_g: experiments/pretrained_models/BI_DiffSR_x4.pth 81 | strict_load_g: true 82 | resume_state: params 83 | 84 | # validation settings 85 | val: 86 | save_img: true 87 | suffix: 'test' # add suffix to saved images, if None, use exp name 88 | 89 | metrics: 90 | psnr: # metric name, can be arbitrary 91 | type: calculate_psnr 92 | crop_border: 4 93 | test_y_channel: true 94 | 95 | ssim: # metric name, can be arbitrary 96 | type: calculate_ssim 97 | crop_border: 4 98 | test_y_channel: true 99 | 100 | lpips: # metric name, can be arbitrary 101 | type: calculate_lpips 102 | crop_border: 4 103 | better: lower 104 | 105 | -------------------------------------------------------------------------------- /options/train/train_BI_DiffSR_x2.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_BI_DiffSR_DDIM_S50_x2 3 | model_type: BIDiffSRModel 4 | scale: 2 5 | num_gpu: auto 6 | manual_seed: 10 7 | find_unused_parameters: True 8 | 9 | # dataset and data loader settings 10 | datasets: 11 | train: 12 | task: SR 13 | name: DF2K 14 | type: MultiPairedImageDataset 15 | dataroot_gt: datasets/DF2K/HR 16 | dataroot_lq: datasets/DF2K/LR_bicubic/X2 17 | filename_tmpl: '{}x2' 18 | io_backend: 19 | type: disk 20 | 21 | gt_size: 128 22 | use_hflip: True 23 | use_rot: True 24 | 25 | # data loader 26 | use_shuffle: True 27 | num_worker_per_gpu: 8 28 | batch_size_per_gpu: 4 29 | dataset_enlarge_ratio: 100 30 | prefetch_mode: ~ 31 | 32 | val: 33 | task: SR 34 | name: Set5 35 | type: MultiPairedImageDataset 36 | dataroot_gt: datasets/benchmark/Set5/HR 37 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2 38 | filename_tmpl: '{}x2' 39 | io_backend: 40 | type: disk 41 | 42 | # network structures 43 | network_g: 44 | type: BIDiffSRUNet 45 | in_channel: 6 46 | out_channel: 3 47 | inner_channel: 64 48 | norm_groups: 16 49 | channel_mults: [1, 2, 4, 8] 50 | attn_res: [] 51 | res_blocks: 2 52 | dropout: 0.2 53 | image_size: 256 54 | fp_res: [256, 128] 55 | total_step: 2000 56 | dynamic_group: 5 # K 57 | 58 | # schedule 59 | beta_schedule: 60 | scheduler_type: DDIM 61 | schedule: linear 62 | n_timestep: 2000 63 | linear_start: !!float 1e-6 64 | linear_end: !!float 1e-2 65 | prediction_type: epsilon 66 | num_inference_steps: 50 67 | guidance_scale: 7.5 68 | is_guidance: False 69 | 70 | # path 71 | path: 72 | pretrain_network_g: ~ 73 | strict_load_g: true 74 | resume_state: ~ 75 | 76 | train: 77 | # ema_decay: 0.999 78 | optim_g: 79 | type: Adam 80 | lr: !!float 1e-4 81 | weight_decay: 0 82 | betas: [0.9, 0.99] 83 | 84 | scheduler: 85 | type: MultiStepLR 86 | milestones: [500000] 87 | gamma: 1 88 | 89 | total_iter: 1000000 90 | warmup_iter: -1 # no warm up 91 | 92 | # losses 93 | pixel_opt: 94 | type: L1Loss 95 | loss_weight: 1.0 96 | reduction: mean 97 | 98 | # validation settings 99 | val: 100 | val_freq: !!float 2e4 101 | save_img: false 102 | 103 | metrics: 104 | psnr: # metric name, can be arbitrary 105 | type: calculate_psnr 106 | crop_border: 2 107 | test_y_channel: true 108 | 109 | ssim: # metric name, can be arbitrary 110 | type: calculate_ssim 111 | crop_border: 2 112 | test_y_channel: true 113 | 114 | lpips: # metric name, can be arbitrary 115 | type: calculate_lpips 116 | crop_border: 2 117 | better: lower 118 | 119 | # logging settings 120 | logger: 121 | print_freq: 500 122 | save_checkpoint_freq: !!float 2e4 123 | use_tb_logger: true 124 | wandb: 125 | project: ~ 126 | resume_id: ~ 127 | 128 | # dist training settings 129 | dist_params: 130 | backend: nccl 131 | port: 29500 132 | -------------------------------------------------------------------------------- /options/train/train_BI_DiffSR_x4.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_BI_DiffSR_DDIM_S50_x4 3 | model_type: BIDiffSRModel 4 | scale: 4 5 | num_gpu: auto 6 | manual_seed: 10 7 | find_unused_parameters: True 8 | 9 | # dataset and data loader settings 10 | datasets: 11 | train: 12 | task: SR 13 | name: DF2K 14 | type: MultiPairedImageDataset 15 | dataroot_gt: datasets/DF2K/HR 16 | dataroot_lq: datasets/DF2K/LR_bicubic/X4 17 | filename_tmpl: '{}x4' 18 | io_backend: 19 | type: disk 20 | 21 | gt_size: 256 22 | use_hflip: True 23 | use_rot: True 24 | 25 | # data loader 26 | use_shuffle: True 27 | num_worker_per_gpu: 8 28 | batch_size_per_gpu: 4 29 | dataset_enlarge_ratio: 100 30 | prefetch_mode: ~ 31 | 32 | val: 33 | task: SR 34 | name: Set5 35 | type: MultiPairedImageDataset 36 | dataroot_gt: datasets/benchmark/Set5/HR 37 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4 38 | filename_tmpl: '{}x4' 39 | io_backend: 40 | type: disk 41 | 42 | # network structures 43 | network_g: 44 | type: BIDiffSRUNet 45 | in_channel: 6 46 | out_channel: 3 47 | inner_channel: 64 48 | norm_groups: 16 49 | channel_mults: [1, 2, 4, 8] 50 | attn_res: [] 51 | res_blocks: 2 52 | dropout: 0.2 53 | image_size: 256 54 | fp_res: [256, 128] 55 | total_step: 2000 56 | dynamic_group: 5 # K 57 | 58 | # schedule 59 | beta_schedule: 60 | scheduler_type: DDIM 61 | schedule: linear 62 | n_timestep: 2000 63 | linear_start: !!float 1e-6 64 | linear_end: !!float 1e-2 65 | prediction_type: epsilon 66 | num_inference_steps: 50 67 | guidance_scale: 7.5 68 | is_guidance: False 69 | 70 | # path 71 | path: 72 | pretrain_network_g: ~ 73 | strict_load_g: false 74 | resume_state: ~ 75 | 76 | train: 77 | # ema_decay: 0.999 78 | optim_g: 79 | type: Adam 80 | lr: !!float 1e-4 81 | weight_decay: 0 82 | betas: [0.9, 0.99] 83 | 84 | scheduler: 85 | type: MultiStepLR 86 | milestones: [500000] 87 | gamma: 1 88 | 89 | total_iter: 1000000 90 | warmup_iter: -1 # no warm up 91 | 92 | # losses 93 | pixel_opt: 94 | type: L1Loss 95 | loss_weight: 1.0 96 | reduction: mean 97 | 98 | # validation settings 99 | val: 100 | val_freq: !!float 2e4 101 | save_img: false 102 | 103 | metrics: 104 | psnr: # metric name, can be arbitrary 105 | type: calculate_psnr 106 | crop_border: 4 107 | test_y_channel: true 108 | 109 | ssim: # metric name, can be arbitrary 110 | type: calculate_ssim 111 | crop_border: 4 112 | test_y_channel: true 113 | 114 | lpips: # metric name, can be arbitrary 115 | type: calculate_lpips 116 | crop_border: 4 117 | better: lower 118 | 119 | # logging settings 120 | logger: 121 | print_freq: 500 122 | save_checkpoint_freq: !!float 2e4 123 | use_tb_logger: true 124 | wandb: 125 | project: ~ 126 | resume_id: ~ 127 | 128 | # dist training settings 129 | dist_params: 130 | backend: nccl 131 | port: 29500 132 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.13.1+cu117 2 | torchvision==0.14.1+cu117 3 | addict 4 | future 5 | lmdb 6 | numpy<2.0.0 7 | opencv-python 8 | Pillow 9 | pyyaml 10 | requests 11 | scikit-image 12 | scipy 13 | tb-nightly 14 | tqdm 15 | yapf 16 | timm 17 | einops 18 | natsort 19 | joblib 20 | wandb 21 | lpips 22 | matplotlib 23 | seaborn 24 | scikit-learn 25 | basicsr -------------------------------------------------------------------------------- /results/README.md: -------------------------------------------------------------------------------- 1 | The testing results. 2 | 3 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | 5 | from basicsr.data import build_dataloader, build_dataset 6 | from basicsr.models import build_model 7 | from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs 8 | from basicsr.utils.options import dict2str 9 | from diffglv.utils.options import parse_options 10 | from basicsr.utils import set_random_seed 11 | import random 12 | 13 | 14 | def test_pipeline(root_path): 15 | # parse options, set distributed setting, set ramdom seed 16 | opt, _ = parse_options(root_path, is_train=False) 17 | 18 | torch.backends.cudnn.benchmark = True 19 | # torch.backends.cudnn.deterministic = True 20 | 21 | # mkdir and initialize loggers 22 | make_exp_dirs(opt) 23 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log") 24 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 25 | logger.info(get_env_info()) 26 | logger.info(dict2str(opt)) 27 | 28 | # create test dataset and dataloader 29 | test_loaders = [] 30 | for _, dataset_opt in sorted(opt['datasets'].items()): 31 | test_set = build_dataset(dataset_opt) 32 | test_loader = build_dataloader( 33 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']) 34 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 35 | test_loaders.append(test_loader) 36 | 37 | # create model 38 | model = build_model(opt) 39 | 40 | for test_loader in test_loaders: 41 | seed = opt.get('manual_seed') 42 | if seed is None: 43 | seed = random.randint(1, 10000) 44 | opt['manual_seed'] = seed 45 | set_random_seed(seed + opt['rank']) 46 | test_set_name = test_loader.dataset.opt['name'] 47 | logger.info(f'Testing {test_set_name}...') 48 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img']) 49 | 50 | 51 | if __name__ == '__main__': 52 | root_path = osp.abspath(osp.join(__file__, osp.pardir)) 53 | test_pipeline(root_path) 54 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import basicsr 3 | import diffglv 4 | 5 | if __name__ == '__main__': 6 | root_path = osp.abspath(osp.join(__file__, osp.pardir)) 7 | basicsr.train_pipeline(root_path) 8 | --------------------------------------------------------------------------------