├── .idea
├── .gitignore
├── Auto-fusion.iml
├── inspectionProfiles
│ └── profiles_settings.xml
├── misc.xml
├── modules.xml
└── vcs.xml
├── README.md
├── data.py
├── genotypes.py
├── images
├── 1.png
├── 2.png
├── 3.png
└── 4.png
├── metric.py
├── model_resnet50.py
├── model_vgg16.py
├── operations.py
├── test.py
└── utils.py
/.idea/.gitignore:
--------------------------------------------------------------------------------
1 | # Default ignored files
2 | /workspace.xml
--------------------------------------------------------------------------------
/.idea/Auto-fusion.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
12 |
13 |
--------------------------------------------------------------------------------
/.idea/inspectionProfiles/profiles_settings.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Auto-MSFNet
2 |
3 | This is a PyTorch implementation of the 2021 ACMMM paper "Auto-MSFNet: Search Multi-scale Fusion Network for Salient Object Detection", this paper can be download by [this link(提取码:a28y)](https://pan.baidu.com/s/14mLCSHXtnuXkjCmuu_9g8A).
4 |
5 | ## Introduction
6 |
7 | 
8 |
9 | Multi-scale features fusion plays a critical role in salient object detection. Most of existing methods have achieved remarkable performance by exploiting various multi-scale features fusion strategies. However, an elegant fusion framework requires expert knowledge and experience, heavily relying on laborious trial and error. In this paper, we propose a multi-scale features fusion framework based on Neural Architecture Search (NAS), named Auto-MSFNet. First, we design a novel search cell, named FusionCell to automatically decide multi-scale features aggregation. Rather than searching one repeatable cell stacked, we allow different FusionCells to flexibly integrate multi-level features. Simultaneously, considering features generated from CNNs are naturally spatial and channel-wise, we propose a new search space for efficiently focusing on the most relevant information. The search space mitigates incomplete object structures or over-predicted foreground regions caused by progressive fusion. Second, we propose a progressive polishing loss to further obtain exquisite boundaries by penalizing misalignment of salient object boundaries. Extensive experiments on five benchmark datasets demonstrate the effectiveness of the proposed method and achieve state-of-the-art performance on four evaluation metrics.
10 |
11 | ## The searched FusionCell structure
12 |
13 | 
14 |
15 | ## Prerequisites
16 |
17 | - Python 3.6
18 | - Pytorch 1.6.0
19 |
20 | ## Usage
21 |
22 | ### 1. Download the datasets
23 |
24 | - [PASCAL-S](http://cbi.gatech.edu/salobj/)
25 | - [ECSSD](http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html)
26 | - [HKU-IS](https://i.cs.hku.hk/~gbli/deep_saliency.html)
27 | - [DUT-OMRON](http://saliencydetection.net/dut-omron/)
28 | - [DUTS](http://saliencydetection.net/duts/)
29 |
30 | ### 2. Saliency maps & Trained model
31 |
32 | - saliency maps: ResNet-50( [Google](https://drive.google.com/file/d/1sX5NBhiFBj5SMgGvBYhPCTUsHi8XxwzA/view?usp=sharing) | [Baidu 提取码:3d22](https://pan.baidu.com/s/1eV8t5pDYnahIIV1gzhgEjg)) Vgg-16([Google](https://drive.google.com/file/d/1N8VqS0fGzmb81f4nG66ot7sNMsIKCUkh/view?usp=sharing) | [Baidu 提取码:wv61](https://pan.baidu.com/s/1ErQz8m4GH3Q4D6aDoaW14A) )
33 | - trained model: ResNet-50( [Google](https://drive.google.com/file/d/1TkJOvCNBuOjydzW-ceJBfkyCutFbYbrc/view?usp=sharing) | [Baidu 提取码:yfh8](https://pan.baidu.com/s/12S43JG4bce4cgN47D5rUnw) ). Vgg-16([Google](https://drive.google.com/file/d/1bZkU1nid_sQ8_eydRfCZOD5OCj-Vwiqk/view?usp=sharing) | [Baidu 提取码:qhqs](https://pan.baidu.com/s/1pONp-yFTdLkb0KrbjvWIcQ) )
34 | - Our quantitative comparisons
35 |
36 | 
37 |
38 | - Our qualitative comparisons
39 |
40 | 
41 |
42 | ### 3.Testing and Evaluated
43 |
44 | We use [this python tools](https://github.com/lartpang/PySODEvalToolkit) to evaluated the saliency maps.
45 |
46 | First, you need download the Pycharm and download the checkpoint (based ResNet-50 or Vgg-16).
47 |
48 | Second, you need change test.py some paths(*e.g.*, dataset path) than
49 |
50 | ```jsx
51 | run test.py
52 | ```
53 |
54 | ### 4.If you think this work is helpful, please cite
55 | ```jsx
56 | @InProceedings{Miao_2021_ACM_MM,
57 | author = {Miao {Zhang} and Tingwei {Liu} and Yongri {Piao} and ShunYu {Yao} and Huchuan {Lu}},
58 | title = {Auto-MSFNet: Search Multi-scale Fusion Network for Salient Object Detection},
59 | booktitle = "ACM Multimedia Conference 2021",
60 | year = {2021}
61 | }
62 | ```
63 | ### 5.Any questions please contact with tingwei@mail.dlut.edu.cn
64 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | from PIL import Image
4 | import torch
5 | import torch.utils.data as data
6 | import torchvision.transforms as transforms
7 | import numpy as np
8 | from torch.utils.data import Dataset
9 |
10 | class MyTestData(Dataset):
11 | """
12 | load images for testing
13 | root: director/to/images/
14 | structure:
15 | - root
16 | - images
17 | - images (images here)
18 | - masks (ground truth)
19 | """
20 |
21 | mean = np.array([0.485, 0.456, 0.406])
22 | std = np.array([0.229, 0.224, 0.225])
23 |
24 | def __init__(self,img_root,gt_root,test_size,transform=True):
25 | super(MyTestData, self).__init__()
26 | self._transform = transform
27 | self.test_size = test_size
28 | img_root = img_root
29 | gt_root = gt_root
30 |
31 | file_names = os.listdir(img_root)
32 | self.img_names = []
33 | self.gt_names = []
34 | self.names = []
35 | for i, name in enumerate(file_names):
36 | if not name.endswith('.jpg'):
37 | continue
38 | self.img_names.append(
39 | os.path.join(img_root, name[:-4] + '.jpg')
40 | )
41 | self.gt_names.append(
42 | os.path.join(gt_root,name[:-4] + '.png')
43 | )
44 | self.names.append(name[:-4])
45 |
46 | def __len__(self):
47 | return len(self.img_names)
48 |
49 | def __getitem__(self, index):
50 | gt_file = self.gt_names[index]
51 | gt = Image.open(gt_file).convert('L')
52 | gt = np.array(gt, dtype=np.int32)
53 | gt = gt / (gt.max() + 1e-8)
54 | gt = np.where(gt > 0.5, 1, 0)
55 | img_file = self.img_names[index]
56 | img = cv2.imread(img_file)[:,:,::-1].astype(np.float32)
57 | img = cv2.resize(img, dsize=(self.test_size, self.test_size), interpolation=cv2.INTER_LINEAR)
58 | name = img_file.split('/')[-1].split('.')[0]
59 |
60 | if self._transform:
61 | try:
62 | img, gt = self.transform(img,gt)
63 | except ValueError:
64 | print(name)
65 | return img, gt,name
66 | else:
67 | return img, gt,name
68 |
69 | def transform(self, img,gt):
70 | img = img.astype(np.float64) / 255
71 | img -= self.mean
72 | img /= self.std
73 | img = img.transpose(2, 0, 1)
74 | img = torch.from_numpy(img).float()
75 | return img,gt
76 |
--------------------------------------------------------------------------------
/genotypes.py:
--------------------------------------------------------------------------------
1 | from collections import namedtuple
2 |
3 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat')
4 | FeafusionGenotype = namedtuple('FeafusionGenotype', ['normal','inside'])
5 |
6 | fusion_genotype_resnet50= FeafusionGenotype(normal=[('SpatialAttention', 3, 0), ('dil_conv_3x3_dil4', 0, 0), ('dil_conv_3x3_dil4', 1, 0), ('dil_conv_3x3_dil4', 2, 0), ('dil_conv_3x3', 0, 1), ('sep_conv_3x3', 1, 1), ('sep_conv_3x3', 2, 1), ('SpatialAttention', 3, 1), ('ChannelAttention', 1, 2), ('dil_conv_3x3', 2, 2), ('sep_conv_3x3_rp2', 0, 2), ('SpatialAttention', 3, 2)], inside=[('SpatialAttention', 0), ('ChannelAttention', 1), ('sep_conv_3x3', 2), ('dil_conv_3x3_rp2', 3), ('sep_conv_3x3', 4), ('sep_conv_3x3_rp2', 0), ('SpatialAttention', 1), ('ChannelAttention', 2), ('dil_conv_3x3', 3), ('sep_conv_3x3_rp2', 4), ('dil_conv_3x3', 5), ('skip_connect', 6), ('dil_conv_3x3', 0), ('dil_conv_3x3_dil4', 1), ('dil_conv_3x3', 2), ('ChannelAttention', 3), ('skip_connect', 4), ('dil_conv_3x3_rp2', 5), ('ChannelAttention', 6)])
7 | fusion_genotype_vgg16 = FeafusionGenotype(normal=[('dil_conv_3x3', 0, 0), ('dil_conv_3x3', 2, 0), ('none', 3, 0), ('none', 1, 0), ('dil_conv_3x3_dil4', 0, 1), ('dil_conv_3x3', 1, 1), ('SpatialAttention', 2, 1), ('dil_conv_3x3_rp2', 3, 1), ('dil_conv_3x3', 0, 2), ('dil_conv_3x3', 3, 2), ('dil_conv_3x3', 2, 2), ('SpatialAttention', 1, 2)], inside=[('dil_conv_3x3_rp2', 0), ('dil_conv_3x3_rp2', 1), ('ChannelAttention', 2), ('sep_conv_3x3_rp2', 3), ('SpatialAttention', 4), ('dil_conv_3x3_rp2', 0), ('sep_conv_3x3_rp2', 1), ('dil_conv_3x3', 2), ('dil_conv_3x3_rp2', 3), ('dil_conv_3x3_rp2', 4), ('dil_conv_3x3_rp2', 5), ('SpatialAttention', 6), ('dil_conv_3x3_rp2', 0), ('sep_conv_3x3_rp2', 1), ('SpatialAttention', 2), ('dil_conv_3x3', 3), ('dil_conv_3x3_rp2', 4), ('dil_conv_3x3_rp2', 5), ('SpatialAttention', 6)])
8 |
--------------------------------------------------------------------------------
/images/1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiuTingWed/Auto-MSFNet/0dfbb5598492e06c404e3e53f534ac669d7b185c/images/1.png
--------------------------------------------------------------------------------
/images/2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiuTingWed/Auto-MSFNet/0dfbb5598492e06c404e3e53f534ac669d7b185c/images/2.png
--------------------------------------------------------------------------------
/images/3.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiuTingWed/Auto-MSFNet/0dfbb5598492e06c404e3e53f534ac669d7b185c/images/3.png
--------------------------------------------------------------------------------
/images/4.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LiuTingWed/Auto-MSFNet/0dfbb5598492e06c404e3e53f534ac669d7b185c/images/4.png
--------------------------------------------------------------------------------
/metric.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 | # @Time : 2020/7/7
3 | # @Author : Lart Pang
4 | # @Email : lartpang@163.com
5 | # @File : metric.py
6 | # @Project : HDFNet
7 | # @GitHub : https://github.com/lartpang
8 |
9 | import numpy as np
10 | from PIL import Image
11 | from scipy.ndimage import center_of_mass, convolve, distance_transform_edt as bwdist
12 |
13 |
14 | class CalFM(object):
15 | # Fmeasure(maxFm, meanFm)---Frequency-tuned salient region detection(CVPR 2009)
16 | def __init__(self, num, thds=255):
17 | self.precision = np.zeros((num, thds))
18 | self.recall = np.zeros((num, thds))
19 | self.meanF = np.zeros(num)
20 | self.idx = 0
21 | self.num = num
22 |
23 | def update(self, pred, gt):
24 | if gt.max() != 0:
25 | prediction, recall, mfmeasure = self.cal(pred, gt)
26 | self.precision[self.idx, :] = prediction
27 | self.recall[self.idx, :] = recall
28 | self.meanF[self.idx] = mfmeasure
29 | self.idx += 1
30 |
31 | def cal(self, pred, gt):
32 | ########################meanF##############################
33 | th = 2 * pred.mean()
34 | if th > 1:
35 | th = 1
36 | binary = np.zeros_like(pred)
37 | binary[pred >= th] = 1
38 | hard_gt = np.zeros_like(gt)
39 | hard_gt[gt > 0.5] = 1
40 | tp = (binary * hard_gt).sum()
41 | if tp == 0:
42 | mfmeasure = 0
43 | else:
44 | pre = tp / binary.sum()
45 | rec = tp / hard_gt.sum()
46 | mfmeasure = 1.3 * pre * rec / (0.3 * pre + rec)
47 |
48 | ########################maxF##############################
49 | pred = np.uint8(pred * 255)
50 | target = pred[gt > 0.5]
51 | nontarget = pred[gt <= 0.5]
52 | targetHist, _ = np.histogram(target, bins=range(256))
53 | nontargetHist, _ = np.histogram(nontarget, bins=range(256))
54 | targetHist = np.cumsum(np.flip(targetHist), axis=0)
55 | nontargetHist = np.cumsum(np.flip(nontargetHist), axis=0)
56 | precision = targetHist / (targetHist + nontargetHist + 1e-8)
57 | recall = targetHist / np.sum(gt)
58 | return precision, recall, mfmeasure
59 |
60 | def show(self):
61 | assert self.num == self.idx, f"{self.num}, {self.idx}"
62 | precision = self.precision.mean(axis=0)
63 | recall = self.recall.mean(axis=0)
64 | fmeasure = 1.3 * precision * recall / (0.3 * precision + recall + 1e-8)
65 | mmfmeasure = self.meanF.mean()
66 | return fmeasure, fmeasure.max(), mmfmeasure, precision, recall
67 |
68 |
69 | class CalMAE(object):
70 | # mean absolute error
71 | def __init__(self, num):
72 | # self.prediction = []
73 | self.prediction = np.zeros(num)
74 | self.idx = 0
75 | self.num = num
76 |
77 | def update(self, pred, gt):
78 | self.prediction[self.idx] = self.cal(pred, gt)
79 | self.idx += 1
80 |
81 | def cal(self, pred, gt):
82 | return np.mean(np.abs(pred - gt))
83 |
84 | def show(self):
85 | assert self.num == self.idx, f"{self.num}, {self.idx}"
86 | return self.prediction.mean()
87 |
88 |
89 | class CalSM(object):
90 | # Structure-measure: A new way to evaluate foreground maps (ICCV 2017)
91 | def __init__(self, num, alpha=0.5):
92 | self.prediction = np.zeros(num)
93 | self.alpha = alpha
94 | self.idx = 0
95 | self.num = num
96 |
97 | def update(self, pred, gt):
98 | gt = gt > 0.5
99 | self.prediction[self.idx] = self.cal(pred, gt)
100 | self.idx += 1
101 |
102 | def show(self):
103 | assert self.num == self.idx, f"{self.num}, {self.idx}"
104 | return self.prediction.mean()
105 |
106 | def cal(self, pred, gt):
107 | y = np.mean(gt)
108 | if y == 0:
109 | score = 1 - np.mean(pred)
110 | elif y == 1:
111 | score = np.mean(pred)
112 | else:
113 | score = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt)
114 | return score
115 |
116 | def object(self, pred, gt):
117 | fg = pred * gt
118 | bg = (1 - pred) * (1 - gt)
119 |
120 | u = np.mean(gt)
121 | return u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, np.logical_not(gt))
122 |
123 | def s_object(self, in1, in2):
124 | x = np.mean(in1[in2])
125 | sigma_x = np.std(in1[in2])
126 | return 2 * x / (pow(x, 2) + 1 + sigma_x + 1e-8)
127 |
128 | def region(self, pred, gt):
129 | [y, x] = center_of_mass(gt)
130 | y = int(round(y)) + 1
131 | x = int(round(x)) + 1
132 | [gt1, gt2, gt3, gt4, w1, w2, w3, w4] = self.divideGT(gt, x, y)
133 | pred1, pred2, pred3, pred4 = self.dividePred(pred, x, y)
134 |
135 | score1 = self.ssim(pred1, gt1)
136 | score2 = self.ssim(pred2, gt2)
137 | score3 = self.ssim(pred3, gt3)
138 | score4 = self.ssim(pred4, gt4)
139 |
140 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4
141 |
142 | def divideGT(self, gt, x, y):
143 | h, w = gt.shape
144 | area = h * w
145 | LT = gt[0:y, 0:x]
146 | RT = gt[0:y, x:w]
147 | LB = gt[y:h, 0:x]
148 | RB = gt[y:h, x:w]
149 |
150 | w1 = x * y / area
151 | w2 = y * (w - x) / area
152 | w3 = (h - y) * x / area
153 | w4 = (h - y) * (w - x) / area
154 |
155 | return LT, RT, LB, RB, w1, w2, w3, w4
156 |
157 | def dividePred(self, pred, x, y):
158 | h, w = pred.shape
159 | LT = pred[0:y, 0:x]
160 | RT = pred[0:y, x:w]
161 | LB = pred[y:h, 0:x]
162 | RB = pred[y:h, x:w]
163 |
164 | return LT, RT, LB, RB
165 |
166 | def ssim(self, in1, in2):
167 | in2 = np.float32(in2)
168 | h, w = in1.shape
169 | N = h * w
170 |
171 | x = np.mean(in1)
172 | y = np.mean(in2)
173 | sigma_x = np.var(in1)
174 | sigma_y = np.var(in2)
175 | sigma_xy = np.sum((in1 - x) * (in2 - y)) / (N - 1)
176 |
177 | alpha = 4 * x * y * sigma_xy
178 | beta = (x * x + y * y) * (sigma_x + sigma_y)
179 |
180 | if alpha != 0:
181 | score = alpha / (beta + 1e-8)
182 | elif alpha == 0 and beta == 0:
183 | score = 1
184 | else:
185 | score = 0
186 |
187 | return score
188 |
189 |
190 | class CalEM(object):
191 | # Enhanced-alignment Measure for Binary Foreground Map Evaluation (IJCAI 2018)
192 | def __init__(self, num):
193 | self.prediction = np.zeros(num)
194 | self.idx = 0
195 | self.num = num
196 |
197 | def update(self, pred, gt):
198 | self.prediction[self.idx] = self.cal(pred, gt)
199 | self.idx += 1
200 |
201 | def cal(self, pred, gt):
202 | th = 2 * pred.mean()
203 | if th > 1:
204 | th = 1
205 | FM = np.zeros(gt.shape)
206 | FM[pred >= th] = 1
207 | FM = np.array(FM, dtype=bool)
208 | GT = np.array(gt, dtype=bool)
209 | dFM = np.double(FM)
210 | if sum(sum(np.double(GT))) == 0:
211 | enhanced_matrix = 1.0 - dFM
212 | elif sum(sum(np.double(~GT))) == 0:
213 | enhanced_matrix = dFM
214 | else:
215 | dGT = np.double(GT)
216 | align_matrix = self.AlignmentTerm(dFM, dGT)
217 | enhanced_matrix = self.EnhancedAlignmentTerm(align_matrix)
218 | [w, h] = np.shape(GT)
219 | score = sum(sum(enhanced_matrix)) / (w * h - 1 + 1e-8)
220 | return score
221 |
222 | def AlignmentTerm(self, dFM, dGT):
223 | mu_FM = np.mean(dFM)
224 | mu_GT = np.mean(dGT)
225 | align_FM = dFM - mu_FM
226 | align_GT = dGT - mu_GT
227 | align_Matrix = 2.0 * (align_GT * align_FM) / (align_GT * align_GT + align_FM * align_FM + 1e-8)
228 | return align_Matrix
229 |
230 | def EnhancedAlignmentTerm(self, align_Matrix):
231 | enhanced = np.power(align_Matrix + 1, 2) / 4
232 | return enhanced
233 |
234 | def show(self):
235 | assert self.num == self.idx, f"{self.num}, {self.idx}"
236 | return self.prediction.mean()
237 |
238 |
239 | class CalWFM(object):
240 | def __init__(self, num, beta=1):
241 | self.scores_list = np.zeros(num)
242 | self.beta = beta
243 | self.eps = 1e-6
244 | self.idx = 0
245 | self.num = num
246 |
247 | def update(self, pred, gt):
248 | gt = gt > 0.5
249 | self.scores_list[self.idx] = 0 if gt.max() == 0 else self.cal(pred, gt)
250 | self.idx += 1
251 |
252 | def matlab_style_gauss2D(self, shape=(7, 7), sigma=5):
253 | """
254 | 2D gaussian mask - should give the same result as MATLAB's
255 | fspecial('gaussian',[shape],[sigma])
256 | """
257 | m, n = [(ss - 1.0) / 2.0 for ss in shape]
258 | y, x = np.ogrid[-m : m + 1, -n : n + 1]
259 | h = np.exp(-(x * x + y * y) / (2.0 * sigma * sigma))
260 | h[h < np.finfo(h.dtype).eps * h.max()] = 0
261 | sumh = h.sum()
262 | if sumh != 0:
263 | h /= sumh
264 | return h
265 |
266 | def cal(self, pred, gt):
267 | # [Dst,IDXT] = bwdist(dGT);
268 | Dst, Idxt = bwdist(gt == 0, return_indices=True)
269 |
270 | # %Pixel dependency
271 | # E = abs(FG-dGT);
272 | E = np.abs(pred - gt)
273 | # Et = E;
274 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region
275 | Et = np.copy(E)
276 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]]
277 |
278 | # K = fspecial('gaussian',7,5);
279 | # EA = imfilter(Et,K);
280 | # MIN_E_EA(GT & EA= 0
323 | assert gt.max() <= 1 and gt.min() >= 0
324 |
325 | self.cal_mae.update(pred, gt)
326 | self.cal_fm.update(pred, gt)
327 | self.cal_sm.update(pred, gt)
328 | self.cal_em.update(pred, gt)
329 | self.cal_wfm.update(pred, gt)
330 |
331 | def show(self):
332 | MAE = self.cal_mae.show()
333 | _, Maxf, Meanf, _, _, = self.cal_fm.show()
334 | SM = self.cal_sm.show()
335 | EM = self.cal_em.show()
336 | WFM = self.cal_wfm.show()
337 | results = {
338 | "MaxF": Maxf,
339 | "MeanF": Meanf,
340 | "WFM": WFM,
341 | "MAE": MAE,
342 | "SM": SM,
343 | "EM": EM,
344 | }
345 | return results
346 |
347 |
348 | if __name__ == "__main__":
349 | pred = Image
350 |
--------------------------------------------------------------------------------
/model_resnet50.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from operations import *
5 | from torch.autograd import Variable
6 | from utils import drop_path
7 | import genotypes
8 |
9 |
10 |
11 | class Bottleneck(nn.Module):
12 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
13 | super(Bottleneck, self).__init__()
14 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
15 | self.bn1 = nn.BatchNorm2d(planes)
16 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=(3 * dilation - 1) // 2,
17 | bias=False,
18 | dilation=dilation)
19 | self.bn2 = nn.BatchNorm2d(planes)
20 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
21 | self.bn3 = nn.BatchNorm2d(planes * 4)
22 | self.downsample = downsample
23 |
24 | def forward(self, x):
25 | residual = x
26 | out = F.relu(self.bn1(self.conv1(x)), inplace=True)
27 | out = F.relu(self.bn2(self.conv2(out)), inplace=True)
28 | out = self.bn3(self.conv3(out))
29 | if self.downsample is not None:
30 | residual = self.downsample(x)
31 | return F.relu(out + residual, inplace=True)
32 |
33 |
34 | class ResNet(nn.Module):
35 | def __init__(self):
36 | super(ResNet, self).__init__()
37 | self.inplanes = 64
38 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
39 | self.bn1 = nn.BatchNorm2d(64)
40 | self.layer1 = self.make_layer(64, 3, stride=1, dilation=1)
41 | self.layer2 = self.make_layer(128, 4, stride=2, dilation=1)
42 | self.layer3 = self.make_layer(256, 6, stride=2, dilation=1)
43 | self.layer4 = self.make_layer(512, 3, stride=2, dilation=1)
44 |
45 | out_channel = 128
46 | self.conv5 = nn.Conv2d(2048, out_channel, kernel_size=3, stride=1, padding=1)
47 | self.bn5 = nn.BatchNorm2d(out_channel)
48 | self.conv4 = nn.Conv2d(1024, out_channel, kernel_size=3, stride=1, padding=1)
49 | self.bn4 = nn.BatchNorm2d(out_channel)
50 | self.conv3 = nn.Conv2d(512, out_channel, kernel_size=3, stride=1, padding=1)
51 | self.bn3 = nn.BatchNorm2d(out_channel)
52 | self.conv2 = nn.Conv2d(256, out_channel, kernel_size=3, stride=1, padding=1)
53 | self.bn2 = nn.BatchNorm2d(out_channel)
54 |
55 |
56 | def make_layer(self, planes, blocks, stride, dilation):
57 | downsample = None
58 | if stride != 1 or self.inplanes != planes * 4:
59 | downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * 4, kernel_size=1, stride=stride, bias=False),
60 | nn.BatchNorm2d(planes * 4))
61 |
62 | layers = [Bottleneck(self.inplanes, planes, stride, downsample, dilation=dilation)]
63 | self.inplanes = planes * 4
64 | for _ in range(1, blocks):
65 | layers.append(Bottleneck(self.inplanes, planes, dilation=dilation))
66 | return nn.Sequential(*layers)
67 |
68 | def forward(self, x):
69 | out1 = F.relu(self.bn1(self.conv1(x)), inplace=True)
70 | out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1)
71 | out2 = self.layer1(out1)
72 | out3 = self.layer2(out2)
73 | out4 = self.layer3(out3)
74 | out5 = self.layer4(out4)
75 |
76 | out2 = F.relu(self.bn2(self.conv2(out2)), inplace=True)
77 | out3 = F.relu(self.bn3(self.conv3(out3)), inplace=True)
78 | out4 = F.relu(self.bn4(self.conv4(out4)), inplace=True)
79 | out5_ = F.relu(self.bn5(self.conv5(out5)), inplace=True)
80 |
81 |
82 | return out5_, out2, out3, out4
83 |
84 | class Featurefusioncell43(nn.Module):
85 | def __init__(self, standardShape, channel, op):
86 | super(Featurefusioncell43, self).__init__()
87 | self.standardShape = standardShape
88 | self._ops = op
89 | self.conv11 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1)
90 | self.bn11 = nn.BatchNorm2d(channel)
91 |
92 | def forward(self, fea, lowfeature):
93 | f2 = fea[0]
94 | levelfusion = fea[2]
95 | f3 = fea[1]
96 | f4 = fea[3]
97 |
98 | assert levelfusion.size()[3] == self.standardShape
99 | if lowfeature.size()[2:] != self.standardShape:
100 | lowfeature = F.interpolate(lowfeature, self.standardShape, mode='bilinear')
101 | if f2.size()[3] != self.standardShape and f2.size() != torch.Size([]):
102 | f2 = F.interpolate(f2, self.standardShape, mode='bilinear')
103 | if f3.size()[3] != self.standardShape and f3.size() != torch.Size([]):
104 | f3 = F.interpolate(f3, self.standardShape, mode='bilinear')
105 | if f4.size()[3] != self.standardShape:
106 | f4 = F.interpolate(f4, self.standardShape, mode='bilinear')
107 |
108 | z1 = f2
109 | z2 = f3
110 | z3 = levelfusion
111 | z4 = f4
112 | pre_note = [lowfeature]
113 | states = [z1, z2, z3, z4]
114 | offset = 0
115 | for i in range(4):
116 | if i == 0:
117 | s0 = states[i]
118 | s1 = self._ops[offset + i](pre_note[i])
119 | add = s0 + s1
120 | pre_note.append(add)
121 | else:
122 | p1 = states[i]
123 | s0 = self._ops[offset + i](pre_note[i])
124 | s1 = self._ops[offset + i + 1](states[i])
125 | add = s0 + s1 + p1
126 | pre_note.append(add)
127 | offset += 1
128 |
129 | out = 0
130 | for i in range(1, 5):
131 | out += pre_note[i]
132 | out = F.relu(self.bn11(self.conv11(out)), inplace=True)
133 |
134 | return out
135 |
136 |
137 | class Featurefusioncell32(nn.Module):
138 | def __init__(self, standardShape, channel, op):
139 | super(Featurefusioncell32, self).__init__()
140 | self.standardShape = standardShape
141 | self._ops = op
142 | self.conv11 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1)
143 | self.bn11 = nn.BatchNorm2d(channel)
144 |
145 | def forward(self, fea, lowfeature):
146 | f2 = fea[0]
147 | levelfusion = fea[3]
148 | f3 = fea[1]
149 | f4 = fea[2]
150 |
151 | assert levelfusion.size()[3] == self.standardShape
152 | if lowfeature.size()[2:] != self.standardShape:
153 | lowfeature = F.interpolate(lowfeature, self.standardShape, mode='bilinear')
154 | if f2.size()[3] != self.standardShape and f2.size() != torch.Size([]):
155 | f2 = F.interpolate(f2, self.standardShape, mode='bilinear')
156 | if f3.size()[3] != self.standardShape and f3.size() != torch.Size([]):
157 | f3 = F.interpolate(f3, self.standardShape, mode='bilinear')
158 | if f4.size()[3] != self.standardShape:
159 | f4 = F.interpolate(f4, self.standardShape, mode='bilinear')
160 |
161 | z1 = f2
162 | z2 = f3
163 | z3 = f4
164 | z4 = levelfusion
165 |
166 | pre_note = [lowfeature]
167 | states = [z1, z2, z3, z4]
168 | offset = 0
169 | for i in range(4):
170 | if i == 0:
171 | s0 = states[i]
172 | s1 = self._ops[offset + i](pre_note[i])
173 | add = s0 + s1
174 | pre_note.append(add)
175 | else:
176 | p1 = states[i]
177 | s0 = self._ops[offset + i](pre_note[i])
178 | s1 = self._ops[offset + i + 1](states[i])
179 | add = s0 + s1 + p1
180 | pre_note.append(add)
181 | offset += 1
182 |
183 | out = 0
184 | for i in range(1, 5):
185 | out += pre_note[i]
186 | out = F.relu(self.bn11(self.conv11(out)), inplace=True)
187 | return out
188 |
189 |
190 | class Featurefusioncell54(nn.Module):
191 | def __init__(self, standardShape, channel, op):
192 | super(Featurefusioncell54, self).__init__()
193 | self.standardShape = standardShape
194 | self._ops = op
195 | self.conv11 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1)
196 | self.bn11 = nn.BatchNorm2d(channel)
197 |
198 | def forward(self, fea):
199 | lowfeature = fea[0]
200 | levelfusion = fea[1]
201 | f2 = fea[2]
202 | f3 = fea[3]
203 |
204 | # if levelfusion is not None:
205 | assert levelfusion.size()[3] == self.standardShape
206 | if lowfeature.size()[2:] != self.standardShape:
207 | lowfeature = F.interpolate(lowfeature, self.standardShape, mode='bilinear')
208 | if f2.size()[3] != self.standardShape and f2.size() != torch.Size([]):
209 | f2 = F.interpolate(f2, self.standardShape, mode='bilinear')
210 | if f3.size()[3] != self.standardShape and f3.size() != torch.Size([]):
211 | f3 = F.interpolate(f3, self.standardShape, mode='bilinear')
212 |
213 | z1 = lowfeature
214 | z2 = levelfusion
215 | z3 = f2
216 | z4 = f3
217 |
218 | pre_note = [z1]
219 | states = [z2, z3, z4]
220 | offset = 0
221 | for i in range(4 - 1):
222 | if i == 0:
223 | s0 = states[i]
224 | s1 = self._ops[offset + i](pre_note[i])
225 | add = s0 + s1
226 | pre_note.append(add)
227 | else:
228 | p1 = states[i]
229 | s0 = self._ops[offset + i](pre_note[i])
230 | s1 = self._ops[offset + i + 1](states[i])
231 | add = s0 + s1 + p1
232 | pre_note.append(add)
233 | offset += 1
234 |
235 | out = 0
236 | for i in range(1, 4):
237 | out += pre_note[i]
238 | out = F.relu(self.bn11(self.conv11(out)), inplace=True)
239 |
240 | return out
241 |
242 |
243 | class FeatureFusion(nn.Module):
244 |
245 | def __init__(self, genotype_fusion, node=3):
246 | super(FeatureFusion, self).__init__()
247 |
248 | self._ops = nn.ModuleList()
249 | self.fnum = 4
250 | self.node = node
251 | C = 128
252 |
253 | genotype_ouside = genotype_fusion.normal
254 | genotype_inside = genotype_fusion.inside
255 | new_genotype_ouside = sorted(genotype_ouside, key=lambda x: (x[2], x[1]))
256 | op_name, op_num, _ = zip(*new_genotype_ouside)
257 |
258 | self.op_num = op_num
259 | offset = 0
260 | for i in range(self.node):
261 | for j in range(self.fnum):
262 | op = OPS[op_name[j + offset]](C, C, 1,False, True)
263 | self._ops += [op]
264 | offset += 4
265 |
266 | op_name_inside, op_num_inside = zip(*genotype_inside)
267 |
268 | k = [5, 7, 7]
269 | noteOper = []
270 | offset = 0
271 | for i in range(self.node):
272 | self._nodes = nn.ModuleList()
273 | for j in range(k[i]):
274 | op = OPS[op_name_inside[j + offset]](C, C, 1,False, True)
275 | self._nodes += [op]
276 | noteOper.append(self._nodes)
277 | offset += k[i]
278 |
279 | self.featurefusioncell54 = Featurefusioncell54(16, C, noteOper[0])
280 | self.featurefusioncell43 = Featurefusioncell43(32, C, noteOper[1])
281 | self.featurefusioncell32 = Featurefusioncell32(64, C, noteOper[2])
282 |
283 | def forward(self, out5, out2, out3, out4):
284 |
285 | states = [out5, out4, out3, out2]
286 |
287 | # 每一条边的特征权重,遍历完一个节点要clear,每一轮4个,一共12条边
288 | fea = []
289 | # 每一个fusion节点输出的tensor字典
290 | feaoutput = []
291 | offset = 0
292 | s = 0
293 |
294 | for i in range(self.node):
295 | for j, v in enumerate(self.op_num):
296 | if j == 4:
297 | break
298 | inputFea = states[v]
299 | x2 = self._ops[offset + j](inputFea)
300 | fea.append(x2)
301 |
302 | if i == 0:
303 | new_fea = self.featurefusioncell54(fea)
304 | feaoutput.append(new_fea)
305 | fea.clear()
306 |
307 | elif i == 1:
308 | new_fea = self.featurefusioncell43(fea, feaoutput[0])
309 | feaoutput.append(new_fea)
310 | fea.clear()
311 |
312 | elif i == 2:
313 | new_fea = self.featurefusioncell32(fea, feaoutput[1])
314 | feaoutput.append(new_fea)
315 | fea.clear()
316 |
317 | offset += 4
318 |
319 | return feaoutput[2],feaoutput[1],feaoutput[0]
320 |
321 | def _loss(self, input, target):
322 | logits = self(input)
323 | logits = logits.squeeze(1)
324 |
325 | return self._criterion(logits, target)
326 |
327 | class vgg16(nn.Module):
328 | def __init__(self,):
329 | super(vgg16, self).__init__()
330 |
331 | # original image's size = 256*256*3
332 |
333 | # conv1
334 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1)
335 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
336 | self.relu1_1 = nn.ReLU(inplace=True)
337 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
338 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
339 | self.relu1_2 = nn.ReLU(inplace=True)
340 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers
341 |
342 | # conv2
343 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
344 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
345 | self.relu2_1 = nn.ReLU(inplace=True)
346 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
347 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
348 | self.relu2_2 = nn.ReLU(inplace=True)
349 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers
350 |
351 | # conv3
352 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
353 | self.bn3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
354 | self.relu3_1 = nn.ReLU(inplace=True)
355 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
356 | self.bn3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
357 | self.relu3_2 = nn.ReLU(inplace=True)
358 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
359 | self.bn3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
360 | self.relu3_3 = nn.ReLU(inplace=True)
361 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 4 layers
362 |
363 | # conv4
364 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
365 | self.bn4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
366 | self.relu4_1 = nn.ReLU(inplace=True)
367 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
368 | self.bn4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
369 | self.relu4_2 = nn.ReLU(inplace=True)
370 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
371 | self.bn4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
372 | self.relu4_3 = nn.ReLU(inplace=True)
373 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 4 layers
374 |
375 | # conv5
376 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
377 | self.bn5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
378 | self.relu5_1 = nn.ReLU(inplace=True)
379 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
380 | self.bn5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
381 | self.relu5_2 = nn.ReLU(inplace=True)
382 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
383 | self.bn5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
384 | self.relu5_3 = nn.ReLU(inplace=True) # 1/32 4 layers
385 |
386 | out_channel = 128
387 | self.conv5 = nn.Conv2d(512, out_channel, kernel_size=3, stride=1, padding=1)
388 | self.bn5 = nn.BatchNorm2d(out_channel)
389 | self.conv4 = nn.Conv2d(512, out_channel, kernel_size=3, stride=1, padding=1)
390 | self.bn4 = nn.BatchNorm2d(out_channel)
391 | self.conv3 = nn.Conv2d(256, out_channel, kernel_size=3, stride=1, padding=1)
392 | self.bn3 = nn.BatchNorm2d(out_channel)
393 | self.conv2 = nn.Conv2d(128, out_channel, kernel_size=3, stride=1, padding=1)
394 | self.bn2 = nn.BatchNorm2d(out_channel)
395 |
396 | def forward(self, x):
397 | h = x
398 |
399 | h = self.relu1_1(self.bn1_1(self.conv1_1(h)))
400 | h = self.relu1_2(self.bn1_2(self.conv1_2(h)))
401 | h_nopool1 = h
402 | h = self.pool1(h)
403 | # pool1 = h
404 |
405 | h = self.relu2_1(self.bn2_1(self.conv2_1(h)))
406 | h = self.relu2_2(self.bn2_2(self.conv2_2(h)))
407 | h_nopool2 = h
408 | h = self.pool2(h)
409 | # pool2 = h
410 |
411 | h = self.relu3_1(self.bn3_1(self.conv3_1(h)))
412 | h = self.relu3_2(self.bn3_2(self.conv3_2(h)))
413 | h = self.relu3_3(self.bn3_3(self.conv3_3(h)))
414 | h_nopool3 = h
415 | h = self.pool3(h)
416 | # pool3 = h
417 |
418 | h = self.relu4_1(self.bn4_1(self.conv4_1(h)))
419 | h = self.relu4_2(self.bn4_2(self.conv4_2(h)))
420 | h = self.relu4_3(self.bn4_3(self.conv4_3(h)))
421 | h_nopool4 = h
422 | h = self.pool4(h)
423 |
424 | h = self.relu5_1(self.bn5_1(self.conv5_1(h)))
425 | h = self.relu5_2(self.bn5_2(self.conv5_2(h)))
426 | h = self.relu5_3(self.bn5_3(self.conv5_3(h)))
427 |
428 | out2 = F.relu(self.bn2(self.conv2(h_nopool2)), inplace=True)
429 | out3 = F.relu(self.bn3(self.conv3(h_nopool3)), inplace=True)
430 | out4 = F.relu(self.bn4(self.conv4(h_nopool4)), inplace=True)
431 | out5_ = F.relu(self.bn5(self.conv5(h)), inplace=True)
432 |
433 | return out5_, out2, out3, out4
434 |
435 | class Network_Resnet50(nn.Module):
436 |
437 | def __init__(self,genotype_fusion):
438 | super(Network_Resnet50, self).__init__()
439 | self.resnet = ResNet()
440 | self.feafusion = FeatureFusion(genotype_fusion)
441 | self.conv44 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1)
442 | self.conv55 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1)
443 | self.conv66 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1)
444 |
445 | def forward(self, input):
446 | h_, h_nopool2, h_nopool3, h_nopool4 = self.resnet(input)
447 | h_nopool2,h_nopool3,h_nopool4 = self.feafusion(h_, h_nopool2, h_nopool3, h_nopool4)
448 | h_nopool2 = F.interpolate(self.conv44(h_nopool2), size=[256, 256], mode='bilinear')
449 | h_nopool3 = F.interpolate(self.conv55(h_nopool3), size=[256, 256], mode='bilinear')
450 | h_nopool4 = F.interpolate(self.conv66(h_nopool4), size=[256, 256], mode='bilinear')
451 | return h_nopool2,h_nopool3,h_nopool4
452 |
--------------------------------------------------------------------------------
/model_vgg16.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from operations import *
5 | from torch.autograd import Variable
6 | from utils import drop_path
7 |
8 |
9 | class Bottleneck(nn.Module):
10 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1):
11 | super(Bottleneck, self).__init__()
12 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
13 | self.bn1 = nn.BatchNorm2d(planes)
14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=(3 * dilation - 1) // 2,
15 | bias=False,
16 | dilation=dilation)
17 | self.bn2 = nn.BatchNorm2d(planes)
18 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
19 | self.bn3 = nn.BatchNorm2d(planes * 4)
20 | self.downsample = downsample
21 |
22 | def forward(self, x):
23 | residual = x
24 | out = F.relu(self.bn1(self.conv1(x)), inplace=True)
25 | out = F.relu(self.bn2(self.conv2(out)), inplace=True)
26 | out = self.bn3(self.conv3(out))
27 | if self.downsample is not None:
28 | residual = self.downsample(x)
29 | return F.relu(out + residual, inplace=True)
30 |
31 |
32 | class vgg16(nn.Module):
33 | def __init__(self,):
34 | super(vgg16, self).__init__()
35 |
36 | # conv1
37 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1)
38 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
39 | self.relu1_1 = nn.ReLU(inplace=True)
40 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1)
41 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
42 | self.relu1_2 = nn.ReLU(inplace=True)
43 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers
44 |
45 | # conv2
46 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1)
47 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
48 | self.relu2_1 = nn.ReLU(inplace=True)
49 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1)
50 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
51 | self.relu2_2 = nn.ReLU(inplace=True)
52 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers
53 |
54 | # conv3
55 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1)
56 | self.bn3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
57 | self.relu3_1 = nn.ReLU(inplace=True)
58 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1)
59 | self.bn3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
60 | self.relu3_2 = nn.ReLU(inplace=True)
61 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1)
62 | self.bn3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)
63 | self.relu3_3 = nn.ReLU(inplace=True)
64 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 4 layers
65 |
66 | # conv4
67 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1)
68 | self.bn4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
69 | self.relu4_1 = nn.ReLU(inplace=True)
70 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1)
71 | self.bn4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
72 | self.relu4_2 = nn.ReLU(inplace=True)
73 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1)
74 | self.bn4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
75 | self.relu4_3 = nn.ReLU(inplace=True)
76 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 4 layers
77 |
78 | # conv5
79 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1)
80 | self.bn5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
81 | self.relu5_1 = nn.ReLU(inplace=True)
82 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1)
83 | self.bn5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
84 | self.relu5_2 = nn.ReLU(inplace=True)
85 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1)
86 | self.bn5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)
87 | self.relu5_3 = nn.ReLU(inplace=True) # 1/32 4 layers
88 |
89 | out_channel = 128
90 | self.conv5 = nn.Conv2d(512, out_channel, kernel_size=3, stride=1, padding=1)
91 | self.bn5 = nn.BatchNorm2d(out_channel)
92 | self.conv4 = nn.Conv2d(512, out_channel, kernel_size=3, stride=1, padding=1)
93 | self.bn4 = nn.BatchNorm2d(out_channel)
94 | self.conv3 = nn.Conv2d(256, out_channel, kernel_size=3, stride=1, padding=1)
95 | self.bn3 = nn.BatchNorm2d(out_channel)
96 | self.conv2 = nn.Conv2d(128, out_channel, kernel_size=3, stride=1, padding=1)
97 | self.bn2 = nn.BatchNorm2d(out_channel)
98 |
99 |
100 | def forward(self, x):
101 | h = x
102 |
103 | h = self.relu1_1(self.bn1_1(self.conv1_1(h)))
104 | h = self.relu1_2(self.bn1_2(self.conv1_2(h)))
105 | h_nopool1 = h
106 | h = self.pool1(h)
107 | # pool1 = h
108 |
109 | h = self.relu2_1(self.bn2_1(self.conv2_1(h)))
110 | h = self.relu2_2(self.bn2_2(self.conv2_2(h)))
111 | h_nopool2 = h
112 | h = self.pool2(h)
113 | # pool2 = h
114 |
115 | h = self.relu3_1(self.bn3_1(self.conv3_1(h)))
116 | h = self.relu3_2(self.bn3_2(self.conv3_2(h)))
117 | h = self.relu3_3(self.bn3_3(self.conv3_3(h)))
118 | h_nopool3 = h
119 | h = self.pool3(h)
120 | # pool3 = h
121 |
122 | h = self.relu4_1(self.bn4_1(self.conv4_1(h)))
123 | h = self.relu4_2(self.bn4_2(self.conv4_2(h)))
124 | h = self.relu4_3(self.bn4_3(self.conv4_3(h)))
125 | h_nopool4 = h
126 | h = self.pool4(h)
127 |
128 | h = self.relu5_1(self.bn5_1(self.conv5_1(h)))
129 | h = self.relu5_2(self.bn5_2(self.conv5_2(h)))
130 | h = self.relu5_3(self.bn5_3(self.conv5_3(h)))
131 |
132 | out2 = F.relu(self.bn2(self.conv2(h_nopool2)), inplace=True)
133 | out3 = F.relu(self.bn3(self.conv3(h_nopool3)), inplace=True)
134 | out4 = F.relu(self.bn4(self.conv4(h_nopool4)), inplace=True)
135 | out5_ = F.relu(self.bn5(self.conv5(h)), inplace=True)
136 |
137 | return out5_, out2, out3, out4
138 |
139 |
140 | class Featurefusioncell43(nn.Module):
141 | def __init__(self, standardShape, channel, op):
142 | super(Featurefusioncell43, self).__init__()
143 | self.standardShape = standardShape
144 | self._ops = op
145 | self.conv11 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1)
146 | self.bn11 = nn.BatchNorm2d(channel)
147 | # self.initialize()
148 |
149 | def initialize(self):
150 | weight_init(self)
151 |
152 | def forward(self, fea, lowfeature):
153 | f2 = fea[0]
154 | levelfusion = fea[2]
155 | f3 = fea[1]
156 | f4 = fea[3]
157 |
158 | if levelfusion.size()[2:] != self.standardShape:
159 | levelfusion = F.interpolate(levelfusion, self.standardShape, mode='bilinear')
160 | if lowfeature.size()[2:] != self.standardShape:
161 | lowfeature = F.interpolate(lowfeature, self.standardShape, mode='bilinear')
162 | if f2.size()[3] != self.standardShape and f2.size() != torch.Size([]):
163 | f2 = F.interpolate(f2, self.standardShape, mode='bilinear')
164 | if f3.size()[3] != self.standardShape and f3.size() != torch.Size([]):
165 | f3 = F.interpolate(f3, self.standardShape, mode='bilinear')
166 | if f4.size()[3] != self.standardShape:
167 | f4 = F.interpolate(f4, self.standardShape, mode='bilinear')
168 |
169 | z1 = f2
170 | z2 = f3
171 | z3 = levelfusion
172 | z4 = f4
173 | pre_note = [lowfeature]
174 | states = [z1, z2, z3, z4]
175 | offset = 0
176 | for i in range(4):
177 | if i == 0:
178 | s0 = states[i]
179 | s1 = self._ops[offset + i](pre_note[i])
180 | add = s0 + s1
181 | pre_note.append(add)
182 | else:
183 | p1 = states[i]
184 | s0 = self._ops[offset + i](pre_note[i])
185 | s1 = self._ops[offset + i + 1](states[i])
186 | add = s0 + s1 + p1
187 | pre_note.append(add)
188 | offset += 1
189 |
190 | out = 0
191 | for i in range(1, 5):
192 | out += pre_note[i]
193 | out = F.relu(self.bn11(self.conv11(out)), inplace=True)
194 |
195 | return out
196 |
197 |
198 | class Featurefusioncell32(nn.Module):
199 | def __init__(self, standardShape, channel, op):
200 | super(Featurefusioncell32, self).__init__()
201 | self.standardShape = standardShape
202 | self._ops = op
203 | self.conv11 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1)
204 | self.bn11 = nn.BatchNorm2d(channel)
205 |
206 | def initialize(self):
207 | weight_init(self)
208 |
209 | def forward(self, fea, lowfeature):
210 | f2 = fea[0]
211 | levelfusion = fea[3]
212 | f3 = fea[1]
213 | f4 = fea[2]
214 |
215 | if levelfusion.size()[2:] != self.standardShape:
216 | levelfusion = F.interpolate(levelfusion, self.standardShape, mode='bilinear')
217 | if lowfeature.size()[2:] != self.standardShape:
218 | lowfeature = F.interpolate(lowfeature, self.standardShape, mode='bilinear')
219 | if f2.size()[3] != self.standardShape and f2.size() != torch.Size([]):
220 | f2 = F.interpolate(f2, self.standardShape, mode='bilinear')
221 | if f3.size()[3] != self.standardShape and f3.size() != torch.Size([]):
222 | f3 = F.interpolate(f3, self.standardShape, mode='bilinear')
223 | if f4.size()[3] != self.standardShape:
224 | f4 = F.interpolate(f4, self.standardShape, mode='bilinear')
225 |
226 | z1 = f2
227 | z2 = f3
228 | z3 = f4
229 | z4 = levelfusion
230 |
231 | pre_note = [lowfeature]
232 | states = [z1, z2, z3, z4]
233 | offset = 0
234 | for i in range(4):
235 | if i == 0:
236 | s0 = states[i]
237 | s1 = self._ops[offset + i](pre_note[i])
238 | add = s0 + s1
239 | pre_note.append(add)
240 | else:
241 | p1 = states[i]
242 | s0 = self._ops[offset + i](pre_note[i])
243 | s1 = self._ops[offset + i + 1](states[i])
244 | add = s0 + s1 + p1
245 | pre_note.append(add)
246 | offset += 1
247 |
248 | out = 0
249 | for i in range(1, 5):
250 | out += pre_note[i]
251 | out = F.relu(self.bn11(self.conv11(out)), inplace=True)
252 | return out
253 |
254 |
255 | class Featurefusioncell54(nn.Module):
256 | def __init__(self, standardShape, channel, op):
257 | super(Featurefusioncell54, self).__init__()
258 | self.standardShape = standardShape
259 | self._ops = op
260 | self.conv11 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1)
261 | self.bn11 = nn.BatchNorm2d(channel)
262 |
263 | def initialize(self):
264 | weight_init(self)
265 |
266 | def forward(self, fea):
267 | lowfeature = fea[0]
268 | levelfusion = fea[1]
269 | f2 = fea[2]
270 | f3 = fea[3]
271 |
272 | if levelfusion.size()[2:] != self.standardShape:
273 | levelfusion = F.interpolate(levelfusion, self.standardShape, mode='bilinear')
274 | if lowfeature.size()[2:] != self.standardShape:
275 | lowfeature = F.interpolate(lowfeature, self.standardShape, mode='bilinear')
276 | if f2.size()[3] != self.standardShape and f2.size() != torch.Size([]):
277 | f2 = F.interpolate(f2, self.standardShape, mode='bilinear')
278 | if f3.size()[3] != self.standardShape and f3.size() != torch.Size([]):
279 | f3 = F.interpolate(f3, self.standardShape, mode='bilinear')
280 |
281 | z1 = lowfeature
282 | z2 = levelfusion
283 | z3 = f2
284 | z4 = f3
285 |
286 | pre_note = [z1]
287 | states = [z2, z3, z4]
288 | offset = 0
289 | for i in range(4 - 1):
290 | if i == 0:
291 | s0 = states[i]
292 | s1 = self._ops[offset + i](pre_note[i])
293 | add = s0 + s1
294 | pre_note.append(add)
295 | else:
296 | p1 = states[i]
297 | s0 = self._ops[offset + i](pre_note[i])
298 | s1 = self._ops[offset + i + 1](states[i])
299 | add = s0 + s1 + p1
300 | pre_note.append(add)
301 | offset += 1
302 |
303 | out = 0
304 | for i in range(1, 4):
305 | out += pre_note[i]
306 | out = F.relu(self.bn11(self.conv11(out)), inplace=True)
307 |
308 | return out
309 |
310 |
311 | class FeatureFusion(nn.Module):
312 |
313 | def __init__(self, genotype_fusion, node=3):
314 | super(FeatureFusion, self).__init__()
315 |
316 | self._ops = nn.ModuleList()
317 | self.fnum = 4
318 | self.node = node
319 | # self.none_num = []
320 | C = 128
321 |
322 | genotype_ouside = genotype_fusion.normal
323 | genotype_inside = genotype_fusion.inside
324 | new_genotype_ouside = sorted(genotype_ouside, key=lambda x: (x[2], x[1]))
325 | op_name, op_num, _ = zip(*new_genotype_ouside)
326 |
327 | self.op_num = op_num
328 | offset = 0
329 | for i in range(self.node):
330 | for j in range(self.fnum):
331 | op = OPS[op_name[j + offset]](C, C, 1,False, True)
332 | self._ops += [op]
333 | offset += 4
334 |
335 | op_name_inside, op_num_inside = zip(*genotype_inside)
336 |
337 | k = [5, 7, 7]
338 | noteOper = []
339 | offset = 0
340 | for i in range(self.node):
341 | self._nodes = nn.ModuleList()
342 | for j in range(k[i]):
343 | op = OPS[op_name_inside[j + offset]](C, C, 1,False, True)
344 | self._nodes += [op]
345 | noteOper.append(self._nodes)
346 | offset += k[i]
347 |
348 | self.featurefusioncell54 = Featurefusioncell54(16, C, noteOper[0])
349 | self.featurefusioncell43 = Featurefusioncell43(32, C, noteOper[1])
350 | self.featurefusioncell32 = Featurefusioncell32(64, C, noteOper[2])
351 |
352 | def forward(self, out5, out2, out3, out4):
353 |
354 | states = [out5, out4, out3, out2]
355 |
356 | fea = []
357 | feaoutput = []
358 | offset = 0
359 | s = 0
360 |
361 | for i in range(self.node):
362 | for j, v in enumerate(self.op_num):
363 | if j == 4:
364 | break
365 | inputFea = states[v]
366 | x2 = self._ops[offset + j](inputFea)
367 | fea.append(x2)
368 |
369 | if i == 0:
370 | new_fea = self.featurefusioncell54(fea)
371 | feaoutput.append(new_fea)
372 | fea.clear()
373 |
374 | elif i == 1:
375 | new_fea = self.featurefusioncell43(fea, feaoutput[0])
376 | feaoutput.append(new_fea)
377 | fea.clear()
378 |
379 | elif i == 2:
380 | new_fea = self.featurefusioncell32(fea, feaoutput[1])
381 | feaoutput.append(new_fea)
382 | fea.clear()
383 |
384 | offset += 4
385 |
386 | return feaoutput[2],feaoutput[1],feaoutput[0]
387 |
388 | def _loss(self, input, target):
389 | logits = self(input)
390 | logits = logits.squeeze(1)
391 |
392 | return self._criterion(logits, target)
393 |
394 |
395 | class Network_vgg16(nn.Module):
396 |
397 | def __init__(self, genotype_fusion):
398 | super(Network_vgg16, self).__init__()
399 | # self.vgg16 = Vgg16_RGB()
400 | self.vgg16 = vgg16()
401 | self.feafusion = FeatureFusion(genotype_fusion)
402 | self.conv44 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1)
403 | self.conv55 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1)
404 | self.conv66 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1)
405 |
406 | def forward(self, input):
407 | h_, h_nopool2, h_nopool3, h_nopool4 = self.vgg16(input)
408 | h_nopool2,h_nopool3,h_nopool4 = self.feafusion(h_, h_nopool2, h_nopool3, h_nopool4)
409 | h_nopool2 = F.interpolate(self.conv44(h_nopool2), size=[256, 256], mode='bilinear')
410 | h_nopool3 = F.interpolate(self.conv55(h_nopool3), size=[256, 256], mode='bilinear')
411 | h_nopool4 = F.interpolate(self.conv66(h_nopool4), size=[256, 256], mode='bilinear')
412 | return h_nopool2,h_nopool3,h_nopool4
413 |
--------------------------------------------------------------------------------
/operations.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | OPS = {
6 | 'none': lambda in_C, out_C, stride, upsample, affine: Zero(stride, upsample=upsample),
7 | 'skip_connect': lambda in_C, out_C, stride, upsample, affine: Identity(upsample=upsample),
8 | 'sep_conv_3x3': lambda in_C, out_C, stride, upsample, affine: SepConv(in_C, out_C, 3, stride, 1, affine=affine,
9 | upsample=upsample),
10 | 'sep_conv_3x3_rp2': lambda in_C, out_C, stride, upsample, affine: SepConvDouble(in_C, out_C, 3, stride, 1,
11 | affine=affine, upsample=upsample),
12 | 'dil_conv_3x3': lambda in_C, out_C, stride, upsample, affine: DilConv(in_C, out_C, 3, stride, 2, 2, affine=affine,
13 | upsample=upsample),
14 | 'dil_conv_3x3_rp2': lambda in_C, out_C, stride, upsample, affine: DilConvDouble(in_C, out_C, 3, stride, 2, 2,
15 | affine=affine, upsample=upsample),
16 | 'dil_conv_3x3_dil4': lambda in_C, out_C, stride, upsample, affine: DilConv(in_C, out_C, 3, stride, 4, 4,
17 | affine=affine, upsample=upsample),
18 |
19 | 'conv_3x3': lambda in_C, out_C, stride, upsample, affine: Conv(in_C, out_C, 3, stride, 1, affine=affine,
20 | upsample=upsample),
21 | 'conv_3x3_rp2': lambda in_C, out_C, stride, upsample, affine: ConvDouble(in_C, out_C, 3, stride, 1, affine=affine,
22 | upsample=upsample),
23 |
24 | 'SpatialAttention': lambda in_C, out_C, stride, upsample, affine: SpatialAttention(in_C,7),
25 | 'ChannelAttention': lambda in_C, out_C, stride, upsample, affine: ChannelAttention(in_C,16),
26 |
27 | }
28 |
29 |
30 | def conv3x3(in_planes, out_planes, stride):
31 | "3x3 convolution with padding"
32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
33 | padding=1, bias=False)
34 |
35 |
36 | class ChannelAttention(nn.Module):
37 | def __init__(self, in_channels, ratio):
38 | super(ChannelAttention, self).__init__()
39 |
40 | self.in_channels = in_channels
41 |
42 | self.linear_1 = nn.Linear(self.in_channels, self.in_channels // ratio)
43 | self.linear_2 = nn.Linear(self.in_channels // ratio, self.in_channels)
44 | self.conv1 = nn.Conv2d(self.in_channels, self.in_channels, kernel_size=1, bias=False)
45 | self.bn1 = nn.BatchNorm2d(self.in_channels)
46 | def forward(self, input_):
47 | n_b, n_c, h, w = input_.size()
48 |
49 | feats = F.adaptive_avg_pool2d(input_, (1, 1)).view((n_b, n_c))
50 | feats = F.relu(self.linear_1(feats))
51 | feats = torch.sigmoid(self.linear_2(feats))
52 |
53 | feats = feats.view((n_b, n_c, 1, 1))
54 | feats = feats.expand_as(input_).clone()
55 | out = torch.mul(input_, feats)
56 | out = F.relu(self.bn1(self.conv1(out)), inplace=True)
57 |
58 | return out
59 |
60 |
61 | class SpatialAttention(nn.Module):
62 | def __init__(self,in_C, kernel_size=7):
63 | super(SpatialAttention, self).__init__()
64 |
65 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
66 | padding = 3 if kernel_size == 7 else 1
67 | self.in_channels = in_C
68 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
69 | self.sigmoid = nn.Sigmoid()
70 | self.conv11 = nn.Conv2d(self.in_channels, self.in_channels, kernel_size=1, bias=False)
71 | self.bn11 = nn.BatchNorm2d(self.in_channels)
72 | def forward(self, x):
73 | input = x
74 | avg_out = torch.mean(x, dim=1, keepdim=True)
75 | max_out, _ = torch.max(x, dim=1, keepdim=True)
76 | x = torch.cat([avg_out, max_out], dim=1)
77 | x = self.conv1(x)
78 | x = self.sigmoid(x)
79 | out = input * x
80 |
81 | out = F.relu(self.bn11(self.conv11(out)), inplace=True)
82 |
83 | return out
84 |
85 |
86 | class Conv(nn.Module):
87 |
88 | def __init__(self, C_in, C_out, kernel_size, stride, padding, upsample, affine=True):
89 | super(Conv, self).__init__()
90 | self.upsample = upsample
91 | self.up = nn.Sequential(
92 | torch.nn.ReLU(inplace=False),
93 | torch.nn.Upsample(scale_factor=2, mode='bilinear')
94 | )
95 | self.op = nn.Sequential(
96 | nn.ReLU(inplace=False),
97 | nn.Conv2d(C_in, C_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
98 | nn.BatchNorm2d(C_out, affine=affine),
99 | )
100 |
101 | def forward(self, x):
102 | if self.upsample is True:
103 | x = self.up(x)
104 | return self.op(x)
105 |
106 |
107 | class ConvDouble(nn.Module):
108 |
109 | def __init__(self, C_in, C_out, kernel_size, stride, padding, upsample, affine=True):
110 | super(ConvDouble, self).__init__()
111 |
112 | self.upsample = upsample
113 | self.up = nn.Sequential(
114 | torch.nn.ReLU(inplace=False),
115 | torch.nn.Upsample(scale_factor=2, mode='bilinear')
116 | )
117 |
118 | self.op = nn.Sequential(
119 | nn.ReLU(inplace=False),
120 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
121 | nn.BatchNorm2d(C_in, affine=affine),
122 | nn.ReLU(inplace=False),
123 | nn.Conv2d(C_in, C_out, kernel_size=kernel_size, stride=1, padding=padding, bias=False),
124 | nn.BatchNorm2d(C_out, affine=affine),
125 | )
126 |
127 | def forward(self, x):
128 | if self.upsample is True:
129 | x = self.up(x)
130 | return self.op(x)
131 |
132 |
133 | class ReLUConvBN(nn.Module):
134 |
135 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True):
136 | super(ReLUConvBN, self).__init__()
137 | self.op = nn.Sequential(
138 | nn.ReLU(inplace=False),
139 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False),
140 | nn.BatchNorm2d(C_out, affine=affine)
141 | )
142 |
143 | def forward(self, x):
144 | return self.op(x)
145 |
146 |
147 | class DilConv(nn.Module):
148 |
149 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, upsample, affine=True):
150 | super(DilConv, self).__init__()
151 |
152 | self.upsample = upsample
153 | self.up = nn.Sequential(
154 | torch.nn.ReLU(inplace=False),
155 | torch.nn.Upsample(scale_factor=2, mode='bilinear')
156 | )
157 |
158 | self.op = nn.Sequential(
159 | nn.ReLU(inplace=False),
160 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
161 | bias=False),
162 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
163 | nn.BatchNorm2d(C_out, affine=affine),
164 | )
165 |
166 | def forward(self, x):
167 | if self.upsample is True:
168 | x = self.up(x)
169 | return self.op(x)
170 |
171 |
172 | class DilConvDouble(nn.Module):
173 |
174 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, upsample, affine=True):
175 | super(DilConvDouble, self).__init__()
176 | self.upsample = upsample
177 | self.up = nn.Sequential(
178 | torch.nn.ReLU(inplace=False),
179 | torch.nn.Upsample(scale_factor=2, mode='bilinear')
180 | )
181 | self.op = nn.Sequential(
182 | nn.ReLU(inplace=False),
183 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation,
184 | bias=False),
185 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
186 | nn.BatchNorm2d(C_in, affine=affine),
187 | nn.ReLU(inplace=False),
188 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation,
189 | bias=False),
190 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
191 | nn.BatchNorm2d(C_out, affine=affine),
192 | )
193 |
194 | def forward(self, x):
195 | if self.upsample is True:
196 | x = self.up(x)
197 | return self.op(x)
198 |
199 |
200 | class SepConv(nn.Module):
201 |
202 | def __init__(self, C_in, C_out, kernel_size, stride, padding, upsample, affine=True):
203 | super(SepConv, self).__init__()
204 |
205 | self.upsample = upsample
206 | self.up = nn.Sequential(
207 | torch.nn.ReLU(inplace=False),
208 | torch.nn.Upsample(scale_factor=2, mode='bilinear')
209 | )
210 |
211 | self.op = nn.Sequential(
212 | nn.ReLU(inplace=False),
213 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
214 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
215 | nn.BatchNorm2d(C_out, affine=affine),
216 | )
217 |
218 | def forward(self, x):
219 | if self.upsample is True:
220 | x = self.up(x)
221 | return self.op(x)
222 |
223 |
224 | class FactorizedReduce(nn.Module):
225 |
226 | def __init__(self, C_in, C_out, affine=True):
227 | super(FactorizedReduce, self).__init__()
228 | assert C_out % 2 == 0
229 |
230 | self.up = nn.Sequential(
231 | nn.ReLU(),
232 | nn.Upsample(scale_factor=2, mode='bilinear')
233 | )
234 |
235 | self.relu = nn.ReLU(inplace=False)
236 | self.conv_1 = nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False)
237 | self.bn = nn.BatchNorm2d(C_out, affine=affine)
238 |
239 | def forward(self, x):
240 | x = self.up(x)
241 | x = self.relu(x)
242 | out = self.conv_1(x)
243 | out = self.bn(out)
244 | return out
245 |
246 |
247 | class SepConvDouble(nn.Module):
248 |
249 | def __init__(self, C_in, C_out, kernel_size, stride, padding, upsample, affine=True):
250 | super(SepConvDouble, self).__init__()
251 |
252 | self.upsample = upsample
253 | self.up = nn.Sequential(
254 | torch.nn.ReLU(inplace=False),
255 | torch.nn.Upsample(scale_factor=2, mode='bilinear')
256 | )
257 |
258 | self.op = nn.Sequential(
259 | nn.ReLU(inplace=False),
260 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
261 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False),
262 | nn.BatchNorm2d(C_in, affine=affine),
263 | nn.ReLU(inplace=False),
264 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, bias=False),
265 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False),
266 | nn.BatchNorm2d(C_out, affine=affine),
267 | )
268 |
269 | def forward(self, x):
270 | if self.upsample is True:
271 | x = self.up(x)
272 | return self.op(x)
273 |
274 |
275 |
276 | class Identity(nn.Module):
277 |
278 | def __init__(self, upsample):
279 | super(Identity, self).__init__()
280 | self.upsample = upsample
281 | self.up = nn.Sequential(
282 | torch.nn.ReLU(inplace=False),
283 | torch.nn.Upsample(scale_factor=2, mode='bilinear')
284 | )
285 |
286 | def forward(self, x):
287 | if self.upsample == True:
288 | x = self.up(x)
289 | return x
290 |
291 |
292 | class Zero(nn.Module):
293 |
294 | def __init__(self, stride, upsample):
295 | super(Zero, self).__init__()
296 | self.stride = stride
297 | self.upsample = upsample
298 | self.up = nn.Sequential(
299 | torch.nn.ReLU(inplace=False),
300 | torch.nn.Upsample(scale_factor=2, mode='bilinear')
301 | )
302 |
303 | def forward(self, x):
304 | if self.upsample == True:
305 | x = self.up(x)
306 | else:
307 | x = x.mul(0.)
308 | return x
309 | # return x[:,:,::self.stride,::self.stride].mul(0.)
310 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import sys
3 | import glob
4 | import numpy as np
5 | import torch
6 | import utils as utils
7 | import logging
8 | import argparse
9 | import torch.nn as nn
10 | import genotypes
11 | import torch.utils
12 | import torchvision.datasets as dset
13 | import torch.backends.cudnn as cudnn
14 | import data
15 | import time
16 | import torch.nn.functional as F
17 | from scipy import misc
18 | from matplotlib.pyplot import imsave
19 | from torch.autograd import Variable
20 | from model_resnet50 import Network_Resnet50 as Network_Resnet50
21 | from model_vgg16 import Network_vgg16 as Network_vgg16
22 |
23 | from PIL import Image
24 | from torchvision.transforms import transforms
25 | from metric import *
26 | from skimage import img_as_ubyte
27 | import os
28 | import cv2
29 |
30 | os.environ['CUDA_VISIBLE_DEVICES']='0,1'
31 | parser = argparse.ArgumentParser("test_model")
32 | parser.add_argument('--batch_size', type=int, default=1, help='batch size')
33 | parser.add_argument('--test_size', type=int, default=256, help='batch size')
34 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id')
35 | parser.add_argument('--init_channels', type=int, default=128, help='num of init channels')
36 | parser.add_argument('--model_path', type=str, default='./checkpoint/Auto_MSFNet_resnet50.pt',
37 | help='path of pretrained checkpoint')
38 | parser.add_argument('--backbone', type=str, default='resnet50', help='test dataset')
39 | parser.add_argument('--fu_arch', type=str, default='fusion_genotype_resnet50', help='which architecture to use')
40 | parser.add_argument('--note', type=str, default='fusion_genotype_resnet50', help='test dataset')
41 |
42 | args = parser.parse_args()
43 | args.save = '{}-{}'.format(args.note, time.strftime("%Y%m%d-%H%M%S"))
44 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py'))
45 |
46 | log_format = '%(asctime)s %(message)s'
47 | logging.basicConfig(stream=sys.stdout, level=logging.INFO,
48 | format=log_format, datefmt='%m/%d %I:%M:%S %p')
49 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt'))
50 | fh.setFormatter(logging.Formatter(log_format))
51 | logging.getLogger().addHandler(fh)
52 |
53 | dataset = ['HKU-IS-WI1D', 'DUTS', 'DUT-OMRON', 'ECSSD', 'PASCAL-S']
54 |
55 |
56 | def main():
57 | if not torch.cuda.is_available():
58 | logging.info('no gpu device available')
59 | sys.exit(1)
60 |
61 | logging.info('gpu device = %d' % args.gpu)
62 | logging.info("args = %s", args)
63 | torch.cuda.set_device(args.gpu)
64 | genotype_fu = eval("genotypes.%s" % args.fu_arch)
65 | if args.backbone == "vgg16":
66 | model = Network_vgg16(genotype_fu)
67 | elif args.backbone == "resnet50":
68 | model = Network_Resnet50(genotype_fu)
69 |
70 | model = model.cuda()
71 | utils.load(model, args.model_path)
72 |
73 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model))
74 |
75 | for i, dataset_name in enumerate(dataset):
76 | test_image_root = '/home/oip/testData/' + dataset_name + '/test_images/'
77 | test_gt_root = '/home/oip/testData/' + dataset_name + '/test_masks/'
78 |
79 | test_data = data.MyTestData(test_image_root, test_gt_root, args.test_size)
80 | test_queue = torch.utils.data.DataLoader(
81 | test_data,
82 | batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True)
83 | num_test = len(test_data)
84 | Fmax_measure, Fm_measure, mae, S_measure = infer(test_queue, model, dataset_name, num_test)
85 | logging.info('dataset_name {}'.format(dataset_name))
86 | logging.info('Fmax-measuree %f', Fmax_measure)
87 | logging.info('Fm-measuree %f', Fm_measure)
88 | logging.info('mae %f', mae)
89 | logging.info('S-measure %f', S_measure)
90 |
91 |
92 | def infer(test_queue, model, dataset_name, num_test):
93 | model.eval()
94 | savepath = './prediction/' + dataset_name
95 | cal_fm = CalFM(num=num_test) # cal是一个对象
96 | cal_mae = CalMAE(num=num_test)
97 | cal_sm = CalSM(num=num_test)
98 | for step, (input, target, name) in enumerate(test_queue):
99 | input = input.cuda()
100 | target = torch.squeeze(target)
101 | with torch.no_grad():
102 | h_nopool2,_,_= model(input)
103 | test_output_root = os.path.join(args.save, savepath)
104 | if not os.path.exists(test_output_root):
105 | os.makedirs(test_output_root)
106 | H,W = target.shape
107 |
108 | h_nopool2 = F.interpolate(h_nopool2,(H,W),mode='bilinear')
109 | output_rgb = torch.squeeze(h_nopool2)
110 | predict_rgb = output_rgb.sigmoid().data.cpu().detach().numpy()
111 | predict_rgb = img_as_ubyte(predict_rgb)
112 | cv2.imwrite(test_output_root + '/' + name[0] + '.png', predict_rgb)
113 | target = target.cpu().detach().numpy()
114 | max_pred_array = predict_rgb.max()
115 | min_pred_array = predict_rgb.min()
116 |
117 | if max_pred_array == min_pred_array:
118 | predict_rgb = predict_rgb / 255
119 | else:
120 | predict_rgb = (predict_rgb - min_pred_array) / (max_pred_array - min_pred_array)
121 |
122 | max_target = target.max()
123 | min_target = target.min()
124 | if max_target == min_target:
125 | target = target / 255
126 | else:
127 | target = (target - min_target) / (max_target - min_target)
128 |
129 | cal_fm.update(predict_rgb, target)
130 | cal_mae.update(predict_rgb, target)
131 | cal_sm.update(predict_rgb, target)
132 |
133 |
134 | if step % 50 == 0 or step == len(test_queue) - 1:
135 | logging.info(
136 | "TestDataSet:{} Step {:03d}/{:03d} ".format(
137 | dataset_name, step, len(test_queue) - 1))
138 | _, maxf, mmf, _, _ = cal_fm.show()
139 | mae = cal_mae.show()
140 | sm = cal_sm.show()
141 | return maxf, mmf, mae, sm,
142 |
143 |
144 | if __name__ == '__main__':
145 | main()
146 |
147 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import torch
4 | import shutil
5 | import torchvision.transforms as transforms
6 | from torch.autograd import Variable
7 | import torch.nn as nn
8 |
9 | class AverageMeter(object):
10 |
11 | def __init__(self):
12 | self.reset()
13 |
14 | def reset(self):
15 | self.avg = 0
16 | self.sum = 0
17 | self.cnt = 0
18 |
19 | def update(self, val, n=1):
20 | self.sum += val * n
21 | self.cnt += n
22 | self.avg = self.sum / self.cnt
23 |
24 |
25 | def count_parameters_in_MB(model):
26 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6
27 |
28 |
29 | def save_checkpoint(state, is_best, save):
30 | filename = os.path.join(save, 'checkpoint.pth.tar')
31 | torch.save(state, filename)
32 | if is_best:
33 | best_filename = os.path.join(save, 'model_best.pth.tar')
34 | shutil.copyfile(filename, best_filename)
35 |
36 |
37 | def save(model, model_path):
38 | torch.save(model.state_dict(), model_path)
39 |
40 |
41 | def load(model, model_path):
42 | a = torch.load(model_path)
43 | model.load_state_dict(a)
44 |
45 |
46 | def drop_path(x, drop_prob):
47 | if drop_prob > 0.:
48 | keep_prob = 1.-drop_prob
49 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob))
50 | x.div_(keep_prob)
51 | x.mul_(mask)
52 | return x
53 |
54 |
55 | def create_exp_dir(path, scripts_to_save=None):
56 | if not os.path.exists(path):
57 | os.mkdir(path)
58 | print('Experiment dir : {}'.format(path))
59 |
60 | if scripts_to_save is not None:
61 | os.mkdir(os.path.join(path, 'scripts'))
62 | for script in scripts_to_save:
63 | dst_file = os.path.join(path, 'scripts', os.path.basename(script))
64 | shutil.copyfile(script, dst_file)
65 | if __name__ == '__main__':
66 | a = 'fusion_genotype_vgg16-20210707-204342'
67 | os.mkdir(a)
--------------------------------------------------------------------------------