├── FHENet.pth
├── FHENet.py
├── README.md
├── mobilenetv2.py
├── requirements.txt
└── test_RGBT.py
/FHENet.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/hjklearn/Rail-Defect-Detection/2aa8eabf4668cbe737e20ca606b1653d354ac45f/FHENet.pth
--------------------------------------------------------------------------------
/FHENet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from .mobilenetv2 import *
4 | import torch.nn.functional as F
5 | # from .van import *
6 |
7 |
8 | class BasicConv2d(nn.Module):
9 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
10 | super(BasicConv2d, self).__init__()
11 | self.conv = nn.Conv2d(in_planes, out_planes,
12 | kernel_size=kernel_size, stride=stride,
13 | padding=padding, dilation=dilation, bias=False)
14 | self.bn = nn.BatchNorm2d(out_planes)
15 | self.relu = nn.ReLU(inplace=True)
16 |
17 | def forward(self, x):
18 | x = self.conv(x)
19 | x = self.bn(x)
20 | x = self.relu(x)
21 | return x
22 |
23 | class Channel_Att(nn.Module):
24 | def __init__(self, channels, t=16):
25 | super(Channel_Att, self).__init__()
26 | self.channels = channels
27 |
28 | self.bn2 = nn.BatchNorm2d(self.channels, affine=True)
29 |
30 | def forward(self, x):
31 | residual = x
32 |
33 | x = self.bn2(x)
34 | weight_bn = self.bn2.weight.data.abs() / torch.sum(self.bn2.weight.data.abs())
35 | x = x.permute(0, 2, 3, 1).contiguous()
36 | x = torch.mul(weight_bn, x)
37 | x = x.permute(0, 3, 1, 2).contiguous()
38 | # x = torch.sigmoid(x)
39 | x = torch.sigmoid(x) * residual #
40 |
41 | return x
42 |
43 | class DSConv(nn.Module):
44 | def __init__(self, in_channel, out_channel, rate):
45 | super(DSConv, self).__init__()
46 | self.depth = nn.Sequential(nn.Conv2d(in_channel, in_channel, kernel_size=3, padding=rate, stride=1, dilation=rate, groups=in_channel),
47 | nn.BatchNorm2d(in_channel),
48 | nn.PReLU())
49 | self.point = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=1, padding=0, stride=1),
50 | nn.BatchNorm2d(out_channel),
51 | nn.PReLU())
52 |
53 | def forward(self, x):
54 | x = self.depth(x)
55 | x = self.point(x)
56 |
57 | return x
58 |
59 |
60 |
61 | class Boundry(nn.Module):
62 | def __init__(self, in_channels):
63 | super(Boundry, self).__init__()
64 |
65 | self.conv1x1_1 = nn.Conv2d(in_channels, 2 * in_channels, 1)
66 | self.conv1x1_2 = nn.Conv2d(in_channels, 2 * in_channels, 1)
67 | self.conv1x1_3 = nn.Conv2d(1, in_channels, 1)
68 | self.conv1x1_4 = nn.Conv2d(2 * in_channels, in_channels, 1)
69 |
70 | self.max2 = nn.MaxPool2d(kernel_size=2, stride=2)
71 | self.max4 = nn.MaxPool2d(kernel_size=4, stride=4)
72 | self.max8 = nn.MaxPool2d(kernel_size=8, stride=8)
73 |
74 |
75 | self.conv3_d1 = nn.Sequential(
76 | nn.Conv2d(1, in_channels, kernel_size=(3, 1), padding=(1, 0), stride=1, dilation=(1, 1)
77 | ),
78 | nn.Conv2d(in_channels, in_channels, kernel_size=(1, 3), padding=(0, 1), stride=1, dilation=(1, 1)
79 | ),
80 | nn.BatchNorm2d(in_channels),
81 | nn.ReLU())
82 | self.conv3_d2 = nn.Sequential(
83 | nn.Conv2d(6 * in_channels, in_channels, kernel_size=(3, 1), padding=(1, 0), stride=1, dilation=(1, 1)
84 | ),
85 | nn.Conv2d(in_channels, in_channels, kernel_size=(1, 3), padding=(0, 1), stride=1, dilation=(1, 1)
86 | ),
87 | nn.BatchNorm2d(in_channels),
88 | nn.ReLU())
89 |
90 |
91 | def forward(self, rgb, d):
92 | rgb_c = self.conv1x1_1(rgb)
93 | d_c = self.conv1x1_2(d)
94 | mul1 = rgb_c.mul(d_c)
95 | # add = torch.cat([mul1, rgb_c, d_c], dim=1)
96 | # add = self.conv3_d1(add)
97 | add = mul1 + rgb_c + d_c
98 | add_c = self.conv1x1_4(add)
99 |
100 | avgmax1 = self.max2(add)
101 | avgmax3 = self.max4(add)
102 | avgmax5 = self.max8(add)
103 | max1, _ = torch.max(add, dim=1, keepdim=True)
104 | max1 = self.conv1x1_3(max1)
105 |
106 | avgmax1_up = F.interpolate(input=avgmax1, size=(add.size()[2], add.size()[3]))
107 | avgmax3_up = F.interpolate(input=avgmax3, size=(add.size()[2], add.size()[3]))
108 | avgmax5_up = F.interpolate(input=avgmax5, size=(add.size()[2], add.size()[3]))
109 |
110 | cat = torch.cat([avgmax1_up, avgmax3_up, avgmax5_up], dim=1)
111 | cat_conv = self.conv3_d2(cat)
112 | out = cat_conv + max1 + add_c
113 |
114 |
115 | return out
116 |
117 |
118 |
119 | class fusion(nn.Module):
120 | def __init__(self, in_channels):
121 | super(fusion, self).__init__()
122 | self.sigmoid = nn.Sigmoid()
123 | self.conv3_d = nn.Sequential(
124 | nn.Conv2d(2 * in_channels, in_channels, kernel_size=(3, 1), padding=(1, 0), stride=1, dilation=(1, 1)
125 | ),
126 | nn.Conv2d(in_channels, in_channels, kernel_size=(1, 3), padding=(0, 1), stride=1, dilation=(1, 1)
127 | ),
128 | nn.BatchNorm2d(in_channels),
129 | nn.ReLU())
130 |
131 | def forward(self, rgb, t):
132 | mul_rt = rgb.mul(t)
133 |
134 | rgb_sig = self.sigmoid(rgb)
135 | t_sig = self.sigmoid(t)
136 |
137 | mul_r = rgb_sig.mul(t)
138 | add_r = mul_r + rgb
139 | mul_t = t_sig.mul(rgb)
140 | add_t = mul_t + t
141 |
142 | r_mul = add_r.mul(mul_rt)
143 | t_mul = add_t.mul(mul_rt)
144 |
145 | cat_all = torch.cat((r_mul, t_mul), dim=1)
146 | out = self.conv3_d(cat_all)
147 |
148 |
149 | return out
150 |
151 |
152 |
153 | class MFI(nn.Module):
154 | def __init__(self, in_channels):
155 | super(MFI, self).__init__()
156 |
157 | self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size=1)
158 | self.conv2 = nn.Conv2d(in_channels, in_channels, kernel_size=1)
159 | self.avgmax = nn.AdaptiveMaxPool2d(1)
160 | self.sig = nn.Sigmoid()
161 | self.conv3_d1 = nn.Sequential(
162 | nn.Conv2d(2 * in_channels, in_channels, kernel_size=(3, 1), padding=(1, 0), stride=1, dilation=(1, 1)
163 | ),
164 | nn.Conv2d(in_channels, in_channels, kernel_size=(1, 3), padding=(0, 1), stride=1, dilation=(1, 1)
165 | ),
166 | nn.BatchNorm2d(in_channels),
167 | nn.ReLU())
168 |
169 | def forward(self, rgb, depth, edge):
170 |
171 | jian_r = rgb - edge
172 | jian_d = depth - edge
173 | jian_r = torch.abs(jian_r)
174 | jian_d = torch.abs(jian_d)
175 |
176 | mul_r_d = jian_r.mul(jian_d)
177 | add_r_d = jian_r + jian_d
178 |
179 | cat = torch.cat((mul_r_d, add_r_d), dim=1)
180 | cat_conv = self.conv3_d1(cat)
181 |
182 | out = cat_conv + edge
183 |
184 | return out
185 |
186 |
187 | class Mirror_model(nn.Module):
188 | def __init__(self):
189 | super(Mirror_model, self).__init__()
190 | self.layer1_rgb = mobilenet_v2().features[0:2]
191 | self.layer2_rgb = mobilenet_v2().features[2:4]
192 | self.layer3_rgb = mobilenet_v2().features[4:7]
193 | self.layer4_rgb = mobilenet_v2().features[7:17]
194 | self.layer5_rgb = mobilenet_v2().features[17:18]
195 |
196 | self.layer1_t = mobilenet_v2().features[0:2]
197 | self.layer2_t = mobilenet_v2().features[2:4]
198 | self.layer3_t = mobilenet_v2().features[4:7]
199 | self.layer4_t = mobilenet_v2().features[7:17]
200 | self.layer5_t = mobilenet_v2().features[17:18]
201 |
202 | self.boundary = Boundry(16)
203 |
204 |
205 |
206 | self.fusion1 = fusion(24)
207 | self.fusion2 = fusion(32)
208 | self.fusion3 = fusion(160)
209 | self.fusion4 = fusion(320)
210 |
211 | self.MFI1 = MFI(160)
212 | self.MFI2 = MFI(16)
213 | self.MFI3 = MFI(16)
214 |
215 |
216 | self.conv16_1_1 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
217 | nn.Conv2d(16, 1, 1))
218 | self.conv16_1_2 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
219 | nn.Conv2d(16, 1, 1))
220 | self.conv16_1_3 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
221 | nn.Conv2d(16, 1, 1))
222 | self.conv16_1_4 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
223 | nn.Conv2d(16, 1, 1))
224 | self.conv16_1_5 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
225 | nn.Conv2d(16, 1, 1))
226 |
227 | self.conv16_160 = nn.Sequential(nn.Conv2d(16, 160, 1),
228 | nn.Upsample(scale_factor=0.125, mode='bilinear', align_corners=True))
229 | self.conv16_32 = nn.Conv2d(16, 32, 1)
230 | self.conv16_24 = nn.Conv2d(16, 24, 1)
231 |
232 |
233 | self.conv32_16 = nn.Sequential(nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True),
234 | nn.Conv2d(32, 16, 1)
235 | )
236 | self.conv24_16 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
237 | nn.Conv2d(24, 16, 1)
238 | )
239 |
240 |
241 | self.conv160_16 = nn.Sequential(nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True),
242 | nn.Conv2d(160, 16, 1),
243 | )
244 | self.conv32_16 = nn.Sequential(nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True),
245 | nn.Conv2d(32, 16, 1)
246 | )
247 | self.conv24_16 = nn.Sequential(nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
248 | nn.Conv2d(24, 16, 1)
249 | )
250 |
251 |
252 |
253 | self.conv_480_160 = nn.Sequential(nn.Conv2d(480, 160, 1),
254 | nn.BatchNorm2d(160),
255 | nn.ReLU(inplace=True))
256 | self.conv_192_16 = nn.Sequential(nn.Conv2d(192, 16, 1),
257 | nn.BatchNorm2d(16),
258 | nn.ReLU(inplace=True))
259 | self.conv_40_16 = nn.Sequential(nn.Conv2d(40, 16, 1),
260 | nn.BatchNorm2d(16),
261 | nn.ReLU(inplace=True))
262 |
263 | self.conv_480_160_1 = nn.Sequential(nn.Conv2d(480, 160, 1),
264 | nn.BatchNorm2d(160),
265 | nn.ReLU(inplace=True))
266 | self.conv_192_16_1 = nn.Sequential(nn.Conv2d(192, 16, 1),
267 | nn.BatchNorm2d(16),
268 | nn.ReLU(inplace=True))
269 | self.conv_40_16_1 = nn.Sequential(nn.Conv2d(40, 16, 1),
270 | nn.BatchNorm2d(16),
271 | nn.ReLU(inplace=True))
272 |
273 | self.up2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
274 | self.up4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
275 | self.up8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
276 |
277 |
278 |
279 |
280 |
281 | def forward(self, rgb, depth):
282 | x1_rgb = self.layer1_rgb(rgb)
283 | x2_rgb = self.layer2_rgb(x1_rgb)
284 | x3_rgb = self.layer3_rgb(x2_rgb)
285 | x4_rgb = self.layer4_rgb(x3_rgb)
286 | x5_rgb = self.layer5_rgb(x4_rgb)
287 |
288 | depth = torch.cat([depth, depth, depth], dim=1)
289 | x1_depth = self.layer1_t(depth)
290 | x2_depth = self.layer2_t(x1_depth)
291 | x3_depth = self.layer3_t(x2_depth)
292 | x4_depth = self.layer4_t(x3_depth)
293 | x5_depth = self.layer5_t(x4_depth)
294 |
295 |
296 | edge = self.boundary(x1_rgb, x1_depth)
297 | edge_conv = self.conv16_1_4(edge)
298 | edge_160 = self.conv16_160(edge)
299 |
300 | x2_r_t = self.fusion1(x2_rgb, x2_depth)
301 | x2_rgb_en = x2_rgb + x2_r_t
302 | x2_depth_en = x2_depth + x2_r_t
303 |
304 | x3_r_t = self.fusion2(x3_rgb, x3_depth)
305 | x3_rgb_en = x3_rgb + x3_r_t
306 | x3_depth_en = x3_depth + x3_r_t
307 |
308 | x4_r_t = self.fusion3(x4_rgb, x4_depth)
309 | x4_rgb_en = x4_rgb + x4_r_t
310 | x4_depth_en = x4_depth + x4_r_t
311 |
312 | x5_r_t = self.fusion4(x5_rgb, x5_depth)
313 | x5_rgb_en = x5_rgb + x5_r_t
314 | x5_depth_en = x5_depth + x5_r_t
315 |
316 | x5_rgb_en_up2 = self.up2(x5_rgb_en)
317 | x5_depth_en_up2 = self.up2(x5_depth_en)
318 | cat_5_4_r = torch.cat((x5_rgb_en_up2, x4_rgb_en), dim=1)
319 | cat_5_4_r_480_160 = self.conv_480_160(cat_5_4_r)
320 | cat_5_4_t = torch.cat((x5_depth_en_up2, x4_depth_en), dim=1)
321 | cat_5_4_t_480_160 = self.conv_480_160_1(cat_5_4_t)
322 | add_5_4 = self.MFI1(cat_5_4_r_480_160, cat_5_4_t_480_160, edge_160)
323 | add_5_4_conv = self.conv160_16(add_5_4)
324 | f3 = add_5_4_conv.mul(edge) + edge
325 |
326 |
327 | cat_5_4_r_480_160_up2 = self.up2(cat_5_4_r_480_160)
328 | cat_5_4_3_r = torch.cat((cat_5_4_r_480_160_up2, x3_rgb_en), dim=1)
329 | cat_5_4_3_r_192_16 = self.conv_192_16(cat_5_4_3_r)
330 | cat_5_4_3_r_192_16_up4 = self.up4(cat_5_4_3_r_192_16)
331 | cat_5_4_t_480_160_up2 = self.up2(cat_5_4_t_480_160)
332 | cat_5_4_3_t = torch.cat((cat_5_4_t_480_160_up2, x3_depth_en), dim=1)
333 | cat_5_4_3_t_196_16 = self.conv_192_16_1(cat_5_4_3_t)
334 | cat_5_4_3_t_196_16_up4 = self.up4(cat_5_4_3_t_196_16)
335 | add_5_4_3 = self.MFI2(cat_5_4_3_r_192_16_up4, cat_5_4_3_t_196_16_up4, f3)
336 | f2 = add_5_4_3.mul(edge) + edge
337 |
338 |
339 | cat_5_4_3_r_192_16_up2 = self.up2(cat_5_4_3_r_192_16)
340 | cat_5_4_3_t_196_16_up2 = self.up2(cat_5_4_3_t_196_16)
341 | cat_5_4_3_2_r = torch.cat((cat_5_4_3_r_192_16_up2, x2_rgb_en), dim=1)
342 | cat_5_4_3_2_r_40_16 = self.conv_40_16(cat_5_4_3_2_r)
343 | cat_5_4_3_2_r_40_16_up2 = self.up2(cat_5_4_3_2_r_40_16)
344 | cat_5_4_3_2_t = torch.cat((cat_5_4_3_t_196_16_up2, x2_depth_en), dim=1)
345 | cat_5_4_3_2_t_40_16 = self.conv_40_16_1(cat_5_4_3_2_t)
346 | cat_5_4_3_2_t_40_16_up2 = self.up2(cat_5_4_3_2_t_40_16)
347 | add_5_4_3_2 = self.MFI3(cat_5_4_3_2_r_40_16_up2, cat_5_4_3_2_t_40_16_up2, f2)
348 | f1 = add_5_4_3_2
349 |
350 |
351 | out3 = self.conv16_1_1(add_5_4_conv)
352 | out2 = self.conv16_1_2(add_5_4_3)
353 | out1 = self.conv16_1_3(f1)
354 | edge1 = self.conv16_1_4(f2)
355 | edge2 = self.conv16_1_5(f3)
356 |
357 | return out1, out2, out3, edge_conv, edge1, edge2
358 |
359 |
360 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # FHENet-PyTorch
2 |
3 | The official pytorch implementation of FHENet:**Lightweight Feature Hierarchical Exploration Network for Real-Time Rail SurfaceDefect Inspection in RGB-D Images**.[[PDF](https://ieeexplore.ieee.org/document/10019291)].The model structure is as follows:
4 |
5 |
6 |
7 | # Requirements
8 | Python 3.6, Pytorch 1.7.1, Cuda 10.2, TensorboardX 2.1, opencv-python.
9 | If anthying goes wrong with environment, please check requirements.txt for details.
10 |
11 | # Feature Maps
12 | Baidu [RGB-D](https://pan.baidu.com/s/1xcK303N9WScaOHdVFqsHIg?pwd=na4e) 提取码: na4e
13 |
14 | # Comparison of results table
15 | Table I Evaluation metrics obtained from compared methods. The best results are shown in bold.
16 |
17 | | Models | Sm↑ | maxEm↑ | maxFm↑ | MAE↓ |
18 | | :----: | :-------: | :-------: | :-------: | :-------: |
19 | | DCMC | 0.484 | 0.595 | 0.498 | 0.287 |
20 | | ACSD | 0.556 | 0.670 | 0.575 | 0.360 |
21 | | DF | 0.564 | 0.713 | 0.636 | 0.241 |
22 | | CDCP | 0.574 | 0.694 | 0.591 | 0.236 |
23 | | DMRA | 0.736 | 0.834 | 0.783 | 0.141 |
24 | | HAI | 0.718 | 0.829 | 0.803 | 0.171 |
25 | | S2MA | 0.775 | 0.864 | 0.817 | 0.141 |
26 | | CONET | 0.786 | 0.878 | 0.834 | 0.101 |
27 | | EMI | 0.800 | 0.876 | 0.850 | 0.104 |
28 | | CSEP | 0.814 | 0.899 | 0.866 | 0.085 |
29 | | EDR | 0.811 | 0.893 | 0.850 | 0.082 |
30 | | BBS | 0.828 | 0.909 | 0.867 | 0.074 |
31 | | DAC | 0.824 | 0.911 | 0.875 | 0.071 |
32 | | CLA | 0.835 | 0.920 | 0.878 | 0.069 |
33 | | Ours | **0.836** | **0.926** | **0.881** | **0.064** |
34 |
35 | Table II Test results of the performance of the relevant methods. The best results are shown in bold.
36 |
37 | | Models | DCMC | ACSD | DF | CDCP | DMRA | HAI | S2MA | CONET | EMI | CSEP | EDR | BBS | DAC | CLA | Ours |
38 | | :------: | :----: | :----: | :----: | :----: | :----: | :--------: | :----: | ------ | :----: | :----: | :----: | :----: | :----: | ---------- | :--------: |
39 | | **Pre↑** | 66.16% | 55.93% | 78.88% | 73.07% | 80.36% | 73.90% | 76.91% | 86.85% | 82.65% | 85.29% | 85.32% | 86.27% | 86.71% | **87.27%** | 87.22% |
40 | | **Rec↑** | 25.46% | 63.88% | 31.02% | 36.14% | 74.18% | **91.67%** | 82.83% | 78.61% | 87.76% | 87.61% | 86.60% | 87.31% | 88.09% | 86.59% | 88.34% |
41 | | **F1↑** | 33.36% | 55.65% | 42.12% | 44.98% | 74.84% | 78.98% | 78.20% | 80.55% | 83.31% | 85.14% | 84.12% | 85.63% | 86.23% | 86.07% | **87.01%** |
42 | | **IOU↑** | 19.23% | 40.63% | 22.41% | 27.86% | 62.96% | 68.91% | 70.39% | 70.57% | 74.82% | 76.65% | 75.39% | 77.27% | 77.77% | 77.87% | **78.93%** |
43 |
44 | # Citation
45 |
46 | If you use FHENet in your academic work, please cite:
47 |
48 | @article{zhou2023fhenet,
49 | title={FHENet: Lightweight Feature Hierarchical Exploration Network for Real-Time Rail Surface Defect Inspection in RGB-D Images},
50 | author={Zhou, Wujie and Hong, Jiankang},
51 | journal={IEEE Transactions on Instrumentation and Measurement},
52 | year={2023},
53 | publisher={IEEE}
54 | }
55 |
56 | # Pretaining Model
57 |
58 | Model weights loading: [Baidu](https://pan.baidu.com/s/1X3iEf7yK65yraI4NYSWMTQ) 提取码:01xe
59 |
--------------------------------------------------------------------------------
/mobilenetv2.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torchvision.models.utils import load_state_dict_from_url
3 |
4 |
5 | __all__ = ['MobileNetV2', 'mobilenet_v2']
6 |
7 |
8 | model_urls = {
9 | 'mobilenet_v2': 'https://download.pytorch.org/models/mobilenet_v2-b0353104.pth',
10 | }
11 |
12 |
13 | def _make_divisible(v, divisor, min_value=None):
14 | """
15 | This function is taken from the original tf repo.
16 | It ensures that all layers have a channel number that is divisible by 8
17 | It can be seen here:
18 | https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py
19 | :param v:
20 | :param divisor:
21 | :param min_value:
22 | :return:
23 | """
24 | if min_value is None:
25 | min_value = divisor
26 | new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
27 | # Make sure that round down does not go down by more than 10%.
28 | if new_v < 0.9 * v:
29 | new_v += divisor
30 | return new_v
31 |
32 |
33 | class ConvBNReLU(nn.Sequential):
34 | def __init__(self, in_planes, out_planes, kernel_size=3, stride=1, dilation=1, groups=1):
35 | padding = ((kernel_size - 1) * dilation + 1) // 2
36 | # padding = (kernel_size - 1) // 2
37 | super(ConvBNReLU, self).__init__(
38 | nn.Conv2d(in_planes, out_planes, kernel_size, stride, padding, dilation=dilation, groups=groups, bias=False),
39 | nn.BatchNorm2d(out_planes),
40 | nn.ReLU6(inplace=True)
41 | )
42 |
43 |
44 | class InvertedResidual(nn.Module):
45 | def __init__(self, inp, oup, stride, expand_ratio, dilation):
46 | super(InvertedResidual, self).__init__()
47 | self.stride = stride
48 | assert stride in [1, 2]
49 |
50 | hidden_dim = int(round(inp * expand_ratio))
51 | self.use_res_connect = self.stride == 1 and inp == oup
52 |
53 | layers = []
54 | if expand_ratio != 1:
55 | # pw
56 | layers.append(ConvBNReLU(inp, hidden_dim, kernel_size=1))
57 | layers.extend([
58 | # dw
59 | ConvBNReLU(hidden_dim, hidden_dim, stride=stride, dilation=dilation, groups=hidden_dim),
60 | # pw-linear
61 | nn.Conv2d(hidden_dim, oup, 1, 1, 0, bias=False),
62 | nn.BatchNorm2d(oup),
63 | ])
64 | self.conv = nn.Sequential(*layers)
65 |
66 | def forward(self, x):
67 | if self.use_res_connect:
68 | return x + self.conv(x)
69 | else:
70 | return self.conv(x)
71 |
72 |
73 | class MobileNetV2(nn.Module):
74 | def __init__(self, num_classes=1000, width_mult=1.0, inverted_residual_setting=None, round_nearest=8):
75 | """
76 | MobileNet V2 main class
77 |
78 | Args:
79 | num_classes (int): Number of classes
80 | width_mult (float): Width multiplier - adjusts number of channels in each layer by this amount
81 | inverted_residual_setting: Network structure
82 | round_nearest (int): Round the number of channels in each layer to be a multiple of this number
83 | Set to 1 to turn off rounding
84 | """
85 | super(MobileNetV2, self).__init__()
86 | block = InvertedResidual
87 | input_channel = 32
88 | last_channel = 1280
89 |
90 | if inverted_residual_setting is None:
91 | inverted_residual_setting = [
92 | # t, c, n, s, d
93 | [1, 16, 1, 1, 1],
94 | [6, 24, 2, 2, 1],
95 | [6, 32, 3, 2, 1],
96 | [6, 64, 4, 2, 1],
97 | [6, 96, 3, 1, 2],
98 | [6, 160, 3, 1, 4],
99 | [6, 320, 1, 2, 1],
100 | ]
101 |
102 | # only check the first element, assuming user knows t,c,n,s are required
103 | # if len(inverted_residual_setting) == 0 or len(inverted_residual_setting[0]) != 4:
104 | # raise ValueError("inverted_residual_setting should be non-empty "
105 | # "or a 4-element list, got {}".format(inverted_residual_setting))
106 |
107 | # building first layer
108 | input_channel = _make_divisible(input_channel * width_mult, round_nearest)
109 | self.last_channel = _make_divisible(last_channel * max(1.0, width_mult), round_nearest)
110 | features = [ConvBNReLU(3, input_channel, stride=2)]
111 | # building inverted residual blocks
112 | for t, c, n, s, d in inverted_residual_setting:
113 | output_channel = _make_divisible(c * width_mult, round_nearest)
114 | for i in range(n):
115 | stride = s if i == 0 else 1
116 | features.append(block(input_channel, output_channel, stride, expand_ratio=t, dilation=d))
117 | input_channel = output_channel
118 | # building last several layers
119 | features.append(ConvBNReLU(input_channel, self.last_channel, kernel_size=1))
120 | # make it nn.Sequential
121 | self.features = nn.Sequential(*features)
122 |
123 | # building classifier
124 | self.classifier = nn.Sequential(
125 | nn.Dropout(0.2),
126 | nn.Linear(self.last_channel, num_classes),
127 | )
128 |
129 | # weight initialization
130 | for m in self.modules():
131 | if isinstance(m, nn.Conv2d):
132 | nn.init.kaiming_normal_(m.weight, mode='fan_out')
133 | if m.bias is not None:
134 | nn.init.zeros_(m.bias)
135 | elif isinstance(m, nn.BatchNorm2d):
136 | nn.init.ones_(m.weight)
137 | nn.init.zeros_(m.bias)
138 | elif isinstance(m, nn.Linear):
139 | nn.init.normal_(m.weight, 0, 0.01)
140 | nn.init.zeros_(m.bias)
141 |
142 | def forward(self, x):
143 | x = self.features(x)
144 | x = x.mean([2, 3])
145 | x = self.classifier(x)
146 | return x
147 |
148 |
149 | def mobilenet_v2(pretrained=True, progress=True, **kwargs):
150 | """
151 | Constructs a MobileNetV2 architecture from
152 | `"MobileNetV2: Inverted Residuals and Linear Bottlenecks" `_.
153 |
154 | Args:
155 | pretrained (bool): If True, returns a model pre-trained on ImageNet
156 | progress (bool): If True, displays a progress bar of the download to stderr
157 | """
158 | model = MobileNetV2(**kwargs)
159 | if pretrained:
160 | state_dict = load_state_dict_from_url(model_urls['mobilenet_v2'],
161 | progress=progress)
162 | model.load_state_dict(state_dict)
163 | print('loading>>>>>>>>>>')
164 | return model
165 |
166 |
167 |
168 | if __name__ == '__main__':
169 | import torch
170 | model = mobilenet_v2(pretrained=True)
171 | print(model)
172 | # model = mobilenet_v2_shallow()
173 |
174 | x = torch.randn((2, 3, 224, 224))
175 | x = model.features[0:2](x)
176 | out1 = model.features[2:4](x)
177 | out2 = model.features[4:7](out1)
178 | out3 = model.features[7:17](out2)
179 | out4 = model.features[17:18](out3)
180 | # out5 = model.features[14:17](out4)
181 | # out6 = model.features[17:18](out5)
182 | #
183 | print(x.shape)
184 | print(out1.shape)
185 | print(out2.shape)
186 | print(out3.shape)
187 | print(out4.shape)
188 | # print(out5.shape)
189 | # print(out6.shape)
190 |
191 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | absl-py==1.0.0
2 | addict==2.4.0
3 | apex==0.1
4 | astor==0.8.1
5 | astunparse==1.6.3
6 | attr==0.3.2
7 | backcall==0.2.0
8 | cached-property==1.5.2
9 | cachetools==4.2.4
10 | certifi==2022.9.24
11 | charset-normalizer==2.0.12
12 | clang==5.0
13 | colorama @ file:///tmp/build/80754af9/colorama_1607707115595/work
14 | cycler==0.11.0
15 | dataclasses @ file:///tmp/build/80754af9/dataclasses_1614363715916/work
16 | decorator==5.1.1
17 | einops==0.4.1
18 | flatbuffers==1.12
19 | gast==0.4.0
20 | google-auth==2.6.5
21 | google-auth-oauthlib==0.4.6
22 | google-pasta==0.2.0
23 | grpcio==1.48.2
24 | h5py==3.1.0
25 | idna==3.4
26 | imageio @ file:///tmp/build/80754af9/imageio_1617700267927/work
27 | importlib-metadata==4.8.3
28 | ipython==7.16.3
29 | ipython-genutils==0.2.0
30 | jedi==0.17.2
31 | joblib @ file:///tmp/build/80754af9/joblib_1613502643832/work
32 | keras==2.6.0
33 | Keras-Preprocessing==1.1.2
34 | kiwisolver==1.3.1
35 | Markdown==3.3.7
36 | matplotlib==3.3.4
37 | mkl-fft==1.3.0
38 | mkl-random==1.1.1
39 | mkl-service==2.3.0
40 | mmcv-full==1.2.6
41 | ninja==1.11.1
42 | nose @ file:///opt/conda/conda-bld/nose_1642704612149/work
43 | numpy==1.19.5
44 | oauthlib==3.2.1
45 | olefile==0.46
46 | opencv-python==4.5.5.62
47 | opt-einsum==3.3.0
48 | packaging==21.3
49 | paddle-bfloat==0.1.7
50 | paddlepaddle==2.4.0
51 | pandas==1.1.5
52 | parso==0.7.1
53 | pexpect==4.8.0
54 | pickleshare==0.7.5
55 | Pillow @ file:///tmp/build/80754af9/pillow_1625649052827/work
56 | portalocker @ file:///tmp/build/80754af9/portalocker_1617135543485/work
57 | progress @ file:///tmp/build/80754af9/progress_1614269494850/work
58 | prompt-toolkit==3.0.30
59 | protobuf==3.19.6
60 | ptflops==0.6.7
61 | ptyprocess==0.7.0
62 | pyasn1==0.4.8
63 | pyasn1-modules==0.2.8
64 | pydensecrf==1.0rc3
65 | Pygments==2.12.0
66 | pyparsing==3.0.6
67 | python-dateutil==2.8.2
68 | pytorch-ssim==0.1
69 | pytz==2022.1
70 | PyYAML==6.0
71 | requests==2.27.1
72 | requests-oauthlib==1.3.1
73 | rsa==4.9
74 | scikit-learn @ file:///tmp/build/80754af9/scikit-learn_1621365798935/work
75 | scipy @ file:///tmp/build/80754af9/scipy_1597686625380/work
76 | six @ file:///tmp/build/80754af9/six_1623709665295/work
77 | tb-nightly==2.9.0a20220420
78 | tensorboard==2.6.0
79 | tensorboard-data-server==0.6.1
80 | tensorboard-plugin-wit==1.8.1
81 | tensorboardX==2.4.1
82 | tensorflow==2.6.2
83 | tensorflow-estimator==2.6.0
84 | termcolor==1.1.0
85 | tf-slim==1.1.0
86 | thop==0.0.31.post2005241907
87 | threadpoolctl @ file:///Users/ktietz/demo/mc3/conda-bld/threadpoolctl_1629802263681/work
88 | timm==0.5.4
89 | torch==1.7.1
90 | torchaudio==0.7.0a0+a853dff
91 | torchvision==0.8.2
92 | tqdm @ file:///opt/conda/conda-bld/tqdm_1647339053476/work
93 | traitlets==4.3.3
94 | typing-extensions @ file:///tmp/build/80754af9/typing_extensions_1631814937681/work
95 | urllib3==1.26.12
96 | wcwidth==0.2.5
97 | Werkzeug==2.0.3
98 | wrapt==1.12.1
99 | yapf==0.32.0
100 | zipp==3.6.0
101 |
--------------------------------------------------------------------------------
/test_RGBT.py:
--------------------------------------------------------------------------------
1 | import torch as t
2 | from torch import nn
3 | # from RGBT_dataprocessing_CNet import testData1,testData2,testData3
4 | from train_test1.RGBT_dataprocessing_CNet import testData1
5 | from torch.utils.data import DataLoader
6 | import os
7 | from torch.autograd import Variable
8 | import matplotlib.pyplot as plt
9 | import torch
10 | from FHENet import Mirror_model
11 | import numpy as np
12 | from datetime import datetime
13 |
14 | test_dataloader1 = DataLoader(testData1, batch_size=1, shuffle=False, num_workers=4)
15 | net = Mirror_model()
16 |
17 | net.load_state_dict(t.load('../Pth/FHENet_RGB_D_SOD_rail.pth'))
18 |
19 | a = '../Documents/RGBT-EvaluationTools/SalMap/'
20 | b = 'Net_SOD_rail'
21 | c = ''
22 | path = a + b + c
23 |
24 | path1 = path
25 | isExist = os.path.exists(path1)
26 | if not isExist:
27 | os.makedirs(path1)
28 | else:
29 | print('path1 exist')
30 |
31 | with torch.no_grad():
32 | net.eval()
33 | net.cuda()
34 | test_mae = 0
35 |
36 | for i, sample in enumerate(test_dataloader1):
37 | image = sample['RGB']
38 | depth = sample['depth']
39 | label = sample['label']
40 | name = sample['name']
41 | name = "".join(name)
42 |
43 | image = Variable(image).cuda()
44 | depth = Variable(depth).cuda()
45 | label = Variable(label).cuda()
46 |
47 |
48 | out1 = net(image, depth)
49 | out = torch.sigmoid(out1[0])
50 |
51 | out_img = out.cpu().detach().numpy()
52 | out_img = out_img.squeeze()
53 |
54 | plt.imsave(path1 + name + '.png', arr=out_img, cmap='gray')
55 | print(path1 + name + '.png')
56 |
57 |
58 |
59 |
60 |
61 |
--------------------------------------------------------------------------------