├── .gitee
├── ISSUE_TEMPLATE.zh-CN.md
└── PULL_REQUEST_TEMPLATE.zh-CN.md
├── .gitignore
├── LICENSE
├── README.md
├── UBRFC
├── Attention.py
├── CR.py
├── Dataset.py
├── Declaration.py
├── GAN.py
├── Get_image.py
├── Loss.py
├── Metrics.py
├── Option.py
├── Parameter.py
├── Util.py
├── test.py
└── train.py
└── images
├── Attention_00.png
├── Dense_00.png
├── Haze1k_00.png
├── NH_00.png
├── Outdoor_00.png
└── framework_00.png
/.gitee/ISSUE_TEMPLATE.zh-CN.md:
--------------------------------------------------------------------------------
1 | ### 该问题是怎么引起的?
2 |
3 |
4 |
5 | ### 重现步骤
6 |
7 |
8 |
9 | ### 报错信息
10 |
11 |
12 |
13 |
14 |
--------------------------------------------------------------------------------
/.gitee/PULL_REQUEST_TEMPLATE.zh-CN.md:
--------------------------------------------------------------------------------
1 | ### 一、内容说明(相关的Issue)
2 |
3 |
4 |
5 | ### 二、建议测试周期和提测地址
6 | 建议测试完成时间:xxxx.xx.xx
7 | 投产上线时间:xxxx.xx.xx
8 | 提测地址:CI环境/压测环境
9 | 测试账号:
10 |
11 | ### 三、变更内容
12 | * 3.1 关联PR列表
13 |
14 | * 3.2 数据库和部署说明
15 | 1. 常规更新
16 | 2. 重启unicorn
17 | 3. 重启sidekiq
18 | 4. 迁移任务:是否有迁移任务,没有写 "无"
19 | 5. rake脚本:`bundle exec xxx RAILS_ENV = production`;没有写 "无"
20 |
21 | * 3.4 其他技术优化内容(做了什么,变更了什么)
22 | - 重构了 xxxx 代码
23 | - xxxx 算法优化
24 |
25 |
26 | * 3.5 废弃通知(什么字段、方法弃用?)
27 |
28 |
29 |
30 | * 3.6 后向不兼容变更(是否有无法向后兼容的变更?)
31 |
32 |
33 |
34 | ### 四、研发自测点(自测哪些?冒烟用例全部自测?)
35 | 自测测试结论:
36 |
37 |
38 | ### 五、测试关注点(需要提醒QA重点关注的、可能会忽略的地方)
39 | 检查点:
40 |
41 | | 需求名称 | 是否影响xx公共模块 | 是否需要xx功能 | 需求升级是否依赖其他子产品 |
42 | |------|------------|----------|---------------|
43 | | xxx | 否 | 需要 | 不需要 |
44 | | | | | |
45 |
46 | 接口测试:
47 |
48 | 性能测试:
49 |
50 | 并发测试:
51 |
52 | 其他:
53 |
54 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | # Object file
2 | *.o
3 |
4 | # Ada Library Information
5 | *.ali
6 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Academic Free License (“AFL”) v. 3.0
2 |
3 | This Academic Free License (the "License") applies to any original work of authorship (the "Original Work") whose owner (the "Licensor") has placed the following licensing notice adjacent to the copyright notice for the Original Work:
4 |
5 | Licensed under the Academic Free License version 3.0
6 |
7 | 1) Grant of Copyright License. Licensor grants You a worldwide, royalty-free, non-exclusive, sublicensable license, for the duration of the copyright, to do the following:
8 |
9 | a) to reproduce the Original Work in copies, either alone or as part of a collective work;
10 | b) to translate, adapt, alter, transform, modify, or arrange the Original Work, thereby creating derivative works ("Derivative Works") based upon the Original Work;
11 | c) to distribute or communicate copies of the Original Work and Derivative Works to the public, under any license of your choice that does not contradict the terms and conditions, including Licensor’s reserved rights and remedies, in this Academic Free License;
12 | d) to perform the Original Work publicly; and
13 | e) to display the Original Work publicly.
14 |
15 | 2) Grant of Patent License. Licensor grants You a worldwide, royalty-free, non-exclusive, sublicensable license, under patent claims owned or controlled by the Licensor that are embodied in the Original Work as furnished by the Licensor, for the duration of the patents, to make, use, sell, offer for sale, have made, and import the Original Work and Derivative Works.
16 |
17 | 3) Grant of Source Code License. The term "Source Code" means the preferred form of the Original Work for making modifications to it and all available documentation describing how to modify the Original Work. Licensor agrees to provide a machine-readable copy of the Source Code of the Original Work along with each copy of the Original Work that Licensor distributes. Licensor reserves the right to satisfy this obligation by placing a machine-readable copy of the Source Code in an information repository reasonably calculated to permit inexpensive and convenient access by You for as long as Licensor continues to distribute the Original Work.
18 |
19 | 4) Exclusions From License Grant. Neither the names of Licensor, nor the names of any contributors to the Original Work, nor any of their trademarks or service marks, may be used to endorse or promote products derived from this Original Work without express prior permission of the Licensor. Except as expressly stated herein, nothing in this License grants any license to Licensor’s trademarks, copyrights, patents, trade secrets or any other intellectual property. No patent license is granted to make, use, sell, offer for sale, have made, or import embodiments of any patent claims other than the licensed claims defined in Section 2. No license is granted to the trademarks of Licensor even if such marks are included in the Original Work. Nothing in this License shall be interpreted to prohibit Licensor from licensing under terms different from this License any Original Work that Licensor otherwise would have a right to license.
20 |
21 | 5) External Deployment. The term "External Deployment" means the use, distribution, or communication of the Original Work or Derivative Works in any way such that the Original Work or Derivative Works may be used by anyone other than You, whether those works are distributed or communicated to those persons or made available as an application intended for use over a network. As an express condition for the grants of license hereunder, You must treat any External Deployment by You of the Original Work or a Derivative Work as a distribution under section 1(c).
22 |
23 | 6) Attribution Rights. You must retain, in the Source Code of any Derivative Works that You create, all copyright, patent, or trademark notices from the Source Code of the Original Work, as well as any notices of licensing and any descriptive text identified therein as an "Attribution Notice." You must cause the Source Code for any Derivative Works that You create to carry a prominent Attribution Notice reasonably calculated to inform recipients that You have modified the Original Work.
24 |
25 | 7) Warranty of Provenance and Disclaimer of Warranty. Licensor warrants that the copyright in and to the Original Work and the patent rights granted herein by Licensor are owned by the Licensor or are sublicensed to You under the terms of this License with the permission of the contributor(s) of those copyrights and patent rights. Except as expressly stated in the immediately preceding sentence, the Original Work is provided under this License on an "AS IS" BASIS and WITHOUT WARRANTY, either express or implied, including, without limitation, the warranties of non-infringement, merchantability or fitness for a particular purpose. THE ENTIRE RISK AS TO THE QUALITY OF THE ORIGINAL WORK IS WITH YOU. This DISCLAIMER OF WARRANTY constitutes an essential part of this License. No license to the Original Work is granted by this License except under this disclaimer.
26 |
27 | 8) Limitation of Liability. Under no circumstances and under no legal theory, whether in tort (including negligence), contract, or otherwise, shall the Licensor be liable to anyone for any indirect, special, incidental, or consequential damages of any character arising as a result of this License or the use of the Original Work including, without limitation, damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses. This limitation of liability shall not apply to the extent applicable law prohibits such limitation.
28 |
29 | 9) Acceptance and Termination. If, at any time, You expressly assented to this License, that assent indicates your clear and irrevocable acceptance of this License and all of its terms and conditions. If You distribute or communicate copies of the Original Work or a Derivative Work, You must make a reasonable effort under the circumstances to obtain the express assent of recipients to the terms of this License. This License conditions your rights to undertake the activities listed in Section 1, including your right to create Derivative Works based upon the Original Work, and doing so without honoring these terms and conditions is prohibited by copyright law and international treaty. Nothing in this License is intended to affect copyright exceptions and limitations (including “fair use” or “fair dealing”). This License shall terminate immediately and You may no longer exercise any of the rights granted to You by this License upon your failure to honor the conditions in Section 1(c).
30 |
31 | 10) Termination for Patent Action. This License shall terminate automatically and You may no longer exercise any of the rights granted to You by this License as of the date You commence an action, including a cross-claim or counterclaim, against Licensor or any licensee alleging that the Original Work infringes a patent. This termination provision shall not apply for an action alleging patent infringement by combinations of the Original Work with other software or hardware.
32 |
33 | 11) Jurisdiction, Venue and Governing Law. Any action or suit relating to this License may be brought only in the courts of a jurisdiction wherein the Licensor resides or in which Licensor conducts its primary business, and under the laws of that jurisdiction excluding its conflict-of-law provisions. The application of the United Nations Convention on Contracts for the International Sale of Goods is expressly excluded. Any use of the Original Work outside the scope of this License or after its termination shall be subject to the requirements and penalties of copyright or patent law in the appropriate jurisdiction. This section shall survive the termination of this License.
34 |
35 | 12) Attorneys’ Fees. In any action to enforce the terms of this License or seeking damages relating thereto, the prevailing party shall be entitled to recover its costs and expenses, including, without limitation, reasonable attorneys' fees and costs incurred in connection with such action, including any appeal of such action. This section shall survive the termination of this License.
36 |
37 | 13) Miscellaneous. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable.
38 |
39 | 14) Definition of "You" in This License. "You" throughout this License, whether in upper or lower case, means an individual or a legal entity exercising rights under, and complying with all of the terms of, this License. For legal entities, "You" includes any entity that controls, is controlled by, or is under common control with you. For purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
40 |
41 | 15) Right to Use. You may use the Original Work in all ways not otherwise restricted or conditioned by this License or by law, and Licensor promises not to interfere with or be responsible for such uses by You.
42 |
43 | 16) Modification of This License. This License is Copyright © 2005 Lawrence Rosen. Permission is granted to copy, distribute, or communicate this License without modification. Nothing in this License permits You to modify this License as applied to the Original Work or to Derivative Works. However, You may modify the text of this License and copy, distribute or communicate your modified version (the "Modified License") and apply it to other original works of authorship subject to the following conditions: (i) You may not indicate in any way that your Modified License is the "Academic Free License" or "AFL" and you may not use those names in the name of your Modified License; (ii) You must replace the notice specified in the first paragraph above with the notice "Licensed under " or with a notice of your own that is not confusingly similar to the notice in this License; and (iii) You may not claim that your original works are open source software unless your Modified License has been approved by Open Source Initiative (OSI) and You comply with its license review and certification process.
44 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # UBRFC
2 |
3 |
4 |
5 | [![Contributors][contributors-shield]][contributors-url]
6 | [![Forks][forks-shield]][forks-url]
7 | [![Stargazers][stars-shield]][stars-url]
8 | [![Issues][issues-shield]][issues-url]
9 |
10 |
11 |
12 |
13 |
14 |
15 |
16 |
17 |
Contrastive Bidirectional Reconstruction Framework
18 |
19 |
20 |
21 |
22 |
23 | Adaptive Fine-Grained Channel Attention
24 |
25 |
26 | Unsupervised Bidirectional Contrastive Reconstruction and Adaptive Fine-Grained Channel Attention Networks for Image Dehazing
27 |
28 | Exploring the documentation for UBRFC-Net »
29 |
30 |
31 | Check Demo
32 | ·
33 | Report Bug
34 | ·
35 | Pull Request
36 |
37 |
38 |
39 |
40 | ## Contents
41 |
42 | - [Dependencies](#dependences)
43 | - [Filetree](#filetree)
44 | - [Pretrained Model](#pretrained-weights-and-dataset)
45 | - [Train](#train)
46 | - [Test](#test)
47 | - [Clone the repo](#clone-the-repo)
48 | - [Qualitative Results](#qualitative-results)
49 | - [Results on RESIDE-Outdoor Dehazing Challenge testing images:](#results-on-reside-outdoor-dehazing-challenge-testing-images)
50 | - [Results on NTIRE 2021 NonHomogeneous Dehazing Challenge testing images:](#results-on-ntire-2021-nonhomogeneous-dehazing-challenge-testing-images)
51 | - [Results on Dense Dehazing Challenge testing images:](#results-on-dense-dehazing-challenge-testing-images)
52 | - [Results on Statehaze1k remote sensing Dehazing Challenge testing images:](#results-on-statehaze1k-remote-sensing-dehazing-challenge-testing-images)
53 | - [Copyright](#copyright)
54 | - [Thanks](#thanks)
55 |
56 | ### Dependences
57 |
58 | 1. Pytorch 1.8.0
59 | 2. Python 3.7.1
60 | 3. CUDA 11.7
61 | 4. Ubuntu 18.04
62 |
63 | ### Filetree
64 | ```
65 | ├─README.md
66 | │
67 | ├─UBRFC
68 | │ Attention.py
69 | │ CR.py
70 | │ Dataset.py
71 | │ GAN.py
72 | │ Get_image.py
73 | │ Loss.py
74 | │ Metrics.py
75 | │ Option.py
76 | │ Parameter.py
77 | │ test.py
78 | │ Util.py
79 | │
80 | ├─images
81 | │ Attention_00.png
82 | │ Dense_00.png
83 | │ framework_00.png
84 | │ NH_00.png
85 | │ Outdoor_00.png
86 | │
87 | └─LICENSE
88 | ```
89 | ### Pretrained Weights and Dataset
90 |
91 | Download our model weights on Google: https://drive.google.com/drive/folders/1fyTzElUd5JvKthlf_1o4PTcoC0mm9ar-?usp=sharing
92 |
93 | Download our test datasets on Google: https://drive.google.com/drive/folders/13Al-It-4srPW7YjS-Iajl54FEtgXNYRC?usp=sharing
94 |
95 | ### Train
96 |
97 | ```shell
98 | python train.py --device 0 --train_root train_path --test_root test_path --batch_size 4
99 | such as:
100 | python train.py --device 0 --train_root /home/Datasets/Outdoor/train/ --test_root /home/Datasets/Outdoor/test/ --batch_size 4
101 | ```
102 |
103 | ### Test
104 |
105 | ```shell
106 | python Get_image.py --device GUP_id --test_root test_path --pre_model_path model_path
107 | such as:
108 | python Get_image.py --device 0 --test_root /home/Dense_hazy/test/ --pre_model_path ./model/best_model.pth
109 | ```
110 |
111 | ### Clone the repo
112 |
113 | ```sh
114 | git clone https://github.com/Lose-Code/UBRFC-Net.git
115 | ```
116 |
117 | ### Qualitative Results
118 |
119 | #### Results on RESIDE-Outdoor Dehazing Challenge testing images
120 |
121 |

122 |
123 |
124 | #### Results on NTIRE 2021 NonHomogeneous Dehazing Challenge testing images
125 |
126 |

127 |
128 |
129 | #### Results on Dense Dehazing Challenge testing images
130 |
131 |

132 |
133 |
134 | #### Results on Statehaze1k remote sensing Dehazing Challenge testing images
135 |
136 |

137 |
138 |
139 |
140 |
141 |
142 |
143 |
144 |
145 | ### Thanks
146 |
147 |
148 | - [GitHub Emoji Cheat Sheet](https://www.webpagefx.com/tools/emoji-cheat-sheet)
149 | - [Img Shields](https://shields.io)
150 | - [Choose an Open Source License](https://choosealicense.com)
151 | - [GitHub Pages](https://pages.github.com)
152 |
153 |
154 |
155 | [contributors-shield]: https://img.shields.io/github/contributors/Lose-Code/UBRFC-Net.svg?style=flat-square
156 | [contributors-url]: https://github.com/Lose-Code/UBRFC-Net/graphs/contributors
157 | [forks-shield]: https://img.shields.io/github/forks/Lose-Code/UBRFC-Net.svg?style=flat-square
158 | [forks-url]: https://github.com/Lose-Code/UBRFC-Net/network/members
159 | [stars-shield]: https://img.shields.io/github/stars/Lose-Code/UBRFC-Net.svg?style=flat-square
160 | [stars-url]: https://github.com/Lose-Code/UBRFC-Net/stargazers
161 | [issues-shield]: https://img.shields.io/github/issues/Lose-Code/UBRFC-Net.svg?style=flat-square
162 | [issues-url]: https://img.shields.io/github/issues/Lose-Code/UBRFC-Net.svg
163 | [license-shield]: https://img.shields.io/github/license/Lose-Code/UBRFC-Net.svg?style=flat-square
164 | [license-url]: https://github.com/Lose-Code/UBRFC-Net/blob/master/LICENSE
165 | [linkedin-shield]: https://img.shields.io/badge/-LinkedIn-black.svg?style=flat-square&logo=linkedin&colorB=555
166 | [linkedin-url]: https://linkedin.com/in/shaojintian
167 |
--------------------------------------------------------------------------------
/UBRFC/Attention.py:
--------------------------------------------------------------------------------
1 | import math
2 | import torch
3 | from torch import nn
4 | #from torchstat import stat # 查看网络参数
5 |
6 | class Mix(nn.Module):
7 | def __init__(self, m=-0.80):
8 | super(Mix, self).__init__()
9 | w = torch.nn.Parameter(torch.FloatTensor([m]), requires_grad=True)
10 | w = torch.nn.Parameter(w, requires_grad=True)
11 | self.w = w
12 | self.mix_block = nn.Sigmoid()
13 |
14 | def forward(self, fea1, fea2):
15 | mix_factor = self.mix_block(self.w)
16 | out = fea1 * mix_factor.expand_as(fea1) + fea2 * (1 - mix_factor.expand_as(fea2))
17 | return out
18 |
19 |
20 |
21 |
22 | # class Attention(nn.Module):
23 | # def __init__(self,channel,b=1, gamma=2):
24 | # super(Attention, self).__init__()
25 | # self.avg_pool = nn.AdaptiveAvgPool2d(1)#全局平均池化
26 | # #一维卷积
27 | # t = int(abs((math.log(channel, 2) + b) / gamma))
28 | # k = t if t % 2 else t + 1
29 | # self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k / 2), bias=False)
30 | # self.fc = nn.Conv2d(channel, channel, 1, padding=0, bias=True)
31 | # self.sigmoid = nn.Sigmoid()
32 | # self.mix = Mix()
33 | # #全连接
34 | # #self.fc = nn.Linear(channel,channel)
35 | # #self.softmax = nn.Softmax(dim=1)
36 |
37 | # def forward(self, input):
38 | # x = self.avg_pool(input)
39 | # x1 = self.conv1(x.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
40 | # out1 = self.sigmoid(x1)
41 | # out2 = self.fc(x)
42 | # out2 = self.sigmoid(out2)
43 | # out = self.mix(out1,out2)
44 | # out = self.conv1(out.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
45 | # #out = self.softmax(out)
46 |
47 | # return input*out
48 | class Attention(nn.Module):
49 | def __init__(self,channel,b=1, gamma=2):
50 | super(Attention, self).__init__()
51 | self.avg_pool = nn.AdaptiveAvgPool2d(1)#全局平均池化
52 | #一维卷积
53 | t = int(abs((math.log(channel, 2) + b) / gamma))
54 | k = t if t % 2 else t + 1
55 | self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k / 2), bias=False)
56 | self.fc = nn.Conv2d(channel, channel, 1, padding=0, bias=True)
57 | self.sigmoid = nn.Sigmoid()
58 | self.mix = Mix()
59 |
60 |
61 | def forward(self, input):
62 | x = self.avg_pool(input)
63 | x1 = self.conv1(x.squeeze(-1).transpose(-1, -2)).transpose(-1, -2)#(1,64,1)
64 | x2 = self.fc(x).squeeze(-1).transpose(-1, -2)#(1,1,64)
65 | out1 = torch.sum(torch.matmul(x1,x2),dim=1).unsqueeze(-1).unsqueeze(-1)#(1,64,1,1)
66 | #x1 = x1.transpose(-1, -2).unsqueeze(-1)
67 | out1 = self.sigmoid(out1)
68 | out2 = torch.sum(torch.matmul(x2.transpose(-1, -2),x1.transpose(-1, -2)),dim=1).unsqueeze(-1).unsqueeze(-1)
69 |
70 | #out2 = self.fc(x)
71 | out2 = self.sigmoid(out2)
72 | out = self.mix(out1,out2)
73 | out = self.conv1(out.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
74 | out = self.sigmoid(out)
75 |
76 | return input*out
77 |
78 | if __name__ == '__main__':
79 | input = torch.rand(1,64,256,256)
80 |
81 | A = Attention(channel=64)
82 | #stat(A, input_size=[64, 1, 1])
83 | y = A(input)
84 | print(y.size())
85 |
86 |
87 |
--------------------------------------------------------------------------------
/UBRFC/CR.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | from torch.nn import functional as F
4 | import torch.nn.functional as fnn
5 | from torch.autograd import Variable
6 | import numpy as np
7 | from torchvision import models
8 |
9 | class Vgg19(torch.nn.Module):
10 | def __init__(self, requires_grad=False):
11 | super(Vgg19, self).__init__()
12 | vgg_pretrained_features = models.vgg19(pretrained=True).features
13 | self.slice1 = torch.nn.Sequential()
14 | self.slice2 = torch.nn.Sequential()
15 | self.slice3 = torch.nn.Sequential()
16 | self.slice4 = torch.nn.Sequential()
17 | self.slice5 = torch.nn.Sequential()
18 | for x in range(2):
19 | self.slice1.add_module(str(x), vgg_pretrained_features[x])
20 | for x in range(2, 7):
21 | self.slice2.add_module(str(x), vgg_pretrained_features[x])
22 | for x in range(7, 12):
23 | self.slice3.add_module(str(x), vgg_pretrained_features[x])
24 | for x in range(12, 21):
25 | self.slice4.add_module(str(x), vgg_pretrained_features[x])
26 | for x in range(21, 30):
27 | self.slice5.add_module(str(x), vgg_pretrained_features[x])
28 | if not requires_grad:
29 | for param in self.parameters():
30 | param.requires_grad = False
31 |
32 | def forward(self, X):
33 | h_relu1 = self.slice1(X)
34 | h_relu2 = self.slice2(h_relu1)
35 | h_relu3 = self.slice3(h_relu2)
36 | h_relu4 = self.slice4(h_relu3)
37 | h_relu5 = self.slice5(h_relu4)
38 | return [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
39 |
40 | class ContrastLoss(nn.Module):
41 | def __init__(self, device,ablation=False):
42 |
43 | super(ContrastLoss, self).__init__()
44 | self.vgg = Vgg19().to(device)
45 | self.l1 = nn.L1Loss().to(device)
46 | self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
47 | self.ab = ablation
48 |
49 | def forward(self, x, gt, push1,push2):
50 | a_vgg, p_vgg, n_vgg,m_vgg = self.vgg(x), self.vgg(gt), self.vgg(push1),self.vgg(push2)
51 | loss = 0
52 |
53 | d_ap, d_an = 0, 0
54 | for i in range(len(a_vgg)):
55 | d_ap = self.l1(a_vgg[i], p_vgg[i].detach())
56 | if not self.ab:
57 | d_an = self.l1(a_vgg[i], n_vgg[i].detach())
58 | d_am = self.l1(a_vgg[i], m_vgg[i].detach())
59 | contrastive = d_ap / (d_an + d_am + 1e-7)
60 | else:
61 | contrastive = d_ap
62 |
63 | loss += self.weights[i] * contrastive
64 | return loss
65 |
--------------------------------------------------------------------------------
/UBRFC/Dataset.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from PIL import Image
4 | from torch.utils.data import Dataset, DataLoader
5 | from torchvision.transforms import transforms
6 |
7 |
8 | class Datasets(Dataset):
9 |
10 | def __init__(self, root_dir, isTrain=True, x_name='hazy', y_name='clean'):
11 | self.root_dir = root_dir
12 | self.isTrain = isTrain
13 | self.train_X = x_name
14 | self.train_Y = y_name
15 | self.X_dir_list = os.listdir(os.path.join(self.root_dir, self.train_X)) # get list of image paths in domain X
16 | self.Y_dir_list = os.listdir(os.path.join(self.root_dir, self.train_Y)) # get list of image paths in domain Y
17 | self.transforms = self.get_transforms()
18 |
19 | def __len__(self):
20 | return len(self.X_dir_list)
21 |
22 | def __getitem__(self, index):
23 | X_img_name = self.X_dir_list[index % len(self.X_dir_list)]
24 |
25 | if self.isTrain:
26 | ind_Y = random.randint(0, len(self.Y_dir_list) - 1)
27 | Y_img_name = self.Y_dir_list[ind_Y]
28 | else:
29 | assert len(self.X_dir_list) == len(self.Y_dir_list)
30 | Y_img_name = self.Y_dir_list[index % len(self.Y_dir_list)]
31 | #Y_img_name = X_img_name
32 | name = Y_img_name.split('.jpg')[0].split('.png')[0]
33 | X_img = Image.open(os.path.join(self.root_dir, self.train_X, X_img_name))
34 | Y_img = Image.open(os.path.join(self.root_dir, self.train_Y, Y_img_name))
35 | X = self.transforms(X_img)
36 | Y = self.transforms(Y_img)
37 |
38 |
39 | return X, Y,name
40 |
41 | def get_transforms(self, crop_size=256):
42 |
43 | if self.isTrain:
44 | all_transforms = [transforms.RandomCrop(crop_size), transforms.ToTensor()]
45 | else:
46 | all_transforms = [transforms.ToTensor()]
47 |
48 | return transforms.Compose(all_transforms)
49 |
50 |
51 | class CustomDatasetLoader():
52 |
53 | def __init__(self, root_dir,isTrain=True,x_name='hazy', y_name='clean',batch_size=2):
54 | if isTrain:
55 | assert batch_size>=2
56 | self.dataset = Datasets(root_dir,isTrain,x_name,y_name)
57 | self.dataloader = DataLoader(self.dataset, batch_size=batch_size, pin_memory=True, shuffle=True,num_workers=4,drop_last=True)
58 |
59 | def __len__(self):
60 | """Return the number of data in the dataset"""
61 | return len(self.dataset)
62 |
63 | def __iter__(self):
64 | """Return a batch of data"""
65 | for i, data in enumerate(self.dataloader):
66 | yield data
67 |
68 | def load_data(self):
69 | return self
70 |
--------------------------------------------------------------------------------
/UBRFC/Declaration.py:
--------------------------------------------------------------------------------
1 | import os
2 | import time
3 | import torch
4 | from Option import opt
5 | from Loss import SSIMLoss
6 | from CR import ContrastLoss
7 | from Parameter import Parameter
8 | from Dataset import CustomDatasetLoader
9 | from GAN import Generator, Discriminator
10 |
11 | lr = opt.lr
12 | d_out_size = 30
13 | device = "cuda:"+opt.device
14 | epochs = opt.epochs
15 |
16 | train_dataloader = CustomDatasetLoader(root_dir=opt.train_root, isTrain=True,batch_size=opt.batch_size,y_name='clear').load_data().dataloader
17 | test_dataloader = CustomDatasetLoader(root_dir=opt.test_root, isTrain=False,batch_size=1,y_name='clear').load_data().dataloader
18 |
19 | generator = Generator()
20 | discriminator = Discriminator() # 输出 bz 1 30 30
21 | parameterNet = Parameter()
22 |
23 | criterionSsim = SSIMLoss()
24 | criterion = torch.nn.MSELoss()
25 | criterionP = torch.nn.L1Loss()
26 | criterionC = ContrastLoss(device,True)
27 |
28 | if os.path.exists(opt.model_path) and opt.is_continue:
29 | lr = torch.load(opt.model_path)["lr"]
30 |
31 | optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr)
32 | optimizer_T = torch.optim.Adam([
33 | {'params': generator.parameters(), 'lr': lr},
34 | {'params': parameterNet.parameters(), 'lr': lr}
35 | ])
36 |
37 | timeStamp = time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())
38 |
39 | generator.to(device)
40 | criterion.to(device)
41 | criterionP.to(device)
42 | parameterNet.to(device)
43 | discriminator.to(device)
44 | criterionSsim.to(device)
45 |
46 | scheduler_T = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_T, T_max=epochs, eta_min=0, last_epoch=-1)
47 | scheduler_D = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer_D, T_max=epochs, eta_min=0, last_epoch=-1)
48 |
49 | if os.path.exists(opt.model_path) and opt.is_continue:
50 | print("加载模型")
51 | generator.load_state_dict(torch.load(opt.model_path)["generator_net"])
52 | discriminator.load_state_dict(torch.load(opt.model_path)["discriminator_net"])
53 | parameterNet.load_state_dict(torch.load(opt.model_path)["parameterNet_net"])
54 | optimizer_T.load_state_dict(torch.load(opt.model_path)["optimizer_T"])
55 | optimizer_D.load_state_dict(torch.load(opt.model_path)["optimizer_D"])
56 | iter_num = torch.load(opt.model_path)["epoch"]
57 | scheduler_T.step()
58 | scheduler_D.step()
59 | else:
60 | iter_num = -1
--------------------------------------------------------------------------------
/UBRFC/GAN.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import functools
3 | from torch import nn
4 | #from torchsummary import summary
5 |
6 | from Attention import Attention
7 | from Util import PALayer, ConvGroups, FE_Block, Fusion_Block, ResnetBlock, ConvBlock, CALayer
8 |
9 |
10 | class Generator(nn.Module):
11 | def __init__(self, ngf=64, bn=False):
12 | super(Generator, self).__init__()
13 | # 下采样
14 | self.down1 = ResnetBlock(3, first=True)
15 |
16 | self.down2 = ResnetBlock(ngf, levels=2)
17 |
18 | self.down3 = ResnetBlock(ngf * 2, levels=2, bn=bn)
19 |
20 | self.res = nn.Sequential(
21 | ResnetBlock(ngf * 4, levels=6, down=False, bn=True)
22 | )
23 |
24 | # 上采样
25 |
26 | self.up1 = nn.Sequential(
27 | nn.LeakyReLU(True),
28 | nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=3, stride=2, padding=1,
29 | output_padding=1),
30 | nn.InstanceNorm2d(ngf * 2) if not bn else nn.BatchNorm2d(ngf * 2),
31 | )
32 |
33 | self.up2 = nn.Sequential(
34 | nn.LeakyReLU(True),
35 | nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1),
36 | nn.InstanceNorm2d(ngf),
37 | )
38 |
39 | self.info_up1 = nn.Sequential(
40 | nn.LeakyReLU(True),
41 | nn.ConvTranspose2d(ngf * 4, ngf * 2, kernel_size=3, stride=2, padding=1,
42 | output_padding=1),
43 | nn.InstanceNorm2d(ngf * 2) if not bn else nn.BatchNorm2d(ngf * 2, eps=1e-5),
44 | )
45 |
46 | self.info_up2 = nn.Sequential(
47 | nn.LeakyReLU(True),
48 | nn.ConvTranspose2d(ngf * 2, ngf, kernel_size=3, stride=2, padding=1, output_padding=1),
49 | nn.InstanceNorm2d(ngf) # if not bn else nn.BatchNorm2d(ngf, eps=1e-5),
50 | )
51 |
52 | self.up3 = nn.Sequential(
53 | nn.ReflectionPad2d(3),
54 | nn.Conv2d(ngf, 3, kernel_size=7, padding=0),
55 | nn.Tanh())
56 |
57 | self.pa2 = PALayer(ngf * 4)
58 | self.pa3 = PALayer(ngf * 2)
59 | self.pa4 = PALayer(ngf)
60 |
61 | # self.ca2 = CALayer(ngf * 4)
62 | # self.ca3 = CALayer(ngf * 2)
63 | # self.ca4 = CALayer(ngf)
64 | self.ca2 = Attention(ngf * 4)
65 | self.ca3 = Attention(ngf * 2)
66 | self.ca4 = Attention(ngf)
67 |
68 | self.down_dcp = ConvGroups(3, bn=bn)
69 |
70 | self.fam1 = FE_Block(ngf, ngf)
71 | self.fam2 = FE_Block(ngf, ngf * 2)
72 | self.fam3 = FE_Block(ngf * 2, ngf * 4)
73 |
74 | self.att1 = Fusion_Block(ngf)
75 | self.att2 = Fusion_Block(ngf * 2)
76 | self.att3 = Fusion_Block(ngf * 4, bn=bn)
77 |
78 | self.merge2 = nn.Sequential(
79 | ConvBlock(ngf * 2, ngf * 2, kernel_size=(3, 3), stride=1, padding=1),
80 | nn.LeakyReLU(inplace=False)
81 | )
82 | self.merge3 = nn.Sequential(
83 | ConvBlock(ngf, ngf, kernel_size=(3, 3), stride=1, padding=1),
84 | nn.LeakyReLU(inplace=False)
85 | )
86 |
87 | def forward(self, hazy, img=0, first=True):
88 | if not first:
89 | dcp_down1, dcp_down2, dcp_down3 = self.down_dcp(img)
90 | x_down1 = self.down1(hazy) # [bs, ngf, ngf * 4, ngf * 4][1,64,256,256]
91 |
92 | att1 = self.att1(dcp_down1, x_down1) if not first else x_down1 #[1,64,256,256]
93 |
94 | x_down2 = self.down2(x_down1) # [bs, ngf*2, ngf*2, ngf*2][1,128,128,128]
95 | att2 = self.att2(dcp_down2, x_down2) if not first else None #None
96 | fuse2 = self.fam2(att1, att2) if not first else self.fam2(att1, x_down2)#[1,128,128,128]
97 |
98 | x_down3 = self.down3(x_down2) # [bs, ngf * 4, ngf, ngf]#[1,256,64,64]
99 | att3 = self.att3(dcp_down3, x_down3) if not first else None #None
100 | fuse3 = self.fam3(fuse2, att3) if not first else self.fam3(fuse2, x_down3)#[1,256,64,64]
101 |
102 | x6 = self.pa2(self.ca2(self.res(x_down3)))#[1,256,64,64]
103 |
104 | fuse_up2 = self.info_up1(fuse3)#[1,128,128,128]
105 | fuse_up2 = self.merge2(fuse_up2 + x_down2)#[1,128,128,128]
106 |
107 | fuse_up3 = self.info_up2(fuse_up2)#[1,64,256,256]
108 | fuse_up3 = self.merge3(fuse_up3 + x_down1)#[1,64,256,256]
109 |
110 | x_up2 = self.up1(x6 + fuse3)#[1,128,128,128]
111 | x_up2 = self.ca3(x_up2)#[1,128,128,128]
112 | x_up2 = self.pa3(x_up2)#[1,128,128,128]
113 |
114 | x_up3 = self.up2(x_up2 + fuse_up2)#[1,64,256,256]
115 | x_up3 = self.ca4(x_up3)#[1,64,256,256]
116 | x_up3 = self.pa4(x_up3)#[1,64,256,256]
117 |
118 | x_up4 = self.up3(x_up3 + fuse_up3)#[1,3,256,256]
119 |
120 | return x_up4
121 |
122 |
123 | class Discriminator(nn.Module):
124 | """
125 | Discriminator class
126 | """
127 |
128 | def __init__(self, inp=3, out=1):
129 | """
130 | Initializes the PatchGAN model with 3 layers as discriminator
131 |
132 | Args:
133 | inp: number of input image channels
134 | out: number of output image channels
135 | """
136 |
137 | super(Discriminator, self).__init__()
138 |
139 | norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
140 |
141 | model = [
142 | nn.Conv2d(inp, 64, kernel_size=4, stride=2, padding=1), # input 3 channels
143 | nn.LeakyReLU(0.2, True),
144 |
145 | nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=True),
146 | norm_layer(128),
147 | nn.LeakyReLU(0.2, True),
148 |
149 | nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=True),
150 | norm_layer(256),
151 | nn.LeakyReLU(0.2, True),
152 |
153 | nn.Conv2d(256, 512, kernel_size=4, stride=1, padding=1, bias=True),
154 | norm_layer(512),
155 | nn.LeakyReLU(0.2, True),
156 |
157 | nn.Conv2d(512, out, kernel_size=4, stride=1, padding=1) # output only 1 channel (prediction map)
158 | ]
159 | self.model = nn.Sequential(*model)
160 |
161 | def forward(self, input):
162 | """
163 | Feed forward the image produced by generator through discriminator
164 |
165 | Args:
166 | input: input image
167 |
168 | Returns:
169 | outputs prediction map with 1 channel
170 | """
171 | result = self.model(input)
172 |
173 | return result
174 |
175 |
176 | class discriminator(nn.Module):
177 | def __init__(self, bn=False, ngf=64):
178 | super(discriminator, self).__init__()
179 | self.net = nn.Sequential(
180 | nn.Conv2d(3, ngf, kernel_size=3, padding=1),
181 | nn.LeakyReLU(0.2),
182 |
183 | nn.ReflectionPad2d(1),
184 | nn.Conv2d(ngf, ngf, kernel_size=3, stride=2, padding=0),
185 | nn.InstanceNorm2d(ngf),
186 | nn.LeakyReLU(0.2),
187 |
188 | nn.ReflectionPad2d(1),
189 | nn.Conv2d(ngf, ngf * 2, kernel_size=3, padding=0),
190 | nn.InstanceNorm2d(ngf * 2),
191 | nn.LeakyReLU(0.2),
192 |
193 | nn.ReflectionPad2d(1),
194 | nn.Conv2d(ngf * 2, ngf * 2, kernel_size=3, stride=2, padding=0),
195 | nn.InstanceNorm2d(ngf * 2),
196 | nn.LeakyReLU(0.2),
197 |
198 | nn.ReflectionPad2d(1),
199 | nn.Conv2d(ngf * 2, ngf * 4, kernel_size=3, padding=0),
200 | nn.InstanceNorm2d(ngf * 4) if not bn else nn.BatchNorm2d(ngf * 4),
201 | nn.LeakyReLU(0.2),
202 |
203 | nn.ReflectionPad2d(1),
204 | nn.Conv2d(ngf * 4, ngf * 4, kernel_size=3, stride=2, padding=0),
205 | nn.InstanceNorm2d(ngf * 4) if not bn else nn.BatchNorm2d(ngf * 4),
206 | nn.LeakyReLU(0.2),
207 |
208 | nn.ReflectionPad2d(1),
209 | nn.Conv2d(ngf * 4, ngf * 8, kernel_size=3, padding=0),
210 | nn.BatchNorm2d(ngf * 8) if bn else nn.InstanceNorm2d(ngf * 8),
211 | nn.LeakyReLU(0.2),
212 |
213 | nn.ReflectionPad2d(1),
214 | nn.Conv2d(ngf * 8, ngf * 8, kernel_size=3, stride=2, padding=0),
215 | nn.BatchNorm2d(ngf * 8) if bn else nn.InstanceNorm2d(ngf * 8),
216 | nn.LeakyReLU(0.2),
217 |
218 | nn.AdaptiveAvgPool2d(1),
219 | nn.Conv2d(ngf * 8, ngf * 16, kernel_size=1),
220 | nn.LeakyReLU(0.2),
221 | nn.Conv2d(ngf * 16, 1, kernel_size=1)
222 |
223 | )
224 |
225 | def forward(self, x):
226 | batch_size = x.size(0)
227 | return torch.sigmoid(self.net(x).view(batch_size))
228 |
229 |
230 |
231 |
232 | if __name__ == '__main__':
233 | net = Generator()
234 | #summary(net, (3,256,256), batch_size=1, device="cpu")
235 | x = torch.randn((1,3,256,256))
236 | #
237 | #print(net)
238 | y = net(x)
239 | print(y.size())
--------------------------------------------------------------------------------
/UBRFC/Get_image.py:
--------------------------------------------------------------------------------
1 | import os
2 | import math
3 | import time
4 | import torch
5 | import numpy as np
6 | from tqdm import tqdm
7 | from Option import opt
8 | from GAN import Generator
9 | from Metrics import ssim, psnr
10 | from Parameter import Parameter
11 | from Dataset import CustomDatasetLoader
12 | from torchvision.utils import save_image
13 |
14 | lr = opt.lr
15 | d_out_size = 30
16 | device = "cuda:"+opt.device
17 | epochs = opt.epochs
18 |
19 |
20 | test_dataloader = CustomDatasetLoader(root_dir=opt.test_root, isTrain=False,batch_size=1).load_data().dataloader
21 |
22 | generator = Generator()
23 |
24 | parameterNet = Parameter()
25 | generator.to(device)
26 |
27 | parameterNet.to(device)
28 |
29 | model_path = opt.pre_model_path
30 | generator.load_state_dict(torch.load(model_path)["generator_net"])
31 | parameterNet.load_state_dict(torch.load(model_path)["parameterNet_net"])
32 | def padding_image(image, h, w):
33 | assert h >= image.size(2)
34 | assert w >= image.size(3)
35 | padding_top = (h - image.size(2)) // 2
36 | padding_down = h - image.size(2) - padding_top
37 | padding_left = (w - image.size(3)) // 2
38 | padding_right = w - image.size(3) - padding_left
39 | out = torch.nn.functional.pad(image, (padding_left, padding_right, padding_top, padding_down), mode='reflect')
40 | return out, padding_left, padding_left + image.size(3), padding_top, padding_top + image.size(2)
41 |
42 | def test(genertor_net, parameter_net, loader_test):
43 | genertor_net.eval()
44 | parameter_net.eval()
45 | clear_psnrs,clear_ssims = [],[]
46 | for i, (inputs, targets, name) in tqdm(enumerate(loader_test), total=len(loader_test), leave=False, desc="测试中"):
47 | #print(name)
48 | h, w = inputs.shape[2], inputs.shape[3]
49 | #print('h, w:',h, w)
50 |
51 | if h>w:
52 | max_h = int(math.ceil(h / 512)) * 512
53 | max_w = int(math.ceil(w / 512)) * 512
54 | else:
55 | max_h = int(math.ceil(h / 256)) * 256
56 | max_w = int(math.ceil(w / 256)) * 256
57 | inputs, ori_left, ori_right, ori_top, ori_down = padding_image(inputs, max_h, max_w)
58 |
59 | inputs = inputs.to(device)
60 | targets = targets.to(device)
61 |
62 | pred = genertor_net(inputs)
63 | #print("pred.size:",pred.size())
64 | _, dehazy_pred = parameter_net(inputs, pred)
65 |
66 |
67 | pred = pred.data[:, :, ori_top:ori_down, ori_left:ori_right]
68 | dehazy_pred = dehazy_pred.data[:, :, ori_top:ori_down, ori_left:ori_right]
69 |
70 | ssim1 = ssim(pred, targets).item()
71 | psnr1 = psnr(pred, targets)
72 |
73 | ssim11 = ssim(dehazy_pred, targets).item()
74 | psnr11 = psnr(dehazy_pred, targets)
75 |
76 | save_image(inputs.data[:1], os.path.join(opt.out_hazy_path, "%s.png" % name[0]))
77 | save_image(targets.data[:1], os.path.join(opt.out_gt_path, "%s.png" % name[0]))
78 |
79 | if psnr11 >= psnr1:
80 | save_image(dehazy_pred.data[:1], os.path.join(opt.out_clear_path, "%s.png" % name[0]))
81 | clear_ssims.append(ssim11)
82 | clear_psnrs.append(psnr11)
83 | else:
84 | save_image(pred.data[:1], os.path.join(opt.out_clear_path, "%s.png" % name[0]))
85 | clear_ssims.append(ssim1)
86 | clear_psnrs.append(psnr1)
87 |
88 | return np.mean(clear_ssims), np.mean(clear_psnrs)
89 |
90 | if __name__ == '__main__':
91 | with torch.no_grad():
92 | print(test(generator, parameterNet, test_dataloader))
--------------------------------------------------------------------------------
/UBRFC/Loss.py:
--------------------------------------------------------------------------------
1 | # Loss functions
2 | from torch import nn
3 | from torchvision import models
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from math import exp
8 | import numpy as np
9 | from torchvision.models import vgg16
10 | import warnings
11 |
12 | warnings.filterwarnings('ignore')
13 |
14 | # 计算一维的高斯分布向量
15 | def gaussian(window_size, sigma):
16 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
17 | return gauss / gauss.sum()
18 |
19 |
20 | # 创建高斯核,通过两个一维高斯分布向量进行矩阵乘法得到
21 | # 可以设定channel参数拓展为3通道
22 | def create_window(window_size, channel=1):
23 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
24 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
25 | window = _2D_window.expand(channel, 1, window_size, window_size).contiguous()
26 | return window
27 |
28 |
29 | # 计算SSIM
30 | # 直接使用SSIM的公式,但是在计算均值时,不是直接求像素平均值,而是采用归一化的高斯核卷积来代替。
31 | # 在计算方差和协方差时用到了公式Var(X)=E[X^2]-E[X]^2, cov(X,Y)=E[XY]-E[X]E[Y].
32 | # 正如前面提到的,上面求期望的操作采用高斯核卷积代替。
33 | def ssim(img1, img2, window_size=11, window=None, size_average=True, full=False, val_range=None):
34 | # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
35 | if val_range is None:
36 | if torch.max(img1) > 128:
37 | max_val = 255
38 | else:
39 | max_val = 1
40 |
41 | if torch.min(img1) < -0.5:
42 | min_val = -1
43 | else:
44 | min_val = 0
45 | L = max_val - min_val
46 | else:
47 | L = val_range
48 |
49 | padd = 0
50 | (_, channel, height, width) = img1.size()
51 | if window is None:
52 | real_size = min(window_size, height, width)
53 | window = create_window(real_size, channel=channel).to(img1.device)
54 |
55 | mu1 = F.conv2d(img1, window, padding=padd, groups=channel)
56 | mu2 = F.conv2d(img2, window, padding=padd, groups=channel)
57 |
58 | mu1_sq = mu1.pow(2)
59 | mu2_sq = mu2.pow(2)
60 | mu1_mu2 = mu1 * mu2
61 |
62 | sigma1_sq = F.conv2d(img1 * img1, window, padding=padd, groups=channel) - mu1_sq
63 | sigma2_sq = F.conv2d(img2 * img2, window, padding=padd, groups=channel) - mu2_sq
64 | sigma12 = F.conv2d(img1 * img2, window, padding=padd, groups=channel) - mu1_mu2
65 |
66 | C1 = (0.01 * L) ** 2
67 | C2 = (0.03 * L) ** 2
68 |
69 | v1 = 2.0 * sigma12 + C2
70 | v2 = sigma1_sq + sigma2_sq + C2
71 | cs = torch.mean(v1 / v2) # contrast sensitivity
72 |
73 | ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
74 |
75 | if size_average:
76 | ret = ssim_map.mean()
77 | else:
78 | ret = ssim_map.mean(1).mean(1).mean(1)
79 |
80 | if full:
81 | return ret, cs
82 | return ret
83 |
84 |
85 | # Classes to re-use window
86 | class SSIMLoss(torch.nn.Module):
87 | def __init__(self, window_size=11, size_average=True, val_range=None):
88 | super(SSIMLoss, self).__init__()
89 | self.window_size = window_size
90 | self.size_average = size_average
91 | self.val_range = val_range
92 |
93 | # Assume 1 channel for SSIM
94 | self.channel = 1
95 | self.window = create_window(window_size)
96 |
97 | def forward(self, img1, img2):
98 | (_, channel, _, _) = img1.size()
99 |
100 | if channel == self.channel and self.window.dtype == img1.dtype:
101 | window = self.window
102 | else:
103 | window = create_window(self.window_size, channel).to(img1.device).type(img1.dtype)
104 | self.window = window
105 | self.channel = channel
106 |
107 | return ssim(img1, img2, window=window, window_size=self.window_size, size_average=self.size_average)
108 |
109 |
110 |
111 |
112 |
113 |
--------------------------------------------------------------------------------
/UBRFC/Metrics.py:
--------------------------------------------------------------------------------
1 | from math import exp
2 | import math
3 | import numpy as np
4 |
5 | import torch
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 |
9 |
10 |
11 |
12 | def gaussian(window_size, sigma):
13 | gauss = torch.Tensor([exp(-(x - window_size // 2) ** 2 / float(2 * sigma ** 2)) for x in range(window_size)])
14 | return gauss / gauss.sum()
15 |
16 |
17 | def create_window(window_size, channel):
18 | _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
19 | _2D_window = _1D_window.mm(_1D_window.t()).float().unsqueeze(0).unsqueeze(0)
20 | # torch.mul(a, b)是矩阵a和b对应位相乘,a和b的维度必须相等,比如a的维度是(1, 2),b的维度是(1, 2),返回的仍是(1, 2)的矩阵
21 | # torch.mm(a, b)是矩阵a和b矩阵相乘,比如a的维度是(1, 2),b的维度是(2, 3),返回的就是(1, 3)的矩阵
22 | # .t(), 求转置
23 | # unsqueeze()在二维张量前面添加2个轴,变成四维张量
24 |
25 | window = Variable(_2D_window.expand(channel, 1, window_size, window_size).contiguous())
26 | # 把张量扩展成(channel, 1, window_size, window_size)的大小,以原来的值填充(其自身的值不变)
27 | # contiguous:view只能用在contiguous的variable上。contiguous一般与transpose,permute,view搭配使用
28 | # 即使用transpose或permute进行维度变换后,需要用contiguous()来返回一个contiguous copy,然后方可使用view对维度进行变形
29 |
30 | return window
31 |
32 |
33 | def _ssim(img1, img2, window, window_size, channel, size_average=True):
34 | mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
35 | mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
36 | mu1_sq = mu1.pow(2)
37 | mu2_sq = mu2.pow(2)
38 | mu1_mu2 = mu1 * mu2
39 | sigma1_sq = F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel) - mu1_sq
40 | sigma2_sq = F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel) - mu2_sq
41 | sigma12 = F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel) - mu1_mu2
42 | C1 = 0.01 ** 2
43 | C2 = 0.03 ** 2
44 | ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
45 |
46 | if size_average:
47 | return ssim_map.mean()
48 | else:
49 | return ssim_map.mean(1).mean(1).mean(1)
50 |
51 |
52 | def ssim(img1, img2, window_size=11, size_average=True):
53 | img1 = torch.clamp(img1, min=0, max=1)
54 | img2 = torch.clamp(img2, min=0, max=1)
55 | (_, channel, _, _) = img1.size()
56 | window = create_window(window_size, channel)
57 | if img1.is_cuda:
58 | window = window.cuda(img1.get_device())
59 | window = window.type_as(img1)
60 | return _ssim(img1, img2, window, window_size, channel, size_average)
61 |
62 |
63 | def psnr(pred, gt):
64 | pred = pred.clamp(0, 1).cpu().numpy()
65 | gt = gt.clamp(0, 1).cpu().numpy()
66 | imdff = pred - gt
67 | rmse = math.sqrt(np.mean(imdff ** 2))
68 | if rmse == 0:
69 | return 100
70 | return 20 * math.log10(1.0 / rmse)
71 |
72 |
73 | if __name__ == "__main__":
74 | pass
75 |
--------------------------------------------------------------------------------
/UBRFC/Option.py:
--------------------------------------------------------------------------------
1 | import os
2 | import argparse
3 |
4 |
5 | class Options:
6 |
7 | def initialize(self, parser):
8 | parser.add_argument('--train_root', default='/data/wy/datasets/Haze1k/train',help='Training path')
9 | parser.add_argument('--test_root', default='/home/wy/datasets/Haze1k/Haze1k_thin/dataset/test', help='Testing path')
10 | parser.add_argument('--batch_size', type=int, default=4, help='input batch size')
11 | parser.add_argument('--epochs', type=int, default=10000, help='Total number of training')
12 | parser.add_argument('--lr', type=float, default=0.0001, help='Learning rate')
13 | parser.add_argument('--device', type=str, default='0', help='GPU id')
14 | parser.add_argument('--is_continue', default=False, help='Whether to continue')
15 | parser.add_argument('--model_path', default='./model/last_model.pth', help='Loading model')
16 | parser.add_argument('--pre_model_path', default='./model/best_model.pth', help='Loading model')
17 | parser.add_argument('--out_hazy_path', default='./images/input')
18 | parser.add_argument('--out_gt_path', default='./images/targets')
19 | parser.add_argument('--out_clear_path', default='./images/clear')
20 | parser.add_argument('--out_log_path', default='./Log/log_train')
21 | return parser
22 |
23 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
24 | parser = Options().initialize(parser)
25 | opt, _ = parser.parse_known_args()
26 |
27 |
28 | if not os.path.exists('./model'):
29 | os.makedirs('./model')
30 | if not os.path.exists('./Log/max_log'):
31 | os.makedirs('./Log/max_log')
32 | if not os.path.exists(opt.out_log_path):
33 | os.makedirs(opt.out_log_path)
34 | if not os.path.exists(opt.out_hazy_path):
35 | os.makedirs(opt.out_hazy_path)
36 | if not os.path.exists(opt.out_gt_path):
37 | os.makedirs(opt.out_gt_path)
38 | if not os.path.exists(opt.out_clear_path):
39 | os.makedirs(opt.out_clear_path)
--------------------------------------------------------------------------------
/UBRFC/Parameter.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from Attention import Attention
4 | import torch.nn.functional as F
5 | import torchvision.models as models
6 |
7 |
8 | def conv_block(in_dim, out_dim):
9 | return nn.Sequential(nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1),
10 | nn.ELU(True),
11 | nn.Conv2d(in_dim, in_dim, kernel_size=3, stride=1, padding=1),
12 | nn.ELU(True),
13 | nn.Conv2d(in_dim, out_dim, kernel_size=1, stride=1, padding=0),
14 | nn.AvgPool2d(kernel_size=2, stride=2))
15 |
16 |
17 | def deconv_block(in_dim, out_dim):
18 | return nn.Sequential(nn.Conv2d(in_dim, out_dim, kernel_size=3, stride=1, padding=1),
19 | nn.ELU(True),
20 | nn.Conv2d(out_dim, out_dim, kernel_size=3, stride=1, padding=1),
21 | nn.ELU(True),
22 | nn.UpsamplingNearest2d(scale_factor=2))
23 |
24 |
25 | def blockUNet1(in_c, out_c, name, transposed=False, bn=False, relu=True, dropout=False):
26 | block = nn.Sequential()
27 | if relu:
28 | block.add_module('%s_relu' % name, nn.ReLU(inplace=True))
29 | else:
30 | block.add_module('%s_leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True))
31 | if not transposed:
32 | block.add_module('%s_conv' % name, nn.Conv2d(in_c, out_c, 3, 1, 1, bias=False))
33 | else:
34 | block.add_module('%s_tconv' % name, nn.ConvTranspose2d(in_c, out_c, 3, 1, 1, bias=False))
35 | if bn:
36 | block.add_module('%s_bn' % name, nn.BatchNorm2d(out_c))
37 | if dropout:
38 | block.add_module('%s_dropout' % name, nn.Dropout2d(0.5, inplace=True))
39 | return block
40 |
41 |
42 | def blockUNet(in_c, out_c, name, transposed=False, bn=False, relu=True, dropout=False):
43 | block = nn.Sequential()
44 | if relu:
45 | block.add_module('%s_relu' % name, nn.ReLU(inplace=True))
46 | else:
47 | block.add_module('%s_leakyrelu' % name, nn.LeakyReLU(0.2, inplace=True))
48 | if not transposed:
49 | block.add_module('%s_conv' % name, nn.Conv2d(in_c, out_c, 4, 2, 1, bias=False))
50 | else:
51 | block.add_module('%s_tconv' % name, nn.ConvTranspose2d(in_c, out_c, 4, 2, 1, bias=False))
52 | if bn:
53 | block.add_module('%s_bn' % name, nn.BatchNorm2d(out_c))
54 | if dropout:
55 | block.add_module('%s_dropout' % name, nn.Dropout2d(0.5, inplace=True))
56 | return block
57 |
58 |
59 | class G(nn.Module):
60 | def __init__(self, input_nc, output_nc, nf):
61 | super(G, self).__init__()
62 | # input is 256 x 256
63 | layer_idx = 1
64 | name = 'layer%d' % layer_idx
65 | layer1 = nn.Sequential()
66 | layer1.add_module(name, nn.Conv2d(input_nc, nf, 4, 2, 1, bias=False))
67 | # input is 128 x 128
68 | layer_idx += 1
69 | name = 'layer%d' % layer_idx
70 | layer2 = blockUNet(nf, nf * 2, name, transposed=False, bn=True, relu=False, dropout=False)
71 | # input is 64 x 64
72 | layer_idx += 1
73 | name = 'layer%d' % layer_idx
74 | layer3 = blockUNet(nf * 2, nf * 4, name, transposed=False, bn=True, relu=False, dropout=False)
75 | # input is 32
76 | layer_idx += 1
77 | name = 'layer%d' % layer_idx
78 | layer4 = blockUNet(nf * 4, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
79 | # input is 16
80 | layer_idx += 1
81 | name = 'layer%d' % layer_idx
82 | layer5 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
83 | # input is 8
84 | layer_idx += 1
85 | name = 'layer%d' % layer_idx
86 | layer6 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
87 | # input is 4
88 | layer_idx += 1
89 | name = 'layer%d' % layer_idx
90 | layer7 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
91 | # input is 2 x 2
92 | layer_idx += 1
93 | name = 'layer%d' % layer_idx
94 | layer8 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
95 |
96 | ## NOTE: decoder
97 | # input is 1
98 | name = 'dlayer%d' % layer_idx
99 | d_inc = nf * 8
100 | dlayer8 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=False, relu=True, dropout=True)
101 |
102 | # input is 2
103 | layer_idx -= 1
104 | name = 'dlayer%d' % layer_idx
105 | d_inc = nf * 8 * 2
106 | dlayer7 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=True, relu=True, dropout=True)
107 | # input is 4
108 | layer_idx -= 1
109 | name = 'dlayer%d' % layer_idx
110 | d_inc = nf * 8 * 2
111 | dlayer6 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=True, relu=True, dropout=True)
112 | # input is 8
113 | layer_idx -= 1
114 | name = 'dlayer%d' % layer_idx
115 | d_inc = nf * 8 * 2
116 | dlayer5 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=True, relu=True, dropout=False)
117 | # input is 16
118 | layer_idx -= 1
119 | name = 'dlayer%d' % layer_idx
120 | d_inc = nf * 8 * 2
121 | dlayer4 = blockUNet(d_inc, nf * 4, name, transposed=True, bn=True, relu=True, dropout=False)
122 | # input is 32
123 | layer_idx -= 1
124 | name = 'dlayer%d' % layer_idx
125 | d_inc = nf * 4 * 2
126 | dlayer3 = blockUNet(d_inc, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False)
127 | # input is 64
128 | layer_idx -= 1
129 | name = 'dlayer%d' % layer_idx
130 | d_inc = nf * 2 * 2
131 | dlayer2 = blockUNet(d_inc, nf, name, transposed=True, bn=True, relu=True, dropout=False)
132 | # input is 128
133 | layer_idx -= 1
134 | name = 'dlayer%d' % layer_idx
135 | dlayer1 = nn.Sequential()
136 | d_inc = nf * 2
137 | dlayer1.add_module('%s_relu' % name, nn.ReLU(inplace=True))
138 | dlayer1.add_module('%s_tconv' % name, nn.ConvTranspose2d(d_inc, 20, 4, 2, 1, bias=False))
139 |
140 | dlayerfinal = nn.Sequential()
141 |
142 | dlayerfinal.add_module('%s_conv' % name, nn.Conv2d(24, output_nc, 3, 1, 1, bias=False))
143 | dlayerfinal.add_module('%s_tanh' % name, nn.Tanh())
144 |
145 | self.conv1010 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm
146 | self.conv1020 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm
147 | self.conv1030 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm
148 | self.conv1040 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm
149 |
150 | self.refine3 = nn.Conv2d(20 + 4, 3, kernel_size=3, stride=1, padding=1)
151 |
152 | self.upsample = F.upsample_nearest
153 |
154 | self.layer1 = layer1
155 | self.layer2 = layer2
156 | self.layer3 = layer3
157 | self.layer4 = layer4
158 | self.layer5 = layer5
159 | self.layer6 = layer6
160 | self.layer7 = layer7
161 | self.layer8 = layer8
162 | self.dlayer8 = dlayer8
163 | self.dlayer7 = dlayer7
164 | self.dlayer6 = dlayer6
165 | self.dlayer5 = dlayer5
166 | self.dlayer4 = dlayer4
167 | self.dlayer3 = dlayer3
168 | self.dlayer2 = dlayer2
169 | self.dlayer1 = dlayer1
170 | self.dlayerfinal = dlayerfinal
171 | self.relu = nn.LeakyReLU(0.2, inplace=True)
172 |
173 | def forward(self, x):
174 | out1 = self.layer1(x)
175 | out2 = self.layer2(out1)
176 | out3 = self.layer3(out2)
177 | out4 = self.layer4(out3)
178 | out5 = self.layer5(out4)
179 | out6 = self.layer6(out5)
180 | out7 = self.layer7(out6)
181 | out8 = self.layer8(out7)
182 | dout8 = self.dlayer8(out8)
183 | dout8_out7 = torch.cat([dout8, out7], 1)
184 | dout7 = self.dlayer7(dout8_out7)
185 | dout7_out6 = torch.cat([dout7, out6], 1)
186 | dout6 = self.dlayer6(dout7_out6)
187 | dout6_out5 = torch.cat([dout6, out5], 1)
188 | dout5 = self.dlayer5(dout6_out5)
189 | dout5_out4 = torch.cat([dout5, out4], 1)
190 | dout4 = self.dlayer4(dout5_out4)
191 | dout4_out3 = torch.cat([dout4, out3], 1)
192 | dout3 = self.dlayer3(dout4_out3)
193 | dout3_out2 = torch.cat([dout3, out2], 1)
194 | dout2 = self.dlayer2(dout3_out2)
195 | dout2_out1 = torch.cat([dout2, out1], 1)
196 | dout1 = self.dlayer1(dout2_out1)
197 |
198 | shape_out = dout1.data.size()
199 | # print(shape_out)
200 | shape_out = shape_out[2:4]
201 |
202 | x101 = F.avg_pool2d(dout1, 16)
203 | x102 = F.avg_pool2d(dout1, 8)
204 | x103 = F.avg_pool2d(dout1, 4)
205 | x104 = F.avg_pool2d(dout1, 2)
206 |
207 | x1010 = self.upsample(self.relu(self.conv1010(x101)), size=shape_out)
208 | x1020 = self.upsample(self.relu(self.conv1020(x102)), size=shape_out)
209 | x1030 = self.upsample(self.relu(self.conv1030(x103)), size=shape_out)
210 | x1040 = self.upsample(self.relu(self.conv1040(x104)), size=shape_out)
211 |
212 | dehaze = torch.cat((x1010, x1020, x1030, x1040, dout1), 1)
213 |
214 | dout1 = self.dlayerfinal(dehaze)
215 |
216 | return dout1
217 |
218 |
219 | class G2(nn.Module):
220 | def __init__(self, input_nc, output_nc, nf):
221 | super(G2, self).__init__()
222 | # input is 256 x 256
223 | layer_idx = 1
224 | name = 'layer%d' % layer_idx
225 | layer1 = nn.Sequential()
226 | layer1.add_module(name, nn.Conv2d(input_nc, nf, 4, 2, 1, bias=False))
227 | # input is 128 x 128
228 | layer_idx += 1
229 | name = 'layer%d' % layer_idx
230 | layer2 = blockUNet(nf, nf * 2, name, transposed=False, bn=True, relu=False, dropout=False)
231 | # input is 64 x 64
232 | layer_idx += 1
233 | name = 'layer%d' % layer_idx
234 | layer3 = blockUNet(nf * 2, nf * 4, name, transposed=False, bn=True, relu=False, dropout=False)
235 | # input is 32
236 | layer_idx += 1
237 | name = 'layer%d' % layer_idx
238 | layer4 = blockUNet(nf * 4, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
239 | # input is 16
240 | layer_idx += 1
241 | name = 'layer%d' % layer_idx
242 | layer5 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
243 | # input is 8
244 | layer_idx += 1
245 | name = 'layer%d' % layer_idx
246 | layer6 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
247 | # input is 4
248 | layer_idx += 1
249 | name = 'layer%d' % layer_idx
250 | layer7 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
251 | # input is 2 x 2
252 | layer_idx += 1
253 | name = 'layer%d' % layer_idx
254 | layer8 = blockUNet(nf * 8, nf * 8, name, transposed=False, bn=True, relu=False, dropout=False)
255 |
256 | ## NOTE: decoder
257 | # input is 1
258 | name = 'dlayer%d' % layer_idx
259 | d_inc = nf * 8
260 | dlayer8 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=False, relu=True, dropout=True)
261 |
262 | # import pdb; pdb.set_trace()
263 | # input is 2
264 | layer_idx -= 1
265 | name = 'dlayer%d' % layer_idx
266 | d_inc = nf * 8 * 2
267 | dlayer7 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=True, relu=True, dropout=True)
268 | # input is 4
269 | layer_idx -= 1
270 | name = 'dlayer%d' % layer_idx
271 | d_inc = nf * 8 * 2
272 | dlayer6 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=True, relu=True, dropout=True)
273 | # input is 8
274 | layer_idx -= 1
275 | name = 'dlayer%d' % layer_idx
276 | d_inc = nf * 8 * 2
277 | dlayer5 = blockUNet(d_inc, nf * 8, name, transposed=True, bn=True, relu=True, dropout=False)
278 | # input is 16
279 | layer_idx -= 1
280 | name = 'dlayer%d' % layer_idx
281 | d_inc = nf * 8 * 2
282 | dlayer4 = blockUNet(d_inc, nf * 4, name, transposed=True, bn=True, relu=True, dropout=False)
283 | # input is 32
284 | layer_idx -= 1
285 | name = 'dlayer%d' % layer_idx
286 | d_inc = nf * 4 * 2
287 | dlayer3 = blockUNet(d_inc, nf * 2, name, transposed=True, bn=True, relu=True, dropout=False)
288 | # input is 64
289 | layer_idx -= 1
290 | name = 'dlayer%d' % layer_idx
291 | d_inc = nf * 2 * 2
292 | dlayer2 = blockUNet(d_inc, nf, name, transposed=True, bn=True, relu=True, dropout=False)
293 | # input is 128
294 | layer_idx -= 1
295 | name = 'dlayer%d' % layer_idx
296 | dlayer1 = nn.Sequential()
297 | d_inc = nf * 2
298 | dlayer1.add_module('%s_relu' % name, nn.ReLU(inplace=True))
299 | dlayer1.add_module('%s_tconv' % name, nn.ConvTranspose2d(d_inc, output_nc, 4, 2, 1, bias=False))
300 | dlayer1.add_module('%s_tanh' % name, nn.LeakyReLU(0.2, inplace=True))
301 |
302 | self.layer1 = layer1
303 | self.layer2 = layer2
304 | self.layer3 = layer3
305 | self.layer4 = layer4
306 | self.layer5 = layer5
307 | self.layer6 = layer6
308 | self.layer7 = layer7
309 | self.layer8 = layer8
310 | self.ca1 = Attention(64)
311 | self.dlayer8 = dlayer8
312 | self.ca2 = Attention(128)
313 | self.dlayer7 = dlayer7
314 | self.ca3 = Attention(128)
315 | self.dlayer6 = dlayer6
316 | self.dlayer5 = dlayer5
317 | self.dlayer4 = dlayer4
318 | self.dlayer3 = dlayer3
319 | self.dlayer2 = dlayer2
320 | self.dlayer1 = dlayer1
321 |
322 | def forward(self, x):
323 | out1 = self.layer1(x) # [1,8,128,128]
324 | out2 = self.layer2(out1) # [1,16,64,64]
325 | out3 = self.layer3(out2) # [1,32,32,32]
326 | out4 = self.layer4(out3) # [1,64,16,16]
327 | out5 = self.layer5(out4) # [1,64,8,8]
328 | out6 = self.layer6(out5) # [1,64,4,4]
329 | out7 = self.layer7(out6) # [1,64,2,2]
330 | out8 = self.layer8(out7) # [1,64,1,1]
331 |
332 | out8 = self.ca1(out8)
333 | dout8 = self.dlayer8(out8) # [1,64,2,2]
334 | dout8_out7 = torch.cat([dout8, out7], 1) # [1,128,2,2]
335 | dout8_out7 = self.ca2(dout8_out7)
336 |
337 | dout7 = self.dlayer7(dout8_out7) # [1,64,4,4]
338 | dout7_out6 = torch.cat([dout7, out6], 1) # [1,128,4,4]
339 | dout7_out6 = self.ca3(dout7_out6)
340 |
341 | dout6 = self.dlayer6(dout7_out6) # [1,64,8,8]
342 | dout6_out5 = torch.cat([dout6, out5], 1) # [1,128,8,8]
343 | dout5 = self.dlayer5(dout6_out5) # [1,64,16,16]
344 | dout5_out4 = torch.cat([dout5, out4], 1) # [1,128,16,16]
345 | dout4 = self.dlayer4(dout5_out4) # [1,32,32,32]
346 | dout4_out3 = torch.cat([dout4, out3], 1) # [1,64,32,32]
347 | dout3 = self.dlayer3(dout4_out3) # [1,16,64,64]
348 | dout3_out2 = torch.cat([dout3, out2], 1) # [1,32,64,64]
349 | dout2 = self.dlayer2(dout3_out2) # [1,8,128,128]
350 | dout2_out1 = torch.cat([dout2, out1], 1) # [1,16,128,128]
351 | dout1 = self.dlayer1(dout2_out1) # [1,3,256,256]
352 | return dout1
353 |
354 |
355 | class BottleneckBlock(nn.Module):
356 | def __init__(self, in_planes, out_planes, dropRate=0.0):
357 | super(BottleneckBlock, self).__init__()
358 | inter_planes = out_planes * 4
359 | self.bn1 = nn.BatchNorm2d(in_planes)
360 | self.relu = nn.ReLU(inplace=True)
361 | self.conv1 = nn.Conv2d(in_planes, inter_planes, kernel_size=1, stride=1,
362 | padding=0, bias=False)
363 | self.bn2 = nn.BatchNorm2d(inter_planes)
364 | self.conv2 = nn.Conv2d(inter_planes, out_planes, kernel_size=3, stride=1,
365 | padding=1, bias=False)
366 | self.droprate = dropRate
367 |
368 | def forward(self, x):
369 | out = self.conv1(self.relu(self.bn1(x)))
370 | if self.droprate > 0:
371 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
372 | out = self.conv2(self.relu(self.bn2(out)))
373 | if self.droprate > 0:
374 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
375 | return torch.cat([x, out], 1)
376 |
377 |
378 | class TransitionBlock(nn.Module):
379 | def __init__(self, in_planes, out_planes, dropRate=0.0):
380 | super(TransitionBlock, self).__init__()
381 | self.bn1 = nn.BatchNorm2d(in_planes)
382 | self.relu = nn.ReLU(inplace=True)
383 | self.conv1 = nn.ConvTranspose2d(in_planes, out_planes, kernel_size=1, stride=1,
384 | padding=0, bias=False)
385 | self.droprate = dropRate
386 |
387 | def forward(self, x):
388 | out = self.conv1(self.relu(self.bn1(x)))
389 | if self.droprate > 0:
390 | out = F.dropout(out, p=self.droprate, inplace=False, training=self.training)
391 | return F.upsample_nearest(out, scale_factor=2)
392 |
393 |
394 | class Dense(nn.Module):
395 | def __init__(self):
396 | super(Dense, self).__init__()
397 |
398 | ############# 256-256 ##############
399 | haze_class = models.densenet121(pretrained=True)
400 |
401 | self.conv0 = haze_class.features.conv0
402 | self.norm0 = haze_class.features.norm0
403 | self.relu0 = haze_class.features.relu0
404 | self.pool0 = haze_class.features.pool0
405 |
406 | ############# Block1-down 64-64 ##############
407 | self.dense_block1 = haze_class.features.denseblock1
408 | self.trans_block1 = haze_class.features.transition1
409 |
410 | ############# Block2-down 32-32 ##############
411 | self.dense_block2 = haze_class.features.denseblock2
412 | self.trans_block2 = haze_class.features.transition2
413 |
414 | ############# Block3-down 16-16 ##############
415 | self.dense_block3 = haze_class.features.denseblock3
416 | self.trans_block3 = haze_class.features.transition3
417 |
418 | ############# Block4-up 8-8 ##############
419 |
420 | self.dense_block4 = BottleneckBlock(512, 256)
421 | self.trans_block4 = TransitionBlock(768, 128)
422 |
423 | ############# Block5-up 16-16 ##############
424 | self.pa1 = Attention(channel=384)
425 | self.dense_block5 = BottleneckBlock(384, 256)
426 | self.trans_block5 = TransitionBlock(640, 128)
427 |
428 | ############# Block6-up 32-32 ##############
429 | self.pa2 = Attention(channel=256)
430 | self.dense_block6 = BottleneckBlock(256, 128)
431 | self.trans_block6 = TransitionBlock(384, 64)
432 |
433 | ############# Block7-up 64-64 ##############
434 | self.pa3 = Attention(channel=64)
435 | self.dense_block7 = BottleneckBlock(64, 64)
436 | self.trans_block7 = TransitionBlock(128, 32)
437 |
438 | ## 128 X 128
439 | ############# Block8-up c ##############
440 | self.pa4 = Attention(channel=32)
441 | self.dense_block8 = BottleneckBlock(32, 32)
442 | self.trans_block8 = TransitionBlock(64, 16)
443 |
444 | self.conv_refin = nn.Conv2d(19, 20, 3, 1, 1)
445 | self.tanh = nn.Tanh()
446 |
447 | self.conv1010 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm
448 | self.conv1020 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm
449 | self.conv1030 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm
450 | self.conv1040 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm
451 |
452 | self.refine3 = nn.Conv2d(20 + 4, 3, kernel_size=3, stride=1, padding=1)
453 | # self.refine3= nn.Conv2d(20+4, 3, kernel_size=7,stride=1,padding=3)
454 |
455 | self.upsample = F.upsample_nearest
456 |
457 | self.relu = nn.LeakyReLU(0.2, inplace=True)
458 |
459 | def forward(self, x):
460 | ## 256x256 x0[1,64,64,64]
461 | x0 = self.pool0(self.relu0(self.norm0(self.conv0(x))))
462 |
463 | ## 64 X 64 x1[1,256,64,64]
464 | x1 = self.dense_block1(x0)
465 | # print x1.size() x1[1,128,32,32]
466 | x1 = self.trans_block1(x1)
467 |
468 | ### 32x32 x2[1,256,16,16]
469 | x2 = self.trans_block2(self.dense_block2(x1))
470 | # print x2.size()
471 |
472 | ### 16 X 16 x3[1,512,8,8]
473 | x3 = self.trans_block3(self.dense_block3(x2))
474 |
475 | # x3=Variable(x3.data,requires_grad=True)
476 |
477 | ## 8 X 8 x4[1,128,16,16]
478 | x4 = self.trans_block4(self.dense_block4(x3))
479 | ## x42[1,384,16,16]
480 | x42 = torch.cat([x4, x2], 1)
481 |
482 | x42 = self.pa1(x42)
483 | ## 16 X 16 x5[1,128,32,32]
484 | x5 = self.trans_block5(self.dense_block5(x42))
485 | ##x52[1,256,32,32]
486 | x52 = torch.cat([x5, x1], 1)
487 | ## 32 X 32 x6[1,64,64,64]
488 |
489 | x52 = self.pa2(x52)
490 | x6 = self.trans_block6(self.dense_block6(x52))
491 | x6 = self.pa3(x6)
492 | ## 64 X 64 x7[1,32,128,128]
493 | x7 = self.trans_block7(self.dense_block7(x6))
494 | x7 = self.pa4(x7)
495 | ## 128 X 128 x8[1,16,256,256]
496 | x8 = self.trans_block8(self.dense_block8(x7))
497 |
498 | # print x8.size()
499 | # print x.size()
500 | ##x8[1,19,256,256]
501 | x8 = torch.cat([x8, x], 1)
502 |
503 | # print x8.size()
504 | ##x9[1,20,256,256]
505 | x9 = self.relu(self.conv_refin(x8))
506 |
507 | shape_out = x9.data.size()
508 | # print(shape_out)
509 | shape_out = shape_out[2:4]
510 | ## x101[1,20,8,8]
511 | ## x102[1,20,16,16]
512 | ## x103[1,20,32,32]
513 | ## x104[1,20,64,64]
514 | x101 = F.avg_pool2d(x9, 32)
515 | x102 = F.avg_pool2d(x9, 16)
516 | x103 = F.avg_pool2d(x9, 8)
517 | x104 = F.avg_pool2d(x9, 4)
518 | ## x1010[1,1,256,256]
519 | ## x1020[1,1,256,256]
520 | x1010 = self.upsample(self.relu(self.conv1010(x101)), size=shape_out)
521 | x1020 = self.upsample(self.relu(self.conv1020(x102)), size=shape_out)
522 | x1030 = self.upsample(self.relu(self.conv1030(x103)), size=shape_out)
523 | x1040 = self.upsample(self.relu(self.conv1040(x104)), size=shape_out)
524 |
525 | dehaze = torch.cat((x1010, x1020, x1030, x1040, x9), 1)
526 | dehaze = self.tanh(self.refine3(dehaze))
527 |
528 | return dehaze
529 |
530 |
531 | class Parameter(nn.Module):
532 | def __init__(self):
533 | super(Parameter, self).__init__()
534 |
535 | self.atp_est = G2(input_nc=3, output_nc=3, nf=8)
536 |
537 | self.tran_dense = Dense()
538 | self.relu = nn.LeakyReLU(0.2, inplace=True)
539 |
540 | self.tanh = nn.Tanh()
541 |
542 | self.refine1 = nn.Conv2d(6, 20, kernel_size=3, stride=1, padding=1)
543 | self.refine2 = nn.Conv2d(20, 20, kernel_size=3, stride=1, padding=1)
544 | self.threshold = nn.Threshold(0.1, 0.1)
545 |
546 | self.conv1010 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm
547 | self.conv1020 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm
548 | self.conv1030 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm
549 | self.conv1040 = nn.Conv2d(20, 1, kernel_size=1, stride=1, padding=0) # 1mm
550 |
551 | self.refine3 = nn.Conv2d(20 + 4, 3, kernel_size=3, stride=1, padding=1)
552 |
553 | self.upsample = F.upsample_nearest
554 |
555 | self.batch1 = nn.BatchNorm2d(20)
556 |
557 | def forward(self, x, y):
558 | tran = self.tran_dense(x)
559 | atp = self.atp_est(x)
560 |
561 | zz = torch.abs((tran)) + (10 ** -10) # t
562 | shape_out1 = atp.data.size()
563 |
564 | shape_out = shape_out1[2:4]
565 | atp = F.avg_pool2d(atp, shape_out1[2])
566 | atp = self.upsample(self.relu(atp), size=shape_out)
567 |
568 | haze = (y * zz) + atp * (1 - zz)
569 | dehaze = (x - atp) / zz + atp # 去雾公式
570 |
571 | return haze, dehaze
572 |
573 |
574 | if __name__ == '__main__':
575 | x = torch.rand((1, 64, 2, 2))
576 | y = torch.rand((1, 3, 256, 256))
577 | a=blockUNet(64, 64, "0", transposed=False, bn=True, relu=False, dropout=False)
578 | #a = Parameter()
579 | #a.eval()
580 | b = a(x)
581 |
--------------------------------------------------------------------------------
/UBRFC/Util.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch import nn
3 |
4 |
5 | class ResnetBlock(nn.Module):
6 |
7 | def __init__(self, dim, down=True, first=False, levels=3, bn=False):
8 | super(ResnetBlock, self).__init__()
9 | blocks = []
10 | for i in range(levels):
11 | blocks.append(Block(dim=dim, bn=bn))
12 | self.res = nn.Sequential(
13 | *blocks
14 | ) if not first else None
15 | self.downsample_layer = nn.Sequential(
16 | nn.InstanceNorm2d(dim, eps=1e-6) if not bn else nn.BatchNorm2d(dim, eps=1e-6),
17 | nn.ReflectionPad2d(1),
18 | nn.Conv2d(dim, dim * 2, kernel_size=3, stride=2)
19 | ) if down else None
20 | self.stem = nn.Sequential(
21 | nn.ReflectionPad2d(3),
22 | nn.Conv2d(dim, 64, kernel_size=7),
23 | nn.InstanceNorm2d(64, eps=1e-6)
24 | ) if first else None
25 |
26 | def forward(self, x):
27 | if self.stem is not None:
28 | out = self.stem(x)
29 | return out
30 | out = x + self.res(x)
31 | if self.downsample_layer is not None:
32 | out = self.downsample_layer(out)
33 | return out
34 |
35 |
36 | class Block(nn.Module):
37 |
38 | def __init__(self, dim, bn=False):
39 | super(Block, self).__init__()
40 |
41 | conv_block = []
42 | conv_block += [nn.ReflectionPad2d(1)]
43 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0),
44 | nn.InstanceNorm2d(dim) if not bn else nn.BatchNorm2d(dim, eps=1e-6),
45 | nn.LeakyReLU()]
46 |
47 | conv_block += [nn.ReflectionPad2d(1)]
48 | conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0),
49 | nn.InstanceNorm2d(dim) if not bn else nn.BatchNorm2d(dim, eps=1e-6)]
50 |
51 | self.conv_block = nn.Sequential(*conv_block)
52 |
53 | def forward(self, x):
54 | out = x + self.conv_block(x)
55 | return out
56 |
57 |
58 | class PALayer(nn.Module):
59 |
60 | def __init__(self, channel):
61 | super(PALayer, self).__init__()
62 | self.pa = nn.Sequential(
63 | nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
64 | nn.ReLU(inplace=True),
65 | nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
66 | nn.Sigmoid()
67 | )
68 |
69 | def forward(self, x):
70 | y = self.pa(x)
71 | return x * y
72 |
73 |
74 | class ConvBlock(nn.Module):
75 |
76 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=True,
77 | bn=False, bias=False):
78 | super(ConvBlock, self).__init__()
79 | self.out_channels = out_planes
80 | self.relu = nn.LeakyReLU(inplace=False) if relu else None
81 | self.pad = nn.ReflectionPad2d(padding)
82 | self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=0,
83 | dilation=dilation, groups=groups, bias=bias)
84 | self.bn = nn.InstanceNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) if not bn else nn.BatchNorm2d(
85 | out_planes, eps=1e-5, momentum=0.01, affine=True)
86 |
87 | def forward(self, x):
88 | if self.relu is not None:
89 | x = self.relu(x)
90 | x = self.pad(x)
91 | x = self.conv(x)
92 | if self.bn is not None:
93 | x = self.bn(x)
94 | return x
95 |
96 |
97 | class Fusion_Block(nn.Module):
98 |
99 | def __init__(self, channel, bn=False, res=False):
100 | super(Fusion_Block, self).__init__()
101 | self.bn = nn.InstanceNorm2d(channel, eps=1e-5, momentum=0.01, affine=True) if not bn else nn.BatchNorm2d(
102 | channel, eps=1e-5, momentum=0.01, affine=True)
103 | self.merge = nn.Sequential(
104 | ConvBlock(channel, channel, kernel_size=(3, 3), stride=1, padding=1),
105 | nn.LeakyReLU(inplace=False)
106 | ) if not res else None
107 | self.block = ResnetBlock(channel, down=False, levels=2, bn=bn) if res else None
108 |
109 | def forward(self, o, s):
110 | o_bn = self.bn(o) if self.bn is not None else o
111 | x = o_bn + s
112 | if self.merge is not None:
113 | x = self.merge(x)
114 | if self.block is not None:
115 | x = self.block(x)
116 | return x
117 |
118 |
119 | class FE_Block(nn.Module):
120 |
121 | def __init__(self, plane1, plane2, res=True):
122 | super(FE_Block, self).__init__()
123 | self.dsc = ConvBlock(plane1, plane2, kernel_size=(3, 3), stride=2, padding=1, relu=False)
124 |
125 | self.merge = nn.Sequential(
126 | ConvBlock(plane2, plane2, kernel_size=(3, 3), stride=1, padding=1),
127 | nn.LeakyReLU(inplace=False)
128 | ) if not res else None
129 | self.block = ResnetBlock(plane2, down=False, levels=2) if res else None
130 |
131 | def forward(self, p, s):
132 | x = s + self.dsc(p)
133 | if self.merge is not None:
134 | x = self.merge(x)
135 | if self.block is not None:
136 | x = self.block(x)
137 | return x
138 |
139 |
140 | class Iter_Downsample(nn.Module):
141 | def __init__(self, ):
142 | super(Iter_Downsample, self).__init__()
143 | self.ds1 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
144 | self.ds2 = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
145 |
146 | def forward(self, x):
147 | x1 = self.ds1(x)
148 | x2 = self.ds2(x1)
149 | return x, x1, x2
150 |
151 |
152 | class CALayer(nn.Module):
153 |
154 | def __init__(self, channel):
155 | super(CALayer, self).__init__()
156 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
157 | self.ca = nn.Sequential(
158 | nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
159 | nn.ReLU(inplace=True),
160 | nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
161 | nn.Sigmoid()
162 | )
163 |
164 | def forward(self, x):
165 | y = self.avg_pool(x)
166 | y = self.ca(y)
167 | return x * y
168 |
169 |
170 | class ConvGroups(nn.Module):
171 |
172 | def __init__(self, in_planes, bn=False):
173 | super(ConvGroups, self).__init__()
174 | self.iter_ds = Iter_Downsample()
175 | self.lcb1 = nn.Sequential(
176 | ConvBlock(in_planes, 16, kernel_size=(3, 3), padding=1), ConvBlock(16, 16, kernel_size=1, stride=1),
177 | ConvBlock(16, 16, kernel_size=(3, 3), padding=1, bn=bn),
178 | ConvBlock(16, 64, kernel_size=1, bn=bn, relu=False))
179 | self.lcb2 = nn.Sequential(
180 | ConvBlock(in_planes, 32, kernel_size=(3, 3), padding=1), ConvBlock(32, 32, kernel_size=1),
181 | ConvBlock(32, 32, kernel_size=(3, 3), padding=1), ConvBlock(32, 32, kernel_size=1, stride=1, bn=bn),
182 | ConvBlock(32, 32, kernel_size=(3, 3), padding=1, bn=bn),
183 | ConvBlock(32, 128, kernel_size=1, bn=bn, relu=False))
184 | self.lcb3 = nn.Sequential(
185 | ConvBlock(in_planes, 64, kernel_size=(3, 3), padding=1), ConvBlock(64, 64, kernel_size=1),
186 | ConvBlock(64, 64, kernel_size=(3, 3), padding=1), ConvBlock(64, 64, kernel_size=1, bn=bn),
187 | ConvBlock(64, 64, kernel_size=(3, 3), padding=1, bn=bn),
188 | ConvBlock(64, 256, kernel_size=1, bn=bn, relu=False))
189 |
190 | def forward(self, x):
191 | img1, img2, img3 = self.iter_ds(x)
192 | s1 = self.lcb1(img1)
193 | s2 = self.lcb2(img2)
194 | s3 = self.lcb3(img3)
195 | return s1, s2, s3
196 |
197 |
198 | def padding_image(image, h, w):
199 | assert h >= image.size(2)
200 | assert w >= image.size(3)
201 | padding_top = (h - image.size(2)) // 2
202 | padding_down = h - image.size(2) - padding_top
203 | padding_left = (w - image.size(3)) // 2
204 | padding_right = w - image.size(3) - padding_left
205 | out = torch.nn.functional.pad(image, (padding_left, padding_right, padding_top, padding_down), mode='reflect')
206 | return out, padding_left, padding_left + image.size(3), padding_top, padding_top + image.size(2)
207 |
208 |
209 | if __name__ == '__main__':
210 | print(CALayer(3))
--------------------------------------------------------------------------------
/UBRFC/test.py:
--------------------------------------------------------------------------------
1 | import os.path
2 | import numpy as np
3 | from tqdm import tqdm
4 | from Option import opt
5 | from Declaration import device,generator,parameterNet,test_dataloader
6 | from Metrics import ssim, psnr
7 | from torchvision.utils import save_image
8 | import math
9 | import torch
10 | def padding_image(image, h, w):
11 | assert h >= image.size(2)
12 | assert w >= image.size(3)
13 | padding_top = (h - image.size(2)) // 2
14 | padding_down = h - image.size(2) - padding_top
15 | padding_left = (w - image.size(3)) // 2
16 | padding_right = w - image.size(3) - padding_left
17 | out = torch.nn.functional.pad(image, (padding_left, padding_right, padding_top, padding_down), mode='reflect')
18 | return out, padding_left, padding_left + image.size(3), padding_top, padding_top + image.size(2)
19 |
20 | def test(genertor_net, parameter_net, loader_test):
21 | genertor_net.eval()
22 | parameter_net.eval()
23 | clear_psnrs,clear_ssims = [],[]
24 | for i, (inputs, targets,name) in tqdm(enumerate(loader_test), total=len(loader_test), leave=False,desc="测试中"):
25 |
26 | h, w = inputs.shape[2], inputs.shape[3]
27 |
28 |
29 | if h>w:
30 | max_h = int(math.ceil(h / 512)) * 512
31 | max_w = int(math.ceil(w / 512)) * 512
32 | else:
33 | max_h = int(math.ceil(h / 256)) * 256
34 | max_w = int(math.ceil(w / 256)) * 256
35 | inputs, ori_left, ori_right, ori_top, ori_down = padding_image(inputs, max_h, max_w)
36 |
37 | inputs = inputs.to(device)
38 | targets = targets.to(device)
39 |
40 | pred = genertor_net(inputs)
41 | _, dehazy_pred = parameter_net(inputs, pred)
42 |
43 |
44 | pred = pred.data[:, :, ori_top:ori_down, ori_left:ori_right]
45 | dehazy_pred = dehazy_pred.data[:, :, ori_top:ori_down, ori_left:ori_right]
46 |
47 | ssim1 = ssim(pred, targets).item()
48 | psnr1 = psnr(pred, targets)
49 |
50 | ssim11 = ssim(dehazy_pred, targets).item()
51 | psnr11 = psnr(dehazy_pred, targets)
52 |
53 | save_image(inputs.data[:1], os.path.join(opt.out_hazy_path,"%s.png" % name[0]))
54 | save_image(targets.data[:1], os.path.join(opt.out_gt_path,"%s.png" % name[0]))
55 |
56 | if psnr11 >= psnr1:
57 | save_image(dehazy_pred.data[:1], os.path.join(opt.out_clear_path,"%s.png" % name[0]))
58 | clear_ssims.append(ssim11)
59 | clear_psnrs.append(psnr11)
60 | else:
61 | save_image(pred.data[:1], os.path.join(opt.out_clear_path,"%s.png" % name[0]))
62 | clear_ssims.append(ssim1)
63 | clear_psnrs.append(psnr1)
64 |
65 | return np.mean(clear_ssims), np.mean(clear_psnrs)
66 |
67 | if __name__ == '__main__':
68 | print(test(generator,parameterNet,test_dataloader))
--------------------------------------------------------------------------------
/UBRFC/train.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from tqdm import tqdm
3 | from test import test
4 | from Declaration import *
5 |
6 | def train():
7 | global lr
8 | generator.train()
9 | discriminator.train()
10 | parameterNet.train()
11 | psnrs,ssims = [],[]
12 | for epoch in range(iter_num + 1, epochs):
13 | loss_total = 0
14 | for i, (x, y,_) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False,
15 | desc="epoch is %d" % epoch):
16 |
17 | x = x.to(device)
18 | y = y.to(device)
19 |
20 | real_label = torch.ones((x.size()[0], 1, d_out_size, d_out_size), requires_grad=False).to(device)
21 | fake_label = torch.zeros((x.size()[0], 1, d_out_size, d_out_size), requires_grad=False).to(device)
22 |
23 | real_out = discriminator(y)
24 | loss_real_D = criterion(real_out, real_label)
25 |
26 | fake_img = generator(x)
27 |
28 | fake_out = discriminator(fake_img.detach())
29 | loss_fake_D = criterion(fake_out, fake_label)
30 |
31 | loss_D = (loss_real_D + loss_fake_D) / 2
32 |
33 | optimizer_D.zero_grad()
34 | loss_D.backward()
35 | optimizer_D.step()
36 | #D_loss = loss_D.item()
37 |
38 | #fake_img_ = generator(x)
39 | output = discriminator(fake_img)
40 | haze, dehaze = parameterNet(x, fake_img)
41 |
42 | loss_G = criterion(output, real_label)
43 | loss_P = criterionP(haze, x) # L1
44 | loss_Right = criterionP(fake_img, dehaze.detach()) # 右拉
45 |
46 | loss_ssim = criterionSsim(fake_img, dehaze.detach()) # 结构
47 | loss_C1 = criterionC(fake_img, dehaze.detach(), x, haze.detach()) # 对比下
48 | loss_C2 = criterionC(haze, x, fake_img.detach(), dehaze.detach()) # 对比上
49 |
50 | total_loss = loss_G + loss_P + loss_Right + 0.1 * loss_ssim + loss_C1 + loss_C2
51 |
52 | optimizer_T.zero_grad()
53 | total_loss.backward()
54 | optimizer_T.step()
55 |
56 | lr = scheduler_T.get_last_lr()[0]
57 | # G_loss = loss_G.item()
58 | # P_loss = loss_P.item()
59 | loss_total = total_loss.item()
60 |
61 | if (epoch % 10 == 0 and epoch != 0) or epoch>=30:
62 | ###测试
63 | with torch.no_grad():
64 | ssim_eval, psnr_eval = test(generator, parameterNet, test_dataloader)
65 | ssims = np.append(ssims, ssim_eval)
66 | psnrs = np.append(psnrs, psnr_eval)
67 |
68 | msg_clear = f"epoch: {epoch}|lr:{lr}|训练:[total loss: %4f]" % (loss_total) + "测试:" + f'ssim:{ssim_eval:.4f}| psnr:{psnr_eval:.4f}\n'
69 | print(msg_clear)
70 |
71 | ##保存log文件
72 | file = open('./Log/log_train/train_' + str(timeStamp) + '.txt', 'a+')
73 | file.write(msg_clear)
74 | file.close()
75 | model = {"generator_net": generator.state_dict(),"discriminator_net": discriminator.state_dict(),"parameterNet_net": parameterNet.state_dict(),
76 | "optimizer_T": optimizer_T.state_dict(),"optimizer_D": optimizer_D.state_dict(),"lr": lr, "epoch": epoch}
77 |
78 | torch.save(model, "./model/last_model.pth")
79 | if psnr_eval == psnrs[np.array(psnrs).argmax()]:
80 | max_msg = "epoch:%d,配对ssim:%f,最大psnr:%f\n" % (epoch, ssim_eval, psnr_eval)
81 | files = open('./Log/max_log/max_log' + str(timeStamp) + '.txt', 'a+')
82 | files.write(max_msg)
83 | files.close()
84 | torch.save(model, "./model/epoch%d_%f_%f_best_model.pth"%(epoch,ssim_eval,psnr_eval))
85 |
86 | scheduler_T.step()
87 | scheduler_D.step()
88 |
89 |
90 |
91 |
92 |
93 | if __name__ == '__main__':
94 | train()
95 |
--------------------------------------------------------------------------------
/images/Attention_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lose-Code/UBRFC-Net/1e9880588649ff077ad7a1cc04faa5ae883ce47a/images/Attention_00.png
--------------------------------------------------------------------------------
/images/Dense_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lose-Code/UBRFC-Net/1e9880588649ff077ad7a1cc04faa5ae883ce47a/images/Dense_00.png
--------------------------------------------------------------------------------
/images/Haze1k_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lose-Code/UBRFC-Net/1e9880588649ff077ad7a1cc04faa5ae883ce47a/images/Haze1k_00.png
--------------------------------------------------------------------------------
/images/NH_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lose-Code/UBRFC-Net/1e9880588649ff077ad7a1cc04faa5ae883ce47a/images/NH_00.png
--------------------------------------------------------------------------------
/images/Outdoor_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lose-Code/UBRFC-Net/1e9880588649ff077ad7a1cc04faa5ae883ce47a/images/Outdoor_00.png
--------------------------------------------------------------------------------
/images/framework_00.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Lose-Code/UBRFC-Net/1e9880588649ff077ad7a1cc04faa5ae883ce47a/images/framework_00.png
--------------------------------------------------------------------------------