├── .DS_Store
├── LICENSE
├── README.md
├── dataset
├── __pycache__
│ ├── dataset_defocusblur.cpython-37.pyc
│ ├── dataset_defocusblur.cpython-38.pyc
│ ├── dataset_dehaze.cpython-37.pyc
│ ├── dataset_dehaze_denseHaze.cpython-37.pyc
│ ├── dataset_dehaze_denseHaze.cpython-38.pyc
│ ├── dataset_demoire.cpython-37.pyc
│ ├── dataset_derain.cpython-37.pyc
│ ├── dataset_derain.cpython-38.pyc
│ ├── dataset_derain_drop.cpython-37.pyc
│ ├── dataset_derain_syn.cpython-37.pyc
│ ├── dataset_deshadow.cpython-37.pyc
│ ├── dataset_desnow.cpython-37.pyc
│ ├── dataset_enhance_smid.cpython-37.pyc
│ └── dataset_motiondeblur.cpython-37.pyc
├── dataset_dehaze_denseHaze.py
├── dataset_derain.py
└── dataset_derain_drop.py
├── evaluate_PSNR_SSIM.m
├── losses.py
├── model.py
├── options.py
├── requirements.txt
├── script
├── test.sh
├── train_dehaze.sh
├── train_derain.sh
└── train_raindrop.sh
├── test
├── test_denseHaze.py
├── test_raindrop.py
└── test_spad.py
├── train
├── train_dehaze.py
├── train_derain.py
└── train_raindrop.py
├── utils
├── .DS_Store
├── __init__.py
├── __pycache__
│ ├── __init__.cpython-37.pyc
│ ├── __init__.cpython-38.pyc
│ ├── caculate_psnr_ssim.cpython-37.pyc
│ ├── caculate_psnr_ssim.cpython-38.pyc
│ ├── dataset_utils.cpython-37.pyc
│ ├── dataset_utils.cpython-38.pyc
│ ├── dir_utils.cpython-37.pyc
│ ├── dir_utils.cpython-38.pyc
│ ├── image_utils.cpython-37.pyc
│ ├── image_utils.cpython-38.pyc
│ ├── model_utils.cpython-37.pyc
│ └── model_utils.cpython-38.pyc
├── antialias.py
├── bundle_submissions.py
├── caculate_psnr_ssim.py
├── dataset_utils.py
├── dir_utils.py
├── image_utils.py
├── loader.py
└── model_utils.py
└── warmup_scheduler
├── __init__.py
├── __pycache__
├── __init__.cpython-37.pyc
├── __init__.cpython-38.pyc
├── scheduler.cpython-37.pyc
└── scheduler.cpython-38.pyc
├── run.py
└── scheduler.py
/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/.DS_Store
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restoration (CVPR 2024)
2 |
3 | [Shihao Zhou](https://joshyzhou.github.io/), [Duosheng Chen](https://github.com/Calvin11311), [Jinshan Pan](https://jspan.github.io/), [Jinglei Shi](https://jingleishi.github.io/), and [Jufeng Yang](https://cv.nankai.edu.cn/)
4 |
5 |
6 | #### News
7 | - **Feb 27, 2024:** AST has been accepted to CVPR 2024 :tada:
8 |
9 |
10 |
11 |
13 |
14 | ## Package dependencies
15 | The project is built with PyTorch 1.9.0, Python3.7, CUDA11.1. For package dependencies, you can install them by:
16 | ```bash
17 | pip install -r requirements.txt
18 | ```
19 | ## Training
20 | ### Derain
21 | To train AST on SPAD, you can run:
22 | ```sh
23 | sh script/train_derain.sh
24 | ```
25 | ### Dehaze
26 | To train AST on Densehaze, you can run:
27 | ```sh
28 | sh script/train_dehaze.sh
29 | ```
30 | ### Raindrop
31 | To train AST on AGAN, you can run:
32 | ```sh
33 | sh script/train_raindrop.sh
34 | ```
35 |
36 |
37 | ## Evaluation
38 | To evaluate AST, you can run:
39 |
40 | ```sh
41 | sh script/test.sh
42 | ```
43 | For evaluate on each dataset, you should uncomment corresponding line.
44 |
45 |
46 | ## Results
47 | Experiments are performed for different image processing tasks including, rain streak removal, raindrop removal, and haze removal.
48 | Here is a summary table containing hyperlinks for easy navigation:
49 |
50 |
74 |
75 | ## Citation
76 | If you find this project useful, please consider citing:
77 |
78 | @inproceedings{zhou2024AST,
79 | title={Adapt or Perish: Adaptive Sparse Transformer with Attentive Feature Refinement for Image Restoration},
80 | author={Zhou, Shihao and Chen, Duosheng and Pan, Jinshan and Shi, Jinglei and Yang, Jufeng},
81 | booktitle={CVPR},
82 | year={2024}
83 | }
84 |
85 | ## Acknowledgement
86 |
87 | This code borrows heavily from [Uformer](https://github.com/ZhendongWang6/Uformer).
88 |
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_defocusblur.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_defocusblur.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_defocusblur.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_defocusblur.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_dehaze.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_dehaze.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_dehaze_denseHaze.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_dehaze_denseHaze.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_dehaze_denseHaze.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_dehaze_denseHaze.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_demoire.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_demoire.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_derain.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_derain.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_derain.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_derain.cpython-38.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_derain_drop.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_derain_drop.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_derain_syn.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_derain_syn.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_deshadow.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_deshadow.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_desnow.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_desnow.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_enhance_smid.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_enhance_smid.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/__pycache__/dataset_motiondeblur.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/dataset/__pycache__/dataset_motiondeblur.cpython-37.pyc
--------------------------------------------------------------------------------
/dataset/dataset_dehaze_denseHaze.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from torch.utils.data import Dataset
4 | import torch
5 | from utils import is_png_file, load_img, Augment_RGB_torch
6 | import torch.nn.functional as F
7 | import random
8 | from PIL import Image
9 | import torchvision.transforms.functional as TF
10 | from natsort import natsorted
11 | from glob import glob
12 | augment = Augment_RGB_torch()
13 | transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')]
14 |
15 |
16 | def is_image_file(filename):
17 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
18 |
19 | ##################################################################################################
20 | class DataLoaderTrain(Dataset):
21 | def __init__(self, rgb_dir, img_options=None, target_transform=None):
22 | super(DataLoaderTrain, self).__init__()
23 |
24 | self.target_transform = target_transform
25 |
26 | gt_dir = 'gt'
27 | input_dir = 'input'
28 |
29 | clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir)))
30 | noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir)))
31 |
32 | self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_png_file(x)]
33 | self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files if is_png_file(x)]
34 |
35 | self.img_options=img_options
36 |
37 | self.tar_size = len(self.clean_filenames) # get the size of target
38 |
39 | def __len__(self):
40 | return self.tar_size
41 |
42 | def __getitem__(self, index):
43 | tar_index = index % self.tar_size
44 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index])))
45 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index])))
46 |
47 | clean = clean.permute(2,0,1)
48 | noisy = noisy.permute(2,0,1)
49 |
50 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1]
51 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
52 |
53 | #Crop Input and Target
54 | ps = self.img_options['patch_size']
55 | H = clean.shape[1]
56 | W = clean.shape[2]
57 | # r = np.random.randint(0, H - ps) if not H-ps else 0
58 | # c = np.random.randint(0, W - ps) if not H-ps else 0
59 | if H-ps==0:
60 | r=0
61 | c=0
62 | else:
63 | r = np.random.randint(0, H - ps)
64 | c = np.random.randint(0, W - ps)
65 | clean = clean[:, r:r + ps, c:c + ps]
66 | noisy = noisy[:, r:r + ps, c:c + ps]
67 |
68 | apply_trans = transforms_aug[random.getrandbits(3)]
69 |
70 | clean = getattr(augment, apply_trans)(clean)
71 | noisy = getattr(augment, apply_trans)(noisy)
72 |
73 | return clean, noisy, clean_filename, noisy_filename
74 |
75 |
76 | ##################################################################################################
77 | class DataLoaderVal(Dataset):
78 | def __init__(self, rgb_dir, img_options=None, target_transform=None):
79 | super(DataLoaderVal, self).__init__()
80 |
81 | self.target_transform = target_transform
82 |
83 | gt_dir = 'gt'
84 | input_dir = 'input'
85 |
86 | clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir)))
87 | noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir)))
88 |
89 |
90 | self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_png_file(x)]
91 | self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files if is_png_file(x)]
92 |
93 | self.img_options=img_options
94 | self.tar_size = len(self.clean_filenames)
95 |
96 | def __len__(self):
97 | return self.tar_size
98 |
99 | def __getitem__(self, index):
100 | tar_index = index % self.tar_size
101 |
102 | inp_path = self.noisy_filenames[tar_index]
103 | tar_path = self.clean_filenames[tar_index]
104 |
105 | inp_img = Image.open(inp_path)
106 | tar_img = Image.open(tar_path)
107 |
108 | # Validate on center crop
109 | if self.img_options:
110 | ps = self.img_options['patch_size']
111 | # if ps is not None:
112 | inp_img = TF.center_crop(inp_img, (ps,ps))
113 | tar_img = TF.center_crop(tar_img, (ps,ps))
114 |
115 | noisy = TF.to_tensor(inp_img)
116 | clean = TF.to_tensor(tar_img)
117 |
118 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1]
119 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
120 |
121 |
122 | return clean, noisy, clean_filename, noisy_filename
123 |
124 | ##################################################################################################
125 |
126 | class DataLoaderTest(Dataset):
127 | def __init__(self, inp_dir, img_options):
128 | super(DataLoaderTest, self).__init__()
129 |
130 | inp_files = sorted(os.listdir(inp_dir))
131 | self.inp_filenames = [os.path.join(inp_dir, x) for x in inp_files if is_image_file(x)]
132 |
133 | self.inp_size = len(self.inp_filenames)
134 | self.img_options = img_options
135 |
136 | def __len__(self):
137 | return self.inp_size
138 |
139 | def __getitem__(self, index):
140 |
141 | path_inp = self.inp_filenames[index]
142 | filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
143 | inp = Image.open(path_inp)
144 |
145 | inp = TF.to_tensor(inp)
146 | return inp, filename
147 |
148 |
149 | def get_training_data(rgb_dir, img_options):
150 | assert os.path.exists(rgb_dir)
151 | return DataLoaderTrain(rgb_dir, img_options, None)
152 |
153 |
154 | def get_validation_data(rgb_dir,img_options=None):
155 | assert os.path.exists(rgb_dir)
156 | return DataLoaderVal(rgb_dir, img_options, None)
157 |
158 | def get_test_data(rgb_dir, img_options=None):
159 | assert os.path.exists(rgb_dir)
160 | return DataLoaderTest(rgb_dir, img_options)
--------------------------------------------------------------------------------
/dataset/dataset_derain.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from torch.utils.data import Dataset
4 | import torch
5 | from utils import is_png_file, load_img, Augment_RGB_torch
6 | import torch.nn.functional as F
7 | import random
8 | from PIL import Image
9 | import torchvision.transforms.functional as TF
10 | from natsort import natsorted
11 | from glob import glob
12 | augment = Augment_RGB_torch()
13 | transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')]
14 |
15 |
16 | def is_image_file(filename):
17 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
18 |
19 | ##################################################################################################
20 | class DataLoaderTrain(Dataset):
21 | def __init__(self, rgb_dir, img_options=None, target_transform=None):
22 | super(DataLoaderTrain, self).__init__()
23 |
24 | self.target_transform = target_transform
25 |
26 | name = 'real_world.txt'
27 | # name = 'real_test_1000.txt'
28 | self.dataset = os.path.join(rgb_dir, name)
29 |
30 | self.mat_files = open(self.dataset, 'r').readlines()
31 |
32 | self.clean_filenames = [os.path.join(rgb_dir, x.split(' ')[1][1:-1]) for x in self.mat_files ]
33 | self.noisy_filenames = [os.path.join(rgb_dir, x.split(' ')[0][1:]) for x in self.mat_files ]
34 |
35 | self.img_options=img_options
36 |
37 | self.tar_size = len(self.clean_filenames) # get the size of target
38 |
39 | def __len__(self):
40 | return self.tar_size
41 |
42 | def __getitem__(self, index):
43 | tar_index = index % self.tar_size
44 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index])))
45 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index])))
46 |
47 | clean = clean.permute(2,0,1)
48 | noisy = noisy.permute(2,0,1)
49 |
50 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1]
51 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
52 |
53 | #Crop Input and Target
54 | ps = self.img_options['patch_size']
55 | H = clean.shape[1]
56 | W = clean.shape[2]
57 | if H-ps==0:
58 | r=0
59 | c=0
60 | else:
61 | r = np.random.randint(0, H - ps)
62 | c = np.random.randint(0, W - ps)
63 | clean = clean[:, r:r + ps, c:c + ps]
64 | noisy = noisy[:, r:r + ps, c:c + ps]
65 |
66 | apply_trans = transforms_aug[random.getrandbits(3)]
67 |
68 | clean = getattr(augment, apply_trans)(clean)
69 | noisy = getattr(augment, apply_trans)(noisy)
70 |
71 | return clean, noisy, clean_filename, noisy_filename
72 |
73 |
74 | ##################################################################################################
75 | class DataLoaderVal(Dataset):
76 | def __init__(self, rgb_dir, target_transform=None):
77 | super(DataLoaderVal, self).__init__()
78 |
79 | self.target_transform = target_transform
80 |
81 | # name='demo.txt'
82 | name = 'real_test_1000.txt'
83 | # name = 'Real_Internet.txt'
84 | self.dataset = os.path.join(rgb_dir, name)
85 | self.mat_files = open(self.dataset, 'r').readlines()
86 |
87 | self.clean_filenames = [os.path.join(rgb_dir, x.split(' ')[1][1:-1]) for x in self.mat_files]
88 | self.noisy_filenames = [os.path.join(rgb_dir, x.split(' ')[0][1:]) for x in self.mat_files]
89 |
90 |
91 | self.tar_size = len(self.clean_filenames)
92 |
93 | def __len__(self):
94 | return self.tar_size
95 |
96 | def __getitem__(self, index):
97 | tar_index = index % self.tar_size
98 |
99 |
100 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index])))
101 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index])))
102 |
103 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1]
104 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
105 |
106 | clean = clean.permute(2,0,1)
107 | noisy = noisy.permute(2,0,1)
108 |
109 | return clean, noisy, clean_filename, noisy_filename
110 |
111 | ##################################################################################################
112 |
113 | class DataLoaderTest(Dataset):
114 | def __init__(self, inp_dir, img_options):
115 | super(DataLoaderTest, self).__init__()
116 |
117 | # name = 'Real_Internet.txt'
118 | name = 'real_test_1000.txt'
119 | self.dataset = os.path.join(inp_dir, name)
120 | self.mat_files = open(self.dataset, 'r').readlines()
121 |
122 | self.clean_filenames = [os.path.join(inp_dir, x.split(' ')[1][1:-1]) for x in self.mat_files]
123 | self.noisy_filenames = [os.path.join(inp_dir, x.split(' ')[0][1:]) for x in self.mat_files]
124 |
125 |
126 | inp_files = sorted(os.listdir(inp_dir))
127 | self.inp_filenames = [os.path.join(inp_dir, x) for x in inp_files if is_image_file(x)]
128 |
129 | self.inp_size = len(self.clean_filenames)
130 | self.img_options = img_options
131 |
132 | def __len__(self):
133 | return self.inp_size
134 |
135 | def __getitem__(self, index):
136 |
137 | path_inp = self.inp_filenames[index]
138 | filename = os.path.splitext(os.path.split(path_inp)[-1])[0]
139 | inp = Image.open(path_inp)
140 |
141 | inp = TF.to_tensor(inp)
142 | return inp, filename
143 |
144 |
145 | def get_training_data(rgb_dir, img_options):
146 | assert os.path.exists(rgb_dir)
147 | return DataLoaderTrain(rgb_dir, img_options, None)
148 |
149 |
150 | def get_validation_data(rgb_dir):
151 | assert os.path.exists(rgb_dir)
152 | return DataLoaderVal(rgb_dir, None)
153 |
154 | def get_test_data(rgb_dir, img_options=None):
155 | assert os.path.exists(rgb_dir)
156 | return DataLoaderTest(rgb_dir, img_options)
--------------------------------------------------------------------------------
/dataset/dataset_derain_drop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os
3 | from torch.utils.data import Dataset
4 | import torch
5 | from utils import is_png_file, load_img, Augment_RGB_torch
6 | import torch.nn.functional as F
7 | import random
8 | from PIL import Image
9 | import torchvision.transforms.functional as TF
10 | from natsort import natsorted
11 | from glob import glob
12 | augment = Augment_RGB_torch()
13 | transforms_aug = [method for method in dir(augment) if callable(getattr(augment, method)) if not method.startswith('_')]
14 |
15 |
16 | def is_image_file(filename):
17 | return any(filename.endswith(extension) for extension in ['jpeg', 'JPEG', 'jpg', 'png', 'JPG', 'PNG', 'gif'])
18 |
19 | ##################################################################################################
20 | class DataLoaderTrain(Dataset):
21 | def __init__(self, rgb_dir, img_options=None, target_transform=None):
22 | super(DataLoaderTrain, self).__init__()
23 |
24 | self.target_transform = target_transform
25 | gt_dir = 'gt'
26 | input_dir = 'data'
27 | clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir)))
28 | noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir)))
29 |
30 | self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_image_file(x)]
31 | self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files if is_image_file(x)]
32 |
33 |
34 | self.img_options=img_options
35 |
36 | self.tar_size = len(self.clean_filenames) # get the size of target
37 |
38 | def __len__(self):
39 | return self.tar_size
40 |
41 | def __getitem__(self, index):
42 | tar_index = index % self.tar_size
43 |
44 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index])))
45 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index])))
46 |
47 | clean = clean.permute(2,0,1)
48 | noisy = noisy.permute(2,0,1)
49 |
50 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1]
51 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
52 |
53 | #Crop Input and Target
54 | ps = self.img_options['patch_size']
55 | H = clean.shape[1]
56 | W = clean.shape[2]
57 | if H-ps==0:
58 | r=0
59 | c=0
60 | else:
61 | r = np.random.randint(0, H - ps)
62 | c = np.random.randint(0, W - ps)
63 | clean = clean[:, r:r + ps, c:c + ps]
64 | noisy = noisy[:, r:r + ps, c:c + ps]
65 |
66 | apply_trans = transforms_aug[random.getrandbits(3)]
67 |
68 | clean = getattr(augment, apply_trans)(clean)
69 | noisy = getattr(augment, apply_trans)(noisy)
70 |
71 | return clean, noisy, clean_filename, noisy_filename
72 |
73 |
74 | ##################################################################################################
75 | class DataLoaderVal(Dataset):
76 | def __init__(self, rgb_dir, img_options=None, target_transform=None):
77 | super(DataLoaderVal, self).__init__()
78 |
79 | self.target_transform = target_transform
80 | gt_dir = 'gt'
81 | input_dir = 'data'
82 |
83 | clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir)))
84 | noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir)))
85 |
86 |
87 | self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_image_file(x)]
88 | self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files if is_image_file(x)]
89 | self.img_options=img_options
90 |
91 | self.tar_size = len(self.clean_filenames)
92 |
93 | def __len__(self):
94 | return self.tar_size
95 |
96 | def __getitem__(self, index):
97 | tar_index = index % self.tar_size
98 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index])))
99 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index])))
100 |
101 | clean = clean.permute(2,0,1)
102 | noisy = noisy.permute(2,0,1)
103 |
104 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1]
105 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
106 |
107 | #Crop Input and Target
108 | ps = self.img_options['patch_size']
109 | H = clean.shape[1]
110 | W = clean.shape[2]
111 | if H-ps==0:
112 | r=0
113 | c=0
114 | else:
115 | r = (H - ps)//2
116 | c = (W - ps)//2
117 | clean = clean[:, r:r + ps, c:c + ps]
118 | noisy = noisy[:, r:r + ps, c:c + ps]
119 |
120 | return clean, noisy, clean_filename, noisy_filename
121 |
122 | ##################################################################################################
123 |
124 | class DataLoaderTest(Dataset):
125 | def __init__(self, rgb_dir):
126 | super(DataLoaderTest, self).__init__()
127 |
128 |
129 | gt_dir = 'gt'
130 | input_dir = 'data'
131 |
132 | clean_files = sorted(os.listdir(os.path.join(rgb_dir, gt_dir)))
133 | noisy_files = sorted(os.listdir(os.path.join(rgb_dir, input_dir)))
134 |
135 |
136 | self.clean_filenames = [os.path.join(rgb_dir, gt_dir, x) for x in clean_files if is_image_file(x)]
137 | self.noisy_filenames = [os.path.join(rgb_dir, input_dir, x) for x in noisy_files if is_image_file(x)]
138 |
139 | self.tar_size = len(self.clean_filenames)
140 |
141 |
142 |
143 |
144 | def __len__(self):
145 | return self.tar_size
146 |
147 | def __getitem__(self, index):
148 |
149 | tar_index = index % self.tar_size
150 |
151 | clean = torch.from_numpy(np.float32(load_img(self.clean_filenames[tar_index])))
152 | noisy = torch.from_numpy(np.float32(load_img(self.noisy_filenames[tar_index])))
153 |
154 | clean_filename = os.path.split(self.clean_filenames[tar_index])[-1]
155 | noisy_filename = os.path.split(self.noisy_filenames[tar_index])[-1]
156 |
157 | clean = clean.permute(2,0,1)
158 | noisy = noisy.permute(2,0,1)
159 | return clean, noisy, clean_filename, noisy_filename
160 |
161 | def get_training_data(rgb_dir, img_options):
162 | assert os.path.exists(rgb_dir)
163 | return DataLoaderTrain(rgb_dir, img_options, None)
164 |
165 |
166 | def get_validation_data(rgb_dir, img_options=None):
167 | assert os.path.exists(rgb_dir)
168 | return DataLoaderVal(rgb_dir, img_options, None)
169 |
170 | def get_test_data(rgb_dir):
171 | assert os.path.exists(rgb_dir)
172 | return DataLoaderTest(rgb_dir)
--------------------------------------------------------------------------------
/evaluate_PSNR_SSIM.m:
--------------------------------------------------------------------------------
1 | clc;close all;clear all;addpath(genpath('./'));
2 |
3 | datasets = {'Rain100L'};
4 | % datasets = {'Rain200L', 'Rain200H', 'SPA-Data'};
5 | num_set = length(datasets);
6 |
7 | psnr_alldatasets = 0;
8 | ssim_alldatasets = 0;
9 | detail_res = zeros(1000,1);
10 | for idx_set = 1:num_set
11 | file_path = strcat('./4test/spad/AST_B/');
12 | gt_path = strcat('./4test/real_test_1000/gt/');
13 | path_list = [dir(strcat(file_path,'*.jpg')); dir(strcat(file_path,'*.png'))];
14 | gt_list = [dir(strcat(gt_path,'*.jpg')); dir(strcat(gt_path,'*.png'))];
15 | img_num = length(path_list);
16 |
17 | total_psnr = 0;
18 | total_ssim = 0;
19 | if img_num > 0
20 | for j = 1:img_num
21 | image_name = path_list(j).name;
22 | gt_name = gt_list(j).name;
23 | input = imread(strcat(file_path,image_name));
24 | gt = imread(strcat(gt_path, gt_name));
25 | ssim_val = compute_ssim(input, gt);
26 | psnr_val = compute_psnr(input, gt);
27 | total_ssim = total_ssim + ssim_val;
28 | total_psnr = total_psnr + psnr_val;
29 | detail_res(j,1)=psnr_val;
30 | end
31 | end
32 | qm_psnr = total_psnr / img_num;
33 | qm_ssim = total_ssim / img_num;
34 |
35 | fprintf('For %s dataset PSNR: %f SSIM: %f\n', datasets{idx_set}, qm_psnr, qm_ssim);
36 |
37 | psnr_alldatasets = psnr_alldatasets + qm_psnr;
38 | ssim_alldatasets = ssim_alldatasets + qm_ssim;
39 |
40 | end
41 |
42 | fprintf('For all datasets PSNR: %f SSIM: %f\n', psnr_alldatasets/num_set, ssim_alldatasets/num_set);
43 |
44 | function ssim_mean=compute_ssim(img1,img2)
45 | if size(img1, 3) == 3
46 | img1 = rgb2ycbcr(img1);
47 | img1 = img1(:, :, 1);
48 | end
49 |
50 | if size(img2, 3) == 3
51 | img2 = rgb2ycbcr(img2);
52 | img2 = img2(:, :, 1);
53 | end
54 | ssim_mean = SSIM_index(img1, img2);
55 | end
56 |
57 | function psnr=compute_psnr(img1,img2)
58 | if size(img1, 3) == 3
59 | img1 = rgb2ycbcr(img1);
60 | img1 = img1(:, :, 1);
61 | end
62 |
63 | if size(img2, 3) == 3
64 | img2 = rgb2ycbcr(img2);
65 | img2 = img2(:, :, 1);
66 | end
67 |
68 | imdff = double(img1) - double(img2);
69 | imdff = imdff(:);
70 | rmse = sqrt(mean(imdff.^2));
71 | psnr = 20*log10(255/rmse);
72 |
73 | end
74 |
75 | function [mssim, ssim_map] = SSIM_index(img1, img2, K, window, L)
76 |
77 | %========================================================================
78 | %SSIM Index, Version 1.0
79 | %Copyright(c) 2003 Zhou Wang
80 | %All Rights Reserved.
81 | %
82 | %The author is with Howard Hughes Medical Institute, and Laboratory
83 | %for Computational Vision at Center for Neural Science and Courant
84 | %Institute of Mathematical Sciences, New York University.
85 | %
86 | %----------------------------------------------------------------------
87 | %Permission to use, copy, or modify this software and its documentation
88 | %for educational and research purposes only and without fee is hereby
89 | %granted, provided that this copyright notice and the original authors'
90 | %names appear on all copies and supporting documentation. This program
91 | %shall not be used, rewritten, or adapted as the basis of a commercial
92 | %software or hardware product without first obtaining permission of the
93 | %authors. The authors make no representations about the suitability of
94 | %this software for any purpose. It is provided "as is" without express
95 | %or implied warranty.
96 | %----------------------------------------------------------------------
97 | %
98 | %This is an implementation of the algorithm for calculating the
99 | %Structural SIMilarity (SSIM) index between two images. Please refer
100 | %to the following paper:
101 | %
102 | %Z. Wang, A. C. Bovik, H. R. Sheikh, and E. P. Simoncelli, "Image
103 | %quality assessment: From error measurement to structural similarity"
104 | %IEEE Transactios on Image Processing, vol. 13, no. 1, Jan. 2004.
105 | %
106 | %Kindly report any suggestions or corrections to zhouwang@ieee.org
107 | %
108 | %----------------------------------------------------------------------
109 | %
110 | %Input : (1) img1: the first image being compared
111 | % (2) img2: the second image being compared
112 | % (3) K: constants in the SSIM index formula (see the above
113 | % reference). defualt value: K = [0.01 0.03]
114 | % (4) window: local window for statistics (see the above
115 | % reference). default widnow is Gaussian given by
116 | % window = fspecial('gaussian', 11, 1.5);
117 | % (5) L: dynamic range of the images. default: L = 255
118 | %
119 | %Output: (1) mssim: the mean SSIM index value between 2 images.
120 | % If one of the images being compared is regarded as
121 | % perfect quality, then mssim can be considered as the
122 | % quality measure of the other image.
123 | % If img1 = img2, then mssim = 1.
124 | % (2) ssim_map: the SSIM index map of the test image. The map
125 | % has a smaller size than the input images. The actual size:
126 | % size(img1) - size(window) + 1.
127 | %
128 | %Default Usage:
129 | % Given 2 test images img1 and img2, whose dynamic range is 0-255
130 | %
131 | % [mssim ssim_map] = ssim_index(img1, img2);
132 | %
133 | %Advanced Usage:
134 | % User defined parameters. For example
135 | %
136 | % K = [0.05 0.05];
137 | % window = ones(8);
138 | % L = 100;
139 | % [mssim ssim_map] = ssim_index(img1, img2, K, window, L);
140 | %
141 | %See the results:
142 | %
143 | % mssim %Gives the mssim value
144 | % imshow(max(0, ssim_map).^4) %Shows the SSIM index map
145 | %
146 | %========================================================================
147 |
148 |
149 | if (nargin < 2 || nargin > 5)
150 | ssim_index = -Inf;
151 | ssim_map = -Inf;
152 | return;
153 | end
154 |
155 | if (size(img1) ~= size(img2))
156 | ssim_index = -Inf;
157 | ssim_map = -Inf;
158 | return;
159 | end
160 |
161 | [M N] = size(img1);
162 |
163 | if (nargin == 2)
164 | if ((M < 11) || (N < 11))
165 | ssim_index = -Inf;
166 | ssim_map = -Inf;
167 | return
168 | end
169 | window = fspecial('gaussian', 11, 1.5); %
170 | K(1) = 0.01; % default settings
171 | K(2) = 0.03; %
172 | L = 255; %
173 | end
174 |
175 | if (nargin == 3)
176 | if ((M < 11) || (N < 11))
177 | ssim_index = -Inf;
178 | ssim_map = -Inf;
179 | return
180 | end
181 | window = fspecial('gaussian', 11, 1.5);
182 | L = 255;
183 | if (length(K) == 2)
184 | if (K(1) < 0 || K(2) < 0)
185 | ssim_index = -Inf;
186 | ssim_map = -Inf;
187 | return;
188 | end
189 | else
190 | ssim_index = -Inf;
191 | ssim_map = -Inf;
192 | return;
193 | end
194 | end
195 |
196 | if (nargin == 4)
197 | [H W] = size(window);
198 | if ((H*W) < 4 || (H > M) || (W > N))
199 | ssim_index = -Inf;
200 | ssim_map = -Inf;
201 | return
202 | end
203 | L = 255;
204 | if (length(K) == 2)
205 | if (K(1) < 0 || K(2) < 0)
206 | ssim_index = -Inf;
207 | ssim_map = -Inf;
208 | return;
209 | end
210 | else
211 | ssim_index = -Inf;
212 | ssim_map = -Inf;
213 | return;
214 | end
215 | end
216 |
217 | if (nargin == 5)
218 | [H W] = size(window);
219 | if ((H*W) < 4 || (H > M) || (W > N))
220 | ssim_index = -Inf;
221 | ssim_map = -Inf;
222 | return
223 | end
224 | if (length(K) == 2)
225 | if (K(1) < 0 || K(2) < 0)
226 | ssim_index = -Inf;
227 | ssim_map = -Inf;
228 | return;
229 | end
230 | else
231 | ssim_index = -Inf;
232 | ssim_map = -Inf;
233 | return;
234 | end
235 | end
236 |
237 | C1 = (K(1)*L)^2;
238 | C2 = (K(2)*L)^2;
239 | window = window/sum(sum(window));
240 | img1 = double(img1);
241 | img2 = double(img2);
242 |
243 | mu1 = filter2(window, img1, 'valid');
244 | mu2 = filter2(window, img2, 'valid');
245 | mu1_sq = mu1.*mu1;
246 | mu2_sq = mu2.*mu2;
247 | mu1_mu2 = mu1.*mu2;
248 | sigma1_sq = filter2(window, img1.*img1, 'valid') - mu1_sq;
249 | sigma2_sq = filter2(window, img2.*img2, 'valid') - mu2_sq;
250 | sigma12 = filter2(window, img1.*img2, 'valid') - mu1_mu2;
251 |
252 | if (C1 > 0 & C2 > 0)
253 | ssim_map = ((2*mu1_mu2 + C1).*(2*sigma12 + C2))./((mu1_sq + mu2_sq + C1).*(sigma1_sq + sigma2_sq + C2));
254 | else
255 | numerator1 = 2*mu1_mu2 + C1;
256 | numerator2 = 2*sigma12 + C2;
257 | denominator1 = mu1_sq + mu2_sq + C1;
258 | denominator2 = sigma1_sq + sigma2_sq + C2;
259 | ssim_map = ones(size(mu1));
260 | index = (denominator1.*denominator2 > 0);
261 | ssim_map(index) = (numerator1(index).*numerator2(index))./(denominator1(index).*denominator2(index));
262 | index = (denominator1 ~= 0) & (denominator2 == 0);
263 | ssim_map(index) = numerator1(index)./denominator1(index);
264 | end
265 |
266 | mssim = mean2(ssim_map);
267 |
268 | end
269 |
270 |
--------------------------------------------------------------------------------
/losses.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 |
7 |
8 | def tv_loss(x, beta = 0.5, reg_coeff = 5):
9 | '''Calculates TV loss for an image `x`.
10 |
11 | Args:
12 | x: image, torch.Variable of torch.Tensor
13 | beta: See https://arxiv.org/abs/1412.0035 (fig. 2) to see effect of `beta`
14 | '''
15 | dh = torch.pow(x[:,:,:,1:] - x[:,:,:,:-1], 2)
16 | dw = torch.pow(x[:,:,1:,:] - x[:,:,:-1,:], 2)
17 | a,b,c,d=x.shape
18 | return reg_coeff*(torch.sum(torch.pow(dh[:, :, :-1] + dw[:, :, :, :-1], beta))/(a*b*c*d))
19 |
20 | class TVLoss(nn.Module):
21 | def __init__(self, tv_loss_weight=1):
22 | super(TVLoss, self).__init__()
23 | self.tv_loss_weight = tv_loss_weight
24 |
25 | def forward(self, x):
26 | batch_size = x.size()[0]
27 | h_x = x.size()[2]
28 | w_x = x.size()[3]
29 | count_h = self.tensor_size(x[:, :, 1:, :])
30 | count_w = self.tensor_size(x[:, :, :, 1:])
31 | h_tv = torch.pow((x[:, :, 1:, :] - x[:, :, :h_x - 1, :]), 2).sum()
32 | w_tv = torch.pow((x[:, :, :, 1:] - x[:, :, :, :w_x - 1]), 2).sum()
33 | return self.tv_loss_weight * 2 * (h_tv / count_h + w_tv / count_w) / batch_size
34 |
35 | @staticmethod
36 | def tensor_size(t):
37 | return t.size()[1] * t.size()[2] * t.size()[3]
38 |
39 |
40 |
41 | class CharbonnierLoss(nn.Module):
42 | """Charbonnier Loss (L1)"""
43 |
44 | def __init__(self, eps=1e-3):
45 | super(CharbonnierLoss, self).__init__()
46 | self.eps = eps
47 |
48 | def forward(self, x, y):
49 | diff = x - y
50 | # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
51 | loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
52 | return loss
53 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.checkpoint as checkpoint
4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5 | import torch.nn.functional as F
6 | from einops import rearrange, repeat
7 | from einops.layers.torch import Rearrange
8 | import math
9 | import numpy as np
10 | import time
11 | from torch import einsum
12 |
13 |
14 | def conv(in_channels, out_channels, kernel_size, bias=False, stride = 1):
15 | return nn.Conv2d(
16 | in_channels, out_channels, kernel_size,
17 | padding=(kernel_size//2), bias=bias, stride = stride)
18 |
19 |
20 |
21 | #########################################
22 | class ConvBlock(nn.Module):
23 | def __init__(self, in_channel, out_channel, strides=1):
24 | super(ConvBlock, self).__init__()
25 | self.strides = strides
26 | self.in_channel=in_channel
27 | self.out_channel=out_channel
28 | self.block = nn.Sequential(
29 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1),
30 | nn.LeakyReLU(inplace=True),
31 | nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=strides, padding=1),
32 | nn.LeakyReLU(inplace=True),
33 | )
34 | self.conv11 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, padding=0)
35 |
36 | def forward(self, x):
37 | out1 = self.block(x)
38 | out2 = self.conv11(x)
39 | out = out1 + out2
40 | return out
41 |
42 | class LinearProjection(nn.Module):
43 | def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0., bias=True):
44 | super().__init__()
45 | inner_dim = dim_head * heads
46 | self.heads = heads
47 | self.to_q = nn.Linear(dim, inner_dim, bias = bias)
48 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias = bias)
49 | self.dim = dim
50 | self.inner_dim = inner_dim
51 |
52 | def forward(self, x, attn_kv=None):
53 | B_, N, C = x.shape
54 | if attn_kv is not None:
55 | attn_kv = attn_kv.unsqueeze(0).repeat(B_,1,1)
56 | else:
57 | attn_kv = x
58 | N_kv = attn_kv.size(1)
59 | q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
60 | kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
61 | q = q[0]
62 | k, v = kv[0], kv[1]
63 | return q,k,v
64 |
65 |
66 | #########################################
67 | ########### window-based self-attention #############
68 | class WindowAttention(nn.Module):
69 | def __init__(self, dim, win_size,num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
70 |
71 | super().__init__()
72 | self.dim = dim
73 | self.win_size = win_size # Wh, Ww
74 | self.num_heads = num_heads
75 | head_dim = dim // num_heads
76 | self.scale = qk_scale or head_dim ** -0.5
77 |
78 | # define a parameter table of relative position bias
79 | self.relative_position_bias_table = nn.Parameter(
80 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
81 |
82 | # get pair-wise relative position index for each token inside the window
83 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1]
84 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1]
85 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
86 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
87 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
88 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
89 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0
90 | relative_coords[:, :, 1] += self.win_size[1] - 1
91 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
92 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
93 | self.register_buffer("relative_position_index", relative_position_index)
94 | trunc_normal_(self.relative_position_bias_table, std=.02)
95 |
96 | if token_projection =='linear':
97 | self.qkv = LinearProjection(dim,num_heads,dim//num_heads,bias=qkv_bias)
98 | else:
99 | raise Exception("Projection error!")
100 |
101 | self.token_projection = token_projection
102 | self.attn_drop = nn.Dropout(attn_drop)
103 | self.proj = nn.Linear(dim, dim)
104 | self.proj_drop = nn.Dropout(proj_drop)
105 |
106 | self.softmax = nn.Softmax(dim=-1)
107 |
108 | def forward(self, x, attn_kv=None, mask=None):
109 | B_, N, C = x.shape
110 | q, k, v = self.qkv(x,attn_kv)
111 | q = q * self.scale
112 | attn = (q @ k.transpose(-2, -1))
113 |
114 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
115 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH
116 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
117 | ratio = attn.size(-1)//relative_position_bias.size(-1)
118 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d = ratio)
119 |
120 | attn = attn + relative_position_bias.unsqueeze(0)
121 |
122 | if mask is not None:
123 | nW = mask.shape[0]
124 | mask = repeat(mask, 'nW m n -> nW m (n d)',d = ratio)
125 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N*ratio) + mask.unsqueeze(1).unsqueeze(0)
126 | attn = attn.view(-1, self.num_heads, N, N*ratio)
127 | attn = self.softmax(attn)
128 | else:
129 | attn = self.softmax(attn)
130 |
131 | attn = self.attn_drop(attn)
132 |
133 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
134 | x = self.proj(x)
135 | x = self.proj_drop(x)
136 | return x
137 |
138 | def extra_repr(self) -> str:
139 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}'
140 |
141 | ########### window-based self-attention #############
142 | class WindowAttention_sparse(nn.Module):
143 | def __init__(self, dim, win_size,num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
144 |
145 | super().__init__()
146 | self.dim = dim
147 | self.win_size = win_size # Wh, Ww
148 | self.num_heads = num_heads
149 | head_dim = dim // num_heads
150 | self.scale = qk_scale or head_dim ** -0.5
151 |
152 | # define a parameter table of relative position bias
153 | self.relative_position_bias_table = nn.Parameter(
154 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
155 |
156 | # get pair-wise relative position index for each token inside the window
157 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1]
158 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1]
159 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
160 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
161 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
162 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
163 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0
164 | relative_coords[:, :, 1] += self.win_size[1] - 1
165 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
166 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
167 | self.register_buffer("relative_position_index", relative_position_index)
168 | trunc_normal_(self.relative_position_bias_table, std=.02)
169 |
170 | if token_projection =='linear':
171 | self.qkv = LinearProjection(dim,num_heads,dim//num_heads,bias=qkv_bias)
172 | else:
173 | raise Exception("Projection error!")
174 |
175 | self.token_projection = token_projection
176 | self.attn_drop = nn.Dropout(attn_drop)
177 | self.proj = nn.Linear(dim, dim)
178 | self.proj_drop = nn.Dropout(proj_drop)
179 |
180 | self.softmax = nn.Softmax(dim=-1)
181 | self.relu = nn.ReLU()
182 | self.w = nn.Parameter(torch.ones(2))
183 |
184 | def forward(self, x, attn_kv=None, mask=None):
185 | B_, N, C = x.shape
186 | q, k, v = self.qkv(x,attn_kv)
187 | q = q * self.scale
188 | attn = (q @ k.transpose(-2, -1))
189 |
190 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
191 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH
192 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
193 | ratio = attn.size(-1)//relative_position_bias.size(-1)
194 | relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d = ratio)
195 |
196 | attn = attn + relative_position_bias.unsqueeze(0)
197 |
198 | if mask is not None:
199 | nW = mask.shape[0]
200 | mask = repeat(mask, 'nW m n -> nW m (n d)',d = ratio)
201 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N*ratio) + mask.unsqueeze(1).unsqueeze(0)
202 | attn = attn.view(-1, self.num_heads, N, N*ratio)
203 | attn0 = self.softmax(attn)
204 | attn1 = self.relu(attn)**2#b,h,w,c
205 | else:
206 | attn0 = self.softmax(attn)
207 | attn1 = self.relu(attn)**2
208 | w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w))
209 | w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w))
210 | attn = attn0*w1+attn1*w2
211 | attn = self.attn_drop(attn)
212 |
213 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
214 | x = self.proj(x)
215 | x = self.proj_drop(x)
216 | return x
217 |
218 | def extra_repr(self) -> str:
219 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}'
220 |
221 | ########### self-attention #############
222 | class Attention(nn.Module):
223 | def __init__(self, dim,num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
224 |
225 | super().__init__()
226 | self.dim = dim
227 | self.num_heads = num_heads
228 | head_dim = dim // num_heads
229 | self.scale = qk_scale or head_dim ** -0.5
230 |
231 | self.qkv = LinearProjection(dim,num_heads,dim//num_heads,bias=qkv_bias)
232 |
233 | self.token_projection = token_projection
234 | self.attn_drop = nn.Dropout(attn_drop)
235 | self.proj = nn.Linear(dim, dim)
236 | self.proj_drop = nn.Dropout(proj_drop)
237 |
238 | self.softmax = nn.Softmax(dim=-1)
239 |
240 | def forward(self, x, attn_kv=None, mask=None):
241 | B_, N, C = x.shape
242 | q, k, v = self.qkv(x,attn_kv)
243 | q = q * self.scale
244 | attn = (q @ k.transpose(-2, -1))
245 | if mask is not None:
246 | nW = mask.shape[0]
247 | # mask = repeat(mask, 'nW m n -> nW m (n d)',d = ratio)
248 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
249 | attn = attn.view(-1, self.num_heads, N, N)
250 | attn = self.softmax(attn)
251 | else:
252 | attn = self.softmax(attn)
253 |
254 | attn = self.attn_drop(attn)
255 |
256 | x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
257 | x = self.proj(x)
258 | x = self.proj_drop(x)
259 | return x
260 |
261 | def extra_repr(self) -> str:
262 | return f'dim={self.dim}, num_heads={self.num_heads}'
263 |
264 |
265 |
266 | #########################################
267 | ########### feed-forward network #############
268 | class Mlp(nn.Module):
269 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
270 | super().__init__()
271 | out_features = out_features or in_features
272 | hidden_features = hidden_features or in_features
273 | self.fc1 = nn.Linear(in_features, hidden_features)
274 | self.act = act_layer()
275 | self.fc2 = nn.Linear(hidden_features, out_features)
276 | self.drop = nn.Dropout(drop)
277 | self.in_features = in_features
278 | self.hidden_features = hidden_features
279 | self.out_features = out_features
280 |
281 | def forward(self, x):
282 | x = self.fc1(x)
283 | x = self.act(x)
284 | x = self.drop(x)
285 | x = self.fc2(x)
286 | x = self.drop(x)
287 | return x
288 |
289 |
290 | class LeFF(nn.Module):
291 | def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU,drop = 0., use_eca=False):
292 | super().__init__()
293 | self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim),
294 | act_layer())
295 | self.dwconv = nn.Sequential(nn.Conv2d(hidden_dim,hidden_dim,groups=hidden_dim,kernel_size=3,stride=1,padding=1),
296 | act_layer())
297 | self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim))
298 | self.dim = dim
299 | self.hidden_dim = hidden_dim
300 | self.eca = nn.Identity()
301 |
302 | def forward(self, x):
303 | # bs x hw x c
304 | bs, hw, c = x.size()
305 | hh = int(math.sqrt(hw))
306 |
307 | x = self.linear1(x)
308 |
309 | # spatial restore
310 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h = hh, w = hh)
311 | # bs,hidden_dim,32x32
312 |
313 | x = self.dwconv(x)
314 |
315 | # flaten
316 | x = rearrange(x, ' b c h w -> b (h w) c', h = hh, w = hh)
317 |
318 | x = self.linear2(x)
319 | x = self.eca(x)
320 |
321 | return x
322 |
323 |
324 | class FRFN(nn.Module):
325 | def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU,drop = 0., use_eca=False):
326 | super().__init__()
327 | self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim*2),
328 | act_layer())
329 | self.dwconv = nn.Sequential(nn.Conv2d(hidden_dim,hidden_dim,groups=hidden_dim,kernel_size=3,stride=1,padding=1),
330 | act_layer())
331 | self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim))
332 | self.dim = dim
333 | self.hidden_dim = hidden_dim
334 |
335 | self.dim_conv = self.dim // 4
336 | self.dim_untouched = self.dim - self.dim_conv
337 | self.partial_conv3 = nn.Conv2d(self.dim_conv, self.dim_conv, 3, 1, 1, bias=False)
338 |
339 | def forward(self, x):
340 | # bs x hw x c
341 | bs, hw, c = x.size()
342 | hh = int(math.sqrt(hw))
343 |
344 |
345 | # spatial restore
346 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h = hh, w = hh)
347 |
348 | x1, x2,= torch.split(x, [self.dim_conv,self.dim_untouched], dim=1)
349 | x1 = self.partial_conv3(x1)
350 | x = torch.cat((x1, x2), 1)
351 |
352 | # flaten
353 | x = rearrange(x, ' b c h w -> b (h w) c', h = hh, w = hh)
354 |
355 | x = self.linear1(x)
356 | #gate mechanism
357 | x_1,x_2 = x.chunk(2,dim=-1)
358 |
359 | x_1 = rearrange(x_1, ' b (h w) (c) -> b c h w ', h = hh, w = hh)
360 | x_1 = self.dwconv(x_1)
361 | x_1 = rearrange(x_1, ' b c h w -> b (h w) c', h = hh, w = hh)
362 | x = x_1 * x_2
363 |
364 | x = self.linear2(x)
365 | # x = self.eca(x)
366 |
367 | return x
368 |
369 |
370 |
371 | #########################################
372 | ########### window operation#############
373 | def window_partition(x, win_size, dilation_rate=1):
374 | B, H, W, C = x.shape
375 | if dilation_rate !=1:
376 | x = x.permute(0,3,1,2) # B, C, H, W
377 | assert type(dilation_rate) is int, 'dilation_rate should be a int'
378 | x = F.unfold(x, kernel_size=win_size,dilation=dilation_rate,padding=4*(dilation_rate-1),stride=win_size) # B, C*Wh*Ww, H/Wh*W/Ww
379 | windows = x.permute(0,2,1).contiguous().view(-1, C, win_size, win_size) # B' ,C ,Wh ,Ww
380 | windows = windows.permute(0,2,3,1).contiguous() # B' ,Wh ,Ww ,C
381 | else:
382 | x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
383 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) # B' ,Wh ,Ww ,C
384 | return windows
385 |
386 | def window_reverse(windows, win_size, H, W, dilation_rate=1):
387 | # B' ,Wh ,Ww ,C
388 | B = int(windows.shape[0] / (H * W / win_size / win_size))
389 | x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
390 | if dilation_rate !=1:
391 | x = windows.permute(0,5,3,4,1,2).contiguous() # B, C*Wh*Ww, H/Wh*W/Ww
392 | x = F.fold(x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4*(dilation_rate-1),stride=win_size)
393 | else:
394 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
395 | return x
396 |
397 | #########################################
398 |
399 | # Downsample Block
400 | class Downsample(nn.Module):
401 | def __init__(self, in_channel, out_channel):
402 | super(Downsample, self).__init__()
403 | self.conv = nn.Sequential(
404 | nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1),
405 | )
406 | self.in_channel = in_channel
407 | self.out_channel = out_channel
408 |
409 | def forward(self, x):
410 | B, L, C = x.shape
411 | # import pdb;pdb.set_trace()
412 | H = int(math.sqrt(L))
413 | W = int(math.sqrt(L))
414 | x = x.transpose(1, 2).contiguous().view(B, C, H, W)
415 | out = self.conv(x).flatten(2).transpose(1,2).contiguous() # B H*W C
416 | return out
417 |
418 | # Upsample Block
419 | class Upsample(nn.Module):
420 | def __init__(self, in_channel, out_channel):
421 | super(Upsample, self).__init__()
422 | self.deconv = nn.Sequential(
423 | nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2),
424 | )
425 | self.in_channel = in_channel
426 | self.out_channel = out_channel
427 |
428 | def forward(self, x):
429 | B, L, C = x.shape
430 | H = int(math.sqrt(L))
431 | W = int(math.sqrt(L))
432 | x = x.transpose(1, 2).contiguous().view(B, C, H, W)
433 | out = self.deconv(x).flatten(2).transpose(1,2).contiguous() # B H*W C
434 | return out
435 |
436 |
437 | # Input Projection
438 | class InputProj(nn.Module):
439 | def __init__(self, in_channel=3, out_channel=64, kernel_size=3, stride=1, norm_layer=None,act_layer=nn.LeakyReLU):
440 | super().__init__()
441 | self.proj = nn.Sequential(
442 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2),
443 | act_layer(inplace=True)
444 | )
445 | if norm_layer is not None:
446 | self.norm = norm_layer(out_channel)
447 | else:
448 | self.norm = None
449 | self.in_channel = in_channel
450 | self.out_channel = out_channel
451 |
452 | def forward(self, x):
453 | B, C, H, W = x.shape
454 | x = self.proj(x).flatten(2).transpose(1, 2).contiguous() # B H*W C
455 | if self.norm is not None:
456 | x = self.norm(x)
457 | return x
458 |
459 | # Output Projection
460 | class OutputProj(nn.Module):
461 | def __init__(self, in_channel=64, out_channel=3, kernel_size=3, stride=1, norm_layer=None,act_layer=None):
462 | super().__init__()
463 | self.proj = nn.Sequential(
464 | nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size//2),
465 | )
466 | if act_layer is not None:
467 | self.proj.add_module(act_layer(inplace=True))
468 | if norm_layer is not None:
469 | self.norm = norm_layer(out_channel)
470 | else:
471 | self.norm = None
472 | self.in_channel = in_channel
473 | self.out_channel = out_channel
474 |
475 | def forward(self, x):
476 | B, L, C = x.shape
477 | H = int(math.sqrt(L))
478 | W = int(math.sqrt(L))
479 | x = x.transpose(1, 2).view(B, C, H, W)
480 | x = self.proj(x)
481 | if self.norm is not None:
482 | x = self.norm(x)
483 | return x
484 |
485 | #########################################
486 | ###########Transformer Block#############
487 | class TransformerBlock(nn.Module):
488 | def __init__(self, dim, input_resolution, num_heads, win_size=8, shift_size=0,
489 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
490 | act_layer=nn.GELU, norm_layer=nn.LayerNorm,token_projection='linear',token_mlp='leff',att=True,sparseAtt=False):
491 | super().__init__()
492 |
493 | self.att = att
494 | self.sparseAtt = sparseAtt
495 |
496 | self.dim = dim
497 | self.input_resolution = input_resolution
498 | self.num_heads = num_heads
499 | self.win_size = win_size
500 | self.shift_size = shift_size
501 | self.mlp_ratio = mlp_ratio
502 | self.token_mlp = token_mlp
503 | if min(self.input_resolution) <= self.win_size:
504 | self.shift_size = 0
505 | self.win_size = min(self.input_resolution)
506 | assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size"
507 |
508 | if self.att:
509 | self.norm1 = norm_layer(dim)
510 | if self.sparseAtt:
511 | self.attn = WindowAttention_sparse(
512 | dim, win_size=to_2tuple(self.win_size), num_heads=num_heads,
513 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
514 | token_projection=token_projection)
515 | else:
516 | self.attn = WindowAttention(
517 | dim, win_size=to_2tuple(self.win_size), num_heads=num_heads,
518 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
519 | token_projection=token_projection)
520 |
521 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
522 | self.norm2 = norm_layer(dim)
523 | mlp_hidden_dim = int(dim * mlp_ratio)
524 | if token_mlp in ['ffn','mlp']:
525 | self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim,act_layer=act_layer, drop=drop)
526 | elif token_mlp=='leff':
527 | self.mlp = LeFF(dim,mlp_hidden_dim,act_layer=act_layer, drop=drop)
528 | elif token_mlp=='frfn':
529 | self.mlp = FRFN(dim,mlp_hidden_dim,act_layer=act_layer, drop=drop)
530 | else:
531 | raise Exception("FFN error!")
532 |
533 |
534 | def with_pos_embed(self, tensor, pos):
535 | return tensor if pos is None else tensor + pos
536 |
537 | def extra_repr(self) -> str:
538 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
539 | f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
540 |
541 | def forward(self, x, mask=None):
542 | B, L, C = x.shape
543 | H = int(math.sqrt(L))
544 | W = int(math.sqrt(L))
545 |
546 | ## input mask
547 | if mask != None:
548 | input_mask = F.interpolate(mask, size=(H,W)).permute(0,2,3,1)
549 | input_mask_windows = window_partition(input_mask, self.win_size) # nW, win_size, win_size, 1
550 | attn_mask = input_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
551 | attn_mask = attn_mask.unsqueeze(2)*attn_mask.unsqueeze(1) # nW, win_size*win_size, win_size*win_size
552 | attn_mask = attn_mask.masked_fill(attn_mask!=0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
553 | else:
554 | attn_mask = None
555 |
556 | ## shift mask
557 | if self.shift_size > 0:
558 | # calculate attention mask for SW-MSA
559 | shift_mask = torch.zeros((1, H, W, 1)).type_as(x)
560 | h_slices = (slice(0, -self.win_size),
561 | slice(-self.win_size, -self.shift_size),
562 | slice(-self.shift_size, None))
563 | w_slices = (slice(0, -self.win_size),
564 | slice(-self.win_size, -self.shift_size),
565 | slice(-self.shift_size, None))
566 | cnt = 0
567 | for h in h_slices:
568 | for w in w_slices:
569 | shift_mask[:, h, w, :] = cnt
570 | cnt += 1
571 | shift_mask_windows = window_partition(shift_mask, self.win_size) # nW, win_size, win_size, 1
572 | shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) # nW, win_size*win_size
573 | shift_attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(2) # nW, win_size*win_size, win_size*win_size
574 | shift_attn_mask = shift_attn_mask.masked_fill(shift_attn_mask != 0, float(-100.0)).masked_fill(shift_attn_mask == 0, float(0.0))
575 | attn_mask = attn_mask + shift_attn_mask if attn_mask is not None else shift_attn_mask
576 |
577 |
578 | shortcut = x
579 |
580 | if self.att:
581 | x = self.norm1(x)
582 | x = x.view(B, H, W, C)
583 |
584 | # cyclic shift
585 | if self.shift_size > 0:
586 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
587 | else:
588 | shifted_x = x
589 |
590 | # partition windows
591 | x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C N*C->C
592 | x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C
593 |
594 |
595 | # W-MSA/SW-MSA
596 | attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, win_size*win_size, C
597 |
598 | # merge windows
599 | attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C)
600 | shifted_x = window_reverse(attn_windows, self.win_size, H, W) # B H' W' C
601 |
602 | # reverse cyclic shift
603 | if self.shift_size > 0:
604 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
605 | else:
606 | x = shifted_x
607 | x = x.view(B, H * W, C)
608 | x = shortcut + self.drop_path(x)
609 |
610 | # FFN
611 | x = x + self.drop_path(self.mlp(self.norm2(x)))
612 | del attn_mask
613 | return x
614 |
615 |
616 | #########################################
617 | ########### Basic layer of AST ################
618 | class BasicASTLayer(nn.Module):
619 | def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size,
620 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
621 | drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False,
622 | token_projection='linear',token_mlp='ffn', shift_flag=True,att=False,sparseAtt=False):
623 |
624 | super().__init__()
625 | self.att = att
626 | self.sparseAtt = sparseAtt
627 | self.dim = dim
628 | self.input_resolution = input_resolution
629 | self.depth = depth
630 | self.use_checkpoint = use_checkpoint
631 | # build blocks
632 | if shift_flag:
633 | self.blocks = nn.ModuleList([
634 | TransformerBlock(dim=dim, input_resolution=input_resolution,
635 | num_heads=num_heads, win_size=win_size,
636 | shift_size=0 if (i % 2 == 0) else win_size // 2,
637 | mlp_ratio=mlp_ratio,
638 | qkv_bias=qkv_bias, qk_scale=qk_scale,
639 | drop=drop, attn_drop=attn_drop,
640 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
641 | norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp,att=self.att,sparseAtt=self.sparseAtt)
642 | for i in range(depth)])
643 | else:
644 | self.blocks = nn.ModuleList([
645 | TransformerBlock(dim=dim, input_resolution=input_resolution,
646 | num_heads=num_heads, win_size=win_size,
647 | shift_size=0,
648 | mlp_ratio=mlp_ratio,
649 | qkv_bias=qkv_bias, qk_scale=qk_scale,
650 | drop=drop, attn_drop=attn_drop,
651 | drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
652 | norm_layer=norm_layer,token_projection=token_projection,token_mlp=token_mlp,att=self.att,sparseAtt=self.sparseAtt)
653 | for i in range(depth)])
654 |
655 | def extra_repr(self) -> str:
656 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
657 |
658 | def forward(self, x, mask=None):
659 | for blk in self.blocks:
660 | if self.use_checkpoint:
661 | x = checkpoint.checkpoint(blk, x)
662 | else:
663 | x = blk(x,mask)
664 | return x
665 |
666 |
667 |
668 | class AST(nn.Module):
669 | def __init__(self, img_size=256, in_chans=3, dd_in=3,
670 | embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2],
671 | win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
672 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
673 | norm_layer=nn.LayerNorm, patch_norm=True,
674 | use_checkpoint=False, token_projection='linear', token_mlp='leff',
675 | dowsample=Downsample, upsample=Upsample, shift_flag=True,**kwargs):
676 | super().__init__()
677 |
678 | self.num_enc_layers = len(depths)//2
679 | self.num_dec_layers = len(depths)//2
680 | self.embed_dim = embed_dim
681 | self.patch_norm = patch_norm
682 | self.mlp_ratio = mlp_ratio
683 | self.token_projection = token_projection
684 | self.mlp = token_mlp
685 | self.win_size =win_size
686 | self.reso = img_size
687 | self.pos_drop = nn.Dropout(p=drop_rate)
688 | self.dd_in = dd_in
689 |
690 | # stochastic depth
691 | enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))]
692 | conv_dpr = [drop_path_rate]*depths[4]
693 | dec_dpr = enc_dpr[::-1]
694 |
695 | # build layers
696 |
697 | # Input/Output
698 | self.input_proj = InputProj(in_channel=dd_in, out_channel=embed_dim, kernel_size=3, stride=1, act_layer=nn.LeakyReLU)
699 | self.output_proj = OutputProj(in_channel=2*embed_dim, out_channel=in_chans, kernel_size=3, stride=1)
700 |
701 | # Encoder
702 | self.encoderlayer_0 = BasicASTLayer(dim=embed_dim,
703 | output_dim=embed_dim,
704 | input_resolution=(img_size,
705 | img_size),
706 | depth=depths[0],
707 | num_heads=num_heads[0],
708 | win_size=win_size,
709 | mlp_ratio=self.mlp_ratio,
710 | qkv_bias=qkv_bias, qk_scale=qk_scale,
711 | drop=drop_rate, attn_drop=attn_drop_rate,
712 | drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])],
713 | norm_layer=norm_layer,
714 | use_checkpoint=use_checkpoint,
715 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=False,sparseAtt=False)
716 | self.dowsample_0 = dowsample(embed_dim, embed_dim*2)
717 | self.encoderlayer_1 = BasicASTLayer(dim=embed_dim*2,
718 | output_dim=embed_dim*2,
719 | input_resolution=(img_size // 2,
720 | img_size // 2),
721 | depth=depths[1],
722 | num_heads=num_heads[1],
723 | win_size=win_size,
724 | mlp_ratio=self.mlp_ratio,
725 | qkv_bias=qkv_bias, qk_scale=qk_scale,
726 | drop=drop_rate, attn_drop=attn_drop_rate,
727 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
728 | norm_layer=norm_layer,
729 | use_checkpoint=use_checkpoint,
730 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=False,sparseAtt=False)
731 | self.dowsample_1 = dowsample(embed_dim*2, embed_dim*4)
732 | self.encoderlayer_2 = BasicASTLayer(dim=embed_dim*4,
733 | output_dim=embed_dim*4,
734 | input_resolution=(img_size // (2 ** 2),
735 | img_size // (2 ** 2)),
736 | depth=depths[2],
737 | num_heads=num_heads[2],
738 | win_size=win_size,
739 | mlp_ratio=self.mlp_ratio,
740 | qkv_bias=qkv_bias, qk_scale=qk_scale,
741 | drop=drop_rate, attn_drop=attn_drop_rate,
742 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
743 | norm_layer=norm_layer,
744 | use_checkpoint=use_checkpoint,
745 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=False,sparseAtt=False)
746 | self.dowsample_2 = dowsample(embed_dim*4, embed_dim*8)
747 | self.encoderlayer_3 = BasicASTLayer(dim=embed_dim*8,
748 | output_dim=embed_dim*8,
749 | input_resolution=(img_size // (2 ** 3),
750 | img_size // (2 ** 3)),
751 | depth=depths[3],
752 | num_heads=num_heads[3],
753 | win_size=win_size,
754 | mlp_ratio=self.mlp_ratio,
755 | qkv_bias=qkv_bias, qk_scale=qk_scale,
756 | drop=drop_rate, attn_drop=attn_drop_rate,
757 | drop_path=enc_dpr[sum(depths[:3]):sum(depths[:4])],
758 | norm_layer=norm_layer,
759 | use_checkpoint=use_checkpoint,
760 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=False,sparseAtt=False)
761 | self.dowsample_3 = dowsample(embed_dim*8, embed_dim*16)
762 |
763 | # Bottleneck
764 | self.conv = BasicASTLayer(dim=embed_dim*16,
765 | output_dim=embed_dim*16,
766 | input_resolution=(img_size // (2 ** 4),
767 | img_size // (2 ** 4)),
768 | depth=depths[4],
769 | num_heads=num_heads[4],
770 | win_size=win_size,
771 | mlp_ratio=self.mlp_ratio,
772 | qkv_bias=qkv_bias, qk_scale=qk_scale,
773 | drop=drop_rate, attn_drop=attn_drop_rate,
774 | drop_path=conv_dpr,
775 | norm_layer=norm_layer,
776 | use_checkpoint=use_checkpoint,
777 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=True,sparseAtt=True)
778 |
779 | # Decoder
780 | self.upsample_0 = upsample(embed_dim*16, embed_dim*8)
781 | self.decoderlayer_0 = BasicASTLayer(dim=embed_dim*16,
782 | output_dim=embed_dim*16,
783 | input_resolution=(img_size // (2 ** 3),
784 | img_size // (2 ** 3)),
785 | depth=depths[5],
786 | num_heads=num_heads[5],
787 | win_size=win_size,
788 | mlp_ratio=self.mlp_ratio,
789 | qkv_bias=qkv_bias, qk_scale=qk_scale,
790 | drop=drop_rate, attn_drop=attn_drop_rate,
791 | drop_path=dec_dpr[:depths[5]],
792 | norm_layer=norm_layer,
793 | use_checkpoint=use_checkpoint,
794 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=True,sparseAtt=True)
795 | self.upsample_1 = upsample(embed_dim*16, embed_dim*4)
796 | self.decoderlayer_1 = BasicASTLayer(dim=embed_dim*8,
797 | output_dim=embed_dim*8,
798 | input_resolution=(img_size // (2 ** 2),
799 | img_size // (2 ** 2)),
800 | depth=depths[6],
801 | num_heads=num_heads[6],
802 | win_size=win_size,
803 | mlp_ratio=self.mlp_ratio,
804 | qkv_bias=qkv_bias, qk_scale=qk_scale,
805 | drop=drop_rate, attn_drop=attn_drop_rate,
806 | drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])],
807 | norm_layer=norm_layer,
808 | use_checkpoint=use_checkpoint,
809 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=True,sparseAtt=True)
810 | self.upsample_2 = upsample(embed_dim*8, embed_dim*2)
811 | self.decoderlayer_2 = BasicASTLayer(dim=embed_dim*4,
812 | output_dim=embed_dim*4,
813 | input_resolution=(img_size // 2,
814 | img_size // 2),
815 | depth=depths[7],
816 | num_heads=num_heads[7],
817 | win_size=win_size,
818 | mlp_ratio=self.mlp_ratio,
819 | qkv_bias=qkv_bias, qk_scale=qk_scale,
820 | drop=drop_rate, attn_drop=attn_drop_rate,
821 | drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])],
822 | norm_layer=norm_layer,
823 | use_checkpoint=use_checkpoint,
824 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=True,sparseAtt=True)
825 | self.upsample_3 = upsample(embed_dim*4, embed_dim)
826 | self.decoderlayer_3 = BasicASTLayer(dim=embed_dim*2,
827 | output_dim=embed_dim*2,
828 | input_resolution=(img_size,
829 | img_size),
830 | depth=depths[8],
831 | num_heads=num_heads[8],
832 | win_size=win_size,
833 | mlp_ratio=self.mlp_ratio,
834 | qkv_bias=qkv_bias, qk_scale=qk_scale,
835 | drop=drop_rate, attn_drop=attn_drop_rate,
836 | drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])],
837 | norm_layer=norm_layer,
838 | use_checkpoint=use_checkpoint,
839 | token_projection=token_projection,token_mlp=token_mlp,shift_flag=shift_flag,att=True,sparseAtt=True)
840 |
841 | self.apply(self._init_weights)
842 |
843 | def _init_weights(self, m):
844 | if isinstance(m, nn.Linear):
845 | trunc_normal_(m.weight, std=.02)
846 | if isinstance(m, nn.Linear) and m.bias is not None:
847 | nn.init.constant_(m.bias, 0)
848 | elif isinstance(m, nn.LayerNorm):
849 | nn.init.constant_(m.bias, 0)
850 | nn.init.constant_(m.weight, 1.0)
851 |
852 | @torch.jit.ignore
853 | def no_weight_decay(self):
854 | return {'absolute_pos_embed'}
855 |
856 | @torch.jit.ignore
857 | def no_weight_decay_keywords(self):
858 | return {'relative_position_bias_table'}
859 |
860 | def extra_repr(self) -> str:
861 | return f"embed_dim={self.embed_dim}, token_projection={self.token_projection}, token_mlp={self.mlp},win_size={self.win_size}"
862 |
863 | def forward(self, x, mask=None):
864 | # Input Projection
865 | y = self.input_proj(x)
866 | y = self.pos_drop(y)
867 | #Encoder
868 | conv0 = self.encoderlayer_0(y,mask=mask)
869 | pool0 = self.dowsample_0(conv0)
870 | conv1 = self.encoderlayer_1(pool0,mask=mask)
871 | pool1 = self.dowsample_1(conv1)
872 | conv2 = self.encoderlayer_2(pool1,mask=mask)
873 | pool2 = self.dowsample_2(conv2)
874 | conv3 = self.encoderlayer_3(pool2,mask=mask)
875 | pool3 = self.dowsample_3(conv3)
876 |
877 | # Bottleneck
878 | conv4 = self.conv(pool3, mask=mask)
879 |
880 | #Decoder
881 | up0 = self.upsample_0(conv4)
882 | deconv0 = torch.cat([up0,conv3],-1)
883 | deconv0 = self.decoderlayer_0(deconv0,mask=mask)
884 |
885 | up1 = self.upsample_1(deconv0)
886 | deconv1 = torch.cat([up1,conv2],-1)
887 | deconv1 = self.decoderlayer_1(deconv1,mask=mask)
888 |
889 | up2 = self.upsample_2(deconv1)
890 | deconv2 = torch.cat([up2,conv1],-1)
891 | deconv2 = self.decoderlayer_2(deconv2,mask=mask)
892 |
893 | up3 = self.upsample_3(deconv2)
894 | deconv3 = torch.cat([up3,conv0],-1)
895 | deconv3 = self.decoderlayer_3(deconv3,mask=mask)
896 |
897 | # Output Projection
898 | y = self.output_proj(deconv3)
899 | return x + y if self.dd_in ==3 else y
900 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | class Options():
4 | """docstring for Options"""
5 | def __init__(self):
6 | pass
7 |
8 | def init(self, parser):
9 | # global settings
10 | parser.add_argument('--batch_size', type=int, default=32, help='batch size')
11 | parser.add_argument('--nepoch', type=int, default=250, help='training epochs')
12 | parser.add_argument('--train_workers', type=int, default=0, help='train_dataloader workers')
13 | parser.add_argument('--eval_workers', type=int, default=0, help='eval_dataloader workers')
14 | parser.add_argument('--dataset', type=str, default ='SIDD')
15 | parser.add_argument('--pretrain_weights',type=str, default='./log/AST_B/models/model_best.pth', help='path of pretrained_weights')
16 | parser.add_argument('--optimizer', type=str, default ='adamw', help='optimizer for training')
17 | parser.add_argument('--lr_initial', type=float, default=0.0002, help='initial learning rate')
18 | parser.add_argument('--step_lr', type=int, default=50, help='weight decay')
19 | parser.add_argument('--weight_decay', type=float, default=0.02, help='weight decay')
20 | parser.add_argument('--gpu', type=str, default='0,1', help='GPUs')
21 | parser.add_argument('--arch', type=str, default ='AST_B', help='archtechture')
22 | parser.add_argument('--mode', type=str, default ='denoising', help='image restoration mode')
23 | parser.add_argument('--dd_in', type=int, default=3, help='dd_in')
24 |
25 | # args for saving
26 | parser.add_argument('--save_dir', type=str, default ='./logs/', help='save dir')
27 | parser.add_argument('--save_images', action='store_true',default=False)
28 | parser.add_argument('--env', type=str, default ='_', help='env')
29 | parser.add_argument('--checkpoint', type=int, default=50, help='checkpoint')
30 |
31 | # args for Uformer
32 | parser.add_argument('--norm_layer', type=str, default ='nn.LayerNorm', help='normalize layer in transformer')
33 | parser.add_argument('--embed_dim', type=int, default=32, help='dim of emdeding features')
34 | parser.add_argument('--win_size', type=int, default=8, help='window size of self-attention')
35 | parser.add_argument('--token_projection', type=str,default='linear', help='linear/conv token projection')
36 | parser.add_argument('--token_mlp', type=str,default='leff', help='ffn/leff token mlp/leff_mpconv')
37 | parser.add_argument('--att_se', action='store_true', default=False, help='se after sa')
38 | parser.add_argument('--modulator', action='store_true', default=False, help='multi-scale modulator')
39 |
40 | # args for vit
41 | parser.add_argument('--vit_dim', type=int, default=256, help='vit hidden_dim')
42 | parser.add_argument('--vit_depth', type=int, default=12, help='vit depth')
43 | parser.add_argument('--vit_nheads', type=int, default=8, help='vit hidden_dim')
44 | parser.add_argument('--vit_mlp_dim', type=int, default=512, help='vit mlp_dim')
45 | parser.add_argument('--vit_patch_size', type=int, default=16, help='vit patch_size')
46 | parser.add_argument('--global_skip', action='store_true', default=False, help='global skip connection')
47 | parser.add_argument('--local_skip', action='store_true', default=False, help='local skip connection')
48 | parser.add_argument('--vit_share', action='store_true', default=False, help='share vit module')
49 |
50 | # args for training
51 | parser.add_argument('--train_ps', type=int, default=256, help='patch size of training sample')
52 | parser.add_argument('--val_ps', type=int, default=256, help='patch size of validation sample')
53 | parser.add_argument('--resume', action='store_true',default=False)
54 | parser.add_argument('--retrain', action='store_true', default=False)
55 | parser.add_argument('--train_dir', type=str, default ='./rain_syn/DDN-Data/train/', help='dir of train data')
56 | parser.add_argument('--val_dir', type=str, default ='./rain_syn/DDN-Data/test/', help='dir of train data')
57 | parser.add_argument('--warmup', action='store_true', default=False, help='warmup')
58 | parser.add_argument('--warmup_epochs', type=int,default=3, help='epochs for warmup')
59 |
60 | # ddp
61 | parser.add_argument("--local_rank", type=int,default=-1,help='DDP parameter, do not modify')#不需要赋值,启动命令 torch.distributed.launch会自动赋值
62 | parser.add_argument("--distribute",action='store_true',help='whether using multi gpu train')
63 | parser.add_argument("--distribute_mode",type=str,default='DDP',help="using which mode to ")
64 | return parser
65 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | torch==1.8.0
2 | torchvision==0.9.0
3 | matplotlib
4 | scikit-image
5 | opencv-python
6 | yacs
7 | joblib
8 | natsort
9 | h5py
10 | tqdm
11 | einops
12 | linformer
13 | timm
14 | ptflops
15 | dataclasses
16 | natsort
--------------------------------------------------------------------------------
/script/test.sh:
--------------------------------------------------------------------------------
1 | ##test on AGAN-Data##
2 | python3 test/test_raindrop.py --arch AST_B --input_dir ../dataset/raindrop/test_a/ --result_dir ./results/rain_drop/AGAN-Data/ --weights ./logs/raindrop/AGAN-Data/AST_B/models/model_best.pth --token_mlp frfn
3 |
4 | ##test on densehaze##
5 | python3 test/test_denseHaze.py --arch AST_B --input_dir ../dataset/Dense-Haze-v2/valid_dense --result_dir ./results/dehaze/DenseHaze/ --weights ./logs/dehazing/DenseHaze/AST_B/models/model_best.pth --token_mlp frfn
6 |
7 | ## test on SPAD##
8 | python3 test/test_spad.py --arch AST_B --input_dir /PATH/TO/DATASET/ --result_dir ./results/deraining/SPAD/ --weights ./pretrained/rain/model_best.pth --token_mlp frfn
9 |
10 |
11 |
12 |
--------------------------------------------------------------------------------
/script/train_dehaze.sh:
--------------------------------------------------------------------------------
1 | ##train on densehaze##
2 |
3 | python3 ./train/train_dehaze.py --arch AST_B --batch_size 4 --gpu '0' --train_ps 256 --train_dir ../dataset/Dense-Haze-v2/train_dense/ --val_ps 256 --val_dir ../dataset/Dense-Haze-v2/valid_dense/ --lr 0.0002 --env _1108_s1 --mode dehaze --nepoch 2000 --dataset DenseHaze --warmup --token_mlp frfn
4 |
5 | # python3 ./train/train_dehaze.py --arch AST_B --batch_size 4 --retrain --pretrain_weights ./logs/dehazing/DenseHaze/AST_B_1108_s1/models/model_best.pth --gpu '0' --train_ps 384 --train_dir ../dataset/Dense-Haze-v2/train_dense/ --val_ps 384 --val_dir ../dataset/Dense-Haze-v2/valid_dense/ --lr 0.00012 --env _1109_s2 --mode dehaze --nepoch 1200 --token_mlp frfn --dataset DenseHaze --warmup
6 |
7 | # python3 ./train/train_dehaze.py --arch AST_B --batch_size 4 --retrain --pretrain_weights ./logs/dehazing/DenseHaze/AST_B_1109_s2/models/model_best.pth --gpu '0' --train_ps 512 --train_dir ../dataset/Dense-Haze-v2/train_dense/ --val_ps 512 --val_dir ../dataset/Dense-Haze-v2/valid_dense/ --lr 0.00008 --env _1110_s3 --mode dehaze --nepoch 800 --token_mlp frfn --dataset DenseHaze --warmup
8 |
9 | # python3 ./train/train_dehaze.py --arch AST_B --retrain --pretrain_weights ./logs/dehazing/DenseHaze/AST_B_1110_s3/models/model_best.pth --batch_size 2 --gpu '0,1' --train_ps 768 --train_dir ../dataset/Dense-Haze-v2/train_dense/ --val_ps 768 --val_dir ../dataset/Dense-Haze-v2/valid_dense/ --lr 0.00003 --env _1105_s4 --mode dehaze --nepoch 300 --token_mlp frfn --dataset DenseHaze --warmup
10 |
11 | # python3 ./train/train_dehaze.py --arch AST_B --retrain --pretrain_weights ./logs/dehazing/DenseHaze/AST_B_1105/models/model_best.pth --batch_size 2 --gpu '0,1' --train_ps 896 --train_dir ../dataset/Dense-Haze-v2/train_dense/ --val_ps 896 --val_dir ../dataset/Dense-Haze-v2/valid_dense/ --lr 0.00001 --env _1106_s5 --mode dehaze --nepoch 80 --token_mlp frfn --dataset DenseHaze --warmup
12 |
--------------------------------------------------------------------------------
/script/train_derain.sh:
--------------------------------------------------------------------------------
1 | ##train on SPAD##
2 | python ./train/train_derain.py --arch AST_B --batch_size 32 --gpu 0,1 --train_ps 128 --train_dir ../derain_dataset/derain/ --env _1030_s1 --val_dir ../derain_dataset/derain/ --save_dir ./logs/ --dataset spad --warmup --token_mlp frfn --nepoch 20 --lr_initial 0.0002
3 |
4 | # python ./train/train_derain.py --arch AST_B --retrain --pretrain_weights ./logs/derain/spad/AST_B_1030_s1/models/model_best.pth --batch_size 16 --gpu 0,1 --train_ps 256 --train_dir ../derain_dataset/derain/ --env _1030_s2 --val_ps 256 --val_dir ../derain_dataset/derain/ --save_dir ./logs/ --dataset spad --warmup --token_mlp frfn --nepoch 15 --lr_initial 0.0001
5 |
--------------------------------------------------------------------------------
/script/train_raindrop.sh:
--------------------------------------------------------------------------------
1 | ##train on AGAN-Data##
2 | python ./train/train_raindrop.py --arch AST_B --batch_size 32 --gpu '0,1' --train_ps 128 --train_dir ../dataset/raindrop/train/ --val_ps 128 --val_dir ../dataset/raindrop/test_a/ --env _1102_s1 --mode derain_drop --nepoch 300 --dataset raindrop --warmup --lr_initial 0.0002 --token_mlp frfn
3 |
4 | # python ./train/train_derain_drop.py --arch AST_B --retrain --pretrain_weights ./logs/raindrop/AGAN-Data/AST_B_1102_s1/models/model_best.pth --batch_size 16 --gpu '0,1' --train_ps 256 --train_dir ../dataset/raindrop/train/ --val_ps 256 --val_dir ../dataset/raindrop/test_a/ --env _1103_s2 --mode derain_drop --nepoch 200 --dataset raindrop --warmup --lr_initial 0.0001 --token_mlp frfn
5 |
6 | # python ./train/train_derain_drop.py --arch AST_B --retrain --pretrain_weights ./logs/derain/raindrop/AST_B_1103_s2/models/model_best.pth --batch_size 8 --gpu '0,1' --train_ps 384 --train_dir ../dataset/raindrop/train/ --val_ps 384 --val_dir ../dataset/raindrop/test_a/ --env _1103_s3 --mode derain_drop --nepoch 150 --dataset raindrop --warmup --lr_initial 0.00008 --token_mlp frfn
--------------------------------------------------------------------------------
/test/test_denseHaze.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys,math
3 | import argparse
4 | from tqdm import tqdm
5 | from einops import rearrange, repeat
6 |
7 | import torch.nn as nn
8 | import torch
9 | from torch.utils.data import DataLoader
10 | import torch.nn.functional as F
11 | # from ptflops import get_model_complexity_info
12 |
13 | dir_name = os.path.dirname(os.path.abspath(__file__))
14 | sys.path.append(os.path.join(dir_name,'../dataset/'))
15 | sys.path.append(os.path.join(dir_name,'..'))
16 |
17 | import scipy.io as sio
18 | from dataset.dataset_dehaze_denseHaze import *
19 | import utils
20 |
21 | from skimage import img_as_float32, img_as_ubyte
22 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss
23 | from skimage.metrics import structural_similarity as ssim_loss
24 |
25 | parser = argparse.ArgumentParser(description='Image dehazing evaluation on DenseHaze')
26 | parser.add_argument('--input_dir', default='/PATH/TO/DATASET/',
27 | type=str, help='Directory of validation images')
28 | parser.add_argument('--result_dir', default='./results_release/haze/densehaze/AST_B/',
29 | type=str, help='Directory for results')
30 | parser.add_argument('--weights', default='./pretrained/haze/model_best.pth',
31 | type=str, help='Path to weights')
32 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES')
33 | parser.add_argument('--arch', default='AST_B', type=str, help='arch')
34 | parser.add_argument('--batch_size', default=1, type=int, help='Batch size for dataloader')
35 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory')
36 | parser.add_argument('--embed_dim', type=int, default=32, help='number of data loading workers')
37 | parser.add_argument('--win_size', type=int, default=8, help='number of data loading workers')
38 | parser.add_argument('--token_projection', type=str,default='linear', help='linear/conv token projection')
39 | parser.add_argument('--token_mlp', type=str,default='leff', help='ffn/leff token mlp')
40 | parser.add_argument('--query_embed', action='store_true', default=False, help='query embedding for the decoder')
41 | parser.add_argument('--dd_in', type=int, default=3, help='dd_in')
42 |
43 | # args for vit
44 | parser.add_argument('--vit_dim', type=int, default=256, help='vit hidden_dim')
45 | parser.add_argument('--vit_depth', type=int, default=12, help='vit depth')
46 | parser.add_argument('--vit_nheads', type=int, default=8, help='vit hidden_dim')
47 | parser.add_argument('--vit_mlp_dim', type=int, default=512, help='vit mlp_dim')
48 | parser.add_argument('--vit_patch_size', type=int, default=16, help='vit patch_size')
49 | parser.add_argument('--global_skip', action='store_true', default=False, help='global skip connection')
50 | parser.add_argument('--local_skip', action='store_true', default=False, help='local skip connection')
51 | parser.add_argument('--vit_share', action='store_true', default=False, help='share vit module')
52 |
53 | parser.add_argument('--train_ps', type=int, default=128, help='patch size of training sample')
54 | args = parser.parse_args()
55 |
56 |
57 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
58 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
59 |
60 | utils.mkdir(args.result_dir)
61 |
62 | test_dataset = get_validation_data(args.input_dir)
63 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=8, drop_last=False)
64 |
65 | model_restoration= utils.get_arch(args)
66 | # model_restoration = torch.nn.DataParallel(model_restoration)
67 |
68 | utils.load_checkpoint(model_restoration,args.weights)
69 | print("===>Testing using weights: ", args.weights)
70 |
71 | model_restoration.cuda()
72 | model_restoration.eval()
73 |
74 |
75 | def expand2square(timg,factor=16.0):
76 | _, _, h, w = timg.size()
77 |
78 | X = int(math.ceil(max(h,w)/float(factor))*factor)
79 |
80 | img = torch.zeros(1,3,X,X).type_as(timg) # 3, h,w
81 | mask = torch.zeros(1,1,X,X).type_as(timg)
82 |
83 | # print(img.size(),mask.size())
84 | # print((X - h)//2, (X - h)//2+h, (X - w)//2, (X - w)//2+w)
85 | img[:,:, ((X - h)//2):((X - h)//2 + h),((X - w)//2):((X - w)//2 + w)] = timg
86 | mask[:,:, ((X - h)//2):((X - h)//2 + h),((X - w)//2):((X - w)//2 + w)].fill_(1)
87 |
88 | return img, mask
89 | from utils.image_utils import splitimage, mergeimage
90 |
91 |
92 | with torch.no_grad():
93 | psnr_val_rgb = []
94 | ssim_val_rgb = []
95 | for ii, data_test in enumerate(tqdm(test_loader), 0):
96 | rgb_gt = data_test[0].numpy().squeeze().transpose((1,2,0))
97 | # rgb_noisy, mask = expand2square(data_test[1].cuda(), factor=128)
98 | filenames = data_test[2]
99 |
100 | input_ = data_test[1].cuda()
101 | B, C, H, W = input_.shape
102 | corp_size_arg = 1152
103 | overlap_size_arg = 384
104 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg)
105 | for i, data in enumerate(split_data):
106 | split_data[i] = model_restoration(data).cpu()
107 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W))
108 | rgb_restored = torch.clamp(restored, 0, 1).permute(0, 2, 3, 1).numpy()
109 |
110 | psnr = psnr_loss(rgb_restored[0], rgb_gt)
111 | ssim = ssim_loss(rgb_restored[0], rgb_gt, channel_axis=2, data_range=1)
112 |
113 | psnr_val_rgb.append(psnr)
114 | ssim_val_rgb.append(ssim)
115 | print("PSNR:",psnr,", SSIM:", ssim, filenames[0], rgb_restored.shape)
116 | utils.save_img(os.path.join(args.result_dir,filenames[0]+'.PNG'), img_as_ubyte(rgb_restored[0]))
117 | with open(os.path.join(args.result_dir,'psnr_ssim.txt'),'a') as f:
118 | f.write(filenames[0]+'.PNG ---->'+"PSNR: %.4f, SSIM: %.4f] "% (psnr, ssim)+'\n')
119 | psnr_val_rgb = sum(psnr_val_rgb)/len(test_dataset)
120 | ssim_val_rgb = sum(ssim_val_rgb)/len(test_dataset)
121 | print("PSNR: %f, SSIM: %f " %(psnr_val_rgb,ssim_val_rgb))
122 | with open(os.path.join(args.result_dir,'psnr_ssim.txt'),'a') as f:
123 | f.write("Arch:"+args.arch+", PSNR: %.4f, SSIM: %.4f] "% (psnr_val_rgb, ssim_val_rgb)+'\n')
--------------------------------------------------------------------------------
/test/test_raindrop.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys
3 | import argparse
4 | from tqdm import tqdm
5 | from einops import rearrange, repeat
6 |
7 | import torch.nn as nn
8 | import torch
9 | from torch.utils.data import DataLoader
10 | import torch.nn.functional as F
11 | # from ptflops import get_model_complexity_info
12 |
13 | dir_name = os.path.dirname(os.path.abspath(__file__))
14 | sys.path.append(os.path.join(dir_name,'../dataset/'))
15 | sys.path.append(os.path.join(dir_name,'..'))
16 |
17 | # import scipy.io as sio
18 | from dataset.dataset_derain_drop import *
19 | import utils
20 | import math
21 |
22 | from skimage import img_as_float32, img_as_ubyte
23 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss
24 | from skimage.metrics import structural_similarity as ssim_loss
25 |
26 | parser = argparse.ArgumentParser(description='Image derain evaluation on spad')
27 | parser.add_argument('--input_dir', default='./dataset/AST_B/deraining/spad/val/',
28 | type=str, help='Directory of validation images')
29 | parser.add_argument('--result_dir', default='./results/deraining/spad/AST_B/se',
30 | type=str, help='Directory for results')
31 | parser.add_argument('--weights', default='./logs/deraining/spad/AST_B/models/model_best.pth',
32 | type=str, help='Path to weights')
33 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES')
34 | parser.add_argument('--arch', default='AST_B', type=str, help='arch')
35 | parser.add_argument('--batch_size', default=1, type=int, help='Batch size for dataloader')
36 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory')
37 | parser.add_argument('--embed_dim', type=int, default=32, help='number of data loading workers')
38 | parser.add_argument('--win_size', type=int, default=8, help='number of data loading workers')
39 | parser.add_argument('--token_projection', type=str,default='linear', help='linear/conv token projection')
40 | parser.add_argument('--token_mlp', type=str,default='leff', help='ffn/leff token mlp')
41 | parser.add_argument('--dd_in', type=int, default=3, help='dd_in')
42 |
43 | # args for vit
44 | parser.add_argument('--vit_dim', type=int, default=256, help='vit hidden_dim')
45 | parser.add_argument('--vit_depth', type=int, default=12, help='vit depth')
46 | parser.add_argument('--vit_nheads', type=int, default=8, help='vit hidden_dim')
47 | parser.add_argument('--vit_mlp_dim', type=int, default=512, help='vit mlp_dim')
48 | parser.add_argument('--vit_patch_size', type=int, default=16, help='vit patch_size')
49 | parser.add_argument('--global_skip', action='store_true', default=False, help='global skip connection')
50 | parser.add_argument('--local_skip', action='store_true', default=False, help='local skip connection')
51 | parser.add_argument('--vit_share', action='store_true', default=False, help='share vit module')
52 |
53 | parser.add_argument('--train_ps', type=int, default=128, help='patch size of training sample')
54 | args = parser.parse_args()
55 |
56 |
57 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
58 | os.environ["CUDA_VISIBLE_DEVICES"] = "2"
59 |
60 |
61 | utils.mkdir(args.result_dir)
62 |
63 | test_dataset = get_test_data(args.input_dir)
64 | # test_dataset = get_test_data(args.input_dir)
65 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=8, drop_last=False)
66 |
67 | model_restoration= utils.get_arch(args)
68 |
69 | utils.load_checkpoint(model_restoration,args.weights)
70 | print("===>Testing using weights: ", args.weights)
71 |
72 | model_restoration.cuda()
73 | model_restoration.eval()
74 | from utils.image_utils import splitimage, mergeimage
75 |
76 | def test_transform(v, op):
77 | v2np = v.data.cpu().numpy()
78 | if op == 'v':
79 | tfnp = v2np[:, :, :, ::-1].copy()
80 | elif op == 'h':
81 | tfnp = v2np[:, :, ::-1, :].copy()
82 | elif op == 't':
83 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
84 |
85 | ret = torch.Tensor(tfnp).to(v.device)
86 |
87 | return ret
88 |
89 | def expand2square(timg,factor=16.0):
90 | _, _, h, w = timg.size()
91 |
92 | X = int(math.ceil(max(h,w)/float(factor))*factor)
93 |
94 | img = torch.zeros(1,3,X,X).type_as(timg) # 3, h,w
95 | mask = torch.zeros(1,1,X,X).type_as(timg)
96 |
97 | # print(img.size(),mask.size())
98 | # print((X - h)//2, (X - h)//2+h, (X - w)//2, (X - w)//2+w)
99 | img[:,:, ((X - h)//2):((X - h)//2 + h),((X - w)//2):((X - w)//2 + w)] = timg
100 | mask[:,:, ((X - h)//2):((X - h)//2 + h),((X - w)//2):((X - w)//2 + w)].fill_(1)
101 |
102 | return img, mask
103 |
104 | # # Process data
105 | with torch.no_grad():
106 | psnr_val_rgb = []
107 | ssim_val_rgb = []
108 | for ii, data_test in enumerate(tqdm(test_loader), 0):
109 | rgb_gt = data_test[0].numpy().squeeze().transpose((1,2,0))
110 | # rgb_noisy, mask = expand2square(data_test[1].cuda(), factor=128)
111 | filenames = data_test[2]
112 |
113 | input_ = data_test[1].cuda()
114 | B, C, H, W = input_.shape
115 | corp_size_arg = 384
116 | overlap_size_arg = 80
117 | split_data, starts = splitimage(input_, crop_size=corp_size_arg, overlap_size=overlap_size_arg)
118 | for i, data in enumerate(split_data):
119 | split_data[i] = model_restoration(data).cpu()
120 | restored = mergeimage(split_data, starts, crop_size=corp_size_arg, resolution=(B, C, H, W))
121 | rgb_restored = torch.clamp(restored, 0, 1).permute(0, 2, 3, 1).numpy()
122 |
123 | psnr = psnr_loss(rgb_restored[0], rgb_gt)
124 | ssim = ssim_loss(rgb_restored[0], rgb_gt, channel_axis=2, data_range=1)
125 |
126 |
127 | psnr_val_rgb.append(psnr)
128 | ssim_val_rgb.append(ssim)
129 | print("PSNR:",psnr,", SSIM:", ssim, filenames[0], rgb_restored.shape)
130 | utils.save_img(os.path.join(args.result_dir,filenames[0]+'.PNG'), img_as_ubyte(rgb_restored[0]))
131 | with open(os.path.join(args.result_dir,'psnr_ssim.txt'),'a') as f:
132 | f.write(filenames[0]+'.PNG ---->'+"PSNR: %.4f, SSIM: %.4f] "% (psnr, ssim)+'\n')
133 | psnr_val_rgb = sum(psnr_val_rgb)/len(test_dataset)
134 | ssim_val_rgb = sum(ssim_val_rgb)/len(test_dataset)
135 | print("PSNR: %f, SSIM: %f " %(psnr_val_rgb,ssim_val_rgb))
136 | with open(os.path.join(args.result_dir,'psnr_ssim.txt'),'a') as f:
137 | f.write("Arch:"+args.arch+", PSNR: %.4f, SSIM: %.4f] "% (psnr_val_rgb, ssim_val_rgb)+'\n')
--------------------------------------------------------------------------------
/test/test_spad.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import os,sys
3 | import argparse
4 | from tqdm import tqdm
5 | from einops import rearrange, repeat
6 |
7 | import torch.nn as nn
8 | import torch
9 | from torch.utils.data import DataLoader
10 | import torch.nn.functional as F
11 |
12 | dir_name = os.path.dirname(os.path.abspath(__file__))
13 | sys.path.append(os.path.join(dir_name,'../dataset/'))
14 | sys.path.append(os.path.join(dir_name,'..'))
15 |
16 | # import scipy.io as sio
17 | from dataset.dataset_derain import *
18 | import utils
19 | import math
20 |
21 | from skimage import img_as_float32, img_as_ubyte
22 | from skimage.metrics import peak_signal_noise_ratio as psnr_loss
23 | from skimage.metrics import structural_similarity as ssim_loss
24 |
25 | parser = argparse.ArgumentParser(description='Image derain evaluation on spad')
26 | parser.add_argument('--input_dir', default='/PATH/TO/DATASET/',
27 | type=str, help='Directory of validation images')
28 | parser.add_argument('--result_dir', default='./results/deraining/spad/AST_B/',
29 | type=str, help='Directory for results')
30 | parser.add_argument('--weights', default='./pretrained/rain/model_best.pth',
31 | type=str, help='Path to weights')
32 | parser.add_argument('--gpus', default='0', type=str, help='CUDA_VISIBLE_DEVICES')
33 | parser.add_argument('--arch', default='AST_B', type=str, help='arch')
34 | parser.add_argument('--batch_size', default=1, type=int, help='Batch size for dataloader')
35 | parser.add_argument('--save_images', action='store_true', help='Save denoised images in result directory')
36 | parser.add_argument('--embed_dim', type=int, default=32, help='number of data loading workers')
37 | parser.add_argument('--win_size', type=int, default=8, help='number of data loading workers')
38 | parser.add_argument('--token_projection', type=str,default='linear', help='linear/conv token projection')
39 | parser.add_argument('--token_mlp', type=str,default='frfn', help='ffn/leff token mlp')
40 | parser.add_argument('--dd_in', type=int, default=3, help='dd_in')
41 |
42 | # args for vit
43 | parser.add_argument('--vit_dim', type=int, default=256, help='vit hidden_dim')
44 | parser.add_argument('--vit_depth', type=int, default=12, help='vit depth')
45 | parser.add_argument('--vit_nheads', type=int, default=8, help='vit hidden_dim')
46 | parser.add_argument('--vit_mlp_dim', type=int, default=512, help='vit mlp_dim')
47 | parser.add_argument('--vit_patch_size', type=int, default=16, help='vit patch_size')
48 | parser.add_argument('--global_skip', action='store_true', default=False, help='global skip connection')
49 | parser.add_argument('--local_skip', action='store_true', default=False, help='local skip connection')
50 | parser.add_argument('--vit_share', action='store_true', default=False, help='share vit module')
51 |
52 | parser.add_argument('--train_ps', type=int, default=128, help='patch size of training sample')
53 | args = parser.parse_args()
54 |
55 |
56 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
57 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
58 |
59 |
60 | utils.mkdir(args.result_dir)
61 |
62 | test_dataset = get_validation_data(args.input_dir)
63 | test_loader = DataLoader(dataset=test_dataset, batch_size=1, shuffle=False, num_workers=8, drop_last=False)
64 |
65 | model_restoration= utils.get_arch(args)
66 |
67 | utils.load_checkpoint(model_restoration,args.weights)
68 | print("===>Testing using weights: ", args.weights)
69 |
70 | model_restoration.cuda()
71 | model_restoration.eval()
72 |
73 |
74 | def test_transform(v, op):
75 | v2np = v.data.cpu().numpy()
76 | if op == 'v':
77 | tfnp = v2np[:, :, :, ::-1].copy()
78 | elif op == 'h':
79 | tfnp = v2np[:, :, ::-1, :].copy()
80 | elif op == 't':
81 | tfnp = v2np.transpose((0, 1, 3, 2)).copy()
82 |
83 | ret = torch.Tensor(tfnp).to(v.device)
84 |
85 | return ret
86 |
87 | def expand2square(timg,factor=16.0):
88 | _, _, h, w = timg.size()
89 |
90 | X = int(math.ceil(max(h,w)/float(factor))*factor)
91 |
92 | img = torch.zeros(1,3,X,X).type_as(timg) # 3, h,w
93 | mask = torch.zeros(1,1,X,X).type_as(timg)
94 |
95 | img[:,:, ((X - h)//2):((X - h)//2 + h),((X - w)//2):((X - w)//2 + w)] = timg
96 | mask[:,:, ((X - h)//2):((X - h)//2 + h),((X - w)//2):((X - w)//2 + w)].fill_(1)
97 |
98 | return img, mask
99 |
100 | # # Process data
101 | with torch.no_grad():
102 | psnr_val_rgb = []
103 | ssim_val_rgb = []
104 | for ii, data_test in enumerate(tqdm(test_loader), 0):
105 | rgb_gt = data_test[0].numpy().squeeze().transpose((1,2,0))
106 | rgb_noisy, mask = expand2square(data_test[1].cuda(), factor=128)
107 | filenames = data_test[2]
108 |
109 | rgb_restored = model_restoration(rgb_noisy)
110 | rgb_restored = torch.masked_select(rgb_restored,mask.bool()).reshape(1,3,rgb_gt.shape[0],rgb_gt.shape[1])
111 | rgb_restored = torch.clamp(rgb_restored,0,1).cpu().numpy().squeeze().transpose((1,2,0))
112 |
113 | psnr = psnr_loss(rgb_restored, rgb_gt)
114 | ssim = ssim_loss(rgb_restored, rgb_gt, multichannel=True)
115 | psnr_val_rgb.append(psnr)
116 | ssim_val_rgb.append(ssim)
117 | print("PSNR:",psnr,", SSIM:", ssim, filenames[0], rgb_restored.shape)
118 | utils.save_img(os.path.join(args.result_dir,filenames[0]+'.PNG'), img_as_ubyte(rgb_restored))
119 | with open(os.path.join(args.result_dir,'psnr_ssim.txt'),'a') as f:
120 | f.write(filenames[0]+'.PNG ---->'+"PSNR: %.4f, SSIM: %.4f] "% (psnr, ssim)+'\n')
121 | psnr_val_rgb = sum(psnr_val_rgb)/len(test_dataset)
122 | ssim_val_rgb = sum(ssim_val_rgb)/len(test_dataset)
123 | print("PSNR: %f, SSIM: %f " %(psnr_val_rgb,ssim_val_rgb))
124 | with open(os.path.join(args.result_dir,'psnr_ssim.txt'),'a') as f:
125 | f.write("Arch:"+args.arch+", PSNR: %.4f, SSIM: %.4f] "% (psnr_val_rgb, ssim_val_rgb)+'\n')
--------------------------------------------------------------------------------
/train/train_dehaze.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | # add dir
5 | dir_name = os.path.dirname(os.path.abspath(__file__))
6 | sys.path.append(os.path.join(dir_name,'../dataset/'))
7 | sys.path.append(os.path.join(dir_name,'..'))
8 | print(sys.path)
9 | print(dir_name)
10 |
11 | import argparse
12 | import options
13 | ######### parser ###########
14 | opt = options.Options().init(argparse.ArgumentParser(description='Image dehazing')).parse_args()
15 | print(opt)
16 |
17 | import utils
18 | from dataset.dataset_dehaze_denseHaze import *
19 | ######### Set GPUs ###########
20 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
21 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
22 | import torch
23 | torch.backends.cudnn.benchmark = True
24 |
25 | import torch.nn as nn
26 | import torch.optim as optim
27 | from torch.utils.data import DataLoader
28 | from natsort import natsorted
29 | import glob
30 | import random
31 | import time
32 | import numpy as np
33 | from einops import rearrange, repeat
34 | import datetime
35 | from pdb import set_trace as stx
36 |
37 | from losses import CharbonnierLoss
38 |
39 | from tqdm import tqdm
40 | from warmup_scheduler import GradualWarmupScheduler
41 | from torch.optim.lr_scheduler import StepLR
42 | from timm.utils import NativeScaler
43 |
44 | ######### Logs dir ###########
45 | log_dir = os.path.join(opt.save_dir, 'dehazing', opt.dataset, opt.arch+opt.env)
46 | if not os.path.exists(log_dir):
47 | os.makedirs(log_dir)
48 | logname = os.path.join(log_dir, datetime.datetime.now().isoformat()+'.txt')
49 | print("Now time is : ",datetime.datetime.now().isoformat())
50 | result_dir = os.path.join(log_dir, 'results')
51 | model_dir = os.path.join(log_dir, 'models')
52 | utils.mkdir(result_dir)
53 | utils.mkdir(model_dir)
54 |
55 | # ######### Set Seeds ###########
56 | random.seed(1234)
57 | np.random.seed(1234)
58 | torch.manual_seed(1234)
59 | torch.cuda.manual_seed_all(1234)
60 |
61 | ######### Model ###########
62 | model_restoration = utils.get_arch(opt)
63 |
64 | with open(logname,'a') as f:
65 | f.write(str(opt)+'\n')
66 | f.write(str(model_restoration)+'\n')
67 |
68 | ######### Optimizer ###########
69 | start_epoch = 1
70 | if opt.optimizer.lower() == 'adam':
71 | optimizer = optim.Adam(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay)
72 | elif opt.optimizer.lower() == 'adamw':
73 | optimizer = optim.AdamW(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay)
74 | else:
75 | raise Exception("Error optimizer...")
76 |
77 |
78 | ######### DataParallel ###########
79 | model_restoration = torch.nn.DataParallel (model_restoration)
80 | model_restoration.cuda()
81 |
82 |
83 | ######### Scheduler ###########
84 | if opt.warmup:
85 | print("Using warmup and cosine strategy!")
86 | warmup_epochs = opt.warmup_epochs
87 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.nepoch-warmup_epochs, eta_min=1e-6)
88 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
89 | scheduler.step()
90 | else:
91 | step = 50
92 | print("Using StepLR,step={}!".format(step))
93 | scheduler = StepLR(optimizer, step_size=step, gamma=0.5)
94 | scheduler.step()
95 |
96 | ######### Resume ###########
97 | if opt.resume:
98 | path_chk_rest = opt.pretrain_weights
99 | print("Resume from "+path_chk_rest)
100 | utils.load_checkpoint(model_restoration,path_chk_rest)
101 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1
102 | lr = utils.load_optim(optimizer, path_chk_rest)
103 |
104 | # for p in optimizer.param_groups: p['lr'] = lr
105 | # warmup = False
106 | # new_lr = lr
107 | # print('------------------------------------------------------------------------------')
108 | # print("==> Resuming Training with learning rate:",new_lr)
109 | # print('------------------------------------------------------------------------------')
110 | for i in range(1, start_epoch):
111 | scheduler.step()
112 | new_lr = scheduler.get_lr()[0]
113 | print('------------------------------------------------------------------------------')
114 | print("==> Resuming Training with learning rate:", new_lr)
115 | print('------------------------------------------------------------------------------')
116 |
117 | # scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.nepoch-start_epoch+1, eta_min=1e-6)
118 |
119 | ######### Retrain ###########
120 | if opt.retrain:
121 | path_chk_rest = opt.pretrain_weights
122 | print("Retrain from "+path_chk_rest)
123 | utils.load_checkpoint(model_restoration,path_chk_rest)
124 |
125 | print('------------------------------------------------------------------------------')
126 | print("==> Re Training with learning rate:", opt.lr_initial)
127 | print('------------------------------------------------------------------------------')
128 |
129 | ######### Loss ###########
130 | criterion = CharbonnierLoss().cuda()
131 |
132 | ######### DataLoader ###########
133 | print('===> Loading datasets')
134 | img_options_train = {'patch_size':opt.train_ps}
135 | img_options_val ={'patch_size':opt.val_ps}
136 | train_dataset = get_training_data(opt.train_dir, img_options_train)
137 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True,
138 | num_workers=opt.train_workers, pin_memory=False, drop_last=False)
139 | val_dataset = get_validation_data(opt.val_dir,img_options_val)
140 | val_loader = DataLoader(dataset=val_dataset, batch_size=opt.batch_size, shuffle=False,
141 | num_workers=opt.eval_workers, pin_memory=False, drop_last=False)
142 |
143 | len_trainset = train_dataset.__len__()
144 | len_valset = val_dataset.__len__()
145 | print("Sizeof training set: ", len_trainset,", sizeof validation set: ", len_valset)
146 | ######### validation ###########
147 | with torch.no_grad():
148 | model_restoration.eval()
149 | psnr_dataset = []
150 | psnr_model_init = []
151 | for ii, data_val in enumerate((val_loader), 0):
152 | target = data_val[0].cuda()
153 | input_ = data_val[1].cuda()
154 | with torch.cuda.amp.autocast():
155 | restored = model_restoration(input_)
156 | restored = torch.clamp(restored,0,1)
157 | psnr_dataset.append(utils.batch_PSNR(input_, target, False).item())
158 | psnr_model_init.append(utils.batch_PSNR(restored, target, False).item())
159 | psnr_dataset = sum(psnr_dataset)/len_valset
160 | psnr_model_init = sum(psnr_model_init)/len_valset
161 | print('Input & GT (PSNR) -->%.4f dB'%(psnr_dataset), ', Model_init & GT (PSNR) -->%.4f dB'%(psnr_model_init))
162 |
163 | ######### train ###########
164 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.nepoch))
165 | best_psnr = 0
166 | best_epoch = 0
167 | best_iter = 0
168 | eval_now = len(train_loader)//4
169 | print("\nEvaluation after every {} Iterations !!!\n".format(eval_now))
170 |
171 | loss_scaler = NativeScaler()
172 | torch.cuda.empty_cache()
173 | for epoch in range(start_epoch, opt.nepoch + 1):
174 | epoch_start_time = time.time()
175 | epoch_loss = 0
176 | train_id = 1
177 |
178 | for i, data in enumerate(tqdm(train_loader), 0):
179 | # zero_grad
180 | optimizer.zero_grad()
181 |
182 | target = data[0].cuda()
183 | input_ = data[1].cuda()
184 |
185 | if epoch>5:
186 | target, input_ = utils.MixUp_AUG().aug(target, input_)
187 | with torch.cuda.amp.autocast():
188 | restored = model_restoration(input_)
189 | loss = criterion(restored, target)
190 | loss_scaler(
191 | loss, optimizer,parameters=model_restoration.parameters())
192 | epoch_loss +=loss.item()
193 |
194 | #### Evaluation ####
195 | if (i+1)%eval_now==0 and i>0:
196 | with torch.no_grad():
197 | model_restoration.eval()
198 | psnr_val_rgb = []
199 | for ii, data_val in enumerate((val_loader), 0):
200 | target = data_val[0].cuda()
201 | input_ = data_val[1].cuda()
202 | filenames = data_val[2]
203 | with torch.cuda.amp.autocast():
204 | restored = model_restoration(input_)
205 | restored = torch.clamp(restored,0,1)
206 | psnr_val_rgb.append(utils.batch_PSNR(restored, target, False).item())
207 |
208 | psnr_val_rgb = sum(psnr_val_rgb)/len_valset
209 |
210 | if psnr_val_rgb > best_psnr:
211 | best_psnr = psnr_val_rgb
212 | best_epoch = epoch
213 | best_iter = i
214 | torch.save({'epoch': epoch,
215 | 'state_dict': model_restoration.state_dict(),
216 | 'optimizer' : optimizer.state_dict()
217 | }, os.path.join(model_dir,"model_best.pth"))
218 |
219 | print("[Ep %d it %d\t PSNR SIDD: %.4f\t] ---- [best_Ep_SIDD %d best_it_SIDD %d Best_PSNR_SIDD %.4f] " % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr))
220 | with open(logname,'a') as f:
221 | f.write("[Ep %d it %d\t PSNR SIDD: %.4f\t] ---- [best_Ep_SIDD %d best_it_SIDD %d Best_PSNR_SIDD %.4f] " \
222 | % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr)+'\n')
223 | model_restoration.train()
224 | torch.cuda.empty_cache()
225 | scheduler.step()
226 |
227 | print("------------------------------------------------------------------")
228 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0]))
229 | print("------------------------------------------------------------------")
230 | with open(logname,'a') as f:
231 | f.write("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0])+'\n')
232 |
233 | torch.save({'epoch': epoch,
234 | 'state_dict': model_restoration.state_dict(),
235 | 'optimizer' : optimizer.state_dict()
236 | }, os.path.join(model_dir,"model_latest.pth"))
237 |
238 | if epoch%opt.checkpoint == 0:
239 | torch.save({'epoch': epoch,
240 | 'state_dict': model_restoration.state_dict(),
241 | 'optimizer' : optimizer.state_dict()
242 | }, os.path.join(model_dir,"model_epoch_{}.pth".format(epoch)))
243 | print("Now time is : ",datetime.datetime.now().isoformat())
244 |
--------------------------------------------------------------------------------
/train/train_derain.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | dir_name = os.path.dirname(os.path.abspath(__file__))
5 | sys.path.append(os.path.join(dir_name,'../dataset/'))
6 | sys.path.append(os.path.join(dir_name,'..'))
7 | print(sys.path)
8 | print(dir_name)
9 |
10 | import argparse
11 | import options
12 | ######### parser ###########
13 | opt = options.Options().init(argparse.ArgumentParser(description='Image derain')).parse_args()
14 | print(opt)
15 |
16 | import utils
17 | from dataset.dataset_derain import *
18 | ######### Set GPUs ###########
19 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
20 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
21 | import torch
22 | torch.backends.cudnn.benchmark = True
23 |
24 | import torch.nn as nn
25 | import torch.optim as optim
26 | from torch.utils.data import DataLoader
27 | from natsort import natsorted
28 | import glob
29 | import random
30 | import time
31 | import numpy as np
32 | from einops import rearrange, repeat
33 | import datetime
34 | from pdb import set_trace as stx
35 |
36 | from losses import CharbonnierLoss
37 |
38 | from tqdm import tqdm
39 | from warmup_scheduler import GradualWarmupScheduler
40 | from torch.optim.lr_scheduler import StepLR
41 | from timm.utils import NativeScaler
42 |
43 |
44 | ######### Logs dir ###########
45 | log_dir = os.path.join(opt.save_dir, 'derain', opt.dataset, opt.arch+opt.env)
46 | if not os.path.exists(log_dir):
47 | os.makedirs(log_dir)
48 | logname = os.path.join(log_dir, datetime.datetime.now().isoformat()+'.txt')
49 | print("Now time is : ",datetime.datetime.now().isoformat())
50 | result_dir = os.path.join(log_dir, 'results')
51 | model_dir = os.path.join(log_dir, 'models')
52 | utils.mkdir(result_dir)
53 | utils.mkdir(model_dir)
54 |
55 | # ######### Set Seeds ###########
56 | random.seed(1234)
57 | np.random.seed(1234)
58 | torch.manual_seed(1234)
59 | torch.cuda.manual_seed_all(1234)
60 |
61 | ######### Model ###########
62 | model_restoration = utils.get_arch(opt)
63 |
64 | with open(logname,'a') as f:
65 | f.write(str(opt)+'\n')
66 | f.write(str(model_restoration)+'\n')
67 |
68 | ######### Optimizer ###########
69 | start_epoch = 1
70 | if opt.optimizer.lower() == 'adam':
71 | optimizer = optim.Adam(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay)
72 | elif opt.optimizer.lower() == 'adamw':
73 | optimizer = optim.AdamW(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay)
74 | else:
75 | raise Exception("Error optimizer...")
76 |
77 |
78 | ######### DataParallel ###########
79 | model_restoration = torch.nn.DataParallel (model_restoration)
80 | model_restoration.cuda()
81 |
82 |
83 | ######### Scheduler ###########
84 | if opt.warmup:
85 | print("Using warmup and cosine strategy!")
86 | warmup_epochs = opt.warmup_epochs
87 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.nepoch-warmup_epochs, eta_min=1e-6)
88 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
89 | scheduler.step()
90 | else:
91 | step = 50
92 | print("Using StepLR,step={}!".format(step))
93 | scheduler = StepLR(optimizer, step_size=step, gamma=0.5)
94 | scheduler.step()
95 |
96 | ######### Resume ###########
97 | if opt.resume:
98 | path_chk_rest = opt.pretrain_weights
99 | print("Resume from "+path_chk_rest)
100 | utils.load_checkpoint(model_restoration,path_chk_rest)
101 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1
102 | lr = utils.load_optim(optimizer, path_chk_rest)
103 |
104 | for i in range(1, start_epoch):
105 | scheduler.step()
106 | new_lr = scheduler.get_lr()[0]
107 | print('------------------------------------------------------------------------------')
108 | print("==> Resuming Training with learning rate:", new_lr)
109 | print('------------------------------------------------------------------------------')
110 |
111 | ######### Retrain ###########
112 | if opt.retrain:
113 | path_chk_rest = opt.pretrain_weights
114 | print("Retrain from "+path_chk_rest)
115 | utils.load_checkpoint(model_restoration,path_chk_rest)
116 |
117 | print('------------------------------------------------------------------------------')
118 | print("==> Re Training with learning rate:", opt.lr_initial)
119 | print('------------------------------------------------------------------------------')
120 |
121 | ######### Loss ###########
122 | criterion = CharbonnierLoss().cuda()
123 |
124 | ######### DataLoader ###########
125 | print('===> Loading datasets')
126 | img_options_train = {'patch_size':opt.train_ps}
127 | train_dataset = get_training_data(opt.train_dir, img_options_train)
128 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True,
129 | num_workers=opt.train_workers, pin_memory=False, drop_last=False)
130 | val_dataset = get_validation_data(opt.val_dir)
131 | val_loader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=False,
132 | num_workers=opt.eval_workers, pin_memory=False, drop_last=False)
133 |
134 | len_trainset = train_dataset.__len__()
135 | len_valset = val_dataset.__len__()
136 | print("Sizeof training set: ", len_trainset,", sizeof validation set: ", len_valset)
137 | ######### validation ###########
138 |
139 | with torch.no_grad():
140 | model_restoration.eval()
141 | psnr_dataset = []
142 | psnr_model_init = []
143 | for ii, data_val in enumerate((val_loader), 0):
144 | target = data_val[0].cuda()
145 | input_ = data_val[1].cuda()
146 | with torch.cuda.amp.autocast():
147 | restored = model_restoration(input_)
148 | restored = torch.clamp(restored,0,1)
149 | psnr_dataset.append(utils.batch_PSNR(input_, target, False, 'y').item())
150 | psnr_model_init.append(utils.batch_PSNR(restored, target, False, 'y').item())
151 | psnr_dataset = sum(psnr_dataset)/len_valset
152 | psnr_model_init = sum(psnr_model_init)/len_valset
153 | print('Input & GT (PSNR) -->%.4f dB'%(psnr_dataset), ', Model_init & GT (PSNR) -->%.4f dB'%(psnr_model_init))
154 |
155 | ######### train ###########
156 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.nepoch))
157 | best_psnr = 0
158 | best_epoch = 0
159 | best_iter = 0
160 | eval_now = len(train_loader)//4
161 | print("\nEvaluation after every {} Iterations !!!\n".format(eval_now))
162 |
163 | loss_scaler = NativeScaler()
164 | torch.cuda.empty_cache()
165 | for epoch in range(start_epoch, opt.nepoch + 1):
166 | epoch_start_time = time.time()
167 | epoch_loss = 0
168 | train_id = 1
169 |
170 | for i, data in enumerate(tqdm(train_loader), 0):
171 | # zero_grad
172 | optimizer.zero_grad()
173 |
174 | target = data[0].cuda()
175 | input_ = data[1].cuda()
176 |
177 | if epoch>5:
178 | target, input_ = utils.MixUp_AUG().aug(target, input_)
179 | with torch.cuda.amp.autocast():
180 | restored = model_restoration(input_)
181 | loss = criterion(restored, target)
182 | loss_scaler(
183 | loss, optimizer,parameters=model_restoration.parameters())
184 | epoch_loss +=loss.item()
185 |
186 | #### Evaluation ####
187 | if (i+1)%eval_now==0 and i>0:
188 | with torch.no_grad():
189 | model_restoration.eval()
190 | psnr_val_rgb = []
191 | for ii, data_val in enumerate((val_loader), 0):
192 | target = data_val[0].cuda()
193 | input_ = data_val[1].cuda()
194 | filenames = data_val[2]
195 | with torch.cuda.amp.autocast():
196 | restored = model_restoration(input_)
197 | restored = torch.clamp(restored,0,1)
198 | psnr_val_rgb.append(utils.batch_PSNR(restored, target, False, 'y').item())
199 |
200 | psnr_val_rgb = sum(psnr_val_rgb)/len_valset
201 |
202 | if psnr_val_rgb > best_psnr:
203 | best_psnr = psnr_val_rgb
204 | best_epoch = epoch
205 | best_iter = i
206 | torch.save({'epoch': epoch,
207 | 'state_dict': model_restoration.state_dict(),
208 | 'optimizer' : optimizer.state_dict()
209 | }, os.path.join(model_dir,"model_best.pth"))
210 |
211 | print("[Ep %d it %d\t PSNR SPAD: %.4f\t] ---- [best_Ep_SPAD %d best_it_SPAD %d Best_PSNR_SPAD %.4f] " % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr))
212 | with open(logname,'a') as f:
213 | f.write("[Ep %d it %d\t PSNR SPAD: %.4f\t] ---- [best_Ep_SPAD %d best_it_SPAD %d Best_PSNR_SPAD %.4f] " \
214 | % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr)+'\n')
215 | model_restoration.train()
216 | torch.cuda.empty_cache()
217 | scheduler.step()
218 |
219 | print("------------------------------------------------------------------")
220 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0]))
221 | print("------------------------------------------------------------------")
222 | with open(logname,'a') as f:
223 | f.write("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0])+'\n')
224 |
225 | torch.save({'epoch': epoch,
226 | 'state_dict': model_restoration.state_dict(),
227 | 'optimizer' : optimizer.state_dict()
228 | }, os.path.join(model_dir,"model_latest.pth"))
229 |
230 | if epoch%opt.checkpoint == 0:
231 | torch.save({'epoch': epoch,
232 | 'state_dict': model_restoration.state_dict(),
233 | 'optimizer' : optimizer.state_dict()
234 | }, os.path.join(model_dir,"model_epoch_{}.pth".format(epoch)))
235 | print("Now time is : ",datetime.datetime.now().isoformat())
236 |
--------------------------------------------------------------------------------
/train/train_raindrop.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 |
4 | dir_name = os.path.dirname(os.path.abspath(__file__))
5 | sys.path.append(os.path.join(dir_name,'../dataset/'))
6 | sys.path.append(os.path.join(dir_name,'..'))
7 | print(sys.path)
8 | print(dir_name)
9 |
10 | import argparse
11 | import options
12 | ######### parser ###########
13 | opt = options.Options().init(argparse.ArgumentParser(description='Image derain')).parse_args()
14 | print(opt)
15 |
16 | import utils
17 | from dataset.dataset_derain_drop import *
18 | ######### Set GPUs ###########
19 | os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
20 | os.environ["CUDA_VISIBLE_DEVICES"] = opt.gpu
21 | import torch
22 | torch.backends.cudnn.benchmark = True
23 |
24 | import torch.nn as nn
25 | import torch.optim as optim
26 | from torch.utils.data import DataLoader
27 | from natsort import natsorted
28 | import glob
29 | import random
30 | import time
31 | import numpy as np
32 | from einops import rearrange, repeat
33 | import datetime
34 | from pdb import set_trace as stx
35 |
36 | from losses import CharbonnierLoss
37 |
38 | from tqdm import tqdm
39 | from warmup_scheduler import GradualWarmupScheduler
40 | from torch.optim.lr_scheduler import StepLR
41 | from timm.utils import NativeScaler
42 |
43 |
44 | ######### Logs dir ###########
45 | log_dir = os.path.join(opt.save_dir, 'derain', opt.dataset, opt.arch+opt.env)
46 | if not os.path.exists(log_dir):
47 | os.makedirs(log_dir)
48 | logname = os.path.join(log_dir, datetime.datetime.now().isoformat()+'.txt')
49 | print("Now time is : ",datetime.datetime.now().isoformat())
50 | result_dir = os.path.join(log_dir, 'results')
51 | model_dir = os.path.join(log_dir, 'models')
52 | utils.mkdir(result_dir)
53 | utils.mkdir(model_dir)
54 |
55 | # ######### Set Seeds ###########
56 | random.seed(1234)
57 | np.random.seed(1234)
58 | torch.manual_seed(1234)
59 | torch.cuda.manual_seed_all(1234)
60 |
61 | ######### Model ###########
62 | model_restoration = utils.get_arch(opt)
63 |
64 | with open(logname,'a') as f:
65 | f.write(str(opt)+'\n')
66 | f.write(str(model_restoration)+'\n')
67 |
68 | ######### Optimizer ###########
69 | start_epoch = 1
70 | if opt.optimizer.lower() == 'adam':
71 | optimizer = optim.Adam(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay)
72 | elif opt.optimizer.lower() == 'adamw':
73 | optimizer = optim.AdamW(model_restoration.parameters(), lr=opt.lr_initial, betas=(0.9, 0.999),eps=1e-8, weight_decay=opt.weight_decay)
74 | else:
75 | raise Exception("Error optimizer...")
76 |
77 |
78 | ######### DataParallel ###########
79 | # os.environ["CUDA_VISIBLE_DEVICES"] = '0,1'
80 | model_restoration = torch.nn.DataParallel (model_restoration)
81 | model_restoration.cuda()
82 |
83 |
84 | ######### Scheduler ###########
85 | if opt.warmup:
86 | print("Using warmup and cosine strategy!")
87 | warmup_epochs = opt.warmup_epochs
88 | scheduler_cosine = optim.lr_scheduler.CosineAnnealingLR(optimizer, opt.nepoch-warmup_epochs, eta_min=1e-6)
89 | scheduler = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=warmup_epochs, after_scheduler=scheduler_cosine)
90 | scheduler.step()
91 | else:
92 | step = 50
93 | print("Using StepLR,step={}!".format(step))
94 | scheduler = StepLR(optimizer, step_size=step, gamma=0.5)
95 | scheduler.step()
96 |
97 | ######### Resume ###########
98 | if opt.resume:
99 | path_chk_rest = opt.pretrain_weights
100 | print("Resume from "+path_chk_rest)
101 | utils.load_checkpoint(model_restoration,path_chk_rest)
102 | start_epoch = utils.load_start_epoch(path_chk_rest) + 1
103 | lr = utils.load_optim(optimizer, path_chk_rest)
104 |
105 | for i in range(1, start_epoch):
106 | scheduler.step()
107 | new_lr = scheduler.get_lr()[0]
108 | print('------------------------------------------------------------------------------')
109 | print("==> Resuming Training with learning rate:", new_lr)
110 | print('------------------------------------------------------------------------------')
111 |
112 | ######### Retrain ###########
113 | if opt.retrain:
114 | path_chk_rest = opt.pretrain_weights
115 | print("Retrain from "+path_chk_rest)
116 | utils.load_checkpoint(model_restoration,path_chk_rest)
117 |
118 | print('------------------------------------------------------------------------------')
119 | print("==> Re Training with learning rate:", opt.lr_initial)
120 | print('------------------------------------------------------------------------------')
121 |
122 | ######### Loss ###########
123 | criterion = CharbonnierLoss().cuda()
124 |
125 | ######### DataLoader ###########
126 | print('===> Loading datasets')
127 | img_options_train = {'patch_size':opt.train_ps}
128 | train_dataset = get_training_data(opt.train_dir, img_options_train)
129 | train_loader = DataLoader(dataset=train_dataset, batch_size=opt.batch_size, shuffle=True,
130 | num_workers=opt.train_workers, pin_memory=False, drop_last=False)
131 | img_options_val = {'patch_size':opt.val_ps}
132 | val_dataset = get_validation_data(opt.val_dir,img_options_val)
133 | val_loader = DataLoader(dataset=val_dataset, batch_size=4, shuffle=False,
134 | num_workers=opt.eval_workers, pin_memory=False, drop_last=False)
135 |
136 | len_trainset = train_dataset.__len__()
137 | len_valset = val_dataset.__len__()
138 | print("Sizeof training set: ", len_trainset,", sizeof validation set: ", len_valset)
139 | ######### validation ###########
140 |
141 | with torch.no_grad():
142 | model_restoration.eval()
143 | psnr_dataset = []
144 | psnr_model_init = []
145 | for ii, data_val in enumerate((val_loader), 0):
146 | target = data_val[0].cuda()
147 | input_ = data_val[1].cuda()
148 | with torch.cuda.amp.autocast():
149 | restored = model_restoration(input_)
150 | restored = torch.clamp(restored,0,1)
151 | psnr_dataset.append(utils.batch_PSNR(input_, target, False, 'y').item())
152 | psnr_model_init.append(utils.batch_PSNR(restored, target, False, 'y').item())
153 | psnr_dataset = sum(psnr_dataset)/len_valset
154 | psnr_model_init = sum(psnr_model_init)/len_valset
155 | print('Input & GT (PSNR) -->%.4f dB'%(psnr_dataset), ', Model_init & GT (PSNR) -->%.4f dB'%(psnr_model_init))
156 |
157 | ######### train ###########
158 | print('===> Start Epoch {} End Epoch {}'.format(start_epoch,opt.nepoch))
159 | best_psnr = 0
160 | best_epoch = 0
161 | best_iter = 0
162 | eval_now = len(train_loader)//4
163 | print("\nEvaluation after every {} Iterations !!!\n".format(eval_now))
164 |
165 | loss_scaler = NativeScaler()
166 | torch.cuda.empty_cache()
167 | for epoch in range(start_epoch, opt.nepoch + 1):
168 | epoch_start_time = time.time()
169 | epoch_loss = 0
170 | train_id = 1
171 |
172 | for i, data in enumerate(tqdm(train_loader), 0):
173 | # zero_grad
174 | optimizer.zero_grad()
175 |
176 | target = data[0].cuda()
177 | input_ = data[1].cuda()
178 |
179 | if epoch>5:
180 | target, input_ = utils.MixUp_AUG().aug(target, input_)
181 | with torch.cuda.amp.autocast():
182 | restored = model_restoration(input_)
183 | loss = criterion(restored, target)
184 | loss_scaler(
185 | loss, optimizer,parameters=model_restoration.parameters())
186 | epoch_loss +=loss.item()
187 |
188 | #### Evaluation ####
189 | if (i+1)%eval_now==0 and i>0:
190 | with torch.no_grad():
191 | model_restoration.eval()
192 | psnr_val_rgb = []
193 | for ii, data_val in enumerate((val_loader), 0):
194 | target = data_val[0].cuda()
195 | input_ = data_val[1].cuda()
196 | filenames = data_val[2]
197 | with torch.cuda.amp.autocast():
198 | restored = model_restoration(input_)
199 | restored = torch.clamp(restored,0,1)
200 | psnr_val_rgb.append(utils.batch_PSNR(restored, target, False, 'y').item())
201 |
202 | psnr_val_rgb = sum(psnr_val_rgb)/len_valset
203 |
204 | if psnr_val_rgb > best_psnr:
205 | best_psnr = psnr_val_rgb
206 | best_epoch = epoch
207 | best_iter = i
208 | torch.save({'epoch': epoch,
209 | 'state_dict': model_restoration.state_dict(),
210 | 'optimizer' : optimizer.state_dict()
211 | }, os.path.join(model_dir,"model_best.pth"))
212 |
213 | print("[Ep %d it %d\t PSNR SPAD: %.4f\t] ---- [best_Ep_SPAD %d best_it_SPAD %d Best_PSNR_SPAD %.4f] " % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr))
214 | with open(logname,'a') as f:
215 | f.write("[Ep %d it %d\t PSNR SPAD: %.4f\t] ---- [best_Ep_SPAD %d best_it_SPAD %d Best_PSNR_SPAD %.4f] " \
216 | % (epoch, i, psnr_val_rgb,best_epoch,best_iter,best_psnr)+'\n')
217 | model_restoration.train()
218 | torch.cuda.empty_cache()
219 | scheduler.step()
220 |
221 | print("------------------------------------------------------------------")
222 | print("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0]))
223 | print("------------------------------------------------------------------")
224 | with open(logname,'a') as f:
225 | f.write("Epoch: {}\tTime: {:.4f}\tLoss: {:.4f}\tLearningRate {:.6f}".format(epoch, time.time()-epoch_start_time,epoch_loss, scheduler.get_lr()[0])+'\n')
226 |
227 | torch.save({'epoch': epoch,
228 | 'state_dict': model_restoration.state_dict(),
229 | 'optimizer' : optimizer.state_dict()
230 | }, os.path.join(model_dir,"model_latest.pth"))
231 |
232 | if epoch%opt.checkpoint == 0:
233 | torch.save({'epoch': epoch,
234 | 'state_dict': model_restoration.state_dict(),
235 | 'optimizer' : optimizer.state_dict()
236 | }, os.path.join(model_dir,"model_epoch_{}.pth".format(epoch)))
237 | print("Now time is : ",datetime.datetime.now().isoformat())
238 |
--------------------------------------------------------------------------------
/utils/.DS_Store:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/.DS_Store
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
1 | from .dir_utils import *
2 | from .dataset_utils import *
3 | from .image_utils import *
4 | from .model_utils import *
5 | from .caculate_psnr_ssim import *
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/caculate_psnr_ssim.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/caculate_psnr_ssim.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/caculate_psnr_ssim.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/caculate_psnr_ssim.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dataset_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/dataset_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dataset_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/dataset_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dir_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/dir_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/dir_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/dir_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/image_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/image_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/image_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/image_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/model_utils.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/model_utils.cpython-37.pyc
--------------------------------------------------------------------------------
/utils/__pycache__/model_utils.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/utils/__pycache__/model_utils.cpython-38.pyc
--------------------------------------------------------------------------------
/utils/antialias.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.parallel
3 | import numpy as np
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | class Downsample(nn.Module):
8 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
9 | super(Downsample, self).__init__()
10 | self.filt_size = filt_size
11 | self.pad_off = pad_off
12 | self.pad_sizes = [int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)), int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2))]
13 | self.pad_sizes = [pad_size+pad_off for pad_size in self.pad_sizes]
14 | self.stride = stride
15 | self.off = int((self.stride-1)/2.)
16 | self.channels = channels
17 |
18 | # print('Filter size [%i]'%filt_size)
19 | if(self.filt_size==1):
20 | a = np.array([1.,])
21 | elif(self.filt_size==2):
22 | a = np.array([1., 1.])
23 | elif(self.filt_size==3):
24 | a = np.array([1., 2., 1.])
25 | elif(self.filt_size==4):
26 | a = np.array([1., 3., 3., 1.])
27 | elif(self.filt_size==5):
28 | a = np.array([1., 4., 6., 4., 1.])
29 | elif(self.filt_size==6):
30 | a = np.array([1., 5., 10., 10., 5., 1.])
31 | elif(self.filt_size==7):
32 | a = np.array([1., 6., 15., 20., 15., 6., 1.])
33 |
34 | filt = torch.Tensor(a[:,None]*a[None,:])
35 | filt = filt/torch.sum(filt)
36 | self.register_buffer('filt', filt[None,None,:,:].repeat((self.channels,1,1,1)))
37 |
38 | self.pad = get_pad_layer(pad_type)(self.pad_sizes)
39 |
40 | def forward(self, inp):
41 | if(self.filt_size==1):
42 | if(self.pad_off==0):
43 | return inp[:,:,::self.stride,::self.stride]
44 | else:
45 | return self.pad(inp)[:,:,::self.stride,::self.stride]
46 | else:
47 | return F.conv2d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
48 |
49 | def get_pad_layer(pad_type):
50 | if(pad_type in ['refl','reflect']):
51 | PadLayer = nn.ReflectionPad2d
52 | elif(pad_type in ['repl','replicate']):
53 | PadLayer = nn.ReplicationPad2d
54 | elif(pad_type=='zero'):
55 | PadLayer = nn.ZeroPad2d
56 | else:
57 | print('Pad type [%s] not recognized'%pad_type)
58 | return PadLayer
59 |
60 |
61 | class Downsample1D(nn.Module):
62 | def __init__(self, pad_type='reflect', filt_size=3, stride=2, channels=None, pad_off=0):
63 | super(Downsample1D, self).__init__()
64 | self.filt_size = filt_size
65 | self.pad_off = pad_off
66 | self.pad_sizes = [int(1. * (filt_size - 1) / 2), int(np.ceil(1. * (filt_size - 1) / 2))]
67 | self.pad_sizes = [pad_size + pad_off for pad_size in self.pad_sizes]
68 | self.stride = stride
69 | self.off = int((self.stride - 1) / 2.)
70 | self.channels = channels
71 |
72 | # print('Filter size [%i]' % filt_size)
73 | if(self.filt_size == 1):
74 | a = np.array([1., ])
75 | elif(self.filt_size == 2):
76 | a = np.array([1., 1.])
77 | elif(self.filt_size == 3):
78 | a = np.array([1., 2., 1.])
79 | elif(self.filt_size == 4):
80 | a = np.array([1., 3., 3., 1.])
81 | elif(self.filt_size == 5):
82 | a = np.array([1., 4., 6., 4., 1.])
83 | elif(self.filt_size == 6):
84 | a = np.array([1., 5., 10., 10., 5., 1.])
85 | elif(self.filt_size == 7):
86 | a = np.array([1., 6., 15., 20., 15., 6., 1.])
87 |
88 | filt = torch.Tensor(a)
89 | filt = filt / torch.sum(filt)
90 | self.register_buffer('filt', filt[None, None, :].repeat((self.channels, 1, 1)))
91 |
92 | self.pad = get_pad_layer_1d(pad_type)(self.pad_sizes)
93 |
94 | def forward(self, inp):
95 | if(self.filt_size == 1):
96 | if(self.pad_off == 0):
97 | return inp[:, :, ::self.stride]
98 | else:
99 | return self.pad(inp)[:, :, ::self.stride]
100 | else:
101 | return F.conv1d(self.pad(inp), self.filt, stride=self.stride, groups=inp.shape[1])
102 |
103 |
104 | def get_pad_layer_1d(pad_type):
105 | if(pad_type in ['refl', 'reflect']):
106 | PadLayer = nn.ReflectionPad1d
107 | elif(pad_type in ['repl', 'replicate']):
108 | PadLayer = nn.ReplicationPad1d
109 | elif(pad_type == 'zero'):
110 | PadLayer = nn.ZeroPad1d
111 | else:
112 | print('Pad type [%s] not recognized' % pad_type)
113 | return PadLayer
--------------------------------------------------------------------------------
/utils/bundle_submissions.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import scipy.io as sio
3 | import os
4 | import h5py
5 |
6 | def bundle_submissions_raw(submission_folder,session):
7 | '''
8 | Bundles submission data for raw denoising
9 | submission_folder Folder where denoised images reside
10 | Output is written to /bundled/. Please submit
11 | the content of this folder.
12 | '''
13 |
14 | out_folder = os.path.join(submission_folder, session)
15 | # out_folder = os.path.join(submission_folder, "bundled/")
16 | try:
17 | os.mkdir(out_folder)
18 | except:pass
19 |
20 | israw = True
21 | eval_version="1.0"
22 |
23 | for i in range(50):
24 | Idenoised = np.zeros((20,), dtype=np.object)
25 | for bb in range(20):
26 | filename = '%04d_%02d.mat'%(i+1,bb+1)
27 | s = sio.loadmat(os.path.join(submission_folder,filename))
28 | Idenoised_crop = s["Idenoised_crop"]
29 | Idenoised[bb] = Idenoised_crop
30 | filename = '%04d.mat'%(i+1)
31 | sio.savemat(os.path.join(out_folder, filename),
32 | {"Idenoised": Idenoised,
33 | "israw": israw,
34 | "eval_version": eval_version},
35 | )
36 |
37 | def bundle_submissions_srgb(submission_folder,session):
38 | '''
39 | Bundles submission data for sRGB denoising
40 |
41 | submission_folder Folder where denoised images reside
42 | Output is written to /bundled/. Please submit
43 | the content of this folder.
44 | '''
45 | out_folder = os.path.join(submission_folder, session)
46 | # out_folder = os.path.join(submission_folder, "bundled/")
47 | try:
48 | os.mkdir(out_folder)
49 | except:pass
50 | israw = False
51 | eval_version="1.0"
52 |
53 | for i in range(50):
54 | Idenoised = np.zeros((20,), dtype=np.object)
55 | for bb in range(20):
56 | filename = '%04d_%02d.mat'%(i+1,bb+1)
57 | s = sio.loadmat(os.path.join(submission_folder,filename))
58 | Idenoised_crop = s["Idenoised_crop"]
59 | Idenoised[bb] = Idenoised_crop
60 | filename = '%04d.mat'%(i+1)
61 | sio.savemat(os.path.join(out_folder, filename),
62 | {"Idenoised": Idenoised,
63 | "israw": israw,
64 | "eval_version": eval_version},
65 | )
66 |
67 |
68 |
69 | def bundle_submissions_srgb_v1(submission_folder,session):
70 | '''
71 | Bundles submission data for sRGB denoising
72 |
73 | submission_folder Folder where denoised images reside
74 | Output is written to /bundled/. Please submit
75 | the content of this folder.
76 | '''
77 | out_folder = os.path.join(submission_folder, session)
78 | # out_folder = os.path.join(submission_folder, "bundled/")
79 | try:
80 | os.mkdir(out_folder)
81 | except:pass
82 | israw = False
83 | eval_version="1.0"
84 |
85 | for i in range(50):
86 | Idenoised = np.zeros((20,), dtype=np.object)
87 | for bb in range(20):
88 | filename = '%04d_%d.mat'%(i+1,bb+1)
89 | s = sio.loadmat(os.path.join(submission_folder,filename))
90 | Idenoised_crop = s["Idenoised_crop"]
91 | Idenoised[bb] = Idenoised_crop
92 | filename = '%04d.mat'%(i+1)
93 | sio.savemat(os.path.join(out_folder, filename),
94 | {"Idenoised": Idenoised,
95 | "israw": israw,
96 | "eval_version": eval_version},
97 | )
--------------------------------------------------------------------------------
/utils/caculate_psnr_ssim.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | import numpy as np
3 | import torch
4 |
5 | # convert 2/3/4-dimensional torch tensor to uint
6 | def tensor2uint(img):
7 | img = img.data.squeeze().float().clamp_(0, 1).cpu().numpy()
8 | if img.ndim == 3:
9 | img = np.transpose(img, (1, 2, 0))
10 | return np.uint8((img*255.0).round())
11 |
12 | def calculate_psnr(img1, img2, crop_border=0, input_order='HWC', test_y_channel=False):
13 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
14 | if input_order not in ['HWC', 'CHW']:
15 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
16 | img1 = reorder_image(img1, input_order=input_order)
17 | img2 = reorder_image(img2, input_order=input_order)
18 | img1 = img1.astype(np.float64)
19 | img2 = img2.astype(np.float64)
20 |
21 | if crop_border != 0:
22 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
23 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
24 |
25 | if test_y_channel:
26 | img1 = to_y_channel(img1)
27 | img2 = to_y_channel(img2)
28 |
29 | mse = np.mean((img1 - img2) ** 2)
30 | if mse == 0:
31 | return float('inf')
32 | return 20. * np.log10(255. / np.sqrt(mse))
33 |
34 |
35 | def _ssim(img1, img2):
36 | C1 = (0.01 * 255) ** 2
37 | C2 = (0.03 * 255) ** 2
38 |
39 | img1 = img1.astype(np.float64)
40 | img2 = img2.astype(np.float64)
41 | kernel = cv2.getGaussianKernel(11, 1.5)
42 | window = np.outer(kernel, kernel.transpose())
43 |
44 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
45 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
46 | mu1_sq = mu1 ** 2
47 | mu2_sq = mu2 ** 2
48 | mu1_mu2 = mu1 * mu2
49 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
50 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
51 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
52 |
53 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
54 | return ssim_map.mean()
55 |
56 |
57 | def calculate_ssim(img1, img2, crop_border=0, input_order='HWC', test_y_channel=False):
58 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
59 | if img1.dtype is not np.uint8:
60 | img1 = (img1 * 255.0).round().astype(np.uint8) # float32 to uint8
61 | if img2.dtype is not np.uint8:
62 | img2 = (img2 * 255.0).round().astype(np.uint8) # float32 to uint8
63 | if input_order not in ['HWC', 'CHW']:
64 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
65 | img1 = reorder_image(img1, input_order=input_order)
66 | img2 = reorder_image(img2, input_order=input_order)
67 | img1 = img1.astype(np.float64)
68 | img2 = img2.astype(np.float64)
69 |
70 | if crop_border != 0:
71 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
72 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
73 |
74 | if test_y_channel:
75 | img1 = to_y_channel(img1)
76 | img2 = to_y_channel(img2)
77 |
78 | ssims = []
79 | for i in range(img1.shape[2]):
80 | ssims.append(_ssim(img1[..., i], img2[..., i]))
81 | return np.array(ssims).mean()
82 |
83 |
84 | def _blocking_effect_factor(im):
85 | block_size = 8
86 |
87 | block_horizontal_positions = torch.arange(7, im.shape[3] - 1, 8)
88 | block_vertical_positions = torch.arange(7, im.shape[2] - 1, 8)
89 |
90 | horizontal_block_difference = (
91 | (im[:, :, :, block_horizontal_positions] - im[:, :, :, block_horizontal_positions + 1]) ** 2).sum(
92 | 3).sum(2).sum(1)
93 | vertical_block_difference = (
94 | (im[:, :, block_vertical_positions, :] - im[:, :, block_vertical_positions + 1, :]) ** 2).sum(3).sum(
95 | 2).sum(1)
96 |
97 | nonblock_horizontal_positions = np.setdiff1d(torch.arange(0, im.shape[3] - 1), block_horizontal_positions)
98 | nonblock_vertical_positions = np.setdiff1d(torch.arange(0, im.shape[2] - 1), block_vertical_positions)
99 |
100 | horizontal_nonblock_difference = (
101 | (im[:, :, :, nonblock_horizontal_positions] - im[:, :, :, nonblock_horizontal_positions + 1]) ** 2).sum(
102 | 3).sum(2).sum(1)
103 | vertical_nonblock_difference = (
104 | (im[:, :, nonblock_vertical_positions, :] - im[:, :, nonblock_vertical_positions + 1, :]) ** 2).sum(
105 | 3).sum(2).sum(1)
106 |
107 | n_boundary_horiz = im.shape[2] * (im.shape[3] // block_size - 1)
108 | n_boundary_vert = im.shape[3] * (im.shape[2] // block_size - 1)
109 | boundary_difference = (horizontal_block_difference + vertical_block_difference) / (
110 | n_boundary_horiz + n_boundary_vert)
111 |
112 | n_nonboundary_horiz = im.shape[2] * (im.shape[3] - 1) - n_boundary_horiz
113 | n_nonboundary_vert = im.shape[3] * (im.shape[2] - 1) - n_boundary_vert
114 | nonboundary_difference = (horizontal_nonblock_difference + vertical_nonblock_difference) / (
115 | n_nonboundary_horiz + n_nonboundary_vert)
116 |
117 | scaler = np.log2(block_size) / np.log2(min([im.shape[2], im.shape[3]]))
118 | bef = scaler * (boundary_difference - nonboundary_difference)
119 |
120 | bef[boundary_difference <= nonboundary_difference] = 0
121 | return bef
122 |
123 |
124 | def calculate_psnrb(img1, img2, crop_border, input_order='HWC', test_y_channel=False):
125 | assert img1.shape == img2.shape, (f'Image shapes are differnet: {img1.shape}, {img2.shape}.')
126 | if input_order not in ['HWC', 'CHW']:
127 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
128 | img1 = reorder_image(img1, input_order=input_order)
129 | img2 = reorder_image(img2, input_order=input_order)
130 | img1 = img1.astype(np.float64)
131 | img2 = img2.astype(np.float64)
132 |
133 | if crop_border != 0:
134 | img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
135 | img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
136 |
137 | if test_y_channel:
138 | img1 = to_y_channel(img1)
139 | img2 = to_y_channel(img2)
140 |
141 | img1 = torch.from_numpy(img1).permute(2, 0, 1).unsqueeze(0) / 255.
142 | img2 = torch.from_numpy(img2).permute(2, 0, 1).unsqueeze(0) / 255.
143 |
144 | total = 0
145 | for c in range(img1.shape[1]):
146 | mse = torch.nn.functional.mse_loss(img1[:, c:c + 1, :, :], img2[:, c:c + 1, :, :], reduction='none')
147 | bef = _blocking_effect_factor(img1[:, c:c + 1, :, :])
148 |
149 | mse = mse.view(mse.shape[0], -1).mean(1)
150 | total += 10 * torch.log10(1 / (mse + bef))
151 |
152 | return float(total) / img1.shape[1]
153 |
154 |
155 | def reorder_image(img, input_order='HWC'):
156 | if input_order not in ['HWC', 'CHW']:
157 | raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' "'HWC' and 'CHW'")
158 | if len(img.shape) == 2:
159 | img = img[..., None]
160 | if input_order == 'CHW':
161 | img = img.transpose(1, 2, 0)
162 | return img
163 |
164 |
165 | def to_y_channel(img):
166 | img = img.astype(np.float32) / 255.
167 | if img.ndim == 3 and img.shape[2] == 3:
168 | img = rgb2ycbcr(img, y_only=True)
169 | img = img[..., None]
170 | else:
171 | raise ValueError(f'Wrong image shape [2]: {img.shape[2]}.')
172 | return img * 255.
173 |
174 |
175 | def _convert_input_type_range(img):
176 | img_type = img.dtype
177 | img = img.astype(np.float32)
178 | if img_type == np.float32:
179 | pass
180 | elif img_type == np.uint8:
181 | img /= 255.
182 | else:
183 | raise TypeError('The img type should be np.float32 or np.uint8, ' f'but got {img_type}')
184 | return img
185 |
186 |
187 | def _convert_output_type_range(img, dst_type):
188 | if dst_type not in (np.uint8, np.float32):
189 | raise TypeError('The dst_type should be np.float32 or np.uint8, ' f'but got {dst_type}')
190 | if dst_type == np.uint8:
191 | img = img.round()
192 | else:
193 | img /= 255.
194 | return img.astype(dst_type)
195 |
196 |
197 | def rgb2ycbcr(img, y_only=False):
198 | img_type = img.dtype
199 | img = _convert_input_type_range(img)
200 | if y_only:
201 | out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
202 | else:
203 | out_img = np.matmul(
204 | img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],[24.966, 112.0, -18.214]]) + [16, 128, 128]
205 | out_img = _convert_output_type_range(out_img, img_type)
206 | return out_img
--------------------------------------------------------------------------------
/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import os
3 |
4 | ### rotate and flip
5 | class Augment_RGB_torch:
6 | def __init__(self):
7 | pass
8 | def transform0(self, torch_tensor):
9 | return torch_tensor
10 | def transform1(self, torch_tensor):
11 | torch_tensor = torch.rot90(torch_tensor, k=1, dims=[-1,-2])
12 | return torch_tensor
13 | def transform2(self, torch_tensor):
14 | torch_tensor = torch.rot90(torch_tensor, k=2, dims=[-1,-2])
15 | return torch_tensor
16 | def transform3(self, torch_tensor):
17 | torch_tensor = torch.rot90(torch_tensor, k=3, dims=[-1,-2])
18 | return torch_tensor
19 | def transform4(self, torch_tensor):
20 | torch_tensor = torch_tensor.flip(-2)
21 | return torch_tensor
22 | def transform5(self, torch_tensor):
23 | torch_tensor = (torch.rot90(torch_tensor, k=1, dims=[-1,-2])).flip(-2)
24 | return torch_tensor
25 | def transform6(self, torch_tensor):
26 | torch_tensor = (torch.rot90(torch_tensor, k=2, dims=[-1,-2])).flip(-2)
27 | return torch_tensor
28 | def transform7(self, torch_tensor):
29 | torch_tensor = (torch.rot90(torch_tensor, k=3, dims=[-1,-2])).flip(-2)
30 | return torch_tensor
31 |
32 |
33 | ### mix two images
34 | class MixUp_AUG:
35 | def __init__(self):
36 | self.dist = torch.distributions.beta.Beta(torch.tensor([1.2]), torch.tensor([1.2]))
37 |
38 | def aug(self, rgb_gt, rgb_noisy):
39 | bs = rgb_gt.size(0)
40 | indices = torch.randperm(bs)
41 | rgb_gt2 = rgb_gt[indices]
42 | rgb_noisy2 = rgb_noisy[indices]
43 |
44 | lam = self.dist.rsample((bs,1)).view(-1,1,1,1).cuda()
45 |
46 | rgb_gt = lam * rgb_gt + (1-lam) * rgb_gt2
47 | rgb_noisy = lam * rgb_noisy + (1-lam) * rgb_noisy2
48 |
49 | return rgb_gt, rgb_noisy
50 |
--------------------------------------------------------------------------------
/utils/dir_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from natsort import natsorted
3 | from glob import glob
4 |
5 | def mkdirs(paths):
6 | if isinstance(paths, list) and not isinstance(paths, str):
7 | for path in paths:
8 | mkdir(path)
9 | else:
10 | mkdir(paths)
11 |
12 | def mkdir(path):
13 | if not os.path.exists(path):
14 | os.makedirs(path)
15 |
16 | def get_last_path(path, session):
17 | x = natsorted(glob(os.path.join(path,'*%s'%session)))[-1]
18 | return x
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import pickle
4 | import cv2
5 | import math
6 | import PIL
7 | from PIL import Image
8 | import torchvision.transforms.functional as F
9 | import torch.nn.functional as f
10 |
11 | def is_numpy_file(filename):
12 | return any(filename.endswith(extension) for extension in [".npy"])
13 |
14 | def is_image_file(filename):
15 | return any(filename.endswith(extension) for extension in [".jpg"])
16 |
17 | def is_png_file(filename):
18 | return any(filename.endswith(extension) for extension in [".png"])
19 |
20 | def is_pkl_file(filename):
21 | return any(filename.endswith(extension) for extension in [".pkl"])
22 |
23 | def load_pkl(filename_):
24 | with open(filename_, 'rb') as f:
25 | ret_dict = pickle.load(f)
26 | return ret_dict
27 |
28 | def save_dict(dict_, filename_):
29 | with open(filename_, 'wb') as f:
30 | pickle.dump(dict_, f)
31 |
32 | def load_npy(file_path):
33 | img = np.load(file_path)
34 | img = img.astype(np.float32)
35 | img = img/255.
36 | img = img[:, :, [2, 1, 0]]
37 | return img
38 |
39 | def load_img(filepath):
40 | img = cv2.cvtColor(cv2.imread(filepath), cv2.COLOR_BGR2RGB)
41 | img = img.astype(np.float32)
42 | img = img/255.
43 | return img
44 |
45 |
46 | def load_reflectPad(filepath):
47 | img = Image.open(filepath).convert('RGB')
48 | factor = 128
49 | h,w = img.size
50 | H, W = ((h + factor) // factor) * factor, ((w + factor) // factor * factor)
51 | padh = H - h if h % factor != 0 else 0
52 | padw = W - w if w % factor != 0 else 0
53 | img = f.pad(img, (0, padw, 0, padh), 'reflect')
54 | img = np.float32(img)
55 | img = img / 255.
56 | return img
57 |
58 |
59 | def load_resize(filepath):
60 | img = Image.open(filepath).convert('RGB')
61 | wd_new, ht_new = img.size
62 | if ht_new > wd_new and ht_new > 1024:
63 | wd_new = int(np.ceil(wd_new * 1024 / ht_new))
64 | ht_new = 1024
65 | elif ht_new <= wd_new and wd_new > 1024:
66 | ht_new = int(np.ceil(ht_new * 1024 / wd_new))
67 | wd_new = 1024
68 | wd_new = int(128 * np.ceil(wd_new / 128.0))
69 | ht_new = int(128 * np.ceil(ht_new / 128.0))
70 | target_edge = wd_new if wd_new >= ht_new else ht_new
71 | img = img.resize((target_edge, target_edge), PIL.Image.ANTIALIAS)
72 |
73 | img = np.float32(img)
74 | img = img / 255.
75 | return img
76 |
77 |
78 | def loader4dehaze(filepath,ps=256):
79 | img = Image.open(filepath).convert('RGB')
80 | w, h = img.size
81 | if w < ps :
82 | padW = 1+(ps-w)//2
83 | img = F.pad(img, (padW,0,padW,0), 0, 'constant')
84 | if h < ps :
85 | padH = 1+(ps-h)//2
86 | img = F.pad(img, (0,padH,0,padH), 0, 'constant')
87 | img = np.float32(img)
88 | img = img / 255.
89 | return img
90 |
91 |
92 | def save_img(filepath, img):
93 | cv2.imwrite(filepath,cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
94 |
95 |
96 | def myPSNR(tar_img, prd_img, cal_type):
97 | imdff = torch.clamp(prd_img,0,1) - torch.clamp(tar_img,0,1)
98 |
99 | if cal_type == 'y':
100 | gray_coeffs = [65.738, 129.057, 25.064]
101 | convert = imdff.new_tensor(gray_coeffs).view(1, 3, 1, 1) / 256
102 | imdff = imdff.mul(convert).sum(dim=1)
103 |
104 | rmse = (imdff**2).mean().sqrt()
105 | ps = 20*torch.log10(1/rmse)
106 | return ps
107 |
108 | def batch_PSNR(img1, img2, average=True, cal_type='N'):
109 | PSNR = []
110 | for im1, im2 in zip(img1, img2):
111 | psnr = myPSNR(im1, im2, cal_type)
112 | PSNR.append(psnr)
113 | return sum(PSNR)/len(PSNR) if average else sum(PSNR)
114 |
115 | def splitimage(imgtensor, crop_size=128, overlap_size=64):
116 | _, C, H, W = imgtensor.shape
117 | hstarts = [x for x in range(0, H, crop_size - overlap_size)]
118 | while hstarts and hstarts[-1] + crop_size >= H:
119 | hstarts.pop()
120 | hstarts.append(H - crop_size)
121 | wstarts = [x for x in range(0, W, crop_size - overlap_size)]
122 | while wstarts and wstarts[-1] + crop_size >= W:
123 | wstarts.pop()
124 | wstarts.append(W - crop_size)
125 | starts = []
126 | split_data = []
127 | for hs in hstarts:
128 | for ws in wstarts:
129 | cimgdata = imgtensor[:, :, hs:hs + crop_size, ws:ws + crop_size]
130 | starts.append((hs, ws))
131 | split_data.append(cimgdata)
132 | return split_data, starts
133 |
134 | def get_scoremap(H, W, C, B=1, is_mean=True):
135 | center_h = H / 2
136 | center_w = W / 2
137 |
138 | score = torch.ones((B, C, H, W))
139 | if not is_mean:
140 | for h in range(H):
141 | for w in range(W):
142 | score[:, :, h, w] = 1.0 / (math.sqrt((h - center_h) ** 2 + (w - center_w) ** 2 + 1e-6))
143 | return score
144 |
145 | def mergeimage(split_data, starts, crop_size = 128, resolution=(1, 3, 128, 128)):
146 | B, C, H, W = resolution[0], resolution[1], resolution[2], resolution[3]
147 | tot_score = torch.zeros((B, C, H, W))
148 | merge_img = torch.zeros((B, C, H, W))
149 | scoremap = get_scoremap(crop_size, crop_size, C, B=B, is_mean=True)
150 | for simg, cstart in zip(split_data, starts):
151 | hs, ws = cstart
152 | merge_img[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap * simg
153 | tot_score[:, :, hs:hs + crop_size, ws:ws + crop_size] += scoremap
154 | merge_img = merge_img / tot_score
155 | return merge_img
--------------------------------------------------------------------------------
/utils/loader.py:
--------------------------------------------------------------------------------
1 | import os
2 |
3 | from dataset import DataLoaderTrain, DataLoaderVal, DataLoaderTest, DataLoaderTestSR
4 | def get_training_data(rgb_dir, img_options):
5 | assert os.path.exists(rgb_dir)
6 | return DataLoaderTrain(rgb_dir, img_options, None)
7 |
8 | def get_validation_data(rgb_dir):
9 | assert os.path.exists(rgb_dir)
10 | return DataLoaderVal(rgb_dir, None)
11 |
12 |
13 | def get_test_data(rgb_dir):
14 | assert os.path.exists(rgb_dir)
15 | return DataLoaderTest(rgb_dir, None)
16 |
17 |
18 | def get_test_data_SR(rgb_dir):
19 | assert os.path.exists(rgb_dir)
20 | return DataLoaderTestSR(rgb_dir, None)
--------------------------------------------------------------------------------
/utils/model_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import os
4 | from collections import OrderedDict
5 |
6 | def freeze(model):
7 | for p in model.parameters():
8 | p.requires_grad=False
9 |
10 | def unfreeze(model):
11 | for p in model.parameters():
12 | p.requires_grad=True
13 |
14 | def is_frozen(model):
15 | x = [p.requires_grad for p in model.parameters()]
16 | return not all(x)
17 |
18 | def save_checkpoint(model_dir, state, session):
19 | epoch = state['epoch']
20 | model_out_path = os.path.join(model_dir,"model_epoch_{}_{}.pth".format(epoch,session))
21 | torch.save(state, model_out_path)
22 |
23 | def load_checkpoint(model, weights):
24 | checkpoint = torch.load(weights)
25 | try:
26 | model.load_state_dict(checkpoint["state_dict"])
27 | except:
28 | state_dict = checkpoint["state_dict"]
29 | new_state_dict = OrderedDict()
30 | for k, v in state_dict.items():
31 | name = k[7:] if 'module.' in k else k
32 | new_state_dict[name] = v
33 | model.load_state_dict(new_state_dict)
34 |
35 |
36 | def load_checkpoint_multigpu(model, weights):
37 | checkpoint = torch.load(weights)
38 | state_dict = checkpoint["state_dict"]
39 | new_state_dict = OrderedDict()
40 | for k, v in state_dict.items():
41 | name = k[7:]
42 | new_state_dict[name] = v
43 | model.load_state_dict(new_state_dict)
44 |
45 | def load_start_epoch(weights):
46 | checkpoint = torch.load(weights)
47 | epoch = checkpoint["epoch"]
48 | return epoch
49 |
50 | def load_optim(optimizer, weights):
51 | checkpoint = torch.load(weights)
52 | optimizer.load_state_dict(checkpoint['optimizer'])
53 | for p in optimizer.param_groups: lr = p['lr']
54 | return lr
55 |
56 | def get_arch(opt):
57 | from model import AST
58 |
59 | arch = opt.arch
60 |
61 | print('You choose '+arch+'...')
62 | if arch == 'AST_T':
63 | model_restoration = AST(img_size=opt.train_ps,embed_dim=16,win_size=8,token_projection='linear',token_mlp=opt.token_mlp)
64 | elif arch == 'AST_B':
65 | model_restoration = AST(img_size=opt.train_ps,embed_dim=32,win_size=8,token_projection='linear',token_mlp=opt.token_mlp,
66 | depths=[1, 2, 8, 8, 2, 8, 8, 2, 1],dd_in=opt.dd_in)
67 | else:
68 | raise Exception("Arch error!")
69 |
70 | return model_restoration
--------------------------------------------------------------------------------
/warmup_scheduler/__init__.py:
--------------------------------------------------------------------------------
1 |
2 | from warmup_scheduler.scheduler import GradualWarmupScheduler
3 |
--------------------------------------------------------------------------------
/warmup_scheduler/__pycache__/__init__.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/warmup_scheduler/__pycache__/__init__.cpython-37.pyc
--------------------------------------------------------------------------------
/warmup_scheduler/__pycache__/__init__.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/warmup_scheduler/__pycache__/__init__.cpython-38.pyc
--------------------------------------------------------------------------------
/warmup_scheduler/__pycache__/scheduler.cpython-37.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/warmup_scheduler/__pycache__/scheduler.cpython-37.pyc
--------------------------------------------------------------------------------
/warmup_scheduler/__pycache__/scheduler.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/joshyZhou/AST/8d6420f1708ead3cb5dda84f67bd5bb91917cc5d/warmup_scheduler/__pycache__/scheduler.cpython-38.pyc
--------------------------------------------------------------------------------
/warmup_scheduler/run.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.optim.lr_scheduler import StepLR, ExponentialLR
3 | from torch.optim.sgd import SGD
4 |
5 | from warmup_scheduler import GradualWarmupScheduler
6 |
7 |
8 | if __name__ == '__main__':
9 | model = [torch.nn.Parameter(torch.randn(2, 2, requires_grad=True))]
10 | optim = SGD(model, 0.1)
11 |
12 | # scheduler_warmup is chained with schduler_steplr
13 | scheduler_steplr = StepLR(optim, step_size=10, gamma=0.1)
14 | scheduler_warmup = GradualWarmupScheduler(optim, multiplier=1, total_epoch=5, after_scheduler=scheduler_steplr)
15 |
16 | # this zero gradient update is needed to avoid a warning message, issue #8.
17 | optim.zero_grad()
18 | optim.step()
19 |
20 | for epoch in range(1, 20):
21 | scheduler_warmup.step(epoch)
22 | print(epoch, optim.param_groups[0]['lr'])
23 |
24 | optim.step() # backward pass (update network)
25 |
--------------------------------------------------------------------------------
/warmup_scheduler/scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 |
5 | class GradualWarmupScheduler(_LRScheduler):
6 | """ Gradually warm-up(increasing) learning rate in optimizer.
7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
8 |
9 | Args:
10 | optimizer (Optimizer): Wrapped optimizer.
11 | multiplier: target learning rate = base lr * multiplier if multiplier > 1.0. if multiplier = 1.0, lr starts from 0 and ends up with the base_lr.
12 | total_epoch: target learning rate is reached at total_epoch, gradually
13 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
14 | """
15 |
16 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
17 | self.multiplier = multiplier
18 | if self.multiplier < 1.:
19 | raise ValueError('multiplier should be greater thant or equal to 1.')
20 | self.total_epoch = total_epoch
21 | self.after_scheduler = after_scheduler
22 | self.finished = False
23 | super(GradualWarmupScheduler, self).__init__(optimizer)
24 |
25 | def get_lr(self):
26 | if self.last_epoch > self.total_epoch:
27 | if self.after_scheduler:
28 | if not self.finished:
29 | self.after_scheduler.base_lrs = [base_lr * self.multiplier for base_lr in self.base_lrs]
30 | self.finished = True
31 | return self.after_scheduler.get_lr()
32 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
33 |
34 | if self.multiplier == 1.0:
35 | return [base_lr * (float(self.last_epoch) / self.total_epoch) for base_lr in self.base_lrs]
36 | else:
37 | return [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
38 |
39 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
40 | if epoch is None:
41 | epoch = self.last_epoch + 1
42 | self.last_epoch = epoch if epoch != 0 else 1 # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
43 | if self.last_epoch <= self.total_epoch:
44 | warmup_lr = [base_lr * ((self.multiplier - 1.) * self.last_epoch / self.total_epoch + 1.) for base_lr in self.base_lrs]
45 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
46 | param_group['lr'] = lr
47 | else:
48 | if epoch is None:
49 | self.after_scheduler.step(metrics, None)
50 | else:
51 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
52 |
53 | def step(self, epoch=None, metrics=None):
54 | if type(self.after_scheduler) != ReduceLROnPlateau:
55 | if self.finished and self.after_scheduler:
56 | if epoch is None:
57 | self.after_scheduler.step(None)
58 | else:
59 | self.after_scheduler.step(epoch - self.total_epoch)
60 | else:
61 | return super(GradualWarmupScheduler, self).step(epoch)
62 | else:
63 | self.step_ReduceLROnPlateau(metrics, epoch)
64 |
--------------------------------------------------------------------------------