├── CoordConv.py
├── LICENSE
├── README.md
├── basic.py
├── criteria.py
├── dataloaders
├── __pycache__
│ ├── kitti_loader.cpython-36.pyc
│ ├── kitti_loader.cpython-38.pyc
│ ├── transforms.cpython-36.pyc
│ └── transforms.cpython-38.pyc
├── calib_cam_to_cam.txt
├── kitti_loader.py
└── transforms.py
├── helper.py
├── images
└── model architecture.png
├── main.py
├── main_distributed.py
├── metrics.py
├── model.py
├── utils.py
└── vis_utils.py
/CoordConv.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import numpy as np
4 |
5 | class AddCoordsNp():
6 | """Add coords to a tensor"""
7 | def __init__(self, x_dim=64, y_dim=64, with_r=False):
8 | self.x_dim = x_dim
9 | self.y_dim = y_dim
10 | self.with_r = with_r
11 |
12 | def call(self):
13 | """
14 | input_tensor: (batch, x_dim, y_dim, c)
15 | """
16 | #batch_size_tensor = np.shape(input_tensor)[0]
17 |
18 | xx_ones = np.ones([self.x_dim], dtype=np.int32)
19 | xx_ones = np.expand_dims(xx_ones, 1)
20 |
21 | #print(xx_ones.shape)
22 |
23 | xx_range = np.expand_dims(np.arange(self.y_dim), 0)
24 | #xx_range = np.expand_dims(xx_range, 1)
25 |
26 | #print(xx_range.shape)
27 |
28 | xx_channel = np.matmul(xx_ones, xx_range)
29 | xx_channel = np.expand_dims(xx_channel, -1)
30 |
31 | yy_ones = np.ones([self.y_dim], dtype=np.int32)
32 | yy_ones = np.expand_dims(yy_ones, 0)
33 |
34 | #print(yy_ones.shape)
35 |
36 | yy_range = np.expand_dims(np.arange(self.x_dim), 1)
37 | #yy_range = np.expand_dims(yy_range, -1)
38 |
39 | #print(yy_range.shape)
40 |
41 | yy_channel = np.matmul(yy_range, yy_ones)
42 | yy_channel = np.expand_dims(yy_channel, -1)
43 |
44 | xx_channel = xx_channel.astype('float32') / (self.y_dim - 1)
45 | yy_channel = yy_channel.astype('float32') / (self.x_dim - 1)
46 |
47 | xx_channel = xx_channel*2 - 1
48 | yy_channel = yy_channel*2 - 1
49 |
50 |
51 | #xx_channel = xx_channel.repeat(batch_size_tensor, axis=0)
52 | #yy_channel = yy_channel.repeat(batch_size_tensor, axis=0)
53 |
54 | ret = np.concatenate([xx_channel, yy_channel], axis=-1)
55 |
56 | if self.with_r:
57 | rr = np.sqrt( np.square(xx_channel-0.5) + np.square(yy_channel-0.5))
58 | ret = np.concatenate([ret, rr], axis=-1)
59 |
60 | return ret
61 |
62 |
63 | # pos = AddCoordsNp(352, 1216)
64 | # print(position.call().shape)
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2018 Fangchang Ma
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # GuideFormer
2 | This repo is the PyTorch implementation of our paper "GuideFormer: Transformers for Image Guided Depth Completion".
3 |

4 |
5 |
6 | ## Install
7 | Our released implementation is tested on.
8 | + Ubuntu 18.04
9 | + Python 3.8.10
10 | + PyTorch 1.8.1 / torchvision 0.9.1
11 | + NVIDIA CUDA 11.0
12 | + 8x NVIDIA Tesla V100 GPUs
13 |
14 | ```bash
15 | pip install numpy matplotlib Pillow
16 | pip install scikit-image
17 | pip install opencv-contrib-python==3.4.2.17
18 | pip install einops
19 | pip install timm
20 | ```
21 |
22 | ## Data
23 | - Download the [KITTI Depth](http://www.cvlibs.net/datasets/kitti/eval_depth.php?benchmark=depth_completion) Dataset from their websites.
24 | The overall data directory is structured as follows:
25 | ```
26 | ├── kitti_depth
27 | | ├── depth
28 | | | ├──data_depth_annotated
29 | | | | ├── train
30 | | | | ├── val
31 | | | ├── data_depth_velodyne
32 | | | | ├── train
33 | | | | ├── val
34 | | | ├── data_depth_selection
35 | | | | ├── test_depth_completion_anonymous
36 | | | | |── test_depth_prediction_anonymous
37 | | | | ├── val_selection_cropped
38 | ```
39 |
40 | ## Commands
41 | A complete list of training options is available with
42 | ```bash
43 | python main.py -h
44 | ```
45 |
46 | ### Training
47 | ```bash
48 | # Non-distributed GPU setting
49 | CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python main.py -b 8
50 |
51 | # Distributed GPU setting
52 | python -m torch.distributed.launch --nproc_per_node=8 main_disrtibuted.py -b 8
53 | ```
54 |
55 | ### Validation
56 | ```bash
57 | CUDA_VISIBLE_DEVICES="0" python main.py -b 1 --evaluate [checkpoint-path]
58 | # evaluate the trained model on the KITTI validation set(val_selection_cropped)
59 | ```
60 |
61 | ### Test
62 | ```bash
63 | CUDA_VISIBLE_DEVICES="0" python main.py -b 1 --evaluate [checkpoint-path] --test
64 | # generate and save results of the trained model on the KIITI test set (test_depth_completion_anonymous)
65 | ```
66 |
67 | ## Related Repositories
68 | The original code framework is rendered from ["PENet: Precise and Efficient Depth Completion"](https://github.com/JUGGHM/PENet_ICRA2021) (which is also rendered from ["Self-supervised Sparse-to-Dense: Self-supervised Depth Completion from LiDAR and Monocular Camera"](https://github.com/fangchangma/self-supervised-depth-completion)).
69 |
70 | And the part of utils is rendered from ["Swin Transformer"](https://github.com/microsoft/Swin-Transformer).
71 |
72 |
73 |
--------------------------------------------------------------------------------
/basic.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 | import math
5 |
6 | gks = 5
7 | pad = [i for i in range(gks*gks)]
8 | shift = torch.zeros(gks*gks, 4)
9 | for i in range(gks):
10 | for j in range(gks):
11 | top = i
12 | bottom = gks-1-i
13 | left = j
14 | right = gks-1-j
15 | pad[i*gks + j] = torch.nn.ZeroPad2d((left, right, top, bottom))
16 | #shift[i*gks + j, :] = torch.tensor([left, right, top, bottom])
17 | mid_pad = torch.nn.ZeroPad2d(((gks-1)/2, (gks-1)/2, (gks-1)/2, (gks-1)/2))
18 | zero_pad = pad[0]
19 |
20 | gks2 = 3 #guide kernel size
21 | pad2 = [i for i in range(gks2*gks2)]
22 | shift = torch.zeros(gks2*gks2, 4)
23 | for i in range(gks2):
24 | for j in range(gks2):
25 | top = i
26 | bottom = gks2-1-i
27 | left = j
28 | right = gks2-1-j
29 | pad2[i*gks2 + j] = torch.nn.ZeroPad2d((left, right, top, bottom))
30 |
31 | gks3 = 7 #guide kernel size
32 | pad3 = [i for i in range(gks3*gks3)]
33 | shift = torch.zeros(gks3*gks3, 4)
34 | for i in range(gks3):
35 | for j in range(gks3):
36 | top = i
37 | bottom = gks3-1-i
38 | left = j
39 | right = gks3-1-j
40 | pad3[i*gks3 + j] = torch.nn.ZeroPad2d((left, right, top, bottom))
41 |
42 | def weights_init(m):
43 | # Initialize filters with Gaussian random weights
44 | if isinstance(m, nn.Conv2d):
45 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
46 | m.weight.data.normal_(0, math.sqrt(2. / n))
47 | if m.bias is not None:
48 | m.bias.data.zero_()
49 | elif isinstance(m, nn.ConvTranspose2d):
50 | n = m.kernel_size[0] * m.kernel_size[1] * m.in_channels
51 | m.weight.data.normal_(0, math.sqrt(2. / n))
52 | if m.bias is not None:
53 | m.bias.data.zero_()
54 | elif isinstance(m, nn.BatchNorm2d):
55 | m.weight.data.fill_(1)
56 | m.bias.data.zero_()
57 |
58 | def convbnrelu(in_channels, out_channels, kernel_size=3,stride=1, padding=1):
59 | return nn.Sequential(
60 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
61 | nn.BatchNorm2d(out_channels),
62 | nn.ReLU(inplace=True)
63 | )
64 |
65 | def deconvbnrelu(in_channels, out_channels, kernel_size=5, stride=2, padding=2, output_padding=1):
66 | return nn.Sequential(
67 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=False),
68 | nn.BatchNorm2d(out_channels),
69 | nn.ReLU(inplace=True)
70 | )
71 |
72 | def convbn(in_channels, out_channels, kernel_size=3,stride=1, padding=1):
73 | return nn.Sequential(
74 | nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=False),
75 | nn.BatchNorm2d(out_channels)
76 | )
77 |
78 | def deconvbn(in_channels, out_channels, kernel_size=4, stride=2, padding=1, output_padding=0):
79 | return nn.Sequential(
80 | nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding, output_padding=output_padding, bias=False),
81 | nn.BatchNorm2d(out_channels)
82 | )
83 |
84 | class BasicBlock(nn.Module):
85 | expansion = 1
86 | __constants__ = ['downsample']
87 |
88 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
89 | base_width=64, dilation=1, norm_layer=None):
90 | super(BasicBlock, self).__init__()
91 | if norm_layer is None:
92 | norm_layer = nn.BatchNorm2d
93 | #norm_layer = encoding.nn.BatchNorm2d
94 | if groups != 1 or base_width != 64:
95 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
96 | if dilation > 1:
97 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
98 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
99 | self.conv1 = conv3x3(inplanes, planes, stride)
100 | self.bn1 = norm_layer(planes)
101 | self.relu = nn.ReLU(inplace=True)
102 | self.conv2 = conv3x3(planes, planes)
103 | self.bn2 = norm_layer(planes)
104 | if stride != 1 or inplanes != planes:
105 | downsample = nn.Sequential(
106 | conv1x1(inplanes, planes, stride),
107 | norm_layer(planes),
108 | )
109 | self.downsample = downsample
110 | self.stride = stride
111 |
112 | def forward(self, x):
113 | identity = x
114 |
115 | out = self.conv1(x)
116 | out = self.bn1(out)
117 | out = self.relu(out)
118 |
119 | out = self.conv2(out)
120 | out = self.bn2(out)
121 |
122 | if self.downsample is not None:
123 | identity = self.downsample(x)
124 |
125 | out += identity
126 | out = self.relu(out)
127 |
128 | return out
129 |
130 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1, bias=False, padding=1):
131 | """3x3 convolution with padding"""
132 | if padding >= 1:
133 | padding = dilation
134 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
135 | padding=padding, groups=groups, bias=bias, dilation=dilation)
136 |
137 | def conv1x1(in_planes, out_planes, stride=1, groups=1, bias=False):
138 | """1x1 convolution"""
139 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, groups=groups, bias=bias)
140 |
141 | class SparseDownSampleClose(nn.Module):
142 | def __init__(self, stride):
143 | super(SparseDownSampleClose, self).__init__()
144 | self.pooling = nn.MaxPool2d(stride, stride)
145 | self.large_number = 600
146 | def forward(self, d, mask):
147 | encode_d = - (1-mask)*self.large_number - d
148 |
149 | d = - self.pooling(encode_d)
150 | mask_result = self.pooling(mask)
151 | d_result = d - (1-mask_result)*self.large_number
152 |
153 | return d_result, mask_result
154 |
155 | class CSPNGenerate(nn.Module):
156 | def __init__(self, in_channels, kernel_size):
157 | super(CSPNGenerate, self).__init__()
158 | self.kernel_size = kernel_size
159 | self.generate = convbn(in_channels, self.kernel_size * self.kernel_size - 1, kernel_size=3, stride=1, padding=1)
160 |
161 | def forward(self, feature):
162 |
163 | guide = self.generate(feature)
164 |
165 | #normalization
166 | guide_sum = torch.sum(guide.abs(), dim=1).unsqueeze(1)
167 | guide = torch.div(guide, guide_sum)
168 | guide_mid = (1 - torch.sum(guide, dim=1)).unsqueeze(1)
169 |
170 | #padding
171 | weight_pad = [i for i in range(self.kernel_size * self.kernel_size)]
172 | for t in range(self.kernel_size*self.kernel_size):
173 | zero_pad = 0
174 | if(self.kernel_size==3):
175 | zero_pad = pad2[t]
176 | elif(self.kernel_size==5):
177 | zero_pad = pad[t]
178 | elif(self.kernel_size==7):
179 | zero_pad = pad3[t]
180 | if(t < int((self.kernel_size*self.kernel_size-1)/2)):
181 | weight_pad[t] = zero_pad(guide[:, t:t+1, :, :])
182 | elif(t > int((self.kernel_size*self.kernel_size-1)/2)):
183 | weight_pad[t] = zero_pad(guide[:, t-1:t, :, :])
184 | else:
185 | weight_pad[t] = zero_pad(guide_mid)
186 |
187 | guide_weight = torch.cat([weight_pad[t] for t in range(self.kernel_size*self.kernel_size)], dim=1)
188 | return guide_weight
189 |
190 | class CSPN(nn.Module):
191 | def __init__(self, kernel_size):
192 | super(CSPN, self).__init__()
193 | self.kernel_size = kernel_size
194 |
195 | def forward(self, guide_weight, hn, h0):
196 |
197 | #CSPN
198 | half = int(0.5 * (self.kernel_size * self.kernel_size - 1))
199 | result_pad = [i for i in range(self.kernel_size * self.kernel_size)]
200 | for t in range(self.kernel_size*self.kernel_size):
201 | zero_pad = 0
202 | if(self.kernel_size==3):
203 | zero_pad = pad2[t]
204 | elif(self.kernel_size==5):
205 | zero_pad = pad[t]
206 | elif(self.kernel_size==7):
207 | zero_pad = pad3[t]
208 | if(t == half):
209 | result_pad[t] = zero_pad(h0)
210 | else:
211 | result_pad[t] = zero_pad(hn)
212 | guide_result = torch.cat([result_pad[t] for t in range(self.kernel_size*self.kernel_size)], dim=1)
213 | #guide_result = torch.cat([result0_pad, result1_pad, result2_pad, result3_pad,result4_pad, result5_pad, result6_pad, result7_pad, result8_pad], 1)
214 |
215 | guide_result = torch.sum((guide_weight.mul(guide_result)), dim=1)
216 | guide_result = guide_result[:, int((self.kernel_size-1)/2):-int((self.kernel_size-1)/2), int((self.kernel_size-1)/2):-int((self.kernel_size-1)/2)]
217 |
218 | return guide_result.unsqueeze(dim=1)
219 |
220 | class CSPNGenerateAccelerate(nn.Module):
221 | def __init__(self, in_channels, kernel_size):
222 | super(CSPNGenerateAccelerate, self).__init__()
223 | self.kernel_size = kernel_size
224 | self.generate = convbn(in_channels, self.kernel_size * self.kernel_size - 1, kernel_size=3, stride=1, padding=1)
225 |
226 | def forward(self, feature):
227 |
228 | guide = self.generate(feature)
229 |
230 | #normalization in standard CSPN
231 | #'''
232 | guide_sum = torch.sum(guide.abs(), dim=1).unsqueeze(1)
233 | guide = torch.div(guide, guide_sum)
234 | guide_mid = (1 - torch.sum(guide, dim=1)).unsqueeze(1)
235 | #'''
236 | #weight_pad = [i for i in range(self.kernel_size * self.kernel_size)]
237 |
238 | half1, half2 = torch.chunk(guide, 2, dim=1)
239 | output = torch.cat((half1, guide_mid, half2), dim=1)
240 | return output
241 |
242 | def kernel_trans(kernel, weight):
243 | kernel_size = int(math.sqrt(kernel.size()[1]))
244 | kernel = F.conv2d(kernel, weight, stride=1, padding=int((kernel_size-1)/2))
245 | return kernel
246 |
247 | class CSPNAccelerate(nn.Module):
248 | def __init__(self, kernel_size, dilation=1, padding=1, stride=1):
249 | super(CSPNAccelerate, self).__init__()
250 | self.kernel_size = kernel_size
251 | self.dilation = dilation
252 | self.padding = padding
253 | self.stride = stride
254 |
255 | def forward(self, kernel, input, input0): #with standard CSPN, an addition input0 port is added
256 | bs = input.size()[0]
257 | h, w = input.size()[2], input.size()[3]
258 | input_im2col = F.unfold(input, self.kernel_size, self.dilation, self.padding, self.stride)
259 | kernel = kernel.reshape(bs, self.kernel_size * self.kernel_size, h * w)
260 |
261 | # standard CSPN
262 | input0 = input0.view(bs, 1, h * w)
263 | mid_index = int((self.kernel_size*self.kernel_size-1)/2)
264 | input_im2col[:, mid_index:mid_index+1, :] = input0
265 |
266 | #print(input_im2col.size(), kernel.size())
267 | output = torch.einsum('ijk,ijk->ik', (input_im2col, kernel))
268 | return output.view(bs, 1, h, w)
269 |
270 | class GeometryFeature(nn.Module):
271 | def __init__(self):
272 | super(GeometryFeature, self).__init__()
273 |
274 | def forward(self, z, vnorm, unorm, h, w, ch, cw, fh, fw):
275 | x = z*(0.5*h*(vnorm+1)-ch)/fh
276 | y = z*(0.5*w*(unorm+1)-cw)/fw
277 | return torch.cat((x, y, z),1)
278 |
279 | class BasicBlockGeo(nn.Module):
280 | expansion = 1
281 | __constants__ = ['downsample']
282 |
283 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1,
284 | base_width=64, dilation=1, norm_layer=None, geoplanes=3):
285 | super(BasicBlockGeo, self).__init__()
286 |
287 | if norm_layer is None:
288 | norm_layer = nn.BatchNorm2d
289 | #norm_layer = encoding.nn.BatchNorm2d
290 | if groups != 1 or base_width != 64:
291 | raise ValueError('BasicBlock only supports groups=1 and base_width=64')
292 | if dilation > 1:
293 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
294 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1
295 | self.conv1 = conv3x3(inplanes + geoplanes, planes, stride)
296 | self.bn1 = norm_layer(planes)
297 | self.relu = nn.ReLU(inplace=True)
298 | self.conv2 = conv3x3(planes+geoplanes, planes)
299 | self.bn2 = norm_layer(planes)
300 | if stride != 1 or inplanes != planes:
301 | downsample = nn.Sequential(
302 | conv1x1(inplanes+geoplanes, planes, stride),
303 | norm_layer(planes),
304 | )
305 | self.downsample = downsample
306 | self.stride = stride
307 |
308 | def forward(self, x, g1=None, g2=None):
309 | identity = x
310 | if g1 is not None:
311 | x = torch.cat((x, g1), 1)
312 | out = self.conv1(x)
313 | out = self.bn1(out)
314 | out = self.relu(out)
315 |
316 | if g2 is not None:
317 | out = torch.cat((g2,out), 1)
318 | out = self.conv2(out)
319 | out = self.bn2(out)
320 |
321 | if self.downsample is not None:
322 | identity = self.downsample(x)
323 |
324 | out += identity
325 | out = self.relu(out)
326 |
327 | return out
328 |
329 |
330 |
--------------------------------------------------------------------------------
/criteria.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 |
4 | loss_names = ['l1', 'l2']
5 |
6 | class MaskedMSELoss(nn.Module):
7 | def __init__(self):
8 | super(MaskedMSELoss, self).__init__()
9 |
10 | def forward(self, pred, target):
11 | assert pred.dim() == target.dim(), "inconsistent dimensions"
12 | valid_mask = (target > 0).detach()
13 | diff = target - pred
14 | diff = diff[valid_mask]
15 | self.loss = (diff**2).mean()
16 | return self.loss
17 |
18 |
19 | class MaskedL1Loss(nn.Module):
20 | def __init__(self):
21 | super(MaskedL1Loss, self).__init__()
22 |
23 | def forward(self, pred, target, weight=None):
24 | assert pred.dim() == target.dim(), "inconsistent dimensions"
25 | valid_mask = (target > 0).detach()
26 | diff = target - pred
27 | diff = diff[valid_mask]
28 | self.loss = diff.abs().mean()
29 | return self.loss
30 |
--------------------------------------------------------------------------------
/dataloaders/__pycache__/kitti_loader.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anonymous1234321/GuideFormer/cccee1c5305977a1bc8d0b8df3f1b6ff66bd1736/dataloaders/__pycache__/kitti_loader.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/kitti_loader.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anonymous1234321/GuideFormer/cccee1c5305977a1bc8d0b8df3f1b6ff66bd1736/dataloaders/__pycache__/kitti_loader.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/transforms.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anonymous1234321/GuideFormer/cccee1c5305977a1bc8d0b8df3f1b6ff66bd1736/dataloaders/__pycache__/transforms.cpython-36.pyc
--------------------------------------------------------------------------------
/dataloaders/__pycache__/transforms.cpython-38.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anonymous1234321/GuideFormer/cccee1c5305977a1bc8d0b8df3f1b6ff66bd1736/dataloaders/__pycache__/transforms.cpython-38.pyc
--------------------------------------------------------------------------------
/dataloaders/calib_cam_to_cam.txt:
--------------------------------------------------------------------------------
1 | calib_time: 09-Jan-2012 13:57:47
2 | corner_dist: 9.950000e-02
3 | S_00: 1.392000e+03 5.120000e+02
4 | K_00: 9.842439e+02 0.000000e+00 6.900000e+02 0.000000e+00 9.808141e+02 2.331966e+02 0.000000e+00 0.000000e+00 1.000000e+00
5 | D_00: -3.728755e-01 2.037299e-01 2.219027e-03 1.383707e-03 -7.233722e-02
6 | R_00: 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00
7 | T_00: 2.573699e-16 -1.059758e-16 1.614870e-16
8 | S_rect_00: 1.242000e+03 3.750000e+02
9 | R_rect_00: 9.999239e-01 9.837760e-03 -7.445048e-03 -9.869795e-03 9.999421e-01 -4.278459e-03 7.402527e-03 4.351614e-03 9.999631e-01
10 | P_rect_00: 7.215377e+02 0.000000e+00 6.095593e+02 0.000000e+00 0.000000e+00 7.215377e+02 1.728540e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00
11 | S_01: 1.392000e+03 5.120000e+02
12 | K_01: 9.895267e+02 0.000000e+00 7.020000e+02 0.000000e+00 9.878386e+02 2.455590e+02 0.000000e+00 0.000000e+00 1.000000e+00
13 | D_01: -3.644661e-01 1.790019e-01 1.148107e-03 -6.298563e-04 -5.314062e-02
14 | R_01: 9.993513e-01 1.860866e-02 -3.083487e-02 -1.887662e-02 9.997863e-01 -8.421873e-03 3.067156e-02 8.998467e-03 9.994890e-01
15 | T_01: -5.370000e-01 4.822061e-03 -1.252488e-02
16 | S_rect_01: 1.242000e+03 3.750000e+02
17 | R_rect_01: 9.996878e-01 -8.976826e-03 2.331651e-02 8.876121e-03 9.999508e-01 4.418952e-03 -2.335503e-02 -4.210612e-03 9.997184e-01
18 | P_rect_01: 7.215377e+02 0.000000e+00 6.095593e+02 -3.875744e+02 0.000000e+00 7.215377e+02 1.728540e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00
19 | S_02: 1.392000e+03 5.120000e+02
20 | K_02: 9.597910e+02 0.000000e+00 6.960217e+02 0.000000e+00 9.569251e+02 2.241806e+02 0.000000e+00 0.000000e+00 1.000000e+00
21 | D_02: -3.691481e-01 1.968681e-01 1.353473e-03 5.677587e-04 -6.770705e-02
22 | R_02: 9.999758e-01 -5.267463e-03 -4.552439e-03 5.251945e-03 9.999804e-01 -3.413835e-03 4.570332e-03 3.389843e-03 9.999838e-01
23 | T_02: 5.956621e-02 2.900141e-04 2.577209e-03
24 | S_rect_02: 1.242000e+03 3.750000e+02
25 | R_rect_02: 9.998817e-01 1.511453e-02 -2.841595e-03 -1.511724e-02 9.998853e-01 -9.338510e-04 2.827154e-03 9.766976e-04 9.999955e-01
26 | P_rect_02: 7.215377e+02 0.000000e+00 6.095593e+02 4.485728e+01 0.000000e+00 7.215377e+02 1.728540e+02 2.163791e-01 0.000000e+00 0.000000e+00 1.000000e+00 2.745884e-03
27 | S_03: 1.392000e+03 5.120000e+02
28 | K_03: 9.037596e+02 0.000000e+00 6.957519e+02 0.000000e+00 9.019653e+02 2.242509e+02 0.000000e+00 0.000000e+00 1.000000e+00
29 | D_03: -3.639558e-01 1.788651e-01 6.029694e-04 -3.922424e-04 -5.382460e-02
30 | R_03: 9.995599e-01 1.699522e-02 -2.431313e-02 -1.704422e-02 9.998531e-01 -1.809756e-03 2.427880e-02 2.223358e-03 9.997028e-01
31 | T_03: -4.731050e-01 5.551470e-03 -5.250882e-03
32 | S_rect_03: 1.242000e+03 3.750000e+02
33 | R_rect_03: 9.998321e-01 -7.193136e-03 1.685599e-02 7.232804e-03 9.999712e-01 -2.293585e-03 -1.683901e-02 2.415116e-03 9.998553e-01
34 | P_rect_03: 7.215377e+02 0.000000e+00 6.095593e+02 -3.395242e+02 0.000000e+00 7.215377e+02 1.728540e+02 2.199936e+00 0.000000e+00 0.000000e+00 1.000000e+00 2.729905e-03
35 |
--------------------------------------------------------------------------------
/dataloaders/kitti_loader.py:
--------------------------------------------------------------------------------
1 | import os
2 | import os.path
3 | import glob
4 | import fnmatch # pattern matching
5 | import numpy as np
6 | from numpy import linalg as LA
7 | from random import choice
8 | from PIL import Image
9 | import torch
10 | import torch.utils.data as data
11 | import cv2
12 | from dataloaders import transforms
13 | import CoordConv
14 |
15 | input_options = ['d', 'rgb', 'rgbd', 'g', 'gd']
16 |
17 | def load_calib():
18 | """
19 | Temporarily hardcoding the calibration matrix using calib file from 2011_09_26
20 | """
21 | calib = open("dataloaders/calib_cam_to_cam.txt", "r")
22 | lines = calib.readlines()
23 | P_rect_line = lines[25]
24 |
25 | Proj_str = P_rect_line.split(":")[1].split(" ")[1:]
26 | Proj = np.reshape(np.array([float(p) for p in Proj_str]),
27 | (3, 4)).astype(np.float32)
28 | K = Proj[:3, :3] # camera matrix
29 |
30 | # note: we will take the center crop of the images during augmentation
31 | # that changes the optical centers, but not focal lengths
32 | # K[0, 2] = K[0, 2] - 13 # from width = 1242 to 1216, with a 13-pixel cut on both sides
33 | # K[1, 2] = K[1, 2] - 11.5 # from width = 375 to 352, with a 11.5-pixel cut on both sides
34 | K[0, 2] = K[0, 2] - 13;
35 | K[1, 2] = K[1, 2] - 11.5;
36 | return K
37 |
38 |
39 | def get_paths_and_transform(split, args):
40 | assert (args.use_d or args.use_rgb
41 | or args.use_g), 'no proper input selected'
42 |
43 | if split == "train":
44 | transform = train_transform
45 | # transform = val_transform
46 | glob_d = os.path.join(
47 | args.data_folder,
48 | 'data_depth_velodyne/train/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png'
49 | )
50 | glob_gt = os.path.join(
51 | args.data_folder,
52 | 'data_depth_annotated/train/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png'
53 | )
54 |
55 | def get_rgb_paths(p):
56 | ps = p.split('/')
57 | # date_liststr = []
58 | # date_liststr.append(ps[-5][:10])
59 | # pnew = '/'.join(date_liststr + ps[-5:-4] + ps[-2:-1] + ['data'] + ps[-1:])
60 | pnew = '/'.join(['train'] + ps[-5:-4] + ps[-2:-1] + ['data'] + ps[-1:])
61 |
62 | pnew = os.path.join(args.data_folder_rgb, pnew)
63 | return pnew
64 | elif split == "val":
65 | if args.val == "full":
66 | transform = val_transform
67 | glob_d = os.path.join(
68 | args.data_folder,
69 | 'data_depth_velodyne/val/*_sync/proj_depth/velodyne_raw/image_0[2,3]/*.png'
70 | )
71 | glob_gt = os.path.join(
72 | args.data_folder,
73 | 'data_depth_annotated/val/*_sync/proj_depth/groundtruth/image_0[2,3]/*.png'
74 | )
75 |
76 | def get_rgb_paths(p):
77 | ps = p.split('/')
78 | date_liststr = []
79 | date_liststr.append(ps[-5][:10])
80 | # pnew = '/'.join(ps[:-7] +
81 | # ['data_rgb']+ps[-6:-4]+ps[-2:-1]+['data']+ps[-1:])
82 | pnew = '/'.join(date_liststr + ps[-5:-4] + ps[-2:-1] + ['data'] + ps[-1:])
83 | pnew = os.path.join(args.data_folder_rgb, pnew)
84 | return pnew
85 |
86 | elif args.val == "select":
87 | # transform = no_transform
88 | transform = val_transform
89 | glob_d = os.path.join(
90 | args.data_folder,
91 | "data_depth_selection/val_selection_cropped/velodyne_raw/*.png")
92 | glob_gt = os.path.join(
93 | args.data_folder,
94 | "data_depth_selection/val_selection_cropped/groundtruth_depth/*.png"
95 | )
96 |
97 | def get_rgb_paths(p):
98 | return p.replace("groundtruth_depth", "image")
99 | elif split == "test_completion":
100 | transform = no_transform
101 | glob_d = os.path.join(
102 | args.data_folder,
103 | "data_depth_selection/test_depth_completion_anonymous/velodyne_raw/*.png"
104 | )
105 | glob_gt = None # "test_depth_completion_anonymous/"
106 | glob_rgb = os.path.join(
107 | args.data_folder,
108 | "data_depth_selection/test_depth_completion_anonymous/image/*.png")
109 | elif split == "test_prediction":
110 | transform = no_transform
111 | glob_d = None
112 | glob_gt = None # "test_depth_completion_anonymous/"
113 | glob_rgb = os.path.join(
114 | args.data_folder,
115 | "data_depth_selection/test_depth_prediction_anonymous/image/*.png")
116 | else:
117 | raise ValueError("Unrecognized split " + str(split))
118 |
119 | if glob_gt is not None:
120 | # train or val-full or val-select
121 | paths_d = sorted(glob.glob(glob_d))
122 | paths_gt = sorted(glob.glob(glob_gt))
123 | paths_rgb = [get_rgb_paths(p) for p in paths_gt]
124 | else:
125 | # test only has d or rgb
126 | paths_rgb = sorted(glob.glob(glob_rgb))
127 | paths_gt = [None] * len(paths_rgb)
128 | if split == "test_prediction":
129 | paths_d = [None] * len(
130 | paths_rgb) # test_prediction has no sparse depth
131 | else:
132 | paths_d = sorted(glob.glob(glob_d))
133 |
134 | if len(paths_d) == 0 and len(paths_rgb) == 0 and len(paths_gt) == 0:
135 | raise (RuntimeError("Found 0 images under {}".format(glob_gt)))
136 | if len(paths_d) == 0 and args.use_d:
137 | raise (RuntimeError("Requested sparse depth but none was found"))
138 | if len(paths_rgb) == 0 and args.use_rgb:
139 | raise (RuntimeError("Requested rgb images but none was found"))
140 | if len(paths_rgb) == 0 and args.use_g:
141 | raise (RuntimeError("Requested gray images but no rgb was found"))
142 | if len(paths_rgb) != len(paths_d) or len(paths_rgb) != len(paths_gt):
143 | print(len(paths_rgb), len(paths_d), len(paths_gt))
144 | # for i in range(999):
145 | # print("#####")
146 | # print(paths_rgb[i])
147 | # print(paths_d[i])
148 | # print(paths_gt[i])
149 | # raise (RuntimeError("Produced different sizes for datasets"))
150 | paths = {"rgb": paths_rgb, "d": paths_d, "gt": paths_gt}
151 | return paths, transform
152 |
153 |
154 | def rgb_read(filename):
155 | assert os.path.exists(filename), "file not found: {}".format(filename)
156 | img_file = Image.open(filename)
157 | # rgb_png = np.array(img_file, dtype=float) / 255.0 # scale pixels to the range [0,1]
158 | rgb_png = np.array(img_file, dtype='uint8') # in the range [0,255]
159 | img_file.close()
160 | return rgb_png
161 |
162 |
163 | def depth_read(filename):
164 | # loads depth map D from png file
165 | # and returns it as a numpy array,
166 | # for details see readme.txt
167 | assert os.path.exists(filename), "file not found: {}".format(filename)
168 | img_file = Image.open(filename)
169 | depth_png = np.array(img_file, dtype=int)
170 | img_file.close()
171 | # make sure we have a proper 16bit depth map here.. not 8bit!
172 | assert np.max(depth_png) > 255, \
173 | "np.max(depth_png)={}, path={}".format(np.max(depth_png), filename)
174 |
175 | depth = depth_png.astype(np.float) / 256.
176 | # depth[depth_png == 0] = -1.
177 | depth = np.expand_dims(depth, -1)
178 | return depth
179 |
180 | def drop_depth_measurements(depth, prob_keep):
181 | mask = np.random.binomial(1, prob_keep, depth.shape)
182 | depth *= mask
183 | return depth
184 |
185 | def train_transform(rgb, sparse, target, position, args):
186 | # s = np.random.uniform(1.0, 1.5) # random scaling
187 | # angle = np.random.uniform(-5.0, 5.0) # random rotation degrees
188 | oheight = args.val_h
189 | owidth = args.val_w
190 |
191 | do_flip = np.random.uniform(0.0, 1.0) < 0.5 # random horizontal flip
192 |
193 | transforms_list = [
194 | # transforms.Rotate(angle),
195 | # transforms.Resize(s),
196 | transforms.BottomCrop((oheight, owidth)),
197 | transforms.HorizontalFlip(do_flip)
198 | ]
199 |
200 | # if small_training == True:
201 | # transforms_list.append(transforms.RandomCrop((rheight, rwidth)))
202 |
203 | transform_geometric = transforms.Compose(transforms_list)
204 |
205 | if sparse is not None:
206 | sparse = transform_geometric(sparse)
207 | target = transform_geometric(target)
208 | if rgb is not None:
209 | brightness = np.random.uniform(max(0, 1 - args.jitter),
210 | 1 + args.jitter)
211 | contrast = np.random.uniform(max(0, 1 - args.jitter), 1 + args.jitter)
212 | saturation = np.random.uniform(max(0, 1 - args.jitter),
213 | 1 + args.jitter)
214 | transform_rgb = transforms.Compose([
215 | transforms.ColorJitter(brightness, contrast, saturation, 0),
216 | transform_geometric
217 | ])
218 | rgb = transform_rgb(rgb)
219 | # sparse = drop_depth_measurements(sparse, 0.9)
220 |
221 | if position is not None:
222 | bottom_crop_only = transforms.Compose([transforms.BottomCrop((oheight, owidth))])
223 | position = bottom_crop_only(position)
224 |
225 | # random crop
226 | #if small_training == True:
227 | if args.not_random_crop == False:
228 | h = oheight
229 | w = owidth
230 | rheight = args.random_crop_height
231 | rwidth = args.random_crop_width
232 | # randomlize
233 | i = np.random.randint(0, h - rheight + 1)
234 | j = np.random.randint(0, w - rwidth + 1)
235 |
236 | if rgb is not None:
237 | if rgb.ndim == 3:
238 | rgb = rgb[i:i + rheight, j:j + rwidth, :]
239 | elif rgb.ndim == 2:
240 | rgb = rgb[i:i + rheight, j:j + rwidth]
241 |
242 | if sparse is not None:
243 | if sparse.ndim == 3:
244 | sparse = sparse[i:i + rheight, j:j + rwidth, :]
245 | elif sparse.ndim == 2:
246 | sparse = sparse[i:i + rheight, j:j + rwidth]
247 |
248 | if target is not None:
249 | if target.ndim == 3:
250 | target = target[i:i + rheight, j:j + rwidth, :]
251 | elif target.ndim == 2:
252 | target = target[i:i + rheight, j:j + rwidth]
253 |
254 | if position is not None:
255 | if position.ndim == 3:
256 | position = position[i:i + rheight, j:j + rwidth, :]
257 | elif position.ndim == 2:
258 | position = position[i:i + rheight, j:j + rwidth]
259 |
260 | return rgb, sparse, target, position
261 |
262 | def val_transform(rgb, sparse, target, position, args):
263 | oheight = args.val_h
264 | owidth = args.val_w
265 |
266 | transform = transforms.Compose([
267 | transforms.BottomCrop((oheight, owidth)),
268 | ])
269 | if rgb is not None:
270 | rgb = transform(rgb)
271 | if sparse is not None:
272 | sparse = transform(sparse)
273 | if target is not None:
274 | target = transform(target)
275 | if position is not None:
276 | position = transform(position)
277 |
278 | return rgb, sparse, target, position
279 |
280 |
281 | def no_transform(rgb, sparse, target, position, args):
282 | return rgb, sparse, target, position
283 |
284 |
285 | to_tensor = transforms.ToTensor()
286 | to_float_tensor = lambda x: to_tensor(x).float()
287 |
288 |
289 | def handle_gray(rgb, args):
290 | if rgb is None:
291 | return None, None
292 | if not args.use_g:
293 | return rgb, None
294 | else:
295 | img = np.array(Image.fromarray(rgb).convert('L'))
296 | img = np.expand_dims(img, -1)
297 | if not args.use_rgb:
298 | rgb_ret = None
299 | else:
300 | rgb_ret = rgb
301 | return rgb_ret, img
302 |
303 |
304 | def get_rgb_near(path, args):
305 | assert path is not None, "path is None"
306 |
307 | def extract_frame_id(filename):
308 | head, tail = os.path.split(filename)
309 | number_string = tail[0:tail.find('.')]
310 | number = int(number_string)
311 | return head, number
312 |
313 | def get_nearby_filename(filename, new_id):
314 | head, _ = os.path.split(filename)
315 | new_filename = os.path.join(head, '%010d.png' % new_id)
316 | return new_filename
317 |
318 | head, number = extract_frame_id(path)
319 | count = 0
320 | max_frame_diff = 3
321 | candidates = [
322 | i - max_frame_diff for i in range(max_frame_diff * 2 + 1)
323 | if i - max_frame_diff != 0
324 | ]
325 | while True:
326 | random_offset = choice(candidates)
327 | path_near = get_nearby_filename(path, number + random_offset)
328 | if os.path.exists(path_near):
329 | break
330 | assert count < 20, "cannot find a nearby frame in 20 trials for {}".format(path_near)
331 |
332 | return rgb_read(path_near)
333 |
334 |
335 | class KittiDepth(data.Dataset):
336 | """A data loader for the Kitti dataset
337 | """
338 |
339 | def __init__(self, split, args):
340 | self.args = args
341 | self.split = split
342 | paths, transform = get_paths_and_transform(split, args)
343 | self.paths = paths
344 | self.transform = transform
345 | self.K = load_calib()
346 | self.threshold_translation = 0.1
347 |
348 | def __getraw__(self, index):
349 | rgb = rgb_read(self.paths['rgb'][index]) if \
350 | (self.paths['rgb'][index] is not None and (self.args.use_rgb or self.args.use_g)) else None
351 | sparse = depth_read(self.paths['d'][index]) if \
352 | (self.paths['d'][index] is not None and self.args.use_d) else None
353 | target = depth_read(self.paths['gt'][index]) if \
354 | self.paths['gt'][index] is not None else None
355 | return rgb, sparse, target
356 |
357 | def __getitem__(self, index):
358 | rgb, sparse, target = self.__getraw__(index)
359 | position = CoordConv.AddCoordsNp(self.args.val_h, self.args.val_w)
360 | position = position.call()
361 | rgb, sparse, target, position = self.transform(rgb, sparse, target, position, self.args)
362 |
363 | rgb, gray = handle_gray(rgb, self.args)
364 | # candidates = {"rgb": rgb, "d": sparse, "gt": target, \
365 | # "g": gray, "r_mat": r_mat, "t_vec": t_vec, "rgb_near": rgb_near}
366 | candidates = {"rgb": rgb, "d": sparse, "gt": target, \
367 | "g": gray, 'position': position, 'K': self.K}
368 |
369 | items = {
370 | key: to_float_tensor(val)
371 | for key, val in candidates.items() if val is not None
372 | }
373 |
374 | return items
375 |
376 | def __len__(self):
377 | return len(self.paths['gt'])
378 |
--------------------------------------------------------------------------------
/dataloaders/transforms.py:
--------------------------------------------------------------------------------
1 | from __future__ import division
2 | import torch
3 | import math
4 | import random
5 |
6 | from PIL import Image, ImageOps, ImageEnhance
7 | try:
8 | import accimage
9 | except ImportError:
10 | accimage = None
11 |
12 | import numpy as np
13 | import numbers
14 | import types
15 | import collections
16 | import warnings
17 |
18 | import scipy.ndimage.interpolation as itpl
19 | import skimage.transform
20 |
21 |
22 | def _is_numpy_image(img):
23 | return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
24 |
25 |
26 | def _is_pil_image(img):
27 | if accimage is not None:
28 | return isinstance(img, (Image.Image, accimage.Image))
29 | else:
30 | return isinstance(img, Image.Image)
31 |
32 |
33 | def _is_tensor_image(img):
34 | return torch.is_tensor(img) and img.ndimension() == 3
35 |
36 |
37 | def adjust_brightness(img, brightness_factor):
38 | """Adjust brightness of an Image.
39 |
40 | Args:
41 | img (PIL Image): PIL Image to be adjusted.
42 | brightness_factor (float): How much to adjust the brightness. Can be
43 | any non negative number. 0 gives a black image, 1 gives the
44 | original image while 2 increases the brightness by a factor of 2.
45 |
46 | Returns:
47 | PIL Image: Brightness adjusted image.
48 | """
49 | if not _is_pil_image(img):
50 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
51 |
52 | enhancer = ImageEnhance.Brightness(img)
53 | img = enhancer.enhance(brightness_factor)
54 | return img
55 |
56 |
57 | def adjust_contrast(img, contrast_factor):
58 | """Adjust contrast of an Image.
59 |
60 | Args:
61 | img (PIL Image): PIL Image to be adjusted.
62 | contrast_factor (float): How much to adjust the contrast. Can be any
63 | non negative number. 0 gives a solid gray image, 1 gives the
64 | original image while 2 increases the contrast by a factor of 2.
65 |
66 | Returns:
67 | PIL Image: Contrast adjusted image.
68 | """
69 | if not _is_pil_image(img):
70 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
71 |
72 | enhancer = ImageEnhance.Contrast(img)
73 | img = enhancer.enhance(contrast_factor)
74 | return img
75 |
76 |
77 | def adjust_saturation(img, saturation_factor):
78 | """Adjust color saturation of an image.
79 |
80 | Args:
81 | img (PIL Image): PIL Image to be adjusted.
82 | saturation_factor (float): How much to adjust the saturation. 0 will
83 | give a black and white image, 1 will give the original image while
84 | 2 will enhance the saturation by a factor of 2.
85 |
86 | Returns:
87 | PIL Image: Saturation adjusted image.
88 | """
89 | if not _is_pil_image(img):
90 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
91 |
92 | enhancer = ImageEnhance.Color(img)
93 | img = enhancer.enhance(saturation_factor)
94 | return img
95 |
96 |
97 | def adjust_hue(img, hue_factor):
98 | """Adjust hue of an image.
99 |
100 | The image hue is adjusted by converting the image to HSV and
101 | cyclically shifting the intensities in the hue channel (H).
102 | The image is then converted back to original image mode.
103 |
104 | `hue_factor` is the amount of shift in H channel and must be in the
105 | interval `[-0.5, 0.5]`.
106 |
107 | See https://en.wikipedia.org/wiki/Hue for more details on Hue.
108 |
109 | Args:
110 | img (PIL Image): PIL Image to be adjusted.
111 | hue_factor (float): How much to shift the hue channel. Should be in
112 | [-0.5, 0.5]. 0.5 and -0.5 give complete reversal of hue channel in
113 | HSV space in positive and negative direction respectively.
114 | 0 means no shift. Therefore, both -0.5 and 0.5 will give an image
115 | with complementary colors while 0 gives the original image.
116 |
117 | Returns:
118 | PIL Image: Hue adjusted image.
119 | """
120 | if not (-0.5 <= hue_factor <= 0.5):
121 | raise ValueError(
122 | 'hue_factor is not in [-0.5, 0.5].'.format(hue_factor))
123 |
124 | if not _is_pil_image(img):
125 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
126 |
127 | input_mode = img.mode
128 | if input_mode in {'L', '1', 'I', 'F'}:
129 | return img
130 |
131 | h, s, v = img.convert('HSV').split()
132 |
133 | np_h = np.array(h, dtype=np.uint8)
134 | # uint8 addition take cares of rotation across boundaries
135 | with np.errstate(over='ignore'):
136 | np_h += np.uint8(hue_factor * 255)
137 | h = Image.fromarray(np_h, 'L')
138 |
139 | img = Image.merge('HSV', (h, s, v)).convert(input_mode)
140 | return img
141 |
142 |
143 | def adjust_gamma(img, gamma, gain=1):
144 | """Perform gamma correction on an image.
145 |
146 | Also known as Power Law Transform. Intensities in RGB mode are adjusted
147 | based on the following equation:
148 |
149 | I_out = 255 * gain * ((I_in / 255) ** gamma)
150 |
151 | See https://en.wikipedia.org/wiki/Gamma_correction for more details.
152 |
153 | Args:
154 | img (PIL Image): PIL Image to be adjusted.
155 | gamma (float): Non negative real number. gamma larger than 1 make the
156 | shadows darker, while gamma smaller than 1 make dark regions
157 | lighter.
158 | gain (float): The constant multiplier.
159 | """
160 | if not _is_pil_image(img):
161 | raise TypeError('img should be PIL Image. Got {}'.format(type(img)))
162 |
163 | if gamma < 0:
164 | raise ValueError('Gamma should be a non-negative real number')
165 |
166 | input_mode = img.mode
167 | img = img.convert('RGB')
168 |
169 | np_img = np.array(img, dtype=np.float32)
170 | np_img = 255 * gain * ((np_img / 255)**gamma)
171 | np_img = np.uint8(np.clip(np_img, 0, 255))
172 |
173 | img = Image.fromarray(np_img, 'RGB').convert(input_mode)
174 | return img
175 |
176 |
177 | class Compose(object):
178 | """Composes several transforms together.
179 |
180 | Args:
181 | transforms (list of ``Transform`` objects): list of transforms to compose.
182 |
183 | Example:
184 | >>> transforms.Compose([
185 | >>> transforms.CenterCrop(10),
186 | >>> transforms.ToTensor(),
187 | >>> ])
188 | """
189 | def __init__(self, transforms):
190 | self.transforms = transforms
191 |
192 | def __call__(self, img):
193 | for t in self.transforms:
194 | img = t(img)
195 | return img
196 |
197 |
198 | class ToTensor(object):
199 | """Convert a ``numpy.ndarray`` to tensor.
200 |
201 | Converts a numpy.ndarray (H x W x C) to a torch.FloatTensor of shape (C x H x W).
202 | """
203 | def __call__(self, img):
204 | """Convert a ``numpy.ndarray`` to tensor.
205 |
206 | Args:
207 | img (numpy.ndarray): Image to be converted to tensor.
208 |
209 | Returns:
210 | Tensor: Converted image.
211 | """
212 | if not (_is_numpy_image(img)):
213 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
214 |
215 | if isinstance(img, np.ndarray):
216 | # handle numpy array
217 | if img.ndim == 3:
218 | img = torch.from_numpy(img.transpose((2, 0, 1)).copy())
219 | elif img.ndim == 2:
220 | img = torch.from_numpy(img.copy())
221 | else:
222 | raise RuntimeError(
223 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.
224 | format(img.ndim))
225 |
226 | return img
227 |
228 |
229 | class NormalizeNumpyArray(object):
230 | """Normalize a ``numpy.ndarray`` with mean and standard deviation.
231 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
232 | will normalize each channel of the input ``numpy.ndarray`` i.e.
233 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
234 |
235 | Args:
236 | mean (sequence): Sequence of means for each channel.
237 | std (sequence): Sequence of standard deviations for each channel.
238 | """
239 | def __init__(self, mean, std):
240 | self.mean = mean
241 | self.std = std
242 |
243 | def __call__(self, img):
244 | """
245 | Args:
246 | img (numpy.ndarray): Image of size (H, W, C) to be normalized.
247 |
248 | Returns:
249 | Tensor: Normalized image.
250 | """
251 | if not (_is_numpy_image(img)):
252 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
253 | # TODO: make efficient
254 | print(img.shape)
255 | for i in range(3):
256 | img[:, :, i] = (img[:, :, i] - self.mean[i]) / self.std[i]
257 | return img
258 |
259 |
260 | class NormalizeTensor(object):
261 | """Normalize an tensor image with mean and standard deviation.
262 | Given mean: ``(M1,...,Mn)`` and std: ``(M1,..,Mn)`` for ``n`` channels, this transform
263 | will normalize each channel of the input ``torch.*Tensor`` i.e.
264 | ``input[channel] = (input[channel] - mean[channel]) / std[channel]``
265 |
266 | Args:
267 | mean (sequence): Sequence of means for each channel.
268 | std (sequence): Sequence of standard deviations for each channel.
269 | """
270 | def __init__(self, mean, std):
271 | self.mean = mean
272 | self.std = std
273 |
274 | def __call__(self, tensor):
275 | """
276 | Args:
277 | tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
278 |
279 | Returns:
280 | Tensor: Normalized Tensor image.
281 | """
282 | if not _is_tensor_image(tensor):
283 | raise TypeError('tensor is not a torch image.')
284 | # TODO: make efficient
285 | for t, m, s in zip(tensor, self.mean, self.std):
286 | t.sub_(m).div_(s)
287 | return tensor
288 |
289 |
290 | class Rotate(object):
291 | """Rotates the given ``numpy.ndarray``.
292 |
293 | Args:
294 | angle (float): The rotation angle in degrees.
295 | """
296 | def __init__(self, angle):
297 | self.angle = angle
298 |
299 | def __call__(self, img):
300 | """
301 | Args:
302 | img (numpy.ndarray (C x H x W)): Image to be rotated.
303 |
304 | Returns:
305 | img (numpy.ndarray (C x H x W)): Rotated image.
306 | """
307 |
308 | # order=0 means nearest-neighbor type interpolation
309 | return skimage.transform.rotate(img, self.angle, resize=False, order=0)
310 |
311 |
312 | class Resize(object):
313 | """Resize the the given ``numpy.ndarray`` to the given size.
314 | Args:
315 | size (sequence or int): Desired output size. If size is a sequence like
316 | (h, w), output size will be matched to this. If size is an int,
317 | smaller edge of the image will be matched to this number.
318 | i.e, if height > width, then image will be rescaled to
319 | (size * height / width, size)
320 | interpolation (int, optional): Desired interpolation. Default is
321 | ``PIL.Image.BILINEAR``
322 | """
323 | def __init__(self, size, interpolation='nearest'):
324 | assert isinstance(size, float)
325 | self.size = size
326 | self.interpolation = interpolation
327 |
328 | def __call__(self, img):
329 | """
330 | Args:
331 | img (numpy.ndarray (C x H x W)): Image to be scaled.
332 | Returns:
333 | img (numpy.ndarray (C x H x W)): Rescaled image.
334 | """
335 | if img.ndim == 3:
336 | return skimage.transform.rescale(img, self.size, order=0)
337 | elif img.ndim == 2:
338 | return skimage.transform.rescale(img, self.size, order=0)
339 | else:
340 | RuntimeError(
341 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
342 | img.ndim))
343 |
344 |
345 | class CenterCrop(object):
346 | """Crops the given ``numpy.ndarray`` at the center.
347 |
348 | Args:
349 | size (sequence or int): Desired output size of the crop. If size is an
350 | int instead of sequence like (h, w), a square crop (size, size) is
351 | made.
352 | """
353 | def __init__(self, size):
354 | if isinstance(size, numbers.Number):
355 | self.size = (int(size), int(size))
356 | else:
357 | self.size = size
358 |
359 | @staticmethod
360 | def get_params(img, output_size):
361 | """Get parameters for ``crop`` for center crop.
362 |
363 | Args:
364 | img (numpy.ndarray (C x H x W)): Image to be cropped.
365 | output_size (tuple): Expected output size of the crop.
366 |
367 | Returns:
368 | tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
369 | """
370 | h = img.shape[0]
371 | w = img.shape[1]
372 | th, tw = output_size
373 | i = int(round((h - th) / 2.))
374 | j = int(round((w - tw) / 2.))
375 |
376 | # # randomized cropping
377 | # i = np.random.randint(i-3, i+4)
378 | # j = np.random.randint(j-3, j+4)
379 |
380 | return i, j, th, tw
381 |
382 | def __call__(self, img):
383 | """
384 | Args:
385 | img (numpy.ndarray (C x H x W)): Image to be cropped.
386 |
387 | Returns:
388 | img (numpy.ndarray (C x H x W)): Cropped image.
389 | """
390 | i, j, h, w = self.get_params(img, self.size)
391 | """
392 | i: Upper pixel coordinate.
393 | j: Left pixel coordinate.
394 | h: Height of the cropped image.
395 | w: Width of the cropped image.
396 | """
397 | if not (_is_numpy_image(img)):
398 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
399 | if img.ndim == 3:
400 | return img[i:i + h, j:j + w, :]
401 | elif img.ndim == 2:
402 | return img[i:i + h, j:j + w]
403 | else:
404 | raise RuntimeError(
405 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
406 | img.ndim))
407 |
408 |
409 | class BottomCrop(object):
410 | """Crops the given ``numpy.ndarray`` at the bottom.
411 |
412 | Args:
413 | size (sequence or int): Desired output size of the crop. If size is an
414 | int instead of sequence like (h, w), a square crop (size, size) is
415 | made.
416 | """
417 | def __init__(self, size):
418 | if isinstance(size, numbers.Number):
419 | self.size = (int(size), int(size))
420 | else:
421 | self.size = size
422 |
423 | @staticmethod
424 | def get_params(img, output_size):
425 | """Get parameters for ``crop`` for bottom crop.
426 |
427 | Args:
428 | img (numpy.ndarray (C x H x W)): Image to be cropped.
429 | output_size (tuple): Expected output size of the crop.
430 |
431 | Returns:
432 | tuple: params (i, j, h, w) to be passed to ``crop`` for bottom crop.
433 | """
434 | h = img.shape[0]
435 | w = img.shape[1]
436 | th, tw = output_size
437 | i = h - th
438 | j = int(round((w - tw) / 2.))
439 |
440 | # randomized left and right cropping
441 | # i = np.random.randint(i-3, i+4)
442 | # j = np.random.randint(j-1, j+1)
443 |
444 | return i, j, th, tw
445 |
446 | def __call__(self, img):
447 | """
448 | Args:
449 | img (numpy.ndarray (C x H x W)): Image to be cropped.
450 |
451 | Returns:
452 | img (numpy.ndarray (C x H x W)): Cropped image.
453 | """
454 | i, j, h, w = self.get_params(img, self.size)
455 | """
456 | i: Upper pixel coordinate.
457 | j: Left pixel coordinate.
458 | h: Height of the cropped image.
459 | w: Width of the cropped image.
460 | """
461 | if not (_is_numpy_image(img)):
462 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
463 | if img.ndim == 3:
464 | return img[i:i + h, j:j + w, :]
465 | elif img.ndim == 2:
466 | return img[i:i + h, j:j + w]
467 | else:
468 | raise RuntimeError(
469 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
470 | img.ndim))
471 |
472 |
473 | class RandomCrop(object):
474 | """Crops the given ``numpy.ndarray`` at the bottom.
475 |
476 | Args:
477 | size (sequence or int): Desired output size of the crop. If size is an
478 | int instead of sequence like (h, w), a square crop (size, size) is
479 | made.
480 | """
481 | def __init__(self, size):
482 | if isinstance(size, numbers.Number):
483 | self.size = (int(size), int(size))
484 | else:
485 | self.size = size
486 |
487 | @staticmethod
488 | def get_params(img, output_size):
489 | """Get parameters for ``crop`` for bottom crop.
490 |
491 | Args:
492 | img (numpy.ndarray (C x H x W)): Image to be cropped.
493 | output_size (tuple): Expected output size of the crop.
494 |
495 | Returns:
496 | tuple: params (i, j, h, w) to be passed to ``crop`` for bottom crop.
497 | """
498 | h = img.shape[0]
499 | w = img.shape[1]
500 | th, tw = output_size
501 |
502 | # randomized left and right cropping
503 | i = np.random.randint(0, h-th+1)
504 | j = np.random.randint(0, w-tw+1)
505 |
506 | return i, j, th, tw
507 |
508 | def __call__(self, img):
509 | """
510 | Args:
511 | img (numpy.ndarray (C x H x W)): Image to be cropped.
512 |
513 | Returns:
514 | img (numpy.ndarray (C x H x W)): Cropped image.
515 | """
516 | i, j, h, w = self.get_params(img, self.size)
517 | """
518 | i: Upper pixel coordinate.
519 | j: Left pixel coordinate.
520 | h: Height of the cropped image.
521 | w: Width of the cropped image.
522 | """
523 | if not (_is_numpy_image(img)):
524 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
525 | if img.ndim == 3:
526 | return img[i:i + h, j:j + w, :]
527 | elif img.ndim == 2:
528 | return img[i:i + h, j:j + w]
529 | else:
530 | raise RuntimeError(
531 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
532 | img.ndim))
533 |
534 |
535 | class Crop(object):
536 | """Crops the given ``numpy.ndarray`` at the center.
537 |
538 | Args:
539 | size (sequence or int): Desired output size of the crop. If size is an
540 | int instead of sequence like (h, w), a square crop (size, size) is
541 | made.
542 | """
543 | def __init__(self, crop):
544 | self.crop = crop
545 |
546 | @staticmethod
547 | def get_params(img, crop):
548 | """Get parameters for ``crop`` for center crop.
549 |
550 | Args:
551 | img (numpy.ndarray (C x H x W)): Image to be cropped.
552 | output_size (tuple): Expected output size of the crop.
553 |
554 | Returns:
555 | tuple: params (i, j, h, w) to be passed to ``crop`` for center crop.
556 | """
557 | x_l, x_r, y_b, y_t = crop
558 | h = img.shape[0]
559 | w = img.shape[1]
560 | assert x_l >= 0 and x_l < w
561 | assert x_r >= 0 and x_r < w
562 | assert y_b >= 0 and y_b < h
563 | assert y_t >= 0 and y_t < h
564 | assert x_l < x_r and y_b < y_t
565 |
566 | return x_l, x_r, y_b, y_t
567 |
568 | def __call__(self, img):
569 | """
570 | Args:
571 | img (numpy.ndarray (C x H x W)): Image to be cropped.
572 |
573 | Returns:
574 | img (numpy.ndarray (C x H x W)): Cropped image.
575 | """
576 | x_l, x_r, y_b, y_t = self.get_params(img, self.crop)
577 | """
578 | i: Upper pixel coordinate.
579 | j: Left pixel coordinate.
580 | h: Height of the cropped image.
581 | w: Width of the cropped image.
582 | """
583 | if not (_is_numpy_image(img)):
584 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
585 | if img.ndim == 3:
586 | return img[y_b:y_t, x_l:x_r, :]
587 | elif img.ndim == 2:
588 | return img[y_b:y_t, x_l:x_r]
589 | else:
590 | raise RuntimeError(
591 | 'img should be ndarray with 2 or 3 dimensions. Got {}'.format(
592 | img.ndim))
593 |
594 |
595 | class Lambda(object):
596 | """Apply a user-defined lambda as a transform.
597 |
598 | Args:
599 | lambd (function): Lambda/function to be used for transform.
600 | """
601 | def __init__(self, lambd):
602 | assert isinstance(lambd, types.LambdaType)
603 | self.lambd = lambd
604 |
605 | def __call__(self, img):
606 | return self.lambd(img)
607 |
608 |
609 | class HorizontalFlip(object):
610 | """Horizontally flip the given ``numpy.ndarray``.
611 |
612 | Args:
613 | do_flip (boolean): whether or not do horizontal flip.
614 |
615 | """
616 | def __init__(self, do_flip):
617 | self.do_flip = do_flip
618 |
619 | def __call__(self, img):
620 | """
621 | Args:
622 | img (numpy.ndarray (C x H x W)): Image to be flipped.
623 |
624 | Returns:
625 | img (numpy.ndarray (C x H x W)): flipped image.
626 | """
627 | if not (_is_numpy_image(img)):
628 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
629 |
630 | if self.do_flip:
631 | return np.fliplr(img)
632 | else:
633 | return img
634 |
635 |
636 | class ColorJitter(object):
637 | """Randomly change the brightness, contrast and saturation of an image.
638 |
639 | Args:
640 | brightness (float): How much to jitter brightness. brightness_factor
641 | is chosen uniformly from [max(0, 1 - brightness), 1 + brightness].
642 | contrast (float): How much to jitter contrast. contrast_factor
643 | is chosen uniformly from [max(0, 1 - contrast), 1 + contrast].
644 | saturation (float): How much to jitter saturation. saturation_factor
645 | is chosen uniformly from [max(0, 1 - saturation), 1 + saturation].
646 | hue(float): How much to jitter hue. hue_factor is chosen uniformly from
647 | [-hue, hue]. Should be >=0 and <= 0.5.
648 | """
649 | def __init__(self, brightness=0, contrast=0, saturation=0, hue=0):
650 | transforms = []
651 | transforms.append(
652 | Lambda(lambda img: adjust_brightness(img, brightness)))
653 | transforms.append(Lambda(lambda img: adjust_contrast(img, contrast)))
654 | transforms.append(
655 | Lambda(lambda img: adjust_saturation(img, saturation)))
656 | transforms.append(Lambda(lambda img: adjust_hue(img, hue)))
657 | np.random.shuffle(transforms)
658 | self.transform = Compose(transforms)
659 |
660 | def __call__(self, img):
661 | """
662 | Args:
663 | img (numpy.ndarray (C x H x W)): Input image.
664 |
665 | Returns:
666 | img (numpy.ndarray (C x H x W)): Color jittered image.
667 | """
668 | if not (_is_numpy_image(img)):
669 | raise TypeError('img should be ndarray. Got {}'.format(type(img)))
670 |
671 | pil = Image.fromarray(img)
672 | return np.array(self.transform(pil))
673 |
--------------------------------------------------------------------------------
/helper.py:
--------------------------------------------------------------------------------
1 | import math
2 | import os, time
3 | import shutil
4 | import torch
5 | import csv
6 | import vis_utils
7 | from metrics import Result
8 |
9 | fieldnames = [
10 | 'epoch', 'rmse', 'photo', 'mae', 'irmse', 'imae', 'mse', 'absrel', 'lg10',
11 | 'silog', 'squared_rel', 'delta1', 'delta2', 'delta3', 'data_time',
12 | 'gpu_time'
13 | ]
14 |
15 |
16 | class logger:
17 | def __init__(self, args, prepare=True):
18 | self.args = args
19 | output_directory = get_folder_name(args)
20 | self.output_directory = output_directory
21 | self.best_result = Result()
22 | self.best_result.set_to_worst()
23 |
24 | if not prepare:
25 | return
26 | if not os.path.exists(output_directory):
27 | os.makedirs(output_directory)
28 | self.train_csv = os.path.join(output_directory, 'train.csv')
29 | self.val_csv = os.path.join(output_directory, 'val.csv')
30 | self.best_txt = os.path.join(output_directory, 'best.txt')
31 |
32 | # backup the source code
33 | if args.resume == '':
34 | print("=> creating source code backup ...")
35 | # backup_directory = os.path.join(output_directory, "code_backup")
36 | # self.backup_directory = backup_directory
37 | # backup_source_code(backup_directory)
38 | # create new csv files with only header
39 | with open(self.train_csv, 'w') as csvfile:
40 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
41 | writer.writeheader()
42 | with open(self.val_csv, 'w') as csvfile:
43 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
44 | writer.writeheader()
45 | print("=> finished creating source code backup.")
46 |
47 | def conditional_print(self, split, i, epoch, lr, n_set, blk_avg_meter,
48 | avg_meter):
49 | if (i + 1) % self.args.print_freq == 0:
50 | avg = avg_meter.average()
51 | blk_avg = blk_avg_meter.average()
52 | print('=> output: {}'.format(self.output_directory))
53 | print(
54 | '{split} Epoch: {0} [{1}/{2}]\tlr={lr} '
55 | 't_Data={blk_avg.data_time:.3f}({average.data_time:.3f}) '
56 | 't_GPU={blk_avg.gpu_time:.3f}({average.gpu_time:.3f})\n\t'
57 | 'RMSE={blk_avg.rmse:.2f}({average.rmse:.2f}) '
58 | 'MAE={blk_avg.mae:.2f}({average.mae:.2f}) '
59 | 'iRMSE={blk_avg.irmse:.2f}({average.irmse:.2f}) '
60 | 'iMAE={blk_avg.imae:.2f}({average.imae:.2f})\n\t'
61 | 'silog={blk_avg.silog:.2f}({average.silog:.2f}) '
62 | 'squared_rel={blk_avg.squared_rel:.2f}({average.squared_rel:.2f}) '
63 | 'Delta1={blk_avg.delta1:.3f}({average.delta1:.3f}) '
64 | 'REL={blk_avg.absrel:.3f}({average.absrel:.3f})\n\t'
65 | 'Lg10={blk_avg.lg10:.3f}({average.lg10:.3f}) '
66 | 'Photometric={blk_avg.photometric:.3f}({average.photometric:.3f}) '
67 | .format(epoch,
68 | i + 1,
69 | n_set,
70 | lr=lr,
71 | blk_avg=blk_avg,
72 | average=avg,
73 | split=split.capitalize()))
74 | blk_avg_meter.reset(False)
75 |
76 | def conditional_save_info(self, split, average_meter, epoch):
77 | avg = average_meter.average()
78 | if split == "train":
79 | csvfile_name = self.train_csv
80 | elif split == "val":
81 | csvfile_name = self.val_csv
82 | elif split == "eval":
83 | eval_filename = os.path.join(self.output_directory, 'eval.txt')
84 | self.save_single_txt(eval_filename, avg, epoch)
85 | return avg
86 | elif "test" in split:
87 | return avg
88 | else:
89 | raise ValueError("wrong split provided to logger")
90 | with open(csvfile_name, 'a') as csvfile:
91 | writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
92 | writer.writerow({
93 | 'epoch': epoch,
94 | 'rmse': avg.rmse,
95 | 'photo': avg.photometric,
96 | 'mae': avg.mae,
97 | 'irmse': avg.irmse,
98 | 'imae': avg.imae,
99 | 'mse': avg.mse,
100 | 'silog': avg.silog,
101 | 'squared_rel': avg.squared_rel,
102 | 'absrel': avg.absrel,
103 | 'lg10': avg.lg10,
104 | 'delta1': avg.delta1,
105 | 'delta2': avg.delta2,
106 | 'delta3': avg.delta3,
107 | 'gpu_time': avg.gpu_time,
108 | 'data_time': avg.data_time
109 | })
110 | return avg
111 |
112 | def save_single_txt(self, filename, result, epoch):
113 | with open(filename, 'w') as txtfile:
114 | txtfile.write(
115 | ("rank_metric={}\n" + "epoch={}\n" + "rmse={:.3f}\n" +
116 | "mae={:.3f}\n" + "silog={:.3f}\n" + "squared_rel={:.3f}\n" +
117 | "irmse={:.3f}\n" + "imae={:.3f}\n" + "mse={:.3f}\n" +
118 | "absrel={:.3f}\n" + "lg10={:.3f}\n" + "delta1={:.3f}\n" +
119 | "t_gpu={:.4f}").format(self.args.rank_metric, epoch,
120 | result.rmse, result.mae, result.silog,
121 | result.squared_rel, result.irmse,
122 | result.imae, result.mse, result.absrel,
123 | result.lg10, result.delta1,
124 | result.gpu_time))
125 |
126 | def save_best_txt(self, result, epoch):
127 | self.save_single_txt(self.best_txt, result, epoch)
128 |
129 | def _get_img_comparison_name(self, mode, epoch, is_best=False):
130 | if mode == 'eval':
131 | return self.output_directory + '/comparison_eval.png'
132 | if mode == 'val':
133 | if is_best:
134 | return self.output_directory + '/comparison_best.png'
135 | else:
136 | return self.output_directory + '/comparison_' + str(epoch) + '.png'
137 |
138 | def conditional_save_img_comparison(self, mode, i, ele, pred, epoch, predrgb=None, predg=None, extra=None, extra2=None, extrargb=None):
139 | # save 8 images for visualization
140 | if mode == 'val' or mode == 'eval':
141 | skip = 100
142 | if i == 0:
143 | self.img_merge = vis_utils.merge_into_row(ele, pred, predrgb, predg, extra, extra2, extrargb)
144 | elif i % skip == 0 and i < 8 * skip:
145 | row = vis_utils.merge_into_row(ele, pred, predrgb, predg, extra, extra2, extrargb)
146 | self.img_merge = vis_utils.add_row(self.img_merge, row)
147 | elif i == 8 * skip:
148 | filename = self._get_img_comparison_name(mode, epoch)
149 | vis_utils.save_image(self.img_merge, filename)
150 |
151 | def save_img_comparison_as_best(self, mode, epoch):
152 | if mode == 'val':
153 | filename = self._get_img_comparison_name(mode, epoch, is_best=True)
154 | vis_utils.save_image(self.img_merge, filename)
155 |
156 | def get_ranking_error(self, result):
157 | return getattr(result, self.args.rank_metric)
158 |
159 | def rank_conditional_save_best(self, mode, result, epoch):
160 | error = self.get_ranking_error(result)
161 | best_error = self.get_ranking_error(self.best_result)
162 | is_best = error < best_error
163 | if is_best and mode == "val":
164 | self.old_best_result = self.best_result
165 | self.best_result = result
166 | self.save_best_txt(result, epoch)
167 | return is_best
168 |
169 | def conditional_save_pred(self, mode, i, pred, epoch):
170 | if ("test" in mode or mode == "eval") and self.args.save_pred:
171 |
172 | # save images for visualization/ testing
173 | image_folder = os.path.join(self.output_directory,
174 | mode + "_output")
175 | if not os.path.exists(image_folder):
176 | os.makedirs(image_folder)
177 | img = torch.squeeze(pred.data.cpu()).numpy()
178 | filename = os.path.join(image_folder, '{0:010d}.png'.format(i))
179 | vis_utils.save_depth_as_uint16png(img, filename)
180 |
181 | def conditional_summarize(self, mode, avg, is_best):
182 | print("\n*\nSummary of ", mode, "round")
183 | print(''
184 | 'RMSE={average.rmse:.3f}\n'
185 | 'MAE={average.mae:.3f}\n'
186 | 'Photo={average.photometric:.3f}\n'
187 | 'iRMSE={average.irmse:.3f}\n'
188 | 'iMAE={average.imae:.3f}\n'
189 | 'squared_rel={average.squared_rel}\n'
190 | 'silog={average.silog}\n'
191 | 'Delta1={average.delta1:.3f}\n'
192 | 'REL={average.absrel:.3f}\n'
193 | 'Lg10={average.lg10:.3f}\n'
194 | 't_GPU={time:.3f}'.format(average=avg, time=avg.gpu_time))
195 | if is_best and mode == "val":
196 | print("New best model by %s (was %.3f)" %
197 | (self.args.rank_metric,
198 | self.get_ranking_error(self.old_best_result)))
199 | elif mode == "val":
200 | print("(best %s is %.3f)" %
201 | (self.args.rank_metric,
202 | self.get_ranking_error(self.best_result)))
203 | print("*\n")
204 |
205 |
206 | ignore_hidden = shutil.ignore_patterns(".", "..", ".git*", "*pycache*",
207 | "*build", "*.fuse*", "*_drive_*")
208 |
209 |
210 | def backup_source_code(backup_directory):
211 | if os.path.exists(backup_directory):
212 | shutil.rmtree(backup_directory)
213 | shutil.copytree('.', backup_directory, ignore=ignore_hidden)
214 |
215 |
216 | def adjust_learning_rate(lr_init, optimizer, epoch, args):
217 | """Sets the learning rate to the initial LR decayed by half every 5 epochs"""
218 | lr = lr_init * (0.5**(epoch // 5))
219 |
220 | for param_group in optimizer.param_groups:
221 | param_group['lr'] = lr
222 | return lr
223 |
224 | def save_checkpoint(state, is_best, epoch, output_directory):
225 | checkpoint_filename = os.path.join(output_directory,
226 | 'checkpoint-' + str(epoch) + '.pth.tar')
227 | torch.save(state, checkpoint_filename)
228 | if is_best:
229 | best_filename = os.path.join(output_directory, 'model_best.pth.tar')
230 | shutil.copyfile(checkpoint_filename, best_filename)
231 | if epoch > 0:
232 | prev_checkpoint_filename = os.path.join(
233 | output_directory, 'checkpoint-' + str(epoch - 1) + '.pth.tar')
234 | if os.path.exists(prev_checkpoint_filename):
235 | os.remove(prev_checkpoint_filename)
236 |
237 |
238 | def get_folder_name(args):
239 | current_time = time.strftime('%Y-%m-%d@%H-%M')
240 | return os.path.join(args.result,
241 | 'input={}.criterion={}.lr={}.bs={}.wd={}.jitter={}.time={}'.
242 | format(args.input, args.criterion, \
243 | args.lr, args.batch_size, args.weight_decay, \
244 | args.jitter, current_time
245 | ))
246 |
247 |
248 | avgpool = torch.nn.AvgPool2d(kernel_size=2, stride=2).cuda()
249 |
250 |
251 | def multiscale(img):
252 | img1 = avgpool(img)
253 | img2 = avgpool(img1)
254 | img3 = avgpool(img2)
255 | img4 = avgpool(img3)
256 | img5 = avgpool(img4)
257 | return img5, img4, img3, img2, img1
258 |
--------------------------------------------------------------------------------
/images/model architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/Anonymous1234321/GuideFormer/cccee1c5305977a1bc8d0b8df3f1b6ff66bd1736/images/model architecture.png
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import torch.nn.parallel
6 | import torch.optim
7 | import torch.utils.data
8 | import time
9 |
10 | from dataloaders.kitti_loader import load_calib, input_options, KittiDepth
11 | from metrics import AverageMeter, Result
12 | import criteria
13 | import helper
14 | import vis_utils
15 |
16 | from model import GuideFormer
17 |
18 | parser = argparse.ArgumentParser(description='Sparse-to-Dense')
19 | parser.add_argument('--workers',
20 | default=8,
21 | type=int,
22 | metavar='N',
23 | help='number of data loading workers (default: 4)')
24 | parser.add_argument('--epochs',
25 | default=100,
26 | type=int,
27 | metavar='N',
28 | help='number of total epochs to run (default: 100)')
29 | parser.add_argument('--start-epoch',
30 | default=0,
31 | type=int,
32 | metavar='N',
33 | help='manual epoch number (useful on restarts)')
34 | parser.add_argument('--start-epoch-bias',
35 | default=0,
36 | type=int,
37 | metavar='N',
38 | help='manual epoch number bias(useful on restarts)')
39 | parser.add_argument('-c',
40 | '--criterion',
41 | metavar='LOSS',
42 | default='l2',
43 | choices=criteria.loss_names,
44 | help='loss function: | '.join(criteria.loss_names) +
45 | ' (default: l2)')
46 | parser.add_argument('-b',
47 | '--batch-size',
48 | default=1,
49 | type=int,
50 | help='mini-batch size (default: 1)')
51 | parser.add_argument('--lr',
52 | '--learning-rate',
53 | default=2e-4,
54 | type=float,
55 | metavar='LR',
56 | help='initial learning rate (default 2e-4)')
57 | parser.add_argument('--weight-decay',
58 | '--wd',
59 | default=1e-6,
60 | type=float,
61 | metavar='W',
62 | help='weight decay (default: 0)')
63 | parser.add_argument('--print-freq',
64 | '-p',
65 | default=10,
66 | type=int,
67 | metavar='N',
68 | help='print frequency (default: 10)')
69 | parser.add_argument('--resume',
70 | default='',
71 | type=str,
72 | metavar='PATH',
73 | help='path to latest checkpoint (default: none)')
74 | parser.add_argument('--data-folder',
75 | default='/resources/KITTI/kitti_depth',
76 | type=str,
77 | metavar='PATH',
78 | help='data folder (default: none)')
79 | parser.add_argument('--data-folder-rgb',
80 | default='/resources/KITTI/kitti_rgb',
81 | type=str,
82 | metavar='PATH',
83 | help='data folder rgb (default: none)')
84 | parser.add_argument('--data-folder-save',
85 | default='/resources/KITTI/submit_test/',
86 | type=str,
87 | metavar='PATH',
88 | help='data folder test results(default: none)')
89 | parser.add_argument('-i',
90 | '--input',
91 | type=str,
92 | default='rgbd',
93 | choices=input_options,
94 | help='input: | '.join(input_options))
95 | parser.add_argument('--val',
96 | type=str,
97 | default="select",
98 | choices=["select", "full"],
99 | help='full or select validation set')
100 | parser.add_argument('--jitter',
101 | type=float,
102 | default=0.1,
103 | help='color jitter for images')
104 | parser.add_argument('--rank-metric',
105 | type=str,
106 | default='rmse',
107 | choices=[m for m in dir(Result()) if not m.startswith('_')],
108 | help='metrics for which best result is saved')
109 |
110 | parser.add_argument('-e', '--evaluate', default='', type=str, metavar='PATH')
111 | parser.add_argument('--test', action="store_true", default=False,
112 | help='save result kitti test dataset for submission')
113 | parser.add_argument('--cpu', action="store_true", default=False, help='run on cpu')
114 |
115 | #random cropping
116 | parser.add_argument('--not-random-crop', action="store_true", default=False,
117 | help='prohibit random cropping')
118 | parser.add_argument('-he', '--random-crop-height', default=320, type=int, metavar='N',
119 | help='random crop height')
120 | parser.add_argument('-w', '--random-crop-width', default=1216, type=int, metavar='N',
121 | help='random crop height')
122 |
123 | args = parser.parse_args()
124 | args.result = os.path.join(os.getcwd(), 'results')
125 | args.use_rgb = ('rgb' in args.input)
126 | args.use_d = 'd' in args.input
127 | args.use_g = 'g' in args.input
128 | args.val_h = 352
129 | args.val_w = 1216
130 | print(args)
131 |
132 | cuda = torch.cuda.is_available() and not args.cpu
133 | if cuda:
134 | import torch.backends.cudnn as cudnn
135 | cudnn.benchmark = True
136 | device = torch.device("cuda")
137 | else:
138 | device = torch.device("cpu")
139 | print("=> using '{}' for computation.".format(device))
140 |
141 | # define loss functions
142 | depth_criterion = criteria.MaskedMSELoss() if (
143 | args.criterion == 'l2') else criteria.MaskedL1Loss()
144 |
145 | #multi batch
146 | multi_batch_size = 1
147 | def iterate(mode, args, loader, model, optimizer, logger, epoch):
148 | actual_epoch = epoch - args.start_epoch + args.start_epoch_bias
149 |
150 | block_average_meter = AverageMeter()
151 | block_average_meter.reset(False)
152 | average_meter = AverageMeter()
153 | meters = [block_average_meter, average_meter]
154 |
155 | # switch to appropriate mode
156 | assert mode in ["train", "val", "eval", "test_prediction", "test_completion"], \
157 | "unsupported mode: {}".format(mode)
158 | if mode == 'train':
159 | model.train()
160 | lr = helper.adjust_learning_rate(args.lr, optimizer, actual_epoch, args)
161 | else:
162 | model.eval()
163 | lr = 0
164 |
165 | torch.cuda.empty_cache()
166 | avg_loss = 0
167 | for i, batch_data in enumerate(loader):
168 | # if(mode == 'train' and i == 10) or (mode == 'val' and i == 50): break
169 |
170 | dstart = time.time()
171 | batch_data = {
172 | key: val.to(device)
173 | for key, val in batch_data.items() if val is not None
174 | }
175 |
176 | gt = batch_data[
177 | 'gt'] if mode != 'test_prediction' and mode != 'test_completion' else None
178 | data_time = time.time() - dstart
179 |
180 | pred = None
181 | start = None
182 | gpu_time = 0
183 |
184 | start = time.time()
185 | cbd_pred, dbd_pred, pred = model(batch_data)
186 |
187 | if(args.evaluate):
188 | gpu_time = time.time() - start
189 |
190 | depth_loss, photometric_loss, smooth_loss, mask = 0, 0, 0, None
191 |
192 | # inter loss_param
193 | cbd_loss, dbd_loss, loss = 0, 0, 0
194 | w_cbd, w_dbd = 0, 0
195 | round1, round2 = 1, 3
196 | if(actual_epoch <= round1):
197 | w_cbd, w_dbd = 0.2, 0.2
198 | elif(actual_epoch <= round2):
199 | w_cbd, w_dbd = 0.05, 0.05
200 | else:
201 | w_cbd, w_dbd = 0, 0
202 |
203 | if mode == 'train':
204 | # Loss 1: the direct depth supervision from ground truth label
205 | # mask=1 indicates that a pixel does not ground truth labels
206 | depth_loss = depth_criterion(pred, gt)
207 | cbd_loss = depth_criterion(cbd_pred, gt)
208 | dbd_loss = depth_criterion(dbd_pred, gt)
209 | loss = (1 - w_cbd - w_dbd) * depth_loss + w_cbd * cbd_loss + w_dbd * dbd_loss
210 |
211 | avg_loss = (avg_loss * i + loss.item()) / float(i + 1)
212 |
213 | if i % multi_batch_size == 0:
214 | optimizer.zero_grad()
215 | loss.backward()
216 |
217 | if i % multi_batch_size == (multi_batch_size-1) or i==(len(loader)-1):
218 | optimizer.step()
219 | print(f"loss: {round(loss.item(), 8)} ({round(avg_loss, 8)}) epoch: {epoch + 1} {i} / {len(loader)}")
220 |
221 | if mode == "test_completion":
222 | str_i = str(i)
223 | path_i = str_i.zfill(10) + '.png'
224 | path = os.path.join(args.data_folder_save, path_i)
225 | vis_utils.save_depth_as_uint16png_upload(pred, path)
226 |
227 | if(not args.evaluate):
228 | gpu_time = time.time() - start
229 | # measure accuracy and record loss
230 | with torch.no_grad():
231 | mini_batch_size = next(iter(batch_data.values())).size(0)
232 | result = Result()
233 | if mode != 'test_prediction' and mode != 'test_completion':
234 | result.evaluate(pred.data, gt.data, photometric_loss)
235 |
236 | for m in meters:
237 | m.update(result, gpu_time, data_time, mini_batch_size)
238 |
239 | if mode != 'train':
240 | logger.conditional_print(mode, i, epoch, lr, len(loader),
241 | block_average_meter, average_meter)
242 | logger.conditional_save_img_comparison(mode, i, batch_data, pred,
243 | epoch)
244 | logger.conditional_save_pred(mode, i, pred, epoch)
245 |
246 | avg = logger.conditional_save_info(mode, average_meter, epoch)
247 | is_best = logger.rank_conditional_save_best(mode, avg, epoch)
248 | if is_best and not (mode == "train"):
249 | logger.save_img_comparison_as_best(mode, epoch)
250 | logger.conditional_summarize(mode, avg, is_best)
251 |
252 | return avg, is_best
253 |
254 | def main():
255 | global args
256 | checkpoint = None
257 | is_eval = False
258 | if args.evaluate:
259 | args_new = args
260 | if os.path.isfile(args.evaluate):
261 | print("=> loading checkpoint '{}' ... ".format(args.evaluate),
262 | end='')
263 | checkpoint = torch.load(args.evaluate, map_location=device)
264 | #args = checkpoint['args']
265 | args.start_epoch = checkpoint['epoch'] + 1
266 | args.data_folder = args_new.data_folder
267 | args.val = args_new.val
268 | is_eval = True
269 |
270 | print("Completed.")
271 | else:
272 | is_eval = True
273 | print("No model found at '{}'".format(args.evaluate))
274 | #return
275 |
276 | elif args.resume: # optionally resume from a checkpoint
277 | args_new = args
278 | if os.path.isfile(args.resume):
279 | print("=> loading checkpoint '{}' ... ".format(args.resume),
280 | end='')
281 | checkpoint = torch.load(args.resume, map_location=device)
282 |
283 | args.start_epoch = checkpoint['epoch'] + 1
284 | args.data_folder = args_new.data_folder
285 | args.val = args_new.val
286 | print("Completed. Resuming from epoch {}.".format(
287 | checkpoint['epoch']))
288 | else:
289 | print("No checkpoint found at '{}'".format(args.resume))
290 | return
291 |
292 | print("=> creating model and optimizer ... ", end='')
293 | model = GuideFormer().to(device)
294 | torch.save(model.state_dict(), 'temp.pth')
295 |
296 | model_named_params = None
297 | optimizer = None
298 |
299 | if checkpoint is not None:
300 | model.load_state_dict(checkpoint['model'], strict=False)
301 | #optimizer.load_state_dict(checkpoint['optimizer'])
302 | print("=> checkpoint state loaded.")
303 |
304 | logger = helper.logger(args)
305 | if checkpoint is not None:
306 | logger.best_result = checkpoint['best_result']
307 | del checkpoint
308 | print("=> logger created.")
309 |
310 | test_dataset = None
311 | test_loader = None
312 | if (args.test):
313 | test_dataset = KittiDepth('test_completion', args)
314 | test_loader = torch.utils.data.DataLoader(
315 | test_dataset,
316 | batch_size=1,
317 | shuffle=False,
318 | num_workers=1,
319 | pin_memory=True)
320 | iterate("test_completion", args, test_loader, model, None, logger, 0)
321 | return
322 |
323 | val_dataset = KittiDepth('val', args)
324 | val_loader = torch.utils.data.DataLoader(
325 | val_dataset,
326 | batch_size=1,
327 | shuffle=False,
328 | num_workers=2,
329 | pin_memory=True) # set batch size to be 1 for validation
330 | print("\t==> val_loader size: {}".format(len(val_loader)))
331 |
332 | if is_eval == True:
333 | for p in model.parameters():
334 | p.requires_grad = False
335 |
336 | result, is_best = iterate("val", args, val_loader, model, None, logger,
337 | args.start_epoch - 1)
338 | return
339 |
340 | model_named_params = [
341 | p for _, p in model.named_parameters() if p.requires_grad
342 | ]
343 | optimizer = torch.optim.Adam(model_named_params, lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.99))
344 | print("completed.")
345 |
346 | model = torch.nn.DataParallel(model)
347 |
348 | # Data loading code
349 | print("=> creating data loaders ... ")
350 | if not is_eval:
351 | train_dataset = KittiDepth('train', args)
352 | train_loader = torch.utils.data.DataLoader(train_dataset,
353 | batch_size=args.batch_size,
354 | shuffle=True,
355 | num_workers=args.workers,
356 | pin_memory=True,
357 | sampler=None)
358 | print("\t==> train_loader size: {}".format(len(train_loader)))
359 |
360 | print("=> starting main loop ...")
361 | for epoch in range(args.start_epoch, args.epochs):
362 | print("=> starting training epoch {} ..".format(epoch))
363 | iterate("train", args, train_loader, model, optimizer, logger, epoch) # train for one epoch
364 |
365 | # validation memory reset
366 | for p in model.parameters():
367 | p.requires_grad = False
368 | result, is_best = iterate("val", args, val_loader, model, None, logger, epoch) # evaluate on validation set
369 |
370 | for p in model.parameters():
371 | p.requires_grad = True
372 |
373 | helper.save_checkpoint({ # save checkpoint
374 | 'epoch': epoch,
375 | 'model': model.module.state_dict(),
376 | 'best_result': logger.best_result,
377 | 'optimizer' : optimizer.state_dict(),
378 | 'args' : args,
379 | }, is_best, epoch, logger.output_directory)
380 |
381 |
382 | if __name__ == '__main__':
383 | main()
384 |
--------------------------------------------------------------------------------
/main_distributed.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 |
4 | import torch
5 | import torch.nn.parallel
6 | import torch.optim
7 | import torch.utils.data
8 | import time
9 |
10 | from tqdm import tqdm
11 |
12 | from dataloaders.kitti_loader import load_calib, input_options, KittiDepth
13 | from metrics import AverageMeter, Result
14 | import criteria
15 | import helper
16 | import vis_utils
17 |
18 | from model import GuideFormer
19 |
20 | # Mulit-GPU and Mixed precision supports
21 | # NOTE : Only 1 process per GPU is supported now
22 | import torch.multiprocessing as mp
23 | import torch.distributed as dist
24 | from torch.nn.parallel import DistributedDataParallel as DDP
25 | from torch_utils import select_device
26 | from torch.cuda import amp
27 | from torch.utils.data import DataLoader
28 | from torch.utils.data.distributed import DistributedSampler
29 |
30 |
31 | os.environ["CUDA_VISIBLE_DEVICS"] = "0,1,2,3,4,5,6,7"
32 | os.environ["MASTER_ADDR"] = 'localhost'
33 | os.environ["MASTER_PORT"] = '12345'
34 |
35 | LOCAL_RANK = int(os.getenv('LOCAL_RANK', -1))
36 | RANK = int(os.getenv('RANK', -1))
37 | WORLD_SIZE = int(os.getenv('WORLD_SIZE', 1))
38 |
39 |
40 | def train(args, device, checkpoint=None):
41 | cuda = torch.cuda.is_available() and not args.cpu
42 |
43 | if RANK == 0: print(args)
44 |
45 | # Prepare train dataset
46 | train_dataset = KittiDepth('train', args)
47 | train_sampler = DistributedSampler(train_dataset, num_replicas=args.num_gpus,
48 | rank=RANK)
49 | batch_size = args.batch_size // args.num_gpus
50 |
51 | train_loader = DataLoader(
52 | dataset=train_dataset, batch_size=batch_size, shuffle=False,
53 | num_workers=args.workers, pin_memory=True, sampler=train_sampler,
54 | drop_last=False)
55 |
56 | # Prepare val datatset
57 | val_dataset = KittiDepth('val', args)
58 | val_sampler = DistributedSampler(val_dataset, num_replicas=args.num_gpus,
59 | rank=RANK)
60 | val_loader = DataLoader(
61 | dataset=val_dataset, batch_size=1, shuffle=False,
62 | num_workers=args.workers, pin_memory=True, sampler=val_sampler, drop_last=False)
63 |
64 |
65 | # Network
66 | model = GuideFormer().to(device)
67 |
68 | if checkpoint is not None:
69 | model.load_state_dict(checkpoint['model'], strict=False)
70 | #optimizer.load_state_dict(checkpoint['optimizer'])
71 |
72 | if RANK == 0:
73 | print("=> checkpoint state loaded.")
74 |
75 | # Loss
76 | depth_criterion = criteria.MaskedMSELoss() if args.criterion == 'l2' \
77 | else criteria.MaskedL1Loss()
78 |
79 | # Optimizer and LR Scheduler
80 | model_named_params = [
81 | p for _, p in model.named_parameters() if p.requires_grad
82 | ]
83 | if args.optimizer == 'adam':
84 | optimizer = torch.optim.Adam(model_named_params, lr=args.lr, weight_decay=args.weight_decay, betas=(0.9, 0.999))
85 | else:
86 | optimizer = torch.optim.AdamW(model_named_params, lr=args.lr, weight_decay=args.weight_decay)
87 |
88 | # DDP
89 | if cuda and RANK != -1:
90 | # model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
91 | model = DDP(model, device_ids=[LOCAL_RANK], output_device=LOCAL_RANK)
92 |
93 | # logger
94 | logger = None
95 | if RANK == 0:
96 | logger = helper.logger(args)
97 | with open(os.path.join(helper.get_folder_name(args), 'hyperparams.txt'), 'w') as f:
98 | f.write(str(args))
99 | f.close()
100 | if checkpoint is not None:
101 | logger.best_result = checkpoint['best_result']
102 | del checkpoint
103 | print("=> logger created.")
104 |
105 | for epoch in range(args.start_epoch, args.epochs + 1):
106 |
107 | ### Train ###
108 | model.train()
109 | lr = helper.adjust_learning_rate(args.lr, optimizer, epoch, args)
110 |
111 | results_total = [torch.zeros(15, dtype=torch.float32).to(device)
112 | for _ in range(args.num_gpus)]
113 |
114 | average_part = AverageMeter()
115 |
116 | train_sampler.set_epoch(epoch)
117 |
118 | if RANK == 0:
119 | print(f'===> Epoch {epoch} / {args.epochs} | lr : {lr}')
120 |
121 | num_sample = len(train_loader) * train_loader.batch_size * args.num_gpus
122 |
123 | if RANK == 0:
124 | pbar = tqdm(total=num_sample)
125 | log_cnt = 0.0
126 | log_loss = 0.0
127 |
128 | for batch, sample in enumerate(train_loader):
129 | # if batch >= 100: break
130 |
131 | dstart = time.time()
132 |
133 | batch_data = {key: val.to(device) for key, val in sample.items()}
134 | gt = batch_data['gt'].to(device)
135 |
136 | data_time = time.time() - dstart
137 |
138 | cbd_loss, dbd_loss, loss = 0, 0, 0
139 | w_cbd, w_dbd = 0, 0
140 | round1, round2 = 1, 3
141 | if(epoch <= round1):
142 | w_cbd, w_dbd = 0.2, 0.2
143 | elif(epoch <= round2):
144 | w_cbd, w_dbd = 0.05, 0.05
145 | else:
146 | w_cbd, w_dbd = 0, 0
147 |
148 | start = time.time()
149 |
150 | optimizer.zero_grad()
151 |
152 | cbd_pred, dbd_pred, pred = model(batch_data)
153 | depth_loss = depth_criterion(pred, gt)
154 | cbd_loss = depth_criterion(cbd_pred, gt)
155 | dbd_loss = depth_criterion(dbd_pred, gt)
156 | loss = (1 - w_cbd - w_dbd) * depth_loss + w_cbd * cbd_loss + w_dbd * dbd_loss
157 |
158 | loss.backward()
159 | optimizer.step()
160 |
161 | gpu_time = time.time() - start
162 |
163 | with torch.no_grad():
164 | result = Result()
165 | result.evaluate(pred.data, gt.data)
166 | average_part.update(result, gpu_time, data_time, batch_size)
167 | if RANK == 0:
168 | log_cnt += 1
169 | log_loss += loss.item()
170 |
171 | error_str = 'Epoch {} | Loss = {:.4f}'.format(epoch, log_loss / log_cnt)
172 |
173 | pbar.set_description(error_str)
174 | pbar.update(train_loader.batch_size * args.num_gpus)
175 |
176 | dist.all_gather(results_total, average_part.average().get_result().to(device))
177 |
178 | if RANK == 0:
179 | pbar.close()
180 |
181 | average_meter = AverageMeter()
182 | result_part = Result()
183 | for result_tensor in results_total:
184 | result_part.update(*result_tensor.cpu().numpy())
185 | average_meter.update(result_part, result_part.gpu_time, result_part.data_time)
186 |
187 | avg = logger.conditional_save_info('train', average_meter, epoch)
188 | is_best = logger.rank_conditional_save_best('train', avg, epoch)
189 | logger.conditional_summarize('train', avg, is_best)
190 |
191 | ### Validation ###
192 | torch.set_grad_enabled(False)
193 | model.eval()
194 |
195 | results_total = [torch.zeros(15, dtype=torch.float32).to(device) for _ in range(args.num_gpus)]
196 |
197 | average_part = AverageMeter()
198 |
199 | num_sample = len(val_loader) * val_loader.batch_size * args.num_gpus
200 |
201 | if RANK == 0:
202 | pbar = tqdm(total=num_sample)
203 |
204 | for batch, sample in enumerate(val_loader):
205 | # if batch >= 10 : break
206 |
207 | dstart = time.time()
208 |
209 | batch_data = {key: val.to(device) for key, val in sample.items()}
210 | gt = batch_data['gt']
211 |
212 | data_time = time.time() - dstart
213 | start = time.time()
214 |
215 | cbd_pred, dbd_pred, pred = model(batch_data)
216 |
217 | gpu_time = time.time() - start
218 |
219 | with torch.no_grad():
220 | result = Result()
221 | result.evaluate(pred.data, gt.data)
222 | average_part.update(result, gpu_time, data_time, batch_size)
223 |
224 | if RANK == 0:
225 | logger.conditional_save_img_comparison('val', batch, batch_data, pred,
226 | epoch)
227 | pbar.update(val_loader.batch_size * args.num_gpus)
228 |
229 | # merge results from each gpu
230 | dist.all_gather(results_total, average_part.average().get_result().to(device))
231 |
232 | if RANK == 0:
233 | pbar.close()
234 |
235 | average_meter = AverageMeter()
236 | result_part = Result()
237 | for result_tensor in results_total:
238 | result_part.update(*result_tensor.cpu().numpy())
239 | average_meter.update(result_part, result_part.gpu_time, result_part.data_time)
240 |
241 | avg = logger.conditional_save_info('val', average_meter, epoch)
242 | is_best = logger.rank_conditional_save_best('val', avg, epoch)
243 | if is_best:
244 | logger.save_img_comparison_as_best('val', epoch)
245 | logger.conditional_summarize('val', avg, is_best)
246 |
247 | helper.save_checkpoint({ # save checkpoint
248 | 'epoch': epoch,
249 | 'model': model.module.state_dict(),
250 | 'best_result': logger.best_result,
251 | 'optimizer': optimizer.state_dict(),
252 | 'args': args,
253 | }, is_best, epoch, logger.output_directory)
254 |
255 | torch.set_grad_enabled(True)
256 |
257 |
258 | if __name__ == '__main__':
259 | parser = argparse.ArgumentParser(description='Sparse-to-Dense')
260 | parser.add_argument('--num_gpus',
261 | type=int,
262 | default=8,
263 | help='number of gpus')
264 | parser.add_argument('--workers',
265 | default=8,
266 | type=int,
267 | metavar='N',
268 | help='number of data loading workers (default: 4)')
269 | parser.add_argument('--epochs',
270 | default=100,
271 | type=int,
272 | metavar='N',
273 | help='number of total epochs to run (default: 100)')
274 | parser.add_argument('--start-epoch',
275 | default=0,
276 | type=int,
277 | metavar='N',
278 | help='manual epoch number (useful on restarts)')
279 | parser.add_argument('--start-epoch-bias',
280 | default=0,
281 | type=int,
282 | metavar='N',
283 | help='manual epoch number bias(useful on restarts)')
284 | parser.add_argument('-c',
285 | '--criterion',
286 | metavar='LOSS',
287 | default='l2',
288 | choices=criteria.loss_names,
289 | help='loss function: | '.join(criteria.loss_names) +
290 | ' (default: l2)')
291 | parser.add_argument('-b',
292 | '--batch-size',
293 | default=1,
294 | type=int,
295 | help='mini-batch size (default: 1)')
296 | parser.add_argument('--optimizer',
297 | default='adam',
298 | type=str,
299 | choices=['adam', 'adamw'],
300 | help='optimizer')
301 | parser.add_argument('--lr',
302 | '--learning-rate',
303 | default=1e-4,
304 | type=float,
305 | metavar='LR',
306 | help='initial learning rate (default 1e-5)')
307 | parser.add_argument('--weight-decay',
308 | '--wd',
309 | default=1e-6,
310 | type=float,
311 | metavar='W',
312 | help='weight decay (default: 0)')
313 | parser.add_argument('--print-freq',
314 | '-p',
315 | default=10,
316 | type=int,
317 | metavar='N',
318 | help='print frequency (default: 10)')
319 | parser.add_argument('--resume',
320 | # default='./results/try10_distributed_no-amp_syncbn_lossx4_224x224_bs=6/model_best.pth.tar',
321 | default='',
322 | type=str,
323 | metavar='PATH',
324 | help='path to latest checkpoint (default: none)')
325 | parser.add_argument('--data-folder',
326 | default='/resources/KITTI/kitti_depth',
327 | type=str,
328 | metavar='PATH',
329 | help='data folder (default: none)')
330 | parser.add_argument('--data-folder-rgb',
331 | default='/resources/KITTI/kitti_rgb',
332 | type=str,
333 | metavar='PATH',
334 | help='data folder rgb (default: none)')
335 | parser.add_argument('--data-folder-save',
336 | default='/resources/KITTI/submit_test/',
337 | type=str,
338 | metavar='PATH',
339 | help='data folder test results(default: none)')
340 | parser.add_argument('-i',
341 | '--input',
342 | type=str,
343 | default='rgbd',
344 | choices=input_options,
345 | help='input: | '.join(input_options))
346 | parser.add_argument('--val',
347 | type=str,
348 | default="select",
349 | choices=["select", "full"],
350 | help='full or select validation set')
351 | parser.add_argument('--jitter',
352 | type=float,
353 | default=0.1,
354 | help='color jitter for images')
355 | parser.add_argument('--rank-metric',
356 | type=str,
357 | default='rmse',
358 | choices=[m for m in dir(Result()) if not m.startswith('_')],
359 | help='metrics for which best result is saved')
360 |
361 | parser.add_argument('-e', '--evaluate', default='', type=str, metavar='PATH')
362 | parser.add_argument('--test', action="store_true", default=False,
363 | help='save result kitti test dataset for submission')
364 | parser.add_argument('--cpu', action="store_true", default=False, help='run on cpu')
365 |
366 | # random cropping
367 | parser.add_argument('--not-random-crop', action="store_true", default=False,
368 | help='prohibit random cropping')
369 | parser.add_argument('-he', '--random-crop-height', default=256, type=int, metavar='N',
370 | help='random crop height')
371 | parser.add_argument('-w', '--random-crop-width', default=1216, type=int, metavar='N',
372 | help='random crop height')
373 |
374 |
375 | # distributed learning
376 | parser.add_argument('--device',
377 | default="0,1,2,3,4,5,6,7",
378 | help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
379 | parser.add_argument('--local_rank', type=int,
380 | default=-1,
381 | help='DDP parameter, do not modify')
382 |
383 | args = parser.parse_args()
384 | args.result = os.path.join('.', 'results')
385 | args.use_rgb = ('rgb' in args.input)
386 | args.use_d = 'd' in args.input
387 | args.use_g = 'g' in args.input
388 | args.val_h = 352 # 352
389 | args.val_w = 1216
390 |
391 | # DDP mode
392 | device = select_device(args.device, batch_size=args.batch_size)
393 | if LOCAL_RANK != -1:
394 | assert torch.cuda.device_count() > LOCAL_RANK
395 | assert args.batch_size % WORLD_SIZE == 0, '--batch-size must be multiple of CUDA device count'
396 | torch.cuda.set_device(LOCAL_RANK)
397 | device = torch.device('cuda', LOCAL_RANK)
398 | dist.init_process_group(backend='nccl', init_method='env://') # distributed backend
399 |
400 | checkpoint = None
401 | if args.resume: # optionally resume from a checkpoint
402 | args_new = args
403 | if os.path.isfile(args.resume):
404 | if RANK == 0:
405 | print("=> loading checkpoint '{}' ... ".format(args.resume),
406 | end='')
407 | checkpoint = torch.load(args.resume, map_location=device)
408 |
409 | args.start_epoch = checkpoint['epoch'] + 1
410 | args.data_folder = args_new.data_folder
411 | args.val = args_new.val
412 | if RANK == 0:
413 | print("Completed. Resuming from epoch {}.".format(
414 | checkpoint['epoch']))
415 | else:
416 | if RANK == 0:
417 | print("No checkpoint found at '{}'".format(args.resume))
418 |
419 | train(args, device, checkpoint)
420 |
421 | if WORLD_SIZE > 1 and RANK == 0:
422 | _ = [print('Destroying process group... ', end=''), dist.destroy_process_group(), print('Done.')]
423 |
--------------------------------------------------------------------------------
/metrics.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import math
3 | import numpy as np
4 |
5 | lg_e_10 = math.log(10)
6 |
7 |
8 | def log10(x):
9 | """Convert a new tensor with the base-10 logarithm of the elements of x. """
10 | return torch.log(x) / lg_e_10
11 |
12 |
13 | class Result(object):
14 | def __init__(self):
15 | self.irmse = 0
16 | self.imae = 0
17 | self.mse = 0
18 | self.rmse = 0
19 | self.mae = 0
20 | self.absrel = 0
21 | self.squared_rel = 0
22 | self.lg10 = 0
23 | self.delta1 = 0
24 | self.delta2 = 0
25 | self.delta3 = 0
26 | self.data_time = 0
27 | self.gpu_time = 0
28 | self.silog = 0 # Scale invariant logarithmic error [log(m)*100]
29 | self.photometric = 0
30 |
31 | def set_to_worst(self):
32 | self.irmse = np.inf
33 | self.imae = np.inf
34 | self.mse = np.inf
35 | self.rmse = np.inf
36 | self.mae = np.inf
37 | self.absrel = np.inf
38 | self.squared_rel = np.inf
39 | self.lg10 = np.inf
40 | self.silog = np.inf
41 | self.delta1 = 0
42 | self.delta2 = 0
43 | self.delta3 = 0
44 | self.data_time = 0
45 | self.gpu_time = 0
46 |
47 | def update(self, irmse, imae, mse, rmse, mae, absrel, squared_rel, lg10, \
48 | delta1, delta2, delta3, gpu_time, data_time, silog, photometric=0):
49 | self.irmse = irmse
50 | self.imae = imae
51 | self.mse = mse
52 | self.rmse = rmse
53 | self.mae = mae
54 | self.absrel = absrel
55 | self.squared_rel = squared_rel
56 | self.lg10 = lg10
57 | self.delta1 = delta1
58 | self.delta2 = delta2
59 | self.delta3 = delta3
60 | self.data_time = data_time
61 | self.gpu_time = gpu_time
62 | self.silog = silog
63 | self.photometric = photometric
64 |
65 | def evaluate(self, output, target, photometric=0):
66 | valid_mask = target > 0.1
67 |
68 | # convert from meters to mm
69 | output_mm = 1e3 * output[valid_mask]
70 | target_mm = 1e3 * target[valid_mask]
71 |
72 | abs_diff = (output_mm - target_mm).abs()
73 |
74 | self.mse = float((torch.pow(abs_diff, 2)).mean())
75 | self.rmse = math.sqrt(self.mse)
76 | self.mae = float(abs_diff.mean())
77 | self.lg10 = float((log10(output_mm) - log10(target_mm)).abs().mean())
78 | self.absrel = float((abs_diff / target_mm).mean())
79 | self.squared_rel = float(((abs_diff / target_mm)**2).mean())
80 |
81 | maxRatio = torch.max(output_mm / target_mm, target_mm / output_mm)
82 | self.delta1 = float((maxRatio < 1.25).float().mean())
83 | self.delta2 = float((maxRatio < 1.25**2).float().mean())
84 | self.delta3 = float((maxRatio < 1.25**3).float().mean())
85 | self.data_time = 0
86 | self.gpu_time = 0
87 |
88 | # silog uses meters
89 | err_log = torch.log(target[valid_mask]) - torch.log(output[valid_mask])
90 | normalized_squared_log = (err_log**2).mean()
91 | log_mean = err_log.mean()
92 | self.silog = math.sqrt(normalized_squared_log -
93 | log_mean * log_mean) * 100
94 |
95 | # convert from meters to km
96 | inv_output_km = (1e-3 * output[valid_mask])**(-1)
97 | inv_target_km = (1e-3 * target[valid_mask])**(-1)
98 | abs_inv_diff = (inv_output_km - inv_target_km).abs()
99 | self.irmse = math.sqrt((torch.pow(abs_inv_diff, 2)).mean())
100 | self.imae = float(abs_inv_diff.mean())
101 |
102 | self.photometric = float(photometric)
103 |
104 |
105 | class AverageMeter(object):
106 | def __init__(self):
107 | self.reset(time_stable=True)
108 |
109 | def reset(self, time_stable):
110 | self.count = 0.0
111 | self.sum_irmse = 0
112 | self.sum_imae = 0
113 | self.sum_mse = 0
114 | self.sum_rmse = 0
115 | self.sum_mae = 0
116 | self.sum_absrel = 0
117 | self.sum_squared_rel = 0
118 | self.sum_lg10 = 0
119 | self.sum_delta1 = 0
120 | self.sum_delta2 = 0
121 | self.sum_delta3 = 0
122 | self.sum_data_time = 0
123 | self.sum_gpu_time = 0
124 | self.sum_photometric = 0
125 | self.sum_silog = 0
126 | self.time_stable = time_stable
127 | self.time_stable_counter_init = 10
128 | self.time_stable_counter = self.time_stable_counter_init
129 |
130 | def update(self, result, gpu_time, data_time, n=1):
131 | self.count += n
132 | self.sum_irmse += n * result.irmse
133 | self.sum_imae += n * result.imae
134 | self.sum_mse += n * result.mse
135 | self.sum_rmse += n * result.rmse
136 | self.sum_mae += n * result.mae
137 | self.sum_absrel += n * result.absrel
138 | self.sum_squared_rel += n * result.squared_rel
139 | self.sum_lg10 += n * result.lg10
140 | self.sum_delta1 += n * result.delta1
141 | self.sum_delta2 += n * result.delta2
142 | self.sum_delta3 += n * result.delta3
143 | self.sum_data_time += n * data_time
144 | if self.time_stable == True and self.time_stable_counter > 0:
145 | self.time_stable_counter = self.time_stable_counter - 1
146 | else:
147 | self.sum_gpu_time += n * gpu_time
148 | self.sum_silog += n * result.silog
149 | self.sum_photometric += n * result.photometric
150 |
151 | def average(self):
152 | avg = Result()
153 | if self.time_stable == True:
154 | if self.count > 0 and self.count - self.time_stable_counter_init > 0:
155 | avg.update(
156 | self.sum_irmse / self.count, self.sum_imae / self.count,
157 | self.sum_mse / self.count, self.sum_rmse / self.count,
158 | self.sum_mae / self.count, self.sum_absrel / self.count,
159 | self.sum_squared_rel / self.count, self.sum_lg10 / self.count,
160 | self.sum_delta1 / self.count, self.sum_delta2 / self.count,
161 | self.sum_delta3 / self.count, self.sum_gpu_time / (self.count - self.time_stable_counter_init),
162 | self.sum_data_time / self.count, self.sum_silog / self.count,
163 | self.sum_photometric / self.count)
164 | elif self.count > 0:
165 | avg.update(
166 | self.sum_irmse / self.count, self.sum_imae / self.count,
167 | self.sum_mse / self.count, self.sum_rmse / self.count,
168 | self.sum_mae / self.count, self.sum_absrel / self.count,
169 | self.sum_squared_rel / self.count, self.sum_lg10 / self.count,
170 | self.sum_delta1 / self.count, self.sum_delta2 / self.count,
171 | self.sum_delta3 / self.count, 0,
172 | self.sum_data_time / self.count, self.sum_silog / self.count,
173 | self.sum_photometric / self.count)
174 | elif self.count > 0:
175 | avg.update(
176 | self.sum_irmse / self.count, self.sum_imae / self.count,
177 | self.sum_mse / self.count, self.sum_rmse / self.count,
178 | self.sum_mae / self.count, self.sum_absrel / self.count,
179 | self.sum_squared_rel / self.count, self.sum_lg10 / self.count,
180 | self.sum_delta1 / self.count, self.sum_delta2 / self.count,
181 | self.sum_delta3 / self.count, self.sum_gpu_time / self.count,
182 | self.sum_data_time / self.count, self.sum_silog / self.count,
183 | self.sum_photometric / self.count)
184 | return avg
185 |
--------------------------------------------------------------------------------
/model.py:
--------------------------------------------------------------------------------
1 | from basic import *
2 | from utils import *
3 |
4 | class GuideFormer(nn.Module):
5 | def __init__(self,
6 | embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2],
7 | num_heads=[4, 8, 16, 32, 16, 8, 4],
8 | win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
9 | drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1,
10 | norm_layer=nn.LayerNorm, token_mlp='dwc',
11 | downsample=PatchMerging, upsample=PatchExpand,
12 | use_checkpoint=False, **kwargs):
13 | super(GuideFormer, self).__init__()
14 |
15 | # GuideFormer parameters
16 | self.num_enc_layers = len(depths) // 2
17 | self.embed_dim = embed_dim
18 | self.mlp_ratio = mlp_ratio
19 | self.mlp = token_mlp
20 | self.win_size = win_size
21 |
22 | self.pos_drop = nn.Dropout(p=drop_rate)
23 |
24 | # stochastic depth
25 | enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))]
26 | conv_dpr = [drop_path_rate] * depths[3]
27 | dec_dpr = enc_dpr[::-1]
28 |
29 | # Color branch
30 | self.rgb_proj_in = InputProj(in_channels=3, out_channels=embed_dim, kernel_size=3, stride=1,
31 | act_layer=nn.GELU)
32 |
33 | self.rgb_encoder_res1 = BasicBlockGeo(inplanes=embed_dim, planes=embed_dim * 2, stride=2, geoplanes=0)
34 | self.rgb_encoder_res2 = BasicBlockGeo(inplanes=embed_dim * 2, planes=embed_dim * 4, stride=2, geoplanes=0)
35 |
36 | self.rgb_encoder_layer1 = GuideFormerLayer(dim=embed_dim * 4,
37 | out_dim=embed_dim * 4, depth=depths[0],
38 | num_heads=num_heads[0], win_size=win_size,
39 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
40 | drop=drop_rate, attn_drop=attn_drop_rate,
41 | drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])],
42 | norm_layer=norm_layer, token_mlp=token_mlp,
43 | use_checkpoint=use_checkpoint)
44 | self.rgb_downsample1 = downsample(embed_dim * 4)
45 | self.rgb_encoder_layer2 = GuideFormerLayer(dim=embed_dim * 8,
46 | out_dim=embed_dim * 8, depth=depths[1],
47 | num_heads=num_heads[1], win_size=win_size,
48 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
49 | drop=drop_rate, attn_drop=attn_drop_rate,
50 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
51 | norm_layer=norm_layer, token_mlp=token_mlp,
52 | use_checkpoint=use_checkpoint)
53 | self.rgb_downsample2 = downsample(embed_dim * 8)
54 | self.rgb_encoder_layer3 = GuideFormerLayer(dim=embed_dim * 16,
55 | out_dim=embed_dim * 16, depth=depths[2],
56 | num_heads=num_heads[2], win_size=win_size,
57 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
58 | drop=drop_rate, attn_drop=attn_drop_rate,
59 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
60 | norm_layer=norm_layer, token_mlp=token_mlp,
61 | use_checkpoint=use_checkpoint)
62 | self.rgb_downsample3 = downsample(embed_dim * 16)
63 |
64 | self.rgb_bottleneck = GuideFormerLayer(dim=embed_dim * 32,
65 | out_dim=embed_dim * 32, depth=depths[3],
66 | num_heads=num_heads[3], win_size=11,
67 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
68 | drop=drop_rate, attn_drop=attn_drop_rate,
69 | drop_path=conv_dpr,
70 | norm_layer=norm_layer, token_mlp=token_mlp,
71 | use_checkpoint=use_checkpoint)
72 |
73 | self.rgb_up3 = upsample(embed_dim * 32, embed_dim * 16)
74 | self.rgb_decoder_layer3 = GuideFormerLayer(dim=embed_dim * 16,
75 | out_dim=embed_dim * 16, depth=depths[-3],
76 | num_heads=num_heads[-3], win_size=win_size,
77 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
78 | drop=drop_rate, attn_drop=attn_drop_rate,
79 | drop_path=dec_dpr[:depths[-3]],
80 | norm_layer=norm_layer, token_mlp=token_mlp,
81 | use_checkpoint=use_checkpoint)
82 | self.rgb_up2 = upsample(embed_dim * 16, embed_dim * 8)
83 | self.rgb_decoder_layer2 = GuideFormerLayer(dim=embed_dim * 8,
84 | out_dim=embed_dim * 8, depth=depths[-2],
85 | num_heads=num_heads[-2], win_size=win_size,
86 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
87 | drop=drop_rate, attn_drop=attn_drop_rate,
88 | drop_path=dec_dpr[sum(depths[-3:-2]):sum(depths[-3:-1])],
89 | norm_layer=norm_layer, token_mlp=token_mlp,
90 | use_checkpoint=use_checkpoint)
91 | self.rgb_up1 = upsample(embed_dim * 8, embed_dim * 4)
92 | self.rgb_decoder_layer1 = GuideFormerLayer(dim=embed_dim * 4,
93 | out_dim=embed_dim * 4, depth=depths[-1],
94 | num_heads=num_heads[-1], win_size=win_size,
95 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
96 | drop=drop_rate, attn_drop=attn_drop_rate,
97 | drop_path=dec_dpr[sum(depths[-3:-1]):sum(depths[-3:])],
98 | norm_layer=norm_layer, token_mlp=token_mlp,
99 | use_checkpoint=use_checkpoint)
100 |
101 | self.rgb_decoder_deconv2 = deconvbnrelu(in_channels=embed_dim * 4, out_channels=embed_dim * 2, kernel_size=5, stride=2, padding=2, output_padding=1)
102 | self.rgb_decoder_deconv1 = deconvbnrelu(in_channels=embed_dim * 2, out_channels=embed_dim, kernel_size=5, stride=2, padding=2, output_padding=1)
103 | self.rgb_decoder_output = OutputProj(in_channels=embed_dim * 1, out_channels=2, kernel_size=3, stride=1,
104 | norm_layer=nn.BatchNorm2d, act_layer=nn.GELU)
105 |
106 | # Depth branch
107 | self.depth_proj_in = InputProj(in_channels=1, out_channels=embed_dim, kernel_size=3, stride=1,
108 | act_layer=nn.GELU)
109 |
110 | self.depth_encoder_res1 = BasicBlockGeo(inplanes=embed_dim, planes=embed_dim * 2, stride=2, geoplanes=0)
111 | self.depth_encoder_res2 = BasicBlockGeo(inplanes=embed_dim * 2, planes=embed_dim * 4, stride=2, geoplanes=0)
112 |
113 | self.depth_encoder_layer1 = GuideFormerLayer(dim=embed_dim * 4,
114 | out_dim=embed_dim * 4, depth=depths[0],
115 | num_heads=num_heads[0], win_size=win_size,
116 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
117 | drop=drop_rate, attn_drop=attn_drop_rate,
118 | drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])],
119 | norm_layer=norm_layer, token_mlp=token_mlp,
120 | use_checkpoint=use_checkpoint)
121 | self.depth_downsample1 = downsample(embed_dim * 4)
122 | self.depth_encoder_layer2 = GuideFormerLayer(dim=embed_dim * 8,
123 | out_dim=embed_dim * 8, depth=depths[1],
124 | num_heads=num_heads[1], win_size=win_size,
125 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
126 | drop=drop_rate, attn_drop=attn_drop_rate,
127 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
128 | norm_layer=norm_layer, token_mlp=token_mlp,
129 | use_checkpoint=use_checkpoint)
130 | self.depth_downsample2 = downsample(embed_dim * 8)
131 | self.depth_encoder_layer3 = GuideFormerLayer(dim=embed_dim * 16,
132 | out_dim=embed_dim * 16, depth=depths[2],
133 | num_heads=num_heads[2], win_size=win_size,
134 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
135 | drop=drop_rate, attn_drop=attn_drop_rate,
136 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
137 | norm_layer=norm_layer, token_mlp=token_mlp,
138 | use_checkpoint=use_checkpoint)
139 | self.depth_downsample3 = downsample(embed_dim * 16)
140 |
141 | self.depth_bottleneck = GuideFormerLayer(dim=embed_dim * 32,
142 | out_dim=embed_dim * 32, depth=depths[3],
143 | num_heads=num_heads[3], win_size=11,
144 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
145 | drop=drop_rate, attn_drop=attn_drop_rate,
146 | drop_path=conv_dpr,
147 | norm_layer=norm_layer, token_mlp=token_mlp,
148 | use_checkpoint=use_checkpoint)
149 |
150 | self.depth_up3 = upsample(embed_dim * 32, embed_dim * 16)
151 | self.depth_decoder_layer3 = GuideFormerLayer(dim=embed_dim * 16,
152 | out_dim=embed_dim * 16, depth=depths[-3],
153 | num_heads=num_heads[-3], win_size=win_size,
154 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
155 | drop=drop_rate, attn_drop=attn_drop_rate,
156 | drop_path=dec_dpr[:depths[-3]],
157 | norm_layer=norm_layer, token_mlp=token_mlp,
158 | use_checkpoint=use_checkpoint)
159 | self.depth_up2 = upsample(embed_dim * 16, embed_dim * 8)
160 | self.depth_decoder_layer2 = GuideFormerLayer(dim=embed_dim * 8,
161 | out_dim=embed_dim * 8, depth=depths[-2],
162 | num_heads=num_heads[-2], win_size=win_size,
163 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
164 | drop=drop_rate, attn_drop=attn_drop_rate,
165 | drop_path=dec_dpr[sum(depths[-3:-2]):sum(depths[-3:-1])],
166 | norm_layer=norm_layer, token_mlp=token_mlp,
167 | use_checkpoint=use_checkpoint)
168 | self.depth_up1 = upsample(embed_dim * 8, embed_dim * 4)
169 | self.depth_decoder_layer1 = GuideFormerLayer(dim=embed_dim * 4,
170 | out_dim=embed_dim * 4, depth=depths[-1],
171 | num_heads=num_heads[-1], win_size=win_size,
172 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
173 | drop=drop_rate, attn_drop=attn_drop_rate,
174 | drop_path=dec_dpr[sum(depths[-3:-1]):sum(depths[-3:])],
175 | norm_layer=norm_layer, token_mlp=token_mlp,
176 | use_checkpoint=use_checkpoint)
177 |
178 | self.depth_decoder_deconv2 = deconvbnrelu(in_channels=embed_dim * 4, out_channels=embed_dim * 2, kernel_size=5,
179 | stride=2, padding=2, output_padding=1)
180 | self.depth_decoder_deconv1 = deconvbnrelu(in_channels=embed_dim * 2, out_channels=embed_dim, kernel_size=5,
181 | stride=2, padding=2, output_padding=1)
182 | self.depth_decoder_output = OutputProj(in_channels=embed_dim, out_channels=2, kernel_size=3, stride=1,
183 | norm_layer=nn.BatchNorm2d, act_layer=nn.GELU)
184 |
185 |
186 | self.rgb2d_attn1 = FusionLayer(dim=embed_dim * 4,
187 | out_dim=embed_dim * 4, depth=depths[0],
188 | num_heads=num_heads[0], win_size=win_size,
189 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
190 | drop=drop_rate, attn_drop=attn_drop_rate,
191 | drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])],
192 | norm_layer=norm_layer, token_mlp=token_mlp,
193 | use_checkpoint=use_checkpoint)
194 | self.rgb2d_attn2 = FusionLayer(dim=embed_dim * 8,
195 | out_dim=embed_dim * 8, depth=depths[1],
196 | num_heads=num_heads[1], win_size=win_size,
197 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
198 | drop=drop_rate, attn_drop=attn_drop_rate,
199 | drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])],
200 | norm_layer=norm_layer, token_mlp=token_mlp,
201 | use_checkpoint=use_checkpoint)
202 | self.rgb2d_attn3 = FusionLayer(dim=embed_dim * 16,
203 | out_dim=embed_dim * 16, depth=depths[2],
204 | num_heads=num_heads[2], win_size=win_size,
205 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
206 | drop=drop_rate, attn_drop=attn_drop_rate,
207 | drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])],
208 | norm_layer=norm_layer, token_mlp=token_mlp,
209 | use_checkpoint=use_checkpoint)
210 | self.rgb2d_attn_bottleneck = FusionLayer(dim=embed_dim * 32,
211 | out_dim=embed_dim * 32, depth=depths[3],
212 | num_heads=num_heads[3], win_size=10,
213 | mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
214 | drop=drop_rate, attn_drop=attn_drop_rate,
215 | drop_path=conv_dpr,
216 | norm_layer=norm_layer, token_mlp=token_mlp,
217 | use_checkpoint=use_checkpoint)
218 |
219 | self.softmax = nn.Softmax(dim=1)
220 |
221 | weights_init(self)
222 |
223 | def forward(self, input):
224 | rgb = input['rgb']
225 | d = input['d']
226 |
227 | B, C, H, W = d.shape
228 | H1, W1 = H, W # 352(320) 1216
229 | H2, W2 = (H1 + 1) // 2, (W1 + 1) // 2 # 176(160) 608
230 | H3, W3 = (H2 + 1) // 2, (W2 + 1) // 2 # 88(80) 304
231 | H4, W4 = (H3 + 1) // 2, (W3 + 1) // 2 # 44(40) 152
232 | H5, W5 = (H4 + 1) // 2, (W4 + 1) // 2 # 22(20) 76
233 | H6, W6 = (H5 + 1) // 2, (W5 + 1) // 2 # 11(10) 38
234 |
235 | # Color branch
236 | rgb_feature = self.rgb_proj_in(rgb)
237 | rgb_res1 = self.rgb_encoder_res1(rgb_feature)
238 | rgb_res2 = self.rgb_encoder_res2(rgb_res1)
239 | rgb_token0 = rgb_res2.flatten(2).transpose(1, 2).contiguous()
240 |
241 | rgb_token1 = self.rgb_encoder_layer1(rgb_token0, (H3, W3))
242 | rgb_pool1 = self.rgb_downsample1(rgb_token1, (H3, W3))
243 |
244 | rgb_token2 = self.rgb_encoder_layer2(rgb_pool1, (H4, W4))
245 | rgb_pool2 = self.rgb_downsample2(rgb_token2, (H4, W4))
246 |
247 | rgb_token3 = self.rgb_encoder_layer3(rgb_pool2, (H5, W5))
248 | rgb_pool3 = self.rgb_downsample3(rgb_token3, (H5, W5))
249 |
250 | rgb_token_bottle = self.rgb_bottleneck(rgb_pool3, (H6, W6))
251 |
252 | rgb_up3 = self.rgb_up3(rgb_token_bottle, (H6, W6), (H5, W5)) + rgb_token3
253 | rgb_feature_decoder3 = self.rgb_decoder_layer3(rgb_up3, (H5, W5))
254 |
255 | rgb_up2 = self.rgb_up2(rgb_feature_decoder3, (H5, W5), (H4, W4)) + rgb_token2
256 | rgb_feature_decoder2 = self.rgb_decoder_layer2(rgb_up2, (H4, W4))
257 |
258 | rgb_up1 = self.rgb_up1(rgb_feature_decoder2, (H4, W4), (H3, W3)) + rgb_token1
259 | rgb_feature_decoder1 = self.rgb_decoder_layer1(rgb_up1, (H3, W3))
260 |
261 | B, _, C = rgb_feature_decoder1.shape
262 | rgb_feature_decoder02 = rgb_feature_decoder1.transpose(1, 2).contiguous().view(B, C, H3, W3).contiguous() + rgb_res2
263 | rgb_feature_decoder02 = self.rgb_decoder_deconv2(rgb_feature_decoder02) + rgb_res1
264 | rgb_feature_decoder01 = self.rgb_decoder_deconv1(rgb_feature_decoder02)
265 |
266 | rgb_output = self.rgb_decoder_output(rgb_feature_decoder01)
267 | rgb_depth, rgb_conf = torch.chunk(rgb_output, 2, dim=1)
268 |
269 |
270 | ### Depth branch ###
271 | depth_feature = self.depth_proj_in(d)
272 | depth_res1 = self.depth_encoder_res1(depth_feature)
273 | depth_res2 = self.depth_encoder_res2(depth_res1)
274 | depth_token0 = depth_res2.flatten(2).transpose(1, 2).contiguous()
275 |
276 | depth_token1_cross = self.rgb2d_attn1(depth_token0, rgb_feature_decoder1, (H3, W3))
277 | depth_token1 = self.depth_encoder_layer1(depth_token1_cross, (H3, W3))
278 | depth_pool1 = self.depth_downsample1(depth_token1, (H3, W3))
279 |
280 | depth_token2_cross = self.rgb2d_attn2(depth_pool1, rgb_feature_decoder2, (H4, W4))
281 | depth_token2 = self.depth_encoder_layer2(depth_token2_cross, (H4, W4))
282 | depth_pool2 = self.depth_downsample2(depth_token2, (H4, W4))
283 |
284 | depth_token3_cross = self.rgb2d_attn3(depth_pool2, rgb_feature_decoder3, (H5, W5))
285 | depth_token3 = self.depth_encoder_layer3(depth_token3_cross, (H5, W5))
286 | depth_pool3 = self.depth_downsample3(depth_token3, (H5, W5))
287 |
288 | depth_token_bottle_cross = self.rgb2d_attn_bottleneck(depth_pool3, rgb_token_bottle, (H6, W6))
289 | depth_token_bottle = self.depth_bottleneck(depth_token_bottle_cross, (H6, W6))
290 |
291 | depth_up3 = self.depth_up3(depth_token_bottle, (H6, W6), (H5, W5)) + depth_token3
292 | depth_feature_decoder3 = self.depth_decoder_layer3(depth_up3, (H5, W5))
293 |
294 | depth_up2 = self.depth_up2(depth_feature_decoder3, (H5, W5), (H4, W4)) + depth_token2
295 | depth_feature_decoder2 = self.depth_decoder_layer2(depth_up2, (H4, W4))
296 |
297 | depth_up1 = self.depth_up1(depth_feature_decoder2, (H4, W4), (H3, W3)) + depth_token1
298 | depth_feature_decoder1 = self.depth_decoder_layer1(depth_up1, (H3, W3))
299 |
300 | B, _, C = depth_feature_decoder1.shape
301 | depth_feature_decoder02 = depth_feature_decoder1.transpose(1, 2).contiguous().view(B, C, H3, W3).contiguous() + depth_res2
302 | depth_feature_decoder02 = self.depth_decoder_deconv2(depth_feature_decoder02) + depth_res1
303 | depth_feature_decoder01 = self.depth_decoder_deconv1(depth_feature_decoder02)
304 |
305 | depth_output = self.depth_decoder_output(depth_feature_decoder01)
306 | d_depth, d_conf = torch.chunk(depth_output, 2, dim=1)
307 |
308 | rgb_conf, d_conf = torch.chunk(self.softmax(torch.cat((rgb_conf, d_conf), dim=1)), 2, dim=1)
309 | output = rgb_conf * rgb_depth + d_conf * d_depth
310 |
311 | return rgb_depth, d_depth, output
312 |
--------------------------------------------------------------------------------
/utils.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.utils.checkpoint as checkpoint
4 | from timm.models.layers import DropPath, to_2tuple, trunc_normal_
5 | import torch.nn.functional as F
6 | from einops import rearrange
7 | import numpy as np
8 |
9 | from basic import *
10 |
11 |
12 | def window_partition(x, win_size):
13 | B, H, W, C = x.shape
14 | x = x.view(B, H // win_size, win_size, W // win_size, win_size, C).contiguous()
15 | windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C)
16 | return windows
17 |
18 |
19 | def window_reverse(windows, win_size, H, W):
20 | B = int(windows.shape[0] / (H * W / win_size / win_size))
21 | x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1).contiguous()
22 | x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
23 | return x
24 |
25 |
26 | class InputProj(nn.Module):
27 | def __init__(self, in_channels=3, out_channels=64, kernel_size=3, stride=1, norm_layer=None,
28 | act_layer=nn.GELU):
29 | super().__init__()
30 | self.proj = nn.Sequential(
31 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=kernel_size // 2),
32 | nn.BatchNorm2d(out_channels),
33 | act_layer()
34 | )
35 |
36 | def forward(self, x):
37 | B, C, H, W = x.shape
38 |
39 | x = self.proj(x)
40 |
41 | return x
42 |
43 |
44 | class OutputProj(nn.Module):
45 | def __init__(self, in_channels=64, out_channels=3, kernel_size=3, stride=1, norm_layer=None, act_layer=None):
46 | super().__init__()
47 | self.proj = nn.Sequential(
48 | nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=kernel_size // 2),
49 | nn.BatchNorm2d(out_channels),
50 | act_layer()
51 | )
52 |
53 | def forward(self, x):
54 | x = self.proj(x)
55 |
56 | return x
57 |
58 |
59 | class PatchMerging(nn.Module):
60 | """ Patch Merging Layer
61 |
62 | Args:
63 | dim (int): Number of input channels.
64 | norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
65 | """
66 |
67 | def __init__(self, dim, norm_layer=nn.LayerNorm):
68 | super().__init__()
69 | self.dim = dim
70 | self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
71 | self.norm = norm_layer(4 * dim)
72 |
73 | def forward(self, x, input_size):
74 | """ Forward function.
75 |
76 | Args:
77 | x: Input feature, tensor size (B, H*W, C).
78 | H, W: Spatial resolution of the input feature.
79 | """
80 | H, W = input_size
81 | B, L, C = x.shape
82 | assert L == H * W, "input feature has wrong size"
83 |
84 | x = x.view(B, H, W, C)
85 |
86 | # padding
87 | pad_input = (H % 2 == 1) or (W % 2 == 1)
88 | if pad_input:
89 | x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
90 |
91 | x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
92 | x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
93 | x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
94 | x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
95 | x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
96 | x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
97 |
98 | x = self.norm(x)
99 | x = self.reduction(x)
100 |
101 | return x
102 |
103 |
104 | class PatchShuffle(nn.Module):
105 | def __init__(self, dim, out_dim, dim_scale=2, norm_layer=nn.LayerNorm):
106 | super().__init__()
107 |
108 | self.dim_scale = dim_scale
109 | self.expand = nn.Linear(dim, out_dim * (dim_scale ** 2), bias=False)
110 | self.norm = norm_layer(out_dim)
111 |
112 | def forward(self, x, input_size, out_size):
113 | H, W = input_size
114 | Hout, Wout = out_size
115 | x = self.expand(x)
116 | B, L, C = x.shape
117 | assert L == H * W, "input feature has wrong size"
118 |
119 | x = x.view(B, H, W, C).contiguous()
120 | if H % self.dim_scale != 0 or W % self.dim_scale != 0:
121 | H_pad = self.dim_scale - H % self.dim_scale
122 | W_pad = self.dim_scale - W % self.dim_scale
123 | x = F.pad(x, (0, 0, 0, W_pad, 0, H_pad))
124 |
125 | x = rearrange(x, 'b h w (p1 p2 c) -> b (h p1) (w p2) c',
126 | p1=self.dim_scale, p2=self.dim_scale, c=C // (self.dim_scale ** 2))
127 | x = x[:, :Hout, :Wout, :]
128 | x = x.reshape(B, -1, C // (self.dim_scale ** 2)).contiguous()
129 | x = self.norm(x)
130 |
131 | return x
132 |
133 |
134 | class LinearProjection(nn.Module):
135 | def __init__(self, dim, heads=8, dim_head=64, dropout=0., bias=True, guide=False):
136 | super().__init__()
137 | inner_dim = dim_head * heads
138 | self.heads = heads
139 | self.proj_in = nn.Identity()
140 | self.guide = guide
141 | self.to_q = nn.Linear(dim, inner_dim, bias=bias)
142 | self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias)
143 |
144 | def forward(self, x, x_guide=None):
145 | B_, N, C = x.shape
146 | if self.guide:
147 | kv = self.to_kv(x_guide).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
148 | else:
149 | kv = self.to_kv(x).reshape(B_, N, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
150 | q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
151 | q = q[0]
152 | k, v = kv[0], kv[1]
153 | return q, k, v
154 |
155 |
156 | class FFN(nn.Module):
157 | def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
158 | super().__init__()
159 | out_features = out_features or in_features
160 | hidden_features = hidden_features or in_features
161 | self.fc1 = nn.Linear(in_features, hidden_features)
162 | self.act = act_layer()
163 | self.fc2 = nn.Linear(hidden_features, out_features)
164 | self.drop = nn.Dropout(drop)
165 |
166 | def forward(self, x, input_size=None):
167 | x = self.fc1(x)
168 | x = self.act(x)
169 | x = self.drop(x)
170 | x = self.fc2(x)
171 | x = self.drop(x)
172 | return x
173 |
174 |
175 | class DWCFF(nn.Module):
176 | def __init__(self, dim=32, out_dim=32, hidden_dim=128, act_layer=nn.GELU, drop=0.):
177 | super().__init__()
178 | self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim), act_layer())
179 | self.dwconv = nn.Sequential(nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim,
180 | kernel_size=3, stride=1, padding=1),
181 | act_layer())
182 | self.linear2 = nn.Sequential(nn.Linear(hidden_dim, out_dim))
183 |
184 | def forward(self, x, input_size):
185 | # bs x hw x c
186 | B, L, C = x.size()
187 | H, W = input_size
188 | assert H * W == L, "output H x W is not the same with L!"
189 |
190 | x = self.linear1(x)
191 |
192 | # spatial restore
193 | x = rearrange(x, ' b (h w) (c) -> b c h w ', h=H, w=W) # bs, hidden_dim, 32x32
194 | x = self.dwconv(x)
195 |
196 | # flatten
197 | x = rearrange(x, ' b c h w -> b (h w) c', h=H, w=W)
198 | x = self.linear2(x)
199 |
200 | return x
201 |
202 |
203 | class WindowAttention(nn.Module):
204 | def __init__(self, dim, win_size, num_heads,
205 | qkv_bias=True, qk_scale=None,
206 | attn_drop=0., proj_drop=0.,
207 | guide=False):
208 |
209 | super().__init__()
210 | self.dim = dim
211 | self.win_size = win_size # Wh, Ww
212 | self.num_heads = num_heads
213 | head_dim = dim // num_heads
214 | self.scale = qk_scale or head_dim ** -0.5
215 | self.guide = guide
216 |
217 | # define a parameter table of relative position bias
218 | self.relative_position_bias_table = nn.Parameter(
219 | torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
220 |
221 | # get pair-wise relative position index for each token inside the window
222 | coords_h = torch.arange(self.win_size[0]) # [0,...,Wh-1]
223 | coords_w = torch.arange(self.win_size[1]) # [0,...,Ww-1]
224 | coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
225 | coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
226 | relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
227 | relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
228 | relative_coords[:, :, 0] += self.win_size[0] - 1 # shift to start from 0
229 | relative_coords[:, :, 1] += self.win_size[1] - 1
230 | relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
231 | relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
232 | self.register_buffer("relative_position_index", relative_position_index)
233 |
234 | self.qkv = LinearProjection(dim, num_heads, self.dim // num_heads, bias=qkv_bias, guide=guide)
235 |
236 | self.attn_drop = nn.Dropout(attn_drop)
237 | self.proj = nn.Linear(dim, dim)
238 | self.proj_drop = nn.Dropout(proj_drop)
239 |
240 | trunc_normal_(self.relative_position_bias_table, std=.02)
241 | self.softmax = nn.Softmax(dim=-1)
242 |
243 | def forward(self, x, x_guide=None, mask=None):
244 | B_, N, C = x.shape
245 | q, k, v = self.qkv(x, x_guide)
246 | q = q * self.scale
247 | attn = (q @ k.transpose(-2, -1))
248 |
249 | relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
250 | self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) # Wh*Ww,Wh*Ww,nH
251 | relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
252 |
253 | attn = attn + relative_position_bias.unsqueeze(0)
254 |
255 | if mask is not None:
256 | nW = mask.shape[0]
257 | attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
258 | attn = attn.view(-1, self.num_heads, N, N)
259 | attn = self.softmax(attn)
260 | else:
261 | attn = self.softmax(attn)
262 |
263 | attn = self.attn_drop(attn)
264 |
265 | x = (attn @ v).transpose(1, 2).reshape(B_, N, self.dim)
266 | x = self.proj(x)
267 | x = self.proj_drop(x)
268 |
269 | return x
270 |
271 | def extra_repr(self) -> str:
272 | return f'dim={self.dim}, win_size={self.win_size}, num_heads={self.num_heads}'
273 |
274 |
275 |
276 | class TransformerBlock(nn.Module):
277 | def __init__(self, dim, out_dim,
278 | num_heads, win_size=8, shift_size=0,
279 | mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
280 | act_layer=nn.GELU, norm_layer=nn.LayerNorm,
281 | token_mlp='dwc', guide=False):
282 | super().__init__()
283 | self.dim = out_dim
284 | self.num_heads = num_heads
285 | self.win_size = win_size
286 | self.shift_size = shift_size
287 | self.mlp_ratio = mlp_ratio
288 | self.guide = guide
289 | assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size"
290 |
291 | self.norm1 = norm_layer(out_dim)
292 | self.attn = WindowAttention(
293 | out_dim, win_size=to_2tuple(self.win_size), num_heads=num_heads,
294 | qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
295 | guide=guide)
296 |
297 | self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
298 | self.norm2 = norm_layer(out_dim)
299 | mlp_hidden_dim = int(out_dim * mlp_ratio)
300 | if token_mlp == 'dwc':
301 | self.mlp = DWCFF(out_dim, out_dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)
302 |
303 | else:
304 | self.mlp = FFN(in_features=out_dim, out_features=out_dim, hidden_features=mlp_hidden_dim,
305 | act_layer=act_layer, drop=drop)
306 |
307 | self.proj_in = nn.Identity()
308 | if dim != out_dim:
309 | self.proj_in = nn.Linear(dim, out_dim)
310 |
311 | self.H = None
312 | self.W = None
313 |
314 | def extra_repr(self) -> str:
315 | return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
316 | f"win_size={self.win_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
317 |
318 | def forward(self, x, x_guide=None):
319 | x = self.proj_in(x)
320 | B, L, C = x.shape
321 | H, W = self.H, self.W
322 | assert H * W == L, "input H x W is not the same with L!"
323 |
324 | shortcut = x
325 | x = self.norm1(x)
326 | x = x.view(B, H, W, C)
327 | if self.guide:
328 | C_guide = x_guide.size(-1)
329 | x_guide = self.norm1(x_guide).view(B, H, W, -1)
330 |
331 | # pad feature maps to multiples of window size
332 | pad_l = pad_t = 0
333 | pad_r = (self.win_size - W % self.win_size) % self.win_size
334 | pad_b = (self.win_size - H % self.win_size) % self.win_size
335 | x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
336 | if self.guide:
337 | x_guide = F.pad(x_guide, (0, 0, pad_l, pad_r, pad_t, pad_b))
338 | _, Hp, Wp, _ = x.shape
339 |
340 | if self.shift_size > 0:
341 | # calculate attention mask for SW-MSA
342 | img_mask = torch.zeros((1, Hp, Wp, 1)).type_as(x).detach() # 1 H W 1
343 | h_slices = (slice(0, -self.win_size),
344 | slice(-self.win_size, -self.shift_size),
345 | slice(-self.shift_size, None))
346 | w_slices = (slice(0, -self.win_size),
347 | slice(-self.win_size, -self.shift_size),
348 | slice(-self.shift_size, None))
349 | cnt = 0
350 | for h in h_slices:
351 | for w in w_slices:
352 | img_mask[:, h, w, :] = cnt
353 | cnt += 1
354 |
355 | mask_windows = window_partition(img_mask, self.win_size) # nW, win_size, win_size, 1
356 | mask_windows = mask_windows.view(-1, self.win_size * self.win_size)
357 | attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
358 | attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
359 | attn_mask = attn_mask.type_as(x)
360 | else:
361 | attn_mask = None
362 |
363 | # cyclic shift
364 | if self.shift_size > 0:
365 | shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
366 | if self.guide:
367 | x_guide = torch.roll(x_guide, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
368 | else:
369 | shifted_x = x
370 |
371 | # partition windows
372 | x_windows = window_partition(shifted_x, self.win_size) # nW*B, win_size, win_size, C
373 | x_windows = x_windows.view(-1, self.win_size * self.win_size, C) # nW*B, win_size*win_size, C
374 | if self.guide:
375 | x_guide = window_partition(x_guide, self.win_size) # nW*B, win_size, win_size, C
376 | x_guide = x_guide.view(-1, self.win_size * self.win_size, C_guide) # nW*B, win_size*win_size, C
377 |
378 | # W-MSA/SW-MSA
379 | attn_windows = self.attn(x_windows, x_guide=x_guide, mask=attn_mask) # nW*B, win_size*win_size, C
380 |
381 | # merge windows
382 | attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C)
383 | shifted_x = window_reverse(attn_windows, self.win_size, Hp, Wp) # B H' W' C
384 |
385 | # reverse cyclic shift
386 | if self.shift_size > 0:
387 | x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
388 | else:
389 | x = shifted_x
390 |
391 | if pad_r > 0 or pad_b > 0:
392 | x = x[:, :H, :W, :].contiguous()
393 |
394 | x = x.view(B, H * W, C)
395 |
396 | # FFN
397 | x = shortcut + self.drop_path(x)
398 | x = x + self.drop_path(self.mlp(self.norm2(x), (H, W)))
399 |
400 | del attn_mask
401 |
402 | return x
403 |
404 |
405 | class GuideFormerLayer(nn.Module):
406 | def __init__(self, dim, out_dim,
407 | depth, num_heads, win_size,
408 | mlp_ratio=4., qkv_bias=True, qk_scale=None,
409 | drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm,
410 | token_mlp='dwc', use_checkpoint=False):
411 |
412 | super().__init__()
413 | self.dim = dim
414 | self.depth = depth
415 | self.use_checkpoint = use_checkpoint
416 |
417 | # build blocks
418 | self.blocks = nn.ModuleList()
419 | for i in range(depth):
420 | if i == 0:
421 | self.blocks.append(TransformerBlock(dim=dim, out_dim=out_dim,
422 | num_heads=num_heads, win_size=win_size,
423 | shift_size=0 if (i % 2 == 0) else win_size // 2,
424 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
425 | drop=drop, attn_drop=attn_drop,
426 | drop_path=drop_path[i] if isinstance(drop_path,
427 | list) else drop_path,
428 | norm_layer=norm_layer, token_mlp=token_mlp))
429 | else:
430 | self.blocks.append(TransformerBlock(dim=out_dim, out_dim=out_dim,
431 | num_heads=num_heads, win_size=win_size,
432 | shift_size=0 if (i % 2 == 0) else win_size // 2,
433 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
434 | drop=drop, attn_drop=attn_drop,
435 | drop_path=drop_path[i] if isinstance(drop_path,
436 | list) else drop_path,
437 | norm_layer=norm_layer, token_mlp=token_mlp))
438 |
439 | def extra_repr(self) -> str:
440 | return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
441 |
442 | def forward(self, x, input_size):
443 | H, W = input_size
444 | B, L, C = x.shape
445 | assert H * W == L, "input H x W is not the same with L!"
446 |
447 | for blk in self.blocks:
448 | blk.H, blk.W = H, W
449 | if self.use_checkpoint:
450 | x = checkpoint.checkpoint(blk, x)
451 | else:
452 | x = blk(x)
453 |
454 | return x
455 |
456 |
457 |
458 | class FusionLayer(nn.Module):
459 | def __init__(self, dim, out_dim,
460 | depth, num_heads, win_size,
461 | mlp_ratio=4., qkv_bias=True, qk_scale=None,
462 | drop=0., attn_drop=0., drop_path=0.1, norm_layer=nn.LayerNorm,
463 | token_mlp='dwc', use_checkpoint=False):
464 |
465 | super().__init__()
466 | self.dim = dim
467 | self.depth = depth
468 | self.use_checkpoint = use_checkpoint
469 |
470 | # build blocks
471 | self.blocks = nn.ModuleList()
472 | for i in range(depth):
473 | if i == 0:
474 | self.blocks.append(TransformerBlock(dim=dim, out_dim=out_dim,
475 | num_heads=num_heads, win_size=win_size,
476 | shift_size=0 if (i % 2 == 0) else win_size // 2,
477 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
478 | drop=drop, attn_drop=attn_drop,
479 | drop_path=drop_path[i] if isinstance(drop_path,
480 | list) else drop_path,
481 | norm_layer=norm_layer, token_mlp=token_mlp,
482 | guide=True))
483 | else:
484 | self.blocks.append(TransformerBlock(dim=out_dim, out_dim=out_dim,
485 | num_heads=num_heads, win_size=win_size,
486 | shift_size=0 if (i % 2 == 0) else win_size // 2,
487 | mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
488 | drop=drop, attn_drop=attn_drop,
489 | drop_path=drop_path[i] if isinstance(drop_path,
490 | list) else drop_path,
491 | norm_layer=norm_layer, token_mlp=token_mlp,
492 | guide=True))
493 |
494 | def forward(self, depth_feat, rgb_feat, input_size):
495 | H, W = input_size
496 | B, L, C = rgb_feat.shape
497 | assert H * W == L, "input H x W is not the same with L!"
498 |
499 | x = depth_feat
500 | for blk in self.blocks:
501 | blk.H, blk.W = H, W
502 | if self.use_checkpoint:
503 | x = checkpoint.checkpoint(blk, x, rgb_feat)
504 | else:
505 | x = blk(x, x_guide=rgb_feat) # B L 2C
506 |
507 | return x
508 |
--------------------------------------------------------------------------------
/vis_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | if not ("DISPLAY" in os.environ):
3 | import matplotlib as mpl
4 | mpl.use('Agg')
5 | import matplotlib.pyplot as plt
6 | from PIL import Image
7 | import numpy as np
8 | import cv2
9 |
10 | cmap = plt.cm.jet
11 | cmap2 = plt.cm.nipy_spectral
12 |
13 | def validcrop(img):
14 | ratio = 256/1216
15 | h = img.size()[2]
16 | w = img.size()[3]
17 | return img[:, :, h-int(ratio*w):, :]
18 |
19 | def depth_colorize(depth):
20 | depth = (depth - np.min(depth)) / (np.max(depth) - np.min(depth))
21 | depth = 255 * cmap(depth)[:, :, :3] # H, W, C
22 | return depth.astype('uint8')
23 |
24 | def feature_colorize(feature):
25 | feature = (feature - np.min(feature)) / ((np.max(feature) - np.min(feature)))
26 | feature = 255 * cmap2(feature)[:, :, :3]
27 | return feature.astype('uint8')
28 |
29 | def mask_vis(mask):
30 | mask = (mask - np.min(mask)) / (np.max(mask) - np.min(mask))
31 | mask = 255 * mask
32 | return mask.astype('uint8')
33 |
34 | def merge_into_row(ele, pred, predrgb=None, predg=None, extra=None, extra2=None, extrargb=None):
35 | def preprocess_depth(x):
36 | y = np.squeeze(x.data.cpu().numpy())
37 | return depth_colorize(y)
38 |
39 | # if is gray, transforms to rgb
40 | img_list = []
41 | if 'rgb' in ele:
42 | rgb = np.squeeze(ele['rgb'][0, ...].data.cpu().numpy())
43 | rgb = np.transpose(rgb, (1, 2, 0))
44 | img_list.append(rgb)
45 | elif 'g' in ele:
46 | g = np.squeeze(ele['g'][0, ...].data.cpu().numpy())
47 | g = np.array(Image.fromarray(g).convert('RGB'))
48 | img_list.append(g)
49 | if 'd' in ele:
50 | img_list.append(preprocess_depth(ele['d'][0, ...]))
51 | img_list.append(preprocess_depth(pred[0, ...]))
52 | if extrargb is not None:
53 | img_list.append(preprocess_depth(extrargb[0, ...]))
54 | if predrgb is not None:
55 | predrgb = np.squeeze(ele['rgb'][0, ...].data.cpu().numpy())
56 | predrgb = np.transpose(predrgb, (1, 2, 0))
57 | #predrgb = predrgb.astype('uint8')
58 | img_list.append(predrgb)
59 | if predg is not None:
60 | predg = np.squeeze(predg[0, ...].data.cpu().numpy())
61 | predg = mask_vis(predg)
62 | predg = np.array(Image.fromarray(predg).convert('RGB'))
63 | #predg = predg.astype('uint8')
64 | img_list.append(predg)
65 | if extra is not None:
66 | extra = np.squeeze(extra[0, ...].data.cpu().numpy())
67 | extra = mask_vis(extra)
68 | extra = np.array(Image.fromarray(extra).convert('RGB'))
69 | img_list.append(extra)
70 | if extra2 is not None:
71 | extra2 = np.squeeze(extra2[0, ...].data.cpu().numpy())
72 | extra2 = mask_vis(extra2)
73 | extra2 = np.array(Image.fromarray(extra2).convert('RGB'))
74 | img_list.append(extra2)
75 | if 'gt' in ele:
76 | img_list.append(preprocess_depth(ele['gt'][0, ...]))
77 |
78 | img_merge = np.hstack(img_list)
79 | return img_merge.astype('uint8')
80 |
81 |
82 | def add_row(img_merge, row):
83 | return np.vstack([img_merge, row])
84 |
85 |
86 | def save_image(img_merge, filename):
87 | image_to_write = cv2.cvtColor(img_merge, cv2.COLOR_RGB2BGR)
88 | cv2.imwrite(filename, image_to_write)
89 |
90 | def save_image_torch(rgb, filename):
91 | #torch2numpy
92 | rgb = validcrop(rgb)
93 | rgb = np.squeeze(rgb[0, ...].data.cpu().numpy())
94 | #print(rgb.size())
95 | rgb = np.transpose(rgb, (1, 2, 0))
96 | rgb = rgb.astype('uint8')
97 | image_to_write = cv2.cvtColor(rgb, cv2.COLOR_RGB2BGR)
98 | cv2.imwrite(filename, image_to_write)
99 |
100 | def save_depth_as_uint16png(img, filename):
101 | #from tensor
102 | img = np.squeeze(img.data.cpu().numpy())
103 | img = (img * 256).astype('uint16')
104 | cv2.imwrite(filename, img)
105 |
106 | def save_depth_as_uint16png_upload(img, filename):
107 | #from tensor
108 | img = np.squeeze(img.data.cpu().numpy())
109 | img = (img * 256.0).astype('uint16')
110 | img_buffer = img.tobytes()
111 | imgsave = Image.new("I", img.T.shape)
112 | imgsave.frombytes(img_buffer, 'raw', "I;16")
113 | imgsave.save(filename)
114 |
115 | def save_depth_as_uint8colored(img, filename):
116 | #from tensor
117 | img = validcrop(img)
118 | img = np.squeeze(img.data.cpu().numpy())
119 | img = depth_colorize(img)
120 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
121 | cv2.imwrite(filename, img)
122 |
123 | def save_mask_as_uint8colored(img, filename, colored=True, normalized=True):
124 | img = validcrop(img)
125 | img = np.squeeze(img.data.cpu().numpy())
126 | if(normalized==False):
127 | img = (img - np.min(img)) / (np.max(img) - np.min(img))
128 | if(colored==True):
129 | img = 255 * cmap(img)[:, :, :3]
130 | else:
131 | img = 255 * img
132 | img = img.astype('uint8')
133 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
134 | cv2.imwrite(filename, img)
135 |
136 | def save_feature_as_uint8colored(img, filename):
137 | img = validcrop(img)
138 | img = np.squeeze(img.data.cpu().numpy())
139 | img = feature_colorize(img)
140 | img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
141 | cv2.imwrite(filename, img)
142 |
--------------------------------------------------------------------------------