├── .gitignore
├── LICENSE
├── README.md
├── datasets
└── README.md
├── diffglv
├── __init__.py
├── archs
│ ├── __init__.py
│ └── unet_BI_DiffSR_arch.py
├── data
│ ├── __init__.py
│ ├── data_util.py
│ └── paired_image_dataset.py
├── losses
│ ├── __init__.py
│ └── at_loss.py
├── metrics
│ ├── __init__.py
│ └── lpips.py
├── models
│ ├── BI_DiffSR_model.py
│ └── __init__.py
└── utils
│ ├── GPU_memory.py
│ ├── __init__.py
│ ├── base_model.py
│ ├── beta_schedule.py
│ ├── extract_subimages.py
│ ├── logger.py
│ ├── lr_scheduler.py
│ ├── make_ds.py
│ ├── options.py
│ └── transforms.py
├── experiments
├── README.md
└── pretrained_models
│ └── README.md
├── figs
├── BI-DiffSR.png
├── F1.png
├── F2-1.png
├── F2-2.png
├── F3-1.png
├── F3-2.png
├── T1.png
├── compare
│ ├── ComS_img_023_BBCU_x4.png
│ ├── ComS_img_023_BI-DiffSR_x4.png
│ ├── ComS_img_023_Bicubic_x4.png
│ ├── ComS_img_023_HR_x4.png
│ ├── ComS_img_023_SR3_x4.png
│ ├── ComS_img_033_BBCU_x4.png
│ ├── ComS_img_033_BI-DiffSR_x4.png
│ ├── ComS_img_033_Bicubic_x4.png
│ ├── ComS_img_033_HR_x4.png
│ └── ComS_img_033_SR3_x4.png
└── logo.png
├── options
├── test
│ ├── test_BI_DiffSR_x2.yml
│ └── test_BI_DiffSR_x4.yml
└── train
│ ├── train_BI_DiffSR_x2.yml
│ └── train_BI_DiffSR_x4.yml
├── requirements.txt
├── results
└── README.md
├── test.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignored folders
2 | datasets/*
3 | experiments/*
4 | results/*
5 | tb_logger/*
6 | wandb/*
7 | tmp/*
8 | slurm/*
9 | scripts/metrics/*
10 |
11 | options/euler/*
12 |
13 | *.DS_Store
14 | .idea
15 |
16 | # ignored files
17 | version.py
18 |
19 | # ignored files with suffix
20 | *.html
21 | *.png
22 | *.jpeg
23 | *.jpg
24 | *.gif
25 | *.pth
26 | *.zip
27 | *.npy
28 | *.pdf
29 |
30 | !figs/*.png
31 |
32 | # slurm
33 | *.err
34 | *.out
35 |
36 | # template
37 |
38 | # Byte-compiled / optimized / DLL files
39 | __pycache__/
40 | *.py[cod]
41 | *$py.class
42 |
43 | # C extensions
44 | *.so
45 |
46 | # Distribution / packaging
47 | .Python
48 | build/
49 | develop-eggs/
50 | dist/
51 | downloads/
52 | eggs/
53 | .eggs/
54 | lib/
55 | lib64/
56 | parts/
57 | sdist/
58 | var/
59 | wheels/
60 | *.egg-info/
61 | .installed.cfg
62 | *.egg
63 | MANIFEST
64 |
65 | # PyInstaller
66 | # Usually these files are written by a python script from a template
67 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
68 | *.manifest
69 | *.spec
70 |
71 | # Installer logs
72 | pip-log.txt
73 | pip-delete-this-directory.txt
74 |
75 | # Unit test / coverage reports
76 | htmlcov/
77 | .tox/
78 | .coverage
79 | .coverage.*
80 | .cache
81 | nosetests.xml
82 | coverage.xml
83 | *.cover
84 | .hypothesis/
85 | .pytest_cache/
86 |
87 | # Translations
88 | *.mo
89 | *.pot
90 |
91 | # Django stuff:
92 | *.log
93 | local_settings.py
94 | db.sqlite3
95 |
96 | # Flask stuff:
97 | instance/
98 | .webassets-cache
99 |
100 | # Scrapy stuff:
101 | .scrapy
102 |
103 | # Sphinx documentation
104 | docs/_build/
105 |
106 | # PyBuilder
107 | target/
108 |
109 | # Jupyter Notebook
110 | .ipynb_checkpoints
111 |
112 | # pyenv
113 | .python-version
114 |
115 | # celery beat schedule file
116 | celerybeat-schedule
117 |
118 | # SageMath parsed files
119 | *.sage.py
120 |
121 | # Environments
122 | .env
123 | .venv
124 | env/
125 | venv/
126 | ENV/
127 | env.bak/
128 | venv.bak/
129 |
130 | # Spyder project settings
131 | .spyderproject
132 | .spyproject
133 |
134 | # Rope project settings
135 | .ropeproject
136 |
137 | # mkdocs documentation
138 | /site
139 |
140 | # mypy
141 | .mypy_cache/
142 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2024 BI-DiffSR Authors
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | # Binarized Diffusion Model for Image Super-Resolution
6 |
7 | [Zheng Chen](https://zhengchen1999.github.io/), [Haotong Qin](https://htqin.github.io/), [Yong Guo](https://www.guoyongcs.com/), [Xiongfei Su](https://ieeexplore.ieee.org/author/37086348852), [Xin Yuan](https://en.westlake.edu.cn/faculty/xin-yuan.html), [Linghe Kong](https://www.cs.sjtu.edu.cn/~linghe.kong/), and [Yulun Zhang](http://yulunzhang.com/), "Binarized Diffusion Model for Image Super-Resolution", NeurIPS, 2024
8 |
9 | [[project](https://zhengchen1999.github.io/BI-DiffSR/)] [[arXiv](https://arxiv.org/abs/2406.05723)] [[supplementary material](https://github.com/zhengchen1999/BI-DiffSR/releases/download/v1/Supplementary_Material.pdf)] [[visual results](https://drive.google.com/drive/folders/1-Mfy8XHG55Bc19gAXqNaNitO0GEx7O1r?usp=drive_link)] [[pretrained models](https://drive.google.com/drive/folders/1hoHAG2yoLltloQ0SYv-QLxwk9Y8ZnTnH?usp=drive_link)]
10 |
11 |
12 |
13 | #### 🔥🔥🔥 News
14 |
15 | - **2024-10-23:** [Project Page](https://zhengchen1999.github.io/BI-DiffSR/) is accessible. 📃📃📃
16 | - **2024-10-14:** Code and pre-trained models are released. ⭐️⭐️⭐️
17 | - **2024-09-26:** BI-DiffSR is accepted at NeurIPS 2024. 🎉🎉🎉
18 | - **2024-06-09:** This repo is released.
19 |
20 | ---
21 |
22 | > **Abstract:** Advanced diffusion models (DMs) perform impressively in image super-resolution (SR), but the high memory and computational costs hinder their deployment. Binarization, an ultra-compression algorithm, offers the potential for effectively accelerating DMs. Nonetheless, due to the model structure and the multi-step iterative attribute of DMs, existing binarization methods result in significant performance degradation. In this paper, we introduce a novel binarized diffusion model, BI-DiffSR, for image SR. First, for the model structure, we design a UNet architecture optimized for binarization. We propose the consistent-pixel-downsample (CP-Down) and consistent-pixel-upsample (CP-Up) to maintain dimension consistent and facilitate the full-precision information transfer. Meanwhile, we design the channel-shuffle-fusion (CS-Fusion) to enhance feature fusion in skip connection. Second, for the activation difference across timestep, we design the timestep-aware redistribution (TaR) and activation function (TaA). The TaR and TaA dynamically adjust the distribution of activations based on different timesteps, improving the flexibility and representation alability of the binarized module. Comprehensive experiments demonstrate that our BI-DiffSR outperforms existing binarization methods.
23 |
24 | 
25 |
26 | ---
27 |
28 | | HR | LR | [SR3 (FP)](https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement) | [BBCU](https://github.com/Zj-BinXia/BBCU) | BI-DiffSR (ours) |
29 | | :-------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :---------------------------------------------------------: | :----------------------------------------------------------: |
30 | |
|
|
|
|
|
31 | |
|
|
|
|
|
32 |
33 | ## TODO
34 |
35 | * [x] Release code and pretrained models
36 |
37 | ## Dependencies
38 |
39 | - Python 3.9
40 | - PyTorch 1.13.1+cu117
41 |
42 | ```bash
43 | # Clone the github repo and go to the default directory 'BI-DiffSR'.
44 | git clone https://github.com/zhengchen1999/BI-DiffSR.git
45 | conda create -n bi_diffsr python=3.9
46 | conda activate bi_diffsr
47 | pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
48 | git clone https://github.com/huggingface/diffusers.git
49 | cd diffusers
50 | pip install -e ".[torch]"
51 | ```
52 |
53 | ## Contents
54 |
55 | 1. [Datasets](#datasets)
56 | 1. [Models](#models)
57 | 1. [Training](#training)
58 | 1. [Testing](#testing)
59 | 1. [Results](#results)
60 | 1. [Citation](#citation)
61 | 1. [Acknowledgements](#acknowledgements)
62 |
63 | ## Datasets
64 |
65 | Used training and testing sets can be downloaded as follows:
66 |
67 | | Training Set | Testing Set | Visual Results |
68 | | :----------------------------------------------------------- | :----------------------------------------------------------: | :----------------------------------------------------------: |
69 | | [DIV2K](https://data.vision.ee.ethz.ch/cvl/DIV2K/) (800 training images, 100 validation images) + [Flickr2K](https://cv.snu.ac.kr/research/EDSR/Flickr2K.tar) (2650 images) [complete training dataset DF2K: [Google Drive](https://drive.google.com/file/d/1TubDkirxl4qAWelfOnpwaSKoj3KLAIG4/view?usp=share_link) / [Baidu Disk](https://pan.baidu.com/s/1KIcPNz3qDsGSM0uDKl4DRw?pwd=74yc)] | Set5 + Set14 + BSD100 + Urban100 + Manga109 [complete testing dataset: [Google Drive](https://drive.google.com/file/d/1yMbItvFKVaCT93yPWmlP3883XtJ-wSee/view?usp=sharing) / [Baidu Disk](https://pan.baidu.com/s/1Tf8WT14vhlA49TO2lz3Y1Q?pwd=8xen)] | [Google Drive](https://drive.google.com/drive/folders/1ZMaZyCer44ZX6tdcDmjIrc_hSsKoMKg2?usp=drive_link) / [Baidu Disk](https://pan.baidu.com/s/1LO-INqy40F5T_coAJsl5qw?pwd=dqnv#list/path=%2F) |
70 |
71 | Download training and testing datasets and put them into the corresponding folders of `datasets/`.
72 |
73 | ## Models
74 |
75 | | Method | Params (M) | FLOPs (G) | PSNR (dB) | LPIPS | Model Zoo | Visual Results |
76 | | :-------- | :--------: | :-------: | :-------: | :----: | :----------------------------------------------------------: | :----------------------------------------------------------: |
77 | | BI-DiffSR | 4.58 | 36.67 | 24.11 | 0.1823 | [Google Drive](https://drive.google.com/drive/folders/1hoHAG2yoLltloQ0SYv-QLxwk9Y8ZnTnH?usp=sharing) | [Google Drive](https://drive.google.com/drive/folders/1-Mfy8XHG55Bc19gAXqNaNitO0GEx7O1r?usp=sharing) |
78 |
79 | The performance is reported on Urban100 (×4). Output size of FLOPs is 3×256×256.
80 |
81 | ## Training
82 |
83 | - The ×2 task requires **4*8 GB** VRAM, and the ×4 task requires **4*20 GB** VRAM.
84 |
85 | - Download [training](https://drive.google.com/file/d/1TubDkirxl4qAWelfOnpwaSKoj3KLAIG4/view?usp=share_link) (DF2K, already processed) and [testing](https://drive.google.com/file/d/1yMbItvFKVaCT93yPWmlP3883XtJ-wSee/view?usp=sharing) (Set5, BSD100, Urban100, Manga109, already processed) datasets, place them in `datasets/`.
86 |
87 | - Run the following scripts. The training configuration is in `options/train/`.
88 |
89 | ```shell
90 | # BI-DiffSR, input=64x64, 4 GPUs
91 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train/train_BI_DiffSR_x2.yml --launcher pytorch
92 | python -m torch.distributed.launch --nproc_per_node=4 --master_port=4321 train.py -opt options/train/train_BI_DiffSR_x4.yml --launcher pytorch
93 | ```
94 |
95 | - The training experiment is in `experiments/`.
96 |
97 | ## Testing
98 |
99 | - Download the pre-trained [models](https://drive.google.com/drive/folders/1hoHAG2yoLltloQ0SYv-QLxwk9Y8ZnTnH?usp=sharing) and place them in `experiments/pretrained_models/`.
100 |
101 | We provide pre-trained models for image SR (×2, ×4).
102 |
103 | - Download [testing](https://drive.google.com/file/d/1yMbItvFKVaCT93yPWmlP3883XtJ-wSee/view?usp=sharing) (Set5, BSD100, Urban100, Manga109) datasets, place them in `datasets/`.
104 |
105 | - Run the following scripts. The testing configuration is in `options/test/`.
106 |
107 | ```shell
108 | # BI-DiffSR, reproduces results in Table 2 of the main paper
109 | python test.py -opt options/test/test_BI_DiffSR_x2.yml
110 | python test.py -opt options/test/test_BI_DiffSR_x4.yml
111 | ```
112 |
113 | Due to the randomness of diffusion model ([diffusers](https://huggingface.co/docs/diffusers)), results may slightly vary.
114 |
115 | - The output is in `results/`.
116 |
117 | ## Results
118 |
119 | We achieve state-of-the-art performance. Detailed results can be found in the paper.
120 |
121 |
122 | Quantitative Comparisons (click to expand)
123 |
124 | - Results in Table 2 (main paper)
125 |
126 |
127 |
128 |
129 |
130 |
131 |
132 |
133 |
134 | Visual Comparisons (click to expand)
135 |
136 |
137 | - Results in Figure 8 (main paper)
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 | - Results in Figure 5 (supplemental material)
146 |
147 |
148 |
149 |
150 |
151 |
152 |
153 |
154 |
155 | - Results in Figure 6 (supplemental material)
156 |
157 |
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 | ## Citation
167 |
168 | If you find the code helpful in your research or work, please cite the following paper(s).
169 |
170 | ```
171 | @inproceedings{chen2024binarized,
172 | title={Binarized Diffusion Model for Image Super-Resolution},
173 | author={Chen, Zheng and Qin, Haotong and Guo, Yong and Su, Xiongfei and Yuan, Xin and Kong, Linghe and Zhang, Yulun},
174 | booktitle={NeurIPS},
175 | year={2024}
176 | }
177 | ```
178 |
179 |
180 |
181 | ## Acknowledgements
182 |
183 | This code is built on [BasicSR](https://github.com/XPixelGroup/BasicSR), [Image-Super-Resolution-via-Iterative-Refinement](https://github.com/Janspiry/Image-Super-Resolution-via-Iterative-Refinement).
184 |
--------------------------------------------------------------------------------
/datasets/README.md:
--------------------------------------------------------------------------------
1 | For training and testing, the directory structure is as follows:
2 |
3 | ```shell
4 | |-- datasets
5 | # train
6 | |-- DF2K
7 | |-- HR
8 | |-- LR_bicubic
9 | |-- X2
10 | |-- X3
11 | |-- X4
12 | # test
13 | |-- benchmark
14 | |-- Set5
15 | |-- HR
16 | |-- LR_bicubic
17 | |-- X2
18 | |-- X3
19 | |-- X4
20 | |-- B100
21 | |-- HR
22 | |-- LR_bicubic
23 | |-- X2
24 | |-- X3
25 | |-- X4
26 | |-- Urban100
27 | |-- HR
28 | |-- LR_bicubic
29 | |-- X2
30 | |-- X3
31 | |-- X4
32 | |-- Manga109
33 | |-- HR
34 | |-- LR_bicubic
35 | |-- X2
36 | |-- X3
37 | |-- X4
38 | ```
39 |
40 | You can download the complete datasets we have collected.
41 |
--------------------------------------------------------------------------------
/diffglv/__init__.py:
--------------------------------------------------------------------------------
1 | from .archs import *
2 | from .data import *
3 | from .losses import *
4 | from .metrics import *
5 | from .models import *
6 |
7 |
--------------------------------------------------------------------------------
/diffglv/archs/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from copy import deepcopy
3 | from os import path as osp
4 |
5 | from basicsr.utils import get_root_logger, scandir
6 | from basicsr.utils.registry import ARCH_REGISTRY
7 |
8 | __all__ = ['build_network']
9 |
10 | # automatically scan and import arch modules for registry
11 | # scan all the files under the 'archs' folder and collect files ending with
12 | # '_arch.py'
13 | arch_folder = osp.dirname(osp.abspath(__file__))
14 | arch_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) if v.endswith('_arch.py')]
15 | # import all the arch modules
16 | _arch_modules = [importlib.import_module(f'diffglv.archs.{file_name}') for file_name in arch_filenames]
17 |
18 |
19 | def build_network(opt):
20 | opt = deepcopy(opt)
21 | network_type = opt.pop('type')
22 | net = ARCH_REGISTRY.get(network_type)(**opt)
23 | logger = get_root_logger()
24 | logger.info(f'Network [{net.__class__.__name__}] is created.')
25 | return net
26 |
--------------------------------------------------------------------------------
/diffglv/archs/unet_BI_DiffSR_arch.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | import torch.nn.functional as F
5 | from inspect import isfunction
6 |
7 | from basicsr.utils.registry import ARCH_REGISTRY
8 |
9 |
10 |
11 | def exists(x):
12 | return x is not None
13 |
14 | def default(val, d):
15 | if exists(val):
16 | return val
17 | return d() if isfunction(d) else d
18 |
19 | # --------------------------------------------- BI Basic Units: START -----------------------------------------------------------------
20 | class RPReLU(nn.Module):
21 | def __init__(self, inplanes):
22 | super(RPReLU, self).__init__()
23 | self.pr_bias0 = LearnableBias(inplanes)
24 | self.pr_prelu = nn.PReLU(inplanes)
25 | self.pr_bias1 = LearnableBias(inplanes)
26 |
27 | def forward(self, x):
28 | x = self.pr_bias1(self.pr_prelu(self.pr_bias0(x)))
29 | return x
30 |
31 | class BinaryActivation(nn.Module):
32 | def __init__(self):
33 | super(BinaryActivation, self).__init__()
34 |
35 | def forward(self, x):
36 | out_forward = torch.sign(x)
37 | mask1 = x < -1
38 | mask2 = x < 0
39 | mask3 = x < 1
40 | out1 = (-1) * mask1.type(torch.float32) + (x*x + 2*x) * (1-mask1.type(torch.float32))
41 | out2 = out1 * mask2.type(torch.float32) + (-x*x + 2*x) * (1-mask2.type(torch.float32))
42 | out3 = out2 * mask3.type(torch.float32) + 1 * (1- mask3.type(torch.float32))
43 | out = out_forward.detach() - out3.detach() + out3
44 |
45 | return out
46 |
47 | class LearnableBias(nn.Module):
48 | def __init__(self, out_chn):
49 | super(LearnableBias, self).__init__()
50 | self.bias = nn.Parameter(torch.zeros(1,out_chn,1,1), requires_grad=True)
51 |
52 | def forward(self, x):
53 | out = x + self.bias.expand_as(x)
54 | return out
55 |
56 | class HardBinaryConv(nn.Conv2d):
57 | def __init__(self, in_chn, out_chn, kernel_size=3, stride=1, padding=1,groups=1,bias=True):
58 | super(HardBinaryConv, self).__init__(
59 | in_chn,
60 | out_chn,
61 | kernel_size,
62 | stride=stride,
63 | padding=padding,
64 | groups=groups,
65 | bias=bias
66 | )
67 |
68 | def forward(self, x):
69 | real_weights = self.weight
70 | scaling_factor = torch.mean(torch.mean(torch.mean(abs(real_weights),dim=3,keepdim=True),dim=2,keepdim=True),dim=1,keepdim=True)
71 | scaling_factor = scaling_factor.detach()
72 | binary_weights_no_grad = scaling_factor * torch.sign(real_weights)
73 | cliped_weights = torch.clamp(real_weights, -1.0, 1.0)
74 | binary_weights = binary_weights_no_grad.detach() - cliped_weights.detach() + cliped_weights
75 | y = F.conv2d(x, binary_weights,self.bias, stride=self.stride, padding=self.padding)
76 | return y
77 |
78 | class BIConv(nn.Module):
79 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, bias=False, groups=1, dynamic_group=5):
80 | super(BIConv, self).__init__()
81 | self.TaR = nn.ModuleDict({
82 | f'dynamic_move_{i}': LearnableBias(in_channels) for i in range(dynamic_group)
83 | })
84 |
85 | self.binary_activation = BinaryActivation()
86 | self.binary_conv = HardBinaryConv(in_channels,
87 | out_channels,
88 | kernel_size,
89 | padding=(kernel_size//2),
90 | bias=bias)
91 | self.TaA = nn.ModuleDict({
92 | f'dynamic_relu_{i}': RPReLU(out_channels) for i in range(dynamic_group)
93 | })
94 |
95 | def forward(self, x, t):
96 | out = self.TaR[f'dynamic_move_{t}'](x)
97 | out = self.binary_activation(out)
98 | out = self.binary_conv(out)
99 | out = self.TaA[f'dynamic_relu_{t}'](out)
100 | out = out + x
101 | return out
102 | # --------------------------------------------- BI Basic Units: END -----------------------------------------------------------------
103 |
104 | # --------------------------------------------- FP Module: START --------------------------------------------------------------------
105 | # PositionalEncoding Source: https://github.com/lmnt-com/wavegrad/blob/master/src/wavegrad/model.py
106 | class PositionalEncoding(nn.Module):
107 | def __init__(self, dim):
108 | super().__init__()
109 | self.dim = dim
110 |
111 | def forward(self, timestep_level):
112 | count = self.dim // 2
113 | step = torch.arange(count, dtype=timestep_level.dtype,
114 | device=timestep_level.device) / count
115 | encoding = timestep_level.unsqueeze(
116 | 1) * torch.exp(-math.log(1e4) * step.unsqueeze(0))
117 | encoding = torch.cat(
118 | [torch.sin(encoding), torch.cos(encoding)], dim=-1)
119 | return encoding
120 |
121 | class Swish(nn.Module):
122 | def forward(self, x):
123 | return x * torch.sigmoid(x)
124 |
125 | class Block(nn.Module):
126 | def __init__(self, dim, dim_out, groups=32, dropout=0):
127 | super().__init__()
128 | self.block = nn.Sequential(
129 | nn.GroupNorm(groups, dim),
130 | Swish(),
131 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
132 | )
133 |
134 | assert dim == dim_out, f"Error: input ({dim}) and output ({dim_out}) channel dimension."
135 |
136 | if dim == dim_out:
137 | self.conv = nn.Conv2d(dim, dim_out, 3, padding=1)
138 |
139 | def forward(self, x):
140 | return self.conv(self.block(x))
141 |
142 | class Block_F(nn.Module):
143 | def __init__(self, dim, dim_out, groups=32, dropout=0):
144 | super().__init__()
145 | self.block = nn.Sequential(
146 | nn.GroupNorm(groups, dim),
147 | Swish(),
148 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
149 | nn.Conv2d(dim, dim_out, 3, padding=1)
150 | )
151 |
152 | def forward(self, x):
153 | return self.block(x)
154 |
155 | class SelfAttention(nn.Module):
156 | def __init__(self, in_channel, n_head=1, norm_groups=32):
157 | super().__init__()
158 |
159 | self.n_head = n_head
160 |
161 | self.norm = nn.GroupNorm(norm_groups, in_channel)
162 | self.qkv = nn.Conv2d(in_channel, in_channel * 3, 1, bias=False)
163 | self.out = nn.Conv2d(in_channel, in_channel, 1)
164 |
165 | def forward(self, input):
166 | batch, channel, height, width = input.shape
167 | n_head = self.n_head
168 | head_dim = channel // n_head
169 |
170 | norm = self.norm(input)
171 | qkv = self.qkv(norm).view(batch, n_head, head_dim * 3, height, width)
172 | query, key, value = qkv.chunk(3, dim=2) # bhdyx
173 |
174 | attn = torch.einsum(
175 | "bnchw, bncyx -> bnhwyx", query, key
176 | ).contiguous() / math.sqrt(channel)
177 | attn = attn.view(batch, n_head, height, width, -1)
178 | attn = torch.softmax(attn, -1)
179 | attn = attn.view(batch, n_head, height, width, height, width)
180 |
181 | out = torch.einsum("bnhwyx, bncyx -> bnchw", attn, value).contiguous()
182 | out = self.out(out.view(batch, channel, height, width))
183 |
184 | return out + input
185 |
186 | class CP_Up_FP(nn.Module):
187 | def __init__(self, dim):
188 | super().__init__()
189 | self.biconv1 = nn.Conv2d(dim, dim, 3, 1, 1)
190 | self.biconv2 = nn.Conv2d(dim, dim, 3, 1, 1)
191 | self.up = nn.PixelShuffle(2)
192 |
193 | def forward(self, x):
194 | '''
195 | input: b,c,h,w
196 | output: b,c/2,h*2,w*2
197 | '''
198 | out1 = self.biconv1(x)
199 | out2 = self.biconv2(x)
200 | out = torch.cat([out1, out2], dim=1)
201 | out = self.up(out)
202 | return out
203 |
204 | class CP_Down_FP(nn.Module):
205 | def __init__(self, dim):
206 | super().__init__()
207 | self.biconv1 = nn.Conv2d(dim//2, dim//2, 3, padding=1)
208 | self.biconv2 = nn.Conv2d(dim//2, dim//2, 3, padding=1)
209 | self.down = nn.PixelUnshuffle(2)
210 |
211 | def forward(self, x):
212 | '''
213 | input: b,c,h,w
214 | output: b,2c,h/2,w/2
215 | '''
216 | b,c,h,w = x.shape
217 | out1 = self.biconv1(x[:,:c//2,:,:])
218 | out2 = self.biconv2(x[:,c//2:,:,:])
219 | out = out1 + out2
220 | out = self.down(out)
221 | return out
222 |
223 | class CS_Fusion_FP(nn.Module):
224 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False, groups=1):
225 | super(CS_Fusion_FP, self).__init__()
226 |
227 | assert in_channels // 2 == out_channels, f"Error: input ({in_channels}) and output ({out_channels}) channel dimension."
228 |
229 | self.biconv_1 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, groups=groups)
230 | self.biconv_2 = nn.Conv2d(out_channels, out_channels, kernel_size, stride, padding, groups=groups)
231 |
232 | def forward(self, x):
233 | '''
234 | x: b,c,h,w
235 | out: b,c/2,h,w
236 | '''
237 | b,c,h,w = x.shape
238 | in_1 = x[:,:c//2,:,:]
239 | in_2 = x[:,c//2:,:,:]
240 |
241 | fu_1 = torch.cat((in_1[:, 1::2, :, :], in_2[:, 0::2, :, :]), dim=1)
242 | fu_2 = torch.cat((in_1[:, 0::2, :, :], in_2[:, 1::2, :, :]), dim=1)
243 |
244 | out_1 = self.biconv_1(fu_1)
245 | out_2 = self.biconv_2(fu_2)
246 |
247 | out = out_1 + out_2
248 | return out
249 |
250 | class ResnetBlock(nn.Module):
251 | def __init__(self, dim, dim_out, timestep_level_emb_dim=None, dropout=0, norm_groups=32):
252 | super().__init__()
253 | self.timestep_func = nn.Sequential(
254 | Swish(),
255 | nn.Linear(timestep_level_emb_dim, dim_out)
256 | )
257 |
258 | assert dim == dim_out, f"Error: input ({dim}) and output ({dim_out}) channel dimension."
259 |
260 | self.block1 = Block(dim, dim_out, groups=norm_groups)
261 | self.block2 = Block(dim_out, dim_out, groups=norm_groups, dropout=dropout)
262 | if dim == dim_out:
263 | self.res_conv = nn.Identity()
264 |
265 | def forward(self, x, time_emb):
266 | b, c, h, w = x.shape
267 | h = self.block1(x)
268 | t_emb = self.timestep_func(time_emb).type(h.dtype)
269 | h = h + t_emb[:, :, None, None]
270 | h = self.block2(h)
271 | return h + self.res_conv(x)
272 |
273 | class ResnetBlocWithAttn(nn.Module):
274 | def __init__(self, dim, dim_out, *, timestep_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False):
275 | super().__init__()
276 | self.with_attn = with_attn
277 | self.res_block = ResnetBlock(
278 | dim, dim_out, timestep_level_emb_dim, norm_groups=norm_groups, dropout=dropout)
279 | if with_attn:
280 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
281 |
282 | def forward(self, x, time_emb):
283 | x = self.res_block(x, time_emb)
284 | if(self.with_attn):
285 | x = self.attn(x)
286 | return x
287 | # -------------------------------------------------- FP Module: END ----------------------------------------------------------------------
288 |
289 | # -------------------------------------------------- BI Module: START --------------------------------------------------------------------
290 | class CP_Up(nn.Module):
291 | def __init__(self, dim, dynamic_group=5):
292 | super().__init__()
293 | self.biconv1 = BIConv(dim, dim, 3, 1, 1, dynamic_group=dynamic_group)
294 | self.biconv2 = BIConv(dim, dim, 3, 1, 1, dynamic_group=dynamic_group)
295 | self.up = nn.PixelShuffle(2)
296 |
297 | def forward(self, x, t):
298 | '''
299 | input: b,c,h,w
300 | output: b,c/2,h*2,w*2
301 | '''
302 | out1 = self.biconv1(x, t)
303 | out2 = self.biconv2(x, t)
304 | out = torch.cat([out1, out2], dim=1)
305 | out = self.up(out)
306 | return out
307 |
308 | class CP_Down(nn.Module):
309 | def __init__(self, dim, dynamic_group=5):
310 | super().__init__()
311 | self.biconv1 = BIConv(dim//2, dim//2, 3, padding=1, dynamic_group=dynamic_group)
312 | self.biconv2 = BIConv(dim//2, dim//2, 3, padding=1, dynamic_group=dynamic_group)
313 | self.down = nn.PixelUnshuffle(2)
314 |
315 | def forward(self, x, t):
316 | '''
317 | input: b,c,h,w
318 | output: b,2c,h/2,w/2
319 | '''
320 | b,c,h,w = x.shape
321 | out1 = self.biconv1(x[:,:c//2,:,:], t)
322 | out2 = self.biconv2(x[:,c//2:,:,:], t)
323 | out = out1 + out2
324 | out = self.down(out)
325 | return out
326 |
327 | class CS_Fusion(nn.Module):
328 | def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=False, groups=1, dynamic_group=5):
329 | super(CS_Fusion, self).__init__()
330 |
331 | assert in_channels // 2 == out_channels, f"Error: input ({in_channels}) and output ({out_channels}) channel dimension."
332 |
333 | self.biconv_1 = BIConv(out_channels, out_channels, kernel_size, stride, padding, bias, groups, dynamic_group=dynamic_group)
334 | self.biconv_2 = BIConv(out_channels, out_channels, kernel_size, stride, padding, bias, groups, dynamic_group=dynamic_group)
335 |
336 | def forward(self, x, t):
337 | '''
338 | x: b,c,h,w
339 | out: b,c/2,h,w
340 | '''
341 | b,c,h,w = x.shape
342 | in_1 = x[:,:c//2,:,:]
343 | in_2 = x[:,c//2:,:,:]
344 |
345 | fu_1 = torch.cat((in_1[:, 1::2, :, :], in_2[:, 0::2, :, :]), dim=1)
346 | fu_2 = torch.cat((in_1[:, 0::2, :, :], in_2[:, 1::2, :, :]), dim=1)
347 |
348 | out_1 = self.biconv_1(fu_1, t)
349 | out_2 = self.biconv_2(fu_2, t)
350 |
351 | out = out_1 + out_2
352 | return out
353 |
354 | class BI_Block(nn.Module):
355 | def __init__(self, dim, dim_out, groups=32, dropout=0, dynamic_group=5):
356 | super().__init__()
357 | self.block = nn.Sequential(
358 | nn.GroupNorm(groups, dim),
359 | Swish(),
360 | nn.Dropout(dropout) if dropout != 0 else nn.Identity(),
361 | )
362 |
363 | assert dim == dim_out, f"Error: input ({dim}) and output ({dim_out}) channel dimension."
364 |
365 | if dim == dim_out:
366 | self.conv = BIConv(dim, dim_out, 3, padding=1, dynamic_group=dynamic_group)
367 |
368 | def forward(self, x, t):
369 | return self.conv(self.block(x), t)
370 |
371 | class BI_ResnetBlock(nn.Module):
372 | def __init__(self, dim, dim_out, timestep_level_emb_dim=None, dropout=0, norm_groups=32, dynamic_group=5):
373 | super().__init__()
374 | self.timestep_func = nn.Sequential(
375 | Swish(),
376 | nn.Linear(timestep_level_emb_dim, dim_out)
377 | )
378 |
379 | assert dim == dim_out, f"Error: input ({dim}) and output ({dim_out}) channel dimension."
380 |
381 | self.block1 = BI_Block(dim, dim_out, groups=norm_groups, dynamic_group=dynamic_group)
382 | self.block2 = BI_Block(dim_out, dim_out, groups=norm_groups, dropout=dropout, dynamic_group=dynamic_group)
383 | if dim == dim_out:
384 | self.res_conv = nn.Identity()
385 |
386 | def forward(self, x, time_emb, t):
387 | b, c, h, w = x.shape
388 | h = self.block1(x, t)
389 | t_emb = self.timestep_func(time_emb).type(h.dtype)
390 | h = h + t_emb[:, :, None, None]
391 | h = self.block2(h, t)
392 | return h + self.res_conv(x)
393 |
394 | class BI_ResnetBlocWithAttn(nn.Module):
395 | def __init__(self, dim, dim_out, *, timestep_level_emb_dim=None, norm_groups=32, dropout=0, with_attn=False, dynamic_group=5):
396 | super().__init__()
397 | self.with_attn = with_attn
398 | self.res_block = BI_ResnetBlock(
399 | dim, dim_out, timestep_level_emb_dim, norm_groups=norm_groups, dropout=dropout, dynamic_group=dynamic_group)
400 | if with_attn:
401 | self.attn = SelfAttention(dim_out, norm_groups=norm_groups)
402 |
403 | def forward(self, x, time_emb, t):
404 | x = self.res_block(x, time_emb, t)
405 | if(self.with_attn):
406 | x = self.attn(x)
407 | return x
408 | # -------------------------------------------------- BI Module: END --------------------------------------------------------------------
409 |
410 | # ----------------------------------------------- BI-DiffSR UNet: START ----------------------------------------------------------------
411 | @ARCH_REGISTRY.register()
412 | class BIDiffSRUNet(nn.Module):
413 | def __init__(
414 | self,
415 | in_channel=6,
416 | out_channel=3,
417 | inner_channel=32,
418 | norm_groups=32,
419 | channel_mults=(1, 2, 4, 8, 8),
420 | attn_res=(8),
421 | res_blocks=3,
422 | dropout=0,
423 | image_size=128,
424 | fp_res=(0),
425 | total_step=2000,
426 | dynamic_group=5
427 | ):
428 | super(BIDiffSRUNet, self).__init__()
429 |
430 | self.in_channel = in_channel
431 | self.total_step = total_step
432 | self.dynamic_group = dynamic_group
433 |
434 | timestep_level_channel = inner_channel
435 | self.timestep_level_mlp = nn.Sequential(
436 | PositionalEncoding(inner_channel),
437 | nn.Linear(inner_channel, inner_channel * 4),
438 | Swish(),
439 | nn.Linear(inner_channel * 4, inner_channel)
440 | )
441 |
442 | num_mults = len(channel_mults)
443 | pre_channel = inner_channel
444 | feat_channels = [pre_channel]
445 | now_res = image_size
446 | downs = [nn.Conv2d(in_channel, inner_channel,
447 | kernel_size=3, padding=1)]
448 | for ind in range(num_mults):
449 | is_last = (ind == num_mults - 1)
450 | use_attn = (now_res in attn_res)
451 | channel_mult = inner_channel * channel_mults[ind]
452 | for _ in range(0, res_blocks):
453 | downs.append(BI_ResnetBlocWithAttn(
454 | pre_channel, channel_mult, timestep_level_emb_dim=timestep_level_channel, norm_groups=norm_groups, dropout=dropout, with_attn=use_attn, dynamic_group=dynamic_group))
455 | feat_channels.append(channel_mult)
456 | pre_channel = channel_mult
457 | if not is_last:
458 | downs.append(CP_Down(pre_channel, dynamic_group=dynamic_group))
459 | now_res = now_res//2
460 | pre_channel = pre_channel*2
461 | feat_channels.append(pre_channel)
462 | self.downs = nn.ModuleList(downs)
463 |
464 | self.mid = nn.ModuleList([
465 | BI_ResnetBlocWithAttn(pre_channel, pre_channel, timestep_level_emb_dim=timestep_level_channel, norm_groups=norm_groups,
466 | dropout=dropout, with_attn=True, dynamic_group=dynamic_group),
467 | BI_ResnetBlocWithAttn(pre_channel, pre_channel, timestep_level_emb_dim=timestep_level_channel, norm_groups=norm_groups,
468 | dropout=dropout, with_attn=False, dynamic_group=dynamic_group)
469 | ])
470 |
471 | ups = []
472 | for ind in reversed(range(num_mults)):
473 | is_last = (ind < 1)
474 | use_attn = (now_res in attn_res)
475 | use_fp= (now_res in fp_res)
476 | channel_mult = inner_channel * channel_mults[ind]
477 | for _ in range(0, res_blocks+1):
478 | if use_fp:
479 | ups.append(CS_Fusion_FP(pre_channel+feat_channels.pop(), channel_mult, kernel_size=1, stride=1, padding=0, bias=False, groups=1))
480 | ups.append(ResnetBlocWithAttn(
481 | channel_mult, channel_mult, timestep_level_emb_dim=timestep_level_channel, norm_groups=norm_groups,
482 | dropout=dropout, with_attn=use_attn))
483 | else:
484 | ups.append(CS_Fusion(pre_channel+feat_channels.pop(), channel_mult, kernel_size=1, stride=1, padding=0, bias=False, groups=1, dynamic_group=dynamic_group))
485 | ups.append(BI_ResnetBlocWithAttn(
486 | channel_mult, channel_mult, timestep_level_emb_dim=timestep_level_channel, norm_groups=norm_groups,
487 | dropout=dropout, with_attn=use_attn, dynamic_group=dynamic_group))
488 | pre_channel = channel_mult
489 | if not is_last:
490 | ups.append(CP_Up(pre_channel, dynamic_group=dynamic_group))
491 | now_res = now_res*2
492 | pre_channel = pre_channel//2
493 |
494 | self.ups = nn.ModuleList(ups)
495 |
496 | self.final_conv = Block_F(pre_channel, default(out_channel, in_channel), groups=norm_groups)
497 |
498 | def forward(self, x, c, time):
499 | index_dynamic = int(time[0][0] * self.dynamic_group / self.total_step)
500 | index_dynamic = max(0, min(index_dynamic, self.dynamic_group - 1))
501 |
502 | time = time.squeeze(1) # consistent with the original code
503 |
504 | if self.in_channel != 3:
505 | x = torch.cat([c, x], dim=1)
506 | t = self.timestep_level_mlp(time) if exists(
507 | self.timestep_level_mlp) else None
508 |
509 | feats = []
510 | for layer in self.downs:
511 | if isinstance(layer, BI_ResnetBlocWithAttn):
512 | x = layer(x, t, index_dynamic)
513 | elif isinstance(layer, ResnetBlocWithAttn):
514 | x = layer(x, t)
515 | elif isinstance(layer, CP_Down):
516 | x = layer(x, index_dynamic)
517 | else:
518 | x = layer(x)
519 | feats.append(x)
520 |
521 | for layer in self.mid:
522 | if isinstance(layer, BI_ResnetBlocWithAttn):
523 | x = layer(x, t, index_dynamic)
524 | else:
525 | x = layer(x)
526 |
527 | for layer in self.ups:
528 | if isinstance(layer, CS_Fusion):
529 | x = layer(torch.cat((x, feats.pop()), dim=1), index_dynamic)
530 | elif isinstance(layer, CS_Fusion_FP):
531 | x = layer(torch.cat((x, feats.pop()), dim=1))
532 | elif isinstance(layer, BI_ResnetBlocWithAttn):
533 | x = layer(x, t, index_dynamic)
534 | elif isinstance(layer, ResnetBlocWithAttn):
535 | x = layer(x, t)
536 | elif isinstance(layer, CP_Up):
537 | x = layer(x, index_dynamic)
538 | else:
539 | x = layer(x)
540 |
541 | return self.final_conv(x)
542 |
543 |
544 | if __name__ == '__main__':
545 | model = BIDiffSRUNet(
546 | in_channel = 6,
547 | out_channel = 3,
548 | inner_channel = 64,
549 | norm_groups = 16,
550 | channel_mults = [1, 2, 4, 8],
551 | attn_res = [],
552 | res_blocks = 2,
553 | dropout = 0.2,
554 | image_size = 256,
555 | fp_res= [256, 128],
556 | dynamic_group=5
557 | )
558 | print(model)
559 |
560 | x = torch.randn((2, 3, 128, 128))
561 | c = torch.randn((2, 3, 128, 128))
562 | timesteps = torch.randint(0, 10, (2,)).long().unsqueeze(1)
563 | x = model(x, c, timesteps)
564 | print(x.shape)
565 | print(sum(map(lambda x: x.numel(), model.parameters())))
--------------------------------------------------------------------------------
/diffglv/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import numpy as np
3 | import random
4 | import torch
5 | import torch.utils.data
6 | from copy import deepcopy
7 | from functools import partial
8 | from os import path as osp
9 |
10 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader
11 | from basicsr.utils import get_root_logger, scandir
12 | from basicsr.utils.dist_util import get_dist_info
13 | from basicsr.utils.registry import DATASET_REGISTRY
14 |
15 | __all__ = ['build_dataset', 'build_dataloader', 'build_dataset_generate']
16 |
17 | # automatically scan and import dataset modules for registry
18 | # scan all the files under the data folder with '_dataset' in file names
19 | data_folder = osp.dirname(osp.abspath(__file__))
20 | dataset_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) if v.endswith('_dataset.py')]
21 | # import all the dataset modules
22 | _dataset_modules = [importlib.import_module(f'diffglv.data.{file_name}') for file_name in dataset_filenames]
23 |
24 |
25 | def build_dataset_generate(dataset_opt, opt):
26 | """Build dataset from options.
27 |
28 | Args:
29 | dataset_opt (dict): Configuration for dataset. It must contain:
30 | name (str): Dataset name.
31 | type (str): Dataset type.
32 | """
33 | dataset_opt = deepcopy(dataset_opt)
34 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt, opt)
35 | logger = get_root_logger()
36 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
37 | return dataset
38 |
39 |
40 | def build_dataset(dataset_opt):
41 | """Build dataset from options.
42 |
43 | Args:
44 | dataset_opt (dict): Configuration for dataset. It must contain:
45 | name (str): Dataset name.
46 | type (str): Dataset type.
47 | """
48 | dataset_opt = deepcopy(dataset_opt)
49 | dataset = DATASET_REGISTRY.get(dataset_opt['type'])(dataset_opt)
50 | logger = get_root_logger()
51 | logger.info(f'Dataset [{dataset.__class__.__name__}] - {dataset_opt["name"]} is built.')
52 | return dataset
53 |
54 |
55 | def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None, seed=None):
56 | """Build dataloader.
57 |
58 | Args:
59 | dataset (torch.utils.data.Dataset): Dataset.
60 | dataset_opt (dict): Dataset options. It contains the following keys:
61 | phase (str): 'train' or 'val'.
62 | num_worker_per_gpu (int): Number of workers for each GPU.
63 | batch_size_per_gpu (int): Training batch size for each GPU.
64 | num_gpu (int): Number of GPUs. Used only in the train phase.
65 | Default: 1.
66 | dist (bool): Whether in distributed training. Used only in the train
67 | phase. Default: False.
68 | sampler (torch.utils.data.sampler): Data sampler. Default: None.
69 | seed (int | None): Seed. Default: None
70 | """
71 | phase = dataset_opt['phase']
72 | rank, _ = get_dist_info()
73 | if phase == 'train':
74 | if dist: # distributed training
75 | batch_size = dataset_opt['batch_size_per_gpu']
76 | num_workers = dataset_opt['num_worker_per_gpu']
77 | else: # non-distributed training
78 | multiplier = 1 if num_gpu == 0 else num_gpu
79 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
80 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
81 | dataloader_args = dict(
82 | dataset=dataset,
83 | batch_size=batch_size,
84 | shuffle=False,
85 | num_workers=num_workers,
86 | sampler=sampler,
87 | drop_last=True)
88 | if sampler is None:
89 | dataloader_args['shuffle'] = True
90 | dataloader_args['worker_init_fn'] = partial(
91 | worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
92 | elif phase in ['val', 'test']: # validation
93 | dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
94 | else:
95 | raise ValueError(f"Wrong dataset phase: {phase}. Supported ones are 'train', 'val' and 'test'.")
96 |
97 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
98 | dataloader_args['persistent_workers'] = dataset_opt.get('persistent_workers', False)
99 |
100 | prefetch_mode = dataset_opt.get('prefetch_mode')
101 | if prefetch_mode == 'cpu': # CPUPrefetcher
102 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
103 | logger = get_root_logger()
104 | logger.info(f'Use {prefetch_mode} prefetch dataloader: num_prefetch_queue = {num_prefetch_queue}')
105 | return PrefetchDataLoader(num_prefetch_queue=num_prefetch_queue, **dataloader_args)
106 | else:
107 | # prefetch_mode=None: Normal dataloader
108 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher
109 | return torch.utils.data.DataLoader(**dataloader_args)
110 |
111 |
112 | def worker_init_fn(worker_id, num_workers, rank, seed):
113 | # Set the worker seed to num_workers * rank + worker_id + seed
114 | worker_seed = num_workers * rank + worker_id + seed
115 | np.random.seed(worker_seed)
116 | random.seed(worker_seed)
117 |
--------------------------------------------------------------------------------
/diffglv/data/data_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import random
3 | import torch
4 |
5 |
6 | def paired_center_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
7 | """Paired random crop. Support Numpy array and Tensor inputs.
8 |
9 | It crops lists of lq and gt images with corresponding locations.
10 |
11 | Args:
12 | img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
13 | should have the same shape. If the input is an ndarray, it will
14 | be transformed to a list containing itself.
15 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
16 | should have the same shape. If the input is an ndarray, it will
17 | be transformed to a list containing itself.
18 | gt_patch_size (int): GT patch size.
19 | scale (int): Scale factor.
20 | gt_path (str): Path to ground-truth. Default: None.
21 |
22 | Returns:
23 | list[ndarray] | ndarray: GT images and LQ images. If returned results
24 | only have one element, just return ndarray.
25 | """
26 |
27 | if not isinstance(img_gts, list):
28 | img_gts = [img_gts]
29 | if not isinstance(img_lqs, list):
30 | img_lqs = [img_lqs]
31 |
32 | # determine input type: Numpy array or Tensor
33 | input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
34 |
35 | if input_type == 'Tensor':
36 | h_lq, w_lq = img_lqs[0].size()[-2:]
37 | h_gt, w_gt = img_gts[0].size()[-2:]
38 | else:
39 | h_lq, w_lq = img_lqs[0].shape[0:2]
40 | h_gt, w_gt = img_gts[0].shape[0:2]
41 | lq_patch_size = gt_patch_size // scale
42 |
43 | if h_gt != h_lq * scale or w_gt != w_lq * scale:
44 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
45 | f'multiplication of LQ ({h_lq}, {w_lq}).')
46 | if h_lq < lq_patch_size or w_lq < lq_patch_size:
47 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
48 | f'({lq_patch_size}, {lq_patch_size}). '
49 | f'Please remove {gt_path}.')
50 |
51 | top = (h_lq - lq_patch_size) // 2
52 | left =(w_lq - lq_patch_size) // 2
53 |
54 | # crop lq patch
55 | if input_type == 'Tensor':
56 | img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
57 | else:
58 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
59 |
60 | # crop corresponding gt patch
61 | top_gt, left_gt = int(top * scale), int(left * scale)
62 | if input_type == 'Tensor':
63 | img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
64 | else:
65 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
66 | if len(img_gts) == 1:
67 | img_gts = img_gts[0]
68 | if len(img_lqs) == 1:
69 | img_lqs = img_lqs[0]
70 | return img_gts, img_lqs
71 |
72 |
73 | def random_crop(img_gts, gt_patch_size, gt_path=None):
74 | if not isinstance(img_gts, list):
75 | img_gts = [img_gts]
76 |
77 |
78 | # determine input type: Numpy array or Tensor
79 | input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
80 |
81 | if input_type == 'Tensor':
82 | h_gt, w_gt = img_gts[0].size()[-2:]
83 | else:
84 | h_gt, w_gt = img_gts[0].shape[0:2]
85 |
86 | if h_gt < gt_patch_size or w_gt < gt_patch_size:
87 | raise ValueError(f' ({h_gt}, {w_gt}) is smaller than patch size '
88 | f'({gt_patch_size}, {gt_patch_size}). '
89 | f'Please remove {gt_path}.')
90 |
91 | top = random.randint(0, h_gt - gt_patch_size)
92 | left = random.randint(0, w_gt - gt_patch_size)
93 |
94 |
95 | if input_type == 'Tensor':
96 | img_gts = [v[:, :, top:top + gt_patch_size, left:left + gt_patch_size] for v in img_gts]
97 | else:
98 | img_gts = [v[top:top + gt_patch_size, left:left + gt_patch_size, ...] for v in img_gts]
99 | if len(img_gts) == 1:
100 | img_gts = img_gts[0]
101 |
102 | return img_gts
--------------------------------------------------------------------------------
/diffglv/data/paired_image_dataset.py:
--------------------------------------------------------------------------------
1 | from torch.utils import data as data
2 | from torchvision.transforms.functional import normalize
3 |
4 | from basicsr.data.data_util import paired_paths_from_folder, paired_paths_from_lmdb, paired_paths_from_meta_info_file
5 | from basicsr.data.transforms import augment
6 | from basicsr.utils import FileClient, imfrombytes, img2tensor
7 | # from basicsr.utils.matlab_functions import rgb2ycbcr
8 | from basicsr.utils.registry import DATASET_REGISTRY
9 | from diffglv.data.data_util import paired_center_crop
10 | from diffglv.utils.transforms import paired_random_crop
11 |
12 | import numpy as np
13 | import cv2
14 | import random
15 |
16 | @DATASET_REGISTRY.register()
17 | class MultiPairedImageDataset(data.Dataset):
18 | """Paired image dataset for image restoration.
19 |
20 | Read LQ (Low Quality, e.g. LR (Low Resolution), blurry, noisy, etc) and GT image pairs.
21 |
22 | There are three modes:
23 | 1. 'lmdb': Use lmdb files.
24 | If opt['io_backend'] == lmdb.
25 | 2. 'meta_info_file': Use meta information file to generate paths.
26 | If opt['io_backend'] != lmdb and opt['meta_info_file'] is not None.
27 | 3. 'folder': Scan folders to generate paths.
28 | The rest.
29 |
30 | Args:
31 | opt (dict): Config for train datasets. It contains the following keys:
32 | dataroot_gt (str): Data root path for gt.
33 | dataroot_lq (str): Data root path for lq.
34 | meta_info_file (str): Path for meta information file.
35 | io_backend (dict): IO backend type and other kwarg.
36 | filename_tmpl (str): Template for each filename. Note that the template excludes the file extension.
37 | Default: '{}'.
38 | gt_size (int): Cropped patched size for gt patches.
39 | use_hflip (bool): Use horizontal flips.
40 | use_rot (bool): Use rotation (use vertical flip and transposing h and w for implementation).
41 |
42 | scale (bool): Scale, which will be added automatically.
43 | phase (str): 'train' or 'val'.
44 | """
45 |
46 | def __init__(self, opt):
47 | super(MultiPairedImageDataset, self).__init__()
48 | self.opt = opt
49 | # file client (io backend)
50 | self.file_client = None
51 | self.io_backend_opt = opt['io_backend']
52 | self.mean = opt['mean'] if 'mean' in opt else None
53 | self.std = opt['std'] if 'std' in opt else None
54 | self.task = opt['task'] if 'task' in opt else None
55 | self.noise = opt['noise'] if 'noise' in opt else 0
56 |
57 | self.gt_folder, self.lq_folder = opt['dataroot_gt'], opt['dataroot_lq']
58 |
59 | if 'filename_tmpl' in opt:
60 | self.filename_tmpl = opt['filename_tmpl']
61 | else:
62 | self.filename_tmpl = '{}'
63 |
64 | if self.io_backend_opt['type'] == 'lmdb':
65 | self.io_backend_opt['db_paths'] = [self.lq_folder, self.gt_folder]
66 | self.io_backend_opt['client_keys'] = ['lq', 'gt']
67 | self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
68 | elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
69 | self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
70 | self.opt['meta_info_file'], self.filename_tmpl)
71 | else:
72 | self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
73 |
74 | def __getitem__(self, index):
75 | if self.file_client is None:
76 | self.file_client = FileClient(self.io_backend_opt.pop('type'), **self.io_backend_opt)
77 |
78 | scale = self.opt['scale']
79 |
80 | # Load gt and lq images. Dimension order: HWC; channel order: BGR;
81 |
82 | if self.task == 'CAR':
83 | # image range: [0, 255], int., H W 1
84 |
85 | gt_path = self.paths[index]['gt_path']
86 | img_bytes = self.file_client.get(gt_path, 'gt')
87 | img_gt = imfrombytes(img_bytes, flag='grayscale', float32=False)
88 | lq_path = self.paths[index]['lq_path']
89 | img_bytes = self.file_client.get(lq_path, 'lq')
90 | img_lq = imfrombytes(img_bytes, flag='grayscale', float32=False)
91 | img_gt = np.expand_dims(img_gt, axis=2).astype(np.float32) / 255.
92 | img_lq = np.expand_dims(img_lq, axis=2).astype(np.float32) / 255.
93 |
94 | # gt_path = self.paths[index]['gt_path']
95 | # img_bytes = self.file_client.get(gt_path, 'gt')
96 | # img_gt = imfrombytes(img_bytes, float32=False)
97 | # lq_path = self.paths[index]['lq_path']
98 | # img_bytes = self.file_client.get(lq_path, 'lq')
99 | # img_lq = imfrombytes(img_bytes, float32=False)
100 | # img_gt = img_gt[:,:,0,None]
101 | # img_lq = img_lq[:,:,0,None]
102 |
103 | elif self.task == 'Color Denoising':
104 | gt_path = self.paths[index]['gt_path']
105 | lq_path = gt_path
106 | img_bytes = self.file_client.get(gt_path, 'gt')
107 | img_gt = imfrombytes(img_bytes, float32=True)
108 | if self.opt['phase'] != 'train':
109 | np.random.seed(seed=0)
110 | img_lq = img_gt + np.random.normal(0, self.noise/255., img_gt.shape)
111 |
112 | elif self.task == 'SR':
113 | # image range: [0, 1], float32., H W 3
114 | gt_path = self.paths[index]['gt_path']
115 | img_bytes = self.file_client.get(gt_path, 'gt')
116 | img_gt = imfrombytes(img_bytes, float32=True)
117 | lq_path = self.paths[index]['lq_path']
118 | img_bytes = self.file_client.get(lq_path, 'lq')
119 | img_lq = imfrombytes(img_bytes, float32=True)
120 | # bicubic
121 | img_lq = cv2.resize(img_lq, (img_lq.shape[1]*scale, img_lq.shape[0]*scale), interpolation=cv2.INTER_CUBIC)
122 |
123 | else:
124 | # image range: [0, 1], float32., H W 3
125 | gt_path = self.paths[index]['gt_path']
126 | img_bytes = self.file_client.get(gt_path, 'gt')
127 | img_gt = imfrombytes(img_bytes, float32=True)
128 | lq_path = self.paths[index]['lq_path']
129 | img_bytes = self.file_client.get(lq_path, 'lq')
130 | img_lq = imfrombytes(img_bytes, float32=True)
131 |
132 | scale = 1
133 | # augmentation for training
134 | if self.opt['phase'] == 'train':
135 | gt_size = self.opt['gt_size']
136 | # random crop
137 | img_gt, img_lq = paired_random_crop(img_gt, img_lq, gt_size, scale, gt_path)
138 | # flip, rotation
139 | img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_hflip'], self.opt['use_rot'])
140 |
141 | # color space transform
142 | if 'color' in self.opt and self.opt['color'] == 'y':
143 | print('Wrong: TODO')
144 | exit()
145 | # img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None]
146 | # img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
147 | else:
148 | # for val set
149 | if 'gt_size' in self.opt:
150 | gt_size = self.opt['gt_size']
151 | img_gt, img_lq = paired_center_crop(img_gt, img_lq, gt_size, scale, gt_path)
152 |
153 | # crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
154 | # TODO: It is better to update the datasets, rather than force to crop
155 | if self.opt['phase'] != 'train':
156 | img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
157 |
158 | # BGR to RGB, HWC to CHW, numpy to tensor
159 | img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
160 | # normalize
161 | if self.mean is not None or self.std is not None:
162 | normalize(img_lq, self.mean, self.std, inplace=True)
163 | normalize(img_gt, self.mean, self.std, inplace=True)
164 |
165 | return {'lq': img_lq, 'gt': img_gt, 'lq_path': lq_path, 'gt_path': gt_path, 'task': self.task}
166 |
167 | def __len__(self):
168 | return len(self.paths)
169 |
--------------------------------------------------------------------------------
/diffglv/losses/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from copy import deepcopy
3 | from os import path as osp
4 |
5 | from basicsr.utils import get_root_logger, scandir
6 | from basicsr.utils.registry import LOSS_REGISTRY
7 | from basicsr.losses.gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty
8 |
9 | __all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize']
10 |
11 | # automatically scan and import loss modules for registry
12 | # scan all the files under the 'losses' folder and collect files ending with '_loss.py'
13 | loss_folder = osp.dirname(osp.abspath(__file__))
14 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_loss.py')]
15 | # import all the loss modules
16 | _model_modules = [importlib.import_module(f'diffglv.losses.{file_name}') for file_name in loss_filenames]
17 |
18 |
19 | def build_loss(opt):
20 | """Build loss from options.
21 |
22 | Args:
23 | opt (dict): Configuration. It must contain:
24 | type (str): Model type.
25 | """
26 | opt = deepcopy(opt)
27 | loss_type = opt.pop('type')
28 | loss = LOSS_REGISTRY.get(loss_type)(**opt)
29 | logger = get_root_logger()
30 | logger.info(f'Loss [{loss.__class__.__name__}] is created.')
31 | return loss
32 |
--------------------------------------------------------------------------------
/diffglv/losses/at_loss.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 | from basicsr.utils.registry import LOSS_REGISTRY
5 |
6 |
7 | @LOSS_REGISTRY.register()
8 | class ATLoss(nn.Module):
9 |
10 | def __init__(self, loss_weight=1.0, reduction='mean'):
11 | super(ATLoss, self).__init__()
12 |
13 |
14 | def forward(self, pred, target):
15 | """
16 | Args:
17 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
18 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
19 | """
20 | Attention_pred = F.normalize(pred.pow(2).mean(1).view(pred.size(0), -1)) # (N, H*W)
21 | Attention_target = F.normalize(target.pow(2).mean(1).view(target.size(0), -1)) # (N, H*W)
22 |
23 | return nn.MSELoss()(Attention_pred, Attention_target)
24 |
--------------------------------------------------------------------------------
/diffglv/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from copy import deepcopy
3 | from os import path as osp
4 |
5 | from basicsr.utils import scandir
6 | from basicsr.utils.registry import METRIC_REGISTRY
7 | from basicsr.metrics.niqe import calculate_niqe
8 | from basicsr.metrics.psnr_ssim import calculate_psnr, calculate_ssim
9 |
10 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
11 |
12 | loss_folder = osp.dirname(osp.abspath(__file__))
13 | loss_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(loss_folder) if v.endswith('_metric.py')]
14 | _model_modules = [importlib.import_module(f'diffglv.metrics.{file_name}') for file_name in loss_filenames]
15 |
16 |
17 | def calculate_metric(data, opt):
18 | """Calculate metric from data and options.
19 |
20 | Args:
21 | opt (dict): Configuration. It must contain:
22 | type (str): Model type.
23 | """
24 | opt = deepcopy(opt)
25 | metric_type = opt.pop('type')
26 | metric = METRIC_REGISTRY.get(metric_type)(**data, **opt)
27 | return metric
28 |
--------------------------------------------------------------------------------
/diffglv/metrics/lpips.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import lpips
3 | from torch import nn
4 |
5 | def disabled_train(self: nn.Module) -> nn.Module:
6 | """Overwrite model.train with this function to make sure train/eval mode
7 | does not change anymore."""
8 | return self
9 |
10 | def frozen_module(module: nn.Module) -> None:
11 | module.eval()
12 | module.train = disabled_train
13 | for p in module.parameters():
14 | p.requires_grad = False
15 |
16 | class LPIPS:
17 | def __init__(self, net: str) -> None:
18 | self.model = lpips.LPIPS(net=net)
19 | frozen_module(self.model)
20 |
21 | @torch.no_grad()
22 | def __call__(self, img1: torch.Tensor, img2: torch.Tensor, normalize: bool, boundarypixels=0) -> torch.Tensor:
23 | """
24 | Compute LPIPS.
25 |
26 | Args:
27 | img1 (torch.Tensor): The first image (NCHW, RGB, [-1, 1]). Specify `normalize` if input
28 | image is range in [0, 1].
29 | img2 (torch.Tensor): The second image (NCHW, RGB, [-1, 1]). Specify `normalize` if input
30 | image is range in [0, 1].
31 | normalize (bool): If specified, the input images will be normalized from [0, 1] to [-1, 1].
32 |
33 | Returns:
34 | lpips_values (torch.Tensor): The lpips scores of this batch.
35 | """
36 |
37 | b, c, h, w = img1.shape
38 | img1 = img1[:, :, :h-h%boundarypixels, :w-w%boundarypixels]
39 | # img1 = img1[:,:, boundarypixels:-boundarypixels,boundarypixels:-boundarypixels]
40 | b, c, h, w = img2.shape
41 | img2 = img2[:, :, :h-h%boundarypixels, :w-w%boundarypixels]
42 | # img2 = img2[:,:, boundarypixels:-boundarypixels,boundarypixels:-boundarypixels]
43 |
44 | return self.model(img1, img2, normalize=normalize)
45 |
46 | def to(self, device: str) -> "LPIPS":
47 | self.model.to(device)
48 | return self
49 |
--------------------------------------------------------------------------------
/diffglv/models/BI_DiffSR_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from collections import OrderedDict
3 | from os import path as osp
4 | from tqdm import tqdm
5 |
6 | from basicsr.archs import build_network
7 | from basicsr.losses import build_loss
8 | from basicsr.metrics import calculate_metric
9 | from basicsr.utils import get_root_logger, imwrite, tensor2img, img2tensor
10 | from basicsr.utils.registry import MODEL_REGISTRY
11 | from diffglv.utils.base_model import BaseModel
12 | from torch.nn import functional as F
13 | import numpy as np
14 | from diffusers import DDPMScheduler, DDIMScheduler
15 |
16 | from diffglv.metrics.lpips import LPIPS
17 |
18 | @MODEL_REGISTRY.register()
19 | class BIDiffSRModel(BaseModel):
20 | """DiffIR model for stage two."""
21 |
22 | def __init__(self, opt):
23 | super(BIDiffSRModel, self).__init__(opt)
24 |
25 | # define network
26 | self.net_g = build_network(opt['network_g'])
27 | self.net_g = self.model_to_device(self.net_g)
28 | self.print_network(self.net_g)
29 |
30 | # load pretrained models
31 | load_path = self.opt['path'].get('pretrain_network_g', None)
32 | if load_path is not None:
33 | param_key = self.opt['path'].get('param_key_g', 'params')
34 | self.load_network(self.net_g, load_path, self.opt['path'].get('strict_load_g', True), param_key)
35 |
36 | # diffusion
37 | self.set_new_noise_schedule(self.opt['beta_schedule'], self.device)
38 |
39 | # lipis
40 | self.lpips_opt = self.opt['val']['metrics'].get('lpips', None)
41 | if self.lpips_opt != None:
42 | self.lpips_metric = LPIPS(net="alex").to(self.device)
43 |
44 | if self.is_train:
45 | self.init_training_settings()
46 |
47 | def init_training_settings(self):
48 | self.net_g.train()
49 | train_opt = self.opt['train']
50 |
51 | self.ema_decay = train_opt.get('ema_decay', 0)
52 | if self.ema_decay > 0:
53 | logger = get_root_logger()
54 | logger.info(f'Use Exponential Moving Average with decay: {self.ema_decay}')
55 | # define network net_g with Exponential Moving Average (EMA)
56 | # net_g_ema is used only for testing on one GPU and saving
57 | # There is no need to wrap with DistributedDataParallel
58 | self.net_g_ema = build_network(self.opt['network_g']).to(self.device)
59 | # load pretrained model
60 | load_path = self.opt['path'].get('pretrain_network_g', None)
61 | if load_path is not None:
62 | self.load_network(self.net_g_ema, load_path, self.opt['path'].get('strict_load_g', True), 'params_ema')
63 | else:
64 | self.model_ema(0) # copy net_g weight
65 | self.net_g_ema.eval()
66 |
67 | # define losses
68 | if train_opt.get('pixel_opt'):
69 | self.cri_pix = build_loss(train_opt['pixel_opt']).to(self.device)
70 | else:
71 | self.cri_pix = None
72 |
73 | if self.cri_pix is None:
74 | raise ValueError('pixel loss is None.')
75 |
76 | # set up optimizers and schedulers
77 | self.setup_optimizers()
78 | self.setup_schedulers()
79 |
80 | def setup_optimizers(self):
81 | train_opt = self.opt['train']
82 | optim_params = []
83 | for k, v in self.net_g.named_parameters():
84 | if v.requires_grad:
85 | optim_params.append(v)
86 | else:
87 | logger = get_root_logger()
88 | logger.warning(f'Network G: Params {k} will not be optimized.')
89 |
90 | optim_type = train_opt['optim_g'].pop('type')
91 | self.optimizer = self.get_optimizer(optim_type, optim_params, **train_opt['optim_g'])
92 | self.optimizers.append(self.optimizer)
93 |
94 | def set_new_noise_schedule(self, schedule_opt, device):
95 | scheduler_opt = self.opt['beta_schedule']
96 | scheduler_type = scheduler_opt.get('scheduler_type', None)
97 | _prediction_type = scheduler_opt.get('prediction_type', None)
98 | if scheduler_type == 'DDPM':
99 | self.noise_scheduler = DDPMScheduler(num_train_timesteps=schedule_opt['n_timestep'],
100 | beta_start=schedule_opt['linear_start'],
101 | beta_end=schedule_opt['linear_end'],
102 | beta_schedule=schedule_opt['schedule'])
103 | elif scheduler_type == 'DDIM':
104 | self.noise_scheduler = DDIMScheduler(num_train_timesteps=schedule_opt['n_timestep'],
105 | beta_start=schedule_opt['linear_start'],
106 | beta_end=schedule_opt['linear_end'],
107 | beta_schedule=schedule_opt['schedule'])
108 | else:
109 | raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
110 |
111 | if _prediction_type is not None:
112 | # set prediction_type of scheduler if defined
113 | self.noise_scheduler.register_to_config(prediction_type=_prediction_type)
114 |
115 | def feed_data(self, data):
116 | self.lq = data['lq'].to(self.device)
117 | if 'gt' in data:
118 | self.gt = data['gt'].to(self.device)
119 |
120 | def optimize_parameters(self, current_iter, noise=None):
121 | self.optimizer.zero_grad()
122 |
123 | noise = torch.randn_like(self.gt).to(self.device)
124 | bsz = self.gt.shape[0]
125 | # Sample a random timestep for each image
126 | random_timestep = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (1,), device=self.device)
127 | timesteps = random_timestep.repeat(bsz).long()
128 |
129 | # Add noise to the latents according to the noise magnitude at each timestep
130 | noisy_image = self.noise_scheduler.add_noise(self.gt, noise, timesteps)
131 |
132 | # Get the target for loss depending on the prediction type
133 | if self.noise_scheduler.config.prediction_type == "epsilon":
134 | target = noise
135 | elif self.noise_scheduler.config.prediction_type == "v_prediction":
136 | target = self.noise_scheduler.get_velocity(self.gt, noise, timesteps)
137 | elif self.noise_scheduler.config.prediction_type == "sample":
138 | target = self.gt
139 | else:
140 | raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}")
141 |
142 | # Predict the noise residual and compute loss
143 | _timesteps = timesteps.unsqueeze(1).to(self.device)
144 | noise_pred = self.net_g(noisy_image, self.lq, _timesteps)
145 | l_total = 0
146 | loss_dict = OrderedDict()
147 |
148 | if self.cri_pix:
149 | l_pix = self.cri_pix(noise_pred, target)
150 | l_total += l_pix
151 | loss_dict['l_pix'] = l_pix
152 |
153 | l_total.backward()
154 | self.optimizer.step()
155 |
156 | self.log_dict = self.reduce_loss_dict(loss_dict)
157 |
158 | if self.ema_decay > 0:
159 | self.model_ema(decay=self.ema_decay)
160 |
161 | def test(self):
162 | scale = 1
163 | window_size = 8
164 | mod_pad_h, mod_pad_w = 0, 0
165 | _, _, h, w = self.lq.size()
166 | if h % window_size != 0:
167 | mod_pad_h = window_size - h % window_size
168 | if w % window_size != 0:
169 | mod_pad_w = window_size - w % window_size
170 | img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
171 |
172 | if hasattr(self, 'net_g_ema'):
173 | print("TODO")
174 | else:
175 | self.net_g.eval()
176 |
177 | is_guidance = self.opt['beta_schedule'].get('is_guidance', False)
178 |
179 | if not is_guidance:
180 | # original conditional
181 | latents = torch.randn_like(img).to(self.device)
182 |
183 | self.noise_scheduler.set_timesteps(self.opt['beta_schedule']['num_inference_steps'])
184 |
185 | for t in self.noise_scheduler.timesteps:
186 | # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
187 | latent_model_input = latents
188 | lq_image = img
189 | _t = t.unsqueeze(0).unsqueeze(1).to(self.device)
190 |
191 | latent_model_input = self.noise_scheduler.scale_model_input(latent_model_input, timestep=t)
192 |
193 | # predict the noise residual
194 | with torch.no_grad():
195 | noise_pred = self.net_g(latent_model_input, lq_image, _t)
196 |
197 | # compute the previous noisy sample x_t -> x_t-1
198 | latents = self.noise_scheduler.step(noise_pred, t, latents).prev_sample
199 | else:
200 | # classifier-free guidance
201 | print("TODO")
202 |
203 | self.output = latents
204 | _, _, h, w = self.output.size()
205 | self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
206 |
207 | self.net_g.train()
208 |
209 | def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
210 | if self.opt['rank'] == 0:
211 | self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
212 |
213 | def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
214 | dataset_name = dataloader.dataset.opt['name']
215 | with_metrics = self.opt['val'].get('metrics') is not None
216 | use_pbar = self.opt['val'].get('pbar', False)
217 |
218 | if with_metrics:
219 | if not hasattr(self, 'metric_results'): # only execute in the first run
220 | self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
221 | # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
222 | self._initialize_best_metric_results(dataset_name)
223 | # zero self.metric_results
224 | if with_metrics:
225 | self.metric_results = {metric: 0 for metric in self.metric_results}
226 |
227 | metric_data = dict()
228 | if use_pbar:
229 | pbar = tqdm(total=len(dataloader), unit='image')
230 |
231 | for idx, val_data in enumerate(dataloader):
232 | img_name = osp.splitext(osp.basename(val_data['lq_path'][0]))[0]
233 | self.feed_data(val_data)
234 | self.test()
235 |
236 | visuals = self.get_current_visuals()
237 | sr_img = tensor2img([visuals['result']])
238 | metric_data['img'] = sr_img
239 | if 'gt' in visuals:
240 | gt_img = tensor2img([visuals['gt']])
241 | metric_data['img2'] = gt_img
242 | del self.gt
243 |
244 | # tentative for out of GPU memory
245 | del self.lq
246 | del self.output
247 | torch.cuda.empty_cache()
248 |
249 | if save_img:
250 | if self.opt['is_train']:
251 | save_img_path = osp.join(self.opt['path']['visualization'], img_name,
252 | f'{img_name}_{current_iter}.png')
253 | else:
254 | if self.opt['val']['suffix']:
255 | save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
256 | f'{img_name}_{self.opt["val"]["suffix"]}.png')
257 | else:
258 | save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
259 | f'{img_name}_{self.opt["name"]}.png')
260 |
261 | imwrite(sr_img, save_img_path)
262 |
263 | if with_metrics:
264 | # calculate metrics
265 | for name, opt_ in self.opt['val']['metrics'].items():
266 | if name == 'lpips': continue
267 | self.metric_results[name] += calculate_metric(metric_data, opt_)
268 | if self.lpips_opt != None:
269 | sr_img = (img2tensor(metric_data['img']) / 255.0).unsqueeze(0).to(self.device)
270 | hq_img = (img2tensor(metric_data['img2']) / 255.0).unsqueeze(0).to(self.device)
271 | self.metric_results['lpips'] += self.lpips_metric(sr_img, hq_img, normalize=True, boundarypixels=self.lpips_opt['crop_border']).item()
272 | if use_pbar:
273 | pbar.update(1)
274 | pbar.set_description(f'Test {img_name}')
275 | if use_pbar:
276 | pbar.close()
277 |
278 | if with_metrics:
279 | for metric in self.metric_results.keys():
280 | self.metric_results[metric] /= (idx + 1)
281 | # update the best metric result
282 | self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
283 |
284 | self._log_validation_metric_values(current_iter, dataset_name, tb_logger)
285 |
286 | def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
287 | log_str = f'Validation {dataset_name}\n'
288 | for metric, value in self.metric_results.items():
289 | log_str += f'\t # {metric}: {value:.4f}'
290 | if hasattr(self, 'best_metric_results'):
291 | log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
292 | f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
293 | log_str += '\n'
294 |
295 | logger = get_root_logger()
296 | logger.info(log_str)
297 | if tb_logger:
298 | for metric, value in self.metric_results.items():
299 | tb_logger.add_scalar(f'metrics/{dataset_name}/{metric}', value, current_iter)
300 |
301 | def get_current_visuals(self):
302 | out_dict = OrderedDict()
303 | out_dict['lq'] = self.lq.detach().cpu()
304 | out_dict['result'] = self.output.detach().cpu()
305 | if hasattr(self, 'gt'):
306 | out_dict['gt'] = self.gt.detach().cpu()
307 | return out_dict
308 |
309 | def save(self, epoch, current_iter):
310 | if hasattr(self, 'net_g_ema'):
311 | self.save_network([self.net_g, self.net_g_ema], 'net_g', current_iter, param_key=['params', 'params_ema'])
312 | else:
313 | self.save_network(self.net_g, 'net_g', current_iter)
314 | self.save_training_state(epoch, current_iter)
315 |
--------------------------------------------------------------------------------
/diffglv/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from copy import deepcopy
3 | from os import path as osp
4 |
5 | from basicsr.utils import get_root_logger, scandir
6 | from basicsr.utils.registry import MODEL_REGISTRY
7 |
8 | __all__ = ['build_model']
9 |
10 | # automatically scan and import model modules for registry
11 | # scan all the files under the 'models' folder and collect files ending with '_model.py'
12 | model_folder = osp.dirname(osp.abspath(__file__))
13 | model_filenames = [osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) if v.endswith('_model.py')]
14 | # import all the model modules
15 | _model_modules = [importlib.import_module(f'diffglv.models.{file_name}') for file_name in model_filenames]
16 |
17 |
18 | def build_model(opt):
19 | """Build model from options.
20 |
21 | Args:
22 | opt (dict): Configuration. It must contain:
23 | model_type (str): Model type.
24 | """
25 | opt = deepcopy(opt)
26 | model = MODEL_REGISTRY.get(opt['model_type'])(opt)
27 | logger = get_root_logger()
28 | logger.info(f'Model [{model.__class__.__name__}] is created.')
29 | return model
30 |
--------------------------------------------------------------------------------
/diffglv/utils/GPU_memory.py:
--------------------------------------------------------------------------------
1 | import pynvml
2 | pynvml.nvmlInit()
3 | gpuDeviceCount = pynvml.nvmlDeviceGetCount()
4 | UNIT = 1024 * 1024
5 | for i in range(gpuDeviceCount):
6 | handle = pynvml.nvmlDeviceGetHandleByIndex(i)#获取GPU i的handle,后续通过handle来处理
7 |
8 | memoryInfo = pynvml.nvmlDeviceGetMemoryInfo(handle)#通过handle获取GPU i的信息
9 |
10 | m_total = memoryInfo.total/UNIT
11 | m_used = memoryInfo.used/UNIT
12 | print('[%s][%s/%s]' % (i, m_used, m_total))
13 | pynvml.nvmlShutdown()
--------------------------------------------------------------------------------
/diffglv/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/diffglv/utils/__init__.py
--------------------------------------------------------------------------------
/diffglv/utils/base_model.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | from collections import OrderedDict
5 | from copy import deepcopy
6 | from torch.nn.parallel import DataParallel, DistributedDataParallel
7 |
8 | from diffglv.utils import lr_scheduler as lr_scheduler
9 | from basicsr.utils import get_root_logger
10 | from basicsr.utils.dist_util import master_only
11 | import torch.nn as nn
12 |
13 | class BaseModel(nn.Module):
14 | """Base model."""
15 |
16 | def __init__(self, opt):
17 | super().__init__()
18 | self.opt = opt
19 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
20 | self.is_train = opt['is_train']
21 | self.schedulers = []
22 | self.optimizers = []
23 |
24 | def feed_data(self, data):
25 | pass
26 |
27 | def optimize_parameters(self):
28 | pass
29 |
30 | def get_current_visuals(self):
31 | pass
32 |
33 | def save(self, epoch, current_iter):
34 | """Save networks and training state."""
35 | pass
36 |
37 | def validation(self, dataloader, current_iter, tb_logger, save_img=False):
38 | """Validation function.
39 |
40 | Args:
41 | dataloader (torch.utils.data.DataLoader): Validation dataloader.
42 | current_iter (int): Current iteration.
43 | tb_logger (tensorboard logger): Tensorboard logger.
44 | save_img (bool): Whether to save images. Default: False.
45 | """
46 | if self.opt['dist']:
47 | self.dist_validation(dataloader, current_iter, tb_logger, save_img)
48 | else:
49 | self.nondist_validation(dataloader, current_iter, tb_logger, save_img)
50 |
51 | def _initialize_best_metric_results(self, dataset_name):
52 | """Initialize the best metric results dict for recording the best metric value and iteration."""
53 | if hasattr(self, 'best_metric_results') and dataset_name in self.best_metric_results:
54 | return
55 | elif not hasattr(self, 'best_metric_results'):
56 | self.best_metric_results = dict()
57 |
58 | # add a dataset record
59 | record = dict()
60 | for metric, content in self.opt['val']['metrics'].items():
61 | better = content.get('better', 'higher')
62 | init_val = float('-inf') if better == 'higher' else float('inf')
63 | record[metric] = dict(better=better, val=init_val, iter=-1)
64 | self.best_metric_results[dataset_name] = record
65 |
66 | def _update_best_metric_result(self, dataset_name, metric, val, current_iter):
67 | if self.best_metric_results[dataset_name][metric]['better'] == 'higher':
68 | if val >= self.best_metric_results[dataset_name][metric]['val']:
69 | self.best_metric_results[dataset_name][metric]['val'] = val
70 | self.best_metric_results[dataset_name][metric]['iter'] = current_iter
71 | else:
72 | if val <= self.best_metric_results[dataset_name][metric]['val']:
73 | self.best_metric_results[dataset_name][metric]['val'] = val
74 | self.best_metric_results[dataset_name][metric]['iter'] = current_iter
75 |
76 | def model_ema(self, decay=0.999):
77 | net_g = self.get_bare_model(self.net_g)
78 |
79 | net_g_params = dict(net_g.named_parameters())
80 | net_g_ema_params = dict(self.net_g_ema.named_parameters())
81 |
82 | for k in net_g_ema_params.keys():
83 | net_g_ema_params[k].data.mul_(decay).add_(net_g_params[k].data, alpha=1 - decay)
84 |
85 | def get_current_log(self):
86 | return self.log_dict
87 |
88 | def model_to_device(self, net):
89 | """Model to device. It also warps models with DistributedDataParallel
90 | or DataParallel.
91 |
92 | Args:
93 | net (nn.Module)
94 | """
95 | net = net.to(self.device)
96 | if self.opt['dist']:
97 | find_unused_parameters = self.opt.get('find_unused_parameters', False)
98 | net = DistributedDataParallel(
99 | net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
100 | elif self.opt['num_gpu'] > 1:
101 | net = DataParallel(net)
102 | return net
103 |
104 | def get_optimizer(self, optim_type, params, lr, **kwargs):
105 | if optim_type == 'Adam':
106 | optimizer = torch.optim.Adam(params, lr, **kwargs)
107 | else:
108 | raise NotImplementedError(f'optimizer {optim_type} is not supperted yet.')
109 | return optimizer
110 |
111 | def setup_schedulers(self):
112 | """Set up schedulers."""
113 | train_opt = self.opt['train']
114 | scheduler_type = train_opt['scheduler'].pop('type')
115 | if scheduler_type in ['MultiStepLR', 'MultiStepRestartLR']:
116 | for optimizer in self.optimizers:
117 | self.schedulers.append(lr_scheduler.MultiStepRestartLR(optimizer, **train_opt['scheduler']))
118 | elif scheduler_type == 'CosineAnnealingRestartLR':
119 | for optimizer in self.optimizers:
120 | self.schedulers.append(lr_scheduler.CosineAnnealingRestartLR(optimizer, **train_opt['scheduler']))
121 | elif scheduler_type == 'CosineAnnealingRestartCyclicLR':
122 | for optimizer in self.optimizers:
123 | self.schedulers.append(lr_scheduler.CosineAnnealingRestartCyclicLR(optimizer, **train_opt['scheduler']))
124 | else:
125 | raise NotImplementedError(f'Scheduler {scheduler_type} is not implemented yet.')
126 |
127 | def get_bare_model(self, net):
128 | """Get bare model, especially under wrapping with
129 | DistributedDataParallel or DataParallel.
130 | """
131 | if isinstance(net, (DataParallel, DistributedDataParallel)):
132 | net = net.module
133 | return net
134 |
135 | @master_only
136 | def print_network(self, net):
137 | """Print the str and parameter number of a network.
138 |
139 | Args:
140 | net (nn.Module)
141 | """
142 | if isinstance(net, (DataParallel, DistributedDataParallel)):
143 | net_cls_str = f'{net.__class__.__name__} - {net.module.__class__.__name__}'
144 | else:
145 | net_cls_str = f'{net.__class__.__name__}'
146 |
147 | net = self.get_bare_model(net)
148 | net_str = str(net)
149 | net_params = sum(map(lambda x: x.numel(), net.parameters()))
150 |
151 | logger = get_root_logger()
152 | logger.info(f'Network: {net_cls_str}, with parameters: {net_params:,d}')
153 | logger.info(net_str)
154 |
155 | def _set_lr(self, lr_groups_l):
156 | """Set learning rate for warmup.
157 |
158 | Args:
159 | lr_groups_l (list): List for lr_groups, each for an optimizer.
160 | """
161 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
162 | for param_group, lr in zip(optimizer.param_groups, lr_groups):
163 | param_group['lr'] = lr
164 |
165 | def _get_init_lr(self):
166 | """Get the initial lr, which is set by the scheduler.
167 | """
168 | init_lr_groups_l = []
169 | for optimizer in self.optimizers:
170 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
171 | return init_lr_groups_l
172 |
173 | def update_learning_rate(self, current_iter, warmup_iter=-1):
174 | """Update learning rate.
175 |
176 | Args:
177 | current_iter (int): Current iteration.
178 | warmup_iter (int): Warmup iter numbers. -1 for no warmup.
179 | Default: -1.
180 | """
181 | if current_iter > 1:
182 | for scheduler in self.schedulers:
183 | scheduler.step()
184 | # set up warm-up learning rate
185 | if current_iter < warmup_iter:
186 | # get initial lr for each group
187 | init_lr_g_l = self._get_init_lr()
188 | # modify warming-up learning rates
189 | # currently only support linearly warm up
190 | warm_up_lr_l = []
191 | for init_lr_g in init_lr_g_l:
192 | warm_up_lr_l.append([v / warmup_iter * current_iter for v in init_lr_g])
193 | # set learning rate
194 | self._set_lr(warm_up_lr_l)
195 |
196 | def get_current_learning_rate(self):
197 | return [param_group['lr'] for param_group in self.optimizers[0].param_groups]
198 |
199 | @master_only
200 | def save_network(self, net, net_label, current_iter, param_key='params'):
201 | """Save networks.
202 |
203 | Args:
204 | net (nn.Module | list[nn.Module]): Network(s) to be saved.
205 | net_label (str): Network label.
206 | current_iter (int): Current iter number.
207 | param_key (str | list[str]): The parameter key(s) to save network.
208 | Default: 'params'.
209 | """
210 | if current_iter == -1:
211 | current_iter = 'latest'
212 | save_filename = f'{net_label}_{current_iter}.pth'
213 | save_path = os.path.join(self.opt['path']['models'], save_filename)
214 |
215 | net = net if isinstance(net, list) else [net]
216 | param_key = param_key if isinstance(param_key, list) else [param_key]
217 | assert len(net) == len(param_key), 'The lengths of net and param_key should be the same.'
218 |
219 | save_dict = {}
220 | for net_, param_key_ in zip(net, param_key):
221 | net_ = self.get_bare_model(net_)
222 | state_dict = net_.state_dict()
223 | for key, param in state_dict.items():
224 | if key.startswith('module.'): # remove unnecessary 'module.'
225 | key = key[7:]
226 | state_dict[key] = param.cpu()
227 | save_dict[param_key_] = state_dict
228 |
229 | # avoid occasional writing errors
230 | retry = 3
231 | while retry > 0:
232 | try:
233 | torch.save(save_dict, save_path)
234 | except Exception as e:
235 | logger = get_root_logger()
236 | logger.warning(f'Save model error: {e}, remaining retry times: {retry - 1}')
237 | time.sleep(1)
238 | else:
239 | break
240 | finally:
241 | retry -= 1
242 | if retry == 0:
243 | logger.warning(f'Still cannot save {save_path}. Just ignore it.')
244 | # raise IOError(f'Cannot save {save_path}.')
245 |
246 | def _print_different_keys_loading(self, crt_net, load_net, strict=True):
247 | """Print keys with different name or different size when loading models.
248 |
249 | 1. Print keys with different names.
250 | 2. If strict=False, print the same key but with different tensor size.
251 | It also ignore these keys with different sizes (not load).
252 |
253 | Args:
254 | crt_net (torch model): Current network.
255 | load_net (dict): Loaded network.
256 | strict (bool): Whether strictly loaded. Default: True.
257 | """
258 | crt_net = self.get_bare_model(crt_net)
259 | crt_net = crt_net.state_dict()
260 | crt_net_keys = set(crt_net.keys())
261 | load_net_keys = set(load_net.keys())
262 |
263 | logger = get_root_logger()
264 | if crt_net_keys != load_net_keys:
265 | logger.warning('Current net - loaded net:')
266 | for v in sorted(list(crt_net_keys - load_net_keys)):
267 | logger.warning(f' {v}')
268 | logger.warning('Loaded net - current net:')
269 | for v in sorted(list(load_net_keys - crt_net_keys)):
270 | logger.warning(f' {v}')
271 |
272 | # check the size for the same keys
273 | if not strict:
274 | common_keys = crt_net_keys & load_net_keys
275 | for k in common_keys:
276 | if crt_net[k].size() != load_net[k].size():
277 | logger.warning(f'Size different, ignore [{k}]: crt_net: '
278 | f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
279 | load_net[k + '.ignore'] = load_net.pop(k)
280 |
281 | def load_network(self, net, load_path, strict=True, param_key='params'):
282 | """Load network.
283 |
284 | Args:
285 | load_path (str): The path of networks to be loaded.
286 | net (nn.Module): Network.
287 | strict (bool): Whether strictly loaded.
288 | param_key (str): The parameter key of loaded network. If set to
289 | None, use the root 'path'.
290 | Default: 'params'.
291 | """
292 | logger = get_root_logger()
293 | net = self.get_bare_model(net)
294 | load_net = torch.load(load_path, map_location=lambda storage, loc: storage)
295 | if param_key is not None:
296 | if param_key not in load_net and 'params' in load_net:
297 | param_key = 'params'
298 | logger.info('Loading: params_ema does not exist, use params.')
299 | load_net = load_net[param_key]
300 | logger.info(f'Loading {net.__class__.__name__} model from {load_path}, with param key: [{param_key}].')
301 | # remove unnecessary 'module.'
302 | for k, v in deepcopy(load_net).items():
303 | if k.startswith('module.'):
304 | load_net[k[7:]] = v
305 | load_net.pop(k)
306 | self._print_different_keys_loading(net, load_net, strict)
307 | net.load_state_dict(load_net, strict=strict)
308 |
309 | @master_only
310 | def save_training_state(self, epoch, current_iter):
311 | """Save training states during training, which will be used for
312 | resuming.
313 |
314 | Args:
315 | epoch (int): Current epoch.
316 | current_iter (int): Current iteration.
317 | """
318 | if current_iter != -1:
319 | state = {'epoch': epoch, 'iter': current_iter, 'optimizers': [], 'schedulers': []}
320 | for o in self.optimizers:
321 | state['optimizers'].append(o.state_dict())
322 | for s in self.schedulers:
323 | state['schedulers'].append(s.state_dict())
324 | save_filename = f'{current_iter}.state'
325 | save_path = os.path.join(self.opt['path']['training_states'], save_filename)
326 |
327 | # avoid occasional writing errors
328 | retry = 3
329 | while retry > 0:
330 | try:
331 | torch.save(state, save_path)
332 | except Exception as e:
333 | logger = get_root_logger()
334 | logger.warning(f'Save training state error: {e}, remaining retry times: {retry - 1}')
335 | time.sleep(1)
336 | else:
337 | break
338 | finally:
339 | retry -= 1
340 | if retry == 0:
341 | logger.warning(f'Still cannot save {save_path}. Just ignore it.')
342 | # raise IOError(f'Cannot save {save_path}.')
343 |
344 | def resume_training(self, resume_state):
345 | """Reload the optimizers and schedulers for resumed training.
346 |
347 | Args:
348 | resume_state (dict): Resume state.
349 | """
350 | resume_optimizers = resume_state['optimizers']
351 | resume_schedulers = resume_state['schedulers']
352 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
353 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
354 | for i, o in enumerate(resume_optimizers):
355 | self.optimizers[i].load_state_dict(o)
356 | for i, s in enumerate(resume_schedulers):
357 | self.schedulers[i].load_state_dict(s)
358 |
359 | def reduce_loss_dict(self, loss_dict):
360 | """reduce loss dict.
361 |
362 | In distributed training, it averages the losses among different GPUs .
363 |
364 | Args:
365 | loss_dict (OrderedDict): Loss dict.
366 | """
367 | with torch.no_grad():
368 | if self.opt['dist']:
369 | keys = []
370 | losses = []
371 | for name, value in loss_dict.items():
372 | keys.append(name)
373 | losses.append(value)
374 | losses = torch.stack(losses, 0)
375 | torch.distributed.reduce(losses, dst=0)
376 | if self.opt['rank'] == 0:
377 | losses /= self.opt['world_size']
378 | loss_dict = {key: loss for key, loss in zip(keys, losses)}
379 |
380 | log_dict = OrderedDict()
381 | for name, value in loss_dict.items():
382 | log_dict[name] = value.mean().item()
383 |
384 | return log_dict
--------------------------------------------------------------------------------
/diffglv/utils/beta_schedule.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import device, nn, einsum
4 | import torch.nn.functional as F
5 | from inspect import isfunction
6 | from functools import partial
7 | import numpy as np
8 | from tqdm import tqdm
9 |
10 |
11 | def _warmup_beta(linear_start, linear_end, n_timestep, warmup_frac):
12 | betas = linear_end * np.ones(n_timestep, dtype=np.float64)
13 | warmup_time = int(n_timestep * warmup_frac)
14 | betas[:warmup_time] = np.linspace(
15 | linear_start, linear_end, warmup_time, dtype=np.float64)
16 | return betas
17 |
18 |
19 | def make_beta_schedule(schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3):
20 | linear_start = float(linear_start)
21 | linear_end = float(linear_end)
22 | if schedule == 'quad':
23 | betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5,
24 | n_timestep, dtype=np.float64) ** 2
25 | elif schedule == 'linear':
26 | betas = np.linspace(linear_start, linear_end,
27 | n_timestep, dtype=np.float64)
28 | elif schedule == 'warmup10':
29 | betas = _warmup_beta(linear_start, linear_end,
30 | n_timestep, 0.1)
31 | elif schedule == 'warmup50':
32 | betas = _warmup_beta(linear_start, linear_end,
33 | n_timestep, 0.5)
34 | elif schedule == 'const':
35 | betas = linear_end * np.ones(n_timestep, dtype=np.float64)
36 | elif schedule == 'jsd': # 1/T, 1/(T-1), 1/(T-2), ..., 1
37 | betas = 1. / np.linspace(n_timestep,
38 | 1, n_timestep, dtype=np.float64)
39 | elif schedule == "cosine":
40 | timesteps = (
41 | torch.arange(n_timestep + 1, dtype=torch.float64) /
42 | n_timestep + cosine_s
43 | )
44 | alphas = timesteps / (1 + cosine_s) * math.pi / 2
45 | alphas = torch.cos(alphas).pow(2)
46 | alphas = alphas / alphas[0]
47 | betas = 1 - alphas[1:] / alphas[:-1]
48 | betas = betas.clamp(max=0.999)
49 | else:
50 | raise NotImplementedError(schedule)
51 | return betas
52 |
53 | def exists(x):
54 | return x is not None
55 |
56 |
57 | def default(val, d):
58 | if exists(val):
59 | return val
60 | return d() if isfunction(d) else d
61 |
--------------------------------------------------------------------------------
/diffglv/utils/extract_subimages.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | import sys
5 | from multiprocessing import Pool
6 | from os import path as osp
7 | from tqdm import tqdm
8 |
9 | from basicsr.utils import scandir
10 |
11 |
12 | def main():
13 | """A multi-thread tool to crop large images to sub-images for faster IO.
14 | It is used for DIV2K dataset.
15 | Args:
16 | opt (dict): Configuration dict. It contains:
17 | n_thread (int): Thread number.
18 | compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and
19 | longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2.
20 | input_folder (str): Path to the input folder.
21 | save_folder (str): Path to save folder.
22 | crop_size (int): Crop size.
23 | step (int): Step for overlapped sliding window.
24 | thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
25 | Usage:
26 | For each folder, run this script.
27 | Typically, there are four folders to be processed for DIV2K dataset.
28 | * DIV2K_train_HR
29 | * DIV2K_train_LR_bicubic/X2
30 | * DIV2K_train_LR_bicubic/X3
31 | * DIV2K_train_LR_bicubic/X4
32 | After process, each sub_folder should have the same number of subimages.
33 | Remember to modify opt configurations according to your settings.
34 | """
35 |
36 | opt = {}
37 | opt['n_thread'] = 20
38 | opt['compression_level'] = 3
39 |
40 | # opt['input_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_HR_X8'
41 | # opt['save_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_HR_X8_sub'
42 | # opt['crop_size'] = 512
43 | # opt['step'] = 256
44 | # opt['thresh_size'] = 0
45 | # extract_subimages(opt)
46 | #
47 | # opt['input_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_HR'
48 | # opt['save_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_HR_sub'
49 | # opt['crop_size'] = 512
50 | # opt['step'] = 256
51 | # opt['thresh_size'] = 0
52 | # extract_subimages(opt)
53 | #
54 | # opt['input_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_LR_X2_rs'
55 | # opt['save_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_LR_X2_rs_sub'
56 | # opt['crop_size'] = 512
57 | # opt['step'] = 256
58 | # opt['thresh_size'] = 0
59 | # extract_subimages(opt)
60 | #
61 | # opt['input_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_LR_X4_rs'
62 | # opt['save_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_LR_X4_rs_sub'
63 | # opt['crop_size'] = 512
64 | # opt['step'] = 256
65 | # opt['thresh_size'] = 0
66 | # extract_subimages(opt)
67 |
68 | opt['input_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_LR_X8_rs'
69 | opt['save_folder'] = '/mnt/petrelfs/gujinjin/fhyu/DF2K/DF2K_LR_X8_rs_sub'
70 | opt['crop_size'] = 512
71 | opt['step'] = 256
72 | opt['thresh_size'] = 0
73 | extract_subimages(opt)
74 |
75 |
76 |
77 | def extract_subimages(opt):
78 | """Crop images to subimages.
79 | Args:
80 | opt (dict): Configuration dict. It contains:
81 | input_folder (str): Path to the input folder.
82 | save_folder (str): Path to save folder.
83 | n_thread (int): Thread number.
84 | """
85 | input_folder = opt['input_folder']
86 | save_folder = opt['save_folder']
87 | if not osp.exists(save_folder):
88 | os.makedirs(save_folder)
89 | print(f'mkdir {save_folder} ...')
90 | else:
91 | print(f'Folder {save_folder} already exists. Exit.')
92 | sys.exit(1)
93 |
94 | img_list = list(scandir(input_folder, full_path=True))
95 |
96 | pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
97 | pool = Pool(opt['n_thread'])
98 | for path in img_list:
99 | pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
100 | pool.close()
101 | pool.join()
102 | pbar.close()
103 | print('All processes done.')
104 |
105 |
106 | def worker(path, opt):
107 | """Worker for each process.
108 | Args:
109 | path (str): Image path.
110 | opt (dict): Configuration dict. It contains:
111 | crop_size (int): Crop size.
112 | step (int): Step for overlapped sliding window.
113 | thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
114 | save_folder (str): Path to save folder.
115 | compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
116 | Returns:
117 | process_info (str): Process information displayed in progress bar.
118 | """
119 | crop_size = opt['crop_size']
120 | step = opt['step']
121 | thresh_size = opt['thresh_size']
122 | img_name, extension = osp.splitext(osp.basename(path))
123 |
124 | # remove the x2, x3, x4 and x8 in the filename for DIV2K
125 | img_name = img_name.replace('x2', '').replace('x3', '').replace('x4', '').replace('x8', '')
126 |
127 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
128 |
129 | h, w = img.shape[0:2]
130 | h_space = np.arange(0, h - crop_size + 1, step)
131 | if h - (h_space[-1] + crop_size) > thresh_size:
132 | h_space = np.append(h_space, h - crop_size)
133 | w_space = np.arange(0, w - crop_size + 1, step)
134 | if w - (w_space[-1] + crop_size) > thresh_size:
135 | w_space = np.append(w_space, w - crop_size)
136 |
137 | index = 0
138 | for x in h_space:
139 | for y in w_space:
140 | index += 1
141 | cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
142 | cropped_img = np.ascontiguousarray(cropped_img)
143 | cv2.imwrite(
144 | osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img,
145 | [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
146 | process_info = f'Processing {img_name} ...'
147 | return process_info
148 |
149 |
150 | if __name__ == '__main__':
151 | main()
--------------------------------------------------------------------------------
/diffglv/utils/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import time
4 |
5 | from basicsr.utils import get_root_logger
6 | from basicsr.utils.dist_util import master_only
7 |
8 | class MessageLogger():
9 | """Message logger for printing.
10 |
11 | Args:
12 | opt (dict): Config. It contains the following keys:
13 | name (str): Exp name.
14 | logger (dict): Contains 'print_freq' (str) for logger interval.
15 | train (dict): Contains 'total_iter' (int) for total iters.
16 | use_tb_logger (bool): Use tensorboard logger.
17 | start_iter (int): Start iter. Default: 1.
18 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
19 | """
20 |
21 | def __init__(self, opt, start_iter=1, tb_logger=None):
22 | self.exp_name = opt['name']
23 | self.interval = opt['logger']['print_freq']
24 | self.start_iter = start_iter
25 | self.max_iters = opt['train']['total_iter']
26 | self.use_tb_logger = opt['logger']['use_tb_logger']
27 | self.tb_logger = tb_logger
28 | self.start_time = time.time()
29 | self.logger = get_root_logger()
30 |
31 | def reset_start_time(self):
32 | self.start_time = time.time()
33 |
34 | @master_only
35 | def __call__(self, log_vars):
36 | """Format logging message.
37 |
38 | Args:
39 | log_vars (dict): It contains the following keys:
40 | epoch (int): Epoch number.
41 | iter (int): Current iter.
42 | lrs (list): List for learning rates.
43 |
44 | time (float): Iter time.
45 | data_time (float): Data time for each iter.
46 | """
47 | # epoch, iter, learning rates
48 | epoch = log_vars.pop('epoch')
49 | current_iter = log_vars.pop('iter')
50 | lrs = log_vars.pop('lrs')
51 |
52 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
53 | for v in lrs:
54 | message += f'{v:.3e},'
55 | message += ')] '
56 |
57 | if 'task' in log_vars.keys():
58 | message += '['
59 | task = log_vars.pop('task')
60 | message += f'task: {task}, '
61 | dataset = log_vars.pop('dataset')
62 | message += f'dataset: {dataset}'
63 | message += '] '
64 |
65 | # time and estimated time
66 | if 'time' in log_vars.keys():
67 | iter_time = log_vars.pop('time')
68 | data_time = log_vars.pop('data_time')
69 |
70 | total_time = time.time() - self.start_time
71 | time_sec_avg = total_time / (current_iter - self.start_iter + 1)
72 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
73 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
74 | message += f'[eta: {eta_str}, '
75 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
76 |
77 | # other items, especially losses
78 | for k, v in log_vars.items():
79 | message += f'{k}: {v:.4e} '
80 | # tensorboard logger
81 | if self.use_tb_logger and 'debug' not in self.exp_name:
82 | if k.startswith('l_'):
83 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
84 | else:
85 | self.tb_logger.add_scalar(k, v, current_iter)
86 | self.logger.info(message)
--------------------------------------------------------------------------------
/diffglv/utils/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from torch.optim.lr_scheduler import _LRScheduler
4 | import torch
5 |
6 |
7 | class MultiStepRestartLR(_LRScheduler):
8 | """ MultiStep with restarts learning rate scheme.
9 |
10 | Args:
11 | optimizer (torch.nn.optimizer): Torch optimizer.
12 | milestones (list): Iterations that will decrease learning rate.
13 | gamma (float): Decrease ratio. Default: 0.1.
14 | restarts (list): Restart iterations. Default: [0].
15 | restart_weights (list): Restart weights at each restart iteration.
16 | Default: [1].
17 | last_epoch (int): Used in _LRScheduler. Default: -1.
18 | """
19 |
20 | def __init__(self,
21 | optimizer,
22 | milestones,
23 | gamma=0.1,
24 | restarts=(0, ),
25 | restart_weights=(1, ),
26 | last_epoch=-1):
27 | self.milestones = Counter(milestones)
28 | self.gamma = gamma
29 | self.restarts = restarts
30 | self.restart_weights = restart_weights
31 | assert len(self.restarts) == len(
32 | self.restart_weights), 'restarts and their weights do not match.'
33 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
34 |
35 | def get_lr(self):
36 | if self.last_epoch in self.restarts:
37 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
38 | return [
39 | group['initial_lr'] * weight
40 | for group in self.optimizer.param_groups
41 | ]
42 | if self.last_epoch not in self.milestones:
43 | return [group['lr'] for group in self.optimizer.param_groups]
44 | return [
45 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
46 | for group in self.optimizer.param_groups
47 | ]
48 |
49 | class LinearLR(_LRScheduler):
50 | """
51 |
52 | Args:
53 | optimizer (torch.nn.optimizer): Torch optimizer.
54 | milestones (list): Iterations that will decrease learning rate.
55 | gamma (float): Decrease ratio. Default: 0.1.
56 | last_epoch (int): Used in _LRScheduler. Default: -1.
57 | """
58 |
59 | def __init__(self,
60 | optimizer,
61 | total_iter,
62 | last_epoch=-1):
63 | self.total_iter = total_iter
64 | super(LinearLR, self).__init__(optimizer, last_epoch)
65 |
66 | def get_lr(self):
67 | process = self.last_epoch / self.total_iter
68 | weight = (1 - process)
69 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups])
70 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
71 |
72 | class VibrateLR(_LRScheduler):
73 | """
74 |
75 | Args:
76 | optimizer (torch.nn.optimizer): Torch optimizer.
77 | milestones (list): Iterations that will decrease learning rate.
78 | gamma (float): Decrease ratio. Default: 0.1.
79 | last_epoch (int): Used in _LRScheduler. Default: -1.
80 | """
81 |
82 | def __init__(self,
83 | optimizer,
84 | total_iter,
85 | last_epoch=-1):
86 | self.total_iter = total_iter
87 | super(VibrateLR, self).__init__(optimizer, last_epoch)
88 |
89 | def get_lr(self):
90 | process = self.last_epoch / self.total_iter
91 |
92 | f = 0.1
93 | if process < 3 / 8:
94 | f = 1 - process * 8 / 3
95 | elif process < 5 / 8:
96 | f = 0.2
97 |
98 | T = self.total_iter // 80
99 | Th = T // 2
100 |
101 | t = self.last_epoch % T
102 |
103 | f2 = t / Th
104 | if t >= Th:
105 | f2 = 2 - f2
106 |
107 | weight = f * f2
108 |
109 | if self.last_epoch < Th:
110 | weight = max(0.1, weight)
111 |
112 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2))
113 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
114 |
115 | def get_position_from_periods(iteration, cumulative_period):
116 | """Get the position from a period list.
117 |
118 | It will return the index of the right-closest number in the period list.
119 | For example, the cumulative_period = [100, 200, 300, 400],
120 | if iteration == 50, return 0;
121 | if iteration == 210, return 2;
122 | if iteration == 300, return 2.
123 |
124 | Args:
125 | iteration (int): Current iteration.
126 | cumulative_period (list[int]): Cumulative period list.
127 |
128 | Returns:
129 | int: The position of the right-closest number in the period list.
130 | """
131 | for i, period in enumerate(cumulative_period):
132 | if iteration <= period:
133 | return i
134 |
135 |
136 | class CosineAnnealingRestartLR(_LRScheduler):
137 | """ Cosine annealing with restarts learning rate scheme.
138 |
139 | An example of config:
140 | periods = [10, 10, 10, 10]
141 | restart_weights = [1, 0.5, 0.5, 0.5]
142 | eta_min=1e-7
143 |
144 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
145 | scheduler will restart with the weights in restart_weights.
146 |
147 | Args:
148 | optimizer (torch.nn.optimizer): Torch optimizer.
149 | periods (list): Period for each cosine anneling cycle.
150 | restart_weights (list): Restart weights at each restart iteration.
151 | Default: [1].
152 | eta_min (float): The mimimum lr. Default: 0.
153 | last_epoch (int): Used in _LRScheduler. Default: -1.
154 | """
155 |
156 | def __init__(self,
157 | optimizer,
158 | periods,
159 | restart_weights=(1, ),
160 | eta_min=0,
161 | last_epoch=-1):
162 | self.periods = periods
163 | self.restart_weights = restart_weights
164 | self.eta_min = eta_min
165 | assert (len(self.periods) == len(self.restart_weights)
166 | ), 'periods and restart_weights should have the same length.'
167 | self.cumulative_period = [
168 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
169 | ]
170 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
171 |
172 | def get_lr(self):
173 | idx = get_position_from_periods(self.last_epoch,
174 | self.cumulative_period)
175 | current_weight = self.restart_weights[idx]
176 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
177 | current_period = self.periods[idx]
178 |
179 | return [
180 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
181 | (1 + math.cos(math.pi * (
182 | (self.last_epoch - nearest_restart) / current_period)))
183 | for base_lr in self.base_lrs
184 | ]
185 |
186 | class CosineAnnealingRestartCyclicLR(_LRScheduler):
187 | """ Cosine annealing with restarts learning rate scheme.
188 | An example of config:
189 | periods = [10, 10, 10, 10]
190 | restart_weights = [1, 0.5, 0.5, 0.5]
191 | eta_min=1e-7
192 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
193 | scheduler will restart with the weights in restart_weights.
194 | Args:
195 | optimizer (torch.nn.optimizer): Torch optimizer.
196 | periods (list): Period for each cosine anneling cycle.
197 | restart_weights (list): Restart weights at each restart iteration.
198 | Default: [1].
199 | eta_min (float): The mimimum lr. Default: 0.
200 | last_epoch (int): Used in _LRScheduler. Default: -1.
201 | """
202 |
203 | def __init__(self,
204 | optimizer,
205 | periods,
206 | restart_weights=(1, ),
207 | eta_mins=(0, ),
208 | last_epoch=-1):
209 | self.periods = periods
210 | self.restart_weights = restart_weights
211 | self.eta_mins = eta_mins
212 | assert (len(self.periods) == len(self.restart_weights)
213 | ), 'periods and restart_weights should have the same length.'
214 | self.cumulative_period = [
215 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
216 | ]
217 | super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch)
218 |
219 | def get_lr(self):
220 | idx = get_position_from_periods(self.last_epoch,
221 | self.cumulative_period)
222 | current_weight = self.restart_weights[idx]
223 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
224 | current_period = self.periods[idx]
225 | eta_min = self.eta_mins[idx]
226 |
227 | return [
228 | eta_min + current_weight * 0.5 * (base_lr - eta_min) *
229 | (1 + math.cos(math.pi * (
230 | (self.last_epoch - nearest_restart) / current_period)))
231 | for base_lr in self.base_lrs
232 | ]
233 |
--------------------------------------------------------------------------------
/diffglv/utils/make_ds.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from PIL import Image
3 | import os
4 | import sys
5 | from multiprocessing import Pool
6 | from os import path as osp
7 | from tqdm import tqdm
8 | from basicsr.utils import scandir
9 |
10 |
11 | def make_downsampling(input_folder, save_folder, scale, rescaling=False, downsample_type='bicubic',
12 | n_thread=20, wash_only=False):
13 | """Crop images to subimages.
14 | Args:
15 | opt (dict): Configuration dict. It contains:
16 | input_folder (str): Path to the input folder.
17 | save_folder (str): Path to save folder.
18 | n_thread (int): Thread number.
19 | """
20 | opt = {}
21 | opt['scale'] = scale
22 | opt['rescaling'] = rescaling
23 | opt['save_folder'] = save_folder
24 | opt['downsample_type'] = downsample_type
25 | opt['wash_only'] = wash_only
26 |
27 | if not osp.exists(save_folder):
28 | os.makedirs(save_folder)
29 | print(f'mkdir {save_folder} ...')
30 | else:
31 | print(f'Folder {save_folder} already exists. Exit.')
32 | sys.exit(1)
33 |
34 | img_list = list(scandir(input_folder, full_path=True))
35 |
36 | pbar = tqdm(total=len(img_list), unit='image', desc='Extract')
37 | pool = Pool(n_thread)
38 | for path in img_list:
39 | pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1))
40 | pool.close()
41 | pool.join()
42 | pbar.close()
43 | print('All processes done.')
44 |
45 |
46 | def worker(path, opt):
47 | """Worker for each process.
48 | Args:
49 | path (str): Image path.
50 | opt (dict): Configuration dict. It contains:
51 | crop_size (int): Crop size.
52 | step (int): Step for overlapped sliding window.
53 | thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped.
54 | save_folder (str): Path to save folder.
55 | compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION.
56 | Returns:
57 | process_info (str): Process information displayed in progress bar.
58 | """
59 | scale = opt['scale']
60 | save_folder = opt['save_folder']
61 | downsample_type = opt['downsample_type']
62 |
63 | hr = Image.open(path)
64 | w, h = hr.size
65 | if w % scale + h % scale:
66 | print('\n\nHR needs data washing\n')
67 | hr = hr.crop([0, 0, w //scale * scale, h //scale *scale])
68 |
69 | ds_func = Image.Resampling.BICUBIC if downsample_type == 'bicubic' else Image.Resampling.BILINEAR
70 |
71 | if not opt['wash_only']:
72 | lr = hr.resize((w//scale, h//scale), ds_func)
73 | if opt['rescaling']:
74 | lr = lr.resize((w//scale*scale, h//scale*scale), ds_func)
75 | else:
76 | lr = hr
77 |
78 | lr.save(osp.join(save_folder, osp.split(path)[-1]))
79 | process_info = f'Processing {osp.split(path)[-1]} ...'
80 | return process_info
81 |
82 |
83 | if __name__ == '__main__':
84 | parser = argparse.ArgumentParser()
85 | parser.add_argument('--src', type=str)
86 | parser.add_argument('--dst', type=str)
87 | parser.add_argument('--scale', type=int)
88 | parser.add_argument('--rescaling', '-rs', action='store_true')
89 | parser.add_argument('--n_worker', type=int, default=20)
90 | parser.add_argument('--ds_func', type=str, default='bicubic')
91 | parser.add_argument('--wash_only', '-wo', action='store_true')
92 |
93 | args = parser.parse_args()
94 |
95 | make_downsampling(args.src, args.dst, args.scale, args.rescaling, args.ds_func, args.n_worker, args.wash_only)
--------------------------------------------------------------------------------
/diffglv/utils/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import random
3 | import torch
4 | import yaml
5 | from collections import OrderedDict
6 | from os import path as osp
7 |
8 | from basicsr.utils import set_random_seed
9 | from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
10 |
11 |
12 | def ordered_yaml():
13 | """Support OrderedDict for yaml.
14 |
15 | Returns:
16 | yaml Loader and Dumper.
17 | """
18 | try:
19 | from yaml import CDumper as Dumper
20 | from yaml import CLoader as Loader
21 | except ImportError:
22 | from yaml import Dumper, Loader
23 |
24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
25 |
26 | def dict_representer(dumper, data):
27 | return dumper.represent_dict(data.items())
28 |
29 | def dict_constructor(loader, node):
30 | return OrderedDict(loader.construct_pairs(node))
31 |
32 | Dumper.add_representer(OrderedDict, dict_representer)
33 | Loader.add_constructor(_mapping_tag, dict_constructor)
34 | return Loader, Dumper
35 |
36 |
37 | def dict2str(opt, indent_level=1):
38 | """dict to string for printing options.
39 |
40 | Args:
41 | opt (dict): Option dict.
42 | indent_level (int): Indent level. Default: 1.
43 |
44 | Return:
45 | (str): Option string for printing.
46 | """
47 | msg = '\n'
48 | for k, v in opt.items():
49 | if isinstance(v, dict):
50 | msg += ' ' * (indent_level * 2) + k + ':['
51 | msg += dict2str(v, indent_level + 1)
52 | msg += ' ' * (indent_level * 2) + ']\n'
53 | else:
54 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
55 | return msg
56 |
57 |
58 | def _postprocess_yml_value(value):
59 | # None
60 | if value == '~' or value.lower() == 'none':
61 | return None
62 | # bool
63 | if value.lower() == 'true':
64 | return True
65 | elif value.lower() == 'false':
66 | return False
67 | # !!float number
68 | if value.startswith('!!float'):
69 | return float(value.replace('!!float', ''))
70 | # number
71 | if value.isdigit():
72 | return int(value)
73 | elif value.replace('.', '', 1).isdigit() and value.count('.') < 2:
74 | return float(value)
75 | # list
76 | if value.startswith('['):
77 | return eval(value)
78 | # str
79 | return value
80 |
81 |
82 | def parse_options(root_path, is_train=True):
83 | parser = argparse.ArgumentParser()
84 | parser.add_argument('-opt', type=str, required=True, help='Path to option YAML file.')
85 | parser.add_argument('--launcher', choices=['none', 'pytorch', 'slurm'], default='none', help='job launcher')
86 | parser.add_argument('--auto_resume', action='store_true')
87 | parser.add_argument('--debug', action='store_true')
88 | parser.add_argument('--local_rank', type=int, default=0)
89 | parser.add_argument(
90 | '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
91 | args = parser.parse_args()
92 |
93 | # parse yml to dict
94 | with open(args.opt, mode='r') as f:
95 | opt = yaml.load(f, Loader=ordered_yaml()[0])
96 |
97 | # distributed settings
98 | if args.launcher == 'none':
99 | opt['dist'] = False
100 | print('Disable distributed.', flush=True)
101 | else:
102 | opt['dist'] = True
103 | if args.launcher == 'slurm' and 'dist_params' in opt:
104 | init_dist(args.launcher, **opt['dist_params'])
105 | else:
106 | init_dist(args.launcher)
107 | opt['rank'], opt['world_size'] = get_dist_info()
108 |
109 | # random seed
110 | seed = opt.get('manual_seed')
111 | if seed is None:
112 | seed = random.randint(1, 10000)
113 | opt['manual_seed'] = seed
114 | set_random_seed(seed + opt['rank'])
115 |
116 | # force to update yml options
117 | if args.force_yml is not None:
118 | for entry in args.force_yml:
119 | # now do not support creating new keys
120 | keys, value = entry.split('=')
121 | keys, value = keys.strip(), value.strip()
122 | value = _postprocess_yml_value(value)
123 | eval_str = 'opt'
124 | for key in keys.split(':'):
125 | eval_str += f'["{key}"]'
126 | eval_str += '=value'
127 | # using exec function
128 | exec(eval_str)
129 |
130 | opt['auto_resume'] = args.auto_resume
131 | opt['is_train'] = is_train
132 |
133 | # debug setting
134 | if args.debug and not opt['name'].startswith('debug'):
135 | opt['name'] = 'debug_' + opt['name']
136 |
137 | if opt['num_gpu'] == 'auto':
138 | opt['num_gpu'] = torch.cuda.device_count()
139 |
140 | # datasets
141 | for phase, dataset in opt['datasets'].items():
142 | # for multiple datasets, e.g., val_1, val_2; test_1, test_2
143 | phase = phase.split('_')[0]
144 | dataset['phase'] = phase
145 | if 'scale' in opt:
146 | dataset['scale'] = opt['scale']
147 | if dataset.get('dataroot_gt') is not None:
148 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
149 | if dataset.get('dataroot_lq') is not None:
150 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
151 |
152 | # paths
153 | for key, val in opt['path'].items():
154 | if (val is not None) and ('resume_state' in key or 'pretrain_network' in key):
155 | opt['path'][key] = osp.expanduser(val)
156 |
157 | if is_train:
158 | experiments_root = osp.join(root_path, 'experiments', opt['name'])
159 | opt['path']['experiments_root'] = experiments_root
160 | opt['path']['models'] = osp.join(experiments_root, 'models')
161 | opt['path']['training_states'] = osp.join(experiments_root, 'training_states')
162 | opt['path']['log'] = experiments_root
163 | opt['path']['visualization'] = osp.join(experiments_root, 'visualization')
164 |
165 | # change some options for debug mode
166 | if 'debug' in opt['name']:
167 | if 'val' in opt:
168 | opt['val']['val_freq'] = 8
169 | opt['logger']['print_freq'] = 1
170 | opt['logger']['save_checkpoint_freq'] = 8
171 | else: # test
172 | results_root = osp.join(root_path, 'results', opt['name'])
173 | opt['path']['results_root'] = results_root
174 | opt['path']['log'] = results_root
175 | opt['path']['visualization'] = osp.join(results_root, 'visualization')
176 |
177 | return opt, args
178 |
179 |
180 | @master_only
181 | def copy_opt_file(opt_file, experiments_root):
182 | # copy the yml file to the experiment root
183 | import sys
184 | import time
185 | from shutil import copyfile
186 | cmd = ' '.join(sys.argv)
187 | filename = osp.join(experiments_root, osp.basename(opt_file))
188 | copyfile(opt_file, filename)
189 |
190 | with open(filename, 'r+') as f:
191 | lines = f.readlines()
192 | lines.insert(0, f'# GENERATE TIME: {time.asctime()}\n# CMD:\n# {cmd}\n\n')
193 | f.seek(0)
194 | f.writelines(lines)
195 |
--------------------------------------------------------------------------------
/diffglv/utils/transforms.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import random
3 | import numpy as np
4 | import torch
5 |
6 |
7 | def mod_crop(img, scale):
8 | """Mod crop images, used during testing.
9 |
10 | Args:
11 | img (ndarray): Input image.
12 | scale (int): Scale factor.
13 |
14 | Returns:
15 | ndarray: Result image.
16 | """
17 | img = img.copy()
18 | if img.ndim in (2, 3):
19 | h, w = img.shape[0], img.shape[1]
20 | h_remainder, w_remainder = h % scale, w % scale
21 | img = img[:h - h_remainder, :w - w_remainder, ...]
22 | else:
23 | raise ValueError(f'Wrong img ndim: {img.ndim}.')
24 | return img
25 |
26 |
27 | def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
28 | """Paired random crop. Support Numpy array and Tensor inputs.
29 |
30 | It crops lists of lq and gt images with corresponding locations.
31 |
32 | Args:
33 | img_gts (list[ndarray] | ndarray | list[Tensor] | Tensor): GT images. Note that all images
34 | should have the same shape. If the input is an ndarray, it will
35 | be transformed to a list containing itself.
36 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
37 | should have the same shape. If the input is an ndarray, it will
38 | be transformed to a list containing itself.
39 | gt_patch_size (int): GT patch size.
40 | scale (int): Scale factor.
41 | gt_path (str): Path to ground-truth. Default: None.
42 |
43 | Returns:
44 | list[ndarray] | ndarray: GT images and LQ images. If returned results
45 | only have one element, just return ndarray.
46 | """
47 |
48 | if not isinstance(img_gts, list):
49 | img_gts = [img_gts]
50 | if not isinstance(img_lqs, list):
51 | img_lqs = [img_lqs]
52 |
53 | # determine input type: Numpy array or Tensor
54 | input_type = 'Tensor' if torch.is_tensor(img_gts[0]) else 'Numpy'
55 |
56 | if input_type == 'Tensor':
57 | h_lq, w_lq = img_lqs[0].size()[-2:]
58 | h_gt, w_gt = img_gts[0].size()[-2:]
59 | else:
60 | h_lq, w_lq = img_lqs[0].shape[0:2]
61 | h_gt, w_gt = img_gts[0].shape[0:2]
62 | lq_patch_size = gt_patch_size // scale
63 |
64 | if h_gt != h_lq * scale or w_gt != w_lq * scale:
65 | raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
66 | f'multiplication of LQ ({h_lq}, {w_lq}).')
67 | if h_lq < lq_patch_size or w_lq < lq_patch_size:
68 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
69 | f'({lq_patch_size}, {lq_patch_size}). '
70 | f'Please remove {gt_path}.')
71 |
72 | # randomly choose top and left coordinates for lq patch
73 | top = random.randint(0, h_lq - lq_patch_size)
74 | left = random.randint(0, w_lq - lq_patch_size)
75 |
76 | # crop lq patch
77 | if input_type == 'Tensor':
78 | img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
79 | else:
80 | img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
81 |
82 | # crop corresponding gt patch
83 | top_gt, left_gt = int(top * scale), int(left * scale)
84 | if input_type == 'Tensor':
85 | img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
86 | else:
87 | img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
88 | if len(img_gts) == 1:
89 | img_gts = img_gts[0]
90 | if len(img_lqs) == 1:
91 | img_lqs = img_lqs[0]
92 | return img_gts, img_lqs
93 |
94 |
95 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
96 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
97 |
98 | We use vertical flip and transpose for rotation implementation.
99 | All the images in the list use the same augmentation.
100 |
101 | Args:
102 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input
103 | is an ndarray, it will be transformed to a list.
104 | hflip (bool): Horizontal flip. Default: True.
105 | rotation (bool): Ratotation. Default: True.
106 | flows (list[ndarray]: Flows to be augmented. If the input is an
107 | ndarray, it will be transformed to a list.
108 | Dimension is (h, w, 2). Default: None.
109 | return_status (bool): Return the status of flip and rotation.
110 | Default: False.
111 |
112 | Returns:
113 | list[ndarray] | ndarray: Augmented images and flows. If returned
114 | results only have one element, just return ndarray.
115 |
116 | """
117 | hflip = hflip and random.random() < 0.5
118 | vflip = rotation and random.random() < 0.5
119 | rot90 = rotation and random.random() < 0.5
120 |
121 | def _augment(img):
122 | if hflip: # horizontal
123 | cv2.flip(img, 1, img)
124 | if vflip: # vertical
125 | cv2.flip(img, 0, img)
126 | if rot90:
127 | img = img.transpose(1, 0, 2)
128 | return img
129 |
130 | def _augment_flow(flow):
131 | if hflip: # horizontal
132 | cv2.flip(flow, 1, flow)
133 | flow[:, :, 0] *= -1
134 | if vflip: # vertical
135 | cv2.flip(flow, 0, flow)
136 | flow[:, :, 1] *= -1
137 | if rot90:
138 | flow = flow.transpose(1, 0, 2)
139 | flow = flow[:, :, [1, 0]]
140 | return flow
141 |
142 | if not isinstance(imgs, list):
143 | imgs = [imgs]
144 | imgs = [_augment(img) for img in imgs]
145 | if len(imgs) == 1:
146 | imgs = imgs[0]
147 |
148 | if flows is not None:
149 | if not isinstance(flows, list):
150 | flows = [flows]
151 | flows = [_augment_flow(flow) for flow in flows]
152 | if len(flows) == 1:
153 | flows = flows[0]
154 | return imgs, flows
155 | else:
156 | if return_status:
157 | return imgs, (hflip, vflip, rot90)
158 | else:
159 | return imgs
160 |
161 |
162 | def img_rotate(img, angle, center=None, scale=1.0):
163 | """Rotate image.
164 |
165 | Args:
166 | img (ndarray): Image to be rotated.
167 | angle (float): Rotation angle in degrees. Positive values mean
168 | counter-clockwise rotation.
169 | center (tuple[int]): Rotation center. If the center is None,
170 | initialize it as the center of the image. Default: None.
171 | scale (float): Isotropic scale factor. Default: 1.0.
172 | """
173 | (h, w) = img.shape[:2]
174 |
175 | if center is None:
176 | center = (w // 2, h // 2)
177 |
178 | matrix = cv2.getRotationMatrix2D(center, angle, scale)
179 | rotated_img = cv2.warpAffine(img, matrix, (w, h))
180 | return rotated_img
181 |
182 |
183 | def data_augmentation(image, mode):
184 | """
185 | Performs data augmentation of the input image
186 | Input:
187 | image: a cv2 (OpenCV) image
188 | mode: int. Choice of transformation to apply to the image
189 | 0 - no transformation
190 | 1 - flip up and down
191 | 2 - rotate counterwise 90 degree
192 | 3 - rotate 90 degree and flip up and down
193 | 4 - rotate 180 degree
194 | 5 - rotate 180 degree and flip
195 | 6 - rotate 270 degree
196 | 7 - rotate 270 degree and flip
197 | """
198 | if mode == 0:
199 | # original
200 | out = image
201 | elif mode == 1:
202 | # flip up and down
203 | out = np.flipud(image)
204 | elif mode == 2:
205 | # rotate counterwise 90 degree
206 | out = np.rot90(image)
207 | elif mode == 3:
208 | # rotate 90 degree and flip up and down
209 | out = np.rot90(image)
210 | out = np.flipud(out)
211 | elif mode == 4:
212 | # rotate 180 degree
213 | out = np.rot90(image, k=2)
214 | elif mode == 5:
215 | # rotate 180 degree and flip
216 | out = np.rot90(image, k=2)
217 | out = np.flipud(out)
218 | elif mode == 6:
219 | # rotate 270 degree
220 | out = np.rot90(image, k=3)
221 | elif mode == 7:
222 | # rotate 270 degree and flip
223 | out = np.rot90(image, k=3)
224 | out = np.flipud(out)
225 | else:
226 | raise Exception('Invalid choice of image transformation')
227 |
228 | return out
229 |
230 | def random_augmentation(*args):
231 | out = []
232 | flag_aug = random.randint(0,7)
233 | for data in args:
234 | out.append(data_augmentation(data, flag_aug).copy())
235 | return out
236 |
--------------------------------------------------------------------------------
/experiments/README.md:
--------------------------------------------------------------------------------
1 | Place pretrained models in `pretrained_models`.
2 |
--------------------------------------------------------------------------------
/experiments/pretrained_models/README.md:
--------------------------------------------------------------------------------
1 | Place pretrained models here.
--------------------------------------------------------------------------------
/figs/BI-DiffSR.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/BI-DiffSR.png
--------------------------------------------------------------------------------
/figs/F1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/F1.png
--------------------------------------------------------------------------------
/figs/F2-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/F2-1.png
--------------------------------------------------------------------------------
/figs/F2-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/F2-2.png
--------------------------------------------------------------------------------
/figs/F3-1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/F3-1.png
--------------------------------------------------------------------------------
/figs/F3-2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/F3-2.png
--------------------------------------------------------------------------------
/figs/T1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/T1.png
--------------------------------------------------------------------------------
/figs/compare/ComS_img_023_BBCU_x4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_023_BBCU_x4.png
--------------------------------------------------------------------------------
/figs/compare/ComS_img_023_BI-DiffSR_x4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_023_BI-DiffSR_x4.png
--------------------------------------------------------------------------------
/figs/compare/ComS_img_023_Bicubic_x4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_023_Bicubic_x4.png
--------------------------------------------------------------------------------
/figs/compare/ComS_img_023_HR_x4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_023_HR_x4.png
--------------------------------------------------------------------------------
/figs/compare/ComS_img_023_SR3_x4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_023_SR3_x4.png
--------------------------------------------------------------------------------
/figs/compare/ComS_img_033_BBCU_x4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_033_BBCU_x4.png
--------------------------------------------------------------------------------
/figs/compare/ComS_img_033_BI-DiffSR_x4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_033_BI-DiffSR_x4.png
--------------------------------------------------------------------------------
/figs/compare/ComS_img_033_Bicubic_x4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_033_Bicubic_x4.png
--------------------------------------------------------------------------------
/figs/compare/ComS_img_033_HR_x4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_033_HR_x4.png
--------------------------------------------------------------------------------
/figs/compare/ComS_img_033_SR3_x4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/compare/ComS_img_033_SR3_x4.png
--------------------------------------------------------------------------------
/figs/logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/zhengchen1999/BI-DiffSR/7f3cf4097e973dc54218bf27f46c7bc49acf9a70/figs/logo.png
--------------------------------------------------------------------------------
/options/test/test_BI_DiffSR_x2.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: test_BI_DiffSR_DDIM_S50_x2
3 | model_type: BIDiffSRModel
4 | scale: 2
5 | num_gpu: 1
6 | manual_seed: 10
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | test_1:
11 | task: SR
12 | name: Set5
13 | type: MultiPairedImageDataset
14 | dataroot_gt: datasets/benchmark/Set5/HR
15 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2
16 | filename_tmpl: '{}x2'
17 | io_backend:
18 | type: disk
19 |
20 | test_2:
21 | task: SR
22 | name: B100
23 | type: MultiPairedImageDataset
24 | dataroot_gt: datasets/benchmark/B100/HR
25 | dataroot_lq: datasets/benchmark/B100/LR_bicubic/X2
26 | filename_tmpl: '{}x2'
27 | io_backend:
28 | type: disk
29 |
30 | test_3:
31 | task: SR
32 | name: Urban100
33 | type: MultiPairedImageDataset
34 | dataroot_gt: datasets/benchmark/Urban100/HR
35 | dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X2
36 | filename_tmpl: '{}x2'
37 | io_backend:
38 | type: disk
39 |
40 | test_4:
41 | task: SR
42 | name: Manga109
43 | type: MultiPairedImageDataset
44 | dataroot_gt: datasets/benchmark/Manga109/HR
45 | dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X2
46 | filename_tmpl: '{}_LRBI_x2'
47 | io_backend:
48 | type: disk
49 |
50 | # network structures
51 | network_g:
52 | type: BIDiffSRUNet
53 | in_channel: 6
54 | out_channel: 3
55 | inner_channel: 64
56 | norm_groups: 16
57 | channel_mults: [1, 2, 4, 8]
58 | attn_res: []
59 | res_blocks: 2
60 | dropout: 0.2
61 | image_size: 256
62 | fp_res: [256, 128]
63 | total_step: 2000
64 | dynamic_group: 5 # K
65 |
66 | # schedule
67 | beta_schedule:
68 | scheduler_type: DDIM
69 | schedule: linear
70 | n_timestep: 2000
71 | linear_start: !!float 1e-6
72 | linear_end: !!float 1e-2
73 | prediction_type: epsilon
74 | num_inference_steps: 50
75 | guidance_scale: 7.5
76 | is_guidance: False
77 |
78 | # path
79 | path:
80 | pretrain_network_g: experiments/pretrained_models/BI_DiffSR_x2.pth
81 | strict_load_g: true
82 | resume_state: params
83 |
84 | # validation settings
85 | val:
86 | save_img: true
87 | suffix: 'test' # add suffix to saved images, if None, use exp name
88 |
89 | metrics:
90 | psnr: # metric name, can be arbitrary
91 | type: calculate_psnr
92 | crop_border: 2
93 | test_y_channel: true
94 |
95 | ssim: # metric name, can be arbitrary
96 | type: calculate_ssim
97 | crop_border: 2
98 | test_y_channel: true
99 |
100 | lpips: # metric name, can be arbitrary
101 | type: calculate_lpips
102 | crop_border: 2
103 | better: lower
104 |
--------------------------------------------------------------------------------
/options/test/test_BI_DiffSR_x4.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: test_BI_DiffSR_DDIM_S50_x4
3 | model_type: BIDiffSRModel
4 | scale: 4
5 | num_gpu: 1
6 | manual_seed: 10
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | test_1:
11 | task: SR
12 | name: Set5
13 | type: MultiPairedImageDataset
14 | dataroot_gt: datasets/benchmark/Set5/HR
15 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
16 | filename_tmpl: '{}x4'
17 | io_backend:
18 | type: disk
19 |
20 | test_2:
21 | task: SR
22 | name: B100
23 | type: MultiPairedImageDataset
24 | dataroot_gt: datasets/benchmark/B100/HR
25 | dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
26 | filename_tmpl: '{}x4'
27 | io_backend:
28 | type: disk
29 |
30 | test_3:
31 | task: SR
32 | name: Urban100
33 | type: MultiPairedImageDataset
34 | dataroot_gt: datasets/benchmark/Urban100/HR
35 | dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
36 | filename_tmpl: '{}x4'
37 | io_backend:
38 | type: disk
39 |
40 | test_4:
41 | task: SR
42 | name: Manga109
43 | type: MultiPairedImageDataset
44 | dataroot_gt: datasets/benchmark/Manga109/HR
45 | dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4
46 | filename_tmpl: '{}_LRBI_x4'
47 | io_backend:
48 | type: disk
49 |
50 | # network structures
51 | network_g:
52 | type: BIDiffSRUNet
53 | in_channel: 6
54 | out_channel: 3
55 | inner_channel: 64
56 | norm_groups: 16
57 | channel_mults: [1, 2, 4, 8]
58 | attn_res: []
59 | res_blocks: 2
60 | dropout: 0.2
61 | image_size: 256
62 | fp_res: [256, 128]
63 | total_step: 2000
64 | dynamic_group: 5 # K
65 |
66 | # schedule
67 | beta_schedule:
68 | scheduler_type: DDIM
69 | schedule: linear
70 | n_timestep: 2000
71 | linear_start: !!float 1e-6
72 | linear_end: !!float 1e-2
73 | prediction_type: epsilon
74 | num_inference_steps: 50
75 | guidance_scale: 7.5
76 | is_guidance: False
77 |
78 | # path
79 | path:
80 | pretrain_network_g: experiments/pretrained_models/BI_DiffSR_x4.pth
81 | strict_load_g: true
82 | resume_state: params
83 |
84 | # validation settings
85 | val:
86 | save_img: true
87 | suffix: 'test' # add suffix to saved images, if None, use exp name
88 |
89 | metrics:
90 | psnr: # metric name, can be arbitrary
91 | type: calculate_psnr
92 | crop_border: 4
93 | test_y_channel: true
94 |
95 | ssim: # metric name, can be arbitrary
96 | type: calculate_ssim
97 | crop_border: 4
98 | test_y_channel: true
99 |
100 | lpips: # metric name, can be arbitrary
101 | type: calculate_lpips
102 | crop_border: 4
103 | better: lower
104 |
105 |
--------------------------------------------------------------------------------
/options/train/train_BI_DiffSR_x2.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: train_BI_DiffSR_DDIM_S50_x2
3 | model_type: BIDiffSRModel
4 | scale: 2
5 | num_gpu: auto
6 | manual_seed: 10
7 | find_unused_parameters: True
8 |
9 | # dataset and data loader settings
10 | datasets:
11 | train:
12 | task: SR
13 | name: DF2K
14 | type: MultiPairedImageDataset
15 | dataroot_gt: datasets/DF2K/HR
16 | dataroot_lq: datasets/DF2K/LR_bicubic/X2
17 | filename_tmpl: '{}x2'
18 | io_backend:
19 | type: disk
20 |
21 | gt_size: 128
22 | use_hflip: True
23 | use_rot: True
24 |
25 | # data loader
26 | use_shuffle: True
27 | num_worker_per_gpu: 8
28 | batch_size_per_gpu: 4
29 | dataset_enlarge_ratio: 100
30 | prefetch_mode: ~
31 |
32 | val:
33 | task: SR
34 | name: Set5
35 | type: MultiPairedImageDataset
36 | dataroot_gt: datasets/benchmark/Set5/HR
37 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X2
38 | filename_tmpl: '{}x2'
39 | io_backend:
40 | type: disk
41 |
42 | # network structures
43 | network_g:
44 | type: BIDiffSRUNet
45 | in_channel: 6
46 | out_channel: 3
47 | inner_channel: 64
48 | norm_groups: 16
49 | channel_mults: [1, 2, 4, 8]
50 | attn_res: []
51 | res_blocks: 2
52 | dropout: 0.2
53 | image_size: 256
54 | fp_res: [256, 128]
55 | total_step: 2000
56 | dynamic_group: 5 # K
57 |
58 | # schedule
59 | beta_schedule:
60 | scheduler_type: DDIM
61 | schedule: linear
62 | n_timestep: 2000
63 | linear_start: !!float 1e-6
64 | linear_end: !!float 1e-2
65 | prediction_type: epsilon
66 | num_inference_steps: 50
67 | guidance_scale: 7.5
68 | is_guidance: False
69 |
70 | # path
71 | path:
72 | pretrain_network_g: ~
73 | strict_load_g: true
74 | resume_state: ~
75 |
76 | train:
77 | # ema_decay: 0.999
78 | optim_g:
79 | type: Adam
80 | lr: !!float 1e-4
81 | weight_decay: 0
82 | betas: [0.9, 0.99]
83 |
84 | scheduler:
85 | type: MultiStepLR
86 | milestones: [500000]
87 | gamma: 1
88 |
89 | total_iter: 1000000
90 | warmup_iter: -1 # no warm up
91 |
92 | # losses
93 | pixel_opt:
94 | type: L1Loss
95 | loss_weight: 1.0
96 | reduction: mean
97 |
98 | # validation settings
99 | val:
100 | val_freq: !!float 2e4
101 | save_img: false
102 |
103 | metrics:
104 | psnr: # metric name, can be arbitrary
105 | type: calculate_psnr
106 | crop_border: 2
107 | test_y_channel: true
108 |
109 | ssim: # metric name, can be arbitrary
110 | type: calculate_ssim
111 | crop_border: 2
112 | test_y_channel: true
113 |
114 | lpips: # metric name, can be arbitrary
115 | type: calculate_lpips
116 | crop_border: 2
117 | better: lower
118 |
119 | # logging settings
120 | logger:
121 | print_freq: 500
122 | save_checkpoint_freq: !!float 2e4
123 | use_tb_logger: true
124 | wandb:
125 | project: ~
126 | resume_id: ~
127 |
128 | # dist training settings
129 | dist_params:
130 | backend: nccl
131 | port: 29500
132 |
--------------------------------------------------------------------------------
/options/train/train_BI_DiffSR_x4.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: train_BI_DiffSR_DDIM_S50_x4
3 | model_type: BIDiffSRModel
4 | scale: 4
5 | num_gpu: auto
6 | manual_seed: 10
7 | find_unused_parameters: True
8 |
9 | # dataset and data loader settings
10 | datasets:
11 | train:
12 | task: SR
13 | name: DF2K
14 | type: MultiPairedImageDataset
15 | dataroot_gt: datasets/DF2K/HR
16 | dataroot_lq: datasets/DF2K/LR_bicubic/X4
17 | filename_tmpl: '{}x4'
18 | io_backend:
19 | type: disk
20 |
21 | gt_size: 256
22 | use_hflip: True
23 | use_rot: True
24 |
25 | # data loader
26 | use_shuffle: True
27 | num_worker_per_gpu: 8
28 | batch_size_per_gpu: 4
29 | dataset_enlarge_ratio: 100
30 | prefetch_mode: ~
31 |
32 | val:
33 | task: SR
34 | name: Set5
35 | type: MultiPairedImageDataset
36 | dataroot_gt: datasets/benchmark/Set5/HR
37 | dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
38 | filename_tmpl: '{}x4'
39 | io_backend:
40 | type: disk
41 |
42 | # network structures
43 | network_g:
44 | type: BIDiffSRUNet
45 | in_channel: 6
46 | out_channel: 3
47 | inner_channel: 64
48 | norm_groups: 16
49 | channel_mults: [1, 2, 4, 8]
50 | attn_res: []
51 | res_blocks: 2
52 | dropout: 0.2
53 | image_size: 256
54 | fp_res: [256, 128]
55 | total_step: 2000
56 | dynamic_group: 5 # K
57 |
58 | # schedule
59 | beta_schedule:
60 | scheduler_type: DDIM
61 | schedule: linear
62 | n_timestep: 2000
63 | linear_start: !!float 1e-6
64 | linear_end: !!float 1e-2
65 | prediction_type: epsilon
66 | num_inference_steps: 50
67 | guidance_scale: 7.5
68 | is_guidance: False
69 |
70 | # path
71 | path:
72 | pretrain_network_g: ~
73 | strict_load_g: false
74 | resume_state: ~
75 |
76 | train:
77 | # ema_decay: 0.999
78 | optim_g:
79 | type: Adam
80 | lr: !!float 1e-4
81 | weight_decay: 0
82 | betas: [0.9, 0.99]
83 |
84 | scheduler:
85 | type: MultiStepLR
86 | milestones: [500000]
87 | gamma: 1
88 |
89 | total_iter: 1000000
90 | warmup_iter: -1 # no warm up
91 |
92 | # losses
93 | pixel_opt:
94 | type: L1Loss
95 | loss_weight: 1.0
96 | reduction: mean
97 |
98 | # validation settings
99 | val:
100 | val_freq: !!float 2e4
101 | save_img: false
102 |
103 | metrics:
104 | psnr: # metric name, can be arbitrary
105 | type: calculate_psnr
106 | crop_border: 4
107 | test_y_channel: true
108 |
109 | ssim: # metric name, can be arbitrary
110 | type: calculate_ssim
111 | crop_border: 4
112 | test_y_channel: true
113 |
114 | lpips: # metric name, can be arbitrary
115 | type: calculate_lpips
116 | crop_border: 4
117 | better: lower
118 |
119 | # logging settings
120 | logger:
121 | print_freq: 500
122 | save_checkpoint_freq: !!float 2e4
123 | use_tb_logger: true
124 | wandb:
125 | project: ~
126 | resume_id: ~
127 |
128 | # dist training settings
129 | dist_params:
130 | backend: nccl
131 | port: 29500
132 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.13.1+cu117
2 | torchvision==0.14.1+cu117
3 | addict
4 | future
5 | lmdb
6 | numpy<2.0.0
7 | opencv-python
8 | Pillow
9 | pyyaml
10 | requests
11 | scikit-image
12 | scipy
13 | tb-nightly
14 | tqdm
15 | yapf
16 | timm
17 | einops
18 | natsort
19 | joblib
20 | wandb
21 | lpips
22 | matplotlib
23 | seaborn
24 | scikit-learn
25 | basicsr
--------------------------------------------------------------------------------
/results/README.md:
--------------------------------------------------------------------------------
1 | The testing results.
2 |
3 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | from os import path as osp
4 |
5 | from basicsr.data import build_dataloader, build_dataset
6 | from basicsr.models import build_model
7 | from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs
8 | from basicsr.utils.options import dict2str
9 | from diffglv.utils.options import parse_options
10 | from basicsr.utils import set_random_seed
11 | import random
12 |
13 |
14 | def test_pipeline(root_path):
15 | # parse options, set distributed setting, set ramdom seed
16 | opt, _ = parse_options(root_path, is_train=False)
17 |
18 | torch.backends.cudnn.benchmark = True
19 | # torch.backends.cudnn.deterministic = True
20 |
21 | # mkdir and initialize loggers
22 | make_exp_dirs(opt)
23 | log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log")
24 | logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
25 | logger.info(get_env_info())
26 | logger.info(dict2str(opt))
27 |
28 | # create test dataset and dataloader
29 | test_loaders = []
30 | for _, dataset_opt in sorted(opt['datasets'].items()):
31 | test_set = build_dataset(dataset_opt)
32 | test_loader = build_dataloader(
33 | test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
34 | logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
35 | test_loaders.append(test_loader)
36 |
37 | # create model
38 | model = build_model(opt)
39 |
40 | for test_loader in test_loaders:
41 | seed = opt.get('manual_seed')
42 | if seed is None:
43 | seed = random.randint(1, 10000)
44 | opt['manual_seed'] = seed
45 | set_random_seed(seed + opt['rank'])
46 | test_set_name = test_loader.dataset.opt['name']
47 | logger.info(f'Testing {test_set_name}...')
48 | model.validation(test_loader, current_iter=opt['name'], tb_logger=None, save_img=opt['val']['save_img'])
49 |
50 |
51 | if __name__ == '__main__':
52 | root_path = osp.abspath(osp.join(__file__, osp.pardir))
53 | test_pipeline(root_path)
54 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os.path as osp
2 | import basicsr
3 | import diffglv
4 |
5 | if __name__ == '__main__':
6 | root_path = osp.abspath(osp.join(__file__, osp.pardir))
7 | basicsr.train_pipeline(root_path)
8 |
--------------------------------------------------------------------------------