├── DifferentiablePatchMatch
├── README.md
├── demo_script.py
└── models
│ ├── .gitignore
│ ├── __init__.py
│ ├── config.py
│ ├── feature_extractor.py
│ ├── image_reconstruction.py
│ └── patch_match.py
├── LICENSE
├── NOTICE
├── README.md
├── deeppruner
├── .gitignore
├── README.md
├── dataloader
│ ├── .gitignore
│ ├── __init__.py
│ ├── kitti_collector.py
│ ├── kitti_loader.py
│ ├── kitti_submission_collector.py
│ ├── preprocess.py
│ ├── readpfm.py
│ ├── sceneflow_collector.py
│ └── sceneflow_loader.py
├── finetune_kitti.py
├── loss_evaluation.py
├── models
│ ├── .gitignore
│ ├── __init__.py
│ ├── config.py
│ ├── deeppruner.py
│ ├── feature_extractor_best.py
│ ├── feature_extractor_fast.py
│ ├── patch_match.py
│ ├── submodules.py
│ ├── submodules2d.py
│ ├── submodules3d.py
│ └── utils.py
├── setup_logging.py
├── submission_kitti.py
└── train_sceneflow.py
└── readme_images
├── CRP.png
├── DPM.png
├── DPM_filters.png
├── DPM_reconstruction
├── DPM
│ ├── Never back down 1032.png
│ ├── Never back down 1979.png
│ ├── Never back down 2046.png
│ ├── cheetah1.png
│ └── face1.png
├── ImageA
│ ├── Never back down 1032.png
│ ├── Never back down 1979.png
│ ├── Never back down 2046.png
│ ├── cheetah1.png
│ └── face1.png
├── ImageB
│ ├── Never back down 1213.png
│ ├── Never back down 2040.png
│ ├── Never back down 2066.png
│ ├── cheetah2.png
│ └── face2.png
└── PM
│ ├── Never back down 1032.png
│ ├── Never back down 1979.png
│ ├── Never back down 2046.png
│ ├── cheetah1.png
│ └── face.png
├── DeepPruner.png
├── KITTI_test_set.png
├── kitti_results.png
├── original_patch_match.png
├── original_patch_match_steps.png
├── rob.png
├── rob_results.png
├── sceneflow.png
├── sceneflow_results.png
├── softmax.png
└── uncertainty_vis.png
/DifferentiablePatchMatch/README.md:
--------------------------------------------------------------------------------
1 | # Differentiable Patch-Match
2 |
3 | ##### Table of Contents
4 | [Patch Match](#PatchMatch)
5 | [Differentiable Patch Match](#DifferentiablePatchMatch)
6 | [Differentiable Patch Match vs Patch Match for Image Reconstruction](#Comparison)
7 | [Run Command](#run_command)
8 | [Citation](#citation)
9 |
10 |
11 | ## PatchMatch
12 |
13 | Patch Match ([Barnes et al.](https://gfx.cs.princeton.edu/pubs/Barnes_2009_PAR/)) was originally introduced as an efficient way to find dense correspondences across images for structural editing. The key idea behind it is that, a **large number of random samples often lead to good guesses**. Additionally, **neighboring pixels usually have coherent matches**. Therefore, once a good match is found, it can efficiently propagate the information to the neighbors.
14 |
15 | | **Patch Match Overview** | **Patch Match Steps** |
16 | | :----------------------------------------------------------- | ------------------------------------------------------------ |
17 | |
|
|
18 |
19 |
20 |
21 | ## Differentiable PatchMatch
22 |
23 | 1. In our work, we unroll generalized **PatchMatch as a recurrent neural network**, where each unrolling step is equivalent to each iteration of the algorithm. **This is important as it allow us to train our full model end-to-end. **
24 | 2. Specifically, we design the following layers:
25 | 1. **Particle sampling layer**: for each pixel i, we randomly generate k disparity values from the uniform distribution over predicted/pre-defined search space.
26 | 2. **Propagation layer**: particles from adjacent pixels are propagated together through convolution with a predefined one-hot filter pattern, which encodes the fact that we allow each pixel to propagate particles to its 4-neighbours.
27 | 3. **Evaluation layer**: for each pixel i, matching scores are computed by taking the inner product between the left feature and the right feature: si,j; = 0(i), f1(i + di,j)>. The best k disparity value for each pixel is carried towards the next iteration
28 |
29 |
30 |
31 | We replace the non-differentiable argmax during evaluation with differentiable softmax.
32 |
33 |
34 |
35 |
36 |
37 |
38 | 
39 |
40 |
41 |
42 |
43 |
44 |
45 | 3. For usage of Differentiable Patch Match for stereo matching, refer to out paper.
46 |
47 |
48 |
49 |
50 | ## Differentiable Patch Match vs Patch Match for Image Reconstruction
51 |
52 | 1. In this section, we compare the proposed Differentiable Patch Match with the original Patch Match for the image reconstruction task.
53 | 2. Given two images, Image A and Image B, we reconstruct Image A from Image B, based on **Differentiable PatchMatch / PatchMatch** "A-to-B" dense patch mappings.
54 | 3. Following are the per-iteration comparison results between the Differentiable Patch Match and the original Patch Match.
55 |
56 |
57 |
58 |
59 |
60 |
61 |
62 |
63 |
64 |
65 |
66 | ### *Images*
67 |
68 |
69 | | Image A | Image B |
70 | | :-----: | :-----: |
71 | |
|
|
72 | |
|
|
73 | |
|
|
74 | |
|
|
75 | |
|
|
76 |
77 |
78 |
79 |
80 |
81 |
82 |
83 |
84 |
85 |
86 |
87 |
88 |
89 |
90 |
91 |
92 | ### *Reconstructed A from B*
93 |
94 |
95 | | Differentiable Patch Match | Patch Match |
96 | | :-----: | :-----: |
97 | |
|
|
98 | |
|
|
99 | |
|
|
100 | |
|
|
101 | |
|
|
102 |
103 |
104 |
105 |
106 |
107 |
108 |
109 |
110 |
111 |
112 |
113 |
114 |
115 |
116 |
117 | ### RUN COMMAND
118 |
119 | 1. Create a base_dir with two folders:
120 | base_directory
121 | |----- image_1
122 | |---- image_2
123 |
124 |
125 |
126 | > python demo_script.py \
127 | > --base_dir
128 | > --save_dir < save directory for reconstructed images>
129 |
130 |
131 |
132 |
133 |
134 | ## Citation
135 |
136 |
137 |
138 | If you use our source code, or our paper, please consider citing the following:
139 |
140 | > @inproceedings{Duggal2019ICCV,
141 | > title = {DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch},
142 | > author = {Shivam Duggal and Shenlong Wang and Wei-Chiu Ma and Rui Hu and Raquel Urtasun}, >
143 | > booktitle = {ICCV},
144 | > year = {2019}
145 | > }
146 |
147 |
148 |
149 | Correspondences to Shivam Duggal , Shenlong Wang , Wei-Chiu Ma
150 |
--------------------------------------------------------------------------------
/DifferentiablePatchMatch/demo_script.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | from PIL import Image
18 | import torch
19 | import random
20 | import skimage
21 | import numpy as np
22 | from models.image_reconstruction import ImageReconstruction
23 | import os
24 | import torchvision.transforms as transforms
25 | import argparse
26 | import matplotlib.pyplot as plt
27 |
28 | parser = argparse.ArgumentParser(description='Differentiable PatchMatch')
29 | parser.add_argument('--base_dir', default='./',
30 | help='path of base directory where images are stored.')
31 | parser.add_argument('--save_dir', default='./',
32 | help='save directory')
33 | parser.add_argument('--no-cuda', action='store_true', default=False,
34 | help='enables CUDA training')
35 | parser.add_argument('--seed', type=int, default=1, metavar='S',
36 | help='random seed (default: 1)')
37 |
38 | args = parser.parse_args()
39 | torch.backends.cudnn.benchmark=True
40 | args.cuda = not args.no_cuda and torch.cuda.is_available()
41 |
42 | if args.cuda:
43 | torch.manual_seed(args.seed)
44 | np.random.seed(args.seed)
45 | random.seed(args.seed)
46 | torch.cuda.manual_seed(args.seed)
47 | torch.backends.cudnn.deterministic = True
48 |
49 | model = ImageReconstruction()
50 |
51 | if args.cuda:
52 | model.cuda()
53 |
54 |
55 | def main():
56 |
57 | base_dir = args.base_dir
58 | for file1, file2 in zip(sorted(os.listdir(base_dir+'/image_1')), sorted(os.listdir(base_dir+'/image_2'))):
59 |
60 | image_1_image_path = base_dir + '/image_1/' + file1
61 | image_2_image_path = base_dir + '/image_2/' + file2
62 |
63 | image_1 = np.asarray(Image.open(image_1_image_path).convert('RGB'))
64 | image_2 = np.asarray(Image.open(image_2_image_path).convert('RGB'))
65 |
66 | image_1 = transforms.ToTensor()(image_1).unsqueeze(0).cuda().float()
67 | image_2 = transforms.ToTensor()(image_2).unsqueeze(0).cuda().float()
68 |
69 | reconstruction = model(image_1, image_2)
70 |
71 | plt.imsave(os.path.join(args.save_dir, image_1_image_path.split('/')[-1]),
72 | np.asarray(reconstruction[0].permute(1,2,0).data.cpu()*256).astype('uint16'))
73 |
74 |
75 |
76 | if __name__ == '__main__':
77 | with torch.no_grad():
78 | main()
--------------------------------------------------------------------------------
/DifferentiablePatchMatch/models/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 |
--------------------------------------------------------------------------------
/DifferentiablePatchMatch/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/DifferentiablePatchMatch/models/__init__.py
--------------------------------------------------------------------------------
/DifferentiablePatchMatch/models/config.py:
--------------------------------------------------------------------------------
1 |
2 | # ---------------------------------------------------------------------------
3 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
4 | #
5 | # Copyright (c) 2019 Uber Technologies, Inc.
6 | #
7 | # Licensed under the Uber Non-Commercial License (the "License");
8 | # you may not use this file except in compliance with the License.
9 | # You may obtain a copy of the License at the root directory of this project.
10 | #
11 | # See the License for the specific language governing permissions and
12 | # limitations under the License.
13 | #
14 | # Written by Shivam Duggal
15 | # ---------------------------------------------------------------------------
16 |
17 | from __future__ import print_function
18 |
19 | class obj(object):
20 | def __init__(self, d):
21 | for key, value in d.items():
22 | if isinstance(value, (list, tuple)):
23 | setattr(self, key, [obj(x) if isinstance(x, dict) else x for x in value])
24 | else:
25 | setattr(self, key, obj(value) if isinstance(value, dict) else value)
26 |
27 | config = {
28 | "patch_match_args": {
29 | # sample count refers to random sampling stage of generalized PM.
30 | # Number of random samples generated: (sample_count+1) * (sample_count+1)
31 | # we generate (sample_count+1) samples in x direction, and (sample_count+1) samples in y direction,
32 | # and then perform meshgrid like opertaion to generate (sample_count+1) * (sample_count+1) samples.
33 | "sample_count": 1,
34 |
35 | "iteration_count": 21,
36 | "propagation_filter_size": 3,
37 | "propagation_type": "faster_filter_3_propagation", # for better code for PM propagation, set it to None
38 | "softmax_temperature": 10000000000, # softmax temperature for evaluation. Larger temp. lead to sharper output.
39 | "random_search_window_size": [100,100], # search range around evaluated offsets after every iteration.
40 | "evaluation_type": "softmax"
41 | },
42 |
43 | "feature_extractor_filter_size": 7
44 |
45 | }
46 |
47 |
48 | config = obj(config)
--------------------------------------------------------------------------------
/DifferentiablePatchMatch/models/feature_extractor.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 |
22 | class feature_extractor(nn.Module):
23 | def __init__(self, filter_size):
24 | super(feature_extractor, self).__init__()
25 |
26 | self.filter_size = filter_size
27 |
28 | def forward(self, left_input, right_input):
29 | """
30 | Feature Extractor
31 |
32 | Description: Aggregates the RGB values from the neighbouring pixels in the window (filter_size * filter_size).
33 | No weights are learnt for this feature extractor.
34 |
35 | Args:
36 | :param left_input: Left Image
37 | :param right_input: Right Image
38 |
39 | Returns:
40 | :left_features: Left Image features
41 | :right_features: Right Image features
42 | :one_hot_filter: Convolution filter used to aggregate neighbour RGB features to the center pixel.
43 | one_hot_filter.shape = (filter_size * filter_size)
44 | """
45 |
46 | device = left_input.get_device()
47 |
48 | label = torch.arange(0, self.filter_size * self.filter_size, device=device).repeat(
49 | self.filter_size * self.filter_size).view(
50 | self.filter_size * self.filter_size, 1, 1, self.filter_size, self.filter_size)
51 |
52 | one_hot_filter = torch.zeros_like(label).scatter_(0, label, 1).float()
53 |
54 | left_features = F.conv3d(left_input.unsqueeze(1), one_hot_filter,
55 | padding=(0, self.filter_size // 2, self.filter_size // 2))
56 | right_features = F.conv3d(right_input.unsqueeze(1), one_hot_filter,
57 | padding=(0, self.filter_size // 2, self.filter_size // 2))
58 |
59 | left_features = left_features.view(left_features.size()[0],
60 | left_features.size()[1] * left_features.size()[2],
61 | left_features.size()[3],
62 | left_features.size()[4])
63 |
64 | right_features = right_features.view(right_features.size()[0],
65 | right_features.size()[1] * right_features.size()[2],
66 | right_features.size()[3],
67 | right_features.size()[4])
68 |
69 | return left_features, right_features, one_hot_filter
70 |
--------------------------------------------------------------------------------
/DifferentiablePatchMatch/models/image_reconstruction.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 | from models.patch_match import PatchMatch
21 | from models.feature_extractor import feature_extractor
22 | from models.config import config as args
23 |
24 |
25 | class Reconstruct(nn.Module):
26 | def __init__(self, filter_size):
27 | super(Reconstruct, self).__init__()
28 | self.filter_size = filter_size
29 |
30 | def forward(self, right_input, offset_x, offset_y, x_coordinate, y_coordinate, neighbour_extraction_filter):
31 | """
32 | Reconstruct the left image using the NNF(NNF represented by the offsets and the xy_coordinates)
33 | We did Patch Voting on the offset field, before reconstruction, in order to
34 | generate smooth reconstruction.
35 | Args:
36 | :right_input: Right Image
37 | :offset_x: horizontal offset to generate the NNF.
38 | :offset_y: vertical offset to generate the NNF.
39 | :x_coordinate: X coordinate
40 | :y_coordinate: Y coordinate
41 |
42 | Returns:
43 | :reconstruction: Right image reconstruction
44 | """
45 |
46 | pad_size = self.filter_size // 2
47 | smooth_offset_x = nn.ReflectionPad2d(
48 | (pad_size, pad_size, pad_size, pad_size))(offset_x)
49 | smooth_offset_y = nn.ReflectionPad2d(
50 | (pad_size, pad_size, pad_size, pad_size))(offset_y)
51 |
52 | smooth_offset_x = F.conv2d(smooth_offset_x,
53 | neighbour_extraction_filter,
54 | padding=(pad_size, pad_size))[:, :, pad_size:-pad_size, pad_size:-pad_size]
55 |
56 | smooth_offset_y = F.conv2d(smooth_offset_y,
57 | neighbour_extraction_filter,
58 | padding=(pad_size, pad_size))[:, :, pad_size:-pad_size, pad_size:-pad_size]
59 |
60 | coord_x = torch.clamp(
61 | x_coordinate - smooth_offset_x,
62 | min=0,
63 | max=smooth_offset_x.size()[3] - 1)
64 |
65 | coord_y = torch.clamp(
66 | y_coordinate - smooth_offset_y,
67 | min=0,
68 | max=smooth_offset_x.size()[2] - 1)
69 |
70 | coord_x -= coord_x.size()[3] / 2
71 | coord_x /= (coord_x.size()[3] / 2)
72 |
73 | coord_y -= coord_y.size()[2] / 2
74 | coord_y /= (coord_y.size()[2] / 2)
75 |
76 | grid = torch.cat((coord_x.unsqueeze(4), coord_y.unsqueeze(4)), dim=4)
77 | grid = grid.view(grid.size()[0] * grid.size()[1], grid.size()[2], grid.size()[3], grid.size()[4])
78 | reconstruction = F.grid_sample(right_input.repeat(grid.size()[0], 1, 1, 1), grid)
79 | reconstruction = torch.mean(reconstruction, dim=0).unsqueeze(0)
80 |
81 | return reconstruction
82 |
83 |
84 | class ImageReconstruction(nn.Module):
85 | def __init__(self):
86 | super(ImageReconstruction, self).__init__()
87 |
88 | self.patch_match = PatchMatch(args.patch_match_args)
89 |
90 | filter_size = args.feature_extractor_filter_size
91 | self.feature_extractor = feature_extractor(filter_size)
92 | self.reconstruct = Reconstruct(filter_size)
93 |
94 | def forward(self, left_input, right_input):
95 | """
96 | ImageReconstruction:
97 | Description: This class performs the task of reconstruction the left image using the data of the other image,,
98 | by fidning correspondences (nnf) between the two fields.
99 | The images acan be any random images with some overlap between the two to assist
100 | the correspondence matching.
101 | For feature_extractor, we just use the RGB features of a (self.filter_size * self.filter_size) patch
102 | around each pixel.
103 | For finding the correspondences, we use the Differentiable PatchMatch.
104 | ** Note: There is no assumption of rectification between the two images. **
105 | ** Note: The words 'left' and 'right' do not have any significance.**
106 |
107 |
108 | Args:
109 | :left_input: Left Image (Image 1)
110 | :right_input: Right Image (Image 2)
111 |
112 | Returns:
113 | :reconstruction: Reconstructed left image.
114 | """
115 |
116 | left_features, right_features, neighbour_extraction_filter = self.feature_extractor(left_input, right_input)
117 | offset_x, offset_y, x_coordinate, y_coordinate = self.patch_match(left_features, right_features)
118 |
119 | reconstruction = self.reconstruct(right_input,
120 | offset_x, offset_y,
121 | x_coordinate, y_coordinate,
122 | neighbour_extraction_filter.squeeze(1))
123 |
124 | return reconstruction
125 |
--------------------------------------------------------------------------------
/DifferentiablePatchMatch/models/patch_match.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 |
22 | class RandomSampler(nn.Module):
23 | def __init__(self, device, number_of_samples):
24 | super(RandomSampler, self).__init__()
25 |
26 | # Number of offset samples generated by this function: (number_of_samples+1) * (number_of_samples+1)
27 | # we generate (number_of_samples+1) samples in x direction, and (number_of_samples+1) samples in y direction,
28 | # and then perform meshgrid like opertaion to generate (number_of_samples+1) * (number_of_samples+1) samples
29 | self.number_of_samples = number_of_samples
30 | self.range_multiplier = torch.arange(0.0, number_of_samples + 1, 1, device=device).view(
31 | number_of_samples + 1, 1, 1)
32 |
33 | def forward(self, min_offset_x, max_offset_x, min_offset_y, max_offset_y):
34 | """
35 | Random Sampler:
36 | Given the search range per pixel (defined by: [[lx(i), ux(i)], [ly(i), uy(i)]]),
37 | where lx = lower_bound of the hoizontal offset,
38 | ux = upper_bound of the horizontal offset,
39 | ly = lower_bound of the vertical offset,
40 | uy = upper_bound of teh vertical offset, for all pixel i. )
41 | random sampler generates samples from this search range.
42 | First the search range is discretized into `number_of_samples` buckets,
43 | then a random sample is generated from each random bucket.
44 | ** Discretization is done in both xy directions. ** (similar to meshgrid)
45 |
46 | Args:
47 | :min_offset_x: Min horizontal offset of the search range.
48 | :max_offset_x: Max horizontal offset of the search range.
49 | :min_offset_y: Min vertical offset of the search range.
50 | :max_offset_y: Max vertical offset of the search range.
51 | Returns:
52 | :offset_x: samples representing offset in the horizontal direction.
53 | :offset_y: samples representing offset in the vertical direction.
54 | """
55 |
56 | device = min_offset_x.get_device()
57 | noise = torch.rand(min_offset_x.repeat(1, self.number_of_samples + 1, 1, 1).size(), device=device)
58 |
59 | offset_x = min_offset_x + ((max_offset_x - min_offset_x) / (self.number_of_samples + 1)) * \
60 | (self.range_multiplier + noise)
61 | offset_y = min_offset_y + ((max_offset_y - min_offset_y) / (self.number_of_samples + 1)) * \
62 | (self.range_multiplier + noise)
63 |
64 | offset_x = offset_x.unsqueeze_(1).expand(-1, offset_y.size()[1], -1, -1, -1)
65 | offset_x = offset_x.contiguous().view(
66 | offset_x.size()[0], offset_x.size()[1] * offset_x.size()[2], offset_x.size()[3], offset_x.size()[4])
67 |
68 | offset_y = offset_y.unsqueeze_(2).expand(-1, -1, offset_y.size()[1], -1, -1)
69 | offset_y = offset_y.contiguous().view(
70 | offset_y.size()[0], offset_y.size()[1] * offset_y.size()[2], offset_y.size()[3], offset_y.size()[4])
71 |
72 | return offset_x, offset_y
73 |
74 |
75 | class Evaluate(nn.Module):
76 | def __init__(self, left_features, filter_size, evaluation_type='softmax', temperature=10000):
77 | super(Evaluate, self).__init__()
78 | self.temperature = temperature
79 | self.filter_size = filter_size
80 | self.softmax = torch.nn.Softmax(dim=1)
81 | self.evaluation_type = evaluation_type
82 |
83 | device = left_features.get_device()
84 | self.left_x_coordinate = torch.arange(0.0, left_features.size()[3], device=device).repeat(
85 | left_features.size()[2]).view(left_features.size()[2], left_features.size()[3])
86 |
87 | self.left_x_coordinate = torch.clamp(self.left_x_coordinate, min=0, max=left_features.size()[3] - 1)
88 | self.left_x_coordinate = self.left_x_coordinate.expand(left_features.size()[0], -1, -1).unsqueeze(1)
89 |
90 | self.left_y_coordinate = torch.arange(0.0, left_features.size()[2], device=device).unsqueeze(1).repeat(
91 | 1, left_features.size()[3]).view(left_features.size()[2], left_features.size()[3])
92 |
93 | self.left_y_coordinate = torch.clamp(self.left_y_coordinate, min=0, max=left_features.size()[3] - 1)
94 | self.left_y_coordinate = self.left_y_coordinate.expand(left_features.size()[0], -1, -1).unsqueeze(1)
95 |
96 | def forward(self, left_features, right_features, offset_x, offset_y):
97 | """
98 | PatchMatch Evaluation Block
99 | Description: For each pixel i, matching scores are computed by taking the inner product between the
100 | left feature and the right feature: score(i,j) = feature_left(i), feature_right(i+disparity(i,j))
101 | for all candidates j. The best k disparity value for each pixel is carried towards the next iteration.
102 |
103 | As per implementation,
104 | the complete disparity search range is discretized into intervals in
105 | DisparityInitialization() function. Corresponding to each disparity interval, we have multiple samples
106 | to evaluate. The best disparity sample per interval is the output of the function.
107 |
108 | Args:
109 | :left_features: Left Image Feature Map
110 | :right_features: Right Image Feature Map
111 | :offset_x: samples representing offset in the horizontal direction.
112 | :offset_y: samples representing offset in the vertical direction.
113 |
114 | Returns:
115 | :offset_x: horizontal offset evaluated as the best offset to generate NNF.
116 | :offset_y: vertical offset evaluated as the best offset to generate NNF.
117 |
118 | """
119 |
120 | right_x_coordinate = torch.clamp(self.left_x_coordinate - offset_x, min=0, max=left_features.size()[3] - 1)
121 | right_y_coordinate = torch.clamp(self.left_y_coordinate - offset_y, min=0, max=left_features.size()[2] - 1)
122 |
123 | right_x_coordinate -= right_x_coordinate.size()[3] / 2
124 | right_x_coordinate /= (right_x_coordinate.size()[3] / 2)
125 | right_y_coordinate -= right_y_coordinate.size()[2] / 2
126 | right_y_coordinate /= (right_y_coordinate.size()[2] / 2)
127 |
128 | samples = torch.cat((right_x_coordinate.unsqueeze(4), right_y_coordinate.unsqueeze(4)), dim=4)
129 | samples = samples.view(samples.size()[0] * samples.size()[1],
130 | samples.size()[2],
131 | samples.size()[3],
132 | samples.size()[4])
133 |
134 | offset_strength = torch.mean(-1.0 * (torch.abs(left_features.expand(
135 | offset_x.size()[1], -1, -1, -1) - F.grid_sample(right_features.expand(
136 | offset_x.size()[1], -1, -1, -1), samples))), dim=1) * self.temperature
137 |
138 | offset_strength = offset_strength.view(left_features.size()[0],
139 | offset_strength.size()[0] // left_features.size()[0],
140 | offset_strength.size()[1],
141 | offset_strength.size()[2])
142 |
143 | if self.evaluation_type == "softmax":
144 | offset_strength = torch.softmax(offset_strength, dim=1)
145 | offset_x = torch.sum(offset_x * offset_strength, dim=1).unsqueeze(1)
146 | offset_y = torch.sum(offset_y * offset_strength, dim=1).unsqueeze(1)
147 | else:
148 | offset_strength = torch.argmax(offset_strength, dim=1).unsqueeze(1)
149 | offset_x = torch.gather(offset_x, index=offset_strength, dim=1)
150 | offset_y = torch.gather(offset_y, index=offset_strength, dim=1)
151 |
152 | return offset_x, offset_y
153 |
154 |
155 | class Propagation(nn.Module):
156 | def __init__(self, device, filter_size):
157 | super(Propagation, self).__init__()
158 | self.filter_size = filter_size
159 | label = torch.arange(0, self.filter_size, device=device).repeat(self.filter_size).view(
160 | self.filter_size, 1, 1, 1, self.filter_size)
161 |
162 | self.one_hot_filter_h = torch.zeros_like(label).scatter_(0, label, 1).float()
163 |
164 | label = torch.arange(0, self.filter_size, device=device).repeat(self.filter_size).view(
165 | self.filter_size, 1, 1, self.filter_size, 1).long()
166 |
167 | self.one_hot_filter_v = torch.zeros_like(label).scatter_(0, label, 1).float()
168 |
169 | def forward(self, offset_x, offset_y, propagation_type="horizontal"):
170 | """
171 | PatchMatch Propagation Block
172 | Description: Particles from adjacent pixels are propagated together through convolution with a
173 | one-hot filter, which en-codes the fact that we allow each pixel
174 | to propagate particles to its 4-neighbours.
175 | Args:
176 | :offset_x: samples representing offset in the horizontal direction.
177 | :offset_y: samples representing offset in the vertical direction.
178 | :device: Cuda/ CPU device
179 | :propagation_type (default:"horizontal"): In order to be memory efficient, we use separable convolutions
180 | for propagtaion.
181 |
182 | Returns:
183 | :aggregated_offset_x: Horizontal offset samples aggregated from the neighbours.
184 | :aggregated_offset_y: Vertical offset samples aggregated from the neighbours.
185 |
186 | """
187 |
188 | offset_x = offset_x.view(offset_x.size()[0], 1, offset_x.size()[1], offset_x.size()[2], offset_x.size()[3])
189 | offset_y = offset_y.view(offset_y.size()[0], 1, offset_y.size()[1], offset_y.size()[2], offset_y.size()[3])
190 |
191 | if propagation_type is "horizontal":
192 | aggregated_offset_x = F.conv3d(offset_x, self.one_hot_filter_h, padding=(0, 0, self.filter_size // 2))
193 | aggregated_offset_y = F.conv3d(offset_y, self.one_hot_filter_h, padding=(0, 0, self.filter_size // 2))
194 |
195 | else:
196 | aggregated_offset_x = F.conv3d(offset_x, self.one_hot_filter_v, padding=(0, self.filter_size // 2, 0))
197 | aggregated_offset_y = F.conv3d(offset_y, self.one_hot_filter_v, padding=(0, self.filter_size // 2, 0))
198 |
199 | aggregated_offset_x = aggregated_offset_x.permute([0, 2, 1, 3, 4])
200 | aggregated_offset_x = aggregated_offset_x.contiguous().view(
201 | aggregated_offset_x.size()[0],
202 | aggregated_offset_x.size()[1] * aggregated_offset_x.size()[2],
203 | aggregated_offset_x.size()[3],
204 | aggregated_offset_x.size()[4])
205 |
206 | aggregated_offset_y = aggregated_offset_y.permute([0, 2, 1, 3, 4])
207 | aggregated_offset_y = aggregated_offset_y.contiguous().view(
208 | aggregated_offset_y.size()[0],
209 | aggregated_offset_y.size()[1] * aggregated_offset_y.size()[2],
210 | aggregated_offset_y.size()[3],
211 | aggregated_offset_y.size()[4])
212 |
213 | return aggregated_offset_x, aggregated_offset_y
214 |
215 |
216 | class PropagationFaster(nn.Module):
217 | def __init__(self):
218 | super(PropagationFaster, self).__init__()
219 |
220 | def forward(self, offset_x, offset_y, device, propagation_type="horizontal"):
221 | """
222 | Faster version of PatchMatch Propagation Block
223 | This version uses a fixed propagation filter size of size 3. This implementation is not recommended
224 | and is used only to do the propagation faster.
225 |
226 | Description: Particles from adjacent pixels are propagated together through convolution with a
227 | one-hot filter, which en-codes the fact that we allow each pixel
228 | to propagate particles to its 4-neighbours.
229 | Args:
230 | :offset_x: samples representing offset in the horizontal direction.
231 | :offset_y: samples representing offset in the vertical direction.
232 | :device: Cuda/ CPU device
233 | :propagation_type (default:"horizontal"): In order to be memory efficient, we use separable convolutions
234 | for propagtaion.
235 |
236 | Returns:
237 | :aggregated_offset_x: Horizontal offset samples aggregated from the neighbours.
238 | :aggregated_offset_y: Vertical offset samples aggregated from the neighbours.
239 |
240 | """
241 |
242 | self.vertical_zeros = torch.zeros((offset_x.size()[0], offset_x.size()[1], 1, offset_x.size()[3])).to(device)
243 | self.horizontal_zeros = torch.zeros((offset_x.size()[0], offset_x.size()[1], offset_x.size()[2], 1)).to(device)
244 |
245 | if propagation_type is "horizontal":
246 | offset_x = torch.cat((torch.cat((self.horizontal_zeros, offset_x[:, :, :, :-1]), dim=3),
247 | offset_x,
248 | torch.cat((offset_x[:, :, :, 1:], self.horizontal_zeros), dim=3)), dim=1)
249 | offset_y = torch.cat((torch.cat((self.horizontal_zeros, offset_y[:, :, :, :-1]), dim=3),
250 | offset_y,
251 | torch.cat((offset_y[:, :, :, 1:], self.horizontal_zeros), dim=3)), dim=1)
252 |
253 | else:
254 | offset_x = torch.cat((torch.cat((self.vertical_zeros, offset_x[:, :, :-1, :]), dim=2),
255 | offset_x,
256 | torch.cat((offset_x[:, :, 1:, :], self.vertical_zeros), dim=2)), dim=1)
257 | offset_y = torch.cat((torch.cat((self.vertical_zeros, offset_y[:, :, :-1, :]), dim=2),
258 | offset_y,
259 | torch.cat((offset_y[:, :, 1:, :], self.vertical_zeros), dim=2)), dim=1)
260 |
261 | return offset_x, offset_y
262 |
263 |
264 | class PatchMatch(nn.Module):
265 | def __init__(self, patch_match_args):
266 | super(PatchMatch, self).__init__()
267 | self.propagation_filter_size = patch_match_args.propagation_filter_size
268 | self.number_of_samples = patch_match_args.sample_count
269 | self.iteration_count = patch_match_args.iteration_count
270 | self.evaluation_type = patch_match_args.evaluation_type
271 | self.softmax_temperature = patch_match_args.softmax_temperature
272 | self.propagation_type = patch_match_args.propagation_type
273 |
274 | self.window_size_x = patch_match_args.random_search_window_size[0]
275 | self.window_size_y = patch_match_args.random_search_window_size[1]
276 |
277 | def forward(self, left_features, right_features):
278 | """
279 | Differential PatchMatch Block
280 | Description: In this work, we unroll generalized PatchMatch as a recurrent neural network,
281 | where each unrolling step is equivalent to each iteration of the algorithm.
282 | This is important as it allow us to train our full model end-to-end.
283 | Specifically, we design the following layers:
284 | - Initialization or Paticle Sampling
285 | - Propagation
286 | - Evaluation
287 | Args:
288 | :left_features: Left Image feature map
289 | :right_features: Right image feature map
290 |
291 | Returns:
292 | :offset_x: offset for each pixel in the left_features corresponding to the
293 | right_features in the horizontal direction.
294 | :offset_y: offset for each pixel in the left_features corresponding to the
295 | right_features in the vertical direction.
296 |
297 | :x_coordinate: X coordinate corresponding to each pxiel.
298 | :y_coordinate: Y coordinate corresponding to each pxiel.
299 |
300 | (Offsets and the xy_cooridnates returned are used to generated the NNF field later for reconstruction.)
301 |
302 | """
303 |
304 | device = left_features.get_device()
305 | if self.propagation_type is "faster_filter_3_propagation":
306 | self.propagation = PropagationFaster()
307 | else:
308 | self.propagation = Propagation(device, self.propagation_filter_size)
309 |
310 | self.evaluate = Evaluate(left_features, self.propagation_filter_size,
311 | self.evaluation_type, self.softmax_temperature)
312 | self.uniform_sampler = RandomSampler(device, self.number_of_samples)
313 |
314 | min_offset_x = torch.zeros((left_features.size()[0], 1, left_features.size()[2],
315 | left_features.size()[3])).to(device) - left_features.size()[3]
316 | max_offset_x = min_offset_x + 2 * left_features.size()[3]
317 | min_offset_y = min_offset_x + left_features.size()[3] - left_features.size()[2]
318 | max_offset_y = min_offset_y + 2 * left_features.size()[2]
319 |
320 | for prop_iter in range(self.iteration_count):
321 | offset_x, offset_y = self.uniform_sampler(min_offset_x, max_offset_x,
322 | min_offset_y, max_offset_y)
323 |
324 | offset_x, offset_y = self.propagation(offset_x, offset_y, device, "horizontal")
325 | offset_x, offset_y = self.evaluate(left_features,
326 | right_features,
327 | offset_x, offset_y)
328 |
329 | offset_x, offset_y = self.propagation(offset_x, offset_y, device, "vertical")
330 | offset_x, offset_y = self.evaluate(left_features,
331 | right_features,
332 | offset_x, offset_y)
333 |
334 | min_offset_x = torch.clamp(offset_x - self.window_size_x // 2, min=-left_features.size()[3],
335 | max=left_features.size()[3])
336 | max_offset_x = torch.clamp(offset_x + self.window_size_x // 2, min=-left_features.size()[3],
337 | max=left_features.size()[3])
338 | min_offset_y = torch.clamp(offset_y - self.window_size_y // 2, min=-left_features.size()[2],
339 | max=left_features.size()[2])
340 | max_offset_y = torch.clamp(offset_y + self.window_size_y // 2, min=-left_features.size()[2],
341 | max=left_features.size()[2])
342 |
343 | return offset_x, offset_y, self.evaluate.left_x_coordinate, self.evaluate.left_y_coordinate
344 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by the text below.
2 |
3 | "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
4 |
5 | "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
6 |
7 | "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
8 |
9 | "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
10 |
11 | "Work" shall mean the work of authorship, whether in Source or Object form, made available under this License.
12 |
13 | This License governs use of the accompanying Work, and your use of the Work constitutes acceptance of this License.
14 |
15 | You may use this Work for any non-commercial purpose, subject to the restrictions in this License. Some purposes which can be non-commercial are teaching, academic research, and personal experimentation. You may also distribute this Work with books or other teaching materials, or publish the Work on websites, that are intended to teach the use of the Work.
16 |
17 | You may not use or distribute this Work, or any derivative works, outputs, or results from the Work, in any form for commercial purposes. Non-exhaustive examples of commercial purposes would be running business operations, licensing, leasing, or selling the Work, or distributing the Work for use with commercial products.
18 |
19 | You may modify this Work and distribute the modified Work for non-commercial purposes, however, you may not grant rights to the Work or derivative works that are broader than or in conflict with those provided by this License. For example, you may not distribute modifications of the Work under terms that would permit commercial use, or under terms that purport to require the Work or derivative works to be sublicensed to others.
20 |
21 | In return, we require that you agree:
22 |
23 | 1. Not to remove any copyright or other notices from the Work.
24 |
25 | 2. That if you distribute the Work in Source or Object form, you will include a verbatim copy of this License.
26 |
27 | 3. That if you distribute derivative works of the Work in Source form, you do so only under a license that includes all of the provisions of this License and is not in conflict with this License, and if you distribute derivative works of the Work solely in Object form you do so only under a license that complies with this License.
28 |
29 | 4. That if you have modified the Work or created derivative works from the Work, and distribute such modifications or derivative works, you will cause the modified files to carry prominent notices so that recipients know that they are not receiving the original Work. Such notices must state: (i) that you have changed the Work; and (ii) the date of any changes.
30 |
31 | 5. If you publicly use the Work or any output or result of the Work, you will provide a notice with such use that provides any person who uses, views, accesses, interacts with, or is otherwise exposed to the Work (i) with information of the nature of the Work, (ii) with a link to the Work, and (iii) a notice that the Work is available under this License.
32 |
33 | 6. THAT THE WORK COMES "AS IS", WITH NO WARRANTIES. THIS MEANS NO EXPRESS, IMPLIED OR STATUTORY WARRANTY, INCLUDING WITHOUT LIMITATION, WARRANTIES OF MERCHANTABILITY OR FITNESS FOR A PARTICULAR PURPOSE OR ANY WARRANTY OF TITLE OR NON-INFRINGEMENT. ALSO, YOU MUST PASS THIS DISCLAIMER ON WHENEVER YOU DISTRIBUTE THE WORK OR DERIVATIVE WORKS.
34 |
35 | 7. THAT NEITHER UBER TECHNOLOGIES, INC. NOR ANY OF ITS AFFILIATES, SUPPLIERS, SUCCESSORS, NOR ASSIGNS WILL BE LIABLE FOR ANY DAMAGES RELATED TO THE WORK OR THIS LICENSE, INCLUDING DIRECT, INDIRECT, SPECIAL, CONSEQUENTIAL OR INCIDENTAL DAMAGES, TO THE MAXIMUM EXTENT THE LAW PERMITS, NO MATTER WHAT LEGAL THEORY IT IS BASED ON. ALSO, YOU MUST PASS THIS LIMITATION OF LIABILITY ON WHENEVER YOU DISTRIBUTE THE WORK OR DERIVATIVE WORKS.
36 |
37 | 8. That if you sue anyone over patents that you think may apply to the Work or anyone's use of the Work, your license to the Work ends automatically.
38 |
39 | 9. That your rights under the License end automatically if you breach it in any way.
40 |
41 | 10. Uber Technologies, Inc. reserves all rights not expressly granted to you in this License.
42 |
--------------------------------------------------------------------------------
/NOTICE:
--------------------------------------------------------------------------------
1 | "DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch" includes derivied work from:
2 |
3 | 1. PSMNet (https://github.com/JiaRenChang/PSMNet)
4 |
5 | MIT License
6 |
7 | Copyright (c) 2018 Jia-Ren Chang
8 |
9 | Permission is hereby granted, free of charge, to any person obtaining a copy
10 | of this software and associated documentation files (the "Software"), to deal
11 | in the Software without restriction, including without limitation the rights
12 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 | copies of the Software, and to permit persons to whom the Software is
14 | furnished to do so, subject to the following conditions:
15 |
16 | The above copyright notice and this permission notice shall be included in all
17 | copies or substantial portions of the Software.
18 |
19 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
25 | SOFTWARE.
26 |
27 |
28 | The derived work can be found in the files:
29 |
30 | 1. DeepPruner/dataloader/readpfm.py
31 | 2. DeepPruner/dataloader/preprocess.py
32 | 3. DeepPruner/dataloader/kitti_submission_collector.py
33 | 4. DeepPruner/dataloader/kitti_loader.py
34 | 5. DeepPruner/dataloader/kitti_collector.py
35 | 6. DeepPruner/dataloader/sceneflow_loader.py
36 | 7. DeepPruner/dataloader/sceneflow_collector.py
37 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
2 |
3 | This repository releases code for our paper [DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch](https://arxiv.org/abs/1909.05845).
4 |
5 | ##### Table of Contents
6 | [DeepPruner](#DeepPruner)
7 | [Differentiable Patch Match](#DifferentiablePatchMatch)
8 | [Requirements (Major Dependencies)](#Requirements)
9 | [Citation](#Citation)
10 |
11 |
12 |
13 | ### **DeepPruner**
14 |
15 | + An efficient "Real Time Stereo Matching" algorithm, which takes as input 2 images and outputs a disparity (or depth) map.
16 |
17 |
18 |
19 | 
20 |
21 |
22 |
23 | + Results/ Metrics:
24 |
25 | + [**KITTI**](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo): **Results competitive to SOTA, while being real-time (8x faster than SOTA). SOTA among published real-time algorithms**.
26 |
27 |
28 | 
29 | 
30 | 
31 |
32 |
33 | + [**ETH3D**](https://www.eth3d.net/low_res_two_view?mask=all&metric=bad-2-0): **SOTA among all ROB entries**.
34 |
35 | + **SceneFlow**: **2nd among all published algorithms, while being 8x faster than the 1st.**
36 |
37 |
38 |
39 |
40 |
41 |
42 |
43 | + [**Robust Vision Challenge**](http://www.robustvision.net/index.php): **Overall ranking 1st**.
44 |
45 |
46 |
47 |
48 |
49 | + Runtime: **62ms** (for DeepPruner-fast), **180ms** (for DeepPruner-best)
50 |
51 | + Cuda Memory Requirements: **805MB** (for DeepPruner-best)
52 |
53 |
54 |
55 |
56 | ### **Differentiable Patch Match**
57 | + Fast algorithm for finding dense nearest neighbor correspondences between patches of images regions.
58 | Differentiable version of the generalized Patch Match algorithm. ([Barnes et al.](https://gfx.cs.princeton.edu/pubs/Barnes_2010_TGP/index.php))
59 |
60 |
61 |
62 |
63 |
64 | More details in the corresponding folder README.
65 |
66 |
67 |
68 | ## Requirements (Major Dependencies)
69 | + Pytorch (0.4.1+)
70 | + Python2.7
71 | + torchvision (0.2.0+)
72 |
73 |
74 |
75 |
76 | ## Citation
77 |
78 | If you use our source code, or our paper, please consider citing the following:
79 | > @inproceedings{Duggal2019ICCV,
80 | title = {DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch},
81 | author = {Shivam Duggal and Shenlong Wang and Wei-Chiu Ma and Rui Hu and Raquel Urtasun},
82 | booktitle = {ICCV},
83 | year = {2019}
84 | }
85 |
86 | Correspondences to Shivam Duggal , Shenlong Wang , Wei-Chiu Ma
87 |
--------------------------------------------------------------------------------
/deeppruner/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.log
--------------------------------------------------------------------------------
/deeppruner/README.md:
--------------------------------------------------------------------------------
1 | # DeepPruner
2 |
3 | This is the code repository for **DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch**.
4 | 
5 |
6 |
7 | ##### Table of Contents
8 |
9 | [Requirements](#Requirements)
10 | [License](#License)
11 | [Model Weights](#Weights)
12 | [Training and Evaluation](#TrainingEvaluation)
13 | - [KITTI](#KITTI)
14 | - [Sceneflow](#Sceneflow)
15 | - [Robust Vision Challenge](#ROB)
16 |
17 |
18 | ## Requirements
19 |
20 | + Pytorch (0.4.1+)
21 | + python (2.7)
22 | + scikit-image
23 | + tensorboardX
24 | + torchvision (0.2.0+)
25 |
26 |
27 | ## License
28 |
29 | + The souce code for DeepPruner and Differentiable PatchMatch are released under the © Uber, 2018-2019. Licensed under the Uber Non-Commercial License.
30 | + The trained model-weights for DeepPruner are released under the license Creative Commons Attribution-NonCommercial-ShareAlike 3.0 License.
31 |
32 |
33 |
34 | ## Model Weights
35 |
36 | DeepPruner was first trained on Sceneflow dataset and then finetuned on KITTI (Combined 394 images of KITTI-2012 and KITTI-2015) dataset.
37 |
38 | + DeepPruner-fast (KITTI)
39 | + DeepPruner-best (KITTI)
40 | + DeepPruner-fast (Sceneflow)
41 | + DeepPruner-best (Sceneflow)
42 |
43 |
44 |
45 | ## Training and Evaluation
46 |
47 | **NOTE:** We allow the users to modify a bunch of model parameters and training setting for their own purpose. You may need to retain the model with the modified parameters.
48 | Check **'models/config.py'** for more details.
49 |
50 |
51 |
52 | ### KITTI Stereo 2012/2015:
53 |
54 | KITTI 2015 has 200 stereo-pairs with ground truth disparities. We used 160 out of these 200 for training and the remaining 40 for validation. The training set was further augmented by 194 stereo image pairs from KITTI 2012 dataset.
55 |
56 | #### Setup:
57 |
58 | 1. Download the KITTI 2012 and KITTI 2015 datasets.
59 | 2. Split KITTI Stereo 2015 training dataset into "training" (160 pairs) and "validation" (40 pairs), following the same directory structure as of the original dataset. **Make sure to have the following structure**:
60 |
61 | > training_directory_stereo_2015 \
62 | > |----- image_2 \
63 | > |----- image_3 \
64 | > |----- disp_occ_0 \
65 | > val_directory_stereo_2015 \
66 | > |----- image_2 \
67 | > |----- image_3 \
68 | > |----- disp_occ_0 \
69 | > train_directory_stereo_2012 \
70 | > |----- colored_0 \
71 | > |----- colored_1 \
72 | > |----- disp_occ \
73 | > test_directory \
74 | > |----- image_2 \
75 | > |----- image_3
76 |
77 | 3. Note that, any other dataset could be used for training, validation and testing. The directory structure should be same.
78 |
79 | #### Training Command:
80 | 1. Like previous works, we fine-tuned the pre-trained Sceneflow model on KITTI dataset.
81 | 2. Training Command:
82 |
83 | > python finetune_kitti.py \\\
84 | > --loadmodel \\\
85 | > --savemodel \\\
86 | > --train_datapath_2015 \\\
87 | > --val_datapath_2015 \\\
88 | > --datapath_2012
89 |
90 | 3. Training command arguments:
91 |
92 | + --loadmodel (default: None): If not set, the model would train from scratch.
93 | + --savemodel (default: './'): If not set, the script will save the trained models after every epoch in the same directory.
94 | + --train_datapath_2015 (default: None): If not set, KITTI stereo 2015 dataset won't be used for training.
95 | + --val_datapath_2015 (default: None): If not set, the script would fail. The validation dataset should have atleast one image to run.
96 | + --datapath_2012 (default: None): If not set, KITTI stereo 2012 dataset won't be used for training.4. Training and validation tensorboard runs will be saved in **'./runs/' directory**.
97 |
98 | #### Evaluation Command (On Any Dataset):
99 | 1. We used KITTI 2015 Stereo testing set for evaluation. **(Note any other dataset could be used.)**
100 | 2. To evaluate DeepPruner on any dataset, just create a base directory like:
101 | > test_directory \
102 | > |----- image_2 \
103 | > |----- image_3
104 |
105 | image_2 folder holds the left images, while image_3 folder holds the right images.
106 |
107 | 3. For evaluation, update the "mode" parameter in "models.config.py" to "evaluation".
108 |
109 | 4. Evaluation command:
110 |
111 | > python submission_kitti.py \
112 | > --loadmodel \
113 | > --save_dir \
114 | > --datapath
115 |
116 | #### Metrics/ Results:
117 |
118 | 1. The metrics used for evaluation are same as provided by the [KITTI Stereo benchmark](http://www.cvlibs.net/datasets/kitti/eval_scene_flow.php?benchmark=stereo).
119 | 2. The quantitaive results obtained by DeepPruner are as follows: (Note the standings in the Tables below are at the time of March 2019.)
120 |
121 |
122 |
123 |
124 |
125 | 3. Alongside learning the disparity (or depth maps), DeepPruner is able to predict the uncertain regions (occluded regions , bushy regions, object edges) efficiently . Since the uncertainty in prediction correlates well with the error in the disparity maps (Figure 7.), such uncertainty can be used in other downstream tasks.
126 |
127 | 4. Qualitative results are as follows:
128 |
129 |
130 | 
131 | 
132 | 
133 |
134 |
135 |
136 |
137 | ### Sceneflow:
138 |
139 | #### Setup:
140 |
141 | 1. Download [Sceneflow dataset](https://lmb.informatik.uni-freiburg.de/resources/datasets/SceneFlowDatasets.en.html#downloads), which consists of FlyingThings3D, Driving and Monkaa (RGB images (cleanpass) and disparity).
142 |
143 | 2. We followed the same directory structure as of the downloaded data. Check **dataloader/sceneflow_collector.py**.
144 |
145 | #### Training/ Evaluation Command:
146 |
147 | > python train_sceneflow.py \\\
148 | > --loadmodel \\\
149 | > --save_dir \\\
150 | > --savemodel \\\
151 | > --datapath_monkaa \\\
152 | > --datapath_flying \\\
153 | > --datapath_driving
154 |
155 | #### Metrics/ Results:
156 |
157 | 1. We used EPE(end point error) as one of the metrics.
158 |
159 |
160 |
161 |
162 |
163 |
164 |
165 |
166 |
167 |
168 |
169 |
170 |
171 |
172 |
173 | ### Robust Vision Challenge:
174 |
175 | #### Details:
176 |
177 | 1. The goal of Robust Vision Challenge challenge is to foster the development of vision systems that are robust and
178 | consequently perform well on a variety of datasets with different characteristics.
179 | Please refer to Robust Vision Challenge for more details.
180 |
181 | 2. We used the pre-trained Seneflow model and then jointly fine-tuned the model on KITTI, ETH3D and Middlebury datasets.
182 |
183 | #### Setup
184 |
185 | 1. Dataloader and setup details coming soon.
186 |
187 | #### Metrics/ Results:
188 |
189 | Check DeepPruner_ROB on KITTI benchmark. \
190 | Check DeepPruner_ROB on ETH3D benchmark. \
191 | Check DeepPruner_ROB on MiddleburyV3 benchmark. \
192 |
193 |
194 |
195 |
196 |
197 |
198 |
199 |
200 |
201 |
--------------------------------------------------------------------------------
/deeppruner/dataloader/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
--------------------------------------------------------------------------------
/deeppruner/dataloader/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/deeppruner/dataloader/__init__.py
--------------------------------------------------------------------------------
/deeppruner/dataloader/kitti_collector.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch.utils.data as data
18 | import os
19 | import glob
20 |
21 |
22 | def datacollector(train_filepath, val_filepath, filepath_2012):
23 |
24 | left_fold = 'image_2/'
25 | right_fold = 'image_3/'
26 | disp_L = 'disp_occ_0/'
27 | disp_R = 'disp_occ_1/'
28 |
29 | left_fold_2012 = 'colored_0/'
30 | right_fold_2012 = 'colored_1/'
31 | disp_L_2012 = 'disp_occ/'
32 |
33 | left_train = []
34 | right_train = []
35 | disp_train_L = []
36 |
37 | left_val = []
38 | right_val = []
39 | disp_val_L = []
40 |
41 |
42 | if train_filepath is not None:
43 | left_train = sorted(glob.glob(os.path.join(train_filepath, left_fold, '*.png')))
44 | right_train = sorted(glob.glob(os.path.join(train_filepath, right_fold, '*.png')))
45 | disp_train_L = sorted(glob.glob(os.path.join(train_filepath, disp_L, '*.png')))
46 |
47 | if filepath_2012 is not None:
48 | left_train +=sorted(glob.glob(os.path.join(filepath_2012, left_fold_2012, '*_10.png')))
49 | right_train += sorted(glob.glob(os.path.join(filepath_2012, right_fold_2012, '*_10.png')))
50 | disp_train_L += sorted(glob.glob(os.path.join(filepath_2012, disp_L_2012, '*_10.png')))
51 |
52 | if val_filepath is not None:
53 | left_val = sorted(glob.glob(os.path.join(val_filepath, left_fold, '*.png')))
54 | right_val = sorted(glob.glob(os.path.join(val_filepath, right_fold, '*.png')))
55 | disp_val_L = sorted(glob.glob(os.path.join(val_filepath, disp_L, '*.png')))
56 |
57 | return left_train, right_train, disp_train_L, left_val, right_val, disp_val_L
58 |
--------------------------------------------------------------------------------
/deeppruner/dataloader/kitti_loader.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch.utils.data as data
18 | import random
19 | from PIL import Image
20 | import numpy as np
21 | from dataloader import preprocess
22 |
23 |
24 | def default_loader(path):
25 | return Image.open(path).convert('RGB')
26 |
27 |
28 | def disparity_loader(path):
29 | return Image.open(path)
30 |
31 | # train/ validation image crop size constants
32 | DEFAULT_TRAIN_IMAGE_HEIGHT = 256
33 | DEFAULT_TRAIN_IMAGE_WIDTH = 512
34 |
35 | DEFAULT_VAL_IMAGE_HEIGHT = 320
36 | DEFAULT_VAL_IMAGE_WIDTH = 1216
37 |
38 |
39 | class KITTILoader(data.Dataset):
40 | def __init__(self, left_images, right_images, left_disparity, training, loader=default_loader, dploader=disparity_loader):
41 |
42 | self.left_img = left_images
43 | self.right_img = right_images
44 | self.left_disp = left_disparity
45 | self.loader = loader
46 | self.dploader = dploader
47 | self.training = training
48 |
49 | def __getitem__(self, index):
50 | left_img = self.left_img[index]
51 | right_img = self.right_img[index]
52 | left_disp = self.left_disp[index]
53 |
54 | left_img = self.loader(left_img)
55 | right_img = self.loader(right_img)
56 | left_disp = self.dploader(left_disp)
57 | w, h = left_img.size
58 |
59 |
60 | if self.training:
61 | th, tw = DEFAULT_TRAIN_IMAGE_HEIGHT, DEFAULT_TRAIN_IMAGE_WIDTH
62 | x1 = random.randint(0, w - tw)
63 | y1 = random.randint(0, h - th)
64 |
65 | else:
66 | th, tw = DEFAULT_VAL_IMAGE_HEIGHT, DEFAULT_VAL_IMAGE_WIDTH
67 | x1 = w - DEFAULT_VAL_IMAGE_WIDTH
68 | y1 = h - DEFAULT_VAL_IMAGE_HEIGHT
69 |
70 | left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
71 | right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))
72 | left_disp = left_disp.crop((x1, y1, x1 + tw, y1 + th))
73 | left_disp = np.ascontiguousarray(left_disp, dtype=np.float32) / 256
74 |
75 |
76 | processed = preprocess.get_transform()
77 | left_img = processed(left_img)
78 | right_img = processed(right_img)
79 |
80 | return left_img, right_img, left_disp
81 |
82 | def __len__(self):
83 | return len(self.left_img)
84 |
--------------------------------------------------------------------------------
/deeppruner/dataloader/kitti_submission_collector.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import os
18 |
19 |
20 | def datacollector(filepath):
21 | left_fold = 'image_2/'
22 | right_fold = 'image_3/'
23 | disp = 'disp_occ_0/'
24 |
25 | image = [img for img in sorted(os.listdir(os.path.join(filepath,left_fold))) if img.find('.png') > -1]
26 |
27 | left_test = [os.path.join(filepath, left_fold, img) for img in image]
28 | right_test = [os.path.join(filepath, right_fold, img) for img in image]
29 |
30 | return left_test, right_test
31 |
--------------------------------------------------------------------------------
/deeppruner/dataloader/preprocess.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torchvision.transforms as transforms
18 |
19 | __imagenet_stats = {'mean': [0.485, 0.456, 0.406],
20 | 'std': [0.229, 0.224, 0.225]}
21 |
22 |
23 | def get_transform():
24 |
25 | normalize = __imagenet_stats
26 | t_list = [
27 | transforms.ToTensor(),
28 | transforms.Normalize(**normalize),
29 | ]
30 |
31 | return transforms.Compose(t_list)
32 |
--------------------------------------------------------------------------------
/deeppruner/dataloader/readpfm.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import re
18 | import numpy as np
19 | import sys
20 |
21 |
22 | def readPFM(file):
23 | file = open(file, 'rb')
24 |
25 | color = None
26 | width = None
27 | height = None
28 | scale = None
29 | endian = None
30 |
31 | header = file.readline().rstrip()
32 | if header == 'PF':
33 | color = True
34 | elif header == 'Pf':
35 | color = False
36 | else:
37 | raise Exception('Not a PFM file.')
38 |
39 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline())
40 | if dim_match:
41 | width, height = map(int, dim_match.groups())
42 | else:
43 | raise Exception('Malformed PFM header.')
44 |
45 | scale = float(file.readline().rstrip())
46 | if scale < 0: # little-endian
47 | endian = '<'
48 | scale = -scale
49 | else:
50 | endian = '>' # big-endian
51 |
52 | data = np.fromfile(file, endian + 'f')
53 | shape = (height, width, 3) if color else (height, width)
54 |
55 | data = np.reshape(data, shape)
56 | data = np.flipud(data)
57 | return data, scale
58 |
--------------------------------------------------------------------------------
/deeppruner/dataloader/sceneflow_collector.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch.utils.data as data
18 | from PIL import Image
19 | import os
20 | import os.path
21 | import logging
22 |
23 | IMG_EXTENSIONS = [
24 | '.jpg', '.JPG', '.jpeg', '.JPEG',
25 | '.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
26 | ]
27 |
28 |
29 | def is_image_file(filename):
30 | return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
31 |
32 |
33 | def dataloader(filepath_monkaa, filepath_flying, filepath_driving):
34 |
35 | try:
36 | monkaa_path = os.path.join(filepath_monkaa, 'monkaa_frames_cleanpass')
37 | monkaa_disp = os.path.join(filepath_monkaa, 'monkaa_disparity')
38 | monkaa_dir = os.listdir(monkaa_path)
39 |
40 | all_left_img = []
41 | all_right_img = []
42 | all_left_disp = []
43 | test_left_img = []
44 | test_right_img = []
45 | test_left_disp = []
46 |
47 | for dd in monkaa_dir:
48 | for im in os.listdir(os.path.join(monkaa_path, dd, 'left')):
49 | if is_image_file(os.path.join(monkaa_path, dd, 'left', im)) and is_image_file(
50 | os.path.join(monkaa_path, dd, 'right', im)):
51 | all_left_img.append(os.path.join(monkaa_path, dd, 'left', im))
52 | all_left_disp.append(os.path.join(monkaa_disp, dd, 'left', im.split(".")[0] + '.pfm'))
53 | all_right_img.append(os.path.join(monkaa_path, dd, 'right', im))
54 |
55 | except:
56 | logging.error("Some error in Monkaa, Monkaa might not be loaded correctly in this case...")
57 | raise Exception('Monkaa dataset couldn\'t be loaded correctly.')
58 |
59 |
60 | try:
61 | flying_path = os.path.join(filepath_flying, 'frames_cleanpass')
62 | flying_disp = os.path.join(filepath_flying, 'disparity')
63 | flying_dir = flying_path + '/TRAIN/'
64 | subdir = ['A', 'B', 'C']
65 |
66 | for ss in subdir:
67 | flying = os.listdir(os.path.join(flying_dir, ss))
68 |
69 | for ff in flying:
70 | imm_l = os.listdir(os.path.join(flying_dir, ss, ff, 'left'))
71 | for im in imm_l:
72 | if is_image_file(os.path.join(flying_dir, ss, ff, 'left', im)):
73 | all_left_img.append(os.path.join(flying_dir, ss, ff, 'left', im))
74 |
75 | all_left_disp.append(os.path.join(flying_disp, 'TRAIN', ss, ff, 'left', im.split(".")[0] + '.pfm'))
76 |
77 | if is_image_file(os.path.join(flying_dir, ss, ff, 'right', im)):
78 | all_right_img.append(os.path.join(flying_dir, ss, ff, 'right', im))
79 |
80 | flying_dir = flying_path + '/TEST/'
81 | subdir = ['A', 'B', 'C']
82 |
83 | for ss in subdir:
84 | flying = os.listdir(os.path.join(flying_dir, ss))
85 |
86 | for ff in flying:
87 | imm_l = os.listdir(os.path.join(flying_dir, ss, ff, 'left'))
88 | for im in imm_l:
89 | if is_image_file(os.path.join(flying_dir, ss, ff, 'left', im)):
90 | test_left_img.append(os.path.join(flying_dir, ss, ff, 'left', im))
91 |
92 | test_left_disp.append(os.path.join(flying_disp, 'TEST', ss, ff, 'left', im.split(".")[0] + '.pfm'))
93 |
94 | if is_image_file(os.path.join(flying_dir, ss, ff, 'right', im)):
95 | test_right_img.append(os.path.join(flying_dir, ss, ff, 'right', im))
96 |
97 | except:
98 | logging.error("Some error in Flying Things, Flying Things might not be loaded correctly in this case...")
99 | raise Exception('Flying Things dataset couldn\'t be loaded correctly.')
100 |
101 | try:
102 | driving_dir = os.path.join(filepath_driving, 'driving_frames_cleanpass/')
103 | driving_disp = os.path.join(filepath_driving, 'driving_disparity/')
104 |
105 | subdir1 = ['35mm_focallength', '15mm_focallength']
106 | subdir2 = ['scene_backwards', 'scene_forwards']
107 | subdir3 = ['fast', 'slow']
108 |
109 | for i in subdir1:
110 | for j in subdir2:
111 | for k in subdir3:
112 | imm_l = os.listdir(os.path.join(driving_dir, i, j, k, 'left'))
113 | for im in imm_l:
114 | if is_image_file(os.path.join(driving_dir, i, j, k, 'left', im)):
115 | all_left_img.append(os.path.join(driving_dir, i, j, k, 'left', im))
116 | all_left_disp.append(os.path.join(driving_disp, i, j, k, 'left', im.split(".")[0] + '.pfm'))
117 |
118 | if is_image_file(os.path.join(driving_dir, i, j, k, 'right', im)):
119 | all_right_img.append(os.path.join(driving_dir, i, j, k, 'right', im))
120 | except:
121 | logging.error("Some error in Driving, Driving might not be loaded correctly in this case...")
122 | raise Exception('Driving dataset couldn\'t be loaded correctly.')
123 |
124 | return all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp
125 |
--------------------------------------------------------------------------------
/deeppruner/dataloader/sceneflow_loader.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch.utils.data as data
18 | import random
19 | from PIL import Image
20 | from dataloader import preprocess
21 | from dataloader import readpfm as rp
22 | import numpy as np
23 | import math
24 |
25 | # train/ validation image crop size constants
26 | DEFAULT_TRAIN_IMAGE_HEIGHT = 256
27 | DEFAULT_TRAIN_IMAGE_WIDTH = 512
28 |
29 | def default_loader(path):
30 | return Image.open(path).convert('RGB')
31 |
32 |
33 | def disparity_loader(path):
34 | return rp.readPFM(path)
35 |
36 |
37 | class SceneflowLoader(data.Dataset):
38 | def __init__(self, left_images, right_images, left_disparity, downsample_scale, training, loader=default_loader, dploader=disparity_loader):
39 |
40 | self.left_images = left_images
41 | self.right_images = right_images
42 | self.left_disparity = left_disparity
43 | self.loader = loader
44 | self.dploader = dploader
45 | self.training = training
46 |
47 | # downsample_scale denotes maximum times the image features are downsampled
48 | # by the network.
49 | # Since the image size used for evaluation may not be divisible by the downsample_scale,
50 | # we pad it with zeros, so that it becomes divible and later unpad the extra zeros.
51 | self.downsample_scale = downsample_scale
52 |
53 | def __getitem__(self, index):
54 | left_img = self.left_images[index]
55 | right_img = self.right_images[index]
56 | left_disp = self.left_disparity[index]
57 |
58 | left_img = self.loader(left_img)
59 | right_img = self.loader(right_img)
60 | left_disp, left_scale = self.dploader(left_disp)
61 | left_disp = np.ascontiguousarray(left_disp, dtype=np.float32)
62 |
63 | if self.training:
64 | w, h = left_img.size
65 | th, tw = DEFAULT_TRAIN_IMAGE_HEIGHT, DEFAULT_TRAIN_IMAGE_WIDTH
66 |
67 | x1 = random.randint(0, w - tw)
68 | y1 = random.randint(0, h - th)
69 |
70 | left_img = left_img.crop((x1, y1, x1 + tw, y1 + th))
71 | right_img = right_img.crop((x1, y1, x1 + tw, y1 + th))
72 | left_disp = left_disp[y1:y1 + th, x1:x1 + tw]
73 |
74 | processed = preprocess.get_transform()
75 | left_img = processed(left_img)
76 | right_img = processed(right_img)
77 |
78 | return left_img, right_img, left_disp
79 | else:
80 | w, h = left_img.size
81 |
82 | dw = w + (self.downsample_scale - (w%self.downsample_scale + (w%self.downsample_scale==0)*self.downsample_scale))
83 | dh = h + (self.downsample_scale - (h%self.downsample_scale + (h%self.downsample_scale==0)*self.downsample_scale))
84 |
85 | left_img = left_img.crop((w - dw, h - dh, w, h))
86 | right_img = right_img.crop((w - dw, h - dh, w, h))
87 |
88 | processed = preprocess.get_transform()
89 | left_img = processed(left_img)
90 | right_img = processed(right_img)
91 |
92 | return left_img, right_img, left_disp, dw-w, dh-h
93 |
94 | def __len__(self):
95 | return len(self.left_images)
96 |
--------------------------------------------------------------------------------
/deeppruner/finetune_kitti.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 |
17 | from __future__ import print_function
18 | import argparse
19 | import os
20 | import random
21 | from collections import namedtuple
22 | import torch
23 | import torch.nn as nn
24 | import torch.nn.parallel
25 | import torch.backends.cudnn as cudnn
26 | import torch.optim as optim
27 | import torch.utils.data
28 | from torch.autograd import Variable
29 | import skimage
30 | import skimage.transform
31 | import numpy as np
32 | from dataloader import kitti_collector as ls
33 | from dataloader import kitti_loader as DA
34 | from models.deeppruner import DeepPruner
35 | from tensorboardX import SummaryWriter
36 | from torchvision import transforms
37 | from loss_evaluation import loss_evaluation
38 | from models.config import config as config_args
39 | import matplotlib.pyplot as plt
40 | import logging
41 | from setup_logging import setup_logging
42 |
43 | parser = argparse.ArgumentParser(description='DeepPruner')
44 | parser.add_argument('--train_datapath_2015', default=None,
45 | help='training data path of KITTI 2015')
46 | parser.add_argument('--datapath_2012', default=None,
47 | help='data path of KITTI 2012 (all used for training)')
48 | parser.add_argument('--val_datapath_2015', default=None,
49 | help='validation data path of KITTI 2015')
50 | parser.add_argument('--epochs', type=int, default=1040,
51 | help='number of epochs to train')
52 | parser.add_argument('--loadmodel', default=None,
53 | help='load model')
54 | parser.add_argument('--savemodel', default='./',
55 | help='save model')
56 | parser.add_argument('--logging_filename', default='./finetune_kitti.log',
57 | help='filename for logs')
58 | parser.add_argument('--no-cuda', action='store_true', default=False,
59 | help='enables CUDA training')
60 | parser.add_argument('--seed', type=int, default=1, metavar='S',
61 | help='random seed (default: 1)')
62 |
63 | args = parser.parse_args()
64 | args.cuda = not args.no_cuda and torch.cuda.is_available()
65 | torch.manual_seed(args.seed)
66 | if args.cuda:
67 | torch.manual_seed(args.seed)
68 | np.random.seed(args.seed)
69 | random.seed(args.seed)
70 | torch.cuda.manual_seed(args.seed)
71 | torch.backends.cudnn.deterministic = True
72 |
73 |
74 | args.cost_aggregator_scale = config_args.cost_aggregator_scale
75 |
76 | setup_logging(args.logging_filename)
77 |
78 | all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = ls.datacollector(
79 | args.train_datapath_2015, args.val_datapath_2015, args.datapath_2012)
80 |
81 |
82 | TrainImgLoader = torch.utils.data.DataLoader(
83 | DA.KITTILoader(all_left_img, all_right_img, all_left_disp, True),
84 | batch_size=16, shuffle=True, num_workers=8, drop_last=False)
85 |
86 | TestImgLoader = torch.utils.data.DataLoader(
87 | DA.KITTILoader(test_left_img, test_right_img, test_left_disp, False),
88 | batch_size=8, shuffle=False, num_workers=4, drop_last=False)
89 |
90 | model = DeepPruner()
91 | writer = SummaryWriter()
92 | model = nn.DataParallel(model)
93 |
94 |
95 | if args.cuda:
96 | model.cuda()
97 |
98 | if args.loadmodel is not None:
99 | logging.info("loading model...")
100 | state_dict = torch.load(args.loadmodel)
101 | model.load_state_dict(state_dict['state_dict'], strict=True)
102 |
103 |
104 | logging.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
105 | optimizer = optim.Adam(model.parameters(), lr=0.1, betas=(0.9, 0.999))
106 |
107 |
108 | def train(imgL, imgR, disp_L, iteration, epoch):
109 | if epoch >= 800:
110 | model.eval()
111 | else:
112 | model.train()
113 |
114 | imgL = Variable(torch.FloatTensor(imgL))
115 | imgR = Variable(torch.FloatTensor(imgR))
116 | disp_L = Variable(torch.FloatTensor(disp_L))
117 |
118 | if args.cuda:
119 | imgL, imgR, disp_true = imgL.cuda(), imgR.cuda(), disp_L.cuda()
120 |
121 | mask = (disp_true > 0)
122 | mask.detach_()
123 |
124 | optimizer.zero_grad()
125 | result = model(imgL,imgR)
126 |
127 | loss, _ = loss_evaluation(result, disp_true, mask, args.cost_aggregator_scale)
128 |
129 | loss.backward()
130 | optimizer.step()
131 |
132 | return loss.item()
133 |
134 |
135 |
136 | def test(imgL,imgR,disp_L,iteration):
137 |
138 | model.eval()
139 | with torch.no_grad():
140 | imgL = Variable(torch.FloatTensor(imgL))
141 | imgR = Variable(torch.FloatTensor(imgR))
142 | disp_L = Variable(torch.FloatTensor(disp_L))
143 |
144 | if args.cuda:
145 | imgL, imgR, disp_true = imgL.cuda(), imgR.cuda(), disp_L.cuda()
146 |
147 | mask = (disp_true > 0)
148 | mask.detach_()
149 |
150 | optimizer.zero_grad()
151 |
152 | result = model(imgL,imgR)
153 | loss, output_disparity = loss_evaluation(result, disp_true, mask, args.cost_aggregator_scale)
154 |
155 | #computing 3-px error: (source psmnet)#
156 | true_disp = disp_true.data.cpu()
157 | disp_true = true_disp
158 | pred_disp = output_disparity.data.cpu()
159 |
160 | index = np.argwhere(true_disp>0)
161 | disp_true[index[0][:], index[1][:], index[2][:]] = np.abs(true_disp[index[0][:], index[1][:], index[2][:]]-pred_disp[index[0][:], index[1][:], index[2][:]])
162 | correct = (disp_true[index[0][:], index[1][:], index[2][:]] < 3)|(disp_true[index[0][:], index[1][:], index[2][:]] < true_disp[index[0][:], index[1][:], index[2][:]]*0.05)
163 | torch.cuda.empty_cache()
164 |
165 | loss = 1-(float(torch.sum(correct))/float(len(index[0])))
166 |
167 | return loss
168 |
169 |
170 | def adjust_learning_rate(optimizer, epoch):
171 | if epoch <= 500:
172 | lr = 0.0001
173 | elif epoch<=1000:
174 | lr = 0.00005
175 | else:
176 | lr = 0.00001
177 | logging.info('learning rate = %.5f' %(lr))
178 | for param_group in optimizer.param_groups:
179 | param_group['lr'] = lr
180 |
181 |
182 | def main():
183 |
184 | for epoch in range(0, args.epochs):
185 | total_train_loss = 0
186 | total_test_loss = 0
187 | adjust_learning_rate(optimizer,epoch)
188 |
189 | if epoch %1==0 and epoch!=0:
190 | for batch_idx, (imgL, imgR, disp_L) in enumerate(TestImgLoader):
191 | test_loss = test(imgL,imgR,disp_L,batch_idx)
192 | total_test_loss += test_loss
193 | logging.info('Iter %d 3-px error in val = %.3f \n' %(batch_idx, test_loss))
194 |
195 | logging.info('epoch %d total test loss = %.3f' %(epoch, total_test_loss/len(TestImgLoader)))
196 | writer.add_scalar("val-loss",total_test_loss/len(TestImgLoader),epoch)
197 |
198 | for batch_idx, (imgL_crop, imgR_crop, disp_crop_L) in enumerate(TrainImgLoader):
199 | loss = train(imgL_crop,imgR_crop,disp_crop_L,batch_idx,epoch)
200 | total_train_loss += loss
201 | logging.info('Iter %d training loss = %.3f \n' %(batch_idx, loss))
202 |
203 | logging.info('epoch %d total training loss = %.3f' %(epoch, total_train_loss/len(TrainImgLoader)))
204 | writer.add_scalar("loss",total_train_loss/len(TrainImgLoader),epoch)
205 |
206 | # SAVE
207 | if epoch%1==0:
208 | savefilename = args.savemodel+'finetune_'+str(epoch)+'.tar'
209 | torch.save({
210 | 'epoch': epoch,
211 | 'state_dict': model.state_dict(),
212 | 'train_loss': total_train_loss,
213 | 'test_loss': total_test_loss,
214 | }, savefilename)
215 |
216 |
217 | if __name__ == '__main__':
218 | main()
219 |
--------------------------------------------------------------------------------
/deeppruner/loss_evaluation.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch.nn.functional as F
18 | from collections import namedtuple
19 | import logging
20 |
21 | loss_weights = {
22 | 'alpha_super_refined': 1.6,
23 | 'alpha_refined': 1.3,
24 | 'alpha_ca': 1.0,
25 | 'alpha_quantile': 1.0,
26 | 'alpha_min_max': 0.7
27 | }
28 |
29 | loss_weights = namedtuple('loss_weights', loss_weights.keys())(*loss_weights.values())
30 |
31 | def loss_evaluation(result, disp_true, mask, cost_aggregator_scale=4):
32 |
33 | # forces min_disparity to be equal or slightly lower than the true disparity
34 | quantile_mask1 = ((disp_true[mask] - result[-1][mask]) < 0).float()
35 | quantile_loss1 = (disp_true[mask] - result[-1][mask]) * (0.05 - quantile_mask1)
36 | quantile_min_disparity_loss = quantile_loss1.mean()
37 |
38 | # forces max_disparity to be equal or slightly larger than the true disparity
39 | quantile_mask2 = ((disp_true[mask] - result[-2][mask]) < 0).float()
40 | quantile_loss2 = (disp_true[mask] - result[-2][mask]) * (0.95 - quantile_mask2)
41 | quantile_max_disparity_loss = quantile_loss2.mean()
42 |
43 | min_disparity_loss = F.smooth_l1_loss(result[-1][mask], disp_true[mask], size_average=True)
44 | max_disparity_loss = F.smooth_l1_loss(result[-2][mask], disp_true[mask], size_average=True)
45 | ca_depth_loss = F.smooth_l1_loss(result[-3][mask], disp_true[mask], size_average=True)
46 | refined_depth_loss = F.smooth_l1_loss(result[-4][mask], disp_true[mask], size_average=True)
47 |
48 | logging.info("============== evaluated losses ==================")
49 | if cost_aggregator_scale == 8:
50 | refined_depth_loss_1 = F.smooth_l1_loss(result[-5][mask], disp_true[mask], size_average=True)
51 | loss = (loss_weights.alpha_super_refined * refined_depth_loss_1)
52 | output_disparity = result[-5]
53 | logging.info('refined_depth_loss_1: %.6f', refined_depth_loss_1)
54 | else:
55 | loss = 0
56 | output_disparity = result[-4]
57 |
58 | loss += (loss_weights.alpha_refined * refined_depth_loss) + \
59 | (loss_weights.alpha_ca * ca_depth_loss) + \
60 | (loss_weights.alpha_quantile * (quantile_max_disparity_loss + quantile_min_disparity_loss)) + \
61 | (loss_weights.alpha_min_max * (min_disparity_loss + max_disparity_loss))
62 |
63 | logging.info('refined_depth_loss: %.6f' % refined_depth_loss)
64 | logging.info('ca_depth_loss: %.6f' % ca_depth_loss)
65 | logging.info('quantile_loss_max_disparity: %.6f' % quantile_max_disparity_loss)
66 | logging.info('quantile_loss_min_disparity: %.6f' % quantile_min_disparity_loss)
67 | logging.info('max_disparity_loss: %.6f' % max_disparity_loss)
68 | logging.info('min_disparity_loss: %.6f' % min_disparity_loss)
69 | logging.info("==================================================\n")
70 |
71 | return loss, output_disparity
72 |
--------------------------------------------------------------------------------
/deeppruner/models/.gitignore:
--------------------------------------------------------------------------------
1 | *.pyc
2 | *.log
--------------------------------------------------------------------------------
/deeppruner/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/deeppruner/models/__init__.py
--------------------------------------------------------------------------------
/deeppruner/models/config.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 |
17 | from __future__ import print_function
18 |
19 | class obj(object):
20 | def __init__(self, d):
21 | for key, value in d.items():
22 | if isinstance(value, (list, tuple)):
23 | setattr(self, key, [obj(x) if isinstance(x, dict) else x for x in value])
24 | else:
25 | setattr(self, key, obj(value) if isinstance(value, dict) else value)
26 |
27 |
28 | config = {
29 | "max_disp": 192,
30 | "cost_aggregator_scale": 4, # for DeepPruner-fast change this to 8.
31 | "mode": "training", # for evaluation/ submission, change this to evaluation.
32 |
33 |
34 | # The code allows the user to change the feature extrcator to any feature extractor of their choice.
35 | # The only requirements of the feature extractor are:
36 | # 1. For cost_aggregator_scale == 4:
37 | # features at downsample-level X4 (feature_extractor_ca_level)
38 | # and downsample-level X2 (feature_extractor_refinement_level) should be the output.
39 | # For cost_aggregator_scale == 8:
40 | # features at downsample-level X8 (feature_extractor_ca_level),
41 | # downsample-level X4 (feature_extractor_refinement_level),
42 | # downsample-level X2 (feature_extractor_refinement_level_1) should be the output,
43 |
44 | # 2. If the feature extractor is modified, change the "feature_extractor_outplanes_*" key in the config
45 | # accordingly.
46 |
47 | "feature_extractor_ca_level_outplanes": 32,
48 | "feature_extractor_refinement_level_outplanes": 32, # for DeepPruner-fast change this to 64.
49 | "feature_extractor_refinement_level_1_outplanes": 32,
50 |
51 | "patch_match_args": {
52 | "sample_count": 12,
53 | "iteration_count": 2,
54 | "propagation_filter_size": 3
55 | },
56 |
57 | "post_CRP_sample_count": 7,
58 | "post_CRP_sampler_type": "uniform", #change to patch_match for Sceneflow model.
59 |
60 | "hourglass_inplanes": 16
61 | }
62 |
63 | config = obj(config)
64 |
--------------------------------------------------------------------------------
/deeppruner/models/deeppruner.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | from models.submodules3d import MinDisparityPredictor, MaxDisparityPredictor, CostAggregator
18 | from models.submodules2d import RefinementNet
19 | from models.submodules import SubModule, conv_relu, convbn_2d_lrelu, convbn_3d_lrelu
20 | from models.utils import SpatialTransformer, UniformSampler
21 | from models.patch_match import PatchMatch
22 | import torch
23 | import torch.nn as nn
24 | import torch.nn.functional as F
25 | from models.config import config as args
26 |
27 | class DeepPruner(SubModule):
28 | def __init__(self):
29 | super(DeepPruner, self).__init__()
30 | self.scale = args.cost_aggregator_scale
31 | self.max_disp = args.max_disp // self.scale
32 | self.mode = args.mode
33 |
34 | self.patch_match_args = args.patch_match_args
35 | self.patch_match_sample_count = self.patch_match_args.sample_count
36 | self.patch_match_iteration_count = self.patch_match_args.iteration_count
37 | self.patch_match_propagation_filter_size = self.patch_match_args.propagation_filter_size
38 |
39 | self.post_CRP_sample_count = args.post_CRP_sample_count
40 | self.post_CRP_sampler_type = args.post_CRP_sampler_type
41 | hourglass_inplanes = args.hourglass_inplanes
42 |
43 | # refinement input features are composed of:
44 | # left image low level features +
45 | # CA output features + CA output disparity
46 |
47 | if self.scale == 8:
48 | from models.feature_extractor_fast import feature_extraction
49 | refinement_inplanes_1 = args.feature_extractor_refinement_level_1_outplanes + 1
50 | self.refinement_net1 = RefinementNet(refinement_inplanes_1)
51 | else:
52 | from models.feature_extractor_best import feature_extraction
53 |
54 | refinement_inplanes = args.feature_extractor_refinement_level_outplanes + self.post_CRP_sample_count + 2 + 1
55 | self.refinement_net = RefinementNet(refinement_inplanes)
56 |
57 | # cost_aggregator_inplanes are composed of:
58 | # left and right image features from feature_extractor (ca_level) +
59 | # features from min/max predictors +
60 | # min_disparity + max_disparity + disparity_samples
61 |
62 | cost_aggregator_inplanes = 2 * (args.feature_extractor_ca_level_outplanes +
63 | self.patch_match_sample_count + 2) + 1
64 | self.cost_aggregator = CostAggregator(cost_aggregator_inplanes, hourglass_inplanes)
65 |
66 | self.feature_extraction = feature_extraction()
67 | self.min_disparity_predictor = MinDisparityPredictor(hourglass_inplanes)
68 | self.max_disparity_predictor = MaxDisparityPredictor(hourglass_inplanes)
69 | self.spatial_transformer = SpatialTransformer()
70 | self.patch_match = PatchMatch(self.patch_match_propagation_filter_size)
71 | self.uniform_sampler = UniformSampler()
72 |
73 | # Confidence Range Predictor(CRP) input features are composed of:
74 | # left and right image features from feature_extractor (ca_level) +
75 | # disparity_samples
76 |
77 | CRP_feature_count = 2 * args.feature_extractor_ca_level_outplanes + 1
78 | self.dres0 = nn.Sequential(convbn_3d_lrelu(CRP_feature_count, 64, 3, 1, 1),
79 | convbn_3d_lrelu(64, 32, 3, 1, 1))
80 |
81 | self.dres1 = nn.Sequential(convbn_3d_lrelu(32, 32, 3, 1, 1),
82 | convbn_3d_lrelu(32, hourglass_inplanes, 3, 1, 1))
83 |
84 | self.min_disparity_conv = conv_relu(1, 1, 5, 1, 2)
85 | self.max_disparity_conv = conv_relu(1, 1, 5, 1, 2)
86 | self.ca_disparity_conv = conv_relu(1, 1, 5, 1, 2)
87 |
88 | self.ca_features_conv = convbn_2d_lrelu(self.post_CRP_sample_count + 2,
89 | self. post_CRP_sample_count + 2, 5, 1, 2, dilation=1, bias=True)
90 | self.min_disparity_features_conv = convbn_2d_lrelu(self.patch_match_sample_count + 2,
91 | self.patch_match_sample_count + 2, 5, 1, 2, dilation=1, bias=True)
92 | self.max_disparity_features_conv = convbn_2d_lrelu(self.patch_match_sample_count + 2,
93 | self.patch_match_sample_count + 2, 5, 1, 2, dilation=1, bias=True)
94 |
95 | self.weight_init()
96 |
97 |
98 | def generate_search_range(self, left_input, sample_count, stage,
99 | input_min_disparity=None, input_max_disparity=None):
100 | """
101 | Description: Generates the disparity search range depending upon the stage it is called.
102 | If stage is "pre" (Pre-PatchMatch and Pre-ConfidenceRangePredictor), the search range is
103 | the entire disparity search range.
104 | If stage is "post" (Post-ConfidenceRangePredictor), then the ConfidenceRangePredictor search range
105 | is adjusted for maximum efficiency.
106 | Args:
107 | :left_input: Left Image Features
108 | :sample_count: number of samples to be generated from the search range. Used to adjust the search range.
109 | :stage: "pre"(Pre-PatchMatch) or "post"(Post-ConfidenceRangePredictor)
110 | :input_min_disparity (default:None): ConfidenceRangePredictor disparity lowerbound (for stage=="post")
111 | :input_max_disparity (default:None): ConfidenceRangePredictor disparity upperbound (for stage=="post")
112 |
113 | Returns:
114 | :min_disparity: Lower bound of disparity search range
115 | :max_disparity: Upper bound of disaprity search range.
116 | """
117 |
118 | device = left_input.get_device()
119 | if stage is "pre":
120 | min_disparity = torch.zeros((left_input.size()[0], 1, left_input.size()[2], left_input.size()[3]),
121 | device=device)
122 | max_disparity = torch.zeros((left_input.size()[0], 1, left_input.size()[2], left_input.size()[3]),
123 | device=device) + self.max_disp
124 |
125 | else:
126 | min_disparity1 = torch.min(input_min_disparity, input_max_disparity)
127 | max_disparity1 = torch.max(input_min_disparity, input_max_disparity)
128 |
129 | # if (max_disparity1 - min_disparity1) > sample_count:
130 | # sample uniformly "sample_count" number of samples from (min_disparity1, max_disparity1)
131 | # else:
132 | # stretch min_disparity1 and max_disparity1 such that (max_disparity1 - min_disparity1) == sample_count
133 |
134 | min_disparity = torch.clamp(min_disparity1 - torch.clamp((
135 | sample_count - max_disparity1 + min_disparity1), min=0) / 2.0, min=0, max=self.max_disp)
136 | max_disparity = torch.clamp(max_disparity1 + torch.clamp(
137 | sample_count - max_disparity1 + min_disparity, min=0), min=0, max=self.max_disp)
138 |
139 | return min_disparity, max_disparity
140 |
141 | def generate_disparity_samples(self, left_input, right_input, min_disparity,
142 | max_disparity, sample_count=12, sampler_type="patch_match"):
143 | """
144 | Description: Generates "sample_count" number of disparity samples from the
145 | search range (min_disparity, max_disparity)
146 | Samples are generated either uniformly from the search range
147 | or are generated using PatchMatch.
148 |
149 | Args:
150 | :left_input: Left Image features.
151 | :right_input: Right Image features.
152 | :min_disparity: LowerBound of the disaprity search range.
153 | :max_disparity: UpperBound of the disparity search range.
154 | :sample_count (default:12): Number of samples to be generated from the input search range.
155 | :sampler_type (default:"patch_match"): samples are generated either using
156 | "patch_match" or "uniform" sampler.
157 | Returns:
158 | :disparity_samples:
159 | """
160 | if sampler_type is "patch_match":
161 | disparity_samples = self.patch_match(left_input, right_input, min_disparity,
162 | max_disparity, sample_count, self.patch_match_iteration_count)
163 | else:
164 | disparity_samples = self.uniform_sampler(min_disparity, max_disparity, sample_count)
165 |
166 | disparity_samples = torch.cat((torch.floor(min_disparity), disparity_samples, torch.ceil(max_disparity)),
167 | dim=1).long()
168 | return disparity_samples
169 |
170 | def cost_volume_generator(self, left_input, right_input, disparity_samples):
171 | """
172 | Description: Generates cost-volume using left image features, disaprity samples
173 | and warped right image features.
174 | Args:
175 | :left_input: Left Image fetaures
176 | :right_input: Right Image features
177 | :disparity_samples: Disaprity samples
178 |
179 | Returns:
180 | :cost_volume:
181 | :disaprity_samples:
182 | :left_feature_map:
183 | """
184 |
185 | right_feature_map, left_feature_map = self.spatial_transformer(left_input,
186 | right_input, disparity_samples)
187 | disparity_samples = disparity_samples.unsqueeze(1).float()
188 |
189 | cost_volume = torch.cat((left_feature_map, right_feature_map, disparity_samples), dim=1)
190 |
191 | return cost_volume, disparity_samples, left_feature_map
192 |
193 | def confidence_range_predictor(self, cost_volume, disparity_samples):
194 | """
195 | Description: The original search space for all pixels is identical. However, in practice, for each
196 | pixel, the highly probable disparities lie in a narrow region. Using the small subset
197 | of disparities estimated from the PatchMatch stage, we have sufficient information to
198 | predict the range in which the true disparity lies. We thus exploit a confidence range
199 | prediction network to adjust the search space for each pixel.
200 |
201 | Args:
202 | :cost_volume: Input Cost-Volume
203 | :disparity_samples: Initial Disparity samples.
204 |
205 | Returns:
206 | :min_disparity: ConfidenceRangePredictor disparity lowerbound
207 | :max_disparity: ConfidenceRangePredictor disparity upperbound
208 | :min_disparity_features: features from ConfidenceRangePredictor-Min
209 | :max_disparity_features: features from ConfidenceRangePredictor-Max
210 | """
211 | # cost-volume bottleneck layers
212 | cost_volume = self.dres0(cost_volume)
213 | cost_volume = self.dres1(cost_volume)
214 |
215 | min_disparity, min_disparity_features = self.min_disparity_predictor(cost_volume,
216 | disparity_samples.squeeze(1))
217 |
218 | max_disparity, max_disparity_features = self.max_disparity_predictor(cost_volume,
219 | disparity_samples.squeeze(1))
220 |
221 | min_disparity = self.min_disparity_conv(min_disparity)
222 | max_disparity = self.max_disparity_conv(max_disparity)
223 | min_disparity_features = self.min_disparity_features_conv(min_disparity_features)
224 | max_disparity_features = self.max_disparity_features_conv(max_disparity_features)
225 |
226 | return min_disparity, max_disparity, min_disparity_features, max_disparity_features
227 |
228 | def forward(self, left_input, right_input):
229 | """
230 | DeepPruner
231 | Description: DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
232 |
233 | Args:
234 | :left_input: Left Stereo Image
235 | :right_input: Right Stereo Image
236 | Returns:
237 | outputs depend of args.mode ("evaluation or "training"), and
238 | also on args.cost_aggregator_scale (8 or 4)
239 |
240 | All possible outputs can be:
241 | :refined_disparity_1: DeepPruner disparity output after Refinement1 stage.
242 | s (only when args.cost_aggregator_scale==8)
243 | :refined_disparity: DeepPruner disparity output after Refinement stage.
244 | :ca_disparity: DeepPruner disparity output after 3D-Cost Aggregation stage.
245 | :max_disparity: DeepPruner disparity by Confidence Range Predictor (Max)
246 | :min_disparity: DeepPruner disparity by Confidence Range Predictor (Min)
247 |
248 | """
249 |
250 | if self.scale == 8:
251 | left_spp_features, left_low_level_features, left_low_level_features_1 = self.feature_extraction(left_input)
252 | right_spp_features, right_low_level_features, _ = self.feature_extraction(
253 | right_input)
254 | else:
255 | left_spp_features, left_low_level_features = self.feature_extraction(left_input)
256 | right_spp_features, right_low_level_features = self.feature_extraction(right_input)
257 |
258 | min_disparity, max_disparity = self.generate_search_range(
259 | left_spp_features,
260 | sample_count=self.patch_match_sample_count, stage="pre")
261 |
262 | disparity_samples = self.generate_disparity_samples(
263 | left_spp_features,
264 | right_spp_features, min_disparity, max_disparity,
265 | sample_count=self.patch_match_sample_count, sampler_type="patch_match")
266 |
267 | cost_volume, disparity_samples, _ = self.cost_volume_generator(left_spp_features,
268 | right_spp_features,
269 | disparity_samples)
270 |
271 | min_disparity, max_disparity, min_disparity_features, max_disparity_features = \
272 | self.confidence_range_predictor(cost_volume, disparity_samples)
273 |
274 | stretched_min_disparity, stretched_max_disparity = self.generate_search_range(
275 | left_spp_features,
276 | sample_count=self.post_CRP_sample_count, stage='post',
277 | input_min_disparity=min_disparity, input_max_disparity=max_disparity)
278 |
279 | disparity_samples = self.generate_disparity_samples(
280 | left_spp_features,
281 | right_spp_features, stretched_min_disparity, stretched_max_disparity,
282 | sample_count=self.post_CRP_sample_count, sampler_type=self.post_CRP_sampler_type)
283 |
284 | cost_volume, disparity_samples, expanded_left_feature_map = self.cost_volume_generator(
285 | left_spp_features,
286 | right_spp_features,
287 | disparity_samples)
288 |
289 | min_disparity_features = min_disparity_features.unsqueeze(2).expand(-1, -1,
290 | expanded_left_feature_map.size()[2], -1, -1)
291 | max_disparity_features = max_disparity_features.unsqueeze(2).expand(-1, -1,
292 | expanded_left_feature_map.size()[2], -1, -1)
293 |
294 | cost_volume = torch.cat((cost_volume, min_disparity_features, max_disparity_features), dim=1)
295 | ca_disparity, ca_features = self.cost_aggregator(cost_volume, disparity_samples.squeeze(1))
296 |
297 | ca_disparity = F.interpolate(ca_disparity * 2, scale_factor=(2, 2), mode='bilinear')
298 | ca_features = F.interpolate(ca_features, scale_factor=(2, 2), mode='bilinear')
299 | ca_disparity = self.ca_disparity_conv(ca_disparity)
300 | ca_features = self.ca_features_conv(ca_features)
301 |
302 | refinement_net_input = torch.cat((left_low_level_features, ca_features, ca_disparity), dim=1)
303 | refined_disparity = self.refinement_net(refinement_net_input, ca_disparity)
304 |
305 | refined_disparity = F.interpolate(refined_disparity * 2, scale_factor=(2, 2), mode='bilinear')
306 |
307 | if self.scale == 8:
308 | refinement_net_input = torch.cat((left_low_level_features_1, refined_disparity), dim=1)
309 | refined_disparity_1 = self.refinement_net1(refinement_net_input, refined_disparity)
310 |
311 | if self.mode == 'evaluation':
312 | if self.scale == 8:
313 | refined_disparity_1 = F.interpolate(refined_disparity_1 * 2, scale_factor=(2, 2),
314 | mode='bilinear').squeeze(1)
315 | return refined_disparity_1
316 | return refined_disparity.squeeze(1)
317 |
318 | min_disparity = F.interpolate(min_disparity * self.scale, scale_factor=(self.scale, self.scale),
319 | mode='bilinear').squeeze(1)
320 | max_disparity = F.interpolate(max_disparity * self.scale, scale_factor=(self.scale, self.scale),
321 | mode='bilinear').squeeze(1)
322 | ca_disparity = F.interpolate(ca_disparity * (self.scale // 2),
323 | scale_factor=((self.scale // 2), (self.scale // 2)), mode='bilinear').squeeze(1)
324 |
325 | if self.scale == 8:
326 | refined_disparity = F.interpolate(refined_disparity * 2, scale_factor=(2, 2), mode='bilinear').squeeze(1)
327 | refined_disparity_1 = F.interpolate(refined_disparity_1 * 2,
328 | scale_factor=(2, 2), mode='bilinear').squeeze(1)
329 |
330 | return refined_disparity_1, refined_disparity, ca_disparity, max_disparity, min_disparity
331 |
332 | return refined_disparity.squeeze(1), ca_disparity, max_disparity, min_disparity
333 |
--------------------------------------------------------------------------------
/deeppruner/models/feature_extractor_best.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch
18 | import torch.nn as nn
19 | import torch.utils.data
20 | import torch.nn.functional as F
21 | from models.submodules import BasicBlock, convbn_relu
22 |
23 |
24 | class feature_extraction(nn.Module):
25 | def __init__(self):
26 | super(feature_extraction, self).__init__()
27 | self.inplanes = 32
28 | self.firstconv = nn.Sequential(convbn_relu(3, 32, 3, 2, 1, 1),
29 | convbn_relu(32, 32, 3, 1, 1, 1),
30 | convbn_relu(32, 32, 3, 1, 1, 1))
31 |
32 | self.layer1 = self._make_layer(BasicBlock, 32, 3, 1, 1, 1)
33 | self.layer2 = self._make_layer(BasicBlock, 64, 16, 2, 1, 1)
34 | self.layer3 = self._make_layer(BasicBlock, 128, 3, 1, 1, 1)
35 | self.layer4 = self._make_layer(BasicBlock, 128, 3, 1, 1, 2)
36 |
37 | self.branch1 = nn.Sequential(nn.AvgPool2d((64, 64), stride=(64, 64)),
38 | convbn_relu(128, 32, 1, 1, 0, 1))
39 |
40 | self.branch2 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32, 32)),
41 | convbn_relu(128, 32, 1, 1, 0, 1))
42 |
43 | self.branch3 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16, 16)),
44 | convbn_relu(128, 32, 1, 1, 0, 1))
45 |
46 | self.branch4 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8, 8)),
47 | convbn_relu(128, 32, 1, 1, 0, 1))
48 |
49 | self.lastconv = nn.Sequential(convbn_relu(320, 128, 3, 1, 1, 1),
50 | nn.Conv2d(128, 32, kernel_size=1, padding=0, stride=1, bias=False))
51 |
52 | def _make_layer(self, block, planes, blocks, stride, pad, dilation):
53 | downsample = None
54 | if stride != 1 or self.inplanes != planes * block.expansion:
55 | downsample = nn.Sequential(
56 | nn.Conv2d(self.inplanes, planes * block.expansion,
57 | kernel_size=1, stride=stride, bias=False),
58 | nn.BatchNorm2d(planes * block.expansion),)
59 |
60 | layers = []
61 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation))
62 | self.inplanes = planes * block.expansion
63 | for i in range(1, blocks):
64 | layers.append(block(self.inplanes, planes, 1, None, pad, dilation))
65 |
66 | return nn.Sequential(*layers)
67 |
68 | def forward(self, input):
69 | """
70 | Feature Extractor
71 | Description: The goal of the feature extraction network is to produce a reliable pixel-wise
72 | feature representation from the input image. Specifically, we employ four residual blocks
73 | and use X2 dilated convolution for the last block to enlarge the receptive field.
74 | We then apply spatial pyramid pooling to build a 4-level pyramid feature.
75 | Through multi-scale information, the model is able to capture large context while
76 | maintaining a high spatial resolution. The size of the final feature map is 1/4 of
77 | the originalinput image size. We share the parameters for the left and right feature network.
78 |
79 | Args:
80 | :input: Input image (RGB)
81 |
82 | Returns:
83 | :output_feature: spp_features (downsampled X4)
84 | :output1: low_level_features (downsampled X2)
85 | """
86 |
87 | output0 = self.firstconv(input)
88 | output1 = self.layer1(output0)
89 | output_raw = self.layer2(output1)
90 | output = self.layer3(output_raw)
91 | output_skip = self.layer4(output)
92 |
93 | output_branch1 = self.branch1(output_skip)
94 | output_branch1 = F.upsample(output_branch1, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear')
95 |
96 | output_branch2 = self.branch2(output_skip)
97 | output_branch2 = F.upsample(output_branch2, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear')
98 |
99 | output_branch3 = self.branch3(output_skip)
100 | output_branch3 = F.upsample(output_branch3, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear')
101 |
102 | output_branch4 = self.branch4(output_skip)
103 | output_branch4 = F.upsample(output_branch4, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear')
104 |
105 | output_feature = torch.cat(
106 | (output_raw,
107 | output_skip,
108 | output_branch4,
109 | output_branch3,
110 | output_branch2,
111 | output_branch1),
112 | 1)
113 | output_feature = self.lastconv(output_feature)
114 |
115 | return output_feature, output1
116 |
--------------------------------------------------------------------------------
/deeppruner/models/feature_extractor_fast.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch
18 | import torch.nn as nn
19 | import torch.utils.data
20 | import torch.nn.functional as F
21 | from models.submodules import BasicBlock, convbn_relu
22 |
23 |
24 | class feature_extraction(nn.Module):
25 | def __init__(self):
26 | super(feature_extraction, self).__init__()
27 | self.inplanes = 32
28 | self.firstconv = nn.Sequential(convbn_relu(3, 32, 3, 2, 1, 1),
29 | convbn_relu(32, 32, 3, 1, 1, 1),
30 | convbn_relu(32, 32, 3, 1, 1, 1))
31 |
32 | self.layer1 = self._make_layer(BasicBlock, 32, 3, 1, 1, 1)
33 | self.layer2 = self._make_layer(BasicBlock, 64, 16, 2, 1, 1)
34 | self.layer3 = self._make_layer(BasicBlock, 128, 3, 2, 1, 1)
35 | self.layer4 = self._make_layer(BasicBlock, 128, 3, 1, 1, 1)
36 |
37 | self.branch2 = nn.Sequential(nn.AvgPool2d((32, 32), stride=(32, 32)),
38 | convbn_relu(128, 32, 1, 1, 0, 1))
39 |
40 | self.branch3 = nn.Sequential(nn.AvgPool2d((16, 16), stride=(16, 16)),
41 | convbn_relu(128, 32, 1, 1, 0, 1))
42 |
43 | self.branch4 = nn.Sequential(nn.AvgPool2d((8, 8), stride=(8, 8)),
44 | convbn_relu(128, 32, 1, 1, 0, 1))
45 |
46 | self.lastconv = nn.Sequential(convbn_relu(352, 128, 3, 1, 1, 1),
47 | nn.Conv2d(128, 32, kernel_size=1, padding=0, stride=1, bias=False))
48 |
49 | def _make_layer(self, block, planes, blocks, stride, pad, dilation):
50 | downsample = None
51 | if stride != 1 or self.inplanes != planes * block.expansion:
52 | downsample = nn.Sequential(
53 | nn.Conv2d(self.inplanes, planes * block.expansion,
54 | kernel_size=1, stride=stride, bias=False),
55 | nn.BatchNorm2d(planes * block.expansion),)
56 |
57 | layers = []
58 | layers.append(block(self.inplanes, planes, stride, downsample, pad, dilation))
59 | self.inplanes = planes * block.expansion
60 | for i in range(1, blocks):
61 | layers.append(block(self.inplanes, planes, 1, None, pad, dilation))
62 |
63 | return nn.Sequential(*layers)
64 |
65 | def forward(self, input):
66 | """
67 | Feature Extractor
68 | Description: The goal of the feature extraction network is to produce a reliable pixel-wise
69 | feature representation from the input image. Specifically, we employ four residual blocks
70 | and use X2 dilated convolution for the last block to enlarge the receptive field.
71 | We then apply spatial pyramid pooling to build a 4-level pyramid feature.
72 | Through multi-scale information, the model is able to capture large context while
73 | maintaining a high spatial resolution. The size of the final feature map is 1/4 of
74 | the originalinput image size. We share the parameters for the left and right feature network.
75 |
76 | Args:
77 | :input: Input image (RGB)
78 |
79 | Returns:
80 | :output_feature: spp_features (downsampled X8)
81 | :output_raw: features (downsampled X4)
82 | :output1: low_level_features (downsampled X2)
83 | """
84 |
85 | output0 = self.firstconv(input)
86 | output1 = self.layer1(output0)
87 | output_raw = self.layer2(output1)
88 | output = self.layer3(output_raw)
89 | output_skip = self.layer4(output)
90 |
91 | output_branch2 = self.branch2(output_skip)
92 | output_branch2 = F.upsample(output_branch2, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear')
93 |
94 | output_branch3 = self.branch3(output_skip)
95 | output_branch3 = F.upsample(output_branch3, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear')
96 |
97 | output_branch4 = self.branch4(output_skip)
98 | output_branch4 = F.upsample(output_branch4, (output_skip.size()[2], output_skip.size()[3]), mode='bilinear')
99 |
100 | output_feature = torch.cat((output, output_skip, output_branch4, output_branch3, output_branch2), 1)
101 | output_feature = self.lastconv(output_feature)
102 |
103 | return output_feature, output_raw, output1
104 |
--------------------------------------------------------------------------------
/deeppruner/models/patch_match.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch
18 | import torch.nn as nn
19 | import torch.nn.functional as F
20 |
21 |
22 | class DisparityInitialization(nn.Module):
23 |
24 | def __init__(self):
25 | super(DisparityInitialization, self).__init__()
26 |
27 | def forward(self, min_disparity, max_disparity, number_of_intervals=10):
28 | """
29 | PatchMatch Initialization Block
30 | Description: Rather than allowing each sample/ particle to reside in the full disparity space,
31 | we divide the search space into 'number_of_intervals' intervals, and force the
32 | i-th particle to be in a i-th interval. This guarantees the diversity of the
33 | particles and helps improve accuracy for later computations.
34 |
35 | As per implementation,
36 | this function divides the complete disparity search space into multiple intervals.
37 |
38 | Args:
39 | :min_disparity: Min Disparity of the disparity search range.
40 | :max_disparity: Max Disparity of the disparity search range.
41 | :number_of_intervals (default: 10): Number of samples to be generated.
42 | Returns:
43 | :interval_noise: Random value between 0-1. Represents offset of the from the interval_min_disparity.
44 | :interval_min_disparity: disparity_sample = interval_min_disparity + interval_noise
45 | :multiplier: 1.0 / number_of_intervals
46 | """
47 |
48 | device = min_disparity.get_device()
49 |
50 | multiplier = 1.0 / number_of_intervals
51 | range_multiplier = torch.arange(0.0, 1, multiplier, device=device).view(number_of_intervals, 1, 1)
52 | range_multiplier = range_multiplier.repeat(1, min_disparity.size()[2], min_disparity.size()[3])
53 |
54 | interval_noise = min_disparity.new_empty(min_disparity.size()[0], number_of_intervals, min_disparity.size()[2],
55 | min_disparity.size()[3]).uniform_(0, 1)
56 | interval_min_disparity = min_disparity + (max_disparity - min_disparity) * range_multiplier
57 |
58 | return interval_noise, interval_min_disparity, multiplier
59 |
60 |
61 | class Evaluate(nn.Module):
62 | def __init__(self, filter_size=3, temperature=7):
63 | super(Evaluate, self).__init__()
64 | self.temperature = temperature
65 | self.filter_size = filter_size
66 | self.softmax = torch.nn.Softmax(dim=1)
67 |
68 | def forward(self, left_input, right_input, disparity_samples, normalized_disparity_samples):
69 | """
70 | PatchMatch Evaluation Block
71 | Description: For each pixel i, matching scores are computed by taking the inner product between the
72 | left feature and the right feature: score(i,j) = feature_left(i), feature_right(i+disparity(i,j))
73 | for all candidates j. The best k disparity value for each pixel is carried towards the next iteration.
74 |
75 | As per implementation,
76 | the complete disparity search range is discretized into intervals in
77 | DisparityInitialization() function. Corresponding to each disparity interval, we have multiple samples
78 | to evaluate. The best disparity sample per interval is the output of the function.
79 |
80 | Args:
81 | :left_input: Left Image Feature Map
82 | :right_input: Right Image Feature Map
83 | :disparity_samples: Disparity Samples to be evaluated. For each pixel, we have
84 | ("number of intervals" X "number_of_samples_per_intervals") samples.
85 |
86 | :normalized_disparity_samples:
87 | Returns:
88 | :disparity_samples: Evaluated disparity sample, one per disparity interval.
89 | :normalized_disparity_samples: Evaluated normaized disparity sample, one per disparity interval.
90 | """
91 | device = left_input.get_device()
92 | left_y_coordinate = torch.arange(0.0, left_input.size()[3], device=device).repeat(
93 | left_input.size()[2]).view(left_input.size()[2], left_input.size()[3])
94 |
95 | left_y_coordinate = torch.clamp(left_y_coordinate, min=0, max=left_input.size()[3] - 1)
96 | left_y_coordinate = left_y_coordinate.expand(left_input.size()[0], -1, -1)
97 |
98 | right_feature_map = right_input.expand(disparity_samples.size()[1], -1, -1, -1, -1).permute([1, 2, 0, 3, 4])
99 | left_feature_map = left_input.expand(disparity_samples.size()[1], -1, -1, -1, -1).permute([1, 2, 0, 3, 4])
100 |
101 | disparity_sample_strength = disparity_samples.new(disparity_samples.size()[0],
102 | disparity_samples.size()[1],
103 | disparity_samples.size()[2],
104 | disparity_samples.size()[3])
105 |
106 | right_y_coordinate = left_y_coordinate.expand(
107 | disparity_samples.size()[1], -1, -1, -1).permute([1, 0, 2, 3]).float()
108 | right_y_coordinate = right_y_coordinate - disparity_samples
109 | right_y_coordinate = torch.clamp(right_y_coordinate, min=0, max=right_input.size()[3] - 1)
110 |
111 | warped_right_feature_map = torch.gather(right_feature_map,
112 | dim=4,
113 | index=right_y_coordinate.expand(
114 | right_input.size()[1], -1, -1, -1, -1).permute([1, 0, 2, 3, 4]).long())
115 |
116 | disparity_sample_strength = torch.mean(left_feature_map * warped_right_feature_map, dim=1) * self.temperature
117 |
118 | disparity_sample_strength = disparity_sample_strength.view(
119 | disparity_sample_strength.size()[0],
120 | disparity_sample_strength.size()[1] // (self.filter_size),
121 | (self.filter_size),
122 | disparity_sample_strength.size()[2],
123 | disparity_sample_strength.size()[3])
124 |
125 | disparity_samples = disparity_samples.view(disparity_samples.size()[0],
126 | disparity_samples.size()[1] // (self.filter_size),
127 | (self.filter_size),
128 | disparity_samples.size()[2],
129 | disparity_samples.size()[3])
130 |
131 | normalized_disparity_samples = normalized_disparity_samples.view(
132 | normalized_disparity_samples.size()[0],
133 | normalized_disparity_samples.size()[1] // (self.filter_size),
134 | (self.filter_size),
135 | normalized_disparity_samples.size()[2],
136 | normalized_disparity_samples.size()[3])
137 |
138 | disparity_sample_strength = disparity_sample_strength.permute([0, 2, 1, 3, 4])
139 | disparity_samples = disparity_samples.permute([0, 2, 1, 3, 4])
140 | normalized_disparity_samples = normalized_disparity_samples.permute([0, 2, 1, 3, 4])
141 |
142 | disparity_sample_strength = torch.softmax(disparity_sample_strength, dim=1)
143 | disparity_samples = torch.sum(disparity_samples * disparity_sample_strength, dim=1)
144 | normalized_disparity_samples = torch.sum(normalized_disparity_samples * disparity_sample_strength, dim=1)
145 |
146 | return normalized_disparity_samples, disparity_samples
147 |
148 |
149 | class Propagation(nn.Module):
150 | def __init__(self, filter_size=3):
151 | super(Propagation, self).__init__()
152 | self.filter_size = filter_size
153 |
154 | def forward(self, disparity_samples, device, propagation_type="horizontal"):
155 | """
156 | PatchMatch Propagation Block
157 | Description: Particles from adjacent pixels are propagated together through convolution with a
158 | pre-defined one-hot filter pattern, which en-codes the fact that we allow each pixel
159 | to propagate particles to its 4-neighbours.
160 |
161 | As per implementation, the complete disparity search range is discretized into intervals in
162 | DisparityInitialization() function.
163 | Now, propagation of samples from neighbouring pixels, is done per interval. This implies that after
164 | propagation, number of samples per pixel = (filter_size X number_of_intervals)
165 |
166 | Args:
167 | :disparity_samples:
168 | :device: Cuda device
169 | :propagation_type (default:"horizontal"): In order to be memory efficient, we use separable convolutions
170 | for propagtaion.
171 |
172 | Returns:
173 | :aggregated_disparity_samples: Disparity Samples aggregated from the neighbours.
174 |
175 | """
176 |
177 | disparity_samples = disparity_samples.view(disparity_samples.size()[0],
178 | 1,
179 | disparity_samples.size()[1],
180 | disparity_samples.size()[2],
181 | disparity_samples.size()[3])
182 |
183 | if propagation_type is "horizontal":
184 | label = torch.arange(0, self.filter_size, device=device).repeat(self.filter_size).view(
185 | self.filter_size, 1, 1, 1, self.filter_size)
186 |
187 | one_hot_filter = torch.zeros_like(label).scatter_(0, label, 1).float()
188 | aggregated_disparity_samples = F.conv3d(disparity_samples,
189 | one_hot_filter, padding=(0, 0, self.filter_size // 2))
190 |
191 | else:
192 | label = torch.arange(0, self.filter_size, device=device).repeat(self.filter_size).view(
193 | self.filter_size, 1, 1, self.filter_size, 1).long()
194 |
195 | one_hot_filter = torch.zeros_like(label).scatter_(0, label, 1).float()
196 | aggregated_disparity_samples = F.conv3d(disparity_samples,
197 | one_hot_filter, padding=(0, self.filter_size // 2, 0))
198 |
199 | aggregated_disparity_samples = aggregated_disparity_samples.permute([0, 2, 1, 3, 4])
200 | aggregated_disparity_samples = aggregated_disparity_samples.contiguous().view(
201 | aggregated_disparity_samples.size()[0],
202 | aggregated_disparity_samples.size()[1] * aggregated_disparity_samples.size()[2],
203 | aggregated_disparity_samples.size()[3],
204 | aggregated_disparity_samples.size()[4])
205 |
206 | return aggregated_disparity_samples
207 |
208 |
209 | class PatchMatch(nn.Module):
210 | def __init__(self, propagation_filter_size=3):
211 | super(PatchMatch, self).__init__()
212 |
213 | self.propagation_filter_size = propagation_filter_size
214 | self.propagation = Propagation(filter_size=propagation_filter_size)
215 | self.disparity_initialization = DisparityInitialization()
216 | self.evaluate = Evaluate(filter_size=propagation_filter_size)
217 |
218 | def forward(self, left_input, right_input, min_disparity, max_disparity, sample_count=10, iteration_count=3):
219 | """
220 | Differntail PatchMatch Block
221 | Description: In this work, we unroll generalized PatchMatch as a recurrent neural network,
222 | where each unrolling step is equivalent to each iteration of the algorithm.
223 | This is important as it allow us to train our full model end-to-end.
224 | Specifically, we design the following layers:
225 | - Initialization or Paticle Sampling
226 | - Propagation
227 | - Evaluation
228 | Args:
229 | :left_input: Left Image feature map
230 | :right_input: Right image feature map
231 | :min_disparity: Min of the disparity search range
232 | :max_disparity: Max of the disparity search range
233 | :sample_count (default:10): Number of disparity samples per pixel. (similar to generalized PatchMatch)
234 | :iteration_count (default:3) : Number of PatchMatch iterations
235 |
236 | Returns:
237 | :disparity_samples: For each pixel, this function returns "sample_count" disparity samples.
238 | """
239 |
240 | device = left_input.get_device()
241 | min_disparity = torch.floor(min_disparity)
242 | max_disparity = torch.ceil(max_disparity)
243 |
244 | # normalized_disparity_samples: Disparity samples normalized by the corresponding interval size.
245 | # i.e (disparity_sample - interval_min_disparity) / interval_size
246 |
247 | normalized_disparity_samples, min_disp_tensor, multiplier = self.disparity_initialization(
248 | min_disparity, max_disparity, sample_count)
249 | min_disp_tensor = min_disp_tensor.unsqueeze(2).repeat(1, 1, self.propagation_filter_size, 1, 1).view(
250 | min_disp_tensor.size()[0],
251 | min_disp_tensor.size()[1] * self.propagation_filter_size,
252 | min_disp_tensor.size()[2],
253 | min_disp_tensor.size()[3])
254 |
255 | for prop_iter in range(iteration_count):
256 | normalized_disparity_samples = self.propagation(normalized_disparity_samples, device, propagation_type="horizontal")
257 | disparity_samples = normalized_disparity_samples * \
258 | (max_disparity - min_disparity) * multiplier + min_disp_tensor
259 |
260 | normalized_disparity_samples, disparity_samples = self.evaluate(left_input,
261 | right_input,
262 | disparity_samples,
263 | normalized_disparity_samples)
264 |
265 | normalized_disparity_samples = self.propagation(normalized_disparity_samples, device, propagation_type="vertical")
266 | disparity_samples = normalized_disparity_samples * \
267 | (max_disparity - min_disparity) * multiplier + min_disp_tensor
268 |
269 | normalized_disparity_samples, disparity_samples = self.evaluate(left_input,
270 | right_input,
271 | disparity_samples,
272 | normalized_disparity_samples)
273 |
274 | return disparity_samples
275 |
--------------------------------------------------------------------------------
/deeppruner/models/submodules.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch.nn as nn
18 | import math
19 |
20 |
21 | def convbn_2d_lrelu(in_planes, out_planes, kernel_size, stride, pad, dilation=1, bias=False):
22 | return nn.Sequential(
23 | nn.Conv2d(in_planes, out_planes, kernel_size=(kernel_size, kernel_size),
24 | stride=(stride, stride), padding=(pad, pad), dilation=(dilation, dilation), bias=bias),
25 | nn.BatchNorm2d(out_planes),
26 | nn.LeakyReLU(0.1, inplace=True))
27 |
28 |
29 | def convbn_3d_lrelu(in_planes, out_planes, kernel_size, stride, pad):
30 |
31 | return nn.Sequential(nn.Conv3d(in_planes, out_planes, kernel_size=kernel_size, padding=(pad, pad, pad),
32 | stride=(1, stride, stride), bias=False),
33 | nn.BatchNorm3d(out_planes),
34 | nn.LeakyReLU(0.1, inplace=True))
35 |
36 |
37 | def conv_relu(in_planes, out_planes, kernel_size, stride, pad, bias=True):
38 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size, stride, pad, bias=bias),
39 | nn.ReLU(inplace=True))
40 |
41 |
42 | def convbn(in_planes, out_planes, kernel_size, stride, pad, dilation):
43 |
44 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
45 | padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False),
46 | nn.BatchNorm2d(out_planes))
47 |
48 |
49 | def convbn_relu(in_planes, out_planes, kernel_size, stride, pad, dilation):
50 |
51 | return nn.Sequential(nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride,
52 | padding=dilation if dilation > 1 else pad, dilation=dilation, bias=False),
53 | nn.BatchNorm2d(out_planes),
54 | nn.ReLU(inplace=True))
55 |
56 |
57 | def convbn_transpose_3d(inplanes, outplanes, kernel_size, padding, output_padding, stride, bias):
58 | return nn.Sequential(nn.ConvTranspose3d(inplanes, outplanes, kernel_size, padding=padding,
59 | output_padding=output_padding, stride=stride, bias=bias),
60 | nn.BatchNorm3d(outplanes))
61 |
62 |
63 | class BasicBlock(nn.Module):
64 | expansion = 1
65 |
66 | def __init__(self, inplanes, planes, stride, downsample, pad, dilation):
67 | super(BasicBlock, self).__init__()
68 |
69 | self.conv1 = convbn_relu(inplanes, planes, 3, stride, pad, dilation)
70 | self.conv2 = convbn(planes, planes, 3, 1, pad, dilation)
71 |
72 | self.downsample = downsample
73 | self.stride = stride
74 |
75 | def forward(self, x):
76 | out = self.conv1(x)
77 | out = self.conv2(out)
78 |
79 | if self.downsample is not None:
80 | x = self.downsample(x)
81 |
82 | out += x
83 |
84 | return out
85 |
86 |
87 | class SubModule(nn.Module):
88 | def __init__(self):
89 | super(SubModule, self).__init__()
90 |
91 | def weight_init(self):
92 | for m in self.modules():
93 | if isinstance(m, nn.Conv2d):
94 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
95 | m.weight.data.normal_(0, math.sqrt(2. / n))
96 | elif isinstance(m, nn.Conv3d):
97 | n = m.kernel_size[0] * m.kernel_size[1] * m.kernel_size[2] * m.out_channels
98 | m.weight.data.normal_(0, math.sqrt(2. / n))
99 | elif isinstance(m, nn.BatchNorm2d):
100 | m.weight.data.fill_(1)
101 | m.bias.data.zero_()
102 | elif isinstance(m, nn.BatchNorm3d):
103 | m.weight.data.fill_(1)
104 | m.bias.data.zero_()
105 | elif isinstance(m, nn.Linear):
106 | m.bias.data.zero_()
107 |
--------------------------------------------------------------------------------
/deeppruner/models/submodules2d.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch.nn as nn
18 | from models.submodules import SubModule, convbn_2d_lrelu
19 |
20 |
21 | class RefinementNet(SubModule):
22 | def __init__(self, inplanes):
23 | super(RefinementNet, self).__init__()
24 |
25 | self.conv1 = nn.Sequential(
26 | convbn_2d_lrelu(inplanes, 32, kernel_size=3, stride=1, pad=1),
27 | convbn_2d_lrelu(32, 32, kernel_size=3, stride=1, pad=1, dilation=1),
28 | convbn_2d_lrelu(32, 32, kernel_size=3, stride=1, pad=1, dilation=1),
29 | convbn_2d_lrelu(32, 16, kernel_size=3, stride=1, pad=2, dilation=2),
30 | convbn_2d_lrelu(16, 16, kernel_size=3, stride=1, pad=4, dilation=4),
31 | convbn_2d_lrelu(16, 16, kernel_size=3, stride=1, pad=1, dilation=1))
32 |
33 | self.classif1 = nn.Conv2d(16, 1, kernel_size=3, padding=1, stride=1, bias=False)
34 | self.relu = nn.ReLU(inplace=True)
35 |
36 | self.weight_init()
37 |
38 | def forward(self, input, disparity):
39 | """
40 | Refinement Block
41 | Description: The network takes left image convolutional features from the second residual block
42 | of the feature network and the current disparity estimation as input.
43 | It then outputs the finetuned disparity prediction. The low-level feature
44 | information serves as a guidance to reduce noise and improve the quality of the final
45 | disparity map, especially on sharp boundaries.
46 |
47 | Args:
48 | :input: Input features composed of left image low-level features, cost-aggregator features, and
49 | cost-aggregator disparity.
50 |
51 | :disparity: predicted disparity
52 | """
53 |
54 | output0 = self.conv1(input)
55 | output0 = self.classif1(output0)
56 | output = self.relu(output0 + disparity)
57 |
58 | return output
--------------------------------------------------------------------------------
/deeppruner/models/submodules3d.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch
18 | import torch.nn as nn
19 | import torch.utils.data
20 | from models.submodules import SubModule, convbn_3d_lrelu, convbn_transpose_3d
21 |
22 |
23 | class HourGlass(SubModule):
24 | def __init__(self, inplanes=16):
25 | super(HourGlass, self).__init__()
26 |
27 | self.conv1 = convbn_3d_lrelu(inplanes, inplanes * 2, kernel_size=3, stride=2, pad=1)
28 | self.conv2 = convbn_3d_lrelu(inplanes * 2, inplanes * 2, kernel_size=3, stride=1, pad=1)
29 |
30 | self.conv1_1 = convbn_3d_lrelu(inplanes * 2, inplanes * 4, kernel_size=3, stride=2, pad=1)
31 | self.conv2_1 = convbn_3d_lrelu(inplanes * 4, inplanes * 4, kernel_size=3, stride=1, pad=1)
32 |
33 | self.conv3 = convbn_3d_lrelu(inplanes * 4, inplanes * 8, kernel_size=3, stride=2, pad=1)
34 | self.conv4 = convbn_3d_lrelu(inplanes * 8, inplanes * 8, kernel_size=3, stride=1, pad=1)
35 |
36 | self.conv5 = convbn_transpose_3d(inplanes * 8, inplanes * 4, kernel_size=3, padding=1,
37 | output_padding=(0, 1, 1), stride=(1, 2, 2), bias=False)
38 | self.conv6 = convbn_transpose_3d(inplanes * 4, inplanes * 2, kernel_size=3, padding=1,
39 | output_padding=(0, 1, 1), stride=(1, 2, 2), bias=False)
40 | self.conv7 = convbn_transpose_3d(inplanes * 2, inplanes, kernel_size=3, padding=1,
41 | output_padding=(0, 1, 1), stride=(1, 2, 2), bias=False)
42 |
43 | self.last_conv3d_layer = nn.Sequential(
44 | convbn_3d_lrelu(inplanes, inplanes * 2, 3, 1, 1),
45 | nn.Conv3d(inplanes * 2, 1, kernel_size=3, padding=1, stride=1, bias=False))
46 |
47 | self.softmax = nn.Softmax(dim=1)
48 |
49 | self.weight_init()
50 |
51 |
52 | class MaxDisparityPredictor(HourGlass):
53 |
54 | def __init__(self, hourglass_inplanes=16):
55 | super(MaxDisparityPredictor, self).__init__(hourglass_inplanes)
56 |
57 | def forward(self, input, input_disparity):
58 | """
59 | Confidence Range Prediction (Max Disparity):
60 | Description: The network has a convolutional encoder-decoder structure. It takes the sparse
61 | disparity estimations from the differentiable PatchMatch, the left image and the warped right image
62 | (warped according to the sparse disparity estimations) as input and outputs the upper bound of
63 | the confidence range for each pixel i.
64 | Args:
65 | :input: Left and Warped right Image features as Cost Volume.
66 | :input_disparity: PatchMatch predicted disparity samples.
67 | Returns:
68 | :disparity_output: Max Disparity of the reduced disaprity search range.
69 | :feature_output: High-level features of the MaxDisparityPredictor
70 | """
71 |
72 | output0 = self.conv1(input)
73 | output0_a = self.conv2(output0) + output0
74 |
75 | output0 = self.conv1_1(output0_a)
76 | output0_c = self.conv2_1(output0) + output0
77 |
78 | output0 = self.conv3(output0_c)
79 | output0 = self.conv4(output0) + output0
80 |
81 | output1 = self.conv5(output0) + output0_c
82 | output1 = self.conv6(output1) + output0_a
83 | output1 = self.conv7(output1)
84 |
85 | output2 = self.last_conv3d_layer(output1).squeeze(1)
86 | feature_output = output2
87 |
88 | confidence_output = self.softmax(output2)
89 | disparity_output = torch.sum(confidence_output * input_disparity, dim=1).unsqueeze(1)
90 |
91 | return disparity_output, feature_output
92 |
93 |
94 | class MinDisparityPredictor(HourGlass):
95 |
96 | def __init__(self, hourglass_inplanes=16):
97 | super(MinDisparityPredictor, self).__init__(hourglass_inplanes)
98 |
99 | def forward(self, input, input_disparity):
100 | """
101 | Confidence Range Prediction (Min Disparity):
102 | Description: The network has a convolutional encoder-decoder structure. It takes the sparse
103 | disparity estimations from the differentiable PatchMatch, the left image and the warped right image
104 | (warped according to the sparse disparity estimations) as input and outputs the lower bound of
105 | the confidence range for each pixel i.
106 | Args:
107 | :input: Left and Warped right Image features as Cost Volume.
108 | :input_disparity: PatchMatch predicted disparity samples.
109 | Returns:
110 | :disparity_output: Min Disparity of the reduced disaprity search range.
111 | :feature_output: High-level features of the MaxDisparityPredictor
112 | """
113 |
114 | output0 = self.conv1(input)
115 | output0_a = self.conv2(output0) + output0
116 |
117 | output0 = self.conv1_1(output0_a)
118 | output0_c = self.conv2_1(output0) + output0
119 |
120 | output0 = self.conv3(output0_c)
121 | output0 = self.conv4(output0) + output0
122 |
123 | output1 = self.conv5(output0) + output0_c
124 | output1 = self.conv6(output1) + output0_a
125 | output1 = self.conv7(output1)
126 |
127 | output2 = self.last_conv3d_layer(output1).squeeze(1)
128 | feature_output = output2
129 |
130 | confidence_output = self.softmax(output2)
131 | disparity_output = torch.sum(confidence_output * input_disparity, dim=1).unsqueeze(1)
132 |
133 | return disparity_output, feature_output
134 |
135 |
136 | class CostAggregator(HourGlass):
137 |
138 | def __init__(self, cost_aggregator_inplanes, hourglass_inplanes=16):
139 | super(CostAggregator, self).__init__(inplanes=16)
140 |
141 | self.dres0 = nn.Sequential(convbn_3d_lrelu(cost_aggregator_inplanes, 64, 3, 1, 1),
142 | convbn_3d_lrelu(64, 32, 3, 1, 1))
143 |
144 | self.dres1 = nn.Sequential(convbn_3d_lrelu(32, 32, 3, 1, 1),
145 | convbn_3d_lrelu(32, hourglass_inplanes, 3, 1, 1))
146 |
147 | def forward(self, input, input_disparity):
148 | """
149 | 3D Cost Aggregator
150 | Description: Based on the predicted range in the pruning module,
151 | we build the 3D cost volume estimator and conduct spatial aggregation.
152 | Following common practice, we take the left image, the warped right image and corresponding disparities
153 | as input and output the cost over the disparity range at the size B X R X H X W , where R is the number
154 | of disparities per pixel. Compared to prior work, our R is more than 10 times smaller, making
155 | this module very efficient. Soft-arg max is again used to predict the disparity value ,
156 | so that our approach is end-to-end trainable.
157 |
158 | Args:
159 | :input: Cost-Volume composed of left image features, warped right image features,
160 | Confidence range Predictor features and input disparity samples/
161 |
162 | :input_disparity: input disparity samples.
163 |
164 | Returns:
165 | :disparity_output: Predicted disparity
166 | :feature_output: High-level features of 3d-Cost Aggregator
167 |
168 | """
169 |
170 | output0 = self.dres0(input)
171 | output0_b = self.dres1(output0)
172 |
173 | output0 = self.conv1(output0_b)
174 | output0_a = self.conv2(output0) + output0
175 |
176 | output0 = self.conv1_1(output0_a)
177 | output0_c = self.conv2_1(output0) + output0
178 |
179 | output0 = self.conv3(output0_c)
180 | output0 = self.conv4(output0) + output0
181 |
182 | output1 = self.conv5(output0) + output0_c
183 | output1 = self.conv6(output1) + output0_a
184 | output1 = self.conv7(output1) + output0_b
185 |
186 | output2 = self.last_conv3d_layer(output1).squeeze(1)
187 | feature_output = output2
188 |
189 | confidence_output = self.softmax(output2)
190 | disparity_output = torch.sum(confidence_output * input_disparity, dim=1)
191 |
192 | return disparity_output.unsqueeze(1), feature_output
193 |
--------------------------------------------------------------------------------
/deeppruner/models/utils.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import torch
18 | import torch.nn as nn
19 |
20 |
21 | class UniformSampler(nn.Module):
22 | def __init__(self):
23 | super(UniformSampler, self).__init__()
24 |
25 | def forward(self, min_disparity, max_disparity, number_of_samples=10):
26 | """
27 | Uniform Sampler
28 | Description: The Confidence Range Predictor predicts a reduced disparity search range R(i) = [l(i), u(i)]
29 | for each pixel i. We then, generate disparity samples from this reduced search range for Cost Aggregation
30 | or second stage of Patch Match. From experiments, we found Uniform sampling to work better.
31 |
32 | Args:
33 | :min_disparity: lower bound of disparity search range (predicted by Confidence Range Predictor)
34 | :max_disparity: upper bound of disparity range predictor (predicted by Confidence Range Predictor)
35 | :number_of_samples (default:10): number of samples to be genearted.
36 | Returns:
37 | :sampled_disparities: Uniformly generated disparity samples from the input search range.
38 | """
39 |
40 | device = min_disparity.get_device()
41 |
42 | multiplier = (max_disparity - min_disparity) / (number_of_samples + 1)
43 | range_multiplier = torch.arange(1.0, number_of_samples + 1, 1, device=device).view(number_of_samples, 1, 1)
44 | sampled_disparities = min_disparity + multiplier * range_multiplier
45 |
46 | return sampled_disparities
47 |
48 |
49 | class SpatialTransformer(nn.Module):
50 | def __init__(self):
51 | super(SpatialTransformer, self).__init__()
52 |
53 | def forward(self, left_input, right_input, disparity_samples):
54 | """
55 | Disparity Sample Cost Evaluator
56 | Description:
57 | Given the left image features, right iamge features and teh disparity samples, generates:
58 | - Per sample cost as , <.,.> denotes scalar-product.
59 | - Warped righ image features
60 |
61 | Args:
62 | :left_input: Left Image Features
63 | :right_input: Right Image Features
64 | :disparity_samples: Disparity Samples genearted by PatchMatch
65 |
66 | Returns:
67 | :disparity_samples_strength_1: Cost associated with each disaprity sample.
68 | :warped_right_feature_map: right iamge features warped according to input disparity.
69 | :left_feature_map: expanded left image features.
70 | """
71 |
72 | device = left_input.get_device()
73 | left_y_coordinate = torch.arange(0.0, left_input.size()[3], device=device).repeat(left_input.size()[2])
74 | left_y_coordinate = left_y_coordinate.view(left_input.size()[2], left_input.size()[3])
75 | left_y_coordinate = torch.clamp(left_y_coordinate, min=0, max=left_input.size()[3] - 1)
76 | left_y_coordinate = left_y_coordinate.expand(left_input.size()[0], -1, -1)
77 |
78 | right_feature_map = right_input.expand(disparity_samples.size()[1], -1, -1, -1, -1).permute([1, 2, 0, 3, 4])
79 | left_feature_map = left_input.expand(disparity_samples.size()[1], -1, -1, -1, -1).permute([1, 2, 0, 3, 4])
80 |
81 | disparity_samples = disparity_samples.float()
82 |
83 | right_y_coordinate = left_y_coordinate.expand(
84 | disparity_samples.size()[1], -1, -1, -1).permute([1, 0, 2, 3]) - disparity_samples
85 |
86 | right_y_coordinate_1 = right_y_coordinate
87 | right_y_coordinate = torch.clamp(right_y_coordinate, min=0, max=right_input.size()[3] - 1)
88 |
89 | warped_right_feature_map = torch.gather(right_feature_map, dim=4, index=right_y_coordinate.expand(
90 | right_input.size()[1], -1, -1, -1, -1).permute([1, 0, 2, 3, 4]).long())
91 |
92 | right_y_coordinate_1 = right_y_coordinate_1.unsqueeze(1)
93 | warped_right_feature_map = (1 - ((right_y_coordinate_1 < 0) +
94 | (right_y_coordinate_1 > right_input.size()[3] - 1)).float()) * \
95 | (warped_right_feature_map) + torch.zeros_like(warped_right_feature_map)
96 |
97 | return warped_right_feature_map, left_feature_map
98 |
--------------------------------------------------------------------------------
/deeppruner/setup_logging.py:
--------------------------------------------------------------------------------
1 | import logging
2 |
3 | def setup_logging(filename):
4 |
5 | log_format = '%(filename)s: %(message)s'
6 | logging.basicConfig(format=log_format, level=logging.INFO)
7 |
8 | file_handler = logging.FileHandler(filename)
9 | file_handler.setFormatter(logging.Formatter(fmt=log_format))
10 | file_handler.setLevel(logging.INFO)
11 | logging.getLogger().addHandler(file_handler)
12 |
--------------------------------------------------------------------------------
/deeppruner/submission_kitti.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import argparse
18 | import os
19 | import random
20 | import torch
21 | import torch.nn as nn
22 | import torch.nn.parallel
23 | import torch.backends.cudnn as cudnn
24 | import torch.optim as optim
25 | import torch.utils.data
26 | from torch.autograd import Variable
27 | import torch.nn.functional as F
28 | import skimage.io
29 | import numpy as np
30 | import logging
31 | from dataloader import kitti_submission_collector as ls
32 | from dataloader import preprocess
33 | from PIL import Image
34 | from models.deeppruner import DeepPruner
35 | from models.config import config as config_args
36 | from setup_logging import setup_logging
37 |
38 | parser = argparse.ArgumentParser(description='DeepPruner')
39 | parser.add_argument('--datapath', default='/',
40 | help='datapath')
41 | parser.add_argument('--loadmodel', default=None,
42 | help='load model')
43 | parser.add_argument('--save_dir', default='./',
44 | help='save directory')
45 | parser.add_argument('--logging_filename', default='./submission_kitti.log',
46 | help='filename for logs')
47 | parser.add_argument('--no-cuda', action='store_true', default=False,
48 | help='enables CUDA training')
49 | parser.add_argument('--seed', type=int, default=1, metavar='S',
50 | help='random seed (default: 1)')
51 |
52 | args = parser.parse_args()
53 | torch.backends.cudnn.benchmark = True
54 | args.cuda = not args.no_cuda and torch.cuda.is_available()
55 |
56 | args.cost_aggregator_scale = config_args.cost_aggregator_scale
57 | args.downsample_scale = args.cost_aggregator_scale * 8.0
58 |
59 | setup_logging(args.logging_filename)
60 |
61 | if args.cuda:
62 | torch.manual_seed(args.seed)
63 | np.random.seed(args.seed)
64 | random.seed(args.seed)
65 | torch.cuda.manual_seed(args.seed)
66 | torch.backends.cudnn.deterministic = True
67 |
68 |
69 | test_left_img, test_right_img = ls.datacollector(args.datapath)
70 |
71 | model = DeepPruner()
72 | model = nn.DataParallel(model)
73 |
74 | if args.cuda:
75 | model.cuda()
76 |
77 | logging.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
78 |
79 |
80 | if args.loadmodel is not None:
81 | logging.info("loading model...")
82 | state_dict = torch.load(args.loadmodel)
83 | model.load_state_dict(state_dict['state_dict'], strict=True)
84 |
85 |
86 | def test(imgL, imgR):
87 | model.eval()
88 | with torch.no_grad():
89 | imgL = Variable(torch.FloatTensor(imgL))
90 | imgR = Variable(torch.FloatTensor(imgR))
91 |
92 | if args.cuda:
93 | imgL, imgR = imgL.cuda(), imgR.cuda()
94 |
95 | refined_disparity = model(imgL, imgR)
96 | return refined_disparity
97 |
98 |
99 | def main():
100 |
101 | for left_image_path, right_image_path in zip(test_left_img, test_right_img):
102 | imgL = np.asarray(Image.open(left_image_path))
103 | imgR = np.asarray(Image.open(right_image_path))
104 |
105 | processed = preprocess.get_transform()
106 | imgL = processed(imgL).numpy()
107 | imgR = processed(imgR).numpy()
108 |
109 | imgL = np.reshape(imgL, [1, 3, imgL.shape[1], imgL.shape[2]])
110 | imgR = np.reshape(imgR, [1, 3, imgR.shape[1], imgR.shape[2]])
111 |
112 | w = imgL.shape[3]
113 | h = imgL.shape[2]
114 | dw = int(args.downsample_scale - (w%args.downsample_scale + (w%args.downsample_scale==0)*args.downsample_scale))
115 | dh = int(args.downsample_scale - (h%args.downsample_scale + (h%args.downsample_scale==0)*args.downsample_scale))
116 |
117 | top_pad = dh
118 | left_pad = dw
119 | imgL = np.lib.pad(imgL, ((0, 0), (0, 0), (top_pad, 0), (0, left_pad)), mode='constant', constant_values=0)
120 | imgR = np.lib.pad(imgR, ((0, 0), (0, 0), (top_pad, 0), (0, left_pad)), mode='constant', constant_values=0)
121 |
122 | disparity = test(imgL, imgR)
123 | disparity = disparity[0, top_pad:, :-left_pad].data.cpu().numpy()
124 | skimage.io.imsave(os.path.join(args.save_dir, left_image_path.split('/')
125 | [-1]), (disparity * 256).astype('uint16'))
126 |
127 | logging.info("Disparity for {} generated at: {}".format(left_image_path, os.path.join(args.save_dir,
128 | left_image_path.split('/')[-1])))
129 |
130 |
131 | if __name__ == '__main__':
132 | main()
133 |
--------------------------------------------------------------------------------
/deeppruner/train_sceneflow.py:
--------------------------------------------------------------------------------
1 | # ---------------------------------------------------------------------------
2 | # DeepPruner: Learning Efficient Stereo Matching via Differentiable PatchMatch
3 | #
4 | # Copyright (c) 2019 Uber Technologies, Inc.
5 | #
6 | # Licensed under the Uber Non-Commercial License (the "License");
7 | # you may not use this file except in compliance with the License.
8 | # You may obtain a copy of the License at the root directory of this project.
9 | #
10 | # See the License for the specific language governing permissions and
11 | # limitations under the License.
12 | #
13 | # Written by Shivam Duggal
14 | # ---------------------------------------------------------------------------
15 |
16 | from __future__ import print_function
17 | import argparse
18 | import os
19 | import random
20 | import torch
21 | import torch.nn as nn
22 | import torch.backends.cudnn as cudnn
23 | import torch.optim as optim
24 | import torch.utils.data
25 | from torch.autograd import Variable
26 | import numpy as np
27 | from dataloader import sceneflow_collector as lt
28 | from dataloader import sceneflow_loader as DA
29 | from models.deeppruner import DeepPruner
30 | from loss_evaluation import loss_evaluation
31 | from tensorboardX import SummaryWriter
32 | import skimage
33 | import time
34 | import logging
35 | from models.config import config as config_args
36 | from setup_logging import setup_logging
37 |
38 | parser = argparse.ArgumentParser(description='DeepPruner')
39 | parser.add_argument('--datapath_monkaa', default='/',
40 | help='datapath for sceneflow monkaa dataset')
41 | parser.add_argument('--datapath_flying', default='/',
42 | help='datapath for sceneflow flying dataset')
43 | parser.add_argument('--datapath_driving', default='/',
44 | help='datapath for sceneflow driving dataset')
45 | parser.add_argument('--epochs', type=int, default=100,
46 | help='number of epochs to train')
47 | parser.add_argument('--loadmodel', default=None,
48 | help='load model')
49 | parser.add_argument('--save_dir', default='./',
50 | help='save directory')
51 | parser.add_argument('--savemodel', default='./',
52 | help='save model')
53 | parser.add_argument('--logging_filename', default='./train_sceneflow.log',
54 | help='save model')
55 | parser.add_argument('--no-cuda', action='store_true', default=False,
56 | help='enables CUDA training')
57 | parser.add_argument('--seed', type=int, default=1, metavar='S',
58 | help='random seed (default: 1)')
59 |
60 | args = parser.parse_args()
61 | args.cuda = not args.no_cuda and torch.cuda.is_available()
62 |
63 | torch.manual_seed(args.seed)
64 | if args.cuda:
65 | torch.manual_seed(args.seed)
66 | np.random.seed(args.seed)
67 | random.seed(args.seed)
68 | torch.cuda.manual_seed(args.seed)
69 | torch.backends.cudnn.deterministic = True
70 |
71 |
72 | args.cost_aggregator_scale = config_args.cost_aggregator_scale
73 | args.maxdisp = config_args.max_disp
74 |
75 | setup_logging(args.logging_filename)
76 |
77 |
78 | all_left_img, all_right_img, all_left_disp, test_left_img, test_right_img, test_left_disp = lt.dataloader(args.datapath_monkaa,
79 | args.datapath_flying, args.datapath_driving)
80 |
81 |
82 | TrainImgLoader = torch.utils.data.DataLoader(
83 | DA.SceneflowLoader(all_left_img, all_right_img, all_left_disp, args.cost_aggregator_scale*8.0, True),
84 | batch_size=1, shuffle=True, num_workers=8, drop_last=False)
85 |
86 | TestImgLoader = torch.utils.data.DataLoader(
87 | DA.SceneflowLoader(test_left_img, test_right_img, test_left_disp, args.cost_aggregator_scale*8.0, False),
88 | batch_size=1, shuffle=False, num_workers=4, drop_last=False)
89 |
90 |
91 | model = DeepPruner()
92 | writer = SummaryWriter()
93 |
94 | if args.cuda:
95 | model = nn.DataParallel(model)
96 | model.cuda()
97 |
98 |
99 | if args.loadmodel is not None:
100 | state_dict = torch.load(args.loadmodel)
101 | model.load_state_dict(state_dict['state_dict'], strict=True)
102 |
103 | logging.info('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()])))
104 |
105 | optimizer = optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999))
106 |
107 |
108 | def train(imgL, imgR, disp_L, iteration):
109 | model.train()
110 | imgL = Variable(torch.FloatTensor(imgL))
111 | imgR = Variable(torch.FloatTensor(imgR))
112 | disp_L = Variable(torch.FloatTensor(disp_L))
113 |
114 | if args.cuda:
115 | imgL, imgR, disp_true = imgL.cuda(), imgR.cuda(), disp_L.cuda()
116 |
117 | mask = disp_true < args.maxdisp
118 | mask.detach_()
119 |
120 | optimizer.zero_grad()
121 | result = model(imgL, imgR)
122 |
123 | loss, _ = loss_evaluation(result, disp_true, mask, args.cost_aggregator_scale)
124 |
125 | loss.backward()
126 | optimizer.step()
127 |
128 | return loss.item()
129 |
130 |
131 | def test(imgL, imgR, disp_L, iteration, pad_w, pad_h):
132 |
133 | model.eval()
134 | with torch.no_grad():
135 | imgL = Variable(torch.FloatTensor(imgL))
136 | imgR = Variable(torch.FloatTensor(imgR))
137 | disp_L = Variable(torch.FloatTensor(disp_L))
138 |
139 | if args.cuda:
140 | imgL, imgR, disp_true = imgL.cuda(), imgR.cuda(), disp_L.cuda()
141 |
142 | mask = disp_true < args.maxdisp
143 | mask.detach_()
144 |
145 | if len(disp_true[mask]) == 0:
146 | logging.info("invalid GT disaprity...")
147 | return 0, 0
148 |
149 | optimizer.zero_grad()
150 |
151 | result = model(imgL, imgR)
152 | output = []
153 | for ind in range(len(result)):
154 | output.append(result[ind][:, pad_h:, pad_w:])
155 | result = output
156 |
157 | loss, output_disparity = loss_evaluation(result, disp_true, mask, args.cost_aggregator_scale)
158 | epe_loss = torch.mean(torch.abs(output_disparity[mask] - disp_true[mask]))
159 |
160 | return loss.item(), epe_loss.item()
161 |
162 |
163 | def adjust_learning_rate(optimizer, epoch):
164 | if epoch <= 20:
165 | lr = 0.001
166 | elif epoch <= 40:
167 | lr = 0.0007
168 | elif epoch <= 60:
169 | lr = 0.0003
170 | else:
171 | lr = 0.0001
172 | for param_group in optimizer.param_groups:
173 | param_group['lr'] = lr
174 |
175 |
176 | def main():
177 | for epoch in range(0, args.epochs):
178 | total_train_loss = 0
179 | total_test_loss = 0
180 | total_epe_loss = 0
181 | adjust_learning_rate(optimizer, epoch)
182 |
183 | if epoch % 1 == 0 and epoch != 0:
184 | logging.info("testing...")
185 | for batch_idx, (imgL, imgR, disp_L, pad_w, pad_h) in enumerate(TestImgLoader):
186 | start_time = time.time()
187 | test_loss, epe_loss = test(imgL, imgR, disp_L, batch_idx, int(pad_w[0].item()), int(pad_h[0].item()))
188 | total_test_loss += test_loss
189 | total_epe_loss += epe_loss
190 |
191 | logging.info('Iter %d 3-px error in val = %.3f, time = %.2f \n' %
192 | (batch_idx, epe_loss, time.time() - start_time))
193 |
194 | writer.add_scalar("val-loss-iter", test_loss, epoch * 4370 + batch_idx)
195 | writer.add_scalar("val-epe-loss-iter", epe_loss, epoch * 4370 + batch_idx)
196 |
197 | logging.info('epoch %d total test loss = %.3f' % (epoch, total_test_loss / len(TestImgLoader)))
198 | writer.add_scalar("val-loss", total_test_loss / len(TestImgLoader), epoch)
199 | logging.info('epoch %d total epe loss = %.3f' % (epoch, total_epe_loss / len(TestImgLoader)))
200 | writer.add_scalar("epe-loss", total_epe_loss / len(TestImgLoader), epoch)
201 |
202 | for batch_idx, (imgL_crop, imgR_crop, disp_crop_L) in enumerate(TrainImgLoader):
203 | start_time = time.time()
204 | loss = train(imgL_crop, imgR_crop, disp_crop_L, batch_idx)
205 | total_train_loss += loss
206 |
207 | writer.add_scalar("loss-iter", loss, batch_idx + 35454 * epoch)
208 | logging.info('Iter %d training loss = %.3f , time = %.2f \n' % (batch_idx, loss, time.time() - start_time))
209 |
210 | logging.info('epoch %d total training loss = %.3f' % (epoch, total_train_loss / len(TrainImgLoader)))
211 | writer.add_scalar("loss", total_train_loss / len(TrainImgLoader), epoch)
212 |
213 | # SAVE
214 | if epoch % 1 == 0:
215 | savefilename = args.savemodel + 'finetune_' + str(epoch) + '.tar'
216 | torch.save({
217 | 'epoch': epoch,
218 | 'state_dict': model.state_dict(),
219 | 'train_loss': total_train_loss,
220 | 'test_loss': total_test_loss,
221 | }, savefilename)
222 |
223 |
224 | if __name__ == '__main__':
225 | main()
226 |
--------------------------------------------------------------------------------
/readme_images/CRP.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/CRP.png
--------------------------------------------------------------------------------
/readme_images/DPM.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM.png
--------------------------------------------------------------------------------
/readme_images/DPM_filters.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_filters.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/DPM/Never back down 1032.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/DPM/Never back down 1032.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/DPM/Never back down 1979.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/DPM/Never back down 1979.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/DPM/Never back down 2046.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/DPM/Never back down 2046.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/DPM/cheetah1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/DPM/cheetah1.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/DPM/face1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/DPM/face1.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/ImageA/Never back down 1032.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageA/Never back down 1032.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/ImageA/Never back down 1979.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageA/Never back down 1979.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/ImageA/Never back down 2046.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageA/Never back down 2046.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/ImageA/cheetah1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageA/cheetah1.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/ImageA/face1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageA/face1.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/ImageB/Never back down 1213.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageB/Never back down 1213.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/ImageB/Never back down 2040.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageB/Never back down 2040.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/ImageB/Never back down 2066.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageB/Never back down 2066.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/ImageB/cheetah2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageB/cheetah2.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/ImageB/face2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/ImageB/face2.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/PM/Never back down 1032.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/PM/Never back down 1032.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/PM/Never back down 1979.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/PM/Never back down 1979.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/PM/Never back down 2046.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/PM/Never back down 2046.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/PM/cheetah1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/PM/cheetah1.png
--------------------------------------------------------------------------------
/readme_images/DPM_reconstruction/PM/face.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DPM_reconstruction/PM/face.png
--------------------------------------------------------------------------------
/readme_images/DeepPruner.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/DeepPruner.png
--------------------------------------------------------------------------------
/readme_images/KITTI_test_set.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/KITTI_test_set.png
--------------------------------------------------------------------------------
/readme_images/kitti_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/kitti_results.png
--------------------------------------------------------------------------------
/readme_images/original_patch_match.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/original_patch_match.png
--------------------------------------------------------------------------------
/readme_images/original_patch_match_steps.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/original_patch_match_steps.png
--------------------------------------------------------------------------------
/readme_images/rob.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/rob.png
--------------------------------------------------------------------------------
/readme_images/rob_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/rob_results.png
--------------------------------------------------------------------------------
/readme_images/sceneflow.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/sceneflow.png
--------------------------------------------------------------------------------
/readme_images/sceneflow_results.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/sceneflow_results.png
--------------------------------------------------------------------------------
/readme_images/softmax.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/softmax.png
--------------------------------------------------------------------------------
/readme_images/uncertainty_vis.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/uber-research/DeepPruner/40b188cf954577e21d5068db2be2bedc6b0e8781/readme_images/uncertainty_vis.png
--------------------------------------------------------------------------------