├── .DS_Store ├── LICENSE ├── README.md ├── dataset ├── __pycache__ │ ├── dataset_defocusblur.cpython-37.pyc │ ├── dataset_defocusblur.cpython-38.pyc │ ├── dataset_dehaze.cpython-37.pyc │ ├── dataset_dehaze_denseHaze.cpython-37.pyc │ ├── dataset_dehaze_denseHaze.cpython-38.pyc │ ├── dataset_demoire.cpython-37.pyc │ ├── dataset_derain.cpython-37.pyc │ ├── dataset_derain.cpython-38.pyc │ ├── dataset_derain_drop.cpython-37.pyc │ ├── dataset_derain_syn.cpython-37.pyc │ ├── dataset_deshadow.cpython-37.pyc │ ├── dataset_desnow.cpython-37.pyc │ ├── dataset_enhance_smid.cpython-37.pyc │ └── dataset_motiondeblur.cpython-37.pyc ├── dataset_dehaze_denseHaze.py ├── dataset_derain.py └── dataset_derain_drop.py ├── evaluate_PSNR_SSIM.m ├── losses.py ├── model.py ├── options.py ├── requirements.txt ├── script ├── test.sh ├── train_dehaze.sh ├── train_derain.sh └── train_raindrop.sh ├── test ├── test_denseHaze.py ├── test_raindrop.py └── test_spad.py ├── train ├── train_dehaze.py ├── train_derain.py └── train_raindrop.py ├── utils ├── .DS_Store ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── __init__.cpython-38.pyc │ ├── caculate_psnr_ssim.cpython-37.pyc │ ├── caculate_psnr_ssim.cpython-38.pyc │ ├── dataset_utils.cpython-37.pyc │ ├── dataset_utils.cpython-38.pyc │ ├── dir_utils.cpython-37.pyc │ ├── dir_utils.cpython-38.pyc │ ├── image_utils.cpython-37.pyc │ ├── image_utils.cpython-38.pyc │ ├── model_utils.cpython-37.pyc │ └── model_utils.cpython-38.pyc ├── antialias.py ├── bundle_submissions.py ├── caculate_psnr_ssim.py ├── dataset_utils.py ├── dir_utils.py ├── image_utils.py ├── loader.py └── model_utils.py └── warmup_scheduler ├── __init__.py ├── __pycache__ ├── __init__.cpython-37.pyc ├── __init__.cpython-38.pyc ├── scheduler.cpython-37.pyc └── scheduler.cpython-38.pyc ├── run.py └── scheduler.py /.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/.DS_Store -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restoration (CVPR 2024) 2 | 3 | [Shihao Zhou](https://joshyzhou.github.io/), [Duosheng Chen](https://github.com/Calvin11311), [Jinshan Pan](https://jspan.github.io/), [Jinglei Shi](https://jingleishi.github.io/), and [Jufeng Yang](https://cv.nankai.edu.cn/) 4 | 5 | 6 | #### News 7 | - **Feb 27, 2024:** AST has been accepted to CVPR 2024 :tada: 8 | 9 |
10 | 11 | 13 | 14 | ## Package dependencies 15 | The project is built with PyTorch 1.9.0, Python3.7, CUDA11.1. For package dependencies, you can install them by: 16 | ```bash 17 | pip install -r requirements.txt 18 | ``` 19 | ## Training 20 | ### Derain 21 | To train AST on SPAD, you can run: 22 | ```sh 23 | sh script/train_derain.sh 24 | ``` 25 | ### Dehaze 26 | To train AST on Densehaze, you can run: 27 | ```sh 28 | sh script/train_dehaze.sh 29 | ``` 30 | ### Raindrop 31 | To train AST on AGAN, you can run: 32 | ```sh 33 | sh script/train_raindrop.sh 34 | ``` 35 | 36 | 37 | ## Evaluation 38 | To evaluate AST, you can run: 39 | 40 | ```sh 41 | sh script/test.sh 42 | ``` 43 | For evaluate on each dataset, you should uncomment corresponding line. 44 | 45 | 46 | ## Results 47 | Experiments are performed for different image processing tasks including, rain streak removal, raindrop removal, and haze removal. 48 | Here is a summary table containing hyperlinks for easy navigation: 49 | 50 | 51 | 52 | 53 | 54 | 55 | 56 | 57 | 58 | 59 | 60 | 61 | 62 | 63 | 64 | 65 | 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 |
BenchmarkPretrained modelVisual Results
SPAD(code:h68m)(code:wqdg)
AGAN(code:astt)(code:astt)
Dense-Haze(code:astt)(code:astt)
74 | 75 | ## Citation 76 | If you find this project useful, please consider citing: 77 | 78 | @inproceedings{zhou2024AST, 79 | title={Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restoration}, 80 | author={Zhou, Shihao and Chen, Duosheng and Pan, Jinshan and Shi, Jinglei and Yang, Jufeng}, 81 | booktitle={CVPR}, 82 | year={2024} 83 | } 84 | 85 | ## Acknowledgement 86 | 87 | This code borrows heavily from [Uformer](https://github.com/ZhendongWang6/Uformer). 88 | -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_defocusblur.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_defocusblur.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_defocusblur.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_defocusblur.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_dehaze.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_dehaze.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_dehaze_denseHaze.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_dehaze_denseHaze.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_dehaze_denseHaze.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_dehaze_denseHaze.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_demoire.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_demoire.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_derain.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_derain.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_derain.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_derain.cpython-38.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_derain_drop.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_derain_drop.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_derain_syn.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_derain_syn.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_deshadow.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_deshadow.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_desnow.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_desnow.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_enhance_smid.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_enhance_smid.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/__pycache__/dataset_motiondeblur.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_motiondeblur.cpython-37.pyc -------------------------------------------------------------------------------- /dataset/dataset_dehaze_denseHaze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from torch.utils.data import Dataset 4 | import torch 5 | from utils import is_png_file, load_img, Augment_RGB_torch 6 | import torch.nn.functional as F 7 | import random 8 | from PIL import Image 9 | import torchvision.transforms.functional as TF 10 | from natsort import natsorted 11 | from glob import glob 12 | augment = Augment_RGB_torch() 13 | transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')] 14 | 15 | 16 | def is_image_file(filename): 17 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 18 | 19 | ################################################################################################## 20 | class DataLoaderTrain(Dataset): 21 | def __init__(self, rgb_dir, img_options=None, target_transform=None): 22 | super(DataLoaderTrain, self).__init__() 23 | 24 | self.target_transform = target_transform 25 | 26 | gt_dir = 'gt' 27 | input_dir = 'input' 28 | 29 | clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir))) 30 | noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) 31 | 32 | self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_png_file(x)] 33 | self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files if is_png_file(x)] 34 | 35 | self.img_options=img_options 36 | 37 | self.tar_size = len(self.clean_filenames) # get the size of target 38 | 39 | def __len__(self): 40 | return self.tar_size 41 | 42 | def __getitem__(self, index): 43 | tar_index = index % self.tar_size 44 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index]))) 45 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index]))) 46 | 47 | clean = clean.permute(2,0,1) 48 | noisy = noisy.permute(2,0,1) 49 | 50 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1] 51 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] 52 | 53 | #Crop Input and Target 54 | ps = self.img_options['patch_size'] 55 | H = clean.shape[1] 56 | W = clean.shape[2] 57 | # r = np.random.randint(0, H - ps) if not H-ps else 0 58 | # c = np.random.randint(0, W - ps) if not H-ps else 0 59 | if H-ps==0: 60 | r=0 61 | c=0 62 | else: 63 | r = np.random.randint(0, H - ps) 64 | c = np.random.randint(0, W - ps) 65 | clean = clean[:, r:r + ps, c:c + ps] 66 | noisy = noisy[:, r:r + ps, c:c + ps] 67 | 68 | apply_trans = transforms_aug[random.getrandbits(3)] 69 | 70 | clean = getattr(augment, apply_trans)(clean) 71 | noisy = getattr(augment, apply_trans)(noisy) 72 | 73 | return clean, noisy, clean_filename, noisy_filename 74 | 75 | 76 | ################################################################################################## 77 | class DataLoaderVal(Dataset): 78 | def __init__(self, rgb_dir, img_options=None, target_transform=None): 79 | super(DataLoaderVal, self).__init__() 80 | 81 | self.target_transform = target_transform 82 | 83 | gt_dir = 'gt' 84 | input_dir = 'input' 85 | 86 | clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir))) 87 | noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) 88 | 89 | 90 | self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_png_file(x)] 91 | self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files if is_png_file(x)] 92 | 93 | self.img_options=img_options 94 | self.tar_size = len(self.clean_filenames) 95 | 96 | def __len__(self): 97 | return self.tar_size 98 | 99 | def __getitem__(self, index): 100 | tar_index = index % self.tar_size 101 | 102 | inp_path = self.noisy_filenames[tar_index] 103 | tar_path = self.clean_filenames[tar_index] 104 | 105 | inp_img = Image.open(inp_path) 106 | tar_img = Image.open(tar_path) 107 | 108 | # Validate on center crop 109 | if self.img_options: 110 | ps = self.img_options['patch_size'] 111 | # if ps is not None: 112 | inp_img = TF.center_crop(inp_img, (ps,ps)) 113 | tar_img = TF.center_crop(tar_img, (ps,ps)) 114 | 115 | noisy = TF.to_tensor(inp_img) 116 | clean = TF.to_tensor(tar_img) 117 | 118 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1] 119 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] 120 | 121 | 122 | return clean, noisy, clean_filename, noisy_filename 123 | 124 | ################################################################################################## 125 | 126 | class DataLoaderTest(Dataset): 127 | def __init__(self, inp_dir, img_options): 128 | super(DataLoaderTest, self).__init__() 129 | 130 | inp_files = sorted(os.listdir(inp_dir)) 131 | self.inp_filenames = [os.path.join(inp_dir, x) for x in inp_files if is_image_file(x)] 132 | 133 | self.inp_size = len(self.inp_filenames) 134 | self.img_options = img_options 135 | 136 | def __len__(self): 137 | return self.inp_size 138 | 139 | def __getitem__(self, index): 140 | 141 | path_inp = self.inp_filenames[index] 142 | filename = os.path.splitext(os.path.split(path_inp)[-1])[0] 143 | inp = Image.open(path_inp) 144 | 145 | inp = TF.to_tensor(inp) 146 | return inp, filename 147 | 148 | 149 | def get_training_data(rgb_dir, img_options): 150 | assert os.path.exists(rgb_dir) 151 | return DataLoaderTrain(rgb_dir, img_options, None) 152 | 153 | 154 | def get_validation_data(rgb_dir,img_options=None): 155 | assert os.path.exists(rgb_dir) 156 | return DataLoaderVal(rgb_dir, img_options, None) 157 | 158 | def get_test_data(rgb_dir, img_options=None): 159 | assert os.path.exists(rgb_dir) 160 | return DataLoaderTest(rgb_dir, img_options) -------------------------------------------------------------------------------- /dataset/dataset_derain.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from torch.utils.data import Dataset 4 | import torch 5 | from utils import is_png_file, load_img, Augment_RGB_torch 6 | import torch.nn.functional as F 7 | import random 8 | from PIL import Image 9 | import torchvision.transforms.functional as TF 10 | from natsort import natsorted 11 | from glob import glob 12 | augment = Augment_RGB_torch() 13 | transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')] 14 | 15 | 16 | def is_image_file(filename): 17 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 18 | 19 | ################################################################################################## 20 | class DataLoaderTrain(Dataset): 21 | def __init__(self, rgb_dir, img_options=None, target_transform=None): 22 | super(DataLoaderTrain, self).__init__() 23 | 24 | self.target_transform = target_transform 25 | 26 | name = 'real_world.txt' 27 | # name = 'real_test_1000.txt' 28 | self.dataset = os.path.join(rgb_dir, name) 29 | 30 | self.mat_files = open(self.dataset, 'r').readlines() 31 | 32 | self.clean_filenames = [os.path.join(rgb_dir, x.split(' ')[1][1:-1]) for x in self.mat_files ] 33 | self.noisy_filenames = [os.path.join(rgb_dir, x.split(' ')[0][1:]) for x in self.mat_files ] 34 | 35 | self.img_options=img_options 36 | 37 | self.tar_size = len(self.clean_filenames) # get the size of target 38 | 39 | def __len__(self): 40 | return self.tar_size 41 | 42 | def __getitem__(self, index): 43 | tar_index = index % self.tar_size 44 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index]))) 45 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index]))) 46 | 47 | clean = clean.permute(2,0,1) 48 | noisy = noisy.permute(2,0,1) 49 | 50 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1] 51 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] 52 | 53 | #Crop Input and Target 54 | ps = self.img_options['patch_size'] 55 | H = clean.shape[1] 56 | W = clean.shape[2] 57 | if H-ps==0: 58 | r=0 59 | c=0 60 | else: 61 | r = np.random.randint(0, H - ps) 62 | c = np.random.randint(0, W - ps) 63 | clean = clean[:, r:r + ps, c:c + ps] 64 | noisy = noisy[:, r:r + ps, c:c + ps] 65 | 66 | apply_trans = transforms_aug[random.getrandbits(3)] 67 | 68 | clean = getattr(augment, apply_trans)(clean) 69 | noisy = getattr(augment, apply_trans)(noisy) 70 | 71 | return clean, noisy, clean_filename, noisy_filename 72 | 73 | 74 | ################################################################################################## 75 | class DataLoaderVal(Dataset): 76 | def __init__(self, rgb_dir, target_transform=None): 77 | super(DataLoaderVal, self).__init__() 78 | 79 | self.target_transform = target_transform 80 | 81 | # name='demo.txt' 82 | name = 'real_test_1000.txt' 83 | # name = 'Real_Internet.txt' 84 | self.dataset = os.path.join(rgb_dir, name) 85 | self.mat_files = open(self.dataset, 'r').readlines() 86 | 87 | self.clean_filenames = [os.path.join(rgb_dir, x.split(' ')[1][1:-1]) for x in self.mat_files] 88 | self.noisy_filenames = [os.path.join(rgb_dir, x.split(' ')[0][1:]) for x in self.mat_files] 89 | 90 | 91 | self.tar_size = len(self.clean_filenames) 92 | 93 | def __len__(self): 94 | return self.tar_size 95 | 96 | def __getitem__(self, index): 97 | tar_index = index % self.tar_size 98 | 99 | 100 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index]))) 101 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index]))) 102 | 103 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1] 104 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] 105 | 106 | clean = clean.permute(2,0,1) 107 | noisy = noisy.permute(2,0,1) 108 | 109 | return clean, noisy, clean_filename, noisy_filename 110 | 111 | ################################################################################################## 112 | 113 | class DataLoaderTest(Dataset): 114 | def __init__(self, inp_dir, img_options): 115 | super(DataLoaderTest, self).__init__() 116 | 117 | # name = 'Real_Internet.txt' 118 | name = 'real_test_1000.txt' 119 | self.dataset = os.path.join(inp_dir, name) 120 | self.mat_files = open(self.dataset, 'r').readlines() 121 | 122 | self.clean_filenames = [os.path.join(inp_dir, x.split(' ')[1][1:-1]) for x in self.mat_files] 123 | self.noisy_filenames = [os.path.join(inp_dir, x.split(' ')[0][1:]) for x in self.mat_files] 124 | 125 | 126 | inp_files = sorted(os.listdir(inp_dir)) 127 | self.inp_filenames = [os.path.join(inp_dir, x) for x in inp_files if is_image_file(x)] 128 | 129 | self.inp_size = len(self.clean_filenames) 130 | self.img_options = img_options 131 | 132 | def __len__(self): 133 | return self.inp_size 134 | 135 | def __getitem__(self, index): 136 | 137 | path_inp = self.inp_filenames[index] 138 | filename = os.path.splitext(os.path.split(path_inp)[-1])[0] 139 | inp = Image.open(path_inp) 140 | 141 | inp = TF.to_tensor(inp) 142 | return inp, filename 143 | 144 | 145 | def get_training_data(rgb_dir, img_options): 146 | assert os.path.exists(rgb_dir) 147 | return DataLoaderTrain(rgb_dir, img_options, None) 148 | 149 | 150 | def get_validation_data(rgb_dir): 151 | assert os.path.exists(rgb_dir) 152 | return DataLoaderVal(rgb_dir, None) 153 | 154 | def get_test_data(rgb_dir, img_options=None): 155 | assert os.path.exists(rgb_dir) 156 | return DataLoaderTest(rgb_dir, img_options) -------------------------------------------------------------------------------- /dataset/dataset_derain_drop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os 3 | from torch.utils.data import Dataset 4 | import torch 5 | from utils import is_png_file, load_img, Augment_RGB_torch 6 | import torch.nn.functional as F 7 | import random 8 | from PIL import Image 9 | import torchvision.transforms.functional as TF 10 | from natsort import natsorted 11 | from glob import glob 12 | augment = Augment_RGB_torch() 13 | transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')] 14 | 15 | 16 | def is_image_file(filename): 17 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif']) 18 | 19 | ################################################################################################## 20 | class DataLoaderTrain(Dataset): 21 | def __init__(self, rgb_dir, img_options=None, target_transform=None): 22 | super(DataLoaderTrain, self).__init__() 23 | 24 | self.target_transform = target_transform 25 | gt_dir = 'gt' 26 | input_dir = 'data' 27 | clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir))) 28 | noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) 29 | 30 | self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_image_file(x)] 31 | self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files if is_image_file(x)] 32 | 33 | 34 | self.img_options=img_options 35 | 36 | self.tar_size = len(self.clean_filenames) # get the size of target 37 | 38 | def __len__(self): 39 | return self.tar_size 40 | 41 | def __getitem__(self, index): 42 | tar_index = index % self.tar_size 43 | 44 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index]))) 45 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index]))) 46 | 47 | clean = clean.permute(2,0,1) 48 | noisy = noisy.permute(2,0,1) 49 | 50 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1] 51 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] 52 | 53 | #Crop Input and Target 54 | ps = self.img_options['patch_size'] 55 | H = clean.shape[1] 56 | W = clean.shape[2] 57 | if H-ps==0: 58 | r=0 59 | c=0 60 | else: 61 | r = np.random.randint(0, H - ps) 62 | c = np.random.randint(0, W - ps) 63 | clean = clean[:, r:r + ps, c:c + ps] 64 | noisy = noisy[:, r:r + ps, c:c + ps] 65 | 66 | apply_trans = transforms_aug[random.getrandbits(3)] 67 | 68 | clean = getattr(augment, apply_trans)(clean) 69 | noisy = getattr(augment, apply_trans)(noisy) 70 | 71 | return clean, noisy, clean_filename, noisy_filename 72 | 73 | 74 | ################################################################################################## 75 | class DataLoaderVal(Dataset): 76 | def __init__(self, rgb_dir, img_options=None, target_transform=None): 77 | super(DataLoaderVal, self).__init__() 78 | 79 | self.target_transform = target_transform 80 | gt_dir = 'gt' 81 | input_dir = 'data' 82 | 83 | clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir))) 84 | noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) 85 | 86 | 87 | self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_image_file(x)] 88 | self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files if is_image_file(x)] 89 | self.img_options=img_options 90 | 91 | self.tar_size = len(self.clean_filenames) 92 | 93 | def __len__(self): 94 | return self.tar_size 95 | 96 | def __getitem__(self, index): 97 | tar_index = index % self.tar_size 98 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index]))) 99 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index]))) 100 | 101 | clean = clean.permute(2,0,1) 102 | noisy = noisy.permute(2,0,1) 103 | 104 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1] 105 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] 106 | 107 | #Crop Input and Target 108 | ps = self.img_options['patch_size'] 109 | H = clean.shape[1] 110 | W = clean.shape[2] 111 | if H-ps==0: 112 | r=0 113 | c=0 114 | else: 115 | r = (H - ps)//2 116 | c = (W - ps)//2 117 | clean = clean[:, r:r + ps, c:c + ps] 118 | noisy = noisy[:, r:r + ps, c:c + ps] 119 | 120 | return clean, noisy, clean_filename, noisy_filename 121 | 122 | ################################################################################################## 123 | 124 | class DataLoaderTest(Dataset): 125 | def __init__(self, rgb_dir): 126 | super(DataLoaderTest, self).__init__() 127 | 128 | 129 | gt_dir = 'gt' 130 | input_dir = 'data' 131 | 132 | clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir))) 133 | noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir))) 134 | 135 | 136 | self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_image_file(x)] 137 | self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files if is_image_file(x)] 138 | 139 | self.tar_size = len(self.clean_filenames) 140 | 141 | 142 | 143 | 144 | def __len__(self): 145 | return self.tar_size 146 | 147 | def __getitem__(self, index): 148 | 149 | tar_index = index % self.tar_size 150 | 151 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index]))) 152 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index]))) 153 | 154 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1] 155 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1] 156 | 157 | clean = clean.permute(2,0,1) 158 | noisy = noisy.permute(2,0,1) 159 | return clean, noisy, clean_filename, noisy_filename 160 | 161 | def get_training_data(rgb_dir, img_options): 162 | assert os.path.exists(rgb_dir) 163 | return DataLoaderTrain(rgb_dir, img_options, None) 164 | 165 | 166 | def get_validation_data(rgb_dir, img_options=None): 167 | assert os.path.exists(rgb_dir) 168 | return DataLoaderVal(rgb_dir, img_options, None) 169 | 170 | def get_test_data(rgb_dir): 171 | assert os.path.exists(rgb_dir) 172 | return DataLoaderTest(rgb_dir) -------------------------------------------------------------------------------- /evaluate_PSNR_SSIM.m: -------------------------------------------------------------------------------- 1 | clc;close all;clear all;addpath(genpath('./')); 2 | 3 | datasets = {'Rain100L'}; 4 | % datasets = {'Rain200L', 'Rain200H', 'SPA-Data'}; 5 | num_set = length(datasets); 6 | 7 | psnr_alldatasets = 0; 8 | ssim_alldatasets = 0; 9 | detail_res = zeros(1000,1); 10 | for idx_set = 1:num_set 11 | file_path = strcat('./4test/spad/AST_B/'); 12 | gt_path = strcat('./4test/real_test_1000/gt/'); 13 | path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))]; 14 | gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))]; 15 | img_num = length(path_list); 16 | 17 | total_psnr = 0; 18 | total_ssim = 0; 19 | if img_num > 0 20 | for j = 1:img_num 21 | image_name = path_list(j).name; 22 | gt_name = gt_list(j).name; 23 | input = imread(strcat(file_path,image_name)); 24 | gt = imread(strcat(gt_path, gt_name)); 25 | ssim_val = compute_ssim(input, gt); 26 | psnr_val = compute_psnr(input, gt); 27 | total_ssim = total_ssim + ssim_val; 28 | total_psnr = total_psnr + psnr_val; 29 | detail_res(j,1)=psnr_val; 30 | end 31 | end 32 | qm_psnr = total_psnr / img_num; 33 | qm_ssim = total_ssim / img_num; 34 | 35 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim); 36 | 37 | psnr_alldatasets = psnr_alldatasets + qm_psnr; 38 | ssim_alldatasets = ssim_alldatasets + qm_ssim; 39 | 40 | end 41 | 42 | fprintf('For all datasets PSNR: %f SSIM: %f\n', psnr_alldatasets/num_set, ssim_alldatasets/num_set); 43 | 44 | function ssim_mean=compute_ssim(img1,img2) 45 | if size(img1, 3) == 3 46 | img1 = rgb2ycbcr(img1); 47 | img1 = img1(:, :, 1); 48 | end 49 | 50 | if size(img2, 3) == 3 51 | img2 = rgb2ycbcr(img2); 52 | img2 = img2(:, :, 1); 53 | end 54 | ssim_mean = SSIM_index(img1, img2); 55 | end 56 | 57 | function psnr=compute_psnr(img1,img2) 58 | if size(img1, 3) == 3 59 | img1 = rgb2ycbcr(img1); 60 | img1 = img1(:, :, 1); 61 | end 62 | 63 | if size(img2, 3) == 3 64 | img2 = rgb2ycbcr(img2); 65 | img2 = img2(:, :, 1); 66 | end 67 | 68 | imdff = double(img1) - double(img2); 69 | imdff = imdff(:); 70 | rmse = sqrt(mean(imdff.^2)); 71 | psnr = 20*log10(255/rmse); 72 | 73 | end 74 | 75 | function [mssim, ssim_map] = SSIM_index(img1, img2, K, window, L) 76 | 77 | %======================================================================== 78 | %SSIM Index, Version 1.0 79 | %Copyright(c) 2003 Zhou Wang 80 | %All Rights Reserved. 81 | % 82 | %The author is with Howard Hughes Medical Institute, and Laboratory 83 | %for Computational Vision at Center for Neural Science and Courant 84 | %Institute of Mathematical Sciences, New York University. 85 | % 86 | %---------------------------------------------------------------------- 87 | %Permission to use, copy, or modify this software and its documentation 88 | %for educational and research purposes only and without fee is hereby 89 | %granted, provided that this copyright notice and the original authors' 90 | %names appear on all copies and supporting documentation. This program 91 | %shall not be used, rewritten, or adapted as the basis of a commercial 92 | %software or hardware product without first obtaining permission of the 93 | %authors. The authors make no representations about the suitability of 94 | %this software for any purpose. It is provided "as is" without express 95 | %or implied warranty. 96 | %---------------------------------------------------------------------- 97 | % 98 | %This is an implementation of the algorithm for calculating the 99 | %Structural SIMilarity (SSIM) index between two images. Please refer 100 | %to the following paper: 101 | % 102 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image 103 | %quality assessment: From error measurement to structural similarity" 104 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004. 105 | % 106 | %Kindly report any suggestions or corrections to zhouwang@ieee.org 107 | % 108 | %---------------------------------------------------------------------- 109 | % 110 | %Input : (1) img1: the first image being compared 111 | % (2) img2: the second image being compared 112 | % (3) K: constants in the SSIM index formula (see the above 113 | % reference). defualt value: K = [0.01 0.03] 114 | % (4) window: local window for statistics (see the above 115 | % reference). default widnow is Gaussian given by 116 | % window = fspecial('gaussian', 11, 1.5); 117 | % (5) L: dynamic range of the images. default: L = 255 118 | % 119 | %Output: (1) mssim: the mean SSIM index value between 2 images. 120 | % If one of the images being compared is regarded as 121 | % perfect quality, then mssim can be considered as the 122 | % quality measure of the other image. 123 | % If img1 = img2, then mssim = 1. 124 | % (2) ssim_map: the SSIM index map of the test image. The map 125 | % has a smaller size than the input images. The actual size: 126 | % size(img1) - size(window) + 1. 127 | % 128 | %Default Usage: 129 | % Given 2 test images img1 and img2, whose dynamic range is 0-255 130 | % 131 | % [mssim ssim_map] = ssim_index(img1, img2); 132 | % 133 | %Advanced Usage: 134 | % User defined parameters. For example 135 | % 136 | % K = [0.05 0.05]; 137 | % window = ones(8); 138 | % L = 100; 139 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L); 140 | % 141 | %See the results: 142 | % 143 | % mssim %Gives the mssim value 144 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map 145 | % 146 | %======================================================================== 147 | 148 | 149 | if (nargin < 2 || nargin > 5) 150 | ssim_index = -Inf; 151 | ssim_map = -Inf; 152 | return; 153 | end 154 | 155 | if (size(img1) ~= size(img2)) 156 | ssim_index = -Inf; 157 | ssim_map = -Inf; 158 | return; 159 | end 160 | 161 | [M N] = size(img1); 162 | 163 | if (nargin == 2) 164 | if ((M < 11) || (N < 11)) 165 | ssim_index = -Inf; 166 | ssim_map = -Inf; 167 | return 168 | end 169 | window = fspecial('gaussian', 11, 1.5); % 170 | K(1) = 0.01; % default settings 171 | K(2) = 0.03; % 172 | L = 255; % 173 | end 174 | 175 | if (nargin == 3) 176 | if ((M < 11) || (N < 11)) 177 | ssim_index = -Inf; 178 | ssim_map = -Inf; 179 | return 180 | end 181 | window = fspecial('gaussian', 11, 1.5); 182 | L = 255; 183 | if (length(K) == 2) 184 | if (K(1) < 0 || K(2) < 0) 185 | ssim_index = -Inf; 186 | ssim_map = -Inf; 187 | return; 188 | end 189 | else 190 | ssim_index = -Inf; 191 | ssim_map = -Inf; 192 | return; 193 | end 194 | end 195 | 196 | if (nargin == 4) 197 | [H W] = size(window); 198 | if ((H*W) < 4 || (H > M) || (W > N)) 199 | ssim_index = -Inf; 200 | ssim_map = -Inf; 201 | return 202 | end 203 | L = 255; 204 | if (length(K) == 2) 205 | if (K(1) < 0 || K(2) < 0) 206 | ssim_index = -Inf; 207 | ssim_map = -Inf; 208 | return; 209 | end 210 | else 211 | ssim_index = -Inf; 212 | ssim_map = -Inf; 213 | return; 214 | end 215 | end 216 | 217 | if (nargin == 5) 218 | [H W] = size(window); 219 | if ((H*W) < 4 || (H > M) || (W > N)) 220 | ssim_index = -Inf; 221 | ssim_map = -Inf; 222 | return 223 | end 224 | if (length(K) == 2) 225 | if (K(1) < 0 || K(2) < 0) 226 | ssim_index = -Inf; 227 | ssim_map = -Inf; 228 | return; 229 | end 230 | else 231 | ssim_index = -Inf; 232 | ssim_map = -Inf; 233 | return; 234 | end 235 | end 236 | 237 | C1 = (K(1)*L)^2; 238 | C2 = (K(2)*L)^2; 239 | window = window/sum(sum(window)); 240 | img1 = double(img1); 241 | img2 = double(img2); 242 | 243 | mu1 = filter2(window, img1, 'valid'); 244 | mu2 = filter2(window, img2, 'valid'); 245 | mu1_sq = mu1.*mu1; 246 | mu2_sq = mu2.*mu2; 247 | mu1_mu2 = mu1.*mu2; 248 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq; 249 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq; 250 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2; 251 | 252 | if (C1 > 0 & C2 > 0) 253 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2)); 254 | else 255 | numerator1 = 2*mu1_mu2 + C1; 256 | numerator2 = 2*sigma12 + C2; 257 | denominator1 = mu1_sq + mu2_sq + C1; 258 | denominator2 = sigma1_sq + sigma2_sq + C2; 259 | ssim_map = ones(size(mu1)); 260 | index = (denominator1.*denominator2 > 0); 261 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index)); 262 | index = (denominator1 ~= 0) & (denominator2 == 0); 263 | ssim_map(index) = numerator1(index)./denominator1(index); 264 | end 265 | 266 | mssim = mean2(ssim_map); 267 | 268 | end 269 | 270 | -------------------------------------------------------------------------------- /losses.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | 7 | 8 | def tv_loss(x, beta = 0.5, reg_coeff = 5): 9 | '''Calculates TV loss for an image `x`. 10 | 11 | Args: 12 | x: image, torch.Variable of torch.Tensor 13 | beta: See https://arxiv.org/abs/1412.0035 (fig. 2) to see effect of `beta` 14 | ''' 15 | dh = torch.pow(x[:,:,:,1:] - x[:,:,:,:-1], 2) 16 | dw = torch.pow(x[:,:,1:,:] - x[:,:,:-1,:], 2) 17 | a,b,c,d=x.shape 18 | return reg_coeff*(torch.sum(torch.pow(dh[:, :, :-1] + dw[:, :, :, :-1], beta))/(a*b*c*d)) 19 | 20 | class TVLoss(nn.Module): 21 | def __init__(self, tv_loss_weight=1): 22 | super(TVLoss, self).__init__() 23 | self.tv_loss_weight = tv_loss_weight 24 | 25 | def forward(self, x): 26 | batch_size = x.size()[0] 27 | h_x = x.size()[2] 28 | w_x = x.size()[3] 29 | count_h = self.tensor_size(x[:, :, 1:, :]) 30 | count_w = self.tensor_size(x[:, :, :, 1:]) 31 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum() 32 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum() 33 | return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size 34 | 35 | @staticmethod 36 | def tensor_size(t): 37 | return t.size()[1] * t.size()[2] * t.size()[3] 38 | 39 | 40 | 41 | class CharbonnierLoss(nn.Module): 42 | """Charbonnier Loss (L1)""" 43 | 44 | def __init__(self, eps=1e-3): 45 | super(CharbonnierLoss, self).__init__() 46 | self.eps = eps 47 | 48 | def forward(self, x, y): 49 | diff = x - y 50 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps)) 51 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps))) 52 | return loss 53 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.utils.checkpoint as checkpoint 4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_ 5 | import torch.nn.functional as F 6 | from einops import rearrange, repeat 7 | from einops.layers.torch import Rearrange 8 | import math 9 | import numpy as np 10 | import time 11 | from torch import einsum 12 | 13 | 14 | def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1): 15 | return nn.Conv2d( 16 | in_channels, out_channels, kernel_size, 17 | padding=(kernel_size//2), bias=bias, stride = stride) 18 | 19 | 20 | 21 | ######################################### 22 | class ConvBlock(nn.Module): 23 | def __init__(self, in_channel, out_channel, strides=1): 24 | super(ConvBlock, self).__init__() 25 | self.strides = strides 26 | self.in_channel=in_channel 27 | self.out_channel=out_channel 28 | self.block = nn.Sequential( 29 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1), 30 | nn.LeakyReLU(inplace=True), 31 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=strides, padding=1), 32 | nn.LeakyReLU(inplace=True), 33 | ) 34 | self.conv11 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, padding=0) 35 | 36 | def forward(self, x): 37 | out1 = self.block(x) 38 | out2 = self.conv11(x) 39 | out = out1 + out2 40 | return out 41 | 42 | class LinearProjection(nn.Module): 43 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., bias=True): 44 | super().__init__() 45 | inner_dim = dim_head * heads 46 | self.heads = heads 47 | self.to_q = nn.Linear(dim, inner_dim, bias = bias) 48 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias) 49 | self.dim = dim 50 | self.inner_dim = inner_dim 51 | 52 | def forward(self, x, attn_kv=None): 53 | B_, N, C = x.shape 54 | if attn_kv is not None: 55 | attn_kv = attn_kv.unsqueeze(0).repeat(B_,1,1) 56 | else: 57 | attn_kv = x 58 | N_kv = attn_kv.size(1) 59 | q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 60 | kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) 61 | q = q[0] 62 | k, v = kv[0], kv[1] 63 | return q,k,v 64 | 65 | 66 | ######################################### 67 | ########### window-based self-attention ############# 68 | class WindowAttention(nn.Module): 69 | def __init__(self, dim, win_size,num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 70 | 71 | super().__init__() 72 | self.dim = dim 73 | self.win_size = win_size # Wh, Ww 74 | self.num_heads = num_heads 75 | head_dim = dim // num_heads 76 | self.scale = qk_scale or head_dim ** -0.5 77 | 78 | # define a parameter table of relative position bias 79 | self.relative_position_bias_table = nn.Parameter( 80 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 81 | 82 | # get pair-wise relative position index for each token inside the window 83 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1] 84 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1] 85 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 86 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 87 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 88 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 89 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0 90 | relative_coords[:, :, 1] += self.win_size[1] - 1 91 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 92 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 93 | self.register_buffer("relative_position_index", relative_position_index) 94 | trunc_normal_(self.relative_position_bias_table, std=.02) 95 | 96 | if token_projection =='linear': 97 | self.qkv = LinearProjection(dim,num_heads,dim//num_heads,bias=qkv_bias) 98 | else: 99 | raise Exception("Projection error!") 100 | 101 | self.token_projection = token_projection 102 | self.attn_drop = nn.Dropout(attn_drop) 103 | self.proj = nn.Linear(dim, dim) 104 | self.proj_drop = nn.Dropout(proj_drop) 105 | 106 | self.softmax = nn.Softmax(dim=-1) 107 | 108 | def forward(self, x, attn_kv=None, mask=None): 109 | B_, N, C = x.shape 110 | q, k, v = self.qkv(x,attn_kv) 111 | q = q * self.scale 112 | attn = (q @ k.transpose(-2, -1)) 113 | 114 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 115 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH 116 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 117 | ratio = attn.size(-1)//relative_position_bias.size(-1) 118 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d = ratio) 119 | 120 | attn = attn + relative_position_bias.unsqueeze(0) 121 | 122 | if mask is not None: 123 | nW = mask.shape[0] 124 | mask = repeat(mask, 'nW m n -> nW m (n d)',d = ratio) 125 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N*ratio) + mask.unsqueeze(1).unsqueeze(0) 126 | attn = attn.view(-1, self.num_heads, N, N*ratio) 127 | attn = self.softmax(attn) 128 | else: 129 | attn = self.softmax(attn) 130 | 131 | attn = self.attn_drop(attn) 132 | 133 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 134 | x = self.proj(x) 135 | x = self.proj_drop(x) 136 | return x 137 | 138 | def extra_repr(self) -> str: 139 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}' 140 | 141 | ########### window-based self-attention ############# 142 | class WindowAttention_sparse(nn.Module): 143 | def __init__(self, dim, win_size,num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 144 | 145 | super().__init__() 146 | self.dim = dim 147 | self.win_size = win_size # Wh, Ww 148 | self.num_heads = num_heads 149 | head_dim = dim // num_heads 150 | self.scale = qk_scale or head_dim ** -0.5 151 | 152 | # define a parameter table of relative position bias 153 | self.relative_position_bias_table = nn.Parameter( 154 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH 155 | 156 | # get pair-wise relative position index for each token inside the window 157 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1] 158 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1] 159 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww 160 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww 161 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww 162 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2 163 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0 164 | relative_coords[:, :, 1] += self.win_size[1] - 1 165 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 166 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww 167 | self.register_buffer("relative_position_index", relative_position_index) 168 | trunc_normal_(self.relative_position_bias_table, std=.02) 169 | 170 | if token_projection =='linear': 171 | self.qkv = LinearProjection(dim,num_heads,dim//num_heads,bias=qkv_bias) 172 | else: 173 | raise Exception("Projection error!") 174 | 175 | self.token_projection = token_projection 176 | self.attn_drop = nn.Dropout(attn_drop) 177 | self.proj = nn.Linear(dim, dim) 178 | self.proj_drop = nn.Dropout(proj_drop) 179 | 180 | self.softmax = nn.Softmax(dim=-1) 181 | self.relu = nn.ReLU() 182 | self.w = nn.Parameter(torch.ones(2)) 183 | 184 | def forward(self, x, attn_kv=None, mask=None): 185 | B_, N, C = x.shape 186 | q, k, v = self.qkv(x,attn_kv) 187 | q = q * self.scale 188 | attn = (q @ k.transpose(-2, -1)) 189 | 190 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( 191 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH 192 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww 193 | ratio = attn.size(-1)//relative_position_bias.size(-1) 194 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d = ratio) 195 | 196 | attn = attn + relative_position_bias.unsqueeze(0) 197 | 198 | if mask is not None: 199 | nW = mask.shape[0] 200 | mask = repeat(mask, 'nW m n -> nW m (n d)',d = ratio) 201 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N*ratio) + mask.unsqueeze(1).unsqueeze(0) 202 | attn = attn.view(-1, self.num_heads, N, N*ratio) 203 | attn0 = self.softmax(attn) 204 | attn1 = self.relu(attn)**2#b,h,w,c 205 | else: 206 | attn0 = self.softmax(attn) 207 | attn1 = self.relu(attn)**2 208 | w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w)) 209 | w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w)) 210 | attn = attn0*w1+attn1*w2 211 | attn = self.attn_drop(attn) 212 | 213 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 214 | x = self.proj(x) 215 | x = self.proj_drop(x) 216 | return x 217 | 218 | def extra_repr(self) -> str: 219 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}' 220 | 221 | ########### self-attention ############# 222 | class Attention(nn.Module): 223 | def __init__(self, dim,num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): 224 | 225 | super().__init__() 226 | self.dim = dim 227 | self.num_heads = num_heads 228 | head_dim = dim // num_heads 229 | self.scale = qk_scale or head_dim ** -0.5 230 | 231 | self.qkv = LinearProjection(dim,num_heads,dim//num_heads,bias=qkv_bias) 232 | 233 | self.token_projection = token_projection 234 | self.attn_drop = nn.Dropout(attn_drop) 235 | self.proj = nn.Linear(dim, dim) 236 | self.proj_drop = nn.Dropout(proj_drop) 237 | 238 | self.softmax = nn.Softmax(dim=-1) 239 | 240 | def forward(self, x, attn_kv=None, mask=None): 241 | B_, N, C = x.shape 242 | q, k, v = self.qkv(x,attn_kv) 243 | q = q * self.scale 244 | attn = (q @ k.transpose(-2, -1)) 245 | if mask is not None: 246 | nW = mask.shape[0] 247 | # mask = repeat(mask, 'nW m n -> nW m (n d)',d = ratio) 248 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0) 249 | attn = attn.view(-1, self.num_heads, N, N) 250 | attn = self.softmax(attn) 251 | else: 252 | attn = self.softmax(attn) 253 | 254 | attn = self.attn_drop(attn) 255 | 256 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C) 257 | x = self.proj(x) 258 | x = self.proj_drop(x) 259 | return x 260 | 261 | def extra_repr(self) -> str: 262 | return f'dim={self.dim}, num_heads={self.num_heads}' 263 | 264 | 265 | 266 | ######################################### 267 | ########### feed-forward network ############# 268 | class Mlp(nn.Module): 269 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): 270 | super().__init__() 271 | out_features = out_features or in_features 272 | hidden_features = hidden_features or in_features 273 | self.fc1 = nn.Linear(in_features, hidden_features) 274 | self.act = act_layer() 275 | self.fc2 = nn.Linear(hidden_features, out_features) 276 | self.drop = nn.Dropout(drop) 277 | self.in_features = in_features 278 | self.hidden_features = hidden_features 279 | self.out_features = out_features 280 | 281 | def forward(self, x): 282 | x = self.fc1(x) 283 | x = self.act(x) 284 | x = self.drop(x) 285 | x = self.fc2(x) 286 | x = self.drop(x) 287 | return x 288 | 289 | 290 | class LeFF(nn.Module): 291 | def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU,drop = 0., use_eca=False): 292 | super().__init__() 293 | self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim), 294 | act_layer()) 295 | self.dwconv = nn.Sequential(nn.Conv2d(hidden_dim,hidden_dim,groups=hidden_dim,kernel_size=3,stride=1,padding=1), 296 | act_layer()) 297 | self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim)) 298 | self.dim = dim 299 | self.hidden_dim = hidden_dim 300 | self.eca = nn.Identity() 301 | 302 | def forward(self, x): 303 | # bs x hw x c 304 | bs, hw, c = x.size() 305 | hh = int(math.sqrt(hw)) 306 | 307 | x = self.linear1(x) 308 | 309 | # spatial restore 310 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h = hh, w = hh) 311 | # bs,hidden_dim,32x32 312 | 313 | x = self.dwconv(x) 314 | 315 | # flaten 316 | x = rearrange(x, ' b c h w -> b (h w) c', h = hh, w = hh) 317 | 318 | x = self.linear2(x) 319 | x = self.eca(x) 320 | 321 | return x 322 | 323 | 324 | class FRFN(nn.Module): 325 | def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU,drop = 0., use_eca=False): 326 | super().__init__() 327 | self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim*2), 328 | act_layer()) 329 | self.dwconv = nn.Sequential(nn.Conv2d(hidden_dim,hidden_dim,groups=hidden_dim,kernel_size=3,stride=1,padding=1), 330 | act_layer()) 331 | self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim)) 332 | self.dim = dim 333 | self.hidden_dim = hidden_dim 334 | 335 | self.dim_conv = self.dim // 4 336 | self.dim_untouched = self.dim - self.dim_conv 337 | self.partial_conv3 = nn.Conv2d(self.dim_conv, self.dim_conv, 3, 1, 1, bias=False) 338 | 339 | def forward(self, x): 340 | # bs x hw x c 341 | bs, hw, c = x.size() 342 | hh = int(math.sqrt(hw)) 343 | 344 | 345 | # spatial restore 346 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h = hh, w = hh) 347 | 348 | x1, x2,= torch.split(x, [self.dim_conv,self.dim_untouched], dim=1) 349 | x1 = self.partial_conv3(x1) 350 | x = torch.cat((x1, x2), 1) 351 | 352 | # flaten 353 | x = rearrange(x, ' b c h w -> b (h w) c', h = hh, w = hh) 354 | 355 | x = self.linear1(x) 356 | #gate mechanism 357 | x_1,x_2 = x.chunk(2,dim=-1) 358 | 359 | x_1 = rearrange(x_1, ' b (h w) (c) -> b c h w ', h = hh, w = hh) 360 | x_1 = self.dwconv(x_1) 361 | x_1 = rearrange(x_1, ' b c h w -> b (h w) c', h = hh, w = hh) 362 | x = x_1 * x_2 363 | 364 | x = self.linear2(x) 365 | # x = self.eca(x) 366 | 367 | return x 368 | 369 | 370 | 371 | ######################################### 372 | ########### window operation############# 373 | def window_partition(x, win_size, dilation_rate=1): 374 | B, H, W, C = x.shape 375 | if dilation_rate !=1: 376 | x = x.permute(0,3,1,2) # B, C, H, W 377 | assert type(dilation_rate) is int, 'dilation_rate should be a int' 378 | x = F.unfold(x, kernel_size=win_size,dilation=dilation_rate,padding=4*(dilation_rate-1),stride=win_size) # B, C*Wh*Ww, H/Wh*W/Ww 379 | windows = x.permute(0,2,1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww 380 | windows = windows.permute(0,2,3,1).contiguous() # B' ,Wh ,Ww ,C 381 | else: 382 | x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) 383 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C 384 | return windows 385 | 386 | def window_reverse(windows, win_size, H, W, dilation_rate=1): 387 | # B' ,Wh ,Ww ,C 388 | B = int(windows.shape[0] / (H * W / win_size / win_size)) 389 | x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) 390 | if dilation_rate !=1: 391 | x = windows.permute(0,5,3,4,1,2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww 392 | x = F.fold(x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4*(dilation_rate-1),stride=win_size) 393 | else: 394 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) 395 | return x 396 | 397 | ######################################### 398 | 399 | # Downsample Block 400 | class Downsample(nn.Module): 401 | def __init__(self, in_channel, out_channel): 402 | super(Downsample, self).__init__() 403 | self.conv = nn.Sequential( 404 | nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1), 405 | ) 406 | self.in_channel = in_channel 407 | self.out_channel = out_channel 408 | 409 | def forward(self, x): 410 | B, L, C = x.shape 411 | # import pdb;pdb.set_trace() 412 | H = int(math.sqrt(L)) 413 | W = int(math.sqrt(L)) 414 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 415 | out = self.conv(x).flatten(2).transpose(1,2).contiguous() # B H*W C 416 | return out 417 | 418 | # Upsample Block 419 | class Upsample(nn.Module): 420 | def __init__(self, in_channel, out_channel): 421 | super(Upsample, self).__init__() 422 | self.deconv = nn.Sequential( 423 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2), 424 | ) 425 | self.in_channel = in_channel 426 | self.out_channel = out_channel 427 | 428 | def forward(self, x): 429 | B, L, C = x.shape 430 | H = int(math.sqrt(L)) 431 | W = int(math.sqrt(L)) 432 | x = x.transpose(1, 2).contiguous().view(B, C, H, W) 433 | out = self.deconv(x).flatten(2).transpose(1,2).contiguous() # B H*W C 434 | return out 435 | 436 | 437 | # Input Projection 438 | class InputProj(nn.Module): 439 | def __init__(self, in_channel=3, out_channel=64, kernel_size=3, stride=1, norm_layer=None,act_layer=nn.LeakyReLU): 440 | super().__init__() 441 | self.proj = nn.Sequential( 442 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2), 443 | act_layer(inplace=True) 444 | ) 445 | if norm_layer is not None: 446 | self.norm = norm_layer(out_channel) 447 | else: 448 | self.norm = None 449 | self.in_channel = in_channel 450 | self.out_channel = out_channel 451 | 452 | def forward(self, x): 453 | B, C, H, W = x.shape 454 | x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C 455 | if self.norm is not None: 456 | x = self.norm(x) 457 | return x 458 | 459 | # Output Projection 460 | class OutputProj(nn.Module): 461 | def __init__(self, in_channel=64, out_channel=3, kernel_size=3, stride=1, norm_layer=None,act_layer=None): 462 | super().__init__() 463 | self.proj = nn.Sequential( 464 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2), 465 | ) 466 | if act_layer is not None: 467 | self.proj.add_module(act_layer(inplace=True)) 468 | if norm_layer is not None: 469 | self.norm = norm_layer(out_channel) 470 | else: 471 | self.norm = None 472 | self.in_channel = in_channel 473 | self.out_channel = out_channel 474 | 475 | def forward(self, x): 476 | B, L, C = x.shape 477 | H = int(math.sqrt(L)) 478 | W = int(math.sqrt(L)) 479 | x = x.transpose(1, 2).view(B, C, H, W) 480 | x = self.proj(x) 481 | if self.norm is not None: 482 | x = self.norm(x) 483 | return x 484 | 485 | ######################################### 486 | ###########Transformer Block############# 487 | class TransformerBlock(nn.Module): 488 | def __init__(self, dim, input_resolution, num_heads, win_size=8, shift_size=0, 489 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., 490 | act_layer=nn.GELU, norm_layer=nn.LayerNorm,token_projection='linear',token_mlp='leff',att=True,sparseAtt=False): 491 | super().__init__() 492 | 493 | self.att = att 494 | self.sparseAtt = sparseAtt 495 | 496 | self.dim = dim 497 | self.input_resolution = input_resolution 498 | self.num_heads = num_heads 499 | self.win_size = win_size 500 | self.shift_size = shift_size 501 | self.mlp_ratio = mlp_ratio 502 | self.token_mlp = token_mlp 503 | if min(self.input_resolution) <= self.win_size: 504 | self.shift_size = 0 505 | self.win_size = min(self.input_resolution) 506 | assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size" 507 | 508 | if self.att: 509 | self.norm1 = norm_layer(dim) 510 | if self.sparseAtt: 511 | self.attn = WindowAttention_sparse( 512 | dim, win_size=to_2tuple(self.win_size), num_heads=num_heads, 513 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, 514 | token_projection=token_projection) 515 | else: 516 | self.attn = WindowAttention( 517 | dim, win_size=to_2tuple(self.win_size), num_heads=num_heads, 518 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, 519 | token_projection=token_projection) 520 | 521 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() 522 | self.norm2 = norm_layer(dim) 523 | mlp_hidden_dim = int(dim * mlp_ratio) 524 | if token_mlp in ['ffn','mlp']: 525 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,act_layer=act_layer, drop=drop) 526 | elif token_mlp=='leff': 527 | self.mlp = LeFF(dim,mlp_hidden_dim,act_layer=act_layer, drop=drop) 528 | elif token_mlp=='frfn': 529 | self.mlp = FRFN(dim,mlp_hidden_dim,act_layer=act_layer, drop=drop) 530 | else: 531 | raise Exception("FFN error!") 532 | 533 | 534 | def with_pos_embed(self, tensor, pos): 535 | return tensor if pos is None else tensor + pos 536 | 537 | def extra_repr(self) -> str: 538 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \ 539 | f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" 540 | 541 | def forward(self, x, mask=None): 542 | B, L, C = x.shape 543 | H = int(math.sqrt(L)) 544 | W = int(math.sqrt(L)) 545 | 546 | ## input mask 547 | if mask != None: 548 | input_mask = F.interpolate(mask, size=(H,W)).permute(0,2,3,1) 549 | input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1 550 | attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size 551 | attn_mask = attn_mask.unsqueeze(2)*attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size 552 | attn_mask = attn_mask.masked_fill(attn_mask!=0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) 553 | else: 554 | attn_mask = None 555 | 556 | ## shift mask 557 | if self.shift_size > 0: 558 | # calculate attention mask for SW-MSA 559 | shift_mask = torch.zeros((1, H, W, 1)).type_as(x) 560 | h_slices = (slice(0, -self.win_size), 561 | slice(-self.win_size, -self.shift_size), 562 | slice(-self.shift_size, None)) 563 | w_slices = (slice(0, -self.win_size), 564 | slice(-self.win_size, -self.shift_size), 565 | slice(-self.shift_size, None)) 566 | cnt = 0 567 | for h in h_slices: 568 | for w in w_slices: 569 | shift_mask[:, h, w, :] = cnt 570 | cnt += 1 571 | shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1 572 | shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size 573 | shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(2) # nW, win_size*win_size, win_size*win_size 574 | shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill(shift_attn_mask == 0, float(0.0)) 575 | attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask 576 | 577 | 578 | shortcut = x 579 | 580 | if self.att: 581 | x = self.norm1(x) 582 | x = x.view(B, H, W, C) 583 | 584 | # cyclic shift 585 | if self.shift_size > 0: 586 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) 587 | else: 588 | shifted_x = x 589 | 590 | # partition windows 591 | x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C 592 | x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C 593 | 594 | 595 | # W-MSA/SW-MSA 596 | attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, win_size*win_size, C 597 | 598 | # merge windows 599 | attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C) 600 | shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C 601 | 602 | # reverse cyclic shift 603 | if self.shift_size > 0: 604 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) 605 | else: 606 | x = shifted_x 607 | x = x.view(B, H * W, C) 608 | x = shortcut + self.drop_path(x) 609 | 610 | # FFN 611 | x = x + self.drop_path(self.mlp(self.norm2(x))) 612 | del attn_mask 613 | return x 614 | 615 | 616 | ######################################### 617 | ########### Basic layer of AST ################ 618 | class BasicASTLayer(nn.Module): 619 | def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size, 620 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., 621 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False, 622 | token_projection='linear',token_mlp='ffn', shift_flag=True,att=False,sparseAtt=False): 623 | 624 | super().__init__() 625 | self.att = att 626 | self.sparseAtt = sparseAtt 627 | self.dim = dim 628 | self.input_resolution = input_resolution 629 | self.depth = depth 630 | self.use_checkpoint = use_checkpoint 631 | # build blocks 632 | if shift_flag: 633 | self.blocks = nn.ModuleList([ 634 | TransformerBlock(dim=dim, input_resolution=input_resolution, 635 | num_heads=num_heads, win_size=win_size, 636 | shift_size=0 if (i % 2 == 0) else win_size // 2, 637 | mlp_ratio=mlp_ratio, 638 | qkv_bias=qkv_bias, qk_scale=qk_scale, 639 | drop=drop, attn_drop=attn_drop, 640 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 641 | norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp,att=self.att,sparseAtt=self.sparseAtt) 642 | for i in range(depth)]) 643 | else: 644 | self.blocks = nn.ModuleList([ 645 | TransformerBlock(dim=dim, input_resolution=input_resolution, 646 | num_heads=num_heads, win_size=win_size, 647 | shift_size=0, 648 | mlp_ratio=mlp_ratio, 649 | qkv_bias=qkv_bias, qk_scale=qk_scale, 650 | drop=drop, attn_drop=attn_drop, 651 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, 652 | norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp,att=self.att,sparseAtt=self.sparseAtt) 653 | for i in range(depth)]) 654 | 655 | def extra_repr(self) -> str: 656 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" 657 | 658 | def forward(self, x, mask=None): 659 | for blk in self.blocks: 660 | if self.use_checkpoint: 661 | x = checkpoint.checkpoint(blk, x) 662 | else: 663 | x = blk(x,mask) 664 | return x 665 | 666 | 667 | 668 | class AST(nn.Module): 669 | def __init__(self, img_size=256, in_chans=3, dd_in=3, 670 | embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2], 671 | win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, 672 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, 673 | norm_layer=nn.LayerNorm, patch_norm=True, 674 | use_checkpoint=False, token_projection='linear', token_mlp='leff', 675 | dowsample=Downsample, upsample=Upsample, shift_flag=True,**kwargs): 676 | super().__init__() 677 | 678 | self.num_enc_layers = len(depths)//2 679 | self.num_dec_layers = len(depths)//2 680 | self.embed_dim = embed_dim 681 | self.patch_norm = patch_norm 682 | self.mlp_ratio = mlp_ratio 683 | self.token_projection = token_projection 684 | self.mlp = token_mlp 685 | self.win_size =win_size 686 | self.reso = img_size 687 | self.pos_drop = nn.Dropout(p=drop_rate) 688 | self.dd_in = dd_in 689 | 690 | # stochastic depth 691 | enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))] 692 | conv_dpr = [drop_path_rate]*depths[4] 693 | dec_dpr = enc_dpr[::-1] 694 | 695 | # build layers 696 | 697 | # Input/Output 698 | self.input_proj = InputProj(in_channel=dd_in, out_channel=embed_dim, kernel_size=3, stride=1, act_layer=nn.LeakyReLU) 699 | self.output_proj = OutputProj(in_channel=2*embed_dim, out_channel=in_chans, kernel_size=3, stride=1) 700 | 701 | # Encoder 702 | self.encoderlayer_0 = BasicASTLayer(dim=embed_dim, 703 | output_dim=embed_dim, 704 | input_resolution=(img_size, 705 | img_size), 706 | depth=depths[0], 707 | num_heads=num_heads[0], 708 | win_size=win_size, 709 | mlp_ratio=self.mlp_ratio, 710 | qkv_bias=qkv_bias, qk_scale=qk_scale, 711 | drop=drop_rate, attn_drop=attn_drop_rate, 712 | drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])], 713 | norm_layer=norm_layer, 714 | use_checkpoint=use_checkpoint, 715 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=False,sparseAtt=False) 716 | self.dowsample_0 = dowsample(embed_dim, embed_dim*2) 717 | self.encoderlayer_1 = BasicASTLayer(dim=embed_dim*2, 718 | output_dim=embed_dim*2, 719 | input_resolution=(img_size // 2, 720 | img_size // 2), 721 | depth=depths[1], 722 | num_heads=num_heads[1], 723 | win_size=win_size, 724 | mlp_ratio=self.mlp_ratio, 725 | qkv_bias=qkv_bias, qk_scale=qk_scale, 726 | drop=drop_rate, attn_drop=attn_drop_rate, 727 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])], 728 | norm_layer=norm_layer, 729 | use_checkpoint=use_checkpoint, 730 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=False,sparseAtt=False) 731 | self.dowsample_1 = dowsample(embed_dim*2, embed_dim*4) 732 | self.encoderlayer_2 = BasicASTLayer(dim=embed_dim*4, 733 | output_dim=embed_dim*4, 734 | input_resolution=(img_size // (2 ** 2), 735 | img_size // (2 ** 2)), 736 | depth=depths[2], 737 | num_heads=num_heads[2], 738 | win_size=win_size, 739 | mlp_ratio=self.mlp_ratio, 740 | qkv_bias=qkv_bias, qk_scale=qk_scale, 741 | drop=drop_rate, attn_drop=attn_drop_rate, 742 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])], 743 | norm_layer=norm_layer, 744 | use_checkpoint=use_checkpoint, 745 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=False,sparseAtt=False) 746 | self.dowsample_2 = dowsample(embed_dim*4, embed_dim*8) 747 | self.encoderlayer_3 = BasicASTLayer(dim=embed_dim*8, 748 | output_dim=embed_dim*8, 749 | input_resolution=(img_size // (2 ** 3), 750 | img_size // (2 ** 3)), 751 | depth=depths[3], 752 | num_heads=num_heads[3], 753 | win_size=win_size, 754 | mlp_ratio=self.mlp_ratio, 755 | qkv_bias=qkv_bias, qk_scale=qk_scale, 756 | drop=drop_rate, attn_drop=attn_drop_rate, 757 | drop_path=enc_dpr[sum(depths[:3]):sum(depths[:4])], 758 | norm_layer=norm_layer, 759 | use_checkpoint=use_checkpoint, 760 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=False,sparseAtt=False) 761 | self.dowsample_3 = dowsample(embed_dim*8, embed_dim*16) 762 | 763 | # Bottleneck 764 | self.conv = BasicASTLayer(dim=embed_dim*16, 765 | output_dim=embed_dim*16, 766 | input_resolution=(img_size // (2 ** 4), 767 | img_size // (2 ** 4)), 768 | depth=depths[4], 769 | num_heads=num_heads[4], 770 | win_size=win_size, 771 | mlp_ratio=self.mlp_ratio, 772 | qkv_bias=qkv_bias, qk_scale=qk_scale, 773 | drop=drop_rate, attn_drop=attn_drop_rate, 774 | drop_path=conv_dpr, 775 | norm_layer=norm_layer, 776 | use_checkpoint=use_checkpoint, 777 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=True,sparseAtt=True) 778 | 779 | # Decoder 780 | self.upsample_0 = upsample(embed_dim*16, embed_dim*8) 781 | self.decoderlayer_0 = BasicASTLayer(dim=embed_dim*16, 782 | output_dim=embed_dim*16, 783 | input_resolution=(img_size // (2 ** 3), 784 | img_size // (2 ** 3)), 785 | depth=depths[5], 786 | num_heads=num_heads[5], 787 | win_size=win_size, 788 | mlp_ratio=self.mlp_ratio, 789 | qkv_bias=qkv_bias, qk_scale=qk_scale, 790 | drop=drop_rate, attn_drop=attn_drop_rate, 791 | drop_path=dec_dpr[:depths[5]], 792 | norm_layer=norm_layer, 793 | use_checkpoint=use_checkpoint, 794 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=True,sparseAtt=True) 795 | self.upsample_1 = upsample(embed_dim*16, embed_dim*4) 796 | self.decoderlayer_1 = BasicASTLayer(dim=embed_dim*8, 797 | output_dim=embed_dim*8, 798 | input_resolution=(img_size // (2 ** 2), 799 | img_size // (2 ** 2)), 800 | depth=depths[6], 801 | num_heads=num_heads[6], 802 | win_size=win_size, 803 | mlp_ratio=self.mlp_ratio, 804 | qkv_bias=qkv_bias, qk_scale=qk_scale, 805 | drop=drop_rate, attn_drop=attn_drop_rate, 806 | drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])], 807 | norm_layer=norm_layer, 808 | use_checkpoint=use_checkpoint, 809 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=True,sparseAtt=True) 810 | self.upsample_2 = upsample(embed_dim*8, embed_dim*2) 811 | self.decoderlayer_2 = BasicASTLayer(dim=embed_dim*4, 812 | output_dim=embed_dim*4, 813 | input_resolution=(img_size // 2, 814 | img_size // 2), 815 | depth=depths[7], 816 | num_heads=num_heads[7], 817 | win_size=win_size, 818 | mlp_ratio=self.mlp_ratio, 819 | qkv_bias=qkv_bias, qk_scale=qk_scale, 820 | drop=drop_rate, attn_drop=attn_drop_rate, 821 | drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])], 822 | norm_layer=norm_layer, 823 | use_checkpoint=use_checkpoint, 824 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=True,sparseAtt=True) 825 | self.upsample_3 = upsample(embed_dim*4, embed_dim) 826 | self.decoderlayer_3 = BasicASTLayer(dim=embed_dim*2, 827 | output_dim=embed_dim*2, 828 | input_resolution=(img_size, 829 | img_size), 830 | depth=depths[8], 831 | num_heads=num_heads[8], 832 | win_size=win_size, 833 | mlp_ratio=self.mlp_ratio, 834 | qkv_bias=qkv_bias, qk_scale=qk_scale, 835 | drop=drop_rate, attn_drop=attn_drop_rate, 836 | drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])], 837 | norm_layer=norm_layer, 838 | use_checkpoint=use_checkpoint, 839 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=True,sparseAtt=True) 840 | 841 | self.apply(self._init_weights) 842 | 843 | def _init_weights(self, m): 844 | if isinstance(m, nn.Linear): 845 | trunc_normal_(m.weight, std=.02) 846 | if isinstance(m, nn.Linear) and m.bias is not None: 847 | nn.init.constant_(m.bias, 0) 848 | elif isinstance(m, nn.LayerNorm): 849 | nn.init.constant_(m.bias, 0) 850 | nn.init.constant_(m.weight, 1.0) 851 | 852 | @torch.jit.ignore 853 | def no_weight_decay(self): 854 | return {'absolute_pos_embed'} 855 | 856 | @torch.jit.ignore 857 | def no_weight_decay_keywords(self): 858 | return {'relative_position_bias_table'} 859 | 860 | def extra_repr(self) -> str: 861 | return f"embed_dim={self.embed_dim}, token_projection={self.token_projection}, token_mlp={self.mlp},win_size={self.win_size}" 862 | 863 | def forward(self, x, mask=None): 864 | # Input Projection 865 | y = self.input_proj(x) 866 | y = self.pos_drop(y) 867 | #Encoder 868 | conv0 = self.encoderlayer_0(y,mask=mask) 869 | pool0 = self.dowsample_0(conv0) 870 | conv1 = self.encoderlayer_1(pool0,mask=mask) 871 | pool1 = self.dowsample_1(conv1) 872 | conv2 = self.encoderlayer_2(pool1,mask=mask) 873 | pool2 = self.dowsample_2(conv2) 874 | conv3 = self.encoderlayer_3(pool2,mask=mask) 875 | pool3 = self.dowsample_3(conv3) 876 | 877 | # Bottleneck 878 | conv4 = self.conv(pool3, mask=mask) 879 | 880 | #Decoder 881 | up0 = self.upsample_0(conv4) 882 | deconv0 = torch.cat([up0,conv3],-1) 883 | deconv0 = self.decoderlayer_0(deconv0,mask=mask) 884 | 885 | up1 = self.upsample_1(deconv0) 886 | deconv1 = torch.cat([up1,conv2],-1) 887 | deconv1 = self.decoderlayer_1(deconv1,mask=mask) 888 | 889 | up2 = self.upsample_2(deconv1) 890 | deconv2 = torch.cat([up2,conv1],-1) 891 | deconv2 = self.decoderlayer_2(deconv2,mask=mask) 892 | 893 | up3 = self.upsample_3(deconv2) 894 | deconv3 = torch.cat([up3,conv0],-1) 895 | deconv3 = self.decoderlayer_3(deconv3,mask=mask) 896 | 897 | # Output Projection 898 | y = self.output_proj(deconv3) 899 | return x + y if self.dd_in ==3 else y 900 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | class Options(): 4 | """docstring for Options""" 5 | def __init__(self): 6 | pass 7 | 8 | def init(self, parser): 9 | # global settings 10 | parser.add_argument('--batch_size', type=int, default=32, help='batch size') 11 | parser.add_argument('--nepoch', type=int, default=250, help='training epochs') 12 | parser.add_argument('--train_workers', type=int, default=0, help='train_dataloader workers') 13 | parser.add_argument('--eval_workers', type=int, default=0, help='eval_dataloader workers') 14 | parser.add_argument('--dataset', type=str, default ='SIDD') 15 | parser.add_argument('--pretrain_weights',type=str, default='./log/AST_B/models/model_best.pth', help='path of pretrained_weights') 16 | parser.add_argument('--optimizer', type=str, default ='adamw', help='optimizer for training') 17 | parser.add_argument('--lr_initial', type=float, default=0.0002, help='initial learning rate') 18 | parser.add_argument('--step_lr', type=int, default=50, help='weight decay') 19 | parser.add_argument('--weight_decay', type=float, default=0.02, help='weight decay') 20 | parser.add_argument('--gpu', type=str, default='0,1', help='GPUs') 21 | parser.add_argument('--arch', type=str, default ='AST_B', help='archtechture') 22 | parser.add_argument('--mode', type=str, default ='denoising', help='image restoration mode') 23 | parser.add_argument('--dd_in', type=int, default=3, help='dd_in') 24 | 25 | # args for saving 26 | parser.add_argument('--save_dir', type=str, default ='./logs/', help='save dir') 27 | parser.add_argument('--save_images', action='store_true',default=False) 28 | parser.add_argument('--env', type=str, default ='_', help='env') 29 | parser.add_argument('--checkpoint', type=int, default=50, help='checkpoint') 30 | 31 | # args for Uformer 32 | parser.add_argument('--norm_layer', type=str, default ='nn.LayerNorm', help='normalize layer in transformer') 33 | parser.add_argument('--embed_dim', type=int, default=32, help='dim of emdeding features') 34 | parser.add_argument('--win_size', type=int, default=8, help='window size of self-attention') 35 | parser.add_argument('--token_projection', type=str,default='linear', help='linear/conv token projection') 36 | parser.add_argument('--token_mlp', type=str,default='leff', help='ffn/leff token mlp/leff_mpconv') 37 | parser.add_argument('--att_se', action='store_true', default=False, help='se after sa') 38 | parser.add_argument('--modulator', action='store_true', default=False, help='multi-scale modulator') 39 | 40 | # args for vit 41 | parser.add_argument('--vit_dim', type=int, default=256, help='vit hidden_dim') 42 | parser.add_argument('--vit_depth', type=int, default=12, help='vit depth') 43 | parser.add_argument('--vit_nheads', type=int, default=8, help='vit hidden_dim') 44 | parser.add_argument('--vit_mlp_dim', type=int, default=512, help='vit mlp_dim') 45 | parser.add_argument('--vit_patch_size', type=int, default=16, help='vit patch_size') 46 | parser.add_argument('--global_skip', action='store_true', default=False, help='global skip connection') 47 | parser.add_argument('--local_skip', action='store_true', default=False, help='local skip connection') 48 | parser.add_argument('--vit_share', action='store_true', default=False, help='share vit module') 49 | 50 | # args for training 51 | parser.add_argument('--train_ps', type=int, default=256, help='patch size of training sample') 52 | parser.add_argument('--val_ps', type=int, default=256, help='patch size of validation sample') 53 | parser.add_argument('--resume', action='store_true',default=False) 54 | parser.add_argument('--retrain', action='store_true', default=False) 55 | parser.add_argument('--train_dir', type=str, default ='./rain_syn/DDN-Data/train/', help='dir of train data') 56 | parser.add_argument('--val_dir', type=str, default ='./rain_syn/DDN-Data/test/', help='dir of train data') 57 | parser.add_argument('--warmup', action='store_true', default=False, help='warmup') 58 | parser.add_argument('--warmup_epochs', type=int,default=3, help='epochs for warmup') 59 | 60 | # ddp 61 | parser.add_argument("--local_rank", type=int,default=-1,help='DDP parameter, do not modify')#不需要赋值,启动命令 torch.distributed.launch会自动赋值 62 | parser.add_argument("--distribute",action='store_true',help='whether using multi gpu train') 63 | parser.add_argument("--distribute_mode",type=str,default='DDP',help="using which mode to ") 64 | return parser 65 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | torch==1.8.0 2 | torchvision==0.9.0 3 | matplotlib 4 | scikit-image 5 | opencv-python 6 | yacs 7 | joblib 8 | natsort 9 | h5py 10 | tqdm 11 | einops 12 | linformer 13 | timm 14 | ptflops 15 | dataclasses 16 | natsort -------------------------------------------------------------------------------- /script/test.sh: -------------------------------------------------------------------------------- 1 | ##test on AGAN-Data## 2 | python3 test/test_raindrop.py --arch AST_B --input_dir ../dataset/raindrop/test_a/ --result_dir ./results/rain_drop/AGAN-Data/ --weights ./logs/raindrop/AGAN-Data/AST_B/models/model_best.pth --token_mlp frfn 3 | 4 | ##test on densehaze## 5 | python3 test/test_denseHaze.py --arch AST_B --input_dir ../dataset/Dense-Haze-v2/valid_dense --result_dir ./results/dehaze/DenseHaze/ --weights ./logs/dehazing/DenseHaze/AST_B/models/model_best.pth --token_mlp frfn 6 | 7 | ## test on SPAD## 8 | python3 test/test_spad.py --arch AST_B --input_dir /PATH/TO/DATASET/ --result_dir ./results/deraining/SPAD/ --weights ./pretrained/rain/model_best.pth --token_mlp frfn 9 | 10 | 11 | 12 | -------------------------------------------------------------------------------- /script/train_dehaze.sh: -------------------------------------------------------------------------------- 1 | ##train on densehaze## 2 | 3 | python3 ./train/train_dehaze.py --arch AST_B --batch_size 4 --gpu '0' --train_ps 256 --train_dir ../dataset/Dense-Haze-v2/train_dense/ --val_ps 256 --val_dir ../dataset/Dense-Haze-v2/valid_dense/ --lr 0.0002 --env _1108_s1 --mode dehaze --nepoch 2000 --dataset DenseHaze --warmup --token_mlp frfn 4 | 5 | # python3 ./train/train_dehaze.py --arch AST_B --batch_size 4 --retrain --pretrain_weights ./logs/dehazing/DenseHaze/AST_B_1108_s1/models/model_best.pth --gpu '0' --train_ps 384 --train_dir ../dataset/Dense-Haze-v2/train_dense/ --val_ps 384 --val_dir ../dataset/Dense-Haze-v2/valid_dense/ --lr 0.00012 --env _1109_s2 --mode dehaze --nepoch 1200 --token_mlp frfn --dataset DenseHaze --warmup 6 | 7 | # python3 ./train/train_dehaze.py --arch AST_B --batch_size 4 --retrain --pretrain_weights ./logs/dehazing/DenseHaze/AST_B_1109_s2/models/model_best.pth --gpu '0' --train_ps 512 --train_dir ../dataset/Dense-Haze-v2/train_dense/ --val_ps 512 --val_dir ../dataset/Dense-Haze-v2/valid_dense/ --lr 0.00008 --env _1110_s3 --mode dehaze --nepoch 800 --token_mlp frfn --dataset DenseHaze --warmup 8 | 9 | # python3 ./train/train_dehaze.py --arch AST_B --retrain --pretrain_weights ./logs/dehazing/DenseHaze/AST_B_1110_s3/models/model_best.pth --batch_size 2 --gpu '0,1' --train_ps 768 --train_dir ../dataset/Dense-Haze-v2/train_dense/ --val_ps 768 --val_dir ../dataset/Dense-Haze-v2/valid_dense/ --lr 0.00003 --env _1105_s4 --mode dehaze --nepoch 300 --token_mlp frfn --dataset DenseHaze --warmup 10 | 11 | # python3 ./train/train_dehaze.py --arch AST_B --retrain --pretrain_weights ./logs/dehazing/DenseHaze/AST_B_1105/models/model_best.pth --batch_size 2 --gpu '0,1' --train_ps 896 --train_dir ../dataset/Dense-Haze-v2/train_dense/ --val_ps 896 --val_dir ../dataset/Dense-Haze-v2/valid_dense/ --lr 0.00001 --env _1106_s5 --mode dehaze --nepoch 80 --token_mlp frfn --dataset DenseHaze --warmup 12 | -------------------------------------------------------------------------------- /script/train_derain.sh: -------------------------------------------------------------------------------- 1 | ##train on SPAD## 2 | python ./train/train_derain.py --arch AST_B --batch_size 32 --gpu 0,1 --train_ps 128 --train_dir ../derain_dataset/derain/ --env _1030_s1 --val_dir ../derain_dataset/derain/ --save_dir ./logs/ --dataset spad --warmup --token_mlp frfn --nepoch 20 --lr_initial 0.0002 3 | 4 | # python ./train/train_derain.py --arch AST_B --retrain --pretrain_weights ./logs/derain/spad/AST_B_1030_s1/models/model_best.pth --batch_size 16 --gpu 0,1 --train_ps 256 --train_dir ../derain_dataset/derain/ --env _1030_s2 --val_ps 256 --val_dir ../derain_dataset/derain/ --save_dir ./logs/ --dataset spad --warmup --token_mlp frfn --nepoch 15 --lr_initial 0.0001 5 | -------------------------------------------------------------------------------- /script/train_raindrop.sh: -------------------------------------------------------------------------------- 1 | ##train on AGAN-Data## 2 | python ./train/train_raindrop.py --arch AST_B --batch_size 32 --gpu '0,1' --train_ps 128 --train_dir ../dataset/raindrop/train/ --val_ps 128 --val_dir ../dataset/raindrop/test_a/ --env _1102_s1 --mode derain_drop --nepoch 300 --dataset raindrop --warmup --lr_initial 0.0002 --token_mlp frfn 3 | 4 | # python ./train/train_derain_drop.py --arch AST_B --retrain --pretrain_weights ./logs/raindrop/AGAN-Data/AST_B_1102_s1/models/model_best.pth --batch_size 16 --gpu '0,1' --train_ps 256 --train_dir ../dataset/raindrop/train/ --val_ps 256 --val_dir ../dataset/raindrop/test_a/ --env _1103_s2 --mode derain_drop --nepoch 200 --dataset raindrop --warmup --lr_initial 0.0001 --token_mlp frfn 5 | 6 | # python ./train/train_derain_drop.py --arch AST_B --retrain --pretrain_weights ./logs/derain/raindrop/AST_B_1103_s2/models/model_best.pth --batch_size 8 --gpu '0,1' --train_ps 384 --train_dir ../dataset/raindrop/train/ --val_ps 384 --val_dir ../dataset/raindrop/test_a/ --env _1103_s3 --mode derain_drop --nepoch 150 --dataset raindrop --warmup --lr_initial 0.00008 --token_mlp frfn -------------------------------------------------------------------------------- /test/test_denseHaze.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys,math 3 | import argparse 4 | from tqdm import tqdm 5 | from einops import rearrange, repeat 6 | 7 | import torch.nn as nn 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torch.nn.functional as F 11 | # from ptflops import get_model_complexity_info 12 | 13 | dir_name = os.path.dirname(os.path.abspath(__file__)) 14 | sys.path.append(os.path.join(dir_name,'../dataset/')) 15 | sys.path.append(os.path.join(dir_name,'..')) 16 | 17 | import scipy.io as sio 18 | from dataset.dataset_dehaze_denseHaze import * 19 | import utils 20 | 21 | from skimage import img_as_float32, img_as_ubyte 22 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 23 | from skimage.metrics import structural_similarity as ssim_loss 24 | 25 | parser = argparse.ArgumentParser(description='Image dehazing evaluation on DenseHaze') 26 | parser.add_argument('--input_dir', default='/PATH/TO/DATASET/', 27 | type=str, help='Directory of validation images') 28 | parser.add_argument('--result_dir', default='./results_release/haze/densehaze/AST_B/', 29 | type=str, help='Directory for results') 30 | parser.add_argument('--weights', default='./pretrained/haze/model_best.pth', 31 | type=str, help='Path to weights') 32 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES') 33 | parser.add_argument('--arch', default='AST_B', type=str, help='arch') 34 | parser.add_argument('--batch_size', default=1, type=int, help='Batch size for dataloader') 35 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory') 36 | parser.add_argument('--embed_dim', type=int, default=32, help='number of data loading workers') 37 | parser.add_argument('--win_size', type=int, default=8, help='number of data loading workers') 38 | parser.add_argument('--token_projection', type=str,default='linear', help='linear/conv token projection') 39 | parser.add_argument('--token_mlp', type=str,default='leff', help='ffn/leff token mlp') 40 | parser.add_argument('--query_embed', action='store_true', default=False, help='query embedding for the decoder') 41 | parser.add_argument('--dd_in', type=int, default=3, help='dd_in') 42 | 43 | # args for vit 44 | parser.add_argument('--vit_dim', type=int, default=256, help='vit hidden_dim') 45 | parser.add_argument('--vit_depth', type=int, default=12, help='vit depth') 46 | parser.add_argument('--vit_nheads', type=int, default=8, help='vit hidden_dim') 47 | parser.add_argument('--vit_mlp_dim', type=int, default=512, help='vit mlp_dim') 48 | parser.add_argument('--vit_patch_size', type=int, default=16, help='vit patch_size') 49 | parser.add_argument('--global_skip', action='store_true', default=False, help='global skip connection') 50 | parser.add_argument('--local_skip', action='store_true', default=False, help='local skip connection') 51 | parser.add_argument('--vit_share', action='store_true', default=False, help='share vit module') 52 | 53 | parser.add_argument('--train_ps', type=int, default=128, help='patch size of training sample') 54 | args = parser.parse_args() 55 | 56 | 57 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 58 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 59 | 60 | utils.mkdir(args.result_dir) 61 | 62 | test_dataset = get_validation_data(args.input_dir) 63 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=8, drop_last=False) 64 | 65 | model_restoration= utils.get_arch(args) 66 | # model_restoration = torch.nn.DataParallel(model_restoration) 67 | 68 | utils.load_checkpoint(model_restoration,args.weights) 69 | print("===>Testing using weights: ", args.weights) 70 | 71 | model_restoration.cuda() 72 | model_restoration.eval() 73 | 74 | 75 | def expand2square(timg,factor=16.0): 76 | _, _, h, w = timg.size() 77 | 78 | X = int(math.ceil(max(h,w)/float(factor))*factor) 79 | 80 | img = torch.zeros(1,3,X,X).type_as(timg) # 3, h,w 81 | mask = torch.zeros(1,1,X,X).type_as(timg) 82 | 83 | # print(img.size(),mask.size()) 84 | # print((X - h)//2, (X - h)//2+h, (X - w)//2, (X - w)//2+w) 85 | img[:,:, ((X - h)//2):((X - h)//2 + h),((X - w)//2):((X - w)//2 + w)] = timg 86 | mask[:,:, ((X - h)//2):((X - h)//2 + h),((X - w)//2):((X - w)//2 + w)].fill_(1) 87 | 88 | return img, mask 89 | from utils.image_utils import splitimage, mergeimage 90 | 91 | 92 | with torch.no_grad(): 93 | psnr_val_rgb = [] 94 | ssim_val_rgb = [] 95 | for ii, data_test in enumerate(tqdm(test_loader), 0): 96 | rgb_gt = data_test[0].numpy().squeeze().transpose((1,2,0)) 97 | # rgb_noisy, mask = expand2square(data_test[1].cuda(), factor=128) 98 | filenames = data_test[2] 99 | 100 | input_ = data_test[1].cuda() 101 | B, C, H, W = input_.shape 102 | corp_size_arg = 1152 103 | overlap_size_arg = 384 104 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg) 105 | for i, data in enumerate(split_data): 106 | split_data[i] = model_restoration(data).cpu() 107 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W)) 108 | rgb_restored = torch.clamp(restored, 0, 1).permute(0, 2, 3, 1).numpy() 109 | 110 | psnr = psnr_loss(rgb_restored[0], rgb_gt) 111 | ssim = ssim_loss(rgb_restored[0], rgb_gt, channel_axis=2, data_range=1) 112 | 113 | psnr_val_rgb.append(psnr) 114 | ssim_val_rgb.append(ssim) 115 | print("PSNR:",psnr,", SSIM:", ssim, filenames[0], rgb_restored.shape) 116 | utils.save_img(os.path.join(args.result_dir,filenames[0]+'.PNG'), img_as_ubyte(rgb_restored[0])) 117 | with open(os.path.join(args.result_dir,'psnr_ssim.txt'),'a') as f: 118 | f.write(filenames[0]+'.PNG ---->'+"PSNR: %.4f, SSIM: %.4f] "% (psnr, ssim)+'\n') 119 | psnr_val_rgb = sum(psnr_val_rgb)/len(test_dataset) 120 | ssim_val_rgb = sum(ssim_val_rgb)/len(test_dataset) 121 | print("PSNR: %f, SSIM: %f " %(psnr_val_rgb,ssim_val_rgb)) 122 | with open(os.path.join(args.result_dir,'psnr_ssim.txt'),'a') as f: 123 | f.write("Arch:"+args.arch+", PSNR: %.4f, SSIM: %.4f] "% (psnr_val_rgb, ssim_val_rgb)+'\n') -------------------------------------------------------------------------------- /test/test_raindrop.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys 3 | import argparse 4 | from tqdm import tqdm 5 | from einops import rearrange, repeat 6 | 7 | import torch.nn as nn 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torch.nn.functional as F 11 | # from ptflops import get_model_complexity_info 12 | 13 | dir_name = os.path.dirname(os.path.abspath(__file__)) 14 | sys.path.append(os.path.join(dir_name,'../dataset/')) 15 | sys.path.append(os.path.join(dir_name,'..')) 16 | 17 | # import scipy.io as sio 18 | from dataset.dataset_derain_drop import * 19 | import utils 20 | import math 21 | 22 | from skimage import img_as_float32, img_as_ubyte 23 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 24 | from skimage.metrics import structural_similarity as ssim_loss 25 | 26 | parser = argparse.ArgumentParser(description='Image derain evaluation on spad') 27 | parser.add_argument('--input_dir', default='./dataset/AST_B/deraining/spad/val/', 28 | type=str, help='Directory of validation images') 29 | parser.add_argument('--result_dir', default='./results/deraining/spad/AST_B/se', 30 | type=str, help='Directory for results') 31 | parser.add_argument('--weights', default='./logs/deraining/spad/AST_B/models/model_best.pth', 32 | type=str, help='Path to weights') 33 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES') 34 | parser.add_argument('--arch', default='AST_B', type=str, help='arch') 35 | parser.add_argument('--batch_size', default=1, type=int, help='Batch size for dataloader') 36 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory') 37 | parser.add_argument('--embed_dim', type=int, default=32, help='number of data loading workers') 38 | parser.add_argument('--win_size', type=int, default=8, help='number of data loading workers') 39 | parser.add_argument('--token_projection', type=str,default='linear', help='linear/conv token projection') 40 | parser.add_argument('--token_mlp', type=str,default='leff', help='ffn/leff token mlp') 41 | parser.add_argument('--dd_in', type=int, default=3, help='dd_in') 42 | 43 | # args for vit 44 | parser.add_argument('--vit_dim', type=int, default=256, help='vit hidden_dim') 45 | parser.add_argument('--vit_depth', type=int, default=12, help='vit depth') 46 | parser.add_argument('--vit_nheads', type=int, default=8, help='vit hidden_dim') 47 | parser.add_argument('--vit_mlp_dim', type=int, default=512, help='vit mlp_dim') 48 | parser.add_argument('--vit_patch_size', type=int, default=16, help='vit patch_size') 49 | parser.add_argument('--global_skip', action='store_true', default=False, help='global skip connection') 50 | parser.add_argument('--local_skip', action='store_true', default=False, help='local skip connection') 51 | parser.add_argument('--vit_share', action='store_true', default=False, help='share vit module') 52 | 53 | parser.add_argument('--train_ps', type=int, default=128, help='patch size of training sample') 54 | args = parser.parse_args() 55 | 56 | 57 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 58 | os.environ["CUDA_VISIBLE_DEVICES"] = "2" 59 | 60 | 61 | utils.mkdir(args.result_dir) 62 | 63 | test_dataset = get_test_data(args.input_dir) 64 | # test_dataset = get_test_data(args.input_dir) 65 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=8, drop_last=False) 66 | 67 | model_restoration= utils.get_arch(args) 68 | 69 | utils.load_checkpoint(model_restoration,args.weights) 70 | print("===>Testing using weights: ", args.weights) 71 | 72 | model_restoration.cuda() 73 | model_restoration.eval() 74 | from utils.image_utils import splitimage, mergeimage 75 | 76 | def test_transform(v, op): 77 | v2np = v.data.cpu().numpy() 78 | if op == 'v': 79 | tfnp = v2np[:, :, :, ::-1].copy() 80 | elif op == 'h': 81 | tfnp = v2np[:, :, ::-1, :].copy() 82 | elif op == 't': 83 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 84 | 85 | ret = torch.Tensor(tfnp).to(v.device) 86 | 87 | return ret 88 | 89 | def expand2square(timg,factor=16.0): 90 | _, _, h, w = timg.size() 91 | 92 | X = int(math.ceil(max(h,w)/float(factor))*factor) 93 | 94 | img = torch.zeros(1,3,X,X).type_as(timg) # 3, h,w 95 | mask = torch.zeros(1,1,X,X).type_as(timg) 96 | 97 | # print(img.size(),mask.size()) 98 | # print((X - h)//2, (X - h)//2+h, (X - w)//2, (X - w)//2+w) 99 | img[:,:, ((X - h)//2):((X - h)//2 + h),((X - w)//2):((X - w)//2 + w)] = timg 100 | mask[:,:, ((X - h)//2):((X - h)//2 + h),((X - w)//2):((X - w)//2 + w)].fill_(1) 101 | 102 | return img, mask 103 | 104 | # # Process data 105 | with torch.no_grad(): 106 | psnr_val_rgb = [] 107 | ssim_val_rgb = [] 108 | for ii, data_test in enumerate(tqdm(test_loader), 0): 109 | rgb_gt = data_test[0].numpy().squeeze().transpose((1,2,0)) 110 | # rgb_noisy, mask = expand2square(data_test[1].cuda(), factor=128) 111 | filenames = data_test[2] 112 | 113 | input_ = data_test[1].cuda() 114 | B, C, H, W = input_.shape 115 | corp_size_arg = 384 116 | overlap_size_arg = 80 117 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg) 118 | for i, data in enumerate(split_data): 119 | split_data[i] = model_restoration(data).cpu() 120 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W)) 121 | rgb_restored = torch.clamp(restored, 0, 1).permute(0, 2, 3, 1).numpy() 122 | 123 | psnr = psnr_loss(rgb_restored[0], rgb_gt) 124 | ssim = ssim_loss(rgb_restored[0], rgb_gt, channel_axis=2, data_range=1) 125 | 126 | 127 | psnr_val_rgb.append(psnr) 128 | ssim_val_rgb.append(ssim) 129 | print("PSNR:",psnr,", SSIM:", ssim, filenames[0], rgb_restored.shape) 130 | utils.save_img(os.path.join(args.result_dir,filenames[0]+'.PNG'), img_as_ubyte(rgb_restored[0])) 131 | with open(os.path.join(args.result_dir,'psnr_ssim.txt'),'a') as f: 132 | f.write(filenames[0]+'.PNG ---->'+"PSNR: %.4f, SSIM: %.4f] "% (psnr, ssim)+'\n') 133 | psnr_val_rgb = sum(psnr_val_rgb)/len(test_dataset) 134 | ssim_val_rgb = sum(ssim_val_rgb)/len(test_dataset) 135 | print("PSNR: %f, SSIM: %f " %(psnr_val_rgb,ssim_val_rgb)) 136 | with open(os.path.join(args.result_dir,'psnr_ssim.txt'),'a') as f: 137 | f.write("Arch:"+args.arch+", PSNR: %.4f, SSIM: %.4f] "% (psnr_val_rgb, ssim_val_rgb)+'\n') -------------------------------------------------------------------------------- /test/test_spad.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os,sys 3 | import argparse 4 | from tqdm import tqdm 5 | from einops import rearrange, repeat 6 | 7 | import torch.nn as nn 8 | import torch 9 | from torch.utils.data import DataLoader 10 | import torch.nn.functional as F 11 | 12 | dir_name = os.path.dirname(os.path.abspath(__file__)) 13 | sys.path.append(os.path.join(dir_name,'../dataset/')) 14 | sys.path.append(os.path.join(dir_name,'..')) 15 | 16 | # import scipy.io as sio 17 | from dataset.dataset_derain import * 18 | import utils 19 | import math 20 | 21 | from skimage import img_as_float32, img_as_ubyte 22 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss 23 | from skimage.metrics import structural_similarity as ssim_loss 24 | 25 | parser = argparse.ArgumentParser(description='Image derain evaluation on spad') 26 | parser.add_argument('--input_dir', default='/PATH/TO/DATASET/', 27 | type=str, help='Directory of validation images') 28 | parser.add_argument('--result_dir', default='./results/deraining/spad/AST_B/', 29 | type=str, help='Directory for results') 30 | parser.add_argument('--weights', default='./pretrained/rain/model_best.pth', 31 | type=str, help='Path to weights') 32 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES') 33 | parser.add_argument('--arch', default='AST_B', type=str, help='arch') 34 | parser.add_argument('--batch_size', default=1, type=int, help='Batch size for dataloader') 35 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory') 36 | parser.add_argument('--embed_dim', type=int, default=32, help='number of data loading workers') 37 | parser.add_argument('--win_size', type=int, default=8, help='number of data loading workers') 38 | parser.add_argument('--token_projection', type=str,default='linear', help='linear/conv token projection') 39 | parser.add_argument('--token_mlp', type=str,default='frfn', help='ffn/leff token mlp') 40 | parser.add_argument('--dd_in', type=int, default=3, help='dd_in') 41 | 42 | # args for vit 43 | parser.add_argument('--vit_dim', type=int, default=256, help='vit hidden_dim') 44 | parser.add_argument('--vit_depth', type=int, default=12, help='vit depth') 45 | parser.add_argument('--vit_nheads', type=int, default=8, help='vit hidden_dim') 46 | parser.add_argument('--vit_mlp_dim', type=int, default=512, help='vit mlp_dim') 47 | parser.add_argument('--vit_patch_size', type=int, default=16, help='vit patch_size') 48 | parser.add_argument('--global_skip', action='store_true', default=False, help='global skip connection') 49 | parser.add_argument('--local_skip', action='store_true', default=False, help='local skip connection') 50 | parser.add_argument('--vit_share', action='store_true', default=False, help='share vit module') 51 | 52 | parser.add_argument('--train_ps', type=int, default=128, help='patch size of training sample') 53 | args = parser.parse_args() 54 | 55 | 56 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 57 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus 58 | 59 | 60 | utils.mkdir(args.result_dir) 61 | 62 | test_dataset = get_validation_data(args.input_dir) 63 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=8, drop_last=False) 64 | 65 | model_restoration= utils.get_arch(args) 66 | 67 | utils.load_checkpoint(model_restoration,args.weights) 68 | print("===>Testing using weights: ", args.weights) 69 | 70 | model_restoration.cuda() 71 | model_restoration.eval() 72 | 73 | 74 | def test_transform(v, op): 75 | v2np = v.data.cpu().numpy() 76 | if op == 'v': 77 | tfnp = v2np[:, :, :, ::-1].copy() 78 | elif op == 'h': 79 | tfnp = v2np[:, :, ::-1, :].copy() 80 | elif op == 't': 81 | tfnp = v2np.transpose((0, 1, 3, 2)).copy() 82 | 83 | ret = torch.Tensor(tfnp).to(v.device) 84 | 85 | return ret 86 | 87 | def expand2square(timg,factor=16.0): 88 | _, _, h, w = timg.size() 89 | 90 | X = int(math.ceil(max(h,w)/float(factor))*factor) 91 | 92 | img = torch.zeros(1,3,X,X).type_as(timg) # 3, h,w 93 | mask = torch.zeros(1,1,X,X).type_as(timg) 94 | 95 | img[:,:, ((X - h)//2):((X - h)//2 + h),((X - w)//2):((X - w)//2 + w)] = timg 96 | mask[:,:, ((X - h)//2):((X - h)//2 + h),((X - w)//2):((X - w)//2 + w)].fill_(1) 97 | 98 | return img, mask 99 | 100 | # # Process data 101 | with torch.no_grad(): 102 | psnr_val_rgb = [] 103 | ssim_val_rgb = [] 104 | for ii, data_test in enumerate(tqdm(test_loader), 0): 105 | rgb_gt = data_test[0].numpy().squeeze().transpose((1,2,0)) 106 | rgb_noisy, mask = expand2square(data_test[1].cuda(), factor=128) 107 | filenames = data_test[2] 108 | 109 | rgb_restored = model_restoration(rgb_noisy) 110 | rgb_restored = torch.masked_select(rgb_restored,mask.bool()).reshape(1,3,rgb_gt.shape[0],rgb_gt.shape[1]) 111 | rgb_restored = torch.clamp(rgb_restored,0,1).cpu().numpy().squeeze().transpose((1,2,0)) 112 | 113 | psnr = psnr_loss(rgb_restored, rgb_gt) 114 | ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True) 115 | psnr_val_rgb.append(psnr) 116 | ssim_val_rgb.append(ssim) 117 | print("PSNR:",psnr,", SSIM:", ssim, filenames[0], rgb_restored.shape) 118 | utils.save_img(os.path.join(args.result_dir,filenames[0]+'.PNG'), img_as_ubyte(rgb_restored)) 119 | with open(os.path.join(args.result_dir,'psnr_ssim.txt'),'a') as f: 120 | f.write(filenames[0]+'.PNG ---->'+"PSNR: %.4f, SSIM: %.4f] "% (psnr, ssim)+'\n') 121 | psnr_val_rgb = sum(psnr_val_rgb)/len(test_dataset) 122 | ssim_val_rgb = sum(ssim_val_rgb)/len(test_dataset) 123 | print("PSNR: %f, SSIM: %f " %(psnr_val_rgb,ssim_val_rgb)) 124 | with open(os.path.join(args.result_dir,'psnr_ssim.txt'),'a') as f: 125 | f.write("Arch:"+args.arch+", PSNR: %.4f, SSIM: %.4f] "% (psnr_val_rgb, ssim_val_rgb)+'\n') -------------------------------------------------------------------------------- /train/train_dehaze.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | # add dir 5 | dir_name = os.path.dirname(os.path.abspath(__file__)) 6 | sys.path.append(os.path.join(dir_name,'../dataset/')) 7 | sys.path.append(os.path.join(dir_name,'..')) 8 | print(sys.path) 9 | print(dir_name) 10 | 11 | import argparse 12 | import options 13 | ######### parser ########### 14 | opt = options.Options().init(argparse.ArgumentParser(description='Image dehazing')).parse_args() 15 | print(opt) 16 | 17 | import utils 18 | from dataset.dataset_dehaze_denseHaze import * 19 | ######### Set GPUs ########### 20 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 21 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu 22 | import torch 23 | torch.backends.cudnn.benchmark = True 24 | 25 | import torch.nn as nn 26 | import torch.optim as optim 27 | from torch.utils.data import DataLoader 28 | from natsort import natsorted 29 | import glob 30 | import random 31 | import time 32 | import numpy as np 33 | from einops import rearrange, repeat 34 | import datetime 35 | from pdb import set_trace as stx 36 | 37 | from losses import CharbonnierLoss 38 | 39 | from tqdm import tqdm 40 | from warmup_scheduler import GradualWarmupScheduler 41 | from torch.optim.lr_scheduler import StepLR 42 | from timm.utils import NativeScaler 43 | 44 | ######### Logs dir ########### 45 | log_dir = os.path.join(opt.save_dir, 'dehazing', opt.dataset, opt.arch+opt.env) 46 | if not os.path.exists(log_dir): 47 | os.makedirs(log_dir) 48 | logname = os.path.join(log_dir, datetime.datetime.now().isoformat()+'.txt') 49 | print("Now time is : ",datetime.datetime.now().isoformat()) 50 | result_dir = os.path.join(log_dir, 'results') 51 | model_dir = os.path.join(log_dir, 'models') 52 | utils.mkdir(result_dir) 53 | utils.mkdir(model_dir) 54 | 55 | # ######### Set Seeds ########### 56 | random.seed(1234) 57 | np.random.seed(1234) 58 | torch.manual_seed(1234) 59 | torch.cuda.manual_seed_all(1234) 60 | 61 | ######### Model ########### 62 | model_restoration = utils.get_arch(opt) 63 | 64 | with open(logname,'a') as f: 65 | f.write(str(opt)+'\n') 66 | f.write(str(model_restoration)+'\n') 67 | 68 | ######### Optimizer ########### 69 | start_epoch = 1 70 | if opt.optimizer.lower() == 'adam': 71 | optimizer = optim.Adam(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay) 72 | elif opt.optimizer.lower() == 'adamw': 73 | optimizer = optim.AdamW(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay) 74 | else: 75 | raise Exception("Error optimizer...") 76 | 77 | 78 | ######### DataParallel ########### 79 | model_restoration = torch.nn.DataParallel (model_restoration) 80 | model_restoration.cuda() 81 | 82 | 83 | ######### Scheduler ########### 84 | if opt.warmup: 85 | print("Using warmup and cosine strategy!") 86 | warmup_epochs = opt.warmup_epochs 87 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.nepoch-warmup_epochs, eta_min=1e-6) 88 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 89 | scheduler.step() 90 | else: 91 | step = 50 92 | print("Using StepLR,step={}!".format(step)) 93 | scheduler = StepLR(optimizer, step_size=step, gamma=0.5) 94 | scheduler.step() 95 | 96 | ######### Resume ########### 97 | if opt.resume: 98 | path_chk_rest = opt.pretrain_weights 99 | print("Resume from "+path_chk_rest) 100 | utils.load_checkpoint(model_restoration,path_chk_rest) 101 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1 102 | lr = utils.load_optim(optimizer, path_chk_rest) 103 | 104 | # for p in optimizer.param_groups: p['lr'] = lr 105 | # warmup = False 106 | # new_lr = lr 107 | # print('------------------------------------------------------------------------------') 108 | # print("==> Resuming Training with learning rate:",new_lr) 109 | # print('------------------------------------------------------------------------------') 110 | for i in range(1, start_epoch): 111 | scheduler.step() 112 | new_lr = scheduler.get_lr()[0] 113 | print('------------------------------------------------------------------------------') 114 | print("==> Resuming Training with learning rate:", new_lr) 115 | print('------------------------------------------------------------------------------') 116 | 117 | # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.nepoch-start_epoch+1, eta_min=1e-6) 118 | 119 | ######### Retrain ########### 120 | if opt.retrain: 121 | path_chk_rest = opt.pretrain_weights 122 | print("Retrain from "+path_chk_rest) 123 | utils.load_checkpoint(model_restoration,path_chk_rest) 124 | 125 | print('------------------------------------------------------------------------------') 126 | print("==> Re Training with learning rate:", opt.lr_initial) 127 | print('------------------------------------------------------------------------------') 128 | 129 | ######### Loss ########### 130 | criterion = CharbonnierLoss().cuda() 131 | 132 | ######### DataLoader ########### 133 | print('===> Loading datasets') 134 | img_options_train = {'patch_size':opt.train_ps} 135 | img_options_val ={'patch_size':opt.val_ps} 136 | train_dataset = get_training_data(opt.train_dir, img_options_train) 137 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True, 138 | num_workers=opt.train_workers, pin_memory=False, drop_last=False) 139 | val_dataset = get_validation_data(opt.val_dir,img_options_val) 140 | val_loader = DataLoader(dataset=val_dataset, batch_size=opt.batch_size, shuffle=False, 141 | num_workers=opt.eval_workers, pin_memory=False, drop_last=False) 142 | 143 | len_trainset = train_dataset.__len__() 144 | len_valset = val_dataset.__len__() 145 | print("Sizeof training set: ", len_trainset,", sizeof validation set: ", len_valset) 146 | ######### validation ########### 147 | with torch.no_grad(): 148 | model_restoration.eval() 149 | psnr_dataset = [] 150 | psnr_model_init = [] 151 | for ii, data_val in enumerate((val_loader), 0): 152 | target = data_val[0].cuda() 153 | input_ = data_val[1].cuda() 154 | with torch.cuda.amp.autocast(): 155 | restored = model_restoration(input_) 156 | restored = torch.clamp(restored,0,1) 157 | psnr_dataset.append(utils.batch_PSNR(input_, target, False).item()) 158 | psnr_model_init.append(utils.batch_PSNR(restored, target, False).item()) 159 | psnr_dataset = sum(psnr_dataset)/len_valset 160 | psnr_model_init = sum(psnr_model_init)/len_valset 161 | print('Input & GT (PSNR) -->%.4f dB'%(psnr_dataset), ', Model_init & GT (PSNR) -->%.4f dB'%(psnr_model_init)) 162 | 163 | ######### train ########### 164 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.nepoch)) 165 | best_psnr = 0 166 | best_epoch = 0 167 | best_iter = 0 168 | eval_now = len(train_loader)//4 169 | print("\nEvaluation after every {} Iterations !!!\n".format(eval_now)) 170 | 171 | loss_scaler = NativeScaler() 172 | torch.cuda.empty_cache() 173 | for epoch in range(start_epoch, opt.nepoch + 1): 174 | epoch_start_time = time.time() 175 | epoch_loss = 0 176 | train_id = 1 177 | 178 | for i, data in enumerate(tqdm(train_loader), 0): 179 | # zero_grad 180 | optimizer.zero_grad() 181 | 182 | target = data[0].cuda() 183 | input_ = data[1].cuda() 184 | 185 | if epoch>5: 186 | target, input_ = utils.MixUp_AUG().aug(target, input_) 187 | with torch.cuda.amp.autocast(): 188 | restored = model_restoration(input_) 189 | loss = criterion(restored, target) 190 | loss_scaler( 191 | loss, optimizer,parameters=model_restoration.parameters()) 192 | epoch_loss +=loss.item() 193 | 194 | #### Evaluation #### 195 | if (i+1)%eval_now==0 and i>0: 196 | with torch.no_grad(): 197 | model_restoration.eval() 198 | psnr_val_rgb = [] 199 | for ii, data_val in enumerate((val_loader), 0): 200 | target = data_val[0].cuda() 201 | input_ = data_val[1].cuda() 202 | filenames = data_val[2] 203 | with torch.cuda.amp.autocast(): 204 | restored = model_restoration(input_) 205 | restored = torch.clamp(restored,0,1) 206 | psnr_val_rgb.append(utils.batch_PSNR(restored, target, False).item()) 207 | 208 | psnr_val_rgb = sum(psnr_val_rgb)/len_valset 209 | 210 | if psnr_val_rgb > best_psnr: 211 | best_psnr = psnr_val_rgb 212 | best_epoch = epoch 213 | best_iter = i 214 | torch.save({'epoch': epoch, 215 | 'state_dict': model_restoration.state_dict(), 216 | 'optimizer' : optimizer.state_dict() 217 | }, os.path.join(model_dir,"model_best.pth")) 218 | 219 | print("[Ep %d it %d\t PSNR SIDD: %.4f\t] ---- [best_Ep_SIDD %d best_it_SIDD %d Best_PSNR_SIDD %.4f] " % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr)) 220 | with open(logname,'a') as f: 221 | f.write("[Ep %d it %d\t PSNR SIDD: %.4f\t] ---- [best_Ep_SIDD %d best_it_SIDD %d Best_PSNR_SIDD %.4f] " \ 222 | % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr)+'\n') 223 | model_restoration.train() 224 | torch.cuda.empty_cache() 225 | scheduler.step() 226 | 227 | print("------------------------------------------------------------------") 228 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0])) 229 | print("------------------------------------------------------------------") 230 | with open(logname,'a') as f: 231 | f.write("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0])+'\n') 232 | 233 | torch.save({'epoch': epoch, 234 | 'state_dict': model_restoration.state_dict(), 235 | 'optimizer' : optimizer.state_dict() 236 | }, os.path.join(model_dir,"model_latest.pth")) 237 | 238 | if epoch%opt.checkpoint == 0: 239 | torch.save({'epoch': epoch, 240 | 'state_dict': model_restoration.state_dict(), 241 | 'optimizer' : optimizer.state_dict() 242 | }, os.path.join(model_dir,"model_epoch_{}.pth".format(epoch))) 243 | print("Now time is : ",datetime.datetime.now().isoformat()) 244 | -------------------------------------------------------------------------------- /train/train_derain.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | dir_name = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.join(dir_name,'../dataset/')) 6 | sys.path.append(os.path.join(dir_name,'..')) 7 | print(sys.path) 8 | print(dir_name) 9 | 10 | import argparse 11 | import options 12 | ######### parser ########### 13 | opt = options.Options().init(argparse.ArgumentParser(description='Image derain')).parse_args() 14 | print(opt) 15 | 16 | import utils 17 | from dataset.dataset_derain import * 18 | ######### Set GPUs ########### 19 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 20 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu 21 | import torch 22 | torch.backends.cudnn.benchmark = True 23 | 24 | import torch.nn as nn 25 | import torch.optim as optim 26 | from torch.utils.data import DataLoader 27 | from natsort import natsorted 28 | import glob 29 | import random 30 | import time 31 | import numpy as np 32 | from einops import rearrange, repeat 33 | import datetime 34 | from pdb import set_trace as stx 35 | 36 | from losses import CharbonnierLoss 37 | 38 | from tqdm import tqdm 39 | from warmup_scheduler import GradualWarmupScheduler 40 | from torch.optim.lr_scheduler import StepLR 41 | from timm.utils import NativeScaler 42 | 43 | 44 | ######### Logs dir ########### 45 | log_dir = os.path.join(opt.save_dir, 'derain', opt.dataset, opt.arch+opt.env) 46 | if not os.path.exists(log_dir): 47 | os.makedirs(log_dir) 48 | logname = os.path.join(log_dir, datetime.datetime.now().isoformat()+'.txt') 49 | print("Now time is : ",datetime.datetime.now().isoformat()) 50 | result_dir = os.path.join(log_dir, 'results') 51 | model_dir = os.path.join(log_dir, 'models') 52 | utils.mkdir(result_dir) 53 | utils.mkdir(model_dir) 54 | 55 | # ######### Set Seeds ########### 56 | random.seed(1234) 57 | np.random.seed(1234) 58 | torch.manual_seed(1234) 59 | torch.cuda.manual_seed_all(1234) 60 | 61 | ######### Model ########### 62 | model_restoration = utils.get_arch(opt) 63 | 64 | with open(logname,'a') as f: 65 | f.write(str(opt)+'\n') 66 | f.write(str(model_restoration)+'\n') 67 | 68 | ######### Optimizer ########### 69 | start_epoch = 1 70 | if opt.optimizer.lower() == 'adam': 71 | optimizer = optim.Adam(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay) 72 | elif opt.optimizer.lower() == 'adamw': 73 | optimizer = optim.AdamW(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay) 74 | else: 75 | raise Exception("Error optimizer...") 76 | 77 | 78 | ######### DataParallel ########### 79 | model_restoration = torch.nn.DataParallel (model_restoration) 80 | model_restoration.cuda() 81 | 82 | 83 | ######### Scheduler ########### 84 | if opt.warmup: 85 | print("Using warmup and cosine strategy!") 86 | warmup_epochs = opt.warmup_epochs 87 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.nepoch-warmup_epochs, eta_min=1e-6) 88 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 89 | scheduler.step() 90 | else: 91 | step = 50 92 | print("Using StepLR,step={}!".format(step)) 93 | scheduler = StepLR(optimizer, step_size=step, gamma=0.5) 94 | scheduler.step() 95 | 96 | ######### Resume ########### 97 | if opt.resume: 98 | path_chk_rest = opt.pretrain_weights 99 | print("Resume from "+path_chk_rest) 100 | utils.load_checkpoint(model_restoration,path_chk_rest) 101 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1 102 | lr = utils.load_optim(optimizer, path_chk_rest) 103 | 104 | for i in range(1, start_epoch): 105 | scheduler.step() 106 | new_lr = scheduler.get_lr()[0] 107 | print('------------------------------------------------------------------------------') 108 | print("==> Resuming Training with learning rate:", new_lr) 109 | print('------------------------------------------------------------------------------') 110 | 111 | ######### Retrain ########### 112 | if opt.retrain: 113 | path_chk_rest = opt.pretrain_weights 114 | print("Retrain from "+path_chk_rest) 115 | utils.load_checkpoint(model_restoration,path_chk_rest) 116 | 117 | print('------------------------------------------------------------------------------') 118 | print("==> Re Training with learning rate:", opt.lr_initial) 119 | print('------------------------------------------------------------------------------') 120 | 121 | ######### Loss ########### 122 | criterion = CharbonnierLoss().cuda() 123 | 124 | ######### DataLoader ########### 125 | print('===> Loading datasets') 126 | img_options_train = {'patch_size':opt.train_ps} 127 | train_dataset = get_training_data(opt.train_dir, img_options_train) 128 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True, 129 | num_workers=opt.train_workers, pin_memory=False, drop_last=False) 130 | val_dataset = get_validation_data(opt.val_dir) 131 | val_loader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=False, 132 | num_workers=opt.eval_workers, pin_memory=False, drop_last=False) 133 | 134 | len_trainset = train_dataset.__len__() 135 | len_valset = val_dataset.__len__() 136 | print("Sizeof training set: ", len_trainset,", sizeof validation set: ", len_valset) 137 | ######### validation ########### 138 | 139 | with torch.no_grad(): 140 | model_restoration.eval() 141 | psnr_dataset = [] 142 | psnr_model_init = [] 143 | for ii, data_val in enumerate((val_loader), 0): 144 | target = data_val[0].cuda() 145 | input_ = data_val[1].cuda() 146 | with torch.cuda.amp.autocast(): 147 | restored = model_restoration(input_) 148 | restored = torch.clamp(restored,0,1) 149 | psnr_dataset.append(utils.batch_PSNR(input_, target, False, 'y').item()) 150 | psnr_model_init.append(utils.batch_PSNR(restored, target, False, 'y').item()) 151 | psnr_dataset = sum(psnr_dataset)/len_valset 152 | psnr_model_init = sum(psnr_model_init)/len_valset 153 | print('Input & GT (PSNR) -->%.4f dB'%(psnr_dataset), ', Model_init & GT (PSNR) -->%.4f dB'%(psnr_model_init)) 154 | 155 | ######### train ########### 156 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.nepoch)) 157 | best_psnr = 0 158 | best_epoch = 0 159 | best_iter = 0 160 | eval_now = len(train_loader)//4 161 | print("\nEvaluation after every {} Iterations !!!\n".format(eval_now)) 162 | 163 | loss_scaler = NativeScaler() 164 | torch.cuda.empty_cache() 165 | for epoch in range(start_epoch, opt.nepoch + 1): 166 | epoch_start_time = time.time() 167 | epoch_loss = 0 168 | train_id = 1 169 | 170 | for i, data in enumerate(tqdm(train_loader), 0): 171 | # zero_grad 172 | optimizer.zero_grad() 173 | 174 | target = data[0].cuda() 175 | input_ = data[1].cuda() 176 | 177 | if epoch>5: 178 | target, input_ = utils.MixUp_AUG().aug(target, input_) 179 | with torch.cuda.amp.autocast(): 180 | restored = model_restoration(input_) 181 | loss = criterion(restored, target) 182 | loss_scaler( 183 | loss, optimizer,parameters=model_restoration.parameters()) 184 | epoch_loss +=loss.item() 185 | 186 | #### Evaluation #### 187 | if (i+1)%eval_now==0 and i>0: 188 | with torch.no_grad(): 189 | model_restoration.eval() 190 | psnr_val_rgb = [] 191 | for ii, data_val in enumerate((val_loader), 0): 192 | target = data_val[0].cuda() 193 | input_ = data_val[1].cuda() 194 | filenames = data_val[2] 195 | with torch.cuda.amp.autocast(): 196 | restored = model_restoration(input_) 197 | restored = torch.clamp(restored,0,1) 198 | psnr_val_rgb.append(utils.batch_PSNR(restored, target, False, 'y').item()) 199 | 200 | psnr_val_rgb = sum(psnr_val_rgb)/len_valset 201 | 202 | if psnr_val_rgb > best_psnr: 203 | best_psnr = psnr_val_rgb 204 | best_epoch = epoch 205 | best_iter = i 206 | torch.save({'epoch': epoch, 207 | 'state_dict': model_restoration.state_dict(), 208 | 'optimizer' : optimizer.state_dict() 209 | }, os.path.join(model_dir,"model_best.pth")) 210 | 211 | print("[Ep %d it %d\t PSNR SPAD: %.4f\t] ---- [best_Ep_SPAD %d best_it_SPAD %d Best_PSNR_SPAD %.4f] " % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr)) 212 | with open(logname,'a') as f: 213 | f.write("[Ep %d it %d\t PSNR SPAD: %.4f\t] ---- [best_Ep_SPAD %d best_it_SPAD %d Best_PSNR_SPAD %.4f] " \ 214 | % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr)+'\n') 215 | model_restoration.train() 216 | torch.cuda.empty_cache() 217 | scheduler.step() 218 | 219 | print("------------------------------------------------------------------") 220 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0])) 221 | print("------------------------------------------------------------------") 222 | with open(logname,'a') as f: 223 | f.write("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0])+'\n') 224 | 225 | torch.save({'epoch': epoch, 226 | 'state_dict': model_restoration.state_dict(), 227 | 'optimizer' : optimizer.state_dict() 228 | }, os.path.join(model_dir,"model_latest.pth")) 229 | 230 | if epoch%opt.checkpoint == 0: 231 | torch.save({'epoch': epoch, 232 | 'state_dict': model_restoration.state_dict(), 233 | 'optimizer' : optimizer.state_dict() 234 | }, os.path.join(model_dir,"model_epoch_{}.pth".format(epoch))) 235 | print("Now time is : ",datetime.datetime.now().isoformat()) 236 | -------------------------------------------------------------------------------- /train/train_raindrop.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | 4 | dir_name = os.path.dirname(os.path.abspath(__file__)) 5 | sys.path.append(os.path.join(dir_name,'../dataset/')) 6 | sys.path.append(os.path.join(dir_name,'..')) 7 | print(sys.path) 8 | print(dir_name) 9 | 10 | import argparse 11 | import options 12 | ######### parser ########### 13 | opt = options.Options().init(argparse.ArgumentParser(description='Image derain')).parse_args() 14 | print(opt) 15 | 16 | import utils 17 | from dataset.dataset_derain_drop import * 18 | ######### Set GPUs ########### 19 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" 20 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu 21 | import torch 22 | torch.backends.cudnn.benchmark = True 23 | 24 | import torch.nn as nn 25 | import torch.optim as optim 26 | from torch.utils.data import DataLoader 27 | from natsort import natsorted 28 | import glob 29 | import random 30 | import time 31 | import numpy as np 32 | from einops import rearrange, repeat 33 | import datetime 34 | from pdb import set_trace as stx 35 | 36 | from losses import CharbonnierLoss 37 | 38 | from tqdm import tqdm 39 | from warmup_scheduler import GradualWarmupScheduler 40 | from torch.optim.lr_scheduler import StepLR 41 | from timm.utils import NativeScaler 42 | 43 | 44 | ######### Logs dir ########### 45 | log_dir = os.path.join(opt.save_dir, 'derain', opt.dataset, opt.arch+opt.env) 46 | if not os.path.exists(log_dir): 47 | os.makedirs(log_dir) 48 | logname = os.path.join(log_dir, datetime.datetime.now().isoformat()+'.txt') 49 | print("Now time is : ",datetime.datetime.now().isoformat()) 50 | result_dir = os.path.join(log_dir, 'results') 51 | model_dir = os.path.join(log_dir, 'models') 52 | utils.mkdir(result_dir) 53 | utils.mkdir(model_dir) 54 | 55 | # ######### Set Seeds ########### 56 | random.seed(1234) 57 | np.random.seed(1234) 58 | torch.manual_seed(1234) 59 | torch.cuda.manual_seed_all(1234) 60 | 61 | ######### Model ########### 62 | model_restoration = utils.get_arch(opt) 63 | 64 | with open(logname,'a') as f: 65 | f.write(str(opt)+'\n') 66 | f.write(str(model_restoration)+'\n') 67 | 68 | ######### Optimizer ########### 69 | start_epoch = 1 70 | if opt.optimizer.lower() == 'adam': 71 | optimizer = optim.Adam(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay) 72 | elif opt.optimizer.lower() == 'adamw': 73 | optimizer = optim.AdamW(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay) 74 | else: 75 | raise Exception("Error optimizer...") 76 | 77 | 78 | ######### DataParallel ########### 79 | # os.environ["CUDA_VISIBLE_DEVICES"] = '0,1' 80 | model_restoration = torch.nn.DataParallel (model_restoration) 81 | model_restoration.cuda() 82 | 83 | 84 | ######### Scheduler ########### 85 | if opt.warmup: 86 | print("Using warmup and cosine strategy!") 87 | warmup_epochs = opt.warmup_epochs 88 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.nepoch-warmup_epochs, eta_min=1e-6) 89 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine) 90 | scheduler.step() 91 | else: 92 | step = 50 93 | print("Using StepLR,step={}!".format(step)) 94 | scheduler = StepLR(optimizer, step_size=step, gamma=0.5) 95 | scheduler.step() 96 | 97 | ######### Resume ########### 98 | if opt.resume: 99 | path_chk_rest = opt.pretrain_weights 100 | print("Resume from "+path_chk_rest) 101 | utils.load_checkpoint(model_restoration,path_chk_rest) 102 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1 103 | lr = utils.load_optim(optimizer, path_chk_rest) 104 | 105 | for i in range(1, start_epoch): 106 | scheduler.step() 107 | new_lr = scheduler.get_lr()[0] 108 | print('------------------------------------------------------------------------------') 109 | print("==> Resuming Training with learning rate:", new_lr) 110 | print('------------------------------------------------------------------------------') 111 | 112 | ######### Retrain ########### 113 | if opt.retrain: 114 | path_chk_rest = opt.pretrain_weights 115 | print("Retrain from "+path_chk_rest) 116 | utils.load_checkpoint(model_restoration,path_chk_rest) 117 | 118 | print('------------------------------------------------------------------------------') 119 | print("==> Re Training with learning rate:", opt.lr_initial) 120 | print('------------------------------------------------------------------------------') 121 | 122 | ######### Loss ########### 123 | criterion = CharbonnierLoss().cuda() 124 | 125 | ######### DataLoader ########### 126 | print('===> Loading datasets') 127 | img_options_train = {'patch_size':opt.train_ps} 128 | train_dataset = get_training_data(opt.train_dir, img_options_train) 129 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True, 130 | num_workers=opt.train_workers, pin_memory=False, drop_last=False) 131 | img_options_val = {'patch_size':opt.val_ps} 132 | val_dataset = get_validation_data(opt.val_dir,img_options_val) 133 | val_loader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=False, 134 | num_workers=opt.eval_workers, pin_memory=False, drop_last=False) 135 | 136 | len_trainset = train_dataset.__len__() 137 | len_valset = val_dataset.__len__() 138 | print("Sizeof training set: ", len_trainset,", sizeof validation set: ", len_valset) 139 | ######### validation ########### 140 | 141 | with torch.no_grad(): 142 | model_restoration.eval() 143 | psnr_dataset = [] 144 | psnr_model_init = [] 145 | for ii, data_val in enumerate((val_loader), 0): 146 | target = data_val[0].cuda() 147 | input_ = data_val[1].cuda() 148 | with torch.cuda.amp.autocast(): 149 | restored = model_restoration(input_) 150 | restored = torch.clamp(restored,0,1) 151 | psnr_dataset.append(utils.batch_PSNR(input_, target, False, 'y').item()) 152 | psnr_model_init.append(utils.batch_PSNR(restored, target, False, 'y').item()) 153 | psnr_dataset = sum(psnr_dataset)/len_valset 154 | psnr_model_init = sum(psnr_model_init)/len_valset 155 | print('Input & GT (PSNR) -->%.4f dB'%(psnr_dataset), ', Model_init & GT (PSNR) -->%.4f dB'%(psnr_model_init)) 156 | 157 | ######### train ########### 158 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.nepoch)) 159 | best_psnr = 0 160 | best_epoch = 0 161 | best_iter = 0 162 | eval_now = len(train_loader)//4 163 | print("\nEvaluation after every {} Iterations !!!\n".format(eval_now)) 164 | 165 | loss_scaler = NativeScaler() 166 | torch.cuda.empty_cache() 167 | for epoch in range(start_epoch, opt.nepoch + 1): 168 | epoch_start_time = time.time() 169 | epoch_loss = 0 170 | train_id = 1 171 | 172 | for i, data in enumerate(tqdm(train_loader), 0): 173 | # zero_grad 174 | optimizer.zero_grad() 175 | 176 | target = data[0].cuda() 177 | input_ = data[1].cuda() 178 | 179 | if epoch>5: 180 | target, input_ = utils.MixUp_AUG().aug(target, input_) 181 | with torch.cuda.amp.autocast(): 182 | restored = model_restoration(input_) 183 | loss = criterion(restored, target) 184 | loss_scaler( 185 | loss, optimizer,parameters=model_restoration.parameters()) 186 | epoch_loss +=loss.item() 187 | 188 | #### Evaluation #### 189 | if (i+1)%eval_now==0 and i>0: 190 | with torch.no_grad(): 191 | model_restoration.eval() 192 | psnr_val_rgb = [] 193 | for ii, data_val in enumerate((val_loader), 0): 194 | target = data_val[0].cuda() 195 | input_ = data_val[1].cuda() 196 | filenames = data_val[2] 197 | with torch.cuda.amp.autocast(): 198 | restored = model_restoration(input_) 199 | restored = torch.clamp(restored,0,1) 200 | psnr_val_rgb.append(utils.batch_PSNR(restored, target, False, 'y').item()) 201 | 202 | psnr_val_rgb = sum(psnr_val_rgb)/len_valset 203 | 204 | if psnr_val_rgb > best_psnr: 205 | best_psnr = psnr_val_rgb 206 | best_epoch = epoch 207 | best_iter = i 208 | torch.save({'epoch': epoch, 209 | 'state_dict': model_restoration.state_dict(), 210 | 'optimizer' : optimizer.state_dict() 211 | }, os.path.join(model_dir,"model_best.pth")) 212 | 213 | print("[Ep %d it %d\t PSNR SPAD: %.4f\t] ---- [best_Ep_SPAD %d best_it_SPAD %d Best_PSNR_SPAD %.4f] " % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr)) 214 | with open(logname,'a') as f: 215 | f.write("[Ep %d it %d\t PSNR SPAD: %.4f\t] ---- [best_Ep_SPAD %d best_it_SPAD %d Best_PSNR_SPAD %.4f] " \ 216 | % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr)+'\n') 217 | model_restoration.train() 218 | torch.cuda.empty_cache() 219 | scheduler.step() 220 | 221 | print("------------------------------------------------------------------") 222 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0])) 223 | print("------------------------------------------------------------------") 224 | with open(logname,'a') as f: 225 | f.write("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0])+'\n') 226 | 227 | torch.save({'epoch': epoch, 228 | 'state_dict': model_restoration.state_dict(), 229 | 'optimizer' : optimizer.state_dict() 230 | }, os.path.join(model_dir,"model_latest.pth")) 231 | 232 | if epoch%opt.checkpoint == 0: 233 | torch.save({'epoch': epoch, 234 | 'state_dict': model_restoration.state_dict(), 235 | 'optimizer' : optimizer.state_dict() 236 | }, os.path.join(model_dir,"model_epoch_{}.pth".format(epoch))) 237 | print("Now time is : ",datetime.datetime.now().isoformat()) 238 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/.DS_Store -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .dir_utils import * 2 | from .dataset_utils import * 3 | from .image_utils import * 4 | from .model_utils import * 5 | from .caculate_psnr_ssim import * -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/caculate_psnr_ssim.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/caculate_psnr_ssim.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/caculate_psnr_ssim.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/caculate_psnr_ssim.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/dataset_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dataset_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/dataset_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dir_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/dir_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/dir_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/dir_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/image_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/image_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/image_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/image_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/model_utils.cpython-37.pyc -------------------------------------------------------------------------------- /utils/__pycache__/model_utils.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/model_utils.cpython-38.pyc -------------------------------------------------------------------------------- /utils/antialias.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.parallel 3 | import numpy as np 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | class Downsample(nn.Module): 8 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): 9 | super(Downsample, self).__init__() 10 | self.filt_size = filt_size 11 | self.pad_off = pad_off 12 | self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))] 13 | self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes] 14 | self.stride = stride 15 | self.off = int((self.stride-1)/2.) 16 | self.channels = channels 17 | 18 | # print('Filter size [%i]'%filt_size) 19 | if(self.filt_size==1): 20 | a = np.array([1.,]) 21 | elif(self.filt_size==2): 22 | a = np.array([1., 1.]) 23 | elif(self.filt_size==3): 24 | a = np.array([1., 2., 1.]) 25 | elif(self.filt_size==4): 26 | a = np.array([1., 3., 3., 1.]) 27 | elif(self.filt_size==5): 28 | a = np.array([1., 4., 6., 4., 1.]) 29 | elif(self.filt_size==6): 30 | a = np.array([1., 5., 10., 10., 5., 1.]) 31 | elif(self.filt_size==7): 32 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 33 | 34 | filt = torch.Tensor(a[:,None]*a[None,:]) 35 | filt = filt/torch.sum(filt) 36 | self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1))) 37 | 38 | self.pad = get_pad_layer(pad_type)(self.pad_sizes) 39 | 40 | def forward(self, inp): 41 | if(self.filt_size==1): 42 | if(self.pad_off==0): 43 | return inp[:,:,::self.stride,::self.stride] 44 | else: 45 | return self.pad(inp)[:,:,::self.stride,::self.stride] 46 | else: 47 | return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 48 | 49 | def get_pad_layer(pad_type): 50 | if(pad_type in ['refl','reflect']): 51 | PadLayer = nn.ReflectionPad2d 52 | elif(pad_type in ['repl','replicate']): 53 | PadLayer = nn.ReplicationPad2d 54 | elif(pad_type=='zero'): 55 | PadLayer = nn.ZeroPad2d 56 | else: 57 | print('Pad type [%s] not recognized'%pad_type) 58 | return PadLayer 59 | 60 | 61 | class Downsample1D(nn.Module): 62 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0): 63 | super(Downsample1D, self).__init__() 64 | self.filt_size = filt_size 65 | self.pad_off = pad_off 66 | self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))] 67 | self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes] 68 | self.stride = stride 69 | self.off = int((self.stride - 1) / 2.) 70 | self.channels = channels 71 | 72 | # print('Filter size [%i]' % filt_size) 73 | if(self.filt_size == 1): 74 | a = np.array([1., ]) 75 | elif(self.filt_size == 2): 76 | a = np.array([1., 1.]) 77 | elif(self.filt_size == 3): 78 | a = np.array([1., 2., 1.]) 79 | elif(self.filt_size == 4): 80 | a = np.array([1., 3., 3., 1.]) 81 | elif(self.filt_size == 5): 82 | a = np.array([1., 4., 6., 4., 1.]) 83 | elif(self.filt_size == 6): 84 | a = np.array([1., 5., 10., 10., 5., 1.]) 85 | elif(self.filt_size == 7): 86 | a = np.array([1., 6., 15., 20., 15., 6., 1.]) 87 | 88 | filt = torch.Tensor(a) 89 | filt = filt / torch.sum(filt) 90 | self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1))) 91 | 92 | self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes) 93 | 94 | def forward(self, inp): 95 | if(self.filt_size == 1): 96 | if(self.pad_off == 0): 97 | return inp[:, :, ::self.stride] 98 | else: 99 | return self.pad(inp)[:, :, ::self.stride] 100 | else: 101 | return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1]) 102 | 103 | 104 | def get_pad_layer_1d(pad_type): 105 | if(pad_type in ['refl', 'reflect']): 106 | PadLayer = nn.ReflectionPad1d 107 | elif(pad_type in ['repl', 'replicate']): 108 | PadLayer = nn.ReplicationPad1d 109 | elif(pad_type == 'zero'): 110 | PadLayer = nn.ZeroPad1d 111 | else: 112 | print('Pad type [%s] not recognized' % pad_type) 113 | return PadLayer -------------------------------------------------------------------------------- /utils/bundle_submissions.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy.io as sio 3 | import os 4 | import h5py 5 | 6 | def bundle_submissions_raw(submission_folder,session): 7 | ''' 8 | Bundles submission data for raw denoising 9 | submission_folder Folder where denoised images reside 10 | Output is written to /bundled/. Please submit 11 | the content of this folder. 12 | ''' 13 | 14 | out_folder = os.path.join(submission_folder, session) 15 | # out_folder = os.path.join(submission_folder, "bundled/") 16 | try: 17 | os.mkdir(out_folder) 18 | except:pass 19 | 20 | israw = True 21 | eval_version="1.0" 22 | 23 | for i in range(50): 24 | Idenoised = np.zeros((20,), dtype=np.object) 25 | for bb in range(20): 26 | filename = '%04d_%02d.mat'%(i+1,bb+1) 27 | s = sio.loadmat(os.path.join(submission_folder,filename)) 28 | Idenoised_crop = s["Idenoised_crop"] 29 | Idenoised[bb] = Idenoised_crop 30 | filename = '%04d.mat'%(i+1) 31 | sio.savemat(os.path.join(out_folder, filename), 32 | {"Idenoised": Idenoised, 33 | "israw": israw, 34 | "eval_version": eval_version}, 35 | ) 36 | 37 | def bundle_submissions_srgb(submission_folder,session): 38 | ''' 39 | Bundles submission data for sRGB denoising 40 | 41 | submission_folder Folder where denoised images reside 42 | Output is written to /bundled/. Please submit 43 | the content of this folder. 44 | ''' 45 | out_folder = os.path.join(submission_folder, session) 46 | # out_folder = os.path.join(submission_folder, "bundled/") 47 | try: 48 | os.mkdir(out_folder) 49 | except:pass 50 | israw = False 51 | eval_version="1.0" 52 | 53 | for i in range(50): 54 | Idenoised = np.zeros((20,), dtype=np.object) 55 | for bb in range(20): 56 | filename = '%04d_%02d.mat'%(i+1,bb+1) 57 | s = sio.loadmat(os.path.join(submission_folder,filename)) 58 | Idenoised_crop = s["Idenoised_crop"] 59 | Idenoised[bb] = Idenoised_crop 60 | filename = '%04d.mat'%(i+1) 61 | sio.savemat(os.path.join(out_folder, filename), 62 | {"Idenoised": Idenoised, 63 | "israw": israw, 64 | "eval_version": eval_version}, 65 | ) 66 | 67 | 68 | 69 | def bundle_submissions_srgb_v1(submission_folder,session): 70 | ''' 71 | Bundles submission data for sRGB denoising 72 | 73 | submission_folder Folder where denoised images reside 74 | Output is written to /bundled/. Please submit 75 | the content of this folder. 76 | ''' 77 | out_folder = os.path.join(submission_folder, session) 78 | # out_folder = os.path.join(submission_folder, "bundled/") 79 | try: 80 | os.mkdir(out_folder) 81 | except:pass 82 | israw = False 83 | eval_version="1.0" 84 | 85 | for i in range(50): 86 | Idenoised = np.zeros((20,), dtype=np.object) 87 | for bb in range(20): 88 | filename = '%04d_%d.mat'%(i+1,bb+1) 89 | s = sio.loadmat(os.path.join(submission_folder,filename)) 90 | Idenoised_crop = s["Idenoised_crop"] 91 | Idenoised[bb] = Idenoised_crop 92 | filename = '%04d.mat'%(i+1) 93 | sio.savemat(os.path.join(out_folder, filename), 94 | {"Idenoised": Idenoised, 95 | "israw": israw, 96 | "eval_version": eval_version}, 97 | ) -------------------------------------------------------------------------------- /utils/caculate_psnr_ssim.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import numpy as np 3 | import torch 4 | 5 | # convert 2/3/4-dimensional torch tensor to uint 6 | def tensor2uint(img): 7 | img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy() 8 | if img.ndim == 3: 9 | img = np.transpose(img, (1, 2, 0)) 10 | return np.uint8((img*255.0).round()) 11 | 12 | def calculate_psnr(img1, img2, crop_border=0, input_order='HWC', test_y_channel=False): 13 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 14 | if input_order not in ['HWC', 'CHW']: 15 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 16 | img1 = reorder_image(img1, input_order=input_order) 17 | img2 = reorder_image(img2, input_order=input_order) 18 | img1 = img1.astype(np.float64) 19 | img2 = img2.astype(np.float64) 20 | 21 | if crop_border != 0: 22 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 23 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 24 | 25 | if test_y_channel: 26 | img1 = to_y_channel(img1) 27 | img2 = to_y_channel(img2) 28 | 29 | mse = np.mean((img1 - img2) ** 2) 30 | if mse == 0: 31 | return float('inf') 32 | return 20. * np.log10(255. / np.sqrt(mse)) 33 | 34 | 35 | def _ssim(img1, img2): 36 | C1 = (0.01 * 255) ** 2 37 | C2 = (0.03 * 255) ** 2 38 | 39 | img1 = img1.astype(np.float64) 40 | img2 = img2.astype(np.float64) 41 | kernel = cv2.getGaussianKernel(11, 1.5) 42 | window = np.outer(kernel, kernel.transpose()) 43 | 44 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 45 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 46 | mu1_sq = mu1 ** 2 47 | mu2_sq = mu2 ** 2 48 | mu1_mu2 = mu1 * mu2 49 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 50 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 51 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 52 | 53 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 54 | return ssim_map.mean() 55 | 56 | 57 | def calculate_ssim(img1, img2, crop_border=0, input_order='HWC', test_y_channel=False): 58 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 59 | if img1.dtype is not np.uint8: 60 | img1 = (img1 * 255.0).round().astype(np.uint8) # float32 to uint8 61 | if img2.dtype is not np.uint8: 62 | img2 = (img2 * 255.0).round().astype(np.uint8) # float32 to uint8 63 | if input_order not in ['HWC', 'CHW']: 64 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 65 | img1 = reorder_image(img1, input_order=input_order) 66 | img2 = reorder_image(img2, input_order=input_order) 67 | img1 = img1.astype(np.float64) 68 | img2 = img2.astype(np.float64) 69 | 70 | if crop_border != 0: 71 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 72 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 73 | 74 | if test_y_channel: 75 | img1 = to_y_channel(img1) 76 | img2 = to_y_channel(img2) 77 | 78 | ssims = [] 79 | for i in range(img1.shape[2]): 80 | ssims.append(_ssim(img1[..., i], img2[..., i])) 81 | return np.array(ssims).mean() 82 | 83 | 84 | def _blocking_effect_factor(im): 85 | block_size = 8 86 | 87 | block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8) 88 | block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8) 89 | 90 | horizontal_block_difference = ( 91 | (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum( 92 | 3).sum(2).sum(1) 93 | vertical_block_difference = ( 94 | (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum( 95 | 2).sum(1) 96 | 97 | nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions) 98 | nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions) 99 | 100 | horizontal_nonblock_difference = ( 101 | (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum( 102 | 3).sum(2).sum(1) 103 | vertical_nonblock_difference = ( 104 | (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum( 105 | 3).sum(2).sum(1) 106 | 107 | n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1) 108 | n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1) 109 | boundary_difference = (horizontal_block_difference + vertical_block_difference) / ( 110 | n_boundary_horiz + n_boundary_vert) 111 | 112 | n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz 113 | n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert 114 | nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / ( 115 | n_nonboundary_horiz + n_nonboundary_vert) 116 | 117 | scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]])) 118 | bef = scaler * (boundary_difference - nonboundary_difference) 119 | 120 | bef[boundary_difference <= nonboundary_difference] = 0 121 | return bef 122 | 123 | 124 | def calculate_psnrb(img1, img2, crop_border, input_order='HWC', test_y_channel=False): 125 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.') 126 | if input_order not in ['HWC', 'CHW']: 127 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"') 128 | img1 = reorder_image(img1, input_order=input_order) 129 | img2 = reorder_image(img2, input_order=input_order) 130 | img1 = img1.astype(np.float64) 131 | img2 = img2.astype(np.float64) 132 | 133 | if crop_border != 0: 134 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...] 135 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...] 136 | 137 | if test_y_channel: 138 | img1 = to_y_channel(img1) 139 | img2 = to_y_channel(img2) 140 | 141 | img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255. 142 | img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255. 143 | 144 | total = 0 145 | for c in range(img1.shape[1]): 146 | mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none') 147 | bef = _blocking_effect_factor(img1[:, c:c + 1, :, :]) 148 | 149 | mse = mse.view(mse.shape[0], -1).mean(1) 150 | total += 10 * torch.log10(1 / (mse + bef)) 151 | 152 | return float(total) / img1.shape[1] 153 | 154 | 155 | def reorder_image(img, input_order='HWC'): 156 | if input_order not in ['HWC', 'CHW']: 157 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'") 158 | if len(img.shape) == 2: 159 | img = img[..., None] 160 | if input_order == 'CHW': 161 | img = img.transpose(1, 2, 0) 162 | return img 163 | 164 | 165 | def to_y_channel(img): 166 | img = img.astype(np.float32) / 255. 167 | if img.ndim == 3 and img.shape[2] == 3: 168 | img = rgb2ycbcr(img, y_only=True) 169 | img = img[..., None] 170 | else: 171 | raise ValueError(f'Wrong image shape [2]: {img.shape[2]}.') 172 | return img * 255. 173 | 174 | 175 | def _convert_input_type_range(img): 176 | img_type = img.dtype 177 | img = img.astype(np.float32) 178 | if img_type == np.float32: 179 | pass 180 | elif img_type == np.uint8: 181 | img /= 255. 182 | else: 183 | raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}') 184 | return img 185 | 186 | 187 | def _convert_output_type_range(img, dst_type): 188 | if dst_type not in (np.uint8, np.float32): 189 | raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}') 190 | if dst_type == np.uint8: 191 | img = img.round() 192 | else: 193 | img /= 255. 194 | return img.astype(dst_type) 195 | 196 | 197 | def rgb2ycbcr(img, y_only=False): 198 | img_type = img.dtype 199 | img = _convert_input_type_range(img) 200 | if y_only: 201 | out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0 202 | else: 203 | out_img = np.matmul( 204 | img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],[24.966, 112.0, -18.214]]) + [16, 128, 128] 205 | out_img = _convert_output_type_range(out_img, img_type) 206 | return out_img -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | 4 | ### rotate and flip 5 | class Augment_RGB_torch: 6 | def __init__(self): 7 | pass 8 | def transform0(self, torch_tensor): 9 | return torch_tensor 10 | def transform1(self, torch_tensor): 11 | torch_tensor = torch.rot90(torch_tensor, k=1, dims=[-1,-2]) 12 | return torch_tensor 13 | def transform2(self, torch_tensor): 14 | torch_tensor = torch.rot90(torch_tensor, k=2, dims=[-1,-2]) 15 | return torch_tensor 16 | def transform3(self, torch_tensor): 17 | torch_tensor = torch.rot90(torch_tensor, k=3, dims=[-1,-2]) 18 | return torch_tensor 19 | def transform4(self, torch_tensor): 20 | torch_tensor = torch_tensor.flip(-2) 21 | return torch_tensor 22 | def transform5(self, torch_tensor): 23 | torch_tensor = (torch.rot90(torch_tensor, k=1, dims=[-1,-2])).flip(-2) 24 | return torch_tensor 25 | def transform6(self, torch_tensor): 26 | torch_tensor = (torch.rot90(torch_tensor, k=2, dims=[-1,-2])).flip(-2) 27 | return torch_tensor 28 | def transform7(self, torch_tensor): 29 | torch_tensor = (torch.rot90(torch_tensor, k=3, dims=[-1,-2])).flip(-2) 30 | return torch_tensor 31 | 32 | 33 | ### mix two images 34 | class MixUp_AUG: 35 | def __init__(self): 36 | self.dist = torch.distributions.beta.Beta(torch.tensor([1.2]), torch.tensor([1.2])) 37 | 38 | def aug(self, rgb_gt, rgb_noisy): 39 | bs = rgb_gt.size(0) 40 | indices = torch.randperm(bs) 41 | rgb_gt2 = rgb_gt[indices] 42 | rgb_noisy2 = rgb_noisy[indices] 43 | 44 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda() 45 | 46 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2 47 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2 48 | 49 | return rgb_gt, rgb_noisy 50 | -------------------------------------------------------------------------------- /utils/dir_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from natsort import natsorted 3 | from glob import glob 4 | 5 | def mkdirs(paths): 6 | if isinstance(paths, list) and not isinstance(paths, str): 7 | for path in paths: 8 | mkdir(path) 9 | else: 10 | mkdir(paths) 11 | 12 | def mkdir(path): 13 | if not os.path.exists(path): 14 | os.makedirs(path) 15 | 16 | def get_last_path(path, session): 17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1] 18 | return x -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import pickle 4 | import cv2 5 | import math 6 | import PIL 7 | from PIL import Image 8 | import torchvision.transforms.functional as F 9 | import torch.nn.functional as f 10 | 11 | def is_numpy_file(filename): 12 | return any(filename.endswith(extension) for extension in [".npy"]) 13 | 14 | def is_image_file(filename): 15 | return any(filename.endswith(extension) for extension in [".jpg"]) 16 | 17 | def is_png_file(filename): 18 | return any(filename.endswith(extension) for extension in [".png"]) 19 | 20 | def is_pkl_file(filename): 21 | return any(filename.endswith(extension) for extension in [".pkl"]) 22 | 23 | def load_pkl(filename_): 24 | with open(filename_, 'rb') as f: 25 | ret_dict = pickle.load(f) 26 | return ret_dict 27 | 28 | def save_dict(dict_, filename_): 29 | with open(filename_, 'wb') as f: 30 | pickle.dump(dict_, f) 31 | 32 | def load_npy(file_path): 33 | img = np.load(file_path) 34 | img = img.astype(np.float32) 35 | img = img/255. 36 | img = img[:, :, [2, 1, 0]] 37 | return img 38 | 39 | def load_img(filepath): 40 | img = cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB) 41 | img = img.astype(np.float32) 42 | img = img/255. 43 | return img 44 | 45 | 46 | def load_reflectPad(filepath): 47 | img = Image.open(filepath).convert('RGB') 48 | factor = 128 49 | h,w = img.size 50 | H, W = ((h + factor) // factor) * factor, ((w + factor) // factor * factor) 51 | padh = H - h if h % factor != 0 else 0 52 | padw = W - w if w % factor != 0 else 0 53 | img = f.pad(img, (0, padw, 0, padh), 'reflect') 54 | img = np.float32(img) 55 | img = img / 255. 56 | return img 57 | 58 | 59 | def load_resize(filepath): 60 | img = Image.open(filepath).convert('RGB') 61 | wd_new, ht_new = img.size 62 | if ht_new > wd_new and ht_new > 1024: 63 | wd_new = int(np.ceil(wd_new * 1024 / ht_new)) 64 | ht_new = 1024 65 | elif ht_new <= wd_new and wd_new > 1024: 66 | ht_new = int(np.ceil(ht_new * 1024 / wd_new)) 67 | wd_new = 1024 68 | wd_new = int(128 * np.ceil(wd_new / 128.0)) 69 | ht_new = int(128 * np.ceil(ht_new / 128.0)) 70 | target_edge = wd_new if wd_new >= ht_new else ht_new 71 | img = img.resize((target_edge, target_edge), PIL.Image.ANTIALIAS) 72 | 73 | img = np.float32(img) 74 | img = img / 255. 75 | return img 76 | 77 | 78 | def loader4dehaze(filepath,ps=256): 79 | img = Image.open(filepath).convert('RGB') 80 | w, h = img.size 81 | if w < ps : 82 | padW = 1+(ps-w)//2 83 | img = F.pad(img, (padW,0,padW,0), 0, 'constant') 84 | if h < ps : 85 | padH = 1+(ps-h)//2 86 | img = F.pad(img, (0,padH,0,padH), 0, 'constant') 87 | img = np.float32(img) 88 | img = img / 255. 89 | return img 90 | 91 | 92 | def save_img(filepath, img): 93 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) 94 | 95 | 96 | def myPSNR(tar_img, prd_img, cal_type): 97 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1) 98 | 99 | if cal_type == 'y': 100 | gray_coeffs = [65.738, 129.057, 25.064] 101 | convert = imdff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256 102 | imdff = imdff.mul(convert).sum(dim=1) 103 | 104 | rmse = (imdff**2).mean().sqrt() 105 | ps = 20*torch.log10(1/rmse) 106 | return ps 107 | 108 | def batch_PSNR(img1, img2, average=True, cal_type='N'): 109 | PSNR = [] 110 | for im1, im2 in zip(img1, img2): 111 | psnr = myPSNR(im1, im2, cal_type) 112 | PSNR.append(psnr) 113 | return sum(PSNR)/len(PSNR) if average else sum(PSNR) 114 | 115 | def splitimage(imgtensor, crop_size=128, overlap_size=64): 116 | _, C, H, W = imgtensor.shape 117 | hstarts = [x for x in range(0, H, crop_size - overlap_size)] 118 | while hstarts and hstarts[-1] + crop_size >= H: 119 | hstarts.pop() 120 | hstarts.append(H - crop_size) 121 | wstarts = [x for x in range(0, W, crop_size - overlap_size)] 122 | while wstarts and wstarts[-1] + crop_size >= W: 123 | wstarts.pop() 124 | wstarts.append(W - crop_size) 125 | starts = [] 126 | split_data = [] 127 | for hs in hstarts: 128 | for ws in wstarts: 129 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size] 130 | starts.append((hs, ws)) 131 | split_data.append(cimgdata) 132 | return split_data, starts 133 | 134 | def get_scoremap(H, W, C, B=1, is_mean=True): 135 | center_h = H / 2 136 | center_w = W / 2 137 | 138 | score = torch.ones((B, C, H, W)) 139 | if not is_mean: 140 | for h in range(H): 141 | for w in range(W): 142 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6)) 143 | return score 144 | 145 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)): 146 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3] 147 | tot_score = torch.zeros((B, C, H, W)) 148 | merge_img = torch.zeros((B, C, H, W)) 149 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True) 150 | for simg, cstart in zip(split_data, starts): 151 | hs, ws = cstart 152 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg 153 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap 154 | merge_img = merge_img / tot_score 155 | return merge_img -------------------------------------------------------------------------------- /utils/loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | 3 | from dataset import DataLoaderTrain, DataLoaderVal, DataLoaderTest, DataLoaderTestSR 4 | def get_training_data(rgb_dir, img_options): 5 | assert os.path.exists(rgb_dir) 6 | return DataLoaderTrain(rgb_dir, img_options, None) 7 | 8 | def get_validation_data(rgb_dir): 9 | assert os.path.exists(rgb_dir) 10 | return DataLoaderVal(rgb_dir, None) 11 | 12 | 13 | def get_test_data(rgb_dir): 14 | assert os.path.exists(rgb_dir) 15 | return DataLoaderTest(rgb_dir, None) 16 | 17 | 18 | def get_test_data_SR(rgb_dir): 19 | assert os.path.exists(rgb_dir) 20 | return DataLoaderTestSR(rgb_dir, None) -------------------------------------------------------------------------------- /utils/model_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import os 4 | from collections import OrderedDict 5 | 6 | def freeze(model): 7 | for p in model.parameters(): 8 | p.requires_grad=False 9 | 10 | def unfreeze(model): 11 | for p in model.parameters(): 12 | p.requires_grad=True 13 | 14 | def is_frozen(model): 15 | x = [p.requires_grad for p in model.parameters()] 16 | return not all(x) 17 | 18 | def save_checkpoint(model_dir, state, session): 19 | epoch = state['epoch'] 20 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session)) 21 | torch.save(state, model_out_path) 22 | 23 | def load_checkpoint(model, weights): 24 | checkpoint = torch.load(weights) 25 | try: 26 | model.load_state_dict(checkpoint["state_dict"]) 27 | except: 28 | state_dict = checkpoint["state_dict"] 29 | new_state_dict = OrderedDict() 30 | for k, v in state_dict.items(): 31 | name = k[7:] if 'module.' in k else k 32 | new_state_dict[name] = v 33 | model.load_state_dict(new_state_dict) 34 | 35 | 36 | def load_checkpoint_multigpu(model, weights): 37 | checkpoint = torch.load(weights) 38 | state_dict = checkpoint["state_dict"] 39 | new_state_dict = OrderedDict() 40 | for k, v in state_dict.items(): 41 | name = k[7:] 42 | new_state_dict[name] = v 43 | model.load_state_dict(new_state_dict) 44 | 45 | def load_start_epoch(weights): 46 | checkpoint = torch.load(weights) 47 | epoch = checkpoint["epoch"] 48 | return epoch 49 | 50 | def load_optim(optimizer, weights): 51 | checkpoint = torch.load(weights) 52 | optimizer.load_state_dict(checkpoint['optimizer']) 53 | for p in optimizer.param_groups: lr = p['lr'] 54 | return lr 55 | 56 | def get_arch(opt): 57 | from model import AST 58 | 59 | arch = opt.arch 60 | 61 | print('You choose '+arch+'...') 62 | if arch == 'AST_T': 63 | model_restoration = AST(img_size=opt.train_ps,embed_dim=16,win_size=8,token_projection='linear',token_mlp=opt.token_mlp) 64 | elif arch == 'AST_B': 65 | model_restoration = AST(img_size=opt.train_ps,embed_dim=32,win_size=8,token_projection='linear',token_mlp=opt.token_mlp, 66 | depths=[1, 2, 8, 8, 2, 8, 8, 2, 1],dd_in=opt.dd_in) 67 | else: 68 | raise Exception("Arch error!") 69 | 70 | return model_restoration -------------------------------------------------------------------------------- /warmup_scheduler/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from warmup_scheduler.scheduler import GradualWarmupScheduler 3 | -------------------------------------------------------------------------------- /warmup_scheduler/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/warmup_scheduler/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /warmup_scheduler/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/warmup_scheduler/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /warmup_scheduler/__pycache__/scheduler.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/warmup_scheduler/__pycache__/scheduler.cpython-37.pyc -------------------------------------------------------------------------------- /warmup_scheduler/__pycache__/scheduler.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/warmup_scheduler/__pycache__/scheduler.cpython-38.pyc -------------------------------------------------------------------------------- /warmup_scheduler/run.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR 3 | from torch.optim.sgd import SGD 4 | 5 | from warmup_scheduler import GradualWarmupScheduler 6 | 7 | 8 | if __name__ == '__main__': 9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))] 10 | optim = SGD(model, 0.1) 11 | 12 | # scheduler_warmup is chained with schduler_steplr 13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1) 14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr) 15 | 16 | # this zero gradient update is needed to avoid a warning message, issue #8. 17 | optim.zero_grad() 18 | optim.step() 19 | 20 | for epoch in range(1, 20): 21 | scheduler_warmup.step(epoch) 22 | print(epoch, optim.param_groups[0]['lr']) 23 | 24 | optim.step() # backward pass (update network) 25 | -------------------------------------------------------------------------------- /warmup_scheduler/scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """ Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | 9 | Args: 10 | optimizer (Optimizer): Wrapped optimizer. 11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr. 12 | total_epoch: target learning rate is reached at total_epoch, gradually 13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 14 | """ 15 | 16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 17 | self.multiplier = multiplier 18 | if self.multiplier < 1.: 19 | raise ValueError('multiplier should be greater thant or equal to 1.') 20 | self.total_epoch = total_epoch 21 | self.after_scheduler = after_scheduler 22 | self.finished = False 23 | super(GradualWarmupScheduler, self).__init__(optimizer) 24 | 25 | def get_lr(self): 26 | if self.last_epoch > self.total_epoch: 27 | if self.after_scheduler: 28 | if not self.finished: 29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs] 30 | self.finished = True 31 | return self.after_scheduler.get_lr() 32 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 33 | 34 | if self.multiplier == 1.0: 35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs] 36 | else: 37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 38 | 39 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 40 | if epoch is None: 41 | epoch = self.last_epoch + 1 42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 43 | if self.last_epoch <= self.total_epoch: 44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs] 45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 46 | param_group['lr'] = lr 47 | else: 48 | if epoch is None: 49 | self.after_scheduler.step(metrics, None) 50 | else: 51 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 52 | 53 | def step(self, epoch=None, metrics=None): 54 | if type(self.after_scheduler) != ReduceLROnPlateau: 55 | if self.finished and self.after_scheduler: 56 | if epoch is None: 57 | self.after_scheduler.step(None) 58 | else: 59 | self.after_scheduler.step(epoch - self.total_epoch) 60 | else: 61 | return super(GradualWarmupScheduler, self).step(epoch) 62 | else: 63 | self.step_ReduceLROnPlateau(metrics, epoch) 64 | --------------------------------------------------------------------------------