├── .idea ├── .gitignore ├── Auto-fusion.iml ├── inspectionProfiles │ └── profiles_settings.xml ├── misc.xml ├── modules.xml └── vcs.xml ├── README.md ├── data.py ├── genotypes.py ├── images ├── 1.png ├── 2.png ├── 3.png └── 4.png ├── metric.py ├── model_resnet50.py ├── model_vgg16.py ├── operations.py ├── test.py └── utils.py /.idea/.gitignore: -------------------------------------------------------------------------------- 1 | # Default ignored files 2 | /workspace.xml -------------------------------------------------------------------------------- /.idea/Auto-fusion.iml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 13 | -------------------------------------------------------------------------------- /.idea/inspectionProfiles/profiles_settings.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 6 | -------------------------------------------------------------------------------- /.idea/misc.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | -------------------------------------------------------------------------------- /.idea/modules.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | -------------------------------------------------------------------------------- /.idea/vcs.xml: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Auto-MSFNet 2 | 3 | This is a PyTorch implementation of the 2021 ACMMM paper "Auto-MSFNet: Search Multi-scale Fusion Network for Salient Object Detection", this paper can be download by [this link(提取码:a28y)](https://pan.baidu.com/s/14mLCSHXtnuXkjCmuu_9g8A). 4 | 5 | ## Introduction 6 | 7 | ![images/1.png](images/1.png) 8 | 9 | Multi-scale features fusion plays a critical role in salient object detection. Most of existing methods have achieved remarkable performance by exploiting various multi-scale features fusion strategies. However, an elegant fusion framework requires expert knowledge and experience, heavily relying on laborious trial and error. In this paper, we propose a multi-scale features fusion framework based on Neural Architecture Search (NAS), named Auto-MSFNet. First, we design a novel search cell, named FusionCell to automatically decide multi-scale features aggregation. Rather than searching one repeatable cell stacked, we allow different FusionCells to flexibly integrate multi-level features. Simultaneously, considering features generated from CNNs are naturally spatial and channel-wise, we propose a new search space for efficiently focusing on the most relevant information. The search space mitigates incomplete object structures or over-predicted foreground regions caused by progressive fusion. Second, we propose a progressive polishing loss to further obtain exquisite boundaries by penalizing misalignment of salient object boundaries. Extensive experiments on five benchmark datasets demonstrate the effectiveness of the proposed method and achieve state-of-the-art performance on four evaluation metrics. 10 | 11 | ## The searched FusionCell structure 12 | 13 | ![images/2.png](images/2.png) 14 | 15 | ## Prerequisites 16 | 17 | - Python 3.6 18 | - Pytorch 1.6.0 19 | 20 | ## Usage 21 | 22 | ### 1. Download the datasets 23 | 24 | - [PASCAL-S](http://cbi.gatech.edu/salobj/) 25 | - [ECSSD](http://www.cse.cuhk.edu.hk/leojia/projects/hsaliency/dataset.html) 26 | - [HKU-IS](https://i.cs.hku.hk/~gbli/deep_saliency.html) 27 | - [DUT-OMRON](http://saliencydetection.net/dut-omron/) 28 | - [DUTS](http://saliencydetection.net/duts/) 29 | 30 | ### 2. Saliency maps & Trained model 31 | 32 | - saliency maps: ResNet-50( [Google](https://drive.google.com/file/d/1sX5NBhiFBj5SMgGvBYhPCTUsHi8XxwzA/view?usp=sharing) | [Baidu 提取码:3d22](https://pan.baidu.com/s/1eV8t5pDYnahIIV1gzhgEjg)) Vgg-16([Google](https://drive.google.com/file/d/1N8VqS0fGzmb81f4nG66ot7sNMsIKCUkh/view?usp=sharing) | [Baidu 提取码:wv61](https://pan.baidu.com/s/1ErQz8m4GH3Q4D6aDoaW14A) ) 33 | - trained model: ResNet-50( [Google](https://drive.google.com/file/d/1TkJOvCNBuOjydzW-ceJBfkyCutFbYbrc/view?usp=sharing) | [Baidu 提取码:yfh8](https://pan.baidu.com/s/12S43JG4bce4cgN47D5rUnw) ). Vgg-16([Google](https://drive.google.com/file/d/1bZkU1nid_sQ8_eydRfCZOD5OCj-Vwiqk/view?usp=sharing) | [Baidu 提取码:qhqs](https://pan.baidu.com/s/1pONp-yFTdLkb0KrbjvWIcQ) ) 34 | - Our quantitative comparisons 35 | 36 | ![images/3.png](images/3.png) 37 | 38 | - Our qualitative comparisons 39 | 40 | ![images/4.png](images/4.png) 41 | 42 | ### 3.Testing and Evaluated 43 | 44 | We use [this python tools](https://github.com/lartpang/PySODEvalToolkit) to evaluated the saliency maps. 45 | 46 | First, you need download the Pycharm and download the checkpoint (based ResNet-50 or Vgg-16). 47 | 48 | Second, you need change test.py some paths(*e.g.*, dataset path) than 49 | 50 | ```jsx 51 | run test.py 52 | ``` 53 | 54 | ### 4.If you think this work is helpful, please cite 55 | ```jsx 56 | @InProceedings{Miao_2021_ACM_MM, 57 | author = {Miao {Zhang} and Tingwei {Liu} and Yongri {Piao} and ShunYu {Yao} and Huchuan {Lu}}, 58 | title = {Auto-MSFNet: Search Multi-scale Fusion Network for Salient Object Detection}, 59 | booktitle = "ACM Multimedia Conference 2021", 60 | year = {2021} 61 | } 62 | ``` 63 | ### 5.Any questions please contact with tingwei@mail.dlut.edu.cn 64 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from PIL import Image 4 | import torch 5 | import torch.utils.data as data 6 | import torchvision.transforms as transforms 7 | import numpy as np 8 | from torch.utils.data import Dataset 9 | 10 | class MyTestData(Dataset): 11 | """ 12 | load images for testing 13 | root: director/to/images/ 14 | structure: 15 | - root 16 | - images 17 | - images (images here) 18 | - masks (ground truth) 19 | """ 20 | 21 | mean = np.array([0.485, 0.456, 0.406]) 22 | std = np.array([0.229, 0.224, 0.225]) 23 | 24 | def __init__(self,img_root,gt_root,test_size,transform=True): 25 | super(MyTestData, self).__init__() 26 | self._transform = transform 27 | self.test_size = test_size 28 | img_root = img_root 29 | gt_root = gt_root 30 | 31 | file_names = os.listdir(img_root) 32 | self.img_names = [] 33 | self.gt_names = [] 34 | self.names = [] 35 | for i, name in enumerate(file_names): 36 | if not name.endswith('.jpg'): 37 | continue 38 | self.img_names.append( 39 | os.path.join(img_root, name[:-4] + '.jpg') 40 | ) 41 | self.gt_names.append( 42 | os.path.join(gt_root,name[:-4] + '.png') 43 | ) 44 | self.names.append(name[:-4]) 45 | 46 | def __len__(self): 47 | return len(self.img_names) 48 | 49 | def __getitem__(self, index): 50 | gt_file = self.gt_names[index] 51 | gt = Image.open(gt_file).convert('L') 52 | gt = np.array(gt, dtype=np.int32) 53 | gt = gt / (gt.max() + 1e-8) 54 | gt = np.where(gt > 0.5, 1, 0) 55 | img_file = self.img_names[index] 56 | img = cv2.imread(img_file)[:,:,::-1].astype(np.float32) 57 | img = cv2.resize(img, dsize=(self.test_size, self.test_size), interpolation=cv2.INTER_LINEAR) 58 | name = img_file.split('/')[-1].split('.')[0] 59 | 60 | if self._transform: 61 | try: 62 | img, gt = self.transform(img,gt) 63 | except ValueError: 64 | print(name) 65 | return img, gt,name 66 | else: 67 | return img, gt,name 68 | 69 | def transform(self, img,gt): 70 | img = img.astype(np.float64) / 255 71 | img -= self.mean 72 | img /= self.std 73 | img = img.transpose(2, 0, 1) 74 | img = torch.from_numpy(img).float() 75 | return img,gt 76 | -------------------------------------------------------------------------------- /genotypes.py: -------------------------------------------------------------------------------- 1 | from collections import namedtuple 2 | 3 | Genotype = namedtuple('Genotype', 'normal normal_concat reduce reduce_concat') 4 | FeafusionGenotype = namedtuple('FeafusionGenotype', ['normal','inside']) 5 | 6 | fusion_genotype_resnet50= FeafusionGenotype(normal=[('SpatialAttention', 3, 0), ('dil_conv_3x3_dil4', 0, 0), ('dil_conv_3x3_dil4', 1, 0), ('dil_conv_3x3_dil4', 2, 0), ('dil_conv_3x3', 0, 1), ('sep_conv_3x3', 1, 1), ('sep_conv_3x3', 2, 1), ('SpatialAttention', 3, 1), ('ChannelAttention', 1, 2), ('dil_conv_3x3', 2, 2), ('sep_conv_3x3_rp2', 0, 2), ('SpatialAttention', 3, 2)], inside=[('SpatialAttention', 0), ('ChannelAttention', 1), ('sep_conv_3x3', 2), ('dil_conv_3x3_rp2', 3), ('sep_conv_3x3', 4), ('sep_conv_3x3_rp2', 0), ('SpatialAttention', 1), ('ChannelAttention', 2), ('dil_conv_3x3', 3), ('sep_conv_3x3_rp2', 4), ('dil_conv_3x3', 5), ('skip_connect', 6), ('dil_conv_3x3', 0), ('dil_conv_3x3_dil4', 1), ('dil_conv_3x3', 2), ('ChannelAttention', 3), ('skip_connect', 4), ('dil_conv_3x3_rp2', 5), ('ChannelAttention', 6)]) 7 | fusion_genotype_vgg16 = FeafusionGenotype(normal=[('dil_conv_3x3', 0, 0), ('dil_conv_3x3', 2, 0), ('none', 3, 0), ('none', 1, 0), ('dil_conv_3x3_dil4', 0, 1), ('dil_conv_3x3', 1, 1), ('SpatialAttention', 2, 1), ('dil_conv_3x3_rp2', 3, 1), ('dil_conv_3x3', 0, 2), ('dil_conv_3x3', 3, 2), ('dil_conv_3x3', 2, 2), ('SpatialAttention', 1, 2)], inside=[('dil_conv_3x3_rp2', 0), ('dil_conv_3x3_rp2', 1), ('ChannelAttention', 2), ('sep_conv_3x3_rp2', 3), ('SpatialAttention', 4), ('dil_conv_3x3_rp2', 0), ('sep_conv_3x3_rp2', 1), ('dil_conv_3x3', 2), ('dil_conv_3x3_rp2', 3), ('dil_conv_3x3_rp2', 4), ('dil_conv_3x3_rp2', 5), ('SpatialAttention', 6), ('dil_conv_3x3_rp2', 0), ('sep_conv_3x3_rp2', 1), ('SpatialAttention', 2), ('dil_conv_3x3', 3), ('dil_conv_3x3_rp2', 4), ('dil_conv_3x3_rp2', 5), ('SpatialAttention', 6)]) 8 | -------------------------------------------------------------------------------- /images/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiuTingWed/Auto-MSFNet/0dfbb5598492e06c404e3e53f534ac669d7b185c/images/1.png -------------------------------------------------------------------------------- /images/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiuTingWed/Auto-MSFNet/0dfbb5598492e06c404e3e53f534ac669d7b185c/images/2.png -------------------------------------------------------------------------------- /images/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiuTingWed/Auto-MSFNet/0dfbb5598492e06c404e3e53f534ac669d7b185c/images/3.png -------------------------------------------------------------------------------- /images/4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LiuTingWed/Auto-MSFNet/0dfbb5598492e06c404e3e53f534ac669d7b185c/images/4.png -------------------------------------------------------------------------------- /metric.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Time : 2020/7/7 3 | # @Author : Lart Pang 4 | # @Email : lartpang@163.com 5 | # @File : metric.py 6 | # @Project : HDFNet 7 | # @GitHub : https://github.com/lartpang 8 | 9 | import numpy as np 10 | from PIL import Image 11 | from scipy.ndimage import center_of_mass, convolve, distance_transform_edt as bwdist 12 | 13 | 14 | class CalFM(object): 15 | # Fmeasure(maxFm, meanFm)---Frequency-tuned salient region detection(CVPR 2009) 16 | def __init__(self, num, thds=255): 17 | self.precision = np.zeros((num, thds)) 18 | self.recall = np.zeros((num, thds)) 19 | self.meanF = np.zeros(num) 20 | self.idx = 0 21 | self.num = num 22 | 23 | def update(self, pred, gt): 24 | if gt.max() != 0: 25 | prediction, recall, mfmeasure = self.cal(pred, gt) 26 | self.precision[self.idx, :] = prediction 27 | self.recall[self.idx, :] = recall 28 | self.meanF[self.idx] = mfmeasure 29 | self.idx += 1 30 | 31 | def cal(self, pred, gt): 32 | ########################meanF############################## 33 | th = 2 * pred.mean() 34 | if th > 1: 35 | th = 1 36 | binary = np.zeros_like(pred) 37 | binary[pred >= th] = 1 38 | hard_gt = np.zeros_like(gt) 39 | hard_gt[gt > 0.5] = 1 40 | tp = (binary * hard_gt).sum() 41 | if tp == 0: 42 | mfmeasure = 0 43 | else: 44 | pre = tp / binary.sum() 45 | rec = tp / hard_gt.sum() 46 | mfmeasure = 1.3 * pre * rec / (0.3 * pre + rec) 47 | 48 | ########################maxF############################## 49 | pred = np.uint8(pred * 255) 50 | target = pred[gt > 0.5] 51 | nontarget = pred[gt <= 0.5] 52 | targetHist, _ = np.histogram(target, bins=range(256)) 53 | nontargetHist, _ = np.histogram(nontarget, bins=range(256)) 54 | targetHist = np.cumsum(np.flip(targetHist), axis=0) 55 | nontargetHist = np.cumsum(np.flip(nontargetHist), axis=0) 56 | precision = targetHist / (targetHist + nontargetHist + 1e-8) 57 | recall = targetHist / np.sum(gt) 58 | return precision, recall, mfmeasure 59 | 60 | def show(self): 61 | assert self.num == self.idx, f"{self.num}, {self.idx}" 62 | precision = self.precision.mean(axis=0) 63 | recall = self.recall.mean(axis=0) 64 | fmeasure = 1.3 * precision * recall / (0.3 * precision + recall + 1e-8) 65 | mmfmeasure = self.meanF.mean() 66 | return fmeasure, fmeasure.max(), mmfmeasure, precision, recall 67 | 68 | 69 | class CalMAE(object): 70 | # mean absolute error 71 | def __init__(self, num): 72 | # self.prediction = [] 73 | self.prediction = np.zeros(num) 74 | self.idx = 0 75 | self.num = num 76 | 77 | def update(self, pred, gt): 78 | self.prediction[self.idx] = self.cal(pred, gt) 79 | self.idx += 1 80 | 81 | def cal(self, pred, gt): 82 | return np.mean(np.abs(pred - gt)) 83 | 84 | def show(self): 85 | assert self.num == self.idx, f"{self.num}, {self.idx}" 86 | return self.prediction.mean() 87 | 88 | 89 | class CalSM(object): 90 | # Structure-measure: A new way to evaluate foreground maps (ICCV 2017) 91 | def __init__(self, num, alpha=0.5): 92 | self.prediction = np.zeros(num) 93 | self.alpha = alpha 94 | self.idx = 0 95 | self.num = num 96 | 97 | def update(self, pred, gt): 98 | gt = gt > 0.5 99 | self.prediction[self.idx] = self.cal(pred, gt) 100 | self.idx += 1 101 | 102 | def show(self): 103 | assert self.num == self.idx, f"{self.num}, {self.idx}" 104 | return self.prediction.mean() 105 | 106 | def cal(self, pred, gt): 107 | y = np.mean(gt) 108 | if y == 0: 109 | score = 1 - np.mean(pred) 110 | elif y == 1: 111 | score = np.mean(pred) 112 | else: 113 | score = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) 114 | return score 115 | 116 | def object(self, pred, gt): 117 | fg = pred * gt 118 | bg = (1 - pred) * (1 - gt) 119 | 120 | u = np.mean(gt) 121 | return u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, np.logical_not(gt)) 122 | 123 | def s_object(self, in1, in2): 124 | x = np.mean(in1[in2]) 125 | sigma_x = np.std(in1[in2]) 126 | return 2 * x / (pow(x, 2) + 1 + sigma_x + 1e-8) 127 | 128 | def region(self, pred, gt): 129 | [y, x] = center_of_mass(gt) 130 | y = int(round(y)) + 1 131 | x = int(round(x)) + 1 132 | [gt1, gt2, gt3, gt4, w1, w2, w3, w4] = self.divideGT(gt, x, y) 133 | pred1, pred2, pred3, pred4 = self.dividePred(pred, x, y) 134 | 135 | score1 = self.ssim(pred1, gt1) 136 | score2 = self.ssim(pred2, gt2) 137 | score3 = self.ssim(pred3, gt3) 138 | score4 = self.ssim(pred4, gt4) 139 | 140 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 141 | 142 | def divideGT(self, gt, x, y): 143 | h, w = gt.shape 144 | area = h * w 145 | LT = gt[0:y, 0:x] 146 | RT = gt[0:y, x:w] 147 | LB = gt[y:h, 0:x] 148 | RB = gt[y:h, x:w] 149 | 150 | w1 = x * y / area 151 | w2 = y * (w - x) / area 152 | w3 = (h - y) * x / area 153 | w4 = (h - y) * (w - x) / area 154 | 155 | return LT, RT, LB, RB, w1, w2, w3, w4 156 | 157 | def dividePred(self, pred, x, y): 158 | h, w = pred.shape 159 | LT = pred[0:y, 0:x] 160 | RT = pred[0:y, x:w] 161 | LB = pred[y:h, 0:x] 162 | RB = pred[y:h, x:w] 163 | 164 | return LT, RT, LB, RB 165 | 166 | def ssim(self, in1, in2): 167 | in2 = np.float32(in2) 168 | h, w = in1.shape 169 | N = h * w 170 | 171 | x = np.mean(in1) 172 | y = np.mean(in2) 173 | sigma_x = np.var(in1) 174 | sigma_y = np.var(in2) 175 | sigma_xy = np.sum((in1 - x) * (in2 - y)) / (N - 1) 176 | 177 | alpha = 4 * x * y * sigma_xy 178 | beta = (x * x + y * y) * (sigma_x + sigma_y) 179 | 180 | if alpha != 0: 181 | score = alpha / (beta + 1e-8) 182 | elif alpha == 0 and beta == 0: 183 | score = 1 184 | else: 185 | score = 0 186 | 187 | return score 188 | 189 | 190 | class CalEM(object): 191 | # Enhanced-alignment Measure for Binary Foreground Map Evaluation (IJCAI 2018) 192 | def __init__(self, num): 193 | self.prediction = np.zeros(num) 194 | self.idx = 0 195 | self.num = num 196 | 197 | def update(self, pred, gt): 198 | self.prediction[self.idx] = self.cal(pred, gt) 199 | self.idx += 1 200 | 201 | def cal(self, pred, gt): 202 | th = 2 * pred.mean() 203 | if th > 1: 204 | th = 1 205 | FM = np.zeros(gt.shape) 206 | FM[pred >= th] = 1 207 | FM = np.array(FM, dtype=bool) 208 | GT = np.array(gt, dtype=bool) 209 | dFM = np.double(FM) 210 | if sum(sum(np.double(GT))) == 0: 211 | enhanced_matrix = 1.0 - dFM 212 | elif sum(sum(np.double(~GT))) == 0: 213 | enhanced_matrix = dFM 214 | else: 215 | dGT = np.double(GT) 216 | align_matrix = self.AlignmentTerm(dFM, dGT) 217 | enhanced_matrix = self.EnhancedAlignmentTerm(align_matrix) 218 | [w, h] = np.shape(GT) 219 | score = sum(sum(enhanced_matrix)) / (w * h - 1 + 1e-8) 220 | return score 221 | 222 | def AlignmentTerm(self, dFM, dGT): 223 | mu_FM = np.mean(dFM) 224 | mu_GT = np.mean(dGT) 225 | align_FM = dFM - mu_FM 226 | align_GT = dGT - mu_GT 227 | align_Matrix = 2.0 * (align_GT * align_FM) / (align_GT * align_GT + align_FM * align_FM + 1e-8) 228 | return align_Matrix 229 | 230 | def EnhancedAlignmentTerm(self, align_Matrix): 231 | enhanced = np.power(align_Matrix + 1, 2) / 4 232 | return enhanced 233 | 234 | def show(self): 235 | assert self.num == self.idx, f"{self.num}, {self.idx}" 236 | return self.prediction.mean() 237 | 238 | 239 | class CalWFM(object): 240 | def __init__(self, num, beta=1): 241 | self.scores_list = np.zeros(num) 242 | self.beta = beta 243 | self.eps = 1e-6 244 | self.idx = 0 245 | self.num = num 246 | 247 | def update(self, pred, gt): 248 | gt = gt > 0.5 249 | self.scores_list[self.idx] = 0 if gt.max() == 0 else self.cal(pred, gt) 250 | self.idx += 1 251 | 252 | def matlab_style_gauss2D(self, shape=(7, 7), sigma=5): 253 | """ 254 | 2D gaussian mask - should give the same result as MATLAB's 255 | fspecial('gaussian',[shape],[sigma]) 256 | """ 257 | m, n = [(ss - 1.0) / 2.0 for ss in shape] 258 | y, x = np.ogrid[-m : m + 1, -n : n + 1] 259 | h = np.exp(-(x * x + y * y) / (2.0 * sigma * sigma)) 260 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 261 | sumh = h.sum() 262 | if sumh != 0: 263 | h /= sumh 264 | return h 265 | 266 | def cal(self, pred, gt): 267 | # [Dst,IDXT] = bwdist(dGT); 268 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 269 | 270 | # %Pixel dependency 271 | # E = abs(FG-dGT); 272 | E = np.abs(pred - gt) 273 | # Et = E; 274 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region 275 | Et = np.copy(E) 276 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 277 | 278 | # K = fspecial('gaussian',7,5); 279 | # EA = imfilter(Et,K); 280 | # MIN_E_EA(GT & EA= 0 323 | assert gt.max() <= 1 and gt.min() >= 0 324 | 325 | self.cal_mae.update(pred, gt) 326 | self.cal_fm.update(pred, gt) 327 | self.cal_sm.update(pred, gt) 328 | self.cal_em.update(pred, gt) 329 | self.cal_wfm.update(pred, gt) 330 | 331 | def show(self): 332 | MAE = self.cal_mae.show() 333 | _, Maxf, Meanf, _, _, = self.cal_fm.show() 334 | SM = self.cal_sm.show() 335 | EM = self.cal_em.show() 336 | WFM = self.cal_wfm.show() 337 | results = { 338 | "MaxF": Maxf, 339 | "MeanF": Meanf, 340 | "WFM": WFM, 341 | "MAE": MAE, 342 | "SM": SM, 343 | "EM": EM, 344 | } 345 | return results 346 | 347 | 348 | if __name__ == "__main__": 349 | pred = Image 350 | -------------------------------------------------------------------------------- /model_resnet50.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from operations import * 5 | from torch.autograd import Variable 6 | from utils import drop_path 7 | import genotypes 8 | 9 | 10 | 11 | class Bottleneck(nn.Module): 12 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 13 | super(Bottleneck, self).__init__() 14 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 15 | self.bn1 = nn.BatchNorm2d(planes) 16 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=(3 * dilation - 1) // 2, 17 | bias=False, 18 | dilation=dilation) 19 | self.bn2 = nn.BatchNorm2d(planes) 20 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 21 | self.bn3 = nn.BatchNorm2d(planes * 4) 22 | self.downsample = downsample 23 | 24 | def forward(self, x): 25 | residual = x 26 | out = F.relu(self.bn1(self.conv1(x)), inplace=True) 27 | out = F.relu(self.bn2(self.conv2(out)), inplace=True) 28 | out = self.bn3(self.conv3(out)) 29 | if self.downsample is not None: 30 | residual = self.downsample(x) 31 | return F.relu(out + residual, inplace=True) 32 | 33 | 34 | class ResNet(nn.Module): 35 | def __init__(self): 36 | super(ResNet, self).__init__() 37 | self.inplanes = 64 38 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 39 | self.bn1 = nn.BatchNorm2d(64) 40 | self.layer1 = self.make_layer(64, 3, stride=1, dilation=1) 41 | self.layer2 = self.make_layer(128, 4, stride=2, dilation=1) 42 | self.layer3 = self.make_layer(256, 6, stride=2, dilation=1) 43 | self.layer4 = self.make_layer(512, 3, stride=2, dilation=1) 44 | 45 | out_channel = 128 46 | self.conv5 = nn.Conv2d(2048, out_channel, kernel_size=3, stride=1, padding=1) 47 | self.bn5 = nn.BatchNorm2d(out_channel) 48 | self.conv4 = nn.Conv2d(1024, out_channel, kernel_size=3, stride=1, padding=1) 49 | self.bn4 = nn.BatchNorm2d(out_channel) 50 | self.conv3 = nn.Conv2d(512, out_channel, kernel_size=3, stride=1, padding=1) 51 | self.bn3 = nn.BatchNorm2d(out_channel) 52 | self.conv2 = nn.Conv2d(256, out_channel, kernel_size=3, stride=1, padding=1) 53 | self.bn2 = nn.BatchNorm2d(out_channel) 54 | 55 | 56 | def make_layer(self, planes, blocks, stride, dilation): 57 | downsample = None 58 | if stride != 1 or self.inplanes != planes * 4: 59 | downsample = nn.Sequential(nn.Conv2d(self.inplanes, planes * 4, kernel_size=1, stride=stride, bias=False), 60 | nn.BatchNorm2d(planes * 4)) 61 | 62 | layers = [Bottleneck(self.inplanes, planes, stride, downsample, dilation=dilation)] 63 | self.inplanes = planes * 4 64 | for _ in range(1, blocks): 65 | layers.append(Bottleneck(self.inplanes, planes, dilation=dilation)) 66 | return nn.Sequential(*layers) 67 | 68 | def forward(self, x): 69 | out1 = F.relu(self.bn1(self.conv1(x)), inplace=True) 70 | out1 = F.max_pool2d(out1, kernel_size=3, stride=2, padding=1) 71 | out2 = self.layer1(out1) 72 | out3 = self.layer2(out2) 73 | out4 = self.layer3(out3) 74 | out5 = self.layer4(out4) 75 | 76 | out2 = F.relu(self.bn2(self.conv2(out2)), inplace=True) 77 | out3 = F.relu(self.bn3(self.conv3(out3)), inplace=True) 78 | out4 = F.relu(self.bn4(self.conv4(out4)), inplace=True) 79 | out5_ = F.relu(self.bn5(self.conv5(out5)), inplace=True) 80 | 81 | 82 | return out5_, out2, out3, out4 83 | 84 | class Featurefusioncell43(nn.Module): 85 | def __init__(self, standardShape, channel, op): 86 | super(Featurefusioncell43, self).__init__() 87 | self.standardShape = standardShape 88 | self._ops = op 89 | self.conv11 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) 90 | self.bn11 = nn.BatchNorm2d(channel) 91 | 92 | def forward(self, fea, lowfeature): 93 | f2 = fea[0] 94 | levelfusion = fea[2] 95 | f3 = fea[1] 96 | f4 = fea[3] 97 | 98 | assert levelfusion.size()[3] == self.standardShape 99 | if lowfeature.size()[2:] != self.standardShape: 100 | lowfeature = F.interpolate(lowfeature, self.standardShape, mode='bilinear') 101 | if f2.size()[3] != self.standardShape and f2.size() != torch.Size([]): 102 | f2 = F.interpolate(f2, self.standardShape, mode='bilinear') 103 | if f3.size()[3] != self.standardShape and f3.size() != torch.Size([]): 104 | f3 = F.interpolate(f3, self.standardShape, mode='bilinear') 105 | if f4.size()[3] != self.standardShape: 106 | f4 = F.interpolate(f4, self.standardShape, mode='bilinear') 107 | 108 | z1 = f2 109 | z2 = f3 110 | z3 = levelfusion 111 | z4 = f4 112 | pre_note = [lowfeature] 113 | states = [z1, z2, z3, z4] 114 | offset = 0 115 | for i in range(4): 116 | if i == 0: 117 | s0 = states[i] 118 | s1 = self._ops[offset + i](pre_note[i]) 119 | add = s0 + s1 120 | pre_note.append(add) 121 | else: 122 | p1 = states[i] 123 | s0 = self._ops[offset + i](pre_note[i]) 124 | s1 = self._ops[offset + i + 1](states[i]) 125 | add = s0 + s1 + p1 126 | pre_note.append(add) 127 | offset += 1 128 | 129 | out = 0 130 | for i in range(1, 5): 131 | out += pre_note[i] 132 | out = F.relu(self.bn11(self.conv11(out)), inplace=True) 133 | 134 | return out 135 | 136 | 137 | class Featurefusioncell32(nn.Module): 138 | def __init__(self, standardShape, channel, op): 139 | super(Featurefusioncell32, self).__init__() 140 | self.standardShape = standardShape 141 | self._ops = op 142 | self.conv11 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) 143 | self.bn11 = nn.BatchNorm2d(channel) 144 | 145 | def forward(self, fea, lowfeature): 146 | f2 = fea[0] 147 | levelfusion = fea[3] 148 | f3 = fea[1] 149 | f4 = fea[2] 150 | 151 | assert levelfusion.size()[3] == self.standardShape 152 | if lowfeature.size()[2:] != self.standardShape: 153 | lowfeature = F.interpolate(lowfeature, self.standardShape, mode='bilinear') 154 | if f2.size()[3] != self.standardShape and f2.size() != torch.Size([]): 155 | f2 = F.interpolate(f2, self.standardShape, mode='bilinear') 156 | if f3.size()[3] != self.standardShape and f3.size() != torch.Size([]): 157 | f3 = F.interpolate(f3, self.standardShape, mode='bilinear') 158 | if f4.size()[3] != self.standardShape: 159 | f4 = F.interpolate(f4, self.standardShape, mode='bilinear') 160 | 161 | z1 = f2 162 | z2 = f3 163 | z3 = f4 164 | z4 = levelfusion 165 | 166 | pre_note = [lowfeature] 167 | states = [z1, z2, z3, z4] 168 | offset = 0 169 | for i in range(4): 170 | if i == 0: 171 | s0 = states[i] 172 | s1 = self._ops[offset + i](pre_note[i]) 173 | add = s0 + s1 174 | pre_note.append(add) 175 | else: 176 | p1 = states[i] 177 | s0 = self._ops[offset + i](pre_note[i]) 178 | s1 = self._ops[offset + i + 1](states[i]) 179 | add = s0 + s1 + p1 180 | pre_note.append(add) 181 | offset += 1 182 | 183 | out = 0 184 | for i in range(1, 5): 185 | out += pre_note[i] 186 | out = F.relu(self.bn11(self.conv11(out)), inplace=True) 187 | return out 188 | 189 | 190 | class Featurefusioncell54(nn.Module): 191 | def __init__(self, standardShape, channel, op): 192 | super(Featurefusioncell54, self).__init__() 193 | self.standardShape = standardShape 194 | self._ops = op 195 | self.conv11 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) 196 | self.bn11 = nn.BatchNorm2d(channel) 197 | 198 | def forward(self, fea): 199 | lowfeature = fea[0] 200 | levelfusion = fea[1] 201 | f2 = fea[2] 202 | f3 = fea[3] 203 | 204 | # if levelfusion is not None: 205 | assert levelfusion.size()[3] == self.standardShape 206 | if lowfeature.size()[2:] != self.standardShape: 207 | lowfeature = F.interpolate(lowfeature, self.standardShape, mode='bilinear') 208 | if f2.size()[3] != self.standardShape and f2.size() != torch.Size([]): 209 | f2 = F.interpolate(f2, self.standardShape, mode='bilinear') 210 | if f3.size()[3] != self.standardShape and f3.size() != torch.Size([]): 211 | f3 = F.interpolate(f3, self.standardShape, mode='bilinear') 212 | 213 | z1 = lowfeature 214 | z2 = levelfusion 215 | z3 = f2 216 | z4 = f3 217 | 218 | pre_note = [z1] 219 | states = [z2, z3, z4] 220 | offset = 0 221 | for i in range(4 - 1): 222 | if i == 0: 223 | s0 = states[i] 224 | s1 = self._ops[offset + i](pre_note[i]) 225 | add = s0 + s1 226 | pre_note.append(add) 227 | else: 228 | p1 = states[i] 229 | s0 = self._ops[offset + i](pre_note[i]) 230 | s1 = self._ops[offset + i + 1](states[i]) 231 | add = s0 + s1 + p1 232 | pre_note.append(add) 233 | offset += 1 234 | 235 | out = 0 236 | for i in range(1, 4): 237 | out += pre_note[i] 238 | out = F.relu(self.bn11(self.conv11(out)), inplace=True) 239 | 240 | return out 241 | 242 | 243 | class FeatureFusion(nn.Module): 244 | 245 | def __init__(self, genotype_fusion, node=3): 246 | super(FeatureFusion, self).__init__() 247 | 248 | self._ops = nn.ModuleList() 249 | self.fnum = 4 250 | self.node = node 251 | C = 128 252 | 253 | genotype_ouside = genotype_fusion.normal 254 | genotype_inside = genotype_fusion.inside 255 | new_genotype_ouside = sorted(genotype_ouside, key=lambda x: (x[2], x[1])) 256 | op_name, op_num, _ = zip(*new_genotype_ouside) 257 | 258 | self.op_num = op_num 259 | offset = 0 260 | for i in range(self.node): 261 | for j in range(self.fnum): 262 | op = OPS[op_name[j + offset]](C, C, 1,False, True) 263 | self._ops += [op] 264 | offset += 4 265 | 266 | op_name_inside, op_num_inside = zip(*genotype_inside) 267 | 268 | k = [5, 7, 7] 269 | noteOper = [] 270 | offset = 0 271 | for i in range(self.node): 272 | self._nodes = nn.ModuleList() 273 | for j in range(k[i]): 274 | op = OPS[op_name_inside[j + offset]](C, C, 1,False, True) 275 | self._nodes += [op] 276 | noteOper.append(self._nodes) 277 | offset += k[i] 278 | 279 | self.featurefusioncell54 = Featurefusioncell54(16, C, noteOper[0]) 280 | self.featurefusioncell43 = Featurefusioncell43(32, C, noteOper[1]) 281 | self.featurefusioncell32 = Featurefusioncell32(64, C, noteOper[2]) 282 | 283 | def forward(self, out5, out2, out3, out4): 284 | 285 | states = [out5, out4, out3, out2] 286 | 287 | # 每一条边的特征权重,遍历完一个节点要clear,每一轮4个,一共12条边 288 | fea = [] 289 | # 每一个fusion节点输出的tensor字典 290 | feaoutput = [] 291 | offset = 0 292 | s = 0 293 | 294 | for i in range(self.node): 295 | for j, v in enumerate(self.op_num): 296 | if j == 4: 297 | break 298 | inputFea = states[v] 299 | x2 = self._ops[offset + j](inputFea) 300 | fea.append(x2) 301 | 302 | if i == 0: 303 | new_fea = self.featurefusioncell54(fea) 304 | feaoutput.append(new_fea) 305 | fea.clear() 306 | 307 | elif i == 1: 308 | new_fea = self.featurefusioncell43(fea, feaoutput[0]) 309 | feaoutput.append(new_fea) 310 | fea.clear() 311 | 312 | elif i == 2: 313 | new_fea = self.featurefusioncell32(fea, feaoutput[1]) 314 | feaoutput.append(new_fea) 315 | fea.clear() 316 | 317 | offset += 4 318 | 319 | return feaoutput[2],feaoutput[1],feaoutput[0] 320 | 321 | def _loss(self, input, target): 322 | logits = self(input) 323 | logits = logits.squeeze(1) 324 | 325 | return self._criterion(logits, target) 326 | 327 | class vgg16(nn.Module): 328 | def __init__(self,): 329 | super(vgg16, self).__init__() 330 | 331 | # original image's size = 256*256*3 332 | 333 | # conv1 334 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 335 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 336 | self.relu1_1 = nn.ReLU(inplace=True) 337 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 338 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 339 | self.relu1_2 = nn.ReLU(inplace=True) 340 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers 341 | 342 | # conv2 343 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 344 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 345 | self.relu2_1 = nn.ReLU(inplace=True) 346 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 347 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 348 | self.relu2_2 = nn.ReLU(inplace=True) 349 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers 350 | 351 | # conv3 352 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 353 | self.bn3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 354 | self.relu3_1 = nn.ReLU(inplace=True) 355 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 356 | self.bn3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 357 | self.relu3_2 = nn.ReLU(inplace=True) 358 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 359 | self.bn3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 360 | self.relu3_3 = nn.ReLU(inplace=True) 361 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 4 layers 362 | 363 | # conv4 364 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 365 | self.bn4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 366 | self.relu4_1 = nn.ReLU(inplace=True) 367 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 368 | self.bn4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 369 | self.relu4_2 = nn.ReLU(inplace=True) 370 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 371 | self.bn4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 372 | self.relu4_3 = nn.ReLU(inplace=True) 373 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 4 layers 374 | 375 | # conv5 376 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 377 | self.bn5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 378 | self.relu5_1 = nn.ReLU(inplace=True) 379 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 380 | self.bn5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 381 | self.relu5_2 = nn.ReLU(inplace=True) 382 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 383 | self.bn5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 384 | self.relu5_3 = nn.ReLU(inplace=True) # 1/32 4 layers 385 | 386 | out_channel = 128 387 | self.conv5 = nn.Conv2d(512, out_channel, kernel_size=3, stride=1, padding=1) 388 | self.bn5 = nn.BatchNorm2d(out_channel) 389 | self.conv4 = nn.Conv2d(512, out_channel, kernel_size=3, stride=1, padding=1) 390 | self.bn4 = nn.BatchNorm2d(out_channel) 391 | self.conv3 = nn.Conv2d(256, out_channel, kernel_size=3, stride=1, padding=1) 392 | self.bn3 = nn.BatchNorm2d(out_channel) 393 | self.conv2 = nn.Conv2d(128, out_channel, kernel_size=3, stride=1, padding=1) 394 | self.bn2 = nn.BatchNorm2d(out_channel) 395 | 396 | def forward(self, x): 397 | h = x 398 | 399 | h = self.relu1_1(self.bn1_1(self.conv1_1(h))) 400 | h = self.relu1_2(self.bn1_2(self.conv1_2(h))) 401 | h_nopool1 = h 402 | h = self.pool1(h) 403 | # pool1 = h 404 | 405 | h = self.relu2_1(self.bn2_1(self.conv2_1(h))) 406 | h = self.relu2_2(self.bn2_2(self.conv2_2(h))) 407 | h_nopool2 = h 408 | h = self.pool2(h) 409 | # pool2 = h 410 | 411 | h = self.relu3_1(self.bn3_1(self.conv3_1(h))) 412 | h = self.relu3_2(self.bn3_2(self.conv3_2(h))) 413 | h = self.relu3_3(self.bn3_3(self.conv3_3(h))) 414 | h_nopool3 = h 415 | h = self.pool3(h) 416 | # pool3 = h 417 | 418 | h = self.relu4_1(self.bn4_1(self.conv4_1(h))) 419 | h = self.relu4_2(self.bn4_2(self.conv4_2(h))) 420 | h = self.relu4_3(self.bn4_3(self.conv4_3(h))) 421 | h_nopool4 = h 422 | h = self.pool4(h) 423 | 424 | h = self.relu5_1(self.bn5_1(self.conv5_1(h))) 425 | h = self.relu5_2(self.bn5_2(self.conv5_2(h))) 426 | h = self.relu5_3(self.bn5_3(self.conv5_3(h))) 427 | 428 | out2 = F.relu(self.bn2(self.conv2(h_nopool2)), inplace=True) 429 | out3 = F.relu(self.bn3(self.conv3(h_nopool3)), inplace=True) 430 | out4 = F.relu(self.bn4(self.conv4(h_nopool4)), inplace=True) 431 | out5_ = F.relu(self.bn5(self.conv5(h)), inplace=True) 432 | 433 | return out5_, out2, out3, out4 434 | 435 | class Network_Resnet50(nn.Module): 436 | 437 | def __init__(self,genotype_fusion): 438 | super(Network_Resnet50, self).__init__() 439 | self.resnet = ResNet() 440 | self.feafusion = FeatureFusion(genotype_fusion) 441 | self.conv44 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1) 442 | self.conv55 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1) 443 | self.conv66 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1) 444 | 445 | def forward(self, input): 446 | h_, h_nopool2, h_nopool3, h_nopool4 = self.resnet(input) 447 | h_nopool2,h_nopool3,h_nopool4 = self.feafusion(h_, h_nopool2, h_nopool3, h_nopool4) 448 | h_nopool2 = F.interpolate(self.conv44(h_nopool2), size=[256, 256], mode='bilinear') 449 | h_nopool3 = F.interpolate(self.conv55(h_nopool3), size=[256, 256], mode='bilinear') 450 | h_nopool4 = F.interpolate(self.conv66(h_nopool4), size=[256, 256], mode='bilinear') 451 | return h_nopool2,h_nopool3,h_nopool4 452 | -------------------------------------------------------------------------------- /model_vgg16.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from operations import * 5 | from torch.autograd import Variable 6 | from utils import drop_path 7 | 8 | 9 | class Bottleneck(nn.Module): 10 | def __init__(self, inplanes, planes, stride=1, downsample=None, dilation=1): 11 | super(Bottleneck, self).__init__() 12 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 13 | self.bn1 = nn.BatchNorm2d(planes) 14 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, padding=(3 * dilation - 1) // 2, 15 | bias=False, 16 | dilation=dilation) 17 | self.bn2 = nn.BatchNorm2d(planes) 18 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 19 | self.bn3 = nn.BatchNorm2d(planes * 4) 20 | self.downsample = downsample 21 | 22 | def forward(self, x): 23 | residual = x 24 | out = F.relu(self.bn1(self.conv1(x)), inplace=True) 25 | out = F.relu(self.bn2(self.conv2(out)), inplace=True) 26 | out = self.bn3(self.conv3(out)) 27 | if self.downsample is not None: 28 | residual = self.downsample(x) 29 | return F.relu(out + residual, inplace=True) 30 | 31 | 32 | class vgg16(nn.Module): 33 | def __init__(self,): 34 | super(vgg16, self).__init__() 35 | 36 | # conv1 37 | self.conv1_1 = nn.Conv2d(3, 64, 3, padding=1) 38 | self.bn1_1 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 39 | self.relu1_1 = nn.ReLU(inplace=True) 40 | self.conv1_2 = nn.Conv2d(64, 64, 3, padding=1) 41 | self.bn1_2 = nn.BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True) 42 | self.relu1_2 = nn.ReLU(inplace=True) 43 | self.pool1 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/2 2 layers 44 | 45 | # conv2 46 | self.conv2_1 = nn.Conv2d(64, 128, 3, padding=1) 47 | self.bn2_1 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 48 | self.relu2_1 = nn.ReLU(inplace=True) 49 | self.conv2_2 = nn.Conv2d(128, 128, 3, padding=1) 50 | self.bn2_2 = nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True) 51 | self.relu2_2 = nn.ReLU(inplace=True) 52 | self.pool2 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/4 2 layers 53 | 54 | # conv3 55 | self.conv3_1 = nn.Conv2d(128, 256, 3, padding=1) 56 | self.bn3_1 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 57 | self.relu3_1 = nn.ReLU(inplace=True) 58 | self.conv3_2 = nn.Conv2d(256, 256, 3, padding=1) 59 | self.bn3_2 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 60 | self.relu3_2 = nn.ReLU(inplace=True) 61 | self.conv3_3 = nn.Conv2d(256, 256, 3, padding=1) 62 | self.bn3_3 = nn.BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True) 63 | self.relu3_3 = nn.ReLU(inplace=True) 64 | self.pool3 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/8 4 layers 65 | 66 | # conv4 67 | self.conv4_1 = nn.Conv2d(256, 512, 3, padding=1) 68 | self.bn4_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 69 | self.relu4_1 = nn.ReLU(inplace=True) 70 | self.conv4_2 = nn.Conv2d(512, 512, 3, padding=1) 71 | self.bn4_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 72 | self.relu4_2 = nn.ReLU(inplace=True) 73 | self.conv4_3 = nn.Conv2d(512, 512, 3, padding=1) 74 | self.bn4_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 75 | self.relu4_3 = nn.ReLU(inplace=True) 76 | self.pool4 = nn.MaxPool2d(2, stride=2, ceil_mode=True) # 1/16 4 layers 77 | 78 | # conv5 79 | self.conv5_1 = nn.Conv2d(512, 512, 3, padding=1) 80 | self.bn5_1 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 81 | self.relu5_1 = nn.ReLU(inplace=True) 82 | self.conv5_2 = nn.Conv2d(512, 512, 3, padding=1) 83 | self.bn5_2 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 84 | self.relu5_2 = nn.ReLU(inplace=True) 85 | self.conv5_3 = nn.Conv2d(512, 512, 3, padding=1) 86 | self.bn5_3 = nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True) 87 | self.relu5_3 = nn.ReLU(inplace=True) # 1/32 4 layers 88 | 89 | out_channel = 128 90 | self.conv5 = nn.Conv2d(512, out_channel, kernel_size=3, stride=1, padding=1) 91 | self.bn5 = nn.BatchNorm2d(out_channel) 92 | self.conv4 = nn.Conv2d(512, out_channel, kernel_size=3, stride=1, padding=1) 93 | self.bn4 = nn.BatchNorm2d(out_channel) 94 | self.conv3 = nn.Conv2d(256, out_channel, kernel_size=3, stride=1, padding=1) 95 | self.bn3 = nn.BatchNorm2d(out_channel) 96 | self.conv2 = nn.Conv2d(128, out_channel, kernel_size=3, stride=1, padding=1) 97 | self.bn2 = nn.BatchNorm2d(out_channel) 98 | 99 | 100 | def forward(self, x): 101 | h = x 102 | 103 | h = self.relu1_1(self.bn1_1(self.conv1_1(h))) 104 | h = self.relu1_2(self.bn1_2(self.conv1_2(h))) 105 | h_nopool1 = h 106 | h = self.pool1(h) 107 | # pool1 = h 108 | 109 | h = self.relu2_1(self.bn2_1(self.conv2_1(h))) 110 | h = self.relu2_2(self.bn2_2(self.conv2_2(h))) 111 | h_nopool2 = h 112 | h = self.pool2(h) 113 | # pool2 = h 114 | 115 | h = self.relu3_1(self.bn3_1(self.conv3_1(h))) 116 | h = self.relu3_2(self.bn3_2(self.conv3_2(h))) 117 | h = self.relu3_3(self.bn3_3(self.conv3_3(h))) 118 | h_nopool3 = h 119 | h = self.pool3(h) 120 | # pool3 = h 121 | 122 | h = self.relu4_1(self.bn4_1(self.conv4_1(h))) 123 | h = self.relu4_2(self.bn4_2(self.conv4_2(h))) 124 | h = self.relu4_3(self.bn4_3(self.conv4_3(h))) 125 | h_nopool4 = h 126 | h = self.pool4(h) 127 | 128 | h = self.relu5_1(self.bn5_1(self.conv5_1(h))) 129 | h = self.relu5_2(self.bn5_2(self.conv5_2(h))) 130 | h = self.relu5_3(self.bn5_3(self.conv5_3(h))) 131 | 132 | out2 = F.relu(self.bn2(self.conv2(h_nopool2)), inplace=True) 133 | out3 = F.relu(self.bn3(self.conv3(h_nopool3)), inplace=True) 134 | out4 = F.relu(self.bn4(self.conv4(h_nopool4)), inplace=True) 135 | out5_ = F.relu(self.bn5(self.conv5(h)), inplace=True) 136 | 137 | return out5_, out2, out3, out4 138 | 139 | 140 | class Featurefusioncell43(nn.Module): 141 | def __init__(self, standardShape, channel, op): 142 | super(Featurefusioncell43, self).__init__() 143 | self.standardShape = standardShape 144 | self._ops = op 145 | self.conv11 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) 146 | self.bn11 = nn.BatchNorm2d(channel) 147 | # self.initialize() 148 | 149 | def initialize(self): 150 | weight_init(self) 151 | 152 | def forward(self, fea, lowfeature): 153 | f2 = fea[0] 154 | levelfusion = fea[2] 155 | f3 = fea[1] 156 | f4 = fea[3] 157 | 158 | if levelfusion.size()[2:] != self.standardShape: 159 | levelfusion = F.interpolate(levelfusion, self.standardShape, mode='bilinear') 160 | if lowfeature.size()[2:] != self.standardShape: 161 | lowfeature = F.interpolate(lowfeature, self.standardShape, mode='bilinear') 162 | if f2.size()[3] != self.standardShape and f2.size() != torch.Size([]): 163 | f2 = F.interpolate(f2, self.standardShape, mode='bilinear') 164 | if f3.size()[3] != self.standardShape and f3.size() != torch.Size([]): 165 | f3 = F.interpolate(f3, self.standardShape, mode='bilinear') 166 | if f4.size()[3] != self.standardShape: 167 | f4 = F.interpolate(f4, self.standardShape, mode='bilinear') 168 | 169 | z1 = f2 170 | z2 = f3 171 | z3 = levelfusion 172 | z4 = f4 173 | pre_note = [lowfeature] 174 | states = [z1, z2, z3, z4] 175 | offset = 0 176 | for i in range(4): 177 | if i == 0: 178 | s0 = states[i] 179 | s1 = self._ops[offset + i](pre_note[i]) 180 | add = s0 + s1 181 | pre_note.append(add) 182 | else: 183 | p1 = states[i] 184 | s0 = self._ops[offset + i](pre_note[i]) 185 | s1 = self._ops[offset + i + 1](states[i]) 186 | add = s0 + s1 + p1 187 | pre_note.append(add) 188 | offset += 1 189 | 190 | out = 0 191 | for i in range(1, 5): 192 | out += pre_note[i] 193 | out = F.relu(self.bn11(self.conv11(out)), inplace=True) 194 | 195 | return out 196 | 197 | 198 | class Featurefusioncell32(nn.Module): 199 | def __init__(self, standardShape, channel, op): 200 | super(Featurefusioncell32, self).__init__() 201 | self.standardShape = standardShape 202 | self._ops = op 203 | self.conv11 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) 204 | self.bn11 = nn.BatchNorm2d(channel) 205 | 206 | def initialize(self): 207 | weight_init(self) 208 | 209 | def forward(self, fea, lowfeature): 210 | f2 = fea[0] 211 | levelfusion = fea[3] 212 | f3 = fea[1] 213 | f4 = fea[2] 214 | 215 | if levelfusion.size()[2:] != self.standardShape: 216 | levelfusion = F.interpolate(levelfusion, self.standardShape, mode='bilinear') 217 | if lowfeature.size()[2:] != self.standardShape: 218 | lowfeature = F.interpolate(lowfeature, self.standardShape, mode='bilinear') 219 | if f2.size()[3] != self.standardShape and f2.size() != torch.Size([]): 220 | f2 = F.interpolate(f2, self.standardShape, mode='bilinear') 221 | if f3.size()[3] != self.standardShape and f3.size() != torch.Size([]): 222 | f3 = F.interpolate(f3, self.standardShape, mode='bilinear') 223 | if f4.size()[3] != self.standardShape: 224 | f4 = F.interpolate(f4, self.standardShape, mode='bilinear') 225 | 226 | z1 = f2 227 | z2 = f3 228 | z3 = f4 229 | z4 = levelfusion 230 | 231 | pre_note = [lowfeature] 232 | states = [z1, z2, z3, z4] 233 | offset = 0 234 | for i in range(4): 235 | if i == 0: 236 | s0 = states[i] 237 | s1 = self._ops[offset + i](pre_note[i]) 238 | add = s0 + s1 239 | pre_note.append(add) 240 | else: 241 | p1 = states[i] 242 | s0 = self._ops[offset + i](pre_note[i]) 243 | s1 = self._ops[offset + i + 1](states[i]) 244 | add = s0 + s1 + p1 245 | pre_note.append(add) 246 | offset += 1 247 | 248 | out = 0 249 | for i in range(1, 5): 250 | out += pre_note[i] 251 | out = F.relu(self.bn11(self.conv11(out)), inplace=True) 252 | return out 253 | 254 | 255 | class Featurefusioncell54(nn.Module): 256 | def __init__(self, standardShape, channel, op): 257 | super(Featurefusioncell54, self).__init__() 258 | self.standardShape = standardShape 259 | self._ops = op 260 | self.conv11 = nn.Conv2d(channel, channel, kernel_size=3, stride=1, padding=1) 261 | self.bn11 = nn.BatchNorm2d(channel) 262 | 263 | def initialize(self): 264 | weight_init(self) 265 | 266 | def forward(self, fea): 267 | lowfeature = fea[0] 268 | levelfusion = fea[1] 269 | f2 = fea[2] 270 | f3 = fea[3] 271 | 272 | if levelfusion.size()[2:] != self.standardShape: 273 | levelfusion = F.interpolate(levelfusion, self.standardShape, mode='bilinear') 274 | if lowfeature.size()[2:] != self.standardShape: 275 | lowfeature = F.interpolate(lowfeature, self.standardShape, mode='bilinear') 276 | if f2.size()[3] != self.standardShape and f2.size() != torch.Size([]): 277 | f2 = F.interpolate(f2, self.standardShape, mode='bilinear') 278 | if f3.size()[3] != self.standardShape and f3.size() != torch.Size([]): 279 | f3 = F.interpolate(f3, self.standardShape, mode='bilinear') 280 | 281 | z1 = lowfeature 282 | z2 = levelfusion 283 | z3 = f2 284 | z4 = f3 285 | 286 | pre_note = [z1] 287 | states = [z2, z3, z4] 288 | offset = 0 289 | for i in range(4 - 1): 290 | if i == 0: 291 | s0 = states[i] 292 | s1 = self._ops[offset + i](pre_note[i]) 293 | add = s0 + s1 294 | pre_note.append(add) 295 | else: 296 | p1 = states[i] 297 | s0 = self._ops[offset + i](pre_note[i]) 298 | s1 = self._ops[offset + i + 1](states[i]) 299 | add = s0 + s1 + p1 300 | pre_note.append(add) 301 | offset += 1 302 | 303 | out = 0 304 | for i in range(1, 4): 305 | out += pre_note[i] 306 | out = F.relu(self.bn11(self.conv11(out)), inplace=True) 307 | 308 | return out 309 | 310 | 311 | class FeatureFusion(nn.Module): 312 | 313 | def __init__(self, genotype_fusion, node=3): 314 | super(FeatureFusion, self).__init__() 315 | 316 | self._ops = nn.ModuleList() 317 | self.fnum = 4 318 | self.node = node 319 | # self.none_num = [] 320 | C = 128 321 | 322 | genotype_ouside = genotype_fusion.normal 323 | genotype_inside = genotype_fusion.inside 324 | new_genotype_ouside = sorted(genotype_ouside, key=lambda x: (x[2], x[1])) 325 | op_name, op_num, _ = zip(*new_genotype_ouside) 326 | 327 | self.op_num = op_num 328 | offset = 0 329 | for i in range(self.node): 330 | for j in range(self.fnum): 331 | op = OPS[op_name[j + offset]](C, C, 1,False, True) 332 | self._ops += [op] 333 | offset += 4 334 | 335 | op_name_inside, op_num_inside = zip(*genotype_inside) 336 | 337 | k = [5, 7, 7] 338 | noteOper = [] 339 | offset = 0 340 | for i in range(self.node): 341 | self._nodes = nn.ModuleList() 342 | for j in range(k[i]): 343 | op = OPS[op_name_inside[j + offset]](C, C, 1,False, True) 344 | self._nodes += [op] 345 | noteOper.append(self._nodes) 346 | offset += k[i] 347 | 348 | self.featurefusioncell54 = Featurefusioncell54(16, C, noteOper[0]) 349 | self.featurefusioncell43 = Featurefusioncell43(32, C, noteOper[1]) 350 | self.featurefusioncell32 = Featurefusioncell32(64, C, noteOper[2]) 351 | 352 | def forward(self, out5, out2, out3, out4): 353 | 354 | states = [out5, out4, out3, out2] 355 | 356 | fea = [] 357 | feaoutput = [] 358 | offset = 0 359 | s = 0 360 | 361 | for i in range(self.node): 362 | for j, v in enumerate(self.op_num): 363 | if j == 4: 364 | break 365 | inputFea = states[v] 366 | x2 = self._ops[offset + j](inputFea) 367 | fea.append(x2) 368 | 369 | if i == 0: 370 | new_fea = self.featurefusioncell54(fea) 371 | feaoutput.append(new_fea) 372 | fea.clear() 373 | 374 | elif i == 1: 375 | new_fea = self.featurefusioncell43(fea, feaoutput[0]) 376 | feaoutput.append(new_fea) 377 | fea.clear() 378 | 379 | elif i == 2: 380 | new_fea = self.featurefusioncell32(fea, feaoutput[1]) 381 | feaoutput.append(new_fea) 382 | fea.clear() 383 | 384 | offset += 4 385 | 386 | return feaoutput[2],feaoutput[1],feaoutput[0] 387 | 388 | def _loss(self, input, target): 389 | logits = self(input) 390 | logits = logits.squeeze(1) 391 | 392 | return self._criterion(logits, target) 393 | 394 | 395 | class Network_vgg16(nn.Module): 396 | 397 | def __init__(self, genotype_fusion): 398 | super(Network_vgg16, self).__init__() 399 | # self.vgg16 = Vgg16_RGB() 400 | self.vgg16 = vgg16() 401 | self.feafusion = FeatureFusion(genotype_fusion) 402 | self.conv44 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1) 403 | self.conv55 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1) 404 | self.conv66 = nn.Conv2d(128, 1, kernel_size=3, stride=1, padding=1) 405 | 406 | def forward(self, input): 407 | h_, h_nopool2, h_nopool3, h_nopool4 = self.vgg16(input) 408 | h_nopool2,h_nopool3,h_nopool4 = self.feafusion(h_, h_nopool2, h_nopool3, h_nopool4) 409 | h_nopool2 = F.interpolate(self.conv44(h_nopool2), size=[256, 256], mode='bilinear') 410 | h_nopool3 = F.interpolate(self.conv55(h_nopool3), size=[256, 256], mode='bilinear') 411 | h_nopool4 = F.interpolate(self.conv66(h_nopool4), size=[256, 256], mode='bilinear') 412 | return h_nopool2,h_nopool3,h_nopool4 413 | -------------------------------------------------------------------------------- /operations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | OPS = { 6 | 'none': lambda in_C, out_C, stride, upsample, affine: Zero(stride, upsample=upsample), 7 | 'skip_connect': lambda in_C, out_C, stride, upsample, affine: Identity(upsample=upsample), 8 | 'sep_conv_3x3': lambda in_C, out_C, stride, upsample, affine: SepConv(in_C, out_C, 3, stride, 1, affine=affine, 9 | upsample=upsample), 10 | 'sep_conv_3x3_rp2': lambda in_C, out_C, stride, upsample, affine: SepConvDouble(in_C, out_C, 3, stride, 1, 11 | affine=affine, upsample=upsample), 12 | 'dil_conv_3x3': lambda in_C, out_C, stride, upsample, affine: DilConv(in_C, out_C, 3, stride, 2, 2, affine=affine, 13 | upsample=upsample), 14 | 'dil_conv_3x3_rp2': lambda in_C, out_C, stride, upsample, affine: DilConvDouble(in_C, out_C, 3, stride, 2, 2, 15 | affine=affine, upsample=upsample), 16 | 'dil_conv_3x3_dil4': lambda in_C, out_C, stride, upsample, affine: DilConv(in_C, out_C, 3, stride, 4, 4, 17 | affine=affine, upsample=upsample), 18 | 19 | 'conv_3x3': lambda in_C, out_C, stride, upsample, affine: Conv(in_C, out_C, 3, stride, 1, affine=affine, 20 | upsample=upsample), 21 | 'conv_3x3_rp2': lambda in_C, out_C, stride, upsample, affine: ConvDouble(in_C, out_C, 3, stride, 1, affine=affine, 22 | upsample=upsample), 23 | 24 | 'SpatialAttention': lambda in_C, out_C, stride, upsample, affine: SpatialAttention(in_C,7), 25 | 'ChannelAttention': lambda in_C, out_C, stride, upsample, affine: ChannelAttention(in_C,16), 26 | 27 | } 28 | 29 | 30 | def conv3x3(in_planes, out_planes, stride): 31 | "3x3 convolution with padding" 32 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 33 | padding=1, bias=False) 34 | 35 | 36 | class ChannelAttention(nn.Module): 37 | def __init__(self, in_channels, ratio): 38 | super(ChannelAttention, self).__init__() 39 | 40 | self.in_channels = in_channels 41 | 42 | self.linear_1 = nn.Linear(self.in_channels, self.in_channels // ratio) 43 | self.linear_2 = nn.Linear(self.in_channels // ratio, self.in_channels) 44 | self.conv1 = nn.Conv2d(self.in_channels, self.in_channels, kernel_size=1, bias=False) 45 | self.bn1 = nn.BatchNorm2d(self.in_channels) 46 | def forward(self, input_): 47 | n_b, n_c, h, w = input_.size() 48 | 49 | feats = F.adaptive_avg_pool2d(input_, (1, 1)).view((n_b, n_c)) 50 | feats = F.relu(self.linear_1(feats)) 51 | feats = torch.sigmoid(self.linear_2(feats)) 52 | 53 | feats = feats.view((n_b, n_c, 1, 1)) 54 | feats = feats.expand_as(input_).clone() 55 | out = torch.mul(input_, feats) 56 | out = F.relu(self.bn1(self.conv1(out)), inplace=True) 57 | 58 | return out 59 | 60 | 61 | class SpatialAttention(nn.Module): 62 | def __init__(self,in_C, kernel_size=7): 63 | super(SpatialAttention, self).__init__() 64 | 65 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 66 | padding = 3 if kernel_size == 7 else 1 67 | self.in_channels = in_C 68 | self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False) 69 | self.sigmoid = nn.Sigmoid() 70 | self.conv11 = nn.Conv2d(self.in_channels, self.in_channels, kernel_size=1, bias=False) 71 | self.bn11 = nn.BatchNorm2d(self.in_channels) 72 | def forward(self, x): 73 | input = x 74 | avg_out = torch.mean(x, dim=1, keepdim=True) 75 | max_out, _ = torch.max(x, dim=1, keepdim=True) 76 | x = torch.cat([avg_out, max_out], dim=1) 77 | x = self.conv1(x) 78 | x = self.sigmoid(x) 79 | out = input * x 80 | 81 | out = F.relu(self.bn11(self.conv11(out)), inplace=True) 82 | 83 | return out 84 | 85 | 86 | class Conv(nn.Module): 87 | 88 | def __init__(self, C_in, C_out, kernel_size, stride, padding, upsample, affine=True): 89 | super(Conv, self).__init__() 90 | self.upsample = upsample 91 | self.up = nn.Sequential( 92 | torch.nn.ReLU(inplace=False), 93 | torch.nn.Upsample(scale_factor=2, mode='bilinear') 94 | ) 95 | self.op = nn.Sequential( 96 | nn.ReLU(inplace=False), 97 | nn.Conv2d(C_in, C_out, kernel_size=kernel_size, stride=stride, padding=padding, bias=False), 98 | nn.BatchNorm2d(C_out, affine=affine), 99 | ) 100 | 101 | def forward(self, x): 102 | if self.upsample is True: 103 | x = self.up(x) 104 | return self.op(x) 105 | 106 | 107 | class ConvDouble(nn.Module): 108 | 109 | def __init__(self, C_in, C_out, kernel_size, stride, padding, upsample, affine=True): 110 | super(ConvDouble, self).__init__() 111 | 112 | self.upsample = upsample 113 | self.up = nn.Sequential( 114 | torch.nn.ReLU(inplace=False), 115 | torch.nn.Upsample(scale_factor=2, mode='bilinear') 116 | ) 117 | 118 | self.op = nn.Sequential( 119 | nn.ReLU(inplace=False), 120 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, bias=False), 121 | nn.BatchNorm2d(C_in, affine=affine), 122 | nn.ReLU(inplace=False), 123 | nn.Conv2d(C_in, C_out, kernel_size=kernel_size, stride=1, padding=padding, bias=False), 124 | nn.BatchNorm2d(C_out, affine=affine), 125 | ) 126 | 127 | def forward(self, x): 128 | if self.upsample is True: 129 | x = self.up(x) 130 | return self.op(x) 131 | 132 | 133 | class ReLUConvBN(nn.Module): 134 | 135 | def __init__(self, C_in, C_out, kernel_size, stride, padding, affine=True): 136 | super(ReLUConvBN, self).__init__() 137 | self.op = nn.Sequential( 138 | nn.ReLU(inplace=False), 139 | nn.Conv2d(C_in, C_out, kernel_size, stride=stride, padding=padding, bias=False), 140 | nn.BatchNorm2d(C_out, affine=affine) 141 | ) 142 | 143 | def forward(self, x): 144 | return self.op(x) 145 | 146 | 147 | class DilConv(nn.Module): 148 | 149 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, upsample, affine=True): 150 | super(DilConv, self).__init__() 151 | 152 | self.upsample = upsample 153 | self.up = nn.Sequential( 154 | torch.nn.ReLU(inplace=False), 155 | torch.nn.Upsample(scale_factor=2, mode='bilinear') 156 | ) 157 | 158 | self.op = nn.Sequential( 159 | nn.ReLU(inplace=False), 160 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, 161 | bias=False), 162 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 163 | nn.BatchNorm2d(C_out, affine=affine), 164 | ) 165 | 166 | def forward(self, x): 167 | if self.upsample is True: 168 | x = self.up(x) 169 | return self.op(x) 170 | 171 | 172 | class DilConvDouble(nn.Module): 173 | 174 | def __init__(self, C_in, C_out, kernel_size, stride, padding, dilation, upsample, affine=True): 175 | super(DilConvDouble, self).__init__() 176 | self.upsample = upsample 177 | self.up = nn.Sequential( 178 | torch.nn.ReLU(inplace=False), 179 | torch.nn.Upsample(scale_factor=2, mode='bilinear') 180 | ) 181 | self.op = nn.Sequential( 182 | nn.ReLU(inplace=False), 183 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, 184 | bias=False), 185 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 186 | nn.BatchNorm2d(C_in, affine=affine), 187 | nn.ReLU(inplace=False), 188 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, dilation=dilation, 189 | bias=False), 190 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 191 | nn.BatchNorm2d(C_out, affine=affine), 192 | ) 193 | 194 | def forward(self, x): 195 | if self.upsample is True: 196 | x = self.up(x) 197 | return self.op(x) 198 | 199 | 200 | class SepConv(nn.Module): 201 | 202 | def __init__(self, C_in, C_out, kernel_size, stride, padding, upsample, affine=True): 203 | super(SepConv, self).__init__() 204 | 205 | self.upsample = upsample 206 | self.up = nn.Sequential( 207 | torch.nn.ReLU(inplace=False), 208 | torch.nn.Upsample(scale_factor=2, mode='bilinear') 209 | ) 210 | 211 | self.op = nn.Sequential( 212 | nn.ReLU(inplace=False), 213 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, bias=False), 214 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 215 | nn.BatchNorm2d(C_out, affine=affine), 216 | ) 217 | 218 | def forward(self, x): 219 | if self.upsample is True: 220 | x = self.up(x) 221 | return self.op(x) 222 | 223 | 224 | class FactorizedReduce(nn.Module): 225 | 226 | def __init__(self, C_in, C_out, affine=True): 227 | super(FactorizedReduce, self).__init__() 228 | assert C_out % 2 == 0 229 | 230 | self.up = nn.Sequential( 231 | nn.ReLU(), 232 | nn.Upsample(scale_factor=2, mode='bilinear') 233 | ) 234 | 235 | self.relu = nn.ReLU(inplace=False) 236 | self.conv_1 = nn.Conv2d(C_in, C_out, 1, stride=1, padding=0, bias=False) 237 | self.bn = nn.BatchNorm2d(C_out, affine=affine) 238 | 239 | def forward(self, x): 240 | x = self.up(x) 241 | x = self.relu(x) 242 | out = self.conv_1(x) 243 | out = self.bn(out) 244 | return out 245 | 246 | 247 | class SepConvDouble(nn.Module): 248 | 249 | def __init__(self, C_in, C_out, kernel_size, stride, padding, upsample, affine=True): 250 | super(SepConvDouble, self).__init__() 251 | 252 | self.upsample = upsample 253 | self.up = nn.Sequential( 254 | torch.nn.ReLU(inplace=False), 255 | torch.nn.Upsample(scale_factor=2, mode='bilinear') 256 | ) 257 | 258 | self.op = nn.Sequential( 259 | nn.ReLU(inplace=False), 260 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=stride, padding=padding, bias=False), 261 | nn.Conv2d(C_in, C_in, kernel_size=1, padding=0, bias=False), 262 | nn.BatchNorm2d(C_in, affine=affine), 263 | nn.ReLU(inplace=False), 264 | nn.Conv2d(C_in, C_in, kernel_size=kernel_size, stride=1, padding=padding, bias=False), 265 | nn.Conv2d(C_in, C_out, kernel_size=1, padding=0, bias=False), 266 | nn.BatchNorm2d(C_out, affine=affine), 267 | ) 268 | 269 | def forward(self, x): 270 | if self.upsample is True: 271 | x = self.up(x) 272 | return self.op(x) 273 | 274 | 275 | 276 | class Identity(nn.Module): 277 | 278 | def __init__(self, upsample): 279 | super(Identity, self).__init__() 280 | self.upsample = upsample 281 | self.up = nn.Sequential( 282 | torch.nn.ReLU(inplace=False), 283 | torch.nn.Upsample(scale_factor=2, mode='bilinear') 284 | ) 285 | 286 | def forward(self, x): 287 | if self.upsample == True: 288 | x = self.up(x) 289 | return x 290 | 291 | 292 | class Zero(nn.Module): 293 | 294 | def __init__(self, stride, upsample): 295 | super(Zero, self).__init__() 296 | self.stride = stride 297 | self.upsample = upsample 298 | self.up = nn.Sequential( 299 | torch.nn.ReLU(inplace=False), 300 | torch.nn.Upsample(scale_factor=2, mode='bilinear') 301 | ) 302 | 303 | def forward(self, x): 304 | if self.upsample == True: 305 | x = self.up(x) 306 | else: 307 | x = x.mul(0.) 308 | return x 309 | # return x[:,:,::self.stride,::self.stride].mul(0.) 310 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import glob 4 | import numpy as np 5 | import torch 6 | import utils as utils 7 | import logging 8 | import argparse 9 | import torch.nn as nn 10 | import genotypes 11 | import torch.utils 12 | import torchvision.datasets as dset 13 | import torch.backends.cudnn as cudnn 14 | import data 15 | import time 16 | import torch.nn.functional as F 17 | from scipy import misc 18 | from matplotlib.pyplot import imsave 19 | from torch.autograd import Variable 20 | from model_resnet50 import Network_Resnet50 as Network_Resnet50 21 | from model_vgg16 import Network_vgg16 as Network_vgg16 22 | 23 | from PIL import Image 24 | from torchvision.transforms import transforms 25 | from metric import * 26 | from skimage import img_as_ubyte 27 | import os 28 | import cv2 29 | 30 | os.environ['CUDA_VISIBLE_DEVICES']='0,1' 31 | parser = argparse.ArgumentParser("test_model") 32 | parser.add_argument('--batch_size', type=int, default=1, help='batch size') 33 | parser.add_argument('--test_size', type=int, default=256, help='batch size') 34 | parser.add_argument('--gpu', type=int, default=0, help='gpu device id') 35 | parser.add_argument('--init_channels', type=int, default=128, help='num of init channels') 36 | parser.add_argument('--model_path', type=str, default='./checkpoint/Auto_MSFNet_resnet50.pt', 37 | help='path of pretrained checkpoint') 38 | parser.add_argument('--backbone', type=str, default='resnet50', help='test dataset') 39 | parser.add_argument('--fu_arch', type=str, default='fusion_genotype_resnet50', help='which architecture to use') 40 | parser.add_argument('--note', type=str, default='fusion_genotype_resnet50', help='test dataset') 41 | 42 | args = parser.parse_args() 43 | args.save = '{}-{}'.format(args.note, time.strftime("%Y%m%d-%H%M%S")) 44 | utils.create_exp_dir(args.save, scripts_to_save=glob.glob('*.py')) 45 | 46 | log_format = '%(asctime)s %(message)s' 47 | logging.basicConfig(stream=sys.stdout, level=logging.INFO, 48 | format=log_format, datefmt='%m/%d %I:%M:%S %p') 49 | fh = logging.FileHandler(os.path.join(args.save, 'log.txt')) 50 | fh.setFormatter(logging.Formatter(log_format)) 51 | logging.getLogger().addHandler(fh) 52 | 53 | dataset = ['HKU-IS-WI1D', 'DUTS', 'DUT-OMRON', 'ECSSD', 'PASCAL-S'] 54 | 55 | 56 | def main(): 57 | if not torch.cuda.is_available(): 58 | logging.info('no gpu device available') 59 | sys.exit(1) 60 | 61 | logging.info('gpu device = %d' % args.gpu) 62 | logging.info("args = %s", args) 63 | torch.cuda.set_device(args.gpu) 64 | genotype_fu = eval("genotypes.%s" % args.fu_arch) 65 | if args.backbone == "vgg16": 66 | model = Network_vgg16(genotype_fu) 67 | elif args.backbone == "resnet50": 68 | model = Network_Resnet50(genotype_fu) 69 | 70 | model = model.cuda() 71 | utils.load(model, args.model_path) 72 | 73 | logging.info("param size = %fMB", utils.count_parameters_in_MB(model)) 74 | 75 | for i, dataset_name in enumerate(dataset): 76 | test_image_root = '/home/oip/testData/' + dataset_name + '/test_images/' 77 | test_gt_root = '/home/oip/testData/' + dataset_name + '/test_masks/' 78 | 79 | test_data = data.MyTestData(test_image_root, test_gt_root, args.test_size) 80 | test_queue = torch.utils.data.DataLoader( 81 | test_data, 82 | batch_size=args.batch_size, shuffle=False, num_workers=0, pin_memory=True) 83 | num_test = len(test_data) 84 | Fmax_measure, Fm_measure, mae, S_measure = infer(test_queue, model, dataset_name, num_test) 85 | logging.info('dataset_name {}'.format(dataset_name)) 86 | logging.info('Fmax-measuree %f', Fmax_measure) 87 | logging.info('Fm-measuree %f', Fm_measure) 88 | logging.info('mae %f', mae) 89 | logging.info('S-measure %f', S_measure) 90 | 91 | 92 | def infer(test_queue, model, dataset_name, num_test): 93 | model.eval() 94 | savepath = './prediction/' + dataset_name 95 | cal_fm = CalFM(num=num_test) # cal是一个对象 96 | cal_mae = CalMAE(num=num_test) 97 | cal_sm = CalSM(num=num_test) 98 | for step, (input, target, name) in enumerate(test_queue): 99 | input = input.cuda() 100 | target = torch.squeeze(target) 101 | with torch.no_grad(): 102 | h_nopool2,_,_= model(input) 103 | test_output_root = os.path.join(args.save, savepath) 104 | if not os.path.exists(test_output_root): 105 | os.makedirs(test_output_root) 106 | H,W = target.shape 107 | 108 | h_nopool2 = F.interpolate(h_nopool2,(H,W),mode='bilinear') 109 | output_rgb = torch.squeeze(h_nopool2) 110 | predict_rgb = output_rgb.sigmoid().data.cpu().detach().numpy() 111 | predict_rgb = img_as_ubyte(predict_rgb) 112 | cv2.imwrite(test_output_root + '/' + name[0] + '.png', predict_rgb) 113 | target = target.cpu().detach().numpy() 114 | max_pred_array = predict_rgb.max() 115 | min_pred_array = predict_rgb.min() 116 | 117 | if max_pred_array == min_pred_array: 118 | predict_rgb = predict_rgb / 255 119 | else: 120 | predict_rgb = (predict_rgb - min_pred_array) / (max_pred_array - min_pred_array) 121 | 122 | max_target = target.max() 123 | min_target = target.min() 124 | if max_target == min_target: 125 | target = target / 255 126 | else: 127 | target = (target - min_target) / (max_target - min_target) 128 | 129 | cal_fm.update(predict_rgb, target) 130 | cal_mae.update(predict_rgb, target) 131 | cal_sm.update(predict_rgb, target) 132 | 133 | 134 | if step % 50 == 0 or step == len(test_queue) - 1: 135 | logging.info( 136 | "TestDataSet:{} Step {:03d}/{:03d} ".format( 137 | dataset_name, step, len(test_queue) - 1)) 138 | _, maxf, mmf, _, _ = cal_fm.show() 139 | mae = cal_mae.show() 140 | sm = cal_sm.show() 141 | return maxf, mmf, mae, sm, 142 | 143 | 144 | if __name__ == '__main__': 145 | main() 146 | 147 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import torch 4 | import shutil 5 | import torchvision.transforms as transforms 6 | from torch.autograd import Variable 7 | import torch.nn as nn 8 | 9 | class AverageMeter(object): 10 | 11 | def __init__(self): 12 | self.reset() 13 | 14 | def reset(self): 15 | self.avg = 0 16 | self.sum = 0 17 | self.cnt = 0 18 | 19 | def update(self, val, n=1): 20 | self.sum += val * n 21 | self.cnt += n 22 | self.avg = self.sum / self.cnt 23 | 24 | 25 | def count_parameters_in_MB(model): 26 | return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6 27 | 28 | 29 | def save_checkpoint(state, is_best, save): 30 | filename = os.path.join(save, 'checkpoint.pth.tar') 31 | torch.save(state, filename) 32 | if is_best: 33 | best_filename = os.path.join(save, 'model_best.pth.tar') 34 | shutil.copyfile(filename, best_filename) 35 | 36 | 37 | def save(model, model_path): 38 | torch.save(model.state_dict(), model_path) 39 | 40 | 41 | def load(model, model_path): 42 | a = torch.load(model_path) 43 | model.load_state_dict(a) 44 | 45 | 46 | def drop_path(x, drop_prob): 47 | if drop_prob > 0.: 48 | keep_prob = 1.-drop_prob 49 | mask = Variable(torch.cuda.FloatTensor(x.size(0), 1, 1, 1).bernoulli_(keep_prob)) 50 | x.div_(keep_prob) 51 | x.mul_(mask) 52 | return x 53 | 54 | 55 | def create_exp_dir(path, scripts_to_save=None): 56 | if not os.path.exists(path): 57 | os.mkdir(path) 58 | print('Experiment dir : {}'.format(path)) 59 | 60 | if scripts_to_save is not None: 61 | os.mkdir(os.path.join(path, 'scripts')) 62 | for script in scripts_to_save: 63 | dst_file = os.path.join(path, 'scripts', os.path.basename(script)) 64 | shutil.copyfile(script, dst_file) 65 | if __name__ == '__main__': 66 | a = 'fusion_genotype_vgg16-20210707-204342' 67 | os.mkdir(a) --------------------------------------------------------------------------------