├── README.md ├── data.py ├── data_aug.m ├── image ├── ACCoNet.png └── table.png ├── model ├── ACCoNet_Res_models.py ├── ACCoNet_VGG_models.py ├── ResNet50.py ├── __init__.py └── vgg.py ├── pytorch_iou ├── __init__.py ├── __init__.pyc └── __pycache__ │ └── __init__.cpython-36.pyc ├── test_ACCoNet.py ├── train_ACCoNet.py └── utils.py /README.md: -------------------------------------------------------------------------------- 1 | # ACCoNet 2 | This project provides the code and results for 'Adjacent Context Coordination Network for Salient Object Detection in Optical Remote Sensing Images', IEEE TCYB, 2023. [IEEE link](https://ieeexplore.ieee.org/document/9756652) and [arxiv link](https://arxiv.org/abs/2203.13664) [Homepage](https://mathlee.github.io/) 3 | 4 | # Network Architecture 5 |
6 | 7 |
8 | 9 | 10 | # Requirements 11 | python 2.7 + pytorch 0.4.0 or 12 | 13 | python 3.7 + pytorch 1.9.0 14 | 15 | 16 | # Saliency maps 17 | We provide saliency maps of our ACCoNet ([VGG_backbone](https://pan.baidu.com/s/11KzUltnKIwbYFbEXtud2gQ) (code: gr06) and [ResNet_backbone](https://pan.baidu.com/s/1_ksAXbRrMWupToCxcSDa8g) (code: 1hpn)) on ORSSD, EORSSD, and additional [ORSI-4199](https://github.com/wchao1213/ORSI-SOD) datasets. 18 | 19 | ![Image](https://github.com/MathLee/ACCoNet/blob/main/image/table.png) 20 | 21 | # Training 22 | 23 | We provide the code for ACCoNet_VGG and ACCoNet_ResNet, please modify '--is_ResNet' and the paths of datasets in train_ACCoNet.py. 24 | 25 | For ACCoNet_VGG, please modify paths of [VGG backbone](https://pan.baidu.com/s/1YQxKZ-y2C4EsqrgKNI7qrw) (code: ego5) in /model/vgg.py. 26 | 27 | data_aug.m is used for data augmentation. 28 | 29 | 30 | # Pre-trained model and testing 31 | 1. Download the following pre-trained models and put them in /models. 32 | 33 | 2. Modify paths of pre-trained models and datasets. 34 | 35 | 3. Run test_ACCoNet.py. 36 | 37 | ORSSD: [ACCoNet_VGG](https://pan.baidu.com/s/1mPb7oyaz9OVKs3T9v4xCmw) (code: 1bsg); [ACCoNet_ResNet](https://pan.baidu.com/s/1UhHLxgBvMgD66jz2SKgclw) (code: mv91). 38 | 39 | EORSSD: [ACCoNet_VGG](https://pan.baidu.com/s/1R2mFox8rEyxH1DTTnMinLA) (code: i016); [ACCoNet_ResNet](https://pan.baidu.com/s/1-TkZcxR6fBNYWKljhL1Qrg) (code: ak5m). 40 | 41 | ORSI-4199: [ACCoNet_VGG](https://pan.baidu.com/s/1WUVmVCwICBEM3gUJxQ5pkw) (code: qv05); [ACCoNet_ResNet](https://pan.baidu.com/s/1I4RWaLDx4ukK8_11y1AEtw) (code: art7). 42 | 43 | 44 | # Evaluation Tool 45 | You can use the [evaluation tool (MATLAB version)](https://github.com/MathLee/MatlabEvaluationTools) to evaluate the above saliency maps. 46 | 47 | 48 | # [ORSI-SOD_Summary](https://github.com/MathLee/ORSI-SOD_Summary) 49 | 50 | # Citation 51 | @ARTICLE{Li_2023_ACCoNet, 52 | author = {Gongyang Li and Zhi Liu and Dan Zeng and Weisi Lin and Haibin Ling}, 53 | title = {Adjacent Context Coordination Network for Salient Object Detection in Optical Remote Sensing Images}, 54 | journal = {IEEE Transactions on Cybernetics}, 55 | volume = {53}, 56 | number = {1}, 57 | pages = {526-538}, 58 | year = {2023}, 59 | month = {Jan.}, 60 | } 61 | 62 | 63 | If you encounter any problems with the code, want to report bugs, etc. 64 | 65 | Please contact me at lllmiemie@163.com or ligongyang@shu.edu.cn. 66 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | 9 | 10 | #several data augumentation strategies 11 | def cv_random_flip(img, label): 12 | flip_flag = random.randint(0, 1) 13 | # flip_flag2= random.randint(0,1) 14 | #left right flip 15 | if flip_flag == 1: 16 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 17 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 18 | #top bottom flip 19 | # if flip_flag2==1: 20 | # img = img.transpose(Image.FLIP_TOP_BOTTOM) 21 | # label = label.transpose(Image.FLIP_TOP_BOTTOM) 22 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM) 23 | return img, label 24 | def randomCrop(image, label): 25 | border=30 26 | image_width = image.size[0] 27 | image_height = image.size[1] 28 | crop_win_width = np.random.randint(image_width-border , image_width) 29 | crop_win_height = np.random.randint(image_height-border , image_height) 30 | random_region = ( 31 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 32 | (image_height + crop_win_height) >> 1) 33 | return image.crop(random_region), label.crop(random_region) 34 | def randomRotation(image,label): 35 | mode=Image.BICUBIC 36 | if random.random()>0.8: 37 | random_angle = np.random.randint(-15, 15) 38 | image=image.rotate(random_angle, mode) 39 | label=label.rotate(random_angle, mode) 40 | return image,label 41 | def colorEnhance(image): 42 | bright_intensity=random.randint(5,15)/10.0 43 | image=ImageEnhance.Brightness(image).enhance(bright_intensity) 44 | contrast_intensity=random.randint(5,15)/10.0 45 | image=ImageEnhance.Contrast(image).enhance(contrast_intensity) 46 | color_intensity=random.randint(0,20)/10.0 47 | image=ImageEnhance.Color(image).enhance(color_intensity) 48 | sharp_intensity=random.randint(0,30)/10.0 49 | image=ImageEnhance.Sharpness(image).enhance(sharp_intensity) 50 | return image 51 | def randomGaussian(image, mean=0.1, sigma=0.35): 52 | def gaussianNoisy(im, mean=mean, sigma=sigma): 53 | for _i in range(len(im)): 54 | im[_i] += random.gauss(mean, sigma) 55 | return im 56 | img = np.asarray(image) 57 | width, height = img.shape 58 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 59 | img = img.reshape([width, height]) 60 | return Image.fromarray(np.uint8(img)) 61 | def randomPeper(img): 62 | 63 | img=np.array(img) 64 | noiseNum=int(0.0015*img.shape[0]*img.shape[1]) 65 | for i in range(noiseNum): 66 | 67 | randX=random.randint(0,img.shape[0]-1) 68 | 69 | randY=random.randint(0,img.shape[1]-1) 70 | 71 | if random.randint(0,1)==0: 72 | 73 | img[randX,randY]=0 74 | 75 | else: 76 | 77 | img[randX,randY]=255 78 | return Image.fromarray(img) 79 | 80 | 81 | class SalObjDataset(data.Dataset): 82 | def __init__(self, image_root, gt_root, trainsize): 83 | self.trainsize = trainsize 84 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 85 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 86 | or f.endswith('.png')] 87 | self.images = sorted(self.images) 88 | # self.depths = sorted(self.depths) 89 | self.gts = sorted(self.gts) 90 | self.filter_files() 91 | self.size = len(self.images) 92 | self.img_transform = transforms.Compose([ 93 | transforms.Resize((self.trainsize, self.trainsize)), 94 | transforms.ToTensor(), 95 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 96 | 97 | self.gt_transform = transforms.Compose([ 98 | transforms.Resize((self.trainsize, self.trainsize)), 99 | transforms.ToTensor()]) 100 | 101 | def __getitem__(self, index): 102 | image = self.rgb_loader(self.images[index]) 103 | gt = self.binary_loader(self.gts[index]) 104 | # image,gt =cv_randop0;l....... 105 | image = self.img_transform(image) 106 | gt = self.gt_transform(gt) 107 | return image, gt 108 | 109 | def filter_files(self): 110 | assert len(self.images) == len(self.gts) 111 | images = [] 112 | # depths = [] 113 | gts = [] 114 | for img_path, gt_path in zip(self.images, self.gts): 115 | img = Image.open(img_path) 116 | gt = Image.open(gt_path) 117 | if img.size == gt.size: 118 | images.append(img_path) 119 | gts.append(gt_path) 120 | self.images = images 121 | self.gts = gts 122 | 123 | def rgb_loader(self, path): 124 | with open(path, 'rb') as f: 125 | img = Image.open(f) 126 | return img.convert('RGB') 127 | 128 | def binary_loader(self, path): 129 | with open(path, 'rb') as f: 130 | img = Image.open(f) 131 | # return img.convert('1') 132 | return img.convert('L') 133 | 134 | def resize(self, img, gt): 135 | assert img.size == gt.size 136 | w, h = img.size 137 | if h < self.trainsize or w < self.trainsize: 138 | h = max(h, self.trainsize) 139 | w = max(w, self.trainsize) 140 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST) 141 | else: 142 | return img, gt 143 | 144 | def __len__(self): 145 | return self.size 146 | 147 | 148 | def get_loader(image_root, gt_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True): 149 | 150 | dataset = SalObjDataset(image_root, gt_root, trainsize) 151 | data_loader = data.DataLoader(dataset=dataset, 152 | batch_size=batchsize, 153 | shuffle=shuffle, 154 | num_workers=num_workers, 155 | pin_memory=pin_memory) 156 | return data_loader 157 | 158 | 159 | class test_dataset: 160 | def __init__(self, image_root, gt_root, testsize): 161 | self.testsize = testsize 162 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 163 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 164 | or f.endswith('.png')] 165 | self.images = sorted(self.images) 166 | # self.depths = sorted(self.depths) 167 | self.gts = sorted(self.gts) 168 | self.img_transform = transforms.Compose([ 169 | transforms.Resize((self.testsize, self.testsize)), 170 | transforms.ToTensor(), 171 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 172 | self.gt_transform = transforms.ToTensor() 173 | self.size = len(self.images) 174 | self.index = 0 175 | 176 | def load_data(self): 177 | image = self.rgb_loader(self.images[self.index]) 178 | image = self.img_transform(image).unsqueeze(0) 179 | gt = self.binary_loader(self.gts[self.index]) 180 | name = self.images[self.index].split('/')[-1] 181 | if name.endswith('.jpg'): 182 | name = name.split('.jpg')[0] + '.png' 183 | self.index += 1 184 | return image, gt, name 185 | 186 | def rgb_loader(self, path): 187 | with open(path, 'rb') as f: 188 | img = Image.open(f) 189 | return img.convert('RGB') 190 | 191 | def binary_loader(self, path): 192 | with open(path, 'rb') as f: 193 | img = Image.open(f) 194 | return img.convert('L') -------------------------------------------------------------------------------- /data_aug.m: -------------------------------------------------------------------------------- 1 | clear; 2 | clc; 3 | close all; 4 | 5 | imPath = '/home/lgy/桌面/ORSI_SOD/dataset/EORSSD/train/image/'; 6 | GtPath = '/home/lgy/桌面/ORSI_SOD/dataset/EORSSD/train/GT/'; 7 | EGPath = '/home/lgy/桌面/ORSI_SOD/dataset/EORSSD/train/edge/'; 8 | 9 | images = dir([GtPath '*.png']); 10 | imagesNum = length(images); 11 | 12 | for i = 1 : imagesNum 13 | im_name = images(i).name(1:end-4); 14 | 15 | gt = imread(fullfile(GtPath, [im_name '.png'])); 16 | im = imread(fullfile(imPath, [im_name '.jpg'])); 17 | eg = imread(fullfile(EGPath, [im_name '.png'])); 18 | 19 | im_1 = imrotate(im,90); 20 | gt_1 = imrotate(gt,90); 21 | eg_1 = imrotate(eg,90); 22 | imwrite(im_1, fullfile(imPath, [im_name '_90.jpg'])); 23 | imwrite(gt_1, fullfile(GtPath, [im_name '_90.png'])); 24 | imwrite(eg_1, fullfile(EGPath, [im_name '_90.png'])); 25 | 26 | im_2 = imrotate(im, 180); 27 | gt_2 = imrotate(gt, 180); 28 | eg_2 = imrotate(eg, 180); 29 | imwrite(im_2, fullfile(imPath, [im_name '_180.jpg'])); 30 | imwrite(gt_2, fullfile(GtPath, [im_name '_180.png'])); 31 | imwrite(eg_2, fullfile(EGPath, [im_name '_180.png'])); 32 | 33 | im_3 = imrotate(im, 270); 34 | gt_3 = imrotate(gt, 270); 35 | eg_3 = imrotate(eg, 270); 36 | imwrite(im_3, fullfile(imPath, [im_name '_270.jpg'])); 37 | imwrite(gt_3, fullfile(GtPath, [im_name '_270.png'])); 38 | imwrite(eg_3, fullfile(EGPath, [im_name '_270.png'])); 39 | 40 | 41 | fl_im = fliplr(im); 42 | fl_gt = fliplr(gt); 43 | fl_eg = fliplr(eg); 44 | imwrite(fl_im, fullfile(imPath, [im_name '_fl.jpg'])); 45 | imwrite(fl_gt, fullfile(GtPath, [im_name '_fl.png'])); 46 | imwrite(fl_eg, fullfile(EGPath, [im_name '_fl.png'])); 47 | 48 | im_1 = imrotate(fl_im,90); 49 | gt_1 = imrotate(fl_gt,90); 50 | eg_1 = imrotate(fl_eg,90); 51 | imwrite(im_1, fullfile(imPath, [im_name '_fl90.jpg'])); 52 | imwrite(gt_1, fullfile(GtPath, [im_name '_fl90.png'])); 53 | imwrite(eg_1, fullfile(EGPath, [im_name '_fl90.png'])); 54 | 55 | im_2 = imrotate(fl_im, 180); 56 | gt_2 = imrotate(fl_gt, 180); 57 | eg_2 = imrotate(fl_eg, 180); 58 | imwrite(im_2, fullfile(imPath, [im_name '_fl180.jpg'])); 59 | imwrite(gt_2, fullfile(GtPath, [im_name '_fl180.png'])); 60 | imwrite(eg_2, fullfile(EGPath, [im_name '_fl180.png'])); 61 | 62 | im_3 = imrotate(fl_im, 270); 63 | gt_3 = imrotate(fl_gt, 270); 64 | eg_3 = imrotate(fl_eg, 270); 65 | imwrite(im_3, fullfile(imPath, [im_name '_fl270.jpg'])); 66 | imwrite(gt_3, fullfile(GtPath, [im_name '_fl270.png'])); 67 | imwrite(eg_3, fullfile(EGPath, [im_name '_fl270.png'])); 68 | 69 | end 70 | 71 | -------------------------------------------------------------------------------- /image/ACCoNet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/ACCoNet/d37bd207286b9e522803d3695e7180e4941ae117/image/ACCoNet.png -------------------------------------------------------------------------------- /image/table.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/ACCoNet/d37bd207286b9e522803d3695e7180e4941ae117/image/table.png -------------------------------------------------------------------------------- /model/ACCoNet_Res_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from ResNet50 import Backbone_ResNet50_in3 5 | 6 | 7 | class BasicConv2d(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 9 | super(BasicConv2d, self).__init__() 10 | self.conv = nn.Conv2d(in_planes, out_planes, 11 | kernel_size=kernel_size, stride=stride, 12 | padding=padding, dilation=dilation, bias=False) 13 | self.bn = nn.BatchNorm2d(out_planes) 14 | self.relu = nn.ReLU(inplace=True) 15 | 16 | def forward(self, x): 17 | x = self.conv(x) 18 | x = self.bn(x) 19 | x = self.relu(x) 20 | return x 21 | 22 | class TransBasicConv2d(nn.Module): 23 | def __init__(self, in_planes, out_planes, kernel_size=2, stride=2, padding=0, dilation=1, bias=False): 24 | super(TransBasicConv2d, self).__init__() 25 | self.Deconv = nn.ConvTranspose2d(in_planes, out_planes, 26 | kernel_size=kernel_size, stride=stride, 27 | padding=padding, dilation=dilation, bias=False) 28 | self.bn = nn.BatchNorm2d(out_planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | 31 | def forward(self, x): 32 | x = self.Deconv(x) 33 | x = self.bn(x) 34 | x = self.relu(x) 35 | return x 36 | 37 | 38 | class ChannelAttention(nn.Module): 39 | def __init__(self, in_planes, ratio=16): 40 | super(ChannelAttention, self).__init__() 41 | 42 | self.max_pool = nn.AdaptiveMaxPool2d(1) 43 | 44 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 45 | self.relu1 = nn.ReLU() 46 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 47 | 48 | self.sigmoid = nn.Sigmoid() 49 | 50 | def forward(self, x): 51 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 52 | out = max_out 53 | return self.sigmoid(out) 54 | 55 | 56 | class SpatialAttention(nn.Module): 57 | def __init__(self, kernel_size=7): 58 | super(SpatialAttention, self).__init__() 59 | 60 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 61 | padding = 3 if kernel_size == 7 else 1 62 | 63 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False) 64 | self.sigmoid = nn.Sigmoid() 65 | 66 | def forward(self, x): 67 | max_out, _ = torch.max(x, dim=1, keepdim=True) 68 | x = max_out 69 | x = self.conv1(x) 70 | return self.sigmoid(x) 71 | 72 | # for conv5 73 | class ACCoM_5(nn.Module): 74 | def __init__(self, cur_channel): 75 | super(ACCoM_5, self).__init__() 76 | self.relu = nn.ReLU(True) 77 | 78 | # current conv 79 | self.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1) 80 | self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2) 81 | self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3) 82 | self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4) 83 | 84 | self.cur_all = BasicConv2d(4*cur_channel, cur_channel, 3, padding=1) 85 | self.cur_all_ca = ChannelAttention(cur_channel) 86 | self.cur_all_sa = SpatialAttention() 87 | 88 | # previous conv 89 | self.downsample2 = nn.MaxPool2d(2, stride=2) 90 | self.pre_sa = SpatialAttention() 91 | 92 | # for m in self.modules(): 93 | # if isinstance(m, nn.Conv2d): 94 | # m.weight.data.normal_(std=0.01) 95 | # m.bias.data.fill_(0) 96 | 97 | def forward(self, x_pre, x_cur): 98 | # current conv 99 | x_cur_1 = self.cur_b1(x_cur) 100 | x_cur_2 = self.cur_b2(x_cur) 101 | x_cur_3 = self.cur_b3(x_cur) 102 | x_cur_4 = self.cur_b4(x_cur) 103 | x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1)) 104 | cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all)) 105 | cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca)) 106 | 107 | # previois conv 108 | x_pre = self.downsample2(x_pre) 109 | pre_sa = x_cur_all.mul(self.pre_sa(x_pre)) 110 | 111 | x_LocAndGlo = cur_all_sa + pre_sa + x_cur 112 | 113 | return x_LocAndGlo 114 | 115 | # for conv1 116 | class ACCoM_1(nn.Module): 117 | def __init__(self, cur_channel): 118 | super(ACCoM_1, self).__init__() 119 | self.relu = nn.ReLU(True) 120 | 121 | # current conv 122 | self.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1) 123 | self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2) 124 | self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3) 125 | self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4) 126 | 127 | self.cur_all = BasicConv2d(4*cur_channel, cur_channel, 3, padding=1) 128 | self.cur_all_ca = ChannelAttention(cur_channel) 129 | self.cur_all_sa = SpatialAttention() 130 | 131 | # latter conv 132 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 133 | self.lat_sa = SpatialAttention() 134 | 135 | # for m in self.modules(): 136 | # if isinstance(m, nn.Conv2d): 137 | # m.weight.data.normal_(std=0.01) 138 | # m.bias.data.fill_(0) 139 | 140 | def forward(self, x_cur, x_lat): 141 | # current conv 142 | x_cur_1 = self.cur_b1(x_cur) 143 | x_cur_2 = self.cur_b2(x_cur) 144 | x_cur_3 = self.cur_b3(x_cur) 145 | x_cur_4 = self.cur_b4(x_cur) 146 | x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1)) 147 | cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all)) 148 | cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca)) 149 | 150 | # latter conv 151 | x_lat = self.upsample2(x_lat) 152 | lat_sa = x_cur_all.mul(self.lat_sa(x_lat)) 153 | 154 | x_LocAndGlo = cur_all_sa + lat_sa + x_cur 155 | 156 | return x_LocAndGlo 157 | 158 | # for conv2/3/4 159 | class ACCoM(nn.Module): 160 | def __init__(self, cur_channel): 161 | super(ACCoM, self).__init__() 162 | self.relu = nn.ReLU(True) 163 | 164 | # current conv 165 | self.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1) 166 | self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2) 167 | self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3) 168 | self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4) 169 | 170 | self.cur_all = BasicConv2d(4 * cur_channel, cur_channel, 3, padding=1) 171 | self.cur_all_ca = ChannelAttention(cur_channel) 172 | self.cur_all_sa = SpatialAttention() 173 | 174 | # previous conv 175 | self.downsample2 = nn.MaxPool2d(2, stride=2) 176 | self.pre_sa = SpatialAttention() 177 | 178 | # latter conv 179 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 180 | self.lat_sa = SpatialAttention() 181 | 182 | # for m in self.modules(): 183 | # if isinstance(m, nn.Conv2d): 184 | # m.weight.data.normal_(std=0.01) 185 | # m.bias.data.fill_(0) 186 | 187 | def forward(self, x_pre, x_cur, x_lat): 188 | # current conv 189 | x_cur_1 = self.cur_b1(x_cur) 190 | x_cur_2 = self.cur_b2(x_cur) 191 | x_cur_3 = self.cur_b3(x_cur) 192 | x_cur_4 = self.cur_b4(x_cur) 193 | x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1)) 194 | cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all)) 195 | cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca)) 196 | 197 | # previois conv 198 | x_pre = self.downsample2(x_pre) 199 | pre_sa = x_cur_all.mul(self.pre_sa(x_pre)) 200 | 201 | # latter conv 202 | x_lat = self.upsample2(x_lat) 203 | lat_sa = x_cur_all.mul(self.lat_sa(x_lat)) 204 | 205 | x_LocAndGlo = cur_all_sa + pre_sa + lat_sa + x_cur 206 | 207 | return x_LocAndGlo 208 | 209 | 210 | class BAB_Decoder(nn.Module): 211 | def __init__(self, channel_1=1024, channel_2=512, channel_3=256, dilation_1=3, dilation_2=2): 212 | super(BAB_Decoder, self).__init__() 213 | 214 | self.conv1 = BasicConv2d(channel_1, channel_2, 3, padding=1) 215 | self.conv1_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_1, dilation=dilation_1) 216 | self.conv2 = BasicConv2d(channel_2, channel_2, 3, padding=1) 217 | self.conv2_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_2, dilation=dilation_2) 218 | self.conv3 = BasicConv2d(channel_2, channel_2, 3, padding=1) 219 | self.conv_fuse = BasicConv2d(channel_2*3, channel_3, 3, padding=1) 220 | 221 | def forward(self, x): 222 | x1 = self.conv1(x) 223 | x1_dila = self.conv1_Dila(x1) 224 | 225 | x2 = self.conv2(x1) 226 | x2_dila = self.conv2_Dila(x2) 227 | 228 | x3 = self.conv3(x2) 229 | 230 | x_fuse = self.conv_fuse(torch.cat((x1_dila, x2_dila, x3), 1)) 231 | 232 | return x_fuse 233 | 234 | 235 | class decoder(nn.Module): 236 | def __init__(self, channel=512): 237 | super(decoder, self).__init__() 238 | self.relu = nn.ReLU(True) 239 | 240 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 241 | 242 | self.decoder5 = nn.Sequential( 243 | BAB_Decoder(512, 512, 512, 3, 2), 244 | nn.Dropout(0.5), 245 | TransBasicConv2d(512, 512, kernel_size=2, stride=2, 246 | padding=0, dilation=1, bias=False) 247 | ) 248 | self.S5 = nn.Conv2d(512, 1, 3, stride=1, padding=1) 249 | 250 | self.decoder4 = nn.Sequential( 251 | BAB_Decoder(1024, 512, 256, 3, 2), 252 | nn.Dropout(0.5), 253 | TransBasicConv2d(256, 256, kernel_size=2, stride=2, 254 | padding=0, dilation=1, bias=False) 255 | ) 256 | self.S4 = nn.Conv2d(256, 1, 3, stride=1, padding=1) 257 | 258 | self.decoder3 = nn.Sequential( 259 | BAB_Decoder(512, 256, 128, 5, 3), 260 | nn.Dropout(0.5), 261 | TransBasicConv2d(128, 128, kernel_size=2, stride=2, 262 | padding=0, dilation=1, bias=False) 263 | ) 264 | self.S3 = nn.Conv2d(128, 1, 3, stride=1, padding=1) 265 | 266 | self.decoder2 = nn.Sequential( 267 | BAB_Decoder(256, 128, 64, 5, 3), 268 | nn.Dropout(0.5), 269 | TransBasicConv2d(64, 64, kernel_size=2, stride=2, 270 | padding=0, dilation=1, bias=False) 271 | ) 272 | self.S2 = nn.Conv2d(64, 1, 3, stride=1, padding=1) 273 | 274 | self.decoder1 = nn.Sequential( 275 | BAB_Decoder(128, 64, 32, 5, 3) 276 | ) 277 | self.S1 = nn.Conv2d(32, 1, 3, stride=1, padding=1) 278 | 279 | 280 | def forward(self, x5, x4, x3, x2, x1): 281 | # x5: 1/16, 512; x4: 1/8, 512; x3: 1/4, 256; x2: 1/2, 128; x1: 1/1, 64 282 | x5_up = self.decoder5(x5) 283 | s5 = self.S5(x5_up) 284 | 285 | x4_up = self.decoder4(torch.cat((x4, x5_up), 1)) 286 | s4 = self.S4(x4_up) 287 | 288 | x3_up = self.decoder3(torch.cat((x3, x4_up), 1)) 289 | s3 = self.S3(x3_up) 290 | 291 | x2_up = self.decoder2(torch.cat((x2, x3_up), 1)) 292 | s2 = self.S2(x2_up) 293 | 294 | x1_up = self.decoder1(torch.cat((x1, x2_up), 1)) 295 | s1 = self.S1(x1_up) 296 | 297 | return s1, s2, s3, s4, s5 298 | 299 | 300 | class ACCoNet_Res(nn.Module): 301 | def __init__(self, channel=32): 302 | super(ACCoNet_Res, self).__init__() 303 | #Backbone model 304 | # ---- ResNet50 Backbone ---- 305 | ( 306 | self.encoder1, 307 | self.encoder2, 308 | self.encoder4, 309 | self.encoder8, 310 | self.encoder16, 311 | ) = Backbone_ResNet50_in3() 312 | 313 | # Lateral layers 314 | self.lateral_conv0 = BasicConv2d(64, 64, 3, stride=1, padding=1) 315 | self.lateral_conv1 = BasicConv2d(256, 128, 3, stride=1, padding=1) 316 | self.lateral_conv2 = BasicConv2d(512, 256, 3, stride=1, padding=1) 317 | self.lateral_conv3 = BasicConv2d(1024, 512, 3, stride=1, padding=1) 318 | self.lateral_conv4 = BasicConv2d(2048, 512, 3, stride=1, padding=1) 319 | 320 | self.ACCoM5 = ACCoM_5(512) 321 | self.ACCoM4 = ACCoM(512) 322 | self.ACCoM3 = ACCoM(256) 323 | self.ACCoM2 = ACCoM(128) 324 | self.ACCoM1 = ACCoM_1(64) 325 | 326 | # self.agg2_rgbd = aggregation(channel) 327 | self.decoder_rgb = decoder(512) 328 | 329 | self.upsample16 = nn.Upsample(scale_factor=16, mode='bilinear', align_corners=True) 330 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 331 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 332 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 333 | 334 | self.sigmoid = nn.Sigmoid() 335 | 336 | 337 | def forward(self, x_rgb): 338 | x0 = self.encoder1(x_rgb) 339 | x1 = self.encoder2(x0) 340 | x2 = self.encoder4(x1) 341 | x3 = self.encoder8(x2) 342 | x4 = self.encoder16(x3) 343 | 344 | x1_rgb = self.lateral_conv0(x0) 345 | x2_rgb = self.lateral_conv1(x1) 346 | x3_rgb = self.lateral_conv2(x2) 347 | x4_rgb = self.lateral_conv3(x3) 348 | x5_rgb = self.lateral_conv4(x4) 349 | 350 | 351 | # up means update 352 | x5_ACCoM = self.ACCoM5(x4_rgb, x5_rgb) 353 | x4_ACCoM = self.ACCoM4(x3_rgb, x4_rgb, x5_rgb) 354 | x3_ACCoM = self.ACCoM3(x2_rgb, x3_rgb, x4_rgb) 355 | x2_ACCoM = self.ACCoM2(x1_rgb, x2_rgb, x3_rgb) 356 | x1_ACCoM = self.ACCoM1(x1_rgb, x2_rgb) 357 | 358 | s1, s2, s3, s4, s5 = self.decoder_rgb(x5_ACCoM, x4_ACCoM, x3_ACCoM, x2_ACCoM, x1_ACCoM) 359 | # At test phase, we can use the HA to post-processing our saliency map 360 | s1 = self.upsample2(s1) 361 | s2 = self.upsample2(s2) 362 | s3 = self.upsample4(s3) 363 | s4 = self.upsample8(s4) 364 | s5 = self.upsample16(s5) 365 | 366 | return s1, s2, s3, s4, s5, self.sigmoid(s1), self.sigmoid(s2), self.sigmoid(s3), self.sigmoid(s4), self.sigmoid(s5) -------------------------------------------------------------------------------- /model/ACCoNet_VGG_models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from vgg import VGG 6 | 7 | class BasicConv2d(nn.Module): 8 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 9 | super(BasicConv2d, self).__init__() 10 | self.conv = nn.Conv2d(in_planes, out_planes, 11 | kernel_size=kernel_size, stride=stride, 12 | padding=padding, dilation=dilation, bias=False) 13 | self.bn = nn.BatchNorm2d(out_planes) 14 | self.relu = nn.ReLU(inplace=True) 15 | 16 | def forward(self, x): 17 | x = self.conv(x) 18 | x = self.bn(x) 19 | x = self.relu(x) 20 | return x 21 | 22 | class TransBasicConv2d(nn.Module): 23 | def __init__(self, in_planes, out_planes, kernel_size=2, stride=2, padding=0, dilation=1, bias=False): 24 | super(TransBasicConv2d, self).__init__() 25 | self.Deconv = nn.ConvTranspose2d(in_planes, out_planes, 26 | kernel_size=kernel_size, stride=stride, 27 | padding=padding, dilation=dilation, bias=False) 28 | self.bn = nn.BatchNorm2d(out_planes) 29 | self.relu = nn.ReLU(inplace=True) 30 | 31 | def forward(self, x): 32 | x = self.Deconv(x) 33 | x = self.bn(x) 34 | x = self.relu(x) 35 | return x 36 | 37 | 38 | class ChannelAttention(nn.Module): 39 | def __init__(self, in_planes, ratio=16): 40 | super(ChannelAttention, self).__init__() 41 | 42 | self.max_pool = nn.AdaptiveMaxPool2d(1) 43 | 44 | self.fc1 = nn.Conv2d(in_planes, in_planes // 16, 1, bias=False) 45 | self.relu1 = nn.ReLU() 46 | self.fc2 = nn.Conv2d(in_planes // 16, in_planes, 1, bias=False) 47 | 48 | self.sigmoid = nn.Sigmoid() 49 | 50 | def forward(self, x): 51 | max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x)))) 52 | out = max_out 53 | return self.sigmoid(out) 54 | 55 | 56 | class SpatialAttention(nn.Module): 57 | def __init__(self, kernel_size=7): 58 | super(SpatialAttention, self).__init__() 59 | 60 | assert kernel_size in (3, 7), 'kernel size must be 3 or 7' 61 | padding = 3 if kernel_size == 7 else 1 62 | 63 | self.conv1 = nn.Conv2d(1, 1, kernel_size, padding=padding, bias=False) 64 | self.sigmoid = nn.Sigmoid() 65 | 66 | def forward(self, x): 67 | max_out, _ = torch.max(x, dim=1, keepdim=True) 68 | x = max_out 69 | x = self.conv1(x) 70 | return self.sigmoid(x) 71 | 72 | # for conv5 73 | class ACCoM5(nn.Module): 74 | def __init__(self, cur_channel): 75 | super(ACCoM5, self).__init__() 76 | self.relu = nn.ReLU(True) 77 | 78 | # current conv 79 | self.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1) 80 | self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2) 81 | self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3) 82 | self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4) 83 | 84 | self.cur_all = BasicConv2d(4*cur_channel, cur_channel, 3, padding=1) 85 | self.cur_all_ca = ChannelAttention(cur_channel) 86 | self.cur_all_sa = SpatialAttention() 87 | 88 | # previous conv 89 | self.downsample2 = nn.MaxPool2d(2, stride=2) 90 | self.pre_sa = SpatialAttention() 91 | 92 | def forward(self, x_pre, x_cur): 93 | # current conv 94 | x_cur_1 = self.cur_b1(x_cur) 95 | x_cur_2 = self.cur_b2(x_cur) 96 | x_cur_3 = self.cur_b3(x_cur) 97 | x_cur_4 = self.cur_b4(x_cur) 98 | x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1)) 99 | cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all)) 100 | cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca)) 101 | 102 | # previois conv 103 | x_pre = self.downsample2(x_pre) 104 | pre_sa = x_cur_all.mul(self.pre_sa(x_pre)) 105 | 106 | x_LocAndGlo = cur_all_sa + pre_sa + x_cur 107 | 108 | return x_LocAndGlo 109 | 110 | # for conv1 111 | class ACCoM1(nn.Module): 112 | def __init__(self, cur_channel): 113 | super(ACCoM1, self).__init__() 114 | self.relu = nn.ReLU(True) 115 | 116 | # current conv 117 | self.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1) 118 | self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2) 119 | self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3) 120 | self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4) 121 | 122 | self.cur_all = BasicConv2d(4*cur_channel, cur_channel, 3, padding=1) 123 | self.cur_all_ca = ChannelAttention(cur_channel) 124 | self.cur_all_sa = SpatialAttention() 125 | 126 | # latter conv 127 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 128 | self.lat_sa = SpatialAttention() 129 | 130 | def forward(self, x_cur, x_lat): 131 | # current conv 132 | x_cur_1 = self.cur_b1(x_cur) 133 | x_cur_2 = self.cur_b2(x_cur) 134 | x_cur_3 = self.cur_b3(x_cur) 135 | x_cur_4 = self.cur_b4(x_cur) 136 | x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1)) 137 | cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all)) 138 | cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca)) 139 | 140 | # latter conv 141 | x_lat = self.upsample2(x_lat) 142 | lat_sa = x_cur_all.mul(self.lat_sa(x_lat)) 143 | 144 | x_LocAndGlo = cur_all_sa + lat_sa + x_cur 145 | 146 | return x_LocAndGlo 147 | 148 | # for conv2/3/4 149 | class ACCoM(nn.Module): 150 | def __init__(self, cur_channel): 151 | super(ACCoM, self).__init__() 152 | self.relu = nn.ReLU(True) 153 | 154 | # current conv 155 | self.cur_b1 = BasicConv2d(cur_channel, cur_channel, 3, padding=1, dilation=1) 156 | self.cur_b2 = BasicConv2d(cur_channel, cur_channel, 3, padding=2, dilation=2) 157 | self.cur_b3 = BasicConv2d(cur_channel, cur_channel, 3, padding=3, dilation=3) 158 | self.cur_b4 = BasicConv2d(cur_channel, cur_channel, 3, padding=4, dilation=4) 159 | 160 | self.cur_all = BasicConv2d(4 * cur_channel, cur_channel, 3, padding=1) 161 | self.cur_all_ca = ChannelAttention(cur_channel) 162 | self.cur_all_sa = SpatialAttention() 163 | 164 | # previous conv 165 | self.downsample2 = nn.MaxPool2d(2, stride=2) 166 | self.pre_sa = SpatialAttention() 167 | 168 | # latter conv 169 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 170 | self.lat_sa = SpatialAttention() 171 | 172 | def forward(self, x_pre, x_cur, x_lat): 173 | # current conv 174 | x_cur_1 = self.cur_b1(x_cur) 175 | x_cur_2 = self.cur_b2(x_cur) 176 | x_cur_3 = self.cur_b3(x_cur) 177 | x_cur_4 = self.cur_b4(x_cur) 178 | x_cur_all = self.cur_all(torch.cat((x_cur_1, x_cur_2, x_cur_3, x_cur_4), 1)) 179 | cur_all_ca = x_cur_all.mul(self.cur_all_ca(x_cur_all)) 180 | cur_all_sa = x_cur_all.mul(self.cur_all_sa(cur_all_ca)) 181 | 182 | # previois conv 183 | x_pre = self.downsample2(x_pre) 184 | pre_sa = x_cur_all.mul(self.pre_sa(x_pre)) 185 | 186 | # latter conv 187 | x_lat = self.upsample2(x_lat) 188 | lat_sa = x_cur_all.mul(self.lat_sa(x_lat)) 189 | 190 | x_LocAndGlo = cur_all_sa + pre_sa + lat_sa + x_cur 191 | 192 | return x_LocAndGlo 193 | 194 | 195 | class BAB_Decoder(nn.Module): 196 | def __init__(self, channel_1=1024, channel_2=512, channel_3=256, dilation_1=3, dilation_2=2): 197 | super(BAB_Decoder, self).__init__() 198 | 199 | self.conv1 = BasicConv2d(channel_1, channel_2, 3, padding=1) 200 | self.conv1_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_1, dilation=dilation_1) 201 | self.conv2 = BasicConv2d(channel_2, channel_2, 3, padding=1) 202 | self.conv2_Dila = BasicConv2d(channel_2, channel_2, 3, padding=dilation_2, dilation=dilation_2) 203 | self.conv3 = BasicConv2d(channel_2, channel_2, 3, padding=1) 204 | self.conv_fuse = BasicConv2d(channel_2*3, channel_3, 3, padding=1) 205 | 206 | def forward(self, x): 207 | x1 = self.conv1(x) 208 | x1_dila = self.conv1_Dila(x1) 209 | 210 | x2 = self.conv2(x1) 211 | x2_dila = self.conv2_Dila(x2) 212 | 213 | x3 = self.conv3(x2) 214 | 215 | x_fuse = self.conv_fuse(torch.cat((x1_dila, x2_dila, x3), 1)) 216 | 217 | return x_fuse 218 | 219 | 220 | class decoder(nn.Module): 221 | def __init__(self, channel=512): 222 | super(decoder, self).__init__() 223 | self.relu = nn.ReLU(True) 224 | 225 | self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 226 | 227 | self.decoder5 = nn.Sequential( 228 | BAB_Decoder(512, 512, 512, 3, 2), 229 | nn.Dropout(0.5), 230 | TransBasicConv2d(512, 512, kernel_size=2, stride=2, 231 | padding=0, dilation=1, bias=False) 232 | ) 233 | self.S5 = nn.Conv2d(512, 1, 3, stride=1, padding=1) 234 | 235 | self.decoder4 = nn.Sequential( 236 | BAB_Decoder(1024, 512, 256, 3, 2), 237 | nn.Dropout(0.5), 238 | TransBasicConv2d(256, 256, kernel_size=2, stride=2, 239 | padding=0, dilation=1, bias=False) 240 | ) 241 | self.S4 = nn.Conv2d(256, 1, 3, stride=1, padding=1) 242 | 243 | self.decoder3 = nn.Sequential( 244 | BAB_Decoder(512, 256, 128, 5, 3), 245 | nn.Dropout(0.5), 246 | TransBasicConv2d(128, 128, kernel_size=2, stride=2, 247 | padding=0, dilation=1, bias=False) 248 | ) 249 | self.S3 = nn.Conv2d(128, 1, 3, stride=1, padding=1) 250 | 251 | self.decoder2 = nn.Sequential( 252 | BAB_Decoder(256, 128, 64, 5, 3), 253 | nn.Dropout(0.5), 254 | TransBasicConv2d(64, 64, kernel_size=2, stride=2, 255 | padding=0, dilation=1, bias=False) 256 | ) 257 | self.S2 = nn.Conv2d(64, 1, 3, stride=1, padding=1) 258 | 259 | self.decoder1 = nn.Sequential( 260 | BAB_Decoder(128, 64, 32, 5, 3) 261 | ) 262 | self.S1 = nn.Conv2d(32, 1, 3, stride=1, padding=1) 263 | 264 | 265 | def forward(self, x5, x4, x3, x2, x1): 266 | # x5: 1/16, 512; x4: 1/8, 512; x3: 1/4, 256; x2: 1/2, 128; x1: 1/1, 64 267 | x5_up = self.decoder5(x5) 268 | s5 = self.S5(x5_up) 269 | 270 | x4_up = self.decoder4(torch.cat((x4, x5_up), 1)) 271 | s4 = self.S4(x4_up) 272 | 273 | x3_up = self.decoder3(torch.cat((x3, x4_up), 1)) 274 | s3 = self.S3(x3_up) 275 | 276 | x2_up = self.decoder2(torch.cat((x2, x3_up), 1)) 277 | s2 = self.S2(x2_up) 278 | 279 | x1_up = self.decoder1(torch.cat((x1, x2_up), 1)) 280 | s1 = self.S1(x1_up) 281 | 282 | return s1, s2, s3, s4, s5 283 | 284 | 285 | class ACCoNet_VGG(nn.Module): 286 | def __init__(self, channel=32): 287 | super(ACCoNet_VGG, self).__init__() 288 | #Backbone model 289 | self.vgg = VGG('rgb') 290 | 291 | self.ACCoM5 = ACCoM5(512) 292 | self.ACCoM4 = ACCoM(512) 293 | self.ACCoM3 = ACCoM(256) 294 | self.ACCoM2 = ACCoM(128) 295 | self.ACCoM1 = ACCoM1(64) 296 | 297 | # self.agg2_rgbd = aggregation(channel) 298 | self.decoder_rgb = decoder(512) 299 | 300 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 301 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 302 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 303 | 304 | self.sigmoid = nn.Sigmoid() 305 | 306 | 307 | def forward(self, x_rgb): 308 | x1_rgb = self.vgg.conv1(x_rgb) 309 | x2_rgb = self.vgg.conv2(x1_rgb) 310 | x3_rgb = self.vgg.conv3(x2_rgb) 311 | x4_rgb = self.vgg.conv4(x3_rgb) 312 | x5_rgb = self.vgg.conv5(x4_rgb) 313 | 314 | # up means update 315 | x5_ACCoM = self.ACCoM5(x4_rgb, x5_rgb) 316 | x4_ACCoM = self.ACCoM4(x3_rgb, x4_rgb, x5_rgb) 317 | x3_ACCoM = self.ACCoM3(x2_rgb, x3_rgb, x4_rgb) 318 | x2_ACCoM = self.ACCoM2(x1_rgb, x2_rgb, x3_rgb) 319 | x1_ACCoM = self.ACCoM1(x1_rgb, x2_rgb) 320 | 321 | s1, s2, s3, s4, s5 = self.decoder_rgb(x5_ACCoM, x4_ACCoM, x3_ACCoM, x2_ACCoM, x1_ACCoM) 322 | 323 | s3 = self.upsample2(s3) 324 | s4 = self.upsample4(s4) 325 | s5 = self.upsample8(s5) 326 | 327 | return s1, s2, s3, s4, s5, self.sigmoid(s1), self.sigmoid(s2), self.sigmoid(s3), self.sigmoid(s4), self.sigmoid(s5) -------------------------------------------------------------------------------- /model/ResNet50.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import torch.nn as nn 4 | import torch.utils.model_zoo as model_zoo 5 | 6 | __all__ = ["Backbone_ResNet50_in3"] 7 | 8 | model_urls = { 9 | "resnet18": "https://download.pytorch.org/models/resnet18-5c106cde.pth", 10 | "resnet34": "https://download.pytorch.org/models/resnet34-333f7ec4.pth", 11 | "resnet50": "https://download.pytorch.org/models/resnet50-19c8e357.pth", 12 | "resnet101": "https://download.pytorch.org/models/resnet101-5d3b4d8f.pth", 13 | "resnet152": "https://download.pytorch.org/models/resnet152-b121ed2d.pth", 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) 20 | 21 | 22 | def conv1x1(in_planes, out_planes, stride=1): 23 | """1x1 convolution""" 24 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 25 | 26 | 27 | class BasicBlock(nn.Module): 28 | expansion = 1 29 | 30 | def __init__(self, inplanes, planes, stride=1, downsample=None): 31 | super(BasicBlock, self).__init__() 32 | self.conv1 = conv3x3(inplanes, planes, stride) 33 | self.bn1 = nn.BatchNorm2d(planes) 34 | self.relu = nn.ReLU(inplace=True) 35 | self.conv2 = conv3x3(planes, planes) 36 | self.bn2 = nn.BatchNorm2d(planes) 37 | self.downsample = downsample 38 | self.stride = stride 39 | 40 | def forward(self, x): 41 | identity = x 42 | 43 | out = self.conv1(x) 44 | out = self.bn1(out) 45 | out = self.relu(out) 46 | 47 | out = self.conv2(out) 48 | out = self.bn2(out) 49 | 50 | if self.downsample is not None: 51 | identity = self.downsample(x) 52 | 53 | out += identity 54 | out = self.relu(out) 55 | 56 | return out 57 | 58 | 59 | class Bottleneck(nn.Module): 60 | expansion = 4 61 | 62 | def __init__(self, inplanes, planes, stride=1, downsample=None): 63 | super(Bottleneck, self).__init__() 64 | self.conv1 = conv1x1(inplanes, planes) 65 | self.bn1 = nn.BatchNorm2d(planes) 66 | self.conv2 = conv3x3(planes, planes, stride) 67 | self.bn2 = nn.BatchNorm2d(planes) 68 | self.conv3 = conv1x1(planes, planes * self.expansion) 69 | self.bn3 = nn.BatchNorm2d(planes * self.expansion) 70 | self.relu = nn.ReLU(inplace=True) 71 | self.downsample = downsample 72 | self.stride = stride 73 | 74 | def forward(self, x): 75 | identity = x 76 | 77 | out = self.conv1(x) 78 | out = self.bn1(out) 79 | out = self.relu(out) 80 | 81 | out = self.conv2(out) 82 | out = self.bn2(out) 83 | out = self.relu(out) 84 | 85 | out = self.conv3(out) 86 | out = self.bn3(out) 87 | 88 | if self.downsample is not None: 89 | identity = self.downsample(x) 90 | 91 | out += identity 92 | out = self.relu(out) 93 | 94 | return out 95 | 96 | 97 | class ResNet(nn.Module): 98 | def __init__(self, block, layers, zero_init_residual=False): 99 | super(ResNet, self).__init__() 100 | self.inplanes = 64 101 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False) 102 | self.bn1 = nn.BatchNorm2d(64) 103 | self.relu = nn.ReLU(inplace=True) 104 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 105 | self.layer1 = self._make_layer(block, 64, layers[0]) 106 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 107 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) # 6 108 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) # 3 109 | 110 | for m in self.modules(): 111 | if isinstance(m, nn.Conv2d): 112 | nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") 113 | elif isinstance(m, nn.BatchNorm2d): 114 | nn.init.constant_(m.weight, 1) 115 | nn.init.constant_(m.bias, 0) 116 | 117 | # Zero-initialize the last BN in each residual branch, 118 | # so that the residual branch starts with zeros, and each residual block behaves like an identity. 119 | # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677 120 | if zero_init_residual: 121 | for m in self.modules(): 122 | if isinstance(m, Bottleneck): 123 | nn.init.constant_(m.bn3.weight, 0) 124 | elif isinstance(m, BasicBlock): 125 | nn.init.constant_(m.bn2.weight, 0) 126 | 127 | def _make_layer(self, block, planes, blocks, stride=1): 128 | downsample = None 129 | if stride != 1 or self.inplanes != planes * block.expansion: 130 | downsample = nn.Sequential( 131 | conv1x1(self.inplanes, planes * block.expansion, stride), 132 | nn.BatchNorm2d(planes * block.expansion), 133 | ) 134 | 135 | layers = [] 136 | layers.append(block(self.inplanes, planes, stride, downsample)) 137 | self.inplanes = planes * block.expansion 138 | for _ in range(1, blocks): 139 | layers.append(block(self.inplanes, planes)) 140 | 141 | return nn.Sequential(*layers) 142 | 143 | def forward(self, x): 144 | x = self.conv1(x) 145 | x = self.bn1(x) 146 | x = self.relu(x) 147 | x = self.maxpool(x) 148 | 149 | x = self.layer1(x) 150 | x = self.layer2(x) 151 | x = self.layer3(x) 152 | x = self.layer4(x) 153 | 154 | return x 155 | 156 | 157 | def resnet18(pretrained=False, **kwargs): 158 | """Constructs a ResNet-18 model. 159 | 160 | Args: 161 | pretrained (bool): If True, returns a model pre-trained on ImageNet 162 | """ 163 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 164 | if pretrained: 165 | pretrained_dict = model_zoo.load_url(model_urls["resnet18"]) 166 | 167 | model_dict = model.state_dict() 168 | # 1. filter out unnecessary keys 169 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 170 | # 2. overwrite entries in the existing state dict 171 | model_dict.update(pretrained_dict) 172 | # 3. load the new state dict 173 | model.load_state_dict(model_dict) 174 | return model 175 | 176 | 177 | def resnet34(pretrained=False, **kwargs): 178 | """Constructs a ResNet-34 model. 179 | 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 184 | if pretrained: 185 | pretrained_dict = model_zoo.load_url(model_urls["resnet34"]) 186 | 187 | model_dict = model.state_dict() 188 | # 1. filter out unnecessary keys 189 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 190 | # 2. overwrite entries in the existing state dict 191 | model_dict.update(pretrained_dict) 192 | # 3. load the new state dict 193 | model.load_state_dict(model_dict) 194 | return model 195 | 196 | 197 | def resnet50(pretrained=False, **kwargs): 198 | """Constructs a ResNet-50 model. 199 | 200 | Args: 201 | pretrained (bool): If True, returns a model pre-trained on ImageNet 202 | """ 203 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 204 | 205 | if pretrained: 206 | pretrained_dict = model_zoo.load_url(model_urls["resnet50"]) 207 | 208 | model_dict = model.state_dict() 209 | # 1. filter out unnecessary keys 210 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 211 | # 2. overwrite entries in the existing state dict 212 | model_dict.update(pretrained_dict) 213 | # 3. load the new state dict 214 | model.load_state_dict(model_dict) 215 | 216 | return model 217 | 218 | 219 | def resnet101(pretrained=False, **kwargs): 220 | """Constructs a ResNet-101 model. 221 | 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | """ 225 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 226 | if pretrained: 227 | pretrained_dict = model_zoo.load_url(model_urls["resnet101"]) 228 | 229 | model_dict = model.state_dict() 230 | # 1. filter out unnecessary keys 231 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 232 | # 2. overwrite entries in the existing state dict 233 | model_dict.update(pretrained_dict) 234 | # 3. load the new state dict 235 | model.load_state_dict(model_dict) 236 | return model 237 | 238 | 239 | def resnet152(pretrained=False, **kwargs): 240 | """Constructs a ResNet-152 model. 241 | 242 | Args: 243 | pretrained (bool): If True, returns a model pre-trained on ImageNet 244 | """ 245 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 246 | 247 | if pretrained: 248 | pretrained_dict = model_zoo.load_url(model_urls["resnet152"]) 249 | 250 | model_dict = model.state_dict() 251 | # 1. filter out unnecessary keys 252 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 253 | # 2. overwrite entries in the existing state dict 254 | model_dict.update(pretrained_dict) 255 | # 3. load the new state dict 256 | model.load_state_dict(model_dict) 257 | 258 | return model 259 | 260 | 261 | def Backbone_ResNet50_in3(): 262 | net = resnet50(pretrained=True) 263 | div_2 = nn.Sequential(*list(net.children())[:3]) 264 | div_4 = nn.Sequential(*list(net.children())[3:5]) 265 | div_8 = net.layer2 266 | div_16 = net.layer3 267 | div_32 = net.layer4 268 | 269 | return div_2, div_4, div_8, div_16, div_32 270 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/vgg.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class VGG(nn.Module): 6 | # pooling layer at the front of block 7 | def __init__(self, mode = 'rgb'): 8 | super(VGG, self).__init__() 9 | 10 | conv1 = nn.Sequential() 11 | conv1.add_module('conv1_1', nn.Conv2d(3, 64, 3, 1, 1)) 12 | conv1.add_module('bn1_1', nn.BatchNorm2d(64)) 13 | conv1.add_module('relu1_1', nn.ReLU(inplace=True)) 14 | conv1.add_module('conv1_2', nn.Conv2d(64, 64, 3, 1, 1)) 15 | conv1.add_module('bn1_2', nn.BatchNorm2d(64)) 16 | conv1.add_module('relu1_2', nn.ReLU(inplace=True)) 17 | 18 | self.conv1 = conv1 19 | conv2 = nn.Sequential() 20 | conv2.add_module('pool1', nn.MaxPool2d(2, stride=2)) 21 | conv2.add_module('conv2_1', nn.Conv2d(64, 128, 3, 1, 1)) 22 | conv2.add_module('bn2_1', nn.BatchNorm2d(128)) 23 | conv2.add_module('relu2_1', nn.ReLU()) 24 | conv2.add_module('conv2_2', nn.Conv2d(128, 128, 3, 1, 1)) 25 | conv2.add_module('bn2_2', nn.BatchNorm2d(128)) 26 | conv2.add_module('relu2_2', nn.ReLU()) 27 | self.conv2 = conv2 28 | 29 | conv3 = nn.Sequential() 30 | conv3.add_module('pool2', nn.MaxPool2d(2, stride=2)) 31 | conv3.add_module('conv3_1', nn.Conv2d(128, 256, 3, 1, 1)) 32 | conv3.add_module('bn3_1', nn.BatchNorm2d(256)) 33 | conv3.add_module('relu3_1', nn.ReLU()) 34 | conv3.add_module('conv3_2', nn.Conv2d(256, 256, 3, 1, 1)) 35 | conv3.add_module('bn3_2', nn.BatchNorm2d(256)) 36 | conv3.add_module('relu3_2', nn.ReLU()) 37 | conv3.add_module('conv3_3', nn.Conv2d(256, 256, 3, 1, 1)) 38 | conv3.add_module('bn3_3', nn.BatchNorm2d(256)) 39 | conv3.add_module('relu3_3', nn.ReLU()) 40 | self.conv3 = conv3 41 | 42 | conv4 = nn.Sequential() 43 | conv4.add_module('pool3_1', nn.MaxPool2d(2, stride=2)) 44 | conv4.add_module('conv4_1', nn.Conv2d(256, 512, 3, 1, 1)) 45 | conv4.add_module('bn4_1', nn.BatchNorm2d(512)) 46 | conv4.add_module('relu4_1', nn.ReLU()) 47 | conv4.add_module('conv4_2', nn.Conv2d(512, 512, 3, 1, 1)) 48 | conv4.add_module('bn4_2', nn.BatchNorm2d(512)) 49 | conv4.add_module('relu4_2', nn.ReLU()) 50 | conv4.add_module('conv4_3', nn.Conv2d(512, 512, 3, 1, 1)) 51 | conv4.add_module('bn4_3', nn.BatchNorm2d(512)) 52 | conv4.add_module('relu4_3', nn.ReLU()) 53 | self.conv4 = conv4 54 | 55 | conv5 = nn.Sequential() 56 | conv5.add_module('pool4', nn.MaxPool2d(2, stride=2)) 57 | conv5.add_module('conv5_1', nn.Conv2d(512, 512, 3, 1, 1)) 58 | conv5.add_module('bn5_1', nn.BatchNorm2d(512)) 59 | conv5.add_module('relu5_1', nn.ReLU()) 60 | conv5.add_module('conv5_2', nn.Conv2d(512, 512, 3, 1, 1)) 61 | conv5.add_module('bn5_2', nn.BatchNorm2d(512)) 62 | conv5.add_module('relu5_2', nn.ReLU()) 63 | conv5.add_module('conv5_3', nn.Conv2d(512, 512, 3, 1, 1)) 64 | conv5.add_module('bn5_2', nn.BatchNorm2d(512)) 65 | conv5.add_module('relu5_3', nn.ReLU()) 66 | self.conv5 = conv5 67 | 68 | pre_train = torch.load('/home/lgy/20210206_ORSI_SOD/model/vgg16-397923af.pth') 69 | self._initialize_weights(pre_train) 70 | 71 | def forward(self, x): 72 | x = self.conv1(x) 73 | x = self.conv2(x) 74 | x = self.conv3(x) 75 | x = self.conv4(x) 76 | x = self.conv5(x) 77 | 78 | return x 79 | 80 | def _initialize_weights(self, pre_train): 81 | keys = pre_train.keys() 82 | 83 | self.conv1.conv1_1.weight.data.copy_(pre_train[keys[0]]) 84 | self.conv1.conv1_2.weight.data.copy_(pre_train[keys[2]]) 85 | self.conv2.conv2_1.weight.data.copy_(pre_train[keys[4]]) 86 | self.conv2.conv2_2.weight.data.copy_(pre_train[keys[6]]) 87 | self.conv3.conv3_1.weight.data.copy_(pre_train[keys[8]]) 88 | self.conv3.conv3_2.weight.data.copy_(pre_train[keys[10]]) 89 | self.conv3.conv3_3.weight.data.copy_(pre_train[keys[12]]) 90 | self.conv4.conv4_1.weight.data.copy_(pre_train[keys[14]]) 91 | self.conv4.conv4_2.weight.data.copy_(pre_train[keys[16]]) 92 | self.conv4.conv4_3.weight.data.copy_(pre_train[keys[18]]) 93 | self.conv5.conv5_1.weight.data.copy_(pre_train[keys[20]]) 94 | self.conv5.conv5_2.weight.data.copy_(pre_train[keys[22]]) 95 | self.conv5.conv5_3.weight.data.copy_(pre_train[keys[24]]) 96 | 97 | self.conv1.conv1_1.bias.data.copy_(pre_train[keys[1]]) 98 | self.conv1.conv1_2.bias.data.copy_(pre_train[keys[3]]) 99 | self.conv2.conv2_1.bias.data.copy_(pre_train[keys[5]]) 100 | self.conv2.conv2_2.bias.data.copy_(pre_train[keys[7]]) 101 | self.conv3.conv3_1.bias.data.copy_(pre_train[keys[9]]) 102 | self.conv3.conv3_2.bias.data.copy_(pre_train[keys[11]]) 103 | self.conv3.conv3_3.bias.data.copy_(pre_train[keys[13]]) 104 | self.conv4.conv4_1.bias.data.copy_(pre_train[keys[15]]) 105 | self.conv4.conv4_2.bias.data.copy_(pre_train[keys[17]]) 106 | self.conv4.conv4_3.bias.data.copy_(pre_train[keys[19]]) 107 | self.conv5.conv5_1.bias.data.copy_(pre_train[keys[21]]) 108 | self.conv5.conv5_2.bias.data.copy_(pre_train[keys[23]]) 109 | self.conv5.conv5_3.bias.data.copy_(pre_train[keys[25]]) 110 | -------------------------------------------------------------------------------- /pytorch_iou/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | from torch.autograd import Variable 6 | import numpy as np 7 | 8 | def _iou(pred, target, size_average = True): 9 | 10 | b = pred.shape[0] 11 | IoU = 0.0 12 | for i in range(0,b): 13 | #compute the IoU of the foreground 14 | Iand1 = torch.sum(target[i,:,:,:]*pred[i,:,:,:]) 15 | Ior1 = torch.sum(target[i,:,:,:]) + torch.sum(pred[i,:,:,:])-Iand1 16 | IoU1 = Iand1/Ior1 17 | 18 | #IoU loss is (1-IoU1) 19 | IoU = IoU + (1-IoU1) 20 | 21 | return IoU/b 22 | 23 | class IOU(torch.nn.Module): 24 | def __init__(self, size_average = True): 25 | super(IOU, self).__init__() 26 | self.size_average = size_average 27 | 28 | def forward(self, pred, target): 29 | 30 | return _iou(pred, target, self.size_average) 31 | -------------------------------------------------------------------------------- /pytorch_iou/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/ACCoNet/d37bd207286b9e522803d3695e7180e4941ae117/pytorch_iou/__init__.pyc -------------------------------------------------------------------------------- /pytorch_iou/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MathLee/ACCoNet/d37bd207286b9e522803d3695e7180e4941ae117/pytorch_iou/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /test_ACCoNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | import numpy as np 5 | import pdb, os, argparse 6 | from scipy import misc 7 | import time 8 | 9 | from model.ACCoNet_VGG_models import ACCoNet_VGG 10 | from model.ACCoNet_Res_models import ACCoNet_Res 11 | from data import test_dataset 12 | 13 | torch.cuda.set_device(0) 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--testsize', type=int, default=256, help='testing size') 16 | parser.add_argument('--is_ResNet', type=bool, default=True, help='VGG or ResNet backbone') 17 | opt = parser.parse_args() 18 | 19 | dataset_path = './dataset/test_dataset/' 20 | 21 | if opt.is_ResNet: 22 | model = ACCoNet_Res() 23 | model.load_state_dict(torch.load('./models/ACCoNet_ResNet/ACCoNet_Res.pth.39')) 24 | else: 25 | model = ACCoNet_VGG() 26 | model.load_state_dict(torch.load('./models/ACCoNet_VGG/ACCoNet_VGG.pth.54')) 27 | 28 | model.cuda() 29 | model.eval() 30 | 31 | # test_datasets = ['EORSSD'] 32 | test_datasets = ['ORSSD'] 33 | 34 | for dataset in test_datasets: 35 | if opt.is_ResNet: 36 | save_path = './results/ResNet50/' + dataset + '/' 37 | else: 38 | save_path = './results/VGG/' + dataset + '/' 39 | if not os.path.exists(save_path): 40 | os.makedirs(save_path) 41 | image_root = dataset_path + dataset + '/image/' 42 | print(dataset) 43 | gt_root = dataset_path + dataset + '/GT/' 44 | test_loader = test_dataset(image_root, gt_root, opt.testsize) 45 | time_sum = 0 46 | for i in range(test_loader.size): 47 | image, gt, name = test_loader.load_data() 48 | gt = np.asarray(gt, np.float32) 49 | gt /= (gt.max() + 1e-8) 50 | image = image.cuda() 51 | time_start = time.time() 52 | res, s2, s3, s4, s5, s1_sig, s2_sig, s3_sig, s4_sig, s5_sig = model(image) 53 | time_end = time.time() 54 | time_sum = time_sum+(time_end-time_start) 55 | res = F.upsample(res, size=gt.shape, mode='bilinear', align_corners=False) 56 | res = res.sigmoid().data.cpu().numpy().squeeze() 57 | res = (res - res.min()) / (res.max() - res.min() + 1e-8) 58 | misc.imsave(save_path+name, res) 59 | if i == test_loader.size-1: 60 | print('Running time {:.5f}'.format(time_sum/test_loader.size)) 61 | print('Average speed: {:.4f} fps'.format(test_loader.size/time_sum)) -------------------------------------------------------------------------------- /train_ACCoNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.autograd import Variable 5 | 6 | import numpy as np 7 | import pdb, os, argparse 8 | from datetime import datetime 9 | 10 | from model.ACCoNet_VGG_models import ACCoNet_VGG 11 | from model.ACCoNet_Res_models import ACCoNet_Res 12 | from data import get_loader 13 | from utils import clip_gradient, adjust_lr 14 | 15 | import pytorch_iou 16 | 17 | torch.cuda.set_device(0) 18 | parser = argparse.ArgumentParser() 19 | parser.add_argument('--epoch', type=int, default=40, help='epoch number') 20 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 21 | % For vgg, batchsize is 6; for ResNet, batchsize is 8. 22 | parser.add_argument('--batchsize', type=int, default=6, help='training batch size') 23 | parser.add_argument('--trainsize', type=int, default=256, help='training dataset size') 24 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 25 | parser.add_argument('--is_ResNet', type=bool, default=False, help='VGG or ResNet backbone') 26 | parser.add_argument('--decay_rate', type=float, default=0.1, help='decay rate of learning rate') 27 | parser.add_argument('--decay_epoch', type=int, default=30, help='every n epochs decay learning rate') 28 | opt = parser.parse_args() 29 | 30 | print('Learning Rate: {} ResNet: {}'.format(opt.lr, opt.is_ResNet)) 31 | # build models 32 | if opt.is_ResNet: 33 | model = ACCoNet_Res() 34 | else: 35 | model = ACCoNet_VGG() 36 | 37 | model.cuda() 38 | params = model.parameters() 39 | optimizer = torch.optim.Adam(params, opt.lr) 40 | 41 | image_root = './dataset/train_dataset/ORSSD/train/image/' 42 | gt_root = './dataset/train_dataset/ORSSD/train/GT/' 43 | # image_root = './dataset/train_dataset/EORSSD/train/image/' 44 | # gt_root = './dataset/train_dataset/EORSSD/train/GT/' 45 | train_loader = get_loader(image_root, gt_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 46 | total_step = len(train_loader) 47 | 48 | CE = torch.nn.BCEWithLogitsLoss() 49 | IOU = pytorch_iou.IOU(size_average = True) 50 | 51 | def train(train_loader, model, optimizer, epoch): 52 | model.train() 53 | for i, pack in enumerate(train_loader, start=1): 54 | optimizer.zero_grad() 55 | images, gts = pack 56 | images = Variable(images) 57 | gts = Variable(gts) 58 | images = images.cuda() 59 | gts = gts.cuda() 60 | 61 | s1, s2, s3, s4, s5, s1_sig, s2_sig, s3_sig, s4_sig, s5_sig = model(images) 62 | 63 | loss1 = CE(s1, gts) + IOU(s1_sig, gts) 64 | loss2 = CE(s2, gts) + IOU(s2_sig, gts) 65 | loss3 = CE(s3, gts) + IOU(s3_sig, gts) 66 | loss4 = CE(s4, gts) + IOU(s4_sig, gts) 67 | loss5 = CE(s5, gts) + IOU(s5_sig, gts) 68 | 69 | loss = loss1 + loss2 + loss3 + loss4 + loss5 70 | 71 | loss.backward() 72 | 73 | clip_gradient(optimizer, opt.clip) 74 | optimizer.step() 75 | 76 | if i % 20 == 0 or i == total_step: 77 | print( 78 | '{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], Learning Rate: {}, Loss: {:.4f}, Loss_ce: {:.4f}, Loss_iou: {:.4f}'. 79 | format(datetime.now(), epoch, opt.epoch, i, total_step, opt.lr * opt.decay_rate ** (epoch // opt.decay_epoch), loss.data, loss1.data, loss2.data)) 80 | 81 | if opt.is_ResNet: 82 | save_path = 'models/ACCoNet_Res/' 83 | else: 84 | save_path = 'models/ACCoNet_VGG/' 85 | 86 | if not os.path.exists(save_path): 87 | os.makedirs(save_path) 88 | if (epoch+1) % 5 == 0: 89 | if opt.is_ResNet: 90 | torch.save(model.state_dict(), save_path + 'ACCoNet_ResNet.pth' + '.%d' % epoch) 91 | else: 92 | torch.save(model.state_dict(), save_path + 'ACCoNet_VGG.pth' + '.%d' % epoch) 93 | 94 | print("Let's go!") 95 | for epoch in range(1, opt.epoch): 96 | adjust_lr(optimizer, opt.lr, epoch, opt.decay_rate, opt.decay_epoch) 97 | train(train_loader, model, optimizer, epoch) 98 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | def clip_gradient(optimizer, grad_clip): 2 | for group in optimizer.param_groups: 3 | for param in group['params']: 4 | if param.grad is not None: 5 | param.grad.data.clamp_(-grad_clip, grad_clip) 6 | 7 | 8 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=30): 9 | decay = decay_rate ** (epoch // decay_epoch) 10 | for param_group in optimizer.param_groups: 11 | param_group['lr'] = init_lr*decay 12 | print('decay_epoch: {}, Current_LR: {}'.format(decay_epoch, init_lr*decay)) --------------------------------------------------------------------------------