├── .gitignore ├── LICENSE ├── README.md ├── assets └── result.png ├── config.py ├── data └── README.md ├── dataset.py ├── imgproc.py ├── model.py ├── requirements.txt ├── results └── .gitkeep ├── samples └── .gitkeep ├── scripts ├── augment_dataset.py ├── data_utils.py ├── prepare_dataset.py ├── run.py └── split_train_valid_dataset.py ├── setup.py ├── train.py └── validate.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | pip-wheel-metadata/ 24 | share/python-wheels/ 25 | *.egg-info/ 26 | .installed.cfg 27 | *.egg 28 | MANIFEST 29 | 30 | # PyInstaller 31 | # Usually these files are written by a python script from a template 32 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 33 | *.manifest 34 | *.spec 35 | 36 | # Installer logs 37 | pip-log.txt 38 | pip-delete-this-directory.txt 39 | 40 | # Unit test / coverage reports 41 | htmlcov/ 42 | .tox/ 43 | .nox/ 44 | .coverage 45 | .coverage.* 46 | .cache 47 | nosetests.xml 48 | coverage.xml 49 | *.cover 50 | .hypothesis/ 51 | .pytest_cache/ 52 | 53 | # Translations 54 | *.mo 55 | *.pot 56 | 57 | # Django stuff: 58 | *.log 59 | local_settings.py 60 | db.sqlite3 61 | db.sqlite3-journal 62 | 63 | # Flask stuff: 64 | instance/ 65 | .webassets-cache 66 | 67 | # Scrapy stuff: 68 | .scrapy 69 | 70 | # Sphinx documentation 71 | docs/_build/ 72 | 73 | # PyBuilder 74 | target/ 75 | 76 | # Jupyter Notebook 77 | .ipynb_checkpoints 78 | 79 | # IPython 80 | profile_default/ 81 | ipython_config.py 82 | 83 | # pyenv 84 | .python-version 85 | 86 | # pipenv 87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 90 | # install all needed dependencies. 91 | #Pipfile.lock 92 | 93 | # celery beat schedule file 94 | celerybeat-schedule 95 | 96 | # SageMath parsed files 97 | *.sage.py 98 | 99 | # Environments 100 | .env 101 | .venv 102 | env/ 103 | venv/ 104 | ENV/ 105 | env.bak/ 106 | venv.bak/ 107 | 108 | # Spyder project settings 109 | .spyderproject 110 | .spyproject 111 | 112 | # Rope project settings 113 | .ropeproject 114 | 115 | # mkdocs documentation 116 | /site 117 | # mypy 118 | .mypy_cache/ 119 | .dmypy.json 120 | dmypy.json 121 | 122 | # Pyre type checker 123 | .pyre/ 124 | 125 | # custom 126 | .idea 127 | .vscode 128 | 129 | # Mac configure file. 130 | .DS_Store 131 | 132 | # Program run create directory. 133 | data 134 | results 135 | samples 136 | 137 | # Program run create file. 138 | *.bmp 139 | *.png 140 | *.mp4 141 | *.zip 142 | *.csv 143 | *.pth 144 | *.mdb 145 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # FSRCNN-PyTorch 2 | 3 | ## Overview 4 | 5 | This repository contains an op-for-op PyTorch reimplementation of [Accelerating the Super-Resolution Convolutional Neural Network](https://arxiv.org/abs/1608.00367v1). 6 | 7 | ## Table of contents 8 | 9 | - [FSRCNN-PyTorch](#fsrcnn-pytorch) 10 | - [Overview](#overview) 11 | - [Table of contents](#table-of-contents) 12 | - [About Accelerating the Super-Resolution Convolutional Neural Network](#about-accelerating-the-super-resolution-convolutional-neural-network) 13 | - [Download weights](#download-weights) 14 | - [Download datasets](#download-datasets) 15 | - [Test](#test) 16 | - [Train](#train) 17 | - [Result](#result) 18 | - [Credit](#credit) 19 | - [Accelerating the Super-Resolution Convolutional Neural Network](#accelerating-the-super-resolution-convolutional-neural-network) 20 | 21 | ## About Accelerating the Super-Resolution Convolutional Neural Network 22 | 23 | If you're new to FSRCNN, here's an abstract straight from the paper: 24 | 25 | As a successful deep model applied in image super-resolution (SR), the Super-Resolution Convolutional Neural Network ( 26 | SRCNN) has demonstrated superior performance to the previous hand-crafted models either in speed and restoration quality. However, the high 27 | computational cost still hinders it from practical usage that demands real-time performance ( 28 | 24 fps). In this paper, we aim at accelerating the current SRCNN, and propose a compact hourglass-shape CNN structure for faster and better SR. We 29 | re-design the SRCNN structure mainly in three aspects. First, we introduce a deconvolution layer at the end of the network, then the mapping is 30 | learned directly from the original low-resolution image (without interpolation) to the high-resolution one. Second, we reformulate the mapping layer 31 | by shrinking the input feature dimension before mapping and expanding back afterwards. Third, we adopt smaller filter sizes but more mapping layers. 32 | The proposed model achieves a speed up of more than 40 times with even superior restoration quality. Further, we present the parameter settings that 33 | can achieve real-time performance on a generic CPU while still maintaining good performance. A corresponding transfer strategy is also proposed for 34 | fast training and testing across different upscaling factors. 35 | 36 | ## Download weights 37 | 38 | - [Google Driver](https://drive.google.com/drive/folders/17ju2HN7Y6pyPK2CC_AqnAfTOe9_3hCQ8?usp=sharing) 39 | - [Baidu Driver](https://pan.baidu.com/s/1yNs4rqIb004-NKEdKBJtYg?pwd=llot) 40 | 41 | ## Download datasets 42 | 43 | Contains DIV2K, DIV8K, Flickr2K, OST, T91, Set5, Set14, BSDS100 and BSDS200, etc. 44 | 45 | - [Google Driver](https://drive.google.com/drive/folders/1A6lzGeQrFMxPqJehK9s37ce-tPDj20mD?usp=sharing) 46 | - [Baidu Driver](https://pan.baidu.com/s/1o-8Ty_7q6DiS3ykLU09IVg?pwd=llot) 47 | 48 | ## Test 49 | 50 | Modify the contents of the file as follows. 51 | 52 | - line 29: `upscale_factor` change to the magnification you need to enlarge. 53 | - line 31: `mode` change Set to valid mode. 54 | - line 67: `model_path` change weight address after training. 55 | 56 | ## Train 57 | 58 | Modify the contents of the file as follows. 59 | 60 | - line 29: `upscale_factor` change to the magnification you need to enlarge. 61 | - line 31: `mode` change Set to train mode. 62 | 63 | If you want to load weights that you've trained before, modify the contents of the file as follows. 64 | 65 | - line 47: `start_epoch` change number of training iterations in the previous round. 66 | - line 48: `resume` change weight address that needs to be loaded. 67 | 68 | ## Result 69 | 70 | Source of original paper results: https://arxiv.org/pdf/1608.00367v1.pdf 71 | 72 | In the following table, the value in `()` indicates the result of the project, and `-` indicates no test. 73 | 74 | | Dataset | Scale | PSNR | 75 | |:-------:|:-----:|:----------------:| 76 | | Set5 | 2 | 36.94(**37.09**) | 77 | | Set5 | 3 | 33.06(**33.06**) | 78 | | Set5 | 4 | 30.55(**30.66**) | 79 | 80 | Low Resolution / Super Resolution / High Resolution 81 | 82 | 83 | ## Credit 84 | 85 | ### Accelerating the Super-Resolution Convolutional Neural Network 86 | 87 | _Chao Dong, Chen Change Loy, Xiaoou Tang_
88 | 89 | **Abstract**
90 | As a successful deep model applied in image super-resolution (SR), the Super-Resolution Convolutional Neural Network ( 91 | SRCNN) has demonstrated superior performance to the previous hand-crafted models either in speed and restoration quality. However, the high 92 | computational cost still hinders it from practical usage that demands real-time performance ( 93 | 24 fps). In this paper, we aim at accelerating the current SRCNN, and propose a compact hourglass-shape CNN structure for faster and better SR. We 94 | re-design the SRCNN structure mainly in three aspects. First, we introduce a deconvolution layer at the end of the network, then the mapping is 95 | learned directly from the original low-resolution image (without interpolation) to the high-resolution one. Second, we reformulate the mapping layer 96 | by shrinking the input feature dimension before mapping and expanding back afterwards. Third, we adopt smaller filter sizes but more mapping layers. 97 | The proposed model achieves a speed up of more than 40 times with even superior restoration quality. Further, we present the parameter settings that 98 | can achieve real-time performance on a generic CPU while still maintaining good performance. A corresponding transfer strategy is also proposed for 99 | fast training and testing across different upscaling factors. 100 | 101 | [[Paper]](https://arxiv.org/pdf/1608.00367v1.pdf) [[Author's implements(Caffe)]](https://drive.google.com/open?id=0B7tU5Pj1dfCMWjhhaE1HR3dqcGs) 102 | 103 | ```bibtex 104 | @article{DBLP:journals/corr/DongLT16, 105 | author = {Chao Dong and 106 | Chen Change Loy and 107 | Xiaoou Tang}, 108 | title = {Accelerating the Super-Resolution Convolutional Neural Network}, 109 | journal = {CoRR}, 110 | volume = {abs/1608.00367}, 111 | year = {2016}, 112 | url = {http://arxiv.org/abs/1608.00367}, 113 | eprinttype = {arXiv}, 114 | eprint = {1608.00367}, 115 | timestamp = {Mon, 13 Aug 2018 16:47:56 +0200}, 116 | biburl = {https://dblp.org/rec/journals/corr/DongLT16.bib}, 117 | bibsource = {dblp computer science bibliography, https://dblp.org} 118 | } 119 | ``` 120 | -------------------------------------------------------------------------------- /assets/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lornatang/FSRCNN-PyTorch/e6f8d53775d93249292fb80eed0af940aa7710c5/assets/result.png -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import random 15 | 16 | import numpy as np 17 | import torch 18 | from torch.backends import cudnn 19 | 20 | # Random seed to maintain reproducible results 21 | random.seed(0) 22 | torch.manual_seed(0) 23 | np.random.seed(0) 24 | # Use GPU for training by default 25 | device = torch.device("cuda", 0) 26 | # Turning on when the image size does not change during training can speed up training 27 | cudnn.benchmark = True 28 | # Image magnification factor 29 | upscale_factor = 2 30 | # Current configuration parameter method 31 | mode = "train" 32 | # Experiment name, easy to save weights and log files 33 | exp_name = "fsrcnn_x2" 34 | 35 | if mode == "train": 36 | # Dataset 37 | train_image_dir = f"data/T91/FSRCNN/train" 38 | valid_image_dir = f"data/T91/FSRCNN/valid" 39 | test_lr_image_dir = f"data/Set5/LRbicx{upscale_factor}" 40 | test_hr_image_dir = f"data/Set5/GTmod12" 41 | 42 | image_size = 20 43 | batch_size = 16 44 | num_workers = 4 45 | 46 | # Incremental training and migration training 47 | start_epoch = 0 48 | resume = "" 49 | 50 | # Total number of epochs 51 | epochs = 3000 52 | 53 | # SGD optimizer parameter 54 | model_lr = 1e-3 55 | model_momentum = 0.9 56 | model_weight_decay = 1e-4 57 | model_nesterov = False 58 | 59 | print_frequency = 200 60 | 61 | if mode == "valid": 62 | # Test data address 63 | lr_dir = f"data/Set5/LRbicx{upscale_factor}" 64 | sr_dir = f"results/test/{exp_name}" 65 | hr_dir = f"data/Set5/GTmod12" 66 | 67 | model_path = f"results/{exp_name}/best.pth.tar" 68 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Usage 2 | 3 | ## Download datasets 4 | 5 | Contains DIV2K, DIV8K, Flickr2K, OST, T91, Set5, Set14, BSDS100 and BSDS200, etc. 6 | 7 | - [Google Driver](https://drive.google.com/drive/folders/1A6lzGeQrFMxPqJehK9s37ce-tPDj20mD?usp=sharing) 8 | - [Baidu Driver](https://pan.baidu.com/s/1o-8Ty_7q6DiS3ykLU09IVg?pwd=llot) 9 | 10 | ## Train dataset struct information 11 | 12 | ```text 13 | - T91 14 | - FSRCNN 15 | - train 16 | - valid 17 | ``` 18 | 19 | ## Test dataset struct information 20 | 21 | ```text 22 | - Set5 23 | - GTmod12 24 | - baby.png 25 | - bird.png 26 | - ... 27 | - LRbicx4 28 | - baby.png 29 | - bird.png 30 | - ... 31 | - Set14 32 | - GTmod12 33 | - baboon.png 34 | - barbara.png 35 | - ... 36 | - LRbicx4 37 | - baboon.png 38 | - barbara.png 39 | - ... 40 | ``` 41 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | """Realize the function of dataset preparation.""" 15 | import os 16 | import queue 17 | import threading 18 | 19 | import cv2 20 | import numpy as np 21 | import torch 22 | from torch.utils.data import Dataset, DataLoader 23 | 24 | import imgproc 25 | 26 | __all__ = [ 27 | "TrainValidImageDataset", "TestImageDataset", 28 | "PrefetchGenerator", "PrefetchDataLoader", "CPUPrefetcher", "CUDAPrefetcher", 29 | ] 30 | 31 | 32 | class TrainValidImageDataset(Dataset): 33 | """Customize the data set loading function and prepare low/high resolution image data in advance. 34 | 35 | Args: 36 | image_dir (str): Train/Valid dataset address. 37 | image_size (int): High resolution image size. 38 | upscale_factor (int): Image up scale factor. 39 | mode (str): Data set loading method, the training data set is for data enhancement, and the verification data set is not for data enhancement. 40 | """ 41 | 42 | def __init__(self, image_dir: str, image_size: int, upscale_factor: int, mode: str) -> None: 43 | super(TrainValidImageDataset, self).__init__() 44 | # Get all image file names in folder 45 | self.image_file_names = [os.path.join(image_dir, image_file_name) for image_file_name in os.listdir(image_dir)] 46 | # Specify the high-resolution image size, with equal length and width 47 | self.image_size = image_size 48 | # How many times the high-resolution image is the low-resolution image 49 | self.upscale_factor = upscale_factor 50 | # Load training dataset or test dataset 51 | self.mode = mode 52 | 53 | def __getitem__(self, batch_index: int) -> [torch.Tensor, torch.Tensor]: 54 | # Read a batch of image data 55 | hr_image = cv2.imread(self.image_file_names[batch_index], cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. 56 | # Use high-resolution image to make low-resolution image 57 | lr_image = imgproc.imresize(hr_image, 1 / self.upscale_factor) 58 | 59 | if self.mode == "Train": 60 | # Data augment 61 | lr_image, hr_image = imgproc.random_crop(lr_image, hr_image, self.image_size, self.upscale_factor) 62 | lr_image, hr_image = imgproc.random_rotate(lr_image, hr_image, angles=[0, 90, 180, 270]) 63 | elif self.mode == "Valid": 64 | lr_image, hr_image = imgproc.center_crop(lr_image, hr_image, self.image_size, self.upscale_factor) 65 | else: 66 | raise ValueError("Unsupported data processing model, please use `Train` or `Valid`.") 67 | 68 | # Only extract the image data of the Y channel 69 | lr_y_image = imgproc.bgr2ycbcr(lr_image, use_y_channel=True) 70 | hr_y_image = imgproc.bgr2ycbcr(hr_image, use_y_channel=True) 71 | 72 | # Convert image data into Tensor stream format (PyTorch). 73 | # Note: The range of input and output is between [0, 1] 74 | lr_y_tensor = imgproc.image2tensor(lr_y_image, range_norm=False, half=False) 75 | hr_y_tensor = imgproc.image2tensor(hr_y_image, range_norm=False, half=False) 76 | 77 | return {"lr": lr_y_tensor, "hr": hr_y_tensor} 78 | 79 | def __len__(self) -> int: 80 | return len(self.image_file_names) 81 | 82 | 83 | class TestImageDataset(Dataset): 84 | """Define Test dataset loading methods. 85 | 86 | Args: 87 | test_lr_image_dir (str): Test dataset address for low resolution image dir. 88 | test_hr_image_dir (str): Test dataset address for high resolution image dir. 89 | upscale_factor (int): Image up scale factor. 90 | """ 91 | 92 | def __init__(self, test_lr_image_dir: str, test_hr_image_dir: str, upscale_factor: int) -> None: 93 | super(TestImageDataset, self).__init__() 94 | # Get all image file names in folder 95 | self.lr_image_file_names = [os.path.join(test_lr_image_dir, x) for x in os.listdir(test_lr_image_dir)] 96 | self.hr_image_file_names = [os.path.join(test_hr_image_dir, x) for x in os.listdir(test_hr_image_dir)] 97 | # How many times the high-resolution image is the low-resolution image 98 | self.upscale_factor = upscale_factor 99 | 100 | def __getitem__(self, batch_index: int) -> [torch.Tensor, torch.Tensor]: 101 | # Read a batch of image data 102 | lr_image = cv2.imread(self.lr_image_file_names[batch_index], cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. 103 | hr_image = cv2.imread(self.hr_image_file_names[batch_index], cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. 104 | 105 | # Only extract the image data of the Y channel 106 | lr_y_image = imgproc.bgr2ycbcr(lr_image, use_y_channel=True) 107 | hr_y_image = imgproc.bgr2ycbcr(hr_image, use_y_channel=True) 108 | 109 | # Convert image data into Tensor stream format (PyTorch). 110 | # Note: The range of input and output is between [0, 1] 111 | lr_y_tensor = imgproc.image2tensor(lr_y_image, range_norm=False, half=False) 112 | hr_y_tensor = imgproc.image2tensor(hr_y_image, range_norm=False, half=False) 113 | 114 | return {"lr": lr_y_tensor, "hr": hr_y_tensor} 115 | 116 | def __len__(self) -> int: 117 | return len(self.lr_image_file_names) 118 | 119 | 120 | class PrefetchGenerator(threading.Thread): 121 | """A fast data prefetch generator. 122 | 123 | Args: 124 | generator: Data generator. 125 | num_data_prefetch_queue (int): How many early data load queues. 126 | """ 127 | 128 | def __init__(self, generator, num_data_prefetch_queue: int) -> None: 129 | threading.Thread.__init__(self) 130 | self.queue = queue.Queue(num_data_prefetch_queue) 131 | self.generator = generator 132 | self.daemon = True 133 | self.start() 134 | 135 | def run(self) -> None: 136 | for item in self.generator: 137 | self.queue.put(item) 138 | self.queue.put(None) 139 | 140 | def __next__(self): 141 | next_item = self.queue.get() 142 | if next_item is None: 143 | raise StopIteration 144 | return next_item 145 | 146 | def __iter__(self): 147 | return self 148 | 149 | 150 | class PrefetchDataLoader(DataLoader): 151 | """A fast data prefetch dataloader. 152 | 153 | Args: 154 | num_data_prefetch_queue (int): How many early data load queues. 155 | kwargs (dict): Other extended parameters. 156 | """ 157 | 158 | def __init__(self, num_data_prefetch_queue: int, **kwargs) -> None: 159 | self.num_data_prefetch_queue = num_data_prefetch_queue 160 | super(PrefetchDataLoader, self).__init__(**kwargs) 161 | 162 | def __iter__(self): 163 | return PrefetchGenerator(super().__iter__(), self.num_data_prefetch_queue) 164 | 165 | 166 | class CPUPrefetcher: 167 | """Use the CPU side to accelerate data reading. 168 | 169 | Args: 170 | dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. 171 | """ 172 | 173 | def __init__(self, dataloader) -> None: 174 | self.original_dataloader = dataloader 175 | self.data = iter(dataloader) 176 | 177 | def next(self): 178 | try: 179 | return next(self.data) 180 | except StopIteration: 181 | return None 182 | 183 | def reset(self): 184 | self.data = iter(self.original_dataloader) 185 | 186 | def __len__(self) -> int: 187 | return len(self.original_dataloader) 188 | 189 | 190 | class CUDAPrefetcher: 191 | """Use the CUDA side to accelerate data reading. 192 | 193 | Args: 194 | dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset. 195 | device (torch.device): Specify running device. 196 | """ 197 | 198 | def __init__(self, dataloader, device: torch.device): 199 | self.batch_data = None 200 | self.original_dataloader = dataloader 201 | self.device = device 202 | 203 | self.data = iter(dataloader) 204 | self.stream = torch.cuda.Stream() 205 | self.preload() 206 | 207 | def preload(self): 208 | try: 209 | self.batch_data = next(self.data) 210 | except StopIteration: 211 | self.batch_data = None 212 | return None 213 | 214 | with torch.cuda.stream(self.stream): 215 | for k, v in self.batch_data.items(): 216 | if torch.is_tensor(v): 217 | self.batch_data[k] = self.batch_data[k].to(self.device, non_blocking=True) 218 | 219 | def next(self): 220 | torch.cuda.current_stream().wait_stream(self.stream) 221 | batch_data = self.batch_data 222 | self.preload() 223 | return batch_data 224 | 225 | def reset(self): 226 | self.data = iter(self.original_dataloader) 227 | self.preload() 228 | 229 | def __len__(self) -> int: 230 | return len(self.original_dataloader) 231 | -------------------------------------------------------------------------------- /imgproc.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | """Realize the function of processing the dataset before training.""" 15 | import math 16 | import random 17 | from typing import Any 18 | 19 | import cv2 20 | import numpy as np 21 | import torch 22 | from torchvision.transforms import functional as F 23 | 24 | __all__ = [ 25 | "image2tensor", "tensor2image", 26 | "rgb2ycbcr", "bgr2ycbcr", "ycbcr2bgr", "ycbcr2rgb", 27 | "center_crop", "random_crop", "random_rotate", "random_horizontally_flip", "random_vertically_flip", 28 | ] 29 | 30 | 31 | def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor: 32 | """Convert ``PIL.Image`` to Tensor. 33 | 34 | Args: 35 | image (np.ndarray): The image data read by ``PIL.Image`` 36 | range_norm (bool): Scale [0, 1] data to between [-1, 1] 37 | half (bool): Whether to convert torch.float32 similarly to torch.half type. 38 | 39 | Returns: 40 | Normalized image data 41 | 42 | Examples: 43 | >>> image = cv2.imread("image.bmp", cv2.IMREAD_UNCHANGED).astype(np.float32) / 255. 44 | >>> tensor_image = image2tensor(image, range_norm=False, half=False) 45 | """ 46 | 47 | tensor = F.to_tensor(image) 48 | 49 | if range_norm: 50 | tensor = tensor.mul_(2.0).sub_(1.0) 51 | if half: 52 | tensor = tensor.half() 53 | 54 | return tensor 55 | 56 | 57 | def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any: 58 | """Converts ``torch.Tensor`` to ``PIL.Image``. 59 | 60 | Args: 61 | tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image`` 62 | range_norm (bool): Scale [-1, 1] data to between [0, 1] 63 | half (bool): Whether to convert torch.float32 similarly to torch.half type. 64 | 65 | Returns: 66 | Convert image data to support PIL library 67 | 68 | Examples: 69 | >>> tensor = torch.randn([1, 3, 128, 128]) 70 | >>> image = tensor2image(tensor, range_norm=False, half=False) 71 | """ 72 | 73 | if range_norm: 74 | tensor = tensor.add_(1.0).div_(2.0) 75 | if half: 76 | tensor = tensor.half() 77 | 78 | image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8") 79 | 80 | return image 81 | 82 | 83 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py` 84 | def cubic(x: Any): 85 | """Implementation of `cubic` function in Matlab under Python language. 86 | 87 | Args: 88 | x: Element vector. 89 | 90 | Returns: 91 | Bicubic interpolation. 92 | """ 93 | 94 | absx = torch.abs(x) 95 | absx2 = absx ** 2 96 | absx3 = absx ** 3 97 | return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ( 98 | ((absx > 1) * (absx <= 2)).type_as(absx)) 99 | 100 | 101 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py` 102 | def calculate_weights_indices(in_length: int, out_length: int, scale: float, kernel_width: int, antialiasing: bool): 103 | """Implementation of `calculate_weights_indices` function in Matlab under Python language. 104 | 105 | Args: 106 | in_length (int): Input length. 107 | out_length (int): Output length. 108 | scale (float): Scale factor. 109 | kernel_width (int): Kernel width. 110 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations. 111 | Caution: Bicubic down-sampling in PIL uses antialiasing by default. 112 | 113 | """ 114 | 115 | if (scale < 1) and antialiasing: 116 | # Use a modified kernel (larger kernel width) to simultaneously 117 | # interpolate and antialiasing 118 | kernel_width = kernel_width / scale 119 | 120 | # Output-space coordinates 121 | x = torch.linspace(1, out_length, out_length) 122 | 123 | # Input-space coordinates. Calculate the inverse mapping such that 0.5 124 | # in output space maps to 0.5 in input space, and 0.5 + scale in output 125 | # space maps to 1.5 in input space. 126 | u = x / scale + 0.5 * (1 - 1 / scale) 127 | 128 | # What is the left-most pixel that can be involved in the computation? 129 | left = torch.floor(u - kernel_width / 2) 130 | 131 | # What is the maximum number of pixels that can be involved in the 132 | # computation? Note: it's OK to use an extra pixel here; if the 133 | # corresponding weights are all zero, it will be eliminated at the end 134 | # of this function. 135 | p = math.ceil(kernel_width) + 2 136 | 137 | # The indices of the input pixels involved in computing the k-th output 138 | # pixel are in row k of the indices matrix. 139 | indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( 140 | out_length, p) 141 | 142 | # The weights used to compute the k-th output pixel are in row k of the 143 | # weights matrix. 144 | distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices 145 | 146 | # apply cubic kernel 147 | if (scale < 1) and antialiasing: 148 | weights = scale * cubic(distance_to_center * scale) 149 | else: 150 | weights = cubic(distance_to_center) 151 | 152 | # Normalize the weights matrix so that each row sums to 1. 153 | weights_sum = torch.sum(weights, 1).view(out_length, 1) 154 | weights = weights / weights_sum.expand(out_length, p) 155 | 156 | # If a column in weights is all zero, get rid of it. only consider the 157 | # first and last column. 158 | weights_zero_tmp = torch.sum((weights == 0), 0) 159 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): 160 | indices = indices.narrow(1, 1, p - 2) 161 | weights = weights.narrow(1, 1, p - 2) 162 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): 163 | indices = indices.narrow(1, 0, p - 2) 164 | weights = weights.narrow(1, 0, p - 2) 165 | weights = weights.contiguous() 166 | indices = indices.contiguous() 167 | sym_len_s = -indices.min() + 1 168 | sym_len_e = indices.max() - in_length 169 | indices = indices + sym_len_s - 1 170 | return weights, indices, int(sym_len_s), int(sym_len_e) 171 | 172 | 173 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py` 174 | def imresize(image: Any, scale_factor: float, antialiasing: bool = True) -> Any: 175 | """Implementation of `imresize` function in Matlab under Python language. 176 | 177 | Args: 178 | image: The input image. 179 | scale_factor (float): Scale factor. The same scale applies for both height and width. 180 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations. 181 | Caution: Bicubic down-sampling in `PIL` uses antialiasing by default. Default: ``True``. 182 | 183 | Returns: 184 | np.ndarray: Output image with shape (c, h, w), [0, 1] range, w/o round. 185 | """ 186 | squeeze_flag = False 187 | if type(image).__module__ == np.__name__: # numpy type 188 | numpy_type = True 189 | if image.ndim == 2: 190 | image = image[:, :, None] 191 | squeeze_flag = True 192 | image = torch.from_numpy(image.transpose(2, 0, 1)).float() 193 | else: 194 | numpy_type = False 195 | if image.ndim == 2: 196 | image = image.unsqueeze(0) 197 | squeeze_flag = True 198 | 199 | in_c, in_h, in_w = image.size() 200 | out_h, out_w = math.ceil(in_h * scale_factor), math.ceil(in_w * scale_factor) 201 | kernel_width = 4 202 | 203 | # get weights and indices 204 | weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale_factor, kernel_width, antialiasing) 205 | weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale_factor, kernel_width, antialiasing) 206 | # process H dimension 207 | # symmetric copying 208 | img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) 209 | img_aug.narrow(1, sym_len_hs, in_h).copy_(image) 210 | 211 | sym_patch = image[:, :sym_len_hs, :] 212 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 213 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 214 | img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) 215 | 216 | sym_patch = image[:, -sym_len_he:, :] 217 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 218 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 219 | img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) 220 | 221 | out_1 = torch.FloatTensor(in_c, out_h, in_w) 222 | kernel_width = weights_h.size(1) 223 | for i in range(out_h): 224 | idx = int(indices_h[i][0]) 225 | for j in range(in_c): 226 | out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) 227 | 228 | # process W dimension 229 | # symmetric copying 230 | out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) 231 | out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) 232 | 233 | sym_patch = out_1[:, :, :sym_len_ws] 234 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 235 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 236 | out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) 237 | 238 | sym_patch = out_1[:, :, -sym_len_we:] 239 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 240 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 241 | out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) 242 | 243 | out_2 = torch.FloatTensor(in_c, out_h, out_w) 244 | kernel_width = weights_w.size(1) 245 | for i in range(out_w): 246 | idx = int(indices_w[i][0]) 247 | for j in range(in_c): 248 | out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) 249 | 250 | if squeeze_flag: 251 | out_2 = out_2.squeeze(0) 252 | if numpy_type: 253 | out_2 = out_2.numpy() 254 | if not squeeze_flag: 255 | out_2 = out_2.transpose(1, 2, 0) 256 | 257 | return out_2 258 | 259 | 260 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py` 261 | def rgb2ycbcr(image: np.ndarray, use_y_channel: bool = False) -> np.ndarray: 262 | """Implementation of rgb2ycbcr function in Matlab under Python language. 263 | 264 | Args: 265 | image (np.ndarray): Image input in RGB format. 266 | use_y_channel (bool): Extract Y channel separately. Default: ``False``. 267 | 268 | Returns: 269 | ndarray: YCbCr image array data. 270 | """ 271 | 272 | if use_y_channel: 273 | image = np.dot(image, [65.481, 128.553, 24.966]) + 16.0 274 | else: 275 | image = np.matmul(image, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128] 276 | 277 | image /= 255. 278 | image = image.astype(np.float32) 279 | 280 | return image 281 | 282 | 283 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py` 284 | def bgr2ycbcr(image: np.ndarray, use_y_channel: bool = False) -> np.ndarray: 285 | """Implementation of bgr2ycbcr function in Matlab under Python language. 286 | 287 | Args: 288 | image (np.ndarray): Image input in BGR format. 289 | use_y_channel (bool): Extract Y channel separately. Default: ``False``. 290 | 291 | Returns: 292 | ndarray: YCbCr image array data. 293 | """ 294 | 295 | if use_y_channel: 296 | image = np.dot(image, [24.966, 128.553, 65.481]) + 16.0 297 | else: 298 | image = np.matmul(image, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128] 299 | 300 | image /= 255. 301 | image = image.astype(np.float32) 302 | 303 | return image 304 | 305 | 306 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py` 307 | def ycbcr2rgb(image: np.ndarray) -> np.ndarray: 308 | """Implementation of ycbcr2rgb function in Matlab under Python language. 309 | 310 | Args: 311 | image (np.ndarray): Image input in YCbCr format. 312 | 313 | Returns: 314 | ndarray: RGB image array data. 315 | """ 316 | 317 | image_dtype = image.dtype 318 | image *= 255. 319 | 320 | image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621], 321 | [0, -0.00153632, 0.00791071], 322 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] 323 | 324 | image /= 255. 325 | image = image.astype(image_dtype) 326 | 327 | return image 328 | 329 | 330 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py` 331 | def ycbcr2bgr(image: np.ndarray) -> np.ndarray: 332 | """Implementation of ycbcr2bgr function in Matlab under Python language. 333 | 334 | Args: 335 | image (np.ndarray): Image input in YCbCr format. 336 | 337 | Returns: 338 | ndarray: BGR image array data. 339 | """ 340 | 341 | image_dtype = image.dtype 342 | image *= 255. 343 | 344 | image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621], 345 | [0.00791071, -0.00153632, 0], 346 | [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] 347 | 348 | image /= 255. 349 | image = image.astype(image_dtype) 350 | 351 | return image 352 | 353 | 354 | def center_crop(lr_image: np.ndarray, hr_image: np.ndarray, hr_image_size: int, upscale_factor: int) -> [np.ndarray, np.ndarray]: 355 | """Crop small image patches from one image center area. 356 | 357 | Args: 358 | lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`. 359 | hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`. 360 | hr_image_size (int): The size of the captured high-resolution image area. 361 | upscale_factor (int): Image up scale factor. 362 | 363 | Returns: 364 | np.ndarray: Small patch images. 365 | """ 366 | 367 | hr_image_height, hr_image_width = hr_image.shape[:2] 368 | 369 | # Just need to find the top and left coordinates of the image 370 | hr_top = (hr_image_height - hr_image_size) // 2 371 | hr_left = (hr_image_width - hr_image_size) // 2 372 | 373 | # Define the LR image position 374 | lr_top = hr_top // upscale_factor 375 | lr_left = hr_left // upscale_factor 376 | lr_image_size = hr_image_size // upscale_factor 377 | 378 | # Crop image patch 379 | patch_lr_image = lr_image[lr_top:lr_top + lr_image_size, lr_left:lr_left + lr_image_size, ...] 380 | patch_hr_image = hr_image[hr_top:hr_top + hr_image_size, hr_left:hr_left + hr_image_size, ...] 381 | 382 | return patch_lr_image, patch_hr_image 383 | 384 | 385 | def random_crop(lr_image: np.ndarray, hr_image: np.ndarray, hr_image_size: int, upscale_factor: int) -> [np.ndarray, np.ndarray]: 386 | """Crop small image patches from one image. 387 | 388 | Args: 389 | lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`. 390 | hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`. 391 | hr_image_size (int): The size of the captured high-resolution image area. 392 | upscale_factor (int): Image up scale factor. 393 | 394 | Returns: 395 | np.ndarray: Small patch images. 396 | """ 397 | 398 | hr_image_height, hr_image_width = hr_image.shape[:2] 399 | 400 | # Just need to find the top and left coordinates of the image 401 | hr_top = random.randint(0, hr_image_height - hr_image_size) 402 | hr_left = random.randint(0, hr_image_width - hr_image_size) 403 | 404 | # Define the LR image position 405 | lr_top = hr_top // upscale_factor 406 | lr_left = hr_left // upscale_factor 407 | lr_image_size = hr_image_size // upscale_factor 408 | 409 | # Crop image patch 410 | patch_lr_image = lr_image[lr_top:lr_top + lr_image_size, lr_left:lr_left + lr_image_size, ...] 411 | patch_hr_image = hr_image[hr_top:hr_top + hr_image_size, hr_left:hr_left + hr_image_size, ...] 412 | 413 | return patch_lr_image, patch_hr_image 414 | 415 | 416 | def random_rotate(lr_image: np.ndarray, hr_image: np.ndarray, angles: list, lr_center=None, hr_center=None, scale_factor: float = 1.0) -> [np.ndarray, np.ndarray]: 417 | """Rotate an image randomly by a specified angle. 418 | 419 | Args: 420 | lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`. 421 | hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`. 422 | angles (list): Specify the rotation angle. 423 | lr_center (tuple[int]): Low-resolution image rotation center. If the center is None, initialize it as the center of the image. ``Default: None``. 424 | hr_center (tuple[int]): Low-resolution image rotation center. If the center is None, initialize it as the center of the image. ``Default: None``. 425 | scale_factor (float): scaling factor. Default: 1.0. 426 | 427 | Returns: 428 | np.ndarray: Rotated images. 429 | """ 430 | 431 | lr_image_height, lr_image_width = lr_image.shape[:2] 432 | hr_image_height, hr_image_width = hr_image.shape[:2] 433 | 434 | if lr_center is None: 435 | lr_center = (lr_image_width // 2, lr_image_height // 2) 436 | if hr_center is None: 437 | hr_center = (hr_image_width // 2, hr_image_height // 2) 438 | 439 | # Random select specific angle 440 | angle = random.choice(angles) 441 | 442 | lr_matrix = cv2.getRotationMatrix2D(lr_center, angle, scale_factor) 443 | hr_matrix = cv2.getRotationMatrix2D(hr_center, angle, scale_factor) 444 | 445 | rotated_lr_image = cv2.warpAffine(lr_image, lr_matrix, (lr_image_width, lr_image_height)) 446 | rotated_hr_image = cv2.warpAffine(hr_image, hr_matrix, (hr_image_width, hr_image_height)) 447 | 448 | return rotated_lr_image, rotated_hr_image 449 | 450 | 451 | def random_horizontally_flip(lr_image: np.ndarray, hr_image: np.ndarray, p=0.5) -> [np.ndarray, np.ndarray]: 452 | """Flip an image horizontally randomly. 453 | 454 | Args: 455 | lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`. 456 | hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`. 457 | p (optional, float): rollover probability. (Default: 0.5) 458 | 459 | Returns: 460 | np.ndarray: Horizontally flip images. 461 | """ 462 | 463 | if random.random() < p: 464 | lr_image = cv2.flip(lr_image, 1) 465 | hr_image = cv2.flip(hr_image, 1) 466 | 467 | return lr_image, hr_image 468 | 469 | 470 | def random_vertically_flip(lr_image: np.ndarray, hr_image: np.ndarray, p=0.5) -> [np.ndarray, np.ndarray]: 471 | """Flip an image vertically randomly. 472 | 473 | Args: 474 | lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`. 475 | hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`. 476 | p (optional, float): rollover probability. (Default: 0.5) 477 | 478 | Returns: 479 | np.ndarray: Vertically flip images. 480 | """ 481 | 482 | if random.random() < p: 483 | lr_image = cv2.flip(lr_image, 0) 484 | hr_image = cv2.flip(hr_image, 0) 485 | 486 | return lr_image, hr_image 487 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================ 14 | """Realize the model definition function.""" 15 | from math import sqrt 16 | 17 | import torch 18 | from torch import nn 19 | 20 | 21 | class FSRCNN(nn.Module): 22 | """ 23 | 24 | Args: 25 | upscale_factor (int): Image magnification factor. 26 | """ 27 | 28 | def __init__(self, upscale_factor: int) -> None: 29 | super(FSRCNN, self).__init__() 30 | # Feature extraction layer. 31 | self.feature_extraction = nn.Sequential( 32 | nn.Conv2d(1, 56, (5, 5), (1, 1), (2, 2)), 33 | nn.PReLU(56) 34 | ) 35 | 36 | # Shrinking layer. 37 | self.shrink = nn.Sequential( 38 | nn.Conv2d(56, 12, (1, 1), (1, 1), (0, 0)), 39 | nn.PReLU(12) 40 | ) 41 | 42 | # Mapping layer. 43 | self.map = nn.Sequential( 44 | nn.Conv2d(12, 12, (3, 3), (1, 1), (1, 1)), 45 | nn.PReLU(12), 46 | nn.Conv2d(12, 12, (3, 3), (1, 1), (1, 1)), 47 | nn.PReLU(12), 48 | nn.Conv2d(12, 12, (3, 3), (1, 1), (1, 1)), 49 | nn.PReLU(12), 50 | nn.Conv2d(12, 12, (3, 3), (1, 1), (1, 1)), 51 | nn.PReLU(12) 52 | ) 53 | 54 | # Expanding layer. 55 | self.expand = nn.Sequential( 56 | nn.Conv2d(12, 56, (1, 1), (1, 1), (0, 0)), 57 | nn.PReLU(56) 58 | ) 59 | 60 | # Deconvolution layer. 61 | self.deconv = nn.ConvTranspose2d(56, 1, (9, 9), (upscale_factor, upscale_factor), (4, 4), (upscale_factor - 1, upscale_factor - 1)) 62 | 63 | # Initialize model weights. 64 | self._initialize_weights() 65 | 66 | def forward(self, x: torch.Tensor) -> torch.Tensor: 67 | return self._forward_impl(x) 68 | 69 | # Support torch.script function. 70 | def _forward_impl(self, x: torch.Tensor) -> torch.Tensor: 71 | out = self.feature_extraction(x) 72 | out = self.shrink(out) 73 | out = self.map(out) 74 | out = self.expand(out) 75 | out = self.deconv(out) 76 | 77 | return out 78 | 79 | # The filter weight of each layer is a Gaussian distribution with zero mean and standard deviation initialized by random extraction 0.001 (deviation is 0). 80 | def _initialize_weights(self) -> None: 81 | for m in self.modules(): 82 | if isinstance(m, nn.Conv2d): 83 | nn.init.normal_(m.weight.data, mean=0.0, std=sqrt(2 / (m.out_channels * m.weight.data[0][0].numel()))) 84 | nn.init.zeros_(m.bias.data) 85 | 86 | nn.init.normal_(self.deconv.weight.data, mean=0.0, std=0.001) 87 | nn.init.zeros_(self.deconv.bias.data) 88 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | opencv-python 2 | numpy 3 | torch 4 | tqdm 5 | setuptools 6 | torchvision 7 | Pillow 8 | natsort -------------------------------------------------------------------------------- /results/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /samples/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /scripts/augment_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import argparse 15 | import os 16 | import shutil 17 | from multiprocessing import Pool 18 | 19 | import cv2 20 | from tqdm import tqdm 21 | 22 | import data_utils 23 | 24 | 25 | def main(args) -> None: 26 | if os.path.exists(args.output_dir): 27 | shutil.rmtree(args.output_dir) 28 | os.makedirs(args.output_dir) 29 | 30 | # Get all image paths 31 | image_file_names = os.listdir(args.images_dir) 32 | 33 | # Splitting images with multiple threads 34 | progress_bar = tqdm(total=len(image_file_names), unit="image", desc="Data augment") 35 | workers_pool = Pool(args.num_workers) 36 | for image_file_name in image_file_names: 37 | workers_pool.apply_async(worker, args=(image_file_name, args), callback=lambda arg: progress_bar.update(1)) 38 | workers_pool.close() 39 | workers_pool.join() 40 | progress_bar.close() 41 | 42 | 43 | def worker(image_file_name, args) -> None: 44 | image = cv2.imread(f"{args.images_dir}/{image_file_name}") 45 | 46 | index = 1 47 | # Data augment 48 | for scale_ratio in [1.0, 0.9, 0.8, 0.7, 0.6]: 49 | new_image = data_utils.imresize(image, scale_ratio) 50 | # Save all images 51 | cv2.imwrite(f"{args.output_dir}/{image_file_name.split('.')[-2]}_{index:02d}.{image_file_name.split('.')[-1]}", new_image) 52 | 53 | index += 1 54 | 55 | 56 | if __name__ == "__main__": 57 | parser = argparse.ArgumentParser(description="Prepare database scripts.") 58 | parser.add_argument("--images_dir", type=str, help="Path to input image directory.") 59 | parser.add_argument("--output_dir", type=str, help="Path to generator image directory.") 60 | parser.add_argument("--num_workers", type=int, help="How many threads to open at the same time.") 61 | args = parser.parse_args() 62 | 63 | main(args) 64 | -------------------------------------------------------------------------------- /scripts/data_utils.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import math 15 | from typing import Any 16 | 17 | import cv2 18 | import numpy as np 19 | import torch 20 | 21 | 22 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py` 23 | def cubic(x: Any): 24 | """Implementation of `cubic` function in Matlab under Python language. 25 | 26 | Args: 27 | x: Element vector. 28 | 29 | Returns: 30 | Bicubic interpolation. 31 | """ 32 | 33 | absx = torch.abs(x) 34 | absx2 = absx ** 2 35 | absx3 = absx ** 3 36 | return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * ( 37 | ((absx > 1) * (absx <= 2)).type_as(absx)) 38 | 39 | 40 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py` 41 | def calculate_weights_indices(in_length: int, out_length: int, scale: float, kernel_width: int, antialiasing: bool): 42 | """Implementation of `calculate_weights_indices` function in Matlab under Python language. 43 | 44 | Args: 45 | in_length (int): Input length. 46 | out_length (int): Output length. 47 | scale (float): Scale factor. 48 | kernel_width (int): Kernel width. 49 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations. 50 | Caution: Bicubic down-sampling in PIL uses antialiasing by default. 51 | 52 | """ 53 | 54 | if (scale < 1) and antialiasing: 55 | # Use a modified kernel (larger kernel width) to simultaneously 56 | # interpolate and antialiasing 57 | kernel_width = kernel_width / scale 58 | 59 | # Output-space coordinates 60 | x = torch.linspace(1, out_length, out_length) 61 | 62 | # Input-space coordinates. Calculate the inverse mapping such that 0.5 63 | # in output space maps to 0.5 in input space, and 0.5 + scale in output 64 | # space maps to 1.5 in input space. 65 | u = x / scale + 0.5 * (1 - 1 / scale) 66 | 67 | # What is the left-most pixel that can be involved in the computation? 68 | left = torch.floor(u - kernel_width / 2) 69 | 70 | # What is the maximum number of pixels that can be involved in the 71 | # computation? Note: it's OK to use an extra pixel here; if the 72 | # corresponding weights are all zero, it will be eliminated at the end 73 | # of this function. 74 | p = math.ceil(kernel_width) + 2 75 | 76 | # The indices of the input pixels involved in computing the k-th output 77 | # pixel are in row k of the indices matrix. 78 | indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand( 79 | out_length, p) 80 | 81 | # The weights used to compute the k-th output pixel are in row k of the 82 | # weights matrix. 83 | distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices 84 | 85 | # apply cubic kernel 86 | if (scale < 1) and antialiasing: 87 | weights = scale * cubic(distance_to_center * scale) 88 | else: 89 | weights = cubic(distance_to_center) 90 | 91 | # Normalize the weights matrix so that each row sums to 1. 92 | weights_sum = torch.sum(weights, 1).view(out_length, 1) 93 | weights = weights / weights_sum.expand(out_length, p) 94 | 95 | # If a column in weights is all zero, get rid of it. only consider the 96 | # first and last column. 97 | weights_zero_tmp = torch.sum((weights == 0), 0) 98 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6): 99 | indices = indices.narrow(1, 1, p - 2) 100 | weights = weights.narrow(1, 1, p - 2) 101 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6): 102 | indices = indices.narrow(1, 0, p - 2) 103 | weights = weights.narrow(1, 0, p - 2) 104 | weights = weights.contiguous() 105 | indices = indices.contiguous() 106 | sym_len_s = -indices.min() + 1 107 | sym_len_e = indices.max() - in_length 108 | indices = indices + sym_len_s - 1 109 | return weights, indices, int(sym_len_s), int(sym_len_e) 110 | 111 | 112 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py` 113 | def imresize(image: Any, scale_factor: float, antialiasing: bool = True) -> Any: 114 | """Implementation of `imresize` function in Matlab under Python language. 115 | 116 | Args: 117 | image: The input image. 118 | scale_factor (float): Scale factor. The same scale applies for both height and width. 119 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations. 120 | Caution: Bicubic down-sampling in `PIL` uses antialiasing by default. Default: ``True``. 121 | 122 | Returns: 123 | np.ndarray: Output image with shape (c, h, w), [0, 1] range, w/o round. 124 | """ 125 | squeeze_flag = False 126 | if type(image).__module__ == np.__name__: # numpy type 127 | numpy_type = True 128 | if image.ndim == 2: 129 | image = image[:, :, None] 130 | squeeze_flag = True 131 | image = torch.from_numpy(image.transpose(2, 0, 1)).float() 132 | else: 133 | numpy_type = False 134 | if image.ndim == 2: 135 | image = image.unsqueeze(0) 136 | squeeze_flag = True 137 | 138 | in_c, in_h, in_w = image.size() 139 | out_h, out_w = math.ceil(in_h * scale_factor), math.ceil(in_w * scale_factor) 140 | kernel_width = 4 141 | 142 | # get weights and indices 143 | weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale_factor, kernel_width, antialiasing) 144 | weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale_factor, kernel_width, antialiasing) 145 | # process H dimension 146 | # symmetric copying 147 | img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w) 148 | img_aug.narrow(1, sym_len_hs, in_h).copy_(image) 149 | 150 | sym_patch = image[:, :sym_len_hs, :] 151 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 152 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 153 | img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv) 154 | 155 | sym_patch = image[:, -sym_len_he:, :] 156 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long() 157 | sym_patch_inv = sym_patch.index_select(1, inv_idx) 158 | img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv) 159 | 160 | out_1 = torch.FloatTensor(in_c, out_h, in_w) 161 | kernel_width = weights_h.size(1) 162 | for i in range(out_h): 163 | idx = int(indices_h[i][0]) 164 | for j in range(in_c): 165 | out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i]) 166 | 167 | # process W dimension 168 | # symmetric copying 169 | out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we) 170 | out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1) 171 | 172 | sym_patch = out_1[:, :, :sym_len_ws] 173 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 174 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 175 | out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv) 176 | 177 | sym_patch = out_1[:, :, -sym_len_we:] 178 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long() 179 | sym_patch_inv = sym_patch.index_select(2, inv_idx) 180 | out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv) 181 | 182 | out_2 = torch.FloatTensor(in_c, out_h, out_w) 183 | kernel_width = weights_w.size(1) 184 | for i in range(out_w): 185 | idx = int(indices_w[i][0]) 186 | for j in range(in_c): 187 | out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i]) 188 | 189 | if squeeze_flag: 190 | out_2 = out_2.squeeze(0) 191 | if numpy_type: 192 | out_2 = out_2.numpy() 193 | if not squeeze_flag: 194 | out_2 = out_2.transpose(1, 2, 0) 195 | 196 | return out_2 197 | -------------------------------------------------------------------------------- /scripts/prepare_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import argparse 15 | import multiprocessing 16 | import os 17 | import shutil 18 | 19 | import cv2 20 | import numpy as np 21 | from tqdm import tqdm 22 | 23 | 24 | def main(args) -> None: 25 | if os.path.exists(args.output_dir): 26 | shutil.rmtree(args.output_dir) 27 | os.makedirs(args.output_dir) 28 | 29 | # Get all image paths 30 | image_file_names = os.listdir(args.images_dir) 31 | 32 | # Splitting images with multiple threads 33 | progress_bar = tqdm(total=len(image_file_names), unit="image", desc="Prepare split image") 34 | workers_pool = multiprocessing.Pool(args.num_workers) 35 | for image_file_name in image_file_names: 36 | workers_pool.apply_async(worker, args=(image_file_name, args), callback=lambda arg: progress_bar.update(1)) 37 | workers_pool.close() 38 | workers_pool.join() 39 | progress_bar.close() 40 | 41 | 42 | def worker(image_file_name, args) -> None: 43 | image = cv2.imread(f"{args.images_dir}/{image_file_name}", cv2.IMREAD_UNCHANGED) 44 | 45 | image_height, image_width = image.shape[0:2] 46 | 47 | index = 1 48 | if image_height >= args.image_size and image_width >= args.image_size: 49 | for pos_y in range(0, image_height - args.image_size + 1, args.step): 50 | for pos_x in range(0, image_width - args.image_size + 1, args.step): 51 | # Crop 52 | crop_image = image[pos_y: pos_y + args.image_size, pos_x:pos_x + args.image_size, ...] 53 | crop_image = np.ascontiguousarray(crop_image) 54 | # Save image 55 | cv2.imwrite(f"{args.output_dir}/{image_file_name.split('.')[-2]}_{index:04d}.{image_file_name.split('.')[-1]}", crop_image) 56 | 57 | index += 1 58 | 59 | 60 | if __name__ == "__main__": 61 | parser = argparse.ArgumentParser(description="Prepare database scripts.") 62 | parser.add_argument("--images_dir", type=str, help="Path to input image directory.") 63 | parser.add_argument("--output_dir", type=str, help="Path to generator image directory.") 64 | parser.add_argument("--image_size", type=int, help="Low-resolution image size from raw image.") 65 | parser.add_argument("--step", type=int, help="Crop image similar to sliding window.") 66 | parser.add_argument("--num_workers", type=int, help="How many threads to open at the same time.") 67 | args = parser.parse_args() 68 | 69 | main(args) 70 | -------------------------------------------------------------------------------- /scripts/run.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | # Prepare dataset 4 | os.system("python ./augment_dataset.py --images_dir ../data/T91/original --output_dir ../data/T91/FSRCNN/original --num_workers 10") 5 | os.system("python ./prepare_dataset.py --images_dir ../data/T91/FSRCNN/original --output_dir ../data/T91/FSRCNN/train --image_size 32 --step 16 --num_workers 10") 6 | 7 | # Split train and valid 8 | os.system("python ./split_train_valid_dataset.py --train_images_dir ../data/T91/FSRCNN/train --valid_images_dir ../data/T91/FSRCNN/valid --valid_samples_ratio 0.1") 9 | -------------------------------------------------------------------------------- /scripts/split_train_valid_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import argparse 15 | import os 16 | import random 17 | import shutil 18 | 19 | from tqdm import tqdm 20 | 21 | 22 | def main(args) -> None: 23 | if not os.path.exists(args.train_images_dir): 24 | os.makedirs(args.train_images_dir) 25 | if not os.path.exists(args.valid_images_dir): 26 | os.makedirs(args.valid_images_dir) 27 | 28 | train_files = os.listdir(args.train_images_dir) 29 | valid_files = random.sample(train_files, int(len(train_files) * args.valid_samples_ratio)) 30 | 31 | process_bar = tqdm(valid_files, total=len(valid_files), unit="image", desc="Split train/valid dataset") 32 | 33 | for image_file_name in process_bar: 34 | shutil.copyfile(f"{args.train_images_dir}/{image_file_name}", f"{args.valid_images_dir}/{image_file_name}") 35 | 36 | 37 | if __name__ == "__main__": 38 | parser = argparse.ArgumentParser(description="Split train and valid dataset scripts.") 39 | parser.add_argument("--train_images_dir", type=str, help="Path to train image directory.") 40 | parser.add_argument("--valid_images_dir", type=str, help="Path to valid image directory.") 41 | parser.add_argument("--valid_samples_ratio", type=float, help="What percentage of the data is extracted from the training set into the validation set.") 42 | args = parser.parse_args() 43 | 44 | main(args) 45 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | import io 15 | import os 16 | import sys 17 | from shutil import rmtree 18 | 19 | from setuptools import Command 20 | from setuptools import find_packages 21 | from setuptools import setup 22 | 23 | # Configure library params. 24 | NAME = "fsrcnn_pytorch" 25 | DESCRIPTION = "Accelerating the Super-Resolution Convolutional Neural Network." 26 | URL = "https://github.com/Lornatang/FSRCNN-PyTorch" 27 | EMAIL = "liu_changyu@dakewe.com" 28 | AUTHOR = "Liu Goodfellow" 29 | REQUIRES_PYTHON = ">=3.8.0" 30 | VERSION = "1.2.2" 31 | 32 | # Libraries that must be installed. 33 | REQUIRED = ["torch"] 34 | 35 | # The following libraries directory need to be installed if you need to run all scripts. 36 | EXTRAS = {} 37 | 38 | # Find the current running location. 39 | here = os.path.abspath(os.path.dirname(__file__)) 40 | 41 | # About README file description. 42 | try: 43 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f: 44 | long_description = "\n" + f.read() 45 | except FileNotFoundError: 46 | long_description = DESCRIPTION 47 | 48 | # Set Current Library Version. 49 | about = {} 50 | if not VERSION: 51 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_") 52 | with open(os.path.join(here, project_slug, "__version__.py")) as f: 53 | exec(f.read(), about) 54 | else: 55 | about["__version__"] = VERSION 56 | 57 | 58 | class UploadCommand(Command): 59 | description = "Build and publish the package." 60 | user_options = [] 61 | 62 | @staticmethod 63 | def status(s): 64 | print("\033[1m{0}\033[0m".format(s)) 65 | 66 | def initialize_options(self): 67 | pass 68 | 69 | def finalize_options(self): 70 | pass 71 | 72 | def run(self): 73 | try: 74 | self.status("Removing previous builds…") 75 | rmtree(os.path.join(here, "dist")) 76 | except OSError: 77 | pass 78 | 79 | self.status("Building Source and Wheel (universal) distribution…") 80 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable)) 81 | 82 | self.status("Uploading the package to PyPI via Twine…") 83 | os.system("twine upload dist/*") 84 | 85 | self.status("Pushing git tags…") 86 | os.system("git tag v{0}".format(about["__version__"])) 87 | os.system("git push --tags") 88 | 89 | sys.exit() 90 | 91 | 92 | setup(name=NAME, 93 | version=about["__version__"], 94 | description=DESCRIPTION, 95 | long_description=long_description, 96 | long_description_content_type="text/markdown", 97 | author=AUTHOR, 98 | author_email=EMAIL, 99 | python_requires=REQUIRES_PYTHON, 100 | url=URL, 101 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]), 102 | install_requires=REQUIRED, 103 | extras_require=EXTRAS, 104 | include_package_data=True, 105 | license="Apache", 106 | classifiers=[ 107 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers 108 | "License :: OSI Approved :: Apache Software License", 109 | "Programming Language :: Python :: 3 :: Only" 110 | ], 111 | cmdclass={ 112 | "upload": UploadCommand, 113 | }, 114 | ) 115 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================ 14 | """File description: Realize the model training function.""" 15 | import os 16 | import shutil 17 | import time 18 | from enum import Enum 19 | 20 | import torch 21 | from torch import nn 22 | from torch import optim 23 | from torch.cuda import amp 24 | from torch.utils.data import DataLoader 25 | from torch.utils.tensorboard import SummaryWriter 26 | 27 | import config 28 | from dataset import CUDAPrefetcher 29 | from dataset import TrainValidImageDataset, TestImageDataset 30 | from model import FSRCNN 31 | 32 | 33 | def main(): 34 | # Initialize training to generate network evaluation indicators 35 | best_psnr = 0.0 36 | 37 | train_prefetcher, valid_prefetcher, test_prefetcher = load_dataset() 38 | print("Load train dataset and valid dataset successfully.") 39 | 40 | model = build_model() 41 | print("Build FSRCNN model successfully.") 42 | 43 | psnr_criterion, pixel_criterion = define_loss() 44 | print("Define all loss functions successfully.") 45 | 46 | optimizer = define_optimizer(model) 47 | print("Define all optimizer functions successfully.") 48 | 49 | print("Check whether the pretrained model is restored...") 50 | if config.resume: 51 | # Load checkpoint model 52 | checkpoint = torch.load(config.resume, map_location=lambda storage, loc: storage) 53 | # Restore the parameters in the training node to this point 54 | config.start_epoch = checkpoint["epoch"] 55 | best_psnr = checkpoint["best_psnr"] 56 | # Load checkpoint state dict. Extract the fitted model weights 57 | model_state_dict = model.state_dict() 58 | new_state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict} 59 | # Overwrite the pretrained model weights to the current model 60 | model_state_dict.update(new_state_dict) 61 | model.load_state_dict(model_state_dict) 62 | # Load the optimizer model 63 | optimizer.load_state_dict(checkpoint["optimizer"]) 64 | # Load the scheduler model 65 | # scheduler.load_state_dict(checkpoint["scheduler"]) 66 | print("Loaded pretrained model weights.") 67 | 68 | # Create a folder of super-resolution experiment results 69 | samples_dir = os.path.join("samples", config.exp_name) 70 | results_dir = os.path.join("results", config.exp_name) 71 | if not os.path.exists(samples_dir): 72 | os.makedirs(samples_dir) 73 | if not os.path.exists(results_dir): 74 | os.makedirs(results_dir) 75 | 76 | # Create training process log file 77 | writer = SummaryWriter(os.path.join("samples", "logs", config.exp_name)) 78 | 79 | # Initialize the gradient scaler 80 | scaler = amp.GradScaler() 81 | 82 | for epoch in range(config.start_epoch, config.epochs): 83 | train(model, train_prefetcher, psnr_criterion, pixel_criterion, optimizer, epoch, scaler, writer) 84 | _ = validate(model, valid_prefetcher, psnr_criterion, epoch, writer, "Valid") 85 | psnr = validate(model, test_prefetcher, psnr_criterion, epoch, writer, "Test") 86 | print("\n") 87 | 88 | # Automatically save the model with the highest index 89 | is_best = psnr > best_psnr 90 | best_psnr = max(psnr, best_psnr) 91 | torch.save({"epoch": epoch + 1, 92 | "best_psnr": best_psnr, 93 | "state_dict": model.state_dict(), 94 | "optimizer": optimizer.state_dict(), 95 | "scheduler": None}, 96 | os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar")) 97 | if is_best: 98 | shutil.copyfile(os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"), os.path.join(results_dir, "best.pth.tar")) 99 | if (epoch + 1) == config.epochs: 100 | shutil.copyfile(os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"), os.path.join(results_dir, "last.pth.tar")) 101 | 102 | 103 | def load_dataset() -> [CUDAPrefetcher, CUDAPrefetcher, CUDAPrefetcher]: 104 | # Load train, test and valid datasets 105 | train_datasets = TrainValidImageDataset(config.train_image_dir, config.image_size, config.upscale_factor, "Train") 106 | valid_datasets = TrainValidImageDataset(config.valid_image_dir, config.image_size, config.upscale_factor, "Valid") 107 | test_datasets = TestImageDataset(config.test_lr_image_dir, config.test_hr_image_dir, config.upscale_factor) 108 | 109 | # Generator all dataloader 110 | train_dataloader = DataLoader(train_datasets, 111 | batch_size=config.batch_size, 112 | shuffle=True, 113 | num_workers=config.num_workers, 114 | pin_memory=True, 115 | drop_last=True, 116 | persistent_workers=True) 117 | valid_dataloader = DataLoader(valid_datasets, 118 | batch_size=config.batch_size, 119 | shuffle=False, 120 | num_workers=config.num_workers, 121 | pin_memory=True, 122 | drop_last=False, 123 | persistent_workers=True) 124 | test_dataloader = DataLoader(test_datasets, 125 | batch_size=1, 126 | shuffle=False, 127 | num_workers=1, 128 | pin_memory=True, 129 | drop_last=False, 130 | persistent_workers=False) 131 | 132 | # Place all data on the preprocessing data loader 133 | train_prefetcher = CUDAPrefetcher(train_dataloader, config.device) 134 | valid_prefetcher = CUDAPrefetcher(valid_dataloader, config.device) 135 | test_prefetcher = CUDAPrefetcher(test_dataloader, config.device) 136 | 137 | return train_prefetcher, valid_prefetcher, test_prefetcher 138 | 139 | 140 | def build_model() -> nn.Module: 141 | model = FSRCNN(config.upscale_factor).to(config.device) 142 | 143 | return model 144 | 145 | 146 | def define_loss() -> [nn.MSELoss, nn.MSELoss]: 147 | psnr_criterion = nn.MSELoss().to(config.device) 148 | pixel_criterion = nn.MSELoss().to(config.device) 149 | 150 | return psnr_criterion, pixel_criterion 151 | 152 | 153 | def define_optimizer(model) -> optim.SGD: 154 | optimizer = optim.SGD([{"params": model.feature_extraction.parameters()}, 155 | {"params": model.shrink.parameters()}, 156 | {"params": model.map.parameters()}, 157 | {"params": model.expand.parameters()}, 158 | {"params": model.deconv.parameters(), "lr": config.model_lr * 0.1}], 159 | lr=config.model_lr, 160 | momentum=config.model_momentum, 161 | weight_decay=config.model_weight_decay, 162 | nesterov=config.model_nesterov) 163 | 164 | return optimizer 165 | 166 | 167 | def train(model, train_prefetcher, psnr_criterion, pixel_criterion, optimizer, epoch, scaler, writer) -> None: 168 | # Calculate how many iterations there are under epoch 169 | batches = len(train_prefetcher) 170 | 171 | batch_time = AverageMeter("Time", ":6.3f") 172 | data_time = AverageMeter("Data", ":6.3f") 173 | losses = AverageMeter("Loss", ":6.6f") 174 | psnres = AverageMeter("PSNR", ":4.2f") 175 | progress = ProgressMeter(batches, [batch_time, data_time, losses, psnres], prefix=f"Epoch: [{epoch + 1}]") 176 | 177 | # Put the generator in training mode 178 | model.train() 179 | 180 | batch_index = 0 181 | 182 | # Calculate the time it takes to test a batch of data 183 | end = time.time() 184 | # enable preload 185 | train_prefetcher.reset() 186 | batch_data = train_prefetcher.next() 187 | while batch_data is not None: 188 | # measure data loading time 189 | data_time.update(time.time() - end) 190 | 191 | lr = batch_data["lr"].to(config.device, non_blocking=True) 192 | hr = batch_data["hr"].to(config.device, non_blocking=True) 193 | 194 | # Initialize the generator gradient 195 | model.zero_grad() 196 | 197 | # Mixed precision training 198 | with amp.autocast(): 199 | sr = model(lr) 200 | loss = pixel_criterion(sr, hr) 201 | 202 | # Gradient zoom 203 | scaler.scale(loss).backward() 204 | # Update generator weight 205 | scaler.step(optimizer) 206 | scaler.update() 207 | 208 | # measure accuracy and record loss 209 | psnr = 10. * torch.log10(1. / psnr_criterion(sr, hr)) 210 | losses.update(loss.item(), lr.size(0)) 211 | psnres.update(psnr.item(), lr.size(0)) 212 | 213 | # measure elapsed time 214 | batch_time.update(time.time() - end) 215 | end = time.time() 216 | 217 | # Record training log information 218 | if batch_index % config.print_frequency == 0: 219 | # Writer Loss to file 220 | writer.add_scalar("Train/Loss", loss.item(), batch_index + epoch * batches + 1) 221 | progress.display(batch_index) 222 | 223 | # Preload the next batch of data 224 | batch_data = train_prefetcher.next() 225 | 226 | # After a batch of data is calculated, add 1 to the number of batches 227 | batch_index += 1 228 | 229 | 230 | def validate(model, valid_prefetcher, psnr_criterion, epoch, writer, mode) -> float: 231 | batch_time = AverageMeter("Time", ":6.3f", Summary.NONE) 232 | psnres = AverageMeter("PSNR", ":4.2f", Summary.AVERAGE) 233 | progress = ProgressMeter(len(valid_prefetcher), [batch_time, psnres], prefix=f"{mode}: ") 234 | 235 | # Put the model in verification mode 236 | model.eval() 237 | 238 | batch_index = 0 239 | 240 | # Calculate the time it takes to test a batch of data 241 | end = time.time() 242 | with torch.no_grad(): 243 | # enable preload 244 | valid_prefetcher.reset() 245 | batch_data = valid_prefetcher.next() 246 | 247 | while batch_data is not None: 248 | # measure data loading time 249 | lr = batch_data["lr"].to(config.device, non_blocking=True) 250 | hr = batch_data["hr"].to(config.device, non_blocking=True) 251 | 252 | # Mixed precision 253 | with amp.autocast(): 254 | sr = model(lr) 255 | 256 | # measure accuracy and record loss 257 | psnr = 10. * torch.log10(1. / psnr_criterion(sr, hr)) 258 | psnres.update(psnr.item(), lr.size(0)) 259 | 260 | # measure elapsed time 261 | batch_time.update(time.time() - end) 262 | end = time.time() 263 | 264 | # Record training log information 265 | if batch_index % config.print_frequency == 0: 266 | progress.display(batch_index) 267 | 268 | # Preload the next batch of data 269 | batch_data = valid_prefetcher.next() 270 | 271 | # After a batch of data is calculated, add 1 to the number of batches 272 | batch_index += 1 273 | 274 | # Print average PSNR metrics 275 | progress.display_summary() 276 | 277 | if mode == "Valid": 278 | writer.add_scalar("Valid/PSNR", psnres.avg, epoch + 1) 279 | elif mode == "Test": 280 | writer.add_scalar("Test/PSNR", psnres.avg, epoch + 1) 281 | else: 282 | raise ValueError("Unsupported mode, please use `Valid` or `Test`.") 283 | 284 | return psnres.avg 285 | 286 | 287 | # Copy form "https://github.com/pytorch/examples/blob/master/imagenet/main.py" 288 | class Summary(Enum): 289 | NONE = 0 290 | AVERAGE = 1 291 | SUM = 2 292 | COUNT = 3 293 | 294 | 295 | class AverageMeter(object): 296 | """Computes and stores the average and current value""" 297 | 298 | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE): 299 | self.name = name 300 | self.fmt = fmt 301 | self.summary_type = summary_type 302 | self.reset() 303 | 304 | def reset(self): 305 | self.val = 0 306 | self.avg = 0 307 | self.sum = 0 308 | self.count = 0 309 | 310 | def update(self, val, n=1): 311 | self.val = val 312 | self.sum += val * n 313 | self.count += n 314 | self.avg = self.sum / self.count 315 | 316 | def __str__(self): 317 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})" 318 | return fmtstr.format(**self.__dict__) 319 | 320 | def summary(self): 321 | if self.summary_type is Summary.NONE: 322 | fmtstr = "" 323 | elif self.summary_type is Summary.AVERAGE: 324 | fmtstr = "{name} {avg:.2f}" 325 | elif self.summary_type is Summary.SUM: 326 | fmtstr = "{name} {sum:.2f}" 327 | elif self.summary_type is Summary.COUNT: 328 | fmtstr = "{name} {count:.2f}" 329 | else: 330 | raise ValueError(f"Invalid summary type {self.summary_type}") 331 | 332 | return fmtstr.format(**self.__dict__) 333 | 334 | 335 | class ProgressMeter(object): 336 | def __init__(self, num_batches, meters, prefix=""): 337 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches) 338 | self.meters = meters 339 | self.prefix = prefix 340 | 341 | def display(self, batch): 342 | entries = [self.prefix + self.batch_fmtstr.format(batch)] 343 | entries += [str(meter) for meter in self.meters] 344 | print("\t".join(entries)) 345 | 346 | def display_summary(self): 347 | entries = [" *"] 348 | entries += [meter.summary() for meter in self.meters] 349 | print(" ".join(entries)) 350 | 351 | def _get_batch_fmtstr(self, num_batches): 352 | num_digits = len(str(num_batches // 1)) 353 | fmt = "{:" + str(num_digits) + "d}" 354 | return "[" + fmt + "/" + fmt.format(num_batches) + "]" 355 | 356 | 357 | if __name__ == "__main__": 358 | main() 359 | -------------------------------------------------------------------------------- /validate.py: -------------------------------------------------------------------------------- 1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved. 2 | # Licensed under the Apache License, Version 2.0 (the "License"); 3 | # you may not use this file except in compliance with the License. 4 | # You may obtain a copy of the License at 5 | # 6 | # http://www.apache.org/licenses/LICENSE-2.0 7 | # 8 | # Unless required by applicable law or agreed to in writing, software 9 | # distributed under the License is distributed on an "AS IS" BASIS, 10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # ============================================================================== 14 | """File description: Realize the verification function after model training.""" 15 | import os 16 | 17 | import cv2 18 | import numpy as np 19 | import torch 20 | from natsort import natsorted 21 | 22 | import config 23 | import imgproc 24 | from model import FSRCNN 25 | 26 | 27 | def main() -> None: 28 | # Initialize the super-resolution model 29 | model = FSRCNN(config.upscale_factor).to(config.device) 30 | print("Build FSRCNN model successfully.") 31 | 32 | # Load the super-resolution model weights 33 | checkpoint = torch.load(config.model_path, map_location=lambda storage, loc: storage) 34 | model.load_state_dict(checkpoint["state_dict"]) 35 | print(f"Load FSRCNN model weights `{os.path.abspath(config.model_path)}` successfully.") 36 | 37 | # Create a folder of super-resolution experiment results 38 | results_dir = os.path.join("results", "test", config.exp_name) 39 | if not os.path.exists(results_dir): 40 | os.makedirs(results_dir) 41 | 42 | # Start the verification mode of the model. 43 | model.eval() 44 | # Turn on half-precision inference. 45 | model.half() 46 | 47 | # Initialize the image evaluation index. 48 | total_psnr = 0.0 49 | 50 | # Get a list of test image file names. 51 | file_names = natsorted(os.listdir(config.hr_dir)) 52 | # Get the number of test image files. 53 | total_files = len(file_names) 54 | 55 | for index in range(total_files): 56 | lr_image_path = os.path.join(config.lr_dir, file_names[index]) 57 | sr_image_path = os.path.join(config.sr_dir, file_names[index]) 58 | hr_image_path = os.path.join(config.hr_dir, file_names[index]) 59 | 60 | print(f"Processing `{os.path.abspath(hr_image_path)}`...") 61 | # Read LR image and HR image 62 | lr_image = cv2.imread(lr_image_path).astype(np.float32) / 255.0 63 | hr_image = cv2.imread(hr_image_path).astype(np.float32) / 255.0 64 | 65 | # Convert BGR image to YCbCr image 66 | lr_ycbcr_image = imgproc.bgr2ycbcr(lr_image, use_y_channel=False) 67 | hr_ycbcr_image = imgproc.bgr2ycbcr(hr_image, use_y_channel=False) 68 | 69 | # Split YCbCr image data 70 | lr_y_image, lr_cb_image, lr_cr_image = cv2.split(lr_ycbcr_image) 71 | hr_y_image, hr_cb_image, hr_cr_image = cv2.split(hr_ycbcr_image) 72 | 73 | # Convert Y image data convert to Y tensor data 74 | lr_y_tensor = imgproc.image2tensor(lr_y_image, range_norm=False, half=True).to(config.device).unsqueeze_(0) 75 | hr_y_tensor = imgproc.image2tensor(hr_y_image, range_norm=False, half=True).to(config.device).unsqueeze_(0) 76 | 77 | # Only reconstruct the Y channel image data. 78 | with torch.no_grad(): 79 | sr_y_tensor = model(lr_y_tensor).clamp_(0, 1.0) 80 | 81 | # Cal PSNR 82 | total_psnr += 10. * torch.log10(1. / torch.mean((sr_y_tensor - hr_y_tensor) ** 2)) 83 | 84 | # Save image 85 | sr_y_image = imgproc.tensor2image(sr_y_tensor, range_norm=False, half=True) 86 | sr_y_image = sr_y_image.astype(np.float32) / 255.0 87 | sr_ycbcr_image = cv2.merge([sr_y_image, hr_cb_image, hr_cr_image]) 88 | sr_image = imgproc.ycbcr2bgr(sr_ycbcr_image) 89 | cv2.imwrite(sr_image_path, sr_image * 255.0) 90 | 91 | print(f"PSNR: {total_psnr / total_files:4.2f}dB.\n") 92 | 93 | 94 | if __name__ == "__main__": 95 | main() 96 | --------------------------------------------------------------------------------