├── Annotations.png
├── README.md
├── data.py
├── model
├── HolisticAttention.py
├── ResNet.py
├── ResNet_models_combine.py
├── ResNet_models_sep.py
├── __init__.py
└── batchrenorm.py
├── test.py
├── train.py
└── utils.py
/Annotations.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/JingZhang617/cascaded_rgbd_sod/95e55b3fc8895e59a9707099d6d9704a981f9876/Annotations.png
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Cascaded RGB-D SOD with COME15K dataset (ICCV2021)
2 | This is the official implementaion of CLNet paper "RGB-D Saliency Detection via Cascaded Mutual Information Minimization".
3 |
4 |
5 | # COME15K RGB-D SOD Dataset
6 | We provide the COME training dataset, which include 8,025 image pairs of RGB-D images for SOD training. The dataset can be found at:
7 | https://drive.google.com/drive/folders/1mGbFKlIJNeW0m7hE-1dGX0b2gcxSMXjB?usp=sharing
8 |
9 | We further introduce two sets of testing dataset, namely COME-E and COME-H, which include 4,600 and 3,000 image pairs respectively, and can be downloaded at:
10 | https://drive.google.com/drive/folders/1w0M9YmYBzkMLijy_Blg6RMRshSvPZju-?usp=sharing
11 |
12 | 训练数据百度网盘:
13 | 链接:https://pan.baidu.com/s/15vaAkGuLVYPGuuYujDuhXg 密码:m2er
14 |
15 | 测试数据百度网盘:
16 | 链接:https://pan.baidu.com/s/1Ohidx48adju5gMI_hkGzug 密码:dofk
17 |
18 | # Rich Annotations
19 | 
20 |
21 | For both the new training and testing dataset, we provide binary ground truth annotations, instance level annotations, saliency ranking (0,1,2,3,4, where 4 indicates the most salient instance). We also provide the raw annotations of five different annotators. In this way, each image will have five binary saliency annotations from five different annotators.
22 |
23 | # New benchmark
24 | With our new training dataset, we re-train existing RGB-D SOD models, and test on ten benchmark testing dataset, including: SSB (STERE), DES, NLPR, NJU2K, LFSD, SIP, DUT-RGBD, RedWeb-S and our COME-E and COME-H. Please find saliency maps of retained models at (constantly updating):
25 | https://drive.google.com/drive/folders/1lCE8OHeqNdjhE4--yR0FFib2C5DBTgwn?usp=sharing
26 |
27 | # Retrain existing models
28 | We retrain state-of-the-art RGB-D SOD models with our new training dataset, and the re-trained models can be found at:
29 | https://drive.google.com/drive/folders/18Tqsn3yYoYO9HH8ZNVhHOTrJ7-UWPAZs?usp=sharing
30 |
31 | # Our trained model on conventional training dataset (the combination of NLPR and NJU2K data) and the produced saliency maps on SSB (STERE), DES, NLPR, NJU2K, LFSD, SIP:
32 |
33 | model: https://drive.google.com/file/d/1gUubs1eGr2fnrlgze-EFhD9XbghfAyhK/view?usp=share_link
34 |
35 | maps: https://drive.google.com/file/d/1OPTc7NsGQq9uYBdfquLIbluMikWezYO8/view?usp=share_link
36 |
37 | Note that, due to being stochastic, the model can perform slightly different each time of training.
38 | Solutions to get deterministic models:
39 | 1) instead of using the reparameterization trick as we used, you can simply follow the auto-encoder learning pipeline, and map features directly to the embedding space;--- to achieve this, you will need to remove the variance mapping function;
40 | 2) or you can simply define variance as 0, leading to deterministic generatration of the latent code, which is in practice easier to implement.
41 |
42 | # Our Bib:
43 |
44 | Please cite our paper if necessary:
45 | ```
46 | @inproceedings{cascaded_rgbd_sod,
47 | title={RGB-D Saliency Detection via Cascaded Mutual Information Minimization},
48 | author={Zhang, Jing and Fan, Deng-Ping and Dai, Yuchao and Yu, Xin and Zhong, Yiran and Barnes, Nick and Shao, Ling},
49 | booktitle={International Conference on Computer Vision (ICCV)},
50 | year={2021}
51 | }
52 | ```
53 | # Copyright
54 | 
This work is licensed under a Creative Commons Attribution-NonCommercial-ShareAlike 3.0 Unported License.
55 |
56 | # Privacy
57 | This dataset is made available for academic use only. If you find yourself or personal belongings in this dataset and feel unwell about it, please contact us and we will immediately remove the respective data.
58 |
59 | # Contact
60 |
61 | Please drop me an email for further problems or discussion: zjnwpu@gmail.com
62 |
--------------------------------------------------------------------------------
/data.py:
--------------------------------------------------------------------------------
1 | import os
2 | from PIL import Image
3 | import torch.utils.data as data
4 | import torchvision.transforms as transforms
5 | import random
6 | import numpy as np
7 | from PIL import ImageEnhance
8 |
9 |
10 | # several data augumentation strategies
11 | def cv_random_flip(img, label, depth):
12 | flip_flag = random.randint(0, 1)
13 | # flip_flag2= random.randint(0,1)
14 | # left right flip
15 | if flip_flag == 1:
16 | img = img.transpose(Image.FLIP_LEFT_RIGHT)
17 | label = label.transpose(Image.FLIP_LEFT_RIGHT)
18 | depth = depth.transpose(Image.FLIP_LEFT_RIGHT)
19 | # top bottom flip
20 | # if flip_flag2==1:
21 | # img = img.transpose(Image.FLIP_TOP_BOTTOM)
22 | # label = label.transpose(Image.FLIP_TOP_BOTTOM)
23 | # depth = depth.transpose(Image.FLIP_TOP_BOTTOM)
24 | return img, label, depth
25 |
26 |
27 | def randomCrop(image, label, depth):
28 | border = 30
29 | image_width = image.size[0]
30 | image_height = image.size[1]
31 | crop_win_width = np.random.randint(image_width - border, image_width)
32 | crop_win_height = np.random.randint(image_height - border, image_height)
33 | random_region = (
34 | (image_width - crop_win_width) >> 1, (image_height - crop_win_height) >> 1, (image_width + crop_win_width) >> 1,
35 | (image_height + crop_win_height) >> 1)
36 | return image.crop(random_region), label.crop(random_region), depth.crop(random_region)
37 |
38 |
39 | def randomRotation(image, label, depth):
40 | mode = Image.BICUBIC
41 | if random.random() > 0.8:
42 | random_angle = np.random.randint(-15, 15)
43 | image = image.rotate(random_angle, mode)
44 | label = label.rotate(random_angle, mode)
45 | depth = depth.rotate(random_angle, mode)
46 | return image, label, depth
47 |
48 |
49 | def colorEnhance(image):
50 | bright_intensity = random.randint(5, 15) / 10.0
51 | image = ImageEnhance.Brightness(image).enhance(bright_intensity)
52 | contrast_intensity = random.randint(5, 15) / 10.0
53 | image = ImageEnhance.Contrast(image).enhance(contrast_intensity)
54 | color_intensity = random.randint(0, 20) / 10.0
55 | image = ImageEnhance.Color(image).enhance(color_intensity)
56 | sharp_intensity = random.randint(0, 30) / 10.0
57 | image = ImageEnhance.Sharpness(image).enhance(sharp_intensity)
58 | return image
59 |
60 |
61 | def randomGaussian(image, mean=0.1, sigma=0.35):
62 | def gaussianNoisy(im, mean=mean, sigma=sigma):
63 | for _i in range(len(im)):
64 | im[_i] += random.gauss(mean, sigma)
65 | return im
66 |
67 | img = np.asarray(image)
68 | width, height = img.shape
69 | img = gaussianNoisy(img[:].flatten(), mean, sigma)
70 | img = img.reshape([width, height])
71 | return Image.fromarray(np.uint8(img))
72 |
73 |
74 | def randomPeper(img):
75 | img = np.array(img)
76 | noiseNum = int(0.0015 * img.shape[0] * img.shape[1])
77 | for i in range(noiseNum):
78 |
79 | randX = random.randint(0, img.shape[0] - 1)
80 |
81 | randY = random.randint(0, img.shape[1] - 1)
82 |
83 | if random.randint(0, 1) == 0:
84 |
85 | img[randX, randY] = 0
86 |
87 | else:
88 |
89 | img[randX, randY] = 255
90 | return Image.fromarray(img)
91 |
92 |
93 | # dataset for training
94 | # The current loader is not using the normalized depth maps for training and test. If you use the normalized depth maps
95 | # (e.g., 0 represents background and 1 represents foreground.), the performance will be further improved.
96 | class SalObjDataset(data.Dataset):
97 | def __init__(self, image_root, gt_root, depth_root, trainsize):
98 | self.trainsize = trainsize
99 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
100 | self.gts = [gt_root + f for f in os.listdir(gt_root) if f.endswith('.jpg')
101 | or f.endswith('.png')]
102 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
103 | or f.endswith('.png')]
104 | self.images = sorted(self.images)
105 | self.gts = sorted(self.gts)
106 | self.depths = sorted(self.depths)
107 | self.filter_files()
108 | self.size = len(self.images)
109 | self.img_transform = transforms.Compose([
110 | transforms.Resize((self.trainsize, self.trainsize)),
111 | transforms.ToTensor(),
112 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
113 | self.gt_transform = transforms.Compose([
114 | transforms.Resize((self.trainsize, self.trainsize)),
115 | transforms.ToTensor()])
116 | self.depths_transform = transforms.Compose(
117 | [transforms.Resize((self.trainsize, self.trainsize)), transforms.ToTensor()])
118 |
119 | def __getitem__(self, index):
120 | image = self.rgb_loader(self.images[index])
121 | gt = self.binary_loader(self.gts[index])
122 | depth = self.rgb_loader(self.depths[index])
123 | image, gt, depth = cv_random_flip(image, gt, depth)
124 | image, gt, depth = randomCrop(image, gt, depth)
125 | image, gt, depth = randomRotation(image, gt, depth)
126 | image = colorEnhance(image)
127 | # gt=randomGaussian(gt)
128 | gt = randomPeper(gt)
129 | image = self.img_transform(image)
130 | gt = self.gt_transform(gt)
131 | depth = self.depths_transform(depth)
132 |
133 | return image, gt, depth
134 |
135 | def filter_files(self):
136 | assert len(self.images) == len(self.gts) and len(self.gts) == len(self.images)
137 | images = []
138 | gts = []
139 | depths = []
140 | for img_path, gt_path, depth_path in zip(self.images, self.gts, self.depths):
141 | img = Image.open(img_path)
142 | gt = Image.open(gt_path)
143 | depth = Image.open(depth_path)
144 | if img.size == gt.size and gt.size == depth.size:
145 | images.append(img_path)
146 | gts.append(gt_path)
147 | depths.append(depth_path)
148 | self.images = images
149 | self.gts = gts
150 | self.depths = depths
151 |
152 | def rgb_loader(self, path):
153 | with open(path, 'rb') as f:
154 | img = Image.open(f)
155 | return img.convert('RGB')
156 |
157 | def binary_loader(self, path):
158 | with open(path, 'rb') as f:
159 | img = Image.open(f)
160 | return img.convert('L')
161 |
162 | def resize(self, img, gt, depth):
163 | assert img.size == gt.size and gt.size == depth.size
164 | w, h = img.size
165 | if h < self.trainsize or w < self.trainsize:
166 | h = max(h, self.trainsize)
167 | w = max(w, self.trainsize)
168 | return img.resize((w, h), Image.BILINEAR), gt.resize((w, h), Image.NEAREST), depth.resize((w, h),
169 | Image.NEAREST)
170 | else:
171 | return img, gt, depth
172 |
173 | def __len__(self):
174 | return self.size
175 |
176 |
177 | # dataloader for training
178 | def get_loader(image_root, gt_root, depth_root, batchsize, trainsize, shuffle=True, num_workers=12, pin_memory=True):
179 | dataset = SalObjDataset(image_root, gt_root, depth_root, trainsize)
180 | data_loader = data.DataLoader(dataset=dataset,
181 | batch_size=batchsize,
182 | shuffle=shuffle,
183 | num_workers=num_workers,
184 | pin_memory=pin_memory)
185 | return data_loader
186 |
187 |
188 | # test dataset and loader
189 | class test_dataset:
190 | def __init__(self, image_root, depth_root, testsize):
191 | self.testsize = testsize
192 | self.images = [image_root + f for f in os.listdir(image_root) if f.endswith('.jpg')]
193 | self.depths = [depth_root + f for f in os.listdir(depth_root) if f.endswith('.bmp')
194 | or f.endswith('.png')]
195 | self.images = sorted(self.images)
196 | self.depths = sorted(self.depths)
197 | self.transform = transforms.Compose([
198 | transforms.Resize((self.testsize, self.testsize)),
199 | transforms.ToTensor(),
200 | transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
201 | # self.gt_transform = transforms.Compose([
202 | # transforms.Resize((self.trainsize, self.trainsize)),
203 | # transforms.ToTensor()])
204 | self.depths_transform = transforms.Compose(
205 | [transforms.Resize((self.testsize, self.testsize)), transforms.ToTensor()])
206 | self.size = len(self.images)
207 | self.index = 0
208 |
209 | def load_data(self):
210 | image = self.rgb_loader(self.images[self.index])
211 | HH = image.size[0]
212 | WW = image.size[1]
213 | image = self.transform(image).unsqueeze(0)
214 | depth = self.rgb_loader(self.depths[self.index])
215 | depth = self.depths_transform(depth).unsqueeze(0)
216 |
217 | name = self.images[self.index].split('/')[-1]
218 | # image_for_post=self.rgb_loader(self.images[self.index])
219 | # image_for_post=image_for_post.resize(gt.size)
220 | if name.endswith('.jpg'):
221 | name = name.split('.jpg')[0] + '.png'
222 | self.index += 1
223 | self.index = self.index % self.size
224 | return image, depth, HH, WW, name
225 |
226 | def rgb_loader(self, path):
227 | with open(path, 'rb') as f:
228 | img = Image.open(f)
229 | return img.convert('RGB')
230 |
231 | def binary_loader(self, path):
232 | with open(path, 'rb') as f:
233 | img = Image.open(f)
234 | return img.convert('L')
235 |
236 | def __len__(self):
237 | return self.size
238 |
239 |
--------------------------------------------------------------------------------
/model/HolisticAttention.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import torch.nn as nn
4 | from torch.nn.parameter import Parameter
5 |
6 | import numpy as np
7 | import scipy.stats as st
8 |
9 |
10 | def gkern(kernlen=16, nsig=3):
11 | interval = (2*nsig+1.)/kernlen
12 | x = np.linspace(-nsig-interval/2., nsig+interval/2., kernlen+1)
13 | kern1d = np.diff(st.norm.cdf(x))
14 | kernel_raw = np.sqrt(np.outer(kern1d, kern1d))
15 | kernel = kernel_raw/kernel_raw.sum()
16 | return kernel
17 |
18 |
19 | def min_max_norm(in_):
20 | max_ = in_.max(3)[0].max(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)
21 | min_ = in_.min(3)[0].min(2)[0].unsqueeze(2).unsqueeze(3).expand_as(in_)
22 | in_ = in_ - min_
23 | return in_.div(max_-min_+1e-8)
24 |
25 |
26 | class HA(nn.Module):
27 | # holistic attention module
28 | def __init__(self):
29 | super(HA, self).__init__()
30 | gaussian_kernel = np.float32(gkern(31, 4))
31 | gaussian_kernel = gaussian_kernel[np.newaxis, np.newaxis, ...]
32 | self.gaussian_kernel = Parameter(torch.from_numpy(gaussian_kernel))
33 |
34 | def forward(self, attention, x):
35 | soft_attention = F.conv2d(attention, self.gaussian_kernel, padding=15)
36 | soft_attention = min_max_norm(soft_attention)
37 | x = torch.mul(x, soft_attention.max(attention))
38 | return x
39 |
--------------------------------------------------------------------------------
/model/ResNet.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import math
3 |
4 |
5 | def conv3x3(in_planes, out_planes, stride=1):
6 | """3x3 convolution with padding"""
7 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
8 | padding=1, bias=False)
9 |
10 |
11 | class BasicBlock(nn.Module):
12 | expansion = 1
13 |
14 | def __init__(self, inplanes, planes, stride=1, downsample=None):
15 | super(BasicBlock, self).__init__()
16 | self.conv1 = conv3x3(inplanes, planes, stride)
17 | self.bn1 = nn.BatchNorm2d(planes)
18 | self.relu = nn.ReLU(inplace=True)
19 | self.conv2 = conv3x3(planes, planes)
20 | self.bn2 = nn.BatchNorm2d(planes)
21 | self.downsample = downsample
22 | self.stride = stride
23 |
24 | def forward(self, x):
25 | residual = x
26 |
27 | out = self.conv1(x)
28 | out = self.bn1(out)
29 | out = self.relu(out)
30 |
31 | out = self.conv2(out)
32 | out = self.bn2(out)
33 |
34 | if self.downsample is not None:
35 | residual = self.downsample(x)
36 |
37 | out += residual
38 | out = self.relu(out)
39 |
40 | return out
41 |
42 |
43 | class Bottleneck(nn.Module):
44 | expansion = 4
45 |
46 | def __init__(self, inplanes, planes, stride=1, downsample=None):
47 | super(Bottleneck, self).__init__()
48 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
49 | self.bn1 = nn.BatchNorm2d(planes)
50 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
51 | padding=1, bias=False)
52 | self.bn2 = nn.BatchNorm2d(planes)
53 | self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
54 | self.bn3 = nn.BatchNorm2d(planes * 4)
55 | self.relu = nn.ReLU(inplace=True)
56 | self.downsample = downsample
57 | self.stride = stride
58 |
59 | def forward(self, x):
60 | residual = x
61 |
62 | out = self.conv1(x)
63 | out = self.bn1(out)
64 | out = self.relu(out)
65 |
66 | out = self.conv2(out)
67 | out = self.bn2(out)
68 | out = self.relu(out)
69 |
70 | out = self.conv3(out)
71 | out = self.bn3(out)
72 |
73 | if self.downsample is not None:
74 | residual = self.downsample(x)
75 |
76 | out += residual
77 | out = self.relu(out)
78 |
79 | return out
80 |
81 |
82 | class B2_ResNet(nn.Module):
83 | # ResNet50 with two branches
84 | def __init__(self):
85 | # self.inplanes = 128
86 | self.inplanes = 64
87 | super(B2_ResNet, self).__init__()
88 |
89 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
90 | bias=False)
91 | self.bn1 = nn.BatchNorm2d(64)
92 | self.relu = nn.ReLU(inplace=True)
93 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
94 | self.layer1 = self._make_layer(Bottleneck, 64, 3)
95 | self.layer2 = self._make_layer(Bottleneck, 128, 4, stride=2)
96 | self.layer3_1 = self._make_layer(Bottleneck, 256, 6, stride=2)
97 | self.layer4_1 = self._make_layer(Bottleneck, 512, 3, stride=2)
98 |
99 | self.inplanes = 512
100 | self.layer3_2 = self._make_layer(Bottleneck, 256, 6, stride=2)
101 | self.layer4_2 = self._make_layer(Bottleneck, 512, 3, stride=2)
102 |
103 | for m in self.modules():
104 | if isinstance(m, nn.Conv2d):
105 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106 | m.weight.data.normal_(0, math.sqrt(2. / n))
107 | elif isinstance(m, nn.BatchNorm2d):
108 | m.weight.data.fill_(1)
109 | m.bias.data.zero_()
110 |
111 | def _make_layer(self, block, planes, blocks, stride=1):
112 | downsample = None
113 | if stride != 1 or self.inplanes != planes * block.expansion:
114 | downsample = nn.Sequential(
115 | nn.Conv2d(self.inplanes, planes * block.expansion,
116 | kernel_size=1, stride=stride, bias=False),
117 | nn.BatchNorm2d(planes * block.expansion),
118 | )
119 |
120 | layers = []
121 | layers.append(block(self.inplanes, planes, stride, downsample))
122 | self.inplanes = planes * block.expansion
123 | for i in range(1, blocks):
124 | layers.append(block(self.inplanes, planes))
125 |
126 | return nn.Sequential(*layers)
127 |
128 | def forward(self, x):
129 | x = self.conv1(x)
130 | x = self.bn1(x)
131 | x = self.relu(x)
132 | x = self.maxpool(x)
133 |
134 | x = self.layer1(x)
135 | x = self.layer2(x)
136 | x1 = self.layer3_1(x)
137 | x1 = self.layer4_1(x1)
138 |
139 | x2 = self.layer3_2(x)
140 | x2 = self.layer4_2(x2)
141 |
142 | return x1, x2
143 |
--------------------------------------------------------------------------------
/model/ResNet_models_combine.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 | import numpy as np
5 | from model.ResNet import B2_ResNet
6 | from utils import init_weights,init_weights_orthogonal_normal
7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8 | from torch.autograd import Variable
9 | from torch.nn import Parameter, Softmax
10 | import torch.nn.functional as F
11 | from torch.distributions import Normal, Independent, kl
12 | from model.HolisticAttention import HA
13 | import math
14 | CE = torch.nn.BCELoss(reduction='sum')
15 | cos_sim = torch.nn.CosineSimilarity(dim=1,eps=1e-8)
16 | from model.batchrenorm import BatchRenorm2d
17 |
18 | class BasicConv2d(nn.Module):
19 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
20 | super(BasicConv2d, self).__init__()
21 | self.conv = nn.Conv2d(in_planes, out_planes,
22 | kernel_size=kernel_size, stride=stride,
23 | padding=padding, dilation=dilation, bias=False)
24 | self.bn = nn.BatchNorm2d(out_planes)
25 | self.relu = nn.ReLU(inplace=True)
26 |
27 | def forward(self, x):
28 | x = self.conv(x)
29 | x = self.bn(x)
30 | return x
31 |
32 | class Classifier_Module(nn.Module):
33 | def __init__(self,dilation_series,padding_series,NoLabels, input_channel):
34 | super(Classifier_Module, self).__init__()
35 | self.conv2d_list = nn.ModuleList()
36 | for dilation,padding in zip(dilation_series,padding_series):
37 | self.conv2d_list.append(nn.Conv2d(input_channel,NoLabels,kernel_size=3,stride=1, padding =padding, dilation = dilation,bias = True))
38 | for m in self.conv2d_list:
39 | m.weight.data.normal_(0, 0.01)
40 |
41 | def forward(self, x):
42 | out = self.conv2d_list[0](x)
43 | for i in range(len(self.conv2d_list)-1):
44 | out += self.conv2d_list[i+1](x)
45 | return out
46 |
47 | class Mutual_info_reg(nn.Module):
48 | def __init__(self, input_channels, channels, latent_size):
49 | super(Mutual_info_reg, self).__init__()
50 | self.contracting_path = nn.ModuleList()
51 | self.input_channels = input_channels
52 | self.relu = nn.ReLU(inplace=True)
53 | self.layer1 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1)
54 | self.bn1 = nn.BatchNorm2d(channels)
55 | self.layer2 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1)
56 | self.bn2 = nn.BatchNorm2d(channels)
57 | self.layer3 = nn.Conv2d(channels, channels, kernel_size=4, stride=2, padding=1)
58 | self.layer4 = nn.Conv2d(channels, channels, kernel_size=4, stride=2, padding=1)
59 |
60 | self.channel = channels
61 |
62 | self.fc1_rgb1 = nn.Linear(channels * 1 * 16 * 16, latent_size)
63 | self.fc2_rgb1 = nn.Linear(channels * 1 * 16 * 16, latent_size)
64 | self.fc1_depth1 = nn.Linear(channels * 1 * 16 * 16, latent_size)
65 | self.fc2_depth1 = nn.Linear(channels * 1 * 16 * 16, latent_size)
66 |
67 | self.fc1_rgb2 = nn.Linear(channels * 1 * 22 * 22, latent_size)
68 | self.fc2_rgb2 = nn.Linear(channels * 1 * 22 * 22, latent_size)
69 | self.fc1_depth2 = nn.Linear(channels * 1 * 22 * 22, latent_size)
70 | self.fc2_depth2 = nn.Linear(channels * 1 * 22 * 22, latent_size)
71 |
72 | self.fc1_rgb3 = nn.Linear(channels * 1 * 28 * 28, latent_size)
73 | self.fc2_rgb3 = nn.Linear(channels * 1 * 28 * 28, latent_size)
74 | self.fc1_depth3 = nn.Linear(channels * 1 * 28 * 28, latent_size)
75 | self.fc2_depth3 = nn.Linear(channels * 1 * 28 * 28, latent_size)
76 |
77 | self.leakyrelu = nn.LeakyReLU()
78 | self.tanh = torch.nn.Tanh()
79 |
80 | def kl_divergence(self, posterior_latent_space, prior_latent_space):
81 | kl_div = kl.kl_divergence(posterior_latent_space, prior_latent_space)
82 | return kl_div
83 |
84 | def reparametrize(self, mu, logvar):
85 | std = logvar.mul(0.5).exp_()
86 | eps = torch.cuda.FloatTensor(std.size()).normal_()
87 | eps = Variable(eps)
88 | return eps.mul(std).add_(mu)
89 |
90 | def forward(self, rgb_feat, depth_feat):
91 | rgb_feat = self.layer3(self.leakyrelu(self.bn1(self.layer1(rgb_feat))))
92 | depth_feat = self.layer4(self.leakyrelu(self.bn2(self.layer2(depth_feat))))
93 | # print(rgb_feat.size())
94 | # print(depth_feat.size())
95 | if rgb_feat.shape[2] == 16:
96 | rgb_feat = rgb_feat.view(-1, self.channel * 1 * 16 * 16)
97 | depth_feat = depth_feat.view(-1, self.channel * 1 * 16 * 16)
98 |
99 | mu_rgb = self.fc1_rgb1(rgb_feat)
100 | logvar_rgb = self.fc2_rgb1(rgb_feat)
101 | mu_depth = self.fc1_depth1(depth_feat)
102 | logvar_depth = self.fc2_depth1(depth_feat)
103 | elif rgb_feat.shape[2] == 22:
104 | rgb_feat = rgb_feat.view(-1, self.channel * 1 * 22 * 22)
105 | depth_feat = depth_feat.view(-1, self.channel * 1 * 22 * 22)
106 | mu_rgb = self.fc1_rgb2(rgb_feat)
107 | logvar_rgb = self.fc2_rgb2(rgb_feat)
108 | mu_depth = self.fc1_depth2(depth_feat)
109 | logvar_depth = self.fc2_depth2(depth_feat)
110 | else:
111 | rgb_feat = rgb_feat.view(-1, self.channel * 1 * 28 * 28)
112 | depth_feat = depth_feat.view(-1, self.channel * 1 * 28 * 28)
113 | mu_rgb = self.fc1_rgb3(rgb_feat)
114 | logvar_rgb = self.fc2_rgb3(rgb_feat)
115 | mu_depth = self.fc1_depth3(depth_feat)
116 | logvar_depth = self.fc2_depth3(depth_feat)
117 |
118 | mu_depth = self.tanh(mu_depth)
119 | mu_rgb = self.tanh(mu_rgb)
120 | logvar_depth = self.tanh(logvar_depth)
121 | logvar_rgb = self.tanh(logvar_rgb)
122 | z_rgb = self.reparametrize(mu_rgb, logvar_rgb)
123 | dist_rgb = Independent(Normal(loc=mu_rgb, scale=torch.exp(logvar_rgb)), 1)
124 | z_depth = self.reparametrize(mu_depth, logvar_depth)
125 | dist_depth = Independent(Normal(loc=mu_depth, scale=torch.exp(logvar_depth)), 1)
126 | bi_di_kld = torch.mean(self.kl_divergence(dist_rgb, dist_depth)) + torch.mean(
127 | self.kl_divergence(dist_depth, dist_rgb))
128 | z_rgb_norm = torch.sigmoid(z_rgb)
129 | z_depth_norm = torch.sigmoid(z_depth)
130 | ce_rgb_depth = CE(z_rgb_norm,z_depth_norm.detach())
131 | ce_depth_rgb = CE(z_depth_norm, z_rgb_norm.detach())
132 | latent_loss = ce_rgb_depth+ce_depth_rgb-bi_di_kld
133 | # latent_loss = torch.abs(cos_sim(z_rgb,z_depth)).sum()
134 |
135 | return latent_loss, z_rgb, z_depth
136 |
137 |
138 |
139 | class CAM_Module(nn.Module):
140 | """ Channel attention module"""
141 | def __init__(self):
142 | super(CAM_Module, self).__init__()
143 | self.gamma = Parameter(torch.zeros(1))
144 | self.softmax = Softmax(dim=-1)
145 | def forward(self,x):
146 | """
147 | inputs :
148 | x : input feature maps( B X C X H X W)
149 | returns :
150 | out : attention value + input feature
151 | attention: B X C X C
152 | """
153 | m_batchsize, C, height, width = x.size()
154 | proj_query = x.view(m_batchsize, C, -1)
155 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
156 | energy = torch.bmm(proj_query, proj_key)
157 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
158 | attention = self.softmax(energy_new)
159 | proj_value = x.view(m_batchsize, C, -1)
160 |
161 | out = torch.bmm(attention, proj_value)
162 | out = out.view(m_batchsize, C, height, width)
163 |
164 | out = self.gamma*out + x
165 | return out
166 |
167 | ## Channel Attention (CA) Layer
168 | class CALayer(nn.Module):
169 | def __init__(self, channel, reduction=16):
170 | super(CALayer, self).__init__()
171 | # global average pooling: feature --> point
172 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
173 | # feature channel downscale and upscale --> channel weight
174 | self.conv_du = nn.Sequential(
175 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
176 | nn.ReLU(inplace=True),
177 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
178 | nn.Sigmoid()
179 | )
180 |
181 | def forward(self, x):
182 | y = self.avg_pool(x)
183 | y = self.conv_du(y)
184 | return x * y
185 |
186 | ## Residual Channel Attention Block (RCAB)
187 |
188 |
189 |
190 | class RCAB(nn.Module):
191 | def __init__(
192 | self, n_feat, kernel_size=3, reduction=16,
193 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
194 |
195 | super(RCAB, self).__init__()
196 | modules_body = []
197 | for i in range(2):
198 | modules_body.append(self.default_conv(n_feat, n_feat, kernel_size, bias=bias))
199 | if bn: modules_body.append(nn.BatchNorm2d(n_feat))
200 | if i == 0: modules_body.append(act)
201 | modules_body.append(CALayer(n_feat, reduction))
202 | self.body = nn.Sequential(*modules_body)
203 | self.res_scale = res_scale
204 |
205 | def default_conv(self, in_channels, out_channels, kernel_size, bias=True):
206 | return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size // 2), bias=bias)
207 |
208 | def forward(self, x):
209 | res = self.body(x)
210 | #res = self.body(x).mul(self.res_scale)
211 | res += x
212 | return res
213 |
214 |
215 |
216 | class Saliency_feat_decoder(nn.Module):
217 | # resnet based encoder decoder
218 | def __init__(self, channel,latent_dim):
219 | super(Saliency_feat_decoder, self).__init__()
220 | self.relu = nn.ReLU(inplace=True)
221 | self.dropout = nn.Dropout(0.3)
222 |
223 | self.layer5 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 2048)
224 | self.layer6 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], 1, channel)
225 | self.layer7 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], 1, channel)
226 |
227 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
228 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
229 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
230 |
231 | self.conv1 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 256)
232 | self.conv2 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 512)
233 | self.conv3 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 1024)
234 | self.conv4 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 2048)
235 |
236 | self.spatial_axes = [2, 3]
237 |
238 | self.racb_43 = RCAB(channel * 2)
239 | self.racb_432 = RCAB(channel * 3)
240 | self.racb_4321 = RCAB(channel * 4)
241 |
242 | self.conv43 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 2*channel)
243 | self.conv432 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 3*channel)
244 | self.conv4321 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 4*channel)
245 |
246 | self.layer_depth = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], 3, channel * 4)
247 |
248 | self.rcab_z1 = RCAB(channel + latent_dim)
249 | self.conv_z1 = BasicConv2d(channel+latent_dim,channel,3,padding=1)
250 |
251 | self.rcab_z2 = RCAB(channel + latent_dim)
252 | self.conv_z2 = BasicConv2d(channel + latent_dim, channel, 3, padding=1)
253 |
254 | self.rcab_z3 = RCAB(channel + latent_dim)
255 | self.conv_z3 = BasicConv2d(channel + latent_dim, channel, 3, padding=1)
256 |
257 | self.rcab_z4 = RCAB(channel + latent_dim)
258 | self.conv_z4 = BasicConv2d(channel + latent_dim, channel, 3, padding=1)
259 |
260 | self.br1 = BatchRenorm2d(channel)
261 | self.br2 = BatchRenorm2d(channel)
262 | self.br3 = BatchRenorm2d(channel)
263 | self.br4 = BatchRenorm2d(channel)
264 |
265 |
266 | def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel):
267 | return block(dilation_series, padding_series, NoLabels, input_channel)
268 |
269 | def tile(self, a, dim, n_tile):
270 | """
271 | This function is taken form PyTorch forum and mimics the behavior of tf.tile.
272 | Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
273 | """
274 | init_dim = a.size(dim)
275 | repeat_idx = [1] * a.dim()
276 | repeat_idx[dim] = n_tile
277 | a = a.repeat(*(repeat_idx))
278 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
279 | return torch.index_select(a, dim, order_index)
280 |
281 | def forward(self, x1,x2,x3,x4,z1=None,z2=None,z3=None,z4=None):
282 | conv1_feat = self.br1(self.conv1(x1))
283 | conv2_feat = self.br2(self.conv2(x2))
284 | conv3_feat = self.br3(self.conv3(x3))
285 | conv4_feat = self.br4(self.conv4(x4))
286 |
287 | if z1!=None:
288 | z1 = torch.unsqueeze(z1, 2)
289 | z1 = self.tile(z1, 2, conv1_feat.shape[self.spatial_axes[0]])
290 | z1 = torch.unsqueeze(z1, 3)
291 | z1 = self.tile(z1, 3, conv1_feat.shape[self.spatial_axes[1]])
292 |
293 | z2 = torch.unsqueeze(z2, 2)
294 | z2 = self.tile(z2, 2, conv2_feat.shape[self.spatial_axes[0]])
295 | z2 = torch.unsqueeze(z2, 3)
296 | z2 = self.tile(z2, 3, conv2_feat.shape[self.spatial_axes[1]])
297 |
298 | z3 = torch.unsqueeze(z3, 2)
299 | z3 = self.tile(z3, 2, conv3_feat.shape[self.spatial_axes[0]])
300 | z3 = torch.unsqueeze(z3, 3)
301 | z3 = self.tile(z3, 3, conv3_feat.shape[self.spatial_axes[1]])
302 |
303 | z4 = torch.unsqueeze(z4, 2)
304 | z4 = self.tile(z4, 2, conv4_feat.shape[self.spatial_axes[0]])
305 | z4 = torch.unsqueeze(z4, 3)
306 | z4 = self.tile(z4, 3, conv4_feat.shape[self.spatial_axes[1]])
307 |
308 | conv1_feat = torch.cat((conv1_feat,z1),1)
309 | conv1_feat = self.rcab_z1(conv1_feat)
310 | conv1_feat = self.conv_z1(conv1_feat)
311 |
312 | conv2_feat = torch.cat((conv2_feat, z2), 1)
313 | conv2_feat = self.rcab_z2(conv2_feat)
314 | conv2_feat = self.conv_z2(conv2_feat)
315 |
316 | conv3_feat = torch.cat((conv3_feat, z3), 1)
317 | conv3_feat = self.rcab_z3(conv3_feat)
318 | conv3_feat = self.conv_z3(conv3_feat)
319 |
320 | conv4_feat = torch.cat((conv4_feat, z4), 1)
321 | conv4_feat = self.rcab_z4(conv4_feat)
322 | conv4_feat = self.conv_z4(conv4_feat)
323 |
324 | conv4_feat = self.upsample2(conv4_feat)
325 |
326 | conv43 = torch.cat((conv4_feat, conv3_feat), 1)
327 | conv43 = self.racb_43(conv43)
328 | conv43 = self.conv43(conv43)
329 |
330 | conv43 = self.upsample2(conv43)
331 | conv432 = torch.cat((self.upsample2(conv4_feat), conv43, conv2_feat), 1)
332 | conv432 = self.racb_432(conv432)
333 | conv432 = self.conv432(conv432)
334 |
335 | conv432 = self.upsample2(conv432)
336 | conv4321 = torch.cat((self.upsample4(conv4_feat), self.upsample2(conv43), conv432, conv1_feat), 1)
337 | conv4321 = self.racb_4321(conv4321)
338 | conv4321 = self.conv4321(conv4321)
339 |
340 | sal_init = self.layer6(conv4321)
341 |
342 | return sal_init
343 |
344 |
345 | class Saliency_feat_endecoder(nn.Module):
346 | # resnet based encoder decoder
347 | def __init__(self, channel):
348 | super(Saliency_feat_endecoder, self).__init__()
349 | self.resnet_rgb = B2_ResNet()
350 | self.resnet_depth = B2_ResNet()
351 | self.relu = nn.ReLU(inplace=True)
352 | self.dropout = nn.Dropout(0.3)
353 | self.latent_dim = 6
354 | self.conv_depth1 = BasicConv2d(6, 3, kernel_size=3, padding=1)
355 | self.sal_decoder1 = Saliency_feat_decoder(channel, self.latent_dim)
356 | self.sal_decoder2 = Saliency_feat_decoder(channel, self.latent_dim)
357 | self.sal_decoder3 = Saliency_feat_decoder(channel, self.latent_dim)
358 | self.sal_decoder4 = Saliency_feat_decoder(channel, self.latent_dim)
359 | self.sal_decoder5 = Saliency_feat_decoder(channel, self.latent_dim)
360 | self.sal_decoder6 = Saliency_feat_decoder(channel, self.latent_dim)
361 |
362 | self.HA = HA()
363 | self.upsample05 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)
364 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
365 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
366 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
367 | self.upsample025 = nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True)
368 | self.upsample0125 = nn.Upsample(scale_factor=0.125, mode='bilinear', align_corners=True)
369 |
370 |
371 | self.convx1_depth = nn.Conv2d(in_channels=256, out_channels=channel, kernel_size=3, padding=1)
372 | self.convx2_depth = nn.Conv2d(in_channels=512, out_channels=channel, kernel_size=3, padding=1)
373 | self.convx3_depth = nn.Conv2d(in_channels=1024, out_channels=channel, kernel_size=3, padding=1)
374 | self.convx4_depth = nn.Conv2d(in_channels=2048, out_channels=channel, kernel_size=3, padding=1)
375 |
376 | self.convx1_rgb = nn.Conv2d(in_channels=256, out_channels=channel, kernel_size=3, padding=1)
377 | self.convx2_rgb = nn.Conv2d(in_channels=512, out_channels=channel, kernel_size=3, padding=1)
378 | self.convx3_rgb = nn.Conv2d(in_channels=1024, out_channels=channel, kernel_size=3, padding=1)
379 | self.convx4_rgb = nn.Conv2d(in_channels=2048, out_channels=channel, kernel_size=3, padding=1)
380 |
381 | self.mi_level1 = Mutual_info_reg(channel,channel,self.latent_dim)
382 | self.mi_level2 = Mutual_info_reg(channel, channel, self.latent_dim)
383 | self.mi_level3 = Mutual_info_reg(channel, channel, self.latent_dim)
384 | self.mi_level4 = Mutual_info_reg(channel, channel, self.latent_dim)
385 |
386 | self.spatial_axes = [2, 3]
387 | self.final_clc = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=3, padding=1)
388 | self.rcab_rgb_feat = RCAB(channel*4)
389 | self.rcab_depth_feat = RCAB(channel*4)
390 |
391 |
392 |
393 |
394 | if self.training:
395 | self.initialize_weights()
396 |
397 | def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel):
398 | return block(dilation_series, padding_series, NoLabels, input_channel)
399 |
400 | def tile(self, a, dim, n_tile):
401 | """
402 | This function is taken form PyTorch forum and mimics the behavior of tf.tile.
403 | Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
404 | """
405 | init_dim = a.size(dim)
406 | repeat_idx = [1] * a.dim()
407 | repeat_idx[dim] = n_tile
408 | a = a.repeat(*(repeat_idx))
409 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
410 | return torch.index_select(a, dim, order_index)
411 |
412 | def forward(self, x,depth=None):
413 | raw_x = x
414 | x = self.resnet_rgb.conv1(x)
415 | x = self.resnet_rgb.bn1(x)
416 | x = self.resnet_rgb.relu(x)
417 | x = self.resnet_rgb.maxpool(x)
418 | x1_rgb = self.resnet_rgb.layer1(x) # 256 x 64 x 64
419 | x2_rgb = self.resnet_rgb.layer2(x1_rgb) # 512 x 32 x 32
420 | x3_rgb = self.resnet_rgb.layer3_1(x2_rgb) # 1024 x 16 x 16
421 | x4_rgb = self.resnet_rgb.layer4_1(x3_rgb) # 2048 x 8 x 8
422 |
423 | sal_init_rgb = self.sal_decoder1(x1_rgb, x2_rgb, x3_rgb, x4_rgb)
424 | x2_2_rgb = self.HA(self.upsample05(sal_init_rgb).sigmoid(), x2_rgb)
425 | x3_2_rgb = self.resnet_rgb.layer3_2(x2_2_rgb) # 1024 x 16 x 16
426 | x4_2_rgb = self.resnet_rgb.layer4_2(x3_2_rgb) # 2048 x 8 x 8
427 | sal_ref_rgb = self.sal_decoder2(x1_rgb, x2_2_rgb, x3_2_rgb, x4_2_rgb)
428 |
429 | if depth==None:
430 | return self.upsample4(sal_init_rgb), self.upsample4(sal_ref_rgb)
431 | else:
432 | x = torch.cat((raw_x,depth),1)
433 | x = self.conv_depth1(x)
434 | x = self.resnet_depth.conv1(x)
435 | x = self.resnet_depth.bn1(x)
436 | x = self.resnet_depth.relu(x)
437 | x = self.resnet_depth.maxpool(x)
438 | x1_depth = self.resnet_depth.layer1(x) # 256 x 64 x 64
439 | x2_depth = self.resnet_depth.layer2(x1_depth) # 512 x 32 x 32
440 | x3_depth = self.resnet_depth.layer3_1(x2_depth) # 1024 x 16 x 16
441 | x4_depth = self.resnet_depth.layer4_1(x3_depth) # 2048 x 8 x 8
442 |
443 | sal_init_depth = self.sal_decoder3(x1_depth, x2_depth, x3_depth, x4_depth)
444 | x2_2_depth = self.HA(self.upsample05(sal_init_depth).sigmoid(), x2_depth)
445 | x3_2_depth = self.resnet_depth.layer3_2(x2_2_depth) # 1024 x 16 x 16
446 | x4_2_depth = self.resnet_depth.layer4_2(x3_2_depth) # 2048 x 8 x 8
447 | sal_ref_depth = self.sal_decoder4(x1_depth, x2_2_depth, x3_2_depth, x4_2_depth)
448 |
449 |
450 | lat_loss1, z1_rgb, z1_depth = self.mi_level1(self.convx1_rgb(x1_rgb), self.convx1_depth(x1_depth))
451 | lat_loss2, z2_rgb, z2_depth = self.mi_level2(self.upsample2(self.convx2_rgb(x2_2_rgb)), self.upsample2(self.convx2_depth(x2_2_depth)))
452 | lat_loss3, z3_rgb, z3_depth = self.mi_level3(self.upsample4(self.convx3_rgb(x3_2_rgb)), self.upsample4(self.convx3_depth(x3_2_depth)))
453 | lat_loss4, z4_rgb, z4_depth = self.mi_level4(self.upsample8(self.convx4_rgb(x4_2_rgb)), self.upsample8(self.convx4_depth(x4_2_depth)))
454 |
455 | lat_loss = lat_loss1+lat_loss2+lat_loss3+lat_loss4
456 |
457 | sal_mi_rgb = self.sal_decoder5(x1_rgb, x2_2_rgb, x3_2_rgb, x4_2_rgb, z1_depth,z2_depth,z3_depth,z4_depth)
458 |
459 | sal_mi_depth = self.sal_decoder6(x1_depth, x2_2_depth, x3_2_depth, x4_2_depth, z1_rgb,z2_rgb,z3_rgb,z4_rgb)
460 |
461 | final_sal = torch.cat((sal_ref_rgb,sal_ref_depth,sal_mi_rgb,sal_mi_depth),1)
462 | final_sal = self.final_clc(final_sal)
463 |
464 | return self.upsample4(sal_init_rgb), self.upsample4(sal_ref_rgb), self.upsample4(sal_init_depth), self.upsample4(
465 | sal_ref_depth), self.upsample4(sal_mi_rgb), self.upsample4(sal_mi_depth), self.upsample4(final_sal), lat_loss
466 |
467 |
468 |
469 | def initialize_weights(self):
470 | res50 = models.resnet50(pretrained=True)
471 | pretrained_dict = res50.state_dict()
472 | all_params = {}
473 | for k, v in self.resnet_rgb.state_dict().items():
474 | if k in pretrained_dict.keys():
475 | v = pretrained_dict[k]
476 | all_params[k] = v
477 | elif '_1' in k:
478 | name = k.split('_1')[0] + k.split('_1')[1]
479 | v = pretrained_dict[name]
480 | all_params[k] = v
481 | elif '_2' in k:
482 | name = k.split('_2')[0] + k.split('_2')[1]
483 | v = pretrained_dict[name]
484 | all_params[k] = v
485 | assert len(all_params.keys()) == len(self.resnet_rgb.state_dict().keys())
486 | self.resnet_rgb.load_state_dict(all_params)
487 | self.resnet_depth.load_state_dict(all_params)
488 |
--------------------------------------------------------------------------------
/model/ResNet_models_sep.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torchvision.models as models
4 | import numpy as np
5 | from model.ResNet import B2_ResNet
6 | from utils import init_weights,init_weights_orthogonal_normal
7 | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
8 | from torch.autograd import Variable
9 | from torch.nn import Parameter, Softmax
10 | import torch.nn.functional as F
11 | from torch.distributions import Normal, Independent, kl
12 | from model.HolisticAttention import HA
13 | import math
14 | CE = torch.nn.BCELoss(reduction='sum')
15 | cos_sim = torch.nn.CosineSimilarity(dim=1,eps=1e-8)
16 | from model.batchrenorm import BatchRenorm2d
17 |
18 | class BasicConv2d(nn.Module):
19 | def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1):
20 | super(BasicConv2d, self).__init__()
21 | self.conv = nn.Conv2d(in_planes, out_planes,
22 | kernel_size=kernel_size, stride=stride,
23 | padding=padding, dilation=dilation, bias=False)
24 | self.bn = nn.BatchNorm2d(out_planes)
25 | self.relu = nn.ReLU(inplace=True)
26 |
27 | def forward(self, x):
28 | x = self.conv(x)
29 | x = self.bn(x)
30 | return x
31 |
32 | class Classifier_Module(nn.Module):
33 | def __init__(self,dilation_series,padding_series,NoLabels, input_channel):
34 | super(Classifier_Module, self).__init__()
35 | self.conv2d_list = nn.ModuleList()
36 | for dilation,padding in zip(dilation_series,padding_series):
37 | self.conv2d_list.append(nn.Conv2d(input_channel,NoLabels,kernel_size=3,stride=1, padding =padding, dilation = dilation,bias = True))
38 | for m in self.conv2d_list:
39 | m.weight.data.normal_(0, 0.01)
40 |
41 | def forward(self, x):
42 | out = self.conv2d_list[0](x)
43 | for i in range(len(self.conv2d_list)-1):
44 | out += self.conv2d_list[i+1](x)
45 | return out
46 |
47 | class Mutual_info_reg(nn.Module):
48 | def __init__(self, input_channels, channels, latent_size):
49 | super(Mutual_info_reg, self).__init__()
50 | self.contracting_path = nn.ModuleList()
51 | self.input_channels = input_channels
52 | self.relu = nn.ReLU(inplace=True)
53 | self.layer1 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1)
54 | self.bn1 = nn.BatchNorm2d(channels)
55 | self.layer2 = nn.Conv2d(input_channels, channels, kernel_size=4, stride=2, padding=1)
56 | self.bn2 = nn.BatchNorm2d(channels)
57 | self.layer3 = nn.Conv2d(channels, channels, kernel_size=4, stride=2, padding=1)
58 | self.layer4 = nn.Conv2d(channels, channels, kernel_size=4, stride=2, padding=1)
59 |
60 | self.channel = channels
61 |
62 | self.fc1_rgb1 = nn.Linear(channels * 1 * 16 * 16, latent_size)
63 | self.fc2_rgb1 = nn.Linear(channels * 1 * 16 * 16, latent_size)
64 | self.fc1_depth1 = nn.Linear(channels * 1 * 16 * 16, latent_size)
65 | self.fc2_depth1 = nn.Linear(channels * 1 * 16 * 16, latent_size)
66 |
67 | self.fc1_rgb2 = nn.Linear(channels * 1 * 22 * 22, latent_size)
68 | self.fc2_rgb2 = nn.Linear(channels * 1 * 22 * 22, latent_size)
69 | self.fc1_depth2 = nn.Linear(channels * 1 * 22 * 22, latent_size)
70 | self.fc2_depth2 = nn.Linear(channels * 1 * 22 * 22, latent_size)
71 |
72 | self.fc1_rgb3 = nn.Linear(channels * 1 * 28 * 28, latent_size)
73 | self.fc2_rgb3 = nn.Linear(channels * 1 * 28 * 28, latent_size)
74 | self.fc1_depth3 = nn.Linear(channels * 1 * 28 * 28, latent_size)
75 | self.fc2_depth3 = nn.Linear(channels * 1 * 28 * 28, latent_size)
76 |
77 | self.leakyrelu = nn.LeakyReLU()
78 | self.tanh = torch.nn.Tanh()
79 |
80 | def kl_divergence(self, posterior_latent_space, prior_latent_space):
81 | kl_div = kl.kl_divergence(posterior_latent_space, prior_latent_space)
82 | return kl_div
83 |
84 | def reparametrize(self, mu, logvar):
85 | std = logvar.mul(0.5).exp_()
86 | eps = torch.cuda.FloatTensor(std.size()).normal_()
87 | eps = Variable(eps)
88 | return eps.mul(std).add_(mu)
89 |
90 | def forward(self, rgb_feat, depth_feat):
91 | rgb_feat = self.layer3(self.leakyrelu(self.bn1(self.layer1(rgb_feat))))
92 | depth_feat = self.layer4(self.leakyrelu(self.bn2(self.layer2(depth_feat))))
93 | # print(rgb_feat.size())
94 | # print(depth_feat.size())
95 | if rgb_feat.shape[2] == 16:
96 | rgb_feat = rgb_feat.view(-1, self.channel * 1 * 16 * 16)
97 | depth_feat = depth_feat.view(-1, self.channel * 1 * 16 * 16)
98 |
99 | mu_rgb = self.fc1_rgb1(rgb_feat)
100 | logvar_rgb = self.fc2_rgb1(rgb_feat)
101 | mu_depth = self.fc1_depth1(depth_feat)
102 | logvar_depth = self.fc2_depth1(depth_feat)
103 | elif rgb_feat.shape[2] == 22:
104 | rgb_feat = rgb_feat.view(-1, self.channel * 1 * 22 * 22)
105 | depth_feat = depth_feat.view(-1, self.channel * 1 * 22 * 22)
106 | mu_rgb = self.fc1_rgb2(rgb_feat)
107 | logvar_rgb = self.fc2_rgb2(rgb_feat)
108 | mu_depth = self.fc1_depth2(depth_feat)
109 | logvar_depth = self.fc2_depth2(depth_feat)
110 | else:
111 | rgb_feat = rgb_feat.view(-1, self.channel * 1 * 28 * 28)
112 | depth_feat = depth_feat.view(-1, self.channel * 1 * 28 * 28)
113 | mu_rgb = self.fc1_rgb3(rgb_feat)
114 | logvar_rgb = self.fc2_rgb3(rgb_feat)
115 | mu_depth = self.fc1_depth3(depth_feat)
116 | logvar_depth = self.fc2_depth3(depth_feat)
117 |
118 | mu_depth = self.tanh(mu_depth)
119 | mu_rgb = self.tanh(mu_rgb)
120 | logvar_depth = self.tanh(logvar_depth)
121 | logvar_rgb = self.tanh(logvar_rgb)
122 | z_rgb = self.reparametrize(mu_rgb, logvar_rgb)
123 | dist_rgb = Independent(Normal(loc=mu_rgb, scale=torch.exp(logvar_rgb)), 1)
124 | z_depth = self.reparametrize(mu_depth, logvar_depth)
125 | dist_depth = Independent(Normal(loc=mu_depth, scale=torch.exp(logvar_depth)), 1)
126 | bi_di_kld = torch.mean(self.kl_divergence(dist_rgb, dist_depth)) + torch.mean(
127 | self.kl_divergence(dist_depth, dist_rgb))
128 | z_rgb_norm = torch.sigmoid(z_rgb)
129 | z_depth_norm = torch.sigmoid(z_depth)
130 | ce_rgb_depth = CE(z_rgb_norm,z_depth_norm.detach())
131 | ce_depth_rgb = CE(z_depth_norm, z_rgb_norm.detach())
132 | latent_loss = ce_rgb_depth+ce_depth_rgb-bi_di_kld
133 | # latent_loss = torch.abs(cos_sim(z_rgb,z_depth)).sum()
134 |
135 | return latent_loss, z_rgb, z_depth
136 |
137 |
138 |
139 | class CAM_Module(nn.Module):
140 | """ Channel attention module"""
141 | def __init__(self):
142 | super(CAM_Module, self).__init__()
143 | self.gamma = Parameter(torch.zeros(1))
144 | self.softmax = Softmax(dim=-1)
145 | def forward(self,x):
146 | """
147 | inputs :
148 | x : input feature maps( B X C X H X W)
149 | returns :
150 | out : attention value + input feature
151 | attention: B X C X C
152 | """
153 | m_batchsize, C, height, width = x.size()
154 | proj_query = x.view(m_batchsize, C, -1)
155 | proj_key = x.view(m_batchsize, C, -1).permute(0, 2, 1)
156 | energy = torch.bmm(proj_query, proj_key)
157 | energy_new = torch.max(energy, -1, keepdim=True)[0].expand_as(energy)-energy
158 | attention = self.softmax(energy_new)
159 | proj_value = x.view(m_batchsize, C, -1)
160 |
161 | out = torch.bmm(attention, proj_value)
162 | out = out.view(m_batchsize, C, height, width)
163 |
164 | out = self.gamma*out + x
165 | return out
166 |
167 | ## Channel Attention (CA) Layer
168 | class CALayer(nn.Module):
169 | def __init__(self, channel, reduction=16):
170 | super(CALayer, self).__init__()
171 | # global average pooling: feature --> point
172 | self.avg_pool = nn.AdaptiveAvgPool2d(1)
173 | # feature channel downscale and upscale --> channel weight
174 | self.conv_du = nn.Sequential(
175 | nn.Conv2d(channel, channel // reduction, 1, padding=0, bias=True),
176 | nn.ReLU(inplace=True),
177 | nn.Conv2d(channel // reduction, channel, 1, padding=0, bias=True),
178 | nn.Sigmoid()
179 | )
180 |
181 | def forward(self, x):
182 | y = self.avg_pool(x)
183 | y = self.conv_du(y)
184 | return x * y
185 |
186 | ## Residual Channel Attention Block (RCAB)
187 |
188 |
189 |
190 | class RCAB(nn.Module):
191 | def __init__(
192 | self, n_feat, kernel_size=3, reduction=16,
193 | bias=True, bn=False, act=nn.ReLU(True), res_scale=1):
194 |
195 | super(RCAB, self).__init__()
196 | modules_body = []
197 | for i in range(2):
198 | modules_body.append(self.default_conv(n_feat, n_feat, kernel_size, bias=bias))
199 | if bn: modules_body.append(nn.BatchNorm2d(n_feat))
200 | if i == 0: modules_body.append(act)
201 | modules_body.append(CALayer(n_feat, reduction))
202 | self.body = nn.Sequential(*modules_body)
203 | self.res_scale = res_scale
204 |
205 | def default_conv(self, in_channels, out_channels, kernel_size, bias=True):
206 | return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size // 2), bias=bias)
207 |
208 | def forward(self, x):
209 | res = self.body(x)
210 | #res = self.body(x).mul(self.res_scale)
211 | res += x
212 | return res
213 |
214 |
215 |
216 | class Saliency_feat_decoder(nn.Module):
217 | # resnet based encoder decoder
218 | def __init__(self, channel,latent_dim):
219 | super(Saliency_feat_decoder, self).__init__()
220 | self.relu = nn.ReLU(inplace=True)
221 | self.dropout = nn.Dropout(0.3)
222 |
223 | self.layer5 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 2048)
224 | self.layer6 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], 1, channel)
225 | self.layer7 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], 1, channel)
226 |
227 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
228 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
229 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
230 |
231 | self.conv1 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 256)
232 | self.conv2 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 512)
233 | self.conv3 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 1024)
234 | self.conv4 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 2048)
235 |
236 | self.spatial_axes = [2, 3]
237 |
238 | self.racb_43 = RCAB(channel * 2)
239 | self.racb_432 = RCAB(channel * 3)
240 | self.racb_4321 = RCAB(channel * 4)
241 |
242 | self.conv43 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 2*channel)
243 | self.conv432 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 3*channel)
244 | self.conv4321 = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], channel, 4*channel)
245 |
246 | self.layer_depth = self._make_pred_layer(Classifier_Module, [6, 12, 18, 24], [6, 12, 18, 24], 3, channel * 4)
247 |
248 | self.rcab_z1 = RCAB(channel + latent_dim)
249 | self.conv_z1 = BasicConv2d(channel+latent_dim,channel,3,padding=1)
250 |
251 | self.rcab_z2 = RCAB(channel + latent_dim)
252 | self.conv_z2 = BasicConv2d(channel + latent_dim, channel, 3, padding=1)
253 |
254 | self.rcab_z3 = RCAB(channel + latent_dim)
255 | self.conv_z3 = BasicConv2d(channel + latent_dim, channel, 3, padding=1)
256 |
257 | self.rcab_z4 = RCAB(channel + latent_dim)
258 | self.conv_z4 = BasicConv2d(channel + latent_dim, channel, 3, padding=1)
259 |
260 | self.br1 = BatchRenorm2d(channel)
261 | self.br2 = BatchRenorm2d(channel)
262 | self.br3 = BatchRenorm2d(channel)
263 | self.br4 = BatchRenorm2d(channel)
264 |
265 |
266 | def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel):
267 | return block(dilation_series, padding_series, NoLabels, input_channel)
268 |
269 | def tile(self, a, dim, n_tile):
270 | """
271 | This function is taken form PyTorch forum and mimics the behavior of tf.tile.
272 | Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
273 | """
274 | init_dim = a.size(dim)
275 | repeat_idx = [1] * a.dim()
276 | repeat_idx[dim] = n_tile
277 | a = a.repeat(*(repeat_idx))
278 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
279 | return torch.index_select(a, dim, order_index)
280 |
281 | def forward(self, x1,x2,x3,x4,z1=None,z2=None,z3=None,z4=None):
282 | conv1_feat = self.br1(self.conv1(x1))
283 | conv2_feat = self.br2(self.conv2(x2))
284 | conv3_feat = self.br3(self.conv3(x3))
285 | conv4_feat = self.br4(self.conv4(x4))
286 |
287 | if z1!=None:
288 | z1 = torch.unsqueeze(z1, 2)
289 | z1 = self.tile(z1, 2, conv1_feat.shape[self.spatial_axes[0]])
290 | z1 = torch.unsqueeze(z1, 3)
291 | z1 = self.tile(z1, 3, conv1_feat.shape[self.spatial_axes[1]])
292 |
293 | z2 = torch.unsqueeze(z2, 2)
294 | z2 = self.tile(z2, 2, conv2_feat.shape[self.spatial_axes[0]])
295 | z2 = torch.unsqueeze(z2, 3)
296 | z2 = self.tile(z2, 3, conv2_feat.shape[self.spatial_axes[1]])
297 |
298 | z3 = torch.unsqueeze(z3, 2)
299 | z3 = self.tile(z3, 2, conv3_feat.shape[self.spatial_axes[0]])
300 | z3 = torch.unsqueeze(z3, 3)
301 | z3 = self.tile(z3, 3, conv3_feat.shape[self.spatial_axes[1]])
302 |
303 | z4 = torch.unsqueeze(z4, 2)
304 | z4 = self.tile(z4, 2, conv4_feat.shape[self.spatial_axes[0]])
305 | z4 = torch.unsqueeze(z4, 3)
306 | z4 = self.tile(z4, 3, conv4_feat.shape[self.spatial_axes[1]])
307 |
308 | conv1_feat = torch.cat((conv1_feat,z1),1)
309 | conv1_feat = self.rcab_z1(conv1_feat)
310 | conv1_feat = self.conv_z1(conv1_feat)
311 |
312 | conv2_feat = torch.cat((conv2_feat, z2), 1)
313 | conv2_feat = self.rcab_z2(conv2_feat)
314 | conv2_feat = self.conv_z2(conv2_feat)
315 |
316 | conv3_feat = torch.cat((conv3_feat, z3), 1)
317 | conv3_feat = self.rcab_z3(conv3_feat)
318 | conv3_feat = self.conv_z3(conv3_feat)
319 |
320 | conv4_feat = torch.cat((conv4_feat, z4), 1)
321 | conv4_feat = self.rcab_z4(conv4_feat)
322 | conv4_feat = self.conv_z4(conv4_feat)
323 |
324 | conv4_feat = self.upsample2(conv4_feat)
325 |
326 | conv43 = torch.cat((conv4_feat, conv3_feat), 1)
327 | conv43 = self.racb_43(conv43)
328 | conv43 = self.conv43(conv43)
329 |
330 | conv43 = self.upsample2(conv43)
331 | conv432 = torch.cat((self.upsample2(conv4_feat), conv43, conv2_feat), 1)
332 | conv432 = self.racb_432(conv432)
333 | conv432 = self.conv432(conv432)
334 |
335 | conv432 = self.upsample2(conv432)
336 | conv4321 = torch.cat((self.upsample4(conv4_feat), self.upsample2(conv43), conv432, conv1_feat), 1)
337 | conv4321 = self.racb_4321(conv4321)
338 | conv4321 = self.conv4321(conv4321)
339 |
340 | sal_init = self.layer6(conv4321)
341 |
342 | return sal_init
343 |
344 |
345 | class Saliency_feat_endecoder(nn.Module):
346 | # resnet based encoder decoder
347 | def __init__(self, channel):
348 | super(Saliency_feat_endecoder, self).__init__()
349 | self.resnet_rgb = B2_ResNet()
350 | self.resnet_depth = B2_ResNet()
351 | self.relu = nn.ReLU(inplace=True)
352 | self.dropout = nn.Dropout(0.3)
353 | self.latent_dim = 6
354 | self.conv_depth1 = BasicConv2d(6, 3, kernel_size=3, padding=1)
355 | self.sal_decoder1 = Saliency_feat_decoder(channel, self.latent_dim)
356 | self.sal_decoder2 = Saliency_feat_decoder(channel, self.latent_dim)
357 | self.sal_decoder3 = Saliency_feat_decoder(channel, self.latent_dim)
358 | self.sal_decoder4 = Saliency_feat_decoder(channel, self.latent_dim)
359 | self.sal_decoder5 = Saliency_feat_decoder(channel, self.latent_dim)
360 | self.sal_decoder6 = Saliency_feat_decoder(channel, self.latent_dim)
361 |
362 | self.HA = HA()
363 | self.upsample05 = nn.Upsample(scale_factor=0.5, mode='bilinear', align_corners=True)
364 | self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear', align_corners=True)
365 | self.upsample2 = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
366 | self.upsample8 = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)
367 | self.upsample025 = nn.Upsample(scale_factor=0.25, mode='bilinear', align_corners=True)
368 | self.upsample0125 = nn.Upsample(scale_factor=0.125, mode='bilinear', align_corners=True)
369 |
370 |
371 | self.convx1_depth = nn.Conv2d(in_channels=256, out_channels=channel, kernel_size=3, padding=1)
372 | self.convx2_depth = nn.Conv2d(in_channels=512, out_channels=channel, kernel_size=3, padding=1)
373 | self.convx3_depth = nn.Conv2d(in_channels=1024, out_channels=channel, kernel_size=3, padding=1)
374 | self.convx4_depth = nn.Conv2d(in_channels=2048, out_channels=channel, kernel_size=3, padding=1)
375 |
376 | self.convx1_rgb = nn.Conv2d(in_channels=256, out_channels=channel, kernel_size=3, padding=1)
377 | self.convx2_rgb = nn.Conv2d(in_channels=512, out_channels=channel, kernel_size=3, padding=1)
378 | self.convx3_rgb = nn.Conv2d(in_channels=1024, out_channels=channel, kernel_size=3, padding=1)
379 | self.convx4_rgb = nn.Conv2d(in_channels=2048, out_channels=channel, kernel_size=3, padding=1)
380 |
381 | self.mi_level1 = Mutual_info_reg(channel,channel,self.latent_dim)
382 | self.mi_level2 = Mutual_info_reg(channel, channel, self.latent_dim)
383 | self.mi_level3 = Mutual_info_reg(channel, channel, self.latent_dim)
384 | self.mi_level4 = Mutual_info_reg(channel, channel, self.latent_dim)
385 |
386 | self.spatial_axes = [2, 3]
387 | self.final_clc = nn.Conv2d(in_channels=4, out_channels=1, kernel_size=3, padding=1)
388 | self.rcab_rgb_feat = RCAB(channel*4)
389 | self.rcab_depth_feat = RCAB(channel*4)
390 |
391 |
392 |
393 |
394 | if self.training:
395 | self.initialize_weights()
396 |
397 | def _make_pred_layer(self, block, dilation_series, padding_series, NoLabels, input_channel):
398 | return block(dilation_series, padding_series, NoLabels, input_channel)
399 |
400 | def tile(self, a, dim, n_tile):
401 | """
402 | This function is taken form PyTorch forum and mimics the behavior of tf.tile.
403 | Source: https://discuss.pytorch.org/t/how-to-tile-a-tensor/13853/3
404 | """
405 | init_dim = a.size(dim)
406 | repeat_idx = [1] * a.dim()
407 | repeat_idx[dim] = n_tile
408 | a = a.repeat(*(repeat_idx))
409 | order_index = torch.LongTensor(np.concatenate([init_dim * np.arange(n_tile) + i for i in range(init_dim)])).to(device)
410 | return torch.index_select(a, dim, order_index)
411 |
412 | def forward(self, x,depth=None):
413 | raw_x = x
414 | x = self.resnet_rgb.conv1(x)
415 | x = self.resnet_rgb.bn1(x)
416 | x = self.resnet_rgb.relu(x)
417 | x = self.resnet_rgb.maxpool(x)
418 | x1_rgb = self.resnet_rgb.layer1(x) # 256 x 64 x 64
419 | x2_rgb = self.resnet_rgb.layer2(x1_rgb) # 512 x 32 x 32
420 | x3_rgb = self.resnet_rgb.layer3_1(x2_rgb) # 1024 x 16 x 16
421 | x4_rgb = self.resnet_rgb.layer4_1(x3_rgb) # 2048 x 8 x 8
422 |
423 | sal_init_rgb = self.sal_decoder1(x1_rgb, x2_rgb, x3_rgb, x4_rgb)
424 | x2_2_rgb = self.HA(self.upsample05(sal_init_rgb).sigmoid(), x2_rgb)
425 | x3_2_rgb = self.resnet_rgb.layer3_2(x2_2_rgb) # 1024 x 16 x 16
426 | x4_2_rgb = self.resnet_rgb.layer4_2(x3_2_rgb) # 2048 x 8 x 8
427 | sal_ref_rgb = self.sal_decoder2(x1_rgb, x2_2_rgb, x3_2_rgb, x4_2_rgb)
428 |
429 | if depth==None:
430 | return self.upsample4(sal_init_rgb), self.upsample4(sal_ref_rgb)
431 | else:
432 | x = depth
433 | x = self.resnet_depth.conv1(x)
434 | x = self.resnet_depth.bn1(x)
435 | x = self.resnet_depth.relu(x)
436 | x = self.resnet_depth.maxpool(x)
437 | x1_depth = self.resnet_depth.layer1(x) # 256 x 64 x 64
438 | x2_depth = self.resnet_depth.layer2(x1_depth) # 512 x 32 x 32
439 | x3_depth = self.resnet_depth.layer3_1(x2_depth) # 1024 x 16 x 16
440 | x4_depth = self.resnet_depth.layer4_1(x3_depth) # 2048 x 8 x 8
441 |
442 | sal_init_depth = self.sal_decoder3(x1_depth, x2_depth, x3_depth, x4_depth)
443 | x2_2_depth = self.HA(self.upsample05(sal_init_depth).sigmoid(), x2_depth)
444 | x3_2_depth = self.resnet_depth.layer3_2(x2_2_depth) # 1024 x 16 x 16
445 | x4_2_depth = self.resnet_depth.layer4_2(x3_2_depth) # 2048 x 8 x 8
446 | sal_ref_depth = self.sal_decoder4(x1_depth, x2_2_depth, x3_2_depth, x4_2_depth)
447 |
448 |
449 | lat_loss1, z1_rgb, z1_depth = self.mi_level1(self.convx1_rgb(x1_rgb), self.convx1_depth(x1_depth))
450 | lat_loss2, z2_rgb, z2_depth = self.mi_level2(self.upsample2(self.convx2_rgb(x2_2_rgb)), self.upsample2(self.convx2_depth(x2_2_depth)))
451 | lat_loss3, z3_rgb, z3_depth = self.mi_level3(self.upsample4(self.convx3_rgb(x3_2_rgb)), self.upsample4(self.convx3_depth(x3_2_depth)))
452 | lat_loss4, z4_rgb, z4_depth = self.mi_level4(self.upsample8(self.convx4_rgb(x4_2_rgb)), self.upsample8(self.convx4_depth(x4_2_depth)))
453 |
454 | lat_loss = lat_loss1+lat_loss2+lat_loss3+lat_loss4
455 |
456 | sal_mi_rgb = self.sal_decoder5(x1_rgb, x2_2_rgb, x3_2_rgb, x4_2_rgb, z1_depth,z2_depth,z3_depth,z4_depth)
457 |
458 | sal_mi_depth = self.sal_decoder6(x1_depth, x2_2_depth, x3_2_depth, x4_2_depth, z1_rgb,z2_rgb,z3_rgb,z4_rgb)
459 |
460 | final_sal = torch.cat((sal_ref_rgb,sal_ref_depth,sal_mi_rgb,sal_mi_depth),1)
461 | final_sal = self.final_clc(final_sal)
462 |
463 | return self.upsample4(sal_init_rgb), self.upsample4(sal_ref_rgb), self.upsample4(sal_init_depth), self.upsample4(
464 | sal_ref_depth), self.upsample4(sal_mi_rgb), self.upsample4(sal_mi_depth), self.upsample4(final_sal), lat_loss
465 |
466 |
467 |
468 | def initialize_weights(self):
469 | res50 = models.resnet50(pretrained=True)
470 | pretrained_dict = res50.state_dict()
471 | all_params = {}
472 | for k, v in self.resnet_rgb.state_dict().items():
473 | if k in pretrained_dict.keys():
474 | v = pretrained_dict[k]
475 | all_params[k] = v
476 | elif '_1' in k:
477 | name = k.split('_1')[0] + k.split('_1')[1]
478 | v = pretrained_dict[name]
479 | all_params[k] = v
480 | elif '_2' in k:
481 | name = k.split('_2')[0] + k.split('_2')[1]
482 | v = pretrained_dict[name]
483 | all_params[k] = v
484 | assert len(all_params.keys()) == len(self.resnet_rgb.state_dict().keys())
485 | self.resnet_rgb.load_state_dict(all_params)
486 | self.resnet_depth.load_state_dict(all_params)
487 |
--------------------------------------------------------------------------------
/model/__init__.py:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/model/batchrenorm.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | __all__ = ["BatchRenorm1d", "BatchRenorm2d", "BatchRenorm3d"]
5 |
6 |
7 | class BatchRenorm(torch.jit.ScriptModule):
8 | def __init__(
9 | self,
10 | num_features: int,
11 | eps: float = 1e-3,
12 | momentum: float = 0.01,
13 | affine: bool = True,
14 | ):
15 | super().__init__()
16 | self.register_buffer(
17 | "running_mean", torch.zeros(num_features, dtype=torch.float)
18 | )
19 | self.register_buffer(
20 | "running_std", torch.ones(num_features, dtype=torch.float)
21 | )
22 | self.register_buffer(
23 | "num_batches_tracked", torch.tensor(0, dtype=torch.long)
24 | )
25 | self.weight = torch.nn.Parameter(
26 | torch.ones(num_features, dtype=torch.float)
27 | )
28 | self.bias = torch.nn.Parameter(
29 | torch.zeros(num_features, dtype=torch.float)
30 | )
31 | self.affine = affine
32 | self.eps = eps
33 | self.step = 0
34 | self.momentum = momentum
35 |
36 | def _check_input_dim(self, x: torch.Tensor) -> None:
37 | raise NotImplementedError() # pragma: no cover
38 |
39 | @property
40 | def rmax(self) -> torch.Tensor:
41 | return (2 / 35000 * self.num_batches_tracked + 25 / 35).clamp_(
42 | 1.0, 3.0
43 | )
44 |
45 | @property
46 | def dmax(self) -> torch.Tensor:
47 | return (5 / 20000 * self.num_batches_tracked - 25 / 20).clamp_(
48 | 0.0, 5.0
49 | )
50 |
51 | def forward(self, x: torch.Tensor) -> torch.Tensor:
52 | self._check_input_dim(x)
53 | if x.dim() > 2:
54 | x = x.transpose(1, -1)
55 | if self.training:
56 | dims = [i for i in range(x.dim() - 1)]
57 | batch_mean = x.mean(dims)
58 | batch_std = x.std(dims, unbiased=False) + self.eps
59 | r = (
60 | batch_std.detach() / self.running_std.view_as(batch_std)
61 | ).clamp_(1 / self.rmax, self.rmax)
62 | d = (
63 | (batch_mean.detach() - self.running_mean.view_as(batch_mean))
64 | / self.running_std.view_as(batch_std)
65 | ).clamp_(-self.dmax, self.dmax)
66 | x = (x - batch_mean) / batch_std * r + d
67 | self.running_mean += self.momentum * (
68 | batch_mean.detach() - self.running_mean
69 | )
70 | self.running_std += self.momentum * (
71 | batch_std.detach() - self.running_std
72 | )
73 | self.num_batches_tracked += 1
74 | else:
75 | x = (x - self.running_mean) / self.running_std
76 | if self.affine:
77 | x = self.weight * x + self.bias
78 | if x.dim() > 2:
79 | x = x.transpose(1, -1)
80 | return x
81 |
82 |
83 | class BatchRenorm1d(BatchRenorm):
84 | def _check_input_dim(self, x: torch.Tensor) -> None:
85 | if x.dim() not in [2, 3]:
86 | raise ValueError("expected 2D or 3D input (got {x.dim()}D input)")
87 |
88 |
89 | class BatchRenorm2d(BatchRenorm):
90 | def _check_input_dim(self, x: torch.Tensor) -> None:
91 | if x.dim() != 4:
92 | raise ValueError("expected 4D input (got {x.dim()}D input)")
93 |
94 |
95 | class BatchRenorm3d(BatchRenorm):
96 | def _check_input_dim(self, x: torch.Tensor) -> None:
97 | if x.dim() != 5:
98 | raise ValueError("expected 5D input (got {x.dim()}D input)")
99 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import numpy as np
5 | import pdb, os, argparse
6 | os.environ["CUDA_VISIBLE_DEVICES"] = '1'
7 | from scipy import misc
8 | from model.ResNet_models import Saliency_feat_endecoder
9 | from data import test_dataset
10 | import cv2
11 |
12 |
13 |
14 | parser = argparse.ArgumentParser()
15 | parser.add_argument('--testsize', type=int, default=352, help='testing size')
16 | parser.add_argument('--latent_dim', type=int, default=3, help='latent dim')
17 | parser.add_argument('--feat_channel', type=int, default=64, help='reduced channel of saliency feat')
18 | opt = parser.parse_args()
19 |
20 | dataset_path = './test/'
21 | depth_path = './test/'
22 |
23 | generator = Saliency_feat_endecoder(channel=opt.feat_channel)
24 | generator.load_state_dict(torch.load('./models/Model_100_gen.pth'))
25 |
26 | generator.cuda()
27 | generator.eval()
28 |
29 | #
30 | test_datasets = ['DES', 'LFSD','NJU2K','NLPR','SIP','STERE']
31 |
32 | for dataset in test_datasets:
33 | save_path = './results/' + dataset + '/'
34 | if not os.path.exists(save_path):
35 | os.makedirs(save_path)
36 |
37 | image_root = dataset_path + dataset + '/RGB/'
38 | depth_root = dataset_path + dataset + '/depth/'
39 | test_loader = test_dataset(image_root, depth_root, opt.testsize)
40 | for i in range(test_loader.size):
41 | print(i)
42 | image, depth, HH, WW, name = test_loader.load_data()
43 | image = image.cuda()
44 | depth = depth.cuda()
45 | _,_,_,_,_,_,generator_pred,_ = generator.forward(image, depth)
46 | res = generator_pred
47 | res = F.upsample(res, size=[WW,HH], mode='bilinear', align_corners=False)
48 | res = res.sigmoid().data.cpu().numpy().squeeze()
49 | res = 255 * (res - res.min()) / (res.max() - res.min() + 1e-8)
50 | cv2.imwrite(save_path + name, res)
51 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | from torch.autograd import Variable
4 | import cv2
5 | import numpy as np
6 | import pdb, os, argparse
7 | os.environ["CUDA_VISIBLE_DEVICES"] = '0'
8 | from datetime import datetime
9 | from model.ResNet_models_combine import Saliency_feat_endecoder
10 | from data import get_loader
11 | from utils import adjust_lr, AvgMeter
12 | from scipy import misc
13 | from utils import l2_regularisation
14 |
15 |
16 | parser = argparse.ArgumentParser()
17 | parser.add_argument('--epoch', type=int, default=100, help='epoch number')
18 | parser.add_argument('--lr_gen', type=float, default=1e-4, help='learning rate')
19 | parser.add_argument('--batchsize', type=int, default=12, help='training batch size')
20 | parser.add_argument('--trainsize', type=int, default=352, help='training dataset size')
21 | parser.add_argument('--clip', type=float, default=0.5, help='gradient clipping margin')
22 | parser.add_argument('--decay_rate', type=float, default=0.9, help='decay rate of learning rate')
23 | parser.add_argument('--decay_epoch', type=int, default=80, help='every n epochs decay learning rate')
24 | parser.add_argument('-beta1_gen', type=float, default=0.5,help='beta of Adam for generator')
25 | parser.add_argument('--weight_decay', type=float, default=0.001, help='weight_decay')
26 | parser.add_argument('--feat_channel', type=int, default=64, help='reduced channel of saliency feat')
27 |
28 | opt = parser.parse_args()
29 | print('Generator Learning Rate: {}'.format(opt.lr_gen))
30 | # build models
31 | generator = Saliency_feat_endecoder(channel=opt.feat_channel)
32 | generator.cuda()
33 |
34 | generator_params = generator.parameters()
35 | generator_optimizer = torch.optim.Adam(generator_params, opt.lr_gen)
36 |
37 | ## load data
38 | image_root = './RGB/'
39 | gt_root = './GT/'
40 | depth_root = './depth/'
41 |
42 | train_loader = get_loader(image_root, gt_root, depth_root, batchsize=opt.batchsize, trainsize=opt.trainsize)
43 | total_step = len(train_loader)
44 |
45 | ## define loss
46 |
47 | CE = torch.nn.BCELoss()
48 | mse_loss = torch.nn.MSELoss(size_average=True, reduce=True)
49 | size_rates = [0.75,1,1.25] # multi-scale training
50 |
51 | def structure_loss(pred, mask):
52 | weit = 1+5*torch.abs(F.avg_pool2d(mask, kernel_size=31, stride=1, padding=15)-mask)
53 | wbce = F.binary_cross_entropy_with_logits(pred, mask, reduction='none')
54 | wbce = (weit*wbce).sum(dim=(2,3))/weit.sum(dim=(2,3))
55 |
56 |
57 | pred = torch.sigmoid(pred)
58 | inter = ((pred * mask) * weit).sum(dim=(2, 3))
59 | union = ((pred + mask) * weit).sum(dim=(2, 3))
60 | wiou = 1-(inter+1)/(union-inter+1)
61 |
62 | return (wbce+wiou).mean()
63 |
64 | def visualize_mi_rgb(var_map):
65 |
66 | for kk in range(var_map.shape[0]):
67 | pred_edge_kk = var_map[kk,:,:,:]
68 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
69 | pred_edge_kk *= 255.0
70 | pred_edge_kk = pred_edge_kk.astype(np.uint8)
71 | save_path = './temp/'
72 | name = '{:02d}_rgb_mi.png'.format(kk)
73 | cv2.imwrite(save_path + name, pred_edge_kk)
74 |
75 | def visualize_mi_depth(var_map):
76 |
77 | for kk in range(var_map.shape[0]):
78 | pred_edge_kk = var_map[kk,:,:,:]
79 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
80 | pred_edge_kk *= 255.0
81 | pred_edge_kk = pred_edge_kk.astype(np.uint8)
82 | save_path = './temp/'
83 | name = '{:02d}_depth_mi.png'.format(kk)
84 | cv2.imwrite(save_path + name, pred_edge_kk)
85 |
86 | ## visualize predictions and gt
87 | def visualize_rgb_init(var_map):
88 |
89 | for kk in range(var_map.shape[0]):
90 | pred_edge_kk = var_map[kk,:,:,:]
91 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
92 | pred_edge_kk *= 255.0
93 | pred_edge_kk = pred_edge_kk.astype(np.uint8)
94 | save_path = './temp/'
95 | name = '{:02d}_rgb_int.png'.format(kk)
96 | cv2.imwrite(save_path + name, pred_edge_kk)
97 |
98 | def visualize_depth_init(var_map):
99 |
100 | for kk in range(var_map.shape[0]):
101 | pred_edge_kk = var_map[kk,:,:,:]
102 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
103 | pred_edge_kk *= 255.0
104 | pred_edge_kk = pred_edge_kk.astype(np.uint8)
105 | save_path = './temp/'
106 | name = '{:02d}_depth_int.png'.format(kk)
107 | cv2.imwrite(save_path + name, pred_edge_kk)
108 |
109 | def visualize_rgb_ref(var_map):
110 |
111 | for kk in range(var_map.shape[0]):
112 | pred_edge_kk = var_map[kk,:,:,:]
113 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
114 | pred_edge_kk *= 255.0
115 | pred_edge_kk = pred_edge_kk.astype(np.uint8)
116 | save_path = './temp/'
117 | name = '{:02d}_rgb_ref.png'.format(kk)
118 | cv2.imwrite(save_path + name, pred_edge_kk)
119 |
120 | def visualize_depth_ref(var_map):
121 |
122 | for kk in range(var_map.shape[0]):
123 | pred_edge_kk = var_map[kk,:,:,:]
124 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
125 | pred_edge_kk *= 255.0
126 | pred_edge_kk = pred_edge_kk.astype(np.uint8)
127 | save_path = './temp/'
128 | name = '{:02d}_depth_ref.png'.format(kk)
129 | cv2.imwrite(save_path + name, pred_edge_kk)
130 |
131 | def visualize_final_rgbd(var_map):
132 |
133 | for kk in range(var_map.shape[0]):
134 | pred_edge_kk = var_map[kk,:,:,:]
135 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
136 | pred_edge_kk *= 255.0
137 | pred_edge_kk = pred_edge_kk.astype(np.uint8)
138 | save_path = './temp/'
139 | name = '{:02d}_rgbd.png'.format(kk)
140 | cv2.imwrite(save_path + name, pred_edge_kk)
141 |
142 | def visualize_uncertainty_prior_init(var_map):
143 |
144 | for kk in range(var_map.shape[0]):
145 | pred_edge_kk = var_map[kk,:,:,:]
146 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
147 | pred_edge_kk *= 255.0
148 | pred_edge_kk = pred_edge_kk.astype(np.uint8)
149 | save_path = './temp/'
150 | name = '{:02d}_prior_int.png'.format(kk)
151 | cv2.imwrite(save_path + name, pred_edge_kk)
152 |
153 | def visualize_gt(var_map):
154 |
155 | for kk in range(var_map.shape[0]):
156 | pred_edge_kk = var_map[kk,:,:,:]
157 | pred_edge_kk = pred_edge_kk.detach().cpu().numpy().squeeze()
158 | # pred_edge_kk = (pred_edge_kk - pred_edge_kk.min()) / (pred_edge_kk.max() - pred_edge_kk.min() + 1e-8)
159 | pred_edge_kk *= 255.0
160 | pred_edge_kk = pred_edge_kk.astype(np.uint8)
161 | save_path = './temp/'
162 | name = '{:02d}_gt.png'.format(kk)
163 | cv2.imwrite(save_path + name, pred_edge_kk)
164 |
165 | ## linear annealing to avoid posterior collapse
166 | def linear_annealing(init, fin, step, annealing_steps):
167 | """Linear annealing of a parameter."""
168 | if annealing_steps == 0:
169 | return fin
170 | assert fin > init
171 | delta = fin - init
172 | annealed = min(init + delta * step / annealing_steps, fin)
173 | return annealed
174 |
175 | def visualize_all_pred(pred1,pred2,pred3,pred4,pred5,pred6,pred7,pred8):
176 | for kk in range(pred1.shape[0]):
177 | pred1_kk, pred2_kk, pred3_kk, pred4_kk, pred5_kk, pred6_kk, pred7_kk, pred8_kk = pred1[kk, :, :, :], pred2[kk, :, :, :], pred3[kk, :, :, :], pred4[kk, :, :, :], pred5[kk, :, :, :], pred6[kk, :, :, :], pred7[kk, :, :, :], pred8[kk, :, :, :]
178 | pred1_kk = (pred1_kk.detach().cpu().numpy().squeeze()*255.0).astype(np.uint8)
179 | pred2_kk = (pred2_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8)
180 | pred3_kk = (pred3_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8)
181 | pred4_kk = (pred4_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8)
182 | pred5_kk = (pred5_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8)
183 | pred6_kk = (pred6_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8)
184 | pred7_kk = (pred7_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8)
185 | pred8_kk = (pred8_kk.detach().cpu().numpy().squeeze() * 255.0).astype(np.uint8)
186 |
187 | cat_img = cv2.hconcat([pred1_kk, pred2_kk, pred3_kk, pred4_kk, pred5_kk, pred6_kk, pred7_kk, pred8_kk])
188 | save_path = './temp/'
189 | name = '{:02d}_gt_initR_refR_intD_refD_mR_mD_Fused.png'.format(kk)
190 | cv2.imwrite(save_path + name, cat_img)
191 |
192 | print("Let's Play!")
193 | for epoch in range(1, opt.epoch+1):
194 | generator.train()
195 | loss_record = AvgMeter()
196 | print('Generator Learning Rate: {}'.format(generator_optimizer.param_groups[0]['lr']))
197 |
198 | for i, pack in enumerate(train_loader, start=1):
199 | for rate in size_rates:
200 | generator_optimizer.zero_grad()
201 | images, gts, depths = pack
202 | # print(index_batch)
203 | images = Variable(images)
204 | gts = Variable(gts)
205 | depths = Variable(depths)
206 | images = images.cuda()
207 | gts = gts.cuda()
208 | depths = depths.cuda()
209 |
210 | # multi-scale training samples
211 | trainsize = int(round(opt.trainsize * rate / 32) * 32)
212 | if rate != 1:
213 | images = F.upsample(images, size=(trainsize, trainsize), mode='bilinear',
214 | align_corners=True)
215 | gts = F.upsample(gts, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
216 | depths = F.upsample(depths, size=(trainsize, trainsize), mode='bilinear', align_corners=True)
217 |
218 | init_rgb, ref_rgb, init_depth, ref_depth, mi_rgb, mi_depth, fuse_sal, latent_loss = generator.forward(images,depths)
219 |
220 | sal_rgb_loss = structure_loss(init_rgb, gts) + structure_loss(ref_rgb, gts) + structure_loss(mi_rgb, gts)
221 | sal_depth_loss = structure_loss(init_depth, gts) + structure_loss(ref_depth, gts) + structure_loss(mi_depth, gts)
222 | sal_final_rgbd = structure_loss(fuse_sal, gts)
223 |
224 |
225 | anneal_reg = linear_annealing(0, 1, epoch, opt.epoch)
226 | latent_loss = 0.1*anneal_reg*latent_loss
227 | sal_loss = sal_rgb_loss+sal_depth_loss+sal_final_rgbd + latent_loss
228 | sal_loss.backward()
229 | generator_optimizer.step()
230 | visualize_all_pred(gts,torch.sigmoid(init_rgb),torch.sigmoid(ref_rgb),torch.sigmoid(init_depth),torch.sigmoid(ref_depth),torch.sigmoid(mi_rgb),torch.sigmoid(mi_depth),torch.sigmoid(fuse_sal))
231 |
232 | if rate == 1:
233 | loss_record.update(sal_loss.data, opt.batchsize)
234 |
235 |
236 | if i % 10 == 0 or i == total_step:
237 | print('{} Epoch [{:03d}/{:03d}], Step [{:04d}/{:04d}], gen Loss: {:.4f}'.
238 | format(datetime.now(), epoch, opt.epoch, i, total_step, loss_record.show()))
239 | # print(anneal_reg)
240 |
241 |
242 | adjust_lr(generator_optimizer, opt.lr_gen, epoch, opt.decay_rate, opt.decay_epoch)
243 |
244 | save_path = 'models/'
245 |
246 |
247 | if not os.path.exists(save_path):
248 | os.makedirs(save_path)
249 | if epoch % opt.epoch == 0:
250 | torch.save(generator.state_dict(), save_path + 'Model' + '_%d' % epoch + '_gen.pth')
251 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import numpy as np
4 |
5 | def clip_gradient(optimizer, grad_clip):
6 | for group in optimizer.param_groups:
7 | for param in group['params']:
8 | if param.grad is not None:
9 | param.grad.data.clamp_(-grad_clip, grad_clip)
10 |
11 |
12 | def adjust_lr(optimizer, init_lr, epoch, decay_rate=0.1, decay_epoch=5):
13 | decay = decay_rate ** (epoch // decay_epoch)
14 | for param_group in optimizer.param_groups:
15 | param_group['lr'] *= decay
16 |
17 |
18 | def truncated_normal_(tensor, mean=0, std=1):
19 | size = tensor.shape
20 | tmp = tensor.new_empty(size + (4,)).normal_()
21 | valid = (tmp < 2) & (tmp > -2)
22 | ind = valid.max(-1, keepdim=True)[1]
23 | tensor.data.copy_(tmp.gather(-1, ind).squeeze(-1))
24 | tensor.data.mul_(std).add_(mean)
25 |
26 | def init_weights(m):
27 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
28 | nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
29 | #nn.init.normal_(m.weight, std=0.001)
30 | #nn.init.normal_(m.bias, std=0.001)
31 | truncated_normal_(m.bias, mean=0, std=0.001)
32 |
33 | def init_weights_orthogonal_normal(m):
34 | if type(m) == nn.Conv2d or type(m) == nn.ConvTranspose2d:
35 | nn.init.orthogonal_(m.weight)
36 | truncated_normal_(m.bias, mean=0, std=0.001)
37 | #nn.init.normal_(m.bias, std=0.001)
38 |
39 | def l2_regularisation(m):
40 | l2_reg = None
41 |
42 | for W in m.parameters():
43 | if l2_reg is None:
44 | l2_reg = W.norm(2)
45 | else:
46 | l2_reg = l2_reg + W.norm(2)
47 | return l2_reg
48 |
49 | class AvgMeter(object):
50 | def __init__(self, num=40):
51 | self.num = num
52 | self.reset()
53 |
54 | def reset(self):
55 | self.val = 0
56 | self.avg = 0
57 | self.sum = 0
58 | self.count = 0
59 | self.losses = []
60 |
61 | def update(self, val, n=1):
62 | self.val = val
63 | self.sum += val * n
64 | self.count += n
65 | self.avg = self.sum / self.count
66 | self.losses.append(val)
67 |
68 | def show(self):
69 | a = len(self.losses)
70 | b = np.maximum(a-self.num, 0)
71 | c = self.losses[b:]
72 | #print(c)
73 | #d = torch.mean(torch.stack(c))
74 | #print(d)
75 | return torch.mean(torch.stack(c))
76 |
77 | # def save_mask_prediction_example(mask, pred, iter):
78 | # plt.imshow(pred[0,:,:],cmap='Greys')
79 | # plt.savefig('images/'+str(iter)+"_prediction.png")
80 | # plt.imshow(mask[0,:,:],cmap='Greys')
81 | # plt.savefig('images/'+str(iter)+"_mask.png")
--------------------------------------------------------------------------------