├── .gitignore
├── LICENSE
├── LICENSES
├── LICENSE
├── LICENSE_BasicSR
├── LICENSE_GLOW
└── README.md
├── README.md
├── code
├── Measure.py
├── confs
│ ├── RRDB_CelebA_8X.yml
│ ├── RRDB_DF2K_4X.yml
│ ├── RRDB_DF2K_8X.yml
│ ├── SRFlow_CelebA_8X.yml
│ ├── SRFlow_DF2K_4X.yml
│ └── SRFlow_DF2K_8X.yml
├── data
│ ├── LRHR_PKL_dataset.py
│ └── __init__.py
├── demo_on_pretrained.ipynb
├── imresize.py
├── models
│ ├── SRFlow_model.py
│ ├── SR_model.py
│ ├── __init__.py
│ ├── base_model.py
│ ├── lr_scheduler.py
│ ├── modules
│ │ ├── FlowActNorms.py
│ │ ├── FlowAffineCouplingsAblation.py
│ │ ├── FlowStep.py
│ │ ├── FlowUpsamplerNet.py
│ │ ├── Permutations.py
│ │ ├── RRDBNet_arch.py
│ │ ├── SRFlowNet_arch.py
│ │ ├── Split.py
│ │ ├── __init__.py
│ │ ├── flow.py
│ │ ├── glow_arch.py
│ │ ├── loss.py
│ │ ├── module_util.py
│ │ └── thops.py
│ └── networks.py
├── options
│ ├── __init__.py
│ └── options.py
├── prepare_data.py
├── test.py
├── train.py
└── utils
│ ├── __init__.py
│ ├── timer.py
│ └── util.py
├── requirements.txt
├── run_jupyter.sh
└── setup.sh
/.gitignore:
--------------------------------------------------------------------------------
1 | __MACOSX/
2 | .DS_Store
3 |
4 | *.pth
5 | *.pklv4
6 | *.zip
7 |
8 | datasets/
9 |
10 | data/
11 | myenv/
12 | .idea
13 | checkpoints/
14 | code/notebooks/
15 | code/local_config.py
16 | *.ipynb
17 |
18 | TRAIN_DONE
19 |
20 | # folder
21 | .vscode
22 |
23 | experiments/*
24 | !experiments/pretrained_models
25 | experiments/pretrained_models/*
26 | # !experiments/pretrained_models/SRResNet_bicx4_in3nf64nb16.pth
27 | !experiments/pretrained_models/README.md
28 |
29 | results/*
30 | tb_logger/*
31 |
32 | # file type
33 | *.svg
34 | *.pyc
35 | *.t7
36 | *.caffemodel
37 | *.mat
38 | *.npy
39 |
40 | # latex
41 | *.aux
42 | *.bbl
43 | *.blg
44 | *.log
45 | *.out
46 | *.synctex.gz
47 |
48 | # TODO
49 | data_samples/samples_byteimg
50 | data_samples/samples_colorimg
51 | data_samples/samples_segprob
52 | data_samples/samples_result
53 |
54 |
55 | Created by https://www.gitignore.io/api/vim,python,pycharm
56 | # Edit at https://www.gitignore.io/?templates=vim,python,pycharm
57 |
58 | ### PyCharm ###
59 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
60 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
61 |
62 | # User-specific stuff
63 | .idea/**/workspace.xml
64 | .idea/**/tasks.xml
65 | .idea/**/usage.statistics.xml
66 | .idea/**/dictionaries
67 | .idea/**/shelf
68 |
69 | # Generated files
70 | .idea/**/contentModel.xml
71 |
72 | # Sensitive or high-churn files
73 | .idea/**/dataSources/
74 | .idea/**/dataSources.ids
75 | .idea/**/dataSources.local.xml
76 | .idea/**/sqlDataSources.xml
77 | .idea/**/dynamic.xml
78 | .idea/**/uiDesigner.xml
79 | .idea/**/dbnavigator.xml
80 |
81 | # Gradle
82 | .idea/**/gradle.xml
83 | .idea/**/libraries
84 |
85 | # Gradle and Maven with auto-import
86 | # When using Gradle or Maven with auto-import, you should exclude module files,
87 | # since they will be recreated, and may cause churn. Uncomment if using
88 | # auto-import.
89 | # .idea/modules.xml
90 | # .idea/*.iml
91 | # .idea/modules
92 | # *.iml
93 | # *.ipr
94 |
95 | # CMake
96 | cmake-build-*/
97 |
98 | # Mongo Explorer plugin
99 | .idea/**/mongoSettings.xml
100 |
101 | # File-based project format
102 | *.iws
103 |
104 | # IntelliJ
105 | out/
106 |
107 | # mpeltonen/sbt-idea plugin
108 | .idea_modules/
109 |
110 | # JIRA plugin
111 | atlassian-ide-plugin.xml
112 |
113 | # Cursive Clojure plugin
114 | .idea/replstate.xml
115 |
116 | # Crashlytics plugin (for Android Studio and IntelliJ)
117 | com_crashlytics_export_strings.xml
118 | crashlytics.properties
119 | crashlytics-build.properties
120 | fabric.properties
121 |
122 | # Editor-based Rest Client
123 | .idea/httpRequests
124 |
125 | # Android studio 3.1+ serialized cache file
126 | .idea/caches/build_file_checksums.ser
127 |
128 | ### PyCharm Patch ###
129 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721
130 |
131 | # *.iml
132 | # modules.xml
133 | # .idea/misc.xml
134 | # *.ipr
135 |
136 | # Sonarlint plugin
137 | .idea/sonarlint
138 |
139 | ### Python ###
140 | # Byte-compiled / optimized / DLL files
141 | __pycache__/
142 | *.py[cod]
143 | *$py.class
144 |
145 | # C extensions
146 | *.so
147 |
148 | # Distribution / packaging
149 | .Python
150 | build/
151 | develop-eggs/
152 | dist/
153 | downloads/
154 | eggs/
155 | .eggs/
156 | lib/
157 | lib64/
158 | parts/
159 | sdist/
160 | var/
161 | wheels/
162 | pip-wheel-metadata/
163 | share/python-wheels/
164 | *.egg-info/
165 | .installed.cfg
166 | *.egg
167 | MANIFEST
168 |
169 | # PyInstaller
170 | # Usually these files are written by a python script from a template
171 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
172 | *.manifest
173 | *.spec
174 |
175 | # Installer logs
176 | pip-log.txt
177 | pip-delete-this-directory.txt
178 |
179 | # Unit test / coverage reports
180 | htmlcov/
181 | .tox/
182 | .nox/
183 | .coverage
184 | .coverage.*
185 | .cache
186 | nosetests.xml
187 | coverage.xml
188 | *.cover
189 | .hypothesis/
190 | .pytest_cache/
191 |
192 | # Translations
193 | *.mo
194 | *.pot
195 |
196 | # Django stuff:
197 | *.log
198 | local_settings.py
199 | db.sqlite3
200 | db.sqlite3-journal
201 |
202 | # Flask stuff:
203 | instance/
204 | .webassets-cache
205 |
206 | # Scrapy stuff:
207 | .scrapy
208 |
209 | # Sphinx documentation
210 | docs/_build/
211 |
212 | # PyBuilder
213 | target/
214 |
215 | # Jupyter Notebook
216 | .ipynb_checkpoints
217 |
218 | # IPython
219 | profile_default/
220 | ipython_config.py
221 |
222 | # pyenv
223 | .python-version
224 |
225 | # pipenv
226 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
227 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
228 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
229 | # install all needed dependencies.
230 | #Pipfile.lock
231 |
232 | # celery beat schedule file
233 | celerybeat-schedule
234 |
235 | # SageMath parsed files
236 | *.sage.py
237 |
238 | # Environments
239 | .env
240 | .venv
241 | env/
242 | venv/
243 | ENV/
244 | env.bak/
245 | venv.bak/
246 |
247 | # Spyder project settings
248 | .spyderproject
249 | .spyproject
250 |
251 | # Rope project settings
252 | .ropeproject
253 |
254 | # mkdocs documentation
255 | /site
256 |
257 | # mypy
258 | .mypy_cache/
259 | .dmypy.json
260 | dmypy.json
261 |
262 | # Pyre type checker
263 | .pyre/
264 |
265 | ### Vim ###
266 | # Swap
267 | [._]*.s[a-v][a-z]
268 | [._]*.sw[a-p]
269 | [._]s[a-rt-v][a-z]
270 | [._]ss[a-gi-z]
271 | [._]sw[a-p]
272 |
273 | # Session
274 | Session.vim
275 | Sessionx.vim
276 |
277 | # Temporary
278 | .netrwhist
279 | *~
280 | # Auto-generated tag files
281 | tags
282 | # Persistent undo
283 | [._]*.un~
284 |
285 | # End of https://www.gitignore.io/api/vim,python,pycharm
286 |
287 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | you may not use this file except in compliance with the License.
4 | You may obtain a copy of the License at
5 |
6 | https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 |
8 | The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 |
15 | Parts of this repository are licensed by
16 | https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
17 | https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
18 |
--------------------------------------------------------------------------------
/LICENSES/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | you may not use this file except in compliance with the License.
4 | You may obtain a copy of the License at
5 |
6 | https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 |
8 | The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | Unless required by applicable law or agreed to in writing, software
10 | distributed under the License is distributed on an "AS IS" BASIS,
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | See the License for the specific language governing permissions and
13 | limitations under the License.
14 |
15 | Parts of this repository are licensed by
16 | https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
17 | https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
--------------------------------------------------------------------------------
/LICENSES/LICENSE_BasicSR:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright 2018-2020 BasicSR Authors
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
--------------------------------------------------------------------------------
/LICENSES/LICENSE_GLOW:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2019 Yuki-Chai
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
--------------------------------------------------------------------------------
/LICENSES/README.md:
--------------------------------------------------------------------------------
1 | # License and Acknowledgement
2 |
3 | A big thanks to following contributes that open sourced their code and therefore helped us a lot in developing SRFlow!
4 |
5 | ## BasicSR
6 | The training framework was adapted from https://github.com/xinntao/BasicSR
7 |
8 | ## GLOW
9 | The Normalizing Flow modules were adapted from https://github.com/chaiyujin/glow-pytorch
10 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SRFlow
2 | #### Official SRFlow training code: Super-Resolution using Normalizing Flow in PyTorch
3 | #### [[Paper] ECCV 2020 Spotlight](https://bit.ly/2DkwQcg)
4 |
5 |
6 |
7 | **News:** Unified Image Super-Resolution and Rescaling [[code](https://bit.ly/2VOKHBb)]
8 |
9 |
10 |
11 | [](https://bit.ly/3jWFRcr)
12 |
13 |
14 |
15 |
16 | # Setup: Data, Environment, PyTorch Demo
17 |
18 |
19 |
20 | ```bash
21 | git clone https://github.com/andreas128/SRFlow.git && cd SRFlow && ./setup.sh
22 | ```
23 |
24 |
25 |
26 | This oneliner will:
27 | - Clone SRFlow
28 | - Setup a python3 virtual env
29 | - Install the packages from `requirements.txt`
30 | - Download the pretrained models
31 | - Download the DIV2K validation data
32 | - Run the Demo Jupyter Notebook
33 |
34 | If you want to install it manually, read the `setup.sh` file. (Links to data/models, pip packages)
35 |
36 |
37 |
38 |
39 | # Demo: Try Normalizing Flow in PyTorch
40 |
41 | ```bash
42 | ./run_jupyter.sh
43 | ```
44 |
45 | This notebook lets you:
46 | - Load the pretrained models.
47 | - Super-resolve images.
48 | - Measure PSNR/SSIM/LPIPS.
49 | - Infer the Normalizing Flow latent space.
50 |
51 |
52 |
53 | # Testing: Apply the included pretrained models
54 |
55 | ```bash
56 | source myenv/bin/activate # Use the env you created using setup.sh
57 | cd code
58 | CUDA_VISIBLE_DEVICES=-1 python test.py ./confs/SRFlow_DF2K_4X.yml # Diverse Images 4X (Dataset Included)
59 | CUDA_VISIBLE_DEVICES=-1 python test.py ./confs/SRFlow_DF2K_8X.yml # Diverse Images 8X (Dataset Included)
60 | CUDA_VISIBLE_DEVICES=-1 python test.py ./confs/SRFlow_CelebA_8X.yml # Faces 8X
61 | ```
62 | For testing, we apply SRFlow to the full images on CPU.
63 |
64 |
65 |
66 | # Training: Reproduce or train on your Data
67 |
68 | The following commands train the Super-Resolution network using Normalizing Flow in PyTorch:
69 |
70 | ```bash
71 | source myenv/bin/activate # Use the env you created using setup.sh
72 | cd code
73 | python train.py -opt ./confs/SRFlow_DF2K_4X.yml # Diverse Images 4X (Dataset Included)
74 | python train.py -opt ./confs/SRFlow_DF2K_8X.yml # Diverse Images 8X (Dataset Included)
75 | python train.py -opt ./confs/SRFlow_CelebA_8X.yml # Faces 8X
76 | ```
77 |
78 | - To reduce the GPU memory, reduce the batch size in the yml file.
79 | - CelebA does not allow us to host the dataset. A script will follow.
80 |
81 | ### How to prepare CelebA?
82 |
83 | **1. Get HD-CelebA-Cropper**
84 |
85 | ```git clone https://github.com/LynnHo/HD-CelebA-Cropper```
86 |
87 | **2. Download the dataset**
88 |
89 | `img_celeba.7z` and `annotations.zip` as desribed in the [Readme](https://github.com/LynnHo/HD-CelebA-Cropper).
90 |
91 | **3. Run the crop align**
92 |
93 | ```python3 align.py --img_dir ./data/data --crop_size_h 640 --crop_size_w 640 --order 3 --face_factor 0.6 --n_worker 8```
94 |
95 | **4. Downsample for GT**
96 |
97 | Use the [matlablike kernel](https://github.com/fatheral/matlab_imresize) to downscale to 160x160 for the GT images.
98 |
99 | **5. Downsample for LR**
100 |
101 | Downscale the GT using the Matlab kernel to the LR size (40x40 or 20x20)
102 |
103 | **6. Train/Validation**
104 |
105 | For training and validation, we use the corresponding sets defined by CelebA (Train: 000001-162770, Validation: 162771-182637)
106 |
107 | **7. Pack to pickle for training**
108 |
109 | `cd code && python prepare_data.py /path/to/img_dir`
110 |
111 |
112 |
113 | # Dataset: How to train on your own data
114 |
115 | The following command creates the pickel files that you can use in the yaml config file:
116 |
117 | ```bash
118 | cd code
119 | python prepare_data.py /path/to/img_dir
120 | ```
121 |
122 | The precomputed DF2K dataset gets downloaded using `setup.sh`. You can reproduce it or prepare your own dataset.
123 |
124 |
125 |
126 | # Our paper explains
127 |
128 | - **How to train Conditional Normalizing Flow**
129 | We designed an architecture that archives state-of-the-art super-resolution quality.
130 | - **How to train Normalizing Flow on a single GPU**
131 | We based our network on GLOW, which uses up to 40 GPUs to train for image generation. SRFlow only needs a single GPU for training conditional image generation.
132 | - **How to use Normalizing Flow for image manipulation**
133 | How to exploit the latent space for Normalizing Flow for controlled image manipulations
134 | - **See many Visual Results**
135 | Compare GAN vs Normalizing Flow yourself. We've included a lot of visuals results in our [[Paper]](https://bit.ly/2D9cN0L).
136 |
137 |
138 |
139 | # GAN vs Normalizing Flow - Blog
140 |
141 | [](https://bit.ly/2EdJzhy)
142 |
143 | - **Sampling:** SRFlow outputs many different images for a single input.
144 | - **Stable Training:** SRFlow has much fewer hyperparameters than GAN approaches, and we did not encounter training stability issues.
145 | - **Convergence:** While GANs cannot converge, conditional Normalizing Flows converge monotonic and stable.
146 | - **Higher Consistency:** When downsampling the super-resolution, one obtains almost the exact input.
147 |
148 | Get a quick introduction to Normalizing Flow in our [[Blog]](https://bit.ly/320bAkH).
149 |
150 |
151 |
152 |
153 | # Wanna help to improve the code?
154 |
155 | If you found a bug or improved the code, please do the following:
156 |
157 | - Fork this repo.
158 | - Push the changes to your repo.
159 | - Create a pull request.
160 |
161 |
162 |
163 | # Paper
164 | [[Paper] ECCV 2020 Spotlight](https://bit.ly/2XcmSks)
165 |
166 | ```bibtex
167 | @inproceedings{lugmayr2020srflow,
168 | title={SRFlow: Learning the Super-Resolution Space with Normalizing Flow},
169 | author={Lugmayr, Andreas and Danelljan, Martin and Van Gool, Luc and Timofte, Radu},
170 | booktitle={ECCV},
171 | year={2020}
172 | }
173 | ```
174 |
175 |
--------------------------------------------------------------------------------
/code/Measure.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import glob
16 | import os
17 | import time
18 | from collections import OrderedDict
19 |
20 | import numpy as np
21 | import torch
22 | import cv2
23 | import argparse
24 |
25 | from natsort import natsort
26 | from skimage.metrics import structural_similarity as ssim
27 | from skimage.metrics import peak_signal_noise_ratio as psnr
28 | import lpips
29 |
30 |
31 | class Measure():
32 | def __init__(self, net='alex', use_gpu=False):
33 | self.device = 'cuda' if use_gpu else 'cpu'
34 | self.model = lpips.LPIPS(net=net)
35 | self.model.to(self.device)
36 |
37 | def measure(self, imgA, imgB):
38 | return [float(f(imgA, imgB)) for f in [self.psnr, self.ssim, self.lpips]]
39 |
40 | def lpips(self, imgA, imgB, model=None):
41 | tA = t(imgA).to(self.device)
42 | tB = t(imgB).to(self.device)
43 | dist01 = self.model.forward(tA, tB).item()
44 | return dist01
45 |
46 | def ssim(self, imgA, imgB):
47 | # multichannel: If True, treat the last dimension of the array as channels. Similarity calculations are done independently for each channel then averaged.
48 | score, diff = ssim(imgA, imgB, full=True, multichannel=True)
49 | return score
50 |
51 | def psnr(self, imgA, imgB):
52 | psnr_val = psnr(imgA, imgB)
53 | return psnr_val
54 |
55 |
56 | def t(img):
57 | def to_4d(img):
58 | assert len(img.shape) == 3
59 | assert img.dtype == np.uint8
60 | img_new = np.expand_dims(img, axis=0)
61 | assert len(img_new.shape) == 4
62 | return img_new
63 |
64 | def to_CHW(img):
65 | return np.transpose(img, [2, 0, 1])
66 |
67 | def to_tensor(img):
68 | return torch.Tensor(img)
69 |
70 | return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1
71 |
72 |
73 | def fiFindByWildcard(wildcard):
74 | return natsort.natsorted(glob.glob(wildcard, recursive=True))
75 |
76 |
77 | def imread(path):
78 | return cv2.imread(path)[:, :, [2, 1, 0]]
79 |
80 |
81 | def format_result(psnr, ssim, lpips):
82 | return f'{psnr:0.2f}, {ssim:0.3f}, {lpips:0.3f}'
83 |
84 | def measure_dirs(dirA, dirB, use_gpu, verbose=False):
85 | if verbose:
86 | vprint = lambda x: print(x)
87 | else:
88 | vprint = lambda x: None
89 |
90 |
91 | t_init = time.time()
92 |
93 | paths_A = fiFindByWildcard(os.path.join(dirA, f'*.{type}'))
94 | paths_B = fiFindByWildcard(os.path.join(dirB, f'*.{type}'))
95 |
96 | vprint("Comparing: ")
97 | vprint(dirA)
98 | vprint(dirB)
99 |
100 | measure = Measure(use_gpu=use_gpu)
101 |
102 | results = []
103 | for pathA, pathB in zip(paths_A, paths_B):
104 | result = OrderedDict()
105 |
106 | t = time.time()
107 | result['psnr'], result['ssim'], result['lpips'] = measure.measure(imread(pathA), imread(pathB))
108 | d = time.time() - t
109 | vprint(f"{pathA.split('/')[-1]}, {pathB.split('/')[-1]}, {format_result(**result)}, {d:0.1f}")
110 |
111 | results.append(result)
112 |
113 | psnr = np.mean([result['psnr'] for result in results])
114 | ssim = np.mean([result['ssim'] for result in results])
115 | lpips = np.mean([result['lpips'] for result in results])
116 |
117 | vprint(f"Final Result: {format_result(psnr, ssim, lpips)}, {time.time() - t_init:0.1f}s")
118 |
119 |
120 | if __name__ == "__main__":
121 | parser = argparse.ArgumentParser()
122 | parser.add_argument('-dirA', default='', type=str)
123 | parser.add_argument('-dirB', default='', type=str)
124 | parser.add_argument('-type', default='png')
125 | parser.add_argument('--use_gpu', action='store_true', default=False)
126 | args = parser.parse_args()
127 |
128 | dirA = args.dirA
129 | dirB = args.dirB
130 | type = args.type
131 | use_gpu = args.use_gpu
132 |
133 | if len(dirA) > 0 and len(dirB) > 0:
134 | measure_dirs(dirA, dirB, use_gpu=use_gpu, verbose=True)
135 |
--------------------------------------------------------------------------------
/code/confs/RRDB_CelebA_8X.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | #### general settings
18 | name: train
19 | use_tb_logger: true
20 | model: SR
21 | distortion: sr
22 | scale: 8
23 | #gpu_ids: [ 0 ]
24 |
25 | #### datasets
26 | datasets:
27 | train:
28 | name: CelebA_160_tr
29 | mode: LRHR_PKL
30 | dataroot_GT: ../datasets/celebA-train-gt_1pct.pklv4
31 | dataroot_LQ: ../datasets/celebA-train-x8_1pct.pklv4
32 |
33 | use_shuffle: true
34 | n_workers: 0 # per GPU
35 | batch_size: 16
36 | GT_size: 160
37 | use_flip: true
38 | use_rot: true
39 | color: RGB
40 | val:
41 | name: CelebA_160_va
42 | mode: LRHR_PKL
43 | dataroot_GT: ../datasets/celebA-valid-gt_1pct.pklv4
44 | dataroot_LQ: ../datasets/celebA-valid-x8_1pct.pklv4
45 | n_max: 10
46 |
47 | #### network structures
48 | network_G:
49 | which_model_G: RRDBNet
50 | in_nc: 3
51 | out_nc: 3
52 | nf: 64
53 | nb: 23
54 |
55 | #### path
56 | path:
57 | pretrain_model_G: ~
58 | strict_load: true
59 | resume_state: auto
60 |
61 | #### training settings: learning rate scheme, loss
62 | train:
63 | lr_G: !!float 2e-4
64 | lr_scheme: CosineAnnealingLR_Restart
65 | beta1: 0.9
66 | beta2: 0.99
67 | niter: 200000
68 | warmup_iter: -1 # no warm up
69 | T_period: [ 50000, 50000, 50000, 50000 ]
70 | restarts: [ 50000, 100000, 150000 ]
71 | restart_weights: [ 1, 1, 1 ]
72 | eta_min: !!float 1e-7
73 |
74 | pixel_criterion: l1
75 | pixel_weight: 1.0
76 |
77 | manual_seed: 10
78 | val_freq: !!float 5e3
79 |
80 | #### logger
81 | logger:
82 | print_freq: 100
83 | save_checkpoint_freq: !!float 1e3
84 |
--------------------------------------------------------------------------------
/code/confs/RRDB_DF2K_4X.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | #### general settings
18 | name: train
19 | use_tb_logger: true
20 | model: SR
21 | distortion: sr
22 | scale: 4
23 | gpu_ids: [ 0 ]
24 |
25 | #### datasets
26 | datasets:
27 | train:
28 | name: CelebA_160_tr
29 | mode: LRHR_PKL
30 | dataroot_GT: ../datasets/DF2K-train-gt_1pct.pklv4
31 | dataroot_LQ: ../datasets/DF2K-train-x4_1pct.pklv4
32 | quant: 32
33 |
34 | use_shuffle: true
35 | n_workers: 3 # per GPU
36 | batch_size: 16
37 | GT_size: 160
38 | use_flip: true
39 | color: RGB
40 | val:
41 | name: CelebA_160_va
42 | mode: LRHR_PKL
43 | dataroot_GT: ../datasets/DF2K-valid-gt_1pct.pklv4
44 | dataroot_LQ: ../datasets/DF2K-valid-x4_1pct.pklv4
45 | quant: 32
46 | n_max: 20
47 |
48 | #### network structures
49 | network_G:
50 | which_model_G: RRDBNet
51 | use_orig: True
52 | in_nc: 3
53 | out_nc: 3
54 | nf: 64
55 | nb: 23
56 |
57 | #### path
58 | path:
59 | pretrain_model_G: ~
60 | strict_load: true
61 | resume_state: auto
62 |
63 | #### training settings: learning rate scheme, loss
64 | train:
65 | lr_G: !!float 2e-4
66 | lr_scheme: CosineAnnealingLR_Restart
67 | beta1: 0.9
68 | beta2: 0.99
69 | niter: 1000000
70 | warmup_iter: -1 # no warm up
71 | T_period: [ 50000, 50000, 50000, 50000 ]
72 | restarts: [ 50000, 100000, 150000 ]
73 | restart_weights: [ 1, 1, 1 ]
74 | eta_min: !!float 1e-7
75 |
76 | pixel_criterion: l1
77 | pixel_weight: 1.0
78 |
79 | manual_seed: 10
80 | val_freq: !!float 5e3
81 |
82 | #### logger
83 | logger:
84 | print_freq: 100
85 | save_checkpoint_freq: !!float 1e3
86 |
--------------------------------------------------------------------------------
/code/confs/RRDB_DF2K_8X.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | #### general settings
18 | name: train
19 | use_tb_logger: true
20 | model: SR
21 | distortion: sr
22 | scale: 8
23 | gpu_ids: [ 0 ]
24 |
25 | #### datasets
26 | datasets:
27 | train:
28 | name: CelebA_160_tr
29 | mode: LRHR_PKL
30 | dataroot_GT: ../datasets/DF2K-train-gt_1pct.pklv4
31 | dataroot_LQ: ../datasets/DF2K-train-x8_1pct.pklv4
32 | quant: 32
33 |
34 | use_shuffle: true
35 | n_workers: 3 # per GPU
36 | batch_size: 16
37 | GT_size: 160
38 | use_flip: true
39 | color: RGB
40 |
41 | val:
42 | name: CelebA_160_va
43 | mode: LRHR_PKL
44 | dataroot_GT: ../datasets/DF2K-valid-gt_1pct.pklv4
45 | dataroot_LQ: ../datasets/DF2K-valid-x8_1pct.pklv4
46 | quant: 32
47 | n_max: 20
48 |
49 | #### network structures
50 | network_G:
51 | which_model_G: RRDBNet
52 | in_nc: 3
53 | out_nc: 3
54 | nf: 64
55 | nb: 23
56 |
57 | #### path
58 | path:
59 | pretrain_model_G: ~
60 | strict_load: true
61 | resume_state: auto
62 |
63 | #### training settings: learning rate scheme, loss
64 | train:
65 | lr_G: !!float 2e-4
66 | lr_scheme: CosineAnnealingLR_Restart
67 | beta1: 0.9
68 | beta2: 0.99
69 | niter: 200000
70 | warmup_iter: -1 # no warm up
71 | T_period: [ 50000, 50000, 50000, 50000 ]
72 | restarts: [ 50000, 100000, 150000 ]
73 | restart_weights: [ 1, 1, 1 ]
74 | eta_min: !!float 1e-7
75 |
76 | pixel_criterion: l1
77 | pixel_weight: 1.0
78 |
79 | manual_seed: 10
80 | val_freq: !!float 5e3
81 |
82 | #### logger
83 | logger:
84 | print_freq: 100
85 | save_checkpoint_freq: !!float 1e3
86 |
--------------------------------------------------------------------------------
/code/confs/SRFlow_CelebA_8X.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | #### general settings
18 | name: train
19 | use_tb_logger: true
20 | model: SRFlow
21 | distortion: sr
22 | scale: 8
23 | gpu_ids: [ 0 ]
24 |
25 | #### datasets
26 | datasets:
27 | train:
28 | name: CelebA_160_tr
29 | mode: LRHR_PKL
30 | dataroot_GT: ../datasets/celebA-train-gt.pklv4
31 | dataroot_LQ: ../datasets/celebA-train-x8.pklv4
32 | quant: 32
33 |
34 | use_shuffle: true
35 | n_workers: 3 # per GPU
36 | batch_size: 16
37 | GT_size: 160
38 | use_flip: true
39 | color: RGB
40 | val:
41 | name: CelebA_160_va
42 | mode: LRHR_PKL
43 | dataroot_GT: ../datasets/celebA-train-gt.pklv4
44 | dataroot_LQ: ../datasets/celebA-train-x8.pklv4
45 | quant: 32
46 | n_max: 20
47 |
48 | #### Test Settings
49 | dataroot_GT: ../datasets/celebA-validation-gt
50 | dataroot_LR: ../datasets/celebA-validation-x8
51 | model_path: ../pretrained_models/SRFlow_CelebA_8X.pth
52 | heat: 0.9 # This is the standard deviation of the latent vectors
53 |
54 | #### network structures
55 | network_G:
56 | which_model_G: SRFlowNet
57 | in_nc: 3
58 | out_nc: 3
59 | nf: 64
60 | nb: 8
61 | upscale: 8
62 | train_RRDB: false
63 | train_RRDB_delay: 0.5
64 |
65 | flow:
66 | K: 16
67 | L: 4
68 | noInitialInj: true
69 | coupling: CondAffineSeparatedAndCond
70 | additionalFlowNoAffine: 2
71 | split:
72 | enable: true
73 | fea_up0: true
74 | stackRRDB:
75 | blocks: [ 1, 3, 5, 7 ]
76 | concat: true
77 |
78 | #### path
79 | path:
80 | pretrain_model_G: ../pretrained_models/RRDB_CelebA_8X.pth
81 | strict_load: true
82 | resume_state: auto
83 |
84 | #### training settings: learning rate scheme, loss
85 | train:
86 | manual_seed: 10
87 | lr_G: !!float 5e-4
88 | weight_decay_G: 0
89 | beta1: 0.9
90 | beta2: 0.99
91 | lr_scheme: MultiStepLR
92 | warmup_iter: -1 # no warm up
93 | lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ]
94 | lr_gamma: 0.5
95 |
96 | niter: 200000
97 | val_freq: 40000
98 |
99 | #### validation settings
100 | val:
101 | heats: [ 0.0, 0.5, 0.75, 1.0 ]
102 | n_sample: 3
103 |
104 | #### logger
105 | logger:
106 | print_freq: 100
107 | save_checkpoint_freq: !!float 1e3
108 |
--------------------------------------------------------------------------------
/code/confs/SRFlow_DF2K_4X.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | #### general settings
18 | name: train
19 | use_tb_logger: true
20 | model: SRFlow
21 | distortion: sr
22 | scale: 4
23 | gpu_ids: [ 0 ]
24 |
25 | #### datasets
26 | datasets:
27 | train:
28 | name: CelebA_160_tr
29 | mode: LRHR_PKL
30 | dataroot_GT: ../datasets/DF2K-tr.pklv4
31 | dataroot_LQ: ../datasets/DF2K-tr_X4.pklv4
32 | quant: 32
33 |
34 | use_shuffle: true
35 | n_workers: 3 # per GPU
36 | batch_size: 12
37 | GT_size: 160
38 | use_flip: true
39 | color: RGB
40 | val:
41 | name: CelebA_160_va
42 | mode: LRHR_PKL
43 | dataroot_GT: ../datasets/DIV2K-va.pklv4
44 | dataroot_LQ: ../datasets/DIV2K-va_X4.pklv4
45 | quant: 32
46 | n_max: 20
47 |
48 | #### Test Settings
49 | dataroot_GT: ../datasets/div2k-validation-modcrop8-gt
50 | dataroot_LR: ../datasets/div2k-validation-modcrop8-x4
51 | model_path: ../pretrained_models/SRFlow_DF2K_4X.pth
52 | heat: 0.9 # This is the standard deviation of the latent vectors
53 |
54 | #### network structures
55 | network_G:
56 | which_model_G: SRFlowNet
57 | in_nc: 3
58 | out_nc: 3
59 | nf: 64
60 | nb: 23
61 | upscale: 4
62 | train_RRDB: false
63 | train_RRDB_delay: 0.5
64 |
65 | flow:
66 | K: 16
67 | L: 3
68 | noInitialInj: true
69 | coupling: CondAffineSeparatedAndCond
70 | additionalFlowNoAffine: 2
71 | split:
72 | enable: true
73 | fea_up0: true
74 | stackRRDB:
75 | blocks: [ 1, 8, 15, 22 ]
76 | concat: true
77 |
78 | #### path
79 | path:
80 | pretrain_model_G: ../pretrained_models/RRDB_DF2K_4X.pth
81 | strict_load: true
82 | resume_state: auto
83 |
84 | #### training settings: learning rate scheme, loss
85 | train:
86 | manual_seed: 10
87 | lr_G: !!float 2.5e-4
88 | weight_decay_G: 0
89 | beta1: 0.9
90 | beta2: 0.99
91 | lr_scheme: MultiStepLR
92 | warmup_iter: -1 # no warm up
93 | lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ]
94 | lr_gamma: 0.5
95 |
96 | niter: 200000
97 | val_freq: 40000
98 |
99 | #### validation settings
100 | val:
101 | heats: [ 0.0, 0.5, 0.75, 1.0 ]
102 | n_sample: 3
103 |
104 | #### logger
105 | logger:
106 | print_freq: 100
107 | save_checkpoint_freq: !!float 1e3
108 |
--------------------------------------------------------------------------------
/code/confs/SRFlow_DF2K_8X.yml:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | #### general settings
18 | name: train
19 | use_tb_logger: true
20 | model: SRFlow
21 | distortion: sr
22 | scale: 8
23 | gpu_ids: [ 0 ]
24 |
25 | #### datasets
26 | datasets:
27 | train:
28 | name: CelebA_160_tr
29 | mode: LRHR_PKL
30 | dataroot_GT: ../datasets/DF2K-tr.pklv4
31 | dataroot_LQ: ../datasets/DF2K-tr_X8.pklv4
32 | quant: 32
33 |
34 | use_shuffle: true
35 | n_workers: 3 # per GPU
36 | batch_size: 16
37 | GT_size: 160
38 | use_flip: true
39 | color: RGB
40 |
41 | val:
42 | name: CelebA_160_va
43 | mode: LRHR_PKL
44 | dataroot_GT: ../datasets/DIV2K-va.pklv4
45 | dataroot_LQ: ../datasets/DIV2K-va_X8.pklv4
46 | quant: 32
47 | n_max: 20
48 |
49 | #### Test Settings
50 | dataroot_GT: ../datasets/div2k-validation-modcrop8-gt
51 | dataroot_LR: ../datasets/div2k-validation-modcrop8-x8
52 | model_path: ../pretrained_models/SRFlow_DF2K_8X.pth
53 | heat: 0.9 # This is the standard deviation of the latent vectors
54 |
55 | #### network structures
56 | network_G:
57 | which_model_G: SRFlowNet
58 | in_nc: 3
59 | out_nc: 3
60 | nf: 64
61 | nb: 23
62 | upscale: 8
63 | train_RRDB: false
64 | train_RRDB_delay: 0.5
65 |
66 | flow:
67 | K: 16
68 | L: 4
69 | noInitialInj: true
70 | coupling: CondAffineSeparatedAndCond
71 | additionalFlowNoAffine: 2
72 | split:
73 | enable: true
74 | fea_up0: true
75 | stackRRDB:
76 | blocks: [ 1, 3, 5, 7 ]
77 | concat: true
78 |
79 | #### path
80 | path:
81 | pretrain_model_G: ../pretrained_models/RRDB_DF2K_8X.pth
82 | strict_load: true
83 | resume_state: auto
84 |
85 | #### training settings: learning rate scheme, loss
86 | train:
87 | manual_seed: 10
88 | lr_G: !!float 5e-4
89 | weight_decay_G: 0
90 | beta1: 0.9
91 | beta2: 0.99
92 | lr_scheme: MultiStepLR
93 | warmup_iter: -1 # no warm up
94 | lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ]
95 | lr_gamma: 0.5
96 |
97 | niter: 200000
98 | val_freq: 40000
99 |
100 | #### validation settings
101 | val:
102 | heats: [ 0.0, 0.5, 0.75, 1.0 ]
103 | n_sample: 3
104 |
105 | test:
106 | heats: [ 0.0, 0.7, 0.8, 0.9 ]
107 |
108 | #### logger
109 | logger:
110 | # Debug print_freq: 100
111 | print_freq: 100
112 | save_checkpoint_freq: !!float 1e3
113 |
--------------------------------------------------------------------------------
/code/data/LRHR_PKL_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | import os
18 | import subprocess
19 | import torch.utils.data as data
20 | import numpy as np
21 | import time
22 | import torch
23 |
24 | import pickle
25 |
26 |
27 | class LRHR_PKLDataset(data.Dataset):
28 | def __init__(self, opt):
29 | super(LRHR_PKLDataset, self).__init__()
30 | self.opt = opt
31 | self.crop_size = opt.get("GT_size", None)
32 | self.scale = None
33 | self.random_scale_list = [1]
34 |
35 | hr_file_path = opt["dataroot_GT"]
36 | lr_file_path = opt["dataroot_LQ"]
37 | y_labels_file_path = opt['dataroot_y_labels']
38 |
39 | gpu = True
40 | augment = True
41 |
42 | self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False
43 | self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False
44 | self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False
45 | self.center_crop_hr_size = opt.get("center_crop_hr_size", None)
46 |
47 | n_max = opt["n_max"] if "n_max" in opt.keys() else int(1e8)
48 |
49 | t = time.time()
50 | self.lr_images = self.load_pkls(lr_file_path, n_max)
51 | self.hr_images = self.load_pkls(hr_file_path, n_max)
52 |
53 | min_val_hr = np.min([i.min() for i in self.hr_images[:20]])
54 | max_val_hr = np.max([i.max() for i in self.hr_images[:20]])
55 |
56 | min_val_lr = np.min([i.min() for i in self.lr_images[:20]])
57 | max_val_lr = np.max([i.max() for i in self.lr_images[:20]])
58 |
59 | t = time.time() - t
60 | print("Loaded {} HR images with [{:.2f}, {:.2f}] in {:.2f}s from {}".
61 | format(len(self.hr_images), min_val_hr, max_val_hr, t, hr_file_path))
62 | print("Loaded {} LR images with [{:.2f}, {:.2f}] in {:.2f}s from {}".
63 | format(len(self.lr_images), min_val_lr, max_val_lr, t, lr_file_path))
64 |
65 | self.gpu = gpu
66 | self.augment = augment
67 |
68 | self.measures = None
69 |
70 | def load_pkls(self, path, n_max):
71 | assert os.path.isfile(path), path
72 | images = []
73 | with open(path, "rb") as f:
74 | images += pickle.load(f)
75 | assert len(images) > 0, path
76 | images = images[:n_max]
77 | images = [np.transpose(image, [2, 0, 1]) for image in images]
78 | return images
79 |
80 | def __len__(self):
81 | return len(self.hr_images)
82 |
83 | def __getitem__(self, item):
84 | hr = self.hr_images[item]
85 | lr = self.lr_images[item]
86 |
87 | if self.scale == None:
88 | self.scale = hr.shape[1] // lr.shape[1]
89 | assert hr.shape[1] == self.scale * lr.shape[1], ('non-fractional ratio', lr.shape, hr.shape)
90 |
91 | if self.use_crop:
92 | hr, lr = random_crop(hr, lr, self.crop_size, self.scale, self.use_crop)
93 |
94 | if self.center_crop_hr_size:
95 | hr, lr = center_crop(hr, self.center_crop_hr_size), center_crop(lr, self.center_crop_hr_size // self.scale)
96 |
97 | if self.use_flip:
98 | hr, lr = random_flip(hr, lr)
99 |
100 | if self.use_rot:
101 | hr, lr = random_rotation(hr, lr)
102 |
103 | hr = hr / 255.0
104 | lr = lr / 255.0
105 |
106 | if self.measures is None or np.random.random() < 0.05:
107 | if self.measures is None:
108 | self.measures = {}
109 | self.measures['hr_means'] = np.mean(hr)
110 | self.measures['hr_stds'] = np.std(hr)
111 | self.measures['lr_means'] = np.mean(lr)
112 | self.measures['lr_stds'] = np.std(lr)
113 |
114 | hr = torch.Tensor(hr)
115 | lr = torch.Tensor(lr)
116 |
117 | # if self.gpu:
118 | # hr = hr.cuda()
119 | # lr = lr.cuda()
120 |
121 | return {'LQ': lr, 'GT': hr, 'LQ_path': str(item), 'GT_path': str(item)}
122 |
123 | def print_and_reset(self, tag):
124 | m = self.measures
125 | kvs = []
126 | for k in sorted(m.keys()):
127 | kvs.append("{}={:.2f}".format(k, m[k]))
128 | print("[KPI] " + tag + ": " + ", ".join(kvs))
129 | self.measures = None
130 |
131 |
132 | def random_flip(img, seg):
133 | random_choice = np.random.choice([True, False])
134 | img = img if random_choice else np.flip(img, 2).copy()
135 | seg = seg if random_choice else np.flip(seg, 2).copy()
136 | return img, seg
137 |
138 |
139 | def random_rotation(img, seg):
140 | random_choice = np.random.choice([0, 1, 3])
141 | img = np.rot90(img, random_choice, axes=(1, 2)).copy()
142 | seg = np.rot90(seg, random_choice, axes=(1, 2)).copy()
143 | return img, seg
144 |
145 |
146 | def random_crop(hr, lr, size_hr, scale, random):
147 | size_lr = size_hr // scale
148 |
149 | size_lr_x = lr.shape[1]
150 | size_lr_y = lr.shape[2]
151 |
152 | start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0
153 | start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0
154 |
155 | # LR Patch
156 | lr_patch = lr[:, start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr]
157 |
158 | # HR Patch
159 | start_x_hr = start_x_lr * scale
160 | start_y_hr = start_y_lr * scale
161 | hr_patch = hr[:, start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr]
162 |
163 | return hr_patch, lr_patch
164 |
165 |
166 | def center_crop(img, size):
167 | assert img.shape[1] == img.shape[2], img.shape
168 | border_double = img.shape[1] - size
169 | assert border_double % 2 == 0, (img.shape, size)
170 | border = border_double // 2
171 | return img[:, border:-border, border:-border]
172 |
173 |
174 | def center_crop_tensor(img, size):
175 | assert img.shape[2] == img.shape[3], img.shape
176 | border_double = img.shape[2] - size
177 | assert border_double % 2 == 0, (img.shape, size)
178 | border = border_double // 2
179 | return img[:, :, border:-border, border:-border]
180 |
--------------------------------------------------------------------------------
/code/data/__init__.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | '''create dataset and dataloader'''
18 | import logging
19 | import torch
20 | import torch.utils.data
21 |
22 |
23 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None):
24 | phase = dataset_opt.get('phase', 'test')
25 | if phase == 'train':
26 | gpu_ids = opt.get('gpu_ids', None)
27 | gpu_ids = gpu_ids if gpu_ids else []
28 | num_workers = dataset_opt['n_workers'] * len(gpu_ids)
29 | batch_size = dataset_opt['batch_size']
30 | shuffle = True
31 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle,
32 | num_workers=num_workers, sampler=sampler, drop_last=True,
33 | pin_memory=False)
34 | else:
35 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1,
36 | pin_memory=True)
37 |
38 |
39 | def create_dataset(dataset_opt):
40 | print(dataset_opt)
41 | mode = dataset_opt['mode']
42 | if mode == 'LRHR_PKL':
43 | from data.LRHR_PKL_dataset import LRHR_PKLDataset as D
44 | else:
45 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode))
46 | dataset = D(dataset_opt)
47 |
48 | logger = logging.getLogger('base')
49 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__,
50 | dataset_opt['name']))
51 | return dataset
52 |
--------------------------------------------------------------------------------
/code/imresize.py:
--------------------------------------------------------------------------------
1 | # https://github.com/fatheral/matlab_imresize
2 | #
3 | # MIT License
4 | #
5 | # Copyright (c) 2020 Alex
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to the following conditions:
13 | #
14 | # The above copyright notice and this permission notice shall be included in all
15 | # copies or substantial portions of the Software.
16 | #
17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
23 | # SOFTWARE.
24 |
25 |
26 | from __future__ import print_function
27 | import numpy as np
28 | from math import ceil, floor
29 |
30 |
31 | def deriveSizeFromScale(img_shape, scale):
32 | output_shape = []
33 | for k in range(2):
34 | output_shape.append(int(ceil(scale[k] * img_shape[k])))
35 | return output_shape
36 |
37 |
38 | def deriveScaleFromSize(img_shape_in, img_shape_out):
39 | scale = []
40 | for k in range(2):
41 | scale.append(1.0 * img_shape_out[k] / img_shape_in[k])
42 | return scale
43 |
44 |
45 | def triangle(x):
46 | x = np.array(x).astype(np.float64)
47 | lessthanzero = np.logical_and((x >= -1), x < 0)
48 | greaterthanzero = np.logical_and((x <= 1), x >= 0)
49 | f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero)
50 | return f
51 |
52 |
53 | def cubic(x):
54 | x = np.array(x).astype(np.float64)
55 | absx = np.absolute(x)
56 | absx2 = np.multiply(absx, absx)
57 | absx3 = np.multiply(absx2, absx)
58 | f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2,
59 | (1 < absx) & (absx <= 2))
60 | return f
61 |
62 |
63 | def contributions(in_length, out_length, scale, kernel, k_width):
64 | if scale < 1:
65 | h = lambda x: scale * kernel(scale * x)
66 | kernel_width = 1.0 * k_width / scale
67 | else:
68 | h = kernel
69 | kernel_width = k_width
70 | x = np.arange(1, out_length + 1).astype(np.float64)
71 | u = x / scale + 0.5 * (1 - 1 / scale)
72 | left = np.floor(u - kernel_width / 2)
73 | P = int(ceil(kernel_width)) + 2
74 | ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0
75 | indices = ind.astype(np.int32)
76 | weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0
77 | weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1))
78 | aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32)
79 | indices = aux[np.mod(indices, aux.size)]
80 | ind2store = np.nonzero(np.any(weights, axis=0))
81 | weights = weights[:, ind2store]
82 | indices = indices[:, ind2store]
83 | return weights, indices
84 |
85 |
86 | def imresizemex(inimg, weights, indices, dim):
87 | in_shape = inimg.shape
88 | w_shape = weights.shape
89 | out_shape = list(in_shape)
90 | out_shape[dim] = w_shape[0]
91 | outimg = np.zeros(out_shape)
92 | if dim == 0:
93 | for i_img in range(in_shape[1]):
94 | for i_w in range(w_shape[0]):
95 | w = weights[i_w, :]
96 | ind = indices[i_w, :]
97 | im_slice = inimg[ind, i_img].astype(np.float64)
98 | outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
99 | elif dim == 1:
100 | for i_img in range(in_shape[0]):
101 | for i_w in range(w_shape[0]):
102 | w = weights[i_w, :]
103 | ind = indices[i_w, :]
104 | im_slice = inimg[i_img, ind].astype(np.float64)
105 | outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0)
106 | if inimg.dtype == np.uint8:
107 | outimg = np.clip(outimg, 0, 255)
108 | return np.around(outimg).astype(np.uint8)
109 | else:
110 | return outimg
111 |
112 |
113 | def imresizevec(inimg, weights, indices, dim):
114 | wshape = weights.shape
115 | if dim == 0:
116 | weights = weights.reshape((wshape[0], wshape[2], 1, 1))
117 | outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1)
118 | elif dim == 1:
119 | weights = weights.reshape((1, wshape[0], wshape[2], 1))
120 | outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2)
121 | if inimg.dtype == np.uint8:
122 | outimg = np.clip(outimg, 0, 255)
123 | return np.around(outimg).astype(np.uint8)
124 | else:
125 | return outimg
126 |
127 |
128 | def resizeAlongDim(A, dim, weights, indices, mode="vec"):
129 | if mode == "org":
130 | out = imresizemex(A, weights, indices, dim)
131 | else:
132 | out = imresizevec(A, weights, indices, dim)
133 | return out
134 |
135 |
136 | def imresize(I, scalar_scale=None, method='bicubic', output_shape=None, mode="vec"):
137 | if method is 'bicubic':
138 | kernel = cubic
139 | elif method is 'bilinear':
140 | kernel = triangle
141 | else:
142 | print('Error: Unidentified method supplied')
143 |
144 | kernel_width = 4.0
145 | # Fill scale and output_size
146 | if scalar_scale is not None:
147 | scalar_scale = float(scalar_scale)
148 | scale = [scalar_scale, scalar_scale]
149 | output_size = deriveSizeFromScale(I.shape, scale)
150 | elif output_shape is not None:
151 | scale = deriveScaleFromSize(I.shape, output_shape)
152 | output_size = list(output_shape)
153 | else:
154 | print('Error: scalar_scale OR output_shape should be defined!')
155 | return
156 | scale_np = np.array(scale)
157 | order = np.argsort(scale_np)
158 | weights = []
159 | indices = []
160 | for k in range(2):
161 | w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width)
162 | weights.append(w)
163 | indices.append(ind)
164 | B = np.copy(I)
165 | flag2D = False
166 | if B.ndim == 2:
167 | B = np.expand_dims(B, axis=2)
168 | flag2D = True
169 | for k in range(2):
170 | dim = order[k]
171 | B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode)
172 | if flag2D:
173 | B = np.squeeze(B, axis=2)
174 | return B
175 |
176 |
177 | def convertDouble2Byte(I):
178 | B = np.clip(I, 0.0, 1.0)
179 | B = 255 * B
180 | return np.around(B).astype(np.uint8)
--------------------------------------------------------------------------------
/code/models/SRFlow_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | import logging
18 | from collections import OrderedDict
19 | from utils.util import get_resume_paths, opt_get
20 |
21 | import torch
22 | import torch.nn as nn
23 | from torch.nn.parallel import DataParallel, DistributedDataParallel
24 | import models.networks as networks
25 | import models.lr_scheduler as lr_scheduler
26 | from .base_model import BaseModel
27 |
28 | logger = logging.getLogger('base')
29 |
30 |
31 | class SRFlowModel(BaseModel):
32 | def __init__(self, opt, step):
33 | super(SRFlowModel, self).__init__(opt)
34 | self.opt = opt
35 |
36 | self.heats = opt['val']['heats']
37 | self.n_sample = opt['val']['n_sample']
38 | self.hr_size = opt_get(opt, ['datasets', 'train', 'center_crop_hr_size'])
39 | self.hr_size = 160 if self.hr_size is None else self.hr_size
40 | self.lr_size = self.hr_size // opt['scale']
41 |
42 | if opt['dist']:
43 | self.rank = torch.distributed.get_rank()
44 | else:
45 | self.rank = -1 # non dist training
46 | train_opt = opt['train']
47 |
48 | # define network and load pretrained models
49 | self.netG = networks.define_Flow(opt, step).to(self.device)
50 | if opt['dist']:
51 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
52 | else:
53 | self.netG = DataParallel(self.netG)
54 | # print network
55 | self.print_network()
56 |
57 | if opt_get(opt, ['path', 'resume_state'], 1) is not None:
58 | self.load()
59 | else:
60 | print("WARNING: skipping initial loading, due to resume_state None")
61 |
62 | if self.is_train:
63 | self.netG.train()
64 |
65 | self.init_optimizer_and_scheduler(train_opt)
66 | self.log_dict = OrderedDict()
67 |
68 | def to(self, device):
69 | self.device = device
70 | self.netG.to(device)
71 |
72 | def init_optimizer_and_scheduler(self, train_opt):
73 | # optimizers
74 | self.optimizers = []
75 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
76 | optim_params_RRDB = []
77 | optim_params_other = []
78 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model
79 | print(k, v.requires_grad)
80 | if v.requires_grad:
81 | if '.RRDB.' in k:
82 | optim_params_RRDB.append(v)
83 | print('opt', k)
84 | else:
85 | optim_params_other.append(v)
86 | if self.rank <= 0:
87 | logger.warning('Params [{:s}] will not optimize.'.format(k))
88 |
89 | print('rrdb params', len(optim_params_RRDB))
90 |
91 | self.optimizer_G = torch.optim.Adam(
92 | [
93 | {"params": optim_params_other, "lr": train_opt['lr_G'], 'beta1': train_opt['beta1'],
94 | 'beta2': train_opt['beta2'], 'weight_decay': wd_G},
95 | {"params": optim_params_RRDB, "lr": train_opt.get('lr_RRDB', train_opt['lr_G']),
96 | 'beta1': train_opt['beta1'],
97 | 'beta2': train_opt['beta2'], 'weight_decay': wd_G}
98 | ],
99 | )
100 |
101 | self.optimizers.append(self.optimizer_G)
102 | # schedulers
103 | if train_opt['lr_scheme'] == 'MultiStepLR':
104 | for optimizer in self.optimizers:
105 | self.schedulers.append(
106 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
107 | restarts=train_opt['restarts'],
108 | weights=train_opt['restart_weights'],
109 | gamma=train_opt['lr_gamma'],
110 | clear_state=train_opt['clear_state'],
111 | lr_steps_invese=train_opt.get('lr_steps_inverse', [])))
112 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
113 | for optimizer in self.optimizers:
114 | self.schedulers.append(
115 | lr_scheduler.CosineAnnealingLR_Restart(
116 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
117 | restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
118 | else:
119 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
120 |
121 | def add_optimizer_and_scheduler_RRDB(self, train_opt):
122 | # optimizers
123 | assert len(self.optimizers) == 1, self.optimizers
124 | assert len(self.optimizer_G.param_groups[1]['params']) == 0, self.optimizer_G.param_groups[1]
125 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model
126 | if v.requires_grad:
127 | if '.RRDB.' in k:
128 | self.optimizer_G.param_groups[1]['params'].append(v)
129 | assert len(self.optimizer_G.param_groups[1]['params']) > 0
130 |
131 | def feed_data(self, data, need_GT=True):
132 | self.var_L = data['LQ'].to(self.device) # LQ
133 | if need_GT:
134 | self.real_H = data['GT'].to(self.device) # GT
135 |
136 | def optimize_parameters(self, step):
137 |
138 | train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
139 | if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \
140 | and not self.netG.module.RRDB_training:
141 | if self.netG.module.set_rrdb_training(True):
142 | self.add_optimizer_and_scheduler_RRDB(self.opt['train'])
143 |
144 | # self.print_rrdb_state()
145 |
146 | self.netG.train()
147 | self.log_dict = OrderedDict()
148 | self.optimizer_G.zero_grad()
149 |
150 | losses = {}
151 | weight_fl = opt_get(self.opt, ['train', 'weight_fl'])
152 | weight_fl = 1 if weight_fl is None else weight_fl
153 | if weight_fl > 0:
154 | z, nll, y_logits = self.netG(gt=self.real_H, lr=self.var_L, reverse=False)
155 | nll_loss = torch.mean(nll)
156 | losses['nll_loss'] = nll_loss * weight_fl
157 |
158 | weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0
159 | if weight_l1 > 0:
160 | z = self.get_z(heat=0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape)
161 | sr, logdet = self.netG(lr=self.var_L, z=z, eps_std=0, reverse=True, reverse_with_grad=True)
162 | l1_loss = (sr - self.real_H).abs().mean()
163 | losses['l1_loss'] = l1_loss * weight_l1
164 |
165 | total_loss = sum(losses.values())
166 | total_loss.backward()
167 | self.optimizer_G.step()
168 |
169 | mean = total_loss.item()
170 | return mean
171 |
172 | def print_rrdb_state(self):
173 | for name, param in self.netG.module.named_parameters():
174 | if "RRDB.conv_first.weight" in name:
175 | print(name, param.requires_grad, param.data.abs().sum())
176 | print('params', [len(p['params']) for p in self.optimizer_G.param_groups])
177 |
178 | def test(self):
179 | self.netG.eval()
180 | self.fake_H = {}
181 | for heat in self.heats:
182 | for i in range(self.n_sample):
183 | z = self.get_z(heat, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape)
184 | with torch.no_grad():
185 | self.fake_H[(heat, i)], logdet = self.netG(lr=self.var_L, z=z, eps_std=heat, reverse=True)
186 | with torch.no_grad():
187 | _, nll, _ = self.netG(gt=self.real_H, lr=self.var_L, reverse=False)
188 | self.netG.train()
189 | return nll.mean().item()
190 |
191 | def get_encode_nll(self, lq, gt):
192 | self.netG.eval()
193 | with torch.no_grad():
194 | _, nll, _ = self.netG(gt=gt, lr=lq, reverse=False)
195 | self.netG.train()
196 | return nll.mean().item()
197 |
198 | def get_sr(self, lq, heat=None, seed=None, z=None, epses=None):
199 | return self.get_sr_with_z(lq, heat, seed, z, epses)[0]
200 |
201 | def get_encode_z(self, lq, gt, epses=None, add_gt_noise=True):
202 | self.netG.eval()
203 | with torch.no_grad():
204 | z, _, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise)
205 | self.netG.train()
206 | return z
207 |
208 | def get_encode_z_and_nll(self, lq, gt, epses=None, add_gt_noise=True):
209 | self.netG.eval()
210 | with torch.no_grad():
211 | z, nll, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise)
212 | self.netG.train()
213 | return z, nll
214 |
215 | def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None):
216 | self.netG.eval()
217 |
218 | z = self.get_z(heat, seed, batch_size=lq.shape[0], lr_shape=lq.shape) if z is None and epses is None else z
219 |
220 | with torch.no_grad():
221 | sr, logdet = self.netG(lr=lq, z=z, eps_std=heat, reverse=True, epses=epses)
222 | self.netG.train()
223 | return sr, z
224 |
225 | def get_z(self, heat, seed=None, batch_size=1, lr_shape=None):
226 | if seed: torch.manual_seed(seed)
227 | if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']):
228 | C = self.netG.module.flowUpsamplerNet.C
229 | H = int(self.opt['scale'] * lr_shape[2] // self.netG.module.flowUpsamplerNet.scaleH)
230 | W = int(self.opt['scale'] * lr_shape[3] // self.netG.module.flowUpsamplerNet.scaleW)
231 | z = torch.normal(mean=0, std=heat, size=(batch_size, C, H, W)) if heat > 0 else torch.zeros(
232 | (batch_size, C, H, W))
233 | else:
234 | L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3
235 | fac = 2 ** (L - 3)
236 | z_size = int(self.lr_size // (2 ** (L - 3)))
237 | z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size))
238 | return z
239 |
240 | def get_current_log(self):
241 | return self.log_dict
242 |
243 | def get_current_visuals(self, need_GT=True):
244 | out_dict = OrderedDict()
245 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
246 | for heat in self.heats:
247 | for i in range(self.n_sample):
248 | out_dict[('SR', heat, i)] = self.fake_H[(heat, i)].detach()[0].float().cpu()
249 | if need_GT:
250 | out_dict['GT'] = self.real_H.detach()[0].float().cpu()
251 | return out_dict
252 |
253 | def print_network(self):
254 | s, n = self.get_network_description(self.netG)
255 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
256 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
257 | self.netG.module.__class__.__name__)
258 | else:
259 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
260 | if self.rank <= 0:
261 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
262 | logger.info(s)
263 |
264 | def load(self):
265 | _, get_resume_model_path = get_resume_paths(self.opt)
266 | if get_resume_model_path is not None:
267 | self.load_network(get_resume_model_path, self.netG, strict=True, submodule=None)
268 | return
269 |
270 | load_path_G = self.opt['path']['pretrain_model_G']
271 | load_submodule = self.opt['path']['load_submodule'] if 'load_submodule' in self.opt['path'].keys() else 'RRDB'
272 | if load_path_G is not None:
273 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
274 | self.load_network(load_path_G, self.netG, self.opt['path'].get('strict_load', True),
275 | submodule=load_submodule)
276 |
277 | def save(self, iter_label):
278 | self.save_network(self.netG, 'G', iter_label)
279 |
--------------------------------------------------------------------------------
/code/models/SR_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | import logging
18 | from collections import OrderedDict
19 |
20 | import torch
21 | import torch.nn as nn
22 | from torch.nn.parallel import DataParallel, DistributedDataParallel
23 | import models.networks as networks
24 | import models.lr_scheduler as lr_scheduler
25 | from utils.util import opt_get
26 | from .base_model import BaseModel
27 | from models.modules.loss import CharbonnierLoss
28 |
29 | logger = logging.getLogger('base')
30 |
31 |
32 | class SRModel(BaseModel):
33 | def __init__(self, opt, step):
34 | super(SRModel, self).__init__(opt)
35 |
36 | self.step = step
37 |
38 | if opt['dist']:
39 | self.rank = torch.distributed.get_rank()
40 | else:
41 | self.rank = -1 # non dist training
42 | train_opt = opt['train']
43 |
44 | # define network and load pretrained_models models
45 | self.netG = networks.define_G(opt).to(self.device)
46 | if opt['dist']:
47 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()])
48 | else:
49 | self.netG = DataParallel(self.netG)
50 | # print network
51 | self.print_network()
52 | self.load()
53 |
54 | if self.is_train:
55 | self.netG.train()
56 |
57 | # loss
58 | loss_type = train_opt['pixel_criterion']
59 | if loss_type == 'l1':
60 | self.cri_pix = nn.L1Loss().to(self.device)
61 | elif loss_type == 'l2':
62 | self.cri_pix = nn.MSELoss().to(self.device)
63 | elif loss_type == 'cb':
64 | self.cri_pix = CharbonnierLoss().to(self.device)
65 | else:
66 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type))
67 | self.l_pix_w = train_opt['pixel_weight']
68 |
69 | # optimizers
70 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0
71 | optim_params = []
72 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model
73 | if v.requires_grad:
74 | optim_params.append(v)
75 | else:
76 | if self.rank <= 0:
77 | logger.warning('Params [{:s}] will not optimize.'.format(k))
78 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'],
79 | weight_decay=wd_G,
80 | betas=(train_opt['beta1'], train_opt['beta2']))
81 | self.optimizers.append(self.optimizer_G)
82 |
83 | # schedulers
84 | if train_opt['lr_scheme'] == 'MultiStepLR':
85 | for optimizer in self.optimizers:
86 | self.schedulers.append(
87 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'],
88 | restarts=train_opt['restarts'],
89 | weights=train_opt['restart_weights'],
90 | gamma=train_opt['lr_gamma'],
91 | clear_state=train_opt['clear_state']))
92 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart':
93 | for optimizer in self.optimizers:
94 | self.schedulers.append(
95 | lr_scheduler.CosineAnnealingLR_Restart(
96 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'],
97 | restarts=train_opt['restarts'], weights=train_opt['restart_weights']))
98 | else:
99 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.')
100 |
101 | self.log_dict = OrderedDict()
102 |
103 | def feed_data(self, data, need_GT=True):
104 | self.var_L = data['LQ'].to(self.device) # LQ
105 | if need_GT:
106 | self.real_H = data['GT'].to(self.device) # GT
107 |
108 | def to(self, device):
109 | self.device = device
110 | self.netG.to(device)
111 |
112 | def optimize_parameters(self, step):
113 | def getEnv(name): import os; return True if name in os.environ.keys() else False
114 |
115 | if getEnv("DEBUG_FEED_IMAGES"):
116 | import imageio
117 | import random
118 | i = random.randint(0, 10000)
119 | label = self.var_L.cpu().numpy()[0].transpose([1, 2, 0])
120 | print("var_L", label.min(), label.max(), label.shape)
121 | imageio.imwrite("/tmp/{}_l.png".format(i), label)
122 | image = self.real_H.cpu().numpy()[0].transpose([1, 2, 0])
123 | print("self.real_H", image.min(), image.max(), image.shape)
124 | imageio.imwrite("/tmp/{}_gt.png".format(i), image)
125 | self.optimizer_G.zero_grad()
126 | self.fake_H = self.netG(self.var_L)
127 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H.to(self.fake_H.device))
128 | l_pix.backward()
129 | self.optimizer_G.step()
130 |
131 | # set log
132 | self.log_dict['l_pix'] = l_pix.item()
133 |
134 | def test(self):
135 | self.netG.eval()
136 | with torch.no_grad():
137 | self.fake_H = self.netG(self.var_L)
138 | self.netG.train()
139 |
140 | def get_encode_nll(self, lq, gt):
141 | return torch.ones(1) * 1e14
142 |
143 | def get_sr(self, lq, heat=None, seed=None):
144 | self.netG.eval()
145 | sr = self.netG(lq)
146 | self.netG.train()
147 | return sr
148 |
149 | def test_x8(self):
150 | # from https://github.com/thstkdgus35/EDSR-PyTorch
151 | self.netG.eval()
152 |
153 | def _transform(v, op):
154 | # if self.precision != 'single': v = v.float()
155 | v2np = v.data.cpu().numpy()
156 | if op == 'v':
157 | tfnp = v2np[:, :, :, ::-1].copy()
158 | elif op == 'h':
159 | tfnp = v2np[:, :, ::-1, :].copy()
160 | elif op == 't':
161 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
162 |
163 | ret = torch.Tensor(tfnp).to(self.device)
164 | # if self.precision == 'half': ret = ret.half()
165 |
166 | return ret
167 |
168 | lr_list = [self.var_L]
169 | for tf in 'v', 'h', 't':
170 | lr_list.extend([_transform(t, tf) for t in lr_list])
171 | with torch.no_grad():
172 | sr_list = [self.netG(aug) for aug in lr_list]
173 | for i in range(len(sr_list)):
174 | if i > 3:
175 | sr_list[i] = _transform(sr_list[i], 't')
176 | if i % 4 > 1:
177 | sr_list[i] = _transform(sr_list[i], 'h')
178 | if (i % 4) % 2 == 1:
179 | sr_list[i] = _transform(sr_list[i], 'v')
180 |
181 | output_cat = torch.cat(sr_list, dim=0)
182 | self.fake_H = output_cat.mean(dim=0, keepdim=True)
183 | self.netG.train()
184 |
185 | def get_current_log(self):
186 | return self.log_dict
187 |
188 | def get_current_visuals(self, need_GT=True):
189 | out_dict = OrderedDict()
190 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu()
191 | out_dict['SR'] = self.fake_H.detach()[0].float().cpu()
192 | if need_GT:
193 | out_dict['GT'] = self.real_H.detach()[0].float().cpu()
194 | return out_dict
195 |
196 | def print_network(self):
197 | s, n = self.get_network_description(self.netG)
198 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel):
199 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__,
200 | self.netG.module.__class__.__name__)
201 | else:
202 | net_struc_str = '{}'.format(self.netG.__class__.__name__)
203 | if self.rank <= 0:
204 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n))
205 | logger.info(s)
206 |
207 | def load(self):
208 | load_path_G = self.opt['path']['pretrain_model_G']
209 | if load_path_G is not None:
210 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G))
211 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load'])
212 |
213 | def save(self, iter_label):
214 | self.save_network(self.netG, 'G', iter_label)
215 |
216 | def get_encode_z_and_nll(self, *args, **kwargs):
217 | return [], torch.zeros(1)
218 |
--------------------------------------------------------------------------------
/code/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import logging
3 | import os
4 |
5 | try:
6 | import local_config
7 | except:
8 | local_config = None
9 |
10 |
11 | logger = logging.getLogger('base')
12 |
13 |
14 | def find_model_using_name(model_name):
15 | # Given the option --model [modelname],
16 | # the file "models/modelname_model.py"
17 | # will be imported.
18 | model_filename = "models." + model_name + "_model"
19 | modellib = importlib.import_module(model_filename)
20 |
21 | # In the file, the class called ModelNameModel() will
22 | # be instantiated. It has to be a subclass of torch.nn.Module,
23 | # and it is case-insensitive.
24 | model = None
25 | target_model_name = model_name.replace('_', '') + 'Model'
26 | for name, cls in modellib.__dict__.items():
27 | if name.lower() == target_model_name.lower():
28 | model = cls
29 |
30 | if model is None:
31 | print(
32 | "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % (
33 | model_filename, target_model_name))
34 | exit(0)
35 |
36 | return model
37 |
38 |
39 | def create_model(opt, step=0, **opt_kwargs):
40 | if local_config is not None:
41 | opt['path']['pretrain_model_G'] = os.path.join(local_config.checkpoint_path, os.path.basename(opt['path']['results_root'] + '.pth'))
42 |
43 | for k, v in opt_kwargs.items():
44 | opt[k] = v
45 |
46 | model = opt['model']
47 |
48 | M = find_model_using_name(model)
49 |
50 | m = M(opt, step)
51 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__))
52 | return m
53 |
--------------------------------------------------------------------------------
/code/models/base_model.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | import os
18 | from collections import OrderedDict
19 | import torch
20 | import torch.nn as nn
21 | from torch.nn.parallel import DistributedDataParallel
22 | import natsort
23 | import glob
24 |
25 |
26 | class BaseModel():
27 | def __init__(self, opt):
28 | self.opt = opt
29 | self.device = torch.device('cuda' if opt.get('gpu_ids', None) is not None else 'cpu')
30 | self.is_train = opt['is_train']
31 | self.schedulers = []
32 | self.optimizers = []
33 |
34 | def feed_data(self, data):
35 | pass
36 |
37 | def optimize_parameters(self):
38 | pass
39 |
40 | def get_current_visuals(self):
41 | pass
42 |
43 | def get_current_losses(self):
44 | pass
45 |
46 | def print_network(self):
47 | pass
48 |
49 | def save(self, label):
50 | pass
51 |
52 | def load(self):
53 | pass
54 |
55 | def _set_lr(self, lr_groups_l):
56 | ''' set learning rate for warmup,
57 | lr_groups_l: list for lr_groups. each for a optimizer'''
58 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l):
59 | for param_group, lr in zip(optimizer.param_groups, lr_groups):
60 | param_group['lr'] = lr
61 |
62 | def _get_init_lr(self):
63 | # get the initial lr, which is set by the scheduler
64 | init_lr_groups_l = []
65 | for optimizer in self.optimizers:
66 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
67 | return init_lr_groups_l
68 |
69 | def update_learning_rate(self, cur_iter, warmup_iter=-1):
70 | for scheduler in self.schedulers:
71 | scheduler.step()
72 | #### set up warm up learning rate
73 | if cur_iter < warmup_iter:
74 | # get initial lr for each group
75 | init_lr_g_l = self._get_init_lr()
76 | # modify warming-up learning rates
77 | warm_up_lr_l = []
78 | for init_lr_g in init_lr_g_l:
79 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g])
80 | # set learning rate
81 | self._set_lr(warm_up_lr_l)
82 |
83 | def get_current_learning_rate(self):
84 | # return self.schedulers[0].get_lr()[0]
85 | return self.optimizers[0].param_groups[0]['lr']
86 |
87 | def get_network_description(self, network):
88 | '''Get the string and total parameters of the network'''
89 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
90 | network = network.module
91 | s = str(network)
92 | n = sum(map(lambda x: x.numel(), network.parameters()))
93 | return s, n
94 |
95 | def save_network(self, network, network_label, iter_label):
96 | paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['models'], "*_{}.pth".format(network_label))),
97 | reverse=True)
98 | paths = [p for p in paths if
99 | "latest_" not in p and not any([str(i * 10000) in p.split("/")[-1].split("_") for i in range(101)])]
100 | if len(paths) > 2:
101 | for path in paths[2:]:
102 | os.remove(path)
103 | save_filename = '{}_{}.pth'.format(iter_label, network_label)
104 | save_path = os.path.join(self.opt['path']['models'], save_filename)
105 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
106 | network = network.module
107 | state_dict = network.state_dict()
108 | for key, param in state_dict.items():
109 | state_dict[key] = param.cpu()
110 | torch.save(state_dict, save_path)
111 |
112 | def load_network(self, load_path, network, strict=True, submodule=None):
113 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel):
114 | network = network.module
115 | if not (submodule is None or submodule.lower() == 'none'.lower()):
116 | network = network.__getattr__(submodule)
117 | load_net = torch.load(load_path)
118 | load_net_clean = OrderedDict() # remove unnecessary 'module.'
119 | for k, v in load_net.items():
120 | if k.startswith('module.'):
121 | load_net_clean[k[7:]] = v
122 | else:
123 | load_net_clean[k] = v
124 | network.load_state_dict(load_net_clean, strict=strict)
125 |
126 | def save_training_state(self, epoch, iter_step):
127 | '''Saves training state during training, which will be used for resuming'''
128 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []}
129 | for s in self.schedulers:
130 | state['schedulers'].append(s.state_dict())
131 | for o in self.optimizers:
132 | state['optimizers'].append(o.state_dict())
133 | save_filename = '{}.state'.format(iter_step)
134 | save_path = os.path.join(self.opt['path']['training_state'], save_filename)
135 |
136 | paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['training_state'], "*.state")),
137 | reverse=True)
138 | paths = [p for p in paths if "latest_" not in p]
139 | if len(paths) > 2:
140 | for path in paths[2:]:
141 | os.remove(path)
142 |
143 | torch.save(state, save_path)
144 |
145 | def resume_training(self, resume_state):
146 | '''Resume the optimizers and schedulers for training'''
147 | resume_optimizers = resume_state['optimizers']
148 | resume_schedulers = resume_state['schedulers']
149 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers'
150 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers'
151 | for i, o in enumerate(resume_optimizers):
152 | self.optimizers[i].load_state_dict(o)
153 | for i, s in enumerate(resume_schedulers):
154 | self.schedulers[i].load_state_dict(s)
155 |
--------------------------------------------------------------------------------
/code/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | import math
18 | from collections import Counter
19 | from collections import defaultdict
20 | import torch
21 | from torch.optim.lr_scheduler import _LRScheduler
22 |
23 |
24 | class MultiStepLR_Restart(_LRScheduler):
25 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1,
26 | clear_state=False, last_epoch=-1, lr_steps_invese=None):
27 | assert lr_steps_invese is not None, "Use empty list"
28 | self.milestones = Counter(milestones)
29 | self.lr_steps_inverse = Counter(lr_steps_invese)
30 | self.gamma = gamma
31 | self.clear_state = clear_state
32 | self.restarts = restarts if restarts else [0]
33 | self.restart_weights = weights if weights else [1]
34 | assert len(self.restarts) == len(
35 | self.restart_weights), 'restarts and their weights do not match.'
36 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch)
37 |
38 | def get_lr(self):
39 | if self.last_epoch in self.restarts:
40 | if self.clear_state:
41 | self.optimizer.state = defaultdict(dict)
42 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
43 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
44 | if self.last_epoch not in self.milestones and self.last_epoch not in self.lr_steps_inverse:
45 | return [group['lr'] for group in self.optimizer.param_groups]
46 | return [
47 | group['lr'] * (self.gamma ** self.milestones[self.last_epoch]) *
48 | (self.gamma ** (-self.lr_steps_inverse[self.last_epoch]))
49 | for group in self.optimizer.param_groups
50 | ]
51 |
52 |
53 | class CosineAnnealingLR_Restart(_LRScheduler):
54 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1):
55 | self.T_period = T_period
56 | self.T_max = self.T_period[0] # current T period
57 | self.eta_min = eta_min
58 | self.restarts = restarts if restarts else [0]
59 | self.restart_weights = weights if weights else [1]
60 | self.last_restart = 0
61 | assert len(self.restarts) == len(
62 | self.restart_weights), 'restarts and their weights do not match.'
63 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch)
64 |
65 | def get_lr(self):
66 | if self.last_epoch == 0:
67 | return self.base_lrs
68 | elif self.last_epoch in self.restarts:
69 | self.last_restart = self.last_epoch
70 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1]
71 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
72 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
73 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0:
74 | return [
75 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
76 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
77 | ]
78 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) /
79 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) *
80 | (group['lr'] - self.eta_min) + self.eta_min
81 | for group in self.optimizer.param_groups]
82 |
83 |
84 | if __name__ == "__main__":
85 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0,
86 | betas=(0.9, 0.99))
87 | ##############################
88 | # MultiStepLR_Restart
89 | ##############################
90 | ## Original
91 | lr_steps = [200000, 400000, 600000, 800000]
92 | restarts = None
93 | restart_weights = None
94 |
95 | ## two
96 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000]
97 | restarts = [500000]
98 | restart_weights = [1]
99 |
100 | ## four
101 | lr_steps = [
102 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000,
103 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000
104 | ]
105 | restarts = [250000, 500000, 750000]
106 | restart_weights = [1, 1, 1]
107 |
108 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5,
109 | clear_state=False)
110 |
111 | ##############################
112 | # Cosine Annealing Restart
113 | ##############################
114 | ## two
115 | T_period = [500000, 500000]
116 | restarts = [500000]
117 | restart_weights = [1]
118 |
119 | ## four
120 | T_period = [250000, 250000, 250000, 250000]
121 | restarts = [250000, 500000, 750000]
122 | restart_weights = [1, 1, 1]
123 |
124 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts,
125 | weights=restart_weights)
126 |
127 | ##############################
128 | # Draw figure
129 | ##############################
130 | N_iter = 1000000
131 | lr_l = list(range(N_iter))
132 | for i in range(N_iter):
133 | scheduler.step()
134 | current_lr = optimizer.param_groups[0]['lr']
135 | lr_l[i] = current_lr
136 |
137 | import matplotlib as mpl
138 | from matplotlib import pyplot as plt
139 | import matplotlib.ticker as mtick
140 |
141 | mpl.style.use('default')
142 | import seaborn
143 |
144 | seaborn.set(style='whitegrid')
145 | seaborn.set_context('paper')
146 |
147 | plt.figure(1)
148 | plt.subplot(111)
149 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0))
150 | plt.title('Title', fontsize=16, color='k')
151 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme')
152 | legend = plt.legend(loc='upper right', shadow=False)
153 | ax = plt.gca()
154 | labels = ax.get_xticks().tolist()
155 | for k, v in enumerate(labels):
156 | labels[k] = str(int(v / 1000)) + 'K'
157 | ax.set_xticklabels(labels)
158 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e'))
159 |
160 | ax.set_ylabel('Learning rate')
161 | ax.set_xlabel('Iteration')
162 | fig = plt.gcf()
163 | plt.show()
164 |
--------------------------------------------------------------------------------
/code/models/modules/FlowActNorms.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import torch
18 | from torch import nn as nn
19 |
20 | from models.modules import thops
21 |
22 |
23 | class _ActNorm(nn.Module):
24 | """
25 | Activation Normalization
26 | Initialize the bias and scale with a given minibatch,
27 | so that the output per-channel have zero mean and unit variance for that.
28 |
29 | After initialization, `bias` and `logs` will be trained as parameters.
30 | """
31 |
32 | def __init__(self, num_features, scale=1.):
33 | super().__init__()
34 | # register mean and scale
35 | size = [1, num_features, 1, 1]
36 | self.register_parameter("bias", nn.Parameter(torch.zeros(*size)))
37 | self.register_parameter("logs", nn.Parameter(torch.zeros(*size)))
38 | self.num_features = num_features
39 | self.scale = float(scale)
40 | self.inited = False
41 |
42 | def _check_input_dim(self, input):
43 | return NotImplemented
44 |
45 | def initialize_parameters(self, input):
46 | self._check_input_dim(input)
47 | if not self.training:
48 | return
49 | if (self.bias != 0).any():
50 | self.inited = True
51 | return
52 | assert input.device == self.bias.device, (input.device, self.bias.device)
53 | with torch.no_grad():
54 | bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0
55 | vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True)
56 | logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6))
57 | self.bias.data.copy_(bias.data)
58 | self.logs.data.copy_(logs.data)
59 | self.inited = True
60 |
61 | def _center(self, input, reverse=False, offset=None):
62 | bias = self.bias
63 |
64 | if offset is not None:
65 | bias = bias + offset
66 |
67 | if not reverse:
68 | return input + bias
69 | else:
70 | return input - bias
71 |
72 | def _scale(self, input, logdet=None, reverse=False, offset=None):
73 | logs = self.logs
74 |
75 | if offset is not None:
76 | logs = logs + offset
77 |
78 | if not reverse:
79 | input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1
80 | # input = input * torch.exp(logs+logs_offset)
81 | else:
82 | input = input * torch.exp(-logs)
83 | if logdet is not None:
84 | """
85 | logs is log_std of `mean of channels`
86 | so we need to multiply pixels
87 | """
88 | dlogdet = thops.sum(logs) * thops.pixels(input)
89 | if reverse:
90 | dlogdet *= -1
91 | logdet = logdet + dlogdet
92 | return input, logdet
93 |
94 | def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None):
95 | if not self.inited:
96 | self.initialize_parameters(input)
97 | self._check_input_dim(input)
98 |
99 | if offset_mask is not None:
100 | logs_offset *= offset_mask
101 | bias_offset *= offset_mask
102 | # no need to permute dims as old version
103 | if not reverse:
104 | # center and scale
105 |
106 | # self.input = input
107 | input = self._center(input, reverse, bias_offset)
108 | input, logdet = self._scale(input, logdet, reverse, logs_offset)
109 | else:
110 | # scale and center
111 | input, logdet = self._scale(input, logdet, reverse, logs_offset)
112 | input = self._center(input, reverse, bias_offset)
113 | return input, logdet
114 |
115 |
116 | class ActNorm2d(_ActNorm):
117 | def __init__(self, num_features, scale=1.):
118 | super().__init__(num_features, scale)
119 |
120 | def _check_input_dim(self, input):
121 | assert len(input.size()) == 4
122 | assert input.size(1) == self.num_features, (
123 | "[ActNorm]: input should be in shape as `BCHW`,"
124 | " channels should be {} rather than {}".format(
125 | self.num_features, input.size()))
126 |
127 |
128 | class MaskedActNorm2d(ActNorm2d):
129 | def __init__(self, num_features, scale=1.):
130 | super().__init__(num_features, scale)
131 |
132 | def forward(self, input, mask, logdet=None, reverse=False):
133 |
134 | assert mask.dtype == torch.bool
135 | output, logdet_out = super().forward(input, logdet, reverse)
136 |
137 | input[mask] = output[mask]
138 | logdet[mask] = logdet_out[mask]
139 |
140 | return input, logdet
141 |
142 |
--------------------------------------------------------------------------------
/code/models/modules/FlowAffineCouplingsAblation.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import torch
18 | from torch import nn as nn
19 |
20 | from models.modules import thops
21 | from models.modules.flow import Conv2d, Conv2dZeros
22 | from utils.util import opt_get
23 |
24 |
25 | class CondAffineSeparatedAndCond(nn.Module):
26 | def __init__(self, in_channels, opt):
27 | super().__init__()
28 | self.need_features = True
29 | self.in_channels = in_channels
30 | self.in_channels_rrdb = 320
31 | self.kernel_hidden = 1
32 | self.affine_eps = 0.0001
33 | self.n_hidden_layers = 1
34 | hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels'])
35 | self.hidden_channels = 64 if hidden_channels is None else hidden_channels
36 |
37 | self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001)
38 |
39 | self.channels_for_nn = self.in_channels // 2
40 | self.channels_for_co = self.in_channels - self.channels_for_nn
41 |
42 | if self.channels_for_nn is None:
43 | self.channels_for_nn = self.in_channels // 2
44 |
45 | self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb,
46 | out_channels=self.channels_for_co * 2,
47 | hidden_channels=self.hidden_channels,
48 | kernel_hidden=self.kernel_hidden,
49 | n_hidden_layers=self.n_hidden_layers)
50 |
51 | self.fFeatures = self.F(in_channels=self.in_channels_rrdb,
52 | out_channels=self.in_channels * 2,
53 | hidden_channels=self.hidden_channels,
54 | kernel_hidden=self.kernel_hidden,
55 | n_hidden_layers=self.n_hidden_layers)
56 |
57 | def forward(self, input: torch.Tensor, logdet=None, reverse=False, ft=None):
58 | if not reverse:
59 | z = input
60 | assert z.shape[1] == self.in_channels, (z.shape[1], self.in_channels)
61 |
62 | # Feature Conditional
63 | scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
64 | z = z + shiftFt
65 | z = z * scaleFt
66 | logdet = logdet + self.get_logdet(scaleFt)
67 |
68 | # Self Conditional
69 | z1, z2 = self.split(z)
70 | scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
71 | self.asserts(scale, shift, z1, z2)
72 | z2 = z2 + shift
73 | z2 = z2 * scale
74 |
75 | logdet = logdet + self.get_logdet(scale)
76 | z = thops.cat_feature(z1, z2)
77 | output = z
78 | else:
79 | z = input
80 |
81 | # Self Conditional
82 | z1, z2 = self.split(z)
83 | scale, shift = self.feature_extract_aff(z1, ft, self.fAffine)
84 | self.asserts(scale, shift, z1, z2)
85 | z2 = z2 / scale
86 | z2 = z2 - shift
87 | z = thops.cat_feature(z1, z2)
88 | logdet = logdet - self.get_logdet(scale)
89 |
90 | # Feature Conditional
91 | scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures)
92 | z = z / scaleFt
93 | z = z - shiftFt
94 | logdet = logdet - self.get_logdet(scaleFt)
95 |
96 | output = z
97 | return output, logdet
98 |
99 | def asserts(self, scale, shift, z1, z2):
100 | assert z1.shape[1] == self.channels_for_nn, (z1.shape[1], self.channels_for_nn)
101 | assert z2.shape[1] == self.channels_for_co, (z2.shape[1], self.channels_for_co)
102 | assert scale.shape[1] == shift.shape[1], (scale.shape[1], shift.shape[1])
103 | assert scale.shape[1] == z2.shape[1], (scale.shape[1], z1.shape[1], z2.shape[1])
104 |
105 | def get_logdet(self, scale):
106 | return thops.sum(torch.log(scale), dim=[1, 2, 3])
107 |
108 | def feature_extract(self, z, f):
109 | h = f(z)
110 | shift, scale = thops.split_feature(h, "cross")
111 | scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
112 | return scale, shift
113 |
114 | def feature_extract_aff(self, z1, ft, f):
115 | z = torch.cat([z1, ft], dim=1)
116 | h = f(z)
117 | shift, scale = thops.split_feature(h, "cross")
118 | scale = (torch.sigmoid(scale + 2.) + self.affine_eps)
119 | return scale, shift
120 |
121 | def split(self, z):
122 | z1 = z[:, :self.channels_for_nn]
123 | z2 = z[:, self.channels_for_nn:]
124 | assert z1.shape[1] + z2.shape[1] == z.shape[1], (z1.shape[1], z2.shape[1], z.shape[1])
125 | return z1, z2
126 |
127 | def F(self, in_channels, out_channels, hidden_channels, kernel_hidden=1, n_hidden_layers=1):
128 | layers = [Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=False)]
129 |
130 | for _ in range(n_hidden_layers):
131 | layers.append(Conv2d(hidden_channels, hidden_channels, kernel_size=[kernel_hidden, kernel_hidden]))
132 | layers.append(nn.ReLU(inplace=False))
133 | layers.append(Conv2dZeros(hidden_channels, out_channels))
134 |
135 | return nn.Sequential(*layers)
136 |
--------------------------------------------------------------------------------
/code/models/modules/FlowStep.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import torch
18 | from torch import nn as nn
19 |
20 | import models.modules
21 | import models.modules.Permutations
22 | from models.modules import flow, thops, FlowAffineCouplingsAblation
23 | from utils.util import opt_get
24 |
25 |
26 | def getConditional(rrdbResults, position):
27 | img_ft = rrdbResults if isinstance(rrdbResults, torch.Tensor) else rrdbResults[position]
28 | return img_ft
29 |
30 |
31 | class FlowStep(nn.Module):
32 | FlowPermutation = {
33 | "reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet),
34 | "shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet),
35 | "invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
36 | "squeeze_invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
37 | "resqueeze_invconv_alternating_2_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
38 | "resqueeze_invconv_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
39 | "InvertibleConv1x1GridAlign": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
40 | "InvertibleConv1x1SubblocksShuf": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
41 | "InvertibleConv1x1GridAlignIndepBorder": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
42 | "InvertibleConv1x1GridAlignIndepBorder4": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev),
43 | }
44 |
45 | def __init__(self, in_channels, hidden_channels,
46 | actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive",
47 | LU_decomposed=False, opt=None, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None,
48 | position=None):
49 | # check configures
50 | assert flow_permutation in FlowStep.FlowPermutation, \
51 | "float_permutation should be in `{}`".format(
52 | FlowStep.FlowPermutation.keys())
53 | super().__init__()
54 | self.flow_permutation = flow_permutation
55 | self.flow_coupling = flow_coupling
56 | self.image_injector = image_injector
57 |
58 | self.norm_type = normOpt['type'] if normOpt else 'ActNorm2d'
59 | self.position = normOpt['position'] if normOpt else None
60 |
61 | self.in_shape = in_shape
62 | self.position = position
63 | self.acOpt = acOpt
64 |
65 | # 1. actnorm
66 | self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale)
67 |
68 | # 2. permute
69 | if flow_permutation == "invconv":
70 | self.invconv = models.modules.Permutations.InvertibleConv1x1(
71 | in_channels, LU_decomposed=LU_decomposed)
72 |
73 | # 3. coupling
74 | if flow_coupling == "CondAffineSeparatedAndCond":
75 | self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels,
76 | opt=opt)
77 | elif flow_coupling == "noCoupling":
78 | pass
79 | else:
80 | raise RuntimeError("coupling not Found:", flow_coupling)
81 |
82 | def forward(self, input, logdet=None, reverse=False, rrdbResults=None):
83 | if not reverse:
84 | return self.normal_flow(input, logdet, rrdbResults)
85 | else:
86 | return self.reverse_flow(input, logdet, rrdbResults)
87 |
88 | def normal_flow(self, z, logdet, rrdbResults=None):
89 | if self.flow_coupling == "bentIdentityPreAct":
90 | z, logdet = self.bentIdentPar(z, logdet, reverse=False)
91 |
92 | # 1. actnorm
93 | if self.norm_type == "ConditionalActNormImageInjector":
94 | img_ft = getConditional(rrdbResults, self.position)
95 | z, logdet = self.actnorm(z, img_ft=img_ft, logdet=logdet, reverse=False)
96 | elif self.norm_type == "noNorm":
97 | pass
98 | else:
99 | z, logdet = self.actnorm(z, logdet=logdet, reverse=False)
100 |
101 | # 2. permute
102 | z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
103 | self, z, logdet, False)
104 |
105 | need_features = self.affine_need_features()
106 |
107 | # 3. coupling
108 | if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
109 | img_ft = getConditional(rrdbResults, self.position)
110 | z, logdet = self.affine(input=z, logdet=logdet, reverse=False, ft=img_ft)
111 | return z, logdet
112 |
113 | def reverse_flow(self, z, logdet, rrdbResults=None):
114 |
115 | need_features = self.affine_need_features()
116 |
117 | # 1.coupling
118 | if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]:
119 | img_ft = getConditional(rrdbResults, self.position)
120 | z, logdet = self.affine(input=z, logdet=logdet, reverse=True, ft=img_ft)
121 |
122 | # 2. permute
123 | z, logdet = FlowStep.FlowPermutation[self.flow_permutation](
124 | self, z, logdet, True)
125 |
126 | # 3. actnorm
127 | z, logdet = self.actnorm(z, logdet=logdet, reverse=True)
128 |
129 | return z, logdet
130 |
131 | def affine_need_features(self):
132 | need_features = False
133 | try:
134 | need_features = self.affine.need_features
135 | except:
136 | pass
137 | return need_features
138 |
--------------------------------------------------------------------------------
/code/models/modules/FlowUpsamplerNet.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import numpy as np
18 | import torch
19 | from torch import nn as nn
20 |
21 | import models.modules.Split
22 | from models.modules import flow, thops
23 | from models.modules.Split import Split2d
24 | from models.modules.glow_arch import f_conv2d_bias
25 | from models.modules.FlowStep import FlowStep
26 | from utils.util import opt_get
27 |
28 |
29 | class FlowUpsamplerNet(nn.Module):
30 | def __init__(self, image_shape, hidden_channels, K, L=None,
31 | actnorm_scale=1.0,
32 | flow_permutation=None,
33 | flow_coupling="affine",
34 | LU_decomposed=False, opt=None):
35 |
36 | super().__init__()
37 |
38 | self.layers = nn.ModuleList()
39 | self.output_shapes = []
40 | self.L = opt_get(opt, ['network_G', 'flow', 'L'])
41 | self.K = opt_get(opt, ['network_G', 'flow', 'K'])
42 | if isinstance(self.K, int):
43 | self.K = [K for K in [K, ] * (self.L + 1)]
44 |
45 | self.opt = opt
46 | H, W, self.C = image_shape
47 | self.check_image_shape()
48 |
49 | if opt['scale'] == 16:
50 | self.levelToName = {
51 | 0: 'fea_up16',
52 | 1: 'fea_up8',
53 | 2: 'fea_up4',
54 | 3: 'fea_up2',
55 | 4: 'fea_up1',
56 | }
57 |
58 | if opt['scale'] == 8:
59 | self.levelToName = {
60 | 0: 'fea_up8',
61 | 1: 'fea_up4',
62 | 2: 'fea_up2',
63 | 3: 'fea_up1',
64 | 4: 'fea_up0'
65 | }
66 |
67 | elif opt['scale'] == 4:
68 | self.levelToName = {
69 | 0: 'fea_up4',
70 | 1: 'fea_up2',
71 | 2: 'fea_up1',
72 | 3: 'fea_up0',
73 | 4: 'fea_up-1'
74 | }
75 |
76 | affineInCh = self.get_affineInCh(opt_get)
77 | flow_permutation = self.get_flow_permutation(flow_permutation, opt)
78 |
79 | normOpt = opt_get(opt, ['network_G', 'flow', 'norm'])
80 |
81 | conditional_channels = {}
82 | n_rrdb = self.get_n_rrdb_channels(opt, opt_get)
83 | n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels'])
84 | conditional_channels[0] = n_rrdb
85 | for level in range(1, self.L + 1):
86 | # Level 1 gets conditionals from 2, 3, 4 => L - level
87 | # Level 2 gets conditionals from 3, 4
88 | # Level 3 gets conditionals from 4
89 | # Level 4 gets conditionals from None
90 | n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels
91 | conditional_channels[level] = n_rrdb + n_bypass
92 |
93 | # Upsampler
94 | for level in range(1, self.L + 1):
95 | # 1. Squeeze
96 | H, W = self.arch_squeeze(H, W)
97 |
98 | # 2. K FlowStep
99 | self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels, opt)
100 | self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling,
101 | flow_permutation,
102 | hidden_channels, normOpt, opt, opt_get,
103 | n_conditinal_channels=conditional_channels[level])
104 | # Split
105 | self.arch_split(H, W, level, self.L, opt, opt_get)
106 |
107 | if opt_get(opt, ['network_G', 'flow', 'split', 'enable']):
108 | self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2)
109 | else:
110 | self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64)
111 |
112 | self.H = H
113 | self.W = W
114 | self.scaleH = 160 / H
115 | self.scaleW = 160 / W
116 |
117 | def get_n_rrdb_channels(self, opt, opt_get):
118 | blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks'])
119 | n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64
120 | return n_rrdb
121 |
122 | def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation,
123 | hidden_channels, normOpt, opt, opt_get, n_conditinal_channels=None):
124 | condAff = self.get_condAffSetting(opt, opt_get)
125 | if condAff is not None:
126 | condAff['in_channels_rrdb'] = n_conditinal_channels
127 |
128 | for k in range(K):
129 | position_name = get_position_name(H, self.opt['scale'])
130 | if normOpt: normOpt['position'] = position_name
131 |
132 | self.layers.append(
133 | FlowStep(in_channels=self.C,
134 | hidden_channels=hidden_channels,
135 | actnorm_scale=actnorm_scale,
136 | flow_permutation=flow_permutation,
137 | flow_coupling=flow_coupling,
138 | acOpt=condAff,
139 | position=position_name,
140 | LU_decomposed=LU_decomposed, opt=opt, idx=k, normOpt=normOpt))
141 | self.output_shapes.append(
142 | [-1, self.C, H, W])
143 |
144 | def get_condAffSetting(self, opt, opt_get):
145 | condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None
146 | condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff
147 | return condAff
148 |
149 | def arch_split(self, H, W, L, levels, opt, opt_get):
150 | correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False)
151 | correction = 0 if correct_splits else 1
152 | if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction:
153 | logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0
154 | consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5
155 | position_name = get_position_name(H, self.opt['scale'])
156 | position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None
157 | cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels'])
158 | cond_channels = 0 if cond_channels is None else cond_channels
159 |
160 | t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d')
161 |
162 | if t == 'Split2d':
163 | split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position,
164 | cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt)
165 | self.layers.append(split)
166 | self.output_shapes.append([-1, split.num_channels_pass, H, W])
167 | self.C = split.num_channels_pass
168 |
169 | def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt):
170 | if 'additionalFlowNoAffine' in opt['network_G']['flow']:
171 | n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine'])
172 | for _ in range(n_additionalFlowNoAffine):
173 | self.layers.append(
174 | FlowStep(in_channels=self.C,
175 | hidden_channels=hidden_channels,
176 | actnorm_scale=actnorm_scale,
177 | flow_permutation='invconv',
178 | flow_coupling='noCoupling',
179 | LU_decomposed=LU_decomposed, opt=opt))
180 | self.output_shapes.append(
181 | [-1, self.C, H, W])
182 |
183 | def arch_squeeze(self, H, W):
184 | self.C, H, W = self.C * 4, H // 2, W // 2
185 | self.layers.append(flow.SqueezeLayer(factor=2))
186 | self.output_shapes.append([-1, self.C, H, W])
187 | return H, W
188 |
189 | def get_flow_permutation(self, flow_permutation, opt):
190 | flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv')
191 | return flow_permutation
192 |
193 | def get_affineInCh(self, opt_get):
194 | affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
195 | affineInCh = (len(affineInCh) + 1) * 64
196 | return affineInCh
197 |
198 | def check_image_shape(self):
199 | assert self.C == 1 or self.C == 3, ("image_shape should be HWC, like (64, 64, 3)"
200 | "self.C == 1 or self.C == 3")
201 |
202 | def forward(self, gt=None, rrdbResults=None, z=None, epses=None, logdet=0., reverse=False, eps_std=None,
203 | y_onehot=None):
204 |
205 | if reverse:
206 | epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses
207 |
208 | sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot)
209 | return sr, logdet
210 | else:
211 | assert gt is not None
212 | assert rrdbResults is not None
213 | z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot)
214 |
215 | return z, logdet
216 |
217 | def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None):
218 | fl_fea = gt
219 | reverse = False
220 | level_conditionals = {}
221 | bypasses = {}
222 |
223 | L = opt_get(self.opt, ['network_G', 'flow', 'L'])
224 |
225 | for level in range(1, L + 1):
226 | bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False)
227 |
228 | for layer, shape in zip(self.layers, self.output_shapes):
229 | size = shape[2]
230 | level = int(np.log(160 / size) / np.log(2))
231 |
232 | if level > 0 and level not in level_conditionals.keys():
233 | level_conditionals[level] = rrdbResults[self.levelToName[level]]
234 |
235 | level_conditionals[level] = rrdbResults[self.levelToName[level]]
236 |
237 | if isinstance(layer, FlowStep):
238 | fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level])
239 | elif isinstance(layer, Split2d):
240 | fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level],
241 | y_onehot=y_onehot)
242 | else:
243 | fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse)
244 |
245 | z = fl_fea
246 |
247 | if not isinstance(epses, list):
248 | return z, logdet
249 |
250 | epses.append(z)
251 | return epses, logdet
252 |
253 | def forward_preFlow(self, fl_fea, logdet, reverse):
254 | if hasattr(self, 'preFlow'):
255 | for l in self.preFlow:
256 | fl_fea, logdet = l(fl_fea, logdet, reverse=reverse)
257 | return fl_fea, logdet
258 |
259 | def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None):
260 | ft = None if layer.position is None else rrdbResults[layer.position]
261 | fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot)
262 |
263 | if isinstance(epses, list):
264 | epses.append(eps)
265 | return fl_fea, logdet
266 |
267 | def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None):
268 | z = epses.pop() if isinstance(epses, list) else z
269 |
270 | fl_fea = z
271 | # debug.imwrite("fl_fea", fl_fea)
272 | bypasses = {}
273 | level_conditionals = {}
274 | if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True:
275 | for level in range(self.L + 1):
276 | level_conditionals[level] = rrdbResults[self.levelToName[level]]
277 |
278 | for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)):
279 | size = shape[2]
280 | level = int(np.log(160 / size) / np.log(2))
281 | # size = fl_fea.shape[2]
282 | # level = int(np.log(160 / size) / np.log(2))
283 |
284 | if isinstance(layer, Split2d):
285 | fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer,
286 | rrdbResults[self.levelToName[level]], logdet=logdet,
287 | y_onehot=y_onehot)
288 | elif isinstance(layer, FlowStep):
289 | fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level])
290 | else:
291 | fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True)
292 |
293 | sr = fl_fea
294 |
295 | assert sr.shape[1] == 3
296 | return sr, logdet
297 |
298 | def forward_split2d_reverse(self, eps_std, epses, fl_fea, layer, rrdbResults, logdet, y_onehot=None):
299 | ft = None if layer.position is None else rrdbResults[layer.position]
300 | fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True,
301 | eps=epses.pop() if isinstance(epses, list) else None,
302 | eps_std=eps_std, ft=ft, y_onehot=y_onehot)
303 | return fl_fea, logdet
304 |
305 |
306 | def get_position_name(H, scale):
307 | downscale_factor = 160 // H
308 | position_name = 'fea_up{}'.format(scale / downscale_factor)
309 | return position_name
310 |
--------------------------------------------------------------------------------
/code/models/modules/Permutations.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import numpy as np
18 | import torch
19 | from torch import nn as nn
20 | from torch.nn import functional as F
21 |
22 | from models.modules import thops
23 |
24 |
25 | class InvertibleConv1x1(nn.Module):
26 | def __init__(self, num_channels, LU_decomposed=False):
27 | super().__init__()
28 | w_shape = [num_channels, num_channels]
29 | w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32)
30 | self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init)))
31 | self.w_shape = w_shape
32 | self.LU = LU_decomposed
33 |
34 | def get_weight(self, input, reverse):
35 | w_shape = self.w_shape
36 | pixels = thops.pixels(input)
37 | dlogdet = torch.slogdet(self.weight)[1] * pixels
38 | if not reverse:
39 | weight = self.weight.view(w_shape[0], w_shape[1], 1, 1)
40 | else:
41 | weight = torch.inverse(self.weight.double()).float() \
42 | .view(w_shape[0], w_shape[1], 1, 1)
43 | return weight, dlogdet
44 | def forward(self, input, logdet=None, reverse=False):
45 | """
46 | log-det = log|abs(|W|)| * pixels
47 | """
48 | weight, dlogdet = self.get_weight(input, reverse)
49 | if not reverse:
50 | z = F.conv2d(input, weight)
51 | if logdet is not None:
52 | logdet = logdet + dlogdet
53 | return z, logdet
54 | else:
55 | z = F.conv2d(input, weight)
56 | if logdet is not None:
57 | logdet = logdet - dlogdet
58 | return z, logdet
59 |
--------------------------------------------------------------------------------
/code/models/modules/RRDBNet_arch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import functools
18 | import torch
19 | import torch.nn as nn
20 | import torch.nn.functional as F
21 | import models.modules.module_util as mutil
22 | from utils.util import opt_get
23 |
24 |
25 | class ResidualDenseBlock_5C(nn.Module):
26 | def __init__(self, nf=64, gc=32, bias=True):
27 | super(ResidualDenseBlock_5C, self).__init__()
28 | # gc: growth channel, i.e. intermediate channels
29 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
30 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
31 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
32 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
33 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
34 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
35 |
36 | # initialization
37 | mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)
38 |
39 | def forward(self, x):
40 | x1 = self.lrelu(self.conv1(x))
41 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
42 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
43 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
44 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
45 | return x5 * 0.2 + x
46 |
47 |
48 | class RRDB(nn.Module):
49 | '''Residual in Residual Dense Block'''
50 |
51 | def __init__(self, nf, gc=32):
52 | super(RRDB, self).__init__()
53 | self.RDB1 = ResidualDenseBlock_5C(nf, gc)
54 | self.RDB2 = ResidualDenseBlock_5C(nf, gc)
55 | self.RDB3 = ResidualDenseBlock_5C(nf, gc)
56 |
57 | def forward(self, x):
58 | out = self.RDB1(x)
59 | out = self.RDB2(out)
60 | out = self.RDB3(out)
61 | return out * 0.2 + x
62 |
63 |
64 | class RRDBNet(nn.Module):
65 | def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None):
66 | self.opt = opt
67 | super(RRDBNet, self).__init__()
68 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
69 | self.scale = scale
70 |
71 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
72 | self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb)
73 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
74 | #### upsampling
75 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
76 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
77 | if self.scale >= 8:
78 | self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
79 | if self.scale >= 16:
80 | self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
81 | if self.scale >= 32:
82 | self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
83 |
84 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
85 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)
86 |
87 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
88 |
89 | def forward(self, x, get_steps=False):
90 | fea = self.conv_first(x)
91 |
92 | block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
93 | block_results = {}
94 |
95 | for idx, m in enumerate(self.RRDB_trunk.children()):
96 | fea = m(fea)
97 | for b in block_idxs:
98 | if b == idx:
99 | block_results["block_{}".format(idx)] = fea
100 |
101 | trunk = self.trunk_conv(fea)
102 |
103 | last_lr_fea = fea + trunk
104 |
105 | fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest'))
106 | fea = self.lrelu(fea_up2)
107 |
108 | fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))
109 | fea = self.lrelu(fea_up4)
110 |
111 | fea_up8 = None
112 | fea_up16 = None
113 | fea_up32 = None
114 |
115 | if self.scale >= 8:
116 | fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest'))
117 | fea = self.lrelu(fea_up8)
118 | if self.scale >= 16:
119 | fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest'))
120 | fea = self.lrelu(fea_up16)
121 | if self.scale >= 32:
122 | fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest'))
123 | fea = self.lrelu(fea_up32)
124 |
125 | out = self.conv_last(self.lrelu(self.HRconv(fea)))
126 |
127 | results = {'last_lr_fea': last_lr_fea,
128 | 'fea_up1': last_lr_fea,
129 | 'fea_up2': fea_up2,
130 | 'fea_up4': fea_up4,
131 | 'fea_up8': fea_up8,
132 | 'fea_up16': fea_up16,
133 | 'fea_up32': fea_up32,
134 | 'out': out}
135 |
136 | fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False
137 | if fea_up0_en:
138 | results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True)
139 | fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False
140 | if fea_upn1_en:
141 | results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True)
142 |
143 | if get_steps:
144 | for k, v in block_results.items():
145 | results[k] = v
146 | return results
147 | else:
148 | return out
149 |
--------------------------------------------------------------------------------
/code/models/modules/SRFlowNet_arch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import math
18 |
19 | import torch
20 | import torch.nn as nn
21 | import torch.nn.functional as F
22 | import numpy as np
23 | from models.modules.RRDBNet_arch import RRDBNet
24 | from models.modules.FlowUpsamplerNet import FlowUpsamplerNet
25 | import models.modules.thops as thops
26 | import models.modules.flow as flow
27 | from utils.util import opt_get
28 |
29 |
30 | class SRFlowNet(nn.Module):
31 | def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None):
32 | super(SRFlowNet, self).__init__()
33 |
34 | self.opt = opt
35 | self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \
36 | None else opt_get(opt, ['datasets', 'train', 'quant'])
37 | self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt)
38 | hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels'])
39 | hidden_channels = hidden_channels or 64
40 | self.RRDB_training = True # Default is true
41 |
42 | train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay'])
43 | set_RRDB_to_train = False
44 | if set_RRDB_to_train:
45 | self.set_rrdb_training(True)
46 |
47 | self.flowUpsamplerNet = \
48 | FlowUpsamplerNet((160, 160, 3), hidden_channels, K,
49 | flow_coupling=opt['network_G']['flow']['coupling'], opt=opt)
50 | self.i = 0
51 |
52 | def set_rrdb_training(self, trainable):
53 | if self.RRDB_training != trainable:
54 | for p in self.RRDB.parameters():
55 | p.requires_grad = trainable
56 | self.RRDB_training = trainable
57 | return True
58 | return False
59 |
60 | def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False,
61 | lr_enc=None,
62 | add_gt_noise=False, step=None, y_label=None):
63 | if not reverse:
64 | return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step,
65 | y_onehot=y_label)
66 | else:
67 | # assert lr.shape[0] == 1
68 | assert lr.shape[1] == 3
69 | # assert lr.shape[2] == 20
70 | # assert lr.shape[3] == 20
71 | # assert z.shape[0] == 1
72 | # assert z.shape[1] == 3 * 8 * 8
73 | # assert z.shape[2] == 20
74 | # assert z.shape[3] == 20
75 | if reverse_with_grad:
76 | return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
77 | add_gt_noise=add_gt_noise)
78 | else:
79 | with torch.no_grad():
80 | return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc,
81 | add_gt_noise=add_gt_noise)
82 |
83 | def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None):
84 | if lr_enc is None:
85 | lr_enc = self.rrdbPreprocessing(lr)
86 |
87 | logdet = torch.zeros_like(gt[:, 0, 0, 0])
88 | pixels = thops.pixels(gt)
89 |
90 | z = gt
91 |
92 | if add_gt_noise:
93 | # Setup
94 | noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True)
95 | if noiseQuant:
96 | z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant)
97 | logdet = logdet + float(-np.log(self.quant) * pixels)
98 |
99 | # Encode
100 | epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses,
101 | y_onehot=y_onehot)
102 |
103 | objective = logdet.clone()
104 |
105 | if isinstance(epses, (list, tuple)):
106 | z = epses[-1]
107 | else:
108 | z = epses
109 |
110 | objective = objective + flow.GaussianDiag.logp(None, None, z)
111 |
112 | nll = (-objective) / float(np.log(2.) * pixels)
113 |
114 | if isinstance(epses, list):
115 | return epses, nll, logdet
116 | return z, nll, logdet
117 |
118 | def rrdbPreprocessing(self, lr):
119 | rrdbResults = self.RRDB(lr, get_steps=True)
120 | block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or []
121 | if len(block_idxs) > 0:
122 | concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1)
123 |
124 | if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False:
125 | keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4']
126 | if 'fea_up0' in rrdbResults.keys():
127 | keys.append('fea_up0')
128 | if 'fea_up-1' in rrdbResults.keys():
129 | keys.append('fea_up-1')
130 | if self.opt['scale'] >= 8:
131 | keys.append('fea_up8')
132 | if self.opt['scale'] == 16:
133 | keys.append('fea_up16')
134 | for k in keys:
135 | h = rrdbResults[k].shape[2]
136 | w = rrdbResults[k].shape[3]
137 | rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1)
138 | return rrdbResults
139 |
140 | def get_score(self, disc_loss_sigma, z):
141 | score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \
142 | z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma)
143 | return -score_real
144 |
145 | def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True):
146 | logdet = torch.zeros_like(lr[:, 0, 0, 0])
147 | pixels = thops.pixels(lr) * self.opt['scale'] ** 2
148 |
149 | if add_gt_noise:
150 | logdet = logdet - float(-np.log(self.quant) * pixels)
151 |
152 | if lr_enc is None:
153 | lr_enc = self.rrdbPreprocessing(lr)
154 |
155 | x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses,
156 | logdet=logdet)
157 |
158 | return x, logdet
159 |
--------------------------------------------------------------------------------
/code/models/modules/Split.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import torch
18 | from torch import nn as nn
19 |
20 | from models.modules import thops
21 | from models.modules.FlowStep import FlowStep
22 | from models.modules.flow import Conv2dZeros, GaussianDiag
23 | from utils.util import opt_get
24 |
25 |
26 | class Split2d(nn.Module):
27 | def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None):
28 | super().__init__()
29 |
30 | self.num_channels_consume = int(round(num_channels * consume_ratio))
31 | self.num_channels_pass = num_channels - self.num_channels_consume
32 |
33 | self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels,
34 | out_channels=self.num_channels_consume * 2)
35 | self.logs_eps = logs_eps
36 | self.position = position
37 | self.opt = opt
38 |
39 | def split2d_prior(self, z, ft):
40 | if ft is not None:
41 | z = torch.cat([z, ft], dim=1)
42 | h = self.conv(z)
43 | return thops.split_feature(h, "cross")
44 |
45 | def exp_eps(self, logs):
46 | return torch.exp(logs) + self.logs_eps
47 |
48 | def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None):
49 | if not reverse:
50 | # self.input = input
51 | z1, z2 = self.split_ratio(input)
52 | mean, logs = self.split2d_prior(z1, ft)
53 |
54 | eps = (z2 - mean) / self.exp_eps(logs)
55 |
56 | logdet = logdet + self.get_logdet(logs, mean, z2)
57 |
58 | # print(logs.shape, mean.shape, z2.shape)
59 | # self.eps = eps
60 | # print('split, enc eps:', eps)
61 | return z1, logdet, eps
62 | else:
63 | z1 = input
64 | mean, logs = self.split2d_prior(z1, ft)
65 |
66 | if eps is None:
67 | #print("WARNING: eps is None, generating eps untested functionality!")
68 | eps = GaussianDiag.sample_eps(mean.shape, eps_std)
69 |
70 | eps = eps.to(mean.device)
71 | z2 = mean + self.exp_eps(logs) * eps
72 |
73 | z = thops.cat_feature(z1, z2)
74 | logdet = logdet - self.get_logdet(logs, mean, z2)
75 |
76 | return z, logdet
77 | # return z, logdet, eps
78 |
79 | def get_logdet(self, logs, mean, z2):
80 | logdet_diff = GaussianDiag.logp(mean, logs, z2)
81 | # print("Split2D: logdet diff", logdet_diff.item())
82 | return logdet_diff
83 |
84 | def split_ratio(self, input):
85 | z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...]
86 | return z1, z2
--------------------------------------------------------------------------------
/code/models/modules/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreas128/SRFlow/5a007ad591c7be8bf32cf23171bfa4473e71683c/code/models/modules/__init__.py
--------------------------------------------------------------------------------
/code/models/modules/flow.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | import numpy as np
21 |
22 | from models.modules.FlowActNorms import ActNorm2d
23 | from . import thops
24 |
25 |
26 | class Conv2d(nn.Conv2d):
27 | pad_dict = {
28 | "same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)],
29 | "valid": lambda kernel, stride: [0 for _ in kernel]
30 | }
31 |
32 | @staticmethod
33 | def get_padding(padding, kernel_size, stride):
34 | # make paddding
35 | if isinstance(padding, str):
36 | if isinstance(kernel_size, int):
37 | kernel_size = [kernel_size, kernel_size]
38 | if isinstance(stride, int):
39 | stride = [stride, stride]
40 | padding = padding.lower()
41 | try:
42 | padding = Conv2d.pad_dict[padding](kernel_size, stride)
43 | except KeyError:
44 | raise ValueError("{} is not supported".format(padding))
45 | return padding
46 |
47 | def __init__(self, in_channels, out_channels,
48 | kernel_size=[3, 3], stride=[1, 1],
49 | padding="same", do_actnorm=True, weight_std=0.05):
50 | padding = Conv2d.get_padding(padding, kernel_size, stride)
51 | super().__init__(in_channels, out_channels, kernel_size, stride,
52 | padding, bias=(not do_actnorm))
53 | # init weight with std
54 | self.weight.data.normal_(mean=0.0, std=weight_std)
55 | if not do_actnorm:
56 | self.bias.data.zero_()
57 | else:
58 | self.actnorm = ActNorm2d(out_channels)
59 | self.do_actnorm = do_actnorm
60 |
61 | def forward(self, input):
62 | x = super().forward(input)
63 | if self.do_actnorm:
64 | x, _ = self.actnorm(x)
65 | return x
66 |
67 |
68 | class Conv2dZeros(nn.Conv2d):
69 | def __init__(self, in_channels, out_channels,
70 | kernel_size=[3, 3], stride=[1, 1],
71 | padding="same", logscale_factor=3):
72 | padding = Conv2d.get_padding(padding, kernel_size, stride)
73 | super().__init__(in_channels, out_channels, kernel_size, stride, padding)
74 | # logscale_factor
75 | self.logscale_factor = logscale_factor
76 | self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1)))
77 | # init
78 | self.weight.data.zero_()
79 | self.bias.data.zero_()
80 |
81 | def forward(self, input):
82 | output = super().forward(input)
83 | return output * torch.exp(self.logs * self.logscale_factor)
84 |
85 |
86 | class GaussianDiag:
87 | Log2PI = float(np.log(2 * np.pi))
88 |
89 | @staticmethod
90 | def likelihood(mean, logs, x):
91 | """
92 | lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) }
93 | k = 1 (Independent)
94 | Var = logs ** 2
95 | """
96 | if mean is None and logs is None:
97 | return -0.5 * (x ** 2 + GaussianDiag.Log2PI)
98 | else:
99 | return -0.5 * (logs * 2. + ((x - mean) ** 2) / torch.exp(logs * 2.) + GaussianDiag.Log2PI)
100 |
101 | @staticmethod
102 | def logp(mean, logs, x):
103 | likelihood = GaussianDiag.likelihood(mean, logs, x)
104 | return thops.sum(likelihood, dim=[1, 2, 3])
105 |
106 | @staticmethod
107 | def sample(mean, logs, eps_std=None):
108 | eps_std = eps_std or 1
109 | eps = torch.normal(mean=torch.zeros_like(mean),
110 | std=torch.ones_like(logs) * eps_std)
111 | return mean + torch.exp(logs) * eps
112 |
113 | @staticmethod
114 | def sample_eps(shape, eps_std, seed=None):
115 | if seed is not None:
116 | torch.manual_seed(seed)
117 | eps = torch.normal(mean=torch.zeros(shape),
118 | std=torch.ones(shape) * eps_std)
119 | return eps
120 |
121 |
122 | def squeeze2d(input, factor=2):
123 | assert factor >= 1 and isinstance(factor, int)
124 | if factor == 1:
125 | return input
126 | size = input.size()
127 | B = size[0]
128 | C = size[1]
129 | H = size[2]
130 | W = size[3]
131 | assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor))
132 | x = input.view(B, C, H // factor, factor, W // factor, factor)
133 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous()
134 | x = x.view(B, C * factor * factor, H // factor, W // factor)
135 | return x
136 |
137 |
138 | def unsqueeze2d(input, factor=2):
139 | assert factor >= 1 and isinstance(factor, int)
140 | factor2 = factor ** 2
141 | if factor == 1:
142 | return input
143 | size = input.size()
144 | B = size[0]
145 | C = size[1]
146 | H = size[2]
147 | W = size[3]
148 | assert C % (factor2) == 0, "{}".format(C)
149 | x = input.view(B, C // factor2, factor, factor, H, W)
150 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous()
151 | x = x.view(B, C // (factor2), H * factor, W * factor)
152 | return x
153 |
154 |
155 | class SqueezeLayer(nn.Module):
156 | def __init__(self, factor):
157 | super().__init__()
158 | self.factor = factor
159 |
160 | def forward(self, input, logdet=None, reverse=False):
161 | if not reverse:
162 | output = squeeze2d(input, self.factor) # Squeeze in forward
163 | return output, logdet
164 | else:
165 | output = unsqueeze2d(input, self.factor)
166 | return output, logdet
167 |
--------------------------------------------------------------------------------
/code/models/modules/glow_arch.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import torch.nn as nn
18 |
19 |
20 | def f_conv2d_bias(in_channels, out_channels):
21 | def padding_same(kernel, stride):
22 | return [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)]
23 |
24 | padding = padding_same([3, 3], [1, 1])
25 | assert padding == [1, 1], padding
26 | return nn.Sequential(
27 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=[3, 3], stride=1, padding=1,
28 | bias=True))
29 |
--------------------------------------------------------------------------------
/code/models/modules/loss.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import torch
18 | import torch.nn as nn
19 |
20 |
21 | class CharbonnierLoss(nn.Module):
22 | """Charbonnier Loss (L1)"""
23 |
24 | def __init__(self, eps=1e-6):
25 | super(CharbonnierLoss, self).__init__()
26 | self.eps = eps
27 |
28 | def forward(self, x, y):
29 | diff = x - y
30 | loss = torch.sum(torch.sqrt(diff * diff + self.eps))
31 | return loss
32 |
33 |
34 | # Define GAN loss: [vanilla | lsgan | wgan-gp]
35 | class GANLoss(nn.Module):
36 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0):
37 | super(GANLoss, self).__init__()
38 | self.gan_type = gan_type.lower()
39 | self.real_label_val = real_label_val
40 | self.fake_label_val = fake_label_val
41 |
42 | if self.gan_type == 'gan' or self.gan_type == 'ragan':
43 | self.loss = nn.BCEWithLogitsLoss()
44 | elif self.gan_type == 'lsgan':
45 | self.loss = nn.MSELoss()
46 | elif self.gan_type == 'wgan-gp':
47 |
48 | def wgan_loss(input, target):
49 | # target is boolean
50 | return -1 * input.mean() if target else input.mean()
51 |
52 | self.loss = wgan_loss
53 | else:
54 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type))
55 |
56 | def get_target_label(self, input, target_is_real):
57 | if self.gan_type == 'wgan-gp':
58 | return target_is_real
59 | if target_is_real:
60 | return torch.empty_like(input).fill_(self.real_label_val)
61 | else:
62 | return torch.empty_like(input).fill_(self.fake_label_val)
63 |
64 | def forward(self, input, target_is_real):
65 | target_label = self.get_target_label(input, target_is_real)
66 | loss = self.loss(input, target_label)
67 | return loss
68 |
69 |
70 | class GradientPenaltyLoss(nn.Module):
71 | def __init__(self, device=torch.device('cpu')):
72 | super(GradientPenaltyLoss, self).__init__()
73 | self.register_buffer('grad_outputs', torch.Tensor())
74 | self.grad_outputs = self.grad_outputs.to(device)
75 |
76 | def get_grad_outputs(self, input):
77 | if self.grad_outputs.size() != input.size():
78 | self.grad_outputs.resize_(input.size()).fill_(1.0)
79 | return self.grad_outputs
80 |
81 | def forward(self, interp, interp_crit):
82 | grad_outputs = self.get_grad_outputs(interp_crit)
83 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp,
84 | grad_outputs=grad_outputs, create_graph=True,
85 | retain_graph=True, only_inputs=True)[0]
86 | grad_interp = grad_interp.view(grad_interp.size(0), -1)
87 | grad_interp_norm = grad_interp.norm(2, dim=1)
88 |
89 | loss = ((grad_interp_norm - 1)**2).mean()
90 | return loss
91 |
--------------------------------------------------------------------------------
/code/models/modules/module_util.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.init as init
20 | import torch.nn.functional as F
21 |
22 |
23 | def initialize_weights(net_l, scale=1):
24 | if not isinstance(net_l, list):
25 | net_l = [net_l]
26 | for net in net_l:
27 | for m in net.modules():
28 | if isinstance(m, nn.Conv2d):
29 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
30 | m.weight.data *= scale # for residual block
31 | if m.bias is not None:
32 | m.bias.data.zero_()
33 | elif isinstance(m, nn.Linear):
34 | init.kaiming_normal_(m.weight, a=0, mode='fan_in')
35 | m.weight.data *= scale
36 | if m.bias is not None:
37 | m.bias.data.zero_()
38 | elif isinstance(m, nn.BatchNorm2d):
39 | init.constant_(m.weight, 1)
40 | init.constant_(m.bias.data, 0.0)
41 |
42 |
43 | def make_layer(block, n_layers):
44 | layers = []
45 | for _ in range(n_layers):
46 | layers.append(block())
47 | return nn.Sequential(*layers)
48 |
49 |
50 | class ResidualBlock_noBN(nn.Module):
51 | '''Residual block w/o BN
52 | ---Conv-ReLU-Conv-+-
53 | |________________|
54 | '''
55 |
56 | def __init__(self, nf=64):
57 | super(ResidualBlock_noBN, self).__init__()
58 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
59 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
60 |
61 | # initialization
62 | initialize_weights([self.conv1, self.conv2], 0.1)
63 |
64 | def forward(self, x):
65 | identity = x
66 | out = F.relu(self.conv1(x), inplace=True)
67 | out = self.conv2(out)
68 | return identity + out
69 |
70 |
71 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'):
72 | """Warp an image or feature map with optical flow
73 | Args:
74 | x (Tensor): size (N, C, H, W)
75 | flow (Tensor): size (N, H, W, 2), normal value
76 | interp_mode (str): 'nearest' or 'bilinear'
77 | padding_mode (str): 'zeros' or 'border' or 'reflection'
78 |
79 | Returns:
80 | Tensor: warped image or feature map
81 | """
82 | assert x.size()[-2:] == flow.size()[1:3]
83 | B, C, H, W = x.size()
84 | # mesh grid
85 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W))
86 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2
87 | grid.requires_grad = False
88 | grid = grid.type_as(x)
89 | vgrid = grid + flow
90 | # scale grid to [-1,1]
91 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0
92 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0
93 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3)
94 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode)
95 | return output
96 |
--------------------------------------------------------------------------------
/code/models/modules/thops.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE
16 |
17 | import torch
18 |
19 |
20 | def sum(tensor, dim=None, keepdim=False):
21 | if dim is None:
22 | # sum up all dim
23 | return torch.sum(tensor)
24 | else:
25 | if isinstance(dim, int):
26 | dim = [dim]
27 | dim = sorted(dim)
28 | for d in dim:
29 | tensor = tensor.sum(dim=d, keepdim=True)
30 | if not keepdim:
31 | for i, d in enumerate(dim):
32 | tensor.squeeze_(d-i)
33 | return tensor
34 |
35 |
36 | def mean(tensor, dim=None, keepdim=False):
37 | if dim is None:
38 | # mean all dim
39 | return torch.mean(tensor)
40 | else:
41 | if isinstance(dim, int):
42 | dim = [dim]
43 | dim = sorted(dim)
44 | for d in dim:
45 | tensor = tensor.mean(dim=d, keepdim=True)
46 | if not keepdim:
47 | for i, d in enumerate(dim):
48 | tensor.squeeze_(d-i)
49 | return tensor
50 |
51 |
52 | def split_feature(tensor, type="split"):
53 | """
54 | type = ["split", "cross"]
55 | """
56 | C = tensor.size(1)
57 | if type == "split":
58 | return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...]
59 | elif type == "cross":
60 | return tensor[:, 0::2, ...], tensor[:, 1::2, ...]
61 |
62 |
63 | def cat_feature(tensor_a, tensor_b):
64 | return torch.cat((tensor_a, tensor_b), dim=1)
65 |
66 |
67 | def pixels(tensor):
68 | return int(tensor.size(2) * tensor.size(3))
--------------------------------------------------------------------------------
/code/models/networks.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | import importlib
18 |
19 | import torch
20 | import logging
21 | import models.modules.RRDBNet_arch as RRDBNet_arch
22 |
23 | logger = logging.getLogger('base')
24 |
25 |
26 | def find_model_using_name(model_name):
27 | model_filename = "models.modules." + model_name + "_arch"
28 | modellib = importlib.import_module(model_filename)
29 |
30 | model = None
31 | target_model_name = model_name.replace('_Net', '')
32 | for name, cls in modellib.__dict__.items():
33 | if name.lower() == target_model_name.lower():
34 | model = cls
35 |
36 | if model is None:
37 | print(
38 | "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % (
39 | model_filename, target_model_name))
40 | exit(0)
41 |
42 | return model
43 |
44 |
45 | ####################
46 | # define network
47 | ####################
48 | #### Generator
49 | def define_G(opt):
50 | opt_net = opt['network_G']
51 | which_model = opt_net['which_model_G']
52 |
53 | if which_model == 'RRDBNet':
54 | netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
55 | nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], opt=opt)
56 | elif which_model == 'EDSRNet':
57 | Arch = find_model_using_name(which_model)
58 | netG = Arch(scale=opt['scale'])
59 | elif which_model == 'rankSRGAN':
60 | Arch = find_model_using_name(which_model)
61 | netG = Arch(upscale=opt['scale'])
62 | # elif which_model == 'sft_arch': # SFT-GAN
63 | # netG = sft_arch.SFT_Net()
64 | else:
65 | raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model))
66 | return netG
67 |
68 |
69 | def define_Flow(opt, step):
70 | opt_net = opt['network_G']
71 | which_model = opt_net['which_model_G']
72 |
73 | Arch = find_model_using_name(which_model)
74 | netG = Arch(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'],
75 | nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], K=opt_net['flow']['K'], opt=opt, step=step)
76 |
77 | return netG
78 |
79 |
80 | #### Discriminator
81 | def define_D(opt):
82 | opt_net = opt['network_D']
83 | which_model = opt_net['which_model_D']
84 |
85 | if which_model == 'discriminator_vgg_128':
86 | hidden_units = opt_net.get('hidden_units', 8192)
87 | netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], hidden_units=hidden_units)
88 | else:
89 | raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model))
90 | return netD
91 |
92 |
93 | #### Define Network used for Perceptual Loss
94 | def define_F(opt, use_bn=False):
95 | gpu_ids = opt.get('gpu_ids', None)
96 | device = torch.device('cuda' if gpu_ids else 'cpu')
97 | # PyTorch pretrained_models VGG19-54, before ReLU.
98 | if use_bn:
99 | feature_layer = 49
100 | else:
101 | feature_layer = 34
102 | netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn,
103 | use_input_norm=True, device=device)
104 | netF.eval() # No need to train
105 | return netF
106 |
--------------------------------------------------------------------------------
/code/options/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreas128/SRFlow/5a007ad591c7be8bf32cf23171bfa4473e71683c/code/options/__init__.py
--------------------------------------------------------------------------------
/code/options/options.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | import os
18 | import os.path as osp
19 | import logging
20 | import yaml
21 | from utils.util import OrderedYaml
22 |
23 | Loader, Dumper = OrderedYaml()
24 |
25 |
26 | def parse(opt_path, is_train=True):
27 | with open(opt_path, mode='r') as f:
28 | opt = yaml.load(f, Loader=Loader)
29 | # export CUDA_VISIBLE_DEVICES
30 | gpu_list = ','.join(str(x) for x in opt.get('gpu_ids', []))
31 | # os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list
32 | # print('export CUDA_VISIBLE_DEVICES=' + gpu_list)
33 | opt['is_train'] = is_train
34 | if opt['distortion'] == 'sr':
35 | scale = opt['scale']
36 |
37 | # datasets
38 | for phase, dataset in opt['datasets'].items():
39 | phase = phase.split('_')[0]
40 | dataset['phase'] = phase
41 | if opt['distortion'] == 'sr':
42 | dataset['scale'] = scale
43 | is_lmdb = False
44 | if dataset.get('dataroot_GT', None) is not None:
45 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT'])
46 | if dataset['dataroot_GT'].endswith('lmdb'):
47 | is_lmdb = True
48 | if dataset.get('dataroot_LQ', None) is not None:
49 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ'])
50 | if dataset['dataroot_LQ'].endswith('lmdb'):
51 | is_lmdb = True
52 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img'
53 | if dataset['mode'].endswith('mc'): # for memcached
54 | dataset['data_type'] = 'mc'
55 | dataset['mode'] = dataset['mode'].replace('_mc', '')
56 |
57 | # path
58 | for key, path in opt['path'].items():
59 | if path and key in opt['path'] and key != 'strict_load':
60 | opt['path'][key] = osp.expanduser(path)
61 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
62 | if is_train:
63 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name'])
64 | opt['path']['experiments_root'] = experiments_root
65 | opt['path']['models'] = osp.join(experiments_root, 'models')
66 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state')
67 | opt['path']['log'] = experiments_root
68 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images')
69 |
70 | # change some options for debug mode
71 | if 'debug' in opt['name']:
72 | opt['train']['val_freq'] = 8
73 | opt['logger']['print_freq'] = 1
74 | opt['logger']['save_checkpoint_freq'] = 8
75 | else: # test
76 | if not opt['path'].get('results_root', None):
77 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
78 | opt['path']['results_root'] = results_root
79 | opt['path']['log'] = opt['path']['results_root']
80 |
81 | # network
82 | if opt['distortion'] == 'sr':
83 | opt['network_G']['scale'] = scale
84 |
85 | # relative learning rate
86 | if 'train' in opt:
87 | niter = opt['train']['niter']
88 | if 'T_period_rel' in opt['train']:
89 | opt['train']['T_period'] = [int(x * niter) for x in opt['train']['T_period_rel']]
90 | if 'restarts_rel' in opt['train']:
91 | opt['train']['restarts'] = [int(x * niter) for x in opt['train']['restarts_rel']]
92 | if 'lr_steps_rel' in opt['train']:
93 | opt['train']['lr_steps'] = [int(x * niter) for x in opt['train']['lr_steps_rel']]
94 | if 'lr_steps_inverse_rel' in opt['train']:
95 | opt['train']['lr_steps_inverse'] = [int(x * niter) for x in opt['train']['lr_steps_inverse_rel']]
96 | print(opt['train'])
97 |
98 | return opt
99 |
100 |
101 | def dict2str(opt, indent_l=1):
102 | '''dict to string for logger'''
103 | msg = ''
104 | for k, v in opt.items():
105 | if isinstance(v, dict):
106 | msg += ' ' * (indent_l * 2) + k + ':[\n'
107 | msg += dict2str(v, indent_l + 1)
108 | msg += ' ' * (indent_l * 2) + ']\n'
109 | else:
110 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n'
111 | return msg
112 |
113 |
114 | class NoneDict(dict):
115 | def __missing__(self, key):
116 | return None
117 |
118 |
119 | # convert to NoneDict, which return None for missing key.
120 | def dict_to_nonedict(opt):
121 | if isinstance(opt, dict):
122 | new_opt = dict()
123 | for key, sub_opt in opt.items():
124 | new_opt[key] = dict_to_nonedict(sub_opt)
125 | return NoneDict(**new_opt)
126 | elif isinstance(opt, list):
127 | return [dict_to_nonedict(sub_opt) for sub_opt in opt]
128 | else:
129 | return opt
130 |
131 |
132 | def check_resume(opt, resume_iter):
133 | '''Check resume states and pretrain_model paths'''
134 | logger = logging.getLogger('base')
135 | if opt['path']['resume_state']:
136 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get(
137 | 'pretrain_model_D', None) is not None:
138 | logger.warning('pretrain_model path will be ignored when resuming training.')
139 |
140 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'],
141 | '{}_G.pth'.format(resume_iter))
142 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G'])
143 | if 'gan' in opt['model']:
144 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'],
145 | '{}_D.pth'.format(resume_iter))
146 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D'])
147 |
--------------------------------------------------------------------------------
/code/prepare_data.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | import glob
16 | import os
17 | import sys
18 |
19 | import numpy as np
20 | import random
21 | import imageio
22 | import pickle
23 |
24 | from natsort import natsort
25 | from tqdm import tqdm
26 |
27 | def get_img_paths(dir_path, wildcard='*.png'):
28 | return natsort.natsorted(glob.glob(dir_path + '/' + wildcard))
29 |
30 | def create_all_dirs(path):
31 | if "." in path.split("/")[-1]:
32 | dirs = os.path.dirname(path)
33 | else:
34 | dirs = path
35 | os.makedirs(dirs, exist_ok=True)
36 |
37 | def to_pklv4(obj, path, vebose=False):
38 | create_all_dirs(path)
39 | with open(path, 'wb') as f:
40 | pickle.dump(obj, f, protocol=4)
41 | if vebose:
42 | print("Wrote {}".format(path))
43 |
44 |
45 | from imresize import imresize
46 |
47 | def random_crop(img, size):
48 | h, w, c = img.shape
49 |
50 | h_start = np.random.randint(0, h - size)
51 | h_end = h_start + size
52 |
53 | w_start = np.random.randint(0, w - size)
54 | w_end = w_start + size
55 |
56 | return img[h_start:h_end, w_start:w_end]
57 |
58 |
59 | def imread(img_path):
60 | img = imageio.imread(img_path)
61 | if len(img.shape) == 2:
62 | img = np.stack([img, ] * 3, axis=2)
63 | return img
64 |
65 |
66 | def to_pklv4_1pct(obj, path, vebose):
67 | n = int(round(len(obj) * 0.01))
68 | path = path.replace(".", "_1pct.")
69 | to_pklv4(obj[:n], path, vebose=True)
70 |
71 |
72 | def main(dir_path):
73 | hrs = []
74 | lqs = []
75 |
76 | img_paths = get_img_paths(dir_path)
77 | for img_path in tqdm(img_paths):
78 | img = imread(img_path)
79 |
80 | for i in range(47):
81 | crop = random_crop(img, 160)
82 | cropX4 = imresize(crop, scalar_scale=0.25)
83 | hrs.append(crop)
84 | lqs.append(cropX4)
85 |
86 | shuffle_combined(hrs, lqs)
87 |
88 | hrs_path = get_hrs_path(dir_path)
89 | to_pklv4(hrs, hrs_path, vebose=True)
90 | to_pklv4_1pct(hrs, hrs_path, vebose=True)
91 |
92 | lqs_path = get_lqs_path(dir_path)
93 | to_pklv4(lqs, lqs_path, vebose=True)
94 | to_pklv4_1pct(lqs, lqs_path, vebose=True)
95 |
96 |
97 | def get_hrs_path(dir_path):
98 | base_dir = os.path.dirname(dir_path)
99 | name = os.path.basename(dir_path)
100 | hrs_path = os.path.join(base_dir, 'pkls', name + '.pklv4')
101 | return hrs_path
102 |
103 |
104 | def get_lqs_path(dir_path):
105 | base_dir = os.path.dirname(dir_path)
106 | name = os.path.basename(dir_path)
107 | hrs_path = os.path.join(base_dir, 'pkls', name + '_X4.pklv4')
108 | return hrs_path
109 |
110 |
111 | def shuffle_combined(hrs, lqs):
112 | combined = list(zip(hrs, lqs))
113 | random.shuffle(combined)
114 | hrs[:], lqs[:] = zip(*combined)
115 |
116 |
117 | if __name__ == "__main__":
118 | dir_path = sys.argv[1]
119 | assert os.path.isdir(dir_path)
120 | main(dir_path)
121 |
--------------------------------------------------------------------------------
/code/test.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 |
18 | import glob
19 | import sys
20 | from collections import OrderedDict
21 |
22 | from natsort import natsort
23 |
24 | import options.options as option
25 | from Measure import Measure, psnr
26 | from imresize import imresize
27 | from models import create_model
28 | import torch
29 | from utils.util import opt_get
30 | import numpy as np
31 | import pandas as pd
32 | import os
33 | import cv2
34 |
35 |
36 | def fiFindByWildcard(wildcard):
37 | return natsort.natsorted(glob.glob(wildcard, recursive=True))
38 |
39 |
40 | def load_model(conf_path):
41 | opt = option.parse(conf_path, is_train=False)
42 | opt['gpu_ids'] = None
43 | opt = option.dict_to_nonedict(opt)
44 | model = create_model(opt)
45 |
46 | model_path = opt_get(opt, ['model_path'], None)
47 | model.load_network(load_path=model_path, network=model.netG)
48 | return model, opt
49 |
50 |
51 | def predict(model, lr):
52 | model.feed_data({"LQ": t(lr)}, need_GT=False)
53 | model.test()
54 | visuals = model.get_current_visuals(need_GT=False)
55 | return visuals.get('rlt', visuals.get("SR"))
56 |
57 |
58 | def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)) / 255
59 |
60 |
61 | def rgb(t): return (
62 | np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype(
63 | np.uint8)
64 |
65 |
66 | def imread(path):
67 | return cv2.imread(path)[:, :, [2, 1, 0]]
68 |
69 |
70 | def imwrite(path, img):
71 | os.makedirs(os.path.dirname(path), exist_ok=True)
72 | cv2.imwrite(path, img[:, :, [2, 1, 0]])
73 |
74 |
75 | def imCropCenter(img, size):
76 | h, w, c = img.shape
77 |
78 | h_start = max(h // 2 - size // 2, 0)
79 | h_end = min(h_start + size, h)
80 |
81 | w_start = max(w // 2 - size // 2, 0)
82 | w_end = min(w_start + size, w)
83 |
84 | return img[h_start:h_end, w_start:w_end]
85 |
86 |
87 | def impad(img, top=0, bottom=0, left=0, right=0, color=255):
88 | return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect')
89 |
90 |
91 | def main():
92 | conf_path = sys.argv[1]
93 | conf = conf_path.split('/')[-1].replace('.yml', '')
94 | model, opt = load_model(conf_path)
95 |
96 | lr_dir = opt['dataroot_LR']
97 | hr_dir = opt['dataroot_GT']
98 |
99 | lr_paths = fiFindByWildcard(os.path.join(lr_dir, '*.png'))
100 | hr_paths = fiFindByWildcard(os.path.join(hr_dir, '*.png'))
101 |
102 | this_dir = os.path.dirname(os.path.realpath(__file__))
103 | test_dir = os.path.join(this_dir, '..', 'results', conf)
104 | print(f"Out dir: {test_dir}")
105 |
106 | measure = Measure(use_gpu=False)
107 |
108 | fname = f'measure_full.csv'
109 | fname_tmp = fname + "_"
110 | path_out_measures = os.path.join(test_dir, fname_tmp)
111 | path_out_measures_final = os.path.join(test_dir, fname)
112 |
113 | if os.path.isfile(path_out_measures_final):
114 | df = pd.read_csv(path_out_measures_final)
115 | elif os.path.isfile(path_out_measures):
116 | df = pd.read_csv(path_out_measures)
117 | else:
118 | df = None
119 |
120 | scale = opt['scale']
121 |
122 | pad_factor = 2
123 |
124 | for lr_path, hr_path, idx_test in zip(lr_paths, hr_paths, range(len(lr_paths))):
125 |
126 | lr = imread(lr_path)
127 | hr = imread(hr_path)
128 |
129 | # Pad image to be % 2
130 | h, w, c = lr.shape
131 | lq_orig = lr.copy()
132 | lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h),
133 | right=int(np.ceil(w / pad_factor) * pad_factor - w))
134 |
135 | lr_t = t(lr)
136 |
137 | heat = opt['heat']
138 |
139 | if df is not None and len(df[(df['heat'] == heat) & (df['name'] == idx_test)]) == 1:
140 | continue
141 |
142 | sr_t = model.get_sr(lq=lr_t, heat=heat)
143 |
144 | sr = rgb(torch.clamp(sr_t, 0, 1))
145 | sr = sr[:h * scale, :w * scale]
146 |
147 | path_out_sr = os.path.join(test_dir, "{:0.2f}".format(heat).replace('.', ''), "{:06d}.png".format(idx_test))
148 | imwrite(path_out_sr, sr)
149 |
150 | meas = OrderedDict(conf=conf, heat=heat, name=idx_test)
151 | meas['PSNR'], meas['SSIM'], meas['LPIPS'] = measure.measure(sr, hr)
152 |
153 | lr_reconstruct_rgb = imresize(sr, 1 / opt['scale'])
154 | meas['LRC PSNR'] = psnr(lq_orig, lr_reconstruct_rgb)
155 |
156 | str_out = format_measurements(meas)
157 | print(str_out)
158 |
159 | df = pd.DataFrame([meas]) if df is None else pd.concat([pd.DataFrame([meas]), df])
160 |
161 | df.to_csv(path_out_measures + "_", index=False)
162 | os.rename(path_out_measures + "_", path_out_measures)
163 |
164 | df.to_csv(path_out_measures, index=False)
165 | os.rename(path_out_measures, path_out_measures_final)
166 |
167 | str_out = format_measurements(df.mean())
168 | print(f"Results in: {path_out_measures_final}")
169 | print('Mean: ' + str_out)
170 |
171 |
172 | def format_measurements(meas):
173 | s_out = []
174 | for k, v in meas.items():
175 | v = f"{v:0.2f}" if isinstance(v, float) else v
176 | s_out.append(f"{k}: {v}")
177 | str_out = ", ".join(s_out)
178 | return str_out
179 |
180 |
181 | if __name__ == "__main__":
182 | main()
183 |
--------------------------------------------------------------------------------
/code/train.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | import os
18 | from os.path import basename
19 | import math
20 | import argparse
21 | import random
22 | import logging
23 | import cv2
24 |
25 | import torch
26 | import torch.distributed as dist
27 | import torch.multiprocessing as mp
28 |
29 | import options.options as option
30 | from utils import util
31 | from data import create_dataloader, create_dataset
32 | from models import create_model
33 | from utils.timer import Timer, TickTock
34 | from utils.util import get_resume_paths
35 |
36 |
37 | def getEnv(name): import os; return True if name in os.environ.keys() else False
38 |
39 |
40 | def init_dist(backend='nccl', **kwargs):
41 | ''' initialization for distributed training'''
42 | # if mp.get_start_method(allow_none=True) is None:
43 | if mp.get_start_method(allow_none=True) != 'spawn':
44 | mp.set_start_method('spawn')
45 | rank = int(os.environ['RANK'])
46 | num_gpus = torch.cuda.device_count()
47 | torch.cuda.set_deviceDistIterSampler(rank % num_gpus)
48 | dist.init_process_group(backend=backend, **kwargs)
49 |
50 |
51 | def main():
52 | #### options
53 | parser = argparse.ArgumentParser()
54 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.')
55 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none',
56 | help='job launcher')
57 | parser.add_argument('--local_rank', type=int, default=0)
58 | args = parser.parse_args()
59 | opt = option.parse(args.opt, is_train=True)
60 |
61 | #### distributed training settings
62 | opt['dist'] = False
63 | rank = -1
64 | print('Disabled distributed training.')
65 |
66 | #### loading resume state if exists
67 | if opt['path'].get('resume_state', None):
68 | resume_state_path, _ = get_resume_paths(opt)
69 |
70 | # distributed resuming: all load into default GPU
71 | if resume_state_path is None:
72 | resume_state = None
73 | else:
74 | device_id = torch.cuda.current_device()
75 | resume_state = torch.load(resume_state_path,
76 | map_location=lambda storage, loc: storage.cuda(device_id))
77 | option.check_resume(opt, resume_state['iter']) # check resume options
78 | else:
79 | resume_state = None
80 |
81 | #### mkdir and loggers
82 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0)
83 | if resume_state is None:
84 | util.mkdir_and_rename(
85 | opt['path']['experiments_root']) # rename experiment folder if exists
86 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root'
87 | and 'pretrain_model' not in key and 'resume' not in key))
88 |
89 | # config loggers. Before it, the log will not work
90 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO,
91 | screen=True, tofile=True)
92 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO,
93 | screen=True, tofile=True)
94 | logger = logging.getLogger('base')
95 | logger.info(option.dict2str(opt))
96 |
97 | # tensorboard logger
98 | if opt.get('use_tb_logger', False) and 'debug' not in opt['name']:
99 | version = float(torch.__version__[0:3])
100 | if version >= 1.1: # PyTorch 1.1
101 | from torch.utils.tensorboard import SummaryWriter
102 | else:
103 | logger.info(
104 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version))
105 | from tensorboardX import SummaryWriter
106 | conf_name = basename(args.opt).replace(".yml", "")
107 | exp_dir = opt['path']['experiments_root']
108 | log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train')
109 | log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid')
110 | tb_logger_train = SummaryWriter(log_dir=log_dir_train)
111 | tb_logger_valid = SummaryWriter(log_dir=log_dir_valid)
112 | else:
113 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True)
114 | logger = logging.getLogger('base')
115 |
116 | # convert to NoneDict, which returns None for missing keys
117 | opt = option.dict_to_nonedict(opt)
118 |
119 | #### random seed
120 | seed = opt['train']['manual_seed']
121 | if seed is None:
122 | seed = random.randint(1, 10000)
123 | if rank <= 0:
124 | logger.info('Random seed: {}'.format(seed))
125 | util.set_random_seed(seed)
126 |
127 | torch.backends.cudnn.benchmark = True
128 | # torch.backends.cudnn.deterministic = True
129 |
130 | #### create train and val dataloader
131 | dataset_ratio = 200 # enlarge the size of each epoch
132 | for phase, dataset_opt in opt['datasets'].items():
133 | if phase == 'train':
134 | train_set = create_dataset(dataset_opt)
135 | print('Dataset created')
136 | train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size']))
137 | total_iters = int(opt['train']['niter'])
138 | total_epochs = int(math.ceil(total_iters / train_size))
139 | train_sampler = None
140 | train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler)
141 | if rank <= 0:
142 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format(
143 | len(train_set), train_size))
144 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format(
145 | total_epochs, total_iters))
146 | elif phase == 'val':
147 | val_set = create_dataset(dataset_opt)
148 | val_loader = create_dataloader(val_set, dataset_opt, opt, None)
149 | if rank <= 0:
150 | logger.info('Number of val images in [{:s}]: {:d}'.format(
151 | dataset_opt['name'], len(val_set)))
152 | else:
153 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase))
154 | assert train_loader is not None
155 |
156 | #### create model
157 | current_step = 0 if resume_state is None else resume_state['iter']
158 | model = create_model(opt, current_step)
159 |
160 | #### resume training
161 | if resume_state:
162 | logger.info('Resuming training from epoch: {}, iter: {}.'.format(
163 | resume_state['epoch'], resume_state['iter']))
164 |
165 | start_epoch = resume_state['epoch']
166 | current_step = resume_state['iter']
167 | model.resume_training(resume_state) # handle optimizers and schedulers
168 | else:
169 | current_step = 0
170 | start_epoch = 0
171 |
172 | #### training
173 | timer = Timer()
174 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step))
175 | timerData = TickTock()
176 |
177 | for epoch in range(start_epoch, total_epochs + 1):
178 | if opt['dist']:
179 | train_sampler.set_epoch(epoch)
180 |
181 | timerData.tick()
182 | for _, train_data in enumerate(train_loader):
183 | timerData.tock()
184 | current_step += 1
185 | if current_step > total_iters:
186 | break
187 |
188 | #### training
189 | model.feed_data(train_data)
190 |
191 | #### update learning rate
192 | model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter'])
193 |
194 | try:
195 | nll = model.optimize_parameters(current_step)
196 | except RuntimeError as e:
197 | print("Skipping ERROR caught in nll = model.optimize_parameters(current_step): ")
198 | print(e)
199 |
200 | if nll is None:
201 | nll = 0
202 |
203 | #### log
204 | def eta(t_iter):
205 | return (t_iter * (opt['train']['niter'] - current_step)) / 3600
206 |
207 | if current_step % opt['logger']['print_freq'] == 0 \
208 | or current_step - (resume_state['iter'] if resume_state else 0) < 25:
209 | avg_time = timer.get_average_and_reset()
210 | avg_data_time = timerData.get_average_and_reset()
211 | message = ' '.format(
212 | epoch, current_step, model.get_current_learning_rate(), avg_time, avg_data_time,
213 | eta(avg_time), nll)
214 | print(message)
215 | timer.tick()
216 | # Reduce number of logs
217 | if current_step % 5 == 0:
218 | tb_logger_train.add_scalar('loss/nll', nll, current_step)
219 | tb_logger_train.add_scalar('lr/base', model.get_current_learning_rate(), current_step)
220 | tb_logger_train.add_scalar('time/iteration', timer.get_last_iteration(), current_step)
221 | tb_logger_train.add_scalar('time/data', timerData.get_last_iteration(), current_step)
222 | tb_logger_train.add_scalar('time/eta', eta(timer.get_last_iteration()), current_step)
223 | for k, v in model.get_current_log().items():
224 | tb_logger_train.add_scalar(k, v, current_step)
225 |
226 | # validation
227 | if current_step % opt['train']['val_freq'] == 0 and rank <= 0:
228 | avg_psnr = 0.0
229 | idx = 0
230 | nlls = []
231 | for val_data in val_loader:
232 | idx += 1
233 | img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0]
234 | img_dir = os.path.join(opt['path']['val_images'], img_name)
235 | util.mkdir(img_dir)
236 |
237 | model.feed_data(val_data)
238 |
239 | nll = model.test()
240 | if nll is None:
241 | nll = 0
242 | nlls.append(nll)
243 |
244 | visuals = model.get_current_visuals()
245 |
246 | sr_img = None
247 | # Save SR images for reference
248 | if hasattr(model, 'heats'):
249 | for heat in model.heats:
250 | for i in range(model.n_sample):
251 | sr_img = util.tensor2img(visuals['SR', heat, i]) # uint8
252 | save_img_path = os.path.join(img_dir,
253 | '{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name,
254 | current_step,
255 | int(heat * 100), i))
256 | util.save_img(sr_img, save_img_path)
257 | else:
258 | sr_img = util.tensor2img(visuals['SR']) # uint8
259 | save_img_path = os.path.join(img_dir,
260 | '{:s}_{:d}.png'.format(img_name, current_step))
261 | util.save_img(sr_img, save_img_path)
262 | assert sr_img is not None
263 |
264 | # Save LQ images for reference
265 | save_img_path_lq = os.path.join(img_dir,
266 | '{:s}_LQ.png'.format(img_name))
267 | if not os.path.isfile(save_img_path_lq):
268 | lq_img = util.tensor2img(visuals['LQ']) # uint8
269 | util.save_img(
270 | cv2.resize(lq_img, dsize=None, fx=opt['scale'], fy=opt['scale'],
271 | interpolation=cv2.INTER_NEAREST),
272 | save_img_path_lq)
273 |
274 | # Save GT images for reference
275 | gt_img = util.tensor2img(visuals['GT']) # uint8
276 | save_img_path_gt = os.path.join(img_dir,
277 | '{:s}_GT.png'.format(img_name))
278 | if not os.path.isfile(save_img_path_gt):
279 | util.save_img(gt_img, save_img_path_gt)
280 |
281 | # calculate PSNR
282 | crop_size = opt['scale']
283 | gt_img = gt_img / 255.
284 | sr_img = sr_img / 255.
285 | cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :]
286 | cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :]
287 | avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255)
288 |
289 | avg_psnr = avg_psnr / idx
290 | avg_nll = sum(nlls) / len(nlls)
291 |
292 | # log
293 | logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr))
294 | logger_val = logging.getLogger('val') # validation logger
295 | logger_val.info(' psnr: {:.4e}'.format(
296 | epoch, current_step, avg_psnr))
297 |
298 | # tensorboard logger
299 | tb_logger_valid.add_scalar('loss/psnr', avg_psnr, current_step)
300 | tb_logger_valid.add_scalar('loss/nll', avg_nll, current_step)
301 |
302 | tb_logger_train.flush()
303 | tb_logger_valid.flush()
304 |
305 | #### save models and training states
306 | if current_step % opt['logger']['save_checkpoint_freq'] == 0:
307 | if rank <= 0:
308 | logger.info('Saving models and training states.')
309 | model.save(current_step)
310 | model.save_training_state(epoch, current_step)
311 |
312 | timerData.tick()
313 |
314 | with open(os.path.join(opt['path']['root'], "TRAIN_DONE"), 'w') as f:
315 | f.write("TRAIN_DONE")
316 |
317 | if rank <= 0:
318 | logger.info('Saving the final model.')
319 | model.save('latest')
320 | logger.info('End of training.')
321 |
322 |
323 | if __name__ == '__main__':
324 | main()
325 |
--------------------------------------------------------------------------------
/code/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreas128/SRFlow/5a007ad591c7be8bf32cf23171bfa4473e71683c/code/utils/__init__.py
--------------------------------------------------------------------------------
/code/utils/timer.py:
--------------------------------------------------------------------------------
1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd.
2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode
7 | #
8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd.
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 | #
15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE
16 |
17 | import time
18 |
19 |
20 | class ScopeTimer:
21 | def __init__(self, name):
22 | self.name = name
23 |
24 | def __enter__(self):
25 | self.start = time.time()
26 | return self
27 |
28 | def __exit__(self, *args):
29 | self.end = time.time()
30 | self.interval = self.end - self.start
31 | print("{} {:.3E}".format(self.name, self.interval))
32 |
33 |
34 | class Timer:
35 | def __init__(self):
36 | self.times = []
37 |
38 | def tick(self):
39 | self.times.append(time.time())
40 |
41 | def get_average_and_reset(self):
42 | if len(self.times) < 2:
43 | return -1
44 | avg = (self.times[-1] - self.times[0]) / (len(self.times) - 1)
45 | self.times = [self.times[-1]]
46 | return avg
47 |
48 | def get_last_iteration(self):
49 | if len(self.times) < 2:
50 | return 0
51 | return self.times[-1] - self.times[-2]
52 |
53 |
54 | class TickTock:
55 | def __init__(self):
56 | self.time_pairs = []
57 | self.current_time = None
58 |
59 | def tick(self):
60 | self.current_time = time.time()
61 |
62 | def tock(self):
63 | assert self.current_time is not None, self.current_time
64 | self.time_pairs.append([self.current_time, time.time()])
65 | self.current_time = None
66 |
67 | def get_average_and_reset(self):
68 | if len(self.time_pairs) == 0:
69 | return -1
70 | deltas = [t2 - t1 for t1, t2 in self.time_pairs]
71 | avg = sum(deltas) / len(deltas)
72 | self.time_pairs = []
73 | return avg
74 |
75 | def get_last_iteration(self):
76 | if len(self.time_pairs) == 0:
77 | return -1
78 | return self.time_pairs[-1][1] - self.time_pairs[-1][0]
79 |
--------------------------------------------------------------------------------
/code/utils/util.py:
--------------------------------------------------------------------------------
1 | import glob
2 | import os
3 | import sys
4 | import time
5 | import math
6 | from datetime import datetime
7 | import random
8 | import logging
9 | from collections import OrderedDict
10 |
11 | import natsort
12 | import numpy as np
13 | import cv2
14 | import torch
15 | from torchvision.utils import make_grid
16 | from shutil import get_terminal_size
17 |
18 | import yaml
19 |
20 | try:
21 | from yaml import CLoader as Loader, CDumper as Dumper
22 | except ImportError:
23 | from yaml import Loader, Dumper
24 |
25 |
26 | def OrderedYaml():
27 | '''yaml orderedDict support'''
28 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
29 |
30 | def dict_representer(dumper, data):
31 | return dumper.represent_dict(data.items())
32 |
33 | def dict_constructor(loader, node):
34 | return OrderedDict(loader.construct_pairs(node))
35 |
36 | Dumper.add_representer(OrderedDict, dict_representer)
37 | Loader.add_constructor(_mapping_tag, dict_constructor)
38 | return Loader, Dumper
39 |
40 |
41 | ####################
42 | # miscellaneous
43 | ####################
44 |
45 |
46 | def get_timestamp():
47 | return datetime.now().strftime('%y%m%d-%H%M%S')
48 |
49 |
50 | def mkdir(path):
51 | if not os.path.exists(path):
52 | os.makedirs(path)
53 |
54 |
55 | def mkdirs(paths):
56 | if isinstance(paths, str):
57 | mkdir(paths)
58 | else:
59 | for path in paths:
60 | mkdir(path)
61 |
62 |
63 | def mkdir_and_rename(path):
64 | if os.path.exists(path):
65 | new_name = path + '_archived_' + get_timestamp()
66 | print('Path already exists. Rename it to [{:s}]'.format(new_name))
67 | logger = logging.getLogger('base')
68 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name))
69 | os.rename(path, new_name)
70 | os.makedirs(path)
71 |
72 |
73 | def set_random_seed(seed):
74 | random.seed(seed)
75 | np.random.seed(seed)
76 | torch.manual_seed(seed)
77 | torch.cuda.manual_seed_all(seed)
78 |
79 |
80 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False):
81 | '''set up logger'''
82 | lg = logging.getLogger(logger_name)
83 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s',
84 | datefmt='%y-%m-%d %H:%M:%S')
85 | lg.setLevel(level)
86 | if tofile:
87 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp()))
88 | fh = logging.FileHandler(log_file, mode='w')
89 | fh.setFormatter(formatter)
90 | lg.addHandler(fh)
91 | if screen:
92 | sh = logging.StreamHandler()
93 | sh.setFormatter(formatter)
94 | lg.addHandler(sh)
95 |
96 |
97 | ####################
98 | # image convert
99 | ####################
100 |
101 |
102 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
103 | '''
104 | Converts a torch Tensor into an image Numpy array
105 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order
106 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default)
107 | '''
108 | if hasattr(tensor, 'detach'):
109 | tensor = tensor.detach()
110 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp
111 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1]
112 | n_dim = tensor.dim()
113 | if n_dim == 4:
114 | n_img = len(tensor)
115 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy()
116 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
117 | elif n_dim == 3:
118 | img_np = tensor.numpy()
119 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR
120 | elif n_dim == 2:
121 | img_np = tensor.numpy()
122 | else:
123 | raise TypeError(
124 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
125 | if out_type == np.uint8:
126 | img_np = (img_np * 255.0).round()
127 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default.
128 | return img_np.astype(out_type)
129 |
130 |
131 | def save_img(img, img_path, mode='RGB'):
132 | cv2.imwrite(img_path, img)
133 |
134 |
135 | ####################
136 | # metric
137 | ####################
138 |
139 |
140 | def calculate_psnr(img1, img2):
141 | # img1 and img2 have range [0, 255]
142 | img1 = img1.astype(np.float64)
143 | img2 = img2.astype(np.float64)
144 | mse = np.mean((img1 - img2) ** 2)
145 | if mse == 0:
146 | return float('inf')
147 | return 20 * math.log10(255.0 / math.sqrt(mse))
148 |
149 |
150 | def get_resume_paths(opt):
151 | resume_state_path = None
152 | resume_model_path = None
153 | ts = opt_get(opt, ['path', 'training_state'])
154 | if opt.get('path', {}).get('resume_state', None) == "auto" and ts is not None:
155 | wildcard = os.path.join(ts, "*")
156 | paths = natsort.natsorted(glob.glob(wildcard))
157 | if len(paths) > 0:
158 | resume_state_path = paths[-1]
159 | resume_model_path = resume_state_path.replace('training_state', 'models').replace('.state', '_G.pth')
160 | else:
161 | resume_state_path = opt.get('path', {}).get('resume_state')
162 | return resume_state_path, resume_model_path
163 |
164 |
165 | def opt_get(opt, keys, default=None):
166 | if opt is None:
167 | return default
168 | ret = opt
169 | for k in keys:
170 | ret = ret.get(k, None)
171 | if ret is None:
172 | return default
173 | return ret
174 |
175 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | appnope==0.1.0
2 | argon2-cffi==20.1.0
3 | async-generator==1.10
4 | attrs==20.2.0
5 | backcall==0.2.0
6 | bleach==3.2.1
7 | certifi==2020.6.20
8 | cffi==1.14.3
9 | cycler==0.10.0
10 | dataclasses==0.6
11 | decorator==4.4.2
12 | defusedxml==0.6.0
13 | entrypoints==0.3
14 | environment-kernels==1.1.1
15 | future==0.18.2
16 | imageio==2.9.0
17 | importlib-metadata==2.0.0
18 | ipykernel==5.3.4
19 | ipython==7.19.0
20 | ipython-genutils==0.2.0
21 | ipywidgets==7.5.1
22 | jedi==0.17.2
23 | Jinja2==2.11.2
24 | jsonschema==3.2.0
25 | jupyter==1.0.0
26 | jupyter-client==6.1.7
27 | jupyter-console==6.2.0
28 | jupyter-core==4.6.3
29 | jupyterlab-pygments==0.1.2
30 | kiwisolver==1.3.1
31 | lpips==0.1.3
32 | MarkupSafe==1.1.1
33 | matplotlib==3.3.2
34 | mistune==0.8.4
35 | natsort==7.0.1
36 | nbclient==0.5.1
37 | nbconvert==6.0.7
38 | nbformat==5.0.8
39 | nest-asyncio==1.4.2
40 | networkx==2.5
41 | notebook==6.1.4
42 | numpy==1.19.4
43 | opencv-python==4.4.0.46
44 | packaging==20.4
45 | pandas==1.1.4
46 | pandocfilters==1.4.3
47 | parso==0.7.1
48 | pexpect==4.8.0
49 | pickleshare==0.7.5
50 | Pillow==8.0.1
51 | prometheus-client==0.8.0
52 | prompt-toolkit==3.0.8
53 | ptyprocess==0.6.0
54 | pycparser==2.20
55 | Pygments==2.7.2
56 | pyparsing==2.4.7
57 | pyrsistent==0.17.3
58 | python-dateutil==2.8.1
59 | pytz==2020.4
60 | PyWavelets==1.1.1
61 | PyYAML==5.3.1
62 | pyzmq==19.0.2
63 | qtconsole==4.7.7
64 | QtPy==1.9.0
65 | scikit-image==0.17.2
66 | scipy==1.5.3
67 | Send2Trash==1.5.0
68 | six==1.15.0
69 | terminado==0.9.1
70 | tensorboard==2.4.0
71 | testpath==0.4.4
72 | tifffile==2020.10.1
73 | torch==1.7.0
74 | torchvision==0.8.1
75 | tornado==6.1
76 | tqdm==4.51.0
77 | traitlets==5.0.5
78 | typing-extensions==3.7.4.3
79 | wcwidth==0.2.5
80 | webencodings==0.5.1
81 | widgetsnbextension==3.5.1
82 | zipp==3.4.0
83 |
--------------------------------------------------------------------------------
/run_jupyter.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e # exit script if an error occurs
4 |
5 | source myenv/bin/activate # Install with ./setup.sh
6 | cd code
7 | python -m jupyter notebook demo_on_pretrained.ipynb # Start jupyter using the python from the virtual environment
8 |
--------------------------------------------------------------------------------
/setup.sh:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 |
3 | set -e # exit script if an error occurs
4 |
5 |
6 |
7 | echo ""
8 | echo "########################################"
9 | echo "Setup Virtual Environment"
10 | echo "########################################"
11 | echo ""
12 |
13 | python3 -m venv myenv # Create a new virtual environment (venv) using native python3.7 venv
14 | source myenv/bin/activate # This replaces the python/pip command with the ones from the venv
15 | which python # shoud output: ./myenv/bin/python
16 |
17 | pip install --upgrade pip # Update pip
18 | pip install -r requirements.txt # Install the exact same packages that we used
19 |
20 | # Alternatively you can install globally using pip
21 | # pip install jupyter torch natsort pyyaml opencv-python torchvision scikit-image tqdm lpips pandas environment_kernels
22 |
23 |
24 |
25 | echo ""
26 | echo "########################################"
27 | echo "Download models, data"
28 | echo "########################################"
29 | echo ""
30 |
31 | wget --continue http://data.vision.ee.ethz.ch/alugmayr/SRFlow/datasets.zip
32 | unzip datasets.zip
33 | rm datasets.zip
34 |
35 | wget --continue http://data.vision.ee.ethz.ch/alugmayr/SRFlow/pretrained_models.zip
36 | unzip pretrained_models.zip
37 | rm pretrained_models.zip
38 |
39 |
40 | echo ""
41 | echo "########################################"
42 | echo "Start Demo"
43 | echo "########################################"
44 | echo ""
45 |
46 | ./run_jupyter.sh
47 |
--------------------------------------------------------------------------------