├── .gitignore ├── LICENSE ├── LICENSES ├── LICENSE ├── LICENSE_BasicSR ├── LICENSE_GLOW └── README.md ├── README.md ├── code ├── Measure.py ├── confs │ ├── RRDB_CelebA_8X.yml │ ├── RRDB_DF2K_4X.yml │ ├── RRDB_DF2K_8X.yml │ ├── SRFlow_CelebA_8X.yml │ ├── SRFlow_DF2K_4X.yml │ └── SRFlow_DF2K_8X.yml ├── data │ ├── LRHR_PKL_dataset.py │ └── __init__.py ├── demo_on_pretrained.ipynb ├── imresize.py ├── models │ ├── SRFlow_model.py │ ├── SR_model.py │ ├── __init__.py │ ├── base_model.py │ ├── lr_scheduler.py │ ├── modules │ │ ├── FlowActNorms.py │ │ ├── FlowAffineCouplingsAblation.py │ │ ├── FlowStep.py │ │ ├── FlowUpsamplerNet.py │ │ ├── Permutations.py │ │ ├── RRDBNet_arch.py │ │ ├── SRFlowNet_arch.py │ │ ├── Split.py │ │ ├── __init__.py │ │ ├── flow.py │ │ ├── glow_arch.py │ │ ├── loss.py │ │ ├── module_util.py │ │ └── thops.py │ └── networks.py ├── options │ ├── __init__.py │ └── options.py ├── prepare_data.py ├── test.py ├── train.py └── utils │ ├── __init__.py │ ├── timer.py │ └── util.py ├── requirements.txt ├── run_jupyter.sh └── setup.sh /.gitignore: -------------------------------------------------------------------------------- 1 | __MACOSX/ 2 | .DS_Store 3 | 4 | *.pth 5 | *.pklv4 6 | *.zip 7 | 8 | datasets/ 9 | 10 | data/ 11 | myenv/ 12 | .idea 13 | checkpoints/ 14 | code/notebooks/ 15 | code/local_config.py 16 | *.ipynb 17 | 18 | TRAIN_DONE 19 | 20 | # folder 21 | .vscode 22 | 23 | experiments/* 24 | !experiments/pretrained_models 25 | experiments/pretrained_models/* 26 | # !experiments/pretrained_models/SRResNet_bicx4_in3nf64nb16.pth 27 | !experiments/pretrained_models/README.md 28 | 29 | results/* 30 | tb_logger/* 31 | 32 | # file type 33 | *.svg 34 | *.pyc 35 | *.t7 36 | *.caffemodel 37 | *.mat 38 | *.npy 39 | 40 | # latex 41 | *.aux 42 | *.bbl 43 | *.blg 44 | *.log 45 | *.out 46 | *.synctex.gz 47 | 48 | # TODO 49 | data_samples/samples_byteimg 50 | data_samples/samples_colorimg 51 | data_samples/samples_segprob 52 | data_samples/samples_result 53 | 54 | 55 | Created by https://www.gitignore.io/api/vim,python,pycharm 56 | # Edit at https://www.gitignore.io/?templates=vim,python,pycharm 57 | 58 | ### PyCharm ### 59 | # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm 60 | # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 61 | 62 | # User-specific stuff 63 | .idea/**/workspace.xml 64 | .idea/**/tasks.xml 65 | .idea/**/usage.statistics.xml 66 | .idea/**/dictionaries 67 | .idea/**/shelf 68 | 69 | # Generated files 70 | .idea/**/contentModel.xml 71 | 72 | # Sensitive or high-churn files 73 | .idea/**/dataSources/ 74 | .idea/**/dataSources.ids 75 | .idea/**/dataSources.local.xml 76 | .idea/**/sqlDataSources.xml 77 | .idea/**/dynamic.xml 78 | .idea/**/uiDesigner.xml 79 | .idea/**/dbnavigator.xml 80 | 81 | # Gradle 82 | .idea/**/gradle.xml 83 | .idea/**/libraries 84 | 85 | # Gradle and Maven with auto-import 86 | # When using Gradle or Maven with auto-import, you should exclude module files, 87 | # since they will be recreated, and may cause churn. Uncomment if using 88 | # auto-import. 89 | # .idea/modules.xml 90 | # .idea/*.iml 91 | # .idea/modules 92 | # *.iml 93 | # *.ipr 94 | 95 | # CMake 96 | cmake-build-*/ 97 | 98 | # Mongo Explorer plugin 99 | .idea/**/mongoSettings.xml 100 | 101 | # File-based project format 102 | *.iws 103 | 104 | # IntelliJ 105 | out/ 106 | 107 | # mpeltonen/sbt-idea plugin 108 | .idea_modules/ 109 | 110 | # JIRA plugin 111 | atlassian-ide-plugin.xml 112 | 113 | # Cursive Clojure plugin 114 | .idea/replstate.xml 115 | 116 | # Crashlytics plugin (for Android Studio and IntelliJ) 117 | com_crashlytics_export_strings.xml 118 | crashlytics.properties 119 | crashlytics-build.properties 120 | fabric.properties 121 | 122 | # Editor-based Rest Client 123 | .idea/httpRequests 124 | 125 | # Android studio 3.1+ serialized cache file 126 | .idea/caches/build_file_checksums.ser 127 | 128 | ### PyCharm Patch ### 129 | # Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 130 | 131 | # *.iml 132 | # modules.xml 133 | # .idea/misc.xml 134 | # *.ipr 135 | 136 | # Sonarlint plugin 137 | .idea/sonarlint 138 | 139 | ### Python ### 140 | # Byte-compiled / optimized / DLL files 141 | __pycache__/ 142 | *.py[cod] 143 | *$py.class 144 | 145 | # C extensions 146 | *.so 147 | 148 | # Distribution / packaging 149 | .Python 150 | build/ 151 | develop-eggs/ 152 | dist/ 153 | downloads/ 154 | eggs/ 155 | .eggs/ 156 | lib/ 157 | lib64/ 158 | parts/ 159 | sdist/ 160 | var/ 161 | wheels/ 162 | pip-wheel-metadata/ 163 | share/python-wheels/ 164 | *.egg-info/ 165 | .installed.cfg 166 | *.egg 167 | MANIFEST 168 | 169 | # PyInstaller 170 | # Usually these files are written by a python script from a template 171 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 172 | *.manifest 173 | *.spec 174 | 175 | # Installer logs 176 | pip-log.txt 177 | pip-delete-this-directory.txt 178 | 179 | # Unit test / coverage reports 180 | htmlcov/ 181 | .tox/ 182 | .nox/ 183 | .coverage 184 | .coverage.* 185 | .cache 186 | nosetests.xml 187 | coverage.xml 188 | *.cover 189 | .hypothesis/ 190 | .pytest_cache/ 191 | 192 | # Translations 193 | *.mo 194 | *.pot 195 | 196 | # Django stuff: 197 | *.log 198 | local_settings.py 199 | db.sqlite3 200 | db.sqlite3-journal 201 | 202 | # Flask stuff: 203 | instance/ 204 | .webassets-cache 205 | 206 | # Scrapy stuff: 207 | .scrapy 208 | 209 | # Sphinx documentation 210 | docs/_build/ 211 | 212 | # PyBuilder 213 | target/ 214 | 215 | # Jupyter Notebook 216 | .ipynb_checkpoints 217 | 218 | # IPython 219 | profile_default/ 220 | ipython_config.py 221 | 222 | # pyenv 223 | .python-version 224 | 225 | # pipenv 226 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 227 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 228 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 229 | # install all needed dependencies. 230 | #Pipfile.lock 231 | 232 | # celery beat schedule file 233 | celerybeat-schedule 234 | 235 | # SageMath parsed files 236 | *.sage.py 237 | 238 | # Environments 239 | .env 240 | .venv 241 | env/ 242 | venv/ 243 | ENV/ 244 | env.bak/ 245 | venv.bak/ 246 | 247 | # Spyder project settings 248 | .spyderproject 249 | .spyproject 250 | 251 | # Rope project settings 252 | .ropeproject 253 | 254 | # mkdocs documentation 255 | /site 256 | 257 | # mypy 258 | .mypy_cache/ 259 | .dmypy.json 260 | dmypy.json 261 | 262 | # Pyre type checker 263 | .pyre/ 264 | 265 | ### Vim ### 266 | # Swap 267 | [._]*.s[a-v][a-z] 268 | [._]*.sw[a-p] 269 | [._]s[a-rt-v][a-z] 270 | [._]ss[a-gi-z] 271 | [._]sw[a-p] 272 | 273 | # Session 274 | Session.vim 275 | Sessionx.vim 276 | 277 | # Temporary 278 | .netrwhist 279 | *~ 280 | # Auto-generated tag files 281 | tags 282 | # Persistent undo 283 | [._]*.un~ 284 | 285 | # End of https://www.gitignore.io/api/vim,python,pycharm 286 | 287 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | 6 | https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | 8 | The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | 15 | Parts of this repository are licensed by 16 | https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 17 | https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 18 | -------------------------------------------------------------------------------- /LICENSES/LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | you may not use this file except in compliance with the License. 4 | You may obtain a copy of the License at 5 | 6 | https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | 8 | The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | Unless required by applicable law or agreed to in writing, software 10 | distributed under the License is distributed on an "AS IS" BASIS, 11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | See the License for the specific language governing permissions and 13 | limitations under the License. 14 | 15 | Parts of this repository are licensed by 16 | https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 17 | https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE -------------------------------------------------------------------------------- /LICENSES/LICENSE_BasicSR: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright 2018-2020 BasicSR Authors 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. -------------------------------------------------------------------------------- /LICENSES/LICENSE_GLOW: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Yuki-Chai 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. -------------------------------------------------------------------------------- /LICENSES/README.md: -------------------------------------------------------------------------------- 1 | # License and Acknowledgement 2 | 3 | A big thanks to following contributes that open sourced their code and therefore helped us a lot in developing SRFlow! 4 | 5 | ## BasicSR 6 | The training framework was adapted from https://github.com/xinntao/BasicSR 7 | 8 | ## GLOW 9 | The Normalizing Flow modules were adapted from https://github.com/chaiyujin/glow-pytorch 10 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SRFlow 2 | #### Official SRFlow training code: Super-Resolution using Normalizing Flow in PyTorch

3 | #### [[Paper] ECCV 2020 Spotlight](https://bit.ly/2DkwQcg) 4 | 5 |
6 | 7 | **News:** Unified Image Super-Resolution and Rescaling [[code](https://bit.ly/2VOKHBb)] 8 |
9 |
10 | 11 | [![SRFlow](https://user-images.githubusercontent.com/11280511/98149322-7ed5c580-1ecd-11eb-8279-f02de9f0df12.gif)](https://bit.ly/3jWFRcr) 12 |
13 |
14 |
15 | 16 | # Setup: Data, Environment, PyTorch Demo 17 | 18 |
19 | 20 | ```bash 21 | git clone https://github.com/andreas128/SRFlow.git && cd SRFlow && ./setup.sh 22 | ``` 23 | 24 |
25 | 26 | This oneliner will: 27 | - Clone SRFlow 28 | - Setup a python3 virtual env 29 | - Install the packages from `requirements.txt` 30 | - Download the pretrained models 31 | - Download the DIV2K validation data 32 | - Run the Demo Jupyter Notebook 33 | 34 | If you want to install it manually, read the `setup.sh` file. (Links to data/models, pip packages) 35 | 36 |
37 |
38 | 39 | # Demo: Try Normalizing Flow in PyTorch 40 | 41 | ```bash 42 | ./run_jupyter.sh 43 | ``` 44 | 45 | This notebook lets you: 46 | - Load the pretrained models. 47 | - Super-resolve images. 48 | - Measure PSNR/SSIM/LPIPS. 49 | - Infer the Normalizing Flow latent space. 50 | 51 |

52 | 53 | # Testing: Apply the included pretrained models 54 | 55 | ```bash 56 | source myenv/bin/activate # Use the env you created using setup.sh 57 | cd code 58 | CUDA_VISIBLE_DEVICES=-1 python test.py ./confs/SRFlow_DF2K_4X.yml # Diverse Images 4X (Dataset Included) 59 | CUDA_VISIBLE_DEVICES=-1 python test.py ./confs/SRFlow_DF2K_8X.yml # Diverse Images 8X (Dataset Included) 60 | CUDA_VISIBLE_DEVICES=-1 python test.py ./confs/SRFlow_CelebA_8X.yml # Faces 8X 61 | ``` 62 | For testing, we apply SRFlow to the full images on CPU. 63 | 64 |

65 | 66 | # Training: Reproduce or train on your Data 67 | 68 | The following commands train the Super-Resolution network using Normalizing Flow in PyTorch: 69 | 70 | ```bash 71 | source myenv/bin/activate # Use the env you created using setup.sh 72 | cd code 73 | python train.py -opt ./confs/SRFlow_DF2K_4X.yml # Diverse Images 4X (Dataset Included) 74 | python train.py -opt ./confs/SRFlow_DF2K_8X.yml # Diverse Images 8X (Dataset Included) 75 | python train.py -opt ./confs/SRFlow_CelebA_8X.yml # Faces 8X 76 | ``` 77 | 78 | - To reduce the GPU memory, reduce the batch size in the yml file. 79 | - CelebA does not allow us to host the dataset. A script will follow. 80 | 81 | ### How to prepare CelebA? 82 | 83 | **1. Get HD-CelebA-Cropper** 84 | 85 | ```git clone https://github.com/LynnHo/HD-CelebA-Cropper``` 86 | 87 | **2. Download the dataset** 88 | 89 | `img_celeba.7z` and `annotations.zip` as desribed in the [Readme](https://github.com/LynnHo/HD-CelebA-Cropper). 90 | 91 | **3. Run the crop align** 92 | 93 | ```python3 align.py --img_dir ./data/data --crop_size_h 640 --crop_size_w 640 --order 3 --face_factor 0.6 --n_worker 8``` 94 | 95 | **4. Downsample for GT** 96 | 97 | Use the [matlablike kernel](https://github.com/fatheral/matlab_imresize) to downscale to 160x160 for the GT images. 98 | 99 | **5. Downsample for LR** 100 | 101 | Downscale the GT using the Matlab kernel to the LR size (40x40 or 20x20) 102 | 103 | **6. Train/Validation** 104 | 105 | For training and validation, we use the corresponding sets defined by CelebA (Train: 000001-162770, Validation: 162771-182637) 106 | 107 | **7. Pack to pickle for training** 108 | 109 | `cd code && python prepare_data.py /path/to/img_dir` 110 | 111 |

112 | 113 | # Dataset: How to train on your own data 114 | 115 | The following command creates the pickel files that you can use in the yaml config file: 116 | 117 | ```bash 118 | cd code 119 | python prepare_data.py /path/to/img_dir 120 | ``` 121 | 122 | The precomputed DF2K dataset gets downloaded using `setup.sh`. You can reproduce it or prepare your own dataset. 123 | 124 |

125 | 126 | # Our paper explains 127 | 128 | - **How to train Conditional Normalizing Flow**
129 | We designed an architecture that archives state-of-the-art super-resolution quality. 130 | - **How to train Normalizing Flow on a single GPU**
131 | We based our network on GLOW, which uses up to 40 GPUs to train for image generation. SRFlow only needs a single GPU for training conditional image generation. 132 | - **How to use Normalizing Flow for image manipulation**
133 | How to exploit the latent space for Normalizing Flow for controlled image manipulations 134 | - **See many Visual Results**
135 | Compare GAN vs Normalizing Flow yourself. We've included a lot of visuals results in our [[Paper]](https://bit.ly/2D9cN0L). 136 | 137 |

138 | 139 | # GAN vs Normalizing Flow - Blog 140 | 141 | [![](https://user-images.githubusercontent.com/11280511/98148862-56e66200-1ecd-11eb-817e-87e99dcab6ca.gif)](https://bit.ly/2EdJzhy) 142 | 143 | - **Sampling:** SRFlow outputs many different images for a single input. 144 | - **Stable Training:** SRFlow has much fewer hyperparameters than GAN approaches, and we did not encounter training stability issues. 145 | - **Convergence:** While GANs cannot converge, conditional Normalizing Flows converge monotonic and stable. 146 | - **Higher Consistency:** When downsampling the super-resolution, one obtains almost the exact input. 147 | 148 | Get a quick introduction to Normalizing Flow in our [[Blog]](https://bit.ly/320bAkH). 149 |


150 | 151 |

152 | 153 | # Wanna help to improve the code? 154 | 155 | If you found a bug or improved the code, please do the following: 156 | 157 | - Fork this repo. 158 | - Push the changes to your repo. 159 | - Create a pull request. 160 | 161 |

162 | 163 | # Paper 164 | [[Paper] ECCV 2020 Spotlight](https://bit.ly/2XcmSks) 165 | 166 | ```bibtex 167 | @inproceedings{lugmayr2020srflow, 168 | title={SRFlow: Learning the Super-Resolution Space with Normalizing Flow}, 169 | author={Lugmayr, Andreas and Danelljan, Martin and Van Gool, Luc and Timofte, Radu}, 170 | booktitle={ECCV}, 171 | year={2020} 172 | } 173 | ``` 174 |

175 | -------------------------------------------------------------------------------- /code/Measure.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import glob 16 | import os 17 | import time 18 | from collections import OrderedDict 19 | 20 | import numpy as np 21 | import torch 22 | import cv2 23 | import argparse 24 | 25 | from natsort import natsort 26 | from skimage.metrics import structural_similarity as ssim 27 | from skimage.metrics import peak_signal_noise_ratio as psnr 28 | import lpips 29 | 30 | 31 | class Measure(): 32 | def __init__(self, net='alex', use_gpu=False): 33 | self.device = 'cuda' if use_gpu else 'cpu' 34 | self.model = lpips.LPIPS(net=net) 35 | self.model.to(self.device) 36 | 37 | def measure(self, imgA, imgB): 38 | return [float(f(imgA, imgB)) for f in [self.psnr, self.ssim, self.lpips]] 39 | 40 | def lpips(self, imgA, imgB, model=None): 41 | tA = t(imgA).to(self.device) 42 | tB = t(imgB).to(self.device) 43 | dist01 = self.model.forward(tA, tB).item() 44 | return dist01 45 | 46 | def ssim(self, imgA, imgB): 47 | # multichannel: If True, treat the last dimension of the array as channels. Similarity calculations are done independently for each channel then averaged. 48 | score, diff = ssim(imgA, imgB, full=True, multichannel=True) 49 | return score 50 | 51 | def psnr(self, imgA, imgB): 52 | psnr_val = psnr(imgA, imgB) 53 | return psnr_val 54 | 55 | 56 | def t(img): 57 | def to_4d(img): 58 | assert len(img.shape) == 3 59 | assert img.dtype == np.uint8 60 | img_new = np.expand_dims(img, axis=0) 61 | assert len(img_new.shape) == 4 62 | return img_new 63 | 64 | def to_CHW(img): 65 | return np.transpose(img, [2, 0, 1]) 66 | 67 | def to_tensor(img): 68 | return torch.Tensor(img) 69 | 70 | return to_tensor(to_4d(to_CHW(img))) / 127.5 - 1 71 | 72 | 73 | def fiFindByWildcard(wildcard): 74 | return natsort.natsorted(glob.glob(wildcard, recursive=True)) 75 | 76 | 77 | def imread(path): 78 | return cv2.imread(path)[:, :, [2, 1, 0]] 79 | 80 | 81 | def format_result(psnr, ssim, lpips): 82 | return f'{psnr:0.2f}, {ssim:0.3f}, {lpips:0.3f}' 83 | 84 | def measure_dirs(dirA, dirB, use_gpu, verbose=False): 85 | if verbose: 86 | vprint = lambda x: print(x) 87 | else: 88 | vprint = lambda x: None 89 | 90 | 91 | t_init = time.time() 92 | 93 | paths_A = fiFindByWildcard(os.path.join(dirA, f'*.{type}')) 94 | paths_B = fiFindByWildcard(os.path.join(dirB, f'*.{type}')) 95 | 96 | vprint("Comparing: ") 97 | vprint(dirA) 98 | vprint(dirB) 99 | 100 | measure = Measure(use_gpu=use_gpu) 101 | 102 | results = [] 103 | for pathA, pathB in zip(paths_A, paths_B): 104 | result = OrderedDict() 105 | 106 | t = time.time() 107 | result['psnr'], result['ssim'], result['lpips'] = measure.measure(imread(pathA), imread(pathB)) 108 | d = time.time() - t 109 | vprint(f"{pathA.split('/')[-1]}, {pathB.split('/')[-1]}, {format_result(**result)}, {d:0.1f}") 110 | 111 | results.append(result) 112 | 113 | psnr = np.mean([result['psnr'] for result in results]) 114 | ssim = np.mean([result['ssim'] for result in results]) 115 | lpips = np.mean([result['lpips'] for result in results]) 116 | 117 | vprint(f"Final Result: {format_result(psnr, ssim, lpips)}, {time.time() - t_init:0.1f}s") 118 | 119 | 120 | if __name__ == "__main__": 121 | parser = argparse.ArgumentParser() 122 | parser.add_argument('-dirA', default='', type=str) 123 | parser.add_argument('-dirB', default='', type=str) 124 | parser.add_argument('-type', default='png') 125 | parser.add_argument('--use_gpu', action='store_true', default=False) 126 | args = parser.parse_args() 127 | 128 | dirA = args.dirA 129 | dirB = args.dirB 130 | type = args.type 131 | use_gpu = args.use_gpu 132 | 133 | if len(dirA) > 0 and len(dirB) > 0: 134 | measure_dirs(dirA, dirB, use_gpu=use_gpu, verbose=True) 135 | -------------------------------------------------------------------------------- /code/confs/RRDB_CelebA_8X.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | #### general settings 18 | name: train 19 | use_tb_logger: true 20 | model: SR 21 | distortion: sr 22 | scale: 8 23 | #gpu_ids: [ 0 ] 24 | 25 | #### datasets 26 | datasets: 27 | train: 28 | name: CelebA_160_tr 29 | mode: LRHR_PKL 30 | dataroot_GT: ../datasets/celebA-train-gt_1pct.pklv4 31 | dataroot_LQ: ../datasets/celebA-train-x8_1pct.pklv4 32 | 33 | use_shuffle: true 34 | n_workers: 0 # per GPU 35 | batch_size: 16 36 | GT_size: 160 37 | use_flip: true 38 | use_rot: true 39 | color: RGB 40 | val: 41 | name: CelebA_160_va 42 | mode: LRHR_PKL 43 | dataroot_GT: ../datasets/celebA-valid-gt_1pct.pklv4 44 | dataroot_LQ: ../datasets/celebA-valid-x8_1pct.pklv4 45 | n_max: 10 46 | 47 | #### network structures 48 | network_G: 49 | which_model_G: RRDBNet 50 | in_nc: 3 51 | out_nc: 3 52 | nf: 64 53 | nb: 23 54 | 55 | #### path 56 | path: 57 | pretrain_model_G: ~ 58 | strict_load: true 59 | resume_state: auto 60 | 61 | #### training settings: learning rate scheme, loss 62 | train: 63 | lr_G: !!float 2e-4 64 | lr_scheme: CosineAnnealingLR_Restart 65 | beta1: 0.9 66 | beta2: 0.99 67 | niter: 200000 68 | warmup_iter: -1 # no warm up 69 | T_period: [ 50000, 50000, 50000, 50000 ] 70 | restarts: [ 50000, 100000, 150000 ] 71 | restart_weights: [ 1, 1, 1 ] 72 | eta_min: !!float 1e-7 73 | 74 | pixel_criterion: l1 75 | pixel_weight: 1.0 76 | 77 | manual_seed: 10 78 | val_freq: !!float 5e3 79 | 80 | #### logger 81 | logger: 82 | print_freq: 100 83 | save_checkpoint_freq: !!float 1e3 84 | -------------------------------------------------------------------------------- /code/confs/RRDB_DF2K_4X.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | #### general settings 18 | name: train 19 | use_tb_logger: true 20 | model: SR 21 | distortion: sr 22 | scale: 4 23 | gpu_ids: [ 0 ] 24 | 25 | #### datasets 26 | datasets: 27 | train: 28 | name: CelebA_160_tr 29 | mode: LRHR_PKL 30 | dataroot_GT: ../datasets/DF2K-train-gt_1pct.pklv4 31 | dataroot_LQ: ../datasets/DF2K-train-x4_1pct.pklv4 32 | quant: 32 33 | 34 | use_shuffle: true 35 | n_workers: 3 # per GPU 36 | batch_size: 16 37 | GT_size: 160 38 | use_flip: true 39 | color: RGB 40 | val: 41 | name: CelebA_160_va 42 | mode: LRHR_PKL 43 | dataroot_GT: ../datasets/DF2K-valid-gt_1pct.pklv4 44 | dataroot_LQ: ../datasets/DF2K-valid-x4_1pct.pklv4 45 | quant: 32 46 | n_max: 20 47 | 48 | #### network structures 49 | network_G: 50 | which_model_G: RRDBNet 51 | use_orig: True 52 | in_nc: 3 53 | out_nc: 3 54 | nf: 64 55 | nb: 23 56 | 57 | #### path 58 | path: 59 | pretrain_model_G: ~ 60 | strict_load: true 61 | resume_state: auto 62 | 63 | #### training settings: learning rate scheme, loss 64 | train: 65 | lr_G: !!float 2e-4 66 | lr_scheme: CosineAnnealingLR_Restart 67 | beta1: 0.9 68 | beta2: 0.99 69 | niter: 1000000 70 | warmup_iter: -1 # no warm up 71 | T_period: [ 50000, 50000, 50000, 50000 ] 72 | restarts: [ 50000, 100000, 150000 ] 73 | restart_weights: [ 1, 1, 1 ] 74 | eta_min: !!float 1e-7 75 | 76 | pixel_criterion: l1 77 | pixel_weight: 1.0 78 | 79 | manual_seed: 10 80 | val_freq: !!float 5e3 81 | 82 | #### logger 83 | logger: 84 | print_freq: 100 85 | save_checkpoint_freq: !!float 1e3 86 | -------------------------------------------------------------------------------- /code/confs/RRDB_DF2K_8X.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | #### general settings 18 | name: train 19 | use_tb_logger: true 20 | model: SR 21 | distortion: sr 22 | scale: 8 23 | gpu_ids: [ 0 ] 24 | 25 | #### datasets 26 | datasets: 27 | train: 28 | name: CelebA_160_tr 29 | mode: LRHR_PKL 30 | dataroot_GT: ../datasets/DF2K-train-gt_1pct.pklv4 31 | dataroot_LQ: ../datasets/DF2K-train-x8_1pct.pklv4 32 | quant: 32 33 | 34 | use_shuffle: true 35 | n_workers: 3 # per GPU 36 | batch_size: 16 37 | GT_size: 160 38 | use_flip: true 39 | color: RGB 40 | 41 | val: 42 | name: CelebA_160_va 43 | mode: LRHR_PKL 44 | dataroot_GT: ../datasets/DF2K-valid-gt_1pct.pklv4 45 | dataroot_LQ: ../datasets/DF2K-valid-x8_1pct.pklv4 46 | quant: 32 47 | n_max: 20 48 | 49 | #### network structures 50 | network_G: 51 | which_model_G: RRDBNet 52 | in_nc: 3 53 | out_nc: 3 54 | nf: 64 55 | nb: 23 56 | 57 | #### path 58 | path: 59 | pretrain_model_G: ~ 60 | strict_load: true 61 | resume_state: auto 62 | 63 | #### training settings: learning rate scheme, loss 64 | train: 65 | lr_G: !!float 2e-4 66 | lr_scheme: CosineAnnealingLR_Restart 67 | beta1: 0.9 68 | beta2: 0.99 69 | niter: 200000 70 | warmup_iter: -1 # no warm up 71 | T_period: [ 50000, 50000, 50000, 50000 ] 72 | restarts: [ 50000, 100000, 150000 ] 73 | restart_weights: [ 1, 1, 1 ] 74 | eta_min: !!float 1e-7 75 | 76 | pixel_criterion: l1 77 | pixel_weight: 1.0 78 | 79 | manual_seed: 10 80 | val_freq: !!float 5e3 81 | 82 | #### logger 83 | logger: 84 | print_freq: 100 85 | save_checkpoint_freq: !!float 1e3 86 | -------------------------------------------------------------------------------- /code/confs/SRFlow_CelebA_8X.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | #### general settings 18 | name: train 19 | use_tb_logger: true 20 | model: SRFlow 21 | distortion: sr 22 | scale: 8 23 | gpu_ids: [ 0 ] 24 | 25 | #### datasets 26 | datasets: 27 | train: 28 | name: CelebA_160_tr 29 | mode: LRHR_PKL 30 | dataroot_GT: ../datasets/celebA-train-gt.pklv4 31 | dataroot_LQ: ../datasets/celebA-train-x8.pklv4 32 | quant: 32 33 | 34 | use_shuffle: true 35 | n_workers: 3 # per GPU 36 | batch_size: 16 37 | GT_size: 160 38 | use_flip: true 39 | color: RGB 40 | val: 41 | name: CelebA_160_va 42 | mode: LRHR_PKL 43 | dataroot_GT: ../datasets/celebA-train-gt.pklv4 44 | dataroot_LQ: ../datasets/celebA-train-x8.pklv4 45 | quant: 32 46 | n_max: 20 47 | 48 | #### Test Settings 49 | dataroot_GT: ../datasets/celebA-validation-gt 50 | dataroot_LR: ../datasets/celebA-validation-x8 51 | model_path: ../pretrained_models/SRFlow_CelebA_8X.pth 52 | heat: 0.9 # This is the standard deviation of the latent vectors 53 | 54 | #### network structures 55 | network_G: 56 | which_model_G: SRFlowNet 57 | in_nc: 3 58 | out_nc: 3 59 | nf: 64 60 | nb: 8 61 | upscale: 8 62 | train_RRDB: false 63 | train_RRDB_delay: 0.5 64 | 65 | flow: 66 | K: 16 67 | L: 4 68 | noInitialInj: true 69 | coupling: CondAffineSeparatedAndCond 70 | additionalFlowNoAffine: 2 71 | split: 72 | enable: true 73 | fea_up0: true 74 | stackRRDB: 75 | blocks: [ 1, 3, 5, 7 ] 76 | concat: true 77 | 78 | #### path 79 | path: 80 | pretrain_model_G: ../pretrained_models/RRDB_CelebA_8X.pth 81 | strict_load: true 82 | resume_state: auto 83 | 84 | #### training settings: learning rate scheme, loss 85 | train: 86 | manual_seed: 10 87 | lr_G: !!float 5e-4 88 | weight_decay_G: 0 89 | beta1: 0.9 90 | beta2: 0.99 91 | lr_scheme: MultiStepLR 92 | warmup_iter: -1 # no warm up 93 | lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ] 94 | lr_gamma: 0.5 95 | 96 | niter: 200000 97 | val_freq: 40000 98 | 99 | #### validation settings 100 | val: 101 | heats: [ 0.0, 0.5, 0.75, 1.0 ] 102 | n_sample: 3 103 | 104 | #### logger 105 | logger: 106 | print_freq: 100 107 | save_checkpoint_freq: !!float 1e3 108 | -------------------------------------------------------------------------------- /code/confs/SRFlow_DF2K_4X.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | #### general settings 18 | name: train 19 | use_tb_logger: true 20 | model: SRFlow 21 | distortion: sr 22 | scale: 4 23 | gpu_ids: [ 0 ] 24 | 25 | #### datasets 26 | datasets: 27 | train: 28 | name: CelebA_160_tr 29 | mode: LRHR_PKL 30 | dataroot_GT: ../datasets/DF2K-tr.pklv4 31 | dataroot_LQ: ../datasets/DF2K-tr_X4.pklv4 32 | quant: 32 33 | 34 | use_shuffle: true 35 | n_workers: 3 # per GPU 36 | batch_size: 12 37 | GT_size: 160 38 | use_flip: true 39 | color: RGB 40 | val: 41 | name: CelebA_160_va 42 | mode: LRHR_PKL 43 | dataroot_GT: ../datasets/DIV2K-va.pklv4 44 | dataroot_LQ: ../datasets/DIV2K-va_X4.pklv4 45 | quant: 32 46 | n_max: 20 47 | 48 | #### Test Settings 49 | dataroot_GT: ../datasets/div2k-validation-modcrop8-gt 50 | dataroot_LR: ../datasets/div2k-validation-modcrop8-x4 51 | model_path: ../pretrained_models/SRFlow_DF2K_4X.pth 52 | heat: 0.9 # This is the standard deviation of the latent vectors 53 | 54 | #### network structures 55 | network_G: 56 | which_model_G: SRFlowNet 57 | in_nc: 3 58 | out_nc: 3 59 | nf: 64 60 | nb: 23 61 | upscale: 4 62 | train_RRDB: false 63 | train_RRDB_delay: 0.5 64 | 65 | flow: 66 | K: 16 67 | L: 3 68 | noInitialInj: true 69 | coupling: CondAffineSeparatedAndCond 70 | additionalFlowNoAffine: 2 71 | split: 72 | enable: true 73 | fea_up0: true 74 | stackRRDB: 75 | blocks: [ 1, 8, 15, 22 ] 76 | concat: true 77 | 78 | #### path 79 | path: 80 | pretrain_model_G: ../pretrained_models/RRDB_DF2K_4X.pth 81 | strict_load: true 82 | resume_state: auto 83 | 84 | #### training settings: learning rate scheme, loss 85 | train: 86 | manual_seed: 10 87 | lr_G: !!float 2.5e-4 88 | weight_decay_G: 0 89 | beta1: 0.9 90 | beta2: 0.99 91 | lr_scheme: MultiStepLR 92 | warmup_iter: -1 # no warm up 93 | lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ] 94 | lr_gamma: 0.5 95 | 96 | niter: 200000 97 | val_freq: 40000 98 | 99 | #### validation settings 100 | val: 101 | heats: [ 0.0, 0.5, 0.75, 1.0 ] 102 | n_sample: 3 103 | 104 | #### logger 105 | logger: 106 | print_freq: 100 107 | save_checkpoint_freq: !!float 1e3 108 | -------------------------------------------------------------------------------- /code/confs/SRFlow_DF2K_8X.yml: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | #### general settings 18 | name: train 19 | use_tb_logger: true 20 | model: SRFlow 21 | distortion: sr 22 | scale: 8 23 | gpu_ids: [ 0 ] 24 | 25 | #### datasets 26 | datasets: 27 | train: 28 | name: CelebA_160_tr 29 | mode: LRHR_PKL 30 | dataroot_GT: ../datasets/DF2K-tr.pklv4 31 | dataroot_LQ: ../datasets/DF2K-tr_X8.pklv4 32 | quant: 32 33 | 34 | use_shuffle: true 35 | n_workers: 3 # per GPU 36 | batch_size: 16 37 | GT_size: 160 38 | use_flip: true 39 | color: RGB 40 | 41 | val: 42 | name: CelebA_160_va 43 | mode: LRHR_PKL 44 | dataroot_GT: ../datasets/DIV2K-va.pklv4 45 | dataroot_LQ: ../datasets/DIV2K-va_X8.pklv4 46 | quant: 32 47 | n_max: 20 48 | 49 | #### Test Settings 50 | dataroot_GT: ../datasets/div2k-validation-modcrop8-gt 51 | dataroot_LR: ../datasets/div2k-validation-modcrop8-x8 52 | model_path: ../pretrained_models/SRFlow_DF2K_8X.pth 53 | heat: 0.9 # This is the standard deviation of the latent vectors 54 | 55 | #### network structures 56 | network_G: 57 | which_model_G: SRFlowNet 58 | in_nc: 3 59 | out_nc: 3 60 | nf: 64 61 | nb: 23 62 | upscale: 8 63 | train_RRDB: false 64 | train_RRDB_delay: 0.5 65 | 66 | flow: 67 | K: 16 68 | L: 4 69 | noInitialInj: true 70 | coupling: CondAffineSeparatedAndCond 71 | additionalFlowNoAffine: 2 72 | split: 73 | enable: true 74 | fea_up0: true 75 | stackRRDB: 76 | blocks: [ 1, 3, 5, 7 ] 77 | concat: true 78 | 79 | #### path 80 | path: 81 | pretrain_model_G: ../pretrained_models/RRDB_DF2K_8X.pth 82 | strict_load: true 83 | resume_state: auto 84 | 85 | #### training settings: learning rate scheme, loss 86 | train: 87 | manual_seed: 10 88 | lr_G: !!float 5e-4 89 | weight_decay_G: 0 90 | beta1: 0.9 91 | beta2: 0.99 92 | lr_scheme: MultiStepLR 93 | warmup_iter: -1 # no warm up 94 | lr_steps_rel: [ 0.5, 0.75, 0.9, 0.95 ] 95 | lr_gamma: 0.5 96 | 97 | niter: 200000 98 | val_freq: 40000 99 | 100 | #### validation settings 101 | val: 102 | heats: [ 0.0, 0.5, 0.75, 1.0 ] 103 | n_sample: 3 104 | 105 | test: 106 | heats: [ 0.0, 0.7, 0.8, 0.9 ] 107 | 108 | #### logger 109 | logger: 110 | # Debug print_freq: 100 111 | print_freq: 100 112 | save_checkpoint_freq: !!float 1e3 113 | -------------------------------------------------------------------------------- /code/data/LRHR_PKL_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | import os 18 | import subprocess 19 | import torch.utils.data as data 20 | import numpy as np 21 | import time 22 | import torch 23 | 24 | import pickle 25 | 26 | 27 | class LRHR_PKLDataset(data.Dataset): 28 | def __init__(self, opt): 29 | super(LRHR_PKLDataset, self).__init__() 30 | self.opt = opt 31 | self.crop_size = opt.get("GT_size", None) 32 | self.scale = None 33 | self.random_scale_list = [1] 34 | 35 | hr_file_path = opt["dataroot_GT"] 36 | lr_file_path = opt["dataroot_LQ"] 37 | y_labels_file_path = opt['dataroot_y_labels'] 38 | 39 | gpu = True 40 | augment = True 41 | 42 | self.use_flip = opt["use_flip"] if "use_flip" in opt.keys() else False 43 | self.use_rot = opt["use_rot"] if "use_rot" in opt.keys() else False 44 | self.use_crop = opt["use_crop"] if "use_crop" in opt.keys() else False 45 | self.center_crop_hr_size = opt.get("center_crop_hr_size", None) 46 | 47 | n_max = opt["n_max"] if "n_max" in opt.keys() else int(1e8) 48 | 49 | t = time.time() 50 | self.lr_images = self.load_pkls(lr_file_path, n_max) 51 | self.hr_images = self.load_pkls(hr_file_path, n_max) 52 | 53 | min_val_hr = np.min([i.min() for i in self.hr_images[:20]]) 54 | max_val_hr = np.max([i.max() for i in self.hr_images[:20]]) 55 | 56 | min_val_lr = np.min([i.min() for i in self.lr_images[:20]]) 57 | max_val_lr = np.max([i.max() for i in self.lr_images[:20]]) 58 | 59 | t = time.time() - t 60 | print("Loaded {} HR images with [{:.2f}, {:.2f}] in {:.2f}s from {}". 61 | format(len(self.hr_images), min_val_hr, max_val_hr, t, hr_file_path)) 62 | print("Loaded {} LR images with [{:.2f}, {:.2f}] in {:.2f}s from {}". 63 | format(len(self.lr_images), min_val_lr, max_val_lr, t, lr_file_path)) 64 | 65 | self.gpu = gpu 66 | self.augment = augment 67 | 68 | self.measures = None 69 | 70 | def load_pkls(self, path, n_max): 71 | assert os.path.isfile(path), path 72 | images = [] 73 | with open(path, "rb") as f: 74 | images += pickle.load(f) 75 | assert len(images) > 0, path 76 | images = images[:n_max] 77 | images = [np.transpose(image, [2, 0, 1]) for image in images] 78 | return images 79 | 80 | def __len__(self): 81 | return len(self.hr_images) 82 | 83 | def __getitem__(self, item): 84 | hr = self.hr_images[item] 85 | lr = self.lr_images[item] 86 | 87 | if self.scale == None: 88 | self.scale = hr.shape[1] // lr.shape[1] 89 | assert hr.shape[1] == self.scale * lr.shape[1], ('non-fractional ratio', lr.shape, hr.shape) 90 | 91 | if self.use_crop: 92 | hr, lr = random_crop(hr, lr, self.crop_size, self.scale, self.use_crop) 93 | 94 | if self.center_crop_hr_size: 95 | hr, lr = center_crop(hr, self.center_crop_hr_size), center_crop(lr, self.center_crop_hr_size // self.scale) 96 | 97 | if self.use_flip: 98 | hr, lr = random_flip(hr, lr) 99 | 100 | if self.use_rot: 101 | hr, lr = random_rotation(hr, lr) 102 | 103 | hr = hr / 255.0 104 | lr = lr / 255.0 105 | 106 | if self.measures is None or np.random.random() < 0.05: 107 | if self.measures is None: 108 | self.measures = {} 109 | self.measures['hr_means'] = np.mean(hr) 110 | self.measures['hr_stds'] = np.std(hr) 111 | self.measures['lr_means'] = np.mean(lr) 112 | self.measures['lr_stds'] = np.std(lr) 113 | 114 | hr = torch.Tensor(hr) 115 | lr = torch.Tensor(lr) 116 | 117 | # if self.gpu: 118 | # hr = hr.cuda() 119 | # lr = lr.cuda() 120 | 121 | return {'LQ': lr, 'GT': hr, 'LQ_path': str(item), 'GT_path': str(item)} 122 | 123 | def print_and_reset(self, tag): 124 | m = self.measures 125 | kvs = [] 126 | for k in sorted(m.keys()): 127 | kvs.append("{}={:.2f}".format(k, m[k])) 128 | print("[KPI] " + tag + ": " + ", ".join(kvs)) 129 | self.measures = None 130 | 131 | 132 | def random_flip(img, seg): 133 | random_choice = np.random.choice([True, False]) 134 | img = img if random_choice else np.flip(img, 2).copy() 135 | seg = seg if random_choice else np.flip(seg, 2).copy() 136 | return img, seg 137 | 138 | 139 | def random_rotation(img, seg): 140 | random_choice = np.random.choice([0, 1, 3]) 141 | img = np.rot90(img, random_choice, axes=(1, 2)).copy() 142 | seg = np.rot90(seg, random_choice, axes=(1, 2)).copy() 143 | return img, seg 144 | 145 | 146 | def random_crop(hr, lr, size_hr, scale, random): 147 | size_lr = size_hr // scale 148 | 149 | size_lr_x = lr.shape[1] 150 | size_lr_y = lr.shape[2] 151 | 152 | start_x_lr = np.random.randint(low=0, high=(size_lr_x - size_lr) + 1) if size_lr_x > size_lr else 0 153 | start_y_lr = np.random.randint(low=0, high=(size_lr_y - size_lr) + 1) if size_lr_y > size_lr else 0 154 | 155 | # LR Patch 156 | lr_patch = lr[:, start_x_lr:start_x_lr + size_lr, start_y_lr:start_y_lr + size_lr] 157 | 158 | # HR Patch 159 | start_x_hr = start_x_lr * scale 160 | start_y_hr = start_y_lr * scale 161 | hr_patch = hr[:, start_x_hr:start_x_hr + size_hr, start_y_hr:start_y_hr + size_hr] 162 | 163 | return hr_patch, lr_patch 164 | 165 | 166 | def center_crop(img, size): 167 | assert img.shape[1] == img.shape[2], img.shape 168 | border_double = img.shape[1] - size 169 | assert border_double % 2 == 0, (img.shape, size) 170 | border = border_double // 2 171 | return img[:, border:-border, border:-border] 172 | 173 | 174 | def center_crop_tensor(img, size): 175 | assert img.shape[2] == img.shape[3], img.shape 176 | border_double = img.shape[2] - size 177 | assert border_double % 2 == 0, (img.shape, size) 178 | border = border_double // 2 179 | return img[:, :, border:-border, border:-border] 180 | -------------------------------------------------------------------------------- /code/data/__init__.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | '''create dataset and dataloader''' 18 | import logging 19 | import torch 20 | import torch.utils.data 21 | 22 | 23 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 24 | phase = dataset_opt.get('phase', 'test') 25 | if phase == 'train': 26 | gpu_ids = opt.get('gpu_ids', None) 27 | gpu_ids = gpu_ids if gpu_ids else [] 28 | num_workers = dataset_opt['n_workers'] * len(gpu_ids) 29 | batch_size = dataset_opt['batch_size'] 30 | shuffle = True 31 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 32 | num_workers=num_workers, sampler=sampler, drop_last=True, 33 | pin_memory=False) 34 | else: 35 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, 36 | pin_memory=True) 37 | 38 | 39 | def create_dataset(dataset_opt): 40 | print(dataset_opt) 41 | mode = dataset_opt['mode'] 42 | if mode == 'LRHR_PKL': 43 | from data.LRHR_PKL_dataset import LRHR_PKLDataset as D 44 | else: 45 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) 46 | dataset = D(dataset_opt) 47 | 48 | logger = logging.getLogger('base') 49 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, 50 | dataset_opt['name'])) 51 | return dataset 52 | -------------------------------------------------------------------------------- /code/imresize.py: -------------------------------------------------------------------------------- 1 | # https://github.com/fatheral/matlab_imresize 2 | # 3 | # MIT License 4 | # 5 | # Copyright (c) 2020 Alex 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to the following conditions: 13 | # 14 | # The above copyright notice and this permission notice shall be included in all 15 | # copies or substantial portions of the Software. 16 | # 17 | # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 18 | # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 19 | # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 20 | # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 21 | # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 22 | # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 23 | # SOFTWARE. 24 | 25 | 26 | from __future__ import print_function 27 | import numpy as np 28 | from math import ceil, floor 29 | 30 | 31 | def deriveSizeFromScale(img_shape, scale): 32 | output_shape = [] 33 | for k in range(2): 34 | output_shape.append(int(ceil(scale[k] * img_shape[k]))) 35 | return output_shape 36 | 37 | 38 | def deriveScaleFromSize(img_shape_in, img_shape_out): 39 | scale = [] 40 | for k in range(2): 41 | scale.append(1.0 * img_shape_out[k] / img_shape_in[k]) 42 | return scale 43 | 44 | 45 | def triangle(x): 46 | x = np.array(x).astype(np.float64) 47 | lessthanzero = np.logical_and((x >= -1), x < 0) 48 | greaterthanzero = np.logical_and((x <= 1), x >= 0) 49 | f = np.multiply((x + 1), lessthanzero) + np.multiply((1 - x), greaterthanzero) 50 | return f 51 | 52 | 53 | def cubic(x): 54 | x = np.array(x).astype(np.float64) 55 | absx = np.absolute(x) 56 | absx2 = np.multiply(absx, absx) 57 | absx3 = np.multiply(absx2, absx) 58 | f = np.multiply(1.5 * absx3 - 2.5 * absx2 + 1, absx <= 1) + np.multiply(-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2, 59 | (1 < absx) & (absx <= 2)) 60 | return f 61 | 62 | 63 | def contributions(in_length, out_length, scale, kernel, k_width): 64 | if scale < 1: 65 | h = lambda x: scale * kernel(scale * x) 66 | kernel_width = 1.0 * k_width / scale 67 | else: 68 | h = kernel 69 | kernel_width = k_width 70 | x = np.arange(1, out_length + 1).astype(np.float64) 71 | u = x / scale + 0.5 * (1 - 1 / scale) 72 | left = np.floor(u - kernel_width / 2) 73 | P = int(ceil(kernel_width)) + 2 74 | ind = np.expand_dims(left, axis=1) + np.arange(P) - 1 # -1 because indexing from 0 75 | indices = ind.astype(np.int32) 76 | weights = h(np.expand_dims(u, axis=1) - indices - 1) # -1 because indexing from 0 77 | weights = np.divide(weights, np.expand_dims(np.sum(weights, axis=1), axis=1)) 78 | aux = np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))).astype(np.int32) 79 | indices = aux[np.mod(indices, aux.size)] 80 | ind2store = np.nonzero(np.any(weights, axis=0)) 81 | weights = weights[:, ind2store] 82 | indices = indices[:, ind2store] 83 | return weights, indices 84 | 85 | 86 | def imresizemex(inimg, weights, indices, dim): 87 | in_shape = inimg.shape 88 | w_shape = weights.shape 89 | out_shape = list(in_shape) 90 | out_shape[dim] = w_shape[0] 91 | outimg = np.zeros(out_shape) 92 | if dim == 0: 93 | for i_img in range(in_shape[1]): 94 | for i_w in range(w_shape[0]): 95 | w = weights[i_w, :] 96 | ind = indices[i_w, :] 97 | im_slice = inimg[ind, i_img].astype(np.float64) 98 | outimg[i_w, i_img] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0) 99 | elif dim == 1: 100 | for i_img in range(in_shape[0]): 101 | for i_w in range(w_shape[0]): 102 | w = weights[i_w, :] 103 | ind = indices[i_w, :] 104 | im_slice = inimg[i_img, ind].astype(np.float64) 105 | outimg[i_img, i_w] = np.sum(np.multiply(np.squeeze(im_slice, axis=0), w.T), axis=0) 106 | if inimg.dtype == np.uint8: 107 | outimg = np.clip(outimg, 0, 255) 108 | return np.around(outimg).astype(np.uint8) 109 | else: 110 | return outimg 111 | 112 | 113 | def imresizevec(inimg, weights, indices, dim): 114 | wshape = weights.shape 115 | if dim == 0: 116 | weights = weights.reshape((wshape[0], wshape[2], 1, 1)) 117 | outimg = np.sum(weights * ((inimg[indices].squeeze(axis=1)).astype(np.float64)), axis=1) 118 | elif dim == 1: 119 | weights = weights.reshape((1, wshape[0], wshape[2], 1)) 120 | outimg = np.sum(weights * ((inimg[:, indices].squeeze(axis=2)).astype(np.float64)), axis=2) 121 | if inimg.dtype == np.uint8: 122 | outimg = np.clip(outimg, 0, 255) 123 | return np.around(outimg).astype(np.uint8) 124 | else: 125 | return outimg 126 | 127 | 128 | def resizeAlongDim(A, dim, weights, indices, mode="vec"): 129 | if mode == "org": 130 | out = imresizemex(A, weights, indices, dim) 131 | else: 132 | out = imresizevec(A, weights, indices, dim) 133 | return out 134 | 135 | 136 | def imresize(I, scalar_scale=None, method='bicubic', output_shape=None, mode="vec"): 137 | if method is 'bicubic': 138 | kernel = cubic 139 | elif method is 'bilinear': 140 | kernel = triangle 141 | else: 142 | print('Error: Unidentified method supplied') 143 | 144 | kernel_width = 4.0 145 | # Fill scale and output_size 146 | if scalar_scale is not None: 147 | scalar_scale = float(scalar_scale) 148 | scale = [scalar_scale, scalar_scale] 149 | output_size = deriveSizeFromScale(I.shape, scale) 150 | elif output_shape is not None: 151 | scale = deriveScaleFromSize(I.shape, output_shape) 152 | output_size = list(output_shape) 153 | else: 154 | print('Error: scalar_scale OR output_shape should be defined!') 155 | return 156 | scale_np = np.array(scale) 157 | order = np.argsort(scale_np) 158 | weights = [] 159 | indices = [] 160 | for k in range(2): 161 | w, ind = contributions(I.shape[k], output_size[k], scale[k], kernel, kernel_width) 162 | weights.append(w) 163 | indices.append(ind) 164 | B = np.copy(I) 165 | flag2D = False 166 | if B.ndim == 2: 167 | B = np.expand_dims(B, axis=2) 168 | flag2D = True 169 | for k in range(2): 170 | dim = order[k] 171 | B = resizeAlongDim(B, dim, weights[dim], indices[dim], mode) 172 | if flag2D: 173 | B = np.squeeze(B, axis=2) 174 | return B 175 | 176 | 177 | def convertDouble2Byte(I): 178 | B = np.clip(I, 0.0, 1.0) 179 | B = 255 * B 180 | return np.around(B).astype(np.uint8) -------------------------------------------------------------------------------- /code/models/SRFlow_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | import logging 18 | from collections import OrderedDict 19 | from utils.util import get_resume_paths, opt_get 20 | 21 | import torch 22 | import torch.nn as nn 23 | from torch.nn.parallel import DataParallel, DistributedDataParallel 24 | import models.networks as networks 25 | import models.lr_scheduler as lr_scheduler 26 | from .base_model import BaseModel 27 | 28 | logger = logging.getLogger('base') 29 | 30 | 31 | class SRFlowModel(BaseModel): 32 | def __init__(self, opt, step): 33 | super(SRFlowModel, self).__init__(opt) 34 | self.opt = opt 35 | 36 | self.heats = opt['val']['heats'] 37 | self.n_sample = opt['val']['n_sample'] 38 | self.hr_size = opt_get(opt, ['datasets', 'train', 'center_crop_hr_size']) 39 | self.hr_size = 160 if self.hr_size is None else self.hr_size 40 | self.lr_size = self.hr_size // opt['scale'] 41 | 42 | if opt['dist']: 43 | self.rank = torch.distributed.get_rank() 44 | else: 45 | self.rank = -1 # non dist training 46 | train_opt = opt['train'] 47 | 48 | # define network and load pretrained models 49 | self.netG = networks.define_Flow(opt, step).to(self.device) 50 | if opt['dist']: 51 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 52 | else: 53 | self.netG = DataParallel(self.netG) 54 | # print network 55 | self.print_network() 56 | 57 | if opt_get(opt, ['path', 'resume_state'], 1) is not None: 58 | self.load() 59 | else: 60 | print("WARNING: skipping initial loading, due to resume_state None") 61 | 62 | if self.is_train: 63 | self.netG.train() 64 | 65 | self.init_optimizer_and_scheduler(train_opt) 66 | self.log_dict = OrderedDict() 67 | 68 | def to(self, device): 69 | self.device = device 70 | self.netG.to(device) 71 | 72 | def init_optimizer_and_scheduler(self, train_opt): 73 | # optimizers 74 | self.optimizers = [] 75 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 76 | optim_params_RRDB = [] 77 | optim_params_other = [] 78 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model 79 | print(k, v.requires_grad) 80 | if v.requires_grad: 81 | if '.RRDB.' in k: 82 | optim_params_RRDB.append(v) 83 | print('opt', k) 84 | else: 85 | optim_params_other.append(v) 86 | if self.rank <= 0: 87 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 88 | 89 | print('rrdb params', len(optim_params_RRDB)) 90 | 91 | self.optimizer_G = torch.optim.Adam( 92 | [ 93 | {"params": optim_params_other, "lr": train_opt['lr_G'], 'beta1': train_opt['beta1'], 94 | 'beta2': train_opt['beta2'], 'weight_decay': wd_G}, 95 | {"params": optim_params_RRDB, "lr": train_opt.get('lr_RRDB', train_opt['lr_G']), 96 | 'beta1': train_opt['beta1'], 97 | 'beta2': train_opt['beta2'], 'weight_decay': wd_G} 98 | ], 99 | ) 100 | 101 | self.optimizers.append(self.optimizer_G) 102 | # schedulers 103 | if train_opt['lr_scheme'] == 'MultiStepLR': 104 | for optimizer in self.optimizers: 105 | self.schedulers.append( 106 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], 107 | restarts=train_opt['restarts'], 108 | weights=train_opt['restart_weights'], 109 | gamma=train_opt['lr_gamma'], 110 | clear_state=train_opt['clear_state'], 111 | lr_steps_invese=train_opt.get('lr_steps_inverse', []))) 112 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': 113 | for optimizer in self.optimizers: 114 | self.schedulers.append( 115 | lr_scheduler.CosineAnnealingLR_Restart( 116 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], 117 | restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) 118 | else: 119 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 120 | 121 | def add_optimizer_and_scheduler_RRDB(self, train_opt): 122 | # optimizers 123 | assert len(self.optimizers) == 1, self.optimizers 124 | assert len(self.optimizer_G.param_groups[1]['params']) == 0, self.optimizer_G.param_groups[1] 125 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model 126 | if v.requires_grad: 127 | if '.RRDB.' in k: 128 | self.optimizer_G.param_groups[1]['params'].append(v) 129 | assert len(self.optimizer_G.param_groups[1]['params']) > 0 130 | 131 | def feed_data(self, data, need_GT=True): 132 | self.var_L = data['LQ'].to(self.device) # LQ 133 | if need_GT: 134 | self.real_H = data['GT'].to(self.device) # GT 135 | 136 | def optimize_parameters(self, step): 137 | 138 | train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) 139 | if train_RRDB_delay is not None and step > int(train_RRDB_delay * self.opt['train']['niter']) \ 140 | and not self.netG.module.RRDB_training: 141 | if self.netG.module.set_rrdb_training(True): 142 | self.add_optimizer_and_scheduler_RRDB(self.opt['train']) 143 | 144 | # self.print_rrdb_state() 145 | 146 | self.netG.train() 147 | self.log_dict = OrderedDict() 148 | self.optimizer_G.zero_grad() 149 | 150 | losses = {} 151 | weight_fl = opt_get(self.opt, ['train', 'weight_fl']) 152 | weight_fl = 1 if weight_fl is None else weight_fl 153 | if weight_fl > 0: 154 | z, nll, y_logits = self.netG(gt=self.real_H, lr=self.var_L, reverse=False) 155 | nll_loss = torch.mean(nll) 156 | losses['nll_loss'] = nll_loss * weight_fl 157 | 158 | weight_l1 = opt_get(self.opt, ['train', 'weight_l1']) or 0 159 | if weight_l1 > 0: 160 | z = self.get_z(heat=0, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) 161 | sr, logdet = self.netG(lr=self.var_L, z=z, eps_std=0, reverse=True, reverse_with_grad=True) 162 | l1_loss = (sr - self.real_H).abs().mean() 163 | losses['l1_loss'] = l1_loss * weight_l1 164 | 165 | total_loss = sum(losses.values()) 166 | total_loss.backward() 167 | self.optimizer_G.step() 168 | 169 | mean = total_loss.item() 170 | return mean 171 | 172 | def print_rrdb_state(self): 173 | for name, param in self.netG.module.named_parameters(): 174 | if "RRDB.conv_first.weight" in name: 175 | print(name, param.requires_grad, param.data.abs().sum()) 176 | print('params', [len(p['params']) for p in self.optimizer_G.param_groups]) 177 | 178 | def test(self): 179 | self.netG.eval() 180 | self.fake_H = {} 181 | for heat in self.heats: 182 | for i in range(self.n_sample): 183 | z = self.get_z(heat, seed=None, batch_size=self.var_L.shape[0], lr_shape=self.var_L.shape) 184 | with torch.no_grad(): 185 | self.fake_H[(heat, i)], logdet = self.netG(lr=self.var_L, z=z, eps_std=heat, reverse=True) 186 | with torch.no_grad(): 187 | _, nll, _ = self.netG(gt=self.real_H, lr=self.var_L, reverse=False) 188 | self.netG.train() 189 | return nll.mean().item() 190 | 191 | def get_encode_nll(self, lq, gt): 192 | self.netG.eval() 193 | with torch.no_grad(): 194 | _, nll, _ = self.netG(gt=gt, lr=lq, reverse=False) 195 | self.netG.train() 196 | return nll.mean().item() 197 | 198 | def get_sr(self, lq, heat=None, seed=None, z=None, epses=None): 199 | return self.get_sr_with_z(lq, heat, seed, z, epses)[0] 200 | 201 | def get_encode_z(self, lq, gt, epses=None, add_gt_noise=True): 202 | self.netG.eval() 203 | with torch.no_grad(): 204 | z, _, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise) 205 | self.netG.train() 206 | return z 207 | 208 | def get_encode_z_and_nll(self, lq, gt, epses=None, add_gt_noise=True): 209 | self.netG.eval() 210 | with torch.no_grad(): 211 | z, nll, _ = self.netG(gt=gt, lr=lq, reverse=False, epses=epses, add_gt_noise=add_gt_noise) 212 | self.netG.train() 213 | return z, nll 214 | 215 | def get_sr_with_z(self, lq, heat=None, seed=None, z=None, epses=None): 216 | self.netG.eval() 217 | 218 | z = self.get_z(heat, seed, batch_size=lq.shape[0], lr_shape=lq.shape) if z is None and epses is None else z 219 | 220 | with torch.no_grad(): 221 | sr, logdet = self.netG(lr=lq, z=z, eps_std=heat, reverse=True, epses=epses) 222 | self.netG.train() 223 | return sr, z 224 | 225 | def get_z(self, heat, seed=None, batch_size=1, lr_shape=None): 226 | if seed: torch.manual_seed(seed) 227 | if opt_get(self.opt, ['network_G', 'flow', 'split', 'enable']): 228 | C = self.netG.module.flowUpsamplerNet.C 229 | H = int(self.opt['scale'] * lr_shape[2] // self.netG.module.flowUpsamplerNet.scaleH) 230 | W = int(self.opt['scale'] * lr_shape[3] // self.netG.module.flowUpsamplerNet.scaleW) 231 | z = torch.normal(mean=0, std=heat, size=(batch_size, C, H, W)) if heat > 0 else torch.zeros( 232 | (batch_size, C, H, W)) 233 | else: 234 | L = opt_get(self.opt, ['network_G', 'flow', 'L']) or 3 235 | fac = 2 ** (L - 3) 236 | z_size = int(self.lr_size // (2 ** (L - 3))) 237 | z = torch.normal(mean=0, std=heat, size=(batch_size, 3 * 8 * 8 * fac * fac, z_size, z_size)) 238 | return z 239 | 240 | def get_current_log(self): 241 | return self.log_dict 242 | 243 | def get_current_visuals(self, need_GT=True): 244 | out_dict = OrderedDict() 245 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu() 246 | for heat in self.heats: 247 | for i in range(self.n_sample): 248 | out_dict[('SR', heat, i)] = self.fake_H[(heat, i)].detach()[0].float().cpu() 249 | if need_GT: 250 | out_dict['GT'] = self.real_H.detach()[0].float().cpu() 251 | return out_dict 252 | 253 | def print_network(self): 254 | s, n = self.get_network_description(self.netG) 255 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): 256 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 257 | self.netG.module.__class__.__name__) 258 | else: 259 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 260 | if self.rank <= 0: 261 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 262 | logger.info(s) 263 | 264 | def load(self): 265 | _, get_resume_model_path = get_resume_paths(self.opt) 266 | if get_resume_model_path is not None: 267 | self.load_network(get_resume_model_path, self.netG, strict=True, submodule=None) 268 | return 269 | 270 | load_path_G = self.opt['path']['pretrain_model_G'] 271 | load_submodule = self.opt['path']['load_submodule'] if 'load_submodule' in self.opt['path'].keys() else 'RRDB' 272 | if load_path_G is not None: 273 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 274 | self.load_network(load_path_G, self.netG, self.opt['path'].get('strict_load', True), 275 | submodule=load_submodule) 276 | 277 | def save(self, iter_label): 278 | self.save_network(self.netG, 'G', iter_label) 279 | -------------------------------------------------------------------------------- /code/models/SR_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | import logging 18 | from collections import OrderedDict 19 | 20 | import torch 21 | import torch.nn as nn 22 | from torch.nn.parallel import DataParallel, DistributedDataParallel 23 | import models.networks as networks 24 | import models.lr_scheduler as lr_scheduler 25 | from utils.util import opt_get 26 | from .base_model import BaseModel 27 | from models.modules.loss import CharbonnierLoss 28 | 29 | logger = logging.getLogger('base') 30 | 31 | 32 | class SRModel(BaseModel): 33 | def __init__(self, opt, step): 34 | super(SRModel, self).__init__(opt) 35 | 36 | self.step = step 37 | 38 | if opt['dist']: 39 | self.rank = torch.distributed.get_rank() 40 | else: 41 | self.rank = -1 # non dist training 42 | train_opt = opt['train'] 43 | 44 | # define network and load pretrained_models models 45 | self.netG = networks.define_G(opt).to(self.device) 46 | if opt['dist']: 47 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 48 | else: 49 | self.netG = DataParallel(self.netG) 50 | # print network 51 | self.print_network() 52 | self.load() 53 | 54 | if self.is_train: 55 | self.netG.train() 56 | 57 | # loss 58 | loss_type = train_opt['pixel_criterion'] 59 | if loss_type == 'l1': 60 | self.cri_pix = nn.L1Loss().to(self.device) 61 | elif loss_type == 'l2': 62 | self.cri_pix = nn.MSELoss().to(self.device) 63 | elif loss_type == 'cb': 64 | self.cri_pix = CharbonnierLoss().to(self.device) 65 | else: 66 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) 67 | self.l_pix_w = train_opt['pixel_weight'] 68 | 69 | # optimizers 70 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 71 | optim_params = [] 72 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model 73 | if v.requires_grad: 74 | optim_params.append(v) 75 | else: 76 | if self.rank <= 0: 77 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 78 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], 79 | weight_decay=wd_G, 80 | betas=(train_opt['beta1'], train_opt['beta2'])) 81 | self.optimizers.append(self.optimizer_G) 82 | 83 | # schedulers 84 | if train_opt['lr_scheme'] == 'MultiStepLR': 85 | for optimizer in self.optimizers: 86 | self.schedulers.append( 87 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], 88 | restarts=train_opt['restarts'], 89 | weights=train_opt['restart_weights'], 90 | gamma=train_opt['lr_gamma'], 91 | clear_state=train_opt['clear_state'])) 92 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': 93 | for optimizer in self.optimizers: 94 | self.schedulers.append( 95 | lr_scheduler.CosineAnnealingLR_Restart( 96 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], 97 | restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) 98 | else: 99 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 100 | 101 | self.log_dict = OrderedDict() 102 | 103 | def feed_data(self, data, need_GT=True): 104 | self.var_L = data['LQ'].to(self.device) # LQ 105 | if need_GT: 106 | self.real_H = data['GT'].to(self.device) # GT 107 | 108 | def to(self, device): 109 | self.device = device 110 | self.netG.to(device) 111 | 112 | def optimize_parameters(self, step): 113 | def getEnv(name): import os; return True if name in os.environ.keys() else False 114 | 115 | if getEnv("DEBUG_FEED_IMAGES"): 116 | import imageio 117 | import random 118 | i = random.randint(0, 10000) 119 | label = self.var_L.cpu().numpy()[0].transpose([1, 2, 0]) 120 | print("var_L", label.min(), label.max(), label.shape) 121 | imageio.imwrite("/tmp/{}_l.png".format(i), label) 122 | image = self.real_H.cpu().numpy()[0].transpose([1, 2, 0]) 123 | print("self.real_H", image.min(), image.max(), image.shape) 124 | imageio.imwrite("/tmp/{}_gt.png".format(i), image) 125 | self.optimizer_G.zero_grad() 126 | self.fake_H = self.netG(self.var_L) 127 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H.to(self.fake_H.device)) 128 | l_pix.backward() 129 | self.optimizer_G.step() 130 | 131 | # set log 132 | self.log_dict['l_pix'] = l_pix.item() 133 | 134 | def test(self): 135 | self.netG.eval() 136 | with torch.no_grad(): 137 | self.fake_H = self.netG(self.var_L) 138 | self.netG.train() 139 | 140 | def get_encode_nll(self, lq, gt): 141 | return torch.ones(1) * 1e14 142 | 143 | def get_sr(self, lq, heat=None, seed=None): 144 | self.netG.eval() 145 | sr = self.netG(lq) 146 | self.netG.train() 147 | return sr 148 | 149 | def test_x8(self): 150 | # from https://github.com/thstkdgus35/EDSR-PyTorch 151 | self.netG.eval() 152 | 153 | def _transform(v, op): 154 | # if self.precision != 'single': v = v.float() 155 | v2np = v.data.cpu().numpy() 156 | if op == 'v': 157 | tfnp = v2np[:, :, :, ::-1].copy() 158 | elif op == 'h': 159 | tfnp = v2np[:, :, ::-1, :].copy() 160 | elif op == 't': 161 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 162 | 163 | ret = torch.Tensor(tfnp).to(self.device) 164 | # if self.precision == 'half': ret = ret.half() 165 | 166 | return ret 167 | 168 | lr_list = [self.var_L] 169 | for tf in 'v', 'h', 't': 170 | lr_list.extend([_transform(t, tf) for t in lr_list]) 171 | with torch.no_grad(): 172 | sr_list = [self.netG(aug) for aug in lr_list] 173 | for i in range(len(sr_list)): 174 | if i > 3: 175 | sr_list[i] = _transform(sr_list[i], 't') 176 | if i % 4 > 1: 177 | sr_list[i] = _transform(sr_list[i], 'h') 178 | if (i % 4) % 2 == 1: 179 | sr_list[i] = _transform(sr_list[i], 'v') 180 | 181 | output_cat = torch.cat(sr_list, dim=0) 182 | self.fake_H = output_cat.mean(dim=0, keepdim=True) 183 | self.netG.train() 184 | 185 | def get_current_log(self): 186 | return self.log_dict 187 | 188 | def get_current_visuals(self, need_GT=True): 189 | out_dict = OrderedDict() 190 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu() 191 | out_dict['SR'] = self.fake_H.detach()[0].float().cpu() 192 | if need_GT: 193 | out_dict['GT'] = self.real_H.detach()[0].float().cpu() 194 | return out_dict 195 | 196 | def print_network(self): 197 | s, n = self.get_network_description(self.netG) 198 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): 199 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 200 | self.netG.module.__class__.__name__) 201 | else: 202 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 203 | if self.rank <= 0: 204 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 205 | logger.info(s) 206 | 207 | def load(self): 208 | load_path_G = self.opt['path']['pretrain_model_G'] 209 | if load_path_G is not None: 210 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 211 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 212 | 213 | def save(self, iter_label): 214 | self.save_network(self.netG, 'G', iter_label) 215 | 216 | def get_encode_z_and_nll(self, *args, **kwargs): 217 | return [], torch.zeros(1) 218 | -------------------------------------------------------------------------------- /code/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import logging 3 | import os 4 | 5 | try: 6 | import local_config 7 | except: 8 | local_config = None 9 | 10 | 11 | logger = logging.getLogger('base') 12 | 13 | 14 | def find_model_using_name(model_name): 15 | # Given the option --model [modelname], 16 | # the file "models/modelname_model.py" 17 | # will be imported. 18 | model_filename = "models." + model_name + "_model" 19 | modellib = importlib.import_module(model_filename) 20 | 21 | # In the file, the class called ModelNameModel() will 22 | # be instantiated. It has to be a subclass of torch.nn.Module, 23 | # and it is case-insensitive. 24 | model = None 25 | target_model_name = model_name.replace('_', '') + 'Model' 26 | for name, cls in modellib.__dict__.items(): 27 | if name.lower() == target_model_name.lower(): 28 | model = cls 29 | 30 | if model is None: 31 | print( 32 | "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % ( 33 | model_filename, target_model_name)) 34 | exit(0) 35 | 36 | return model 37 | 38 | 39 | def create_model(opt, step=0, **opt_kwargs): 40 | if local_config is not None: 41 | opt['path']['pretrain_model_G'] = os.path.join(local_config.checkpoint_path, os.path.basename(opt['path']['results_root'] + '.pth')) 42 | 43 | for k, v in opt_kwargs.items(): 44 | opt[k] = v 45 | 46 | model = opt['model'] 47 | 48 | M = find_model_using_name(model) 49 | 50 | m = M(opt, step) 51 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 52 | return m 53 | -------------------------------------------------------------------------------- /code/models/base_model.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | import os 18 | from collections import OrderedDict 19 | import torch 20 | import torch.nn as nn 21 | from torch.nn.parallel import DistributedDataParallel 22 | import natsort 23 | import glob 24 | 25 | 26 | class BaseModel(): 27 | def __init__(self, opt): 28 | self.opt = opt 29 | self.device = torch.device('cuda' if opt.get('gpu_ids', None) is not None else 'cpu') 30 | self.is_train = opt['is_train'] 31 | self.schedulers = [] 32 | self.optimizers = [] 33 | 34 | def feed_data(self, data): 35 | pass 36 | 37 | def optimize_parameters(self): 38 | pass 39 | 40 | def get_current_visuals(self): 41 | pass 42 | 43 | def get_current_losses(self): 44 | pass 45 | 46 | def print_network(self): 47 | pass 48 | 49 | def save(self, label): 50 | pass 51 | 52 | def load(self): 53 | pass 54 | 55 | def _set_lr(self, lr_groups_l): 56 | ''' set learning rate for warmup, 57 | lr_groups_l: list for lr_groups. each for a optimizer''' 58 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 59 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 60 | param_group['lr'] = lr 61 | 62 | def _get_init_lr(self): 63 | # get the initial lr, which is set by the scheduler 64 | init_lr_groups_l = [] 65 | for optimizer in self.optimizers: 66 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) 67 | return init_lr_groups_l 68 | 69 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 70 | for scheduler in self.schedulers: 71 | scheduler.step() 72 | #### set up warm up learning rate 73 | if cur_iter < warmup_iter: 74 | # get initial lr for each group 75 | init_lr_g_l = self._get_init_lr() 76 | # modify warming-up learning rates 77 | warm_up_lr_l = [] 78 | for init_lr_g in init_lr_g_l: 79 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 80 | # set learning rate 81 | self._set_lr(warm_up_lr_l) 82 | 83 | def get_current_learning_rate(self): 84 | # return self.schedulers[0].get_lr()[0] 85 | return self.optimizers[0].param_groups[0]['lr'] 86 | 87 | def get_network_description(self, network): 88 | '''Get the string and total parameters of the network''' 89 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 90 | network = network.module 91 | s = str(network) 92 | n = sum(map(lambda x: x.numel(), network.parameters())) 93 | return s, n 94 | 95 | def save_network(self, network, network_label, iter_label): 96 | paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['models'], "*_{}.pth".format(network_label))), 97 | reverse=True) 98 | paths = [p for p in paths if 99 | "latest_" not in p and not any([str(i * 10000) in p.split("/")[-1].split("_") for i in range(101)])] 100 | if len(paths) > 2: 101 | for path in paths[2:]: 102 | os.remove(path) 103 | save_filename = '{}_{}.pth'.format(iter_label, network_label) 104 | save_path = os.path.join(self.opt['path']['models'], save_filename) 105 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 106 | network = network.module 107 | state_dict = network.state_dict() 108 | for key, param in state_dict.items(): 109 | state_dict[key] = param.cpu() 110 | torch.save(state_dict, save_path) 111 | 112 | def load_network(self, load_path, network, strict=True, submodule=None): 113 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 114 | network = network.module 115 | if not (submodule is None or submodule.lower() == 'none'.lower()): 116 | network = network.__getattr__(submodule) 117 | load_net = torch.load(load_path) 118 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 119 | for k, v in load_net.items(): 120 | if k.startswith('module.'): 121 | load_net_clean[k[7:]] = v 122 | else: 123 | load_net_clean[k] = v 124 | network.load_state_dict(load_net_clean, strict=strict) 125 | 126 | def save_training_state(self, epoch, iter_step): 127 | '''Saves training state during training, which will be used for resuming''' 128 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} 129 | for s in self.schedulers: 130 | state['schedulers'].append(s.state_dict()) 131 | for o in self.optimizers: 132 | state['optimizers'].append(o.state_dict()) 133 | save_filename = '{}.state'.format(iter_step) 134 | save_path = os.path.join(self.opt['path']['training_state'], save_filename) 135 | 136 | paths = natsort.natsorted(glob.glob(os.path.join(self.opt['path']['training_state'], "*.state")), 137 | reverse=True) 138 | paths = [p for p in paths if "latest_" not in p] 139 | if len(paths) > 2: 140 | for path in paths[2:]: 141 | os.remove(path) 142 | 143 | torch.save(state, save_path) 144 | 145 | def resume_training(self, resume_state): 146 | '''Resume the optimizers and schedulers for training''' 147 | resume_optimizers = resume_state['optimizers'] 148 | resume_schedulers = resume_state['schedulers'] 149 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' 150 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' 151 | for i, o in enumerate(resume_optimizers): 152 | self.optimizers[i].load_state_dict(o) 153 | for i, s in enumerate(resume_schedulers): 154 | self.schedulers[i].load_state_dict(s) 155 | -------------------------------------------------------------------------------- /code/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | import math 18 | from collections import Counter 19 | from collections import defaultdict 20 | import torch 21 | from torch.optim.lr_scheduler import _LRScheduler 22 | 23 | 24 | class MultiStepLR_Restart(_LRScheduler): 25 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, 26 | clear_state=False, last_epoch=-1, lr_steps_invese=None): 27 | assert lr_steps_invese is not None, "Use empty list" 28 | self.milestones = Counter(milestones) 29 | self.lr_steps_inverse = Counter(lr_steps_invese) 30 | self.gamma = gamma 31 | self.clear_state = clear_state 32 | self.restarts = restarts if restarts else [0] 33 | self.restart_weights = weights if weights else [1] 34 | assert len(self.restarts) == len( 35 | self.restart_weights), 'restarts and their weights do not match.' 36 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) 37 | 38 | def get_lr(self): 39 | if self.last_epoch in self.restarts: 40 | if self.clear_state: 41 | self.optimizer.state = defaultdict(dict) 42 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 43 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 44 | if self.last_epoch not in self.milestones and self.last_epoch not in self.lr_steps_inverse: 45 | return [group['lr'] for group in self.optimizer.param_groups] 46 | return [ 47 | group['lr'] * (self.gamma ** self.milestones[self.last_epoch]) * 48 | (self.gamma ** (-self.lr_steps_inverse[self.last_epoch])) 49 | for group in self.optimizer.param_groups 50 | ] 51 | 52 | 53 | class CosineAnnealingLR_Restart(_LRScheduler): 54 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): 55 | self.T_period = T_period 56 | self.T_max = self.T_period[0] # current T period 57 | self.eta_min = eta_min 58 | self.restarts = restarts if restarts else [0] 59 | self.restart_weights = weights if weights else [1] 60 | self.last_restart = 0 61 | assert len(self.restarts) == len( 62 | self.restart_weights), 'restarts and their weights do not match.' 63 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) 64 | 65 | def get_lr(self): 66 | if self.last_epoch == 0: 67 | return self.base_lrs 68 | elif self.last_epoch in self.restarts: 69 | self.last_restart = self.last_epoch 70 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] 71 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 72 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 73 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: 74 | return [ 75 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 76 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 77 | ] 78 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / 79 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * 80 | (group['lr'] - self.eta_min) + self.eta_min 81 | for group in self.optimizer.param_groups] 82 | 83 | 84 | if __name__ == "__main__": 85 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, 86 | betas=(0.9, 0.99)) 87 | ############################## 88 | # MultiStepLR_Restart 89 | ############################## 90 | ## Original 91 | lr_steps = [200000, 400000, 600000, 800000] 92 | restarts = None 93 | restart_weights = None 94 | 95 | ## two 96 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] 97 | restarts = [500000] 98 | restart_weights = [1] 99 | 100 | ## four 101 | lr_steps = [ 102 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, 103 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 104 | ] 105 | restarts = [250000, 500000, 750000] 106 | restart_weights = [1, 1, 1] 107 | 108 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, 109 | clear_state=False) 110 | 111 | ############################## 112 | # Cosine Annealing Restart 113 | ############################## 114 | ## two 115 | T_period = [500000, 500000] 116 | restarts = [500000] 117 | restart_weights = [1] 118 | 119 | ## four 120 | T_period = [250000, 250000, 250000, 250000] 121 | restarts = [250000, 500000, 750000] 122 | restart_weights = [1, 1, 1] 123 | 124 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, 125 | weights=restart_weights) 126 | 127 | ############################## 128 | # Draw figure 129 | ############################## 130 | N_iter = 1000000 131 | lr_l = list(range(N_iter)) 132 | for i in range(N_iter): 133 | scheduler.step() 134 | current_lr = optimizer.param_groups[0]['lr'] 135 | lr_l[i] = current_lr 136 | 137 | import matplotlib as mpl 138 | from matplotlib import pyplot as plt 139 | import matplotlib.ticker as mtick 140 | 141 | mpl.style.use('default') 142 | import seaborn 143 | 144 | seaborn.set(style='whitegrid') 145 | seaborn.set_context('paper') 146 | 147 | plt.figure(1) 148 | plt.subplot(111) 149 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 150 | plt.title('Title', fontsize=16, color='k') 151 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') 152 | legend = plt.legend(loc='upper right', shadow=False) 153 | ax = plt.gca() 154 | labels = ax.get_xticks().tolist() 155 | for k, v in enumerate(labels): 156 | labels[k] = str(int(v / 1000)) + 'K' 157 | ax.set_xticklabels(labels) 158 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) 159 | 160 | ax.set_ylabel('Learning rate') 161 | ax.set_xlabel('Iteration') 162 | fig = plt.gcf() 163 | plt.show() 164 | -------------------------------------------------------------------------------- /code/models/modules/FlowActNorms.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import torch 18 | from torch import nn as nn 19 | 20 | from models.modules import thops 21 | 22 | 23 | class _ActNorm(nn.Module): 24 | """ 25 | Activation Normalization 26 | Initialize the bias and scale with a given minibatch, 27 | so that the output per-channel have zero mean and unit variance for that. 28 | 29 | After initialization, `bias` and `logs` will be trained as parameters. 30 | """ 31 | 32 | def __init__(self, num_features, scale=1.): 33 | super().__init__() 34 | # register mean and scale 35 | size = [1, num_features, 1, 1] 36 | self.register_parameter("bias", nn.Parameter(torch.zeros(*size))) 37 | self.register_parameter("logs", nn.Parameter(torch.zeros(*size))) 38 | self.num_features = num_features 39 | self.scale = float(scale) 40 | self.inited = False 41 | 42 | def _check_input_dim(self, input): 43 | return NotImplemented 44 | 45 | def initialize_parameters(self, input): 46 | self._check_input_dim(input) 47 | if not self.training: 48 | return 49 | if (self.bias != 0).any(): 50 | self.inited = True 51 | return 52 | assert input.device == self.bias.device, (input.device, self.bias.device) 53 | with torch.no_grad(): 54 | bias = thops.mean(input.clone(), dim=[0, 2, 3], keepdim=True) * -1.0 55 | vars = thops.mean((input.clone() + bias) ** 2, dim=[0, 2, 3], keepdim=True) 56 | logs = torch.log(self.scale / (torch.sqrt(vars) + 1e-6)) 57 | self.bias.data.copy_(bias.data) 58 | self.logs.data.copy_(logs.data) 59 | self.inited = True 60 | 61 | def _center(self, input, reverse=False, offset=None): 62 | bias = self.bias 63 | 64 | if offset is not None: 65 | bias = bias + offset 66 | 67 | if not reverse: 68 | return input + bias 69 | else: 70 | return input - bias 71 | 72 | def _scale(self, input, logdet=None, reverse=False, offset=None): 73 | logs = self.logs 74 | 75 | if offset is not None: 76 | logs = logs + offset 77 | 78 | if not reverse: 79 | input = input * torch.exp(logs) # should have shape batchsize, n_channels, 1, 1 80 | # input = input * torch.exp(logs+logs_offset) 81 | else: 82 | input = input * torch.exp(-logs) 83 | if logdet is not None: 84 | """ 85 | logs is log_std of `mean of channels` 86 | so we need to multiply pixels 87 | """ 88 | dlogdet = thops.sum(logs) * thops.pixels(input) 89 | if reverse: 90 | dlogdet *= -1 91 | logdet = logdet + dlogdet 92 | return input, logdet 93 | 94 | def forward(self, input, logdet=None, reverse=False, offset_mask=None, logs_offset=None, bias_offset=None): 95 | if not self.inited: 96 | self.initialize_parameters(input) 97 | self._check_input_dim(input) 98 | 99 | if offset_mask is not None: 100 | logs_offset *= offset_mask 101 | bias_offset *= offset_mask 102 | # no need to permute dims as old version 103 | if not reverse: 104 | # center and scale 105 | 106 | # self.input = input 107 | input = self._center(input, reverse, bias_offset) 108 | input, logdet = self._scale(input, logdet, reverse, logs_offset) 109 | else: 110 | # scale and center 111 | input, logdet = self._scale(input, logdet, reverse, logs_offset) 112 | input = self._center(input, reverse, bias_offset) 113 | return input, logdet 114 | 115 | 116 | class ActNorm2d(_ActNorm): 117 | def __init__(self, num_features, scale=1.): 118 | super().__init__(num_features, scale) 119 | 120 | def _check_input_dim(self, input): 121 | assert len(input.size()) == 4 122 | assert input.size(1) == self.num_features, ( 123 | "[ActNorm]: input should be in shape as `BCHW`," 124 | " channels should be {} rather than {}".format( 125 | self.num_features, input.size())) 126 | 127 | 128 | class MaskedActNorm2d(ActNorm2d): 129 | def __init__(self, num_features, scale=1.): 130 | super().__init__(num_features, scale) 131 | 132 | def forward(self, input, mask, logdet=None, reverse=False): 133 | 134 | assert mask.dtype == torch.bool 135 | output, logdet_out = super().forward(input, logdet, reverse) 136 | 137 | input[mask] = output[mask] 138 | logdet[mask] = logdet_out[mask] 139 | 140 | return input, logdet 141 | 142 | -------------------------------------------------------------------------------- /code/models/modules/FlowAffineCouplingsAblation.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import torch 18 | from torch import nn as nn 19 | 20 | from models.modules import thops 21 | from models.modules.flow import Conv2d, Conv2dZeros 22 | from utils.util import opt_get 23 | 24 | 25 | class CondAffineSeparatedAndCond(nn.Module): 26 | def __init__(self, in_channels, opt): 27 | super().__init__() 28 | self.need_features = True 29 | self.in_channels = in_channels 30 | self.in_channels_rrdb = 320 31 | self.kernel_hidden = 1 32 | self.affine_eps = 0.0001 33 | self.n_hidden_layers = 1 34 | hidden_channels = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'hidden_channels']) 35 | self.hidden_channels = 64 if hidden_channels is None else hidden_channels 36 | 37 | self.affine_eps = opt_get(opt, ['network_G', 'flow', 'CondAffineSeparatedAndCond', 'eps'], 0.0001) 38 | 39 | self.channels_for_nn = self.in_channels // 2 40 | self.channels_for_co = self.in_channels - self.channels_for_nn 41 | 42 | if self.channels_for_nn is None: 43 | self.channels_for_nn = self.in_channels // 2 44 | 45 | self.fAffine = self.F(in_channels=self.channels_for_nn + self.in_channels_rrdb, 46 | out_channels=self.channels_for_co * 2, 47 | hidden_channels=self.hidden_channels, 48 | kernel_hidden=self.kernel_hidden, 49 | n_hidden_layers=self.n_hidden_layers) 50 | 51 | self.fFeatures = self.F(in_channels=self.in_channels_rrdb, 52 | out_channels=self.in_channels * 2, 53 | hidden_channels=self.hidden_channels, 54 | kernel_hidden=self.kernel_hidden, 55 | n_hidden_layers=self.n_hidden_layers) 56 | 57 | def forward(self, input: torch.Tensor, logdet=None, reverse=False, ft=None): 58 | if not reverse: 59 | z = input 60 | assert z.shape[1] == self.in_channels, (z.shape[1], self.in_channels) 61 | 62 | # Feature Conditional 63 | scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures) 64 | z = z + shiftFt 65 | z = z * scaleFt 66 | logdet = logdet + self.get_logdet(scaleFt) 67 | 68 | # Self Conditional 69 | z1, z2 = self.split(z) 70 | scale, shift = self.feature_extract_aff(z1, ft, self.fAffine) 71 | self.asserts(scale, shift, z1, z2) 72 | z2 = z2 + shift 73 | z2 = z2 * scale 74 | 75 | logdet = logdet + self.get_logdet(scale) 76 | z = thops.cat_feature(z1, z2) 77 | output = z 78 | else: 79 | z = input 80 | 81 | # Self Conditional 82 | z1, z2 = self.split(z) 83 | scale, shift = self.feature_extract_aff(z1, ft, self.fAffine) 84 | self.asserts(scale, shift, z1, z2) 85 | z2 = z2 / scale 86 | z2 = z2 - shift 87 | z = thops.cat_feature(z1, z2) 88 | logdet = logdet - self.get_logdet(scale) 89 | 90 | # Feature Conditional 91 | scaleFt, shiftFt = self.feature_extract(ft, self.fFeatures) 92 | z = z / scaleFt 93 | z = z - shiftFt 94 | logdet = logdet - self.get_logdet(scaleFt) 95 | 96 | output = z 97 | return output, logdet 98 | 99 | def asserts(self, scale, shift, z1, z2): 100 | assert z1.shape[1] == self.channels_for_nn, (z1.shape[1], self.channels_for_nn) 101 | assert z2.shape[1] == self.channels_for_co, (z2.shape[1], self.channels_for_co) 102 | assert scale.shape[1] == shift.shape[1], (scale.shape[1], shift.shape[1]) 103 | assert scale.shape[1] == z2.shape[1], (scale.shape[1], z1.shape[1], z2.shape[1]) 104 | 105 | def get_logdet(self, scale): 106 | return thops.sum(torch.log(scale), dim=[1, 2, 3]) 107 | 108 | def feature_extract(self, z, f): 109 | h = f(z) 110 | shift, scale = thops.split_feature(h, "cross") 111 | scale = (torch.sigmoid(scale + 2.) + self.affine_eps) 112 | return scale, shift 113 | 114 | def feature_extract_aff(self, z1, ft, f): 115 | z = torch.cat([z1, ft], dim=1) 116 | h = f(z) 117 | shift, scale = thops.split_feature(h, "cross") 118 | scale = (torch.sigmoid(scale + 2.) + self.affine_eps) 119 | return scale, shift 120 | 121 | def split(self, z): 122 | z1 = z[:, :self.channels_for_nn] 123 | z2 = z[:, self.channels_for_nn:] 124 | assert z1.shape[1] + z2.shape[1] == z.shape[1], (z1.shape[1], z2.shape[1], z.shape[1]) 125 | return z1, z2 126 | 127 | def F(self, in_channels, out_channels, hidden_channels, kernel_hidden=1, n_hidden_layers=1): 128 | layers = [Conv2d(in_channels, hidden_channels), nn.ReLU(inplace=False)] 129 | 130 | for _ in range(n_hidden_layers): 131 | layers.append(Conv2d(hidden_channels, hidden_channels, kernel_size=[kernel_hidden, kernel_hidden])) 132 | layers.append(nn.ReLU(inplace=False)) 133 | layers.append(Conv2dZeros(hidden_channels, out_channels)) 134 | 135 | return nn.Sequential(*layers) 136 | -------------------------------------------------------------------------------- /code/models/modules/FlowStep.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import torch 18 | from torch import nn as nn 19 | 20 | import models.modules 21 | import models.modules.Permutations 22 | from models.modules import flow, thops, FlowAffineCouplingsAblation 23 | from utils.util import opt_get 24 | 25 | 26 | def getConditional(rrdbResults, position): 27 | img_ft = rrdbResults if isinstance(rrdbResults, torch.Tensor) else rrdbResults[position] 28 | return img_ft 29 | 30 | 31 | class FlowStep(nn.Module): 32 | FlowPermutation = { 33 | "reverse": lambda obj, z, logdet, rev: (obj.reverse(z, rev), logdet), 34 | "shuffle": lambda obj, z, logdet, rev: (obj.shuffle(z, rev), logdet), 35 | "invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), 36 | "squeeze_invconv": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), 37 | "resqueeze_invconv_alternating_2_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), 38 | "resqueeze_invconv_3": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), 39 | "InvertibleConv1x1GridAlign": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), 40 | "InvertibleConv1x1SubblocksShuf": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), 41 | "InvertibleConv1x1GridAlignIndepBorder": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), 42 | "InvertibleConv1x1GridAlignIndepBorder4": lambda obj, z, logdet, rev: obj.invconv(z, logdet, rev), 43 | } 44 | 45 | def __init__(self, in_channels, hidden_channels, 46 | actnorm_scale=1.0, flow_permutation="invconv", flow_coupling="additive", 47 | LU_decomposed=False, opt=None, image_injector=None, idx=None, acOpt=None, normOpt=None, in_shape=None, 48 | position=None): 49 | # check configures 50 | assert flow_permutation in FlowStep.FlowPermutation, \ 51 | "float_permutation should be in `{}`".format( 52 | FlowStep.FlowPermutation.keys()) 53 | super().__init__() 54 | self.flow_permutation = flow_permutation 55 | self.flow_coupling = flow_coupling 56 | self.image_injector = image_injector 57 | 58 | self.norm_type = normOpt['type'] if normOpt else 'ActNorm2d' 59 | self.position = normOpt['position'] if normOpt else None 60 | 61 | self.in_shape = in_shape 62 | self.position = position 63 | self.acOpt = acOpt 64 | 65 | # 1. actnorm 66 | self.actnorm = models.modules.FlowActNorms.ActNorm2d(in_channels, actnorm_scale) 67 | 68 | # 2. permute 69 | if flow_permutation == "invconv": 70 | self.invconv = models.modules.Permutations.InvertibleConv1x1( 71 | in_channels, LU_decomposed=LU_decomposed) 72 | 73 | # 3. coupling 74 | if flow_coupling == "CondAffineSeparatedAndCond": 75 | self.affine = models.modules.FlowAffineCouplingsAblation.CondAffineSeparatedAndCond(in_channels=in_channels, 76 | opt=opt) 77 | elif flow_coupling == "noCoupling": 78 | pass 79 | else: 80 | raise RuntimeError("coupling not Found:", flow_coupling) 81 | 82 | def forward(self, input, logdet=None, reverse=False, rrdbResults=None): 83 | if not reverse: 84 | return self.normal_flow(input, logdet, rrdbResults) 85 | else: 86 | return self.reverse_flow(input, logdet, rrdbResults) 87 | 88 | def normal_flow(self, z, logdet, rrdbResults=None): 89 | if self.flow_coupling == "bentIdentityPreAct": 90 | z, logdet = self.bentIdentPar(z, logdet, reverse=False) 91 | 92 | # 1. actnorm 93 | if self.norm_type == "ConditionalActNormImageInjector": 94 | img_ft = getConditional(rrdbResults, self.position) 95 | z, logdet = self.actnorm(z, img_ft=img_ft, logdet=logdet, reverse=False) 96 | elif self.norm_type == "noNorm": 97 | pass 98 | else: 99 | z, logdet = self.actnorm(z, logdet=logdet, reverse=False) 100 | 101 | # 2. permute 102 | z, logdet = FlowStep.FlowPermutation[self.flow_permutation]( 103 | self, z, logdet, False) 104 | 105 | need_features = self.affine_need_features() 106 | 107 | # 3. coupling 108 | if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]: 109 | img_ft = getConditional(rrdbResults, self.position) 110 | z, logdet = self.affine(input=z, logdet=logdet, reverse=False, ft=img_ft) 111 | return z, logdet 112 | 113 | def reverse_flow(self, z, logdet, rrdbResults=None): 114 | 115 | need_features = self.affine_need_features() 116 | 117 | # 1.coupling 118 | if need_features or self.flow_coupling in ["condAffine", "condFtAffine", "condNormAffine"]: 119 | img_ft = getConditional(rrdbResults, self.position) 120 | z, logdet = self.affine(input=z, logdet=logdet, reverse=True, ft=img_ft) 121 | 122 | # 2. permute 123 | z, logdet = FlowStep.FlowPermutation[self.flow_permutation]( 124 | self, z, logdet, True) 125 | 126 | # 3. actnorm 127 | z, logdet = self.actnorm(z, logdet=logdet, reverse=True) 128 | 129 | return z, logdet 130 | 131 | def affine_need_features(self): 132 | need_features = False 133 | try: 134 | need_features = self.affine.need_features 135 | except: 136 | pass 137 | return need_features 138 | -------------------------------------------------------------------------------- /code/models/modules/FlowUpsamplerNet.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import numpy as np 18 | import torch 19 | from torch import nn as nn 20 | 21 | import models.modules.Split 22 | from models.modules import flow, thops 23 | from models.modules.Split import Split2d 24 | from models.modules.glow_arch import f_conv2d_bias 25 | from models.modules.FlowStep import FlowStep 26 | from utils.util import opt_get 27 | 28 | 29 | class FlowUpsamplerNet(nn.Module): 30 | def __init__(self, image_shape, hidden_channels, K, L=None, 31 | actnorm_scale=1.0, 32 | flow_permutation=None, 33 | flow_coupling="affine", 34 | LU_decomposed=False, opt=None): 35 | 36 | super().__init__() 37 | 38 | self.layers = nn.ModuleList() 39 | self.output_shapes = [] 40 | self.L = opt_get(opt, ['network_G', 'flow', 'L']) 41 | self.K = opt_get(opt, ['network_G', 'flow', 'K']) 42 | if isinstance(self.K, int): 43 | self.K = [K for K in [K, ] * (self.L + 1)] 44 | 45 | self.opt = opt 46 | H, W, self.C = image_shape 47 | self.check_image_shape() 48 | 49 | if opt['scale'] == 16: 50 | self.levelToName = { 51 | 0: 'fea_up16', 52 | 1: 'fea_up8', 53 | 2: 'fea_up4', 54 | 3: 'fea_up2', 55 | 4: 'fea_up1', 56 | } 57 | 58 | if opt['scale'] == 8: 59 | self.levelToName = { 60 | 0: 'fea_up8', 61 | 1: 'fea_up4', 62 | 2: 'fea_up2', 63 | 3: 'fea_up1', 64 | 4: 'fea_up0' 65 | } 66 | 67 | elif opt['scale'] == 4: 68 | self.levelToName = { 69 | 0: 'fea_up4', 70 | 1: 'fea_up2', 71 | 2: 'fea_up1', 72 | 3: 'fea_up0', 73 | 4: 'fea_up-1' 74 | } 75 | 76 | affineInCh = self.get_affineInCh(opt_get) 77 | flow_permutation = self.get_flow_permutation(flow_permutation, opt) 78 | 79 | normOpt = opt_get(opt, ['network_G', 'flow', 'norm']) 80 | 81 | conditional_channels = {} 82 | n_rrdb = self.get_n_rrdb_channels(opt, opt_get) 83 | n_bypass_channels = opt_get(opt, ['network_G', 'flow', 'levelConditional', 'n_channels']) 84 | conditional_channels[0] = n_rrdb 85 | for level in range(1, self.L + 1): 86 | # Level 1 gets conditionals from 2, 3, 4 => L - level 87 | # Level 2 gets conditionals from 3, 4 88 | # Level 3 gets conditionals from 4 89 | # Level 4 gets conditionals from None 90 | n_bypass = 0 if n_bypass_channels is None else (self.L - level) * n_bypass_channels 91 | conditional_channels[level] = n_rrdb + n_bypass 92 | 93 | # Upsampler 94 | for level in range(1, self.L + 1): 95 | # 1. Squeeze 96 | H, W = self.arch_squeeze(H, W) 97 | 98 | # 2. K FlowStep 99 | self.arch_additionalFlowAffine(H, LU_decomposed, W, actnorm_scale, hidden_channels, opt) 100 | self.arch_FlowStep(H, self.K[level], LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, 101 | flow_permutation, 102 | hidden_channels, normOpt, opt, opt_get, 103 | n_conditinal_channels=conditional_channels[level]) 104 | # Split 105 | self.arch_split(H, W, level, self.L, opt, opt_get) 106 | 107 | if opt_get(opt, ['network_G', 'flow', 'split', 'enable']): 108 | self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64 // 2 // 2) 109 | else: 110 | self.f = f_conv2d_bias(affineInCh, 2 * 3 * 64) 111 | 112 | self.H = H 113 | self.W = W 114 | self.scaleH = 160 / H 115 | self.scaleW = 160 / W 116 | 117 | def get_n_rrdb_channels(self, opt, opt_get): 118 | blocks = opt_get(opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) 119 | n_rrdb = 64 if blocks is None else (len(blocks) + 1) * 64 120 | return n_rrdb 121 | 122 | def arch_FlowStep(self, H, K, LU_decomposed, W, actnorm_scale, affineInCh, flow_coupling, flow_permutation, 123 | hidden_channels, normOpt, opt, opt_get, n_conditinal_channels=None): 124 | condAff = self.get_condAffSetting(opt, opt_get) 125 | if condAff is not None: 126 | condAff['in_channels_rrdb'] = n_conditinal_channels 127 | 128 | for k in range(K): 129 | position_name = get_position_name(H, self.opt['scale']) 130 | if normOpt: normOpt['position'] = position_name 131 | 132 | self.layers.append( 133 | FlowStep(in_channels=self.C, 134 | hidden_channels=hidden_channels, 135 | actnorm_scale=actnorm_scale, 136 | flow_permutation=flow_permutation, 137 | flow_coupling=flow_coupling, 138 | acOpt=condAff, 139 | position=position_name, 140 | LU_decomposed=LU_decomposed, opt=opt, idx=k, normOpt=normOpt)) 141 | self.output_shapes.append( 142 | [-1, self.C, H, W]) 143 | 144 | def get_condAffSetting(self, opt, opt_get): 145 | condAff = opt_get(opt, ['network_G', 'flow', 'condAff']) or None 146 | condAff = opt_get(opt, ['network_G', 'flow', 'condFtAffine']) or condAff 147 | return condAff 148 | 149 | def arch_split(self, H, W, L, levels, opt, opt_get): 150 | correct_splits = opt_get(opt, ['network_G', 'flow', 'split', 'correct_splits'], False) 151 | correction = 0 if correct_splits else 1 152 | if opt_get(opt, ['network_G', 'flow', 'split', 'enable']) and L < levels - correction: 153 | logs_eps = opt_get(opt, ['network_G', 'flow', 'split', 'logs_eps']) or 0 154 | consume_ratio = opt_get(opt, ['network_G', 'flow', 'split', 'consume_ratio']) or 0.5 155 | position_name = get_position_name(H, self.opt['scale']) 156 | position = position_name if opt_get(opt, ['network_G', 'flow', 'split', 'conditional']) else None 157 | cond_channels = opt_get(opt, ['network_G', 'flow', 'split', 'cond_channels']) 158 | cond_channels = 0 if cond_channels is None else cond_channels 159 | 160 | t = opt_get(opt, ['network_G', 'flow', 'split', 'type'], 'Split2d') 161 | 162 | if t == 'Split2d': 163 | split = models.modules.Split.Split2d(num_channels=self.C, logs_eps=logs_eps, position=position, 164 | cond_channels=cond_channels, consume_ratio=consume_ratio, opt=opt) 165 | self.layers.append(split) 166 | self.output_shapes.append([-1, split.num_channels_pass, H, W]) 167 | self.C = split.num_channels_pass 168 | 169 | def arch_additionalFlowAffine(self, H, LU_decomposed, W, actnorm_scale, hidden_channels, opt): 170 | if 'additionalFlowNoAffine' in opt['network_G']['flow']: 171 | n_additionalFlowNoAffine = int(opt['network_G']['flow']['additionalFlowNoAffine']) 172 | for _ in range(n_additionalFlowNoAffine): 173 | self.layers.append( 174 | FlowStep(in_channels=self.C, 175 | hidden_channels=hidden_channels, 176 | actnorm_scale=actnorm_scale, 177 | flow_permutation='invconv', 178 | flow_coupling='noCoupling', 179 | LU_decomposed=LU_decomposed, opt=opt)) 180 | self.output_shapes.append( 181 | [-1, self.C, H, W]) 182 | 183 | def arch_squeeze(self, H, W): 184 | self.C, H, W = self.C * 4, H // 2, W // 2 185 | self.layers.append(flow.SqueezeLayer(factor=2)) 186 | self.output_shapes.append([-1, self.C, H, W]) 187 | return H, W 188 | 189 | def get_flow_permutation(self, flow_permutation, opt): 190 | flow_permutation = opt['network_G']['flow'].get('flow_permutation', 'invconv') 191 | return flow_permutation 192 | 193 | def get_affineInCh(self, opt_get): 194 | affineInCh = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] 195 | affineInCh = (len(affineInCh) + 1) * 64 196 | return affineInCh 197 | 198 | def check_image_shape(self): 199 | assert self.C == 1 or self.C == 3, ("image_shape should be HWC, like (64, 64, 3)" 200 | "self.C == 1 or self.C == 3") 201 | 202 | def forward(self, gt=None, rrdbResults=None, z=None, epses=None, logdet=0., reverse=False, eps_std=None, 203 | y_onehot=None): 204 | 205 | if reverse: 206 | epses_copy = [eps for eps in epses] if isinstance(epses, list) else epses 207 | 208 | sr, logdet = self.decode(rrdbResults, z, eps_std, epses=epses_copy, logdet=logdet, y_onehot=y_onehot) 209 | return sr, logdet 210 | else: 211 | assert gt is not None 212 | assert rrdbResults is not None 213 | z, logdet = self.encode(gt, rrdbResults, logdet=logdet, epses=epses, y_onehot=y_onehot) 214 | 215 | return z, logdet 216 | 217 | def encode(self, gt, rrdbResults, logdet=0.0, epses=None, y_onehot=None): 218 | fl_fea = gt 219 | reverse = False 220 | level_conditionals = {} 221 | bypasses = {} 222 | 223 | L = opt_get(self.opt, ['network_G', 'flow', 'L']) 224 | 225 | for level in range(1, L + 1): 226 | bypasses[level] = torch.nn.functional.interpolate(gt, scale_factor=2 ** -level, mode='bilinear', align_corners=False) 227 | 228 | for layer, shape in zip(self.layers, self.output_shapes): 229 | size = shape[2] 230 | level = int(np.log(160 / size) / np.log(2)) 231 | 232 | if level > 0 and level not in level_conditionals.keys(): 233 | level_conditionals[level] = rrdbResults[self.levelToName[level]] 234 | 235 | level_conditionals[level] = rrdbResults[self.levelToName[level]] 236 | 237 | if isinstance(layer, FlowStep): 238 | fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse, rrdbResults=level_conditionals[level]) 239 | elif isinstance(layer, Split2d): 240 | fl_fea, logdet = self.forward_split2d(epses, fl_fea, layer, logdet, reverse, level_conditionals[level], 241 | y_onehot=y_onehot) 242 | else: 243 | fl_fea, logdet = layer(fl_fea, logdet, reverse=reverse) 244 | 245 | z = fl_fea 246 | 247 | if not isinstance(epses, list): 248 | return z, logdet 249 | 250 | epses.append(z) 251 | return epses, logdet 252 | 253 | def forward_preFlow(self, fl_fea, logdet, reverse): 254 | if hasattr(self, 'preFlow'): 255 | for l in self.preFlow: 256 | fl_fea, logdet = l(fl_fea, logdet, reverse=reverse) 257 | return fl_fea, logdet 258 | 259 | def forward_split2d(self, epses, fl_fea, layer, logdet, reverse, rrdbResults, y_onehot=None): 260 | ft = None if layer.position is None else rrdbResults[layer.position] 261 | fl_fea, logdet, eps = layer(fl_fea, logdet, reverse=reverse, eps=epses, ft=ft, y_onehot=y_onehot) 262 | 263 | if isinstance(epses, list): 264 | epses.append(eps) 265 | return fl_fea, logdet 266 | 267 | def decode(self, rrdbResults, z, eps_std=None, epses=None, logdet=0.0, y_onehot=None): 268 | z = epses.pop() if isinstance(epses, list) else z 269 | 270 | fl_fea = z 271 | # debug.imwrite("fl_fea", fl_fea) 272 | bypasses = {} 273 | level_conditionals = {} 274 | if not opt_get(self.opt, ['network_G', 'flow', 'levelConditional', 'conditional']) == True: 275 | for level in range(self.L + 1): 276 | level_conditionals[level] = rrdbResults[self.levelToName[level]] 277 | 278 | for layer, shape in zip(reversed(self.layers), reversed(self.output_shapes)): 279 | size = shape[2] 280 | level = int(np.log(160 / size) / np.log(2)) 281 | # size = fl_fea.shape[2] 282 | # level = int(np.log(160 / size) / np.log(2)) 283 | 284 | if isinstance(layer, Split2d): 285 | fl_fea, logdet = self.forward_split2d_reverse(eps_std, epses, fl_fea, layer, 286 | rrdbResults[self.levelToName[level]], logdet=logdet, 287 | y_onehot=y_onehot) 288 | elif isinstance(layer, FlowStep): 289 | fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, rrdbResults=level_conditionals[level]) 290 | else: 291 | fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True) 292 | 293 | sr = fl_fea 294 | 295 | assert sr.shape[1] == 3 296 | return sr, logdet 297 | 298 | def forward_split2d_reverse(self, eps_std, epses, fl_fea, layer, rrdbResults, logdet, y_onehot=None): 299 | ft = None if layer.position is None else rrdbResults[layer.position] 300 | fl_fea, logdet = layer(fl_fea, logdet=logdet, reverse=True, 301 | eps=epses.pop() if isinstance(epses, list) else None, 302 | eps_std=eps_std, ft=ft, y_onehot=y_onehot) 303 | return fl_fea, logdet 304 | 305 | 306 | def get_position_name(H, scale): 307 | downscale_factor = 160 // H 308 | position_name = 'fea_up{}'.format(scale / downscale_factor) 309 | return position_name 310 | -------------------------------------------------------------------------------- /code/models/modules/Permutations.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import numpy as np 18 | import torch 19 | from torch import nn as nn 20 | from torch.nn import functional as F 21 | 22 | from models.modules import thops 23 | 24 | 25 | class InvertibleConv1x1(nn.Module): 26 | def __init__(self, num_channels, LU_decomposed=False): 27 | super().__init__() 28 | w_shape = [num_channels, num_channels] 29 | w_init = np.linalg.qr(np.random.randn(*w_shape))[0].astype(np.float32) 30 | self.register_parameter("weight", nn.Parameter(torch.Tensor(w_init))) 31 | self.w_shape = w_shape 32 | self.LU = LU_decomposed 33 | 34 | def get_weight(self, input, reverse): 35 | w_shape = self.w_shape 36 | pixels = thops.pixels(input) 37 | dlogdet = torch.slogdet(self.weight)[1] * pixels 38 | if not reverse: 39 | weight = self.weight.view(w_shape[0], w_shape[1], 1, 1) 40 | else: 41 | weight = torch.inverse(self.weight.double()).float() \ 42 | .view(w_shape[0], w_shape[1], 1, 1) 43 | return weight, dlogdet 44 | def forward(self, input, logdet=None, reverse=False): 45 | """ 46 | log-det = log|abs(|W|)| * pixels 47 | """ 48 | weight, dlogdet = self.get_weight(input, reverse) 49 | if not reverse: 50 | z = F.conv2d(input, weight) 51 | if logdet is not None: 52 | logdet = logdet + dlogdet 53 | return z, logdet 54 | else: 55 | z = F.conv2d(input, weight) 56 | if logdet is not None: 57 | logdet = logdet - dlogdet 58 | return z, logdet 59 | -------------------------------------------------------------------------------- /code/models/modules/RRDBNet_arch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import functools 18 | import torch 19 | import torch.nn as nn 20 | import torch.nn.functional as F 21 | import models.modules.module_util as mutil 22 | from utils.util import opt_get 23 | 24 | 25 | class ResidualDenseBlock_5C(nn.Module): 26 | def __init__(self, nf=64, gc=32, bias=True): 27 | super(ResidualDenseBlock_5C, self).__init__() 28 | # gc: growth channel, i.e. intermediate channels 29 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 30 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 31 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 32 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 33 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 34 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 35 | 36 | # initialization 37 | mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 38 | 39 | def forward(self, x): 40 | x1 = self.lrelu(self.conv1(x)) 41 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 42 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 43 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 44 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 45 | return x5 * 0.2 + x 46 | 47 | 48 | class RRDB(nn.Module): 49 | '''Residual in Residual Dense Block''' 50 | 51 | def __init__(self, nf, gc=32): 52 | super(RRDB, self).__init__() 53 | self.RDB1 = ResidualDenseBlock_5C(nf, gc) 54 | self.RDB2 = ResidualDenseBlock_5C(nf, gc) 55 | self.RDB3 = ResidualDenseBlock_5C(nf, gc) 56 | 57 | def forward(self, x): 58 | out = self.RDB1(x) 59 | out = self.RDB2(out) 60 | out = self.RDB3(out) 61 | return out * 0.2 + x 62 | 63 | 64 | class RRDBNet(nn.Module): 65 | def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, opt=None): 66 | self.opt = opt 67 | super(RRDBNet, self).__init__() 68 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) 69 | self.scale = scale 70 | 71 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 72 | self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb) 73 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 74 | #### upsampling 75 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 76 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 77 | if self.scale >= 8: 78 | self.upconv3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 79 | if self.scale >= 16: 80 | self.upconv4 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 81 | if self.scale >= 32: 82 | self.upconv5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 83 | 84 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 85 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 86 | 87 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) 88 | 89 | def forward(self, x, get_steps=False): 90 | fea = self.conv_first(x) 91 | 92 | block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] 93 | block_results = {} 94 | 95 | for idx, m in enumerate(self.RRDB_trunk.children()): 96 | fea = m(fea) 97 | for b in block_idxs: 98 | if b == idx: 99 | block_results["block_{}".format(idx)] = fea 100 | 101 | trunk = self.trunk_conv(fea) 102 | 103 | last_lr_fea = fea + trunk 104 | 105 | fea_up2 = self.upconv1(F.interpolate(last_lr_fea, scale_factor=2, mode='nearest')) 106 | fea = self.lrelu(fea_up2) 107 | 108 | fea_up4 = self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')) 109 | fea = self.lrelu(fea_up4) 110 | 111 | fea_up8 = None 112 | fea_up16 = None 113 | fea_up32 = None 114 | 115 | if self.scale >= 8: 116 | fea_up8 = self.upconv3(F.interpolate(fea, scale_factor=2, mode='nearest')) 117 | fea = self.lrelu(fea_up8) 118 | if self.scale >= 16: 119 | fea_up16 = self.upconv4(F.interpolate(fea, scale_factor=2, mode='nearest')) 120 | fea = self.lrelu(fea_up16) 121 | if self.scale >= 32: 122 | fea_up32 = self.upconv5(F.interpolate(fea, scale_factor=2, mode='nearest')) 123 | fea = self.lrelu(fea_up32) 124 | 125 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 126 | 127 | results = {'last_lr_fea': last_lr_fea, 128 | 'fea_up1': last_lr_fea, 129 | 'fea_up2': fea_up2, 130 | 'fea_up4': fea_up4, 131 | 'fea_up8': fea_up8, 132 | 'fea_up16': fea_up16, 133 | 'fea_up32': fea_up32, 134 | 'out': out} 135 | 136 | fea_up0_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up0']) or False 137 | if fea_up0_en: 138 | results['fea_up0'] = F.interpolate(last_lr_fea, scale_factor=1/2, mode='bilinear', align_corners=False, recompute_scale_factor=True) 139 | fea_upn1_en = opt_get(self.opt, ['network_G', 'flow', 'fea_up-1']) or False 140 | if fea_upn1_en: 141 | results['fea_up-1'] = F.interpolate(last_lr_fea, scale_factor=1/4, mode='bilinear', align_corners=False, recompute_scale_factor=True) 142 | 143 | if get_steps: 144 | for k, v in block_results.items(): 145 | results[k] = v 146 | return results 147 | else: 148 | return out 149 | -------------------------------------------------------------------------------- /code/models/modules/SRFlowNet_arch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import math 18 | 19 | import torch 20 | import torch.nn as nn 21 | import torch.nn.functional as F 22 | import numpy as np 23 | from models.modules.RRDBNet_arch import RRDBNet 24 | from models.modules.FlowUpsamplerNet import FlowUpsamplerNet 25 | import models.modules.thops as thops 26 | import models.modules.flow as flow 27 | from utils.util import opt_get 28 | 29 | 30 | class SRFlowNet(nn.Module): 31 | def __init__(self, in_nc, out_nc, nf, nb, gc=32, scale=4, K=None, opt=None, step=None): 32 | super(SRFlowNet, self).__init__() 33 | 34 | self.opt = opt 35 | self.quant = 255 if opt_get(opt, ['datasets', 'train', 'quant']) is \ 36 | None else opt_get(opt, ['datasets', 'train', 'quant']) 37 | self.RRDB = RRDBNet(in_nc, out_nc, nf, nb, gc, scale, opt) 38 | hidden_channels = opt_get(opt, ['network_G', 'flow', 'hidden_channels']) 39 | hidden_channels = hidden_channels or 64 40 | self.RRDB_training = True # Default is true 41 | 42 | train_RRDB_delay = opt_get(self.opt, ['network_G', 'train_RRDB_delay']) 43 | set_RRDB_to_train = False 44 | if set_RRDB_to_train: 45 | self.set_rrdb_training(True) 46 | 47 | self.flowUpsamplerNet = \ 48 | FlowUpsamplerNet((160, 160, 3), hidden_channels, K, 49 | flow_coupling=opt['network_G']['flow']['coupling'], opt=opt) 50 | self.i = 0 51 | 52 | def set_rrdb_training(self, trainable): 53 | if self.RRDB_training != trainable: 54 | for p in self.RRDB.parameters(): 55 | p.requires_grad = trainable 56 | self.RRDB_training = trainable 57 | return True 58 | return False 59 | 60 | def forward(self, gt=None, lr=None, z=None, eps_std=None, reverse=False, epses=None, reverse_with_grad=False, 61 | lr_enc=None, 62 | add_gt_noise=False, step=None, y_label=None): 63 | if not reverse: 64 | return self.normal_flow(gt, lr, epses=epses, lr_enc=lr_enc, add_gt_noise=add_gt_noise, step=step, 65 | y_onehot=y_label) 66 | else: 67 | # assert lr.shape[0] == 1 68 | assert lr.shape[1] == 3 69 | # assert lr.shape[2] == 20 70 | # assert lr.shape[3] == 20 71 | # assert z.shape[0] == 1 72 | # assert z.shape[1] == 3 * 8 * 8 73 | # assert z.shape[2] == 20 74 | # assert z.shape[3] == 20 75 | if reverse_with_grad: 76 | return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc, 77 | add_gt_noise=add_gt_noise) 78 | else: 79 | with torch.no_grad(): 80 | return self.reverse_flow(lr, z, y_onehot=y_label, eps_std=eps_std, epses=epses, lr_enc=lr_enc, 81 | add_gt_noise=add_gt_noise) 82 | 83 | def normal_flow(self, gt, lr, y_onehot=None, epses=None, lr_enc=None, add_gt_noise=True, step=None): 84 | if lr_enc is None: 85 | lr_enc = self.rrdbPreprocessing(lr) 86 | 87 | logdet = torch.zeros_like(gt[:, 0, 0, 0]) 88 | pixels = thops.pixels(gt) 89 | 90 | z = gt 91 | 92 | if add_gt_noise: 93 | # Setup 94 | noiseQuant = opt_get(self.opt, ['network_G', 'flow', 'augmentation', 'noiseQuant'], True) 95 | if noiseQuant: 96 | z = z + ((torch.rand(z.shape, device=z.device) - 0.5) / self.quant) 97 | logdet = logdet + float(-np.log(self.quant) * pixels) 98 | 99 | # Encode 100 | epses, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, gt=z, logdet=logdet, reverse=False, epses=epses, 101 | y_onehot=y_onehot) 102 | 103 | objective = logdet.clone() 104 | 105 | if isinstance(epses, (list, tuple)): 106 | z = epses[-1] 107 | else: 108 | z = epses 109 | 110 | objective = objective + flow.GaussianDiag.logp(None, None, z) 111 | 112 | nll = (-objective) / float(np.log(2.) * pixels) 113 | 114 | if isinstance(epses, list): 115 | return epses, nll, logdet 116 | return z, nll, logdet 117 | 118 | def rrdbPreprocessing(self, lr): 119 | rrdbResults = self.RRDB(lr, get_steps=True) 120 | block_idxs = opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'blocks']) or [] 121 | if len(block_idxs) > 0: 122 | concat = torch.cat([rrdbResults["block_{}".format(idx)] for idx in block_idxs], dim=1) 123 | 124 | if opt_get(self.opt, ['network_G', 'flow', 'stackRRDB', 'concat']) or False: 125 | keys = ['last_lr_fea', 'fea_up1', 'fea_up2', 'fea_up4'] 126 | if 'fea_up0' in rrdbResults.keys(): 127 | keys.append('fea_up0') 128 | if 'fea_up-1' in rrdbResults.keys(): 129 | keys.append('fea_up-1') 130 | if self.opt['scale'] >= 8: 131 | keys.append('fea_up8') 132 | if self.opt['scale'] == 16: 133 | keys.append('fea_up16') 134 | for k in keys: 135 | h = rrdbResults[k].shape[2] 136 | w = rrdbResults[k].shape[3] 137 | rrdbResults[k] = torch.cat([rrdbResults[k], F.interpolate(concat, (h, w))], dim=1) 138 | return rrdbResults 139 | 140 | def get_score(self, disc_loss_sigma, z): 141 | score_real = 0.5 * (1 - 1 / (disc_loss_sigma ** 2)) * thops.sum(z ** 2, dim=[1, 2, 3]) - \ 142 | z.shape[1] * z.shape[2] * z.shape[3] * math.log(disc_loss_sigma) 143 | return -score_real 144 | 145 | def reverse_flow(self, lr, z, y_onehot, eps_std, epses=None, lr_enc=None, add_gt_noise=True): 146 | logdet = torch.zeros_like(lr[:, 0, 0, 0]) 147 | pixels = thops.pixels(lr) * self.opt['scale'] ** 2 148 | 149 | if add_gt_noise: 150 | logdet = logdet - float(-np.log(self.quant) * pixels) 151 | 152 | if lr_enc is None: 153 | lr_enc = self.rrdbPreprocessing(lr) 154 | 155 | x, logdet = self.flowUpsamplerNet(rrdbResults=lr_enc, z=z, eps_std=eps_std, reverse=True, epses=epses, 156 | logdet=logdet) 157 | 158 | return x, logdet 159 | -------------------------------------------------------------------------------- /code/models/modules/Split.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import torch 18 | from torch import nn as nn 19 | 20 | from models.modules import thops 21 | from models.modules.FlowStep import FlowStep 22 | from models.modules.flow import Conv2dZeros, GaussianDiag 23 | from utils.util import opt_get 24 | 25 | 26 | class Split2d(nn.Module): 27 | def __init__(self, num_channels, logs_eps=0, cond_channels=0, position=None, consume_ratio=0.5, opt=None): 28 | super().__init__() 29 | 30 | self.num_channels_consume = int(round(num_channels * consume_ratio)) 31 | self.num_channels_pass = num_channels - self.num_channels_consume 32 | 33 | self.conv = Conv2dZeros(in_channels=self.num_channels_pass + cond_channels, 34 | out_channels=self.num_channels_consume * 2) 35 | self.logs_eps = logs_eps 36 | self.position = position 37 | self.opt = opt 38 | 39 | def split2d_prior(self, z, ft): 40 | if ft is not None: 41 | z = torch.cat([z, ft], dim=1) 42 | h = self.conv(z) 43 | return thops.split_feature(h, "cross") 44 | 45 | def exp_eps(self, logs): 46 | return torch.exp(logs) + self.logs_eps 47 | 48 | def forward(self, input, logdet=0., reverse=False, eps_std=None, eps=None, ft=None, y_onehot=None): 49 | if not reverse: 50 | # self.input = input 51 | z1, z2 = self.split_ratio(input) 52 | mean, logs = self.split2d_prior(z1, ft) 53 | 54 | eps = (z2 - mean) / self.exp_eps(logs) 55 | 56 | logdet = logdet + self.get_logdet(logs, mean, z2) 57 | 58 | # print(logs.shape, mean.shape, z2.shape) 59 | # self.eps = eps 60 | # print('split, enc eps:', eps) 61 | return z1, logdet, eps 62 | else: 63 | z1 = input 64 | mean, logs = self.split2d_prior(z1, ft) 65 | 66 | if eps is None: 67 | #print("WARNING: eps is None, generating eps untested functionality!") 68 | eps = GaussianDiag.sample_eps(mean.shape, eps_std) 69 | 70 | eps = eps.to(mean.device) 71 | z2 = mean + self.exp_eps(logs) * eps 72 | 73 | z = thops.cat_feature(z1, z2) 74 | logdet = logdet - self.get_logdet(logs, mean, z2) 75 | 76 | return z, logdet 77 | # return z, logdet, eps 78 | 79 | def get_logdet(self, logs, mean, z2): 80 | logdet_diff = GaussianDiag.logp(mean, logs, z2) 81 | # print("Split2D: logdet diff", logdet_diff.item()) 82 | return logdet_diff 83 | 84 | def split_ratio(self, input): 85 | z1, z2 = input[:, :self.num_channels_pass, ...], input[:, self.num_channels_pass:, ...] 86 | return z1, z2 -------------------------------------------------------------------------------- /code/models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreas128/SRFlow/5a007ad591c7be8bf32cf23171bfa4473e71683c/code/models/modules/__init__.py -------------------------------------------------------------------------------- /code/models/modules/flow.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | import numpy as np 21 | 22 | from models.modules.FlowActNorms import ActNorm2d 23 | from . import thops 24 | 25 | 26 | class Conv2d(nn.Conv2d): 27 | pad_dict = { 28 | "same": lambda kernel, stride: [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)], 29 | "valid": lambda kernel, stride: [0 for _ in kernel] 30 | } 31 | 32 | @staticmethod 33 | def get_padding(padding, kernel_size, stride): 34 | # make paddding 35 | if isinstance(padding, str): 36 | if isinstance(kernel_size, int): 37 | kernel_size = [kernel_size, kernel_size] 38 | if isinstance(stride, int): 39 | stride = [stride, stride] 40 | padding = padding.lower() 41 | try: 42 | padding = Conv2d.pad_dict[padding](kernel_size, stride) 43 | except KeyError: 44 | raise ValueError("{} is not supported".format(padding)) 45 | return padding 46 | 47 | def __init__(self, in_channels, out_channels, 48 | kernel_size=[3, 3], stride=[1, 1], 49 | padding="same", do_actnorm=True, weight_std=0.05): 50 | padding = Conv2d.get_padding(padding, kernel_size, stride) 51 | super().__init__(in_channels, out_channels, kernel_size, stride, 52 | padding, bias=(not do_actnorm)) 53 | # init weight with std 54 | self.weight.data.normal_(mean=0.0, std=weight_std) 55 | if not do_actnorm: 56 | self.bias.data.zero_() 57 | else: 58 | self.actnorm = ActNorm2d(out_channels) 59 | self.do_actnorm = do_actnorm 60 | 61 | def forward(self, input): 62 | x = super().forward(input) 63 | if self.do_actnorm: 64 | x, _ = self.actnorm(x) 65 | return x 66 | 67 | 68 | class Conv2dZeros(nn.Conv2d): 69 | def __init__(self, in_channels, out_channels, 70 | kernel_size=[3, 3], stride=[1, 1], 71 | padding="same", logscale_factor=3): 72 | padding = Conv2d.get_padding(padding, kernel_size, stride) 73 | super().__init__(in_channels, out_channels, kernel_size, stride, padding) 74 | # logscale_factor 75 | self.logscale_factor = logscale_factor 76 | self.register_parameter("logs", nn.Parameter(torch.zeros(out_channels, 1, 1))) 77 | # init 78 | self.weight.data.zero_() 79 | self.bias.data.zero_() 80 | 81 | def forward(self, input): 82 | output = super().forward(input) 83 | return output * torch.exp(self.logs * self.logscale_factor) 84 | 85 | 86 | class GaussianDiag: 87 | Log2PI = float(np.log(2 * np.pi)) 88 | 89 | @staticmethod 90 | def likelihood(mean, logs, x): 91 | """ 92 | lnL = -1/2 * { ln|Var| + ((X - Mu)^T)(Var^-1)(X - Mu) + kln(2*PI) } 93 | k = 1 (Independent) 94 | Var = logs ** 2 95 | """ 96 | if mean is None and logs is None: 97 | return -0.5 * (x ** 2 + GaussianDiag.Log2PI) 98 | else: 99 | return -0.5 * (logs * 2. + ((x - mean) ** 2) / torch.exp(logs * 2.) + GaussianDiag.Log2PI) 100 | 101 | @staticmethod 102 | def logp(mean, logs, x): 103 | likelihood = GaussianDiag.likelihood(mean, logs, x) 104 | return thops.sum(likelihood, dim=[1, 2, 3]) 105 | 106 | @staticmethod 107 | def sample(mean, logs, eps_std=None): 108 | eps_std = eps_std or 1 109 | eps = torch.normal(mean=torch.zeros_like(mean), 110 | std=torch.ones_like(logs) * eps_std) 111 | return mean + torch.exp(logs) * eps 112 | 113 | @staticmethod 114 | def sample_eps(shape, eps_std, seed=None): 115 | if seed is not None: 116 | torch.manual_seed(seed) 117 | eps = torch.normal(mean=torch.zeros(shape), 118 | std=torch.ones(shape) * eps_std) 119 | return eps 120 | 121 | 122 | def squeeze2d(input, factor=2): 123 | assert factor >= 1 and isinstance(factor, int) 124 | if factor == 1: 125 | return input 126 | size = input.size() 127 | B = size[0] 128 | C = size[1] 129 | H = size[2] 130 | W = size[3] 131 | assert H % factor == 0 and W % factor == 0, "{}".format((H, W, factor)) 132 | x = input.view(B, C, H // factor, factor, W // factor, factor) 133 | x = x.permute(0, 1, 3, 5, 2, 4).contiguous() 134 | x = x.view(B, C * factor * factor, H // factor, W // factor) 135 | return x 136 | 137 | 138 | def unsqueeze2d(input, factor=2): 139 | assert factor >= 1 and isinstance(factor, int) 140 | factor2 = factor ** 2 141 | if factor == 1: 142 | return input 143 | size = input.size() 144 | B = size[0] 145 | C = size[1] 146 | H = size[2] 147 | W = size[3] 148 | assert C % (factor2) == 0, "{}".format(C) 149 | x = input.view(B, C // factor2, factor, factor, H, W) 150 | x = x.permute(0, 1, 4, 2, 5, 3).contiguous() 151 | x = x.view(B, C // (factor2), H * factor, W * factor) 152 | return x 153 | 154 | 155 | class SqueezeLayer(nn.Module): 156 | def __init__(self, factor): 157 | super().__init__() 158 | self.factor = factor 159 | 160 | def forward(self, input, logdet=None, reverse=False): 161 | if not reverse: 162 | output = squeeze2d(input, self.factor) # Squeeze in forward 163 | return output, logdet 164 | else: 165 | output = unsqueeze2d(input, self.factor) 166 | return output, logdet 167 | -------------------------------------------------------------------------------- /code/models/modules/glow_arch.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import torch.nn as nn 18 | 19 | 20 | def f_conv2d_bias(in_channels, out_channels): 21 | def padding_same(kernel, stride): 22 | return [((k - 1) * s + 1) // 2 for k, s in zip(kernel, stride)] 23 | 24 | padding = padding_same([3, 3], [1, 1]) 25 | assert padding == [1, 1], padding 26 | return nn.Sequential( 27 | nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=[3, 3], stride=1, padding=1, 28 | bias=True)) 29 | -------------------------------------------------------------------------------- /code/models/modules/loss.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class CharbonnierLoss(nn.Module): 22 | """Charbonnier Loss (L1)""" 23 | 24 | def __init__(self, eps=1e-6): 25 | super(CharbonnierLoss, self).__init__() 26 | self.eps = eps 27 | 28 | def forward(self, x, y): 29 | diff = x - y 30 | loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 31 | return loss 32 | 33 | 34 | # Define GAN loss: [vanilla | lsgan | wgan-gp] 35 | class GANLoss(nn.Module): 36 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 37 | super(GANLoss, self).__init__() 38 | self.gan_type = gan_type.lower() 39 | self.real_label_val = real_label_val 40 | self.fake_label_val = fake_label_val 41 | 42 | if self.gan_type == 'gan' or self.gan_type == 'ragan': 43 | self.loss = nn.BCEWithLogitsLoss() 44 | elif self.gan_type == 'lsgan': 45 | self.loss = nn.MSELoss() 46 | elif self.gan_type == 'wgan-gp': 47 | 48 | def wgan_loss(input, target): 49 | # target is boolean 50 | return -1 * input.mean() if target else input.mean() 51 | 52 | self.loss = wgan_loss 53 | else: 54 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 55 | 56 | def get_target_label(self, input, target_is_real): 57 | if self.gan_type == 'wgan-gp': 58 | return target_is_real 59 | if target_is_real: 60 | return torch.empty_like(input).fill_(self.real_label_val) 61 | else: 62 | return torch.empty_like(input).fill_(self.fake_label_val) 63 | 64 | def forward(self, input, target_is_real): 65 | target_label = self.get_target_label(input, target_is_real) 66 | loss = self.loss(input, target_label) 67 | return loss 68 | 69 | 70 | class GradientPenaltyLoss(nn.Module): 71 | def __init__(self, device=torch.device('cpu')): 72 | super(GradientPenaltyLoss, self).__init__() 73 | self.register_buffer('grad_outputs', torch.Tensor()) 74 | self.grad_outputs = self.grad_outputs.to(device) 75 | 76 | def get_grad_outputs(self, input): 77 | if self.grad_outputs.size() != input.size(): 78 | self.grad_outputs.resize_(input.size()).fill_(1.0) 79 | return self.grad_outputs 80 | 81 | def forward(self, interp, interp_crit): 82 | grad_outputs = self.get_grad_outputs(interp_crit) 83 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, 84 | grad_outputs=grad_outputs, create_graph=True, 85 | retain_graph=True, only_inputs=True)[0] 86 | grad_interp = grad_interp.view(grad_interp.size(0), -1) 87 | grad_interp_norm = grad_interp.norm(2, dim=1) 88 | 89 | loss = ((grad_interp_norm - 1)**2).mean() 90 | return loss 91 | -------------------------------------------------------------------------------- /code/models/modules/module_util.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.init as init 20 | import torch.nn.functional as F 21 | 22 | 23 | def initialize_weights(net_l, scale=1): 24 | if not isinstance(net_l, list): 25 | net_l = [net_l] 26 | for net in net_l: 27 | for m in net.modules(): 28 | if isinstance(m, nn.Conv2d): 29 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 30 | m.weight.data *= scale # for residual block 31 | if m.bias is not None: 32 | m.bias.data.zero_() 33 | elif isinstance(m, nn.Linear): 34 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 35 | m.weight.data *= scale 36 | if m.bias is not None: 37 | m.bias.data.zero_() 38 | elif isinstance(m, nn.BatchNorm2d): 39 | init.constant_(m.weight, 1) 40 | init.constant_(m.bias.data, 0.0) 41 | 42 | 43 | def make_layer(block, n_layers): 44 | layers = [] 45 | for _ in range(n_layers): 46 | layers.append(block()) 47 | return nn.Sequential(*layers) 48 | 49 | 50 | class ResidualBlock_noBN(nn.Module): 51 | '''Residual block w/o BN 52 | ---Conv-ReLU-Conv-+- 53 | |________________| 54 | ''' 55 | 56 | def __init__(self, nf=64): 57 | super(ResidualBlock_noBN, self).__init__() 58 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 59 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 60 | 61 | # initialization 62 | initialize_weights([self.conv1, self.conv2], 0.1) 63 | 64 | def forward(self, x): 65 | identity = x 66 | out = F.relu(self.conv1(x), inplace=True) 67 | out = self.conv2(out) 68 | return identity + out 69 | 70 | 71 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 72 | """Warp an image or feature map with optical flow 73 | Args: 74 | x (Tensor): size (N, C, H, W) 75 | flow (Tensor): size (N, H, W, 2), normal value 76 | interp_mode (str): 'nearest' or 'bilinear' 77 | padding_mode (str): 'zeros' or 'border' or 'reflection' 78 | 79 | Returns: 80 | Tensor: warped image or feature map 81 | """ 82 | assert x.size()[-2:] == flow.size()[1:3] 83 | B, C, H, W = x.size() 84 | # mesh grid 85 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 86 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 87 | grid.requires_grad = False 88 | grid = grid.type_as(x) 89 | vgrid = grid + flow 90 | # scale grid to [-1,1] 91 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 92 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 93 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 94 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 95 | return output 96 | -------------------------------------------------------------------------------- /code/models/modules/thops.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/chaiyujin/glow-pytorch/blob/master/LICENSE 16 | 17 | import torch 18 | 19 | 20 | def sum(tensor, dim=None, keepdim=False): 21 | if dim is None: 22 | # sum up all dim 23 | return torch.sum(tensor) 24 | else: 25 | if isinstance(dim, int): 26 | dim = [dim] 27 | dim = sorted(dim) 28 | for d in dim: 29 | tensor = tensor.sum(dim=d, keepdim=True) 30 | if not keepdim: 31 | for i, d in enumerate(dim): 32 | tensor.squeeze_(d-i) 33 | return tensor 34 | 35 | 36 | def mean(tensor, dim=None, keepdim=False): 37 | if dim is None: 38 | # mean all dim 39 | return torch.mean(tensor) 40 | else: 41 | if isinstance(dim, int): 42 | dim = [dim] 43 | dim = sorted(dim) 44 | for d in dim: 45 | tensor = tensor.mean(dim=d, keepdim=True) 46 | if not keepdim: 47 | for i, d in enumerate(dim): 48 | tensor.squeeze_(d-i) 49 | return tensor 50 | 51 | 52 | def split_feature(tensor, type="split"): 53 | """ 54 | type = ["split", "cross"] 55 | """ 56 | C = tensor.size(1) 57 | if type == "split": 58 | return tensor[:, :C // 2, ...], tensor[:, C // 2:, ...] 59 | elif type == "cross": 60 | return tensor[:, 0::2, ...], tensor[:, 1::2, ...] 61 | 62 | 63 | def cat_feature(tensor_a, tensor_b): 64 | return torch.cat((tensor_a, tensor_b), dim=1) 65 | 66 | 67 | def pixels(tensor): 68 | return int(tensor.size(2) * tensor.size(3)) -------------------------------------------------------------------------------- /code/models/networks.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | import importlib 18 | 19 | import torch 20 | import logging 21 | import models.modules.RRDBNet_arch as RRDBNet_arch 22 | 23 | logger = logging.getLogger('base') 24 | 25 | 26 | def find_model_using_name(model_name): 27 | model_filename = "models.modules." + model_name + "_arch" 28 | modellib = importlib.import_module(model_filename) 29 | 30 | model = None 31 | target_model_name = model_name.replace('_Net', '') 32 | for name, cls in modellib.__dict__.items(): 33 | if name.lower() == target_model_name.lower(): 34 | model = cls 35 | 36 | if model is None: 37 | print( 38 | "In %s.py, there should be a subclass of torch.nn.Module with class name that matches %s." % ( 39 | model_filename, target_model_name)) 40 | exit(0) 41 | 42 | return model 43 | 44 | 45 | #################### 46 | # define network 47 | #################### 48 | #### Generator 49 | def define_G(opt): 50 | opt_net = opt['network_G'] 51 | which_model = opt_net['which_model_G'] 52 | 53 | if which_model == 'RRDBNet': 54 | netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 55 | nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], opt=opt) 56 | elif which_model == 'EDSRNet': 57 | Arch = find_model_using_name(which_model) 58 | netG = Arch(scale=opt['scale']) 59 | elif which_model == 'rankSRGAN': 60 | Arch = find_model_using_name(which_model) 61 | netG = Arch(upscale=opt['scale']) 62 | # elif which_model == 'sft_arch': # SFT-GAN 63 | # netG = sft_arch.SFT_Net() 64 | else: 65 | raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) 66 | return netG 67 | 68 | 69 | def define_Flow(opt, step): 70 | opt_net = opt['network_G'] 71 | which_model = opt_net['which_model_G'] 72 | 73 | Arch = find_model_using_name(which_model) 74 | netG = Arch(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 75 | nf=opt_net['nf'], nb=opt_net['nb'], scale=opt['scale'], K=opt_net['flow']['K'], opt=opt, step=step) 76 | 77 | return netG 78 | 79 | 80 | #### Discriminator 81 | def define_D(opt): 82 | opt_net = opt['network_D'] 83 | which_model = opt_net['which_model_D'] 84 | 85 | if which_model == 'discriminator_vgg_128': 86 | hidden_units = opt_net.get('hidden_units', 8192) 87 | netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf'], hidden_units=hidden_units) 88 | else: 89 | raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) 90 | return netD 91 | 92 | 93 | #### Define Network used for Perceptual Loss 94 | def define_F(opt, use_bn=False): 95 | gpu_ids = opt.get('gpu_ids', None) 96 | device = torch.device('cuda' if gpu_ids else 'cpu') 97 | # PyTorch pretrained_models VGG19-54, before ReLU. 98 | if use_bn: 99 | feature_layer = 49 100 | else: 101 | feature_layer = 34 102 | netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, 103 | use_input_norm=True, device=device) 104 | netF.eval() # No need to train 105 | return netF 106 | -------------------------------------------------------------------------------- /code/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreas128/SRFlow/5a007ad591c7be8bf32cf23171bfa4473e71683c/code/options/__init__.py -------------------------------------------------------------------------------- /code/options/options.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | import os 18 | import os.path as osp 19 | import logging 20 | import yaml 21 | from utils.util import OrderedYaml 22 | 23 | Loader, Dumper = OrderedYaml() 24 | 25 | 26 | def parse(opt_path, is_train=True): 27 | with open(opt_path, mode='r') as f: 28 | opt = yaml.load(f, Loader=Loader) 29 | # export CUDA_VISIBLE_DEVICES 30 | gpu_list = ','.join(str(x) for x in opt.get('gpu_ids', [])) 31 | # os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 32 | # print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 33 | opt['is_train'] = is_train 34 | if opt['distortion'] == 'sr': 35 | scale = opt['scale'] 36 | 37 | # datasets 38 | for phase, dataset in opt['datasets'].items(): 39 | phase = phase.split('_')[0] 40 | dataset['phase'] = phase 41 | if opt['distortion'] == 'sr': 42 | dataset['scale'] = scale 43 | is_lmdb = False 44 | if dataset.get('dataroot_GT', None) is not None: 45 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) 46 | if dataset['dataroot_GT'].endswith('lmdb'): 47 | is_lmdb = True 48 | if dataset.get('dataroot_LQ', None) is not None: 49 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) 50 | if dataset['dataroot_LQ'].endswith('lmdb'): 51 | is_lmdb = True 52 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img' 53 | if dataset['mode'].endswith('mc'): # for memcached 54 | dataset['data_type'] = 'mc' 55 | dataset['mode'] = dataset['mode'].replace('_mc', '') 56 | 57 | # path 58 | for key, path in opt['path'].items(): 59 | if path and key in opt['path'] and key != 'strict_load': 60 | opt['path'][key] = osp.expanduser(path) 61 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 62 | if is_train: 63 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) 64 | opt['path']['experiments_root'] = experiments_root 65 | opt['path']['models'] = osp.join(experiments_root, 'models') 66 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state') 67 | opt['path']['log'] = experiments_root 68 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images') 69 | 70 | # change some options for debug mode 71 | if 'debug' in opt['name']: 72 | opt['train']['val_freq'] = 8 73 | opt['logger']['print_freq'] = 1 74 | opt['logger']['save_checkpoint_freq'] = 8 75 | else: # test 76 | if not opt['path'].get('results_root', None): 77 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 78 | opt['path']['results_root'] = results_root 79 | opt['path']['log'] = opt['path']['results_root'] 80 | 81 | # network 82 | if opt['distortion'] == 'sr': 83 | opt['network_G']['scale'] = scale 84 | 85 | # relative learning rate 86 | if 'train' in opt: 87 | niter = opt['train']['niter'] 88 | if 'T_period_rel' in opt['train']: 89 | opt['train']['T_period'] = [int(x * niter) for x in opt['train']['T_period_rel']] 90 | if 'restarts_rel' in opt['train']: 91 | opt['train']['restarts'] = [int(x * niter) for x in opt['train']['restarts_rel']] 92 | if 'lr_steps_rel' in opt['train']: 93 | opt['train']['lr_steps'] = [int(x * niter) for x in opt['train']['lr_steps_rel']] 94 | if 'lr_steps_inverse_rel' in opt['train']: 95 | opt['train']['lr_steps_inverse'] = [int(x * niter) for x in opt['train']['lr_steps_inverse_rel']] 96 | print(opt['train']) 97 | 98 | return opt 99 | 100 | 101 | def dict2str(opt, indent_l=1): 102 | '''dict to string for logger''' 103 | msg = '' 104 | for k, v in opt.items(): 105 | if isinstance(v, dict): 106 | msg += ' ' * (indent_l * 2) + k + ':[\n' 107 | msg += dict2str(v, indent_l + 1) 108 | msg += ' ' * (indent_l * 2) + ']\n' 109 | else: 110 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 111 | return msg 112 | 113 | 114 | class NoneDict(dict): 115 | def __missing__(self, key): 116 | return None 117 | 118 | 119 | # convert to NoneDict, which return None for missing key. 120 | def dict_to_nonedict(opt): 121 | if isinstance(opt, dict): 122 | new_opt = dict() 123 | for key, sub_opt in opt.items(): 124 | new_opt[key] = dict_to_nonedict(sub_opt) 125 | return NoneDict(**new_opt) 126 | elif isinstance(opt, list): 127 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 128 | else: 129 | return opt 130 | 131 | 132 | def check_resume(opt, resume_iter): 133 | '''Check resume states and pretrain_model paths''' 134 | logger = logging.getLogger('base') 135 | if opt['path']['resume_state']: 136 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( 137 | 'pretrain_model_D', None) is not None: 138 | logger.warning('pretrain_model path will be ignored when resuming training.') 139 | 140 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], 141 | '{}_G.pth'.format(resume_iter)) 142 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 143 | if 'gan' in opt['model']: 144 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], 145 | '{}_D.pth'.format(resume_iter)) 146 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) 147 | -------------------------------------------------------------------------------- /code/prepare_data.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | import glob 16 | import os 17 | import sys 18 | 19 | import numpy as np 20 | import random 21 | import imageio 22 | import pickle 23 | 24 | from natsort import natsort 25 | from tqdm import tqdm 26 | 27 | def get_img_paths(dir_path, wildcard='*.png'): 28 | return natsort.natsorted(glob.glob(dir_path + '/' + wildcard)) 29 | 30 | def create_all_dirs(path): 31 | if "." in path.split("/")[-1]: 32 | dirs = os.path.dirname(path) 33 | else: 34 | dirs = path 35 | os.makedirs(dirs, exist_ok=True) 36 | 37 | def to_pklv4(obj, path, vebose=False): 38 | create_all_dirs(path) 39 | with open(path, 'wb') as f: 40 | pickle.dump(obj, f, protocol=4) 41 | if vebose: 42 | print("Wrote {}".format(path)) 43 | 44 | 45 | from imresize import imresize 46 | 47 | def random_crop(img, size): 48 | h, w, c = img.shape 49 | 50 | h_start = np.random.randint(0, h - size) 51 | h_end = h_start + size 52 | 53 | w_start = np.random.randint(0, w - size) 54 | w_end = w_start + size 55 | 56 | return img[h_start:h_end, w_start:w_end] 57 | 58 | 59 | def imread(img_path): 60 | img = imageio.imread(img_path) 61 | if len(img.shape) == 2: 62 | img = np.stack([img, ] * 3, axis=2) 63 | return img 64 | 65 | 66 | def to_pklv4_1pct(obj, path, vebose): 67 | n = int(round(len(obj) * 0.01)) 68 | path = path.replace(".", "_1pct.") 69 | to_pklv4(obj[:n], path, vebose=True) 70 | 71 | 72 | def main(dir_path): 73 | hrs = [] 74 | lqs = [] 75 | 76 | img_paths = get_img_paths(dir_path) 77 | for img_path in tqdm(img_paths): 78 | img = imread(img_path) 79 | 80 | for i in range(47): 81 | crop = random_crop(img, 160) 82 | cropX4 = imresize(crop, scalar_scale=0.25) 83 | hrs.append(crop) 84 | lqs.append(cropX4) 85 | 86 | shuffle_combined(hrs, lqs) 87 | 88 | hrs_path = get_hrs_path(dir_path) 89 | to_pklv4(hrs, hrs_path, vebose=True) 90 | to_pklv4_1pct(hrs, hrs_path, vebose=True) 91 | 92 | lqs_path = get_lqs_path(dir_path) 93 | to_pklv4(lqs, lqs_path, vebose=True) 94 | to_pklv4_1pct(lqs, lqs_path, vebose=True) 95 | 96 | 97 | def get_hrs_path(dir_path): 98 | base_dir = os.path.dirname(dir_path) 99 | name = os.path.basename(dir_path) 100 | hrs_path = os.path.join(base_dir, 'pkls', name + '.pklv4') 101 | return hrs_path 102 | 103 | 104 | def get_lqs_path(dir_path): 105 | base_dir = os.path.dirname(dir_path) 106 | name = os.path.basename(dir_path) 107 | hrs_path = os.path.join(base_dir, 'pkls', name + '_X4.pklv4') 108 | return hrs_path 109 | 110 | 111 | def shuffle_combined(hrs, lqs): 112 | combined = list(zip(hrs, lqs)) 113 | random.shuffle(combined) 114 | hrs[:], lqs[:] = zip(*combined) 115 | 116 | 117 | if __name__ == "__main__": 118 | dir_path = sys.argv[1] 119 | assert os.path.isdir(dir_path) 120 | main(dir_path) 121 | -------------------------------------------------------------------------------- /code/test.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | 18 | import glob 19 | import sys 20 | from collections import OrderedDict 21 | 22 | from natsort import natsort 23 | 24 | import options.options as option 25 | from Measure import Measure, psnr 26 | from imresize import imresize 27 | from models import create_model 28 | import torch 29 | from utils.util import opt_get 30 | import numpy as np 31 | import pandas as pd 32 | import os 33 | import cv2 34 | 35 | 36 | def fiFindByWildcard(wildcard): 37 | return natsort.natsorted(glob.glob(wildcard, recursive=True)) 38 | 39 | 40 | def load_model(conf_path): 41 | opt = option.parse(conf_path, is_train=False) 42 | opt['gpu_ids'] = None 43 | opt = option.dict_to_nonedict(opt) 44 | model = create_model(opt) 45 | 46 | model_path = opt_get(opt, ['model_path'], None) 47 | model.load_network(load_path=model_path, network=model.netG) 48 | return model, opt 49 | 50 | 51 | def predict(model, lr): 52 | model.feed_data({"LQ": t(lr)}, need_GT=False) 53 | model.test() 54 | visuals = model.get_current_visuals(need_GT=False) 55 | return visuals.get('rlt', visuals.get("SR")) 56 | 57 | 58 | def t(array): return torch.Tensor(np.expand_dims(array.transpose([2, 0, 1]), axis=0).astype(np.float32)) / 255 59 | 60 | 61 | def rgb(t): return ( 62 | np.clip((t[0] if len(t.shape) == 4 else t).detach().cpu().numpy().transpose([1, 2, 0]), 0, 1) * 255).astype( 63 | np.uint8) 64 | 65 | 66 | def imread(path): 67 | return cv2.imread(path)[:, :, [2, 1, 0]] 68 | 69 | 70 | def imwrite(path, img): 71 | os.makedirs(os.path.dirname(path), exist_ok=True) 72 | cv2.imwrite(path, img[:, :, [2, 1, 0]]) 73 | 74 | 75 | def imCropCenter(img, size): 76 | h, w, c = img.shape 77 | 78 | h_start = max(h // 2 - size // 2, 0) 79 | h_end = min(h_start + size, h) 80 | 81 | w_start = max(w // 2 - size // 2, 0) 82 | w_end = min(w_start + size, w) 83 | 84 | return img[h_start:h_end, w_start:w_end] 85 | 86 | 87 | def impad(img, top=0, bottom=0, left=0, right=0, color=255): 88 | return np.pad(img, [(top, bottom), (left, right), (0, 0)], 'reflect') 89 | 90 | 91 | def main(): 92 | conf_path = sys.argv[1] 93 | conf = conf_path.split('/')[-1].replace('.yml', '') 94 | model, opt = load_model(conf_path) 95 | 96 | lr_dir = opt['dataroot_LR'] 97 | hr_dir = opt['dataroot_GT'] 98 | 99 | lr_paths = fiFindByWildcard(os.path.join(lr_dir, '*.png')) 100 | hr_paths = fiFindByWildcard(os.path.join(hr_dir, '*.png')) 101 | 102 | this_dir = os.path.dirname(os.path.realpath(__file__)) 103 | test_dir = os.path.join(this_dir, '..', 'results', conf) 104 | print(f"Out dir: {test_dir}") 105 | 106 | measure = Measure(use_gpu=False) 107 | 108 | fname = f'measure_full.csv' 109 | fname_tmp = fname + "_" 110 | path_out_measures = os.path.join(test_dir, fname_tmp) 111 | path_out_measures_final = os.path.join(test_dir, fname) 112 | 113 | if os.path.isfile(path_out_measures_final): 114 | df = pd.read_csv(path_out_measures_final) 115 | elif os.path.isfile(path_out_measures): 116 | df = pd.read_csv(path_out_measures) 117 | else: 118 | df = None 119 | 120 | scale = opt['scale'] 121 | 122 | pad_factor = 2 123 | 124 | for lr_path, hr_path, idx_test in zip(lr_paths, hr_paths, range(len(lr_paths))): 125 | 126 | lr = imread(lr_path) 127 | hr = imread(hr_path) 128 | 129 | # Pad image to be % 2 130 | h, w, c = lr.shape 131 | lq_orig = lr.copy() 132 | lr = impad(lr, bottom=int(np.ceil(h / pad_factor) * pad_factor - h), 133 | right=int(np.ceil(w / pad_factor) * pad_factor - w)) 134 | 135 | lr_t = t(lr) 136 | 137 | heat = opt['heat'] 138 | 139 | if df is not None and len(df[(df['heat'] == heat) & (df['name'] == idx_test)]) == 1: 140 | continue 141 | 142 | sr_t = model.get_sr(lq=lr_t, heat=heat) 143 | 144 | sr = rgb(torch.clamp(sr_t, 0, 1)) 145 | sr = sr[:h * scale, :w * scale] 146 | 147 | path_out_sr = os.path.join(test_dir, "{:0.2f}".format(heat).replace('.', ''), "{:06d}.png".format(idx_test)) 148 | imwrite(path_out_sr, sr) 149 | 150 | meas = OrderedDict(conf=conf, heat=heat, name=idx_test) 151 | meas['PSNR'], meas['SSIM'], meas['LPIPS'] = measure.measure(sr, hr) 152 | 153 | lr_reconstruct_rgb = imresize(sr, 1 / opt['scale']) 154 | meas['LRC PSNR'] = psnr(lq_orig, lr_reconstruct_rgb) 155 | 156 | str_out = format_measurements(meas) 157 | print(str_out) 158 | 159 | df = pd.DataFrame([meas]) if df is None else pd.concat([pd.DataFrame([meas]), df]) 160 | 161 | df.to_csv(path_out_measures + "_", index=False) 162 | os.rename(path_out_measures + "_", path_out_measures) 163 | 164 | df.to_csv(path_out_measures, index=False) 165 | os.rename(path_out_measures, path_out_measures_final) 166 | 167 | str_out = format_measurements(df.mean()) 168 | print(f"Results in: {path_out_measures_final}") 169 | print('Mean: ' + str_out) 170 | 171 | 172 | def format_measurements(meas): 173 | s_out = [] 174 | for k, v in meas.items(): 175 | v = f"{v:0.2f}" if isinstance(v, float) else v 176 | s_out.append(f"{k}: {v}") 177 | str_out = ", ".join(s_out) 178 | return str_out 179 | 180 | 181 | if __name__ == "__main__": 182 | main() 183 | -------------------------------------------------------------------------------- /code/train.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | import os 18 | from os.path import basename 19 | import math 20 | import argparse 21 | import random 22 | import logging 23 | import cv2 24 | 25 | import torch 26 | import torch.distributed as dist 27 | import torch.multiprocessing as mp 28 | 29 | import options.options as option 30 | from utils import util 31 | from data import create_dataloader, create_dataset 32 | from models import create_model 33 | from utils.timer import Timer, TickTock 34 | from utils.util import get_resume_paths 35 | 36 | 37 | def getEnv(name): import os; return True if name in os.environ.keys() else False 38 | 39 | 40 | def init_dist(backend='nccl', **kwargs): 41 | ''' initialization for distributed training''' 42 | # if mp.get_start_method(allow_none=True) is None: 43 | if mp.get_start_method(allow_none=True) != 'spawn': 44 | mp.set_start_method('spawn') 45 | rank = int(os.environ['RANK']) 46 | num_gpus = torch.cuda.device_count() 47 | torch.cuda.set_deviceDistIterSampler(rank % num_gpus) 48 | dist.init_process_group(backend=backend, **kwargs) 49 | 50 | 51 | def main(): 52 | #### options 53 | parser = argparse.ArgumentParser() 54 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.') 55 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', 56 | help='job launcher') 57 | parser.add_argument('--local_rank', type=int, default=0) 58 | args = parser.parse_args() 59 | opt = option.parse(args.opt, is_train=True) 60 | 61 | #### distributed training settings 62 | opt['dist'] = False 63 | rank = -1 64 | print('Disabled distributed training.') 65 | 66 | #### loading resume state if exists 67 | if opt['path'].get('resume_state', None): 68 | resume_state_path, _ = get_resume_paths(opt) 69 | 70 | # distributed resuming: all load into default GPU 71 | if resume_state_path is None: 72 | resume_state = None 73 | else: 74 | device_id = torch.cuda.current_device() 75 | resume_state = torch.load(resume_state_path, 76 | map_location=lambda storage, loc: storage.cuda(device_id)) 77 | option.check_resume(opt, resume_state['iter']) # check resume options 78 | else: 79 | resume_state = None 80 | 81 | #### mkdir and loggers 82 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) 83 | if resume_state is None: 84 | util.mkdir_and_rename( 85 | opt['path']['experiments_root']) # rename experiment folder if exists 86 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' 87 | and 'pretrain_model' not in key and 'resume' not in key)) 88 | 89 | # config loggers. Before it, the log will not work 90 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, 91 | screen=True, tofile=True) 92 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, 93 | screen=True, tofile=True) 94 | logger = logging.getLogger('base') 95 | logger.info(option.dict2str(opt)) 96 | 97 | # tensorboard logger 98 | if opt.get('use_tb_logger', False) and 'debug' not in opt['name']: 99 | version = float(torch.__version__[0:3]) 100 | if version >= 1.1: # PyTorch 1.1 101 | from torch.utils.tensorboard import SummaryWriter 102 | else: 103 | logger.info( 104 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) 105 | from tensorboardX import SummaryWriter 106 | conf_name = basename(args.opt).replace(".yml", "") 107 | exp_dir = opt['path']['experiments_root'] 108 | log_dir_train = os.path.join(exp_dir, 'tb', conf_name, 'train') 109 | log_dir_valid = os.path.join(exp_dir, 'tb', conf_name, 'valid') 110 | tb_logger_train = SummaryWriter(log_dir=log_dir_train) 111 | tb_logger_valid = SummaryWriter(log_dir=log_dir_valid) 112 | else: 113 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) 114 | logger = logging.getLogger('base') 115 | 116 | # convert to NoneDict, which returns None for missing keys 117 | opt = option.dict_to_nonedict(opt) 118 | 119 | #### random seed 120 | seed = opt['train']['manual_seed'] 121 | if seed is None: 122 | seed = random.randint(1, 10000) 123 | if rank <= 0: 124 | logger.info('Random seed: {}'.format(seed)) 125 | util.set_random_seed(seed) 126 | 127 | torch.backends.cudnn.benchmark = True 128 | # torch.backends.cudnn.deterministic = True 129 | 130 | #### create train and val dataloader 131 | dataset_ratio = 200 # enlarge the size of each epoch 132 | for phase, dataset_opt in opt['datasets'].items(): 133 | if phase == 'train': 134 | train_set = create_dataset(dataset_opt) 135 | print('Dataset created') 136 | train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) 137 | total_iters = int(opt['train']['niter']) 138 | total_epochs = int(math.ceil(total_iters / train_size)) 139 | train_sampler = None 140 | train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) 141 | if rank <= 0: 142 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format( 143 | len(train_set), train_size)) 144 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format( 145 | total_epochs, total_iters)) 146 | elif phase == 'val': 147 | val_set = create_dataset(dataset_opt) 148 | val_loader = create_dataloader(val_set, dataset_opt, opt, None) 149 | if rank <= 0: 150 | logger.info('Number of val images in [{:s}]: {:d}'.format( 151 | dataset_opt['name'], len(val_set))) 152 | else: 153 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) 154 | assert train_loader is not None 155 | 156 | #### create model 157 | current_step = 0 if resume_state is None else resume_state['iter'] 158 | model = create_model(opt, current_step) 159 | 160 | #### resume training 161 | if resume_state: 162 | logger.info('Resuming training from epoch: {}, iter: {}.'.format( 163 | resume_state['epoch'], resume_state['iter'])) 164 | 165 | start_epoch = resume_state['epoch'] 166 | current_step = resume_state['iter'] 167 | model.resume_training(resume_state) # handle optimizers and schedulers 168 | else: 169 | current_step = 0 170 | start_epoch = 0 171 | 172 | #### training 173 | timer = Timer() 174 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) 175 | timerData = TickTock() 176 | 177 | for epoch in range(start_epoch, total_epochs + 1): 178 | if opt['dist']: 179 | train_sampler.set_epoch(epoch) 180 | 181 | timerData.tick() 182 | for _, train_data in enumerate(train_loader): 183 | timerData.tock() 184 | current_step += 1 185 | if current_step > total_iters: 186 | break 187 | 188 | #### training 189 | model.feed_data(train_data) 190 | 191 | #### update learning rate 192 | model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) 193 | 194 | try: 195 | nll = model.optimize_parameters(current_step) 196 | except RuntimeError as e: 197 | print("Skipping ERROR caught in nll = model.optimize_parameters(current_step): ") 198 | print(e) 199 | 200 | if nll is None: 201 | nll = 0 202 | 203 | #### log 204 | def eta(t_iter): 205 | return (t_iter * (opt['train']['niter'] - current_step)) / 3600 206 | 207 | if current_step % opt['logger']['print_freq'] == 0 \ 208 | or current_step - (resume_state['iter'] if resume_state else 0) < 25: 209 | avg_time = timer.get_average_and_reset() 210 | avg_data_time = timerData.get_average_and_reset() 211 | message = ' '.format( 212 | epoch, current_step, model.get_current_learning_rate(), avg_time, avg_data_time, 213 | eta(avg_time), nll) 214 | print(message) 215 | timer.tick() 216 | # Reduce number of logs 217 | if current_step % 5 == 0: 218 | tb_logger_train.add_scalar('loss/nll', nll, current_step) 219 | tb_logger_train.add_scalar('lr/base', model.get_current_learning_rate(), current_step) 220 | tb_logger_train.add_scalar('time/iteration', timer.get_last_iteration(), current_step) 221 | tb_logger_train.add_scalar('time/data', timerData.get_last_iteration(), current_step) 222 | tb_logger_train.add_scalar('time/eta', eta(timer.get_last_iteration()), current_step) 223 | for k, v in model.get_current_log().items(): 224 | tb_logger_train.add_scalar(k, v, current_step) 225 | 226 | # validation 227 | if current_step % opt['train']['val_freq'] == 0 and rank <= 0: 228 | avg_psnr = 0.0 229 | idx = 0 230 | nlls = [] 231 | for val_data in val_loader: 232 | idx += 1 233 | img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] 234 | img_dir = os.path.join(opt['path']['val_images'], img_name) 235 | util.mkdir(img_dir) 236 | 237 | model.feed_data(val_data) 238 | 239 | nll = model.test() 240 | if nll is None: 241 | nll = 0 242 | nlls.append(nll) 243 | 244 | visuals = model.get_current_visuals() 245 | 246 | sr_img = None 247 | # Save SR images for reference 248 | if hasattr(model, 'heats'): 249 | for heat in model.heats: 250 | for i in range(model.n_sample): 251 | sr_img = util.tensor2img(visuals['SR', heat, i]) # uint8 252 | save_img_path = os.path.join(img_dir, 253 | '{:s}_{:09d}_h{:03d}_s{:d}.png'.format(img_name, 254 | current_step, 255 | int(heat * 100), i)) 256 | util.save_img(sr_img, save_img_path) 257 | else: 258 | sr_img = util.tensor2img(visuals['SR']) # uint8 259 | save_img_path = os.path.join(img_dir, 260 | '{:s}_{:d}.png'.format(img_name, current_step)) 261 | util.save_img(sr_img, save_img_path) 262 | assert sr_img is not None 263 | 264 | # Save LQ images for reference 265 | save_img_path_lq = os.path.join(img_dir, 266 | '{:s}_LQ.png'.format(img_name)) 267 | if not os.path.isfile(save_img_path_lq): 268 | lq_img = util.tensor2img(visuals['LQ']) # uint8 269 | util.save_img( 270 | cv2.resize(lq_img, dsize=None, fx=opt['scale'], fy=opt['scale'], 271 | interpolation=cv2.INTER_NEAREST), 272 | save_img_path_lq) 273 | 274 | # Save GT images for reference 275 | gt_img = util.tensor2img(visuals['GT']) # uint8 276 | save_img_path_gt = os.path.join(img_dir, 277 | '{:s}_GT.png'.format(img_name)) 278 | if not os.path.isfile(save_img_path_gt): 279 | util.save_img(gt_img, save_img_path_gt) 280 | 281 | # calculate PSNR 282 | crop_size = opt['scale'] 283 | gt_img = gt_img / 255. 284 | sr_img = sr_img / 255. 285 | cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] 286 | cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] 287 | avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) 288 | 289 | avg_psnr = avg_psnr / idx 290 | avg_nll = sum(nlls) / len(nlls) 291 | 292 | # log 293 | logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) 294 | logger_val = logging.getLogger('val') # validation logger 295 | logger_val.info(' psnr: {:.4e}'.format( 296 | epoch, current_step, avg_psnr)) 297 | 298 | # tensorboard logger 299 | tb_logger_valid.add_scalar('loss/psnr', avg_psnr, current_step) 300 | tb_logger_valid.add_scalar('loss/nll', avg_nll, current_step) 301 | 302 | tb_logger_train.flush() 303 | tb_logger_valid.flush() 304 | 305 | #### save models and training states 306 | if current_step % opt['logger']['save_checkpoint_freq'] == 0: 307 | if rank <= 0: 308 | logger.info('Saving models and training states.') 309 | model.save(current_step) 310 | model.save_training_state(epoch, current_step) 311 | 312 | timerData.tick() 313 | 314 | with open(os.path.join(opt['path']['root'], "TRAIN_DONE"), 'w') as f: 315 | f.write("TRAIN_DONE") 316 | 317 | if rank <= 0: 318 | logger.info('Saving the final model.') 319 | model.save('latest') 320 | logger.info('End of training.') 321 | 322 | 323 | if __name__ == '__main__': 324 | main() 325 | -------------------------------------------------------------------------------- /code/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreas128/SRFlow/5a007ad591c7be8bf32cf23171bfa4473e71683c/code/utils/__init__.py -------------------------------------------------------------------------------- /code/utils/timer.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) 2020 Huawei Technologies Co., Ltd. 2 | # Licensed under CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International) (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode 7 | # 8 | # The code is released for academic research use only. For commercial use, please contact Huawei Technologies Co., Ltd. 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | # 15 | # This file contains content licensed by https://github.com/xinntao/BasicSR/blob/master/LICENSE/LICENSE 16 | 17 | import time 18 | 19 | 20 | class ScopeTimer: 21 | def __init__(self, name): 22 | self.name = name 23 | 24 | def __enter__(self): 25 | self.start = time.time() 26 | return self 27 | 28 | def __exit__(self, *args): 29 | self.end = time.time() 30 | self.interval = self.end - self.start 31 | print("{} {:.3E}".format(self.name, self.interval)) 32 | 33 | 34 | class Timer: 35 | def __init__(self): 36 | self.times = [] 37 | 38 | def tick(self): 39 | self.times.append(time.time()) 40 | 41 | def get_average_and_reset(self): 42 | if len(self.times) < 2: 43 | return -1 44 | avg = (self.times[-1] - self.times[0]) / (len(self.times) - 1) 45 | self.times = [self.times[-1]] 46 | return avg 47 | 48 | def get_last_iteration(self): 49 | if len(self.times) < 2: 50 | return 0 51 | return self.times[-1] - self.times[-2] 52 | 53 | 54 | class TickTock: 55 | def __init__(self): 56 | self.time_pairs = [] 57 | self.current_time = None 58 | 59 | def tick(self): 60 | self.current_time = time.time() 61 | 62 | def tock(self): 63 | assert self.current_time is not None, self.current_time 64 | self.time_pairs.append([self.current_time, time.time()]) 65 | self.current_time = None 66 | 67 | def get_average_and_reset(self): 68 | if len(self.time_pairs) == 0: 69 | return -1 70 | deltas = [t2 - t1 for t1, t2 in self.time_pairs] 71 | avg = sum(deltas) / len(deltas) 72 | self.time_pairs = [] 73 | return avg 74 | 75 | def get_last_iteration(self): 76 | if len(self.time_pairs) == 0: 77 | return -1 78 | return self.time_pairs[-1][1] - self.time_pairs[-1][0] 79 | -------------------------------------------------------------------------------- /code/utils/util.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import os 3 | import sys 4 | import time 5 | import math 6 | from datetime import datetime 7 | import random 8 | import logging 9 | from collections import OrderedDict 10 | 11 | import natsort 12 | import numpy as np 13 | import cv2 14 | import torch 15 | from torchvision.utils import make_grid 16 | from shutil import get_terminal_size 17 | 18 | import yaml 19 | 20 | try: 21 | from yaml import CLoader as Loader, CDumper as Dumper 22 | except ImportError: 23 | from yaml import Loader, Dumper 24 | 25 | 26 | def OrderedYaml(): 27 | '''yaml orderedDict support''' 28 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 29 | 30 | def dict_representer(dumper, data): 31 | return dumper.represent_dict(data.items()) 32 | 33 | def dict_constructor(loader, node): 34 | return OrderedDict(loader.construct_pairs(node)) 35 | 36 | Dumper.add_representer(OrderedDict, dict_representer) 37 | Loader.add_constructor(_mapping_tag, dict_constructor) 38 | return Loader, Dumper 39 | 40 | 41 | #################### 42 | # miscellaneous 43 | #################### 44 | 45 | 46 | def get_timestamp(): 47 | return datetime.now().strftime('%y%m%d-%H%M%S') 48 | 49 | 50 | def mkdir(path): 51 | if not os.path.exists(path): 52 | os.makedirs(path) 53 | 54 | 55 | def mkdirs(paths): 56 | if isinstance(paths, str): 57 | mkdir(paths) 58 | else: 59 | for path in paths: 60 | mkdir(path) 61 | 62 | 63 | def mkdir_and_rename(path): 64 | if os.path.exists(path): 65 | new_name = path + '_archived_' + get_timestamp() 66 | print('Path already exists. Rename it to [{:s}]'.format(new_name)) 67 | logger = logging.getLogger('base') 68 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) 69 | os.rename(path, new_name) 70 | os.makedirs(path) 71 | 72 | 73 | def set_random_seed(seed): 74 | random.seed(seed) 75 | np.random.seed(seed) 76 | torch.manual_seed(seed) 77 | torch.cuda.manual_seed_all(seed) 78 | 79 | 80 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): 81 | '''set up logger''' 82 | lg = logging.getLogger(logger_name) 83 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 84 | datefmt='%y-%m-%d %H:%M:%S') 85 | lg.setLevel(level) 86 | if tofile: 87 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) 88 | fh = logging.FileHandler(log_file, mode='w') 89 | fh.setFormatter(formatter) 90 | lg.addHandler(fh) 91 | if screen: 92 | sh = logging.StreamHandler() 93 | sh.setFormatter(formatter) 94 | lg.addHandler(sh) 95 | 96 | 97 | #################### 98 | # image convert 99 | #################### 100 | 101 | 102 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 103 | ''' 104 | Converts a torch Tensor into an image Numpy array 105 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 106 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 107 | ''' 108 | if hasattr(tensor, 'detach'): 109 | tensor = tensor.detach() 110 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 111 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 112 | n_dim = tensor.dim() 113 | if n_dim == 4: 114 | n_img = len(tensor) 115 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 116 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 117 | elif n_dim == 3: 118 | img_np = tensor.numpy() 119 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 120 | elif n_dim == 2: 121 | img_np = tensor.numpy() 122 | else: 123 | raise TypeError( 124 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 125 | if out_type == np.uint8: 126 | img_np = (img_np * 255.0).round() 127 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 128 | return img_np.astype(out_type) 129 | 130 | 131 | def save_img(img, img_path, mode='RGB'): 132 | cv2.imwrite(img_path, img) 133 | 134 | 135 | #################### 136 | # metric 137 | #################### 138 | 139 | 140 | def calculate_psnr(img1, img2): 141 | # img1 and img2 have range [0, 255] 142 | img1 = img1.astype(np.float64) 143 | img2 = img2.astype(np.float64) 144 | mse = np.mean((img1 - img2) ** 2) 145 | if mse == 0: 146 | return float('inf') 147 | return 20 * math.log10(255.0 / math.sqrt(mse)) 148 | 149 | 150 | def get_resume_paths(opt): 151 | resume_state_path = None 152 | resume_model_path = None 153 | ts = opt_get(opt, ['path', 'training_state']) 154 | if opt.get('path', {}).get('resume_state', None) == "auto" and ts is not None: 155 | wildcard = os.path.join(ts, "*") 156 | paths = natsort.natsorted(glob.glob(wildcard)) 157 | if len(paths) > 0: 158 | resume_state_path = paths[-1] 159 | resume_model_path = resume_state_path.replace('training_state', 'models').replace('.state', '_G.pth') 160 | else: 161 | resume_state_path = opt.get('path', {}).get('resume_state') 162 | return resume_state_path, resume_model_path 163 | 164 | 165 | def opt_get(opt, keys, default=None): 166 | if opt is None: 167 | return default 168 | ret = opt 169 | for k in keys: 170 | ret = ret.get(k, None) 171 | if ret is None: 172 | return default 173 | return ret 174 | 175 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | appnope==0.1.0 2 | argon2-cffi==20.1.0 3 | async-generator==1.10 4 | attrs==20.2.0 5 | backcall==0.2.0 6 | bleach==3.2.1 7 | certifi==2020.6.20 8 | cffi==1.14.3 9 | cycler==0.10.0 10 | dataclasses==0.6 11 | decorator==4.4.2 12 | defusedxml==0.6.0 13 | entrypoints==0.3 14 | environment-kernels==1.1.1 15 | future==0.18.2 16 | imageio==2.9.0 17 | importlib-metadata==2.0.0 18 | ipykernel==5.3.4 19 | ipython==7.19.0 20 | ipython-genutils==0.2.0 21 | ipywidgets==7.5.1 22 | jedi==0.17.2 23 | Jinja2==2.11.2 24 | jsonschema==3.2.0 25 | jupyter==1.0.0 26 | jupyter-client==6.1.7 27 | jupyter-console==6.2.0 28 | jupyter-core==4.6.3 29 | jupyterlab-pygments==0.1.2 30 | kiwisolver==1.3.1 31 | lpips==0.1.3 32 | MarkupSafe==1.1.1 33 | matplotlib==3.3.2 34 | mistune==0.8.4 35 | natsort==7.0.1 36 | nbclient==0.5.1 37 | nbconvert==6.0.7 38 | nbformat==5.0.8 39 | nest-asyncio==1.4.2 40 | networkx==2.5 41 | notebook==6.1.4 42 | numpy==1.19.4 43 | opencv-python==4.4.0.46 44 | packaging==20.4 45 | pandas==1.1.4 46 | pandocfilters==1.4.3 47 | parso==0.7.1 48 | pexpect==4.8.0 49 | pickleshare==0.7.5 50 | Pillow==8.0.1 51 | prometheus-client==0.8.0 52 | prompt-toolkit==3.0.8 53 | ptyprocess==0.6.0 54 | pycparser==2.20 55 | Pygments==2.7.2 56 | pyparsing==2.4.7 57 | pyrsistent==0.17.3 58 | python-dateutil==2.8.1 59 | pytz==2020.4 60 | PyWavelets==1.1.1 61 | PyYAML==5.3.1 62 | pyzmq==19.0.2 63 | qtconsole==4.7.7 64 | QtPy==1.9.0 65 | scikit-image==0.17.2 66 | scipy==1.5.3 67 | Send2Trash==1.5.0 68 | six==1.15.0 69 | terminado==0.9.1 70 | tensorboard==2.4.0 71 | testpath==0.4.4 72 | tifffile==2020.10.1 73 | torch==1.7.0 74 | torchvision==0.8.1 75 | tornado==6.1 76 | tqdm==4.51.0 77 | traitlets==5.0.5 78 | typing-extensions==3.7.4.3 79 | wcwidth==0.2.5 80 | webencodings==0.5.1 81 | widgetsnbextension==3.5.1 82 | zipp==3.4.0 83 | -------------------------------------------------------------------------------- /run_jupyter.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e # exit script if an error occurs 4 | 5 | source myenv/bin/activate # Install with ./setup.sh 6 | cd code 7 | python -m jupyter notebook demo_on_pretrained.ipynb # Start jupyter using the python from the virtual environment 8 | -------------------------------------------------------------------------------- /setup.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | set -e # exit script if an error occurs 4 | 5 | 6 | 7 | echo "" 8 | echo "########################################" 9 | echo "Setup Virtual Environment" 10 | echo "########################################" 11 | echo "" 12 | 13 | python3 -m venv myenv # Create a new virtual environment (venv) using native python3.7 venv 14 | source myenv/bin/activate # This replaces the python/pip command with the ones from the venv 15 | which python # shoud output: ./myenv/bin/python 16 | 17 | pip install --upgrade pip # Update pip 18 | pip install -r requirements.txt # Install the exact same packages that we used 19 | 20 | # Alternatively you can install globally using pip 21 | # pip install jupyter torch natsort pyyaml opencv-python torchvision scikit-image tqdm lpips pandas environment_kernels 22 | 23 | 24 | 25 | echo "" 26 | echo "########################################" 27 | echo "Download models, data" 28 | echo "########################################" 29 | echo "" 30 | 31 | wget --continue http://data.vision.ee.ethz.ch/alugmayr/SRFlow/datasets.zip 32 | unzip datasets.zip 33 | rm datasets.zip 34 | 35 | wget --continue http://data.vision.ee.ethz.ch/alugmayr/SRFlow/pretrained_models.zip 36 | unzip pretrained_models.zip 37 | rm pretrained_models.zip 38 | 39 | 40 | echo "" 41 | echo "########################################" 42 | echo "Start Demo" 43 | echo "########################################" 44 | echo "" 45 | 46 | ./run_jupyter.sh 47 | --------------------------------------------------------------------------------