├── .gitignore ├── LICENSE ├── README.md ├── VERSION ├── assets ├── dataset.jpg ├── real.jpg └── synthetic.jpg ├── basicsr ├── data │ ├── __init__.py │ ├── data_sampler.py │ ├── data_util.py │ ├── lap_dataset.py │ ├── prefetch_dataloader.py │ └── transforms.py ├── metrics │ ├── __init__.py │ ├── gmsd.py │ ├── metric_util.py │ └── psnr_ssim.py ├── models │ ├── __init__.py │ ├── archs │ │ ├── __init__.py │ │ ├── jarnet_arch.py │ │ └── nafnet_arch.py │ ├── base_model.py │ ├── image_restoration_model.py │ ├── losses │ │ ├── __init__.py │ │ ├── loss_util.py │ │ └── losses.py │ └── lr_scheduler.py ├── test.py ├── train_jarnet.py ├── train_others.py └── utils │ ├── __init__.py │ ├── bundle_submissions.py │ ├── create_lmdb.py │ ├── dist_util.py │ ├── download_util.py │ ├── face_util.py │ ├── file_client.py │ ├── flow_util.py │ ├── img_util.py │ ├── lmdb_util.py │ ├── logger.py │ ├── matlab_functions.py │ ├── misc.py │ └── options.py ├── options ├── test_JARNet_LAP.yml ├── test_NAFNet_LAP.yml ├── train_JARNet_LAP.yml └── train_NAFNet_LAP.yml ├── requirements.txt ├── setup.cfg └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | # ignored folders 2 | datasets/* 3 | experiments/* 4 | results/* 5 | tb_logger/* 6 | basicsr/tb_logger/* 7 | wandb/* 8 | tmp/* 9 | 10 | docs/api 11 | scripts/__init__.py 12 | 13 | *.DS_Store 14 | .idea 15 | 16 | # ignored files 17 | version.py 18 | 19 | # ignored files with suffix 20 | *.html 21 | *.png 22 | *.jpeg 23 | *.jpg 24 | *.gif 25 | *.pth 26 | *.zip 27 | *.npy 28 | *.tif 29 | 30 | # Byte-compiled / optimized / DLL files 31 | __pycache__/ 32 | *.py[cod] 33 | *$py.class 34 | 35 | # C extensions 36 | *.so 37 | 38 | # Distribution / packaging 39 | .Python 40 | build/ 41 | develop-eggs/ 42 | dist/ 43 | downloads/ 44 | eggs/ 45 | .eggs/ 46 | lib/ 47 | lib64/ 48 | parts/ 49 | sdist/ 50 | var/ 51 | wheels/ 52 | share/python-wheels/ 53 | *.egg-info/ 54 | .installed.cfg 55 | *.egg 56 | MANIFEST 57 | 58 | # PyInstaller 59 | # Usually these files are written by a python script from a template 60 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 61 | *.manifest 62 | *.spec 63 | 64 | # Installer logs 65 | pip-log.txt 66 | pip-delete-this-directory.txt 67 | 68 | # Unit test / coverage reports 69 | htmlcov/ 70 | .tox/ 71 | .nox/ 72 | .coverage 73 | .coverage.* 74 | .cache 75 | nosetests.xml 76 | coverage.xml 77 | *.cover 78 | *.py,cover 79 | .hypothesis/ 80 | .pytest_cache/ 81 | cover/ 82 | 83 | # Translations 84 | *.mo 85 | *.pot 86 | 87 | # Django stuff: 88 | *.log 89 | local_settings.py 90 | db.sqlite3 91 | db.sqlite3-journal 92 | 93 | # Flask stuff: 94 | instance/ 95 | .webassets-cache 96 | 97 | # Scrapy stuff: 98 | .scrapy 99 | 100 | # Sphinx documentation 101 | docs/_build/ 102 | 103 | # PyBuilder 104 | .pybuilder/ 105 | target/ 106 | 107 | # Jupyter Notebook 108 | .ipynb_checkpoints 109 | 110 | # IPython 111 | profile_default/ 112 | ipython_config.py 113 | 114 | # pyenv 115 | # For a library or package, you might want to ignore these files since the code is 116 | # intended to run in multiple environments; otherwise, check them in: 117 | # .python-version 118 | 119 | # pipenv 120 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 121 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 122 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 123 | # install all needed dependencies. 124 | #Pipfile.lock 125 | 126 | # poetry 127 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 128 | # This is especially recommended for binary packages to ensure reproducibility, and is more 129 | # commonly ignored for libraries. 130 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 131 | #poetry.lock 132 | 133 | # pdm 134 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 135 | #pdm.lock 136 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 137 | # in version control. 138 | # https://pdm.fming.dev/#use-with-ide 139 | .pdm.toml 140 | 141 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 142 | __pypackages__/ 143 | 144 | # Celery stuff 145 | celerybeat-schedule 146 | celerybeat.pid 147 | 148 | # SageMath parsed files 149 | *.sage.py 150 | 151 | # Environments 152 | .env 153 | .venv 154 | env/ 155 | venv/ 156 | ENV/ 157 | env.bak/ 158 | venv.bak/ 159 | 160 | # Spyder project settings 161 | .spyderproject 162 | .spyproject 163 | 164 | # Rope project settings 165 | .ropeproject 166 | 167 | # mkdocs documentation 168 | /site 169 | 170 | # mypy 171 | .mypy_cache/ 172 | .dmypy.json 173 | dmypy.json 174 | 175 | # Pyre type checker 176 | .pyre/ 177 | 178 | # pytype static type analyzer 179 | .pytype/ 180 | 181 | # Cython debug symbols 182 | cython_debug/ 183 | 184 | # PyCharm 185 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 186 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 187 | # and can be added to the global gitignore or merged into this file. For a more nuclear 188 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 189 | #.idea/ 190 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Deep Linear Array Pushbroom Image Restoration: A Degradation Pipeline and Jitter-Aware Restoration Network 2 | Zida Chen*, Ziran Zhang*, Haoying Li, Menghao Li, Yueting Chen, Qi Li, Huajun Feng, Zhihai Xu, Shiqi Chen 3 | 4 | State Key Laboratory of Extreme Photonics and Instrumentation, Zhejiang University 5 | 6 | ## News 7 | **2024.01.17** Our train/test code, LAP dataset and pre-trained model are available now. 8 | 9 | **2023.12.09** Our paper is accepted by **AAAI 2024**! [Paper Link](https://arxiv.org/abs/2401.08171) 10 | 11 | ## Get Started 12 | ### Data Preparation 13 | 14 | You can download our synthetic LAP [training](https://drive.google.com/file/d/1n8PHoG7RYrp3uvDU-5DXF4FFl_KHqZt1/view?usp=sharing) and [evaluation](https://drive.google.com/file/d/18hLBPKpbpSBRUpnRcM5zzZySyRurMk81/view?usp=sharing) dataset. 15 | 16 | Please unzip them and config the dataset path in your config file (e.g. options/test_JARNet_LAP.yaml). 17 | 18 | Note: You need to reserve at least **200GB** of disk space for storage of data. 19 | 20 | 21 | ### Environment 22 | 23 | - Install the packages in your environment (python >= 3.7): 24 | ``` 25 | pip install -r requirements.txt 26 | ``` 27 | 28 | - Build up [BasicSR](https://github.com/XPixelGroup/BasicSR) environment by running: 29 | ``` 30 | python setup.py develop --no_cuda_ext 31 | ``` 32 | 33 | - To choose the gpu id, please modify the following code in the train/evaluation script: 34 | ``` 35 | os.environ["CUDA_VISIBLE_DEVICES"] = {your gpu id} 36 | ``` 37 | 38 | ### Train 39 | 40 | You can train the JARNet by using: 41 | ``` 42 | python basicsr/train_jarnet.py -opt options/train_JARNet_LAP.yml 43 | ``` 44 | 45 | You can train other restoration model (e.g. NAFNet) by using: 46 | ``` 47 | python basicsr/train_others.py -opt options/train_NAFNet_LAP.yml 48 | ``` 49 | 50 | You should fill out the path of the dataset in train config (i.e. yaml) file. 51 | 52 | ### Evaluation 53 | 54 | You can evaluate the JARNet by using: 55 | ``` 56 | python basicsr/test.py -opt options/test_JARNet_LAP.yml 57 | ``` 58 | 59 | You should fill out the path of the dataset and pre-trained model in test config (i.e. yaml) file. 60 | 61 | You can infer JARNet by the [pre-trained model](https://drive.google.com/file/d/1IjFgrXD4YBmRaWPX-WukgvzRdJHrZqZ6/view?usp=sharing) on our LAP evaluation dataset. 62 | 63 | ## Visualization Comparisons 64 | ### Samples of Synthetic LAP Dataset 65 | 66 | 67 | ### Synthetic LAP Image Restoration Results 68 | 69 | 70 | ### Real-World LAP Image Restoration Results 71 | 72 | 73 | ## Citation 74 | 75 | If this repo helps you, please consider citing our work: 76 | 77 | ``` 78 | @inproceedings{chen2024deep, 79 | title={Deep Linear Array Pushbroom Image Restoration: A Degradation Pipeline and Jitter-Aware Restoration Network}, 80 | author={Zida Chen and Ziran Zhang and Haoying Li and Menghao Li and Yueting Chen and Qi Li and Huajun Feng and Zhihai Xu and Shiqi Chen}, 81 | booktitle={AAAI}, 82 | year={2024} 83 | } 84 | ``` 85 | 86 | ## Contact 87 | If you have any question, please contact zd_chen@zju.edu.cn. 88 | 89 | 90 | ## Acknowledgment 91 | Our code is based on the [BasicSR](https://github.com/XPixelGroup/BasicSR) toolbox. -------------------------------------------------------------------------------- /VERSION: -------------------------------------------------------------------------------- 1 | 1.2.0 2 | -------------------------------------------------------------------------------- /assets/dataset.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHW2000/JARNet/839c8986ead2951114de03598949af45f904d753/assets/dataset.jpg -------------------------------------------------------------------------------- /assets/real.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHW2000/JARNet/839c8986ead2951114de03598949af45f904d753/assets/real.jpg -------------------------------------------------------------------------------- /assets/synthetic.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JHW2000/JARNet/839c8986ead2951114de03598949af45f904d753/assets/synthetic.jpg -------------------------------------------------------------------------------- /basicsr/data/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import numpy as np 3 | import random 4 | import torch 5 | import torch.utils.data 6 | from functools import partial 7 | from os import path as osp 8 | 9 | from basicsr.data.prefetch_dataloader import PrefetchDataLoader 10 | from basicsr.utils import get_root_logger, scandir 11 | from basicsr.utils.dist_util import get_dist_info 12 | 13 | __all__ = ['create_dataset', 'create_dataloader'] 14 | 15 | # automatically scan and import dataset modules 16 | # scan all the files under the data folder with '_dataset' in file names 17 | data_folder = osp.dirname(osp.abspath(__file__)) 18 | dataset_filenames = [ 19 | osp.splitext(osp.basename(v))[0] for v in scandir(data_folder) 20 | if v.endswith('_dataset.py') 21 | ] 22 | # import all the dataset modules 23 | _dataset_modules = [ 24 | importlib.import_module(f'basicsr.data.{file_name}') 25 | for file_name in dataset_filenames 26 | ] 27 | 28 | 29 | def create_dataset(dataset_opt): 30 | """Create dataset. 31 | 32 | Args: 33 | dataset_opt (dict): Configuration for dataset. It constains: 34 | name (str): Dataset name. 35 | type (str): Dataset type. 36 | """ 37 | dataset_type = dataset_opt['type'] 38 | 39 | # dynamic instantiation 40 | for module in _dataset_modules: 41 | dataset_cls = getattr(module, dataset_type, None) 42 | if dataset_cls is not None: 43 | break 44 | if dataset_cls is None: 45 | raise ValueError(f'Dataset {dataset_type} is not found.') 46 | print("cls:",dataset_opt) 47 | 48 | dataset = dataset_cls(dataset_opt) 49 | 50 | logger = get_root_logger() 51 | logger.info( 52 | f'Dataset {dataset.__class__.__name__} - {dataset_opt["name"]} ' 53 | 'is created.') 54 | return dataset 55 | 56 | 57 | def create_dataloader(dataset, 58 | dataset_opt, 59 | num_gpu=1, 60 | dist=False, 61 | sampler=None, 62 | seed=None): 63 | """Create dataloader. 64 | 65 | Args: 66 | dataset (torch.utils.data.Dataset): Dataset. 67 | dataset_opt (dict): Dataset options. It contains the following keys: 68 | phase (str): 'train' or 'val'. 69 | num_worker_per_gpu (int): Number of workers for each GPU. 70 | batch_size_per_gpu (int): Training batch size for each GPU. 71 | num_gpu (int): Number of GPUs. Used only in the train phase. 72 | Default: 1. 73 | dist (bool): Whether in distributed training. Used only in the train 74 | phase. Default: False. 75 | sampler (torch.utils.data.sampler): Data sampler. Default: None. 76 | seed (int | None): Seed. Default: None 77 | """ 78 | phase = dataset_opt['phase'] 79 | rank, _ = get_dist_info() 80 | if phase == 'train': 81 | if dist: # distributed training 82 | batch_size = dataset_opt['batch_size_per_gpu'] 83 | num_workers = dataset_opt['num_worker_per_gpu'] 84 | else: # non-distributed training 85 | multiplier = 1 if num_gpu == 0 else num_gpu 86 | batch_size = dataset_opt['batch_size_per_gpu'] * multiplier 87 | num_workers = dataset_opt['num_worker_per_gpu'] * multiplier 88 | dataloader_args = dict( 89 | dataset=dataset, 90 | batch_size=batch_size, 91 | shuffle=False, 92 | num_workers=num_workers, 93 | sampler=sampler, 94 | drop_last=True) 95 | if sampler is None: 96 | dataloader_args['shuffle'] = True 97 | dataloader_args['worker_init_fn'] = partial( 98 | worker_init_fn, num_workers=num_workers, rank=rank, 99 | seed=seed) if seed is not None else None 100 | elif phase in ['val', 'test']: # validation 101 | dataloader_args = dict( 102 | dataset=dataset, batch_size=1, shuffle=False, num_workers=0) 103 | else: 104 | raise ValueError(f'Wrong dataset phase: {phase}. ' 105 | "Supported ones are 'train', 'val' and 'test'.") 106 | 107 | dataloader_args['pin_memory'] = dataset_opt.get('pin_memory', False) 108 | 109 | prefetch_mode = dataset_opt.get('prefetch_mode') 110 | if prefetch_mode == 'cpu': # CPUPrefetcher 111 | num_prefetch_queue = dataset_opt.get('num_prefetch_queue', 1) 112 | logger = get_root_logger() 113 | logger.info(f'Use {prefetch_mode} prefetch dataloader: ' 114 | f'num_prefetch_queue = {num_prefetch_queue}') 115 | return PrefetchDataLoader( 116 | num_prefetch_queue=num_prefetch_queue, **dataloader_args) 117 | else: 118 | # prefetch_mode=None: Normal dataloader 119 | # prefetch_mode='cuda': dataloader for CUDAPrefetcher 120 | return torch.utils.data.DataLoader(**dataloader_args) 121 | 122 | 123 | def worker_init_fn(worker_id, num_workers, rank, seed): 124 | # Set the worker seed to num_workers * rank + worker_id + seed 125 | worker_seed = num_workers * rank + worker_id + seed 126 | np.random.seed(worker_seed) 127 | random.seed(worker_seed) 128 | -------------------------------------------------------------------------------- /basicsr/data/data_sampler.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch.utils.data.sampler import Sampler 4 | 5 | 6 | class EnlargedSampler(Sampler): 7 | """Sampler that restricts data loading to a subset of the dataset. 8 | 9 | Modified from torch.utils.data.distributed.DistributedSampler 10 | Support enlarging the dataset for iteration-based training, for saving 11 | time when restart the dataloader after each epoch 12 | 13 | Args: 14 | dataset (torch.utils.data.Dataset): Dataset used for sampling. 15 | num_replicas (int | None): Number of processes participating in 16 | the training. It is usually the world_size. 17 | rank (int | None): Rank of the current process within num_replicas. 18 | ratio (int): Enlarging ratio. Default: 1. 19 | """ 20 | 21 | def __init__(self, dataset, num_replicas, rank, ratio=1): 22 | self.dataset = dataset 23 | self.num_replicas = num_replicas 24 | self.rank = rank 25 | self.epoch = 0 26 | self.num_samples = math.ceil( 27 | len(self.dataset) * ratio / self.num_replicas) 28 | self.total_size = self.num_samples * self.num_replicas 29 | 30 | def __iter__(self): 31 | # deterministically shuffle based on epoch 32 | g = torch.Generator() 33 | g.manual_seed(self.epoch) 34 | indices = torch.randperm(self.total_size, generator=g).tolist() 35 | 36 | dataset_size = len(self.dataset) 37 | indices = [v % dataset_size for v in indices] 38 | 39 | # subsample 40 | indices = indices[self.rank:self.total_size:self.num_replicas] 41 | assert len(indices) == self.num_samples 42 | 43 | return iter(indices) 44 | 45 | def __len__(self): 46 | return self.num_samples 47 | 48 | def set_epoch(self, epoch): 49 | self.epoch = epoch 50 | -------------------------------------------------------------------------------- /basicsr/data/lap_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from pathlib import Path 4 | from torch.utils import data as data 5 | from basicsr.data.transforms import paired_random_crop, four_random_crop 6 | from basicsr.utils import img2tensor 7 | import glob 8 | import os 9 | import torchvision.transforms as transforms 10 | import torch.nn.functional as F 11 | 12 | def unsqueeze_twice(x): 13 | return x.unsqueeze(0).unsqueeze(0) 14 | 15 | def warp(img,jit): 16 | jit = torch.from_numpy(-jit) 17 | # img = torch.from_numpy(img).float() # error: uint16 convert don't support 18 | img = torch.from_numpy(img.astype(np.int32)).float() # val psnr error, np.int32 pr np.float32 instead of np.int16 19 | h, w = img.shape 20 | grid_y, grid_x = torch.meshgrid(torch.arange(0, h), torch.arange(0, w)) 21 | grid = torch.stack((grid_x, grid_y), 2).type_as(img) 22 | grid.requires_grad = False 23 | 24 | grid_flow = grid + jit 25 | grid_flow = grid_flow.unsqueeze(0) 26 | grid_flow = grid_flow[:, :h, :w, :] 27 | grid_flow_x = 2.0 * grid_flow[:, :, :, 0] / max(w - 1, 1) - 1.0 28 | grid_flow_y = 2.0 * grid_flow[:, :, :, 1] / max(h - 1, 1) - 1.0 29 | grid_flow = torch.stack((grid_flow_x, grid_flow_y), dim = 3) 30 | 31 | img_tensor = unsqueeze_twice(img) 32 | img_subdivision = F.grid_sample(img_tensor, grid_flow, 33 | mode = 'bilinear', padding_mode = "reflection", align_corners = True) # nearest 34 | img = np.array(img_subdivision).astype(int)[0,0,:,:] 35 | return img 36 | 37 | class LAPDataSet(data.Dataset): 38 | def __init__(self, opt): 39 | super(LAPDataSet, self).__init__() 40 | self.opt = opt 41 | print(opt) 42 | self.path = sorted(glob.glob(os.path.join(Path(opt['dataroot_gt'])) + "/*.*")) 43 | self.transform = transforms.Compose( [transforms.ToTensor(),]) 44 | 45 | def __getitem__(self, index): 46 | index = int(index) 47 | data_dict = np.load(self.path[index % len(self.path)], allow_pickle=True ) 48 | data_dict = data_dict.item() 49 | 50 | img = data_dict["img_LAP"] 51 | gt = data_dict["img_gt"] 52 | flow = data_dict["jit_information_noise"] 53 | flow_gt = data_dict["jit_information"] 54 | 55 | img = np.expand_dims(img, axis=2) 56 | gt = np.expand_dims(gt, axis=2) 57 | 58 | # normalization 59 | img = img.astype(np.float32) / 255. / 255. 60 | gt = gt.astype(np.float32) / 255. / 255. 61 | 62 | if self.opt['name'] == 'TrainSet': 63 | scale = self.opt['scale'] 64 | gt_size = self.opt['gt_size'] 65 | # gt, img = paired_random_crop(gt, img, gt_size, scale, None) 66 | gt, img, flow, flow_gt = four_random_crop(gt, img, flow, flow_gt, gt_size, scale, None) 67 | 68 | img = img2tensor(img) 69 | gt = img2tensor(gt) 70 | flow = img2tensor(flow) 71 | flow_gt = img2tensor(flow_gt) 72 | 73 | return {'lq': img, 'gt': gt, 'flow': flow,"flow_gt": flow_gt,"lq_path": self.path[index % len(self.path)]} 74 | 75 | def __len__(self): 76 | return len(self.path) 77 | 78 | 79 | class LAPDataSetNoWarp(data.Dataset): 80 | def __init__(self, opt): 81 | super(LAPDataSetNoWarp, self).__init__() 82 | self.opt = opt 83 | print(opt) 84 | self.path = sorted(glob.glob(os.path.join(Path(opt['dataroot_gt'])) + "/*.*")) 85 | self.transform = transforms.Compose( [transforms.ToTensor(),]) 86 | 87 | def __getitem__(self, index): 88 | index = int(index) 89 | data_dict = np.load(self.path[index % len(self.path)], allow_pickle=True ) 90 | data_dict = data_dict.item() 91 | 92 | img = data_dict["img_LAP"] 93 | gt = data_dict["img_gt"] 94 | 95 | img = np.expand_dims(img, axis=2) 96 | gt = np.expand_dims(gt, axis=2) 97 | 98 | # normalization 99 | img = img.astype(np.float32) / 255. / 255. 100 | gt = gt.astype(np.float32) / 255. / 255. 101 | 102 | if self.opt['name'] == 'TrainSet': 103 | scale = self.opt['scale'] 104 | gt_size = self.opt['gt_size'] 105 | gt, img = paired_random_crop(gt, img, gt_size, scale, None) 106 | 107 | img = img2tensor(img) 108 | gt = img2tensor(gt) 109 | 110 | return {'lq': img, 'gt': gt,"lq_path": self.path[index % len(self.path)]} 111 | 112 | def __len__(self): 113 | return len(self.path) 114 | 115 | 116 | class LAPDataSetWarp(data.Dataset): 117 | def __init__(self, opt): 118 | super(LAPDataSetWarp, self).__init__() 119 | self.opt = opt 120 | print(opt) 121 | self.path = sorted(glob.glob(os.path.join(Path(opt['dataroot_gt'])) + "/*.*")) 122 | self.transform = transforms.Compose( [transforms.ToTensor(),]) 123 | 124 | def __getitem__(self, index): 125 | index = int(index) 126 | data_dict = np.load(self.path[index % len(self.path)], allow_pickle=True ) 127 | data_dict = data_dict.item() 128 | 129 | img = data_dict["img_LAP"] 130 | gt = data_dict["img_gt"] 131 | flow = data_dict["jit_information_noise"] 132 | flow_gt = data_dict["jit_information"] 133 | 134 | img = warp(img,flow) 135 | 136 | img = np.expand_dims(img, axis=2) 137 | gt = np.expand_dims(gt, axis=2) 138 | 139 | # normalization 140 | img = img.astype(np.float32) / 255. / 255. 141 | gt = gt.astype(np.float32) / 255. / 255. 142 | 143 | if self.opt['name'] == 'TrainSet': 144 | scale = self.opt['scale'] 145 | gt_size = self.opt['gt_size'] 146 | # gt, img = paired_random_crop(gt, img, gt_size, scale, None) 147 | gt, img, flow, flow_gt = four_random_crop(gt, img, flow, flow_gt, gt_size, scale, None) 148 | 149 | img = img2tensor(img) 150 | gt = img2tensor(gt) 151 | flow = img2tensor(flow) 152 | flow_gt = img2tensor(flow_gt) 153 | 154 | return {'lq': img, 'gt': gt, 'flow': flow,"flow_gt": flow_gt,"lq_path": self.path[index % len(self.path)]} 155 | 156 | def __len__(self): 157 | return len(self.path) -------------------------------------------------------------------------------- /basicsr/data/prefetch_dataloader.py: -------------------------------------------------------------------------------- 1 | import queue as Queue 2 | import threading 3 | import torch 4 | from torch.utils.data import DataLoader 5 | 6 | 7 | class PrefetchGenerator(threading.Thread): 8 | """A general prefetch generator. 9 | 10 | Ref: 11 | https://stackoverflow.com/questions/7323664/python-generator-pre-fetch 12 | 13 | Args: 14 | generator: Python generator. 15 | num_prefetch_queue (int): Number of prefetch queue. 16 | """ 17 | 18 | def __init__(self, generator, num_prefetch_queue): 19 | threading.Thread.__init__(self) 20 | self.queue = Queue.Queue(num_prefetch_queue) 21 | self.generator = generator 22 | self.daemon = True 23 | self.start() 24 | 25 | def run(self): 26 | for item in self.generator: 27 | self.queue.put(item) 28 | self.queue.put(None) 29 | 30 | def __next__(self): 31 | next_item = self.queue.get() 32 | if next_item is None: 33 | raise StopIteration 34 | return next_item 35 | 36 | def __iter__(self): 37 | return self 38 | 39 | 40 | class PrefetchDataLoader(DataLoader): 41 | """Prefetch version of dataloader. 42 | 43 | Ref: 44 | https://github.com/IgorSusmelj/pytorch-styleguide/issues/5# 45 | 46 | TODO: 47 | Need to test on single gpu and ddp (multi-gpu). There is a known issue in 48 | ddp. 49 | 50 | Args: 51 | num_prefetch_queue (int): Number of prefetch queue. 52 | kwargs (dict): Other arguments for dataloader. 53 | """ 54 | 55 | def __init__(self, num_prefetch_queue, **kwargs): 56 | self.num_prefetch_queue = num_prefetch_queue 57 | super(PrefetchDataLoader, self).__init__(**kwargs) 58 | 59 | def __iter__(self): 60 | return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue) 61 | 62 | 63 | class CPUPrefetcher(): 64 | """CPU prefetcher. 65 | 66 | Args: 67 | loader: Dataloader. 68 | """ 69 | 70 | def __init__(self, loader): 71 | self.ori_loader = loader 72 | self.loader = iter(loader) 73 | 74 | def next(self): 75 | try: 76 | return next(self.loader) 77 | except StopIteration: 78 | return None 79 | 80 | def reset(self): 81 | self.loader = iter(self.ori_loader) 82 | 83 | 84 | class CUDAPrefetcher(): 85 | """CUDA prefetcher. 86 | 87 | Ref: 88 | https://github.com/NVIDIA/apex/issues/304# 89 | 90 | It may consums more GPU memory. 91 | 92 | Args: 93 | loader: Dataloader. 94 | opt (dict): Options. 95 | """ 96 | 97 | def __init__(self, loader, opt): 98 | self.ori_loader = loader 99 | self.loader = iter(loader) 100 | self.opt = opt 101 | self.stream = torch.cuda.Stream() 102 | self.device = torch.device('cuda' if opt['num_gpu'] != 0 else 'cpu') 103 | self.preload() 104 | 105 | def preload(self): 106 | try: 107 | self.batch = next(self.loader) # self.batch is a dict 108 | except StopIteration: 109 | self.batch = None 110 | return None 111 | # put tensors to gpu 112 | with torch.cuda.stream(self.stream): 113 | for k, v in self.batch.items(): 114 | if torch.is_tensor(v): 115 | self.batch[k] = self.batch[k].to( 116 | device=self.device, non_blocking=True) 117 | 118 | def next(self): 119 | torch.cuda.current_stream().wait_stream(self.stream) 120 | batch = self.batch 121 | self.preload() 122 | return batch 123 | 124 | def reset(self): 125 | self.loader = iter(self.ori_loader) 126 | self.preload() 127 | -------------------------------------------------------------------------------- /basicsr/data/transforms.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import random 3 | import numpy as np 4 | 5 | def mod_crop(img, scale): 6 | """Mod crop images, used during testing. 7 | 8 | Args: 9 | img (ndarray): Input image. 10 | scale (int): Scale factor. 11 | 12 | Returns: 13 | ndarray: Result image. 14 | """ 15 | img = img.copy() 16 | if img.ndim in (2, 3): 17 | h, w = img.shape[0], img.shape[1] 18 | h_remainder, w_remainder = h % scale, w % scale 19 | img = img[:h - h_remainder, :w - w_remainder, ...] 20 | else: 21 | raise ValueError(f'Wrong img ndim: {img.ndim}.') 22 | return img 23 | 24 | def paired_random_crop(img_gts, img_lqs, lq_patch_size, scale, gt_path): 25 | """Paired random crop. 26 | 27 | It crops lists of lq and gt images with corresponding locations. 28 | 29 | Args: 30 | img_gts (list[ndarray] | ndarray): GT images. Note that all images 31 | should have the same shape. If the input is an ndarray, it will 32 | be transformed to a list containing itself. 33 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 34 | should have the same shape. If the input is an ndarray, it will 35 | be transformed to a list containing itself. 36 | lq_patch_size (int): LQ patch size. 37 | scale (int): Scale factor. 38 | gt_path (str): Path to ground-truth. 39 | 40 | Returns: 41 | list[ndarray] | ndarray: GT images and LQ images. If returned results 42 | only have one element, just return ndarray. 43 | """ 44 | 45 | if not isinstance(img_gts, list): 46 | img_gts = [img_gts] 47 | if not isinstance(img_lqs, list): 48 | img_lqs = [img_lqs] 49 | 50 | h_lq, w_lq, _ = img_lqs[0].shape 51 | h_gt, w_gt, _ = img_gts[0].shape 52 | gt_patch_size = int(lq_patch_size * scale) 53 | 54 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 55 | raise ValueError( 56 | f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 57 | f'multiplication of LQ ({h_lq}, {w_lq}).') 58 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 59 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 60 | f'({lq_patch_size}, {lq_patch_size}). ' 61 | f'Please remove {gt_path}.') 62 | 63 | # randomly choose top and left coordinates for lq patch 64 | top = random.randint(0, h_lq - lq_patch_size) 65 | left = random.randint(0, w_lq - lq_patch_size) 66 | 67 | # crop lq patch 68 | img_lqs = [ 69 | v[top:top + lq_patch_size, left:left + lq_patch_size, ...] 70 | for v in img_lqs 71 | ] 72 | 73 | # crop corresponding gt patch 74 | top_gt, left_gt = int(top * scale), int(left * scale) 75 | img_gts = [ 76 | v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] 77 | for v in img_gts 78 | ] 79 | if len(img_gts) == 1: 80 | img_gts = img_gts[0] 81 | if len(img_lqs) == 1: 82 | img_lqs = img_lqs[0] 83 | return img_gts, img_lqs 84 | 85 | def four_random_crop(img_gts, img_lqs,flow,flow_gt, lq_patch_size, scale, gt_path): 86 | """Paired random crop. 87 | 88 | It crops lists of lq and gt images with corresponding locations. 89 | 90 | Args: 91 | img_gts (list[ndarray] | ndarray): GT images. Note that all images 92 | should have the same shape. If the input is an ndarray, it will 93 | be transformed to a list containing itself. 94 | img_lqs (list[ndarray] | ndarray): LQ images. Note that all images 95 | should have the same shape. If the input is an ndarray, it will 96 | be transformed to a list containing itself. 97 | lq_patch_size (int): LQ patch size. 98 | scale (int): Scale factor. 99 | gt_path (str): Path to ground-truth. 100 | 101 | Returns: 102 | list[ndarray] | ndarray: GT images and LQ images. If returned results 103 | only have one element, just return ndarray. 104 | """ 105 | 106 | if not isinstance(img_gts, list): 107 | img_gts = [img_gts] 108 | if not isinstance(img_lqs, list): 109 | img_lqs = [img_lqs] 110 | # if not isinstance(img_gts, list): 111 | flow = [flow] 112 | # if not isinstance(img_lqs, list): 113 | flow_gt = [flow_gt] 114 | 115 | h_lq, w_lq, _ = img_lqs[0].shape 116 | h_gt, w_gt, _ = img_gts[0].shape 117 | gt_patch_size = int(lq_patch_size * scale) 118 | 119 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 120 | raise ValueError( 121 | f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 122 | f'multiplication of LQ ({h_lq}, {w_lq}).') 123 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 124 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 125 | f'({lq_patch_size}, {lq_patch_size}). ' 126 | f'Please remove {gt_path}.') 127 | 128 | # randomly choose top and left coordinates for lq patch 129 | top = random.randint(0, h_lq - lq_patch_size) 130 | left = random.randint(0, w_lq - lq_patch_size) 131 | 132 | # crop lq patch 133 | img_lqs = [ 134 | v[top:top + lq_patch_size, left:left + lq_patch_size, ...] 135 | for v in img_lqs 136 | ] 137 | 138 | flow = [ 139 | v[top:top + lq_patch_size, left:left + lq_patch_size, ...] 140 | for v in flow 141 | ] 142 | 143 | # crop corresponding gt patch 144 | top_gt, left_gt = int(top * scale), int(left * scale) 145 | img_gts = [ 146 | v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] 147 | for v in img_gts 148 | ] 149 | 150 | flow_gt = [ 151 | v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] 152 | for v in flow_gt 153 | ] 154 | if len(img_gts) == 1: 155 | img_gts = img_gts[0] 156 | if len(img_lqs) == 1: 157 | img_lqs = img_lqs[0] 158 | 159 | if len(flow) == 1: 160 | flow = flow[0] 161 | if len(flow_gt) == 1: 162 | flow_gt = flow_gt[0] 163 | return img_gts, img_lqs, flow, flow_gt 164 | 165 | def paired_random_crop_DP(img_lqLs, img_lqRs, img_gts, gt_patch_size, scale, gt_path): 166 | if not isinstance(img_gts, list): 167 | img_gts = [img_gts] 168 | if not isinstance(img_lqLs, list): 169 | img_lqLs = [img_lqLs] 170 | if not isinstance(img_lqRs, list): 171 | img_lqRs = [img_lqRs] 172 | 173 | h_lq, w_lq, _ = img_lqLs[0].shape 174 | h_gt, w_gt, _ = img_gts[0].shape 175 | lq_patch_size = gt_patch_size // scale 176 | 177 | if h_gt != h_lq * scale or w_gt != w_lq * scale: 178 | raise ValueError( 179 | f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', 180 | f'multiplication of LQ ({h_lq}, {w_lq}).') 181 | if h_lq < lq_patch_size or w_lq < lq_patch_size: 182 | raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size ' 183 | f'({lq_patch_size}, {lq_patch_size}). ' 184 | f'Please remove {gt_path}.') 185 | 186 | # randomly choose top and left coordinates for lq patch 187 | top = random.randint(0, h_lq - lq_patch_size) 188 | left = random.randint(0, w_lq - lq_patch_size) 189 | 190 | # crop lq patch 191 | img_lqLs = [ 192 | v[top:top + lq_patch_size, left:left + lq_patch_size, ...] 193 | for v in img_lqLs 194 | ] 195 | 196 | img_lqRs = [ 197 | v[top:top + lq_patch_size, left:left + lq_patch_size, ...] 198 | for v in img_lqRs 199 | ] 200 | 201 | # crop corresponding gt patch 202 | top_gt, left_gt = int(top * scale), int(left * scale) 203 | img_gts = [ 204 | v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] 205 | for v in img_gts 206 | ] 207 | if len(img_gts) == 1: 208 | img_gts = img_gts[0] 209 | if len(img_lqLs) == 1: 210 | img_lqLs = img_lqLs[0] 211 | if len(img_lqRs) == 1: 212 | img_lqRs = img_lqRs[0] 213 | return img_lqLs, img_lqRs, img_gts 214 | 215 | 216 | def augment(imgs, hflip=True, rotation=True, flows=None, return_status=False): 217 | """Augment: horizontal flips OR rotate (0, 90, 180, 270 degrees). 218 | 219 | We use vertical flip and transpose for rotation implementation. 220 | All the images in the list use the same augmentation. 221 | 222 | Args: 223 | imgs (list[ndarray] | ndarray): Images to be augmented. If the input 224 | is an ndarray, it will be transformed to a list. 225 | hflip (bool): Horizontal flip. Default: True. 226 | rotation (bool): Ratotation. Default: True. 227 | flows (list[ndarray]: Flows to be augmented. If the input is an 228 | ndarray, it will be transformed to a list. 229 | Dimension is (h, w, 2). Default: None. 230 | return_status (bool): Return the status of flip and rotation. 231 | Default: False. 232 | 233 | Returns: 234 | list[ndarray] | ndarray: Augmented images and flows. If returned 235 | results only have one element, just return ndarray. 236 | 237 | """ 238 | hflip = hflip and random.random() < 0.5 239 | vflip = rotation and random.random() < 0.5 240 | rot90 = rotation and random.random() < 0.5 241 | 242 | def _augment(img): 243 | if hflip: # horizontal 244 | cv2.flip(img, 1, img) 245 | if vflip: # vertical 246 | cv2.flip(img, 0, img) 247 | if rot90: 248 | img = img.transpose(1, 0, 2) 249 | return img 250 | 251 | def _augment_flow(flow): 252 | if hflip: # horizontal 253 | cv2.flip(flow, 1, flow) 254 | flow[:, :, 0] *= -1 255 | if vflip: # vertical 256 | cv2.flip(flow, 0, flow) 257 | flow[:, :, 1] *= -1 258 | if rot90: 259 | flow = flow.transpose(1, 0, 2) 260 | flow = flow[:, :, [1, 0]] 261 | return flow 262 | 263 | if not isinstance(imgs, list): 264 | imgs = [imgs] 265 | imgs = [_augment(img) for img in imgs] 266 | if len(imgs) == 1: 267 | imgs = imgs[0] 268 | 269 | if flows is not None: 270 | if not isinstance(flows, list): 271 | flows = [flows] 272 | flows = [_augment_flow(flow) for flow in flows] 273 | if len(flows) == 1: 274 | flows = flows[0] 275 | return imgs, flows 276 | else: 277 | if return_status: 278 | return imgs, (hflip, vflip, rot90) 279 | else: 280 | return imgs 281 | 282 | 283 | def img_rotate(img, angle, center=None, scale=1.0): 284 | """Rotate image. 285 | 286 | Args: 287 | img (ndarray): Image to be rotated. 288 | angle (float): Rotation angle in degrees. Positive values mean 289 | counter-clockwise rotation. 290 | center (tuple[int]): Rotation center. If the center is None, 291 | initialize it as the center of the image. Default: None. 292 | scale (float): Isotropic scale factor. Default: 1.0. 293 | """ 294 | (h, w) = img.shape[:2] 295 | 296 | if center is None: 297 | center = (w // 2, h // 2) 298 | 299 | matrix = cv2.getRotationMatrix2D(center, angle, scale) 300 | rotated_img = cv2.warpAffine(img, matrix, (w, h)) 301 | return rotated_img 302 | 303 | def data_augmentation(image, mode): 304 | """ 305 | Performs data augmentation of the input image 306 | Input: 307 | image: a cv2 (OpenCV) image 308 | mode: int. Choice of transformation to apply to the image 309 | 0 - no transformation 310 | 1 - flip up and down 311 | 2 - rotate counterwise 90 degree 312 | 3 - rotate 90 degree and flip up and down 313 | 4 - rotate 180 degree 314 | 5 - rotate 180 degree and flip 315 | 6 - rotate 270 degree 316 | 7 - rotate 270 degree and flip 317 | """ 318 | if mode == 0: 319 | # original 320 | out = image 321 | elif mode == 1: 322 | # flip up and down 323 | out = np.flipud(image) 324 | elif mode == 2: 325 | # rotate counterwise 90 degree 326 | out = np.rot90(image) 327 | elif mode == 3: 328 | # rotate 90 degree and flip up and down 329 | out = np.rot90(image) 330 | out = np.flipud(out) 331 | elif mode == 4: 332 | # rotate 180 degree 333 | out = np.rot90(image, k=2) 334 | elif mode == 5: 335 | # rotate 180 degree and flip 336 | out = np.rot90(image, k=2) 337 | out = np.flipud(out) 338 | elif mode == 6: 339 | # rotate 270 degree 340 | out = np.rot90(image, k=3) 341 | elif mode == 7: 342 | # rotate 270 degree and flip 343 | out = np.rot90(image, k=3) 344 | out = np.flipud(out) 345 | else: 346 | raise Exception('Invalid choice of image transformation') 347 | 348 | return out 349 | 350 | def random_augmentation(*args): 351 | out = [] 352 | flag_aug = random.randint(0,7) 353 | for data in args: 354 | out.append(data_augmentation(data, flag_aug).copy()) 355 | return out 356 | -------------------------------------------------------------------------------- /basicsr/metrics/__init__.py: -------------------------------------------------------------------------------- 1 | from .psnr_ssim import calculate_psnr, calculate_ssim 2 | from .gmsd import calculate_gmsd 3 | 4 | __all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_gmsd'] 5 | -------------------------------------------------------------------------------- /basicsr/metrics/gmsd.py: -------------------------------------------------------------------------------- 1 | r"""GMSD Metric 2 | 3 | Created by: https://github.com/dingkeyan93/IQA-optimization/blob/master/IQA_pytorch/GMSD.py 4 | 5 | Modified by: Jiadi Mo (https://github.com/JiadiMo) 6 | 7 | Refer to: 8 | Matlab code from https://www4.comp.polyu.edu.hk/~cslzhang/IQA/GMSD/GMSD.m; 9 | 10 | """ 11 | 12 | import torch 13 | from torch import nn 14 | from torch.nn import functional as F 15 | 16 | # from pyiqa.utils.color_util import to_y_channel 17 | # from pyiqa.utils.registry import ARCH_REGISTRY 18 | 19 | 20 | def calculate_gmsd( 21 | x: torch.Tensor, 22 | y: torch.Tensor, 23 | T: int = 170, 24 | channels: int = 1, # 3 25 | test_y_channel: bool = False, # True 26 | ) -> torch.Tensor: 27 | r"""GMSD metric. 28 | Args: 29 | x: A distortion tensor. Shape :math:`(N, C, H, W)`. 30 | y: A reference tensor. Shape :math:`(N, C, H, W)`. 31 | T: A positive constant that supplies numerical stability. 32 | channels: Number of channels. 33 | test_y_channel: bool, whether to use y channel on ycbcr. 34 | """ 35 | if test_y_channel: 36 | x = to_y_channel(x, 255) 37 | y = to_y_channel(y, 255) 38 | channels = 1 39 | else: 40 | x = x * 255. 41 | y = y * 255. 42 | 43 | dx = (torch.Tensor([[1, 0, -1], [1, 0, -1], [1, 0, -1]]) / 3.).unsqueeze(0).unsqueeze(0).repeat(channels, 1, 1, 44 | 1).to(x) 45 | dy = (torch.Tensor([[1, 1, 1], [0, 0, 0], [-1, -1, -1]]) / 3.).unsqueeze(0).unsqueeze(0).repeat(channels, 1, 1, 46 | 1).to(x) 47 | aveKernel = torch.ones(channels, 1, 2, 2).to(x) / 4. 48 | 49 | Y1 = F.conv2d(x, aveKernel, stride=2, padding=0, groups=channels) 50 | Y2 = F.conv2d(y, aveKernel, stride=2, padding=0, groups=channels) 51 | 52 | IxY1 = F.conv2d(Y1, dx, stride=1, padding=1, groups=channels) 53 | IyY1 = F.conv2d(Y1, dy, stride=1, padding=1, groups=channels) 54 | gradientMap1 = torch.sqrt(IxY1**2 + IyY1**2 + 1e-12) 55 | 56 | IxY2 = F.conv2d(Y2, dx, stride=1, padding=1, groups=channels) 57 | IyY2 = F.conv2d(Y2, dy, stride=1, padding=1, groups=channels) 58 | gradientMap2 = torch.sqrt(IxY2**2 + IyY2**2 + 1e-12) 59 | 60 | quality_map = (2 * gradientMap1 * gradientMap2 + T) / (gradientMap1**2 + gradientMap2**2 + T) 61 | score = torch.std(quality_map.view(quality_map.shape[0], -1), dim=1) 62 | 63 | return score.item() 64 | 65 | 66 | # @ARCH_REGISTRY.register() 67 | class GMSD(nn.Module): 68 | r'''Gradient Magnitude Similarity Deviation Metric. 69 | Args: 70 | channels: Number of channels. 71 | test_y_channel: bool, whether to use y channel on ycbcr. 72 | Reference: 73 | Xue, Wufeng, Lei Zhang, Xuanqin Mou, and Alan C. Bovik. 74 | "Gradient magnitude similarity deviation: A highly efficient 75 | perceptual image quality index." IEEE Transactions on Image 76 | Processing 23, no. 2 (2013): 684-695. 77 | ''' 78 | 79 | def __init__(self, channels: int = 3, test_y_channel: bool = True) -> None: 80 | super(GMSD, self).__init__() 81 | self.channels = channels 82 | self.test_y_channel = test_y_channel 83 | 84 | def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: 85 | r"""Args: 86 | x: A distortion tensor. Shape :math:`(N, C, H, W)`. 87 | y: A reference tensor. Shape :math:`(N, C, H, W)`. 88 | Order of input is important. 89 | """ 90 | assert x.shape == y.shape, f'Input and reference images should have the same shape, but got {x.shape} and {y.shape}' 91 | score = gmsd(x, y, channels=self.channels, test_y_channel=self.test_y_channel) 92 | 93 | return score -------------------------------------------------------------------------------- /basicsr/metrics/metric_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | from basicsr.utils.matlab_functions import bgr2ycbcr 4 | 5 | 6 | def reorder_image(img, input_order='HWC'): 7 | """Reorder images to 'HWC' order. 8 | 9 | If the input_order is (h, w), return (h, w, 1); 10 | If the input_order is (c, h, w), return (h, w, c); 11 | If the input_order is (h, w, c), return as it is. 12 | 13 | Args: 14 | img (ndarray): Input image. 15 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 16 | If the input image shape is (h, w), input_order will not have 17 | effects. Default: 'HWC'. 18 | 19 | Returns: 20 | ndarray: reordered image. 21 | """ 22 | 23 | if input_order not in ['HWC', 'CHW']: 24 | raise ValueError( 25 | f'Wrong input_order {input_order}. Supported input_orders are ' 26 | "'HWC' and 'CHW'") 27 | if len(img.shape) == 2: 28 | img = img[..., None] 29 | if input_order == 'CHW': 30 | img = img.transpose(1, 2, 0) 31 | return img 32 | 33 | 34 | def to_y_channel(img): 35 | """Change to Y channel of YCbCr. 36 | 37 | Args: 38 | img (ndarray): Images with range [0, 255]. 39 | 40 | Returns: 41 | (ndarray): Images with range [0, 255] (float type) without round. 42 | """ 43 | img = img.astype(np.float32) / 255. 44 | if img.ndim == 3 and img.shape[2] == 3: 45 | img = bgr2ycbcr(img, y_only=True) 46 | img = img[..., None] 47 | return img * 255. 48 | -------------------------------------------------------------------------------- /basicsr/metrics/psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | 4 | from basicsr.metrics.metric_util import reorder_image, to_y_channel 5 | import skimage.metrics 6 | import torch 7 | 8 | 9 | def calculate_psnr(img1, 10 | img2, 11 | crop_border, 12 | input_order='HWC', 13 | test_y_channel=False): 14 | """Calculate PSNR (Peak Signal-to-Noise Ratio). 15 | 16 | Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio 17 | 18 | Args: 19 | img1 (ndarray/tensor): Images with range [0, 255]/[0, 1]. 20 | img2 (ndarray/tensor): Images with range [0, 255]/[0, 1]. 21 | crop_border (int): Cropped pixels in each edge of an image. These 22 | pixels are not involved in the PSNR calculation. 23 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 24 | Default: 'HWC'. 25 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 26 | 27 | Returns: 28 | float: psnr result. 29 | """ 30 | 31 | assert img1.shape == img2.shape, ( 32 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 33 | if input_order not in ['HWC', 'CHW']: 34 | raise ValueError( 35 | f'Wrong input_order {input_order}. Supported input_orders are ' 36 | '"HWC" and "CHW"') 37 | if type(img1) == torch.Tensor: 38 | if len(img1.shape) == 4: 39 | img1 = img1.squeeze(0) 40 | img1 = img1.detach().cpu().numpy().transpose(1,2,0) 41 | if type(img2) == torch.Tensor: 42 | if len(img2.shape) == 4: 43 | img2 = img2.squeeze(0) 44 | img2 = img2.detach().cpu().numpy().transpose(1,2,0) 45 | 46 | img1 = reorder_image(img1, input_order=input_order) 47 | img2 = reorder_image(img2, input_order=input_order) 48 | img1 = img1.astype(np.float64) 49 | img2 = img2.astype(np.float64) 50 | 51 | if crop_border != 0: 52 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 53 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 54 | 55 | if test_y_channel: 56 | img1 = to_y_channel(img1) 57 | img2 = to_y_channel(img2) 58 | 59 | mse = np.mean((img1 - img2)**2) 60 | if mse == 0: 61 | return float('inf') 62 | max_value = 1. if img1.max() <= 1 else 255. 63 | return 20. * np.log10(max_value / np.sqrt(mse)) 64 | 65 | 66 | def _ssim(img1, img2): 67 | """Calculate SSIM (structural similarity) for one channel images. 68 | 69 | It is called by func:`calculate_ssim`. 70 | 71 | Args: 72 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 73 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 74 | 75 | Returns: 76 | float: ssim result. 77 | """ 78 | 79 | C1 = (0.01 * 255)**2 80 | C2 = (0.03 * 255)**2 81 | 82 | img1 = img1.astype(np.float64) 83 | img2 = img2.astype(np.float64) 84 | kernel = cv2.getGaussianKernel(11, 1.5) 85 | window = np.outer(kernel, kernel.transpose()) 86 | 87 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 88 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 89 | mu1_sq = mu1**2 90 | mu2_sq = mu2**2 91 | mu1_mu2 = mu1 * mu2 92 | sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq 93 | sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq 94 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 95 | 96 | ssim_map = ((2 * mu1_mu2 + C1) * 97 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 98 | (sigma1_sq + sigma2_sq + C2)) 99 | return ssim_map.mean() 100 | 101 | def prepare_for_ssim(img, k): 102 | import torch 103 | with torch.no_grad(): 104 | img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float() 105 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k//2, padding_mode='reflect') 106 | conv.weight.requires_grad = False 107 | conv.weight[:, :, :, :] = 1. / (k * k) 108 | 109 | img = conv(img) 110 | 111 | img = img.squeeze(0).squeeze(0) 112 | img = img[0::k, 0::k] 113 | return img.detach().cpu().numpy() 114 | 115 | def prepare_for_ssim_rgb(img, k): 116 | import torch 117 | with torch.no_grad(): 118 | img = torch.from_numpy(img).float() #HxWx3 119 | 120 | conv = torch.nn.Conv2d(1, 1, k, stride=1, padding=k // 2, padding_mode='reflect') 121 | conv.weight.requires_grad = False 122 | conv.weight[:, :, :, :] = 1. / (k * k) 123 | 124 | new_img = [] 125 | 126 | for i in range(3): 127 | new_img.append(conv(img[:, :, i].unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)[0::k, 0::k]) 128 | 129 | return torch.stack(new_img, dim=2).detach().cpu().numpy() 130 | 131 | def _3d_gaussian_calculator(img, conv3d): 132 | out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0) 133 | return out 134 | 135 | def _generate_3d_gaussian_kernel(): 136 | kernel = cv2.getGaussianKernel(11, 1.5) 137 | window = np.outer(kernel, kernel.transpose()) 138 | kernel_3 = cv2.getGaussianKernel(11, 1.5) 139 | kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0)) 140 | conv3d = torch.nn.Conv3d(1, 1, (11, 11, 11), stride=1, padding=(5, 5, 5), bias=False, padding_mode='replicate') 141 | conv3d.weight.requires_grad = False 142 | conv3d.weight[0, 0, :, :, :] = kernel 143 | return conv3d 144 | 145 | def _ssim_3d(img1, img2, max_value): 146 | assert len(img1.shape) == 3 and len(img2.shape) == 3 147 | """Calculate SSIM (structural similarity) for one channel images. 148 | 149 | It is called by func:`calculate_ssim`. 150 | 151 | Args: 152 | img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. 153 | img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'. 154 | 155 | Returns: 156 | float: ssim result. 157 | """ 158 | C1 = (0.01 * max_value) ** 2 159 | C2 = (0.03 * max_value) ** 2 160 | img1 = img1.astype(np.float64) 161 | img2 = img2.astype(np.float64) 162 | 163 | kernel = _generate_3d_gaussian_kernel().cuda() 164 | 165 | img1 = torch.tensor(img1).float().cuda() 166 | img2 = torch.tensor(img2).float().cuda() 167 | 168 | 169 | mu1 = _3d_gaussian_calculator(img1, kernel) 170 | mu2 = _3d_gaussian_calculator(img2, kernel) 171 | 172 | mu1_sq = mu1 ** 2 173 | mu2_sq = mu2 ** 2 174 | mu1_mu2 = mu1 * mu2 175 | sigma1_sq = _3d_gaussian_calculator(img1 ** 2, kernel) - mu1_sq 176 | sigma2_sq = _3d_gaussian_calculator(img2 ** 2, kernel) - mu2_sq 177 | sigma12 = _3d_gaussian_calculator(img1*img2, kernel) - mu1_mu2 178 | 179 | ssim_map = ((2 * mu1_mu2 + C1) * 180 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 181 | (sigma1_sq + sigma2_sq + C2)) 182 | return float(ssim_map.mean()) 183 | 184 | def _ssim_cly(img1, img2): 185 | assert len(img1.shape) == 2 and len(img2.shape) == 2 186 | """Calculate SSIM (structural similarity) for one channel images. 187 | 188 | It is called by func:`calculate_ssim`. 189 | 190 | Args: 191 | img1 (ndarray): Images with range [0, 255] with order 'HWC'. 192 | img2 (ndarray): Images with range [0, 255] with order 'HWC'. 193 | 194 | Returns: 195 | float: ssim result. 196 | """ 197 | 198 | C1 = (0.01 * 255)**2 199 | C2 = (0.03 * 255)**2 200 | img1 = img1.astype(np.float64) 201 | img2 = img2.astype(np.float64) 202 | 203 | kernel = cv2.getGaussianKernel(11, 1.5) 204 | # print(kernel) 205 | window = np.outer(kernel, kernel.transpose()) 206 | 207 | bt = cv2.BORDER_REPLICATE 208 | 209 | mu1 = cv2.filter2D(img1, -1, window, borderType=bt) 210 | mu2 = cv2.filter2D(img2, -1, window,borderType=bt) 211 | 212 | mu1_sq = mu1**2 213 | mu2_sq = mu2**2 214 | mu1_mu2 = mu1 * mu2 215 | sigma1_sq = cv2.filter2D(img1**2, -1, window, borderType=bt) - mu1_sq 216 | sigma2_sq = cv2.filter2D(img2**2, -1, window, borderType=bt) - mu2_sq 217 | sigma12 = cv2.filter2D(img1 * img2, -1, window, borderType=bt) - mu1_mu2 218 | 219 | ssim_map = ((2 * mu1_mu2 + C1) * 220 | (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * 221 | (sigma1_sq + sigma2_sq + C2)) 222 | return ssim_map.mean() 223 | 224 | 225 | def calculate_ssim(img1, 226 | img2, 227 | crop_border, 228 | input_order='HWC', 229 | test_y_channel=False): 230 | """Calculate SSIM (structural similarity). 231 | 232 | Ref: 233 | Image quality assessment: From error visibility to structural similarity 234 | 235 | The results are the same as that of the official released MATLAB code in 236 | https://ece.uwaterloo.ca/~z70wang/research/ssim/. 237 | 238 | For three-channel images, SSIM is calculated for each channel and then 239 | averaged. 240 | 241 | Args: 242 | img1 (ndarray): Images with range [0, 255]. 243 | img2 (ndarray): Images with range [0, 255]. 244 | crop_border (int): Cropped pixels in each edge of an image. These 245 | pixels are not involved in the SSIM calculation. 246 | input_order (str): Whether the input order is 'HWC' or 'CHW'. 247 | Default: 'HWC'. 248 | test_y_channel (bool): Test on Y channel of YCbCr. Default: False. 249 | 250 | Returns: 251 | float: ssim result. 252 | """ 253 | 254 | assert img1.shape == img2.shape, ( 255 | f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 256 | if input_order not in ['HWC', 'CHW']: 257 | raise ValueError( 258 | f'Wrong input_order {input_order}. Supported input_orders are ' 259 | '"HWC" and "CHW"') 260 | 261 | if type(img1) == torch.Tensor: 262 | if len(img1.shape) == 4: 263 | img1 = img1.squeeze(0) 264 | img1 = img1.detach().cpu().numpy().transpose(1,2,0) 265 | if type(img2) == torch.Tensor: 266 | if len(img2.shape) == 4: 267 | img2 = img2.squeeze(0) 268 | img2 = img2.detach().cpu().numpy().transpose(1,2,0) 269 | 270 | img1 = reorder_image(img1, input_order=input_order) 271 | img2 = reorder_image(img2, input_order=input_order) 272 | 273 | img1 = img1.astype(np.float64) 274 | img2 = img2.astype(np.float64) 275 | 276 | if crop_border != 0: 277 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 278 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 279 | 280 | if test_y_channel: 281 | img1 = to_y_channel(img1) 282 | img2 = to_y_channel(img2) 283 | return _ssim_cly(img1[..., 0], img2[..., 0]) 284 | 285 | 286 | ssims = [] 287 | # ssims_before = [] 288 | 289 | # skimage_before = skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True) 290 | # print('.._skimage', 291 | # skimage.metrics.structural_similarity(img1, img2, data_range=255., multichannel=True)) 292 | max_value = 1 if img1.max() <= 1 else 255 293 | with torch.no_grad(): 294 | final_ssim = _ssim_3d(img1, img2, max_value) 295 | ssims.append(final_ssim) 296 | 297 | # for i in range(img1.shape[2]): 298 | # ssims_before.append(_ssim(img1, img2)) 299 | 300 | # print('..ssim mean , new {:.4f} and before {:.4f} .... skimage before {:.4f}'.format(np.array(ssims).mean(), np.array(ssims_before).mean(), skimage_before)) 301 | # ssims.append(skimage.metrics.structural_similarity(img1[..., i], img2[..., i], multichannel=False)) 302 | 303 | return np.array(ssims).mean() 304 | -------------------------------------------------------------------------------- /basicsr/models/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import get_root_logger, scandir 5 | 6 | # automatically scan and import model modules 7 | # scan all the files under the 'models' folder and collect files ending with 8 | # '_model.py' 9 | model_folder = osp.dirname(osp.abspath(__file__)) 10 | model_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(model_folder) 12 | if v.endswith('_model.py') 13 | ] 14 | # import all the model modules 15 | _model_modules = [ 16 | importlib.import_module(f'basicsr.models.{file_name}') 17 | for file_name in model_filenames 18 | ] 19 | 20 | 21 | def create_model(opt): 22 | """Create model. 23 | 24 | Args: 25 | opt (dict): Configuration. It constains: 26 | model_type (str): Model type. 27 | """ 28 | model_type = opt['model_type'] 29 | 30 | # dynamic instantiation 31 | for module in _model_modules: 32 | model_cls = getattr(module, model_type, None) 33 | if model_cls is not None: 34 | break 35 | if model_cls is None: 36 | raise ValueError(f'Model {model_type} is not found.') 37 | 38 | model = model_cls(opt) 39 | 40 | logger = get_root_logger() 41 | logger.info(f'Model [{model.__class__.__name__}] is created.') 42 | return model 43 | -------------------------------------------------------------------------------- /basicsr/models/archs/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | 6 | # automatically scan and import arch modules 7 | # scan all the files under the 'archs' folder and collect files ending with 8 | # '_arch.py' 9 | arch_folder = osp.dirname(osp.abspath(__file__)) 10 | arch_filenames = [ 11 | osp.splitext(osp.basename(v))[0] for v in scandir(arch_folder) 12 | if v.endswith('_arch.py') 13 | ] 14 | # import all the arch modules 15 | _arch_modules = [ 16 | importlib.import_module(f'basicsr.models.archs.{file_name}') 17 | for file_name in arch_filenames 18 | ] 19 | 20 | 21 | def dynamic_instantiation(modules, cls_type, opt): 22 | """Dynamically instantiate class. 23 | 24 | Args: 25 | modules (list[importlib modules]): List of modules from importlib 26 | files. 27 | cls_type (str): Class type. 28 | opt (dict): Class initialization kwargs. 29 | 30 | Returns: 31 | class: Instantiated class. 32 | """ 33 | 34 | for module in modules: 35 | cls_ = getattr(module, cls_type, None) 36 | if cls_ is not None: 37 | break 38 | if cls_ is None: 39 | raise ValueError(f'{cls_type} is not found.') 40 | return cls_(**opt) 41 | 42 | 43 | def define_network(opt): 44 | network_type = opt.pop('type') 45 | net = dynamic_instantiation(_arch_modules, network_type, opt) 46 | return net 47 | -------------------------------------------------------------------------------- /basicsr/models/archs/nafnet_arch.py: -------------------------------------------------------------------------------- 1 | # ------------------------------------------------------------------------ 2 | # Copyright (c) 2022 megvii-model. All Rights Reserved. 3 | # ------------------------------------------------------------------------ 4 | 5 | ''' 6 | Simple Baselines for Image Restoration 7 | 8 | @article{chen2022simple, 9 | title={Simple Baselines for Image Restoration}, 10 | author={Chen, Liangyu and Chu, Xiaojie and Zhang, Xiangyu and Sun, Jian}, 11 | journal={arXiv preprint arXiv:2204.04676}, 12 | year={2022} 13 | } 14 | ''' 15 | 16 | import torch 17 | import torch.nn as nn 18 | import numpy as np 19 | import torch.nn.functional as F 20 | # from basicsr.models.archs.arch_util import LayerNorm2d 21 | # from basicsr.models.archs.local_arch import Local_Base 22 | 23 | 24 | class LayerNormFunction(torch.autograd.Function): 25 | 26 | @staticmethod 27 | def forward(ctx, x, weight, bias, eps): 28 | ctx.eps = eps 29 | N, C, H, W = x.size() 30 | mu = x.mean(1, keepdim=True) 31 | var = (x - mu).pow(2).mean(1, keepdim=True) 32 | y = (x - mu) / (var + eps).sqrt() 33 | ctx.save_for_backward(y, var, weight) 34 | y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1) 35 | return y 36 | 37 | @staticmethod 38 | def backward(ctx, grad_output): 39 | eps = ctx.eps 40 | 41 | N, C, H, W = grad_output.size() 42 | y, var, weight = ctx.saved_variables 43 | g = grad_output * weight.view(1, C, 1, 1) 44 | mean_g = g.mean(dim=1, keepdim=True) 45 | 46 | mean_gy = (g * y).mean(dim=1, keepdim=True) 47 | gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g) 48 | return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum( 49 | dim=0), None 50 | 51 | class LayerNorm2d(nn.Module): 52 | 53 | def __init__(self, channels, eps=1e-6): 54 | super(LayerNorm2d, self).__init__() 55 | self.register_parameter('weight', nn.Parameter(torch.ones(channels))) 56 | self.register_parameter('bias', nn.Parameter(torch.zeros(channels))) 57 | self.eps = eps 58 | 59 | def forward(self, x): 60 | return LayerNormFunction.apply(x, self.weight, self.bias, self.eps) 61 | 62 | 63 | 64 | class AvgPool2d(nn.Module): 65 | def __init__(self, kernel_size=None, base_size=None, auto_pad=True, fast_imp=False, train_size=None): 66 | super().__init__() 67 | self.kernel_size = kernel_size 68 | self.base_size = base_size 69 | self.auto_pad = auto_pad 70 | 71 | # only used for fast implementation 72 | self.fast_imp = fast_imp 73 | self.rs = [5, 4, 3, 2, 1] 74 | self.max_r1 = self.rs[0] 75 | self.max_r2 = self.rs[0] 76 | self.train_size = train_size 77 | 78 | def extra_repr(self): 79 | return 'kernel_size={}, base_size={}, stride={}, fast_imp={}'.format( 80 | self.kernel_size, self.base_size, self.kernel_size, self.fast_imp) 81 | 82 | def forward(self, x): 83 | if self.kernel_size is None and self.base_size: 84 | train_size = self.train_size 85 | if isinstance(self.base_size, int): 86 | self.base_size = (self.base_size, self.base_size) 87 | self.kernel_size = list(self.base_size) 88 | self.kernel_size[0] = x.shape[2] * self.base_size[0] // train_size[-2] # self.base_size[0] 89 | self.kernel_size[1] = x.shape[3] * self.base_size[1] // train_size[-1] # self.base_size[1] 90 | #print(self.kernel_size, self.base_size, train_size) 91 | # only used for fast implementation 92 | self.max_r1 = max(1, self.rs[0] * x.shape[2] // train_size[-2]) 93 | self.max_r2 = max(1, self.rs[0] * x.shape[3] // train_size[-1]) 94 | 95 | #print(self.extra_repr()) 96 | if self.kernel_size[0] >= x.size(-2) and self.kernel_size[1] >= x.size(-1): 97 | #print("Yes~") 98 | return F.adaptive_avg_pool2d(x, 1) 99 | 100 | if self.fast_imp: # Non-equivalent implementation but faster 101 | h, w = x.shape[2:] 102 | if self.kernel_size[0] >= h and self.kernel_size[1] >= w: 103 | out = F.adaptive_avg_pool2d(x, 1) 104 | else: 105 | r1 = [r for r in self.rs if h % r == 0][0] 106 | r2 = [r for r in self.rs if w % r == 0][0] 107 | # reduction_constraint 108 | r1 = min(self.max_r1, r1) 109 | r2 = min(self.max_r2, r2) 110 | s = x[:, :, ::r1, ::r2].cumsum(dim=-1).cumsum(dim=-2) 111 | n, c, h, w = s.shape 112 | k1, k2 = min(h - 1, self.kernel_size[0] // r1), min(w - 1, self.kernel_size[1] // r2) 113 | out = (s[:, :, :-k1, :-k2] - s[:, :, :-k1, k2:] - s[:, :, k1:, :-k2] + s[:, :, k1:, k2:]) / (k1 * k2) 114 | out = torch.nn.functional.interpolate(out, scale_factor=(r1, r2)) 115 | else: 116 | n, c, h, w = x.shape 117 | s = x.cumsum(dim=-1).cumsum_(dim=-2) 118 | s = torch.nn.functional.pad(s, (1, 0, 1, 0)) # pad 0 for convenience 119 | k1, k2 = min(h, self.kernel_size[0]), min(w, self.kernel_size[1]) 120 | s1, s2, s3, s4 = s[:, :, :-k1, :-k2], s[:, :, :-k1, k2:], s[:, :, k1:, :-k2], s[:, :, k1:, k2:] 121 | out = s4 + s1 - s2 - s3 122 | out = out / (k1 * k2) 123 | 124 | if self.auto_pad: 125 | n, c, h, w = x.shape 126 | _h, _w = out.shape[2:] 127 | # print(x.shape, self.kernel_size) 128 | pad2d = ((w - _w) // 2, (w - _w + 1) // 2, (h - _h) // 2, (h - _h + 1) // 2) 129 | out = torch.nn.functional.pad(out, pad2d, mode='replicate') 130 | 131 | return out 132 | 133 | def replace_layers(model, base_size, train_size, fast_imp, **kwargs): 134 | for n, m in model.named_children(): 135 | if len(list(m.children())) > 0: 136 | ## compound module, go inside it 137 | replace_layers(m, base_size, train_size, fast_imp, **kwargs) 138 | 139 | if isinstance(m, nn.AdaptiveAvgPool2d): 140 | pool = AvgPool2d(base_size=base_size, fast_imp=fast_imp, train_size=train_size) 141 | assert m.output_size == 1 142 | setattr(model, n, pool) 143 | 144 | 145 | class Local_Base(): 146 | def convert(self, *args, train_size, **kwargs): 147 | replace_layers(self, *args, train_size=train_size, **kwargs) 148 | imgs = torch.rand(train_size) 149 | with torch.no_grad(): 150 | self.forward(imgs) 151 | 152 | class SimpleGate(nn.Module): 153 | def forward(self, x): 154 | x1, x2 = x.chunk(2, dim=1) 155 | return x1 * x2 156 | 157 | class NAFBlock(nn.Module): 158 | def __init__(self, c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.): 159 | super().__init__() 160 | dw_channel = c * DW_Expand 161 | self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 162 | self.conv2 = nn.Conv2d(in_channels=dw_channel, out_channels=dw_channel, kernel_size=3, padding=1, stride=1, groups=dw_channel, 163 | bias=True) 164 | self.conv3 = nn.Conv2d(in_channels=dw_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 165 | 166 | # Simplified Channel Attention 167 | self.sca = nn.Sequential( 168 | nn.AdaptiveAvgPool2d(1), 169 | nn.Conv2d(in_channels=dw_channel // 2, out_channels=dw_channel // 2, kernel_size=1, padding=0, stride=1, 170 | groups=1, bias=True), 171 | ) 172 | 173 | # SimpleGate 174 | self.sg = SimpleGate() 175 | 176 | ffn_channel = FFN_Expand * c 177 | self.conv4 = nn.Conv2d(in_channels=c, out_channels=ffn_channel, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 178 | self.conv5 = nn.Conv2d(in_channels=ffn_channel // 2, out_channels=c, kernel_size=1, padding=0, stride=1, groups=1, bias=True) 179 | 180 | self.norm1 = LayerNorm2d(c) 181 | self.norm2 = LayerNorm2d(c) 182 | 183 | self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 184 | self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity() 185 | 186 | self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 187 | self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True) 188 | 189 | def forward(self, inp): 190 | x = inp 191 | 192 | x = self.norm1(x) 193 | 194 | x = self.conv1(x) 195 | x = self.conv2(x) 196 | x = self.sg(x) 197 | #print('--',x.shape, self.sca(x).shape) 198 | x = x * self.sca(x) 199 | x = self.conv3(x) 200 | 201 | x = self.dropout1(x) 202 | 203 | y = inp + x * self.beta 204 | 205 | x = self.conv4(self.norm2(y)) 206 | x = self.sg(x) 207 | x = self.conv5(x) 208 | 209 | x = self.dropout2(x) 210 | 211 | return y + x * self.gamma 212 | 213 | 214 | class NAFNet(nn.Module): 215 | 216 | def __init__(self, img_channel=3, width=16, middle_blk_num=1, enc_blk_nums=[], dec_blk_nums=[]): 217 | super().__init__() 218 | 219 | self.intro = nn.Conv2d(in_channels=img_channel, out_channels=width, kernel_size=3, padding=1, stride=1, groups=1, 220 | bias=True) 221 | self.ending = nn.Conv2d(in_channels=width, out_channels=img_channel, kernel_size=3, padding=1, stride=1, groups=1, 222 | bias=True) 223 | 224 | self.encoders = nn.ModuleList() 225 | self.decoders = nn.ModuleList() 226 | self.middle_blks = nn.ModuleList() 227 | self.ups = nn.ModuleList() 228 | self.downs = nn.ModuleList() 229 | 230 | chan = width 231 | for num in enc_blk_nums: 232 | self.encoders.append( 233 | nn.Sequential( 234 | *[NAFBlock(chan) for _ in range(num)] 235 | ) 236 | ) 237 | self.downs.append( 238 | nn.Conv2d(chan, 2*chan, 2, 2) 239 | ) 240 | chan = chan * 2 241 | 242 | self.middle_blks = \ 243 | nn.Sequential( 244 | *[NAFBlock(chan) for _ in range(middle_blk_num)] 245 | ) 246 | 247 | for num in dec_blk_nums: 248 | self.ups.append( 249 | nn.Sequential( 250 | nn.Conv2d(chan, chan * 2, 1, bias=False), 251 | nn.PixelShuffle(2) 252 | ) 253 | ) 254 | chan = chan // 2 255 | self.decoders.append( 256 | nn.Sequential( 257 | *[NAFBlock(chan) for _ in range(num)] 258 | ) 259 | ) 260 | 261 | self.padder_size = 2 ** len(self.encoders) 262 | 263 | def forward(self, inp): 264 | B, C, H, W = inp.shape 265 | inp = self.check_image_size(inp) 266 | 267 | x = self.intro(inp) 268 | 269 | encs = [] 270 | 271 | for encoder, down in zip(self.encoders, self.downs): 272 | x = encoder(x) 273 | encs.append(x) 274 | x = down(x) 275 | 276 | x = self.middle_blks(x) 277 | 278 | for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]): 279 | x = up(x) 280 | x = x + enc_skip 281 | x = decoder(x) 282 | 283 | x = self.ending(x) 284 | x = x + inp 285 | 286 | return x[:, :, :H, :W] 287 | 288 | def check_image_size(self, x): 289 | _, _, h, w = x.size() 290 | mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size 291 | mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size 292 | x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h)) 293 | return x 294 | 295 | class NAFNetLocal(Local_Base, NAFNet): 296 | def __init__(self, *args, train_size=(1, 1, 128, 128), base_size=(192, 192), fast_imp=False, **kwargs): 297 | Local_Base.__init__(self) 298 | NAFNet.__init__(self, *args, **kwargs) 299 | 300 | N, C, H, W = train_size 301 | base_size = (int(H * 1.5), int(W * 1.5)) 302 | 303 | self.eval() 304 | with torch.no_grad(): 305 | self.convert(base_size=base_size, train_size=train_size, fast_imp=fast_imp) 306 | 307 | 308 | if __name__ == '__main__': 309 | img_channel = 1 310 | enc_blks = [1, 1, 1, 28] 311 | middle_blk_num = 1 312 | dec_blks = [1, 1, 1, 1] 313 | width = 64 314 | 315 | net = NAFNet(img_channel=img_channel, width=width, middle_blk_num=middle_blk_num, 316 | enc_blk_nums=enc_blks, dec_blk_nums=dec_blks) 317 | 318 | inp_shape = (1, 256, 256) 319 | 320 | from ptflops import get_model_complexity_info 321 | 322 | macs, params = get_model_complexity_info(net, inp_shape, verbose=False, print_per_layer_stat=False) 323 | 324 | params = float(params[:-1]) 325 | macs = float(macs[:-4]) 326 | 327 | print(macs, params) -------------------------------------------------------------------------------- /basicsr/models/losses/__init__.py: -------------------------------------------------------------------------------- 1 | from .losses import (L1Loss, MSELoss, PSNRLoss, CharbonnierLoss,PerceptualLoss,FFTLoss) 2 | 3 | __all__ = [ 4 | 'L1Loss', 'MSELoss', 'PSNRLoss', 'CharbonnierLoss','PerceptualLoss','FFTLoss' 5 | ] -------------------------------------------------------------------------------- /basicsr/models/losses/loss_util.py: -------------------------------------------------------------------------------- 1 | import functools 2 | from torch.nn import functional as F 3 | 4 | 5 | def reduce_loss(loss, reduction): 6 | """Reduce loss as specified. 7 | 8 | Args: 9 | loss (Tensor): Elementwise loss tensor. 10 | reduction (str): Options are 'none', 'mean' and 'sum'. 11 | 12 | Returns: 13 | Tensor: Reduced loss tensor. 14 | """ 15 | reduction_enum = F._Reduction.get_enum(reduction) 16 | # none: 0, elementwise_mean:1, sum: 2 17 | if reduction_enum == 0: 18 | return loss 19 | elif reduction_enum == 1: 20 | return loss.mean() 21 | else: 22 | return loss.sum() 23 | 24 | 25 | def weight_reduce_loss(loss, weight=None, reduction='mean'): 26 | """Apply element-wise weight and reduce loss. 27 | 28 | Args: 29 | loss (Tensor): Element-wise loss. 30 | weight (Tensor): Element-wise weights. Default: None. 31 | reduction (str): Same as built-in losses of PyTorch. Options are 32 | 'none', 'mean' and 'sum'. Default: 'mean'. 33 | 34 | Returns: 35 | Tensor: Loss values. 36 | """ 37 | # if weight is specified, apply element-wise weight 38 | if weight is not None: 39 | assert weight.dim() == loss.dim() 40 | assert weight.size(1) == 1 or weight.size(1) == loss.size(1) 41 | loss = loss * weight 42 | 43 | # if weight is not specified or reduction is sum, just reduce the loss 44 | if weight is None or reduction == 'sum': 45 | loss = reduce_loss(loss, reduction) 46 | # if reduction is mean, then compute mean over weight region 47 | elif reduction == 'mean': 48 | if weight.size(1) > 1: 49 | weight = weight.sum() 50 | else: 51 | weight = weight.sum() * loss.size(1) 52 | loss = loss.sum() / weight 53 | 54 | return loss 55 | 56 | 57 | def weighted_loss(loss_func): 58 | """Create a weighted version of a given loss function. 59 | 60 | To use this decorator, the loss function must have the signature like 61 | `loss_func(pred, target, **kwargs)`. The function only needs to compute 62 | element-wise loss without any reduction. This decorator will add weight 63 | and reduction arguments to the function. The decorated function will have 64 | the signature like `loss_func(pred, target, weight=None, reduction='mean', 65 | **kwargs)`. 66 | 67 | :Example: 68 | 69 | >>> import torch 70 | >>> @weighted_loss 71 | >>> def l1_loss(pred, target): 72 | >>> return (pred - target).abs() 73 | 74 | >>> pred = torch.Tensor([0, 2, 3]) 75 | >>> target = torch.Tensor([1, 1, 1]) 76 | >>> weight = torch.Tensor([1, 0, 1]) 77 | 78 | >>> l1_loss(pred, target) 79 | tensor(1.3333) 80 | >>> l1_loss(pred, target, weight) 81 | tensor(1.5000) 82 | >>> l1_loss(pred, target, reduction='none') 83 | tensor([1., 1., 2.]) 84 | >>> l1_loss(pred, target, weight, reduction='sum') 85 | tensor(3.) 86 | """ 87 | 88 | @functools.wraps(loss_func) 89 | def wrapper(pred, target, weight=None, reduction='mean', **kwargs): 90 | # get element-wise loss 91 | loss = loss_func(pred, target, **kwargs) 92 | loss = weight_reduce_loss(loss, weight, reduction) 93 | return loss 94 | 95 | return wrapper 96 | -------------------------------------------------------------------------------- /basicsr/models/losses/losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn as nn 3 | from torch.nn import functional as F 4 | import numpy as np 5 | from basicsr.models.losses.loss_util import weighted_loss 6 | 7 | _reduction_modes = ['none', 'mean', 'sum'] 8 | 9 | 10 | @weighted_loss 11 | def l1_loss(pred, target): 12 | return F.l1_loss(pred, target, reduction='none') 13 | 14 | 15 | @weighted_loss 16 | def mse_loss(pred, target): 17 | return F.mse_loss(pred, target, reduction='none') 18 | 19 | 20 | class L1Loss(nn.Module): 21 | """L1 (mean absolute error, MAE) loss. 22 | 23 | Args: 24 | loss_weight (float): Loss weight for L1 loss. Default: 1.0. 25 | reduction (str): Specifies the reduction to apply to the output. 26 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 27 | """ 28 | 29 | def __init__(self, loss_weight=1.0, reduction='mean'): 30 | super(L1Loss, self).__init__() 31 | if reduction not in ['none', 'mean', 'sum']: 32 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 33 | f'Supported ones are: {_reduction_modes}') 34 | 35 | self.loss_weight = loss_weight 36 | self.reduction = reduction 37 | 38 | def forward(self, pred, target, weight=None, **kwargs): 39 | """ 40 | Args: 41 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 42 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 43 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 44 | weights. Default: None. 45 | """ 46 | return self.loss_weight * l1_loss( 47 | pred, target, weight, reduction=self.reduction) 48 | 49 | class MSELoss(nn.Module): 50 | """MSE (L2) loss. 51 | 52 | Args: 53 | loss_weight (float): Loss weight for MSE loss. Default: 1.0. 54 | reduction (str): Specifies the reduction to apply to the output. 55 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 56 | """ 57 | 58 | def __init__(self, loss_weight=1.0, reduction='mean'): 59 | super(MSELoss, self).__init__() 60 | if reduction not in ['none', 'mean', 'sum']: 61 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' 62 | f'Supported ones are: {_reduction_modes}') 63 | 64 | self.loss_weight = loss_weight 65 | self.reduction = reduction 66 | 67 | def forward(self, pred, target, weight=None, **kwargs): 68 | """ 69 | Args: 70 | pred (Tensor): of shape (N, C, H, W). Predicted tensor. 71 | target (Tensor): of shape (N, C, H, W). Ground truth tensor. 72 | weight (Tensor, optional): of shape (N, C, H, W). Element-wise 73 | weights. Default: None. 74 | """ 75 | return self.loss_weight * mse_loss( 76 | pred, target, weight, reduction=self.reduction) 77 | 78 | class PSNRLoss(nn.Module): 79 | 80 | def __init__(self, loss_weight=1.0, reduction='mean', toY=False): 81 | super(PSNRLoss, self).__init__() 82 | assert reduction == 'mean' 83 | self.loss_weight = loss_weight 84 | self.scale = 10 / np.log(10) 85 | self.toY = toY 86 | self.coef = torch.tensor([65.481, 128.553, 24.966]).reshape(1, 3, 1, 1) 87 | self.first = True 88 | 89 | def forward(self, pred, target): 90 | assert len(pred.size()) == 4 91 | if self.toY: 92 | if self.first: 93 | self.coef = self.coef.to(pred.device) 94 | self.first = False 95 | 96 | pred = (pred * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 97 | target = (target * self.coef).sum(dim=1).unsqueeze(dim=1) + 16. 98 | 99 | pred, target = pred / 255., target / 255. 100 | pass 101 | assert len(pred.size()) == 4 102 | 103 | return self.loss_weight * self.scale * torch.log(((pred - target) ** 2).mean(dim=(1, 2, 3)) + 1e-8).mean() 104 | 105 | class CharbonnierLoss(nn.Module): 106 | """Charbonnier Loss (L1)""" 107 | 108 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-3): 109 | super(CharbonnierLoss, self).__init__() 110 | self.eps = eps 111 | self.loss_weight = loss_weight 112 | 113 | def forward(self, x, y): 114 | diff = x - y 115 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 116 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 117 | return loss * self.loss_weight 118 | 119 | class FFTLoss(nn.Module): 120 | """L1 loss in frequency domain with FFT. 121 | 122 | Args: 123 | loss_weight (float): Loss weight for FFT loss. Default: 1.0. 124 | reduction (str): Specifies the reduction to apply to the output. 125 | Supported choices are 'none' | 'mean' | 'sum'. Default: 'mean'. 126 | """ 127 | 128 | def __init__(self, loss_weight=1.0, reduction='mean'): 129 | super(FFTLoss, self).__init__() 130 | if reduction not in ['none', 'mean', 'sum']: 131 | raise ValueError(f'Unsupported reduction mode: {reduction}. ' f'Supported ones are: {_reduction_modes}') 132 | 133 | self.loss_weight = loss_weight 134 | self.reduction = reduction 135 | 136 | def forward(self, pred, target, weight=None, **kwargs): 137 | """ 138 | Args: 139 | pred (Tensor): of shape (..., C, H, W). Predicted tensor. 140 | target (Tensor): of shape (..., C, H, W). Ground truth tensor. 141 | weight (Tensor, optional): of shape (..., C, H, W). Element-wise 142 | weights. Default: None. 143 | """ 144 | 145 | pred_fft = torch.fft.fft2(pred, dim=(-2, -1)) 146 | pred_fft = torch.stack([pred_fft.real, pred_fft.imag], dim=-1) 147 | target_fft = torch.fft.fft2(target, dim=(-2, -1)) 148 | target_fft = torch.stack([target_fft.real, target_fft.imag], dim=-1) 149 | return self.loss_weight * l1_loss(pred_fft, target_fft, weight, reduction=self.reduction) 150 | 151 | import torchvision.models as models 152 | class PerceptualLoss(nn.Module): 153 | 154 | def __init__(self, loss_weight=1.0, reduction='mean', eps=1e-6): 155 | super(PerceptualLoss, self).__init__() 156 | def contentFunc(): 157 | conv_3_3_layer = 14 158 | cnn = models.vgg19(weights=models.VGG19_Weights.DEFAULT).features 159 | model = nn.Sequential() 160 | for i,layer in enumerate(list(cnn)): 161 | model.add_module(str(i),layer) 162 | if i == conv_3_3_layer: 163 | break 164 | return model 165 | self.contentFunc = contentFunc() 166 | self.loss_weight = loss_weight 167 | self.reduction = reduction 168 | def forward(self, fakeIm, realIm): 169 | fakeIm_ = torch.cat((fakeIm, fakeIm,fakeIm), 1) 170 | realIm_ = torch.cat((realIm, realIm,realIm), 1) 171 | f_fake = self.contentFunc.forward(fakeIm_) 172 | f_real = self.contentFunc.forward(realIm_) 173 | f_real_no_grad = f_real.detach() 174 | # loss = torch.nn.MSELoss(f_fake, f_real_no_grad) 175 | loss = mse_loss(f_fake, f_real_no_grad, reduction=self.reduction) 176 | return loss * self.loss_weight 177 | -------------------------------------------------------------------------------- /basicsr/models/lr_scheduler.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | import torch 5 | 6 | 7 | class MultiStepRestartLR(_LRScheduler): 8 | """ MultiStep with restarts learning rate scheme. 9 | 10 | Args: 11 | optimizer (torch.nn.optimizer): Torch optimizer. 12 | milestones (list): Iterations that will decrease learning rate. 13 | gamma (float): Decrease ratio. Default: 0.1. 14 | restarts (list): Restart iterations. Default: [0]. 15 | restart_weights (list): Restart weights at each restart iteration. 16 | Default: [1]. 17 | last_epoch (int): Used in _LRScheduler. Default: -1. 18 | """ 19 | 20 | def __init__(self, 21 | optimizer, 22 | milestones, 23 | gamma=0.1, 24 | restarts=(0, ), 25 | restart_weights=(1, ), 26 | last_epoch=-1): 27 | self.milestones = Counter(milestones) 28 | self.gamma = gamma 29 | self.restarts = restarts 30 | self.restart_weights = restart_weights 31 | assert len(self.restarts) == len( 32 | self.restart_weights), 'restarts and their weights do not match.' 33 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 34 | 35 | def get_lr(self): 36 | if self.last_epoch in self.restarts: 37 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 38 | return [ 39 | group['initial_lr'] * weight 40 | for group in self.optimizer.param_groups 41 | ] 42 | if self.last_epoch not in self.milestones: 43 | return [group['lr'] for group in self.optimizer.param_groups] 44 | return [ 45 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 46 | for group in self.optimizer.param_groups 47 | ] 48 | 49 | class LinearLR(_LRScheduler): 50 | """ 51 | 52 | Args: 53 | optimizer (torch.nn.optimizer): Torch optimizer. 54 | milestones (list): Iterations that will decrease learning rate. 55 | gamma (float): Decrease ratio. Default: 0.1. 56 | last_epoch (int): Used in _LRScheduler. Default: -1. 57 | """ 58 | 59 | def __init__(self, 60 | optimizer, 61 | total_iter, 62 | last_epoch=-1): 63 | self.total_iter = total_iter 64 | super(LinearLR, self).__init__(optimizer, last_epoch) 65 | 66 | def get_lr(self): 67 | process = self.last_epoch / self.total_iter 68 | weight = (1 - process) 69 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups]) 70 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 71 | 72 | class VibrateLR(_LRScheduler): 73 | """ 74 | 75 | Args: 76 | optimizer (torch.nn.optimizer): Torch optimizer. 77 | milestones (list): Iterations that will decrease learning rate. 78 | gamma (float): Decrease ratio. Default: 0.1. 79 | last_epoch (int): Used in _LRScheduler. Default: -1. 80 | """ 81 | 82 | def __init__(self, 83 | optimizer, 84 | total_iter, 85 | last_epoch=-1): 86 | self.total_iter = total_iter 87 | super(VibrateLR, self).__init__(optimizer, last_epoch) 88 | 89 | def get_lr(self): 90 | process = self.last_epoch / self.total_iter 91 | 92 | f = 0.1 93 | if process < 3 / 8: 94 | f = 1 - process * 8 / 3 95 | elif process < 5 / 8: 96 | f = 0.2 97 | 98 | T = self.total_iter // 80 99 | Th = T // 2 100 | 101 | t = self.last_epoch % T 102 | 103 | f2 = t / Th 104 | if t >= Th: 105 | f2 = 2 - f2 106 | 107 | weight = f * f2 108 | 109 | if self.last_epoch < Th: 110 | weight = max(0.1, weight) 111 | 112 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2)) 113 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 114 | 115 | def get_position_from_periods(iteration, cumulative_period): 116 | """Get the position from a period list. 117 | 118 | It will return the index of the right-closest number in the period list. 119 | For example, the cumulative_period = [100, 200, 300, 400], 120 | if iteration == 50, return 0; 121 | if iteration == 210, return 2; 122 | if iteration == 300, return 2. 123 | 124 | Args: 125 | iteration (int): Current iteration. 126 | cumulative_period (list[int]): Cumulative period list. 127 | 128 | Returns: 129 | int: The position of the right-closest number in the period list. 130 | """ 131 | for i, period in enumerate(cumulative_period): 132 | # print("12230033",i, period,iteration ) 133 | if iteration <= period: 134 | return i 135 | 136 | 137 | class CosineAnnealingRestartLR(_LRScheduler): 138 | """ Cosine annealing with restarts learning rate scheme. 139 | 140 | An example of config: 141 | periods = [10, 10, 10, 10] 142 | restart_weights = [1, 0.5, 0.5, 0.5] 143 | eta_min=1e-7 144 | 145 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 146 | scheduler will restart with the weights in restart_weights. 147 | 148 | Args: 149 | optimizer (torch.nn.optimizer): Torch optimizer. 150 | periods (list): Period for each cosine anneling cycle. 151 | restart_weights (list): Restart weights at each restart iteration. 152 | Default: [1]. 153 | eta_min (float): The mimimum lr. Default: 0. 154 | last_epoch (int): Used in _LRScheduler. Default: -1. 155 | """ 156 | 157 | def __init__(self, 158 | optimizer, 159 | periods, 160 | restart_weights=(1, ), 161 | eta_min=0, 162 | last_epoch=-1): 163 | self.periods = periods 164 | self.restart_weights = restart_weights 165 | self.eta_min = eta_min 166 | assert (len(self.periods) == len(self.restart_weights) 167 | ), 'periods and restart_weights should have the same length.' 168 | self.cumulative_period = [ 169 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 170 | ] 171 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 172 | 173 | def get_lr(self): 174 | idx = get_position_from_periods(self.last_epoch, 175 | self.cumulative_period) 176 | current_weight = self.restart_weights[idx] 177 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 178 | current_period = self.periods[idx] 179 | 180 | return [ 181 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 182 | (1 + math.cos(math.pi * ( 183 | (self.last_epoch - nearest_restart) / current_period))) 184 | for base_lr in self.base_lrs 185 | ] 186 | 187 | class CosineAnnealingRestartCyclicLR(_LRScheduler): 188 | """ Cosine annealing with restarts learning rate scheme. 189 | An example of config: 190 | periods = [10, 10, 10, 10] 191 | restart_weights = [1, 0.5, 0.5, 0.5] 192 | eta_min=1e-7 193 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 194 | scheduler will restart with the weights in restart_weights. 195 | Args: 196 | optimizer (torch.nn.optimizer): Torch optimizer. 197 | periods (list): Period for each cosine anneling cycle. 198 | restart_weights (list): Restart weights at each restart iteration. 199 | Default: [1]. 200 | eta_min (float): The mimimum lr. Default: 0. 201 | last_epoch (int): Used in _LRScheduler. Default: -1. 202 | """ 203 | 204 | def __init__(self, 205 | optimizer, 206 | periods, 207 | restart_weights=(1, ), 208 | eta_mins=(0, ), 209 | last_epoch=-1): 210 | self.periods = periods 211 | self.restart_weights = restart_weights 212 | self.eta_mins = eta_mins 213 | assert (len(self.periods) == len(self.restart_weights) 214 | ), 'periods and restart_weights should have the same length.' 215 | self.cumulative_period = [ 216 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 217 | ] 218 | super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch) 219 | 220 | def get_lr(self): 221 | idx = get_position_from_periods(self.last_epoch, 222 | self.cumulative_period) 223 | # print("12230033",idx,self.last_epoch,self.cumulative_period) 224 | current_weight = self.restart_weights[idx] 225 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 226 | current_period = self.periods[idx] 227 | eta_min = self.eta_mins[idx] 228 | 229 | return [ 230 | eta_min + current_weight * 0.5 * (base_lr - eta_min) * 231 | (1 + math.cos(math.pi * ( 232 | (self.last_epoch - nearest_restart) / current_period))) 233 | for base_lr in self.base_lrs 234 | ] 235 | 236 | 237 | # ##测试代码 238 | # import torch 239 | # import torch.nn.functional as F 240 | # class Net(torch.nn.Module): 241 | # def __init__(self,n_feature,n_hidden,n_output): 242 | # super(Net,self).__init__() 243 | # self.hidden = torch.nn.Linear(n_feature,n_hidden) 244 | # self.predict = torch.nn.Linear(n_hidden,n_output) 245 | 246 | # def forward(self,x): 247 | # x = F.relu(self.hidden(x)) 248 | # x = self.predict(x) 249 | # return x 250 | # net = Net(1,10,1) 251 | # optimizer_G = torch.optim.Adam(net.parameters(), lr=0.001) 252 | # scheduler = CosineAnnealingRestartCyclicLR(optimizer_G, (10,)) 253 | # y = [] 254 | # for epoch in range(10): 255 | # lr = optimizer_G.state_dict()['param_groups'][0]['lr'] 256 | # y.append(lr) 257 | # scheduler.step() 258 | # print(y) 259 | # import matplotlib.pyplot as plt 260 | # plt.figure(dpi=300) 261 | # plt.plot(y) 262 | # plt.title('MultiplicativeLR') 263 | # plt.savefig("a.png") -------------------------------------------------------------------------------- /basicsr/test.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch 3 | from os import path as osp 4 | import os 5 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 6 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # choose your gpu id 7 | 8 | import random, argparse 9 | 10 | from basicsr.data import create_dataloader, create_dataset 11 | from basicsr.models import create_model 12 | from basicsr.utils import (get_env_info, get_root_logger, get_time_str, 13 | make_exp_dirs, set_random_seed) 14 | from basicsr.utils.options import dict2str, parse 15 | from basicsr.utils.dist_util import get_dist_info, init_dist 16 | 17 | def parse_options(is_train=True): 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument( 20 | '-opt', default="options/test_JARNet_LAP.yml", 21 | type=str)#, required=True, help='Path to option YAML file.') 22 | parser.add_argument( 23 | '--launcher', 24 | choices=['none', 'pytorch', 'slurm'], 25 | default='none', 26 | help='job launcher') 27 | parser.add_argument('--local_rank', type=int, default=0) 28 | args = parser.parse_args() 29 | opt = parse(args.opt, is_train=is_train) 30 | 31 | # distributed settings 32 | if args.launcher == 'none': 33 | opt['dist'] = False 34 | print('Disable distributed.', flush=True) 35 | else: 36 | opt['dist'] = True 37 | if args.launcher == 'slurm' and 'dist_params' in opt: 38 | init_dist(args.launcher, **opt['dist_params']) 39 | else: 40 | init_dist(args.launcher) 41 | print('init dist .. ', args.launcher) 42 | 43 | opt['rank'], opt['world_size'] = get_dist_info() 44 | 45 | # random seed 46 | seed = opt.get('manual_seed') 47 | if seed is None: 48 | seed = random.randint(1, 10000) 49 | opt['manual_seed'] = seed 50 | set_random_seed(seed + opt['rank']) 51 | 52 | return opt 53 | 54 | def test_pipeline(root_path): 55 | # parse options, set distributed setting, set ramdom seed 56 | #opt, _ = parse_options(root_path, is_train=False) 57 | opt = parse_options(is_train=False) 58 | 59 | torch.backends.cudnn.benchmark = True 60 | # torch.backends.cudnn.deterministic = True 61 | 62 | # mkdir and initialize loggers 63 | make_exp_dirs(opt) 64 | log_file = osp.join(opt['path']['log'], 65 | f"test_{opt['name']}_{get_time_str()}.log") 66 | logger = get_root_logger( 67 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 68 | logger.info(get_env_info()) 69 | logger.info(dict2str(opt)) 70 | 71 | # create test dataset and dataloader 72 | test_loaders = [] 73 | for phase, dataset_opt in sorted(opt['datasets'].items()): 74 | test_set = create_dataset(dataset_opt) 75 | test_loader = create_dataloader( 76 | test_set, 77 | dataset_opt, 78 | num_gpu=opt['num_gpu'], 79 | dist=opt['dist'], 80 | sampler=None, 81 | seed=opt['manual_seed']) 82 | logger.info( 83 | f"Number of test images in {dataset_opt['name']}: {len(test_set)}") 84 | test_loaders.append(test_loader) 85 | 86 | # create model 87 | model = create_model(opt) 88 | 89 | for test_loader in test_loaders: 90 | test_set_name = test_loader.dataset.opt['name'] 91 | logger.info(f'Testing {test_set_name}...') 92 | rgb2bgr = opt['val'].get('rgb2bgr', True) 93 | # wheather use uint8 image to compute metrics 94 | use_image = opt['val'].get('use_image', True) 95 | model.validation( 96 | test_loader, 97 | current_iter=opt['name'], 98 | tb_logger=None, 99 | save_img=opt['val']['save_img'], 100 | rgb2bgr=rgb2bgr, use_image=use_image) 101 | 102 | 103 | if __name__ == '__main__': 104 | root_path = osp.abspath(osp.join(__file__, osp.pardir, osp.pardir)) 105 | test_pipeline(root_path) 106 | -------------------------------------------------------------------------------- /basicsr/train_jarnet.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # choose your gpu id 4 | import torch 5 | import logging 6 | import math 7 | import time 8 | import random 9 | import numpy as np 10 | import datetime 11 | import argparse 12 | from basicsr.utils import (MessageLogger,make_exp_dirs,mkdir_and_rename,get_root_logger, 13 | get_env_info,init_wandb_logger,init_tb_logger, 14 | get_time_str,check_resume,set_random_seed) 15 | from os import path as osp 16 | from basicsr.utils.options import dict2str 17 | from basicsr.data import create_dataloader, create_dataset 18 | from basicsr.data.data_sampler import EnlargedSampler 19 | from basicsr.models import create_model 20 | from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher 21 | from basicsr.utils.dist_util import get_dist_info, init_dist 22 | from basicsr.utils.options import dict2str, parse 23 | 24 | def parse_options(is_train=True): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | '-opt', default="options/train_JARNet_LAP.yml", 28 | type=str)#, required=True, help='Path to option YAML file.') 29 | parser.add_argument( 30 | '--launcher', 31 | choices=['none', 'pytorch', 'slurm'], 32 | default='none', 33 | help='job launcher') 34 | parser.add_argument('--local-rank', type=int, default=0) 35 | args = parser.parse_args() 36 | opt = parse(args.opt, is_train=is_train) 37 | 38 | # distributed settings 39 | if args.launcher == 'none': 40 | opt['dist'] = False 41 | print('Disable distributed.', flush=True) 42 | else: 43 | opt['dist'] = True 44 | if args.launcher == 'slurm' and 'dist_params' in opt: 45 | init_dist(args.launcher, **opt['dist_params']) 46 | else: 47 | init_dist(args.launcher) 48 | print('init dist .. ', args.launcher) 49 | 50 | opt['rank'], opt['world_size'] = get_dist_info() 51 | 52 | # random seed 53 | seed = opt.get('manual_seed') 54 | if seed is None: 55 | seed = random.randint(1, 10000) 56 | opt['manual_seed'] = seed 57 | set_random_seed(seed + opt['rank']) 58 | 59 | return opt 60 | 61 | def init_loggers(opt): 62 | log_file = osp.join(opt['path']['log'], 63 | f"train_{opt['name']}_{get_time_str()}.log") 64 | logger = get_root_logger( 65 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 66 | logger.info(get_env_info()) 67 | logger.info(dict2str(opt)) 68 | 69 | # initialize wandb logger before tensorboard logger to allow proper sync: 70 | if (opt['logger'].get('wandb') 71 | is not None) and (opt['logger']['wandb'].get('project') 72 | is not None) and ('debug' not in opt['name']): 73 | assert opt['logger'].get('use_tb_logger') is True, ( 74 | 'should turn on tensorboard when using wandb') 75 | init_wandb_logger(opt) 76 | tb_logger = None 77 | if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']: 78 | tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) 79 | return logger, tb_logger 80 | 81 | def create_train_val_dataloader(opt, logger): 82 | # create train and val dataloaders 83 | train_loader, val_loader = None, None 84 | for phase, dataset_opt in opt['datasets'].items(): 85 | if phase == 'train': 86 | # if 0: 87 | dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) 88 | train_set = create_dataset(dataset_opt) 89 | train_sampler = EnlargedSampler(train_set, opt['world_size'], 90 | opt['rank'], dataset_enlarge_ratio) 91 | # train_sampler = None 92 | train_loader = create_dataloader( 93 | train_set, 94 | dataset_opt, 95 | num_gpu=opt['num_gpu'], 96 | dist=opt['dist'], 97 | sampler=train_sampler, 98 | seed=opt['manual_seed']) 99 | 100 | num_iter_per_epoch = math.ceil( 101 | len(train_set) * dataset_enlarge_ratio / 102 | (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) 103 | total_iters = int(opt['train']['total_iter']) 104 | # print("----------",num_iter_per_epoch,len(train_set) ) 105 | total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) 106 | logger.info( 107 | 'Training statistics:' 108 | f'\n\tNumber of train images: {len(train_set)}' 109 | f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' 110 | f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' 111 | f'\n\tWorld size (gpu number): {opt["world_size"]}' 112 | f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' 113 | f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') 114 | # elif 1: 115 | elif phase == 'val': 116 | val_set = create_dataset(dataset_opt) 117 | val_loader = create_dataloader( 118 | val_set, 119 | dataset_opt, 120 | num_gpu=opt['num_gpu'], 121 | dist=opt['dist'], 122 | sampler=None, 123 | seed=opt['manual_seed']) 124 | logger.info( 125 | f'Number of val images/folders in {dataset_opt["name"]}: ' 126 | f'{len(val_set)}') 127 | else: 128 | raise ValueError(f'Dataset phase {phase} is not recognized.') 129 | 130 | return train_loader, train_sampler, val_loader, total_epochs, total_iters 131 | 132 | 133 | def main(): 134 | # parse options, set distributed setting, set ramdom seed 135 | opt = parse_options(is_train=True) 136 | 137 | torch.backends.cudnn.benchmark = True 138 | # automatic resume .. 139 | state_folder_path = 'experiments/{}/training_states/'.format(opt['name']) 140 | 141 | try: 142 | states = os.listdir(state_folder_path) 143 | except: 144 | states = [] 145 | 146 | resume_state = None 147 | if len(states) > 0: 148 | max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states])) 149 | resume_state = os.path.join(state_folder_path, max_state_file) 150 | opt['path']['resume_state'] = resume_state 151 | 152 | # load resume states if necessary 153 | if opt['path'].get('resume_state'): 154 | device_id = torch.cuda.current_device() 155 | resume_state = torch.load( 156 | opt['path']['resume_state'], 157 | map_location=lambda storage, loc: storage.cuda(device_id)) 158 | else: 159 | resume_state = None 160 | 161 | # mkdir for experiments and logger 162 | if resume_state is None: 163 | make_exp_dirs(opt) 164 | if opt['logger'].get('use_tb_logger') and 'debug' not in opt[ 165 | 'name'] and opt['rank'] == 0: 166 | mkdir_and_rename(osp.join('tb_logger', opt['name'])) 167 | 168 | # initialize loggers 169 | logger, tb_logger = init_loggers(opt) 170 | 171 | # create train and validation dataloaders 172 | result = create_train_val_dataloader(opt, logger) 173 | train_loader, train_sampler, val_loader, total_epochs, total_iters = result 174 | 175 | # create model 176 | if resume_state: # resume training 177 | check_resume(opt, resume_state['iter']) 178 | model = create_model(opt) 179 | model.resume_training(resume_state) # handle optimizers and schedulers 180 | logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " 181 | f"iter: {resume_state['iter']}.") 182 | start_epoch = resume_state['epoch'] 183 | current_iter = resume_state['iter'] 184 | else: 185 | model = create_model(opt) 186 | start_epoch = 0 187 | current_iter = 0 188 | 189 | # create message logger (formatted outputs) 190 | msg_logger = MessageLogger(opt, current_iter, tb_logger) 191 | 192 | # dataloader prefetcher 193 | prefetch_mode = opt['datasets']['train'].get('prefetch_mode') 194 | if prefetch_mode is None or prefetch_mode == 'cpu': 195 | prefetcher = CPUPrefetcher(train_loader) 196 | elif prefetch_mode == 'cuda': 197 | prefetcher = CUDAPrefetcher(train_loader, opt) 198 | logger.info(f'Use {prefetch_mode} prefetch dataloader') 199 | if opt['datasets']['train'].get('pin_memory') is not True: 200 | raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') 201 | else: 202 | raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' 203 | "Supported ones are: None, 'cuda', 'cpu'.") 204 | 205 | 206 | # training 207 | logger.info( 208 | f'Start training from epoch: {start_epoch}, iter: {current_iter}') 209 | data_time, iter_time = time.time(), time.time() 210 | start_time = time.time() 211 | 212 | iters = opt['datasets']['train'].get('iters') 213 | batch_size = opt['datasets']['train'].get('batch_size_per_gpu') 214 | mini_batch_sizes = opt['datasets']['train'].get('mini_batch_sizes') 215 | gt_size = opt['datasets']['train'].get('gt_size') 216 | mini_gt_sizes = opt['datasets']['train'].get('gt_sizes') 217 | 218 | groups = np.array([sum(iters[0:i + 1]) for i in range(0, len(iters))]) 219 | 220 | logger_j = [True] * len(groups) 221 | 222 | scale = opt['scale'] 223 | 224 | epoch = start_epoch 225 | while current_iter <= total_iters: 226 | train_sampler.set_epoch(epoch) 227 | prefetcher.reset() 228 | train_data = prefetcher.next() 229 | 230 | while train_data is not None: 231 | data_time = time.time() - data_time 232 | 233 | current_iter += 1 234 | if current_iter > total_iters: 235 | break 236 | # update learning rate 237 | model.update_learning_rate( 238 | current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) 239 | 240 | ### ------Progressive learning --------------------- 241 | j = ((current_iter>groups) !=True).nonzero()[0] 242 | if len(j) == 0: 243 | bs_j = len(groups) - 1 244 | else: 245 | bs_j = j[0] 246 | 247 | mini_gt_size = mini_gt_sizes[bs_j] 248 | mini_batch_size = mini_batch_sizes[bs_j] 249 | 250 | if logger_j[bs_j]: 251 | logger.info('\n Updating Patch_Size to {} and Batch_Size to {} \n'.format(mini_gt_size, mini_batch_size*torch.cuda.device_count())) 252 | logger_j[bs_j] = False 253 | 254 | lq = train_data['lq'] 255 | gt = train_data['gt'] 256 | flow = train_data['flow'] 257 | flow_gt = train_data['flow_gt'] 258 | # print(lq.shape) 259 | 260 | # if mini_batch_size < batch_size: 261 | # indices = random.sample(range(0, batch_size), k=mini_batch_size) 262 | # lq = lq[indices] 263 | # gt = gt[indices] 264 | 265 | if mini_batch_size < batch_size: 266 | indices = random.sample(range(0, batch_size), k=mini_batch_size) 267 | lq = lq[indices] 268 | gt = gt[indices] 269 | flow = flow[indices] 270 | flow_gt = flow_gt[indices] 271 | 272 | if mini_gt_size < gt_size: 273 | x0 = int((gt_size - mini_gt_size) * random.random()) 274 | y0 = int((gt_size - mini_gt_size) * random.random()) 275 | x1 = x0 + mini_gt_size 276 | y1 = y0 + mini_gt_size 277 | lq = lq[:,:,x0:x1,y0:y1] 278 | gt = gt[:,:,x0*scale:x1*scale,y0*scale:y1*scale] 279 | flow = flow[:,:,x0:x1,y0:y1] 280 | flow_gt = flow_gt[:,:,x0:x1,y0:y1] 281 | ###------------------------------------------- 282 | 283 | model.feed_train_data({'lq': lq, 'gt':gt,'flow':flow,"flow_gt":flow_gt}) 284 | model.optimize_parameters(current_iter) 285 | 286 | iter_time = time.time() - iter_time 287 | # log 288 | if current_iter % opt['logger']['print_freq'] == 0: 289 | log_vars = {'epoch': epoch, 'iter': current_iter} 290 | log_vars.update({'lrs': model.get_current_learning_rate()}) 291 | log_vars.update({'time': iter_time, 'data_time': data_time}) 292 | log_vars.update(model.get_current_log()) 293 | msg_logger(log_vars) 294 | 295 | # save models and training states 296 | if current_iter % opt['logger']['save_checkpoint_freq'] == 0 or (current_iter > 300000 and current_iter % 10000 == 0): 297 | logger.info('Saving models and training states.') 298 | model.save(epoch, current_iter) 299 | 300 | # validation 301 | if opt.get('val') is not None and (current_iter % 302 | opt['val']['val_freq'] == 0 or (current_iter > 300000 and current_iter % 10000 == 0)): 303 | rgb2bgr = opt['val'].get('rgb2bgr', True) 304 | # wheather use uint8 image to compute metrics 305 | use_image = opt['val'].get('use_image', True) 306 | model.validation(val_loader, current_iter, tb_logger, 307 | opt['val']['save_img'], rgb2bgr, use_image ) 308 | 309 | data_time = time.time() 310 | iter_time = time.time() 311 | train_data = prefetcher.next() 312 | # a 313 | # end of iter 314 | epoch += 1 315 | 316 | # end of epoch 317 | 318 | consumed_time = str( 319 | datetime.timedelta(seconds=int(time.time() - start_time))) 320 | logger.info(f'End of training. Time consumed: {consumed_time}') 321 | logger.info('Save the latest model.') 322 | model.save(epoch=-1, current_iter=-1) # -1 stands for the latest 323 | if opt.get('val') is not None: 324 | model.validation(val_loader, current_iter, tb_logger, 325 | opt['val']['save_last_img']) 326 | if tb_logger: 327 | tb_logger.close() 328 | 329 | if __name__ == "__main__": 330 | main() -------------------------------------------------------------------------------- /basicsr/train_others.py: -------------------------------------------------------------------------------- 1 | import os 2 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 3 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" # choose your gpu id 4 | import torch 5 | import logging 6 | import math 7 | import time 8 | import random 9 | import numpy as np 10 | import datetime 11 | import argparse 12 | from basicsr.utils import (MessageLogger,make_exp_dirs,mkdir_and_rename,get_root_logger, 13 | get_env_info,init_wandb_logger,init_tb_logger, 14 | get_time_str,check_resume,set_random_seed) 15 | from os import path as osp 16 | from basicsr.utils.options import dict2str 17 | from basicsr.data import create_dataloader, create_dataset 18 | from basicsr.data.data_sampler import EnlargedSampler 19 | from basicsr.models import create_model 20 | from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher 21 | from basicsr.utils.dist_util import get_dist_info, init_dist 22 | from basicsr.utils.options import dict2str, parse 23 | 24 | def parse_options(is_train=True): 25 | parser = argparse.ArgumentParser() 26 | parser.add_argument( 27 | '-opt', default="options/train_NAFNet_LAP.yml", 28 | type=str)#, required=True, help='Path to option YAML file.') 29 | parser.add_argument( 30 | '--launcher', 31 | choices=['none', 'pytorch', 'slurm'], 32 | default='none', 33 | help='job launcher') 34 | parser.add_argument('--local-rank', type=int, default=0) 35 | args = parser.parse_args() 36 | opt = parse(args.opt, is_train=is_train) 37 | 38 | # distributed settings 39 | if args.launcher == 'none': 40 | opt['dist'] = False 41 | print('Disable distributed.', flush=True) 42 | else: 43 | opt['dist'] = True 44 | if args.launcher == 'slurm' and 'dist_params' in opt: 45 | init_dist(args.launcher, **opt['dist_params']) 46 | else: 47 | init_dist(args.launcher) 48 | print('init dist .. ', args.launcher) 49 | 50 | opt['rank'], opt['world_size'] = get_dist_info() 51 | 52 | # random seed 53 | seed = opt.get('manual_seed') 54 | if seed is None: 55 | seed = random.randint(1, 10000) 56 | opt['manual_seed'] = seed 57 | set_random_seed(seed + opt['rank']) 58 | 59 | return opt 60 | 61 | def init_loggers(opt): 62 | log_file = osp.join(opt['path']['log'], 63 | f"train_{opt['name']}_{get_time_str()}.log") 64 | logger = get_root_logger( 65 | logger_name='basicsr', log_level=logging.INFO, log_file=log_file) 66 | logger.info(get_env_info()) 67 | logger.info(dict2str(opt)) 68 | 69 | # initialize wandb logger before tensorboard logger to allow proper sync: 70 | if (opt['logger'].get('wandb') 71 | is not None) and (opt['logger']['wandb'].get('project') 72 | is not None) and ('debug' not in opt['name']): 73 | assert opt['logger'].get('use_tb_logger') is True, ( 74 | 'should turn on tensorboard when using wandb') 75 | init_wandb_logger(opt) 76 | tb_logger = None 77 | if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']: 78 | tb_logger = init_tb_logger(log_dir=osp.join('tb_logger', opt['name'])) 79 | return logger, tb_logger 80 | 81 | def create_train_val_dataloader(opt, logger): 82 | # create train and val dataloaders 83 | train_loader, val_loader = None, None 84 | for phase, dataset_opt in opt['datasets'].items(): 85 | if phase == 'train': 86 | dataset_enlarge_ratio = dataset_opt.get('dataset_enlarge_ratio', 1) 87 | train_set = create_dataset(dataset_opt) 88 | train_sampler = EnlargedSampler(train_set, opt['world_size'], 89 | opt['rank'], dataset_enlarge_ratio) 90 | train_loader = create_dataloader( 91 | train_set, 92 | dataset_opt, 93 | num_gpu=opt['num_gpu'], 94 | dist=opt['dist'], 95 | sampler=train_sampler, 96 | seed=opt['manual_seed']) 97 | 98 | num_iter_per_epoch = math.ceil( 99 | len(train_set) * dataset_enlarge_ratio / 100 | (dataset_opt['batch_size_per_gpu'] * opt['world_size'])) 101 | total_iters = int(opt['train']['total_iter']) 102 | total_epochs = math.ceil(total_iters / (num_iter_per_epoch)) 103 | logger.info( 104 | 'Training statistics:' 105 | f'\n\tNumber of train images: {len(train_set)}' 106 | f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}' 107 | f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}' 108 | f'\n\tWorld size (gpu number): {opt["world_size"]}' 109 | f'\n\tRequire iter number per epoch: {num_iter_per_epoch}' 110 | f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.') 111 | 112 | elif phase == 'val': 113 | val_set = create_dataset(dataset_opt) 114 | val_loader = create_dataloader( 115 | val_set, 116 | dataset_opt, 117 | num_gpu=opt['num_gpu'], 118 | dist=opt['dist'], 119 | sampler=None, 120 | seed=opt['manual_seed']) 121 | logger.info( 122 | f'Number of val images/folders in {dataset_opt["name"]}: ' 123 | f'{len(val_set)}') 124 | else: 125 | raise ValueError(f'Dataset phase {phase} is not recognized.') 126 | 127 | return train_loader, train_sampler, val_loader, total_epochs, total_iters 128 | 129 | 130 | def main(): 131 | # parse options, set distributed setting, set ramdom seed 132 | opt = parse_options(is_train=True) 133 | 134 | torch.backends.cudnn.benchmark = True 135 | # automatic resume .. 136 | state_folder_path = 'experiments/{}/training_states/'.format(opt['name']) 137 | 138 | try: 139 | states = os.listdir(state_folder_path) 140 | except: 141 | states = [] 142 | 143 | resume_state = None 144 | if len(states) > 0: 145 | max_state_file = '{}.state'.format(max([int(x[0:-6]) for x in states])) 146 | resume_state = os.path.join(state_folder_path, max_state_file) 147 | opt['path']['resume_state'] = resume_state 148 | 149 | # load resume states if necessary 150 | if opt['path'].get('resume_state'): 151 | device_id = torch.cuda.current_device() 152 | resume_state = torch.load( 153 | opt['path']['resume_state'], 154 | map_location=lambda storage, loc: storage.cuda(device_id)) 155 | else: 156 | resume_state = None 157 | 158 | # mkdir for experiments and logger 159 | if resume_state is None: 160 | make_exp_dirs(opt) 161 | if opt['logger'].get('use_tb_logger') and 'debug' not in opt[ 162 | 'name'] and opt['rank'] == 0: 163 | mkdir_and_rename(osp.join('tb_logger', opt['name'])) 164 | 165 | # initialize loggers 166 | logger, tb_logger = init_loggers(opt) 167 | 168 | # create train and validation dataloaders 169 | result = create_train_val_dataloader(opt, logger) 170 | train_loader, train_sampler, val_loader, total_epochs, total_iters = result 171 | 172 | # create model 173 | if resume_state: # resume training 174 | check_resume(opt, resume_state['iter']) 175 | model = create_model(opt) 176 | model.resume_training(resume_state) # handle optimizers and schedulers 177 | logger.info(f"Resuming training from epoch: {resume_state['epoch']}, " 178 | f"iter: {resume_state['iter']}.") 179 | start_epoch = resume_state['epoch'] 180 | current_iter = resume_state['iter'] 181 | else: 182 | model = create_model(opt) 183 | start_epoch = 0 184 | current_iter = 0 185 | 186 | # create message logger (formatted outputs) 187 | msg_logger = MessageLogger(opt, current_iter, tb_logger) 188 | 189 | # dataloader prefetcher 190 | prefetch_mode = opt['datasets']['train'].get('prefetch_mode') 191 | if prefetch_mode is None or prefetch_mode == 'cpu': 192 | prefetcher = CPUPrefetcher(train_loader) 193 | elif prefetch_mode == 'cuda': 194 | prefetcher = CUDAPrefetcher(train_loader, opt) 195 | logger.info(f'Use {prefetch_mode} prefetch dataloader') 196 | if opt['datasets']['train'].get('pin_memory') is not True: 197 | raise ValueError('Please set pin_memory=True for CUDAPrefetcher.') 198 | else: 199 | raise ValueError(f'Wrong prefetch_mode {prefetch_mode}.' 200 | "Supported ones are: None, 'cuda', 'cpu'.") 201 | 202 | 203 | # training 204 | logger.info( 205 | f'Start training from epoch: {start_epoch}, iter: {current_iter}') 206 | data_time, iter_time = time.time(), time.time() 207 | start_time = time.time() 208 | 209 | iters = opt['datasets']['train'].get('iters') 210 | batch_size = opt['datasets']['train'].get('batch_size_per_gpu') 211 | mini_batch_sizes = opt['datasets']['train'].get('mini_batch_sizes') 212 | gt_size = opt['datasets']['train'].get('gt_size') 213 | mini_gt_sizes = opt['datasets']['train'].get('gt_sizes') 214 | 215 | groups = np.array([sum(iters[0:i + 1]) for i in range(0, len(iters))]) 216 | 217 | logger_j = [True] * len(groups) 218 | 219 | scale = opt['scale'] 220 | 221 | epoch = start_epoch 222 | while current_iter <= total_iters: 223 | train_sampler.set_epoch(epoch) 224 | prefetcher.reset() 225 | train_data = prefetcher.next() 226 | 227 | while train_data is not None: 228 | data_time = time.time() - data_time 229 | 230 | current_iter += 1 231 | if current_iter > total_iters: 232 | break 233 | # update learning rate 234 | model.update_learning_rate( 235 | current_iter, warmup_iter=opt['train'].get('warmup_iter', -1)) 236 | 237 | ### ------Progressive learning --------------------- 238 | j = ((current_iter>groups) !=True).nonzero()[0] 239 | if len(j) == 0: 240 | bs_j = len(groups) - 1 241 | else: 242 | bs_j = j[0] 243 | 244 | mini_gt_size = mini_gt_sizes[bs_j] 245 | mini_batch_size = mini_batch_sizes[bs_j] 246 | 247 | if logger_j[bs_j]: 248 | logger.info('\n Updating Patch_Size to {} and Batch_Size to {} \n'.format(mini_gt_size, mini_batch_size*torch.cuda.device_count())) 249 | logger_j[bs_j] = False 250 | 251 | lq = train_data['lq'] 252 | gt = train_data['gt'] 253 | 254 | if mini_batch_size < batch_size: 255 | indices = random.sample(range(0, batch_size), k=mini_batch_size) 256 | lq = lq[indices] 257 | gt = gt[indices] 258 | 259 | # if mini_batch_size < batch_size: 260 | # indices = random.sample(range(0, batch_size), k=mini_batch_size) 261 | # lq = lq[indices] 262 | # gt = gt[indices] 263 | 264 | if mini_gt_size < gt_size: 265 | x0 = int((gt_size - mini_gt_size) * random.random()) 266 | y0 = int((gt_size - mini_gt_size) * random.random()) 267 | x1 = x0 + mini_gt_size 268 | y1 = y0 + mini_gt_size 269 | lq = lq[:,:,x0:x1,y0:y1] 270 | gt = gt[:,:,x0*scale:x1*scale,y0*scale:y1*scale] 271 | ###------------------------------------------- 272 | 273 | model.feed_train_data({'lq': lq, 'gt':gt}) 274 | model.optimize_parameters(current_iter) 275 | 276 | iter_time = time.time() - iter_time 277 | # log 278 | if current_iter % opt['logger']['print_freq'] == 0: 279 | log_vars = {'epoch': epoch, 'iter': current_iter} 280 | log_vars.update({'lrs': model.get_current_learning_rate()}) 281 | log_vars.update({'time': iter_time, 'data_time': data_time}) 282 | log_vars.update(model.get_current_log()) 283 | msg_logger(log_vars) 284 | 285 | # save models and training states 286 | if current_iter % opt['logger']['save_checkpoint_freq'] == 0 or (current_iter > 300000 and current_iter % 10000 == 0): 287 | logger.info('Saving models and training states.') 288 | model.save(epoch, current_iter) 289 | 290 | # validation 291 | if opt.get('val') is not None and ((current_iter % 292 | opt['val']['val_freq'] == 0) or (current_iter > 300000 and current_iter % 10000 == 0)): 293 | rgb2bgr = opt['val'].get('rgb2bgr', True) 294 | # wheather use uint8 image to compute metrics 295 | use_image = opt['val'].get('use_image', True) 296 | model.validation(val_loader, current_iter, tb_logger, 297 | opt['val']['save_img'], rgb2bgr, use_image ) 298 | 299 | data_time = time.time() 300 | iter_time = time.time() 301 | train_data = prefetcher.next() 302 | 303 | epoch += 1 304 | 305 | # end of epoch 306 | 307 | consumed_time = str( 308 | datetime.timedelta(seconds=int(time.time() - start_time))) 309 | logger.info(f'End of training. Time consumed: {consumed_time}') 310 | logger.info('Save the latest model.') 311 | model.save(epoch=-1, current_iter=-1) # -1 stands for the latest 312 | if opt.get('val') is not None: 313 | model.validation(val_loader, current_iter, tb_logger, 314 | opt['val']['save_last_img']) 315 | if tb_logger: 316 | tb_logger.close() 317 | 318 | if __name__ == "__main__": 319 | main() -------------------------------------------------------------------------------- /basicsr/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .file_client import FileClient 2 | from .img_util import crop_border, imfrombytes, img2tensor, imwrite, tensor2img, padding, padding_DP, imfrombytesDP 3 | from .logger import (MessageLogger, get_env_info, get_root_logger, 4 | init_tb_logger, init_wandb_logger) 5 | from .misc import (check_resume, get_time_str, make_exp_dirs, mkdir_and_rename, 6 | scandir, scandir_SIDD, set_random_seed, sizeof_fmt) 7 | from .create_lmdb import (create_lmdb_for_reds, create_lmdb_for_gopro, create_lmdb_for_rain13k) 8 | 9 | __all__ = [ 10 | # file_client.py 11 | 'FileClient', 12 | # img_util.py 13 | 'img2tensor', 14 | 'tensor2img', 15 | 'imfrombytes', 16 | 'imwrite', 17 | 'crop_border', 18 | # logger.py 19 | 'MessageLogger', 20 | 'init_tb_logger', 21 | 'init_wandb_logger', 22 | 'get_root_logger', 23 | 'get_env_info', 24 | # misc.py 25 | 'set_random_seed', 26 | 'get_time_str', 27 | 'mkdir_and_rename', 28 | 'make_exp_dirs', 29 | 'scandir', 30 | 'check_resume', 31 | 'sizeof_fmt', 32 | 'padding', 33 | 'padding_DP', 34 | 'imfrombytesDP', 35 | 'create_lmdb_for_reds', 36 | 'create_lmdb_for_gopro', 37 | 'create_lmdb_for_rain13k', 38 | ] 39 | -------------------------------------------------------------------------------- /basicsr/utils/bundle_submissions.py: -------------------------------------------------------------------------------- 1 | # Author: Tobias Plötz, TU Darmstadt (tobias.ploetz@visinf.tu-darmstadt.de) 2 | 3 | # This file is part of the implementation as described in the CVPR 2017 paper: 4 | # Tobias Plötz and Stefan Roth, Benchmarking Denoising Algorithms with Real Photographs. 5 | # Please see the file LICENSE.txt for the license governing this code. 6 | 7 | 8 | import numpy as np 9 | import scipy.io as sio 10 | import os 11 | import h5py 12 | 13 | def bundle_submissions_raw(submission_folder,session): 14 | ''' 15 | Bundles submission data for raw denoising 16 | 17 | submission_folder Folder where denoised images reside 18 | 19 | Output is written to /bundled/. Please submit 20 | the content of this folder. 21 | ''' 22 | 23 | out_folder = os.path.join(submission_folder, session) 24 | # out_folder = os.path.join(submission_folder, "bundled/") 25 | try: 26 | os.mkdir(out_folder) 27 | except:pass 28 | 29 | israw = True 30 | eval_version="1.0" 31 | 32 | for i in range(50): 33 | Idenoised = np.zeros((20,), dtype=np.object) 34 | for bb in range(20): 35 | filename = '%04d_%02d.mat'%(i+1,bb+1) 36 | s = sio.loadmat(os.path.join(submission_folder,filename)) 37 | Idenoised_crop = s["Idenoised_crop"] 38 | Idenoised[bb] = Idenoised_crop 39 | filename = '%04d.mat'%(i+1) 40 | sio.savemat(os.path.join(out_folder, filename), 41 | {"Idenoised": Idenoised, 42 | "israw": israw, 43 | "eval_version": eval_version}, 44 | ) 45 | 46 | def bundle_submissions_srgb(submission_folder,session): 47 | ''' 48 | Bundles submission data for sRGB denoising 49 | 50 | submission_folder Folder where denoised images reside 51 | 52 | Output is written to /bundled/. Please submit 53 | the content of this folder. 54 | ''' 55 | out_folder = os.path.join(submission_folder, session) 56 | # out_folder = os.path.join(submission_folder, "bundled/") 57 | try: 58 | os.mkdir(out_folder) 59 | except:pass 60 | israw = False 61 | eval_version="1.0" 62 | 63 | for i in range(50): 64 | Idenoised = np.zeros((20,), dtype=np.object) 65 | for bb in range(20): 66 | filename = '%04d_%02d.mat'%(i+1,bb+1) 67 | s = sio.loadmat(os.path.join(submission_folder,filename)) 68 | Idenoised_crop = s["Idenoised_crop"] 69 | Idenoised[bb] = Idenoised_crop 70 | filename = '%04d.mat'%(i+1) 71 | sio.savemat(os.path.join(out_folder, filename), 72 | {"Idenoised": Idenoised, 73 | "israw": israw, 74 | "eval_version": eval_version}, 75 | ) 76 | 77 | 78 | 79 | def bundle_submissions_srgb_v1(submission_folder,session): 80 | ''' 81 | Bundles submission data for sRGB denoising 82 | 83 | submission_folder Folder where denoised images reside 84 | 85 | Output is written to /bundled/. Please submit 86 | the content of this folder. 87 | ''' 88 | out_folder = os.path.join(submission_folder, session) 89 | # out_folder = os.path.join(submission_folder, "bundled/") 90 | try: 91 | os.mkdir(out_folder) 92 | except:pass 93 | israw = False 94 | eval_version="1.0" 95 | 96 | for i in range(50): 97 | Idenoised = np.zeros((20,), dtype=np.object) 98 | for bb in range(20): 99 | filename = '%04d_%d.mat'%(i+1,bb+1) 100 | s = sio.loadmat(os.path.join(submission_folder,filename)) 101 | Idenoised_crop = s["Idenoised_crop"] 102 | Idenoised[bb] = Idenoised_crop 103 | filename = '%04d.mat'%(i+1) 104 | sio.savemat(os.path.join(out_folder, filename), 105 | {"Idenoised": Idenoised, 106 | "israw": israw, 107 | "eval_version": eval_version}, 108 | ) -------------------------------------------------------------------------------- /basicsr/utils/create_lmdb.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | from os import path as osp 3 | 4 | from basicsr.utils import scandir 5 | from basicsr.utils.lmdb_util import make_lmdb_from_imgs 6 | 7 | def prepare_keys(folder_path, suffix='png'): 8 | """Prepare image path list and keys for DIV2K dataset. 9 | 10 | Args: 11 | folder_path (str): Folder path. 12 | 13 | Returns: 14 | list[str]: Image path list. 15 | list[str]: Key list. 16 | """ 17 | print('Reading image path list ...') 18 | img_path_list = sorted( 19 | list(scandir(folder_path, suffix=suffix, recursive=False))) 20 | keys = [img_path.split('.{}'.format(suffix))[0] for img_path in sorted(img_path_list)] 21 | 22 | return img_path_list, keys 23 | 24 | def create_lmdb_for_reds(): 25 | folder_path = './datasets/REDS/val/sharp_300' 26 | lmdb_path = './datasets/REDS/val/sharp_300.lmdb' 27 | img_path_list, keys = prepare_keys(folder_path, 'png') 28 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 29 | # 30 | folder_path = './datasets/REDS/val/blur_300' 31 | lmdb_path = './datasets/REDS/val/blur_300.lmdb' 32 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 33 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 34 | 35 | folder_path = './datasets/REDS/train/train_sharp' 36 | lmdb_path = './datasets/REDS/train/train_sharp.lmdb' 37 | img_path_list, keys = prepare_keys(folder_path, 'png') 38 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 39 | 40 | folder_path = './datasets/REDS/train/train_blur_jpeg' 41 | lmdb_path = './datasets/REDS/train/train_blur_jpeg.lmdb' 42 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 43 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 44 | 45 | 46 | def create_lmdb_for_gopro(): 47 | folder_path = './datasets/GoPro/train/blur_crops' 48 | lmdb_path = './datasets/GoPro/train/blur_crops.lmdb' 49 | 50 | img_path_list, keys = prepare_keys(folder_path, 'png') 51 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 52 | 53 | folder_path = './datasets/GoPro/train/sharp_crops' 54 | lmdb_path = './datasets/GoPro/train/sharp_crops.lmdb' 55 | 56 | img_path_list, keys = prepare_keys(folder_path, 'png') 57 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 58 | 59 | folder_path = './datasets/GoPro/test/target' 60 | lmdb_path = './datasets/GoPro/test/target.lmdb' 61 | 62 | img_path_list, keys = prepare_keys(folder_path, 'png') 63 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 64 | 65 | folder_path = './datasets/GoPro/test/input' 66 | lmdb_path = './datasets/GoPro/test/input.lmdb' 67 | 68 | img_path_list, keys = prepare_keys(folder_path, 'png') 69 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 70 | 71 | def create_lmdb_for_rain13k(): 72 | folder_path = './datasets/Rain13k/train/input' 73 | lmdb_path = './datasets/Rain13k/train/input.lmdb' 74 | 75 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 76 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 77 | 78 | folder_path = './datasets/Rain13k/train/target' 79 | lmdb_path = './datasets/Rain13k/train/target.lmdb' 80 | 81 | img_path_list, keys = prepare_keys(folder_path, 'jpg') 82 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 83 | 84 | def create_lmdb_for_SIDD(): 85 | folder_path = './datasets/SIDD/train/input_crops' 86 | lmdb_path = './datasets/SIDD/train/input_crops.lmdb' 87 | 88 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 89 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 90 | 91 | folder_path = './datasets/SIDD/train/gt_crops' 92 | lmdb_path = './datasets/SIDD/train/gt_crops.lmdb' 93 | 94 | img_path_list, keys = prepare_keys(folder_path, 'PNG') 95 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 96 | 97 | #for val 98 | folder_path = './datasets/SIDD/val/input_crops' 99 | lmdb_path = './datasets/SIDD/val/input_crops.lmdb' 100 | mat_path = './datasets/SIDD/ValidationNoisyBlocksSrgb.mat' 101 | if not osp.exists(folder_path): 102 | os.makedirs(folder_path) 103 | assert osp.exists(mat_path) 104 | data = scio.loadmat(mat_path)['ValidationNoisyBlocksSrgb'] 105 | N, B, H ,W, C = data.shape 106 | data = data.reshape(N*B, H, W, C) 107 | for i in tqdm(range(N*B)): 108 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 109 | img_path_list, keys = prepare_keys(folder_path, 'png') 110 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 111 | 112 | folder_path = './datasets/SIDD/val/gt_crops' 113 | lmdb_path = './datasets/SIDD/val/gt_crops.lmdb' 114 | mat_path = './datasets/SIDD/ValidationGtBlocksSrgb.mat' 115 | if not osp.exists(folder_path): 116 | os.makedirs(folder_path) 117 | assert osp.exists(mat_path) 118 | data = scio.loadmat(mat_path)['ValidationGtBlocksSrgb'] 119 | N, B, H ,W, C = data.shape 120 | data = data.reshape(N*B, H, W, C) 121 | for i in tqdm(range(N*B)): 122 | cv2.imwrite(osp.join(folder_path, 'ValidationBlocksSrgb_{}.png'.format(i)), cv2.cvtColor(data[i,...], cv2.COLOR_RGB2BGR)) 123 | img_path_list, keys = prepare_keys(folder_path, 'png') 124 | make_lmdb_from_imgs(folder_path, lmdb_path, img_path_list, keys) 125 | -------------------------------------------------------------------------------- /basicsr/utils/dist_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/dist_utils.py # noqa: E501 2 | import functools 3 | import os 4 | import subprocess 5 | import torch 6 | import torch.distributed as dist 7 | import torch.multiprocessing as mp 8 | 9 | 10 | def init_dist(launcher, backend='nccl', **kwargs): 11 | if mp.get_start_method(allow_none=True) is None: 12 | mp.set_start_method('spawn') 13 | if launcher == 'pytorch': 14 | _init_dist_pytorch(backend, **kwargs) 15 | elif launcher == 'slurm': 16 | _init_dist_slurm(backend, **kwargs) 17 | else: 18 | raise ValueError(f'Invalid launcher type: {launcher}') 19 | 20 | 21 | def _init_dist_pytorch(backend, **kwargs): 22 | rank = int(os.environ['RANK']) 23 | num_gpus = torch.cuda.device_count() 24 | torch.cuda.set_device(rank % num_gpus) 25 | dist.init_process_group(backend=backend, **kwargs) 26 | 27 | 28 | def _init_dist_slurm(backend, port=None): 29 | """Initialize slurm distributed training environment. 30 | 31 | If argument ``port`` is not specified, then the master port will be system 32 | environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system 33 | environment variable, then a default port ``29500`` will be used. 34 | 35 | Args: 36 | backend (str): Backend of torch.distributed. 37 | port (int, optional): Master port. Defaults to None. 38 | """ 39 | proc_id = int(os.environ['SLURM_PROCID']) 40 | ntasks = int(os.environ['SLURM_NTASKS']) 41 | node_list = os.environ['SLURM_NODELIST'] 42 | num_gpus = torch.cuda.device_count() 43 | torch.cuda.set_device(proc_id % num_gpus) 44 | addr = subprocess.getoutput( 45 | f'scontrol show hostname {node_list} | head -n1') 46 | # specify master port 47 | if port is not None: 48 | os.environ['MASTER_PORT'] = str(port) 49 | elif 'MASTER_PORT' in os.environ: 50 | pass # use MASTER_PORT in the environment variable 51 | else: 52 | # 29500 is torch.distributed default port 53 | os.environ['MASTER_PORT'] = '29500' 54 | os.environ['MASTER_ADDR'] = addr 55 | os.environ['WORLD_SIZE'] = str(ntasks) 56 | os.environ['LOCAL_RANK'] = str(proc_id % num_gpus) 57 | os.environ['RANK'] = str(proc_id) 58 | dist.init_process_group(backend=backend) 59 | 60 | 61 | def get_dist_info(): 62 | if dist.is_available(): 63 | initialized = dist.is_initialized() 64 | else: 65 | initialized = False 66 | if initialized: 67 | rank = dist.get_rank() 68 | world_size = dist.get_world_size() 69 | else: 70 | rank = 0 71 | world_size = 1 72 | return rank, world_size 73 | 74 | 75 | def master_only(func): 76 | 77 | @functools.wraps(func) 78 | def wrapper(*args, **kwargs): 79 | rank, _ = get_dist_info() 80 | if rank == 0: 81 | return func(*args, **kwargs) 82 | 83 | return wrapper 84 | -------------------------------------------------------------------------------- /basicsr/utils/download_util.py: -------------------------------------------------------------------------------- 1 | import math 2 | import requests 3 | from tqdm import tqdm 4 | 5 | from .misc import sizeof_fmt 6 | 7 | 8 | def download_file_from_google_drive(file_id, save_path): 9 | """Download files from google drive. 10 | 11 | Ref: 12 | https://stackoverflow.com/questions/25010369/wget-curl-large-file-from-google-drive # noqa E501 13 | 14 | Args: 15 | file_id (str): File id. 16 | save_path (str): Save path. 17 | """ 18 | 19 | session = requests.Session() 20 | URL = 'https://docs.google.com/uc?export=download' 21 | params = {'id': file_id} 22 | 23 | response = session.get(URL, params=params, stream=True) 24 | token = get_confirm_token(response) 25 | if token: 26 | params['confirm'] = token 27 | response = session.get(URL, params=params, stream=True) 28 | 29 | # get file size 30 | response_file_size = session.get( 31 | URL, params=params, stream=True, headers={'Range': 'bytes=0-2'}) 32 | if 'Content-Range' in response_file_size.headers: 33 | file_size = int( 34 | response_file_size.headers['Content-Range'].split('/')[1]) 35 | else: 36 | file_size = None 37 | 38 | save_response_content(response, save_path, file_size) 39 | 40 | 41 | def get_confirm_token(response): 42 | for key, value in response.cookies.items(): 43 | if key.startswith('download_warning'): 44 | return value 45 | return None 46 | 47 | 48 | def save_response_content(response, 49 | destination, 50 | file_size=None, 51 | chunk_size=32768): 52 | if file_size is not None: 53 | pbar = tqdm(total=math.ceil(file_size / chunk_size), unit='chunk') 54 | 55 | readable_file_size = sizeof_fmt(file_size) 56 | else: 57 | pbar = None 58 | 59 | with open(destination, 'wb') as f: 60 | downloaded_size = 0 61 | for chunk in response.iter_content(chunk_size): 62 | downloaded_size += chunk_size 63 | if pbar is not None: 64 | pbar.update(1) 65 | pbar.set_description(f'Download {sizeof_fmt(downloaded_size)} ' 66 | f'/ {readable_file_size}') 67 | if chunk: # filter out keep-alive new chunks 68 | f.write(chunk) 69 | if pbar is not None: 70 | pbar.close() 71 | -------------------------------------------------------------------------------- /basicsr/utils/face_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import os 4 | import torch 5 | from skimage import transform as trans 6 | 7 | from basicsr.utils import imwrite 8 | 9 | try: 10 | import dlib 11 | except ImportError: 12 | print('Please install dlib before testing face restoration.' 13 | 'Reference: https://github.com/davisking/dlib') 14 | 15 | 16 | class FaceRestorationHelper(object): 17 | """Helper for the face restoration pipeline.""" 18 | 19 | def __init__(self, upscale_factor, face_size=512): 20 | self.upscale_factor = upscale_factor 21 | self.face_size = (face_size, face_size) 22 | 23 | # standard 5 landmarks for FFHQ faces with 1024 x 1024 24 | self.face_template = np.array([[686.77227723, 488.62376238], 25 | [586.77227723, 493.59405941], 26 | [337.91089109, 488.38613861], 27 | [437.95049505, 493.51485149], 28 | [513.58415842, 678.5049505]]) 29 | self.face_template = self.face_template / (1024 // face_size) 30 | # for estimation the 2D similarity transformation 31 | self.similarity_trans = trans.SimilarityTransform() 32 | 33 | self.all_landmarks_5 = [] 34 | self.all_landmarks_68 = [] 35 | self.affine_matrices = [] 36 | self.inverse_affine_matrices = [] 37 | self.cropped_faces = [] 38 | self.restored_faces = [] 39 | self.save_png = True 40 | 41 | def init_dlib(self, detection_path, landmark5_path, landmark68_path): 42 | """Initialize the dlib detectors and predictors.""" 43 | self.face_detector = dlib.cnn_face_detection_model_v1(detection_path) 44 | self.shape_predictor_5 = dlib.shape_predictor(landmark5_path) 45 | self.shape_predictor_68 = dlib.shape_predictor(landmark68_path) 46 | 47 | def free_dlib_gpu_memory(self): 48 | del self.face_detector 49 | del self.shape_predictor_5 50 | del self.shape_predictor_68 51 | 52 | def read_input_image(self, img_path): 53 | # self.input_img is Numpy array, (h, w, c) with RGB order 54 | self.input_img = dlib.load_rgb_image(img_path) 55 | 56 | def detect_faces(self, 57 | img_path, 58 | upsample_num_times=1, 59 | only_keep_largest=False): 60 | """ 61 | Args: 62 | img_path (str): Image path. 63 | upsample_num_times (int): Upsamples the image before running the 64 | face detector 65 | 66 | Returns: 67 | int: Number of detected faces. 68 | """ 69 | self.read_input_image(img_path) 70 | det_faces = self.face_detector(self.input_img, upsample_num_times) 71 | if len(det_faces) == 0: 72 | print('No face detected. Try to increase upsample_num_times.') 73 | else: 74 | if only_keep_largest: 75 | print('Detect several faces and only keep the largest.') 76 | face_areas = [] 77 | for i in range(len(det_faces)): 78 | face_area = (det_faces[i].rect.right() - 79 | det_faces[i].rect.left()) * ( 80 | det_faces[i].rect.bottom() - 81 | det_faces[i].rect.top()) 82 | face_areas.append(face_area) 83 | largest_idx = face_areas.index(max(face_areas)) 84 | self.det_faces = [det_faces[largest_idx]] 85 | else: 86 | self.det_faces = det_faces 87 | return len(self.det_faces) 88 | 89 | def get_face_landmarks_5(self): 90 | for face in self.det_faces: 91 | shape = self.shape_predictor_5(self.input_img, face.rect) 92 | landmark = np.array([[part.x, part.y] for part in shape.parts()]) 93 | self.all_landmarks_5.append(landmark) 94 | return len(self.all_landmarks_5) 95 | 96 | def get_face_landmarks_68(self): 97 | """Get 68 densemarks for cropped images. 98 | 99 | Should only have one face at most in the cropped image. 100 | """ 101 | num_detected_face = 0 102 | for idx, face in enumerate(self.cropped_faces): 103 | # face detection 104 | det_face = self.face_detector(face, 1) # TODO: can we remove it? 105 | if len(det_face) == 0: 106 | print(f'Cannot find faces in cropped image with index {idx}.') 107 | self.all_landmarks_68.append(None) 108 | else: 109 | if len(det_face) > 1: 110 | print('Detect several faces in the cropped face. Use the ' 111 | ' largest one. Note that it will also cause overlap ' 112 | 'during paste_faces_to_input_image.') 113 | face_areas = [] 114 | for i in range(len(det_face)): 115 | face_area = (det_face[i].rect.right() - 116 | det_face[i].rect.left()) * ( 117 | det_face[i].rect.bottom() - 118 | det_face[i].rect.top()) 119 | face_areas.append(face_area) 120 | largest_idx = face_areas.index(max(face_areas)) 121 | face_rect = det_face[largest_idx].rect 122 | else: 123 | face_rect = det_face[0].rect 124 | shape = self.shape_predictor_68(face, face_rect) 125 | landmark = np.array([[part.x, part.y] 126 | for part in shape.parts()]) 127 | self.all_landmarks_68.append(landmark) 128 | num_detected_face += 1 129 | 130 | return num_detected_face 131 | 132 | def warp_crop_faces(self, 133 | save_cropped_path=None, 134 | save_inverse_affine_path=None): 135 | """Get affine matrix, warp and cropped faces. 136 | 137 | Also get inverse affine matrix for post-processing. 138 | """ 139 | for idx, landmark in enumerate(self.all_landmarks_5): 140 | # use 5 landmarks to get affine matrix 141 | self.similarity_trans.estimate(landmark, self.face_template) 142 | affine_matrix = self.similarity_trans.params[0:2, :] 143 | self.affine_matrices.append(affine_matrix) 144 | # warp and crop faces 145 | cropped_face = cv2.warpAffine(self.input_img, affine_matrix, 146 | self.face_size) 147 | self.cropped_faces.append(cropped_face) 148 | # save the cropped face 149 | if save_cropped_path is not None: 150 | path, ext = os.path.splitext(save_cropped_path) 151 | if self.save_png: 152 | save_path = f'{path}_{idx:02d}.png' 153 | else: 154 | save_path = f'{path}_{idx:02d}{ext}' 155 | 156 | imwrite( 157 | cv2.cvtColor(cropped_face, cv2.COLOR_RGB2BGR), save_path) 158 | 159 | # get inverse affine matrix 160 | self.similarity_trans.estimate(self.face_template, 161 | landmark * self.upscale_factor) 162 | inverse_affine = self.similarity_trans.params[0:2, :] 163 | self.inverse_affine_matrices.append(inverse_affine) 164 | # save inverse affine matrices 165 | if save_inverse_affine_path is not None: 166 | path, _ = os.path.splitext(save_inverse_affine_path) 167 | save_path = f'{path}_{idx:02d}.pth' 168 | torch.save(inverse_affine, save_path) 169 | 170 | def add_restored_face(self, face): 171 | self.restored_faces.append(face) 172 | 173 | def paste_faces_to_input_image(self, save_path): 174 | # operate in the BGR order 175 | input_img = cv2.cvtColor(self.input_img, cv2.COLOR_RGB2BGR) 176 | h, w, _ = input_img.shape 177 | h_up, w_up = h * self.upscale_factor, w * self.upscale_factor 178 | # simply resize the background 179 | upsample_img = cv2.resize(input_img, (w_up, h_up)) 180 | assert len(self.restored_faces) == len(self.inverse_affine_matrices), ( 181 | 'length of restored_faces and affine_matrices are different.') 182 | for restored_face, inverse_affine in zip(self.restored_faces, 183 | self.inverse_affine_matrices): 184 | inv_restored = cv2.warpAffine(restored_face, inverse_affine, 185 | (w_up, h_up)) 186 | mask = np.ones((*self.face_size, 3), dtype=np.float32) 187 | inv_mask = cv2.warpAffine(mask, inverse_affine, (w_up, h_up)) 188 | # remove the black borders 189 | inv_mask_erosion = cv2.erode( 190 | inv_mask, 191 | np.ones((2 * self.upscale_factor, 2 * self.upscale_factor), 192 | np.uint8)) 193 | inv_restored_remove_border = inv_mask_erosion * inv_restored 194 | total_face_area = np.sum(inv_mask_erosion) // 3 195 | # compute the fusion edge based on the area of face 196 | w_edge = int(total_face_area**0.5) // 20 197 | erosion_radius = w_edge * 2 198 | inv_mask_center = cv2.erode( 199 | inv_mask_erosion, 200 | np.ones((erosion_radius, erosion_radius), np.uint8)) 201 | blur_size = w_edge * 2 202 | inv_soft_mask = cv2.GaussianBlur(inv_mask_center, 203 | (blur_size + 1, blur_size + 1), 0) 204 | upsample_img = inv_soft_mask * inv_restored_remove_border + ( 205 | 1 - inv_soft_mask) * upsample_img 206 | if self.save_png: 207 | save_path = save_path.replace('.jpg', 208 | '.png').replace('.jpeg', '.png') 209 | imwrite(upsample_img.astype(np.uint8), save_path) 210 | 211 | def clean_all(self): 212 | self.all_landmarks_5 = [] 213 | self.all_landmarks_68 = [] 214 | self.restored_faces = [] 215 | self.affine_matrices = [] 216 | self.cropped_faces = [] 217 | self.inverse_affine_matrices = [] 218 | -------------------------------------------------------------------------------- /basicsr/utils/file_client.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/fileio/file_client.py # noqa: E501 2 | from abc import ABCMeta, abstractmethod 3 | 4 | 5 | class BaseStorageBackend(metaclass=ABCMeta): 6 | """Abstract class of storage backends. 7 | 8 | All backends need to implement two apis: ``get()`` and ``get_text()``. 9 | ``get()`` reads the file as a byte stream and ``get_text()`` reads the file 10 | as texts. 11 | """ 12 | 13 | @abstractmethod 14 | def get(self, filepath): 15 | pass 16 | 17 | @abstractmethod 18 | def get_text(self, filepath): 19 | pass 20 | 21 | 22 | class MemcachedBackend(BaseStorageBackend): 23 | """Memcached storage backend. 24 | 25 | Attributes: 26 | server_list_cfg (str): Config file for memcached server list. 27 | client_cfg (str): Config file for memcached client. 28 | sys_path (str | None): Additional path to be appended to `sys.path`. 29 | Default: None. 30 | """ 31 | 32 | def __init__(self, server_list_cfg, client_cfg, sys_path=None): 33 | if sys_path is not None: 34 | import sys 35 | sys.path.append(sys_path) 36 | try: 37 | import mc 38 | except ImportError: 39 | raise ImportError( 40 | 'Please install memcached to enable MemcachedBackend.') 41 | 42 | self.server_list_cfg = server_list_cfg 43 | self.client_cfg = client_cfg 44 | self._client = mc.MemcachedClient.GetInstance(self.server_list_cfg, 45 | self.client_cfg) 46 | # mc.pyvector servers as a point which points to a memory cache 47 | self._mc_buffer = mc.pyvector() 48 | 49 | def get(self, filepath): 50 | filepath = str(filepath) 51 | import mc 52 | self._client.Get(filepath, self._mc_buffer) 53 | value_buf = mc.ConvertBuffer(self._mc_buffer) 54 | return value_buf 55 | 56 | def get_text(self, filepath): 57 | raise NotImplementedError 58 | 59 | 60 | class HardDiskBackend(BaseStorageBackend): 61 | """Raw hard disks storage backend.""" 62 | 63 | def get(self, filepath): 64 | filepath = str(filepath) 65 | with open(filepath, 'rb') as f: 66 | value_buf = f.read() 67 | return value_buf 68 | 69 | def get_text(self, filepath): 70 | filepath = str(filepath) 71 | with open(filepath, 'r') as f: 72 | value_buf = f.read() 73 | return value_buf 74 | 75 | 76 | class LmdbBackend(BaseStorageBackend): 77 | """Lmdb storage backend. 78 | 79 | Args: 80 | db_paths (str | list[str]): Lmdb database paths. 81 | client_keys (str | list[str]): Lmdb client keys. Default: 'default'. 82 | readonly (bool, optional): Lmdb environment parameter. If True, 83 | disallow any write operations. Default: True. 84 | lock (bool, optional): Lmdb environment parameter. If False, when 85 | concurrent access occurs, do not lock the database. Default: False. 86 | readahead (bool, optional): Lmdb environment parameter. If False, 87 | disable the OS filesystem readahead mechanism, which may improve 88 | random read performance when a database is larger than RAM. 89 | Default: False. 90 | 91 | Attributes: 92 | db_paths (list): Lmdb database path. 93 | _client (list): A list of several lmdb envs. 94 | """ 95 | 96 | def __init__(self, 97 | db_paths, 98 | client_keys='default', 99 | readonly=True, 100 | lock=False, 101 | readahead=False, 102 | **kwargs): 103 | try: 104 | import lmdb 105 | except ImportError: 106 | raise ImportError('Please install lmdb to enable LmdbBackend.') 107 | 108 | if isinstance(client_keys, str): 109 | client_keys = [client_keys] 110 | 111 | if isinstance(db_paths, list): 112 | self.db_paths = [str(v) for v in db_paths] 113 | elif isinstance(db_paths, str): 114 | self.db_paths = [str(db_paths)] 115 | assert len(client_keys) == len(self.db_paths), ( 116 | 'client_keys and db_paths should have the same length, ' 117 | f'but received {len(client_keys)} and {len(self.db_paths)}.') 118 | 119 | self._client = {} 120 | 121 | for client, path in zip(client_keys, self.db_paths): 122 | self._client[client] = lmdb.open( 123 | path, 124 | readonly=readonly, 125 | lock=lock, 126 | readahead=readahead, 127 | map_size=8*1024*10485760, 128 | # max_readers=1, 129 | **kwargs) 130 | 131 | def get(self, filepath, client_key): 132 | """Get values according to the filepath from one lmdb named client_key. 133 | 134 | Args: 135 | filepath (str | obj:`Path`): Here, filepath is the lmdb key. 136 | client_key (str): Used for distinguishing differnet lmdb envs. 137 | """ 138 | filepath = str(filepath) 139 | assert client_key in self._client, (f'client_key {client_key} is not ' 140 | 'in lmdb clients.') 141 | client = self._client[client_key] 142 | with client.begin(write=False) as txn: 143 | value_buf = txn.get(filepath.encode('ascii')) 144 | return value_buf 145 | 146 | def get_text(self, filepath): 147 | raise NotImplementedError 148 | 149 | 150 | class FileClient(object): 151 | """A general file client to access files in different backend. 152 | 153 | The client loads a file or text in a specified backend from its path 154 | and return it as a binary file. it can also register other backend 155 | accessor with a given name and backend class. 156 | 157 | Attributes: 158 | backend (str): The storage backend type. Options are "disk", 159 | "memcached" and "lmdb". 160 | client (:obj:`BaseStorageBackend`): The backend object. 161 | """ 162 | 163 | _backends = { 164 | 'disk': HardDiskBackend, 165 | 'memcached': MemcachedBackend, 166 | 'lmdb': LmdbBackend, 167 | } 168 | 169 | def __init__(self, backend='disk', **kwargs): 170 | if backend not in self._backends: 171 | raise ValueError( 172 | f'Backend {backend} is not supported. Currently supported ones' 173 | f' are {list(self._backends.keys())}') 174 | self.backend = backend 175 | self.client = self._backends[backend](**kwargs) 176 | 177 | def get(self, filepath, client_key='default'): 178 | # client_key is used only for lmdb, where different fileclients have 179 | # different lmdb environments. 180 | if self.backend == 'lmdb': 181 | return self.client.get(filepath, client_key) 182 | else: 183 | return self.client.get(filepath) 184 | 185 | def get_text(self, filepath): 186 | return self.client.get_text(filepath) 187 | -------------------------------------------------------------------------------- /basicsr/utils/flow_util.py: -------------------------------------------------------------------------------- 1 | # Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501 2 | import cv2 3 | import numpy as np 4 | import os 5 | 6 | 7 | def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs): 8 | """Read an optical flow map. 9 | 10 | Args: 11 | flow_path (ndarray or str): Flow path. 12 | quantize (bool): whether to read quantized pair, if set to True, 13 | remaining args will be passed to :func:`dequantize_flow`. 14 | concat_axis (int): The axis that dx and dy are concatenated, 15 | can be either 0 or 1. Ignored if quantize is False. 16 | 17 | Returns: 18 | ndarray: Optical flow represented as a (h, w, 2) numpy array 19 | """ 20 | if quantize: 21 | assert concat_axis in [0, 1] 22 | cat_flow = cv2.imread(flow_path, cv2.IMREAD_UNCHANGED) 23 | if cat_flow.ndim != 2: 24 | raise IOError(f'{flow_path} is not a valid quantized flow file, ' 25 | f'its dimension is {cat_flow.ndim}.') 26 | assert cat_flow.shape[concat_axis] % 2 == 0 27 | dx, dy = np.split(cat_flow, 2, axis=concat_axis) 28 | flow = dequantize_flow(dx, dy, *args, **kwargs) 29 | else: 30 | with open(flow_path, 'rb') as f: 31 | try: 32 | header = f.read(4).decode('utf-8') 33 | except Exception: 34 | raise IOError(f'Invalid flow file: {flow_path}') 35 | else: 36 | if header != 'PIEH': 37 | raise IOError(f'Invalid flow file: {flow_path}, ' 38 | 'header does not contain PIEH') 39 | 40 | w = np.fromfile(f, np.int32, 1).squeeze() 41 | h = np.fromfile(f, np.int32, 1).squeeze() 42 | flow = np.fromfile(f, np.float32, w * h * 2).reshape((h, w, 2)) 43 | 44 | return flow.astype(np.float32) 45 | 46 | 47 | def flowwrite(flow, filename, quantize=False, concat_axis=0, *args, **kwargs): 48 | """Write optical flow to file. 49 | 50 | If the flow is not quantized, it will be saved as a .flo file losslessly, 51 | otherwise a jpeg image which is lossy but of much smaller size. (dx and dy 52 | will be concatenated horizontally into a single image if quantize is True.) 53 | 54 | Args: 55 | flow (ndarray): (h, w, 2) array of optical flow. 56 | filename (str): Output filepath. 57 | quantize (bool): Whether to quantize the flow and save it to 2 jpeg 58 | images. If set to True, remaining args will be passed to 59 | :func:`quantize_flow`. 60 | concat_axis (int): The axis that dx and dy are concatenated, 61 | can be either 0 or 1. Ignored if quantize is False. 62 | """ 63 | if not quantize: 64 | with open(filename, 'wb') as f: 65 | f.write('PIEH'.encode('utf-8')) 66 | np.array([flow.shape[1], flow.shape[0]], dtype=np.int32).tofile(f) 67 | flow = flow.astype(np.float32) 68 | flow.tofile(f) 69 | f.flush() 70 | else: 71 | assert concat_axis in [0, 1] 72 | dx, dy = quantize_flow(flow, *args, **kwargs) 73 | dxdy = np.concatenate((dx, dy), axis=concat_axis) 74 | os.makedirs(filename, exist_ok=True) 75 | cv2.imwrite(dxdy, filename) 76 | 77 | 78 | def quantize_flow(flow, max_val=0.02, norm=True): 79 | """Quantize flow to [0, 255]. 80 | 81 | After this step, the size of flow will be much smaller, and can be 82 | dumped as jpeg images. 83 | 84 | Args: 85 | flow (ndarray): (h, w, 2) array of optical flow. 86 | max_val (float): Maximum value of flow, values beyond 87 | [-max_val, max_val] will be truncated. 88 | norm (bool): Whether to divide flow values by image width/height. 89 | 90 | Returns: 91 | tuple[ndarray]: Quantized dx and dy. 92 | """ 93 | h, w, _ = flow.shape 94 | dx = flow[..., 0] 95 | dy = flow[..., 1] 96 | if norm: 97 | dx = dx / w # avoid inplace operations 98 | dy = dy / h 99 | # use 255 levels instead of 256 to make sure 0 is 0 after dequantization. 100 | flow_comps = [ 101 | quantize(d, -max_val, max_val, 255, np.uint8) for d in [dx, dy] 102 | ] 103 | return tuple(flow_comps) 104 | 105 | 106 | def dequantize_flow(dx, dy, max_val=0.02, denorm=True): 107 | """Recover from quantized flow. 108 | 109 | Args: 110 | dx (ndarray): Quantized dx. 111 | dy (ndarray): Quantized dy. 112 | max_val (float): Maximum value used when quantizing. 113 | denorm (bool): Whether to multiply flow values with width/height. 114 | 115 | Returns: 116 | ndarray: Dequantized flow. 117 | """ 118 | assert dx.shape == dy.shape 119 | assert dx.ndim == 2 or (dx.ndim == 3 and dx.shape[-1] == 1) 120 | 121 | dx, dy = [dequantize(d, -max_val, max_val, 255) for d in [dx, dy]] 122 | 123 | if denorm: 124 | dx *= dx.shape[1] 125 | dy *= dx.shape[0] 126 | flow = np.dstack((dx, dy)) 127 | return flow 128 | 129 | 130 | def quantize(arr, min_val, max_val, levels, dtype=np.int64): 131 | """Quantize an array of (-inf, inf) to [0, levels-1]. 132 | 133 | Args: 134 | arr (ndarray): Input array. 135 | min_val (scalar): Minimum value to be clipped. 136 | max_val (scalar): Maximum value to be clipped. 137 | levels (int): Quantization levels. 138 | dtype (np.type): The type of the quantized array. 139 | 140 | Returns: 141 | tuple: Quantized array. 142 | """ 143 | if not (isinstance(levels, int) and levels > 1): 144 | raise ValueError( 145 | f'levels must be a positive integer, but got {levels}') 146 | if min_val >= max_val: 147 | raise ValueError( 148 | f'min_val ({min_val}) must be smaller than max_val ({max_val})') 149 | 150 | arr = np.clip(arr, min_val, max_val) - min_val 151 | quantized_arr = np.minimum( 152 | np.floor(levels * arr / (max_val - min_val)).astype(dtype), levels - 1) 153 | 154 | return quantized_arr 155 | 156 | 157 | def dequantize(arr, min_val, max_val, levels, dtype=np.float64): 158 | """Dequantize an array. 159 | 160 | Args: 161 | arr (ndarray): Input array. 162 | min_val (scalar): Minimum value to be clipped. 163 | max_val (scalar): Maximum value to be clipped. 164 | levels (int): Quantization levels. 165 | dtype (np.type): The type of the dequantized array. 166 | 167 | Returns: 168 | tuple: Dequantized array. 169 | """ 170 | if not (isinstance(levels, int) and levels > 1): 171 | raise ValueError( 172 | f'levels must be a positive integer, but got {levels}') 173 | if min_val >= max_val: 174 | raise ValueError( 175 | f'min_val ({min_val}) must be smaller than max_val ({max_val})') 176 | 177 | dequantized_arr = (arr + 0.5).astype(dtype) * (max_val - 178 | min_val) / levels + min_val 179 | 180 | return dequantized_arr 181 | -------------------------------------------------------------------------------- /basicsr/utils/img_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import math 3 | import numpy as np 4 | import os 5 | import torch 6 | from torchvision.utils import make_grid 7 | 8 | 9 | def img2tensor(imgs, bgr2rgb=True, float32=True): 10 | """Numpy array to tensor. 11 | 12 | Args: 13 | imgs (list[ndarray] | ndarray): Input images. 14 | bgr2rgb (bool): Whether to change bgr to rgb. 15 | float32 (bool): Whether to change to float32. 16 | 17 | Returns: 18 | list[tensor] | tensor: Tensor images. If returned results only have 19 | one element, just return tensor. 20 | """ 21 | 22 | def _totensor(img, bgr2rgb, float32): 23 | if img.shape[2] == 3 and bgr2rgb: 24 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 25 | img = torch.from_numpy(img.transpose(2, 0, 1)) 26 | if float32: 27 | img = img.float() 28 | return img 29 | 30 | if isinstance(imgs, list): 31 | return [_totensor(img, bgr2rgb, float32) for img in imgs] 32 | else: 33 | return _totensor(imgs, bgr2rgb, float32) 34 | 35 | 36 | def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): 37 | """Convert torch Tensors into image numpy arrays. 38 | 39 | After clamping to [min, max], values will be normalized to [0, 1]. 40 | 41 | Args: 42 | tensor (Tensor or list[Tensor]): Accept shapes: 43 | 1) 4D mini-batch Tensor of shape (B x 3/1 x H x W); 44 | 2) 3D Tensor of shape (3/1 x H x W); 45 | 3) 2D Tensor of shape (H x W). 46 | Tensor channel should be in RGB order. 47 | rgb2bgr (bool): Whether to change rgb to bgr. 48 | out_type (numpy type): output types. If ``np.uint8``, transform outputs 49 | to uint8 type with range [0, 255]; otherwise, float type with 50 | range [0, 1]. Default: ``np.uint8``. 51 | min_max (tuple[int]): min and max values for clamp. 52 | 53 | Returns: 54 | (Tensor or list): 3D ndarray of shape (H x W x C) OR 2D ndarray of 55 | shape (H x W). The channel order is BGR. 56 | """ 57 | if not (torch.is_tensor(tensor) or 58 | (isinstance(tensor, list) 59 | and all(torch.is_tensor(t) for t in tensor))): 60 | raise TypeError( 61 | f'tensor or list of tensors expected, got {type(tensor)}') 62 | 63 | if torch.is_tensor(tensor): 64 | tensor = [tensor] 65 | result = [] 66 | for _tensor in tensor: 67 | _tensor = _tensor.squeeze(0).float().detach().cpu().clamp_(*min_max) 68 | _tensor = (_tensor - min_max[0]) / (min_max[1] - min_max[0]) 69 | 70 | n_dim = _tensor.dim() 71 | if n_dim == 4: 72 | img_np = make_grid( 73 | _tensor, nrow=int(math.sqrt(_tensor.size(0))), 74 | normalize=False).numpy() 75 | img_np = img_np.transpose(1, 2, 0) 76 | if rgb2bgr: 77 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 78 | elif n_dim == 3: 79 | img_np = _tensor.numpy() 80 | img_np = img_np.transpose(1, 2, 0) 81 | if img_np.shape[2] == 1: # gray image 82 | img_np = np.squeeze(img_np, axis=2) 83 | else: 84 | if rgb2bgr: 85 | img_np = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR) 86 | elif n_dim == 2: 87 | img_np = _tensor.numpy() 88 | else: 89 | raise TypeError('Only support 4D, 3D or 2D tensor. ' 90 | f'But received with dimension: {n_dim}') 91 | if out_type == np.uint8: 92 | # Unlike MATLAB, numpy.unit8() WILL NOT round by default. 93 | img_np = (img_np * 255.0).round() 94 | img_np = img_np.astype(out_type) 95 | result.append(img_np) 96 | if len(result) == 1: 97 | result = result[0] 98 | return result 99 | 100 | 101 | def imfrombytes(content, flag='color', float32=False): 102 | """Read an image from bytes. 103 | 104 | Args: 105 | content (bytes): Image bytes got from files or other streams. 106 | flag (str): Flags specifying the color type of a loaded image, 107 | candidates are `color`, `grayscale` and `unchanged`. 108 | float32 (bool): Whether to change to float32., If True, will also norm 109 | to [0, 1]. Default: False. 110 | 111 | Returns: 112 | ndarray: Loaded image array. 113 | """ 114 | img_np = np.frombuffer(content, np.uint8) 115 | imread_flags = { 116 | 'color': cv2.IMREAD_COLOR, 117 | 'grayscale': cv2.IMREAD_GRAYSCALE, 118 | 'unchanged': cv2.IMREAD_UNCHANGED 119 | } 120 | if img_np is None: 121 | raise Exception('None .. !!!') 122 | print("12230033neibu1",img_np) 123 | img = cv2.imdecode(img_np, imread_flags[flag]) 124 | print("12230033neibu",img) 125 | if float32: 126 | img = img.astype(np.float32) / 255. 127 | return img 128 | 129 | def imfrombytesDP(content, flag='color', float32=False): 130 | """Read an image from bytes. 131 | 132 | Args: 133 | content (bytes): Image bytes got from files or other streams. 134 | flag (str): Flags specifying the color type of a loaded image, 135 | candidates are `color`, `grayscale` and `unchanged`. 136 | float32 (bool): Whether to change to float32., If True, will also norm 137 | to [0, 1]. Default: False. 138 | 139 | Returns: 140 | ndarray: Loaded image array. 141 | """ 142 | img_np = np.frombuffer(content, np.uint8) 143 | if img_np is None: 144 | raise Exception('None .. !!!') 145 | img = cv2.imdecode(img_np, cv2.IMREAD_UNCHANGED) 146 | if float32: 147 | img = img.astype(np.float32) / 65535. 148 | return img 149 | 150 | def padding(img_lq, img_gt, gt_size): 151 | h, w, _ = img_lq.shape 152 | 153 | h_pad = max(0, gt_size - h) 154 | w_pad = max(0, gt_size - w) 155 | 156 | if h_pad == 0 and w_pad == 0: 157 | return img_lq, img_gt 158 | 159 | img_lq = cv2.copyMakeBorder(img_lq, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 160 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 161 | # print('img_lq', img_lq.shape, img_gt.shape) 162 | if img_lq.ndim == 2: 163 | img_lq = np.expand_dims(img_lq, axis=2) 164 | if img_gt.ndim == 2: 165 | img_gt = np.expand_dims(img_gt, axis=2) 166 | return img_lq, img_gt 167 | 168 | def padding_DP(img_lqL, img_lqR, img_gt, gt_size): 169 | h, w, _ = img_gt.shape 170 | 171 | h_pad = max(0, gt_size - h) 172 | w_pad = max(0, gt_size - w) 173 | 174 | if h_pad == 0 and w_pad == 0: 175 | return img_lqL, img_lqR, img_gt 176 | 177 | img_lqL = cv2.copyMakeBorder(img_lqL, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 178 | img_lqR = cv2.copyMakeBorder(img_lqR, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 179 | img_gt = cv2.copyMakeBorder(img_gt, 0, h_pad, 0, w_pad, cv2.BORDER_REFLECT) 180 | # print('img_lq', img_lq.shape, img_gt.shape) 181 | return img_lqL, img_lqR, img_gt 182 | 183 | def imwrite(img, file_path, params=None, auto_mkdir=True): 184 | """Write image to file. 185 | 186 | Args: 187 | img (ndarray): Image array to be written. 188 | file_path (str): Image file path. 189 | params (None or list): Same as opencv's :func:`imwrite` interface. 190 | auto_mkdir (bool): If the parent folder of `file_path` does not exist, 191 | whether to create it automatically. 192 | 193 | Returns: 194 | bool: Successful or not. 195 | """ 196 | if auto_mkdir: 197 | dir_name = os.path.abspath(os.path.dirname(file_path)) 198 | os.makedirs(dir_name, exist_ok=True) 199 | return cv2.imwrite(file_path, img, params) 200 | 201 | 202 | def crop_border(imgs, crop_border): 203 | """Crop borders of images. 204 | 205 | Args: 206 | imgs (list[ndarray] | ndarray): Images with shape (h, w, c). 207 | crop_border (int): Crop border for each end of height and weight. 208 | 209 | Returns: 210 | list[ndarray]: Cropped images. 211 | """ 212 | if crop_border == 0: 213 | return imgs 214 | else: 215 | if isinstance(imgs, list): 216 | return [ 217 | v[crop_border:-crop_border, crop_border:-crop_border, ...] 218 | for v in imgs 219 | ] 220 | else: 221 | return imgs[crop_border:-crop_border, crop_border:-crop_border, 222 | ...] 223 | -------------------------------------------------------------------------------- /basicsr/utils/lmdb_util.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import lmdb 3 | import sys 4 | from multiprocessing import Pool 5 | from os import path as osp 6 | from tqdm import tqdm 7 | 8 | 9 | def make_lmdb_from_imgs(data_path, 10 | lmdb_path, 11 | img_path_list, 12 | keys, 13 | batch=5000, 14 | compress_level=1, 15 | multiprocessing_read=False, 16 | n_thread=40, 17 | map_size=None): 18 | """Make lmdb from images. 19 | 20 | Contents of lmdb. The file structure is: 21 | example.lmdb 22 | ├── data.mdb 23 | ├── lock.mdb 24 | ├── meta_info.txt 25 | 26 | The data.mdb and lock.mdb are standard lmdb files and you can refer to 27 | https://lmdb.readthedocs.io/en/release/ for more details. 28 | 29 | The meta_info.txt is a specified txt file to record the meta information 30 | of our datasets. It will be automatically created when preparing 31 | datasets by our provided dataset tools. 32 | Each line in the txt file records 1)image name (with extension), 33 | 2)image shape, and 3)compression level, separated by a white space. 34 | 35 | For example, the meta information could be: 36 | `000_00000000.png (720,1280,3) 1`, which means: 37 | 1) image name (with extension): 000_00000000.png; 38 | 2) image shape: (720,1280,3); 39 | 3) compression level: 1 40 | 41 | We use the image name without extension as the lmdb key. 42 | 43 | If `multiprocessing_read` is True, it will read all the images to memory 44 | using multiprocessing. Thus, your server needs to have enough memory. 45 | 46 | Args: 47 | data_path (str): Data path for reading images. 48 | lmdb_path (str): Lmdb save path. 49 | img_path_list (str): Image path list. 50 | keys (str): Used for lmdb keys. 51 | batch (int): After processing batch images, lmdb commits. 52 | Default: 5000. 53 | compress_level (int): Compress level when encoding images. Default: 1. 54 | multiprocessing_read (bool): Whether use multiprocessing to read all 55 | the images to memory. Default: False. 56 | n_thread (int): For multiprocessing. 57 | map_size (int | None): Map size for lmdb env. If None, use the 58 | estimated size from images. Default: None 59 | """ 60 | 61 | assert len(img_path_list) == len(keys), ( 62 | 'img_path_list and keys should have the same length, ' 63 | f'but got {len(img_path_list)} and {len(keys)}') 64 | print(f'Create lmdb for {data_path}, save to {lmdb_path}...') 65 | print(f'Totoal images: {len(img_path_list)}') 66 | if not lmdb_path.endswith('.lmdb'): 67 | raise ValueError("lmdb_path must end with '.lmdb'.") 68 | if osp.exists(lmdb_path): 69 | print(f'Folder {lmdb_path} already exists. Exit.') 70 | sys.exit(1) 71 | 72 | if multiprocessing_read: 73 | # read all the images to memory (multiprocessing) 74 | dataset = {} # use dict to keep the order for multiprocessing 75 | shapes = {} 76 | print(f'Read images with multiprocessing, #thread: {n_thread} ...') 77 | pbar = tqdm(total=len(img_path_list), unit='image') 78 | 79 | def callback(arg): 80 | """get the image data and update pbar.""" 81 | key, dataset[key], shapes[key] = arg 82 | pbar.update(1) 83 | pbar.set_description(f'Read {key}') 84 | 85 | pool = Pool(n_thread) 86 | for path, key in zip(img_path_list, keys): 87 | pool.apply_async( 88 | read_img_worker, 89 | args=(osp.join(data_path, path), key, compress_level), 90 | callback=callback) 91 | pool.close() 92 | pool.join() 93 | pbar.close() 94 | print(f'Finish reading {len(img_path_list)} images.') 95 | 96 | # create lmdb environment 97 | if map_size is None: 98 | # obtain data size for one image 99 | img = cv2.imread( 100 | osp.join(data_path, img_path_list[0]), cv2.IMREAD_UNCHANGED) 101 | _, img_byte = cv2.imencode( 102 | '.png', img, [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 103 | data_size_per_img = img_byte.nbytes 104 | print('Data size per image is: ', data_size_per_img) 105 | data_size = data_size_per_img * len(img_path_list) 106 | map_size = data_size * 10 107 | 108 | env = lmdb.open(lmdb_path, map_size=map_size) 109 | 110 | # write data to lmdb 111 | pbar = tqdm(total=len(img_path_list), unit='chunk') 112 | txn = env.begin(write=True) 113 | txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 114 | for idx, (path, key) in enumerate(zip(img_path_list, keys)): 115 | pbar.update(1) 116 | pbar.set_description(f'Write {key}') 117 | key_byte = key.encode('ascii') 118 | if multiprocessing_read: 119 | img_byte = dataset[key] 120 | h, w, c = shapes[key] 121 | else: 122 | _, img_byte, img_shape = read_img_worker( 123 | osp.join(data_path, path), key, compress_level) 124 | h, w, c = img_shape 125 | 126 | txn.put(key_byte, img_byte) 127 | # write meta information 128 | txt_file.write(f'{key}.png ({h},{w},{c}) {compress_level}\n') 129 | if idx % batch == 0: 130 | txn.commit() 131 | txn = env.begin(write=True) 132 | pbar.close() 133 | txn.commit() 134 | env.close() 135 | txt_file.close() 136 | print('\nFinish writing lmdb.') 137 | 138 | 139 | def read_img_worker(path, key, compress_level): 140 | """Read image worker. 141 | 142 | Args: 143 | path (str): Image path. 144 | key (str): Image key. 145 | compress_level (int): Compress level when encoding images. 146 | 147 | Returns: 148 | str: Image key. 149 | byte: Image byte. 150 | tuple[int]: Image shape. 151 | """ 152 | 153 | img = cv2.imread(path, cv2.IMREAD_UNCHANGED) 154 | if img.ndim == 2: 155 | h, w = img.shape 156 | c = 1 157 | else: 158 | h, w, c = img.shape 159 | _, img_byte = cv2.imencode('.png', img, 160 | [cv2.IMWRITE_PNG_COMPRESSION, compress_level]) 161 | return (key, img_byte, (h, w, c)) 162 | 163 | 164 | class LmdbMaker(): 165 | """LMDB Maker. 166 | 167 | Args: 168 | lmdb_path (str): Lmdb save path. 169 | map_size (int): Map size for lmdb env. Default: 1024 ** 4, 1TB. 170 | batch (int): After processing batch images, lmdb commits. 171 | Default: 5000. 172 | compress_level (int): Compress level when encoding images. Default: 1. 173 | """ 174 | 175 | def __init__(self, 176 | lmdb_path, 177 | map_size=1024**4, 178 | batch=5000, 179 | compress_level=1): 180 | if not lmdb_path.endswith('.lmdb'): 181 | raise ValueError("lmdb_path must end with '.lmdb'.") 182 | if osp.exists(lmdb_path): 183 | print(f'Folder {lmdb_path} already exists. Exit.') 184 | sys.exit(1) 185 | 186 | self.lmdb_path = lmdb_path 187 | self.batch = batch 188 | self.compress_level = compress_level 189 | self.env = lmdb.open(lmdb_path, map_size=map_size) 190 | self.txn = self.env.begin(write=True) 191 | self.txt_file = open(osp.join(lmdb_path, 'meta_info.txt'), 'w') 192 | self.counter = 0 193 | 194 | def put(self, img_byte, key, img_shape): 195 | self.counter += 1 196 | key_byte = key.encode('ascii') 197 | self.txn.put(key_byte, img_byte) 198 | # write meta information 199 | h, w, c = img_shape 200 | self.txt_file.write(f'{key}.png ({h},{w},{c}) {self.compress_level}\n') 201 | if self.counter % self.batch == 0: 202 | self.txn.commit() 203 | self.txn = self.env.begin(write=True) 204 | 205 | def close(self): 206 | self.txn.commit() 207 | self.env.close() 208 | self.txt_file.close() 209 | -------------------------------------------------------------------------------- /basicsr/utils/logger.py: -------------------------------------------------------------------------------- 1 | import datetime 2 | import logging 3 | import time 4 | 5 | from .dist_util import get_dist_info, master_only 6 | 7 | initialized_logger = {} 8 | 9 | 10 | class MessageLogger(): 11 | """Message logger for printing. 12 | 13 | Args: 14 | opt (dict): Config. It contains the following keys: 15 | name (str): Exp name. 16 | logger (dict): Contains 'print_freq' (str) for logger interval. 17 | train (dict): Contains 'total_iter' (int) for total iters. 18 | use_tb_logger (bool): Use tensorboard logger. 19 | start_iter (int): Start iter. Default: 1. 20 | tb_logger (obj:`tb_logger`): Tensorboard logger. Default: None. 21 | """ 22 | 23 | def __init__(self, opt, start_iter=1, tb_logger=None): 24 | self.exp_name = opt['name'] 25 | self.interval = opt['logger']['print_freq'] 26 | self.start_iter = start_iter 27 | self.max_iters = opt['train']['total_iter'] 28 | self.use_tb_logger = opt['logger']['use_tb_logger'] 29 | self.tb_logger = tb_logger 30 | self.start_time = time.time() 31 | self.logger = get_root_logger() 32 | 33 | @master_only 34 | def __call__(self, log_vars): 35 | """Format logging message. 36 | 37 | Args: 38 | log_vars (dict): It contains the following keys: 39 | epoch (int): Epoch number. 40 | iter (int): Current iter. 41 | lrs (list): List for learning rates. 42 | 43 | time (float): Iter time. 44 | data_time (float): Data time for each iter. 45 | """ 46 | # epoch, iter, learning rates 47 | epoch = log_vars.pop('epoch') 48 | current_iter = log_vars.pop('iter') 49 | lrs = log_vars.pop('lrs') 50 | 51 | message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, ' f'iter:{current_iter:8,d}, lr:(') 52 | for v in lrs: 53 | message += f'{v:.3e},' 54 | message += ')] ' 55 | 56 | # time and estimated time 57 | if 'time' in log_vars.keys(): 58 | iter_time = log_vars.pop('time') 59 | data_time = log_vars.pop('data_time') 60 | 61 | total_time = time.time() - self.start_time 62 | time_sec_avg = total_time / (current_iter - self.start_iter + 1) 63 | eta_sec = time_sec_avg * (self.max_iters - current_iter - 1) 64 | eta_str = str(datetime.timedelta(seconds=int(eta_sec))) 65 | message += f'[eta: {eta_str}, ' 66 | message += f'time (data): {iter_time:.3f} ({data_time:.3f})] ' 67 | 68 | # other items, especially losses 69 | for k, v in log_vars.items(): 70 | message += f'{k}: {v:.4e} ' 71 | # tensorboard logger 72 | if self.use_tb_logger and 'debug' not in self.exp_name: 73 | if k.startswith('l_'): 74 | self.tb_logger.add_scalar(f'losses/{k}', v, current_iter) 75 | else: 76 | self.tb_logger.add_scalar(k, v, current_iter) 77 | self.logger.info(message) 78 | 79 | 80 | @master_only 81 | def init_tb_logger(log_dir): 82 | from torch.utils.tensorboard import SummaryWriter 83 | tb_logger = SummaryWriter(log_dir=log_dir) 84 | return tb_logger 85 | 86 | 87 | @master_only 88 | def init_wandb_logger(opt): 89 | """We now only use wandb to sync tensorboard log.""" 90 | import wandb 91 | logger = logging.getLogger('basicsr') 92 | 93 | project = opt['logger']['wandb']['project'] 94 | resume_id = opt['logger']['wandb'].get('resume_id') 95 | if resume_id: 96 | wandb_id = resume_id 97 | resume = 'allow' 98 | logger.warning(f'Resume wandb logger with id={wandb_id}.') 99 | else: 100 | wandb_id = wandb.util.generate_id() 101 | resume = 'never' 102 | 103 | wandb.init(id=wandb_id, resume=resume, name=opt['name'], config=opt, project=project, sync_tensorboard=True) 104 | 105 | logger.info(f'Use wandb logger with id={wandb_id}; project={project}.') 106 | 107 | 108 | def get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=None): 109 | """Get the root logger. 110 | 111 | The logger will be initialized if it has not been initialized. By default a 112 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 113 | also be added. 114 | 115 | Args: 116 | logger_name (str): root logger name. Default: 'basicsr'. 117 | log_file (str | None): The log filename. If specified, a FileHandler 118 | will be added to the root logger. 119 | log_level (int): The root logger level. Note that only the process of 120 | rank 0 is affected, while other processes will set the level to 121 | "Error" and be silent most of the time. 122 | 123 | Returns: 124 | logging.Logger: The root logger. 125 | """ 126 | logger = logging.getLogger(logger_name) 127 | # if the logger has been initialized, just return it 128 | if logger_name in initialized_logger: 129 | return logger 130 | 131 | format_str = '%(asctime)s %(levelname)s: %(message)s' 132 | stream_handler = logging.StreamHandler() 133 | stream_handler.setFormatter(logging.Formatter(format_str)) 134 | logger.addHandler(stream_handler) 135 | logger.propagate = False 136 | rank, _ = get_dist_info() 137 | if rank != 0: 138 | logger.setLevel('ERROR') 139 | elif log_file is not None: 140 | logger.setLevel(log_level) 141 | # add file handler 142 | file_handler = logging.FileHandler(log_file, 'w') 143 | file_handler.setFormatter(logging.Formatter(format_str)) 144 | file_handler.setLevel(log_level) 145 | logger.addHandler(file_handler) 146 | initialized_logger[logger_name] = True 147 | return logger 148 | 149 | 150 | def get_env_info(): 151 | """Get environment information. 152 | 153 | Currently, only log the software version. 154 | """ 155 | import torch 156 | import torchvision 157 | 158 | from basicsr.version import __version__ 159 | msg = r""" 160 | ____ _ _____ ____ 161 | / __ ) ____ _ _____ (_)_____/ ___/ / __ \ 162 | / __ |/ __ `// ___// // ___/\__ \ / /_/ / 163 | / /_/ // /_/ /(__ )/ // /__ ___/ // _, _/ 164 | /_____/ \__,_//____//_/ \___//____//_/ |_| 165 | ______ __ __ __ __ 166 | / ____/____ ____ ____/ / / / __ __ _____ / /__ / / 167 | / / __ / __ \ / __ \ / __ / / / / / / // ___// //_/ / / 168 | / /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/ 169 | \____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_) 170 | """ 171 | msg += ('\nVersion Information: ' 172 | f'\n\tBasicSR: {__version__}' 173 | f'\n\tPyTorch: {torch.__version__}' 174 | f'\n\tTorchVision: {torchvision.__version__}') 175 | return msg -------------------------------------------------------------------------------- /basicsr/utils/misc.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | import random 4 | import time 5 | import torch 6 | from os import path as osp 7 | 8 | from .dist_util import master_only 9 | from .logger import get_root_logger 10 | 11 | 12 | def set_random_seed(seed): 13 | """Set random seeds.""" 14 | random.seed(seed) 15 | np.random.seed(seed) 16 | torch.manual_seed(seed) 17 | torch.cuda.manual_seed(seed) 18 | torch.cuda.manual_seed_all(seed) 19 | 20 | 21 | def get_time_str(): 22 | return time.strftime('%Y%m%d_%H%M%S', time.localtime()) 23 | 24 | 25 | def mkdir_and_rename(path): 26 | """mkdirs. If path exists, rename it with timestamp and create a new one. 27 | 28 | Args: 29 | path (str): Folder path. 30 | """ 31 | if osp.exists(path): 32 | new_name = path + '_archived_' + get_time_str() 33 | print(f'Path already exists. Rename it to {new_name}', flush=True) 34 | os.rename(path, new_name) 35 | os.makedirs(path, exist_ok=True) 36 | 37 | 38 | @master_only 39 | def make_exp_dirs(opt): 40 | """Make dirs for experiments.""" 41 | path_opt = opt['path'].copy() 42 | if opt['is_train']: 43 | mkdir_and_rename(path_opt.pop('experiments_root')) 44 | else: 45 | mkdir_and_rename(path_opt.pop('results_root')) 46 | for key, path in path_opt.items(): 47 | if ('strict_load' not in key) and ('pretrain_network' 48 | not in key) and ('resume' 49 | not in key): 50 | os.makedirs(path, exist_ok=True) 51 | 52 | 53 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 54 | """Scan a directory to find the interested files. 55 | 56 | Args: 57 | dir_path (str): Path of the directory. 58 | suffix (str | tuple(str), optional): File suffix that we are 59 | interested in. Default: None. 60 | recursive (bool, optional): If set to True, recursively scan the 61 | directory. Default: False. 62 | full_path (bool, optional): If set to True, include the dir_path. 63 | Default: False. 64 | 65 | Returns: 66 | A generator for all the interested files with relative pathes. 67 | """ 68 | 69 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 70 | raise TypeError('"suffix" must be a string or tuple of strings') 71 | 72 | root = dir_path 73 | 74 | def _scandir(dir_path, suffix, recursive): 75 | for entry in os.scandir(dir_path): 76 | if not entry.name.startswith('.') and entry.is_file(): 77 | if full_path: 78 | return_path = entry.path 79 | else: 80 | return_path = osp.relpath(entry.path, root) 81 | 82 | if suffix is None: 83 | yield return_path 84 | elif return_path.endswith(suffix): 85 | yield return_path 86 | else: 87 | if recursive: 88 | yield from _scandir( 89 | entry.path, suffix=suffix, recursive=recursive) 90 | else: 91 | continue 92 | 93 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 94 | 95 | def scandir_SIDD(dir_path, keywords=None, recursive=False, full_path=False): 96 | """Scan a directory to find the interested files. 97 | 98 | Args: 99 | dir_path (str): Path of the directory. 100 | keywords (str | tuple(str), optional): File keywords that we are 101 | interested in. Default: None. 102 | recursive (bool, optional): If set to True, recursively scan the 103 | directory. Default: False. 104 | full_path (bool, optional): If set to True, include the dir_path. 105 | Default: False. 106 | 107 | Returns: 108 | A generator for all the interested files with relative pathes. 109 | """ 110 | 111 | if (keywords is not None) and not isinstance(keywords, (str, tuple)): 112 | raise TypeError('"keywords" must be a string or tuple of strings') 113 | 114 | root = dir_path 115 | 116 | def _scandir(dir_path, keywords, recursive): 117 | for entry in os.scandir(dir_path): 118 | if not entry.name.startswith('.') and entry.is_file(): 119 | if full_path: 120 | return_path = entry.path 121 | else: 122 | return_path = osp.relpath(entry.path, root) 123 | 124 | if keywords is None: 125 | yield return_path 126 | elif return_path.find(keywords) > 0: 127 | yield return_path 128 | else: 129 | if recursive: 130 | yield from _scandir( 131 | entry.path, keywords=keywords, recursive=recursive) 132 | else: 133 | continue 134 | 135 | return _scandir(dir_path, keywords=keywords, recursive=recursive) 136 | 137 | def check_resume(opt, resume_iter): 138 | """Check resume states and pretrain_network paths. 139 | 140 | Args: 141 | opt (dict): Options. 142 | resume_iter (int): Resume iteration. 143 | """ 144 | logger = get_root_logger() 145 | if opt['path']['resume_state']: 146 | # get all the networks 147 | networks = [key for key in opt.keys() if key.startswith('network_')] 148 | flag_pretrain = False 149 | for network in networks: 150 | if opt['path'].get(f'pretrain_{network}') is not None: 151 | flag_pretrain = True 152 | if flag_pretrain: 153 | logger.warning( 154 | 'pretrain_network path will be ignored during resuming.') 155 | # set pretrained model paths 156 | for network in networks: 157 | name = f'pretrain_{network}' 158 | basename = network.replace('network_', '') 159 | if opt['path'].get('ignore_resume_networks') is None or ( 160 | basename not in opt['path']['ignore_resume_networks']): 161 | opt['path'][name] = osp.join( 162 | opt['path']['models'], f'net_{basename}_{resume_iter}.pth') 163 | logger.info(f"Set {name} to {opt['path'][name]}") 164 | 165 | 166 | def sizeof_fmt(size, suffix='B'): 167 | """Get human readable file size. 168 | 169 | Args: 170 | size (int): File size. 171 | suffix (str): Suffix. Default: 'B'. 172 | 173 | Return: 174 | str: Formated file siz. 175 | """ 176 | for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']: 177 | if abs(size) < 1024.0: 178 | return f'{size:3.1f} {unit}{suffix}' 179 | size /= 1024.0 180 | return f'{size:3.1f} Y{suffix}' 181 | -------------------------------------------------------------------------------- /basicsr/utils/options.py: -------------------------------------------------------------------------------- 1 | import yaml 2 | from collections import OrderedDict 3 | from os import path as osp 4 | 5 | 6 | def ordered_yaml(): 7 | """Support OrderedDict for yaml. 8 | 9 | Returns: 10 | yaml Loader and Dumper. 11 | """ 12 | try: 13 | from yaml import CDumper as Dumper 14 | from yaml import CLoader as Loader 15 | except ImportError: 16 | from yaml import Dumper, Loader 17 | 18 | _mapping_tag = yaml.resolver.BaseResolver.DEFAULT_MAPPING_TAG 19 | 20 | def dict_representer(dumper, data): 21 | return dumper.represent_dict(data.items()) 22 | 23 | def dict_constructor(loader, node): 24 | return OrderedDict(loader.construct_pairs(node)) 25 | 26 | Dumper.add_representer(OrderedDict, dict_representer) 27 | Loader.add_constructor(_mapping_tag, dict_constructor) 28 | return Loader, Dumper 29 | 30 | 31 | def parse(opt_path, is_train=True): 32 | """Parse option file. 33 | 34 | Args: 35 | opt_path (str): Option file path. 36 | is_train (str): Indicate whether in training or not. Default: True. 37 | 38 | Returns: 39 | (dict): Options. 40 | """ 41 | with open(opt_path, mode='r') as f: 42 | Loader, _ = ordered_yaml() 43 | opt = yaml.load(f, Loader=Loader) 44 | 45 | opt['is_train'] = is_train 46 | 47 | # datasets 48 | for phase, dataset in opt['datasets'].items(): 49 | # for several datasets, e.g., test_1, test_2 50 | phase = phase.split('_')[0] 51 | dataset['phase'] = phase 52 | if 'scale' in opt: 53 | dataset['scale'] = opt['scale'] 54 | if dataset.get('dataroot_gt') is not None: 55 | dataset['dataroot_gt'] = osp.expanduser(dataset['dataroot_gt']) 56 | if dataset.get('dataroot_lq') is not None: 57 | dataset['dataroot_lq'] = osp.expanduser(dataset['dataroot_lq']) 58 | 59 | # paths 60 | for key, val in opt['path'].items(): 61 | if (val is not None) and ('resume_state' in key 62 | or 'pretrain_network' in key): 63 | opt['path'][key] = osp.expanduser(val) 64 | opt['path']['root'] = osp.abspath( 65 | osp.join(__file__, osp.pardir, osp.pardir, osp.pardir)) 66 | if is_train: 67 | experiments_root = osp.join(opt['path']['root'], 'experiments', 68 | opt['name']) 69 | opt['path']['experiments_root'] = experiments_root 70 | opt['path']['models'] = osp.join(experiments_root, 'models') 71 | opt['path']['training_states'] = osp.join(experiments_root, 72 | 'training_states') 73 | opt['path']['log'] = experiments_root 74 | opt['path']['visualization'] = osp.join(experiments_root, 75 | 'visualization') 76 | 77 | # change some options for debug mode 78 | if 'debug' in opt['name']: 79 | if 'val' in opt: 80 | opt['val']['val_freq'] = 8 81 | opt['logger']['print_freq'] = 1 82 | opt['logger']['save_checkpoint_freq'] = 8 83 | else: # test 84 | results_root = osp.join(opt['path']['root'], 'results', opt['name']) 85 | opt['path']['results_root'] = results_root 86 | opt['path']['log'] = results_root 87 | opt['path']['visualization'] = osp.join(results_root, 'visualization') 88 | 89 | return opt 90 | 91 | 92 | def dict2str(opt, indent_level=1): 93 | """dict to string for printing options. 94 | 95 | Args: 96 | opt (dict): Option dict. 97 | indent_level (int): Indent level. Default: 1. 98 | 99 | Return: 100 | (str): Option string for printing. 101 | """ 102 | msg = '\n' 103 | for k, v in opt.items(): 104 | if isinstance(v, dict): 105 | msg += ' ' * (indent_level * 2) + k + ':[' 106 | msg += dict2str(v, indent_level + 1) 107 | msg += ' ' * (indent_level * 2) + ']\n' 108 | else: 109 | msg += ' ' * (indent_level * 2) + k + ': ' + str(v) + '\n' 110 | return msg 111 | -------------------------------------------------------------------------------- /options/test_JARNet_LAP.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: test_JARNet_LAP_20240114 3 | model_type: ImageMlossFlowModel 4 | scale: 1 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | test: 11 | name: ValSet 12 | type: LAPDataSet 13 | dataroot_gt: ~ # Please fill out the path to the LAP evaluation dataset. 14 | io_backend: 15 | type: disk 16 | 17 | network_g: 18 | type: JARNet 19 | img_channel: 1 20 | width: 32 21 | middle_blk_num: 4 22 | enc_blk_nums: [4, 4, 4, 4] 23 | dec_blk_nums: [4, 4, 4, 4] 24 | window_size_e_fft: [64, -1, -1, -1] 25 | 26 | # path 27 | path: 28 | pretrain_network_g: ~ # Please fill out the path to the pre-trained model. 29 | strict_load_g: true 30 | resume_state: ~ 31 | 32 | # validation settings 33 | val: 34 | save_img: false 35 | use_image: true # a litte different from true in PSNR and SSIM, we apply true for PSNR and SSIM, and false for gmsd (otherwise error occurs) 36 | 37 | metrics: 38 | psnr: 39 | type: calculate_psnr 40 | crop_border: 0 41 | test_y_channel: false 42 | ssim: 43 | type: calculate_ssim 44 | crop_border: 0 45 | test_y_channel: false 46 | # gmsd: 47 | # type: calculate_gmsd 48 | 49 | # dist training settings 50 | dist_params: 51 | backend: nccl 52 | port: 29500 53 | -------------------------------------------------------------------------------- /options/test_NAFNet_LAP.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: test_NAFNet_LAP_20240114 3 | model_type: ImageMlossModel 4 | scale: 1 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | test: 11 | name: ValSet 12 | type: LAPDataSetNoWarp 13 | dataroot_gt: ~ # Please fill out the path to the LAP evaluation dataset. 14 | io_backend: 15 | type: disk 16 | 17 | network_g: 18 | type: NAFNet 19 | img_channel: 1 20 | width: 64 21 | middle_blk_num: 1 22 | enc_blk_nums: [1, 1, 1, 28] 23 | dec_blk_nums: [1, 1, 1, 1] 24 | 25 | # path 26 | path: 27 | pretrain_network_g: ~ 28 | strict_load_g: true 29 | resume_state: ~ 30 | 31 | # validation settings 32 | val: 33 | save_img: False 34 | use_image: true # a litte different from true in PSNR and SSIM, we apply true for PSNR and SSIM, and false for gmsd (otherwise error occurs) 35 | 36 | metrics: 37 | psnr: 38 | type: calculate_psnr 39 | crop_border: 0 40 | test_y_channel: false 41 | ssim: 42 | type: calculate_ssim 43 | crop_border: 0 44 | test_y_channel: false 45 | # gmsd: 46 | # type: calculate_gmsd 47 | 48 | # dist training settings 49 | dist_params: 50 | backend: nccl 51 | port: 29500 52 | -------------------------------------------------------------------------------- /options/train_JARNet_LAP.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_JARNet_LAP_20240114 3 | model_type: ImageMlossFlowModel 4 | scale: 1 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: LAPDataSet 13 | dataroot_gt: ~ # Please fill out the path to the LAP training dataset. 14 | filename_tmpl: '{}' 15 | io_backend: 16 | type: disk 17 | 18 | # data loader 19 | use_shuffle: true 20 | num_worker_per_gpu: 4 21 | batch_size_per_gpu: 4 22 | 23 | mini_batch_sizes: [4] 24 | iters: [450000] 25 | gt_size: 128 26 | gt_sizes: [128] 27 | 28 | dataset_enlarge_ratio: 1 29 | prefetch_mode: ~ 30 | 31 | val: 32 | name: ValSet 33 | type: LAPDataSet 34 | dataroot_gt: ~ # Please fill out the path to the LAP evaluation dataset. 35 | io_backend: 36 | type: disk 37 | 38 | network_g: 39 | type: JARNet 40 | img_channel: 1 41 | width: 32 42 | middle_blk_num: 4 43 | enc_blk_nums: [4, 4, 4, 4] 44 | dec_blk_nums: [4, 4, 4, 4] 45 | window_size_e_fft: [64, -1, -1, -1] 46 | 47 | # path 48 | path: 49 | pretrain_network_g: ~ 50 | strict_load_g: true 51 | resume_state: ~ 52 | 53 | # training settings 54 | train: 55 | total_iter: 450000 56 | warmup_iter: -1 # no warm up 57 | use_grad_clip: true 58 | 59 | scheduler: 60 | type: TrueCosineAnnealingLR 61 | T_max: 450000 62 | eta_min: !!float 1e-7 63 | 64 | mixing_augs: 65 | mixup: false 66 | mixup_beta: 1.2 67 | use_identity: true 68 | 69 | optim_g: 70 | type: AdamW 71 | lr: !!float 3e-4 72 | weight_decay: !!float 1e-3 73 | betas: [0.9, 0.999] 74 | 75 | # losses 76 | pixel_opt: 77 | type: L1Loss 78 | loss_weight: 1 79 | reduction: mean 80 | 81 | flow_opt: 82 | type: PerceptualLoss 83 | loss_weight: 0.0001 84 | reduction: mean 85 | 86 | msfr_opt: 87 | type: FFTLoss 88 | loss_weight: 0.1 89 | reduction: mean 90 | 91 | flow_weight: 0.1 92 | 93 | # validation settings 94 | val: 95 | window_size: 8 96 | val_freq: !!float 5e4 97 | save_img: false 98 | save_last_img: true 99 | rgb2bgr: false 100 | use_image: true 101 | 102 | metrics: 103 | psnr: 104 | type: calculate_psnr 105 | crop_border: 0 106 | test_y_channel: false 107 | ssim: 108 | type: calculate_ssim 109 | crop_border: 0 110 | test_y_channel: false 111 | 112 | # logging settings 113 | logger: 114 | print_freq: 1000 115 | save_checkpoint_freq: !!float 5e4 116 | use_tb_logger: true 117 | wandb: 118 | project: ~ 119 | resume_id: ~ 120 | 121 | # dist training settings 122 | dist_params: 123 | backend: nccl 124 | port: 29500 125 | -------------------------------------------------------------------------------- /options/train_NAFNet_LAP.yml: -------------------------------------------------------------------------------- 1 | # general settings 2 | name: train_NAFNet_LAP_20240114 3 | model_type: ImageMlossModel 4 | scale: 1 5 | num_gpu: 1 # set num_gpu: 0 for cpu mode 6 | manual_seed: 100 7 | 8 | # dataset and data loader settings 9 | datasets: 10 | train: 11 | name: TrainSet 12 | type: LAPDataSetNoWarp 13 | dataroot_gt: ~ # Please fill out the path to the LAP training dataset. 14 | filename_tmpl: '{}' 15 | io_backend: 16 | type: disk 17 | 18 | use_shuffle: true 19 | num_worker_per_gpu: 4 20 | batch_size_per_gpu: 4 21 | 22 | mini_batch_sizes: [4] 23 | iters: [450000] 24 | gt_size: 128 25 | gt_sizes: [128] 26 | 27 | dataset_enlarge_ratio: 1 28 | prefetch_mode: ~ 29 | 30 | val: 31 | name: ValSet 32 | type: LAPDataSetNoWarp 33 | dataroot_gt: ~ # Please fill out the path to the LAP evaluation dataset. 34 | io_backend: 35 | type: disk 36 | 37 | network_g: 38 | type: NAFNet 39 | img_channel: 1 40 | width: 64 41 | middle_blk_num: 1 42 | enc_blk_nums: [1, 1, 1, 28] 43 | dec_blk_nums: [1, 1, 1, 1] 44 | 45 | # path 46 | path: 47 | pretrain_network_g: ~ 48 | strict_load_g: true 49 | resume_state: ~ 50 | 51 | # training settings 52 | train: 53 | total_iter: 450000 54 | warmup_iter: -1 # no warm up 55 | use_grad_clip: true 56 | 57 | scheduler: 58 | type: TrueCosineAnnealingLR 59 | T_max: 450000 60 | eta_min: !!float 1e-7 61 | 62 | mixing_augs: 63 | mixup: false 64 | mixup_beta: 1.2 65 | use_identity: true 66 | 67 | optim_g: 68 | type: AdamW 69 | lr: !!float 3e-4 70 | weight_decay: !!float 1e-3 71 | betas: [0.9, 0.9] 72 | 73 | pixel_opt: 74 | type: PSNRLoss 75 | loss_weight: 1 76 | reduction: mean 77 | 78 | # validation settings 79 | val: 80 | window_size: 8 81 | val_freq: !!float 5e4 82 | save_img: false 83 | save_last_img: true 84 | rgb2bgr: false 85 | use_image: true 86 | 87 | metrics: 88 | psnr: 89 | type: calculate_psnr 90 | crop_border: 0 91 | test_y_channel: false 92 | ssim: 93 | type: calculate_ssim 94 | crop_border: 0 95 | test_y_channel: false 96 | 97 | # logging settings 98 | logger: 99 | print_freq: 1000 100 | save_checkpoint_freq: !!float 5e4 101 | use_tb_logger: true 102 | wandb: 103 | project: ~ 104 | resume_id: ~ 105 | 106 | # dist training settings 107 | dist_params: 108 | backend: nccl 109 | port: 29500 110 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | einops 2 | h5py 3 | lmdb 4 | matplotlib 5 | numpy 6 | opencv-python 7 | ptflops 8 | pyyaml 9 | scikit-image 10 | scipy 11 | torch 12 | torchvision 13 | tqdm 14 | wandb -------------------------------------------------------------------------------- /setup.cfg: -------------------------------------------------------------------------------- 1 | [flake8] 2 | ignore = 3 | # line break before binary operator (W503) 4 | W503, 5 | # line break after binary operator (W504) 6 | W504, 7 | max-line-length=79 8 | 9 | [yapf] 10 | based_on_style = pep8 11 | blank_line_before_nested_class_or_def = true 12 | split_before_expression_after_opening_paren = true 13 | 14 | [isort] 15 | line_length = 79 16 | multi_line_output = 0 17 | known_standard_library = pkg_resources,setuptools 18 | known_first_party = basicsr 19 | known_third_party = PIL,cv2,lmdb,numpy,requests,scipy,skimage,torch,torchvision,tqdm,yaml 20 | no_lines_before = STDLIB,LOCALFOLDER 21 | default_section = THIRDPARTY 22 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | 3 | from setuptools import find_packages, setup 4 | 5 | import os 6 | import subprocess 7 | import sys 8 | import time 9 | import torch 10 | from torch.utils.cpp_extension import (BuildExtension, CppExtension, 11 | CUDAExtension) 12 | 13 | version_file = 'basicsr/version.py' 14 | 15 | 16 | def readme(): 17 | return '' 18 | # with open('README.md', encoding='utf-8') as f: 19 | # content = f.read() 20 | # return content 21 | 22 | 23 | def get_git_hash(): 24 | 25 | def _minimal_ext_cmd(cmd): 26 | # construct minimal environment 27 | env = {} 28 | for k in ['SYSTEMROOT', 'PATH', 'HOME']: 29 | v = os.environ.get(k) 30 | if v is not None: 31 | env[k] = v 32 | # LANGUAGE is used on win32 33 | env['LANGUAGE'] = 'C' 34 | env['LANG'] = 'C' 35 | env['LC_ALL'] = 'C' 36 | out = subprocess.Popen( 37 | cmd, stdout=subprocess.PIPE, env=env).communicate()[0] 38 | return out 39 | 40 | try: 41 | out = _minimal_ext_cmd(['git', 'rev-parse', 'HEAD']) 42 | sha = out.strip().decode('ascii') 43 | except OSError: 44 | sha = 'unknown' 45 | 46 | return sha 47 | 48 | 49 | def get_hash(): 50 | if os.path.exists('.git'): 51 | sha = get_git_hash()[:7] 52 | elif os.path.exists(version_file): 53 | try: 54 | from basicsr.version import __version__ 55 | sha = __version__.split('+')[-1] 56 | except ImportError: 57 | raise ImportError('Unable to get git version') 58 | else: 59 | sha = 'unknown' 60 | 61 | return sha 62 | 63 | 64 | def write_version_py(): 65 | content = """# GENERATED VERSION FILE 66 | # TIME: {} 67 | __version__ = '{}' 68 | short_version = '{}' 69 | version_info = ({}) 70 | """ 71 | sha = get_hash() 72 | with open('VERSION', 'r') as f: 73 | SHORT_VERSION = f.read().strip() 74 | VERSION_INFO = ', '.join( 75 | [x if x.isdigit() else f'"{x}"' for x in SHORT_VERSION.split('.')]) 76 | VERSION = SHORT_VERSION + '+' + sha 77 | 78 | version_file_str = content.format(time.asctime(), VERSION, SHORT_VERSION, 79 | VERSION_INFO) 80 | with open(version_file, 'w') as f: 81 | f.write(version_file_str) 82 | 83 | 84 | def get_version(): 85 | with open(version_file, 'r') as f: 86 | exec(compile(f.read(), version_file, 'exec')) 87 | return locals()['__version__'] 88 | 89 | 90 | def make_cuda_ext(name, module, sources, sources_cuda=None): 91 | if sources_cuda is None: 92 | sources_cuda = [] 93 | define_macros = [] 94 | extra_compile_args = {'cxx': []} 95 | 96 | if torch.cuda.is_available() or os.getenv('FORCE_CUDA', '0') == '1': 97 | define_macros += [('WITH_CUDA', None)] 98 | extension = CUDAExtension 99 | extra_compile_args['nvcc'] = [ 100 | '-D__CUDA_NO_HALF_OPERATORS__', 101 | '-D__CUDA_NO_HALF_CONVERSIONS__', 102 | '-D__CUDA_NO_HALF2_OPERATORS__', 103 | ] 104 | sources += sources_cuda 105 | else: 106 | print(f'Compiling {name} without CUDA') 107 | extension = CppExtension 108 | 109 | return extension( 110 | name=f'{module}.{name}', 111 | sources=[os.path.join(*module.split('.'), p) for p in sources], 112 | define_macros=define_macros, 113 | extra_compile_args=extra_compile_args) 114 | 115 | 116 | def get_requirements(filename='requirements.txt'): 117 | return [] 118 | here = os.path.dirname(os.path.realpath(__file__)) 119 | with open(os.path.join(here, filename), 'r') as f: 120 | requires = [line.replace('\n', '') for line in f.readlines()] 121 | return requires 122 | 123 | 124 | if __name__ == '__main__': 125 | if '--no_cuda_ext' in sys.argv: 126 | ext_modules = [] 127 | sys.argv.remove('--no_cuda_ext') 128 | else: 129 | ext_modules = [ 130 | make_cuda_ext( 131 | name='deform_conv_ext', 132 | module='basicsr.models.ops.dcn', 133 | sources=['src/deform_conv_ext.cpp'], 134 | sources_cuda=[ 135 | 'src/deform_conv_cuda.cpp', 136 | 'src/deform_conv_cuda_kernel.cu' 137 | ]), 138 | make_cuda_ext( 139 | name='fused_act_ext', 140 | module='basicsr.models.ops.fused_act', 141 | sources=['src/fused_bias_act.cpp'], 142 | sources_cuda=['src/fused_bias_act_kernel.cu']), 143 | make_cuda_ext( 144 | name='upfirdn2d_ext', 145 | module='basicsr.models.ops.upfirdn2d', 146 | sources=['src/upfirdn2d.cpp'], 147 | sources_cuda=['src/upfirdn2d_kernel.cu']), 148 | ] 149 | 150 | write_version_py() 151 | setup( 152 | name='basicsr', 153 | version=get_version(), 154 | description='Open Source Image and Video Super-Resolution Toolbox', 155 | long_description=readme(), 156 | author='Xintao Wang', 157 | author_email='xintao.wang@outlook.com', 158 | keywords='computer vision, restoration, super resolution', 159 | url='https://github.com/xinntao/BasicSR', 160 | packages=find_packages( 161 | exclude=('options', 'datasets', 'experiments', 'results', 162 | 'tb_logger', 'wandb')), 163 | classifiers=[ 164 | 'Development Status :: 4 - Beta', 165 | 'License :: OSI Approved :: Apache Software License', 166 | 'Operating System :: OS Independent', 167 | 'Programming Language :: Python :: 3', 168 | 'Programming Language :: Python :: 3.7', 169 | 'Programming Language :: Python :: 3.8', 170 | ], 171 | license='Apache License 2.0', 172 | setup_requires=['cython', 'numpy'], 173 | install_requires=get_requirements(), 174 | ext_modules=ext_modules, 175 | cmdclass={'build_ext': BuildExtension}, 176 | zip_safe=False) 177 | --------------------------------------------------------------------------------