├── README.md
├── data.py
├── data_aug.m
├── image
├── ACCoNet.png
└── table.png
├── model
├── ACCoNet_Res_models.py
├── ACCoNet_VGG_models.py
├── ResNet50.py
├── __init__.py
└── vgg.py
├── pytorch_iou
├── __init__.py
├── __init__.pyc
└── __pycache__
│ └── __init__.cpython-36.pyc
├── test_ACCoNet.py
├── train_ACCoNet.py
└── utils.py
/README.md:
--------------------------------------------------------------------------------
1 | # ACCoNet
2 | This project provides the code and results for 'Adjacent Context Coordination Network for Salient Object Detection in Optical Remote Sensing Images', IEEE TCYB, 2023. [IEEE link](https://ieeexplore.ieee.org/document/9756652) and [arxiv link](https://arxiv.org/abs/2203.13664) [Homepage](https://mathlee.github.io/)
3 |
4 | # Network Architecture
5 |
6 |

7 |
8 |
9 |
10 | # Requirements
11 | python 2.7 + pytorch 0.4.0 or
12 |
13 | python 3.7 + pytorch 1.9.0
14 |
15 |
16 | # Saliency maps
17 | We provide saliency maps of our ACCoNet ([VGG_backbone](https://pan.baidu.com/s/11KzUltnKIwbYFbEXtud2gQ) (code: gr06) and [ResNet_backbone](https://pan.baidu.com/s/1_ksAXbRrMWupToCxcSDa8g) (code: 1hpn)) on ORSSD, EORSSD, and additional [ORSI-4199](https://github.com/wchao1213/ORSI-SOD) datasets.
18 |
19 | 
20 |
21 | # Training
22 |
23 | We provide the code for ACCoNet_VGG and ACCoNet_ResNet, please modify '--is_ResNet' and the paths of datasets in train_ACCoNet.py.
24 |
25 | For ACCoNet_VGG, please modify paths of [VGG backbone](https://pan.baidu.com/s/1YQxKZ-y2C4EsqrgKNI7qrw) (code: ego5) in /model/vgg.py.
26 |
27 | data_aug.m is used for data augmentation.
28 |
29 |
30 | # Pre-trained model and testing
31 | 1. Download the following pre-trained models and put them in /models.
32 |
33 | 2. Modify paths of pre-trained models and datasets.
34 |
35 | 3. Run test_ACCoNet.py.
36 |
37 | ORSSD: [ACCoNet_VGG](https://pan.baidu.com/s/1mPb7oyaz9OVKs3T9v4xCmw) (code: 1bsg); [ACCoNet_ResNet](https://pan.baidu.com/s/1UhHLxgBvMgD66jz2SKgclw) (code: mv91).
38 |
39 | EORSSD: [ACCoNet_VGG](https://pan.baidu.com/s/1R2mFox8rEyxH1DTTnMinLA) (code: i016); [ACCoNet_ResNet](https://pan.baidu.com/s/1-TkZcxR6fBNYWKljhL1Qrg) (code: ak5m).
40 |
41 | ORSI-4199: [ACCoNet_VGG](https://pan.baidu.com/s/1WUVmVCwICBEM3gUJxQ5pkw) (code: qv05); [ACCoNet_ResNet](https://pan.baidu.com/s/1I4RWaLDx4ukK8_11y1AEtw) (code: art7).
42 |
43 |
44 | # Evaluation Tool
45 | You can use the [evaluation tool (MATLAB version)](https://github.com/MathLee/MatlabEvaluationTools) to evaluate the above saliency maps.
46 |
47 |
48 | # [ORSI-SOD_Summary](https://github.com/MathLee/ORSI-SOD_Summary)
49 |
50 | # Citation
51 | @ARTICLE{Li_2023_ACCoNet,
52 | author = {Gongyang Li and Zhi Liu and Dan Zeng and Weisi Lin and Haibin Ling},
53 | title = {Adjacent Context Coordination Network for Salient Object Detection in Optical Remote Sensing Images},
54 | journal = {IEEE Transactions on Cybernetics},
55 | volume = {53},
56 | number = {1},
57 | pages = {526-538},
58 | year = {2023},
59 | month = {Jan.},
60 | }
61 |
62 |
63 | If you encounter any problems with the code, want to report bugs, etc.
64 |
65 | Please contact me at lllmiemie@163.com or ligongyang@shu.edu.cn.
66 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import random
6 | import numpy as np
7 | from PIL import ImageEnhance
8 |
9 |
10 | #several data augumentation strategies
11 | def cv_random_flip(img, label):
12 | flip_flag = random.randint(0, 1)
13 | # flip_flag2= random.randint(0,1)
14 | #left right flip
15 | if flip_flag == 1:
16 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
17 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
18 | #top bottom flip
19 | # if flip_flag2==1:
20 | # img = img.transpose(Image.FLIP_TOP_BOTTOM)
21 | # label = label.transpose(Image.FLIP_TOP_BOTTOM)
22 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM)
23 | return img, label
24 | def randomCrop(image, label):
25 | border=30
26 | image_width = image.size[0]
27 | image_height = image.size[1]
28 | crop_win_width = np.random.randint(image_width-border , image_width)
29 | crop_win_height = np.random.randint(image_height-border , image_height)
30 | random_region = (
31 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
32 | (image_height + crop_win_height) >> 1)
33 | return image.crop(random_region), label.crop(random_region)
34 | def randomRotation(image,label):
35 | mode=Image.BICUBIC
36 | if random.random()>0.8:
37 | random_angle = np.random.randint(-15, 15)
38 | image=image.rotate(random_angle, mode)
39 | label=label.rotate(random_angle, mode)
40 | return image,label
41 | def colorEnhance(image):
42 | bright_intensity=random.randint(5,15)/10.0
43 | image=ImageEnhance.Brightness(image).enhance(bright_intensity)
44 | contrast_intensity=random.randint(5,15)/10.0
45 | image=ImageEnhance.Contrast(image).enhance(contrast_intensity)
46 | color_intensity=random.randint(0,20)/10.0
47 | image=ImageEnhance.Color(image).enhance(color_intensity)
48 | sharp_intensity=random.randint(0,30)/10.0
49 | image=ImageEnhance.Sharpness(image).enhance(sharp_intensity)
50 | return image
51 | def randomGaussian(image, mean=0.1, sigma=0.35):
52 | def gaussianNoisy(im, mean=mean, sigma=sigma):
53 | for _i in range(len(im)):
54 | im[_i] += random.gauss(mean, sigma)
55 | return im
56 | img = np.asarray(image)
57 | width, height = img.shape
58 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
59 | img = img.reshape([width, height])
60 | return Image.fromarray(np.uint8(img))
61 | def randomPeper(img):
62 |
63 | img=np.array(img)
64 | noiseNum=int(0.0015*img.shape[0]*img.shape[1])
65 | for i in range(noiseNum):
66 |
67 | randX=random.randint(0,img.shape[0]-1)
68 |
69 | randY=random.randint(0,img.shape[1]-1)
70 |
71 | if random.randint(0,1)==0:
72 |
73 | img[randX,randY]=0
74 |
75 | else:
76 |
77 | img[randX,randY]=255
78 | return Image.fromarray(img)
79 |
80 |
81 | class SalObjDataset(data.Dataset):
82 | def __init__(self, image_root, gt_root, trainsize):
83 | self.trainsize = trainsize
84 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
85 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
86 | or f.endswith('.png')]
87 | self.images = sorted(self.images)
88 | # self.depths = sorted(self.depths)
89 | self.gts = sorted(self.gts)
90 | self.filter_files()
91 | self.size = len(self.images)
92 | self.img_transform = transforms.Compose([
93 | transforms.Resize((self.trainsize, self.trainsize)),
94 | transforms.ToTensor(),
95 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
96 |
97 | self.gt_transform = transforms.Compose([
98 | transforms.Resize((self.trainsize, self.trainsize)),
99 | transforms.ToTensor()])
100 |
101 | def __getitem__(self, index):
102 | image = self.rgb_loader(self.images[index])
103 | gt = self.binary_loader(self.gts[index])
104 | # image,gt =cv_randop0;l.......
105 | image = self.img_transform(image)
106 | gt = self.gt_transform(gt)
107 | return image, gt
108 |
109 | def filter_files(self):
110 | assert len(self.images) == len(self.gts)
111 | images = []
112 | # depths = []
113 | gts = []
114 | for img_path, gt_path in zip(self.images, self.gts):
115 | img = Image.open(img_path)
116 | gt = Image.open(gt_path)
117 | if img.size == gt.size:
118 | images.append(img_path)
119 | gts.append(gt_path)
120 | self.images = images
121 | self.gts = gts
122 |
123 | def rgb_loader(self, path):
124 | with open(path, 'rb') as f:
125 | img = Image.open(f)
126 | return img.convert('RGB')
127 |
128 | def binary_loader(self, path):
129 | with open(path, 'rb') as f:
130 | img = Image.open(f)
131 | # return img.convert('1')
132 | return img.convert('L')
133 |
134 | def resize(self, img, gt):
135 | assert img.size == gt.size
136 | w, h = img.size
137 | if h < self.trainsize or w < self.trainsize:
138 | h = max(h, self.trainsize)
139 | w = max(w, self.trainsize)
140 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST)
141 | else:
142 | return img, gt
143 |
144 | def __len__(self):
145 | return self.size
146 |
147 |
148 | def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True):
149 |
150 | dataset = SalObjDataset(image_root, gt_root, trainsize)
151 | data_loader = data.DataLoader(dataset=dataset,
152 | batch_size=batchsize,
153 | shuffle=shuffle,
154 | num_workers=num_workers,
155 | pin_memory=pin_memory)
156 | return data_loader
157 |
158 |
159 | class test_dataset:
160 | def __init__(self, image_root, gt_root, testsize):
161 | self.testsize = testsize
162 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
163 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
164 | or f.endswith('.png')]
165 | self.images = sorted(self.images)
166 | # self.depths = sorted(self.depths)
167 | self.gts = sorted(self.gts)
168 | self.img_transform = transforms.Compose([
169 | transforms.Resize((self.testsize, self.testsize)),
170 | transforms.ToTensor(),
171 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
172 | self.gt_transform = transforms.ToTensor()
173 | self.size = len(self.images)
174 | self.index = 0
175 |
176 | def load_data(self):
177 | image = self.rgb_loader(self.images[self.index])
178 | image = self.img_transform(image).unsqueeze(0)
179 | gt = self.binary_loader(self.gts[self.index])
180 | name = self.images[self.index].split('/')[-1]
181 | if name.endswith('.jpg'):
182 | name = name.split('.jpg')[0] + '.png'
183 | self.index += 1
184 | return image, gt, name
185 |
186 | def rgb_loader(self, path):
187 | with open(path, 'rb') as f:
188 | img = Image.open(f)
189 | return img.convert('RGB')
190 |
191 | def binary_loader(self, path):
192 | with open(path, 'rb') as f:
193 | img = Image.open(f)
194 | return img.convert('L')
--------------------------------------------------------------------------------
/data_aug.m:
--------------------------------------------------------------------------------
1 | clear;
2 | clc;
3 | close all;
4 |
5 | imPath = '/home/lgy/桌面/ORSI_SOD/dataset/EORSSD/train/image/';
6 | GtPath = '/home/lgy/桌面/ORSI_SOD/dataset/EORSSD/train/GT/';
7 | EGPath = '/home/lgy/桌面/ORSI_SOD/dataset/EORSSD/train/edge/';
8 |
9 | images = dir([GtPath '*.png']);
10 | imagesNum = length(images);
11 |
12 | for i = 1 : imagesNum
13 | im_name = images(i).name(1:end-4);
14 |
15 | gt = imread(fullfile(GtPath, [im_name '.png']));
16 | im = imread(fullfile(imPath, [im_name '.jpg']));
17 | eg = imread(fullfile(EGPath, [im_name '.png']));
18 |
19 | im_1 = imrotate(im,90);
20 | gt_1 = imrotate(gt,90);
21 | eg_1 = imrotate(eg,90);
22 | imwrite(im_1, fullfile(imPath, [im_name '_90.jpg']));
23 | imwrite(gt_1, fullfile(GtPath, [im_name '_90.png']));
24 | imwrite(eg_1, fullfile(EGPath, [im_name '_90.png']));
25 |
26 | im_2 = imrotate(im, 180);
27 | gt_2 = imrotate(gt, 180);
28 | eg_2 = imrotate(eg, 180);
29 | imwrite(im_2, fullfile(imPath, [im_name '_180.jpg']));
30 | imwrite(gt_2, fullfile(GtPath, [im_name '_180.png']));
31 | imwrite(eg_2, fullfile(EGPath, [im_name '_180.png']));
32 |
33 | im_3 = imrotate(im, 270);
34 | gt_3 = imrotate(gt, 270);
35 | eg_3 = imrotate(eg, 270);
36 | imwrite(im_3, fullfile(imPath, [im_name '_270.jpg']));
37 | imwrite(gt_3, fullfile(GtPath, [im_name '_270.png']));
38 | imwrite(eg_3, fullfile(EGPath, [im_name '_270.png']));
39 |
40 |
41 | fl_im = fliplr(im);
42 | fl_gt = fliplr(gt);
43 | fl_eg = fliplr(eg);
44 | imwrite(fl_im, fullfile(imPath, [im_name '_fl.jpg']));
45 | imwrite(fl_gt, fullfile(GtPath, [im_name '_fl.png']));
46 | imwrite(fl_eg, fullfile(EGPath, [im_name '_fl.png']));
47 |
48 | im_1 = imrotate(fl_im,90);
49 | gt_1 = imrotate(fl_gt,90);
50 | eg_1 = imrotate(fl_eg,90);
51 | imwrite(im_1, fullfile(imPath, [im_name '_fl90.jpg']));
52 | imwrite(gt_1, fullfile(GtPath, [im_name '_fl90.png']));
53 | imwrite(eg_1, fullfile(EGPath, [im_name '_fl90.png']));
54 |
55 | im_2 = imrotate(fl_im, 180);
56 | gt_2 = imrotate(fl_gt, 180);
57 | eg_2 = imrotate(fl_eg, 180);
58 | imwrite(im_2, fullfile(imPath, [im_name '_fl180.jpg']));
59 | imwrite(gt_2, fullfile(GtPath, [im_name '_fl180.png']));
60 | imwrite(eg_2, fullfile(EGPath, [im_name '_fl180.png']));
61 |
62 | im_3 = imrotate(fl_im, 270);
63 | gt_3 = imrotate(fl_gt, 270);
64 | eg_3 = imrotate(fl_eg, 270);
65 | imwrite(im_3, fullfile(imPath, [im_name '_fl270.jpg']));
66 | imwrite(gt_3, fullfile(GtPath, [im_name '_fl270.png']));
67 | imwrite(eg_3, fullfile(EGPath, [im_name '_fl270.png']));
68 |
69 | end
70 |
71 |
--------------------------------------------------------------------------------
/image/ACCoNet.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/ACCoNet/d37bd207286b9e522803d3695e7180e4941ae117/image/ACCoNet.png
--------------------------------------------------------------------------------
/image/table.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/ACCoNet/d37bd207286b9e522803d3695e7180e4941ae117/image/table.png
--------------------------------------------------------------------------------
/model/ACCoNet_Res_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | from ResNet50 import Backbone_ResNet50_in3
5 |
6 |
7 | class BasicConv2d(nn.Module):
8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
9 | super(BasicConv2d, self).__init__()
10 | self.conv = nn.Conv2d(in_planes, out_planes,
11 | kernel_size=kernel_size, stride=stride,
12 | padding=padding, dilation=dilation, bias=False)
13 | self.bn = nn.BatchNorm2d(out_planes)
14 | self.relu = nn.ReLU(inplace=True)
15 |
16 | def forward(self, x):
17 | x = self.conv(x)
18 | x = self.bn(x)
19 | x = self.relu(x)
20 | return x
21 |
22 | class TransBasicConv2d(nn.Module):
23 | def __init__(self, in_planes, out_planes, kernel_size=2, stride=2, padding=0, dilation=1, bias=False):
24 | super(TransBasicConv2d, self).__init__()
25 | self.Deconv = nn.ConvTranspose2d(in_planes, out_planes,
26 | kernel_size=kernel_size, stride=stride,
27 | padding=padding, dilation=dilation, bias=False)
28 | self.bn = nn.BatchNorm2d(out_planes)
29 | self.relu = nn.ReLU(inplace=True)
30 |
31 | def forward(self, x):
32 | x = self.Deconv(x)
33 | x = self.bn(x)
34 | x = self.relu(x)
35 | return x
36 |
37 |
38 | class ChannelAttention(nn.Module):
39 | def __init__(self, in_planes, ratio=16):
40 | super(ChannelAttention, self).__init__()
41 |
42 | self.max_pool = nn.AdaptiveMaxPool2d(1)
43 |
44 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
45 | self.relu1 = nn.ReLU()
46 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
47 |
48 | self.sigmoid = nn.Sigmoid()
49 |
50 | def forward(self, x):
51 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
52 | out = max_out
53 | return self.sigmoid(out)
54 |
55 |
56 | class SpatialAttention(nn.Module):
57 | def __init__(self, kernel_size=7):
58 | super(SpatialAttention, self).__init__()
59 |
60 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
61 | padding = 3 if kernel_size == 7 else 1
62 |
63 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False)
64 | self.sigmoid = nn.Sigmoid()
65 |
66 | def forward(self, x):
67 | max_out, _ = torch.max(x, dim=1, keepdim=True)
68 | x = max_out
69 | x = self.conv1(x)
70 | return self.sigmoid(x)
71 |
72 | # for conv5
73 | class ACCoM_5(nn.Module):
74 | def __init__(self, cur_channel):
75 | super(ACCoM_5, self).__init__()
76 | self.relu = nn.ReLU(True)
77 |
78 | # current conv
79 | self.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1)
80 | self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2)
81 | self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3)
82 | self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4)
83 |
84 | self.cur_all = BasicConv2d(4*cur_channel, cur_channel, 3, padding=1)
85 | self.cur_all_ca = ChannelAttention(cur_channel)
86 | self.cur_all_sa = SpatialAttention()
87 |
88 | # previous conv
89 | self.downsample2 = nn.MaxPool2d(2, stride=2)
90 | self.pre_sa = SpatialAttention()
91 |
92 | # for m in self.modules():
93 | # if isinstance(m, nn.Conv2d):
94 | # m.weight.data.normal_(std=0.01)
95 | # m.bias.data.fill_(0)
96 |
97 | def forward(self, x_pre, x_cur):
98 | # current conv
99 | x_cur_1 = self.cur_b1(x_cur)
100 | x_cur_2 = self.cur_b2(x_cur)
101 | x_cur_3 = self.cur_b3(x_cur)
102 | x_cur_4 = self.cur_b4(x_cur)
103 | x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1))
104 | cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all))
105 | cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca))
106 |
107 | # previois conv
108 | x_pre = self.downsample2(x_pre)
109 | pre_sa = x_cur_all.mul(self.pre_sa(x_pre))
110 |
111 | x_LocAndGlo = cur_all_sa + pre_sa + x_cur
112 |
113 | return x_LocAndGlo
114 |
115 | # for conv1
116 | class ACCoM_1(nn.Module):
117 | def __init__(self, cur_channel):
118 | super(ACCoM_1, self).__init__()
119 | self.relu = nn.ReLU(True)
120 |
121 | # current conv
122 | self.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1)
123 | self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2)
124 | self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3)
125 | self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4)
126 |
127 | self.cur_all = BasicConv2d(4*cur_channel, cur_channel, 3, padding=1)
128 | self.cur_all_ca = ChannelAttention(cur_channel)
129 | self.cur_all_sa = SpatialAttention()
130 |
131 | # latter conv
132 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
133 | self.lat_sa = SpatialAttention()
134 |
135 | # for m in self.modules():
136 | # if isinstance(m, nn.Conv2d):
137 | # m.weight.data.normal_(std=0.01)
138 | # m.bias.data.fill_(0)
139 |
140 | def forward(self, x_cur, x_lat):
141 | # current conv
142 | x_cur_1 = self.cur_b1(x_cur)
143 | x_cur_2 = self.cur_b2(x_cur)
144 | x_cur_3 = self.cur_b3(x_cur)
145 | x_cur_4 = self.cur_b4(x_cur)
146 | x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1))
147 | cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all))
148 | cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca))
149 |
150 | # latter conv
151 | x_lat = self.upsample2(x_lat)
152 | lat_sa = x_cur_all.mul(self.lat_sa(x_lat))
153 |
154 | x_LocAndGlo = cur_all_sa + lat_sa + x_cur
155 |
156 | return x_LocAndGlo
157 |
158 | # for conv2/3/4
159 | class ACCoM(nn.Module):
160 | def __init__(self, cur_channel):
161 | super(ACCoM, self).__init__()
162 | self.relu = nn.ReLU(True)
163 |
164 | # current conv
165 | self.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1)
166 | self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2)
167 | self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3)
168 | self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4)
169 |
170 | self.cur_all = BasicConv2d(4 * cur_channel, cur_channel, 3, padding=1)
171 | self.cur_all_ca = ChannelAttention(cur_channel)
172 | self.cur_all_sa = SpatialAttention()
173 |
174 | # previous conv
175 | self.downsample2 = nn.MaxPool2d(2, stride=2)
176 | self.pre_sa = SpatialAttention()
177 |
178 | # latter conv
179 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
180 | self.lat_sa = SpatialAttention()
181 |
182 | # for m in self.modules():
183 | # if isinstance(m, nn.Conv2d):
184 | # m.weight.data.normal_(std=0.01)
185 | # m.bias.data.fill_(0)
186 |
187 | def forward(self, x_pre, x_cur, x_lat):
188 | # current conv
189 | x_cur_1 = self.cur_b1(x_cur)
190 | x_cur_2 = self.cur_b2(x_cur)
191 | x_cur_3 = self.cur_b3(x_cur)
192 | x_cur_4 = self.cur_b4(x_cur)
193 | x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1))
194 | cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all))
195 | cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca))
196 |
197 | # previois conv
198 | x_pre = self.downsample2(x_pre)
199 | pre_sa = x_cur_all.mul(self.pre_sa(x_pre))
200 |
201 | # latter conv
202 | x_lat = self.upsample2(x_lat)
203 | lat_sa = x_cur_all.mul(self.lat_sa(x_lat))
204 |
205 | x_LocAndGlo = cur_all_sa + pre_sa + lat_sa + x_cur
206 |
207 | return x_LocAndGlo
208 |
209 |
210 | class BAB_Decoder(nn.Module):
211 | def __init__(self, channel_1=1024, channel_2=512, channel_3=256, dilation_1=3, dilation_2=2):
212 | super(BAB_Decoder, self).__init__()
213 |
214 | self.conv1 = BasicConv2d(channel_1, channel_2, 3, padding=1)
215 | self.conv1_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_1, dilation=dilation_1)
216 | self.conv2 = BasicConv2d(channel_2, channel_2, 3, padding=1)
217 | self.conv2_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_2, dilation=dilation_2)
218 | self.conv3 = BasicConv2d(channel_2, channel_2, 3, padding=1)
219 | self.conv_fuse = BasicConv2d(channel_2*3, channel_3, 3, padding=1)
220 |
221 | def forward(self, x):
222 | x1 = self.conv1(x)
223 | x1_dila = self.conv1_Dila(x1)
224 |
225 | x2 = self.conv2(x1)
226 | x2_dila = self.conv2_Dila(x2)
227 |
228 | x3 = self.conv3(x2)
229 |
230 | x_fuse = self.conv_fuse(torch.cat((x1_dila, x2_dila, x3), 1))
231 |
232 | return x_fuse
233 |
234 |
235 | class decoder(nn.Module):
236 | def __init__(self, channel=512):
237 | super(decoder, self).__init__()
238 | self.relu = nn.ReLU(True)
239 |
240 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
241 |
242 | self.decoder5 = nn.Sequential(
243 | BAB_Decoder(512, 512, 512, 3, 2),
244 | nn.Dropout(0.5),
245 | TransBasicConv2d(512, 512, kernel_size=2, stride=2,
246 | padding=0, dilation=1, bias=False)
247 | )
248 | self.S5 = nn.Conv2d(512, 1, 3, stride=1, padding=1)
249 |
250 | self.decoder4 = nn.Sequential(
251 | BAB_Decoder(1024, 512, 256, 3, 2),
252 | nn.Dropout(0.5),
253 | TransBasicConv2d(256, 256, kernel_size=2, stride=2,
254 | padding=0, dilation=1, bias=False)
255 | )
256 | self.S4 = nn.Conv2d(256, 1, 3, stride=1, padding=1)
257 |
258 | self.decoder3 = nn.Sequential(
259 | BAB_Decoder(512, 256, 128, 5, 3),
260 | nn.Dropout(0.5),
261 | TransBasicConv2d(128, 128, kernel_size=2, stride=2,
262 | padding=0, dilation=1, bias=False)
263 | )
264 | self.S3 = nn.Conv2d(128, 1, 3, stride=1, padding=1)
265 |
266 | self.decoder2 = nn.Sequential(
267 | BAB_Decoder(256, 128, 64, 5, 3),
268 | nn.Dropout(0.5),
269 | TransBasicConv2d(64, 64, kernel_size=2, stride=2,
270 | padding=0, dilation=1, bias=False)
271 | )
272 | self.S2 = nn.Conv2d(64, 1, 3, stride=1, padding=1)
273 |
274 | self.decoder1 = nn.Sequential(
275 | BAB_Decoder(128, 64, 32, 5, 3)
276 | )
277 | self.S1 = nn.Conv2d(32, 1, 3, stride=1, padding=1)
278 |
279 |
280 | def forward(self, x5, x4, x3, x2, x1):
281 | # x5: 1/16, 512; x4: 1/8, 512; x3: 1/4, 256; x2: 1/2, 128; x1: 1/1, 64
282 | x5_up = self.decoder5(x5)
283 | s5 = self.S5(x5_up)
284 |
285 | x4_up = self.decoder4(torch.cat((x4, x5_up), 1))
286 | s4 = self.S4(x4_up)
287 |
288 | x3_up = self.decoder3(torch.cat((x3, x4_up), 1))
289 | s3 = self.S3(x3_up)
290 |
291 | x2_up = self.decoder2(torch.cat((x2, x3_up), 1))
292 | s2 = self.S2(x2_up)
293 |
294 | x1_up = self.decoder1(torch.cat((x1, x2_up), 1))
295 | s1 = self.S1(x1_up)
296 |
297 | return s1, s2, s3, s4, s5
298 |
299 |
300 | class ACCoNet_Res(nn.Module):
301 | def __init__(self, channel=32):
302 | super(ACCoNet_Res, self).__init__()
303 | #Backbone model
304 | # ---- ResNet50 Backbone ----
305 | (
306 | self.encoder1,
307 | self.encoder2,
308 | self.encoder4,
309 | self.encoder8,
310 | self.encoder16,
311 | ) = Backbone_ResNet50_in3()
312 |
313 | # Lateral layers
314 | self.lateral_conv0 = BasicConv2d(64, 64, 3, stride=1, padding=1)
315 | self.lateral_conv1 = BasicConv2d(256, 128, 3, stride=1, padding=1)
316 | self.lateral_conv2 = BasicConv2d(512, 256, 3, stride=1, padding=1)
317 | self.lateral_conv3 = BasicConv2d(1024, 512, 3, stride=1, padding=1)
318 | self.lateral_conv4 = BasicConv2d(2048, 512, 3, stride=1, padding=1)
319 |
320 | self.ACCoM5 = ACCoM_5(512)
321 | self.ACCoM4 = ACCoM(512)
322 | self.ACCoM3 = ACCoM(256)
323 | self.ACCoM2 = ACCoM(128)
324 | self.ACCoM1 = ACCoM_1(64)
325 |
326 | # self.agg2_rgbd = aggregation(channel)
327 | self.decoder_rgb = decoder(512)
328 |
329 | self.upsample16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True)
330 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
331 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
332 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
333 |
334 | self.sigmoid = nn.Sigmoid()
335 |
336 |
337 | def forward(self, x_rgb):
338 | x0 = self.encoder1(x_rgb)
339 | x1 = self.encoder2(x0)
340 | x2 = self.encoder4(x1)
341 | x3 = self.encoder8(x2)
342 | x4 = self.encoder16(x3)
343 |
344 | x1_rgb = self.lateral_conv0(x0)
345 | x2_rgb = self.lateral_conv1(x1)
346 | x3_rgb = self.lateral_conv2(x2)
347 | x4_rgb = self.lateral_conv3(x3)
348 | x5_rgb = self.lateral_conv4(x4)
349 |
350 |
351 | # up means update
352 | x5_ACCoM = self.ACCoM5(x4_rgb, x5_rgb)
353 | x4_ACCoM = self.ACCoM4(x3_rgb, x4_rgb, x5_rgb)
354 | x3_ACCoM = self.ACCoM3(x2_rgb, x3_rgb, x4_rgb)
355 | x2_ACCoM = self.ACCoM2(x1_rgb, x2_rgb, x3_rgb)
356 | x1_ACCoM = self.ACCoM1(x1_rgb, x2_rgb)
357 |
358 | s1, s2, s3, s4, s5 = self.decoder_rgb(x5_ACCoM, x4_ACCoM, x3_ACCoM, x2_ACCoM, x1_ACCoM)
359 | # At test phase, we can use the HA to post-processing our saliency map
360 | s1 = self.upsample2(s1)
361 | s2 = self.upsample2(s2)
362 | s3 = self.upsample4(s3)
363 | s4 = self.upsample8(s4)
364 | s5 = self.upsample16(s5)
365 |
366 | return s1, s2, s3, s4, s5, self.sigmoid(s1), self.sigmoid(s2), self.sigmoid(s3), self.sigmoid(s4), self.sigmoid(s5)
--------------------------------------------------------------------------------
/model/ACCoNet_VGG_models.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from vgg import VGG
6 |
7 | class BasicConv2d(nn.Module):
8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
9 | super(BasicConv2d, self).__init__()
10 | self.conv = nn.Conv2d(in_planes, out_planes,
11 | kernel_size=kernel_size, stride=stride,
12 | padding=padding, dilation=dilation, bias=False)
13 | self.bn = nn.BatchNorm2d(out_planes)
14 | self.relu = nn.ReLU(inplace=True)
15 |
16 | def forward(self, x):
17 | x = self.conv(x)
18 | x = self.bn(x)
19 | x = self.relu(x)
20 | return x
21 |
22 | class TransBasicConv2d(nn.Module):
23 | def __init__(self, in_planes, out_planes, kernel_size=2, stride=2, padding=0, dilation=1, bias=False):
24 | super(TransBasicConv2d, self).__init__()
25 | self.Deconv = nn.ConvTranspose2d(in_planes, out_planes,
26 | kernel_size=kernel_size, stride=stride,
27 | padding=padding, dilation=dilation, bias=False)
28 | self.bn = nn.BatchNorm2d(out_planes)
29 | self.relu = nn.ReLU(inplace=True)
30 |
31 | def forward(self, x):
32 | x = self.Deconv(x)
33 | x = self.bn(x)
34 | x = self.relu(x)
35 | return x
36 |
37 |
38 | class ChannelAttention(nn.Module):
39 | def __init__(self, in_planes, ratio=16):
40 | super(ChannelAttention, self).__init__()
41 |
42 | self.max_pool = nn.AdaptiveMaxPool2d(1)
43 |
44 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False)
45 | self.relu1 = nn.ReLU()
46 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False)
47 |
48 | self.sigmoid = nn.Sigmoid()
49 |
50 | def forward(self, x):
51 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
52 | out = max_out
53 | return self.sigmoid(out)
54 |
55 |
56 | class SpatialAttention(nn.Module):
57 | def __init__(self, kernel_size=7):
58 | super(SpatialAttention, self).__init__()
59 |
60 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
61 | padding = 3 if kernel_size == 7 else 1
62 |
63 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False)
64 | self.sigmoid = nn.Sigmoid()
65 |
66 | def forward(self, x):
67 | max_out, _ = torch.max(x, dim=1, keepdim=True)
68 | x = max_out
69 | x = self.conv1(x)
70 | return self.sigmoid(x)
71 |
72 | # for conv5
73 | class ACCoM5(nn.Module):
74 | def __init__(self, cur_channel):
75 | super(ACCoM5, self).__init__()
76 | self.relu = nn.ReLU(True)
77 |
78 | # current conv
79 | self.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1)
80 | self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2)
81 | self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3)
82 | self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4)
83 |
84 | self.cur_all = BasicConv2d(4*cur_channel, cur_channel, 3, padding=1)
85 | self.cur_all_ca = ChannelAttention(cur_channel)
86 | self.cur_all_sa = SpatialAttention()
87 |
88 | # previous conv
89 | self.downsample2 = nn.MaxPool2d(2, stride=2)
90 | self.pre_sa = SpatialAttention()
91 |
92 | def forward(self, x_pre, x_cur):
93 | # current conv
94 | x_cur_1 = self.cur_b1(x_cur)
95 | x_cur_2 = self.cur_b2(x_cur)
96 | x_cur_3 = self.cur_b3(x_cur)
97 | x_cur_4 = self.cur_b4(x_cur)
98 | x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1))
99 | cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all))
100 | cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca))
101 |
102 | # previois conv
103 | x_pre = self.downsample2(x_pre)
104 | pre_sa = x_cur_all.mul(self.pre_sa(x_pre))
105 |
106 | x_LocAndGlo = cur_all_sa + pre_sa + x_cur
107 |
108 | return x_LocAndGlo
109 |
110 | # for conv1
111 | class ACCoM1(nn.Module):
112 | def __init__(self, cur_channel):
113 | super(ACCoM1, self).__init__()
114 | self.relu = nn.ReLU(True)
115 |
116 | # current conv
117 | self.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1)
118 | self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2)
119 | self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3)
120 | self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4)
121 |
122 | self.cur_all = BasicConv2d(4*cur_channel, cur_channel, 3, padding=1)
123 | self.cur_all_ca = ChannelAttention(cur_channel)
124 | self.cur_all_sa = SpatialAttention()
125 |
126 | # latter conv
127 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
128 | self.lat_sa = SpatialAttention()
129 |
130 | def forward(self, x_cur, x_lat):
131 | # current conv
132 | x_cur_1 = self.cur_b1(x_cur)
133 | x_cur_2 = self.cur_b2(x_cur)
134 | x_cur_3 = self.cur_b3(x_cur)
135 | x_cur_4 = self.cur_b4(x_cur)
136 | x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1))
137 | cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all))
138 | cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca))
139 |
140 | # latter conv
141 | x_lat = self.upsample2(x_lat)
142 | lat_sa = x_cur_all.mul(self.lat_sa(x_lat))
143 |
144 | x_LocAndGlo = cur_all_sa + lat_sa + x_cur
145 |
146 | return x_LocAndGlo
147 |
148 | # for conv2/3/4
149 | class ACCoM(nn.Module):
150 | def __init__(self, cur_channel):
151 | super(ACCoM, self).__init__()
152 | self.relu = nn.ReLU(True)
153 |
154 | # current conv
155 | self.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1)
156 | self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2)
157 | self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3)
158 | self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4)
159 |
160 | self.cur_all = BasicConv2d(4 * cur_channel, cur_channel, 3, padding=1)
161 | self.cur_all_ca = ChannelAttention(cur_channel)
162 | self.cur_all_sa = SpatialAttention()
163 |
164 | # previous conv
165 | self.downsample2 = nn.MaxPool2d(2, stride=2)
166 | self.pre_sa = SpatialAttention()
167 |
168 | # latter conv
169 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
170 | self.lat_sa = SpatialAttention()
171 |
172 | def forward(self, x_pre, x_cur, x_lat):
173 | # current conv
174 | x_cur_1 = self.cur_b1(x_cur)
175 | x_cur_2 = self.cur_b2(x_cur)
176 | x_cur_3 = self.cur_b3(x_cur)
177 | x_cur_4 = self.cur_b4(x_cur)
178 | x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1))
179 | cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all))
180 | cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca))
181 |
182 | # previois conv
183 | x_pre = self.downsample2(x_pre)
184 | pre_sa = x_cur_all.mul(self.pre_sa(x_pre))
185 |
186 | # latter conv
187 | x_lat = self.upsample2(x_lat)
188 | lat_sa = x_cur_all.mul(self.lat_sa(x_lat))
189 |
190 | x_LocAndGlo = cur_all_sa + pre_sa + lat_sa + x_cur
191 |
192 | return x_LocAndGlo
193 |
194 |
195 | class BAB_Decoder(nn.Module):
196 | def __init__(self, channel_1=1024, channel_2=512, channel_3=256, dilation_1=3, dilation_2=2):
197 | super(BAB_Decoder, self).__init__()
198 |
199 | self.conv1 = BasicConv2d(channel_1, channel_2, 3, padding=1)
200 | self.conv1_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_1, dilation=dilation_1)
201 | self.conv2 = BasicConv2d(channel_2, channel_2, 3, padding=1)
202 | self.conv2_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_2, dilation=dilation_2)
203 | self.conv3 = BasicConv2d(channel_2, channel_2, 3, padding=1)
204 | self.conv_fuse = BasicConv2d(channel_2*3, channel_3, 3, padding=1)
205 |
206 | def forward(self, x):
207 | x1 = self.conv1(x)
208 | x1_dila = self.conv1_Dila(x1)
209 |
210 | x2 = self.conv2(x1)
211 | x2_dila = self.conv2_Dila(x2)
212 |
213 | x3 = self.conv3(x2)
214 |
215 | x_fuse = self.conv_fuse(torch.cat((x1_dila, x2_dila, x3), 1))
216 |
217 | return x_fuse
218 |
219 |
220 | class decoder(nn.Module):
221 | def __init__(self, channel=512):
222 | super(decoder, self).__init__()
223 | self.relu = nn.ReLU(True)
224 |
225 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
226 |
227 | self.decoder5 = nn.Sequential(
228 | BAB_Decoder(512, 512, 512, 3, 2),
229 | nn.Dropout(0.5),
230 | TransBasicConv2d(512, 512, kernel_size=2, stride=2,
231 | padding=0, dilation=1, bias=False)
232 | )
233 | self.S5 = nn.Conv2d(512, 1, 3, stride=1, padding=1)
234 |
235 | self.decoder4 = nn.Sequential(
236 | BAB_Decoder(1024, 512, 256, 3, 2),
237 | nn.Dropout(0.5),
238 | TransBasicConv2d(256, 256, kernel_size=2, stride=2,
239 | padding=0, dilation=1, bias=False)
240 | )
241 | self.S4 = nn.Conv2d(256, 1, 3, stride=1, padding=1)
242 |
243 | self.decoder3 = nn.Sequential(
244 | BAB_Decoder(512, 256, 128, 5, 3),
245 | nn.Dropout(0.5),
246 | TransBasicConv2d(128, 128, kernel_size=2, stride=2,
247 | padding=0, dilation=1, bias=False)
248 | )
249 | self.S3 = nn.Conv2d(128, 1, 3, stride=1, padding=1)
250 |
251 | self.decoder2 = nn.Sequential(
252 | BAB_Decoder(256, 128, 64, 5, 3),
253 | nn.Dropout(0.5),
254 | TransBasicConv2d(64, 64, kernel_size=2, stride=2,
255 | padding=0, dilation=1, bias=False)
256 | )
257 | self.S2 = nn.Conv2d(64, 1, 3, stride=1, padding=1)
258 |
259 | self.decoder1 = nn.Sequential(
260 | BAB_Decoder(128, 64, 32, 5, 3)
261 | )
262 | self.S1 = nn.Conv2d(32, 1, 3, stride=1, padding=1)
263 |
264 |
265 | def forward(self, x5, x4, x3, x2, x1):
266 | # x5: 1/16, 512; x4: 1/8, 512; x3: 1/4, 256; x2: 1/2, 128; x1: 1/1, 64
267 | x5_up = self.decoder5(x5)
268 | s5 = self.S5(x5_up)
269 |
270 | x4_up = self.decoder4(torch.cat((x4, x5_up), 1))
271 | s4 = self.S4(x4_up)
272 |
273 | x3_up = self.decoder3(torch.cat((x3, x4_up), 1))
274 | s3 = self.S3(x3_up)
275 |
276 | x2_up = self.decoder2(torch.cat((x2, x3_up), 1))
277 | s2 = self.S2(x2_up)
278 |
279 | x1_up = self.decoder1(torch.cat((x1, x2_up), 1))
280 | s1 = self.S1(x1_up)
281 |
282 | return s1, s2, s3, s4, s5
283 |
284 |
285 | class ACCoNet_VGG(nn.Module):
286 | def __init__(self, channel=32):
287 | super(ACCoNet_VGG, self).__init__()
288 | #Backbone model
289 | self.vgg = VGG('rgb')
290 |
291 | self.ACCoM5 = ACCoM5(512)
292 | self.ACCoM4 = ACCoM(512)
293 | self.ACCoM3 = ACCoM(256)
294 | self.ACCoM2 = ACCoM(128)
295 | self.ACCoM1 = ACCoM1(64)
296 |
297 | # self.agg2_rgbd = aggregation(channel)
298 | self.decoder_rgb = decoder(512)
299 |
300 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
301 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
302 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
303 |
304 | self.sigmoid = nn.Sigmoid()
305 |
306 |
307 | def forward(self, x_rgb):
308 | x1_rgb = self.vgg.conv1(x_rgb)
309 | x2_rgb = self.vgg.conv2(x1_rgb)
310 | x3_rgb = self.vgg.conv3(x2_rgb)
311 | x4_rgb = self.vgg.conv4(x3_rgb)
312 | x5_rgb = self.vgg.conv5(x4_rgb)
313 |
314 | # up means update
315 | x5_ACCoM = self.ACCoM5(x4_rgb, x5_rgb)
316 | x4_ACCoM = self.ACCoM4(x3_rgb, x4_rgb, x5_rgb)
317 | x3_ACCoM = self.ACCoM3(x2_rgb, x3_rgb, x4_rgb)
318 | x2_ACCoM = self.ACCoM2(x1_rgb, x2_rgb, x3_rgb)
319 | x1_ACCoM = self.ACCoM1(x1_rgb, x2_rgb)
320 |
321 | s1, s2, s3, s4, s5 = self.decoder_rgb(x5_ACCoM, x4_ACCoM, x3_ACCoM, x2_ACCoM, x1_ACCoM)
322 |
323 | s3 = self.upsample2(s3)
324 | s4 = self.upsample4(s4)
325 | s5 = self.upsample8(s5)
326 |
327 | return s1, s2, s3, s4, s5, self.sigmoid(s1), self.sigmoid(s2), self.sigmoid(s3), self.sigmoid(s4), self.sigmoid(s5)
--------------------------------------------------------------------------------
/model/ResNet50.py:
--------------------------------------------------------------------------------
1 | # -*- coding: utf-8 -*-
2 |
3 | import torch.nn as nn
4 | import torch.utils.model_zoo as model_zoo
5 |
6 | __all__ = ["Backbone_ResNet50_in3"]
7 |
8 | model_urls = {
9 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth",
10 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth",
11 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth",
12 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth",
13 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth",
14 | }
15 |
16 |
17 | def conv3x3(in_planes, out_planes, stride=1):
18 | """3x3 convolution with padding"""
19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False)
20 |
21 |
22 | def conv1x1(in_planes, out_planes, stride=1):
23 | """1x1 convolution"""
24 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
25 |
26 |
27 | class BasicBlock(nn.Module):
28 | expansion = 1
29 |
30 | def __init__(self, inplanes, planes, stride=1, downsample=None):
31 | super(BasicBlock, self).__init__()
32 | self.conv1 = conv3x3(inplanes, planes, stride)
33 | self.bn1 = nn.BatchNorm2d(planes)
34 | self.relu = nn.ReLU(inplace=True)
35 | self.conv2 = conv3x3(planes, planes)
36 | self.bn2 = nn.BatchNorm2d(planes)
37 | self.downsample = downsample
38 | self.stride = stride
39 |
40 | def forward(self, x):
41 | identity = x
42 |
43 | out = self.conv1(x)
44 | out = self.bn1(out)
45 | out = self.relu(out)
46 |
47 | out = self.conv2(out)
48 | out = self.bn2(out)
49 |
50 | if self.downsample is not None:
51 | identity = self.downsample(x)
52 |
53 | out += identity
54 | out = self.relu(out)
55 |
56 | return out
57 |
58 |
59 | class Bottleneck(nn.Module):
60 | expansion = 4
61 |
62 | def __init__(self, inplanes, planes, stride=1, downsample=None):
63 | super(Bottleneck, self).__init__()
64 | self.conv1 = conv1x1(inplanes, planes)
65 | self.bn1 = nn.BatchNorm2d(planes)
66 | self.conv2 = conv3x3(planes, planes, stride)
67 | self.bn2 = nn.BatchNorm2d(planes)
68 | self.conv3 = conv1x1(planes, planes * self.expansion)
69 | self.bn3 = nn.BatchNorm2d(planes * self.expansion)
70 | self.relu = nn.ReLU(inplace=True)
71 | self.downsample = downsample
72 | self.stride = stride
73 |
74 | def forward(self, x):
75 | identity = x
76 |
77 | out = self.conv1(x)
78 | out = self.bn1(out)
79 | out = self.relu(out)
80 |
81 | out = self.conv2(out)
82 | out = self.bn2(out)
83 | out = self.relu(out)
84 |
85 | out = self.conv3(out)
86 | out = self.bn3(out)
87 |
88 | if self.downsample is not None:
89 | identity = self.downsample(x)
90 |
91 | out += identity
92 | out = self.relu(out)
93 |
94 | return out
95 |
96 |
97 | class ResNet(nn.Module):
98 | def __init__(self, block, layers, zero_init_residual=False):
99 | super(ResNet, self).__init__()
100 | self.inplanes = 64
101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
102 | self.bn1 = nn.BatchNorm2d(64)
103 | self.relu = nn.ReLU(inplace=True)
104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
105 | self.layer1 = self._make_layer(block, 64, layers[0])
106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) # 6
108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # 3
109 |
110 | for m in self.modules():
111 | if isinstance(m, nn.Conv2d):
112 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
113 | elif isinstance(m, nn.BatchNorm2d):
114 | nn.init.constant_(m.weight, 1)
115 | nn.init.constant_(m.bias, 0)
116 |
117 | # Zero-initialize the last BN in each residual branch,
118 | # so that the residual branch starts with zeros, and each residual block behaves like an identity.
119 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
120 | if zero_init_residual:
121 | for m in self.modules():
122 | if isinstance(m, Bottleneck):
123 | nn.init.constant_(m.bn3.weight, 0)
124 | elif isinstance(m, BasicBlock):
125 | nn.init.constant_(m.bn2.weight, 0)
126 |
127 | def _make_layer(self, block, planes, blocks, stride=1):
128 | downsample = None
129 | if stride != 1 or self.inplanes != planes * block.expansion:
130 | downsample = nn.Sequential(
131 | conv1x1(self.inplanes, planes * block.expansion, stride),
132 | nn.BatchNorm2d(planes * block.expansion),
133 | )
134 |
135 | layers = []
136 | layers.append(block(self.inplanes, planes, stride, downsample))
137 | self.inplanes = planes * block.expansion
138 | for _ in range(1, blocks):
139 | layers.append(block(self.inplanes, planes))
140 |
141 | return nn.Sequential(*layers)
142 |
143 | def forward(self, x):
144 | x = self.conv1(x)
145 | x = self.bn1(x)
146 | x = self.relu(x)
147 | x = self.maxpool(x)
148 |
149 | x = self.layer1(x)
150 | x = self.layer2(x)
151 | x = self.layer3(x)
152 | x = self.layer4(x)
153 |
154 | return x
155 |
156 |
157 | def resnet18(pretrained=False, **kwargs):
158 | """Constructs a ResNet-18 model.
159 |
160 | Args:
161 | pretrained (bool): If True, returns a model pre-trained on ImageNet
162 | """
163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs)
164 | if pretrained:
165 | pretrained_dict = model_zoo.load_url(model_urls["resnet18"])
166 |
167 | model_dict = model.state_dict()
168 | # 1. filter out unnecessary keys
169 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
170 | # 2. overwrite entries in the existing state dict
171 | model_dict.update(pretrained_dict)
172 | # 3. load the new state dict
173 | model.load_state_dict(model_dict)
174 | return model
175 |
176 |
177 | def resnet34(pretrained=False, **kwargs):
178 | """Constructs a ResNet-34 model.
179 |
180 | Args:
181 | pretrained (bool): If True, returns a model pre-trained on ImageNet
182 | """
183 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs)
184 | if pretrained:
185 | pretrained_dict = model_zoo.load_url(model_urls["resnet34"])
186 |
187 | model_dict = model.state_dict()
188 | # 1. filter out unnecessary keys
189 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
190 | # 2. overwrite entries in the existing state dict
191 | model_dict.update(pretrained_dict)
192 | # 3. load the new state dict
193 | model.load_state_dict(model_dict)
194 | return model
195 |
196 |
197 | def resnet50(pretrained=False, **kwargs):
198 | """Constructs a ResNet-50 model.
199 |
200 | Args:
201 | pretrained (bool): If True, returns a model pre-trained on ImageNet
202 | """
203 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
204 |
205 | if pretrained:
206 | pretrained_dict = model_zoo.load_url(model_urls["resnet50"])
207 |
208 | model_dict = model.state_dict()
209 | # 1. filter out unnecessary keys
210 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
211 | # 2. overwrite entries in the existing state dict
212 | model_dict.update(pretrained_dict)
213 | # 3. load the new state dict
214 | model.load_state_dict(model_dict)
215 |
216 | return model
217 |
218 |
219 | def resnet101(pretrained=False, **kwargs):
220 | """Constructs a ResNet-101 model.
221 |
222 | Args:
223 | pretrained (bool): If True, returns a model pre-trained on ImageNet
224 | """
225 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs)
226 | if pretrained:
227 | pretrained_dict = model_zoo.load_url(model_urls["resnet101"])
228 |
229 | model_dict = model.state_dict()
230 | # 1. filter out unnecessary keys
231 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
232 | # 2. overwrite entries in the existing state dict
233 | model_dict.update(pretrained_dict)
234 | # 3. load the new state dict
235 | model.load_state_dict(model_dict)
236 | return model
237 |
238 |
239 | def resnet152(pretrained=False, **kwargs):
240 | """Constructs a ResNet-152 model.
241 |
242 | Args:
243 | pretrained (bool): If True, returns a model pre-trained on ImageNet
244 | """
245 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs)
246 |
247 | if pretrained:
248 | pretrained_dict = model_zoo.load_url(model_urls["resnet152"])
249 |
250 | model_dict = model.state_dict()
251 | # 1. filter out unnecessary keys
252 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
253 | # 2. overwrite entries in the existing state dict
254 | model_dict.update(pretrained_dict)
255 | # 3. load the new state dict
256 | model.load_state_dict(model_dict)
257 |
258 | return model
259 |
260 |
261 | def Backbone_ResNet50_in3():
262 | net = resnet50(pretrained=True)
263 | div_2 = nn.Sequential(*list(net.children())[:3])
264 | div_4 = nn.Sequential(*list(net.children())[3:5])
265 | div_8 = net.layer2
266 | div_16 = net.layer3
267 | div_32 = net.layer4
268 |
269 | return div_2, div_4, div_8, div_16, div_32
270 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/vgg.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 |
5 | class VGG(nn.Module):
6 | # pooling layer at the front of block
7 | def __init__(self, mode = 'rgb'):
8 | super(VGG, self).__init__()
9 |
10 | conv1 = nn.Sequential()
11 | conv1.add_module('conv1_1', nn.Conv2d(3, 64, 3, 1, 1))
12 | conv1.add_module('bn1_1', nn.BatchNorm2d(64))
13 | conv1.add_module('relu1_1', nn.ReLU(inplace=True))
14 | conv1.add_module('conv1_2', nn.Conv2d(64, 64, 3, 1, 1))
15 | conv1.add_module('bn1_2', nn.BatchNorm2d(64))
16 | conv1.add_module('relu1_2', nn.ReLU(inplace=True))
17 |
18 | self.conv1 = conv1
19 | conv2 = nn.Sequential()
20 | conv2.add_module('pool1', nn.MaxPool2d(2, stride=2))
21 | conv2.add_module('conv2_1', nn.Conv2d(64, 128, 3, 1, 1))
22 | conv2.add_module('bn2_1', nn.BatchNorm2d(128))
23 | conv2.add_module('relu2_1', nn.ReLU())
24 | conv2.add_module('conv2_2', nn.Conv2d(128, 128, 3, 1, 1))
25 | conv2.add_module('bn2_2', nn.BatchNorm2d(128))
26 | conv2.add_module('relu2_2', nn.ReLU())
27 | self.conv2 = conv2
28 |
29 | conv3 = nn.Sequential()
30 | conv3.add_module('pool2', nn.MaxPool2d(2, stride=2))
31 | conv3.add_module('conv3_1', nn.Conv2d(128, 256, 3, 1, 1))
32 | conv3.add_module('bn3_1', nn.BatchNorm2d(256))
33 | conv3.add_module('relu3_1', nn.ReLU())
34 | conv3.add_module('conv3_2', nn.Conv2d(256, 256, 3, 1, 1))
35 | conv3.add_module('bn3_2', nn.BatchNorm2d(256))
36 | conv3.add_module('relu3_2', nn.ReLU())
37 | conv3.add_module('conv3_3', nn.Conv2d(256, 256, 3, 1, 1))
38 | conv3.add_module('bn3_3', nn.BatchNorm2d(256))
39 | conv3.add_module('relu3_3', nn.ReLU())
40 | self.conv3 = conv3
41 |
42 | conv4 = nn.Sequential()
43 | conv4.add_module('pool3_1', nn.MaxPool2d(2, stride=2))
44 | conv4.add_module('conv4_1', nn.Conv2d(256, 512, 3, 1, 1))
45 | conv4.add_module('bn4_1', nn.BatchNorm2d(512))
46 | conv4.add_module('relu4_1', nn.ReLU())
47 | conv4.add_module('conv4_2', nn.Conv2d(512, 512, 3, 1, 1))
48 | conv4.add_module('bn4_2', nn.BatchNorm2d(512))
49 | conv4.add_module('relu4_2', nn.ReLU())
50 | conv4.add_module('conv4_3', nn.Conv2d(512, 512, 3, 1, 1))
51 | conv4.add_module('bn4_3', nn.BatchNorm2d(512))
52 | conv4.add_module('relu4_3', nn.ReLU())
53 | self.conv4 = conv4
54 |
55 | conv5 = nn.Sequential()
56 | conv5.add_module('pool4', nn.MaxPool2d(2, stride=2))
57 | conv5.add_module('conv5_1', nn.Conv2d(512, 512, 3, 1, 1))
58 | conv5.add_module('bn5_1', nn.BatchNorm2d(512))
59 | conv5.add_module('relu5_1', nn.ReLU())
60 | conv5.add_module('conv5_2', nn.Conv2d(512, 512, 3, 1, 1))
61 | conv5.add_module('bn5_2', nn.BatchNorm2d(512))
62 | conv5.add_module('relu5_2', nn.ReLU())
63 | conv5.add_module('conv5_3', nn.Conv2d(512, 512, 3, 1, 1))
64 | conv5.add_module('bn5_2', nn.BatchNorm2d(512))
65 | conv5.add_module('relu5_3', nn.ReLU())
66 | self.conv5 = conv5
67 |
68 | pre_train = torch.load('/home/lgy/20210206_ORSI_SOD/model/vgg16-397923af.pth')
69 | self._initialize_weights(pre_train)
70 |
71 | def forward(self, x):
72 | x = self.conv1(x)
73 | x = self.conv2(x)
74 | x = self.conv3(x)
75 | x = self.conv4(x)
76 | x = self.conv5(x)
77 |
78 | return x
79 |
80 | def _initialize_weights(self, pre_train):
81 | keys = pre_train.keys()
82 |
83 | self.conv1.conv1_1.weight.data.copy_(pre_train[keys[0]])
84 | self.conv1.conv1_2.weight.data.copy_(pre_train[keys[2]])
85 | self.conv2.conv2_1.weight.data.copy_(pre_train[keys[4]])
86 | self.conv2.conv2_2.weight.data.copy_(pre_train[keys[6]])
87 | self.conv3.conv3_1.weight.data.copy_(pre_train[keys[8]])
88 | self.conv3.conv3_2.weight.data.copy_(pre_train[keys[10]])
89 | self.conv3.conv3_3.weight.data.copy_(pre_train[keys[12]])
90 | self.conv4.conv4_1.weight.data.copy_(pre_train[keys[14]])
91 | self.conv4.conv4_2.weight.data.copy_(pre_train[keys[16]])
92 | self.conv4.conv4_3.weight.data.copy_(pre_train[keys[18]])
93 | self.conv5.conv5_1.weight.data.copy_(pre_train[keys[20]])
94 | self.conv5.conv5_2.weight.data.copy_(pre_train[keys[22]])
95 | self.conv5.conv5_3.weight.data.copy_(pre_train[keys[24]])
96 |
97 | self.conv1.conv1_1.bias.data.copy_(pre_train[keys[1]])
98 | self.conv1.conv1_2.bias.data.copy_(pre_train[keys[3]])
99 | self.conv2.conv2_1.bias.data.copy_(pre_train[keys[5]])
100 | self.conv2.conv2_2.bias.data.copy_(pre_train[keys[7]])
101 | self.conv3.conv3_1.bias.data.copy_(pre_train[keys[9]])
102 | self.conv3.conv3_2.bias.data.copy_(pre_train[keys[11]])
103 | self.conv3.conv3_3.bias.data.copy_(pre_train[keys[13]])
104 | self.conv4.conv4_1.bias.data.copy_(pre_train[keys[15]])
105 | self.conv4.conv4_2.bias.data.copy_(pre_train[keys[17]])
106 | self.conv4.conv4_3.bias.data.copy_(pre_train[keys[19]])
107 | self.conv5.conv5_1.bias.data.copy_(pre_train[keys[21]])
108 | self.conv5.conv5_2.bias.data.copy_(pre_train[keys[23]])
109 | self.conv5.conv5_3.bias.data.copy_(pre_train[keys[25]])
110 |
--------------------------------------------------------------------------------
/pytorch_iou/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
3 | import torch
4 | import torch.nn.functional as F
5 | from torch.autograd import Variable
6 | import numpy as np
7 |
8 | def _iou(pred, target, size_average = True):
9 |
10 | b = pred.shape[0]
11 | IoU = 0.0
12 | for i in range(0,b):
13 | #compute the IoU of the foreground
14 | Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:])
15 | Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1
16 | IoU1 = Iand1/Ior1
17 |
18 | #IoU loss is (1-IoU1)
19 | IoU = IoU + (1-IoU1)
20 |
21 | return IoU/b
22 |
23 | class IOU(torch.nn.Module):
24 | def __init__(self, size_average = True):
25 | super(IOU, self).__init__()
26 | self.size_average = size_average
27 |
28 | def forward(self, pred, target):
29 |
30 | return _iou(pred, target, self.size_average)
31 |
--------------------------------------------------------------------------------
/pytorch_iou/__init__.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/ACCoNet/d37bd207286b9e522803d3695e7180e4941ae117/pytorch_iou/__init__.pyc
--------------------------------------------------------------------------------
/pytorch_iou/__pycache__/__init__.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MathLee/ACCoNet/d37bd207286b9e522803d3695e7180e4941ae117/pytorch_iou/__pycache__/__init__.cpython-36.pyc
--------------------------------------------------------------------------------
/test_ACCoNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 |
4 | import numpy as np
5 | import pdb, os, argparse
6 | from scipy import misc
7 | import time
8 |
9 | from model.ACCoNet_VGG_models import ACCoNet_VGG
10 | from model.ACCoNet_Res_models import ACCoNet_Res
11 | from data import test_dataset
12 |
13 | torch.cuda.set_device(0)
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--testsize', type=int, default=256, help='testing size')
16 | parser.add_argument('--is_ResNet', type=bool, default=True, help='VGG or ResNet backbone')
17 | opt = parser.parse_args()
18 |
19 | dataset_path = './dataset/test_dataset/'
20 |
21 | if opt.is_ResNet:
22 | model = ACCoNet_Res()
23 | model.load_state_dict(torch.load('./models/ACCoNet_ResNet/ACCoNet_Res.pth.39'))
24 | else:
25 | model = ACCoNet_VGG()
26 | model.load_state_dict(torch.load('./models/ACCoNet_VGG/ACCoNet_VGG.pth.54'))
27 |
28 | model.cuda()
29 | model.eval()
30 |
31 | # test_datasets = ['EORSSD']
32 | test_datasets = ['ORSSD']
33 |
34 | for dataset in test_datasets:
35 | if opt.is_ResNet:
36 | save_path = './results/ResNet50/' + dataset + '/'
37 | else:
38 | save_path = './results/VGG/' + dataset + '/'
39 | if not os.path.exists(save_path):
40 | os.makedirs(save_path)
41 | image_root = dataset_path + dataset + '/image/'
42 | print(dataset)
43 | gt_root = dataset_path + dataset + '/GT/'
44 | test_loader = test_dataset(image_root, gt_root, opt.testsize)
45 | time_sum = 0
46 | for i in range(test_loader.size):
47 | image, gt, name = test_loader.load_data()
48 | gt = np.asarray(gt, np.float32)
49 | gt /= (gt.max() + 1e-8)
50 | image = image.cuda()
51 | time_start = time.time()
52 | res, s2, s3, s4, s5, s1_sig, s2_sig, s3_sig, s4_sig, s5_sig = model(image)
53 | time_end = time.time()
54 | time_sum = time_sum+(time_end-time_start)
55 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False)
56 | res = res.sigmoid().data.cpu().numpy().squeeze()
57 | res = (res - res.min()) / (res.max() - res.min() + 1e-8)
58 | misc.imsave(save_path+name, res)
59 | if i == test_loader.size-1:
60 | print('Running time {:.5f}'.format(time_sum/test_loader.size))
61 | print('Average speed: {:.4f} fps'.format(test_loader.size/time_sum))
--------------------------------------------------------------------------------
/train_ACCoNet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.autograd import Variable
5 |
6 | import numpy as np
7 | import pdb, os, argparse
8 | from datetime import datetime
9 |
10 | from model.ACCoNet_VGG_models import ACCoNet_VGG
11 | from model.ACCoNet_Res_models import ACCoNet_Res
12 | from data import get_loader
13 | from utils import clip_gradient, adjust_lr
14 |
15 | import pytorch_iou
16 |
17 | torch.cuda.set_device(0)
18 | parser = argparse.ArgumentParser()
19 | parser.add_argument('--epoch', type=int, default=40, help='epoch number')
20 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
21 | % For vgg, batchsize is 6; for ResNet, batchsize is 8.
22 | parser.add_argument('--batchsize', type=int, default=6, help='training batch size')
23 | parser.add_argument('--trainsize', type=int, default=256, help='training dataset size')
24 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
25 | parser.add_argument('--is_ResNet', type=bool, default=False, help='VGG or ResNet backbone')
26 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate')
27 | parser.add_argument('--decay_epoch', type=int, default=30, help='every n epochs decay learning rate')
28 | opt = parser.parse_args()
29 |
30 | print('Learning Rate: {} ResNet: {}'.format(opt.lr, opt.is_ResNet))
31 | # build models
32 | if opt.is_ResNet:
33 | model = ACCoNet_Res()
34 | else:
35 | model = ACCoNet_VGG()
36 |
37 | model.cuda()
38 | params = model.parameters()
39 | optimizer = torch.optim.Adam(params, opt.lr)
40 |
41 | image_root = './dataset/train_dataset/ORSSD/train/image/'
42 | gt_root = './dataset/train_dataset/ORSSD/train/GT/'
43 | # image_root = './dataset/train_dataset/EORSSD/train/image/'
44 | # gt_root = './dataset/train_dataset/EORSSD/train/GT/'
45 | train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
46 | total_step = len(train_loader)
47 |
48 | CE = torch.nn.BCEWithLogitsLoss()
49 | IOU = pytorch_iou.IOU(size_average = True)
50 |
51 | def train(train_loader, model, optimizer, epoch):
52 | model.train()
53 | for i, pack in enumerate(train_loader, start=1):
54 | optimizer.zero_grad()
55 | images, gts = pack
56 | images = Variable(images)
57 | gts = Variable(gts)
58 | images = images.cuda()
59 | gts = gts.cuda()
60 |
61 | s1, s2, s3, s4, s5, s1_sig, s2_sig, s3_sig, s4_sig, s5_sig = model(images)
62 |
63 | loss1 = CE(s1, gts) + IOU(s1_sig, gts)
64 | loss2 = CE(s2, gts) + IOU(s2_sig, gts)
65 | loss3 = CE(s3, gts) + IOU(s3_sig, gts)
66 | loss4 = CE(s4, gts) + IOU(s4_sig, gts)
67 | loss5 = CE(s5, gts) + IOU(s5_sig, gts)
68 |
69 | loss = loss1 + loss2 + loss3 + loss4 + loss5
70 |
71 | loss.backward()
72 |
73 | clip_gradient(optimizer, opt.clip)
74 | optimizer.step()
75 |
76 | if i % 20 == 0 or i == total_step:
77 | print(
78 | '{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Learning Rate: {}, Loss: {:.4f}, Loss_ce: {:.4f}, Loss_iou: {:.4f}'.
79 | format(datetime.now(), epoch, opt.epoch, i, total_step, opt.lr * opt.decay_rate ** (epoch // opt.decay_epoch), loss.data, loss1.data, loss2.data))
80 |
81 | if opt.is_ResNet:
82 | save_path = 'models/ACCoNet_Res/'
83 | else:
84 | save_path = 'models/ACCoNet_VGG/'
85 |
86 | if not os.path.exists(save_path):
87 | os.makedirs(save_path)
88 | if (epoch+1) % 5 == 0:
89 | if opt.is_ResNet:
90 | torch.save(model.state_dict(), save_path + 'ACCoNet_ResNet.pth' + '.%d' % epoch)
91 | else:
92 | torch.save(model.state_dict(), save_path + 'ACCoNet_VGG.pth' + '.%d' % epoch)
93 |
94 | print("Let's go!")
95 | for epoch in range(1, opt.epoch):
96 | adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch)
97 | train(train_loader, model, optimizer, epoch)
98 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | def clip_gradient(optimizer, grad_clip):
2 | for group in optimizer.param_groups:
3 | for param in group['params']:
4 | if param.grad is not None:
5 | param.grad.data.clamp_(-grad_clip, grad_clip)
6 |
7 |
8 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30):
9 | decay = decay_rate ** (epoch // decay_epoch)
10 | for param_group in optimizer.param_groups:
11 | param_group['lr'] = init_lr*decay
12 | print('decay_epoch: {}, Current_LR: {}'.format(decay_epoch, init_lr*decay))
--------------------------------------------------------------------------------