├── README.md
├── data.py
├── data_aug.m
├── images
├── MCCM.png
├── MCCNet.png
└── table.png
├── model
├── MCCNet_models.py
├── __init__.py
└── vgg.py
├── pytorch_fm
├── __init__.py
└── __init__.pyc
├── pytorch_iou
├── __init__.py
├── __init__.pyc
└── __pycache__
│ └── __init__.cpython-36.pyc
├── test_MCCNet.py
├── train_MCCNet.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # MCCNet
2 | This project provides the code and results for 'Multi-Content Complementation Network for Salient Object Detection in Optical Remote Sensing Images', IEEE TGRS, vol. 60, pp. 1-13, 2022. [IEEE link](https://ieeexplore.ieee.org/document/9631225) and [arxiv link](https://arxiv.org/abs/2112.01932) [Homepage](https://mathlee.github.io/)
3 |
4 |
5 | # Network Architecture
6 |
7 |

8 |
9 |
10 | # Multi-Content Complementation Module (MCCM)
11 |
12 |

13 |
14 |
15 |
16 | # Requirements
17 | python 2.7 + pytorch 0.4.0 or
18 |
19 | python 3.7 + pytorch 1.9.0
20 |
21 |
22 | # Saliency maps
23 | We provide saliency maps and [measure results (.mat)](https://pan.baidu.com/s/1l4GPBcPYCO9atbgDwbkfUw) (code: i9d0) of [all compared methods](https://pan.baidu.com/s/1TP6An1VWygGUy4uvojL0bg) (code: 5np3) and [our MCCNet](https://pan.baidu.com/s/10JIKL2Q48RvBGeT2pmPfDA) (code: 3pvq) on ORSSD and EORSSD datasets.
24 |
25 | In addition, we also provide [saliency maps of our MCCNet](https://pan.baidu.com/s/1dz-GeELIqMdzKlPvzETixA) (code: 413m) on the recently published [ORSI-4199](https://github.com/wchao1213/ORSI-SOD) dataset.
26 |
27 | 
28 |
29 | # Training
30 |
31 | We get the ground truth of edge using [sal2edge.m](https://github.com/JXingZhao/EGNet/blob/master/sal2edge.m) in [EGNet](https://github.com/JXingZhao/EGNet),and use data_aug.m for data augmentation.
32 |
33 | Modify paths of [VGG backbone](https://pan.baidu.com/s/1YQxKZ-y2C4EsqrgKNI7qrw) (code: ego5) and datasets, then run train_MCCNet.py.
34 |
35 |
36 | # Pre-trained model and testing
37 | Download the following pre-trained model, and modify paths of pre-trained model and datasets, then run test_MCCNet.py.
38 |
39 | [ORSSD](https://pan.baidu.com/s/1LdUE8F11r61r8wk3Y9wPLA) (code: awqr)
40 |
41 | [EORSSD](https://pan.baidu.com/s/14LrEt1LW5QmZvkhsgbKgfg) (code: wm3p)
42 |
43 | [ORSI-4199](https://pan.baidu.com/s/1hmANQp9cslyPuDE-3NlqAg) (code: 336a)
44 |
45 |
46 | # Evaluation Tool
47 | You can use the [evaluation tool (MATLAB version)](https://github.com/MathLee/MatlabEvaluationTools) to evaluate the above saliency maps.
48 |
49 |
50 | # [ORSI-SOD_Summary](https://github.com/MathLee/ORSI-SOD_Summary)
51 |
52 | # Citation
53 | @ARTICLE{Li_2022_MCCNet,
54 | author = {Gongyang Li and Zhi Liu and Weisi Lin and Haibin Ling},
55 | title = {Multi-Content Complementation Network for Salient Object Detection in Optical Remote Sensing Images},
56 | journal = {IEEE Transactions on Geoscience and Remote Sensing},
57 | volume = {60},
58 | pages = {1-13},
59 | year = {2022},
60 | }
61 |
62 |
63 | If you encounter any problems with the code, want to report bugs, etc.
64 |
65 | Please contact me at lllmiemie@163.com or ligongyang@shu.edu.cn.
66 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import random
6 | import numpy as np
7 | from PIL import ImageEnhance
8 |
9 |
10 | #several data augumentation strategies
11 | def cv_random_flip(img, label):
12 | flip_flag = random.randint(0, 1)
13 | # flip_flag2= random.randint(0,1)
14 | #left right flip
15 | if flip_flag == 1:
16 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
17 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
18 | #top bottom flip
19 | # if flip_flag2==1:
20 | # img = img.transpose(Image.FLIP_TOP_BOTTOM)
21 | # label = label.transpose(Image.FLIP_TOP_BOTTOM)
22 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM)
23 | return img, label
24 | def randomCrop(image, label):
25 | border=30
26 | image_width = image.size[0]
27 | image_height = image.size[1]
28 | crop_win_width = np.random.randint(image_width-border , image_width)
29 | crop_win_height = np.random.randint(image_height-border , image_height)
30 | random_region = (
31 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
32 | (image_height + crop_win_height) >> 1)
33 | return image.crop(random_region), label.crop(random_region)
34 | def randomRotation(image,label):
35 | mode=Image.BICUBIC
36 | if random.random()>0.8:
37 | random_angle = np.random.randint(-15, 15)
38 | image=image.rotate(random_angle, mode)
39 | label=label.rotate(random_angle, mode)
40 | return image,label
41 | def colorEnhance(image):
42 | bright_intensity=random.randint(5,15)/10.0
43 | image=ImageEnhance.Brightness(image).enhance(bright_intensity)
44 | contrast_intensity=random.randint(5,15)/10.0
45 | image=ImageEnhance.Contrast(image).enhance(contrast_intensity)
46 | color_intensity=random.randint(0,20)/10.0
47 | image=ImageEnhance.Color(image).enhance(color_intensity)
48 | sharp_intensity=random.randint(0,30)/10.0
49 | image=ImageEnhance.Sharpness(image).enhance(sharp_intensity)
50 | return image
51 | def randomGaussian(image, mean=0.1, sigma=0.35):
52 | def gaussianNoisy(im, mean=mean, sigma=sigma):
53 | for _i in range(len(im)):
54 | im[_i] += random.gauss(mean, sigma)
55 | return im
56 | img = np.asarray(image)
57 | width, height = img.shape
58 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
59 | img = img.reshape([width, height])
60 | return Image.fromarray(np.uint8(img))
61 | def randomPeper(img):
62 |
63 | img=np.array(img)
64 | noiseNum=int(0.0015*img.shape[0]*img.shape[1])
65 | for i in range(noiseNum):
66 |
67 | randX=random.randint(0,img.shape[0]-1)
68 |
69 | randY=random.randint(0,img.shape[1]-1)
70 |
71 | if random.randint(0,1)==0:
72 |
73 | img[randX,randY]=0
74 |
75 | else:
76 |
77 | img[randX,randY]=255
78 | return Image.fromarray(img)
79 |
80 |
81 | class SalObjDataset(data.Dataset):
82 | def __init__(self, image_root, gt_root, edge_root, trainsize):
83 | self.trainsize = trainsize
84 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
85 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
86 | or f.endswith('.png')]
87 | self.edges = [gt_root + f for f in os.listdir(edge_root) if f.endswith('.jpg')
88 | or f.endswith('.png')]
89 | self.images = sorted(self.images)
90 | # self.depths = sorted(self.depths)
91 | self.gts = sorted(self.gts)
92 | self.edges = sorted(self.edges)
93 | self.filter_files()
94 | self.size = len(self.images)
95 | self.img_transform = transforms.Compose([
96 | transforms.Resize((self.trainsize, self.trainsize)),
97 | transforms.ToTensor(),
98 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
99 |
100 | self.gt_transform = transforms.Compose([
101 | transforms.Resize((self.trainsize, self.trainsize)),
102 | transforms.ToTensor()])
103 |
104 | self.edge_transform = transforms.Compose([
105 | transforms.Resize((self.trainsize, self.trainsize)),
106 | transforms.ToTensor()])
107 |
108 | def __getitem__(self, index):
109 | image = self.rgb_loader(self.images[index])
110 | gt = self.binary_loader(self.gts[index])
111 | edge = self.binary_loader(self.edges[index])
112 | # image,gt =cv_random_flip(image,gt)
113 | # image,gt =randomCrop(image, gt)
114 | # image,gt =randomRotation(image, gt)
115 | # image=colorEnhance(image)
116 | # gt=randomPeper(gt)
117 | image = self.img_transform(image)
118 | gt = self.gt_transform(gt)
119 | edge = self.gt_transform(edge)
120 | return image, gt, edge
121 |
122 | def filter_files(self):
123 | assert len(self.images) == len(self.gts)
124 | images = []
125 | # depths = []
126 | gts = []
127 | edges = []
128 | for img_path, gt_path, edge_path in zip(self.images, self.gts, self.edges):
129 | img = Image.open(img_path)
130 | gt = Image.open(gt_path)
131 | edge = Image.open(edge_path)
132 | if img.size == gt.size:
133 | images.append(img_path)
134 | gts.append(gt_path)
135 | edges.append(edge_path)
136 | self.images = images
137 | self.gts = gts
138 | self.edges = edges
139 |
140 | def rgb_loader(self, path):
141 | with open(path, 'rb') as f:
142 | img = Image.open(f)
143 | return img.convert('RGB')
144 |
145 | def binary_loader(self, path):
146 | with open(path, 'rb') as f:
147 | img = Image.open(f)
148 | # return img.convert('1')
149 | return img.convert('L')
150 |
151 | def resize(self, img, gt, edge):
152 | assert img.size == gt.size
153 | w, h = img.size
154 | if h < self.trainsize or w < self.trainsize:
155 | h = max(h, self.trainsize)
156 | w = max(w, self.trainsize)
157 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), edge.resize((w, h), Image.NEAREST)
158 | else:
159 | return img, gt, edge
160 |
161 | def __len__(self):
162 | return self.size
163 |
164 |
165 | def get_loader(image_root, gt_root, edge_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True):
166 |
167 | dataset = SalObjDataset(image_root, gt_root, edge_root, trainsize)
168 | data_loader = data.DataLoader(dataset=dataset,
169 | batch_size=batchsize,
170 | shuffle=shuffle,
171 | num_workers=num_workers,
172 | pin_memory=pin_memory)
173 | return data_loader
174 |
175 |
176 | class test_dataset:
177 | def __init__(self, image_root, gt_root, testsize):
178 | self.testsize = testsize
179 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
180 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
181 | or f.endswith('.png')]
182 | self.images = sorted(self.images)
183 | # self.depths = sorted(self.depths)
184 | self.gts = sorted(self.gts)
185 | self.img_transform = transforms.Compose([
186 | transforms.Resize((self.testsize, self.testsize)),
187 | transforms.ToTensor(),
188 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
189 | self.gt_transform = transforms.ToTensor()
190 | self.size = len(self.images)
191 | self.index = 0
192 |
193 | def load_data(self):
194 | image = self.rgb_loader(self.images[self.index])
195 | image = self.img_transform(image).unsqueeze(0)
196 | gt = self.binary_loader(self.gts[self.index])
197 | name = self.images[self.index].split('/')[-1]
198 | if name.endswith('.jpg'):
199 | name = name.split('.jpg')[0] + '.png'
200 | self.index += 1
201 | return image, gt, name
202 |
203 | def rgb_loader(self, path):
204 | with open(path, 'rb') as f:
205 | img = Image.open(f)
206 | return img.convert('RGB')
207 |
208 | def binary_loader(self, path):
209 | with open(path, 'rb') as f:
210 | img = Image.open(f)
211 | return img.convert('L')
212 |
--------------------------------------------------------------------------------
/data_aug.m:
--------------------------------------------------------------------------------
1 | clear;
2 | clc;
3 | close all;
4 |
5 | imPath = '/home/lgy/桌面/ORSI_SOD/dataset/EORSSD/train/image/';
6 | GtPath = '/home/lgy/桌面/ORSI_SOD/dataset/EORSSD/train/GT/';
7 | EGPath = '/home/lgy/桌面/ORSI_SOD/dataset/EORSSD/train/edge/';
8 |
9 | images = dir([GtPath '*.png']);
10 | imagesNum = length(images);
11 |
12 | for i = 1 : imagesNum
13 | im_name = images(i).name(1:end-4);
14 |
15 | gt = imread(fullfile(GtPath, [im_name '.png']));
16 | im = imread(fullfile(imPath, [im_name '.jpg']));
17 | eg = imread(fullfile(EGPath, [im_name '.png']));
18 |
19 | im_1 = imrotate(im,90);
20 | gt_1 = imrotate(gt,90);
21 | eg_1 = imrotate(eg,90);
22 | imwrite(im_1, fullfile(imPath, [im_name '_90.jpg']));
23 | imwrite(gt_1, fullfile(GtPath, [im_name '_90.png']));
24 | imwrite(eg_1, fullfile(EGPath, [im_name '_90.png']));
25 |
26 | im_2 = imrotate(im, 180);
27 | gt_2 = imrotate(gt, 180);
28 | eg_2 = imrotate(eg, 180);
29 | imwrite(im_2, fullfile(imPath, [im_name '_180.jpg']));
30 | imwrite(gt_2, fullfile(GtPath, [im_name '_180.png']));
31 | imwrite(eg_2, fullfile(EGPath, [im_name '_180.png']));
32 |
33 | im_3 = imrotate(im, 270);
34 | gt_3 = imrotate(gt, 270);
35 | eg_3 = imrotate(eg, 270);
36 | imwrite(im_3, fullfile(imPath, [im_name '_270.jpg']));
37 | imwrite(gt_3, fullfile(GtPath, [im_name '_270.png']));
38 | imwrite(eg_3, fullfile(EGPath, [im_name '_270.png']));
39 |
40 |
41 | fl_im = fliplr(im);
42 | fl_gt = fliplr(gt);
43 | fl_eg = fliplr(eg);
44 | imwrite(fl_im, fullfile(imPath, [im_name '_fl.jpg']));
45 | imwrite(fl_gt, fullfile(GtPath, [im_name '_fl.png']));
46 | imwrite(fl_eg, fullfile(EGPath, [im_name '_fl.png']));
47 |
48 | im_1 = imrotate(fl_im,90);
49 | gt_1 = imrotate(fl_gt,90);
50 | eg_1 = imrotate(fl_eg,90);
51 | imwrite(im_1, fullfile(imPath, [im_name '_fl90.jpg']));
52 | imwrite(gt_1, fullfile(GtPath, [im_name '_fl90.png']));
53 | imwrite(eg_1, fullfile(EGPath, [im_name '_fl90.png']));
54 |
55 | im_2 = imrotate(fl_im, 180);
56 | gt_2 = imrotate(fl_gt, 180);
57 | eg_2 = imrotate(fl_eg, 180);
58 | imwrite(im_2, fullfile(imPath, [im_name '_fl180.jpg']));
59 | imwrite(gt_2, fullfile(GtPath, [im_name '_fl180.png']));
60 | imwrite(eg_2, fullfile(EGPath, [im_name '_fl180.png']));
61 |
62 | im_3 = imrotate(fl_im, 270);
63 | gt_3 = imrotate(fl_gt, 270);
64 | eg_3 = imrotate(fl_eg, 270);
65 | imwrite(im_3, fullfile(imPath, [im_name '_fl270.jpg']));
66 | imwrite(gt_3, fullfile(GtPath, [im_name '_fl270.png']));
67 | imwrite(eg_3, fullfile(EGPath, [im_name '_fl270.png']));
68 |
69 | end
70 |
--------------------------------------------------------------------------------
/images/MCCM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/MCCNet/a9b1876267412ee795acfa94b4921b7f8940fe27/images/MCCM.png
--------------------------------------------------------------------------------
/images/MCCNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/MCCNet/a9b1876267412ee795acfa94b4921b7f8940fe27/images/MCCNet.png
--------------------------------------------------------------------------------
/images/table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/MCCNet/a9b1876267412ee795acfa94b4921b7f8940fe27/images/table.png
--------------------------------------------------------------------------------
/model/MCCNet_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import numpy as np
5 | import os
6 |
7 | from vgg import VGG
8 |
9 |
10 | class BasicConv2d(nn.Module):
11 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
12 | super(BasicConv2d, self).__init__()
13 | self.conv = nn.Conv2d(in_planes, out_planes,
14 | kernel_size=kernel_size, stride=stride,
15 | padding=padding, dilation=dilation, bias=False)
16 | self.bn = nn.BatchNorm2d(out_planes)
17 | self.relu = nn.ReLU(inplace=True)
18 |
19 | def forward(self, x):
20 | x = self.conv(x)
21 | x = self.bn(x)
22 | x = self.relu(x)
23 | return x
24 |
25 |
26 | class TransBasicConv2d(nn.Module):
27 | def __init__(self, in_planes, out_planes, kernel_size=2, stride=2, padding=0, dilation=1, bias=False):
28 | super(TransBasicConv2d, self).__init__()
29 | self.Deconv = nn.ConvTranspose2d(in_planes, out_planes,
30 | kernel_size=kernel_size, stride=stride,
31 | padding=padding, dilation=dilation, bias=False)
32 | self.bn = nn.BatchNorm2d(out_planes)
33 | self.relu = nn.ReLU(inplace=True)
34 |
35 | def forward(self, x):
36 | x = self.Deconv(x)
37 | x = self.bn(x)
38 | x = self.relu(x)
39 | return x
40 |
41 |
42 | class ChannelAttention(nn.Module):
43 | def __init__(self, in_planes, ratio=16):
44 | super(ChannelAttention, self).__init__()
45 |
46 | self.max_pool = nn.AdaptiveMaxPool2d(1)
47 |
48 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
49 | self.relu1 = nn.ReLU()
50 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
51 |
52 | self.sigmoid = nn.Sigmoid()
53 |
54 | def forward(self, x):
55 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
56 | out = max_out
57 | return self.sigmoid(out)
58 |
59 |
60 | class SpatialAttention(nn.Module):
61 | def __init__(self, kernel_size=7):
62 | super(SpatialAttention, self).__init__()
63 |
64 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
65 | padding = 3 if kernel_size == 7 else 1
66 |
67 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False)
68 | self.sigmoid = nn.Sigmoid()
69 |
70 | def forward(self, x):
71 | max_out, _ = torch.max(x, dim=1, keepdim=True)
72 | x = max_out
73 | x = self.conv1(x)
74 | return self.sigmoid(x)
75 |
76 |
77 | class SpatialAttention_no_s(nn.Module):
78 | def __init__(self, kernel_size=7):
79 | super(SpatialAttention_no_s, self).__init__()
80 |
81 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
82 | padding = 3 if kernel_size == 7 else 1
83 |
84 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False)
85 | # self.sigmoid = nn.Sigmoid()
86 |
87 | def forward(self, x):
88 | max_out, _ = torch.max(x, dim=1, keepdim=True)
89 | x = max_out
90 | x = self.conv1(x)
91 | return x
92 |
93 |
94 | class MCCM(nn.Module):
95 | def __init__(self, cur_channel):
96 | super(MCCM, self).__init__()
97 | self.relu = nn.ReLU(True)
98 |
99 | self.ca = ChannelAttention(cur_channel)
100 | self.sa_fg = SpatialAttention_no_s()
101 | self.sa_edge = SpatialAttention_no_s()
102 | self.sigmoid = nn.Sigmoid()
103 | self.FE_conv = BasicConv2d(cur_channel, cur_channel, 3, padding=1)
104 | self.BG_conv = BasicConv2d(cur_channel, cur_channel, 3, padding=1)
105 |
106 | self.global_avg_pool = nn.AdaptiveAvgPool2d(1)
107 | self.conv1 = BasicConv2d(cur_channel, cur_channel, 1)
108 | self.sa_ic = SpatialAttention()
109 | self.IC_conv = BasicConv2d(cur_channel, cur_channel, 3, padding=1)
110 |
111 | self.FE_B_I_conv = BasicConv2d(3 * cur_channel, cur_channel, 3, padding=1)
112 |
113 | def forward(self, x):
114 | x_ca = x.mul(self.ca(x))
115 | # Foreground attention
116 | x_sa_fg = self.sa_fg(x_ca)
117 | # Edge attention
118 | x_edge = self.sa_edge(x_ca)
119 | # Foreground and Edge (FE) feature
120 | x_fg_edge = self.FE_conv(x_ca.mul(self.sigmoid(x_sa_fg) + self.sigmoid(x_edge)))
121 |
122 | # Background feature
123 | x_bg = self.BG_conv(x_ca.mul(1 - self.sigmoid(x_sa_fg) - self.sigmoid(x_edge)))
124 |
125 | # Image-level content
126 | in_size = x.shape[2:]
127 | x_gap = self.conv1(self.global_avg_pool(x))
128 | x_up = F.interpolate(x_gap, size=in_size, mode="bilinear", align_corners=True)
129 | x_ic = self.IC_conv(x.mul(self.sa_ic(x_up)))
130 |
131 | x_RE_B_I = self.FE_B_I_conv(torch.cat((x_fg_edge, x_bg, x_ic), 1))
132 |
133 | return (x + x_RE_B_I), x_edge
134 |
135 | class decoder(nn.Module):
136 | def __init__(self, channel=512):
137 | super(decoder, self).__init__()
138 | self.relu = nn.ReLU(True)
139 |
140 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
141 |
142 | self.decoder5 = nn.Sequential(
143 | BasicConv2d(channel, 512, 3, padding=1),
144 | BasicConv2d(512, 512, 3, padding=1),
145 | BasicConv2d(512, 512, 3, padding=1),
146 | nn.Dropout(0.5),
147 | TransBasicConv2d(512, 512, kernel_size=2, stride=2,
148 | padding=0, dilation=1, bias=False)
149 | )
150 | self.S5 = nn.Conv2d(512, 1, 3, stride=1, padding=1)
151 |
152 | self.decoder4 = nn.Sequential(
153 | BasicConv2d(1024, 512, 3, padding=1),
154 | BasicConv2d(512, 512, 3, padding=1),
155 | BasicConv2d(512, 256, 3, padding=1),
156 | nn.Dropout(0.5),
157 | TransBasicConv2d(256, 256, kernel_size=2, stride=2,
158 | padding=0, dilation=1, bias=False)
159 | )
160 | self.S4 = nn.Conv2d(256, 1, 3, stride=1, padding=1)
161 |
162 | self.decoder3 = nn.Sequential(
163 | BasicConv2d(512, 256, 3, padding=1),
164 | BasicConv2d(256, 256, 3, padding=1),
165 | BasicConv2d(256, 128, 3, padding=1),
166 | nn.Dropout(0.5),
167 | TransBasicConv2d(128, 128, kernel_size=2, stride=2,
168 | padding=0, dilation=1, bias=False)
169 | )
170 | self.S3 = nn.Conv2d(128, 1, 3, stride=1, padding=1)
171 |
172 | self.decoder2 = nn.Sequential(
173 | BasicConv2d(256, 128, 3, padding=1),
174 | BasicConv2d(128, 64, 3, padding=1),
175 | nn.Dropout(0.5),
176 | TransBasicConv2d(64, 64, kernel_size=2, stride=2,
177 | padding=0, dilation=1, bias=False)
178 | )
179 | self.S2 = nn.Conv2d(64, 1, 3, stride=1, padding=1)
180 |
181 | self.decoder1 = nn.Sequential(
182 | BasicConv2d(128, 64, 3, padding=1),
183 | BasicConv2d(64, 32, 3, padding=1),
184 | )
185 | self.S1 = nn.Conv2d(32, 1, 3, stride=1, padding=1)
186 |
187 | def forward(self, x5, x4, x3, x2, x1):
188 | # x5: 1/16, 512; x4: 1/8, 512; x3: 1/4, 256; x2: 1/2, 128; x1: 1/1, 64
189 | x5_up = self.decoder5(x5)
190 | s5 = self.S5(x5_up)
191 |
192 | x4_up = self.decoder4(torch.cat((x4, x5_up), 1))
193 | s4 = self.S4(x4_up)
194 |
195 | x3_up = self.decoder3(torch.cat((x3, x4_up), 1))
196 | s3 = self.S3(x3_up)
197 |
198 | x2_up = self.decoder2(torch.cat((x2, x3_up), 1))
199 | s2 = self.S2(x2_up)
200 |
201 | x1_up = self.decoder1(torch.cat((x1, x2_up), 1))
202 | s1 = self.S1(x1_up)
203 |
204 | return s1, s2, s3, s4, s5
205 |
206 |
207 | class MCCNet_VGG(nn.Module):
208 | def __init__(self, channel=32):
209 | super(MCCNet_VGG, self).__init__()
210 | # Backbone model
211 | self.vgg = VGG('rgb')
212 |
213 | self.MCCM5 = MCCM(512)
214 | self.MCCM4 = MCCM(512)
215 | self.MCCM3 = MCCM(256)
216 | self.MCCM2 = MCCM(128)
217 | self.MCCM1 = MCCM(64)
218 |
219 | self.decoder_rgb = decoder(512)
220 |
221 | self.upsample16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
222 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
223 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
224 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
225 |
226 | self.sigmoid = nn.Sigmoid()
227 |
228 | def forward(self, x_rgb):
229 | x1_rgb = self.vgg.conv1(x_rgb)
230 | x2_rgb = self.vgg.conv2(x1_rgb)
231 | x3_rgb = self.vgg.conv3(x2_rgb)
232 | x4_rgb = self.vgg.conv4(x3_rgb)
233 | x5_rgb = self.vgg.conv5(x4_rgb)
234 |
235 | # LG means Local and Global, i.e., adjacent context information
236 | x5_MCCM, eg5 = self.MCCM5(x5_rgb)
237 | x4_MCCM, eg4 = self.MCCM4(x4_rgb)
238 | x3_MCCM, eg3 = self.MCCM3(x3_rgb)
239 | x2_MCCM, eg2 = self.MCCM2(x2_rgb)
240 | x1_MCCM, eg1 = self.MCCM1(x1_rgb)
241 |
242 | s1, s2, s3, s4, s5 = self.decoder_rgb(x5_MCCM, x4_MCCM, x3_MCCM, x2_MCCM, x1_MCCM)
243 |
244 | s3 = self.upsample2(s3)
245 | s4 = self.upsample4(s4)
246 | s5 = self.upsample8(s5)
247 |
248 | eg2 = self.upsample2(eg2)
249 | eg3 = self.upsample4(eg3)
250 | eg4 = self.upsample8(eg4)
251 | eg5 = self.upsample16(eg5)
252 |
253 | return s1, s2, s3, s4, s5, self.sigmoid(s1), self.sigmoid(s2), self.sigmoid(s3), self.sigmoid(s4), self.sigmoid(
254 | s5), eg1, eg2, eg3, eg4, eg5
255 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class VGG(nn.Module):
6 | # pooling layer at the front of block
7 | def __init__(self, mode = 'rgb'):
8 | super(VGG, self).__init__()
9 |
10 | conv1 = nn.Sequential()
11 | conv1.add_module('conv1_1', nn.Conv2d(3, 64, 3, 1, 1))
12 | conv1.add_module('bn1_1', nn.BatchNorm2d(64))
13 | conv1.add_module('relu1_1', nn.ReLU(inplace=True))
14 | conv1.add_module('conv1_2', nn.Conv2d(64, 64, 3, 1, 1))
15 | conv1.add_module('bn1_2', nn.BatchNorm2d(64))
16 | conv1.add_module('relu1_2', nn.ReLU(inplace=True))
17 |
18 | self.conv1 = conv1
19 | conv2 = nn.Sequential()
20 | conv2.add_module('pool1', nn.MaxPool2d(2, stride=2))
21 | conv2.add_module('conv2_1', nn.Conv2d(64, 128, 3, 1, 1))
22 | conv2.add_module('bn2_1', nn.BatchNorm2d(128))
23 | conv2.add_module('relu2_1', nn.ReLU())
24 | conv2.add_module('conv2_2', nn.Conv2d(128, 128, 3, 1, 1))
25 | conv2.add_module('bn2_2', nn.BatchNorm2d(128))
26 | conv2.add_module('relu2_2', nn.ReLU())
27 | self.conv2 = conv2
28 |
29 | conv3 = nn.Sequential()
30 | conv3.add_module('pool2', nn.MaxPool2d(2, stride=2))
31 | conv3.add_module('conv3_1', nn.Conv2d(128, 256, 3, 1, 1))
32 | conv3.add_module('bn3_1', nn.BatchNorm2d(256))
33 | conv3.add_module('relu3_1', nn.ReLU())
34 | conv3.add_module('conv3_2', nn.Conv2d(256, 256, 3, 1, 1))
35 | conv3.add_module('bn3_2', nn.BatchNorm2d(256))
36 | conv3.add_module('relu3_2', nn.ReLU())
37 | conv3.add_module('conv3_3', nn.Conv2d(256, 256, 3, 1, 1))
38 | conv3.add_module('bn3_3', nn.BatchNorm2d(256))
39 | conv3.add_module('relu3_3', nn.ReLU())
40 | self.conv3 = conv3
41 |
42 | conv4 = nn.Sequential()
43 | conv4.add_module('pool3_1', nn.MaxPool2d(2, stride=2))
44 | conv4.add_module('conv4_1', nn.Conv2d(256, 512, 3, 1, 1))
45 | conv4.add_module('bn4_1', nn.BatchNorm2d(512))
46 | conv4.add_module('relu4_1', nn.ReLU())
47 | conv4.add_module('conv4_2', nn.Conv2d(512, 512, 3, 1, 1))
48 | conv4.add_module('bn4_2', nn.BatchNorm2d(512))
49 | conv4.add_module('relu4_2', nn.ReLU())
50 | conv4.add_module('conv4_3', nn.Conv2d(512, 512, 3, 1, 1))
51 | conv4.add_module('bn4_3', nn.BatchNorm2d(512))
52 | conv4.add_module('relu4_3', nn.ReLU())
53 | self.conv4 = conv4
54 |
55 | conv5 = nn.Sequential()
56 | conv5.add_module('pool4', nn.MaxPool2d(2, stride=2))
57 | conv5.add_module('conv5_1', nn.Conv2d(512, 512, 3, 1, 1))
58 | conv5.add_module('bn5_1', nn.BatchNorm2d(512))
59 | conv5.add_module('relu5_1', nn.ReLU())
60 | conv5.add_module('conv5_2', nn.Conv2d(512, 512, 3, 1, 1))
61 | conv5.add_module('bn5_2', nn.BatchNorm2d(512))
62 | conv5.add_module('relu5_2', nn.ReLU())
63 | conv5.add_module('conv5_3', nn.Conv2d(512, 512, 3, 1, 1))
64 | conv5.add_module('bn5_2', nn.BatchNorm2d(512))
65 | conv5.add_module('relu5_3', nn.ReLU())
66 | self.conv5 = conv5
67 |
68 | pre_train = torch.load('/home/lgy/20210206_ORSI_SOD/model/vgg16-397923af.pth')
69 | self._initialize_weights(pre_train)
70 |
71 | def forward(self, x):
72 | x = self.conv1(x)
73 | x = self.conv2(x)
74 | x = self.conv3(x)
75 | x = self.conv4(x)
76 | x = self.conv5(x)
77 |
78 | return x
79 |
80 | def _initialize_weights(self, pre_train):
81 | keys = pre_train.keys()
82 |
83 | self.conv1.conv1_1.weight.data.copy_(pre_train[keys[0]])
84 | self.conv1.conv1_2.weight.data.copy_(pre_train[keys[2]])
85 | self.conv2.conv2_1.weight.data.copy_(pre_train[keys[4]])
86 | self.conv2.conv2_2.weight.data.copy_(pre_train[keys[6]])
87 | self.conv3.conv3_1.weight.data.copy_(pre_train[keys[8]])
88 | self.conv3.conv3_2.weight.data.copy_(pre_train[keys[10]])
89 | self.conv3.conv3_3.weight.data.copy_(pre_train[keys[12]])
90 | self.conv4.conv4_1.weight.data.copy_(pre_train[keys[14]])
91 | self.conv4.conv4_2.weight.data.copy_(pre_train[keys[16]])
92 | self.conv4.conv4_3.weight.data.copy_(pre_train[keys[18]])
93 | self.conv5.conv5_1.weight.data.copy_(pre_train[keys[20]])
94 | self.conv5.conv5_2.weight.data.copy_(pre_train[keys[22]])
95 | self.conv5.conv5_3.weight.data.copy_(pre_train[keys[24]])
96 |
97 | self.conv1.conv1_1.bias.data.copy_(pre_train[keys[1]])
98 | self.conv1.conv1_2.bias.data.copy_(pre_train[keys[3]])
99 | self.conv2.conv2_1.bias.data.copy_(pre_train[keys[5]])
100 | self.conv2.conv2_2.bias.data.copy_(pre_train[keys[7]])
101 | self.conv3.conv3_1.bias.data.copy_(pre_train[keys[9]])
102 | self.conv3.conv3_2.bias.data.copy_(pre_train[keys[11]])
103 | self.conv3.conv3_3.bias.data.copy_(pre_train[keys[13]])
104 | self.conv4.conv4_1.bias.data.copy_(pre_train[keys[15]])
105 | self.conv4.conv4_2.bias.data.copy_(pre_train[keys[17]])
106 | self.conv4.conv4_3.bias.data.copy_(pre_train[keys[19]])
107 | self.conv5.conv5_1.bias.data.copy_(pre_train[keys[21]])
108 | self.conv5.conv5_2.bias.data.copy_(pre_train[keys[23]])
109 | self.conv5.conv5_3.bias.data.copy_(pre_train[keys[25]])
110 |
--------------------------------------------------------------------------------
/pytorch_fm/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import numpy as np
7 |
8 |
9 |
10 | # class FLoss(torch.nn.Module):
11 | # def __init__(self, beta=0.3, log_like=False):
12 | # super(FLoss, self).__init__()
13 | # self.beta = beta
14 | # self.log_like = log_like
15 | #
16 | # def forward(self, prediction, target):
17 | # EPS = 1e-10
18 | # floss = 0.0
19 | # N = prediction.shape[0]
20 | # for i in range(0, N):
21 | # TP = (prediction[i, :, :, :] * target[i, :, :, :])
22 | # H = self.beta * target[i, :, :, :] + prediction[i, :, :, :]
23 | # fm = (1 + self.beta) * TP / (H + EPS)
24 | # if self.log_like:
25 | # floss = floss - torch.log(fm)
26 | # else:
27 | # floss = floss + (1 - fm)
28 | #
29 | # return floss / N
30 |
31 | class FLoss(torch.nn.Module):
32 | def __init__(self, beta=0.3, log_like=False):
33 | super(FLoss, self).__init__()
34 | self.beta = beta
35 | self.log_like = log_like
36 |
37 | def forward(self, prediction, target):
38 | EPS = 1e-10
39 | N = prediction.size(0)
40 | TP = (prediction * target).view(N, -1).sum(dim=1)
41 | H = self.beta * target.view(N, -1).sum(dim=1) + prediction.view(N, -1).sum(dim=1)
42 | fmeasure = (1 + self.beta) * TP / (H + EPS)
43 | if self.log_like:
44 | floss = -torch.log(fmeasure)
45 | else:
46 | floss = (1 - fmeasure)
47 | return floss.mean()
--------------------------------------------------------------------------------
/pytorch_fm/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/MCCNet/a9b1876267412ee795acfa94b4921b7f8940fe27/pytorch_fm/__init__.pyc
--------------------------------------------------------------------------------
/pytorch_iou/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import numpy as np
7 |
8 | def _iou(pred, target, size_average = True):
9 |
10 | b = pred.shape[0]
11 | IoU = 0.0
12 | for i in range(0,b):
13 | #compute the IoU of the foreground
14 | Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:])
15 | Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1
16 | IoU1 = Iand1/Ior1
17 |
18 | #IoU loss is (1-IoU1)
19 | IoU = IoU + (1-IoU1)
20 |
21 | return IoU/b
22 |
23 | class IOU(torch.nn.Module):
24 | def __init__(self, size_average = True):
25 | super(IOU, self).__init__()
26 | self.size_average = size_average
27 |
28 | def forward(self, pred, target):
29 |
30 | return _iou(pred, target, self.size_average)
31 |
--------------------------------------------------------------------------------
/pytorch_iou/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/MCCNet/a9b1876267412ee795acfa94b4921b7f8940fe27/pytorch_iou/__init__.pyc
--------------------------------------------------------------------------------
/pytorch_iou/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/MCCNet/a9b1876267412ee795acfa94b4921b7f8940fe27/pytorch_iou/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/test_MCCNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import numpy as np
5 | import pdb, os, argparse
6 | from scipy import misc
7 | import time
8 |
9 | from model.MCCNet_models import MCCNet_VGG
10 | from data import test_dataset
11 |
12 | torch.cuda.set_device(1)
13 | parser = argparse.ArgumentParser()
14 | parser.add_argument('--testsize', type=int, default=256, help='testing size')
15 | opt = parser.parse_args()
16 |
17 | dataset_path = './dataset/test_dataset/'
18 |
19 | model = MCCNet_VGG()
20 | model.load_state_dict(torch.load('./models/MCCNet_VGG/MCCNet_VGG.pth.34'))
21 |
22 | model.cuda()
23 | model.eval()
24 |
25 | # test_datasets = ['EORSSD']
26 | test_datasets = ['ORSSD']
27 |
28 | for dataset in test_datasets:
29 | save_path = './results/VGG/' + dataset + '/'
30 | if not os.path.exists(save_path):
31 | os.makedirs(save_path)
32 | image_root = dataset_path + dataset + '/image/'
33 | print(dataset)
34 | gt_root = dataset_path + dataset + '/GT/'
35 | test_loader = test_dataset(image_root, gt_root, opt.testsize)
36 | time_sum = 0
37 | for i in range(test_loader.size):
38 | image, gt, name = test_loader.load_data()
39 | gt = np.asarray(gt, np.float32)
40 | gt /= (gt.max() + 1e-8)
41 | image = image.cuda()
42 | time_start = time.time()
43 | res, s2, s3, s4, s5, s1_sig, s2_sig, s3_sig, s4_sig, s5_sig, eg1, eg2, eg3, eg4, eg5 = model(image)
44 | time_end = time.time()
45 | time_sum = time_sum+(time_end-time_start)
46 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
47 | res = res.sigmoid().data.cpu().numpy().squeeze()
48 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
49 | misc.imsave(save_path+name, res)
50 | if i == test_loader.size-1:
51 | print('Running time {:.5f}'.format(time_sum/test_loader.size))
52 | print('Average speed: {:.4f} fps'.format(test_loader.size/time_sum))
53 |
--------------------------------------------------------------------------------
/train_MCCNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 |
6 | import numpy as np
7 | import pdb, os, argparse
8 | from datetime import datetime
9 |
10 | from model.MCCNet_models import MCCNet_VGG
11 | from data import get_loader
12 | from utils import clip_gradient, adjust_lr
13 |
14 | import pytorch_iou
15 | import pytorch_fm
16 |
17 | torch.cuda.set_device(1)
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--epoch', type=int, default=40, help='epoch number')
20 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
21 | parser.add_argument('--batchsize', type=int, default=8, help='training batch size')
22 | parser.add_argument('--trainsize', type=int, default=256, help='training dataset size')
23 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
24 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
25 | parser.add_argument('--decay_epoch', type=int, default=30, help='every n epochs decay learning rate')
26 | opt = parser.parse_args()
27 |
28 | print('Learning Rate: {}'.format(opt.lr))
29 | # build models
30 | model = MCCNet_VGG()
31 |
32 | model.cuda()
33 | params = model.parameters()
34 | optimizer = torch.optim.Adam(params, opt.lr)
35 |
36 | image_root = './dataset/train_dataset/ORSSD/train/image/'
37 | gt_root = './dataset/train_dataset/ORSSD/train/GT/'
38 | edge_root = './dataset/train_dataset/ORSSD/train/edge/'
39 | # image_root = './dataset/train_dataset/EORSSD/train/image/'
40 | # gt_root = './dataset/train_dataset/EORSSD/train/GT/'
41 | # edge_root = './dataset/train_dataset/EORSSD/train/edge/'
42 | train_loader = get_loader(image_root, gt_root, edge_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
43 | total_step = len(train_loader)
44 |
45 | CE = torch.nn.BCEWithLogitsLoss()
46 | IOU = pytorch_iou.IOU(size_average = True)
47 | floss = pytorch_fm.FLoss()
48 |
49 | def train(train_loader, model, optimizer, epoch):
50 | model.train()
51 | for i, pack in enumerate(train_loader, start=1):
52 | optimizer.zero_grad()
53 | images, gts, edges = pack
54 | images = Variable(images)
55 | gts = Variable(gts)
56 | edges = Variable(edges)
57 | images = images.cuda()
58 | gts = gts.cuda()
59 | edges = edges.cuda()
60 |
61 | s1, s2, s3, s4, s5, s1_sig, s2_sig, s3_sig, s4_sig, s5_sig, eg1, eg2, eg3, eg4, eg5 = model(images)
62 | # bce+iou+fmloss
63 | loss1 = CE(s1, gts) + IOU(s1_sig, gts) + CE(eg1, edges) + floss(s1_sig, gts)
64 | loss2 = CE(s2, gts) + IOU(s2_sig, gts) + CE(eg2, edges) + floss(s2_sig, gts)
65 | loss3 = CE(s3, gts) + IOU(s3_sig, gts) + CE(eg3, edges) + floss(s3_sig, gts)
66 | loss4 = CE(s4, gts) + IOU(s4_sig, gts) + CE(eg4, edges) + floss(s4_sig, gts)
67 | loss5 = CE(s5, gts) + IOU(s5_sig, gts) + CE(eg5, edges) + floss(s5_sig, gts)
68 |
69 | loss = loss1 + loss2 + loss3 + loss4 + loss5
70 |
71 | loss.backward()
72 |
73 | clip_gradient(optimizer, opt.clip)
74 | optimizer.step()
75 |
76 | if i % 20 == 0 or i == total_step:
77 | print(
78 | '{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Learning Rate: {}, Loss: {:.4f}, Loss1: {:.4f}, Loss2: {:.4f}'.
79 | format(datetime.now(), epoch, opt.epoch, i, total_step,
80 | opt.lr * opt.decay_rate ** (epoch // opt.decay_epoch), loss.data, loss1.data,
81 | loss2.data))
82 |
83 | save_path = 'models/MCCNet_VGG/'
84 | if not os.path.exists(save_path):
85 | os.makedirs(save_path)
86 | if (epoch+1) % 5 == 0:
87 | torch.save(model.state_dict(), save_path + 'MCCNet_VGG.pth' + '.%d' % epoch)
88 |
89 | print("Let's go!")
90 | for epoch in range(1, opt.epoch):
91 | adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
92 | train(train_loader, model, optimizer, epoch)
93 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | def clip_gradient(optimizer, grad_clip):
2 | for group in optimizer.param_groups:
3 | for param in group['params']:
4 | if param.grad is not None:
5 | param.grad.data.clamp_(-grad_clip, grad_clip)
6 |
7 |
8 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
9 | decay = decay_rate ** (epoch // decay_epoch)
10 | for param_group in optimizer.param_groups:
11 | param_group['lr'] = init_lr*decay
12 | print('decay_epoch: {}, Current_LR: {}'.format(decay_epoch, init_lr*decay))
--------------------------------------------------------------------------------