├── Annotations.png ├── README.md ├── data.py ├── model ├── HolisticAttention.py ├── ResNet.py ├── ResNet_models_combine.py ├── ResNet_models_sep.py ├── __init__.py └── batchrenorm.py ├── test.py ├── train.py └── utils.py /Annotations.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/JingZhang617/cascaded_rgbd_sod/95e55b3fc8895e59a9707099d6d9704a981f9876/Annotations.png -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Cascaded RGB-D SOD with COME15K dataset (ICCV2021) 2 | This is the official implementaion of CLNet paper "RGB-D Saliency Detection via Cascaded Mutual Information Minimization". 3 | 4 | 5 | # COME15K RGB-D SOD Dataset 6 | We provide the COME training dataset, which include 8,025 image pairs of RGB-D images for SOD training. The dataset can be found at: 7 | https://drive.google.com/drive/folders/1mGbFKlIJNeW0m7hE-1dGX0b2gcxSMXjB?usp=sharing 8 | 9 | We further introduce two sets of testing dataset, namely COME-E and COME-H, which include 4,600 and 3,000 image pairs respectively, and can be downloaded at: 10 | https://drive.google.com/drive/folders/1w0M9YmYBzkMLijy_Blg6RMRshSvPZju-?usp=sharing 11 | 12 | 训练数据百度网盘: 13 | 链接:https://pan.baidu.com/s/15vaAkGuLVYPGuuYujDuhXg 密码:m2er 14 | 15 | 测试数据百度网盘: 16 | 链接:https://pan.baidu.com/s/1Ohidx48adju5gMI_hkGzug 密码:dofk 17 | 18 | # Rich Annotations 19 | ![alt text](./Annotations.png) 20 | 21 | For both the new training and testing dataset, we provide binary ground truth annotations, instance level annotations, saliency ranking (0,1,2,3,4, where 4 indicates the most salient instance). We also provide the raw annotations of five different annotators. In this way, each image will have five binary saliency annotations from five different annotators. 22 | 23 | # New benchmark 24 | With our new training dataset, we re-train existing RGB-D SOD models, and test on ten benchmark testing dataset, including: SSB (STERE), DES, NLPR, NJU2K, LFSD, SIP, DUT-RGBD, RedWeb-S and our COME-E and COME-H. Please find saliency maps of retained models at (constantly updating): 25 | https://drive.google.com/drive/folders/1lCE8OHeqNdjhE4--yR0FFib2C5DBTgwn?usp=sharing 26 | 27 | # Retrain existing models 28 | We retrain state-of-the-art RGB-D SOD models with our new training dataset, and the re-trained models can be found at: 29 | https://drive.google.com/drive/folders/18Tqsn3yYoYO9HH8ZNVhHOTrJ7-UWPAZs?usp=sharing 30 | 31 | # Our trained model on conventional training dataset (the combination of NLPR and NJU2K data) and the produced saliency maps on SSB (STERE), DES, NLPR, NJU2K, LFSD, SIP: 32 | 33 | model: https://drive.google.com/file/d/1gUubs1eGr2fnrlgze-EFhD9XbghfAyhK/view?usp=share_link 34 | 35 | maps: https://drive.google.com/file/d/1OPTc7NsGQq9uYBdfquLIbluMikWezYO8/view?usp=share_link 36 | 37 | Note that, due to being stochastic, the model can perform slightly different each time of training. 38 | Solutions to get deterministic models: 39 | 1) instead of using the reparameterization trick as we used, you can simply follow the auto-encoder learning pipeline, and map features directly to the embedding space;--- to achieve this, you will need to remove the variance mapping function; 40 | 2) or you can simply define variance as 0, leading to deterministic generatration of the latent code, which is in practice easier to implement. 41 | 42 | # Our Bib: 43 | 44 | Please cite our paper if necessary: 45 | ``` 46 | @inproceedings{cascaded_rgbd_sod, 47 | title={RGB-D Saliency Detection via Cascaded Mutual Information Minimization}, 48 | author={Zhang, Jing and Fan, Deng-Ping and Dai, Yuchao and Yu, Xin and Zhong, Yiran and Barnes, Nick and Shao, Ling}, 49 | booktitle={International Conference on Computer Vision (ICCV)}, 50 | year={2021} 51 | } 52 | ``` 53 | # Copyright 54 | Creative Commons License
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License. 55 | 56 | # Privacy 57 | This dataset is made available for academic use only. If you find yourself or personal belongings in this dataset and feel unwell about it, please contact us and we will immediately remove the respective data. 58 | 59 | # Contact 60 | 61 | Please drop me an email for further problems or discussion: zjnwpu@gmail.com 62 | -------------------------------------------------------------------------------- /data.py: -------------------------------------------------------------------------------- 1 | import os 2 | from PIL import Image 3 | import torch.utils.data as data 4 | import torchvision.transforms as transforms 5 | import random 6 | import numpy as np 7 | from PIL import ImageEnhance 8 | 9 | 10 | # several data augumentation strategies 11 | def cv_random_flip(img, label, depth): 12 | flip_flag = random.randint(0, 1) 13 | # flip_flag2= random.randint(0,1) 14 | # left right flip 15 | if flip_flag == 1: 16 | img = img.transpose(Image.FLIP_LEFT_RIGHT) 17 | label = label.transpose(Image.FLIP_LEFT_RIGHT) 18 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT) 19 | # top bottom flip 20 | # if flip_flag2==1: 21 | # img = img.transpose(Image.FLIP_TOP_BOTTOM) 22 | # label = label.transpose(Image.FLIP_TOP_BOTTOM) 23 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM) 24 | return img, label, depth 25 | 26 | 27 | def randomCrop(image, label, depth): 28 | border = 30 29 | image_width = image.size[0] 30 | image_height = image.size[1] 31 | crop_win_width = np.random.randint(image_width - border, image_width) 32 | crop_win_height = np.random.randint(image_height - border, image_height) 33 | random_region = ( 34 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1, 35 | (image_height + crop_win_height) >> 1) 36 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region) 37 | 38 | 39 | def randomRotation(image, label, depth): 40 | mode = Image.BICUBIC 41 | if random.random() > 0.8: 42 | random_angle = np.random.randint(-15, 15) 43 | image = image.rotate(random_angle, mode) 44 | label = label.rotate(random_angle, mode) 45 | depth = depth.rotate(random_angle, mode) 46 | return image, label, depth 47 | 48 | 49 | def colorEnhance(image): 50 | bright_intensity = random.randint(5, 15) / 10.0 51 | image = ImageEnhance.Brightness(image).enhance(bright_intensity) 52 | contrast_intensity = random.randint(5, 15) / 10.0 53 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity) 54 | color_intensity = random.randint(0, 20) / 10.0 55 | image = ImageEnhance.Color(image).enhance(color_intensity) 56 | sharp_intensity = random.randint(0, 30) / 10.0 57 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity) 58 | return image 59 | 60 | 61 | def randomGaussian(image, mean=0.1, sigma=0.35): 62 | def gaussianNoisy(im, mean=mean, sigma=sigma): 63 | for _i in range(len(im)): 64 | im[_i] += random.gauss(mean, sigma) 65 | return im 66 | 67 | img = np.asarray(image) 68 | width, height = img.shape 69 | img = gaussianNoisy(img[:].flatten(), mean, sigma) 70 | img = img.reshape([width, height]) 71 | return Image.fromarray(np.uint8(img)) 72 | 73 | 74 | def randomPeper(img): 75 | img = np.array(img) 76 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1]) 77 | for i in range(noiseNum): 78 | 79 | randX = random.randint(0, img.shape[0] - 1) 80 | 81 | randY = random.randint(0, img.shape[1] - 1) 82 | 83 | if random.randint(0, 1) == 0: 84 | 85 | img[randX, randY] = 0 86 | 87 | else: 88 | 89 | img[randX, randY] = 255 90 | return Image.fromarray(img) 91 | 92 | 93 | # dataset for training 94 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps 95 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved. 96 | class SalObjDataset(data.Dataset): 97 | def __init__(self, image_root, gt_root, depth_root, trainsize): 98 | self.trainsize = trainsize 99 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 100 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg') 101 | or f.endswith('.png')] 102 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 103 | or f.endswith('.png')] 104 | self.images = sorted(self.images) 105 | self.gts = sorted(self.gts) 106 | self.depths = sorted(self.depths) 107 | self.filter_files() 108 | self.size = len(self.images) 109 | self.img_transform = transforms.Compose([ 110 | transforms.Resize((self.trainsize, self.trainsize)), 111 | transforms.ToTensor(), 112 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 113 | self.gt_transform = transforms.Compose([ 114 | transforms.Resize((self.trainsize, self.trainsize)), 115 | transforms.ToTensor()]) 116 | self.depths_transform = transforms.Compose( 117 | [transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()]) 118 | 119 | def __getitem__(self, index): 120 | image = self.rgb_loader(self.images[index]) 121 | gt = self.binary_loader(self.gts[index]) 122 | depth = self.rgb_loader(self.depths[index]) 123 | image, gt, depth = cv_random_flip(image, gt, depth) 124 | image, gt, depth = randomCrop(image, gt, depth) 125 | image, gt, depth = randomRotation(image, gt, depth) 126 | image = colorEnhance(image) 127 | # gt=randomGaussian(gt) 128 | gt = randomPeper(gt) 129 | image = self.img_transform(image) 130 | gt = self.gt_transform(gt) 131 | depth = self.depths_transform(depth) 132 | 133 | return image, gt, depth 134 | 135 | def filter_files(self): 136 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images) 137 | images = [] 138 | gts = [] 139 | depths = [] 140 | for img_path, gt_path, depth_path in zip(self.images, self.gts, self.depths): 141 | img = Image.open(img_path) 142 | gt = Image.open(gt_path) 143 | depth = Image.open(depth_path) 144 | if img.size == gt.size and gt.size == depth.size: 145 | images.append(img_path) 146 | gts.append(gt_path) 147 | depths.append(depth_path) 148 | self.images = images 149 | self.gts = gts 150 | self.depths = depths 151 | 152 | def rgb_loader(self, path): 153 | with open(path, 'rb') as f: 154 | img = Image.open(f) 155 | return img.convert('RGB') 156 | 157 | def binary_loader(self, path): 158 | with open(path, 'rb') as f: 159 | img = Image.open(f) 160 | return img.convert('L') 161 | 162 | def resize(self, img, gt, depth): 163 | assert img.size == gt.size and gt.size == depth.size 164 | w, h = img.size 165 | if h < self.trainsize or w < self.trainsize: 166 | h = max(h, self.trainsize) 167 | w = max(w, self.trainsize) 168 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h), 169 | Image.NEAREST) 170 | else: 171 | return img, gt, depth 172 | 173 | def __len__(self): 174 | return self.size 175 | 176 | 177 | # dataloader for training 178 | def get_loader(image_root, gt_root, depth_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True): 179 | dataset = SalObjDataset(image_root, gt_root, depth_root, trainsize) 180 | data_loader = data.DataLoader(dataset=dataset, 181 | batch_size=batchsize, 182 | shuffle=shuffle, 183 | num_workers=num_workers, 184 | pin_memory=pin_memory) 185 | return data_loader 186 | 187 | 188 | # test dataset and loader 189 | class test_dataset: 190 | def __init__(self, image_root, depth_root, testsize): 191 | self.testsize = testsize 192 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')] 193 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp') 194 | or f.endswith('.png')] 195 | self.images = sorted(self.images) 196 | self.depths = sorted(self.depths) 197 | self.transform = transforms.Compose([ 198 | transforms.Resize((self.testsize, self.testsize)), 199 | transforms.ToTensor(), 200 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])]) 201 | # self.gt_transform = transforms.Compose([ 202 | # transforms.Resize((self.trainsize, self.trainsize)), 203 | # transforms.ToTensor()]) 204 | self.depths_transform = transforms.Compose( 205 | [transforms.Resize((self.testsize, self.testsize)), transforms.ToTensor()]) 206 | self.size = len(self.images) 207 | self.index = 0 208 | 209 | def load_data(self): 210 | image = self.rgb_loader(self.images[self.index]) 211 | HH = image.size[0] 212 | WW = image.size[1] 213 | image = self.transform(image).unsqueeze(0) 214 | depth = self.rgb_loader(self.depths[self.index]) 215 | depth = self.depths_transform(depth).unsqueeze(0) 216 | 217 | name = self.images[self.index].split('/')[-1] 218 | # image_for_post=self.rgb_loader(self.images[self.index]) 219 | # image_for_post=image_for_post.resize(gt.size) 220 | if name.endswith('.jpg'): 221 | name = name.split('.jpg')[0] + '.png' 222 | self.index += 1 223 | self.index = self.index % self.size 224 | return image, depth, HH, WW, name 225 | 226 | def rgb_loader(self, path): 227 | with open(path, 'rb') as f: 228 | img = Image.open(f) 229 | return img.convert('RGB') 230 | 231 | def binary_loader(self, path): 232 | with open(path, 'rb') as f: 233 | img = Image.open(f) 234 | return img.convert('L') 235 | 236 | def __len__(self): 237 | return self.size 238 | 239 | -------------------------------------------------------------------------------- /model/HolisticAttention.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import torch.nn as nn 4 | from torch.nn.parameter import Parameter 5 | 6 | import numpy as np 7 | import scipy.stats as st 8 | 9 | 10 | def gkern(kernlen=16, nsig=3): 11 | interval = (2*nsig+1.)/kernlen 12 | x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1) 13 | kern1d = np.diff(st.norm.cdf(x)) 14 | kernel_raw = np.sqrt(np.outer(kern1d, kern1d)) 15 | kernel = kernel_raw/kernel_raw.sum() 16 | return kernel 17 | 18 | 19 | def min_max_norm(in_): 20 | max_ = in_.max(3)[0].max(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_) 21 | min_ = in_.min(3)[0].min(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_) 22 | in_ = in_ - min_ 23 | return in_.div(max_-min_+1e-8) 24 | 25 | 26 | class HA(nn.Module): 27 | # holistic attention module 28 | def __init__(self): 29 | super(HA, self).__init__() 30 | gaussian_kernel = np.float32(gkern(31, 4)) 31 | gaussian_kernel = gaussian_kernel[np.newaxis, np.newaxis, ...] 32 | self.gaussian_kernel = Parameter(torch.from_numpy(gaussian_kernel)) 33 | 34 | def forward(self, attention, x): 35 | soft_attention = F.conv2d(attention, self.gaussian_kernel, padding=15) 36 | soft_attention = min_max_norm(soft_attention) 37 | x = torch.mul(x, soft_attention.max(attention)) 38 | return x 39 | -------------------------------------------------------------------------------- /model/ResNet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | 4 | 5 | def conv3x3(in_planes, out_planes, stride=1): 6 | """3x3 convolution with padding""" 7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 8 | padding=1, bias=False) 9 | 10 | 11 | class BasicBlock(nn.Module): 12 | expansion = 1 13 | 14 | def __init__(self, inplanes, planes, stride=1, downsample=None): 15 | super(BasicBlock, self).__init__() 16 | self.conv1 = conv3x3(inplanes, planes, stride) 17 | self.bn1 = nn.BatchNorm2d(planes) 18 | self.relu = nn.ReLU(inplace=True) 19 | self.conv2 = conv3x3(planes, planes) 20 | self.bn2 = nn.BatchNorm2d(planes) 21 | self.downsample = downsample 22 | self.stride = stride 23 | 24 | def forward(self, x): 25 | residual = x 26 | 27 | out = self.conv1(x) 28 | out = self.bn1(out) 29 | out = self.relu(out) 30 | 31 | out = self.conv2(out) 32 | out = self.bn2(out) 33 | 34 | if self.downsample is not None: 35 | residual = self.downsample(x) 36 | 37 | out += residual 38 | out = self.relu(out) 39 | 40 | return out 41 | 42 | 43 | class Bottleneck(nn.Module): 44 | expansion = 4 45 | 46 | def __init__(self, inplanes, planes, stride=1, downsample=None): 47 | super(Bottleneck, self).__init__() 48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 49 | self.bn1 = nn.BatchNorm2d(planes) 50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 51 | padding=1, bias=False) 52 | self.bn2 = nn.BatchNorm2d(planes) 53 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False) 54 | self.bn3 = nn.BatchNorm2d(planes * 4) 55 | self.relu = nn.ReLU(inplace=True) 56 | self.downsample = downsample 57 | self.stride = stride 58 | 59 | def forward(self, x): 60 | residual = x 61 | 62 | out = self.conv1(x) 63 | out = self.bn1(out) 64 | out = self.relu(out) 65 | 66 | out = self.conv2(out) 67 | out = self.bn2(out) 68 | out = self.relu(out) 69 | 70 | out = self.conv3(out) 71 | out = self.bn3(out) 72 | 73 | if self.downsample is not None: 74 | residual = self.downsample(x) 75 | 76 | out += residual 77 | out = self.relu(out) 78 | 79 | return out 80 | 81 | 82 | class B2_ResNet(nn.Module): 83 | # ResNet50 with two branches 84 | def __init__(self): 85 | # self.inplanes = 128 86 | self.inplanes = 64 87 | super(B2_ResNet, self).__init__() 88 | 89 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 90 | bias=False) 91 | self.bn1 = nn.BatchNorm2d(64) 92 | self.relu = nn.ReLU(inplace=True) 93 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 94 | self.layer1 = self._make_layer(Bottleneck, 64, 3) 95 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2) 96 | self.layer3_1 = self._make_layer(Bottleneck, 256, 6, stride=2) 97 | self.layer4_1 = self._make_layer(Bottleneck, 512, 3, stride=2) 98 | 99 | self.inplanes = 512 100 | self.layer3_2 = self._make_layer(Bottleneck, 256, 6, stride=2) 101 | self.layer4_2 = self._make_layer(Bottleneck, 512, 3, stride=2) 102 | 103 | for m in self.modules(): 104 | if isinstance(m, nn.Conv2d): 105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 106 | m.weight.data.normal_(0, math.sqrt(2. / n)) 107 | elif isinstance(m, nn.BatchNorm2d): 108 | m.weight.data.fill_(1) 109 | m.bias.data.zero_() 110 | 111 | def _make_layer(self, block, planes, blocks, stride=1): 112 | downsample = None 113 | if stride != 1 or self.inplanes != planes * block.expansion: 114 | downsample = nn.Sequential( 115 | nn.Conv2d(self.inplanes, planes * block.expansion, 116 | kernel_size=1, stride=stride, bias=False), 117 | nn.BatchNorm2d(planes * block.expansion), 118 | ) 119 | 120 | layers = [] 121 | layers.append(block(self.inplanes, planes, stride, downsample)) 122 | self.inplanes = planes * block.expansion 123 | for i in range(1, blocks): 124 | layers.append(block(self.inplanes, planes)) 125 | 126 | return nn.Sequential(*layers) 127 | 128 | def forward(self, x): 129 | x = self.conv1(x) 130 | x = self.bn1(x) 131 | x = self.relu(x) 132 | x = self.maxpool(x) 133 | 134 | x = self.layer1(x) 135 | x = self.layer2(x) 136 | x1 = self.layer3_1(x) 137 | x1 = self.layer4_1(x1) 138 | 139 | x2 = self.layer3_2(x) 140 | x2 = self.layer4_2(x2) 141 | 142 | return x1, x2 143 | -------------------------------------------------------------------------------- /model/ResNet_models_combine.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import numpy as np 5 | from model.ResNet import B2_ResNet 6 | from utils import init_weights,init_weights_orthogonal_normal 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | from torch.autograd import Variable 9 | from torch.nn import Parameter, Softmax 10 | import torch.nn.functional as F 11 | from torch.distributions import Normal, Independent, kl 12 | from model.HolisticAttention import HA 13 | import math 14 | CE = torch.nn.BCELoss(reduction='sum') 15 | cos_sim = torch.nn.CosineSimilarity(dim=1,eps=1e-8) 16 | from model.batchrenorm import BatchRenorm2d 17 | 18 | class BasicConv2d(nn.Module): 19 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 20 | super(BasicConv2d, self).__init__() 21 | self.conv = nn.Conv2d(in_planes, out_planes, 22 | kernel_size=kernel_size, stride=stride, 23 | padding=padding, dilation=dilation, bias=False) 24 | self.bn = nn.BatchNorm2d(out_planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | 27 | def forward(self, x): 28 | x = self.conv(x) 29 | x = self.bn(x) 30 | return x 31 | 32 | class Classifier_Module(nn.Module): 33 | def __init__(self,dilation_series,padding_series,NoLabels, input_channel): 34 | super(Classifier_Module, self).__init__() 35 | self.conv2d_list = nn.ModuleList() 36 | for dilation,padding in zip(dilation_series,padding_series): 37 | self.conv2d_list.append(nn.Conv2d(input_channel,NoLabels,kernel_size=3,stride=1, padding =padding, dilation = dilation,bias = True)) 38 | for m in self.conv2d_list: 39 | m.weight.data.normal_(0, 0.01) 40 | 41 | def forward(self, x): 42 | out = self.conv2d_list[0](x) 43 | for i in range(len(self.conv2d_list)-1): 44 | out += self.conv2d_list[i+1](x) 45 | return out 46 | 47 | class Mutual_info_reg(nn.Module): 48 | def __init__(self, input_channels, channels, latent_size): 49 | super(Mutual_info_reg, self).__init__() 50 | self.contracting_path = nn.ModuleList() 51 | self.input_channels = input_channels 52 | self.relu = nn.ReLU(inplace=True) 53 | self.layer1 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1) 54 | self.bn1 = nn.BatchNorm2d(channels) 55 | self.layer2 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1) 56 | self.bn2 = nn.BatchNorm2d(channels) 57 | self.layer3 = nn.Conv2d(channels, channels, kernel_size=4, stride=2, padding=1) 58 | self.layer4 = nn.Conv2d(channels, channels, kernel_size=4, stride=2, padding=1) 59 | 60 | self.channel = channels 61 | 62 | self.fc1_rgb1 = nn.Linear(channels * 1 * 16 * 16, latent_size) 63 | self.fc2_rgb1 = nn.Linear(channels * 1 * 16 * 16, latent_size) 64 | self.fc1_depth1 = nn.Linear(channels * 1 * 16 * 16, latent_size) 65 | self.fc2_depth1 = nn.Linear(channels * 1 * 16 * 16, latent_size) 66 | 67 | self.fc1_rgb2 = nn.Linear(channels * 1 * 22 * 22, latent_size) 68 | self.fc2_rgb2 = nn.Linear(channels * 1 * 22 * 22, latent_size) 69 | self.fc1_depth2 = nn.Linear(channels * 1 * 22 * 22, latent_size) 70 | self.fc2_depth2 = nn.Linear(channels * 1 * 22 * 22, latent_size) 71 | 72 | self.fc1_rgb3 = nn.Linear(channels * 1 * 28 * 28, latent_size) 73 | self.fc2_rgb3 = nn.Linear(channels * 1 * 28 * 28, latent_size) 74 | self.fc1_depth3 = nn.Linear(channels * 1 * 28 * 28, latent_size) 75 | self.fc2_depth3 = nn.Linear(channels * 1 * 28 * 28, latent_size) 76 | 77 | self.leakyrelu = nn.LeakyReLU() 78 | self.tanh = torch.nn.Tanh() 79 | 80 | def kl_divergence(self, posterior_latent_space, prior_latent_space): 81 | kl_div = kl.kl_divergence(posterior_latent_space, prior_latent_space) 82 | return kl_div 83 | 84 | def reparametrize(self, mu, logvar): 85 | std = logvar.mul(0.5).exp_() 86 | eps = torch.cuda.FloatTensor(std.size()).normal_() 87 | eps = Variable(eps) 88 | return eps.mul(std).add_(mu) 89 | 90 | def forward(self, rgb_feat, depth_feat): 91 | rgb_feat = self.layer3(self.leakyrelu(self.bn1(self.layer1(rgb_feat)))) 92 | depth_feat = self.layer4(self.leakyrelu(self.bn2(self.layer2(depth_feat)))) 93 | # print(rgb_feat.size()) 94 | # print(depth_feat.size()) 95 | if rgb_feat.shape[2] == 16: 96 | rgb_feat = rgb_feat.view(-1, self.channel * 1 * 16 * 16) 97 | depth_feat = depth_feat.view(-1, self.channel * 1 * 16 * 16) 98 | 99 | mu_rgb = self.fc1_rgb1(rgb_feat) 100 | logvar_rgb = self.fc2_rgb1(rgb_feat) 101 | mu_depth = self.fc1_depth1(depth_feat) 102 | logvar_depth = self.fc2_depth1(depth_feat) 103 | elif rgb_feat.shape[2] == 22: 104 | rgb_feat = rgb_feat.view(-1, self.channel * 1 * 22 * 22) 105 | depth_feat = depth_feat.view(-1, self.channel * 1 * 22 * 22) 106 | mu_rgb = self.fc1_rgb2(rgb_feat) 107 | logvar_rgb = self.fc2_rgb2(rgb_feat) 108 | mu_depth = self.fc1_depth2(depth_feat) 109 | logvar_depth = self.fc2_depth2(depth_feat) 110 | else: 111 | rgb_feat = rgb_feat.view(-1, self.channel * 1 * 28 * 28) 112 | depth_feat = depth_feat.view(-1, self.channel * 1 * 28 * 28) 113 | mu_rgb = self.fc1_rgb3(rgb_feat) 114 | logvar_rgb = self.fc2_rgb3(rgb_feat) 115 | mu_depth = self.fc1_depth3(depth_feat) 116 | logvar_depth = self.fc2_depth3(depth_feat) 117 | 118 | mu_depth = self.tanh(mu_depth) 119 | mu_rgb = self.tanh(mu_rgb) 120 | logvar_depth = self.tanh(logvar_depth) 121 | logvar_rgb = self.tanh(logvar_rgb) 122 | z_rgb = self.reparametrize(mu_rgb, logvar_rgb) 123 | dist_rgb = Independent(Normal(loc=mu_rgb, scale=torch.exp(logvar_rgb)), 1) 124 | z_depth = self.reparametrize(mu_depth, logvar_depth) 125 | dist_depth = Independent(Normal(loc=mu_depth, scale=torch.exp(logvar_depth)), 1) 126 | bi_di_kld = torch.mean(self.kl_divergence(dist_rgb, dist_depth)) + torch.mean( 127 | self.kl_divergence(dist_depth, dist_rgb)) 128 | z_rgb_norm = torch.sigmoid(z_rgb) 129 | z_depth_norm = torch.sigmoid(z_depth) 130 | ce_rgb_depth = CE(z_rgb_norm,z_depth_norm.detach()) 131 | ce_depth_rgb = CE(z_depth_norm, z_rgb_norm.detach()) 132 | latent_loss = ce_rgb_depth+ce_depth_rgb-bi_di_kld 133 | # latent_loss = torch.abs(cos_sim(z_rgb,z_depth)).sum() 134 | 135 | return latent_loss, z_rgb, z_depth 136 | 137 | 138 | 139 | class CAM_Module(nn.Module): 140 | """ Channel attention module""" 141 | def __init__(self): 142 | super(CAM_Module, self).__init__() 143 | self.gamma = Parameter(torch.zeros(1)) 144 | self.softmax = Softmax(dim=-1) 145 | def forward(self,x): 146 | """ 147 | inputs : 148 | x : input feature maps( B X C X H X W) 149 | returns : 150 | out : attention value + input feature 151 | attention: B X C X C 152 | """ 153 | m_batchsize, C, height, width = x.size() 154 | proj_query = x.view(m_batchsize, C, -1) 155 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) 156 | energy = torch.bmm(proj_query, proj_key) 157 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy 158 | attention = self.softmax(energy_new) 159 | proj_value = x.view(m_batchsize, C, -1) 160 | 161 | out = torch.bmm(attention, proj_value) 162 | out = out.view(m_batchsize, C, height, width) 163 | 164 | out = self.gamma*out + x 165 | return out 166 | 167 | ## Channel Attention (CA) Layer 168 | class CALayer(nn.Module): 169 | def __init__(self, channel, reduction=16): 170 | super(CALayer, self).__init__() 171 | # global average pooling: feature --> point 172 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 173 | # feature channel downscale and upscale --> channel weight 174 | self.conv_du = nn.Sequential( 175 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 176 | nn.ReLU(inplace=True), 177 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 178 | nn.Sigmoid() 179 | ) 180 | 181 | def forward(self, x): 182 | y = self.avg_pool(x) 183 | y = self.conv_du(y) 184 | return x * y 185 | 186 | ## Residual Channel Attention Block (RCAB) 187 | 188 | 189 | 190 | class RCAB(nn.Module): 191 | def __init__( 192 | self, n_feat, kernel_size=3, reduction=16, 193 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 194 | 195 | super(RCAB, self).__init__() 196 | modules_body = [] 197 | for i in range(2): 198 | modules_body.append(self.default_conv(n_feat, n_feat, kernel_size, bias=bias)) 199 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 200 | if i == 0: modules_body.append(act) 201 | modules_body.append(CALayer(n_feat, reduction)) 202 | self.body = nn.Sequential(*modules_body) 203 | self.res_scale = res_scale 204 | 205 | def default_conv(self, in_channels, out_channels, kernel_size, bias=True): 206 | return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size // 2), bias=bias) 207 | 208 | def forward(self, x): 209 | res = self.body(x) 210 | #res = self.body(x).mul(self.res_scale) 211 | res += x 212 | return res 213 | 214 | 215 | 216 | class Saliency_feat_decoder(nn.Module): 217 | # resnet based encoder decoder 218 | def __init__(self, channel,latent_dim): 219 | super(Saliency_feat_decoder, self).__init__() 220 | self.relu = nn.ReLU(inplace=True) 221 | self.dropout = nn.Dropout(0.3) 222 | 223 | self.layer5 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 2048) 224 | self.layer6 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], 1, channel) 225 | self.layer7 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], 1, channel) 226 | 227 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 228 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 229 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 230 | 231 | self.conv1 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 256) 232 | self.conv2 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 512) 233 | self.conv3 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 1024) 234 | self.conv4 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 2048) 235 | 236 | self.spatial_axes = [2, 3] 237 | 238 | self.racb_43 = RCAB(channel * 2) 239 | self.racb_432 = RCAB(channel * 3) 240 | self.racb_4321 = RCAB(channel * 4) 241 | 242 | self.conv43 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 2*channel) 243 | self.conv432 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 3*channel) 244 | self.conv4321 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 4*channel) 245 | 246 | self.layer_depth = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], 3, channel * 4) 247 | 248 | self.rcab_z1 = RCAB(channel + latent_dim) 249 | self.conv_z1 = BasicConv2d(channel+latent_dim,channel,3,padding=1) 250 | 251 | self.rcab_z2 = RCAB(channel + latent_dim) 252 | self.conv_z2 = BasicConv2d(channel + latent_dim, channel, 3, padding=1) 253 | 254 | self.rcab_z3 = RCAB(channel + latent_dim) 255 | self.conv_z3 = BasicConv2d(channel + latent_dim, channel, 3, padding=1) 256 | 257 | self.rcab_z4 = RCAB(channel + latent_dim) 258 | self.conv_z4 = BasicConv2d(channel + latent_dim, channel, 3, padding=1) 259 | 260 | self.br1 = BatchRenorm2d(channel) 261 | self.br2 = BatchRenorm2d(channel) 262 | self.br3 = BatchRenorm2d(channel) 263 | self.br4 = BatchRenorm2d(channel) 264 | 265 | 266 | def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel): 267 | return block(dilation_series, padding_series, NoLabels, input_channel) 268 | 269 | def tile(self, a, dim, n_tile): 270 | """ 271 | This function is taken form PyTorch forum and mimics the behavior of tf.tile. 272 | Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 273 | """ 274 | init_dim = a.size(dim) 275 | repeat_idx = [1] * a.dim() 276 | repeat_idx[dim] = n_tile 277 | a = a.repeat(*(repeat_idx)) 278 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device) 279 | return torch.index_select(a, dim, order_index) 280 | 281 | def forward(self, x1,x2,x3,x4,z1=None,z2=None,z3=None,z4=None): 282 | conv1_feat = self.br1(self.conv1(x1)) 283 | conv2_feat = self.br2(self.conv2(x2)) 284 | conv3_feat = self.br3(self.conv3(x3)) 285 | conv4_feat = self.br4(self.conv4(x4)) 286 | 287 | if z1!=None: 288 | z1 = torch.unsqueeze(z1, 2) 289 | z1 = self.tile(z1, 2, conv1_feat.shape[self.spatial_axes[0]]) 290 | z1 = torch.unsqueeze(z1, 3) 291 | z1 = self.tile(z1, 3, conv1_feat.shape[self.spatial_axes[1]]) 292 | 293 | z2 = torch.unsqueeze(z2, 2) 294 | z2 = self.tile(z2, 2, conv2_feat.shape[self.spatial_axes[0]]) 295 | z2 = torch.unsqueeze(z2, 3) 296 | z2 = self.tile(z2, 3, conv2_feat.shape[self.spatial_axes[1]]) 297 | 298 | z3 = torch.unsqueeze(z3, 2) 299 | z3 = self.tile(z3, 2, conv3_feat.shape[self.spatial_axes[0]]) 300 | z3 = torch.unsqueeze(z3, 3) 301 | z3 = self.tile(z3, 3, conv3_feat.shape[self.spatial_axes[1]]) 302 | 303 | z4 = torch.unsqueeze(z4, 2) 304 | z4 = self.tile(z4, 2, conv4_feat.shape[self.spatial_axes[0]]) 305 | z4 = torch.unsqueeze(z4, 3) 306 | z4 = self.tile(z4, 3, conv4_feat.shape[self.spatial_axes[1]]) 307 | 308 | conv1_feat = torch.cat((conv1_feat,z1),1) 309 | conv1_feat = self.rcab_z1(conv1_feat) 310 | conv1_feat = self.conv_z1(conv1_feat) 311 | 312 | conv2_feat = torch.cat((conv2_feat, z2), 1) 313 | conv2_feat = self.rcab_z2(conv2_feat) 314 | conv2_feat = self.conv_z2(conv2_feat) 315 | 316 | conv3_feat = torch.cat((conv3_feat, z3), 1) 317 | conv3_feat = self.rcab_z3(conv3_feat) 318 | conv3_feat = self.conv_z3(conv3_feat) 319 | 320 | conv4_feat = torch.cat((conv4_feat, z4), 1) 321 | conv4_feat = self.rcab_z4(conv4_feat) 322 | conv4_feat = self.conv_z4(conv4_feat) 323 | 324 | conv4_feat = self.upsample2(conv4_feat) 325 | 326 | conv43 = torch.cat((conv4_feat, conv3_feat), 1) 327 | conv43 = self.racb_43(conv43) 328 | conv43 = self.conv43(conv43) 329 | 330 | conv43 = self.upsample2(conv43) 331 | conv432 = torch.cat((self.upsample2(conv4_feat), conv43, conv2_feat), 1) 332 | conv432 = self.racb_432(conv432) 333 | conv432 = self.conv432(conv432) 334 | 335 | conv432 = self.upsample2(conv432) 336 | conv4321 = torch.cat((self.upsample4(conv4_feat), self.upsample2(conv43), conv432, conv1_feat), 1) 337 | conv4321 = self.racb_4321(conv4321) 338 | conv4321 = self.conv4321(conv4321) 339 | 340 | sal_init = self.layer6(conv4321) 341 | 342 | return sal_init 343 | 344 | 345 | class Saliency_feat_endecoder(nn.Module): 346 | # resnet based encoder decoder 347 | def __init__(self, channel): 348 | super(Saliency_feat_endecoder, self).__init__() 349 | self.resnet_rgb = B2_ResNet() 350 | self.resnet_depth = B2_ResNet() 351 | self.relu = nn.ReLU(inplace=True) 352 | self.dropout = nn.Dropout(0.3) 353 | self.latent_dim = 6 354 | self.conv_depth1 = BasicConv2d(6, 3, kernel_size=3, padding=1) 355 | self.sal_decoder1 = Saliency_feat_decoder(channel, self.latent_dim) 356 | self.sal_decoder2 = Saliency_feat_decoder(channel, self.latent_dim) 357 | self.sal_decoder3 = Saliency_feat_decoder(channel, self.latent_dim) 358 | self.sal_decoder4 = Saliency_feat_decoder(channel, self.latent_dim) 359 | self.sal_decoder5 = Saliency_feat_decoder(channel, self.latent_dim) 360 | self.sal_decoder6 = Saliency_feat_decoder(channel, self.latent_dim) 361 | 362 | self.HA = HA() 363 | self.upsample05 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True) 364 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 365 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 366 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 367 | self.upsample025 = nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True) 368 | self.upsample0125 = nn.Upsample(scale_factor=0.125, mode='bilinear', align_corners=True) 369 | 370 | 371 | self.convx1_depth = nn.Conv2d(in_channels=256, out_channels=channel, kernel_size=3, padding=1) 372 | self.convx2_depth = nn.Conv2d(in_channels=512, out_channels=channel, kernel_size=3, padding=1) 373 | self.convx3_depth = nn.Conv2d(in_channels=1024, out_channels=channel, kernel_size=3, padding=1) 374 | self.convx4_depth = nn.Conv2d(in_channels=2048, out_channels=channel, kernel_size=3, padding=1) 375 | 376 | self.convx1_rgb = nn.Conv2d(in_channels=256, out_channels=channel, kernel_size=3, padding=1) 377 | self.convx2_rgb = nn.Conv2d(in_channels=512, out_channels=channel, kernel_size=3, padding=1) 378 | self.convx3_rgb = nn.Conv2d(in_channels=1024, out_channels=channel, kernel_size=3, padding=1) 379 | self.convx4_rgb = nn.Conv2d(in_channels=2048, out_channels=channel, kernel_size=3, padding=1) 380 | 381 | self.mi_level1 = Mutual_info_reg(channel,channel,self.latent_dim) 382 | self.mi_level2 = Mutual_info_reg(channel, channel, self.latent_dim) 383 | self.mi_level3 = Mutual_info_reg(channel, channel, self.latent_dim) 384 | self.mi_level4 = Mutual_info_reg(channel, channel, self.latent_dim) 385 | 386 | self.spatial_axes = [2, 3] 387 | self.final_clc = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=3, padding=1) 388 | self.rcab_rgb_feat = RCAB(channel*4) 389 | self.rcab_depth_feat = RCAB(channel*4) 390 | 391 | 392 | 393 | 394 | if self.training: 395 | self.initialize_weights() 396 | 397 | def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel): 398 | return block(dilation_series, padding_series, NoLabels, input_channel) 399 | 400 | def tile(self, a, dim, n_tile): 401 | """ 402 | This function is taken form PyTorch forum and mimics the behavior of tf.tile. 403 | Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 404 | """ 405 | init_dim = a.size(dim) 406 | repeat_idx = [1] * a.dim() 407 | repeat_idx[dim] = n_tile 408 | a = a.repeat(*(repeat_idx)) 409 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device) 410 | return torch.index_select(a, dim, order_index) 411 | 412 | def forward(self, x,depth=None): 413 | raw_x = x 414 | x = self.resnet_rgb.conv1(x) 415 | x = self.resnet_rgb.bn1(x) 416 | x = self.resnet_rgb.relu(x) 417 | x = self.resnet_rgb.maxpool(x) 418 | x1_rgb = self.resnet_rgb.layer1(x) # 256 x 64 x 64 419 | x2_rgb = self.resnet_rgb.layer2(x1_rgb) # 512 x 32 x 32 420 | x3_rgb = self.resnet_rgb.layer3_1(x2_rgb) # 1024 x 16 x 16 421 | x4_rgb = self.resnet_rgb.layer4_1(x3_rgb) # 2048 x 8 x 8 422 | 423 | sal_init_rgb = self.sal_decoder1(x1_rgb, x2_rgb, x3_rgb, x4_rgb) 424 | x2_2_rgb = self.HA(self.upsample05(sal_init_rgb).sigmoid(), x2_rgb) 425 | x3_2_rgb = self.resnet_rgb.layer3_2(x2_2_rgb) # 1024 x 16 x 16 426 | x4_2_rgb = self.resnet_rgb.layer4_2(x3_2_rgb) # 2048 x 8 x 8 427 | sal_ref_rgb = self.sal_decoder2(x1_rgb, x2_2_rgb, x3_2_rgb, x4_2_rgb) 428 | 429 | if depth==None: 430 | return self.upsample4(sal_init_rgb), self.upsample4(sal_ref_rgb) 431 | else: 432 | x = torch.cat((raw_x,depth),1) 433 | x = self.conv_depth1(x) 434 | x = self.resnet_depth.conv1(x) 435 | x = self.resnet_depth.bn1(x) 436 | x = self.resnet_depth.relu(x) 437 | x = self.resnet_depth.maxpool(x) 438 | x1_depth = self.resnet_depth.layer1(x) # 256 x 64 x 64 439 | x2_depth = self.resnet_depth.layer2(x1_depth) # 512 x 32 x 32 440 | x3_depth = self.resnet_depth.layer3_1(x2_depth) # 1024 x 16 x 16 441 | x4_depth = self.resnet_depth.layer4_1(x3_depth) # 2048 x 8 x 8 442 | 443 | sal_init_depth = self.sal_decoder3(x1_depth, x2_depth, x3_depth, x4_depth) 444 | x2_2_depth = self.HA(self.upsample05(sal_init_depth).sigmoid(), x2_depth) 445 | x3_2_depth = self.resnet_depth.layer3_2(x2_2_depth) # 1024 x 16 x 16 446 | x4_2_depth = self.resnet_depth.layer4_2(x3_2_depth) # 2048 x 8 x 8 447 | sal_ref_depth = self.sal_decoder4(x1_depth, x2_2_depth, x3_2_depth, x4_2_depth) 448 | 449 | 450 | lat_loss1, z1_rgb, z1_depth = self.mi_level1(self.convx1_rgb(x1_rgb), self.convx1_depth(x1_depth)) 451 | lat_loss2, z2_rgb, z2_depth = self.mi_level2(self.upsample2(self.convx2_rgb(x2_2_rgb)), self.upsample2(self.convx2_depth(x2_2_depth))) 452 | lat_loss3, z3_rgb, z3_depth = self.mi_level3(self.upsample4(self.convx3_rgb(x3_2_rgb)), self.upsample4(self.convx3_depth(x3_2_depth))) 453 | lat_loss4, z4_rgb, z4_depth = self.mi_level4(self.upsample8(self.convx4_rgb(x4_2_rgb)), self.upsample8(self.convx4_depth(x4_2_depth))) 454 | 455 | lat_loss = lat_loss1+lat_loss2+lat_loss3+lat_loss4 456 | 457 | sal_mi_rgb = self.sal_decoder5(x1_rgb, x2_2_rgb, x3_2_rgb, x4_2_rgb, z1_depth,z2_depth,z3_depth,z4_depth) 458 | 459 | sal_mi_depth = self.sal_decoder6(x1_depth, x2_2_depth, x3_2_depth, x4_2_depth, z1_rgb,z2_rgb,z3_rgb,z4_rgb) 460 | 461 | final_sal = torch.cat((sal_ref_rgb,sal_ref_depth,sal_mi_rgb,sal_mi_depth),1) 462 | final_sal = self.final_clc(final_sal) 463 | 464 | return self.upsample4(sal_init_rgb), self.upsample4(sal_ref_rgb), self.upsample4(sal_init_depth), self.upsample4( 465 | sal_ref_depth), self.upsample4(sal_mi_rgb), self.upsample4(sal_mi_depth), self.upsample4(final_sal), lat_loss 466 | 467 | 468 | 469 | def initialize_weights(self): 470 | res50 = models.resnet50(pretrained=True) 471 | pretrained_dict = res50.state_dict() 472 | all_params = {} 473 | for k, v in self.resnet_rgb.state_dict().items(): 474 | if k in pretrained_dict.keys(): 475 | v = pretrained_dict[k] 476 | all_params[k] = v 477 | elif '_1' in k: 478 | name = k.split('_1')[0] + k.split('_1')[1] 479 | v = pretrained_dict[name] 480 | all_params[k] = v 481 | elif '_2' in k: 482 | name = k.split('_2')[0] + k.split('_2')[1] 483 | v = pretrained_dict[name] 484 | all_params[k] = v 485 | assert len(all_params.keys()) == len(self.resnet_rgb.state_dict().keys()) 486 | self.resnet_rgb.load_state_dict(all_params) 487 | self.resnet_depth.load_state_dict(all_params) 488 | -------------------------------------------------------------------------------- /model/ResNet_models_sep.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torchvision.models as models 4 | import numpy as np 5 | from model.ResNet import B2_ResNet 6 | from utils import init_weights,init_weights_orthogonal_normal 7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') 8 | from torch.autograd import Variable 9 | from torch.nn import Parameter, Softmax 10 | import torch.nn.functional as F 11 | from torch.distributions import Normal, Independent, kl 12 | from model.HolisticAttention import HA 13 | import math 14 | CE = torch.nn.BCELoss(reduction='sum') 15 | cos_sim = torch.nn.CosineSimilarity(dim=1,eps=1e-8) 16 | from model.batchrenorm import BatchRenorm2d 17 | 18 | class BasicConv2d(nn.Module): 19 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1): 20 | super(BasicConv2d, self).__init__() 21 | self.conv = nn.Conv2d(in_planes, out_planes, 22 | kernel_size=kernel_size, stride=stride, 23 | padding=padding, dilation=dilation, bias=False) 24 | self.bn = nn.BatchNorm2d(out_planes) 25 | self.relu = nn.ReLU(inplace=True) 26 | 27 | def forward(self, x): 28 | x = self.conv(x) 29 | x = self.bn(x) 30 | return x 31 | 32 | class Classifier_Module(nn.Module): 33 | def __init__(self,dilation_series,padding_series,NoLabels, input_channel): 34 | super(Classifier_Module, self).__init__() 35 | self.conv2d_list = nn.ModuleList() 36 | for dilation,padding in zip(dilation_series,padding_series): 37 | self.conv2d_list.append(nn.Conv2d(input_channel,NoLabels,kernel_size=3,stride=1, padding =padding, dilation = dilation,bias = True)) 38 | for m in self.conv2d_list: 39 | m.weight.data.normal_(0, 0.01) 40 | 41 | def forward(self, x): 42 | out = self.conv2d_list[0](x) 43 | for i in range(len(self.conv2d_list)-1): 44 | out += self.conv2d_list[i+1](x) 45 | return out 46 | 47 | class Mutual_info_reg(nn.Module): 48 | def __init__(self, input_channels, channels, latent_size): 49 | super(Mutual_info_reg, self).__init__() 50 | self.contracting_path = nn.ModuleList() 51 | self.input_channels = input_channels 52 | self.relu = nn.ReLU(inplace=True) 53 | self.layer1 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1) 54 | self.bn1 = nn.BatchNorm2d(channels) 55 | self.layer2 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1) 56 | self.bn2 = nn.BatchNorm2d(channels) 57 | self.layer3 = nn.Conv2d(channels, channels, kernel_size=4, stride=2, padding=1) 58 | self.layer4 = nn.Conv2d(channels, channels, kernel_size=4, stride=2, padding=1) 59 | 60 | self.channel = channels 61 | 62 | self.fc1_rgb1 = nn.Linear(channels * 1 * 16 * 16, latent_size) 63 | self.fc2_rgb1 = nn.Linear(channels * 1 * 16 * 16, latent_size) 64 | self.fc1_depth1 = nn.Linear(channels * 1 * 16 * 16, latent_size) 65 | self.fc2_depth1 = nn.Linear(channels * 1 * 16 * 16, latent_size) 66 | 67 | self.fc1_rgb2 = nn.Linear(channels * 1 * 22 * 22, latent_size) 68 | self.fc2_rgb2 = nn.Linear(channels * 1 * 22 * 22, latent_size) 69 | self.fc1_depth2 = nn.Linear(channels * 1 * 22 * 22, latent_size) 70 | self.fc2_depth2 = nn.Linear(channels * 1 * 22 * 22, latent_size) 71 | 72 | self.fc1_rgb3 = nn.Linear(channels * 1 * 28 * 28, latent_size) 73 | self.fc2_rgb3 = nn.Linear(channels * 1 * 28 * 28, latent_size) 74 | self.fc1_depth3 = nn.Linear(channels * 1 * 28 * 28, latent_size) 75 | self.fc2_depth3 = nn.Linear(channels * 1 * 28 * 28, latent_size) 76 | 77 | self.leakyrelu = nn.LeakyReLU() 78 | self.tanh = torch.nn.Tanh() 79 | 80 | def kl_divergence(self, posterior_latent_space, prior_latent_space): 81 | kl_div = kl.kl_divergence(posterior_latent_space, prior_latent_space) 82 | return kl_div 83 | 84 | def reparametrize(self, mu, logvar): 85 | std = logvar.mul(0.5).exp_() 86 | eps = torch.cuda.FloatTensor(std.size()).normal_() 87 | eps = Variable(eps) 88 | return eps.mul(std).add_(mu) 89 | 90 | def forward(self, rgb_feat, depth_feat): 91 | rgb_feat = self.layer3(self.leakyrelu(self.bn1(self.layer1(rgb_feat)))) 92 | depth_feat = self.layer4(self.leakyrelu(self.bn2(self.layer2(depth_feat)))) 93 | # print(rgb_feat.size()) 94 | # print(depth_feat.size()) 95 | if rgb_feat.shape[2] == 16: 96 | rgb_feat = rgb_feat.view(-1, self.channel * 1 * 16 * 16) 97 | depth_feat = depth_feat.view(-1, self.channel * 1 * 16 * 16) 98 | 99 | mu_rgb = self.fc1_rgb1(rgb_feat) 100 | logvar_rgb = self.fc2_rgb1(rgb_feat) 101 | mu_depth = self.fc1_depth1(depth_feat) 102 | logvar_depth = self.fc2_depth1(depth_feat) 103 | elif rgb_feat.shape[2] == 22: 104 | rgb_feat = rgb_feat.view(-1, self.channel * 1 * 22 * 22) 105 | depth_feat = depth_feat.view(-1, self.channel * 1 * 22 * 22) 106 | mu_rgb = self.fc1_rgb2(rgb_feat) 107 | logvar_rgb = self.fc2_rgb2(rgb_feat) 108 | mu_depth = self.fc1_depth2(depth_feat) 109 | logvar_depth = self.fc2_depth2(depth_feat) 110 | else: 111 | rgb_feat = rgb_feat.view(-1, self.channel * 1 * 28 * 28) 112 | depth_feat = depth_feat.view(-1, self.channel * 1 * 28 * 28) 113 | mu_rgb = self.fc1_rgb3(rgb_feat) 114 | logvar_rgb = self.fc2_rgb3(rgb_feat) 115 | mu_depth = self.fc1_depth3(depth_feat) 116 | logvar_depth = self.fc2_depth3(depth_feat) 117 | 118 | mu_depth = self.tanh(mu_depth) 119 | mu_rgb = self.tanh(mu_rgb) 120 | logvar_depth = self.tanh(logvar_depth) 121 | logvar_rgb = self.tanh(logvar_rgb) 122 | z_rgb = self.reparametrize(mu_rgb, logvar_rgb) 123 | dist_rgb = Independent(Normal(loc=mu_rgb, scale=torch.exp(logvar_rgb)), 1) 124 | z_depth = self.reparametrize(mu_depth, logvar_depth) 125 | dist_depth = Independent(Normal(loc=mu_depth, scale=torch.exp(logvar_depth)), 1) 126 | bi_di_kld = torch.mean(self.kl_divergence(dist_rgb, dist_depth)) + torch.mean( 127 | self.kl_divergence(dist_depth, dist_rgb)) 128 | z_rgb_norm = torch.sigmoid(z_rgb) 129 | z_depth_norm = torch.sigmoid(z_depth) 130 | ce_rgb_depth = CE(z_rgb_norm,z_depth_norm.detach()) 131 | ce_depth_rgb = CE(z_depth_norm, z_rgb_norm.detach()) 132 | latent_loss = ce_rgb_depth+ce_depth_rgb-bi_di_kld 133 | # latent_loss = torch.abs(cos_sim(z_rgb,z_depth)).sum() 134 | 135 | return latent_loss, z_rgb, z_depth 136 | 137 | 138 | 139 | class CAM_Module(nn.Module): 140 | """ Channel attention module""" 141 | def __init__(self): 142 | super(CAM_Module, self).__init__() 143 | self.gamma = Parameter(torch.zeros(1)) 144 | self.softmax = Softmax(dim=-1) 145 | def forward(self,x): 146 | """ 147 | inputs : 148 | x : input feature maps( B X C X H X W) 149 | returns : 150 | out : attention value + input feature 151 | attention: B X C X C 152 | """ 153 | m_batchsize, C, height, width = x.size() 154 | proj_query = x.view(m_batchsize, C, -1) 155 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1) 156 | energy = torch.bmm(proj_query, proj_key) 157 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy 158 | attention = self.softmax(energy_new) 159 | proj_value = x.view(m_batchsize, C, -1) 160 | 161 | out = torch.bmm(attention, proj_value) 162 | out = out.view(m_batchsize, C, height, width) 163 | 164 | out = self.gamma*out + x 165 | return out 166 | 167 | ## Channel Attention (CA) Layer 168 | class CALayer(nn.Module): 169 | def __init__(self, channel, reduction=16): 170 | super(CALayer, self).__init__() 171 | # global average pooling: feature --> point 172 | self.avg_pool = nn.AdaptiveAvgPool2d(1) 173 | # feature channel downscale and upscale --> channel weight 174 | self.conv_du = nn.Sequential( 175 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True), 176 | nn.ReLU(inplace=True), 177 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True), 178 | nn.Sigmoid() 179 | ) 180 | 181 | def forward(self, x): 182 | y = self.avg_pool(x) 183 | y = self.conv_du(y) 184 | return x * y 185 | 186 | ## Residual Channel Attention Block (RCAB) 187 | 188 | 189 | 190 | class RCAB(nn.Module): 191 | def __init__( 192 | self, n_feat, kernel_size=3, reduction=16, 193 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1): 194 | 195 | super(RCAB, self).__init__() 196 | modules_body = [] 197 | for i in range(2): 198 | modules_body.append(self.default_conv(n_feat, n_feat, kernel_size, bias=bias)) 199 | if bn: modules_body.append(nn.BatchNorm2d(n_feat)) 200 | if i == 0: modules_body.append(act) 201 | modules_body.append(CALayer(n_feat, reduction)) 202 | self.body = nn.Sequential(*modules_body) 203 | self.res_scale = res_scale 204 | 205 | def default_conv(self, in_channels, out_channels, kernel_size, bias=True): 206 | return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size // 2), bias=bias) 207 | 208 | def forward(self, x): 209 | res = self.body(x) 210 | #res = self.body(x).mul(self.res_scale) 211 | res += x 212 | return res 213 | 214 | 215 | 216 | class Saliency_feat_decoder(nn.Module): 217 | # resnet based encoder decoder 218 | def __init__(self, channel,latent_dim): 219 | super(Saliency_feat_decoder, self).__init__() 220 | self.relu = nn.ReLU(inplace=True) 221 | self.dropout = nn.Dropout(0.3) 222 | 223 | self.layer5 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 2048) 224 | self.layer6 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], 1, channel) 225 | self.layer7 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], 1, channel) 226 | 227 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 228 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 229 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 230 | 231 | self.conv1 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 256) 232 | self.conv2 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 512) 233 | self.conv3 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 1024) 234 | self.conv4 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 2048) 235 | 236 | self.spatial_axes = [2, 3] 237 | 238 | self.racb_43 = RCAB(channel * 2) 239 | self.racb_432 = RCAB(channel * 3) 240 | self.racb_4321 = RCAB(channel * 4) 241 | 242 | self.conv43 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 2*channel) 243 | self.conv432 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 3*channel) 244 | self.conv4321 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 4*channel) 245 | 246 | self.layer_depth = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], 3, channel * 4) 247 | 248 | self.rcab_z1 = RCAB(channel + latent_dim) 249 | self.conv_z1 = BasicConv2d(channel+latent_dim,channel,3,padding=1) 250 | 251 | self.rcab_z2 = RCAB(channel + latent_dim) 252 | self.conv_z2 = BasicConv2d(channel + latent_dim, channel, 3, padding=1) 253 | 254 | self.rcab_z3 = RCAB(channel + latent_dim) 255 | self.conv_z3 = BasicConv2d(channel + latent_dim, channel, 3, padding=1) 256 | 257 | self.rcab_z4 = RCAB(channel + latent_dim) 258 | self.conv_z4 = BasicConv2d(channel + latent_dim, channel, 3, padding=1) 259 | 260 | self.br1 = BatchRenorm2d(channel) 261 | self.br2 = BatchRenorm2d(channel) 262 | self.br3 = BatchRenorm2d(channel) 263 | self.br4 = BatchRenorm2d(channel) 264 | 265 | 266 | def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel): 267 | return block(dilation_series, padding_series, NoLabels, input_channel) 268 | 269 | def tile(self, a, dim, n_tile): 270 | """ 271 | This function is taken form PyTorch forum and mimics the behavior of tf.tile. 272 | Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 273 | """ 274 | init_dim = a.size(dim) 275 | repeat_idx = [1] * a.dim() 276 | repeat_idx[dim] = n_tile 277 | a = a.repeat(*(repeat_idx)) 278 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device) 279 | return torch.index_select(a, dim, order_index) 280 | 281 | def forward(self, x1,x2,x3,x4,z1=None,z2=None,z3=None,z4=None): 282 | conv1_feat = self.br1(self.conv1(x1)) 283 | conv2_feat = self.br2(self.conv2(x2)) 284 | conv3_feat = self.br3(self.conv3(x3)) 285 | conv4_feat = self.br4(self.conv4(x4)) 286 | 287 | if z1!=None: 288 | z1 = torch.unsqueeze(z1, 2) 289 | z1 = self.tile(z1, 2, conv1_feat.shape[self.spatial_axes[0]]) 290 | z1 = torch.unsqueeze(z1, 3) 291 | z1 = self.tile(z1, 3, conv1_feat.shape[self.spatial_axes[1]]) 292 | 293 | z2 = torch.unsqueeze(z2, 2) 294 | z2 = self.tile(z2, 2, conv2_feat.shape[self.spatial_axes[0]]) 295 | z2 = torch.unsqueeze(z2, 3) 296 | z2 = self.tile(z2, 3, conv2_feat.shape[self.spatial_axes[1]]) 297 | 298 | z3 = torch.unsqueeze(z3, 2) 299 | z3 = self.tile(z3, 2, conv3_feat.shape[self.spatial_axes[0]]) 300 | z3 = torch.unsqueeze(z3, 3) 301 | z3 = self.tile(z3, 3, conv3_feat.shape[self.spatial_axes[1]]) 302 | 303 | z4 = torch.unsqueeze(z4, 2) 304 | z4 = self.tile(z4, 2, conv4_feat.shape[self.spatial_axes[0]]) 305 | z4 = torch.unsqueeze(z4, 3) 306 | z4 = self.tile(z4, 3, conv4_feat.shape[self.spatial_axes[1]]) 307 | 308 | conv1_feat = torch.cat((conv1_feat,z1),1) 309 | conv1_feat = self.rcab_z1(conv1_feat) 310 | conv1_feat = self.conv_z1(conv1_feat) 311 | 312 | conv2_feat = torch.cat((conv2_feat, z2), 1) 313 | conv2_feat = self.rcab_z2(conv2_feat) 314 | conv2_feat = self.conv_z2(conv2_feat) 315 | 316 | conv3_feat = torch.cat((conv3_feat, z3), 1) 317 | conv3_feat = self.rcab_z3(conv3_feat) 318 | conv3_feat = self.conv_z3(conv3_feat) 319 | 320 | conv4_feat = torch.cat((conv4_feat, z4), 1) 321 | conv4_feat = self.rcab_z4(conv4_feat) 322 | conv4_feat = self.conv_z4(conv4_feat) 323 | 324 | conv4_feat = self.upsample2(conv4_feat) 325 | 326 | conv43 = torch.cat((conv4_feat, conv3_feat), 1) 327 | conv43 = self.racb_43(conv43) 328 | conv43 = self.conv43(conv43) 329 | 330 | conv43 = self.upsample2(conv43) 331 | conv432 = torch.cat((self.upsample2(conv4_feat), conv43, conv2_feat), 1) 332 | conv432 = self.racb_432(conv432) 333 | conv432 = self.conv432(conv432) 334 | 335 | conv432 = self.upsample2(conv432) 336 | conv4321 = torch.cat((self.upsample4(conv4_feat), self.upsample2(conv43), conv432, conv1_feat), 1) 337 | conv4321 = self.racb_4321(conv4321) 338 | conv4321 = self.conv4321(conv4321) 339 | 340 | sal_init = self.layer6(conv4321) 341 | 342 | return sal_init 343 | 344 | 345 | class Saliency_feat_endecoder(nn.Module): 346 | # resnet based encoder decoder 347 | def __init__(self, channel): 348 | super(Saliency_feat_endecoder, self).__init__() 349 | self.resnet_rgb = B2_ResNet() 350 | self.resnet_depth = B2_ResNet() 351 | self.relu = nn.ReLU(inplace=True) 352 | self.dropout = nn.Dropout(0.3) 353 | self.latent_dim = 6 354 | self.conv_depth1 = BasicConv2d(6, 3, kernel_size=3, padding=1) 355 | self.sal_decoder1 = Saliency_feat_decoder(channel, self.latent_dim) 356 | self.sal_decoder2 = Saliency_feat_decoder(channel, self.latent_dim) 357 | self.sal_decoder3 = Saliency_feat_decoder(channel, self.latent_dim) 358 | self.sal_decoder4 = Saliency_feat_decoder(channel, self.latent_dim) 359 | self.sal_decoder5 = Saliency_feat_decoder(channel, self.latent_dim) 360 | self.sal_decoder6 = Saliency_feat_decoder(channel, self.latent_dim) 361 | 362 | self.HA = HA() 363 | self.upsample05 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True) 364 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True) 365 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) 366 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True) 367 | self.upsample025 = nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True) 368 | self.upsample0125 = nn.Upsample(scale_factor=0.125, mode='bilinear', align_corners=True) 369 | 370 | 371 | self.convx1_depth = nn.Conv2d(in_channels=256, out_channels=channel, kernel_size=3, padding=1) 372 | self.convx2_depth = nn.Conv2d(in_channels=512, out_channels=channel, kernel_size=3, padding=1) 373 | self.convx3_depth = nn.Conv2d(in_channels=1024, out_channels=channel, kernel_size=3, padding=1) 374 | self.convx4_depth = nn.Conv2d(in_channels=2048, out_channels=channel, kernel_size=3, padding=1) 375 | 376 | self.convx1_rgb = nn.Conv2d(in_channels=256, out_channels=channel, kernel_size=3, padding=1) 377 | self.convx2_rgb = nn.Conv2d(in_channels=512, out_channels=channel, kernel_size=3, padding=1) 378 | self.convx3_rgb = nn.Conv2d(in_channels=1024, out_channels=channel, kernel_size=3, padding=1) 379 | self.convx4_rgb = nn.Conv2d(in_channels=2048, out_channels=channel, kernel_size=3, padding=1) 380 | 381 | self.mi_level1 = Mutual_info_reg(channel,channel,self.latent_dim) 382 | self.mi_level2 = Mutual_info_reg(channel, channel, self.latent_dim) 383 | self.mi_level3 = Mutual_info_reg(channel, channel, self.latent_dim) 384 | self.mi_level4 = Mutual_info_reg(channel, channel, self.latent_dim) 385 | 386 | self.spatial_axes = [2, 3] 387 | self.final_clc = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=3, padding=1) 388 | self.rcab_rgb_feat = RCAB(channel*4) 389 | self.rcab_depth_feat = RCAB(channel*4) 390 | 391 | 392 | 393 | 394 | if self.training: 395 | self.initialize_weights() 396 | 397 | def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel): 398 | return block(dilation_series, padding_series, NoLabels, input_channel) 399 | 400 | def tile(self, a, dim, n_tile): 401 | """ 402 | This function is taken form PyTorch forum and mimics the behavior of tf.tile. 403 | Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3 404 | """ 405 | init_dim = a.size(dim) 406 | repeat_idx = [1] * a.dim() 407 | repeat_idx[dim] = n_tile 408 | a = a.repeat(*(repeat_idx)) 409 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device) 410 | return torch.index_select(a, dim, order_index) 411 | 412 | def forward(self, x,depth=None): 413 | raw_x = x 414 | x = self.resnet_rgb.conv1(x) 415 | x = self.resnet_rgb.bn1(x) 416 | x = self.resnet_rgb.relu(x) 417 | x = self.resnet_rgb.maxpool(x) 418 | x1_rgb = self.resnet_rgb.layer1(x) # 256 x 64 x 64 419 | x2_rgb = self.resnet_rgb.layer2(x1_rgb) # 512 x 32 x 32 420 | x3_rgb = self.resnet_rgb.layer3_1(x2_rgb) # 1024 x 16 x 16 421 | x4_rgb = self.resnet_rgb.layer4_1(x3_rgb) # 2048 x 8 x 8 422 | 423 | sal_init_rgb = self.sal_decoder1(x1_rgb, x2_rgb, x3_rgb, x4_rgb) 424 | x2_2_rgb = self.HA(self.upsample05(sal_init_rgb).sigmoid(), x2_rgb) 425 | x3_2_rgb = self.resnet_rgb.layer3_2(x2_2_rgb) # 1024 x 16 x 16 426 | x4_2_rgb = self.resnet_rgb.layer4_2(x3_2_rgb) # 2048 x 8 x 8 427 | sal_ref_rgb = self.sal_decoder2(x1_rgb, x2_2_rgb, x3_2_rgb, x4_2_rgb) 428 | 429 | if depth==None: 430 | return self.upsample4(sal_init_rgb), self.upsample4(sal_ref_rgb) 431 | else: 432 | x = depth 433 | x = self.resnet_depth.conv1(x) 434 | x = self.resnet_depth.bn1(x) 435 | x = self.resnet_depth.relu(x) 436 | x = self.resnet_depth.maxpool(x) 437 | x1_depth = self.resnet_depth.layer1(x) # 256 x 64 x 64 438 | x2_depth = self.resnet_depth.layer2(x1_depth) # 512 x 32 x 32 439 | x3_depth = self.resnet_depth.layer3_1(x2_depth) # 1024 x 16 x 16 440 | x4_depth = self.resnet_depth.layer4_1(x3_depth) # 2048 x 8 x 8 441 | 442 | sal_init_depth = self.sal_decoder3(x1_depth, x2_depth, x3_depth, x4_depth) 443 | x2_2_depth = self.HA(self.upsample05(sal_init_depth).sigmoid(), x2_depth) 444 | x3_2_depth = self.resnet_depth.layer3_2(x2_2_depth) # 1024 x 16 x 16 445 | x4_2_depth = self.resnet_depth.layer4_2(x3_2_depth) # 2048 x 8 x 8 446 | sal_ref_depth = self.sal_decoder4(x1_depth, x2_2_depth, x3_2_depth, x4_2_depth) 447 | 448 | 449 | lat_loss1, z1_rgb, z1_depth = self.mi_level1(self.convx1_rgb(x1_rgb), self.convx1_depth(x1_depth)) 450 | lat_loss2, z2_rgb, z2_depth = self.mi_level2(self.upsample2(self.convx2_rgb(x2_2_rgb)), self.upsample2(self.convx2_depth(x2_2_depth))) 451 | lat_loss3, z3_rgb, z3_depth = self.mi_level3(self.upsample4(self.convx3_rgb(x3_2_rgb)), self.upsample4(self.convx3_depth(x3_2_depth))) 452 | lat_loss4, z4_rgb, z4_depth = self.mi_level4(self.upsample8(self.convx4_rgb(x4_2_rgb)), self.upsample8(self.convx4_depth(x4_2_depth))) 453 | 454 | lat_loss = lat_loss1+lat_loss2+lat_loss3+lat_loss4 455 | 456 | sal_mi_rgb = self.sal_decoder5(x1_rgb, x2_2_rgb, x3_2_rgb, x4_2_rgb, z1_depth,z2_depth,z3_depth,z4_depth) 457 | 458 | sal_mi_depth = self.sal_decoder6(x1_depth, x2_2_depth, x3_2_depth, x4_2_depth, z1_rgb,z2_rgb,z3_rgb,z4_rgb) 459 | 460 | final_sal = torch.cat((sal_ref_rgb,sal_ref_depth,sal_mi_rgb,sal_mi_depth),1) 461 | final_sal = self.final_clc(final_sal) 462 | 463 | return self.upsample4(sal_init_rgb), self.upsample4(sal_ref_rgb), self.upsample4(sal_init_depth), self.upsample4( 464 | sal_ref_depth), self.upsample4(sal_mi_rgb), self.upsample4(sal_mi_depth), self.upsample4(final_sal), lat_loss 465 | 466 | 467 | 468 | def initialize_weights(self): 469 | res50 = models.resnet50(pretrained=True) 470 | pretrained_dict = res50.state_dict() 471 | all_params = {} 472 | for k, v in self.resnet_rgb.state_dict().items(): 473 | if k in pretrained_dict.keys(): 474 | v = pretrained_dict[k] 475 | all_params[k] = v 476 | elif '_1' in k: 477 | name = k.split('_1')[0] + k.split('_1')[1] 478 | v = pretrained_dict[name] 479 | all_params[k] = v 480 | elif '_2' in k: 481 | name = k.split('_2')[0] + k.split('_2')[1] 482 | v = pretrained_dict[name] 483 | all_params[k] = v 484 | assert len(all_params.keys()) == len(self.resnet_rgb.state_dict().keys()) 485 | self.resnet_rgb.load_state_dict(all_params) 486 | self.resnet_depth.load_state_dict(all_params) 487 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /model/batchrenorm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | __all__ = ["BatchRenorm1d", "BatchRenorm2d", "BatchRenorm3d"] 5 | 6 | 7 | class BatchRenorm(torch.jit.ScriptModule): 8 | def __init__( 9 | self, 10 | num_features: int, 11 | eps: float = 1e-3, 12 | momentum: float = 0.01, 13 | affine: bool = True, 14 | ): 15 | super().__init__() 16 | self.register_buffer( 17 | "running_mean", torch.zeros(num_features, dtype=torch.float) 18 | ) 19 | self.register_buffer( 20 | "running_std", torch.ones(num_features, dtype=torch.float) 21 | ) 22 | self.register_buffer( 23 | "num_batches_tracked", torch.tensor(0, dtype=torch.long) 24 | ) 25 | self.weight = torch.nn.Parameter( 26 | torch.ones(num_features, dtype=torch.float) 27 | ) 28 | self.bias = torch.nn.Parameter( 29 | torch.zeros(num_features, dtype=torch.float) 30 | ) 31 | self.affine = affine 32 | self.eps = eps 33 | self.step = 0 34 | self.momentum = momentum 35 | 36 | def _check_input_dim(self, x: torch.Tensor) -> None: 37 | raise NotImplementedError() # pragma: no cover 38 | 39 | @property 40 | def rmax(self) -> torch.Tensor: 41 | return (2 / 35000 * self.num_batches_tracked + 25 / 35).clamp_( 42 | 1.0, 3.0 43 | ) 44 | 45 | @property 46 | def dmax(self) -> torch.Tensor: 47 | return (5 / 20000 * self.num_batches_tracked - 25 / 20).clamp_( 48 | 0.0, 5.0 49 | ) 50 | 51 | def forward(self, x: torch.Tensor) -> torch.Tensor: 52 | self._check_input_dim(x) 53 | if x.dim() > 2: 54 | x = x.transpose(1, -1) 55 | if self.training: 56 | dims = [i for i in range(x.dim() - 1)] 57 | batch_mean = x.mean(dims) 58 | batch_std = x.std(dims, unbiased=False) + self.eps 59 | r = ( 60 | batch_std.detach() / self.running_std.view_as(batch_std) 61 | ).clamp_(1 / self.rmax, self.rmax) 62 | d = ( 63 | (batch_mean.detach() - self.running_mean.view_as(batch_mean)) 64 | / self.running_std.view_as(batch_std) 65 | ).clamp_(-self.dmax, self.dmax) 66 | x = (x - batch_mean) / batch_std * r + d 67 | self.running_mean += self.momentum * ( 68 | batch_mean.detach() - self.running_mean 69 | ) 70 | self.running_std += self.momentum * ( 71 | batch_std.detach() - self.running_std 72 | ) 73 | self.num_batches_tracked += 1 74 | else: 75 | x = (x - self.running_mean) / self.running_std 76 | if self.affine: 77 | x = self.weight * x + self.bias 78 | if x.dim() > 2: 79 | x = x.transpose(1, -1) 80 | return x 81 | 82 | 83 | class BatchRenorm1d(BatchRenorm): 84 | def _check_input_dim(self, x: torch.Tensor) -> None: 85 | if x.dim() not in [2, 3]: 86 | raise ValueError("expected 2D or 3D input (got {x.dim()}D input)") 87 | 88 | 89 | class BatchRenorm2d(BatchRenorm): 90 | def _check_input_dim(self, x: torch.Tensor) -> None: 91 | if x.dim() != 4: 92 | raise ValueError("expected 4D input (got {x.dim()}D input)") 93 | 94 | 95 | class BatchRenorm3d(BatchRenorm): 96 | def _check_input_dim(self, x: torch.Tensor) -> None: 97 | if x.dim() != 5: 98 | raise ValueError("expected 5D input (got {x.dim()}D input)") 99 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import pdb, os, argparse 6 | os.environ["CUDA_VISIBLE_DEVICES"] = '1' 7 | from scipy import misc 8 | from model.ResNet_models import Saliency_feat_endecoder 9 | from data import test_dataset 10 | import cv2 11 | 12 | 13 | 14 | parser = argparse.ArgumentParser() 15 | parser.add_argument('--testsize', type=int, default=352, help='testing size') 16 | parser.add_argument('--latent_dim', type=int, default=3, help='latent dim') 17 | parser.add_argument('--feat_channel', type=int, default=64, help='reduced channel of saliency feat') 18 | opt = parser.parse_args() 19 | 20 | dataset_path = './test/' 21 | depth_path = './test/' 22 | 23 | generator = Saliency_feat_endecoder(channel=opt.feat_channel) 24 | generator.load_state_dict(torch.load('./models/Model_100_gen.pth')) 25 | 26 | generator.cuda() 27 | generator.eval() 28 | 29 | # 30 | test_datasets = ['DES', 'LFSD','NJU2K','NLPR','SIP','STERE'] 31 | 32 | for dataset in test_datasets: 33 | save_path = './results/' + dataset + '/' 34 | if not os.path.exists(save_path): 35 | os.makedirs(save_path) 36 | 37 | image_root = dataset_path + dataset + '/RGB/' 38 | depth_root = dataset_path + dataset + '/depth/' 39 | test_loader = test_dataset(image_root, depth_root, opt.testsize) 40 | for i in range(test_loader.size): 41 | print(i) 42 | image, depth, HH, WW, name = test_loader.load_data() 43 | image = image.cuda() 44 | depth = depth.cuda() 45 | _,_,_,_,_,_,generator_pred,_ = generator.forward(image, depth) 46 | res = generator_pred 47 | res = F.upsample(res, size=[WW,HH], mode='bilinear', align_corners=False) 48 | res = res.sigmoid().data.cpu().numpy().squeeze() 49 | res = 255 * (res - res.min()) / (res.max() - res.min() + 1e-8) 50 | cv2.imwrite(save_path + name, res) 51 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | from torch.autograd import Variable 4 | import cv2 5 | import numpy as np 6 | import pdb, os, argparse 7 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 8 | from datetime import datetime 9 | from model.ResNet_models_combine import Saliency_feat_endecoder 10 | from data import get_loader 11 | from utils import adjust_lr, AvgMeter 12 | from scipy import misc 13 | from utils import l2_regularisation 14 | 15 | 16 | parser = argparse.ArgumentParser() 17 | parser.add_argument('--epoch', type=int, default=100, help='epoch number') 18 | parser.add_argument('--lr_gen', type=float, default=1e-4, help='learning rate') 19 | parser.add_argument('--batchsize', type=int, default=12, help='training batch size') 20 | parser.add_argument('--trainsize', type=int, default=352, help='training dataset size') 21 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin') 22 | parser.add_argument('--decay_rate', type=float, default=0.9, help='decay rate of learning rate') 23 | parser.add_argument('--decay_epoch', type=int, default=80, help='every n epochs decay learning rate') 24 | parser.add_argument('-beta1_gen', type=float, default=0.5,help='beta of Adam for generator') 25 | parser.add_argument('--weight_decay', type=float, default=0.001, help='weight_decay') 26 | parser.add_argument('--feat_channel', type=int, default=64, help='reduced channel of saliency feat') 27 | 28 | opt = parser.parse_args() 29 | print('Generator Learning Rate: {}'.format(opt.lr_gen)) 30 | # build models 31 | generator = Saliency_feat_endecoder(channel=opt.feat_channel) 32 | generator.cuda() 33 | 34 | generator_params = generator.parameters() 35 | generator_optimizer = torch.optim.Adam(generator_params, opt.lr_gen) 36 | 37 | ## load data 38 | image_root = './RGB/' 39 | gt_root = './GT/' 40 | depth_root = './depth/' 41 | 42 | train_loader = get_loader(image_root, gt_root, depth_root, batchsize=opt.batchsize, trainsize=opt.trainsize) 43 | total_step = len(train_loader) 44 | 45 | ## define loss 46 | 47 | CE = torch.nn.BCELoss() 48 | mse_loss = torch.nn.MSELoss(size_average=True, reduce=True) 49 | size_rates = [0.75,1,1.25] # multi-scale training 50 | 51 | def structure_loss(pred, mask): 52 | weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask) 53 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none') 54 | wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3)) 55 | 56 | 57 | pred = torch.sigmoid(pred) 58 | inter = ((pred * mask) * weit).sum(dim=(2, 3)) 59 | union = ((pred + mask) * weit).sum(dim=(2, 3)) 60 | wiou = 1-(inter+1)/(union-inter+1) 61 | 62 | return (wbce+wiou).mean() 63 | 64 | def visualize_mi_rgb(var_map): 65 | 66 | for kk in range(var_map.shape[0]): 67 | pred_edge_kk = var_map[kk,:,:,:] 68 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze() 69 | pred_edge_kk *= 255.0 70 | pred_edge_kk = pred_edge_kk.astype(np.uint8) 71 | save_path = './temp/' 72 | name = '{:02d}_rgb_mi.png'.format(kk) 73 | cv2.imwrite(save_path + name, pred_edge_kk) 74 | 75 | def visualize_mi_depth(var_map): 76 | 77 | for kk in range(var_map.shape[0]): 78 | pred_edge_kk = var_map[kk,:,:,:] 79 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze() 80 | pred_edge_kk *= 255.0 81 | pred_edge_kk = pred_edge_kk.astype(np.uint8) 82 | save_path = './temp/' 83 | name = '{:02d}_depth_mi.png'.format(kk) 84 | cv2.imwrite(save_path + name, pred_edge_kk) 85 | 86 | ## visualize predictions and gt 87 | def visualize_rgb_init(var_map): 88 | 89 | for kk in range(var_map.shape[0]): 90 | pred_edge_kk = var_map[kk,:,:,:] 91 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze() 92 | pred_edge_kk *= 255.0 93 | pred_edge_kk = pred_edge_kk.astype(np.uint8) 94 | save_path = './temp/' 95 | name = '{:02d}_rgb_int.png'.format(kk) 96 | cv2.imwrite(save_path + name, pred_edge_kk) 97 | 98 | def visualize_depth_init(var_map): 99 | 100 | for kk in range(var_map.shape[0]): 101 | pred_edge_kk = var_map[kk,:,:,:] 102 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze() 103 | pred_edge_kk *= 255.0 104 | pred_edge_kk = pred_edge_kk.astype(np.uint8) 105 | save_path = './temp/' 106 | name = '{:02d}_depth_int.png'.format(kk) 107 | cv2.imwrite(save_path + name, pred_edge_kk) 108 | 109 | def visualize_rgb_ref(var_map): 110 | 111 | for kk in range(var_map.shape[0]): 112 | pred_edge_kk = var_map[kk,:,:,:] 113 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze() 114 | pred_edge_kk *= 255.0 115 | pred_edge_kk = pred_edge_kk.astype(np.uint8) 116 | save_path = './temp/' 117 | name = '{:02d}_rgb_ref.png'.format(kk) 118 | cv2.imwrite(save_path + name, pred_edge_kk) 119 | 120 | def visualize_depth_ref(var_map): 121 | 122 | for kk in range(var_map.shape[0]): 123 | pred_edge_kk = var_map[kk,:,:,:] 124 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze() 125 | pred_edge_kk *= 255.0 126 | pred_edge_kk = pred_edge_kk.astype(np.uint8) 127 | save_path = './temp/' 128 | name = '{:02d}_depth_ref.png'.format(kk) 129 | cv2.imwrite(save_path + name, pred_edge_kk) 130 | 131 | def visualize_final_rgbd(var_map): 132 | 133 | for kk in range(var_map.shape[0]): 134 | pred_edge_kk = var_map[kk,:,:,:] 135 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze() 136 | pred_edge_kk *= 255.0 137 | pred_edge_kk = pred_edge_kk.astype(np.uint8) 138 | save_path = './temp/' 139 | name = '{:02d}_rgbd.png'.format(kk) 140 | cv2.imwrite(save_path + name, pred_edge_kk) 141 | 142 | def visualize_uncertainty_prior_init(var_map): 143 | 144 | for kk in range(var_map.shape[0]): 145 | pred_edge_kk = var_map[kk,:,:,:] 146 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze() 147 | pred_edge_kk *= 255.0 148 | pred_edge_kk = pred_edge_kk.astype(np.uint8) 149 | save_path = './temp/' 150 | name = '{:02d}_prior_int.png'.format(kk) 151 | cv2.imwrite(save_path + name, pred_edge_kk) 152 | 153 | def visualize_gt(var_map): 154 | 155 | for kk in range(var_map.shape[0]): 156 | pred_edge_kk = var_map[kk,:,:,:] 157 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze() 158 | # pred_edge_kk = (pred_edge_kk - pred_edge_kk.min()) / (pred_edge_kk.max() - pred_edge_kk.min() + 1e-8) 159 | pred_edge_kk *= 255.0 160 | pred_edge_kk = pred_edge_kk.astype(np.uint8) 161 | save_path = './temp/' 162 | name = '{:02d}_gt.png'.format(kk) 163 | cv2.imwrite(save_path + name, pred_edge_kk) 164 | 165 | ## linear annealing to avoid posterior collapse 166 | def linear_annealing(init, fin, step, annealing_steps): 167 | """Linear annealing of a parameter.""" 168 | if annealing_steps == 0: 169 | return fin 170 | assert fin > init 171 | delta = fin - init 172 | annealed = min(init + delta * step / annealing_steps, fin) 173 | return annealed 174 | 175 | def visualize_all_pred(pred1,pred2,pred3,pred4,pred5,pred6,pred7,pred8): 176 | for kk in range(pred1.shape[0]): 177 | pred1_kk, pred2_kk, pred3_kk, pred4_kk, pred5_kk, pred6_kk, pred7_kk, pred8_kk = pred1[kk, :, :, :], pred2[kk, :, :, :], pred3[kk, :, :, :], pred4[kk, :, :, :], pred5[kk, :, :, :], pred6[kk, :, :, :], pred7[kk, :, :, :], pred8[kk, :, :, :] 178 | pred1_kk = (pred1_kk.detach().cpu().numpy().squeeze()*255.0).astype(np.uint8) 179 | pred2_kk = (pred2_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8) 180 | pred3_kk = (pred3_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8) 181 | pred4_kk = (pred4_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8) 182 | pred5_kk = (pred5_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8) 183 | pred6_kk = (pred6_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8) 184 | pred7_kk = (pred7_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8) 185 | pred8_kk = (pred8_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8) 186 | 187 | cat_img = cv2.hconcat([pred1_kk, pred2_kk, pred3_kk, pred4_kk, pred5_kk, pred6_kk, pred7_kk, pred8_kk]) 188 | save_path = './temp/' 189 | name = '{:02d}_gt_initR_refR_intD_refD_mR_mD_Fused.png'.format(kk) 190 | cv2.imwrite(save_path + name, cat_img) 191 | 192 | print("Let's Play!") 193 | for epoch in range(1, opt.epoch+1): 194 | generator.train() 195 | loss_record = AvgMeter() 196 | print('Generator Learning Rate: {}'.format(generator_optimizer.param_groups[0]['lr'])) 197 | 198 | for i, pack in enumerate(train_loader, start=1): 199 | for rate in size_rates: 200 | generator_optimizer.zero_grad() 201 | images, gts, depths = pack 202 | # print(index_batch) 203 | images = Variable(images) 204 | gts = Variable(gts) 205 | depths = Variable(depths) 206 | images = images.cuda() 207 | gts = gts.cuda() 208 | depths = depths.cuda() 209 | 210 | # multi-scale training samples 211 | trainsize = int(round(opt.trainsize * rate / 32) * 32) 212 | if rate != 1: 213 | images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear', 214 | align_corners=True) 215 | gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 216 | depths = F.upsample(depths, size=(trainsize, trainsize), mode='bilinear', align_corners=True) 217 | 218 | init_rgb, ref_rgb, init_depth, ref_depth, mi_rgb, mi_depth, fuse_sal, latent_loss = generator.forward(images,depths) 219 | 220 | sal_rgb_loss = structure_loss(init_rgb, gts) + structure_loss(ref_rgb, gts) + structure_loss(mi_rgb, gts) 221 | sal_depth_loss = structure_loss(init_depth, gts) + structure_loss(ref_depth, gts) + structure_loss(mi_depth, gts) 222 | sal_final_rgbd = structure_loss(fuse_sal, gts) 223 | 224 | 225 | anneal_reg = linear_annealing(0, 1, epoch, opt.epoch) 226 | latent_loss = 0.1*anneal_reg*latent_loss 227 | sal_loss = sal_rgb_loss+sal_depth_loss+sal_final_rgbd + latent_loss 228 | sal_loss.backward() 229 | generator_optimizer.step() 230 | visualize_all_pred(gts,torch.sigmoid(init_rgb),torch.sigmoid(ref_rgb),torch.sigmoid(init_depth),torch.sigmoid(ref_depth),torch.sigmoid(mi_rgb),torch.sigmoid(mi_depth),torch.sigmoid(fuse_sal)) 231 | 232 | if rate == 1: 233 | loss_record.update(sal_loss.data, opt.batchsize) 234 | 235 | 236 | if i % 10 == 0 or i == total_step: 237 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], gen Loss: {:.4f}'. 238 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record.show())) 239 | # print(anneal_reg) 240 | 241 | 242 | adjust_lr(generator_optimizer, opt.lr_gen, epoch, opt.decay_rate, opt.decay_epoch) 243 | 244 | save_path = 'models/' 245 | 246 | 247 | if not os.path.exists(save_path): 248 | os.makedirs(save_path) 249 | if epoch % opt.epoch == 0: 250 | torch.save(generator.state_dict(), save_path + 'Model' + '_%d' % epoch + '_gen.pth') 251 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | 5 | def clip_gradient(optimizer, grad_clip): 6 | for group in optimizer.param_groups: 7 | for param in group['params']: 8 | if param.grad is not None: 9 | param.grad.data.clamp_(-grad_clip, grad_clip) 10 | 11 | 12 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=5): 13 | decay = decay_rate ** (epoch // decay_epoch) 14 | for param_group in optimizer.param_groups: 15 | param_group['lr'] *= decay 16 | 17 | 18 | def truncated_normal_(tensor, mean=0, std=1): 19 | size = tensor.shape 20 | tmp = tensor.new_empty(size + (4,)).normal_() 21 | valid = (tmp < 2) & (tmp > -2) 22 | ind = valid.max(-1, keepdim=True)[1] 23 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1)) 24 | tensor.data.mul_(std).add_(mean) 25 | 26 | def init_weights(m): 27 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 28 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') 29 | #nn.init.normal_(m.weight, std=0.001) 30 | #nn.init.normal_(m.bias, std=0.001) 31 | truncated_normal_(m.bias, mean=0, std=0.001) 32 | 33 | def init_weights_orthogonal_normal(m): 34 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d: 35 | nn.init.orthogonal_(m.weight) 36 | truncated_normal_(m.bias, mean=0, std=0.001) 37 | #nn.init.normal_(m.bias, std=0.001) 38 | 39 | def l2_regularisation(m): 40 | l2_reg = None 41 | 42 | for W in m.parameters(): 43 | if l2_reg is None: 44 | l2_reg = W.norm(2) 45 | else: 46 | l2_reg = l2_reg + W.norm(2) 47 | return l2_reg 48 | 49 | class AvgMeter(object): 50 | def __init__(self, num=40): 51 | self.num = num 52 | self.reset() 53 | 54 | def reset(self): 55 | self.val = 0 56 | self.avg = 0 57 | self.sum = 0 58 | self.count = 0 59 | self.losses = [] 60 | 61 | def update(self, val, n=1): 62 | self.val = val 63 | self.sum += val * n 64 | self.count += n 65 | self.avg = self.sum / self.count 66 | self.losses.append(val) 67 | 68 | def show(self): 69 | a = len(self.losses) 70 | b = np.maximum(a-self.num, 0) 71 | c = self.losses[b:] 72 | #print(c) 73 | #d = torch.mean(torch.stack(c)) 74 | #print(d) 75 | return torch.mean(torch.stack(c)) 76 | 77 | # def save_mask_prediction_example(mask, pred, iter): 78 | # plt.imshow(pred[0,:,:],cmap='Greys') 79 | # plt.savefig('images/'+str(iter)+"_prediction.png") 80 | # plt.imshow(mask[0,:,:],cmap='Greys') 81 | # plt.savefig('images/'+str(iter)+"_mask.png") --------------------------------------------------------------------------------