├── LICENSE
├── LICENSE.md
├── README.md
├── assets
├── DRH.png
├── Dn&DRL.png
├── SR&DN.png
├── adaptir_logo.png
├── classicSR.png
├── imgsli1.png
├── imgsli2.png
├── imgsli3.png
├── imgsli4.png
├── imgsli5.png
├── imgsli6.png
├── imgsli7.png
├── low-light.png
├── pipeline.png
└── scalabiltity.png
├── data_dir
├── noisy
│ └── denoise.txt
└── rainy
│ └── rainTrain.txt
├── environment.yaml
├── net
├── common.py
├── edt.py
└── ipt.py
├── options.py
├── test.py
├── train.py
├── utils
├── __init__.py
├── common.py
├── dataset_utils.py
├── degradation_utils.py
├── image_io.py
├── image_utils.py
├── imresize.py
├── loss_utils.py
├── schedulers.py
└── val_utils.py
└── val_options.py
/LICENSE:
--------------------------------------------------------------------------------
1 | Apache License
2 | Version 2.0, January 2004
3 | http://www.apache.org/licenses/
4 |
5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6 |
7 | 1. Definitions.
8 |
9 | "License" shall mean the terms and conditions for use, reproduction,
10 | and distribution as defined by Sections 1 through 9 of this document.
11 |
12 | "Licensor" shall mean the copyright owner or entity authorized by
13 | the copyright owner that is granting the License.
14 |
15 | "Legal Entity" shall mean the union of the acting entity and all
16 | other entities that control, are controlled by, or are under common
17 | control with that entity. For the purposes of this definition,
18 | "control" means (i) the power, direct or indirect, to cause the
19 | direction or management of such entity, whether by contract or
20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the
21 | outstanding shares, or (iii) beneficial ownership of such entity.
22 |
23 | "You" (or "Your") shall mean an individual or Legal Entity
24 | exercising permissions granted by this License.
25 |
26 | "Source" form shall mean the preferred form for making modifications,
27 | including but not limited to software source code, documentation
28 | source, and configuration files.
29 |
30 | "Object" form shall mean any form resulting from mechanical
31 | transformation or translation of a Source form, including but
32 | not limited to compiled object code, generated documentation,
33 | and conversions to other media types.
34 |
35 | "Work" shall mean the work of authorship, whether in Source or
36 | Object form, made available under the License, as indicated by a
37 | copyright notice that is included in or attached to the work
38 | (an example is provided in the Appendix below).
39 |
40 | "Derivative Works" shall mean any work, whether in Source or Object
41 | form, that is based on (or derived from) the Work and for which the
42 | editorial revisions, annotations, elaborations, or other modifications
43 | represent, as a whole, an original work of authorship. For the purposes
44 | of this License, Derivative Works shall not include works that remain
45 | separable from, or merely link (or bind by name) to the interfaces of,
46 | the Work and Derivative Works thereof.
47 |
48 | "Contribution" shall mean any work of authorship, including
49 | the original version of the Work and any modifications or additions
50 | to that Work or Derivative Works thereof, that is intentionally
51 | submitted to Licensor for inclusion in the Work by the copyright owner
52 | or by an individual or Legal Entity authorized to submit on behalf of
53 | the copyright owner. For the purposes of this definition, "submitted"
54 | means any form of electronic, verbal, or written communication sent
55 | to the Licensor or its representatives, including but not limited to
56 | communication on electronic mailing lists, source code control systems,
57 | and issue tracking systems that are managed by, or on behalf of, the
58 | Licensor for the purpose of discussing and improving the Work, but
59 | excluding communication that is conspicuously marked or otherwise
60 | designated in writing by the copyright owner as "Not a Contribution."
61 |
62 | "Contributor" shall mean Licensor and any individual or Legal Entity
63 | on behalf of whom a Contribution has been received by Licensor and
64 | subsequently incorporated within the Work.
65 |
66 | 2. Grant of Copyright License. Subject to the terms and conditions of
67 | this License, each Contributor hereby grants to You a perpetual,
68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69 | copyright license to reproduce, prepare Derivative Works of,
70 | publicly display, publicly perform, sublicense, and distribute the
71 | Work and such Derivative Works in Source or Object form.
72 |
73 | 3. Grant of Patent License. Subject to the terms and conditions of
74 | this License, each Contributor hereby grants to You a perpetual,
75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76 | (except as stated in this section) patent license to make, have made,
77 | use, offer to sell, sell, import, and otherwise transfer the Work,
78 | where such license applies only to those patent claims licensable
79 | by such Contributor that are necessarily infringed by their
80 | Contribution(s) alone or by combination of their Contribution(s)
81 | with the Work to which such Contribution(s) was submitted. If You
82 | institute patent litigation against any entity (including a
83 | cross-claim or counterclaim in a lawsuit) alleging that the Work
84 | or a Contribution incorporated within the Work constitutes direct
85 | or contributory patent infringement, then any patent licenses
86 | granted to You under this License for that Work shall terminate
87 | as of the date such litigation is filed.
88 |
89 | 4. Redistribution. You may reproduce and distribute copies of the
90 | Work or Derivative Works thereof in any medium, with or without
91 | modifications, and in Source or Object form, provided that You
92 | meet the following conditions:
93 |
94 | (a) You must give any other recipients of the Work or
95 | Derivative Works a copy of this License; and
96 |
97 | (b) You must cause any modified files to carry prominent notices
98 | stating that You changed the files; and
99 |
100 | (c) You must retain, in the Source form of any Derivative Works
101 | that You distribute, all copyright, patent, trademark, and
102 | attribution notices from the Source form of the Work,
103 | excluding those notices that do not pertain to any part of
104 | the Derivative Works; and
105 |
106 | (d) If the Work includes a "NOTICE" text file as part of its
107 | distribution, then any Derivative Works that You distribute must
108 | include a readable copy of the attribution notices contained
109 | within such NOTICE file, excluding those notices that do not
110 | pertain to any part of the Derivative Works, in at least one
111 | of the following places: within a NOTICE text file distributed
112 | as part of the Derivative Works; within the Source form or
113 | documentation, if provided along with the Derivative Works; or,
114 | within a display generated by the Derivative Works, if and
115 | wherever such third-party notices normally appear. The contents
116 | of the NOTICE file are for informational purposes only and
117 | do not modify the License. You may add Your own attribution
118 | notices within Derivative Works that You distribute, alongside
119 | or as an addendum to the NOTICE text from the Work, provided
120 | that such additional attribution notices cannot be construed
121 | as modifying the License.
122 |
123 | You may add Your own copyright statement to Your modifications and
124 | may provide additional or different license terms and conditions
125 | for use, reproduction, or distribution of Your modifications, or
126 | for any such Derivative Works as a whole, provided Your use,
127 | reproduction, and distribution of the Work otherwise complies with
128 | the conditions stated in this License.
129 |
130 | 5. Submission of Contributions. Unless You explicitly state otherwise,
131 | any Contribution intentionally submitted for inclusion in the Work
132 | by You to the Licensor shall be under the terms and conditions of
133 | this License, without any additional terms or conditions.
134 | Notwithstanding the above, nothing herein shall supersede or modify
135 | the terms of any separate license agreement you may have executed
136 | with Licensor regarding such Contributions.
137 |
138 | 6. Trademarks. This License does not grant permission to use the trade
139 | names, trademarks, service marks, or product names of the Licensor,
140 | except as required for reasonable and customary use in describing the
141 | origin of the Work and reproducing the content of the NOTICE file.
142 |
143 | 7. Disclaimer of Warranty. Unless required by applicable law or
144 | agreed to in writing, Licensor provides the Work (and each
145 | Contributor provides its Contributions) on an "AS IS" BASIS,
146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147 | implied, including, without limitation, any warranties or conditions
148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149 | PARTICULAR PURPOSE. You are solely responsible for determining the
150 | appropriateness of using or redistributing the Work and assume any
151 | risks associated with Your exercise of permissions under this License.
152 |
153 | 8. Limitation of Liability. In no event and under no legal theory,
154 | whether in tort (including negligence), contract, or otherwise,
155 | unless required by applicable law (such as deliberate and grossly
156 | negligent acts) or agreed to in writing, shall any Contributor be
157 | liable to You for damages, including any direct, indirect, special,
158 | incidental, or consequential damages of any character arising as a
159 | result of this License or out of the use or inability to use the
160 | Work (including but not limited to damages for loss of goodwill,
161 | work stoppage, computer failure or malfunction, or any and all
162 | other commercial damages or losses), even if such Contributor
163 | has been advised of the possibility of such damages.
164 |
165 | 9. Accepting Warranty or Additional Liability. While redistributing
166 | the Work or Derivative Works thereof, You may choose to offer,
167 | and charge a fee for, acceptance of support, warranty, indemnity,
168 | or other liability obligations and/or rights consistent with this
169 | License. However, in accepting such obligations, You may act only
170 | on Your own behalf and on Your sole responsibility, not on behalf
171 | of any other Contributor, and only if You agree to indemnify,
172 | defend, and hold each Contributor harmless for any liability
173 | incurred by, or claims asserted against, such Contributor by reason
174 | of your accepting any such warranty or additional liability.
175 |
176 | END OF TERMS AND CONDITIONS
177 |
178 | APPENDIX: How to apply the Apache License to your work.
179 |
180 | To apply the Apache License to your work, attach the following
181 | boilerplate notice, with the fields enclosed by brackets "[]"
182 | replaced with your own identifying information. (Don't include
183 | the brackets!) The text should be enclosed in the appropriate
184 | comment syntax for the file format. We also recommend that a
185 | file or class name and description of purpose be included on the
186 | same "printed page" as the copyright notice for easier
187 | identification within third-party archives.
188 |
189 | Copyright [yyyy] [name of copyright owner]
190 |
191 | Licensed under the Apache License, Version 2.0 (the "License");
192 | you may not use this file except in compliance with the License.
193 | You may obtain a copy of the License at
194 |
195 | http://www.apache.org/licenses/LICENSE-2.0
196 |
197 | Unless required by applicable law or agreed to in writing, software
198 | distributed under the License is distributed on an "AS IS" BASIS,
199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200 | See the License for the specific language governing permissions and
201 | limitations under the License.
202 |
--------------------------------------------------------------------------------
/LICENSE.md:
--------------------------------------------------------------------------------
1 | ## ACADEMIC PUBLIC LICENSE
2 |
3 | ### Permissions
4 | :heavy_check_mark: Non-Commercial use
5 | :heavy_check_mark: Modification
6 | :heavy_check_mark: Distribution
7 | :heavy_check_mark: Private use
8 |
9 | ### Limitations
10 | :x: Commercial Use
11 | :x: Liability
12 | :x: Warranty
13 |
14 | ### Conditions
15 | :information_source: License and copyright notice
16 | :information_source: Same License
17 |
18 | PromptIR is free for use in noncommercial settings: at academic institutions for teaching and research use, and at non-profit research organizations.
19 | You can use PromptIR in your research, academic work, non-commercial work, projects and personal work. We only ask you to credit us appropriately.
20 |
21 | You have the right to use the software, to distribute copies, to receive source code, to change the software and distribute your modifications or the modified software.
22 | If you distribute verbatim or modified copies of this software, they must be distributed under this license.
23 | This license guarantees that you're safe when using PromptIR in your work, for teaching or research.
24 | This license guarantees that PromptIR will remain available free of charge for nonprofit use.
25 | You can modify PromptIR to your purposes, and you can also share your modifications.
26 |
27 | If you would like to use PromptIR in commercial settings, contact us so we can discuss options. Send an email to vaishnav.potlapalli@mbzuai.ac.ae
28 |
29 |
30 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 | ## Parameter Efficient Adaptation for Image Restoration with Heterogeneous Mixture-of-Experts
6 |
7 | ## [[Paper](https://arxiv.org/pdf/2312.08881.pdf)]
8 |
9 | [Hang Guo](https://github.com/csguoh), [Tao Dai](https://cstaodai.com/), [Yuanchao Bai](https://scholar.google.com/citations?user=hjYIFZcAAAAJ&hl=zh-CN), Bin Chen, Xudong Ren, Zexuan Zhu, [Shu-Tao Xia](https://scholar.google.com/citations?hl=zh-CN&user=koAXTXgAAAAJ)
10 |
11 |
12 | > **Abstract:** Designing single-task image restoration models for specific degradation has seen great success in recent years. To achieve generalized image restoration, all-in-one methods have recently been proposed and shown potential for multiple restoration tasks using one single model. Despite the promising results, the existing all-in-one paradigm still suffers from high computational costs as well as limited generalization on unseen degradations. In this work, we introduce an alternative solution to improve the generalization of image restoration models. Drawing inspiration from recent advancements in Parameter Efficient Transfer Learning (PETL), we aim to tune only a small number of parameters to adapt pre-trained restoration models to various tasks. However, current PETL methods fail to generalize across varied restoration tasks due to their homogeneous representation nature. To this end, we propose AdaptIR, a Mixture-of-Experts (MoE) with orthogonal multi-branch design to capture local spatial, global spatial, and channel representation bases, followed by adaptive base combination to obtain heterogeneous representation for different degradations. Extensive experiments demonstrate that our AdaptIR achieves stable performance on single-degradation tasks, and excels in hybrid-degradation tasks, with fine-tuning only 0.6% parameters for 8 hours.
13 |
14 |
15 |
16 |
17 |
18 |
19 | ⭐If this work is helpful for you, please help star this repo. Thanks!🤗
20 |
21 |
22 |
23 | ## 📑 Contents
24 |
25 | - [Visual Results](#visual_results)
26 | - [News](#news)
27 | - [TODO](#todo)
28 | - [Results](#results)
29 | - [Citation](#cite)
30 |
31 |
32 | ## :eyes:Visual Results On Different Restoration Tasks
33 | [
](https://imgsli.com/MjI1Njk3) [
](https://imgsli.com/MjI1NzIx) [
](https://imgsli.com/MjI1NzEx) [
](https://imgsli.com/MjI1NzAw)
34 |
35 | [
](https://imgsli.com/MjI1NzAz) [
](https://imgsli.com/MjI1NzAx) [
](https://imgsli.com/MjI1NzE2)
36 |
37 |
38 |
39 | ## 🆕 News
40 |
41 | - **2023-12-12:** arXiv paper available.
42 | - **2023-12-16:** This repo is released.
43 | - **2023-09-28:** 😊Our AdaptIR was accepted by NeurIPS2024!
44 | - **2024-10-19:** 🔈The code is available now, enjoy yourself!
45 | - **2025-01-13:** Updated README file with detailed instruciton.
46 |
47 |
48 | ## ☑️ TODO
49 |
50 | - [x] arXiv version
51 | - [x] Release code
52 | - [x] More detailed introductions of README file
53 | - [ ] Further improvements
54 |
55 |
56 | ## 🥇 Results
57 |
58 | We achieve state-of-the-art adaptation performance on various downstream image restoration tasks. Detailed results can be found in the paper.
59 |
60 |
61 | Evaluation on Second-order Degradation (LR4&Noise30) (click to expand)
62 |
63 |
64 |
65 |
66 |
67 |
68 |
69 |
70 | Evaluation on Classic SR (click to expand)
71 |
72 |
73 |
74 |
75 |
76 |
77 |
78 |
79 | Evaluation on Denoise&DerainL (click to expand)
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 | Evaluation on Heavy Rain Streak Removal (click to expand)
89 |
90 |
91 |
92 |
93 |
94 |
95 |
96 |
97 | Evaluation on Low-light Image Enhancement (click to expand)
98 |
99 |
100 |
101 |
102 |
103 |
104 |
105 |
106 |
107 | Evaluation on Model Scalability (click to expand)
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 |
118 | ## Datasets & Models Preparation
119 |
120 | ### Datasets
121 |
122 | Since this work involves various restoration tasks, you may collect the training and testing datasets you need from existing repos, such as [Basicsr](https://github.com/XPixelGroup/BasicSR/blob/master/docs/DatasetPreparation.md), [Restormer](https://github.com/swz30/Restormer/tree/main), and [PromptIR](https://github.com/va1shn9v/PromptIR/blob/main/INSTALL.md).
123 |
124 |
125 |
126 |
127 | ### Pre-trained weights
128 |
129 | - IPT pre-trained models
130 | download the `IPT_pretrain` with the [link](https://drive.google.com/drive/folders/1MVSdUX0YBExauG0fFz4ANiWTrq9xZEj7) of the [IPT repo](https://github.com/huawei-noah/Pretrained-IPT).
131 |
132 |
133 | - EDT pre-trained models
134 | download the `SRx2x3x4_EDTB_ImageNet200K.pth` with the [link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155137927_link_cuhk_edu_hk/Eikt_wPDrIFCpVpiU0zYNu0BwOhQIHgNWuH1FYZbxZhq_w?e=bVEVeW) of the [EDT repo](https://github.com/fenglinglwb/EDT)
135 |
136 |
137 |
138 | ## Training
139 |
140 | Our AdaptIR can adapt the pretrained models to various unseen downstream tasks, including Hybrid-Degradation (`lr4_noise30`, `lr4_jpeg30`), Image SR (`sr_2`,`sr_3`,`sr_4`), Image denoising (`denoise_30`, `denoise_50`), Image Deraining (`derainL`, `derainH`) and low-light image enhancement (`low_light`).
141 |
142 | One can adjust the param `de_type` in the `./options.py` file to train the specific downstream models. Note that only the very lightweight AdaptIR is tuned and thus it only consumes about 8 hours for downstream adaptation.
143 |
144 | **One single 3090 with 24GB memory is enough for training.**
145 |
146 | You can simply run the following command to start training, with our default params:
147 |
148 | ```
149 | python train.py
150 | ```
151 |
152 |
153 | ## Testing
154 |
155 | After training, the downstream weights can be found in the `./train_ckpt` path. You can load this ckpt to evaluate the performance of the downstream unseen tasks.
156 |
157 | ```
158 | python test.py
159 | ```
160 |
161 |
162 |
163 |
164 | ## 🥰 Citation
165 |
166 | Please cite us if our work is useful for your research.
167 |
168 | ```
169 | @inproceedings{guoparameter,
170 | title={Parameter Efficient Adaptation for Image Restoration with Heterogeneous Mixture-of-Experts},
171 | author={Guo, Hang and Dai, Tao and Bai, Yuanchao and Chen, Bin and Ren, Xudong and Zhu, Zexuan and Xia, Shu-Tao},
172 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems}
173 | }
174 | ```
175 |
176 | ## License
177 |
178 | This project is released under the [Apache 2.0 license](LICENSE).
179 |
180 | ## Acknowledgement
181 |
182 | This code is based on [AirNet](https://github.com/XLearning-SCU/2022-CVPR-AirNet), [IPT](https://github.com/huawei-noah/Pretrained-IPT) and [EDT](https://github.com/fenglinglwb/EDT). Thanks for their awesome work.
183 |
184 | ## Contact
185 |
186 | If you have any questions, feel free to approach me at cshguo@gmail.com
187 |
188 |
--------------------------------------------------------------------------------
/assets/DRH.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/DRH.png
--------------------------------------------------------------------------------
/assets/Dn&DRL.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/Dn&DRL.png
--------------------------------------------------------------------------------
/assets/SR&DN.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/SR&DN.png
--------------------------------------------------------------------------------
/assets/adaptir_logo.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/adaptir_logo.png
--------------------------------------------------------------------------------
/assets/classicSR.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/classicSR.png
--------------------------------------------------------------------------------
/assets/imgsli1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli1.png
--------------------------------------------------------------------------------
/assets/imgsli2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli2.png
--------------------------------------------------------------------------------
/assets/imgsli3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli3.png
--------------------------------------------------------------------------------
/assets/imgsli4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli4.png
--------------------------------------------------------------------------------
/assets/imgsli5.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli5.png
--------------------------------------------------------------------------------
/assets/imgsli6.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli6.png
--------------------------------------------------------------------------------
/assets/imgsli7.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli7.png
--------------------------------------------------------------------------------
/assets/low-light.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/low-light.png
--------------------------------------------------------------------------------
/assets/pipeline.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/pipeline.png
--------------------------------------------------------------------------------
/assets/scalabiltity.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/scalabiltity.png
--------------------------------------------------------------------------------
/data_dir/rainy/rainTrain.txt:
--------------------------------------------------------------------------------
1 | rainy/rain-100.png
2 | rainy/rain-101.png
3 | rainy/rain-102.png
4 | rainy/rain-103.png
5 | rainy/rain-104.png
6 | rainy/rain-105.png
7 | rainy/rain-106.png
8 | rainy/rain-107.png
9 | rainy/rain-108.png
10 | rainy/rain-109.png
11 | rainy/rain-10.png
12 | rainy/rain-110.png
13 | rainy/rain-111.png
14 | rainy/rain-112.png
15 | rainy/rain-113.png
16 | rainy/rain-114.png
17 | rainy/rain-115.png
18 | rainy/rain-116.png
19 | rainy/rain-117.png
20 | rainy/rain-118.png
21 | rainy/rain-119.png
22 | rainy/rain-11.png
23 | rainy/rain-120.png
24 | rainy/rain-121.png
25 | rainy/rain-122.png
26 | rainy/rain-123.png
27 | rainy/rain-124.png
28 | rainy/rain-125.png
29 | rainy/rain-126.png
30 | rainy/rain-127.png
31 | rainy/rain-128.png
32 | rainy/rain-129.png
33 | rainy/rain-12.png
34 | rainy/rain-130.png
35 | rainy/rain-131.png
36 | rainy/rain-132.png
37 | rainy/rain-133.png
38 | rainy/rain-134.png
39 | rainy/rain-135.png
40 | rainy/rain-136.png
41 | rainy/rain-137.png
42 | rainy/rain-138.png
43 | rainy/rain-139.png
44 | rainy/rain-13.png
45 | rainy/rain-140.png
46 | rainy/rain-141.png
47 | rainy/rain-142.png
48 | rainy/rain-143.png
49 | rainy/rain-144.png
50 | rainy/rain-145.png
51 | rainy/rain-146.png
52 | rainy/rain-147.png
53 | rainy/rain-148.png
54 | rainy/rain-149.png
55 | rainy/rain-14.png
56 | rainy/rain-150.png
57 | rainy/rain-151.png
58 | rainy/rain-152.png
59 | rainy/rain-153.png
60 | rainy/rain-154.png
61 | rainy/rain-155.png
62 | rainy/rain-156.png
63 | rainy/rain-157.png
64 | rainy/rain-158.png
65 | rainy/rain-159.png
66 | rainy/rain-15.png
67 | rainy/rain-160.png
68 | rainy/rain-161.png
69 | rainy/rain-162.png
70 | rainy/rain-163.png
71 | rainy/rain-164.png
72 | rainy/rain-165.png
73 | rainy/rain-166.png
74 | rainy/rain-167.png
75 | rainy/rain-168.png
76 | rainy/rain-169.png
77 | rainy/rain-16.png
78 | rainy/rain-170.png
79 | rainy/rain-171.png
80 | rainy/rain-172.png
81 | rainy/rain-173.png
82 | rainy/rain-174.png
83 | rainy/rain-175.png
84 | rainy/rain-176.png
85 | rainy/rain-177.png
86 | rainy/rain-178.png
87 | rainy/rain-179.png
88 | rainy/rain-17.png
89 | rainy/rain-180.png
90 | rainy/rain-181.png
91 | rainy/rain-182.png
92 | rainy/rain-183.png
93 | rainy/rain-184.png
94 | rainy/rain-185.png
95 | rainy/rain-186.png
96 | rainy/rain-187.png
97 | rainy/rain-188.png
98 | rainy/rain-189.png
99 | rainy/rain-18.png
100 | rainy/rain-190.png
101 | rainy/rain-191.png
102 | rainy/rain-192.png
103 | rainy/rain-193.png
104 | rainy/rain-194.png
105 | rainy/rain-195.png
106 | rainy/rain-196.png
107 | rainy/rain-197.png
108 | rainy/rain-198.png
109 | rainy/rain-199.png
110 | rainy/rain-19.png
111 | rainy/rain-1.png
112 | rainy/rain-200.png
113 | rainy/rain-20.png
114 | rainy/rain-21.png
115 | rainy/rain-22.png
116 | rainy/rain-23.png
117 | rainy/rain-24.png
118 | rainy/rain-25.png
119 | rainy/rain-26.png
120 | rainy/rain-27.png
121 | rainy/rain-28.png
122 | rainy/rain-29.png
123 | rainy/rain-2.png
124 | rainy/rain-30.png
125 | rainy/rain-31.png
126 | rainy/rain-32.png
127 | rainy/rain-33.png
128 | rainy/rain-34.png
129 | rainy/rain-35.png
130 | rainy/rain-36.png
131 | rainy/rain-37.png
132 | rainy/rain-38.png
133 | rainy/rain-39.png
134 | rainy/rain-3.png
135 | rainy/rain-40.png
136 | rainy/rain-41.png
137 | rainy/rain-42.png
138 | rainy/rain-43.png
139 | rainy/rain-44.png
140 | rainy/rain-45.png
141 | rainy/rain-46.png
142 | rainy/rain-47.png
143 | rainy/rain-48.png
144 | rainy/rain-49.png
145 | rainy/rain-4.png
146 | rainy/rain-50.png
147 | rainy/rain-51.png
148 | rainy/rain-52.png
149 | rainy/rain-53.png
150 | rainy/rain-54.png
151 | rainy/rain-55.png
152 | rainy/rain-56.png
153 | rainy/rain-57.png
154 | rainy/rain-58.png
155 | rainy/rain-59.png
156 | rainy/rain-5.png
157 | rainy/rain-60.png
158 | rainy/rain-61.png
159 | rainy/rain-62.png
160 | rainy/rain-63.png
161 | rainy/rain-64.png
162 | rainy/rain-65.png
163 | rainy/rain-66.png
164 | rainy/rain-67.png
165 | rainy/rain-68.png
166 | rainy/rain-69.png
167 | rainy/rain-6.png
168 | rainy/rain-70.png
169 | rainy/rain-71.png
170 | rainy/rain-72.png
171 | rainy/rain-73.png
172 | rainy/rain-74.png
173 | rainy/rain-75.png
174 | rainy/rain-76.png
175 | rainy/rain-77.png
176 | rainy/rain-78.png
177 | rainy/rain-79.png
178 | rainy/rain-7.png
179 | rainy/rain-80.png
180 | rainy/rain-81.png
181 | rainy/rain-82.png
182 | rainy/rain-83.png
183 | rainy/rain-84.png
184 | rainy/rain-85.png
185 | rainy/rain-86.png
186 | rainy/rain-87.png
187 | rainy/rain-88.png
188 | rainy/rain-89.png
189 | rainy/rain-8.png
190 | rainy/rain-90.png
191 | rainy/rain-91.png
192 | rainy/rain-92.png
193 | rainy/rain-93.png
194 | rainy/rain-94.png
195 | rainy/rain-95.png
196 | rainy/rain-96.png
197 | rainy/rain-97.png
198 | rainy/rain-98.png
199 | rainy/rain-99.png
200 | rainy/rain-9.png
--------------------------------------------------------------------------------
/environment.yaml:
--------------------------------------------------------------------------------
1 | name: adaptir
2 | channels:
3 | - pytorch
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 | - _libgcc_mutex=0.1=main
8 | - _openmp_mutex=5.1=1_gnu
9 | - abseil-cpp=20211102.0=hd4dd3e8_0
10 | - absl-py=1.4.0=py38h06a4308_0
11 | - blas=1.0=mkl
12 | - blinker=1.4=py38h06a4308_0
13 | - bottleneck=1.3.5=py38h7deecbd_0
14 | - brotlipy=0.7.0=py38h0a891b7_1004
15 | - bzip2=1.0.8=h7f98852_4
16 | - c-ares=1.19.0=h5eee18b_0
17 | - ca-certificates=2023.08.22=h06a4308_0
18 | - cachetools=4.2.2=pyhd3eb1b0_0
19 | - certifi=2023.7.22=py38h06a4308_0
20 | - cffi=1.15.0=py38h3931269_0
21 | - chardet=3.0.4=py38h06a4308_1003
22 | - charset-normalizer=3.1.0=pyhd8ed1ab_0
23 | - cryptography=37.0.2=py38h2b5fc30_0
24 | - cudatoolkit=11.6.0=hecad31d_10
25 | - ffmpeg=4.3=hf484d3e_0
26 | - freetype=2.10.4=h0708190_1
27 | - giflib=5.2.1=h5eee18b_3
28 | - gmp=6.2.1=h58526e2_0
29 | - gnutls=3.6.13=h85f3911_1
30 | - google-auth=2.6.0=pyhd3eb1b0_0
31 | - google-auth-oauthlib=0.5.2=py38h06a4308_0
32 | - grpc-cpp=1.48.2=h5bf31a4_0
33 | - grpcio=1.48.2=py38h5bf31a4_0
34 | - idna=3.4=pyhd8ed1ab_0
35 | - intel-openmp=2021.4.0=h06a4308_3561
36 | - jpeg=9e=h166bdaf_1
37 | - lame=3.100=h7f98852_1001
38 | - lcms2=2.12=hddcbb42_0
39 | - ld_impl_linux-64=2.38=h1181459_1
40 | - libffi=3.4.4=h6a678d5_0
41 | - libgcc-ng=11.2.0=h1234567_1
42 | - libgomp=11.2.0=h1234567_1
43 | - libiconv=1.17=h166bdaf_0
44 | - libpng=1.6.37=h21135ba_2
45 | - libprotobuf=3.20.3=he621ea3_0
46 | - libstdcxx-ng=11.2.0=h1234567_1
47 | - libtiff=4.2.0=hecacb30_2
48 | - libwebp=1.2.2=h55f646e_0
49 | - libwebp-base=1.2.2=h7f98852_1
50 | - lz4-c=1.9.3=h9c3ff4c_1
51 | - markdown=3.4.1=py38h06a4308_0
52 | - markupsafe=2.1.1=py38h7f8727e_0
53 | - mkl=2021.4.0=h06a4308_640
54 | - mkl-service=2.4.0=py38h95df7f1_0
55 | - mkl_fft=1.3.1=py38h8666266_1
56 | - mkl_random=1.2.2=py38h1abd341_0
57 | - ncurses=6.4=h6a678d5_0
58 | - nettle=3.6=he412f7d_0
59 | - numexpr=2.8.4=py38he184ba9_0
60 | - numpy=1.24.3=py38h14f4228_0
61 | - numpy-base=1.24.3=py38h31eccc5_0
62 | - oauthlib=3.2.2=py38h06a4308_0
63 | - olefile=0.46=pyh9f0ad1d_1
64 | - openh264=2.1.1=h780b84a_0
65 | - openjpeg=2.4.0=hb52868f_1
66 | - openssl=1.1.1w=h7f8727e_0
67 | - packaging=23.1=py38h06a4308_0
68 | - pandas=2.0.3=py38h1128e8f_0
69 | - pip=23.0.1=py38h06a4308_0
70 | - pyasn1=0.4.8=pyhd3eb1b0_0
71 | - pyasn1-modules=0.2.8=py_0
72 | - pycparser=2.21=pyhd8ed1ab_0
73 | - pyjwt=2.4.0=py38h06a4308_0
74 | - pyopenssl=22.0.0=pyhd8ed1ab_1
75 | - pysocks=1.7.1=pyha2e5f31_6
76 | - python=3.8.16=h7a1cb2a_3
77 | - python-dateutil=2.8.2=pyhd3eb1b0_0
78 | - python-tzdata=2023.3=pyhd3eb1b0_0
79 | - python_abi=3.8=2_cp38
80 | - pytorch=1.12.1=py3.8_cuda11.6_cudnn8.3.2_0
81 | - pytorch-mutex=1.0=cuda
82 | - re2=2022.04.01=h295c915_0
83 | - readline=8.2=h5eee18b_0
84 | - requests-oauthlib=1.3.0=py_0
85 | - rsa=4.7.2=pyhd3eb1b0_1
86 | - setuptools=66.0.0=py38h06a4308_0
87 | - six=1.16.0=pyh6c4a22f_0
88 | - sqlite=3.41.2=h5eee18b_0
89 | - tensorboard=2.12.1=py38h06a4308_0
90 | - tensorboard-data-server=0.7.0=py38h52d8a92_0
91 | - tensorboard-plugin-wit=1.8.1=py38h06a4308_0
92 | - tk=8.6.12=h1ccaba5_0
93 | - torchaudio=0.12.1=py38_cu116
94 | - torchvision=0.13.1=py38_cu116
95 | - tqdm=4.65.0=py38hb070fc8_0
96 | - urllib3=1.26.15=pyhd8ed1ab_0
97 | - werkzeug=2.2.3=py38h06a4308_0
98 | - wheel=0.38.4=py38h06a4308_0
99 | - xz=5.4.2=h5eee18b_0
100 | - zipp=3.11.0=py38h06a4308_0
101 | - zlib=1.2.13=h5eee18b_0
102 | - zstd=1.5.2=ha4553b6_0
103 | - pip:
104 | - accelerate==0.25.0
105 | - addict==2.4.0
106 | - aiohttp==3.8.5
107 | - aiosignal==1.3.1
108 | - antlr4-python3-runtime==4.9.3
109 | - anyio==3.7.1
110 | - appdirs==1.4.4
111 | - argon2-cffi==23.1.0
112 | - argon2-cffi-bindings==21.2.0
113 | - arrow==1.2.3
114 | - astor==0.8.1
115 | - asttokens==2.2.1
116 | - async-lru==2.0.4
117 | - async-timeout==4.0.2
118 | - attrs==23.1.0
119 | - avalanche-lib==0.4.0
120 | - babel==2.14.0
121 | - backcall==0.2.0
122 | - beautifulsoup4==4.12.2
123 | - bleach==6.1.0
124 | - blessed==1.20.0
125 | - bs4==0.0.1
126 | - click==8.1.6
127 | - cloudpickle==3.0.0
128 | - comm==0.2.0
129 | - contourpy==1.1.0
130 | - croniter==1.3.15
131 | - ctrl-benchmark==0.0.4
132 | - cycler==0.11.0
133 | - cython==3.0.6
134 | - dateutils==0.6.12
135 | - debugpy==1.8.0
136 | - decorator==5.1.1
137 | - deepdiff==6.3.1
138 | - defusedxml==0.7.1
139 | - dill==0.3.7
140 | - docker-pycreds==0.4.0
141 | - easydict==1.10
142 | - editdistance==0.6.2
143 | - einops==0.6.1
144 | - exceptiongroup==1.1.2
145 | - executing==1.2.0
146 | - fastapi==0.88.0
147 | - fastjsonschema==2.19.0
148 | - filelock==3.12.0
149 | - fonttools==4.41.1
150 | - fqdn==1.5.1
151 | - frozenlist==1.4.0
152 | - fsspec==2023.5.0
153 | - fvcore==0.1.5.post20221221
154 | - gdown==4.7.1
155 | - gitdb==4.0.10
156 | - gitpython==3.1.32
157 | - gputil==1.4.0
158 | - gym==0.26.2
159 | - gym-notices==0.0.8
160 | - h11==0.14.0
161 | - higher==0.2.1
162 | - httpcore==1.0.2
163 | - httpx==0.26.0
164 | - huggingface-hub==0.19.4
165 | - imageio==2.28.1
166 | - imgaug==0.4.0
167 | - importlib-metadata==7.1.0
168 | - importlib-resources==6.0.0
169 | - inquirer==3.1.3
170 | - iopath==0.1.10
171 | - ipykernel==6.27.1
172 | - ipython==8.12.2
173 | - ipywidgets==8.1.1
174 | - isoduration==20.11.0
175 | - itsdangerous==2.1.2
176 | - jedi==0.18.2
177 | - jinja2==3.1.2
178 | - joblib==1.3.2
179 | - json5==0.9.14
180 | - jsonpointer==2.4
181 | - jsonschema==4.20.0
182 | - jsonschema-specifications==2023.11.2
183 | - jupyter==1.0.0
184 | - jupyter-client==8.6.0
185 | - jupyter-console==6.6.3
186 | - jupyter-core==5.5.1
187 | - jupyter-events==0.9.0
188 | - jupyter-lsp==2.2.1
189 | - jupyter-server==2.12.1
190 | - jupyter-server-terminals==0.5.0
191 | - jupyterlab==4.0.9
192 | - jupyterlab-pygments==0.3.0
193 | - jupyterlab-server==2.25.2
194 | - jupyterlab-widgets==3.0.9
195 | - kiwisolver==1.4.4
196 | - lazy-loader==0.2
197 | - lightning==2.0.1
198 | - lightning-cloud==0.5.37
199 | - lightning-utilities==0.9.0
200 | - lmdb==1.4.1
201 | - lpips==0.1.4
202 | - lvis==0.5.3
203 | - markdown-it-py==3.0.0
204 | - matplotlib==3.7.2
205 | - matplotlib-inline==0.1.6
206 | - mdurl==0.1.2
207 | - mistune==3.0.2
208 | - mmengine==0.10.3
209 | - multidict==6.0.4
210 | - nbclient==0.9.0
211 | - nbconvert==7.13.0
212 | - nbformat==5.9.2
213 | - nest-asyncio==1.5.8
214 | - networkx==2.8.8
215 | - notebook==7.0.6
216 | - notebook-shim==0.2.3
217 | - omegaconf==2.3.0
218 | - opencv-python==4.7.0.72
219 | - opencv-python-headless==4.10.0.84
220 | - opt-einsum==3.3.0
221 | - ordered-set==4.1.0
222 | - overrides==7.4.0
223 | - paddlepaddle-gpu==2.5.2.post116
224 | - pandocfilters==1.5.0
225 | - parso==0.8.3
226 | - pathtools==0.1.2
227 | - patsy==0.5.4
228 | - peft==0.5.0
229 | - petl==1.7.14
230 | - pexpect==4.8.0
231 | - pickleshare==0.7.5
232 | - pillow==9.5.0
233 | - pkgutil-resolve-name==1.3.10
234 | - platformdirs==4.1.0
235 | - plotly==5.18.0
236 | - portalocker==2.8.2
237 | - prometheus-client==0.19.0
238 | - prompt-toolkit==3.0.38
239 | - protobuf==3.20.3
240 | - psutil==5.9.5
241 | - ptflops==0.7
242 | - ptyprocess==0.7.0
243 | - pure-eval==0.2.2
244 | - pycocotools==2.0.7
245 | - pydantic==1.10.11
246 | - pydot==1.4.2
247 | - pygments==2.15.1
248 | - pyparsing==3.0.9
249 | - python-editor==1.0.4
250 | - python-json-logger==2.0.7
251 | - python-multipart==0.0.6
252 | - pytorch-lightning==2.0.1
253 | - pytorchcv==0.0.67
254 | - pytz==2023.3
255 | - pywavelets==1.4.1
256 | - pyyaml==6.0
257 | - pyzmq==25.1.2
258 | - qtconsole==5.5.1
259 | - qtpy==2.4.1
260 | - quadprog==0.1.11
261 | - readchar==4.0.5
262 | - referencing==0.32.0
263 | - regex==2023.10.3
264 | - requests==2.32.3
265 | - rfc3339-validator==0.1.4
266 | - rfc3986-validator==0.1.1
267 | - rich==13.4.2
268 | - rpds-py==0.15.2
269 | - safetensors==0.3.1
270 | - scikit-image==0.20.0
271 | - scikit-learn==0.24.2
272 | - scipy==1.9.1
273 | - send2trash==1.8.2
274 | - sentry-sdk==1.28.1
275 | - setproctitle==1.3.2
276 | - shapely==2.0.1
277 | - smmap==5.0.0
278 | - sniffio==1.3.0
279 | - soupsieve==2.4.1
280 | - stack-data==0.6.2
281 | - starlette==0.22.0
282 | - starsessions==1.3.0
283 | - statsmodels==0.14.1
284 | - tabulate==0.9.0
285 | - tenacity==8.2.3
286 | - tensorboardx==2.6
287 | - termcolor==2.4.0
288 | - terminado==0.18.0
289 | - thop==0.1.1-2209072238
290 | - threadpoolctl==3.2.0
291 | - tifffile==2023.4.12
292 | - timm==0.4.12
293 | - tinycss2==1.2.1
294 | - tokenizers==0.15.0
295 | - tomli==2.0.1
296 | - torchmetrics==1.0.1
297 | - tornado==6.4
298 | - traitlets==5.9.0
299 | - transformers==4.36.1
300 | - typing-extensions==4.4.0
301 | - uri-template==1.3.0
302 | - uvicorn==0.23.1
303 | - wcwidth==0.2.6
304 | - webcolors==1.13
305 | - webencodings==0.5.1
306 | - websocket-client==1.6.1
307 | - websockets==11.0.3
308 | - widgetsnbextension==4.0.9
309 | - yacs==0.1.8
310 | - yapf==0.40.2
311 | - yarl==1.9.2
312 |
313 |
--------------------------------------------------------------------------------
/net/common.py:
--------------------------------------------------------------------------------
1 | # 2021.05.07-Changed for IPT
2 | # Huawei Technologies Co., Ltd.
3 |
4 | import math
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 | def default_conv(in_channels, out_channels, kernel_size, bias=True):
11 | return nn.Conv2d(
12 | in_channels, out_channels, kernel_size,
13 | padding=(kernel_size//2), bias=bias)
14 |
15 | class MeanShift(nn.Conv2d):
16 | def __init__(
17 | self, rgb_range,
18 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1):
19 |
20 | super(MeanShift, self).__init__(3, 3, kernel_size=1)
21 | std = torch.Tensor(rgb_std)
22 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1)
23 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std
24 | for p in self.parameters():
25 | p.requires_grad = False
26 |
27 | class BasicBlock(nn.Sequential):
28 | def __init__(
29 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False,
30 | bn=True, act=nn.ReLU(True)):
31 |
32 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)]
33 | if bn:
34 | m.append(nn.BatchNorm2d(out_channels))
35 | if act is not None:
36 | m.append(act)
37 |
38 | super(BasicBlock, self).__init__(*m)
39 |
40 | class ResBlock(nn.Module):
41 | def __init__(
42 | self, conv, n_feats, kernel_size,
43 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
44 |
45 | super(ResBlock, self).__init__()
46 | m = []
47 | for i in range(2):
48 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias))
49 | if bn:
50 | m.append(nn.BatchNorm2d(n_feats))
51 | if i == 0:
52 | m.append(act)
53 |
54 | self.body = nn.Sequential(*m)
55 | self.res_scale = res_scale
56 |
57 | def forward(self, x):
58 | res = self.body(x).mul(self.res_scale)
59 | res += x
60 |
61 | return res
62 |
63 | class Upsampler(nn.Sequential):
64 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True):
65 |
66 | m = []
67 | if (scale & (scale - 1)) == 0: # Is scale = 2^n?
68 | for _ in range(int(math.log(scale, 2))):
69 | m.append(conv(n_feats, 4 * n_feats, 3, bias))
70 | m.append(nn.PixelShuffle(2))
71 | if bn:
72 | m.append(nn.BatchNorm2d(n_feats))
73 | if act == 'relu':
74 | m.append(nn.ReLU(True))
75 | elif act == 'prelu':
76 | m.append(nn.PReLU(n_feats))
77 |
78 | elif scale == 3:
79 | m.append(conv(n_feats, 9 * n_feats, 3, bias))
80 | m.append(nn.PixelShuffle(3))
81 | if bn:
82 | m.append(nn.BatchNorm2d(n_feats))
83 | if act == 'relu':
84 | m.append(nn.ReLU(True))
85 | elif act == 'prelu':
86 | m.append(nn.PReLU(n_feats))
87 | else:
88 | raise NotImplementedError
89 |
90 | super(Upsampler, self).__init__(*m)
91 |
92 |
--------------------------------------------------------------------------------
/net/ipt.py:
--------------------------------------------------------------------------------
1 | # 2021.05.07-Changed for IPT
2 | # Huawei Technologies Co., Ltd.
3 | import os
4 |
5 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
6 |
7 | from net import common
8 | import math
9 | import torch
10 | import torch.nn.functional as F
11 | from torch import nn, Tensor
12 | import copy
13 | from torchvision import ops
14 | from matplotlib import pyplot as plt
15 | import numpy as np
16 | from functools import partial, reduce
17 | from operator import mul
18 |
19 |
20 | class LayerNorm2d(nn.Module):
21 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
22 | super().__init__()
23 | self.weight = nn.Parameter(torch.ones(num_channels))
24 | self.bias = nn.Parameter(torch.zeros(num_channels))
25 | self.eps = eps
26 |
27 | def forward(self, x: torch.Tensor) -> torch.Tensor:
28 | u = x.mean(1, keepdim=True)
29 | s = (x - u).pow(2).mean(1, keepdim=True)
30 | x = (x - u) / torch.sqrt(s + self.eps)
31 | x = self.weight[:, None, None] * x + self.bias[:, None, None]
32 | return x
33 |
34 |
35 | def tensor_to_image(tensor):
36 | if tensor.max() > 1:
37 | tensor = tensor / 255.
38 | # Ensure the input tensor is on the CPU and in the range [0, 1]
39 | tensor = tensor.cpu()
40 | # tensor = tensor.clamp(0, 1)
41 |
42 | # Convert the tensor to a NumPy array
43 | image = tensor.squeeze(0).permute(1, 2, 0).numpy()
44 | plt.imshow(image)
45 | plt.savefig('./test.jpg')
46 |
47 |
48 | class IPT(nn.Module):
49 | def __init__(self, args):
50 | super(IPT, self).__init__()
51 | self.TASK_MAP = {'lr4_noise30': 2, 'lr4_jpeg30': 2, 'sr_2': 0, 'sr_3': 1, 'sr_4': 2,
52 | 'derainH': 3,'derainL': 3, 'denoise_30': 4, 'denoise_50': 5, 'low_light': 5, }
53 | if isinstance(args.de_type,list):
54 | self.task_idx = None
55 | else:
56 | self.task_idx = self.TASK_MAP[args.de_type] if type(args.de_type) is not list else 5
57 | conv = common.default_conv
58 | self.scales = [2, 3, 4, 1, 1, 1]
59 | n_feats = 64
60 | self.patch_size = 48
61 | n_colors = 3
62 | kernel_size = 3
63 | rgb_range = 255
64 | act = nn.ReLU(True)
65 |
66 | self.sub_mean = common.MeanShift(rgb_range)
67 | self.add_mean = common.MeanShift(rgb_range, sign=1)
68 |
69 | self.head = nn.ModuleList([
70 | nn.Sequential(
71 | conv(n_colors, n_feats, kernel_size),
72 | common.ResBlock(conv, n_feats, 5, act=act),
73 | common.ResBlock(conv, n_feats, 5, act=act)
74 | ) for _ in self.scales
75 | ])
76 |
77 | self.body = VisionTransformer(img_dim=48, patch_dim=3,
78 | num_channels=n_feats, embedding_dim=n_feats * 3 * 3,
79 | num_heads=12, num_layers=12,
80 | hidden_dim=n_feats * 3 * 3 * 4, num_queries=len(self.scales),
81 | dropout_rate=0, mlp=False, pos_every=False,
82 | no_pos=False, no_norm=False)
83 |
84 | self.tail = nn.ModuleList([
85 | nn.Sequential(
86 | common.Upsampler(conv, s, n_feats, act=False),
87 | conv(n_feats, 3, kernel_size)
88 | ) for s in self.scales
89 | ])
90 |
91 | def forward(self, x, de_id=None):
92 | x = x * 255.
93 | if not self.training:
94 | return self.forward_chop(x) / 255.
95 | else:
96 | return self.forward_train(x, de_id) / 255.
97 |
98 | def forward_train(self, x,de_id=None):
99 | if de_id is not None:
100 | self.task_idx = self.TASK_MAP[de_id[0]] if type(de_id[0]) is not list else 5
101 | x = self.sub_mean(x)
102 | x = self.head[self.task_idx](x)
103 |
104 | res = self.body(x, self.task_idx)
105 | res += x
106 |
107 | x = self.tail[self.task_idx](res)
108 | x = self.add_mean(x)
109 |
110 | return x
111 |
112 | def set_scale(self, task_idx):
113 | self.task_idx = task_idx
114 |
115 | def forward_chop(self, x):
116 | x.cpu()
117 | batchsize = 64
118 | h, w = x.size()[-2:]
119 | padsize = int(self.patch_size)
120 | shave = int(self.patch_size / 2)
121 |
122 | scale = self.scales[self.task_idx]
123 |
124 | h_cut = (h - padsize) % (int(shave / 2))
125 | w_cut = (w - padsize) % (int(shave / 2))
126 |
127 | x_unfold = torch.nn.functional.unfold(x, padsize, stride=int(shave / 2)).transpose(0,
128 | 2).contiguous() # [num_patch, 48*48*3, N]
129 |
130 | x_hw_cut = x[..., (h - padsize):, (w - padsize):]
131 | y_hw_cut = self.forward_train(x_hw_cut.cuda()).cpu()
132 |
133 | x_h_cut = x[..., (h - padsize):, :]
134 | x_w_cut = x[..., :, (w - padsize):]
135 | y_h_cut = self.cut_h(x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize)
136 | y_w_cut = self.cut_w(x_w_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize)
137 |
138 | x_h_top = x[..., :padsize, :]
139 | x_w_top = x[..., :, :padsize]
140 | y_h_top = self.cut_h(x_h_top, h, w, h_cut, w_cut, padsize, shave, scale, batchsize)
141 | y_w_top = self.cut_w(x_w_top, h, w, h_cut, w_cut, padsize, shave, scale, batchsize)
142 |
143 | x_unfold = x_unfold.view(x_unfold.size(0), -1, padsize, padsize)
144 | y_unfold = []
145 |
146 | x_range = x_unfold.size(0) // batchsize + (x_unfold.size(0) % batchsize != 0)
147 | x_unfold.cuda()
148 | for i in range(x_range):
149 | y_unfold.append(self.forward_train(x_unfold[i * batchsize:(i + 1) * batchsize, ...]).cpu())
150 | y_unfold = torch.cat(y_unfold, dim=0)
151 |
152 | y = torch.nn.functional.fold(y_unfold.view(y_unfold.size(0), -1, 1).transpose(0, 2).contiguous(),
153 | ((h - h_cut) * scale, (w - w_cut) * scale), padsize * scale,
154 | stride=int(shave / 2 * scale))
155 |
156 | y[..., :padsize * scale, :] = y_h_top
157 | y[..., :, :padsize * scale] = y_w_top
158 |
159 | y_unfold = y_unfold[..., int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale),
160 | int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale)].contiguous()
161 | y_inter = torch.nn.functional.fold(y_unfold.view(y_unfold.size(0), -1, 1).transpose(0, 2).contiguous(),
162 | ((h - h_cut - shave) * scale, (w - w_cut - shave) * scale),
163 | padsize * scale - shave * scale, stride=int(shave / 2 * scale))
164 |
165 | y_ones = torch.ones(y_inter.shape, dtype=y_inter.dtype)
166 | divisor = torch.nn.functional.fold(
167 | torch.nn.functional.unfold(y_ones, padsize * scale - shave * scale, stride=int(shave / 2 * scale)),
168 | ((h - h_cut - shave) * scale, (w - w_cut - shave) * scale), padsize * scale - shave * scale,
169 | stride=int(shave / 2 * scale))
170 |
171 | y_inter = y_inter / divisor
172 |
173 | y[..., int(shave / 2 * scale):(h - h_cut) * scale - int(shave / 2 * scale),
174 | int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)] = y_inter
175 |
176 | y = torch.cat([y[..., :y.size(2) - int((padsize - h_cut) / 2 * scale), :],
177 | y_h_cut[..., int((padsize - h_cut) / 2 * scale + 0.5):, :]], dim=2)
178 | y_w_cat = torch.cat([y_w_cut[..., :y_w_cut.size(2) - int((padsize - h_cut) / 2 * scale), :],
179 | y_hw_cut[..., int((padsize - h_cut) / 2 * scale + 0.5):, :]], dim=2)
180 | y = torch.cat([y[..., :, :y.size(3) - int((padsize - w_cut) / 2 * scale)],
181 | y_w_cat[..., :, int((padsize - w_cut) / 2 * scale + 0.5):]], dim=3)
182 | return y.cuda()
183 |
184 | def cut_h(self, x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize):
185 |
186 | x_h_cut_unfold = torch.nn.functional.unfold(x_h_cut, padsize, stride=int(shave / 2)).transpose(0,
187 | 2).contiguous()
188 |
189 | x_h_cut_unfold = x_h_cut_unfold.view(x_h_cut_unfold.size(0), -1, padsize, padsize)
190 | x_range = x_h_cut_unfold.size(0) // batchsize + (x_h_cut_unfold.size(0) % batchsize != 0)
191 | y_h_cut_unfold = []
192 | x_h_cut_unfold.cuda()
193 | for i in range(x_range):
194 | y_h_cut_unfold.append(self.forward_train(x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, ...]).cpu())
195 | y_h_cut_unfold = torch.cat(y_h_cut_unfold, dim=0)
196 |
197 | y_h_cut = torch.nn.functional.fold(
198 | y_h_cut_unfold.view(y_h_cut_unfold.size(0), -1, 1).transpose(0, 2).contiguous(),
199 | (padsize * scale, (w - w_cut) * scale), padsize * scale, stride=int(shave / 2 * scale))
200 | y_h_cut_unfold = y_h_cut_unfold[..., :,
201 | int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale)].contiguous()
202 | y_h_cut_inter = torch.nn.functional.fold(
203 | y_h_cut_unfold.view(y_h_cut_unfold.size(0), -1, 1).transpose(0, 2).contiguous(),
204 | (padsize * scale, (w - w_cut - shave) * scale), (padsize * scale, padsize * scale - shave * scale),
205 | stride=int(shave / 2 * scale))
206 |
207 | y_ones = torch.ones(y_h_cut_inter.shape, dtype=y_h_cut_inter.dtype)
208 | divisor = torch.nn.functional.fold(
209 | torch.nn.functional.unfold(y_ones, (padsize * scale, padsize * scale - shave * scale),
210 | stride=int(shave / 2 * scale)), (padsize * scale, (w - w_cut - shave) * scale),
211 | (padsize * scale, padsize * scale - shave * scale), stride=int(shave / 2 * scale))
212 | y_h_cut_inter = y_h_cut_inter / divisor
213 |
214 | y_h_cut[..., :, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)] = y_h_cut_inter
215 | return y_h_cut
216 |
217 | def cut_w(self, x_w_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize):
218 | x_w_cut_unfold = torch.nn.functional.unfold(x_w_cut, padsize, stride=int(shave / 2)).transpose(0,
219 | 2).contiguous()
220 |
221 | x_w_cut_unfold = x_w_cut_unfold.view(x_w_cut_unfold.size(0), -1, padsize, padsize)
222 | x_range = x_w_cut_unfold.size(0) // batchsize + (x_w_cut_unfold.size(0) % batchsize != 0)
223 | y_w_cut_unfold = []
224 | x_w_cut_unfold.cuda()
225 | for i in range(x_range):
226 | y_w_cut_unfold.append(self.forward_train(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, ...]).cpu())
227 | y_w_cut_unfold = torch.cat(y_w_cut_unfold, dim=0)
228 |
229 | y_w_cut = torch.nn.functional.fold(
230 | y_w_cut_unfold.view(y_w_cut_unfold.size(0), -1, 1).transpose(0, 2).contiguous(),
231 | ((h - h_cut) * scale, padsize * scale), padsize * scale, stride=int(shave / 2 * scale))
232 | y_w_cut_unfold = y_w_cut_unfold[..., int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale),
233 | :].contiguous()
234 | y_w_cut_inter = torch.nn.functional.fold(
235 | y_w_cut_unfold.view(y_w_cut_unfold.size(0), -1, 1).transpose(0, 2).contiguous(),
236 | ((h - h_cut - shave) * scale, padsize * scale), (padsize * scale - shave * scale, padsize * scale),
237 | stride=int(shave / 2 * scale))
238 |
239 | y_ones = torch.ones(y_w_cut_inter.shape, dtype=y_w_cut_inter.dtype)
240 | divisor = torch.nn.functional.fold(
241 | torch.nn.functional.unfold(y_ones, (padsize * scale - shave * scale, padsize * scale),
242 | stride=int(shave / 2 * scale)), ((h - h_cut - shave) * scale, padsize * scale),
243 | (padsize * scale - shave * scale, padsize * scale), stride=int(shave / 2 * scale))
244 | y_w_cut_inter = y_w_cut_inter / divisor
245 |
246 | y_w_cut[..., int(shave / 2 * scale):(h - h_cut) * scale - int(shave / 2 * scale), :] = y_w_cut_inter
247 | return y_w_cut
248 |
249 |
250 | class VisionTransformer(nn.Module):
251 | def __init__(
252 | self,
253 | img_dim,
254 | patch_dim,
255 | num_channels,
256 | embedding_dim,
257 | num_heads,
258 | num_layers,
259 | hidden_dim,
260 | num_queries,
261 | positional_encoding_type="learned",
262 | dropout_rate=0,
263 | no_norm=False,
264 | mlp=False,
265 | pos_every=False,
266 | no_pos=False
267 | ):
268 | super(VisionTransformer, self).__init__()
269 |
270 | assert embedding_dim % num_heads == 0
271 | assert img_dim % patch_dim == 0
272 | self.no_norm = no_norm
273 | self.mlp = mlp
274 | self.embedding_dim = embedding_dim
275 | self.num_heads = num_heads
276 | self.patch_dim = patch_dim
277 | self.num_channels = num_channels
278 |
279 | self.img_dim = img_dim
280 | self.pos_every = pos_every
281 | self.num_patches = int((img_dim // patch_dim) ** 2)
282 | self.seq_length = self.num_patches
283 | self.flatten_dim = patch_dim * patch_dim * num_channels
284 |
285 | self.out_dim = patch_dim * patch_dim * num_channels
286 |
287 | self.no_pos = no_pos
288 |
289 | if self.mlp == False:
290 | self.linear_encoding = nn.Linear(self.flatten_dim, embedding_dim)
291 | self.mlp_head = nn.Sequential(
292 | nn.Linear(embedding_dim, hidden_dim),
293 | nn.Dropout(dropout_rate),
294 | nn.ReLU(),
295 | nn.Linear(hidden_dim, self.out_dim),
296 | nn.Dropout(dropout_rate)
297 | )
298 |
299 | self.query_embed = nn.Embedding(num_queries, embedding_dim * self.seq_length)
300 |
301 | encoder_layer = TransformerEncoderLayer(embedding_dim, num_heads, hidden_dim, dropout_rate, self.no_norm)
302 | self.encoder = TransformerEncoder(encoder_layer, num_layers)
303 |
304 | decoder_layer = TransformerDecoderLayer(embedding_dim, num_heads, hidden_dim, dropout_rate, self.no_norm)
305 | self.decoder = TransformerDecoder(decoder_layer, num_layers)
306 |
307 | if not self.no_pos:
308 | self.position_encoding = LearnedPositionalEncoding(
309 | self.seq_length, self.embedding_dim, self.seq_length
310 | )
311 |
312 | self.dropout_layer1 = nn.Dropout(dropout_rate)
313 |
314 | if no_norm:
315 | for m in self.modules():
316 | if isinstance(m, nn.Linear):
317 | nn.init.normal_(m.weight, std=1 / m.weight.size(1))
318 |
319 | def forward(self, x, query_idx, con=False):
320 |
321 | x = torch.nn.functional.unfold(x, self.patch_dim, stride=self.patch_dim).transpose(1, 2).transpose(0,
322 | 1).contiguous()
323 |
324 | if self.mlp == False:
325 | x = self.dropout_layer1(self.linear_encoding(x)) + x
326 |
327 | query_embed = self.query_embed.weight[query_idx].view(-1, 1, self.embedding_dim).repeat(1, x.size(1), 1)
328 | else:
329 | query_embed = None
330 |
331 | if not self.no_pos:
332 | pos = self.position_encoding(x).transpose(0, 1)
333 |
334 | if self.pos_every:
335 | x = self.encoder(x, pos=pos)
336 | x = self.decoder(x, x, pos=pos, query_pos=query_embed)
337 | elif self.no_pos:
338 | x = self.encoder(x)
339 | x = self.decoder(x, x, query_pos=query_embed)
340 | else: # here
341 | x = self.encoder(x + pos)
342 | x = self.decoder(x, x, query_pos=query_embed)
343 |
344 | if self.mlp == False:
345 | x = self.mlp_head(x) + x
346 |
347 | x = x.transpose(0, 1).contiguous().view(x.size(1), -1, self.flatten_dim)
348 |
349 | if con:
350 | con_x = x
351 | x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), int(self.img_dim), self.patch_dim,
352 | stride=self.patch_dim)
353 | return x, con_x
354 |
355 | x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), int(self.img_dim), self.patch_dim,
356 | stride=self.patch_dim)
357 |
358 | return x
359 |
360 |
361 | class LearnedPositionalEncoding(nn.Module):
362 | def __init__(self, max_position_embeddings, embedding_dim, seq_length):
363 | super(LearnedPositionalEncoding, self).__init__()
364 | self.pe = nn.Embedding(max_position_embeddings, embedding_dim)
365 | self.seq_length = seq_length
366 |
367 | self.register_buffer(
368 | "position_ids", torch.arange(self.seq_length).expand((1, -1))
369 | )
370 |
371 | def forward(self, x, position_ids=None):
372 | if position_ids is None:
373 | position_ids = self.position_ids[:, : self.seq_length]
374 |
375 | position_embeddings = self.pe(position_ids)
376 | return position_embeddings
377 |
378 |
379 | class TransformerEncoder(nn.Module):
380 |
381 | def __init__(self, encoder_layer, num_layers):
382 | super().__init__()
383 | self.layers = _get_clones(encoder_layer, num_layers)
384 | self.num_layers = num_layers
385 | d_model = 576
386 |
387 | def forward(self, src, pos=None):
388 | output = src
389 | for idx, layer in enumerate(self.layers):
390 | output = layer(output, pos=pos)
391 | return output
392 |
393 |
394 | class TransformerEncoderLayer(nn.Module):
395 |
396 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, no_norm=False,
397 | activation="relu"):
398 | super().__init__()
399 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
400 | # Implementation of Feedforward model
401 | self.linear1 = nn.Linear(d_model, dim_feedforward)
402 | self.dropout = nn.Dropout(dropout)
403 | self.linear2 = nn.Linear(dim_feedforward, d_model)
404 |
405 | self.norm1 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
406 | self.norm2 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
407 | self.dropout1 = nn.Dropout(dropout)
408 | self.dropout2 = nn.Dropout(dropout)
409 | self.activation = _get_activation_fn(activation)
410 |
411 | self.adaptir = AdaptIR(d_model)
412 |
413 | nn.init.kaiming_uniform_(self.self_attn.in_proj_weight, a=math.sqrt(5))
414 |
415 | def with_pos_embed(self, tensor, pos):
416 | return tensor if pos is None else tensor + pos
417 |
418 | def forward(self, src, pos=None):
419 | src2 = self.norm1(src)
420 | q = k = self.with_pos_embed(src2, pos)
421 | src2 = self.self_attn(q, k, src2)
422 |
423 | src = src + self.dropout1(src2[0])
424 | src2 = self.norm2(src)
425 | adapt = self.adaptir(src2)
426 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
427 | src = src + self.dropout2(src2+adapt)
428 | return src
429 |
430 |
431 | class TransformerDecoder(nn.Module):
432 | def __init__(self, decoder_layer, num_layers):
433 | super().__init__()
434 | self.layers = _get_clones(decoder_layer, num_layers)
435 | self.num_layers = num_layers
436 |
437 | def forward(self, tgt, memory, pos=None, query_pos=None):
438 | output = tgt
439 | for idx, layer in enumerate(self.layers):
440 | output = layer(output, memory, pos=pos, query_pos=query_pos)
441 | return output
442 |
443 |
444 | class TransformerDecoderLayer(nn.Module):
445 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, no_norm=False,
446 | activation="relu"):
447 | super().__init__()
448 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
449 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False)
450 | # Implementation of Feedforward model
451 | self.linear1 = nn.Linear(d_model, dim_feedforward)
452 | self.dropout = nn.Dropout(dropout)
453 | self.linear2 = nn.Linear(dim_feedforward, d_model)
454 |
455 | self.norm1 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
456 | self.norm2 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
457 | self.norm3 = nn.LayerNorm(d_model) if not no_norm else nn.Identity()
458 | self.dropout1 = nn.Dropout(dropout)
459 | self.dropout2 = nn.Dropout(dropout)
460 | self.dropout3 = nn.Dropout(dropout)
461 |
462 | self.activation = _get_activation_fn(activation)
463 |
464 | self.adaptir = AdaptIR(d_model)
465 |
466 | def with_pos_embed(self, tensor, pos):
467 | if pos is not None and pos.shape[0] < tensor.shape[0]:
468 | pos = torch.cat(
469 | [torch.zeros(tensor.shape[0] - pos.shape[0], tensor.shape[1], tensor.shape[2]).to(pos.device), pos],
470 | dim=0)
471 | return tensor if pos is None else tensor + pos
472 |
473 | def forward(self, tgt, memory, pos=None, query_pos=None):
474 | tgt2 = self.norm1(tgt)
475 | q = k = self.with_pos_embed(tgt2, query_pos)
476 | tgt2 = self.self_attn(q, k, value=tgt2)[0]
477 | tgt = tgt + self.dropout1(tgt2)
478 |
479 | tgt2 = self.norm2(tgt)
480 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
481 | key=self.with_pos_embed(memory, pos),
482 | value=memory)[0]
483 | tgt = tgt + self.dropout2(tgt2)
484 |
485 | tgt2 = self.norm3(tgt)
486 | adapt = self.adaptir(tgt2)
487 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
488 |
489 | tgt = tgt + self.dropout3(tgt2+adapt)
490 | return tgt
491 |
492 |
493 | def _get_clones(module, N):
494 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
495 |
496 |
497 | def _get_activation_fn(activation):
498 | """Return an activation function given a string"""
499 | if activation == "relu":
500 | return F.relu
501 | if activation == "gelu":
502 | return F.gelu
503 | if activation == "glu":
504 | return F.glu
505 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
506 |
507 |
508 | class AdaptIR(nn.Module):
509 | def __init__(self, d_model):
510 | super(AdaptIR, self).__init__()
511 | self.hidden = d_model // 24
512 | self.rank = self.hidden // 2
513 | self.kernel_size = 3
514 | self.group = self.hidden
515 | self.head = nn.Conv2d(d_model, self.hidden, 1, 1)
516 |
517 | self.BN = nn.BatchNorm2d(self.hidden)
518 |
519 | self.conv_weight_A = nn.Parameter(torch.randn(self.hidden, self.rank))
520 | self.conv_weight_B = nn.Parameter(
521 | torch.randn(self.rank, self.hidden // self.group * self.kernel_size * self.kernel_size))
522 | self.conv_bias = nn.Parameter(torch.zeros(self.hidden))
523 | nn.init.kaiming_uniform_(self.conv_weight_A, a=math.sqrt(5))
524 | nn.init.kaiming_uniform_(self.conv_weight_B, a=math.sqrt(5))
525 |
526 | self.amp_fuse = nn.Conv2d(self.hidden, self.hidden, 1, 1, groups=self.hidden)
527 | self.pha_fuse = nn.Conv2d(self.hidden, self.hidden, 1, 1, groups=self.hidden)
528 | nn.init.ones_(self.pha_fuse.weight)
529 | nn.init.ones_(self.amp_fuse.weight)
530 | nn.init.zeros_(self.amp_fuse.bias)
531 | nn.init.zeros_(self.pha_fuse.bias)
532 |
533 | self.compress = nn.Conv2d(self.hidden, 1, 1, 1)
534 | self.proj = nn.Sequential(
535 | nn.Linear(self.hidden, self.hidden // 2),
536 | nn.GELU(),
537 | nn.Linear(self.hidden // 2, self.hidden),
538 | )
539 | self.tail = nn.Conv2d(self.hidden, d_model, 1, 1, bias=False)
540 | nn.init.zeros_(self.tail.weight)
541 |
542 |
543 | self.channel_interaction = nn.Sequential(
544 | nn.AdaptiveAvgPool2d(1),
545 | nn.Conv2d(self.hidden, self.hidden // 8, kernel_size=1),
546 | nn.GELU(),
547 | nn.Conv2d(self.hidden // 8,self.hidden, kernel_size=1)
548 | )
549 | nn.init.zeros_(self.channel_interaction[3].weight)
550 | nn.init.zeros_(self.channel_interaction[3].bias)
551 |
552 |
553 | self.spatial_interaction = nn.Conv2d(self.hidden, 1, kernel_size=1)
554 | nn.init.zeros_(self.spatial_interaction.weight)
555 | nn.init.zeros_(self.spatial_interaction.bias)
556 |
557 |
558 | def forward(self, x):
559 | L, N, C = x.shape
560 | H = W = int(math.sqrt(L))
561 | x = x.view(H, W, N, C).permute(2, 3, 0, 1).contiguous() # N,C,H,W
562 | x = self.BN(self.head(x))
563 |
564 | global_x = torch.fft.rfft2(x, dim=(2, 3), norm='ortho')
565 | mag_x = torch.abs(global_x)
566 | pha_x = torch.angle(global_x)
567 | Mag = self.amp_fuse(mag_x)
568 | Pha = self.pha_fuse(pha_x)
569 | real = Mag * torch.cos(Pha)
570 | imag = Mag * torch.sin(Pha)
571 | global_x = torch.complex(real, imag)
572 | global_x = torch.fft.irfft2(global_x, s=(H, W), dim=(2, 3), norm='ortho') # N,C,H,W
573 | global_x = torch.abs(global_x)
574 |
575 | conv_weight = (self.conv_weight_A @ self.conv_weight_B) \
576 | .view(self.hidden, self.hidden // self.group, self.kernel_size, self.kernel_size).contiguous()
577 | local_x = F.conv2d(x, weight=conv_weight, bias=self.conv_bias, stride=1, padding=1, groups=self.group)
578 |
579 | score = self.compress(x).view(N, 1, H * W).permute(0, 2, 1).contiguous() # N,HW,1
580 | score = F.softmax(score, dim=1)
581 | out = x.view(N, self.hidden, H * W) # N,C,HW
582 | out = out @ score # N,C,1
583 | out = out.permute(2, 0, 1) # 1,N,C
584 | out = self.proj(out)
585 | channel_score = out.permute(1, 2, 0).unsqueeze(-1).contiguous() # N,C,1,1
586 |
587 | channel_gate = self.channel_interaction(global_x).sigmoid()
588 | spatial_gate = self.spatial_interaction(local_x).sigmoid()
589 | spatial_x = channel_gate*local_x+spatial_gate*global_x
590 |
591 | x = self.tail(channel_score*spatial_x)
592 | x = x.view(N, C, H * W).permute(2, 0, 1).contiguous()
593 | return x
594 |
595 |
596 |
--------------------------------------------------------------------------------
/options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser()
4 |
5 | # Input Parameters
6 | parser.add_argument('--cuda', type=int, default=0)
7 | parser.add_argument('--arch', type=str, default='IPT',choices=['IPT','EDT'])
8 | parser.add_argument('--de_type', nargs='+', default='lr4_noise30',
9 | choices=['lr4_noise30', 'lr4_jpeg30', 'sr_2','sr_3','sr_4', 'denoise_30'
10 | 'denoise_50', 'derainL', 'derainH', 'low_light',],
11 | help='which type of degradations is training and testing for.')
12 | parser.add_argument('--patch_size', type=int, default=48, help='patchsize of input.')
13 | parser.add_argument('--epochs', type=int, default=500, help='maximum number of epochs to train the total model.')
14 | parser.add_argument('--val_every_n_epoch', type=int, default=25)
15 | parser.add_argument('--batch_size', type=int,default=16,help="Batch size to use per GPU")
16 | parser.add_argument('--lr', type=float,default=1e-4,help="learning rate")
17 |
18 | parser.add_argument('--num_workers', type=int, default=8, help='number of workers.')
19 | # path
20 | parser.add_argument('--data_file_dir', type=str, default='./data_dir/', help='where clean images of denoising saves.')
21 | parser.add_argument('--dataset_dir', type=str, default='/data/guohang/dataset',
22 | help='where training images of deraining saves.')
23 | parser.add_argument('--output_path', type=str, default="./output/", help='output save path')
24 | parser.add_argument("--wblogger",type=str,default=None,help = "Determine to log to wandb or not and the project name")
25 | parser.add_argument("--ckpt_dir",type=str,default="./train_ckpt",help = "Name of the Directory where the checkpoint is to be saved")
26 | parser.add_argument("--num_gpus",type=int,default=4,help = "Number of GPUs to use for training")
27 |
28 | options = parser.parse_args()
29 |
30 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import subprocess
3 | from tqdm import tqdm
4 | import numpy as np
5 | import torch
6 | from torch.utils.data import DataLoader
7 | import os
8 | import torch.nn as nn
9 | import torch.optim as optim
10 | from utils.dataset_utils import DenoiseTestDataset, DerainLowlightDataset,SRHybridTestDataset
11 | from utils.val_utils import AverageMeter, compute_psnr_ssim
12 | from utils.image_io import save_image_tensor
13 | from net.ipt import IPT
14 | import lightning.pytorch as pl
15 | import torch.nn.functional as F
16 | from utils.schedulers import LinearWarmupCosineAnnealingLR
17 | from utils.common import calculate_psnr_ssim
18 | from net.edt import EDT
19 | device = torch.device('cuda')
20 | from matplotlib import pyplot as plt
21 |
22 |
23 |
24 | class MultiTaskIRModel(pl.LightningModule):
25 | def __init__(self,args):
26 | super().__init__()
27 | self.args = args
28 | if args.arch == 'IPT':
29 | self.net = IPT(args)
30 | state_dict = torch.load('/data/guohang/pretrained/IPT_pretrain.pt')
31 | self.net.load_state_dict(state_dict, strict=False)
32 | elif args.arch == 'EDT':
33 | self.net = EDT(args)
34 | state_dict = torch.load('/data/guohang/pretrained/SRx2x3x4_EDTB_ImageNet200K.pth')
35 | self.net.load_state_dict(state_dict, strict=False)
36 |
37 |
38 | def forward(self,x):
39 | return self.net(x)
40 |
41 | def training_step(self, batch, batch_idx):
42 | ([clean_name, de_id], degrad_patch, clean_patch) = batch
43 | restored = self.net(degrad_patch)
44 |
45 | loss = self.loss_fn(restored,clean_patch)
46 | self.log("train_loss", loss)
47 | return loss
48 |
49 | def lr_scheduler_step(self,scheduler,metric):
50 | scheduler.step(self.current_epoch)
51 |
52 |
53 | def configure_optimizers(self):
54 | optimizer = optim.AdamW(self.parameters(), lr=1e-5)
55 | scheduler = LinearWarmupCosineAnnealingLR(optimizer=optimizer,warmup_epochs=15,max_epochs=150)
56 |
57 | return [optimizer],[scheduler]
58 |
59 |
60 |
61 | def test_Denoise(net, dataset, sigma=15):
62 | dataset.set_sigma(sigma)
63 | testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0)
64 |
65 | psnr = AverageMeter()
66 | ssim = AverageMeter()
67 |
68 | with torch.no_grad():
69 | for ([clean_name], degrad_patch, clean_patch) in tqdm(testloader):
70 | degrad_patch, clean_patch = degrad_patch.to(device), clean_patch.to(device)
71 | restored = net(degrad_patch)
72 | temp_psnr, temp_ssim, N = compute_psnr_ssim(restored, clean_patch)
73 | psnr.update(temp_psnr, N)
74 | ssim.update(temp_ssim, N)
75 |
76 | print("Denoise sigma=%d: psnr: %.2f, ssim: %.4f" % (sigma, psnr.avg, ssim.avg))
77 |
78 |
79 |
80 | def test_Derain_LowLight(net, dataset, task="derain"):
81 | dataset.set_dataset(task)
82 | testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0)
83 |
84 | psnr = AverageMeter()
85 | ssim = AverageMeter()
86 |
87 | with torch.no_grad():
88 | for ([degraded_name], degrad_patch, clean_patch) in tqdm(testloader):
89 | degrad_patch, clean_patch = degrad_patch.to(device), clean_patch.to(device)
90 | restored = net(degrad_patch)
91 | to_y = True if 'derain' in task else False
92 | temp_psnr, temp_ssim, N = compute_psnr_ssim(restored, clean_patch,to_y=to_y)
93 | psnr.update(temp_psnr, N)
94 | ssim.update(temp_ssim, N)
95 |
96 | print("PSNR: %.2f, SSIM: %.4f" % (psnr.avg, ssim.avg))
97 |
98 |
99 |
100 |
101 | def test_SR(net,dataset,scale):
102 | dataset.set_scale(scale)
103 | testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0)
104 |
105 | psnr = AverageMeter()
106 | ssim = AverageMeter()
107 |
108 | with torch.no_grad():
109 | for ([clean_name], degrad_patch, clean_patch) in tqdm(testloader):
110 | degrad_patch, clean_patch = degrad_patch.to(device), clean_patch.to(device)
111 | restored = net(degrad_patch)
112 |
113 | temp_psnr, temp_ssim, N = compute_psnr_ssim(restored, clean_patch,to_y=True,bd=scale)
114 |
115 | psnr.update(temp_psnr, N)
116 | ssim.update(temp_ssim, N)
117 |
118 | print("SR scale=%d: psnr: %.2f, ssim: %.4f" % (scale, psnr.avg, ssim.avg))
119 |
120 |
121 |
122 |
123 |
124 | def test_hybrid_degradation(net,dataset,scale):
125 | dataset.set_scale(scale)
126 | testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0)
127 | psnr = AverageMeter()
128 | ssim = AverageMeter()
129 |
130 | with torch.no_grad():
131 | for ([clean_name], degrad_patch, clean_patch) in tqdm(testloader):
132 | degrad_patch, clean_patch = degrad_patch.to(device), clean_patch.to(device)
133 | restored = net(degrad_patch)
134 | temp_psnr, temp_ssim, N = compute_psnr_ssim(restored, clean_patch,to_y=True,bd=scale)
135 | psnr.update(temp_psnr, N)
136 | ssim.update(temp_ssim, N)
137 | print("SR scale=%d: psnr: %.2f, ssim: %.4f" % (scale, psnr.avg, ssim.avg))
138 |
139 |
140 |
141 |
142 | if __name__ == '__main__':
143 | parser = argparse.ArgumentParser()
144 | # Input Parameters
145 | parser.add_argument('--cuda', type=int, default=0)
146 | parser.add_argument('--arch', type=str, default='IPT', choices=['IPT', 'EDT'])
147 | parser.add_argument('--de_type', nargs='+', default='lr4_noise30',
148 | choices=['lr4_noise30', 'lr4_jpeg30', 'sr_2', 'sr_3', 'sr_4', 'denoise_30'
149 | 'denoise_50', 'derainL', 'derainH', 'low_light'],
150 | help='which type of degradations is training and testing for.')
151 | parser.add_argument('--output_path', type=str, default="./test_output/", help='output save path')
152 | parser.add_argument('--base_path', type=str, default="/data/guohang/dataset", help='save path of test noisy images')
153 | parser.add_argument('--ckpt_name', type=str, default='/data/guohang/AdaptIR/train_ckpt/last.ckpt', help='checkpoint save path')
154 | testopt = parser.parse_args()
155 |
156 |
157 |
158 | np.random.seed(0)
159 | torch.manual_seed(0)
160 | torch.cuda.set_device(testopt.cuda)
161 |
162 |
163 | ckpt_path = testopt.ckpt_name
164 |
165 | denoise_splits = ["ColorDN/Urban100HQ"]
166 | derainH_splits = ["Rain100H/"]
167 | derainL_splits = ["Rain100L/"]
168 | low_light_splits = ["LOLv1/Test"]
169 | hybrid_splits = ['Set5','Set14','Urban100','B100','Manga109']
170 | sr_splits = ['Set5','Set14','Urban100','B100','Manga109']
171 |
172 | denoise_tests = []
173 | derain_tests = []
174 | sr_tests = []
175 |
176 |
177 | print("CKPT name : {}".format(ckpt_path))
178 |
179 |
180 | net = MultiTaskIRModel(testopt)
181 | ckpt = torch.load(ckpt_path)
182 | net.load_state_dict(ckpt['state_dict'])
183 | net.eval()
184 | net.to(device)
185 |
186 |
187 | if 'denoise' in testopt.de_type:
188 | base_path = testopt.base_path
189 | for i in denoise_splits:
190 | testopt.denoise_path = os.path.join(base_path, i)
191 | denoise_testset = DenoiseTestDataset(testopt)
192 | denoise_tests.append(denoise_testset)
193 | for testset,name in zip(denoise_tests,denoise_splits):
194 | if 'denoise_30' in testopt.de_type:
195 | print('Start {} testing Sigma=30...'.format(name))
196 | test_Denoise(net, testset, sigma=30)
197 | if 'denoise_50' in testopt.de_type:
198 | print('Start {} testing Sigma=50...'.format(name))
199 | test_Denoise(net, testset, sigma=50)
200 |
201 | elif 'derainL' in testopt.de_type:
202 | print('Start testing light rain streak removal...')
203 | derain_base_path = testopt.base_path
204 | for name in derainL_splits:
205 | print('Start testing {} rain streak removal...'.format(name))
206 | testopt.derain_path = os.path.join(derain_base_path,name)
207 | derain_set = DerainLowlightDataset(testopt)
208 | test_Derain_LowLight(net, derain_set, task="derainL")
209 |
210 | elif 'derainH' in testopt.de_type:
211 | print('Start testing heavy rain streak removal...')
212 | derain_base_path = testopt.base_path
213 | for name in derainH_splits:
214 | print('Start testing {} rain streak removal...'.format(name))
215 | testopt.derain_path = os.path.join(derain_base_path,name)
216 | derain_set = DerainLowlightDataset(testopt)
217 | test_Derain_LowLight(net, derain_set, task="derainH")
218 |
219 | elif 'low_light' in testopt.de_type:
220 | print('Start testing heavy rain streak removal...')
221 | low_light_base_path = testopt.base_path
222 | for name in low_light_splits:
223 | print('Start testing {} low light enhancement...'.format(name))
224 | testopt.low_light_path = os.path.join(low_light_base_path,name)
225 | lowlight_set = DerainLowlightDataset(testopt)
226 | test_Derain_LowLight(net, lowlight_set, task="low_light")
227 |
228 |
229 | elif 'sr' in testopt.de_type:
230 | print('Start testing super-resolution...')
231 | sr_base_path = testopt.base_path
232 | for name in sr_splits:
233 | print('Start testing {} super-resolution...'.format(name))
234 | testopt.sr_path = os.path.join(sr_base_path,'ARTSR',name,'HR')
235 | sr_set = SRHybridTestDataset(testopt)
236 | sr_tests.append(sr_set)
237 | for testset,name in zip(sr_tests,sr_splits):
238 | if 'sr_2' in testopt.de_type:
239 | print('Start {} testing SRx2...'.format(name))
240 | test_SR(net, testset,scale=2)
241 | if 'sr_3' in testopt.de_type:
242 | print('Start {} testing SRx3...'.format(name))
243 | test_SR(net, testset,scale=3)
244 | if 'sr_4' in testopt.de_type:
245 | print('Start {} testing SRx4...'.format(name))
246 | test_SR(net, testset,scale=4)
247 |
248 |
249 | elif 'lr4_noise30' in testopt.de_type:
250 | print('Start testing super-resolution...')
251 | sr_base_path = testopt.base_path
252 | for name in sr_splits:
253 | print('Start testing {} LR4+Noise30...'.format(name))
254 | testopt.sr_path = os.path.join(sr_base_path,'ARTSR',name,'HR')
255 | sr_set = SRHybridTestDataset(testopt)
256 | sr_tests.append(sr_set)
257 | test_SR(net, sr_set, scale=4)
258 |
259 |
260 | elif 'lr4_jpeg30' in testopt.de_type:
261 | print('Start testing super-resolution...')
262 | sr_base_path = testopt.base_path
263 | for name in sr_splits:
264 | print('Start testing {} LR4+JPEG30...'.format(name))
265 | testopt.sr_path = os.path.join(sr_base_path,'ARTSR',name,'HR')
266 | sr_set = SRHybridTestDataset(testopt)
267 | sr_tests.append(sr_set)
268 | test_SR(net, sr_set, scale=4)
269 |
270 |
271 | else:
272 | raise NotImplementedError
273 |
274 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import torch.nn as nn
4 | import torch.optim as optim
5 | from torch.utils.data import DataLoader
6 | from utils.dataset_utils import PromptTrainDataset
7 | from net.ipt import IPT
8 | from net.edt import EDT
9 | from utils.schedulers import LinearWarmupCosineAnnealingLR
10 | from torch.optim.lr_scheduler import MultiStepLR
11 | import numpy as np
12 | from options import options as opt
13 | import lightning.pytorch as pl
14 | from lightning.pytorch.loggers import WandbLogger,TensorBoardLogger
15 | from lightning.pytorch.callbacks import ModelCheckpoint
16 | from utils.val_utils import AverageMeter, compute_psnr_ssim
17 |
18 |
19 | class MultiTaskIRModel(pl.LightningModule):
20 | def __init__(self,args):
21 | super().__init__()
22 | print('load from pretrained model...')
23 | self.args = args
24 | if args.arch == 'IPT':
25 | self.net = IPT(args)
26 | state_dict = torch.load('/data/guohang/pretrained/IPT_pretrain.pt')
27 | self.net.load_state_dict(state_dict, strict=False)
28 | elif args.arch == 'EDT':
29 | self.net = EDT(args)
30 | state_dict = torch.load('/data/guohang/pretrained/SRx2x3x4_EDTB_ImageNet200K.pth')
31 | self.net.load_state_dict(state_dict, strict=False)
32 | print('frezz parameters in body except adapter, head and tail are NOT trainable')
33 | for name, param in self.net.named_parameters():
34 | if "adaptir" not in name:
35 | param.requires_grad = False
36 |
37 | # for name, param in self.net.named_parameters():
38 | # if param.requires_grad:
39 | # print(name)
40 |
41 | self.loss_fn = nn.L1Loss()
42 | self.save_hyperparameters()
43 |
44 |
45 | def forward(self,x):
46 | return self.net(x)
47 |
48 | def training_step(self, batch, batch_idx):
49 | ([clean_name, de_id], degrad_patch, clean_patch) = batch
50 | restored = self.net(degrad_patch,de_id)
51 | loss = self.loss_fn(restored,clean_patch)
52 | self.log("train_loss", loss)
53 | return loss
54 |
55 |
56 |
57 | # def validation_step(self,batch, batch_idx):
58 | # ([clean_name], degrad_patch, clean_patch) = batch
59 | # restored = self.net(degrad_patch)
60 | # psnr_to_y = False if 'denoise' in self.args.de_type else True
61 | # temp_psnr, temp_ssim, N = compute_psnr_ssim(restored, clean_patch,to_y=psnr_to_y)
62 | # self.log('psnr',temp_psnr,on_epoch=True,sync_dist=True)
63 | # return temp_psnr
64 |
65 |
66 |
67 | def lr_scheduler_step(self,scheduler,metric):
68 | scheduler.step()
69 |
70 |
71 | def configure_optimizers(self):
72 | optimizer = optim.AdamW(self.parameters(), lr=self.args.lr)
73 | scheduler = MultiStepLR(optimizer, milestones=[250,400,450,475], gamma=0.5)
74 | return [optimizer],[scheduler]
75 |
76 |
77 | def main():
78 | logger = TensorBoardLogger(save_dir = "./logs")
79 | trainset = PromptTrainDataset(opt)
80 | checkpoint_callback = ModelCheckpoint(dirpath = opt.ckpt_dir,
81 | save_last=True,)
82 |
83 | trainloader = DataLoader(trainset, batch_size=opt.batch_size, pin_memory=True, shuffle=True,
84 | drop_last=True, num_workers=opt.num_workers)
85 |
86 | model = MultiTaskIRModel(opt)
87 |
88 | trainer = pl.Trainer(max_epochs=opt.epochs,accelerator="gpu",
89 | devices=opt.num_gpus,
90 | logger=logger,callbacks=[checkpoint_callback])
91 |
92 | trainer.fit(model=model, train_dataloaders=trainloader)
93 |
94 |
95 | if __name__ == '__main__':
96 | main()
97 |
98 |
99 |
100 |
--------------------------------------------------------------------------------
/utils/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/utils/__init__.py
--------------------------------------------------------------------------------
/utils/common.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from datetime import datetime
3 | import logging
4 | import math
5 | import numpy as np
6 | import os
7 | import random
8 | from shutil import get_terminal_size
9 | import sys
10 | import time
11 |
12 | import torch
13 | from torchvision.utils import make_grid
14 |
15 |
16 | def mkdir(path):
17 | if not os.path.exists(path):
18 | os.makedirs(path)
19 |
20 |
21 | def get_timestamp():
22 | return datetime.now().strftime('%y%m%d-%H%M%S')
23 |
24 |
25 | def setup_logger(logger_name, save_dir, phase, level=logging.INFO, screen=False, to_file=False):
26 | lg = logging.getLogger(logger_name)
27 | formatter = logging.Formatter('[%(asctime)s.%(msecs)03d - %(levelname)s]: %(message)s',
28 | datefmt='%y-%m-%d %H:%M:%S')
29 | lg.setLevel(level)
30 | if to_file:
31 | log_file = os.path.join(save_dir, '{}_{}.txt'.format(get_timestamp(),phase))
32 | fh = logging.FileHandler(log_file, mode='w')
33 | fh.setFormatter(formatter)
34 | lg.addHandler(fh)
35 | if screen:
36 | sh = logging.StreamHandler()
37 | sh.setFormatter(formatter)
38 | lg.addHandler(sh)
39 |
40 |
41 | def init_random_seed(seed=0):
42 | np.random.seed(seed)
43 | random.seed(seed)
44 |
45 | # default values of benchmark and deterministic are False
46 | # # for reproducibility
47 | # torch.backends.cudnn.benchmark = False
48 | # torch.backends.cudnn.deterministic = True
49 | # for speed
50 | torch.backends.cudnn.benchmark = True
51 | torch.backends.cudnn.deterministic = False
52 |
53 | torch.manual_seed(seed)
54 | torch.cuda.manual_seed(seed)
55 | torch.cuda.manual_seed_all(seed)
56 |
57 |
58 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)):
59 | # 4D: grid (B, C, H, W), 3D: (C, H, W), 2D: (H, W)
60 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max)
61 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0])
62 |
63 | n_dim = tensor.dim()
64 | if n_dim == 4:
65 | n_img = len(tensor)
66 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), padding=0, normalize=False).numpy()
67 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
68 | elif n_dim == 3:
69 | img_np = tensor.numpy()
70 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0))
71 | elif n_dim == 2:
72 | img_np = tensor.numpy()
73 | else:
74 | raise TypeError(
75 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim))
76 |
77 | if out_type == np.uint8:
78 | img_np = (img_np * 255.0).round()
79 |
80 | return img_np.astype(out_type)
81 |
82 |
83 | def calculate_psnr_ssim(img1, img2, to_y=True, bd=0):
84 | img1 = img1.astype(np.float64)
85 | img2 = img2.astype(np.float64)
86 | if to_y:
87 | img1 = bgr2ycbcr(img1 / 255.0, only_y=True) * 255.0
88 | img2 = bgr2ycbcr(img2 / 255.0, only_y=True) * 255.0
89 | if bd != 0:
90 | img1 = img1[bd:-bd, bd:-bd]
91 | img2 = img2[bd:-bd, bd:-bd]
92 | psnr = calculate_psnr(img1, img2)
93 | ssim = calculate_ssim(img1, img2)
94 |
95 | return psnr, ssim
96 |
97 | def calculate_psnr(img1, img2):
98 | # img1 and img2 have range [0, 255]
99 | img1 = img1.astype(np.float64)
100 | img2 = img2.astype(np.float64)
101 | mse = np.mean((img1 - img2) ** 2)
102 | if mse == 0:
103 | return float('inf')
104 |
105 | return 20 * math.log10(255.0 / math.sqrt(mse))
106 |
107 |
108 | def calculate_ssim(img1, img2):
109 | C1 = (0.01 * 255) ** 2
110 | C2 = (0.03 * 255) ** 2
111 |
112 | img1 = img1.astype(np.float64)
113 | img2 = img2.astype(np.float64)
114 | kernel = cv2.getGaussianKernel(11, 1.5)
115 | window = np.outer(kernel, kernel.transpose())
116 |
117 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
118 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
119 | mu1_sq = mu1 ** 2
120 | mu2_sq = mu2 ** 2
121 | mu1_mu2 = mu1 * mu2
122 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
123 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
124 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
125 |
126 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
127 |
128 | return ssim_map.mean()
129 |
130 |
131 | def calc_psnr(sr, hr, scale=2, rgb_range=1.0, benchmark=True):
132 | # benchmark: to Y channel
133 | diff = (sr - hr).data.div(rgb_range)
134 | if benchmark:
135 | shave = scale
136 | if diff.size(1) > 1:
137 | convert = diff.new(1, 3, 1, 1)
138 | convert[0, 0, 0, 0] = 65.738
139 | convert[0, 1, 0, 0] = 129.057
140 | convert[0, 2, 0, 0] = 25.064
141 | diff.mul_(convert).div_(256)
142 | diff = diff.sum(dim=1, keepdim=True)
143 | else:
144 | shave = scale + 6
145 |
146 | valid = diff[:, :, shave:-shave, shave:-shave]
147 | mse = valid.pow(2).mean()
148 |
149 | return -10 * math.log10(mse)
150 |
151 |
152 | def rgb2ycbcr(img, only_y=True):
153 | '''same as matlab rgb2ycbcr
154 | only_y: only return Y channel
155 | Input:
156 | uint8, [0, 255]
157 | float, [0, 1]
158 | '''
159 | in_img_type = img.dtype
160 | img.astype(np.float32)
161 | if in_img_type != np.uint8:
162 | img *= 255.
163 | # convert
164 | if only_y:
165 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
166 | else:
167 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
168 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
169 | if in_img_type == np.uint8:
170 | rlt = rlt.round()
171 | else:
172 | rlt /= 255.
173 | return rlt.astype(in_img_type)
174 |
175 |
176 | def bgr2ycbcr(img, only_y=True):
177 | '''bgr version of rgb2ycbcr
178 | only_y: only return Y channel
179 | Input:
180 | uint8, [0, 255]
181 | float, [0, 1]
182 | '''
183 | in_img_type = img.dtype
184 | img.astype(np.float32)
185 | if in_img_type != np.uint8:
186 | img *= 255.
187 | # convert
188 | if only_y:
189 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
190 | else:
191 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
192 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
193 | if in_img_type == np.uint8:
194 | rlt = rlt.round()
195 | else:
196 | rlt /= 255.
197 | return rlt.astype(in_img_type)
198 |
199 |
200 | def ycbcr2rgb(img):
201 | '''same as matlab ycbcr2rgb
202 | Input:
203 | uint8, [0, 255]
204 | float, [0, 1]
205 | '''
206 | in_img_type = img.dtype
207 | img.astype(np.float32)
208 | if in_img_type != np.uint8:
209 | img *= 255.
210 | # convert
211 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
212 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
213 | if in_img_type == np.uint8:
214 | rlt = rlt.round()
215 | else:
216 | rlt /= 255.
217 | return rlt.astype(in_img_type)
218 |
219 |
220 | def flipx4_forward(model, inp):
221 | """Flip testing with X4 self ensemble, i.e., normal, flip H, flip W, flip H and W
222 | Args:
223 | model (PyTorch model)
224 | inp (Tensor): inputs defined by the model
225 |
226 | Returns:
227 | output (Tensor): outputs of the model. float, in CPU
228 | """
229 | with torch.no_grad():
230 | # normal
231 | output_f = model(inp)
232 |
233 | # flip W
234 | output = model(torch.flip(inp, (-1, )))
235 | output_f = output_f + torch.flip(output, (-1, ))
236 | # flip H
237 | output = model(torch.flip(inp, (-2, )))
238 | output_f = output_f + torch.flip(output, (-2, ))
239 | # flip both H and W
240 | output = model(torch.flip(inp, (-2, -1)))
241 | output_f = output_f + torch.flip(output, (-2, -1))
242 |
243 | output_f = output_f / 4
244 |
245 | return output_f
246 |
247 |
248 | def flipRotx8_forward(model, inp):
249 | output_f = flipx4_forward(model, inp)
250 |
251 | # rot 90
252 | output = flipx4_forward(model, inp.permute(0, 1, 3, 2))
253 | output_f = output_f + output.permute(0, 1, 3, 2)
254 |
255 | output_f = output_f / 2
256 |
257 | return output_f
258 |
259 |
260 | class ProgressBar(object):
261 | '''A progress bar which can print the progress
262 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py
263 | '''
264 |
265 | def __init__(self, task_num=0, bar_width=50, start=True):
266 | self.task_num = task_num
267 | max_bar_width = self._get_max_bar_width()
268 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width)
269 | self.completed = 0
270 | if start:
271 | self.start()
272 |
273 | def _get_max_bar_width(self):
274 | terminal_width, _ = get_terminal_size()
275 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50)
276 | if max_bar_width < 10:
277 | print('terminal width is too small ({}), please consider widen the terminal for better '
278 | 'progressbar visualization'.format(terminal_width))
279 | max_bar_width = 10
280 | return max_bar_width
281 |
282 | def start(self):
283 | if self.task_num > 0:
284 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format(
285 | ' ' * self.bar_width, self.task_num, 'Start...'))
286 | else:
287 | sys.stdout.write('completed: 0, elapsed: 0s')
288 | sys.stdout.flush()
289 | self.start_time = time.time()
290 |
291 | def update(self, msg='In progress...'):
292 | self.completed += 1
293 | elapsed = time.time() - self.start_time
294 | fps = self.completed / elapsed
295 | if self.task_num > 0:
296 | percentage = self.completed / float(self.task_num)
297 | eta = int(elapsed * (1 - percentage) / percentage + 0.5)
298 | mark_width = int(self.bar_width * percentage)
299 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width)
300 | sys.stdout.write('\033[2F') # cursor up 2 lines
301 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display)
302 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format(
303 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg))
304 | else:
305 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format(
306 | self.completed, int(elapsed + 0.5), fps))
307 | sys.stdout.flush()
308 |
309 |
310 | def scandir(dir_path, suffix=None, recursive=False, full_path=False):
311 | """Scan a directory to find the interested files.
312 |
313 | Args:
314 | dir_path (str): Path of the directory.
315 | suffix (str | tuple(str), optional): File suffix that we are
316 | interested in. Default: None.
317 | recursive (bool, optional): If set to True, recursively scan the
318 | directory. Default: False.
319 | full_path (bool, optional): If set to True, include the dir_path.
320 | Default: False.
321 |
322 | Returns:
323 | A generator for all the interested files with relative pathes.
324 | """
325 |
326 | if (suffix is not None) and not isinstance(suffix, (str, tuple)):
327 | raise TypeError('"suffix" must be a string or tuple of strings')
328 |
329 | root = dir_path
330 |
331 | def _scandir(dir_path, suffix, recursive):
332 | for entry in os.scandir(dir_path):
333 | if not entry.name.startswith('.') and entry.is_file():
334 | if full_path:
335 | return_path = entry.path
336 | else:
337 | return_path = os.path.relpath(entry.path, root)
338 |
339 | if suffix is None:
340 | yield return_path
341 | elif return_path.endswith(suffix):
342 | yield return_path
343 | else:
344 | if recursive:
345 | yield from _scandir(
346 | entry.path, suffix=suffix, recursive=recursive)
347 | else:
348 | continue
349 |
350 | return _scandir(dir_path, suffix=suffix, recursive=recursive)
351 |
--------------------------------------------------------------------------------
/utils/dataset_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | import copy
4 | from PIL import Image
5 | import numpy as np
6 | import cv2
7 | from torch.utils.data import Dataset, DataLoader
8 | from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor
9 | import torch
10 |
11 | from utils.image_utils import random_augmentation, crop_img
12 | from utils.degradation_utils import Degradation
13 |
14 | def add_jpg_compression(img, quality=30):
15 | """Add JPG compression artifacts.
16 |
17 | Args:
18 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32.
19 | quality (float): JPG compression quality. 0 for lowest quality, 100 for
20 | best quality. Default: 90.
21 |
22 | Returns:
23 | (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1],
24 | float32.
25 | """
26 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
27 | _, encimg = cv2.imencode('.jpg', img, encode_param)
28 | img = np.clip(cv2.imdecode(encimg, 1), 0, 255).astype(np.uint8)
29 | return img
30 |
31 |
32 | class PromptTrainDataset(Dataset):
33 | def __init__(self, args):
34 | super(PromptTrainDataset, self).__init__()
35 | self.args = args
36 | self.D = Degradation(args)
37 | self.de_temp = 0
38 | self.de_type = self.args.de_type
39 | print(self.de_type)
40 | self._init_ids()
41 | self._merge_ids()
42 |
43 | self.crop_transform = Compose([
44 | ToPILImage(),
45 | RandomCrop(args.patch_size),
46 | ]) # only used in denoise
47 |
48 | self.toTensor = ToTensor()
49 |
50 | def _init_ids(self):
51 | if 'sr' in self.de_type:
52 | self._init_sr_idx()
53 | if 'lr4_noise30' in self.de_type:
54 | self._init_lr_dn_idx()
55 | if 'lr4_jpeg30' in self.de_type:
56 | self._init_lr_jpeg_idx()
57 | if 'denoise_30' in self.de_type or 'denoise_50' in self.de_type:
58 | self._init_dn_ids()
59 | if 'derainL' in self.de_type:
60 | self._init_rs_ids(mode='L')
61 | if 'derainH' in self.de_type:
62 | self._init_rs_ids(mode='H')
63 | if 'low_light' in self.de_type:
64 | self._init_low_light_ids()
65 |
66 | def _init_lr_dn_idx(self):
67 | clean_ids = []
68 | div2k = sorted(os.listdir(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR'))[:800]
69 | filckr2k = os.listdir(self.args.dataset_dir + '/Flickr2K/Flickr2K_HR')
70 | clean_ids += [os.path.join(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR', file) for file in div2k]
71 | clean_ids += [os.path.join(self.args.dataset_dir + '/Flickr2K/Flickr2K_HR', file) for file in filckr2k] # total -- 3450
72 |
73 |
74 | self.lr_dn_ids = [{"clean_id": x, "de_type": self.de_type} for x in clean_ids]
75 | random.shuffle(self.lr_dn_ids)
76 | self.num_clean = len(self.lr_dn_ids)
77 | print("Total LR4&DN30 Ids : {}".format(self.num_clean))
78 |
79 | def _init_lr_jpeg_idx(self):
80 | clean_ids = []
81 | div2k = sorted(os.listdir(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR'))[:800]
82 | filckr2k = os.listdir(self.args.dataset_dir + '/Flickr2K/Flickr2K_HR')
83 | clean_ids += [os.path.join(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR', file) for file in div2k]
84 | clean_ids += [os.path.join(self.args.dataset_dir + '/Flickr2K/Flickr2K_HR', file) for file in filckr2k] # total -- 3450
85 |
86 |
87 | self.lr_jpeg_ids = [{"clean_id": x, "de_type": self.de_type} for x in clean_ids]
88 | random.shuffle(self.lr_jpeg_ids)
89 |
90 | self.num_clean = len(self.lr_jpeg_ids)
91 | print("Total LR4&JPEG30 Ids : {}".format(self.num_clean))
92 |
93 |
94 |
95 | def _init_sr_idx(self):
96 | clean_ids = []
97 | div2k = sorted(os.listdir(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR'))[:800]
98 | filckr2k = os.listdir(self.args.dataset_dir + '/Flickr2K/Flickr2K_HR')
99 | clean_ids += [os.path.join(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR', file) for file in div2k]
100 | clean_ids += [os.path.join(self.args.dataset_dir + '/Flickr2K/Flickr2K_HR', file) for file in filckr2k] # total -- 3450
101 |
102 | if 'sr_2' in self.de_type:
103 | self.sr2_ids = [{"clean_id": x, "de_type": self.de_type} for x in clean_ids]
104 | random.shuffle(self.sr2_ids)
105 | if 'sr_3' in self.de_type:
106 | self.sr3_ids = [{"clean_id": x, "de_type": self.de_type} for x in clean_ids]
107 | random.shuffle(self.sr3_ids)
108 | if 'sr_4' in self.de_type:
109 | self.sr4_ids = [{"clean_id": x, "de_type": self.de_type} for x in clean_ids]
110 | random.shuffle(self.sr4_ids)
111 |
112 | self.num_clean = len(clean_ids)
113 | print("Total SR Ids : {}".format(self.num_clean))
114 |
115 | def _init_low_light_ids(self):
116 | temp_ids = []
117 | temp_ids += os.listdir(os.path.join(self.args.dataset_dir, 'LOLv1', 'Train','input'))
118 | new_ids = []
119 | for name in temp_ids:
120 | if 'DS' not in name:
121 | new_ids.append(name)
122 | temp_ids = new_ids
123 | self.low_light_ids = [{"clean_id": os.path.join(self.args.dataset_dir, 'LOLv1', 'Train','input', x), "de_type": self.de_type} for x in temp_ids] * 16
124 | random.shuffle(self.low_light_ids)
125 | self.low_light_counter = 0
126 | self.num_low_light = len(self.low_light_ids)
127 | print("Total Low-light Ids : {}".format(self.num_low_light))
128 |
129 |
130 | def _init_dn_ids(self):
131 | ref_file = self.args.data_file_dir + "noisy/denoise.txt"
132 | temp_ids = []
133 | temp_ids+= [id_.strip() for id_ in open(ref_file)]
134 | clean_ids = []
135 | name_list = os.listdir(self.args.dataset_dir + '/WED')+os.listdir(self.args.dataset_dir + '/BSD400')
136 | clean_ids += [id_ for id_ in name_list if id_.strip() in temp_ids]
137 | tmp=[]
138 | for elem in clean_ids:
139 | if 'bmp' in elem:
140 | tmp.append(self.args.dataset_dir + '/WED/'+elem)
141 | else:
142 | tmp.append(self.args.dataset_dir + '/BSD400/'+elem)
143 | clean_ids = tmp
144 | # add DIV2K and Filkr2K
145 | div2k = sorted(os.listdir(self.args.dataset_dir+'/DIV2K/DIV2K_train_HR'))[:800]
146 | filckr2k = os.listdir(self.args.dataset_dir+'/Flickr2K/Flickr2K_HR')
147 | clean_ids += [os.path.join(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR',file) for file in div2k]
148 | clean_ids += [os.path.join(self.args.dataset_dir+'/Flickr2K/Flickr2K_HR',file) for file in filckr2k] # total -- 8194
149 |
150 | if 'denoise_30' in self.de_type:
151 | self.s30_ids = [{"clean_id": x,"de_type":'denoise_30'} for x in clean_ids]
152 | random.shuffle(self.s30_ids)
153 | self.s30_counter = 0
154 | if 'denoise_50' in self.de_type:
155 | self.s50_ids = [{"clean_id": x,"de_type":'denoise_50'} for x in clean_ids]
156 | random.shuffle(self.s50_ids)
157 | self.s50_counter = 0
158 |
159 | self.num_clean = len(clean_ids)
160 | print("Total Denoise Ids : {}".format(self.num_clean))
161 |
162 | def _init_rs_ids(self,mode):
163 | if mode == 'H':
164 | temp_ids = []
165 | rain_path = self.args.dataset_dir + '/RainTrainH'
166 | rain13k = os.listdir(rain_path)
167 | temp_ids += [os.path.join(rain_path, file.replace('norain', 'rain')) for file in rain13k if 'norain-' in file] * 4
168 |
169 | self.rsH_ids = [{"clean_id": x, "de_type": 'derainH'} for x in temp_ids]
170 | random.shuffle(self.rsH_ids)
171 | self.rlH_counter = 0
172 | self.num_rlH = len(self.rsH_ids)
173 | print("Total Heavy Rainy Ids : {}".format(self.num_rlH))
174 |
175 | else:
176 | temp_ids = []
177 | rain_path = self.args.dataset_dir + '/RainTrainL'
178 | rain13k = os.listdir(rain_path)
179 | temp_ids += [os.path.join(rain_path, file.replace('norain', 'rain')) for file in rain13k if 'norain-' in file] * 24
180 |
181 | self.rsL_ids = [{"clean_id": x, "de_type": 'derainL'} for x in temp_ids]
182 | random.shuffle(self.rsL_ids)
183 | self.rlL_counter = 0
184 | self.num_rlL = len(self.rsL_ids)
185 | print("Total Light Rainy Ids : {}".format(self.num_rlL))
186 |
187 | def _crop_patch(self, img_1, img_2, s_hr=1):
188 | H = img_1.shape[0]
189 | W = img_1.shape[1]
190 | ind_H = random.randint(0, H - self.args.patch_size)
191 | ind_W = random.randint(0, W - self.args.patch_size)
192 |
193 | patch_1 = img_1[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size]
194 | patch_2 = img_2[ind_H * s_hr:(ind_H + self.args.patch_size) * s_hr,
195 | ind_W * s_hr:(ind_W + self.args.patch_size) * s_hr]
196 |
197 | return patch_1, patch_2
198 |
199 | def _get_gt_name(self, rainy_name):
200 | gt_name = rainy_name.split("rain-")[0] + 'norain-' + rainy_name.split('rain-')[-1]
201 | return gt_name
202 |
203 | def _get_lr_name(self, hr_name):
204 | scale = self.de_type.split('_')[-1]
205 | base_name = os.path.basename(hr_name).split('.')[0]
206 | lr_path = os.path.join(hr_name.split('HR')[0] + 'LR_bicubic', 'X' + scale, base_name + 'x' + scale + '.png')
207 | return lr_path
208 |
209 | def _get_hybrid_name(self, hr_name):
210 | scale = '4'
211 | base_name = os.path.basename(hr_name).split('.')[0]
212 | lr_path = os.path.join(hr_name.split('HR')[0] + 'LR_bicubic', 'X' + scale, base_name + 'x' + scale + '.png')
213 | return lr_path
214 |
215 |
216 | def _get_normal_light_name(self, low_light_name):
217 | normal_light_name = low_light_name.replace('input', 'target')
218 | return normal_light_name
219 |
220 |
221 |
222 | def _merge_ids(self):
223 | self.sample_ids = []
224 | if 'lr4_noise30' in self.de_type:
225 | self.sample_ids += self.lr_dn_ids
226 | if 'lr4_jpeg30' in self.de_type:
227 | self.sample_ids += self.lr_jpeg_ids
228 | if 'denoise_30' in self.de_type:
229 | self.sample_ids += self.s30_ids
230 | if 'denoise_50' in self.de_type:
231 | self.sample_ids += self.s50_ids
232 | if "derainL" in self.de_type:
233 | self.sample_ids += self.rsL_ids
234 | if "derainH" in self.de_type:
235 | self.sample_ids += self.rsH_ids
236 | if "low_light" in self.de_type:
237 | self.sample_ids += self.low_light_ids
238 | if 'sr_2' in self.de_type:
239 | self.sample_ids += self.sr2_ids
240 | if 'sr_3' in self.de_type:
241 | self.sample_ids += self.sr3_ids
242 | if 'sr_4' in self.de_type:
243 | self.sample_ids += self.sr4_ids
244 |
245 | random.shuffle(self.sample_ids)
246 | print(len(self.sample_ids))
247 |
248 |
249 | def __getitem__(self, idx):
250 | sample = self.sample_ids[idx]
251 | de_id = sample["de_type"]
252 | if 'denoise' in de_id: # denoise
253 | clean_id = sample["clean_id"]
254 | clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16)
255 | clean_patch = self.crop_transform(clean_img)
256 | clean_patch = np.array(clean_patch)
257 | clean_name = clean_id.split("/")[-1].split('.')[0]
258 | clean_patch = random_augmentation(clean_patch)[0]
259 | degrad_patch = self.D.single_degrade(clean_patch, de_id)
260 |
261 | if 'derain' in de_id:
262 | # Rain Streak Removal
263 | degrad_img = crop_img(np.array(Image.open(sample["clean_id"]).convert('RGB')), base=16)
264 | clean_name = self._get_gt_name(sample["clean_id"])
265 | clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16)
266 | degrad_patch, clean_patch = random_augmentation(*self._crop_patch(degrad_img, clean_img))
267 |
268 | if 'sr' in de_id:
269 | scale = int(self.de_type.split('_')[-1])
270 | hr_img = np.array(Image.open(sample["clean_id"]).convert('RGB'))
271 | clean_name = sample['clean_id'].split('/')[-1].split('.')[0]
272 | lr_name = self._get_hybrid_name(sample["clean_id"])
273 | lr_img = np.array(Image.open(lr_name).convert('RGB'))
274 | degrad_patch, clean_patch = random_augmentation(*self._crop_patch(lr_img, hr_img, s_hr=scale))
275 |
276 | if 'lr4_noise30' in de_id:
277 | scale = int(self.de_type[2])
278 | hr_img = np.array(Image.open(sample["clean_id"]).convert('RGB'))
279 | clean_name = sample['clean_id'].split('/')[-1].split('.')[0]
280 | lr_name = self._get_hybrid_name(sample["clean_id"])
281 | lr_img = np.array(Image.open(lr_name).convert('RGB'))
282 | degrad_patch, clean_patch = random_augmentation(*self._crop_patch(lr_img, hr_img, s_hr=scale))
283 | degrad_patch = self.D._add_gaussian_noise(degrad_patch,sigma=30)[0]
284 |
285 | if 'lr4_jpeg30' in de_id:
286 | scale = int(self.de_type[2])
287 | hr_img = np.array(Image.open(sample["clean_id"]).convert('RGB'))
288 | clean_name = sample['clean_id'].split('/')[-1].split('.')[0]
289 | lr_name = self._get_lr_name(sample["clean_id"])
290 | lr_img = np.array(Image.open(lr_name).convert('RGB'))
291 | degrad_patch, clean_patch = random_augmentation(*self._crop_patch(lr_img, hr_img, s_hr=scale))
292 | degrad_patch = add_jpg_compression(degrad_patch)
293 |
294 | if 'low_light' in de_id:
295 | degrad_img = crop_img(np.array(Image.open(sample["clean_id"]).convert('RGB')), base=16)
296 | clean_name = self._get_normal_light_name(sample["clean_id"])
297 | clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16)
298 | degrad_patch, clean_patch = random_augmentation(*self._crop_patch(degrad_img, clean_img))
299 |
300 | clean_patch = self.toTensor(clean_patch)
301 | degrad_patch = self.toTensor(degrad_patch)
302 |
303 | return [clean_name, de_id], degrad_patch, clean_patch
304 |
305 |
306 | def __len__(self):
307 | return len(self.sample_ids)
308 |
309 |
310 | class DenoiseTestDataset(Dataset):
311 | def __init__(self, args):
312 | super(DenoiseTestDataset, self).__init__()
313 | self.args = copy.deepcopy(args)
314 | self.clean_ids = []
315 | self.sigma = 15
316 |
317 | self._init_clean_ids()
318 |
319 | self.toTensor = ToTensor()
320 |
321 | def _init_clean_ids(self):
322 | name_list = os.listdir(self.args.denoise_path)
323 | self.clean_ids += [os.path.join(self.args.denoise_path, id_) for id_ in name_list]
324 | self.clean_ids = sorted(self.clean_ids)
325 | self.num_clean = len(self.clean_ids)
326 |
327 | def _add_gaussian_noise(self, clean_patch):
328 | noise = np.random.randn(*clean_patch.shape)
329 | noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255)
330 | return noisy_patch, clean_patch
331 |
332 | def set_sigma(self, sigma):
333 | self.sigma = sigma
334 |
335 | def __getitem__(self, clean_id):
336 | # clean_img = crop_img(np.array(Image.open(self.clean_ids[clean_id]).convert('RGB')), base=16)
337 | clean_img = np.array(Image.open(self.clean_ids[clean_id]).convert('RGB')).astype(np.float32)
338 | clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0]
339 | noisy_img = self._add_gaussian_noise(clean_img)[0]
340 | # clean_img, noisy_img = self.toTensor(clean_img), self.toTensor(noisy_img)
341 | clean_img, noisy_img = torch.from_numpy(np.transpose(clean_img / 255., (2, 0, 1))).float(), torch.from_numpy(
342 | np.transpose(noisy_img / 255., (2, 0, 1))).float()
343 |
344 | return [clean_name], noisy_img, clean_img
345 |
346 | def tile_degrad(input_, tile=128, tile_overlap=0):
347 | sigma_dict = {0: 0, 1: 15, 2: 25, 3: 50}
348 | b, c, h, w = input_.shape
349 | tile = min(tile, h, w)
350 | assert tile % 8 == 0, "tile size should be multiple of 8"
351 |
352 | stride = tile - tile_overlap
353 | h_idx_list = list(range(0, h - tile, stride)) + [h - tile]
354 | w_idx_list = list(range(0, w - tile, stride)) + [w - tile]
355 | E = torch.zeros(b, c, h, w).type_as(input_)
356 | W = torch.zeros_like(E)
357 | s = 0
358 | for h_idx in h_idx_list:
359 | for w_idx in w_idx_list:
360 | in_patch = input_[..., h_idx:h_idx + tile, w_idx:w_idx + tile]
361 | out_patch = in_patch
362 | # out_patch = model(in_patch)
363 | out_patch_mask = torch.ones_like(in_patch)
364 |
365 | E[..., h_idx:(h_idx + tile), w_idx:(w_idx + tile)].add_(out_patch)
366 | W[..., h_idx:(h_idx + tile), w_idx:(w_idx + tile)].add_(out_patch_mask)
367 | restored = E.div_(W)
368 |
369 | restored = torch.clamp(restored, 0, 1)
370 | return restored
371 |
372 | def __len__(self):
373 | return self.num_clean
374 |
375 |
376 | class DerainLowlightDataset(Dataset):
377 | def __init__(self, args):
378 | super(DerainLowlightDataset, self).__init__()
379 | self.ids = []
380 | self.args = copy.deepcopy(args)
381 | self.task = self.args.de_type
382 | self.toTensor = ToTensor()
383 | self.set_dataset(self.task)
384 |
385 | def _add_gaussian_noise(self, clean_patch):
386 | noise = np.random.randn(*clean_patch.shape)
387 | noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255).astype(np.uint8)
388 | return noisy_patch, clean_patch
389 |
390 | def _init_input_ids(self):
391 | if 'derain' in self.task: # derain
392 | self.ids = []
393 | name_list = os.listdir(os.path.join(self.args.derain_path, 'rainy'))
394 | # print(name_list)
395 | self.ids += [os.path.join(self.args.derain_path, 'rainy', id_) for id_ in name_list]
396 | elif self.task == 'low_light':
397 | self.ids = []
398 | name_list = os.listdir(os.path.join(self.args.low_light_path, 'input'))
399 | new_ids = []
400 | for name in name_list:
401 | if 'DS' not in name:
402 | new_ids.append(name)
403 | name_list = new_ids
404 | self.ids += [os.path.join(self.args.low_light_path, 'input', id_) for id_ in name_list]
405 |
406 | self.length = len(self.ids)
407 |
408 | def _get_gt_path(self, degraded_name):
409 | if 'derain' in self.task:
410 | filename = degraded_name.split('-')[-1]
411 | gt_name = os.path.join(self.args.derain_path, 'norain-' + filename)
412 | elif self.task == 'low_light':
413 | gt_name = degraded_name.replace('input', 'target')
414 | return gt_name
415 |
416 | def set_dataset(self, task):
417 | self._init_input_ids()
418 |
419 | def __getitem__(self, idx):
420 | degraded_path = self.ids[idx]
421 | clean_path = self._get_gt_path(degraded_path)
422 |
423 | degraded_img = np.array(Image.open(degraded_path).convert('RGB'))
424 |
425 | clean_img = np.array(Image.open(clean_path).convert('RGB'))
426 |
427 | clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img)
428 | degraded_name = degraded_path.split('/')[-1][:-4]
429 |
430 | return [degraded_name], degraded_img, clean_img
431 |
432 | def __len__(self):
433 | return self.length
434 |
435 |
436 | class SRHybridTestDataset(Dataset):
437 | def __init__(self, args):
438 | super(SRHybridTestDataset, self).__init__()
439 | self.args = copy.deepcopy(args)
440 | self.clean_ids = []
441 | self.toTensor = ToTensor()
442 |
443 | def set_scale(self, s):
444 | self.scale = s
445 | self._init_sr_ids()
446 |
447 | def _init_sr_ids(self):
448 | hr_path = self.args.sr_path
449 | name_list = os.listdir(hr_path)
450 | self.clean_ids += [os.path.join(hr_path, id_) for id_ in name_list]
451 | self.clean_ids = sorted(self.clean_ids)
452 | self.num_clean = len(self.clean_ids)
453 |
454 | def __getitem__(self, clean_id):
455 | hr_img = np.array(Image.open(self.clean_ids[clean_id]).convert('RGB'))
456 |
457 | file_name, ext = os.path.splitext(os.path.basename(self.clean_ids[clean_id]))
458 | if 'Manga109' not in self.clean_ids[clean_id]:
459 | lr_path = os.path.join(os.path.dirname(self.clean_ids[clean_id]).replace('HR', 'LR_bicubic'),
460 | 'X{}/{}x{}{}'.format(self.scale, file_name, self.scale, ext))
461 | else:
462 | lr_path = os.path.join(os.path.dirname(self.clean_ids[clean_id]).replace('HR', 'LR_bicubic'),
463 | 'X{}/{}_LRBI_x{}{}'.format(self.scale, file_name, self.scale, ext))
464 | lr_img = np.array(Image.open(lr_path).convert('RGB'))
465 |
466 |
467 | if self.args.de_type == 'lr4_noise30':
468 | sigma=30
469 | noise = np.random.randn(*lr_img.shape)
470 | lr_img = np.clip(lr_img + noise * sigma, 0, 255).astype(np.uint8)
471 | if self.args.de_type == 'lr4_jpeg30':
472 | lr_img = add_jpg_compression(lr_img)
473 |
474 |
475 | ih, iw = lr_img.shape[:2]
476 | hr_img = hr_img[0:ih * self.scale, 0:iw * self.scale]
477 | hr_img, lr_img = self.toTensor(hr_img), self.toTensor(lr_img)
478 |
479 | return [lr_path.split('/')[-1]], lr_img, hr_img
480 |
481 | def __len__(self):
482 | return self.num_clean
483 |
484 |
485 |
486 |
487 |
--------------------------------------------------------------------------------
/utils/degradation_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor, Grayscale
3 |
4 | from PIL import Image
5 | import random
6 | import numpy as np
7 |
8 | from utils.image_utils import crop_img
9 |
10 |
11 | class Degradation(object):
12 | def __init__(self, args):
13 | super(Degradation, self).__init__()
14 | self.args = args
15 | self.toTensor = ToTensor()
16 | self.crop_transform = Compose([
17 | ToPILImage(),
18 | RandomCrop(args.patch_size),
19 | ])
20 |
21 | def _add_gaussian_noise(self, clean_patch, sigma):
22 | # noise = torch.randn(*(clean_patch.shape))
23 | # clean_patch = self.toTensor(clean_patch)
24 | noise = np.random.randn(*clean_patch.shape)
25 | noisy_patch = np.clip(clean_patch + noise * sigma, 0, 255).astype(np.uint8)
26 | return noisy_patch, clean_patch
27 |
28 | def _degrade_by_type(self, clean_patch, degrade_type):
29 | if degrade_type == 'denoise_15':
30 | # denoise sigma=15
31 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=15)
32 | elif degrade_type == 'denoise_25':
33 | # denoise sigma=25
34 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=25)
35 | elif degrade_type == 'denoise_30':
36 | # denoise sigma=30
37 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=30)
38 | elif degrade_type == 'denoise_50':
39 | # denoise sigma=50
40 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=50)
41 |
42 | return degraded_patch, clean_patch
43 |
44 | def degrade(self, clean_patch_1, clean_patch_2, degrade_type=None):
45 | if degrade_type == None:
46 | degrade_type = random.randint(0, 3)
47 | else:
48 | degrade_type = degrade_type
49 |
50 | degrad_patch_1, _ = self._degrade_by_type(clean_patch_1, degrade_type)
51 | degrad_patch_2, _ = self._degrade_by_type(clean_patch_2, degrade_type)
52 | return degrad_patch_1, degrad_patch_2
53 |
54 | def single_degrade(self,clean_patch,degrade_type = None):
55 | if degrade_type == None:
56 | degrade_type = random.randint(0, 3)
57 | else:
58 | degrade_type = degrade_type
59 |
60 | degrad_patch_1, _ = self._degrade_by_type(clean_patch, degrade_type)
61 | return degrad_patch_1
62 |
--------------------------------------------------------------------------------
/utils/image_io.py:
--------------------------------------------------------------------------------
1 | import glob
2 |
3 | import torch
4 | import torchvision
5 | import matplotlib
6 | import matplotlib.pyplot as plt
7 | import numpy as np
8 | from PIL import Image
9 |
10 | # import skvideo.io
11 |
12 | matplotlib.use('agg')
13 |
14 |
15 | def prepare_hazy_image(file_name):
16 | img_pil = crop_image(get_image(file_name, -1)[0], d=32)
17 | return pil_to_np(img_pil)
18 |
19 |
20 | def prepare_gt_img(file_name, SOTS=True):
21 | if SOTS:
22 | img_pil = crop_image(crop_a_image(get_image(file_name, -1)[0], d=10), d=32)
23 | else:
24 | img_pil = crop_image(get_image(file_name, -1)[0], d=32)
25 |
26 | return pil_to_np(img_pil)
27 |
28 |
29 | def crop_a_image(img, d=10):
30 | bbox = [
31 | int((d)),
32 | int((d)),
33 | int((img.size[0] - d)),
34 | int((img.size[1] - d)),
35 | ]
36 | img_cropped = img.crop(bbox)
37 | return img_cropped
38 |
39 |
40 | def crop_image(img, d=32):
41 | """
42 | Make dimensions divisible by d
43 |
44 | :param pil img:
45 | :param d:
46 | :return:
47 | """
48 |
49 | new_size = (img.size[0] - img.size[0] % d,
50 | img.size[1] - img.size[1] % d)
51 |
52 | bbox = [
53 | int((img.size[0] - new_size[0]) / 2),
54 | int((img.size[1] - new_size[1]) / 2),
55 | int((img.size[0] + new_size[0]) / 2),
56 | int((img.size[1] + new_size[1]) / 2),
57 | ]
58 |
59 | img_cropped = img.crop(bbox)
60 | return img_cropped
61 |
62 |
63 | def crop_np_image(img_np, d=32):
64 | return torch_to_np(crop_torch_image(np_to_torch(img_np), d))
65 |
66 |
67 | def crop_torch_image(img, d=32):
68 | """
69 | Make dimensions divisible by d
70 | image is [1, 3, W, H] or [3, W, H]
71 | :param pil img:
72 | :param d:
73 | :return:
74 | """
75 | new_size = (img.shape[-2] - img.shape[-2] % d,
76 | img.shape[-1] - img.shape[-1] % d)
77 | pad = ((img.shape[-2] - new_size[-2]) // 2, (img.shape[-1] - new_size[-1]) // 2)
78 |
79 | if len(img.shape) == 4:
80 | return img[:, :, pad[-2]: pad[-2] + new_size[-2], pad[-1]: pad[-1] + new_size[-1]]
81 | assert len(img.shape) == 3
82 | return img[:, pad[-2]: pad[-2] + new_size[-2], pad[-1]: pad[-1] + new_size[-1]]
83 |
84 |
85 | def get_params(opt_over, net, net_input, downsampler=None):
86 | """
87 | Returns parameters that we want to optimize over.
88 | :param opt_over: comma separated list, e.g. "net,input" or "net"
89 | :param net: network
90 | :param net_input: torch.Tensor that stores input `z`
91 | :param downsampler:
92 | :return:
93 | """
94 |
95 | opt_over_list = opt_over.split(',')
96 | params = []
97 |
98 | for opt in opt_over_list:
99 |
100 | if opt == 'net':
101 | params += [x for x in net.parameters()]
102 | elif opt == 'down':
103 | assert downsampler is not None
104 | params = [x for x in downsampler.parameters()]
105 | elif opt == 'input':
106 | net_input.requires_grad = True
107 | params += [net_input]
108 | else:
109 | assert False, 'what is it?'
110 |
111 | return params
112 |
113 |
114 | def get_image_grid(images_np, nrow=8):
115 | """
116 | Creates a grid from a list of images by concatenating them.
117 | :param images_np:
118 | :param nrow:
119 | :return:
120 | """
121 | images_torch = [torch.from_numpy(x).type(torch.FloatTensor) for x in images_np]
122 | torch_grid = torchvision.utils.make_grid(images_torch, nrow)
123 |
124 | return torch_grid.numpy()
125 |
126 |
127 | def plot_image_grid(name, images_np, interpolation='lanczos', output_path="output/"):
128 | """
129 | Draws images in a grid
130 |
131 | Args:
132 | images_np: list of images, each image is np.array of size 3xHxW or 1xHxW
133 | nrow: how many images will be in one row
134 | interpolation: interpolation used in plt.imshow
135 | """
136 | assert len(images_np) == 2
137 | n_channels = max(x.shape[0] for x in images_np)
138 | assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels"
139 |
140 | images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np]
141 |
142 | grid = get_image_grid(images_np, 2)
143 |
144 | if images_np[0].shape[0] == 1:
145 | plt.imshow(grid[0], cmap='gray', interpolation=interpolation)
146 | else:
147 | plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation)
148 |
149 | plt.savefig(output_path + "{}.png".format(name))
150 |
151 |
152 | def save_image_np(name, image_np, output_path="output/"):
153 | p = np_to_pil(image_np)
154 | p.save(output_path + "{}.png".format(name))
155 |
156 |
157 | def save_image_tensor(image_tensor, output_path="output/"):
158 | image_np = torch_to_np(image_tensor)
159 | # print(image_np.shape)
160 | p = np_to_pil(image_np)
161 | p.save(output_path)
162 |
163 |
164 | def video_to_images(file_name, name):
165 | video = prepare_video(file_name)
166 | for i, f in enumerate(video):
167 | save_image(name + "_{0:03d}".format(i), f)
168 |
169 |
170 | def images_to_video(images_dir, name, gray=True):
171 | num = len(glob.glob(images_dir + "/*.jpg"))
172 | c = []
173 | for i in range(num):
174 | if gray:
175 | img = prepare_gray_image(images_dir + "/" + name + "_{}.jpg".format(i))
176 | else:
177 | img = prepare_image(images_dir + "/" + name + "_{}.jpg".format(i))
178 | print(img.shape)
179 | c.append(img)
180 | save_video(name, np.array(c))
181 |
182 |
183 | def save_heatmap(name, image_np):
184 | cmap = plt.get_cmap('jet')
185 |
186 | rgba_img = cmap(image_np)
187 | rgb_img = np.delete(rgba_img, 3, 2)
188 | save_image(name, rgb_img.transpose(2, 0, 1))
189 |
190 |
191 | def save_graph(name, graph_list, output_path="output/"):
192 | plt.clf()
193 | plt.plot(graph_list)
194 | plt.savefig(output_path + name + ".png")
195 |
196 |
197 | def create_augmentations(np_image):
198 | """
199 | convention: original, left, upside-down, right, rot1, rot2, rot3
200 | :param np_image:
201 | :return:
202 | """
203 | aug = [np_image.copy(), np.rot90(np_image, 1, (1, 2)).copy(),
204 | np.rot90(np_image, 2, (1, 2)).copy(), np.rot90(np_image, 3, (1, 2)).copy()]
205 | flipped = np_image[:, ::-1, :].copy()
206 | aug += [flipped.copy(), np.rot90(flipped, 1, (1, 2)).copy(), np.rot90(flipped, 2, (1, 2)).copy(),
207 | np.rot90(flipped, 3, (1, 2)).copy()]
208 | return aug
209 |
210 |
211 | def create_video_augmentations(np_video):
212 | """
213 | convention: original, left, upside-down, right, rot1, rot2, rot3
214 | :param np_video:
215 | :return:
216 | """
217 | aug = [np_video.copy(), np.rot90(np_video, 1, (2, 3)).copy(),
218 | np.rot90(np_video, 2, (2, 3)).copy(), np.rot90(np_video, 3, (2, 3)).copy()]
219 | flipped = np_video[:, :, ::-1, :].copy()
220 | aug += [flipped.copy(), np.rot90(flipped, 1, (2, 3)).copy(), np.rot90(flipped, 2, (2, 3)).copy(),
221 | np.rot90(flipped, 3, (2, 3)).copy()]
222 | return aug
223 |
224 |
225 | def save_graphs(name, graph_dict, output_path="output/"):
226 | """
227 |
228 | :param name:
229 | :param dict graph_dict: a dict from the name of the list to the list itself.
230 | :return:
231 | """
232 | plt.clf()
233 | fig, ax = plt.subplots()
234 | for k, v in graph_dict.items():
235 | ax.plot(v, label=k)
236 | # ax.semilogy(v, label=k)
237 | ax.set_xlabel('iterations')
238 | # ax.set_ylabel(name)
239 | ax.set_ylabel('MSE-loss')
240 | # ax.set_ylabel('PSNR')
241 | plt.legend()
242 | plt.savefig(output_path + name + ".png")
243 |
244 |
245 | def load(path):
246 | """Load PIL image."""
247 | img = Image.open(path)
248 | return img
249 |
250 |
251 | def get_image(path, imsize=-1):
252 | """Load an image and resize to a cpecific size.
253 |
254 | Args:
255 | path: path to image
256 | imsize: tuple or scalar with dimensions; -1 for `no resize`
257 | """
258 | img = load(path)
259 | if isinstance(imsize, int):
260 | imsize = (imsize, imsize)
261 |
262 | if imsize[0] != -1 and img.size != imsize:
263 | if imsize[0] > img.size[0]:
264 | img = img.resize(imsize, Image.BICUBIC)
265 | else:
266 | img = img.resize(imsize, Image.ANTIALIAS)
267 |
268 | img_np = pil_to_np(img)
269 | # 3*460*620
270 | # print(np.shape(img_np))
271 |
272 | return img, img_np
273 |
274 |
275 | def prepare_gt(file_name):
276 | """
277 | loads makes it divisible
278 | :param file_name:
279 | :return: the numpy representation of the image
280 | """
281 | img = get_image(file_name, -1)
282 | # print(img[0].size)
283 |
284 | img_pil = img[0].crop([10, 10, img[0].size[0] - 10, img[0].size[1] - 10])
285 |
286 | img_pil = crop_image(img_pil, d=32)
287 |
288 | # img_pil = get_image(file_name, -1)[0]
289 | # print(img_pil.size)
290 | return pil_to_np(img_pil)
291 |
292 |
293 | def prepare_image(file_name):
294 | """
295 | loads makes it divisible
296 | :param file_name:
297 | :return: the numpy representation of the image
298 | """
299 | img = get_image(file_name, -1)
300 | # print(img[0].size)
301 | # img_pil = img[0]
302 | img_pil = crop_image(img[0], d=16)
303 | # img_pil = get_image(file_name, -1)[0]
304 | # print(img_pil.size)
305 | return pil_to_np(img_pil)
306 |
307 |
308 | # def prepare_video(file_name, folder="output/"):
309 | # data = skvideo.io.vread(folder + file_name)
310 | # return crop_torch_image(data.transpose(0, 3, 1, 2).astype(np.float32) / 255.)[:35]
311 | #
312 | #
313 | # def save_video(name, video_np, output_path="output/"):
314 | # outputdata = video_np * 255
315 | # outputdata = outputdata.astype(np.uint8)
316 | # skvideo.io.vwrite(output_path + "{}.mp4".format(name), outputdata.transpose(0, 2, 3, 1))
317 |
318 |
319 | def prepare_gray_image(file_name):
320 | img = prepare_image(file_name)
321 | return np.array([np.mean(img, axis=0)])
322 |
323 |
324 | def pil_to_np(img_PIL, with_transpose=True):
325 | """
326 | Converts image in PIL format to np.array.
327 |
328 | From W x H x C [0...255] to C x W x H [0..1]
329 | """
330 | ar = np.array(img_PIL)
331 | if len(ar.shape) == 3 and ar.shape[-1] == 4:
332 | ar = ar[:, :, :3]
333 | # this is alpha channel
334 | if with_transpose:
335 | if len(ar.shape) == 3:
336 | ar = ar.transpose(2, 0, 1)
337 | else:
338 | ar = ar[None, ...]
339 |
340 | return ar.astype(np.float32) / 255.
341 |
342 |
343 | def median(img_np_list):
344 | """
345 | assumes C x W x H [0..1]
346 | :param img_np_list:
347 | :return:
348 | """
349 | assert len(img_np_list) > 0
350 | l = len(img_np_list)
351 | shape = img_np_list[0].shape
352 | result = np.zeros(shape)
353 | for c in range(shape[0]):
354 | for w in range(shape[1]):
355 | for h in range(shape[2]):
356 | result[c, w, h] = sorted(i[c, w, h] for i in img_np_list)[l // 2]
357 | return result
358 |
359 |
360 | def average(img_np_list):
361 | """
362 | assumes C x W x H [0..1]
363 | :param img_np_list:
364 | :return:
365 | """
366 | assert len(img_np_list) > 0
367 | l = len(img_np_list)
368 | shape = img_np_list[0].shape
369 | result = np.zeros(shape)
370 | for i in img_np_list:
371 | result += i
372 | return result / l
373 |
374 |
375 | def np_to_pil(img_np):
376 | """
377 | Converts image in np.array format to PIL image.
378 |
379 | From C x W x H [0..1] to W x H x C [0...255]
380 | :param img_np:
381 | :return:
382 | """
383 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
384 |
385 | if img_np.shape[0] == 1:
386 | ar = ar[0]
387 | else:
388 | assert img_np.shape[0] == 3, img_np.shape
389 | ar = ar.transpose(1, 2, 0)
390 |
391 | return Image.fromarray(ar)
392 |
393 |
394 | def np_to_torch(img_np):
395 | """
396 | Converts image in numpy.array to torch.Tensor.
397 |
398 | From C x W x H [0..1] to C x W x H [0..1]
399 |
400 | :param img_np:
401 | :return:
402 | """
403 | return torch.from_numpy(img_np)[None, :]
404 |
405 |
406 | def torch_to_np(img_var):
407 | """
408 | Converts an image in torch.Tensor format to np.array.
409 |
410 | From 1 x C x W x H [0..1] to C x W x H [0..1]
411 | :param img_var:
412 | :return:
413 | """
414 | return img_var.detach().cpu().numpy()[0]
415 |
--------------------------------------------------------------------------------
/utils/image_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Created on 2020/9/8
3 |
4 | @author: Boyun Li
5 | """
6 | import os
7 | import numpy as np
8 | import torch
9 | import random
10 | import torch.nn as nn
11 | from torch.nn import init
12 | from PIL import Image
13 |
14 | class EdgeComputation(nn.Module):
15 | def __init__(self, test=False):
16 | super(EdgeComputation, self).__init__()
17 | self.test = test
18 | def forward(self, x):
19 | if self.test:
20 | x_diffx = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1])
21 | x_diffy = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :])
22 |
23 | # y = torch.Tensor(x.size()).cuda()
24 | y = torch.Tensor(x.size())
25 | y.fill_(0)
26 | y[:, :, :, 1:] += x_diffx
27 | y[:, :, :, :-1] += x_diffx
28 | y[:, :, 1:, :] += x_diffy
29 | y[:, :, :-1, :] += x_diffy
30 | y = torch.sum(y, 1, keepdim=True) / 3
31 | y /= 4
32 | return y
33 | else:
34 | x_diffx = torch.abs(x[:, :, 1:] - x[:, :, :-1])
35 | x_diffy = torch.abs(x[:, 1:, :] - x[:, :-1, :])
36 |
37 | y = torch.Tensor(x.size())
38 | y.fill_(0)
39 | y[:, :, 1:] += x_diffx
40 | y[:, :, :-1] += x_diffx
41 | y[:, 1:, :] += x_diffy
42 | y[:, :-1, :] += x_diffy
43 | y = torch.sum(y, 0) / 3
44 | y /= 4
45 | return y.unsqueeze(0)
46 |
47 |
48 | # randomly crop a patch from image
49 | def crop_patch(im, pch_size):
50 | H = im.shape[0]
51 | W = im.shape[1]
52 | ind_H = random.randint(0, H - pch_size)
53 | ind_W = random.randint(0, W - pch_size)
54 | pch = im[ind_H:ind_H + pch_size, ind_W:ind_W + pch_size]
55 | return pch
56 |
57 |
58 | # crop an image to the multiple of base
59 | def crop_img(image, base=64):
60 | h = image.shape[0]
61 | w = image.shape[1]
62 | crop_h = h % base
63 | crop_w = w % base
64 | return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :]
65 |
66 |
67 | # image (H, W, C) -> patches (B, H, W, C)
68 | def slice_image2patches(image, patch_size=64, overlap=0):
69 | assert image.shape[0] % patch_size == 0 and image.shape[1] % patch_size == 0
70 | H = image.shape[0]
71 | W = image.shape[1]
72 | patches = []
73 | image_padding = np.pad(image, ((overlap, overlap), (overlap, overlap), (0, 0)), mode='edge')
74 | for h in range(H // patch_size):
75 | for w in range(W // patch_size):
76 | idx_h = [h * patch_size, (h + 1) * patch_size + overlap]
77 | idx_w = [w * patch_size, (w + 1) * patch_size + overlap]
78 | patches.append(np.expand_dims(image_padding[idx_h[0]:idx_h[1], idx_w[0]:idx_w[1], :], axis=0))
79 | return np.concatenate(patches, axis=0)
80 |
81 |
82 | # patches (B, H, W, C) -> image (H, W, C)
83 | def splice_patches2image(patches, image_size, overlap=0):
84 | assert len(image_size) > 1
85 | assert patches.shape[-3] == patches.shape[-2]
86 | H = image_size[0]
87 | W = image_size[1]
88 | patch_size = patches.shape[-2] - overlap
89 | image = np.zeros(image_size)
90 | idx = 0
91 | for h in range(H // patch_size):
92 | for w in range(W // patch_size):
93 | image[h * patch_size:(h + 1) * patch_size, w * patch_size:(w + 1) * patch_size, :] = patches[idx,
94 | overlap:patch_size + overlap,
95 | overlap:patch_size + overlap,
96 | :]
97 | idx += 1
98 | return image
99 |
100 |
101 | # def data_augmentation(image, mode):
102 | # if mode == 0:
103 | # # original
104 | # out = image.numpy()
105 | # elif mode == 1:
106 | # # flip up and down
107 | # out = np.flipud(image)
108 | # elif mode == 2:
109 | # # rotate counterwise 90 degree
110 | # out = np.rot90(image, axes=(1, 2))
111 | # elif mode == 3:
112 | # # rotate 90 degree and flip up and down
113 | # out = np.rot90(image, axes=(1, 2))
114 | # out = np.flipud(out)
115 | # elif mode == 4:
116 | # # rotate 180 degree
117 | # out = np.rot90(image, k=2, axes=(1, 2))
118 | # elif mode == 5:
119 | # # rotate 180 degree and flip
120 | # out = np.rot90(image, k=2, axes=(1, 2))
121 | # out = np.flipud(out)
122 | # elif mode == 6:
123 | # # rotate 270 degree
124 | # out = np.rot90(image, k=3, axes=(1, 2))
125 | # elif mode == 7:
126 | # # rotate 270 degree and flip
127 | # out = np.rot90(image, k=3, axes=(1, 2))
128 | # out = np.flipud(out)
129 | # else:
130 | # raise Exception('Invalid choice of image transformation')
131 | # return out
132 |
133 | def data_augmentation(image, mode):
134 | if mode == 0:
135 | # original
136 | out = image.numpy()
137 | elif mode == 1:
138 | # flip up and down
139 | out = np.flipud(image)
140 | elif mode == 2:
141 | # rotate counterwise 90 degree
142 | out = np.rot90(image)
143 | elif mode == 3:
144 | # rotate 90 degree and flip up and down
145 | out = np.rot90(image)
146 | out = np.flipud(out)
147 | elif mode == 4:
148 | # rotate 180 degree
149 | out = np.rot90(image, k=2)
150 | elif mode == 5:
151 | # rotate 180 degree and flip
152 | out = np.rot90(image, k=2)
153 | out = np.flipud(out)
154 | elif mode == 6:
155 | # rotate 270 degree
156 | out = np.rot90(image, k=3)
157 | elif mode == 7:
158 | # rotate 270 degree and flip
159 | out = np.rot90(image, k=3)
160 | out = np.flipud(out)
161 | else:
162 | raise Exception('Invalid choice of image transformation')
163 | return out
164 |
165 |
166 | def random_augmentation(*args):
167 | out = []
168 | flag_aug = random.randint(1, 7)
169 | for data in args:
170 | out.append(data_augmentation(data, flag_aug).copy())
171 | return out
172 |
173 |
174 | def weights_init_normal_(m):
175 | classname = m.__class__.__name__
176 | if classname.find('Conv') != -1:
177 | init.uniform(m.weight.data, 0.0, 0.02)
178 | elif classname.find('Linear') != -1:
179 | init.uniform(m.weight.data, 0.0, 0.02)
180 | elif classname.find('BatchNorm2d') != -1:
181 | init.uniform(m.weight.data, 1.0, 0.02)
182 | init.constant(m.bias.data, 0.0)
183 |
184 |
185 | def weights_init_normal(m):
186 | classname = m.__class__.__name__
187 | if classname.find('Conv2d') != -1:
188 | m.apply(weights_init_normal_)
189 | elif classname.find('Linear') != -1:
190 | init.uniform(m.weight.data, 0.0, 0.02)
191 | elif classname.find('BatchNorm2d') != -1:
192 | init.uniform(m.weight.data, 1.0, 0.02)
193 | init.constant(m.bias.data, 0.0)
194 |
195 |
196 | def weights_init_xavier(m):
197 | classname = m.__class__.__name__
198 | if classname.find('Conv') != -1:
199 | init.xavier_normal(m.weight.data, gain=1)
200 | elif classname.find('Linear') != -1:
201 | init.xavier_normal(m.weight.data, gain=1)
202 | elif classname.find('BatchNorm2d') != -1:
203 | init.uniform(m.weight.data, 1.0, 0.02)
204 | init.constant(m.bias.data, 0.0)
205 |
206 |
207 | def weights_init_kaiming(m):
208 | classname = m.__class__.__name__
209 | if classname.find('Conv') != -1:
210 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
211 | elif classname.find('Linear') != -1:
212 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in')
213 | elif classname.find('BatchNorm2d') != -1:
214 | init.uniform(m.weight.data, 1.0, 0.02)
215 | init.constant(m.bias.data, 0.0)
216 |
217 |
218 | def weights_init_orthogonal(m):
219 | classname = m.__class__.__name__
220 | print(classname)
221 | if classname.find('Conv') != -1:
222 | init.orthogonal(m.weight.data, gain=1)
223 | elif classname.find('Linear') != -1:
224 | init.orthogonal(m.weight.data, gain=1)
225 | elif classname.find('BatchNorm2d') != -1:
226 | init.uniform(m.weight.data, 1.0, 0.02)
227 | init.constant(m.bias.data, 0.0)
228 |
229 |
230 | def init_weights(net, init_type='normal'):
231 | print('initialization method [%s]' % init_type)
232 | if init_type == 'normal':
233 | net.apply(weights_init_normal)
234 | elif init_type == 'xavier':
235 | net.apply(weights_init_xavier)
236 | elif init_type == 'kaiming':
237 | net.apply(weights_init_kaiming)
238 | elif init_type == 'orthogonal':
239 | net.apply(weights_init_orthogonal)
240 | else:
241 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
242 |
243 |
244 | def np_to_torch(img_np):
245 | """
246 | Converts image in numpy.array to torch.Tensor.
247 |
248 | From C x W x H [0..1] to C x W x H [0..1]
249 |
250 | :param img_np:
251 | :return:
252 | """
253 | return torch.from_numpy(img_np)[None, :]
254 |
255 |
256 | def torch_to_np(img_var):
257 | """
258 | Converts an image in torch.Tensor format to np.array.
259 |
260 | From 1 x C x W x H [0..1] to C x W x H [0..1]
261 | :param img_var:
262 | :return:
263 | """
264 | return img_var.detach().cpu().numpy()
265 | # return img_var.detach().cpu().numpy()[0]
266 |
267 |
268 | def save_image(name, image_np, output_path="output/normal/"):
269 | if not os.path.exists(output_path):
270 | os.mkdir(output_path)
271 |
272 | p = np_to_pil(image_np)
273 | p.save(output_path + "{}.png".format(name))
274 |
275 |
276 | def np_to_pil(img_np):
277 | """
278 | Converts image in np.array format to PIL image.
279 |
280 | From C x W x H [0..1] to W x H x C [0...255]
281 | :param img_np:
282 | :return:
283 | """
284 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8)
285 |
286 | if img_np.shape[0] == 1:
287 | ar = ar[0]
288 | else:
289 | assert img_np.shape[0] == 3, img_np.shape
290 | ar = ar.transpose(1, 2, 0)
291 |
292 | return Image.fromarray(ar)
--------------------------------------------------------------------------------
/utils/imresize.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.ndimage import filters, measurements, interpolation
3 | from math import pi
4 |
5 |
6 | def imresize(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False):
7 | # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa
8 | scale_factor, output_shape = fix_scale_and_size(im.shape, output_shape, scale_factor)
9 |
10 | # For a given numeric kernel case, just do convolution and sub-sampling (downscaling only)
11 | if type(kernel) == np.ndarray and scale_factor[0] <= 1:
12 | return numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag)
13 |
14 | # Choose interpolation method, each method has the matching kernel size
15 | method, kernel_width = {
16 | "cubic": (cubic, 4.0),
17 | "lanczos2": (lanczos2, 4.0),
18 | "lanczos3": (lanczos3, 6.0),
19 | "box": (box, 1.0),
20 | "linear": (linear, 2.0),
21 | None: (cubic, 4.0) # set default interpolation method as cubic
22 | }.get(kernel)
23 |
24 | # Antialiasing is only used when downscaling
25 | antialiasing *= (scale_factor[0] < 1)
26 |
27 | # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient
28 | sorted_dims = np.argsort(np.array(scale_factor)).tolist()
29 |
30 | # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction
31 | out_im = np.copy(im)
32 | for dim in sorted_dims:
33 | # No point doing calculations for scale-factor 1. nothing will happen anyway
34 | if scale_factor[dim] == 1.0:
35 | continue
36 |
37 | # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the
38 | # weights that multiply the values there to get its result.
39 | weights, field_of_view = contributions(im.shape[dim], output_shape[dim], scale_factor[dim],
40 | method, kernel_width, antialiasing)
41 |
42 | # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim
43 | out_im = resize_along_dim(out_im, dim, weights, field_of_view)
44 |
45 | return out_im
46 |
47 |
48 | def fix_scale_and_size(input_shape, output_shape, scale_factor):
49 | # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the
50 | # same size as the number of input dimensions)
51 | if scale_factor is not None:
52 | # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it.
53 | if np.isscalar(scale_factor):
54 | scale_factor = [scale_factor, scale_factor]
55 |
56 | # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales
57 | scale_factor = list(scale_factor)
58 | scale_factor.extend([1] * (len(input_shape) - len(scale_factor)))
59 |
60 | # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size
61 | # to all the unspecified dimensions
62 | if output_shape is not None:
63 | output_shape = list(np.uint(np.array(output_shape))) + list(input_shape[len(output_shape):])
64 |
65 | # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is
66 | # sub-optimal, because there can be different scales to the same output-shape.
67 | if scale_factor is None:
68 | scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape)
69 |
70 | # Dealing with missing output-shape. calculating according to scale-factor
71 | if output_shape is None:
72 | output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor)))
73 |
74 | return scale_factor, output_shape
75 |
76 |
77 | def contributions(in_length, out_length, scale, kernel, kernel_width, antialiasing):
78 | # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied
79 | # such that each position from the field_of_view will be multiplied with a matching filter from the
80 | # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers
81 | # around it. This is only done for one dimension of the image.
82 |
83 | # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of
84 | # 1/sf. this means filtering is more 'low-pass filter'.
85 | fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel
86 | kernel_width *= 1.0 / scale if antialiasing else 1.0
87 |
88 | # These are the coordinates of the output image
89 | out_coordinates = np.arange(1, out_length+1)
90 |
91 | # These are the matching positions of the output-coordinates on the input image coordinates.
92 | # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels:
93 | # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel.
94 | # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to
95 | # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big
96 | # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor).
97 | # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is
98 | # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means:
99 | # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf)
100 | match_coordinates = 1.0 * out_coordinates / scale + 0.5 * (1 - 1.0 / scale)
101 |
102 | # This is the left boundary to start multiplying the filter from, it depends on the size of the filter
103 | left_boundary = np.floor(match_coordinates - kernel_width / 2)
104 |
105 | # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers
106 | # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them)
107 | expanded_kernel_width = np.ceil(kernel_width) + 2
108 |
109 | # Determine a set of field_of_view for each each output position, these are the pixels in the input image
110 | # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the
111 | # vertical dim is the pixels it 'sees' (kernel_size + 2)
112 | field_of_view = np.squeeze(np.uint(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1))
113 |
114 | # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the
115 | # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in
116 | # 'field_of_view')
117 | weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1)
118 |
119 | # Normalize weights to sum up to 1. be careful from dividing by 0
120 | sum_weights = np.sum(weights, axis=1)
121 | sum_weights[sum_weights == 0] = 1.0
122 | weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1)
123 |
124 | # We use this mirror structure as a trick for reflection padding at the boundaries
125 | mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1))))
126 | field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])]
127 |
128 | # Get rid of weights and pixel positions that are of zero weight
129 | non_zero_out_pixels = np.nonzero(np.any(weights, axis=0))
130 | weights = np.squeeze(weights[:, non_zero_out_pixels])
131 | field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels])
132 |
133 | # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size
134 | return weights, field_of_view
135 |
136 |
137 | def resize_along_dim(im, dim, weights, field_of_view):
138 | # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize
139 | tmp_im = np.swapaxes(im, dim, 0)
140 |
141 | # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for
142 | # tmp_im[field_of_view.T], (bsxfun style)
143 | weights = np.reshape(weights.T, list(weights.T.shape) + (np.ndim(im) - 1) * [1])
144 |
145 | # This is a bit of a complicated multiplication: tmp_im[field_of_view.T] is a tensor of order image_dims+1.
146 | # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim
147 | # only, this is why it only adds 1 dim to the shape). We then multiply, for each pixel, its set of positions with
148 | # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style:
149 | # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the
150 | # same number
151 | tmp_out_im = np.sum(tmp_im[field_of_view.T] * weights, axis=0)
152 |
153 | # Finally we swap back the axes to the original order
154 | return np.swapaxes(tmp_out_im, dim, 0)
155 |
156 |
157 | def numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag):
158 | # See kernel_shift function to understand what this is
159 | if kernel_shift_flag:
160 | kernel = kernel_shift(kernel, scale_factor)
161 |
162 | # First run a correlation (convolution with flipped kernel)
163 | out_im = np.zeros_like(im)
164 | for channel in range(np.ndim(im)):
165 | out_im[:, :, channel] = filters.correlate(im[:, :, channel], kernel)
166 |
167 | # Then subsample and return
168 | return out_im[np.round(np.linspace(0, im.shape[0] - 1 / scale_factor[0], output_shape[0])).astype(int)[:, None],
169 | np.round(np.linspace(0, im.shape[1] - 1 / scale_factor[1], output_shape[1])).astype(int), :]
170 |
171 |
172 | def kernel_shift(kernel, sf):
173 | # There are two reasons for shifting the kernel:
174 | # 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know
175 | # the degradation process included shifting so we always assume center of mass is center of the kernel.
176 | # 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first
177 | # pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the
178 | # top left corner of the first pixel. that is why different shift size needed between od and even size.
179 | # Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows:
180 | # The input image, when interpolated (regular bicubic) is exactly aligned with ground truth.
181 |
182 | # First calculate the current center of mass for the kernel
183 | current_center_of_mass = measurements.center_of_mass(kernel)
184 |
185 | # The second ("+ 0.5 * ....") is for applying condition 2 from the comments above
186 | wanted_center_of_mass = np.array(kernel.shape) / 2 + 0.5 * (sf - (kernel.shape[0] % 2))
187 |
188 | # Define the shift vector for the kernel shifting (x,y)
189 | shift_vec = wanted_center_of_mass - current_center_of_mass
190 |
191 | # Before applying the shift, we first pad the kernel so that nothing is lost due to the shift
192 | # (biggest shift among dims + 1 for safety)
193 | kernel = np.pad(kernel, np.int(np.ceil(np.max(shift_vec))) + 1, 'constant')
194 |
195 | # Finally shift the kernel and return
196 | return interpolation.shift(kernel, shift_vec)
197 |
198 |
199 | # These next functions are all interpolation methods. x is the distance from the left pixel center
200 |
201 |
202 | def cubic(x):
203 | absx = np.abs(x)
204 | absx2 = absx ** 2
205 | absx3 = absx ** 3
206 | return ((1.5*absx3 - 2.5*absx2 + 1) * (absx <= 1) +
207 | (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * ((1 < absx) & (absx <= 2)))
208 |
209 |
210 | def lanczos2(x):
211 | return (((np.sin(pi*x) * np.sin(pi*x/2) + np.finfo(np.float32).eps) /
212 | ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps))
213 | * (abs(x) < 2))
214 |
215 |
216 | def box(x):
217 | return ((-0.5 <= x) & (x < 0.5)) * 1.0
218 |
219 |
220 | def lanczos3(x):
221 | return (((np.sin(pi*x) * np.sin(pi*x/3) + np.finfo(np.float32).eps) /
222 | ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps))
223 | * (abs(x) < 3))
224 |
225 |
226 | def linear(x):
227 | return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1))
228 |
229 |
230 | def np_imresize(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False):
231 | return np.clip(imresize(im.transpose(1, 2, 0), scale_factor, output_shape, kernel, antialiasing,
232 | kernel_shift_flag).transpose(2, 0, 1), 0, 1)
--------------------------------------------------------------------------------
/utils/loss_utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.nn.functional import mse_loss
4 |
5 |
6 | class GANLoss(nn.Module):
7 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0,
8 | tensor=torch.FloatTensor):
9 | super(GANLoss, self).__init__()
10 | self.real_label = target_real_label
11 | self.fake_label = target_fake_label
12 | self.real_label_var = None
13 | self.fake_label_var = None
14 | self.Tensor = tensor
15 | if use_lsgan:
16 | self.loss = nn.MSELoss()
17 | else:
18 | self.loss = nn.BCELoss()
19 |
20 | def get_target_tensor(self, input, target_is_real):
21 | target_tensor = None
22 | if target_is_real:
23 | create_label = ((self.real_label_var is None) or(self.real_label_var.numel() != input.numel()))
24 | # pdb.set_trace()
25 | if create_label:
26 | real_tensor = self.Tensor(input.size()).fill_(self.real_label)
27 | # self.real_label_var = Variable(real_tensor, requires_grad=False)
28 | # self.real_label_var = torch.Tensor(real_tensor)
29 | self.real_label_var = real_tensor
30 | target_tensor = self.real_label_var
31 | else:
32 | # pdb.set_trace()
33 | create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel()))
34 | if create_label:
35 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label)
36 | # self.fake_label_var = Variable(fake_tensor, requires_grad=False)
37 | # self.fake_label_var = torch.Tensor(fake_tensor)
38 | self.fake_label_var = fake_tensor
39 | target_tensor = self.fake_label_var
40 | return target_tensor
41 |
42 | def __call__(self, input, target_is_real):
43 | target_tensor = self.get_target_tensor(input, target_is_real)
44 | # pdb.set_trace()
45 | return self.loss(input, target_tensor)
46 |
47 |
--------------------------------------------------------------------------------
/utils/schedulers.py:
--------------------------------------------------------------------------------
1 | import math
2 | from collections import Counter
3 | from torch.optim.lr_scheduler import _LRScheduler
4 | import torch
5 | import warnings
6 | from typing import List
7 |
8 | from torch import nn
9 | from torch.optim import Adam, Optimizer
10 |
11 | class MultiStepRestartLR(_LRScheduler):
12 | """ MultiStep with restarts learning rate scheme.
13 |
14 | Args:
15 | optimizer (torch.nn.optimizer): Torch optimizer.
16 | milestones (list): Iterations that will decrease learning rate.
17 | gamma (float): Decrease ratio. Default: 0.1.
18 | restarts (list): Restart iterations. Default: [0].
19 | restart_weights (list): Restart weights at each restart iteration.
20 | Default: [1].
21 | last_epoch (int): Used in _LRScheduler. Default: -1.
22 | """
23 |
24 | def __init__(self,
25 | optimizer,
26 | milestones,
27 | gamma=0.1,
28 | restarts=(0, ),
29 | restart_weights=(1, ),
30 | last_epoch=-1):
31 | self.milestones = Counter(milestones)
32 | self.gamma = gamma
33 | self.restarts = restarts
34 | self.restart_weights = restart_weights
35 | assert len(self.restarts) == len(
36 | self.restart_weights), 'restarts and their weights do not match.'
37 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch)
38 |
39 | def get_lr(self):
40 | if self.last_epoch in self.restarts:
41 | weight = self.restart_weights[self.restarts.index(self.last_epoch)]
42 | return [
43 | group['initial_lr'] * weight
44 | for group in self.optimizer.param_groups
45 | ]
46 | if self.last_epoch not in self.milestones:
47 | return [group['lr'] for group in self.optimizer.param_groups]
48 | return [
49 | group['lr'] * self.gamma**self.milestones[self.last_epoch]
50 | for group in self.optimizer.param_groups
51 | ]
52 |
53 | class LinearLR(_LRScheduler):
54 | """
55 |
56 | Args:
57 | optimizer (torch.nn.optimizer): Torch optimizer.
58 | milestones (list): Iterations that will decrease learning rate.
59 | gamma (float): Decrease ratio. Default: 0.1.
60 | last_epoch (int): Used in _LRScheduler. Default: -1.
61 | """
62 |
63 | def __init__(self,
64 | optimizer,
65 | total_iter,
66 | last_epoch=-1):
67 | self.total_iter = total_iter
68 | super(LinearLR, self).__init__(optimizer, last_epoch)
69 |
70 | def get_lr(self):
71 | process = self.last_epoch / self.total_iter
72 | weight = (1 - process)
73 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups])
74 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
75 |
76 | class VibrateLR(_LRScheduler):
77 | """
78 |
79 | Args:
80 | optimizer (torch.nn.optimizer): Torch optimizer.
81 | milestones (list): Iterations that will decrease learning rate.
82 | gamma (float): Decrease ratio. Default: 0.1.
83 | last_epoch (int): Used in _LRScheduler. Default: -1.
84 | """
85 |
86 | def __init__(self,
87 | optimizer,
88 | total_iter,
89 | last_epoch=-1):
90 | self.total_iter = total_iter
91 | super(VibrateLR, self).__init__(optimizer, last_epoch)
92 |
93 | def get_lr(self):
94 | process = self.last_epoch / self.total_iter
95 |
96 | f = 0.1
97 | if process < 3 / 8:
98 | f = 1 - process * 8 / 3
99 | elif process < 5 / 8:
100 | f = 0.2
101 |
102 | T = self.total_iter // 80
103 | Th = T // 2
104 |
105 | t = self.last_epoch % T
106 |
107 | f2 = t / Th
108 | if t >= Th:
109 | f2 = 2 - f2
110 |
111 | weight = f * f2
112 |
113 | if self.last_epoch < Th:
114 | weight = max(0.1, weight)
115 |
116 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2))
117 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups]
118 |
119 | def get_position_from_periods(iteration, cumulative_period):
120 | """Get the position from a period list.
121 |
122 | It will return the index of the right-closest number in the period list.
123 | For example, the cumulative_period = [100, 200, 300, 400],
124 | if iteration == 50, return 0;
125 | if iteration == 210, return 2;
126 | if iteration == 300, return 2.
127 |
128 | Args:
129 | iteration (int): Current iteration.
130 | cumulative_period (list[int]): Cumulative period list.
131 |
132 | Returns:
133 | int: The position of the right-closest number in the period list.
134 | """
135 | for i, period in enumerate(cumulative_period):
136 | if iteration <= period:
137 | return i
138 |
139 |
140 | class CosineAnnealingRestartLR(_LRScheduler):
141 | """ Cosine annealing with restarts learning rate scheme.
142 |
143 | An example of config:
144 | periods = [10, 10, 10, 10]
145 | restart_weights = [1, 0.5, 0.5, 0.5]
146 | eta_min=1e-7
147 |
148 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
149 | scheduler will restart with the weights in restart_weights.
150 |
151 | Args:
152 | optimizer (torch.nn.optimizer): Torch optimizer.
153 | periods (list): Period for each cosine anneling cycle.
154 | restart_weights (list): Restart weights at each restart iteration.
155 | Default: [1].
156 | eta_min (float): The mimimum lr. Default: 0.
157 | last_epoch (int): Used in _LRScheduler. Default: -1.
158 | """
159 |
160 | def __init__(self,
161 | optimizer,
162 | periods,
163 | restart_weights=(1, ),
164 | eta_min=0,
165 | last_epoch=-1):
166 | self.periods = periods
167 | self.restart_weights = restart_weights
168 | self.eta_min = eta_min
169 | assert (len(self.periods) == len(self.restart_weights)
170 | ), 'periods and restart_weights should have the same length.'
171 | self.cumulative_period = [
172 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
173 | ]
174 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
175 |
176 | def get_lr(self):
177 | idx = get_position_from_periods(self.last_epoch,
178 | self.cumulative_period)
179 | current_weight = self.restart_weights[idx]
180 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
181 | current_period = self.periods[idx]
182 |
183 | return [
184 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
185 | (1 + math.cos(math.pi * (
186 | (self.last_epoch - nearest_restart) / current_period)))
187 | for base_lr in self.base_lrs
188 | ]
189 |
190 | class CosineAnnealingRestartCyclicLR(_LRScheduler):
191 | """ Cosine annealing with restarts learning rate scheme.
192 | An example of config:
193 | periods = [10, 10, 10, 10]
194 | restart_weights = [1, 0.5, 0.5, 0.5]
195 | eta_min=1e-7
196 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the
197 | scheduler will restart with the weights in restart_weights.
198 | Args:
199 | optimizer (torch.nn.optimizer): Torch optimizer.
200 | periods (list): Period for each cosine anneling cycle.
201 | restart_weights (list): Restart weights at each restart iteration.
202 | Default: [1].
203 | eta_min (float): The mimimum lr. Default: 0.
204 | last_epoch (int): Used in _LRScheduler. Default: -1.
205 | """
206 |
207 | def __init__(self,
208 | optimizer,
209 | periods,
210 | restart_weights=(1, ),
211 | eta_mins=(0, ),
212 | last_epoch=-1):
213 | self.periods = periods
214 | self.restart_weights = restart_weights
215 | self.eta_mins = eta_mins
216 | assert (len(self.periods) == len(self.restart_weights)
217 | ), 'periods and restart_weights should have the same length.'
218 | self.cumulative_period = [
219 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
220 | ]
221 | super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch)
222 |
223 | def get_lr(self):
224 | idx = get_position_from_periods(self.last_epoch,
225 | self.cumulative_period)
226 | current_weight = self.restart_weights[idx]
227 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1]
228 | current_period = self.periods[idx]
229 | eta_min = self.eta_mins[idx]
230 |
231 | return [
232 | eta_min + current_weight * 0.5 * (base_lr - eta_min) *
233 | (1 + math.cos(math.pi * (
234 | (self.last_epoch - nearest_restart) / current_period)))
235 | for base_lr in self.base_lrs
236 | ]
237 |
238 |
239 | class LinearWarmupCosineAnnealingLR(_LRScheduler):
240 | """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr
241 | and base_lr followed by a cosine annealing schedule between base_lr and eta_min.
242 | .. warning::
243 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR`
244 | after each iteration as calling it after each epoch will keep the starting lr at
245 | warmup_start_lr for the first epoch which is 0 in most cases.
246 | .. warning::
247 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING.
248 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of
249 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing
250 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling
251 | train and validation methods.
252 | Example:
253 | >>> layer = nn.Linear(10, 1)
254 | >>> optimizer = Adam(layer.parameters(), lr=0.02)
255 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40)
256 | >>> #
257 | >>> # the default case
258 | >>> for epoch in range(40):
259 | ... # train(...)
260 | ... # validate(...)
261 | ... scheduler.step()
262 | >>> #
263 | >>> # passing epoch param case
264 | >>> for epoch in range(40):
265 | ... scheduler.step(epoch)
266 | ... # train(...)
267 | ... # validate(...)
268 | """
269 |
270 | def __init__(
271 | self,
272 | optimizer: Optimizer,
273 | warmup_epochs: int,
274 | max_epochs: int,
275 | warmup_start_lr: float = 0.0,
276 | eta_min: float = 0.0,
277 | last_epoch: int = -1,
278 | ) -> None:
279 | """
280 | Args:
281 | optimizer (Optimizer): Wrapped optimizer.
282 | warmup_epochs (int): Maximum number of iterations for linear warmup
283 | max_epochs (int): Maximum number of iterations
284 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0.
285 | eta_min (float): Minimum learning rate. Default: 0.
286 | last_epoch (int): The index of last epoch. Default: -1.
287 | """
288 | self.warmup_epochs = warmup_epochs
289 | self.max_epochs = max_epochs
290 | self.warmup_start_lr = warmup_start_lr
291 | self.eta_min = eta_min
292 |
293 | super().__init__(optimizer, last_epoch)
294 |
295 | def get_lr(self) -> List[float]:
296 | """Compute learning rate using chainable form of the scheduler."""
297 | if not self._get_lr_called_within_step:
298 | warnings.warn(
299 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.",
300 | UserWarning,
301 | )
302 |
303 | if self.last_epoch == 0:
304 | return [self.warmup_start_lr] * len(self.base_lrs)
305 | if self.last_epoch < self.warmup_epochs:
306 | return [
307 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
308 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
309 | ]
310 | if self.last_epoch == self.warmup_epochs:
311 | return self.base_lrs
312 | if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0:
313 | return [
314 | group["lr"]
315 | + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2
316 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
317 | ]
318 |
319 | return [
320 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
321 | / (
322 | 1
323 | + math.cos(
324 | math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs)
325 | )
326 | )
327 | * (group["lr"] - self.eta_min)
328 | + self.eta_min
329 | for group in self.optimizer.param_groups
330 | ]
331 |
332 | def _get_closed_form_lr(self) -> List[float]:
333 | """Called when epoch is passed as a param to the `step` function of the scheduler."""
334 | if self.last_epoch < self.warmup_epochs:
335 | return [
336 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1)
337 | for base_lr in self.base_lrs
338 | ]
339 |
340 | return [
341 | self.eta_min
342 | + 0.5
343 | * (base_lr - self.eta_min)
344 | * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)))
345 | for base_lr in self.base_lrs
346 | ]
347 |
348 |
349 | # warmup + decay as a function
350 | def linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False):
351 | """Linear warmup for warmup_steps, optionally with cosine annealing or linear decay to 0 at total_steps."""
352 | assert not (linear and cosine)
353 |
354 | def fn(step):
355 | if step < warmup_steps:
356 | return float(step) / float(max(1, warmup_steps))
357 |
358 | if not (cosine or linear):
359 | # no decay
360 | return 1.0
361 |
362 | progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps))
363 | if cosine:
364 | # cosine decay
365 | return 0.5 * (1.0 + math.cos(math.pi * progress))
366 |
367 | # linear decay
368 | return 1.0 - progress
369 |
370 | return fn
371 |
--------------------------------------------------------------------------------
/utils/val_utils.py:
--------------------------------------------------------------------------------
1 | import time
2 | import numpy as np
3 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity
4 | import cv2
5 | import math
6 |
7 | class AverageMeter():
8 | """ Computes and stores the average and current value """
9 |
10 | def __init__(self):
11 | self.reset()
12 |
13 | def reset(self):
14 | """ Reset all statistics """
15 | self.val = 0
16 | self.avg = 0
17 | self.sum = 0
18 | self.count = 0
19 |
20 | def update(self, val, n=1):
21 | """ Update statistics """
22 | self.val = val
23 | self.sum += val * n
24 | self.count += n
25 | self.avg = self.sum / self.count
26 |
27 |
28 | def accuracy(output, target, topk=(1,)):
29 | """ Computes the precision@k for the specified values of k """
30 | maxk = max(topk)
31 | batch_size = target.size(0)
32 |
33 | _, pred = output.topk(maxk, 1, True, True)
34 | pred = pred.t()
35 | # one-hot case
36 | if target.ndimension() > 1:
37 | target = target.max(1)[1]
38 |
39 | correct = pred.eq(target.view(1, -1).expand_as(pred))
40 |
41 | res = []
42 | for k in topk:
43 | correct_k = correct[:k].view(-1).float().sum(0)
44 | res.append(correct_k.mul_(1.0 / batch_size))
45 |
46 | return res
47 |
48 |
49 |
50 | def compute_psnr_ssim(recoverd, clean,to_y=False,bd=0):
51 | # shape: [1,C,H,W] range:[0-1] mode:rgb
52 | assert recoverd.shape == clean.shape
53 | recoverd = recoverd*255.
54 | clean = clean*255.
55 | recoverd = np.clip(recoverd.detach().cpu().numpy(), 0, 255)
56 | clean = np.clip(clean.detach().cpu().numpy(), 0, 255)
57 |
58 | recoverd = recoverd.transpose(0, 2, 3, 1)[0]
59 | clean = clean.transpose(0, 2, 3, 1)[0]
60 | recoverd = recoverd.astype(np.float64).round()
61 | clean = clean.astype(np.float64).round()
62 | if to_y:
63 | recoverd = rgb2ycbcr(recoverd / 255.0, only_y=True) * 255.0
64 | clean = rgb2ycbcr(clean / 255.0, only_y=True) * 255.0
65 |
66 | if bd != 0:
67 | recoverd= recoverd[bd:-bd, bd:-bd]
68 | clean = clean[bd:-bd, bd:-bd]
69 | psnr = calculate_psnr(recoverd, clean)
70 | ssim = calculate_ssim(recoverd, clean)
71 |
72 | return psnr, ssim, 1
73 |
74 |
75 |
76 |
77 |
78 | class timer():
79 | def __init__(self):
80 | self.acc = 0
81 | self.tic()
82 |
83 | def tic(self):
84 | self.t0 = time.time()
85 |
86 | def toc(self):
87 | return time.time() - self.t0
88 |
89 | def hold(self):
90 | self.acc += self.toc()
91 |
92 | def release(self):
93 | ret = self.acc
94 | self.acc = 0
95 |
96 | return ret
97 |
98 | def reset(self):
99 | self.acc = 0
100 |
101 |
102 |
103 |
104 |
105 | def calculate_psnr(img1, img2):
106 | # img1 and img2 have range [0, 255]
107 | img1 = img1.astype(np.float64)
108 | img2 = img2.astype(np.float64)
109 | mse = np.mean((img1 - img2) ** 2)
110 | if mse == 0:
111 | return float('inf')
112 |
113 | return 20 * math.log10(255.0 / math.sqrt(mse))
114 |
115 |
116 | def calculate_ssim(img1, img2):
117 | C1 = (0.01 * 255) ** 2
118 | C2 = (0.03 * 255) ** 2
119 |
120 | img1 = img1.astype(np.float64)
121 | img2 = img2.astype(np.float64)
122 | kernel = cv2.getGaussianKernel(11, 1.5)
123 | window = np.outer(kernel, kernel.transpose())
124 |
125 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5]
126 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
127 | mu1_sq = mu1 ** 2
128 | mu2_sq = mu2 ** 2
129 | mu1_mu2 = mu1 * mu2
130 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
131 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
132 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
133 |
134 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
135 |
136 | return ssim_map.mean()
137 |
138 |
139 | def calc_psnr(sr, hr, scale=2, rgb_range=1.0, benchmark=True):
140 | # benchmark: to Y channel
141 | diff = (sr - hr).data.div(rgb_range)
142 | if benchmark:
143 | shave = scale
144 | if diff.size(1) > 1:
145 | convert = diff.new(1, 3, 1, 1)
146 | convert[0, 0, 0, 0] = 65.738
147 | convert[0, 1, 0, 0] = 129.057
148 | convert[0, 2, 0, 0] = 25.064
149 | diff.mul_(convert).div_(256)
150 | diff = diff.sum(dim=1, keepdim=True)
151 | else:
152 | shave = scale + 6
153 |
154 | valid = diff[:, :, shave:-shave, shave:-shave]
155 | mse = valid.pow(2).mean()
156 |
157 | return -10 * math.log10(mse)
158 |
159 |
160 | def rgb2ycbcr(img, only_y=True):
161 | '''same as matlab rgb2ycbcr
162 | only_y: only return Y channel
163 | Input:
164 | uint8, [0, 255]
165 | float, [0, 1]
166 | '''
167 | in_img_type = img.dtype
168 | img.astype(np.float32)
169 | if in_img_type != np.uint8:
170 | img *= 255.
171 | # convert
172 | if only_y:
173 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0
174 | else:
175 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786],
176 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128]
177 | if in_img_type == np.uint8:
178 | rlt = rlt.round()
179 | else:
180 | rlt /= 255.
181 | return rlt.astype(in_img_type)
182 |
183 |
184 | def bgr2ycbcr(img, only_y=True):
185 | '''bgr version of rgb2ycbcr
186 | only_y: only return Y channel
187 | Input:
188 | uint8, [0, 255]
189 | float, [0, 1]
190 | '''
191 | in_img_type = img.dtype
192 | img.astype(np.float32)
193 | if in_img_type != np.uint8:
194 | img *= 255.
195 | # convert
196 | if only_y:
197 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0
198 | else:
199 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786],
200 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128]
201 | if in_img_type == np.uint8:
202 | rlt = rlt.round()
203 | else:
204 | rlt /= 255.
205 | return rlt.astype(in_img_type)
206 |
207 |
208 | def ycbcr2rgb(img):
209 | '''same as matlab ycbcr2rgb
210 | Input:
211 | uint8, [0, 255]
212 | float, [0, 1]
213 | '''
214 | in_img_type = img.dtype
215 | img.astype(np.float32)
216 | if in_img_type != np.uint8:
217 | img *= 255.
218 | # convert
219 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
220 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836]
221 | if in_img_type == np.uint8:
222 | rlt = rlt.round()
223 | else:
224 | rlt /= 255.
225 | return rlt.astype(in_img_type)
226 |
--------------------------------------------------------------------------------
/val_options.py:
--------------------------------------------------------------------------------
1 | import argparse
2 |
3 | parser = argparse.ArgumentParser()
4 | parser.add_argument('--cuda', type=int, default=0)
5 | parser.add_argument('--base_path', type=str, default="/data/guohang/dataset", help='save path of test noisy images')
6 | parser.add_argument('--denoise_path',type=str,default='/data/guohang/dataset/CBSD68/original_png')
7 | parser.add_argument('--derain_path', type=str, default="/data/guohang/dataset/Rain100L",
8 | help='save path of test raining images')
9 | parser.add_argument('--sr_path', type=str, default="/data/guohang/dataset/SR/Set5/HR",
10 | help='path to the sr dataset for validation')
11 | parser.add_argument('--output_path', type=str, default="./test_output/", help='output save path')
12 |
13 | testopt = parser.parse_args()
14 |
15 |
16 |
17 |
18 |
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 |
28 |
29 |
30 |
31 |
32 |
33 |
34 |
--------------------------------------------------------------------------------