├── .gitignore
├── LICENSE
├── README.md
├── assets
└── result.png
├── config.py
├── data
└── README.md
├── dataset.py
├── imgproc.py
├── model.py
├── requirements.txt
├── results
└── .gitkeep
├── samples
└── .gitkeep
├── scripts
├── augment_dataset.py
├── data_utils.py
├── prepare_dataset.py
├── run.py
└── split_train_valid_dataset.py
├── setup.py
├── train.py
└── validate.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # Byte-compiled / optimized / DLL files
2 | __pycache__/
3 | *.py[cod]
4 | *$py.class
5 |
6 | # C extensions
7 | *.so
8 |
9 | # Distribution / packaging
10 | .Python
11 | build/
12 | develop-eggs/
13 | dist/
14 | downloads/
15 | eggs/
16 | .eggs/
17 | lib/
18 | lib64/
19 | parts/
20 | sdist/
21 | var/
22 | wheels/
23 | pip-wheel-metadata/
24 | share/python-wheels/
25 | *.egg-info/
26 | .installed.cfg
27 | *.egg
28 | MANIFEST
29 |
30 | # PyInstaller
31 | # Usually these files are written by a python script from a template
32 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
33 | *.manifest
34 | *.spec
35 |
36 | # Installer logs
37 | pip-log.txt
38 | pip-delete-this-directory.txt
39 |
40 | # Unit test / coverage reports
41 | htmlcov/
42 | .tox/
43 | .nox/
44 | .coverage
45 | .coverage.*
46 | .cache
47 | nosetests.xml
48 | coverage.xml
49 | *.cover
50 | .hypothesis/
51 | .pytest_cache/
52 |
53 | # Translations
54 | *.mo
55 | *.pot
56 |
57 | # Django stuff:
58 | *.log
59 | local_settings.py
60 | db.sqlite3
61 | db.sqlite3-journal
62 |
63 | # Flask stuff:
64 | instance/
65 | .webassets-cache
66 |
67 | # Scrapy stuff:
68 | .scrapy
69 |
70 | # Sphinx documentation
71 | docs/_build/
72 |
73 | # PyBuilder
74 | target/
75 |
76 | # Jupyter Notebook
77 | .ipynb_checkpoints
78 |
79 | # IPython
80 | profile_default/
81 | ipython_config.py
82 |
83 | # pyenv
84 | .python-version
85 |
86 | # pipenv
87 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
88 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
89 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
90 | # install all needed dependencies.
91 | #Pipfile.lock
92 |
93 | # celery beat schedule file
94 | celerybeat-schedule
95 |
96 | # SageMath parsed files
97 | *.sage.py
98 |
99 | # Environments
100 | .env
101 | .venv
102 | env/
103 | venv/
104 | ENV/
105 | env.bak/
106 | venv.bak/
107 |
108 | # Spyder project settings
109 | .spyderproject
110 | .spyproject
111 |
112 | # Rope project settings
113 | .ropeproject
114 |
115 | # mkdocs documentation
116 | /site
117 | # mypy
118 | .mypy_cache/
119 | .dmypy.json
120 | dmypy.json
121 |
122 | # Pyre type checker
123 | .pyre/
124 |
125 | # custom
126 | .idea
127 | .vscode
128 |
129 | # Mac configure file.
130 | .DS_Store
131 |
132 | # Program run create directory.
133 | data
134 | results
135 | samples
136 |
137 | # Program run create file.
138 | *.bmp
139 | *.png
140 | *.mp4
141 | *.zip
142 | *.csv
143 | *.pth
144 | *.mdb
145 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FSRCNN-PyTorch
2 |
3 | ## Overview
4 |
5 | This repository contains an op-for-op PyTorch reimplementation of [Accelerating the Super-Resolution Convolutional Neural Network](https://arxiv.org/abs/1608.00367v1).
6 |
7 | ## Table of contents
8 |
9 | - [FSRCNN-PyTorch](#fsrcnn-pytorch)
10 | - [Overview](#overview)
11 | - [Table of contents](#table-of-contents)
12 | - [About Accelerating the Super-Resolution Convolutional Neural Network](#about-accelerating-the-super-resolution-convolutional-neural-network)
13 | - [Download weights](#download-weights)
14 | - [Download datasets](#download-datasets)
15 | - [Test](#test)
16 | - [Train](#train)
17 | - [Result](#result)
18 | - [Credit](#credit)
19 | - [Accelerating the Super-Resolution Convolutional Neural Network](#accelerating-the-super-resolution-convolutional-neural-network)
20 |
21 | ## About Accelerating the Super-Resolution Convolutional Neural Network
22 |
23 | If you're new to FSRCNN, here's an abstract straight from the paper:
24 |
25 | As a successful deep model applied in image super-resolution (SR), the Super-Resolution Convolutional Neural Network (
26 | SRCNN) has demonstrated superior performance to the previous hand-crafted models either in speed and restoration quality. However, the high
27 | computational cost still hinders it from practical usage that demands real-time performance (
28 | 24 fps). In this paper, we aim at accelerating the current SRCNN, and propose a compact hourglass-shape CNN structure for faster and better SR. We
29 | re-design the SRCNN structure mainly in three aspects. First, we introduce a deconvolution layer at the end of the network, then the mapping is
30 | learned directly from the original low-resolution image (without interpolation) to the high-resolution one. Second, we reformulate the mapping layer
31 | by shrinking the input feature dimension before mapping and expanding back afterwards. Third, we adopt smaller filter sizes but more mapping layers.
32 | The proposed model achieves a speed up of more than 40 times with even superior restoration quality. Further, we present the parameter settings that
33 | can achieve real-time performance on a generic CPU while still maintaining good performance. A corresponding transfer strategy is also proposed for
34 | fast training and testing across different upscaling factors.
35 |
36 | ## Download weights
37 |
38 | - [Google Driver](https://drive.google.com/drive/folders/17ju2HN7Y6pyPK2CC_AqnAfTOe9_3hCQ8?usp=sharing)
39 | - [Baidu Driver](https://pan.baidu.com/s/1yNs4rqIb004-NKEdKBJtYg?pwd=llot)
40 |
41 | ## Download datasets
42 |
43 | Contains DIV2K, DIV8K, Flickr2K, OST, T91, Set5, Set14, BSDS100 and BSDS200, etc.
44 |
45 | - [Google Driver](https://drive.google.com/drive/folders/1A6lzGeQrFMxPqJehK9s37ce-tPDj20mD?usp=sharing)
46 | - [Baidu Driver](https://pan.baidu.com/s/1o-8Ty_7q6DiS3ykLU09IVg?pwd=llot)
47 |
48 | ## Test
49 |
50 | Modify the contents of the file as follows.
51 |
52 | - line 29: `upscale_factor` change to the magnification you need to enlarge.
53 | - line 31: `mode` change Set to valid mode.
54 | - line 67: `model_path` change weight address after training.
55 |
56 | ## Train
57 |
58 | Modify the contents of the file as follows.
59 |
60 | - line 29: `upscale_factor` change to the magnification you need to enlarge.
61 | - line 31: `mode` change Set to train mode.
62 |
63 | If you want to load weights that you've trained before, modify the contents of the file as follows.
64 |
65 | - line 47: `start_epoch` change number of training iterations in the previous round.
66 | - line 48: `resume` change weight address that needs to be loaded.
67 |
68 | ## Result
69 |
70 | Source of original paper results: https://arxiv.org/pdf/1608.00367v1.pdf
71 |
72 | In the following table, the value in `()` indicates the result of the project, and `-` indicates no test.
73 |
74 | | Dataset | Scale | PSNR |
75 | |:-------:|:-----:|:----------------:|
76 | | Set5 | 2 | 36.94(**37.09**) |
77 | | Set5 | 3 | 33.06(**33.06**) |
78 | | Set5 | 4 | 30.55(**30.66**) |
79 |
80 | Low Resolution / Super Resolution / High Resolution
81 |
82 |
83 | ## Credit
84 |
85 | ### Accelerating the Super-Resolution Convolutional Neural Network
86 |
87 | _Chao Dong, Chen Change Loy, Xiaoou Tang_
88 |
89 | **Abstract**
90 | As a successful deep model applied in image super-resolution (SR), the Super-Resolution Convolutional Neural Network (
91 | SRCNN) has demonstrated superior performance to the previous hand-crafted models either in speed and restoration quality. However, the high
92 | computational cost still hinders it from practical usage that demands real-time performance (
93 | 24 fps). In this paper, we aim at accelerating the current SRCNN, and propose a compact hourglass-shape CNN structure for faster and better SR. We
94 | re-design the SRCNN structure mainly in three aspects. First, we introduce a deconvolution layer at the end of the network, then the mapping is
95 | learned directly from the original low-resolution image (without interpolation) to the high-resolution one. Second, we reformulate the mapping layer
96 | by shrinking the input feature dimension before mapping and expanding back afterwards. Third, we adopt smaller filter sizes but more mapping layers.
97 | The proposed model achieves a speed up of more than 40 times with even superior restoration quality. Further, we present the parameter settings that
98 | can achieve real-time performance on a generic CPU while still maintaining good performance. A corresponding transfer strategy is also proposed for
99 | fast training and testing across different upscaling factors.
100 |
101 | [[Paper]](https://arxiv.org/pdf/1608.00367v1.pdf) [[Author's implements(Caffe)]](https://drive.google.com/open?id=0B7tU5Pj1dfCMWjhhaE1HR3dqcGs)
102 |
103 | ```bibtex
104 | @article{DBLP:journals/corr/DongLT16,
105 | author = {Chao Dong and
106 | Chen Change Loy and
107 | Xiaoou Tang},
108 | title = {Accelerating the Super-Resolution Convolutional Neural Network},
109 | journal = {CoRR},
110 | volume = {abs/1608.00367},
111 | year = {2016},
112 | url = {http://arxiv.org/abs/1608.00367},
113 | eprinttype = {arXiv},
114 | eprint = {1608.00367},
115 | timestamp = {Mon, 13 Aug 2018 16:47:56 +0200},
116 | biburl = {https://dblp.org/rec/journals/corr/DongLT16.bib},
117 | bibsource = {dblp computer science bibliography, https://dblp.org}
118 | }
119 | ```
120 |
--------------------------------------------------------------------------------
/assets/result.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lornatang/FSRCNN-PyTorch/e6f8d53775d93249292fb80eed0af940aa7710c5/assets/result.png
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import random
15 |
16 | import numpy as np
17 | import torch
18 | from torch.backends import cudnn
19 |
20 | # Random seed to maintain reproducible results
21 | random.seed(0)
22 | torch.manual_seed(0)
23 | np.random.seed(0)
24 | # Use GPU for training by default
25 | device = torch.device("cuda", 0)
26 | # Turning on when the image size does not change during training can speed up training
27 | cudnn.benchmark = True
28 | # Image magnification factor
29 | upscale_factor = 2
30 | # Current configuration parameter method
31 | mode = "train"
32 | # Experiment name, easy to save weights and log files
33 | exp_name = "fsrcnn_x2"
34 |
35 | if mode == "train":
36 | # Dataset
37 | train_image_dir = f"data/T91/FSRCNN/train"
38 | valid_image_dir = f"data/T91/FSRCNN/valid"
39 | test_lr_image_dir = f"data/Set5/LRbicx{upscale_factor}"
40 | test_hr_image_dir = f"data/Set5/GTmod12"
41 |
42 | image_size = 20
43 | batch_size = 16
44 | num_workers = 4
45 |
46 | # Incremental training and migration training
47 | start_epoch = 0
48 | resume = ""
49 |
50 | # Total number of epochs
51 | epochs = 3000
52 |
53 | # SGD optimizer parameter
54 | model_lr = 1e-3
55 | model_momentum = 0.9
56 | model_weight_decay = 1e-4
57 | model_nesterov = False
58 |
59 | print_frequency = 200
60 |
61 | if mode == "valid":
62 | # Test data address
63 | lr_dir = f"data/Set5/LRbicx{upscale_factor}"
64 | sr_dir = f"results/test/{exp_name}"
65 | hr_dir = f"data/Set5/GTmod12"
66 |
67 | model_path = f"results/{exp_name}/best.pth.tar"
68 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Usage
2 |
3 | ## Download datasets
4 |
5 | Contains DIV2K, DIV8K, Flickr2K, OST, T91, Set5, Set14, BSDS100 and BSDS200, etc.
6 |
7 | - [Google Driver](https://drive.google.com/drive/folders/1A6lzGeQrFMxPqJehK9s37ce-tPDj20mD?usp=sharing)
8 | - [Baidu Driver](https://pan.baidu.com/s/1o-8Ty_7q6DiS3ykLU09IVg?pwd=llot)
9 |
10 | ## Train dataset struct information
11 |
12 | ```text
13 | - T91
14 | - FSRCNN
15 | - train
16 | - valid
17 | ```
18 |
19 | ## Test dataset struct information
20 |
21 | ```text
22 | - Set5
23 | - GTmod12
24 | - baby.png
25 | - bird.png
26 | - ...
27 | - LRbicx4
28 | - baby.png
29 | - bird.png
30 | - ...
31 | - Set14
32 | - GTmod12
33 | - baboon.png
34 | - barbara.png
35 | - ...
36 | - LRbicx4
37 | - baboon.png
38 | - barbara.png
39 | - ...
40 | ```
41 |
--------------------------------------------------------------------------------
/dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | """Realize the function of dataset preparation."""
15 | import os
16 | import queue
17 | import threading
18 |
19 | import cv2
20 | import numpy as np
21 | import torch
22 | from torch.utils.data import Dataset, DataLoader
23 |
24 | import imgproc
25 |
26 | __all__ = [
27 | "TrainValidImageDataset", "TestImageDataset",
28 | "PrefetchGenerator", "PrefetchDataLoader", "CPUPrefetcher", "CUDAPrefetcher",
29 | ]
30 |
31 |
32 | class TrainValidImageDataset(Dataset):
33 | """Customize the data set loading function and prepare low/high resolution image data in advance.
34 |
35 | Args:
36 | image_dir (str): Train/Valid dataset address.
37 | image_size (int): High resolution image size.
38 | upscale_factor (int): Image up scale factor.
39 | mode (str): Data set loading method, the training data set is for data enhancement, and the verification data set is not for data enhancement.
40 | """
41 |
42 | def __init__(self, image_dir: str, image_size: int, upscale_factor: int, mode: str) -> None:
43 | super(TrainValidImageDataset, self).__init__()
44 | # Get all image file names in folder
45 | self.image_file_names = [os.path.join(image_dir, image_file_name) for image_file_name in os.listdir(image_dir)]
46 | # Specify the high-resolution image size, with equal length and width
47 | self.image_size = image_size
48 | # How many times the high-resolution image is the low-resolution image
49 | self.upscale_factor = upscale_factor
50 | # Load training dataset or test dataset
51 | self.mode = mode
52 |
53 | def __getitem__(self, batch_index: int) -> [torch.Tensor, torch.Tensor]:
54 | # Read a batch of image data
55 | hr_image = cv2.imread(self.image_file_names[batch_index], cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
56 | # Use high-resolution image to make low-resolution image
57 | lr_image = imgproc.imresize(hr_image, 1 / self.upscale_factor)
58 |
59 | if self.mode == "Train":
60 | # Data augment
61 | lr_image, hr_image = imgproc.random_crop(lr_image, hr_image, self.image_size, self.upscale_factor)
62 | lr_image, hr_image = imgproc.random_rotate(lr_image, hr_image, angles=[0, 90, 180, 270])
63 | elif self.mode == "Valid":
64 | lr_image, hr_image = imgproc.center_crop(lr_image, hr_image, self.image_size, self.upscale_factor)
65 | else:
66 | raise ValueError("Unsupported data processing model, please use `Train` or `Valid`.")
67 |
68 | # Only extract the image data of the Y channel
69 | lr_y_image = imgproc.bgr2ycbcr(lr_image, use_y_channel=True)
70 | hr_y_image = imgproc.bgr2ycbcr(hr_image, use_y_channel=True)
71 |
72 | # Convert image data into Tensor stream format (PyTorch).
73 | # Note: The range of input and output is between [0, 1]
74 | lr_y_tensor = imgproc.image2tensor(lr_y_image, range_norm=False, half=False)
75 | hr_y_tensor = imgproc.image2tensor(hr_y_image, range_norm=False, half=False)
76 |
77 | return {"lr": lr_y_tensor, "hr": hr_y_tensor}
78 |
79 | def __len__(self) -> int:
80 | return len(self.image_file_names)
81 |
82 |
83 | class TestImageDataset(Dataset):
84 | """Define Test dataset loading methods.
85 |
86 | Args:
87 | test_lr_image_dir (str): Test dataset address for low resolution image dir.
88 | test_hr_image_dir (str): Test dataset address for high resolution image dir.
89 | upscale_factor (int): Image up scale factor.
90 | """
91 |
92 | def __init__(self, test_lr_image_dir: str, test_hr_image_dir: str, upscale_factor: int) -> None:
93 | super(TestImageDataset, self).__init__()
94 | # Get all image file names in folder
95 | self.lr_image_file_names = [os.path.join(test_lr_image_dir, x) for x in os.listdir(test_lr_image_dir)]
96 | self.hr_image_file_names = [os.path.join(test_hr_image_dir, x) for x in os.listdir(test_hr_image_dir)]
97 | # How many times the high-resolution image is the low-resolution image
98 | self.upscale_factor = upscale_factor
99 |
100 | def __getitem__(self, batch_index: int) -> [torch.Tensor, torch.Tensor]:
101 | # Read a batch of image data
102 | lr_image = cv2.imread(self.lr_image_file_names[batch_index], cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
103 | hr_image = cv2.imread(self.hr_image_file_names[batch_index], cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
104 |
105 | # Only extract the image data of the Y channel
106 | lr_y_image = imgproc.bgr2ycbcr(lr_image, use_y_channel=True)
107 | hr_y_image = imgproc.bgr2ycbcr(hr_image, use_y_channel=True)
108 |
109 | # Convert image data into Tensor stream format (PyTorch).
110 | # Note: The range of input and output is between [0, 1]
111 | lr_y_tensor = imgproc.image2tensor(lr_y_image, range_norm=False, half=False)
112 | hr_y_tensor = imgproc.image2tensor(hr_y_image, range_norm=False, half=False)
113 |
114 | return {"lr": lr_y_tensor, "hr": hr_y_tensor}
115 |
116 | def __len__(self) -> int:
117 | return len(self.lr_image_file_names)
118 |
119 |
120 | class PrefetchGenerator(threading.Thread):
121 | """A fast data prefetch generator.
122 |
123 | Args:
124 | generator: Data generator.
125 | num_data_prefetch_queue (int): How many early data load queues.
126 | """
127 |
128 | def __init__(self, generator, num_data_prefetch_queue: int) -> None:
129 | threading.Thread.__init__(self)
130 | self.queue = queue.Queue(num_data_prefetch_queue)
131 | self.generator = generator
132 | self.daemon = True
133 | self.start()
134 |
135 | def run(self) -> None:
136 | for item in self.generator:
137 | self.queue.put(item)
138 | self.queue.put(None)
139 |
140 | def __next__(self):
141 | next_item = self.queue.get()
142 | if next_item is None:
143 | raise StopIteration
144 | return next_item
145 |
146 | def __iter__(self):
147 | return self
148 |
149 |
150 | class PrefetchDataLoader(DataLoader):
151 | """A fast data prefetch dataloader.
152 |
153 | Args:
154 | num_data_prefetch_queue (int): How many early data load queues.
155 | kwargs (dict): Other extended parameters.
156 | """
157 |
158 | def __init__(self, num_data_prefetch_queue: int, **kwargs) -> None:
159 | self.num_data_prefetch_queue = num_data_prefetch_queue
160 | super(PrefetchDataLoader, self).__init__(**kwargs)
161 |
162 | def __iter__(self):
163 | return PrefetchGenerator(super().__iter__(), self.num_data_prefetch_queue)
164 |
165 |
166 | class CPUPrefetcher:
167 | """Use the CPU side to accelerate data reading.
168 |
169 | Args:
170 | dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
171 | """
172 |
173 | def __init__(self, dataloader) -> None:
174 | self.original_dataloader = dataloader
175 | self.data = iter(dataloader)
176 |
177 | def next(self):
178 | try:
179 | return next(self.data)
180 | except StopIteration:
181 | return None
182 |
183 | def reset(self):
184 | self.data = iter(self.original_dataloader)
185 |
186 | def __len__(self) -> int:
187 | return len(self.original_dataloader)
188 |
189 |
190 | class CUDAPrefetcher:
191 | """Use the CUDA side to accelerate data reading.
192 |
193 | Args:
194 | dataloader (DataLoader): Data loader. Combines a dataset and a sampler, and provides an iterable over the given dataset.
195 | device (torch.device): Specify running device.
196 | """
197 |
198 | def __init__(self, dataloader, device: torch.device):
199 | self.batch_data = None
200 | self.original_dataloader = dataloader
201 | self.device = device
202 |
203 | self.data = iter(dataloader)
204 | self.stream = torch.cuda.Stream()
205 | self.preload()
206 |
207 | def preload(self):
208 | try:
209 | self.batch_data = next(self.data)
210 | except StopIteration:
211 | self.batch_data = None
212 | return None
213 |
214 | with torch.cuda.stream(self.stream):
215 | for k, v in self.batch_data.items():
216 | if torch.is_tensor(v):
217 | self.batch_data[k] = self.batch_data[k].to(self.device, non_blocking=True)
218 |
219 | def next(self):
220 | torch.cuda.current_stream().wait_stream(self.stream)
221 | batch_data = self.batch_data
222 | self.preload()
223 | return batch_data
224 |
225 | def reset(self):
226 | self.data = iter(self.original_dataloader)
227 | self.preload()
228 |
229 | def __len__(self) -> int:
230 | return len(self.original_dataloader)
231 |
--------------------------------------------------------------------------------
/imgproc.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | """Realize the function of processing the dataset before training."""
15 | import math
16 | import random
17 | from typing import Any
18 |
19 | import cv2
20 | import numpy as np
21 | import torch
22 | from torchvision.transforms import functional as F
23 |
24 | __all__ = [
25 | "image2tensor", "tensor2image",
26 | "rgb2ycbcr", "bgr2ycbcr", "ycbcr2bgr", "ycbcr2rgb",
27 | "center_crop", "random_crop", "random_rotate", "random_horizontally_flip", "random_vertically_flip",
28 | ]
29 |
30 |
31 | def image2tensor(image: np.ndarray, range_norm: bool, half: bool) -> torch.Tensor:
32 | """Convert ``PIL.Image`` to Tensor.
33 |
34 | Args:
35 | image (np.ndarray): The image data read by ``PIL.Image``
36 | range_norm (bool): Scale [0, 1] data to between [-1, 1]
37 | half (bool): Whether to convert torch.float32 similarly to torch.half type.
38 |
39 | Returns:
40 | Normalized image data
41 |
42 | Examples:
43 | >>> image = cv2.imread("image.bmp", cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
44 | >>> tensor_image = image2tensor(image, range_norm=False, half=False)
45 | """
46 |
47 | tensor = F.to_tensor(image)
48 |
49 | if range_norm:
50 | tensor = tensor.mul_(2.0).sub_(1.0)
51 | if half:
52 | tensor = tensor.half()
53 |
54 | return tensor
55 |
56 |
57 | def tensor2image(tensor: torch.Tensor, range_norm: bool, half: bool) -> Any:
58 | """Converts ``torch.Tensor`` to ``PIL.Image``.
59 |
60 | Args:
61 | tensor (torch.Tensor): The image that needs to be converted to ``PIL.Image``
62 | range_norm (bool): Scale [-1, 1] data to between [0, 1]
63 | half (bool): Whether to convert torch.float32 similarly to torch.half type.
64 |
65 | Returns:
66 | Convert image data to support PIL library
67 |
68 | Examples:
69 | >>> tensor = torch.randn([1, 3, 128, 128])
70 | >>> image = tensor2image(tensor, range_norm=False, half=False)
71 | """
72 |
73 | if range_norm:
74 | tensor = tensor.add_(1.0).div_(2.0)
75 | if half:
76 | tensor = tensor.half()
77 |
78 | image = tensor.squeeze_(0).permute(1, 2, 0).mul_(255).clamp_(0, 255).cpu().numpy().astype("uint8")
79 |
80 | return image
81 |
82 |
83 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
84 | def cubic(x: Any):
85 | """Implementation of `cubic` function in Matlab under Python language.
86 |
87 | Args:
88 | x: Element vector.
89 |
90 | Returns:
91 | Bicubic interpolation.
92 | """
93 |
94 | absx = torch.abs(x)
95 | absx2 = absx ** 2
96 | absx3 = absx ** 3
97 | return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (
98 | ((absx > 1) * (absx <= 2)).type_as(absx))
99 |
100 |
101 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
102 | def calculate_weights_indices(in_length: int, out_length: int, scale: float, kernel_width: int, antialiasing: bool):
103 | """Implementation of `calculate_weights_indices` function in Matlab under Python language.
104 |
105 | Args:
106 | in_length (int): Input length.
107 | out_length (int): Output length.
108 | scale (float): Scale factor.
109 | kernel_width (int): Kernel width.
110 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
111 | Caution: Bicubic down-sampling in PIL uses antialiasing by default.
112 |
113 | """
114 |
115 | if (scale < 1) and antialiasing:
116 | # Use a modified kernel (larger kernel width) to simultaneously
117 | # interpolate and antialiasing
118 | kernel_width = kernel_width / scale
119 |
120 | # Output-space coordinates
121 | x = torch.linspace(1, out_length, out_length)
122 |
123 | # Input-space coordinates. Calculate the inverse mapping such that 0.5
124 | # in output space maps to 0.5 in input space, and 0.5 + scale in output
125 | # space maps to 1.5 in input space.
126 | u = x / scale + 0.5 * (1 - 1 / scale)
127 |
128 | # What is the left-most pixel that can be involved in the computation?
129 | left = torch.floor(u - kernel_width / 2)
130 |
131 | # What is the maximum number of pixels that can be involved in the
132 | # computation? Note: it's OK to use an extra pixel here; if the
133 | # corresponding weights are all zero, it will be eliminated at the end
134 | # of this function.
135 | p = math.ceil(kernel_width) + 2
136 |
137 | # The indices of the input pixels involved in computing the k-th output
138 | # pixel are in row k of the indices matrix.
139 | indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
140 | out_length, p)
141 |
142 | # The weights used to compute the k-th output pixel are in row k of the
143 | # weights matrix.
144 | distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
145 |
146 | # apply cubic kernel
147 | if (scale < 1) and antialiasing:
148 | weights = scale * cubic(distance_to_center * scale)
149 | else:
150 | weights = cubic(distance_to_center)
151 |
152 | # Normalize the weights matrix so that each row sums to 1.
153 | weights_sum = torch.sum(weights, 1).view(out_length, 1)
154 | weights = weights / weights_sum.expand(out_length, p)
155 |
156 | # If a column in weights is all zero, get rid of it. only consider the
157 | # first and last column.
158 | weights_zero_tmp = torch.sum((weights == 0), 0)
159 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
160 | indices = indices.narrow(1, 1, p - 2)
161 | weights = weights.narrow(1, 1, p - 2)
162 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
163 | indices = indices.narrow(1, 0, p - 2)
164 | weights = weights.narrow(1, 0, p - 2)
165 | weights = weights.contiguous()
166 | indices = indices.contiguous()
167 | sym_len_s = -indices.min() + 1
168 | sym_len_e = indices.max() - in_length
169 | indices = indices + sym_len_s - 1
170 | return weights, indices, int(sym_len_s), int(sym_len_e)
171 |
172 |
173 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
174 | def imresize(image: Any, scale_factor: float, antialiasing: bool = True) -> Any:
175 | """Implementation of `imresize` function in Matlab under Python language.
176 |
177 | Args:
178 | image: The input image.
179 | scale_factor (float): Scale factor. The same scale applies for both height and width.
180 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
181 | Caution: Bicubic down-sampling in `PIL` uses antialiasing by default. Default: ``True``.
182 |
183 | Returns:
184 | np.ndarray: Output image with shape (c, h, w), [0, 1] range, w/o round.
185 | """
186 | squeeze_flag = False
187 | if type(image).__module__ == np.__name__: # numpy type
188 | numpy_type = True
189 | if image.ndim == 2:
190 | image = image[:, :, None]
191 | squeeze_flag = True
192 | image = torch.from_numpy(image.transpose(2, 0, 1)).float()
193 | else:
194 | numpy_type = False
195 | if image.ndim == 2:
196 | image = image.unsqueeze(0)
197 | squeeze_flag = True
198 |
199 | in_c, in_h, in_w = image.size()
200 | out_h, out_w = math.ceil(in_h * scale_factor), math.ceil(in_w * scale_factor)
201 | kernel_width = 4
202 |
203 | # get weights and indices
204 | weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale_factor, kernel_width, antialiasing)
205 | weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale_factor, kernel_width, antialiasing)
206 | # process H dimension
207 | # symmetric copying
208 | img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
209 | img_aug.narrow(1, sym_len_hs, in_h).copy_(image)
210 |
211 | sym_patch = image[:, :sym_len_hs, :]
212 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
213 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
214 | img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
215 |
216 | sym_patch = image[:, -sym_len_he:, :]
217 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
218 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
219 | img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
220 |
221 | out_1 = torch.FloatTensor(in_c, out_h, in_w)
222 | kernel_width = weights_h.size(1)
223 | for i in range(out_h):
224 | idx = int(indices_h[i][0])
225 | for j in range(in_c):
226 | out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
227 |
228 | # process W dimension
229 | # symmetric copying
230 | out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
231 | out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
232 |
233 | sym_patch = out_1[:, :, :sym_len_ws]
234 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
235 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
236 | out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
237 |
238 | sym_patch = out_1[:, :, -sym_len_we:]
239 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
240 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
241 | out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
242 |
243 | out_2 = torch.FloatTensor(in_c, out_h, out_w)
244 | kernel_width = weights_w.size(1)
245 | for i in range(out_w):
246 | idx = int(indices_w[i][0])
247 | for j in range(in_c):
248 | out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
249 |
250 | if squeeze_flag:
251 | out_2 = out_2.squeeze(0)
252 | if numpy_type:
253 | out_2 = out_2.numpy()
254 | if not squeeze_flag:
255 | out_2 = out_2.transpose(1, 2, 0)
256 |
257 | return out_2
258 |
259 |
260 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
261 | def rgb2ycbcr(image: np.ndarray, use_y_channel: bool = False) -> np.ndarray:
262 | """Implementation of rgb2ycbcr function in Matlab under Python language.
263 |
264 | Args:
265 | image (np.ndarray): Image input in RGB format.
266 | use_y_channel (bool): Extract Y channel separately. Default: ``False``.
267 |
268 | Returns:
269 | ndarray: YCbCr image array data.
270 | """
271 |
272 | if use_y_channel:
273 | image = np.dot(image, [65.481, 128.553, 24.966]) + 16.0
274 | else:
275 | image = np.matmul(image, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
276 |
277 | image /= 255.
278 | image = image.astype(np.float32)
279 |
280 | return image
281 |
282 |
283 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
284 | def bgr2ycbcr(image: np.ndarray, use_y_channel: bool = False) -> np.ndarray:
285 | """Implementation of bgr2ycbcr function in Matlab under Python language.
286 |
287 | Args:
288 | image (np.ndarray): Image input in BGR format.
289 | use_y_channel (bool): Extract Y channel separately. Default: ``False``.
290 |
291 | Returns:
292 | ndarray: YCbCr image array data.
293 | """
294 |
295 | if use_y_channel:
296 | image = np.dot(image, [24.966, 128.553, 65.481]) + 16.0
297 | else:
298 | image = np.matmul(image, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
299 |
300 | image /= 255.
301 | image = image.astype(np.float32)
302 |
303 | return image
304 |
305 |
306 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
307 | def ycbcr2rgb(image: np.ndarray) -> np.ndarray:
308 | """Implementation of ycbcr2rgb function in Matlab under Python language.
309 |
310 | Args:
311 | image (np.ndarray): Image input in YCbCr format.
312 |
313 | Returns:
314 | ndarray: RGB image array data.
315 | """
316 |
317 | image_dtype = image.dtype
318 | image *= 255.
319 |
320 | image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621],
321 | [0, -0.00153632, 0.00791071],
322 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
323 |
324 | image /= 255.
325 | image = image.astype(image_dtype)
326 |
327 | return image
328 |
329 |
330 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
331 | def ycbcr2bgr(image: np.ndarray) -> np.ndarray:
332 | """Implementation of ycbcr2bgr function in Matlab under Python language.
333 |
334 | Args:
335 | image (np.ndarray): Image input in YCbCr format.
336 |
337 | Returns:
338 | ndarray: BGR image array data.
339 | """
340 |
341 | image_dtype = image.dtype
342 | image *= 255.
343 |
344 | image = np.matmul(image, [[0.00456621, 0.00456621, 0.00456621],
345 | [0.00791071, -0.00153632, 0],
346 | [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921]
347 |
348 | image /= 255.
349 | image = image.astype(image_dtype)
350 |
351 | return image
352 |
353 |
354 | def center_crop(lr_image: np.ndarray, hr_image: np.ndarray, hr_image_size: int, upscale_factor: int) -> [np.ndarray, np.ndarray]:
355 | """Crop small image patches from one image center area.
356 |
357 | Args:
358 | lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`.
359 | hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`.
360 | hr_image_size (int): The size of the captured high-resolution image area.
361 | upscale_factor (int): Image up scale factor.
362 |
363 | Returns:
364 | np.ndarray: Small patch images.
365 | """
366 |
367 | hr_image_height, hr_image_width = hr_image.shape[:2]
368 |
369 | # Just need to find the top and left coordinates of the image
370 | hr_top = (hr_image_height - hr_image_size) // 2
371 | hr_left = (hr_image_width - hr_image_size) // 2
372 |
373 | # Define the LR image position
374 | lr_top = hr_top // upscale_factor
375 | lr_left = hr_left // upscale_factor
376 | lr_image_size = hr_image_size // upscale_factor
377 |
378 | # Crop image patch
379 | patch_lr_image = lr_image[lr_top:lr_top + lr_image_size, lr_left:lr_left + lr_image_size, ...]
380 | patch_hr_image = hr_image[hr_top:hr_top + hr_image_size, hr_left:hr_left + hr_image_size, ...]
381 |
382 | return patch_lr_image, patch_hr_image
383 |
384 |
385 | def random_crop(lr_image: np.ndarray, hr_image: np.ndarray, hr_image_size: int, upscale_factor: int) -> [np.ndarray, np.ndarray]:
386 | """Crop small image patches from one image.
387 |
388 | Args:
389 | lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`.
390 | hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`.
391 | hr_image_size (int): The size of the captured high-resolution image area.
392 | upscale_factor (int): Image up scale factor.
393 |
394 | Returns:
395 | np.ndarray: Small patch images.
396 | """
397 |
398 | hr_image_height, hr_image_width = hr_image.shape[:2]
399 |
400 | # Just need to find the top and left coordinates of the image
401 | hr_top = random.randint(0, hr_image_height - hr_image_size)
402 | hr_left = random.randint(0, hr_image_width - hr_image_size)
403 |
404 | # Define the LR image position
405 | lr_top = hr_top // upscale_factor
406 | lr_left = hr_left // upscale_factor
407 | lr_image_size = hr_image_size // upscale_factor
408 |
409 | # Crop image patch
410 | patch_lr_image = lr_image[lr_top:lr_top + lr_image_size, lr_left:lr_left + lr_image_size, ...]
411 | patch_hr_image = hr_image[hr_top:hr_top + hr_image_size, hr_left:hr_left + hr_image_size, ...]
412 |
413 | return patch_lr_image, patch_hr_image
414 |
415 |
416 | def random_rotate(lr_image: np.ndarray, hr_image: np.ndarray, angles: list, lr_center=None, hr_center=None, scale_factor: float = 1.0) -> [np.ndarray, np.ndarray]:
417 | """Rotate an image randomly by a specified angle.
418 |
419 | Args:
420 | lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`.
421 | hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`.
422 | angles (list): Specify the rotation angle.
423 | lr_center (tuple[int]): Low-resolution image rotation center. If the center is None, initialize it as the center of the image. ``Default: None``.
424 | hr_center (tuple[int]): Low-resolution image rotation center. If the center is None, initialize it as the center of the image. ``Default: None``.
425 | scale_factor (float): scaling factor. Default: 1.0.
426 |
427 | Returns:
428 | np.ndarray: Rotated images.
429 | """
430 |
431 | lr_image_height, lr_image_width = lr_image.shape[:2]
432 | hr_image_height, hr_image_width = hr_image.shape[:2]
433 |
434 | if lr_center is None:
435 | lr_center = (lr_image_width // 2, lr_image_height // 2)
436 | if hr_center is None:
437 | hr_center = (hr_image_width // 2, hr_image_height // 2)
438 |
439 | # Random select specific angle
440 | angle = random.choice(angles)
441 |
442 | lr_matrix = cv2.getRotationMatrix2D(lr_center, angle, scale_factor)
443 | hr_matrix = cv2.getRotationMatrix2D(hr_center, angle, scale_factor)
444 |
445 | rotated_lr_image = cv2.warpAffine(lr_image, lr_matrix, (lr_image_width, lr_image_height))
446 | rotated_hr_image = cv2.warpAffine(hr_image, hr_matrix, (hr_image_width, hr_image_height))
447 |
448 | return rotated_lr_image, rotated_hr_image
449 |
450 |
451 | def random_horizontally_flip(lr_image: np.ndarray, hr_image: np.ndarray, p=0.5) -> [np.ndarray, np.ndarray]:
452 | """Flip an image horizontally randomly.
453 |
454 | Args:
455 | lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`.
456 | hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`.
457 | p (optional, float): rollover probability. (Default: 0.5)
458 |
459 | Returns:
460 | np.ndarray: Horizontally flip images.
461 | """
462 |
463 | if random.random() < p:
464 | lr_image = cv2.flip(lr_image, 1)
465 | hr_image = cv2.flip(hr_image, 1)
466 |
467 | return lr_image, hr_image
468 |
469 |
470 | def random_vertically_flip(lr_image: np.ndarray, hr_image: np.ndarray, p=0.5) -> [np.ndarray, np.ndarray]:
471 | """Flip an image vertically randomly.
472 |
473 | Args:
474 | lr_image (np.ndarray): The input low-resolution image for `OpenCV.imread`.
475 | hr_image (np.ndarray): The input high-resolution image for `OpenCV.imread`.
476 | p (optional, float): rollover probability. (Default: 0.5)
477 |
478 | Returns:
479 | np.ndarray: Vertically flip images.
480 | """
481 |
482 | if random.random() < p:
483 | lr_image = cv2.flip(lr_image, 0)
484 | hr_image = cv2.flip(hr_image, 0)
485 |
486 | return lr_image, hr_image
487 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ============================================================================
14 | """Realize the model definition function."""
15 | from math import sqrt
16 |
17 | import torch
18 | from torch import nn
19 |
20 |
21 | class FSRCNN(nn.Module):
22 | """
23 |
24 | Args:
25 | upscale_factor (int): Image magnification factor.
26 | """
27 |
28 | def __init__(self, upscale_factor: int) -> None:
29 | super(FSRCNN, self).__init__()
30 | # Feature extraction layer.
31 | self.feature_extraction = nn.Sequential(
32 | nn.Conv2d(1, 56, (5, 5), (1, 1), (2, 2)),
33 | nn.PReLU(56)
34 | )
35 |
36 | # Shrinking layer.
37 | self.shrink = nn.Sequential(
38 | nn.Conv2d(56, 12, (1, 1), (1, 1), (0, 0)),
39 | nn.PReLU(12)
40 | )
41 |
42 | # Mapping layer.
43 | self.map = nn.Sequential(
44 | nn.Conv2d(12, 12, (3, 3), (1, 1), (1, 1)),
45 | nn.PReLU(12),
46 | nn.Conv2d(12, 12, (3, 3), (1, 1), (1, 1)),
47 | nn.PReLU(12),
48 | nn.Conv2d(12, 12, (3, 3), (1, 1), (1, 1)),
49 | nn.PReLU(12),
50 | nn.Conv2d(12, 12, (3, 3), (1, 1), (1, 1)),
51 | nn.PReLU(12)
52 | )
53 |
54 | # Expanding layer.
55 | self.expand = nn.Sequential(
56 | nn.Conv2d(12, 56, (1, 1), (1, 1), (0, 0)),
57 | nn.PReLU(56)
58 | )
59 |
60 | # Deconvolution layer.
61 | self.deconv = nn.ConvTranspose2d(56, 1, (9, 9), (upscale_factor, upscale_factor), (4, 4), (upscale_factor - 1, upscale_factor - 1))
62 |
63 | # Initialize model weights.
64 | self._initialize_weights()
65 |
66 | def forward(self, x: torch.Tensor) -> torch.Tensor:
67 | return self._forward_impl(x)
68 |
69 | # Support torch.script function.
70 | def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
71 | out = self.feature_extraction(x)
72 | out = self.shrink(out)
73 | out = self.map(out)
74 | out = self.expand(out)
75 | out = self.deconv(out)
76 |
77 | return out
78 |
79 | # The filter weight of each layer is a Gaussian distribution with zero mean and standard deviation initialized by random extraction 0.001 (deviation is 0).
80 | def _initialize_weights(self) -> None:
81 | for m in self.modules():
82 | if isinstance(m, nn.Conv2d):
83 | nn.init.normal_(m.weight.data, mean=0.0, std=sqrt(2 / (m.out_channels * m.weight.data[0][0].numel())))
84 | nn.init.zeros_(m.bias.data)
85 |
86 | nn.init.normal_(self.deconv.weight.data, mean=0.0, std=0.001)
87 | nn.init.zeros_(self.deconv.bias.data)
88 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | opencv-python
2 | numpy
3 | torch
4 | tqdm
5 | setuptools
6 | torchvision
7 | Pillow
8 | natsort
--------------------------------------------------------------------------------
/results/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/samples/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/scripts/augment_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import argparse
15 | import os
16 | import shutil
17 | from multiprocessing import Pool
18 |
19 | import cv2
20 | from tqdm import tqdm
21 |
22 | import data_utils
23 |
24 |
25 | def main(args) -> None:
26 | if os.path.exists(args.output_dir):
27 | shutil.rmtree(args.output_dir)
28 | os.makedirs(args.output_dir)
29 |
30 | # Get all image paths
31 | image_file_names = os.listdir(args.images_dir)
32 |
33 | # Splitting images with multiple threads
34 | progress_bar = tqdm(total=len(image_file_names), unit="image", desc="Data augment")
35 | workers_pool = Pool(args.num_workers)
36 | for image_file_name in image_file_names:
37 | workers_pool.apply_async(worker, args=(image_file_name, args), callback=lambda arg: progress_bar.update(1))
38 | workers_pool.close()
39 | workers_pool.join()
40 | progress_bar.close()
41 |
42 |
43 | def worker(image_file_name, args) -> None:
44 | image = cv2.imread(f"{args.images_dir}/{image_file_name}")
45 |
46 | index = 1
47 | # Data augment
48 | for scale_ratio in [1.0, 0.9, 0.8, 0.7, 0.6]:
49 | new_image = data_utils.imresize(image, scale_ratio)
50 | # Save all images
51 | cv2.imwrite(f"{args.output_dir}/{image_file_name.split('.')[-2]}_{index:02d}.{image_file_name.split('.')[-1]}", new_image)
52 |
53 | index += 1
54 |
55 |
56 | if __name__ == "__main__":
57 | parser = argparse.ArgumentParser(description="Prepare database scripts.")
58 | parser.add_argument("--images_dir", type=str, help="Path to input image directory.")
59 | parser.add_argument("--output_dir", type=str, help="Path to generator image directory.")
60 | parser.add_argument("--num_workers", type=int, help="How many threads to open at the same time.")
61 | args = parser.parse_args()
62 |
63 | main(args)
64 |
--------------------------------------------------------------------------------
/scripts/data_utils.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import math
15 | from typing import Any
16 |
17 | import cv2
18 | import numpy as np
19 | import torch
20 |
21 |
22 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
23 | def cubic(x: Any):
24 | """Implementation of `cubic` function in Matlab under Python language.
25 |
26 | Args:
27 | x: Element vector.
28 |
29 | Returns:
30 | Bicubic interpolation.
31 | """
32 |
33 | absx = torch.abs(x)
34 | absx2 = absx ** 2
35 | absx3 = absx ** 3
36 | return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (
37 | ((absx > 1) * (absx <= 2)).type_as(absx))
38 |
39 |
40 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
41 | def calculate_weights_indices(in_length: int, out_length: int, scale: float, kernel_width: int, antialiasing: bool):
42 | """Implementation of `calculate_weights_indices` function in Matlab under Python language.
43 |
44 | Args:
45 | in_length (int): Input length.
46 | out_length (int): Output length.
47 | scale (float): Scale factor.
48 | kernel_width (int): Kernel width.
49 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
50 | Caution: Bicubic down-sampling in PIL uses antialiasing by default.
51 |
52 | """
53 |
54 | if (scale < 1) and antialiasing:
55 | # Use a modified kernel (larger kernel width) to simultaneously
56 | # interpolate and antialiasing
57 | kernel_width = kernel_width / scale
58 |
59 | # Output-space coordinates
60 | x = torch.linspace(1, out_length, out_length)
61 |
62 | # Input-space coordinates. Calculate the inverse mapping such that 0.5
63 | # in output space maps to 0.5 in input space, and 0.5 + scale in output
64 | # space maps to 1.5 in input space.
65 | u = x / scale + 0.5 * (1 - 1 / scale)
66 |
67 | # What is the left-most pixel that can be involved in the computation?
68 | left = torch.floor(u - kernel_width / 2)
69 |
70 | # What is the maximum number of pixels that can be involved in the
71 | # computation? Note: it's OK to use an extra pixel here; if the
72 | # corresponding weights are all zero, it will be eliminated at the end
73 | # of this function.
74 | p = math.ceil(kernel_width) + 2
75 |
76 | # The indices of the input pixels involved in computing the k-th output
77 | # pixel are in row k of the indices matrix.
78 | indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
79 | out_length, p)
80 |
81 | # The weights used to compute the k-th output pixel are in row k of the
82 | # weights matrix.
83 | distance_to_center = u.view(out_length, 1).expand(out_length, p) - indices
84 |
85 | # apply cubic kernel
86 | if (scale < 1) and antialiasing:
87 | weights = scale * cubic(distance_to_center * scale)
88 | else:
89 | weights = cubic(distance_to_center)
90 |
91 | # Normalize the weights matrix so that each row sums to 1.
92 | weights_sum = torch.sum(weights, 1).view(out_length, 1)
93 | weights = weights / weights_sum.expand(out_length, p)
94 |
95 | # If a column in weights is all zero, get rid of it. only consider the
96 | # first and last column.
97 | weights_zero_tmp = torch.sum((weights == 0), 0)
98 | if not math.isclose(weights_zero_tmp[0], 0, rel_tol=1e-6):
99 | indices = indices.narrow(1, 1, p - 2)
100 | weights = weights.narrow(1, 1, p - 2)
101 | if not math.isclose(weights_zero_tmp[-1], 0, rel_tol=1e-6):
102 | indices = indices.narrow(1, 0, p - 2)
103 | weights = weights.narrow(1, 0, p - 2)
104 | weights = weights.contiguous()
105 | indices = indices.contiguous()
106 | sym_len_s = -indices.min() + 1
107 | sym_len_e = indices.max() - in_length
108 | indices = indices + sym_len_s - 1
109 | return weights, indices, int(sym_len_s), int(sym_len_e)
110 |
111 |
112 | # Code reference `https://github.com/xinntao/BasicSR/blob/master/basicsr/utils/matlab_functions.py`
113 | def imresize(image: Any, scale_factor: float, antialiasing: bool = True) -> Any:
114 | """Implementation of `imresize` function in Matlab under Python language.
115 |
116 | Args:
117 | image: The input image.
118 | scale_factor (float): Scale factor. The same scale applies for both height and width.
119 | antialiasing (bool): Whether to apply antialiasing when down-sampling operations.
120 | Caution: Bicubic down-sampling in `PIL` uses antialiasing by default. Default: ``True``.
121 |
122 | Returns:
123 | np.ndarray: Output image with shape (c, h, w), [0, 1] range, w/o round.
124 | """
125 | squeeze_flag = False
126 | if type(image).__module__ == np.__name__: # numpy type
127 | numpy_type = True
128 | if image.ndim == 2:
129 | image = image[:, :, None]
130 | squeeze_flag = True
131 | image = torch.from_numpy(image.transpose(2, 0, 1)).float()
132 | else:
133 | numpy_type = False
134 | if image.ndim == 2:
135 | image = image.unsqueeze(0)
136 | squeeze_flag = True
137 |
138 | in_c, in_h, in_w = image.size()
139 | out_h, out_w = math.ceil(in_h * scale_factor), math.ceil(in_w * scale_factor)
140 | kernel_width = 4
141 |
142 | # get weights and indices
143 | weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale_factor, kernel_width, antialiasing)
144 | weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale_factor, kernel_width, antialiasing)
145 | # process H dimension
146 | # symmetric copying
147 | img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
148 | img_aug.narrow(1, sym_len_hs, in_h).copy_(image)
149 |
150 | sym_patch = image[:, :sym_len_hs, :]
151 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
152 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
153 | img_aug.narrow(1, 0, sym_len_hs).copy_(sym_patch_inv)
154 |
155 | sym_patch = image[:, -sym_len_he:, :]
156 | inv_idx = torch.arange(sym_patch.size(1) - 1, -1, -1).long()
157 | sym_patch_inv = sym_patch.index_select(1, inv_idx)
158 | img_aug.narrow(1, sym_len_hs + in_h, sym_len_he).copy_(sym_patch_inv)
159 |
160 | out_1 = torch.FloatTensor(in_c, out_h, in_w)
161 | kernel_width = weights_h.size(1)
162 | for i in range(out_h):
163 | idx = int(indices_h[i][0])
164 | for j in range(in_c):
165 | out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
166 |
167 | # process W dimension
168 | # symmetric copying
169 | out_1_aug = torch.FloatTensor(in_c, out_h, in_w + sym_len_ws + sym_len_we)
170 | out_1_aug.narrow(2, sym_len_ws, in_w).copy_(out_1)
171 |
172 | sym_patch = out_1[:, :, :sym_len_ws]
173 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
174 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
175 | out_1_aug.narrow(2, 0, sym_len_ws).copy_(sym_patch_inv)
176 |
177 | sym_patch = out_1[:, :, -sym_len_we:]
178 | inv_idx = torch.arange(sym_patch.size(2) - 1, -1, -1).long()
179 | sym_patch_inv = sym_patch.index_select(2, inv_idx)
180 | out_1_aug.narrow(2, sym_len_ws + in_w, sym_len_we).copy_(sym_patch_inv)
181 |
182 | out_2 = torch.FloatTensor(in_c, out_h, out_w)
183 | kernel_width = weights_w.size(1)
184 | for i in range(out_w):
185 | idx = int(indices_w[i][0])
186 | for j in range(in_c):
187 | out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
188 |
189 | if squeeze_flag:
190 | out_2 = out_2.squeeze(0)
191 | if numpy_type:
192 | out_2 = out_2.numpy()
193 | if not squeeze_flag:
194 | out_2 = out_2.transpose(1, 2, 0)
195 |
196 | return out_2
197 |
--------------------------------------------------------------------------------
/scripts/prepare_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import argparse
15 | import multiprocessing
16 | import os
17 | import shutil
18 |
19 | import cv2
20 | import numpy as np
21 | from tqdm import tqdm
22 |
23 |
24 | def main(args) -> None:
25 | if os.path.exists(args.output_dir):
26 | shutil.rmtree(args.output_dir)
27 | os.makedirs(args.output_dir)
28 |
29 | # Get all image paths
30 | image_file_names = os.listdir(args.images_dir)
31 |
32 | # Splitting images with multiple threads
33 | progress_bar = tqdm(total=len(image_file_names), unit="image", desc="Prepare split image")
34 | workers_pool = multiprocessing.Pool(args.num_workers)
35 | for image_file_name in image_file_names:
36 | workers_pool.apply_async(worker, args=(image_file_name, args), callback=lambda arg: progress_bar.update(1))
37 | workers_pool.close()
38 | workers_pool.join()
39 | progress_bar.close()
40 |
41 |
42 | def worker(image_file_name, args) -> None:
43 | image = cv2.imread(f"{args.images_dir}/{image_file_name}", cv2.IMREAD_UNCHANGED)
44 |
45 | image_height, image_width = image.shape[0:2]
46 |
47 | index = 1
48 | if image_height >= args.image_size and image_width >= args.image_size:
49 | for pos_y in range(0, image_height - args.image_size + 1, args.step):
50 | for pos_x in range(0, image_width - args.image_size + 1, args.step):
51 | # Crop
52 | crop_image = image[pos_y: pos_y + args.image_size, pos_x:pos_x + args.image_size, ...]
53 | crop_image = np.ascontiguousarray(crop_image)
54 | # Save image
55 | cv2.imwrite(f"{args.output_dir}/{image_file_name.split('.')[-2]}_{index:04d}.{image_file_name.split('.')[-1]}", crop_image)
56 |
57 | index += 1
58 |
59 |
60 | if __name__ == "__main__":
61 | parser = argparse.ArgumentParser(description="Prepare database scripts.")
62 | parser.add_argument("--images_dir", type=str, help="Path to input image directory.")
63 | parser.add_argument("--output_dir", type=str, help="Path to generator image directory.")
64 | parser.add_argument("--image_size", type=int, help="Low-resolution image size from raw image.")
65 | parser.add_argument("--step", type=int, help="Crop image similar to sliding window.")
66 | parser.add_argument("--num_workers", type=int, help="How many threads to open at the same time.")
67 | args = parser.parse_args()
68 |
69 | main(args)
70 |
--------------------------------------------------------------------------------
/scripts/run.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | # Prepare dataset
4 | os.system("python ./augment_dataset.py --images_dir ../data/T91/original --output_dir ../data/T91/FSRCNN/original --num_workers 10")
5 | os.system("python ./prepare_dataset.py --images_dir ../data/T91/FSRCNN/original --output_dir ../data/T91/FSRCNN/train --image_size 32 --step 16 --num_workers 10")
6 |
7 | # Split train and valid
8 | os.system("python ./split_train_valid_dataset.py --train_images_dir ../data/T91/FSRCNN/train --valid_images_dir ../data/T91/FSRCNN/valid --valid_samples_ratio 0.1")
9 |
--------------------------------------------------------------------------------
/scripts/split_train_valid_dataset.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import argparse
15 | import os
16 | import random
17 | import shutil
18 |
19 | from tqdm import tqdm
20 |
21 |
22 | def main(args) -> None:
23 | if not os.path.exists(args.train_images_dir):
24 | os.makedirs(args.train_images_dir)
25 | if not os.path.exists(args.valid_images_dir):
26 | os.makedirs(args.valid_images_dir)
27 |
28 | train_files = os.listdir(args.train_images_dir)
29 | valid_files = random.sample(train_files, int(len(train_files) * args.valid_samples_ratio))
30 |
31 | process_bar = tqdm(valid_files, total=len(valid_files), unit="image", desc="Split train/valid dataset")
32 |
33 | for image_file_name in process_bar:
34 | shutil.copyfile(f"{args.train_images_dir}/{image_file_name}", f"{args.valid_images_dir}/{image_file_name}")
35 |
36 |
37 | if __name__ == "__main__":
38 | parser = argparse.ArgumentParser(description="Split train and valid dataset scripts.")
39 | parser.add_argument("--train_images_dir", type=str, help="Path to train image directory.")
40 | parser.add_argument("--valid_images_dir", type=str, help="Path to valid image directory.")
41 | parser.add_argument("--valid_samples_ratio", type=float, help="What percentage of the data is extracted from the training set into the validation set.")
42 | args = parser.parse_args()
43 |
44 | main(args)
45 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | import io
15 | import os
16 | import sys
17 | from shutil import rmtree
18 |
19 | from setuptools import Command
20 | from setuptools import find_packages
21 | from setuptools import setup
22 |
23 | # Configure library params.
24 | NAME = "fsrcnn_pytorch"
25 | DESCRIPTION = "Accelerating the Super-Resolution Convolutional Neural Network."
26 | URL = "https://github.com/Lornatang/FSRCNN-PyTorch"
27 | EMAIL = "liu_changyu@dakewe.com"
28 | AUTHOR = "Liu Goodfellow"
29 | REQUIRES_PYTHON = ">=3.8.0"
30 | VERSION = "1.2.2"
31 |
32 | # Libraries that must be installed.
33 | REQUIRED = ["torch"]
34 |
35 | # The following libraries directory need to be installed if you need to run all scripts.
36 | EXTRAS = {}
37 |
38 | # Find the current running location.
39 | here = os.path.abspath(os.path.dirname(__file__))
40 |
41 | # About README file description.
42 | try:
43 | with io.open(os.path.join(here, "README.md"), encoding="utf-8") as f:
44 | long_description = "\n" + f.read()
45 | except FileNotFoundError:
46 | long_description = DESCRIPTION
47 |
48 | # Set Current Library Version.
49 | about = {}
50 | if not VERSION:
51 | project_slug = NAME.lower().replace("-", "_").replace(" ", "_")
52 | with open(os.path.join(here, project_slug, "__version__.py")) as f:
53 | exec(f.read(), about)
54 | else:
55 | about["__version__"] = VERSION
56 |
57 |
58 | class UploadCommand(Command):
59 | description = "Build and publish the package."
60 | user_options = []
61 |
62 | @staticmethod
63 | def status(s):
64 | print("\033[1m{0}\033[0m".format(s))
65 |
66 | def initialize_options(self):
67 | pass
68 |
69 | def finalize_options(self):
70 | pass
71 |
72 | def run(self):
73 | try:
74 | self.status("Removing previous builds…")
75 | rmtree(os.path.join(here, "dist"))
76 | except OSError:
77 | pass
78 |
79 | self.status("Building Source and Wheel (universal) distribution…")
80 | os.system("{0} setup.py sdist bdist_wheel --universal".format(sys.executable))
81 |
82 | self.status("Uploading the package to PyPI via Twine…")
83 | os.system("twine upload dist/*")
84 |
85 | self.status("Pushing git tags…")
86 | os.system("git tag v{0}".format(about["__version__"]))
87 | os.system("git push --tags")
88 |
89 | sys.exit()
90 |
91 |
92 | setup(name=NAME,
93 | version=about["__version__"],
94 | description=DESCRIPTION,
95 | long_description=long_description,
96 | long_description_content_type="text/markdown",
97 | author=AUTHOR,
98 | author_email=EMAIL,
99 | python_requires=REQUIRES_PYTHON,
100 | url=URL,
101 | packages=find_packages(exclude=["tests", "*.tests", "*.tests.*", "tests.*"]),
102 | install_requires=REQUIRED,
103 | extras_require=EXTRAS,
104 | include_package_data=True,
105 | license="Apache",
106 | classifiers=[
107 | # Full list: https://pypi.python.org/pypi?%3Aaction=list_classifiers
108 | "License :: OSI Approved :: Apache Software License",
109 | "Programming Language :: Python :: 3 :: Only"
110 | ],
111 | cmdclass={
112 | "upload": UploadCommand,
113 | },
114 | )
115 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ============================================================================
14 | """File description: Realize the model training function."""
15 | import os
16 | import shutil
17 | import time
18 | from enum import Enum
19 |
20 | import torch
21 | from torch import nn
22 | from torch import optim
23 | from torch.cuda import amp
24 | from torch.utils.data import DataLoader
25 | from torch.utils.tensorboard import SummaryWriter
26 |
27 | import config
28 | from dataset import CUDAPrefetcher
29 | from dataset import TrainValidImageDataset, TestImageDataset
30 | from model import FSRCNN
31 |
32 |
33 | def main():
34 | # Initialize training to generate network evaluation indicators
35 | best_psnr = 0.0
36 |
37 | train_prefetcher, valid_prefetcher, test_prefetcher = load_dataset()
38 | print("Load train dataset and valid dataset successfully.")
39 |
40 | model = build_model()
41 | print("Build FSRCNN model successfully.")
42 |
43 | psnr_criterion, pixel_criterion = define_loss()
44 | print("Define all loss functions successfully.")
45 |
46 | optimizer = define_optimizer(model)
47 | print("Define all optimizer functions successfully.")
48 |
49 | print("Check whether the pretrained model is restored...")
50 | if config.resume:
51 | # Load checkpoint model
52 | checkpoint = torch.load(config.resume, map_location=lambda storage, loc: storage)
53 | # Restore the parameters in the training node to this point
54 | config.start_epoch = checkpoint["epoch"]
55 | best_psnr = checkpoint["best_psnr"]
56 | # Load checkpoint state dict. Extract the fitted model weights
57 | model_state_dict = model.state_dict()
58 | new_state_dict = {k: v for k, v in checkpoint["state_dict"].items() if k in model_state_dict}
59 | # Overwrite the pretrained model weights to the current model
60 | model_state_dict.update(new_state_dict)
61 | model.load_state_dict(model_state_dict)
62 | # Load the optimizer model
63 | optimizer.load_state_dict(checkpoint["optimizer"])
64 | # Load the scheduler model
65 | # scheduler.load_state_dict(checkpoint["scheduler"])
66 | print("Loaded pretrained model weights.")
67 |
68 | # Create a folder of super-resolution experiment results
69 | samples_dir = os.path.join("samples", config.exp_name)
70 | results_dir = os.path.join("results", config.exp_name)
71 | if not os.path.exists(samples_dir):
72 | os.makedirs(samples_dir)
73 | if not os.path.exists(results_dir):
74 | os.makedirs(results_dir)
75 |
76 | # Create training process log file
77 | writer = SummaryWriter(os.path.join("samples", "logs", config.exp_name))
78 |
79 | # Initialize the gradient scaler
80 | scaler = amp.GradScaler()
81 |
82 | for epoch in range(config.start_epoch, config.epochs):
83 | train(model, train_prefetcher, psnr_criterion, pixel_criterion, optimizer, epoch, scaler, writer)
84 | _ = validate(model, valid_prefetcher, psnr_criterion, epoch, writer, "Valid")
85 | psnr = validate(model, test_prefetcher, psnr_criterion, epoch, writer, "Test")
86 | print("\n")
87 |
88 | # Automatically save the model with the highest index
89 | is_best = psnr > best_psnr
90 | best_psnr = max(psnr, best_psnr)
91 | torch.save({"epoch": epoch + 1,
92 | "best_psnr": best_psnr,
93 | "state_dict": model.state_dict(),
94 | "optimizer": optimizer.state_dict(),
95 | "scheduler": None},
96 | os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"))
97 | if is_best:
98 | shutil.copyfile(os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"), os.path.join(results_dir, "best.pth.tar"))
99 | if (epoch + 1) == config.epochs:
100 | shutil.copyfile(os.path.join(samples_dir, f"epoch_{epoch + 1}.pth.tar"), os.path.join(results_dir, "last.pth.tar"))
101 |
102 |
103 | def load_dataset() -> [CUDAPrefetcher, CUDAPrefetcher, CUDAPrefetcher]:
104 | # Load train, test and valid datasets
105 | train_datasets = TrainValidImageDataset(config.train_image_dir, config.image_size, config.upscale_factor, "Train")
106 | valid_datasets = TrainValidImageDataset(config.valid_image_dir, config.image_size, config.upscale_factor, "Valid")
107 | test_datasets = TestImageDataset(config.test_lr_image_dir, config.test_hr_image_dir, config.upscale_factor)
108 |
109 | # Generator all dataloader
110 | train_dataloader = DataLoader(train_datasets,
111 | batch_size=config.batch_size,
112 | shuffle=True,
113 | num_workers=config.num_workers,
114 | pin_memory=True,
115 | drop_last=True,
116 | persistent_workers=True)
117 | valid_dataloader = DataLoader(valid_datasets,
118 | batch_size=config.batch_size,
119 | shuffle=False,
120 | num_workers=config.num_workers,
121 | pin_memory=True,
122 | drop_last=False,
123 | persistent_workers=True)
124 | test_dataloader = DataLoader(test_datasets,
125 | batch_size=1,
126 | shuffle=False,
127 | num_workers=1,
128 | pin_memory=True,
129 | drop_last=False,
130 | persistent_workers=False)
131 |
132 | # Place all data on the preprocessing data loader
133 | train_prefetcher = CUDAPrefetcher(train_dataloader, config.device)
134 | valid_prefetcher = CUDAPrefetcher(valid_dataloader, config.device)
135 | test_prefetcher = CUDAPrefetcher(test_dataloader, config.device)
136 |
137 | return train_prefetcher, valid_prefetcher, test_prefetcher
138 |
139 |
140 | def build_model() -> nn.Module:
141 | model = FSRCNN(config.upscale_factor).to(config.device)
142 |
143 | return model
144 |
145 |
146 | def define_loss() -> [nn.MSELoss, nn.MSELoss]:
147 | psnr_criterion = nn.MSELoss().to(config.device)
148 | pixel_criterion = nn.MSELoss().to(config.device)
149 |
150 | return psnr_criterion, pixel_criterion
151 |
152 |
153 | def define_optimizer(model) -> optim.SGD:
154 | optimizer = optim.SGD([{"params": model.feature_extraction.parameters()},
155 | {"params": model.shrink.parameters()},
156 | {"params": model.map.parameters()},
157 | {"params": model.expand.parameters()},
158 | {"params": model.deconv.parameters(), "lr": config.model_lr * 0.1}],
159 | lr=config.model_lr,
160 | momentum=config.model_momentum,
161 | weight_decay=config.model_weight_decay,
162 | nesterov=config.model_nesterov)
163 |
164 | return optimizer
165 |
166 |
167 | def train(model, train_prefetcher, psnr_criterion, pixel_criterion, optimizer, epoch, scaler, writer) -> None:
168 | # Calculate how many iterations there are under epoch
169 | batches = len(train_prefetcher)
170 |
171 | batch_time = AverageMeter("Time", ":6.3f")
172 | data_time = AverageMeter("Data", ":6.3f")
173 | losses = AverageMeter("Loss", ":6.6f")
174 | psnres = AverageMeter("PSNR", ":4.2f")
175 | progress = ProgressMeter(batches, [batch_time, data_time, losses, psnres], prefix=f"Epoch: [{epoch + 1}]")
176 |
177 | # Put the generator in training mode
178 | model.train()
179 |
180 | batch_index = 0
181 |
182 | # Calculate the time it takes to test a batch of data
183 | end = time.time()
184 | # enable preload
185 | train_prefetcher.reset()
186 | batch_data = train_prefetcher.next()
187 | while batch_data is not None:
188 | # measure data loading time
189 | data_time.update(time.time() - end)
190 |
191 | lr = batch_data["lr"].to(config.device, non_blocking=True)
192 | hr = batch_data["hr"].to(config.device, non_blocking=True)
193 |
194 | # Initialize the generator gradient
195 | model.zero_grad()
196 |
197 | # Mixed precision training
198 | with amp.autocast():
199 | sr = model(lr)
200 | loss = pixel_criterion(sr, hr)
201 |
202 | # Gradient zoom
203 | scaler.scale(loss).backward()
204 | # Update generator weight
205 | scaler.step(optimizer)
206 | scaler.update()
207 |
208 | # measure accuracy and record loss
209 | psnr = 10. * torch.log10(1. / psnr_criterion(sr, hr))
210 | losses.update(loss.item(), lr.size(0))
211 | psnres.update(psnr.item(), lr.size(0))
212 |
213 | # measure elapsed time
214 | batch_time.update(time.time() - end)
215 | end = time.time()
216 |
217 | # Record training log information
218 | if batch_index % config.print_frequency == 0:
219 | # Writer Loss to file
220 | writer.add_scalar("Train/Loss", loss.item(), batch_index + epoch * batches + 1)
221 | progress.display(batch_index)
222 |
223 | # Preload the next batch of data
224 | batch_data = train_prefetcher.next()
225 |
226 | # After a batch of data is calculated, add 1 to the number of batches
227 | batch_index += 1
228 |
229 |
230 | def validate(model, valid_prefetcher, psnr_criterion, epoch, writer, mode) -> float:
231 | batch_time = AverageMeter("Time", ":6.3f", Summary.NONE)
232 | psnres = AverageMeter("PSNR", ":4.2f", Summary.AVERAGE)
233 | progress = ProgressMeter(len(valid_prefetcher), [batch_time, psnres], prefix=f"{mode}: ")
234 |
235 | # Put the model in verification mode
236 | model.eval()
237 |
238 | batch_index = 0
239 |
240 | # Calculate the time it takes to test a batch of data
241 | end = time.time()
242 | with torch.no_grad():
243 | # enable preload
244 | valid_prefetcher.reset()
245 | batch_data = valid_prefetcher.next()
246 |
247 | while batch_data is not None:
248 | # measure data loading time
249 | lr = batch_data["lr"].to(config.device, non_blocking=True)
250 | hr = batch_data["hr"].to(config.device, non_blocking=True)
251 |
252 | # Mixed precision
253 | with amp.autocast():
254 | sr = model(lr)
255 |
256 | # measure accuracy and record loss
257 | psnr = 10. * torch.log10(1. / psnr_criterion(sr, hr))
258 | psnres.update(psnr.item(), lr.size(0))
259 |
260 | # measure elapsed time
261 | batch_time.update(time.time() - end)
262 | end = time.time()
263 |
264 | # Record training log information
265 | if batch_index % config.print_frequency == 0:
266 | progress.display(batch_index)
267 |
268 | # Preload the next batch of data
269 | batch_data = valid_prefetcher.next()
270 |
271 | # After a batch of data is calculated, add 1 to the number of batches
272 | batch_index += 1
273 |
274 | # Print average PSNR metrics
275 | progress.display_summary()
276 |
277 | if mode == "Valid":
278 | writer.add_scalar("Valid/PSNR", psnres.avg, epoch + 1)
279 | elif mode == "Test":
280 | writer.add_scalar("Test/PSNR", psnres.avg, epoch + 1)
281 | else:
282 | raise ValueError("Unsupported mode, please use `Valid` or `Test`.")
283 |
284 | return psnres.avg
285 |
286 |
287 | # Copy form "https://github.com/pytorch/examples/blob/master/imagenet/main.py"
288 | class Summary(Enum):
289 | NONE = 0
290 | AVERAGE = 1
291 | SUM = 2
292 | COUNT = 3
293 |
294 |
295 | class AverageMeter(object):
296 | """Computes and stores the average and current value"""
297 |
298 | def __init__(self, name, fmt=":f", summary_type=Summary.AVERAGE):
299 | self.name = name
300 | self.fmt = fmt
301 | self.summary_type = summary_type
302 | self.reset()
303 |
304 | def reset(self):
305 | self.val = 0
306 | self.avg = 0
307 | self.sum = 0
308 | self.count = 0
309 |
310 | def update(self, val, n=1):
311 | self.val = val
312 | self.sum += val * n
313 | self.count += n
314 | self.avg = self.sum / self.count
315 |
316 | def __str__(self):
317 | fmtstr = "{name} {val" + self.fmt + "} ({avg" + self.fmt + "})"
318 | return fmtstr.format(**self.__dict__)
319 |
320 | def summary(self):
321 | if self.summary_type is Summary.NONE:
322 | fmtstr = ""
323 | elif self.summary_type is Summary.AVERAGE:
324 | fmtstr = "{name} {avg:.2f}"
325 | elif self.summary_type is Summary.SUM:
326 | fmtstr = "{name} {sum:.2f}"
327 | elif self.summary_type is Summary.COUNT:
328 | fmtstr = "{name} {count:.2f}"
329 | else:
330 | raise ValueError(f"Invalid summary type {self.summary_type}")
331 |
332 | return fmtstr.format(**self.__dict__)
333 |
334 |
335 | class ProgressMeter(object):
336 | def __init__(self, num_batches, meters, prefix=""):
337 | self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
338 | self.meters = meters
339 | self.prefix = prefix
340 |
341 | def display(self, batch):
342 | entries = [self.prefix + self.batch_fmtstr.format(batch)]
343 | entries += [str(meter) for meter in self.meters]
344 | print("\t".join(entries))
345 |
346 | def display_summary(self):
347 | entries = [" *"]
348 | entries += [meter.summary() for meter in self.meters]
349 | print(" ".join(entries))
350 |
351 | def _get_batch_fmtstr(self, num_batches):
352 | num_digits = len(str(num_batches // 1))
353 | fmt = "{:" + str(num_digits) + "d}"
354 | return "[" + fmt + "/" + fmt.format(num_batches) + "]"
355 |
356 |
357 | if __name__ == "__main__":
358 | main()
359 |
--------------------------------------------------------------------------------
/validate.py:
--------------------------------------------------------------------------------
1 | # Copyright 2021 Dakewe Biotech Corporation. All Rights Reserved.
2 | # Licensed under the Apache License, Version 2.0 (the "License");
3 | # you may not use this file except in compliance with the License.
4 | # You may obtain a copy of the License at
5 | #
6 | # http://www.apache.org/licenses/LICENSE-2.0
7 | #
8 | # Unless required by applicable law or agreed to in writing, software
9 | # distributed under the License is distributed on an "AS IS" BASIS,
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | # ==============================================================================
14 | """File description: Realize the verification function after model training."""
15 | import os
16 |
17 | import cv2
18 | import numpy as np
19 | import torch
20 | from natsort import natsorted
21 |
22 | import config
23 | import imgproc
24 | from model import FSRCNN
25 |
26 |
27 | def main() -> None:
28 | # Initialize the super-resolution model
29 | model = FSRCNN(config.upscale_factor).to(config.device)
30 | print("Build FSRCNN model successfully.")
31 |
32 | # Load the super-resolution model weights
33 | checkpoint = torch.load(config.model_path, map_location=lambda storage, loc: storage)
34 | model.load_state_dict(checkpoint["state_dict"])
35 | print(f"Load FSRCNN model weights `{os.path.abspath(config.model_path)}` successfully.")
36 |
37 | # Create a folder of super-resolution experiment results
38 | results_dir = os.path.join("results", "test", config.exp_name)
39 | if not os.path.exists(results_dir):
40 | os.makedirs(results_dir)
41 |
42 | # Start the verification mode of the model.
43 | model.eval()
44 | # Turn on half-precision inference.
45 | model.half()
46 |
47 | # Initialize the image evaluation index.
48 | total_psnr = 0.0
49 |
50 | # Get a list of test image file names.
51 | file_names = natsorted(os.listdir(config.hr_dir))
52 | # Get the number of test image files.
53 | total_files = len(file_names)
54 |
55 | for index in range(total_files):
56 | lr_image_path = os.path.join(config.lr_dir, file_names[index])
57 | sr_image_path = os.path.join(config.sr_dir, file_names[index])
58 | hr_image_path = os.path.join(config.hr_dir, file_names[index])
59 |
60 | print(f"Processing `{os.path.abspath(hr_image_path)}`...")
61 | # Read LR image and HR image
62 | lr_image = cv2.imread(lr_image_path).astype(np.float32) / 255.0
63 | hr_image = cv2.imread(hr_image_path).astype(np.float32) / 255.0
64 |
65 | # Convert BGR image to YCbCr image
66 | lr_ycbcr_image = imgproc.bgr2ycbcr(lr_image, use_y_channel=False)
67 | hr_ycbcr_image = imgproc.bgr2ycbcr(hr_image, use_y_channel=False)
68 |
69 | # Split YCbCr image data
70 | lr_y_image, lr_cb_image, lr_cr_image = cv2.split(lr_ycbcr_image)
71 | hr_y_image, hr_cb_image, hr_cr_image = cv2.split(hr_ycbcr_image)
72 |
73 | # Convert Y image data convert to Y tensor data
74 | lr_y_tensor = imgproc.image2tensor(lr_y_image, range_norm=False, half=True).to(config.device).unsqueeze_(0)
75 | hr_y_tensor = imgproc.image2tensor(hr_y_image, range_norm=False, half=True).to(config.device).unsqueeze_(0)
76 |
77 | # Only reconstruct the Y channel image data.
78 | with torch.no_grad():
79 | sr_y_tensor = model(lr_y_tensor).clamp_(0, 1.0)
80 |
81 | # Cal PSNR
82 | total_psnr += 10. * torch.log10(1. / torch.mean((sr_y_tensor - hr_y_tensor) ** 2))
83 |
84 | # Save image
85 | sr_y_image = imgproc.tensor2image(sr_y_tensor, range_norm=False, half=True)
86 | sr_y_image = sr_y_image.astype(np.float32) / 255.0
87 | sr_ycbcr_image = cv2.merge([sr_y_image, hr_cb_image, hr_cr_image])
88 | sr_image = imgproc.ycbcr2bgr(sr_ycbcr_image)
89 | cv2.imwrite(sr_image_path, sr_image * 255.0)
90 |
91 | print(f"PSNR: {total_psnr / total_files:4.2f}dB.\n")
92 |
93 |
94 | if __name__ == "__main__":
95 | main()
96 |
--------------------------------------------------------------------------------