├── README.md ├── data.py ├── data_aug.m ├── images ├── MCCM.png ├── MCCNet.png └── table.png ├── model ├── MCCNet_models.py ├── __init__.py └── vgg.py ├── pytorch_fm ├── __init__.py └── __init__.pyc ├── pytorch_iou ├── __init__.py ├── __init__.pyc └── __pycache__ │ └── __init__.cpython-36.pyc ├── test_MCCNet.py ├── train_MCCNet.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # MCCNet 2 | This project provides the code and results for 'Multi-Content Complementation Network for Salient Object Detection in Optical Remote Sensing Images', IEEE TGRS, vol. 60, pp. 1-13, 2022. [IEEE link](https://ieeexplore.ieee.org/document/9631225) and [arxiv link](https://arxiv.org/abs/2112.01932) [Homepage](https://mathlee.github.io/) 3 | 4 | 5 | # Network Architecture 6 |
7 | 8 |
9 | 10 | # Multi-Content Complementation Module (MCCM) 11 |
12 | 13 |
14 | 15 | 16 | # Requirements 17 | python 2.7 + pytorch 0.4.0 or 18 | 19 | python 3.7 + pytorch 1.9.0 20 | 21 | 22 | # Saliency maps 23 | We provide saliency maps and [measure results (.mat)](https://pan.baidu.com/s/1l4GPBcPYCO9atbgDwbkfUw) (code: i9d0) of [all compared methods](https://pan.baidu.com/s/1TP6An1VWygGUy4uvojL0bg) (code: 5np3) and [our MCCNet](https://pan.baidu.com/s/10JIKL2Q48RvBGeT2pmPfDA) (code: 3pvq) on ORSSD and EORSSD datasets. 24 | 25 | In addition, we also provide [saliency maps of our MCCNet](https://pan.baidu.com/s/1dz-GeELIqMdzKlPvzETixA) (code: 413m) on the recently published [ORSI-4199](https://github.com/wchao1213/ORSI-SOD) dataset. 26 | 27 | ![Image](https://github.com/MathLee/MCCNet/blob/main/images/table.png) 28 | 29 | # Training 30 | 31 | We get the ground truth of edge using [sal2edge.m](https://github.com/JXingZhao/EGNet/blob/master/sal2edge.m) in [EGNet](https://github.com/JXingZhao/EGNet),and use data_aug.m for data augmentation. 32 | 33 | Modify paths of [VGG backbone](https://pan.baidu.com/s/1YQxKZ-y2C4EsqrgKNI7qrw) (code: ego5) and datasets, then run train_MCCNet.py. 34 | 35 | 36 | # Pre-trained model and testing 37 | Download the following pre-trained model, and modify paths of pre-trained model and datasets, then run test_MCCNet.py. 38 | 39 | [ORSSD](https://pan.baidu.com/s/1LdUE8F11r61r8wk3Y9wPLA) (code: awqr) 40 | 41 | [EORSSD](https://pan.baidu.com/s/14LrEt1LW5QmZvkhsgbKgfg) (code: wm3p) 42 | 43 | [ORSI-4199](https://pan.baidu.com/s/1hmANQp9cslyPuDE-3NlqAg) (code: 336a) 44 | 45 | 46 | # Evaluation Tool 47 | You can use the [evaluation tool (MATLAB version)](https://github.com/MathLee/MatlabEvaluationTools) to evaluate the above saliency maps. 48 | 49 | 50 | # [ORSI-SOD_Summary](https://github.com/MathLee/ORSI-SOD_Summary) 51 | 52 | # Citation 53 | @ARTICLE{Li_2022_MCCNet, 54 | author = {Gongyang Li and Zhi Liu and Weisi Lin and Haibin Ling}, 55 | title = {Multi-Content Complementation Network for Salient Object Detection in Optical Remote Sensing Images}, 56 | journal = {IEEE Transactions on Geoscience and Remote Sensing}, 57 | volume = {60}, 58 | pages = {1-13}, 59 | year = {2022}, 60 | } 61 | 62 | 63 | If you encounter any problems with the code, want to report bugs, etc. 64 | 65 | Please contact me at lllmiemie@163.com or ligongyang@shu.edu.cn. 66 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | 9 | 10 | #several data augumentation strategies 11 | def cv_random_flip(img, label): 12 | flip_flag = random.randint(0, 1) 13 | # flip_flag2= random.randint(0,1) 14 | #left right flip 15 | if flip_flag == 1: 16 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 17 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 18 | #top bottom flip 19 | # if flip_flag2==1: 20 | # img = img.transpose(Image.FLIP_TOP_BOTTOM) 21 | # label = label.transpose(Image.FLIP_TOP_BOTTOM) 22 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM) 23 | return img, label 24 | def randomCrop(image, label): 25 | border=30 26 | image_width = image.size[0] 27 | image_height = image.size[1] 28 | crop_win_width = np.random.randint(image_width-border , image_width) 29 | crop_win_height = np.random.randint(image_height-border , image_height) 30 | random_region = ( 31 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 32 | (image_height + crop_win_height) >> 1) 33 | return image.crop(random_region), label.crop(random_region) 34 | def randomRotation(image,label): 35 | mode=Image.BICUBIC 36 | if random.random()>0.8: 37 | random_angle = np.random.randint(-15, 15) 38 | image=image.rotate(random_angle, mode) 39 | label=label.rotate(random_angle, mode) 40 | return image,label 41 | def colorEnhance(image): 42 | bright_intensity=random.randint(5,15)/10.0 43 | image=ImageEnhance.Brightness(image).enhance(bright_intensity) 44 | contrast_intensity=random.randint(5,15)/10.0 45 | image=ImageEnhance.Contrast(image).enhance(contrast_intensity) 46 | color_intensity=random.randint(0,20)/10.0 47 | image=ImageEnhance.Color(image).enhance(color_intensity) 48 | sharp_intensity=random.randint(0,30)/10.0 49 | image=ImageEnhance.Sharpness(image).enhance(sharp_intensity) 50 | return image 51 | def randomGaussian(image, mean=0.1, sigma=0.35): 52 | def gaussianNoisy(im, mean=mean, sigma=sigma): 53 | for _i in range(len(im)): 54 | im[_i] += random.gauss(mean, sigma) 55 | return im 56 | img = np.asarray(image) 57 | width, height = img.shape 58 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 59 | img = img.reshape([width, height]) 60 | return Image.fromarray(np.uint8(img)) 61 | def randomPeper(img): 62 | 63 | img=np.array(img) 64 | noiseNum=int(0.0015*img.shape[0]*img.shape[1]) 65 | for i in range(noiseNum): 66 | 67 | randX=random.randint(0,img.shape[0]-1) 68 | 69 | randY=random.randint(0,img.shape[1]-1) 70 | 71 | if random.randint(0,1)==0: 72 | 73 | img[randX,randY]=0 74 | 75 | else: 76 | 77 | img[randX,randY]=255 78 | return Image.fromarray(img) 79 | 80 | 81 | class SalObjDataset(data.Dataset): 82 | def __init__(self, image_root, gt_root, edge_root, trainsize): 83 | self.trainsize = trainsize 84 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 85 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 86 | or f.endswith('.png')] 87 | self.edges = [gt_root + f for f in os.listdir(edge_root) if f.endswith('.jpg') 88 | or f.endswith('.png')] 89 | self.images = sorted(self.images) 90 | # self.depths = sorted(self.depths) 91 | self.gts = sorted(self.gts) 92 | self.edges = sorted(self.edges) 93 | self.filter_files() 94 | self.size = len(self.images) 95 | self.img_transform = transforms.Compose([ 96 | transforms.Resize((self.trainsize, self.trainsize)), 97 | transforms.ToTensor(), 98 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 99 | 100 | self.gt_transform = transforms.Compose([ 101 | transforms.Resize((self.trainsize, self.trainsize)), 102 | transforms.ToTensor()]) 103 | 104 | self.edge_transform = transforms.Compose([ 105 | transforms.Resize((self.trainsize, self.trainsize)), 106 | transforms.ToTensor()]) 107 | 108 | def __getitem__(self, index): 109 | image = self.rgb_loader(self.images[index]) 110 | gt = self.binary_loader(self.gts[index]) 111 | edge = self.binary_loader(self.edges[index]) 112 | # image,gt =cv_random_flip(image,gt) 113 | # image,gt =randomCrop(image, gt) 114 | # image,gt =randomRotation(image, gt) 115 | # image=colorEnhance(image) 116 | # gt=randomPeper(gt) 117 | image = self.img_transform(image) 118 | gt = self.gt_transform(gt) 119 | edge = self.gt_transform(edge) 120 | return image, gt, edge 121 | 122 | def filter_files(self): 123 | assert len(self.images) == len(self.gts) 124 | images = [] 125 | # depths = [] 126 | gts = [] 127 | edges = [] 128 | for img_path, gt_path, edge_path in zip(self.images, self.gts, self.edges): 129 | img = Image.open(img_path) 130 | gt = Image.open(gt_path) 131 | edge = Image.open(edge_path) 132 | if img.size == gt.size: 133 | images.append(img_path) 134 | gts.append(gt_path) 135 | edges.append(edge_path) 136 | self.images = images 137 | self.gts = gts 138 | self.edges = edges 139 | 140 | def rgb_loader(self, path): 141 | with open(path, 'rb') as f: 142 | img = Image.open(f) 143 | return img.convert('RGB') 144 | 145 | def binary_loader(self, path): 146 | with open(path, 'rb') as f: 147 | img = Image.open(f) 148 | # return img.convert('1') 149 | return img.convert('L') 150 | 151 | def resize(self, img, gt, edge): 152 | assert img.size == gt.size 153 | w, h = img.size 154 | if h < self.trainsize or w < self.trainsize: 155 | h = max(h, self.trainsize) 156 | w = max(w, self.trainsize) 157 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), edge.resize((w, h), Image.NEAREST) 158 | else: 159 | return img, gt, edge 160 | 161 | def __len__(self): 162 | return self.size 163 | 164 | 165 | def get_loader(image_root, gt_root, edge_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True): 166 | 167 | dataset = SalObjDataset(image_root, gt_root, edge_root, trainsize) 168 | data_loader = data.DataLoader(dataset=dataset, 169 | batch_size=batchsize, 170 | shuffle=shuffle, 171 | num_workers=num_workers, 172 | pin_memory=pin_memory) 173 | return data_loader 174 | 175 | 176 | class test_dataset: 177 | def __init__(self, image_root, gt_root, testsize): 178 | self.testsize = testsize 179 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 180 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 181 | or f.endswith('.png')] 182 | self.images = sorted(self.images) 183 | # self.depths = sorted(self.depths) 184 | self.gts = sorted(self.gts) 185 | self.img_transform = transforms.Compose([ 186 | transforms.Resize((self.testsize, self.testsize)), 187 | transforms.ToTensor(), 188 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 189 | self.gt_transform = transforms.ToTensor() 190 | self.size = len(self.images) 191 | self.index = 0 192 | 193 | def load_data(self): 194 | image = self.rgb_loader(self.images[self.index]) 195 | image = self.img_transform(image).unsqueeze(0) 196 | gt = self.binary_loader(self.gts[self.index]) 197 | name = self.images[self.index].split('/')[-1] 198 | if name.endswith('.jpg'): 199 | name = name.split('.jpg')[0] + '.png' 200 | self.index += 1 201 | return image, gt, name 202 | 203 | def rgb_loader(self, path): 204 | with open(path, 'rb') as f: 205 | img = Image.open(f) 206 | return img.convert('RGB') 207 | 208 | def binary_loader(self, path): 209 | with open(path, 'rb') as f: 210 | img = Image.open(f) 211 | return img.convert('L') 212 | -------------------------------------------------------------------------------- /data_aug.m: -------------------------------------------------------------------------------- 1 | clear; 2 | clc; 3 | close all; 4 | 5 | imPath = '/home/lgy/桌面/ORSI_SOD/dataset/EORSSD/train/image/'; 6 | GtPath = '/home/lgy/桌面/ORSI_SOD/dataset/EORSSD/train/GT/'; 7 | EGPath = '/home/lgy/桌面/ORSI_SOD/dataset/EORSSD/train/edge/'; 8 | 9 | images = dir([GtPath '*.png']); 10 | imagesNum = length(images); 11 | 12 | for i = 1 : imagesNum 13 | im_name = images(i).name(1:end-4); 14 | 15 | gt = imread(fullfile(GtPath, [im_name '.png'])); 16 | im = imread(fullfile(imPath, [im_name '.jpg'])); 17 | eg = imread(fullfile(EGPath, [im_name '.png'])); 18 | 19 | im_1 = imrotate(im,90); 20 | gt_1 = imrotate(gt,90); 21 | eg_1 = imrotate(eg,90); 22 | imwrite(im_1, fullfile(imPath, [im_name '_90.jpg'])); 23 | imwrite(gt_1, fullfile(GtPath, [im_name '_90.png'])); 24 | imwrite(eg_1, fullfile(EGPath, [im_name '_90.png'])); 25 | 26 | im_2 = imrotate(im, 180); 27 | gt_2 = imrotate(gt, 180); 28 | eg_2 = imrotate(eg, 180); 29 | imwrite(im_2, fullfile(imPath, [im_name '_180.jpg'])); 30 | imwrite(gt_2, fullfile(GtPath, [im_name '_180.png'])); 31 | imwrite(eg_2, fullfile(EGPath, [im_name '_180.png'])); 32 | 33 | im_3 = imrotate(im, 270); 34 | gt_3 = imrotate(gt, 270); 35 | eg_3 = imrotate(eg, 270); 36 | imwrite(im_3, fullfile(imPath, [im_name '_270.jpg'])); 37 | imwrite(gt_3, fullfile(GtPath, [im_name '_270.png'])); 38 | imwrite(eg_3, fullfile(EGPath, [im_name '_270.png'])); 39 | 40 | 41 | fl_im = fliplr(im); 42 | fl_gt = fliplr(gt); 43 | fl_eg = fliplr(eg); 44 | imwrite(fl_im, fullfile(imPath, [im_name '_fl.jpg'])); 45 | imwrite(fl_gt, fullfile(GtPath, [im_name '_fl.png'])); 46 | imwrite(fl_eg, fullfile(EGPath, [im_name '_fl.png'])); 47 | 48 | im_1 = imrotate(fl_im,90); 49 | gt_1 = imrotate(fl_gt,90); 50 | eg_1 = imrotate(fl_eg,90); 51 | imwrite(im_1, fullfile(imPath, [im_name '_fl90.jpg'])); 52 | imwrite(gt_1, fullfile(GtPath, [im_name '_fl90.png'])); 53 | imwrite(eg_1, fullfile(EGPath, [im_name '_fl90.png'])); 54 | 55 | im_2 = imrotate(fl_im, 180); 56 | gt_2 = imrotate(fl_gt, 180); 57 | eg_2 = imrotate(fl_eg, 180); 58 | imwrite(im_2, fullfile(imPath, [im_name '_fl180.jpg'])); 59 | imwrite(gt_2, fullfile(GtPath, [im_name '_fl180.png'])); 60 | imwrite(eg_2, fullfile(EGPath, [im_name '_fl180.png'])); 61 | 62 | im_3 = imrotate(fl_im, 270); 63 | gt_3 = imrotate(fl_gt, 270); 64 | eg_3 = imrotate(fl_eg, 270); 65 | imwrite(im_3, fullfile(imPath, [im_name '_fl270.jpg'])); 66 | imwrite(gt_3, fullfile(GtPath, [im_name '_fl270.png'])); 67 | imwrite(eg_3, fullfile(EGPath, [im_name '_fl270.png'])); 68 | 69 | end 70 | -------------------------------------------------------------------------------- /images/MCCM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/MCCNet/a9b1876267412ee795acfa94b4921b7f8940fe27/images/MCCM.png -------------------------------------------------------------------------------- /images/MCCNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/MCCNet/a9b1876267412ee795acfa94b4921b7f8940fe27/images/MCCNet.png -------------------------------------------------------------------------------- /images/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/MCCNet/a9b1876267412ee795acfa94b4921b7f8940fe27/images/table.png -------------------------------------------------------------------------------- /model/MCCNet_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import os 6 | 7 | from vgg import VGG 8 | 9 | 10 | class BasicConv2d(nn.Module): 11 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 12 | super(BasicConv2d, self).__init__() 13 | self.conv = nn.Conv2d(in_planes, out_planes, 14 | kernel_size=kernel_size, stride=stride, 15 | padding=padding, dilation=dilation, bias=False) 16 | self.bn = nn.BatchNorm2d(out_planes) 17 | self.relu = nn.ReLU(inplace=True) 18 | 19 | def forward(self, x): 20 | x = self.conv(x) 21 | x = self.bn(x) 22 | x = self.relu(x) 23 | return x 24 | 25 | 26 | class TransBasicConv2d(nn.Module): 27 | def __init__(self, in_planes, out_planes, kernel_size=2, stride=2, padding=0, dilation=1, bias=False): 28 | super(TransBasicConv2d, self).__init__() 29 | self.Deconv = nn.ConvTranspose2d(in_planes, out_planes, 30 | kernel_size=kernel_size, stride=stride, 31 | padding=padding, dilation=dilation, bias=False) 32 | self.bn = nn.BatchNorm2d(out_planes) 33 | self.relu = nn.ReLU(inplace=True) 34 | 35 | def forward(self, x): 36 | x = self.Deconv(x) 37 | x = self.bn(x) 38 | x = self.relu(x) 39 | return x 40 | 41 | 42 | class ChannelAttention(nn.Module): 43 | def __init__(self, in_planes, ratio=16): 44 | super(ChannelAttention, self).__init__() 45 | 46 | self.max_pool = nn.AdaptiveMaxPool2d(1) 47 | 48 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 49 | self.relu1 = nn.ReLU() 50 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 51 | 52 | self.sigmoid = nn.Sigmoid() 53 | 54 | def forward(self, x): 55 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 56 | out = max_out 57 | return self.sigmoid(out) 58 | 59 | 60 | class SpatialAttention(nn.Module): 61 | def __init__(self, kernel_size=7): 62 | super(SpatialAttention, self).__init__() 63 | 64 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 65 | padding = 3 if kernel_size == 7 else 1 66 | 67 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False) 68 | self.sigmoid = nn.Sigmoid() 69 | 70 | def forward(self, x): 71 | max_out, _ = torch.max(x, dim=1, keepdim=True) 72 | x = max_out 73 | x = self.conv1(x) 74 | return self.sigmoid(x) 75 | 76 | 77 | class SpatialAttention_no_s(nn.Module): 78 | def __init__(self, kernel_size=7): 79 | super(SpatialAttention_no_s, self).__init__() 80 | 81 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 82 | padding = 3 if kernel_size == 7 else 1 83 | 84 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False) 85 | # self.sigmoid = nn.Sigmoid() 86 | 87 | def forward(self, x): 88 | max_out, _ = torch.max(x, dim=1, keepdim=True) 89 | x = max_out 90 | x = self.conv1(x) 91 | return x 92 | 93 | 94 | class MCCM(nn.Module): 95 | def __init__(self, cur_channel): 96 | super(MCCM, self).__init__() 97 | self.relu = nn.ReLU(True) 98 | 99 | self.ca = ChannelAttention(cur_channel) 100 | self.sa_fg = SpatialAttention_no_s() 101 | self.sa_edge = SpatialAttention_no_s() 102 | self.sigmoid = nn.Sigmoid() 103 | self.FE_conv = BasicConv2d(cur_channel, cur_channel, 3, padding=1) 104 | self.BG_conv = BasicConv2d(cur_channel, cur_channel, 3, padding=1) 105 | 106 | self.global_avg_pool = nn.AdaptiveAvgPool2d(1) 107 | self.conv1 = BasicConv2d(cur_channel, cur_channel, 1) 108 | self.sa_ic = SpatialAttention() 109 | self.IC_conv = BasicConv2d(cur_channel, cur_channel, 3, padding=1) 110 | 111 | self.FE_B_I_conv = BasicConv2d(3 * cur_channel, cur_channel, 3, padding=1) 112 | 113 | def forward(self, x): 114 | x_ca = x.mul(self.ca(x)) 115 | # Foreground attention 116 | x_sa_fg = self.sa_fg(x_ca) 117 | # Edge attention 118 | x_edge = self.sa_edge(x_ca) 119 | # Foreground and Edge (FE) feature 120 | x_fg_edge = self.FE_conv(x_ca.mul(self.sigmoid(x_sa_fg) + self.sigmoid(x_edge))) 121 | 122 | # Background feature 123 | x_bg = self.BG_conv(x_ca.mul(1 - self.sigmoid(x_sa_fg) - self.sigmoid(x_edge))) 124 | 125 | # Image-level content 126 | in_size = x.shape[2:] 127 | x_gap = self.conv1(self.global_avg_pool(x)) 128 | x_up = F.interpolate(x_gap, size=in_size, mode="bilinear", align_corners=True) 129 | x_ic = self.IC_conv(x.mul(self.sa_ic(x_up))) 130 | 131 | x_RE_B_I = self.FE_B_I_conv(torch.cat((x_fg_edge, x_bg, x_ic), 1)) 132 | 133 | return (x + x_RE_B_I), x_edge 134 | 135 | class decoder(nn.Module): 136 | def __init__(self, channel=512): 137 | super(decoder, self).__init__() 138 | self.relu = nn.ReLU(True) 139 | 140 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 141 | 142 | self.decoder5 = nn.Sequential( 143 | BasicConv2d(channel, 512, 3, padding=1), 144 | BasicConv2d(512, 512, 3, padding=1), 145 | BasicConv2d(512, 512, 3, padding=1), 146 | nn.Dropout(0.5), 147 | TransBasicConv2d(512, 512, kernel_size=2, stride=2, 148 | padding=0, dilation=1, bias=False) 149 | ) 150 | self.S5 = nn.Conv2d(512, 1, 3, stride=1, padding=1) 151 | 152 | self.decoder4 = nn.Sequential( 153 | BasicConv2d(1024, 512, 3, padding=1), 154 | BasicConv2d(512, 512, 3, padding=1), 155 | BasicConv2d(512, 256, 3, padding=1), 156 | nn.Dropout(0.5), 157 | TransBasicConv2d(256, 256, kernel_size=2, stride=2, 158 | padding=0, dilation=1, bias=False) 159 | ) 160 | self.S4 = nn.Conv2d(256, 1, 3, stride=1, padding=1) 161 | 162 | self.decoder3 = nn.Sequential( 163 | BasicConv2d(512, 256, 3, padding=1), 164 | BasicConv2d(256, 256, 3, padding=1), 165 | BasicConv2d(256, 128, 3, padding=1), 166 | nn.Dropout(0.5), 167 | TransBasicConv2d(128, 128, kernel_size=2, stride=2, 168 | padding=0, dilation=1, bias=False) 169 | ) 170 | self.S3 = nn.Conv2d(128, 1, 3, stride=1, padding=1) 171 | 172 | self.decoder2 = nn.Sequential( 173 | BasicConv2d(256, 128, 3, padding=1), 174 | BasicConv2d(128, 64, 3, padding=1), 175 | nn.Dropout(0.5), 176 | TransBasicConv2d(64, 64, kernel_size=2, stride=2, 177 | padding=0, dilation=1, bias=False) 178 | ) 179 | self.S2 = nn.Conv2d(64, 1, 3, stride=1, padding=1) 180 | 181 | self.decoder1 = nn.Sequential( 182 | BasicConv2d(128, 64, 3, padding=1), 183 | BasicConv2d(64, 32, 3, padding=1), 184 | ) 185 | self.S1 = nn.Conv2d(32, 1, 3, stride=1, padding=1) 186 | 187 | def forward(self, x5, x4, x3, x2, x1): 188 | # x5: 1/16, 512; x4: 1/8, 512; x3: 1/4, 256; x2: 1/2, 128; x1: 1/1, 64 189 | x5_up = self.decoder5(x5) 190 | s5 = self.S5(x5_up) 191 | 192 | x4_up = self.decoder4(torch.cat((x4, x5_up), 1)) 193 | s4 = self.S4(x4_up) 194 | 195 | x3_up = self.decoder3(torch.cat((x3, x4_up), 1)) 196 | s3 = self.S3(x3_up) 197 | 198 | x2_up = self.decoder2(torch.cat((x2, x3_up), 1)) 199 | s2 = self.S2(x2_up) 200 | 201 | x1_up = self.decoder1(torch.cat((x1, x2_up), 1)) 202 | s1 = self.S1(x1_up) 203 | 204 | return s1, s2, s3, s4, s5 205 | 206 | 207 | class MCCNet_VGG(nn.Module): 208 | def __init__(self, channel=32): 209 | super(MCCNet_VGG, self).__init__() 210 | # Backbone model 211 | self.vgg = VGG('rgb') 212 | 213 | self.MCCM5 = MCCM(512) 214 | self.MCCM4 = MCCM(512) 215 | self.MCCM3 = MCCM(256) 216 | self.MCCM2 = MCCM(128) 217 | self.MCCM1 = MCCM(64) 218 | 219 | self.decoder_rgb = decoder(512) 220 | 221 | self.upsample16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) 222 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 223 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 224 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 225 | 226 | self.sigmoid = nn.Sigmoid() 227 | 228 | def forward(self, x_rgb): 229 | x1_rgb = self.vgg.conv1(x_rgb) 230 | x2_rgb = self.vgg.conv2(x1_rgb) 231 | x3_rgb = self.vgg.conv3(x2_rgb) 232 | x4_rgb = self.vgg.conv4(x3_rgb) 233 | x5_rgb = self.vgg.conv5(x4_rgb) 234 | 235 | # LG means Local and Global, i.e., adjacent context information 236 | x5_MCCM, eg5 = self.MCCM5(x5_rgb) 237 | x4_MCCM, eg4 = self.MCCM4(x4_rgb) 238 | x3_MCCM, eg3 = self.MCCM3(x3_rgb) 239 | x2_MCCM, eg2 = self.MCCM2(x2_rgb) 240 | x1_MCCM, eg1 = self.MCCM1(x1_rgb) 241 | 242 | s1, s2, s3, s4, s5 = self.decoder_rgb(x5_MCCM, x4_MCCM, x3_MCCM, x2_MCCM, x1_MCCM) 243 | 244 | s3 = self.upsample2(s3) 245 | s4 = self.upsample4(s4) 246 | s5 = self.upsample8(s5) 247 | 248 | eg2 = self.upsample2(eg2) 249 | eg3 = self.upsample4(eg3) 250 | eg4 = self.upsample8(eg4) 251 | eg5 = self.upsample16(eg5) 252 | 253 | return s1, s2, s3, s4, s5, self.sigmoid(s1), self.sigmoid(s2), self.sigmoid(s3), self.sigmoid(s4), self.sigmoid( 254 | s5), eg1, eg2, eg3, eg4, eg5 255 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class VGG(nn.Module): 6 | # pooling layer at the front of block 7 | def __init__(self, mode = 'rgb'): 8 | super(VGG, self).__init__() 9 | 10 | conv1 = nn.Sequential() 11 | conv1.add_module('conv1_1', nn.Conv2d(3, 64, 3, 1, 1)) 12 | conv1.add_module('bn1_1', nn.BatchNorm2d(64)) 13 | conv1.add_module('relu1_1', nn.ReLU(inplace=True)) 14 | conv1.add_module('conv1_2', nn.Conv2d(64, 64, 3, 1, 1)) 15 | conv1.add_module('bn1_2', nn.BatchNorm2d(64)) 16 | conv1.add_module('relu1_2', nn.ReLU(inplace=True)) 17 | 18 | self.conv1 = conv1 19 | conv2 = nn.Sequential() 20 | conv2.add_module('pool1', nn.MaxPool2d(2, stride=2)) 21 | conv2.add_module('conv2_1', nn.Conv2d(64, 128, 3, 1, 1)) 22 | conv2.add_module('bn2_1', nn.BatchNorm2d(128)) 23 | conv2.add_module('relu2_1', nn.ReLU()) 24 | conv2.add_module('conv2_2', nn.Conv2d(128, 128, 3, 1, 1)) 25 | conv2.add_module('bn2_2', nn.BatchNorm2d(128)) 26 | conv2.add_module('relu2_2', nn.ReLU()) 27 | self.conv2 = conv2 28 | 29 | conv3 = nn.Sequential() 30 | conv3.add_module('pool2', nn.MaxPool2d(2, stride=2)) 31 | conv3.add_module('conv3_1', nn.Conv2d(128, 256, 3, 1, 1)) 32 | conv3.add_module('bn3_1', nn.BatchNorm2d(256)) 33 | conv3.add_module('relu3_1', nn.ReLU()) 34 | conv3.add_module('conv3_2', nn.Conv2d(256, 256, 3, 1, 1)) 35 | conv3.add_module('bn3_2', nn.BatchNorm2d(256)) 36 | conv3.add_module('relu3_2', nn.ReLU()) 37 | conv3.add_module('conv3_3', nn.Conv2d(256, 256, 3, 1, 1)) 38 | conv3.add_module('bn3_3', nn.BatchNorm2d(256)) 39 | conv3.add_module('relu3_3', nn.ReLU()) 40 | self.conv3 = conv3 41 | 42 | conv4 = nn.Sequential() 43 | conv4.add_module('pool3_1', nn.MaxPool2d(2, stride=2)) 44 | conv4.add_module('conv4_1', nn.Conv2d(256, 512, 3, 1, 1)) 45 | conv4.add_module('bn4_1', nn.BatchNorm2d(512)) 46 | conv4.add_module('relu4_1', nn.ReLU()) 47 | conv4.add_module('conv4_2', nn.Conv2d(512, 512, 3, 1, 1)) 48 | conv4.add_module('bn4_2', nn.BatchNorm2d(512)) 49 | conv4.add_module('relu4_2', nn.ReLU()) 50 | conv4.add_module('conv4_3', nn.Conv2d(512, 512, 3, 1, 1)) 51 | conv4.add_module('bn4_3', nn.BatchNorm2d(512)) 52 | conv4.add_module('relu4_3', nn.ReLU()) 53 | self.conv4 = conv4 54 | 55 | conv5 = nn.Sequential() 56 | conv5.add_module('pool4', nn.MaxPool2d(2, stride=2)) 57 | conv5.add_module('conv5_1', nn.Conv2d(512, 512, 3, 1, 1)) 58 | conv5.add_module('bn5_1', nn.BatchNorm2d(512)) 59 | conv5.add_module('relu5_1', nn.ReLU()) 60 | conv5.add_module('conv5_2', nn.Conv2d(512, 512, 3, 1, 1)) 61 | conv5.add_module('bn5_2', nn.BatchNorm2d(512)) 62 | conv5.add_module('relu5_2', nn.ReLU()) 63 | conv5.add_module('conv5_3', nn.Conv2d(512, 512, 3, 1, 1)) 64 | conv5.add_module('bn5_2', nn.BatchNorm2d(512)) 65 | conv5.add_module('relu5_3', nn.ReLU()) 66 | self.conv5 = conv5 67 | 68 | pre_train = torch.load('/home/lgy/20210206_ORSI_SOD/model/vgg16-397923af.pth') 69 | self._initialize_weights(pre_train) 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = self.conv2(x) 74 | x = self.conv3(x) 75 | x = self.conv4(x) 76 | x = self.conv5(x) 77 | 78 | return x 79 | 80 | def _initialize_weights(self, pre_train): 81 | keys = pre_train.keys() 82 | 83 | self.conv1.conv1_1.weight.data.copy_(pre_train[keys[0]]) 84 | self.conv1.conv1_2.weight.data.copy_(pre_train[keys[2]]) 85 | self.conv2.conv2_1.weight.data.copy_(pre_train[keys[4]]) 86 | self.conv2.conv2_2.weight.data.copy_(pre_train[keys[6]]) 87 | self.conv3.conv3_1.weight.data.copy_(pre_train[keys[8]]) 88 | self.conv3.conv3_2.weight.data.copy_(pre_train[keys[10]]) 89 | self.conv3.conv3_3.weight.data.copy_(pre_train[keys[12]]) 90 | self.conv4.conv4_1.weight.data.copy_(pre_train[keys[14]]) 91 | self.conv4.conv4_2.weight.data.copy_(pre_train[keys[16]]) 92 | self.conv4.conv4_3.weight.data.copy_(pre_train[keys[18]]) 93 | self.conv5.conv5_1.weight.data.copy_(pre_train[keys[20]]) 94 | self.conv5.conv5_2.weight.data.copy_(pre_train[keys[22]]) 95 | self.conv5.conv5_3.weight.data.copy_(pre_train[keys[24]]) 96 | 97 | self.conv1.conv1_1.bias.data.copy_(pre_train[keys[1]]) 98 | self.conv1.conv1_2.bias.data.copy_(pre_train[keys[3]]) 99 | self.conv2.conv2_1.bias.data.copy_(pre_train[keys[5]]) 100 | self.conv2.conv2_2.bias.data.copy_(pre_train[keys[7]]) 101 | self.conv3.conv3_1.bias.data.copy_(pre_train[keys[9]]) 102 | self.conv3.conv3_2.bias.data.copy_(pre_train[keys[11]]) 103 | self.conv3.conv3_3.bias.data.copy_(pre_train[keys[13]]) 104 | self.conv4.conv4_1.bias.data.copy_(pre_train[keys[15]]) 105 | self.conv4.conv4_2.bias.data.copy_(pre_train[keys[17]]) 106 | self.conv4.conv4_3.bias.data.copy_(pre_train[keys[19]]) 107 | self.conv5.conv5_1.bias.data.copy_(pre_train[keys[21]]) 108 | self.conv5.conv5_2.bias.data.copy_(pre_train[keys[23]]) 109 | self.conv5.conv5_3.bias.data.copy_(pre_train[keys[25]]) 110 | -------------------------------------------------------------------------------- /pytorch_fm/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | 9 | 10 | # class FLoss(torch.nn.Module): 11 | # def __init__(self, beta=0.3, log_like=False): 12 | # super(FLoss, self).__init__() 13 | # self.beta = beta 14 | # self.log_like = log_like 15 | # 16 | # def forward(self, prediction, target): 17 | # EPS = 1e-10 18 | # floss = 0.0 19 | # N = prediction.shape[0] 20 | # for i in range(0, N): 21 | # TP = (prediction[i, :, :, :] * target[i, :, :, :]) 22 | # H = self.beta * target[i, :, :, :] + prediction[i, :, :, :] 23 | # fm = (1 + self.beta) * TP / (H + EPS) 24 | # if self.log_like: 25 | # floss = floss - torch.log(fm) 26 | # else: 27 | # floss = floss + (1 - fm) 28 | # 29 | # return floss / N 30 | 31 | class FLoss(torch.nn.Module): 32 | def __init__(self, beta=0.3, log_like=False): 33 | super(FLoss, self).__init__() 34 | self.beta = beta 35 | self.log_like = log_like 36 | 37 | def forward(self, prediction, target): 38 | EPS = 1e-10 39 | N = prediction.size(0) 40 | TP = (prediction * target).view(N, -1).sum(dim=1) 41 | H = self.beta * target.view(N, -1).sum(dim=1) + prediction.view(N, -1).sum(dim=1) 42 | fmeasure = (1 + self.beta) * TP / (H + EPS) 43 | if self.log_like: 44 | floss = -torch.log(fmeasure) 45 | else: 46 | floss = (1 - fmeasure) 47 | return floss.mean() -------------------------------------------------------------------------------- /pytorch_fm/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/MCCNet/a9b1876267412ee795acfa94b4921b7f8940fe27/pytorch_fm/__init__.pyc -------------------------------------------------------------------------------- /pytorch_iou/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | def _iou(pred, target, size_average = True): 9 | 10 | b = pred.shape[0] 11 | IoU = 0.0 12 | for i in range(0,b): 13 | #compute the IoU of the foreground 14 | Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:]) 15 | Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1 16 | IoU1 = Iand1/Ior1 17 | 18 | #IoU loss is (1-IoU1) 19 | IoU = IoU + (1-IoU1) 20 | 21 | return IoU/b 22 | 23 | class IOU(torch.nn.Module): 24 | def __init__(self, size_average = True): 25 | super(IOU, self).__init__() 26 | self.size_average = size_average 27 | 28 | def forward(self, pred, target): 29 | 30 | return _iou(pred, target, self.size_average) 31 | -------------------------------------------------------------------------------- /pytorch_iou/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/MCCNet/a9b1876267412ee795acfa94b4921b7f8940fe27/pytorch_iou/__init__.pyc -------------------------------------------------------------------------------- /pytorch_iou/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/MCCNet/a9b1876267412ee795acfa94b4921b7f8940fe27/pytorch_iou/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /test_MCCNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import numpy as np 5 | import pdb, os, argparse 6 | from scipy import misc 7 | import time 8 | 9 | from model.MCCNet_models import MCCNet_VGG 10 | from data import test_dataset 11 | 12 | torch.cuda.set_device(1) 13 | parser = argparse.ArgumentParser() 14 | parser.add_argument('--testsize', type=int, default=256, help='testing size') 15 | opt = parser.parse_args() 16 | 17 | dataset_path = './dataset/test_dataset/' 18 | 19 | model = MCCNet_VGG() 20 | model.load_state_dict(torch.load('./models/MCCNet_VGG/MCCNet_VGG.pth.34')) 21 | 22 | model.cuda() 23 | model.eval() 24 | 25 | # test_datasets = ['EORSSD'] 26 | test_datasets = ['ORSSD'] 27 | 28 | for dataset in test_datasets: 29 | save_path = './results/VGG/' + dataset + '/' 30 | if not os.path.exists(save_path): 31 | os.makedirs(save_path) 32 | image_root = dataset_path + dataset + '/image/' 33 | print(dataset) 34 | gt_root = dataset_path + dataset + '/GT/' 35 | test_loader = test_dataset(image_root, gt_root, opt.testsize) 36 | time_sum = 0 37 | for i in range(test_loader.size): 38 | image, gt, name = test_loader.load_data() 39 | gt = np.asarray(gt, np.float32) 40 | gt /= (gt.max() + 1e-8) 41 | image = image.cuda() 42 | time_start = time.time() 43 | res, s2, s3, s4, s5, s1_sig, s2_sig, s3_sig, s4_sig, s5_sig, eg1, eg2, eg3, eg4, eg5 = model(image) 44 | time_end = time.time() 45 | time_sum = time_sum+(time_end-time_start) 46 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 47 | res = res.sigmoid().data.cpu().numpy().squeeze() 48 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 49 | misc.imsave(save_path+name, res) 50 | if i == test_loader.size-1: 51 | print('Running time {:.5f}'.format(time_sum/test_loader.size)) 52 | print('Average speed: {:.4f} fps'.format(test_loader.size/time_sum)) 53 | -------------------------------------------------------------------------------- /train_MCCNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | import numpy as np 7 | import pdb, os, argparse 8 | from datetime import datetime 9 | 10 | from model.MCCNet_models import MCCNet_VGG 11 | from data import get_loader 12 | from utils import clip_gradient, adjust_lr 13 | 14 | import pytorch_iou 15 | import pytorch_fm 16 | 17 | torch.cuda.set_device(1) 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--epoch', type=int, default=40, help='epoch number') 20 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 21 | parser.add_argument('--batchsize', type=int, default=8, help='training batch size') 22 | parser.add_argument('--trainsize', type=int, default=256, help='training dataset size') 23 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 24 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate') 25 | parser.add_argument('--decay_epoch', type=int, default=30, help='every n epochs decay learning rate') 26 | opt = parser.parse_args() 27 | 28 | print('Learning Rate: {}'.format(opt.lr)) 29 | # build models 30 | model = MCCNet_VGG() 31 | 32 | model.cuda() 33 | params = model.parameters() 34 | optimizer = torch.optim.Adam(params, opt.lr) 35 | 36 | image_root = './dataset/train_dataset/ORSSD/train/image/' 37 | gt_root = './dataset/train_dataset/ORSSD/train/GT/' 38 | edge_root = './dataset/train_dataset/ORSSD/train/edge/' 39 | # image_root = './dataset/train_dataset/EORSSD/train/image/' 40 | # gt_root = './dataset/train_dataset/EORSSD/train/GT/' 41 | # edge_root = './dataset/train_dataset/EORSSD/train/edge/' 42 | train_loader = get_loader(image_root, gt_root, edge_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 43 | total_step = len(train_loader) 44 | 45 | CE = torch.nn.BCEWithLogitsLoss() 46 | IOU = pytorch_iou.IOU(size_average = True) 47 | floss = pytorch_fm.FLoss() 48 | 49 | def train(train_loader, model, optimizer, epoch): 50 | model.train() 51 | for i, pack in enumerate(train_loader, start=1): 52 | optimizer.zero_grad() 53 | images, gts, edges = pack 54 | images = Variable(images) 55 | gts = Variable(gts) 56 | edges = Variable(edges) 57 | images = images.cuda() 58 | gts = gts.cuda() 59 | edges = edges.cuda() 60 | 61 | s1, s2, s3, s4, s5, s1_sig, s2_sig, s3_sig, s4_sig, s5_sig, eg1, eg2, eg3, eg4, eg5 = model(images) 62 | # bce+iou+fmloss 63 | loss1 = CE(s1, gts) + IOU(s1_sig, gts) + CE(eg1, edges) + floss(s1_sig, gts) 64 | loss2 = CE(s2, gts) + IOU(s2_sig, gts) + CE(eg2, edges) + floss(s2_sig, gts) 65 | loss3 = CE(s3, gts) + IOU(s3_sig, gts) + CE(eg3, edges) + floss(s3_sig, gts) 66 | loss4 = CE(s4, gts) + IOU(s4_sig, gts) + CE(eg4, edges) + floss(s4_sig, gts) 67 | loss5 = CE(s5, gts) + IOU(s5_sig, gts) + CE(eg5, edges) + floss(s5_sig, gts) 68 | 69 | loss = loss1 + loss2 + loss3 + loss4 + loss5 70 | 71 | loss.backward() 72 | 73 | clip_gradient(optimizer, opt.clip) 74 | optimizer.step() 75 | 76 | if i % 20 == 0 or i == total_step: 77 | print( 78 | '{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Learning Rate: {}, Loss: {:.4f}, Loss1: {:.4f}, Loss2: {:.4f}'. 79 | format(datetime.now(), epoch, opt.epoch, i, total_step, 80 | opt.lr * opt.decay_rate ** (epoch // opt.decay_epoch), loss.data, loss1.data, 81 | loss2.data)) 82 | 83 | save_path = 'models/MCCNet_VGG/' 84 | if not os.path.exists(save_path): 85 | os.makedirs(save_path) 86 | if (epoch+1) % 5 == 0: 87 | torch.save(model.state_dict(), save_path + 'MCCNet_VGG.pth' + '.%d' % epoch) 88 | 89 | print("Let's go!") 90 | for epoch in range(1, opt.epoch): 91 | adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) 92 | train(train_loader, model, optimizer, epoch) 93 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | def clip_gradient(optimizer, grad_clip): 2 | for group in optimizer.param_groups: 3 | for param in group['params']: 4 | if param.grad is not None: 5 | param.grad.data.clamp_(-grad_clip, grad_clip) 6 | 7 | 8 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 9 | decay = decay_rate ** (epoch // decay_epoch) 10 | for param_group in optimizer.param_groups: 11 | param_group['lr'] = init_lr*decay 12 | print('decay_epoch: {}, Current_LR: {}'.format(decay_epoch, init_lr*decay)) --------------------------------------------------------------------------------