├── .gitee ├── ISSUE_TEMPLATE.zh-CN.md └── PULL_REQUEST_TEMPLATE.zh-CN.md ├── .gitignore ├── LICENSE ├── README.md ├── UBRFC ├── Attention.py ├── CR.py ├── Dataset.py ├── Declaration.py ├── GAN.py ├── Get_image.py ├── Loss.py ├── Metrics.py ├── Option.py ├── Parameter.py ├── Util.py ├── test.py └── train.py └── images ├── Attention_00.png ├── Dense_00.png ├── Haze1k_00.png ├── NH_00.png ├── Outdoor_00.png └── framework_00.png /.gitee/ISSUE_TEMPLATE.zh-CN.md: -------------------------------------------------------------------------------- 1 | ### 该问题是怎么引起的? 2 | 3 | 4 | 5 | ### 重现步骤 6 | 7 | 8 | 9 | ### 报错信息 10 | 11 | 12 | 13 | 14 | -------------------------------------------------------------------------------- /.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md: -------------------------------------------------------------------------------- 1 | ### 一、内容说明(相关的Issue) 2 | 3 | 4 | 5 | ### 二、建议测试周期和提测地址 6 | 建议测试完成时间:xxxx.xx.xx 7 | 投产上线时间:xxxx.xx.xx 8 | 提测地址:CI环境/压测环境 9 | 测试账号: 10 | 11 | ### 三、变更内容 12 | * 3.1 关联PR列表 13 | 14 | * 3.2 数据库和部署说明 15 | 1. 常规更新 16 | 2. 重启unicorn 17 | 3. 重启sidekiq 18 | 4. 迁移任务:是否有迁移任务,没有写 "无" 19 | 5. rake脚本:`bundle exec xxx RAILS_ENV = production`;没有写 "无" 20 | 21 | * 3.4 其他技术优化内容(做了什么,变更了什么) 22 | - 重构了 xxxx 代码 23 | - xxxx 算法优化 24 | 25 | 26 | * 3.5 废弃通知(什么字段、方法弃用?) 27 | 28 | 29 | 30 | * 3.6 后向不兼容变更(是否有无法向后兼容的变更?) 31 | 32 | 33 | 34 | ### 四、研发自测点(自测哪些?冒烟用例全部自测?) 35 | 自测测试结论: 36 | 37 | 38 | ### 五、测试关注点(需要提醒QA重点关注的、可能会忽略的地方) 39 | 检查点: 40 | 41 | | 需求名称 | 是否影响xx公共模块 | 是否需要xx功能 | 需求升级是否依赖其他子产品 | 42 | |------|------------|----------|---------------| 43 | | xxx | 否 | 需要 | 不需要 | 44 | | | | | | 45 | 46 | 接口测试: 47 | 48 | 性能测试: 49 | 50 | 并发测试: 51 | 52 | 其他: 53 | 54 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Object file 2 | *.o 3 | 4 | # Ada Library Information 5 | *.ali 6 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Academic Free License (“AFL”) v. 3.0 2 | 3 | This Academic Free License (the "License") applies to any original work of authorship (the "Original Work") whose owner (the "Licensor") has placed the following licensing notice adjacent to the copyright notice for the Original Work: 4 | 5 | Licensed under the Academic Free License version 3.0 6 | 7 | 1) Grant of Copyright License. Licensor grants You a worldwide, royalty-free, non-exclusive, sublicensable license, for the duration of the copyright, to do the following: 8 | 9 | a) to reproduce the Original Work in copies, either alone or as part of a collective work; 10 | b) to translate, adapt, alter, transform, modify, or arrange the Original Work, thereby creating derivative works ("Derivative Works") based upon the Original Work; 11 | c) to distribute or communicate copies of the Original Work and Derivative Works to the public, under any license of your choice that does not contradict the terms and conditions, including Licensor’s reserved rights and remedies, in this Academic Free License; 12 | d) to perform the Original Work publicly; and 13 | e) to display the Original Work publicly. 14 | 15 | 2) Grant of Patent License. Licensor grants You a worldwide, royalty-free, non-exclusive, sublicensable license, under patent claims owned or controlled by the Licensor that are embodied in the Original Work as furnished by the Licensor, for the duration of the patents, to make, use, sell, offer for sale, have made, and import the Original Work and Derivative Works. 16 | 17 | 3) Grant of Source Code License. The term "Source Code" means the preferred form of the Original Work for making modifications to it and all available documentation describing how to modify the Original Work. Licensor agrees to provide a machine-readable copy of the Source Code of the Original Work along with each copy of the Original Work that Licensor distributes. Licensor reserves the right to satisfy this obligation by placing a machine-readable copy of the Source Code in an information repository reasonably calculated to permit inexpensive and convenient access by You for as long as Licensor continues to distribute the Original Work. 18 | 19 | 4) Exclusions From License Grant. Neither the names of Licensor, nor the names of any contributors to the Original Work, nor any of their trademarks or service marks, may be used to endorse or promote products derived from this Original Work without express prior permission of the Licensor. Except as expressly stated herein, nothing in this License grants any license to Licensor’s trademarks, copyrights, patents, trade secrets or any other intellectual property. No patent license is granted to make, use, sell, offer for sale, have made, or import embodiments of any patent claims other than the licensed claims defined in Section 2. No license is granted to the trademarks of Licensor even if such marks are included in the Original Work. Nothing in this License shall be interpreted to prohibit Licensor from licensing under terms different from this License any Original Work that Licensor otherwise would have a right to license. 20 | 21 | 5) External Deployment. The term "External Deployment" means the use, distribution, or communication of the Original Work or Derivative Works in any way such that the Original Work or Derivative Works may be used by anyone other than You, whether those works are distributed or communicated to those persons or made available as an application intended for use over a network. As an express condition for the grants of license hereunder, You must treat any External Deployment by You of the Original Work or a Derivative Work as a distribution under section 1(c). 22 | 23 | 6) Attribution Rights. You must retain, in the Source Code of any Derivative Works that You create, all copyright, patent, or trademark notices from the Source Code of the Original Work, as well as any notices of licensing and any descriptive text identified therein as an "Attribution Notice." You must cause the Source Code for any Derivative Works that You create to carry a prominent Attribution Notice reasonably calculated to inform recipients that You have modified the Original Work. 24 | 25 | 7) Warranty of Provenance and Disclaimer of Warranty. Licensor warrants that the copyright in and to the Original Work and the patent rights granted herein by Licensor are owned by the Licensor or are sublicensed to You under the terms of this License with the permission of the contributor(s) of those copyrights and patent rights. Except as expressly stated in the immediately preceding sentence, the Original Work is provided under this License on an "AS IS" BASIS and WITHOUT WARRANTY, either express or implied, including, without limitation, the warranties of non-infringement, merchantability or fitness for a particular purpose. THE ENTIRE RISK AS TO THE QUALITY OF THE ORIGINAL WORK IS WITH YOU. This DISCLAIMER OF WARRANTY constitutes an essential part of this License. No license to the Original Work is granted by this License except under this disclaimer. 26 | 27 | 8) Limitation of Liability. Under no circumstances and under no legal theory, whether in tort (including negligence), contract, or otherwise, shall the Licensor be liable to anyone for any indirect, special, incidental, or consequential damages of any character arising as a result of this License or the use of the Original Work including, without limitation, damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses. This limitation of liability shall not apply to the extent applicable law prohibits such limitation. 28 | 29 | 9) Acceptance and Termination. If, at any time, You expressly assented to this License, that assent indicates your clear and irrevocable acceptance of this License and all of its terms and conditions. If You distribute or communicate copies of the Original Work or a Derivative Work, You must make a reasonable effort under the circumstances to obtain the express assent of recipients to the terms of this License. This License conditions your rights to undertake the activities listed in Section 1, including your right to create Derivative Works based upon the Original Work, and doing so without honoring these terms and conditions is prohibited by copyright law and international treaty. Nothing in this License is intended to affect copyright exceptions and limitations (including “fair use” or “fair dealing”). This License shall terminate immediately and You may no longer exercise any of the rights granted to You by this License upon your failure to honor the conditions in Section 1(c). 30 | 31 | 10) Termination for Patent Action. This License shall terminate automatically and You may no longer exercise any of the rights granted to You by this License as of the date You commence an action, including a cross-claim or counterclaim, against Licensor or any licensee alleging that the Original Work infringes a patent. This termination provision shall not apply for an action alleging patent infringement by combinations of the Original Work with other software or hardware. 32 | 33 | 11) Jurisdiction, Venue and Governing Law. Any action or suit relating to this License may be brought only in the courts of a jurisdiction wherein the Licensor resides or in which Licensor conducts its primary business, and under the laws of that jurisdiction excluding its conflict-of-law provisions. The application of the United Nations Convention on Contracts for the International Sale of Goods is expressly excluded. Any use of the Original Work outside the scope of this License or after its termination shall be subject to the requirements and penalties of copyright or patent law in the appropriate jurisdiction. This section shall survive the termination of this License. 34 | 35 | 12) Attorneys’ Fees. In any action to enforce the terms of this License or seeking damages relating thereto, the prevailing party shall be entitled to recover its costs and expenses, including, without limitation, reasonable attorneys' fees and costs incurred in connection with such action, including any appeal of such action. This section shall survive the termination of this License. 36 | 37 | 13) Miscellaneous. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. 38 | 39 | 14) Definition of "You" in This License. "You" throughout this License, whether in upper or lower case, means an individual or a legal entity exercising rights under, and complying with all of the terms of, this License. For legal entities, "You" includes any entity that controls, is controlled by, or is under common control with you. For purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 40 | 41 | 15) Right to Use. You may use the Original Work in all ways not otherwise restricted or conditioned by this License or by law, and Licensor promises not to interfere with or be responsible for such uses by You. 42 | 43 | 16) Modification of This License. This License is Copyright © 2005 Lawrence Rosen. Permission is granted to copy, distribute, or communicate this License without modification. Nothing in this License permits You to modify this License as applied to the Original Work or to Derivative Works. However, You may modify the text of this License and copy, distribute or communicate your modified version (the "Modified License") and apply it to other original works of authorship subject to the following conditions: (i) You may not indicate in any way that your Modified License is the "Academic Free License" or "AFL" and you may not use those names in the name of your Modified License; (ii) You must replace the notice specified in the first paragraph above with the notice "Licensed under " or with a notice of your own that is not confusingly similar to the notice in this License; and (iii) You may not claim that your original works are open source software unless your Modified License has been approved by Open Source Initiative (OSI) and You comply with its license review and certification process. 44 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UBRFC 2 | 3 | 4 | 5 | [![Contributors][contributors-shield]][contributors-url] 6 | [![Forks][forks-shield]][forks-url] 7 | [![Stargazers][stars-shield]][stars-url] 8 | [![Issues][issues-shield]][issues-url] 9 | 10 | 11 |
12 | 13 |

14 | 15 | Logo 16 | 17 |

Contrastive Bidirectional Reconstruction Framework

18 |

19 | 20 | Logo 21 | 22 |

23 |

Adaptive Fine-Grained Channel Attention

24 | 25 |

26 | Unsupervised Bidirectional Contrastive Reconstruction and Adaptive Fine-Grained Channel Attention Networks for Image Dehazing 27 |
28 | Exploring the documentation for UBRFC-Net » 29 |
30 |
31 | Check Demo 32 | · 33 | Report Bug 34 | · 35 | Pull Request 36 |

37 | 38 |

39 | 40 | ## Contents 41 | 42 | - [Dependencies](#dependences) 43 | - [Filetree](#filetree) 44 | - [Pretrained Model](#pretrained-weights-and-dataset) 45 | - [Train](#train) 46 | - [Test](#test) 47 | - [Clone the repo](#clone-the-repo) 48 | - [Qualitative Results](#qualitative-results) 49 | - [Results on RESIDE-Outdoor Dehazing Challenge testing images:](#results-on-reside-outdoor-dehazing-challenge-testing-images) 50 | - [Results on NTIRE 2021 NonHomogeneous Dehazing Challenge testing images:](#results-on-ntire-2021-nonhomogeneous-dehazing-challenge-testing-images) 51 | - [Results on Dense Dehazing Challenge testing images:](#results-on-dense-dehazing-challenge-testing-images) 52 | - [Results on Statehaze1k remote sensing Dehazing Challenge testing images:](#results-on-statehaze1k-remote-sensing-dehazing-challenge-testing-images) 53 | - [Copyright](#copyright) 54 | - [Thanks](#thanks) 55 | 56 | ### Dependences 57 | 58 | 1. Pytorch 1.8.0 59 | 2. Python 3.7.1 60 | 3. CUDA 11.7 61 | 4. Ubuntu 18.04 62 | 63 | ### Filetree 64 | ``` 65 | ├─README.md 66 | │ 67 | ├─UBRFC 68 | │ Attention.py 69 | │ CR.py 70 | │ Dataset.py 71 | │ GAN.py 72 | │ Get_image.py 73 | │ Loss.py 74 | │ Metrics.py 75 | │ Option.py 76 | │ Parameter.py 77 | │ test.py 78 | │ Util.py 79 | │ 80 | ├─images 81 | │ Attention_00.png 82 | │ Dense_00.png 83 | │ framework_00.png 84 | │ NH_00.png 85 | │ Outdoor_00.png 86 | │ 87 | └─LICENSE 88 | ``` 89 | ### Pretrained Weights and Dataset 90 | 91 | Download our model weights on Google: https://drive.google.com/drive/folders/1fyTzElUd5JvKthlf_1o4PTcoC0mm9ar-?usp=sharing 92 | 93 | Download our test datasets on Google: https://drive.google.com/drive/folders/13Al-It-4srPW7YjS-Iajl54FEtgXNYRC?usp=sharing 94 | 95 | ### Train 96 | 97 | ```shell 98 | python train.py --device 0 --train_root train_path --test_root test_path --batch_size 4 99 | such as: 100 | python train.py --device 0 --train_root /home/Datasets/Outdoor/train/ --test_root /home/Datasets/Outdoor/test/ --batch_size 4 101 | ``` 102 | 103 | ### Test 104 | 105 | ```shell 106 | python Get_image.py --device GUP_id --test_root test_path --pre_model_path model_path 107 | such as: 108 | python Get_image.py --device 0 --test_root /home/Dense_hazy/test/ --pre_model_path ./model/best_model.pth 109 | ``` 110 | 111 | ### Clone the repo 112 | 113 | ```sh 114 | git clone https://github.com/Lose-Code/UBRFC-Net.git 115 | ``` 116 | 117 | ### Qualitative Results 118 | 119 | #### Results on RESIDE-Outdoor Dehazing Challenge testing images 120 |
121 | 122 |
123 | 124 | #### Results on NTIRE 2021 NonHomogeneous Dehazing Challenge testing images 125 |
126 | 127 |
128 | 129 | #### Results on Dense Dehazing Challenge testing images 130 |
131 | 132 |
133 | 134 | #### Results on Statehaze1k remote sensing Dehazing Challenge testing images 135 |
136 | 137 |
138 | 139 | 140 | 141 | 142 | 143 | 144 | 145 | ### Thanks 146 | 147 | 148 | - [GitHub Emoji Cheat Sheet](https://www.webpagefx.com/tools/emoji-cheat-sheet) 149 | - [Img Shields](https://shields.io) 150 | - [Choose an Open Source License](https://choosealicense.com) 151 | - [GitHub Pages](https://pages.github.com) 152 | 153 | 154 | 155 | [contributors-shield]: https://img.shields.io/github/contributors/Lose-Code/UBRFC-Net.svg?style=flat-square 156 | [contributors-url]: https://github.com/Lose-Code/UBRFC-Net/graphs/contributors 157 | [forks-shield]: https://img.shields.io/github/forks/Lose-Code/UBRFC-Net.svg?style=flat-square 158 | [forks-url]: https://github.com/Lose-Code/UBRFC-Net/network/members 159 | [stars-shield]: https://img.shields.io/github/stars/Lose-Code/UBRFC-Net.svg?style=flat-square 160 | [stars-url]: https://github.com/Lose-Code/UBRFC-Net/stargazers 161 | [issues-shield]: https://img.shields.io/github/issues/Lose-Code/UBRFC-Net.svg?style=flat-square 162 | [issues-url]: https://img.shields.io/github/issues/Lose-Code/UBRFC-Net.svg 163 | [license-shield]: https://img.shields.io/github/license/Lose-Code/UBRFC-Net.svg?style=flat-square 164 | [license-url]: https://github.com/Lose-Code/UBRFC-Net/blob/master/LICENSE 165 | [linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=flat-square&logo=linkedin&colorB=555 166 | [linkedin-url]: https://linkedin.com/in/shaojintian 167 | -------------------------------------------------------------------------------- /UBRFC/Attention.py: -------------------------------------------------------------------------------- 1 | import math 2 | import torch 3 | from torch import nn 4 | #from torchstat import stat # 查看网络参数 5 | 6 | class Mix(nn.Module): 7 | def __init__(self, m=-0.80): 8 | super(Mix, self).__init__() 9 | w = torch.nn.Parameter(torch.FloatTensor([m]), requires_grad=True) 10 | w = torch.nn.Parameter(w, requires_grad=True) 11 | self.w = w 12 | self.mix_block = nn.Sigmoid() 13 | 14 | def forward(self, fea1, fea2): 15 | mix_factor = self.mix_block(self.w) 16 | out = fea1 * mix_factor.expand_as(fea1) + fea2 * (1 - mix_factor.expand_as(fea2)) 17 | return out 18 | 19 | 20 | 21 | 22 | # class Attention(nn.Module): 23 | # def __init__(self,channel,b=1, gamma=2): 24 | # super(Attention, self).__init__() 25 | # self.avg_pool = nn.AdaptiveAvgPool2d(1)#全局平均池化 26 | # #一维卷积 27 | # t = int(abs((math.log(channel, 2) + b) / gamma)) 28 | # k = t if t % 2 else t + 1 29 | # self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k / 2), bias=False) 30 | # self.fc = nn.Conv2d(channel, channel, 1, padding=0, bias=True) 31 | # self.sigmoid = nn.Sigmoid() 32 | # self.mix = Mix() 33 | # #全连接 34 | # #self.fc = nn.Linear(channel,channel) 35 | # #self.softmax = nn.Softmax(dim=1) 36 | 37 | # def forward(self, input): 38 | # x = self.avg_pool(input) 39 | # x1 = self.conv1(x.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 40 | # out1 = self.sigmoid(x1) 41 | # out2 = self.fc(x) 42 | # out2 = self.sigmoid(out2) 43 | # out = self.mix(out1,out2) 44 | # out = self.conv1(out.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 45 | # #out = self.softmax(out) 46 | 47 | # return input*out 48 | class Attention(nn.Module): 49 | def __init__(self,channel,b=1, gamma=2): 50 | super(Attention, self).__init__() 51 | self.avg_pool = nn.AdaptiveAvgPool2d(1)#全局平均池化 52 | #一维卷积 53 | t = int(abs((math.log(channel, 2) + b) / gamma)) 54 | k = t if t % 2 else t + 1 55 | self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k / 2), bias=False) 56 | self.fc = nn.Conv2d(channel, channel, 1, padding=0, bias=True) 57 | self.sigmoid = nn.Sigmoid() 58 | self.mix = Mix() 59 | 60 | 61 | def forward(self, input): 62 | x = self.avg_pool(input) 63 | x1 = self.conv1(x.squeeze(-1).transpose(-1, -2)).transpose(-1, -2)#(1,64,1) 64 | x2 = self.fc(x).squeeze(-1).transpose(-1, -2)#(1,1,64) 65 | out1 = torch.sum(torch.matmul(x1,x2),dim=1).unsqueeze(-1).unsqueeze(-1)#(1,64,1,1) 66 | #x1 = x1.transpose(-1, -2).unsqueeze(-1) 67 | out1 = self.sigmoid(out1) 68 | out2 = torch.sum(torch.matmul(x2.transpose(-1, -2),x1.transpose(-1, -2)),dim=1).unsqueeze(-1).unsqueeze(-1) 69 | 70 | #out2 = self.fc(x) 71 | out2 = self.sigmoid(out2) 72 | out = self.mix(out1,out2) 73 | out = self.conv1(out.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1) 74 | out = self.sigmoid(out) 75 | 76 | return input*out 77 | 78 | if __name__ == '__main__': 79 | input = torch.rand(1,64,256,256) 80 | 81 | A = Attention(channel=64) 82 | #stat(A, input_size=[64, 1, 1]) 83 | y = A(input) 84 | print(y.size()) 85 | 86 | 87 | -------------------------------------------------------------------------------- /UBRFC/CR.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | from torch.nn import functional as F 4 | import torch.nn.functional as fnn 5 | from torch.autograd import Variable 6 | import numpy as np 7 | from torchvision import models 8 | 9 | class Vgg19(torch.nn.Module): 10 | def __init__(self, requires_grad=False): 11 | super(Vgg19, self).__init__() 12 | vgg_pretrained_features = models.vgg19(pretrained=True).features 13 | self.slice1 = torch.nn.Sequential() 14 | self.slice2 = torch.nn.Sequential() 15 | self.slice3 = torch.nn.Sequential() 16 | self.slice4 = torch.nn.Sequential() 17 | self.slice5 = torch.nn.Sequential() 18 | for x in range(2): 19 | self.slice1.add_module(str(x), vgg_pretrained_features[x]) 20 | for x in range(2, 7): 21 | self.slice2.add_module(str(x), vgg_pretrained_features[x]) 22 | for x in range(7, 12): 23 | self.slice3.add_module(str(x), vgg_pretrained_features[x]) 24 | for x in range(12, 21): 25 | self.slice4.add_module(str(x), vgg_pretrained_features[x]) 26 | for x in range(21, 30): 27 | self.slice5.add_module(str(x), vgg_pretrained_features[x]) 28 | if not requires_grad: 29 | for param in self.parameters(): 30 | param.requires_grad = False 31 | 32 | def forward(self, X): 33 | h_relu1 = self.slice1(X) 34 | h_relu2 = self.slice2(h_relu1) 35 | h_relu3 = self.slice3(h_relu2) 36 | h_relu4 = self.slice4(h_relu3) 37 | h_relu5 = self.slice5(h_relu4) 38 | return [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5] 39 | 40 | class ContrastLoss(nn.Module): 41 | def __init__(self, device,ablation=False): 42 | 43 | super(ContrastLoss, self).__init__() 44 | self.vgg = Vgg19().to(device) 45 | self.l1 = nn.L1Loss().to(device) 46 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0] 47 | self.ab = ablation 48 | 49 | def forward(self, x, gt, push1,push2): 50 | a_vgg, p_vgg, n_vgg,m_vgg = self.vgg(x), self.vgg(gt), self.vgg(push1),self.vgg(push2) 51 | loss = 0 52 | 53 | d_ap, d_an = 0, 0 54 | for i in range(len(a_vgg)): 55 | d_ap = self.l1(a_vgg[i], p_vgg[i].detach()) 56 | if not self.ab: 57 | d_an = self.l1(a_vgg[i], n_vgg[i].detach()) 58 | d_am = self.l1(a_vgg[i], m_vgg[i].detach()) 59 | contrastive = d_ap / (d_an + d_am + 1e-7) 60 | else: 61 | contrastive = d_ap 62 | 63 | loss += self.weights[i] * contrastive 64 | return loss 65 | -------------------------------------------------------------------------------- /UBRFC/Dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from PIL import Image 4 | from torch.utils.data import Dataset, DataLoader 5 | from torchvision.transforms import transforms 6 | 7 | 8 | class Datasets(Dataset): 9 | 10 | def __init__(self, root_dir, isTrain=True, x_name='hazy', y_name='clean'): 11 | self.root_dir = root_dir 12 | self.isTrain = isTrain 13 | self.train_X = x_name 14 | self.train_Y = y_name 15 | self.X_dir_list = os.listdir(os.path.join(self.root_dir, self.train_X)) # get list of image paths in domain X 16 | self.Y_dir_list = os.listdir(os.path.join(self.root_dir, self.train_Y)) # get list of image paths in domain Y 17 | self.transforms = self.get_transforms() 18 | 19 | def __len__(self): 20 | return len(self.X_dir_list) 21 | 22 | def __getitem__(self, index): 23 | X_img_name = self.X_dir_list[index % len(self.X_dir_list)] 24 | 25 | if self.isTrain: 26 | ind_Y = random.randint(0, len(self.Y_dir_list) - 1) 27 | Y_img_name = self.Y_dir_list[ind_Y] 28 | else: 29 | assert len(self.X_dir_list) == len(self.Y_dir_list) 30 | Y_img_name = self.Y_dir_list[index % len(self.Y_dir_list)] 31 | #Y_img_name = X_img_name 32 | name = Y_img_name.split('.jpg')[0].split('.png')[0] 33 | X_img = Image.open(os.path.join(self.root_dir, self.train_X, X_img_name)) 34 | Y_img = Image.open(os.path.join(self.root_dir, self.train_Y, Y_img_name)) 35 | X = self.transforms(X_img) 36 | Y = self.transforms(Y_img) 37 | 38 | 39 | return X, Y,name 40 | 41 | def get_transforms(self, crop_size=256): 42 | 43 | if self.isTrain: 44 | all_transforms = [transforms.RandomCrop(crop_size), transforms.ToTensor()] 45 | else: 46 | all_transforms = [transforms.ToTensor()] 47 | 48 | return transforms.Compose(all_transforms) 49 | 50 | 51 | class CustomDatasetLoader(): 52 | 53 | def __init__(self, root_dir,isTrain=True,x_name='hazy', y_name='clean',batch_size=2): 54 | if isTrain: 55 | assert batch_size>=2 56 | self.dataset = Datasets(root_dir,isTrain,x_name,y_name) 57 | self.dataloader = DataLoader(self.dataset, batch_size=batch_size, pin_memory=True, shuffle=True,num_workers=4,drop_last=True) 58 | 59 | def __len__(self): 60 | """Return the number of data in the dataset""" 61 | return len(self.dataset) 62 | 63 | def __iter__(self): 64 | """Return a batch of data""" 65 | for i, data in enumerate(self.dataloader): 66 | yield data 67 | 68 | def load_data(self): 69 | return self 70 | -------------------------------------------------------------------------------- /UBRFC/Declaration.py: -------------------------------------------------------------------------------- 1 | import os 2 | import time 3 | import torch 4 | from Option import opt 5 | from Loss import SSIMLoss 6 | from CR import ContrastLoss 7 | from Parameter import Parameter 8 | from Dataset import CustomDatasetLoader 9 | from GAN import Generator, Discriminator 10 | 11 | lr = opt.lr 12 | d_out_size = 30 13 | device = "cuda:"+opt.device 14 | epochs = opt.epochs 15 | 16 | train_dataloader = CustomDatasetLoader(root_dir=opt.train_root, isTrain=True,batch_size=opt.batch_size,y_name='clear').load_data().dataloader 17 | test_dataloader = CustomDatasetLoader(root_dir=opt.test_root, isTrain=False,batch_size=1,y_name='clear').load_data().dataloader 18 | 19 | generator = Generator() 20 | discriminator = Discriminator() # 输出 bz 1 30 30 21 | parameterNet = Parameter() 22 | 23 | criterionSsim = SSIMLoss() 24 | criterion = torch.nn.MSELoss() 25 | criterionP = torch.nn.L1Loss() 26 | criterionC = ContrastLoss(device,True) 27 | 28 | if os.path.exists(opt.model_path) and opt.is_continue: 29 | lr = torch.load(opt.model_path)["lr"] 30 | 31 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr) 32 | optimizer_T = torch.optim.Adam([ 33 | {'params': generator.parameters(), 'lr': lr}, 34 | {'params': parameterNet.parameters(), 'lr': lr} 35 | ]) 36 | 37 | timeStamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime()) 38 | 39 | generator.to(device) 40 | criterion.to(device) 41 | criterionP.to(device) 42 | parameterNet.to(device) 43 | discriminator.to(device) 44 | criterionSsim.to(device) 45 | 46 | scheduler_T = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_T, T_max=epochs, eta_min=0, last_epoch=-1) 47 | scheduler_D = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_D, T_max=epochs, eta_min=0, last_epoch=-1) 48 | 49 | if os.path.exists(opt.model_path) and opt.is_continue: 50 | print("加载模型") 51 | generator.load_state_dict(torch.load(opt.model_path)["generator_net"]) 52 | discriminator.load_state_dict(torch.load(opt.model_path)["discriminator_net"]) 53 | parameterNet.load_state_dict(torch.load(opt.model_path)["parameterNet_net"]) 54 | optimizer_T.load_state_dict(torch.load(opt.model_path)["optimizer_T"]) 55 | optimizer_D.load_state_dict(torch.load(opt.model_path)["optimizer_D"]) 56 | iter_num = torch.load(opt.model_path)["epoch"] 57 | scheduler_T.step() 58 | scheduler_D.step() 59 | else: 60 | iter_num = -1 -------------------------------------------------------------------------------- /UBRFC/GAN.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import functools 3 | from torch import nn 4 | #from torchsummary import summary 5 | 6 | from Attention import Attention 7 | from Util import PALayer, ConvGroups, FE_Block, Fusion_Block, ResnetBlock, ConvBlock, CALayer 8 | 9 | 10 | class Generator(nn.Module): 11 | def __init__(self, ngf=64, bn=False): 12 | super(Generator, self).__init__() 13 | # 下采样 14 | self.down1 = ResnetBlock(3, first=True) 15 | 16 | self.down2 = ResnetBlock(ngf, levels=2) 17 | 18 | self.down3 = ResnetBlock(ngf * 2, levels=2, bn=bn) 19 | 20 | self.res = nn.Sequential( 21 | ResnetBlock(ngf * 4, levels=6, down=False, bn=True) 22 | ) 23 | 24 | # 上采样 25 | 26 | self.up1 = nn.Sequential( 27 | nn.LeakyReLU(True), 28 | nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=3, stride=2, padding=1, 29 | output_padding=1), 30 | nn.InstanceNorm2d(ngf * 2) if not bn else nn.BatchNorm2d(ngf * 2), 31 | ) 32 | 33 | self.up2 = nn.Sequential( 34 | nn.LeakyReLU(True), 35 | nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1), 36 | nn.InstanceNorm2d(ngf), 37 | ) 38 | 39 | self.info_up1 = nn.Sequential( 40 | nn.LeakyReLU(True), 41 | nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=3, stride=2, padding=1, 42 | output_padding=1), 43 | nn.InstanceNorm2d(ngf * 2) if not bn else nn.BatchNorm2d(ngf * 2, eps=1e-5), 44 | ) 45 | 46 | self.info_up2 = nn.Sequential( 47 | nn.LeakyReLU(True), 48 | nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1), 49 | nn.InstanceNorm2d(ngf) # if not bn else nn.BatchNorm2d(ngf, eps=1e-5), 50 | ) 51 | 52 | self.up3 = nn.Sequential( 53 | nn.ReflectionPad2d(3), 54 | nn.Conv2d(ngf, 3, kernel_size=7, padding=0), 55 | nn.Tanh()) 56 | 57 | self.pa2 = PALayer(ngf * 4) 58 | self.pa3 = PALayer(ngf * 2) 59 | self.pa4 = PALayer(ngf) 60 | 61 | # self.ca2 = CALayer(ngf * 4) 62 | # self.ca3 = CALayer(ngf * 2) 63 | # self.ca4 = CALayer(ngf) 64 | self.ca2 = Attention(ngf * 4) 65 | self.ca3 = Attention(ngf * 2) 66 | self.ca4 = Attention(ngf) 67 | 68 | self.down_dcp = ConvGroups(3, bn=bn) 69 | 70 | self.fam1 = FE_Block(ngf, ngf) 71 | self.fam2 = FE_Block(ngf, ngf * 2) 72 | self.fam3 = FE_Block(ngf * 2, ngf * 4) 73 | 74 | self.att1 = Fusion_Block(ngf) 75 | self.att2 = Fusion_Block(ngf * 2) 76 | self.att3 = Fusion_Block(ngf * 4, bn=bn) 77 | 78 | self.merge2 = nn.Sequential( 79 | ConvBlock(ngf * 2, ngf * 2, kernel_size=(3, 3), stride=1, padding=1), 80 | nn.LeakyReLU(inplace=False) 81 | ) 82 | self.merge3 = nn.Sequential( 83 | ConvBlock(ngf, ngf, kernel_size=(3, 3), stride=1, padding=1), 84 | nn.LeakyReLU(inplace=False) 85 | ) 86 | 87 | def forward(self, hazy, img=0, first=True): 88 | if not first: 89 | dcp_down1, dcp_down2, dcp_down3 = self.down_dcp(img) 90 | x_down1 = self.down1(hazy) # [bs, ngf, ngf * 4, ngf * 4][1,64,256,256] 91 | 92 | att1 = self.att1(dcp_down1, x_down1) if not first else x_down1 #[1,64,256,256] 93 | 94 | x_down2 = self.down2(x_down1) # [bs, ngf*2, ngf*2, ngf*2][1,128,128,128] 95 | att2 = self.att2(dcp_down2, x_down2) if not first else None #None 96 | fuse2 = self.fam2(att1, att2) if not first else self.fam2(att1, x_down2)#[1,128,128,128] 97 | 98 | x_down3 = self.down3(x_down2) # [bs, ngf * 4, ngf, ngf]#[1,256,64,64] 99 | att3 = self.att3(dcp_down3, x_down3) if not first else None #None 100 | fuse3 = self.fam3(fuse2, att3) if not first else self.fam3(fuse2, x_down3)#[1,256,64,64] 101 | 102 | x6 = self.pa2(self.ca2(self.res(x_down3)))#[1,256,64,64] 103 | 104 | fuse_up2 = self.info_up1(fuse3)#[1,128,128,128] 105 | fuse_up2 = self.merge2(fuse_up2 + x_down2)#[1,128,128,128] 106 | 107 | fuse_up3 = self.info_up2(fuse_up2)#[1,64,256,256] 108 | fuse_up3 = self.merge3(fuse_up3 + x_down1)#[1,64,256,256] 109 | 110 | x_up2 = self.up1(x6 + fuse3)#[1,128,128,128] 111 | x_up2 = self.ca3(x_up2)#[1,128,128,128] 112 | x_up2 = self.pa3(x_up2)#[1,128,128,128] 113 | 114 | x_up3 = self.up2(x_up2 + fuse_up2)#[1,64,256,256] 115 | x_up3 = self.ca4(x_up3)#[1,64,256,256] 116 | x_up3 = self.pa4(x_up3)#[1,64,256,256] 117 | 118 | x_up4 = self.up3(x_up3 + fuse_up3)#[1,3,256,256] 119 | 120 | return x_up4 121 | 122 | 123 | class Discriminator(nn.Module): 124 | """ 125 | Discriminator class 126 | """ 127 | 128 | def __init__(self, inp=3, out=1): 129 | """ 130 | Initializes the PatchGAN model with 3 layers as discriminator 131 | 132 | Args: 133 | inp: number of input image channels 134 | out: number of output image channels 135 | """ 136 | 137 | super(Discriminator, self).__init__() 138 | 139 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False) 140 | 141 | model = [ 142 | nn.Conv2d(inp, 64, kernel_size=4, stride=2, padding=1), # input 3 channels 143 | nn.LeakyReLU(0.2, True), 144 | 145 | nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=True), 146 | norm_layer(128), 147 | nn.LeakyReLU(0.2, True), 148 | 149 | nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=True), 150 | norm_layer(256), 151 | nn.LeakyReLU(0.2, True), 152 | 153 | nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1, bias=True), 154 | norm_layer(512), 155 | nn.LeakyReLU(0.2, True), 156 | 157 | nn.Conv2d(512, out, kernel_size=4, stride=1, padding=1) # output only 1 channel (prediction map) 158 | ] 159 | self.model = nn.Sequential(*model) 160 | 161 | def forward(self, input): 162 | """ 163 | Feed forward the image produced by generator through discriminator 164 | 165 | Args: 166 | input: input image 167 | 168 | Returns: 169 | outputs prediction map with 1 channel 170 | """ 171 | result = self.model(input) 172 | 173 | return result 174 | 175 | 176 | class discriminator(nn.Module): 177 | def __init__(self, bn=False, ngf=64): 178 | super(discriminator, self).__init__() 179 | self.net = nn.Sequential( 180 | nn.Conv2d(3, ngf, kernel_size=3, padding=1), 181 | nn.LeakyReLU(0.2), 182 | 183 | nn.ReflectionPad2d(1), 184 | nn.Conv2d(ngf, ngf, kernel_size=3, stride=2, padding=0), 185 | nn.InstanceNorm2d(ngf), 186 | nn.LeakyReLU(0.2), 187 | 188 | nn.ReflectionPad2d(1), 189 | nn.Conv2d(ngf, ngf * 2, kernel_size=3, padding=0), 190 | nn.InstanceNorm2d(ngf * 2), 191 | nn.LeakyReLU(0.2), 192 | 193 | nn.ReflectionPad2d(1), 194 | nn.Conv2d(ngf * 2, ngf * 2, kernel_size=3, stride=2, padding=0), 195 | nn.InstanceNorm2d(ngf * 2), 196 | nn.LeakyReLU(0.2), 197 | 198 | nn.ReflectionPad2d(1), 199 | nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, padding=0), 200 | nn.InstanceNorm2d(ngf * 4) if not bn else nn.BatchNorm2d(ngf * 4), 201 | nn.LeakyReLU(0.2), 202 | 203 | nn.ReflectionPad2d(1), 204 | nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=2, padding=0), 205 | nn.InstanceNorm2d(ngf * 4) if not bn else nn.BatchNorm2d(ngf * 4), 206 | nn.LeakyReLU(0.2), 207 | 208 | nn.ReflectionPad2d(1), 209 | nn.Conv2d(ngf * 4, ngf * 8, kernel_size=3, padding=0), 210 | nn.BatchNorm2d(ngf * 8) if bn else nn.InstanceNorm2d(ngf * 8), 211 | nn.LeakyReLU(0.2), 212 | 213 | nn.ReflectionPad2d(1), 214 | nn.Conv2d(ngf * 8, ngf * 8, kernel_size=3, stride=2, padding=0), 215 | nn.BatchNorm2d(ngf * 8) if bn else nn.InstanceNorm2d(ngf * 8), 216 | nn.LeakyReLU(0.2), 217 | 218 | nn.AdaptiveAvgPool2d(1), 219 | nn.Conv2d(ngf * 8, ngf * 16, kernel_size=1), 220 | nn.LeakyReLU(0.2), 221 | nn.Conv2d(ngf * 16, 1, kernel_size=1) 222 | 223 | ) 224 | 225 | def forward(self, x): 226 | batch_size = x.size(0) 227 | return torch.sigmoid(self.net(x).view(batch_size)) 228 | 229 | 230 | 231 | 232 | if __name__ == '__main__': 233 | net = Generator() 234 | #summary(net, (3,256,256), batch_size=1, device="cpu") 235 | x = torch.randn((1,3,256,256)) 236 | # 237 | #print(net) 238 | y = net(x) 239 | print(y.size()) -------------------------------------------------------------------------------- /UBRFC/Get_image.py: -------------------------------------------------------------------------------- 1 | import os 2 | import math 3 | import time 4 | import torch 5 | import numpy as np 6 | from tqdm import tqdm 7 | from Option import opt 8 | from GAN import Generator 9 | from Metrics import ssim, psnr 10 | from Parameter import Parameter 11 | from Dataset import CustomDatasetLoader 12 | from torchvision.utils import save_image 13 | 14 | lr = opt.lr 15 | d_out_size = 30 16 | device = "cuda:"+opt.device 17 | epochs = opt.epochs 18 | 19 | 20 | test_dataloader = CustomDatasetLoader(root_dir=opt.test_root, isTrain=False,batch_size=1).load_data().dataloader 21 | 22 | generator = Generator() 23 | 24 | parameterNet = Parameter() 25 | generator.to(device) 26 | 27 | parameterNet.to(device) 28 | 29 | model_path = opt.pre_model_path 30 | generator.load_state_dict(torch.load(model_path)["generator_net"]) 31 | parameterNet.load_state_dict(torch.load(model_path)["parameterNet_net"]) 32 | def padding_image(image, h, w): 33 | assert h >= image.size(2) 34 | assert w >= image.size(3) 35 | padding_top = (h - image.size(2)) // 2 36 | padding_down = h - image.size(2) - padding_top 37 | padding_left = (w - image.size(3)) // 2 38 | padding_right = w - image.size(3) - padding_left 39 | out = torch.nn.functional.pad(image, (padding_left, padding_right, padding_top, padding_down), mode='reflect') 40 | return out, padding_left, padding_left + image.size(3), padding_top, padding_top + image.size(2) 41 | 42 | def test(genertor_net, parameter_net, loader_test): 43 | genertor_net.eval() 44 | parameter_net.eval() 45 | clear_psnrs,clear_ssims = [],[] 46 | for i, (inputs, targets, name) in tqdm(enumerate(loader_test), total=len(loader_test), leave=False, desc="测试中"): 47 | #print(name) 48 | h, w = inputs.shape[2], inputs.shape[3] 49 | #print('h, w:',h, w) 50 | 51 | if h>w: 52 | max_h = int(math.ceil(h / 512)) * 512 53 | max_w = int(math.ceil(w / 512)) * 512 54 | else: 55 | max_h = int(math.ceil(h / 256)) * 256 56 | max_w = int(math.ceil(w / 256)) * 256 57 | inputs, ori_left, ori_right, ori_top, ori_down = padding_image(inputs, max_h, max_w) 58 | 59 | inputs = inputs.to(device) 60 | targets = targets.to(device) 61 | 62 | pred = genertor_net(inputs) 63 | #print("pred.size:",pred.size()) 64 | _, dehazy_pred = parameter_net(inputs, pred) 65 | 66 | 67 | pred = pred.data[:, :, ori_top:ori_down, ori_left:ori_right] 68 | dehazy_pred = dehazy_pred.data[:, :, ori_top:ori_down, ori_left:ori_right] 69 | 70 | ssim1 = ssim(pred, targets).item() 71 | psnr1 = psnr(pred, targets) 72 | 73 | ssim11 = ssim(dehazy_pred, targets).item() 74 | psnr11 = psnr(dehazy_pred, targets) 75 | 76 | save_image(inputs.data[:1], os.path.join(opt.out_hazy_path, "%s.png" % name[0])) 77 | save_image(targets.data[:1], os.path.join(opt.out_gt_path, "%s.png" % name[0])) 78 | 79 | if psnr11 >= psnr1: 80 | save_image(dehazy_pred.data[:1], os.path.join(opt.out_clear_path, "%s.png" % name[0])) 81 | clear_ssims.append(ssim11) 82 | clear_psnrs.append(psnr11) 83 | else: 84 | save_image(pred.data[:1], os.path.join(opt.out_clear_path, "%s.png" % name[0])) 85 | clear_ssims.append(ssim1) 86 | clear_psnrs.append(psnr1) 87 | 88 | return np.mean(clear_ssims), np.mean(clear_psnrs) 89 | 90 | if __name__ == '__main__': 91 | with torch.no_grad(): 92 | print(test(generator, parameterNet, test_dataloader)) -------------------------------------------------------------------------------- /UBRFC/Loss.py: -------------------------------------------------------------------------------- 1 | # Loss functions 2 | from torch import nn 3 | from torchvision import models 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from math import exp 8 | import numpy as np 9 | from torchvision.models import vgg16 10 | import warnings 11 | 12 | warnings.filterwarnings('ignore') 13 | 14 | # 计算一维的高斯分布向量 15 | def gaussian(window_size, sigma): 16 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 17 | return gauss / gauss.sum() 18 | 19 | 20 | # 创建高斯核,通过两个一维高斯分布向量进行矩阵乘法得到 21 | # 可以设定channel参数拓展为3通道 22 | def create_window(window_size, channel=1): 23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 25 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous() 26 | return window 27 | 28 | 29 | # 计算SSIM 30 | # 直接使用SSIM的公式,但是在计算均值时,不是直接求像素平均值,而是采用归一化的高斯核卷积来代替。 31 | # 在计算方差和协方差时用到了公式Var(X)=E[X^2]-E[X]^2, cov(X,Y)=E[XY]-E[X]E[Y]. 32 | # 正如前面提到的,上面求期望的操作采用高斯核卷积代替。 33 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None): 34 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh). 35 | if val_range is None: 36 | if torch.max(img1) > 128: 37 | max_val = 255 38 | else: 39 | max_val = 1 40 | 41 | if torch.min(img1) < -0.5: 42 | min_val = -1 43 | else: 44 | min_val = 0 45 | L = max_val - min_val 46 | else: 47 | L = val_range 48 | 49 | padd = 0 50 | (_, channel, height, width) = img1.size() 51 | if window is None: 52 | real_size = min(window_size, height, width) 53 | window = create_window(real_size, channel=channel).to(img1.device) 54 | 55 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel) 56 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel) 57 | 58 | mu1_sq = mu1.pow(2) 59 | mu2_sq = mu2.pow(2) 60 | mu1_mu2 = mu1 * mu2 61 | 62 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq 63 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq 64 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2 65 | 66 | C1 = (0.01 * L) ** 2 67 | C2 = (0.03 * L) ** 2 68 | 69 | v1 = 2.0 * sigma12 + C2 70 | v2 = sigma1_sq + sigma2_sq + C2 71 | cs = torch.mean(v1 / v2) # contrast sensitivity 72 | 73 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2) 74 | 75 | if size_average: 76 | ret = ssim_map.mean() 77 | else: 78 | ret = ssim_map.mean(1).mean(1).mean(1) 79 | 80 | if full: 81 | return ret, cs 82 | return ret 83 | 84 | 85 | # Classes to re-use window 86 | class SSIMLoss(torch.nn.Module): 87 | def __init__(self, window_size=11, size_average=True, val_range=None): 88 | super(SSIMLoss, self).__init__() 89 | self.window_size = window_size 90 | self.size_average = size_average 91 | self.val_range = val_range 92 | 93 | # Assume 1 channel for SSIM 94 | self.channel = 1 95 | self.window = create_window(window_size) 96 | 97 | def forward(self, img1, img2): 98 | (_, channel, _, _) = img1.size() 99 | 100 | if channel == self.channel and self.window.dtype == img1.dtype: 101 | window = self.window 102 | else: 103 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype) 104 | self.window = window 105 | self.channel = channel 106 | 107 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average) 108 | 109 | 110 | 111 | 112 | 113 | -------------------------------------------------------------------------------- /UBRFC/Metrics.py: -------------------------------------------------------------------------------- 1 | from math import exp 2 | import math 3 | import numpy as np 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | 9 | 10 | 11 | 12 | def gaussian(window_size, sigma): 13 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)]) 14 | return gauss / gauss.sum() 15 | 16 | 17 | def create_window(window_size, channel): 18 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1) 19 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0) 20 | # torch.mul(a, b)是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵 21 | # torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵 22 | # .t(), 求转置 23 | # unsqueeze()在二维张量前面添加2个轴,变成四维张量 24 | 25 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous()) 26 | # 把张量扩展成(channel, 1, window_size, window_size)的大小,以原来的值填充(其自身的值不变) 27 | # contiguous:view只能用在contiguous的variable上。contiguous一般与transpose,permute,view搭配使用 28 | # 即使用transpose或permute进行维度变换后,需要用contiguous()来返回一个contiguous copy,然后方可使用view对维度进行变形 29 | 30 | return window 31 | 32 | 33 | def _ssim(img1, img2, window, window_size, channel, size_average=True): 34 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel) 35 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel) 36 | mu1_sq = mu1.pow(2) 37 | mu2_sq = mu2.pow(2) 38 | mu1_mu2 = mu1 * mu2 39 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq 40 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq 41 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2 42 | C1 = 0.01 ** 2 43 | C2 = 0.03 ** 2 44 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)) 45 | 46 | if size_average: 47 | return ssim_map.mean() 48 | else: 49 | return ssim_map.mean(1).mean(1).mean(1) 50 | 51 | 52 | def ssim(img1, img2, window_size=11, size_average=True): 53 | img1 = torch.clamp(img1, min=0, max=1) 54 | img2 = torch.clamp(img2, min=0, max=1) 55 | (_, channel, _, _) = img1.size() 56 | window = create_window(window_size, channel) 57 | if img1.is_cuda: 58 | window = window.cuda(img1.get_device()) 59 | window = window.type_as(img1) 60 | return _ssim(img1, img2, window, window_size, channel, size_average) 61 | 62 | 63 | def psnr(pred, gt): 64 | pred = pred.clamp(0, 1).cpu().numpy() 65 | gt = gt.clamp(0, 1).cpu().numpy() 66 | imdff = pred - gt 67 | rmse = math.sqrt(np.mean(imdff ** 2)) 68 | if rmse == 0: 69 | return 100 70 | return 20 * math.log10(1.0 / rmse) 71 | 72 | 73 | if __name__ == "__main__": 74 | pass 75 | -------------------------------------------------------------------------------- /UBRFC/Option.py: -------------------------------------------------------------------------------- 1 | import os 2 | import argparse 3 | 4 | 5 | class Options: 6 | 7 | def initialize(self, parser): 8 | parser.add_argument('--train_root', default='/data/wy/datasets/Haze1k/train',help='Training path') 9 | parser.add_argument('--test_root', default='/home/wy/datasets/Haze1k/Haze1k_thin/dataset/test', help='Testing path') 10 | parser.add_argument('--batch_size', type=int, default=4, help='input batch size') 11 | parser.add_argument('--epochs', type=int, default=10000, help='Total number of training') 12 | parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate') 13 | parser.add_argument('--device', type=str, default='0', help='GPU id') 14 | parser.add_argument('--is_continue', default=False, help='Whether to continue') 15 | parser.add_argument('--model_path', default='./model/last_model.pth', help='Loading model') 16 | parser.add_argument('--pre_model_path', default='./model/best_model.pth', help='Loading model') 17 | parser.add_argument('--out_hazy_path', default='./images/input') 18 | parser.add_argument('--out_gt_path', default='./images/targets') 19 | parser.add_argument('--out_clear_path', default='./images/clear') 20 | parser.add_argument('--out_log_path', default='./Log/log_train') 21 | return parser 22 | 23 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 24 | parser = Options().initialize(parser) 25 | opt, _ = parser.parse_known_args() 26 | 27 | 28 | if not os.path.exists('./model'): 29 | os.makedirs('./model') 30 | if not os.path.exists('./Log/max_log'): 31 | os.makedirs('./Log/max_log') 32 | if not os.path.exists(opt.out_log_path): 33 | os.makedirs(opt.out_log_path) 34 | if not os.path.exists(opt.out_hazy_path): 35 | os.makedirs(opt.out_hazy_path) 36 | if not os.path.exists(opt.out_gt_path): 37 | os.makedirs(opt.out_gt_path) 38 | if not os.path.exists(opt.out_clear_path): 39 | os.makedirs(opt.out_clear_path) -------------------------------------------------------------------------------- /UBRFC/Parameter.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from Attention import Attention 4 | import torch.nn.functional as F 5 | import torchvision.models as models 6 | 7 | 8 | def conv_block(in_dim, out_dim): 9 | return nn.Sequential(nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1), 10 | nn.ELU(True), 11 | nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1), 12 | nn.ELU(True), 13 | nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=1, padding=0), 14 | nn.AvgPool2d(kernel_size=2, stride=2)) 15 | 16 | 17 | def deconv_block(in_dim, out_dim): 18 | return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=1, padding=1), 19 | nn.ELU(True), 20 | nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1), 21 | nn.ELU(True), 22 | nn.UpsamplingNearest2d(scale_factor=2)) 23 | 24 | 25 | def blockUNet1(in_c, out_c, name, transposed=False, bn=False, relu=True, dropout=False): 26 | block = nn.Sequential() 27 | if relu: 28 | block.add_module('%s_relu' % name, nn.ReLU(inplace=True)) 29 | else: 30 | block.add_module('%s_leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True)) 31 | if not transposed: 32 | block.add_module('%s_conv' % name, nn.Conv2d(in_c, out_c, 3, 1, 1, bias=False)) 33 | else: 34 | block.add_module('%s_tconv' % name, nn.ConvTranspose2d(in_c, out_c, 3, 1, 1, bias=False)) 35 | if bn: 36 | block.add_module('%s_bn' % name, nn.BatchNorm2d(out_c)) 37 | if dropout: 38 | block.add_module('%s_dropout' % name, nn.Dropout2d(0.5, inplace=True)) 39 | return block 40 | 41 | 42 | def blockUNet(in_c, out_c, name, transposed=False, bn=False, relu=True, dropout=False): 43 | block = nn.Sequential() 44 | if relu: 45 | block.add_module('%s_relu' % name, nn.ReLU(inplace=True)) 46 | else: 47 | block.add_module('%s_leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True)) 48 | if not transposed: 49 | block.add_module('%s_conv' % name, nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False)) 50 | else: 51 | block.add_module('%s_tconv' % name, nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False)) 52 | if bn: 53 | block.add_module('%s_bn' % name, nn.BatchNorm2d(out_c)) 54 | if dropout: 55 | block.add_module('%s_dropout' % name, nn.Dropout2d(0.5, inplace=True)) 56 | return block 57 | 58 | 59 | class G(nn.Module): 60 | def __init__(self, input_nc, output_nc, nf): 61 | super(G, self).__init__() 62 | # input is 256 x 256 63 | layer_idx = 1 64 | name = 'layer%d' % layer_idx 65 | layer1 = nn.Sequential() 66 | layer1.add_module(name, nn.Conv2d(input_nc, nf, 4, 2, 1, bias=False)) 67 | # input is 128 x 128 68 | layer_idx += 1 69 | name = 'layer%d' % layer_idx 70 | layer2 = blockUNet(nf, nf * 2, name, transposed=False, bn=True, relu=False, dropout=False) 71 | # input is 64 x 64 72 | layer_idx += 1 73 | name = 'layer%d' % layer_idx 74 | layer3 = blockUNet(nf * 2, nf * 4, name, transposed=False, bn=True, relu=False, dropout=False) 75 | # input is 32 76 | layer_idx += 1 77 | name = 'layer%d' % layer_idx 78 | layer4 = blockUNet(nf * 4, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False) 79 | # input is 16 80 | layer_idx += 1 81 | name = 'layer%d' % layer_idx 82 | layer5 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False) 83 | # input is 8 84 | layer_idx += 1 85 | name = 'layer%d' % layer_idx 86 | layer6 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False) 87 | # input is 4 88 | layer_idx += 1 89 | name = 'layer%d' % layer_idx 90 | layer7 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False) 91 | # input is 2 x 2 92 | layer_idx += 1 93 | name = 'layer%d' % layer_idx 94 | layer8 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False) 95 | 96 | ## NOTE: decoder 97 | # input is 1 98 | name = 'dlayer%d' % layer_idx 99 | d_inc = nf * 8 100 | dlayer8 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=False, relu=True, dropout=True) 101 | 102 | # input is 2 103 | layer_idx -= 1 104 | name = 'dlayer%d' % layer_idx 105 | d_inc = nf * 8 * 2 106 | dlayer7 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=True, relu=True, dropout=True) 107 | # input is 4 108 | layer_idx -= 1 109 | name = 'dlayer%d' % layer_idx 110 | d_inc = nf * 8 * 2 111 | dlayer6 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=True, relu=True, dropout=True) 112 | # input is 8 113 | layer_idx -= 1 114 | name = 'dlayer%d' % layer_idx 115 | d_inc = nf * 8 * 2 116 | dlayer5 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=True, relu=True, dropout=False) 117 | # input is 16 118 | layer_idx -= 1 119 | name = 'dlayer%d' % layer_idx 120 | d_inc = nf * 8 * 2 121 | dlayer4 = blockUNet(d_inc, nf * 4, name, transposed=True, bn=True, relu=True, dropout=False) 122 | # input is 32 123 | layer_idx -= 1 124 | name = 'dlayer%d' % layer_idx 125 | d_inc = nf * 4 * 2 126 | dlayer3 = blockUNet(d_inc, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False) 127 | # input is 64 128 | layer_idx -= 1 129 | name = 'dlayer%d' % layer_idx 130 | d_inc = nf * 2 * 2 131 | dlayer2 = blockUNet(d_inc, nf, name, transposed=True, bn=True, relu=True, dropout=False) 132 | # input is 128 133 | layer_idx -= 1 134 | name = 'dlayer%d' % layer_idx 135 | dlayer1 = nn.Sequential() 136 | d_inc = nf * 2 137 | dlayer1.add_module('%s_relu' % name, nn.ReLU(inplace=True)) 138 | dlayer1.add_module('%s_tconv' % name, nn.ConvTranspose2d(d_inc, 20, 4, 2, 1, bias=False)) 139 | 140 | dlayerfinal = nn.Sequential() 141 | 142 | dlayerfinal.add_module('%s_conv' % name, nn.Conv2d(24, output_nc, 3, 1, 1, bias=False)) 143 | dlayerfinal.add_module('%s_tanh' % name, nn.Tanh()) 144 | 145 | self.conv1010 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm 146 | self.conv1020 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm 147 | self.conv1030 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm 148 | self.conv1040 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm 149 | 150 | self.refine3 = nn.Conv2d(20 + 4, 3, kernel_size=3, stride=1, padding=1) 151 | 152 | self.upsample = F.upsample_nearest 153 | 154 | self.layer1 = layer1 155 | self.layer2 = layer2 156 | self.layer3 = layer3 157 | self.layer4 = layer4 158 | self.layer5 = layer5 159 | self.layer6 = layer6 160 | self.layer7 = layer7 161 | self.layer8 = layer8 162 | self.dlayer8 = dlayer8 163 | self.dlayer7 = dlayer7 164 | self.dlayer6 = dlayer6 165 | self.dlayer5 = dlayer5 166 | self.dlayer4 = dlayer4 167 | self.dlayer3 = dlayer3 168 | self.dlayer2 = dlayer2 169 | self.dlayer1 = dlayer1 170 | self.dlayerfinal = dlayerfinal 171 | self.relu = nn.LeakyReLU(0.2, inplace=True) 172 | 173 | def forward(self, x): 174 | out1 = self.layer1(x) 175 | out2 = self.layer2(out1) 176 | out3 = self.layer3(out2) 177 | out4 = self.layer4(out3) 178 | out5 = self.layer5(out4) 179 | out6 = self.layer6(out5) 180 | out7 = self.layer7(out6) 181 | out8 = self.layer8(out7) 182 | dout8 = self.dlayer8(out8) 183 | dout8_out7 = torch.cat([dout8, out7], 1) 184 | dout7 = self.dlayer7(dout8_out7) 185 | dout7_out6 = torch.cat([dout7, out6], 1) 186 | dout6 = self.dlayer6(dout7_out6) 187 | dout6_out5 = torch.cat([dout6, out5], 1) 188 | dout5 = self.dlayer5(dout6_out5) 189 | dout5_out4 = torch.cat([dout5, out4], 1) 190 | dout4 = self.dlayer4(dout5_out4) 191 | dout4_out3 = torch.cat([dout4, out3], 1) 192 | dout3 = self.dlayer3(dout4_out3) 193 | dout3_out2 = torch.cat([dout3, out2], 1) 194 | dout2 = self.dlayer2(dout3_out2) 195 | dout2_out1 = torch.cat([dout2, out1], 1) 196 | dout1 = self.dlayer1(dout2_out1) 197 | 198 | shape_out = dout1.data.size() 199 | # print(shape_out) 200 | shape_out = shape_out[2:4] 201 | 202 | x101 = F.avg_pool2d(dout1, 16) 203 | x102 = F.avg_pool2d(dout1, 8) 204 | x103 = F.avg_pool2d(dout1, 4) 205 | x104 = F.avg_pool2d(dout1, 2) 206 | 207 | x1010 = self.upsample(self.relu(self.conv1010(x101)), size=shape_out) 208 | x1020 = self.upsample(self.relu(self.conv1020(x102)), size=shape_out) 209 | x1030 = self.upsample(self.relu(self.conv1030(x103)), size=shape_out) 210 | x1040 = self.upsample(self.relu(self.conv1040(x104)), size=shape_out) 211 | 212 | dehaze = torch.cat((x1010, x1020, x1030, x1040, dout1), 1) 213 | 214 | dout1 = self.dlayerfinal(dehaze) 215 | 216 | return dout1 217 | 218 | 219 | class G2(nn.Module): 220 | def __init__(self, input_nc, output_nc, nf): 221 | super(G2, self).__init__() 222 | # input is 256 x 256 223 | layer_idx = 1 224 | name = 'layer%d' % layer_idx 225 | layer1 = nn.Sequential() 226 | layer1.add_module(name, nn.Conv2d(input_nc, nf, 4, 2, 1, bias=False)) 227 | # input is 128 x 128 228 | layer_idx += 1 229 | name = 'layer%d' % layer_idx 230 | layer2 = blockUNet(nf, nf * 2, name, transposed=False, bn=True, relu=False, dropout=False) 231 | # input is 64 x 64 232 | layer_idx += 1 233 | name = 'layer%d' % layer_idx 234 | layer3 = blockUNet(nf * 2, nf * 4, name, transposed=False, bn=True, relu=False, dropout=False) 235 | # input is 32 236 | layer_idx += 1 237 | name = 'layer%d' % layer_idx 238 | layer4 = blockUNet(nf * 4, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False) 239 | # input is 16 240 | layer_idx += 1 241 | name = 'layer%d' % layer_idx 242 | layer5 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False) 243 | # input is 8 244 | layer_idx += 1 245 | name = 'layer%d' % layer_idx 246 | layer6 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False) 247 | # input is 4 248 | layer_idx += 1 249 | name = 'layer%d' % layer_idx 250 | layer7 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False) 251 | # input is 2 x 2 252 | layer_idx += 1 253 | name = 'layer%d' % layer_idx 254 | layer8 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False) 255 | 256 | ## NOTE: decoder 257 | # input is 1 258 | name = 'dlayer%d' % layer_idx 259 | d_inc = nf * 8 260 | dlayer8 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=False, relu=True, dropout=True) 261 | 262 | # import pdb; pdb.set_trace() 263 | # input is 2 264 | layer_idx -= 1 265 | name = 'dlayer%d' % layer_idx 266 | d_inc = nf * 8 * 2 267 | dlayer7 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=True, relu=True, dropout=True) 268 | # input is 4 269 | layer_idx -= 1 270 | name = 'dlayer%d' % layer_idx 271 | d_inc = nf * 8 * 2 272 | dlayer6 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=True, relu=True, dropout=True) 273 | # input is 8 274 | layer_idx -= 1 275 | name = 'dlayer%d' % layer_idx 276 | d_inc = nf * 8 * 2 277 | dlayer5 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=True, relu=True, dropout=False) 278 | # input is 16 279 | layer_idx -= 1 280 | name = 'dlayer%d' % layer_idx 281 | d_inc = nf * 8 * 2 282 | dlayer4 = blockUNet(d_inc, nf * 4, name, transposed=True, bn=True, relu=True, dropout=False) 283 | # input is 32 284 | layer_idx -= 1 285 | name = 'dlayer%d' % layer_idx 286 | d_inc = nf * 4 * 2 287 | dlayer3 = blockUNet(d_inc, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False) 288 | # input is 64 289 | layer_idx -= 1 290 | name = 'dlayer%d' % layer_idx 291 | d_inc = nf * 2 * 2 292 | dlayer2 = blockUNet(d_inc, nf, name, transposed=True, bn=True, relu=True, dropout=False) 293 | # input is 128 294 | layer_idx -= 1 295 | name = 'dlayer%d' % layer_idx 296 | dlayer1 = nn.Sequential() 297 | d_inc = nf * 2 298 | dlayer1.add_module('%s_relu' % name, nn.ReLU(inplace=True)) 299 | dlayer1.add_module('%s_tconv' % name, nn.ConvTranspose2d(d_inc, output_nc, 4, 2, 1, bias=False)) 300 | dlayer1.add_module('%s_tanh' % name, nn.LeakyReLU(0.2, inplace=True)) 301 | 302 | self.layer1 = layer1 303 | self.layer2 = layer2 304 | self.layer3 = layer3 305 | self.layer4 = layer4 306 | self.layer5 = layer5 307 | self.layer6 = layer6 308 | self.layer7 = layer7 309 | self.layer8 = layer8 310 | self.ca1 = Attention(64) 311 | self.dlayer8 = dlayer8 312 | self.ca2 = Attention(128) 313 | self.dlayer7 = dlayer7 314 | self.ca3 = Attention(128) 315 | self.dlayer6 = dlayer6 316 | self.dlayer5 = dlayer5 317 | self.dlayer4 = dlayer4 318 | self.dlayer3 = dlayer3 319 | self.dlayer2 = dlayer2 320 | self.dlayer1 = dlayer1 321 | 322 | def forward(self, x): 323 | out1 = self.layer1(x) # [1,8,128,128] 324 | out2 = self.layer2(out1) # [1,16,64,64] 325 | out3 = self.layer3(out2) # [1,32,32,32] 326 | out4 = self.layer4(out3) # [1,64,16,16] 327 | out5 = self.layer5(out4) # [1,64,8,8] 328 | out6 = self.layer6(out5) # [1,64,4,4] 329 | out7 = self.layer7(out6) # [1,64,2,2] 330 | out8 = self.layer8(out7) # [1,64,1,1] 331 | 332 | out8 = self.ca1(out8) 333 | dout8 = self.dlayer8(out8) # [1,64,2,2] 334 | dout8_out7 = torch.cat([dout8, out7], 1) # [1,128,2,2] 335 | dout8_out7 = self.ca2(dout8_out7) 336 | 337 | dout7 = self.dlayer7(dout8_out7) # [1,64,4,4] 338 | dout7_out6 = torch.cat([dout7, out6], 1) # [1,128,4,4] 339 | dout7_out6 = self.ca3(dout7_out6) 340 | 341 | dout6 = self.dlayer6(dout7_out6) # [1,64,8,8] 342 | dout6_out5 = torch.cat([dout6, out5], 1) # [1,128,8,8] 343 | dout5 = self.dlayer5(dout6_out5) # [1,64,16,16] 344 | dout5_out4 = torch.cat([dout5, out4], 1) # [1,128,16,16] 345 | dout4 = self.dlayer4(dout5_out4) # [1,32,32,32] 346 | dout4_out3 = torch.cat([dout4, out3], 1) # [1,64,32,32] 347 | dout3 = self.dlayer3(dout4_out3) # [1,16,64,64] 348 | dout3_out2 = torch.cat([dout3, out2], 1) # [1,32,64,64] 349 | dout2 = self.dlayer2(dout3_out2) # [1,8,128,128] 350 | dout2_out1 = torch.cat([dout2, out1], 1) # [1,16,128,128] 351 | dout1 = self.dlayer1(dout2_out1) # [1,3,256,256] 352 | return dout1 353 | 354 | 355 | class BottleneckBlock(nn.Module): 356 | def __init__(self, in_planes, out_planes, dropRate=0.0): 357 | super(BottleneckBlock, self).__init__() 358 | inter_planes = out_planes * 4 359 | self.bn1 = nn.BatchNorm2d(in_planes) 360 | self.relu = nn.ReLU(inplace=True) 361 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1, 362 | padding=0, bias=False) 363 | self.bn2 = nn.BatchNorm2d(inter_planes) 364 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1, 365 | padding=1, bias=False) 366 | self.droprate = dropRate 367 | 368 | def forward(self, x): 369 | out = self.conv1(self.relu(self.bn1(x))) 370 | if self.droprate > 0: 371 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 372 | out = self.conv2(self.relu(self.bn2(out))) 373 | if self.droprate > 0: 374 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 375 | return torch.cat([x, out], 1) 376 | 377 | 378 | class TransitionBlock(nn.Module): 379 | def __init__(self, in_planes, out_planes, dropRate=0.0): 380 | super(TransitionBlock, self).__init__() 381 | self.bn1 = nn.BatchNorm2d(in_planes) 382 | self.relu = nn.ReLU(inplace=True) 383 | self.conv1 = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=1, stride=1, 384 | padding=0, bias=False) 385 | self.droprate = dropRate 386 | 387 | def forward(self, x): 388 | out = self.conv1(self.relu(self.bn1(x))) 389 | if self.droprate > 0: 390 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training) 391 | return F.upsample_nearest(out, scale_factor=2) 392 | 393 | 394 | class Dense(nn.Module): 395 | def __init__(self): 396 | super(Dense, self).__init__() 397 | 398 | ############# 256-256 ############## 399 | haze_class = models.densenet121(pretrained=True) 400 | 401 | self.conv0 = haze_class.features.conv0 402 | self.norm0 = haze_class.features.norm0 403 | self.relu0 = haze_class.features.relu0 404 | self.pool0 = haze_class.features.pool0 405 | 406 | ############# Block1-down 64-64 ############## 407 | self.dense_block1 = haze_class.features.denseblock1 408 | self.trans_block1 = haze_class.features.transition1 409 | 410 | ############# Block2-down 32-32 ############## 411 | self.dense_block2 = haze_class.features.denseblock2 412 | self.trans_block2 = haze_class.features.transition2 413 | 414 | ############# Block3-down 16-16 ############## 415 | self.dense_block3 = haze_class.features.denseblock3 416 | self.trans_block3 = haze_class.features.transition3 417 | 418 | ############# Block4-up 8-8 ############## 419 | 420 | self.dense_block4 = BottleneckBlock(512, 256) 421 | self.trans_block4 = TransitionBlock(768, 128) 422 | 423 | ############# Block5-up 16-16 ############## 424 | self.pa1 = Attention(channel=384) 425 | self.dense_block5 = BottleneckBlock(384, 256) 426 | self.trans_block5 = TransitionBlock(640, 128) 427 | 428 | ############# Block6-up 32-32 ############## 429 | self.pa2 = Attention(channel=256) 430 | self.dense_block6 = BottleneckBlock(256, 128) 431 | self.trans_block6 = TransitionBlock(384, 64) 432 | 433 | ############# Block7-up 64-64 ############## 434 | self.pa3 = Attention(channel=64) 435 | self.dense_block7 = BottleneckBlock(64, 64) 436 | self.trans_block7 = TransitionBlock(128, 32) 437 | 438 | ## 128 X 128 439 | ############# Block8-up c ############## 440 | self.pa4 = Attention(channel=32) 441 | self.dense_block8 = BottleneckBlock(32, 32) 442 | self.trans_block8 = TransitionBlock(64, 16) 443 | 444 | self.conv_refin = nn.Conv2d(19, 20, 3, 1, 1) 445 | self.tanh = nn.Tanh() 446 | 447 | self.conv1010 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm 448 | self.conv1020 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm 449 | self.conv1030 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm 450 | self.conv1040 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm 451 | 452 | self.refine3 = nn.Conv2d(20 + 4, 3, kernel_size=3, stride=1, padding=1) 453 | # self.refine3= nn.Conv2d(20+4, 3, kernel_size=7,stride=1,padding=3) 454 | 455 | self.upsample = F.upsample_nearest 456 | 457 | self.relu = nn.LeakyReLU(0.2, inplace=True) 458 | 459 | def forward(self, x): 460 | ## 256x256 x0[1,64,64,64] 461 | x0 = self.pool0(self.relu0(self.norm0(self.conv0(x)))) 462 | 463 | ## 64 X 64 x1[1,256,64,64] 464 | x1 = self.dense_block1(x0) 465 | # print x1.size() x1[1,128,32,32] 466 | x1 = self.trans_block1(x1) 467 | 468 | ### 32x32 x2[1,256,16,16] 469 | x2 = self.trans_block2(self.dense_block2(x1)) 470 | # print x2.size() 471 | 472 | ### 16 X 16 x3[1,512,8,8] 473 | x3 = self.trans_block3(self.dense_block3(x2)) 474 | 475 | # x3=Variable(x3.data,requires_grad=True) 476 | 477 | ## 8 X 8 x4[1,128,16,16] 478 | x4 = self.trans_block4(self.dense_block4(x3)) 479 | ## x42[1,384,16,16] 480 | x42 = torch.cat([x4, x2], 1) 481 | 482 | x42 = self.pa1(x42) 483 | ## 16 X 16 x5[1,128,32,32] 484 | x5 = self.trans_block5(self.dense_block5(x42)) 485 | ##x52[1,256,32,32] 486 | x52 = torch.cat([x5, x1], 1) 487 | ## 32 X 32 x6[1,64,64,64] 488 | 489 | x52 = self.pa2(x52) 490 | x6 = self.trans_block6(self.dense_block6(x52)) 491 | x6 = self.pa3(x6) 492 | ## 64 X 64 x7[1,32,128,128] 493 | x7 = self.trans_block7(self.dense_block7(x6)) 494 | x7 = self.pa4(x7) 495 | ## 128 X 128 x8[1,16,256,256] 496 | x8 = self.trans_block8(self.dense_block8(x7)) 497 | 498 | # print x8.size() 499 | # print x.size() 500 | ##x8[1,19,256,256] 501 | x8 = torch.cat([x8, x], 1) 502 | 503 | # print x8.size() 504 | ##x9[1,20,256,256] 505 | x9 = self.relu(self.conv_refin(x8)) 506 | 507 | shape_out = x9.data.size() 508 | # print(shape_out) 509 | shape_out = shape_out[2:4] 510 | ## x101[1,20,8,8] 511 | ## x102[1,20,16,16] 512 | ## x103[1,20,32,32] 513 | ## x104[1,20,64,64] 514 | x101 = F.avg_pool2d(x9, 32) 515 | x102 = F.avg_pool2d(x9, 16) 516 | x103 = F.avg_pool2d(x9, 8) 517 | x104 = F.avg_pool2d(x9, 4) 518 | ## x1010[1,1,256,256] 519 | ## x1020[1,1,256,256] 520 | x1010 = self.upsample(self.relu(self.conv1010(x101)), size=shape_out) 521 | x1020 = self.upsample(self.relu(self.conv1020(x102)), size=shape_out) 522 | x1030 = self.upsample(self.relu(self.conv1030(x103)), size=shape_out) 523 | x1040 = self.upsample(self.relu(self.conv1040(x104)), size=shape_out) 524 | 525 | dehaze = torch.cat((x1010, x1020, x1030, x1040, x9), 1) 526 | dehaze = self.tanh(self.refine3(dehaze)) 527 | 528 | return dehaze 529 | 530 | 531 | class Parameter(nn.Module): 532 | def __init__(self): 533 | super(Parameter, self).__init__() 534 | 535 | self.atp_est = G2(input_nc=3, output_nc=3, nf=8) 536 | 537 | self.tran_dense = Dense() 538 | self.relu = nn.LeakyReLU(0.2, inplace=True) 539 | 540 | self.tanh = nn.Tanh() 541 | 542 | self.refine1 = nn.Conv2d(6, 20, kernel_size=3, stride=1, padding=1) 543 | self.refine2 = nn.Conv2d(20, 20, kernel_size=3, stride=1, padding=1) 544 | self.threshold = nn.Threshold(0.1, 0.1) 545 | 546 | self.conv1010 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm 547 | self.conv1020 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm 548 | self.conv1030 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm 549 | self.conv1040 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm 550 | 551 | self.refine3 = nn.Conv2d(20 + 4, 3, kernel_size=3, stride=1, padding=1) 552 | 553 | self.upsample = F.upsample_nearest 554 | 555 | self.batch1 = nn.BatchNorm2d(20) 556 | 557 | def forward(self, x, y): 558 | tran = self.tran_dense(x) 559 | atp = self.atp_est(x) 560 | 561 | zz = torch.abs((tran)) + (10 ** -10) # t 562 | shape_out1 = atp.data.size() 563 | 564 | shape_out = shape_out1[2:4] 565 | atp = F.avg_pool2d(atp, shape_out1[2]) 566 | atp = self.upsample(self.relu(atp), size=shape_out) 567 | 568 | haze = (y * zz) + atp * (1 - zz) 569 | dehaze = (x - atp) / zz + atp # 去雾公式 570 | 571 | return haze, dehaze 572 | 573 | 574 | if __name__ == '__main__': 575 | x = torch.rand((1, 64, 2, 2)) 576 | y = torch.rand((1, 3, 256, 256)) 577 | a=blockUNet(64, 64, "0", transposed=False, bn=True, relu=False, dropout=False) 578 | #a = Parameter() 579 | #a.eval() 580 | b = a(x) 581 | -------------------------------------------------------------------------------- /UBRFC/Util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | 4 | 5 | class ResnetBlock(nn.Module): 6 | 7 | def __init__(self, dim, down=True, first=False, levels=3, bn=False): 8 | super(ResnetBlock, self).__init__() 9 | blocks = [] 10 | for i in range(levels): 11 | blocks.append(Block(dim=dim, bn=bn)) 12 | self.res = nn.Sequential( 13 | *blocks 14 | ) if not first else None 15 | self.downsample_layer = nn.Sequential( 16 | nn.InstanceNorm2d(dim, eps=1e-6) if not bn else nn.BatchNorm2d(dim, eps=1e-6), 17 | nn.ReflectionPad2d(1), 18 | nn.Conv2d(dim, dim * 2, kernel_size=3, stride=2) 19 | ) if down else None 20 | self.stem = nn.Sequential( 21 | nn.ReflectionPad2d(3), 22 | nn.Conv2d(dim, 64, kernel_size=7), 23 | nn.InstanceNorm2d(64, eps=1e-6) 24 | ) if first else None 25 | 26 | def forward(self, x): 27 | if self.stem is not None: 28 | out = self.stem(x) 29 | return out 30 | out = x + self.res(x) 31 | if self.downsample_layer is not None: 32 | out = self.downsample_layer(out) 33 | return out 34 | 35 | 36 | class Block(nn.Module): 37 | 38 | def __init__(self, dim, bn=False): 39 | super(Block, self).__init__() 40 | 41 | conv_block = [] 42 | conv_block += [nn.ReflectionPad2d(1)] 43 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0), 44 | nn.InstanceNorm2d(dim) if not bn else nn.BatchNorm2d(dim, eps=1e-6), 45 | nn.LeakyReLU()] 46 | 47 | conv_block += [nn.ReflectionPad2d(1)] 48 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0), 49 | nn.InstanceNorm2d(dim) if not bn else nn.BatchNorm2d(dim, eps=1e-6)] 50 | 51 | self.conv_block = nn.Sequential(*conv_block) 52 | 53 | def forward(self, x): 54 | out = x + self.conv_block(x) 55 | return out 56 | 57 | 58 | class PALayer(nn.Module): 59 | 60 | def __init__(self, channel): 61 | super(PALayer, self).__init__() 62 | self.pa = nn.Sequential( 63 | nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True), 64 | nn.ReLU(inplace=True), 65 | nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True), 66 | nn.Sigmoid() 67 | ) 68 | 69 | def forward(self, x): 70 | y = self.pa(x) 71 | return x * y 72 | 73 | 74 | class ConvBlock(nn.Module): 75 | 76 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True, 77 | bn=False, bias=False): 78 | super(ConvBlock, self).__init__() 79 | self.out_channels = out_planes 80 | self.relu = nn.LeakyReLU(inplace=False) if relu else None 81 | self.pad = nn.ReflectionPad2d(padding) 82 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0, 83 | dilation=dilation, groups=groups, bias=bias) 84 | self.bn = nn.InstanceNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if not bn else nn.BatchNorm2d( 85 | out_planes, eps=1e-5, momentum=0.01, affine=True) 86 | 87 | def forward(self, x): 88 | if self.relu is not None: 89 | x = self.relu(x) 90 | x = self.pad(x) 91 | x = self.conv(x) 92 | if self.bn is not None: 93 | x = self.bn(x) 94 | return x 95 | 96 | 97 | class Fusion_Block(nn.Module): 98 | 99 | def __init__(self, channel, bn=False, res=False): 100 | super(Fusion_Block, self).__init__() 101 | self.bn = nn.InstanceNorm2d(channel, eps=1e-5, momentum=0.01, affine=True) if not bn else nn.BatchNorm2d( 102 | channel, eps=1e-5, momentum=0.01, affine=True) 103 | self.merge = nn.Sequential( 104 | ConvBlock(channel, channel, kernel_size=(3, 3), stride=1, padding=1), 105 | nn.LeakyReLU(inplace=False) 106 | ) if not res else None 107 | self.block = ResnetBlock(channel, down=False, levels=2, bn=bn) if res else None 108 | 109 | def forward(self, o, s): 110 | o_bn = self.bn(o) if self.bn is not None else o 111 | x = o_bn + s 112 | if self.merge is not None: 113 | x = self.merge(x) 114 | if self.block is not None: 115 | x = self.block(x) 116 | return x 117 | 118 | 119 | class FE_Block(nn.Module): 120 | 121 | def __init__(self, plane1, plane2, res=True): 122 | super(FE_Block, self).__init__() 123 | self.dsc = ConvBlock(plane1, plane2, kernel_size=(3, 3), stride=2, padding=1, relu=False) 124 | 125 | self.merge = nn.Sequential( 126 | ConvBlock(plane2, plane2, kernel_size=(3, 3), stride=1, padding=1), 127 | nn.LeakyReLU(inplace=False) 128 | ) if not res else None 129 | self.block = ResnetBlock(plane2, down=False, levels=2) if res else None 130 | 131 | def forward(self, p, s): 132 | x = s + self.dsc(p) 133 | if self.merge is not None: 134 | x = self.merge(x) 135 | if self.block is not None: 136 | x = self.block(x) 137 | return x 138 | 139 | 140 | class Iter_Downsample(nn.Module): 141 | def __init__(self, ): 142 | super(Iter_Downsample, self).__init__() 143 | self.ds1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) 144 | self.ds2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0) 145 | 146 | def forward(self, x): 147 | x1 = self.ds1(x) 148 | x2 = self.ds2(x1) 149 | return x, x1, x2 150 | 151 | 152 | class CALayer(nn.Module): 153 | 154 | def __init__(self, channel): 155 | super(CALayer, self).__init__() 156 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 157 | self.ca = nn.Sequential( 158 | nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True), 159 | nn.ReLU(inplace=True), 160 | nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True), 161 | nn.Sigmoid() 162 | ) 163 | 164 | def forward(self, x): 165 | y = self.avg_pool(x) 166 | y = self.ca(y) 167 | return x * y 168 | 169 | 170 | class ConvGroups(nn.Module): 171 | 172 | def __init__(self, in_planes, bn=False): 173 | super(ConvGroups, self).__init__() 174 | self.iter_ds = Iter_Downsample() 175 | self.lcb1 = nn.Sequential( 176 | ConvBlock(in_planes, 16, kernel_size=(3, 3), padding=1), ConvBlock(16, 16, kernel_size=1, stride=1), 177 | ConvBlock(16, 16, kernel_size=(3, 3), padding=1, bn=bn), 178 | ConvBlock(16, 64, kernel_size=1, bn=bn, relu=False)) 179 | self.lcb2 = nn.Sequential( 180 | ConvBlock(in_planes, 32, kernel_size=(3, 3), padding=1), ConvBlock(32, 32, kernel_size=1), 181 | ConvBlock(32, 32, kernel_size=(3, 3), padding=1), ConvBlock(32, 32, kernel_size=1, stride=1, bn=bn), 182 | ConvBlock(32, 32, kernel_size=(3, 3), padding=1, bn=bn), 183 | ConvBlock(32, 128, kernel_size=1, bn=bn, relu=False)) 184 | self.lcb3 = nn.Sequential( 185 | ConvBlock(in_planes, 64, kernel_size=(3, 3), padding=1), ConvBlock(64, 64, kernel_size=1), 186 | ConvBlock(64, 64, kernel_size=(3, 3), padding=1), ConvBlock(64, 64, kernel_size=1, bn=bn), 187 | ConvBlock(64, 64, kernel_size=(3, 3), padding=1, bn=bn), 188 | ConvBlock(64, 256, kernel_size=1, bn=bn, relu=False)) 189 | 190 | def forward(self, x): 191 | img1, img2, img3 = self.iter_ds(x) 192 | s1 = self.lcb1(img1) 193 | s2 = self.lcb2(img2) 194 | s3 = self.lcb3(img3) 195 | return s1, s2, s3 196 | 197 | 198 | def padding_image(image, h, w): 199 | assert h >= image.size(2) 200 | assert w >= image.size(3) 201 | padding_top = (h - image.size(2)) // 2 202 | padding_down = h - image.size(2) - padding_top 203 | padding_left = (w - image.size(3)) // 2 204 | padding_right = w - image.size(3) - padding_left 205 | out = torch.nn.functional.pad(image, (padding_left, padding_right, padding_top, padding_down), mode='reflect') 206 | return out, padding_left, padding_left + image.size(3), padding_top, padding_top + image.size(2) 207 | 208 | 209 | if __name__ == '__main__': 210 | print(CALayer(3)) -------------------------------------------------------------------------------- /UBRFC/test.py: -------------------------------------------------------------------------------- 1 | import os.path 2 | import numpy as np 3 | from tqdm import tqdm 4 | from Option import opt 5 | from Declaration import device,generator,parameterNet,test_dataloader 6 | from Metrics import ssim, psnr 7 | from torchvision.utils import save_image 8 | import math 9 | import torch 10 | def padding_image(image, h, w): 11 | assert h >= image.size(2) 12 | assert w >= image.size(3) 13 | padding_top = (h - image.size(2)) // 2 14 | padding_down = h - image.size(2) - padding_top 15 | padding_left = (w - image.size(3)) // 2 16 | padding_right = w - image.size(3) - padding_left 17 | out = torch.nn.functional.pad(image, (padding_left, padding_right, padding_top, padding_down), mode='reflect') 18 | return out, padding_left, padding_left + image.size(3), padding_top, padding_top + image.size(2) 19 | 20 | def test(genertor_net, parameter_net, loader_test): 21 | genertor_net.eval() 22 | parameter_net.eval() 23 | clear_psnrs,clear_ssims = [],[] 24 | for i, (inputs, targets,name) in tqdm(enumerate(loader_test), total=len(loader_test), leave=False,desc="测试中"): 25 | 26 | h, w = inputs.shape[2], inputs.shape[3] 27 | 28 | 29 | if h>w: 30 | max_h = int(math.ceil(h / 512)) * 512 31 | max_w = int(math.ceil(w / 512)) * 512 32 | else: 33 | max_h = int(math.ceil(h / 256)) * 256 34 | max_w = int(math.ceil(w / 256)) * 256 35 | inputs, ori_left, ori_right, ori_top, ori_down = padding_image(inputs, max_h, max_w) 36 | 37 | inputs = inputs.to(device) 38 | targets = targets.to(device) 39 | 40 | pred = genertor_net(inputs) 41 | _, dehazy_pred = parameter_net(inputs, pred) 42 | 43 | 44 | pred = pred.data[:, :, ori_top:ori_down, ori_left:ori_right] 45 | dehazy_pred = dehazy_pred.data[:, :, ori_top:ori_down, ori_left:ori_right] 46 | 47 | ssim1 = ssim(pred, targets).item() 48 | psnr1 = psnr(pred, targets) 49 | 50 | ssim11 = ssim(dehazy_pred, targets).item() 51 | psnr11 = psnr(dehazy_pred, targets) 52 | 53 | save_image(inputs.data[:1], os.path.join(opt.out_hazy_path,"%s.png" % name[0])) 54 | save_image(targets.data[:1], os.path.join(opt.out_gt_path,"%s.png" % name[0])) 55 | 56 | if psnr11 >= psnr1: 57 | save_image(dehazy_pred.data[:1], os.path.join(opt.out_clear_path,"%s.png" % name[0])) 58 | clear_ssims.append(ssim11) 59 | clear_psnrs.append(psnr11) 60 | else: 61 | save_image(pred.data[:1], os.path.join(opt.out_clear_path,"%s.png" % name[0])) 62 | clear_ssims.append(ssim1) 63 | clear_psnrs.append(psnr1) 64 | 65 | return np.mean(clear_ssims), np.mean(clear_psnrs) 66 | 67 | if __name__ == '__main__': 68 | print(test(generator,parameterNet,test_dataloader)) -------------------------------------------------------------------------------- /UBRFC/train.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from tqdm import tqdm 3 | from test import test 4 | from Declaration import * 5 | 6 | def train(): 7 | global lr 8 | generator.train() 9 | discriminator.train() 10 | parameterNet.train() 11 | psnrs,ssims = [],[] 12 | for epoch in range(iter_num + 1, epochs): 13 | loss_total = 0 14 | for i, (x, y,_) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False, 15 | desc="epoch is %d" % epoch): 16 | 17 | x = x.to(device) 18 | y = y.to(device) 19 | 20 | real_label = torch.ones((x.size()[0], 1, d_out_size, d_out_size), requires_grad=False).to(device) 21 | fake_label = torch.zeros((x.size()[0], 1, d_out_size, d_out_size), requires_grad=False).to(device) 22 | 23 | real_out = discriminator(y) 24 | loss_real_D = criterion(real_out, real_label) 25 | 26 | fake_img = generator(x) 27 | 28 | fake_out = discriminator(fake_img.detach()) 29 | loss_fake_D = criterion(fake_out, fake_label) 30 | 31 | loss_D = (loss_real_D + loss_fake_D) / 2 32 | 33 | optimizer_D.zero_grad() 34 | loss_D.backward() 35 | optimizer_D.step() 36 | #D_loss = loss_D.item() 37 | 38 | #fake_img_ = generator(x) 39 | output = discriminator(fake_img) 40 | haze, dehaze = parameterNet(x, fake_img) 41 | 42 | loss_G = criterion(output, real_label) 43 | loss_P = criterionP(haze, x) # L1 44 | loss_Right = criterionP(fake_img, dehaze.detach()) # 右拉 45 | 46 | loss_ssim = criterionSsim(fake_img, dehaze.detach()) # 结构 47 | loss_C1 = criterionC(fake_img, dehaze.detach(), x, haze.detach()) # 对比下 48 | loss_C2 = criterionC(haze, x, fake_img.detach(), dehaze.detach()) # 对比上 49 | 50 | total_loss = loss_G + loss_P + loss_Right + 0.1 * loss_ssim + loss_C1 + loss_C2 51 | 52 | optimizer_T.zero_grad() 53 | total_loss.backward() 54 | optimizer_T.step() 55 | 56 | lr = scheduler_T.get_last_lr()[0] 57 | # G_loss = loss_G.item() 58 | # P_loss = loss_P.item() 59 | loss_total = total_loss.item() 60 | 61 | if (epoch % 10 == 0 and epoch != 0) or epoch>=30: 62 | ###测试 63 | with torch.no_grad(): 64 | ssim_eval, psnr_eval = test(generator, parameterNet, test_dataloader) 65 | ssims = np.append(ssims, ssim_eval) 66 | psnrs = np.append(psnrs, psnr_eval) 67 | 68 | msg_clear = f"epoch: {epoch}|lr:{lr}|训练:[total loss: %4f]" % (loss_total) + "测试:" + f'ssim:{ssim_eval:.4f}| psnr:{psnr_eval:.4f}\n' 69 | print(msg_clear) 70 | 71 | ##保存log文件 72 | file = open('./Log/log_train/train_' + str(timeStamp) + '.txt', 'a+') 73 | file.write(msg_clear) 74 | file.close() 75 | model = {"generator_net": generator.state_dict(),"discriminator_net": discriminator.state_dict(),"parameterNet_net": parameterNet.state_dict(), 76 | "optimizer_T": optimizer_T.state_dict(),"optimizer_D": optimizer_D.state_dict(),"lr": lr, "epoch": epoch} 77 | 78 | torch.save(model, "./model/last_model.pth") 79 | if psnr_eval == psnrs[np.array(psnrs).argmax()]: 80 | max_msg = "epoch:%d,配对ssim:%f,最大psnr:%f\n" % (epoch, ssim_eval, psnr_eval) 81 | files = open('./Log/max_log/max_log' + str(timeStamp) + '.txt', 'a+') 82 | files.write(max_msg) 83 | files.close() 84 | torch.save(model, "./model/epoch%d_%f_%f_best_model.pth"%(epoch,ssim_eval,psnr_eval)) 85 | 86 | scheduler_T.step() 87 | scheduler_D.step() 88 | 89 | 90 | 91 | 92 | 93 | if __name__ == '__main__': 94 | train() 95 | -------------------------------------------------------------------------------- /images/Attention_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lose-Code/UBRFC-Net/1e9880588649ff077ad7a1cc04faa5ae883ce47a/images/Attention_00.png -------------------------------------------------------------------------------- /images/Dense_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lose-Code/UBRFC-Net/1e9880588649ff077ad7a1cc04faa5ae883ce47a/images/Dense_00.png -------------------------------------------------------------------------------- /images/Haze1k_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lose-Code/UBRFC-Net/1e9880588649ff077ad7a1cc04faa5ae883ce47a/images/Haze1k_00.png -------------------------------------------------------------------------------- /images/NH_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lose-Code/UBRFC-Net/1e9880588649ff077ad7a1cc04faa5ae883ce47a/images/NH_00.png -------------------------------------------------------------------------------- /images/Outdoor_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lose-Code/UBRFC-Net/1e9880588649ff077ad7a1cc04faa5ae883ce47a/images/Outdoor_00.png -------------------------------------------------------------------------------- /images/framework_00.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Lose-Code/UBRFC-Net/1e9880588649ff077ad7a1cc04faa5ae883ce47a/images/framework_00.png --------------------------------------------------------------------------------