├── .gitignore ├── README.md └── modeling ├── __init__.py ├── backbone ├── __init__.py └── psmnet_backbone.py ├── conf_meausre ├── __init__.py └── conf_net.py ├── cost_aggregation ├── __init__.py ├── acfnet_cost.py └── utils │ ├── __init__.py │ └── hourglass.py ├── cost_computation ├── __init__.py └── cat_fms.py ├── disp_prediction ├── __init__.py ├── faster_soft_argmin.py └── soft_argmin.py ├── layers ├── __init__.py └── basic_layers.py ├── loss ├── __init__.py ├── conf_nll_loss.py ├── smooth_l1_loss.py ├── stereo_focal_loss.py └── utils │ ├── __init__.py │ └── disp2prob.py └── models ├── AcfNet.py └── __init__.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | 6 | # C extensions 7 | *.so 8 | 9 | # Distribution / packaging 10 | .Python 11 | build/ 12 | develop-eggs/ 13 | dist/ 14 | downloads/ 15 | eggs/ 16 | .eggs/ 17 | lib/ 18 | lib64/ 19 | parts/ 20 | sdist/ 21 | var/ 22 | wheels/ 23 | *.egg-info/ 24 | .installed.cfg 25 | *.egg 26 | MANIFEST 27 | 28 | # PyInstaller 29 | # Usually these files are written by a python script from a template 30 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 31 | *.manifest 32 | *.spec 33 | 34 | # Installer logs 35 | pip-log.txt 36 | pip-delete-this-directory.txt 37 | 38 | # Unit test / coverage reports 39 | htmlcov/ 40 | .tox/ 41 | .coverage 42 | .coverage.* 43 | .cache 44 | nosetests.xml 45 | coverage.xml 46 | *.cover 47 | .hypothesis/ 48 | .pytest_cache/ 49 | 50 | # Translations 51 | *.mo 52 | *.pot 53 | 54 | # Django stuff: 55 | *.log 56 | local_settings.py 57 | db.sqlite3 58 | 59 | # Flask stuff: 60 | instance/ 61 | .webassets-cache 62 | 63 | # Scrapy stuff: 64 | .scrapy 65 | 66 | # Sphinx documentation 67 | docs/_build/ 68 | 69 | # PyBuilder 70 | target/ 71 | 72 | # Jupyter Notebook 73 | .ipynb_checkpoints 74 | 75 | # pyenv 76 | .python-version 77 | 78 | # celery beat schedule file 79 | celerybeat-schedule 80 | 81 | # SageMath parsed files 82 | *.sage.py 83 | 84 | # Environments 85 | .env 86 | .venv 87 | env/ 88 | venv/ 89 | ENV/ 90 | env.bak/ 91 | venv.bak/ 92 | 93 | # Spyder project settings 94 | .spyderproject 95 | .spyproject 96 | 97 | # Rope project settings 98 | .ropeproject 99 | 100 | # mkdocs documentation 101 | /site 102 | 103 | # mypy 104 | .mypy_cache/ 105 | 106 | # cython generated cpp 107 | .vscode 108 | .idea 109 | 110 | # custom 111 | *.pkl 112 | *.pkl.json 113 | *.log.json 114 | 115 | .DS_Store 116 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # AcfNet 2 | This repository contains the code (in PyTorch) for "[Adaptive Unimodal Cost Volume Filtering for Deep Stereo Matching](https://arxiv.org/abs/1909.03751)", accepted in AAAI 2020. 3 | 4 | ## Notes 5 | 6 | * We have provided all modules used in our AcfNet, and it can be directly integrated into our recent released [DenseMatchingBenchmark](https://github.com/DeepMotionAIResearch/DenseMatchingBenchmark) for training. 7 | 8 | 9 | 10 | ## Requirements: 11 | - PyTorch1.1+, Python3.5+, Cuda10.0+ 12 | 13 | ## Reference: 14 | 15 | If you find the code useful, please cite our paper: 16 | 17 | @article{zhang2019adaptive, 18 | title={Adaptive Unimodal Cost Volume Filtering for Deep Stereo Matching}, 19 | author={Zhang, Youmin and Chen, Yimin and Bai, Xiao and Zhou, Jun and Yu, Kun and Li, Zhiwei and Yang, Kuiyuan}, 20 | journal={arXiv preprint arXiv:1909.03751}, 21 | year={2019} 22 | } 23 | -------------------------------------------------------------------------------- /modeling/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/youmi-zym/AcfNet/82798abd27636fe5e28fab0a8ed6460c9ad9c2b1/modeling/__init__.py -------------------------------------------------------------------------------- /modeling/backbone/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/youmi-zym/AcfNet/82798abd27636fe5e28fab0a8ed6460c9ad9c2b1/modeling/backbone/__init__.py -------------------------------------------------------------------------------- /modeling/backbone/psmnet_backbone.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..layers.basic_layers import conv_bn, conv_bn_relu, BasicBlock 6 | 7 | 8 | class PsmBb(nn.Module): 9 | """ 10 | Backbone proposed in PSMNet. 11 | Args: 12 | in_planes, (int): the channels of input 13 | batchNorm, (bool): whether use batch normalization layer, default True 14 | Inputs: 15 | l_img, (torch.Tensor): left image 16 | r_img, (torch.Tensor): right image 17 | Outputs: 18 | l_fms, (torch.Tensor): left image feature maps 19 | right, (torch.Tensor): right image feature maps 20 | """ 21 | def __init__(self, in_planes=3, batchNorm=True): 22 | super(PsmBb, self).__init__() 23 | self.in_planes = in_planes 24 | self.batchNorm = batchNorm 25 | 26 | self.firstconv = nn.Sequential( 27 | conv_bn_relu(self.batchNorm, self.in_planes, 32, 3, 2, 1, 1, bias=False), 28 | conv_bn_relu(self.batchNorm, 32, 32, 3, 1, 1, 1, bias=False), 29 | conv_bn_relu(self.batchNorm, 32, 32, 3, 1, 1, 1, bias=False), 30 | ) 31 | 32 | # For building Basic Block 33 | self.in_planes = 32 34 | 35 | self.layer1 = self._make_layer(self.batchNorm, BasicBlock, 32, 3, 1, 1, 1) 36 | self.layer2 = self._make_layer(self.batchNorm, BasicBlock, 64, 16, 2, 1, 1) 37 | self.layer3 = self._make_layer(self.batchNorm, BasicBlock, 128, 3, 1, 1, 1) 38 | self.layer4 = self._make_layer(self.batchNorm, BasicBlock, 128, 3, 1, 2, 2) 39 | 40 | self.branch1 = nn.Sequential( 41 | nn.AvgPool2d((64, 64), stride=(64, 64)), 42 | conv_bn_relu(self.batchNorm, 128, 32, 1, 1, 0, 1, bias=False), 43 | ) 44 | self.branch2 = nn.Sequential( 45 | nn.AvgPool2d((32, 32), stride=(32, 32)), 46 | conv_bn_relu(self.batchNorm, 128, 32, 1, 1, 0, 1, bias=False), 47 | ) 48 | self.branch3 = nn.Sequential( 49 | nn.AvgPool2d((16, 16), stride=(16, 16)), 50 | conv_bn_relu(self.batchNorm, 128, 32, 1, 1, 0, 1, bias=False), 51 | ) 52 | self.branch4 = nn.Sequential( 53 | nn.AvgPool2d((8, 8), stride=(8, 8)), 54 | conv_bn_relu(self.batchNorm, 128, 32, 1, 1, 0, 1, bias=False), 55 | ) 56 | self.lastconv = nn.Sequential( 57 | conv_bn_relu(self.batchNorm, 320, 128, 3, 1, 1, 1, bias=False), 58 | nn.Conv2d(128, 32, kernel_size=1, padding=0, stride=1, dilation=1, bias=False) 59 | ) 60 | 61 | def _make_layer(self, batchNorm, block, out_planes, blocks, stride, padding, dilation): 62 | downsample = None 63 | if stride != 1 or self.in_planes != out_planes * block.expansion: 64 | downsample = conv_bn( 65 | batchNorm, self.in_planes, out_planes * block.expansion, 66 | kernel_size=1, stride=stride, padding=0, dilation=1 67 | ) 68 | 69 | layers = [] 70 | layers.append( 71 | block(batchNorm, self.in_planes, out_planes, stride, downsample, padding, dilation) 72 | ) 73 | self.in_planes = out_planes * block.expansion 74 | for i in range(1, blocks): 75 | layers.append( 76 | block(batchNorm, self.in_planes, out_planes, 1, None, padding, dilation) 77 | ) 78 | 79 | return nn.Sequential(*layers) 80 | 81 | def backbone(self, x): 82 | output_2_0 = self.firstconv(x) 83 | output_2_1 = self.layer1(output_2_0) 84 | output_4_0 = self.layer2(output_2_1) 85 | output_4_1 = self.layer3(output_4_0) 86 | output_8 = self.layer4(output_4_1) 87 | 88 | output_branch1 = self.branch1(output_8) 89 | output_branch1 = F.interpolate( 90 | output_branch1, (output_8.size()[2], output_8.size()[3]), 91 | mode='bilinear', align_corners=True 92 | ) 93 | 94 | output_branch2 = self.branch2(output_8) 95 | output_branch2 = F.interpolate( 96 | output_branch2, (output_8.size()[2], output_8.size()[3]), 97 | mode='bilinear', align_corners=True 98 | ) 99 | 100 | output_branch3 = self.branch3(output_8) 101 | output_branch3 = F.interpolate( 102 | output_branch3, (output_8.size()[2], output_8.size()[3]), 103 | mode='bilinear', align_corners=True 104 | ) 105 | 106 | output_branch4 = self.branch4(output_8) 107 | output_branch4 = F.interpolate( 108 | output_branch4, (output_8.size()[2], output_8.size()[3]), 109 | mode='bilinear', align_corners=True 110 | ) 111 | 112 | output_feature = torch.cat( 113 | (output_4_0, output_8, output_branch4, output_branch3, output_branch2, output_branch1), 1) 114 | output_feature = self.lastconv(output_feature) 115 | 116 | return output_feature 117 | 118 | def forward(self, *input): 119 | if len(input) != 2: 120 | raise ValueError( 121 | 'expected input length 2 (got {} length input)'.format(len(input)) 122 | ) 123 | l_img, r_img = input 124 | 125 | l_fms = self.backbone(l_img) 126 | r_fms = self.backbone(r_img) 127 | 128 | return l_fms, r_fms 129 | -------------------------------------------------------------------------------- /modeling/conf_meausre/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/youmi-zym/AcfNet/82798abd27636fe5e28fab0a8ed6460c9ad9c2b1/modeling/conf_meausre/__init__.py -------------------------------------------------------------------------------- /modeling/conf_meausre/conf_net.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..layers.basic_layers import conv_bn_relu 6 | 7 | class ConfidenceEstimation(nn.Module): 8 | """ 9 | Args: 10 | in_planes, (int): usually cost volume used to calculate confidence map with $in_planes$ in Channel Dimension 11 | batchNorm, (bool): whether use batch normalization layer, default True 12 | Inputs: 13 | cost, (Tensor): cost volume in (BatchSize, in_planes, Height, Width) layout 14 | Outputs: 15 | confCost, (Tensor): in (BatchSize, 1, Height, Width) layout 16 | """ 17 | 18 | def __init__(self, in_planes, batchNorm=True): 19 | super(ConfidenceEstimation, self).__init__() 20 | 21 | self.in_planes = in_planes 22 | self.sec_in_planes = int(self.in_planes//3) 23 | self.sec_in_planes = self.sec_in_planes if self.sec_in_planes > 0 else 1 24 | 25 | self.conf_net = nn.Sequential(conv_bn_relu(batchNorm, self.in_planes, self.sec_in_planes, 3, 1, 1, bias=False), 26 | nn.Conv2d(self.sec_in_planes, 1, 1, 1, 0, bias=False)) 27 | 28 | def forward(self, cost): 29 | assert cost.shape[1] == self.in_planes 30 | 31 | confCost = self.conf_net(cost) 32 | 33 | return confCost -------------------------------------------------------------------------------- /modeling/cost_aggregation/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/youmi-zym/AcfNet/82798abd27636fe5e28fab0a8ed6460c9ad9c2b1/modeling/cost_aggregation/__init__.py -------------------------------------------------------------------------------- /modeling/cost_aggregation/acfnet_cost.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from ..basic_layers import conv3d_bn, conv3d_bn_relu 6 | from .utils.hourglass import Hourglass 7 | 8 | class AcfCost(nn.Module): 9 | """ 10 | It's the same as PSMNet cost aggregation network. 11 | Args: 12 | max_disp, (int): max disparity 13 | in_planes, (int): the channels of raw cost volume 14 | batchNorm, (bool): whether use batch normalization layer, default True 15 | Outputs: 16 | cost_volume, (tuple of Tensor): cost volume in (BatchSize, MaxDisparity, Height, Width) layout 17 | """ 18 | 19 | def __init__(self, max_disp, in_planes=64, batchNorm=True): 20 | super(AcfCost, self).__init__() 21 | self.max_disp = max_disp 22 | self.in_planes = in_planes 23 | self.batchNorm = batchNorm 24 | 25 | self.dres0 = nn.Sequential( 26 | conv3d_bn_relu(self.batchNorm, self.in_planes, 32, 3, 1, 1), 27 | conv3d_bn_relu(self.batchNorm, 32, 32, 3, 1, 1), 28 | ) 29 | self.dres1 = nn.Sequential( 30 | conv3d_bn_relu(self.batchNorm, 32, 32, 3, 1, 1), 31 | conv3d_bn(self.batchNorm, 32, 32, 3, 1, 1) 32 | ) 33 | self.dres2 = Hourglass(in_planes=32, batchNorm=self.batchNorm) 34 | self.dres3 = Hourglass(in_planes=32, batchNorm=self.batchNorm) 35 | self.dres4 = Hourglass(in_planes=32, batchNorm=self.batchNorm) 36 | 37 | self.classif1 = nn.Sequential( 38 | conv3d_bn_relu(self.batchNorm, 32, 32, 3, 1, 1), 39 | nn.Conv3d(32, 1, kernel_size=3, stride=1, padding=1, bias=False), 40 | ) 41 | self.classif2 = nn.Sequential( 42 | conv3d_bn_relu(self.batchNorm, 32, 32, 3, 1, 1), 43 | nn.Conv3d(32, 1, kernel_size=3, stride=1, padding=1, bias=False), 44 | ) 45 | self.classif3 = nn.Sequential( 46 | conv3d_bn_relu(self.batchNorm, 32, 32, 3, 1, 1), 47 | nn.Conv3d(32, 1, kernel_size=3, stride=1, padding=1, bias=False) 48 | ) 49 | 50 | self.deconv1 = nn.ConvTranspose3d(1, 1, 8, 4, 2, bias=False) 51 | self.deconv2 = nn.ConvTranspose3d(1, 1, 8, 4, 2, bias=False) 52 | self.deconv3 = nn.ConvTranspose3d(1, 1, 8, 4, 2, bias=False) 53 | 54 | def forward(self, raw_cost): 55 | B, C, D, H, W = raw_cost.shape 56 | # concat_fms: (BatchSize, Channels*2, MaxDisparity/4, Height/4, Width/4) 57 | cost0 = self.dres0(raw_cost) 58 | cost0 = self.dres1(cost0) + cost0 59 | 60 | out1, pre1, post1 = self.dres2(cost0, None, None) 61 | out1 = out1 + cost0 62 | 63 | out2, pre2, post2 = self.dres3(out1, pre1, post1) 64 | out2 = out2 + cost0 65 | 66 | out3, pre3, post3 = self.dres4(out2, pre2, post2) 67 | out3 = out3 + cost0 68 | 69 | cost1 = self.classif1(out1) 70 | cost2 = self.classif2(out2) + cost1 71 | cost3 = self.classif3(out3) + cost2 72 | 73 | # (BatchSize, 1, MaxDisparity, Height, Width) 74 | full_h, full_w = H * 4, W * 4 75 | 76 | cost1 = self.deconv1(cost1, [self.max_disp, full_h, full_w]) 77 | cost2 = self.deconv2(cost2, [self.max_disp, full_h, full_w]) 78 | cost3 = self.deconv3(cost3, [self.max_disp, full_h, full_w]) 79 | 80 | # (BatchSize, MaxDisparity, Height, Width) 81 | cost1 = torch.squeeze(cost1, 1) 82 | cost2 = torch.squeeze(cost2, 1) 83 | cost3 = torch.squeeze(cost3, 1) 84 | 85 | return [cost3, cost2, cost1] 86 | -------------------------------------------------------------------------------- /modeling/cost_aggregation/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/youmi-zym/AcfNet/82798abd27636fe5e28fab0a8ed6460c9ad9c2b1/modeling/cost_aggregation/utils/__init__.py -------------------------------------------------------------------------------- /modeling/cost_aggregation/utils/hourglass.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from dmb.modeling.stereo.layers.basic_layers import conv3d_bn, conv3d_bn_relu, conv_bn_relu, deconv3d_bn 6 | 7 | 8 | class Hourglass(nn.Module): 9 | def __init__(self, in_planes, batchNorm=True): 10 | super(Hourglass, self).__init__() 11 | self.batchNorm = batchNorm 12 | 13 | self.conv1 = conv3d_bn_relu( 14 | self.batchNorm, in_planes, in_planes * 2, 15 | kernel_size=3, stride=2, padding=1, bias=False 16 | ) 17 | 18 | self.conv2 = conv3d_bn( 19 | self.batchNorm, in_planes * 2, in_planes * 2, 20 | kernel_size=3, stride=1, padding=1, bias=False 21 | ) 22 | 23 | self.conv3 = conv3d_bn_relu( 24 | self.batchNorm, in_planes * 2, in_planes * 2, 25 | kernel_size=3, stride=2, padding=1, bias=False 26 | ) 27 | self.conv4 = conv3d_bn_relu( 28 | self.batchNorm, in_planes * 2, in_planes * 2, 29 | kernel_size=3, stride=1, padding=1, bias=False 30 | ) 31 | self.conv5 = deconv3d_bn( 32 | self.batchNorm, in_planes * 2, in_planes * 2, 33 | kernel_size=3, padding=1, output_padding=1, stride=2, bias=False 34 | ) 35 | self.conv6 = deconv3d_bn( 36 | self.batchNorm, in_planes * 2, in_planes, 37 | kernel_size=3, padding=1, output_padding=1, stride=2, bias=False 38 | ) 39 | 40 | def forward(self, x, presqu, postsqu): 41 | # in:1/4, out:1/8 42 | out = self.conv1(x) 43 | # in:1/8, out:1/8 44 | pre = self.conv2(out) 45 | if postsqu is not None: 46 | pre = F.relu(pre + postsqu, inplace=True) 47 | else: 48 | pre = F.relu(pre, inplace=True) 49 | 50 | # in:1/8, out:1/16 51 | out = self.conv3(pre) 52 | # in:1/16, out:1/16 53 | out = self.conv4(out) 54 | 55 | # in:1/16, out:1/8 56 | if presqu is not None: 57 | post = F.relu(self.conv5(out) + presqu, inplace=True) 58 | else: 59 | post = F.relu(self.conv5(out) + pre, inplace=True) 60 | 61 | # in:1/8, out:1/4 62 | out = self.conv6(post) 63 | 64 | return out, pre, post 65 | 66 | -------------------------------------------------------------------------------- /modeling/cost_computation/__init__.py: -------------------------------------------------------------------------------- 1 | from .cat_fms import cat_fms -------------------------------------------------------------------------------- /modeling/cost_computation/cat_fms.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | def cat_fms(reference_fm, target_fm, max_disp, start_disp=0, dilation=1): 5 | """ 6 | Based on reference feature, shift target feature within [start disp, end disp] to form the cost volume 7 | Details please refer to GC-Net, generate disparity [start disp, end disp] 8 | Args: 9 | max_disp, (int): under the scale of feature used, often equals to (end disp - start disp + 1), the max searching range of disparity 10 | start_disp (int): the start searching disparity index, usually be 0 11 | dilation (int): the step between near disparity index 12 | 13 | Inputs: 14 | reference_fm, (Tensor): reference feature, i.e. left image feature, in [BatchSize, Channel, Height, Width] layout 15 | target_fm, (Tensor): target feature, i.e. right image feature, in [BatchSize, Channel, Height, Width] layout 16 | 17 | Output: 18 | concat_fm, (Tensor): the formed cost volume, in [BatchSize, Channel*2, disp_sample_number, Height, Width] layout 19 | """ 20 | end_disp = start_disp + max_disp - 1 21 | disp_sample_number = (max_disp+dilation-1)//dilation 22 | 23 | device = reference_fm.device 24 | N, C, H, W = reference_fm.shape 25 | concat_fm = torch.zeros(N, C * 2, disp_sample_number, H, W).to(device) 26 | 27 | # PSMNet cost-volume construction method 28 | idx = 0 29 | for i in range(start_disp, end_disp+1, dilation): 30 | if i > 0: 31 | concat_fm[:, :C, idx, :, i:] = reference_fm[:, :, :, i:] 32 | concat_fm[:, C:, idx, :, i:] = target_fm[:, :, :, :-i] 33 | elif i==0: 34 | concat_fm[:, :C, idx, :, :] = reference_fm 35 | concat_fm[:, C:, idx, :, :] = target_fm 36 | else: 37 | concat_fm[:, :C, idx, :, :i] = reference_fm[:, :, :, :i] 38 | concat_fm[:, C:, idx, :, :i] = target_fm[:, :, :, abs(i):] 39 | idx = idx + 1 40 | 41 | concat_fm = concat_fm.contiguous() 42 | return concat_fm 43 | -------------------------------------------------------------------------------- /modeling/disp_prediction/__init__.py: -------------------------------------------------------------------------------- 1 | from .soft_argmin import soft_argmin 2 | from .faster_soft_argmin import faster_soft_argmin 3 | 4 | __all__ = ['soft_argmin', 'faster_soft_argmin'] 5 | -------------------------------------------------------------------------------- /modeling/disp_prediction/faster_soft_argmin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class faster_soft_argmin(nn.Module): 7 | """ 8 | A faster implementation of soft argmin. 9 | details can refer to dmb.modeling.stereo.disp_prediction.soft_argmin 10 | Args: 11 | max_disp, (int): under the scale of feature used, often equals to (end disp - start disp + 1), the max searching range of disparity 12 | start_disp (int): the start searching disparity index, usually be 0 13 | dilation (int): the step between near disparity index 14 | normalize (bool): whether apply softmax on cost_volume, default True 15 | temperature (float, int): a temperature factor will times with cost_volume 16 | details can refer to: https://bouthilx.wordpress.com/2013/04/21/a-soft-argmax/ 17 | 18 | Inputs: 19 | cost_volume (Tensor): the matching cost after regularization, in [B, disp_sample_number, W, H] layout 20 | 21 | Returns: 22 | disp_map (Tensor): a disparity map regressed from cost volume, in [B, 1, W, H] layout 23 | """ 24 | def __init__(self, max_disp, start_disp=0, dilation=1): 25 | super(faster_soft_argmin, self).__init__() 26 | self.max_disp = max_disp 27 | self.start_disp = start_disp 28 | self.dilation = dilation 29 | self.end_disp = start_disp + max_disp - 1 30 | self.disp_sample_number = (max_disp + dilation - 1) // dilation 31 | 32 | self.disp_regression = nn.Conv3d(1, 1, (self.disp_sample_number, 1, 1), 1, 0, bias=False) 33 | 34 | # compute disparity index: (1 ,1, disp_sample_number, 1, 1) 35 | disp_index = torch.linspace(self.start_disp, self.end_disp, self.disp_sample_number).repeat(1, 1, 1, 1, 1).permute(0, 1, 4, 2, 3).contiguous() 36 | self.disp_regression.weight.data = disp_index 37 | self.disp_regression.weight.requires_grad = False 38 | 39 | def forward(self, cost_volume, normalize=True, temperature=1.0): 40 | cost_volume = cost_volume * temperature 41 | if normalize: 42 | prob_volume = F.softmax(cost_volume, dim=1) 43 | else: 44 | prob_volume = cost_volume 45 | 46 | # [B, disp_sample_number, W, H] -> [B, 1, disp_sample_number, W, H] 47 | prob_volume = prob_volume.unsqueeze(1) 48 | 49 | disp_map = self.disp_regression(prob_volume) 50 | 51 | # [B, 1, 1, W, H] -> [B, 1, W, H] 52 | disp_map = disp_map.squeeze(1) 53 | 54 | return disp_map 55 | 56 | -------------------------------------------------------------------------------- /modeling/disp_prediction/soft_argmin.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | def soft_argmin(cost_volume, max_disp, start_disp=0, dilation=1, normalize=True, temperature=1.0): 6 | # type: (torch.Tensor, int, int, int, bool, [float, int]) -> torch.Tensor 7 | r"""Implementation of soft argmin proposed by GC-Net. 8 | Args: 9 | max_disp, (int): under the scale of feature used, often equals to (end disp - start disp + 1), the max searching range of disparity 10 | start_disp (int): the start searching disparity index, usually be 0 11 | dilation (int): the step between near disparity index 12 | normalize (bool): whether apply softmax on cost_volume, default True 13 | temperature (float, int): a temperature factor will times with cost_volume 14 | details can refer to: https://bouthilx.wordpress.com/2013/04/21/a-soft-argmax/ 15 | 16 | Inputs: 17 | cost_volume (Tensor): the matching cost after regularization, in [B, disp_sample_number, W, H] layout 18 | 19 | Returns: 20 | disp_map (Tensor): a disparity map regressed from cost volume, in [B, 1, W, H] layout 21 | """ 22 | if cost_volume.dim() != 4: 23 | raise ValueError( 24 | 'expected 4D input (got {}D input)'.format(cost_volume.dim()) 25 | ) 26 | end_disp = start_disp + max_disp - 1 27 | disp_sample_number = (max_disp + dilation - 1) // dilation 28 | 29 | # grab cost volume shape 30 | N, D, H, W = cost_volume.shape 31 | 32 | assert disp_sample_number == D, \ 33 | "The number of disparity samples should be same with the size of cost volume Channel dimension!" 34 | 35 | # generate disparity indexes 36 | disp_index = torch.linspace(start_disp, end_disp, disp_sample_number).to(cost_volume.device) 37 | disp_index = disp_index.repeat(N, H, W, 1).permute(0, 3, 1, 2).contiguous() 38 | 39 | # compute probability volume 40 | # prob_volume: (BatchSize, disp_sample_number, Height, Width) 41 | cost_volume = cost_volume * temperature 42 | if normalize: 43 | prob_volume = F.softmax(cost_volume, dim=1) 44 | else: 45 | prob_volume = cost_volume 46 | 47 | # compute disparity: (BatchSize, 1, Height, Width) 48 | disp_map = torch.sum(prob_volume * disp_index, dim=1, keepdim=True) 49 | 50 | return disp_map 51 | -------------------------------------------------------------------------------- /modeling/layers/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/youmi-zym/AcfNet/82798abd27636fe5e28fab0a8ed6460c9ad9c2b1/modeling/layers/__init__.py -------------------------------------------------------------------------------- /modeling/layers/basic_layers.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from torch.nn.modules.utils import _pair, _triple 5 | 6 | def consistent_padding_with_dilation(padding, dilation, dim=2): 7 | assert dim==2 or dim==3, 'Convolution layer only support 2D and 3D' 8 | if dim == 2: 9 | padding = _pair(padding) 10 | dilation = _pair(dilation) 11 | else: # dim == 3 12 | padding = _triple(padding) 13 | dilation = _triple(dilation) 14 | 15 | padding = list(padding) 16 | for d in range(dim): 17 | padding[d] = dilation[d] if dilation[d] > 1 else padding[d] 18 | padding = tuple(padding) 19 | 20 | return padding, dilation 21 | 22 | 23 | def conv_bn(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, bias=True): 24 | padding, dilation = consistent_padding_with_dilation(padding, dilation, dim=2) 25 | if batchNorm: 26 | return nn.Sequential( 27 | nn.Conv2d( 28 | in_planes, out_planes, kernel_size=kernel_size, 29 | stride=stride, padding=padding, 30 | dilation=dilation, bias=bias), 31 | nn.BatchNorm2d(out_planes), 32 | ) 33 | else: 34 | return nn.Sequential( 35 | nn.Conv2d( 36 | in_planes, out_planes, kernel_size=kernel_size, 37 | stride=stride, padding=padding, 38 | dilation=dilation, bias=bias), 39 | ) 40 | 41 | def deconv_bn(batchNorm, in_planes, out_planes, kernel_size=4, stride=2, padding=1, output_padding=0, bias=True): 42 | if batchNorm: 43 | return nn.Sequential( 44 | nn.ConvTranspose2d( 45 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 46 | padding=padding, output_padding=output_padding, bias=bias), 47 | nn.BatchNorm2d(out_planes), 48 | ) 49 | else: 50 | return nn.Sequential( 51 | nn.ConvTranspose2d( 52 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 53 | padding=padding, output_padding=output_padding, bias=bias 54 | ), 55 | ) 56 | 57 | def conv3d_bn(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, bias=True): 58 | padding, dilation = consistent_padding_with_dilation(padding, dilation, dim=3) 59 | if batchNorm: 60 | return nn.Sequential( 61 | nn.Conv3d( 62 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 63 | padding=padding, dilation=dilation, bias=bias), 64 | nn.BatchNorm3d(out_planes), 65 | ) 66 | else: 67 | return nn.Sequential( 68 | nn.Conv3d( 69 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 70 | padding=padding, dilation=dilation, bias=bias 71 | ), 72 | ) 73 | 74 | 75 | def deconv3d_bn(batchNorm, in_planes, out_planes, kernel_size=4, stride=2, padding=1, output_padding=0, bias=True): 76 | if batchNorm: 77 | return nn.Sequential( 78 | nn.ConvTranspose3d( 79 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 80 | padding=padding, output_padding=output_padding, bias=bias), 81 | nn.BatchNorm3d(out_planes), 82 | ) 83 | else: 84 | return nn.Sequential( 85 | nn.ConvTranspose3d( 86 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 87 | padding=padding, output_padding=output_padding, bias=bias 88 | ), 89 | ) 90 | 91 | 92 | def conv_bn_relu(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, bias=True): 93 | padding, dilation = consistent_padding_with_dilation(padding, dilation, dim=2) 94 | if batchNorm: 95 | return nn.Sequential( 96 | nn.Conv2d( 97 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 98 | padding=padding, dilation=dilation, bias=bias), 99 | nn.BatchNorm2d(out_planes), 100 | nn.ReLU(inplace=True), 101 | ) 102 | else: 103 | return nn.Sequential( 104 | nn.Conv2d( 105 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 106 | padding=padding, dilation=dilation, bias=bias), 107 | nn.ReLU(inplace=True), 108 | ) 109 | 110 | def deconv_bn_relu(batchNorm, in_planes, out_planes, kernel_size=4, stride=2, padding=1, output_padding=0, bias=True): 111 | if batchNorm: 112 | return nn.Sequential( 113 | nn.ConvTranspose2d( 114 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 115 | padding=padding, output_padding=output_padding, bias=bias), 116 | nn.BatchNorm2d(out_planes), 117 | nn.ReLU(inplace=True), 118 | ) 119 | else: 120 | return nn.Sequential( 121 | nn.ConvTranspose2d( 122 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 123 | padding=padding, output_padding=output_padding, bias=bias 124 | ), 125 | nn.ReLU(inplace=True), 126 | ) 127 | 128 | def conv3d_bn_relu(batchNorm, in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1, bias=True): 129 | padding, dilation = consistent_padding_with_dilation(padding, dilation, dim=3) 130 | if batchNorm: 131 | return nn.Sequential( 132 | nn.Conv3d( 133 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 134 | padding=padding, dilation=dilation, bias=bias), 135 | nn.BatchNorm3d(out_planes), 136 | nn.ReLU(inplace=True), 137 | ) 138 | else: 139 | return nn.Sequential( 140 | nn.Conv3d( 141 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 142 | padding=padding, dilation=dilation, bias=bias 143 | ), 144 | nn.ReLU(inplace=True), 145 | ) 146 | 147 | 148 | def deconv3d_bn_relu(batchNorm, in_planes, out_planes, kernel_size=4, stride=2, padding=1, output_padding=0, bias=True): 149 | if batchNorm: 150 | return nn.Sequential( 151 | nn.ConvTranspose3d( 152 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 153 | padding=padding, output_padding=output_padding, bias=bias), 154 | nn.BatchNorm3d(out_planes), 155 | nn.ReLU(inplace=True), 156 | ) 157 | else: 158 | return nn.Sequential( 159 | nn.ConvTranspose3d( 160 | in_planes, out_planes, kernel_size=kernel_size, stride=stride, 161 | padding=padding, output_padding=output_padding, bias=bias 162 | ), 163 | nn.ReLU(inplace=True), 164 | ) 165 | 166 | 167 | 168 | class BasicBlock(nn.Module): 169 | expansion = 1 170 | 171 | def __init__(self, batchNorm, in_planes, out_planes, stride, downsample, padding, dilation): 172 | super(BasicBlock, self).__init__() 173 | self.conv1 = conv_bn_relu( 174 | batchNorm=batchNorm, in_planes=in_planes, out_planes=out_planes, 175 | kernel_size=3, stride=stride, padding=padding, dilation=dilation, bias=False 176 | ) 177 | self.conv2 = conv_bn( 178 | batchNorm=batchNorm, in_planes=out_planes, out_planes=out_planes, 179 | kernel_size=3, stride=1, padding=padding, dilation=dilation, bias=False 180 | ) 181 | self.downsample = downsample 182 | self.stride = stride 183 | 184 | def forward(self, x): 185 | out = self.conv1(x) 186 | out = self.conv2(out) 187 | 188 | if self.downsample is not None: 189 | x = self.downsample(x) 190 | out += x 191 | 192 | return out -------------------------------------------------------------------------------- /modeling/loss/__init__.py: -------------------------------------------------------------------------------- 1 | from .conf_nll_loss import ConfidenceNllLoss 2 | from .smooth_l1_loss import DispSmoothL1Loss 3 | from .stereo_focal_loss import StereoFocalLoss 4 | 5 | __all__ = ['ConfidenceNllLoss', 'DispSmoothL1Loss', 'StereoFocalLoss'] -------------------------------------------------------------------------------- /modeling/loss/conf_nll_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class ConfidenceNllLoss(object): 6 | """ 7 | 8 | Args: 9 | weights (list of float or None): weight for each scale of estCost. 10 | max_disp (int): the max of Disparity. default: 192 11 | sparse (bool): whether the ground-truth disparity is sparse, for example, KITTI is sparse, but SceneFlow is not. default: False 12 | 13 | Inputs: 14 | estConf (Tensor or list of tensor): the estimated confidence mam, in (BatchSize, 1, Height, Width) layout. 15 | gtDisp (Tensor): the ground truth disparity map, in (BatchSize, 1, Height, Width) layout. 16 | 17 | Outputs: 18 | loss (dict), the loss of each level 19 | """ 20 | 21 | def __init__(self, max_disp, weights=None, sparse=False): 22 | self.max_disp = max_disp 23 | self.weights = weights 24 | self.sparse = sparse 25 | if sparse: 26 | # sparse disparity ==> max_pooling 27 | self.scale_func = F.adaptive_max_pool2d 28 | else: 29 | # dense disparity ==> avg_pooling 30 | self.scale_func = F.adaptive_avg_pool2d 31 | 32 | def loss_per_level(self, estConf, gtDisp): 33 | N, C, H, W = estConf.shape 34 | scaled_gtDisp = gtDisp 35 | scale = 1.0 36 | if gtDisp.shape[-2] != H or gtDisp.shape[-1] != W: 37 | # compute scale per level and scale gtDisp 38 | scale = gtDisp.shape[-1] / (W * 1.0) 39 | scaled_gtDisp = gtDisp / scale 40 | scaled_gtDisp = self.scale_func(scaled_gtDisp, (H, W)) 41 | 42 | # mask for valid disparity 43 | # gt zero and lt max disparity 44 | mask = (scaled_gtDisp > 0) & (scaled_gtDisp < (self.max_disp / scale)) 45 | mask = mask.detach_().type_as(gtDisp) 46 | # NLL loss 47 | loss = (-1.0 * F.logsigmoid(estConf) * mask).mean() 48 | 49 | return loss 50 | 51 | def __call__(self, estConf, gtDisp): 52 | if not isinstance(estConf, (list, tuple)): 53 | estConf = [estConf] 54 | 55 | if self.weights is None: 56 | self.weights = [1.0] * len(estConf) 57 | 58 | # compute loss for per level 59 | loss_all_level = [ 60 | self.loss_per_level(est_conf_per_lvl, gtDisp) 61 | for est_conf_per_lvl in estConf 62 | ] 63 | # re-weight loss per level 64 | weighted_loss_all_level = dict() 65 | for i, loss_per_level in enumerate(loss_all_level): 66 | name = "conf_loss_lvl{}".format(i) 67 | weighted_loss_all_level[name] = self.weights[i] * loss_per_level 68 | 69 | return weighted_loss_all_level 70 | 71 | def __repr__(self): 72 | repr_str = '{}\n'.format(self.__class__.__name__) 73 | repr_str += ' ' * 4 + 'Max Disparity: {}\n'.format(self.max_disp) 74 | repr_str += ' ' * 4 + 'Loss weight: {}\n'.format(self.weights) 75 | repr_str += ' ' * 4 + 'Disparity is sparse: {}\n'.format(self.sparse) 76 | 77 | return repr_str 78 | 79 | @property 80 | def name(self): 81 | return 'ConfidenceNLLLoss' -------------------------------------------------------------------------------- /modeling/loss/smooth_l1_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class DispSmoothL1Loss(object): 6 | """ 7 | 8 | Args: 9 | max_disp, (int): the max of Disparity. default: 192 10 | start_disp, (int): the start searching disparity index, usually be 0 11 | weights, (list of float or None): weight for each scale of estCost. 12 | sparse, (bool): whether the ground-truth disparity is sparse, for example, KITTI is sparse, but SceneFlow is not. default: False 13 | 14 | Inputs: 15 | estDisp, (Tensor or list of Tensor): the estimated disparity map, in (BatchSize, 1, Height, Width) layout. 16 | gtDisp, (Tensor): the ground truth disparity map, in (BatchSize, 1, Height, Width) layout. 17 | 18 | Outputs: 19 | loss, (dict), the loss of each level 20 | """ 21 | 22 | def __init__(self, max_disp, start_disp=0, weights=None, sparse=False): 23 | self.max_disp = max_disp 24 | self.weights = weights 25 | self.start_disp = start_disp 26 | self.sparse = sparse 27 | if sparse: 28 | # sparse disparity ==> max_pooling 29 | self.scale_func = F.adaptive_max_pool2d 30 | else: 31 | # dense disparity ==> avg_pooling 32 | self.scale_func = F.adaptive_avg_pool2d 33 | 34 | def loss_per_level(self, estDisp, gtDisp): 35 | N, C, H, W = estDisp.shape 36 | scaled_gtDisp = gtDisp 37 | scale = 1.0 38 | if gtDisp.shape[-2] != H or gtDisp.shape[-1] != W: 39 | # compute scale per level and scale gtDisp 40 | scale = gtDisp.shape[-1] / (W * 1.0) 41 | scaled_gtDisp = gtDisp / scale 42 | scaled_gtDisp = self.scale_func(scaled_gtDisp, (H, W)) 43 | 44 | # mask for valid disparity 45 | # (start disparity, max disparity / scale) 46 | # Attention: the invalid disparity of KITTI is set as 0, be sure to mask it out 47 | mask = (scaled_gtDisp > self.start_disp) & (scaled_gtDisp < (self.max_disp / scale)) 48 | if mask.sum() < 1.0: 49 | print('SmoothL1 loss: there is no point\'s disparity is in ({},{})!'.format(self.start_disp, self.max_disp / scale)) 50 | loss = (torch.abs(estDisp - scaled_gtDisp) * mask.float()).mean() 51 | return loss 52 | 53 | # smooth l1 loss 54 | loss = F.smooth_l1_loss(estDisp[mask], scaled_gtDisp[mask], reduction='mean') 55 | 56 | return loss 57 | 58 | def __call__(self, estDisp, gtDisp): 59 | if not isinstance(estDisp, (list, tuple)): 60 | estDisp = [estDisp] 61 | 62 | if self.weights is None: 63 | self.weights = [1.0] * len(estDisp) 64 | 65 | # compute loss for per level 66 | loss_all_level = [] 67 | for est_disp_per_lvl in estDisp: 68 | loss_all_level.append( 69 | self.loss_per_level(est_disp_per_lvl, gtDisp) 70 | ) 71 | 72 | # re-weight loss per level 73 | weighted_loss_all_level = dict() 74 | for i, loss_per_level in enumerate(loss_all_level): 75 | name = "l1_loss_lvl{}".format(i) 76 | weighted_loss_all_level[name] = self.weights[i] * loss_per_level 77 | 78 | return weighted_loss_all_level 79 | 80 | def __repr__(self): 81 | repr_str = '{}\n'.format(self.__class__.__name__) 82 | repr_str += ' ' * 4 + 'Max Disparity: {}\n'.format(self.max_disp) 83 | repr_str += ' ' * 4 + 'Start disparity: {}\n'.format(self.start_disp) 84 | repr_str += ' ' * 4 + 'Loss weight: {}\n'.format(self.weights) 85 | repr_str += ' ' * 4 + 'Disparity is sparse: {}\n'.format(self.sparse) 86 | 87 | return repr_str 88 | 89 | @property 90 | def name(self): 91 | return 'SmoothL1Loss' -------------------------------------------------------------------------------- /modeling/loss/stereo_focal_loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | class StereoFocalLoss(object): 6 | """ 7 | Under the same start disparity and maximum disparity, calculating all estimated cost volumes' loss 8 | Args: 9 | max_disp, (int): the max of Disparity. default: 192 10 | start_disp, (int): the start searching disparity index, usually be 0 11 | dilation (int): the step between near disparity index, it mainly used in gt probability volume generation 12 | weights, (list of float or None): weight for each scale of estCost. 13 | focal_coefficient, (float): stereo focal loss coefficient, details please refer to paper. default: 0.0 14 | sparse, (bool): whether the ground-truth disparity is sparse, for example, KITTI is sparse, but SceneFlow is not. default: False 15 | 16 | Inputs: 17 | estCost, (Tensor or list of Tensor): the estimated cost volume, in (BatchSize, max_disp, Height, Width) layout 18 | gtDisp, (Tensor): the ground truth disparity map, in (BatchSize, 1, Height, Width) layout. 19 | variance, (Tensor or list of Tensor): the variance of distribution, details please refer to paper, in (BatchSize, 1, Height, Width) layout. 20 | 21 | Outputs: 22 | loss, (dict), the loss of each level 23 | 24 | ..Note: 25 | Before calculate loss, the estCost shouldn't be normalized, 26 | because we will use softmax for normalization 27 | """ 28 | 29 | def __init__(self, max_disp=192, start_disp=0, dilation=1, weights=None, focal_coefficient=0.0, sparse=False): 30 | self.max_disp = max_disp 31 | self.start_disp = start_disp 32 | self.dilation = dilation 33 | self.weights = weights 34 | self.focal_coefficient = focal_coefficient 35 | self.sparse = sparse 36 | if sparse: 37 | # sparse disparity ==> max_pooling 38 | self.scale_func = F.adaptive_max_pool2d 39 | else: 40 | # dense disparity ==> avg_pooling 41 | self.scale_func = F.adaptive_avg_pool2d 42 | 43 | def loss_per_level(self, estCost, gtDisp, variance, dilation): 44 | from dmb.modeling.stereo.loss.utils import LaplaceDisp2Prob, GaussianDisp2Prob, OneHotDisp2Prob 45 | N, C, H, W = estCost.shape 46 | scaled_gtDisp = gtDisp.clone() 47 | scale = 1.0 48 | if gtDisp.shape[-2] != H or gtDisp.shape[-1] != W: 49 | # compute scale per level and scale gtDisp 50 | scale = gtDisp.shape[-1] / (W * 1.0) 51 | scaled_gtDisp = gtDisp.clone() / scale 52 | 53 | scaled_gtDisp = self.scale_func(scaled_gtDisp, (H, W)) 54 | 55 | # mask for valid disparity 56 | # (start_disp, max disparity / scale) 57 | # Attention: the invalid disparity of KITTI is set as 0, be sure to mask it out 58 | lower_bound = self.start_disp 59 | upper_bound = lower_bound + int(self.max_disp/scale) 60 | mask = (scaled_gtDisp > lower_bound) & (scaled_gtDisp < upper_bound) 61 | mask = mask.detach_().type_as(scaled_gtDisp) 62 | if mask.sum() < 1.0: 63 | print('Stereo focal loss: there is no point\'s ' 64 | 'disparity is in [{},{})!'.format(lower_bound, upper_bound)) 65 | scaled_gtProb = torch.zeros_like(estCost) # let this sample have loss with 0 66 | else: 67 | # transfer disparity map to probability map 68 | mask_scaled_gtDisp = scaled_gtDisp * mask 69 | scaled_gtProb = LaplaceDisp2Prob(int(self.max_disp/scale), mask_scaled_gtDisp, variance=variance, 70 | start_disp=self.start_disp, dilation=dilation).getProb() 71 | 72 | # stereo focal loss 73 | estProb = F.log_softmax(estCost, dim=1) 74 | weight = (1.0 - scaled_gtProb).pow(-self.focal_coefficient).type_as(scaled_gtProb) 75 | loss = -((scaled_gtProb * estProb) * weight * mask.float()).sum(dim=1, keepdim=True).mean() 76 | 77 | return loss 78 | 79 | def __call__(self, estCost, gtDisp, variance): 80 | if not isinstance(estCost, (list, tuple)): 81 | estCost = [estCost] 82 | 83 | if self.weights is None: 84 | self.weights = 1.0 85 | 86 | if not isinstance(self.weights, (list, tuple)): 87 | self.weights = [self.weights] * len(estCost) 88 | 89 | if not isinstance(self.dilation, (list, tuple)): 90 | self.dilation = [self.dilation] * len(estCost) 91 | 92 | if not isinstance(variance, (list, tuple)): 93 | variance = [variance] * len(estCost) 94 | 95 | # compute loss for per level 96 | loss_all_level = [] 97 | for est_cost_per_lvl, var, dt in zip(estCost, variance, self.dilation): 98 | loss_all_level.append( 99 | self.loss_per_level(est_cost_per_lvl, gtDisp, var, dt)) 100 | 101 | # re-weight loss per level 102 | weighted_loss_all_level = dict() 103 | for i, loss_per_level in enumerate(loss_all_level): 104 | name = "stereo_focal_loss_lvl{}".format(i) 105 | weighted_loss_all_level[name] = self.weights[i] * loss_per_level 106 | 107 | return weighted_loss_all_level 108 | 109 | def __repr__(self): 110 | repr_str = '{}\n'.format(self.__class__.__name__) 111 | repr_str += ' ' * 4 + 'Max Disparity: {}\n'.format(self.max_disp) 112 | repr_str += ' ' * 4 + 'Start disparity: {}\n'.format(self.start_disp) 113 | repr_str += ' ' * 4 + 'Dilation rate: {}\n'.format(self.dilation) 114 | repr_str += ' ' * 4 + 'Loss weight: {}\n'.format(self.weights) 115 | repr_str += ' ' * 4 + 'Focal coefficient: {}\n'.format(self.focal_coefficient) 116 | repr_str += ' ' * 4 + 'Disparity is sparse: {}\n'.format(self.sparse) 117 | 118 | return repr_str 119 | 120 | @property 121 | def name(self): 122 | return 'StereoFocalLoss' -------------------------------------------------------------------------------- /modeling/loss/utils/__init__.py: -------------------------------------------------------------------------------- 1 | from .disp2prob import LaplaceDisp2Prob, GaussianDisp2Prob, OneHotDisp2Prob 2 | 3 | __all__ = ['LaplaceDisp2Prob', 'GaussianDisp2Prob', 'OneHotDisp2Prob',] -------------------------------------------------------------------------------- /modeling/loss/utils/disp2prob.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | def isNaN(x): 8 | return x != x 9 | 10 | 11 | class Disp2Prob(object): 12 | """ 13 | Convert disparity map to matching probability volume 14 | Args: 15 | maxDisp, (int): the maximum of disparity 16 | gtDisp, (torch.Tensor): in (..., Height, Width) layout 17 | start_disp (int): the start searching disparity index, usually be 0 18 | dilation (int): the step between near disparity index 19 | 20 | Outputs: 21 | probability, (torch.Tensor): in [BatchSize, maxDisp, Height, Width] layout 22 | 23 | 24 | """ 25 | def __init__(self, maxDisp, gtDisp, start_disp=0, dilation=1): 26 | 27 | if not isinstance(maxDisp, int): 28 | raise TypeError('int is expected, got {}'.format(type(maxDisp))) 29 | 30 | if not torch.is_tensor(gtDisp): 31 | raise TypeError('tensor is expected, got {}'.format(type(gtDisp))) 32 | 33 | if not isinstance(start_disp, int): 34 | raise TypeError('int is expected, got {}'.format(type(start_disp))) 35 | 36 | if not isinstance(dilation, int): 37 | raise TypeError('int is expected, got {}'.format(type(dilation))) 38 | 39 | if gtDisp.dim() == 2: # single image H x W 40 | gtDisp = gtDisp.view(1, 1, gtDisp.size(0), gtDisp.size(1)) 41 | 42 | if gtDisp.dim() == 3: # multi image B x H x W 43 | gtDisp = gtDisp.view(gtDisp.size(0), 1, gtDisp.size(1), gtDisp.size(2)) 44 | 45 | if gtDisp.dim() == 4: 46 | if gtDisp.size(1) == 1: # mult image B x 1 x H x W 47 | gtDisp = gtDisp 48 | else: 49 | raise ValueError('2nd dimension size should be 1, got {}'.format(gtDisp.size(1))) 50 | 51 | self.gtDisp = gtDisp 52 | self.maxDisp = maxDisp 53 | self.start_disp = start_disp 54 | self.dilation = dilation 55 | self.end_disp = start_disp + maxDisp - 1 56 | self.disp_sample_number = (maxDisp + dilation -1) // dilation 57 | self.eps = 1e-40 58 | 59 | def getProb(self): 60 | # [BatchSize, 1, Height, Width] 61 | b, c, h, w = self.gtDisp.shape 62 | assert c == 1 63 | 64 | # if start_disp = 0, dilation = 1, then generate disparity candidates as [0, 1, 2, ... , maxDisp-1] 65 | index = torch.linspace(self.start_disp, self.end_disp, self.disp_sample_number) 66 | index = index.to(self.gtDisp.device) 67 | 68 | # [BatchSize, maxDisp, Height, Width] 69 | self.index = index.repeat(b, h, w, 1).permute(0, 3, 1, 2).contiguous() 70 | 71 | # the gtDisp must be (start_disp, end_disp), otherwise, we have to mask it out 72 | mask = (self.gtDisp > self.start_disp) & (self.gtDisp < self.end_disp) 73 | mask = mask.detach().type_as(self.gtDisp) 74 | self.gtDisp = self.gtDisp * mask 75 | 76 | probability = self.calProb() 77 | 78 | # let the outliers' probability to be 0 79 | # in case divide or log 0, we plus a tiny constant value 80 | probability = probability * mask + self.eps 81 | 82 | # in case probability is NaN 83 | if isNaN(probability.min()) or isNaN(probability.max()): 84 | print('Probability ==> min: {}, max: {}'.format(probability.min(), probability.max())) 85 | print('Disparity Ground Truth after mask out ==> min: {}, max: {}'.format(self.gtDisp.min(), 86 | self.gtDisp.max())) 87 | raise ValueError(" \'probability contains NaN!") 88 | 89 | return probability 90 | 91 | def kick_invalid_half(self): 92 | distance = self.gtDisp - self.index 93 | invalid_index = distance < 0 94 | # after softmax, the valid index with value 1e6 will approximately get 0 95 | distance[invalid_index] = 1e6 96 | return distance 97 | 98 | def calProb(self): 99 | raise NotImplementedError 100 | 101 | 102 | class LaplaceDisp2Prob(Disp2Prob): 103 | # variance is the diversity of the Laplace distribution 104 | def __init__(self, maxDisp, gtDisp, variance=1, start_disp=0, dilation=1): 105 | super(LaplaceDisp2Prob, self).__init__(maxDisp, gtDisp, start_disp, dilation) 106 | self.variance = variance 107 | 108 | def calProb(self): 109 | # 1/N * exp( - (d - d{gt}) / var), N is normalization factor, [BatchSize, maxDisp, Height, Width] 110 | scaled_distance = ((-torch.abs(self.index - self.gtDisp)) / self.variance) 111 | probability = F.softmax(scaled_distance, dim=1) 112 | 113 | return probability 114 | 115 | 116 | class GaussianDisp2Prob(Disp2Prob): 117 | # variance is the variance of the Gaussian distribution 118 | def __init__(self, maxDisp, gtDisp, variance=1, start_disp=0, dilation=1): 119 | super(GaussianDisp2Prob, self).__init__(maxDisp, gtDisp, start_disp, dilation) 120 | self.variance = variance 121 | 122 | def calProb(self): 123 | # 1/N * exp( - (d - d{gt})^2 / b), N is normalization factor, [BatchSize, maxDisp, Height, Width] 124 | distance = (torch.abs(self.index - self.gtDisp)) 125 | scaled_distance = (- distance.pow(2.0) / self.variance) 126 | probability = F.softmax(scaled_distance, dim=1) 127 | 128 | return probability 129 | 130 | class OneHotDisp2Prob(Disp2Prob): 131 | # variance is the variance of the OneHot distribution 132 | def __init__(self, maxDisp, gtDisp, variance=1, start_disp=0, dilation=1): 133 | super(OneHotDisp2Prob, self).__init__(maxDisp, gtDisp, start_disp, dilation) 134 | self.variance = variance 135 | 136 | def getProb(self): 137 | 138 | # |d - d{gt}| < variance, [BatchSize, maxDisp, Height, Width] 139 | probability = torch.lt(torch.abs(self.index - self.gtDisp), self.variance).type_as(self.gtDisp) 140 | 141 | return probability 142 | -------------------------------------------------------------------------------- /modeling/models/AcfNet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .backbone import PsmBb 6 | from .cost_computation import cat_fms 7 | from .cost_aggregation import AcfCost 8 | from .conf_measure import ConfidenceEstimation 9 | from .disp_prediction import faster_soft_argmin 10 | from .loss import StereoFocalLoss, ConfidenceNllLoss, DispSmoothL1Loss 11 | from dmb.data.datasets.evaluation.stereo.eval import remove_padding, do_evaluation, do_occlusion_evaluation 12 | 13 | class AcfNet(nn.Module): 14 | """ 15 | Args: 16 | cfg, (dict): the configuration of model 17 | Inputs: 18 | batch, (dict): the input batch of the model, 19 | keywords must contains 'leftImage', 'rightImage', 'original size' 20 | optionally contains 'leftDisp', 'rightDisp' 21 | Image or disparity map are in torch.Tensor type 22 | original size is the size of original unprocessed image, i.e. (H,W), e.g. (540, 960) in SceneFlow 23 | Outputs: 24 | result, (dict): the result grouped in a dict. 25 | """ 26 | 27 | def __init__(self, cfg): 28 | super(AcfNet, self).__init__() 29 | self.cfg = cfg.copy() 30 | 31 | self.batchNorm = cfg.model.batchNorm 32 | 33 | self.max_disp = cfg.model.max_disp 34 | 35 | # image feature extraction 36 | self.backbone_in_planes = cfg.model.backbone.in_planes 37 | self.backbone = PsmBb(in_planes=self.backbone_in_planes, batchNorm=self.batchNorm) 38 | 39 | # matching cost computation 40 | self.cost_computation = cat_fms 41 | 42 | # matching cost aggregation 43 | self.cost_aggregation_in_planes = cfg.model.cost_aggregation.in_planes 44 | self.cost_aggregation = AcfCost(max_disp=self.max_disp, 45 | in_planes=self.cost_aggregation_in_planes, 46 | batchNorm=self.batchNorm) 47 | 48 | # confidence learning 49 | self.conf_est_net = nn.ModuleList([ 50 | ConfidenceEstimation(in_planes=self.max_disp, batchNorm=self.batchNorm) for i in range(3)]) 51 | self.confidence_coefficient = cfg.model.confidence.coefficient 52 | self.confidence_init_value = cfg.model.confidence.init_value 53 | 54 | # calculate loss 55 | self.weights = cfg.model.loss.weights 56 | self.sparse = cfg.data.sparse 57 | 58 | # focal loss 59 | self.focal_coefficient = cfg.model.loss.focal_coefficient 60 | self.loss_variance = cfg.model.loss.variance 61 | self.focal_loss_evaluator = \ 62 | StereoFocalLoss(max_disp=self.max_disp, weights=self.weights, 63 | focal_coefficient=self.focal_coefficient, sparse=self.sparse) 64 | 65 | # smooth l1 loss 66 | self.l1_loss_evaluator = DispSmoothL1Loss(self.max_disp, weights=self.weights, 67 | sparse=self.sparse) 68 | self.l1_loss_weight = cfg.model.loss.l1_loss_weight 69 | 70 | # nll loss 71 | self.conf_loss_evaluator = \ 72 | ConfidenceNllLoss(max_disp=self.max_disp, weights=self.weights, sparse=self.sparse) 73 | self.conf_loss_weight = cfg.model.loss.conf_loss_weight 74 | 75 | # disparity regression 76 | # Attention: faster soft argmin contains a nn.Conv3d with fixed value 77 | # and cannot be initialized with other initialization method, e.g. Xavier, Kaiming initialization 78 | self.sa_temperature = cfg.model.disparity_prediction.sa_temperature 79 | self.disp_predictor = faster_soft_argmin(self.max_disp) 80 | 81 | 82 | def forward(self, batch): 83 | ref_image, target_image = batch['leftImage'], batch['rightImage'] 84 | target = batch['leftDisp'] if 'leftDisp' in batch else None 85 | 86 | ref_fm, target_fm = self.backbone(ref_image, target_image) 87 | 88 | raw_cost = self.cost_computation(ref_fm, target_fm, int(self.max_disp // 4)) 89 | costs = self.cost_aggregation(raw_cost) 90 | 91 | confidence_costs = [cen(c) for c, cen in zip(costs, self.conf_est_net)] 92 | confidences = [torch.sigmoid(c) for c in confidence_costs] 93 | 94 | variances = [self.confidence_coefficient * (1 - conf) + self.confidence_init_value for conf in confidences] 95 | 96 | disps = [self.disp_predictor(cost, temperature=self.sa_temperature) for cost in costs] 97 | 98 | 99 | if self.training: 100 | assert target is not None, "Ground truth disparity map should be given" 101 | losses = {} 102 | focal_losses = self.focal_loss_evaluator(costs, target, variances) 103 | losses.update(focal_losses) 104 | 105 | l1_losses = self.l1_loss_evaluator(disps, target) 106 | l1_losses = {k: v * self.l1_loss_weight for k, v in zip(l1_losses.keys(), l1_losses.values())} 107 | losses.update(l1_losses) 108 | 109 | nll_losses = self.conf_loss_evaluator(confidence_costs, target) 110 | nll_losses = {k : v * self.conf_loss_weight for k, v in zip(nll_losses.keys(), nll_losses.values())} 111 | losses.update(nll_losses) 112 | 113 | return losses 114 | else: 115 | confidences = remove_padding(confidences, batch['original_size']) 116 | disps = remove_padding(disps, batch['original_size']) 117 | 118 | error_dict = {} 119 | if target is not None: 120 | target = remove_padding(target, batch['original_size']) 121 | error_dict = do_evaluation(disps[0], target, 122 | self.cfg.model.eval.lower_bound, 123 | self.cfg.model.eval.upper_bound) 124 | 125 | if self.cfg.model.eval.eval_occlusion and 'leftDisp' in batch and 'rightDisp' in batch: 126 | batch['leftDisp'] = remove_padding(batch['leftDisp'], batch['original_size']) 127 | batch['rightDisp'] = remove_padding(batch['rightDisp'], batch['original_size']) 128 | occ_error_dict = do_occlusion_evaluation(disps[0], batch['leftDisp'], batch['rightDisp'], 129 | self.cfg.model.eval.lower_bound, 130 | self.cfg.model.eval.upper_bound) 131 | error_dict.update(occ_error_dict) 132 | 133 | 134 | result = {'Disparity': disps, 135 | 'GroundTruth': target, 136 | 'Confidence': confidences, 137 | 'Error': error_dict, 138 | } 139 | 140 | if self.cfg.model.eval.is_cost_return: 141 | if self.cfg.model.eval.is_cost_to_cpu: 142 | costs = [cost.cpu() for cost in costs] 143 | result['Cost'] = costs 144 | 145 | return result 146 | -------------------------------------------------------------------------------- /modeling/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/youmi-zym/AcfNet/82798abd27636fe5e28fab0a8ed6460c9ad9c2b1/modeling/models/__init__.py --------------------------------------------------------------------------------