├── .gitignore
├── LICENSE
├── README.md
├── VERSION
├── assets
├── dataset.jpg
├── real.jpg
└── synthetic.jpg
├── basicsr
├── data
│ ├── __init__.py
│ ├── data_sampler.py
│ ├── data_util.py
│ ├── lap_dataset.py
│ ├── prefetch_dataloader.py
│ └── transforms.py
├── metrics
│ ├── __init__.py
│ ├── gmsd.py
│ ├── metric_util.py
│ └── psnr_ssim.py
├── models
│ ├── __init__.py
│ ├── archs
│ │ ├── __init__.py
│ │ ├── jarnet_arch.py
│ │ └── nafnet_arch.py
│ ├── base_model.py
│ ├── image_restoration_model.py
│ ├── losses
│ │ ├── __init__.py
│ │ ├── loss_util.py
│ │ └── losses.py
│ └── lr_scheduler.py
├── test.py
├── train_jarnet.py
├── train_others.py
└── utils
│ ├── __init__.py
│ ├── bundle_submissions.py
│ ├── create_lmdb.py
│ ├── dist_util.py
│ ├── download_util.py
│ ├── face_util.py
│ ├── file_client.py
│ ├── flow_util.py
│ ├── img_util.py
│ ├── lmdb_util.py
│ ├── logger.py
│ ├── matlab_functions.py
│ ├── misc.py
│ └── options.py
├── options
├── test_JARNet_LAP.yml
├── test_NAFNet_LAP.yml
├── train_JARNet_LAP.yml
└── train_NAFNet_LAP.yml
├── requirements.txt
├── setup.cfg
└── setup.py
/.gitignore:
--------------------------------------------------------------------------------
1 | # ignored folders
2 | datasets/*
3 | experiments/*
4 | results/*
5 | tb_logger/*
6 | basicsr/tb_logger/*
7 | wandb/*
8 | tmp/*
9 |
10 | docs/api
11 | scripts/__init__.py
12 |
13 | *.DS_Store
14 | .idea
15 |
16 | # ignored files
17 | version.py
18 |
19 | # ignored files with suffix
20 | *.html
21 | *.png
22 | *.jpeg
23 | *.jpg
24 | *.gif
25 | *.pth
26 | *.zip
27 | *.npy
28 | *.tif
29 |
30 | # Byte-compiled / optimized / DLL files
31 | __pycache__/
32 | *.py[cod]
33 | *$py.class
34 |
35 | # C extensions
36 | *.so
37 |
38 | # Distribution / packaging
39 | .Python
40 | build/
41 | develop-eggs/
42 | dist/
43 | downloads/
44 | eggs/
45 | .eggs/
46 | lib/
47 | lib64/
48 | parts/
49 | sdist/
50 | var/
51 | wheels/
52 | share/python-wheels/
53 | *.egg-info/
54 | .installed.cfg
55 | *.egg
56 | MANIFEST
57 |
58 | # PyInstaller
59 | # Usually these files are written by a python script from a template
60 | # before PyInstaller builds the exe, so as to inject date/other infos into it.
61 | *.manifest
62 | *.spec
63 |
64 | # Installer logs
65 | pip-log.txt
66 | pip-delete-this-directory.txt
67 |
68 | # Unit test / coverage reports
69 | htmlcov/
70 | .tox/
71 | .nox/
72 | .coverage
73 | .coverage.*
74 | .cache
75 | nosetests.xml
76 | coverage.xml
77 | *.cover
78 | *.py,cover
79 | .hypothesis/
80 | .pytest_cache/
81 | cover/
82 |
83 | # Translations
84 | *.mo
85 | *.pot
86 |
87 | # Django stuff:
88 | *.log
89 | local_settings.py
90 | db.sqlite3
91 | db.sqlite3-journal
92 |
93 | # Flask stuff:
94 | instance/
95 | .webassets-cache
96 |
97 | # Scrapy stuff:
98 | .scrapy
99 |
100 | # Sphinx documentation
101 | docs/_build/
102 |
103 | # PyBuilder
104 | .pybuilder/
105 | target/
106 |
107 | # Jupyter Notebook
108 | .ipynb_checkpoints
109 |
110 | # IPython
111 | profile_default/
112 | ipython_config.py
113 |
114 | # pyenv
115 | # For a library or package, you might want to ignore these files since the code is
116 | # intended to run in multiple environments; otherwise, check them in:
117 | # .python-version
118 |
119 | # pipenv
120 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
121 | # However, in case of collaboration, if having platform-specific dependencies or dependencies
122 | # having no cross-platform support, pipenv may install dependencies that don't work, or not
123 | # install all needed dependencies.
124 | #Pipfile.lock
125 |
126 | # poetry
127 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
128 | # This is especially recommended for binary packages to ensure reproducibility, and is more
129 | # commonly ignored for libraries.
130 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
131 | #poetry.lock
132 |
133 | # pdm
134 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
135 | #pdm.lock
136 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
137 | # in version control.
138 | # https://pdm.fming.dev/#use-with-ide
139 | .pdm.toml
140 |
141 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
142 | __pypackages__/
143 |
144 | # Celery stuff
145 | celerybeat-schedule
146 | celerybeat.pid
147 |
148 | # SageMath parsed files
149 | *.sage.py
150 |
151 | # Environments
152 | .env
153 | .venv
154 | env/
155 | venv/
156 | ENV/
157 | env.bak/
158 | venv.bak/
159 |
160 | # Spyder project settings
161 | .spyderproject
162 | .spyproject
163 |
164 | # Rope project settings
165 | .ropeproject
166 |
167 | # mkdocs documentation
168 | /site
169 |
170 | # mypy
171 | .mypy_cache/
172 | .dmypy.json
173 | dmypy.json
174 |
175 | # Pyre type checker
176 | .pyre/
177 |
178 | # pytype static type analyzer
179 | .pytype/
180 |
181 | # Cython debug symbols
182 | cython_debug/
183 |
184 | # PyCharm
185 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
186 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
187 | # and can be added to the global gitignore or merged into this file. For a more nuclear
188 | # option (not recommended) you can uncomment the following to ignore the entire idea folder.
189 | #.idea/
190 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Deep Linear Array Pushbroom Image Restoration: A Degradation Pipeline and Jitter-Aware Restoration Network
2 | Zida Chen*, Ziran Zhang*, Haoying Li, Menghao Li, Yueting Chen, Qi Li, Huajun Feng, Zhihai Xu, Shiqi Chen
3 |
4 | State Key Laboratory of Extreme Photonics and Instrumentation, Zhejiang University
5 |
6 | ## News
7 | **2024.01.17** Our train/test code, LAP dataset and pre-trained model are available now.
8 |
9 | **2023.12.09** Our paper is accepted by **AAAI 2024**! [Paper Link](https://arxiv.org/abs/2401.08171)
10 |
11 | ## Get Started
12 | ### Data Preparation
13 |
14 | You can download our synthetic LAP [training](https://drive.google.com/file/d/1n8PHoG7RYrp3uvDU-5DXF4FFl_KHqZt1/view?usp=sharing) and [evaluation](https://drive.google.com/file/d/18hLBPKpbpSBRUpnRcM5zzZySyRurMk81/view?usp=sharing) dataset.
15 |
16 | Please unzip them and config the dataset path in your config file (e.g. options/test_JARNet_LAP.yaml).
17 |
18 | Note: You need to reserve at least **200GB** of disk space for storage of data.
19 |
20 |
21 | ### Environment
22 |
23 | - Install the packages in your environment (python >= 3.7):
24 | ```
25 | pip install -r requirements.txt
26 | ```
27 |
28 | - Build up [BasicSR](https://github.com/XPixelGroup/BasicSR) environment by running:
29 | ```
30 | python setup.py develop --no_cuda_ext
31 | ```
32 |
33 | - To choose the gpu id, please modify the following code in the train/evaluation script:
34 | ```
35 | os.environ["CUDA_VISIBLE_DEVICES"] = {your gpu id}
36 | ```
37 |
38 | ### Train
39 |
40 | You can train the JARNet by using:
41 | ```
42 | python basicsr/train_jarnet.py -opt options/train_JARNet_LAP.yml
43 | ```
44 |
45 | You can train other restoration model (e.g. NAFNet) by using:
46 | ```
47 | python basicsr/train_others.py -opt options/train_NAFNet_LAP.yml
48 | ```
49 |
50 | You should fill out the path of the dataset in train config (i.e. yaml) file.
51 |
52 | ### Evaluation
53 |
54 | You can evaluate the JARNet by using:
55 | ```
56 | python basicsr/test.py -opt options/test_JARNet_LAP.yml
57 | ```
58 |
59 | You should fill out the path of the dataset and pre-trained model in test config (i.e. yaml) file.
60 |
61 | You can infer JARNet by the [pre-trained model](https://drive.google.com/file/d/1IjFgrXD4YBmRaWPX-WukgvzRdJHrZqZ6/view?usp=sharing) on our LAP evaluation dataset.
62 |
63 | ## Visualization Comparisons
64 | ### Samples of Synthetic LAP Dataset
65 |
66 |
67 | ### Synthetic LAP Image Restoration Results
68 |
69 |
70 | ### Real-World LAP Image Restoration Results
71 |
72 |
73 | ## Citation
74 |
75 | If this repo helps you, please consider citing our work:
76 |
77 | ```
78 | @inproceedings{chen2024deep,
79 | title={Deep Linear Array Pushbroom Image Restoration: A Degradation Pipeline and Jitter-Aware Restoration Network},
80 | author={Zida Chen and Ziran Zhang and Haoying Li and Menghao Li and Yueting Chen and Qi Li and Huajun Feng and Zhihai Xu and Shiqi Chen},
81 | booktitle={AAAI},
82 | year={2024}
83 | }
84 | ```
85 |
86 | ## Contact
87 | If you have any question, please contact zd_chen@zju.edu.cn.
88 |
89 |
90 | ## Acknowledgment
91 | Our code is based on the [BasicSR](https://github.com/XPixelGroup/BasicSR) toolbox.
--------------------------------------------------------------------------------
/VERSION:
--------------------------------------------------------------------------------
1 | 1.2.0
2 |
--------------------------------------------------------------------------------
/assets/dataset.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JHW2000/JARNet/839c8986ead2951114de03598949af45f904d753/assets/dataset.jpg
--------------------------------------------------------------------------------
/assets/real.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JHW2000/JARNet/839c8986ead2951114de03598949af45f904d753/assets/real.jpg
--------------------------------------------------------------------------------
/assets/synthetic.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JHW2000/JARNet/839c8986ead2951114de03598949af45f904d753/assets/synthetic.jpg
--------------------------------------------------------------------------------
/basicsr/data/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | import numpy as np
3 | import random
4 | import torch
5 | import torch.utils.data
6 | from functools import partial
7 | from os import path as osp
8 |
9 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader
10 | from basicsr.utils import get_root_logger, scandir
11 | from basicsr.utils.dist_util import get_dist_info
12 |
13 | __all__ = ['create_dataset', 'create_dataloader']
14 |
15 | # automatically scan and import dataset modules
16 | # scan all the files under the data folder with '_dataset' in file names
17 | data_folder = osp.dirname(osp.abspath(__file__))
18 | dataset_filenames = [
19 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder)
20 | if v.endswith('_dataset.py')
21 | ]
22 | # import all the dataset modules
23 | _dataset_modules = [
24 | importlib.import_module(f'basicsr.data.{file_name}')
25 | for file_name in dataset_filenames
26 | ]
27 |
28 |
29 | def create_dataset(dataset_opt):
30 | """Create dataset.
31 |
32 | Args:
33 | dataset_opt (dict): Configuration for dataset. It constains:
34 | name (str): Dataset name.
35 | type (str): Dataset type.
36 | """
37 | dataset_type = dataset_opt['type']
38 |
39 | # dynamic instantiation
40 | for module in _dataset_modules:
41 | dataset_cls = getattr(module, dataset_type, None)
42 | if dataset_cls is not None:
43 | break
44 | if dataset_cls is None:
45 | raise ValueError(f'Dataset {dataset_type} is not found.')
46 | print("cls:",dataset_opt)
47 |
48 | dataset = dataset_cls(dataset_opt)
49 |
50 | logger = get_root_logger()
51 | logger.info(
52 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} '
53 | 'is created.')
54 | return dataset
55 |
56 |
57 | def create_dataloader(dataset,
58 | dataset_opt,
59 | num_gpu=1,
60 | dist=False,
61 | sampler=None,
62 | seed=None):
63 | """Create dataloader.
64 |
65 | Args:
66 | dataset (torch.utils.data.Dataset): Dataset.
67 | dataset_opt (dict): Dataset options. It contains the following keys:
68 | phase (str): 'train' or 'val'.
69 | num_worker_per_gpu (int): Number of workers for each GPU.
70 | batch_size_per_gpu (int): Training batch size for each GPU.
71 | num_gpu (int): Number of GPUs. Used only in the train phase.
72 | Default: 1.
73 | dist (bool): Whether in distributed training. Used only in the train
74 | phase. Default: False.
75 | sampler (torch.utils.data.sampler): Data sampler. Default: None.
76 | seed (int | None): Seed. Default: None
77 | """
78 | phase = dataset_opt['phase']
79 | rank, _ = get_dist_info()
80 | if phase == 'train':
81 | if dist: # distributed training
82 | batch_size = dataset_opt['batch_size_per_gpu']
83 | num_workers = dataset_opt['num_worker_per_gpu']
84 | else: # non-distributed training
85 | multiplier = 1 if num_gpu == 0 else num_gpu
86 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier
87 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier
88 | dataloader_args = dict(
89 | dataset=dataset,
90 | batch_size=batch_size,
91 | shuffle=False,
92 | num_workers=num_workers,
93 | sampler=sampler,
94 | drop_last=True)
95 | if sampler is None:
96 | dataloader_args['shuffle'] = True
97 | dataloader_args['worker_init_fn'] = partial(
98 | worker_init_fn, num_workers=num_workers, rank=rank,
99 | seed=seed) if seed is not None else None
100 | elif phase in ['val', 'test']: # validation
101 | dataloader_args = dict(
102 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
103 | else:
104 | raise ValueError(f'Wrong dataset phase: {phase}. '
105 | "Supported ones are 'train', 'val' and 'test'.")
106 |
107 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False)
108 |
109 | prefetch_mode = dataset_opt.get('prefetch_mode')
110 | if prefetch_mode == 'cpu': # CPUPrefetcher
111 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1)
112 | logger = get_root_logger()
113 | logger.info(f'Use {prefetch_mode} prefetch dataloader: '
114 | f'num_prefetch_queue = {num_prefetch_queue}')
115 | return PrefetchDataLoader(
116 | num_prefetch_queue=num_prefetch_queue, **dataloader_args)
117 | else:
118 | # prefetch_mode=None: Normal dataloader
119 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher
120 | return torch.utils.data.DataLoader(**dataloader_args)
121 |
122 |
123 | def worker_init_fn(worker_id, num_workers, rank, seed):
124 | # Set the worker seed to num_workers * rank + worker_id + seed
125 | worker_seed = num_workers * rank + worker_id + seed
126 | np.random.seed(worker_seed)
127 | random.seed(worker_seed)
128 |
--------------------------------------------------------------------------------
/basicsr/data/data_sampler.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch.utils.data.sampler import Sampler
4 |
5 |
6 | class EnlargedSampler(Sampler):
7 | """Sampler that restricts data loading to a subset of the dataset.
8 |
9 | Modified from torch.utils.data.distributed.DistributedSampler
10 | Support enlarging the dataset for iteration-based training, for saving
11 | time when restart the dataloader after each epoch
12 |
13 | Args:
14 | dataset (torch.utils.data.Dataset): Dataset used for sampling.
15 | num_replicas (int | None): Number of processes participating in
16 | the training. It is usually the world_size.
17 | rank (int | None): Rank of the current process within num_replicas.
18 | ratio (int): Enlarging ratio. Default: 1.
19 | """
20 |
21 | def __init__(self, dataset, num_replicas, rank, ratio=1):
22 | self.dataset = dataset
23 | self.num_replicas = num_replicas
24 | self.rank = rank
25 | self.epoch = 0
26 | self.num_samples = math.ceil(
27 | len(self.dataset) * ratio / self.num_replicas)
28 | self.total_size = self.num_samples * self.num_replicas
29 |
30 | def __iter__(self):
31 | # deterministically shuffle based on epoch
32 | g = torch.Generator()
33 | g.manual_seed(self.epoch)
34 | indices = torch.randperm(self.total_size, generator=g).tolist()
35 |
36 | dataset_size = len(self.dataset)
37 | indices = [v % dataset_size for v in indices]
38 |
39 | # subsample
40 | indices = indices[self.rank:self.total_size:self.num_replicas]
41 | assert len(indices) == self.num_samples
42 |
43 | return iter(indices)
44 |
45 | def __len__(self):
46 | return self.num_samples
47 |
48 | def set_epoch(self, epoch):
49 | self.epoch = epoch
50 |
--------------------------------------------------------------------------------
/basicsr/data/lap_dataset.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | from pathlib import Path
4 | from torch.utils import data as data
5 | from basicsr.data.transforms import paired_random_crop, four_random_crop
6 | from basicsr.utils import img2tensor
7 | import glob
8 | import os
9 | import torchvision.transforms as transforms
10 | import torch.nn.functional as F
11 |
12 | def unsqueeze_twice(x):
13 | return x.unsqueeze(0).unsqueeze(0)
14 |
15 | def warp(img,jit):
16 | jit = torch.from_numpy(-jit)
17 | # img = torch.from_numpy(img).float() # error: uint16 convert don't support
18 | img = torch.from_numpy(img.astype(np.int32)).float() # val psnr error, np.int32 pr np.float32 instead of np.int16
19 | h, w = img.shape
20 | grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w))
21 | grid = torch.stack((grid_x, grid_y), 2).type_as(img)
22 | grid.requires_grad = False
23 |
24 | grid_flow = grid + jit
25 | grid_flow = grid_flow.unsqueeze(0)
26 | grid_flow = grid_flow[:, :h, :w, :]
27 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0
28 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0
29 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim = 3)
30 |
31 | img_tensor = unsqueeze_twice(img)
32 | img_subdivision = F.grid_sample(img_tensor, grid_flow,
33 | mode = 'bilinear', padding_mode = "reflection", align_corners = True) # nearest
34 | img = np.array(img_subdivision).astype(int)[0,0,:,:]
35 | return img
36 |
37 | class LAPDataSet(data.Dataset):
38 | def __init__(self, opt):
39 | super(LAPDataSet, self).__init__()
40 | self.opt = opt
41 | print(opt)
42 | self.path = sorted(glob.glob(os.path.join(Path(opt['dataroot_gt'])) + "/*.*"))
43 | self.transform = transforms.Compose( [transforms.ToTensor(),])
44 |
45 | def __getitem__(self, index):
46 | index = int(index)
47 | data_dict = np.load(self.path[index % len(self.path)], allow_pickle=True )
48 | data_dict = data_dict.item()
49 |
50 | img = data_dict["img_LAP"]
51 | gt = data_dict["img_gt"]
52 | flow = data_dict["jit_information_noise"]
53 | flow_gt = data_dict["jit_information"]
54 |
55 | img = np.expand_dims(img, axis=2)
56 | gt = np.expand_dims(gt, axis=2)
57 |
58 | # normalization
59 | img = img.astype(np.float32) / 255. / 255.
60 | gt = gt.astype(np.float32) / 255. / 255.
61 |
62 | if self.opt['name'] == 'TrainSet':
63 | scale = self.opt['scale']
64 | gt_size = self.opt['gt_size']
65 | # gt, img = paired_random_crop(gt, img, gt_size, scale, None)
66 | gt, img, flow, flow_gt = four_random_crop(gt, img, flow, flow_gt, gt_size, scale, None)
67 |
68 | img = img2tensor(img)
69 | gt = img2tensor(gt)
70 | flow = img2tensor(flow)
71 | flow_gt = img2tensor(flow_gt)
72 |
73 | return {'lq': img, 'gt': gt, 'flow': flow,"flow_gt": flow_gt,"lq_path": self.path[index % len(self.path)]}
74 |
75 | def __len__(self):
76 | return len(self.path)
77 |
78 |
79 | class LAPDataSetNoWarp(data.Dataset):
80 | def __init__(self, opt):
81 | super(LAPDataSetNoWarp, self).__init__()
82 | self.opt = opt
83 | print(opt)
84 | self.path = sorted(glob.glob(os.path.join(Path(opt['dataroot_gt'])) + "/*.*"))
85 | self.transform = transforms.Compose( [transforms.ToTensor(),])
86 |
87 | def __getitem__(self, index):
88 | index = int(index)
89 | data_dict = np.load(self.path[index % len(self.path)], allow_pickle=True )
90 | data_dict = data_dict.item()
91 |
92 | img = data_dict["img_LAP"]
93 | gt = data_dict["img_gt"]
94 |
95 | img = np.expand_dims(img, axis=2)
96 | gt = np.expand_dims(gt, axis=2)
97 |
98 | # normalization
99 | img = img.astype(np.float32) / 255. / 255.
100 | gt = gt.astype(np.float32) / 255. / 255.
101 |
102 | if self.opt['name'] == 'TrainSet':
103 | scale = self.opt['scale']
104 | gt_size = self.opt['gt_size']
105 | gt, img = paired_random_crop(gt, img, gt_size, scale, None)
106 |
107 | img = img2tensor(img)
108 | gt = img2tensor(gt)
109 |
110 | return {'lq': img, 'gt': gt,"lq_path": self.path[index % len(self.path)]}
111 |
112 | def __len__(self):
113 | return len(self.path)
114 |
115 |
116 | class LAPDataSetWarp(data.Dataset):
117 | def __init__(self, opt):
118 | super(LAPDataSetWarp, self).__init__()
119 | self.opt = opt
120 | print(opt)
121 | self.path = sorted(glob.glob(os.path.join(Path(opt['dataroot_gt'])) + "/*.*"))
122 | self.transform = transforms.Compose( [transforms.ToTensor(),])
123 |
124 | def __getitem__(self, index):
125 | index = int(index)
126 | data_dict = np.load(self.path[index % len(self.path)], allow_pickle=True )
127 | data_dict = data_dict.item()
128 |
129 | img = data_dict["img_LAP"]
130 | gt = data_dict["img_gt"]
131 | flow = data_dict["jit_information_noise"]
132 | flow_gt = data_dict["jit_information"]
133 |
134 | img = warp(img,flow)
135 |
136 | img = np.expand_dims(img, axis=2)
137 | gt = np.expand_dims(gt, axis=2)
138 |
139 | # normalization
140 | img = img.astype(np.float32) / 255. / 255.
141 | gt = gt.astype(np.float32) / 255. / 255.
142 |
143 | if self.opt['name'] == 'TrainSet':
144 | scale = self.opt['scale']
145 | gt_size = self.opt['gt_size']
146 | # gt, img = paired_random_crop(gt, img, gt_size, scale, None)
147 | gt, img, flow, flow_gt = four_random_crop(gt, img, flow, flow_gt, gt_size, scale, None)
148 |
149 | img = img2tensor(img)
150 | gt = img2tensor(gt)
151 | flow = img2tensor(flow)
152 | flow_gt = img2tensor(flow_gt)
153 |
154 | return {'lq': img, 'gt': gt, 'flow': flow,"flow_gt": flow_gt,"lq_path": self.path[index % len(self.path)]}
155 |
156 | def __len__(self):
157 | return len(self.path)
--------------------------------------------------------------------------------
/basicsr/data/prefetch_dataloader.py:
--------------------------------------------------------------------------------
1 | import queue as Queue
2 | import threading
3 | import torch
4 | from torch.utils.data import DataLoader
5 |
6 |
7 | class PrefetchGenerator(threading.Thread):
8 | """A general prefetch generator.
9 |
10 | Ref:
11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch
12 |
13 | Args:
14 | generator: Python generator.
15 | num_prefetch_queue (int): Number of prefetch queue.
16 | """
17 |
18 | def __init__(self, generator, num_prefetch_queue):
19 | threading.Thread.__init__(self)
20 | self.queue = Queue.Queue(num_prefetch_queue)
21 | self.generator = generator
22 | self.daemon = True
23 | self.start()
24 |
25 | def run(self):
26 | for item in self.generator:
27 | self.queue.put(item)
28 | self.queue.put(None)
29 |
30 | def __next__(self):
31 | next_item = self.queue.get()
32 | if next_item is None:
33 | raise StopIteration
34 | return next_item
35 |
36 | def __iter__(self):
37 | return self
38 |
39 |
40 | class PrefetchDataLoader(DataLoader):
41 | """Prefetch version of dataloader.
42 |
43 | Ref:
44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5#
45 |
46 | TODO:
47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in
48 | ddp.
49 |
50 | Args:
51 | num_prefetch_queue (int): Number of prefetch queue.
52 | kwargs (dict): Other arguments for dataloader.
53 | """
54 |
55 | def __init__(self, num_prefetch_queue, **kwargs):
56 | self.num_prefetch_queue = num_prefetch_queue
57 | super(PrefetchDataLoader, self).__init__(**kwargs)
58 |
59 | def __iter__(self):
60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
61 |
62 |
63 | class CPUPrefetcher():
64 | """CPU prefetcher.
65 |
66 | Args:
67 | loader: Dataloader.
68 | """
69 |
70 | def __init__(self, loader):
71 | self.ori_loader = loader
72 | self.loader = iter(loader)
73 |
74 | def next(self):
75 | try:
76 | return next(self.loader)
77 | except StopIteration:
78 | return None
79 |
80 | def reset(self):
81 | self.loader = iter(self.ori_loader)
82 |
83 |
84 | class CUDAPrefetcher():
85 | """CUDA prefetcher.
86 |
87 | Ref:
88 | https://github.com/NVIDIA/apex/issues/304#
89 |
90 | It may consums more GPU memory.
91 |
92 | Args:
93 | loader: Dataloader.
94 | opt (dict): Options.
95 | """
96 |
97 | def __init__(self, loader, opt):
98 | self.ori_loader = loader
99 | self.loader = iter(loader)
100 | self.opt = opt
101 | self.stream = torch.cuda.Stream()
102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu')
103 | self.preload()
104 |
105 | def preload(self):
106 | try:
107 | self.batch = next(self.loader) # self.batch is a dict
108 | except StopIteration:
109 | self.batch = None
110 | return None
111 | # put tensors to gpu
112 | with torch.cuda.stream(self.stream):
113 | for k, v in self.batch.items():
114 | if torch.is_tensor(v):
115 | self.batch[k] = self.batch[k].to(
116 | device=self.device, non_blocking=True)
117 |
118 | def next(self):
119 | torch.cuda.current_stream().wait_stream(self.stream)
120 | batch = self.batch
121 | self.preload()
122 | return batch
123 |
124 | def reset(self):
125 | self.loader = iter(self.ori_loader)
126 | self.preload()
127 |
--------------------------------------------------------------------------------
/basicsr/data/transforms.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import random
3 | import numpy as np
4 |
5 | def mod_crop(img, scale):
6 | """Mod crop images, used during testing.
7 |
8 | Args:
9 | img (ndarray): Input image.
10 | scale (int): Scale factor.
11 |
12 | Returns:
13 | ndarray: Result image.
14 | """
15 | img = img.copy()
16 | if img.ndim in (2, 3):
17 | h, w = img.shape[0], img.shape[1]
18 | h_remainder, w_remainder = h % scale, w % scale
19 | img = img[:h - h_remainder, :w - w_remainder, ...]
20 | else:
21 | raise ValueError(f'Wrong img ndim: {img.ndim}.')
22 | return img
23 |
24 | def paired_random_crop(img_gts, img_lqs, lq_patch_size, scale, gt_path):
25 | """Paired random crop.
26 |
27 | It crops lists of lq and gt images with corresponding locations.
28 |
29 | Args:
30 | img_gts (list[ndarray] | ndarray): GT images. Note that all images
31 | should have the same shape. If the input is an ndarray, it will
32 | be transformed to a list containing itself.
33 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
34 | should have the same shape. If the input is an ndarray, it will
35 | be transformed to a list containing itself.
36 | lq_patch_size (int): LQ patch size.
37 | scale (int): Scale factor.
38 | gt_path (str): Path to ground-truth.
39 |
40 | Returns:
41 | list[ndarray] | ndarray: GT images and LQ images. If returned results
42 | only have one element, just return ndarray.
43 | """
44 |
45 | if not isinstance(img_gts, list):
46 | img_gts = [img_gts]
47 | if not isinstance(img_lqs, list):
48 | img_lqs = [img_lqs]
49 |
50 | h_lq, w_lq, _ = img_lqs[0].shape
51 | h_gt, w_gt, _ = img_gts[0].shape
52 | gt_patch_size = int(lq_patch_size * scale)
53 |
54 | if h_gt != h_lq * scale or w_gt != w_lq * scale:
55 | raise ValueError(
56 | f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
57 | f'multiplication of LQ ({h_lq}, {w_lq}).')
58 | if h_lq < lq_patch_size or w_lq < lq_patch_size:
59 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
60 | f'({lq_patch_size}, {lq_patch_size}). '
61 | f'Please remove {gt_path}.')
62 |
63 | # randomly choose top and left coordinates for lq patch
64 | top = random.randint(0, h_lq - lq_patch_size)
65 | left = random.randint(0, w_lq - lq_patch_size)
66 |
67 | # crop lq patch
68 | img_lqs = [
69 | v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
70 | for v in img_lqs
71 | ]
72 |
73 | # crop corresponding gt patch
74 | top_gt, left_gt = int(top * scale), int(left * scale)
75 | img_gts = [
76 | v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
77 | for v in img_gts
78 | ]
79 | if len(img_gts) == 1:
80 | img_gts = img_gts[0]
81 | if len(img_lqs) == 1:
82 | img_lqs = img_lqs[0]
83 | return img_gts, img_lqs
84 |
85 | def four_random_crop(img_gts, img_lqs,flow,flow_gt, lq_patch_size, scale, gt_path):
86 | """Paired random crop.
87 |
88 | It crops lists of lq and gt images with corresponding locations.
89 |
90 | Args:
91 | img_gts (list[ndarray] | ndarray): GT images. Note that all images
92 | should have the same shape. If the input is an ndarray, it will
93 | be transformed to a list containing itself.
94 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images
95 | should have the same shape. If the input is an ndarray, it will
96 | be transformed to a list containing itself.
97 | lq_patch_size (int): LQ patch size.
98 | scale (int): Scale factor.
99 | gt_path (str): Path to ground-truth.
100 |
101 | Returns:
102 | list[ndarray] | ndarray: GT images and LQ images. If returned results
103 | only have one element, just return ndarray.
104 | """
105 |
106 | if not isinstance(img_gts, list):
107 | img_gts = [img_gts]
108 | if not isinstance(img_lqs, list):
109 | img_lqs = [img_lqs]
110 | # if not isinstance(img_gts, list):
111 | flow = [flow]
112 | # if not isinstance(img_lqs, list):
113 | flow_gt = [flow_gt]
114 |
115 | h_lq, w_lq, _ = img_lqs[0].shape
116 | h_gt, w_gt, _ = img_gts[0].shape
117 | gt_patch_size = int(lq_patch_size * scale)
118 |
119 | if h_gt != h_lq * scale or w_gt != w_lq * scale:
120 | raise ValueError(
121 | f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
122 | f'multiplication of LQ ({h_lq}, {w_lq}).')
123 | if h_lq < lq_patch_size or w_lq < lq_patch_size:
124 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
125 | f'({lq_patch_size}, {lq_patch_size}). '
126 | f'Please remove {gt_path}.')
127 |
128 | # randomly choose top and left coordinates for lq patch
129 | top = random.randint(0, h_lq - lq_patch_size)
130 | left = random.randint(0, w_lq - lq_patch_size)
131 |
132 | # crop lq patch
133 | img_lqs = [
134 | v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
135 | for v in img_lqs
136 | ]
137 |
138 | flow = [
139 | v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
140 | for v in flow
141 | ]
142 |
143 | # crop corresponding gt patch
144 | top_gt, left_gt = int(top * scale), int(left * scale)
145 | img_gts = [
146 | v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
147 | for v in img_gts
148 | ]
149 |
150 | flow_gt = [
151 | v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
152 | for v in flow_gt
153 | ]
154 | if len(img_gts) == 1:
155 | img_gts = img_gts[0]
156 | if len(img_lqs) == 1:
157 | img_lqs = img_lqs[0]
158 |
159 | if len(flow) == 1:
160 | flow = flow[0]
161 | if len(flow_gt) == 1:
162 | flow_gt = flow_gt[0]
163 | return img_gts, img_lqs, flow, flow_gt
164 |
165 | def paired_random_crop_DP(img_lqLs, img_lqRs, img_gts, gt_patch_size, scale, gt_path):
166 | if not isinstance(img_gts, list):
167 | img_gts = [img_gts]
168 | if not isinstance(img_lqLs, list):
169 | img_lqLs = [img_lqLs]
170 | if not isinstance(img_lqRs, list):
171 | img_lqRs = [img_lqRs]
172 |
173 | h_lq, w_lq, _ = img_lqLs[0].shape
174 | h_gt, w_gt, _ = img_gts[0].shape
175 | lq_patch_size = gt_patch_size // scale
176 |
177 | if h_gt != h_lq * scale or w_gt != w_lq * scale:
178 | raise ValueError(
179 | f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
180 | f'multiplication of LQ ({h_lq}, {w_lq}).')
181 | if h_lq < lq_patch_size or w_lq < lq_patch_size:
182 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
183 | f'({lq_patch_size}, {lq_patch_size}). '
184 | f'Please remove {gt_path}.')
185 |
186 | # randomly choose top and left coordinates for lq patch
187 | top = random.randint(0, h_lq - lq_patch_size)
188 | left = random.randint(0, w_lq - lq_patch_size)
189 |
190 | # crop lq patch
191 | img_lqLs = [
192 | v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
193 | for v in img_lqLs
194 | ]
195 |
196 | img_lqRs = [
197 | v[top:top + lq_patch_size, left:left + lq_patch_size, ...]
198 | for v in img_lqRs
199 | ]
200 |
201 | # crop corresponding gt patch
202 | top_gt, left_gt = int(top * scale), int(left * scale)
203 | img_gts = [
204 | v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...]
205 | for v in img_gts
206 | ]
207 | if len(img_gts) == 1:
208 | img_gts = img_gts[0]
209 | if len(img_lqLs) == 1:
210 | img_lqLs = img_lqLs[0]
211 | if len(img_lqRs) == 1:
212 | img_lqRs = img_lqRs[0]
213 | return img_lqLs, img_lqRs, img_gts
214 |
215 |
216 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False):
217 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees).
218 |
219 | We use vertical flip and transpose for rotation implementation.
220 | All the images in the list use the same augmentation.
221 |
222 | Args:
223 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input
224 | is an ndarray, it will be transformed to a list.
225 | hflip (bool): Horizontal flip. Default: True.
226 | rotation (bool): Ratotation. Default: True.
227 | flows (list[ndarray]: Flows to be augmented. If the input is an
228 | ndarray, it will be transformed to a list.
229 | Dimension is (h, w, 2). Default: None.
230 | return_status (bool): Return the status of flip and rotation.
231 | Default: False.
232 |
233 | Returns:
234 | list[ndarray] | ndarray: Augmented images and flows. If returned
235 | results only have one element, just return ndarray.
236 |
237 | """
238 | hflip = hflip and random.random() < 0.5
239 | vflip = rotation and random.random() < 0.5
240 | rot90 = rotation and random.random() < 0.5
241 |
242 | def _augment(img):
243 | if hflip: # horizontal
244 | cv2.flip(img, 1, img)
245 | if vflip: # vertical
246 | cv2.flip(img, 0, img)
247 | if rot90:
248 | img = img.transpose(1, 0, 2)
249 | return img
250 |
251 | def _augment_flow(flow):
252 | if hflip: # horizontal
253 | cv2.flip(flow, 1, flow)
254 | flow[:, :, 0] *= -1
255 | if vflip: # vertical
256 | cv2.flip(flow, 0, flow)
257 | flow[:, :, 1] *= -1
258 | if rot90:
259 | flow = flow.transpose(1, 0, 2)
260 | flow = flow[:, :, [1, 0]]
261 | return flow
262 |
263 | if not isinstance(imgs, list):
264 | imgs = [imgs]
265 | imgs = [_augment(img) for img in imgs]
266 | if len(imgs) == 1:
267 | imgs = imgs[0]
268 |
269 | if flows is not None:
270 | if not isinstance(flows, list):
271 | flows = [flows]
272 | flows = [_augment_flow(flow) for flow in flows]
273 | if len(flows) == 1:
274 | flows = flows[0]
275 | return imgs, flows
276 | else:
277 | if return_status:
278 | return imgs, (hflip, vflip, rot90)
279 | else:
280 | return imgs
281 |
282 |
283 | def img_rotate(img, angle, center=None, scale=1.0):
284 | """Rotate image.
285 |
286 | Args:
287 | img (ndarray): Image to be rotated.
288 | angle (float): Rotation angle in degrees. Positive values mean
289 | counter-clockwise rotation.
290 | center (tuple[int]): Rotation center. If the center is None,
291 | initialize it as the center of the image. Default: None.
292 | scale (float): Isotropic scale factor. Default: 1.0.
293 | """
294 | (h, w) = img.shape[:2]
295 |
296 | if center is None:
297 | center = (w // 2, h // 2)
298 |
299 | matrix = cv2.getRotationMatrix2D(center, angle, scale)
300 | rotated_img = cv2.warpAffine(img, matrix, (w, h))
301 | return rotated_img
302 |
303 | def data_augmentation(image, mode):
304 | """
305 | Performs data augmentation of the input image
306 | Input:
307 | image: a cv2 (OpenCV) image
308 | mode: int. Choice of transformation to apply to the image
309 | 0 - no transformation
310 | 1 - flip up and down
311 | 2 - rotate counterwise 90 degree
312 | 3 - rotate 90 degree and flip up and down
313 | 4 - rotate 180 degree
314 | 5 - rotate 180 degree and flip
315 | 6 - rotate 270 degree
316 | 7 - rotate 270 degree and flip
317 | """
318 | if mode == 0:
319 | # original
320 | out = image
321 | elif mode == 1:
322 | # flip up and down
323 | out = np.flipud(image)
324 | elif mode == 2:
325 | # rotate counterwise 90 degree
326 | out = np.rot90(image)
327 | elif mode == 3:
328 | # rotate 90 degree and flip up and down
329 | out = np.rot90(image)
330 | out = np.flipud(out)
331 | elif mode == 4:
332 | # rotate 180 degree
333 | out = np.rot90(image, k=2)
334 | elif mode == 5:
335 | # rotate 180 degree and flip
336 | out = np.rot90(image, k=2)
337 | out = np.flipud(out)
338 | elif mode == 6:
339 | # rotate 270 degree
340 | out = np.rot90(image, k=3)
341 | elif mode == 7:
342 | # rotate 270 degree and flip
343 | out = np.rot90(image, k=3)
344 | out = np.flipud(out)
345 | else:
346 | raise Exception('Invalid choice of image transformation')
347 |
348 | return out
349 |
350 | def random_augmentation(*args):
351 | out = []
352 | flag_aug = random.randint(0,7)
353 | for data in args:
354 | out.append(data_augmentation(data, flag_aug).copy())
355 | return out
356 |
--------------------------------------------------------------------------------
/basicsr/metrics/__init__.py:
--------------------------------------------------------------------------------
1 | from .psnr_ssim import calculate_psnr, calculate_ssim
2 | from .gmsd import calculate_gmsd
3 |
4 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_gmsd']
5 |
--------------------------------------------------------------------------------
/basicsr/metrics/gmsd.py:
--------------------------------------------------------------------------------
1 | r"""GMSD Metric
2 |
3 | Created by: https://github.com/dingkeyan93/IQA-optimization/blob/master/IQA_pytorch/GMSD.py
4 |
5 | Modified by: Jiadi Mo (https://github.com/JiadiMo)
6 |
7 | Refer to:
8 | Matlab code from https://www4.comp.polyu.edu.hk/~cslzhang/IQA/GMSD/GMSD.m;
9 |
10 | """
11 |
12 | import torch
13 | from torch import nn
14 | from torch.nn import functional as F
15 |
16 | # from pyiqa.utils.color_util import to_y_channel
17 | # from pyiqa.utils.registry import ARCH_REGISTRY
18 |
19 |
20 | def calculate_gmsd(
21 | x: torch.Tensor,
22 | y: torch.Tensor,
23 | T: int = 170,
24 | channels: int = 1, # 3
25 | test_y_channel: bool = False, # True
26 | ) -> torch.Tensor:
27 | r"""GMSD metric.
28 | Args:
29 | x: A distortion tensor. Shape :math:`(N, C, H, W)`.
30 | y: A reference tensor. Shape :math:`(N, C, H, W)`.
31 | T: A positive constant that supplies numerical stability.
32 | channels: Number of channels.
33 | test_y_channel: bool, whether to use y channel on ycbcr.
34 | """
35 | if test_y_channel:
36 | x = to_y_channel(x, 255)
37 | y = to_y_channel(y, 255)
38 | channels = 1
39 | else:
40 | x = x * 255.
41 | y = y * 255.
42 |
43 | dx = (torch.Tensor([[1, 0, -1], [1, 0, -1], [1, 0, -1]]) / 3.).unsqueeze(0).unsqueeze(0).repeat(channels, 1, 1,
44 | 1).to(x)
45 | dy = (torch.Tensor([[1, 1, 1], [0, 0, 0], [-1, -1, -1]]) / 3.).unsqueeze(0).unsqueeze(0).repeat(channels, 1, 1,
46 | 1).to(x)
47 | aveKernel = torch.ones(channels, 1, 2, 2).to(x) / 4.
48 |
49 | Y1 = F.conv2d(x, aveKernel, stride=2, padding=0, groups=channels)
50 | Y2 = F.conv2d(y, aveKernel, stride=2, padding=0, groups=channels)
51 |
52 | IxY1 = F.conv2d(Y1, dx, stride=1, padding=1, groups=channels)
53 | IyY1 = F.conv2d(Y1, dy, stride=1, padding=1, groups=channels)
54 | gradientMap1 = torch.sqrt(IxY1**2 + IyY1**2 + 1e-12)
55 |
56 | IxY2 = F.conv2d(Y2, dx, stride=1, padding=1, groups=channels)
57 | IyY2 = F.conv2d(Y2, dy, stride=1, padding=1, groups=channels)
58 | gradientMap2 = torch.sqrt(IxY2**2 + IyY2**2 + 1e-12)
59 |
60 | quality_map = (2 * gradientMap1 * gradientMap2 + T) / (gradientMap1**2 + gradientMap2**2 + T)
61 | score = torch.std(quality_map.view(quality_map.shape[0], -1), dim=1)
62 |
63 | return score.item()
64 |
65 |
66 | # @ARCH_REGISTRY.register()
67 | class GMSD(nn.Module):
68 | r'''Gradient Magnitude Similarity Deviation Metric.
69 | Args:
70 | channels: Number of channels.
71 | test_y_channel: bool, whether to use y channel on ycbcr.
72 | Reference:
73 | Xue, Wufeng, Lei Zhang, Xuanqin Mou, and Alan C. Bovik.
74 | "Gradient magnitude similarity deviation: A highly efficient
75 | perceptual image quality index." IEEE Transactions on Image
76 | Processing 23, no. 2 (2013): 684-695.
77 | '''
78 |
79 | def __init__(self, channels: int = 3, test_y_channel: bool = True) -> None:
80 | super(GMSD, self).__init__()
81 | self.channels = channels
82 | self.test_y_channel = test_y_channel
83 |
84 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
85 | r"""Args:
86 | x: A distortion tensor. Shape :math:`(N, C, H, W)`.
87 | y: A reference tensor. Shape :math:`(N, C, H, W)`.
88 | Order of input is important.
89 | """
90 | assert x.shape == y.shape, f'Input and reference images should have the same shape, but got {x.shape} and {y.shape}'
91 | score = gmsd(x, y, channels=self.channels, test_y_channel=self.test_y_channel)
92 |
93 | return score
--------------------------------------------------------------------------------
/basicsr/metrics/metric_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | from basicsr.utils.matlab_functions import bgr2ycbcr
4 |
5 |
6 | def reorder_image(img, input_order='HWC'):
7 | """Reorder images to 'HWC' order.
8 |
9 | If the input_order is (h, w), return (h, w, 1);
10 | If the input_order is (c, h, w), return (h, w, c);
11 | If the input_order is (h, w, c), return as it is.
12 |
13 | Args:
14 | img (ndarray): Input image.
15 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
16 | If the input image shape is (h, w), input_order will not have
17 | effects. Default: 'HWC'.
18 |
19 | Returns:
20 | ndarray: reordered image.
21 | """
22 |
23 | if input_order not in ['HWC', 'CHW']:
24 | raise ValueError(
25 | f'Wrong input_order {input_order}. Supported input_orders are '
26 | "'HWC' and 'CHW'")
27 | if len(img.shape) == 2:
28 | img = img[..., None]
29 | if input_order == 'CHW':
30 | img = img.transpose(1, 2, 0)
31 | return img
32 |
33 |
34 | def to_y_channel(img):
35 | """Change to Y channel of YCbCr.
36 |
37 | Args:
38 | img (ndarray): Images with range [0, 255].
39 |
40 | Returns:
41 | (ndarray): Images with range [0, 255] (float type) without round.
42 | """
43 | img = img.astype(np.float32) / 255.
44 | if img.ndim == 3 and img.shape[2] == 3:
45 | img = bgr2ycbcr(img, y_only=True)
46 | img = img[..., None]
47 | return img * 255.
48 |
--------------------------------------------------------------------------------
/basicsr/metrics/psnr_ssim.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 |
4 | from basicsr.metrics.metric_util import reorder_image, to_y_channel
5 | import skimage.metrics
6 | import torch
7 |
8 |
9 | def calculate_psnr(img1,
10 | img2,
11 | crop_border,
12 | input_order='HWC',
13 | test_y_channel=False):
14 | """Calculate PSNR (Peak Signal-to-Noise Ratio).
15 |
16 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
17 |
18 | Args:
19 | img1 (ndarray/tensor): Images with range [0, 255]/[0, 1].
20 | img2 (ndarray/tensor): Images with range [0, 255]/[0, 1].
21 | crop_border (int): Cropped pixels in each edge of an image. These
22 | pixels are not involved in the PSNR calculation.
23 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
24 | Default: 'HWC'.
25 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
26 |
27 | Returns:
28 | float: psnr result.
29 | """
30 |
31 | assert img1.shape == img2.shape, (
32 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
33 | if input_order not in ['HWC', 'CHW']:
34 | raise ValueError(
35 | f'Wrong input_order {input_order}. Supported input_orders are '
36 | '"HWC" and "CHW"')
37 | if type(img1) == torch.Tensor:
38 | if len(img1.shape) == 4:
39 | img1 = img1.squeeze(0)
40 | img1 = img1.detach().cpu().numpy().transpose(1,2,0)
41 | if type(img2) == torch.Tensor:
42 | if len(img2.shape) == 4:
43 | img2 = img2.squeeze(0)
44 | img2 = img2.detach().cpu().numpy().transpose(1,2,0)
45 |
46 | img1 = reorder_image(img1, input_order=input_order)
47 | img2 = reorder_image(img2, input_order=input_order)
48 | img1 = img1.astype(np.float64)
49 | img2 = img2.astype(np.float64)
50 |
51 | if crop_border != 0:
52 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
53 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
54 |
55 | if test_y_channel:
56 | img1 = to_y_channel(img1)
57 | img2 = to_y_channel(img2)
58 |
59 | mse = np.mean((img1 - img2)**2)
60 | if mse == 0:
61 | return float('inf')
62 | max_value = 1. if img1.max() <= 1 else 255.
63 | return 20. * np.log10(max_value / np.sqrt(mse))
64 |
65 |
66 | def _ssim(img1, img2):
67 | """Calculate SSIM (structural similarity) for one channel images.
68 |
69 | It is called by func:`calculate_ssim`.
70 |
71 | Args:
72 | img1 (ndarray): Images with range [0, 255] with order 'HWC'.
73 | img2 (ndarray): Images with range [0, 255] with order 'HWC'.
74 |
75 | Returns:
76 | float: ssim result.
77 | """
78 |
79 | C1 = (0.01 * 255)**2
80 | C2 = (0.03 * 255)**2
81 |
82 | img1 = img1.astype(np.float64)
83 | img2 = img2.astype(np.float64)
84 | kernel = cv2.getGaussianKernel(11, 1.5)
85 | window = np.outer(kernel, kernel.transpose())
86 |
87 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
88 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
89 | mu1_sq = mu1**2
90 | mu2_sq = mu2**2
91 | mu1_mu2 = mu1 * mu2
92 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
93 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
94 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
95 |
96 | ssim_map = ((2 * mu1_mu2 + C1) *
97 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
98 | (sigma1_sq + sigma2_sq + C2))
99 | return ssim_map.mean()
100 |
101 | def prepare_for_ssim(img, k):
102 | import torch
103 | with torch.no_grad():
104 | img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
105 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect')
106 | conv.weight.requires_grad = False
107 | conv.weight[:, :, :, :] = 1. / (k * k)
108 |
109 | img = conv(img)
110 |
111 | img = img.squeeze(0).squeeze(0)
112 | img = img[0::k, 0::k]
113 | return img.detach().cpu().numpy()
114 |
115 | def prepare_for_ssim_rgb(img, k):
116 | import torch
117 | with torch.no_grad():
118 | img = torch.from_numpy(img).float() #HxWx3
119 |
120 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect')
121 | conv.weight.requires_grad = False
122 | conv.weight[:, :, :, :] = 1. / (k * k)
123 |
124 | new_img = []
125 |
126 | for i in range(3):
127 | new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k])
128 |
129 | return torch.stack(new_img, dim=2).detach().cpu().numpy()
130 |
131 | def _3d_gaussian_calculator(img, conv3d):
132 | out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
133 | return out
134 |
135 | def _generate_3d_gaussian_kernel():
136 | kernel = cv2.getGaussianKernel(11, 1.5)
137 | window = np.outer(kernel, kernel.transpose())
138 | kernel_3 = cv2.getGaussianKernel(11, 1.5)
139 | kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0))
140 | conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate')
141 | conv3d.weight.requires_grad = False
142 | conv3d.weight[0, 0, :, :, :] = kernel
143 | return conv3d
144 |
145 | def _ssim_3d(img1, img2, max_value):
146 | assert len(img1.shape) == 3 and len(img2.shape) == 3
147 | """Calculate SSIM (structural similarity) for one channel images.
148 |
149 | It is called by func:`calculate_ssim`.
150 |
151 | Args:
152 | img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
153 | img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
154 |
155 | Returns:
156 | float: ssim result.
157 | """
158 | C1 = (0.01 * max_value) ** 2
159 | C2 = (0.03 * max_value) ** 2
160 | img1 = img1.astype(np.float64)
161 | img2 = img2.astype(np.float64)
162 |
163 | kernel = _generate_3d_gaussian_kernel().cuda()
164 |
165 | img1 = torch.tensor(img1).float().cuda()
166 | img2 = torch.tensor(img2).float().cuda()
167 |
168 |
169 | mu1 = _3d_gaussian_calculator(img1, kernel)
170 | mu2 = _3d_gaussian_calculator(img2, kernel)
171 |
172 | mu1_sq = mu1 ** 2
173 | mu2_sq = mu2 ** 2
174 | mu1_mu2 = mu1 * mu2
175 | sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq
176 | sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq
177 | sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2
178 |
179 | ssim_map = ((2 * mu1_mu2 + C1) *
180 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
181 | (sigma1_sq + sigma2_sq + C2))
182 | return float(ssim_map.mean())
183 |
184 | def _ssim_cly(img1, img2):
185 | assert len(img1.shape) == 2 and len(img2.shape) == 2
186 | """Calculate SSIM (structural similarity) for one channel images.
187 |
188 | It is called by func:`calculate_ssim`.
189 |
190 | Args:
191 | img1 (ndarray): Images with range [0, 255] with order 'HWC'.
192 | img2 (ndarray): Images with range [0, 255] with order 'HWC'.
193 |
194 | Returns:
195 | float: ssim result.
196 | """
197 |
198 | C1 = (0.01 * 255)**2
199 | C2 = (0.03 * 255)**2
200 | img1 = img1.astype(np.float64)
201 | img2 = img2.astype(np.float64)
202 |
203 | kernel = cv2.getGaussianKernel(11, 1.5)
204 | # print(kernel)
205 | window = np.outer(kernel, kernel.transpose())
206 |
207 | bt = cv2.BORDER_REPLICATE
208 |
209 | mu1 = cv2.filter2D(img1, -1, window, borderType=bt)
210 | mu2 = cv2.filter2D(img2, -1, window,borderType=bt)
211 |
212 | mu1_sq = mu1**2
213 | mu2_sq = mu2**2
214 | mu1_mu2 = mu1 * mu2
215 | sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq
216 | sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq
217 | sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2
218 |
219 | ssim_map = ((2 * mu1_mu2 + C1) *
220 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
221 | (sigma1_sq + sigma2_sq + C2))
222 | return ssim_map.mean()
223 |
224 |
225 | def calculate_ssim(img1,
226 | img2,
227 | crop_border,
228 | input_order='HWC',
229 | test_y_channel=False):
230 | """Calculate SSIM (structural similarity).
231 |
232 | Ref:
233 | Image quality assessment: From error visibility to structural similarity
234 |
235 | The results are the same as that of the official released MATLAB code in
236 | https://ece.uwaterloo.ca/~z70wang/research/ssim/.
237 |
238 | For three-channel images, SSIM is calculated for each channel and then
239 | averaged.
240 |
241 | Args:
242 | img1 (ndarray): Images with range [0, 255].
243 | img2 (ndarray): Images with range [0, 255].
244 | crop_border (int): Cropped pixels in each edge of an image. These
245 | pixels are not involved in the SSIM calculation.
246 | input_order (str): Whether the input order is 'HWC' or 'CHW'.
247 | Default: 'HWC'.
248 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
249 |
250 | Returns:
251 | float: ssim result.
252 | """
253 |
254 | assert img1.shape == img2.shape, (
255 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
256 | if input_order not in ['HWC', 'CHW']:
257 | raise ValueError(
258 | f'Wrong input_order {input_order}. Supported input_orders are '
259 | '"HWC" and "CHW"')
260 |
261 | if type(img1) == torch.Tensor:
262 | if len(img1.shape) == 4:
263 | img1 = img1.squeeze(0)
264 | img1 = img1.detach().cpu().numpy().transpose(1,2,0)
265 | if type(img2) == torch.Tensor:
266 | if len(img2.shape) == 4:
267 | img2 = img2.squeeze(0)
268 | img2 = img2.detach().cpu().numpy().transpose(1,2,0)
269 |
270 | img1 = reorder_image(img1, input_order=input_order)
271 | img2 = reorder_image(img2, input_order=input_order)
272 |
273 | img1 = img1.astype(np.float64)
274 | img2 = img2.astype(np.float64)
275 |
276 | if crop_border != 0:
277 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
278 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
279 |
280 | if test_y_channel:
281 | img1 = to_y_channel(img1)
282 | img2 = to_y_channel(img2)
283 | return _ssim_cly(img1[..., 0], img2[..., 0])
284 |
285 |
286 | ssims = []
287 | # ssims_before = []
288 |
289 | # skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)
290 | # print('.._skimage',
291 | # skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True))
292 | max_value = 1 if img1.max() <= 1 else 255
293 | with torch.no_grad():
294 | final_ssim = _ssim_3d(img1, img2, max_value)
295 | ssims.append(final_ssim)
296 |
297 | # for i in range(img1.shape[2]):
298 | # ssims_before.append(_ssim(img1, img2))
299 |
300 | # print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before))
301 | # ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False))
302 |
303 | return np.array(ssims).mean()
304 |
--------------------------------------------------------------------------------
/basicsr/models/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import get_root_logger, scandir
5 |
6 | # automatically scan and import model modules
7 | # scan all the files under the 'models' folder and collect files ending with
8 | # '_model.py'
9 | model_folder = osp.dirname(osp.abspath(__file__))
10 | model_filenames = [
11 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder)
12 | if v.endswith('_model.py')
13 | ]
14 | # import all the model modules
15 | _model_modules = [
16 | importlib.import_module(f'basicsr.models.{file_name}')
17 | for file_name in model_filenames
18 | ]
19 |
20 |
21 | def create_model(opt):
22 | """Create model.
23 |
24 | Args:
25 | opt (dict): Configuration. It constains:
26 | model_type (str): Model type.
27 | """
28 | model_type = opt['model_type']
29 |
30 | # dynamic instantiation
31 | for module in _model_modules:
32 | model_cls = getattr(module, model_type, None)
33 | if model_cls is not None:
34 | break
35 | if model_cls is None:
36 | raise ValueError(f'Model {model_type} is not found.')
37 |
38 | model = model_cls(opt)
39 |
40 | logger = get_root_logger()
41 | logger.info(f'Model [{model.__class__.__name__}] is created.')
42 | return model
43 |
--------------------------------------------------------------------------------
/basicsr/models/archs/__init__.py:
--------------------------------------------------------------------------------
1 | import importlib
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 |
6 | # automatically scan and import arch modules
7 | # scan all the files under the 'archs' folder and collect files ending with
8 | # '_arch.py'
9 | arch_folder = osp.dirname(osp.abspath(__file__))
10 | arch_filenames = [
11 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder)
12 | if v.endswith('_arch.py')
13 | ]
14 | # import all the arch modules
15 | _arch_modules = [
16 | importlib.import_module(f'basicsr.models.archs.{file_name}')
17 | for file_name in arch_filenames
18 | ]
19 |
20 |
21 | def dynamic_instantiation(modules, cls_type, opt):
22 | """Dynamically instantiate class.
23 |
24 | Args:
25 | modules (list[importlib modules]): List of modules from importlib
26 | files.
27 | cls_type (str): Class type.
28 | opt (dict): Class initialization kwargs.
29 |
30 | Returns:
31 | class: Instantiated class.
32 | """
33 |
34 | for module in modules:
35 | cls_ = getattr(module, cls_type, None)
36 | if cls_ is not None:
37 | break
38 | if cls_ is None:
39 | raise ValueError(f'{cls_type} is not found.')
40 | return cls_(**opt)
41 |
42 |
43 | def define_network(opt):
44 | network_type = opt.pop('type')
45 | net = dynamic_instantiation(_arch_modules, network_type, opt)
46 | return net
47 |
--------------------------------------------------------------------------------
/basicsr/models/archs/nafnet_arch.py:
--------------------------------------------------------------------------------
1 | # ------------------------------------------------------------------------
2 | # Copyright (c) 2022 megvii-model. All Rights Reserved.
3 | # ------------------------------------------------------------------------
4 |
5 | '''
6 | Simple Baselines for Image Restoration
7 |
8 | @article{chen2022simple,
9 | title={Simple Baselines for Image Restoration},
10 | author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian},
11 | journal={arXiv preprint arXiv:2204.04676},
12 | year={2022}
13 | }
14 | '''
15 |
16 | import torch
17 | import torch.nn as nn
18 | import numpy as np
19 | import torch.nn.functional as F
20 | # from basicsr.models.archs.arch_util import LayerNorm2d
21 | # from basicsr.models.archs.local_arch import Local_Base
22 |
23 |
24 | class LayerNormFunction(torch.autograd.Function):
25 |
26 | @staticmethod
27 | def forward(ctx, x, weight, bias, eps):
28 | ctx.eps = eps
29 | N, C, H, W = x.size()
30 | mu = x.mean(1, keepdim=True)
31 | var = (x - mu).pow(2).mean(1, keepdim=True)
32 | y = (x - mu) / (var + eps).sqrt()
33 | ctx.save_for_backward(y, var, weight)
34 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
35 | return y
36 |
37 | @staticmethod
38 | def backward(ctx, grad_output):
39 | eps = ctx.eps
40 |
41 | N, C, H, W = grad_output.size()
42 | y, var, weight = ctx.saved_variables
43 | g = grad_output * weight.view(1, C, 1, 1)
44 | mean_g = g.mean(dim=1, keepdim=True)
45 |
46 | mean_gy = (g * y).mean(dim=1, keepdim=True)
47 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
48 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
49 | dim=0), None
50 |
51 | class LayerNorm2d(nn.Module):
52 |
53 | def __init__(self, channels, eps=1e-6):
54 | super(LayerNorm2d, self).__init__()
55 | self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
56 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
57 | self.eps = eps
58 |
59 | def forward(self, x):
60 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
61 |
62 |
63 |
64 | class AvgPool2d(nn.Module):
65 | def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None):
66 | super().__init__()
67 | self.kernel_size = kernel_size
68 | self.base_size = base_size
69 | self.auto_pad = auto_pad
70 |
71 | # only used for fast implementation
72 | self.fast_imp = fast_imp
73 | self.rs = [5, 4, 3, 2, 1]
74 | self.max_r1 = self.rs[0]
75 | self.max_r2 = self.rs[0]
76 | self.train_size = train_size
77 |
78 | def extra_repr(self):
79 | return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format(
80 | self.kernel_size, self.base_size, self.kernel_size, self.fast_imp)
81 |
82 | def forward(self, x):
83 | if self.kernel_size is None and self.base_size:
84 | train_size = self.train_size
85 | if isinstance(self.base_size, int):
86 | self.base_size = (self.base_size, self.base_size)
87 | self.kernel_size = list(self.base_size)
88 | self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2] # self.base_size[0]
89 | self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1] # self.base_size[1]
90 | #print(self.kernel_size, self.base_size, train_size)
91 | # only used for fast implementation
92 | self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2])
93 | self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1])
94 |
95 | #print(self.extra_repr())
96 | if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1):
97 | #print("Yes~")
98 | return F.adaptive_avg_pool2d(x, 1)
99 |
100 | if self.fast_imp: # Non-equivalent implementation but faster
101 | h, w = x.shape[2:]
102 | if self.kernel_size[0] >= h and self.kernel_size[1] >= w:
103 | out = F.adaptive_avg_pool2d(x, 1)
104 | else:
105 | r1 = [r for r in self.rs if h % r == 0][0]
106 | r2 = [r for r in self.rs if w % r == 0][0]
107 | # reduction_constraint
108 | r1 = min(self.max_r1, r1)
109 | r2 = min(self.max_r2, r2)
110 | s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2)
111 | n, c, h, w = s.shape
112 | k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2)
113 | out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2)
114 | out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2))
115 | else:
116 | n, c, h, w = x.shape
117 | s = x.cumsum(dim=-1).cumsum_(dim=-2)
118 | s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience
119 | k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1])
120 | s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:]
121 | out = s4 + s1 - s2 - s3
122 | out = out / (k1 * k2)
123 |
124 | if self.auto_pad:
125 | n, c, h, w = x.shape
126 | _h, _w = out.shape[2:]
127 | # print(x.shape, self.kernel_size)
128 | pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2)
129 | out = torch.nn.functional.pad(out, pad2d, mode='replicate')
130 |
131 | return out
132 |
133 | def replace_layers(model, base_size, train_size, fast_imp, **kwargs):
134 | for n, m in model.named_children():
135 | if len(list(m.children())) > 0:
136 | ## compound module, go inside it
137 | replace_layers(m, base_size, train_size, fast_imp, **kwargs)
138 |
139 | if isinstance(m, nn.AdaptiveAvgPool2d):
140 | pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size)
141 | assert m.output_size == 1
142 | setattr(model, n, pool)
143 |
144 |
145 | class Local_Base():
146 | def convert(self, *args, train_size, **kwargs):
147 | replace_layers(self, *args, train_size=train_size, **kwargs)
148 | imgs = torch.rand(train_size)
149 | with torch.no_grad():
150 | self.forward(imgs)
151 |
152 | class SimpleGate(nn.Module):
153 | def forward(self, x):
154 | x1, x2 = x.chunk(2, dim=1)
155 | return x1 * x2
156 |
157 | class NAFBlock(nn.Module):
158 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.):
159 | super().__init__()
160 | dw_channel = c * DW_Expand
161 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
162 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel,
163 | bias=True)
164 | self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
165 |
166 | # Simplified Channel Attention
167 | self.sca = nn.Sequential(
168 | nn.AdaptiveAvgPool2d(1),
169 | nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1,
170 | groups=1, bias=True),
171 | )
172 |
173 | # SimpleGate
174 | self.sg = SimpleGate()
175 |
176 | ffn_channel = FFN_Expand * c
177 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
178 | self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True)
179 |
180 | self.norm1 = LayerNorm2d(c)
181 | self.norm2 = LayerNorm2d(c)
182 |
183 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
184 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
185 |
186 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
187 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
188 |
189 | def forward(self, inp):
190 | x = inp
191 |
192 | x = self.norm1(x)
193 |
194 | x = self.conv1(x)
195 | x = self.conv2(x)
196 | x = self.sg(x)
197 | #print('--',x.shape, self.sca(x).shape)
198 | x = x * self.sca(x)
199 | x = self.conv3(x)
200 |
201 | x = self.dropout1(x)
202 |
203 | y = inp + x * self.beta
204 |
205 | x = self.conv4(self.norm2(y))
206 | x = self.sg(x)
207 | x = self.conv5(x)
208 |
209 | x = self.dropout2(x)
210 |
211 | return y + x * self.gamma
212 |
213 |
214 | class NAFNet(nn.Module):
215 |
216 | def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]):
217 | super().__init__()
218 |
219 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1,
220 | bias=True)
221 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1,
222 | bias=True)
223 |
224 | self.encoders = nn.ModuleList()
225 | self.decoders = nn.ModuleList()
226 | self.middle_blks = nn.ModuleList()
227 | self.ups = nn.ModuleList()
228 | self.downs = nn.ModuleList()
229 |
230 | chan = width
231 | for num in enc_blk_nums:
232 | self.encoders.append(
233 | nn.Sequential(
234 | *[NAFBlock(chan) for _ in range(num)]
235 | )
236 | )
237 | self.downs.append(
238 | nn.Conv2d(chan, 2*chan, 2, 2)
239 | )
240 | chan = chan * 2
241 |
242 | self.middle_blks = \
243 | nn.Sequential(
244 | *[NAFBlock(chan) for _ in range(middle_blk_num)]
245 | )
246 |
247 | for num in dec_blk_nums:
248 | self.ups.append(
249 | nn.Sequential(
250 | nn.Conv2d(chan, chan * 2, 1, bias=False),
251 | nn.PixelShuffle(2)
252 | )
253 | )
254 | chan = chan // 2
255 | self.decoders.append(
256 | nn.Sequential(
257 | *[NAFBlock(chan) for _ in range(num)]
258 | )
259 | )
260 |
261 | self.padder_size = 2 ** len(self.encoders)
262 |
263 | def forward(self, inp):
264 | B, C, H, W = inp.shape
265 | inp = self.check_image_size(inp)
266 |
267 | x = self.intro(inp)
268 |
269 | encs = []
270 |
271 | for encoder, down in zip(self.encoders, self.downs):
272 | x = encoder(x)
273 | encs.append(x)
274 | x = down(x)
275 |
276 | x = self.middle_blks(x)
277 |
278 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
279 | x = up(x)
280 | x = x + enc_skip
281 | x = decoder(x)
282 |
283 | x = self.ending(x)
284 | x = x + inp
285 |
286 | return x[:, :, :H, :W]
287 |
288 | def check_image_size(self, x):
289 | _, _, h, w = x.size()
290 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
291 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
292 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
293 | return x
294 |
295 | class NAFNetLocal(Local_Base, NAFNet):
296 | def __init__(self, *args, train_size=(1, 1, 128, 128), base_size=(192, 192), fast_imp=False, **kwargs):
297 | Local_Base.__init__(self)
298 | NAFNet.__init__(self, *args, **kwargs)
299 |
300 | N, C, H, W = train_size
301 | base_size = (int(H * 1.5), int(W * 1.5))
302 |
303 | self.eval()
304 | with torch.no_grad():
305 | self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp)
306 |
307 |
308 | if __name__ == '__main__':
309 | img_channel = 1
310 | enc_blks = [1, 1, 1, 28]
311 | middle_blk_num = 1
312 | dec_blks = [1, 1, 1, 1]
313 | width = 64
314 |
315 | net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num,
316 | enc_blk_nums=enc_blks, dec_blk_nums=dec_blks)
317 |
318 | inp_shape = (1, 256, 256)
319 |
320 | from ptflops import get_model_complexity_info
321 |
322 | macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False)
323 |
324 | params = float(params[:-1])
325 | macs = float(macs[:-4])
326 |
327 | print(macs, params)
--------------------------------------------------------------------------------
/basicsr/models/losses/__init__.py:
--------------------------------------------------------------------------------
1 | from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss,PerceptualLoss,FFTLoss)
2 |
3 | __all__ = [
4 | 'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss','PerceptualLoss','FFTLoss'
5 | ]
--------------------------------------------------------------------------------
/basicsr/models/losses/loss_util.py:
--------------------------------------------------------------------------------
1 | import functools
2 | from torch.nn import functional as F
3 |
4 |
5 | def reduce_loss(loss, reduction):
6 | """Reduce loss as specified.
7 |
8 | Args:
9 | loss (Tensor): Elementwise loss tensor.
10 | reduction (str): Options are 'none', 'mean' and 'sum'.
11 |
12 | Returns:
13 | Tensor: Reduced loss tensor.
14 | """
15 | reduction_enum = F._Reduction.get_enum(reduction)
16 | # none: 0, elementwise_mean:1, sum: 2
17 | if reduction_enum == 0:
18 | return loss
19 | elif reduction_enum == 1:
20 | return loss.mean()
21 | else:
22 | return loss.sum()
23 |
24 |
25 | def weight_reduce_loss(loss, weight=None, reduction='mean'):
26 | """Apply element-wise weight and reduce loss.
27 |
28 | Args:
29 | loss (Tensor): Element-wise loss.
30 | weight (Tensor): Element-wise weights. Default: None.
31 | reduction (str): Same as built-in losses of PyTorch. Options are
32 | 'none', 'mean' and 'sum'. Default: 'mean'.
33 |
34 | Returns:
35 | Tensor: Loss values.
36 | """
37 | # if weight is specified, apply element-wise weight
38 | if weight is not None:
39 | assert weight.dim() == loss.dim()
40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1)
41 | loss = loss * weight
42 |
43 | # if weight is not specified or reduction is sum, just reduce the loss
44 | if weight is None or reduction == 'sum':
45 | loss = reduce_loss(loss, reduction)
46 | # if reduction is mean, then compute mean over weight region
47 | elif reduction == 'mean':
48 | if weight.size(1) > 1:
49 | weight = weight.sum()
50 | else:
51 | weight = weight.sum() * loss.size(1)
52 | loss = loss.sum() / weight
53 |
54 | return loss
55 |
56 |
57 | def weighted_loss(loss_func):
58 | """Create a weighted version of a given loss function.
59 |
60 | To use this decorator, the loss function must have the signature like
61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute
62 | element-wise loss without any reduction. This decorator will add weight
63 | and reduction arguments to the function. The decorated function will have
64 | the signature like `loss_func(pred, target, weight=None, reduction='mean',
65 | **kwargs)`.
66 |
67 | :Example:
68 |
69 | >>> import torch
70 | >>> @weighted_loss
71 | >>> def l1_loss(pred, target):
72 | >>> return (pred - target).abs()
73 |
74 | >>> pred = torch.Tensor([0, 2, 3])
75 | >>> target = torch.Tensor([1, 1, 1])
76 | >>> weight = torch.Tensor([1, 0, 1])
77 |
78 | >>> l1_loss(pred, target)
79 | tensor(1.3333)
80 | >>> l1_loss(pred, target, weight)
81 | tensor(1.5000)
82 | >>> l1_loss(pred, target, reduction='none')
83 | tensor([1., 1., 2.])
84 | >>> l1_loss(pred, target, weight, reduction='sum')
85 | tensor(3.)
86 | """
87 |
88 | @functools.wraps(loss_func)
89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs):
90 | # get element-wise loss
91 | loss = loss_func(pred, target, **kwargs)
92 | loss = weight_reduce_loss(loss, weight, reduction)
93 | return loss
94 |
95 | return wrapper
96 |
--------------------------------------------------------------------------------
/basicsr/models/losses/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn as nn
3 | from torch.nn import functional as F
4 | import numpy as np
5 | from basicsr.models.losses.loss_util import weighted_loss
6 |
7 | _reduction_modes = ['none', 'mean', 'sum']
8 |
9 |
10 | @weighted_loss
11 | def l1_loss(pred, target):
12 | return F.l1_loss(pred, target, reduction='none')
13 |
14 |
15 | @weighted_loss
16 | def mse_loss(pred, target):
17 | return F.mse_loss(pred, target, reduction='none')
18 |
19 |
20 | class L1Loss(nn.Module):
21 | """L1 (mean absolute error, MAE) loss.
22 |
23 | Args:
24 | loss_weight (float): Loss weight for L1 loss. Default: 1.0.
25 | reduction (str): Specifies the reduction to apply to the output.
26 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
27 | """
28 |
29 | def __init__(self, loss_weight=1.0, reduction='mean'):
30 | super(L1Loss, self).__init__()
31 | if reduction not in ['none', 'mean', 'sum']:
32 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
33 | f'Supported ones are: {_reduction_modes}')
34 |
35 | self.loss_weight = loss_weight
36 | self.reduction = reduction
37 |
38 | def forward(self, pred, target, weight=None, **kwargs):
39 | """
40 | Args:
41 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
42 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
43 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
44 | weights. Default: None.
45 | """
46 | return self.loss_weight * l1_loss(
47 | pred, target, weight, reduction=self.reduction)
48 |
49 | class MSELoss(nn.Module):
50 | """MSE (L2) loss.
51 |
52 | Args:
53 | loss_weight (float): Loss weight for MSE loss. Default: 1.0.
54 | reduction (str): Specifies the reduction to apply to the output.
55 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
56 | """
57 |
58 | def __init__(self, loss_weight=1.0, reduction='mean'):
59 | super(MSELoss, self).__init__()
60 | if reduction not in ['none', 'mean', 'sum']:
61 | raise ValueError(f'Unsupported reduction mode: {reduction}. '
62 | f'Supported ones are: {_reduction_modes}')
63 |
64 | self.loss_weight = loss_weight
65 | self.reduction = reduction
66 |
67 | def forward(self, pred, target, weight=None, **kwargs):
68 | """
69 | Args:
70 | pred (Tensor): of shape (N, C, H, W). Predicted tensor.
71 | target (Tensor): of shape (N, C, H, W). Ground truth tensor.
72 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise
73 | weights. Default: None.
74 | """
75 | return self.loss_weight * mse_loss(
76 | pred, target, weight, reduction=self.reduction)
77 |
78 | class PSNRLoss(nn.Module):
79 |
80 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False):
81 | super(PSNRLoss, self).__init__()
82 | assert reduction == 'mean'
83 | self.loss_weight = loss_weight
84 | self.scale = 10 / np.log(10)
85 | self.toY = toY
86 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1)
87 | self.first = True
88 |
89 | def forward(self, pred, target):
90 | assert len(pred.size()) == 4
91 | if self.toY:
92 | if self.first:
93 | self.coef = self.coef.to(pred.device)
94 | self.first = False
95 |
96 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
97 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16.
98 |
99 | pred, target = pred / 255., target / 255.
100 | pass
101 | assert len(pred.size()) == 4
102 |
103 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean()
104 |
105 | class CharbonnierLoss(nn.Module):
106 | """Charbonnier Loss (L1)"""
107 |
108 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3):
109 | super(CharbonnierLoss, self).__init__()
110 | self.eps = eps
111 | self.loss_weight = loss_weight
112 |
113 | def forward(self, x, y):
114 | diff = x - y
115 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
116 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
117 | return loss * self.loss_weight
118 |
119 | class FFTLoss(nn.Module):
120 | """L1 loss in frequency domain with FFT.
121 |
122 | Args:
123 | loss_weight (float): Loss weight for FFT loss. Default: 1.0.
124 | reduction (str): Specifies the reduction to apply to the output.
125 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'.
126 | """
127 |
128 | def __init__(self, loss_weight=1.0, reduction='mean'):
129 | super(FFTLoss, self).__init__()
130 | if reduction not in ['none', 'mean', 'sum']:
131 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}')
132 |
133 | self.loss_weight = loss_weight
134 | self.reduction = reduction
135 |
136 | def forward(self, pred, target, weight=None, **kwargs):
137 | """
138 | Args:
139 | pred (Tensor): of shape (..., C, H, W). Predicted tensor.
140 | target (Tensor): of shape (..., C, H, W). Ground truth tensor.
141 | weight (Tensor, optional): of shape (..., C, H, W). Element-wise
142 | weights. Default: None.
143 | """
144 |
145 | pred_fft = torch.fft.fft2(pred, dim=(-2, -1))
146 | pred_fft = torch.stack([pred_fft.real, pred_fft.imag], dim=-1)
147 | target_fft = torch.fft.fft2(target, dim=(-2, -1))
148 | target_fft = torch.stack([target_fft.real, target_fft.imag], dim=-1)
149 | return self.loss_weight * l1_loss(pred_fft, target_fft, weight, reduction=self.reduction)
150 |
151 | import torchvision.models as models
152 | class PerceptualLoss(nn.Module):
153 |
154 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-6):
155 | super(PerceptualLoss, self).__init__()
156 | def contentFunc():
157 | conv_3_3_layer = 14
158 | cnn = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features
159 | model = nn.Sequential()
160 | for i,layer in enumerate(list(cnn)):
161 | model.add_module(str(i),layer)
162 | if i == conv_3_3_layer:
163 | break
164 | return model
165 | self.contentFunc = contentFunc()
166 | self.loss_weight = loss_weight
167 | self.reduction = reduction
168 | def forward(self, fakeIm, realIm):
169 | fakeIm_ = torch.cat((fakeIm, fakeIm,fakeIm), 1)
170 | realIm_ = torch.cat((realIm, realIm,realIm), 1)
171 | f_fake = self.contentFunc.forward(fakeIm_)
172 | f_real = self.contentFunc.forward(realIm_)
173 | f_real_no_grad = f_real.detach()
174 | # loss = torch.nn.MSELoss(f_fake, f_real_no_grad)
175 | loss = mse_loss(f_fake, f_real_no_grad, reduction=self.reduction)
176 | return loss * self.loss_weight
177 |
--------------------------------------------------------------------------------
/basicsr/models/lr_scheduler.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from torch.optim.lr_scheduler import _LRScheduler
4 | import torch
5 |
6 |
7 | class MultiStepRestartLR(_LRScheduler):
8 | """ MultiStep with restarts learning rate scheme.
9 |
10 | Args:
11 | optimizer (torch.nn.optimizer): Torch optimizer.
12 | milestones (list): Iterations that will decrease learning rate.
13 | gamma (float): Decrease ratio. Default: 0.1.
14 | restarts (list): Restart iterations. Default: [0].
15 | restart_weights (list): Restart weights at each restart iteration.
16 | Default: [1].
17 | last_epoch (int): Used in _LRScheduler. Default: -1.
18 | """
19 |
20 | def __init__(self,
21 | optimizer,
22 | milestones,
23 | gamma=0.1,
24 | restarts=(0, ),
25 | restart_weights=(1, ),
26 | last_epoch=-1):
27 | self.milestones = Counter(milestones)
28 | self.gamma = gamma
29 | self.restarts = restarts
30 | self.restart_weights = restart_weights
31 | assert len(self.restarts) == len(
32 | self.restart_weights), 'restarts and their weights do not match.'
33 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
34 |
35 | def get_lr(self):
36 | if self.last_epoch in self.restarts:
37 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
38 | return [
39 | group['initial_lr'] * weight
40 | for group in self.optimizer.param_groups
41 | ]
42 | if self.last_epoch not in self.milestones:
43 | return [group['lr'] for group in self.optimizer.param_groups]
44 | return [
45 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
46 | for group in self.optimizer.param_groups
47 | ]
48 |
49 | class LinearLR(_LRScheduler):
50 | """
51 |
52 | Args:
53 | optimizer (torch.nn.optimizer): Torch optimizer.
54 | milestones (list): Iterations that will decrease learning rate.
55 | gamma (float): Decrease ratio. Default: 0.1.
56 | last_epoch (int): Used in _LRScheduler. Default: -1.
57 | """
58 |
59 | def __init__(self,
60 | optimizer,
61 | total_iter,
62 | last_epoch=-1):
63 | self.total_iter = total_iter
64 | super(LinearLR, self).__init__(optimizer, last_epoch)
65 |
66 | def get_lr(self):
67 | process = self.last_epoch / self.total_iter
68 | weight = (1 - process)
69 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups])
70 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
71 |
72 | class VibrateLR(_LRScheduler):
73 | """
74 |
75 | Args:
76 | optimizer (torch.nn.optimizer): Torch optimizer.
77 | milestones (list): Iterations that will decrease learning rate.
78 | gamma (float): Decrease ratio. Default: 0.1.
79 | last_epoch (int): Used in _LRScheduler. Default: -1.
80 | """
81 |
82 | def __init__(self,
83 | optimizer,
84 | total_iter,
85 | last_epoch=-1):
86 | self.total_iter = total_iter
87 | super(VibrateLR, self).__init__(optimizer, last_epoch)
88 |
89 | def get_lr(self):
90 | process = self.last_epoch / self.total_iter
91 |
92 | f = 0.1
93 | if process < 3 / 8:
94 | f = 1 - process * 8 / 3
95 | elif process < 5 / 8:
96 | f = 0.2
97 |
98 | T = self.total_iter // 80
99 | Th = T // 2
100 |
101 | t = self.last_epoch % T
102 |
103 | f2 = t / Th
104 | if t >= Th:
105 | f2 = 2 - f2
106 |
107 | weight = f * f2
108 |
109 | if self.last_epoch < Th:
110 | weight = max(0.1, weight)
111 |
112 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2))
113 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
114 |
115 | def get_position_from_periods(iteration, cumulative_period):
116 | """Get the position from a period list.
117 |
118 | It will return the index of the right-closest number in the period list.
119 | For example, the cumulative_period = [100, 200, 300, 400],
120 | if iteration == 50, return 0;
121 | if iteration == 210, return 2;
122 | if iteration == 300, return 2.
123 |
124 | Args:
125 | iteration (int): Current iteration.
126 | cumulative_period (list[int]): Cumulative period list.
127 |
128 | Returns:
129 | int: The position of the right-closest number in the period list.
130 | """
131 | for i, period in enumerate(cumulative_period):
132 | # print("12230033",i, period,iteration )
133 | if iteration <= period:
134 | return i
135 |
136 |
137 | class CosineAnnealingRestartLR(_LRScheduler):
138 | """ Cosine annealing with restarts learning rate scheme.
139 |
140 | An example of config:
141 | periods = [10, 10, 10, 10]
142 | restart_weights = [1, 0.5, 0.5, 0.5]
143 | eta_min=1e-7
144 |
145 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
146 | scheduler will restart with the weights in restart_weights.
147 |
148 | Args:
149 | optimizer (torch.nn.optimizer): Torch optimizer.
150 | periods (list): Period for each cosine anneling cycle.
151 | restart_weights (list): Restart weights at each restart iteration.
152 | Default: [1].
153 | eta_min (float): The mimimum lr. Default: 0.
154 | last_epoch (int): Used in _LRScheduler. Default: -1.
155 | """
156 |
157 | def __init__(self,
158 | optimizer,
159 | periods,
160 | restart_weights=(1, ),
161 | eta_min=0,
162 | last_epoch=-1):
163 | self.periods = periods
164 | self.restart_weights = restart_weights
165 | self.eta_min = eta_min
166 | assert (len(self.periods) == len(self.restart_weights)
167 | ), 'periods and restart_weights should have the same length.'
168 | self.cumulative_period = [
169 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
170 | ]
171 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
172 |
173 | def get_lr(self):
174 | idx = get_position_from_periods(self.last_epoch,
175 | self.cumulative_period)
176 | current_weight = self.restart_weights[idx]
177 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
178 | current_period = self.periods[idx]
179 |
180 | return [
181 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
182 | (1 + math.cos(math.pi * (
183 | (self.last_epoch - nearest_restart) / current_period)))
184 | for base_lr in self.base_lrs
185 | ]
186 |
187 | class CosineAnnealingRestartCyclicLR(_LRScheduler):
188 | """ Cosine annealing with restarts learning rate scheme.
189 | An example of config:
190 | periods = [10, 10, 10, 10]
191 | restart_weights = [1, 0.5, 0.5, 0.5]
192 | eta_min=1e-7
193 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
194 | scheduler will restart with the weights in restart_weights.
195 | Args:
196 | optimizer (torch.nn.optimizer): Torch optimizer.
197 | periods (list): Period for each cosine anneling cycle.
198 | restart_weights (list): Restart weights at each restart iteration.
199 | Default: [1].
200 | eta_min (float): The mimimum lr. Default: 0.
201 | last_epoch (int): Used in _LRScheduler. Default: -1.
202 | """
203 |
204 | def __init__(self,
205 | optimizer,
206 | periods,
207 | restart_weights=(1, ),
208 | eta_mins=(0, ),
209 | last_epoch=-1):
210 | self.periods = periods
211 | self.restart_weights = restart_weights
212 | self.eta_mins = eta_mins
213 | assert (len(self.periods) == len(self.restart_weights)
214 | ), 'periods and restart_weights should have the same length.'
215 | self.cumulative_period = [
216 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
217 | ]
218 | super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch)
219 |
220 | def get_lr(self):
221 | idx = get_position_from_periods(self.last_epoch,
222 | self.cumulative_period)
223 | # print("12230033",idx,self.last_epoch,self.cumulative_period)
224 | current_weight = self.restart_weights[idx]
225 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
226 | current_period = self.periods[idx]
227 | eta_min = self.eta_mins[idx]
228 |
229 | return [
230 | eta_min + current_weight * 0.5 * (base_lr - eta_min) *
231 | (1 + math.cos(math.pi * (
232 | (self.last_epoch - nearest_restart) / current_period)))
233 | for base_lr in self.base_lrs
234 | ]
235 |
236 |
237 | # ##测试代码
238 | # import torch
239 | # import torch.nn.functional as F
240 | # class Net(torch.nn.Module):
241 | # def __init__(self,n_feature,n_hidden,n_output):
242 | # super(Net,self).__init__()
243 | # self.hidden = torch.nn.Linear(n_feature,n_hidden)
244 | # self.predict = torch.nn.Linear(n_hidden,n_output)
245 |
246 | # def forward(self,x):
247 | # x = F.relu(self.hidden(x))
248 | # x = self.predict(x)
249 | # return x
250 | # net = Net(1,10,1)
251 | # optimizer_G = torch.optim.Adam(net.parameters(), lr=0.001)
252 | # scheduler = CosineAnnealingRestartCyclicLR(optimizer_G, (10,))
253 | # y = []
254 | # for epoch in range(10):
255 | # lr = optimizer_G.state_dict()['param_groups'][0]['lr']
256 | # y.append(lr)
257 | # scheduler.step()
258 | # print(y)
259 | # import matplotlib.pyplot as plt
260 | # plt.figure(dpi=300)
261 | # plt.plot(y)
262 | # plt.title('MultiplicativeLR')
263 | # plt.savefig("a.png")
--------------------------------------------------------------------------------
/basicsr/test.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import torch
3 | from os import path as osp
4 | import os
5 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
6 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # choose your gpu id
7 |
8 | import random, argparse
9 |
10 | from basicsr.data import create_dataloader, create_dataset
11 | from basicsr.models import create_model
12 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str,
13 | make_exp_dirs, set_random_seed)
14 | from basicsr.utils.options import dict2str, parse
15 | from basicsr.utils.dist_util import get_dist_info, init_dist
16 |
17 | def parse_options(is_train=True):
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument(
20 | '-opt', default="options/test_JARNet_LAP.yml",
21 | type=str)#, required=True, help='Path to option YAML file.')
22 | parser.add_argument(
23 | '--launcher',
24 | choices=['none', 'pytorch', 'slurm'],
25 | default='none',
26 | help='job launcher')
27 | parser.add_argument('--local_rank', type=int, default=0)
28 | args = parser.parse_args()
29 | opt = parse(args.opt, is_train=is_train)
30 |
31 | # distributed settings
32 | if args.launcher == 'none':
33 | opt['dist'] = False
34 | print('Disable distributed.', flush=True)
35 | else:
36 | opt['dist'] = True
37 | if args.launcher == 'slurm' and 'dist_params' in opt:
38 | init_dist(args.launcher, **opt['dist_params'])
39 | else:
40 | init_dist(args.launcher)
41 | print('init dist .. ', args.launcher)
42 |
43 | opt['rank'], opt['world_size'] = get_dist_info()
44 |
45 | # random seed
46 | seed = opt.get('manual_seed')
47 | if seed is None:
48 | seed = random.randint(1, 10000)
49 | opt['manual_seed'] = seed
50 | set_random_seed(seed + opt['rank'])
51 |
52 | return opt
53 |
54 | def test_pipeline(root_path):
55 | # parse options, set distributed setting, set ramdom seed
56 | #opt, _ = parse_options(root_path, is_train=False)
57 | opt = parse_options(is_train=False)
58 |
59 | torch.backends.cudnn.benchmark = True
60 | # torch.backends.cudnn.deterministic = True
61 |
62 | # mkdir and initialize loggers
63 | make_exp_dirs(opt)
64 | log_file = osp.join(opt['path']['log'],
65 | f"test_{opt['name']}_{get_time_str()}.log")
66 | logger = get_root_logger(
67 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
68 | logger.info(get_env_info())
69 | logger.info(dict2str(opt))
70 |
71 | # create test dataset and dataloader
72 | test_loaders = []
73 | for phase, dataset_opt in sorted(opt['datasets'].items()):
74 | test_set = create_dataset(dataset_opt)
75 | test_loader = create_dataloader(
76 | test_set,
77 | dataset_opt,
78 | num_gpu=opt['num_gpu'],
79 | dist=opt['dist'],
80 | sampler=None,
81 | seed=opt['manual_seed'])
82 | logger.info(
83 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
84 | test_loaders.append(test_loader)
85 |
86 | # create model
87 | model = create_model(opt)
88 |
89 | for test_loader in test_loaders:
90 | test_set_name = test_loader.dataset.opt['name']
91 | logger.info(f'Testing {test_set_name}...')
92 | rgb2bgr = opt['val'].get('rgb2bgr', True)
93 | # wheather use uint8 image to compute metrics
94 | use_image = opt['val'].get('use_image', True)
95 | model.validation(
96 | test_loader,
97 | current_iter=opt['name'],
98 | tb_logger=None,
99 | save_img=opt['val']['save_img'],
100 | rgb2bgr=rgb2bgr, use_image=use_image)
101 |
102 |
103 | if __name__ == '__main__':
104 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir))
105 | test_pipeline(root_path)
106 |
--------------------------------------------------------------------------------
/basicsr/train_jarnet.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # choose your gpu id
4 | import torch
5 | import logging
6 | import math
7 | import time
8 | import random
9 | import numpy as np
10 | import datetime
11 | import argparse
12 | from basicsr.utils import (MessageLogger,make_exp_dirs,mkdir_and_rename,get_root_logger,
13 | get_env_info,init_wandb_logger,init_tb_logger,
14 | get_time_str,check_resume,set_random_seed)
15 | from os import path as osp
16 | from basicsr.utils.options import dict2str
17 | from basicsr.data import create_dataloader, create_dataset
18 | from basicsr.data.data_sampler import EnlargedSampler
19 | from basicsr.models import create_model
20 | from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
21 | from basicsr.utils.dist_util import get_dist_info, init_dist
22 | from basicsr.utils.options import dict2str, parse
23 |
24 | def parse_options(is_train=True):
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument(
27 | '-opt', default="options/train_JARNet_LAP.yml",
28 | type=str)#, required=True, help='Path to option YAML file.')
29 | parser.add_argument(
30 | '--launcher',
31 | choices=['none', 'pytorch', 'slurm'],
32 | default='none',
33 | help='job launcher')
34 | parser.add_argument('--local-rank', type=int, default=0)
35 | args = parser.parse_args()
36 | opt = parse(args.opt, is_train=is_train)
37 |
38 | # distributed settings
39 | if args.launcher == 'none':
40 | opt['dist'] = False
41 | print('Disable distributed.', flush=True)
42 | else:
43 | opt['dist'] = True
44 | if args.launcher == 'slurm' and 'dist_params' in opt:
45 | init_dist(args.launcher, **opt['dist_params'])
46 | else:
47 | init_dist(args.launcher)
48 | print('init dist .. ', args.launcher)
49 |
50 | opt['rank'], opt['world_size'] = get_dist_info()
51 |
52 | # random seed
53 | seed = opt.get('manual_seed')
54 | if seed is None:
55 | seed = random.randint(1, 10000)
56 | opt['manual_seed'] = seed
57 | set_random_seed(seed + opt['rank'])
58 |
59 | return opt
60 |
61 | def init_loggers(opt):
62 | log_file = osp.join(opt['path']['log'],
63 | f"train_{opt['name']}_{get_time_str()}.log")
64 | logger = get_root_logger(
65 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
66 | logger.info(get_env_info())
67 | logger.info(dict2str(opt))
68 |
69 | # initialize wandb logger before tensorboard logger to allow proper sync:
70 | if (opt['logger'].get('wandb')
71 | is not None) and (opt['logger']['wandb'].get('project')
72 | is not None) and ('debug' not in opt['name']):
73 | assert opt['logger'].get('use_tb_logger') is True, (
74 | 'should turn on tensorboard when using wandb')
75 | init_wandb_logger(opt)
76 | tb_logger = None
77 | if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
78 | tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
79 | return logger, tb_logger
80 |
81 | def create_train_val_dataloader(opt, logger):
82 | # create train and val dataloaders
83 | train_loader, val_loader = None, None
84 | for phase, dataset_opt in opt['datasets'].items():
85 | if phase == 'train':
86 | # if 0:
87 | dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
88 | train_set = create_dataset(dataset_opt)
89 | train_sampler = EnlargedSampler(train_set, opt['world_size'],
90 | opt['rank'], dataset_enlarge_ratio)
91 | # train_sampler = None
92 | train_loader = create_dataloader(
93 | train_set,
94 | dataset_opt,
95 | num_gpu=opt['num_gpu'],
96 | dist=opt['dist'],
97 | sampler=train_sampler,
98 | seed=opt['manual_seed'])
99 |
100 | num_iter_per_epoch = math.ceil(
101 | len(train_set) * dataset_enlarge_ratio /
102 | (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
103 | total_iters = int(opt['train']['total_iter'])
104 | # print("----------",num_iter_per_epoch,len(train_set) )
105 | total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
106 | logger.info(
107 | 'Training statistics:'
108 | f'\n\tNumber of train images: {len(train_set)}'
109 | f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
110 | f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
111 | f'\n\tWorld size (gpu number): {opt["world_size"]}'
112 | f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
113 | f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
114 | # elif 1:
115 | elif phase == 'val':
116 | val_set = create_dataset(dataset_opt)
117 | val_loader = create_dataloader(
118 | val_set,
119 | dataset_opt,
120 | num_gpu=opt['num_gpu'],
121 | dist=opt['dist'],
122 | sampler=None,
123 | seed=opt['manual_seed'])
124 | logger.info(
125 | f'Number of val images/folders in {dataset_opt["name"]}: '
126 | f'{len(val_set)}')
127 | else:
128 | raise ValueError(f'Dataset phase {phase} is not recognized.')
129 |
130 | return train_loader, train_sampler, val_loader, total_epochs, total_iters
131 |
132 |
133 | def main():
134 | # parse options, set distributed setting, set ramdom seed
135 | opt = parse_options(is_train=True)
136 |
137 | torch.backends.cudnn.benchmark = True
138 | # automatic resume ..
139 | state_folder_path = 'experiments/{}/training_states/'.format(opt['name'])
140 |
141 | try:
142 | states = os.listdir(state_folder_path)
143 | except:
144 | states = []
145 |
146 | resume_state = None
147 | if len(states) > 0:
148 | max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states]))
149 | resume_state = os.path.join(state_folder_path, max_state_file)
150 | opt['path']['resume_state'] = resume_state
151 |
152 | # load resume states if necessary
153 | if opt['path'].get('resume_state'):
154 | device_id = torch.cuda.current_device()
155 | resume_state = torch.load(
156 | opt['path']['resume_state'],
157 | map_location=lambda storage, loc: storage.cuda(device_id))
158 | else:
159 | resume_state = None
160 |
161 | # mkdir for experiments and logger
162 | if resume_state is None:
163 | make_exp_dirs(opt)
164 | if opt['logger'].get('use_tb_logger') and 'debug' not in opt[
165 | 'name'] and opt['rank'] == 0:
166 | mkdir_and_rename(osp.join('tb_logger', opt['name']))
167 |
168 | # initialize loggers
169 | logger, tb_logger = init_loggers(opt)
170 |
171 | # create train and validation dataloaders
172 | result = create_train_val_dataloader(opt, logger)
173 | train_loader, train_sampler, val_loader, total_epochs, total_iters = result
174 |
175 | # create model
176 | if resume_state: # resume training
177 | check_resume(opt, resume_state['iter'])
178 | model = create_model(opt)
179 | model.resume_training(resume_state) # handle optimizers and schedulers
180 | logger.info(f"Resuming training from epoch: {resume_state['epoch']}, "
181 | f"iter: {resume_state['iter']}.")
182 | start_epoch = resume_state['epoch']
183 | current_iter = resume_state['iter']
184 | else:
185 | model = create_model(opt)
186 | start_epoch = 0
187 | current_iter = 0
188 |
189 | # create message logger (formatted outputs)
190 | msg_logger = MessageLogger(opt, current_iter, tb_logger)
191 |
192 | # dataloader prefetcher
193 | prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
194 | if prefetch_mode is None or prefetch_mode == 'cpu':
195 | prefetcher = CPUPrefetcher(train_loader)
196 | elif prefetch_mode == 'cuda':
197 | prefetcher = CUDAPrefetcher(train_loader, opt)
198 | logger.info(f'Use {prefetch_mode} prefetch dataloader')
199 | if opt['datasets']['train'].get('pin_memory') is not True:
200 | raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
201 | else:
202 | raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.'
203 | "Supported ones are: None, 'cuda', 'cpu'.")
204 |
205 |
206 | # training
207 | logger.info(
208 | f'Start training from epoch: {start_epoch}, iter: {current_iter}')
209 | data_time, iter_time = time.time(), time.time()
210 | start_time = time.time()
211 |
212 | iters = opt['datasets']['train'].get('iters')
213 | batch_size = opt['datasets']['train'].get('batch_size_per_gpu')
214 | mini_batch_sizes = opt['datasets']['train'].get('mini_batch_sizes')
215 | gt_size = opt['datasets']['train'].get('gt_size')
216 | mini_gt_sizes = opt['datasets']['train'].get('gt_sizes')
217 |
218 | groups = np.array([sum(iters[0:i + 1]) for i in range(0, len(iters))])
219 |
220 | logger_j = [True] * len(groups)
221 |
222 | scale = opt['scale']
223 |
224 | epoch = start_epoch
225 | while current_iter <= total_iters:
226 | train_sampler.set_epoch(epoch)
227 | prefetcher.reset()
228 | train_data = prefetcher.next()
229 |
230 | while train_data is not None:
231 | data_time = time.time() - data_time
232 |
233 | current_iter += 1
234 | if current_iter > total_iters:
235 | break
236 | # update learning rate
237 | model.update_learning_rate(
238 | current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
239 |
240 | ### ------Progressive learning ---------------------
241 | j = ((current_iter>groups) !=True).nonzero()[0]
242 | if len(j) == 0:
243 | bs_j = len(groups) - 1
244 | else:
245 | bs_j = j[0]
246 |
247 | mini_gt_size = mini_gt_sizes[bs_j]
248 | mini_batch_size = mini_batch_sizes[bs_j]
249 |
250 | if logger_j[bs_j]:
251 | logger.info('\n Updating Patch_Size to {} and Batch_Size to {} \n'.format(mini_gt_size, mini_batch_size*torch.cuda.device_count()))
252 | logger_j[bs_j] = False
253 |
254 | lq = train_data['lq']
255 | gt = train_data['gt']
256 | flow = train_data['flow']
257 | flow_gt = train_data['flow_gt']
258 | # print(lq.shape)
259 |
260 | # if mini_batch_size < batch_size:
261 | # indices = random.sample(range(0, batch_size), k=mini_batch_size)
262 | # lq = lq[indices]
263 | # gt = gt[indices]
264 |
265 | if mini_batch_size < batch_size:
266 | indices = random.sample(range(0, batch_size), k=mini_batch_size)
267 | lq = lq[indices]
268 | gt = gt[indices]
269 | flow = flow[indices]
270 | flow_gt = flow_gt[indices]
271 |
272 | if mini_gt_size < gt_size:
273 | x0 = int((gt_size - mini_gt_size) * random.random())
274 | y0 = int((gt_size - mini_gt_size) * random.random())
275 | x1 = x0 + mini_gt_size
276 | y1 = y0 + mini_gt_size
277 | lq = lq[:,:,x0:x1,y0:y1]
278 | gt = gt[:,:,x0*scale:x1*scale,y0*scale:y1*scale]
279 | flow = flow[:,:,x0:x1,y0:y1]
280 | flow_gt = flow_gt[:,:,x0:x1,y0:y1]
281 | ###-------------------------------------------
282 |
283 | model.feed_train_data({'lq': lq, 'gt':gt,'flow':flow,"flow_gt":flow_gt})
284 | model.optimize_parameters(current_iter)
285 |
286 | iter_time = time.time() - iter_time
287 | # log
288 | if current_iter % opt['logger']['print_freq'] == 0:
289 | log_vars = {'epoch': epoch, 'iter': current_iter}
290 | log_vars.update({'lrs': model.get_current_learning_rate()})
291 | log_vars.update({'time': iter_time, 'data_time': data_time})
292 | log_vars.update(model.get_current_log())
293 | msg_logger(log_vars)
294 |
295 | # save models and training states
296 | if current_iter % opt['logger']['save_checkpoint_freq'] == 0 or (current_iter > 300000 and current_iter % 10000 == 0):
297 | logger.info('Saving models and training states.')
298 | model.save(epoch, current_iter)
299 |
300 | # validation
301 | if opt.get('val') is not None and (current_iter %
302 | opt['val']['val_freq'] == 0 or (current_iter > 300000 and current_iter % 10000 == 0)):
303 | rgb2bgr = opt['val'].get('rgb2bgr', True)
304 | # wheather use uint8 image to compute metrics
305 | use_image = opt['val'].get('use_image', True)
306 | model.validation(val_loader, current_iter, tb_logger,
307 | opt['val']['save_img'], rgb2bgr, use_image )
308 |
309 | data_time = time.time()
310 | iter_time = time.time()
311 | train_data = prefetcher.next()
312 | # a
313 | # end of iter
314 | epoch += 1
315 |
316 | # end of epoch
317 |
318 | consumed_time = str(
319 | datetime.timedelta(seconds=int(time.time() - start_time)))
320 | logger.info(f'End of training. Time consumed: {consumed_time}')
321 | logger.info('Save the latest model.')
322 | model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
323 | if opt.get('val') is not None:
324 | model.validation(val_loader, current_iter, tb_logger,
325 | opt['val']['save_last_img'])
326 | if tb_logger:
327 | tb_logger.close()
328 |
329 | if __name__ == "__main__":
330 | main()
--------------------------------------------------------------------------------
/basicsr/train_others.py:
--------------------------------------------------------------------------------
1 | import os
2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # choose your gpu id
4 | import torch
5 | import logging
6 | import math
7 | import time
8 | import random
9 | import numpy as np
10 | import datetime
11 | import argparse
12 | from basicsr.utils import (MessageLogger,make_exp_dirs,mkdir_and_rename,get_root_logger,
13 | get_env_info,init_wandb_logger,init_tb_logger,
14 | get_time_str,check_resume,set_random_seed)
15 | from os import path as osp
16 | from basicsr.utils.options import dict2str
17 | from basicsr.data import create_dataloader, create_dataset
18 | from basicsr.data.data_sampler import EnlargedSampler
19 | from basicsr.models import create_model
20 | from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
21 | from basicsr.utils.dist_util import get_dist_info, init_dist
22 | from basicsr.utils.options import dict2str, parse
23 |
24 | def parse_options(is_train=True):
25 | parser = argparse.ArgumentParser()
26 | parser.add_argument(
27 | '-opt', default="options/train_NAFNet_LAP.yml",
28 | type=str)#, required=True, help='Path to option YAML file.')
29 | parser.add_argument(
30 | '--launcher',
31 | choices=['none', 'pytorch', 'slurm'],
32 | default='none',
33 | help='job launcher')
34 | parser.add_argument('--local-rank', type=int, default=0)
35 | args = parser.parse_args()
36 | opt = parse(args.opt, is_train=is_train)
37 |
38 | # distributed settings
39 | if args.launcher == 'none':
40 | opt['dist'] = False
41 | print('Disable distributed.', flush=True)
42 | else:
43 | opt['dist'] = True
44 | if args.launcher == 'slurm' and 'dist_params' in opt:
45 | init_dist(args.launcher, **opt['dist_params'])
46 | else:
47 | init_dist(args.launcher)
48 | print('init dist .. ', args.launcher)
49 |
50 | opt['rank'], opt['world_size'] = get_dist_info()
51 |
52 | # random seed
53 | seed = opt.get('manual_seed')
54 | if seed is None:
55 | seed = random.randint(1, 10000)
56 | opt['manual_seed'] = seed
57 | set_random_seed(seed + opt['rank'])
58 |
59 | return opt
60 |
61 | def init_loggers(opt):
62 | log_file = osp.join(opt['path']['log'],
63 | f"train_{opt['name']}_{get_time_str()}.log")
64 | logger = get_root_logger(
65 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
66 | logger.info(get_env_info())
67 | logger.info(dict2str(opt))
68 |
69 | # initialize wandb logger before tensorboard logger to allow proper sync:
70 | if (opt['logger'].get('wandb')
71 | is not None) and (opt['logger']['wandb'].get('project')
72 | is not None) and ('debug' not in opt['name']):
73 | assert opt['logger'].get('use_tb_logger') is True, (
74 | 'should turn on tensorboard when using wandb')
75 | init_wandb_logger(opt)
76 | tb_logger = None
77 | if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
78 | tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name']))
79 | return logger, tb_logger
80 |
81 | def create_train_val_dataloader(opt, logger):
82 | # create train and val dataloaders
83 | train_loader, val_loader = None, None
84 | for phase, dataset_opt in opt['datasets'].items():
85 | if phase == 'train':
86 | dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1)
87 | train_set = create_dataset(dataset_opt)
88 | train_sampler = EnlargedSampler(train_set, opt['world_size'],
89 | opt['rank'], dataset_enlarge_ratio)
90 | train_loader = create_dataloader(
91 | train_set,
92 | dataset_opt,
93 | num_gpu=opt['num_gpu'],
94 | dist=opt['dist'],
95 | sampler=train_sampler,
96 | seed=opt['manual_seed'])
97 |
98 | num_iter_per_epoch = math.ceil(
99 | len(train_set) * dataset_enlarge_ratio /
100 | (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
101 | total_iters = int(opt['train']['total_iter'])
102 | total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
103 | logger.info(
104 | 'Training statistics:'
105 | f'\n\tNumber of train images: {len(train_set)}'
106 | f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
107 | f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
108 | f'\n\tWorld size (gpu number): {opt["world_size"]}'
109 | f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
110 | f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
111 |
112 | elif phase == 'val':
113 | val_set = create_dataset(dataset_opt)
114 | val_loader = create_dataloader(
115 | val_set,
116 | dataset_opt,
117 | num_gpu=opt['num_gpu'],
118 | dist=opt['dist'],
119 | sampler=None,
120 | seed=opt['manual_seed'])
121 | logger.info(
122 | f'Number of val images/folders in {dataset_opt["name"]}: '
123 | f'{len(val_set)}')
124 | else:
125 | raise ValueError(f'Dataset phase {phase} is not recognized.')
126 |
127 | return train_loader, train_sampler, val_loader, total_epochs, total_iters
128 |
129 |
130 | def main():
131 | # parse options, set distributed setting, set ramdom seed
132 | opt = parse_options(is_train=True)
133 |
134 | torch.backends.cudnn.benchmark = True
135 | # automatic resume ..
136 | state_folder_path = 'experiments/{}/training_states/'.format(opt['name'])
137 |
138 | try:
139 | states = os.listdir(state_folder_path)
140 | except:
141 | states = []
142 |
143 | resume_state = None
144 | if len(states) > 0:
145 | max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states]))
146 | resume_state = os.path.join(state_folder_path, max_state_file)
147 | opt['path']['resume_state'] = resume_state
148 |
149 | # load resume states if necessary
150 | if opt['path'].get('resume_state'):
151 | device_id = torch.cuda.current_device()
152 | resume_state = torch.load(
153 | opt['path']['resume_state'],
154 | map_location=lambda storage, loc: storage.cuda(device_id))
155 | else:
156 | resume_state = None
157 |
158 | # mkdir for experiments and logger
159 | if resume_state is None:
160 | make_exp_dirs(opt)
161 | if opt['logger'].get('use_tb_logger') and 'debug' not in opt[
162 | 'name'] and opt['rank'] == 0:
163 | mkdir_and_rename(osp.join('tb_logger', opt['name']))
164 |
165 | # initialize loggers
166 | logger, tb_logger = init_loggers(opt)
167 |
168 | # create train and validation dataloaders
169 | result = create_train_val_dataloader(opt, logger)
170 | train_loader, train_sampler, val_loader, total_epochs, total_iters = result
171 |
172 | # create model
173 | if resume_state: # resume training
174 | check_resume(opt, resume_state['iter'])
175 | model = create_model(opt)
176 | model.resume_training(resume_state) # handle optimizers and schedulers
177 | logger.info(f"Resuming training from epoch: {resume_state['epoch']}, "
178 | f"iter: {resume_state['iter']}.")
179 | start_epoch = resume_state['epoch']
180 | current_iter = resume_state['iter']
181 | else:
182 | model = create_model(opt)
183 | start_epoch = 0
184 | current_iter = 0
185 |
186 | # create message logger (formatted outputs)
187 | msg_logger = MessageLogger(opt, current_iter, tb_logger)
188 |
189 | # dataloader prefetcher
190 | prefetch_mode = opt['datasets']['train'].get('prefetch_mode')
191 | if prefetch_mode is None or prefetch_mode == 'cpu':
192 | prefetcher = CPUPrefetcher(train_loader)
193 | elif prefetch_mode == 'cuda':
194 | prefetcher = CUDAPrefetcher(train_loader, opt)
195 | logger.info(f'Use {prefetch_mode} prefetch dataloader')
196 | if opt['datasets']['train'].get('pin_memory') is not True:
197 | raise ValueError('Please set pin_memory=True for CUDAPrefetcher.')
198 | else:
199 | raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.'
200 | "Supported ones are: None, 'cuda', 'cpu'.")
201 |
202 |
203 | # training
204 | logger.info(
205 | f'Start training from epoch: {start_epoch}, iter: {current_iter}')
206 | data_time, iter_time = time.time(), time.time()
207 | start_time = time.time()
208 |
209 | iters = opt['datasets']['train'].get('iters')
210 | batch_size = opt['datasets']['train'].get('batch_size_per_gpu')
211 | mini_batch_sizes = opt['datasets']['train'].get('mini_batch_sizes')
212 | gt_size = opt['datasets']['train'].get('gt_size')
213 | mini_gt_sizes = opt['datasets']['train'].get('gt_sizes')
214 |
215 | groups = np.array([sum(iters[0:i + 1]) for i in range(0, len(iters))])
216 |
217 | logger_j = [True] * len(groups)
218 |
219 | scale = opt['scale']
220 |
221 | epoch = start_epoch
222 | while current_iter <= total_iters:
223 | train_sampler.set_epoch(epoch)
224 | prefetcher.reset()
225 | train_data = prefetcher.next()
226 |
227 | while train_data is not None:
228 | data_time = time.time() - data_time
229 |
230 | current_iter += 1
231 | if current_iter > total_iters:
232 | break
233 | # update learning rate
234 | model.update_learning_rate(
235 | current_iter, warmup_iter=opt['train'].get('warmup_iter', -1))
236 |
237 | ### ------Progressive learning ---------------------
238 | j = ((current_iter>groups) !=True).nonzero()[0]
239 | if len(j) == 0:
240 | bs_j = len(groups) - 1
241 | else:
242 | bs_j = j[0]
243 |
244 | mini_gt_size = mini_gt_sizes[bs_j]
245 | mini_batch_size = mini_batch_sizes[bs_j]
246 |
247 | if logger_j[bs_j]:
248 | logger.info('\n Updating Patch_Size to {} and Batch_Size to {} \n'.format(mini_gt_size, mini_batch_size*torch.cuda.device_count()))
249 | logger_j[bs_j] = False
250 |
251 | lq = train_data['lq']
252 | gt = train_data['gt']
253 |
254 | if mini_batch_size < batch_size:
255 | indices = random.sample(range(0, batch_size), k=mini_batch_size)
256 | lq = lq[indices]
257 | gt = gt[indices]
258 |
259 | # if mini_batch_size < batch_size:
260 | # indices = random.sample(range(0, batch_size), k=mini_batch_size)
261 | # lq = lq[indices]
262 | # gt = gt[indices]
263 |
264 | if mini_gt_size < gt_size:
265 | x0 = int((gt_size - mini_gt_size) * random.random())
266 | y0 = int((gt_size - mini_gt_size) * random.random())
267 | x1 = x0 + mini_gt_size
268 | y1 = y0 + mini_gt_size
269 | lq = lq[:,:,x0:x1,y0:y1]
270 | gt = gt[:,:,x0*scale:x1*scale,y0*scale:y1*scale]
271 | ###-------------------------------------------
272 |
273 | model.feed_train_data({'lq': lq, 'gt':gt})
274 | model.optimize_parameters(current_iter)
275 |
276 | iter_time = time.time() - iter_time
277 | # log
278 | if current_iter % opt['logger']['print_freq'] == 0:
279 | log_vars = {'epoch': epoch, 'iter': current_iter}
280 | log_vars.update({'lrs': model.get_current_learning_rate()})
281 | log_vars.update({'time': iter_time, 'data_time': data_time})
282 | log_vars.update(model.get_current_log())
283 | msg_logger(log_vars)
284 |
285 | # save models and training states
286 | if current_iter % opt['logger']['save_checkpoint_freq'] == 0 or (current_iter > 300000 and current_iter % 10000 == 0):
287 | logger.info('Saving models and training states.')
288 | model.save(epoch, current_iter)
289 |
290 | # validation
291 | if opt.get('val') is not None and ((current_iter %
292 | opt['val']['val_freq'] == 0) or (current_iter > 300000 and current_iter % 10000 == 0)):
293 | rgb2bgr = opt['val'].get('rgb2bgr', True)
294 | # wheather use uint8 image to compute metrics
295 | use_image = opt['val'].get('use_image', True)
296 | model.validation(val_loader, current_iter, tb_logger,
297 | opt['val']['save_img'], rgb2bgr, use_image )
298 |
299 | data_time = time.time()
300 | iter_time = time.time()
301 | train_data = prefetcher.next()
302 |
303 | epoch += 1
304 |
305 | # end of epoch
306 |
307 | consumed_time = str(
308 | datetime.timedelta(seconds=int(time.time() - start_time)))
309 | logger.info(f'End of training. Time consumed: {consumed_time}')
310 | logger.info('Save the latest model.')
311 | model.save(epoch=-1, current_iter=-1) # -1 stands for the latest
312 | if opt.get('val') is not None:
313 | model.validation(val_loader, current_iter, tb_logger,
314 | opt['val']['save_last_img'])
315 | if tb_logger:
316 | tb_logger.close()
317 |
318 | if __name__ == "__main__":
319 | main()
--------------------------------------------------------------------------------
/basicsr/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .file_client import FileClient
2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP
3 | from .logger import (MessageLogger, get_env_info, get_root_logger,
4 | init_tb_logger, init_wandb_logger)
5 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename,
6 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt)
7 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k)
8 |
9 | __all__ = [
10 | # file_client.py
11 | 'FileClient',
12 | # img_util.py
13 | 'img2tensor',
14 | 'tensor2img',
15 | 'imfrombytes',
16 | 'imwrite',
17 | 'crop_border',
18 | # logger.py
19 | 'MessageLogger',
20 | 'init_tb_logger',
21 | 'init_wandb_logger',
22 | 'get_root_logger',
23 | 'get_env_info',
24 | # misc.py
25 | 'set_random_seed',
26 | 'get_time_str',
27 | 'mkdir_and_rename',
28 | 'make_exp_dirs',
29 | 'scandir',
30 | 'check_resume',
31 | 'sizeof_fmt',
32 | 'padding',
33 | 'padding_DP',
34 | 'imfrombytesDP',
35 | 'create_lmdb_for_reds',
36 | 'create_lmdb_for_gopro',
37 | 'create_lmdb_for_rain13k',
38 | ]
39 |
--------------------------------------------------------------------------------
/basicsr/utils/bundle_submissions.py:
--------------------------------------------------------------------------------
1 | # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de)
2 |
3 | # This file is part of the implementation as described in the CVPR 2017 paper:
4 | # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs.
5 | # Please see the file LICENSE.txt for the license governing this code.
6 |
7 |
8 | import numpy as np
9 | import scipy.io as sio
10 | import os
11 | import h5py
12 |
13 | def bundle_submissions_raw(submission_folder,session):
14 | '''
15 | Bundles submission data for raw denoising
16 |
17 | submission_folder Folder where denoised images reside
18 |
19 | Output is written to /bundled/. Please submit
20 | the content of this folder.
21 | '''
22 |
23 | out_folder = os.path.join(submission_folder, session)
24 | # out_folder = os.path.join(submission_folder, "bundled/")
25 | try:
26 | os.mkdir(out_folder)
27 | except:pass
28 |
29 | israw = True
30 | eval_version="1.0"
31 |
32 | for i in range(50):
33 | Idenoised = np.zeros((20,), dtype=np.object)
34 | for bb in range(20):
35 | filename = '%04d_%02d.mat'%(i+1,bb+1)
36 | s = sio.loadmat(os.path.join(submission_folder,filename))
37 | Idenoised_crop = s["Idenoised_crop"]
38 | Idenoised[bb] = Idenoised_crop
39 | filename = '%04d.mat'%(i+1)
40 | sio.savemat(os.path.join(out_folder, filename),
41 | {"Idenoised": Idenoised,
42 | "israw": israw,
43 | "eval_version": eval_version},
44 | )
45 |
46 | def bundle_submissions_srgb(submission_folder,session):
47 | '''
48 | Bundles submission data for sRGB denoising
49 |
50 | submission_folder Folder where denoised images reside
51 |
52 | Output is written to /bundled/. Please submit
53 | the content of this folder.
54 | '''
55 | out_folder = os.path.join(submission_folder, session)
56 | # out_folder = os.path.join(submission_folder, "bundled/")
57 | try:
58 | os.mkdir(out_folder)
59 | except:pass
60 | israw = False
61 | eval_version="1.0"
62 |
63 | for i in range(50):
64 | Idenoised = np.zeros((20,), dtype=np.object)
65 | for bb in range(20):
66 | filename = '%04d_%02d.mat'%(i+1,bb+1)
67 | s = sio.loadmat(os.path.join(submission_folder,filename))
68 | Idenoised_crop = s["Idenoised_crop"]
69 | Idenoised[bb] = Idenoised_crop
70 | filename = '%04d.mat'%(i+1)
71 | sio.savemat(os.path.join(out_folder, filename),
72 | {"Idenoised": Idenoised,
73 | "israw": israw,
74 | "eval_version": eval_version},
75 | )
76 |
77 |
78 |
79 | def bundle_submissions_srgb_v1(submission_folder,session):
80 | '''
81 | Bundles submission data for sRGB denoising
82 |
83 | submission_folder Folder where denoised images reside
84 |
85 | Output is written to /bundled/. Please submit
86 | the content of this folder.
87 | '''
88 | out_folder = os.path.join(submission_folder, session)
89 | # out_folder = os.path.join(submission_folder, "bundled/")
90 | try:
91 | os.mkdir(out_folder)
92 | except:pass
93 | israw = False
94 | eval_version="1.0"
95 |
96 | for i in range(50):
97 | Idenoised = np.zeros((20,), dtype=np.object)
98 | for bb in range(20):
99 | filename = '%04d_%d.mat'%(i+1,bb+1)
100 | s = sio.loadmat(os.path.join(submission_folder,filename))
101 | Idenoised_crop = s["Idenoised_crop"]
102 | Idenoised[bb] = Idenoised_crop
103 | filename = '%04d.mat'%(i+1)
104 | sio.savemat(os.path.join(out_folder, filename),
105 | {"Idenoised": Idenoised,
106 | "israw": israw,
107 | "eval_version": eval_version},
108 | )
--------------------------------------------------------------------------------
/basicsr/utils/create_lmdb.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | from os import path as osp
3 |
4 | from basicsr.utils import scandir
5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs
6 |
7 | def prepare_keys(folder_path, suffix='png'):
8 | """Prepare image path list and keys for DIV2K dataset.
9 |
10 | Args:
11 | folder_path (str): Folder path.
12 |
13 | Returns:
14 | list[str]: Image path list.
15 | list[str]: Key list.
16 | """
17 | print('Reading image path list ...')
18 | img_path_list = sorted(
19 | list(scandir(folder_path, suffix=suffix, recursive=False)))
20 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)]
21 |
22 | return img_path_list, keys
23 |
24 | def create_lmdb_for_reds():
25 | folder_path = './datasets/REDS/val/sharp_300'
26 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb'
27 | img_path_list, keys = prepare_keys(folder_path, 'png')
28 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
29 | #
30 | folder_path = './datasets/REDS/val/blur_300'
31 | lmdb_path = './datasets/REDS/val/blur_300.lmdb'
32 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
33 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
34 |
35 | folder_path = './datasets/REDS/train/train_sharp'
36 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb'
37 | img_path_list, keys = prepare_keys(folder_path, 'png')
38 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
39 |
40 | folder_path = './datasets/REDS/train/train_blur_jpeg'
41 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb'
42 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
43 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
44 |
45 |
46 | def create_lmdb_for_gopro():
47 | folder_path = './datasets/GoPro/train/blur_crops'
48 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb'
49 |
50 | img_path_list, keys = prepare_keys(folder_path, 'png')
51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
52 |
53 | folder_path = './datasets/GoPro/train/sharp_crops'
54 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb'
55 |
56 | img_path_list, keys = prepare_keys(folder_path, 'png')
57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
58 |
59 | folder_path = './datasets/GoPro/test/target'
60 | lmdb_path = './datasets/GoPro/test/target.lmdb'
61 |
62 | img_path_list, keys = prepare_keys(folder_path, 'png')
63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
64 |
65 | folder_path = './datasets/GoPro/test/input'
66 | lmdb_path = './datasets/GoPro/test/input.lmdb'
67 |
68 | img_path_list, keys = prepare_keys(folder_path, 'png')
69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
70 |
71 | def create_lmdb_for_rain13k():
72 | folder_path = './datasets/Rain13k/train/input'
73 | lmdb_path = './datasets/Rain13k/train/input.lmdb'
74 |
75 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
77 |
78 | folder_path = './datasets/Rain13k/train/target'
79 | lmdb_path = './datasets/Rain13k/train/target.lmdb'
80 |
81 | img_path_list, keys = prepare_keys(folder_path, 'jpg')
82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
83 |
84 | def create_lmdb_for_SIDD():
85 | folder_path = './datasets/SIDD/train/input_crops'
86 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb'
87 |
88 | img_path_list, keys = prepare_keys(folder_path, 'PNG')
89 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
90 |
91 | folder_path = './datasets/SIDD/train/gt_crops'
92 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb'
93 |
94 | img_path_list, keys = prepare_keys(folder_path, 'PNG')
95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
96 |
97 | #for val
98 | folder_path = './datasets/SIDD/val/input_crops'
99 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb'
100 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat'
101 | if not osp.exists(folder_path):
102 | os.makedirs(folder_path)
103 | assert osp.exists(mat_path)
104 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb']
105 | N, B, H ,W, C = data.shape
106 | data = data.reshape(N*B, H, W, C)
107 | for i in tqdm(range(N*B)):
108 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
109 | img_path_list, keys = prepare_keys(folder_path, 'png')
110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
111 |
112 | folder_path = './datasets/SIDD/val/gt_crops'
113 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb'
114 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat'
115 | if not osp.exists(folder_path):
116 | os.makedirs(folder_path)
117 | assert osp.exists(mat_path)
118 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb']
119 | N, B, H ,W, C = data.shape
120 | data = data.reshape(N*B, H, W, C)
121 | for i in tqdm(range(N*B)):
122 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR))
123 | img_path_list, keys = prepare_keys(folder_path, 'png')
124 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys)
125 |
--------------------------------------------------------------------------------
/basicsr/utils/dist_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501
2 | import functools
3 | import os
4 | import subprocess
5 | import torch
6 | import torch.distributed as dist
7 | import torch.multiprocessing as mp
8 |
9 |
10 | def init_dist(launcher, backend='nccl', **kwargs):
11 | if mp.get_start_method(allow_none=True) is None:
12 | mp.set_start_method('spawn')
13 | if launcher == 'pytorch':
14 | _init_dist_pytorch(backend, **kwargs)
15 | elif launcher == 'slurm':
16 | _init_dist_slurm(backend, **kwargs)
17 | else:
18 | raise ValueError(f'Invalid launcher type: {launcher}')
19 |
20 |
21 | def _init_dist_pytorch(backend, **kwargs):
22 | rank = int(os.environ['RANK'])
23 | num_gpus = torch.cuda.device_count()
24 | torch.cuda.set_device(rank % num_gpus)
25 | dist.init_process_group(backend=backend, **kwargs)
26 |
27 |
28 | def _init_dist_slurm(backend, port=None):
29 | """Initialize slurm distributed training environment.
30 |
31 | If argument ``port`` is not specified, then the master port will be system
32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
33 | environment variable, then a default port ``29500`` will be used.
34 |
35 | Args:
36 | backend (str): Backend of torch.distributed.
37 | port (int, optional): Master port. Defaults to None.
38 | """
39 | proc_id = int(os.environ['SLURM_PROCID'])
40 | ntasks = int(os.environ['SLURM_NTASKS'])
41 | node_list = os.environ['SLURM_NODELIST']
42 | num_gpus = torch.cuda.device_count()
43 | torch.cuda.set_device(proc_id % num_gpus)
44 | addr = subprocess.getoutput(
45 | f'scontrol show hostname {node_list} | head -n1')
46 | # specify master port
47 | if port is not None:
48 | os.environ['MASTER_PORT'] = str(port)
49 | elif 'MASTER_PORT' in os.environ:
50 | pass # use MASTER_PORT in the environment variable
51 | else:
52 | # 29500 is torch.distributed default port
53 | os.environ['MASTER_PORT'] = '29500'
54 | os.environ['MASTER_ADDR'] = addr
55 | os.environ['WORLD_SIZE'] = str(ntasks)
56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus)
57 | os.environ['RANK'] = str(proc_id)
58 | dist.init_process_group(backend=backend)
59 |
60 |
61 | def get_dist_info():
62 | if dist.is_available():
63 | initialized = dist.is_initialized()
64 | else:
65 | initialized = False
66 | if initialized:
67 | rank = dist.get_rank()
68 | world_size = dist.get_world_size()
69 | else:
70 | rank = 0
71 | world_size = 1
72 | return rank, world_size
73 |
74 |
75 | def master_only(func):
76 |
77 | @functools.wraps(func)
78 | def wrapper(*args, **kwargs):
79 | rank, _ = get_dist_info()
80 | if rank == 0:
81 | return func(*args, **kwargs)
82 |
83 | return wrapper
84 |
--------------------------------------------------------------------------------
/basicsr/utils/download_util.py:
--------------------------------------------------------------------------------
1 | import math
2 | import requests
3 | from tqdm import tqdm
4 |
5 | from .misc import sizeof_fmt
6 |
7 |
8 | def download_file_from_google_drive(file_id, save_path):
9 | """Download files from google drive.
10 |
11 | Ref:
12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501
13 |
14 | Args:
15 | file_id (str): File id.
16 | save_path (str): Save path.
17 | """
18 |
19 | session = requests.Session()
20 | URL = 'https://docs.google.com/uc?export=download'
21 | params = {'id': file_id}
22 |
23 | response = session.get(URL, params=params, stream=True)
24 | token = get_confirm_token(response)
25 | if token:
26 | params['confirm'] = token
27 | response = session.get(URL, params=params, stream=True)
28 |
29 | # get file size
30 | response_file_size = session.get(
31 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'})
32 | if 'Content-Range' in response_file_size.headers:
33 | file_size = int(
34 | response_file_size.headers['Content-Range'].split('/')[1])
35 | else:
36 | file_size = None
37 |
38 | save_response_content(response, save_path, file_size)
39 |
40 |
41 | def get_confirm_token(response):
42 | for key, value in response.cookies.items():
43 | if key.startswith('download_warning'):
44 | return value
45 | return None
46 |
47 |
48 | def save_response_content(response,
49 | destination,
50 | file_size=None,
51 | chunk_size=32768):
52 | if file_size is not None:
53 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk')
54 |
55 | readable_file_size = sizeof_fmt(file_size)
56 | else:
57 | pbar = None
58 |
59 | with open(destination, 'wb') as f:
60 | downloaded_size = 0
61 | for chunk in response.iter_content(chunk_size):
62 | downloaded_size += chunk_size
63 | if pbar is not None:
64 | pbar.update(1)
65 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} '
66 | f'/ {readable_file_size}')
67 | if chunk: # filter out keep-alive new chunks
68 | f.write(chunk)
69 | if pbar is not None:
70 | pbar.close()
71 |
--------------------------------------------------------------------------------
/basicsr/utils/face_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import os
4 | import torch
5 | from skimage import transform as trans
6 |
7 | from basicsr.utils import imwrite
8 |
9 | try:
10 | import dlib
11 | except ImportError:
12 | print('Please install dlib before testing face restoration.'
13 | 'Reference: https://github.com/davisking/dlib')
14 |
15 |
16 | class FaceRestorationHelper(object):
17 | """Helper for the face restoration pipeline."""
18 |
19 | def __init__(self, upscale_factor, face_size=512):
20 | self.upscale_factor = upscale_factor
21 | self.face_size = (face_size, face_size)
22 |
23 | # standard 5 landmarks for FFHQ faces with 1024 x 1024
24 | self.face_template = np.array([[686.77227723, 488.62376238],
25 | [586.77227723, 493.59405941],
26 | [337.91089109, 488.38613861],
27 | [437.95049505, 493.51485149],
28 | [513.58415842, 678.5049505]])
29 | self.face_template = self.face_template / (1024 // face_size)
30 | # for estimation the 2D similarity transformation
31 | self.similarity_trans = trans.SimilarityTransform()
32 |
33 | self.all_landmarks_5 = []
34 | self.all_landmarks_68 = []
35 | self.affine_matrices = []
36 | self.inverse_affine_matrices = []
37 | self.cropped_faces = []
38 | self.restored_faces = []
39 | self.save_png = True
40 |
41 | def init_dlib(self, detection_path, landmark5_path, landmark68_path):
42 | """Initialize the dlib detectors and predictors."""
43 | self.face_detector = dlib.cnn_face_detection_model_v1(detection_path)
44 | self.shape_predictor_5 = dlib.shape_predictor(landmark5_path)
45 | self.shape_predictor_68 = dlib.shape_predictor(landmark68_path)
46 |
47 | def free_dlib_gpu_memory(self):
48 | del self.face_detector
49 | del self.shape_predictor_5
50 | del self.shape_predictor_68
51 |
52 | def read_input_image(self, img_path):
53 | # self.input_img is Numpy array, (h, w, c) with RGB order
54 | self.input_img = dlib.load_rgb_image(img_path)
55 |
56 | def detect_faces(self,
57 | img_path,
58 | upsample_num_times=1,
59 | only_keep_largest=False):
60 | """
61 | Args:
62 | img_path (str): Image path.
63 | upsample_num_times (int): Upsamples the image before running the
64 | face detector
65 |
66 | Returns:
67 | int: Number of detected faces.
68 | """
69 | self.read_input_image(img_path)
70 | det_faces = self.face_detector(self.input_img, upsample_num_times)
71 | if len(det_faces) == 0:
72 | print('No face detected. Try to increase upsample_num_times.')
73 | else:
74 | if only_keep_largest:
75 | print('Detect several faces and only keep the largest.')
76 | face_areas = []
77 | for i in range(len(det_faces)):
78 | face_area = (det_faces[i].rect.right() -
79 | det_faces[i].rect.left()) * (
80 | det_faces[i].rect.bottom() -
81 | det_faces[i].rect.top())
82 | face_areas.append(face_area)
83 | largest_idx = face_areas.index(max(face_areas))
84 | self.det_faces = [det_faces[largest_idx]]
85 | else:
86 | self.det_faces = det_faces
87 | return len(self.det_faces)
88 |
89 | def get_face_landmarks_5(self):
90 | for face in self.det_faces:
91 | shape = self.shape_predictor_5(self.input_img, face.rect)
92 | landmark = np.array([[part.x, part.y] for part in shape.parts()])
93 | self.all_landmarks_5.append(landmark)
94 | return len(self.all_landmarks_5)
95 |
96 | def get_face_landmarks_68(self):
97 | """Get 68 densemarks for cropped images.
98 |
99 | Should only have one face at most in the cropped image.
100 | """
101 | num_detected_face = 0
102 | for idx, face in enumerate(self.cropped_faces):
103 | # face detection
104 | det_face = self.face_detector(face, 1) # TODO: can we remove it?
105 | if len(det_face) == 0:
106 | print(f'Cannot find faces in cropped image with index {idx}.')
107 | self.all_landmarks_68.append(None)
108 | else:
109 | if len(det_face) > 1:
110 | print('Detect several faces in the cropped face. Use the '
111 | ' largest one. Note that it will also cause overlap '
112 | 'during paste_faces_to_input_image.')
113 | face_areas = []
114 | for i in range(len(det_face)):
115 | face_area = (det_face[i].rect.right() -
116 | det_face[i].rect.left()) * (
117 | det_face[i].rect.bottom() -
118 | det_face[i].rect.top())
119 | face_areas.append(face_area)
120 | largest_idx = face_areas.index(max(face_areas))
121 | face_rect = det_face[largest_idx].rect
122 | else:
123 | face_rect = det_face[0].rect
124 | shape = self.shape_predictor_68(face, face_rect)
125 | landmark = np.array([[part.x, part.y]
126 | for part in shape.parts()])
127 | self.all_landmarks_68.append(landmark)
128 | num_detected_face += 1
129 |
130 | return num_detected_face
131 |
132 | def warp_crop_faces(self,
133 | save_cropped_path=None,
134 | save_inverse_affine_path=None):
135 | """Get affine matrix, warp and cropped faces.
136 |
137 | Also get inverse affine matrix for post-processing.
138 | """
139 | for idx, landmark in enumerate(self.all_landmarks_5):
140 | # use 5 landmarks to get affine matrix
141 | self.similarity_trans.estimate(landmark, self.face_template)
142 | affine_matrix = self.similarity_trans.params[0:2, :]
143 | self.affine_matrices.append(affine_matrix)
144 | # warp and crop faces
145 | cropped_face = cv2.warpAffine(self.input_img, affine_matrix,
146 | self.face_size)
147 | self.cropped_faces.append(cropped_face)
148 | # save the cropped face
149 | if save_cropped_path is not None:
150 | path, ext = os.path.splitext(save_cropped_path)
151 | if self.save_png:
152 | save_path = f'{path}_{idx:02d}.png'
153 | else:
154 | save_path = f'{path}_{idx:02d}{ext}'
155 |
156 | imwrite(
157 | cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path)
158 |
159 | # get inverse affine matrix
160 | self.similarity_trans.estimate(self.face_template,
161 | landmark * self.upscale_factor)
162 | inverse_affine = self.similarity_trans.params[0:2, :]
163 | self.inverse_affine_matrices.append(inverse_affine)
164 | # save inverse affine matrices
165 | if save_inverse_affine_path is not None:
166 | path, _ = os.path.splitext(save_inverse_affine_path)
167 | save_path = f'{path}_{idx:02d}.pth'
168 | torch.save(inverse_affine, save_path)
169 |
170 | def add_restored_face(self, face):
171 | self.restored_faces.append(face)
172 |
173 | def paste_faces_to_input_image(self, save_path):
174 | # operate in the BGR order
175 | input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR)
176 | h, w, _ = input_img.shape
177 | h_up, w_up = h * self.upscale_factor, w * self.upscale_factor
178 | # simply resize the background
179 | upsample_img = cv2.resize(input_img, (w_up, h_up))
180 | assert len(self.restored_faces) == len(self.inverse_affine_matrices), (
181 | 'length of restored_faces and affine_matrices are different.')
182 | for restored_face, inverse_affine in zip(self.restored_faces,
183 | self.inverse_affine_matrices):
184 | inv_restored = cv2.warpAffine(restored_face, inverse_affine,
185 | (w_up, h_up))
186 | mask = np.ones((*self.face_size, 3), dtype=np.float32)
187 | inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up))
188 | # remove the black borders
189 | inv_mask_erosion = cv2.erode(
190 | inv_mask,
191 | np.ones((2 * self.upscale_factor, 2 * self.upscale_factor),
192 | np.uint8))
193 | inv_restored_remove_border = inv_mask_erosion * inv_restored
194 | total_face_area = np.sum(inv_mask_erosion) // 3
195 | # compute the fusion edge based on the area of face
196 | w_edge = int(total_face_area**0.5) // 20
197 | erosion_radius = w_edge * 2
198 | inv_mask_center = cv2.erode(
199 | inv_mask_erosion,
200 | np.ones((erosion_radius, erosion_radius), np.uint8))
201 | blur_size = w_edge * 2
202 | inv_soft_mask = cv2.GaussianBlur(inv_mask_center,
203 | (blur_size + 1, blur_size + 1), 0)
204 | upsample_img = inv_soft_mask * inv_restored_remove_border + (
205 | 1 - inv_soft_mask) * upsample_img
206 | if self.save_png:
207 | save_path = save_path.replace('.jpg',
208 | '.png').replace('.jpeg', '.png')
209 | imwrite(upsample_img.astype(np.uint8), save_path)
210 |
211 | def clean_all(self):
212 | self.all_landmarks_5 = []
213 | self.all_landmarks_68 = []
214 | self.restored_faces = []
215 | self.affine_matrices = []
216 | self.cropped_faces = []
217 | self.inverse_affine_matrices = []
218 |
--------------------------------------------------------------------------------
/basicsr/utils/file_client.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501
2 | from abc import ABCMeta, abstractmethod
3 |
4 |
5 | class BaseStorageBackend(metaclass=ABCMeta):
6 | """Abstract class of storage backends.
7 |
8 | All backends need to implement two apis: ``get()`` and ``get_text()``.
9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file
10 | as texts.
11 | """
12 |
13 | @abstractmethod
14 | def get(self, filepath):
15 | pass
16 |
17 | @abstractmethod
18 | def get_text(self, filepath):
19 | pass
20 |
21 |
22 | class MemcachedBackend(BaseStorageBackend):
23 | """Memcached storage backend.
24 |
25 | Attributes:
26 | server_list_cfg (str): Config file for memcached server list.
27 | client_cfg (str): Config file for memcached client.
28 | sys_path (str | None): Additional path to be appended to `sys.path`.
29 | Default: None.
30 | """
31 |
32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None):
33 | if sys_path is not None:
34 | import sys
35 | sys.path.append(sys_path)
36 | try:
37 | import mc
38 | except ImportError:
39 | raise ImportError(
40 | 'Please install memcached to enable MemcachedBackend.')
41 |
42 | self.server_list_cfg = server_list_cfg
43 | self.client_cfg = client_cfg
44 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg,
45 | self.client_cfg)
46 | # mc.pyvector servers as a point which points to a memory cache
47 | self._mc_buffer = mc.pyvector()
48 |
49 | def get(self, filepath):
50 | filepath = str(filepath)
51 | import mc
52 | self._client.Get(filepath, self._mc_buffer)
53 | value_buf = mc.ConvertBuffer(self._mc_buffer)
54 | return value_buf
55 |
56 | def get_text(self, filepath):
57 | raise NotImplementedError
58 |
59 |
60 | class HardDiskBackend(BaseStorageBackend):
61 | """Raw hard disks storage backend."""
62 |
63 | def get(self, filepath):
64 | filepath = str(filepath)
65 | with open(filepath, 'rb') as f:
66 | value_buf = f.read()
67 | return value_buf
68 |
69 | def get_text(self, filepath):
70 | filepath = str(filepath)
71 | with open(filepath, 'r') as f:
72 | value_buf = f.read()
73 | return value_buf
74 |
75 |
76 | class LmdbBackend(BaseStorageBackend):
77 | """Lmdb storage backend.
78 |
79 | Args:
80 | db_paths (str | list[str]): Lmdb database paths.
81 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'.
82 | readonly (bool, optional): Lmdb environment parameter. If True,
83 | disallow any write operations. Default: True.
84 | lock (bool, optional): Lmdb environment parameter. If False, when
85 | concurrent access occurs, do not lock the database. Default: False.
86 | readahead (bool, optional): Lmdb environment parameter. If False,
87 | disable the OS filesystem readahead mechanism, which may improve
88 | random read performance when a database is larger than RAM.
89 | Default: False.
90 |
91 | Attributes:
92 | db_paths (list): Lmdb database path.
93 | _client (list): A list of several lmdb envs.
94 | """
95 |
96 | def __init__(self,
97 | db_paths,
98 | client_keys='default',
99 | readonly=True,
100 | lock=False,
101 | readahead=False,
102 | **kwargs):
103 | try:
104 | import lmdb
105 | except ImportError:
106 | raise ImportError('Please install lmdb to enable LmdbBackend.')
107 |
108 | if isinstance(client_keys, str):
109 | client_keys = [client_keys]
110 |
111 | if isinstance(db_paths, list):
112 | self.db_paths = [str(v) for v in db_paths]
113 | elif isinstance(db_paths, str):
114 | self.db_paths = [str(db_paths)]
115 | assert len(client_keys) == len(self.db_paths), (
116 | 'client_keys and db_paths should have the same length, '
117 | f'but received {len(client_keys)} and {len(self.db_paths)}.')
118 |
119 | self._client = {}
120 |
121 | for client, path in zip(client_keys, self.db_paths):
122 | self._client[client] = lmdb.open(
123 | path,
124 | readonly=readonly,
125 | lock=lock,
126 | readahead=readahead,
127 | map_size=8*1024*10485760,
128 | # max_readers=1,
129 | **kwargs)
130 |
131 | def get(self, filepath, client_key):
132 | """Get values according to the filepath from one lmdb named client_key.
133 |
134 | Args:
135 | filepath (str | obj:`Path`): Here, filepath is the lmdb key.
136 | client_key (str): Used for distinguishing differnet lmdb envs.
137 | """
138 | filepath = str(filepath)
139 | assert client_key in self._client, (f'client_key {client_key} is not '
140 | 'in lmdb clients.')
141 | client = self._client[client_key]
142 | with client.begin(write=False) as txn:
143 | value_buf = txn.get(filepath.encode('ascii'))
144 | return value_buf
145 |
146 | def get_text(self, filepath):
147 | raise NotImplementedError
148 |
149 |
150 | class FileClient(object):
151 | """A general file client to access files in different backend.
152 |
153 | The client loads a file or text in a specified backend from its path
154 | and return it as a binary file. it can also register other backend
155 | accessor with a given name and backend class.
156 |
157 | Attributes:
158 | backend (str): The storage backend type. Options are "disk",
159 | "memcached" and "lmdb".
160 | client (:obj:`BaseStorageBackend`): The backend object.
161 | """
162 |
163 | _backends = {
164 | 'disk': HardDiskBackend,
165 | 'memcached': MemcachedBackend,
166 | 'lmdb': LmdbBackend,
167 | }
168 |
169 | def __init__(self, backend='disk', **kwargs):
170 | if backend not in self._backends:
171 | raise ValueError(
172 | f'Backend {backend} is not supported. Currently supported ones'
173 | f' are {list(self._backends.keys())}')
174 | self.backend = backend
175 | self.client = self._backends[backend](**kwargs)
176 |
177 | def get(self, filepath, client_key='default'):
178 | # client_key is used only for lmdb, where different fileclients have
179 | # different lmdb environments.
180 | if self.backend == 'lmdb':
181 | return self.client.get(filepath, client_key)
182 | else:
183 | return self.client.get(filepath)
184 |
185 | def get_text(self, filepath):
186 | return self.client.get_text(filepath)
187 |
--------------------------------------------------------------------------------
/basicsr/utils/flow_util.py:
--------------------------------------------------------------------------------
1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
2 | import cv2
3 | import numpy as np
4 | import os
5 |
6 |
7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
8 | """Read an optical flow map.
9 |
10 | Args:
11 | flow_path (ndarray or str): Flow path.
12 | quantize (bool): whether to read quantized pair, if set to True,
13 | remaining args will be passed to :func:`dequantize_flow`.
14 | concat_axis (int): The axis that dx and dy are concatenated,
15 | can be either 0 or 1. Ignored if quantize is False.
16 |
17 | Returns:
18 | ndarray: Optical flow represented as a (h, w, 2) numpy array
19 | """
20 | if quantize:
21 | assert concat_axis in [0, 1]
22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED)
23 | if cat_flow.ndim != 2:
24 | raise IOError(f'{flow_path} is not a valid quantized flow file, '
25 | f'its dimension is {cat_flow.ndim}.')
26 | assert cat_flow.shape[concat_axis] % 2 == 0
27 | dx, dy = np.split(cat_flow, 2, axis=concat_axis)
28 | flow = dequantize_flow(dx, dy, *args, **kwargs)
29 | else:
30 | with open(flow_path, 'rb') as f:
31 | try:
32 | header = f.read(4).decode('utf-8')
33 | except Exception:
34 | raise IOError(f'Invalid flow file: {flow_path}')
35 | else:
36 | if header != 'PIEH':
37 | raise IOError(f'Invalid flow file: {flow_path}, '
38 | 'header does not contain PIEH')
39 |
40 | w = np.fromfile(f, np.int32, 1).squeeze()
41 | h = np.fromfile(f, np.int32, 1).squeeze()
42 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2))
43 |
44 | return flow.astype(np.float32)
45 |
46 |
47 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs):
48 | """Write optical flow to file.
49 |
50 | If the flow is not quantized, it will be saved as a .flo file losslessly,
51 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy
52 | will be concatenated horizontally into a single image if quantize is True.)
53 |
54 | Args:
55 | flow (ndarray): (h, w, 2) array of optical flow.
56 | filename (str): Output filepath.
57 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg
58 | images. If set to True, remaining args will be passed to
59 | :func:`quantize_flow`.
60 | concat_axis (int): The axis that dx and dy are concatenated,
61 | can be either 0 or 1. Ignored if quantize is False.
62 | """
63 | if not quantize:
64 | with open(filename, 'wb') as f:
65 | f.write('PIEH'.encode('utf-8'))
66 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f)
67 | flow = flow.astype(np.float32)
68 | flow.tofile(f)
69 | f.flush()
70 | else:
71 | assert concat_axis in [0, 1]
72 | dx, dy = quantize_flow(flow, *args, **kwargs)
73 | dxdy = np.concatenate((dx, dy), axis=concat_axis)
74 | os.makedirs(filename, exist_ok=True)
75 | cv2.imwrite(dxdy, filename)
76 |
77 |
78 | def quantize_flow(flow, max_val=0.02, norm=True):
79 | """Quantize flow to [0, 255].
80 |
81 | After this step, the size of flow will be much smaller, and can be
82 | dumped as jpeg images.
83 |
84 | Args:
85 | flow (ndarray): (h, w, 2) array of optical flow.
86 | max_val (float): Maximum value of flow, values beyond
87 | [-max_val, max_val] will be truncated.
88 | norm (bool): Whether to divide flow values by image width/height.
89 |
90 | Returns:
91 | tuple[ndarray]: Quantized dx and dy.
92 | """
93 | h, w, _ = flow.shape
94 | dx = flow[..., 0]
95 | dy = flow[..., 1]
96 | if norm:
97 | dx = dx / w # avoid inplace operations
98 | dy = dy / h
99 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization.
100 | flow_comps = [
101 | quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy]
102 | ]
103 | return tuple(flow_comps)
104 |
105 |
106 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True):
107 | """Recover from quantized flow.
108 |
109 | Args:
110 | dx (ndarray): Quantized dx.
111 | dy (ndarray): Quantized dy.
112 | max_val (float): Maximum value used when quantizing.
113 | denorm (bool): Whether to multiply flow values with width/height.
114 |
115 | Returns:
116 | ndarray: Dequantized flow.
117 | """
118 | assert dx.shape == dy.shape
119 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1)
120 |
121 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]]
122 |
123 | if denorm:
124 | dx *= dx.shape[1]
125 | dy *= dx.shape[0]
126 | flow = np.dstack((dx, dy))
127 | return flow
128 |
129 |
130 | def quantize(arr, min_val, max_val, levels, dtype=np.int64):
131 | """Quantize an array of (-inf, inf) to [0, levels-1].
132 |
133 | Args:
134 | arr (ndarray): Input array.
135 | min_val (scalar): Minimum value to be clipped.
136 | max_val (scalar): Maximum value to be clipped.
137 | levels (int): Quantization levels.
138 | dtype (np.type): The type of the quantized array.
139 |
140 | Returns:
141 | tuple: Quantized array.
142 | """
143 | if not (isinstance(levels, int) and levels > 1):
144 | raise ValueError(
145 | f'levels must be a positive integer, but got {levels}')
146 | if min_val >= max_val:
147 | raise ValueError(
148 | f'min_val ({min_val}) must be smaller than max_val ({max_val})')
149 |
150 | arr = np.clip(arr, min_val, max_val) - min_val
151 | quantized_arr = np.minimum(
152 | np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1)
153 |
154 | return quantized_arr
155 |
156 |
157 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64):
158 | """Dequantize an array.
159 |
160 | Args:
161 | arr (ndarray): Input array.
162 | min_val (scalar): Minimum value to be clipped.
163 | max_val (scalar): Maximum value to be clipped.
164 | levels (int): Quantization levels.
165 | dtype (np.type): The type of the dequantized array.
166 |
167 | Returns:
168 | tuple: Dequantized array.
169 | """
170 | if not (isinstance(levels, int) and levels > 1):
171 | raise ValueError(
172 | f'levels must be a positive integer, but got {levels}')
173 | if min_val >= max_val:
174 | raise ValueError(
175 | f'min_val ({min_val}) must be smaller than max_val ({max_val})')
176 |
177 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val -
178 | min_val) / levels + min_val
179 |
180 | return dequantized_arr
181 |
--------------------------------------------------------------------------------
/basicsr/utils/img_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import math
3 | import numpy as np
4 | import os
5 | import torch
6 | from torchvision.utils import make_grid
7 |
8 |
9 | def img2tensor(imgs, bgr2rgb=True, float32=True):
10 | """Numpy array to tensor.
11 |
12 | Args:
13 | imgs (list[ndarray] | ndarray): Input images.
14 | bgr2rgb (bool): Whether to change bgr to rgb.
15 | float32 (bool): Whether to change to float32.
16 |
17 | Returns:
18 | list[tensor] | tensor: Tensor images. If returned results only have
19 | one element, just return tensor.
20 | """
21 |
22 | def _totensor(img, bgr2rgb, float32):
23 | if img.shape[2] == 3 and bgr2rgb:
24 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
25 | img = torch.from_numpy(img.transpose(2, 0, 1))
26 | if float32:
27 | img = img.float()
28 | return img
29 |
30 | if isinstance(imgs, list):
31 | return [_totensor(img, bgr2rgb, float32) for img in imgs]
32 | else:
33 | return _totensor(imgs, bgr2rgb, float32)
34 |
35 |
36 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
37 | """Convert torch Tensors into image numpy arrays.
38 |
39 | After clamping to [min, max], values will be normalized to [0, 1].
40 |
41 | Args:
42 | tensor (Tensor or list[Tensor]): Accept shapes:
43 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W);
44 | 2) 3D Tensor of shape (3/1 x H x W);
45 | 3) 2D Tensor of shape (H x W).
46 | Tensor channel should be in RGB order.
47 | rgb2bgr (bool): Whether to change rgb to bgr.
48 | out_type (numpy type): output types. If ``np.uint8``, transform outputs
49 | to uint8 type with range [0, 255]; otherwise, float type with
50 | range [0, 1]. Default: ``np.uint8``.
51 | min_max (tuple[int]): min and max values for clamp.
52 |
53 | Returns:
54 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of
55 | shape (H x W). The channel order is BGR.
56 | """
57 | if not (torch.is_tensor(tensor) or
58 | (isinstance(tensor, list)
59 | and all(torch.is_tensor(t) for t in tensor))):
60 | raise TypeError(
61 | f'tensor or list of tensors expected, got {type(tensor)}')
62 |
63 | if torch.is_tensor(tensor):
64 | tensor = [tensor]
65 | result = []
66 | for _tensor in tensor:
67 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max)
68 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0])
69 |
70 | n_dim = _tensor.dim()
71 | if n_dim == 4:
72 | img_np = make_grid(
73 | _tensor, nrow=int(math.sqrt(_tensor.size(0))),
74 | normalize=False).numpy()
75 | img_np = img_np.transpose(1, 2, 0)
76 | if rgb2bgr:
77 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
78 | elif n_dim == 3:
79 | img_np = _tensor.numpy()
80 | img_np = img_np.transpose(1, 2, 0)
81 | if img_np.shape[2] == 1: # gray image
82 | img_np = np.squeeze(img_np, axis=2)
83 | else:
84 | if rgb2bgr:
85 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
86 | elif n_dim == 2:
87 | img_np = _tensor.numpy()
88 | else:
89 | raise TypeError('Only support 4D, 3D or 2D tensor. '
90 | f'But received with dimension: {n_dim}')
91 | if out_type == np.uint8:
92 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default.
93 | img_np = (img_np * 255.0).round()
94 | img_np = img_np.astype(out_type)
95 | result.append(img_np)
96 | if len(result) == 1:
97 | result = result[0]
98 | return result
99 |
100 |
101 | def imfrombytes(content, flag='color', float32=False):
102 | """Read an image from bytes.
103 |
104 | Args:
105 | content (bytes): Image bytes got from files or other streams.
106 | flag (str): Flags specifying the color type of a loaded image,
107 | candidates are `color`, `grayscale` and `unchanged`.
108 | float32 (bool): Whether to change to float32., If True, will also norm
109 | to [0, 1]. Default: False.
110 |
111 | Returns:
112 | ndarray: Loaded image array.
113 | """
114 | img_np = np.frombuffer(content, np.uint8)
115 | imread_flags = {
116 | 'color': cv2.IMREAD_COLOR,
117 | 'grayscale': cv2.IMREAD_GRAYSCALE,
118 | 'unchanged': cv2.IMREAD_UNCHANGED
119 | }
120 | if img_np is None:
121 | raise Exception('None .. !!!')
122 | print("12230033neibu1",img_np)
123 | img = cv2.imdecode(img_np, imread_flags[flag])
124 | print("12230033neibu",img)
125 | if float32:
126 | img = img.astype(np.float32) / 255.
127 | return img
128 |
129 | def imfrombytesDP(content, flag='color', float32=False):
130 | """Read an image from bytes.
131 |
132 | Args:
133 | content (bytes): Image bytes got from files or other streams.
134 | flag (str): Flags specifying the color type of a loaded image,
135 | candidates are `color`, `grayscale` and `unchanged`.
136 | float32 (bool): Whether to change to float32., If True, will also norm
137 | to [0, 1]. Default: False.
138 |
139 | Returns:
140 | ndarray: Loaded image array.
141 | """
142 | img_np = np.frombuffer(content, np.uint8)
143 | if img_np is None:
144 | raise Exception('None .. !!!')
145 | img = cv2.imdecode(img_np, cv2.IMREAD_UNCHANGED)
146 | if float32:
147 | img = img.astype(np.float32) / 65535.
148 | return img
149 |
150 | def padding(img_lq, img_gt, gt_size):
151 | h, w, _ = img_lq.shape
152 |
153 | h_pad = max(0, gt_size - h)
154 | w_pad = max(0, gt_size - w)
155 |
156 | if h_pad == 0 and w_pad == 0:
157 | return img_lq, img_gt
158 |
159 | img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
160 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
161 | # print('img_lq', img_lq.shape, img_gt.shape)
162 | if img_lq.ndim == 2:
163 | img_lq = np.expand_dims(img_lq, axis=2)
164 | if img_gt.ndim == 2:
165 | img_gt = np.expand_dims(img_gt, axis=2)
166 | return img_lq, img_gt
167 |
168 | def padding_DP(img_lqL, img_lqR, img_gt, gt_size):
169 | h, w, _ = img_gt.shape
170 |
171 | h_pad = max(0, gt_size - h)
172 | w_pad = max(0, gt_size - w)
173 |
174 | if h_pad == 0 and w_pad == 0:
175 | return img_lqL, img_lqR, img_gt
176 |
177 | img_lqL = cv2.copyMakeBorder(img_lqL, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
178 | img_lqR = cv2.copyMakeBorder(img_lqR, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
179 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT)
180 | # print('img_lq', img_lq.shape, img_gt.shape)
181 | return img_lqL, img_lqR, img_gt
182 |
183 | def imwrite(img, file_path, params=None, auto_mkdir=True):
184 | """Write image to file.
185 |
186 | Args:
187 | img (ndarray): Image array to be written.
188 | file_path (str): Image file path.
189 | params (None or list): Same as opencv's :func:`imwrite` interface.
190 | auto_mkdir (bool): If the parent folder of `file_path` does not exist,
191 | whether to create it automatically.
192 |
193 | Returns:
194 | bool: Successful or not.
195 | """
196 | if auto_mkdir:
197 | dir_name = os.path.abspath(os.path.dirname(file_path))
198 | os.makedirs(dir_name, exist_ok=True)
199 | return cv2.imwrite(file_path, img, params)
200 |
201 |
202 | def crop_border(imgs, crop_border):
203 | """Crop borders of images.
204 |
205 | Args:
206 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c).
207 | crop_border (int): Crop border for each end of height and weight.
208 |
209 | Returns:
210 | list[ndarray]: Cropped images.
211 | """
212 | if crop_border == 0:
213 | return imgs
214 | else:
215 | if isinstance(imgs, list):
216 | return [
217 | v[crop_border:-crop_border, crop_border:-crop_border, ...]
218 | for v in imgs
219 | ]
220 | else:
221 | return imgs[crop_border:-crop_border, crop_border:-crop_border,
222 | ...]
223 |
--------------------------------------------------------------------------------
/basicsr/utils/lmdb_util.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import lmdb
3 | import sys
4 | from multiprocessing import Pool
5 | from os import path as osp
6 | from tqdm import tqdm
7 |
8 |
9 | def make_lmdb_from_imgs(data_path,
10 | lmdb_path,
11 | img_path_list,
12 | keys,
13 | batch=5000,
14 | compress_level=1,
15 | multiprocessing_read=False,
16 | n_thread=40,
17 | map_size=None):
18 | """Make lmdb from images.
19 |
20 | Contents of lmdb. The file structure is:
21 | example.lmdb
22 | ├── data.mdb
23 | ├── lock.mdb
24 | ├── meta_info.txt
25 |
26 | The data.mdb and lock.mdb are standard lmdb files and you can refer to
27 | https://lmdb.readthedocs.io/en/release/ for more details.
28 |
29 | The meta_info.txt is a specified txt file to record the meta information
30 | of our datasets. It will be automatically created when preparing
31 | datasets by our provided dataset tools.
32 | Each line in the txt file records 1)image name (with extension),
33 | 2)image shape, and 3)compression level, separated by a white space.
34 |
35 | For example, the meta information could be:
36 | `000_00000000.png (720,1280,3) 1`, which means:
37 | 1) image name (with extension): 000_00000000.png;
38 | 2) image shape: (720,1280,3);
39 | 3) compression level: 1
40 |
41 | We use the image name without extension as the lmdb key.
42 |
43 | If `multiprocessing_read` is True, it will read all the images to memory
44 | using multiprocessing. Thus, your server needs to have enough memory.
45 |
46 | Args:
47 | data_path (str): Data path for reading images.
48 | lmdb_path (str): Lmdb save path.
49 | img_path_list (str): Image path list.
50 | keys (str): Used for lmdb keys.
51 | batch (int): After processing batch images, lmdb commits.
52 | Default: 5000.
53 | compress_level (int): Compress level when encoding images. Default: 1.
54 | multiprocessing_read (bool): Whether use multiprocessing to read all
55 | the images to memory. Default: False.
56 | n_thread (int): For multiprocessing.
57 | map_size (int | None): Map size for lmdb env. If None, use the
58 | estimated size from images. Default: None
59 | """
60 |
61 | assert len(img_path_list) == len(keys), (
62 | 'img_path_list and keys should have the same length, '
63 | f'but got {len(img_path_list)} and {len(keys)}')
64 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
65 | print(f'Totoal images: {len(img_path_list)}')
66 | if not lmdb_path.endswith('.lmdb'):
67 | raise ValueError("lmdb_path must end with '.lmdb'.")
68 | if osp.exists(lmdb_path):
69 | print(f'Folder {lmdb_path} already exists. Exit.')
70 | sys.exit(1)
71 |
72 | if multiprocessing_read:
73 | # read all the images to memory (multiprocessing)
74 | dataset = {} # use dict to keep the order for multiprocessing
75 | shapes = {}
76 | print(f'Read images with multiprocessing, #thread: {n_thread} ...')
77 | pbar = tqdm(total=len(img_path_list), unit='image')
78 |
79 | def callback(arg):
80 | """get the image data and update pbar."""
81 | key, dataset[key], shapes[key] = arg
82 | pbar.update(1)
83 | pbar.set_description(f'Read {key}')
84 |
85 | pool = Pool(n_thread)
86 | for path, key in zip(img_path_list, keys):
87 | pool.apply_async(
88 | read_img_worker,
89 | args=(osp.join(data_path, path), key, compress_level),
90 | callback=callback)
91 | pool.close()
92 | pool.join()
93 | pbar.close()
94 | print(f'Finish reading {len(img_path_list)} images.')
95 |
96 | # create lmdb environment
97 | if map_size is None:
98 | # obtain data size for one image
99 | img = cv2.imread(
100 | osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED)
101 | _, img_byte = cv2.imencode(
102 | '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
103 | data_size_per_img = img_byte.nbytes
104 | print('Data size per image is: ', data_size_per_img)
105 | data_size = data_size_per_img * len(img_path_list)
106 | map_size = data_size * 10
107 |
108 | env = lmdb.open(lmdb_path, map_size=map_size)
109 |
110 | # write data to lmdb
111 | pbar = tqdm(total=len(img_path_list), unit='chunk')
112 | txn = env.begin(write=True)
113 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
114 | for idx, (path, key) in enumerate(zip(img_path_list, keys)):
115 | pbar.update(1)
116 | pbar.set_description(f'Write {key}')
117 | key_byte = key.encode('ascii')
118 | if multiprocessing_read:
119 | img_byte = dataset[key]
120 | h, w, c = shapes[key]
121 | else:
122 | _, img_byte, img_shape = read_img_worker(
123 | osp.join(data_path, path), key, compress_level)
124 | h, w, c = img_shape
125 |
126 | txn.put(key_byte, img_byte)
127 | # write meta information
128 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n')
129 | if idx % batch == 0:
130 | txn.commit()
131 | txn = env.begin(write=True)
132 | pbar.close()
133 | txn.commit()
134 | env.close()
135 | txt_file.close()
136 | print('\nFinish writing lmdb.')
137 |
138 |
139 | def read_img_worker(path, key, compress_level):
140 | """Read image worker.
141 |
142 | Args:
143 | path (str): Image path.
144 | key (str): Image key.
145 | compress_level (int): Compress level when encoding images.
146 |
147 | Returns:
148 | str: Image key.
149 | byte: Image byte.
150 | tuple[int]: Image shape.
151 | """
152 |
153 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED)
154 | if img.ndim == 2:
155 | h, w = img.shape
156 | c = 1
157 | else:
158 | h, w, c = img.shape
159 | _, img_byte = cv2.imencode('.png', img,
160 | [cv2.IMWRITE_PNG_COMPRESSION, compress_level])
161 | return (key, img_byte, (h, w, c))
162 |
163 |
164 | class LmdbMaker():
165 | """LMDB Maker.
166 |
167 | Args:
168 | lmdb_path (str): Lmdb save path.
169 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB.
170 | batch (int): After processing batch images, lmdb commits.
171 | Default: 5000.
172 | compress_level (int): Compress level when encoding images. Default: 1.
173 | """
174 |
175 | def __init__(self,
176 | lmdb_path,
177 | map_size=1024**4,
178 | batch=5000,
179 | compress_level=1):
180 | if not lmdb_path.endswith('.lmdb'):
181 | raise ValueError("lmdb_path must end with '.lmdb'.")
182 | if osp.exists(lmdb_path):
183 | print(f'Folder {lmdb_path} already exists. Exit.')
184 | sys.exit(1)
185 |
186 | self.lmdb_path = lmdb_path
187 | self.batch = batch
188 | self.compress_level = compress_level
189 | self.env = lmdb.open(lmdb_path, map_size=map_size)
190 | self.txn = self.env.begin(write=True)
191 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w')
192 | self.counter = 0
193 |
194 | def put(self, img_byte, key, img_shape):
195 | self.counter += 1
196 | key_byte = key.encode('ascii')
197 | self.txn.put(key_byte, img_byte)
198 | # write meta information
199 | h, w, c = img_shape
200 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n')
201 | if self.counter % self.batch == 0:
202 | self.txn.commit()
203 | self.txn = self.env.begin(write=True)
204 |
205 | def close(self):
206 | self.txn.commit()
207 | self.env.close()
208 | self.txt_file.close()
209 |
--------------------------------------------------------------------------------
/basicsr/utils/logger.py:
--------------------------------------------------------------------------------
1 | import datetime
2 | import logging
3 | import time
4 |
5 | from .dist_util import get_dist_info, master_only
6 |
7 | initialized_logger = {}
8 |
9 |
10 | class MessageLogger():
11 | """Message logger for printing.
12 |
13 | Args:
14 | opt (dict): Config. It contains the following keys:
15 | name (str): Exp name.
16 | logger (dict): Contains 'print_freq' (str) for logger interval.
17 | train (dict): Contains 'total_iter' (int) for total iters.
18 | use_tb_logger (bool): Use tensorboard logger.
19 | start_iter (int): Start iter. Default: 1.
20 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None.
21 | """
22 |
23 | def __init__(self, opt, start_iter=1, tb_logger=None):
24 | self.exp_name = opt['name']
25 | self.interval = opt['logger']['print_freq']
26 | self.start_iter = start_iter
27 | self.max_iters = opt['train']['total_iter']
28 | self.use_tb_logger = opt['logger']['use_tb_logger']
29 | self.tb_logger = tb_logger
30 | self.start_time = time.time()
31 | self.logger = get_root_logger()
32 |
33 | @master_only
34 | def __call__(self, log_vars):
35 | """Format logging message.
36 |
37 | Args:
38 | log_vars (dict): It contains the following keys:
39 | epoch (int): Epoch number.
40 | iter (int): Current iter.
41 | lrs (list): List for learning rates.
42 |
43 | time (float): Iter time.
44 | data_time (float): Data time for each iter.
45 | """
46 | # epoch, iter, learning rates
47 | epoch = log_vars.pop('epoch')
48 | current_iter = log_vars.pop('iter')
49 | lrs = log_vars.pop('lrs')
50 |
51 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(')
52 | for v in lrs:
53 | message += f'{v:.3e},'
54 | message += ')] '
55 |
56 | # time and estimated time
57 | if 'time' in log_vars.keys():
58 | iter_time = log_vars.pop('time')
59 | data_time = log_vars.pop('data_time')
60 |
61 | total_time = time.time() - self.start_time
62 | time_sec_avg = total_time / (current_iter - self.start_iter + 1)
63 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1)
64 | eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
65 | message += f'[eta: {eta_str}, '
66 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] '
67 |
68 | # other items, especially losses
69 | for k, v in log_vars.items():
70 | message += f'{k}: {v:.4e} '
71 | # tensorboard logger
72 | if self.use_tb_logger and 'debug' not in self.exp_name:
73 | if k.startswith('l_'):
74 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter)
75 | else:
76 | self.tb_logger.add_scalar(k, v, current_iter)
77 | self.logger.info(message)
78 |
79 |
80 | @master_only
81 | def init_tb_logger(log_dir):
82 | from torch.utils.tensorboard import SummaryWriter
83 | tb_logger = SummaryWriter(log_dir=log_dir)
84 | return tb_logger
85 |
86 |
87 | @master_only
88 | def init_wandb_logger(opt):
89 | """We now only use wandb to sync tensorboard log."""
90 | import wandb
91 | logger = logging.getLogger('basicsr')
92 |
93 | project = opt['logger']['wandb']['project']
94 | resume_id = opt['logger']['wandb'].get('resume_id')
95 | if resume_id:
96 | wandb_id = resume_id
97 | resume = 'allow'
98 | logger.warning(f'Resume wandb logger with id={wandb_id}.')
99 | else:
100 | wandb_id = wandb.util.generate_id()
101 | resume = 'never'
102 |
103 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True)
104 |
105 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.')
106 |
107 |
108 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None):
109 | """Get the root logger.
110 |
111 | The logger will be initialized if it has not been initialized. By default a
112 | StreamHandler will be added. If `log_file` is specified, a FileHandler will
113 | also be added.
114 |
115 | Args:
116 | logger_name (str): root logger name. Default: 'basicsr'.
117 | log_file (str | None): The log filename. If specified, a FileHandler
118 | will be added to the root logger.
119 | log_level (int): The root logger level. Note that only the process of
120 | rank 0 is affected, while other processes will set the level to
121 | "Error" and be silent most of the time.
122 |
123 | Returns:
124 | logging.Logger: The root logger.
125 | """
126 | logger = logging.getLogger(logger_name)
127 | # if the logger has been initialized, just return it
128 | if logger_name in initialized_logger:
129 | return logger
130 |
131 | format_str = '%(asctime)s %(levelname)s: %(message)s'
132 | stream_handler = logging.StreamHandler()
133 | stream_handler.setFormatter(logging.Formatter(format_str))
134 | logger.addHandler(stream_handler)
135 | logger.propagate = False
136 | rank, _ = get_dist_info()
137 | if rank != 0:
138 | logger.setLevel('ERROR')
139 | elif log_file is not None:
140 | logger.setLevel(log_level)
141 | # add file handler
142 | file_handler = logging.FileHandler(log_file, 'w')
143 | file_handler.setFormatter(logging.Formatter(format_str))
144 | file_handler.setLevel(log_level)
145 | logger.addHandler(file_handler)
146 | initialized_logger[logger_name] = True
147 | return logger
148 |
149 |
150 | def get_env_info():
151 | """Get environment information.
152 |
153 | Currently, only log the software version.
154 | """
155 | import torch
156 | import torchvision
157 |
158 | from basicsr.version import __version__
159 | msg = r"""
160 | ____ _ _____ ____
161 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \
162 | / __ |/ __ `// ___// // ___/\__ \ / /_/ /
163 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/
164 | /_____/ \__,_//____//_/ \___//____//_/ |_|
165 | ______ __ __ __ __
166 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / /
167 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / /
168 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
169 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
170 | """
171 | msg += ('\nVersion Information: '
172 | f'\n\tBasicSR: {__version__}'
173 | f'\n\tPyTorch: {torch.__version__}'
174 | f'\n\tTorchVision: {torchvision.__version__}')
175 | return msg
--------------------------------------------------------------------------------
/basicsr/utils/misc.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | import random
4 | import time
5 | import torch
6 | from os import path as osp
7 |
8 | from .dist_util import master_only
9 | from .logger import get_root_logger
10 |
11 |
12 | def set_random_seed(seed):
13 | """Set random seeds."""
14 | random.seed(seed)
15 | np.random.seed(seed)
16 | torch.manual_seed(seed)
17 | torch.cuda.manual_seed(seed)
18 | torch.cuda.manual_seed_all(seed)
19 |
20 |
21 | def get_time_str():
22 | return time.strftime('%Y%m%d_%H%M%S', time.localtime())
23 |
24 |
25 | def mkdir_and_rename(path):
26 | """mkdirs. If path exists, rename it with timestamp and create a new one.
27 |
28 | Args:
29 | path (str): Folder path.
30 | """
31 | if osp.exists(path):
32 | new_name = path + '_archived_' + get_time_str()
33 | print(f'Path already exists. Rename it to {new_name}', flush=True)
34 | os.rename(path, new_name)
35 | os.makedirs(path, exist_ok=True)
36 |
37 |
38 | @master_only
39 | def make_exp_dirs(opt):
40 | """Make dirs for experiments."""
41 | path_opt = opt['path'].copy()
42 | if opt['is_train']:
43 | mkdir_and_rename(path_opt.pop('experiments_root'))
44 | else:
45 | mkdir_and_rename(path_opt.pop('results_root'))
46 | for key, path in path_opt.items():
47 | if ('strict_load' not in key) and ('pretrain_network'
48 | not in key) and ('resume'
49 | not in key):
50 | os.makedirs(path, exist_ok=True)
51 |
52 |
53 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
54 | """Scan a directory to find the interested files.
55 |
56 | Args:
57 | dir_path (str): Path of the directory.
58 | suffix (str | tuple(str), optional): File suffix that we are
59 | interested in. Default: None.
60 | recursive (bool, optional): If set to True, recursively scan the
61 | directory. Default: False.
62 | full_path (bool, optional): If set to True, include the dir_path.
63 | Default: False.
64 |
65 | Returns:
66 | A generator for all the interested files with relative pathes.
67 | """
68 |
69 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
70 | raise TypeError('"suffix" must be a string or tuple of strings')
71 |
72 | root = dir_path
73 |
74 | def _scandir(dir_path, suffix, recursive):
75 | for entry in os.scandir(dir_path):
76 | if not entry.name.startswith('.') and entry.is_file():
77 | if full_path:
78 | return_path = entry.path
79 | else:
80 | return_path = osp.relpath(entry.path, root)
81 |
82 | if suffix is None:
83 | yield return_path
84 | elif return_path.endswith(suffix):
85 | yield return_path
86 | else:
87 | if recursive:
88 | yield from _scandir(
89 | entry.path, suffix=suffix, recursive=recursive)
90 | else:
91 | continue
92 |
93 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
94 |
95 | def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False):
96 | """Scan a directory to find the interested files.
97 |
98 | Args:
99 | dir_path (str): Path of the directory.
100 | keywords (str | tuple(str), optional): File keywords that we are
101 | interested in. Default: None.
102 | recursive (bool, optional): If set to True, recursively scan the
103 | directory. Default: False.
104 | full_path (bool, optional): If set to True, include the dir_path.
105 | Default: False.
106 |
107 | Returns:
108 | A generator for all the interested files with relative pathes.
109 | """
110 |
111 | if (keywords is not None) and not isinstance(keywords, (str, tuple)):
112 | raise TypeError('"keywords" must be a string or tuple of strings')
113 |
114 | root = dir_path
115 |
116 | def _scandir(dir_path, keywords, recursive):
117 | for entry in os.scandir(dir_path):
118 | if not entry.name.startswith('.') and entry.is_file():
119 | if full_path:
120 | return_path = entry.path
121 | else:
122 | return_path = osp.relpath(entry.path, root)
123 |
124 | if keywords is None:
125 | yield return_path
126 | elif return_path.find(keywords) > 0:
127 | yield return_path
128 | else:
129 | if recursive:
130 | yield from _scandir(
131 | entry.path, keywords=keywords, recursive=recursive)
132 | else:
133 | continue
134 |
135 | return _scandir(dir_path, keywords=keywords, recursive=recursive)
136 |
137 | def check_resume(opt, resume_iter):
138 | """Check resume states and pretrain_network paths.
139 |
140 | Args:
141 | opt (dict): Options.
142 | resume_iter (int): Resume iteration.
143 | """
144 | logger = get_root_logger()
145 | if opt['path']['resume_state']:
146 | # get all the networks
147 | networks = [key for key in opt.keys() if key.startswith('network_')]
148 | flag_pretrain = False
149 | for network in networks:
150 | if opt['path'].get(f'pretrain_{network}') is not None:
151 | flag_pretrain = True
152 | if flag_pretrain:
153 | logger.warning(
154 | 'pretrain_network path will be ignored during resuming.')
155 | # set pretrained model paths
156 | for network in networks:
157 | name = f'pretrain_{network}'
158 | basename = network.replace('network_', '')
159 | if opt['path'].get('ignore_resume_networks') is None or (
160 | basename not in opt['path']['ignore_resume_networks']):
161 | opt['path'][name] = osp.join(
162 | opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
163 | logger.info(f"Set {name} to {opt['path'][name]}")
164 |
165 |
166 | def sizeof_fmt(size, suffix='B'):
167 | """Get human readable file size.
168 |
169 | Args:
170 | size (int): File size.
171 | suffix (str): Suffix. Default: 'B'.
172 |
173 | Return:
174 | str: Formated file siz.
175 | """
176 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
177 | if abs(size) < 1024.0:
178 | return f'{size:3.1f} {unit}{suffix}'
179 | size /= 1024.0
180 | return f'{size:3.1f} Y{suffix}'
181 |
--------------------------------------------------------------------------------
/basicsr/utils/options.py:
--------------------------------------------------------------------------------
1 | import yaml
2 | from collections import OrderedDict
3 | from os import path as osp
4 |
5 |
6 | def ordered_yaml():
7 | """Support OrderedDict for yaml.
8 |
9 | Returns:
10 | yaml Loader and Dumper.
11 | """
12 | try:
13 | from yaml import CDumper as Dumper
14 | from yaml import CLoader as Loader
15 | except ImportError:
16 | from yaml import Dumper, Loader
17 |
18 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG
19 |
20 | def dict_representer(dumper, data):
21 | return dumper.represent_dict(data.items())
22 |
23 | def dict_constructor(loader, node):
24 | return OrderedDict(loader.construct_pairs(node))
25 |
26 | Dumper.add_representer(OrderedDict, dict_representer)
27 | Loader.add_constructor(_mapping_tag, dict_constructor)
28 | return Loader, Dumper
29 |
30 |
31 | def parse(opt_path, is_train=True):
32 | """Parse option file.
33 |
34 | Args:
35 | opt_path (str): Option file path.
36 | is_train (str): Indicate whether in training or not. Default: True.
37 |
38 | Returns:
39 | (dict): Options.
40 | """
41 | with open(opt_path, mode='r') as f:
42 | Loader, _ = ordered_yaml()
43 | opt = yaml.load(f, Loader=Loader)
44 |
45 | opt['is_train'] = is_train
46 |
47 | # datasets
48 | for phase, dataset in opt['datasets'].items():
49 | # for several datasets, e.g., test_1, test_2
50 | phase = phase.split('_')[0]
51 | dataset['phase'] = phase
52 | if 'scale' in opt:
53 | dataset['scale'] = opt['scale']
54 | if dataset.get('dataroot_gt') is not None:
55 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt'])
56 | if dataset.get('dataroot_lq') is not None:
57 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq'])
58 |
59 | # paths
60 | for key, val in opt['path'].items():
61 | if (val is not None) and ('resume_state' in key
62 | or 'pretrain_network' in key):
63 | opt['path'][key] = osp.expanduser(val)
64 | opt['path']['root'] = osp.abspath(
65 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir))
66 | if is_train:
67 | experiments_root = osp.join(opt['path']['root'], 'experiments',
68 | opt['name'])
69 | opt['path']['experiments_root'] = experiments_root
70 | opt['path']['models'] = osp.join(experiments_root, 'models')
71 | opt['path']['training_states'] = osp.join(experiments_root,
72 | 'training_states')
73 | opt['path']['log'] = experiments_root
74 | opt['path']['visualization'] = osp.join(experiments_root,
75 | 'visualization')
76 |
77 | # change some options for debug mode
78 | if 'debug' in opt['name']:
79 | if 'val' in opt:
80 | opt['val']['val_freq'] = 8
81 | opt['logger']['print_freq'] = 1
82 | opt['logger']['save_checkpoint_freq'] = 8
83 | else: # test
84 | results_root = osp.join(opt['path']['root'], 'results', opt['name'])
85 | opt['path']['results_root'] = results_root
86 | opt['path']['log'] = results_root
87 | opt['path']['visualization'] = osp.join(results_root, 'visualization')
88 |
89 | return opt
90 |
91 |
92 | def dict2str(opt, indent_level=1):
93 | """dict to string for printing options.
94 |
95 | Args:
96 | opt (dict): Option dict.
97 | indent_level (int): Indent level. Default: 1.
98 |
99 | Return:
100 | (str): Option string for printing.
101 | """
102 | msg = '\n'
103 | for k, v in opt.items():
104 | if isinstance(v, dict):
105 | msg += ' ' * (indent_level * 2) + k + ':['
106 | msg += dict2str(v, indent_level + 1)
107 | msg += ' ' * (indent_level * 2) + ']\n'
108 | else:
109 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n'
110 | return msg
111 |
--------------------------------------------------------------------------------
/options/test_JARNet_LAP.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: test_JARNet_LAP_20240114
3 | model_type: ImageMlossFlowModel
4 | scale: 1
5 | num_gpu: 1 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | test:
11 | name: ValSet
12 | type: LAPDataSet
13 | dataroot_gt: ~ # Please fill out the path to the LAP evaluation dataset.
14 | io_backend:
15 | type: disk
16 |
17 | network_g:
18 | type: JARNet
19 | img_channel: 1
20 | width: 32
21 | middle_blk_num: 4
22 | enc_blk_nums: [4, 4, 4, 4]
23 | dec_blk_nums: [4, 4, 4, 4]
24 | window_size_e_fft: [64, -1, -1, -1]
25 |
26 | # path
27 | path:
28 | pretrain_network_g: ~ # Please fill out the path to the pre-trained model.
29 | strict_load_g: true
30 | resume_state: ~
31 |
32 | # validation settings
33 | val:
34 | save_img: false
35 | use_image: true # a litte different from true in PSNR and SSIM, we apply true for PSNR and SSIM, and false for gmsd (otherwise error occurs)
36 |
37 | metrics:
38 | psnr:
39 | type: calculate_psnr
40 | crop_border: 0
41 | test_y_channel: false
42 | ssim:
43 | type: calculate_ssim
44 | crop_border: 0
45 | test_y_channel: false
46 | # gmsd:
47 | # type: calculate_gmsd
48 |
49 | # dist training settings
50 | dist_params:
51 | backend: nccl
52 | port: 29500
53 |
--------------------------------------------------------------------------------
/options/test_NAFNet_LAP.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: test_NAFNet_LAP_20240114
3 | model_type: ImageMlossModel
4 | scale: 1
5 | num_gpu: 1 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | test:
11 | name: ValSet
12 | type: LAPDataSetNoWarp
13 | dataroot_gt: ~ # Please fill out the path to the LAP evaluation dataset.
14 | io_backend:
15 | type: disk
16 |
17 | network_g:
18 | type: NAFNet
19 | img_channel: 1
20 | width: 64
21 | middle_blk_num: 1
22 | enc_blk_nums: [1, 1, 1, 28]
23 | dec_blk_nums: [1, 1, 1, 1]
24 |
25 | # path
26 | path:
27 | pretrain_network_g: ~
28 | strict_load_g: true
29 | resume_state: ~
30 |
31 | # validation settings
32 | val:
33 | save_img: False
34 | use_image: true # a litte different from true in PSNR and SSIM, we apply true for PSNR and SSIM, and false for gmsd (otherwise error occurs)
35 |
36 | metrics:
37 | psnr:
38 | type: calculate_psnr
39 | crop_border: 0
40 | test_y_channel: false
41 | ssim:
42 | type: calculate_ssim
43 | crop_border: 0
44 | test_y_channel: false
45 | # gmsd:
46 | # type: calculate_gmsd
47 |
48 | # dist training settings
49 | dist_params:
50 | backend: nccl
51 | port: 29500
52 |
--------------------------------------------------------------------------------
/options/train_JARNet_LAP.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: train_JARNet_LAP_20240114
3 | model_type: ImageMlossFlowModel
4 | scale: 1
5 | num_gpu: 1 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: LAPDataSet
13 | dataroot_gt: ~ # Please fill out the path to the LAP training dataset.
14 | filename_tmpl: '{}'
15 | io_backend:
16 | type: disk
17 |
18 | # data loader
19 | use_shuffle: true
20 | num_worker_per_gpu: 4
21 | batch_size_per_gpu: 4
22 |
23 | mini_batch_sizes: [4]
24 | iters: [450000]
25 | gt_size: 128
26 | gt_sizes: [128]
27 |
28 | dataset_enlarge_ratio: 1
29 | prefetch_mode: ~
30 |
31 | val:
32 | name: ValSet
33 | type: LAPDataSet
34 | dataroot_gt: ~ # Please fill out the path to the LAP evaluation dataset.
35 | io_backend:
36 | type: disk
37 |
38 | network_g:
39 | type: JARNet
40 | img_channel: 1
41 | width: 32
42 | middle_blk_num: 4
43 | enc_blk_nums: [4, 4, 4, 4]
44 | dec_blk_nums: [4, 4, 4, 4]
45 | window_size_e_fft: [64, -1, -1, -1]
46 |
47 | # path
48 | path:
49 | pretrain_network_g: ~
50 | strict_load_g: true
51 | resume_state: ~
52 |
53 | # training settings
54 | train:
55 | total_iter: 450000
56 | warmup_iter: -1 # no warm up
57 | use_grad_clip: true
58 |
59 | scheduler:
60 | type: TrueCosineAnnealingLR
61 | T_max: 450000
62 | eta_min: !!float 1e-7
63 |
64 | mixing_augs:
65 | mixup: false
66 | mixup_beta: 1.2
67 | use_identity: true
68 |
69 | optim_g:
70 | type: AdamW
71 | lr: !!float 3e-4
72 | weight_decay: !!float 1e-3
73 | betas: [0.9, 0.999]
74 |
75 | # losses
76 | pixel_opt:
77 | type: L1Loss
78 | loss_weight: 1
79 | reduction: mean
80 |
81 | flow_opt:
82 | type: PerceptualLoss
83 | loss_weight: 0.0001
84 | reduction: mean
85 |
86 | msfr_opt:
87 | type: FFTLoss
88 | loss_weight: 0.1
89 | reduction: mean
90 |
91 | flow_weight: 0.1
92 |
93 | # validation settings
94 | val:
95 | window_size: 8
96 | val_freq: !!float 5e4
97 | save_img: false
98 | save_last_img: true
99 | rgb2bgr: false
100 | use_image: true
101 |
102 | metrics:
103 | psnr:
104 | type: calculate_psnr
105 | crop_border: 0
106 | test_y_channel: false
107 | ssim:
108 | type: calculate_ssim
109 | crop_border: 0
110 | test_y_channel: false
111 |
112 | # logging settings
113 | logger:
114 | print_freq: 1000
115 | save_checkpoint_freq: !!float 5e4
116 | use_tb_logger: true
117 | wandb:
118 | project: ~
119 | resume_id: ~
120 |
121 | # dist training settings
122 | dist_params:
123 | backend: nccl
124 | port: 29500
125 |
--------------------------------------------------------------------------------
/options/train_NAFNet_LAP.yml:
--------------------------------------------------------------------------------
1 | # general settings
2 | name: train_NAFNet_LAP_20240114
3 | model_type: ImageMlossModel
4 | scale: 1
5 | num_gpu: 1 # set num_gpu: 0 for cpu mode
6 | manual_seed: 100
7 |
8 | # dataset and data loader settings
9 | datasets:
10 | train:
11 | name: TrainSet
12 | type: LAPDataSetNoWarp
13 | dataroot_gt: ~ # Please fill out the path to the LAP training dataset.
14 | filename_tmpl: '{}'
15 | io_backend:
16 | type: disk
17 |
18 | use_shuffle: true
19 | num_worker_per_gpu: 4
20 | batch_size_per_gpu: 4
21 |
22 | mini_batch_sizes: [4]
23 | iters: [450000]
24 | gt_size: 128
25 | gt_sizes: [128]
26 |
27 | dataset_enlarge_ratio: 1
28 | prefetch_mode: ~
29 |
30 | val:
31 | name: ValSet
32 | type: LAPDataSetNoWarp
33 | dataroot_gt: ~ # Please fill out the path to the LAP evaluation dataset.
34 | io_backend:
35 | type: disk
36 |
37 | network_g:
38 | type: NAFNet
39 | img_channel: 1
40 | width: 64
41 | middle_blk_num: 1
42 | enc_blk_nums: [1, 1, 1, 28]
43 | dec_blk_nums: [1, 1, 1, 1]
44 |
45 | # path
46 | path:
47 | pretrain_network_g: ~
48 | strict_load_g: true
49 | resume_state: ~
50 |
51 | # training settings
52 | train:
53 | total_iter: 450000
54 | warmup_iter: -1 # no warm up
55 | use_grad_clip: true
56 |
57 | scheduler:
58 | type: TrueCosineAnnealingLR
59 | T_max: 450000
60 | eta_min: !!float 1e-7
61 |
62 | mixing_augs:
63 | mixup: false
64 | mixup_beta: 1.2
65 | use_identity: true
66 |
67 | optim_g:
68 | type: AdamW
69 | lr: !!float 3e-4
70 | weight_decay: !!float 1e-3
71 | betas: [0.9, 0.9]
72 |
73 | pixel_opt:
74 | type: PSNRLoss
75 | loss_weight: 1
76 | reduction: mean
77 |
78 | # validation settings
79 | val:
80 | window_size: 8
81 | val_freq: !!float 5e4
82 | save_img: false
83 | save_last_img: true
84 | rgb2bgr: false
85 | use_image: true
86 |
87 | metrics:
88 | psnr:
89 | type: calculate_psnr
90 | crop_border: 0
91 | test_y_channel: false
92 | ssim:
93 | type: calculate_ssim
94 | crop_border: 0
95 | test_y_channel: false
96 |
97 | # logging settings
98 | logger:
99 | print_freq: 1000
100 | save_checkpoint_freq: !!float 5e4
101 | use_tb_logger: true
102 | wandb:
103 | project: ~
104 | resume_id: ~
105 |
106 | # dist training settings
107 | dist_params:
108 | backend: nccl
109 | port: 29500
110 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | einops
2 | h5py
3 | lmdb
4 | matplotlib
5 | numpy
6 | opencv-python
7 | ptflops
8 | pyyaml
9 | scikit-image
10 | scipy
11 | torch
12 | torchvision
13 | tqdm
14 | wandb
--------------------------------------------------------------------------------
/setup.cfg:
--------------------------------------------------------------------------------
1 | [flake8]
2 | ignore =
3 | # line break before binary operator (W503)
4 | W503,
5 | # line break after binary operator (W504)
6 | W504,
7 | max-line-length=79
8 |
9 | [yapf]
10 | based_on_style = pep8
11 | blank_line_before_nested_class_or_def = true
12 | split_before_expression_after_opening_paren = true
13 |
14 | [isort]
15 | line_length = 79
16 | multi_line_output = 0
17 | known_standard_library = pkg_resources,setuptools
18 | known_first_party = basicsr
19 | known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml
20 | no_lines_before = STDLIB,LOCALFOLDER
21 | default_section = THIRDPARTY
22 |
--------------------------------------------------------------------------------
/setup.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/env python
2 |
3 | from setuptools import find_packages, setup
4 |
5 | import os
6 | import subprocess
7 | import sys
8 | import time
9 | import torch
10 | from torch.utils.cpp_extension import (BuildExtension, CppExtension,
11 | CUDAExtension)
12 |
13 | version_file = 'basicsr/version.py'
14 |
15 |
16 | def readme():
17 | return ''
18 | # with open('README.md', encoding='utf-8') as f:
19 | # content = f.read()
20 | # return content
21 |
22 |
23 | def get_git_hash():
24 |
25 | def _minimal_ext_cmd(cmd):
26 | # construct minimal environment
27 | env = {}
28 | for k in ['SYSTEMROOT', 'PATH', 'HOME']:
29 | v = os.environ.get(k)
30 | if v is not None:
31 | env[k] = v
32 | # LANGUAGE is used on win32
33 | env['LANGUAGE'] = 'C'
34 | env['LANG'] = 'C'
35 | env['LC_ALL'] = 'C'
36 | out = subprocess.Popen(
37 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0]
38 | return out
39 |
40 | try:
41 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD'])
42 | sha = out.strip().decode('ascii')
43 | except OSError:
44 | sha = 'unknown'
45 |
46 | return sha
47 |
48 |
49 | def get_hash():
50 | if os.path.exists('.git'):
51 | sha = get_git_hash()[:7]
52 | elif os.path.exists(version_file):
53 | try:
54 | from basicsr.version import __version__
55 | sha = __version__.split('+')[-1]
56 | except ImportError:
57 | raise ImportError('Unable to get git version')
58 | else:
59 | sha = 'unknown'
60 |
61 | return sha
62 |
63 |
64 | def write_version_py():
65 | content = """# GENERATED VERSION FILE
66 | # TIME: {}
67 | __version__ = '{}'
68 | short_version = '{}'
69 | version_info = ({})
70 | """
71 | sha = get_hash()
72 | with open('VERSION', 'r') as f:
73 | SHORT_VERSION = f.read().strip()
74 | VERSION_INFO = ', '.join(
75 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')])
76 | VERSION = SHORT_VERSION + '+' + sha
77 |
78 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION,
79 | VERSION_INFO)
80 | with open(version_file, 'w') as f:
81 | f.write(version_file_str)
82 |
83 |
84 | def get_version():
85 | with open(version_file, 'r') as f:
86 | exec(compile(f.read(), version_file, 'exec'))
87 | return locals()['__version__']
88 |
89 |
90 | def make_cuda_ext(name, module, sources, sources_cuda=None):
91 | if sources_cuda is None:
92 | sources_cuda = []
93 | define_macros = []
94 | extra_compile_args = {'cxx': []}
95 |
96 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1':
97 | define_macros += [('WITH_CUDA', None)]
98 | extension = CUDAExtension
99 | extra_compile_args['nvcc'] = [
100 | '-D__CUDA_NO_HALF_OPERATORS__',
101 | '-D__CUDA_NO_HALF_CONVERSIONS__',
102 | '-D__CUDA_NO_HALF2_OPERATORS__',
103 | ]
104 | sources += sources_cuda
105 | else:
106 | print(f'Compiling {name} without CUDA')
107 | extension = CppExtension
108 |
109 | return extension(
110 | name=f'{module}.{name}',
111 | sources=[os.path.join(*module.split('.'), p) for p in sources],
112 | define_macros=define_macros,
113 | extra_compile_args=extra_compile_args)
114 |
115 |
116 | def get_requirements(filename='requirements.txt'):
117 | return []
118 | here = os.path.dirname(os.path.realpath(__file__))
119 | with open(os.path.join(here, filename), 'r') as f:
120 | requires = [line.replace('\n', '') for line in f.readlines()]
121 | return requires
122 |
123 |
124 | if __name__ == '__main__':
125 | if '--no_cuda_ext' in sys.argv:
126 | ext_modules = []
127 | sys.argv.remove('--no_cuda_ext')
128 | else:
129 | ext_modules = [
130 | make_cuda_ext(
131 | name='deform_conv_ext',
132 | module='basicsr.models.ops.dcn',
133 | sources=['src/deform_conv_ext.cpp'],
134 | sources_cuda=[
135 | 'src/deform_conv_cuda.cpp',
136 | 'src/deform_conv_cuda_kernel.cu'
137 | ]),
138 | make_cuda_ext(
139 | name='fused_act_ext',
140 | module='basicsr.models.ops.fused_act',
141 | sources=['src/fused_bias_act.cpp'],
142 | sources_cuda=['src/fused_bias_act_kernel.cu']),
143 | make_cuda_ext(
144 | name='upfirdn2d_ext',
145 | module='basicsr.models.ops.upfirdn2d',
146 | sources=['src/upfirdn2d.cpp'],
147 | sources_cuda=['src/upfirdn2d_kernel.cu']),
148 | ]
149 |
150 | write_version_py()
151 | setup(
152 | name='basicsr',
153 | version=get_version(),
154 | description='Open Source Image and Video Super-Resolution Toolbox',
155 | long_description=readme(),
156 | author='Xintao Wang',
157 | author_email='xintao.wang@outlook.com',
158 | keywords='computer vision, restoration, super resolution',
159 | url='https://github.com/xinntao/BasicSR',
160 | packages=find_packages(
161 | exclude=('options', 'datasets', 'experiments', 'results',
162 | 'tb_logger', 'wandb')),
163 | classifiers=[
164 | 'Development Status :: 4 - Beta',
165 | 'License :: OSI Approved :: Apache Software License',
166 | 'Operating System :: OS Independent',
167 | 'Programming Language :: Python :: 3',
168 | 'Programming Language :: Python :: 3.7',
169 | 'Programming Language :: Python :: 3.8',
170 | ],
171 | license='Apache License 2.0',
172 | setup_requires=['cython', 'numpy'],
173 | install_requires=get_requirements(),
174 | ext_modules=ext_modules,
175 | cmdclass={'build_ext': BuildExtension},
176 | zip_safe=False)
177 |
--------------------------------------------------------------------------------