├── DifferentiablePatchMatch ├── README.md ├── demo_script.py └── models │ ├── .gitignore │ ├── __init__.py │ ├── config.py │ ├── feature_extractor.py │ ├── image_reconstruction.py │ └── patch_match.py ├── LICENSE ├── NOTICE ├── README.md ├── deeppruner ├── .gitignore ├── README.md ├── dataloader │ ├── .gitignore │ ├── __init__.py │ ├── kitti_collector.py │ ├── kitti_loader.py │ ├── kitti_submission_collector.py │ ├── preprocess.py │ ├── readpfm.py │ ├── sceneflow_collector.py │ └── sceneflow_loader.py ├── finetune_kitti.py ├── loss_evaluation.py ├── models │ ├── .gitignore │ ├── __init__.py │ ├── config.py │ ├── deeppruner.py │ ├── feature_extractor_best.py │ ├── feature_extractor_fast.py │ ├── patch_match.py │ ├── submodules.py │ ├── submodules2d.py │ ├── submodules3d.py │ └── utils.py ├── setup_logging.py ├── submission_kitti.py └── train_sceneflow.py └── readme_images ├── CRP.png ├── DPM.png ├── DPM_filters.png ├── DPM_reconstruction ├── DPM │ ├── Never back down 1032.png │ ├── Never back down 1979.png │ ├── Never back down 2046.png │ ├── cheetah1.png │ └── face1.png ├── ImageA │ ├── Never back down 1032.png │ ├── Never back down 1979.png │ ├── Never back down 2046.png │ ├── cheetah1.png │ └── face1.png ├── ImageB │ ├── Never back down 1213.png │ ├── Never back down 2040.png │ ├── Never back down 2066.png │ ├── cheetah2.png │ └── face2.png └── PM │ ├── Never back down 1032.png │ ├── Never back down 1979.png │ ├── Never back down 2046.png │ ├── cheetah1.png │ └── face.png ├── DeepPruner.png ├── KITTI_test_set.png ├── kitti_results.png ├── original_patch_match.png ├── original_patch_match_steps.png ├── rob.png ├── rob_results.png ├── sceneflow.png ├── sceneflow_results.png ├── softmax.png └── uncertainty_vis.png /DifferentiablePatchMatch/README.md: -------------------------------------------------------------------------------- 1 | # Differentiable Patch-Match 2 | 3 | ##### Table of Contents 4 | [Patch Match](#PatchMatch) 5 | [Differentiable Patch Match](#DifferentiablePatchMatch) 6 | [Differentiable Patch Match vs Patch Match for Image Reconstruction](#Comparison) 7 | [Run Command](#run_command) 8 | [Citation](#citation) 9 | 10 | 11 | ## PatchMatch 12 | 13 | Patch Match ([Barnes et al.](https://gfx.cs.princeton.edu/pubs/Barnes_2009_PAR/)) was originally introduced as an efficient way to find dense correspondences across images for structural editing. The key idea behind it is that, a **large number of random samples often lead to good guesses**. Additionally, **neighboring pixels usually have coherent matches**. Therefore, once a good match is found, it can efficiently propagate the information to the neighbors. 14 | 15 | | **Patch Match Overview** | **Patch Match Steps** | 16 | | :----------------------------------------------------------- | ------------------------------------------------------------ | 17 | | | | 18 | 19 | 20 | 21 | ## Differentiable PatchMatch 22 | 23 | 1. In our work, we unroll generalized **PatchMatch as a recurrent neural network**, where each unrolling step is equivalent to each iteration of the algorithm. **This is important as it allow us to train our full model end-to-end. ** 24 | 2. Specifically, we design the following layers: 25 | 1. **Particle sampling layer**: for each pixel i, we randomly generate k disparity values from the uniform distribution over predicted/pre-defined search space. 26 | 2. **Propagation layer**: particles from adjacent pixels are propagated together through convolution with a predefined one-hot filter pattern, which encodes the fact that we allow each pixel to propagate particles to its 4-neighbours. 27 | 3. **Evaluation layer**: for each pixel i, matching scores are computed by taking the inner product between the left feature and the right feature: si,j; = 0(i), f1(i + di,j)>. The best k disparity value for each pixel is carried towards the next iteration 28 | 29 | 30 | 31 | We replace the non-differentiable argmax during evaluation with differentiable softmax. 32 | 33 |

34 | 35 |

36 | 37 |

38 | 39 |

40 | 41 | 42 | 43 | 44 | 45 | 3. For usage of Differentiable Patch Match for stereo matching, refer to out paper. 46 | 47 | 48 | 49 | 50 | ## Differentiable Patch Match vs Patch Match for Image Reconstruction 51 | 52 | 1. In this section, we compare the proposed Differentiable Patch Match with the original Patch Match for the image reconstruction task. 53 | 2. Given two images, Image A and Image B, we reconstruct Image A from Image B, based on **Differentiable PatchMatch / PatchMatch** "A-to-B" dense patch mappings. 54 | 3. Following are the per-iteration comparison results between the Differentiable Patch Match and the original Patch Match. 55 | 56 | 57 |   58 |   59 |   60 |   61 |   62 |   63 |   64 |   65 | 66 | ### *Images* 67 | 68 | 69 | | Image A | Image B | 70 | | :-----: | :-----: | 71 | | Never back down 1032 | Never back down 1213 | 72 | | Never back down 1979 | Never back down 2040 | 73 | | Never back down 2046 | Never back down 2066 | 74 | | cheetah | cheetah2 | 75 | | cheetah | cheetah2 | 76 | 77 | 78 | 79 | 80 | 81 | 82 |   83 |   84 |   85 |   86 |   87 |   88 |   89 |   90 | 91 | 92 | ### *Reconstructed A from B* 93 | 94 | 95 | | Differentiable Patch Match | Patch Match | 96 | | :-----: | :-----: | 97 | | Never back down 1032 | Never back down 1213 | 98 | | Never back down 1032 | Never back down 1213 | 99 | | Never back down 1032 | Never back down 1213 | 100 | | Never back down 1032 | Never back down 1213 | 101 | | Never back down 1032 | Never back down 1213 | 102 | 103 | 104 | 105 |   106 |   107 |   108 |   109 |   110 |   111 |   112 |   113 | 114 | 115 | 116 | 117 | ### RUN COMMAND 118 | 119 | 1. Create a base_dir with two folders: 120 | base_directory 121 | |----- image_1 122 | |---- image_2 123 | 124 | 125 | 126 | > python demo_script.py \ 127 | > --base_dir 128 | > --save_dir < save directory for reconstructed images> 129 | 130 | 131 | 132 | 133 | 134 | ## Citation 135 | 136 | 137 | 138 | If you use our source code, or our paper, please consider citing the following: 139 | 140 | > @inproceedings{Duggal2019ICCV, 141 | > title = {DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch}, 142 | > author = {Shivam Duggal and Shenlong Wang and Wei-Chiu Ma and Rui Hu and Raquel Urtasun}, > 143 | > booktitle = {ICCV}, 144 | > year = {2019} 145 | > } 146 | 147 | 148 | 149 | Correspondences to Shivam Duggal , Shenlong Wang , Wei-Chiu Ma 150 | -------------------------------------------------------------------------------- /DifferentiablePatchMatch/demo_script.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | from PIL import Image 18 | import torch 19 | import random 20 | import skimage 21 | import numpy as np 22 | from models.image_reconstruction import ImageReconstruction 23 | import os 24 | import torchvision.transforms as transforms 25 | import argparse 26 | import matplotlib.pyplot as plt 27 | 28 | parser = argparse.ArgumentParser(description='Differentiable PatchMatch') 29 | parser.add_argument('--base_dir', default='./', 30 | help='path of base directory where images are stored.') 31 | parser.add_argument('--save_dir', default='./', 32 | help='save directory') 33 | parser.add_argument('--no-cuda', action='store_true', default=False, 34 | help='enables CUDA training') 35 | parser.add_argument('--seed', type=int, default=1, metavar='S', 36 | help='random seed (default: 1)') 37 | 38 | args = parser.parse_args() 39 | torch.backends.cudnn.benchmark=True 40 | args.cuda = not args.no_cuda and torch.cuda.is_available() 41 | 42 | if args.cuda: 43 | torch.manual_seed(args.seed) 44 | np.random.seed(args.seed) 45 | random.seed(args.seed) 46 | torch.cuda.manual_seed(args.seed) 47 | torch.backends.cudnn.deterministic = True 48 | 49 | model = ImageReconstruction() 50 | 51 | if args.cuda: 52 | model.cuda() 53 | 54 | 55 | def main(): 56 | 57 | base_dir = args.base_dir 58 | for file1, file2 in zip(sorted(os.listdir(base_dir+'/image_1')), sorted(os.listdir(base_dir+'/image_2'))): 59 | 60 | image_1_image_path = base_dir + '/image_1/' + file1 61 | image_2_image_path = base_dir + '/image_2/' + file2 62 | 63 | image_1 = np.asarray(Image.open(image_1_image_path).convert('RGB')) 64 | image_2 = np.asarray(Image.open(image_2_image_path).convert('RGB')) 65 | 66 | image_1 = transforms.ToTensor()(image_1).unsqueeze(0).cuda().float() 67 | image_2 = transforms.ToTensor()(image_2).unsqueeze(0).cuda().float() 68 | 69 | reconstruction = model(image_1, image_2) 70 | 71 | plt.imsave(os.path.join(args.save_dir, image_1_image_path.split('/')[-1]), 72 | np.asarray(reconstruction[0].permute(1,2,0).data.cpu()*256).astype('uint16')) 73 | 74 | 75 | 76 | if __name__ == '__main__': 77 | with torch.no_grad(): 78 | main() -------------------------------------------------------------------------------- /DifferentiablePatchMatch/models/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | -------------------------------------------------------------------------------- /DifferentiablePatchMatch/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/DifferentiablePatchMatch/models/__init__.py -------------------------------------------------------------------------------- /DifferentiablePatchMatch/models/config.py: -------------------------------------------------------------------------------- 1 | 2 | # --------------------------------------------------------------------------- 3 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 4 | # 5 | # Copyright (c) 2019 Uber Technologies, Inc. 6 | # 7 | # Licensed under the Uber Non-Commercial License (the "License"); 8 | # you may not use this file except in compliance with the License. 9 | # You may obtain a copy of the License at the root directory of this project. 10 | # 11 | # See the License for the specific language governing permissions and 12 | # limitations under the License. 13 | # 14 | # Written by Shivam Duggal 15 | # --------------------------------------------------------------------------- 16 | 17 | from __future__ import print_function 18 | 19 | class obj(object): 20 | def __init__(self, d): 21 | for key, value in d.items(): 22 | if isinstance(value, (list, tuple)): 23 | setattr(self, key, [obj(x) if isinstance(x, dict) else x for x in value]) 24 | else: 25 | setattr(self, key, obj(value) if isinstance(value, dict) else value) 26 | 27 | config = { 28 | "patch_match_args": { 29 | # sample count refers to random sampling stage of generalized PM. 30 | # Number of random samples generated: (sample_count+1) * (sample_count+1) 31 | # we generate (sample_count+1) samples in x direction, and (sample_count+1) samples in y direction, 32 | # and then perform meshgrid like opertaion to generate (sample_count+1) * (sample_count+1) samples. 33 | "sample_count": 1, 34 | 35 | "iteration_count": 21, 36 | "propagation_filter_size": 3, 37 | "propagation_type": "faster_filter_3_propagation", # for better code for PM propagation, set it to None 38 | "softmax_temperature": 10000000000, # softmax temperature for evaluation. Larger temp. lead to sharper output. 39 | "random_search_window_size": [100,100], # search range around evaluated offsets after every iteration. 40 | "evaluation_type": "softmax" 41 | }, 42 | 43 | "feature_extractor_filter_size": 7 44 | 45 | } 46 | 47 | 48 | config = obj(config) -------------------------------------------------------------------------------- /DifferentiablePatchMatch/models/feature_extractor.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | class feature_extractor(nn.Module): 23 | def __init__(self, filter_size): 24 | super(feature_extractor, self).__init__() 25 | 26 | self.filter_size = filter_size 27 | 28 | def forward(self, left_input, right_input): 29 | """ 30 | Feature Extractor 31 | 32 | Description: Aggregates the RGB values from the neighbouring pixels in the window (filter_size * filter_size). 33 | No weights are learnt for this feature extractor. 34 | 35 | Args: 36 | :param left_input: Left Image 37 | :param right_input: Right Image 38 | 39 | Returns: 40 | :left_features: Left Image features 41 | :right_features: Right Image features 42 | :one_hot_filter: Convolution filter used to aggregate neighbour RGB features to the center pixel. 43 | one_hot_filter.shape = (filter_size * filter_size) 44 | """ 45 | 46 | device = left_input.get_device() 47 | 48 | label = torch.arange(0, self.filter_size * self.filter_size, device=device).repeat( 49 | self.filter_size * self.filter_size).view( 50 | self.filter_size * self.filter_size, 1, 1, self.filter_size, self.filter_size) 51 | 52 | one_hot_filter = torch.zeros_like(label).scatter_(0, label, 1).float() 53 | 54 | left_features = F.conv3d(left_input.unsqueeze(1), one_hot_filter, 55 | padding=(0, self.filter_size // 2, self.filter_size // 2)) 56 | right_features = F.conv3d(right_input.unsqueeze(1), one_hot_filter, 57 | padding=(0, self.filter_size // 2, self.filter_size // 2)) 58 | 59 | left_features = left_features.view(left_features.size()[0], 60 | left_features.size()[1] * left_features.size()[2], 61 | left_features.size()[3], 62 | left_features.size()[4]) 63 | 64 | right_features = right_features.view(right_features.size()[0], 65 | right_features.size()[1] * right_features.size()[2], 66 | right_features.size()[3], 67 | right_features.size()[4]) 68 | 69 | return left_features, right_features, one_hot_filter 70 | -------------------------------------------------------------------------------- /DifferentiablePatchMatch/models/image_reconstruction.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | from models.patch_match import PatchMatch 21 | from models.feature_extractor import feature_extractor 22 | from models.config import config as args 23 | 24 | 25 | class Reconstruct(nn.Module): 26 | def __init__(self, filter_size): 27 | super(Reconstruct, self).__init__() 28 | self.filter_size = filter_size 29 | 30 | def forward(self, right_input, offset_x, offset_y, x_coordinate, y_coordinate, neighbour_extraction_filter): 31 | """ 32 | Reconstruct the left image using the NNF(NNF represented by the offsets and the xy_coordinates) 33 | We did Patch Voting on the offset field, before reconstruction, in order to 34 | generate smooth reconstruction. 35 | Args: 36 | :right_input: Right Image 37 | :offset_x: horizontal offset to generate the NNF. 38 | :offset_y: vertical offset to generate the NNF. 39 | :x_coordinate: X coordinate 40 | :y_coordinate: Y coordinate 41 | 42 | Returns: 43 | :reconstruction: Right image reconstruction 44 | """ 45 | 46 | pad_size = self.filter_size // 2 47 | smooth_offset_x = nn.ReflectionPad2d( 48 | (pad_size, pad_size, pad_size, pad_size))(offset_x) 49 | smooth_offset_y = nn.ReflectionPad2d( 50 | (pad_size, pad_size, pad_size, pad_size))(offset_y) 51 | 52 | smooth_offset_x = F.conv2d(smooth_offset_x, 53 | neighbour_extraction_filter, 54 | padding=(pad_size, pad_size))[:, :, pad_size:-pad_size, pad_size:-pad_size] 55 | 56 | smooth_offset_y = F.conv2d(smooth_offset_y, 57 | neighbour_extraction_filter, 58 | padding=(pad_size, pad_size))[:, :, pad_size:-pad_size, pad_size:-pad_size] 59 | 60 | coord_x = torch.clamp( 61 | x_coordinate - smooth_offset_x, 62 | min=0, 63 | max=smooth_offset_x.size()[3] - 1) 64 | 65 | coord_y = torch.clamp( 66 | y_coordinate - smooth_offset_y, 67 | min=0, 68 | max=smooth_offset_x.size()[2] - 1) 69 | 70 | coord_x -= coord_x.size()[3] / 2 71 | coord_x /= (coord_x.size()[3] / 2) 72 | 73 | coord_y -= coord_y.size()[2] / 2 74 | coord_y /= (coord_y.size()[2] / 2) 75 | 76 | grid = torch.cat((coord_x.unsqueeze(4), coord_y.unsqueeze(4)), dim=4) 77 | grid = grid.view(grid.size()[0] * grid.size()[1], grid.size()[2], grid.size()[3], grid.size()[4]) 78 | reconstruction = F.grid_sample(right_input.repeat(grid.size()[0], 1, 1, 1), grid) 79 | reconstruction = torch.mean(reconstruction, dim=0).unsqueeze(0) 80 | 81 | return reconstruction 82 | 83 | 84 | class ImageReconstruction(nn.Module): 85 | def __init__(self): 86 | super(ImageReconstruction, self).__init__() 87 | 88 | self.patch_match = PatchMatch(args.patch_match_args) 89 | 90 | filter_size = args.feature_extractor_filter_size 91 | self.feature_extractor = feature_extractor(filter_size) 92 | self.reconstruct = Reconstruct(filter_size) 93 | 94 | def forward(self, left_input, right_input): 95 | """ 96 | ImageReconstruction: 97 | Description: This class performs the task of reconstruction the left image using the data of the other image,, 98 | by fidning correspondences (nnf) between the two fields. 99 | The images acan be any random images with some overlap between the two to assist 100 | the correspondence matching. 101 | For feature_extractor, we just use the RGB features of a (self.filter_size * self.filter_size) patch 102 | around each pixel. 103 | For finding the correspondences, we use the Differentiable PatchMatch. 104 | ** Note: There is no assumption of rectification between the two images. ** 105 | ** Note: The words 'left' and 'right' do not have any significance.** 106 | 107 | 108 | Args: 109 | :left_input: Left Image (Image 1) 110 | :right_input: Right Image (Image 2) 111 | 112 | Returns: 113 | :reconstruction: Reconstructed left image. 114 | """ 115 | 116 | left_features, right_features, neighbour_extraction_filter = self.feature_extractor(left_input, right_input) 117 | offset_x, offset_y, x_coordinate, y_coordinate = self.patch_match(left_features, right_features) 118 | 119 | reconstruction = self.reconstruct(right_input, 120 | offset_x, offset_y, 121 | x_coordinate, y_coordinate, 122 | neighbour_extraction_filter.squeeze(1)) 123 | 124 | return reconstruction 125 | -------------------------------------------------------------------------------- /DifferentiablePatchMatch/models/patch_match.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | class RandomSampler(nn.Module): 23 | def __init__(self, device, number_of_samples): 24 | super(RandomSampler, self).__init__() 25 | 26 | # Number of offset samples generated by this function: (number_of_samples+1) * (number_of_samples+1) 27 | # we generate (number_of_samples+1) samples in x direction, and (number_of_samples+1) samples in y direction, 28 | # and then perform meshgrid like opertaion to generate (number_of_samples+1) * (number_of_samples+1) samples 29 | self.number_of_samples = number_of_samples 30 | self.range_multiplier = torch.arange(0.0, number_of_samples + 1, 1, device=device).view( 31 | number_of_samples + 1, 1, 1) 32 | 33 | def forward(self, min_offset_x, max_offset_x, min_offset_y, max_offset_y): 34 | """ 35 | Random Sampler: 36 | Given the search range per pixel (defined by: [[lx(i), ux(i)], [ly(i), uy(i)]]), 37 | where lx = lower_bound of the hoizontal offset, 38 | ux = upper_bound of the horizontal offset, 39 | ly = lower_bound of the vertical offset, 40 | uy = upper_bound of teh vertical offset, for all pixel i. ) 41 | random sampler generates samples from this search range. 42 | First the search range is discretized into `number_of_samples` buckets, 43 | then a random sample is generated from each random bucket. 44 | ** Discretization is done in both xy directions. ** (similar to meshgrid) 45 | 46 | Args: 47 | :min_offset_x: Min horizontal offset of the search range. 48 | :max_offset_x: Max horizontal offset of the search range. 49 | :min_offset_y: Min vertical offset of the search range. 50 | :max_offset_y: Max vertical offset of the search range. 51 | Returns: 52 | :offset_x: samples representing offset in the horizontal direction. 53 | :offset_y: samples representing offset in the vertical direction. 54 | """ 55 | 56 | device = min_offset_x.get_device() 57 | noise = torch.rand(min_offset_x.repeat(1, self.number_of_samples + 1, 1, 1).size(), device=device) 58 | 59 | offset_x = min_offset_x + ((max_offset_x - min_offset_x) / (self.number_of_samples + 1)) * \ 60 | (self.range_multiplier + noise) 61 | offset_y = min_offset_y + ((max_offset_y - min_offset_y) / (self.number_of_samples + 1)) * \ 62 | (self.range_multiplier + noise) 63 | 64 | offset_x = offset_x.unsqueeze_(1).expand(-1, offset_y.size()[1], -1, -1, -1) 65 | offset_x = offset_x.contiguous().view( 66 | offset_x.size()[0], offset_x.size()[1] * offset_x.size()[2], offset_x.size()[3], offset_x.size()[4]) 67 | 68 | offset_y = offset_y.unsqueeze_(2).expand(-1, -1, offset_y.size()[1], -1, -1) 69 | offset_y = offset_y.contiguous().view( 70 | offset_y.size()[0], offset_y.size()[1] * offset_y.size()[2], offset_y.size()[3], offset_y.size()[4]) 71 | 72 | return offset_x, offset_y 73 | 74 | 75 | class Evaluate(nn.Module): 76 | def __init__(self, left_features, filter_size, evaluation_type='softmax', temperature=10000): 77 | super(Evaluate, self).__init__() 78 | self.temperature = temperature 79 | self.filter_size = filter_size 80 | self.softmax = torch.nn.Softmax(dim=1) 81 | self.evaluation_type = evaluation_type 82 | 83 | device = left_features.get_device() 84 | self.left_x_coordinate = torch.arange(0.0, left_features.size()[3], device=device).repeat( 85 | left_features.size()[2]).view(left_features.size()[2], left_features.size()[3]) 86 | 87 | self.left_x_coordinate = torch.clamp(self.left_x_coordinate, min=0, max=left_features.size()[3] - 1) 88 | self.left_x_coordinate = self.left_x_coordinate.expand(left_features.size()[0], -1, -1).unsqueeze(1) 89 | 90 | self.left_y_coordinate = torch.arange(0.0, left_features.size()[2], device=device).unsqueeze(1).repeat( 91 | 1, left_features.size()[3]).view(left_features.size()[2], left_features.size()[3]) 92 | 93 | self.left_y_coordinate = torch.clamp(self.left_y_coordinate, min=0, max=left_features.size()[3] - 1) 94 | self.left_y_coordinate = self.left_y_coordinate.expand(left_features.size()[0], -1, -1).unsqueeze(1) 95 | 96 | def forward(self, left_features, right_features, offset_x, offset_y): 97 | """ 98 | PatchMatch Evaluation Block 99 | Description: For each pixel i, matching scores are computed by taking the inner product between the 100 | left feature and the right feature: score(i,j) = feature_left(i), feature_right(i+disparity(i,j)) 101 | for all candidates j. The best k disparity value for each pixel is carried towards the next iteration. 102 | 103 | As per implementation, 104 | the complete disparity search range is discretized into intervals in 105 | DisparityInitialization() function. Corresponding to each disparity interval, we have multiple samples 106 | to evaluate. The best disparity sample per interval is the output of the function. 107 | 108 | Args: 109 | :left_features: Left Image Feature Map 110 | :right_features: Right Image Feature Map 111 | :offset_x: samples representing offset in the horizontal direction. 112 | :offset_y: samples representing offset in the vertical direction. 113 | 114 | Returns: 115 | :offset_x: horizontal offset evaluated as the best offset to generate NNF. 116 | :offset_y: vertical offset evaluated as the best offset to generate NNF. 117 | 118 | """ 119 | 120 | right_x_coordinate = torch.clamp(self.left_x_coordinate - offset_x, min=0, max=left_features.size()[3] - 1) 121 | right_y_coordinate = torch.clamp(self.left_y_coordinate - offset_y, min=0, max=left_features.size()[2] - 1) 122 | 123 | right_x_coordinate -= right_x_coordinate.size()[3] / 2 124 | right_x_coordinate /= (right_x_coordinate.size()[3] / 2) 125 | right_y_coordinate -= right_y_coordinate.size()[2] / 2 126 | right_y_coordinate /= (right_y_coordinate.size()[2] / 2) 127 | 128 | samples = torch.cat((right_x_coordinate.unsqueeze(4), right_y_coordinate.unsqueeze(4)), dim=4) 129 | samples = samples.view(samples.size()[0] * samples.size()[1], 130 | samples.size()[2], 131 | samples.size()[3], 132 | samples.size()[4]) 133 | 134 | offset_strength = torch.mean(-1.0 * (torch.abs(left_features.expand( 135 | offset_x.size()[1], -1, -1, -1) - F.grid_sample(right_features.expand( 136 | offset_x.size()[1], -1, -1, -1), samples))), dim=1) * self.temperature 137 | 138 | offset_strength = offset_strength.view(left_features.size()[0], 139 | offset_strength.size()[0] // left_features.size()[0], 140 | offset_strength.size()[1], 141 | offset_strength.size()[2]) 142 | 143 | if self.evaluation_type == "softmax": 144 | offset_strength = torch.softmax(offset_strength, dim=1) 145 | offset_x = torch.sum(offset_x * offset_strength, dim=1).unsqueeze(1) 146 | offset_y = torch.sum(offset_y * offset_strength, dim=1).unsqueeze(1) 147 | else: 148 | offset_strength = torch.argmax(offset_strength, dim=1).unsqueeze(1) 149 | offset_x = torch.gather(offset_x, index=offset_strength, dim=1) 150 | offset_y = torch.gather(offset_y, index=offset_strength, dim=1) 151 | 152 | return offset_x, offset_y 153 | 154 | 155 | class Propagation(nn.Module): 156 | def __init__(self, device, filter_size): 157 | super(Propagation, self).__init__() 158 | self.filter_size = filter_size 159 | label = torch.arange(0, self.filter_size, device=device).repeat(self.filter_size).view( 160 | self.filter_size, 1, 1, 1, self.filter_size) 161 | 162 | self.one_hot_filter_h = torch.zeros_like(label).scatter_(0, label, 1).float() 163 | 164 | label = torch.arange(0, self.filter_size, device=device).repeat(self.filter_size).view( 165 | self.filter_size, 1, 1, self.filter_size, 1).long() 166 | 167 | self.one_hot_filter_v = torch.zeros_like(label).scatter_(0, label, 1).float() 168 | 169 | def forward(self, offset_x, offset_y, propagation_type="horizontal"): 170 | """ 171 | PatchMatch Propagation Block 172 | Description: Particles from adjacent pixels are propagated together through convolution with a 173 | one-hot filter, which en-codes the fact that we allow each pixel 174 | to propagate particles to its 4-neighbours. 175 | Args: 176 | :offset_x: samples representing offset in the horizontal direction. 177 | :offset_y: samples representing offset in the vertical direction. 178 | :device: Cuda/ CPU device 179 | :propagation_type (default:"horizontal"): In order to be memory efficient, we use separable convolutions 180 | for propagtaion. 181 | 182 | Returns: 183 | :aggregated_offset_x: Horizontal offset samples aggregated from the neighbours. 184 | :aggregated_offset_y: Vertical offset samples aggregated from the neighbours. 185 | 186 | """ 187 | 188 | offset_x = offset_x.view(offset_x.size()[0], 1, offset_x.size()[1], offset_x.size()[2], offset_x.size()[3]) 189 | offset_y = offset_y.view(offset_y.size()[0], 1, offset_y.size()[1], offset_y.size()[2], offset_y.size()[3]) 190 | 191 | if propagation_type is "horizontal": 192 | aggregated_offset_x = F.conv3d(offset_x, self.one_hot_filter_h, padding=(0, 0, self.filter_size // 2)) 193 | aggregated_offset_y = F.conv3d(offset_y, self.one_hot_filter_h, padding=(0, 0, self.filter_size // 2)) 194 | 195 | else: 196 | aggregated_offset_x = F.conv3d(offset_x, self.one_hot_filter_v, padding=(0, self.filter_size // 2, 0)) 197 | aggregated_offset_y = F.conv3d(offset_y, self.one_hot_filter_v, padding=(0, self.filter_size // 2, 0)) 198 | 199 | aggregated_offset_x = aggregated_offset_x.permute([0, 2, 1, 3, 4]) 200 | aggregated_offset_x = aggregated_offset_x.contiguous().view( 201 | aggregated_offset_x.size()[0], 202 | aggregated_offset_x.size()[1] * aggregated_offset_x.size()[2], 203 | aggregated_offset_x.size()[3], 204 | aggregated_offset_x.size()[4]) 205 | 206 | aggregated_offset_y = aggregated_offset_y.permute([0, 2, 1, 3, 4]) 207 | aggregated_offset_y = aggregated_offset_y.contiguous().view( 208 | aggregated_offset_y.size()[0], 209 | aggregated_offset_y.size()[1] * aggregated_offset_y.size()[2], 210 | aggregated_offset_y.size()[3], 211 | aggregated_offset_y.size()[4]) 212 | 213 | return aggregated_offset_x, aggregated_offset_y 214 | 215 | 216 | class PropagationFaster(nn.Module): 217 | def __init__(self): 218 | super(PropagationFaster, self).__init__() 219 | 220 | def forward(self, offset_x, offset_y, device, propagation_type="horizontal"): 221 | """ 222 | Faster version of PatchMatch Propagation Block 223 | This version uses a fixed propagation filter size of size 3. This implementation is not recommended 224 | and is used only to do the propagation faster. 225 | 226 | Description: Particles from adjacent pixels are propagated together through convolution with a 227 | one-hot filter, which en-codes the fact that we allow each pixel 228 | to propagate particles to its 4-neighbours. 229 | Args: 230 | :offset_x: samples representing offset in the horizontal direction. 231 | :offset_y: samples representing offset in the vertical direction. 232 | :device: Cuda/ CPU device 233 | :propagation_type (default:"horizontal"): In order to be memory efficient, we use separable convolutions 234 | for propagtaion. 235 | 236 | Returns: 237 | :aggregated_offset_x: Horizontal offset samples aggregated from the neighbours. 238 | :aggregated_offset_y: Vertical offset samples aggregated from the neighbours. 239 | 240 | """ 241 | 242 | self.vertical_zeros = torch.zeros((offset_x.size()[0], offset_x.size()[1], 1, offset_x.size()[3])).to(device) 243 | self.horizontal_zeros = torch.zeros((offset_x.size()[0], offset_x.size()[1], offset_x.size()[2], 1)).to(device) 244 | 245 | if propagation_type is "horizontal": 246 | offset_x = torch.cat((torch.cat((self.horizontal_zeros, offset_x[:, :, :, :-1]), dim=3), 247 | offset_x, 248 | torch.cat((offset_x[:, :, :, 1:], self.horizontal_zeros), dim=3)), dim=1) 249 | offset_y = torch.cat((torch.cat((self.horizontal_zeros, offset_y[:, :, :, :-1]), dim=3), 250 | offset_y, 251 | torch.cat((offset_y[:, :, :, 1:], self.horizontal_zeros), dim=3)), dim=1) 252 | 253 | else: 254 | offset_x = torch.cat((torch.cat((self.vertical_zeros, offset_x[:, :, :-1, :]), dim=2), 255 | offset_x, 256 | torch.cat((offset_x[:, :, 1:, :], self.vertical_zeros), dim=2)), dim=1) 257 | offset_y = torch.cat((torch.cat((self.vertical_zeros, offset_y[:, :, :-1, :]), dim=2), 258 | offset_y, 259 | torch.cat((offset_y[:, :, 1:, :], self.vertical_zeros), dim=2)), dim=1) 260 | 261 | return offset_x, offset_y 262 | 263 | 264 | class PatchMatch(nn.Module): 265 | def __init__(self, patch_match_args): 266 | super(PatchMatch, self).__init__() 267 | self.propagation_filter_size = patch_match_args.propagation_filter_size 268 | self.number_of_samples = patch_match_args.sample_count 269 | self.iteration_count = patch_match_args.iteration_count 270 | self.evaluation_type = patch_match_args.evaluation_type 271 | self.softmax_temperature = patch_match_args.softmax_temperature 272 | self.propagation_type = patch_match_args.propagation_type 273 | 274 | self.window_size_x = patch_match_args.random_search_window_size[0] 275 | self.window_size_y = patch_match_args.random_search_window_size[1] 276 | 277 | def forward(self, left_features, right_features): 278 | """ 279 | Differential PatchMatch Block 280 | Description: In this work, we unroll generalized PatchMatch as a recurrent neural network, 281 | where each unrolling step is equivalent to each iteration of the algorithm. 282 | This is important as it allow us to train our full model end-to-end. 283 | Specifically, we design the following layers: 284 | - Initialization or Paticle Sampling 285 | - Propagation 286 | - Evaluation 287 | Args: 288 | :left_features: Left Image feature map 289 | :right_features: Right image feature map 290 | 291 | Returns: 292 | :offset_x: offset for each pixel in the left_features corresponding to the 293 | right_features in the horizontal direction. 294 | :offset_y: offset for each pixel in the left_features corresponding to the 295 | right_features in the vertical direction. 296 | 297 | :x_coordinate: X coordinate corresponding to each pxiel. 298 | :y_coordinate: Y coordinate corresponding to each pxiel. 299 | 300 | (Offsets and the xy_cooridnates returned are used to generated the NNF field later for reconstruction.) 301 | 302 | """ 303 | 304 | device = left_features.get_device() 305 | if self.propagation_type is "faster_filter_3_propagation": 306 | self.propagation = PropagationFaster() 307 | else: 308 | self.propagation = Propagation(device, self.propagation_filter_size) 309 | 310 | self.evaluate = Evaluate(left_features, self.propagation_filter_size, 311 | self.evaluation_type, self.softmax_temperature) 312 | self.uniform_sampler = RandomSampler(device, self.number_of_samples) 313 | 314 | min_offset_x = torch.zeros((left_features.size()[0], 1, left_features.size()[2], 315 | left_features.size()[3])).to(device) - left_features.size()[3] 316 | max_offset_x = min_offset_x + 2 * left_features.size()[3] 317 | min_offset_y = min_offset_x + left_features.size()[3] - left_features.size()[2] 318 | max_offset_y = min_offset_y + 2 * left_features.size()[2] 319 | 320 | for prop_iter in range(self.iteration_count): 321 | offset_x, offset_y = self.uniform_sampler(min_offset_x, max_offset_x, 322 | min_offset_y, max_offset_y) 323 | 324 | offset_x, offset_y = self.propagation(offset_x, offset_y, device, "horizontal") 325 | offset_x, offset_y = self.evaluate(left_features, 326 | right_features, 327 | offset_x, offset_y) 328 | 329 | offset_x, offset_y = self.propagation(offset_x, offset_y, device, "vertical") 330 | offset_x, offset_y = self.evaluate(left_features, 331 | right_features, 332 | offset_x, offset_y) 333 | 334 | min_offset_x = torch.clamp(offset_x - self.window_size_x // 2, min=-left_features.size()[3], 335 | max=left_features.size()[3]) 336 | max_offset_x = torch.clamp(offset_x + self.window_size_x // 2, min=-left_features.size()[3], 337 | max=left_features.size()[3]) 338 | min_offset_y = torch.clamp(offset_y - self.window_size_y // 2, min=-left_features.size()[2], 339 | max=left_features.size()[2]) 340 | max_offset_y = torch.clamp(offset_y + self.window_size_y // 2, min=-left_features.size()[2], 341 | max=left_features.size()[2]) 342 | 343 | return offset_x, offset_y, self.evaluate.left_x_coordinate, self.evaluate.left_y_coordinate 344 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by the text below. 2 | 3 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. 4 | 5 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. 6 | 7 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. 8 | 9 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. 10 | 11 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under this License. 12 | 13 | This License governs use of the accompanying Work, and your use of the Work constitutes acceptance of this License. 14 | 15 | You may use this Work for any non-commercial purpose, subject to the restrictions in this License. Some purposes which can be non-commercial are teaching, academic research, and personal experimentation. You may also distribute this Work with books or other teaching materials, or publish the Work on websites, that are intended to teach the use of the Work. 16 | 17 | You may not use or distribute this Work, or any derivative works, outputs, or results from the Work, in any form for commercial purposes. Non-exhaustive examples of commercial purposes would be running business operations, licensing, leasing, or selling the Work, or distributing the Work for use with commercial products. 18 | 19 | You may modify this Work and distribute the modified Work for non-commercial purposes, however, you may not grant rights to the Work or derivative works that are broader than or in conflict with those provided by this License. For example, you may not distribute modifications of the Work under terms that would permit commercial use, or under terms that purport to require the Work or derivative works to be sublicensed to others. 20 | 21 | In return, we require that you agree: 22 | 23 | 1. Not to remove any copyright or other notices from the Work. 24 | 25 | 2. That if you distribute the Work in Source or Object form, you will include a verbatim copy of this License. 26 | 27 | 3. That if you distribute derivative works of the Work in Source form, you do so only under a license that includes all of the provisions of this License and is not in conflict with this License, and if you distribute derivative works of the Work solely in Object form you do so only under a license that complies with this License. 28 | 29 | 4. That if you have modified the Work or created derivative works from the Work, and distribute such modifications or derivative works, you will cause the modified files to carry prominent notices so that recipients know that they are not receiving the original Work. Such notices must state: (i) that you have changed the Work; and (ii) the date of any changes. 30 | 31 | 5. If you publicly use the Work or any output or result of the Work, you will provide a notice with such use that provides any person who uses, views, accesses, interacts with, or is otherwise exposed to the Work (i) with information of the nature of the Work, (ii) with a link to the Work, and (iii) a notice that the Work is available under this License. 32 | 33 | 6. THAT THE WORK COMES "AS IS", WITH NO WARRANTIES. THIS MEANS NO EXPRESS, IMPLIED OR STATUTORY WARRANTY, INCLUDING WITHOUT LIMITATION, WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE OR ANY WARRANTY OF TITLE OR NON-INFRINGEMENT. ALSO, YOU MUST PASS THIS DISCLAIMER ON WHENEVER YOU DISTRIBUTE THE WORK OR DERIVATIVE WORKS. 34 | 35 | 7. THAT NEITHER UBER TECHNOLOGIES, INC. NOR ANY OF ITS AFFILIATES, SUPPLIERS, SUCCESSORS, NOR ASSIGNS WILL BE LIABLE FOR ANY DAMAGES RELATED TO THE WORK OR THIS LICENSE, INCLUDING DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL OR INCIDENTAL DAMAGES, TO THE MAXIMUM EXTENT THE LAW PERMITS, NO MATTER WHAT LEGAL THEORY IT IS BASED ON. ALSO, YOU MUST PASS THIS LIMITATION OF LIABILITY ON WHENEVER YOU DISTRIBUTE THE WORK OR DERIVATIVE WORKS. 36 | 37 | 8. That if you sue anyone over patents that you think may apply to the Work or anyone's use of the Work, your license to the Work ends automatically. 38 | 39 | 9. That your rights under the License end automatically if you breach it in any way. 40 | 41 | 10. Uber Technologies, Inc. reserves all rights not expressly granted to you in this License. 42 | -------------------------------------------------------------------------------- /NOTICE: -------------------------------------------------------------------------------- 1 | "DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch" includes derivied work from: 2 | 3 | 1. PSMNet (https://github.com/JiaRenChang/PSMNet) 4 | 5 | MIT License 6 | 7 | Copyright (c) 2018 Jia-Ren Chang 8 | 9 | Permission is hereby granted, free of charge, to any person obtaining a copy 10 | of this software and associated documentation files (the "Software"), to deal 11 | in the Software without restriction, including without limitation the rights 12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 13 | copies of the Software, and to permit persons to whom the Software is 14 | furnished to do so, subject to the following conditions: 15 | 16 | The above copyright notice and this permission notice shall be included in all 17 | copies or substantial portions of the Software. 18 | 19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 25 | SOFTWARE. 26 | 27 | 28 | The derived work can be found in the files: 29 | 30 | 1. DeepPruner/dataloader/readpfm.py 31 | 2. DeepPruner/dataloader/preprocess.py 32 | 3. DeepPruner/dataloader/kitti_submission_collector.py 33 | 4. DeepPruner/dataloader/kitti_loader.py 34 | 5. DeepPruner/dataloader/kitti_collector.py 35 | 6. DeepPruner/dataloader/sceneflow_loader.py 36 | 7. DeepPruner/dataloader/sceneflow_collector.py 37 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 2 | 3 | This repository releases code for our paper [DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch](https://arxiv.org/abs/1909.05845). 4 | 5 | ##### Table of Contents 6 | [DeepPruner](#DeepPruner) 7 | [Differentiable Patch Match](#DifferentiablePatchMatch) 8 | [Requirements (Major Dependencies)](#Requirements) 9 | [Citation](#Citation) 10 | 11 | 12 | 13 | ### **DeepPruner** 14 | 15 | + An efficient "Real Time Stereo Matching" algorithm, which takes as input 2 images and outputs a disparity (or depth) map. 16 | 17 | 18 | 19 | ![](readme_images/DeepPruner.png) 20 | 21 | 22 | 23 | + Results/ Metrics: 24 | 25 | + [**KITTI**](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo): **Results competitive to SOTA, while being real-time (8x faster than SOTA). SOTA among published real-time algorithms**. 26 | 27 | 28 | ![](readme_images/KITTI_test_set.png) 29 | ![](readme_images/CRP.png) 30 | ![](readme_images/uncertainty_vis.png) 31 | 32 | 33 | + [**ETH3D**](https://www.eth3d.net/low_res_two_view?mask=all&metric=bad-2-0): **SOTA among all ROB entries**. 34 | 35 | + **SceneFlow**: **2nd among all published algorithms, while being 8x faster than the 1st.** 36 | 37 |

38 | 39 |

40 | 41 | 42 | 43 | + [**Robust Vision Challenge**](http://www.robustvision.net/index.php): **Overall ranking 1st**. 44 | 45 |

46 | 47 |

48 | 49 | + Runtime: **62ms** (for DeepPruner-fast), **180ms** (for DeepPruner-best) 50 | 51 | + Cuda Memory Requirements: **805MB** (for DeepPruner-best) 52 | 53 | 54 | 55 | 56 | ### **Differentiable Patch Match** 57 | + Fast algorithm for finding dense nearest neighbor correspondences between patches of images regions. 58 | Differentiable version of the generalized Patch Match algorithm. ([Barnes et al.](https://gfx.cs.princeton.edu/pubs/Barnes_2010_TGP/index.php)) 59 | 60 |

61 | 62 |

63 | 64 | More details in the corresponding folder README. 65 | 66 | 67 | 68 | ## Requirements (Major Dependencies) 69 | + Pytorch (0.4.1+) 70 | + Python2.7 71 | + torchvision (0.2.0+) 72 | 73 | 74 | 75 | 76 | ## Citation 77 | 78 | If you use our source code, or our paper, please consider citing the following: 79 | > @inproceedings{Duggal2019ICCV, 80 | title = {DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch}, 81 | author = {Shivam Duggal and Shenlong Wang and Wei-Chiu Ma and Rui Hu and Raquel Urtasun}, 82 | booktitle = {ICCV}, 83 | year = {2019} 84 | } 85 | 86 | Correspondences to Shivam Duggal , Shenlong Wang , Wei-Chiu Ma 87 | -------------------------------------------------------------------------------- /deeppruner/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.log -------------------------------------------------------------------------------- /deeppruner/README.md: -------------------------------------------------------------------------------- 1 | # DeepPruner 2 | 3 | This is the code repository for **DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch**. 4 | ![](../readme_images/DeepPruner.png) 5 | 6 | 7 | ##### Table of Contents 8 | 9 | [Requirements](#Requirements) 10 | [License](#License) 11 | [Model Weights](#Weights) 12 | [Training and Evaluation](#TrainingEvaluation) 13 | - [KITTI](#KITTI) 14 | - [Sceneflow](#Sceneflow) 15 | - [Robust Vision Challenge](#ROB) 16 | 17 | 18 | ## Requirements 19 | 20 | + Pytorch (0.4.1+) 21 | + python (2.7) 22 | + scikit-image 23 | + tensorboardX 24 | + torchvision (0.2.0+) 25 | 26 | 27 | ## License 28 | 29 | + The souce code for DeepPruner and Differentiable PatchMatch are released under the © Uber, 2018-2019. Licensed under the Uber Non-Commercial License. 30 | + The trained model-weights for DeepPruner are released under the license Creative Commons Attribution-NonCommercial-ShareAlike 3.0 License. 31 | 32 | 33 | 34 | ## Model Weights 35 | 36 | DeepPruner was first trained on Sceneflow dataset and then finetuned on KITTI (Combined 394 images of KITTI-2012 and KITTI-2015) dataset. 37 | 38 | + DeepPruner-fast (KITTI) 39 | + DeepPruner-best (KITTI) 40 | + DeepPruner-fast (Sceneflow) 41 | + DeepPruner-best (Sceneflow) 42 | 43 | 44 | 45 | ## Training and Evaluation 46 | 47 | **NOTE:** We allow the users to modify a bunch of model parameters and training setting for their own purpose. You may need to retain the model with the modified parameters. 48 | Check **'models/config.py'** for more details. 49 | 50 | 51 | 52 | ### KITTI Stereo 2012/2015: 53 | 54 | KITTI 2015 has 200 stereo-pairs with ground truth disparities. We used 160 out of these 200 for training and the remaining 40 for validation. The training set was further augmented by 194 stereo image pairs from KITTI 2012 dataset. 55 | 56 | #### Setup: 57 | 58 | 1. Download the KITTI 2012 and KITTI 2015 datasets. 59 | 2. Split KITTI Stereo 2015 training dataset into "training" (160 pairs) and "validation" (40 pairs), following the same directory structure as of the original dataset. **Make sure to have the following structure**: 60 | 61 | > training_directory_stereo_2015 \ 62 | > |----- image_2 \ 63 | > |----- image_3 \ 64 | > |----- disp_occ_0 \ 65 | > val_directory_stereo_2015 \ 66 | > |----- image_2 \ 67 | > |----- image_3 \ 68 | > |----- disp_occ_0 \ 69 | > train_directory_stereo_2012 \ 70 | > |----- colored_0 \ 71 | > |----- colored_1 \ 72 | > |----- disp_occ \ 73 | > test_directory \ 74 | > |----- image_2 \ 75 | > |----- image_3 76 | 77 | 3. Note that, any other dataset could be used for training, validation and testing. The directory structure should be same. 78 | 79 | #### Training Command: 80 | 1. Like previous works, we fine-tuned the pre-trained Sceneflow model on KITTI dataset. 81 | 2. Training Command: 82 | 83 | > python finetune_kitti.py \\\ 84 | > --loadmodel \\\ 85 | > --savemodel \\\ 86 | > --train_datapath_2015 \\\ 87 | > --val_datapath_2015 \\\ 88 | > --datapath_2012 89 | 90 | 3. Training command arguments: 91 | 92 | + --loadmodel (default: None): If not set, the model would train from scratch. 93 | + --savemodel (default: './'): If not set, the script will save the trained models after every epoch in the same directory. 94 | + --train_datapath_2015 (default: None): If not set, KITTI stereo 2015 dataset won't be used for training. 95 | + --val_datapath_2015 (default: None): If not set, the script would fail. The validation dataset should have atleast one image to run. 96 | + --datapath_2012 (default: None): If not set, KITTI stereo 2012 dataset won't be used for training.4. Training and validation tensorboard runs will be saved in **'./runs/' directory**. 97 | 98 | #### Evaluation Command (On Any Dataset): 99 | 1. We used KITTI 2015 Stereo testing set for evaluation. **(Note any other dataset could be used.)** 100 | 2. To evaluate DeepPruner on any dataset, just create a base directory like: 101 | > test_directory \ 102 | > |----- image_2 \ 103 | > |----- image_3 104 | 105 | image_2 folder holds the left images, while image_3 folder holds the right images. 106 | 107 | 3. For evaluation, update the "mode" parameter in "models.config.py" to "evaluation". 108 | 109 | 4. Evaluation command: 110 | 111 | > python submission_kitti.py \ 112 | > --loadmodel \ 113 | > --save_dir \ 114 | > --datapath 115 | 116 | #### Metrics/ Results: 117 | 118 | 1. The metrics used for evaluation are same as provided by the [KITTI Stereo benchmark](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo). 119 | 2. The quantitaive results obtained by DeepPruner are as follows: (Note the standings in the Tables below are at the time of March 2019.) 120 | 121 |

122 | 123 |

124 | 125 | 3. Alongside learning the disparity (or depth maps), DeepPruner is able to predict the uncertain regions (occluded regions , bushy regions, object edges) efficiently . Since the uncertainty in prediction correlates well with the error in the disparity maps (Figure 7.), such uncertainty can be used in other downstream tasks. 126 | 127 | 4. Qualitative results are as follows: 128 | 129 | 130 | ![](../readme_images/KITTI_test_set.png) 131 | ![](../readme_images/CRP.png) 132 | ![](../readme_images/uncertainty_vis.png) 133 | 134 | 135 | 136 | 137 | ### Sceneflow: 138 | 139 | #### Setup: 140 | 141 | 1. Download [Sceneflow dataset](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html#downloads), which consists of FlyingThings3D, Driving and Monkaa (RGB images (cleanpass) and disparity). 142 | 143 | 2. We followed the same directory structure as of the downloaded data. Check **dataloader/sceneflow_collector.py**. 144 | 145 | #### Training/ Evaluation Command: 146 | 147 | > python train_sceneflow.py \\\ 148 | > --loadmodel \\\ 149 | > --save_dir \\\ 150 | > --savemodel \\\ 151 | > --datapath_monkaa \\\ 152 | > --datapath_flying \\\ 153 | > --datapath_driving 154 | 155 | #### Metrics/ Results: 156 | 157 | 1. We used EPE(end point error) as one of the metrics. 158 | 159 |

160 | 161 |

162 | 163 |

164 | 165 |

166 | 167 | 168 | 169 | 170 | 171 | 172 | 173 | ### Robust Vision Challenge: 174 | 175 | #### Details: 176 | 177 | 1. The goal of Robust Vision Challenge challenge is to foster the development of vision systems that are robust and 178 | consequently perform well on a variety of datasets with different characteristics. 179 | Please refer to Robust Vision Challenge for more details. 180 | 181 | 2. We used the pre-trained Seneflow model and then jointly fine-tuned the model on KITTI, ETH3D and Middlebury datasets. 182 | 183 | #### Setup 184 | 185 | 1. Dataloader and setup details coming soon. 186 | 187 | #### Metrics/ Results: 188 | 189 | Check DeepPruner_ROB on KITTI benchmark. \ 190 | Check DeepPruner_ROB on ETH3D benchmark. \ 191 | Check DeepPruner_ROB on MiddleburyV3 benchmark. \ 192 | 193 | 194 |

195 | 196 |

197 | 198 |

199 | 200 |

201 | -------------------------------------------------------------------------------- /deeppruner/dataloader/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc -------------------------------------------------------------------------------- /deeppruner/dataloader/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/deeppruner/dataloader/__init__.py -------------------------------------------------------------------------------- /deeppruner/dataloader/kitti_collector.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch.utils.data as data 18 | import os 19 | import glob 20 | 21 | 22 | def datacollector(train_filepath, val_filepath, filepath_2012): 23 | 24 | left_fold = 'image_2/' 25 | right_fold = 'image_3/' 26 | disp_L = 'disp_occ_0/' 27 | disp_R = 'disp_occ_1/' 28 | 29 | left_fold_2012 = 'colored_0/' 30 | right_fold_2012 = 'colored_1/' 31 | disp_L_2012 = 'disp_occ/' 32 | 33 | left_train = [] 34 | right_train = [] 35 | disp_train_L = [] 36 | 37 | left_val = [] 38 | right_val = [] 39 | disp_val_L = [] 40 | 41 | 42 | if train_filepath is not None: 43 | left_train = sorted(glob.glob(os.path.join(train_filepath, left_fold, '*.png'))) 44 | right_train = sorted(glob.glob(os.path.join(train_filepath, right_fold, '*.png'))) 45 | disp_train_L = sorted(glob.glob(os.path.join(train_filepath, disp_L, '*.png'))) 46 | 47 | if filepath_2012 is not None: 48 | left_train +=sorted(glob.glob(os.path.join(filepath_2012, left_fold_2012, '*_10.png'))) 49 | right_train += sorted(glob.glob(os.path.join(filepath_2012, right_fold_2012, '*_10.png'))) 50 | disp_train_L += sorted(glob.glob(os.path.join(filepath_2012, disp_L_2012, '*_10.png'))) 51 | 52 | if val_filepath is not None: 53 | left_val = sorted(glob.glob(os.path.join(val_filepath, left_fold, '*.png'))) 54 | right_val = sorted(glob.glob(os.path.join(val_filepath, right_fold, '*.png'))) 55 | disp_val_L = sorted(glob.glob(os.path.join(val_filepath, disp_L, '*.png'))) 56 | 57 | return left_train, right_train, disp_train_L, left_val, right_val, disp_val_L 58 | -------------------------------------------------------------------------------- /deeppruner/dataloader/kitti_loader.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch.utils.data as data 18 | import random 19 | from PIL import Image 20 | import numpy as np 21 | from dataloader import preprocess 22 | 23 | 24 | def default_loader(path): 25 | return Image.open(path).convert('RGB') 26 | 27 | 28 | def disparity_loader(path): 29 | return Image.open(path) 30 | 31 | # train/ validation image crop size constants 32 | DEFAULT_TRAIN_IMAGE_HEIGHT = 256 33 | DEFAULT_TRAIN_IMAGE_WIDTH = 512 34 | 35 | DEFAULT_VAL_IMAGE_HEIGHT = 320 36 | DEFAULT_VAL_IMAGE_WIDTH = 1216 37 | 38 | 39 | class KITTILoader(data.Dataset): 40 | def __init__(self, left_images, right_images, left_disparity, training, loader=default_loader, dploader=disparity_loader): 41 | 42 | self.left_img = left_images 43 | self.right_img = right_images 44 | self.left_disp = left_disparity 45 | self.loader = loader 46 | self.dploader = dploader 47 | self.training = training 48 | 49 | def __getitem__(self, index): 50 | left_img = self.left_img[index] 51 | right_img = self.right_img[index] 52 | left_disp = self.left_disp[index] 53 | 54 | left_img = self.loader(left_img) 55 | right_img = self.loader(right_img) 56 | left_disp = self.dploader(left_disp) 57 | w, h = left_img.size 58 | 59 | 60 | if self.training: 61 | th, tw = DEFAULT_TRAIN_IMAGE_HEIGHT, DEFAULT_TRAIN_IMAGE_WIDTH 62 | x1 = random.randint(0, w - tw) 63 | y1 = random.randint(0, h - th) 64 | 65 | else: 66 | th, tw = DEFAULT_VAL_IMAGE_HEIGHT, DEFAULT_VAL_IMAGE_WIDTH 67 | x1 = w - DEFAULT_VAL_IMAGE_WIDTH 68 | y1 = h - DEFAULT_VAL_IMAGE_HEIGHT 69 | 70 | left_img = left_img.crop((x1, y1, x1 + tw, y1 + th)) 71 | right_img = right_img.crop((x1, y1, x1 + tw, y1 + th)) 72 | left_disp = left_disp.crop((x1, y1, x1 + tw, y1 + th)) 73 | left_disp = np.ascontiguousarray(left_disp, dtype=np.float32) / 256 74 | 75 | 76 | processed = preprocess.get_transform() 77 | left_img = processed(left_img) 78 | right_img = processed(right_img) 79 | 80 | return left_img, right_img, left_disp 81 | 82 | def __len__(self): 83 | return len(self.left_img) 84 | -------------------------------------------------------------------------------- /deeppruner/dataloader/kitti_submission_collector.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import os 18 | 19 | 20 | def datacollector(filepath): 21 | left_fold = 'image_2/' 22 | right_fold = 'image_3/' 23 | disp = 'disp_occ_0/' 24 | 25 | image = [img for img in sorted(os.listdir(os.path.join(filepath,left_fold))) if img.find('.png') > -1] 26 | 27 | left_test = [os.path.join(filepath, left_fold, img) for img in image] 28 | right_test = [os.path.join(filepath, right_fold, img) for img in image] 29 | 30 | return left_test, right_test 31 | -------------------------------------------------------------------------------- /deeppruner/dataloader/preprocess.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torchvision.transforms as transforms 18 | 19 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406], 20 | 'std': [0.229, 0.224, 0.225]} 21 | 22 | 23 | def get_transform(): 24 | 25 | normalize = __imagenet_stats 26 | t_list = [ 27 | transforms.ToTensor(), 28 | transforms.Normalize(**normalize), 29 | ] 30 | 31 | return transforms.Compose(t_list) 32 | -------------------------------------------------------------------------------- /deeppruner/dataloader/readpfm.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import re 18 | import numpy as np 19 | import sys 20 | 21 | 22 | def readPFM(file): 23 | file = open(file, 'rb') 24 | 25 | color = None 26 | width = None 27 | height = None 28 | scale = None 29 | endian = None 30 | 31 | header = file.readline().rstrip() 32 | if header == 'PF': 33 | color = True 34 | elif header == 'Pf': 35 | color = False 36 | else: 37 | raise Exception('Not a PFM file.') 38 | 39 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline()) 40 | if dim_match: 41 | width, height = map(int, dim_match.groups()) 42 | else: 43 | raise Exception('Malformed PFM header.') 44 | 45 | scale = float(file.readline().rstrip()) 46 | if scale < 0: # little-endian 47 | endian = '<' 48 | scale = -scale 49 | else: 50 | endian = '>' # big-endian 51 | 52 | data = np.fromfile(file, endian + 'f') 53 | shape = (height, width, 3) if color else (height, width) 54 | 55 | data = np.reshape(data, shape) 56 | data = np.flipud(data) 57 | return data, scale 58 | -------------------------------------------------------------------------------- /deeppruner/dataloader/sceneflow_collector.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch.utils.data as data 18 | from PIL import Image 19 | import os 20 | import os.path 21 | import logging 22 | 23 | IMG_EXTENSIONS = [ 24 | '.jpg', '.JPG', '.jpeg', '.JPEG', 25 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP', 26 | ] 27 | 28 | 29 | def is_image_file(filename): 30 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS) 31 | 32 | 33 | def dataloader(filepath_monkaa, filepath_flying, filepath_driving): 34 | 35 | try: 36 | monkaa_path = os.path.join(filepath_monkaa, 'monkaa_frames_cleanpass') 37 | monkaa_disp = os.path.join(filepath_monkaa, 'monkaa_disparity') 38 | monkaa_dir = os.listdir(monkaa_path) 39 | 40 | all_left_img = [] 41 | all_right_img = [] 42 | all_left_disp = [] 43 | test_left_img = [] 44 | test_right_img = [] 45 | test_left_disp = [] 46 | 47 | for dd in monkaa_dir: 48 | for im in os.listdir(os.path.join(monkaa_path, dd, 'left')): 49 | if is_image_file(os.path.join(monkaa_path, dd, 'left', im)) and is_image_file( 50 | os.path.join(monkaa_path, dd, 'right', im)): 51 | all_left_img.append(os.path.join(monkaa_path, dd, 'left', im)) 52 | all_left_disp.append(os.path.join(monkaa_disp, dd, 'left', im.split(".")[0] + '.pfm')) 53 | all_right_img.append(os.path.join(monkaa_path, dd, 'right', im)) 54 | 55 | except: 56 | logging.error("Some error in Monkaa, Monkaa might not be loaded correctly in this case...") 57 | raise Exception('Monkaa dataset couldn\'t be loaded correctly.') 58 | 59 | 60 | try: 61 | flying_path = os.path.join(filepath_flying, 'frames_cleanpass') 62 | flying_disp = os.path.join(filepath_flying, 'disparity') 63 | flying_dir = flying_path + '/TRAIN/' 64 | subdir = ['A', 'B', 'C'] 65 | 66 | for ss in subdir: 67 | flying = os.listdir(os.path.join(flying_dir, ss)) 68 | 69 | for ff in flying: 70 | imm_l = os.listdir(os.path.join(flying_dir, ss, ff, 'left')) 71 | for im in imm_l: 72 | if is_image_file(os.path.join(flying_dir, ss, ff, 'left', im)): 73 | all_left_img.append(os.path.join(flying_dir, ss, ff, 'left', im)) 74 | 75 | all_left_disp.append(os.path.join(flying_disp, 'TRAIN', ss, ff, 'left', im.split(".")[0] + '.pfm')) 76 | 77 | if is_image_file(os.path.join(flying_dir, ss, ff, 'right', im)): 78 | all_right_img.append(os.path.join(flying_dir, ss, ff, 'right', im)) 79 | 80 | flying_dir = flying_path + '/TEST/' 81 | subdir = ['A', 'B', 'C'] 82 | 83 | for ss in subdir: 84 | flying = os.listdir(os.path.join(flying_dir, ss)) 85 | 86 | for ff in flying: 87 | imm_l = os.listdir(os.path.join(flying_dir, ss, ff, 'left')) 88 | for im in imm_l: 89 | if is_image_file(os.path.join(flying_dir, ss, ff, 'left', im)): 90 | test_left_img.append(os.path.join(flying_dir, ss, ff, 'left', im)) 91 | 92 | test_left_disp.append(os.path.join(flying_disp, 'TEST', ss, ff, 'left', im.split(".")[0] + '.pfm')) 93 | 94 | if is_image_file(os.path.join(flying_dir, ss, ff, 'right', im)): 95 | test_right_img.append(os.path.join(flying_dir, ss, ff, 'right', im)) 96 | 97 | except: 98 | logging.error("Some error in Flying Things, Flying Things might not be loaded correctly in this case...") 99 | raise Exception('Flying Things dataset couldn\'t be loaded correctly.') 100 | 101 | try: 102 | driving_dir = os.path.join(filepath_driving, 'driving_frames_cleanpass/') 103 | driving_disp = os.path.join(filepath_driving, 'driving_disparity/') 104 | 105 | subdir1 = ['35mm_focallength', '15mm_focallength'] 106 | subdir2 = ['scene_backwards', 'scene_forwards'] 107 | subdir3 = ['fast', 'slow'] 108 | 109 | for i in subdir1: 110 | for j in subdir2: 111 | for k in subdir3: 112 | imm_l = os.listdir(os.path.join(driving_dir, i, j, k, 'left')) 113 | for im in imm_l: 114 | if is_image_file(os.path.join(driving_dir, i, j, k, 'left', im)): 115 | all_left_img.append(os.path.join(driving_dir, i, j, k, 'left', im)) 116 | all_left_disp.append(os.path.join(driving_disp, i, j, k, 'left', im.split(".")[0] + '.pfm')) 117 | 118 | if is_image_file(os.path.join(driving_dir, i, j, k, 'right', im)): 119 | all_right_img.append(os.path.join(driving_dir, i, j, k, 'right', im)) 120 | except: 121 | logging.error("Some error in Driving, Driving might not be loaded correctly in this case...") 122 | raise Exception('Driving dataset couldn\'t be loaded correctly.') 123 | 124 | return all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp 125 | -------------------------------------------------------------------------------- /deeppruner/dataloader/sceneflow_loader.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch.utils.data as data 18 | import random 19 | from PIL import Image 20 | from dataloader import preprocess 21 | from dataloader import readpfm as rp 22 | import numpy as np 23 | import math 24 | 25 | # train/ validation image crop size constants 26 | DEFAULT_TRAIN_IMAGE_HEIGHT = 256 27 | DEFAULT_TRAIN_IMAGE_WIDTH = 512 28 | 29 | def default_loader(path): 30 | return Image.open(path).convert('RGB') 31 | 32 | 33 | def disparity_loader(path): 34 | return rp.readPFM(path) 35 | 36 | 37 | class SceneflowLoader(data.Dataset): 38 | def __init__(self, left_images, right_images, left_disparity, downsample_scale, training, loader=default_loader, dploader=disparity_loader): 39 | 40 | self.left_images = left_images 41 | self.right_images = right_images 42 | self.left_disparity = left_disparity 43 | self.loader = loader 44 | self.dploader = dploader 45 | self.training = training 46 | 47 | # downsample_scale denotes maximum times the image features are downsampled 48 | # by the network. 49 | # Since the image size used for evaluation may not be divisible by the downsample_scale, 50 | # we pad it with zeros, so that it becomes divible and later unpad the extra zeros. 51 | self.downsample_scale = downsample_scale 52 | 53 | def __getitem__(self, index): 54 | left_img = self.left_images[index] 55 | right_img = self.right_images[index] 56 | left_disp = self.left_disparity[index] 57 | 58 | left_img = self.loader(left_img) 59 | right_img = self.loader(right_img) 60 | left_disp, left_scale = self.dploader(left_disp) 61 | left_disp = np.ascontiguousarray(left_disp, dtype=np.float32) 62 | 63 | if self.training: 64 | w, h = left_img.size 65 | th, tw = DEFAULT_TRAIN_IMAGE_HEIGHT, DEFAULT_TRAIN_IMAGE_WIDTH 66 | 67 | x1 = random.randint(0, w - tw) 68 | y1 = random.randint(0, h - th) 69 | 70 | left_img = left_img.crop((x1, y1, x1 + tw, y1 + th)) 71 | right_img = right_img.crop((x1, y1, x1 + tw, y1 + th)) 72 | left_disp = left_disp[y1:y1 + th, x1:x1 + tw] 73 | 74 | processed = preprocess.get_transform() 75 | left_img = processed(left_img) 76 | right_img = processed(right_img) 77 | 78 | return left_img, right_img, left_disp 79 | else: 80 | w, h = left_img.size 81 | 82 | dw = w + (self.downsample_scale - (w%self.downsample_scale + (w%self.downsample_scale==0)*self.downsample_scale)) 83 | dh = h + (self.downsample_scale - (h%self.downsample_scale + (h%self.downsample_scale==0)*self.downsample_scale)) 84 | 85 | left_img = left_img.crop((w - dw, h - dh, w, h)) 86 | right_img = right_img.crop((w - dw, h - dh, w, h)) 87 | 88 | processed = preprocess.get_transform() 89 | left_img = processed(left_img) 90 | right_img = processed(right_img) 91 | 92 | return left_img, right_img, left_disp, dw-w, dh-h 93 | 94 | def __len__(self): 95 | return len(self.left_images) 96 | -------------------------------------------------------------------------------- /deeppruner/finetune_kitti.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | 17 | from __future__ import print_function 18 | import argparse 19 | import os 20 | import random 21 | from collections import namedtuple 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.parallel 25 | import torch.backends.cudnn as cudnn 26 | import torch.optim as optim 27 | import torch.utils.data 28 | from torch.autograd import Variable 29 | import skimage 30 | import skimage.transform 31 | import numpy as np 32 | from dataloader import kitti_collector as ls 33 | from dataloader import kitti_loader as DA 34 | from models.deeppruner import DeepPruner 35 | from tensorboardX import SummaryWriter 36 | from torchvision import transforms 37 | from loss_evaluation import loss_evaluation 38 | from models.config import config as config_args 39 | import matplotlib.pyplot as plt 40 | import logging 41 | from setup_logging import setup_logging 42 | 43 | parser = argparse.ArgumentParser(description='DeepPruner') 44 | parser.add_argument('--train_datapath_2015', default=None, 45 | help='training data path of KITTI 2015') 46 | parser.add_argument('--datapath_2012', default=None, 47 | help='data path of KITTI 2012 (all used for training)') 48 | parser.add_argument('--val_datapath_2015', default=None, 49 | help='validation data path of KITTI 2015') 50 | parser.add_argument('--epochs', type=int, default=1040, 51 | help='number of epochs to train') 52 | parser.add_argument('--loadmodel', default=None, 53 | help='load model') 54 | parser.add_argument('--savemodel', default='./', 55 | help='save model') 56 | parser.add_argument('--logging_filename', default='./finetune_kitti.log', 57 | help='filename for logs') 58 | parser.add_argument('--no-cuda', action='store_true', default=False, 59 | help='enables CUDA training') 60 | parser.add_argument('--seed', type=int, default=1, metavar='S', 61 | help='random seed (default: 1)') 62 | 63 | args = parser.parse_args() 64 | args.cuda = not args.no_cuda and torch.cuda.is_available() 65 | torch.manual_seed(args.seed) 66 | if args.cuda: 67 | torch.manual_seed(args.seed) 68 | np.random.seed(args.seed) 69 | random.seed(args.seed) 70 | torch.cuda.manual_seed(args.seed) 71 | torch.backends.cudnn.deterministic = True 72 | 73 | 74 | args.cost_aggregator_scale = config_args.cost_aggregator_scale 75 | 76 | setup_logging(args.logging_filename) 77 | 78 | all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.datacollector( 79 | args.train_datapath_2015, args.val_datapath_2015, args.datapath_2012) 80 | 81 | 82 | TrainImgLoader = torch.utils.data.DataLoader( 83 | DA.KITTILoader(all_left_img, all_right_img, all_left_disp, True), 84 | batch_size=16, shuffle=True, num_workers=8, drop_last=False) 85 | 86 | TestImgLoader = torch.utils.data.DataLoader( 87 | DA.KITTILoader(test_left_img, test_right_img, test_left_disp, False), 88 | batch_size=8, shuffle=False, num_workers=4, drop_last=False) 89 | 90 | model = DeepPruner() 91 | writer = SummaryWriter() 92 | model = nn.DataParallel(model) 93 | 94 | 95 | if args.cuda: 96 | model.cuda() 97 | 98 | if args.loadmodel is not None: 99 | logging.info("loading model...") 100 | state_dict = torch.load(args.loadmodel) 101 | model.load_state_dict(state_dict['state_dict'], strict=True) 102 | 103 | 104 | logging.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 105 | optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999)) 106 | 107 | 108 | def train(imgL, imgR, disp_L, iteration, epoch): 109 | if epoch >= 800: 110 | model.eval() 111 | else: 112 | model.train() 113 | 114 | imgL = Variable(torch.FloatTensor(imgL)) 115 | imgR = Variable(torch.FloatTensor(imgR)) 116 | disp_L = Variable(torch.FloatTensor(disp_L)) 117 | 118 | if args.cuda: 119 | imgL, imgR, disp_true = imgL.cuda(), imgR.cuda(), disp_L.cuda() 120 | 121 | mask = (disp_true > 0) 122 | mask.detach_() 123 | 124 | optimizer.zero_grad() 125 | result = model(imgL,imgR) 126 | 127 | loss, _ = loss_evaluation(result, disp_true, mask, args.cost_aggregator_scale) 128 | 129 | loss.backward() 130 | optimizer.step() 131 | 132 | return loss.item() 133 | 134 | 135 | 136 | def test(imgL,imgR,disp_L,iteration): 137 | 138 | model.eval() 139 | with torch.no_grad(): 140 | imgL = Variable(torch.FloatTensor(imgL)) 141 | imgR = Variable(torch.FloatTensor(imgR)) 142 | disp_L = Variable(torch.FloatTensor(disp_L)) 143 | 144 | if args.cuda: 145 | imgL, imgR, disp_true = imgL.cuda(), imgR.cuda(), disp_L.cuda() 146 | 147 | mask = (disp_true > 0) 148 | mask.detach_() 149 | 150 | optimizer.zero_grad() 151 | 152 | result = model(imgL,imgR) 153 | loss, output_disparity = loss_evaluation(result, disp_true, mask, args.cost_aggregator_scale) 154 | 155 | #computing 3-px error: (source psmnet)# 156 | true_disp = disp_true.data.cpu() 157 | disp_true = true_disp 158 | pred_disp = output_disparity.data.cpu() 159 | 160 | index = np.argwhere(true_disp>0) 161 | disp_true[index[0][:], index[1][:], index[2][:]] = np.abs(true_disp[index[0][:], index[1][:], index[2][:]]-pred_disp[index[0][:], index[1][:], index[2][:]]) 162 | correct = (disp_true[index[0][:], index[1][:], index[2][:]] < 3)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05) 163 | torch.cuda.empty_cache() 164 | 165 | loss = 1-(float(torch.sum(correct))/float(len(index[0]))) 166 | 167 | return loss 168 | 169 | 170 | def adjust_learning_rate(optimizer, epoch): 171 | if epoch <= 500: 172 | lr = 0.0001 173 | elif epoch<=1000: 174 | lr = 0.00005 175 | else: 176 | lr = 0.00001 177 | logging.info('learning rate = %.5f' %(lr)) 178 | for param_group in optimizer.param_groups: 179 | param_group['lr'] = lr 180 | 181 | 182 | def main(): 183 | 184 | for epoch in range(0, args.epochs): 185 | total_train_loss = 0 186 | total_test_loss = 0 187 | adjust_learning_rate(optimizer,epoch) 188 | 189 | if epoch %1==0 and epoch!=0: 190 | for batch_idx, (imgL, imgR, disp_L) in enumerate(TestImgLoader): 191 | test_loss = test(imgL,imgR,disp_L,batch_idx) 192 | total_test_loss += test_loss 193 | logging.info('Iter %d 3-px error in val = %.3f \n' %(batch_idx, test_loss)) 194 | 195 | logging.info('epoch %d total test loss = %.3f' %(epoch, total_test_loss/len(TestImgLoader))) 196 | writer.add_scalar("val-loss",total_test_loss/len(TestImgLoader),epoch) 197 | 198 | for batch_idx, (imgL_crop, imgR_crop, disp_crop_L) in enumerate(TrainImgLoader): 199 | loss = train(imgL_crop,imgR_crop,disp_crop_L,batch_idx,epoch) 200 | total_train_loss += loss 201 | logging.info('Iter %d training loss = %.3f \n' %(batch_idx, loss)) 202 | 203 | logging.info('epoch %d total training loss = %.3f' %(epoch, total_train_loss/len(TrainImgLoader))) 204 | writer.add_scalar("loss",total_train_loss/len(TrainImgLoader),epoch) 205 | 206 | # SAVE 207 | if epoch%1==0: 208 | savefilename = args.savemodel+'finetune_'+str(epoch)+'.tar' 209 | torch.save({ 210 | 'epoch': epoch, 211 | 'state_dict': model.state_dict(), 212 | 'train_loss': total_train_loss, 213 | 'test_loss': total_test_loss, 214 | }, savefilename) 215 | 216 | 217 | if __name__ == '__main__': 218 | main() 219 | -------------------------------------------------------------------------------- /deeppruner/loss_evaluation.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch.nn.functional as F 18 | from collections import namedtuple 19 | import logging 20 | 21 | loss_weights = { 22 | 'alpha_super_refined': 1.6, 23 | 'alpha_refined': 1.3, 24 | 'alpha_ca': 1.0, 25 | 'alpha_quantile': 1.0, 26 | 'alpha_min_max': 0.7 27 | } 28 | 29 | loss_weights = namedtuple('loss_weights', loss_weights.keys())(*loss_weights.values()) 30 | 31 | def loss_evaluation(result, disp_true, mask, cost_aggregator_scale=4): 32 | 33 | # forces min_disparity to be equal or slightly lower than the true disparity 34 | quantile_mask1 = ((disp_true[mask] - result[-1][mask]) < 0).float() 35 | quantile_loss1 = (disp_true[mask] - result[-1][mask]) * (0.05 - quantile_mask1) 36 | quantile_min_disparity_loss = quantile_loss1.mean() 37 | 38 | # forces max_disparity to be equal or slightly larger than the true disparity 39 | quantile_mask2 = ((disp_true[mask] - result[-2][mask]) < 0).float() 40 | quantile_loss2 = (disp_true[mask] - result[-2][mask]) * (0.95 - quantile_mask2) 41 | quantile_max_disparity_loss = quantile_loss2.mean() 42 | 43 | min_disparity_loss = F.smooth_l1_loss(result[-1][mask], disp_true[mask], size_average=True) 44 | max_disparity_loss = F.smooth_l1_loss(result[-2][mask], disp_true[mask], size_average=True) 45 | ca_depth_loss = F.smooth_l1_loss(result[-3][mask], disp_true[mask], size_average=True) 46 | refined_depth_loss = F.smooth_l1_loss(result[-4][mask], disp_true[mask], size_average=True) 47 | 48 | logging.info("============== evaluated losses ==================") 49 | if cost_aggregator_scale == 8: 50 | refined_depth_loss_1 = F.smooth_l1_loss(result[-5][mask], disp_true[mask], size_average=True) 51 | loss = (loss_weights.alpha_super_refined * refined_depth_loss_1) 52 | output_disparity = result[-5] 53 | logging.info('refined_depth_loss_1: %.6f', refined_depth_loss_1) 54 | else: 55 | loss = 0 56 | output_disparity = result[-4] 57 | 58 | loss += (loss_weights.alpha_refined * refined_depth_loss) + \ 59 | (loss_weights.alpha_ca * ca_depth_loss) + \ 60 | (loss_weights.alpha_quantile * (quantile_max_disparity_loss + quantile_min_disparity_loss)) + \ 61 | (loss_weights.alpha_min_max * (min_disparity_loss + max_disparity_loss)) 62 | 63 | logging.info('refined_depth_loss: %.6f' % refined_depth_loss) 64 | logging.info('ca_depth_loss: %.6f' % ca_depth_loss) 65 | logging.info('quantile_loss_max_disparity: %.6f' % quantile_max_disparity_loss) 66 | logging.info('quantile_loss_min_disparity: %.6f' % quantile_min_disparity_loss) 67 | logging.info('max_disparity_loss: %.6f' % max_disparity_loss) 68 | logging.info('min_disparity_loss: %.6f' % min_disparity_loss) 69 | logging.info("==================================================\n") 70 | 71 | return loss, output_disparity 72 | -------------------------------------------------------------------------------- /deeppruner/models/.gitignore: -------------------------------------------------------------------------------- 1 | *.pyc 2 | *.log -------------------------------------------------------------------------------- /deeppruner/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/deeppruner/models/__init__.py -------------------------------------------------------------------------------- /deeppruner/models/config.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | 17 | from __future__ import print_function 18 | 19 | class obj(object): 20 | def __init__(self, d): 21 | for key, value in d.items(): 22 | if isinstance(value, (list, tuple)): 23 | setattr(self, key, [obj(x) if isinstance(x, dict) else x for x in value]) 24 | else: 25 | setattr(self, key, obj(value) if isinstance(value, dict) else value) 26 | 27 | 28 | config = { 29 | "max_disp": 192, 30 | "cost_aggregator_scale": 4, # for DeepPruner-fast change this to 8. 31 | "mode": "training", # for evaluation/ submission, change this to evaluation. 32 | 33 | 34 | # The code allows the user to change the feature extrcator to any feature extractor of their choice. 35 | # The only requirements of the feature extractor are: 36 | # 1. For cost_aggregator_scale == 4: 37 | # features at downsample-level X4 (feature_extractor_ca_level) 38 | # and downsample-level X2 (feature_extractor_refinement_level) should be the output. 39 | # For cost_aggregator_scale == 8: 40 | # features at downsample-level X8 (feature_extractor_ca_level), 41 | # downsample-level X4 (feature_extractor_refinement_level), 42 | # downsample-level X2 (feature_extractor_refinement_level_1) should be the output, 43 | 44 | # 2. If the feature extractor is modified, change the "feature_extractor_outplanes_*" key in the config 45 | # accordingly. 46 | 47 | "feature_extractor_ca_level_outplanes": 32, 48 | "feature_extractor_refinement_level_outplanes": 32, # for DeepPruner-fast change this to 64. 49 | "feature_extractor_refinement_level_1_outplanes": 32, 50 | 51 | "patch_match_args": { 52 | "sample_count": 12, 53 | "iteration_count": 2, 54 | "propagation_filter_size": 3 55 | }, 56 | 57 | "post_CRP_sample_count": 7, 58 | "post_CRP_sampler_type": "uniform", #change to patch_match for Sceneflow model. 59 | 60 | "hourglass_inplanes": 16 61 | } 62 | 63 | config = obj(config) 64 | -------------------------------------------------------------------------------- /deeppruner/models/deeppruner.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | from models.submodules3d import MinDisparityPredictor, MaxDisparityPredictor, CostAggregator 18 | from models.submodules2d import RefinementNet 19 | from models.submodules import SubModule, conv_relu, convbn_2d_lrelu, convbn_3d_lrelu 20 | from models.utils import SpatialTransformer, UniformSampler 21 | from models.patch_match import PatchMatch 22 | import torch 23 | import torch.nn as nn 24 | import torch.nn.functional as F 25 | from models.config import config as args 26 | 27 | class DeepPruner(SubModule): 28 | def __init__(self): 29 | super(DeepPruner, self).__init__() 30 | self.scale = args.cost_aggregator_scale 31 | self.max_disp = args.max_disp // self.scale 32 | self.mode = args.mode 33 | 34 | self.patch_match_args = args.patch_match_args 35 | self.patch_match_sample_count = self.patch_match_args.sample_count 36 | self.patch_match_iteration_count = self.patch_match_args.iteration_count 37 | self.patch_match_propagation_filter_size = self.patch_match_args.propagation_filter_size 38 | 39 | self.post_CRP_sample_count = args.post_CRP_sample_count 40 | self.post_CRP_sampler_type = args.post_CRP_sampler_type 41 | hourglass_inplanes = args.hourglass_inplanes 42 | 43 | # refinement input features are composed of: 44 | # left image low level features + 45 | # CA output features + CA output disparity 46 | 47 | if self.scale == 8: 48 | from models.feature_extractor_fast import feature_extraction 49 | refinement_inplanes_1 = args.feature_extractor_refinement_level_1_outplanes + 1 50 | self.refinement_net1 = RefinementNet(refinement_inplanes_1) 51 | else: 52 | from models.feature_extractor_best import feature_extraction 53 | 54 | refinement_inplanes = args.feature_extractor_refinement_level_outplanes + self.post_CRP_sample_count + 2 + 1 55 | self.refinement_net = RefinementNet(refinement_inplanes) 56 | 57 | # cost_aggregator_inplanes are composed of: 58 | # left and right image features from feature_extractor (ca_level) + 59 | # features from min/max predictors + 60 | # min_disparity + max_disparity + disparity_samples 61 | 62 | cost_aggregator_inplanes = 2 * (args.feature_extractor_ca_level_outplanes + 63 | self.patch_match_sample_count + 2) + 1 64 | self.cost_aggregator = CostAggregator(cost_aggregator_inplanes, hourglass_inplanes) 65 | 66 | self.feature_extraction = feature_extraction() 67 | self.min_disparity_predictor = MinDisparityPredictor(hourglass_inplanes) 68 | self.max_disparity_predictor = MaxDisparityPredictor(hourglass_inplanes) 69 | self.spatial_transformer = SpatialTransformer() 70 | self.patch_match = PatchMatch(self.patch_match_propagation_filter_size) 71 | self.uniform_sampler = UniformSampler() 72 | 73 | # Confidence Range Predictor(CRP) input features are composed of: 74 | # left and right image features from feature_extractor (ca_level) + 75 | # disparity_samples 76 | 77 | CRP_feature_count = 2 * args.feature_extractor_ca_level_outplanes + 1 78 | self.dres0 = nn.Sequential(convbn_3d_lrelu(CRP_feature_count, 64, 3, 1, 1), 79 | convbn_3d_lrelu(64, 32, 3, 1, 1)) 80 | 81 | self.dres1 = nn.Sequential(convbn_3d_lrelu(32, 32, 3, 1, 1), 82 | convbn_3d_lrelu(32, hourglass_inplanes, 3, 1, 1)) 83 | 84 | self.min_disparity_conv = conv_relu(1, 1, 5, 1, 2) 85 | self.max_disparity_conv = conv_relu(1, 1, 5, 1, 2) 86 | self.ca_disparity_conv = conv_relu(1, 1, 5, 1, 2) 87 | 88 | self.ca_features_conv = convbn_2d_lrelu(self.post_CRP_sample_count + 2, 89 | self. post_CRP_sample_count + 2, 5, 1, 2, dilation=1, bias=True) 90 | self.min_disparity_features_conv = convbn_2d_lrelu(self.patch_match_sample_count + 2, 91 | self.patch_match_sample_count + 2, 5, 1, 2, dilation=1, bias=True) 92 | self.max_disparity_features_conv = convbn_2d_lrelu(self.patch_match_sample_count + 2, 93 | self.patch_match_sample_count + 2, 5, 1, 2, dilation=1, bias=True) 94 | 95 | self.weight_init() 96 | 97 | 98 | def generate_search_range(self, left_input, sample_count, stage, 99 | input_min_disparity=None, input_max_disparity=None): 100 | """ 101 | Description: Generates the disparity search range depending upon the stage it is called. 102 | If stage is "pre" (Pre-PatchMatch and Pre-ConfidenceRangePredictor), the search range is 103 | the entire disparity search range. 104 | If stage is "post" (Post-ConfidenceRangePredictor), then the ConfidenceRangePredictor search range 105 | is adjusted for maximum efficiency. 106 | Args: 107 | :left_input: Left Image Features 108 | :sample_count: number of samples to be generated from the search range. Used to adjust the search range. 109 | :stage: "pre"(Pre-PatchMatch) or "post"(Post-ConfidenceRangePredictor) 110 | :input_min_disparity (default:None): ConfidenceRangePredictor disparity lowerbound (for stage=="post") 111 | :input_max_disparity (default:None): ConfidenceRangePredictor disparity upperbound (for stage=="post") 112 | 113 | Returns: 114 | :min_disparity: Lower bound of disparity search range 115 | :max_disparity: Upper bound of disaprity search range. 116 | """ 117 | 118 | device = left_input.get_device() 119 | if stage is "pre": 120 | min_disparity = torch.zeros((left_input.size()[0], 1, left_input.size()[2], left_input.size()[3]), 121 | device=device) 122 | max_disparity = torch.zeros((left_input.size()[0], 1, left_input.size()[2], left_input.size()[3]), 123 | device=device) + self.max_disp 124 | 125 | else: 126 | min_disparity1 = torch.min(input_min_disparity, input_max_disparity) 127 | max_disparity1 = torch.max(input_min_disparity, input_max_disparity) 128 | 129 | # if (max_disparity1 - min_disparity1) > sample_count: 130 | # sample uniformly "sample_count" number of samples from (min_disparity1, max_disparity1) 131 | # else: 132 | # stretch min_disparity1 and max_disparity1 such that (max_disparity1 - min_disparity1) == sample_count 133 | 134 | min_disparity = torch.clamp(min_disparity1 - torch.clamp(( 135 | sample_count - max_disparity1 + min_disparity1), min=0) / 2.0, min=0, max=self.max_disp) 136 | max_disparity = torch.clamp(max_disparity1 + torch.clamp( 137 | sample_count - max_disparity1 + min_disparity, min=0), min=0, max=self.max_disp) 138 | 139 | return min_disparity, max_disparity 140 | 141 | def generate_disparity_samples(self, left_input, right_input, min_disparity, 142 | max_disparity, sample_count=12, sampler_type="patch_match"): 143 | """ 144 | Description: Generates "sample_count" number of disparity samples from the 145 | search range (min_disparity, max_disparity) 146 | Samples are generated either uniformly from the search range 147 | or are generated using PatchMatch. 148 | 149 | Args: 150 | :left_input: Left Image features. 151 | :right_input: Right Image features. 152 | :min_disparity: LowerBound of the disaprity search range. 153 | :max_disparity: UpperBound of the disparity search range. 154 | :sample_count (default:12): Number of samples to be generated from the input search range. 155 | :sampler_type (default:"patch_match"): samples are generated either using 156 | "patch_match" or "uniform" sampler. 157 | Returns: 158 | :disparity_samples: 159 | """ 160 | if sampler_type is "patch_match": 161 | disparity_samples = self.patch_match(left_input, right_input, min_disparity, 162 | max_disparity, sample_count, self.patch_match_iteration_count) 163 | else: 164 | disparity_samples = self.uniform_sampler(min_disparity, max_disparity, sample_count) 165 | 166 | disparity_samples = torch.cat((torch.floor(min_disparity), disparity_samples, torch.ceil(max_disparity)), 167 | dim=1).long() 168 | return disparity_samples 169 | 170 | def cost_volume_generator(self, left_input, right_input, disparity_samples): 171 | """ 172 | Description: Generates cost-volume using left image features, disaprity samples 173 | and warped right image features. 174 | Args: 175 | :left_input: Left Image fetaures 176 | :right_input: Right Image features 177 | :disparity_samples: Disaprity samples 178 | 179 | Returns: 180 | :cost_volume: 181 | :disaprity_samples: 182 | :left_feature_map: 183 | """ 184 | 185 | right_feature_map, left_feature_map = self.spatial_transformer(left_input, 186 | right_input, disparity_samples) 187 | disparity_samples = disparity_samples.unsqueeze(1).float() 188 | 189 | cost_volume = torch.cat((left_feature_map, right_feature_map, disparity_samples), dim=1) 190 | 191 | return cost_volume, disparity_samples, left_feature_map 192 | 193 | def confidence_range_predictor(self, cost_volume, disparity_samples): 194 | """ 195 | Description: The original search space for all pixels is identical. However, in practice, for each 196 | pixel, the highly probable disparities lie in a narrow region. Using the small subset 197 | of disparities estimated from the PatchMatch stage, we have sufficient information to 198 | predict the range in which the true disparity lies. We thus exploit a confidence range 199 | prediction network to adjust the search space for each pixel. 200 | 201 | Args: 202 | :cost_volume: Input Cost-Volume 203 | :disparity_samples: Initial Disparity samples. 204 | 205 | Returns: 206 | :min_disparity: ConfidenceRangePredictor disparity lowerbound 207 | :max_disparity: ConfidenceRangePredictor disparity upperbound 208 | :min_disparity_features: features from ConfidenceRangePredictor-Min 209 | :max_disparity_features: features from ConfidenceRangePredictor-Max 210 | """ 211 | # cost-volume bottleneck layers 212 | cost_volume = self.dres0(cost_volume) 213 | cost_volume = self.dres1(cost_volume) 214 | 215 | min_disparity, min_disparity_features = self.min_disparity_predictor(cost_volume, 216 | disparity_samples.squeeze(1)) 217 | 218 | max_disparity, max_disparity_features = self.max_disparity_predictor(cost_volume, 219 | disparity_samples.squeeze(1)) 220 | 221 | min_disparity = self.min_disparity_conv(min_disparity) 222 | max_disparity = self.max_disparity_conv(max_disparity) 223 | min_disparity_features = self.min_disparity_features_conv(min_disparity_features) 224 | max_disparity_features = self.max_disparity_features_conv(max_disparity_features) 225 | 226 | return min_disparity, max_disparity, min_disparity_features, max_disparity_features 227 | 228 | def forward(self, left_input, right_input): 229 | """ 230 | DeepPruner 231 | Description: DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 232 | 233 | Args: 234 | :left_input: Left Stereo Image 235 | :right_input: Right Stereo Image 236 | Returns: 237 | outputs depend of args.mode ("evaluation or "training"), and 238 | also on args.cost_aggregator_scale (8 or 4) 239 | 240 | All possible outputs can be: 241 | :refined_disparity_1: DeepPruner disparity output after Refinement1 stage. 242 | s (only when args.cost_aggregator_scale==8) 243 | :refined_disparity: DeepPruner disparity output after Refinement stage. 244 | :ca_disparity: DeepPruner disparity output after 3D-Cost Aggregation stage. 245 | :max_disparity: DeepPruner disparity by Confidence Range Predictor (Max) 246 | :min_disparity: DeepPruner disparity by Confidence Range Predictor (Min) 247 | 248 | """ 249 | 250 | if self.scale == 8: 251 | left_spp_features, left_low_level_features, left_low_level_features_1 = self.feature_extraction(left_input) 252 | right_spp_features, right_low_level_features, _ = self.feature_extraction( 253 | right_input) 254 | else: 255 | left_spp_features, left_low_level_features = self.feature_extraction(left_input) 256 | right_spp_features, right_low_level_features = self.feature_extraction(right_input) 257 | 258 | min_disparity, max_disparity = self.generate_search_range( 259 | left_spp_features, 260 | sample_count=self.patch_match_sample_count, stage="pre") 261 | 262 | disparity_samples = self.generate_disparity_samples( 263 | left_spp_features, 264 | right_spp_features, min_disparity, max_disparity, 265 | sample_count=self.patch_match_sample_count, sampler_type="patch_match") 266 | 267 | cost_volume, disparity_samples, _ = self.cost_volume_generator(left_spp_features, 268 | right_spp_features, 269 | disparity_samples) 270 | 271 | min_disparity, max_disparity, min_disparity_features, max_disparity_features = \ 272 | self.confidence_range_predictor(cost_volume, disparity_samples) 273 | 274 | stretched_min_disparity, stretched_max_disparity = self.generate_search_range( 275 | left_spp_features, 276 | sample_count=self.post_CRP_sample_count, stage='post', 277 | input_min_disparity=min_disparity, input_max_disparity=max_disparity) 278 | 279 | disparity_samples = self.generate_disparity_samples( 280 | left_spp_features, 281 | right_spp_features, stretched_min_disparity, stretched_max_disparity, 282 | sample_count=self.post_CRP_sample_count, sampler_type=self.post_CRP_sampler_type) 283 | 284 | cost_volume, disparity_samples, expanded_left_feature_map = self.cost_volume_generator( 285 | left_spp_features, 286 | right_spp_features, 287 | disparity_samples) 288 | 289 | min_disparity_features = min_disparity_features.unsqueeze(2).expand(-1, -1, 290 | expanded_left_feature_map.size()[2], -1, -1) 291 | max_disparity_features = max_disparity_features.unsqueeze(2).expand(-1, -1, 292 | expanded_left_feature_map.size()[2], -1, -1) 293 | 294 | cost_volume = torch.cat((cost_volume, min_disparity_features, max_disparity_features), dim=1) 295 | ca_disparity, ca_features = self.cost_aggregator(cost_volume, disparity_samples.squeeze(1)) 296 | 297 | ca_disparity = F.interpolate(ca_disparity * 2, scale_factor=(2, 2), mode='bilinear') 298 | ca_features = F.interpolate(ca_features, scale_factor=(2, 2), mode='bilinear') 299 | ca_disparity = self.ca_disparity_conv(ca_disparity) 300 | ca_features = self.ca_features_conv(ca_features) 301 | 302 | refinement_net_input = torch.cat((left_low_level_features, ca_features, ca_disparity), dim=1) 303 | refined_disparity = self.refinement_net(refinement_net_input, ca_disparity) 304 | 305 | refined_disparity = F.interpolate(refined_disparity * 2, scale_factor=(2, 2), mode='bilinear') 306 | 307 | if self.scale == 8: 308 | refinement_net_input = torch.cat((left_low_level_features_1, refined_disparity), dim=1) 309 | refined_disparity_1 = self.refinement_net1(refinement_net_input, refined_disparity) 310 | 311 | if self.mode == 'evaluation': 312 | if self.scale == 8: 313 | refined_disparity_1 = F.interpolate(refined_disparity_1 * 2, scale_factor=(2, 2), 314 | mode='bilinear').squeeze(1) 315 | return refined_disparity_1 316 | return refined_disparity.squeeze(1) 317 | 318 | min_disparity = F.interpolate(min_disparity * self.scale, scale_factor=(self.scale, self.scale), 319 | mode='bilinear').squeeze(1) 320 | max_disparity = F.interpolate(max_disparity * self.scale, scale_factor=(self.scale, self.scale), 321 | mode='bilinear').squeeze(1) 322 | ca_disparity = F.interpolate(ca_disparity * (self.scale // 2), 323 | scale_factor=((self.scale // 2), (self.scale // 2)), mode='bilinear').squeeze(1) 324 | 325 | if self.scale == 8: 326 | refined_disparity = F.interpolate(refined_disparity * 2, scale_factor=(2, 2), mode='bilinear').squeeze(1) 327 | refined_disparity_1 = F.interpolate(refined_disparity_1 * 2, 328 | scale_factor=(2, 2), mode='bilinear').squeeze(1) 329 | 330 | return refined_disparity_1, refined_disparity, ca_disparity, max_disparity, min_disparity 331 | 332 | return refined_disparity.squeeze(1), ca_disparity, max_disparity, min_disparity 333 | -------------------------------------------------------------------------------- /deeppruner/models/feature_extractor_best.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch 18 | import torch.nn as nn 19 | import torch.utils.data 20 | import torch.nn.functional as F 21 | from models.submodules import BasicBlock, convbn_relu 22 | 23 | 24 | class feature_extraction(nn.Module): 25 | def __init__(self): 26 | super(feature_extraction, self).__init__() 27 | self.inplanes = 32 28 | self.firstconv = nn.Sequential(convbn_relu(3, 32, 3, 2, 1, 1), 29 | convbn_relu(32, 32, 3, 1, 1, 1), 30 | convbn_relu(32, 32, 3, 1, 1, 1)) 31 | 32 | self.layer1 = self._make_layer(BasicBlock, 32, 3, 1, 1, 1) 33 | self.layer2 = self._make_layer(BasicBlock, 64, 16, 2, 1, 1) 34 | self.layer3 = self._make_layer(BasicBlock, 128, 3, 1, 1, 1) 35 | self.layer4 = self._make_layer(BasicBlock, 128, 3, 1, 1, 2) 36 | 37 | self.branch1 = nn.Sequential(nn.AvgPool2d((64, 64), stride=(64, 64)), 38 | convbn_relu(128, 32, 1, 1, 0, 1)) 39 | 40 | self.branch2 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32, 32)), 41 | convbn_relu(128, 32, 1, 1, 0, 1)) 42 | 43 | self.branch3 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16, 16)), 44 | convbn_relu(128, 32, 1, 1, 0, 1)) 45 | 46 | self.branch4 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8, 8)), 47 | convbn_relu(128, 32, 1, 1, 0, 1)) 48 | 49 | self.lastconv = nn.Sequential(convbn_relu(320, 128, 3, 1, 1, 1), 50 | nn.Conv2d(128, 32, kernel_size=1, padding=0, stride=1, bias=False)) 51 | 52 | def _make_layer(self, block, planes, blocks, stride, pad, dilation): 53 | downsample = None 54 | if stride != 1 or self.inplanes != planes * block.expansion: 55 | downsample = nn.Sequential( 56 | nn.Conv2d(self.inplanes, planes * block.expansion, 57 | kernel_size=1, stride=stride, bias=False), 58 | nn.BatchNorm2d(planes * block.expansion),) 59 | 60 | layers = [] 61 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation)) 62 | self.inplanes = planes * block.expansion 63 | for i in range(1, blocks): 64 | layers.append(block(self.inplanes, planes, 1, None, pad, dilation)) 65 | 66 | return nn.Sequential(*layers) 67 | 68 | def forward(self, input): 69 | """ 70 | Feature Extractor 71 | Description: The goal of the feature extraction network is to produce a reliable pixel-wise 72 | feature representation from the input image. Specifically, we employ four residual blocks 73 | and use X2 dilated convolution for the last block to enlarge the receptive field. 74 | We then apply spatial pyramid pooling to build a 4-level pyramid feature. 75 | Through multi-scale information, the model is able to capture large context while 76 | maintaining a high spatial resolution. The size of the final feature map is 1/4 of 77 | the originalinput image size. We share the parameters for the left and right feature network. 78 | 79 | Args: 80 | :input: Input image (RGB) 81 | 82 | Returns: 83 | :output_feature: spp_features (downsampled X4) 84 | :output1: low_level_features (downsampled X2) 85 | """ 86 | 87 | output0 = self.firstconv(input) 88 | output1 = self.layer1(output0) 89 | output_raw = self.layer2(output1) 90 | output = self.layer3(output_raw) 91 | output_skip = self.layer4(output) 92 | 93 | output_branch1 = self.branch1(output_skip) 94 | output_branch1 = F.upsample(output_branch1, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 95 | 96 | output_branch2 = self.branch2(output_skip) 97 | output_branch2 = F.upsample(output_branch2, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 98 | 99 | output_branch3 = self.branch3(output_skip) 100 | output_branch3 = F.upsample(output_branch3, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 101 | 102 | output_branch4 = self.branch4(output_skip) 103 | output_branch4 = F.upsample(output_branch4, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 104 | 105 | output_feature = torch.cat( 106 | (output_raw, 107 | output_skip, 108 | output_branch4, 109 | output_branch3, 110 | output_branch2, 111 | output_branch1), 112 | 1) 113 | output_feature = self.lastconv(output_feature) 114 | 115 | return output_feature, output1 116 | -------------------------------------------------------------------------------- /deeppruner/models/feature_extractor_fast.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch 18 | import torch.nn as nn 19 | import torch.utils.data 20 | import torch.nn.functional as F 21 | from models.submodules import BasicBlock, convbn_relu 22 | 23 | 24 | class feature_extraction(nn.Module): 25 | def __init__(self): 26 | super(feature_extraction, self).__init__() 27 | self.inplanes = 32 28 | self.firstconv = nn.Sequential(convbn_relu(3, 32, 3, 2, 1, 1), 29 | convbn_relu(32, 32, 3, 1, 1, 1), 30 | convbn_relu(32, 32, 3, 1, 1, 1)) 31 | 32 | self.layer1 = self._make_layer(BasicBlock, 32, 3, 1, 1, 1) 33 | self.layer2 = self._make_layer(BasicBlock, 64, 16, 2, 1, 1) 34 | self.layer3 = self._make_layer(BasicBlock, 128, 3, 2, 1, 1) 35 | self.layer4 = self._make_layer(BasicBlock, 128, 3, 1, 1, 1) 36 | 37 | self.branch2 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32, 32)), 38 | convbn_relu(128, 32, 1, 1, 0, 1)) 39 | 40 | self.branch3 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16, 16)), 41 | convbn_relu(128, 32, 1, 1, 0, 1)) 42 | 43 | self.branch4 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8, 8)), 44 | convbn_relu(128, 32, 1, 1, 0, 1)) 45 | 46 | self.lastconv = nn.Sequential(convbn_relu(352, 128, 3, 1, 1, 1), 47 | nn.Conv2d(128, 32, kernel_size=1, padding=0, stride=1, bias=False)) 48 | 49 | def _make_layer(self, block, planes, blocks, stride, pad, dilation): 50 | downsample = None 51 | if stride != 1 or self.inplanes != planes * block.expansion: 52 | downsample = nn.Sequential( 53 | nn.Conv2d(self.inplanes, planes * block.expansion, 54 | kernel_size=1, stride=stride, bias=False), 55 | nn.BatchNorm2d(planes * block.expansion),) 56 | 57 | layers = [] 58 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation)) 59 | self.inplanes = planes * block.expansion 60 | for i in range(1, blocks): 61 | layers.append(block(self.inplanes, planes, 1, None, pad, dilation)) 62 | 63 | return nn.Sequential(*layers) 64 | 65 | def forward(self, input): 66 | """ 67 | Feature Extractor 68 | Description: The goal of the feature extraction network is to produce a reliable pixel-wise 69 | feature representation from the input image. Specifically, we employ four residual blocks 70 | and use X2 dilated convolution for the last block to enlarge the receptive field. 71 | We then apply spatial pyramid pooling to build a 4-level pyramid feature. 72 | Through multi-scale information, the model is able to capture large context while 73 | maintaining a high spatial resolution. The size of the final feature map is 1/4 of 74 | the originalinput image size. We share the parameters for the left and right feature network. 75 | 76 | Args: 77 | :input: Input image (RGB) 78 | 79 | Returns: 80 | :output_feature: spp_features (downsampled X8) 81 | :output_raw: features (downsampled X4) 82 | :output1: low_level_features (downsampled X2) 83 | """ 84 | 85 | output0 = self.firstconv(input) 86 | output1 = self.layer1(output0) 87 | output_raw = self.layer2(output1) 88 | output = self.layer3(output_raw) 89 | output_skip = self.layer4(output) 90 | 91 | output_branch2 = self.branch2(output_skip) 92 | output_branch2 = F.upsample(output_branch2, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 93 | 94 | output_branch3 = self.branch3(output_skip) 95 | output_branch3 = F.upsample(output_branch3, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 96 | 97 | output_branch4 = self.branch4(output_skip) 98 | output_branch4 = F.upsample(output_branch4, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear') 99 | 100 | output_feature = torch.cat((output, output_skip, output_branch4, output_branch3, output_branch2), 1) 101 | output_feature = self.lastconv(output_feature) 102 | 103 | return output_feature, output_raw, output1 104 | -------------------------------------------------------------------------------- /deeppruner/models/patch_match.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch 18 | import torch.nn as nn 19 | import torch.nn.functional as F 20 | 21 | 22 | class DisparityInitialization(nn.Module): 23 | 24 | def __init__(self): 25 | super(DisparityInitialization, self).__init__() 26 | 27 | def forward(self, min_disparity, max_disparity, number_of_intervals=10): 28 | """ 29 | PatchMatch Initialization Block 30 | Description: Rather than allowing each sample/ particle to reside in the full disparity space, 31 | we divide the search space into 'number_of_intervals' intervals, and force the 32 | i-th particle to be in a i-th interval. This guarantees the diversity of the 33 | particles and helps improve accuracy for later computations. 34 | 35 | As per implementation, 36 | this function divides the complete disparity search space into multiple intervals. 37 | 38 | Args: 39 | :min_disparity: Min Disparity of the disparity search range. 40 | :max_disparity: Max Disparity of the disparity search range. 41 | :number_of_intervals (default: 10): Number of samples to be generated. 42 | Returns: 43 | :interval_noise: Random value between 0-1. Represents offset of the from the interval_min_disparity. 44 | :interval_min_disparity: disparity_sample = interval_min_disparity + interval_noise 45 | :multiplier: 1.0 / number_of_intervals 46 | """ 47 | 48 | device = min_disparity.get_device() 49 | 50 | multiplier = 1.0 / number_of_intervals 51 | range_multiplier = torch.arange(0.0, 1, multiplier, device=device).view(number_of_intervals, 1, 1) 52 | range_multiplier = range_multiplier.repeat(1, min_disparity.size()[2], min_disparity.size()[3]) 53 | 54 | interval_noise = min_disparity.new_empty(min_disparity.size()[0], number_of_intervals, min_disparity.size()[2], 55 | min_disparity.size()[3]).uniform_(0, 1) 56 | interval_min_disparity = min_disparity + (max_disparity - min_disparity) * range_multiplier 57 | 58 | return interval_noise, interval_min_disparity, multiplier 59 | 60 | 61 | class Evaluate(nn.Module): 62 | def __init__(self, filter_size=3, temperature=7): 63 | super(Evaluate, self).__init__() 64 | self.temperature = temperature 65 | self.filter_size = filter_size 66 | self.softmax = torch.nn.Softmax(dim=1) 67 | 68 | def forward(self, left_input, right_input, disparity_samples, normalized_disparity_samples): 69 | """ 70 | PatchMatch Evaluation Block 71 | Description: For each pixel i, matching scores are computed by taking the inner product between the 72 | left feature and the right feature: score(i,j) = feature_left(i), feature_right(i+disparity(i,j)) 73 | for all candidates j. The best k disparity value for each pixel is carried towards the next iteration. 74 | 75 | As per implementation, 76 | the complete disparity search range is discretized into intervals in 77 | DisparityInitialization() function. Corresponding to each disparity interval, we have multiple samples 78 | to evaluate. The best disparity sample per interval is the output of the function. 79 | 80 | Args: 81 | :left_input: Left Image Feature Map 82 | :right_input: Right Image Feature Map 83 | :disparity_samples: Disparity Samples to be evaluated. For each pixel, we have 84 | ("number of intervals" X "number_of_samples_per_intervals") samples. 85 | 86 | :normalized_disparity_samples: 87 | Returns: 88 | :disparity_samples: Evaluated disparity sample, one per disparity interval. 89 | :normalized_disparity_samples: Evaluated normaized disparity sample, one per disparity interval. 90 | """ 91 | device = left_input.get_device() 92 | left_y_coordinate = torch.arange(0.0, left_input.size()[3], device=device).repeat( 93 | left_input.size()[2]).view(left_input.size()[2], left_input.size()[3]) 94 | 95 | left_y_coordinate = torch.clamp(left_y_coordinate, min=0, max=left_input.size()[3] - 1) 96 | left_y_coordinate = left_y_coordinate.expand(left_input.size()[0], -1, -1) 97 | 98 | right_feature_map = right_input.expand(disparity_samples.size()[1], -1, -1, -1, -1).permute([1, 2, 0, 3, 4]) 99 | left_feature_map = left_input.expand(disparity_samples.size()[1], -1, -1, -1, -1).permute([1, 2, 0, 3, 4]) 100 | 101 | disparity_sample_strength = disparity_samples.new(disparity_samples.size()[0], 102 | disparity_samples.size()[1], 103 | disparity_samples.size()[2], 104 | disparity_samples.size()[3]) 105 | 106 | right_y_coordinate = left_y_coordinate.expand( 107 | disparity_samples.size()[1], -1, -1, -1).permute([1, 0, 2, 3]).float() 108 | right_y_coordinate = right_y_coordinate - disparity_samples 109 | right_y_coordinate = torch.clamp(right_y_coordinate, min=0, max=right_input.size()[3] - 1) 110 | 111 | warped_right_feature_map = torch.gather(right_feature_map, 112 | dim=4, 113 | index=right_y_coordinate.expand( 114 | right_input.size()[1], -1, -1, -1, -1).permute([1, 0, 2, 3, 4]).long()) 115 | 116 | disparity_sample_strength = torch.mean(left_feature_map * warped_right_feature_map, dim=1) * self.temperature 117 | 118 | disparity_sample_strength = disparity_sample_strength.view( 119 | disparity_sample_strength.size()[0], 120 | disparity_sample_strength.size()[1] // (self.filter_size), 121 | (self.filter_size), 122 | disparity_sample_strength.size()[2], 123 | disparity_sample_strength.size()[3]) 124 | 125 | disparity_samples = disparity_samples.view(disparity_samples.size()[0], 126 | disparity_samples.size()[1] // (self.filter_size), 127 | (self.filter_size), 128 | disparity_samples.size()[2], 129 | disparity_samples.size()[3]) 130 | 131 | normalized_disparity_samples = normalized_disparity_samples.view( 132 | normalized_disparity_samples.size()[0], 133 | normalized_disparity_samples.size()[1] // (self.filter_size), 134 | (self.filter_size), 135 | normalized_disparity_samples.size()[2], 136 | normalized_disparity_samples.size()[3]) 137 | 138 | disparity_sample_strength = disparity_sample_strength.permute([0, 2, 1, 3, 4]) 139 | disparity_samples = disparity_samples.permute([0, 2, 1, 3, 4]) 140 | normalized_disparity_samples = normalized_disparity_samples.permute([0, 2, 1, 3, 4]) 141 | 142 | disparity_sample_strength = torch.softmax(disparity_sample_strength, dim=1) 143 | disparity_samples = torch.sum(disparity_samples * disparity_sample_strength, dim=1) 144 | normalized_disparity_samples = torch.sum(normalized_disparity_samples * disparity_sample_strength, dim=1) 145 | 146 | return normalized_disparity_samples, disparity_samples 147 | 148 | 149 | class Propagation(nn.Module): 150 | def __init__(self, filter_size=3): 151 | super(Propagation, self).__init__() 152 | self.filter_size = filter_size 153 | 154 | def forward(self, disparity_samples, device, propagation_type="horizontal"): 155 | """ 156 | PatchMatch Propagation Block 157 | Description: Particles from adjacent pixels are propagated together through convolution with a 158 | pre-defined one-hot filter pattern, which en-codes the fact that we allow each pixel 159 | to propagate particles to its 4-neighbours. 160 | 161 | As per implementation, the complete disparity search range is discretized into intervals in 162 | DisparityInitialization() function. 163 | Now, propagation of samples from neighbouring pixels, is done per interval. This implies that after 164 | propagation, number of samples per pixel = (filter_size X number_of_intervals) 165 | 166 | Args: 167 | :disparity_samples: 168 | :device: Cuda device 169 | :propagation_type (default:"horizontal"): In order to be memory efficient, we use separable convolutions 170 | for propagtaion. 171 | 172 | Returns: 173 | :aggregated_disparity_samples: Disparity Samples aggregated from the neighbours. 174 | 175 | """ 176 | 177 | disparity_samples = disparity_samples.view(disparity_samples.size()[0], 178 | 1, 179 | disparity_samples.size()[1], 180 | disparity_samples.size()[2], 181 | disparity_samples.size()[3]) 182 | 183 | if propagation_type is "horizontal": 184 | label = torch.arange(0, self.filter_size, device=device).repeat(self.filter_size).view( 185 | self.filter_size, 1, 1, 1, self.filter_size) 186 | 187 | one_hot_filter = torch.zeros_like(label).scatter_(0, label, 1).float() 188 | aggregated_disparity_samples = F.conv3d(disparity_samples, 189 | one_hot_filter, padding=(0, 0, self.filter_size // 2)) 190 | 191 | else: 192 | label = torch.arange(0, self.filter_size, device=device).repeat(self.filter_size).view( 193 | self.filter_size, 1, 1, self.filter_size, 1).long() 194 | 195 | one_hot_filter = torch.zeros_like(label).scatter_(0, label, 1).float() 196 | aggregated_disparity_samples = F.conv3d(disparity_samples, 197 | one_hot_filter, padding=(0, self.filter_size // 2, 0)) 198 | 199 | aggregated_disparity_samples = aggregated_disparity_samples.permute([0, 2, 1, 3, 4]) 200 | aggregated_disparity_samples = aggregated_disparity_samples.contiguous().view( 201 | aggregated_disparity_samples.size()[0], 202 | aggregated_disparity_samples.size()[1] * aggregated_disparity_samples.size()[2], 203 | aggregated_disparity_samples.size()[3], 204 | aggregated_disparity_samples.size()[4]) 205 | 206 | return aggregated_disparity_samples 207 | 208 | 209 | class PatchMatch(nn.Module): 210 | def __init__(self, propagation_filter_size=3): 211 | super(PatchMatch, self).__init__() 212 | 213 | self.propagation_filter_size = propagation_filter_size 214 | self.propagation = Propagation(filter_size=propagation_filter_size) 215 | self.disparity_initialization = DisparityInitialization() 216 | self.evaluate = Evaluate(filter_size=propagation_filter_size) 217 | 218 | def forward(self, left_input, right_input, min_disparity, max_disparity, sample_count=10, iteration_count=3): 219 | """ 220 | Differntail PatchMatch Block 221 | Description: In this work, we unroll generalized PatchMatch as a recurrent neural network, 222 | where each unrolling step is equivalent to each iteration of the algorithm. 223 | This is important as it allow us to train our full model end-to-end. 224 | Specifically, we design the following layers: 225 | - Initialization or Paticle Sampling 226 | - Propagation 227 | - Evaluation 228 | Args: 229 | :left_input: Left Image feature map 230 | :right_input: Right image feature map 231 | :min_disparity: Min of the disparity search range 232 | :max_disparity: Max of the disparity search range 233 | :sample_count (default:10): Number of disparity samples per pixel. (similar to generalized PatchMatch) 234 | :iteration_count (default:3) : Number of PatchMatch iterations 235 | 236 | Returns: 237 | :disparity_samples: For each pixel, this function returns "sample_count" disparity samples. 238 | """ 239 | 240 | device = left_input.get_device() 241 | min_disparity = torch.floor(min_disparity) 242 | max_disparity = torch.ceil(max_disparity) 243 | 244 | # normalized_disparity_samples: Disparity samples normalized by the corresponding interval size. 245 | # i.e (disparity_sample - interval_min_disparity) / interval_size 246 | 247 | normalized_disparity_samples, min_disp_tensor, multiplier = self.disparity_initialization( 248 | min_disparity, max_disparity, sample_count) 249 | min_disp_tensor = min_disp_tensor.unsqueeze(2).repeat(1, 1, self.propagation_filter_size, 1, 1).view( 250 | min_disp_tensor.size()[0], 251 | min_disp_tensor.size()[1] * self.propagation_filter_size, 252 | min_disp_tensor.size()[2], 253 | min_disp_tensor.size()[3]) 254 | 255 | for prop_iter in range(iteration_count): 256 | normalized_disparity_samples = self.propagation(normalized_disparity_samples, device, propagation_type="horizontal") 257 | disparity_samples = normalized_disparity_samples * \ 258 | (max_disparity - min_disparity) * multiplier + min_disp_tensor 259 | 260 | normalized_disparity_samples, disparity_samples = self.evaluate(left_input, 261 | right_input, 262 | disparity_samples, 263 | normalized_disparity_samples) 264 | 265 | normalized_disparity_samples = self.propagation(normalized_disparity_samples, device, propagation_type="vertical") 266 | disparity_samples = normalized_disparity_samples * \ 267 | (max_disparity - min_disparity) * multiplier + min_disp_tensor 268 | 269 | normalized_disparity_samples, disparity_samples = self.evaluate(left_input, 270 | right_input, 271 | disparity_samples, 272 | normalized_disparity_samples) 273 | 274 | return disparity_samples 275 | -------------------------------------------------------------------------------- /deeppruner/models/submodules.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch.nn as nn 18 | import math 19 | 20 | 21 | def convbn_2d_lrelu(in_planes, out_planes, kernel_size, stride, pad, dilation=1, bias=False): 22 | return nn.Sequential( 23 | nn.Conv2d(in_planes, out_planes, kernel_size=(kernel_size, kernel_size), 24 | stride=(stride, stride), padding=(pad, pad), dilation=(dilation, dilation), bias=bias), 25 | nn.BatchNorm2d(out_planes), 26 | nn.LeakyReLU(0.1, inplace=True)) 27 | 28 | 29 | def convbn_3d_lrelu(in_planes, out_planes, kernel_size, stride, pad): 30 | 31 | return nn.Sequential(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=(pad, pad, pad), 32 | stride=(1, stride, stride), bias=False), 33 | nn.BatchNorm3d(out_planes), 34 | nn.LeakyReLU(0.1, inplace=True)) 35 | 36 | 37 | def conv_relu(in_planes, out_planes, kernel_size, stride, pad, bias=True): 38 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad, bias=bias), 39 | nn.ReLU(inplace=True)) 40 | 41 | 42 | def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation): 43 | 44 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 45 | padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False), 46 | nn.BatchNorm2d(out_planes)) 47 | 48 | 49 | def convbn_relu(in_planes, out_planes, kernel_size, stride, pad, dilation): 50 | 51 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, 52 | padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False), 53 | nn.BatchNorm2d(out_planes), 54 | nn.ReLU(inplace=True)) 55 | 56 | 57 | def convbn_transpose_3d(inplanes, outplanes, kernel_size, padding, output_padding, stride, bias): 58 | return nn.Sequential(nn.ConvTranspose3d(inplanes, outplanes, kernel_size, padding=padding, 59 | output_padding=output_padding, stride=stride, bias=bias), 60 | nn.BatchNorm3d(outplanes)) 61 | 62 | 63 | class BasicBlock(nn.Module): 64 | expansion = 1 65 | 66 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation): 67 | super(BasicBlock, self).__init__() 68 | 69 | self.conv1 = convbn_relu(inplanes, planes, 3, stride, pad, dilation) 70 | self.conv2 = convbn(planes, planes, 3, 1, pad, dilation) 71 | 72 | self.downsample = downsample 73 | self.stride = stride 74 | 75 | def forward(self, x): 76 | out = self.conv1(x) 77 | out = self.conv2(out) 78 | 79 | if self.downsample is not None: 80 | x = self.downsample(x) 81 | 82 | out += x 83 | 84 | return out 85 | 86 | 87 | class SubModule(nn.Module): 88 | def __init__(self): 89 | super(SubModule, self).__init__() 90 | 91 | def weight_init(self): 92 | for m in self.modules(): 93 | if isinstance(m, nn.Conv2d): 94 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 95 | m.weight.data.normal_(0, math.sqrt(2. / n)) 96 | elif isinstance(m, nn.Conv3d): 97 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels 98 | m.weight.data.normal_(0, math.sqrt(2. / n)) 99 | elif isinstance(m, nn.BatchNorm2d): 100 | m.weight.data.fill_(1) 101 | m.bias.data.zero_() 102 | elif isinstance(m, nn.BatchNorm3d): 103 | m.weight.data.fill_(1) 104 | m.bias.data.zero_() 105 | elif isinstance(m, nn.Linear): 106 | m.bias.data.zero_() 107 | -------------------------------------------------------------------------------- /deeppruner/models/submodules2d.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch.nn as nn 18 | from models.submodules import SubModule, convbn_2d_lrelu 19 | 20 | 21 | class RefinementNet(SubModule): 22 | def __init__(self, inplanes): 23 | super(RefinementNet, self).__init__() 24 | 25 | self.conv1 = nn.Sequential( 26 | convbn_2d_lrelu(inplanes, 32, kernel_size=3, stride=1, pad=1), 27 | convbn_2d_lrelu(32, 32, kernel_size=3, stride=1, pad=1, dilation=1), 28 | convbn_2d_lrelu(32, 32, kernel_size=3, stride=1, pad=1, dilation=1), 29 | convbn_2d_lrelu(32, 16, kernel_size=3, stride=1, pad=2, dilation=2), 30 | convbn_2d_lrelu(16, 16, kernel_size=3, stride=1, pad=4, dilation=4), 31 | convbn_2d_lrelu(16, 16, kernel_size=3, stride=1, pad=1, dilation=1)) 32 | 33 | self.classif1 = nn.Conv2d(16, 1, kernel_size=3, padding=1, stride=1, bias=False) 34 | self.relu = nn.ReLU(inplace=True) 35 | 36 | self.weight_init() 37 | 38 | def forward(self, input, disparity): 39 | """ 40 | Refinement Block 41 | Description: The network takes left image convolutional features from the second residual block 42 | of the feature network and the current disparity estimation as input. 43 | It then outputs the finetuned disparity prediction. The low-level feature 44 | information serves as a guidance to reduce noise and improve the quality of the final 45 | disparity map, especially on sharp boundaries. 46 | 47 | Args: 48 | :input: Input features composed of left image low-level features, cost-aggregator features, and 49 | cost-aggregator disparity. 50 | 51 | :disparity: predicted disparity 52 | """ 53 | 54 | output0 = self.conv1(input) 55 | output0 = self.classif1(output0) 56 | output = self.relu(output0 + disparity) 57 | 58 | return output -------------------------------------------------------------------------------- /deeppruner/models/submodules3d.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch 18 | import torch.nn as nn 19 | import torch.utils.data 20 | from models.submodules import SubModule, convbn_3d_lrelu, convbn_transpose_3d 21 | 22 | 23 | class HourGlass(SubModule): 24 | def __init__(self, inplanes=16): 25 | super(HourGlass, self).__init__() 26 | 27 | self.conv1 = convbn_3d_lrelu(inplanes, inplanes * 2, kernel_size=3, stride=2, pad=1) 28 | self.conv2 = convbn_3d_lrelu(inplanes * 2, inplanes * 2, kernel_size=3, stride=1, pad=1) 29 | 30 | self.conv1_1 = convbn_3d_lrelu(inplanes * 2, inplanes * 4, kernel_size=3, stride=2, pad=1) 31 | self.conv2_1 = convbn_3d_lrelu(inplanes * 4, inplanes * 4, kernel_size=3, stride=1, pad=1) 32 | 33 | self.conv3 = convbn_3d_lrelu(inplanes * 4, inplanes * 8, kernel_size=3, stride=2, pad=1) 34 | self.conv4 = convbn_3d_lrelu(inplanes * 8, inplanes * 8, kernel_size=3, stride=1, pad=1) 35 | 36 | self.conv5 = convbn_transpose_3d(inplanes * 8, inplanes * 4, kernel_size=3, padding=1, 37 | output_padding=(0, 1, 1), stride=(1, 2, 2), bias=False) 38 | self.conv6 = convbn_transpose_3d(inplanes * 4, inplanes * 2, kernel_size=3, padding=1, 39 | output_padding=(0, 1, 1), stride=(1, 2, 2), bias=False) 40 | self.conv7 = convbn_transpose_3d(inplanes * 2, inplanes, kernel_size=3, padding=1, 41 | output_padding=(0, 1, 1), stride=(1, 2, 2), bias=False) 42 | 43 | self.last_conv3d_layer = nn.Sequential( 44 | convbn_3d_lrelu(inplanes, inplanes * 2, 3, 1, 1), 45 | nn.Conv3d(inplanes * 2, 1, kernel_size=3, padding=1, stride=1, bias=False)) 46 | 47 | self.softmax = nn.Softmax(dim=1) 48 | 49 | self.weight_init() 50 | 51 | 52 | class MaxDisparityPredictor(HourGlass): 53 | 54 | def __init__(self, hourglass_inplanes=16): 55 | super(MaxDisparityPredictor, self).__init__(hourglass_inplanes) 56 | 57 | def forward(self, input, input_disparity): 58 | """ 59 | Confidence Range Prediction (Max Disparity): 60 | Description: The network has a convolutional encoder-decoder structure. It takes the sparse 61 | disparity estimations from the differentiable PatchMatch, the left image and the warped right image 62 | (warped according to the sparse disparity estimations) as input and outputs the upper bound of 63 | the confidence range for each pixel i. 64 | Args: 65 | :input: Left and Warped right Image features as Cost Volume. 66 | :input_disparity: PatchMatch predicted disparity samples. 67 | Returns: 68 | :disparity_output: Max Disparity of the reduced disaprity search range. 69 | :feature_output: High-level features of the MaxDisparityPredictor 70 | """ 71 | 72 | output0 = self.conv1(input) 73 | output0_a = self.conv2(output0) + output0 74 | 75 | output0 = self.conv1_1(output0_a) 76 | output0_c = self.conv2_1(output0) + output0 77 | 78 | output0 = self.conv3(output0_c) 79 | output0 = self.conv4(output0) + output0 80 | 81 | output1 = self.conv5(output0) + output0_c 82 | output1 = self.conv6(output1) + output0_a 83 | output1 = self.conv7(output1) 84 | 85 | output2 = self.last_conv3d_layer(output1).squeeze(1) 86 | feature_output = output2 87 | 88 | confidence_output = self.softmax(output2) 89 | disparity_output = torch.sum(confidence_output * input_disparity, dim=1).unsqueeze(1) 90 | 91 | return disparity_output, feature_output 92 | 93 | 94 | class MinDisparityPredictor(HourGlass): 95 | 96 | def __init__(self, hourglass_inplanes=16): 97 | super(MinDisparityPredictor, self).__init__(hourglass_inplanes) 98 | 99 | def forward(self, input, input_disparity): 100 | """ 101 | Confidence Range Prediction (Min Disparity): 102 | Description: The network has a convolutional encoder-decoder structure. It takes the sparse 103 | disparity estimations from the differentiable PatchMatch, the left image and the warped right image 104 | (warped according to the sparse disparity estimations) as input and outputs the lower bound of 105 | the confidence range for each pixel i. 106 | Args: 107 | :input: Left and Warped right Image features as Cost Volume. 108 | :input_disparity: PatchMatch predicted disparity samples. 109 | Returns: 110 | :disparity_output: Min Disparity of the reduced disaprity search range. 111 | :feature_output: High-level features of the MaxDisparityPredictor 112 | """ 113 | 114 | output0 = self.conv1(input) 115 | output0_a = self.conv2(output0) + output0 116 | 117 | output0 = self.conv1_1(output0_a) 118 | output0_c = self.conv2_1(output0) + output0 119 | 120 | output0 = self.conv3(output0_c) 121 | output0 = self.conv4(output0) + output0 122 | 123 | output1 = self.conv5(output0) + output0_c 124 | output1 = self.conv6(output1) + output0_a 125 | output1 = self.conv7(output1) 126 | 127 | output2 = self.last_conv3d_layer(output1).squeeze(1) 128 | feature_output = output2 129 | 130 | confidence_output = self.softmax(output2) 131 | disparity_output = torch.sum(confidence_output * input_disparity, dim=1).unsqueeze(1) 132 | 133 | return disparity_output, feature_output 134 | 135 | 136 | class CostAggregator(HourGlass): 137 | 138 | def __init__(self, cost_aggregator_inplanes, hourglass_inplanes=16): 139 | super(CostAggregator, self).__init__(inplanes=16) 140 | 141 | self.dres0 = nn.Sequential(convbn_3d_lrelu(cost_aggregator_inplanes, 64, 3, 1, 1), 142 | convbn_3d_lrelu(64, 32, 3, 1, 1)) 143 | 144 | self.dres1 = nn.Sequential(convbn_3d_lrelu(32, 32, 3, 1, 1), 145 | convbn_3d_lrelu(32, hourglass_inplanes, 3, 1, 1)) 146 | 147 | def forward(self, input, input_disparity): 148 | """ 149 | 3D Cost Aggregator 150 | Description: Based on the predicted range in the pruning module, 151 | we build the 3D cost volume estimator and conduct spatial aggregation. 152 | Following common practice, we take the left image, the warped right image and corresponding disparities 153 | as input and output the cost over the disparity range at the size B X R X H X W , where R is the number 154 | of disparities per pixel. Compared to prior work, our R is more than 10 times smaller, making 155 | this module very efficient. Soft-arg max is again used to predict the disparity value , 156 | so that our approach is end-to-end trainable. 157 | 158 | Args: 159 | :input: Cost-Volume composed of left image features, warped right image features, 160 | Confidence range Predictor features and input disparity samples/ 161 | 162 | :input_disparity: input disparity samples. 163 | 164 | Returns: 165 | :disparity_output: Predicted disparity 166 | :feature_output: High-level features of 3d-Cost Aggregator 167 | 168 | """ 169 | 170 | output0 = self.dres0(input) 171 | output0_b = self.dres1(output0) 172 | 173 | output0 = self.conv1(output0_b) 174 | output0_a = self.conv2(output0) + output0 175 | 176 | output0 = self.conv1_1(output0_a) 177 | output0_c = self.conv2_1(output0) + output0 178 | 179 | output0 = self.conv3(output0_c) 180 | output0 = self.conv4(output0) + output0 181 | 182 | output1 = self.conv5(output0) + output0_c 183 | output1 = self.conv6(output1) + output0_a 184 | output1 = self.conv7(output1) + output0_b 185 | 186 | output2 = self.last_conv3d_layer(output1).squeeze(1) 187 | feature_output = output2 188 | 189 | confidence_output = self.softmax(output2) 190 | disparity_output = torch.sum(confidence_output * input_disparity, dim=1) 191 | 192 | return disparity_output.unsqueeze(1), feature_output 193 | -------------------------------------------------------------------------------- /deeppruner/models/utils.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import torch 18 | import torch.nn as nn 19 | 20 | 21 | class UniformSampler(nn.Module): 22 | def __init__(self): 23 | super(UniformSampler, self).__init__() 24 | 25 | def forward(self, min_disparity, max_disparity, number_of_samples=10): 26 | """ 27 | Uniform Sampler 28 | Description: The Confidence Range Predictor predicts a reduced disparity search range R(i) = [l(i), u(i)] 29 | for each pixel i. We then, generate disparity samples from this reduced search range for Cost Aggregation 30 | or second stage of Patch Match. From experiments, we found Uniform sampling to work better. 31 | 32 | Args: 33 | :min_disparity: lower bound of disparity search range (predicted by Confidence Range Predictor) 34 | :max_disparity: upper bound of disparity range predictor (predicted by Confidence Range Predictor) 35 | :number_of_samples (default:10): number of samples to be genearted. 36 | Returns: 37 | :sampled_disparities: Uniformly generated disparity samples from the input search range. 38 | """ 39 | 40 | device = min_disparity.get_device() 41 | 42 | multiplier = (max_disparity - min_disparity) / (number_of_samples + 1) 43 | range_multiplier = torch.arange(1.0, number_of_samples + 1, 1, device=device).view(number_of_samples, 1, 1) 44 | sampled_disparities = min_disparity + multiplier * range_multiplier 45 | 46 | return sampled_disparities 47 | 48 | 49 | class SpatialTransformer(nn.Module): 50 | def __init__(self): 51 | super(SpatialTransformer, self).__init__() 52 | 53 | def forward(self, left_input, right_input, disparity_samples): 54 | """ 55 | Disparity Sample Cost Evaluator 56 | Description: 57 | Given the left image features, right iamge features and teh disparity samples, generates: 58 | - Per sample cost as , <.,.> denotes scalar-product. 59 | - Warped righ image features 60 | 61 | Args: 62 | :left_input: Left Image Features 63 | :right_input: Right Image Features 64 | :disparity_samples: Disparity Samples genearted by PatchMatch 65 | 66 | Returns: 67 | :disparity_samples_strength_1: Cost associated with each disaprity sample. 68 | :warped_right_feature_map: right iamge features warped according to input disparity. 69 | :left_feature_map: expanded left image features. 70 | """ 71 | 72 | device = left_input.get_device() 73 | left_y_coordinate = torch.arange(0.0, left_input.size()[3], device=device).repeat(left_input.size()[2]) 74 | left_y_coordinate = left_y_coordinate.view(left_input.size()[2], left_input.size()[3]) 75 | left_y_coordinate = torch.clamp(left_y_coordinate, min=0, max=left_input.size()[3] - 1) 76 | left_y_coordinate = left_y_coordinate.expand(left_input.size()[0], -1, -1) 77 | 78 | right_feature_map = right_input.expand(disparity_samples.size()[1], -1, -1, -1, -1).permute([1, 2, 0, 3, 4]) 79 | left_feature_map = left_input.expand(disparity_samples.size()[1], -1, -1, -1, -1).permute([1, 2, 0, 3, 4]) 80 | 81 | disparity_samples = disparity_samples.float() 82 | 83 | right_y_coordinate = left_y_coordinate.expand( 84 | disparity_samples.size()[1], -1, -1, -1).permute([1, 0, 2, 3]) - disparity_samples 85 | 86 | right_y_coordinate_1 = right_y_coordinate 87 | right_y_coordinate = torch.clamp(right_y_coordinate, min=0, max=right_input.size()[3] - 1) 88 | 89 | warped_right_feature_map = torch.gather(right_feature_map, dim=4, index=right_y_coordinate.expand( 90 | right_input.size()[1], -1, -1, -1, -1).permute([1, 0, 2, 3, 4]).long()) 91 | 92 | right_y_coordinate_1 = right_y_coordinate_1.unsqueeze(1) 93 | warped_right_feature_map = (1 - ((right_y_coordinate_1 < 0) + 94 | (right_y_coordinate_1 > right_input.size()[3] - 1)).float()) * \ 95 | (warped_right_feature_map) + torch.zeros_like(warped_right_feature_map) 96 | 97 | return warped_right_feature_map, left_feature_map 98 | -------------------------------------------------------------------------------- /deeppruner/setup_logging.py: -------------------------------------------------------------------------------- 1 | import logging 2 | 3 | def setup_logging(filename): 4 | 5 | log_format = '%(filename)s: %(message)s' 6 | logging.basicConfig(format=log_format, level=logging.INFO) 7 | 8 | file_handler = logging.FileHandler(filename) 9 | file_handler.setFormatter(logging.Formatter(fmt=log_format)) 10 | file_handler.setLevel(logging.INFO) 11 | logging.getLogger().addHandler(file_handler) 12 | -------------------------------------------------------------------------------- /deeppruner/submission_kitti.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import argparse 18 | import os 19 | import random 20 | import torch 21 | import torch.nn as nn 22 | import torch.nn.parallel 23 | import torch.backends.cudnn as cudnn 24 | import torch.optim as optim 25 | import torch.utils.data 26 | from torch.autograd import Variable 27 | import torch.nn.functional as F 28 | import skimage.io 29 | import numpy as np 30 | import logging 31 | from dataloader import kitti_submission_collector as ls 32 | from dataloader import preprocess 33 | from PIL import Image 34 | from models.deeppruner import DeepPruner 35 | from models.config import config as config_args 36 | from setup_logging import setup_logging 37 | 38 | parser = argparse.ArgumentParser(description='DeepPruner') 39 | parser.add_argument('--datapath', default='/', 40 | help='datapath') 41 | parser.add_argument('--loadmodel', default=None, 42 | help='load model') 43 | parser.add_argument('--save_dir', default='./', 44 | help='save directory') 45 | parser.add_argument('--logging_filename', default='./submission_kitti.log', 46 | help='filename for logs') 47 | parser.add_argument('--no-cuda', action='store_true', default=False, 48 | help='enables CUDA training') 49 | parser.add_argument('--seed', type=int, default=1, metavar='S', 50 | help='random seed (default: 1)') 51 | 52 | args = parser.parse_args() 53 | torch.backends.cudnn.benchmark = True 54 | args.cuda = not args.no_cuda and torch.cuda.is_available() 55 | 56 | args.cost_aggregator_scale = config_args.cost_aggregator_scale 57 | args.downsample_scale = args.cost_aggregator_scale * 8.0 58 | 59 | setup_logging(args.logging_filename) 60 | 61 | if args.cuda: 62 | torch.manual_seed(args.seed) 63 | np.random.seed(args.seed) 64 | random.seed(args.seed) 65 | torch.cuda.manual_seed(args.seed) 66 | torch.backends.cudnn.deterministic = True 67 | 68 | 69 | test_left_img, test_right_img = ls.datacollector(args.datapath) 70 | 71 | model = DeepPruner() 72 | model = nn.DataParallel(model) 73 | 74 | if args.cuda: 75 | model.cuda() 76 | 77 | logging.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 78 | 79 | 80 | if args.loadmodel is not None: 81 | logging.info("loading model...") 82 | state_dict = torch.load(args.loadmodel) 83 | model.load_state_dict(state_dict['state_dict'], strict=True) 84 | 85 | 86 | def test(imgL, imgR): 87 | model.eval() 88 | with torch.no_grad(): 89 | imgL = Variable(torch.FloatTensor(imgL)) 90 | imgR = Variable(torch.FloatTensor(imgR)) 91 | 92 | if args.cuda: 93 | imgL, imgR = imgL.cuda(), imgR.cuda() 94 | 95 | refined_disparity = model(imgL, imgR) 96 | return refined_disparity 97 | 98 | 99 | def main(): 100 | 101 | for left_image_path, right_image_path in zip(test_left_img, test_right_img): 102 | imgL = np.asarray(Image.open(left_image_path)) 103 | imgR = np.asarray(Image.open(right_image_path)) 104 | 105 | processed = preprocess.get_transform() 106 | imgL = processed(imgL).numpy() 107 | imgR = processed(imgR).numpy() 108 | 109 | imgL = np.reshape(imgL, [1, 3, imgL.shape[1], imgL.shape[2]]) 110 | imgR = np.reshape(imgR, [1, 3, imgR.shape[1], imgR.shape[2]]) 111 | 112 | w = imgL.shape[3] 113 | h = imgL.shape[2] 114 | dw = int(args.downsample_scale - (w%args.downsample_scale + (w%args.downsample_scale==0)*args.downsample_scale)) 115 | dh = int(args.downsample_scale - (h%args.downsample_scale + (h%args.downsample_scale==0)*args.downsample_scale)) 116 | 117 | top_pad = dh 118 | left_pad = dw 119 | imgL = np.lib.pad(imgL, ((0, 0), (0, 0), (top_pad, 0), (0, left_pad)), mode='constant', constant_values=0) 120 | imgR = np.lib.pad(imgR, ((0, 0), (0, 0), (top_pad, 0), (0, left_pad)), mode='constant', constant_values=0) 121 | 122 | disparity = test(imgL, imgR) 123 | disparity = disparity[0, top_pad:, :-left_pad].data.cpu().numpy() 124 | skimage.io.imsave(os.path.join(args.save_dir, left_image_path.split('/') 125 | [-1]), (disparity * 256).astype('uint16')) 126 | 127 | logging.info("Disparity for {} generated at: {}".format(left_image_path, os.path.join(args.save_dir, 128 | left_image_path.split('/')[-1]))) 129 | 130 | 131 | if __name__ == '__main__': 132 | main() 133 | -------------------------------------------------------------------------------- /deeppruner/train_sceneflow.py: -------------------------------------------------------------------------------- 1 | # --------------------------------------------------------------------------- 2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch 3 | # 4 | # Copyright (c) 2019 Uber Technologies, Inc. 5 | # 6 | # Licensed under the Uber Non-Commercial License (the "License"); 7 | # you may not use this file except in compliance with the License. 8 | # You may obtain a copy of the License at the root directory of this project. 9 | # 10 | # See the License for the specific language governing permissions and 11 | # limitations under the License. 12 | # 13 | # Written by Shivam Duggal 14 | # --------------------------------------------------------------------------- 15 | 16 | from __future__ import print_function 17 | import argparse 18 | import os 19 | import random 20 | import torch 21 | import torch.nn as nn 22 | import torch.backends.cudnn as cudnn 23 | import torch.optim as optim 24 | import torch.utils.data 25 | from torch.autograd import Variable 26 | import numpy as np 27 | from dataloader import sceneflow_collector as lt 28 | from dataloader import sceneflow_loader as DA 29 | from models.deeppruner import DeepPruner 30 | from loss_evaluation import loss_evaluation 31 | from tensorboardX import SummaryWriter 32 | import skimage 33 | import time 34 | import logging 35 | from models.config import config as config_args 36 | from setup_logging import setup_logging 37 | 38 | parser = argparse.ArgumentParser(description='DeepPruner') 39 | parser.add_argument('--datapath_monkaa', default='/', 40 | help='datapath for sceneflow monkaa dataset') 41 | parser.add_argument('--datapath_flying', default='/', 42 | help='datapath for sceneflow flying dataset') 43 | parser.add_argument('--datapath_driving', default='/', 44 | help='datapath for sceneflow driving dataset') 45 | parser.add_argument('--epochs', type=int, default=100, 46 | help='number of epochs to train') 47 | parser.add_argument('--loadmodel', default=None, 48 | help='load model') 49 | parser.add_argument('--save_dir', default='./', 50 | help='save directory') 51 | parser.add_argument('--savemodel', default='./', 52 | help='save model') 53 | parser.add_argument('--logging_filename', default='./train_sceneflow.log', 54 | help='save model') 55 | parser.add_argument('--no-cuda', action='store_true', default=False, 56 | help='enables CUDA training') 57 | parser.add_argument('--seed', type=int, default=1, metavar='S', 58 | help='random seed (default: 1)') 59 | 60 | args = parser.parse_args() 61 | args.cuda = not args.no_cuda and torch.cuda.is_available() 62 | 63 | torch.manual_seed(args.seed) 64 | if args.cuda: 65 | torch.manual_seed(args.seed) 66 | np.random.seed(args.seed) 67 | random.seed(args.seed) 68 | torch.cuda.manual_seed(args.seed) 69 | torch.backends.cudnn.deterministic = True 70 | 71 | 72 | args.cost_aggregator_scale = config_args.cost_aggregator_scale 73 | args.maxdisp = config_args.max_disp 74 | 75 | setup_logging(args.logging_filename) 76 | 77 | 78 | all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = lt.dataloader(args.datapath_monkaa, 79 | args.datapath_flying, args.datapath_driving) 80 | 81 | 82 | TrainImgLoader = torch.utils.data.DataLoader( 83 | DA.SceneflowLoader(all_left_img, all_right_img, all_left_disp, args.cost_aggregator_scale*8.0, True), 84 | batch_size=1, shuffle=True, num_workers=8, drop_last=False) 85 | 86 | TestImgLoader = torch.utils.data.DataLoader( 87 | DA.SceneflowLoader(test_left_img, test_right_img, test_left_disp, args.cost_aggregator_scale*8.0, False), 88 | batch_size=1, shuffle=False, num_workers=4, drop_last=False) 89 | 90 | 91 | model = DeepPruner() 92 | writer = SummaryWriter() 93 | 94 | if args.cuda: 95 | model = nn.DataParallel(model) 96 | model.cuda() 97 | 98 | 99 | if args.loadmodel is not None: 100 | state_dict = torch.load(args.loadmodel) 101 | model.load_state_dict(state_dict['state_dict'], strict=True) 102 | 103 | logging.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 104 | 105 | optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999)) 106 | 107 | 108 | def train(imgL, imgR, disp_L, iteration): 109 | model.train() 110 | imgL = Variable(torch.FloatTensor(imgL)) 111 | imgR = Variable(torch.FloatTensor(imgR)) 112 | disp_L = Variable(torch.FloatTensor(disp_L)) 113 | 114 | if args.cuda: 115 | imgL, imgR, disp_true = imgL.cuda(), imgR.cuda(), disp_L.cuda() 116 | 117 | mask = disp_true < args.maxdisp 118 | mask.detach_() 119 | 120 | optimizer.zero_grad() 121 | result = model(imgL, imgR) 122 | 123 | loss, _ = loss_evaluation(result, disp_true, mask, args.cost_aggregator_scale) 124 | 125 | loss.backward() 126 | optimizer.step() 127 | 128 | return loss.item() 129 | 130 | 131 | def test(imgL, imgR, disp_L, iteration, pad_w, pad_h): 132 | 133 | model.eval() 134 | with torch.no_grad(): 135 | imgL = Variable(torch.FloatTensor(imgL)) 136 | imgR = Variable(torch.FloatTensor(imgR)) 137 | disp_L = Variable(torch.FloatTensor(disp_L)) 138 | 139 | if args.cuda: 140 | imgL, imgR, disp_true = imgL.cuda(), imgR.cuda(), disp_L.cuda() 141 | 142 | mask = disp_true < args.maxdisp 143 | mask.detach_() 144 | 145 | if len(disp_true[mask]) == 0: 146 | logging.info("invalid GT disaprity...") 147 | return 0, 0 148 | 149 | optimizer.zero_grad() 150 | 151 | result = model(imgL, imgR) 152 | output = [] 153 | for ind in range(len(result)): 154 | output.append(result[ind][:, pad_h:, pad_w:]) 155 | result = output 156 | 157 | loss, output_disparity = loss_evaluation(result, disp_true, mask, args.cost_aggregator_scale) 158 | epe_loss = torch.mean(torch.abs(output_disparity[mask] - disp_true[mask])) 159 | 160 | return loss.item(), epe_loss.item() 161 | 162 | 163 | def adjust_learning_rate(optimizer, epoch): 164 | if epoch <= 20: 165 | lr = 0.001 166 | elif epoch <= 40: 167 | lr = 0.0007 168 | elif epoch <= 60: 169 | lr = 0.0003 170 | else: 171 | lr = 0.0001 172 | for param_group in optimizer.param_groups: 173 | param_group['lr'] = lr 174 | 175 | 176 | def main(): 177 | for epoch in range(0, args.epochs): 178 | total_train_loss = 0 179 | total_test_loss = 0 180 | total_epe_loss = 0 181 | adjust_learning_rate(optimizer, epoch) 182 | 183 | if epoch % 1 == 0 and epoch != 0: 184 | logging.info("testing...") 185 | for batch_idx, (imgL, imgR, disp_L, pad_w, pad_h) in enumerate(TestImgLoader): 186 | start_time = time.time() 187 | test_loss, epe_loss = test(imgL, imgR, disp_L, batch_idx, int(pad_w[0].item()), int(pad_h[0].item())) 188 | total_test_loss += test_loss 189 | total_epe_loss += epe_loss 190 | 191 | logging.info('Iter %d 3-px error in val = %.3f, time = %.2f \n' % 192 | (batch_idx, epe_loss, time.time() - start_time)) 193 | 194 | writer.add_scalar("val-loss-iter", test_loss, epoch * 4370 + batch_idx) 195 | writer.add_scalar("val-epe-loss-iter", epe_loss, epoch * 4370 + batch_idx) 196 | 197 | logging.info('epoch %d total test loss = %.3f' % (epoch, total_test_loss / len(TestImgLoader))) 198 | writer.add_scalar("val-loss", total_test_loss / len(TestImgLoader), epoch) 199 | logging.info('epoch %d total epe loss = %.3f' % (epoch, total_epe_loss / len(TestImgLoader))) 200 | writer.add_scalar("epe-loss", total_epe_loss / len(TestImgLoader), epoch) 201 | 202 | for batch_idx, (imgL_crop, imgR_crop, disp_crop_L) in enumerate(TrainImgLoader): 203 | start_time = time.time() 204 | loss = train(imgL_crop, imgR_crop, disp_crop_L, batch_idx) 205 | total_train_loss += loss 206 | 207 | writer.add_scalar("loss-iter", loss, batch_idx + 35454 * epoch) 208 | logging.info('Iter %d training loss = %.3f , time = %.2f \n' % (batch_idx, loss, time.time() - start_time)) 209 | 210 | logging.info('epoch %d total training loss = %.3f' % (epoch, total_train_loss / len(TrainImgLoader))) 211 | writer.add_scalar("loss", total_train_loss / len(TrainImgLoader), epoch) 212 | 213 | # SAVE 214 | if epoch % 1 == 0: 215 | savefilename = args.savemodel + 'finetune_' + str(epoch) + '.tar' 216 | torch.save({ 217 | 'epoch': epoch, 218 | 'state_dict': model.state_dict(), 219 | 'train_loss': total_train_loss, 220 | 'test_loss': total_test_loss, 221 | }, savefilename) 222 | 223 | 224 | if __name__ == '__main__': 225 | main() 226 | -------------------------------------------------------------------------------- /readme_images/CRP.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/CRP.png -------------------------------------------------------------------------------- /readme_images/DPM.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM.png -------------------------------------------------------------------------------- /readme_images/DPM_filters.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_filters.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/DPM/Never back down 1032.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/DPM/Never back down 1032.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/DPM/Never back down 1979.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/DPM/Never back down 1979.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/DPM/Never back down 2046.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/DPM/Never back down 2046.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/DPM/cheetah1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/DPM/cheetah1.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/DPM/face1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/DPM/face1.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/ImageA/Never back down 1032.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageA/Never back down 1032.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/ImageA/Never back down 1979.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageA/Never back down 1979.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/ImageA/Never back down 2046.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageA/Never back down 2046.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/ImageA/cheetah1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageA/cheetah1.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/ImageA/face1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageA/face1.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/ImageB/Never back down 1213.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageB/Never back down 1213.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/ImageB/Never back down 2040.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageB/Never back down 2040.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/ImageB/Never back down 2066.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageB/Never back down 2066.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/ImageB/cheetah2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageB/cheetah2.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/ImageB/face2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageB/face2.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/PM/Never back down 1032.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/PM/Never back down 1032.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/PM/Never back down 1979.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/PM/Never back down 1979.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/PM/Never back down 2046.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/PM/Never back down 2046.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/PM/cheetah1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/PM/cheetah1.png -------------------------------------------------------------------------------- /readme_images/DPM_reconstruction/PM/face.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/PM/face.png -------------------------------------------------------------------------------- /readme_images/DeepPruner.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DeepPruner.png -------------------------------------------------------------------------------- /readme_images/KITTI_test_set.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/KITTI_test_set.png -------------------------------------------------------------------------------- /readme_images/kitti_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/kitti_results.png -------------------------------------------------------------------------------- /readme_images/original_patch_match.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/original_patch_match.png -------------------------------------------------------------------------------- /readme_images/original_patch_match_steps.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/original_patch_match_steps.png -------------------------------------------------------------------------------- /readme_images/rob.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/rob.png -------------------------------------------------------------------------------- /readme_images/rob_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/rob_results.png -------------------------------------------------------------------------------- /readme_images/sceneflow.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/sceneflow.png -------------------------------------------------------------------------------- /readme_images/sceneflow_results.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/sceneflow_results.png -------------------------------------------------------------------------------- /readme_images/softmax.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/softmax.png -------------------------------------------------------------------------------- /readme_images/uncertainty_vis.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/uncertainty_vis.png --------------------------------------------------------------------------------