├── LICENSE ├── LICENSE.md ├── README.md ├── assets ├── DRH.png ├── Dn&DRL.png ├── SR&DN.png ├── adaptir_logo.png ├── classicSR.png ├── imgsli1.png ├── imgsli2.png ├── imgsli3.png ├── imgsli4.png ├── imgsli5.png ├── imgsli6.png ├── imgsli7.png ├── low-light.png ├── pipeline.png └── scalabiltity.png ├── data_dir ├── noisy │ └── denoise.txt └── rainy │ └── rainTrain.txt ├── environment.yaml ├── net ├── common.py ├── edt.py └── ipt.py ├── options.py ├── test.py ├── train.py ├── utils ├── __init__.py ├── common.py ├── dataset_utils.py ├── degradation_utils.py ├── image_io.py ├── image_utils.py ├── imresize.py ├── loss_utils.py ├── schedulers.py └── val_utils.py └── val_options.py /LICENSE: -------------------------------------------------------------------------------- 1 | Apache License 2 | Version 2.0, January 2004 3 | http://www.apache.org/licenses/ 4 | 5 | TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 6 | 7 | 1. Definitions. 8 | 9 | "License" shall mean the terms and conditions for use, reproduction, 10 | and distribution as defined by Sections 1 through 9 of this document. 11 | 12 | "Licensor" shall mean the copyright owner or entity authorized by 13 | the copyright owner that is granting the License. 14 | 15 | "Legal Entity" shall mean the union of the acting entity and all 16 | other entities that control, are controlled by, or are under common 17 | control with that entity. For the purposes of this definition, 18 | "control" means (i) the power, direct or indirect, to cause the 19 | direction or management of such entity, whether by contract or 20 | otherwise, or (ii) ownership of fifty percent (50%) or more of the 21 | outstanding shares, or (iii) beneficial ownership of such entity. 22 | 23 | "You" (or "Your") shall mean an individual or Legal Entity 24 | exercising permissions granted by this License. 25 | 26 | "Source" form shall mean the preferred form for making modifications, 27 | including but not limited to software source code, documentation 28 | source, and configuration files. 29 | 30 | "Object" form shall mean any form resulting from mechanical 31 | transformation or translation of a Source form, including but 32 | not limited to compiled object code, generated documentation, 33 | and conversions to other media types. 34 | 35 | "Work" shall mean the work of authorship, whether in Source or 36 | Object form, made available under the License, as indicated by a 37 | copyright notice that is included in or attached to the work 38 | (an example is provided in the Appendix below). 39 | 40 | "Derivative Works" shall mean any work, whether in Source or Object 41 | form, that is based on (or derived from) the Work and for which the 42 | editorial revisions, annotations, elaborations, or other modifications 43 | represent, as a whole, an original work of authorship. For the purposes 44 | of this License, Derivative Works shall not include works that remain 45 | separable from, or merely link (or bind by name) to the interfaces of, 46 | the Work and Derivative Works thereof. 47 | 48 | "Contribution" shall mean any work of authorship, including 49 | the original version of the Work and any modifications or additions 50 | to that Work or Derivative Works thereof, that is intentionally 51 | submitted to Licensor for inclusion in the Work by the copyright owner 52 | or by an individual or Legal Entity authorized to submit on behalf of 53 | the copyright owner. For the purposes of this definition, "submitted" 54 | means any form of electronic, verbal, or written communication sent 55 | to the Licensor or its representatives, including but not limited to 56 | communication on electronic mailing lists, source code control systems, 57 | and issue tracking systems that are managed by, or on behalf of, the 58 | Licensor for the purpose of discussing and improving the Work, but 59 | excluding communication that is conspicuously marked or otherwise 60 | designated in writing by the copyright owner as "Not a Contribution." 61 | 62 | "Contributor" shall mean Licensor and any individual or Legal Entity 63 | on behalf of whom a Contribution has been received by Licensor and 64 | subsequently incorporated within the Work. 65 | 66 | 2. Grant of Copyright License. Subject to the terms and conditions of 67 | this License, each Contributor hereby grants to You a perpetual, 68 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 69 | copyright license to reproduce, prepare Derivative Works of, 70 | publicly display, publicly perform, sublicense, and distribute the 71 | Work and such Derivative Works in Source or Object form. 72 | 73 | 3. Grant of Patent License. Subject to the terms and conditions of 74 | this License, each Contributor hereby grants to You a perpetual, 75 | worldwide, non-exclusive, no-charge, royalty-free, irrevocable 76 | (except as stated in this section) patent license to make, have made, 77 | use, offer to sell, sell, import, and otherwise transfer the Work, 78 | where such license applies only to those patent claims licensable 79 | by such Contributor that are necessarily infringed by their 80 | Contribution(s) alone or by combination of their Contribution(s) 81 | with the Work to which such Contribution(s) was submitted. If You 82 | institute patent litigation against any entity (including a 83 | cross-claim or counterclaim in a lawsuit) alleging that the Work 84 | or a Contribution incorporated within the Work constitutes direct 85 | or contributory patent infringement, then any patent licenses 86 | granted to You under this License for that Work shall terminate 87 | as of the date such litigation is filed. 88 | 89 | 4. Redistribution. You may reproduce and distribute copies of the 90 | Work or Derivative Works thereof in any medium, with or without 91 | modifications, and in Source or Object form, provided that You 92 | meet the following conditions: 93 | 94 | (a) You must give any other recipients of the Work or 95 | Derivative Works a copy of this License; and 96 | 97 | (b) You must cause any modified files to carry prominent notices 98 | stating that You changed the files; and 99 | 100 | (c) You must retain, in the Source form of any Derivative Works 101 | that You distribute, all copyright, patent, trademark, and 102 | attribution notices from the Source form of the Work, 103 | excluding those notices that do not pertain to any part of 104 | the Derivative Works; and 105 | 106 | (d) If the Work includes a "NOTICE" text file as part of its 107 | distribution, then any Derivative Works that You distribute must 108 | include a readable copy of the attribution notices contained 109 | within such NOTICE file, excluding those notices that do not 110 | pertain to any part of the Derivative Works, in at least one 111 | of the following places: within a NOTICE text file distributed 112 | as part of the Derivative Works; within the Source form or 113 | documentation, if provided along with the Derivative Works; or, 114 | within a display generated by the Derivative Works, if and 115 | wherever such third-party notices normally appear. The contents 116 | of the NOTICE file are for informational purposes only and 117 | do not modify the License. You may add Your own attribution 118 | notices within Derivative Works that You distribute, alongside 119 | or as an addendum to the NOTICE text from the Work, provided 120 | that such additional attribution notices cannot be construed 121 | as modifying the License. 122 | 123 | You may add Your own copyright statement to Your modifications and 124 | may provide additional or different license terms and conditions 125 | for use, reproduction, or distribution of Your modifications, or 126 | for any such Derivative Works as a whole, provided Your use, 127 | reproduction, and distribution of the Work otherwise complies with 128 | the conditions stated in this License. 129 | 130 | 5. Submission of Contributions. Unless You explicitly state otherwise, 131 | any Contribution intentionally submitted for inclusion in the Work 132 | by You to the Licensor shall be under the terms and conditions of 133 | this License, without any additional terms or conditions. 134 | Notwithstanding the above, nothing herein shall supersede or modify 135 | the terms of any separate license agreement you may have executed 136 | with Licensor regarding such Contributions. 137 | 138 | 6. Trademarks. This License does not grant permission to use the trade 139 | names, trademarks, service marks, or product names of the Licensor, 140 | except as required for reasonable and customary use in describing the 141 | origin of the Work and reproducing the content of the NOTICE file. 142 | 143 | 7. Disclaimer of Warranty. Unless required by applicable law or 144 | agreed to in writing, Licensor provides the Work (and each 145 | Contributor provides its Contributions) on an "AS IS" BASIS, 146 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or 147 | implied, including, without limitation, any warranties or conditions 148 | of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A 149 | PARTICULAR PURPOSE. You are solely responsible for determining the 150 | appropriateness of using or redistributing the Work and assume any 151 | risks associated with Your exercise of permissions under this License. 152 | 153 | 8. Limitation of Liability. In no event and under no legal theory, 154 | whether in tort (including negligence), contract, or otherwise, 155 | unless required by applicable law (such as deliberate and grossly 156 | negligent acts) or agreed to in writing, shall any Contributor be 157 | liable to You for damages, including any direct, indirect, special, 158 | incidental, or consequential damages of any character arising as a 159 | result of this License or out of the use or inability to use the 160 | Work (including but not limited to damages for loss of goodwill, 161 | work stoppage, computer failure or malfunction, or any and all 162 | other commercial damages or losses), even if such Contributor 163 | has been advised of the possibility of such damages. 164 | 165 | 9. Accepting Warranty or Additional Liability. While redistributing 166 | the Work or Derivative Works thereof, You may choose to offer, 167 | and charge a fee for, acceptance of support, warranty, indemnity, 168 | or other liability obligations and/or rights consistent with this 169 | License. However, in accepting such obligations, You may act only 170 | on Your own behalf and on Your sole responsibility, not on behalf 171 | of any other Contributor, and only if You agree to indemnify, 172 | defend, and hold each Contributor harmless for any liability 173 | incurred by, or claims asserted against, such Contributor by reason 174 | of your accepting any such warranty or additional liability. 175 | 176 | END OF TERMS AND CONDITIONS 177 | 178 | APPENDIX: How to apply the Apache License to your work. 179 | 180 | To apply the Apache License to your work, attach the following 181 | boilerplate notice, with the fields enclosed by brackets "[]" 182 | replaced with your own identifying information. (Don't include 183 | the brackets!) The text should be enclosed in the appropriate 184 | comment syntax for the file format. We also recommend that a 185 | file or class name and description of purpose be included on the 186 | same "printed page" as the copyright notice for easier 187 | identification within third-party archives. 188 | 189 | Copyright [yyyy] [name of copyright owner] 190 | 191 | Licensed under the Apache License, Version 2.0 (the "License"); 192 | you may not use this file except in compliance with the License. 193 | You may obtain a copy of the License at 194 | 195 | http://www.apache.org/licenses/LICENSE-2.0 196 | 197 | Unless required by applicable law or agreed to in writing, software 198 | distributed under the License is distributed on an "AS IS" BASIS, 199 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 200 | See the License for the specific language governing permissions and 201 | limitations under the License. 202 | -------------------------------------------------------------------------------- /LICENSE.md: -------------------------------------------------------------------------------- 1 | ## ACADEMIC PUBLIC LICENSE 2 | 3 | ### Permissions 4 | :heavy_check_mark: Non-Commercial use 5 | :heavy_check_mark: Modification 6 | :heavy_check_mark: Distribution 7 | :heavy_check_mark: Private use 8 | 9 | ### Limitations 10 | :x: Commercial Use 11 | :x: Liability 12 | :x: Warranty 13 | 14 | ### Conditions 15 | :information_source: License and copyright notice 16 | :information_source: Same License 17 | 18 | PromptIR is free for use in noncommercial settings: at academic institutions for teaching and research use, and at non-profit research organizations. 19 | You can use PromptIR in your research, academic work, non-commercial work, projects and personal work. We only ask you to credit us appropriately. 20 | 21 | You have the right to use the software, to distribute copies, to receive source code, to change the software and distribute your modifications or the modified software. 22 | If you distribute verbatim or modified copies of this software, they must be distributed under this license. 23 | This license guarantees that you're safe when using PromptIR in your work, for teaching or research. 24 | This license guarantees that PromptIR will remain available free of charge for nonprofit use. 25 | You can modify PromptIR to your purposes, and you can also share your modifications. 26 | 27 | If you would like to use PromptIR in commercial settings, contact us so we can discuss options. Send an email to vaishnav.potlapalli@mbzuai.ac.ae 28 | 29 | 30 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 |

4 | 5 | ## Parameter Efficient Adaptation for Image Restoration with Heterogeneous Mixture-of-Experts 6 | 7 | ## [[Paper](https://arxiv.org/pdf/2312.08881.pdf)] 8 | 9 | [Hang Guo](https://github.com/csguoh), [Tao Dai](https://cstaodai.com/), [Yuanchao Bai](https://scholar.google.com/citations?user=hjYIFZcAAAAJ&hl=zh-CN), Bin Chen, Xudong Ren, Zexuan Zhu, [Shu-Tao Xia](https://scholar.google.com/citations?hl=zh-CN&user=koAXTXgAAAAJ) 10 | 11 | 12 | > **Abstract:** Designing single-task image restoration models for specific degradation has seen great success in recent years. To achieve generalized image restoration, all-in-one methods have recently been proposed and shown potential for multiple restoration tasks using one single model. Despite the promising results, the existing all-in-one paradigm still suffers from high computational costs as well as limited generalization on unseen degradations. In this work, we introduce an alternative solution to improve the generalization of image restoration models. Drawing inspiration from recent advancements in Parameter Efficient Transfer Learning (PETL), we aim to tune only a small number of parameters to adapt pre-trained restoration models to various tasks. However, current PETL methods fail to generalize across varied restoration tasks due to their homogeneous representation nature. To this end, we propose AdaptIR, a Mixture-of-Experts (MoE) with orthogonal multi-branch design to capture local spatial, global spatial, and channel representation bases, followed by adaptive base combination to obtain heterogeneous representation for different degradations. Extensive experiments demonstrate that our AdaptIR achieves stable performance on single-degradation tasks, and excels in hybrid-degradation tasks, with fine-tuning only 0.6% parameters for 8 hours. 13 | 14 | 15 |

16 | 17 |

18 | 19 | ⭐If this work is helpful for you, please help star this repo. Thanks!🤗 20 | 21 | 22 | 23 | ## 📑 Contents 24 | 25 | - [Visual Results](#visual_results) 26 | - [News](#news) 27 | - [TODO](#todo) 28 | - [Results](#results) 29 | - [Citation](#cite) 30 | 31 | 32 | ## :eyes:Visual Results On Different Restoration Tasks 33 | [](https://imgsli.com/MjI1Njk3) [](https://imgsli.com/MjI1NzIx) [](https://imgsli.com/MjI1NzEx) [](https://imgsli.com/MjI1NzAw) 34 | 35 | [](https://imgsli.com/MjI1NzAz) [](https://imgsli.com/MjI1NzAx) [](https://imgsli.com/MjI1NzE2) 36 | 37 | 38 | 39 | ## 🆕 News 40 | 41 | - **2023-12-12:** arXiv paper available. 42 | - **2023-12-16:** This repo is released. 43 | - **2023-09-28:** 😊Our AdaptIR was accepted by NeurIPS2024! 44 | - **2024-10-19:** 🔈The code is available now, enjoy yourself! 45 | - **2025-01-13:** Updated README file with detailed instruciton. 46 | 47 | 48 | ## ☑️ TODO 49 | 50 | - [x] arXiv version 51 | - [x] Release code 52 | - [x] More detailed introductions of README file 53 | - [ ] Further improvements 54 | 55 | 56 | ## 🥇 Results 57 | 58 | We achieve state-of-the-art adaptation performance on various downstream image restoration tasks. Detailed results can be found in the paper. 59 | 60 |
61 | Evaluation on Second-order Degradation (LR4&Noise30) (click to expand) 62 | 63 |

64 | 65 |

66 |
67 | 68 | 69 |
70 | Evaluation on Classic SR (click to expand) 71 | 72 |

73 | 74 |

75 |
76 | 77 | 78 |
79 | Evaluation on Denoise&DerainL (click to expand) 80 | 81 |

82 | 83 |

84 |
85 | 86 | 87 |
88 | Evaluation on Heavy Rain Streak Removal (click to expand) 89 | 90 |

91 | 92 |

93 |
94 | 95 | 96 |
97 | Evaluation on Low-light Image Enhancement (click to expand) 98 | 99 |

100 | 101 |

102 | 103 |
104 | 105 | 106 |
107 | Evaluation on Model Scalability (click to expand) 108 | 109 |

110 | 111 |

112 |
113 | 114 | 115 | 116 | 117 | 118 | ## Datasets & Models Preparation 119 | 120 | ### Datasets 121 | 122 | Since this work involves various restoration tasks, you may collect the training and testing datasets you need from existing repos, such as [Basicsr](https://github.com/XPixelGroup/BasicSR/blob/master/docs/DatasetPreparation.md), [Restormer](https://github.com/swz30/Restormer/tree/main), and [PromptIR](https://github.com/va1shn9v/PromptIR/blob/main/INSTALL.md). 123 | 124 | 125 | 126 | 127 | ### Pre-trained weights 128 | 129 | - IPT pre-trained models 130 | download the `IPT_pretrain` with the [link](https://drive.google.com/drive/folders/1MVSdUX0YBExauG0fFz4ANiWTrq9xZEj7) of the [IPT repo](https://github.com/huawei-noah/Pretrained-IPT). 131 | 132 | 133 | - EDT pre-trained models 134 | download the `SRx2x3x4_EDTB_ImageNet200K.pth` with the [link](https://mycuhk-my.sharepoint.com/:f:/g/personal/1155137927_link_cuhk_edu_hk/Eikt_wPDrIFCpVpiU0zYNu0BwOhQIHgNWuH1FYZbxZhq_w?e=bVEVeW) of the [EDT repo](https://github.com/fenglinglwb/EDT) 135 | 136 | 137 | 138 | ## Training 139 | 140 | Our AdaptIR can adapt the pretrained models to various unseen downstream tasks, including Hybrid-Degradation (`lr4_noise30`, `lr4_jpeg30`), Image SR (`sr_2`,`sr_3`,`sr_4`), Image denoising (`denoise_30`, `denoise_50`), Image Deraining (`derainL`, `derainH`) and low-light image enhancement (`low_light`). 141 | 142 | One can adjust the param `de_type` in the `./options.py` file to train the specific downstream models. Note that only the very lightweight AdaptIR is tuned and thus it only consumes about 8 hours for downstream adaptation. 143 | 144 | **One single 3090 with 24GB memory is enough for training.** 145 | 146 | You can simply run the following command to start training, with our default params: 147 | 148 | ``` 149 | python train.py 150 | ``` 151 | 152 | 153 | ## Testing 154 | 155 | After training, the downstream weights can be found in the `./train_ckpt` path. You can load this ckpt to evaluate the performance of the downstream unseen tasks. 156 | 157 | ``` 158 | python test.py 159 | ``` 160 | 161 | 162 | 163 | 164 | ## 🥰 Citation 165 | 166 | Please cite us if our work is useful for your research. 167 | 168 | ``` 169 | @inproceedings{guoparameter, 170 | title={Parameter Efficient Adaptation for Image Restoration with Heterogeneous Mixture-of-Experts}, 171 | author={Guo, Hang and Dai, Tao and Bai, Yuanchao and Chen, Bin and Ren, Xudong and Zhu, Zexuan and Xia, Shu-Tao}, 172 | booktitle={The Thirty-eighth Annual Conference on Neural Information Processing Systems} 173 | } 174 | ``` 175 | 176 | ## License 177 | 178 | This project is released under the [Apache 2.0 license](LICENSE). 179 | 180 | ## Acknowledgement 181 | 182 | This code is based on [AirNet](https://github.com/XLearning-SCU/2022-CVPR-AirNet), [IPT](https://github.com/huawei-noah/Pretrained-IPT) and [EDT](https://github.com/fenglinglwb/EDT). Thanks for their awesome work. 183 | 184 | ## Contact 185 | 186 | If you have any questions, feel free to approach me at cshguo@gmail.com 187 | 188 | -------------------------------------------------------------------------------- /assets/DRH.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/DRH.png -------------------------------------------------------------------------------- /assets/Dn&DRL.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/Dn&DRL.png -------------------------------------------------------------------------------- /assets/SR&DN.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/SR&DN.png -------------------------------------------------------------------------------- /assets/adaptir_logo.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/adaptir_logo.png -------------------------------------------------------------------------------- /assets/classicSR.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/classicSR.png -------------------------------------------------------------------------------- /assets/imgsli1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli1.png -------------------------------------------------------------------------------- /assets/imgsli2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli2.png -------------------------------------------------------------------------------- /assets/imgsli3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli3.png -------------------------------------------------------------------------------- /assets/imgsli4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli4.png -------------------------------------------------------------------------------- /assets/imgsli5.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli5.png -------------------------------------------------------------------------------- /assets/imgsli6.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli6.png -------------------------------------------------------------------------------- /assets/imgsli7.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/imgsli7.png -------------------------------------------------------------------------------- /assets/low-light.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/low-light.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/pipeline.png -------------------------------------------------------------------------------- /assets/scalabiltity.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/assets/scalabiltity.png -------------------------------------------------------------------------------- /data_dir/rainy/rainTrain.txt: -------------------------------------------------------------------------------- 1 | rainy/rain-100.png 2 | rainy/rain-101.png 3 | rainy/rain-102.png 4 | rainy/rain-103.png 5 | rainy/rain-104.png 6 | rainy/rain-105.png 7 | rainy/rain-106.png 8 | rainy/rain-107.png 9 | rainy/rain-108.png 10 | rainy/rain-109.png 11 | rainy/rain-10.png 12 | rainy/rain-110.png 13 | rainy/rain-111.png 14 | rainy/rain-112.png 15 | rainy/rain-113.png 16 | rainy/rain-114.png 17 | rainy/rain-115.png 18 | rainy/rain-116.png 19 | rainy/rain-117.png 20 | rainy/rain-118.png 21 | rainy/rain-119.png 22 | rainy/rain-11.png 23 | rainy/rain-120.png 24 | rainy/rain-121.png 25 | rainy/rain-122.png 26 | rainy/rain-123.png 27 | rainy/rain-124.png 28 | rainy/rain-125.png 29 | rainy/rain-126.png 30 | rainy/rain-127.png 31 | rainy/rain-128.png 32 | rainy/rain-129.png 33 | rainy/rain-12.png 34 | rainy/rain-130.png 35 | rainy/rain-131.png 36 | rainy/rain-132.png 37 | rainy/rain-133.png 38 | rainy/rain-134.png 39 | rainy/rain-135.png 40 | rainy/rain-136.png 41 | rainy/rain-137.png 42 | rainy/rain-138.png 43 | rainy/rain-139.png 44 | rainy/rain-13.png 45 | rainy/rain-140.png 46 | rainy/rain-141.png 47 | rainy/rain-142.png 48 | rainy/rain-143.png 49 | rainy/rain-144.png 50 | rainy/rain-145.png 51 | rainy/rain-146.png 52 | rainy/rain-147.png 53 | rainy/rain-148.png 54 | rainy/rain-149.png 55 | rainy/rain-14.png 56 | rainy/rain-150.png 57 | rainy/rain-151.png 58 | rainy/rain-152.png 59 | rainy/rain-153.png 60 | rainy/rain-154.png 61 | rainy/rain-155.png 62 | rainy/rain-156.png 63 | rainy/rain-157.png 64 | rainy/rain-158.png 65 | rainy/rain-159.png 66 | rainy/rain-15.png 67 | rainy/rain-160.png 68 | rainy/rain-161.png 69 | rainy/rain-162.png 70 | rainy/rain-163.png 71 | rainy/rain-164.png 72 | rainy/rain-165.png 73 | rainy/rain-166.png 74 | rainy/rain-167.png 75 | rainy/rain-168.png 76 | rainy/rain-169.png 77 | rainy/rain-16.png 78 | rainy/rain-170.png 79 | rainy/rain-171.png 80 | rainy/rain-172.png 81 | rainy/rain-173.png 82 | rainy/rain-174.png 83 | rainy/rain-175.png 84 | rainy/rain-176.png 85 | rainy/rain-177.png 86 | rainy/rain-178.png 87 | rainy/rain-179.png 88 | rainy/rain-17.png 89 | rainy/rain-180.png 90 | rainy/rain-181.png 91 | rainy/rain-182.png 92 | rainy/rain-183.png 93 | rainy/rain-184.png 94 | rainy/rain-185.png 95 | rainy/rain-186.png 96 | rainy/rain-187.png 97 | rainy/rain-188.png 98 | rainy/rain-189.png 99 | rainy/rain-18.png 100 | rainy/rain-190.png 101 | rainy/rain-191.png 102 | rainy/rain-192.png 103 | rainy/rain-193.png 104 | rainy/rain-194.png 105 | rainy/rain-195.png 106 | rainy/rain-196.png 107 | rainy/rain-197.png 108 | rainy/rain-198.png 109 | rainy/rain-199.png 110 | rainy/rain-19.png 111 | rainy/rain-1.png 112 | rainy/rain-200.png 113 | rainy/rain-20.png 114 | rainy/rain-21.png 115 | rainy/rain-22.png 116 | rainy/rain-23.png 117 | rainy/rain-24.png 118 | rainy/rain-25.png 119 | rainy/rain-26.png 120 | rainy/rain-27.png 121 | rainy/rain-28.png 122 | rainy/rain-29.png 123 | rainy/rain-2.png 124 | rainy/rain-30.png 125 | rainy/rain-31.png 126 | rainy/rain-32.png 127 | rainy/rain-33.png 128 | rainy/rain-34.png 129 | rainy/rain-35.png 130 | rainy/rain-36.png 131 | rainy/rain-37.png 132 | rainy/rain-38.png 133 | rainy/rain-39.png 134 | rainy/rain-3.png 135 | rainy/rain-40.png 136 | rainy/rain-41.png 137 | rainy/rain-42.png 138 | rainy/rain-43.png 139 | rainy/rain-44.png 140 | rainy/rain-45.png 141 | rainy/rain-46.png 142 | rainy/rain-47.png 143 | rainy/rain-48.png 144 | rainy/rain-49.png 145 | rainy/rain-4.png 146 | rainy/rain-50.png 147 | rainy/rain-51.png 148 | rainy/rain-52.png 149 | rainy/rain-53.png 150 | rainy/rain-54.png 151 | rainy/rain-55.png 152 | rainy/rain-56.png 153 | rainy/rain-57.png 154 | rainy/rain-58.png 155 | rainy/rain-59.png 156 | rainy/rain-5.png 157 | rainy/rain-60.png 158 | rainy/rain-61.png 159 | rainy/rain-62.png 160 | rainy/rain-63.png 161 | rainy/rain-64.png 162 | rainy/rain-65.png 163 | rainy/rain-66.png 164 | rainy/rain-67.png 165 | rainy/rain-68.png 166 | rainy/rain-69.png 167 | rainy/rain-6.png 168 | rainy/rain-70.png 169 | rainy/rain-71.png 170 | rainy/rain-72.png 171 | rainy/rain-73.png 172 | rainy/rain-74.png 173 | rainy/rain-75.png 174 | rainy/rain-76.png 175 | rainy/rain-77.png 176 | rainy/rain-78.png 177 | rainy/rain-79.png 178 | rainy/rain-7.png 179 | rainy/rain-80.png 180 | rainy/rain-81.png 181 | rainy/rain-82.png 182 | rainy/rain-83.png 183 | rainy/rain-84.png 184 | rainy/rain-85.png 185 | rainy/rain-86.png 186 | rainy/rain-87.png 187 | rainy/rain-88.png 188 | rainy/rain-89.png 189 | rainy/rain-8.png 190 | rainy/rain-90.png 191 | rainy/rain-91.png 192 | rainy/rain-92.png 193 | rainy/rain-93.png 194 | rainy/rain-94.png 195 | rainy/rain-95.png 196 | rainy/rain-96.png 197 | rainy/rain-97.png 198 | rainy/rain-98.png 199 | rainy/rain-99.png 200 | rainy/rain-9.png -------------------------------------------------------------------------------- /environment.yaml: -------------------------------------------------------------------------------- 1 | name: adaptir 2 | channels: 3 | - pytorch 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | - _libgcc_mutex=0.1=main 8 | - _openmp_mutex=5.1=1_gnu 9 | - abseil-cpp=20211102.0=hd4dd3e8_0 10 | - absl-py=1.4.0=py38h06a4308_0 11 | - blas=1.0=mkl 12 | - blinker=1.4=py38h06a4308_0 13 | - bottleneck=1.3.5=py38h7deecbd_0 14 | - brotlipy=0.7.0=py38h0a891b7_1004 15 | - bzip2=1.0.8=h7f98852_4 16 | - c-ares=1.19.0=h5eee18b_0 17 | - ca-certificates=2023.08.22=h06a4308_0 18 | - cachetools=4.2.2=pyhd3eb1b0_0 19 | - certifi=2023.7.22=py38h06a4308_0 20 | - cffi=1.15.0=py38h3931269_0 21 | - chardet=3.0.4=py38h06a4308_1003 22 | - charset-normalizer=3.1.0=pyhd8ed1ab_0 23 | - cryptography=37.0.2=py38h2b5fc30_0 24 | - cudatoolkit=11.6.0=hecad31d_10 25 | - ffmpeg=4.3=hf484d3e_0 26 | - freetype=2.10.4=h0708190_1 27 | - giflib=5.2.1=h5eee18b_3 28 | - gmp=6.2.1=h58526e2_0 29 | - gnutls=3.6.13=h85f3911_1 30 | - google-auth=2.6.0=pyhd3eb1b0_0 31 | - google-auth-oauthlib=0.5.2=py38h06a4308_0 32 | - grpc-cpp=1.48.2=h5bf31a4_0 33 | - grpcio=1.48.2=py38h5bf31a4_0 34 | - idna=3.4=pyhd8ed1ab_0 35 | - intel-openmp=2021.4.0=h06a4308_3561 36 | - jpeg=9e=h166bdaf_1 37 | - lame=3.100=h7f98852_1001 38 | - lcms2=2.12=hddcbb42_0 39 | - ld_impl_linux-64=2.38=h1181459_1 40 | - libffi=3.4.4=h6a678d5_0 41 | - libgcc-ng=11.2.0=h1234567_1 42 | - libgomp=11.2.0=h1234567_1 43 | - libiconv=1.17=h166bdaf_0 44 | - libpng=1.6.37=h21135ba_2 45 | - libprotobuf=3.20.3=he621ea3_0 46 | - libstdcxx-ng=11.2.0=h1234567_1 47 | - libtiff=4.2.0=hecacb30_2 48 | - libwebp=1.2.2=h55f646e_0 49 | - libwebp-base=1.2.2=h7f98852_1 50 | - lz4-c=1.9.3=h9c3ff4c_1 51 | - markdown=3.4.1=py38h06a4308_0 52 | - markupsafe=2.1.1=py38h7f8727e_0 53 | - mkl=2021.4.0=h06a4308_640 54 | - mkl-service=2.4.0=py38h95df7f1_0 55 | - mkl_fft=1.3.1=py38h8666266_1 56 | - mkl_random=1.2.2=py38h1abd341_0 57 | - ncurses=6.4=h6a678d5_0 58 | - nettle=3.6=he412f7d_0 59 | - numexpr=2.8.4=py38he184ba9_0 60 | - numpy=1.24.3=py38h14f4228_0 61 | - numpy-base=1.24.3=py38h31eccc5_0 62 | - oauthlib=3.2.2=py38h06a4308_0 63 | - olefile=0.46=pyh9f0ad1d_1 64 | - openh264=2.1.1=h780b84a_0 65 | - openjpeg=2.4.0=hb52868f_1 66 | - openssl=1.1.1w=h7f8727e_0 67 | - packaging=23.1=py38h06a4308_0 68 | - pandas=2.0.3=py38h1128e8f_0 69 | - pip=23.0.1=py38h06a4308_0 70 | - pyasn1=0.4.8=pyhd3eb1b0_0 71 | - pyasn1-modules=0.2.8=py_0 72 | - pycparser=2.21=pyhd8ed1ab_0 73 | - pyjwt=2.4.0=py38h06a4308_0 74 | - pyopenssl=22.0.0=pyhd8ed1ab_1 75 | - pysocks=1.7.1=pyha2e5f31_6 76 | - python=3.8.16=h7a1cb2a_3 77 | - python-dateutil=2.8.2=pyhd3eb1b0_0 78 | - python-tzdata=2023.3=pyhd3eb1b0_0 79 | - python_abi=3.8=2_cp38 80 | - pytorch=1.12.1=py3.8_cuda11.6_cudnn8.3.2_0 81 | - pytorch-mutex=1.0=cuda 82 | - re2=2022.04.01=h295c915_0 83 | - readline=8.2=h5eee18b_0 84 | - requests-oauthlib=1.3.0=py_0 85 | - rsa=4.7.2=pyhd3eb1b0_1 86 | - setuptools=66.0.0=py38h06a4308_0 87 | - six=1.16.0=pyh6c4a22f_0 88 | - sqlite=3.41.2=h5eee18b_0 89 | - tensorboard=2.12.1=py38h06a4308_0 90 | - tensorboard-data-server=0.7.0=py38h52d8a92_0 91 | - tensorboard-plugin-wit=1.8.1=py38h06a4308_0 92 | - tk=8.6.12=h1ccaba5_0 93 | - torchaudio=0.12.1=py38_cu116 94 | - torchvision=0.13.1=py38_cu116 95 | - tqdm=4.65.0=py38hb070fc8_0 96 | - urllib3=1.26.15=pyhd8ed1ab_0 97 | - werkzeug=2.2.3=py38h06a4308_0 98 | - wheel=0.38.4=py38h06a4308_0 99 | - xz=5.4.2=h5eee18b_0 100 | - zipp=3.11.0=py38h06a4308_0 101 | - zlib=1.2.13=h5eee18b_0 102 | - zstd=1.5.2=ha4553b6_0 103 | - pip: 104 | - accelerate==0.25.0 105 | - addict==2.4.0 106 | - aiohttp==3.8.5 107 | - aiosignal==1.3.1 108 | - antlr4-python3-runtime==4.9.3 109 | - anyio==3.7.1 110 | - appdirs==1.4.4 111 | - argon2-cffi==23.1.0 112 | - argon2-cffi-bindings==21.2.0 113 | - arrow==1.2.3 114 | - astor==0.8.1 115 | - asttokens==2.2.1 116 | - async-lru==2.0.4 117 | - async-timeout==4.0.2 118 | - attrs==23.1.0 119 | - avalanche-lib==0.4.0 120 | - babel==2.14.0 121 | - backcall==0.2.0 122 | - beautifulsoup4==4.12.2 123 | - bleach==6.1.0 124 | - blessed==1.20.0 125 | - bs4==0.0.1 126 | - click==8.1.6 127 | - cloudpickle==3.0.0 128 | - comm==0.2.0 129 | - contourpy==1.1.0 130 | - croniter==1.3.15 131 | - ctrl-benchmark==0.0.4 132 | - cycler==0.11.0 133 | - cython==3.0.6 134 | - dateutils==0.6.12 135 | - debugpy==1.8.0 136 | - decorator==5.1.1 137 | - deepdiff==6.3.1 138 | - defusedxml==0.7.1 139 | - dill==0.3.7 140 | - docker-pycreds==0.4.0 141 | - easydict==1.10 142 | - editdistance==0.6.2 143 | - einops==0.6.1 144 | - exceptiongroup==1.1.2 145 | - executing==1.2.0 146 | - fastapi==0.88.0 147 | - fastjsonschema==2.19.0 148 | - filelock==3.12.0 149 | - fonttools==4.41.1 150 | - fqdn==1.5.1 151 | - frozenlist==1.4.0 152 | - fsspec==2023.5.0 153 | - fvcore==0.1.5.post20221221 154 | - gdown==4.7.1 155 | - gitdb==4.0.10 156 | - gitpython==3.1.32 157 | - gputil==1.4.0 158 | - gym==0.26.2 159 | - gym-notices==0.0.8 160 | - h11==0.14.0 161 | - higher==0.2.1 162 | - httpcore==1.0.2 163 | - httpx==0.26.0 164 | - huggingface-hub==0.19.4 165 | - imageio==2.28.1 166 | - imgaug==0.4.0 167 | - importlib-metadata==7.1.0 168 | - importlib-resources==6.0.0 169 | - inquirer==3.1.3 170 | - iopath==0.1.10 171 | - ipykernel==6.27.1 172 | - ipython==8.12.2 173 | - ipywidgets==8.1.1 174 | - isoduration==20.11.0 175 | - itsdangerous==2.1.2 176 | - jedi==0.18.2 177 | - jinja2==3.1.2 178 | - joblib==1.3.2 179 | - json5==0.9.14 180 | - jsonpointer==2.4 181 | - jsonschema==4.20.0 182 | - jsonschema-specifications==2023.11.2 183 | - jupyter==1.0.0 184 | - jupyter-client==8.6.0 185 | - jupyter-console==6.6.3 186 | - jupyter-core==5.5.1 187 | - jupyter-events==0.9.0 188 | - jupyter-lsp==2.2.1 189 | - jupyter-server==2.12.1 190 | - jupyter-server-terminals==0.5.0 191 | - jupyterlab==4.0.9 192 | - jupyterlab-pygments==0.3.0 193 | - jupyterlab-server==2.25.2 194 | - jupyterlab-widgets==3.0.9 195 | - kiwisolver==1.4.4 196 | - lazy-loader==0.2 197 | - lightning==2.0.1 198 | - lightning-cloud==0.5.37 199 | - lightning-utilities==0.9.0 200 | - lmdb==1.4.1 201 | - lpips==0.1.4 202 | - lvis==0.5.3 203 | - markdown-it-py==3.0.0 204 | - matplotlib==3.7.2 205 | - matplotlib-inline==0.1.6 206 | - mdurl==0.1.2 207 | - mistune==3.0.2 208 | - mmengine==0.10.3 209 | - multidict==6.0.4 210 | - nbclient==0.9.0 211 | - nbconvert==7.13.0 212 | - nbformat==5.9.2 213 | - nest-asyncio==1.5.8 214 | - networkx==2.8.8 215 | - notebook==7.0.6 216 | - notebook-shim==0.2.3 217 | - omegaconf==2.3.0 218 | - opencv-python==4.7.0.72 219 | - opencv-python-headless==4.10.0.84 220 | - opt-einsum==3.3.0 221 | - ordered-set==4.1.0 222 | - overrides==7.4.0 223 | - paddlepaddle-gpu==2.5.2.post116 224 | - pandocfilters==1.5.0 225 | - parso==0.8.3 226 | - pathtools==0.1.2 227 | - patsy==0.5.4 228 | - peft==0.5.0 229 | - petl==1.7.14 230 | - pexpect==4.8.0 231 | - pickleshare==0.7.5 232 | - pillow==9.5.0 233 | - pkgutil-resolve-name==1.3.10 234 | - platformdirs==4.1.0 235 | - plotly==5.18.0 236 | - portalocker==2.8.2 237 | - prometheus-client==0.19.0 238 | - prompt-toolkit==3.0.38 239 | - protobuf==3.20.3 240 | - psutil==5.9.5 241 | - ptflops==0.7 242 | - ptyprocess==0.7.0 243 | - pure-eval==0.2.2 244 | - pycocotools==2.0.7 245 | - pydantic==1.10.11 246 | - pydot==1.4.2 247 | - pygments==2.15.1 248 | - pyparsing==3.0.9 249 | - python-editor==1.0.4 250 | - python-json-logger==2.0.7 251 | - python-multipart==0.0.6 252 | - pytorch-lightning==2.0.1 253 | - pytorchcv==0.0.67 254 | - pytz==2023.3 255 | - pywavelets==1.4.1 256 | - pyyaml==6.0 257 | - pyzmq==25.1.2 258 | - qtconsole==5.5.1 259 | - qtpy==2.4.1 260 | - quadprog==0.1.11 261 | - readchar==4.0.5 262 | - referencing==0.32.0 263 | - regex==2023.10.3 264 | - requests==2.32.3 265 | - rfc3339-validator==0.1.4 266 | - rfc3986-validator==0.1.1 267 | - rich==13.4.2 268 | - rpds-py==0.15.2 269 | - safetensors==0.3.1 270 | - scikit-image==0.20.0 271 | - scikit-learn==0.24.2 272 | - scipy==1.9.1 273 | - send2trash==1.8.2 274 | - sentry-sdk==1.28.1 275 | - setproctitle==1.3.2 276 | - shapely==2.0.1 277 | - smmap==5.0.0 278 | - sniffio==1.3.0 279 | - soupsieve==2.4.1 280 | - stack-data==0.6.2 281 | - starlette==0.22.0 282 | - starsessions==1.3.0 283 | - statsmodels==0.14.1 284 | - tabulate==0.9.0 285 | - tenacity==8.2.3 286 | - tensorboardx==2.6 287 | - termcolor==2.4.0 288 | - terminado==0.18.0 289 | - thop==0.1.1-2209072238 290 | - threadpoolctl==3.2.0 291 | - tifffile==2023.4.12 292 | - timm==0.4.12 293 | - tinycss2==1.2.1 294 | - tokenizers==0.15.0 295 | - tomli==2.0.1 296 | - torchmetrics==1.0.1 297 | - tornado==6.4 298 | - traitlets==5.9.0 299 | - transformers==4.36.1 300 | - typing-extensions==4.4.0 301 | - uri-template==1.3.0 302 | - uvicorn==0.23.1 303 | - wcwidth==0.2.6 304 | - webcolors==1.13 305 | - webencodings==0.5.1 306 | - websocket-client==1.6.1 307 | - websockets==11.0.3 308 | - widgetsnbextension==4.0.9 309 | - yacs==0.1.8 310 | - yapf==0.40.2 311 | - yarl==1.9.2 312 | 313 | -------------------------------------------------------------------------------- /net/common.py: -------------------------------------------------------------------------------- 1 | # 2021.05.07-Changed for IPT 2 | # Huawei Technologies Co., Ltd. 3 | 4 | import math 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | def default_conv(in_channels, out_channels, kernel_size, bias=True): 11 | return nn.Conv2d( 12 | in_channels, out_channels, kernel_size, 13 | padding=(kernel_size//2), bias=bias) 14 | 15 | class MeanShift(nn.Conv2d): 16 | def __init__( 17 | self, rgb_range, 18 | rgb_mean=(0.4488, 0.4371, 0.4040), rgb_std=(1.0, 1.0, 1.0), sign=-1): 19 | 20 | super(MeanShift, self).__init__(3, 3, kernel_size=1) 21 | std = torch.Tensor(rgb_std) 22 | self.weight.data = torch.eye(3).view(3, 3, 1, 1) / std.view(3, 1, 1, 1) 23 | self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean) / std 24 | for p in self.parameters(): 25 | p.requires_grad = False 26 | 27 | class BasicBlock(nn.Sequential): 28 | def __init__( 29 | self, conv, in_channels, out_channels, kernel_size, stride=1, bias=False, 30 | bn=True, act=nn.ReLU(True)): 31 | 32 | m = [conv(in_channels, out_channels, kernel_size, bias=bias)] 33 | if bn: 34 | m.append(nn.BatchNorm2d(out_channels)) 35 | if act is not None: 36 | m.append(act) 37 | 38 | super(BasicBlock, self).__init__(*m) 39 | 40 | class ResBlock(nn.Module): 41 | def __init__( 42 | self, conv, n_feats, kernel_size, 43 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 44 | 45 | super(ResBlock, self).__init__() 46 | m = [] 47 | for i in range(2): 48 | m.append(conv(n_feats, n_feats, kernel_size, bias=bias)) 49 | if bn: 50 | m.append(nn.BatchNorm2d(n_feats)) 51 | if i == 0: 52 | m.append(act) 53 | 54 | self.body = nn.Sequential(*m) 55 | self.res_scale = res_scale 56 | 57 | def forward(self, x): 58 | res = self.body(x).mul(self.res_scale) 59 | res += x 60 | 61 | return res 62 | 63 | class Upsampler(nn.Sequential): 64 | def __init__(self, conv, scale, n_feats, bn=False, act=False, bias=True): 65 | 66 | m = [] 67 | if (scale & (scale - 1)) == 0: # Is scale = 2^n? 68 | for _ in range(int(math.log(scale, 2))): 69 | m.append(conv(n_feats, 4 * n_feats, 3, bias)) 70 | m.append(nn.PixelShuffle(2)) 71 | if bn: 72 | m.append(nn.BatchNorm2d(n_feats)) 73 | if act == 'relu': 74 | m.append(nn.ReLU(True)) 75 | elif act == 'prelu': 76 | m.append(nn.PReLU(n_feats)) 77 | 78 | elif scale == 3: 79 | m.append(conv(n_feats, 9 * n_feats, 3, bias)) 80 | m.append(nn.PixelShuffle(3)) 81 | if bn: 82 | m.append(nn.BatchNorm2d(n_feats)) 83 | if act == 'relu': 84 | m.append(nn.ReLU(True)) 85 | elif act == 'prelu': 86 | m.append(nn.PReLU(n_feats)) 87 | else: 88 | raise NotImplementedError 89 | 90 | super(Upsampler, self).__init__(*m) 91 | 92 | -------------------------------------------------------------------------------- /net/ipt.py: -------------------------------------------------------------------------------- 1 | # 2021.05.07-Changed for IPT 2 | # Huawei Technologies Co., Ltd. 3 | import os 4 | 5 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved 6 | 7 | from net import common 8 | import math 9 | import torch 10 | import torch.nn.functional as F 11 | from torch import nn, Tensor 12 | import copy 13 | from torchvision import ops 14 | from matplotlib import pyplot as plt 15 | import numpy as np 16 | from functools import partial, reduce 17 | from operator import mul 18 | 19 | 20 | class LayerNorm2d(nn.Module): 21 | def __init__(self, num_channels: int, eps: float = 1e-6) -> None: 22 | super().__init__() 23 | self.weight = nn.Parameter(torch.ones(num_channels)) 24 | self.bias = nn.Parameter(torch.zeros(num_channels)) 25 | self.eps = eps 26 | 27 | def forward(self, x: torch.Tensor) -> torch.Tensor: 28 | u = x.mean(1, keepdim=True) 29 | s = (x - u).pow(2).mean(1, keepdim=True) 30 | x = (x - u) / torch.sqrt(s + self.eps) 31 | x = self.weight[:, None, None] * x + self.bias[:, None, None] 32 | return x 33 | 34 | 35 | def tensor_to_image(tensor): 36 | if tensor.max() > 1: 37 | tensor = tensor / 255. 38 | # Ensure the input tensor is on the CPU and in the range [0, 1] 39 | tensor = tensor.cpu() 40 | # tensor = tensor.clamp(0, 1) 41 | 42 | # Convert the tensor to a NumPy array 43 | image = tensor.squeeze(0).permute(1, 2, 0).numpy() 44 | plt.imshow(image) 45 | plt.savefig('./test.jpg') 46 | 47 | 48 | class IPT(nn.Module): 49 | def __init__(self, args): 50 | super(IPT, self).__init__() 51 | self.TASK_MAP = {'lr4_noise30': 2, 'lr4_jpeg30': 2, 'sr_2': 0, 'sr_3': 1, 'sr_4': 2, 52 | 'derainH': 3,'derainL': 3, 'denoise_30': 4, 'denoise_50': 5, 'low_light': 5, } 53 | if isinstance(args.de_type,list): 54 | self.task_idx = None 55 | else: 56 | self.task_idx = self.TASK_MAP[args.de_type] if type(args.de_type) is not list else 5 57 | conv = common.default_conv 58 | self.scales = [2, 3, 4, 1, 1, 1] 59 | n_feats = 64 60 | self.patch_size = 48 61 | n_colors = 3 62 | kernel_size = 3 63 | rgb_range = 255 64 | act = nn.ReLU(True) 65 | 66 | self.sub_mean = common.MeanShift(rgb_range) 67 | self.add_mean = common.MeanShift(rgb_range, sign=1) 68 | 69 | self.head = nn.ModuleList([ 70 | nn.Sequential( 71 | conv(n_colors, n_feats, kernel_size), 72 | common.ResBlock(conv, n_feats, 5, act=act), 73 | common.ResBlock(conv, n_feats, 5, act=act) 74 | ) for _ in self.scales 75 | ]) 76 | 77 | self.body = VisionTransformer(img_dim=48, patch_dim=3, 78 | num_channels=n_feats, embedding_dim=n_feats * 3 * 3, 79 | num_heads=12, num_layers=12, 80 | hidden_dim=n_feats * 3 * 3 * 4, num_queries=len(self.scales), 81 | dropout_rate=0, mlp=False, pos_every=False, 82 | no_pos=False, no_norm=False) 83 | 84 | self.tail = nn.ModuleList([ 85 | nn.Sequential( 86 | common.Upsampler(conv, s, n_feats, act=False), 87 | conv(n_feats, 3, kernel_size) 88 | ) for s in self.scales 89 | ]) 90 | 91 | def forward(self, x, de_id=None): 92 | x = x * 255. 93 | if not self.training: 94 | return self.forward_chop(x) / 255. 95 | else: 96 | return self.forward_train(x, de_id) / 255. 97 | 98 | def forward_train(self, x,de_id=None): 99 | if de_id is not None: 100 | self.task_idx = self.TASK_MAP[de_id[0]] if type(de_id[0]) is not list else 5 101 | x = self.sub_mean(x) 102 | x = self.head[self.task_idx](x) 103 | 104 | res = self.body(x, self.task_idx) 105 | res += x 106 | 107 | x = self.tail[self.task_idx](res) 108 | x = self.add_mean(x) 109 | 110 | return x 111 | 112 | def set_scale(self, task_idx): 113 | self.task_idx = task_idx 114 | 115 | def forward_chop(self, x): 116 | x.cpu() 117 | batchsize = 64 118 | h, w = x.size()[-2:] 119 | padsize = int(self.patch_size) 120 | shave = int(self.patch_size / 2) 121 | 122 | scale = self.scales[self.task_idx] 123 | 124 | h_cut = (h - padsize) % (int(shave / 2)) 125 | w_cut = (w - padsize) % (int(shave / 2)) 126 | 127 | x_unfold = torch.nn.functional.unfold(x, padsize, stride=int(shave / 2)).transpose(0, 128 | 2).contiguous() # [num_patch, 48*48*3, N] 129 | 130 | x_hw_cut = x[..., (h - padsize):, (w - padsize):] 131 | y_hw_cut = self.forward_train(x_hw_cut.cuda()).cpu() 132 | 133 | x_h_cut = x[..., (h - padsize):, :] 134 | x_w_cut = x[..., :, (w - padsize):] 135 | y_h_cut = self.cut_h(x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize) 136 | y_w_cut = self.cut_w(x_w_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize) 137 | 138 | x_h_top = x[..., :padsize, :] 139 | x_w_top = x[..., :, :padsize] 140 | y_h_top = self.cut_h(x_h_top, h, w, h_cut, w_cut, padsize, shave, scale, batchsize) 141 | y_w_top = self.cut_w(x_w_top, h, w, h_cut, w_cut, padsize, shave, scale, batchsize) 142 | 143 | x_unfold = x_unfold.view(x_unfold.size(0), -1, padsize, padsize) 144 | y_unfold = [] 145 | 146 | x_range = x_unfold.size(0) // batchsize + (x_unfold.size(0) % batchsize != 0) 147 | x_unfold.cuda() 148 | for i in range(x_range): 149 | y_unfold.append(self.forward_train(x_unfold[i * batchsize:(i + 1) * batchsize, ...]).cpu()) 150 | y_unfold = torch.cat(y_unfold, dim=0) 151 | 152 | y = torch.nn.functional.fold(y_unfold.view(y_unfold.size(0), -1, 1).transpose(0, 2).contiguous(), 153 | ((h - h_cut) * scale, (w - w_cut) * scale), padsize * scale, 154 | stride=int(shave / 2 * scale)) 155 | 156 | y[..., :padsize * scale, :] = y_h_top 157 | y[..., :, :padsize * scale] = y_w_top 158 | 159 | y_unfold = y_unfold[..., int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale), 160 | int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale)].contiguous() 161 | y_inter = torch.nn.functional.fold(y_unfold.view(y_unfold.size(0), -1, 1).transpose(0, 2).contiguous(), 162 | ((h - h_cut - shave) * scale, (w - w_cut - shave) * scale), 163 | padsize * scale - shave * scale, stride=int(shave / 2 * scale)) 164 | 165 | y_ones = torch.ones(y_inter.shape, dtype=y_inter.dtype) 166 | divisor = torch.nn.functional.fold( 167 | torch.nn.functional.unfold(y_ones, padsize * scale - shave * scale, stride=int(shave / 2 * scale)), 168 | ((h - h_cut - shave) * scale, (w - w_cut - shave) * scale), padsize * scale - shave * scale, 169 | stride=int(shave / 2 * scale)) 170 | 171 | y_inter = y_inter / divisor 172 | 173 | y[..., int(shave / 2 * scale):(h - h_cut) * scale - int(shave / 2 * scale), 174 | int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)] = y_inter 175 | 176 | y = torch.cat([y[..., :y.size(2) - int((padsize - h_cut) / 2 * scale), :], 177 | y_h_cut[..., int((padsize - h_cut) / 2 * scale + 0.5):, :]], dim=2) 178 | y_w_cat = torch.cat([y_w_cut[..., :y_w_cut.size(2) - int((padsize - h_cut) / 2 * scale), :], 179 | y_hw_cut[..., int((padsize - h_cut) / 2 * scale + 0.5):, :]], dim=2) 180 | y = torch.cat([y[..., :, :y.size(3) - int((padsize - w_cut) / 2 * scale)], 181 | y_w_cat[..., :, int((padsize - w_cut) / 2 * scale + 0.5):]], dim=3) 182 | return y.cuda() 183 | 184 | def cut_h(self, x_h_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize): 185 | 186 | x_h_cut_unfold = torch.nn.functional.unfold(x_h_cut, padsize, stride=int(shave / 2)).transpose(0, 187 | 2).contiguous() 188 | 189 | x_h_cut_unfold = x_h_cut_unfold.view(x_h_cut_unfold.size(0), -1, padsize, padsize) 190 | x_range = x_h_cut_unfold.size(0) // batchsize + (x_h_cut_unfold.size(0) % batchsize != 0) 191 | y_h_cut_unfold = [] 192 | x_h_cut_unfold.cuda() 193 | for i in range(x_range): 194 | y_h_cut_unfold.append(self.forward_train(x_h_cut_unfold[i * batchsize:(i + 1) * batchsize, ...]).cpu()) 195 | y_h_cut_unfold = torch.cat(y_h_cut_unfold, dim=0) 196 | 197 | y_h_cut = torch.nn.functional.fold( 198 | y_h_cut_unfold.view(y_h_cut_unfold.size(0), -1, 1).transpose(0, 2).contiguous(), 199 | (padsize * scale, (w - w_cut) * scale), padsize * scale, stride=int(shave / 2 * scale)) 200 | y_h_cut_unfold = y_h_cut_unfold[..., :, 201 | int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale)].contiguous() 202 | y_h_cut_inter = torch.nn.functional.fold( 203 | y_h_cut_unfold.view(y_h_cut_unfold.size(0), -1, 1).transpose(0, 2).contiguous(), 204 | (padsize * scale, (w - w_cut - shave) * scale), (padsize * scale, padsize * scale - shave * scale), 205 | stride=int(shave / 2 * scale)) 206 | 207 | y_ones = torch.ones(y_h_cut_inter.shape, dtype=y_h_cut_inter.dtype) 208 | divisor = torch.nn.functional.fold( 209 | torch.nn.functional.unfold(y_ones, (padsize * scale, padsize * scale - shave * scale), 210 | stride=int(shave / 2 * scale)), (padsize * scale, (w - w_cut - shave) * scale), 211 | (padsize * scale, padsize * scale - shave * scale), stride=int(shave / 2 * scale)) 212 | y_h_cut_inter = y_h_cut_inter / divisor 213 | 214 | y_h_cut[..., :, int(shave / 2 * scale):(w - w_cut) * scale - int(shave / 2 * scale)] = y_h_cut_inter 215 | return y_h_cut 216 | 217 | def cut_w(self, x_w_cut, h, w, h_cut, w_cut, padsize, shave, scale, batchsize): 218 | x_w_cut_unfold = torch.nn.functional.unfold(x_w_cut, padsize, stride=int(shave / 2)).transpose(0, 219 | 2).contiguous() 220 | 221 | x_w_cut_unfold = x_w_cut_unfold.view(x_w_cut_unfold.size(0), -1, padsize, padsize) 222 | x_range = x_w_cut_unfold.size(0) // batchsize + (x_w_cut_unfold.size(0) % batchsize != 0) 223 | y_w_cut_unfold = [] 224 | x_w_cut_unfold.cuda() 225 | for i in range(x_range): 226 | y_w_cut_unfold.append(self.forward_train(x_w_cut_unfold[i * batchsize:(i + 1) * batchsize, ...]).cpu()) 227 | y_w_cut_unfold = torch.cat(y_w_cut_unfold, dim=0) 228 | 229 | y_w_cut = torch.nn.functional.fold( 230 | y_w_cut_unfold.view(y_w_cut_unfold.size(0), -1, 1).transpose(0, 2).contiguous(), 231 | ((h - h_cut) * scale, padsize * scale), padsize * scale, stride=int(shave / 2 * scale)) 232 | y_w_cut_unfold = y_w_cut_unfold[..., int(shave / 2 * scale):padsize * scale - int(shave / 2 * scale), 233 | :].contiguous() 234 | y_w_cut_inter = torch.nn.functional.fold( 235 | y_w_cut_unfold.view(y_w_cut_unfold.size(0), -1, 1).transpose(0, 2).contiguous(), 236 | ((h - h_cut - shave) * scale, padsize * scale), (padsize * scale - shave * scale, padsize * scale), 237 | stride=int(shave / 2 * scale)) 238 | 239 | y_ones = torch.ones(y_w_cut_inter.shape, dtype=y_w_cut_inter.dtype) 240 | divisor = torch.nn.functional.fold( 241 | torch.nn.functional.unfold(y_ones, (padsize * scale - shave * scale, padsize * scale), 242 | stride=int(shave / 2 * scale)), ((h - h_cut - shave) * scale, padsize * scale), 243 | (padsize * scale - shave * scale, padsize * scale), stride=int(shave / 2 * scale)) 244 | y_w_cut_inter = y_w_cut_inter / divisor 245 | 246 | y_w_cut[..., int(shave / 2 * scale):(h - h_cut) * scale - int(shave / 2 * scale), :] = y_w_cut_inter 247 | return y_w_cut 248 | 249 | 250 | class VisionTransformer(nn.Module): 251 | def __init__( 252 | self, 253 | img_dim, 254 | patch_dim, 255 | num_channels, 256 | embedding_dim, 257 | num_heads, 258 | num_layers, 259 | hidden_dim, 260 | num_queries, 261 | positional_encoding_type="learned", 262 | dropout_rate=0, 263 | no_norm=False, 264 | mlp=False, 265 | pos_every=False, 266 | no_pos=False 267 | ): 268 | super(VisionTransformer, self).__init__() 269 | 270 | assert embedding_dim % num_heads == 0 271 | assert img_dim % patch_dim == 0 272 | self.no_norm = no_norm 273 | self.mlp = mlp 274 | self.embedding_dim = embedding_dim 275 | self.num_heads = num_heads 276 | self.patch_dim = patch_dim 277 | self.num_channels = num_channels 278 | 279 | self.img_dim = img_dim 280 | self.pos_every = pos_every 281 | self.num_patches = int((img_dim // patch_dim) ** 2) 282 | self.seq_length = self.num_patches 283 | self.flatten_dim = patch_dim * patch_dim * num_channels 284 | 285 | self.out_dim = patch_dim * patch_dim * num_channels 286 | 287 | self.no_pos = no_pos 288 | 289 | if self.mlp == False: 290 | self.linear_encoding = nn.Linear(self.flatten_dim, embedding_dim) 291 | self.mlp_head = nn.Sequential( 292 | nn.Linear(embedding_dim, hidden_dim), 293 | nn.Dropout(dropout_rate), 294 | nn.ReLU(), 295 | nn.Linear(hidden_dim, self.out_dim), 296 | nn.Dropout(dropout_rate) 297 | ) 298 | 299 | self.query_embed = nn.Embedding(num_queries, embedding_dim * self.seq_length) 300 | 301 | encoder_layer = TransformerEncoderLayer(embedding_dim, num_heads, hidden_dim, dropout_rate, self.no_norm) 302 | self.encoder = TransformerEncoder(encoder_layer, num_layers) 303 | 304 | decoder_layer = TransformerDecoderLayer(embedding_dim, num_heads, hidden_dim, dropout_rate, self.no_norm) 305 | self.decoder = TransformerDecoder(decoder_layer, num_layers) 306 | 307 | if not self.no_pos: 308 | self.position_encoding = LearnedPositionalEncoding( 309 | self.seq_length, self.embedding_dim, self.seq_length 310 | ) 311 | 312 | self.dropout_layer1 = nn.Dropout(dropout_rate) 313 | 314 | if no_norm: 315 | for m in self.modules(): 316 | if isinstance(m, nn.Linear): 317 | nn.init.normal_(m.weight, std=1 / m.weight.size(1)) 318 | 319 | def forward(self, x, query_idx, con=False): 320 | 321 | x = torch.nn.functional.unfold(x, self.patch_dim, stride=self.patch_dim).transpose(1, 2).transpose(0, 322 | 1).contiguous() 323 | 324 | if self.mlp == False: 325 | x = self.dropout_layer1(self.linear_encoding(x)) + x 326 | 327 | query_embed = self.query_embed.weight[query_idx].view(-1, 1, self.embedding_dim).repeat(1, x.size(1), 1) 328 | else: 329 | query_embed = None 330 | 331 | if not self.no_pos: 332 | pos = self.position_encoding(x).transpose(0, 1) 333 | 334 | if self.pos_every: 335 | x = self.encoder(x, pos=pos) 336 | x = self.decoder(x, x, pos=pos, query_pos=query_embed) 337 | elif self.no_pos: 338 | x = self.encoder(x) 339 | x = self.decoder(x, x, query_pos=query_embed) 340 | else: # here 341 | x = self.encoder(x + pos) 342 | x = self.decoder(x, x, query_pos=query_embed) 343 | 344 | if self.mlp == False: 345 | x = self.mlp_head(x) + x 346 | 347 | x = x.transpose(0, 1).contiguous().view(x.size(1), -1, self.flatten_dim) 348 | 349 | if con: 350 | con_x = x 351 | x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), int(self.img_dim), self.patch_dim, 352 | stride=self.patch_dim) 353 | return x, con_x 354 | 355 | x = torch.nn.functional.fold(x.transpose(1, 2).contiguous(), int(self.img_dim), self.patch_dim, 356 | stride=self.patch_dim) 357 | 358 | return x 359 | 360 | 361 | class LearnedPositionalEncoding(nn.Module): 362 | def __init__(self, max_position_embeddings, embedding_dim, seq_length): 363 | super(LearnedPositionalEncoding, self).__init__() 364 | self.pe = nn.Embedding(max_position_embeddings, embedding_dim) 365 | self.seq_length = seq_length 366 | 367 | self.register_buffer( 368 | "position_ids", torch.arange(self.seq_length).expand((1, -1)) 369 | ) 370 | 371 | def forward(self, x, position_ids=None): 372 | if position_ids is None: 373 | position_ids = self.position_ids[:, : self.seq_length] 374 | 375 | position_embeddings = self.pe(position_ids) 376 | return position_embeddings 377 | 378 | 379 | class TransformerEncoder(nn.Module): 380 | 381 | def __init__(self, encoder_layer, num_layers): 382 | super().__init__() 383 | self.layers = _get_clones(encoder_layer, num_layers) 384 | self.num_layers = num_layers 385 | d_model = 576 386 | 387 | def forward(self, src, pos=None): 388 | output = src 389 | for idx, layer in enumerate(self.layers): 390 | output = layer(output, pos=pos) 391 | return output 392 | 393 | 394 | class TransformerEncoderLayer(nn.Module): 395 | 396 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, no_norm=False, 397 | activation="relu"): 398 | super().__init__() 399 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False) 400 | # Implementation of Feedforward model 401 | self.linear1 = nn.Linear(d_model, dim_feedforward) 402 | self.dropout = nn.Dropout(dropout) 403 | self.linear2 = nn.Linear(dim_feedforward, d_model) 404 | 405 | self.norm1 = nn.LayerNorm(d_model) if not no_norm else nn.Identity() 406 | self.norm2 = nn.LayerNorm(d_model) if not no_norm else nn.Identity() 407 | self.dropout1 = nn.Dropout(dropout) 408 | self.dropout2 = nn.Dropout(dropout) 409 | self.activation = _get_activation_fn(activation) 410 | 411 | self.adaptir = AdaptIR(d_model) 412 | 413 | nn.init.kaiming_uniform_(self.self_attn.in_proj_weight, a=math.sqrt(5)) 414 | 415 | def with_pos_embed(self, tensor, pos): 416 | return tensor if pos is None else tensor + pos 417 | 418 | def forward(self, src, pos=None): 419 | src2 = self.norm1(src) 420 | q = k = self.with_pos_embed(src2, pos) 421 | src2 = self.self_attn(q, k, src2) 422 | 423 | src = src + self.dropout1(src2[0]) 424 | src2 = self.norm2(src) 425 | adapt = self.adaptir(src2) 426 | src2 = self.linear2(self.dropout(self.activation(self.linear1(src2)))) 427 | src = src + self.dropout2(src2+adapt) 428 | return src 429 | 430 | 431 | class TransformerDecoder(nn.Module): 432 | def __init__(self, decoder_layer, num_layers): 433 | super().__init__() 434 | self.layers = _get_clones(decoder_layer, num_layers) 435 | self.num_layers = num_layers 436 | 437 | def forward(self, tgt, memory, pos=None, query_pos=None): 438 | output = tgt 439 | for idx, layer in enumerate(self.layers): 440 | output = layer(output, memory, pos=pos, query_pos=query_pos) 441 | return output 442 | 443 | 444 | class TransformerDecoderLayer(nn.Module): 445 | def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, no_norm=False, 446 | activation="relu"): 447 | super().__init__() 448 | self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False) 449 | self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, bias=False) 450 | # Implementation of Feedforward model 451 | self.linear1 = nn.Linear(d_model, dim_feedforward) 452 | self.dropout = nn.Dropout(dropout) 453 | self.linear2 = nn.Linear(dim_feedforward, d_model) 454 | 455 | self.norm1 = nn.LayerNorm(d_model) if not no_norm else nn.Identity() 456 | self.norm2 = nn.LayerNorm(d_model) if not no_norm else nn.Identity() 457 | self.norm3 = nn.LayerNorm(d_model) if not no_norm else nn.Identity() 458 | self.dropout1 = nn.Dropout(dropout) 459 | self.dropout2 = nn.Dropout(dropout) 460 | self.dropout3 = nn.Dropout(dropout) 461 | 462 | self.activation = _get_activation_fn(activation) 463 | 464 | self.adaptir = AdaptIR(d_model) 465 | 466 | def with_pos_embed(self, tensor, pos): 467 | if pos is not None and pos.shape[0] < tensor.shape[0]: 468 | pos = torch.cat( 469 | [torch.zeros(tensor.shape[0] - pos.shape[0], tensor.shape[1], tensor.shape[2]).to(pos.device), pos], 470 | dim=0) 471 | return tensor if pos is None else tensor + pos 472 | 473 | def forward(self, tgt, memory, pos=None, query_pos=None): 474 | tgt2 = self.norm1(tgt) 475 | q = k = self.with_pos_embed(tgt2, query_pos) 476 | tgt2 = self.self_attn(q, k, value=tgt2)[0] 477 | tgt = tgt + self.dropout1(tgt2) 478 | 479 | tgt2 = self.norm2(tgt) 480 | tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos), 481 | key=self.with_pos_embed(memory, pos), 482 | value=memory)[0] 483 | tgt = tgt + self.dropout2(tgt2) 484 | 485 | tgt2 = self.norm3(tgt) 486 | adapt = self.adaptir(tgt2) 487 | tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) 488 | 489 | tgt = tgt + self.dropout3(tgt2+adapt) 490 | return tgt 491 | 492 | 493 | def _get_clones(module, N): 494 | return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) 495 | 496 | 497 | def _get_activation_fn(activation): 498 | """Return an activation function given a string""" 499 | if activation == "relu": 500 | return F.relu 501 | if activation == "gelu": 502 | return F.gelu 503 | if activation == "glu": 504 | return F.glu 505 | raise RuntimeError(F"activation should be relu/gelu, not {activation}.") 506 | 507 | 508 | class AdaptIR(nn.Module): 509 | def __init__(self, d_model): 510 | super(AdaptIR, self).__init__() 511 | self.hidden = d_model // 24 512 | self.rank = self.hidden // 2 513 | self.kernel_size = 3 514 | self.group = self.hidden 515 | self.head = nn.Conv2d(d_model, self.hidden, 1, 1) 516 | 517 | self.BN = nn.BatchNorm2d(self.hidden) 518 | 519 | self.conv_weight_A = nn.Parameter(torch.randn(self.hidden, self.rank)) 520 | self.conv_weight_B = nn.Parameter( 521 | torch.randn(self.rank, self.hidden // self.group * self.kernel_size * self.kernel_size)) 522 | self.conv_bias = nn.Parameter(torch.zeros(self.hidden)) 523 | nn.init.kaiming_uniform_(self.conv_weight_A, a=math.sqrt(5)) 524 | nn.init.kaiming_uniform_(self.conv_weight_B, a=math.sqrt(5)) 525 | 526 | self.amp_fuse = nn.Conv2d(self.hidden, self.hidden, 1, 1, groups=self.hidden) 527 | self.pha_fuse = nn.Conv2d(self.hidden, self.hidden, 1, 1, groups=self.hidden) 528 | nn.init.ones_(self.pha_fuse.weight) 529 | nn.init.ones_(self.amp_fuse.weight) 530 | nn.init.zeros_(self.amp_fuse.bias) 531 | nn.init.zeros_(self.pha_fuse.bias) 532 | 533 | self.compress = nn.Conv2d(self.hidden, 1, 1, 1) 534 | self.proj = nn.Sequential( 535 | nn.Linear(self.hidden, self.hidden // 2), 536 | nn.GELU(), 537 | nn.Linear(self.hidden // 2, self.hidden), 538 | ) 539 | self.tail = nn.Conv2d(self.hidden, d_model, 1, 1, bias=False) 540 | nn.init.zeros_(self.tail.weight) 541 | 542 | 543 | self.channel_interaction = nn.Sequential( 544 | nn.AdaptiveAvgPool2d(1), 545 | nn.Conv2d(self.hidden, self.hidden // 8, kernel_size=1), 546 | nn.GELU(), 547 | nn.Conv2d(self.hidden // 8,self.hidden, kernel_size=1) 548 | ) 549 | nn.init.zeros_(self.channel_interaction[3].weight) 550 | nn.init.zeros_(self.channel_interaction[3].bias) 551 | 552 | 553 | self.spatial_interaction = nn.Conv2d(self.hidden, 1, kernel_size=1) 554 | nn.init.zeros_(self.spatial_interaction.weight) 555 | nn.init.zeros_(self.spatial_interaction.bias) 556 | 557 | 558 | def forward(self, x): 559 | L, N, C = x.shape 560 | H = W = int(math.sqrt(L)) 561 | x = x.view(H, W, N, C).permute(2, 3, 0, 1).contiguous() # N,C,H,W 562 | x = self.BN(self.head(x)) 563 | 564 | global_x = torch.fft.rfft2(x, dim=(2, 3), norm='ortho') 565 | mag_x = torch.abs(global_x) 566 | pha_x = torch.angle(global_x) 567 | Mag = self.amp_fuse(mag_x) 568 | Pha = self.pha_fuse(pha_x) 569 | real = Mag * torch.cos(Pha) 570 | imag = Mag * torch.sin(Pha) 571 | global_x = torch.complex(real, imag) 572 | global_x = torch.fft.irfft2(global_x, s=(H, W), dim=(2, 3), norm='ortho') # N,C,H,W 573 | global_x = torch.abs(global_x) 574 | 575 | conv_weight = (self.conv_weight_A @ self.conv_weight_B) \ 576 | .view(self.hidden, self.hidden // self.group, self.kernel_size, self.kernel_size).contiguous() 577 | local_x = F.conv2d(x, weight=conv_weight, bias=self.conv_bias, stride=1, padding=1, groups=self.group) 578 | 579 | score = self.compress(x).view(N, 1, H * W).permute(0, 2, 1).contiguous() # N,HW,1 580 | score = F.softmax(score, dim=1) 581 | out = x.view(N, self.hidden, H * W) # N,C,HW 582 | out = out @ score # N,C,1 583 | out = out.permute(2, 0, 1) # 1,N,C 584 | out = self.proj(out) 585 | channel_score = out.permute(1, 2, 0).unsqueeze(-1).contiguous() # N,C,1,1 586 | 587 | channel_gate = self.channel_interaction(global_x).sigmoid() 588 | spatial_gate = self.spatial_interaction(local_x).sigmoid() 589 | spatial_x = channel_gate*local_x+spatial_gate*global_x 590 | 591 | x = self.tail(channel_score*spatial_x) 592 | x = x.view(N, C, H * W).permute(2, 0, 1).contiguous() 593 | return x 594 | 595 | 596 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | 5 | # Input Parameters 6 | parser.add_argument('--cuda', type=int, default=0) 7 | parser.add_argument('--arch', type=str, default='IPT',choices=['IPT','EDT']) 8 | parser.add_argument('--de_type', nargs='+', default='lr4_noise30', 9 | choices=['lr4_noise30', 'lr4_jpeg30', 'sr_2','sr_3','sr_4', 'denoise_30' 10 | 'denoise_50', 'derainL', 'derainH', 'low_light',], 11 | help='which type of degradations is training and testing for.') 12 | parser.add_argument('--patch_size', type=int, default=48, help='patchsize of input.') 13 | parser.add_argument('--epochs', type=int, default=500, help='maximum number of epochs to train the total model.') 14 | parser.add_argument('--val_every_n_epoch', type=int, default=25) 15 | parser.add_argument('--batch_size', type=int,default=16,help="Batch size to use per GPU") 16 | parser.add_argument('--lr', type=float,default=1e-4,help="learning rate") 17 | 18 | parser.add_argument('--num_workers', type=int, default=8, help='number of workers.') 19 | # path 20 | parser.add_argument('--data_file_dir', type=str, default='./data_dir/', help='where clean images of denoising saves.') 21 | parser.add_argument('--dataset_dir', type=str, default='/data/guohang/dataset', 22 | help='where training images of deraining saves.') 23 | parser.add_argument('--output_path', type=str, default="./output/", help='output save path') 24 | parser.add_argument("--wblogger",type=str,default=None,help = "Determine to log to wandb or not and the project name") 25 | parser.add_argument("--ckpt_dir",type=str,default="./train_ckpt",help = "Name of the Directory where the checkpoint is to be saved") 26 | parser.add_argument("--num_gpus",type=int,default=4,help = "Number of GPUs to use for training") 27 | 28 | options = parser.parse_args() 29 | 30 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import subprocess 3 | from tqdm import tqdm 4 | import numpy as np 5 | import torch 6 | from torch.utils.data import DataLoader 7 | import os 8 | import torch.nn as nn 9 | import torch.optim as optim 10 | from utils.dataset_utils import DenoiseTestDataset, DerainLowlightDataset,SRHybridTestDataset 11 | from utils.val_utils import AverageMeter, compute_psnr_ssim 12 | from utils.image_io import save_image_tensor 13 | from net.ipt import IPT 14 | import lightning.pytorch as pl 15 | import torch.nn.functional as F 16 | from utils.schedulers import LinearWarmupCosineAnnealingLR 17 | from utils.common import calculate_psnr_ssim 18 | from net.edt import EDT 19 | device = torch.device('cuda') 20 | from matplotlib import pyplot as plt 21 | 22 | 23 | 24 | class MultiTaskIRModel(pl.LightningModule): 25 | def __init__(self,args): 26 | super().__init__() 27 | self.args = args 28 | if args.arch == 'IPT': 29 | self.net = IPT(args) 30 | state_dict = torch.load('/data/guohang/pretrained/IPT_pretrain.pt') 31 | self.net.load_state_dict(state_dict, strict=False) 32 | elif args.arch == 'EDT': 33 | self.net = EDT(args) 34 | state_dict = torch.load('/data/guohang/pretrained/SRx2x3x4_EDTB_ImageNet200K.pth') 35 | self.net.load_state_dict(state_dict, strict=False) 36 | 37 | 38 | def forward(self,x): 39 | return self.net(x) 40 | 41 | def training_step(self, batch, batch_idx): 42 | ([clean_name, de_id], degrad_patch, clean_patch) = batch 43 | restored = self.net(degrad_patch) 44 | 45 | loss = self.loss_fn(restored,clean_patch) 46 | self.log("train_loss", loss) 47 | return loss 48 | 49 | def lr_scheduler_step(self,scheduler,metric): 50 | scheduler.step(self.current_epoch) 51 | 52 | 53 | def configure_optimizers(self): 54 | optimizer = optim.AdamW(self.parameters(), lr=1e-5) 55 | scheduler = LinearWarmupCosineAnnealingLR(optimizer=optimizer,warmup_epochs=15,max_epochs=150) 56 | 57 | return [optimizer],[scheduler] 58 | 59 | 60 | 61 | def test_Denoise(net, dataset, sigma=15): 62 | dataset.set_sigma(sigma) 63 | testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0) 64 | 65 | psnr = AverageMeter() 66 | ssim = AverageMeter() 67 | 68 | with torch.no_grad(): 69 | for ([clean_name], degrad_patch, clean_patch) in tqdm(testloader): 70 | degrad_patch, clean_patch = degrad_patch.to(device), clean_patch.to(device) 71 | restored = net(degrad_patch) 72 | temp_psnr, temp_ssim, N = compute_psnr_ssim(restored, clean_patch) 73 | psnr.update(temp_psnr, N) 74 | ssim.update(temp_ssim, N) 75 | 76 | print("Denoise sigma=%d: psnr: %.2f, ssim: %.4f" % (sigma, psnr.avg, ssim.avg)) 77 | 78 | 79 | 80 | def test_Derain_LowLight(net, dataset, task="derain"): 81 | dataset.set_dataset(task) 82 | testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0) 83 | 84 | psnr = AverageMeter() 85 | ssim = AverageMeter() 86 | 87 | with torch.no_grad(): 88 | for ([degraded_name], degrad_patch, clean_patch) in tqdm(testloader): 89 | degrad_patch, clean_patch = degrad_patch.to(device), clean_patch.to(device) 90 | restored = net(degrad_patch) 91 | to_y = True if 'derain' in task else False 92 | temp_psnr, temp_ssim, N = compute_psnr_ssim(restored, clean_patch,to_y=to_y) 93 | psnr.update(temp_psnr, N) 94 | ssim.update(temp_ssim, N) 95 | 96 | print("PSNR: %.2f, SSIM: %.4f" % (psnr.avg, ssim.avg)) 97 | 98 | 99 | 100 | 101 | def test_SR(net,dataset,scale): 102 | dataset.set_scale(scale) 103 | testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0) 104 | 105 | psnr = AverageMeter() 106 | ssim = AverageMeter() 107 | 108 | with torch.no_grad(): 109 | for ([clean_name], degrad_patch, clean_patch) in tqdm(testloader): 110 | degrad_patch, clean_patch = degrad_patch.to(device), clean_patch.to(device) 111 | restored = net(degrad_patch) 112 | 113 | temp_psnr, temp_ssim, N = compute_psnr_ssim(restored, clean_patch,to_y=True,bd=scale) 114 | 115 | psnr.update(temp_psnr, N) 116 | ssim.update(temp_ssim, N) 117 | 118 | print("SR scale=%d: psnr: %.2f, ssim: %.4f" % (scale, psnr.avg, ssim.avg)) 119 | 120 | 121 | 122 | 123 | 124 | def test_hybrid_degradation(net,dataset,scale): 125 | dataset.set_scale(scale) 126 | testloader = DataLoader(dataset, batch_size=1, pin_memory=True, shuffle=False, num_workers=0) 127 | psnr = AverageMeter() 128 | ssim = AverageMeter() 129 | 130 | with torch.no_grad(): 131 | for ([clean_name], degrad_patch, clean_patch) in tqdm(testloader): 132 | degrad_patch, clean_patch = degrad_patch.to(device), clean_patch.to(device) 133 | restored = net(degrad_patch) 134 | temp_psnr, temp_ssim, N = compute_psnr_ssim(restored, clean_patch,to_y=True,bd=scale) 135 | psnr.update(temp_psnr, N) 136 | ssim.update(temp_ssim, N) 137 | print("SR scale=%d: psnr: %.2f, ssim: %.4f" % (scale, psnr.avg, ssim.avg)) 138 | 139 | 140 | 141 | 142 | if __name__ == '__main__': 143 | parser = argparse.ArgumentParser() 144 | # Input Parameters 145 | parser.add_argument('--cuda', type=int, default=0) 146 | parser.add_argument('--arch', type=str, default='IPT', choices=['IPT', 'EDT']) 147 | parser.add_argument('--de_type', nargs='+', default='lr4_noise30', 148 | choices=['lr4_noise30', 'lr4_jpeg30', 'sr_2', 'sr_3', 'sr_4', 'denoise_30' 149 | 'denoise_50', 'derainL', 'derainH', 'low_light'], 150 | help='which type of degradations is training and testing for.') 151 | parser.add_argument('--output_path', type=str, default="./test_output/", help='output save path') 152 | parser.add_argument('--base_path', type=str, default="/data/guohang/dataset", help='save path of test noisy images') 153 | parser.add_argument('--ckpt_name', type=str, default='/data/guohang/AdaptIR/train_ckpt/last.ckpt', help='checkpoint save path') 154 | testopt = parser.parse_args() 155 | 156 | 157 | 158 | np.random.seed(0) 159 | torch.manual_seed(0) 160 | torch.cuda.set_device(testopt.cuda) 161 | 162 | 163 | ckpt_path = testopt.ckpt_name 164 | 165 | denoise_splits = ["ColorDN/Urban100HQ"] 166 | derainH_splits = ["Rain100H/"] 167 | derainL_splits = ["Rain100L/"] 168 | low_light_splits = ["LOLv1/Test"] 169 | hybrid_splits = ['Set5','Set14','Urban100','B100','Manga109'] 170 | sr_splits = ['Set5','Set14','Urban100','B100','Manga109'] 171 | 172 | denoise_tests = [] 173 | derain_tests = [] 174 | sr_tests = [] 175 | 176 | 177 | print("CKPT name : {}".format(ckpt_path)) 178 | 179 | 180 | net = MultiTaskIRModel(testopt) 181 | ckpt = torch.load(ckpt_path) 182 | net.load_state_dict(ckpt['state_dict']) 183 | net.eval() 184 | net.to(device) 185 | 186 | 187 | if 'denoise' in testopt.de_type: 188 | base_path = testopt.base_path 189 | for i in denoise_splits: 190 | testopt.denoise_path = os.path.join(base_path, i) 191 | denoise_testset = DenoiseTestDataset(testopt) 192 | denoise_tests.append(denoise_testset) 193 | for testset,name in zip(denoise_tests,denoise_splits): 194 | if 'denoise_30' in testopt.de_type: 195 | print('Start {} testing Sigma=30...'.format(name)) 196 | test_Denoise(net, testset, sigma=30) 197 | if 'denoise_50' in testopt.de_type: 198 | print('Start {} testing Sigma=50...'.format(name)) 199 | test_Denoise(net, testset, sigma=50) 200 | 201 | elif 'derainL' in testopt.de_type: 202 | print('Start testing light rain streak removal...') 203 | derain_base_path = testopt.base_path 204 | for name in derainL_splits: 205 | print('Start testing {} rain streak removal...'.format(name)) 206 | testopt.derain_path = os.path.join(derain_base_path,name) 207 | derain_set = DerainLowlightDataset(testopt) 208 | test_Derain_LowLight(net, derain_set, task="derainL") 209 | 210 | elif 'derainH' in testopt.de_type: 211 | print('Start testing heavy rain streak removal...') 212 | derain_base_path = testopt.base_path 213 | for name in derainH_splits: 214 | print('Start testing {} rain streak removal...'.format(name)) 215 | testopt.derain_path = os.path.join(derain_base_path,name) 216 | derain_set = DerainLowlightDataset(testopt) 217 | test_Derain_LowLight(net, derain_set, task="derainH") 218 | 219 | elif 'low_light' in testopt.de_type: 220 | print('Start testing heavy rain streak removal...') 221 | low_light_base_path = testopt.base_path 222 | for name in low_light_splits: 223 | print('Start testing {} low light enhancement...'.format(name)) 224 | testopt.low_light_path = os.path.join(low_light_base_path,name) 225 | lowlight_set = DerainLowlightDataset(testopt) 226 | test_Derain_LowLight(net, lowlight_set, task="low_light") 227 | 228 | 229 | elif 'sr' in testopt.de_type: 230 | print('Start testing super-resolution...') 231 | sr_base_path = testopt.base_path 232 | for name in sr_splits: 233 | print('Start testing {} super-resolution...'.format(name)) 234 | testopt.sr_path = os.path.join(sr_base_path,'ARTSR',name,'HR') 235 | sr_set = SRHybridTestDataset(testopt) 236 | sr_tests.append(sr_set) 237 | for testset,name in zip(sr_tests,sr_splits): 238 | if 'sr_2' in testopt.de_type: 239 | print('Start {} testing SRx2...'.format(name)) 240 | test_SR(net, testset,scale=2) 241 | if 'sr_3' in testopt.de_type: 242 | print('Start {} testing SRx3...'.format(name)) 243 | test_SR(net, testset,scale=3) 244 | if 'sr_4' in testopt.de_type: 245 | print('Start {} testing SRx4...'.format(name)) 246 | test_SR(net, testset,scale=4) 247 | 248 | 249 | elif 'lr4_noise30' in testopt.de_type: 250 | print('Start testing super-resolution...') 251 | sr_base_path = testopt.base_path 252 | for name in sr_splits: 253 | print('Start testing {} LR4+Noise30...'.format(name)) 254 | testopt.sr_path = os.path.join(sr_base_path,'ARTSR',name,'HR') 255 | sr_set = SRHybridTestDataset(testopt) 256 | sr_tests.append(sr_set) 257 | test_SR(net, sr_set, scale=4) 258 | 259 | 260 | elif 'lr4_jpeg30' in testopt.de_type: 261 | print('Start testing super-resolution...') 262 | sr_base_path = testopt.base_path 263 | for name in sr_splits: 264 | print('Start testing {} LR4+JPEG30...'.format(name)) 265 | testopt.sr_path = os.path.join(sr_base_path,'ARTSR',name,'HR') 266 | sr_set = SRHybridTestDataset(testopt) 267 | sr_tests.append(sr_set) 268 | test_SR(net, sr_set, scale=4) 269 | 270 | 271 | else: 272 | raise NotImplementedError 273 | 274 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import torch.nn as nn 4 | import torch.optim as optim 5 | from torch.utils.data import DataLoader 6 | from utils.dataset_utils import PromptTrainDataset 7 | from net.ipt import IPT 8 | from net.edt import EDT 9 | from utils.schedulers import LinearWarmupCosineAnnealingLR 10 | from torch.optim.lr_scheduler import MultiStepLR 11 | import numpy as np 12 | from options import options as opt 13 | import lightning.pytorch as pl 14 | from lightning.pytorch.loggers import WandbLogger,TensorBoardLogger 15 | from lightning.pytorch.callbacks import ModelCheckpoint 16 | from utils.val_utils import AverageMeter, compute_psnr_ssim 17 | 18 | 19 | class MultiTaskIRModel(pl.LightningModule): 20 | def __init__(self,args): 21 | super().__init__() 22 | print('load from pretrained model...') 23 | self.args = args 24 | if args.arch == 'IPT': 25 | self.net = IPT(args) 26 | state_dict = torch.load('/data/guohang/pretrained/IPT_pretrain.pt') 27 | self.net.load_state_dict(state_dict, strict=False) 28 | elif args.arch == 'EDT': 29 | self.net = EDT(args) 30 | state_dict = torch.load('/data/guohang/pretrained/SRx2x3x4_EDTB_ImageNet200K.pth') 31 | self.net.load_state_dict(state_dict, strict=False) 32 | print('frezz parameters in body except adapter, head and tail are NOT trainable') 33 | for name, param in self.net.named_parameters(): 34 | if "adaptir" not in name: 35 | param.requires_grad = False 36 | 37 | # for name, param in self.net.named_parameters(): 38 | # if param.requires_grad: 39 | # print(name) 40 | 41 | self.loss_fn = nn.L1Loss() 42 | self.save_hyperparameters() 43 | 44 | 45 | def forward(self,x): 46 | return self.net(x) 47 | 48 | def training_step(self, batch, batch_idx): 49 | ([clean_name, de_id], degrad_patch, clean_patch) = batch 50 | restored = self.net(degrad_patch,de_id) 51 | loss = self.loss_fn(restored,clean_patch) 52 | self.log("train_loss", loss) 53 | return loss 54 | 55 | 56 | 57 | # def validation_step(self,batch, batch_idx): 58 | # ([clean_name], degrad_patch, clean_patch) = batch 59 | # restored = self.net(degrad_patch) 60 | # psnr_to_y = False if 'denoise' in self.args.de_type else True 61 | # temp_psnr, temp_ssim, N = compute_psnr_ssim(restored, clean_patch,to_y=psnr_to_y) 62 | # self.log('psnr',temp_psnr,on_epoch=True,sync_dist=True) 63 | # return temp_psnr 64 | 65 | 66 | 67 | def lr_scheduler_step(self,scheduler,metric): 68 | scheduler.step() 69 | 70 | 71 | def configure_optimizers(self): 72 | optimizer = optim.AdamW(self.parameters(), lr=self.args.lr) 73 | scheduler = MultiStepLR(optimizer, milestones=[250,400,450,475], gamma=0.5) 74 | return [optimizer],[scheduler] 75 | 76 | 77 | def main(): 78 | logger = TensorBoardLogger(save_dir = "./logs") 79 | trainset = PromptTrainDataset(opt) 80 | checkpoint_callback = ModelCheckpoint(dirpath = opt.ckpt_dir, 81 | save_last=True,) 82 | 83 | trainloader = DataLoader(trainset, batch_size=opt.batch_size, pin_memory=True, shuffle=True, 84 | drop_last=True, num_workers=opt.num_workers) 85 | 86 | model = MultiTaskIRModel(opt) 87 | 88 | trainer = pl.Trainer(max_epochs=opt.epochs,accelerator="gpu", 89 | devices=opt.num_gpus, 90 | logger=logger,callbacks=[checkpoint_callback]) 91 | 92 | trainer.fit(model=model, train_dataloaders=trainloader) 93 | 94 | 95 | if __name__ == '__main__': 96 | main() 97 | 98 | 99 | 100 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/csguoh/AdaptIR/02b3490717e9fbdec5874a63ce619615d51444c8/utils/__init__.py -------------------------------------------------------------------------------- /utils/common.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | from datetime import datetime 3 | import logging 4 | import math 5 | import numpy as np 6 | import os 7 | import random 8 | from shutil import get_terminal_size 9 | import sys 10 | import time 11 | 12 | import torch 13 | from torchvision.utils import make_grid 14 | 15 | 16 | def mkdir(path): 17 | if not os.path.exists(path): 18 | os.makedirs(path) 19 | 20 | 21 | def get_timestamp(): 22 | return datetime.now().strftime('%y%m%d-%H%M%S') 23 | 24 | 25 | def setup_logger(logger_name, save_dir, phase, level=logging.INFO, screen=False, to_file=False): 26 | lg = logging.getLogger(logger_name) 27 | formatter = logging.Formatter('[%(asctime)s.%(msecs)03d - %(levelname)s]: %(message)s', 28 | datefmt='%y-%m-%d %H:%M:%S') 29 | lg.setLevel(level) 30 | if to_file: 31 | log_file = os.path.join(save_dir, '{}_{}.txt'.format(get_timestamp(),phase)) 32 | fh = logging.FileHandler(log_file, mode='w') 33 | fh.setFormatter(formatter) 34 | lg.addHandler(fh) 35 | if screen: 36 | sh = logging.StreamHandler() 37 | sh.setFormatter(formatter) 38 | lg.addHandler(sh) 39 | 40 | 41 | def init_random_seed(seed=0): 42 | np.random.seed(seed) 43 | random.seed(seed) 44 | 45 | # default values of benchmark and deterministic are False 46 | # # for reproducibility 47 | # torch.backends.cudnn.benchmark = False 48 | # torch.backends.cudnn.deterministic = True 49 | # for speed 50 | torch.backends.cudnn.benchmark = True 51 | torch.backends.cudnn.deterministic = False 52 | 53 | torch.manual_seed(seed) 54 | torch.cuda.manual_seed(seed) 55 | torch.cuda.manual_seed_all(seed) 56 | 57 | 58 | def tensor2img(tensor, out_type=np.uint8, min_max=(0, 1)): 59 | # 4D: grid (B, C, H, W), 3D: (C, H, W), 2D: (H, W) 60 | tensor = tensor.squeeze().float().cpu().clamp_(*min_max) 61 | tensor = (tensor - min_max[0]) / (min_max[1] - min_max[0]) 62 | 63 | n_dim = tensor.dim() 64 | if n_dim == 4: 65 | n_img = len(tensor) 66 | img_np = make_grid(tensor, nrow=int(math.sqrt(n_img)), padding=0, normalize=False).numpy() 67 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) 68 | elif n_dim == 3: 69 | img_np = tensor.numpy() 70 | img_np = np.transpose(img_np[[2, 1, 0], :, :], (1, 2, 0)) 71 | elif n_dim == 2: 72 | img_np = tensor.numpy() 73 | else: 74 | raise TypeError( 75 | 'Only support 4D, 3D and 2D tensor. But received with dimension: {:d}'.format(n_dim)) 76 | 77 | if out_type == np.uint8: 78 | img_np = (img_np * 255.0).round() 79 | 80 | return img_np.astype(out_type) 81 | 82 | 83 | def calculate_psnr_ssim(img1, img2, to_y=True, bd=0): 84 | img1 = img1.astype(np.float64) 85 | img2 = img2.astype(np.float64) 86 | if to_y: 87 | img1 = bgr2ycbcr(img1 / 255.0, only_y=True) * 255.0 88 | img2 = bgr2ycbcr(img2 / 255.0, only_y=True) * 255.0 89 | if bd != 0: 90 | img1 = img1[bd:-bd, bd:-bd] 91 | img2 = img2[bd:-bd, bd:-bd] 92 | psnr = calculate_psnr(img1, img2) 93 | ssim = calculate_ssim(img1, img2) 94 | 95 | return psnr, ssim 96 | 97 | def calculate_psnr(img1, img2): 98 | # img1 and img2 have range [0, 255] 99 | img1 = img1.astype(np.float64) 100 | img2 = img2.astype(np.float64) 101 | mse = np.mean((img1 - img2) ** 2) 102 | if mse == 0: 103 | return float('inf') 104 | 105 | return 20 * math.log10(255.0 / math.sqrt(mse)) 106 | 107 | 108 | def calculate_ssim(img1, img2): 109 | C1 = (0.01 * 255) ** 2 110 | C2 = (0.03 * 255) ** 2 111 | 112 | img1 = img1.astype(np.float64) 113 | img2 = img2.astype(np.float64) 114 | kernel = cv2.getGaussianKernel(11, 1.5) 115 | window = np.outer(kernel, kernel.transpose()) 116 | 117 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 118 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 119 | mu1_sq = mu1 ** 2 120 | mu2_sq = mu2 ** 2 121 | mu1_mu2 = mu1 * mu2 122 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 123 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 124 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 125 | 126 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 127 | 128 | return ssim_map.mean() 129 | 130 | 131 | def calc_psnr(sr, hr, scale=2, rgb_range=1.0, benchmark=True): 132 | # benchmark: to Y channel 133 | diff = (sr - hr).data.div(rgb_range) 134 | if benchmark: 135 | shave = scale 136 | if diff.size(1) > 1: 137 | convert = diff.new(1, 3, 1, 1) 138 | convert[0, 0, 0, 0] = 65.738 139 | convert[0, 1, 0, 0] = 129.057 140 | convert[0, 2, 0, 0] = 25.064 141 | diff.mul_(convert).div_(256) 142 | diff = diff.sum(dim=1, keepdim=True) 143 | else: 144 | shave = scale + 6 145 | 146 | valid = diff[:, :, shave:-shave, shave:-shave] 147 | mse = valid.pow(2).mean() 148 | 149 | return -10 * math.log10(mse) 150 | 151 | 152 | def rgb2ycbcr(img, only_y=True): 153 | '''same as matlab rgb2ycbcr 154 | only_y: only return Y channel 155 | Input: 156 | uint8, [0, 255] 157 | float, [0, 1] 158 | ''' 159 | in_img_type = img.dtype 160 | img.astype(np.float32) 161 | if in_img_type != np.uint8: 162 | img *= 255. 163 | # convert 164 | if only_y: 165 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 166 | else: 167 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 168 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] 169 | if in_img_type == np.uint8: 170 | rlt = rlt.round() 171 | else: 172 | rlt /= 255. 173 | return rlt.astype(in_img_type) 174 | 175 | 176 | def bgr2ycbcr(img, only_y=True): 177 | '''bgr version of rgb2ycbcr 178 | only_y: only return Y channel 179 | Input: 180 | uint8, [0, 255] 181 | float, [0, 1] 182 | ''' 183 | in_img_type = img.dtype 184 | img.astype(np.float32) 185 | if in_img_type != np.uint8: 186 | img *= 255. 187 | # convert 188 | if only_y: 189 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 190 | else: 191 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 192 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 193 | if in_img_type == np.uint8: 194 | rlt = rlt.round() 195 | else: 196 | rlt /= 255. 197 | return rlt.astype(in_img_type) 198 | 199 | 200 | def ycbcr2rgb(img): 201 | '''same as matlab ycbcr2rgb 202 | Input: 203 | uint8, [0, 255] 204 | float, [0, 1] 205 | ''' 206 | in_img_type = img.dtype 207 | img.astype(np.float32) 208 | if in_img_type != np.uint8: 209 | img *= 255. 210 | # convert 211 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], 212 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] 213 | if in_img_type == np.uint8: 214 | rlt = rlt.round() 215 | else: 216 | rlt /= 255. 217 | return rlt.astype(in_img_type) 218 | 219 | 220 | def flipx4_forward(model, inp): 221 | """Flip testing with X4 self ensemble, i.e., normal, flip H, flip W, flip H and W 222 | Args: 223 | model (PyTorch model) 224 | inp (Tensor): inputs defined by the model 225 | 226 | Returns: 227 | output (Tensor): outputs of the model. float, in CPU 228 | """ 229 | with torch.no_grad(): 230 | # normal 231 | output_f = model(inp) 232 | 233 | # flip W 234 | output = model(torch.flip(inp, (-1, ))) 235 | output_f = output_f + torch.flip(output, (-1, )) 236 | # flip H 237 | output = model(torch.flip(inp, (-2, ))) 238 | output_f = output_f + torch.flip(output, (-2, )) 239 | # flip both H and W 240 | output = model(torch.flip(inp, (-2, -1))) 241 | output_f = output_f + torch.flip(output, (-2, -1)) 242 | 243 | output_f = output_f / 4 244 | 245 | return output_f 246 | 247 | 248 | def flipRotx8_forward(model, inp): 249 | output_f = flipx4_forward(model, inp) 250 | 251 | # rot 90 252 | output = flipx4_forward(model, inp.permute(0, 1, 3, 2)) 253 | output_f = output_f + output.permute(0, 1, 3, 2) 254 | 255 | output_f = output_f / 2 256 | 257 | return output_f 258 | 259 | 260 | class ProgressBar(object): 261 | '''A progress bar which can print the progress 262 | modified from https://github.com/hellock/cvbase/blob/master/cvbase/progress.py 263 | ''' 264 | 265 | def __init__(self, task_num=0, bar_width=50, start=True): 266 | self.task_num = task_num 267 | max_bar_width = self._get_max_bar_width() 268 | self.bar_width = (bar_width if bar_width <= max_bar_width else max_bar_width) 269 | self.completed = 0 270 | if start: 271 | self.start() 272 | 273 | def _get_max_bar_width(self): 274 | terminal_width, _ = get_terminal_size() 275 | max_bar_width = min(int(terminal_width * 0.6), terminal_width - 50) 276 | if max_bar_width < 10: 277 | print('terminal width is too small ({}), please consider widen the terminal for better ' 278 | 'progressbar visualization'.format(terminal_width)) 279 | max_bar_width = 10 280 | return max_bar_width 281 | 282 | def start(self): 283 | if self.task_num > 0: 284 | sys.stdout.write('[{}] 0/{}, elapsed: 0s, ETA:\n{}\n'.format( 285 | ' ' * self.bar_width, self.task_num, 'Start...')) 286 | else: 287 | sys.stdout.write('completed: 0, elapsed: 0s') 288 | sys.stdout.flush() 289 | self.start_time = time.time() 290 | 291 | def update(self, msg='In progress...'): 292 | self.completed += 1 293 | elapsed = time.time() - self.start_time 294 | fps = self.completed / elapsed 295 | if self.task_num > 0: 296 | percentage = self.completed / float(self.task_num) 297 | eta = int(elapsed * (1 - percentage) / percentage + 0.5) 298 | mark_width = int(self.bar_width * percentage) 299 | bar_chars = '>' * mark_width + '-' * (self.bar_width - mark_width) 300 | sys.stdout.write('\033[2F') # cursor up 2 lines 301 | sys.stdout.write('\033[J') # clean the output (remove extra chars since last display) 302 | sys.stdout.write('[{}] {}/{}, {:.1f} task/s, elapsed: {}s, ETA: {:5}s\n{}\n'.format( 303 | bar_chars, self.completed, self.task_num, fps, int(elapsed + 0.5), eta, msg)) 304 | else: 305 | sys.stdout.write('completed: {}, elapsed: {}s, {:.1f} tasks/s'.format( 306 | self.completed, int(elapsed + 0.5), fps)) 307 | sys.stdout.flush() 308 | 309 | 310 | def scandir(dir_path, suffix=None, recursive=False, full_path=False): 311 | """Scan a directory to find the interested files. 312 | 313 | Args: 314 | dir_path (str): Path of the directory. 315 | suffix (str | tuple(str), optional): File suffix that we are 316 | interested in. Default: None. 317 | recursive (bool, optional): If set to True, recursively scan the 318 | directory. Default: False. 319 | full_path (bool, optional): If set to True, include the dir_path. 320 | Default: False. 321 | 322 | Returns: 323 | A generator for all the interested files with relative pathes. 324 | """ 325 | 326 | if (suffix is not None) and not isinstance(suffix, (str, tuple)): 327 | raise TypeError('"suffix" must be a string or tuple of strings') 328 | 329 | root = dir_path 330 | 331 | def _scandir(dir_path, suffix, recursive): 332 | for entry in os.scandir(dir_path): 333 | if not entry.name.startswith('.') and entry.is_file(): 334 | if full_path: 335 | return_path = entry.path 336 | else: 337 | return_path = os.path.relpath(entry.path, root) 338 | 339 | if suffix is None: 340 | yield return_path 341 | elif return_path.endswith(suffix): 342 | yield return_path 343 | else: 344 | if recursive: 345 | yield from _scandir( 346 | entry.path, suffix=suffix, recursive=recursive) 347 | else: 348 | continue 349 | 350 | return _scandir(dir_path, suffix=suffix, recursive=recursive) 351 | -------------------------------------------------------------------------------- /utils/dataset_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | import copy 4 | from PIL import Image 5 | import numpy as np 6 | import cv2 7 | from torch.utils.data import Dataset, DataLoader 8 | from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor 9 | import torch 10 | 11 | from utils.image_utils import random_augmentation, crop_img 12 | from utils.degradation_utils import Degradation 13 | 14 | def add_jpg_compression(img, quality=30): 15 | """Add JPG compression artifacts. 16 | 17 | Args: 18 | img (Numpy array): Input image, shape (h, w, c), range [0, 1], float32. 19 | quality (float): JPG compression quality. 0 for lowest quality, 100 for 20 | best quality. Default: 90. 21 | 22 | Returns: 23 | (Numpy array): Returned image after JPG, shape (h, w, c), range[0, 1], 24 | float32. 25 | """ 26 | encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality] 27 | _, encimg = cv2.imencode('.jpg', img, encode_param) 28 | img = np.clip(cv2.imdecode(encimg, 1), 0, 255).astype(np.uint8) 29 | return img 30 | 31 | 32 | class PromptTrainDataset(Dataset): 33 | def __init__(self, args): 34 | super(PromptTrainDataset, self).__init__() 35 | self.args = args 36 | self.D = Degradation(args) 37 | self.de_temp = 0 38 | self.de_type = self.args.de_type 39 | print(self.de_type) 40 | self._init_ids() 41 | self._merge_ids() 42 | 43 | self.crop_transform = Compose([ 44 | ToPILImage(), 45 | RandomCrop(args.patch_size), 46 | ]) # only used in denoise 47 | 48 | self.toTensor = ToTensor() 49 | 50 | def _init_ids(self): 51 | if 'sr' in self.de_type: 52 | self._init_sr_idx() 53 | if 'lr4_noise30' in self.de_type: 54 | self._init_lr_dn_idx() 55 | if 'lr4_jpeg30' in self.de_type: 56 | self._init_lr_jpeg_idx() 57 | if 'denoise_30' in self.de_type or 'denoise_50' in self.de_type: 58 | self._init_dn_ids() 59 | if 'derainL' in self.de_type: 60 | self._init_rs_ids(mode='L') 61 | if 'derainH' in self.de_type: 62 | self._init_rs_ids(mode='H') 63 | if 'low_light' in self.de_type: 64 | self._init_low_light_ids() 65 | 66 | def _init_lr_dn_idx(self): 67 | clean_ids = [] 68 | div2k = sorted(os.listdir(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR'))[:800] 69 | filckr2k = os.listdir(self.args.dataset_dir + '/Flickr2K/Flickr2K_HR') 70 | clean_ids += [os.path.join(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR', file) for file in div2k] 71 | clean_ids += [os.path.join(self.args.dataset_dir + '/Flickr2K/Flickr2K_HR', file) for file in filckr2k] # total -- 3450 72 | 73 | 74 | self.lr_dn_ids = [{"clean_id": x, "de_type": self.de_type} for x in clean_ids] 75 | random.shuffle(self.lr_dn_ids) 76 | self.num_clean = len(self.lr_dn_ids) 77 | print("Total LR4&DN30 Ids : {}".format(self.num_clean)) 78 | 79 | def _init_lr_jpeg_idx(self): 80 | clean_ids = [] 81 | div2k = sorted(os.listdir(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR'))[:800] 82 | filckr2k = os.listdir(self.args.dataset_dir + '/Flickr2K/Flickr2K_HR') 83 | clean_ids += [os.path.join(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR', file) for file in div2k] 84 | clean_ids += [os.path.join(self.args.dataset_dir + '/Flickr2K/Flickr2K_HR', file) for file in filckr2k] # total -- 3450 85 | 86 | 87 | self.lr_jpeg_ids = [{"clean_id": x, "de_type": self.de_type} for x in clean_ids] 88 | random.shuffle(self.lr_jpeg_ids) 89 | 90 | self.num_clean = len(self.lr_jpeg_ids) 91 | print("Total LR4&JPEG30 Ids : {}".format(self.num_clean)) 92 | 93 | 94 | 95 | def _init_sr_idx(self): 96 | clean_ids = [] 97 | div2k = sorted(os.listdir(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR'))[:800] 98 | filckr2k = os.listdir(self.args.dataset_dir + '/Flickr2K/Flickr2K_HR') 99 | clean_ids += [os.path.join(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR', file) for file in div2k] 100 | clean_ids += [os.path.join(self.args.dataset_dir + '/Flickr2K/Flickr2K_HR', file) for file in filckr2k] # total -- 3450 101 | 102 | if 'sr_2' in self.de_type: 103 | self.sr2_ids = [{"clean_id": x, "de_type": self.de_type} for x in clean_ids] 104 | random.shuffle(self.sr2_ids) 105 | if 'sr_3' in self.de_type: 106 | self.sr3_ids = [{"clean_id": x, "de_type": self.de_type} for x in clean_ids] 107 | random.shuffle(self.sr3_ids) 108 | if 'sr_4' in self.de_type: 109 | self.sr4_ids = [{"clean_id": x, "de_type": self.de_type} for x in clean_ids] 110 | random.shuffle(self.sr4_ids) 111 | 112 | self.num_clean = len(clean_ids) 113 | print("Total SR Ids : {}".format(self.num_clean)) 114 | 115 | def _init_low_light_ids(self): 116 | temp_ids = [] 117 | temp_ids += os.listdir(os.path.join(self.args.dataset_dir, 'LOLv1', 'Train','input')) 118 | new_ids = [] 119 | for name in temp_ids: 120 | if 'DS' not in name: 121 | new_ids.append(name) 122 | temp_ids = new_ids 123 | self.low_light_ids = [{"clean_id": os.path.join(self.args.dataset_dir, 'LOLv1', 'Train','input', x), "de_type": self.de_type} for x in temp_ids] * 16 124 | random.shuffle(self.low_light_ids) 125 | self.low_light_counter = 0 126 | self.num_low_light = len(self.low_light_ids) 127 | print("Total Low-light Ids : {}".format(self.num_low_light)) 128 | 129 | 130 | def _init_dn_ids(self): 131 | ref_file = self.args.data_file_dir + "noisy/denoise.txt" 132 | temp_ids = [] 133 | temp_ids+= [id_.strip() for id_ in open(ref_file)] 134 | clean_ids = [] 135 | name_list = os.listdir(self.args.dataset_dir + '/WED')+os.listdir(self.args.dataset_dir + '/BSD400') 136 | clean_ids += [id_ for id_ in name_list if id_.strip() in temp_ids] 137 | tmp=[] 138 | for elem in clean_ids: 139 | if 'bmp' in elem: 140 | tmp.append(self.args.dataset_dir + '/WED/'+elem) 141 | else: 142 | tmp.append(self.args.dataset_dir + '/BSD400/'+elem) 143 | clean_ids = tmp 144 | # add DIV2K and Filkr2K 145 | div2k = sorted(os.listdir(self.args.dataset_dir+'/DIV2K/DIV2K_train_HR'))[:800] 146 | filckr2k = os.listdir(self.args.dataset_dir+'/Flickr2K/Flickr2K_HR') 147 | clean_ids += [os.path.join(self.args.dataset_dir + '/DIV2K/DIV2K_train_HR',file) for file in div2k] 148 | clean_ids += [os.path.join(self.args.dataset_dir+'/Flickr2K/Flickr2K_HR',file) for file in filckr2k] # total -- 8194 149 | 150 | if 'denoise_30' in self.de_type: 151 | self.s30_ids = [{"clean_id": x,"de_type":'denoise_30'} for x in clean_ids] 152 | random.shuffle(self.s30_ids) 153 | self.s30_counter = 0 154 | if 'denoise_50' in self.de_type: 155 | self.s50_ids = [{"clean_id": x,"de_type":'denoise_50'} for x in clean_ids] 156 | random.shuffle(self.s50_ids) 157 | self.s50_counter = 0 158 | 159 | self.num_clean = len(clean_ids) 160 | print("Total Denoise Ids : {}".format(self.num_clean)) 161 | 162 | def _init_rs_ids(self,mode): 163 | if mode == 'H': 164 | temp_ids = [] 165 | rain_path = self.args.dataset_dir + '/RainTrainH' 166 | rain13k = os.listdir(rain_path) 167 | temp_ids += [os.path.join(rain_path, file.replace('norain', 'rain')) for file in rain13k if 'norain-' in file] * 4 168 | 169 | self.rsH_ids = [{"clean_id": x, "de_type": 'derainH'} for x in temp_ids] 170 | random.shuffle(self.rsH_ids) 171 | self.rlH_counter = 0 172 | self.num_rlH = len(self.rsH_ids) 173 | print("Total Heavy Rainy Ids : {}".format(self.num_rlH)) 174 | 175 | else: 176 | temp_ids = [] 177 | rain_path = self.args.dataset_dir + '/RainTrainL' 178 | rain13k = os.listdir(rain_path) 179 | temp_ids += [os.path.join(rain_path, file.replace('norain', 'rain')) for file in rain13k if 'norain-' in file] * 24 180 | 181 | self.rsL_ids = [{"clean_id": x, "de_type": 'derainL'} for x in temp_ids] 182 | random.shuffle(self.rsL_ids) 183 | self.rlL_counter = 0 184 | self.num_rlL = len(self.rsL_ids) 185 | print("Total Light Rainy Ids : {}".format(self.num_rlL)) 186 | 187 | def _crop_patch(self, img_1, img_2, s_hr=1): 188 | H = img_1.shape[0] 189 | W = img_1.shape[1] 190 | ind_H = random.randint(0, H - self.args.patch_size) 191 | ind_W = random.randint(0, W - self.args.patch_size) 192 | 193 | patch_1 = img_1[ind_H:ind_H + self.args.patch_size, ind_W:ind_W + self.args.patch_size] 194 | patch_2 = img_2[ind_H * s_hr:(ind_H + self.args.patch_size) * s_hr, 195 | ind_W * s_hr:(ind_W + self.args.patch_size) * s_hr] 196 | 197 | return patch_1, patch_2 198 | 199 | def _get_gt_name(self, rainy_name): 200 | gt_name = rainy_name.split("rain-")[0] + 'norain-' + rainy_name.split('rain-')[-1] 201 | return gt_name 202 | 203 | def _get_lr_name(self, hr_name): 204 | scale = self.de_type.split('_')[-1] 205 | base_name = os.path.basename(hr_name).split('.')[0] 206 | lr_path = os.path.join(hr_name.split('HR')[0] + 'LR_bicubic', 'X' + scale, base_name + 'x' + scale + '.png') 207 | return lr_path 208 | 209 | def _get_hybrid_name(self, hr_name): 210 | scale = '4' 211 | base_name = os.path.basename(hr_name).split('.')[0] 212 | lr_path = os.path.join(hr_name.split('HR')[0] + 'LR_bicubic', 'X' + scale, base_name + 'x' + scale + '.png') 213 | return lr_path 214 | 215 | 216 | def _get_normal_light_name(self, low_light_name): 217 | normal_light_name = low_light_name.replace('input', 'target') 218 | return normal_light_name 219 | 220 | 221 | 222 | def _merge_ids(self): 223 | self.sample_ids = [] 224 | if 'lr4_noise30' in self.de_type: 225 | self.sample_ids += self.lr_dn_ids 226 | if 'lr4_jpeg30' in self.de_type: 227 | self.sample_ids += self.lr_jpeg_ids 228 | if 'denoise_30' in self.de_type: 229 | self.sample_ids += self.s30_ids 230 | if 'denoise_50' in self.de_type: 231 | self.sample_ids += self.s50_ids 232 | if "derainL" in self.de_type: 233 | self.sample_ids += self.rsL_ids 234 | if "derainH" in self.de_type: 235 | self.sample_ids += self.rsH_ids 236 | if "low_light" in self.de_type: 237 | self.sample_ids += self.low_light_ids 238 | if 'sr_2' in self.de_type: 239 | self.sample_ids += self.sr2_ids 240 | if 'sr_3' in self.de_type: 241 | self.sample_ids += self.sr3_ids 242 | if 'sr_4' in self.de_type: 243 | self.sample_ids += self.sr4_ids 244 | 245 | random.shuffle(self.sample_ids) 246 | print(len(self.sample_ids)) 247 | 248 | 249 | def __getitem__(self, idx): 250 | sample = self.sample_ids[idx] 251 | de_id = sample["de_type"] 252 | if 'denoise' in de_id: # denoise 253 | clean_id = sample["clean_id"] 254 | clean_img = crop_img(np.array(Image.open(clean_id).convert('RGB')), base=16) 255 | clean_patch = self.crop_transform(clean_img) 256 | clean_patch = np.array(clean_patch) 257 | clean_name = clean_id.split("/")[-1].split('.')[0] 258 | clean_patch = random_augmentation(clean_patch)[0] 259 | degrad_patch = self.D.single_degrade(clean_patch, de_id) 260 | 261 | if 'derain' in de_id: 262 | # Rain Streak Removal 263 | degrad_img = crop_img(np.array(Image.open(sample["clean_id"]).convert('RGB')), base=16) 264 | clean_name = self._get_gt_name(sample["clean_id"]) 265 | clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16) 266 | degrad_patch, clean_patch = random_augmentation(*self._crop_patch(degrad_img, clean_img)) 267 | 268 | if 'sr' in de_id: 269 | scale = int(self.de_type.split('_')[-1]) 270 | hr_img = np.array(Image.open(sample["clean_id"]).convert('RGB')) 271 | clean_name = sample['clean_id'].split('/')[-1].split('.')[0] 272 | lr_name = self._get_hybrid_name(sample["clean_id"]) 273 | lr_img = np.array(Image.open(lr_name).convert('RGB')) 274 | degrad_patch, clean_patch = random_augmentation(*self._crop_patch(lr_img, hr_img, s_hr=scale)) 275 | 276 | if 'lr4_noise30' in de_id: 277 | scale = int(self.de_type[2]) 278 | hr_img = np.array(Image.open(sample["clean_id"]).convert('RGB')) 279 | clean_name = sample['clean_id'].split('/')[-1].split('.')[0] 280 | lr_name = self._get_hybrid_name(sample["clean_id"]) 281 | lr_img = np.array(Image.open(lr_name).convert('RGB')) 282 | degrad_patch, clean_patch = random_augmentation(*self._crop_patch(lr_img, hr_img, s_hr=scale)) 283 | degrad_patch = self.D._add_gaussian_noise(degrad_patch,sigma=30)[0] 284 | 285 | if 'lr4_jpeg30' in de_id: 286 | scale = int(self.de_type[2]) 287 | hr_img = np.array(Image.open(sample["clean_id"]).convert('RGB')) 288 | clean_name = sample['clean_id'].split('/')[-1].split('.')[0] 289 | lr_name = self._get_lr_name(sample["clean_id"]) 290 | lr_img = np.array(Image.open(lr_name).convert('RGB')) 291 | degrad_patch, clean_patch = random_augmentation(*self._crop_patch(lr_img, hr_img, s_hr=scale)) 292 | degrad_patch = add_jpg_compression(degrad_patch) 293 | 294 | if 'low_light' in de_id: 295 | degrad_img = crop_img(np.array(Image.open(sample["clean_id"]).convert('RGB')), base=16) 296 | clean_name = self._get_normal_light_name(sample["clean_id"]) 297 | clean_img = crop_img(np.array(Image.open(clean_name).convert('RGB')), base=16) 298 | degrad_patch, clean_patch = random_augmentation(*self._crop_patch(degrad_img, clean_img)) 299 | 300 | clean_patch = self.toTensor(clean_patch) 301 | degrad_patch = self.toTensor(degrad_patch) 302 | 303 | return [clean_name, de_id], degrad_patch, clean_patch 304 | 305 | 306 | def __len__(self): 307 | return len(self.sample_ids) 308 | 309 | 310 | class DenoiseTestDataset(Dataset): 311 | def __init__(self, args): 312 | super(DenoiseTestDataset, self).__init__() 313 | self.args = copy.deepcopy(args) 314 | self.clean_ids = [] 315 | self.sigma = 15 316 | 317 | self._init_clean_ids() 318 | 319 | self.toTensor = ToTensor() 320 | 321 | def _init_clean_ids(self): 322 | name_list = os.listdir(self.args.denoise_path) 323 | self.clean_ids += [os.path.join(self.args.denoise_path, id_) for id_ in name_list] 324 | self.clean_ids = sorted(self.clean_ids) 325 | self.num_clean = len(self.clean_ids) 326 | 327 | def _add_gaussian_noise(self, clean_patch): 328 | noise = np.random.randn(*clean_patch.shape) 329 | noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255) 330 | return noisy_patch, clean_patch 331 | 332 | def set_sigma(self, sigma): 333 | self.sigma = sigma 334 | 335 | def __getitem__(self, clean_id): 336 | # clean_img = crop_img(np.array(Image.open(self.clean_ids[clean_id]).convert('RGB')), base=16) 337 | clean_img = np.array(Image.open(self.clean_ids[clean_id]).convert('RGB')).astype(np.float32) 338 | clean_name = self.clean_ids[clean_id].split("/")[-1].split('.')[0] 339 | noisy_img = self._add_gaussian_noise(clean_img)[0] 340 | # clean_img, noisy_img = self.toTensor(clean_img), self.toTensor(noisy_img) 341 | clean_img, noisy_img = torch.from_numpy(np.transpose(clean_img / 255., (2, 0, 1))).float(), torch.from_numpy( 342 | np.transpose(noisy_img / 255., (2, 0, 1))).float() 343 | 344 | return [clean_name], noisy_img, clean_img 345 | 346 | def tile_degrad(input_, tile=128, tile_overlap=0): 347 | sigma_dict = {0: 0, 1: 15, 2: 25, 3: 50} 348 | b, c, h, w = input_.shape 349 | tile = min(tile, h, w) 350 | assert tile % 8 == 0, "tile size should be multiple of 8" 351 | 352 | stride = tile - tile_overlap 353 | h_idx_list = list(range(0, h - tile, stride)) + [h - tile] 354 | w_idx_list = list(range(0, w - tile, stride)) + [w - tile] 355 | E = torch.zeros(b, c, h, w).type_as(input_) 356 | W = torch.zeros_like(E) 357 | s = 0 358 | for h_idx in h_idx_list: 359 | for w_idx in w_idx_list: 360 | in_patch = input_[..., h_idx:h_idx + tile, w_idx:w_idx + tile] 361 | out_patch = in_patch 362 | # out_patch = model(in_patch) 363 | out_patch_mask = torch.ones_like(in_patch) 364 | 365 | E[..., h_idx:(h_idx + tile), w_idx:(w_idx + tile)].add_(out_patch) 366 | W[..., h_idx:(h_idx + tile), w_idx:(w_idx + tile)].add_(out_patch_mask) 367 | restored = E.div_(W) 368 | 369 | restored = torch.clamp(restored, 0, 1) 370 | return restored 371 | 372 | def __len__(self): 373 | return self.num_clean 374 | 375 | 376 | class DerainLowlightDataset(Dataset): 377 | def __init__(self, args): 378 | super(DerainLowlightDataset, self).__init__() 379 | self.ids = [] 380 | self.args = copy.deepcopy(args) 381 | self.task = self.args.de_type 382 | self.toTensor = ToTensor() 383 | self.set_dataset(self.task) 384 | 385 | def _add_gaussian_noise(self, clean_patch): 386 | noise = np.random.randn(*clean_patch.shape) 387 | noisy_patch = np.clip(clean_patch + noise * self.sigma, 0, 255).astype(np.uint8) 388 | return noisy_patch, clean_patch 389 | 390 | def _init_input_ids(self): 391 | if 'derain' in self.task: # derain 392 | self.ids = [] 393 | name_list = os.listdir(os.path.join(self.args.derain_path, 'rainy')) 394 | # print(name_list) 395 | self.ids += [os.path.join(self.args.derain_path, 'rainy', id_) for id_ in name_list] 396 | elif self.task == 'low_light': 397 | self.ids = [] 398 | name_list = os.listdir(os.path.join(self.args.low_light_path, 'input')) 399 | new_ids = [] 400 | for name in name_list: 401 | if 'DS' not in name: 402 | new_ids.append(name) 403 | name_list = new_ids 404 | self.ids += [os.path.join(self.args.low_light_path, 'input', id_) for id_ in name_list] 405 | 406 | self.length = len(self.ids) 407 | 408 | def _get_gt_path(self, degraded_name): 409 | if 'derain' in self.task: 410 | filename = degraded_name.split('-')[-1] 411 | gt_name = os.path.join(self.args.derain_path, 'norain-' + filename) 412 | elif self.task == 'low_light': 413 | gt_name = degraded_name.replace('input', 'target') 414 | return gt_name 415 | 416 | def set_dataset(self, task): 417 | self._init_input_ids() 418 | 419 | def __getitem__(self, idx): 420 | degraded_path = self.ids[idx] 421 | clean_path = self._get_gt_path(degraded_path) 422 | 423 | degraded_img = np.array(Image.open(degraded_path).convert('RGB')) 424 | 425 | clean_img = np.array(Image.open(clean_path).convert('RGB')) 426 | 427 | clean_img, degraded_img = self.toTensor(clean_img), self.toTensor(degraded_img) 428 | degraded_name = degraded_path.split('/')[-1][:-4] 429 | 430 | return [degraded_name], degraded_img, clean_img 431 | 432 | def __len__(self): 433 | return self.length 434 | 435 | 436 | class SRHybridTestDataset(Dataset): 437 | def __init__(self, args): 438 | super(SRHybridTestDataset, self).__init__() 439 | self.args = copy.deepcopy(args) 440 | self.clean_ids = [] 441 | self.toTensor = ToTensor() 442 | 443 | def set_scale(self, s): 444 | self.scale = s 445 | self._init_sr_ids() 446 | 447 | def _init_sr_ids(self): 448 | hr_path = self.args.sr_path 449 | name_list = os.listdir(hr_path) 450 | self.clean_ids += [os.path.join(hr_path, id_) for id_ in name_list] 451 | self.clean_ids = sorted(self.clean_ids) 452 | self.num_clean = len(self.clean_ids) 453 | 454 | def __getitem__(self, clean_id): 455 | hr_img = np.array(Image.open(self.clean_ids[clean_id]).convert('RGB')) 456 | 457 | file_name, ext = os.path.splitext(os.path.basename(self.clean_ids[clean_id])) 458 | if 'Manga109' not in self.clean_ids[clean_id]: 459 | lr_path = os.path.join(os.path.dirname(self.clean_ids[clean_id]).replace('HR', 'LR_bicubic'), 460 | 'X{}/{}x{}{}'.format(self.scale, file_name, self.scale, ext)) 461 | else: 462 | lr_path = os.path.join(os.path.dirname(self.clean_ids[clean_id]).replace('HR', 'LR_bicubic'), 463 | 'X{}/{}_LRBI_x{}{}'.format(self.scale, file_name, self.scale, ext)) 464 | lr_img = np.array(Image.open(lr_path).convert('RGB')) 465 | 466 | 467 | if self.args.de_type == 'lr4_noise30': 468 | sigma=30 469 | noise = np.random.randn(*lr_img.shape) 470 | lr_img = np.clip(lr_img + noise * sigma, 0, 255).astype(np.uint8) 471 | if self.args.de_type == 'lr4_jpeg30': 472 | lr_img = add_jpg_compression(lr_img) 473 | 474 | 475 | ih, iw = lr_img.shape[:2] 476 | hr_img = hr_img[0:ih * self.scale, 0:iw * self.scale] 477 | hr_img, lr_img = self.toTensor(hr_img), self.toTensor(lr_img) 478 | 479 | return [lr_path.split('/')[-1]], lr_img, hr_img 480 | 481 | def __len__(self): 482 | return self.num_clean 483 | 484 | 485 | 486 | 487 | -------------------------------------------------------------------------------- /utils/degradation_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torchvision.transforms import ToPILImage, Compose, RandomCrop, ToTensor, Grayscale 3 | 4 | from PIL import Image 5 | import random 6 | import numpy as np 7 | 8 | from utils.image_utils import crop_img 9 | 10 | 11 | class Degradation(object): 12 | def __init__(self, args): 13 | super(Degradation, self).__init__() 14 | self.args = args 15 | self.toTensor = ToTensor() 16 | self.crop_transform = Compose([ 17 | ToPILImage(), 18 | RandomCrop(args.patch_size), 19 | ]) 20 | 21 | def _add_gaussian_noise(self, clean_patch, sigma): 22 | # noise = torch.randn(*(clean_patch.shape)) 23 | # clean_patch = self.toTensor(clean_patch) 24 | noise = np.random.randn(*clean_patch.shape) 25 | noisy_patch = np.clip(clean_patch + noise * sigma, 0, 255).astype(np.uint8) 26 | return noisy_patch, clean_patch 27 | 28 | def _degrade_by_type(self, clean_patch, degrade_type): 29 | if degrade_type == 'denoise_15': 30 | # denoise sigma=15 31 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=15) 32 | elif degrade_type == 'denoise_25': 33 | # denoise sigma=25 34 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=25) 35 | elif degrade_type == 'denoise_30': 36 | # denoise sigma=30 37 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=30) 38 | elif degrade_type == 'denoise_50': 39 | # denoise sigma=50 40 | degraded_patch, clean_patch = self._add_gaussian_noise(clean_patch, sigma=50) 41 | 42 | return degraded_patch, clean_patch 43 | 44 | def degrade(self, clean_patch_1, clean_patch_2, degrade_type=None): 45 | if degrade_type == None: 46 | degrade_type = random.randint(0, 3) 47 | else: 48 | degrade_type = degrade_type 49 | 50 | degrad_patch_1, _ = self._degrade_by_type(clean_patch_1, degrade_type) 51 | degrad_patch_2, _ = self._degrade_by_type(clean_patch_2, degrade_type) 52 | return degrad_patch_1, degrad_patch_2 53 | 54 | def single_degrade(self,clean_patch,degrade_type = None): 55 | if degrade_type == None: 56 | degrade_type = random.randint(0, 3) 57 | else: 58 | degrade_type = degrade_type 59 | 60 | degrad_patch_1, _ = self._degrade_by_type(clean_patch, degrade_type) 61 | return degrad_patch_1 62 | -------------------------------------------------------------------------------- /utils/image_io.py: -------------------------------------------------------------------------------- 1 | import glob 2 | 3 | import torch 4 | import torchvision 5 | import matplotlib 6 | import matplotlib.pyplot as plt 7 | import numpy as np 8 | from PIL import Image 9 | 10 | # import skvideo.io 11 | 12 | matplotlib.use('agg') 13 | 14 | 15 | def prepare_hazy_image(file_name): 16 | img_pil = crop_image(get_image(file_name, -1)[0], d=32) 17 | return pil_to_np(img_pil) 18 | 19 | 20 | def prepare_gt_img(file_name, SOTS=True): 21 | if SOTS: 22 | img_pil = crop_image(crop_a_image(get_image(file_name, -1)[0], d=10), d=32) 23 | else: 24 | img_pil = crop_image(get_image(file_name, -1)[0], d=32) 25 | 26 | return pil_to_np(img_pil) 27 | 28 | 29 | def crop_a_image(img, d=10): 30 | bbox = [ 31 | int((d)), 32 | int((d)), 33 | int((img.size[0] - d)), 34 | int((img.size[1] - d)), 35 | ] 36 | img_cropped = img.crop(bbox) 37 | return img_cropped 38 | 39 | 40 | def crop_image(img, d=32): 41 | """ 42 | Make dimensions divisible by d 43 | 44 | :param pil img: 45 | :param d: 46 | :return: 47 | """ 48 | 49 | new_size = (img.size[0] - img.size[0] % d, 50 | img.size[1] - img.size[1] % d) 51 | 52 | bbox = [ 53 | int((img.size[0] - new_size[0]) / 2), 54 | int((img.size[1] - new_size[1]) / 2), 55 | int((img.size[0] + new_size[0]) / 2), 56 | int((img.size[1] + new_size[1]) / 2), 57 | ] 58 | 59 | img_cropped = img.crop(bbox) 60 | return img_cropped 61 | 62 | 63 | def crop_np_image(img_np, d=32): 64 | return torch_to_np(crop_torch_image(np_to_torch(img_np), d)) 65 | 66 | 67 | def crop_torch_image(img, d=32): 68 | """ 69 | Make dimensions divisible by d 70 | image is [1, 3, W, H] or [3, W, H] 71 | :param pil img: 72 | :param d: 73 | :return: 74 | """ 75 | new_size = (img.shape[-2] - img.shape[-2] % d, 76 | img.shape[-1] - img.shape[-1] % d) 77 | pad = ((img.shape[-2] - new_size[-2]) // 2, (img.shape[-1] - new_size[-1]) // 2) 78 | 79 | if len(img.shape) == 4: 80 | return img[:, :, pad[-2]: pad[-2] + new_size[-2], pad[-1]: pad[-1] + new_size[-1]] 81 | assert len(img.shape) == 3 82 | return img[:, pad[-2]: pad[-2] + new_size[-2], pad[-1]: pad[-1] + new_size[-1]] 83 | 84 | 85 | def get_params(opt_over, net, net_input, downsampler=None): 86 | """ 87 | Returns parameters that we want to optimize over. 88 | :param opt_over: comma separated list, e.g. "net,input" or "net" 89 | :param net: network 90 | :param net_input: torch.Tensor that stores input `z` 91 | :param downsampler: 92 | :return: 93 | """ 94 | 95 | opt_over_list = opt_over.split(',') 96 | params = [] 97 | 98 | for opt in opt_over_list: 99 | 100 | if opt == 'net': 101 | params += [x for x in net.parameters()] 102 | elif opt == 'down': 103 | assert downsampler is not None 104 | params = [x for x in downsampler.parameters()] 105 | elif opt == 'input': 106 | net_input.requires_grad = True 107 | params += [net_input] 108 | else: 109 | assert False, 'what is it?' 110 | 111 | return params 112 | 113 | 114 | def get_image_grid(images_np, nrow=8): 115 | """ 116 | Creates a grid from a list of images by concatenating them. 117 | :param images_np: 118 | :param nrow: 119 | :return: 120 | """ 121 | images_torch = [torch.from_numpy(x).type(torch.FloatTensor) for x in images_np] 122 | torch_grid = torchvision.utils.make_grid(images_torch, nrow) 123 | 124 | return torch_grid.numpy() 125 | 126 | 127 | def plot_image_grid(name, images_np, interpolation='lanczos', output_path="output/"): 128 | """ 129 | Draws images in a grid 130 | 131 | Args: 132 | images_np: list of images, each image is np.array of size 3xHxW or 1xHxW 133 | nrow: how many images will be in one row 134 | interpolation: interpolation used in plt.imshow 135 | """ 136 | assert len(images_np) == 2 137 | n_channels = max(x.shape[0] for x in images_np) 138 | assert (n_channels == 3) or (n_channels == 1), "images should have 1 or 3 channels" 139 | 140 | images_np = [x if (x.shape[0] == n_channels) else np.concatenate([x, x, x], axis=0) for x in images_np] 141 | 142 | grid = get_image_grid(images_np, 2) 143 | 144 | if images_np[0].shape[0] == 1: 145 | plt.imshow(grid[0], cmap='gray', interpolation=interpolation) 146 | else: 147 | plt.imshow(grid.transpose(1, 2, 0), interpolation=interpolation) 148 | 149 | plt.savefig(output_path + "{}.png".format(name)) 150 | 151 | 152 | def save_image_np(name, image_np, output_path="output/"): 153 | p = np_to_pil(image_np) 154 | p.save(output_path + "{}.png".format(name)) 155 | 156 | 157 | def save_image_tensor(image_tensor, output_path="output/"): 158 | image_np = torch_to_np(image_tensor) 159 | # print(image_np.shape) 160 | p = np_to_pil(image_np) 161 | p.save(output_path) 162 | 163 | 164 | def video_to_images(file_name, name): 165 | video = prepare_video(file_name) 166 | for i, f in enumerate(video): 167 | save_image(name + "_{0:03d}".format(i), f) 168 | 169 | 170 | def images_to_video(images_dir, name, gray=True): 171 | num = len(glob.glob(images_dir + "/*.jpg")) 172 | c = [] 173 | for i in range(num): 174 | if gray: 175 | img = prepare_gray_image(images_dir + "/" + name + "_{}.jpg".format(i)) 176 | else: 177 | img = prepare_image(images_dir + "/" + name + "_{}.jpg".format(i)) 178 | print(img.shape) 179 | c.append(img) 180 | save_video(name, np.array(c)) 181 | 182 | 183 | def save_heatmap(name, image_np): 184 | cmap = plt.get_cmap('jet') 185 | 186 | rgba_img = cmap(image_np) 187 | rgb_img = np.delete(rgba_img, 3, 2) 188 | save_image(name, rgb_img.transpose(2, 0, 1)) 189 | 190 | 191 | def save_graph(name, graph_list, output_path="output/"): 192 | plt.clf() 193 | plt.plot(graph_list) 194 | plt.savefig(output_path + name + ".png") 195 | 196 | 197 | def create_augmentations(np_image): 198 | """ 199 | convention: original, left, upside-down, right, rot1, rot2, rot3 200 | :param np_image: 201 | :return: 202 | """ 203 | aug = [np_image.copy(), np.rot90(np_image, 1, (1, 2)).copy(), 204 | np.rot90(np_image, 2, (1, 2)).copy(), np.rot90(np_image, 3, (1, 2)).copy()] 205 | flipped = np_image[:, ::-1, :].copy() 206 | aug += [flipped.copy(), np.rot90(flipped, 1, (1, 2)).copy(), np.rot90(flipped, 2, (1, 2)).copy(), 207 | np.rot90(flipped, 3, (1, 2)).copy()] 208 | return aug 209 | 210 | 211 | def create_video_augmentations(np_video): 212 | """ 213 | convention: original, left, upside-down, right, rot1, rot2, rot3 214 | :param np_video: 215 | :return: 216 | """ 217 | aug = [np_video.copy(), np.rot90(np_video, 1, (2, 3)).copy(), 218 | np.rot90(np_video, 2, (2, 3)).copy(), np.rot90(np_video, 3, (2, 3)).copy()] 219 | flipped = np_video[:, :, ::-1, :].copy() 220 | aug += [flipped.copy(), np.rot90(flipped, 1, (2, 3)).copy(), np.rot90(flipped, 2, (2, 3)).copy(), 221 | np.rot90(flipped, 3, (2, 3)).copy()] 222 | return aug 223 | 224 | 225 | def save_graphs(name, graph_dict, output_path="output/"): 226 | """ 227 | 228 | :param name: 229 | :param dict graph_dict: a dict from the name of the list to the list itself. 230 | :return: 231 | """ 232 | plt.clf() 233 | fig, ax = plt.subplots() 234 | for k, v in graph_dict.items(): 235 | ax.plot(v, label=k) 236 | # ax.semilogy(v, label=k) 237 | ax.set_xlabel('iterations') 238 | # ax.set_ylabel(name) 239 | ax.set_ylabel('MSE-loss') 240 | # ax.set_ylabel('PSNR') 241 | plt.legend() 242 | plt.savefig(output_path + name + ".png") 243 | 244 | 245 | def load(path): 246 | """Load PIL image.""" 247 | img = Image.open(path) 248 | return img 249 | 250 | 251 | def get_image(path, imsize=-1): 252 | """Load an image and resize to a cpecific size. 253 | 254 | Args: 255 | path: path to image 256 | imsize: tuple or scalar with dimensions; -1 for `no resize` 257 | """ 258 | img = load(path) 259 | if isinstance(imsize, int): 260 | imsize = (imsize, imsize) 261 | 262 | if imsize[0] != -1 and img.size != imsize: 263 | if imsize[0] > img.size[0]: 264 | img = img.resize(imsize, Image.BICUBIC) 265 | else: 266 | img = img.resize(imsize, Image.ANTIALIAS) 267 | 268 | img_np = pil_to_np(img) 269 | # 3*460*620 270 | # print(np.shape(img_np)) 271 | 272 | return img, img_np 273 | 274 | 275 | def prepare_gt(file_name): 276 | """ 277 | loads makes it divisible 278 | :param file_name: 279 | :return: the numpy representation of the image 280 | """ 281 | img = get_image(file_name, -1) 282 | # print(img[0].size) 283 | 284 | img_pil = img[0].crop([10, 10, img[0].size[0] - 10, img[0].size[1] - 10]) 285 | 286 | img_pil = crop_image(img_pil, d=32) 287 | 288 | # img_pil = get_image(file_name, -1)[0] 289 | # print(img_pil.size) 290 | return pil_to_np(img_pil) 291 | 292 | 293 | def prepare_image(file_name): 294 | """ 295 | loads makes it divisible 296 | :param file_name: 297 | :return: the numpy representation of the image 298 | """ 299 | img = get_image(file_name, -1) 300 | # print(img[0].size) 301 | # img_pil = img[0] 302 | img_pil = crop_image(img[0], d=16) 303 | # img_pil = get_image(file_name, -1)[0] 304 | # print(img_pil.size) 305 | return pil_to_np(img_pil) 306 | 307 | 308 | # def prepare_video(file_name, folder="output/"): 309 | # data = skvideo.io.vread(folder + file_name) 310 | # return crop_torch_image(data.transpose(0, 3, 1, 2).astype(np.float32) / 255.)[:35] 311 | # 312 | # 313 | # def save_video(name, video_np, output_path="output/"): 314 | # outputdata = video_np * 255 315 | # outputdata = outputdata.astype(np.uint8) 316 | # skvideo.io.vwrite(output_path + "{}.mp4".format(name), outputdata.transpose(0, 2, 3, 1)) 317 | 318 | 319 | def prepare_gray_image(file_name): 320 | img = prepare_image(file_name) 321 | return np.array([np.mean(img, axis=0)]) 322 | 323 | 324 | def pil_to_np(img_PIL, with_transpose=True): 325 | """ 326 | Converts image in PIL format to np.array. 327 | 328 | From W x H x C [0...255] to C x W x H [0..1] 329 | """ 330 | ar = np.array(img_PIL) 331 | if len(ar.shape) == 3 and ar.shape[-1] == 4: 332 | ar = ar[:, :, :3] 333 | # this is alpha channel 334 | if with_transpose: 335 | if len(ar.shape) == 3: 336 | ar = ar.transpose(2, 0, 1) 337 | else: 338 | ar = ar[None, ...] 339 | 340 | return ar.astype(np.float32) / 255. 341 | 342 | 343 | def median(img_np_list): 344 | """ 345 | assumes C x W x H [0..1] 346 | :param img_np_list: 347 | :return: 348 | """ 349 | assert len(img_np_list) > 0 350 | l = len(img_np_list) 351 | shape = img_np_list[0].shape 352 | result = np.zeros(shape) 353 | for c in range(shape[0]): 354 | for w in range(shape[1]): 355 | for h in range(shape[2]): 356 | result[c, w, h] = sorted(i[c, w, h] for i in img_np_list)[l // 2] 357 | return result 358 | 359 | 360 | def average(img_np_list): 361 | """ 362 | assumes C x W x H [0..1] 363 | :param img_np_list: 364 | :return: 365 | """ 366 | assert len(img_np_list) > 0 367 | l = len(img_np_list) 368 | shape = img_np_list[0].shape 369 | result = np.zeros(shape) 370 | for i in img_np_list: 371 | result += i 372 | return result / l 373 | 374 | 375 | def np_to_pil(img_np): 376 | """ 377 | Converts image in np.array format to PIL image. 378 | 379 | From C x W x H [0..1] to W x H x C [0...255] 380 | :param img_np: 381 | :return: 382 | """ 383 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) 384 | 385 | if img_np.shape[0] == 1: 386 | ar = ar[0] 387 | else: 388 | assert img_np.shape[0] == 3, img_np.shape 389 | ar = ar.transpose(1, 2, 0) 390 | 391 | return Image.fromarray(ar) 392 | 393 | 394 | def np_to_torch(img_np): 395 | """ 396 | Converts image in numpy.array to torch.Tensor. 397 | 398 | From C x W x H [0..1] to C x W x H [0..1] 399 | 400 | :param img_np: 401 | :return: 402 | """ 403 | return torch.from_numpy(img_np)[None, :] 404 | 405 | 406 | def torch_to_np(img_var): 407 | """ 408 | Converts an image in torch.Tensor format to np.array. 409 | 410 | From 1 x C x W x H [0..1] to C x W x H [0..1] 411 | :param img_var: 412 | :return: 413 | """ 414 | return img_var.detach().cpu().numpy()[0] 415 | -------------------------------------------------------------------------------- /utils/image_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Created on 2020/9/8 3 | 4 | @author: Boyun Li 5 | """ 6 | import os 7 | import numpy as np 8 | import torch 9 | import random 10 | import torch.nn as nn 11 | from torch.nn import init 12 | from PIL import Image 13 | 14 | class EdgeComputation(nn.Module): 15 | def __init__(self, test=False): 16 | super(EdgeComputation, self).__init__() 17 | self.test = test 18 | def forward(self, x): 19 | if self.test: 20 | x_diffx = torch.abs(x[:, :, :, 1:] - x[:, :, :, :-1]) 21 | x_diffy = torch.abs(x[:, :, 1:, :] - x[:, :, :-1, :]) 22 | 23 | # y = torch.Tensor(x.size()).cuda() 24 | y = torch.Tensor(x.size()) 25 | y.fill_(0) 26 | y[:, :, :, 1:] += x_diffx 27 | y[:, :, :, :-1] += x_diffx 28 | y[:, :, 1:, :] += x_diffy 29 | y[:, :, :-1, :] += x_diffy 30 | y = torch.sum(y, 1, keepdim=True) / 3 31 | y /= 4 32 | return y 33 | else: 34 | x_diffx = torch.abs(x[:, :, 1:] - x[:, :, :-1]) 35 | x_diffy = torch.abs(x[:, 1:, :] - x[:, :-1, :]) 36 | 37 | y = torch.Tensor(x.size()) 38 | y.fill_(0) 39 | y[:, :, 1:] += x_diffx 40 | y[:, :, :-1] += x_diffx 41 | y[:, 1:, :] += x_diffy 42 | y[:, :-1, :] += x_diffy 43 | y = torch.sum(y, 0) / 3 44 | y /= 4 45 | return y.unsqueeze(0) 46 | 47 | 48 | # randomly crop a patch from image 49 | def crop_patch(im, pch_size): 50 | H = im.shape[0] 51 | W = im.shape[1] 52 | ind_H = random.randint(0, H - pch_size) 53 | ind_W = random.randint(0, W - pch_size) 54 | pch = im[ind_H:ind_H + pch_size, ind_W:ind_W + pch_size] 55 | return pch 56 | 57 | 58 | # crop an image to the multiple of base 59 | def crop_img(image, base=64): 60 | h = image.shape[0] 61 | w = image.shape[1] 62 | crop_h = h % base 63 | crop_w = w % base 64 | return image[crop_h // 2:h - crop_h + crop_h // 2, crop_w // 2:w - crop_w + crop_w // 2, :] 65 | 66 | 67 | # image (H, W, C) -> patches (B, H, W, C) 68 | def slice_image2patches(image, patch_size=64, overlap=0): 69 | assert image.shape[0] % patch_size == 0 and image.shape[1] % patch_size == 0 70 | H = image.shape[0] 71 | W = image.shape[1] 72 | patches = [] 73 | image_padding = np.pad(image, ((overlap, overlap), (overlap, overlap), (0, 0)), mode='edge') 74 | for h in range(H // patch_size): 75 | for w in range(W // patch_size): 76 | idx_h = [h * patch_size, (h + 1) * patch_size + overlap] 77 | idx_w = [w * patch_size, (w + 1) * patch_size + overlap] 78 | patches.append(np.expand_dims(image_padding[idx_h[0]:idx_h[1], idx_w[0]:idx_w[1], :], axis=0)) 79 | return np.concatenate(patches, axis=0) 80 | 81 | 82 | # patches (B, H, W, C) -> image (H, W, C) 83 | def splice_patches2image(patches, image_size, overlap=0): 84 | assert len(image_size) > 1 85 | assert patches.shape[-3] == patches.shape[-2] 86 | H = image_size[0] 87 | W = image_size[1] 88 | patch_size = patches.shape[-2] - overlap 89 | image = np.zeros(image_size) 90 | idx = 0 91 | for h in range(H // patch_size): 92 | for w in range(W // patch_size): 93 | image[h * patch_size:(h + 1) * patch_size, w * patch_size:(w + 1) * patch_size, :] = patches[idx, 94 | overlap:patch_size + overlap, 95 | overlap:patch_size + overlap, 96 | :] 97 | idx += 1 98 | return image 99 | 100 | 101 | # def data_augmentation(image, mode): 102 | # if mode == 0: 103 | # # original 104 | # out = image.numpy() 105 | # elif mode == 1: 106 | # # flip up and down 107 | # out = np.flipud(image) 108 | # elif mode == 2: 109 | # # rotate counterwise 90 degree 110 | # out = np.rot90(image, axes=(1, 2)) 111 | # elif mode == 3: 112 | # # rotate 90 degree and flip up and down 113 | # out = np.rot90(image, axes=(1, 2)) 114 | # out = np.flipud(out) 115 | # elif mode == 4: 116 | # # rotate 180 degree 117 | # out = np.rot90(image, k=2, axes=(1, 2)) 118 | # elif mode == 5: 119 | # # rotate 180 degree and flip 120 | # out = np.rot90(image, k=2, axes=(1, 2)) 121 | # out = np.flipud(out) 122 | # elif mode == 6: 123 | # # rotate 270 degree 124 | # out = np.rot90(image, k=3, axes=(1, 2)) 125 | # elif mode == 7: 126 | # # rotate 270 degree and flip 127 | # out = np.rot90(image, k=3, axes=(1, 2)) 128 | # out = np.flipud(out) 129 | # else: 130 | # raise Exception('Invalid choice of image transformation') 131 | # return out 132 | 133 | def data_augmentation(image, mode): 134 | if mode == 0: 135 | # original 136 | out = image.numpy() 137 | elif mode == 1: 138 | # flip up and down 139 | out = np.flipud(image) 140 | elif mode == 2: 141 | # rotate counterwise 90 degree 142 | out = np.rot90(image) 143 | elif mode == 3: 144 | # rotate 90 degree and flip up and down 145 | out = np.rot90(image) 146 | out = np.flipud(out) 147 | elif mode == 4: 148 | # rotate 180 degree 149 | out = np.rot90(image, k=2) 150 | elif mode == 5: 151 | # rotate 180 degree and flip 152 | out = np.rot90(image, k=2) 153 | out = np.flipud(out) 154 | elif mode == 6: 155 | # rotate 270 degree 156 | out = np.rot90(image, k=3) 157 | elif mode == 7: 158 | # rotate 270 degree and flip 159 | out = np.rot90(image, k=3) 160 | out = np.flipud(out) 161 | else: 162 | raise Exception('Invalid choice of image transformation') 163 | return out 164 | 165 | 166 | def random_augmentation(*args): 167 | out = [] 168 | flag_aug = random.randint(1, 7) 169 | for data in args: 170 | out.append(data_augmentation(data, flag_aug).copy()) 171 | return out 172 | 173 | 174 | def weights_init_normal_(m): 175 | classname = m.__class__.__name__ 176 | if classname.find('Conv') != -1: 177 | init.uniform(m.weight.data, 0.0, 0.02) 178 | elif classname.find('Linear') != -1: 179 | init.uniform(m.weight.data, 0.0, 0.02) 180 | elif classname.find('BatchNorm2d') != -1: 181 | init.uniform(m.weight.data, 1.0, 0.02) 182 | init.constant(m.bias.data, 0.0) 183 | 184 | 185 | def weights_init_normal(m): 186 | classname = m.__class__.__name__ 187 | if classname.find('Conv2d') != -1: 188 | m.apply(weights_init_normal_) 189 | elif classname.find('Linear') != -1: 190 | init.uniform(m.weight.data, 0.0, 0.02) 191 | elif classname.find('BatchNorm2d') != -1: 192 | init.uniform(m.weight.data, 1.0, 0.02) 193 | init.constant(m.bias.data, 0.0) 194 | 195 | 196 | def weights_init_xavier(m): 197 | classname = m.__class__.__name__ 198 | if classname.find('Conv') != -1: 199 | init.xavier_normal(m.weight.data, gain=1) 200 | elif classname.find('Linear') != -1: 201 | init.xavier_normal(m.weight.data, gain=1) 202 | elif classname.find('BatchNorm2d') != -1: 203 | init.uniform(m.weight.data, 1.0, 0.02) 204 | init.constant(m.bias.data, 0.0) 205 | 206 | 207 | def weights_init_kaiming(m): 208 | classname = m.__class__.__name__ 209 | if classname.find('Conv') != -1: 210 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 211 | elif classname.find('Linear') != -1: 212 | init.kaiming_normal(m.weight.data, a=0, mode='fan_in') 213 | elif classname.find('BatchNorm2d') != -1: 214 | init.uniform(m.weight.data, 1.0, 0.02) 215 | init.constant(m.bias.data, 0.0) 216 | 217 | 218 | def weights_init_orthogonal(m): 219 | classname = m.__class__.__name__ 220 | print(classname) 221 | if classname.find('Conv') != -1: 222 | init.orthogonal(m.weight.data, gain=1) 223 | elif classname.find('Linear') != -1: 224 | init.orthogonal(m.weight.data, gain=1) 225 | elif classname.find('BatchNorm2d') != -1: 226 | init.uniform(m.weight.data, 1.0, 0.02) 227 | init.constant(m.bias.data, 0.0) 228 | 229 | 230 | def init_weights(net, init_type='normal'): 231 | print('initialization method [%s]' % init_type) 232 | if init_type == 'normal': 233 | net.apply(weights_init_normal) 234 | elif init_type == 'xavier': 235 | net.apply(weights_init_xavier) 236 | elif init_type == 'kaiming': 237 | net.apply(weights_init_kaiming) 238 | elif init_type == 'orthogonal': 239 | net.apply(weights_init_orthogonal) 240 | else: 241 | raise NotImplementedError('initialization method [%s] is not implemented' % init_type) 242 | 243 | 244 | def np_to_torch(img_np): 245 | """ 246 | Converts image in numpy.array to torch.Tensor. 247 | 248 | From C x W x H [0..1] to C x W x H [0..1] 249 | 250 | :param img_np: 251 | :return: 252 | """ 253 | return torch.from_numpy(img_np)[None, :] 254 | 255 | 256 | def torch_to_np(img_var): 257 | """ 258 | Converts an image in torch.Tensor format to np.array. 259 | 260 | From 1 x C x W x H [0..1] to C x W x H [0..1] 261 | :param img_var: 262 | :return: 263 | """ 264 | return img_var.detach().cpu().numpy() 265 | # return img_var.detach().cpu().numpy()[0] 266 | 267 | 268 | def save_image(name, image_np, output_path="output/normal/"): 269 | if not os.path.exists(output_path): 270 | os.mkdir(output_path) 271 | 272 | p = np_to_pil(image_np) 273 | p.save(output_path + "{}.png".format(name)) 274 | 275 | 276 | def np_to_pil(img_np): 277 | """ 278 | Converts image in np.array format to PIL image. 279 | 280 | From C x W x H [0..1] to W x H x C [0...255] 281 | :param img_np: 282 | :return: 283 | """ 284 | ar = np.clip(img_np * 255, 0, 255).astype(np.uint8) 285 | 286 | if img_np.shape[0] == 1: 287 | ar = ar[0] 288 | else: 289 | assert img_np.shape[0] == 3, img_np.shape 290 | ar = ar.transpose(1, 2, 0) 291 | 292 | return Image.fromarray(ar) -------------------------------------------------------------------------------- /utils/imresize.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.ndimage import filters, measurements, interpolation 3 | from math import pi 4 | 5 | 6 | def imresize(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False): 7 | # First standardize values and fill missing arguments (if needed) by deriving scale from output shape or vice versa 8 | scale_factor, output_shape = fix_scale_and_size(im.shape, output_shape, scale_factor) 9 | 10 | # For a given numeric kernel case, just do convolution and sub-sampling (downscaling only) 11 | if type(kernel) == np.ndarray and scale_factor[0] <= 1: 12 | return numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag) 13 | 14 | # Choose interpolation method, each method has the matching kernel size 15 | method, kernel_width = { 16 | "cubic": (cubic, 4.0), 17 | "lanczos2": (lanczos2, 4.0), 18 | "lanczos3": (lanczos3, 6.0), 19 | "box": (box, 1.0), 20 | "linear": (linear, 2.0), 21 | None: (cubic, 4.0) # set default interpolation method as cubic 22 | }.get(kernel) 23 | 24 | # Antialiasing is only used when downscaling 25 | antialiasing *= (scale_factor[0] < 1) 26 | 27 | # Sort indices of dimensions according to scale of each dimension. since we are going dim by dim this is efficient 28 | sorted_dims = np.argsort(np.array(scale_factor)).tolist() 29 | 30 | # Iterate over dimensions to calculate local weights for resizing and resize each time in one direction 31 | out_im = np.copy(im) 32 | for dim in sorted_dims: 33 | # No point doing calculations for scale-factor 1. nothing will happen anyway 34 | if scale_factor[dim] == 1.0: 35 | continue 36 | 37 | # for each coordinate (along 1 dim), calculate which coordinates in the input image affect its result and the 38 | # weights that multiply the values there to get its result. 39 | weights, field_of_view = contributions(im.shape[dim], output_shape[dim], scale_factor[dim], 40 | method, kernel_width, antialiasing) 41 | 42 | # Use the affecting position values and the set of weights to calculate the result of resizing along this 1 dim 43 | out_im = resize_along_dim(out_im, dim, weights, field_of_view) 44 | 45 | return out_im 46 | 47 | 48 | def fix_scale_and_size(input_shape, output_shape, scale_factor): 49 | # First fixing the scale-factor (if given) to be standardized the function expects (a list of scale factors in the 50 | # same size as the number of input dimensions) 51 | if scale_factor is not None: 52 | # By default, if scale-factor is a scalar we assume 2d resizing and duplicate it. 53 | if np.isscalar(scale_factor): 54 | scale_factor = [scale_factor, scale_factor] 55 | 56 | # We extend the size of scale-factor list to the size of the input by assigning 1 to all the unspecified scales 57 | scale_factor = list(scale_factor) 58 | scale_factor.extend([1] * (len(input_shape) - len(scale_factor))) 59 | 60 | # Fixing output-shape (if given): extending it to the size of the input-shape, by assigning the original input-size 61 | # to all the unspecified dimensions 62 | if output_shape is not None: 63 | output_shape = list(np.uint(np.array(output_shape))) + list(input_shape[len(output_shape):]) 64 | 65 | # Dealing with the case of non-give scale-factor, calculating according to output-shape. note that this is 66 | # sub-optimal, because there can be different scales to the same output-shape. 67 | if scale_factor is None: 68 | scale_factor = 1.0 * np.array(output_shape) / np.array(input_shape) 69 | 70 | # Dealing with missing output-shape. calculating according to scale-factor 71 | if output_shape is None: 72 | output_shape = np.uint(np.ceil(np.array(input_shape) * np.array(scale_factor))) 73 | 74 | return scale_factor, output_shape 75 | 76 | 77 | def contributions(in_length, out_length, scale, kernel, kernel_width, antialiasing): 78 | # This function calculates a set of 'filters' and a set of field_of_view that will later on be applied 79 | # such that each position from the field_of_view will be multiplied with a matching filter from the 80 | # 'weights' based on the interpolation method and the distance of the sub-pixel location from the pixel centers 81 | # around it. This is only done for one dimension of the image. 82 | 83 | # When anti-aliasing is activated (default and only for downscaling) the receptive field is stretched to size of 84 | # 1/sf. this means filtering is more 'low-pass filter'. 85 | fixed_kernel = (lambda arg: scale * kernel(scale * arg)) if antialiasing else kernel 86 | kernel_width *= 1.0 / scale if antialiasing else 1.0 87 | 88 | # These are the coordinates of the output image 89 | out_coordinates = np.arange(1, out_length+1) 90 | 91 | # These are the matching positions of the output-coordinates on the input image coordinates. 92 | # Best explained by example: say we have 4 horizontal pixels for HR and we downscale by SF=2 and get 2 pixels: 93 | # [1,2,3,4] -> [1,2]. Remember each pixel number is the middle of the pixel. 94 | # The scaling is done between the distances and not pixel numbers (the right boundary of pixel 4 is transformed to 95 | # the right boundary of pixel 2. pixel 1 in the small image matches the boundary between pixels 1 and 2 in the big 96 | # one and not to pixel 2. This means the position is not just multiplication of the old pos by scale-factor). 97 | # So if we measure distance from the left border, middle of pixel 1 is at distance d=0.5, border between 1 and 2 is 98 | # at d=1, and so on (d = p - 0.5). we calculate (d_new = d_old / sf) which means: 99 | # (p_new-0.5 = (p_old-0.5) / sf) -> p_new = p_old/sf + 0.5 * (1-1/sf) 100 | match_coordinates = 1.0 * out_coordinates / scale + 0.5 * (1 - 1.0 / scale) 101 | 102 | # This is the left boundary to start multiplying the filter from, it depends on the size of the filter 103 | left_boundary = np.floor(match_coordinates - kernel_width / 2) 104 | 105 | # Kernel width needs to be enlarged because when covering has sub-pixel borders, it must 'see' the pixel centers 106 | # of the pixels it only covered a part from. So we add one pixel at each side to consider (weights can zeroize them) 107 | expanded_kernel_width = np.ceil(kernel_width) + 2 108 | 109 | # Determine a set of field_of_view for each each output position, these are the pixels in the input image 110 | # that the pixel in the output image 'sees'. We get a matrix whos horizontal dim is the output pixels (big) and the 111 | # vertical dim is the pixels it 'sees' (kernel_size + 2) 112 | field_of_view = np.squeeze(np.uint(np.expand_dims(left_boundary, axis=1) + np.arange(expanded_kernel_width) - 1)) 113 | 114 | # Assign weight to each pixel in the field of view. A matrix whos horizontal dim is the output pixels and the 115 | # vertical dim is a list of weights matching to the pixel in the field of view (that are specified in 116 | # 'field_of_view') 117 | weights = fixed_kernel(1.0 * np.expand_dims(match_coordinates, axis=1) - field_of_view - 1) 118 | 119 | # Normalize weights to sum up to 1. be careful from dividing by 0 120 | sum_weights = np.sum(weights, axis=1) 121 | sum_weights[sum_weights == 0] = 1.0 122 | weights = 1.0 * weights / np.expand_dims(sum_weights, axis=1) 123 | 124 | # We use this mirror structure as a trick for reflection padding at the boundaries 125 | mirror = np.uint(np.concatenate((np.arange(in_length), np.arange(in_length - 1, -1, step=-1)))) 126 | field_of_view = mirror[np.mod(field_of_view, mirror.shape[0])] 127 | 128 | # Get rid of weights and pixel positions that are of zero weight 129 | non_zero_out_pixels = np.nonzero(np.any(weights, axis=0)) 130 | weights = np.squeeze(weights[:, non_zero_out_pixels]) 131 | field_of_view = np.squeeze(field_of_view[:, non_zero_out_pixels]) 132 | 133 | # Final products are the relative positions and the matching weights, both are output_size X fixed_kernel_size 134 | return weights, field_of_view 135 | 136 | 137 | def resize_along_dim(im, dim, weights, field_of_view): 138 | # To be able to act on each dim, we swap so that dim 0 is the wanted dim to resize 139 | tmp_im = np.swapaxes(im, dim, 0) 140 | 141 | # We add singleton dimensions to the weight matrix so we can multiply it with the big tensor we get for 142 | # tmp_im[field_of_view.T], (bsxfun style) 143 | weights = np.reshape(weights.T, list(weights.T.shape) + (np.ndim(im) - 1) * [1]) 144 | 145 | # This is a bit of a complicated multiplication: tmp_im[field_of_view.T] is a tensor of order image_dims+1. 146 | # for each pixel in the output-image it matches the positions the influence it from the input image (along 1 dim 147 | # only, this is why it only adds 1 dim to the shape). We then multiply, for each pixel, its set of positions with 148 | # the matching set of weights. we do this by this big tensor element-wise multiplication (MATLAB bsxfun style: 149 | # matching dims are multiplied element-wise while singletons mean that the matching dim is all multiplied by the 150 | # same number 151 | tmp_out_im = np.sum(tmp_im[field_of_view.T] * weights, axis=0) 152 | 153 | # Finally we swap back the axes to the original order 154 | return np.swapaxes(tmp_out_im, dim, 0) 155 | 156 | 157 | def numeric_kernel(im, kernel, scale_factor, output_shape, kernel_shift_flag): 158 | # See kernel_shift function to understand what this is 159 | if kernel_shift_flag: 160 | kernel = kernel_shift(kernel, scale_factor) 161 | 162 | # First run a correlation (convolution with flipped kernel) 163 | out_im = np.zeros_like(im) 164 | for channel in range(np.ndim(im)): 165 | out_im[:, :, channel] = filters.correlate(im[:, :, channel], kernel) 166 | 167 | # Then subsample and return 168 | return out_im[np.round(np.linspace(0, im.shape[0] - 1 / scale_factor[0], output_shape[0])).astype(int)[:, None], 169 | np.round(np.linspace(0, im.shape[1] - 1 / scale_factor[1], output_shape[1])).astype(int), :] 170 | 171 | 172 | def kernel_shift(kernel, sf): 173 | # There are two reasons for shifting the kernel: 174 | # 1. Center of mass is not in the center of the kernel which creates ambiguity. There is no possible way to know 175 | # the degradation process included shifting so we always assume center of mass is center of the kernel. 176 | # 2. We further shift kernel center so that top left result pixel corresponds to the middle of the sfXsf first 177 | # pixels. Default is for odd size to be in the middle of the first pixel and for even sized kernel to be at the 178 | # top left corner of the first pixel. that is why different shift size needed between od and even size. 179 | # Given that these two conditions are fulfilled, we are happy and aligned, the way to test it is as follows: 180 | # The input image, when interpolated (regular bicubic) is exactly aligned with ground truth. 181 | 182 | # First calculate the current center of mass for the kernel 183 | current_center_of_mass = measurements.center_of_mass(kernel) 184 | 185 | # The second ("+ 0.5 * ....") is for applying condition 2 from the comments above 186 | wanted_center_of_mass = np.array(kernel.shape) / 2 + 0.5 * (sf - (kernel.shape[0] % 2)) 187 | 188 | # Define the shift vector for the kernel shifting (x,y) 189 | shift_vec = wanted_center_of_mass - current_center_of_mass 190 | 191 | # Before applying the shift, we first pad the kernel so that nothing is lost due to the shift 192 | # (biggest shift among dims + 1 for safety) 193 | kernel = np.pad(kernel, np.int(np.ceil(np.max(shift_vec))) + 1, 'constant') 194 | 195 | # Finally shift the kernel and return 196 | return interpolation.shift(kernel, shift_vec) 197 | 198 | 199 | # These next functions are all interpolation methods. x is the distance from the left pixel center 200 | 201 | 202 | def cubic(x): 203 | absx = np.abs(x) 204 | absx2 = absx ** 2 205 | absx3 = absx ** 3 206 | return ((1.5*absx3 - 2.5*absx2 + 1) * (absx <= 1) + 207 | (-0.5*absx3 + 2.5*absx2 - 4*absx + 2) * ((1 < absx) & (absx <= 2))) 208 | 209 | 210 | def lanczos2(x): 211 | return (((np.sin(pi*x) * np.sin(pi*x/2) + np.finfo(np.float32).eps) / 212 | ((pi**2 * x**2 / 2) + np.finfo(np.float32).eps)) 213 | * (abs(x) < 2)) 214 | 215 | 216 | def box(x): 217 | return ((-0.5 <= x) & (x < 0.5)) * 1.0 218 | 219 | 220 | def lanczos3(x): 221 | return (((np.sin(pi*x) * np.sin(pi*x/3) + np.finfo(np.float32).eps) / 222 | ((pi**2 * x**2 / 3) + np.finfo(np.float32).eps)) 223 | * (abs(x) < 3)) 224 | 225 | 226 | def linear(x): 227 | return (x + 1) * ((-1 <= x) & (x < 0)) + (1 - x) * ((0 <= x) & (x <= 1)) 228 | 229 | 230 | def np_imresize(im, scale_factor=None, output_shape=None, kernel=None, antialiasing=True, kernel_shift_flag=False): 231 | return np.clip(imresize(im.transpose(1, 2, 0), scale_factor, output_shape, kernel, antialiasing, 232 | kernel_shift_flag).transpose(2, 0, 1), 0, 1) -------------------------------------------------------------------------------- /utils/loss_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.nn.functional import mse_loss 4 | 5 | 6 | class GANLoss(nn.Module): 7 | def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0, 8 | tensor=torch.FloatTensor): 9 | super(GANLoss, self).__init__() 10 | self.real_label = target_real_label 11 | self.fake_label = target_fake_label 12 | self.real_label_var = None 13 | self.fake_label_var = None 14 | self.Tensor = tensor 15 | if use_lsgan: 16 | self.loss = nn.MSELoss() 17 | else: 18 | self.loss = nn.BCELoss() 19 | 20 | def get_target_tensor(self, input, target_is_real): 21 | target_tensor = None 22 | if target_is_real: 23 | create_label = ((self.real_label_var is None) or(self.real_label_var.numel() != input.numel())) 24 | # pdb.set_trace() 25 | if create_label: 26 | real_tensor = self.Tensor(input.size()).fill_(self.real_label) 27 | # self.real_label_var = Variable(real_tensor, requires_grad=False) 28 | # self.real_label_var = torch.Tensor(real_tensor) 29 | self.real_label_var = real_tensor 30 | target_tensor = self.real_label_var 31 | else: 32 | # pdb.set_trace() 33 | create_label = ((self.fake_label_var is None) or (self.fake_label_var.numel() != input.numel())) 34 | if create_label: 35 | fake_tensor = self.Tensor(input.size()).fill_(self.fake_label) 36 | # self.fake_label_var = Variable(fake_tensor, requires_grad=False) 37 | # self.fake_label_var = torch.Tensor(fake_tensor) 38 | self.fake_label_var = fake_tensor 39 | target_tensor = self.fake_label_var 40 | return target_tensor 41 | 42 | def __call__(self, input, target_is_real): 43 | target_tensor = self.get_target_tensor(input, target_is_real) 44 | # pdb.set_trace() 45 | return self.loss(input, target_tensor) 46 | 47 | -------------------------------------------------------------------------------- /utils/schedulers.py: -------------------------------------------------------------------------------- 1 | import math 2 | from collections import Counter 3 | from torch.optim.lr_scheduler import _LRScheduler 4 | import torch 5 | import warnings 6 | from typing import List 7 | 8 | from torch import nn 9 | from torch.optim import Adam, Optimizer 10 | 11 | class MultiStepRestartLR(_LRScheduler): 12 | """ MultiStep with restarts learning rate scheme. 13 | 14 | Args: 15 | optimizer (torch.nn.optimizer): Torch optimizer. 16 | milestones (list): Iterations that will decrease learning rate. 17 | gamma (float): Decrease ratio. Default: 0.1. 18 | restarts (list): Restart iterations. Default: [0]. 19 | restart_weights (list): Restart weights at each restart iteration. 20 | Default: [1]. 21 | last_epoch (int): Used in _LRScheduler. Default: -1. 22 | """ 23 | 24 | def __init__(self, 25 | optimizer, 26 | milestones, 27 | gamma=0.1, 28 | restarts=(0, ), 29 | restart_weights=(1, ), 30 | last_epoch=-1): 31 | self.milestones = Counter(milestones) 32 | self.gamma = gamma 33 | self.restarts = restarts 34 | self.restart_weights = restart_weights 35 | assert len(self.restarts) == len( 36 | self.restart_weights), 'restarts and their weights do not match.' 37 | super(MultiStepRestartLR, self).__init__(optimizer, last_epoch) 38 | 39 | def get_lr(self): 40 | if self.last_epoch in self.restarts: 41 | weight = self.restart_weights[self.restarts.index(self.last_epoch)] 42 | return [ 43 | group['initial_lr'] * weight 44 | for group in self.optimizer.param_groups 45 | ] 46 | if self.last_epoch not in self.milestones: 47 | return [group['lr'] for group in self.optimizer.param_groups] 48 | return [ 49 | group['lr'] * self.gamma**self.milestones[self.last_epoch] 50 | for group in self.optimizer.param_groups 51 | ] 52 | 53 | class LinearLR(_LRScheduler): 54 | """ 55 | 56 | Args: 57 | optimizer (torch.nn.optimizer): Torch optimizer. 58 | milestones (list): Iterations that will decrease learning rate. 59 | gamma (float): Decrease ratio. Default: 0.1. 60 | last_epoch (int): Used in _LRScheduler. Default: -1. 61 | """ 62 | 63 | def __init__(self, 64 | optimizer, 65 | total_iter, 66 | last_epoch=-1): 67 | self.total_iter = total_iter 68 | super(LinearLR, self).__init__(optimizer, last_epoch) 69 | 70 | def get_lr(self): 71 | process = self.last_epoch / self.total_iter 72 | weight = (1 - process) 73 | # print('get lr ', [weight * group['initial_lr'] for group in self.optimizer.param_groups]) 74 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 75 | 76 | class VibrateLR(_LRScheduler): 77 | """ 78 | 79 | Args: 80 | optimizer (torch.nn.optimizer): Torch optimizer. 81 | milestones (list): Iterations that will decrease learning rate. 82 | gamma (float): Decrease ratio. Default: 0.1. 83 | last_epoch (int): Used in _LRScheduler. Default: -1. 84 | """ 85 | 86 | def __init__(self, 87 | optimizer, 88 | total_iter, 89 | last_epoch=-1): 90 | self.total_iter = total_iter 91 | super(VibrateLR, self).__init__(optimizer, last_epoch) 92 | 93 | def get_lr(self): 94 | process = self.last_epoch / self.total_iter 95 | 96 | f = 0.1 97 | if process < 3 / 8: 98 | f = 1 - process * 8 / 3 99 | elif process < 5 / 8: 100 | f = 0.2 101 | 102 | T = self.total_iter // 80 103 | Th = T // 2 104 | 105 | t = self.last_epoch % T 106 | 107 | f2 = t / Th 108 | if t >= Th: 109 | f2 = 2 - f2 110 | 111 | weight = f * f2 112 | 113 | if self.last_epoch < Th: 114 | weight = max(0.1, weight) 115 | 116 | # print('f {}, T {}, Th {}, t {}, f2 {}'.format(f, T, Th, t, f2)) 117 | return [weight * group['initial_lr'] for group in self.optimizer.param_groups] 118 | 119 | def get_position_from_periods(iteration, cumulative_period): 120 | """Get the position from a period list. 121 | 122 | It will return the index of the right-closest number in the period list. 123 | For example, the cumulative_period = [100, 200, 300, 400], 124 | if iteration == 50, return 0; 125 | if iteration == 210, return 2; 126 | if iteration == 300, return 2. 127 | 128 | Args: 129 | iteration (int): Current iteration. 130 | cumulative_period (list[int]): Cumulative period list. 131 | 132 | Returns: 133 | int: The position of the right-closest number in the period list. 134 | """ 135 | for i, period in enumerate(cumulative_period): 136 | if iteration <= period: 137 | return i 138 | 139 | 140 | class CosineAnnealingRestartLR(_LRScheduler): 141 | """ Cosine annealing with restarts learning rate scheme. 142 | 143 | An example of config: 144 | periods = [10, 10, 10, 10] 145 | restart_weights = [1, 0.5, 0.5, 0.5] 146 | eta_min=1e-7 147 | 148 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 149 | scheduler will restart with the weights in restart_weights. 150 | 151 | Args: 152 | optimizer (torch.nn.optimizer): Torch optimizer. 153 | periods (list): Period for each cosine anneling cycle. 154 | restart_weights (list): Restart weights at each restart iteration. 155 | Default: [1]. 156 | eta_min (float): The mimimum lr. Default: 0. 157 | last_epoch (int): Used in _LRScheduler. Default: -1. 158 | """ 159 | 160 | def __init__(self, 161 | optimizer, 162 | periods, 163 | restart_weights=(1, ), 164 | eta_min=0, 165 | last_epoch=-1): 166 | self.periods = periods 167 | self.restart_weights = restart_weights 168 | self.eta_min = eta_min 169 | assert (len(self.periods) == len(self.restart_weights) 170 | ), 'periods and restart_weights should have the same length.' 171 | self.cumulative_period = [ 172 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 173 | ] 174 | super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch) 175 | 176 | def get_lr(self): 177 | idx = get_position_from_periods(self.last_epoch, 178 | self.cumulative_period) 179 | current_weight = self.restart_weights[idx] 180 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 181 | current_period = self.periods[idx] 182 | 183 | return [ 184 | self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) * 185 | (1 + math.cos(math.pi * ( 186 | (self.last_epoch - nearest_restart) / current_period))) 187 | for base_lr in self.base_lrs 188 | ] 189 | 190 | class CosineAnnealingRestartCyclicLR(_LRScheduler): 191 | """ Cosine annealing with restarts learning rate scheme. 192 | An example of config: 193 | periods = [10, 10, 10, 10] 194 | restart_weights = [1, 0.5, 0.5, 0.5] 195 | eta_min=1e-7 196 | It has four cycles, each has 10 iterations. At 10th, 20th, 30th, the 197 | scheduler will restart with the weights in restart_weights. 198 | Args: 199 | optimizer (torch.nn.optimizer): Torch optimizer. 200 | periods (list): Period for each cosine anneling cycle. 201 | restart_weights (list): Restart weights at each restart iteration. 202 | Default: [1]. 203 | eta_min (float): The mimimum lr. Default: 0. 204 | last_epoch (int): Used in _LRScheduler. Default: -1. 205 | """ 206 | 207 | def __init__(self, 208 | optimizer, 209 | periods, 210 | restart_weights=(1, ), 211 | eta_mins=(0, ), 212 | last_epoch=-1): 213 | self.periods = periods 214 | self.restart_weights = restart_weights 215 | self.eta_mins = eta_mins 216 | assert (len(self.periods) == len(self.restart_weights) 217 | ), 'periods and restart_weights should have the same length.' 218 | self.cumulative_period = [ 219 | sum(self.periods[0:i + 1]) for i in range(0, len(self.periods)) 220 | ] 221 | super(CosineAnnealingRestartCyclicLR, self).__init__(optimizer, last_epoch) 222 | 223 | def get_lr(self): 224 | idx = get_position_from_periods(self.last_epoch, 225 | self.cumulative_period) 226 | current_weight = self.restart_weights[idx] 227 | nearest_restart = 0 if idx == 0 else self.cumulative_period[idx - 1] 228 | current_period = self.periods[idx] 229 | eta_min = self.eta_mins[idx] 230 | 231 | return [ 232 | eta_min + current_weight * 0.5 * (base_lr - eta_min) * 233 | (1 + math.cos(math.pi * ( 234 | (self.last_epoch - nearest_restart) / current_period))) 235 | for base_lr in self.base_lrs 236 | ] 237 | 238 | 239 | class LinearWarmupCosineAnnealingLR(_LRScheduler): 240 | """Sets the learning rate of each parameter group to follow a linear warmup schedule between warmup_start_lr 241 | and base_lr followed by a cosine annealing schedule between base_lr and eta_min. 242 | .. warning:: 243 | It is recommended to call :func:`.step()` for :class:`LinearWarmupCosineAnnealingLR` 244 | after each iteration as calling it after each epoch will keep the starting lr at 245 | warmup_start_lr for the first epoch which is 0 in most cases. 246 | .. warning:: 247 | passing epoch to :func:`.step()` is being deprecated and comes with an EPOCH_DEPRECATION_WARNING. 248 | It calls the :func:`_get_closed_form_lr()` method for this scheduler instead of 249 | :func:`get_lr()`. Though this does not change the behavior of the scheduler, when passing 250 | epoch param to :func:`.step()`, the user should call the :func:`.step()` function before calling 251 | train and validation methods. 252 | Example: 253 | >>> layer = nn.Linear(10, 1) 254 | >>> optimizer = Adam(layer.parameters(), lr=0.02) 255 | >>> scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=40) 256 | >>> # 257 | >>> # the default case 258 | >>> for epoch in range(40): 259 | ... # train(...) 260 | ... # validate(...) 261 | ... scheduler.step() 262 | >>> # 263 | >>> # passing epoch param case 264 | >>> for epoch in range(40): 265 | ... scheduler.step(epoch) 266 | ... # train(...) 267 | ... # validate(...) 268 | """ 269 | 270 | def __init__( 271 | self, 272 | optimizer: Optimizer, 273 | warmup_epochs: int, 274 | max_epochs: int, 275 | warmup_start_lr: float = 0.0, 276 | eta_min: float = 0.0, 277 | last_epoch: int = -1, 278 | ) -> None: 279 | """ 280 | Args: 281 | optimizer (Optimizer): Wrapped optimizer. 282 | warmup_epochs (int): Maximum number of iterations for linear warmup 283 | max_epochs (int): Maximum number of iterations 284 | warmup_start_lr (float): Learning rate to start the linear warmup. Default: 0. 285 | eta_min (float): Minimum learning rate. Default: 0. 286 | last_epoch (int): The index of last epoch. Default: -1. 287 | """ 288 | self.warmup_epochs = warmup_epochs 289 | self.max_epochs = max_epochs 290 | self.warmup_start_lr = warmup_start_lr 291 | self.eta_min = eta_min 292 | 293 | super().__init__(optimizer, last_epoch) 294 | 295 | def get_lr(self) -> List[float]: 296 | """Compute learning rate using chainable form of the scheduler.""" 297 | if not self._get_lr_called_within_step: 298 | warnings.warn( 299 | "To get the last learning rate computed by the scheduler, " "please use `get_last_lr()`.", 300 | UserWarning, 301 | ) 302 | 303 | if self.last_epoch == 0: 304 | return [self.warmup_start_lr] * len(self.base_lrs) 305 | if self.last_epoch < self.warmup_epochs: 306 | return [ 307 | group["lr"] + (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 308 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 309 | ] 310 | if self.last_epoch == self.warmup_epochs: 311 | return self.base_lrs 312 | if (self.last_epoch - 1 - self.max_epochs) % (2 * (self.max_epochs - self.warmup_epochs)) == 0: 313 | return [ 314 | group["lr"] 315 | + (base_lr - self.eta_min) * (1 - math.cos(math.pi / (self.max_epochs - self.warmup_epochs))) / 2 316 | for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups) 317 | ] 318 | 319 | return [ 320 | (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 321 | / ( 322 | 1 323 | + math.cos( 324 | math.pi * (self.last_epoch - self.warmup_epochs - 1) / (self.max_epochs - self.warmup_epochs) 325 | ) 326 | ) 327 | * (group["lr"] - self.eta_min) 328 | + self.eta_min 329 | for group in self.optimizer.param_groups 330 | ] 331 | 332 | def _get_closed_form_lr(self) -> List[float]: 333 | """Called when epoch is passed as a param to the `step` function of the scheduler.""" 334 | if self.last_epoch < self.warmup_epochs: 335 | return [ 336 | self.warmup_start_lr + self.last_epoch * (base_lr - self.warmup_start_lr) / (self.warmup_epochs - 1) 337 | for base_lr in self.base_lrs 338 | ] 339 | 340 | return [ 341 | self.eta_min 342 | + 0.5 343 | * (base_lr - self.eta_min) 344 | * (1 + math.cos(math.pi * (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs))) 345 | for base_lr in self.base_lrs 346 | ] 347 | 348 | 349 | # warmup + decay as a function 350 | def linear_warmup_decay(warmup_steps, total_steps, cosine=True, linear=False): 351 | """Linear warmup for warmup_steps, optionally with cosine annealing or linear decay to 0 at total_steps.""" 352 | assert not (linear and cosine) 353 | 354 | def fn(step): 355 | if step < warmup_steps: 356 | return float(step) / float(max(1, warmup_steps)) 357 | 358 | if not (cosine or linear): 359 | # no decay 360 | return 1.0 361 | 362 | progress = float(step - warmup_steps) / float(max(1, total_steps - warmup_steps)) 363 | if cosine: 364 | # cosine decay 365 | return 0.5 * (1.0 + math.cos(math.pi * progress)) 366 | 367 | # linear decay 368 | return 1.0 - progress 369 | 370 | return fn 371 | -------------------------------------------------------------------------------- /utils/val_utils.py: -------------------------------------------------------------------------------- 1 | import time 2 | import numpy as np 3 | from skimage.metrics import peak_signal_noise_ratio, structural_similarity 4 | import cv2 5 | import math 6 | 7 | class AverageMeter(): 8 | """ Computes and stores the average and current value """ 9 | 10 | def __init__(self): 11 | self.reset() 12 | 13 | def reset(self): 14 | """ Reset all statistics """ 15 | self.val = 0 16 | self.avg = 0 17 | self.sum = 0 18 | self.count = 0 19 | 20 | def update(self, val, n=1): 21 | """ Update statistics """ 22 | self.val = val 23 | self.sum += val * n 24 | self.count += n 25 | self.avg = self.sum / self.count 26 | 27 | 28 | def accuracy(output, target, topk=(1,)): 29 | """ Computes the precision@k for the specified values of k """ 30 | maxk = max(topk) 31 | batch_size = target.size(0) 32 | 33 | _, pred = output.topk(maxk, 1, True, True) 34 | pred = pred.t() 35 | # one-hot case 36 | if target.ndimension() > 1: 37 | target = target.max(1)[1] 38 | 39 | correct = pred.eq(target.view(1, -1).expand_as(pred)) 40 | 41 | res = [] 42 | for k in topk: 43 | correct_k = correct[:k].view(-1).float().sum(0) 44 | res.append(correct_k.mul_(1.0 / batch_size)) 45 | 46 | return res 47 | 48 | 49 | 50 | def compute_psnr_ssim(recoverd, clean,to_y=False,bd=0): 51 | # shape: [1,C,H,W] range:[0-1] mode:rgb 52 | assert recoverd.shape == clean.shape 53 | recoverd = recoverd*255. 54 | clean = clean*255. 55 | recoverd = np.clip(recoverd.detach().cpu().numpy(), 0, 255) 56 | clean = np.clip(clean.detach().cpu().numpy(), 0, 255) 57 | 58 | recoverd = recoverd.transpose(0, 2, 3, 1)[0] 59 | clean = clean.transpose(0, 2, 3, 1)[0] 60 | recoverd = recoverd.astype(np.float64).round() 61 | clean = clean.astype(np.float64).round() 62 | if to_y: 63 | recoverd = rgb2ycbcr(recoverd / 255.0, only_y=True) * 255.0 64 | clean = rgb2ycbcr(clean / 255.0, only_y=True) * 255.0 65 | 66 | if bd != 0: 67 | recoverd= recoverd[bd:-bd, bd:-bd] 68 | clean = clean[bd:-bd, bd:-bd] 69 | psnr = calculate_psnr(recoverd, clean) 70 | ssim = calculate_ssim(recoverd, clean) 71 | 72 | return psnr, ssim, 1 73 | 74 | 75 | 76 | 77 | 78 | class timer(): 79 | def __init__(self): 80 | self.acc = 0 81 | self.tic() 82 | 83 | def tic(self): 84 | self.t0 = time.time() 85 | 86 | def toc(self): 87 | return time.time() - self.t0 88 | 89 | def hold(self): 90 | self.acc += self.toc() 91 | 92 | def release(self): 93 | ret = self.acc 94 | self.acc = 0 95 | 96 | return ret 97 | 98 | def reset(self): 99 | self.acc = 0 100 | 101 | 102 | 103 | 104 | 105 | def calculate_psnr(img1, img2): 106 | # img1 and img2 have range [0, 255] 107 | img1 = img1.astype(np.float64) 108 | img2 = img2.astype(np.float64) 109 | mse = np.mean((img1 - img2) ** 2) 110 | if mse == 0: 111 | return float('inf') 112 | 113 | return 20 * math.log10(255.0 / math.sqrt(mse)) 114 | 115 | 116 | def calculate_ssim(img1, img2): 117 | C1 = (0.01 * 255) ** 2 118 | C2 = (0.03 * 255) ** 2 119 | 120 | img1 = img1.astype(np.float64) 121 | img2 = img2.astype(np.float64) 122 | kernel = cv2.getGaussianKernel(11, 1.5) 123 | window = np.outer(kernel, kernel.transpose()) 124 | 125 | mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] 126 | mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5] 127 | mu1_sq = mu1 ** 2 128 | mu2_sq = mu2 ** 2 129 | mu1_mu2 = mu1 * mu2 130 | sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq 131 | sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq 132 | sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2 133 | 134 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 135 | 136 | return ssim_map.mean() 137 | 138 | 139 | def calc_psnr(sr, hr, scale=2, rgb_range=1.0, benchmark=True): 140 | # benchmark: to Y channel 141 | diff = (sr - hr).data.div(rgb_range) 142 | if benchmark: 143 | shave = scale 144 | if diff.size(1) > 1: 145 | convert = diff.new(1, 3, 1, 1) 146 | convert[0, 0, 0, 0] = 65.738 147 | convert[0, 1, 0, 0] = 129.057 148 | convert[0, 2, 0, 0] = 25.064 149 | diff.mul_(convert).div_(256) 150 | diff = diff.sum(dim=1, keepdim=True) 151 | else: 152 | shave = scale + 6 153 | 154 | valid = diff[:, :, shave:-shave, shave:-shave] 155 | mse = valid.pow(2).mean() 156 | 157 | return -10 * math.log10(mse) 158 | 159 | 160 | def rgb2ycbcr(img, only_y=True): 161 | '''same as matlab rgb2ycbcr 162 | only_y: only return Y channel 163 | Input: 164 | uint8, [0, 255] 165 | float, [0, 1] 166 | ''' 167 | in_img_type = img.dtype 168 | img.astype(np.float32) 169 | if in_img_type != np.uint8: 170 | img *= 255. 171 | # convert 172 | if only_y: 173 | rlt = np.dot(img, [65.481, 128.553, 24.966]) / 255.0 + 16.0 174 | else: 175 | rlt = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], 176 | [24.966, 112.0, -18.214]]) / 255.0 + [16, 128, 128] 177 | if in_img_type == np.uint8: 178 | rlt = rlt.round() 179 | else: 180 | rlt /= 255. 181 | return rlt.astype(in_img_type) 182 | 183 | 184 | def bgr2ycbcr(img, only_y=True): 185 | '''bgr version of rgb2ycbcr 186 | only_y: only return Y channel 187 | Input: 188 | uint8, [0, 255] 189 | float, [0, 1] 190 | ''' 191 | in_img_type = img.dtype 192 | img.astype(np.float32) 193 | if in_img_type != np.uint8: 194 | img *= 255. 195 | # convert 196 | if only_y: 197 | rlt = np.dot(img, [24.966, 128.553, 65.481]) / 255.0 + 16.0 198 | else: 199 | rlt = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], 200 | [65.481, -37.797, 112.0]]) / 255.0 + [16, 128, 128] 201 | if in_img_type == np.uint8: 202 | rlt = rlt.round() 203 | else: 204 | rlt /= 255. 205 | return rlt.astype(in_img_type) 206 | 207 | 208 | def ycbcr2rgb(img): 209 | '''same as matlab ycbcr2rgb 210 | Input: 211 | uint8, [0, 255] 212 | float, [0, 1] 213 | ''' 214 | in_img_type = img.dtype 215 | img.astype(np.float32) 216 | if in_img_type != np.uint8: 217 | img *= 255. 218 | # convert 219 | rlt = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], 220 | [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] 221 | if in_img_type == np.uint8: 222 | rlt = rlt.round() 223 | else: 224 | rlt /= 255. 225 | return rlt.astype(in_img_type) 226 | -------------------------------------------------------------------------------- /val_options.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | 3 | parser = argparse.ArgumentParser() 4 | parser.add_argument('--cuda', type=int, default=0) 5 | parser.add_argument('--base_path', type=str, default="/data/guohang/dataset", help='save path of test noisy images') 6 | parser.add_argument('--denoise_path',type=str,default='/data/guohang/dataset/CBSD68/original_png') 7 | parser.add_argument('--derain_path', type=str, default="/data/guohang/dataset/Rain100L", 8 | help='save path of test raining images') 9 | parser.add_argument('--sr_path', type=str, default="/data/guohang/dataset/SR/Set5/HR", 10 | help='path to the sr dataset for validation') 11 | parser.add_argument('--output_path', type=str, default="./test_output/", help='output save path') 12 | 13 | testopt = parser.parse_args() 14 | 15 | 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | 32 | 33 | 34 | --------------------------------------------------------------------------------