├── LICENSE ├── README.md ├── core ├── R_MSFM.py ├── __pycache__ │ ├── R_MSFM.cpython-38.pyc │ ├── corr.cpython-38.pyc │ ├── datasets.cpython-38.pyc │ ├── depthRAFT.cpython-38.pyc │ ├── extractor.cpython-38.pyc │ ├── raft.cpython-38.pyc │ ├── triplet_attention.cpython-38.pyc │ └── update.cpython-38.pyc └── update.py ├── datasets ├── __init__.py ├── kitti_dataset.py └── mono_dataset.py ├── evaluate_depth.py ├── kitti_utils.py ├── layers.py ├── networks ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-38.pyc │ ├── depthRAFT.cpython-38.pyc │ ├── depth_decoder.cpython-38.pyc │ ├── pose_cnn.cpython-38.pyc │ ├── pose_decoder.cpython-38.pyc │ └── resnet_encoder.cpython-38.pyc ├── pose_cnn.py ├── pose_decoder.py └── resnet_encoder.py ├── options.py ├── test_simple.py ├── train.py ├── trainer.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2021 Gus 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # R-MSFM: Recurrent Multi-Scale Feature Modulation for Monocular Depth Estimating(ICCV-2021) 2 | This is the official implementation for testing depth estimation using the model proposed in 3 | >R-MSFM: Recurrent Multi-Scale Feature Modulation for Monocular Depth Estimating 4 | >Zhongkai Zhou, Xinnan Fan, Pengfei Shi, Yuanxue Xin 5 | 6 | R-MSFM can estimate a depth map from a single image. 7 | 8 | Paper is now available at [ICCV2021](https://openaccess.thecvf.com/content/ICCV2021/papers/Zhou_R-MSFM_Recurrent_Multi-Scale_Feature_Modulation_for_Monocular_Depth_Estimating_ICCV_2021_paper.pdf) 9 | 10 | ## Training code 11 | You can train R-MSFM3 with: 12 | ```shell 13 | python train.py --iters=3 14 | ``` 15 | or R-MSFM6 with 16 | ```shell 17 | python train.py --iters=6 18 | ``` 19 | or R-MSFM3-GC with 20 | ```shell 21 | python train.py --iters=3 --gc 22 | ``` 23 | or R-MSFM6-GC with 24 | ```shell 25 | python train.py --iters=6 --gc 26 | ``` 27 | 28 | ## Improved Version(T-PAMI2024) 29 | Paper is now available at [T-PAMI2024](https://ieeexplore.ieee.org/abstract/document/10574331). 30 | In this paper, we improve our R-MSFM from two aspects and achieve SOTA results. 31 | 1. We propose another lightweight convolutional network R-MSFMX that evolved from our R-MSFM to better address the problem of depth estimation. Our R-MSFMX takes the first three blocks from ResNet50 instead of ResNet18 in our R-MSFM and further improves the depth accuracy. 32 | 2. We promote the geometry consistent depth learning for both our R-MSFM and R-MSFMX, which prevents the depth artifacts at object borders and thus generates more consistent depth. We denote the models that perform geometry consistent depth estimation by the postfix (GC). 33 | 34 | We show the superiority of our R-MSFMX-GC as follows: 35 | ![4](https://user-images.githubusercontent.com/32475718/160613575-a924c751-7352-4429-87ff-c6f6bcc19c44.jpg) 36 | 37 | 38 | The rows (from up to bottom) are RGB images, and the results by [Monodepth2](https://github.com/nianticlabs/monodepth2), R-MSFM6, and the improved version R-MSFMX6-GC. 39 | 40 | 41 | ## R-MSFM Results 42 | ![image](https://user-images.githubusercontent.com/32475718/160614132-3e7d25cc-e3d2-4d63-a2de-4fcaf10ef04e.png) 43 | ## R-MSFMX Results 44 | ![image](https://user-images.githubusercontent.com/32475718/160617371-50e304c0-1266-4ccc-afb7-524231c43bcf.png) 45 | 46 | 47 | 48 | 49 | ## Precomputed Results 50 | We have updated all the results as follows: 51 | [results](https://drive.google.com/drive/folders/1xLglsHFVxxTlvj5UBEyK5MQ_D0dLIjbS?usp=sharing) 52 | 53 | ## Pretrained Models 54 | We have updated all the results as follows: 55 | [models](https://drive.google.com/drive/folders/1IhUsEEY-oKfgcsTX2uHuENMe7u-1Pzik?usp=sharing) 56 | 57 | ## KITTI Evaluation 58 | You can predict scaled disparity for a single image used R-MSFM3 with: 59 | ```shell 60 | python test_simple.py --image_path='path_to_image' --model_path='path_to_model' --update=3 61 | ``` 62 | or R-MSFMX3 with 63 | ```shell 64 | python test_simple.py --image_path='path_to_image' --model_path='path_to_model' --update=3 --x 65 | ``` 66 | or R-MSFM6 with: 67 | ```shell 68 | python test_simple.py --image_path='path_to_image' --model_path='path_to_model' --update=6 69 | ``` 70 | or R-MSFM6X with: 71 | ```shell 72 | python test_simple.py --image_path='path_to_image' --model_path='path_to_model' --update=6 --x 73 | ``` 74 | ## License & Acknowledgement 75 | The codes are based on [RAFT](https://github.com/princeton-vl/RAFT), [Monodepth2](https://github.com/nianticlabs/monodepth2). Please also follow their licenses. Thanks for their great works. 76 | -------------------------------------------------------------------------------- /core/R_MSFM.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from update import BasicUpdateBlock 6 | 7 | try: 8 | autocast = torch.cuda.amp.autocast 9 | except: 10 | # dummy autocast for PyTorch < 1.6 11 | class autocast: 12 | def __init__(self, enabled): 13 | pass 14 | 15 | def __enter__(self): 16 | pass 17 | 18 | def __exit__(self, *args): 19 | pass 20 | 21 | 22 | class SepConvGRU(nn.Module): 23 | def __init__(self): 24 | super(SepConvGRU, self).__init__() 25 | hidden_dim = 128 26 | catt = 256 27 | 28 | self.convz1 = nn.Conv2d(catt, hidden_dim, (1, 3), padding=(0, 1)) 29 | self.convr1 = nn.Conv2d(catt, hidden_dim, (1, 3), padding=(0, 1)) 30 | self.convq1 = nn.Conv2d(catt, hidden_dim, (1, 3), padding=(0, 1)) 31 | 32 | self.convz2 = nn.Conv2d(catt, hidden_dim, (3, 1), padding=(1, 0)) 33 | self.convr2 = nn.Conv2d(catt, hidden_dim, (3, 1), padding=(1, 0)) 34 | self.convq2 = nn.Conv2d(catt, hidden_dim, (3, 1), padding=(1, 0)) 35 | 36 | def forward(self, h, x): 37 | # horizontal 38 | hx = torch.cat([h, x], dim=1) 39 | z = torch.sigmoid(self.convz1(hx)) 40 | r = torch.sigmoid(self.convr1(hx)) 41 | q = torch.tanh(self.convq1(torch.cat([r * h, x], dim=1))) 42 | h = (1 - z) * h + z * q 43 | 44 | # vertical 45 | hx = torch.cat([h, x], dim=1) 46 | z = torch.sigmoid(self.convz2(hx)) 47 | r = torch.sigmoid(self.convr2(hx)) 48 | q = torch.tanh(self.convq2(torch.cat([r * h, x], dim=1))) 49 | h = (1 - z) * h + z * q 50 | 51 | return h 52 | 53 | 54 | class R_MSFM3(nn.Module): 55 | def __init__(self, x): 56 | super(R_MSFM3, self).__init__() 57 | self.convX11 = torch.nn.Sequential( 58 | nn.ReflectionPad2d(1), 59 | torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=0, bias=True), 60 | torch.nn.LeakyReLU(inplace=True), 61 | nn.ReflectionPad2d(1), 62 | torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=0, bias=True), 63 | torch.nn.Tanh()) 64 | 65 | if x: 66 | self.convX21 = torch.nn.Sequential( 67 | nn.ReflectionPad2d(1), 68 | torch.nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=0, bias=True), 69 | torch.nn.Tanh()) 70 | self.convX31 = torch.nn.Sequential( 71 | nn.ReflectionPad2d(1), 72 | torch.nn.Conv2d(in_channels=512, out_channels=128, kernel_size=3, stride=1, padding=0, bias=True), 73 | torch.nn.Tanh()) 74 | else: 75 | self.convX21 = torch.nn.Sequential( 76 | nn.ReflectionPad2d(1), 77 | torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=0, bias=True), 78 | torch.nn.Tanh()) 79 | self.convX31 = torch.nn.Sequential( 80 | nn.ReflectionPad2d(1), 81 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=0, bias=True), 82 | torch.nn.Tanh()) 83 | 84 | self.sigmoid = nn.Sigmoid() 85 | self.update_block = BasicUpdateBlock() 86 | self.gruc = SepConvGRU() 87 | 88 | def upsample_depth(self, flow, mask): 89 | N, _, H, W = flow.shape 90 | mask = mask.view(N, 1, 9, 8, 8, H, W) 91 | mask = torch.softmax(mask, dim=2) 92 | 93 | up_flow = F.unfold(flow, [3, 3], padding=1) 94 | up_flow = up_flow.view(N, 1, 9, 1, 1, H, W) 95 | 96 | up_flow = torch.sum(mask * up_flow, dim=2) 97 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 98 | return up_flow.reshape(N, 1, 8 * H, 8 * W) 99 | 100 | def forward(self, features, iters=3): 101 | 102 | x1, x2, x3 = features 103 | disp_predictions = {} 104 | b, c, h, w = x3.shape 105 | dispFea = torch.zeros([b, 1, h, w], requires_grad=True).to(x1.device) 106 | net = torch.zeros([b, 256, h, w], requires_grad=True).to(x1.device) 107 | 108 | for itr in range(iters): 109 | if itr in [0]: 110 | corr = self.convX31(x3) 111 | elif itr in [1]: 112 | corrh = corr 113 | corr = self.convX21(x2) 114 | corr = self.gruc(corrh, corr) 115 | elif itr in [2]: 116 | corrh = corr 117 | corr = self.convX11(x1) 118 | corr = self.gruc(corrh, corr) 119 | 120 | net, up_mask, delta_disp = self.update_block(net, corr, dispFea) 121 | dispFea = dispFea + delta_disp 122 | 123 | disp = self.sigmoid(dispFea) 124 | # upsample predictions 125 | if self.training: 126 | disp_up = self.upsample_depth(disp, up_mask) 127 | disp_predictions[("disp_up", itr)] = disp_up 128 | else: 129 | if (iters-1)==itr: 130 | disp_up = self.upsample_depth(disp, up_mask) 131 | disp_predictions[("disp_up", itr)] = disp_up 132 | 133 | 134 | return disp_predictions 135 | 136 | 137 | class R_MSFM6(nn.Module): 138 | def __init__(self,x): 139 | super(R_MSFM6, self).__init__() 140 | 141 | self.convX11 = torch.nn.Sequential( 142 | nn.ReflectionPad2d(1), 143 | torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=0, bias=True), 144 | torch.nn.LeakyReLU(inplace=True), 145 | nn.ReflectionPad2d(1), 146 | torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=0, bias=True), 147 | torch.nn.Tanh()) 148 | 149 | self.convX12 = torch.nn.Sequential( 150 | nn.Conv2d(128, 128, (1, 3), padding=(0, 1)), 151 | torch.nn.Tanh(), 152 | nn.Conv2d(128, 128, (3, 1), padding=(1, 0)), 153 | torch.nn.Tanh()) 154 | 155 | 156 | if x: 157 | self.convX21 = torch.nn.Sequential( 158 | nn.ReflectionPad2d(1), 159 | torch.nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=2, padding=0, bias=True), 160 | torch.nn.Tanh()) 161 | self.convX31 = torch.nn.Sequential( 162 | nn.ReflectionPad2d(1), 163 | torch.nn.Conv2d(in_channels=512, out_channels=128, kernel_size=3, stride=1, padding=0, bias=True), 164 | torch.nn.Tanh()) 165 | else: 166 | self.convX21 = torch.nn.Sequential( 167 | nn.ReflectionPad2d(1), 168 | torch.nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=2, padding=0, bias=True), 169 | torch.nn.Tanh()) 170 | self.convX31 = torch.nn.Sequential( 171 | nn.ReflectionPad2d(1), 172 | torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=0, dilation=1, 173 | bias=True), 174 | torch.nn.Tanh()) 175 | 176 | 177 | 178 | self.convX22 = torch.nn.Sequential( 179 | nn.Conv2d(128, 128, (1, 3), padding=(0, 1)), 180 | torch.nn.Tanh(), 181 | nn.Conv2d(128, 128, (3, 1), padding=(1, 0)), 182 | torch.nn.Tanh()) 183 | 184 | self.convX32 = torch.nn.Sequential( 185 | nn.Conv2d(128, 128, (1, 3), padding=(0, 1)), 186 | torch.nn.Tanh(), 187 | nn.Conv2d(128, 128, (3, 1), padding=(1, 0)), 188 | torch.nn.Tanh()) 189 | 190 | self.sigmoid = nn.Sigmoid() 191 | self.gruc = SepConvGRU() 192 | self.update_block = BasicUpdateBlock() 193 | 194 | def upsample_depth(self, flow, mask): 195 | """ Upsample flow field [H/8, W/8, 2] -> [H, W, 2] using convex combination """ 196 | N, _, H, W = flow.shape 197 | mask = mask.view(N, 1, 9, 8, 8, H, W) 198 | mask = torch.softmax(mask, dim=2) 199 | 200 | up_flow = F.unfold(flow, [3, 3], padding=1) 201 | up_flow = up_flow.view(N, 1, 9, 1, 1, H, W) 202 | 203 | up_flow = torch.sum(mask * up_flow, dim=2) 204 | up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) 205 | return up_flow.reshape(N, 1, 8 * H, 8 * W) 206 | 207 | def forward(self, features, iters=6): 208 | """ Estimate depth for a single image """ 209 | 210 | x1, x2, x3 = features 211 | 212 | disp_predictions = {} 213 | b, c, h, w = x3.shape 214 | dispFea = torch.zeros([b, 1, h, w], requires_grad=True).to(x1.device) 215 | net = torch.zeros([b, 256, h, w], requires_grad=True).to(x1.device) 216 | 217 | for itr in range(iters): 218 | if itr in [0]: 219 | corr = self.convX31(x3) 220 | elif itr in [1]: 221 | corrh = corr 222 | corr = self.convX32(corr) 223 | corr = self.gruc(corrh, corr) 224 | elif itr in [2]: 225 | corrh = corr 226 | corr = self.convX21(x2) 227 | corr = self.gruc(corrh, corr) 228 | elif itr in [3]: 229 | corrh = corr 230 | corr = self.convX22(corr) 231 | corr = self.gruc(corrh, corr) 232 | elif itr in [4]: 233 | corrh = corr 234 | corr = self.convX11(x1) 235 | corr = self.gruc(corrh, corr) 236 | elif itr in [5]: 237 | corrh = corr 238 | corr = self.convX12(corr) 239 | corr = self.gruc(corrh, corr) 240 | 241 | net, up_mask, delta_disp = self.update_block(net, corr, dispFea) 242 | dispFea = dispFea + delta_disp 243 | 244 | disp = self.sigmoid(dispFea) 245 | # upsample predictions 246 | 247 | if self.training: 248 | disp_up = self.upsample_depth(disp, up_mask) 249 | disp_predictions[("disp_up", itr)] = disp_up 250 | else: 251 | if (iters-1)==itr: 252 | disp_up = self.upsample_depth(disp, up_mask) 253 | disp_predictions[("disp_up", itr)] = disp_up 254 | 255 | 256 | return disp_predictions 257 | -------------------------------------------------------------------------------- /core/__pycache__/R_MSFM.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/core/__pycache__/R_MSFM.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/corr.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/core/__pycache__/corr.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/datasets.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/core/__pycache__/datasets.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/depthRAFT.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/core/__pycache__/depthRAFT.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/extractor.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/core/__pycache__/extractor.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/raft.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/core/__pycache__/raft.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/triplet_attention.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/core/__pycache__/triplet_attention.cpython-38.pyc -------------------------------------------------------------------------------- /core/__pycache__/update.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/core/__pycache__/update.cpython-38.pyc -------------------------------------------------------------------------------- /core/update.py: -------------------------------------------------------------------------------- 1 | import sys 2 | 3 | sys.path.append('..') 4 | import torch.nn as nn 5 | import torch 6 | 7 | 8 | class ConvBlock(nn.Module): 9 | """Layer to perform a convolution followed by LeakyReLU 10 | """ 11 | 12 | def __init__(self, in_channels, out_channels): 13 | super(ConvBlock, self).__init__() 14 | 15 | self.conv = Conv3x3(in_channels, out_channels) 16 | self.nonlin = nn.LeakyReLU(inplace=True) 17 | 18 | def forward(self, x): 19 | out = self.conv(x) 20 | out = self.nonlin(out) 21 | return out 22 | 23 | 24 | class Conv3x3(nn.Module): 25 | """Layer to pad and convolve input 26 | """ 27 | 28 | def __init__(self, in_channels, out_channels, use_refl=True): 29 | super(Conv3x3, self).__init__() 30 | 31 | if use_refl: 32 | self.pad = nn.ReflectionPad2d(1) 33 | else: 34 | self.pad = nn.ZeroPad2d(1) 35 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) 36 | 37 | def forward(self, x): 38 | out = self.pad(x) 39 | out = self.conv(out) 40 | return out 41 | 42 | 43 | class dispHead(nn.Module): 44 | def __init__(self): 45 | super(dispHead, self).__init__() 46 | outD = 1 47 | 48 | self.covd1 = torch.nn.Sequential(nn.ReflectionPad2d(1), 49 | torch.nn.Conv2d(in_channels=192, out_channels=256, kernel_size=3, stride=1, 50 | padding=0, bias=True), 51 | torch.nn.LeakyReLU(inplace=True)) 52 | 53 | self.covd2 = torch.nn.Sequential(nn.ReflectionPad2d(1), 54 | torch.nn.Conv2d(in_channels=256, out_channels=outD, kernel_size=3, stride=1, 55 | padding=0, bias=True)) 56 | 57 | def forward(self, x): 58 | return self.covd2(self.covd1(x)) 59 | 60 | 61 | class BasicMotionEncoder(nn.Module): 62 | def __init__(self): 63 | super(BasicMotionEncoder, self).__init__() 64 | # inD = 1 65 | 66 | self.convc1 = ConvBlock(128, 160) 67 | self.convc2 = ConvBlock(160, 128) 68 | 69 | self.convf1 = torch.nn.Sequential( 70 | nn.ReflectionPad2d(3), 71 | torch.nn.Conv2d(in_channels=1, out_channels=64, kernel_size=7, padding=0, bias=True), 72 | torch.nn.LeakyReLU(inplace=True)) 73 | self.convf2 = torch.nn.Sequential( 74 | nn.ReflectionPad2d(1), 75 | torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, padding=0, bias=True), 76 | torch.nn.LeakyReLU(inplace=True)) 77 | 78 | self.conv = ConvBlock(128 + 32, 192 - 1) 79 | 80 | def forward(self, depth, corr): 81 | cor = self.convc1(corr) 82 | cor = self.convc2(cor) 83 | 84 | dep = self.convf1(depth) 85 | dep = self.convf2(dep) 86 | 87 | cor_depth = torch.cat([cor, dep], dim=1) 88 | out = self.conv(cor_depth) 89 | return torch.cat([out, depth], dim=1) 90 | 91 | 92 | class BasicUpdateBlock(nn.Module): 93 | def __init__(self): 94 | super(BasicUpdateBlock, self).__init__() 95 | self.encoder = BasicMotionEncoder() 96 | 97 | self.flow_head = dispHead() 98 | self.mask = nn.Sequential( 99 | nn.ReflectionPad2d(1), 100 | nn.Conv2d(192, 324, 3), 101 | nn.LeakyReLU(inplace=True), 102 | nn.Conv2d(324, 64 * 9, 1, padding=0)) 103 | 104 | def forward(self, net, corr, depth): 105 | net = self.encoder(depth, corr) 106 | delta_depth = self.flow_head(net) 107 | 108 | # scale mask to balence gradients 109 | mask = .25 * self.mask(net) 110 | 111 | return net, mask, delta_depth 112 | 113 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | from .kitti_dataset import KITTIRAWDataset, KITTIOdomDataset, KITTIDepthDataset 2 | -------------------------------------------------------------------------------- /datasets/kitti_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import os 10 | import skimage.transform 11 | import numpy as np 12 | import PIL.Image as pil 13 | 14 | from kitti_utils import generate_depth_map 15 | from .mono_dataset import MonoDataset 16 | 17 | 18 | class KITTIDataset(MonoDataset): 19 | """Superclass for different types of KITTI dataset loaders 20 | """ 21 | def __init__(self, *args, **kwargs): 22 | super(KITTIDataset, self).__init__(*args, **kwargs) 23 | 24 | # NOTE: Make sure your intrinsics matrix is *normalized* by the original image size 25 | self.K = np.array([[0.58, 0, 0.5, 0], 26 | [0, 1.92, 0.5, 0], 27 | [0, 0, 1, 0], 28 | [0, 0, 0, 1]], dtype=np.float32) 29 | 30 | self.full_res_shape = (1242, 375) 31 | self.side_map = {"2": 2, "3": 3, "l": 2, "r": 3} 32 | 33 | def check_depth(self): 34 | line = self.filenames[0].split() 35 | scene_name = line[0] 36 | frame_index = int(line[1]) 37 | 38 | velo_filename = os.path.join( 39 | self.data_path, 40 | scene_name, 41 | "velodyne_points/data/{:010d}.bin".format(int(frame_index))) 42 | 43 | return os.path.isfile(velo_filename) 44 | 45 | def get_color(self, folder, frame_index, side, do_flip): 46 | color = self.loader(self.get_image_path(folder, frame_index, side)) 47 | 48 | if do_flip: 49 | color = color.transpose(pil.FLIP_LEFT_RIGHT) 50 | 51 | return color 52 | 53 | 54 | class KITTIRAWDataset(KITTIDataset): 55 | """KITTI dataset which loads the original velodyne depth maps for ground truth 56 | """ 57 | def __init__(self, *args, **kwargs): 58 | super(KITTIRAWDataset, self).__init__(*args, **kwargs) 59 | 60 | def get_image_path(self, folder, frame_index, side): 61 | f_str = "{:010d}{}".format(frame_index, self.img_ext) 62 | image_path = os.path.join( 63 | self.data_path, folder, "image_0{}/data".format(self.side_map[side]), f_str) 64 | return image_path 65 | 66 | def get_depth(self, folder, frame_index, side, do_flip): 67 | calib_path = os.path.join(self.data_path, folder.split("/")[0]) 68 | 69 | velo_filename = os.path.join( 70 | self.data_path, 71 | folder, 72 | "velodyne_points/data/{:010d}.bin".format(int(frame_index))) 73 | 74 | depth_gt = generate_depth_map(calib_path, velo_filename, self.side_map[side]) 75 | depth_gt = skimage.transform.resize( 76 | depth_gt, self.full_res_shape[::-1], order=0, preserve_range=True, mode='constant') 77 | 78 | if do_flip: 79 | depth_gt = np.fliplr(depth_gt) 80 | 81 | return depth_gt 82 | 83 | 84 | class KITTIOdomDataset(KITTIDataset): 85 | """KITTI dataset for odometry training and testing 86 | """ 87 | def __init__(self, *args, **kwargs): 88 | super(KITTIOdomDataset, self).__init__(*args, **kwargs) 89 | 90 | def get_image_path(self, folder, frame_index, side): 91 | f_str = "{:06d}{}".format(frame_index, self.img_ext) 92 | image_path = os.path.join( 93 | self.data_path, 94 | "sequences/{:02d}".format(int(folder)), 95 | "image_{}".format(self.side_map[side]), 96 | f_str) 97 | return image_path 98 | 99 | 100 | class KITTIDepthDataset(KITTIDataset): 101 | """KITTI dataset which uses the updated ground truth depth maps 102 | """ 103 | def __init__(self, *args, **kwargs): 104 | super(KITTIDepthDataset, self).__init__(*args, **kwargs) 105 | 106 | def get_image_path(self, folder, frame_index, side): 107 | f_str = "{:010d}{}".format(frame_index, self.img_ext) 108 | image_path = os.path.join( 109 | self.data_path, 110 | folder, 111 | "image_0{}/data".format(self.side_map[side]), 112 | f_str) 113 | return image_path 114 | 115 | def get_depth(self, folder, frame_index, side, do_flip): 116 | f_str = "{:010d}.png".format(frame_index) 117 | depth_path = os.path.join( 118 | self.data_path, 119 | folder, 120 | "proj_depth/groundtruth/image_0{}".format(self.side_map[side]), 121 | f_str) 122 | 123 | depth_gt = pil.open(depth_path) 124 | depth_gt = depth_gt.resize(self.full_res_shape, pil.NEAREST) 125 | depth_gt = np.array(depth_gt).astype(np.float32) / 256 126 | 127 | if do_flip: 128 | depth_gt = np.fliplr(depth_gt) 129 | 130 | return depth_gt 131 | -------------------------------------------------------------------------------- /datasets/mono_dataset.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import os 10 | import random 11 | import numpy as np 12 | import copy 13 | from PIL import Image # using pillow-simd for increased speed 14 | 15 | import torch 16 | import torch.utils.data as data 17 | from torchvision import transforms 18 | 19 | 20 | def pil_loader(path): 21 | # open path as file to avoid ResourceWarning 22 | # (https://github.com/python-pillow/Pillow/issues/835) 23 | with open(path, 'rb') as f: 24 | with Image.open(f) as img: 25 | return img.convert('RGB') 26 | 27 | 28 | class MonoDataset(data.Dataset): 29 | """Superclass for monocular dataloaders 30 | 31 | Args: 32 | data_path 33 | filenames 34 | height 35 | width 36 | frame_idxs 37 | num_scales 38 | is_train 39 | img_ext 40 | """ 41 | def __init__(self, 42 | data_path, 43 | filenames, 44 | height, 45 | width, 46 | frame_idxs, 47 | is_train=False, 48 | img_ext='.jpg'): 49 | super(MonoDataset, self).__init__() 50 | 51 | self.data_path = data_path 52 | self.filenames = filenames 53 | self.height = height 54 | self.width = width 55 | self.interp = Image.ANTIALIAS 56 | 57 | self.frame_idxs = frame_idxs 58 | 59 | self.is_train = is_train 60 | self.img_ext = img_ext 61 | 62 | self.loader = pil_loader 63 | self.to_tensor = transforms.ToTensor() 64 | 65 | # We need to specify augmentations differently in newer versions of torchvision. 66 | # We first try the newer tuple version; if this fails we fall back to scalars 67 | try: 68 | self.brightness = (0.8, 1.2) 69 | self.contrast = (0.8, 1.2) 70 | self.saturation = (0.8, 1.2) 71 | self.hue = (-0.1, 0.1) 72 | transforms.ColorJitter( 73 | self.brightness, self.contrast, self.saturation, self.hue) 74 | except TypeError: 75 | self.brightness = 0.2 76 | self.contrast = 0.2 77 | self.saturation = 0.2 78 | self.hue = 0.1 79 | 80 | self.resize = {} 81 | i = 0 82 | s = 2 ** i 83 | self.resize[i] = transforms.Resize((self.height // s, self.width // s), 84 | interpolation=self.interp) 85 | 86 | self.load_depth = self.check_depth() 87 | 88 | def preprocess(self, inputs, color_aug): 89 | """Resize colour images to the required scales and augment if required 90 | 91 | We create the color_aug object in advance and apply the same augmentation to all 92 | images in this item. This ensures that all images input to the pose network receive the 93 | same augmentation. 94 | """ 95 | for k in list(inputs): 96 | frame = inputs[k] 97 | if "color" in k: 98 | n, im, i = k 99 | i = 0 100 | inputs[(n, im, i)] = self.resize[i](inputs[(n, im, i - 1)]) 101 | 102 | for k in list(inputs): 103 | f = inputs[k] 104 | if "color" in k: 105 | n, im, i = k 106 | inputs[(n, im, i)] = self.to_tensor(f) 107 | inputs[(n + "_aug", im, i)] = self.to_tensor(color_aug(f)) 108 | 109 | def __len__(self): 110 | return len(self.filenames) 111 | 112 | def __getitem__(self, index): 113 | """Returns a single training item from the dataset as a dictionary. 114 | 115 | Values correspond to torch tensors. 116 | Keys in the dictionary are either strings or tuples: 117 | 118 | ("color", , ) for raw colour images, 119 | ("color_aug", , ) for augmented colour images, 120 | ("K", scale) or ("inv_K", scale) for camera intrinsics, 121 | "stereo_T" for camera extrinsics, and 122 | "depth_gt" for ground truth depth maps. 123 | 124 | is either: 125 | an integer (e.g. 0, -1, or 1) representing the temporal step relative to 'index', 126 | or 127 | "s" for the opposite image in the stereo pair. 128 | 129 | is an integer representing the scale of the image relative to the fullsize image: 130 | -1 images at native resolution as loaded from disk 131 | 0 images resized to (self.width, self.height ) 132 | 1 images resized to (self.width // 2, self.height // 2) 133 | 2 images resized to (self.width // 4, self.height // 4) 134 | 3 images resized to (self.width // 8, self.height // 8) 135 | """ 136 | inputs = {} 137 | 138 | do_color_aug = self.is_train and random.random() > 0.5 139 | do_flip = self.is_train and random.random() > 0.5 140 | 141 | line = self.filenames[index].split() 142 | folder = line[0] 143 | 144 | if len(line) == 3: 145 | frame_index = int(line[1]) 146 | else: 147 | frame_index = 0 148 | 149 | if len(line) == 3: 150 | side = line[2] 151 | else: 152 | side = None 153 | 154 | for i in self.frame_idxs: 155 | if i == "s": 156 | other_side = {"r": "l", "l": "r"}[side] 157 | inputs[("color", i, -1)] = self.get_color(folder, frame_index, other_side, do_flip) 158 | else: 159 | inputs[("color", i, -1)] = self.get_color(folder, frame_index + i, side, do_flip) 160 | 161 | # adjusting intrinsics to match each scale in the pyramid 162 | scale = 0 163 | K = self.K.copy() 164 | 165 | K[0, :] *= self.width // (2 ** scale) 166 | K[1, :] *= self.height // (2 ** scale) 167 | 168 | inv_K = np.linalg.pinv(K) 169 | 170 | inputs[("K", scale)] = torch.from_numpy(K) 171 | inputs[("inv_K", scale)] = torch.from_numpy(inv_K) 172 | 173 | if do_color_aug: 174 | color_aug = transforms.ColorJitter( 175 | self.brightness, self.contrast, self.saturation, self.hue) 176 | else: 177 | color_aug = (lambda x: x) 178 | 179 | self.preprocess(inputs, color_aug) 180 | 181 | for i in self.frame_idxs: 182 | del inputs[("color", i, -1)] 183 | del inputs[("color_aug", i, -1)] 184 | 185 | if self.load_depth: 186 | depth_gt = self.get_depth(folder, frame_index, side, do_flip) 187 | inputs["depth_gt"] = np.expand_dims(depth_gt, 0) 188 | inputs["depth_gt"] = torch.from_numpy(inputs["depth_gt"].astype(np.float32)) 189 | 190 | if "s" in self.frame_idxs: 191 | stereo_T = np.eye(4, dtype=np.float32) 192 | baseline_sign = -1 if do_flip else 1 193 | side_sign = -1 if side == "l" else 1 194 | stereo_T[0, 3] = side_sign * baseline_sign * 0.1 195 | 196 | inputs["stereo_T"] = torch.from_numpy(stereo_T) 197 | 198 | return inputs 199 | 200 | def get_color(self, folder, frame_index, side, do_flip): 201 | raise NotImplementedError 202 | 203 | def check_depth(self): 204 | raise NotImplementedError 205 | 206 | def get_depth(self, folder, frame_index, side, do_flip): 207 | raise NotImplementedError 208 | -------------------------------------------------------------------------------- /evaluate_depth.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | import sys 3 | import os 4 | import cv2 5 | import numpy as np 6 | sys.path.append('./core') 7 | import torch 8 | from torch.utils.data import DataLoader 9 | 10 | from layers import disp_to_depth 11 | from utils import readlines 12 | from options import RMSFM2Options 13 | import datasets 14 | import networks 15 | from R_MSFM import R_MSFM3,R_MSFM6 16 | cv2.setNumThreads(0) # This speeds up evaluation 5x on our unix systems (OpenCV 3.3.1) 17 | 18 | 19 | splits_dir = os.path.join(os.path.dirname(__file__), "splits") 20 | 21 | # Models which were trained with stereo supervision were trained with a nominal 22 | # baseline of 0.1 units. The KITTI rig has a baseline of 54cm. Therefore, 23 | # to convert our stereo predictions to real-world scale we multiply our depths by 5.4. 24 | STEREO_SCALE_FACTOR = 5.4 25 | 26 | 27 | def compute_errors(gt, pred): 28 | """Computation of error metrics between predicted and ground truth depths 29 | """ 30 | thresh = np.maximum((gt / pred), (pred / gt)) 31 | a1 = (thresh < 1.25 ).mean() 32 | a2 = (thresh < 1.25 ** 2).mean() 33 | a3 = (thresh < 1.25 ** 3).mean() 34 | 35 | rmse = (gt - pred) ** 2 36 | rmse = np.sqrt(rmse.mean()) 37 | 38 | rmse_log = (np.log(gt) - np.log(pred)) ** 2 39 | rmse_log = np.sqrt(rmse_log.mean()) 40 | 41 | abs_rel = np.mean(np.abs(gt - pred) / gt) 42 | 43 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 44 | 45 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 46 | 47 | 48 | def batch_post_process_disparity(l_disp, r_disp): 49 | """Apply the disparity post-processing method as introduced in Monodepthv1 50 | """ 51 | _, h, w = l_disp.shape 52 | m_disp = 0.5 * (l_disp + r_disp) 53 | l, _ = np.meshgrid(np.linspace(0, 1, w), np.linspace(0, 1, h)) 54 | l_mask = (1.0 - np.clip(20 * (l - 0.05), 0, 1))[None, ...] 55 | r_mask = l_mask[:, :, ::-1] 56 | return r_mask * l_disp + l_mask * r_disp + (1.0 - l_mask - r_mask) * m_disp 57 | 58 | 59 | def evaluate(opt): 60 | """Evaluates a pretrained model using a specified test set 61 | """ 62 | MIN_DEPTH = 1e-3 63 | MAX_DEPTH = 80 64 | 65 | assert sum((opt.eval_mono, opt.eval_stereo)) == 1, \ 66 | "Please choose mono or stereo evaluation by setting either --eval_mono or --eval_stereo" 67 | 68 | if opt.ext_disp_to_eval is None: 69 | 70 | opt.load_weights_folder = os.path.expanduser(opt.load_weights_folder) 71 | 72 | assert os.path.isdir(opt.load_weights_folder), \ 73 | "Cannot find a folder at {}".format(opt.load_weights_folder) 74 | 75 | print("-> Loading weights from {}".format(opt.load_weights_folder)) 76 | 77 | filenames = readlines(os.path.join(splits_dir, opt.eval_split, "test_files.txt")) 78 | encoder_path = os.path.join(opt.load_weights_folder, "encoder.pth") 79 | decoder_path = os.path.join(opt.load_weights_folder, "depth.pth") 80 | 81 | encoder_dict = torch.load(encoder_path) 82 | 83 | dataset = datasets.KITTIRAWDataset(opt.data_path, filenames, 84 | encoder_dict['height'], encoder_dict['width'], 85 | [0], is_train=False) 86 | dataloader = DataLoader(dataset, 1, shuffle=False, num_workers=0, 87 | pin_memory=True, drop_last=False) 88 | 89 | encoder =networks.ResnetEncoder(opt.num_layers, True) 90 | 91 | if opt.iters = 6: 92 | depth_decoder = R_MSFM6() 93 | else: 94 | depth_decoder = R_MSFM3() 95 | 96 | model_dict = encoder.state_dict() 97 | encoder.load_state_dict({k: v for k, v in encoder_dict.items() if k in model_dict}) 98 | depth_decoder.load_state_dict(torch.load(decoder_path)) 99 | 100 | encoder.cuda() 101 | encoder.eval() 102 | depth_decoder.cuda() 103 | depth_decoder.eval() 104 | 105 | pred_disps = [] 106 | 107 | print("-> Computing predictions with size {}x{}".format( 108 | encoder_dict['width'], encoder_dict['height'])) 109 | 110 | with torch.no_grad(): 111 | for data in dataloader: 112 | input_color = data[("color", 0, 0)].cuda() 113 | 114 | if opt.post_process: 115 | # Post-processed results require each image to have two forward passes 116 | input_color = torch.cat((input_color, torch.flip(input_color, [3])), 0) 117 | 118 | output = depth_decoder(encoder(input_color), iters=opt.iters) 119 | 120 | pred_disp, _ = disp_to_depth(output[("disp_up", opt.iters-1)], opt.min_depth, opt.max_depth) 121 | pred_disp = pred_disp.cpu()[:, 0].numpy() 122 | 123 | if opt.post_process: 124 | N = pred_disp.shape[0] // 2 125 | pred_disp = batch_post_process_disparity(pred_disp[:N], pred_disp[N:, :, ::-1]) 126 | 127 | pred_disps.append(pred_disp) 128 | 129 | pred_disps = np.concatenate(pred_disps) 130 | 131 | else: 132 | # Load predictions from file 133 | print("-> Loading predictions from {}".format(opt.ext_disp_to_eval)) 134 | pred_disps = np.load(opt.ext_disp_to_eval) 135 | 136 | if opt.eval_eigen_to_benchmark: 137 | eigen_to_benchmark_ids = np.load( 138 | os.path.join(splits_dir, "benchmark", "eigen_to_benchmark_ids.npy")) 139 | 140 | pred_disps = pred_disps[eigen_to_benchmark_ids] 141 | 142 | if opt.save_pred_disps: 143 | output_path = os.path.join( 144 | opt.load_weights_folder, "disps_{}_split.npy".format(opt.eval_split)) 145 | print("-> Saving predicted disparities to ", output_path) 146 | np.save(output_path, pred_disps) 147 | 148 | if opt.no_eval: 149 | print("-> Evaluation disabled. Done.") 150 | quit() 151 | 152 | elif opt.eval_split == 'benchmark': 153 | save_dir = os.path.join(opt.load_weights_folder, "benchmark_predictions") 154 | print("-> Saving out benchmark predictions to {}".format(save_dir)) 155 | if not os.path.exists(save_dir): 156 | os.makedirs(save_dir) 157 | 158 | for idx in range(len(pred_disps)): 159 | disp_resized = cv2.resize(pred_disps[idx], (1216, 352)) 160 | depth = STEREO_SCALE_FACTOR / disp_resized 161 | depth = np.clip(depth, 0, 80) 162 | depth = np.uint16(depth * 256) 163 | save_path = os.path.join(save_dir, "{:010d}.png".format(idx)) 164 | cv2.imwrite(save_path, depth) 165 | 166 | print("-> No ground truth is available for the KITTI benchmark, so not evaluating. Done.") 167 | quit() 168 | 169 | gt_path = os.path.join(splits_dir, opt.eval_split, "gt_depths.npz") 170 | gt_depths = np.load(gt_path, fix_imports=True, allow_pickle=True, encoding='latin1')["data"] 171 | 172 | print("-> Evaluating") 173 | 174 | if opt.eval_stereo: 175 | print(" Stereo evaluation - " 176 | "disabling median scaling, scaling by {}".format(STEREO_SCALE_FACTOR)) 177 | opt.disable_median_scaling = True 178 | opt.pred_depth_scale_factor = STEREO_SCALE_FACTOR 179 | else: 180 | print(" Mono evaluation - using median scaling") 181 | 182 | errors = [] 183 | ratios = [] 184 | 185 | for i in range(pred_disps.shape[0]): 186 | 187 | gt_depth = gt_depths[i] 188 | gt_height, gt_width = gt_depth.shape[:2] 189 | 190 | pred_disp = pred_disps[i] 191 | pred_disp = cv2.resize(pred_disp, (gt_width, gt_height)) 192 | pred_depth = 1 / pred_disp 193 | 194 | if opt.eval_split == "eigen": 195 | mask = np.logical_and(gt_depth > MIN_DEPTH, gt_depth < MAX_DEPTH) 196 | 197 | crop = np.array([0.40810811 * gt_height, 0.99189189 * gt_height, 198 | 0.03594771 * gt_width, 0.96405229 * gt_width]).astype(np.int32) 199 | crop_mask = np.zeros(mask.shape) 200 | crop_mask[crop[0]:crop[1], crop[2]:crop[3]] = 1 201 | mask = np.logical_and(mask, crop_mask) 202 | 203 | else: 204 | mask = gt_depth > 0 205 | 206 | pred_depth = pred_depth[mask] 207 | gt_depth = gt_depth[mask] 208 | 209 | pred_depth *= opt.pred_depth_scale_factor 210 | if not opt.disable_median_scaling: 211 | ratio = np.median(gt_depth) / np.median(pred_depth) 212 | ratios.append(ratio) 213 | pred_depth *= ratio 214 | 215 | pred_depth[pred_depth < MIN_DEPTH] = MIN_DEPTH 216 | pred_depth[pred_depth > MAX_DEPTH] = MAX_DEPTH 217 | 218 | errors.append(compute_errors(gt_depth, pred_depth)) 219 | 220 | if not opt.disable_median_scaling: 221 | ratios = np.array(ratios) 222 | med = np.median(ratios) 223 | print(" Scaling ratios | med: {:0.3f} | std: {:0.3f}".format(med, np.std(ratios / med))) 224 | 225 | mean_errors = np.array(errors).mean(0) 226 | 227 | print("\n " + ("{:>8} | " * 7).format("abs_rel", "sq_rel", "rmse", "rmse_log", "a1", "a2", "a3")) 228 | print(("&{: 8.3f} " * 7).format(*mean_errors.tolist()) + "\\\\") 229 | print("\n-> Done!") 230 | 231 | 232 | if __name__ == "__main__": 233 | options = RMSFM2Options() 234 | evaluate(options.parse()) 235 | 236 | ''' 237 | python evaluate_depth.py --load_weights_folder /path/to/your/weights/ --eval_mono 238 | 239 | ''' 240 | -------------------------------------------------------------------------------- /kitti_utils.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import numpy as np 5 | from collections import Counter 6 | 7 | 8 | def load_velodyne_points(filename): 9 | """Load 3D point cloud from KITTI file format 10 | (adapted from https://github.com/hunse/kitti) 11 | """ 12 | points = np.fromfile(filename, dtype=np.float32).reshape(-1, 4) 13 | points[:, 3] = 1.0 # homogeneous 14 | return points 15 | 16 | 17 | def read_calib_file(path): 18 | """Read KITTI calibration file 19 | (from https://github.com/hunse/kitti) 20 | """ 21 | float_chars = set("0123456789.e+- ") 22 | data = {} 23 | with open(path, 'r') as f: 24 | for line in f.readlines(): 25 | key, value = line.split(':', 1) 26 | value = value.strip() 27 | data[key] = value 28 | if float_chars.issuperset(value): 29 | # try to cast to float array 30 | try: 31 | data[key] = np.array(list(map(float, value.split(' ')))) 32 | except ValueError: 33 | # casting error: data[key] already eq. value, so pass 34 | pass 35 | 36 | return data 37 | 38 | 39 | def sub2ind(matrixSize, rowSub, colSub): 40 | """Convert row, col matrix subscripts to linear indices 41 | """ 42 | m, n = matrixSize 43 | return rowSub * (n-1) + colSub - 1 44 | 45 | 46 | def generate_depth_map(calib_dir, velo_filename, cam=2, vel_depth=False): 47 | """Generate a depth map from velodyne data 48 | """ 49 | # load calibration files 50 | cam2cam = read_calib_file(os.path.join(calib_dir, 'calib_cam_to_cam.txt')) 51 | velo2cam = read_calib_file(os.path.join(calib_dir, 'calib_velo_to_cam.txt')) 52 | velo2cam = np.hstack((velo2cam['R'].reshape(3, 3), velo2cam['T'][..., np.newaxis])) 53 | velo2cam = np.vstack((velo2cam, np.array([0, 0, 0, 1.0]))) 54 | 55 | # get image shape 56 | im_shape = cam2cam["S_rect_02"][::-1].astype(np.int32) 57 | 58 | # compute projection matrix velodyne->image plane 59 | R_cam2rect = np.eye(4) 60 | R_cam2rect[:3, :3] = cam2cam['R_rect_00'].reshape(3, 3) 61 | P_rect = cam2cam['P_rect_0'+str(cam)].reshape(3, 4) 62 | P_velo2im = np.dot(np.dot(P_rect, R_cam2rect), velo2cam) 63 | 64 | # load velodyne points and remove all behind image plane (approximation) 65 | # each row of the velodyne data is forward, left, up, reflectance 66 | velo = load_velodyne_points(velo_filename) 67 | velo = velo[velo[:, 0] >= 0, :] 68 | 69 | # project the points to the camera 70 | velo_pts_im = np.dot(P_velo2im, velo.T).T 71 | velo_pts_im[:, :2] = velo_pts_im[:, :2] / velo_pts_im[:, 2][..., np.newaxis] 72 | 73 | if vel_depth: 74 | velo_pts_im[:, 2] = velo[:, 0] 75 | 76 | # check if in bounds 77 | # use minus 1 to get the exact same value as KITTI matlab code 78 | velo_pts_im[:, 0] = np.round(velo_pts_im[:, 0]) - 1 79 | velo_pts_im[:, 1] = np.round(velo_pts_im[:, 1]) - 1 80 | val_inds = (velo_pts_im[:, 0] >= 0) & (velo_pts_im[:, 1] >= 0) 81 | val_inds = val_inds & (velo_pts_im[:, 0] < im_shape[1]) & (velo_pts_im[:, 1] < im_shape[0]) 82 | velo_pts_im = velo_pts_im[val_inds, :] 83 | 84 | # project to image 85 | depth = np.zeros((im_shape[:2])) 86 | depth[velo_pts_im[:, 1].astype(np.int), velo_pts_im[:, 0].astype(np.int)] = velo_pts_im[:, 2] 87 | 88 | # find the duplicate points and choose the closest depth 89 | inds = sub2ind(depth.shape, velo_pts_im[:, 1], velo_pts_im[:, 0]) 90 | dupe_inds = [item for item, count in Counter(inds).items() if count > 1] 91 | for dd in dupe_inds: 92 | pts = np.where(inds == dd)[0] 93 | x_loc = int(velo_pts_im[pts[0], 0]) 94 | y_loc = int(velo_pts_im[pts[0], 1]) 95 | depth[y_loc, x_loc] = velo_pts_im[pts, 2].min() 96 | depth[depth < 0] = 0 97 | 98 | return depth 99 | -------------------------------------------------------------------------------- /layers.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torch.nn.functional as F 14 | 15 | def disp_to_depth(disp, min_depth, max_depth): 16 | """Convert network's sigmoid output into depth prediction 17 | The formula for this conversion is given in the 'additional considerations' 18 | section of the paper. 19 | """ 20 | min_disp = 1 / max_depth 21 | max_disp = 1 / min_depth 22 | scaled_disp = min_disp + (max_disp - min_disp) * disp 23 | depth = 1 / scaled_disp 24 | return scaled_disp, depth 25 | 26 | 27 | 28 | def transformation_from_parameters(axisangle, translation, invert=False): 29 | """Convert the network's (axisangle, translation) output into a 4x4 matrix 30 | """ 31 | R = rot_from_axisangle(axisangle) 32 | t = translation.clone() 33 | 34 | if invert: 35 | R = R.transpose(1, 2) 36 | t *= -1 37 | 38 | T = get_translation_matrix(t) 39 | 40 | if invert: 41 | M = torch.matmul(R, T) 42 | else: 43 | M = torch.matmul(T, R) 44 | 45 | return M 46 | 47 | 48 | def get_translation_matrix(translation_vector): 49 | """Convert a translation vector into a 4x4 transformation matrix 50 | """ 51 | T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device) 52 | 53 | t = translation_vector.contiguous().view(-1, 3, 1) 54 | 55 | T[:, 0, 0] = 1 56 | T[:, 1, 1] = 1 57 | T[:, 2, 2] = 1 58 | T[:, 3, 3] = 1 59 | T[:, :3, 3, None] = t 60 | 61 | return T 62 | 63 | 64 | def rot_from_axisangle(vec): 65 | """Convert an axisangle rotation into a 4x4 transformation matrix 66 | (adapted from https://github.com/Wallacoloo/printipi) 67 | Input 'vec' has to be Bx1x3 68 | """ 69 | angle = torch.norm(vec, 2, 2, True) 70 | axis = vec / (angle + 1e-7) 71 | 72 | ca = torch.cos(angle) 73 | sa = torch.sin(angle) 74 | C = 1 - ca 75 | 76 | x = axis[..., 0].unsqueeze(1) 77 | y = axis[..., 1].unsqueeze(1) 78 | z = axis[..., 2].unsqueeze(1) 79 | 80 | xs = x * sa 81 | ys = y * sa 82 | zs = z * sa 83 | xC = x * C 84 | yC = y * C 85 | zC = z * C 86 | xyC = x * yC 87 | yzC = y * zC 88 | zxC = z * xC 89 | 90 | rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device) 91 | 92 | rot[:, 0, 0] = torch.squeeze(x * xC + ca) 93 | rot[:, 0, 1] = torch.squeeze(xyC - zs) 94 | rot[:, 0, 2] = torch.squeeze(zxC + ys) 95 | rot[:, 1, 0] = torch.squeeze(xyC + zs) 96 | rot[:, 1, 1] = torch.squeeze(y * yC + ca) 97 | rot[:, 1, 2] = torch.squeeze(yzC - xs) 98 | rot[:, 2, 0] = torch.squeeze(zxC - ys) 99 | rot[:, 2, 1] = torch.squeeze(yzC + xs) 100 | rot[:, 2, 2] = torch.squeeze(z * zC + ca) 101 | rot[:, 3, 3] = 1 102 | 103 | return rot 104 | 105 | 106 | class ConvBlock(nn.Module): 107 | """Layer to perform a convolution followed by ELU 108 | """ 109 | def __init__(self, in_channels, out_channels): 110 | super(ConvBlock, self).__init__() 111 | 112 | self.conv = Conv3x3(in_channels, out_channels) 113 | self.nonlin = nn.ELU(inplace=True) 114 | 115 | def forward(self, x): 116 | out = self.conv(x) 117 | out = self.nonlin(out) 118 | return out 119 | 120 | 121 | class Conv3x3(nn.Module): 122 | """Layer to pad and convolve input 123 | """ 124 | def __init__(self, in_channels, out_channels, use_refl=True): 125 | super(Conv3x3, self).__init__() 126 | 127 | if use_refl: 128 | self.pad = nn.ReflectionPad2d(1) 129 | else: 130 | self.pad = nn.ZeroPad2d(1) 131 | self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3) 132 | 133 | def forward(self, x): 134 | out = self.pad(x) 135 | out = self.conv(out) 136 | return out 137 | 138 | 139 | class BackprojectDepth(nn.Module): 140 | """Layer to transform a depth image into a point cloud 141 | """ 142 | def __init__(self, batch_size, height, width): 143 | super(BackprojectDepth, self).__init__() 144 | 145 | self.batch_size = batch_size 146 | self.height = height 147 | self.width = width 148 | 149 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') 150 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) 151 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), 152 | requires_grad=False) 153 | 154 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), 155 | requires_grad=False) 156 | 157 | self.pix_coords = torch.unsqueeze(torch.stack( 158 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) 159 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) 160 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), 161 | requires_grad=False) 162 | 163 | def forward(self, depth, inv_K): 164 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) 165 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points 166 | cam_points = torch.cat([cam_points, self.ones], 1) 167 | 168 | return cam_points 169 | 170 | 171 | class Project3D(nn.Module): 172 | """Layer which projects 3D points into a camera with intrinsics K and at position T 173 | """ 174 | def __init__(self, batch_size, height, width, eps=1e-7): 175 | super(Project3D, self).__init__() 176 | 177 | self.batch_size = batch_size 178 | self.height = height 179 | self.width = width 180 | self.eps = eps 181 | 182 | def forward(self, points, K, T): 183 | P = torch.matmul(K, T)[:, :3, :] 184 | 185 | cam_points = torch.matmul(P, points) 186 | Z = cam_points[:, 2].clamp(min=1e-3) 187 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) 188 | pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) 189 | pix_coords = pix_coords.permute(0, 2, 3, 1) 190 | pix_coords[..., 0] /= self.width - 1 191 | pix_coords[..., 1] /= self.height - 1 192 | pix_coords = (pix_coords - 0.5) * 2 193 | return pix_coords, Z.reshape(self.batch_size, 1, self.height, self.width) 194 | 195 | 196 | 197 | class BackprojectDepthLoss(nn.Module): 198 | """Layer to transform a depth image into a point cloud 199 | """ 200 | def __init__(self, batch_size, height, width): 201 | super(BackprojectDepthLoss, self).__init__() 202 | 203 | self.batch_size = batch_size 204 | self.height = height 205 | self.width = width 206 | 207 | meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy') 208 | self.id_coords = np.stack(meshgrid, axis=0).astype(np.float32) 209 | self.id_coords = nn.Parameter(torch.from_numpy(self.id_coords), 210 | requires_grad=False) 211 | 212 | self.ones = nn.Parameter(torch.ones(self.batch_size, 1, self.height * self.width), 213 | requires_grad=False) 214 | 215 | self.pix_coords = torch.unsqueeze(torch.stack( 216 | [self.id_coords[0].view(-1), self.id_coords[1].view(-1)], 0), 0) 217 | self.pix_coords = self.pix_coords.repeat(batch_size, 1, 1) 218 | self.pix_coords = nn.Parameter(torch.cat([self.pix_coords, self.ones], 1), 219 | requires_grad=False) 220 | 221 | def forward(self, depth, inv_K, coords1=None): 222 | if coords1 == None: 223 | cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords) 224 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points 225 | cam_points = torch.cat([cam_points, self.ones], 1) 226 | return cam_points, self.id_coords 227 | else: 228 | coords1 = coords1.permute(0,3,1,2) 229 | coords1s = [] 230 | for i in range(self.batch_size): 231 | xAndy = torch.stack( 232 | [coords1[i,0].view(-1), coords1[i,1].view(-1)], 0) 233 | coords1s.append(xAndy) 234 | coords1 = torch.stack(coords1s, 0) 235 | cam_points = nn.Parameter(torch.cat([coords1, self.ones], 1), 236 | requires_grad=False) 237 | cam_points = torch.matmul(inv_K[:, :3, :3], cam_points) 238 | cam_points = depth.view(self.batch_size, 1, -1) * cam_points 239 | cam_points = torch.cat([cam_points, self.ones], 1) 240 | return cam_points 241 | 242 | 243 | class Project3DLoss(nn.Module): 244 | """Layer which projects 3D points into a camera with intrinsics K and at position T 245 | """ 246 | def __init__(self, batch_size, height, width, eps=1e-7): 247 | super(Project3DLoss, self).__init__() 248 | 249 | self.batch_size = batch_size 250 | self.height = height 251 | self.width = width 252 | self.eps = eps 253 | 254 | def forward(self, points, K, T, flow): 255 | P = torch.matmul(K, T)[:, :3, :] 256 | 257 | cam_points = torch.matmul(P, points) 258 | 259 | pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps) 260 | pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width) + flow 261 | pix_coords = pix_coords.permute(0, 2, 3, 1) 262 | pix_coordsOrg = pix_coords.clone() 263 | pix_coords[..., 0] /= self.width - 1 264 | pix_coords[..., 1] /= self.height - 1 265 | pix_coords = (pix_coords - 0.5) * 2 266 | return pix_coords, pix_coordsOrg 267 | 268 | def upsample(x): 269 | """Upsample input tensor by a factor of 2 270 | """ 271 | return F.interpolate(x, scale_factor=2, mode="nearest") 272 | 273 | 274 | def get_smooth_loss(disp, img): 275 | """Computes the smoothness loss for a disparity image 276 | The color image is used for edge-aware smoothness 277 | """ 278 | grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:]) 279 | grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :]) 280 | 281 | grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True) 282 | grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True) 283 | 284 | grad_disp_x *= torch.exp(-grad_img_x) 285 | grad_disp_y *= torch.exp(-grad_img_y) 286 | 287 | return grad_disp_x.mean() + grad_disp_y.mean() 288 | 289 | 290 | class SSIM(nn.Module): 291 | """Layer to compute the SSIM loss between a pair of images 292 | """ 293 | def __init__(self): 294 | super(SSIM, self).__init__() 295 | self.mu_x_pool = nn.AvgPool2d(3, 1) 296 | self.mu_y_pool = nn.AvgPool2d(3, 1) 297 | self.sig_x_pool = nn.AvgPool2d(3, 1) 298 | self.sig_y_pool = nn.AvgPool2d(3, 1) 299 | self.sig_xy_pool = nn.AvgPool2d(3, 1) 300 | 301 | self.refl = nn.ReflectionPad2d(1) 302 | 303 | self.C1 = 0.01 ** 2 304 | self.C2 = 0.03 ** 2 305 | 306 | def forward(self, x, y): 307 | x = self.refl(x) 308 | y = self.refl(y) 309 | 310 | mu_x = self.mu_x_pool(x) 311 | mu_y = self.mu_y_pool(y) 312 | 313 | sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2 314 | sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2 315 | sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y 316 | 317 | SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2) 318 | SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2) 319 | 320 | return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1) 321 | 322 | 323 | def compute_depth_errors(gt, pred): 324 | """Computation of error metrics between predicted and ground truth depths 325 | """ 326 | thresh = torch.max((gt / pred), (pred / gt)) 327 | a1 = (thresh < 1.25 ).float().mean() 328 | a2 = (thresh < 1.25 ** 2).float().mean() 329 | a3 = (thresh < 1.25 ** 3).float().mean() 330 | 331 | rmse = (gt - pred) ** 2 332 | rmse = torch.sqrt(rmse.mean()) 333 | 334 | rmse_log = (torch.log(gt) - torch.log(pred)) ** 2 335 | rmse_log = torch.sqrt(rmse_log.mean()) 336 | 337 | abs_rel = torch.mean(torch.abs(gt - pred) / gt) 338 | 339 | sq_rel = torch.mean((gt - pred) ** 2 / gt) 340 | 341 | return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3 342 | -------------------------------------------------------------------------------- /networks/__init__.py: -------------------------------------------------------------------------------- 1 | from .resnet_encoder import ResnetEncoder,ResnetEncoder2 2 | from .pose_decoder import PoseDecoder 3 | from .pose_cnn import PoseCNN 4 | 5 | -------------------------------------------------------------------------------- /networks/__pycache__/__init__.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/networks/__pycache__/__init__.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/depthRAFT.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/networks/__pycache__/depthRAFT.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/depth_decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/networks/__pycache__/depth_decoder.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/pose_cnn.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/networks/__pycache__/pose_cnn.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/pose_decoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/networks/__pycache__/pose_decoder.cpython-38.pyc -------------------------------------------------------------------------------- /networks/__pycache__/resnet_encoder.cpython-38.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/jsczzzk/R-MSFM/5dd7c65e880ea9d9b317dcac79cd15ac11c33d62/networks/__pycache__/resnet_encoder.cpython-38.pyc -------------------------------------------------------------------------------- /networks/pose_cnn.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | 13 | class PoseCNN(nn.Module): 14 | def __init__(self, num_input_frames): 15 | super(PoseCNN, self).__init__() 16 | 17 | self.num_input_frames = num_input_frames 18 | 19 | self.convs = {} 20 | self.convs[0] = nn.Conv2d(3 * num_input_frames, 16, 7, 2, 3) 21 | self.convs[1] = nn.Conv2d(16, 32, 5, 2, 2) 22 | self.convs[2] = nn.Conv2d(32, 64, 3, 2, 1) 23 | self.convs[3] = nn.Conv2d(64, 128, 3, 2, 1) 24 | self.convs[4] = nn.Conv2d(128, 256, 3, 2, 1) 25 | self.convs[5] = nn.Conv2d(256, 256, 3, 2, 1) 26 | self.convs[6] = nn.Conv2d(256, 256, 3, 2, 1) 27 | 28 | self.pose_conv = nn.Conv2d(256, 6 * (num_input_frames - 1), 1) 29 | 30 | self.num_convs = len(self.convs) 31 | 32 | self.relu = nn.ReLU(True) 33 | 34 | self.net = nn.ModuleList(list(self.convs.values())) 35 | 36 | def forward(self, out): 37 | 38 | for i in range(self.num_convs): 39 | out = self.convs[i](out) 40 | out = self.relu(out) 41 | 42 | out = self.pose_conv(out) 43 | out = out.mean(3).mean(2) 44 | 45 | out = 0.01 * out.view(-1, self.num_input_frames - 1, 1, 6) 46 | 47 | axisangle = out[..., :3] 48 | translation = out[..., 3:] 49 | 50 | return axisangle, translation 51 | -------------------------------------------------------------------------------- /networks/pose_decoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import torch 10 | import torch.nn as nn 11 | from collections import OrderedDict 12 | 13 | 14 | class PoseDecoder(nn.Module): 15 | def __init__(self, num_ch_enc, num_input_features, num_frames_to_predict_for=None, stride=1): 16 | super(PoseDecoder, self).__init__() 17 | 18 | self.num_ch_enc = num_ch_enc 19 | self.num_input_features = num_input_features 20 | 21 | if num_frames_to_predict_for is None: 22 | num_frames_to_predict_for = num_input_features - 1 23 | self.num_frames_to_predict_for = num_frames_to_predict_for 24 | 25 | self.convs = OrderedDict() 26 | self.convs[("squeeze")] = nn.Conv2d(self.num_ch_enc[-1], 256, 1) 27 | self.convs[("pose", 0)] = nn.Conv2d(num_input_features * 256, 256, 3, stride, 1) 28 | self.convs[("pose", 1)] = nn.Conv2d(256, 256, 3, stride, 1) 29 | self.convs[("pose", 2)] = nn.Conv2d(256, 6 * num_frames_to_predict_for, 1) 30 | 31 | self.relu = nn.ReLU() 32 | 33 | self.net = nn.ModuleList(list(self.convs.values())) 34 | 35 | def forward(self, input_features): 36 | last_features = [f[-1] for f in input_features] 37 | 38 | cat_features = [self.relu(self.convs["squeeze"](f)) for f in last_features] 39 | cat_features = torch.cat(cat_features, 1) 40 | 41 | out = cat_features 42 | for i in range(3): 43 | out = self.convs[("pose", i)](out) 44 | if i != 2: 45 | out = self.relu(out) 46 | 47 | out = out.mean(3).mean(2) 48 | 49 | out = 0.01 * out.view(-1, self.num_frames_to_predict_for, 1, 6) 50 | 51 | axisangle = out[..., :3] 52 | translation = out[..., 3:] 53 | 54 | return axisangle, translation 55 | -------------------------------------------------------------------------------- /networks/resnet_encoder.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | import os 9 | import numpy as np 10 | 11 | import torch 12 | import torch.nn as nn 13 | import torchvision.models as models 14 | import torch.utils.model_zoo as model_zoo 15 | 16 | 17 | class ResNetMultiImageInput(models.ResNet): 18 | """Constructs a resnet model with varying number of input images. 19 | Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py 20 | """ 21 | def __init__(self, block, layers, num_classes=1000, num_input_images=1): 22 | super(ResNetMultiImageInput, self).__init__(block, layers) 23 | self.inplanes = 64 24 | self.conv1 = nn.Conv2d(num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False) 25 | self.bn1 = nn.BatchNorm2d(64) 26 | self.relu = nn.ReLU(inplace=True) 27 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 28 | self.layer1 = self._make_layer(block, 64, layers[0]) 29 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 30 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 31 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 32 | 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 36 | elif isinstance(m, nn.BatchNorm2d): 37 | nn.init.constant_(m.weight, 1) 38 | nn.init.constant_(m.bias, 0) 39 | 40 | 41 | def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1): 42 | """Constructs a ResNet model. 43 | Args: 44 | num_layers (int): Number of resnet layers. Must be 18 or 50 45 | pretrained (bool): If True, returns a model pre-trained on ImageNet 46 | num_input_images (int): Number of frames stacked as input 47 | """ 48 | assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet" 49 | blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers] 50 | block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers] 51 | model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images) 52 | 53 | if pretrained: 54 | loaded = model_zoo.load_url(models.resnet.model_urls['resnet{}'.format(num_layers)]) 55 | loaded['conv1.weight'] = torch.cat( 56 | [loaded['conv1.weight']] * num_input_images, 1) / num_input_images 57 | model.load_state_dict(loaded) 58 | return model 59 | 60 | #encoder = networks.ResnetEncoder(18, False) 61 | #features = encoder(input_image) 62 | class ResnetEncoder(nn.Module): 63 | def __init__(self, num_layers, pretrained, num_input_images=1): 64 | super(ResnetEncoder, self).__init__() 65 | self.num_ch_enc = np.array([64, 64, 128, 256, 512]) 66 | resnets = {18: models.resnet18, 67 | 34: models.resnet34, 68 | 50: models.resnet50, 69 | 101: models.resnet101, 70 | 152: models.resnet152} 71 | if num_layers not in resnets: 72 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) 73 | 74 | if num_input_images > 1: 75 | self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) 76 | else: 77 | self.encoder = resnets[num_layers](pretrained) 78 | if num_layers > 34: 79 | self.num_ch_enc[1:] *= 4 80 | 81 | def forward(self, input_image): 82 | self.features = [] 83 | x = (input_image - 0.45) / 0.225 84 | x = self.encoder.conv1(x) 85 | x = self.encoder.bn1(x) 86 | self.features.append(self.encoder.relu(x)) 87 | self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) 88 | self.features.append(self.encoder.layer2(self.features[-1])) 89 | # self.features.append(self.encoder.layer3(self.features[-1])) 90 | # self.features.append(self.encoder.layer4(self.features[-1])) 91 | 92 | return self.features 93 | class ResnetEncoder2(nn.Module): 94 | """Pytorch module for a resnet encoder 95 | """ 96 | def __init__(self, num_layers, pretrained, num_input_images=1): 97 | super(ResnetEncoder2, self).__init__() 98 | 99 | self.num_ch_enc = np.array([64, 64, 128, 256, 512]) 100 | 101 | resnets = {18: models.resnet18, 102 | 34: models.resnet34, 103 | 50: models.resnet50, 104 | 101: models.resnet101, 105 | 152: models.resnet152} 106 | 107 | if num_layers not in resnets: 108 | raise ValueError("{} is not a valid number of resnet layers".format(num_layers)) 109 | 110 | if num_input_images > 1: 111 | self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images) 112 | else: 113 | self.encoder = resnets[num_layers](pretrained) 114 | 115 | if num_layers > 34: 116 | self.num_ch_enc[1:] *= 4 117 | 118 | def forward(self, input_image): 119 | self.features = [] 120 | x = (input_image - 0.45) / 0.225 121 | x = self.encoder.conv1(x) 122 | x = self.encoder.bn1(x) 123 | self.features.append(self.encoder.relu(x)) 124 | self.features.append(self.encoder.layer1(self.encoder.maxpool(self.features[-1]))) 125 | self.features.append(self.encoder.layer2(self.features[-1])) 126 | self.features.append(self.encoder.layer3(self.features[-1])) 127 | self.features.append(self.encoder.layer4(self.features[-1])) 128 | 129 | return self.features 130 | -------------------------------------------------------------------------------- /options.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | 9 | import os 10 | import argparse 11 | 12 | file_dir = os.path.dirname(__file__) # the directory that options.py resides in 13 | 14 | 15 | class RMSFM2Options: 16 | def __init__(self): 17 | self.parser = argparse.ArgumentParser(description="R-MSFM2 options") 18 | 19 | # PATHS 20 | self.parser.add_argument("--data_path", 21 | type=str, 22 | help="path to the training data", 23 | default=os.path.join(file_dir, "kitti_data")) 24 | self.parser.add_argument("--log_dir", 25 | type=str, 26 | help="log directory", 27 | default=os.path.join(os.path.expanduser("./"), "tmp")) 28 | 29 | # TRAINING options 30 | self.parser.add_argument("--model_name", 31 | type=str, 32 | help="the name of the folder to save the model in", 33 | default="R-MSFM2") 34 | self.parser.add_argument("--split", 35 | type=str, 36 | help="which training split to use", 37 | choices=["eigen_zhou", "eigen_full", "odom", "benchmark"], 38 | default="benchmark") 39 | self.parser.add_argument("--num_layers", 40 | type=int, 41 | help="number of resnet layers", 42 | default=18, 43 | choices=[18, 34, 50, 101, 152]) 44 | self.parser.add_argument("--dataset", 45 | type=str, 46 | help="dataset to train on", 47 | default="kitti", 48 | choices=["kitti", "kitti_odom", "kitti_depth", "kitti_test"]) 49 | self.parser.add_argument("--png", 50 | help="if set, trains from raw KITTI png files (instead of jpgs)", 51 | action="store_true") 52 | self.parser.add_argument("--height", 53 | type=int, 54 | help="input image height", 55 | default=192) 56 | self.parser.add_argument("--width", 57 | type=int, 58 | help="input image width", 59 | default=640) 60 | self.parser.add_argument("--disparity_smoothness", 61 | type=float, 62 | help="disparity smoothness weight", 63 | default=1e-3) 64 | self.parser.add_argument("--min_depth", 65 | type=float, 66 | help="minimum depth", 67 | default=0.1) 68 | self.parser.add_argument("--max_depth", 69 | type=float, 70 | help="maximum depth", 71 | default=100.0) 72 | self.parser.add_argument("--use_stereo", 73 | help="if set, uses stereo pair for training", 74 | action="store_true") 75 | self.parser.add_argument("--frame_ids", 76 | nargs="+", 77 | type=int, 78 | help="frames to load", 79 | default=[0, -1, 1]) 80 | self.parser.add_argument("--scales", 81 | nargs="+", 82 | type=int, 83 | help="scales used in the loss", 84 | default=[0, 1, 2, 3]) 85 | # OPTIMIZATION options 86 | self.parser.add_argument("--batch_size", 87 | type=int, 88 | help="batch size", 89 | default=12) 90 | self.parser.add_argument("--learning_rate", 91 | type=float, 92 | help="learning rate", 93 | default=1e-4) 94 | self.parser.add_argument("--num_epochs", 95 | type=int, 96 | help="number of epochs", 97 | default=20) 98 | 99 | 100 | # SYSTEM options 101 | self.parser.add_argument("--no_cuda", 102 | help="if set disables CUDA", 103 | action="store_true") 104 | self.parser.add_argument("--num_workers", 105 | type=int, 106 | help="number of dataloader workers", 107 | default=12) 108 | 109 | # LOADING options 110 | self.parser.add_argument("--load_weights_folder", 111 | type=str, 112 | help="name of model to load") 113 | self.parser.add_argument("--models_to_load", 114 | nargs="+", 115 | type=str, 116 | help="models to load", 117 | default=["pose_encoder", "depth", "pose","encoder"]) 118 | 119 | # LOGGING options 120 | self.parser.add_argument("--log_frequency", 121 | type=int, 122 | help="number of batches between each tensorboard log", 123 | default=250) 124 | self.parser.add_argument("--save_frequency", 125 | type=int, 126 | help="number of epochs between each save", 127 | default=1) 128 | 129 | # EVALUATION options 130 | self.parser.add_argument("--eval_stereo", 131 | help="if set evaluates in stereo mode", 132 | action="store_true") 133 | self.parser.add_argument("--eval_mono", 134 | help="if set evaluates in mono mode", 135 | action="store_true") 136 | self.parser.add_argument("--disable_median_scaling", 137 | help="if set disables median scaling in evaluation", 138 | action="store_true") 139 | self.parser.add_argument("--pred_depth_scale_factor", 140 | help="if set multiplies predictions by this number", 141 | type=float, 142 | default=1) 143 | self.parser.add_argument("--ext_disp_to_eval", 144 | type=str, 145 | help="optional path to a .npy disparities file to evaluate") 146 | self.parser.add_argument("--eval_split", 147 | type=str, 148 | default="eigen", 149 | choices=[ 150 | "eigen", "eigen_benchmark", "benchmark", "odom_9", "odom_10"], 151 | help="which split to run eval on") 152 | self.parser.add_argument("--save_pred_disps", 153 | help="if set saves predicted disparities", 154 | action="store_true") 155 | self.parser.add_argument("--no_eval", 156 | help="if set disables evaluation", 157 | action="store_true") 158 | self.parser.add_argument("--eval_eigen_to_benchmark", 159 | help="if set assume we are loading eigen results from npy but " 160 | "we want to evaluate using the new benchmark.", 161 | action="store_true") 162 | self.parser.add_argument("--eval_out_dir", 163 | help="if set will output the disparities to this folder", 164 | type=str) 165 | self.parser.add_argument("--post_process", 166 | help="if set will perform the flipping post processing " 167 | "from the original monodepth paper", 168 | action="store_true") 169 | self.parser.add_argument("--gc", 170 | help="if set will train-gc", 171 | action="store_true") 172 | 173 | self.parser.add_argument('--iters', type=int, default=6) 174 | self.parser.add_argument('--wdecay', type=float, default=.00005) 175 | self.parser.add_argument('--epsilon', type=float, default=1e-8) 176 | self.parser.add_argument('--clip', type=float, default=1.0) 177 | self.parser.add_argument('--gamma', type=float, default=0.8, help='exponential weighting') 178 | 179 | def parse(self): 180 | self.options = self.parser.parse_args() 181 | return self.options 182 | 183 | def parse_allmodel(self,model_name): 184 | a = ['--load_weights_folder', model_name] 185 | self.options = self.parser.parse_args(a) 186 | return self.options 187 | -------------------------------------------------------------------------------- /test_simple.py: -------------------------------------------------------------------------------- 1 | from __future__ import absolute_import, division, print_function 2 | 3 | import os 4 | import sys 5 | import glob 6 | sys.path.append('core') 7 | import argparse 8 | import numpy as np 9 | import PIL.Image as pil 10 | import matplotlib as mpl 11 | import matplotlib.cm as cm 12 | from R_MSFM import R_MSFM3,R_MSFM6 13 | import torch 14 | from torchvision import transforms, datasets 15 | import time 16 | import networks 17 | import time 18 | import shutil 19 | 20 | def disp_to_depth(disp, min_depth, max_depth): 21 | """Convert network's sigmoid output into depth prediction 22 | """ 23 | min_disp = 1 / max_depth 24 | max_disp = 1 / min_depth 25 | scaled_disp = min_disp + (max_disp - min_disp) * disp 26 | depth = 1 / scaled_disp 27 | return scaled_disp, depth 28 | 29 | def parse_args(): 30 | parser = argparse.ArgumentParser(description='Simple testing funtion for R-MSFM models.') 31 | parser.add_argument('--image_path', type=str,help='path to a test image or folder of images', required=True) 32 | parser.add_argument('--ext', type=str,help='image extension to search for in folder', default="jpeg") 33 | parser.add_argument('--model_path', type=str,help='path to a models.pth', default="./3M") 34 | parser.add_argument('--update', type=int,help='iterative update', default=3) 35 | parser.add_argument("--no_cuda",help='if set, disables CUDA',action='store_true') 36 | parser.add_argument("--x",help='if set, R-MSFMX',action='store_true') 37 | return parser.parse_args() 38 | 39 | 40 | def test_simple(args): 41 | """Function to predict for a single image or folder of images 42 | """ 43 | # assert args.model_name is not None, \ 44 | # "You must specify the --model_name parameter; see README.md for an example" 45 | 46 | if torch.cuda.is_available() and not args.no_cuda: 47 | device = torch.device("cuda") 48 | else: 49 | device = torch.device("cpu") 50 | 51 | # download_model_if_doesnt_exist(args.model_name) 52 | model_path = args.model_path 53 | print("-> Loading model from ", model_path) 54 | encoder_path = os.path.join(model_path, "encoder.pth") 55 | depth_decoder_path = os.path.join(model_path, "depth.pth") 56 | 57 | 58 | 59 | # LOADING PRETRAINED MODEL 60 | print(" Loading pretrained encoder") 61 | if args.x: 62 | encoder = networks.ResnetEncoder(50, False) 63 | else: 64 | encoder = networks.ResnetEncoder(18, False) 65 | encoder .load_state_dict(torch.load(encoder_path, map_location= device),False) 66 | encoder.to(device) 67 | encoder.eval() 68 | 69 | print(" Loading pretrained decoder") 70 | if args.update == 3: 71 | depth_decoder = R_MSFM3(args.x) 72 | else: 73 | depth_decoder = R_MSFM6(args.x) 74 | depth_decoder.load_state_dict(torch.load(depth_decoder_path, map_location= device)) 75 | depth_decoder.to(device) 76 | depth_decoder.eval() 77 | 78 | # FINDING INPUT IMAGES 79 | if os.path.isfile(args.image_path): 80 | # Only testing on a single image 81 | paths = [args.image_path] 82 | output_directory = os.path.dirname(args.image_path) 83 | elif os.path.isdir(args.image_path): 84 | # Searching folder for images 85 | paths = glob.glob(os.path.join(args.image_path, '*.{}'.format(args.ext))) 86 | output_directory=os.path.join(args.image_path,'output') 87 | else: 88 | raise Exception("Can not find args.image_path: {}".format(args.image_path)) 89 | 90 | print("-> Predicting on {:d} test images".format(len(paths))) 91 | 92 | if os.path.exists(output_directory): 93 | 94 | shutil.rmtree(output_directory) 95 | 96 | os.makedirs(output_directory) 97 | 98 | # PREDICTING ON EACH IMAGE IN TURN 99 | with torch.no_grad(): 100 | min_infer_time = 10 101 | for idx, image_path in enumerate(paths): 102 | 103 | if image_path.endswith("_disp.jpg"): 104 | # don't try to predict disparity for a disparity image! 105 | continue 106 | feed_width = 640 107 | feed_height = 192 108 | # Load image and preprocess 109 | input_image = pil.open(image_path).convert('RGB') 110 | original_width, original_height = input_image.size 111 | input_image = input_image.resize((feed_width, feed_height), pil.LANCZOS) 112 | input_image = transforms.ToTensor()(input_image).unsqueeze(0) 113 | 114 | #torch.cuda.synchronize() 115 | start = time.time() 116 | # PREDICTION 117 | input_image = input_image.to(device) 118 | features = encoder(input_image) 119 | outputs = depth_decoder(features) 120 | #torch.cuda.synchronize() 121 | end = time.time() 122 | infer_time = end-start 123 | 124 | 125 | if infer_time < min_infer_time: 126 | min_infer_time = infer_time 127 | 128 | 129 | if args.update == 3: 130 | disp = outputs[("disp_up", 2)] 131 | else: 132 | disp = outputs[("disp_up", 5)] 133 | disp_resized = torch.nn.functional.interpolate( 134 | disp, (original_height, original_width), mode="bilinear", align_corners=False) 135 | 136 | # Saving numpy file 137 | output_name = os.path.splitext(os.path.basename(image_path))[0] 138 | name_dest_npy = os.path.join(output_directory, "{}_disp.npy".format(output_name)) 139 | scaled_disp, _ = disp_to_depth(disp, 0.1, 100) 140 | #np.save(name_dest_npy, scaled_disp.cpu().numpy()) 141 | 142 | # Saving colormapped depth image 143 | disp_resized_np = disp_resized.squeeze().cpu().numpy() 144 | vmax = np.percentile(disp_resized_np, 95) 145 | normalizer = mpl.colors.Normalize(vmin=disp_resized_np.min(), vmax=vmax) 146 | mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') 147 | colormapped_im = (mapper.to_rgba(disp_resized_np)[:, :, :3] * 255).astype(np.uint8) 148 | 149 | im = pil.fromarray(colormapped_im) 150 | 151 | name_dest_im = os.path.join(output_directory, "{}_disp.jpeg".format(output_name)) 152 | im.save(name_dest_im) 153 | 154 | print(" Processed {:d} of {:d} images - saved prediction to {}".format( 155 | idx + 1, len(paths), name_dest_im)) 156 | 157 | print('min_infer_time:', min_infer_time) 158 | print('-> Done!') 159 | 160 | 161 | if __name__ == '__main__': 162 | args = parse_args() 163 | test_simple(args) 164 | 165 | 166 | ''' 167 | python test_simple.py --image_path='/path/to/your/data/' --model_path='/path/to/your/model/' --update=6 168 | 169 | ''' 170 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | import numpy as np 9 | import random 10 | from trainer import Trainer 11 | from options import RMSFM2Options 12 | import torch 13 | import os 14 | 15 | options = RMSFM2Options() 16 | opts = options.parse() 17 | torch.backends.cudnn.benchmark = True 18 | def setup_seed(seed): 19 | torch.manual_seed(seed) 20 | torch.cuda.manual_seed_all(seed) 21 | np.random.seed(seed) 22 | random.seed(seed) 23 | torch.backends.cudnn.deterministic = True 24 | 25 | setup_seed(1) 26 | 27 | if __name__ == "__main__": 28 | trainer = Trainer(opts) 29 | trainer.train() 30 | 31 | ''' 32 | python train.py --gc 33 | python train.py 34 | ''' -------------------------------------------------------------------------------- /trainer.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | import sys 9 | sys.path.append('core') 10 | import numpy as np 11 | import time 12 | 13 | import torch 14 | import torch.nn.functional as F 15 | import torch.optim as optim 16 | from torch.utils.data import DataLoader 17 | from tensorboardX import SummaryWriter 18 | import json 19 | from utils import * 20 | from kitti_utils import * 21 | from layers import * 22 | from R_MSFM import R_MSFM3,R_MSFM6 23 | import datasets 24 | import networks 25 | 26 | accu = 2 27 | try: 28 | from torch.cuda.amp import GradScaler 29 | except: 30 | # dummy GradScaler for PyTorch < 1.6 31 | class GradScaler: 32 | def __init__(self): 33 | pass 34 | def scale(self, loss): 35 | return loss 36 | def unscale_(self, optimizer): 37 | pass 38 | def step(self, optimizer): 39 | optimizer.step() 40 | def update(self): 41 | pass 42 | 43 | class Trainer: 44 | def __init__(self, options): 45 | self.opt = options 46 | self.log_path = os.path.join(self.opt.log_dir, self.opt.model_name) 47 | self.scaler = GradScaler(enabled=False) 48 | # checking height and width are multiples of 32 49 | assert self.opt.height % 32 == 0, "'height' must be a multiple of 32" 50 | assert self.opt.width % 32 == 0, "'width' must be a multiple of 32" 51 | self.models = {} 52 | self.parameters_to_train = [] 53 | self.device = torch.device("cpu" if self.opt.no_cuda else "cuda") 54 | self.num_input_frames = len(self.opt.frame_ids) 55 | assert self.opt.frame_ids[0] == 0, "frame_ids must start with 0" 56 | 57 | self.use_pose_net = not (self.opt.use_stereo and self.opt.frame_ids == [0]) 58 | 59 | if self.opt.use_stereo: 60 | self.opt.frame_ids.append("s") 61 | 62 | self.models["encoder"] = networks.ResnetEncoder(self.opt.num_layers, True) 63 | self.models["encoder"].to(self.device) 64 | self.parameters_to_train += list(self.models["encoder"].parameters()) 65 | 66 | if self.opt.iters = 6: 67 | self.models["depth"] = R_MSFM6() 68 | else: 69 | self.models["depth"] = R_MSFM3() 70 | self.models["depth"].to(self.device) 71 | self.parameters_to_train += list(self.models["depth"].parameters()) 72 | 73 | 74 | if self.use_pose_net: 75 | self.models["pose_encoder"] = networks.ResnetEncoder2(18,True,2) 76 | self.models["pose_encoder"].to(self.device) 77 | self.parameters_to_train += list(self.models["pose_encoder"].parameters()) 78 | 79 | self.models["pose"] = networks.PoseDecoder( 80 | self.models["pose_encoder"].num_ch_enc, 81 | num_input_features=1, 82 | num_frames_to_predict_for=2) 83 | self.models["pose"].to(self.device) 84 | self.parameters_to_train += list(self.models["pose"].parameters()) 85 | 86 | self.optimizer =optim.AdamW(self.parameters_to_train, lr=self.opt.learning_rate, 87 | weight_decay=self.opt.wdecay) 88 | self.scheduler = optim.lr_scheduler.OneCycleLR(self.optimizer,self.opt.learning_rate, epochs=40, div_factor= 25 , 89 | steps_per_epoch=len(train_dataset)//self.opt.batch_size, pct_start= 0.1, cycle_momentum=False) 90 | 91 | print("Training model named:\n ", self.opt.model_name) 92 | print("Models and tensorboard events files are saved to:\n ", self.opt.log_dir) 93 | 94 | datasets_dict = {"kitti": datasets.KITTIRAWDataset, 95 | "kitti_odom": datasets.KITTIOdomDataset} 96 | self.dataset = datasets_dict[self.opt.dataset] 97 | 98 | fpath = os.path.join(os.path.dirname(__file__), "splits", self.opt.split, "{}_files.txt") 99 | 100 | train_filenames = readlines(fpath.format("train")) 101 | val_filenames = readlines(fpath.format("val")) 102 | img_ext = '.png' if self.opt.png else '.jpg' 103 | 104 | num_train_samples = len(train_filenames) 105 | self.num_total_steps = num_train_samples // self.opt.batch_size * self.opt.num_epochs 106 | 107 | train_dataset = self.dataset( 108 | self.opt.data_path, train_filenames, self.opt.height, self.opt.width, 109 | self.opt.frame_ids, is_train=True, img_ext=img_ext) 110 | self.train_loader = DataLoader( 111 | train_dataset, self.opt.batch_size, True, 112 | num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) 113 | val_dataset = self.dataset( 114 | self.opt.data_path, val_filenames, self.opt.height, self.opt.width, 115 | self.opt.frame_ids, is_train=False, img_ext=img_ext) 116 | self.val_loader = DataLoader( 117 | val_dataset, self.opt.batch_size, True, 118 | num_workers=self.opt.num_workers, pin_memory=True, drop_last=True) 119 | self.val_iter = iter(self.val_loader) 120 | 121 | self.writers = {} 122 | for mode in ["train", "val"]: 123 | self.writers[mode] = SummaryWriter(os.path.join(self.log_path, mode)) 124 | 125 | self.ssim = SSIM() 126 | self.ssim.to(self.device) 127 | 128 | self.backproject_depth = {} 129 | self.project_3d = {} 130 | scale = 0 131 | h = self.opt.height // (2 ** scale) 132 | w = self.opt.width // (2 ** scale) 133 | 134 | self.backproject_depth[scale] = BackprojectDepth(self.opt.batch_size, h, w) 135 | self.backproject_depth[scale].cuda() 136 | 137 | self.project_3d[scale] = Project3D(self.opt.batch_size, h, w) 138 | self.project_3d[scale].cuda() 139 | 140 | self.depth_metric_names = [ 141 | "de/abs_rel", "de/sq_rel", "de/rms", "de/log_rms", "da/a1", "da/a2", "da/a3"] 142 | 143 | print("Using split:\n ", self.opt.split) 144 | print("There are {:d} training items and {:d} validation items\n".format( 145 | len(train_dataset), len(val_dataset))) 146 | 147 | self.save_opts() 148 | 149 | def set_train(self): 150 | """Convert all models to training mode 151 | """ 152 | for m in self.models.values(): 153 | m.train() 154 | 155 | def set_eval(self): 156 | """Convert all models to testing/evaluation mode 157 | """ 158 | for m in self.models.values(): 159 | m.eval() 160 | 161 | def train(self): 162 | """Run the entire training pipeline 163 | """ 164 | self.epoch = 0 165 | self.step = 0 166 | self.start_time = time.time() 167 | for self.epoch in range(self.opt.num_epochs): 168 | self.run_epoch() 169 | if (self.epoch + 1) % self.opt.save_frequency == 0: 170 | self.save_model() 171 | 172 | 173 | def run_epoch(self): 174 | """Run a single epoch of training and validation 175 | # """ 176 | # self.scheduler.step() 177 | 178 | print("Training") 179 | self.set_train() 180 | 181 | for batch_idx, inputs in enumerate(self.train_loader): 182 | before_op_time = time.time() 183 | 184 | outputs, losses = self.process_batch(inputs) 185 | losses["loss"] = losses["loss"]/accu 186 | losses["loss"].backward() 187 | if self.step%accu==0: 188 | torch.nn.utils.clip_grad_norm_(self.parameters_to_train, self.opt.clip) 189 | self.optimizer.step() 190 | self.optimizer.zero_grad() 191 | self.scheduler.step() 192 | Lr = self.scheduler.get_lr()[0] 193 | 194 | duration = time.time() - before_op_time 195 | 196 | # log less frequently after the first 2000 steps to save time & disk space 197 | early_phase = batch_idx % self.opt.log_frequency == 0 and self.step < 6000 198 | late_phase = self.step % 2000 == 0 199 | 200 | if early_phase or late_phase: 201 | self.log_time(batch_idx, duration, losses["loss"].cpu().data) 202 | 203 | if "depth_gt" in inputs: 204 | self.compute_depth_losses(inputs, outputs, losses) 205 | 206 | self.log("train", inputs, outputs, losses, Lr) 207 | # self.val() 208 | if self.epoch>0 and (self.step %1500) ==0 : 209 | self.save_model() 210 | self.step += 1 211 | 212 | def process_batch(self, inputs): 213 | """Pass a minibatch through the network and generate images and losses 214 | """ 215 | for key, ipt in inputs.items(): 216 | inputs[key] = ipt.to(self.device) 217 | 218 | 219 | features = self.models["encoder"](inputs["color_aug", 0, 0]) 220 | outputs = self.models["depth"](features, iters = self.opt.iters) 221 | if self.opt.gc: 222 | aa = self.models["encoder"](inputs["color_aug", -1, 0]) 223 | bb = self.models["depth"](aa, iters = self.opt.iters) 224 | outputs_1 = {} 225 | outputs_1[("disp_up", -1, 0)] = bb[("disp_up", 0)] 226 | outputs_1[("disp_up", -1, 1)] = bb[("disp_up", 1)] 227 | outputs_1[("disp_up", -1, 2)] = bb[("disp_up", 2)] 228 | outputs_1[("disp_up", -1, 3)] = bb[("disp_up", 3)] 229 | outputs_1[("disp_up", -1, 4)] = bb[("disp_up", 4)] 230 | outputs_1[("disp_up", -1, 5)] = bb[("disp_up", 5)] 231 | 232 | aa = self.models["encoder"](inputs["color_aug", 1, 0]) 233 | bb = self.models["depth"](aa, iters = self.opt.iters) 234 | outputs1 = {} 235 | outputs1[("disp_up", 1, 0)] = bb[("disp_up", 0)] 236 | outputs1[("disp_up", 1, 1)] = bb[("disp_up", 1)] 237 | outputs1[("disp_up", 1, 2)] = bb[("disp_up", 2)] 238 | outputs1[("disp_up", 1, 3)] = bb[("disp_up", 3)] 239 | outputs1[("disp_up", 1, 4)] = bb[("disp_up", 4)] 240 | outputs1[("disp_up", 1, 5)] = bb[("disp_up", 5)] 241 | outputs1.update(outputs_1) 242 | outputs.update(outputs1) 243 | 244 | 245 | if self.use_pose_net: 246 | outputs.update(self.predict_poses2(inputs)) 247 | 248 | 249 | self.generate_images_pred(inputs, outputs, iters = self.opt.iters) 250 | losses = self.compute_losses(inputs, outputs) 251 | 252 | return outputs, losses 253 | 254 | 255 | def predict_poses2(self, inputs): 256 | """Predict poses between input frames for monocular sequences. 257 | """ 258 | outputs = {} 259 | pose_feats = {f_i: inputs["color_aug", f_i, 0] for f_i in self.opt.frame_ids} 260 | 261 | for f_i in self.opt.frame_ids[1:]: 262 | if f_i != "s": 263 | # To maintain ordering we always pass frames in temporal order 264 | if f_i < 0: 265 | pose_inputs = [pose_feats[f_i], pose_feats[0]] 266 | else: 267 | pose_inputs = [pose_feats[0], pose_feats[f_i]] 268 | 269 | pose_inputs = [self.models["pose_encoder"](torch.cat(pose_inputs, 1))] 270 | axisangle, translation = self.models["pose"](pose_inputs) 271 | outputs[("axisangle", 0, f_i)] = axisangle 272 | outputs[("translation", 0, f_i)] = translation 273 | 274 | # Invert the matrix if the frame id is negative 275 | outputs[("cam_T_cam", 0, f_i)] = transformation_from_parameters( 276 | axisangle[:, 0], translation[:, 0], invert=(f_i < 0)) 277 | 278 | return outputs 279 | 280 | def val(self): 281 | """Validate the model on a single minibatch 282 | """ 283 | self.set_eval() 284 | try: 285 | inputs = self.val_iter.next() 286 | except StopIteration: 287 | self.val_iter = iter(self.val_loader) 288 | inputs = self.val_iter.next() 289 | 290 | with torch.no_grad(): 291 | outputs, losses = self.process_batch(inputs) 292 | 293 | if "depth_gt" in inputs: 294 | self.compute_depth_losses(inputs, outputs, losses) 295 | 296 | self.log("val", inputs, outputs, losses) 297 | del inputs, outputs, losses 298 | 299 | self.set_train() 300 | 301 | def generate_images_pred(self, inputs, outputs, iters=3): 302 | 303 | for scale in range(iters): 304 | 305 | disp = outputs[("disp_up", scale)] 306 | 307 | source_scale = 0 308 | 309 | _, depth = disp_to_depth(disp, self.opt.min_depth, self.opt.max_depth) 310 | 311 | outputs[("depth", 0, scale)] = depth 312 | 313 | for i, frame_id in enumerate(self.opt.frame_ids[1:]): 314 | 315 | if frame_id == "s": 316 | T = inputs["stereo_T"] 317 | else: 318 | T = outputs[("cam_T_cam", 0, frame_id)] 319 | 320 | cam_points = self.backproject_depth[source_scale]( 321 | depth, inputs[("inv_K", source_scale)]) 322 | pix_coords, computed_depth = self.project_3d[source_scale]( 323 | cam_points, inputs[("K", source_scale)], T) 324 | outputs[("sample", frame_id, scale)] = pix_coords 325 | 326 | outputs[("color", frame_id, scale)] = F.grid_sample( 327 | inputs[("color", frame_id, source_scale)], 328 | outputs[("sample", frame_id, scale)], 329 | padding_mode="border") 330 | if self.opt.gc: 331 | outputs[("computed_depth", frame_id, scale)] = computed_depth 332 | _, outputs[("projected_depth", frame_id, scale)] = disp_to_depth(F.grid_sample( 333 | outputs[("disp_up", frame_id, scale)], 334 | outputs[("sample", frame_id, scale)], 335 | padding_mode="border"), self.opt.min_depth, self.opt.max_depth) 336 | 337 | outputs[("color_identity", frame_id, scale)] = \ 338 | inputs[("color", frame_id, source_scale)] 339 | 340 | def smooth_l1_loss_ours(self, input,target,beta=0.15): 341 | n = torch.abs(input-target) 342 | cond = n 0 483 | 484 | # garg/eigen crop 485 | crop_mask = torch.zeros_like(mask) 486 | crop_mask[:, :, 153:371, 44:1197] = 1 487 | mask = mask * crop_mask 488 | 489 | depth_gt = depth_gt[mask] 490 | depth_pred = depth_pred[mask] 491 | depth_pred *= torch.median(depth_gt) / torch.median(depth_pred) 492 | 493 | depth_pred = torch.clamp(depth_pred, min=1e-3, max=80) 494 | 495 | depth_errors = compute_depth_errors(depth_gt, depth_pred) 496 | 497 | for i, metric in enumerate(self.depth_metric_names): 498 | losses[metric] = np.array(depth_errors[i].cpu()) 499 | 500 | def log_time(self, batch_idx, duration, loss): 501 | """Print a logging statement to the terminal 502 | """ 503 | samples_per_sec = self.opt.batch_size / duration 504 | time_sofar = time.time() - self.start_time 505 | training_time_left = ( 506 | self.num_total_steps / self.step - 1.0) * time_sofar if self.step > 0 else 0 507 | print_string = "epoch {:>3} | batch {:>6} | examples/s: {:5.1f}" + \ 508 | " | loss: {:.5f} | time elapsed: {} | time left: {}" 509 | print(print_string.format(self.epoch, batch_idx, samples_per_sec, loss, 510 | sec_to_hm_str(time_sofar), sec_to_hm_str(training_time_left))) 511 | 512 | 513 | def log(self, mode, inputs, outputs, losses, Lr): 514 | """Write an event to the tensorboard events file 515 | """ 516 | writer = self.writers[mode] 517 | for l, v in losses.items(): 518 | writer.add_scalar("{}".format(l), v, self.step) 519 | writer.add_scalar('learningRate', Lr, self.step) 520 | 521 | for j in range(min(4, self.opt.batch_size)): # write a maxmimum of four images 522 | s = 0 523 | writer.add_image( 524 | "color_{}_{}/{}".format(s, s, j), 525 | inputs[("color", s, s)][j].data, self.step) 526 | 527 | writer.add_image( 528 | "disp_{}/{}".format(s, j), 529 | normalize_image(outputs[("disp_up", s)][j]), self.step) 530 | 531 | def save_opts(self): 532 | """Save options to disk so we know what we ran this experiment with 533 | """ 534 | models_dir = os.path.join(self.log_path, "models") 535 | if not os.path.exists(models_dir): 536 | os.makedirs(models_dir) 537 | to_save = self.opt.__dict__.copy() 538 | 539 | with open(os.path.join(models_dir, 'opt.json'), 'w') as f: 540 | json.dump(to_save, f, indent=2) 541 | 542 | def save_model(self): 543 | """Save model weights to disk 544 | """ 545 | save_folder = os.path.join(self.log_path, "models", "weights_{}_{}".format(self.epoch, self.step)) 546 | if not os.path.exists(save_folder): 547 | os.makedirs(save_folder) 548 | 549 | for model_name, model in self.models.items(): 550 | save_path = os.path.join(save_folder, "{}.pth".format(model_name)) 551 | to_save = model.state_dict() 552 | if model_name == 'encoder': 553 | # save the sizes - these are needed at prediction time 554 | to_save['height'] = self.opt.height 555 | to_save['width'] = self.opt.width 556 | to_save['use_stereo'] = self.opt.use_stereo 557 | torch.save(to_save, save_path) 558 | 559 | save_path = os.path.join(save_folder, "{}.pth".format("adam")) 560 | torch.save(self.optimizer.state_dict(), save_path) 561 | 562 | def load_model(self): 563 | """Load model(s) from disk 564 | """ 565 | self.opt.load_weights_folder = os.path.expanduser(self.opt.load_weights_folder) 566 | 567 | assert os.path.isdir(self.opt.load_weights_folder), \ 568 | "Cannot find folder {}".format(self.opt.load_weights_folder) 569 | print("loading model from folder {}".format(self.opt.load_weights_folder)) 570 | 571 | for n in self.opt.models_to_load: 572 | print("Loading {} weights...".format(n)) 573 | path = os.path.join(self.opt.load_weights_folder, "{}.pth".format(n)) 574 | model_dict = self.models[n].state_dict() 575 | pretrained_dict = torch.load(path, map_location='cpu') 576 | pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} 577 | model_dict.update(pretrained_dict) 578 | self.models[n].load_state_dict(model_dict, strict = True) 579 | 580 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | # Copyright Niantic 2019. Patent Pending. All rights reserved. 2 | # 3 | # This software is licensed under the terms of the Monodepth2 licence 4 | # which allows for non-commercial use only, the full terms of which are made 5 | # available in the LICENSE file. 6 | 7 | from __future__ import absolute_import, division, print_function 8 | import os 9 | import hashlib 10 | import zipfile 11 | from six.moves import urllib 12 | 13 | 14 | def readlines(filename): 15 | """Read all the lines in a text file and return as a list 16 | """ 17 | with open(filename, 'r') as f: 18 | lines = f.read().splitlines() 19 | return lines 20 | 21 | 22 | def normalize_image(x): 23 | """Rescale image pixels to span range [0, 1] 24 | """ 25 | ma = float(x.max().cpu().data) 26 | mi = float(x.min().cpu().data) 27 | d = ma - mi if ma != mi else 1e5 28 | return (x - mi) / d 29 | 30 | 31 | def sec_to_hm(t): 32 | """Convert time in seconds to time in hours, minutes and seconds 33 | e.g. 10239 -> (2, 50, 39) 34 | """ 35 | t = int(t) 36 | s = t % 60 37 | t //= 60 38 | m = t % 60 39 | t //= 60 40 | return t, m, s 41 | 42 | 43 | def sec_to_hm_str(t): 44 | """Convert time in seconds to a nice string 45 | e.g. 10239 -> '02h50m39s' 46 | """ 47 | h, m, s = sec_to_hm(t) 48 | return "{:02d}h{:02d}m{:02d}s".format(h, m, s) 49 | 50 | 51 | --------------------------------------------------------------------------------