├── .gitignore ├── README.md ├── compute_score.py ├── config.py ├── libs ├── __init__.py ├── dataset.py ├── saliency_metric.py └── util.py ├── main.py ├── nets ├── __init__.py ├── drn.py └── sodnet.py └── paper └── UMNet-cvpr2022.pdf /.gitignore: -------------------------------------------------------------------------------- 1 | .idea* 2 | __pycache__ 3 | *.txt 4 | events.* 5 | *.swp 6 | *.tar 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # UMNet 2 | The Pytorch implementation of CVPR2022 paper [Multi-Source Uncertainty Mining for Deep Unsupervised Saliency Detection](https://openaccess.thecvf.com/content/CVPR2022/papers/Wang_Multi-Source_Uncertainty_Mining_for_Deep_Unsupervised_Saliency_Detection_CVPR_2022_paper.pdf) 3 | 4 | # Trained Model,Test Data and Results 5 | 6 | Please download the trained model, test data and SOD results from [Baidu Cloud](https://pan.baidu.com/s/10YDn5tiexLx4iUjE8zP4wg?pwd=tmzw) (password: tmzw). 7 | 8 | 9 | 10 | # Requirement 11 | • Python 3.7 12 | 13 | • PyTorch 1.6.1 14 | 15 | • torchvision 16 | 17 | • numpy 18 | 19 | • Pillow 20 | 21 | • Cython 22 | 23 | # Run 24 | 1. Please download the [trained model](https://pan.baidu.com/s/10YDn5tiexLx4iUjE8zP4wg?pwd=tmzw) and [test datasets](https://pan.baidu.com/s/10YDn5tiexLx4iUjE8zP4wg?pwd=tmzw) (including DUTS-TE, OMRON, ECSSD, and HKU-IS). Uncompress and put them in the current file. 25 | 3. Set the path of testing sets and trained model in [config.py](https://github.com/yifanw90/UMNet/blob/main/config.py). The default setting can be in [config.py](https://github.com/yifanw90/UMNet/blob/main/config.py). 26 | 4. Run [main.py](https://github.com/yifanw90/UMNet/blob/main/main.py) to obtain the predicted saliency maps. The results are saved in the save_path (see [config.py](https://github.com/yifanw90/UMNet/blob/main/config.py)). You can also download our saliency results from [Baidu Cloud](https://pan.baidu.com/s/10YDn5tiexLx4iUjE8zP4wg?pwd=tmzw). 27 | 4. Run [compute_score.py](https://github.com/yifanw90/UMNet/blob/main/compute_score.py) to obtain the evaluation scores of the predictions in terms of MAE, Fmax, Sm, and Em. The evaluation codes are referred from https://github.com/Xiaoqi-Zhao-DLUT/GateNet-RGB-Saliency. 28 | 5. Please be sure that the paths of ground truth and predictions are valid in [compute_score.py](https://github.com/yifanw90/UMNet/blob/main/compute_score.py). 29 | 30 | 31 | # Train 32 | Note: Our method is trained mainly following the same setting of [DeepUSPS](https://github.com/sally20921/DeepUSPS). We use MSRA-B 2500 training data for network training. 33 | 1. Four traditional SOD methods including MC, HS, DSR, and RBD are adopted to generate pseudo labels for the training data, which are refined using the first stage of DeepUSPS. 34 | 2. The four kinds of refined pseudo labels are used for multi-source network learning using our [training code](https://pan.baidu.com/s/18Xsq7MJ_hCNNCLM5Eyxt0g?pwd=a4hh) (extract code: a4hh). 35 | -------------------------------------------------------------------------------- /compute_score.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from libs.util import eval_dataset 3 | from libs.saliency_metric import cal_mae, cal_fm, cal_sm, cal_em, cal_wfm 4 | import os 5 | 6 | data_path_gt = 'Test_Data' 7 | data_path_pre = 'result_UMNet' ##pre_salmap_path 8 | 9 | 10 | datasets = ['DUTS-TE', 'OMRON', 'ECSSD', 'HKU-IS'] ##test_datasets_name 11 | #datasets = ['DUTS-TE'] 12 | 13 | for dataset in datasets: 14 | gt_path = os.path.join(data_path_gt, dataset, 'gt') 15 | pre_path = os.path.join(data_path_pre, dataset) 16 | test_loader = eval_dataset(pre_path, gt_path) 17 | mae, fm, sm, em = cal_mae(), cal_fm(test_loader.size), cal_sm(), cal_em() 18 | 19 | for i in range(test_loader.size): 20 | print('Computing scores for %d / %d' % (i + 1, test_loader.size)) 21 | sal, gt = test_loader.load_data() 22 | #assert sal.size == gt.size 23 | if sal.size != gt.size: 24 | x, y = gt.size 25 | sal = sal.resize((x, y)) 26 | gt = np.asarray(gt, np.float32) 27 | gt /= (gt.max() + 1e-8) # convert gt from [0, 255] to [0,1] 28 | gt[gt > 0.5] = 1 29 | gt[gt != 1] = 0 # binarize gt with a threthhold of 0.5 30 | res = sal 31 | res = np.array(res) 32 | if res.max() == res.min(): 33 | res = res / 255.0 34 | else: 35 | res = (res - res.min()) / (res.max() - res.min()) # convert res to [0,1] 36 | mae.update(res, gt) 37 | sm.update(res, gt) 38 | fm.update(res, gt) 39 | em.update(res, gt) 40 | 41 | MAE = mae.show() 42 | maxf, _, _, _ = fm.show() 43 | sm = sm.show() 44 | em = em.show() 45 | print('dataset: {} MAE: {:.4f} maxF: {:.4f} Sm: {:.4f} Em: {:.4f}'.format(dataset, MAE, maxf, sm, em)) 46 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | from easydict import EasyDict as edict 2 | import os 3 | 4 | 5 | os.environ['CUDA_VISIBLE_DEVICES'] = '4' 6 | 7 | config = edict() 8 | 9 | ### Input 10 | config.input_size = [320, 320] 11 | ## PIL Image.open 12 | config.input_mean = [0.49227863, 0.46342391, 0.39742668] 13 | config.input_std = [0.22730138, 0.22451538, 0.22985159] 14 | 15 | #config.snapshot = 'models/UMNet_trained.pth' 16 | config.model_path = 'trained_model/UMNet.pth' 17 | 18 | config.data_path = 'Test_Data' 19 | 20 | config.datasets = ['DUTS-TE', 'OMRON', 'ECSSD', 'HKU-IS'] #dataset name 21 | #config.datasets = ['DUTS-TE'] 22 | 23 | config.save_path = 'result_UMNet' # save path 24 | -------------------------------------------------------------------------------- /libs/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/UMNet/0d2608ab97665cff236676026c8c8ef41874b5b6/libs/__init__.py -------------------------------------------------------------------------------- /libs/dataset.py: -------------------------------------------------------------------------------- 1 | from __future__ import division 2 | import torch 3 | from PIL import Image 4 | import torchvision.transforms as T 5 | import torch.nn.functional as F 6 | from collections import namedtuple 7 | 8 | 9 | class Transform(object): 10 | def __init__(self, config): 11 | transforms = [] 12 | transforms.append(ToTensor()) 13 | transforms.append(Normalize(config.input_mean, config.input_std)) 14 | transforms.append(Resize(config.input_size)) 15 | self.transforms = T.Compose(transforms) 16 | def __call__(self, batch): 17 | batch = self.transforms(batch) 18 | return batch 19 | 20 | 21 | class ToTensor(object): 22 | """Converts a PIL.Image or numpy.ndarray (H x W x C) in the range 23 | [0, 255] to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0]. 24 | """ 25 | def __init__(self): 26 | self.to_tensor = T.ToTensor() 27 | self.dtype = torch.float32 28 | 29 | def __call__(self, image): 30 | image = self.to_tensor(image).type(self.dtype) # C, H, W 31 | 32 | if image.shape[0] == 1: 33 | image = torch.cat((image, image, image), dim=0) 34 | 35 | return image 36 | 37 | 38 | class Normalize(object): 39 | def __init__(self, mean, std): 40 | self.normalize = T.Normalize(mean, std, inplace=True) 41 | def __call__(self, image): 42 | return self.normalize(image) 43 | 44 | 45 | class Resize(object): 46 | def __init__(self, input_size): 47 | self.input_size = input_size 48 | 49 | def __call__(self, image): 50 | return F.interpolate(image[None], size=self.input_size, mode='bilinear', align_corners=False)[0] 51 | -------------------------------------------------------------------------------- /libs/saliency_metric.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import ndimage 3 | from scipy.ndimage import convolve, distance_transform_edt as bwdist 4 | 5 | 6 | class cal_fm(object): 7 | # Fmeasure(maxFm,meanFm)---Frequency-tuned salient region detection(CVPR 2009) 8 | def __init__(self, num, thds=255): 9 | self.num = num 10 | self.thds = thds 11 | self.precision = np.zeros((self.num, self.thds)) 12 | self.recall = np.zeros((self.num, self.thds)) 13 | self.meanF = np.zeros((self.num, 1)) 14 | self.idx = 0 15 | 16 | def update(self, pred, gt): 17 | if gt.max() != 0: 18 | prediction, recall, Fmeasure_temp = self.cal(pred, gt) # 得到当前图像255个阈值的精度和召回率,以及自适应阈值F-measure 19 | self.precision[self.idx, :] = prediction 20 | self.recall[self.idx, :] = recall 21 | self.meanF[self.idx, :] = Fmeasure_temp 22 | self.idx += 1 23 | 24 | def cal(self, pred, gt): 25 | ########################meanF############################## 26 | th = 2 * pred.mean() 27 | if th > 1: 28 | th = 1 29 | binary = np.zeros_like(pred) 30 | binary[pred >= th] = 1 31 | hard_gt = np.zeros_like(gt) 32 | hard_gt[gt > 0.5] = 1 33 | tp = (binary * hard_gt).sum() 34 | if tp == 0: 35 | meanF = 0 36 | else: 37 | pre = tp / binary.sum() 38 | rec = tp / hard_gt.sum() 39 | meanF = 1.3 * pre * rec / (0.3 * pre + rec) 40 | ########################maxF############################## 41 | pred = np.uint8(pred * 255) 42 | target = pred[gt > 0.5] 43 | nontarget = pred[gt <= 0.5] 44 | targetHist, _ = np.histogram(target, bins=range(256)) 45 | nontargetHist, _ = np.histogram(nontarget, bins=range(256)) 46 | targetHist = np.cumsum(np.flip(targetHist), axis=0) 47 | nontargetHist = np.cumsum(np.flip(nontargetHist), axis=0) 48 | precision = targetHist / (targetHist + nontargetHist + 1e-8) 49 | recall = targetHist / np.sum(gt) 50 | return precision, recall, meanF 51 | 52 | def show(self): 53 | assert self.num == self.idx 54 | precision = self.precision.mean(axis=0) 55 | recall = self.recall.mean(axis=0) 56 | fmeasure = 1.3 * precision * recall / (0.3 * precision + recall + 1e-8) 57 | fmeasure_avg = self.meanF.mean(axis=0) 58 | return fmeasure.max(), fmeasure_avg[0], precision, recall 59 | 60 | 61 | class cal_mae(object): 62 | # mean absolute error 63 | def __init__(self): 64 | self.prediction = [] 65 | 66 | def update(self, pred, gt): 67 | score = self.cal(pred, gt) 68 | self.prediction.append(score) 69 | 70 | def cal(self, pred, gt): 71 | return np.mean(np.abs(pred - gt)) 72 | 73 | def show(self): 74 | return np.mean(self.prediction) 75 | 76 | 77 | class cal_sm(object): 78 | # Structure-measure: A new way to evaluate foreground maps (ICCV 2017) 79 | def __init__(self, alpha=0.5): 80 | self.prediction = [] 81 | self.alpha = alpha 82 | 83 | def update(self, pred, gt): 84 | gt = gt > 0.5 85 | score = self.cal(pred, gt) 86 | self.prediction.append(score) 87 | 88 | def show(self): 89 | return np.mean(self.prediction) 90 | 91 | def cal(self, pred, gt): 92 | y = np.mean(gt) 93 | if y == 0: 94 | score = 1 - np.mean(pred) 95 | elif y == 1: 96 | score = np.mean(pred) 97 | else: 98 | score = self.alpha * self.object(pred, gt) + (1 - self.alpha) * self.region(pred, gt) 99 | return score 100 | 101 | def object(self, pred, gt): 102 | fg = pred * gt 103 | bg = (1 - pred) * (1 - gt) 104 | 105 | u = np.mean(gt) 106 | return u * self.s_object(fg, gt) + (1 - u) * self.s_object(bg, np.logical_not(gt)) 107 | 108 | def s_object(self, in1, in2): 109 | x = np.mean(in1[in2]) 110 | sigma_x = np.std(in1[in2]) 111 | return 2 * x / (pow(x, 2) + 1 + sigma_x + 1e-8) 112 | 113 | def region(self, pred, gt): 114 | [y, x] = ndimage.center_of_mass(gt) 115 | y = int(round(y)) + 1 116 | x = int(round(x)) + 1 117 | [gt1, gt2, gt3, gt4, w1, w2, w3, w4] = self.divideGT(gt, x, y) 118 | pred1, pred2, pred3, pred4 = self.dividePred(pred, x, y) 119 | 120 | score1 = self.ssim(pred1, gt1) 121 | score2 = self.ssim(pred2, gt2) 122 | score3 = self.ssim(pred3, gt3) 123 | score4 = self.ssim(pred4, gt4) 124 | 125 | return w1 * score1 + w2 * score2 + w3 * score3 + w4 * score4 126 | 127 | def divideGT(self, gt, x, y): 128 | h, w = gt.shape 129 | area = h * w 130 | LT = gt[0:y, 0:x] 131 | RT = gt[0:y, x:w] 132 | LB = gt[y:h, 0:x] 133 | RB = gt[y:h, x:w] 134 | 135 | w1 = x * y / area 136 | w2 = y * (w - x) / area 137 | w3 = (h - y) * x / area 138 | w4 = (h - y) * (w - x) / area 139 | 140 | return LT, RT, LB, RB, w1, w2, w3, w4 141 | 142 | def dividePred(self, pred, x, y): 143 | h, w = pred.shape 144 | LT = pred[0:y, 0:x] 145 | RT = pred[0:y, x:w] 146 | LB = pred[y:h, 0:x] 147 | RB = pred[y:h, x:w] 148 | 149 | return LT, RT, LB, RB 150 | 151 | def ssim(self, in1, in2): 152 | in2 = np.float32(in2) 153 | h, w = in1.shape 154 | N = h * w 155 | 156 | x = np.mean(in1) 157 | y = np.mean(in2) 158 | sigma_x = np.var(in1) 159 | sigma_y = np.var(in2) 160 | sigma_xy = np.sum((in1 - x) * (in2 - y)) / (N - 1) 161 | 162 | alpha = 4 * x * y * sigma_xy 163 | beta = (x * x + y * y) * (sigma_x + sigma_y) 164 | 165 | if alpha != 0: 166 | score = alpha / (beta + 1e-8) 167 | elif alpha == 0 and beta == 0: 168 | score = 1 169 | else: 170 | score = 0 171 | 172 | return score 173 | 174 | 175 | class cal_em(object): 176 | # Enhanced-alignment Measure for Binary Foreground Map Evaluation (IJCAI 2018) 177 | def __init__(self): 178 | self.prediction = [] 179 | 180 | def update(self, pred, gt): 181 | score = self.cal(pred, gt) 182 | self.prediction.append(score) 183 | 184 | def cal(self, pred, gt): 185 | th = 2 * pred.mean() 186 | if th > 1: 187 | th = 1 188 | FM = np.zeros(gt.shape) 189 | FM[pred >= th] = 1 190 | FM = np.array(FM, dtype=bool) 191 | GT = np.array(gt, dtype=bool) 192 | dFM = np.double(FM) 193 | if (sum(sum(np.double(GT))) == 0): 194 | enhanced_matrix = 1.0 - dFM 195 | elif (sum(sum(np.double(~GT))) == 0): 196 | enhanced_matrix = dFM 197 | else: 198 | dGT = np.double(GT) 199 | align_matrix = self.AlignmentTerm(dFM, dGT) 200 | enhanced_matrix = self.EnhancedAlignmentTerm(align_matrix) 201 | [w, h] = np.shape(GT) 202 | score = sum(sum(enhanced_matrix)) / (w * h - 1 + 1e-8) 203 | return score 204 | 205 | def AlignmentTerm(self, dFM, dGT): 206 | mu_FM = np.mean(dFM) 207 | mu_GT = np.mean(dGT) 208 | align_FM = dFM - mu_FM 209 | align_GT = dGT - mu_GT 210 | align_Matrix = 2. * (align_GT * align_FM) / (align_GT * align_GT + align_FM * align_FM + 1e-8) 211 | return align_Matrix 212 | 213 | def EnhancedAlignmentTerm(self, align_Matrix): 214 | enhanced = np.power(align_Matrix + 1, 2) / 4 215 | return enhanced 216 | 217 | def show(self): 218 | return np.mean(self.prediction) 219 | 220 | 221 | class cal_wfm(object): 222 | def __init__(self, beta=1): 223 | self.beta = beta 224 | self.eps = 1e-6 225 | self.scores_list = [] 226 | 227 | def update(self, pred, gt): 228 | assert pred.ndim == gt.ndim and pred.shape == gt.shape 229 | assert pred.max() <= 1 and pred.min() >= 0 230 | assert gt.max() <= 1 and gt.min() >= 0 231 | 232 | gt = gt > 0.5 233 | if gt.max() == 0: 234 | score = 0 235 | else: 236 | score = self.cal(pred, gt) 237 | self.scores_list.append(score) 238 | 239 | def matlab_style_gauss2D(self, shape=(7, 7), sigma=5): 240 | """ 241 | 2D gaussian mask - should give the same result as MATLAB's 242 | fspecial('gaussian',[shape],[sigma]) 243 | """ 244 | m, n = [(ss - 1.) / 2. for ss in shape] 245 | y, x = np.ogrid[-m:m + 1, -n:n + 1] 246 | h = np.exp(-(x * x + y * y) / (2. * sigma * sigma)) 247 | h[h < np.finfo(h.dtype).eps * h.max()] = 0 248 | sumh = h.sum() 249 | if sumh != 0: 250 | h /= sumh 251 | return h 252 | 253 | def cal(self, pred, gt): 254 | # [Dst,IDXT] = bwdist(dGT); 255 | Dst, Idxt = bwdist(gt == 0, return_indices=True) 256 | 257 | # %Pixel dependency 258 | # E = abs(FG-dGT); 259 | E = np.abs(pred - gt) 260 | # Et = E; 261 | # Et(~GT)=Et(IDXT(~GT)); %To deal correctly with the edges of the foreground region 262 | Et = np.copy(E) 263 | Et[gt == 0] = Et[Idxt[0][gt == 0], Idxt[1][gt == 0]] 264 | 265 | # K = fspecial('gaussian',7,5); 266 | # EA = imfilter(Et,K); 267 | # MIN_E_EA(GT & EA 0: 153 | self.avgpool = nn.AvgPool2d(pool_size) 154 | self.fc = nn.Conv2d(self.out_dim, num_classes, kernel_size=1, 155 | stride=1, padding=0, bias=True) 156 | for m in self.modules(): 157 | if isinstance(m, nn.Conv2d): # 对卷积层权重初始化 158 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 159 | m.weight.data.normal_(0, math.sqrt(2. / n)) 160 | elif isinstance(m, BatchNorm): # 对bn层权重和偏置初始化 161 | m.weight.data.fill_(1) 162 | m.bias.data.zero_() 163 | 164 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, 165 | new_level=True, residual=True): 166 | assert dilation == 1 or dilation % 2 == 0 167 | downsample = None 168 | if stride != 1 or self.inplanes != planes * block.expansion: 169 | downsample = nn.Sequential( 170 | nn.Conv2d(self.inplanes, planes * block.expansion, 171 | kernel_size=1, stride=stride, bias=False), 172 | BatchNorm(planes * block.expansion), 173 | ) 174 | 175 | layers = list() 176 | layers.append(block( 177 | self.inplanes, planes, stride, downsample, 178 | dilation=(1, 1) if dilation == 1 else ( 179 | dilation // 2 if new_level else dilation, dilation), 180 | residual=residual)) 181 | self.inplanes = planes * block.expansion 182 | for i in range(1, blocks): 183 | layers.append(block(self.inplanes, planes, residual=residual, 184 | dilation=(dilation, dilation))) 185 | 186 | return nn.Sequential(*layers) 187 | 188 | def _make_conv_layers(self, channels, convs, stride=1, dilation=1): 189 | modules = [] 190 | for i in range(convs): 191 | modules.extend([ 192 | nn.Conv2d(self.inplanes, channels, kernel_size=3, 193 | stride=stride if i == 0 else 1, 194 | padding=dilation, bias=False, dilation=dilation), 195 | BatchNorm(channels), 196 | nn.ReLU(inplace=True)]) 197 | self.inplanes = channels 198 | return nn.Sequential(*modules) 199 | 200 | def forward(self, x): 201 | y = list() 202 | if self.arch == 'C': 203 | x = self.conv1(x) 204 | x = self.bn1(x) 205 | x = self.relu(x) 206 | elif self.arch == 'D': 207 | x = self.layer0(x) 208 | 209 | x = self.layer1(x) 210 | y.append(x) 211 | x = self.layer2(x) 212 | y.append(x) 213 | 214 | x = self.layer3(x) 215 | y.append(x) 216 | 217 | x = self.layer4(x) 218 | y.append(x) 219 | 220 | x = self.layer5(x) 221 | y.append(x) 222 | 223 | if self.layer6 is not None: 224 | x = self.layer6(x) 225 | y.append(x) 226 | 227 | if self.layer7 is not None: 228 | x = self.layer7(x) 229 | y.append(x) 230 | 231 | if self.layer8 is not None: 232 | x = self.layer8(x) 233 | y.append(x) 234 | 235 | #if self.out_map: 236 | # x = self.fc(x) 237 | #else: 238 | # x = self.avgpool(x) 239 | # x = self.fc(x) 240 | # x = x.view(x.size(0), -1) 241 | 242 | if self.out_middle: 243 | return x, y 244 | else: 245 | return x 246 | 247 | 248 | class DRN_A(nn.Module): 249 | 250 | def __init__(self, block, layers, num_classes=1000): 251 | self.inplanes = 64 252 | super(DRN_A, self).__init__() 253 | self.out_dim = 512 * block.expansion 254 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 255 | bias=False) 256 | self.bn1 = nn.BatchNorm2d(64) 257 | self.relu = nn.ReLU(inplace=True) 258 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 259 | self.layer1 = self._make_layer(block, 64, layers[0]) 260 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 261 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 262 | dilation=2) 263 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 264 | dilation=4) 265 | self.avgpool = nn.AvgPool2d(28, stride=1) 266 | self.fc = nn.Linear(512 * block.expansion, num_classes) 267 | 268 | for m in self.modules(): 269 | if isinstance(m, nn.Conv2d): 270 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 271 | m.weight.data.normal_(0, math.sqrt(2. / n)) 272 | elif isinstance(m, BatchNorm): 273 | m.weight.data.fill_(1) 274 | m.bias.data.zero_() 275 | 276 | 277 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1): 278 | downsample = None 279 | if stride != 1 or self.inplanes != planes * block.expansion: 280 | downsample = nn.Sequential( 281 | nn.Conv2d(self.inplanes, planes * block.expansion, 282 | kernel_size=1, stride=stride, bias=False), 283 | nn.BatchNorm2d(planes * block.expansion), 284 | ) 285 | 286 | layers = [] 287 | layers.append(block(self.inplanes, planes, stride, downsample)) 288 | self.inplanes = planes * block.expansion 289 | for i in range(1, blocks): 290 | layers.append(block(self.inplanes, planes, 291 | dilation=(dilation, dilation))) 292 | 293 | return nn.Sequential(*layers) 294 | 295 | def forward(self, x): 296 | x = self.conv1(x) 297 | x = self.bn1(x) 298 | x = self.relu(x) 299 | x = self.maxpool(x) 300 | 301 | x = self.layer1(x) 302 | x = self.layer2(x) 303 | x = self.layer3(x) 304 | x = self.layer4(x) 305 | 306 | x = self.avgpool(x) 307 | x = x.view(x.size(0), -1) 308 | x = self.fc(x) 309 | 310 | return x 311 | 312 | 313 | def drn_a_50(pretrained=False, **kwargs): 314 | model = DRN_A(Bottleneck, [3, 4, 6, 3], **kwargs) 315 | if pretrained: 316 | model.load_state_dict(model_zoo.load_url(model_urls['resnet50'])) 317 | return model 318 | 319 | 320 | def drn_c_26(pretrained=False, **kwargs): 321 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='C', **kwargs) 322 | if pretrained: 323 | model.load_state_dict(model_zoo.load_url(model_urls['drn-c-26'])) 324 | return model 325 | 326 | 327 | def drn_c_42(pretrained=False, **kwargs): 328 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', **kwargs) 329 | if pretrained: 330 | model.load_state_dict(model_zoo.load_url(model_urls['drn-c-42'])) 331 | return model 332 | 333 | 334 | def drn_c_58(pretrained=False, **kwargs): 335 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='C', **kwargs) 336 | if pretrained: 337 | model.load_state_dict(model_zoo.load_url(model_urls['drn-c-58'])) 338 | return model 339 | 340 | 341 | def drn_d_22(pretrained=False, **kwargs): 342 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 1, 1], arch='D', **kwargs) 343 | if pretrained: 344 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-22'])) 345 | return model 346 | 347 | 348 | def drn_d_24(pretrained=False, **kwargs): 349 | model = DRN(BasicBlock, [1, 1, 2, 2, 2, 2, 2, 2], arch='D', **kwargs) 350 | if pretrained: 351 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-24'])) 352 | return model 353 | 354 | 355 | def drn_d_38(pretrained=False, **kwargs): 356 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', **kwargs) 357 | if pretrained: 358 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-38'])) 359 | return model 360 | 361 | 362 | def drn_d_40(pretrained=False, **kwargs): 363 | model = DRN(BasicBlock, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', **kwargs) 364 | if pretrained: 365 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-40'])) 366 | return model 367 | 368 | 369 | def drn_d_54(pretrained=False, **kwargs): 370 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 1, 1], arch='D', **kwargs) 371 | if pretrained: 372 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-54'])) 373 | return model 374 | 375 | 376 | def drn_d_56(pretrained=False, **kwargs): 377 | model = DRN(Bottleneck, [1, 1, 3, 4, 6, 3, 2, 2], arch='D', **kwargs) 378 | if pretrained: 379 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-56'])) 380 | return model 381 | 382 | 383 | def drn_d_105(pretrained=False, out_middle=False, **kwargs): 384 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 1, 1], arch='D', out_middle=out_middle, **kwargs) 385 | if pretrained: 386 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-105'])) 387 | return model 388 | 389 | 390 | def drn_d_107(pretrained=False, **kwargs): 391 | model = DRN(Bottleneck, [1, 1, 3, 4, 23, 3, 2, 2], arch='D', **kwargs) 392 | if pretrained: 393 | model.load_state_dict(model_zoo.load_url(model_urls['drn-d-107'])) 394 | return model 395 | -------------------------------------------------------------------------------- /nets/sodnet.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from nets import drn 3 | import torch.nn.functional as F 4 | 5 | try: 6 | from nn.modules import batchnormsync 7 | except ImportError: 8 | pass 9 | 10 | 11 | class SODNet(nn.Module): 12 | def __init__(self, backbone_name='drn_d_105', num_flabel=4, pretrained_model=None, pretrained=False): 13 | super(SODNet, self).__init__() 14 | self.backbone = drn.__dict__.get(backbone_name)(pretrained=pretrained, out_middle=True, num_classes=1000) # 创建了初始的DRN_base model 15 | self.conv2 = nn.Conv2d(512, 1, 3, padding=1, stride=1, bias=True) 16 | self.sigmoid = nn.Sigmoid() 17 | 18 | 19 | def forward(self, im, target_size): 20 | ### saliency network 21 | x, _ = self.backbone(im) #input: im shape [320, 320] 22 | x = self.conv2(x) 23 | x = F.interpolate(x, size=im.shape[2:], mode='bilinear', align_corners=False) #resize to the resolution of input 24 | 25 | pre = self.sigmoid(x) 26 | pre = F.interpolate(pre, size=target_size, mode='bilinear', align_corners=False) #resize to the resolution of original image 27 | pre = (pre - pre.min()) / (pre.max() - pre.min()) 28 | return pre -------------------------------------------------------------------------------- /paper/UMNet-cvpr2022.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/yifanw90/UMNet/0d2608ab97665cff236676026c8c8ef41874b5b6/paper/UMNet-cvpr2022.pdf --------------------------------------------------------------------------------