├── .gitignore ├── LICENSE ├── README.md ├── README.md-old ├── codes ├── __init__.py ├── data │ ├── GT_dataset.py │ ├── LRHR_seg_bg_dataset.py │ ├── LR_dataset.py │ ├── __init__.py │ ├── data_loader.py │ ├── data_sampler.py │ └── util.py ├── models │ ├── SFTGAN_ACD_model.py │ ├── SRGAN_model.py │ ├── SR_model.py │ ├── __init__.py │ ├── base_model.py │ ├── lr_scheduler.py │ ├── modules │ │ ├── RRDBNet_arch.py │ │ ├── SRResNet_arch.py │ │ ├── __init__.py │ │ ├── __pycache__ │ │ │ ├── RRDBNet_arch.cpython-36.pyc │ │ │ ├── SRResNet_arch.cpython-36.pyc │ │ │ ├── __init__.cpython-36.pyc │ │ │ ├── discriminator_vgg_arch.cpython-36.pyc │ │ │ ├── loss.cpython-36.pyc │ │ │ └── module_util.cpython-36.pyc │ │ ├── discriminator_vgg_arch.py │ │ ├── loss.py │ │ ├── module_util.py │ │ ├── seg_arch.py │ │ └── sft_arch.py │ └── networks.py ├── options │ ├── __init__.py │ ├── df2k │ │ └── test_df2k.yml │ ├── dped │ │ └── test_dped.yml │ └── options.py ├── scripts │ ├── back_projection │ │ ├── backprojection.m │ │ ├── main_bp.m │ │ └── main_reverse_filter.m │ ├── color2gray.py │ ├── create_lmdb.py │ ├── extract_subimgs_single.py │ ├── generate_mod_LR_bic.m │ ├── generate_mod_LR_bic.py │ ├── transfer_params_MSRResNet.py │ └── transfer_params_sft.py ├── test.py ├── train.py └── utils │ ├── __init__.py │ └── util.py ├── figures ├── 0913.png ├── 0935.png ├── arch.png ├── df2k.png ├── dped.png ├── track1.png └── track2.png └── yoon ├── degradation_pair_data.py ├── img_resize.py ├── options ├── train_df2k.yml └── train_dped.yml ├── stage1_kernel.py ├── stage1_noise.py └── train_realsr.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Created by .ignore support plugin (hsz.mobi) 2 | ### Example user template template 3 | ### Example user template 4 | 5 | # IntelliJ project files 6 | .idea 7 | *.iml 8 | out 9 | gen 10 | ### Python template 11 | # Byte-compiled / optimized / DLL files 12 | __pycache__/ 13 | *.py[cod] 14 | *$py.class 15 | 16 | # C extensions 17 | *.so 18 | 19 | # Distribution / packaging 20 | .Python 21 | build/ 22 | develop-eggs/ 23 | dist/ 24 | downloads/ 25 | eggs/ 26 | .eggs/ 27 | lib/ 28 | lib64/ 29 | parts/ 30 | sdist/ 31 | var/ 32 | wheels/ 33 | pip-wheel-metadata/ 34 | share/python-wheels/ 35 | *.egg-info/ 36 | .installed.cfg 37 | *.egg 38 | MANIFEST 39 | 40 | # PyInstaller 41 | # Usually these files are written by a python script from a template 42 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 43 | *.manifest 44 | *.spec 45 | 46 | # Installer logs 47 | pip-log.txt 48 | pip-delete-this-directory.txt 49 | 50 | # Unit test / coverage reports 51 | htmlcov/ 52 | .tox/ 53 | .nox/ 54 | .coverage 55 | .coverage.* 56 | .cache 57 | nosetests.xml 58 | coverage.xml 59 | *.cover 60 | .hypothesis/ 61 | .pytest_cache/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | target/ 85 | 86 | # Jupyter Notebook 87 | .ipynb_checkpoints 88 | 89 | # IPython 90 | profile_default/ 91 | ipython_config.py 92 | 93 | # pyenv 94 | .python-version 95 | 96 | # pipenv 97 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 98 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 99 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 100 | # install all needed dependencies. 101 | #Pipfile.lock 102 | 103 | # celery beat schedule file 104 | celerybeat-schedule 105 | 106 | # SageMath parsed files 107 | *.sage.py 108 | 109 | # Environments 110 | .env 111 | .venv 112 | env/ 113 | venv/ 114 | ENV/ 115 | env.bak/ 116 | venv.bak/ 117 | 118 | # Spyder project settings 119 | .spyderproject 120 | .spyproject 121 | 122 | # Rope project settings 123 | .ropeproject 124 | 125 | # mkdocs documentation 126 | /site 127 | 128 | # mypy 129 | .mypy_cache/ 130 | .dmypy.json 131 | dmypy.json 132 | 133 | # Pyre type checker 134 | .pyre/ 135 | 136 | .vscode/ 137 | pretrained_model/ 138 | results/ 139 | yoon/kernels/ 140 | yoon/noises/ 141 | experiments/ 142 | .env* 143 | tb_logger/ 144 | exp_archive/ 145 | *.zip 146 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Intro. 2 | 3 | This repo is an unofficial implementation of [RealSR](https://github.com/jixiaozhong/RealSR) including training codes and kernel/noise estimation codes. 4 | 5 | # Usage 6 | 7 | ## kernel and noise estimation 8 | 9 | You can download the estimated kernel from [here](https://sianalytics-my.sharepoint.com/:u:/g/personal/yoon28_si-analytics_ai/EVr8bcSAy4lIlTA14L7TClMBmfjQop-MpnM4Y_XFgKqoWA?e=kO4uyB) and the noise images from [here](https://sianalytics-my.sharepoint.com/:u:/g/personal/yoon28_si-analytics_ai/EVm7nV1BekFJtd6xr75sEJQBhY2FdQqIW1o3bZkVv4DEtA?e=ZhhqEB). 10 | 11 | Furthermore, you can estimate kernels with NTIRE2020 real-world SR data by executing `stage1_kernel.py` located under `yoon` folder. 12 | Since the kernel estiamtion needs [KernelGAN](https://github.com/sefibk/KernelGAN), you have to designate the location of KernelGAN codes, eg. `export PYTHONPATH=/path/to/kernelgan`. 13 | 14 | For the noises, you can also use `stage1_noise.py` to estimate noises from corrupted images rather than downloading the noise images from the above. 15 | Note that the noises you downloaded are produced with settings `patch_size=128` and `max_var=100`. 16 | 17 | To locate the NTIRE2020 dataset, you can modify the variables `DATA_LOC`, `DATA_X`, `DATA_Y`, `DATA_VAL` that are located at the early part of `stage1_kernel.py` and `stage1_noise.py`. 18 | 19 | ## training 20 | 21 | For training, please use `train_realsr.py` codes located under `yoon` folder. 22 | For example: 23 | 24 | ``` 25 | PYTHONPATH=/mnt/workspace/SR/RealSR/codes CUDA_VISIBLE_DEVICES=14,15 python3 yoon/train_realsr.py -opt yoon/options/train_df2k.yml 26 | ``` 27 | 28 | This example first registers python codes under `codes` folders to `PYTHONPATH` since I separated original codes from the codes I implemented. And it declares 2 gpus for the training procedure. 29 | 30 | # Misc. 31 | 32 | First, this repo is somewhat messy because I implemented the paper for my own needs. Second, I fail to reproduce the results in the sense of SR quality. With my implementation, I saw that noises are very well removed but the sharpness of my results is not as much good as the original results. 33 | 34 | In my opinion, reproducing the paper is difficult because the author did not share many of the hyper-paramter settings, such as the variance cutoff (`max_var`), the size of noise patch (`patch_size`), the clean-up scale factor and so on. 35 | 36 | # Reference 37 | 38 | Ji, Xiaozhong, et al. "Real-World Super-Resolution via Kernel Estimation and Noise Injection." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition Workshops. 2020. -------------------------------------------------------------------------------- /README.md-old: -------------------------------------------------------------------------------- 1 | # RealSR 2 | 3 | Real-World Super-Resolution via Kernel Estimation and Noise Injection 4 | 5 | Xiaozhong Ji, Yun Cao, Ying Tai, Chengjie Wang, Jilin Li, and Feiyue Huang 6 | 7 | *Tencent YouTu Lab* 8 | 9 | Our solution is the **winner of CVPR NTIRE 2020 Challenge on Real-World Super-Resolution** in both tracks. 10 | 11 | (*Official PyTorch Implementation*) 12 | 13 | ## Update - May 26, 2020 14 | - Add [DF2K-JPEG](https://drive.google.com/open?id=1w8QbCLM6g-MMVlIhRERtSXrP-Dh7cPhm) Model. 15 | - [Executable files](https://drive.google.com/open?id=1-FZPyMtuDfEnAPgSBfePYhv0NorznDPU) based on [ncnn](https://github.com/Tencent/ncnn) are available. Test your own images on windows/linux/macos. More details refer to [realsr-ncnn-vulkan](https://github.com/nihui/realsr-ncnn-vulkan) 16 | - Usage - ```./realsr-ncnn-vulkan -i in.jpg -o out.png``` 17 | - ```-x``` - use ensemble 18 | - ```-g 0``` - select gpu id. 19 | 20 | ## Introduction 21 | 22 | Recent state-of-the-art super-resolution methods have achieved impressive performance on ideal datasets regardless of blur and noise. However, these methods always fail in real-world image super-resolution, since most of them adopt simple bicubic downsampling from high-quality images to construct Low-Resolution (LR) and High-Resolution (HR) pairs for training which may lose track of frequency-related details. To address this issue, we focus on designing a novel degradation framework for real-world images by estimating various blur kernels as well as real noise distributions. Based on our novel degradation framework, we can acquire LR images sharing a common domain with real-world images. Then, we propose a real-world super-resolution model aiming at better perception. Extensive experiments on synthetic noise data and real-world images demonstrate that our method outperforms the state-of-the-art methods, resulting in lower noise and better visual quality. In addition, our method is the winner of NTIRE 2020 Challenge on both tracks of Real-World Super-Resolution, which significantly outperforms other competitors by large margins. 23 | 24 | ![RealSR](figures/arch.png) 25 | 26 | If you are interested in this work, please cite our [paper](http://openaccess.thecvf.com/content_CVPRW_2020/papers/w31/Ji_Real-World_Super-Resolution_via_Kernel_Estimation_and_Noise_Injection_CVPRW_2020_paper.pdf) 27 | 28 | @InProceedings{Ji_2020_CVPR_Workshops, 29 | author = {Ji, Xiaozhong and Cao, Yun and Tai, Ying and Wang, Chengjie and Li, Jilin and Huang, Feiyue}, 30 | title = {Real-World Super-Resolution via Kernel Estimation and Noise Injection}, 31 | booktitle = {The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) Workshops}, 32 | month = {June}, 33 | year = {2020} 34 | } 35 | 36 | and challenge report [NTIRE 2020 Challenge on Real-World Image Super-Resolution: Methods and Results](https://arxiv.org/pdf/2005.01996.pdf) 37 | 38 | @article{Lugmayr2020ntire, 39 | title={NTIRE 2020 Challenge on Real-World Image Super-Resolution: Methods and Results}, 40 | author={Andreas Lugmayr, Martin Danelljan, Radu Timofte, Namhyuk Ahn, Dongwoon Bai, Jie Cai, Yun Cao, Junyang Chen, Kaihua Cheng, SeYoung Chun, Wei Deng, Mostafa El-Khamy Chiu, Man Ho, Xiaozhong Ji, Amin Kheradmand, Gwantae Kim, Hanseok Ko, Kanghyu Lee, Jungwon Lee, Hao Li, Ziluan Liu, Zhi-Song Liu, Shuai Liu, Yunhua Lu, Zibo Meng, Pablo Navarrete, Michelini Christian, Micheloni Kalpesh, Prajapati Haoyu, Ren Yong, Hyeok Seo, Wan-Chi Siu, Kyung-Ah Sohn, Ying Tai, Rao Muhammad Umer, Shuangquan Wang, Huibing Wang, Timothy Haoning Wu, Haoning Wu, Biao Yang, Fuzhi Yang, Jaejun Yoo, Tongtong Zhao, Yuanbo Zhou, Haijie Zhuo, Ziyao Zong, Xueyi Zou}, 41 | journal={CVPR Workshops}, 42 | year={2020}, 43 | } 44 | 45 | 46 | 47 | 48 | ## Visual Results 49 | 50 | ![0](figures/0913.png) 51 | 52 | ![1](figures/0935.png) 53 | 54 | # Quantitative Results Compared with Other Participating Methods 55 | 56 | 'Impressionism' is our team. Note that the final decision is based on MOS (Mean Opinion Score) and MOR (Mean Opinion Rank). 57 | 58 | ![0](figures/track1.png) 59 | 60 | ![1](figures/track2.png) 61 | 62 | # Qualitative Results Compared with Other Participating Methods 63 | 64 | 'Impressionism' is our team. 65 | 66 | ![0](figures/df2k.png) 67 | 68 | ![1](figures/dped.png) 69 | 70 | 71 | 72 | ## Dependencies and Installation 73 | This code is based on [BasicSR](https://github.com/xinntao/BasicSR). 74 | 75 | - Python 3 (Recommend to use [Anaconda](https://www.anaconda.com/download/#linux)) 76 | - [PyTorch >= 1.0](https://pytorch.org/) 77 | - NVIDIA GPU + [CUDA](https://developer.nvidia.com/cuda-downloads) 78 | - Python packages: `pip install numpy opencv-python lmdb pyyaml` 79 | - TensorBoard: 80 | - PyTorch >= 1.1: `pip install tb-nightly future` 81 | - PyTorch == 1.0: `pip install tensorboardX` 82 | 83 | 84 | ## Pre-trained models 85 | - Models for challenge results 86 | - [DF2K](https://drive.google.com/open?id=1pWGfSw-UxOkrtbh14GeLQgYnMLdLguOF) for corrupted images with processing noise. 87 | - [DPED](https://drive.google.com/open?id=1zZIuQSepFlupV103AatoP-JSJpwJFS19) for real images taken by cell phone camera. 88 | - Extended models 89 | - [DF2K-JPEG](https://drive.google.com/open?id=1w8QbCLM6g-MMVlIhRERtSXrP-Dh7cPhm) for compressed jpeg image. 90 | 91 | ## Testing 92 | Download dataset from [NTIRE 2020 RWSR](https://competitions.codalab.org/competitions/22220#participate) and unzip it to your path. 93 | 94 | For convenient, we provide [Corrupted-te-x](https://drive.google.com/open?id=1GrLxeE-LruddQoAePV1Z7MFclXdZWHMa) and [DPEDiphone-crop-te-x](https://drive.google.com/open?id=19zlofWRxkhsjf_TuRA2oI9jgozifGvxp). 95 | 96 | ```cd ./codes``` 97 | 98 | ### DF2K: Image processing artifacts 99 | 1. Modify the configuration file options/df2k/test_df2k.yml 100 | - line 1 : 'name' -- dir name for saving the testing results 101 | - line 13 : 'dataroot_LR' -- test images dir 102 | - line 26 : 'pretrain_model_G' -- pre-trained model for testing 103 | 2. Run command : 104 | ```CUDA_VISIBLE_DEVICES=X python3 test.py -opt options/df2k/test_df2k.yml ``` 105 | 3. The output images is saved in '../results/' 106 | 107 | ### DPED: Smartphone images 108 | 1. Modify the configuration file options/dped/test_dped.yml 109 | - line 1 : 'name' -- dir name for saving the testing results 110 | - line 13 : 'dataroot_LR' -- test images dir 111 | - line 26 : 'pretrain_model_G' -- pre-trained model for testing 112 | 2. Run command : 113 | ```CUDA_VISIBLE_DEVICES=X python3 test.py -opt options/dped/test_dped.yml``` 114 | 3. The output images is saved in '../results/' 115 | 116 | 117 | ## Training code 118 | 119 | Release soon. 120 | -------------------------------------------------------------------------------- /codes/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/codes/__init__.py -------------------------------------------------------------------------------- /codes/data/GT_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import lmdb 5 | import torch 6 | import torch.utils.data as data 7 | import data.util as util 8 | 9 | 10 | class GTDataset(data.Dataset): 11 | ''' 12 | Read LQ (Low Quality, here is LR) and GT image pairs. 13 | If only GT image is provided, generate LQ image on-the-fly. 14 | The pair is ensured by 'sorted' function, so please check the name convention. 15 | ''' 16 | 17 | def __init__(self, opt): 18 | super(GTDataset, self).__init__() 19 | self.opt = opt 20 | self.data_type = self.opt['data_type'] 21 | self.paths_LQ, self.paths_GT = None, None 22 | self.sizes_LQ, self.sizes_GT = None, None 23 | self.LQ_env, self.GT_env = None, None # environment for lmdb 24 | 25 | self.paths_GT, self.sizes_GT = util.get_image_paths(self.data_type, opt['dataroot_GT']) 26 | self.paths_LQ, self.sizes_LQ = util.get_image_paths(self.data_type, opt['dataroot_LQ']) 27 | assert self.paths_GT, 'Error: GT path is empty.' 28 | if self.paths_LQ and self.paths_GT: 29 | assert len(self.paths_LQ) == len( 30 | self.paths_GT 31 | ), 'GT and LQ datasets have different number of images - {}, {}.'.format( 32 | len(self.paths_LQ), len(self.paths_GT)) 33 | self.random_scale_list = [1] 34 | 35 | def _init_lmdb(self): 36 | # https://github.com/chainer/chainermn/issues/129 37 | self.GT_env = lmdb.open(self.opt['dataroot_GT'], readonly=True, lock=False, readahead=False, 38 | meminit=False) 39 | self.LQ_env = lmdb.open(self.opt['dataroot_LQ'], readonly=True, lock=False, readahead=False, 40 | meminit=False) 41 | 42 | def __getitem__(self, index): 43 | if self.data_type == 'lmdb': 44 | if (self.GT_env is None) or (self.LQ_env is None): 45 | self._init_lmdb() 46 | GT_path, LQ_path = None, None 47 | scale = self.opt['scale'] 48 | GT_size = self.opt['GT_size'] 49 | 50 | # get GT image 51 | GT_path = self.paths_GT[index] 52 | if self.data_type == 'lmdb': 53 | resolution = [int(s) for s in self.sizes_GT[index].split('_')] 54 | else: 55 | resolution = None 56 | img_GT = util.read_img(self.GT_env, GT_path, resolution) 57 | # modcrop in the validation / test phase 58 | if self.opt['phase'] != 'train': 59 | img_GT = util.modcrop(img_GT, scale) 60 | # change color space if necessary 61 | if self.opt['color']: 62 | img_GT = util.channel_convert(img_GT.shape[2], self.opt['color'], [img_GT])[0] 63 | 64 | # get LQ image 65 | if self.paths_LQ: 66 | LQ_path = self.paths_LQ[index] 67 | if self.data_type == 'lmdb': 68 | resolution = [int(s) for s in self.sizes_LQ[index].split('_')] 69 | else: 70 | resolution = None 71 | img_LQ = util.read_img(self.LQ_env, LQ_path, resolution) 72 | # img_LQ = util.imresize_np(img_LQ, 4) 73 | # img_LQ = util.imresize_np(img_LQ, 0.25) 74 | else: # down-sampling on-the-fly 75 | # randomly scale during training 76 | if self.opt['phase'] == 'train': 77 | random_scale = random.choice(self.random_scale_list) 78 | H_s, W_s, _ = img_GT.shape 79 | 80 | def _mod(n, random_scale, scale, thres): 81 | rlt = int(n * random_scale) 82 | rlt = (rlt // scale) * scale 83 | return thres if rlt < thres else rlt 84 | 85 | H_s = _mod(H_s, random_scale, scale, GT_size) 86 | W_s = _mod(W_s, random_scale, scale, GT_size) 87 | img_GT = cv2.resize(np.copy(img_GT), (W_s, H_s), interpolation=cv2.INTER_LINEAR) 88 | # force to 3 channels 89 | if img_GT.ndim == 2: 90 | img_GT = cv2.cvtColor(img_GT, cv2.COLOR_GRAY2BGR) 91 | 92 | H, W, _ = img_GT.shape 93 | # using matlab imresize 94 | img_LQ = util.imresize_np(img_GT, 1 / scale, True) 95 | if img_LQ.ndim == 2: 96 | img_LQ = np.expand_dims(img_LQ, axis=2) 97 | 98 | if self.opt['phase'] == 'train': 99 | # if the image size is too small 100 | H, W, _ = img_GT.shape 101 | if H < GT_size or W < GT_size: 102 | img_GT = cv2.resize(np.copy(img_GT), (GT_size, GT_size), 103 | interpolation=cv2.INTER_LINEAR) 104 | # using matlab imresize 105 | img_LQ = util.imresize_np(img_GT, 1 / scale, True) 106 | if img_LQ.ndim == 2: 107 | img_LQ = np.expand_dims(img_LQ, axis=2) 108 | 109 | H, W, C = img_LQ.shape 110 | LQ_size = GT_size // scale 111 | 112 | # randomly crop 113 | rnd_h = random.randint(0, max(0, H - LQ_size)) 114 | rnd_w = random.randint(0, max(0, W - LQ_size)) 115 | img_LQ = img_LQ[rnd_h:rnd_h + LQ_size, rnd_w:rnd_w + LQ_size, :] 116 | rnd_h_GT, rnd_w_GT = int(rnd_h * scale), int(rnd_w * scale) 117 | img_GT = img_GT[rnd_h_GT:rnd_h_GT + GT_size, rnd_w_GT:rnd_w_GT + GT_size, :] 118 | 119 | # augmentation - flip, rotate 120 | img_LQ, img_GT = util.augment([img_LQ, img_GT], self.opt['use_flip'], 121 | self.opt['use_rot']) 122 | 123 | # change color space if necessary 124 | if self.opt['color']: 125 | img_LQ = util.channel_convert(C, self.opt['color'], 126 | [img_LQ])[0] # TODO during val no definition 127 | 128 | # BGR to RGB, HWC to CHW, numpy to tensor 129 | if img_GT.shape[2] == 3: 130 | img_GT = img_GT[:, :, [2, 1, 0]] 131 | img_LQ = img_LQ[:, :, [2, 1, 0]] 132 | img_GT = torch.from_numpy(np.ascontiguousarray(np.transpose(img_GT, (2, 0, 1)))).float() 133 | img_LQ = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LQ, (2, 0, 1)))).float() 134 | 135 | if LQ_path is None: 136 | LQ_path = GT_path 137 | return {'LQ': img_LQ, 'GT': img_GT, 'LQ_path': LQ_path, 'GT_path': GT_path} 138 | 139 | def __len__(self): 140 | return len(self.paths_GT) 141 | -------------------------------------------------------------------------------- /codes/data/LRHR_seg_bg_dataset.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | import cv2 4 | import torch 5 | import torch.utils.data as data 6 | import data.util as util 7 | 8 | 9 | class LRHRSeg_BG_Dataset(data.Dataset): 10 | ''' 11 | Read HR image, segmentation probability map; generate LR image, category for SFTGAN 12 | also sample general scenes for background 13 | need to generate LR images on-the-fly 14 | ''' 15 | 16 | def __init__(self, opt): 17 | super(LRHRSeg_BG_Dataset, self).__init__() 18 | self.opt = opt 19 | self.paths_LR = None 20 | self.paths_HR = None 21 | self.paths_HR_bg = None # HR images for background scenes 22 | self.LR_env = None # environment for lmdb 23 | self.HR_env = None 24 | self.HR_env_bg = None 25 | 26 | # read image list from lmdb or image files 27 | self.HR_env, self.paths_HR = util.get_image_paths(opt['data_type'], opt['dataroot_GT']) 28 | self.LR_env, self.paths_LR = util.get_image_paths(opt['data_type'], opt['dataroot_LR']) 29 | self.HR_env_bg, self.paths_HR_bg = util.get_image_paths(opt['data_type'], 30 | opt['dataroot_GT_bg']) 31 | 32 | assert self.paths_HR, 'Error: HR path is empty.' 33 | if self.paths_LR and self.paths_HR: 34 | assert len(self.paths_LR) == len(self.paths_HR), \ 35 | 'HR and LR datasets have different number of images - {}, {}.'.format( 36 | len(self.paths_LR), len(self.paths_HR)) 37 | 38 | self.random_scale_list = [1, 0.9, 0.8, 0.7, 0.6, 0.5] 39 | self.ratio = 10 # 10 OST data samples and 1 DIV2K general data samples(background) 40 | 41 | def __getitem__(self, index): 42 | HR_path, LR_path = None, None 43 | scale = self.opt['scale'] 44 | HR_size = self.opt['HR_size'] 45 | 46 | # get HR image 47 | if self.opt['phase'] == 'train' and \ 48 | random.choice(list(range(self.ratio))) == 0: # read background images 49 | bg_index = random.randint(0, len(self.paths_HR_bg) - 1) 50 | HR_path = self.paths_HR_bg[bg_index] 51 | img_HR = util.read_img(self.HR_env_bg, HR_path) 52 | seg = torch.FloatTensor(8, img_HR.shape[0], img_HR.shape[1]).fill_(0) 53 | seg[0, :, :] = 1 # background 54 | else: 55 | HR_path = self.paths_HR[index] 56 | img_HR = util.read_img(self.HR_env, HR_path) 57 | seg = torch.load(HR_path.replace('/img/', '/bicseg/').replace('.png', '.pth')) 58 | # read segmentatin files, you should change it to your settings. 59 | 60 | # modcrop in the validation / test phase 61 | if self.opt['phase'] != 'train': 62 | img_HR = util.modcrop(img_HR, 8) 63 | 64 | seg = np.transpose(seg.numpy(), (1, 2, 0)) 65 | 66 | # get LR image 67 | if self.paths_LR: 68 | LR_path = self.paths_LR[index] 69 | img_LR = util.read_img(self.LR_env, LR_path) 70 | else: # down-sampling on-the-fly 71 | # randomly scale during training 72 | if self.opt['phase'] == 'train': 73 | random_scale = random.choice(self.random_scale_list) 74 | H_s, W_s, _ = seg.shape 75 | 76 | def _mod(n, random_scale, scale, thres): 77 | rlt = int(n * random_scale) 78 | rlt = (rlt // scale) * scale 79 | return thres if rlt < thres else rlt 80 | 81 | H_s = _mod(H_s, random_scale, scale, HR_size) 82 | W_s = _mod(W_s, random_scale, scale, HR_size) 83 | img_HR = cv2.resize(np.copy(img_HR), (W_s, H_s), interpolation=cv2.INTER_LINEAR) 84 | seg = cv2.resize(np.copy(seg), (W_s, H_s), interpolation=cv2.INTER_NEAREST) 85 | 86 | H, W, _ = img_HR.shape 87 | # using matlab imresize 88 | img_LR = util.imresize_np(img_HR, 1 / scale, True) 89 | if img_LR.ndim == 2: 90 | img_LR = np.expand_dims(img_LR, axis=2) 91 | 92 | H, W, C = img_LR.shape 93 | if self.opt['phase'] == 'train': 94 | LR_size = HR_size // scale 95 | 96 | # randomly crop 97 | rnd_h = random.randint(0, max(0, H - LR_size)) 98 | rnd_w = random.randint(0, max(0, W - LR_size)) 99 | img_LR = img_LR[rnd_h:rnd_h + LR_size, rnd_w:rnd_w + LR_size, :] 100 | rnd_h_HR, rnd_w_HR = int(rnd_h * scale), int(rnd_w * scale) 101 | img_HR = img_HR[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :] 102 | seg = seg[rnd_h_HR:rnd_h_HR + HR_size, rnd_w_HR:rnd_w_HR + HR_size, :] 103 | 104 | # augmentation - flip, rotate 105 | img_LR, img_HR, seg = util.augment([img_LR, img_HR, seg], self.opt['use_flip'], 106 | self.opt['use_rot']) 107 | 108 | # category 109 | if 'building' in HR_path: 110 | category = 1 111 | elif 'plant' in HR_path: 112 | category = 2 113 | elif 'mountain' in HR_path: 114 | category = 3 115 | elif 'water' in HR_path: 116 | category = 4 117 | elif 'sky' in HR_path: 118 | category = 5 119 | elif 'grass' in HR_path: 120 | category = 6 121 | elif 'animal' in HR_path: 122 | category = 7 123 | else: 124 | category = 0 # background 125 | else: 126 | category = -1 # during val, useless 127 | 128 | # BGR to RGB, HWC to CHW, numpy to tensor 129 | if img_HR.shape[2] == 3: 130 | img_HR = img_HR[:, :, [2, 1, 0]] 131 | img_LR = img_LR[:, :, [2, 1, 0]] 132 | img_HR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_HR, (2, 0, 1)))).float() 133 | img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float() 134 | seg = torch.from_numpy(np.ascontiguousarray(np.transpose(seg, (2, 0, 1)))).float() 135 | 136 | if LR_path is None: 137 | LR_path = HR_path 138 | return { 139 | 'LR': img_LR, 140 | 'HR': img_HR, 141 | 'seg': seg, 142 | 'category': category, 143 | 'LR_path': LR_path, 144 | 'HR_path': HR_path 145 | } 146 | 147 | def __len__(self): 148 | return len(self.paths_HR) 149 | -------------------------------------------------------------------------------- /codes/data/LR_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.utils.data as data 4 | import data.util as util 5 | 6 | 7 | class LRDataset(data.Dataset): 8 | '''Read LR images only in the test phase.''' 9 | 10 | def __init__(self, opt): 11 | super(LRDataset, self).__init__() 12 | self.opt = opt 13 | self.paths_LR = None 14 | self.LR_env = None # environment for lmdb 15 | 16 | # read image list from lmdb or image files 17 | self.paths_LR, _ = util.get_image_paths(opt['data_type'], opt['dataroot_LR']) 18 | assert self.paths_LR, 'Error: LR paths are empty.' 19 | 20 | def __getitem__(self, index): 21 | LR_path = None 22 | 23 | # get LR image 24 | LR_path = self.paths_LR[index] 25 | img_LR = util.read_img(self.LR_env, LR_path) 26 | H, W, C = img_LR.shape 27 | 28 | # change color space if necessary 29 | if self.opt['color']: 30 | img_LR = util.channel_convert(C, self.opt['color'], [img_LR])[0] 31 | 32 | # BGR to RGB, HWC to CHW, numpy to tensor 33 | if img_LR.shape[2] == 3: 34 | img_LR = img_LR[:, :, [2, 1, 0]] 35 | img_LR = torch.from_numpy(np.ascontiguousarray(np.transpose(img_LR, (2, 0, 1)))).float() 36 | 37 | return {'LQ': img_LR, 'LQ_path': LR_path} 38 | 39 | def __len__(self): 40 | return len(self.paths_LR) 41 | -------------------------------------------------------------------------------- /codes/data/__init__.py: -------------------------------------------------------------------------------- 1 | '''create dataset and dataloader''' 2 | import logging 3 | import torch 4 | import torch.utils.data 5 | 6 | 7 | def create_dataloader(dataset, dataset_opt, opt=None, sampler=None): 8 | phase = dataset_opt['phase'] 9 | if phase == 'train': 10 | if opt['dist']: 11 | world_size = torch.distributed.get_world_size() 12 | num_workers = dataset_opt['n_workers'] 13 | assert dataset_opt['batch_size'] % world_size == 0 14 | batch_size = dataset_opt['batch_size'] // world_size 15 | shuffle = False 16 | else: 17 | num_workers = dataset_opt['n_workers'] * len(opt['gpu_ids']) 18 | batch_size = dataset_opt['batch_size'] 19 | shuffle = True 20 | return torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=shuffle, 21 | num_workers=num_workers, sampler=sampler, drop_last=True, 22 | pin_memory=False) 23 | else: 24 | return torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False, num_workers=1, 25 | pin_memory=True) 26 | 27 | 28 | def create_dataset(dataset_opt): 29 | mode = dataset_opt['mode'] 30 | if mode == 'LR': 31 | from data.LR_dataset import LRDataset as D 32 | dataset = D(dataset_opt) 33 | elif mode == 'LQGT': 34 | from data.LQGT_dataset import LQGTDataset as D 35 | dataset = D(dataset_opt) 36 | # elif mode == 'LQGTseg_bg': 37 | # from data.LQGT_seg_bg_dataset import LQGTSeg_BG_Dataset as D 38 | elif mode == 'yoon': 39 | if dataset_opt["phase"] == "train": 40 | from degradation_pair_data import DegradationParing as D 41 | hr_folder = dataset_opt["dataroot_GT"] 42 | kern_folder = dataset_opt["kernel_folder"] 43 | noise_folder = dataset_opt["noise_folder"] 44 | gt_patch_size = dataset_opt["GT_size"] 45 | scale_factor = 1 / dataset_opt["scale"] 46 | use_shuffle = dataset_opt["use_shuffle"] 47 | rgb = dataset_opt["color"] == "RGB" 48 | dataset = D(hr_folder, kern_folder, noise_folder, scale_factor, gt_patch_size, permute=use_shuffle, bgr2rgb=rgb) 49 | else: 50 | from degradation_pair_data import TestDataSR as D 51 | lr_folder = dataset_opt["dataroot_LQ"] 52 | gt_folder = dataset_opt["dataroot_GT"] 53 | use_shuffle = dataset_opt["use_shuffle"] 54 | rgb = dataset_opt["color"] == "RGB" 55 | dataset = D(lr_folder, gt_folder, "/mnt/data/NTIRE2020/realSR/track1/Corrupted-te-x", permute=use_shuffle, bgr2rgb=rgb) 56 | else: 57 | raise NotImplementedError('Dataset [{:s}] is not recognized.'.format(mode)) 58 | 59 | logger = logging.getLogger('base') 60 | logger.info('Dataset [{:s} - {:s}] is created.'.format(dataset.__class__.__name__, 61 | dataset_opt['name'])) 62 | return dataset 63 | -------------------------------------------------------------------------------- /codes/data/data_loader.py: -------------------------------------------------------------------------------- 1 | # import torchvision.datasets as dset 2 | # import torchvision.transforms as transforms 3 | # import torch.utils.data as data_utils 4 | import torch.utils.data as data 5 | import torch 6 | from torchvision import transforms 7 | # from functools import partial 8 | import numpy as np 9 | # from imageio import imread 10 | from PIL import Image 11 | import glob 12 | from scipy.io import loadmat 13 | 14 | class kernelDataset(data.Dataset): 15 | def __init__(self, dataset='x2/'): 16 | super(kernelDataset, self).__init__() 17 | 18 | base = dataset 19 | 20 | self.mat_files = sorted(glob.glob(base + '*.mat')) 21 | 22 | def __getitem__(self, index): 23 | mat = loadmat(self.mat_files[index]) 24 | x = np.array([mat['kernel']]) 25 | #x = np.swapaxes(x, 2, 0) 26 | #print(np.shape(x)) 27 | 28 | return torch.from_numpy(x).float() 29 | 30 | def __len__(self): 31 | return len(self.mat_files) 32 | 33 | 34 | class noiseDataset(data.Dataset): 35 | def __init__(self, dataset='x2/'): 36 | super(noiseDataset, self).__init__() 37 | 38 | base = dataset 39 | import os 40 | assert os.path.exists(base) 41 | 42 | # self.mat_files = sorted(glob.glob(base + '*.mat')) 43 | self.noise_imgs = sorted(glob.glob(base + '*.png')) 44 | self.pre_process = transforms.Compose([transforms.RandomCrop(32), 45 | transforms.ToTensor()]) 46 | 47 | def __getitem__(self, index): 48 | # mat = loadmat(self.mat_files[index]) 49 | # x = np.array([mat['kernel']]) 50 | # x = np.swapaxes(x, 2, 0) 51 | # print(np.shape(x)) 52 | noise = self.pre_process(Image.open(self.noise_imgs[index])) 53 | norm_noise = (noise - torch.mean(noise, dim=[1, 2], keepdim=True)) 54 | return norm_noise 55 | 56 | def __len__(self): 57 | return len(self.noise_imgs) 58 | -------------------------------------------------------------------------------- /codes/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | """ 2 | Modified from torch.utils.data.distributed.DistributedSampler 3 | Support enlarging the dataset for *iter-oriented* training, for saving time when restart the 4 | dataloader after each epoch 5 | """ 6 | import math 7 | import torch 8 | from torch.utils.data.sampler import Sampler 9 | import torch.distributed as dist 10 | 11 | 12 | class DistIterSampler(Sampler): 13 | """Sampler that restricts data loading to a subset of the dataset. 14 | 15 | It is especially useful in conjunction with 16 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 17 | process can pass a DistributedSampler instance as a DataLoader sampler, 18 | and load a subset of the original dataset that is exclusive to it. 19 | 20 | .. note:: 21 | Dataset is assumed to be of constant size. 22 | 23 | Arguments: 24 | dataset: Dataset used for sampling. 25 | num_replicas (optional): Number of processes participating in 26 | distributed training. 27 | rank (optional): Rank of the current process within num_replicas. 28 | """ 29 | 30 | def __init__(self, dataset, num_replicas=None, rank=None, ratio=100): 31 | if num_replicas is None: 32 | if not dist.is_available(): 33 | raise RuntimeError("Requires distributed package to be available") 34 | num_replicas = dist.get_world_size() 35 | if rank is None: 36 | if not dist.is_available(): 37 | raise RuntimeError("Requires distributed package to be available") 38 | rank = dist.get_rank() 39 | self.dataset = dataset 40 | self.num_replicas = num_replicas 41 | self.rank = rank 42 | self.epoch = 0 43 | self.num_samples = int(math.ceil(len(self.dataset) * ratio / self.num_replicas)) 44 | self.total_size = self.num_samples * self.num_replicas 45 | 46 | def __iter__(self): 47 | # deterministically shuffle based on epoch 48 | g = torch.Generator() 49 | g.manual_seed(self.epoch) 50 | indices = torch.randperm(self.total_size, generator=g).tolist() 51 | 52 | dsize = len(self.dataset) 53 | indices = [v % dsize for v in indices] 54 | 55 | # subsample 56 | indices = indices[self.rank:self.total_size:self.num_replicas] 57 | assert len(indices) == self.num_samples 58 | 59 | return iter(indices) 60 | 61 | def __len__(self): 62 | return self.num_samples 63 | 64 | def set_epoch(self, epoch): 65 | self.epoch = epoch 66 | -------------------------------------------------------------------------------- /codes/models/SFTGAN_ACD_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.optim import lr_scheduler 7 | 8 | import models.networks as networks 9 | from .base_model import BaseModel 10 | from models.modules.loss import GANLoss, GradientPenaltyLoss 11 | 12 | logger = logging.getLogger('base') 13 | 14 | 15 | class SFTGAN_ACD_Model(BaseModel): 16 | def __init__(self, opt): 17 | super(SFTGAN_ACD_Model, self).__init__(opt) 18 | train_opt = opt['train'] 19 | 20 | # define networks and load pretrained models 21 | self.netG = networks.define_G(opt).to(self.device) # G 22 | if self.is_train: 23 | self.netD = networks.define_D(opt).to(self.device) # D 24 | self.netG.train() 25 | self.netD.train() 26 | self.load() # load G and D if needed 27 | 28 | # define losses, optimizer and scheduler 29 | if self.is_train: 30 | # G pixel loss 31 | if train_opt['pixel_weight'] > 0: 32 | l_pix_type = train_opt['pixel_criterion'] 33 | if l_pix_type == 'l1': 34 | self.cri_pix = nn.L1Loss().to(self.device) 35 | elif l_pix_type == 'l2': 36 | self.cri_pix = nn.MSELoss().to(self.device) 37 | else: 38 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_pix_type)) 39 | self.l_pix_w = train_opt['pixel_weight'] 40 | else: 41 | logging.info('Remove pixel loss.') 42 | self.cri_pix = None 43 | 44 | # G feature loss 45 | if train_opt['feature_weight'] > 0: 46 | l_fea_type = train_opt['feature_criterion'] 47 | if l_fea_type == 'l1': 48 | self.cri_fea = nn.L1Loss().to(self.device) 49 | elif l_fea_type == 'l2': 50 | self.cri_fea = nn.MSELoss().to(self.device) 51 | else: 52 | raise NotImplementedError('Loss type [{:s}] not recognized.'.format(l_fea_type)) 53 | self.l_fea_w = train_opt['feature_weight'] 54 | else: 55 | logging.info('Remove feature loss.') 56 | self.cri_fea = None 57 | if self.cri_fea: # load VGG perceptual loss 58 | self.netF = networks.define_F(opt, use_bn=False).to(self.device) 59 | 60 | # GD gan loss 61 | self.cri_gan = GANLoss(train_opt['gan_type'], 1.0, 0.0).to(self.device) 62 | self.l_gan_w = train_opt['gan_weight'] 63 | # D_update_ratio and D_init_iters are for WGAN 64 | self.D_update_ratio = train_opt['D_update_ratio'] if train_opt['D_update_ratio'] else 1 65 | self.D_init_iters = train_opt['D_init_iters'] if train_opt['D_init_iters'] else 0 66 | 67 | if train_opt['gan_type'] == 'wgan-gp': 68 | self.random_pt = torch.Tensor(1, 1, 1, 1).to(self.device) 69 | # gradient penalty loss 70 | self.cri_gp = GradientPenaltyLoss(device=self.device).to(self.device) 71 | self.l_gp_w = train_opt['gp_weigth'] 72 | 73 | # D cls loss 74 | self.cri_ce = nn.CrossEntropyLoss(ignore_index=0).to(self.device) 75 | # ignore background, since bg images may conflict with other classes 76 | 77 | # optimizers 78 | # G 79 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 80 | optim_params_SFT = [] 81 | optim_params_other = [] 82 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model 83 | if 'SFT' in k or 'Cond' in k: 84 | optim_params_SFT.append(v) 85 | else: 86 | optim_params_other.append(v) 87 | self.optimizer_G_SFT = torch.optim.Adam(optim_params_SFT, lr=train_opt['lr_G'] * 5, 88 | weight_decay=wd_G, 89 | betas=(train_opt['beta1_G'], 0.999)) 90 | self.optimizer_G_other = torch.optim.Adam(optim_params_other, lr=train_opt['lr_G'], 91 | weight_decay=wd_G, 92 | betas=(train_opt['beta1_G'], 0.999)) 93 | self.optimizers.append(self.optimizer_G_SFT) 94 | self.optimizers.append(self.optimizer_G_other) 95 | # D 96 | wd_D = train_opt['weight_decay_D'] if train_opt['weight_decay_D'] else 0 97 | self.optimizer_D = torch.optim.Adam(self.netD.parameters(), lr=train_opt['lr_D'], 98 | weight_decay=wd_D, 99 | betas=(train_opt['beta1_D'], 0.999)) 100 | self.optimizers.append(self.optimizer_D) 101 | 102 | # schedulers 103 | if train_opt['lr_scheme'] == 'MultiStepLR': 104 | for optimizer in self.optimizers: 105 | self.schedulers.append( 106 | lr_scheduler.MultiStepLR(optimizer, train_opt['lr_steps'], 107 | train_opt['lr_gamma'])) 108 | else: 109 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 110 | 111 | self.log_dict = OrderedDict() 112 | # print network 113 | self.print_network() 114 | 115 | def feed_data(self, data, need_GT=True): 116 | # LR 117 | self.var_L = data['LR'].to(self.device) 118 | # seg 119 | self.var_seg = data['seg'].to(self.device) 120 | # category 121 | self.var_cat = data['category'].long().to(self.device) 122 | 123 | if need_GT: # train or val 124 | self.var_H = data['GT'].to(self.device) 125 | 126 | def optimize_parameters(self, step): 127 | # G 128 | self.optimizer_G_SFT.zero_grad() 129 | self.optimizer_G_other.zero_grad() 130 | self.fake_H = self.netG((self.var_L, self.var_seg)) 131 | 132 | l_g_total = 0 133 | if step % self.D_update_ratio == 0 and step > self.D_init_iters: 134 | if self.cri_pix: # pixel loss 135 | l_g_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.var_H) 136 | l_g_total += l_g_pix 137 | if self.cri_fea: # feature loss 138 | real_fea = self.netF(self.var_H).detach() 139 | fake_fea = self.netF(self.fake_H) 140 | l_g_fea = self.l_fea_w * self.cri_fea(fake_fea, real_fea) 141 | l_g_total += l_g_fea 142 | # G gan + cls loss 143 | pred_g_fake, cls_g_fake = self.netD(self.fake_H) 144 | l_g_gan = self.l_gan_w * self.cri_gan(pred_g_fake, True) 145 | l_g_cls = self.l_gan_w * self.cri_ce(cls_g_fake, self.var_cat) 146 | l_g_total += l_g_gan 147 | l_g_total += l_g_cls 148 | 149 | l_g_total.backward() 150 | self.optimizer_G_SFT.step() 151 | if step > 20000: 152 | self.optimizer_G_other.step() 153 | 154 | # D 155 | self.optimizer_D.zero_grad() 156 | l_d_total = 0 157 | # real data 158 | pred_d_real, cls_d_real = self.netD(self.var_H) 159 | l_d_real = self.cri_gan(pred_d_real, True) 160 | l_d_cls_real = self.cri_ce(cls_d_real, self.var_cat) 161 | # fake data 162 | pred_d_fake, cls_d_fake = self.netD(self.fake_H.detach()) # detach to avoid BP to G 163 | l_d_fake = self.cri_gan(pred_d_fake, False) 164 | l_d_cls_fake = self.cri_ce(cls_d_fake, self.var_cat) 165 | 166 | l_d_total = l_d_real + l_d_cls_real + l_d_fake + l_d_cls_fake 167 | 168 | if self.opt['train']['gan_type'] == 'wgan-gp': 169 | batch_size = self.var_H.size(0) 170 | if self.random_pt.size(0) != batch_size: 171 | self.random_pt.resize_(batch_size, 1, 1, 1) 172 | self.random_pt.uniform_() # Draw random interpolation points 173 | interp = self.random_pt * self.fake_H.detach() + (1 - self.random_pt) * self.var_H 174 | interp.requires_grad = True 175 | interp_crit, _ = self.netD(interp) 176 | l_d_gp = self.l_gp_w * self.cri_gp(interp, interp_crit) # maybe wrong in cls? 177 | l_d_total += l_d_gp 178 | 179 | l_d_total.backward() 180 | self.optimizer_D.step() 181 | 182 | # set log 183 | if step % self.D_update_ratio == 0 and step > self.D_init_iters: 184 | # G 185 | if self.cri_pix: 186 | self.log_dict['l_g_pix'] = l_g_pix.item() 187 | if self.cri_fea: 188 | self.log_dict['l_g_fea'] = l_g_fea.item() 189 | self.log_dict['l_g_gan'] = l_g_gan.item() 190 | # D 191 | self.log_dict['l_d_real'] = l_d_real.item() 192 | self.log_dict['l_d_fake'] = l_d_fake.item() 193 | self.log_dict['l_d_cls_real'] = l_d_cls_real.item() 194 | self.log_dict['l_d_cls_fake'] = l_d_cls_fake.item() 195 | if self.opt['train']['gan_type'] == 'wgan-gp': 196 | self.log_dict['l_d_gp'] = l_d_gp.item() 197 | # D outputs 198 | self.log_dict['D_real'] = torch.mean(pred_d_real.detach()) 199 | self.log_dict['D_fake'] = torch.mean(pred_d_fake.detach()) 200 | 201 | def test(self): 202 | self.netG.eval() 203 | with torch.no_grad(): 204 | self.fake_H = self.netG((self.var_L, self.var_seg)) 205 | self.netG.train() 206 | 207 | def get_current_log(self): 208 | return self.log_dict 209 | 210 | def get_current_visuals(self, need_GT=True): 211 | out_dict = OrderedDict() 212 | out_dict['LR'] = self.var_L.detach()[0].float().cpu() 213 | out_dict['SR'] = self.fake_H.detach()[0].float().cpu() 214 | if need_GT: 215 | out_dict['GT'] = self.var_H.detach()[0].float().cpu() 216 | return out_dict 217 | 218 | def print_network(self): 219 | # G 220 | s, n = self.get_network_description(self.netG) 221 | if isinstance(self.netG, nn.DataParallel): 222 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 223 | self.netG.module.__class__.__name__) 224 | else: 225 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 226 | 227 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 228 | logger.info(s) 229 | if self.is_train: 230 | # D 231 | s, n = self.get_network_description(self.netD) 232 | if isinstance(self.netD, nn.DataParallel): 233 | net_struc_str = '{} - {}'.format(self.netD.__class__.__name__, 234 | self.netD.module.__class__.__name__) 235 | else: 236 | net_struc_str = '{}'.format(self.netD.__class__.__name__) 237 | 238 | logger.info('Network D structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 239 | logger.info(s) 240 | 241 | if self.cri_fea: # F, Perceptual Network 242 | s, n = self.get_network_description(self.netF) 243 | if isinstance(self.netF, nn.DataParallel): 244 | net_struc_str = '{} - {}'.format(self.netF.__class__.__name__, 245 | self.netF.module.__class__.__name__) 246 | else: 247 | net_struc_str = '{}'.format(self.netF.__class__.__name__) 248 | 249 | logger.info('Network F structure: {}, with parameters: {:,d}'.format( 250 | net_struc_str, n)) 251 | logger.info(s) 252 | 253 | def load(self): 254 | load_path_G = self.opt['path']['pretrain_model_G'] 255 | if load_path_G is not None: 256 | logger.info('Loading pretrained model for G [{:s}] ...'.format(load_path_G)) 257 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 258 | load_path_D = self.opt['path']['pretrain_model_D'] 259 | if self.opt['is_train'] and load_path_D is not None: 260 | logger.info('Loading pretrained model for D [{:s}] ...'.format(load_path_D)) 261 | self.load_network(load_path_D, self.netD, self.opt['path']['strict_load']) 262 | 263 | def save(self, iter_step): 264 | self.save_network(self.netG, 'G', iter_step) 265 | self.save_network(self.netD, 'D', iter_step) 266 | -------------------------------------------------------------------------------- /codes/models/SR_model.py: -------------------------------------------------------------------------------- 1 | import logging 2 | from collections import OrderedDict 3 | 4 | import torch 5 | import torch.nn as nn 6 | from torch.nn.parallel import DataParallel, DistributedDataParallel 7 | import models.networks as networks 8 | import models.lr_scheduler as lr_scheduler 9 | from .base_model import BaseModel 10 | from models.modules.loss import CharbonnierLoss 11 | 12 | logger = logging.getLogger('base') 13 | 14 | 15 | class SRModel(BaseModel): 16 | def __init__(self, opt): 17 | super(SRModel, self).__init__(opt) 18 | 19 | if opt['dist']: 20 | self.rank = torch.distributed.get_rank() 21 | else: 22 | self.rank = -1 # non dist training 23 | train_opt = opt['train'] 24 | 25 | # define network and load pretrained models 26 | self.netG = networks.define_G(opt).to(self.device) 27 | if opt['dist']: 28 | self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) 29 | else: 30 | self.netG = DataParallel(self.netG) 31 | # print network 32 | self.print_network() 33 | self.load() 34 | 35 | if self.is_train: 36 | self.netG.train() 37 | 38 | # loss 39 | loss_type = train_opt['pixel_criterion'] 40 | if loss_type == 'l1': 41 | self.cri_pix = nn.L1Loss().to(self.device) 42 | elif loss_type == 'l2': 43 | self.cri_pix = nn.MSELoss().to(self.device) 44 | elif loss_type == 'cb': 45 | self.cri_pix = CharbonnierLoss().to(self.device) 46 | else: 47 | raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) 48 | self.l_pix_w = train_opt['pixel_weight'] 49 | 50 | # optimizers 51 | wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 52 | optim_params = [] 53 | for k, v in self.netG.named_parameters(): # can optimize for a part of the model 54 | if v.requires_grad: 55 | optim_params.append(v) 56 | else: 57 | if self.rank <= 0: 58 | logger.warning('Params [{:s}] will not optimize.'.format(k)) 59 | self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], 60 | weight_decay=wd_G, 61 | betas=(train_opt['beta1'], train_opt['beta2'])) 62 | self.optimizers.append(self.optimizer_G) 63 | 64 | # schedulers 65 | if train_opt['lr_scheme'] == 'MultiStepLR': 66 | for optimizer in self.optimizers: 67 | self.schedulers.append( 68 | lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], 69 | restarts=train_opt['restarts'], 70 | weights=train_opt['restart_weights'], 71 | gamma=train_opt['lr_gamma'], 72 | clear_state=train_opt['clear_state'])) 73 | elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': 74 | for optimizer in self.optimizers: 75 | self.schedulers.append( 76 | lr_scheduler.CosineAnnealingLR_Restart( 77 | optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], 78 | restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) 79 | else: 80 | raise NotImplementedError('MultiStepLR learning rate scheme is enough.') 81 | 82 | self.log_dict = OrderedDict() 83 | 84 | def feed_data(self, data, need_GT=True): 85 | self.var_L = data['LQ'].to(self.device) # LQ 86 | if need_GT: 87 | self.real_H = data['GT'].to(self.device) # GT 88 | 89 | def optimize_parameters(self, step): 90 | self.optimizer_G.zero_grad() 91 | self.fake_H = self.netG(self.var_L) 92 | l_pix = self.l_pix_w * self.cri_pix(self.fake_H, self.real_H) 93 | l_pix.backward() 94 | self.optimizer_G.step() 95 | 96 | # set log 97 | self.log_dict['l_pix'] = l_pix.item() 98 | 99 | def test(self): 100 | self.netG.eval() 101 | with torch.no_grad(): 102 | self.fake_H = self.netG(self.var_L) 103 | self.netG.train() 104 | 105 | def test_x8(self): 106 | # from https://github.com/thstkdgus35/EDSR-PyTorch 107 | self.netG.eval() 108 | 109 | def _transform(v, op): 110 | # if self.precision != 'single': v = v.float() 111 | v2np = v.data.cpu().numpy() 112 | if op == 'v': 113 | tfnp = v2np[:, :, :, ::-1].copy() 114 | elif op == 'h': 115 | tfnp = v2np[:, :, ::-1, :].copy() 116 | elif op == 't': 117 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 118 | 119 | ret = torch.Tensor(tfnp).to(self.device) 120 | # if self.precision == 'half': ret = ret.half() 121 | 122 | return ret 123 | 124 | lr_list = [self.var_L] 125 | for tf in 'v', 'h', 't': 126 | lr_list.extend([_transform(t, tf) for t in lr_list]) 127 | with torch.no_grad(): 128 | sr_list = [self.netG(aug) for aug in lr_list] 129 | for i in range(len(sr_list)): 130 | if i > 3: 131 | sr_list[i] = _transform(sr_list[i], 't') 132 | if i % 4 > 1: 133 | sr_list[i] = _transform(sr_list[i], 'h') 134 | if (i % 4) % 2 == 1: 135 | sr_list[i] = _transform(sr_list[i], 'v') 136 | 137 | output_cat = torch.cat(sr_list, dim=0) 138 | self.fake_H = output_cat.mean(dim=0, keepdim=True) 139 | self.netG.train() 140 | 141 | def get_current_log(self): 142 | return self.log_dict 143 | 144 | def get_current_visuals(self, need_GT=True): 145 | out_dict = OrderedDict() 146 | out_dict['LQ'] = self.var_L.detach()[0].float().cpu() 147 | out_dict['SR'] = self.fake_H.detach()[0].float().cpu() 148 | if need_GT: 149 | out_dict['GT'] = self.real_H.detach()[0].float().cpu() 150 | return out_dict 151 | 152 | def print_network(self): 153 | s, n = self.get_network_description(self.netG) 154 | if isinstance(self.netG, nn.DataParallel) or isinstance(self.netG, DistributedDataParallel): 155 | net_struc_str = '{} - {}'.format(self.netG.__class__.__name__, 156 | self.netG.module.__class__.__name__) 157 | else: 158 | net_struc_str = '{}'.format(self.netG.__class__.__name__) 159 | if self.rank <= 0: 160 | logger.info('Network G structure: {}, with parameters: {:,d}'.format(net_struc_str, n)) 161 | logger.info(s) 162 | 163 | def load(self): 164 | load_path_G = self.opt['path']['pretrain_model_G'] 165 | if load_path_G is not None: 166 | logger.info('Loading model for G [{:s}] ...'.format(load_path_G)) 167 | self.load_network(load_path_G, self.netG, self.opt['path']['strict_load']) 168 | 169 | def save(self, iter_label): 170 | self.save_network(self.netG, 'G', iter_label) 171 | -------------------------------------------------------------------------------- /codes/models/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | logger = logging.getLogger('base') 3 | 4 | 5 | def create_model(opt): 6 | model = opt['model'] 7 | 8 | if model == 'sr': 9 | from .SR_model import SRModel as M 10 | elif model == 'srgan': 11 | from .SRGAN_model import SRGANModel as M 12 | elif model == 'sftgan': 13 | from .SFTGAN_ACD_model import SFTGAN_ACD_Model as M 14 | elif model == 'noisegan': 15 | from .NOISEGAN_model import NOISEGANModel as M 16 | else: 17 | raise NotImplementedError('Model [{:s}] not recognized.'.format(model)) 18 | m = M(opt) 19 | logger.info('Model [{:s}] is created.'.format(m.__class__.__name__)) 20 | return m 21 | -------------------------------------------------------------------------------- /codes/models/base_model.py: -------------------------------------------------------------------------------- 1 | import os 2 | from collections import OrderedDict 3 | import torch 4 | import torch.nn as nn 5 | from torch.nn.parallel import DistributedDataParallel 6 | 7 | 8 | class BaseModel(): 9 | def __init__(self, opt): 10 | self.opt = opt 11 | self.device = torch.device('cuda' if opt['gpu_ids'] is not None else 'cpu') 12 | self.is_train = opt['is_train'] 13 | self.schedulers = [] 14 | self.optimizers = [] 15 | 16 | def feed_data(self, data): 17 | pass 18 | 19 | def optimize_parameters(self): 20 | pass 21 | 22 | def get_current_visuals(self): 23 | pass 24 | 25 | def get_current_losses(self): 26 | pass 27 | 28 | def print_network(self): 29 | pass 30 | 31 | def save(self, label): 32 | pass 33 | 34 | def load(self): 35 | pass 36 | 37 | def _set_lr(self, lr_groups_l): 38 | ''' set learning rate for warmup, 39 | lr_groups_l: list for lr_groups. each for a optimizer''' 40 | for optimizer, lr_groups in zip(self.optimizers, lr_groups_l): 41 | for param_group, lr in zip(optimizer.param_groups, lr_groups): 42 | param_group['lr'] = lr 43 | 44 | def _get_init_lr(self): 45 | # get the initial lr, which is set by the scheduler 46 | init_lr_groups_l = [] 47 | for optimizer in self.optimizers: 48 | init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups]) 49 | return init_lr_groups_l 50 | 51 | def update_learning_rate(self, cur_iter, warmup_iter=-1): 52 | for scheduler in self.schedulers: 53 | scheduler.step() 54 | #### set up warm up learning rate 55 | if cur_iter < warmup_iter: 56 | # get initial lr for each group 57 | init_lr_g_l = self._get_init_lr() 58 | # modify warming-up learning rates 59 | warm_up_lr_l = [] 60 | for init_lr_g in init_lr_g_l: 61 | warm_up_lr_l.append([v / warmup_iter * cur_iter for v in init_lr_g]) 62 | # set learning rate 63 | self._set_lr(warm_up_lr_l) 64 | 65 | def get_current_learning_rate(self): 66 | # return self.schedulers[0].get_lr()[0] 67 | return self.optimizers[0].param_groups[0]['lr'] 68 | 69 | def get_network_description(self, network): 70 | '''Get the string and total parameters of the network''' 71 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 72 | network = network.module 73 | s = str(network) 74 | n = sum(map(lambda x: x.numel(), network.parameters())) 75 | return s, n 76 | 77 | def save_network(self, network, network_label, iter_label): 78 | save_filename = '{}_{}.pth'.format(iter_label, network_label) 79 | save_path = os.path.join(self.opt['path']['models'], save_filename) 80 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 81 | network = network.module 82 | state_dict = network.state_dict() 83 | for key, param in state_dict.items(): 84 | state_dict[key] = param.cpu() 85 | torch.save(state_dict, save_path) 86 | 87 | def load_network(self, load_path, network, strict=True): 88 | if isinstance(network, nn.DataParallel) or isinstance(network, DistributedDataParallel): 89 | network = network.module 90 | load_net = torch.load(load_path) 91 | load_net_clean = OrderedDict() # remove unnecessary 'module.' 92 | for k, v in load_net.items(): 93 | if k.startswith('module.'): 94 | load_net_clean[k[7:]] = v 95 | else: 96 | load_net_clean[k] = v 97 | network.load_state_dict(load_net_clean, strict=strict) 98 | 99 | def save_training_state(self, epoch, iter_step): 100 | '''Saves training state during training, which will be used for resuming''' 101 | state = {'epoch': epoch, 'iter': iter_step, 'schedulers': [], 'optimizers': []} 102 | for s in self.schedulers: 103 | state['schedulers'].append(s.state_dict()) 104 | for o in self.optimizers: 105 | state['optimizers'].append(o.state_dict()) 106 | save_filename = '{}.state'.format(iter_step) 107 | save_path = os.path.join(self.opt['path']['training_state'], save_filename) 108 | torch.save(state, save_path) 109 | 110 | def resume_training(self, resume_state): 111 | '''Resume the optimizers and schedulers for training''' 112 | resume_optimizers = resume_state['optimizers'] 113 | resume_schedulers = resume_state['schedulers'] 114 | assert len(resume_optimizers) == len(self.optimizers), 'Wrong lengths of optimizers' 115 | assert len(resume_schedulers) == len(self.schedulers), 'Wrong lengths of schedulers' 116 | for i, o in enumerate(resume_optimizers): 117 | self.optimizers[i].load_state_dict(o) 118 | for i, s in enumerate(resume_schedulers): 119 | self.schedulers[i].load_state_dict(s) 120 | 121 | -------------------------------------------------------------------------------- /codes/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from collections import defaultdict 4 | import torch 5 | from torch.optim.lr_scheduler import _LRScheduler 6 | 7 | 8 | class MultiStepLR_Restart(_LRScheduler): 9 | def __init__(self, optimizer, milestones, restarts=None, weights=None, gamma=0.1, 10 | clear_state=False, last_epoch=-1): 11 | self.milestones = Counter(milestones) 12 | self.gamma = gamma 13 | self.clear_state = clear_state 14 | self.restarts = restarts if restarts else [0] 15 | self.restart_weights = weights if weights else [1] 16 | assert len(self.restarts) == len( 17 | self.restart_weights), 'restarts and their weights do not match.' 18 | super(MultiStepLR_Restart, self).__init__(optimizer, last_epoch) 19 | 20 | def get_lr(self): 21 | if self.last_epoch in self.restarts: 22 | if self.clear_state: 23 | self.optimizer.state = defaultdict(dict) 24 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 25 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 26 | if self.last_epoch not in self.milestones: 27 | return [group['lr'] for group in self.optimizer.param_groups] 28 | return [ 29 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 30 | for group in self.optimizer.param_groups 31 | ] 32 | 33 | 34 | class CosineAnnealingLR_Restart(_LRScheduler): 35 | def __init__(self, optimizer, T_period, restarts=None, weights=None, eta_min=0, last_epoch=-1): 36 | self.T_period = T_period 37 | self.T_max = self.T_period[0] # current T period 38 | self.eta_min = eta_min 39 | self.restarts = restarts if restarts else [0] 40 | self.restart_weights = weights if weights else [1] 41 | self.last_restart = 0 42 | assert len(self.restarts) == len( 43 | self.restart_weights), 'restarts and their weights do not match.' 44 | super(CosineAnnealingLR_Restart, self).__init__(optimizer, last_epoch) 45 | 46 | def get_lr(self): 47 | if self.last_epoch == 0: 48 | return self.base_lrs 49 | elif self.last_epoch in self.restarts: 50 | self.last_restart = self.last_epoch 51 | self.T_max = self.T_period[self.restarts.index(self.last_epoch) + 1] 52 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 53 | return [group['initial_lr'] * weight for group in self.optimizer.param_groups] 54 | elif (self.last_epoch - self.last_restart - 1 - self.T_max) % (2 * self.T_max) == 0: 55 | return [ 56 | group['lr'] + (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2 57 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 58 | ] 59 | return [(1 + math.cos(math.pi * (self.last_epoch - self.last_restart) / self.T_max)) / 60 | (1 + math.cos(math.pi * ((self.last_epoch - self.last_restart) - 1) / self.T_max)) * 61 | (group['lr'] - self.eta_min) + self.eta_min 62 | for group in self.optimizer.param_groups] 63 | 64 | 65 | if __name__ == "__main__": 66 | optimizer = torch.optim.Adam([torch.zeros(3, 64, 3, 3)], lr=2e-4, weight_decay=0, 67 | betas=(0.9, 0.99)) 68 | ############################## 69 | # MultiStepLR_Restart 70 | ############################## 71 | ## Original 72 | lr_steps = [200000, 400000, 600000, 800000] 73 | restarts = None 74 | restart_weights = None 75 | 76 | ## two 77 | lr_steps = [100000, 200000, 300000, 400000, 490000, 600000, 700000, 800000, 900000, 990000] 78 | restarts = [500000] 79 | restart_weights = [1] 80 | 81 | ## four 82 | lr_steps = [ 83 | 50000, 100000, 150000, 200000, 240000, 300000, 350000, 400000, 450000, 490000, 550000, 84 | 600000, 650000, 700000, 740000, 800000, 850000, 900000, 950000, 990000 85 | ] 86 | restarts = [250000, 500000, 750000] 87 | restart_weights = [1, 1, 1] 88 | 89 | scheduler = MultiStepLR_Restart(optimizer, lr_steps, restarts, restart_weights, gamma=0.5, 90 | clear_state=False) 91 | 92 | ############################## 93 | # Cosine Annealing Restart 94 | ############################## 95 | ## two 96 | T_period = [500000, 500000] 97 | restarts = [500000] 98 | restart_weights = [1] 99 | 100 | ## four 101 | T_period = [250000, 250000, 250000, 250000] 102 | restarts = [250000, 500000, 750000] 103 | restart_weights = [1, 1, 1] 104 | 105 | scheduler = CosineAnnealingLR_Restart(optimizer, T_period, eta_min=1e-7, restarts=restarts, 106 | weights=restart_weights) 107 | 108 | ############################## 109 | # Draw figure 110 | ############################## 111 | N_iter = 1000000 112 | lr_l = list(range(N_iter)) 113 | for i in range(N_iter): 114 | scheduler.step() 115 | current_lr = optimizer.param_groups[0]['lr'] 116 | lr_l[i] = current_lr 117 | 118 | import matplotlib as mpl 119 | from matplotlib import pyplot as plt 120 | import matplotlib.ticker as mtick 121 | mpl.style.use('default') 122 | import seaborn 123 | seaborn.set(style='whitegrid') 124 | seaborn.set_context('paper') 125 | 126 | plt.figure(1) 127 | plt.subplot(111) 128 | plt.ticklabel_format(style='sci', axis='x', scilimits=(0, 0)) 129 | plt.title('Title', fontsize=16, color='k') 130 | plt.plot(list(range(N_iter)), lr_l, linewidth=1.5, label='learning rate scheme') 131 | legend = plt.legend(loc='upper right', shadow=False) 132 | ax = plt.gca() 133 | labels = ax.get_xticks().tolist() 134 | for k, v in enumerate(labels): 135 | labels[k] = str(int(v / 1000)) + 'K' 136 | ax.set_xticklabels(labels) 137 | ax.yaxis.set_major_formatter(mtick.FormatStrFormatter('%.1e')) 138 | 139 | ax.set_ylabel('Learning rate') 140 | ax.set_xlabel('Iteration') 141 | fig = plt.gcf() 142 | plt.show() 143 | -------------------------------------------------------------------------------- /codes/models/modules/RRDBNet_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | import models.modules.module_util as mutil 6 | 7 | 8 | class ResidualDenseBlock_5C(nn.Module): 9 | def __init__(self, nf=64, gc=32, bias=True): 10 | super(ResidualDenseBlock_5C, self).__init__() 11 | # gc: growth channel, i.e. intermediate channels 12 | self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias) 13 | self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias) 14 | self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias) 15 | self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias) 16 | self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias) 17 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=False) 18 | 19 | # initialization 20 | mutil.initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1) 21 | 22 | def forward(self, x): 23 | x1 = self.lrelu(self.conv1(x)) 24 | x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) 25 | x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) 26 | x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) 27 | x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) 28 | return x5 * 0.2 + x 29 | 30 | 31 | class RRDB(nn.Module): 32 | '''Residual in Residual Dense Block''' 33 | 34 | def __init__(self, nf, gc=32): 35 | super(RRDB, self).__init__() 36 | self.RDB1 = ResidualDenseBlock_5C(nf, gc) 37 | self.RDB2 = ResidualDenseBlock_5C(nf, gc) 38 | self.RDB3 = ResidualDenseBlock_5C(nf, gc) 39 | 40 | def forward(self, x): 41 | out = self.RDB1(x) 42 | out = self.RDB2(out) 43 | out = self.RDB3(out) 44 | return out * 0.2 + x 45 | 46 | 47 | class RRDBNet(nn.Module): 48 | def __init__(self, in_nc, out_nc, nf, nb, gc=32): 49 | super(RRDBNet, self).__init__() 50 | RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc) 51 | 52 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 53 | self.RRDB_trunk = mutil.make_layer(RRDB_block_f, nb) 54 | self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 55 | #### upsampling 56 | self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 57 | self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 58 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 59 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 60 | 61 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=False) 62 | 63 | def forward(self, x): 64 | fea = self.conv_first(x) 65 | trunk = self.trunk_conv(self.RRDB_trunk(fea)) 66 | fea = fea + trunk 67 | 68 | fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest'))) 69 | fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest'))) 70 | out = self.conv_last(self.lrelu(self.HRconv(fea))) 71 | 72 | return out 73 | -------------------------------------------------------------------------------- /codes/models/modules/SRResNet_arch.py: -------------------------------------------------------------------------------- 1 | import functools 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import models.modules.module_util as mutil 5 | 6 | 7 | class MSRResNet(nn.Module): 8 | ''' modified SRResNet''' 9 | 10 | def __init__(self, in_nc=3, out_nc=3, nf=64, nb=16, upscale=4): 11 | super(MSRResNet, self).__init__() 12 | self.upscale = upscale 13 | 14 | self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 15 | basic_block = functools.partial(mutil.ResidualBlock_noBN, nf=nf) 16 | self.recon_trunk = mutil.make_layer(basic_block, nb) 17 | 18 | # upsampling 19 | if self.upscale == 2: 20 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 21 | self.pixel_shuffle = nn.PixelShuffle(2) 22 | elif self.upscale == 3: 23 | self.upconv1 = nn.Conv2d(nf, nf * 9, 3, 1, 1, bias=True) 24 | self.pixel_shuffle = nn.PixelShuffle(3) 25 | elif self.upscale == 4: 26 | self.upconv1 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 27 | self.upconv2 = nn.Conv2d(nf, nf * 4, 3, 1, 1, bias=True) 28 | self.pixel_shuffle = nn.PixelShuffle(2) 29 | 30 | self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 31 | self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True) 32 | 33 | # activation function 34 | self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True) 35 | 36 | # initialization 37 | mutil.initialize_weights([self.conv_first, self.upconv1, self.HRconv, self.conv_last], 0.1) 38 | if self.upscale == 4: 39 | mutil.initialize_weights(self.upconv2, 0.1) 40 | 41 | def forward(self, x): 42 | fea = self.lrelu(self.conv_first(x)) 43 | out = self.recon_trunk(fea) 44 | 45 | if self.upscale == 4: 46 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 47 | out = self.lrelu(self.pixel_shuffle(self.upconv2(out))) 48 | elif self.upscale == 3 or self.upscale == 2: 49 | out = self.lrelu(self.pixel_shuffle(self.upconv1(out))) 50 | 51 | out = self.conv_last(self.lrelu(self.HRconv(out))) 52 | base = F.interpolate(x, scale_factor=self.upscale, mode='bilinear', align_corners=False) 53 | out += base 54 | return out 55 | -------------------------------------------------------------------------------- /codes/models/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/codes/models/modules/__init__.py -------------------------------------------------------------------------------- /codes/models/modules/__pycache__/RRDBNet_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/codes/models/modules/__pycache__/RRDBNet_arch.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/modules/__pycache__/SRResNet_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/codes/models/modules/__pycache__/SRResNet_arch.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/modules/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/codes/models/modules/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/modules/__pycache__/discriminator_vgg_arch.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/codes/models/modules/__pycache__/discriminator_vgg_arch.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/modules/__pycache__/loss.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/codes/models/modules/__pycache__/loss.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/modules/__pycache__/module_util.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/codes/models/modules/__pycache__/module_util.cpython-36.pyc -------------------------------------------------------------------------------- /codes/models/modules/discriminator_vgg_arch.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision 4 | import functools 5 | 6 | class NLayerDiscriminator(nn.Module): 7 | """Defines a PatchGAN discriminator""" 8 | 9 | def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d): 10 | """Construct a PatchGAN discriminator 11 | 12 | Parameters: 13 | input_nc (int) -- the number of channels in input images 14 | ndf (int) -- the number of filters in the last conv layer 15 | n_layers (int) -- the number of conv layers in the discriminator 16 | norm_layer -- normalization layer 17 | """ 18 | super(NLayerDiscriminator, self).__init__() 19 | if type(norm_layer) == functools.partial: # no need to use bias as BatchNorm2d has affine parameters 20 | use_bias = norm_layer.func == nn.InstanceNorm2d 21 | else: 22 | use_bias = norm_layer == nn.InstanceNorm2d 23 | # use_bias = False 24 | kw = 4 25 | padw = 1 26 | sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, False)] 27 | nf_mult = 1 28 | nf_mult_prev = 1 29 | for n in range(1, n_layers): # gradually increase the number of filters 30 | nf_mult_prev = nf_mult 31 | nf_mult = min(2 ** n, 8) 32 | sequence += [ 33 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias), 34 | norm_layer(ndf * nf_mult), 35 | nn.LeakyReLU(0.2, False) 36 | ] 37 | 38 | nf_mult_prev = nf_mult 39 | nf_mult = min(2 ** n_layers, 8) 40 | sequence += [ 41 | nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias), 42 | norm_layer(ndf * nf_mult), 43 | nn.LeakyReLU(0.2, False) 44 | ] 45 | 46 | sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map 47 | # TODO 48 | self.model = nn.Sequential(*sequence) 49 | 50 | def forward(self, x): 51 | """Standard forward.""" 52 | return self.model(x) 53 | 54 | 55 | class Discriminator_VGG_128(nn.Module): 56 | def __init__(self, in_nc, nf): 57 | super(Discriminator_VGG_128, self).__init__() 58 | # [64, 128, 128] 59 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 60 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 61 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 62 | # [64, 64, 64] 63 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 64 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 65 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 66 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 67 | # [128, 32, 32] 68 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 69 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 70 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 71 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 72 | # [256, 16, 16] 73 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 74 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 75 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 76 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 77 | # [512, 8, 8] 78 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 79 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 80 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 81 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 82 | 83 | self.linear1 = nn.Linear(512 * 4 * 4, 100) 84 | self.linear2 = nn.Linear(100, 1) 85 | 86 | # activation function 87 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=False) 88 | 89 | def forward(self, x): 90 | fea = self.lrelu(self.conv0_0(x)) 91 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 92 | 93 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 94 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 95 | 96 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 97 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 98 | 99 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 100 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 101 | 102 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 103 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 104 | 105 | fea = fea.view(fea.size(0), -1) 106 | fea = self.lrelu(self.linear1(fea)) 107 | out = self.linear2(fea) 108 | return out 109 | 110 | class Discriminator_VGG_256(nn.Module): 111 | def __init__(self, in_nc, nf): 112 | super(Discriminator_VGG_256, self).__init__() 113 | # [64, 128, 128] 114 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 115 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 116 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 117 | # [64, 64, 64] 118 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 119 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 120 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 121 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 122 | # [128, 32, 32] 123 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 124 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 125 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 126 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 127 | # [256, 16, 16] 128 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 129 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 130 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 131 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 132 | # [512, 8, 8] 133 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 134 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 135 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 136 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 137 | 138 | self.conv5_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 139 | self.bn5_0 = nn.BatchNorm2d(nf * 8, affine=True) 140 | self.conv5_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 141 | self.bn5_1 = nn.BatchNorm2d(nf * 8, affine=True) 142 | 143 | self.linear1 = nn.Linear(512 * 4 * 4, 100) 144 | self.linear2 = nn.Linear(100, 1) 145 | 146 | # activation function 147 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=False) 148 | 149 | def forward(self, x): 150 | fea = self.lrelu(self.conv0_0(x)) 151 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 152 | 153 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 154 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 155 | 156 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 157 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 158 | 159 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 160 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 161 | 162 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 163 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 164 | 165 | fea = self.lrelu(self.bn5_0(self.conv5_0(fea))) 166 | fea = self.lrelu(self.bn5_1(self.conv5_1(fea))) 167 | 168 | fea = fea.view(fea.size(0), -1) 169 | fea = self.lrelu(self.linear1(fea)) 170 | out = self.linear2(fea) 171 | return out 172 | 173 | 174 | class Discriminator_VGG_512(nn.Module): 175 | def __init__(self, in_nc, nf): 176 | super(Discriminator_VGG_512, self).__init__() 177 | # [64, 128, 128] 178 | self.conv0_0 = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True) 179 | self.conv0_1 = nn.Conv2d(nf, nf, 4, 2, 1, bias=False) 180 | self.bn0_1 = nn.BatchNorm2d(nf, affine=True) 181 | # [64, 64, 64] 182 | self.conv1_0 = nn.Conv2d(nf, nf * 2, 3, 1, 1, bias=False) 183 | self.bn1_0 = nn.BatchNorm2d(nf * 2, affine=True) 184 | self.conv1_1 = nn.Conv2d(nf * 2, nf * 2, 4, 2, 1, bias=False) 185 | self.bn1_1 = nn.BatchNorm2d(nf * 2, affine=True) 186 | # [128, 32, 32] 187 | self.conv2_0 = nn.Conv2d(nf * 2, nf * 4, 3, 1, 1, bias=False) 188 | self.bn2_0 = nn.BatchNorm2d(nf * 4, affine=True) 189 | self.conv2_1 = nn.Conv2d(nf * 4, nf * 4, 4, 2, 1, bias=False) 190 | self.bn2_1 = nn.BatchNorm2d(nf * 4, affine=True) 191 | # [256, 16, 16] 192 | self.conv3_0 = nn.Conv2d(nf * 4, nf * 8, 3, 1, 1, bias=False) 193 | self.bn3_0 = nn.BatchNorm2d(nf * 8, affine=True) 194 | self.conv3_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 195 | self.bn3_1 = nn.BatchNorm2d(nf * 8, affine=True) 196 | # [512, 8, 8] 197 | self.conv4_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 198 | self.bn4_0 = nn.BatchNorm2d(nf * 8, affine=True) 199 | self.conv4_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 200 | self.bn4_1 = nn.BatchNorm2d(nf * 8, affine=True) 201 | 202 | self.conv5_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 203 | self.bn5_0 = nn.BatchNorm2d(nf * 8, affine=True) 204 | self.conv5_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 205 | self.bn5_1 = nn.BatchNorm2d(nf * 8, affine=True) 206 | 207 | self.conv6_0 = nn.Conv2d(nf * 8, nf * 8, 3, 1, 1, bias=False) 208 | self.bn6_0 = nn.BatchNorm2d(nf * 8, affine=True) 209 | self.conv6_1 = nn.Conv2d(nf * 8, nf * 8, 4, 2, 1, bias=False) 210 | self.bn6_1 = nn.BatchNorm2d(nf * 8, affine=True) 211 | 212 | self.linear1 = nn.Linear(512 * 4 * 4, 100) 213 | self.linear2 = nn.Linear(100, 1) 214 | 215 | # activation function 216 | self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=False) 217 | 218 | def forward(self, x): 219 | fea = self.lrelu(self.conv0_0(x)) 220 | fea = self.lrelu(self.bn0_1(self.conv0_1(fea))) 221 | 222 | fea = self.lrelu(self.bn1_0(self.conv1_0(fea))) 223 | fea = self.lrelu(self.bn1_1(self.conv1_1(fea))) 224 | 225 | fea = self.lrelu(self.bn2_0(self.conv2_0(fea))) 226 | fea = self.lrelu(self.bn2_1(self.conv2_1(fea))) 227 | 228 | fea = self.lrelu(self.bn3_0(self.conv3_0(fea))) 229 | fea = self.lrelu(self.bn3_1(self.conv3_1(fea))) 230 | 231 | fea = self.lrelu(self.bn4_0(self.conv4_0(fea))) 232 | fea = self.lrelu(self.bn4_1(self.conv4_1(fea))) 233 | 234 | fea = self.lrelu(self.bn5_0(self.conv5_0(fea))) 235 | fea = self.lrelu(self.bn5_1(self.conv5_1(fea))) 236 | 237 | fea = self.lrelu(self.bn6_0(self.conv6_0(fea))) 238 | fea = self.lrelu(self.bn6_1(self.conv6_1(fea))) 239 | 240 | fea = fea.view(fea.size(0), -1) 241 | fea = self.lrelu(self.linear1(fea)) 242 | out = self.linear2(fea) 243 | return out 244 | 245 | 246 | class VGGFeatureExtractor(nn.Module): 247 | def __init__(self, feature_layer=34, use_bn=False, use_input_norm=True, 248 | device=torch.device('cpu')): 249 | super(VGGFeatureExtractor, self).__init__() 250 | self.use_input_norm = use_input_norm 251 | if use_bn: 252 | model = torchvision.models.vgg19_bn(pretrained=True) 253 | else: 254 | model = torchvision.models.vgg19(pretrained=True) 255 | if self.use_input_norm: 256 | mean = torch.Tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(device) 257 | # [0.485 - 1, 0.456 - 1, 0.406 - 1] if input in range [-1, 1] 258 | std = torch.Tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(device) 259 | # [0.229 * 2, 0.224 * 2, 0.225 * 2] if input in range [-1, 1] 260 | self.register_buffer('mean', mean) 261 | self.register_buffer('std', std) 262 | self.features = nn.Sequential(*list(model.features.children())[:(feature_layer + 1)]) 263 | # No need to BP to variable 264 | for k, v in self.features.named_parameters(): 265 | v.requires_grad = False 266 | 267 | def forward(self, x): 268 | # Assume input range is [0, 1] 269 | if self.use_input_norm: 270 | x = (x - self.mean) / self.std 271 | output = self.features(x) 272 | return output 273 | -------------------------------------------------------------------------------- /codes/models/modules/loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class CharbonnierLoss(nn.Module): 6 | """Charbonnier Loss (L1)""" 7 | 8 | def __init__(self, eps=1e-6): 9 | super(CharbonnierLoss, self).__init__() 10 | self.eps = eps 11 | 12 | def forward(self, x, y): 13 | diff = x - y 14 | loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 15 | return loss 16 | 17 | 18 | # Define GAN loss: [vanilla | lsgan | wgan-gp] 19 | class GANLoss(nn.Module): 20 | def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0): 21 | super(GANLoss, self).__init__() 22 | self.gan_type = gan_type.lower() 23 | self.real_label_val = real_label_val 24 | self.fake_label_val = fake_label_val 25 | 26 | if self.gan_type == 'gan' or self.gan_type == 'ragan': 27 | self.loss = nn.BCEWithLogitsLoss() 28 | elif self.gan_type == 'lsgan': 29 | self.loss = nn.MSELoss() 30 | elif self.gan_type == 'wgan-gp': 31 | 32 | def wgan_loss(input, target): 33 | # target is boolean 34 | return -1 * input.mean() if target else input.mean() 35 | 36 | self.loss = wgan_loss 37 | else: 38 | raise NotImplementedError('GAN type [{:s}] is not found'.format(self.gan_type)) 39 | 40 | def get_target_label(self, input, target_is_real): 41 | if self.gan_type == 'wgan-gp': 42 | return target_is_real 43 | if target_is_real: 44 | return torch.empty_like(input).fill_(self.real_label_val) 45 | else: 46 | return torch.empty_like(input).fill_(self.fake_label_val) 47 | 48 | def forward(self, input, target_is_real): 49 | target_label = self.get_target_label(input, target_is_real) 50 | loss = self.loss(input, target_label) 51 | return loss 52 | 53 | 54 | class GradientPenaltyLoss(nn.Module): 55 | def __init__(self, device=torch.device('cpu')): 56 | super(GradientPenaltyLoss, self).__init__() 57 | self.register_buffer('grad_outputs', torch.Tensor()) 58 | self.grad_outputs = self.grad_outputs.to(device) 59 | 60 | def get_grad_outputs(self, input): 61 | if self.grad_outputs.size() != input.size(): 62 | self.grad_outputs.resize_(input.size()).fill_(1.0) 63 | return self.grad_outputs 64 | 65 | def forward(self, interp, interp_crit): 66 | grad_outputs = self.get_grad_outputs(interp_crit) 67 | grad_interp = torch.autograd.grad(outputs=interp_crit, inputs=interp, 68 | grad_outputs=grad_outputs, create_graph=True, 69 | retain_graph=True, only_inputs=True)[0] 70 | grad_interp = grad_interp.view(grad_interp.size(0), -1) 71 | grad_interp_norm = grad_interp.norm(2, dim=1) 72 | 73 | loss = ((grad_interp_norm - 1)**2).mean() 74 | return loss 75 | -------------------------------------------------------------------------------- /codes/models/modules/module_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | 6 | 7 | def initialize_weights(net_l, scale=1): 8 | if not isinstance(net_l, list): 9 | net_l = [net_l] 10 | for net in net_l: 11 | for m in net.modules(): 12 | if isinstance(m, nn.Conv2d): 13 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 14 | m.weight.data *= scale # for residual block 15 | if m.bias is not None: 16 | m.bias.data.zero_() 17 | elif isinstance(m, nn.Linear): 18 | init.kaiming_normal_(m.weight, a=0, mode='fan_in') 19 | m.weight.data *= scale 20 | if m.bias is not None: 21 | m.bias.data.zero_() 22 | elif isinstance(m, nn.BatchNorm2d): 23 | init.constant_(m.weight, 1) 24 | init.constant_(m.bias.data, 0.0) 25 | 26 | 27 | def make_layer(block, n_layers): 28 | layers = [] 29 | for _ in range(n_layers): 30 | layers.append(block()) 31 | return nn.Sequential(*layers) 32 | 33 | 34 | class ResidualBlock_noBN(nn.Module): 35 | '''Residual block w/o BN 36 | ---Conv-ReLU-Conv-+- 37 | |________________| 38 | ''' 39 | 40 | def __init__(self, nf=64): 41 | super(ResidualBlock_noBN, self).__init__() 42 | self.conv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 43 | self.conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True) 44 | 45 | # initialization 46 | initialize_weights([self.conv1, self.conv2], 0.1) 47 | 48 | def forward(self, x): 49 | identity = x 50 | out = F.relu(self.conv1(x), inplace=True) 51 | out = self.conv2(out) 52 | return identity + out 53 | 54 | 55 | def flow_warp(x, flow, interp_mode='bilinear', padding_mode='zeros'): 56 | """Warp an image or feature map with optical flow 57 | Args: 58 | x (Tensor): size (N, C, H, W) 59 | flow (Tensor): size (N, H, W, 2), normal value 60 | interp_mode (str): 'nearest' or 'bilinear' 61 | padding_mode (str): 'zeros' or 'border' or 'reflection' 62 | 63 | Returns: 64 | Tensor: warped image or feature map 65 | """ 66 | assert x.size()[-2:] == flow.size()[1:3] 67 | B, C, H, W = x.size() 68 | # mesh grid 69 | grid_y, grid_x = torch.meshgrid(torch.arange(0, H), torch.arange(0, W)) 70 | grid = torch.stack((grid_x, grid_y), 2).float() # W(x), H(y), 2 71 | grid.requires_grad = False 72 | grid = grid.type_as(x) 73 | vgrid = grid + flow 74 | # scale grid to [-1,1] 75 | vgrid_x = 2.0 * vgrid[:, :, :, 0] / max(W - 1, 1) - 1.0 76 | vgrid_y = 2.0 * vgrid[:, :, :, 1] / max(H - 1, 1) - 1.0 77 | vgrid_scaled = torch.stack((vgrid_x, vgrid_y), dim=3) 78 | output = F.grid_sample(x, vgrid_scaled, mode=interp_mode, padding_mode=padding_mode) 79 | return output 80 | -------------------------------------------------------------------------------- /codes/models/modules/seg_arch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | architecture for segmentation 3 | ''' 4 | import torch.nn as nn 5 | from . import block as B 6 | 7 | 8 | class Res131(nn.Module): 9 | def __init__(self, in_nc, mid_nc, out_nc, dilation=1, stride=1): 10 | super(Res131, self).__init__() 11 | conv0 = B.conv_block(in_nc, mid_nc, 1, 1, 1, 1, False, 'zero', 'batch') 12 | conv1 = B.conv_block(mid_nc, mid_nc, 3, stride, dilation, 1, False, 'zero', 'batch') 13 | conv2 = B.conv_block(mid_nc, out_nc, 1, 1, 1, 1, False, 'zero', 'batch', None) # No ReLU 14 | self.res = B.sequential(conv0, conv1, conv2) 15 | if in_nc == out_nc: 16 | self.has_proj = False 17 | else: 18 | self.has_proj = True 19 | self.proj = B.conv_block(in_nc, out_nc, 1, stride, 1, 1, False, 'zero', 'batch', None) 20 | # No ReLU 21 | 22 | def forward(self, x): 23 | res = self.res(x) 24 | if self.has_proj: 25 | x = self.proj(x) 26 | return nn.functional.relu(x + res, inplace=True) 27 | 28 | 29 | class OutdoorSceneSeg(nn.Module): 30 | def __init__(self): 31 | super(OutdoorSceneSeg, self).__init__() 32 | # conv1 33 | blocks = [] 34 | conv1_1 = B.conv_block(3, 64, 3, 2, 1, 1, False, 'zero', 'batch') # /2 35 | conv1_2 = B.conv_block(64, 64, 3, 1, 1, 1, False, 'zero', 'batch') 36 | conv1_3 = B.conv_block(64, 128, 3, 1, 1, 1, False, 'zero', 'batch') 37 | max_pool = nn.MaxPool2d(3, stride=2, padding=0, ceil_mode=True) # /2 38 | blocks = [conv1_1, conv1_2, conv1_3, max_pool] 39 | # conv2, 3 blocks 40 | blocks.append(Res131(128, 64, 256)) 41 | for i in range(2): 42 | blocks.append(Res131(256, 64, 256)) 43 | # conv3, 4 blocks 44 | blocks.append(Res131(256, 128, 512, 1, 2)) # /2 45 | for i in range(3): 46 | blocks.append(Res131(512, 128, 512)) 47 | # conv4, 23 blocks 48 | blocks.append(Res131(512, 256, 1024, 2)) 49 | for i in range(22): 50 | blocks.append(Res131(1024, 256, 1024, 2)) 51 | # conv5 52 | blocks.append(Res131(1024, 512, 2048, 4)) 53 | blocks.append(Res131(2048, 512, 2048, 4)) 54 | blocks.append(Res131(2048, 512, 2048, 4)) 55 | blocks.append(B.conv_block(2048, 512, 3, 1, 1, 1, False, 'zero', 'batch')) 56 | blocks.append(nn.Dropout(0.1)) 57 | # # conv6 58 | blocks.append(nn.Conv2d(512, 8, 1, 1)) 59 | 60 | self.feature = B.sequential(*blocks) 61 | # deconv 62 | self.deconv = nn.ConvTranspose2d(8, 8, 16, 8, 4, 0, 8, False, 1) 63 | # softmax 64 | self.softmax = nn.Softmax(1) 65 | 66 | def forward(self, x): 67 | x = self.feature(x) 68 | x = self.deconv(x) 69 | x = self.softmax(x) 70 | return x 71 | -------------------------------------------------------------------------------- /codes/models/modules/sft_arch.py: -------------------------------------------------------------------------------- 1 | ''' 2 | architecture for sft 3 | ''' 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | 8 | class SFTLayer(nn.Module): 9 | def __init__(self): 10 | super(SFTLayer, self).__init__() 11 | self.SFT_scale_conv0 = nn.Conv2d(32, 32, 1) 12 | self.SFT_scale_conv1 = nn.Conv2d(32, 64, 1) 13 | self.SFT_shift_conv0 = nn.Conv2d(32, 32, 1) 14 | self.SFT_shift_conv1 = nn.Conv2d(32, 64, 1) 15 | 16 | def forward(self, x): 17 | # x[0]: fea; x[1]: cond 18 | scale = self.SFT_scale_conv1(F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.1, inplace=True)) 19 | shift = self.SFT_shift_conv1(F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.1, inplace=True)) 20 | return x[0] * (scale + 1) + shift 21 | 22 | 23 | class ResBlock_SFT(nn.Module): 24 | def __init__(self): 25 | super(ResBlock_SFT, self).__init__() 26 | self.sft0 = SFTLayer() 27 | self.conv0 = nn.Conv2d(64, 64, 3, 1, 1) 28 | self.sft1 = SFTLayer() 29 | self.conv1 = nn.Conv2d(64, 64, 3, 1, 1) 30 | 31 | def forward(self, x): 32 | # x[0]: fea; x[1]: cond 33 | fea = self.sft0(x) 34 | fea = F.relu(self.conv0(fea), inplace=True) 35 | fea = self.sft1((fea, x[1])) 36 | fea = self.conv1(fea) 37 | return (x[0] + fea, x[1]) # return a tuple containing features and conditions 38 | 39 | 40 | class SFT_Net(nn.Module): 41 | def __init__(self): 42 | super(SFT_Net, self).__init__() 43 | self.conv0 = nn.Conv2d(3, 64, 3, 1, 1) 44 | 45 | sft_branch = [] 46 | for i in range(16): 47 | sft_branch.append(ResBlock_SFT()) 48 | sft_branch.append(SFTLayer()) 49 | sft_branch.append(nn.Conv2d(64, 64, 3, 1, 1)) 50 | self.sft_branch = nn.Sequential(*sft_branch) 51 | 52 | self.HR_branch = nn.Sequential(nn.Conv2d(64, 256, 3, 1, 53 | 1), nn.PixelShuffle(2), nn.ReLU(True), 54 | nn.Conv2d(64, 256, 3, 1, 1), nn.PixelShuffle(2), 55 | nn.ReLU(True), nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(True), 56 | nn.Conv2d(64, 3, 3, 1, 1)) 57 | 58 | self.CondNet = nn.Sequential(nn.Conv2d(8, 128, 4, 4), nn.LeakyReLU(0.1, True), 59 | nn.Conv2d(128, 128, 1), nn.LeakyReLU(0.1, True), 60 | nn.Conv2d(128, 128, 1), nn.LeakyReLU(0.1, True), 61 | nn.Conv2d(128, 128, 1), nn.LeakyReLU(0.1, True), 62 | nn.Conv2d(128, 32, 1)) 63 | 64 | def forward(self, x): 65 | # x[0]: img; x[1]: seg 66 | cond = self.CondNet(x[1]) 67 | fea = self.conv0(x[0]) 68 | res = self.sft_branch((fea, cond)) 69 | fea = fea + res 70 | out = self.HR_branch(fea) 71 | return out 72 | 73 | 74 | # Auxiliary Classifier Discriminator 75 | class ACD_VGG_BN_96(nn.Module): 76 | def __init__(self): 77 | super(ACD_VGG_BN_96, self).__init__() 78 | 79 | self.feature = nn.Sequential( 80 | nn.Conv2d(3, 64, 3, 1, 1), 81 | nn.LeakyReLU(0.1, True), 82 | nn.Conv2d(64, 64, 4, 2, 1), 83 | nn.BatchNorm2d(64, affine=True), 84 | nn.LeakyReLU(0.1, True), 85 | nn.Conv2d(64, 128, 3, 1, 1), 86 | nn.BatchNorm2d(128, affine=True), 87 | nn.LeakyReLU(0.1, True), 88 | nn.Conv2d(128, 128, 4, 2, 1), 89 | nn.BatchNorm2d(128, affine=True), 90 | nn.LeakyReLU(0.1, True), 91 | nn.Conv2d(128, 256, 3, 1, 1), 92 | nn.BatchNorm2d(256, affine=True), 93 | nn.LeakyReLU(0.1, True), 94 | nn.Conv2d(256, 256, 4, 2, 1), 95 | nn.BatchNorm2d(256, affine=True), 96 | nn.LeakyReLU(0.1, True), 97 | nn.Conv2d(256, 512, 3, 1, 1), 98 | nn.BatchNorm2d(512, affine=True), 99 | nn.LeakyReLU(0.1, True), 100 | nn.Conv2d(512, 512, 4, 2, 1), 101 | nn.BatchNorm2d(512, affine=True), 102 | nn.LeakyReLU(0.1, True), 103 | ) 104 | 105 | # gan 106 | self.gan = nn.Sequential(nn.Linear(512 * 6 * 6, 100), nn.LeakyReLU(0.1, True), 107 | nn.Linear(100, 1)) 108 | 109 | self.cls = nn.Sequential(nn.Linear(512 * 6 * 6, 100), nn.LeakyReLU(0.1, True), 110 | nn.Linear(100, 8)) 111 | 112 | def forward(self, x): 113 | fea = self.feature(x) 114 | fea = fea.view(fea.size(0), -1) 115 | gan = self.gan(fea) 116 | cls = self.cls(fea) 117 | return [gan, cls] 118 | 119 | 120 | ############################################# 121 | # below is the sft arch for the torch version 122 | ############################################# 123 | 124 | 125 | class SFTLayer_torch(nn.Module): 126 | def __init__(self): 127 | super(SFTLayer_torch, self).__init__() 128 | self.SFT_scale_conv0 = nn.Conv2d(32, 32, 1) 129 | self.SFT_scale_conv1 = nn.Conv2d(32, 64, 1) 130 | self.SFT_shift_conv0 = nn.Conv2d(32, 32, 1) 131 | self.SFT_shift_conv1 = nn.Conv2d(32, 64, 1) 132 | 133 | def forward(self, x): 134 | # x[0]: fea; x[1]: cond 135 | scale = self.SFT_scale_conv1(F.leaky_relu(self.SFT_scale_conv0(x[1]), 0.01, inplace=True)) 136 | shift = self.SFT_shift_conv1(F.leaky_relu(self.SFT_shift_conv0(x[1]), 0.01, inplace=True)) 137 | return x[0] * scale + shift 138 | 139 | 140 | class ResBlock_SFT_torch(nn.Module): 141 | def __init__(self): 142 | super(ResBlock_SFT_torch, self).__init__() 143 | self.sft0 = SFTLayer_torch() 144 | self.conv0 = nn.Conv2d(64, 64, 3, 1, 1) 145 | self.sft1 = SFTLayer_torch() 146 | self.conv1 = nn.Conv2d(64, 64, 3, 1, 1) 147 | 148 | def forward(self, x): 149 | # x[0]: fea; x[1]: cond 150 | fea = F.relu(self.sft0(x), inplace=True) 151 | fea = self.conv0(fea) 152 | fea = F.relu(self.sft1((fea, x[1])), inplace=True) 153 | fea = self.conv1(fea) 154 | return (x[0] + fea, x[1]) # return a tuple containing features and conditions 155 | 156 | 157 | class SFT_Net_torch(nn.Module): 158 | def __init__(self): 159 | super(SFT_Net_torch, self).__init__() 160 | self.conv0 = nn.Conv2d(3, 64, 3, 1, 1) 161 | 162 | sft_branch = [] 163 | for i in range(16): 164 | sft_branch.append(ResBlock_SFT_torch()) 165 | sft_branch.append(SFTLayer_torch()) 166 | sft_branch.append(nn.Conv2d(64, 64, 3, 1, 1)) 167 | self.sft_branch = nn.Sequential(*sft_branch) 168 | 169 | self.HR_branch = nn.Sequential(nn.Upsample(scale_factor=2, mode='nearest'), 170 | nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(True), 171 | nn.Upsample(scale_factor=2, mode='nearest'), 172 | nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(True), 173 | nn.Conv2d(64, 64, 3, 1, 1), nn.ReLU(True), 174 | nn.Conv2d(64, 3, 3, 1, 1)) 175 | 176 | # Condtion network 177 | self.CondNet = nn.Sequential(nn.Conv2d(8, 128, 4, 4), nn.LeakyReLU(0.1, True), 178 | nn.Conv2d(128, 128, 1), nn.LeakyReLU(0.1, True), 179 | nn.Conv2d(128, 128, 1), nn.LeakyReLU(0.1, True), 180 | nn.Conv2d(128, 128, 1), nn.LeakyReLU(0.1, True), 181 | nn.Conv2d(128, 32, 1)) 182 | 183 | def forward(self, x): 184 | # x[0]: img; x[1]: seg 185 | cond = self.CondNet(x[1]) 186 | fea = self.conv0(x[0]) 187 | res = self.sft_branch((fea, cond)) 188 | fea = fea + res 189 | out = self.HR_branch(fea) 190 | return out 191 | -------------------------------------------------------------------------------- /codes/models/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import logging 4 | import models.modules.SRResNet_arch as SRResNet_arch 5 | import models.modules.discriminator_vgg_arch as SRGAN_arch 6 | import models.modules.RRDBNet_arch as RRDBNet_arch 7 | logger = logging.getLogger('base') 8 | 9 | 10 | class ResidualBlock(nn.Module): 11 | def __init__(self, channels): 12 | super(ResidualBlock, self).__init__() 13 | self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 14 | self.prelu = nn.PReLU() 15 | self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) 16 | 17 | def forward(self, x): 18 | residual = self.conv1(x) 19 | residual = self.prelu(residual) 20 | residual = self.conv2(residual) 21 | return x + residual 22 | 23 | 24 | 25 | class Generator(nn.Module): 26 | def __init__(self, n_res_blocks=8): 27 | super(Generator, self).__init__() 28 | self.block_input = nn.Sequential( 29 | nn.Conv2d(3, 64, kernel_size=3, padding=1), 30 | nn.PReLU() 31 | ) 32 | self.res_blocks = nn.ModuleList([ResidualBlock(64) for _ in range(n_res_blocks)]) 33 | self.block_output = nn.Conv2d(64, 3, kernel_size=3, padding=1) 34 | # for k, v in self.features.named_parameters(): 35 | # v.requires_grad = False 36 | # self.high_pass = FilterHigh(kernel_size=5) 37 | # self.noise_level = 1 38 | 39 | def forward(self, x, z): 40 | # noise_map = self.high_pass(noise_img) 41 | # concat_input = torch.cat([x, noise_map], dim=1) 42 | z = z.expand(x.shape) 43 | block = self.block_input(z) 44 | for res_block in self.res_blocks: 45 | block = res_block(block) 46 | noise = self.block_output(block) 47 | # out = torch.tanh(block) * self.noise_level + x 48 | # noise = torch.sigmoid(block) 49 | return torch.clamp(x + noise, 0, 1), noise 50 | 51 | #################### 52 | # define network 53 | #################### 54 | #### Generator 55 | def define_G(opt): 56 | opt_net = opt['network_G'] 57 | which_model = opt_net['which_model_G'] 58 | 59 | if which_model == 'MSRResNet': 60 | netG = SRResNet_arch.MSRResNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 61 | nf=opt_net['nf'], nb=opt_net['nb'], upscale=opt_net['scale']) 62 | elif which_model == 'RRDBNet': 63 | netG = RRDBNet_arch.RRDBNet(in_nc=opt_net['in_nc'], out_nc=opt_net['out_nc'], 64 | nf=opt_net['nf'], nb=opt_net['nb']) 65 | # elif which_model == 'sft_arch': # SFT-GAN 66 | # netG = sft_arch.SFT_Net() 67 | else: 68 | raise NotImplementedError('Generator model [{:s}] not recognized'.format(which_model)) 69 | return netG 70 | 71 | 72 | #### Discriminator 73 | def define_D(opt): 74 | opt_net = opt['network_D'] 75 | which_model = opt_net['which_model_D'] 76 | 77 | if which_model == 'discriminator_vgg_128': 78 | netD = SRGAN_arch.Discriminator_VGG_128(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 79 | elif which_model == 'discriminator_vgg_256': 80 | netD = SRGAN_arch.Discriminator_VGG_256(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 81 | elif which_model == 'discriminator_vgg_512': 82 | netD = SRGAN_arch.Discriminator_VGG_512(in_nc=opt_net['in_nc'], nf=opt_net['nf']) 83 | elif which_model == 'NLayerDiscriminator': 84 | if opt_net['norm_layer'] == 'batchnorm': 85 | norm_layer = nn.BatchNorm2d 86 | elif opt_net['norm_layer'] == 'instancenorm': 87 | norm_layer = nn.InstanceNorm2d 88 | netD = SRGAN_arch.NLayerDiscriminator(input_nc=opt_net['in_nc'], ndf=opt_net['nf'], n_layers=opt_net['nlayer'], norm_layer=norm_layer) 89 | else: 90 | raise NotImplementedError('Discriminator model [{:s}] not recognized'.format(which_model)) 91 | return netD 92 | 93 | 94 | #### Define Network used for Perceptual Loss 95 | def define_F(opt, use_bn=False): 96 | gpu_ids = opt['gpu_ids'] 97 | device = torch.device('cuda' if gpu_ids else 'cpu') 98 | # PyTorch pretrained VGG19-54, before ReLU. 99 | if use_bn: 100 | feature_layer = 49 101 | else: 102 | feature_layer = 34 103 | netF = SRGAN_arch.VGGFeatureExtractor(feature_layer=feature_layer, use_bn=use_bn, 104 | use_input_norm=True, device=device) 105 | netF.eval() # No need to train 106 | return netF 107 | -------------------------------------------------------------------------------- /codes/options/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/codes/options/__init__.py -------------------------------------------------------------------------------- /codes/options/df2k/test_df2k.yml: -------------------------------------------------------------------------------- 1 | name: Track1 2 | suffix: ~ # add suffix to saved images 3 | model: srgan 4 | distortion: sr 5 | scale: 4 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | gpu_ids: [0] 8 | 9 | datasets: 10 | test_1: # the 1st test dataset 11 | name: DIV2K 12 | mode: LR 13 | dataroot_LR: /mnt/data/NTIRE2020/realSR/track1/Corrupted-te-x 14 | 15 | #### network structures 16 | network_G: 17 | which_model_G: RRDBNet 18 | in_nc: 3 19 | out_nc: 3 20 | nf: 64 21 | nb: 23 22 | upscale: 4 23 | 24 | #### path 25 | path: 26 | pretrain_model_G: pretrained_model/origin/DF2K.pth 27 | results_root: ./results/ -------------------------------------------------------------------------------- /codes/options/dped/test_dped.yml: -------------------------------------------------------------------------------- 1 | name: Track2 2 | suffix: ~ # add suffix to saved images 3 | model: srgan 4 | distortion: sr 5 | scale: 4 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | gpu_ids: [0] 8 | 9 | datasets: 10 | test_1: # the 1st test dataset 11 | name: DPED 12 | mode: LR 13 | dataroot_LR: /mnt/data/NTIRE2020/realSR/track2/DPEDiphone-crop-te-x 14 | 15 | #### network structures 16 | network_G: 17 | which_model_G: RRDBNet 18 | in_nc: 3 19 | out_nc: 3 20 | nf: 64 21 | nb: 23 22 | upscale: 4 23 | 24 | #### path 25 | path: 26 | pretrain_model_G: pretrained_model/origin/DPED.pth 27 | results_root: ./results/ 28 | 29 | back_projection: False 30 | back_projection_lamda: !!float 0.2 -------------------------------------------------------------------------------- /codes/options/options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import logging 4 | import yaml 5 | from utils.util import OrderedYaml 6 | Loader, Dumper = OrderedYaml() 7 | 8 | 9 | def parse(opt_path, is_train=True): 10 | with open(opt_path, mode='r') as f: 11 | opt = yaml.load(f, Loader=Loader) 12 | # export CUDA_VISIBLE_DEVICES 13 | #gpu_list = ','.join(str(x) for x in opt['gpu_ids']) 14 | #os.environ['CUDA_VISIBLE_DEVICES'] = gpu_list 15 | #print('export CUDA_VISIBLE_DEVICES=' + gpu_list) 16 | 17 | opt['is_train'] = is_train 18 | if opt['distortion'] == 'sr': 19 | scale = opt['scale'] 20 | 21 | # datasets 22 | for phase, dataset in opt['datasets'].items(): 23 | phase = phase.split('_')[0] 24 | dataset['phase'] = phase 25 | if opt['distortion'] == 'sr': 26 | dataset['scale'] = scale 27 | is_lmdb = False 28 | if dataset.get('dataroot_GT', None) is not None: 29 | dataset['dataroot_GT'] = osp.expanduser(dataset['dataroot_GT']) 30 | if dataset['dataroot_GT'].endswith('lmdb'): 31 | is_lmdb = True 32 | # if dataset.get('dataroot_GT_bg', None) is not None: 33 | # dataset['dataroot_GT_bg'] = osp.expanduser(dataset['dataroot_GT_bg']) 34 | if dataset.get('dataroot_LQ', None) is not None: 35 | dataset['dataroot_LQ'] = osp.expanduser(dataset['dataroot_LQ']) 36 | if dataset['dataroot_LQ'].endswith('lmdb'): 37 | is_lmdb = True 38 | dataset['data_type'] = 'lmdb' if is_lmdb else 'img' 39 | if dataset['mode'].endswith('mc'): # for memcached 40 | dataset['data_type'] = 'mc' 41 | dataset['mode'] = dataset['mode'].replace('_mc', '') 42 | 43 | # path 44 | for key, path in opt['path'].items(): 45 | if path and key in opt['path'] and key != 'strict_load': 46 | opt['path'][key] = osp.expanduser(path) 47 | opt['path']['root'] = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 48 | if is_train: 49 | experiments_root = osp.join(opt['path']['root'], 'experiments', opt['name']) 50 | opt['path']['experiments_root'] = experiments_root 51 | opt['path']['models'] = osp.join(experiments_root, 'models') 52 | opt['path']['training_state'] = osp.join(experiments_root, 'training_state') 53 | opt['path']['log'] = experiments_root 54 | opt['path']['val_images'] = osp.join(experiments_root, 'val_images') 55 | 56 | # change some options for debug mode 57 | if 'debug' in opt['name']: 58 | opt['train']['val_freq'] = 8 59 | opt['logger']['print_freq'] = 1 60 | opt['logger']['save_checkpoint_freq'] = 8 61 | else: # test 62 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 63 | opt['path']['results_root'] = results_root 64 | opt['path']['log'] = results_root 65 | 66 | # network 67 | if opt['distortion'] == 'sr': 68 | opt['network_G']['scale'] = scale 69 | 70 | return opt 71 | 72 | 73 | def dict2str(opt, indent_l=1): 74 | '''dict to string for logger''' 75 | msg = '' 76 | for k, v in opt.items(): 77 | if isinstance(v, dict): 78 | msg += ' ' * (indent_l * 2) + k + ':[\n' 79 | msg += dict2str(v, indent_l + 1) 80 | msg += ' ' * (indent_l * 2) + ']\n' 81 | else: 82 | msg += ' ' * (indent_l * 2) + k + ': ' + str(v) + '\n' 83 | return msg 84 | 85 | 86 | class NoneDict(dict): 87 | def __missing__(self, key): 88 | return None 89 | 90 | 91 | # convert to NoneDict, which return None for missing key. 92 | def dict_to_nonedict(opt): 93 | if isinstance(opt, dict): 94 | new_opt = dict() 95 | for key, sub_opt in opt.items(): 96 | new_opt[key] = dict_to_nonedict(sub_opt) 97 | return NoneDict(**new_opt) 98 | elif isinstance(opt, list): 99 | return [dict_to_nonedict(sub_opt) for sub_opt in opt] 100 | else: 101 | return opt 102 | 103 | 104 | def check_resume(opt, resume_iter): 105 | '''Check resume states and pretrain_model paths''' 106 | logger = logging.getLogger('base') 107 | if opt['path']['resume_state']: 108 | if opt['path'].get('pretrain_model_G', None) is not None or opt['path'].get( 109 | 'pretrain_model_D', None) is not None: 110 | logger.warning('pretrain_model path will be ignored when resuming training.') 111 | 112 | opt['path']['pretrain_model_G'] = osp.join(opt['path']['models'], 113 | '{}_G.pth'.format(resume_iter)) 114 | logger.info('Set [pretrain_model_G] to ' + opt['path']['pretrain_model_G']) 115 | if 'gan' in opt['model']: 116 | opt['path']['pretrain_model_D'] = osp.join(opt['path']['models'], 117 | '{}_D.pth'.format(resume_iter)) 118 | logger.info('Set [pretrain_model_D] to ' + opt['path']['pretrain_model_D']) 119 | -------------------------------------------------------------------------------- /codes/scripts/back_projection/backprojection.m: -------------------------------------------------------------------------------- 1 | function [im_h] = backprojection(im_h, im_l, maxIter) 2 | 3 | [row_l, col_l,~] = size(im_l); 4 | [row_h, col_h,~] = size(im_h); 5 | 6 | p = fspecial('gaussian', 5, 1); 7 | p = p.^2; 8 | p = p./sum(p(:)); 9 | 10 | im_l = double(im_l); 11 | im_h = double(im_h); 12 | 13 | for ii = 1:maxIter 14 | im_l_s = imresize(im_h, [row_l, col_l], 'bicubic'); 15 | im_diff = im_l - im_l_s; 16 | im_diff = imresize(im_diff, [row_h, col_h], 'bicubic'); 17 | im_h(:,:,1) = im_h(:,:,1) + conv2(im_diff(:,:,1), p, 'same'); 18 | im_h(:,:,2) = im_h(:,:,2) + conv2(im_diff(:,:,2), p, 'same'); 19 | im_h(:,:,3) = im_h(:,:,3) + conv2(im_diff(:,:,3), p, 'same'); 20 | end 21 | -------------------------------------------------------------------------------- /codes/scripts/back_projection/main_bp.m: -------------------------------------------------------------------------------- 1 | clear; close all; clc; 2 | 3 | LR_folder = './LR'; % LR 4 | preout_folder = './results'; % pre output 5 | save_folder = './results_20bp'; 6 | filepaths = dir(fullfile(preout_folder, '*.png')); 7 | max_iter = 20; 8 | 9 | if ~ exist(save_folder, 'dir') 10 | mkdir(save_folder); 11 | end 12 | 13 | for idx_im = 1:length(filepaths) 14 | fprintf([num2str(idx_im) '\n']); 15 | im_name = filepaths(idx_im).name; 16 | im_LR = im2double(imread(fullfile(LR_folder, im_name))); 17 | im_out = im2double(imread(fullfile(preout_folder, im_name))); 18 | %tic 19 | im_out = backprojection(im_out, im_LR, max_iter); 20 | %toc 21 | imwrite(im_out, fullfile(save_folder, im_name)); 22 | end 23 | -------------------------------------------------------------------------------- /codes/scripts/back_projection/main_reverse_filter.m: -------------------------------------------------------------------------------- 1 | clear; close all; clc; 2 | 3 | LR_folder = './LR'; % LR 4 | preout_folder = './results'; % pre output 5 | save_folder = './results_20if'; 6 | filepaths = dir(fullfile(preout_folder, '*.png')); 7 | max_iter = 20; 8 | 9 | if ~ exist(save_folder, 'dir') 10 | mkdir(save_folder); 11 | end 12 | 13 | for idx_im = 1:length(filepaths) 14 | fprintf([num2str(idx_im) '\n']); 15 | im_name = filepaths(idx_im).name; 16 | im_LR = im2double(imread(fullfile(LR_folder, im_name))); 17 | im_out = im2double(imread(fullfile(preout_folder, im_name))); 18 | J = imresize(im_LR,4,'bicubic'); 19 | %tic 20 | for m = 1:max_iter 21 | im_out = im_out + (J - imresize(imresize(im_out,1/4,'bicubic'),4,'bicubic')); 22 | end 23 | %toc 24 | imwrite(im_out, fullfile(save_folder, im_name)); 25 | end 26 | -------------------------------------------------------------------------------- /codes/scripts/color2gray.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path 3 | import sys 4 | from multiprocessing import Pool 5 | import cv2 6 | try: 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | from data.util import bgr2ycbcr 9 | from utils.progress_bar import ProgressBar 10 | except ImportError: 11 | pass 12 | 13 | 14 | def main(): 15 | """A multi-thread tool for converting RGB images to gary/Y images.""" 16 | 17 | input_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800' 18 | save_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_gray' 19 | mode = 'gray' # 'gray' | 'y': Y channel in YCbCr space 20 | compression_level = 3 # 3 is the default value in cv2 21 | # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer 22 | # compression time. If read raw images during training, use 0 for faster IO speed. 23 | n_thread = 20 # thread number 24 | 25 | if not os.path.exists(save_folder): 26 | os.makedirs(save_folder) 27 | print('mkdir [{:s}] ...'.format(save_folder)) 28 | else: 29 | print('Folder [{:s}] already exists. Exit...'.format(save_folder)) 30 | sys.exit(1) 31 | # print('Parent process {:d}.'.format(os.getpid())) 32 | 33 | img_list = [] 34 | for root, _, file_list in sorted(os.walk(input_folder)): 35 | path = [os.path.join(root, x) for x in file_list] # assume only images in the input_folder 36 | img_list.extend(path) 37 | 38 | def update(arg): 39 | pbar.update(arg) 40 | 41 | pbar = ProgressBar(len(img_list)) 42 | 43 | pool = Pool(n_thread) 44 | for path in img_list: 45 | pool.apply_async(worker, args=(path, save_folder, mode, compression_level), callback=update) 46 | pool.close() 47 | pool.join() 48 | print('All subprocesses done.') 49 | 50 | 51 | def worker(path, save_folder, mode, compression_level): 52 | img_name = os.path.basename(path) 53 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # BGR 54 | if mode == 'gray': 55 | img_y = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) 56 | else: 57 | img_y = bgr2ycbcr(img, only_y=True) 58 | cv2.imwrite(os.path.join(save_folder, img_name), img_y, 59 | [cv2.IMWRITE_PNG_COMPRESSION, compression_level]) 60 | return 'Processing {:s} ...'.format(img_name) 61 | 62 | 63 | if __name__ == '__main__': 64 | main() 65 | -------------------------------------------------------------------------------- /codes/scripts/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | import glob 5 | import pickle 6 | import lmdb 7 | import cv2 8 | try: 9 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 10 | from utils.util import ProgressBar 11 | except ImportError: 12 | pass 13 | 14 | # configurations 15 | img_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub/*' # glob matching pattern 16 | lmdb_save_path = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub.lmdb' 17 | meta_info = {'name': 'DIV2K800_sub_GT'} 18 | mode = 2 # 1 for reading all the images to memory and then writing to lmdb (more memory); 19 | # 2 for reading several images and then writing to lmdb, loop over (less memory) 20 | batch = 1000 # Used in mode 2. After batch images, lmdb commits. 21 | ########################################### 22 | if not lmdb_save_path.endswith('.lmdb'): 23 | raise ValueError("lmdb_save_path must end with \'lmdb\'.") 24 | #### whether the lmdb file exist 25 | if osp.exists(lmdb_save_path): 26 | print('Folder [{:s}] already exists. Exit...'.format(lmdb_save_path)) 27 | sys.exit(1) 28 | img_list = sorted(glob.glob(img_folder)) 29 | if mode == 1: 30 | print('Read images...') 31 | dataset = [cv2.imread(v, cv2.IMREAD_UNCHANGED) for v in img_list] 32 | data_size = sum([img.nbytes for img in dataset]) 33 | elif mode == 2: 34 | print('Calculating the total size of images...') 35 | data_size = sum(os.stat(v).st_size for v in img_list) 36 | else: 37 | raise ValueError('mode should be 1 or 2') 38 | 39 | key_l = [] 40 | resolution_l = [] 41 | pbar = ProgressBar(len(img_list)) 42 | env = lmdb.open(lmdb_save_path, map_size=data_size * 10) 43 | txn = env.begin(write=True) # txn is a Transaction object 44 | for i, v in enumerate(img_list): 45 | pbar.update('Write {}'.format(v)) 46 | base_name = osp.splitext(osp.basename(v))[0] 47 | key = base_name.encode('ascii') 48 | data = dataset[i] if mode == 1 else cv2.imread(v, cv2.IMREAD_UNCHANGED) 49 | if data.ndim == 2: 50 | H, W = data.shape 51 | C = 1 52 | else: 53 | H, W, C = data.shape 54 | txn.put(key, data) 55 | key_l.append(base_name) 56 | resolution_l.append('{:d}_{:d}_{:d}'.format(C, H, W)) 57 | # commit in mode 2 58 | if mode == 2 and i % batch == 1: 59 | txn.commit() 60 | txn = env.begin(write=True) 61 | 62 | txn.commit() 63 | env.close() 64 | 65 | print('Finish writing lmdb.') 66 | 67 | #### create meta information 68 | # check whether all the images are the same size 69 | same_resolution = (len(set(resolution_l)) <= 1) 70 | if same_resolution: 71 | meta_info['resolution'] = [resolution_l[0]] 72 | meta_info['keys'] = key_l 73 | print('All images have the same resolution. Simplify the meta info...') 74 | else: 75 | meta_info['resolution'] = resolution_l 76 | meta_info['keys'] = key_l 77 | print('Not all images have the same resolution. Save meta info for each image...') 78 | 79 | #### pickle dump 80 | pickle.dump(meta_info, open(osp.join(lmdb_save_path, 'meta_info.pkl'), "wb")) 81 | print('Finish creating lmdb meta info.') 82 | -------------------------------------------------------------------------------- /codes/scripts/extract_subimgs_single.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | import sys 4 | from multiprocessing import Pool 5 | import numpy as np 6 | import cv2 7 | try: 8 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 9 | from utils.util import ProgressBar 10 | except ImportError: 11 | pass 12 | 13 | 14 | def main(): 15 | """A multi-thread tool to crop sub imags.""" 16 | input_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800' 17 | save_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub' 18 | n_thread = 20 19 | crop_sz = 480 20 | step = 240 21 | thres_sz = 48 22 | compression_level = 3 # 3 is the default value in cv2 23 | # CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size and longer 24 | # compression time. If read raw images during training, use 0 for faster IO speed. 25 | 26 | if not os.path.exists(save_folder): 27 | os.makedirs(save_folder) 28 | print('mkdir [{:s}] ...'.format(save_folder)) 29 | else: 30 | print('Folder [{:s}] already exists. Exit...'.format(save_folder)) 31 | sys.exit(1) 32 | 33 | img_list = [] 34 | for root, _, file_list in sorted(os.walk(input_folder)): 35 | path = [os.path.join(root, x) for x in file_list] # assume only images in the input_folder 36 | img_list.extend(path) 37 | 38 | def update(arg): 39 | pbar.update(arg) 40 | 41 | pbar = ProgressBar(len(img_list)) 42 | 43 | pool = Pool(n_thread) 44 | for path in img_list: 45 | pool.apply_async(worker, 46 | args=(path, save_folder, crop_sz, step, thres_sz, compression_level), 47 | callback=update) 48 | pool.close() 49 | pool.join() 50 | print('All subprocesses done.') 51 | 52 | 53 | def worker(path, save_folder, crop_sz, step, thres_sz, compression_level): 54 | img_name = os.path.basename(path) 55 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 56 | 57 | n_channels = len(img.shape) 58 | if n_channels == 2: 59 | h, w = img.shape 60 | elif n_channels == 3: 61 | h, w, c = img.shape 62 | else: 63 | raise ValueError('Wrong image shape - {}'.format(n_channels)) 64 | 65 | h_space = np.arange(0, h - crop_sz + 1, step) 66 | if h - (h_space[-1] + crop_sz) > thres_sz: 67 | h_space = np.append(h_space, h - crop_sz) 68 | w_space = np.arange(0, w - crop_sz + 1, step) 69 | if w - (w_space[-1] + crop_sz) > thres_sz: 70 | w_space = np.append(w_space, w - crop_sz) 71 | 72 | index = 0 73 | for x in h_space: 74 | for y in w_space: 75 | index += 1 76 | if n_channels == 2: 77 | crop_img = img[x:x + crop_sz, y:y + crop_sz] 78 | else: 79 | crop_img = img[x:x + crop_sz, y:y + crop_sz, :] 80 | crop_img = np.ascontiguousarray(crop_img) 81 | # var = np.var(crop_img / 255) 82 | # if var > 0.008: 83 | # print(img_name, index_str, var) 84 | cv2.imwrite( 85 | os.path.join(save_folder, img_name.replace('.png', '_s{:03d}.png'.format(index))), 86 | crop_img, [cv2.IMWRITE_PNG_COMPRESSION, compression_level]) 87 | return 'Processing {:s} ...'.format(img_name) 88 | 89 | 90 | if __name__ == '__main__': 91 | main() 92 | -------------------------------------------------------------------------------- /codes/scripts/generate_mod_LR_bic.m: -------------------------------------------------------------------------------- 1 | function generate_mod_LR_bic() 2 | %% matlab code to genetate mod images, bicubic-downsampled LR, bicubic_upsampled images. 3 | 4 | %% set parameters 5 | % comment the unnecessary line 6 | input_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub'; 7 | % save_mod_folder = ''; 8 | save_LR_folder = '/mnt/SSD/xtwang/BasicSR_datasets/DIV2K800/DIV2K800_sub_bicLRx4'; 9 | % save_bic_folder = ''; 10 | 11 | up_scale = 4; 12 | mod_scale = 4; 13 | 14 | if exist('save_mod_folder', 'var') 15 | if exist(save_mod_folder, 'dir') 16 | disp(['It will cover ', save_mod_folder]); 17 | else 18 | mkdir(save_mod_folder); 19 | end 20 | end 21 | if exist('save_LR_folder', 'var') 22 | if exist(save_LR_folder, 'dir') 23 | disp(['It will cover ', save_LR_folder]); 24 | else 25 | mkdir(save_LR_folder); 26 | end 27 | end 28 | if exist('save_bic_folder', 'var') 29 | if exist(save_bic_folder, 'dir') 30 | disp(['It will cover ', save_bic_folder]); 31 | else 32 | mkdir(save_bic_folder); 33 | end 34 | end 35 | 36 | idx = 0; 37 | filepaths = dir(fullfile(input_folder,'*.*')); 38 | for i = 1 : length(filepaths) 39 | [paths,imname,ext] = fileparts(filepaths(i).name); 40 | if isempty(imname) 41 | disp('Ignore . folder.'); 42 | elseif strcmp(imname, '.') 43 | disp('Ignore .. folder.'); 44 | else 45 | idx = idx + 1; 46 | str_rlt = sprintf('%d\t%s.\n', idx, imname); 47 | fprintf(str_rlt); 48 | % read image 49 | img = imread(fullfile(input_folder, [imname, ext])); 50 | img = im2double(img); 51 | % modcrop 52 | img = modcrop(img, mod_scale); 53 | if exist('save_mod_folder', 'var') 54 | imwrite(img, fullfile(save_mod_folder, [imname, '.png'])); 55 | end 56 | % LR 57 | im_LR = imresize(img, 1/up_scale, 'bicubic'); 58 | if exist('save_LR_folder', 'var') 59 | imwrite(im_LR, fullfile(save_LR_folder, [imname, '_bicLRx4.png'])); 60 | end 61 | % Bicubic 62 | if exist('save_bic_folder', 'var') 63 | im_B = imresize(im_LR, up_scale, 'bicubic'); 64 | imwrite(im_B, fullfile(save_bic_folder, [imname, '_bicx4.png'])); 65 | end 66 | end 67 | end 68 | end 69 | 70 | %% modcrop 71 | function img = modcrop(img, modulo) 72 | if size(img,3) == 1 73 | sz = size(img); 74 | sz = sz - mod(sz, modulo); 75 | img = img(1:sz(1), 1:sz(2)); 76 | else 77 | tmpsz = size(img); 78 | sz = tmpsz(1:2); 79 | sz = sz - mod(sz, modulo); 80 | img = img(1:sz(1), 1:sz(2),:); 81 | end 82 | end 83 | -------------------------------------------------------------------------------- /codes/scripts/generate_mod_LR_bic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import cv2 4 | import numpy as np 5 | 6 | try: 7 | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) 8 | from data.util import imresize_np 9 | except ImportError: 10 | pass 11 | 12 | 13 | def generate_mod_LR_bic(): 14 | # set parameters 15 | up_scale = 4 16 | mod_scale = 4 17 | # set data dir 18 | sourcedir = '/data/datasets/img' 19 | savedir = '/data/datasets/mod' 20 | 21 | saveHRpath = os.path.join(savedir, 'HR', 'x' + str(mod_scale)) 22 | saveLRpath = os.path.join(savedir, 'LR', 'x' + str(up_scale)) 23 | saveBicpath = os.path.join(savedir, 'Bic', 'x' + str(up_scale)) 24 | 25 | if not os.path.isdir(sourcedir): 26 | print('Error: No source data found') 27 | exit(0) 28 | if not os.path.isdir(savedir): 29 | os.mkdir(savedir) 30 | 31 | if not os.path.isdir(os.path.join(savedir, 'HR')): 32 | os.mkdir(os.path.join(savedir, 'HR')) 33 | if not os.path.isdir(os.path.join(savedir, 'LR')): 34 | os.mkdir(os.path.join(savedir, 'LR')) 35 | if not os.path.isdir(os.path.join(savedir, 'Bic')): 36 | os.mkdir(os.path.join(savedir, 'Bic')) 37 | 38 | if not os.path.isdir(saveHRpath): 39 | os.mkdir(saveHRpath) 40 | else: 41 | print('It will cover ' + str(saveHRpath)) 42 | 43 | if not os.path.isdir(saveLRpath): 44 | os.mkdir(saveLRpath) 45 | else: 46 | print('It will cover ' + str(saveLRpath)) 47 | 48 | if not os.path.isdir(saveBicpath): 49 | os.mkdir(saveBicpath) 50 | else: 51 | print('It will cover ' + str(saveBicpath)) 52 | 53 | filepaths = [f for f in os.listdir(sourcedir) if f.endswith('.png')] 54 | num_files = len(filepaths) 55 | 56 | # prepare data with augementation 57 | for i in range(num_files): 58 | filename = filepaths[i] 59 | print('No.{} -- Processing {}'.format(i, filename)) 60 | # read image 61 | image = cv2.imread(os.path.join(sourcedir, filename)) 62 | 63 | width = int(np.floor(image.shape[1] / mod_scale)) 64 | height = int(np.floor(image.shape[0] / mod_scale)) 65 | # modcrop 66 | if len(image.shape) == 3: 67 | image_HR = image[0:mod_scale * height, 0:mod_scale * width, :] 68 | else: 69 | image_HR = image[0:mod_scale * height, 0:mod_scale * width] 70 | # LR 71 | image_LR = imresize_np(image_HR, 1 / up_scale, True) 72 | # bic 73 | image_Bic = imresize_np(image_LR, up_scale, True) 74 | 75 | cv2.imwrite(os.path.join(saveHRpath, filename), image_HR) 76 | cv2.imwrite(os.path.join(saveLRpath, filename), image_LR) 77 | cv2.imwrite(os.path.join(saveBicpath, filename), image_Bic) 78 | 79 | 80 | if __name__ == "__main__": 81 | generate_mod_LR_bic() 82 | -------------------------------------------------------------------------------- /codes/scripts/transfer_params_MSRResNet.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import sys 3 | import torch 4 | try: 5 | sys.path.append(osp.dirname(osp.dirname(osp.abspath(__file__)))) 6 | import models.modules.SRResNet_arch as SRResNet_arch 7 | except ImportError: 8 | pass 9 | 10 | pretrained_net = torch.load('../../experiments/pretrained_models/MSRResNetx4.pth') 11 | crt_model = SRResNet_arch.MSRResNet(in_nc=3, out_nc=3, nf=64, nb=16, upscale=3) 12 | crt_net = crt_model.state_dict() 13 | 14 | for k, v in crt_net.items(): 15 | if k in pretrained_net and 'upconv1' not in k: 16 | crt_net[k] = pretrained_net[k] 17 | print('replace ... ', k) 18 | 19 | # x4 -> x3 20 | crt_net['upconv1.weight'][0:256, :, :, :] = pretrained_net['upconv1.weight'] / 2 21 | crt_net['upconv1.weight'][256:512, :, :, :] = pretrained_net['upconv1.weight'] / 2 22 | crt_net['upconv1.weight'][512:576, :, :, :] = pretrained_net['upconv1.weight'][0:64, :, :, :] / 2 23 | crt_net['upconv1.bias'][0:256] = pretrained_net['upconv1.bias'] / 2 24 | crt_net['upconv1.bias'][256:512] = pretrained_net['upconv1.bias'] / 2 25 | crt_net['upconv1.bias'][512:576] = pretrained_net['upconv1.bias'][0:64] / 2 26 | 27 | torch.save(crt_net, '../../experiments/pretrained_models/MSRResNetx3_ini.pth') 28 | -------------------------------------------------------------------------------- /codes/scripts/transfer_params_sft.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import init 3 | 4 | pretrained_net = torch.load('../../experiments/pretrained_models/SRGAN_bicx4_noBN_DIV2K.pth') 5 | # should run train debug mode first to get an initial model 6 | crt_net = torch.load('../../experiments/pretrained_models/sft_net_raw.pth') 7 | 8 | for k, v in crt_net.items(): 9 | if 'weight' in k: 10 | print(k, 'weight') 11 | init.kaiming_normal(v, a=0, mode='fan_in') 12 | v *= 0.1 13 | elif 'bias' in k: 14 | print(k, 'bias') 15 | v.fill_(0) 16 | 17 | crt_net['conv0.weight'] = pretrained_net['model.0.weight'] 18 | crt_net['conv0.bias'] = pretrained_net['model.0.bias'] 19 | # residual blocks 20 | for i in range(16): 21 | crt_net['sft_branch.{:d}.conv0.weight'.format(i)] = pretrained_net[ 22 | 'model.1.sub.{:d}.res.0.weight'.format(i)] 23 | crt_net['sft_branch.{:d}.conv0.bias'.format(i)] = pretrained_net[ 24 | 'model.1.sub.{:d}.res.0.bias'.format(i)] 25 | crt_net['sft_branch.{:d}.conv1.weight'.format(i)] = pretrained_net[ 26 | 'model.1.sub.{:d}.res.2.weight'.format(i)] 27 | crt_net['sft_branch.{:d}.conv1.bias'.format(i)] = pretrained_net[ 28 | 'model.1.sub.{:d}.res.2.bias'.format(i)] 29 | 30 | crt_net['sft_branch.17.weight'] = pretrained_net['model.1.sub.16.weight'] 31 | crt_net['sft_branch.17.bias'] = pretrained_net['model.1.sub.16.bias'] 32 | 33 | # HR 34 | crt_net['HR_branch.0.weight'] = pretrained_net['model.2.weight'] 35 | crt_net['HR_branch.0.bias'] = pretrained_net['model.2.bias'] 36 | crt_net['HR_branch.3.weight'] = pretrained_net['model.5.weight'] 37 | crt_net['HR_branch.3.bias'] = pretrained_net['model.5.bias'] 38 | crt_net['HR_branch.6.weight'] = pretrained_net['model.8.weight'] 39 | crt_net['HR_branch.6.bias'] = pretrained_net['model.8.bias'] 40 | crt_net['HR_branch.8.weight'] = pretrained_net['model.10.weight'] 41 | crt_net['HR_branch.8.bias'] = pretrained_net['model.10.bias'] 42 | 43 | print('OK. \n Saving model...') 44 | torch.save(crt_net, '../../experiments/pretrained_models/sft_net_ini.pth') 45 | -------------------------------------------------------------------------------- /codes/test.py: -------------------------------------------------------------------------------- 1 | import os.path as osp 2 | import logging 3 | import time 4 | import argparse 5 | from collections import OrderedDict 6 | 7 | import options.options as option 8 | import utils.util as util 9 | from data.util import bgr2ycbcr 10 | from data import create_dataset, create_dataloader 11 | from models import create_model 12 | 13 | #### options 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('-opt', type=str, required=True, help='Path to options YMAL file.') 16 | opt = option.parse(parser.parse_args().opt, is_train=False) 17 | opt = option.dict_to_nonedict(opt) 18 | 19 | util.mkdirs( 20 | (path for key, path in opt['path'].items() 21 | if not key == 'experiments_root' and 'pretrain_model' not in key and 'resume' not in key)) 22 | util.setup_logger('base', opt['path']['log'], 'test_' + opt['name'], level=logging.INFO, 23 | screen=True, tofile=True) 24 | logger = logging.getLogger('base') 25 | logger.info(option.dict2str(opt)) 26 | 27 | #### Create test dataset and dataloader 28 | test_loaders = [] 29 | for phase, dataset_opt in sorted(opt['datasets'].items()): 30 | test_set = create_dataset(dataset_opt) 31 | test_loader = create_dataloader(test_set, dataset_opt) 32 | logger.info('Number of test images in [{:s}]: {:d}'.format(dataset_opt['name'], len(test_set))) 33 | test_loaders.append(test_loader) 34 | 35 | model = create_model(opt) 36 | for test_loader in test_loaders: 37 | test_set_name = test_loader.dataset.opt['name'] 38 | logger.info('\nTesting [{:s}]...'.format(test_set_name)) 39 | test_start_time = time.time() 40 | dataset_dir = osp.join(opt['path']['results_root'], test_set_name) 41 | util.mkdir(dataset_dir) 42 | 43 | test_results = OrderedDict() 44 | test_results['psnr'] = [] 45 | test_results['ssim'] = [] 46 | test_results['psnr_y'] = [] 47 | test_results['ssim_y'] = [] 48 | 49 | for data in test_loader: 50 | need_GT = False if test_loader.dataset.opt['dataroot_GT'] is None else True 51 | model.feed_data(data, need_GT=need_GT) 52 | img_path = data['GT_path'][0] if need_GT else data['LQ_path'][0] 53 | img_name = osp.splitext(osp.basename(img_path))[0] 54 | if opt['model'] == 'sr': 55 | model.test_x8() 56 | elif opt['large'] is not None: 57 | model.test_chop() 58 | else: 59 | model.test() 60 | if opt['back_projection'] is not None and opt['back_projection'] is True: 61 | model.back_projection() 62 | visuals = model.get_current_visuals(need_GT=need_GT) 63 | 64 | sr_img = util.tensor2img(visuals['SR']) # uint8 65 | 66 | # save images 67 | suffix = opt['suffix'] 68 | if suffix: 69 | save_img_path = osp.join(dataset_dir, img_name + suffix + '.png') 70 | else: 71 | save_img_path = osp.join(dataset_dir, img_name + '.png') 72 | util.save_img(sr_img, save_img_path) 73 | 74 | # calculate PSNR and SSIM 75 | if need_GT: 76 | gt_img = util.tensor2img(visuals['GT']) 77 | gt_img = gt_img / 255. 78 | sr_img = sr_img / 255. 79 | 80 | crop_border = opt['crop_border'] if opt['crop_border'] else opt['scale'] 81 | if crop_border == 0: 82 | cropped_sr_img = sr_img 83 | cropped_gt_img = gt_img 84 | else: 85 | cropped_sr_img = sr_img[crop_border:-crop_border, crop_border:-crop_border, :] 86 | cropped_gt_img = gt_img[crop_border:-crop_border, crop_border:-crop_border, :] 87 | 88 | psnr = util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) 89 | ssim = util.calculate_ssim(cropped_sr_img * 255, cropped_gt_img * 255) 90 | test_results['psnr'].append(psnr) 91 | test_results['ssim'].append(ssim) 92 | 93 | if gt_img.shape[2] == 3: # RGB image 94 | sr_img_y = bgr2ycbcr(sr_img, only_y=True) 95 | gt_img_y = bgr2ycbcr(gt_img, only_y=True) 96 | if crop_border == 0: 97 | cropped_sr_img_y = sr_img_y 98 | cropped_gt_img_y = gt_img_y 99 | else: 100 | cropped_sr_img_y = sr_img_y[crop_border:-crop_border, crop_border:-crop_border] 101 | cropped_gt_img_y = gt_img_y[crop_border:-crop_border, crop_border:-crop_border] 102 | psnr_y = util.calculate_psnr(cropped_sr_img_y * 255, cropped_gt_img_y * 255) 103 | ssim_y = util.calculate_ssim(cropped_sr_img_y * 255, cropped_gt_img_y * 255) 104 | test_results['psnr_y'].append(psnr_y) 105 | test_results['ssim_y'].append(ssim_y) 106 | logger.info( 107 | '{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}; PSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}.'. 108 | format(img_name, psnr, ssim, psnr_y, ssim_y)) 109 | else: 110 | logger.info('{:20s} - PSNR: {:.6f} dB; SSIM: {:.6f}.'.format(img_name, psnr, ssim)) 111 | else: 112 | logger.info(img_name) 113 | 114 | test_run_time = time.time()-test_start_time 115 | print('Runtime {} (s) per image'.format(test_run_time / len(test_loader))) 116 | 117 | if need_GT: # metrics 118 | # Average PSNR/SSIM results 119 | ave_psnr = sum(test_results['psnr']) / len(test_results['psnr']) 120 | ave_ssim = sum(test_results['ssim']) / len(test_results['ssim']) 121 | logger.info( 122 | '----Average PSNR/SSIM results for {}----\n\tPSNR: {:.6f} dB; SSIM: {:.6f}\n'.format( 123 | test_set_name, ave_psnr, ave_ssim)) 124 | if test_results['psnr_y'] and test_results['ssim_y']: 125 | ave_psnr_y = sum(test_results['psnr_y']) / len(test_results['psnr_y']) 126 | ave_ssim_y = sum(test_results['ssim_y']) / len(test_results['ssim_y']) 127 | logger.info( 128 | '----Y channel, average PSNR/SSIM----\n\tPSNR_Y: {:.6f} dB; SSIM_Y: {:.6f}\n'. 129 | format(ave_psnr_y, ave_ssim_y)) 130 | -------------------------------------------------------------------------------- /codes/train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import random 5 | import logging 6 | import torch 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | from data.data_sampler import DistIterSampler 10 | 11 | import options.options as option 12 | from utils import util 13 | from data import create_dataloader, create_dataset 14 | from models import create_model 15 | 16 | 17 | def init_dist(backend='nccl', **kwargs): 18 | ''' initialization for distributed training''' 19 | # if mp.get_start_method(allow_none=True) is None: 20 | if mp.get_start_method(allow_none=True) != 'spawn': 21 | mp.set_start_method('spawn') 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def main(): 29 | #### options 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.') 32 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='none', 33 | help='job launcher') 34 | parser.add_argument('--local_rank', type=int, default=0) 35 | args = parser.parse_args() 36 | opt = option.parse(args.opt, is_train=True) 37 | 38 | #### distributed training settings 39 | if args.launcher == 'none': # disabled distributed training 40 | opt['dist'] = False 41 | rank = -1 42 | print('Disabled distributed training.') 43 | else: 44 | opt['dist'] = True 45 | init_dist() 46 | world_size = torch.distributed.get_world_size() 47 | rank = torch.distributed.get_rank() 48 | 49 | #### loading resume state if exists 50 | if opt['path'].get('resume_state', None): 51 | # distributed resuming: all load into default GPU 52 | device_id = torch.cuda.current_device() 53 | resume_state = torch.load(opt['path']['resume_state'], 54 | map_location=lambda storage, loc: storage.cuda(device_id)) 55 | option.check_resume(opt, resume_state['iter']) # check resume options 56 | else: 57 | resume_state = None 58 | 59 | #### mkdir and loggers 60 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) 61 | if resume_state is None: 62 | util.mkdir_and_rename( 63 | opt['path']['experiments_root']) # rename experiment folder if exists 64 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' 65 | and 'pretrain_model' not in key and 'resume' not in key)) 66 | 67 | # config loggers. Before it, the log will not work 68 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, 69 | screen=True, tofile=True) 70 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, 71 | screen=True, tofile=True) 72 | logger = logging.getLogger('base') 73 | logger.info(option.dict2str(opt)) 74 | # tensorboard logger 75 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 76 | version = float(torch.__version__[0:3]) 77 | if version >= 1.1: # PyTorch 1.1 78 | from torch.utils.tensorboard import SummaryWriter 79 | else: 80 | logger.info( 81 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) 82 | from tensorboardX import SummaryWriter 83 | tb_logger = SummaryWriter(log_dir='../tb_logger/' + opt['name']) 84 | else: 85 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) 86 | logger = logging.getLogger('base') 87 | 88 | # convert to NoneDict, which returns None for missing keys 89 | opt = option.dict_to_nonedict(opt) 90 | 91 | #### random seed 92 | seed = opt['train']['manual_seed'] 93 | if seed is None: 94 | seed = random.randint(1, 10000) 95 | if rank <= 0: 96 | logger.info('Random seed: {}'.format(seed)) 97 | util.set_random_seed(seed) 98 | 99 | torch.backends.cudnn.benckmark = True 100 | # torch.backends.cudnn.deterministic = True 101 | 102 | #### create train and val dataloader 103 | dataset_ratio = 200 # enlarge the size of each epoch 104 | train_loader = None 105 | val_loader = None 106 | for phase, dataset_opt in opt['datasets'].items(): 107 | if phase == 'train': 108 | # print('\n\n\n\n\n\n\n\n', dataset_opt) 109 | train_set = create_dataset(dataset_opt) 110 | train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) 111 | total_iters = int(opt['train']['niter']) 112 | total_epochs = int(math.ceil(total_iters / train_size)) 113 | if opt['dist']: 114 | train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) 115 | total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) 116 | else: 117 | train_sampler = None 118 | train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) 119 | if rank <= 0: 120 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format( 121 | len(train_set), train_size)) 122 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format( 123 | total_epochs, total_iters)) 124 | elif phase == 'val': 125 | val_set = create_dataset(dataset_opt) 126 | val_loader = create_dataloader(val_set, dataset_opt, opt, None) 127 | if rank <= 0: 128 | logger.info('Number of val images in [{:s}]: {:d}'.format( 129 | dataset_opt['name'], len(val_set))) 130 | else: 131 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) 132 | assert train_loader is not None 133 | 134 | #### create model 135 | model = create_model(opt) 136 | 137 | #### resume training 138 | if resume_state: 139 | logger.info('Resuming training from epoch: {}, iter: {}.'.format( 140 | resume_state['epoch'], resume_state['iter'])) 141 | 142 | start_epoch = resume_state['epoch'] 143 | current_step = resume_state['iter'] 144 | model.resume_training(resume_state) # handle optimizers and schedulers 145 | else: 146 | current_step = 0 147 | start_epoch = 0 148 | 149 | #### training 150 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) 151 | for epoch in range(start_epoch, total_epochs + 1): 152 | if opt['dist']: 153 | train_sampler.set_epoch(epoch) 154 | for _, train_data in enumerate(train_loader): 155 | current_step += 1 156 | if current_step > total_iters: 157 | break 158 | #### update learning rate 159 | model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) 160 | 161 | #### training 162 | model.feed_data(train_data) 163 | model.optimize_parameters(current_step) 164 | 165 | #### log 166 | if current_step % opt['logger']['print_freq'] == 0: 167 | logs = model.get_current_log() 168 | message = ' '.format( 169 | epoch, current_step, model.get_current_learning_rate()) 170 | for k, v in logs.items(): 171 | message += '{:s}: {:.4e} '.format(k, v) 172 | # tensorboard logger 173 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 174 | if rank <= 0: 175 | tb_logger.add_scalar(k, v, current_step) 176 | if rank <= 0: 177 | logger.info(message) 178 | 179 | # validation 180 | if current_step % opt['train']['val_freq'] == 0 and rank <= 0 and val_loader is not None: 181 | avg_psnr = val_pix_err_f = val_pix_err_nf = val_mean_color_err = 0.0 182 | idx = 0 183 | for val_data in val_loader: 184 | idx += 1 185 | img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] 186 | img_dir = os.path.join(opt['path']['val_images'], img_name) 187 | util.mkdir(img_dir) 188 | 189 | model.feed_data(val_data) 190 | model.test() 191 | 192 | visuals = model.get_current_visuals() 193 | sr_img = util.tensor2img(visuals['SR']) # uint8 194 | gt_img = util.tensor2img(visuals['GT']) # uint8 195 | 196 | # Save SR images for reference 197 | save_img_path = os.path.join(img_dir, 198 | '{:s}_{:d}.png'.format(img_name, current_step)) 199 | util.save_img(sr_img, save_img_path) 200 | 201 | # calculate PSNR 202 | crop_size = opt['scale'] 203 | gt_img = gt_img / 255. 204 | sr_img = sr_img / 255. 205 | cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] 206 | cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] 207 | avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) 208 | 209 | 210 | avg_psnr = avg_psnr / idx 211 | val_pix_err_f /= idx 212 | val_pix_err_nf /= idx 213 | val_mean_color_err /= idx 214 | 215 | # log 216 | logger.info('# Validation # PSNR: {:.4e}'.format(avg_psnr)) 217 | logger_val = logging.getLogger('val') # validation logger 218 | logger_val.info(' psnr: {:.4e}'.format( 219 | epoch, current_step, avg_psnr)) 220 | # tensorboard logger 221 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 222 | tb_logger.add_scalar('psnr', avg_psnr, current_step) 223 | tb_logger.add_scalar('val_pix_err_f', val_pix_err_f, current_step) 224 | tb_logger.add_scalar('val_pix_err_nf', val_pix_err_nf, current_step) 225 | tb_logger.add_scalar('val_mean_color_err', val_mean_color_err, current_step) 226 | 227 | #### save models and training states 228 | if current_step % opt['logger']['save_checkpoint_freq'] == 0: 229 | if rank <= 0: 230 | logger.info('Saving models and training states.') 231 | model.save(current_step) 232 | model.save_training_state(epoch, current_step) 233 | 234 | if rank <= 0: 235 | logger.info('Saving the final model.') 236 | model.save('latest') 237 | logger.info('End of training.') 238 | 239 | 240 | if __name__ == '__main__': 241 | main() 242 | -------------------------------------------------------------------------------- /codes/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/codes/utils/__init__.py -------------------------------------------------------------------------------- /codes/utils/util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import math 5 | from datetime import datetime 6 | import random 7 | import logging 8 | from collections import OrderedDict 9 | import numpy as np 10 | import cv2 11 | import torch 12 | from torchvision.utils import make_grid 13 | from shutil import get_terminal_size 14 | 15 | import yaml 16 | try: 17 | from yaml import CLoader as Loader, CDumper as Dumper 18 | except ImportError: 19 | from yaml import Loader, Dumper 20 | 21 | 22 | def OrderedYaml(): 23 | '''yaml orderedDict support''' 24 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 25 | 26 | def dict_representer(dumper, data): 27 | return dumper.represent_dict(data.items()) 28 | 29 | def dict_constructor(loader, node): 30 | return OrderedDict(loader.construct_pairs(node)) 31 | 32 | Dumper.add_representer(OrderedDict, dict_representer) 33 | Loader.add_constructor(_mapping_tag, dict_constructor) 34 | return Loader, Dumper 35 | 36 | 37 | #################### 38 | # miscellaneous 39 | #################### 40 | 41 | 42 | def get_timestamp(): 43 | return datetime.now().strftime('%y%m%d-%H%M%S') 44 | 45 | 46 | def mkdir(path): 47 | if not os.path.exists(path): 48 | os.makedirs(path) 49 | 50 | 51 | def mkdirs(paths): 52 | if isinstance(paths, str): 53 | mkdir(paths) 54 | else: 55 | for path in paths: 56 | mkdir(path) 57 | 58 | 59 | def mkdir_and_rename(path): 60 | if os.path.exists(path): 61 | new_name = path + '_archived_' + get_timestamp() 62 | print('Path already exists. Rename it to [{:s}]'.format(new_name)) 63 | logger = logging.getLogger('base') 64 | logger.info('Path already exists. Rename it to [{:s}]'.format(new_name)) 65 | os.rename(path, new_name) 66 | os.makedirs(path) 67 | 68 | 69 | def set_random_seed(seed): 70 | random.seed(seed) 71 | np.random.seed(seed) 72 | torch.manual_seed(seed) 73 | torch.cuda.manual_seed(seed) 74 | torch.cuda.manual_seed_all(seed) 75 | 76 | 77 | def setup_logger(logger_name, root, phase, level=logging.INFO, screen=False, tofile=False): 78 | '''set up logger''' 79 | lg = logging.getLogger(logger_name) 80 | formatter = logging.Formatter('%(asctime)s.%(msecs)03d - %(levelname)s: %(message)s', 81 | datefmt='%y-%m-%d %H:%M:%S') 82 | lg.setLevel(level) 83 | if tofile: 84 | log_file = os.path.join(root, phase + '_{}.log'.format(get_timestamp())) 85 | fh = logging.FileHandler(log_file, mode='w') 86 | fh.setFormatter(formatter) 87 | lg.addHandler(fh) 88 | if screen: 89 | sh = logging.StreamHandler() 90 | sh.setFormatter(formatter) 91 | lg.addHandler(sh) 92 | 93 | 94 | #################### 95 | # image convert 96 | #################### 97 | 98 | 99 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 100 | ''' 101 | Converts a torch Tensor into an image Numpy array 102 | Input: 4D(B,(3/1),H,W), 3D(C,H,W), or 2D(H,W), any range, RGB channel order 103 | Output: 3D(H,W,C) or 2D(H,W), [0,255], np.uint8 (default) 104 | ''' 105 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) # clamp 106 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) # to range [0,1] 107 | n_dim = tensor.dim() 108 | if n_dim == 4: 109 | n_img = len(tensor) 110 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), normalize=False).numpy() 111 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 112 | elif n_dim == 3: 113 | img_np = tensor.numpy() 114 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) # HWC, BGR 115 | elif n_dim == 2: 116 | img_np = tensor.numpy() 117 | else: 118 | raise TypeError( 119 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 120 | if out_type == np.uint8: 121 | img_np = (img_np * 255.0).round() 122 | # Important. Unlike matlab, numpy.unit8() WILL NOT round by default. 123 | return img_np.astype(out_type) 124 | 125 | 126 | def save_img(img, img_path, mode='RGB'): 127 | cv2.imwrite(img_path, img) 128 | 129 | 130 | #################### 131 | # metric 132 | #################### 133 | 134 | 135 | def calculate_psnr(img1, img2): 136 | # img1 and img2 have range [0, 255] 137 | img1 = img1.astype(np.float64) 138 | img2 = img2.astype(np.float64) 139 | mse = np.mean((img1 - img2)**2) 140 | if mse == 0: 141 | return float('inf') 142 | return 20 * math.log10(255.0 / math.sqrt(mse)) 143 | 144 | 145 | def ssim(img1, img2): 146 | C1 = (0.01 * 255)**2 147 | C2 = (0.03 * 255)**2 148 | 149 | img1 = img1.astype(np.float64) 150 | img2 = img2.astype(np.float64) 151 | kernel = cv2.getGaussianKernel(11, 1.5) 152 | window = np.outer(kernel, kernel.transpose()) 153 | 154 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid 155 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 156 | mu1_sq = mu1**2 157 | mu2_sq = mu2**2 158 | mu1_mu2 = mu1 * mu2 159 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 160 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 161 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 162 | 163 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 164 | (sigma1_sq + sigma2_sq + C2)) 165 | return ssim_map.mean() 166 | 167 | 168 | def calculate_ssim(img1, img2): 169 | '''calculate SSIM 170 | the same outputs as MATLAB's 171 | img1, img2: [0, 255] 172 | ''' 173 | if not img1.shape == img2.shape: 174 | raise ValueError('Input images must have the same dimensions.') 175 | if img1.ndim == 2: 176 | return ssim(img1, img2) 177 | elif img1.ndim == 3: 178 | if img1.shape[2] == 3: 179 | ssims = [] 180 | for i in range(3): 181 | ssims.append(ssim(img1, img2)) 182 | return np.array(ssims).mean() 183 | elif img1.shape[2] == 1: 184 | return ssim(np.squeeze(img1), np.squeeze(img2)) 185 | else: 186 | raise ValueError('Wrong input image dimensions.') 187 | 188 | 189 | class ProgressBar(object): 190 | '''A progress bar which can print the progress 191 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py 192 | ''' 193 | 194 | def __init__(self, task_num=0, bar_width=50, start=True): 195 | self.task_num = task_num 196 | max_bar_width = self._get_max_bar_width() 197 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) 198 | self.completed = 0 199 | if start: 200 | self.start() 201 | 202 | def _get_max_bar_width(self): 203 | terminal_width, _ = get_terminal_size() 204 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) 205 | if max_bar_width < 10: 206 | print('terminal width is too small ({}), please consider widen the terminal for better ' 207 | 'progressbar visualization'.format(terminal_width)) 208 | max_bar_width = 10 209 | return max_bar_width 210 | 211 | def start(self): 212 | if self.task_num > 0: 213 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format( 214 | ' ' * self.bar_width, self.task_num, 'Start...')) 215 | else: 216 | sys.stdout.write('completed: 0, elapsed: 0s') 217 | sys.stdout.flush() 218 | self.start_time = time.time() 219 | 220 | def update(self, msg='In progress...'): 221 | self.completed += 1 222 | elapsed = time.time() - self.start_time 223 | fps = self.completed / elapsed 224 | if self.task_num > 0: 225 | percentage = self.completed / float(self.task_num) 226 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 227 | mark_width = int(self.bar_width * percentage) 228 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) 229 | sys.stdout.write('\033[2F') # cursor up 2 lines 230 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) 231 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format( 232 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg)) 233 | else: 234 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( 235 | self.completed, int(elapsed + 0.5), fps)) 236 | sys.stdout.flush() 237 | -------------------------------------------------------------------------------- /figures/0913.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/figures/0913.png -------------------------------------------------------------------------------- /figures/0935.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/figures/0935.png -------------------------------------------------------------------------------- /figures/arch.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/figures/arch.png -------------------------------------------------------------------------------- /figures/df2k.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/figures/df2k.png -------------------------------------------------------------------------------- /figures/dped.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/figures/dped.png -------------------------------------------------------------------------------- /figures/track1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/figures/track1.png -------------------------------------------------------------------------------- /figures/track2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yoon28/realsr-noise-injection/402679490bf0972d09aaaadee3b5b9850c2a36e4/figures/track2.png -------------------------------------------------------------------------------- /yoon/degradation_pair_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | import glob 5 | 6 | import random 7 | import numpy as np 8 | import cv2 9 | import torch 10 | from torch.utils.data import Dataset, DataLoader 11 | from torchvision import transforms 12 | import torchvision.transforms.functional as TF 13 | from torchvision.datasets import ImageFolder 14 | 15 | from img_resize import imresize 16 | from scipy.io import loadmat 17 | 18 | def img_random_crop(img, length, yx=None): 19 | img_size = img.shape 20 | y_range = img_size[0] - length + 1 21 | x_range = img_size[1] - length + 1 22 | if yx: 23 | y = yx[0] 24 | x = yx[1] 25 | else: 26 | [y, x] = np.random.randint(0, [y_range, x_range]) 27 | return img[y:(y+length), x:(x+length)], [y, x] 28 | 29 | class TestDataSR(Dataset): 30 | def __init__(self, lr_folder, gt_folder=None, test_folder=None, permute=False, bgr2rgb=True): 31 | super(TestDataSR).__init__() 32 | self.bgr2rgb = bgr2rgb 33 | self.permute = permute 34 | self.lr_folder = lr_folder 35 | self.lr_files = glob.glob(os.path.join(lr_folder, "**.png")) 36 | self.lr_files.sort() 37 | self.gt_folder = None 38 | if gt_folder: 39 | self.gt_folder = gt_folder 40 | 41 | if permute: 42 | np.random.shuffle(self.lr_files) 43 | 44 | self.selected_test_samples = ["0913.png", "0935.png"] 45 | self.test_samples = [] 46 | self.test_folder = test_folder 47 | if test_folder: 48 | for s in self.selected_test_samples: 49 | afile = os.path.join(test_folder, s) 50 | img = cv2.imread(afile) 51 | if bgr2rgb: 52 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 53 | self.test_samples.append(img) 54 | 55 | def get_test_sample(self, index): 56 | return {"LQ":TF.to_tensor(self.test_samples[index]).unsqueeze(0), 57 | "LQ_path":[os.path.join(self.test_folder, self.selected_test_samples[index])]} 58 | 59 | def __len__(self): 60 | return len(self.lr_files) 61 | 62 | def __getitem__(self, index): 63 | sample = dict() 64 | img = cv2.imread(self.lr_files[index]) 65 | if self.bgr2rgb: 66 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 67 | lr = TF.to_tensor(img) 68 | sample["LQ"] = lr 69 | sample["LQ_path"] = self.lr_files[index] 70 | if self.gt_folder: 71 | filename = os.path.basename(self.lr_files[index]) 72 | gt = cv2.imread(os.path.join(self.gt_folder, filename)) 73 | if self.bgr2rgb: 74 | gt = cv2.cvtColor(gt, cv2.COLOR_BGR2RGB) 75 | gt = TF.to_tensor(gt) 76 | sample["GT"] = gt 77 | sample["GT_path"] = os.path.join(self.gt_folder, filename) 78 | return sample 79 | 80 | class DegradationParing(Dataset): 81 | def __init__(self, hr_folder, kern_folder, noise_folder, scale_factor, sr_size, corrupted_folder=None, clean_sc=2, permute=True, bgr2rgb=True): 82 | super(DegradationParing, self).__init__() 83 | assert(0 < scale_factor <= 1) 84 | self.scale_factor = scale_factor 85 | self.hr_size = sr_size 86 | self.lr_size = int(sr_size * scale_factor) 87 | self.bgr2rgb = bgr2rgb # opencv loader 88 | self.hr_folder = hr_folder 89 | self.kern_folder = kern_folder 90 | self.noise_folder = noise_folder 91 | self.corrupted_folder = corrupted_folder 92 | self.hr_files = glob.glob(os.path.join(hr_folder, "**.png")) 93 | self.kern_files = glob.glob(os.path.join(kern_folder, "**/*_x4.mat")) 94 | self.noise_files = glob.glob(os.path.join(noise_folder, "**.png")) 95 | self.clean_sc = clean_sc 96 | if self.corrupted_folder: 97 | self.corrupted_folder = glob.glob(os.path.join(corrupted_folder, "**.png")) 98 | self.hr_files.extend(self.corrupted_folder) 99 | 100 | if permute: 101 | np.random.shuffle(self.hr_files) 102 | np.random.shuffle(self.kern_files) 103 | np.random.shuffle(self.noise_files) 104 | 105 | def add_noise_and_preproc(self, hr_im, lr_im, noise): 106 | lr, yx = img_random_crop(lr_im, self.lr_size) 107 | yx[0] = int(round(yx[0] / self.scale_factor)) 108 | yx[1] = int(round(yx[1] / self.scale_factor)) 109 | hr, _ = img_random_crop(hr_im, self.hr_size, yx) 110 | z_im, _ = img_random_crop(noise, self.lr_size) 111 | 112 | if random.random() < 0.5: 113 | z_im = cv2.flip(z_im, 0) # vertical 114 | if random.random() < 0.5: 115 | z_im = cv2.flip(z_im, 1) # horizontal 116 | z_mean = np.mean(z_im.reshape(-1, 3), 0).reshape((1, 1, 3)) 117 | z = z_im.astype(np.float) - z_mean 118 | lr = np.clip(np.round(lr.astype(np.float) + z), 0, 255).astype(np.uint8) # add noise 119 | 120 | if random.random() < 0.5: 121 | hr = cv2.flip(hr, 0) # vertical flip 122 | lr = cv2.flip(lr, 0) 123 | 124 | if random.random() < 0.5: 125 | hr = cv2.flip(hr, 1) # horizontal 126 | lr = cv2.flip(lr, 1) 127 | 128 | # cv2.imshow("hr", hr[:,:,::-1]) 129 | # cv2.imshow("lr", lr[:,:,::-1]) 130 | # cv2.imshow("z", z_im[:,:,::-1]) 131 | 132 | hr = TF.to_tensor(hr) 133 | lr = TF.to_tensor(lr) 134 | return hr, lr, z_im 135 | 136 | def __getitem__(self, index): 137 | kernel = np.array(loadmat(self.kern_files[index])["Kernel"]) 138 | rand_idx = np.random.randint(0, [len(self.hr_files), len(self.noise_files)]) 139 | hr_file = self.hr_files[rand_idx[0]] 140 | hr_img = cv2.imread(hr_file) 141 | if os.path.basename(hr_file)[:8] == "Flickr2K": # hr_file in self.corrupted_folder: 142 | hr_img = cv2.resize(hr_img, dsize=(0, 0), fx=1./self.clean_sc, fy=1./self.clean_sc, interpolation=cv2.INTER_AREA) 143 | noise = cv2.imread(self.noise_files[rand_idx[1]]) 144 | hr_img, _ = img_random_crop(hr_img, int(self.hr_size * 2)) 145 | if self.bgr2rgb: 146 | hr_img = cv2.cvtColor(hr_img, cv2.COLOR_BGR2RGB) # hr_img[:, :, [2, 1, 0]] 147 | noise = cv2.cvtColor(noise, cv2.COLOR_BGR2RGB) # noise[:, :, ::-1] 148 | dwn_img = imresize(hr_img, self.scale_factor, kernel=kernel) 149 | hr, lr, z = self.add_noise_and_preproc(hr_img, dwn_img, noise) 150 | # cv2.imshow("hr_ori", hr_img[:,:,::-1]) 151 | # cv2.imshow("lr_ori", dwn_img[:,:,::-1]) 152 | # cv2.waitKey() 153 | return {"GT":hr, "LQ":lr, "z":z, "LQ_path":self.hr_files[rand_idx[0]], "GT_path":self.hr_files[rand_idx[0]]} 154 | 155 | def __len__(self): 156 | return len(self.kern_files) 157 | 158 | def valid(): 159 | lr_folder = "/mnt/data/NTIRE2020/realSR/track1/Corrupted-va-x" 160 | gt_folder = "/mnt/data/NTIRE2020/realSR/track1/DIV2K_valid_HR" 161 | test_folder = "/mnt/data/NTIRE2020/realSR/track1/Corrupted-te-x" 162 | data = TestDataSR(lr_folder, gt_folder, test_folder) 163 | 164 | n_test = len(data.test_samples) 165 | for t in range(n_test): 166 | t_im = data.get_test_sample(t) 167 | lr = t_im["LQ"][0].cpu().numpy().transpose(1, 2, 0) * 255 168 | lr = np.round(lr).astype(np.uint8) 169 | cv2.imshow("te_{}".format(t), lr) 170 | 171 | for i in range(len(data)): 172 | lr = data[i]["LQ"].cpu().numpy().transpose(1, 2, 0) * 255 173 | lr = np.round(lr).astype(np.uint8) 174 | gt = data[i]["GT"].cpu().numpy().transpose(1, 2, 0) * 255 175 | gt = np.round(gt).astype(np.uint8) 176 | cv2.imshow("lr", lr[:, :, ::-1]) 177 | cv2.imshow("gt", gt[:, :, ::-1]) 178 | cv2.waitKey() 179 | 180 | if __name__ == "__main__": 181 | #valid() 182 | seed_num = 0 183 | torch.manual_seed(seed_num) 184 | torch.cuda.manual_seed(seed_num) 185 | torch.cuda.manual_seed_all(seed_num) 186 | torch.backends.cudnn.deterministic = True 187 | torch.backends.cudnn.benchmark = False 188 | np.random.seed(seed_num) 189 | random.seed(seed_num) 190 | 191 | hr_folder = "/mnt/data/NTIRE2020/realSR/track1/Corrupted-tr-y" 192 | kern_folder = "yoon/kernels/track1" 193 | noise_folder = "yoon/noises/track1/p128_v100" 194 | corrupted_folder = "/mnt/data/NTIRE2020/realSR/track1/Corrupted-tr-x" 195 | 196 | data = DegradationParing(hr_folder, kern_folder, noise_folder, 0.25, 64, corrupted_folder=corrupted_folder, clean_sc=2) 197 | loader = DataLoader(data, batch_size=4, shuffle=True) 198 | for i, D in enumerate(loader): 199 | for j in range(len(D["GT"])): 200 | hr = D["GT"][j].cpu().numpy().transpose(1, 2, 0) * 255 201 | lr = D["LQ"][j].cpu().numpy().transpose(1, 2, 0) * 255 202 | hr = np.round(hr).astype(np.uint8) 203 | lr = np.round(lr).astype(np.uint8) 204 | cv2.imshow("hr_{}".format(j), hr[:, :, ::-1]) 205 | cv2.imshow("lr_{}".format(j), lr[:, :, ::-1]) 206 | cv2.imshow("z_{}".format(j), D["z"][j].cpu().numpy()) 207 | cv2.waitKey() 208 | print("fin.") 209 | -------------------------------------------------------------------------------- /yoon/img_resize.py: -------------------------------------------------------------------------------- 1 | """ 2 | This code is copied from https://github.com/sefibk/KernelGAN/blob/master/imresize.py 3 | """ 4 | 5 | import numpy as np 6 | from scipy.ndimage import filters, measurements, interpolation 7 | from math import pi 8 | 9 | 10 | def imresize(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False): 11 | # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa 12 | scale_factor, output_shape = fix_scale_and_size(im.shape, output_shape, scale_factor) 13 | 14 | # For a given numeric kernel case, just do convolution and sub-sampling (downscaling only) 15 | if type(kernel) == np.ndarray and scale_factor[0] <= 1: 16 | return numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag) 17 | 18 | # Choose interpolation method, each method has the matching kernel size 19 | method, kernel_width = { 20 | "cubic": (cubic, 4.0), 21 | "lanczos2": (lanczos2, 4.0), 22 | "lanczos3": (lanczos3, 6.0), 23 | "box": (box, 1.0), 24 | "linear": (linear, 2.0), 25 | None: (cubic, 4.0) # Default interpolation method is cubic 26 | }.get(kernel) 27 | 28 | # Antialiasing is only used when downscaling 29 | antialiasing *= (scale_factor[0] < 1) 30 | 31 | # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient 32 | sorted_dims = np.argsort(np.array(scale_factor)).tolist() 33 | 34 | # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction 35 | out_im = np.copy(im) 36 | for dim in sorted_dims: 37 | # No point doing calculations for scale-factor 1. nothing will happen anyway 38 | if scale_factor[dim] == 1.0: 39 | continue 40 | 41 | # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the 42 | # weights that multiply the values there to get its result. 43 | weights, field_of_view = contributions(im.shape[dim], output_shape[dim], scale_factor[dim], 44 | method, kernel_width, antialiasing) 45 | 46 | # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim 47 | out_im = resize_along_dim(out_im, dim, weights, field_of_view) 48 | 49 | return out_im 50 | 51 | 52 | def fix_scale_and_size(input_shape, output_shape, scale_factor): 53 | # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the 54 | # same size as the number of input dimensions) 55 | if scale_factor is not None: 56 | # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it. 57 | if np.isscalar(scale_factor): 58 | scale_factor = [scale_factor, scale_factor] 59 | 60 | # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales 61 | scale_factor = list(scale_factor) 62 | scale_factor.extend([1] * (len(input_shape) - len(scale_factor))) 63 | 64 | # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size 65 | # to all the unspecified dimensions 66 | if output_shape is not None: 67 | output_shape = list(np.uint(np.array(output_shape))) + list(input_shape[len(output_shape):]) 68 | 69 | # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is 70 | # sub-optimal, because there can be different scales to the same output-shape. 71 | if scale_factor is None: 72 | scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape) 73 | 74 | # Dealing with missing output-shape. calculating according to scale-factor 75 | if output_shape is None: 76 | output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor))) 77 | 78 | return scale_factor, output_shape 79 | 80 | 81 | def contributions(in_length, out_length, scale, kernel, kernel_width, antialiasing): 82 | # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied 83 | # such that each position from the field_of_view will be multiplied with a matching filter from the 84 | # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers 85 | # around it. This is only done for one dimension of the image. 86 | 87 | # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of 88 | # 1/sf. this means filtering is more 'low-pass filter'. 89 | fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel 90 | kernel_width *= 1.0 / scale if antialiasing else 1.0 91 | 92 | # These are the coordinates of the output image 93 | out_coordinates = np.arange(1, out_length + 1) 94 | 95 | # These are the matching positions of the output-coordinates on the input image coordinates. 96 | # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels: 97 | # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel. 98 | # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to 99 | # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big 100 | # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor). 101 | # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is 102 | # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means: 103 | # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf) 104 | match_coordinates = 1.0 * out_coordinates / scale + 0.5 * (1 - 1.0 / scale) 105 | 106 | # This is the left boundary to start multiplying the filter from, it depends on the size of the filter 107 | left_boundary = np.floor(match_coordinates - kernel_width / 2) 108 | 109 | # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers 110 | # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them) 111 | expanded_kernel_width = np.ceil(kernel_width) + 2 112 | 113 | # Determine a set of field_of_view for each each output position, these are the pixels in the input image 114 | # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the 115 | # vertical dim is the pixels it 'sees' (kernel_size + 2) 116 | field_of_view = np.squeeze(np.uint(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)) 117 | 118 | # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the 119 | # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in 120 | # 'field_of_view') 121 | weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1) 122 | 123 | # Normalize weights to sum up to 1. be careful from dividing by 0 124 | sum_weights = np.sum(weights, axis=1) 125 | sum_weights[sum_weights == 0] = 1.0 126 | weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1) 127 | 128 | # We use this mirror structure as a trick for reflection padding at the boundaries 129 | mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1)))) 130 | field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])] 131 | 132 | # Get rid of weights and pixel positions that are of zero weight 133 | non_zero_out_pixels = np.nonzero(np.any(weights, axis=0)) 134 | weights = np.squeeze(weights[:, non_zero_out_pixels]) 135 | field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels]) 136 | 137 | # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size 138 | return weights, field_of_view 139 | 140 | 141 | def resize_along_dim(im, dim, weights, field_of_view): 142 | # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize 143 | tmp_im = np.swapaxes(im, dim, 0) 144 | 145 | # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for 146 | # tmp_im[field_of_view.T], (bsxfun style) 147 | weights = np.reshape(weights.T, list(weights.T.shape) + (np.ndim(im) - 1) * [1]) 148 | 149 | # This is a bit of a complicated multiplication: tmp_im[field_of_view.T] is a tensor of order image_dims+1. 150 | # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim 151 | # only, this is why it only adds 1 dim to the shape). We then multiply, for each pixel, its set of positions with 152 | # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style: 153 | # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the 154 | # same number 155 | tmp_out_im = np.sum(tmp_im[field_of_view.T] * weights, axis=0) 156 | 157 | # Finally we swap back the axes to the original order 158 | return np.swapaxes(tmp_out_im, dim, 0) 159 | 160 | 161 | def numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag): 162 | # See kernel_shift function to understand what this is 163 | if kernel_shift_flag: 164 | kernel = kernel_shift(kernel, scale_factor) 165 | 166 | # First run a correlation (convolution with flipped kernel) 167 | out_im = np.zeros_like(im) 168 | for channel in range(np.ndim(im)): 169 | out_im[:, :, channel] = filters.correlate(im[:, :, channel], kernel) 170 | 171 | # Then subsample and return 172 | return out_im[np.round(np.linspace(0, im.shape[0] - 1 / scale_factor[0], output_shape[0])).astype(int)[:, None], 173 | np.round(np.linspace(0, im.shape[1] - 1 / scale_factor[1], output_shape[1])).astype(int), :] 174 | 175 | 176 | def kernel_shift(kernel, sf): 177 | # There are two reasons for shifting the kernel: 178 | # 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know 179 | # the degradation process included shifting so we always assume center of mass is center of the kernel. 180 | # 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first 181 | # pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the 182 | # top left corner of the first pixel. that is why different shift size needed between od and even size. 183 | # Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows: 184 | # The input image, when interpolated (regular bicubic) is exactly aligned with ground truth. 185 | 186 | # First calculate the current center of mass for the kernel 187 | current_center_of_mass = measurements.center_of_mass(kernel) 188 | 189 | # The second ("+ 0.5 * ....") is for applying condition 2 from the comments above 190 | wanted_center_of_mass = np.array(kernel.shape) // 2 + 0.5 * (sf - (kernel.shape[0] % 2)) 191 | # wanted_center_of_mass = np.array(kernel.shape) / 2 + 0.5 * (np.array(sf)[0:2] - (kernel.shape[0] % 2)) 192 | 193 | # Define the shift vector for the kernel shifting (x,y) 194 | shift_vec = wanted_center_of_mass - current_center_of_mass 195 | 196 | # Before applying the shift, we first pad the kernel so that nothing is lost due to the shift 197 | # (biggest shift among dims + 1 for safety) 198 | kernel = np.pad(kernel, np.int(np.ceil(np.max(shift_vec))) + 1, 'constant') 199 | 200 | # Finally shift the kernel and return 201 | return interpolation.shift(kernel, shift_vec) 202 | 203 | 204 | # These next functions are all interpolation methods. x is the distance from the left pixel center 205 | 206 | 207 | def cubic(x): 208 | absx = np.abs(x) 209 | absx2 = absx ** 2 210 | absx3 = absx ** 3 211 | return ((1.5 * absx3 - 2.5 * absx2 + 1) * (absx <= 1) + 212 | (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ((1 < absx) & (absx <= 2))) 213 | 214 | 215 | def lanczos2(x): 216 | return (((np.sin(pi * x) * np.sin(pi * x / 2) + np.finfo(np.float32).eps) / 217 | ((pi ** 2 * x ** 2 / 2) + np.finfo(np.float32).eps)) 218 | * (abs(x) < 2)) 219 | 220 | 221 | def box(x): 222 | return ((-0.5 <= x) & (x < 0.5)) * 1.0 223 | 224 | 225 | def lanczos3(x): 226 | return (((np.sin(pi * x) * np.sin(pi * x / 3) + np.finfo(np.float32).eps) / 227 | ((pi ** 2 * x ** 2 / 3) + np.finfo(np.float32).eps)) 228 | * (abs(x) < 3)) 229 | 230 | 231 | def linear(x): 232 | return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1)) 233 | 234 | if __name__ == "__main__": 235 | import cv2 236 | from scipy.io import loadmat 237 | fake_im = np.random.randint(255, size=[500, 500, 3], dtype=np.uint8) 238 | kernel = np.array(loadmat("yoon/kernels/track1/Flickr2K_000001/Flickr2K_000001_kernel_x2.mat")["Kernel"]) 239 | # kernel = np.ones([17, 17]) / (17 * 17) 240 | fake_lr = imresize(fake_im, 0.5, kernel=kernel) 241 | cv2.imshow("fake", fake_im) 242 | cv2.imshow("lr", fake_lr) 243 | cv2.waitKey() 244 | print("fin.") -------------------------------------------------------------------------------- /yoon/options/train_df2k.yml: -------------------------------------------------------------------------------- 1 | name: Track1 2 | use_tb_logger: true 3 | suffix: ~ # add suffix to saved images 4 | model: srgan 5 | distortion: sr 6 | scale: 4 7 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 8 | gpu_ids: [0, 1] 9 | 10 | datasets: 11 | train: # the 1st test dataset 12 | name: DIV2K 13 | mode: yoon 14 | dataroot_GT: /mnt/data/NTIRE2020/realSR/track1/Corrupted-tr-y 15 | kernel_folder: yoon/kernels/track1 16 | noise_folder: yoon/noises/track1/p128_v100 17 | 18 | use_shuffle: true 19 | n_workers: 24 20 | batch_size: 48 21 | GT_size: 128 22 | use_flip: true 23 | use_rot: false 24 | color: RGB 25 | 26 | val: 27 | name: DIV2K_VAL 28 | mode: yoon 29 | dataroot_LQ: /mnt/data/NTIRE2020/realSR/track1/Corrupted-va-x 30 | dataroot_GT: /mnt/data/NTIRE2020/realSR/track1/DIV2K_valid_HR 31 | 32 | num_val: 20 33 | use_shuffle: false 34 | color: RGB 35 | 36 | #### network structures 37 | network_G: 38 | which_model_G: RRDBNet 39 | in_nc: 3 40 | out_nc: 3 41 | nf: 64 42 | nb: 23 43 | upscale: 4 44 | 45 | network_D: 46 | which_model_D: NLayerDiscriminator 47 | in_nc: 3 48 | nf: 64 49 | nlayer: 3 50 | norm_layer: instancenorm # batchnorm 51 | 52 | #### path 53 | path: 54 | pretrain_model_G: pretrained_model/esrgan/RRDB_ESRGAN_x4.pth 55 | results_root: ./results/ 56 | 57 | #### training settings: learning rate scheme, loss 58 | train: 59 | lr_G: !!float 1e-4 60 | weight_decay_G: 0 61 | beta1_G: 0.9 62 | beta2_G: 0.99 63 | lr_D: !!float 1e-4 64 | weight_decay_D: 0 65 | beta1_D: 0.9 66 | beta2_D: 0.99 67 | lr_scheme: MultiStepLR 68 | 69 | niter: 40000 70 | warmup_iter: -1 # no warm up 71 | lr_steps: [4000, 8000, 12000] 72 | lr_gamma: 0.5 73 | 74 | pixel_criterion: l1 75 | pixel_weight: !!float 1e-2 76 | feature_criterion: l1 77 | feature_weight: 1 78 | gan_type: ragan # gan | ragan 79 | gan_weight: !!float 5e-3 80 | 81 | D_update_ratio: 1 82 | D_init_iters: 0 83 | 84 | manual_seed: 0 85 | val_freq: !!float 2e3 86 | 87 | #### logger 88 | logger: 89 | print_freq: 100000 90 | save_checkpoint_freq: !!float 2e3 -------------------------------------------------------------------------------- /yoon/options/train_dped.yml: -------------------------------------------------------------------------------- 1 | name: Track2 2 | suffix: ~ # add suffix to saved images 3 | model: srgan 4 | distortion: sr 5 | scale: 4 6 | crop_border: ~ # crop border when evaluation. If None(~), crop the scale pixels 7 | gpu_ids: [0] 8 | 9 | datasets: 10 | test_1: # the 1st test dataset 11 | name: DPED 12 | mode: LR 13 | dataroot_LR: /mnt/data/NTIRE2020/realSR/track2/DPEDiphone-crop-te-x 14 | 15 | #### network structures 16 | network_G: 17 | which_model_G: RRDBNet 18 | in_nc: 3 19 | out_nc: 3 20 | nf: 64 21 | nb: 23 22 | upscale: 4 23 | 24 | #### path 25 | path: 26 | pretrain_model_G: pretrained_model/origin/DPED.pth 27 | results_root: ./results/ 28 | 29 | back_projection: False 30 | back_projection_lamda: !!float 0.2 -------------------------------------------------------------------------------- /yoon/stage1_kernel.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | import numpy as np 4 | import cv2 5 | import random 6 | import torch 7 | 8 | from configs import Config 9 | from kernelGAN import KernelGAN 10 | from data import DataGenerator 11 | from learner import Learner 12 | 13 | import tqdm 14 | 15 | DATA_LOC = "/mnt/data/NTIRE2020/realSR/track2" # "/mnt/data/NTIRE2020/realSR/track1" 16 | DATA_X = "DPEDiphone-tr-x" # "Corrupted-tr-x" 17 | DATA_Y = "DPEDiphone-tr-y" # "Corrupted-tr-y" 18 | DATA_VAL = "DPEDiphone-va" # "Corrupted-va-x" 19 | 20 | def config_kernelGAN(afile): 21 | img_folder = os.path.dirname(afile) 22 | img_file = os.path.basename(afile) 23 | out_dir = "yoon/kernels/track2" 24 | 25 | params = ["--input_image_path", afile, 26 | "--output_dir_path", out_dir, 27 | "--noise_scale", str(1.0), 28 | "--X4"] 29 | conf = Config().parse(params) 30 | conf.input2 = None 31 | return conf 32 | 33 | def estimate_kernel(img_file): 34 | conf = config_kernelGAN(img_file) 35 | kgan = KernelGAN(conf) 36 | learner = Learner() 37 | data = DataGenerator(conf, kgan) 38 | for iteration in tqdm.tqdm(range(conf.max_iters), ncols=70): 39 | [g_in, d_in, _] = data.__getitem__(iteration) 40 | kgan.train(g_in, d_in) 41 | learner.update(iteration, kgan) 42 | kgan.finish() 43 | 44 | if __name__ == "__main__": 45 | seed_num = 0 46 | torch.manual_seed(seed_num) 47 | torch.cuda.manual_seed(seed_num) 48 | torch.cuda.manual_seed_all(seed_num) 49 | torch.backends.cudnn.deterministic = True 50 | torch.backends.cudnn.benchmark = False 51 | np.random.seed(seed_num) 52 | random.seed(seed_num) 53 | 54 | # exit(0) 55 | 56 | data = {"X":[os.path.join(DATA_LOC, DATA_X, f) for f in os.listdir(os.path.join(DATA_LOC, DATA_X)) if f[-4:] == ".png"], 57 | "Y":[os.path.join(DATA_LOC, DATA_Y, f) for f in os.listdir(os.path.join(DATA_LOC, DATA_Y)) if f[-4:] == ".png"], 58 | "val":[os.path.join(DATA_LOC, DATA_VAL, f) for f in os.listdir(os.path.join(DATA_LOC, DATA_VAL)) if f[-4:] == ".png"]} 59 | 60 | Kernels = [] 61 | Noises = [] 62 | for f in data["X"]: 63 | estimate_kernel(f) 64 | print("fin.") 65 | -------------------------------------------------------------------------------- /yoon/stage1_noise.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | 3 | import numpy as np 4 | import cv2 5 | import random 6 | import torch 7 | 8 | DATA_LOC = "/mnt/data/NTIRE2020/realSR/track2" # "/mnt/data/NTIRE2020/realSR/track1" 9 | DATA_X = "DPEDiphone-tr-x" # "Corrupted-tr-x" # 10 | DATA_Y = "DPEDiphone-tr-y" # "Corrupted-tr-y" # 11 | DATA_VAL = "DPEDiphone-va" # "Corrupted-va-x" 12 | OUT_DIR = "yoon/noises/track2" 13 | 14 | def noises_estimation(img, rescale_0_1=False): 15 | patch_size = 64 16 | block_size = 16 17 | stride_g = patch_size // 2 18 | stride_l = block_size 19 | mu = 0.1 20 | lambd = 0.25 21 | noises = [] 22 | im_size = img.shape[:2] 23 | for y in range(0, im_size[0], stride_g): 24 | for x in range(0, im_size[1], stride_g): 25 | if x + patch_size > im_size[1] or y + patch_size > im_size[0]: 26 | continue 27 | if rescale_0_1: 28 | patch = img[y:(y+patch_size), x:(x+patch_size), :].astype(np.float) / 255.0 29 | else: 30 | patch = img[y:(y+patch_size), x:(x+patch_size), :].astype(np.float) 31 | mean_patch = np.mean(patch.reshape((-1, 3)), 0) 32 | var_patch = np.var(patch.reshape((-1, 3)), 0) 33 | 34 | is_noise = True 35 | for j in range(0, patch_size, stride_l): 36 | for i in range(0, patch_size, stride_l): 37 | if j+block_size > patch_size or i+block_size > patch_size: 38 | continue 39 | block = patch[j:(j+block_size), i:(i+block_size), :] 40 | #assert block.shape[0] == block_size and block.shape[1] == block_size 41 | mean_block = np.mean(block.reshape((-1, 3)), 0) 42 | var_block = np.var(block.reshape((-1, 3)), 0) 43 | if np.greater(np.abs(mean_block - mean_patch), mean_patch*mu).any(): 44 | is_noise = False 45 | break 46 | if np.greater(np.abs(var_block - var_patch), var_patch*lambd).any(): 47 | is_noise = False 48 | break 49 | if not is_noise: 50 | break 51 | if is_noise: 52 | noises.append(img[y:(y+patch_size), x:(x+patch_size), :]) 53 | return noises 54 | 55 | def noises_estimation_simple(img, rescale_0_1=False): 56 | patch_size = 128 57 | stride_g = patch_size // 2 58 | max_var = (10*10) 59 | min_var = (0*0) 60 | noises = [] 61 | im_size = img.shape[:2] 62 | for y in range(0, im_size[0], stride_g): 63 | for x in range(0, im_size[1], stride_g): 64 | if x + patch_size > im_size[1] or y + patch_size > im_size[0]: 65 | continue 66 | if rescale_0_1: 67 | patch = img[y:(y+patch_size), x:(x+patch_size), :].astype(np.float) / 255.0 68 | else: 69 | patch = img[y:(y+patch_size), x:(x+patch_size), :].astype(np.float) 70 | var_patch = np.var(patch.reshape((-1, 3)), 0) 71 | if np.less(var_patch, max_var).all() and np.greater(var_patch, min_var).all(): 72 | noises.append(img[y:(y+patch_size), x:(x+patch_size), :]) 73 | return noises 74 | 75 | if __name__ == "__main__": 76 | seed_num = 0 77 | torch.manual_seed(seed_num) 78 | torch.cuda.manual_seed(seed_num) 79 | torch.cuda.manual_seed_all(seed_num) 80 | torch.backends.cudnn.deterministic = True 81 | torch.backends.cudnn.benchmark = False 82 | np.random.seed(seed_num) 83 | random.seed(seed_num) 84 | 85 | exit(0) 86 | 87 | data = {"X":[os.path.join(DATA_LOC, DATA_X, f) for f in os.listdir(os.path.join(DATA_LOC, DATA_X)) if f[-4:] == ".png"], 88 | "Y":[os.path.join(DATA_LOC, DATA_Y, f) for f in os.listdir(os.path.join(DATA_LOC, DATA_Y)) if f[-4:] == ".png"], 89 | "val":[os.path.join(DATA_LOC, DATA_VAL, f) for f in os.listdir(os.path.join(DATA_LOC, DATA_VAL)) if f[-4:] == ".png"]} 90 | 91 | out_dir = os.path.join(OUT_DIR, "p128_v100") 92 | j = 0 93 | for f in data["X"]: 94 | print(f) 95 | img = cv2.imread(f) 96 | N = noises_estimation_simple(img) 97 | for n in N: 98 | j += 1 99 | filename = os.path.join(out_dir, "noise_{:08d}.png".format(j)) 100 | cv2.imwrite(filename, n) 101 | print("\tsaved: ", filename) 102 | 103 | # if __name__ == "__main__": 104 | # seed_num = 0 105 | # torch.manual_seed(seed_num) 106 | # torch.cuda.manual_seed(seed_num) 107 | # torch.cuda.manual_seed_all(seed_num) 108 | # torch.backends.cudnn.deterministic = True 109 | # torch.backends.cudnn.benchmark = False 110 | # np.random.seed(seed_num) 111 | # random.seed(seed_num) 112 | 113 | # data = {"X":[os.path.join(DATA_LOC, DATA_X, f) for f in os.listdir(os.path.join(DATA_LOC, DATA_X)) if f[-4:] == ".png"], 114 | # "Y":[os.path.join(DATA_LOC, DATA_Y, f) for f in os.listdir(os.path.join(DATA_LOC, DATA_Y)) if f[-4:] == ".png"], 115 | # "val":[os.path.join(DATA_LOC, DATA_VAL, f) for f in os.listdir(os.path.join(DATA_LOC, DATA_VAL)) if f[-4:] == ".png"]} 116 | # print(cv2.__version__) 117 | # Noises = [] 118 | # for j, f in enumerate(data["X"]): 119 | # img = cv2.imread(f) 120 | # N = noises_estimation_simple(img) 121 | # Noises.extend(N) 122 | # cv2.imshow("X", img) 123 | # for i, n in enumerate(N): 124 | # cv2.imshow("N_{}".format(i), n) 125 | # cv2.waitKey() 126 | # cv2.destroyAllWindows() 127 | # print(len(Noises)) 128 | # # if len(Noises) > 5: 129 | # # break 130 | 131 | # test_im = cv2.imread(data["Y"][0]) 132 | # im_size = test_im.shape[:2] 133 | # for i, n in enumerate(Noises): 134 | # test_canv = test_im.astype(np.float) / 255.0 135 | # N = n.astype(np.float) / 255.0 136 | # N_mean = np.mean(N.reshape((-1, 3)), 0).reshape((1, 1, 3)) 137 | # N_size = N.shape[:2] 138 | # z = (N-N_mean) 139 | # crop = test_canv[im_size[0]//2:im_size[0]//2+N_size[0], im_size[1]//2:im_size[1]//2+N_size[1], :] 140 | # corruped = crop + z 141 | # corruped = np.round((np.clip(corruped, 0, 1) * 255)).astype(np.uint8) 142 | # crop = (crop * 255).astype(np.uint8) 143 | # z = np.round(np.clip(z, 0, 1) * 255).astype(np.uint8) 144 | # cv2.imshow("noise", z) 145 | # cv2.imshow("noise_patch", n) 146 | # cv2.imshow("crop", crop) 147 | # cv2.imshow("orig", test_im) 148 | # cv2.imshow("corrup", corruped) 149 | # cv2.waitKey() 150 | # print(i) 151 | # print("fin.") 152 | -------------------------------------------------------------------------------- /yoon/train_realsr.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import argparse 4 | import random 5 | import logging 6 | import torch 7 | import torch.distributed as dist 8 | import torch.multiprocessing as mp 9 | 10 | from data.data_sampler import DistIterSampler 11 | import options.options as option 12 | from utils import util 13 | from data import create_dataloader, create_dataset 14 | from models import create_model 15 | 16 | 17 | def init_dist(backend='nccl', rank=0): 18 | ''' initialization for distributed training''' 19 | # if mp.get_start_method(allow_none=True) is None: 20 | if mp.get_start_method(allow_none=True) != 'spawn': 21 | mp.set_start_method('spawn') 22 | # rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, init_method="tcp://127.0.0.1:23571", world_size=num_gpus, rank=rank) 26 | 27 | 28 | def main(gpu, num_gpus): 29 | #### options 30 | parser = argparse.ArgumentParser() 31 | parser.add_argument('-opt', type=str, help='Path to option YMAL file.') 32 | parser.add_argument('--launcher', choices=['none', 'pytorch'], default='pytorch', 33 | help='job launcher') 34 | parser.add_argument('--local_rank', type=int, default=0) 35 | args = parser.parse_args() 36 | opt = option.parse(args.opt, is_train=True) 37 | 38 | #### distributed training settings 39 | if args.launcher == 'none': # disabled distributed training 40 | opt['dist'] = False 41 | rank = -1 42 | print('Disabled distributed training.') 43 | else: 44 | opt['dist'] = True 45 | init_dist(rank=gpu) 46 | world_size = torch.distributed.get_world_size() 47 | rank = torch.distributed.get_rank() 48 | 49 | #### loading resume state if exists 50 | if opt['path'].get('resume_state', None): 51 | # distributed resuming: all load into default GPU 52 | device_id = torch.cuda.current_device() 53 | resume_state = torch.load(opt['path']['resume_state'], 54 | map_location=lambda storage, loc: storage.cuda(device_id)) 55 | option.check_resume(opt, resume_state['iter']) # check resume options 56 | else: 57 | resume_state = None 58 | 59 | #### mkdir and loggers 60 | if rank <= 0: # normal training (rank -1) OR distributed training (rank 0) 61 | if resume_state is None: 62 | util.mkdir_and_rename( 63 | opt['path']['experiments_root']) # rename experiment folder if exists 64 | util.mkdirs((path for key, path in opt['path'].items() if not key == 'experiments_root' 65 | and 'pretrain_model' not in key and 'resume' not in key)) 66 | 67 | # config loggers. Before it, the log will not work 68 | util.setup_logger('base', opt['path']['log'], 'train_' + opt['name'], level=logging.INFO, 69 | screen=True, tofile=True) 70 | util.setup_logger('val', opt['path']['log'], 'val_' + opt['name'], level=logging.INFO, 71 | screen=True, tofile=True) 72 | logger = logging.getLogger('base') 73 | logger.info(option.dict2str(opt)) 74 | # tensorboard logger 75 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 76 | version = float(torch.__version__[0:3]) 77 | if version >= 1.1: # PyTorch 1.1 78 | from torch.utils.tensorboard import SummaryWriter 79 | else: 80 | logger.info( 81 | 'You are using PyTorch {}. Tensorboard will use [tensorboardX]'.format(version)) 82 | from tensorboardX import SummaryWriter 83 | tb_logger = SummaryWriter(log_dir='tb_logger/' + opt['name']) 84 | else: 85 | util.setup_logger('base', opt['path']['log'], 'train', level=logging.INFO, screen=True) 86 | logger = logging.getLogger('base') 87 | 88 | # convert to NoneDict, which returns None for missing keys 89 | opt = option.dict_to_nonedict(opt) 90 | 91 | #### random seed 92 | seed = opt['train']['manual_seed'] 93 | if seed is None: 94 | seed = random.randint(1, 10000) 95 | if rank <= 0: 96 | logger.info('Random seed: {}'.format(seed)) 97 | util.set_random_seed(seed) 98 | 99 | # torch.backends.cudnn.benckmark = True 100 | # torch.backends.cudnn.deterministic = True 101 | 102 | #### create train and val dataloader 103 | dataset_ratio = 1 # 200 # enlarge the size of each epoch 104 | train_loader = None 105 | val_loader = None 106 | for phase, dataset_opt in opt['datasets'].items(): 107 | if phase == 'train': 108 | # print('\n\n\n\n\n\n\n\n', dataset_opt) 109 | train_set = create_dataset(dataset_opt) 110 | train_size = int(math.ceil(len(train_set) / dataset_opt['batch_size'])) 111 | total_iters = int(opt['train']['niter']) 112 | total_epochs = int(math.ceil(total_iters / train_size)) 113 | if opt['dist']: 114 | train_sampler = DistIterSampler(train_set, world_size, rank, dataset_ratio) 115 | total_epochs = int(math.ceil(total_iters / (train_size * dataset_ratio))) 116 | else: 117 | train_sampler = None 118 | train_loader = create_dataloader(train_set, dataset_opt, opt, train_sampler) 119 | if rank <= 0: 120 | logger.info('Number of train images: {:,d}, iters: {:,d}'.format( 121 | len(train_set), train_size)) 122 | logger.info('Total epochs needed: {:d} for iters {:,d}'.format( 123 | total_epochs, total_iters)) 124 | elif phase == 'val': 125 | val_set = create_dataset(dataset_opt) 126 | val_loader = create_dataloader(val_set, dataset_opt, opt, None) 127 | if rank <= 0: 128 | logger.info('Number of val images in [{:s}]: {:d}'.format( 129 | dataset_opt['name'], len(val_set))) 130 | else: 131 | raise NotImplementedError('Phase [{:s}] is not recognized.'.format(phase)) 132 | assert train_loader is not None 133 | 134 | #### create model 135 | model = create_model(opt) 136 | 137 | #### resume training 138 | if resume_state: 139 | logger.info('Resuming training from epoch: {}, iter: {}.'.format( 140 | resume_state['epoch'], resume_state['iter'])) 141 | 142 | start_epoch = resume_state['epoch'] 143 | current_step = resume_state['iter'] 144 | model.resume_training(resume_state) # handle optimizers and schedulers 145 | else: 146 | current_step = 0 147 | start_epoch = 0 148 | 149 | #### training 150 | logger.info('Start training from epoch: {:d}, iter: {:d}'.format(start_epoch, current_step)) 151 | for epoch in range(start_epoch, total_epochs + 1): 152 | if opt['dist']: 153 | train_sampler.set_epoch(epoch) 154 | for ii, train_data in enumerate(train_loader): 155 | current_step += 1 156 | if current_step > total_iters: 157 | break 158 | #### update learning rate 159 | if ii > 0: 160 | model.update_learning_rate(current_step, warmup_iter=opt['train']['warmup_iter']) 161 | 162 | #### training 163 | model.feed_data(train_data) 164 | model.optimize_parameters(current_step) 165 | 166 | if rank <= 0: 167 | logs = model.get_current_log() 168 | message = "ep={}, iter={}/{}, lr={:6f}".format(epoch, current_step, total_iters, model.get_current_learning_rate()) 169 | for k, v in logs.items(): 170 | message += ', {:s}={:.5f}'.format(k, v) 171 | # tensorboard logger 172 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 173 | tb_logger.add_scalar(k, v, current_step) 174 | print(message + "\r", end=" ") 175 | #### log 176 | # if current_step % opt['logger']['print_freq'] == 0: 177 | # logs = model.get_current_log() 178 | # message = ' '.format( 179 | # epoch, current_step, model.get_current_learning_rate()) 180 | # for k, v in logs.items(): 181 | # message += '{:s}: {:.4e} '.format(k, v) 182 | # # tensorboard logger 183 | # if opt['use_tb_logger'] and 'debug' not in opt['name']: 184 | # if rank <= 0: 185 | # tb_logger.add_scalar(k, v, current_step) 186 | # if rank <= 0: 187 | # logger.info(message) 188 | 189 | # validation 190 | if current_step % opt['train']['val_freq'] == 0 and rank <= 0 and val_loader is not None: 191 | avg_psnr = val_pix_err_f = val_pix_err_nf = val_mean_color_err = 0.0 192 | idx = 0 193 | print("\n\tvalidation: ep={}, curr_step={}".format(epoch, current_step)) 194 | n_test_pick = len(val_set.test_samples) 195 | for ti in range(n_test_pick): 196 | test_sample = val_set.get_test_sample(ti) 197 | img_name = "TEST_" + os.path.splitext(os.path.basename(test_sample['LQ_path'][0]))[0] 198 | img_dir = os.path.join(opt['path']['val_images'], img_name) 199 | util.mkdir(img_dir) 200 | model.feed_data(test_sample, need_GT=False) 201 | model.test() 202 | visuals = model.get_current_visuals(need_GT=False) 203 | sr_img = util.tensor2img(visuals['SR']) # uint8 204 | save_img_path = os.path.join(img_dir, '{:s}_{:d}.png'.format(img_name, current_step)) 205 | util.save_img(sr_img, save_img_path) 206 | 207 | for val_data in val_loader: 208 | if idx >= opt['datasets']['val']['num_val']: 209 | break 210 | idx += 1 211 | img_name = os.path.splitext(os.path.basename(val_data['LQ_path'][0]))[0] 212 | img_dir = os.path.join(opt['path']['val_images'], img_name) 213 | # img_dir = opt['path']['val_images'] 214 | util.mkdir(img_dir) 215 | 216 | model.feed_data(val_data, need_GT=True) 217 | model.test() 218 | 219 | visuals = model.get_current_visuals(need_GT=True) 220 | sr_img = util.tensor2img(visuals['SR']) # uint8 221 | gt_img = util.tensor2img(visuals['GT']) # uint8 222 | 223 | # Save SR images for reference 224 | save_img_path = os.path.join(img_dir, 225 | '{:s}_{:d}.png'.format(img_name, current_step)) 226 | util.save_img(sr_img, save_img_path) 227 | 228 | # # calculate PSNR 229 | crop_size = opt['scale'] 230 | gt_img = gt_img / 255. 231 | sr_img = sr_img / 255. 232 | cropped_sr_img = sr_img[crop_size:-crop_size, crop_size:-crop_size, :] 233 | cropped_gt_img = gt_img[crop_size:-crop_size, crop_size:-crop_size, :] 234 | avg_psnr += util.calculate_psnr(cropped_sr_img * 255, cropped_gt_img * 255) 235 | avg_psnr = avg_psnr / idx 236 | val_pix_err_f /= idx 237 | val_pix_err_nf /= idx 238 | val_mean_color_err /= idx 239 | 240 | # log 241 | logger.info('# Validation # PSNR: {:.3f}'.format(avg_psnr)) 242 | logger_val = logging.getLogger('val') # validation logger 243 | logger_val.info(' psnr: {:.3f}'.format( 244 | epoch, current_step, avg_psnr)) 245 | # tensorboard logger 246 | if opt['use_tb_logger'] and 'debug' not in opt['name']: 247 | tb_logger.add_scalar('psnr', avg_psnr, current_step) 248 | tb_logger.add_scalar('val_pix_err_f', val_pix_err_f, current_step) 249 | tb_logger.add_scalar('val_pix_err_nf', val_pix_err_nf, current_step) 250 | tb_logger.add_scalar('val_mean_color_err', val_mean_color_err, current_step) 251 | 252 | #### save models and training states 253 | if current_step % opt['logger']['save_checkpoint_freq'] == 0: 254 | if rank <= 0: 255 | logger.info('Saving models and training states.') 256 | model.save(current_step) 257 | model.save_training_state(epoch, current_step) 258 | 259 | if rank <= 0: 260 | logger.info('Saving the final model.') 261 | model.save('latest') 262 | logger.info('End of training.') 263 | 264 | 265 | if __name__ == '__main__': 266 | torch.backends.cudnn.deterministic = True 267 | torch.backends.cudnn.benchmark = False 268 | num_gpus = torch.cuda.device_count() 269 | torch.multiprocessing.spawn(main, nprocs=num_gpus, args=(num_gpus, )) 270 | #main(-1, 2) 271 | --------------------------------------------------------------------------------