├── GBDF.yaml ├── LeRes ├── Resnet.py ├── Resnext_torch.py ├── __init__.py ├── multi_depth_model_woauxi.py ├── net_tools.py ├── network_auxi.py ├── spvcnn_classsification.py ├── spvcnn_utils.py └── test_utils.py ├── MiDaS ├── base_model.py ├── blocks.py ├── hubconf.py ├── midas_net.py └── transforms.py ├── README.md ├── SGR ├── DepthNet.py ├── networks.py ├── resnet.py └── syncbn │ ├── make_ext.sh │ ├── modules │ ├── __init__.py │ ├── __init__.pyc │ ├── functional │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── _syncbn │ │ │ ├── __init__.py │ │ │ ├── __init__.pyc │ │ │ ├── _ext │ │ │ │ ├── __init__.py │ │ │ │ ├── __init__.pyc │ │ │ │ └── syncbn │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── __init__.pyc │ │ │ │ │ └── _syncbn.so │ │ │ ├── build.py │ │ │ └── src │ │ │ │ ├── common.h │ │ │ │ ├── syncbn.cpp │ │ │ │ ├── syncbn.cu │ │ │ │ ├── syncbn.cu.h │ │ │ │ ├── syncbn.cu.o │ │ │ │ └── syncbn.h │ │ ├── syncbn.py │ │ └── syncbn.pyc │ └── nn │ │ ├── __init__.py │ │ ├── __init__.pyc │ │ ├── syncbn.py │ │ └── syncbn.pyc │ ├── requirements.txt │ └── test.py ├── dpt ├── __init__.py ├── base_model.py ├── blocks.py ├── midas_net.py ├── models.py ├── transforms.py ├── util │ ├── __init__.py │ ├── io.py │ ├── misc.py │ └── pallete.py ├── vit.py └── weights │ └── .placeholder ├── eval.py ├── figures ├── 1.gif ├── 2.gif └── 3.gif ├── input ├── 1.png ├── 2.png └── 3.png ├── newcrfs ├── dataloaders │ ├── __init__.py │ ├── dataloader.py │ └── dataloader_kittipred.py ├── networks │ ├── NewCRFDepth.py │ ├── __init__.py │ ├── newcrf_layers.py │ ├── newcrf_utils.py │ ├── swin_transformer.py │ └── uper_crf_head.py └── utils.py ├── run.py ├── train.py └── utils ├── func.py ├── guided_f.py ├── hypersim.py ├── middleburry2021.py ├── model.py └── multiscopic.py /GBDF.yaml: -------------------------------------------------------------------------------- 1 | name: GBDF 2 | channels: 3 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main 4 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/ 5 | - https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/ 6 | - defaults 7 | dependencies: 8 | - _libgcc_mutex=0.1=main 9 | - _openmp_mutex=5.1=1_gnu 10 | - ca-certificates=2022.10.11=h06a4308_0 11 | - certifi=2022.9.24=py38h06a4308_0 12 | - ld_impl_linux-64=2.38=h1181459_1 13 | - libffi=3.4.2=h6a678d5_6 14 | - libgcc-ng=11.2.0=h1234567_1 15 | - libgomp=11.2.0=h1234567_1 16 | - libstdcxx-ng=11.2.0=h1234567_1 17 | - ncurses=6.3=h5eee18b_3 18 | - openssl=1.1.1s=h7f8727e_0 19 | - pip=22.2.2=py38h06a4308_0 20 | - python=3.8.15=h3fd9d12_0 21 | - readline=8.2=h5eee18b_0 22 | - setuptools=65.5.0=py38h06a4308_0 23 | - sqlite=3.40.0=h5082296_0 24 | - tk=8.6.12=h1ccaba5_0 25 | - wheel=0.37.1=pyhd3eb1b0_0 26 | - xz=5.2.6=h5eee18b_0 27 | - zlib=1.2.13=h5eee18b_0 28 | - pip: 29 | - argparse==1.4.0 30 | - click==8.1.3 31 | - contourpy==1.0.6 32 | - cycler==0.11.0 33 | - fonttools==4.38.0 34 | - h5df==0.1.5 35 | - h5py==3.7.0 36 | - imageio==2.22.4 37 | - kiwisolver==1.4.4 38 | - matplotlib==3.6.2 39 | - networkx==2.8.8 40 | - numpy==1.23.5 41 | - opencv-python==4.6.0.66 42 | - packaging==21.3 43 | - pandas==1.5.2 44 | - pillow==9.3.0 45 | - plyfile==0.7.4 46 | - pyparsing==3.0.9 47 | - python-dateutil==2.8.2 48 | - pytz==2022.6 49 | - pywavelets==1.4.1 50 | - scikit-image==0.19.3 51 | - scipy==1.9.3 52 | - six==1.16.0 53 | - sklearn==0.0.post1 54 | - tifffile==2022.10.10 55 | - typing-extensions==4.4.0 56 | 57 | -------------------------------------------------------------------------------- /LeRes/Resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn as NN 3 | 4 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 5 | 'resnet152'] 6 | 7 | 8 | model_urls = { 9 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 10 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 11 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 12 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 13 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 14 | } 15 | 16 | 17 | def conv3x3(in_planes, out_planes, stride=1): 18 | """3x3 convolution with padding""" 19 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 20 | padding=1, bias=False) 21 | 22 | 23 | class BasicBlock(nn.Module): 24 | expansion = 1 25 | 26 | def __init__(self, inplanes, planes, stride=1, downsample=None): 27 | super(BasicBlock, self).__init__() 28 | self.conv1 = conv3x3(inplanes, planes, stride) 29 | self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d 30 | self.relu = nn.ReLU(inplace=True) 31 | self.conv2 = conv3x3(planes, planes) 32 | self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d 33 | self.downsample = downsample 34 | self.stride = stride 35 | 36 | def forward(self, x): 37 | residual = x 38 | 39 | out = self.conv1(x) 40 | out = self.bn1(out) 41 | out = self.relu(out) 42 | 43 | out = self.conv2(out) 44 | out = self.bn2(out) 45 | 46 | if self.downsample is not None: 47 | residual = self.downsample(x) 48 | 49 | out += residual 50 | out = self.relu(out) 51 | 52 | return out 53 | 54 | 55 | class Bottleneck(nn.Module): 56 | expansion = 4 57 | 58 | def __init__(self, inplanes, planes, stride=1, downsample=None): 59 | super(Bottleneck, self).__init__() 60 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 61 | self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d 62 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 63 | padding=1, bias=False) 64 | self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d 65 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 66 | self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d 67 | self.relu = nn.ReLU(inplace=True) 68 | self.downsample = downsample 69 | self.stride = stride 70 | 71 | def forward(self, x): 72 | residual = x 73 | 74 | out = self.conv1(x) 75 | out = self.bn1(out) 76 | out = self.relu(out) 77 | 78 | out = self.conv2(out) 79 | out = self.bn2(out) 80 | out = self.relu(out) 81 | 82 | out = self.conv3(out) 83 | out = self.bn3(out) 84 | 85 | if self.downsample is not None: 86 | residual = self.downsample(x) 87 | 88 | out += residual 89 | out = self.relu(out) 90 | 91 | return out 92 | 93 | 94 | class ResNet(nn.Module): 95 | 96 | def __init__(self, block, layers, num_classes=1000): 97 | self.inplanes = 64 98 | super(ResNet, self).__init__() 99 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 100 | bias=False) 101 | self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d 102 | self.relu = nn.ReLU(inplace=True) 103 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 104 | self.layer1 = self._make_layer(block, 64, layers[0]) 105 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 106 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 107 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 108 | #self.avgpool = nn.AvgPool2d(7, stride=1) 109 | #self.fc = nn.Linear(512 * block.expansion, num_classes) 110 | 111 | for m in self.modules(): 112 | if isinstance(m, nn.Conv2d): 113 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 114 | elif isinstance(m, nn.BatchNorm2d): 115 | nn.init.constant_(m.weight, 1) 116 | nn.init.constant_(m.bias, 0) 117 | 118 | def _make_layer(self, block, planes, blocks, stride=1): 119 | downsample = None 120 | if stride != 1 or self.inplanes != planes * block.expansion: 121 | downsample = nn.Sequential( 122 | nn.Conv2d(self.inplanes, planes * block.expansion, 123 | kernel_size=1, stride=stride, bias=False), 124 | NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d 125 | ) 126 | 127 | layers = [] 128 | layers.append(block(self.inplanes, planes, stride, downsample)) 129 | self.inplanes = planes * block.expansion 130 | for i in range(1, blocks): 131 | layers.append(block(self.inplanes, planes)) 132 | 133 | return nn.Sequential(*layers) 134 | 135 | def forward(self, x): 136 | features = [] 137 | 138 | x = self.conv1(x) 139 | x = self.bn1(x) 140 | x = self.relu(x) 141 | x = self.maxpool(x) 142 | 143 | x = self.layer1(x) 144 | features.append(x) 145 | x = self.layer2(x) 146 | features.append(x) 147 | x = self.layer3(x) 148 | features.append(x) 149 | x = self.layer4(x) 150 | features.append(x) 151 | 152 | return features 153 | 154 | 155 | def resnet18(pretrained=True, **kwargs): 156 | """Constructs a ResNet-18 model. 157 | Args: 158 | pretrained (bool): If True, returns a model pre-trained on ImageNet 159 | """ 160 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 161 | return model 162 | 163 | 164 | def resnet34(pretrained=True, **kwargs): 165 | """Constructs a ResNet-34 model. 166 | Args: 167 | pretrained (bool): If True, returns a model pre-trained on ImageNet 168 | """ 169 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 170 | return model 171 | 172 | 173 | def resnet50(pretrained=True, **kwargs): 174 | """Constructs a ResNet-50 model. 175 | Args: 176 | pretrained (bool): If True, returns a model pre-trained on ImageNet 177 | """ 178 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 179 | 180 | return model 181 | 182 | 183 | def resnet101(pretrained=True, **kwargs): 184 | """Constructs a ResNet-101 model. 185 | Args: 186 | pretrained (bool): If True, returns a model pre-trained on ImageNet 187 | """ 188 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 189 | 190 | return model 191 | 192 | 193 | def resnet152(pretrained=True, **kwargs): 194 | """Constructs a ResNet-152 model. 195 | Args: 196 | pretrained (bool): If True, returns a model pre-trained on ImageNet 197 | """ 198 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 199 | return model 200 | -------------------------------------------------------------------------------- /LeRes/Resnext_torch.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # coding: utf-8 3 | import torch.nn as nn 4 | 5 | try: 6 | from urllib import urlretrieve 7 | except ImportError: 8 | from urllib.request import urlretrieve 9 | 10 | __all__ = ['resnext101_32x8d'] 11 | 12 | 13 | model_urls = { 14 | 'resnext50_32x4d': 'https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pth', 15 | 'resnext101_32x8d': 'https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pth', 16 | } 17 | 18 | 19 | def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1): 20 | """3x3 convolution with padding""" 21 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 22 | padding=dilation, groups=groups, bias=False, dilation=dilation) 23 | 24 | 25 | def conv1x1(in_planes, out_planes, stride=1): 26 | """1x1 convolution""" 27 | return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False) 28 | 29 | 30 | class BasicBlock(nn.Module): 31 | expansion = 1 32 | 33 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 34 | base_width=64, dilation=1, norm_layer=None): 35 | super(BasicBlock, self).__init__() 36 | if norm_layer is None: 37 | norm_layer = nn.BatchNorm2d 38 | if groups != 1 or base_width != 64: 39 | raise ValueError('BasicBlock only supports groups=1 and base_width=64') 40 | if dilation > 1: 41 | raise NotImplementedError("Dilation > 1 not supported in BasicBlock") 42 | # Both self.conv1 and self.downsample layers downsample the input when stride != 1 43 | self.conv1 = conv3x3(inplanes, planes, stride) 44 | self.bn1 = norm_layer(planes) 45 | self.relu = nn.ReLU(inplace=True) 46 | self.conv2 = conv3x3(planes, planes) 47 | self.bn2 = norm_layer(planes) 48 | self.downsample = downsample 49 | self.stride = stride 50 | 51 | def forward(self, x): 52 | identity = x 53 | 54 | out = self.conv1(x) 55 | out = self.bn1(out) 56 | out = self.relu(out) 57 | 58 | out = self.conv2(out) 59 | out = self.bn2(out) 60 | 61 | if self.downsample is not None: 62 | identity = self.downsample(x) 63 | 64 | out += identity 65 | out = self.relu(out) 66 | 67 | return out 68 | 69 | 70 | class Bottleneck(nn.Module): 71 | 72 | 73 | expansion = 4 74 | 75 | def __init__(self, inplanes, planes, stride=1, downsample=None, groups=1, 76 | base_width=64, dilation=1, norm_layer=None): 77 | super(Bottleneck, self).__init__() 78 | if norm_layer is None: 79 | norm_layer = nn.BatchNorm2d 80 | width = int(planes * (base_width / 64.)) * groups 81 | # Both self.conv2 and self.downsample layers downsample the input when stride != 1 82 | self.conv1 = conv1x1(inplanes, width) 83 | self.bn1 = norm_layer(width) 84 | self.conv2 = conv3x3(width, width, stride, groups, dilation) 85 | self.bn2 = norm_layer(width) 86 | self.conv3 = conv1x1(width, planes * self.expansion) 87 | self.bn3 = norm_layer(planes * self.expansion) 88 | self.relu = nn.ReLU(inplace=True) 89 | self.downsample = downsample 90 | self.stride = stride 91 | 92 | def forward(self, x): 93 | identity = x 94 | 95 | out = self.conv1(x) 96 | out = self.bn1(out) 97 | out = self.relu(out) 98 | 99 | out = self.conv2(out) 100 | out = self.bn2(out) 101 | out = self.relu(out) 102 | 103 | out = self.conv3(out) 104 | out = self.bn3(out) 105 | 106 | if self.downsample is not None: 107 | identity = self.downsample(x) 108 | 109 | out += identity 110 | out = self.relu(out) 111 | 112 | return out 113 | 114 | 115 | class ResNet(nn.Module): 116 | 117 | def __init__(self, block, layers, num_classes=1000, zero_init_residual=False, 118 | groups=1, width_per_group=64, replace_stride_with_dilation=None, 119 | norm_layer=None): 120 | super(ResNet, self).__init__() 121 | if norm_layer is None: 122 | norm_layer = nn.BatchNorm2d 123 | self._norm_layer = norm_layer 124 | 125 | self.inplanes = 64 126 | self.dilation = 1 127 | if replace_stride_with_dilation is None: 128 | # each element in the tuple indicates if we should replace 129 | # the 2x2 stride with a dilated convolution instead 130 | replace_stride_with_dilation = [False, False, False] 131 | if len(replace_stride_with_dilation) != 3: 132 | raise ValueError("replace_stride_with_dilation should be None " 133 | "or a 3-element tuple, got {}".format(replace_stride_with_dilation)) 134 | self.groups = groups 135 | self.base_width = width_per_group 136 | self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3, 137 | bias=False) 138 | self.bn1 = norm_layer(self.inplanes) 139 | self.relu = nn.ReLU(inplace=True) 140 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 141 | self.layer1 = self._make_layer(block, 64, layers[0]) 142 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, 143 | dilate=replace_stride_with_dilation[0]) 144 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 145 | dilate=replace_stride_with_dilation[1]) 146 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 147 | dilate=replace_stride_with_dilation[2]) 148 | #self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) 149 | #self.fc = nn.Linear(512 * block.expansion, num_classes) 150 | 151 | for m in self.modules(): 152 | if isinstance(m, nn.Conv2d): 153 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 154 | elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 155 | nn.init.constant_(m.weight, 1) 156 | nn.init.constant_(m.bias, 0) 157 | 158 | if zero_init_residual: 159 | for m in self.modules(): 160 | if isinstance(m, Bottleneck): 161 | nn.init.constant_(m.bn3.weight, 0) 162 | elif isinstance(m, BasicBlock): 163 | nn.init.constant_(m.bn2.weight, 0) 164 | 165 | def _make_layer(self, block, planes, blocks, stride=1, dilate=False): 166 | norm_layer = self._norm_layer 167 | downsample = None 168 | previous_dilation = self.dilation 169 | if dilate: 170 | self.dilation *= stride 171 | stride = 1 172 | if stride != 1 or self.inplanes != planes * block.expansion: 173 | downsample = nn.Sequential( 174 | conv1x1(self.inplanes, planes * block.expansion, stride), 175 | norm_layer(planes * block.expansion), 176 | ) 177 | 178 | layers = [] 179 | layers.append(block(self.inplanes, planes, stride, downsample, self.groups, 180 | self.base_width, previous_dilation, norm_layer)) 181 | self.inplanes = planes * block.expansion 182 | for _ in range(1, blocks): 183 | layers.append(block(self.inplanes, planes, groups=self.groups, 184 | base_width=self.base_width, dilation=self.dilation, 185 | norm_layer=norm_layer)) 186 | 187 | return nn.Sequential(*layers) 188 | 189 | def _forward_impl(self, x): 190 | # See note [TorchScript super()] 191 | features = [] 192 | x = self.conv1(x) 193 | x = self.bn1(x) 194 | x = self.relu(x) 195 | x = self.maxpool(x) 196 | 197 | x = self.layer1(x) 198 | features.append(x) 199 | 200 | x = self.layer2(x) 201 | features.append(x) 202 | 203 | x = self.layer3(x) 204 | features.append(x) 205 | 206 | x = self.layer4(x) 207 | features.append(x) 208 | 209 | #x = self.avgpool(x) 210 | #x = torch.flatten(x, 1) 211 | #x = self.fc(x) 212 | 213 | return features 214 | 215 | def forward(self, x): 216 | return self._forward_impl(x) 217 | 218 | 219 | 220 | def resnext101_32x8d(pretrained=True, **kwargs): 221 | """Constructs a ResNet-152 model. 222 | Args: 223 | pretrained (bool): If True, returns a model pre-trained on ImageNet 224 | """ 225 | kwargs['groups'] = 32 226 | kwargs['width_per_group'] = 8 227 | 228 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 229 | return model 230 | 231 | 232 | 233 | if __name__ == '__main__': 234 | import torch 235 | model = resnext101_32x8d(True).cuda() 236 | 237 | rgb = torch.rand((2, 3, 256, 256)).cuda() 238 | out = model(rgb) 239 | print(len(out)) 240 | 241 | -------------------------------------------------------------------------------- /LeRes/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /LeRes/multi_depth_model_woauxi.py: -------------------------------------------------------------------------------- 1 | from LeRes import network_auxi as network 2 | from LeRes.net_tools import get_func 3 | import torch 4 | import torch.nn as nn 5 | from collections import OrderedDict 6 | 7 | 8 | def strip_prefix_if_present(state_dict=None, prefix="module."): 9 | if state_dict is None: 10 | depth_dict='./res101.pth' 11 | depth_dict = torch.load(depth_dict) 12 | state_dict = depth_dict['depth_model'] 13 | keys = sorted(state_dict.keys()) 14 | if not all(key.startswith(prefix) for key in keys): 15 | return state_dict 16 | stripped_state_dict = OrderedDict() 17 | for key, value in state_dict.items(): 18 | stripped_state_dict[key.replace(prefix, "")] = value 19 | return stripped_state_dict 20 | 21 | class RelDepthModel(nn.Module): 22 | def __init__(self, backbone='resnet50'): 23 | super(RelDepthModel, self).__init__() 24 | if backbone == 'resnet50': 25 | encoder = 'resnet50_stride32' 26 | elif backbone == 'resnext101': 27 | encoder = 'resnext101_stride32x8d' 28 | self.depth_model = DepthModel(encoder) 29 | 30 | def inference(self, rgb): 31 | with torch.no_grad(): 32 | input = rgb.cuda() 33 | depth = self.depth_model(input) 34 | pred_depth_out = depth - depth.min() + 0.01 35 | return pred_depth_out 36 | 37 | def check_feature(self, rgb): 38 | with torch.no_grad(): 39 | input = rgb.cuda() 40 | feature = self.depth_model(input) 41 | return feature 42 | 43 | 44 | class DepthModel(nn.Module): 45 | def __init__(self, encoder): 46 | super(DepthModel, self).__init__() 47 | backbone = network.__name__.split('.')[-1] + '.' + encoder 48 | self.encoder_modules = get_func(backbone)() 49 | self.decoder_modules = network.Decoder() 50 | 51 | def forward(self, x): 52 | lateral_out = self.encoder_modules(x) 53 | out_logit = self.decoder_modules(lateral_out) 54 | return out_logit 55 | # return lateral_out -------------------------------------------------------------------------------- /LeRes/net_tools.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | import torch 3 | import os 4 | from collections import OrderedDict 5 | 6 | 7 | def get_func(func_name): 8 | """Helper to return a function object by name. func_name must identify a 9 | function in this module or the path to a function relative to the base 10 | 'modeling' module. 11 | """ 12 | if func_name == '': 13 | return None 14 | try: 15 | parts = func_name.split('.') 16 | # Refers to a function in this module 17 | if len(parts) == 1: 18 | return globals()[parts[0]] 19 | # Otherwise, assume we're referencing a module under modeling 20 | module_name = 'LeRes.' + '.'.join(parts[:-1]) 21 | module = importlib.import_module(module_name) 22 | return getattr(module, parts[-1]) 23 | except Exception: 24 | print('Failed to f1ind function: %s', func_name) 25 | raise 26 | 27 | def load_ckpt(args, depth_model, shift_model, focal_model): 28 | """ 29 | Load checkpoint. 30 | """ 31 | if os.path.isfile(args.load_ckpt): 32 | print("loading checkpoint %s" % args.load_ckpt) 33 | checkpoint = torch.load(args.load_ckpt) 34 | if shift_model is not None: 35 | shift_model.load_state_dict(strip_prefix_if_present(checkpoint['shift_model'], 'module.'), 36 | strict=True) 37 | if focal_model is not None: 38 | focal_model.load_state_dict(strip_prefix_if_present(checkpoint['focal_model'], 'module.'), 39 | strict=True) 40 | depth_model.load_state_dict(strip_prefix_if_present(checkpoint['depth_model'], "module."), 41 | strict=True) 42 | del checkpoint 43 | torch.cuda.empty_cache() 44 | 45 | 46 | def strip_prefix_if_present(state_dict, prefix): 47 | keys = sorted(state_dict.keys()) 48 | if not all(key.startswith(prefix) for key in keys): 49 | return state_dict 50 | stripped_state_dict = OrderedDict() 51 | for key, value in state_dict.items(): 52 | stripped_state_dict[key.replace(prefix, "")] = value 53 | return stripped_state_dict -------------------------------------------------------------------------------- /LeRes/spvcnn_classsification.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torchsparse.nn as spnn 3 | from torchsparse.point_tensor import PointTensor 4 | from LeRes.spvcnn_utils import * 5 | __all__ = ['SPVCNN_CLASSIFICATION'] 6 | 7 | 8 | 9 | class BasicConvolutionBlock(nn.Module): 10 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1): 11 | super().__init__() 12 | self.net = nn.Sequential( 13 | spnn.Conv3d(inc, 14 | outc, 15 | kernel_size=ks, 16 | dilation=dilation, 17 | stride=stride), 18 | spnn.BatchNorm(outc), 19 | spnn.ReLU(True)) 20 | 21 | def forward(self, x): 22 | out = self.net(x) 23 | return out 24 | 25 | 26 | class BasicDeconvolutionBlock(nn.Module): 27 | def __init__(self, inc, outc, ks=3, stride=1): 28 | super().__init__() 29 | self.net = nn.Sequential( 30 | spnn.Conv3d(inc, 31 | outc, 32 | kernel_size=ks, 33 | stride=stride, 34 | transpose=True), 35 | spnn.BatchNorm(outc), 36 | spnn.ReLU(True)) 37 | 38 | def forward(self, x): 39 | return self.net(x) 40 | 41 | 42 | class ResidualBlock(nn.Module): 43 | def __init__(self, inc, outc, ks=3, stride=1, dilation=1): 44 | super().__init__() 45 | self.net = nn.Sequential( 46 | spnn.Conv3d(inc, 47 | outc, 48 | kernel_size=ks, 49 | dilation=dilation, 50 | stride=stride), spnn.BatchNorm(outc), 51 | spnn.ReLU(True), 52 | spnn.Conv3d(outc, 53 | outc, 54 | kernel_size=ks, 55 | dilation=dilation, 56 | stride=1), 57 | spnn.BatchNorm(outc) 58 | ) 59 | 60 | self.downsample = nn.Sequential() if (inc == outc and stride == 1) else \ 61 | nn.Sequential( 62 | spnn.Conv3d(inc, outc, kernel_size=1, dilation=1, stride=stride), 63 | spnn.BatchNorm(outc) 64 | ) 65 | 66 | self.relu = spnn.ReLU(True) 67 | 68 | def forward(self, x): 69 | out = self.relu(self.net(x) + self.downsample(x)) 70 | return out 71 | 72 | 73 | class SPVCNN_CLASSIFICATION(nn.Module): 74 | def __init__(self, **kwargs): 75 | super().__init__() 76 | 77 | cr = kwargs.get('cr', 1.0) 78 | cs = [32, 32, 64, 128, 256, 256, 128, 96, 96] 79 | cs = [int(cr * x) for x in cs] 80 | 81 | if 'pres' in kwargs and 'vres' in kwargs: 82 | self.pres = kwargs['pres'] 83 | self.vres = kwargs['vres'] 84 | 85 | self.stem = nn.Sequential( 86 | spnn.Conv3d(kwargs['input_channel'], cs[0], kernel_size=3, stride=1), 87 | spnn.BatchNorm(cs[0]), 88 | spnn.ReLU(True), 89 | spnn.Conv3d(cs[0], cs[0], kernel_size=3, stride=1), 90 | spnn.BatchNorm(cs[0]), 91 | spnn.ReLU(True)) 92 | 93 | self.stage1 = nn.Sequential( 94 | BasicConvolutionBlock(cs[0], cs[0], ks=2, stride=2, dilation=1), 95 | ResidualBlock(cs[0], cs[1], ks=3, stride=1, dilation=1), 96 | ResidualBlock(cs[1], cs[1], ks=3, stride=1, dilation=1), 97 | ) 98 | 99 | self.stage2 = nn.Sequential( 100 | BasicConvolutionBlock(cs[1], cs[1], ks=2, stride=2, dilation=1), 101 | ResidualBlock(cs[1], cs[2], ks=3, stride=1, dilation=1), 102 | ResidualBlock(cs[2], cs[2], ks=3, stride=1, dilation=1), 103 | ) 104 | 105 | self.stage3 = nn.Sequential( 106 | BasicConvolutionBlock(cs[2], cs[2], ks=2, stride=2, dilation=1), 107 | ResidualBlock(cs[2], cs[3], ks=3, stride=1, dilation=1), 108 | ResidualBlock(cs[3], cs[3], ks=3, stride=1, dilation=1), 109 | ) 110 | 111 | self.stage4 = nn.Sequential( 112 | BasicConvolutionBlock(cs[3], cs[3], ks=2, stride=2, dilation=1), 113 | ResidualBlock(cs[3], cs[4], ks=3, stride=1, dilation=1), 114 | ResidualBlock(cs[4], cs[4], ks=3, stride=1, dilation=1), 115 | ) 116 | self.avg_pool = spnn.GlobalAveragePooling() 117 | self.classifier = nn.Sequential(nn.Linear(cs[4], kwargs['num_classes'])) 118 | self.point_transforms = nn.ModuleList([ 119 | nn.Sequential( 120 | nn.Linear(cs[0], cs[4]), 121 | nn.BatchNorm1d(cs[4]), 122 | nn.ReLU(True), 123 | ), 124 | ]) 125 | 126 | self.weight_initialization() 127 | self.dropout = nn.Dropout(0.3, True) 128 | 129 | def weight_initialization(self): 130 | for m in self.modules(): 131 | if isinstance(m, nn.BatchNorm1d): 132 | nn.init.constant_(m.weight, 1) 133 | nn.init.constant_(m.bias, 0) 134 | 135 | def forward(self, x): 136 | # x: SparseTensor z: PointTensor 137 | z = PointTensor(x.F, x.C.float()) 138 | 139 | x0 = initial_voxelize(z, self.pres, self.vres) 140 | 141 | x0 = self.stem(x0) 142 | z0 = voxel_to_point(x0, z, nearest=False) 143 | z0.F = z0.F 144 | 145 | x1 = point_to_voxel(x0, z0) 146 | x1 = self.stage1(x1) 147 | x2 = self.stage2(x1) 148 | x3 = self.stage3(x2) 149 | x4 = self.stage4(x3) 150 | z1 = voxel_to_point(x4, z0) 151 | z1.F = z1.F + self.point_transforms[0](z0.F) 152 | y1 = point_to_voxel(x4, z1) 153 | pool = self.avg_pool(y1) 154 | out = self.classifier(pool) 155 | return out 156 | 157 | 158 | -------------------------------------------------------------------------------- /LeRes/spvcnn_utils.py: -------------------------------------------------------------------------------- 1 | import torchsparse.nn.functional as spf 2 | from torchsparse.point_tensor import PointTensor 3 | from torchsparse.utils.kernel_region import * 4 | from torchsparse.utils.helpers import * 5 | 6 | 7 | __all__ = ['initial_voxelize', 'point_to_voxel', 'voxel_to_point'] 8 | 9 | 10 | # z: PointTensor 11 | # return: SparseTensor 12 | def initial_voxelize(z, init_res, after_res): 13 | new_float_coord = torch.cat( 14 | [(z.C[:, :3] * init_res) / after_res, z.C[:, -1].view(-1, 1)], 1) 15 | 16 | pc_hash = spf.sphash(torch.floor(new_float_coord).int()) 17 | sparse_hash = torch.unique(pc_hash) 18 | idx_query = spf.sphashquery(pc_hash, sparse_hash) 19 | counts = spf.spcount(idx_query.int(), len(sparse_hash)) 20 | 21 | inserted_coords = spf.spvoxelize(torch.floor(new_float_coord), idx_query, 22 | counts) 23 | inserted_coords = torch.round(inserted_coords).int() 24 | inserted_feat = spf.spvoxelize(z.F, idx_query, counts) 25 | 26 | new_tensor = SparseTensor(inserted_feat, inserted_coords, 1) 27 | new_tensor.check() 28 | z.additional_features['idx_query'][1] = idx_query 29 | z.additional_features['counts'][1] = counts 30 | z.C = new_float_coord 31 | 32 | return new_tensor 33 | 34 | 35 | # x: SparseTensor, z: PointTensor 36 | # return: SparseTensor 37 | def point_to_voxel(x, z): 38 | if z.additional_features is None or z.additional_features.get('idx_query') is None\ 39 | or z.additional_features['idx_query'].get(x.s) is None: 40 | #pc_hash = hash_gpu(torch.floor(z.C).int()) 41 | pc_hash = spf.sphash( 42 | torch.cat([ 43 | torch.floor(z.C[:, :3] / x.s).int() * x.s, 44 | z.C[:, -1].int().view(-1, 1) 45 | ], 1)) 46 | sparse_hash = spf.sphash(x.C) 47 | idx_query = spf.sphashquery(pc_hash, sparse_hash) 48 | counts = spf.spcount(idx_query.int(), x.C.shape[0]) 49 | z.additional_features['idx_query'][x.s] = idx_query 50 | z.additional_features['counts'][x.s] = counts 51 | else: 52 | idx_query = z.additional_features['idx_query'][x.s] 53 | counts = z.additional_features['counts'][x.s] 54 | 55 | inserted_feat = spf.spvoxelize(z.F, idx_query, counts) 56 | new_tensor = SparseTensor(inserted_feat, x.C, x.s) 57 | new_tensor.coord_maps = x.coord_maps 58 | new_tensor.kernel_maps = x.kernel_maps 59 | 60 | return new_tensor 61 | 62 | 63 | # x: SparseTensor, z: PointTensor 64 | # return: PointTensor 65 | def voxel_to_point(x, z, nearest=False): 66 | if z.idx_query is None or z.weights is None or z.idx_query.get( 67 | x.s) is None or z.weights.get(x.s) is None: 68 | kr = KernelRegion(2, x.s, 1) 69 | off = kr.get_kernel_offset().to(z.F.device) 70 | #old_hash = kernel_hash_gpu(torch.floor(z.C).int(), off) 71 | old_hash = spf.sphash( 72 | torch.cat([ 73 | torch.floor(z.C[:, :3] / x.s).int() * x.s, 74 | z.C[:, -1].int().view(-1, 1) 75 | ], 1), off) 76 | pc_hash = spf.sphash(x.C.to(z.F.device)) 77 | idx_query = spf.sphashquery(old_hash, pc_hash) 78 | weights = spf.calc_ti_weights(z.C, idx_query, 79 | scale=x.s).transpose(0, 1).contiguous() 80 | idx_query = idx_query.transpose(0, 1).contiguous() 81 | if nearest: 82 | weights[:, 1:] = 0. 83 | idx_query[:, 1:] = -1 84 | new_feat = spf.spdevoxelize(x.F, idx_query, weights) 85 | new_tensor = PointTensor(new_feat, 86 | z.C, 87 | idx_query=z.idx_query, 88 | weights=z.weights) 89 | new_tensor.additional_features = z.additional_features 90 | new_tensor.idx_query[x.s] = idx_query 91 | new_tensor.weights[x.s] = weights 92 | z.idx_query[x.s] = idx_query 93 | z.weights[x.s] = weights 94 | 95 | else: 96 | new_feat = spf.spdevoxelize(x.F, z.idx_query.get(x.s), z.weights.get(x.s)) 97 | new_tensor = PointTensor(new_feat, 98 | z.C, 99 | idx_query=z.idx_query, 100 | weights=z.weights) 101 | new_tensor.additional_features = z.additional_features 102 | 103 | return new_tensor 104 | 105 | 106 | -------------------------------------------------------------------------------- /LeRes/test_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | from torchsparse import SparseTensor 4 | from torchsparse.utils import sparse_collate_fn, sparse_quantize 5 | from plyfile import PlyData, PlyElement 6 | import os 7 | 8 | def init_image_coor(height, width, u0=None, v0=None): 9 | u0 = width / 2.0 if u0 is None else u0 10 | v0 = height / 2.0 if v0 is None else v0 11 | 12 | x_row = np.arange(0, width) 13 | x = np.tile(x_row, (height, 1)) 14 | x = x.astype(np.float32) 15 | u_u0 = x - u0 16 | 17 | y_col = np.arange(0, height) 18 | y = np.tile(y_col, (width, 1)).T 19 | y = y.astype(np.float32) 20 | v_v0 = y - v0 21 | return u_u0, v_v0 22 | 23 | def depth_to_pcd(depth, u_u0, v_v0, f, invalid_value=0): 24 | mask_invalid = depth <= invalid_value 25 | depth[mask_invalid] = 0.0 26 | x = u_u0 / f * depth 27 | y = v_v0 / f * depth 28 | z = depth 29 | pcd = np.stack([x, y, z], axis=2) 30 | return pcd, ~mask_invalid 31 | 32 | def pcd_to_sparsetensor(pcd, mask_valid, voxel_size=0.01, num_points=100000): 33 | pcd_valid = pcd[mask_valid] 34 | block_ = pcd_valid 35 | block = np.zeros_like(block_) 36 | block[:, :3] = block_[:, :3] 37 | 38 | pc_ = np.round(block_[:, :3] / voxel_size) 39 | pc_ -= pc_.min(0, keepdims=1) 40 | feat_ = block 41 | 42 | # transfer point cloud to voxels 43 | inds = sparse_quantize(pc_, 44 | feat_, 45 | return_index=True, 46 | return_invs=False) 47 | if len(inds) > num_points: 48 | inds = np.random.choice(inds, num_points, replace=False) 49 | 50 | pc = pc_[inds] 51 | feat = feat_[inds] 52 | lidar = SparseTensor(feat, pc) 53 | feed_dict = [{'lidar': lidar}] 54 | inputs = sparse_collate_fn(feed_dict) 55 | return inputs 56 | 57 | def pcd_uv_to_sparsetensor(pcd, u_u0, v_v0, mask_valid, f= 500.0, voxel_size=0.01, mask_side=None, num_points=100000): 58 | if mask_side is not None: 59 | mask_valid = mask_valid & mask_side 60 | pcd_valid = pcd[mask_valid] 61 | u_u0_valid = u_u0[mask_valid][:, np.newaxis] / f 62 | v_v0_valid = v_v0[mask_valid][:, np.newaxis] / f 63 | 64 | block_ = np.concatenate([pcd_valid, u_u0_valid, v_v0_valid], axis=1) 65 | block = np.zeros_like(block_) 66 | block[:, :] = block_[:, :] 67 | 68 | 69 | pc_ = np.round(block_[:, :3] / voxel_size) 70 | pc_ -= pc_.min(0, keepdims=1) 71 | feat_ = block 72 | 73 | # transfer point cloud to voxels 74 | inds = sparse_quantize(pc_, 75 | feat_, 76 | return_index=True, 77 | return_invs=False) 78 | if len(inds) > num_points: 79 | inds = np.random.choice(inds, num_points, replace=False) 80 | 81 | pc = pc_[inds] 82 | feat = feat_[inds] 83 | lidar = SparseTensor(feat, pc) 84 | feed_dict = [{'lidar': lidar}] 85 | inputs = sparse_collate_fn(feed_dict) 86 | return inputs 87 | 88 | 89 | def refine_focal_one_step(depth, focal, model, u0, v0): 90 | # reconstruct PCD from depth 91 | u_u0, v_v0 = init_image_coor(depth.shape[0], depth.shape[1], u0=u0, v0=v0) 92 | pcd, mask_valid = depth_to_pcd(depth, u_u0, v_v0, f=focal, invalid_value=0) 93 | # input for the voxelnet 94 | feed_dict = pcd_uv_to_sparsetensor(pcd, u_u0, v_v0, mask_valid, f=focal, voxel_size=0.005, mask_side=None) 95 | inputs = feed_dict['lidar'].cuda() 96 | 97 | outputs = model(inputs) 98 | return outputs 99 | 100 | def refine_shift_one_step(depth_wshift, model, focal, u0, v0): 101 | # reconstruct PCD from depth 102 | u_u0, v_v0 = init_image_coor(depth_wshift.shape[0], depth_wshift.shape[1], u0=u0, v0=v0) 103 | pcd_wshift, mask_valid = depth_to_pcd(depth_wshift, u_u0, v_v0, f=focal, invalid_value=0) 104 | # input for the voxelnet 105 | feed_dict = pcd_to_sparsetensor(pcd_wshift, mask_valid, voxel_size=0.01) 106 | inputs = feed_dict['lidar'].cuda() 107 | 108 | outputs = model(inputs) 109 | return outputs 110 | 111 | def refine_focal(depth, focal, model, u0, v0): 112 | last_scale = 1 113 | focal_tmp = np.copy(focal) 114 | for i in range(1): 115 | scale = refine_focal_one_step(depth, focal_tmp, model, u0, v0) 116 | focal_tmp = focal_tmp / scale.item() 117 | last_scale = last_scale * scale 118 | return torch.tensor([[last_scale]]) 119 | 120 | def refine_shift(depth_wshift, model, focal, u0, v0): 121 | depth_wshift_tmp = np.copy(depth_wshift) 122 | last_shift = 0 123 | for i in range(1): 124 | shift = refine_shift_one_step(depth_wshift_tmp, model, focal, u0, v0) 125 | shift = shift if shift.item() < 0.7 else torch.tensor([[0.7]]) 126 | depth_wshift_tmp -= shift.item() 127 | last_shift += shift.item() 128 | return torch.tensor([[last_shift]]) 129 | 130 | def reconstruct_3D(depth, f): 131 | """ 132 | Reconstruct depth to 3D pointcloud with the provided focal length. 133 | Return: 134 | pcd: N X 3 array, point cloud 135 | """ 136 | cu = depth.shape[1] / 2 137 | cv = depth.shape[0] / 2 138 | width = depth.shape[1] 139 | height = depth.shape[0] 140 | row = np.arange(0, width, 1) 141 | u = np.array([row for i in np.arange(height)]) 142 | col = np.arange(0, height, 1) 143 | v = np.array([col for i in np.arange(width)]) 144 | v = v.transpose(1, 0) 145 | 146 | if f > 1e5: 147 | print('Infinit focal length!!!') 148 | x = u - cu 149 | y = v - cv 150 | z = depth / depth.max() * x.max() 151 | else: 152 | x = (u - cu) * depth / f 153 | y = (v - cv) * depth / f 154 | z = depth 155 | 156 | x = np.reshape(x, (width * height, 1)).astype(np.float) 157 | y = np.reshape(y, (width * height, 1)).astype(np.float) 158 | z = np.reshape(z, (width * height, 1)).astype(np.float) 159 | pcd = np.concatenate((x, y, z), axis=1) 160 | pcd = pcd.astype(np.int) 161 | return pcd 162 | 163 | def save_point_cloud(pcd, rgb, filename, binary=True): 164 | """Save an RGB point cloud as a PLY file. 165 | 166 | :paras 167 | @pcd: Nx3 matrix, the XYZ coordinates 168 | @rgb: NX3 matrix, the rgb colors for each 3D point 169 | """ 170 | assert pcd.shape[0] == rgb.shape[0] 171 | 172 | if rgb is None: 173 | gray_concat = np.tile(np.array([128], dtype=np.uint8), (pcd.shape[0], 3)) 174 | points_3d = np.hstack((pcd, gray_concat)) 175 | else: 176 | points_3d = np.hstack((pcd, rgb)) 177 | python_types = (float, float, float, int, int, int) 178 | npy_types = [('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), 179 | ('blue', 'u1')] 180 | if binary is True: 181 | # Format into NumPy structured array 182 | vertices = [] 183 | for row_idx in range(points_3d.shape[0]): 184 | cur_point = points_3d[row_idx] 185 | vertices.append(tuple(dtype(point) for dtype, point in zip(python_types, cur_point))) 186 | vertices_array = np.array(vertices, dtype=npy_types) 187 | el = PlyElement.describe(vertices_array, 'vertex') 188 | 189 | # Write 190 | PlyData([el]).write(filename) 191 | else: 192 | x = np.squeeze(points_3d[:, 0]) 193 | y = np.squeeze(points_3d[:, 1]) 194 | z = np.squeeze(points_3d[:, 2]) 195 | r = np.squeeze(points_3d[:, 3]) 196 | g = np.squeeze(points_3d[:, 4]) 197 | b = np.squeeze(points_3d[:, 5]) 198 | 199 | ply_head = 'ply\n' \ 200 | 'format ascii 1.0\n' \ 201 | 'element vertex %d\n' \ 202 | 'property float x\n' \ 203 | 'property float y\n' \ 204 | 'property float z\n' \ 205 | 'property uchar red\n' \ 206 | 'property uchar green\n' \ 207 | 'property uchar blue\n' \ 208 | 'end_header' % r.shape[0] 209 | # ---- Save ply data to disk 210 | np.savetxt(filename, np.column_stack((x, y, z, r, g, b)), fmt="%d %d %d %d %d %d", header=ply_head, comments='') 211 | 212 | def reconstruct_depth(depth, rgb, dir, pcd_name, focal): 213 | """ 214 | para disp: disparity, [h, w] 215 | para rgb: rgb image, [h, w, 3], in rgb format 216 | """ 217 | rgb = np.squeeze(rgb) 218 | depth = np.squeeze(depth) 219 | 220 | mask = depth < 1e-8 221 | depth[mask] = 0 222 | depth = depth / depth.max() * 10000 223 | 224 | pcd = reconstruct_3D(depth, f=focal) 225 | rgb_n = np.reshape(rgb, (-1, 3)) 226 | save_point_cloud(pcd, rgb_n, os.path.join(dir, pcd_name + '.ply')) 227 | 228 | 229 | def recover_metric_depth(pred, gt): 230 | if type(pred).__module__ == torch.__name__: 231 | pred = pred.cpu().numpy() 232 | if type(gt).__module__ == torch.__name__: 233 | gt = gt.cpu().numpy() 234 | gt = gt.squeeze() 235 | pred = pred.squeeze() 236 | mask = (gt > 1e-8) & (pred > 1e-8) 237 | 238 | gt_mask = gt[mask] 239 | pred_mask = pred[mask] 240 | a, b = np.polyfit(pred_mask, gt_mask, deg=1) 241 | pred_metric = a * pred + b 242 | return pred_metric 243 | -------------------------------------------------------------------------------- /MiDaS/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class BaseModel(torch.nn.Module): 6 | def load(self, path): 7 | """Load model from file. 8 | 9 | Args: 10 | path (str): file path 11 | """ 12 | parameters = torch.load(path) 13 | 14 | if "optimizer" in parameters: 15 | parameters = parameters["model"] 16 | 17 | self.load_state_dict(parameters) 18 | -------------------------------------------------------------------------------- /MiDaS/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from MiDaS.hubconf import resnext101_32x8d_wsl 4 | 5 | 6 | def _make_encoder(features, use_pretrained): 7 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 8 | scratch = _make_scratch([256, 512, 1024, 2048], features) 9 | 10 | return pretrained, scratch 11 | 12 | 13 | def _make_resnet_backbone(resnet): 14 | pretrained = nn.Module() 15 | pretrained.layer1 = nn.Sequential( 16 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 17 | ) 18 | 19 | pretrained.layer2 = resnet.layer2 20 | pretrained.layer3 = resnet.layer3 21 | pretrained.layer4 = resnet.layer4 22 | 23 | return pretrained 24 | 25 | 26 | def _make_pretrained_resnext101_wsl(use_pretrained): 27 | # resnet = torch.hub.load('facebookresearch/WSL-Images[:main]', 'resnext101_32x16d_wsl') 28 | resnet = resnext101_32x8d_wsl() 29 | return _make_resnet_backbone(resnet) 30 | 31 | 32 | def _make_scratch(in_shape, out_shape): 33 | scratch = nn.Module() 34 | 35 | scratch.layer1_rn = nn.Conv2d( 36 | in_shape[0], out_shape, kernel_size=3, stride=1, padding=1, bias=False 37 | ) 38 | scratch.layer2_rn = nn.Conv2d( 39 | in_shape[1], out_shape, kernel_size=3, stride=1, padding=1, bias=False 40 | ) 41 | scratch.layer3_rn = nn.Conv2d( 42 | in_shape[2], out_shape, kernel_size=3, stride=1, padding=1, bias=False 43 | ) 44 | scratch.layer4_rn = nn.Conv2d( 45 | in_shape[3], out_shape, kernel_size=3, stride=1, padding=1, bias=False 46 | ) 47 | return scratch 48 | 49 | 50 | class Interpolate(nn.Module): 51 | """Interpolation module. 52 | """ 53 | 54 | def __init__(self, scale_factor, mode): 55 | """Init. 56 | 57 | Args: 58 | scale_factor (float): scaling 59 | mode (str): interpolation mode 60 | """ 61 | super(Interpolate, self).__init__() 62 | 63 | self.interp = nn.functional.interpolate 64 | self.scale_factor = scale_factor 65 | self.mode = mode 66 | 67 | def forward(self, x): 68 | """Forward pass. 69 | 70 | Args: 71 | x (tensor): input 72 | 73 | Returns: 74 | tensor: interpolated data 75 | """ 76 | 77 | x = self.interp( 78 | x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False 79 | ) 80 | 81 | return x 82 | 83 | 84 | class ResidualConvUnit(nn.Module): 85 | """Residual convolution module. 86 | """ 87 | 88 | def __init__(self, features): 89 | """Init. 90 | 91 | Args: 92 | features (int): number of features 93 | """ 94 | super().__init__() 95 | 96 | self.conv1 = nn.Conv2d( 97 | features, features, kernel_size=3, stride=1, padding=1, bias=True 98 | ) 99 | 100 | self.conv2 = nn.Conv2d( 101 | features, features, kernel_size=3, stride=1, padding=1, bias=True 102 | ) 103 | 104 | self.relu = nn.ReLU(inplace=True) 105 | 106 | def forward(self, x): 107 | """Forward pass. 108 | 109 | Args: 110 | x (tensor): input 111 | 112 | Returns: 113 | tensor: output 114 | """ 115 | out = self.relu(x) 116 | out = self.conv1(out) 117 | out = self.relu(out) 118 | out = self.conv2(out) 119 | 120 | return out + x 121 | 122 | 123 | class FeatureFusionBlock(nn.Module): 124 | """Feature fusion block. 125 | """ 126 | 127 | def __init__(self, features): 128 | """Init. 129 | 130 | Args: 131 | features (int): number of features 132 | """ 133 | super(FeatureFusionBlock, self).__init__() 134 | 135 | self.resConfUnit1 = ResidualConvUnit(features) 136 | self.resConfUnit2 = ResidualConvUnit(features) 137 | 138 | def forward(self, *xs): 139 | """Forward pass. 140 | 141 | Returns: 142 | tensor: output 143 | """ 144 | output = xs[0] 145 | 146 | if len(xs) == 2: 147 | output += self.resConfUnit1(xs[1]) 148 | 149 | output = self.resConfUnit2(output) 150 | 151 | output = nn.functional.interpolate( 152 | output, scale_factor=2, mode="bilinear", align_corners=True 153 | ) 154 | 155 | return output 156 | -------------------------------------------------------------------------------- /MiDaS/hubconf.py: -------------------------------------------------------------------------------- 1 | 2 | dependencies = ['torch', 'torchvision'] 3 | 4 | from torch.hub import load_state_dict_from_url 5 | from torchvision.models.resnet import ResNet, Bottleneck 6 | 7 | 8 | model_urls = { 9 | 'resnext101_32x8d': 'https://download.pytorch.org/models/ig_resnext101_32x8-c38310e5.pth', 10 | 'resnext101_32x16d': 'https://download.pytorch.org/models/ig_resnext101_32x16-c6f796b0.pth', 11 | 'resnext101_32x32d': 'https://download.pytorch.org/models/ig_resnext101_32x32-e4b90b00.pth', 12 | 'resnext101_32x48d': 'https://download.pytorch.org/models/ig_resnext101_32x48-3e41cc8a.pth', 13 | } 14 | 15 | 16 | def _resnext(arch, block, layers, pretrained, progress, **kwargs): 17 | model = ResNet(block, layers, **kwargs) 18 | state_dict = load_state_dict_from_url(model_urls[arch], progress=progress) 19 | model.load_state_dict(state_dict) 20 | return model 21 | 22 | 23 | def resnext101_32x8d_wsl(progress=True, **kwargs): 24 | """ 25 | Args: 26 | progress (bool): If True, displays a progress bar of the download to stderr. 27 | """ 28 | kwargs['groups'] = 32 29 | kwargs['width_per_group'] = 8 30 | return _resnext('resnext101_32x8d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 31 | 32 | 33 | def resnext101_32x16d_wsl(progress=True, **kwargs): 34 | """ 35 | Args: 36 | progress (bool): If True, displays a progress bar of the download to stderr. 37 | """ 38 | kwargs['groups'] = 32 39 | kwargs['width_per_group'] = 16 40 | return _resnext('resnext101_32x16d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 41 | 42 | 43 | def resnext101_32x32d_wsl(progress=True, **kwargs): 44 | """ 45 | Args: 46 | progress (bool): If True, displays a progress bar of the download to stderr. 47 | """ 48 | kwargs['groups'] = 32 49 | kwargs['width_per_group'] = 32 50 | return _resnext('resnext101_32x32d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) 51 | 52 | 53 | def resnext101_32x48d_wsl(progress=True, **kwargs): 54 | """ 55 | Args: 56 | progress (bool): If True, displays a progress bar of the download to stderr. 57 | """ 58 | kwargs['groups'] = 32 59 | kwargs['width_per_group'] = 48 60 | return _resnext('resnext101_32x48d', Bottleneck, [3, 4, 23, 3], True, progress, **kwargs) -------------------------------------------------------------------------------- /MiDaS/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | """ 3 | import torch 4 | import torch.nn as nn 5 | 6 | from MiDaS.base_model import BaseModel 7 | from MiDaS.blocks import FeatureFusionBlock, Interpolate, _make_encoder 8 | 9 | 10 | class MidasNet(BaseModel): 11 | """Network for monocular depth estimation. 12 | """ 13 | 14 | def __init__(self, path=None, features=256, non_negative=True): 15 | """Init. 16 | 17 | Args: 18 | path (str, optional): Path to saved model. Defaults to None. 19 | features (int, optional): Number of features. Defaults to 256. 20 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 21 | """ 22 | # print("Loading weights: ", path) 23 | 24 | super(MidasNet, self).__init__() 25 | 26 | use_pretrained = False if path else True 27 | 28 | self.pretrained, self.scratch = _make_encoder(features, use_pretrained) 29 | 30 | self.scratch.refinenet4 = FeatureFusionBlock(features) 31 | self.scratch.refinenet3 = FeatureFusionBlock(features) 32 | self.scratch.refinenet2 = FeatureFusionBlock(features) 33 | self.scratch.refinenet1 = FeatureFusionBlock(features) 34 | 35 | self.scratch.output_conv = nn.Sequential( 36 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 37 | Interpolate(scale_factor=2, mode="bilinear"), 38 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 39 | nn.ReLU(True), 40 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 41 | nn.ReLU(True) if non_negative else nn.Identity(), 42 | ) 43 | 44 | if path: 45 | self.load(path) 46 | 47 | def forward(self, x): 48 | """Forward pass. 49 | 50 | Args: 51 | x (tensor): input data (image) 52 | 53 | Returns: 54 | tensor: depth 55 | """ 56 | 57 | layer_1 = self.pretrained.layer1(x) 58 | layer_2 = self.pretrained.layer2(layer_1) 59 | layer_3 = self.pretrained.layer3(layer_2) 60 | layer_4 = self.pretrained.layer4(layer_3) 61 | 62 | layer_1_rn = self.scratch.layer1_rn(layer_1) 63 | layer_2_rn = self.scratch.layer2_rn(layer_2) 64 | layer_3_rn = self.scratch.layer3_rn(layer_3) 65 | layer_4_rn = self.scratch.layer4_rn(layer_4) 66 | 67 | path_4 = self.scratch.refinenet4(layer_4_rn) 68 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 69 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 70 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 71 | 72 | out = self.scratch.output_conv(path_1) 73 | 74 | return torch.squeeze(out, dim=1) 75 | -------------------------------------------------------------------------------- /MiDaS/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height). 50 | """ 51 | 52 | def __init__( 53 | self, 54 | width, 55 | height, 56 | resize_target=True, 57 | keep_aspect_ratio=False, 58 | ensure_multiple_of=1, 59 | resize_method="lower_bound", 60 | image_interpolation_method=cv2.INTER_AREA, 61 | ): 62 | """Init. 63 | 64 | Args: 65 | width (int): desired output width 66 | height (int): desired output height 67 | resize_target (bool, optional): 68 | True: Resize the full sample (image, mask, target). 69 | False: Resize image only. 70 | Defaults to True. 71 | keep_aspect_ratio (bool, optional): 72 | True: Keep the aspect ratio of the input sample. 73 | Output sample might not have the given width and height, and 74 | resize behaviour depends on the parameter 'resize_method'. 75 | Defaults to False. 76 | ensure_multiple_of (int, optional): 77 | Output width and height is constrained to be multiple of this parameter. 78 | Defaults to 1. 79 | resize_method (str, optional): 80 | "lower_bound": Output will be at least as large as the given size. 81 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 82 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 83 | Defaults to "lower_bound". 84 | """ 85 | self.__width = width 86 | self.__height = height 87 | 88 | self.__resize_target = resize_target 89 | self.__keep_aspect_ratio = keep_aspect_ratio 90 | self.__multiple_of = ensure_multiple_of 91 | self.__resize_method = resize_method 92 | self.__image_interpolation_method = image_interpolation_method 93 | 94 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 95 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 96 | 97 | if max_val is not None and y > max_val: 98 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 99 | 100 | if y < min_val: 101 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 102 | 103 | return y 104 | 105 | def get_size(self, width, height): 106 | # determine new height and width 107 | scale_height = self.__height / height 108 | scale_width = self.__width / width 109 | 110 | if self.__keep_aspect_ratio: 111 | if self.__resize_method == "lower_bound": 112 | # scale such that output size is lower bound 113 | if scale_width > scale_height: 114 | # fit width 115 | scale_height = scale_width 116 | else: 117 | # fit height 118 | scale_width = scale_height 119 | elif self.__resize_method == "upper_bound": 120 | # scale such that output size is upper bound 121 | if scale_width < scale_height: 122 | # fit width 123 | scale_height = scale_width 124 | else: 125 | # fit height 126 | scale_width = scale_height 127 | elif self.__resize_method == "minimal": 128 | # scale as least as possbile 129 | if abs(1 - scale_width) < abs(1 - scale_height): 130 | # fit width 131 | scale_height = scale_width 132 | else: 133 | # fit height 134 | scale_width = scale_height 135 | else: 136 | raise ValueError( 137 | f"resize_method {self.__resize_method} not implemented" 138 | ) 139 | 140 | if self.__resize_method == "lower_bound": 141 | new_height = self.constrain_to_multiple_of( 142 | scale_height * height, min_val=self.__height 143 | ) 144 | new_width = self.constrain_to_multiple_of( 145 | scale_width * width, min_val=self.__width 146 | ) 147 | elif self.__resize_method == "upper_bound": 148 | new_height = self.constrain_to_multiple_of( 149 | scale_height * height, max_val=self.__height 150 | ) 151 | new_width = self.constrain_to_multiple_of( 152 | scale_width * width, max_val=self.__width 153 | ) 154 | elif self.__resize_method == "minimal": 155 | new_height = self.constrain_to_multiple_of(scale_height * height) 156 | new_width = self.constrain_to_multiple_of(scale_width * width) 157 | else: 158 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 159 | 160 | return (new_width, new_height) 161 | 162 | def __call__(self, sample): 163 | width, height = self.get_size( 164 | sample["image"].shape[1], sample["image"].shape[0] 165 | ) 166 | 167 | # resize sample 168 | sample["image"] = cv2.resize( 169 | sample["image"], 170 | (width, height), 171 | interpolation=self.__image_interpolation_method, 172 | ) 173 | 174 | if self.__resize_target: 175 | if "disparity" in sample: 176 | sample["disparity"] = cv2.resize( 177 | sample["disparity"], 178 | (width, height), 179 | interpolation=cv2.INTER_NEAREST, 180 | ) 181 | 182 | if "depth" in sample: 183 | sample["depth"] = cv2.resize( 184 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 185 | ) 186 | 187 | sample["mask"] = cv2.resize( 188 | sample["mask"].astype(np.float32), 189 | (width, height), 190 | interpolation=cv2.INTER_NEAREST, 191 | ) 192 | sample["mask"] = sample["mask"].astype(bool) 193 | 194 | return sample 195 | 196 | 197 | class NormalizeImage(object): 198 | """Normlize image by given mean and std. 199 | """ 200 | 201 | def __init__(self, mean, std): 202 | self.__mean = mean 203 | self.__std = std 204 | 205 | def __call__(self, sample): 206 | sample["image"] = (sample["image"] - self.__mean) / self.__std 207 | 208 | return sample 209 | 210 | 211 | class PrepareForNet(object): 212 | """Prepare sample for usage as network input. 213 | """ 214 | 215 | def __init__(self): 216 | pass 217 | 218 | def __call__(self, sample): 219 | image = np.transpose(sample["image"], (2, 0, 1)) 220 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 221 | 222 | if "mask" in sample: 223 | sample["mask"] = sample["mask"].astype(np.float32) 224 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 225 | 226 | if "disparity" in sample: 227 | disparity = sample["disparity"].astype(np.float32) 228 | sample["disparity"] = np.ascontiguousarray(disparity) 229 | 230 | if "depth" in sample: 231 | depth = sample["depth"].astype(np.float32) 232 | sample["depth"] = np.ascontiguousarray(depth) 233 | 234 | return sample 235 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Multi-resolution Monocular Depth Map Fusion by Self-supervised Gradient-based Composition 2 | 3 | This repository contains code and models for our [paper](https://arxiv.org/pdf/2212.01538.pdf): 4 | 5 | > [1] Yaqiao Dai, Renjiao Yi, Chenyang Zhu, Hongjun He, Kai Xu, Multi-resolution Monocular Depth Map Fusion by Self-supervised Gradient-based Composition, AAAI 2023 6 | 7 | ![](./figures/1.gif) 8 | 9 | ![](./figures/2.gif) 10 | 11 | ![](./figures/3.gif) 12 | 13 | ### Changelog 14 | 15 | * [November 2022] Initial release of code and models 16 | 17 | ### Setup 18 | 19 | 1) Download the code. 20 | ```shell 21 | git clone https://github.com/YuiNsky/Gradient-based-depth-map-fusion.git 22 | cd Gradient-based-depth-map-fusion 23 | ``` 24 | 25 | 26 | 27 | 28 | 2. Set up dependencies: 29 | 2.1 Create conda virtual environment. 30 | 31 | ```shell 32 | conda env create -f GBDF.yaml 33 | conda activate GBDF 34 | ``` 35 | 36 | 2.2 Install pytorch in virtual environment. 37 | ```shell 38 | pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html 39 | ``` 40 | 41 | 42 | 43 | 3. Download fusion model [model_dict.pt](https://github.com/YuiNsky/Gradient-based-depth-map-fusion/releases/download/v1.0/model_dict.pt) and place in the folder `models`. 44 | 45 | 46 | 47 | 48 | 4. Download one or more backbone pretrained model. 49 | 50 | ​ LeRes: [res50.pth](https://cloudstor.aarnet.edu.au/plus/s/VVQayrMKPlpVkw9) or [res101.pth](https://cloudstor.aarnet.edu.au/plus/s/lTIJF4vrvHCAI31), place in the folder `LeRes`. 51 | 52 | ​ DPT: [dpt_hybrid-midas-501f0c75.pt](https://github.com/intel-isl/DPT/releases/download/1_0/dpt_hybrid-midas-501f0c75.pt), place in the folder `dpt/weights`. 53 | 54 | ​ SGR: [model.pth.tar](https://drive.google.com/file/d/1p8c8-nUTNry5usQmGdTC2TrwWrp3dQ0y/view?usp=sharing) , place in the folder `SGR`. 55 | 56 | ​ MiDas: [model.pt](https://drive.google.com/file/d/1nqW_Hwj86kslfsXR7EnXpEWdO2csz1cC), place in the folder `MiDaS`. 57 | 58 | ​ NeWCRFs: [model_nyu.ckpt](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/newcrfs/models/model_nyu.ckpt), place in the folder `newcrfs`. 59 | 60 | 61 | 62 | 63 | 5. The code was tested with Python 3.8, PyTorch 1.9.1, OpenCV 4.6.0. 64 | 65 | ### Usage 66 | 67 | 1) Place one or more input images in the folder `input`. 68 | 69 | 2) Run our model with a monocular depth estimation method: 70 | 71 | ```shell 72 | python run.py -p LeRes50 73 | ``` 74 | 75 | 76 | 3) The results are written to the folder `output`, every result is the combination of input image, backbone prediction and our prediction. 77 | 78 | ​ Use the flag `-p` to switch between different backbones. Possible options are `LeRes50` (default), `LeRes101`, `SGR`, `MiDaS`, `DPT` and `NeWCRFs`. 79 | 80 | ### Evaluation 81 | 82 | Our evaluation contains three published high resolution datasets, which are Multiscopic, Middleburry2021 and Hypersim. 83 | 84 | To evaluate our model on Multiscopic, you can download this dataset [here](https://sites.google.com/view/multiscopic). You need to download the test dataset, rename it as `multiscopic` and place it in folder `datasets`. 85 | 86 | To evaluate our model on Middleburry2021, you can download this dataset [here](https://vision.middlebury.edu/stereo/data/scenes2021/zip/all.zip). You need to unzip the dataset, rename it as `2021mobile` and place it in folder `datasets`. 87 | 88 | To evaluate our model on Hypersim, you can download the whole dataset [here](https://github.com/apple/ml-hypersim/blob/main/code/python/tools/dataset_download_images.py). We also provide the evaluation subsets [hypersim](https://shanghaitecheducn-my.sharepoint.com/:u:/g/personal/chenky12022_shanghaitech_edu_cn/EZcASVNppkNIo34mSBiXUjAByGyg4HCEXW0voRdnmT-sQg?e=1opopk). You need to download the subsets and place it in folder `datasets`. 89 | 90 | 91 | 92 | Then you can evaluate our fusion model with specified monocular depth estimation method and dataset: 93 | 94 | ```shell 95 | python eval.py -p LeRes50 -d middleburry2021 96 | ``` 97 | 98 | Use the flag `-p` to switch between different backbones. Possible options are `LeRes50` (default), `SGR`, `MiDaS`, `DPT` and `NeWCRFs`. 99 | 100 | Use the flag `-d` to switch between different datasets. Possible options are `middleburry2021` (default), `multiscopic` and `hypersim`. 101 | 102 | ### Training 103 | 104 | Our model was trained based on backbone LeRes and dataset HR-WSI, we use guided filter to preprocess the dataset and select high quality results as our training datasets based on canny edge detection. You need to download the preprocessed dataset [HR](https://shanghaitecheducn-my.sharepoint.com/:u:/g/personal/chenky12022_shanghaitech_edu_cn/EQWICYjodhFCjiimYsLSUDABI5-sYddf6MleupjU0RRPWQ?e=GahsRB) and place it in the folder `datasets`. 105 | 106 | Then you can train our fusion model using GPU: 107 | 108 | ```shell 109 | python train.py 110 | ``` 111 | 112 | 113 | ### Citation 114 | 115 | Please cite our papers if you use this code. 116 | ``` 117 | @article{dai2022multi, 118 | title={Multi-resolution Monocular Depth Map Fusion by Self-supervised Gradient-based Composition}, 119 | author={Dai, Yaqiao and Yi, Renjiao and Zhu, Chenyang and He, Hongjun and Xu, Kai}, 120 | journal={arXiv preprint arXiv:2212.01538}, 121 | year={2022} 122 | } 123 | ``` 124 | 125 | ### License 126 | 127 | MIT License 128 | -------------------------------------------------------------------------------- /SGR/DepthNet.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | import torchvision 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | import torch.nn.init as init 7 | 8 | import sys 9 | 10 | 11 | import SGR.resnet as resnet 12 | 13 | from SGR.networks import * 14 | 15 | class Decoder(nn.Module): 16 | def __init__(self, inchannels = [256, 512, 1024, 2048], midchannels = [256, 256, 256, 512], upfactors = [2,2,2,2], outchannels = 1): 17 | super(Decoder, self).__init__() 18 | self.inchannels = inchannels 19 | self.midchannels = midchannels 20 | self.upfactors = upfactors 21 | self.outchannels = outchannels 22 | 23 | self.conv = FTB(inchannels=self.inchannels[3], midchannels=self.midchannels[3]) 24 | self.conv1 = nn.Conv2d(in_channels=self.midchannels[3], out_channels=self.midchannels[2], kernel_size=3, padding=1, stride=1, bias=True) 25 | self.upsample = nn.Upsample(scale_factor=self.upfactors[3], mode='bilinear', align_corners=True) 26 | 27 | self.ffm2 = FFM(inchannels=self.inchannels[2], midchannels=self.midchannels[2], outchannels = self.midchannels[2], upfactor=self.upfactors[2]) 28 | self.ffm1 = FFM(inchannels=self.inchannels[1], midchannels=self.midchannels[1], outchannels = self.midchannels[1], upfactor=self.upfactors[1]) 29 | self.ffm0 = FFM(inchannels=self.inchannels[0], midchannels=self.midchannels[0], outchannels = self.midchannels[0], upfactor=self.upfactors[0]) 30 | 31 | self.outconv = AO(inchannels=self.inchannels[0], outchannels=self.outchannels, upfactor=2) 32 | 33 | self._init_params() 34 | 35 | def _init_params(self): 36 | for m in self.modules(): 37 | if isinstance(m, nn.Conv2d): 38 | #init.kaiming_normal_(m.weight, mode='fan_out') 39 | init.normal_(m.weight, std=0.01) 40 | #init.xavier_normal_(m.weight) 41 | if m.bias is not None: 42 | init.constant_(m.bias, 0) 43 | elif isinstance(m, nn.ConvTranspose2d): 44 | #init.kaiming_normal_(m.weight, mode='fan_out') 45 | init.normal_(m.weight, std=0.01) 46 | #init.xavier_normal_(m.weight) 47 | if m.bias is not None: 48 | init.constant_(m.bias, 0) 49 | elif isinstance(m, NN.BatchNorm2d): #NN.BatchNorm2d 50 | init.constant_(m.weight, 1) 51 | init.constant_(m.bias, 0) 52 | elif isinstance(m, nn.Linear): 53 | init.normal_(m.weight, std=0.01) 54 | if m.bias is not None: 55 | init.constant_(m.bias, 0) 56 | 57 | def forward(self, features): 58 | _,_,h,w = features[3].size() 59 | x = self.conv(features[3]) 60 | x = self.conv1(x) 61 | x = self.upsample(x) 62 | 63 | x = self.ffm2(features[2], x) 64 | x = self.ffm1(features[1], x) 65 | x = self.ffm0(features[0], x) 66 | 67 | #----------------------------------------- 68 | x = self.outconv(x) 69 | 70 | return x 71 | 72 | class DepthNet(nn.Module): 73 | __factory = { 74 | 18: resnet.resnet18, 75 | 34: resnet.resnet34, 76 | 50: resnet.resnet50, 77 | 101: resnet.resnet101, 78 | 152: resnet.resnet152 79 | } 80 | def __init__(self, 81 | backbone='resnet', 82 | depth=50, 83 | pretrained=True, 84 | inchannels=[256, 512, 1024, 2048], 85 | midchannels=[256, 256, 256, 512], 86 | upfactors=[2, 2, 2, 2], 87 | outchannels=1): 88 | super(DepthNet, self).__init__() 89 | self.backbone = backbone 90 | self.depth = depth 91 | self.pretrained = pretrained 92 | self.inchannels = inchannels 93 | self.midchannels = midchannels 94 | self.upfactors = upfactors 95 | self.outchannels = outchannels 96 | 97 | # Build model 98 | if self.depth not in DepthNet.__factory: 99 | raise KeyError("Unsupported depth:", self.depth) 100 | self.encoder = DepthNet.__factory[depth](pretrained=pretrained) 101 | 102 | self.decoder = Decoder(inchannels=self.inchannels, midchannels=self.midchannels, upfactors=self.upfactors, outchannels=self.outchannels) 103 | 104 | def forward(self, x): 105 | x = self.encoder(x) # 1/4, 1/8, 1/16, 1/32 106 | x = self.decoder(x) 107 | 108 | return x 109 | 110 | if __name__ == '__main__': 111 | net = DepthNet(depth=50, pretrained=True) 112 | print(net) 113 | inputs = torch.ones(4,3,128,128) 114 | out = net(inputs) 115 | print(out.size()) 116 | -------------------------------------------------------------------------------- /SGR/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import sys 5 | import torch.nn as NN 6 | 7 | 8 | class FTB(nn.Module): 9 | def __init__(self, inchannels, midchannels=512): 10 | super(FTB, self).__init__() 11 | self.in1 = inchannels 12 | self.mid = midchannels 13 | 14 | self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True) 15 | # NN.BatchNorm2d 16 | self.conv_branch = nn.Sequential(nn.ReLU(inplace=True),\ 17 | nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True),\ 18 | NN.BatchNorm2d(num_features=self.mid),\ 19 | nn.ReLU(inplace=True),\ 20 | nn.Conv2d(in_channels=self.mid, out_channels= self.mid, kernel_size=3, padding=1, stride=1, bias=True)) 21 | self.relu = nn.ReLU(inplace=True) 22 | 23 | self.init_params() 24 | 25 | def forward(self, x): 26 | x = self.conv1(x) 27 | x = x + self.conv_branch(x) 28 | x = self.relu(x) 29 | 30 | return x 31 | 32 | def init_params(self): 33 | for m in self.modules(): 34 | if isinstance(m, nn.Conv2d): 35 | #init.kaiming_normal_(m.weight, mode='fan_out') 36 | init.normal_(m.weight, std=0.01) 37 | # init.xavier_normal_(m.weight) 38 | if m.bias is not None: 39 | init.constant_(m.bias, 0) 40 | elif isinstance(m, nn.ConvTranspose2d): 41 | #init.kaiming_normal_(m.weight, mode='fan_out') 42 | init.normal_(m.weight, std=0.01) 43 | # init.xavier_normal_(m.weight) 44 | if m.bias is not None: 45 | init.constant_(m.bias, 0) 46 | elif isinstance(m, NN.BatchNorm2d): #NN.BatchNorm2d 47 | init.constant_(m.weight, 1) 48 | init.constant_(m.bias, 0) 49 | elif isinstance(m, nn.Linear): 50 | init.normal_(m.weight, std=0.01) 51 | if m.bias is not None: 52 | init.constant_(m.bias, 0) 53 | 54 | 55 | class FFM(nn.Module): 56 | def __init__(self, inchannels, midchannels, outchannels, upfactor=2): 57 | super(FFM, self).__init__() 58 | self.inchannels = inchannels 59 | self.midchannels = midchannels 60 | self.outchannels = outchannels 61 | self.upfactor = upfactor 62 | 63 | self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels) 64 | self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels) 65 | 66 | self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) 67 | 68 | self.init_params() 69 | 70 | def forward(self, low_x, high_x): 71 | x = self.ftb1(low_x) 72 | x = x + high_x 73 | x = self.ftb2(x) 74 | x = self.upsample(x) 75 | 76 | return x 77 | 78 | def init_params(self): 79 | for m in self.modules(): 80 | if isinstance(m, nn.Conv2d): 81 | #init.kaiming_normal_(m.weight, mode='fan_out') 82 | init.normal_(m.weight, std=0.01) 83 | #init.xavier_normal_(m.weight) 84 | if m.bias is not None: 85 | init.constant_(m.bias, 0) 86 | elif isinstance(m, nn.ConvTranspose2d): 87 | #init.kaiming_normal_(m.weight, mode='fan_out') 88 | init.normal_(m.weight, std=0.01) 89 | #init.xavier_normal_(m.weight) 90 | if m.bias is not None: 91 | init.constant_(m.bias, 0) 92 | elif isinstance(m, NN.BatchNorm2d): #NN.Batchnorm2d 93 | init.constant_(m.weight, 1) 94 | init.constant_(m.bias, 0) 95 | elif isinstance(m, nn.Linear): 96 | init.normal_(m.weight, std=0.01) 97 | if m.bias is not None: 98 | init.constant_(m.bias, 0) 99 | 100 | 101 | class AO(nn.Module): 102 | # Adaptive output module 103 | def __init__(self, inchannels, outchannels, upfactor=2): 104 | super(AO, self).__init__() 105 | self.inchannels = inchannels 106 | self.outchannels = outchannels 107 | self.upfactor = upfactor 108 | 109 | self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\ 110 | NN.BatchNorm2d(num_features=self.inchannels//2),\ 111 | nn.ReLU(inplace=True),\ 112 | nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True),\ 113 | nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)) 114 | 115 | self.init_params() 116 | 117 | def forward(self, x): 118 | x = self.adapt_conv(x) 119 | return x 120 | 121 | def init_params(self): 122 | for m in self.modules(): 123 | if isinstance(m, nn.Conv2d): 124 | #init.kaiming_normal_(m.weight, mode='fan_out') 125 | init.normal_(m.weight, std=0.01) 126 | #init.xavier_normal_(m.weight) 127 | if m.bias is not None: 128 | init.constant_(m.bias, 0) 129 | elif isinstance(m, nn.ConvTranspose2d): 130 | #init.kaiming_normal_(m.weight, mode='fan_out') 131 | init.normal_(m.weight, std=0.01) 132 | #init.xavier_normal_(m.weight) 133 | if m.bias is not None: 134 | init.constant_(m.bias, 0) 135 | elif isinstance(m, NN.BatchNorm2d): #NN.Batchnorm2d 136 | init.constant_(m.weight, 1) 137 | init.constant_(m.bias, 0) 138 | elif isinstance(m, nn.Linear): 139 | init.normal_(m.weight, std=0.01) 140 | if m.bias is not None: 141 | init.constant_(m.bias, 0) 142 | -------------------------------------------------------------------------------- /SGR/resnet.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import math 3 | import torch.utils.model_zoo as model_zoo 4 | import torchvision 5 | 6 | import sys 7 | from torch import nn as NN 8 | 9 | 10 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 11 | 'resnet152'] 12 | 13 | 14 | model_urls = { 15 | 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 16 | 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 17 | 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 18 | 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 19 | 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth', 20 | } 21 | 22 | 23 | def conv3x3(in_planes, out_planes, stride=1): 24 | """3x3 convolution with padding""" 25 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 26 | padding=1, bias=False) 27 | 28 | 29 | class BasicBlock(nn.Module): 30 | expansion = 1 31 | 32 | def __init__(self, inplanes, planes, stride=1, downsample=None): 33 | super(BasicBlock, self).__init__() 34 | self.conv1 = conv3x3(inplanes, planes, stride) 35 | self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d 36 | self.relu = nn.ReLU(inplace=True) 37 | self.conv2 = conv3x3(planes, planes) 38 | self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d 39 | self.downsample = downsample 40 | self.stride = stride 41 | 42 | def forward(self, x): 43 | residual = x 44 | 45 | out = self.conv1(x) 46 | out = self.bn1(out) 47 | out = self.relu(out) 48 | 49 | out = self.conv2(out) 50 | out = self.bn2(out) 51 | 52 | if self.downsample is not None: 53 | residual = self.downsample(x) 54 | 55 | out += residual 56 | out = self.relu(out) 57 | 58 | return out 59 | 60 | 61 | class Bottleneck(nn.Module): 62 | expansion = 4 63 | 64 | def __init__(self, inplanes, planes, stride=1, downsample=None): 65 | super(Bottleneck, self).__init__() 66 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 67 | self.bn1 = NN.BatchNorm2d(planes) #NN.BatchNorm2d 68 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride, 69 | padding=1, bias=False) 70 | self.bn2 = NN.BatchNorm2d(planes) #NN.BatchNorm2d 71 | self.conv3 = nn.Conv2d(planes, planes * self.expansion, kernel_size=1, bias=False) 72 | self.bn3 = NN.BatchNorm2d(planes * self.expansion) #NN.BatchNorm2d 73 | self.relu = nn.ReLU(inplace=True) 74 | self.downsample = downsample 75 | self.stride = stride 76 | 77 | def forward(self, x): 78 | residual = x 79 | 80 | out = self.conv1(x) 81 | out = self.bn1(out) 82 | out = self.relu(out) 83 | 84 | out = self.conv2(out) 85 | out = self.bn2(out) 86 | out = self.relu(out) 87 | 88 | out = self.conv3(out) 89 | out = self.bn3(out) 90 | 91 | if self.downsample is not None: 92 | residual = self.downsample(x) 93 | 94 | out += residual 95 | out = self.relu(out) 96 | 97 | return out 98 | 99 | 100 | class ResNet(nn.Module): 101 | 102 | def __init__(self, block, layers, num_classes=1000): 103 | self.inplanes = 64 104 | super(ResNet, self).__init__() 105 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 106 | bias=False) 107 | self.bn1 = NN.BatchNorm2d(64) #NN.BatchNorm2d 108 | self.relu = nn.ReLU(inplace=True) 109 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 110 | self.layer1 = self._make_layer(block, 64, layers[0]) 111 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2) 112 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2) 113 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2) 114 | #self.avgpool = nn.AvgPool2d(7, stride=1) 115 | #self.fc = nn.Linear(512 * block.expansion, num_classes) 116 | 117 | for m in self.modules(): 118 | if isinstance(m, nn.Conv2d): 119 | nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 120 | elif isinstance(m, nn.BatchNorm2d): 121 | nn.init.constant_(m.weight, 1) 122 | nn.init.constant_(m.bias, 0) 123 | 124 | def _make_layer(self, block, planes, blocks, stride=1): 125 | downsample = None 126 | if stride != 1 or self.inplanes != planes * block.expansion: 127 | downsample = nn.Sequential( 128 | nn.Conv2d(self.inplanes, planes * block.expansion, 129 | kernel_size=1, stride=stride, bias=False), 130 | NN.BatchNorm2d(planes * block.expansion), #NN.BatchNorm2d 131 | ) 132 | 133 | layers = [] 134 | layers.append(block(self.inplanes, planes, stride, downsample)) 135 | self.inplanes = planes * block.expansion 136 | for i in range(1, blocks): 137 | layers.append(block(self.inplanes, planes)) 138 | 139 | return nn.Sequential(*layers) 140 | 141 | def forward(self, x): 142 | features = [] 143 | 144 | x = self.conv1(x) 145 | x = self.bn1(x) 146 | x = self.relu(x) 147 | x = self.maxpool(x) 148 | 149 | x = self.layer1(x) 150 | features.append(x) 151 | x = self.layer2(x) 152 | features.append(x) 153 | x = self.layer3(x) 154 | features.append(x) 155 | x = self.layer4(x) 156 | features.append(x) 157 | 158 | return features 159 | 160 | 161 | def resnet18(pretrained=True, **kwargs): 162 | """Constructs a ResNet-18 model. 163 | Args: 164 | pretrained (bool): If True, returns a model pre-trained on ImageNet 165 | """ 166 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 167 | if pretrained: 168 | pretrained_model = torchvision.models.resnet18(pretrained=True) 169 | pretrained_dict = pretrained_model.state_dict() 170 | model_dict = model.state_dict() 171 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict} 172 | model_dict.update(pretrained_dict) 173 | model.load_state_dict(model_dict) 174 | 175 | return model 176 | 177 | 178 | def resnet34(pretrained=True, **kwargs): 179 | """Constructs a ResNet-34 model. 180 | Args: 181 | pretrained (bool): If True, returns a model pre-trained on ImageNet 182 | """ 183 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 184 | if pretrained: 185 | pretrained_model = torchvision.models.resnet34(pretrained=True) 186 | pretrained_dict = pretrained_model.state_dict() 187 | model_dict = model.state_dict() 188 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict} 189 | model_dict.update(pretrained_dict) 190 | model.load_state_dict(model_dict) 191 | 192 | return model 193 | 194 | 195 | def resnet50(pretrained=True, **kwargs): 196 | """Constructs a ResNet-50 model. 197 | Args: 198 | pretrained (bool): If True, returns a model pre-trained on ImageNet 199 | """ 200 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 201 | if pretrained: 202 | pretrained_model = torchvision.models.resnet50(pretrained=True) 203 | pretrained_dict = pretrained_model.state_dict() 204 | model_dict = model.state_dict() 205 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict} 206 | model_dict.update(pretrained_dict) 207 | model.load_state_dict(model_dict) 208 | 209 | return model 210 | 211 | 212 | def resnet101(pretrained=True, **kwargs): 213 | """Constructs a ResNet-101 model. 214 | Args: 215 | pretrained (bool): If True, returns a model pre-trained on ImageNet 216 | """ 217 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 218 | if pretrained: 219 | pretrained_model = torchvision.models.resnet101(pretrained=True) 220 | pretrained_dict = pretrained_model.state_dict() 221 | model_dict = model.state_dict() 222 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict} 223 | model_dict.update(pretrained_dict) 224 | model.load_state_dict(model_dict) 225 | 226 | return model 227 | 228 | 229 | def resnet152(pretrained=True, **kwargs): 230 | """Constructs a ResNet-152 model. 231 | Args: 232 | pretrained (bool): If True, returns a model pre-trained on ImageNet 233 | """ 234 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 235 | if pretrained: 236 | pretrained_model = torchvision.models.resnet152(pretrained=True) 237 | pretrained_dict = pretrained_model.state_dict() 238 | model_dict = model.state_dict() 239 | pretrained_dict = {k:v for k, v in pretrained_dict.items() if k in model_dict} 240 | model_dict.update(pretrained_dict) 241 | model.load_state_dict(model_dict) 242 | 243 | return model 244 | -------------------------------------------------------------------------------- /SGR/syncbn/make_ext.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | 3 | PYTHON_CMD=${PYTHON_CMD:=python} 4 | CUDA_PATH=/usr/local/cuda-8.0 5 | CUDA_INCLUDE_DIR=/usr/local/cuda-8.0/include 6 | GENCODE="-gencode arch=compute_61,code=sm_61 \ 7 | -gencode arch=compute_52,code=sm_52 \ 8 | -gencode arch=compute_52,code=compute_52" 9 | NVCCOPT="-std=c++11 -x cu --expt-extended-lambda -O3 -Xcompiler -fPIC" 10 | 11 | ROOTDIR=$PWD 12 | echo "========= Build BatchNorm2dSync =========" 13 | if [ -z "$1" ]; then TORCH=$($PYTHON_CMD -c "import os; import torch; print(os.path.dirname(torch.__file__))"); else TORCH="$1"; fi 14 | cd modules/functional/_syncbn/src 15 | $CUDA_PATH/bin/nvcc -c -o syncbn.cu.o syncbn.cu $NVCCOPT $GENCODE -I $CUDA_INCLUDE_DIR 16 | cd ../ 17 | $PYTHON_CMD build.py 18 | cd $ROOTDIR 19 | 20 | # END 21 | echo "========= Build Complete =========" 22 | -------------------------------------------------------------------------------- /SGR/syncbn/modules/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/SGR/syncbn/modules/__init__.py -------------------------------------------------------------------------------- /SGR/syncbn/modules/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/SGR/syncbn/modules/__init__.pyc -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncbn import batchnorm2d_sync 2 | -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/SGR/syncbn/modules/functional/__init__.pyc -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/_syncbn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/SGR/syncbn/modules/functional/_syncbn/__init__.py -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/_syncbn/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/SGR/syncbn/modules/functional/_syncbn/__init__.pyc -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/_syncbn/_ext/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/SGR/syncbn/modules/functional/_syncbn/_ext/__init__.py -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/_syncbn/_ext/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/SGR/syncbn/modules/functional/_syncbn/_ext/__init__.pyc -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/_syncbn/_ext/syncbn/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._syncbn import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/_syncbn/_ext/syncbn/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/SGR/syncbn/modules/functional/_syncbn/_ext/syncbn/__init__.pyc -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/_syncbn/_ext/syncbn/_syncbn.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/SGR/syncbn/modules/functional/_syncbn/_ext/syncbn/_syncbn.so -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/_syncbn/build.py: -------------------------------------------------------------------------------- 1 | import os 2 | from torch.utils.ffi import create_extension 3 | 4 | sources = ['src/syncbn.cpp'] 5 | headers = ['src/syncbn.h'] 6 | extra_objects = ['src/syncbn.cu.o'] 7 | with_cuda = True 8 | 9 | this_file = os.path.dirname(os.path.realpath(__file__)) 10 | extra_objects = [os.path.join(this_file, fname) for fname in extra_objects] 11 | 12 | ffi = create_extension( 13 | '_ext.syncbn', 14 | headers=headers, 15 | sources=sources, 16 | relative_to=__file__, 17 | with_cuda=with_cuda, 18 | extra_objects=extra_objects, 19 | extra_compile_args=["-std=c++11"] 20 | ) 21 | 22 | if __name__ == '__main__': 23 | ffi.build() 24 | -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/_syncbn/src/common.h: -------------------------------------------------------------------------------- 1 | #ifndef __COMMON__ 2 | #define __COMMON__ 3 | #include 4 | 5 | /* 6 | * General settings 7 | */ 8 | const int WARP_SIZE = 32; 9 | const int MAX_BLOCK_SIZE = 512; 10 | 11 | /* 12 | * Utility functions 13 | */ 14 | template 15 | __device__ __forceinline__ T WARP_SHFL_XOR( 16 | T value, int laneMask, int width = warpSize, unsigned int mask = 0xffffffff) { 17 | #if CUDART_VERSION >= 9000 18 | return __shfl_xor_sync(mask, value, laneMask, width); 19 | #else 20 | return __shfl_xor(value, laneMask, width); 21 | #endif 22 | } 23 | 24 | __device__ __forceinline__ int getMSB(int val) { return 31 - __clz(val); } 25 | 26 | static int getNumThreads(int nElem) { 27 | int threadSizes[5] = {32, 64, 128, 256, MAX_BLOCK_SIZE}; 28 | for (int i = 0; i != 5; ++i) { 29 | if (nElem <= threadSizes[i]) { 30 | return threadSizes[i]; 31 | } 32 | } 33 | return MAX_BLOCK_SIZE; 34 | } 35 | 36 | 37 | #endif -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/_syncbn/src/syncbn.cpp: -------------------------------------------------------------------------------- 1 | // All functions assume that input and output tensors are already initialized 2 | // and have the correct dimensions 3 | #include 4 | 5 | extern THCState *state; 6 | 7 | void get_sizes(const THCudaTensor *t, int *N, int *C, int *S) { 8 | // Get sizes 9 | *S = 1; 10 | *N = THCudaTensor_size(state, t, 0); 11 | *C = THCudaTensor_size(state, t, 1); 12 | if (THCudaTensor_nDimension(state, t) > 2) { 13 | for (int i = 2; i < THCudaTensor_nDimension(state, t); ++i) { 14 | *S *= THCudaTensor_size(state, t, i); 15 | } 16 | } 17 | } 18 | 19 | // Forward definition of implementation functions 20 | extern "C" { 21 | int _syncbn_sum_sqsum_cuda(int N, int C, int S, 22 | const float *x, float *sum, float *sqsum, 23 | cudaStream_t stream); 24 | int _syncbn_forward_cuda( 25 | int N, int C, int S, float *z, const float *x, 26 | const float *gamma, const float *beta, 27 | const float *mean, const float *var, float eps, cudaStream_t stream); 28 | int _syncbn_backward_xhat_cuda( 29 | int N, int C, int S, const float *dz, const float *x, 30 | const float *mean, const float *var, float *sum_dz, float *sum_dz_xhat, 31 | float eps, cudaStream_t stream); 32 | int _syncbn_backward_cuda( 33 | int N, int C, int S, const float *dz, const float *x, 34 | const float *gamma, const float *beta, 35 | const float *mean, const float *var, 36 | const float *sum_dz, const float *sum_dz_xhat, 37 | float *dx, float *dgamma, float *dbeta, 38 | float eps, cudaStream_t stream); 39 | } 40 | 41 | extern "C" int syncbn_sum_sqsum_cuda( 42 | const THCudaTensor *x, THCudaTensor *sum, THCudaTensor *sqsum) { 43 | cudaStream_t stream = THCState_getCurrentStream(state); 44 | 45 | int S, N, C; 46 | get_sizes(x, &N, &C, &S); 47 | 48 | // Get pointers 49 | const float *x_data = THCudaTensor_data(state, x); 50 | float *sum_data = THCudaTensor_data(state, sum); 51 | float *sqsum_data = THCudaTensor_data(state, sqsum); 52 | 53 | return _syncbn_sum_sqsum_cuda(N, C, S, x_data, sum_data, sqsum_data, stream); 54 | } 55 | 56 | extern "C" int syncbn_forward_cuda( 57 | THCudaTensor *z, const THCudaTensor *x, 58 | const THCudaTensor *gamma, const THCudaTensor *beta, 59 | const THCudaTensor *mean, const THCudaTensor *var, float eps){ 60 | cudaStream_t stream = THCState_getCurrentStream(state); 61 | 62 | int S, N, C; 63 | get_sizes(x, &N, &C, &S); 64 | 65 | // Get pointers 66 | float *z_data = THCudaTensor_data(state, z); 67 | const float *x_data = THCudaTensor_data(state, x); 68 | const float *gamma_data = THCudaTensor_nDimension(state, gamma) != 0 ? 69 | THCudaTensor_data(state, gamma) : 0; 70 | const float *beta_data = THCudaTensor_nDimension(state, beta) != 0 ? 71 | THCudaTensor_data(state, beta) : 0; 72 | const float *mean_data = THCudaTensor_data(state, mean); 73 | const float *var_data = THCudaTensor_data(state, var); 74 | 75 | return _syncbn_forward_cuda( 76 | N, C, S, z_data, x_data, gamma_data, beta_data, 77 | mean_data, var_data, eps, stream); 78 | 79 | } 80 | 81 | extern "C" int syncbn_backward_xhat_cuda( 82 | const THCudaTensor *dz, const THCudaTensor *x, 83 | const THCudaTensor *mean, const THCudaTensor *var, 84 | THCudaTensor *sum_dz, THCudaTensor *sum_dz_xhat, float eps) { 85 | cudaStream_t stream = THCState_getCurrentStream(state); 86 | 87 | int S, N, C; 88 | get_sizes(dz, &N, &C, &S); 89 | 90 | // Get pointers 91 | const float *dz_data = THCudaTensor_data(state, dz); 92 | const float *x_data = THCudaTensor_data(state, x); 93 | const float *mean_data = THCudaTensor_data(state, mean); 94 | const float *var_data = THCudaTensor_data(state, var); 95 | float *sum_dz_data = THCudaTensor_data(state, sum_dz); 96 | float *sum_dz_xhat_data = THCudaTensor_data(state, sum_dz_xhat); 97 | 98 | return _syncbn_backward_xhat_cuda( 99 | N, C, S, dz_data, x_data, mean_data, var_data, 100 | sum_dz_data, sum_dz_xhat_data, eps, stream); 101 | 102 | } 103 | extern "C" int syncbn_backard_cuda( 104 | const THCudaTensor *dz, const THCudaTensor *x, 105 | const THCudaTensor *gamma, const THCudaTensor *beta, 106 | const THCudaTensor *mean, const THCudaTensor *var, 107 | const THCudaTensor *sum_dz, const THCudaTensor *sum_dz_xhat, 108 | THCudaTensor *dx, THCudaTensor *dgamma, THCudaTensor *dbeta, float eps) { 109 | cudaStream_t stream = THCState_getCurrentStream(state); 110 | 111 | int S, N, C; 112 | get_sizes(dz, &N, &C, &S); 113 | 114 | // Get pointers 115 | const float *dz_data = THCudaTensor_data(state, dz); 116 | const float *x_data = THCudaTensor_data(state, x); 117 | const float *gamma_data = THCudaTensor_nDimension(state, gamma) != 0 ? 118 | THCudaTensor_data(state, gamma) : 0; 119 | const float *beta_data = THCudaTensor_nDimension(state, beta) != 0 ? 120 | THCudaTensor_data(state, beta) : 0; 121 | const float *mean_data = THCudaTensor_data(state, mean); 122 | const float *var_data = THCudaTensor_data(state, var); 123 | const float *sum_dz_data = THCudaTensor_data(state, sum_dz); 124 | const float *sum_dz_xhat_data = THCudaTensor_data(state, sum_dz_xhat); 125 | float *dx_data = THCudaTensor_nDimension(state, dx) != 0 ? 126 | THCudaTensor_data(state, dx) : 0; 127 | float *dgamma_data = THCudaTensor_nDimension(state, dgamma) != 0 ? 128 | THCudaTensor_data(state, dgamma) : 0; 129 | float *dbeta_data = THCudaTensor_nDimension(state, dbeta) != 0 ? 130 | THCudaTensor_data(state, dbeta) : 0; 131 | 132 | return _syncbn_backward_cuda( 133 | N, C, S, dz_data, x_data, gamma_data, beta_data, 134 | mean_data, var_data, sum_dz_data, sum_dz_xhat_data, 135 | dx_data, dgamma_data, dbeta_data, eps, stream); 136 | } -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/_syncbn/src/syncbn.cu.h: -------------------------------------------------------------------------------- 1 | #ifndef __SYNCBN__ 2 | #define __SYNCBN__ 3 | 4 | /* 5 | * Exported functions 6 | */ 7 | extern "C" int _syncbn_sum_sqsum_cuda(int N, int C, int S, const float *x, 8 | float *sum, float *sqsum, 9 | cudaStream_t stream); 10 | extern "C" int _syncbn_forward_cuda( 11 | int N, int C, int S, float *z, const float *x, 12 | const float *gamma, const float *beta, const float *mean, const float *var, 13 | float eps, cudaStream_t stream); 14 | extern "C" int _syncbn_backward_xhat_cuda( 15 | int N, int C, int S, const float *dz, const float *x, 16 | const float *mean, const float *var, float *sum_dz, float *sum_dz_xhat, 17 | float eps, cudaStream_t stream); 18 | extern "C" int _syncbn_backward_cuda( 19 | int N, int C, int S, const float *dz, const float *x, 20 | const float *gamma, const float *beta, const float *mean, const float *var, 21 | const float *sum_dz, const float *sum_dz_xhat, 22 | float *dx, float *dweight, float *dbias, 23 | float eps, cudaStream_t stream); 24 | 25 | 26 | #endif 27 | -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/_syncbn/src/syncbn.cu.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/SGR/syncbn/modules/functional/_syncbn/src/syncbn.cu.o -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/_syncbn/src/syncbn.h: -------------------------------------------------------------------------------- 1 | int syncbn_sum_sqsum_cuda( 2 | const THCudaTensor *x, THCudaTensor *sum, THCudaTensor *sqsum); 3 | int syncbn_forward_cuda( 4 | THCudaTensor *z, const THCudaTensor *x, 5 | const THCudaTensor *gamma, const THCudaTensor *beta, 6 | const THCudaTensor *mean, const THCudaTensor *var, float eps); 7 | int syncbn_backward_xhat_cuda( 8 | const THCudaTensor *dz, const THCudaTensor *x, 9 | const THCudaTensor *mean, const THCudaTensor *var, 10 | THCudaTensor *sum_dz, THCudaTensor *sum_dz_xhat, 11 | float eps); 12 | int syncbn_backard_cuda( 13 | const THCudaTensor *dz, const THCudaTensor *x, 14 | const THCudaTensor *gamma, const THCudaTensor *beta, 15 | const THCudaTensor *mean, const THCudaTensor *var, 16 | const THCudaTensor *sum_dz, const THCudaTensor *sum_dz_xhat, 17 | THCudaTensor *dx, THCudaTensor *dgamma, THCudaTensor *dbeta, float eps); 18 | -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/syncbn.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | BatchNorm2dSync with multi-gpu 5 | 6 | 7 | /*****************************************************************************/ 8 | """ 9 | from __future__ import absolute_import 10 | from __future__ import division 11 | from __future__ import print_function 12 | 13 | import torch.cuda.comm as comm 14 | from torch.autograd import Function 15 | from torch.autograd.function import once_differentiable 16 | 17 | from ._syncbn._ext import syncbn as _lib_bn 18 | 19 | 20 | def _count_samples(x): 21 | count = 1 22 | for i, s in enumerate(x.size()): 23 | if i != 1: 24 | count *= s 25 | return count 26 | 27 | 28 | def _check_contiguous(*args): 29 | if not all([mod is None or mod.is_contiguous() for mod in args]): 30 | raise ValueError("Non-contiguous input") 31 | 32 | 33 | class BatchNorm2dSyncFunc(Function): 34 | 35 | @classmethod 36 | def forward(cls, ctx, x, weight, bias, running_mean, running_var, 37 | extra, compute_stats=True, momentum=0.1, eps=1e-05): 38 | # Save context 39 | if extra is not None: 40 | cls._parse_extra(ctx, extra) 41 | ctx.compute_stats = compute_stats 42 | ctx.momentum = momentum 43 | ctx.eps = eps 44 | if ctx.compute_stats: 45 | N = _count_samples(x) * (ctx.master_queue.maxsize + 1) 46 | assert N > 1 47 | num_features = running_mean.size(0) 48 | # 1. compute sum(x) and sum(x^2) 49 | xsum = x.new().resize_(num_features) 50 | xsqsum = x.new().resize_(num_features) 51 | _check_contiguous(x, xsum, xsqsum) 52 | _lib_bn.syncbn_sum_sqsum_cuda(x.detach(), xsum, xsqsum) 53 | if ctx.is_master: 54 | xsums, xsqsums = [xsum], [xsqsum] 55 | # master : gatther all sum(x) and sum(x^2) from slaves 56 | for _ in range(ctx.master_queue.maxsize): 57 | xsum_w, xsqsum_w = ctx.master_queue.get() 58 | ctx.master_queue.task_done() 59 | xsums.append(xsum_w) 60 | xsqsums.append(xsqsum_w) 61 | xsum = comm.reduce_add(xsums) 62 | xsqsum = comm.reduce_add(xsqsums) 63 | mean = xsum / N 64 | sumvar = xsqsum - xsum * mean 65 | var = sumvar / N 66 | uvar = sumvar / (N - 1) 67 | # master : broadcast global mean, variance to all slaves 68 | tensors = comm.broadcast_coalesced( 69 | (mean, uvar, var), [mean.get_device()] + ctx.worker_ids) 70 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 71 | queue.put(ts) 72 | else: 73 | # slave : send sum(x) and sum(x^2) to master 74 | ctx.master_queue.put((xsum, xsqsum)) 75 | # slave : get global mean and variance 76 | mean, uvar, var = ctx.worker_queue.get() 77 | ctx.worker_queue.task_done() 78 | 79 | # Update running stats 80 | running_mean.mul_((1 - ctx.momentum)).add_(ctx.momentum * mean) 81 | running_var.mul_((1 - ctx.momentum)).add_(ctx.momentum * uvar) 82 | ctx.N = N 83 | ctx.save_for_backward(x, weight, bias, mean, var) 84 | else: 85 | mean, var = running_mean, running_var 86 | 87 | output = x.new().resize_as_(x) 88 | _check_contiguous(output, x, mean, var, weight, bias) 89 | # do batch norm forward 90 | _lib_bn.syncbn_forward_cuda( 91 | output, x, weight if weight is not None else x.new(), 92 | bias if bias is not None else x.new(), mean, var, ctx.eps) 93 | return output 94 | 95 | @staticmethod 96 | @once_differentiable 97 | def backward(ctx, dz): 98 | x, weight, bias, mean, var = ctx.saved_tensors 99 | dz = dz.contiguous() 100 | if ctx.needs_input_grad[0]: 101 | dx = dz.new().resize_as_(dz) 102 | else: 103 | dx = None 104 | if ctx.needs_input_grad[1]: 105 | dweight = dz.new().resize_as_(mean).zero_() 106 | else: 107 | dweight = None 108 | if ctx.needs_input_grad[2]: 109 | dbias = dz.new().resize_as_(mean).zero_() 110 | else: 111 | dbias = None 112 | _check_contiguous(x, dz, weight, bias, mean, var) 113 | 114 | # 1. compute \sum(\frac{dJ}{dy_i}) and \sum(\frac{dJ}{dy_i}*\hat{x_i}) 115 | num_features = mean.size(0) 116 | sum_dz = x.new().resize_(num_features) 117 | sum_dz_xhat = x.new().resize_(num_features) 118 | _check_contiguous(sum_dz, sum_dz_xhat) 119 | _lib_bn.syncbn_backward_xhat_cuda( 120 | dz, x, mean, var, sum_dz, sum_dz_xhat, ctx.eps) 121 | if ctx.is_master: 122 | sum_dzs, sum_dz_xhats = [sum_dz], [sum_dz_xhat] 123 | # master : gatther from slaves 124 | for _ in range(ctx.master_queue.maxsize): 125 | sum_dz_w, sum_dz_xhat_w = ctx.master_queue.get() 126 | ctx.master_queue.task_done() 127 | sum_dzs.append(sum_dz_w) 128 | sum_dz_xhats.append(sum_dz_xhat_w) 129 | # master : compute global stats 130 | sum_dz = comm.reduce_add(sum_dzs) 131 | sum_dz_xhat = comm.reduce_add(sum_dz_xhats) 132 | sum_dz /= ctx.N 133 | sum_dz_xhat /= ctx.N 134 | # master : broadcast global stats 135 | tensors = comm.broadcast_coalesced( 136 | (sum_dz, sum_dz_xhat), [mean.get_device()] + ctx.worker_ids) 137 | for ts, queue in zip(tensors[1:], ctx.worker_queues): 138 | queue.put(ts) 139 | else: 140 | # slave : send to master 141 | ctx.master_queue.put((sum_dz, sum_dz_xhat)) 142 | # slave : get global stats 143 | sum_dz, sum_dz_xhat = ctx.worker_queue.get() 144 | ctx.worker_queue.task_done() 145 | 146 | # do batch norm backward 147 | _lib_bn.syncbn_backard_cuda( 148 | dz, x, weight if weight is not None else dz.new(), 149 | bias if bias is not None else dz.new(), 150 | mean, var, sum_dz, sum_dz_xhat, 151 | dx if dx is not None else dz.new(), 152 | dweight if dweight is not None else dz.new(), 153 | dbias if dbias is not None else dz.new(), ctx.eps) 154 | 155 | return dx, dweight, dbias, None, None, None, \ 156 | None, None, None, None, None 157 | 158 | @staticmethod 159 | def _parse_extra(ctx, extra): 160 | ctx.is_master = extra["is_master"] 161 | if ctx.is_master: 162 | ctx.master_queue = extra["master_queue"] 163 | ctx.worker_queues = extra["worker_queues"] 164 | ctx.worker_ids = extra["worker_ids"] 165 | else: 166 | ctx.master_queue = extra["master_queue"] 167 | ctx.worker_queue = extra["worker_queue"] 168 | 169 | batchnorm2d_sync = BatchNorm2dSyncFunc.apply 170 | 171 | __all__ = ["batchnorm2d_sync"] 172 | -------------------------------------------------------------------------------- /SGR/syncbn/modules/functional/syncbn.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/SGR/syncbn/modules/functional/syncbn.pyc -------------------------------------------------------------------------------- /SGR/syncbn/modules/nn/__init__.py: -------------------------------------------------------------------------------- 1 | from .syncbn import * 2 | -------------------------------------------------------------------------------- /SGR/syncbn/modules/nn/__init__.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/SGR/syncbn/modules/nn/__init__.pyc -------------------------------------------------------------------------------- /SGR/syncbn/modules/nn/syncbn.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | BatchNorm2dSync with multi-gpu 5 | 6 | /*****************************************************************************/ 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | try: 13 | # python 3 14 | from queue import Queue 15 | except ImportError: 16 | # python 2 17 | from Queue import Queue 18 | 19 | import torch 20 | import torch.nn as nn 21 | from modules.functional import batchnorm2d_sync 22 | 23 | 24 | class BatchNorm2d(nn.BatchNorm2d): 25 | """ 26 | BatchNorm2d with automatic multi-GPU Sync 27 | """ 28 | 29 | def __init__(self, num_features, eps=1e-5, momentum=0.1, affine=True, 30 | track_running_stats=True): 31 | super(BatchNorm2d, self).__init__( 32 | num_features, eps=eps, momentum=momentum, affine=affine, 33 | track_running_stats=track_running_stats) 34 | self.devices = list(range(torch.cuda.device_count())) 35 | if len(self.devices) > 1: 36 | # Initialize queues 37 | self.worker_ids = self.devices[1:] 38 | self.master_queue = Queue(len(self.worker_ids)) 39 | self.worker_queues = [Queue(1) for _ in self.worker_ids] 40 | 41 | def forward(self, x): 42 | compute_stats = self.training or not self.track_running_stats 43 | if compute_stats and len(self.devices) > 1: 44 | if x.get_device() == self.devices[0]: 45 | # Master mode 46 | extra = { 47 | "is_master": True, 48 | "master_queue": self.master_queue, 49 | "worker_queues": self.worker_queues, 50 | "worker_ids": self.worker_ids 51 | } 52 | else: 53 | # Worker mode 54 | extra = { 55 | "is_master": False, 56 | "master_queue": self.master_queue, 57 | "worker_queue": self.worker_queues[ 58 | self.worker_ids.index(x.get_device())] 59 | } 60 | return batchnorm2d_sync(x, self.weight, self.bias, 61 | self.running_mean, self.running_var, 62 | extra, compute_stats, self.momentum, 63 | self.eps) 64 | return super(BatchNorm2d, self).forward(x) 65 | 66 | def __repr__(self): 67 | """repr""" 68 | rep = '{name}({num_features}, eps={eps}, momentum={momentum},' \ 69 | ' affine={affine}, devices={devices})' 70 | return rep.format(name=self.__class__.__name__, **self.__dict__) 71 | -------------------------------------------------------------------------------- /SGR/syncbn/modules/nn/syncbn.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/SGR/syncbn/modules/nn/syncbn.pyc -------------------------------------------------------------------------------- /SGR/syncbn/requirements.txt: -------------------------------------------------------------------------------- 1 | future 2 | cffi 3 | -------------------------------------------------------------------------------- /SGR/syncbn/test.py: -------------------------------------------------------------------------------- 1 | """ 2 | /*****************************************************************************/ 3 | 4 | Test for BatchNorm2dSync with multi-gpu 5 | 6 | /*****************************************************************************/ 7 | """ 8 | from __future__ import absolute_import 9 | from __future__ import division 10 | from __future__ import print_function 11 | 12 | import sys 13 | import numpy as np 14 | import torch 15 | from torch import nn 16 | from torch.nn import functional as F 17 | sys.path.append("./") 18 | from modules import nn as NN 19 | 20 | torch.backends.cudnn.deterministic = True 21 | 22 | 23 | def init_weight(model): 24 | for m in model.modules(): 25 | if isinstance(m, nn.Conv2d): 26 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 27 | m.weight.data.normal_(0, np.sqrt(2. / n)) 28 | elif isinstance(m, NN.BatchNorm2d) or isinstance(m, nn.BatchNorm2d): 29 | m.weight.data.fill_(1) 30 | m.bias.data.zero_() 31 | elif isinstance(m, nn.Linear): 32 | m.bias.data.zero_() 33 | 34 | num_gpu = torch.cuda.device_count() 35 | print("num_gpu={}".format(num_gpu)) 36 | if num_gpu < 2: 37 | print("No multi-gpu found. NN.BatchNorm2d will act as normal nn.BatchNorm2d") 38 | 39 | m1 = nn.Sequential( 40 | nn.Conv2d(3, 3, 1, 1, bias=False), 41 | nn.BatchNorm2d(3), 42 | nn.ReLU(inplace=True), 43 | nn.Conv2d(3, 3, 1, 1, bias=False), 44 | nn.BatchNorm2d(3), 45 | ).cuda() 46 | torch.manual_seed(123) 47 | init_weight(m1) 48 | m2 = nn.Sequential( 49 | nn.Conv2d(3, 3, 1, 1, bias=False), 50 | NN.BatchNorm2d(3), 51 | nn.ReLU(inplace=True), 52 | nn.Conv2d(3, 3, 1, 1, bias=False), 53 | NN.BatchNorm2d(3), 54 | ).cuda() 55 | torch.manual_seed(123) 56 | init_weight(m2) 57 | m2 = nn.DataParallel(m2, device_ids=range(num_gpu)) 58 | o1 = torch.optim.SGD(m1.parameters(), 1e-3) 59 | o2 = torch.optim.SGD(m2.parameters(), 1e-3) 60 | y = torch.ones(num_gpu).float().cuda() 61 | torch.manual_seed(123) 62 | for _ in range(100): 63 | x = torch.rand(num_gpu, 3, 2, 2).cuda() 64 | o1.zero_grad() 65 | z1 = m1(x) 66 | l1 = F.mse_loss(z1.mean(-1).mean(-1).mean(-1), y) 67 | l1.backward() 68 | o1.step() 69 | o2.zero_grad() 70 | z2 = m2(x) 71 | l2 = F.mse_loss(z2.mean(-1).mean(-1).mean(-1), y) 72 | l2.backward() 73 | o2.step() 74 | print(m2.module[1].bias.grad - m1[1].bias.grad) 75 | print(m2.module[1].weight.grad - m1[1].weight.grad) 76 | print(m2.module[-1].bias.grad - m1[-1].bias.grad) 77 | print(m2.module[-1].weight.grad - m1[-1].weight.grad) 78 | m2 = m2.module 79 | print("===============================") 80 | print("m1(nn.BatchNorm2d) running_mean", 81 | m1[1].running_mean, m1[-1].running_mean) 82 | print("m2(NN.BatchNorm2d) running_mean", 83 | m2[1].running_mean, m2[-1].running_mean) 84 | print("m1(nn.BatchNorm2d) running_var", m1[1].running_var, m1[-1].running_var) 85 | print("m2(NN.BatchNorm2d) running_var", m2[1].running_var, m2[-1].running_var) 86 | print("m1(nn.BatchNorm2d) weight", m1[1].weight, m1[-1].weight) 87 | print("m2(NN.BatchNorm2d) weight", m2[1].weight, m2[-1].weight) 88 | print("m1(nn.BatchNorm2d) bias", m1[1].bias, m1[-1].bias) 89 | print("m2(NN.BatchNorm2d) bias", m2[1].bias, m2[-1].bias) 90 | -------------------------------------------------------------------------------- /dpt/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/dpt/__init__.py -------------------------------------------------------------------------------- /dpt/base_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | class BaseModel(torch.nn.Module): 5 | def load(self, path): 6 | """Load model from file. 7 | 8 | Args: 9 | path (str): file path 10 | """ 11 | parameters = torch.load(path, map_location=torch.device("cpu")) 12 | 13 | if "optimizer" in parameters: 14 | parameters = parameters["model"] 15 | 16 | self.load_state_dict(parameters) 17 | -------------------------------------------------------------------------------- /dpt/blocks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | from .vit import ( 5 | _make_pretrained_vitb_rn50_384, 6 | _make_pretrained_vitl16_384, 7 | _make_pretrained_vitb16_384, 8 | forward_vit, 9 | ) 10 | 11 | 12 | def _make_encoder( 13 | backbone, 14 | features, 15 | use_pretrained, 16 | groups=1, 17 | expand=False, 18 | exportable=True, 19 | hooks=None, 20 | use_vit_only=False, 21 | use_readout="ignore", 22 | enable_attention_hooks=False, 23 | ): 24 | if backbone == "vitl16_384": 25 | pretrained = _make_pretrained_vitl16_384( 26 | use_pretrained, 27 | hooks=hooks, 28 | use_readout=use_readout, 29 | enable_attention_hooks=enable_attention_hooks, 30 | ) 31 | scratch = _make_scratch( 32 | [256, 512, 1024, 1024], features, groups=groups, expand=expand 33 | ) # ViT-L/16 - 85.0% Top1 (backbone) 34 | elif backbone == "vitb_rn50_384": 35 | pretrained = _make_pretrained_vitb_rn50_384( 36 | use_pretrained, 37 | hooks=hooks, 38 | use_vit_only=use_vit_only, 39 | use_readout=use_readout, 40 | enable_attention_hooks=enable_attention_hooks, 41 | ) 42 | scratch = _make_scratch( 43 | [256, 512, 768, 768], features, groups=groups, expand=expand 44 | ) # ViT-H/16 - 85.0% Top1 (backbone) 45 | elif backbone == "vitb16_384": 46 | pretrained = _make_pretrained_vitb16_384( 47 | use_pretrained, 48 | hooks=hooks, 49 | use_readout=use_readout, 50 | enable_attention_hooks=enable_attention_hooks, 51 | ) 52 | scratch = _make_scratch( 53 | [96, 192, 384, 768], features, groups=groups, expand=expand 54 | ) # ViT-B/16 - 84.6% Top1 (backbone) 55 | elif backbone == "resnext101_wsl": 56 | pretrained = _make_pretrained_resnext101_wsl(use_pretrained) 57 | scratch = _make_scratch( 58 | [256, 512, 1024, 2048], features, groups=groups, expand=expand 59 | ) # efficientnet_lite3 60 | else: 61 | print(f"Backbone '{backbone}' not implemented") 62 | assert False 63 | 64 | return pretrained, scratch 65 | 66 | 67 | def _make_scratch(in_shape, out_shape, groups=1, expand=False): 68 | scratch = nn.Module() 69 | 70 | out_shape1 = out_shape 71 | out_shape2 = out_shape 72 | out_shape3 = out_shape 73 | out_shape4 = out_shape 74 | if expand == True: 75 | out_shape1 = out_shape 76 | out_shape2 = out_shape * 2 77 | out_shape3 = out_shape * 4 78 | out_shape4 = out_shape * 8 79 | 80 | scratch.layer1_rn = nn.Conv2d( 81 | in_shape[0], 82 | out_shape1, 83 | kernel_size=3, 84 | stride=1, 85 | padding=1, 86 | bias=False, 87 | groups=groups, 88 | ) 89 | scratch.layer2_rn = nn.Conv2d( 90 | in_shape[1], 91 | out_shape2, 92 | kernel_size=3, 93 | stride=1, 94 | padding=1, 95 | bias=False, 96 | groups=groups, 97 | ) 98 | scratch.layer3_rn = nn.Conv2d( 99 | in_shape[2], 100 | out_shape3, 101 | kernel_size=3, 102 | stride=1, 103 | padding=1, 104 | bias=False, 105 | groups=groups, 106 | ) 107 | scratch.layer4_rn = nn.Conv2d( 108 | in_shape[3], 109 | out_shape4, 110 | kernel_size=3, 111 | stride=1, 112 | padding=1, 113 | bias=False, 114 | groups=groups, 115 | ) 116 | 117 | return scratch 118 | 119 | 120 | def _make_resnet_backbone(resnet): 121 | pretrained = nn.Module() 122 | pretrained.layer1 = nn.Sequential( 123 | resnet.conv1, resnet.bn1, resnet.relu, resnet.maxpool, resnet.layer1 124 | ) 125 | 126 | pretrained.layer2 = resnet.layer2 127 | pretrained.layer3 = resnet.layer3 128 | pretrained.layer4 = resnet.layer4 129 | 130 | return pretrained 131 | 132 | 133 | def _make_pretrained_resnext101_wsl(use_pretrained): 134 | resnet = torch.hub.load("facebookresearch/WSL-Images", "resnext101_32x8d_wsl") 135 | return _make_resnet_backbone(resnet) 136 | 137 | 138 | class Interpolate(nn.Module): 139 | """Interpolation module.""" 140 | 141 | def __init__(self, scale_factor, mode, align_corners=False): 142 | """Init. 143 | 144 | Args: 145 | scale_factor (float): scaling 146 | mode (str): interpolation mode 147 | """ 148 | super(Interpolate, self).__init__() 149 | 150 | self.interp = nn.functional.interpolate 151 | self.scale_factor = scale_factor 152 | self.mode = mode 153 | self.align_corners = align_corners 154 | 155 | def forward(self, x): 156 | """Forward pass. 157 | 158 | Args: 159 | x (tensor): input 160 | 161 | Returns: 162 | tensor: interpolated data 163 | """ 164 | 165 | x = self.interp( 166 | x, 167 | scale_factor=self.scale_factor, 168 | mode=self.mode, 169 | align_corners=self.align_corners, 170 | ) 171 | 172 | return x 173 | 174 | 175 | class ResidualConvUnit(nn.Module): 176 | """Residual convolution module.""" 177 | 178 | def __init__(self, features): 179 | """Init. 180 | 181 | Args: 182 | features (int): number of features 183 | """ 184 | super().__init__() 185 | 186 | self.conv1 = nn.Conv2d( 187 | features, features, kernel_size=3, stride=1, padding=1, bias=True 188 | ) 189 | 190 | self.conv2 = nn.Conv2d( 191 | features, features, kernel_size=3, stride=1, padding=1, bias=True 192 | ) 193 | 194 | self.relu = nn.ReLU(inplace=True) 195 | 196 | def forward(self, x): 197 | """Forward pass. 198 | 199 | Args: 200 | x (tensor): input 201 | 202 | Returns: 203 | tensor: output 204 | """ 205 | out = self.relu(x) 206 | out = self.conv1(out) 207 | out = self.relu(out) 208 | out = self.conv2(out) 209 | 210 | return out + x 211 | 212 | 213 | class FeatureFusionBlock(nn.Module): 214 | """Feature fusion block.""" 215 | 216 | def __init__(self, features): 217 | """Init. 218 | 219 | Args: 220 | features (int): number of features 221 | """ 222 | super(FeatureFusionBlock, self).__init__() 223 | 224 | self.resConfUnit1 = ResidualConvUnit(features) 225 | self.resConfUnit2 = ResidualConvUnit(features) 226 | 227 | def forward(self, *xs): 228 | """Forward pass. 229 | 230 | Returns: 231 | tensor: output 232 | """ 233 | output = xs[0] 234 | 235 | if len(xs) == 2: 236 | output += self.resConfUnit1(xs[1]) 237 | 238 | output = self.resConfUnit2(output) 239 | 240 | output = nn.functional.interpolate( 241 | output, scale_factor=2, mode="bilinear", align_corners=True 242 | ) 243 | 244 | return output 245 | 246 | 247 | class ResidualConvUnit_custom(nn.Module): 248 | """Residual convolution module.""" 249 | 250 | def __init__(self, features, activation, bn): 251 | """Init. 252 | 253 | Args: 254 | features (int): number of features 255 | """ 256 | super().__init__() 257 | 258 | self.bn = bn 259 | 260 | self.groups = 1 261 | 262 | self.conv1 = nn.Conv2d( 263 | features, 264 | features, 265 | kernel_size=3, 266 | stride=1, 267 | padding=1, 268 | bias=not self.bn, 269 | groups=self.groups, 270 | ) 271 | 272 | self.conv2 = nn.Conv2d( 273 | features, 274 | features, 275 | kernel_size=3, 276 | stride=1, 277 | padding=1, 278 | bias=not self.bn, 279 | groups=self.groups, 280 | ) 281 | 282 | if self.bn == True: 283 | self.bn1 = nn.BatchNorm2d(features) 284 | self.bn2 = nn.BatchNorm2d(features) 285 | 286 | self.activation = activation 287 | 288 | self.skip_add = nn.quantized.FloatFunctional() 289 | 290 | def forward(self, x): 291 | """Forward pass. 292 | 293 | Args: 294 | x (tensor): input 295 | 296 | Returns: 297 | tensor: output 298 | """ 299 | 300 | out = self.activation(x) 301 | out = self.conv1(out) 302 | if self.bn == True: 303 | out = self.bn1(out) 304 | 305 | out = self.activation(out) 306 | out = self.conv2(out) 307 | if self.bn == True: 308 | out = self.bn2(out) 309 | 310 | if self.groups > 1: 311 | out = self.conv_merge(out) 312 | 313 | return self.skip_add.add(out, x) 314 | 315 | # return out + x 316 | 317 | 318 | class FeatureFusionBlock_custom(nn.Module): 319 | """Feature fusion block.""" 320 | 321 | def __init__( 322 | self, 323 | features, 324 | activation, 325 | deconv=False, 326 | bn=False, 327 | expand=False, 328 | align_corners=True, 329 | ): 330 | """Init. 331 | 332 | Args: 333 | features (int): number of features 334 | """ 335 | super(FeatureFusionBlock_custom, self).__init__() 336 | 337 | self.deconv = deconv 338 | self.align_corners = align_corners 339 | 340 | self.groups = 1 341 | 342 | self.expand = expand 343 | out_features = features 344 | if self.expand == True: 345 | out_features = features // 2 346 | 347 | self.out_conv = nn.Conv2d( 348 | features, 349 | out_features, 350 | kernel_size=1, 351 | stride=1, 352 | padding=0, 353 | bias=True, 354 | groups=1, 355 | ) 356 | 357 | self.resConfUnit1 = ResidualConvUnit_custom(features, activation, bn) 358 | self.resConfUnit2 = ResidualConvUnit_custom(features, activation, bn) 359 | 360 | self.skip_add = nn.quantized.FloatFunctional() 361 | 362 | def forward(self, *xs): 363 | """Forward pass. 364 | 365 | Returns: 366 | tensor: output 367 | """ 368 | output = xs[0] 369 | 370 | if len(xs) == 2: 371 | res = self.resConfUnit1(xs[1]) 372 | output = self.skip_add.add(output, res) 373 | # output += res 374 | 375 | output = self.resConfUnit2(output) 376 | 377 | output = nn.functional.interpolate( 378 | output, scale_factor=2, mode="bilinear", align_corners=self.align_corners 379 | ) 380 | 381 | output = self.out_conv(output) 382 | 383 | return output 384 | -------------------------------------------------------------------------------- /dpt/midas_net.py: -------------------------------------------------------------------------------- 1 | """MidashNet: Network for monocular depth estimation trained by mixing several datasets. 2 | This file contains code that is adapted from 3 | https://github.com/thomasjpfan/pytorch_refinenet/blob/master/pytorch_refinenet/refinenet/refinenet_4cascade.py 4 | """ 5 | import torch 6 | import torch.nn as nn 7 | 8 | from .base_model import BaseModel 9 | from .blocks import FeatureFusionBlock, Interpolate, _make_encoder 10 | 11 | 12 | class MidasNet_large(BaseModel): 13 | """Network for monocular depth estimation.""" 14 | 15 | def __init__(self, path=None, features=256, non_negative=True): 16 | """Init. 17 | 18 | Args: 19 | path (str, optional): Path to saved model. Defaults to None. 20 | features (int, optional): Number of features. Defaults to 256. 21 | backbone (str, optional): Backbone network for encoder. Defaults to resnet50 22 | """ 23 | print("Loading weights: ", path) 24 | 25 | super(MidasNet_large, self).__init__() 26 | 27 | use_pretrained = False if path is None else True 28 | 29 | self.pretrained, self.scratch = _make_encoder( 30 | backbone="resnext101_wsl", features=features, use_pretrained=use_pretrained 31 | ) 32 | 33 | self.scratch.refinenet4 = FeatureFusionBlock(features) 34 | self.scratch.refinenet3 = FeatureFusionBlock(features) 35 | self.scratch.refinenet2 = FeatureFusionBlock(features) 36 | self.scratch.refinenet1 = FeatureFusionBlock(features) 37 | 38 | self.scratch.output_conv = nn.Sequential( 39 | nn.Conv2d(features, 128, kernel_size=3, stride=1, padding=1), 40 | Interpolate(scale_factor=2, mode="bilinear"), 41 | nn.Conv2d(128, 32, kernel_size=3, stride=1, padding=1), 42 | nn.ReLU(True), 43 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 44 | nn.ReLU(True) if non_negative else nn.Identity(), 45 | ) 46 | 47 | if path: 48 | self.load(path) 49 | 50 | def forward(self, x): 51 | """Forward pass. 52 | 53 | Args: 54 | x (tensor): input data (image) 55 | 56 | Returns: 57 | tensor: depth 58 | """ 59 | 60 | layer_1 = self.pretrained.layer1(x) 61 | layer_2 = self.pretrained.layer2(layer_1) 62 | layer_3 = self.pretrained.layer3(layer_2) 63 | layer_4 = self.pretrained.layer4(layer_3) 64 | 65 | layer_1_rn = self.scratch.layer1_rn(layer_1) 66 | layer_2_rn = self.scratch.layer2_rn(layer_2) 67 | layer_3_rn = self.scratch.layer3_rn(layer_3) 68 | layer_4_rn = self.scratch.layer4_rn(layer_4) 69 | 70 | path_4 = self.scratch.refinenet4(layer_4_rn) 71 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 72 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 73 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 74 | 75 | out = self.scratch.output_conv(path_1) 76 | 77 | return torch.squeeze(out, dim=1) 78 | -------------------------------------------------------------------------------- /dpt/models.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .base_model import BaseModel 6 | from .blocks import ( 7 | FeatureFusionBlock, 8 | FeatureFusionBlock_custom, 9 | Interpolate, 10 | _make_encoder, 11 | forward_vit, 12 | ) 13 | 14 | 15 | def _make_fusion_block(features, use_bn): 16 | return FeatureFusionBlock_custom( 17 | features, 18 | nn.ReLU(False), 19 | deconv=False, 20 | bn=use_bn, 21 | expand=False, 22 | align_corners=True, 23 | ) 24 | 25 | 26 | class DPT(BaseModel): 27 | def __init__( 28 | self, 29 | head, 30 | features=256, 31 | backbone="vitb_rn50_384", 32 | readout="project", 33 | channels_last=False, 34 | use_bn=False, 35 | enable_attention_hooks=False, 36 | ): 37 | 38 | super(DPT, self).__init__() 39 | 40 | self.channels_last = channels_last 41 | 42 | hooks = { 43 | "vitb_rn50_384": [0, 1, 8, 11], 44 | "vitb16_384": [2, 5, 8, 11], 45 | "vitl16_384": [5, 11, 17, 23], 46 | } 47 | 48 | # Instantiate backbone and reassemble blocks 49 | self.pretrained, self.scratch = _make_encoder( 50 | backbone, 51 | features, 52 | False, # Set to true of you want to train from scratch, uses ImageNet weights 53 | groups=1, 54 | expand=False, 55 | exportable=False, 56 | hooks=hooks[backbone], 57 | use_readout=readout, 58 | enable_attention_hooks=enable_attention_hooks, 59 | ) 60 | 61 | self.scratch.refinenet1 = _make_fusion_block(features, use_bn) 62 | self.scratch.refinenet2 = _make_fusion_block(features, use_bn) 63 | self.scratch.refinenet3 = _make_fusion_block(features, use_bn) 64 | self.scratch.refinenet4 = _make_fusion_block(features, use_bn) 65 | 66 | self.scratch.output_conv = head 67 | 68 | def forward(self, x): 69 | if self.channels_last == True: 70 | x.contiguous(memory_format=torch.channels_last) 71 | 72 | layer_1, layer_2, layer_3, layer_4 = forward_vit(self.pretrained, x) 73 | 74 | layer_1_rn = self.scratch.layer1_rn(layer_1) 75 | layer_2_rn = self.scratch.layer2_rn(layer_2) 76 | layer_3_rn = self.scratch.layer3_rn(layer_3) 77 | layer_4_rn = self.scratch.layer4_rn(layer_4) 78 | 79 | path_4 = self.scratch.refinenet4(layer_4_rn) 80 | path_3 = self.scratch.refinenet3(path_4, layer_3_rn) 81 | path_2 = self.scratch.refinenet2(path_3, layer_2_rn) 82 | path_1 = self.scratch.refinenet1(path_2, layer_1_rn) 83 | 84 | out = self.scratch.output_conv(path_1) 85 | 86 | return out 87 | 88 | 89 | class DPTDepthModel(DPT): 90 | def __init__( 91 | self, path=None, non_negative=True, scale=1.0, shift=0.0, invert=False, **kwargs 92 | ): 93 | features = kwargs["features"] if "features" in kwargs else 256 94 | 95 | self.scale = scale 96 | self.shift = shift 97 | self.invert = invert 98 | 99 | head = nn.Sequential( 100 | nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1), 101 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 102 | nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1), 103 | nn.ReLU(True), 104 | nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0), 105 | nn.ReLU(True) if non_negative else nn.Identity(), 106 | nn.Identity(), 107 | ) 108 | 109 | super().__init__(head, **kwargs) 110 | 111 | if path is not None: 112 | self.load(path) 113 | 114 | def forward(self, x): 115 | inv_depth = super().forward(x).squeeze(dim=1) 116 | 117 | if self.invert: 118 | depth = self.scale * inv_depth + self.shift 119 | depth[depth < 1e-8] = 1e-8 120 | depth = 1.0 / depth 121 | return depth 122 | else: 123 | return inv_depth 124 | 125 | 126 | class DPTSegmentationModel(DPT): 127 | def __init__(self, num_classes, path=None, **kwargs): 128 | 129 | features = kwargs["features"] if "features" in kwargs else 256 130 | 131 | kwargs["use_bn"] = True 132 | 133 | head = nn.Sequential( 134 | nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), 135 | nn.BatchNorm2d(features), 136 | nn.ReLU(True), 137 | nn.Dropout(0.1, False), 138 | nn.Conv2d(features, num_classes, kernel_size=1), 139 | Interpolate(scale_factor=2, mode="bilinear", align_corners=True), 140 | ) 141 | 142 | super().__init__(head, **kwargs) 143 | 144 | self.auxlayer = nn.Sequential( 145 | nn.Conv2d(features, features, kernel_size=3, padding=1, bias=False), 146 | nn.BatchNorm2d(features), 147 | nn.ReLU(True), 148 | nn.Dropout(0.1, False), 149 | nn.Conv2d(features, num_classes, kernel_size=1), 150 | ) 151 | 152 | if path is not None: 153 | self.load(path) 154 | -------------------------------------------------------------------------------- /dpt/transforms.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def apply_min_size(sample, size, image_interpolation_method=cv2.INTER_AREA): 7 | """Rezise the sample to ensure the given size. Keeps aspect ratio. 8 | 9 | Args: 10 | sample (dict): sample 11 | size (tuple): image size 12 | 13 | Returns: 14 | tuple: new size 15 | """ 16 | shape = list(sample["disparity"].shape) 17 | 18 | if shape[0] >= size[0] and shape[1] >= size[1]: 19 | return sample 20 | 21 | scale = [0, 0] 22 | scale[0] = size[0] / shape[0] 23 | scale[1] = size[1] / shape[1] 24 | 25 | scale = max(scale) 26 | 27 | shape[0] = math.ceil(scale * shape[0]) 28 | shape[1] = math.ceil(scale * shape[1]) 29 | 30 | # resize 31 | sample["image"] = cv2.resize( 32 | sample["image"], tuple(shape[::-1]), interpolation=image_interpolation_method 33 | ) 34 | 35 | sample["disparity"] = cv2.resize( 36 | sample["disparity"], tuple(shape[::-1]), interpolation=cv2.INTER_NEAREST 37 | ) 38 | sample["mask"] = cv2.resize( 39 | sample["mask"].astype(np.float32), 40 | tuple(shape[::-1]), 41 | interpolation=cv2.INTER_NEAREST, 42 | ) 43 | sample["mask"] = sample["mask"].astype(bool) 44 | 45 | return tuple(shape) 46 | 47 | 48 | class Resize(object): 49 | """Resize sample to given size (width, height).""" 50 | 51 | def __init__( 52 | self, 53 | width, 54 | height, 55 | resize_target=True, 56 | keep_aspect_ratio=False, 57 | ensure_multiple_of=1, 58 | resize_method="lower_bound", 59 | image_interpolation_method=cv2.INTER_AREA, 60 | ): 61 | """Init. 62 | 63 | Args: 64 | width (int): desired output width 65 | height (int): desired output height 66 | resize_target (bool, optional): 67 | True: Resize the full sample (image, mask, target). 68 | False: Resize image only. 69 | Defaults to True. 70 | keep_aspect_ratio (bool, optional): 71 | True: Keep the aspect ratio of the input sample. 72 | Output sample might not have the given width and height, and 73 | resize behaviour depends on the parameter 'resize_method'. 74 | Defaults to False. 75 | ensure_multiple_of (int, optional): 76 | Output width and height is constrained to be multiple of this parameter. 77 | Defaults to 1. 78 | resize_method (str, optional): 79 | "lower_bound": Output will be at least as large as the given size. 80 | "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) 81 | "minimal": Scale as least as possible. (Output size might be smaller than given size.) 82 | Defaults to "lower_bound". 83 | """ 84 | self.__width = width 85 | self.__height = height 86 | 87 | self.__resize_target = resize_target 88 | self.__keep_aspect_ratio = keep_aspect_ratio 89 | self.__multiple_of = ensure_multiple_of 90 | self.__resize_method = resize_method 91 | self.__image_interpolation_method = image_interpolation_method 92 | 93 | def constrain_to_multiple_of(self, x, min_val=0, max_val=None): 94 | y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) 95 | 96 | if max_val is not None and y > max_val: 97 | y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) 98 | 99 | if y < min_val: 100 | y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) 101 | 102 | return y 103 | 104 | def get_size(self, width, height): 105 | # determine new height and width 106 | scale_height = self.__height / height 107 | scale_width = self.__width / width 108 | 109 | if self.__keep_aspect_ratio: 110 | if self.__resize_method == "lower_bound": 111 | # scale such that output size is lower bound 112 | if scale_width > scale_height: 113 | # fit width 114 | scale_height = scale_width 115 | else: 116 | # fit height 117 | scale_width = scale_height 118 | elif self.__resize_method == "upper_bound": 119 | # scale such that output size is upper bound 120 | if scale_width < scale_height: 121 | # fit width 122 | scale_height = scale_width 123 | else: 124 | # fit height 125 | scale_width = scale_height 126 | elif self.__resize_method == "minimal": 127 | # scale as least as possbile 128 | if abs(1 - scale_width) < abs(1 - scale_height): 129 | # fit width 130 | scale_height = scale_width 131 | else: 132 | # fit height 133 | scale_width = scale_height 134 | else: 135 | raise ValueError( 136 | f"resize_method {self.__resize_method} not implemented" 137 | ) 138 | 139 | if self.__resize_method == "lower_bound": 140 | new_height = self.constrain_to_multiple_of( 141 | scale_height * height, min_val=self.__height 142 | ) 143 | new_width = self.constrain_to_multiple_of( 144 | scale_width * width, min_val=self.__width 145 | ) 146 | elif self.__resize_method == "upper_bound": 147 | new_height = self.constrain_to_multiple_of( 148 | scale_height * height, max_val=self.__height 149 | ) 150 | new_width = self.constrain_to_multiple_of( 151 | scale_width * width, max_val=self.__width 152 | ) 153 | elif self.__resize_method == "minimal": 154 | new_height = self.constrain_to_multiple_of(scale_height * height) 155 | new_width = self.constrain_to_multiple_of(scale_width * width) 156 | else: 157 | raise ValueError(f"resize_method {self.__resize_method} not implemented") 158 | 159 | return (new_width, new_height) 160 | 161 | def __call__(self, sample): 162 | width, height = self.get_size( 163 | sample["image"].shape[1], sample["image"].shape[0] 164 | ) 165 | 166 | # resize sample 167 | sample["image"] = cv2.resize( 168 | sample["image"], 169 | (width, height), 170 | interpolation=self.__image_interpolation_method, 171 | ) 172 | 173 | if self.__resize_target: 174 | if "disparity" in sample: 175 | sample["disparity"] = cv2.resize( 176 | sample["disparity"], 177 | (width, height), 178 | interpolation=cv2.INTER_NEAREST, 179 | ) 180 | 181 | if "depth" in sample: 182 | sample["depth"] = cv2.resize( 183 | sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST 184 | ) 185 | 186 | sample["mask"] = cv2.resize( 187 | sample["mask"].astype(np.float32), 188 | (width, height), 189 | interpolation=cv2.INTER_NEAREST, 190 | ) 191 | sample["mask"] = sample["mask"].astype(bool) 192 | 193 | return sample 194 | 195 | 196 | class NormalizeImage(object): 197 | """Normlize image by given mean and std.""" 198 | 199 | def __init__(self, mean, std): 200 | self.__mean = mean 201 | self.__std = std 202 | 203 | def __call__(self, sample): 204 | sample["image"] = (sample["image"] - self.__mean) / self.__std 205 | 206 | return sample 207 | 208 | 209 | class PrepareForNet(object): 210 | """Prepare sample for usage as network input.""" 211 | 212 | def __init__(self): 213 | pass 214 | 215 | def __call__(self, sample): 216 | image = np.transpose(sample["image"], (2, 0, 1)) 217 | sample["image"] = np.ascontiguousarray(image).astype(np.float32) 218 | 219 | if "mask" in sample: 220 | sample["mask"] = sample["mask"].astype(np.float32) 221 | sample["mask"] = np.ascontiguousarray(sample["mask"]) 222 | 223 | if "disparity" in sample: 224 | disparity = sample["disparity"].astype(np.float32) 225 | sample["disparity"] = np.ascontiguousarray(disparity) 226 | 227 | if "depth" in sample: 228 | depth = sample["depth"].astype(np.float32) 229 | sample["depth"] = np.ascontiguousarray(depth) 230 | 231 | return sample 232 | -------------------------------------------------------------------------------- /dpt/util/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/dpt/util/__init__.py -------------------------------------------------------------------------------- /dpt/util/io.py: -------------------------------------------------------------------------------- 1 | """Utils for monoDepth. 2 | """ 3 | import sys 4 | import re 5 | import numpy as np 6 | import cv2 7 | import torch 8 | 9 | from PIL import Image 10 | 11 | 12 | from .pallete import get_mask_pallete 13 | 14 | def read_pfm(path): 15 | """Read pfm file. 16 | 17 | Args: 18 | path (str): path to file 19 | 20 | Returns: 21 | tuple: (data, scale) 22 | """ 23 | with open(path, "rb") as file: 24 | 25 | color = None 26 | width = None 27 | height = None 28 | scale = None 29 | endian = None 30 | 31 | header = file.readline().rstrip() 32 | if header.decode("ascii") == "PF": 33 | color = True 34 | elif header.decode("ascii") == "Pf": 35 | color = False 36 | else: 37 | raise Exception("Not a PFM file: " + path) 38 | 39 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("ascii")) 40 | if dim_match: 41 | width, height = list(map(int, dim_match.groups())) 42 | else: 43 | raise Exception("Malformed PFM header.") 44 | 45 | scale = float(file.readline().decode("ascii").rstrip()) 46 | if scale < 0: 47 | # little-endian 48 | endian = "<" 49 | scale = -scale 50 | else: 51 | # big-endian 52 | endian = ">" 53 | 54 | data = np.fromfile(file, endian + "f") 55 | shape = (height, width, 3) if color else (height, width) 56 | 57 | data = np.reshape(data, shape) 58 | data = np.flipud(data) 59 | 60 | return data, scale 61 | 62 | 63 | def write_pfm(path, image, scale=1): 64 | """Write pfm file. 65 | 66 | Args: 67 | path (str): pathto file 68 | image (array): data 69 | scale (int, optional): Scale. Defaults to 1. 70 | """ 71 | 72 | with open(path, "wb") as file: 73 | color = None 74 | 75 | if image.dtype.name != "float32": 76 | raise Exception("Image dtype must be float32.") 77 | 78 | image = np.flipud(image) 79 | 80 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 81 | color = True 82 | elif ( 83 | len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1 84 | ): # greyscale 85 | color = False 86 | else: 87 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 88 | 89 | file.write("PF\n" if color else "Pf\n".encode()) 90 | file.write("%d %d\n".encode() % (image.shape[1], image.shape[0])) 91 | 92 | endian = image.dtype.byteorder 93 | 94 | if endian == "<" or endian == "=" and sys.byteorder == "little": 95 | scale = -scale 96 | 97 | file.write("%f\n".encode() % scale) 98 | 99 | image.tofile(file) 100 | 101 | 102 | def read_image(path): 103 | """Read image and output RGB image (0-1). 104 | 105 | Args: 106 | path (str): path to file 107 | 108 | Returns: 109 | array: RGB image (0-1) 110 | """ 111 | img = cv2.imread(path) 112 | 113 | if img.ndim == 2: 114 | img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) 115 | 116 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 117 | 118 | return img 119 | 120 | 121 | def resize_image(img): 122 | """Resize image and make it fit for network. 123 | 124 | Args: 125 | img (array): image 126 | 127 | Returns: 128 | tensor: data ready for network 129 | """ 130 | height_orig = img.shape[0] 131 | width_orig = img.shape[1] 132 | 133 | if width_orig > height_orig: 134 | scale = width_orig / 384 135 | else: 136 | scale = height_orig / 384 137 | 138 | height = (np.ceil(height_orig / scale / 32) * 32).astype(int) 139 | width = (np.ceil(width_orig / scale / 32) * 32).astype(int) 140 | 141 | img_resized = cv2.resize(img, (width, height), interpolation=cv2.INTER_AREA) 142 | 143 | img_resized = ( 144 | torch.from_numpy(np.transpose(img_resized, (2, 0, 1))).contiguous().float() 145 | ) 146 | img_resized = img_resized.unsqueeze(0) 147 | 148 | return img_resized 149 | 150 | 151 | def resize_depth(depth, width, height): 152 | """Resize depth map and bring to CPU (numpy). 153 | 154 | Args: 155 | depth (tensor): depth 156 | width (int): image width 157 | height (int): image height 158 | 159 | Returns: 160 | array: processed depth 161 | """ 162 | depth = torch.squeeze(depth[0, :, :, :]).to("cpu") 163 | 164 | depth_resized = cv2.resize( 165 | depth.numpy(), (width, height), interpolation=cv2.INTER_CUBIC 166 | ) 167 | 168 | return depth_resized 169 | 170 | 171 | def write_depth(path, depth, bits=1, absolute_depth=False): 172 | """Write depth map to pfm and png file. 173 | 174 | Args: 175 | path (str): filepath without extension 176 | depth (array): depth 177 | """ 178 | write_pfm(path + ".pfm", depth.astype(np.float32)) 179 | 180 | if absolute_depth: 181 | out = depth 182 | else: 183 | depth_min = depth.min() 184 | depth_max = depth.max() 185 | 186 | max_val = (2 ** (8 * bits)) - 1 187 | 188 | if depth_max - depth_min > np.finfo("float").eps: 189 | out = max_val * (depth - depth_min) / (depth_max - depth_min) 190 | else: 191 | out = np.zeros(depth.shape, dtype=depth.dtype) 192 | 193 | if bits == 1: 194 | cv2.imwrite(path + ".png", out.astype("uint8"), [cv2.IMWRITE_PNG_COMPRESSION, 0]) 195 | elif bits == 2: 196 | cv2.imwrite(path + ".png", out.astype("uint16"), [cv2.IMWRITE_PNG_COMPRESSION, 0]) 197 | 198 | return 199 | 200 | 201 | def write_segm_img(path, image, labels, palette="detail", alpha=0.5): 202 | """Write depth map to pfm and png file. 203 | 204 | Args: 205 | path (str): filepath without extension 206 | image (array): input image 207 | labels (array): labeling of the image 208 | """ 209 | 210 | mask = get_mask_pallete(labels, "ade20k") 211 | 212 | img = Image.fromarray(np.uint8(255*image)).convert("RGBA") 213 | seg = mask.convert("RGBA") 214 | 215 | out = Image.blend(img, seg, alpha) 216 | 217 | out.save(path + ".png") 218 | 219 | return 220 | -------------------------------------------------------------------------------- /dpt/util/misc.py: -------------------------------------------------------------------------------- 1 | import matplotlib.pyplot as plt 2 | 3 | from dpt.vit import get_mean_attention_map 4 | 5 | def visualize_attention(input, model, prediction, model_type): 6 | input = (input + 1.0)/2.0 7 | 8 | attn1 = model.pretrained.attention["attn_1"] 9 | attn2 = model.pretrained.attention["attn_2"] 10 | attn3 = model.pretrained.attention["attn_3"] 11 | attn4 = model.pretrained.attention["attn_4"] 12 | 13 | plt.subplot(3,4,1), plt.imshow(input.squeeze().permute(1,2,0)), plt.title("Input", fontsize=8), plt.axis("off") 14 | plt.subplot(3,4,2), plt.imshow(prediction), plt.set_cmap("inferno"), plt.title("Prediction", fontsize=8), plt.axis("off") 15 | 16 | if model_type == "dpt_hybrid": 17 | h = [3,6,9,12] 18 | else: 19 | h = [6,12,18,24] 20 | 21 | # upper left 22 | plt.subplot(345), 23 | ax1 = plt.imshow(get_mean_attention_map(attn1, 1, input.shape)) 24 | plt.ylabel("Upper left corner", fontsize=8) 25 | plt.title(f"Layer {h[0]}", fontsize=8) 26 | gc = plt.gca() 27 | gc.axes.xaxis.set_ticklabels([]) 28 | gc.axes.yaxis.set_ticklabels([]) 29 | gc.axes.xaxis.set_ticks([]) 30 | gc.axes.yaxis.set_ticks([]) 31 | 32 | 33 | plt.subplot(346), 34 | plt.imshow(get_mean_attention_map(attn2, 1, input.shape)) 35 | plt.title(f"Layer {h[1]}", fontsize=8) 36 | plt.axis("off"), 37 | 38 | plt.subplot(347), 39 | plt.imshow(get_mean_attention_map(attn3, 1, input.shape)) 40 | plt.title(f"Layer {h[2]}", fontsize=8) 41 | plt.axis("off"), 42 | 43 | 44 | plt.subplot(348), 45 | plt.imshow(get_mean_attention_map(attn4, 1, input.shape)) 46 | plt.title(f"Layer {h[3]}", fontsize=8) 47 | plt.axis("off"), 48 | 49 | 50 | # lower right 51 | plt.subplot(3,4,9), plt.imshow(get_mean_attention_map(attn1, -1, input.shape)) 52 | plt.ylabel("Lower right corner", fontsize=8) 53 | gc = plt.gca() 54 | gc.axes.xaxis.set_ticklabels([]) 55 | gc.axes.yaxis.set_ticklabels([]) 56 | gc.axes.xaxis.set_ticks([]) 57 | gc.axes.yaxis.set_ticks([]) 58 | 59 | plt.subplot(3,4,10), plt.imshow(get_mean_attention_map(attn2, -1, input.shape)), plt.axis("off") 60 | plt.subplot(3,4,11), plt.imshow(get_mean_attention_map(attn3, -1, input.shape)), plt.axis("off") 61 | plt.subplot(3,4,12), plt.imshow(get_mean_attention_map(attn4, -1, input.shape)), plt.axis("off") 62 | plt.tight_layout() 63 | plt.show() 64 | -------------------------------------------------------------------------------- /dpt/util/pallete.py: -------------------------------------------------------------------------------- 1 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 | ## Created by: Hang Zhang 3 | ## ECE Department, Rutgers University 4 | ## Email: zhang.hang@rutgers.edu 5 | ## Copyright (c) 2017 6 | ## 7 | ## This source code is licensed under the MIT-style license found in the 8 | ## LICENSE file in the root directory of this source tree 9 | ##+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 10 | 11 | from PIL import Image 12 | 13 | def get_mask_pallete(npimg, dataset='detail'): 14 | """Get image color pallete for visualizing masks""" 15 | # recovery boundary 16 | if dataset == 'pascal_voc': 17 | npimg[npimg==21] = 255 18 | # put colormap 19 | out_img = Image.fromarray(npimg.squeeze().astype('uint8')) 20 | if dataset == 'ade20k': 21 | out_img.putpalette(adepallete) 22 | elif dataset == 'citys': 23 | out_img.putpalette(citypallete) 24 | elif dataset in ('detail', 'pascal_voc', 'pascal_aug'): 25 | out_img.putpalette(vocpallete) 26 | return out_img 27 | 28 | def _get_voc_pallete(num_cls): 29 | n = num_cls 30 | pallete = [0]*(n*3) 31 | for j in range(0,n): 32 | lab = j 33 | pallete[j*3+0] = 0 34 | pallete[j*3+1] = 0 35 | pallete[j*3+2] = 0 36 | i = 0 37 | while (lab > 0): 38 | pallete[j*3+0] |= (((lab >> 0) & 1) << (7-i)) 39 | pallete[j*3+1] |= (((lab >> 1) & 1) << (7-i)) 40 | pallete[j*3+2] |= (((lab >> 2) & 1) << (7-i)) 41 | i = i + 1 42 | lab >>= 3 43 | return pallete 44 | 45 | vocpallete = _get_voc_pallete(256) 46 | 47 | adepallete = [0,0,0,120,120,120,180,120,120,6,230,230,80,50,50,4,200,3,120,120,80,140,140,140,204,5,255,230,230,230,4,250,7,224,5,255,235,255,7,150,5,61,120,120,70,8,255,51,255,6,82,143,255,140,204,255,4,255,51,7,204,70,3,0,102,200,61,230,250,255,6,51,11,102,255,255,7,71,255,9,224,9,7,230,220,220,220,255,9,92,112,9,255,8,255,214,7,255,224,255,184,6,10,255,71,255,41,10,7,255,255,224,255,8,102,8,255,255,61,6,255,194,7,255,122,8,0,255,20,255,8,41,255,5,153,6,51,255,235,12,255,160,150,20,0,163,255,140,140,140,250,10,15,20,255,0,31,255,0,255,31,0,255,224,0,153,255,0,0,0,255,255,71,0,0,235,255,0,173,255,31,0,255,11,200,200,255,82,0,0,255,245,0,61,255,0,255,112,0,255,133,255,0,0,255,163,0,255,102,0,194,255,0,0,143,255,51,255,0,0,82,255,0,255,41,0,255,173,10,0,255,173,255,0,0,255,153,255,92,0,255,0,255,255,0,245,255,0,102,255,173,0,255,0,20,255,184,184,0,31,255,0,255,61,0,71,255,255,0,204,0,255,194,0,255,82,0,10,255,0,112,255,51,0,255,0,194,255,0,122,255,0,255,163,255,153,0,0,255,10,255,112,0,143,255,0,82,0,255,163,255,0,255,235,0,8,184,170,133,0,255,0,255,92,184,0,255,255,0,31,0,184,255,0,214,255,255,0,112,92,255,0,0,224,255,112,224,255,70,184,160,163,0,255,153,0,255,71,255,0,255,0,163,255,204,0,255,0,143,0,255,235,133,255,0,255,0,235,245,0,255,255,0,122,255,245,0,10,190,212,214,255,0,0,204,255,20,0,255,255,255,0,0,153,255,0,41,255,0,255,204,41,0,255,41,255,0,173,0,255,0,245,255,71,0,255,122,0,255,0,255,184,0,92,255,184,255,0,0,133,255,255,214,0,25,194,194,102,255,0,92,0,255] 48 | 49 | citypallete = [ 50 | 128,64,128,244,35,232,70,70,70,102,102,156,190,153,153,153,153,153,250,170,30,220,220,0,107,142,35,152,251,152,70,130,180,220,20,60,255,0,0,0,0,142,0,0,70,0,60,100,0,80,100,0,0,230,119,11,32,128,192,0,0,64,128,128,64,128,0,192,128,128,192,128,64,64,0,192,64,0,64,192,0,192,192,0,64,64,128,192,64,128,64,192,128,192,192,128,0,0,64,128,0,64,0,128,64,128,128,64,0,0,192,128,0,192,0,128,192,128,128,192,64,0,64,192,0,64,64,128,64,192,128,64,64,0,192,192,0,192,64,128,192,192,128,192,0,64,64,128,64,64,0,192,64,128,192,64,0,64,192,128,64,192,0,192,192,128,192,192,64,64,64,192,64,64,64,192,64,192,192,64,64,64,192,192,64,192,64,192,192,192,192,192,32,0,0,160,0,0,32,128,0,160,128,0,32,0,128,160,0,128,32,128,128,160,128,128,96,0,0,224,0,0,96,128,0,224,128,0,96,0,128,224,0,128,96,128,128,224,128,128,32,64,0,160,64,0,32,192,0,160,192,0,32,64,128,160,64,128,32,192,128,160,192,128,96,64,0,224,64,0,96,192,0,224,192,0,96,64,128,224,64,128,96,192,128,224,192,128,32,0,64,160,0,64,32,128,64,160,128,64,32,0,192,160,0,192,32,128,192,160,128,192,96,0,64,224,0,64,96,128,64,224,128,64,96,0,192,224,0,192,96,128,192,224,128,192,32,64,64,160,64,64,32,192,64,160,192,64,32,64,192,160,64,192,32,192,192,160,192,192,96,64,64,224,64,64,96,192,64,224,192,64,96,64,192,224,64,192,96,192,192,224,192,192,0,32,0,128,32,0,0,160,0,128,160,0,0,32,128,128,32,128,0,160,128,128,160,128,64,32,0,192,32,0,64,160,0,192,160,0,64,32,128,192,32,128,64,160,128,192,160,128,0,96,0,128,96,0,0,224,0,128,224,0,0,96,128,128,96,128,0,224,128,128,224,128,64,96,0,192,96,0,64,224,0,192,224,0,64,96,128,192,96,128,64,224,128,192,224,128,0,32,64,128,32,64,0,160,64,128,160,64,0,32,192,128,32,192,0,160,192,128,160,192,64,32,64,192,32,64,64,160,64,192,160,64,64,32,192,192,32,192,64,160,192,192,160,192,0,96,64,128,96,64,0,224,64,128,224,64,0,96,192,128,96,192,0,224,192,128,224,192,64,96,64,192,96,64,64,224,64,192,224,64,64,96,192,192,96,192,64,224,192,192,224,192,32,32,0,160,32,0,32,160,0,160,160,0,32,32,128,160,32,128,32,160,128,160,160,128,96,32,0,224,32,0,96,160,0,224,160,0,96,32,128,224,32,128,96,160,128,224,160,128,32,96,0,160,96,0,32,224,0,160,224,0,32,96,128,160,96,128,32,224,128,160,224,128,96,96,0,224,96,0,96,224,0,224,224,0,96,96,128,224,96,128,96,224,128,224,224,128,32,32,64,160,32,64,32,160,64,160,160,64,32,32,192,160,32,192,32,160,192,160,160,192,96,32,64,224,32,64,96,160,64,224,160,64,96,32,192,224,32,192,96,160,192,224,160,192,32,96,64,160,96,64,32,224,64,160,224,64,32,96,192,160,96,192,32,224,192,160,224,192,96,96,64,224,96,64,96,224,64,224,224,64,96,96,192,224,96,192,96,224,192,0,0,0] 51 | -------------------------------------------------------------------------------- /dpt/weights/.placeholder: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/dpt/weights/.placeholder -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import torch 4 | import argparse 5 | import numpy as np 6 | from tqdm import trange 7 | from utils.func import * 8 | from SGR import DepthNet as SGRnet 9 | from MiDaS.midas_net import MidasNet 10 | from utils.model import Gradient_FusionModel 11 | from torch.optim import lr_scheduler, AdamW 12 | import torchvision.transforms as transforms 13 | from LeRes.multi_depth_model_woauxi import strip_prefix_if_present, RelDepthModel 14 | from utils.middleburry2021 import middleburry 15 | from utils.multiscopic import multiscopic 16 | from utils.hypersim import hypersim 17 | from torchvision.transforms import Compose 18 | from dpt.models import DPTDepthModel 19 | from dpt.midas_net import MidasNet_large 20 | from dpt.transforms import Resize, NormalizeImage, PrepareForNet 21 | from newcrfs.networks.NewCRFDepth import NewCRFDepth 22 | from torch.autograd import Variable 23 | 24 | 25 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 26 | os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE' 27 | 28 | 29 | def run(args): 30 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 31 | size = 224 32 | if args.pred_model == 'LeRes50': 33 | Depth_model = RelDepthModel(backbone='resnet50') 34 | depth_dict = './LeRes/res50.pth' 35 | depth_dict = torch.load(depth_dict) 36 | Depth_model.load_state_dict(strip_prefix_if_present(depth_dict['depth_model'], "module."), strict=True) 37 | model_flag = 1 38 | 39 | elif args.pred_model == 'SGR': 40 | Depth_model = SGRnet.DepthNet() 41 | if device == torch.device("cuda"): 42 | Depth_model = torch.nn.DataParallel(Depth_model, device_ids=[0]).cuda() 43 | else: 44 | print('sgr model can not run correctly without cpu') 45 | exit() 46 | depth_dict = torch.load('./SGR/model.pth.tar') 47 | Depth_model.load_state_dict(depth_dict['state_dict']) 48 | model_flag = 2 49 | 50 | elif args.pred_model == 'MiDaS': 51 | Depth_model = MidasNet('./MiDaS/model.pt', non_negative=True) 52 | model_flag = 3 53 | size = 192 54 | 55 | elif args.pred_model == 'dpt': 56 | torch.backends.cudnn.enabled = True 57 | torch.backends.cudnn.benchmark = True 58 | Depth_model = DPTDepthModel( 59 | path="dpt/weights/dpt_hybrid-midas-501f0c75.pt", 60 | backbone="vitb_rn50_384", 61 | non_negative=True, 62 | enable_attention_hooks=False, 63 | ) 64 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 65 | transform_low = Compose( 66 | [Resize( 67 | 384, 68 | 384, 69 | resize_target=None, 70 | keep_aspect_ratio=True, 71 | ensure_multiple_of=32, 72 | resize_method="minimal", 73 | image_interpolation_method=cv2.INTER_CUBIC, 74 | ), 75 | normalization, 76 | PrepareForNet(),]) 77 | 78 | transform_high = Compose( 79 | [Resize( 80 | 384*3, 81 | 384*3, 82 | resize_target=None, 83 | keep_aspect_ratio=True, 84 | ensure_multiple_of=32, 85 | resize_method="minimal", 86 | image_interpolation_method=cv2.INTER_CUBIC, 87 | ), 88 | normalization, 89 | PrepareForNet(),]) 90 | if device == torch.device("cuda"): 91 | Depth_model = Depth_model.to(memory_format=torch.channels_last) 92 | Depth_model = Depth_model.half() 93 | model_flag = 4 94 | 95 | elif args.pred_model == 'newcrfs': 96 | max_depth = 1000 97 | checkpoint_path = './newcrfs/model_nyu.ckpt' 98 | Depth_model = NewCRFDepth(version='large07', inv_depth=True, max_depth=max_depth) 99 | Depth_model = torch.nn.DataParallel(Depth_model) 100 | checkpoint = torch.load(checkpoint_path) 101 | Depth_model.load_state_dict(checkpoint['model']) 102 | model_flag = 5 103 | 104 | else: 105 | print('no such model') 106 | exit() 107 | 108 | Fuse_model = Gradient_FusionModel(dict_path=args.model_weights) 109 | 110 | Fuse_model.to(device) 111 | Depth_model.to(device) 112 | Fuse_model = Fuse_model.eval() 113 | Depth_model = Depth_model.eval() 114 | 115 | if args.eval_dataset == 'middleburry2021': 116 | dataset = middleburry() 117 | elif args.eval_dataset == 'multiscopic': 118 | dataset = multiscopic() 119 | elif args.eval_dataset == 'hypersim': 120 | dataset = hypersim() 121 | else: 122 | print('no such dataset') 123 | exit() 124 | 125 | # while dataset.index != dataset.num-1: 126 | for i in trange(dataset.num): 127 | img, depth, val_mask = dataset.getitem() 128 | if model_flag == 4: 129 | img = img.astype('float32')/255.0 130 | low_img = transform_low({"image": img})["image"] 131 | high_img = transform_high({"image": img})["image"] 132 | elif model_flag == 5: 133 | img = img.astype('float32')/255.0 134 | low_img = cv2.resize(img, (640, 480)) 135 | high_img = cv2.resize(img, (640*3, 480*3)) 136 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 137 | low_img = np.expand_dims(low_img, axis=0) 138 | low_img = np.transpose(low_img, (0, 3, 1, 2)) 139 | low_img = Variable(normalize(torch.from_numpy(low_img)).float()).cuda() 140 | high_img = np.expand_dims(high_img, axis=0) 141 | high_img = np.transpose(high_img, (0, 3, 1, 2)) 142 | high_img = Variable(normalize(torch.from_numpy(high_img)).float()).cuda() 143 | else: 144 | low_img, high_img = scale_image(img, size, device) 145 | 146 | with torch.no_grad(): 147 | if model_flag == 1: 148 | low_dep = Depth_model.inference(low_img) 149 | high_dep = Depth_model.inference(high_img) 150 | 151 | elif model_flag == 2: 152 | low_dep = Depth_model.forward(low_img) 153 | high_dep = Depth_model.forward(high_img) 154 | low_dep = low_dep.max() - low_dep 155 | high_dep = high_dep.max() - high_dep 156 | 157 | elif model_flag == 3: 158 | low_dep = Depth_model.forward(low_img).unsqueeze(0) 159 | high_dep = Depth_model.forward(high_img).unsqueeze(0) 160 | low_dep = low_dep.max() - low_dep 161 | high_dep = high_dep.max() - high_dep 162 | 163 | elif model_flag == 4: 164 | sample = torch.from_numpy(low_img).to(device).unsqueeze(0) 165 | sample = sample.to(memory_format=torch.channels_last) 166 | sample = sample.half() 167 | low_dep = Depth_model.forward(sample) 168 | low_dep = (torch.nn.functional.interpolate( 169 | low_dep.unsqueeze(1), 170 | size=img.shape[:2], 171 | mode="bicubic", 172 | align_corners=False,)).float() 173 | sample = torch.from_numpy(high_img).to(device).unsqueeze(0) 174 | sample = sample.to(memory_format=torch.channels_last) 175 | sample = sample.half() 176 | high_dep = Depth_model.forward(sample) 177 | high_dep = (torch.nn.functional.interpolate( 178 | high_dep.unsqueeze(1), 179 | size=img.shape[:2], 180 | mode="bicubic", 181 | align_corners=False,)).float() 182 | low_dep = low_dep.max() - low_dep 183 | high_dep = high_dep.max() - high_dep 184 | 185 | elif model_flag == 5: 186 | low_dep = Depth_model(low_img) 187 | high_dep = Depth_model(high_img) 188 | 189 | low_dep, high_dep, fusion = Fuse_model.inference(low_dep, high_dep) 190 | dataset.compute_error(fusion, depth, val_mask) 191 | 192 | print('Results:') 193 | print('sq_rel = ', np.nanmean(dataset.sq_rel)) 194 | print('rms = ', np.nanmean(dataset.rms)) 195 | print('log10 = ', np.nanmean(dataset.log10)) 196 | print('thr1 = ', np.nanmean(dataset.thr1)) 197 | print('thr2 = ', np.nanmean(dataset.thr2)) 198 | 199 | 200 | 201 | if __name__ == '__main__': 202 | parser = argparse.ArgumentParser() 203 | 204 | parser.add_argument('-m', '--model_weights', 205 | default='./models/model_dict.pt', 206 | help='path to the trained weights of model' 207 | ) 208 | 209 | parser.add_argument('-p', '--pred_model', 210 | default='LeRes50', 211 | help='model type: LeRes50, SGR ,MiDaS, dpt or newcrfs' 212 | ) 213 | 214 | parser.add_argument('-d', '--eval_dataset', 215 | default='middleburry2021', 216 | help='dataset: multiscopic, middleburry2021 or hypersim' 217 | ) 218 | 219 | args = parser.parse_args() 220 | run(args) -------------------------------------------------------------------------------- /figures/1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/figures/1.gif -------------------------------------------------------------------------------- /figures/2.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/figures/2.gif -------------------------------------------------------------------------------- /figures/3.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/figures/3.gif -------------------------------------------------------------------------------- /input/1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/input/1.png -------------------------------------------------------------------------------- /input/2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/input/2.png -------------------------------------------------------------------------------- /input/3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/input/3.png -------------------------------------------------------------------------------- /newcrfs/dataloaders/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/newcrfs/dataloaders/__init__.py -------------------------------------------------------------------------------- /newcrfs/networks/NewCRFDepth.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .swin_transformer import SwinTransformer 6 | from .newcrf_layers import NewCRF 7 | from .uper_crf_head import PSP 8 | ######################################################################################################################## 9 | 10 | 11 | class NewCRFDepth(nn.Module): 12 | """ 13 | Depth network based on neural window FC-CRFs architecture. 14 | """ 15 | def __init__(self, version=None, inv_depth=False, pretrained=None, 16 | frozen_stages=-1, min_depth=0.1, max_depth=100.0, **kwargs): 17 | super().__init__() 18 | 19 | self.inv_depth = inv_depth 20 | self.with_auxiliary_head = False 21 | self.with_neck = False 22 | 23 | norm_cfg = dict(type='BN', requires_grad=True) 24 | # norm_cfg = dict(type='GN', requires_grad=True, num_groups=8) 25 | 26 | window_size = int(version[-2:]) 27 | 28 | if version[:-2] == 'base': 29 | embed_dim = 128 30 | depths = [2, 2, 18, 2] 31 | num_heads = [4, 8, 16, 32] 32 | in_channels = [128, 256, 512, 1024] 33 | elif version[:-2] == 'large': 34 | embed_dim = 192 35 | depths = [2, 2, 18, 2] 36 | num_heads = [6, 12, 24, 48] 37 | in_channels = [192, 384, 768, 1536] 38 | elif version[:-2] == 'tiny': 39 | embed_dim = 96 40 | depths = [2, 2, 6, 2] 41 | num_heads = [3, 6, 12, 24] 42 | in_channels = [96, 192, 384, 768] 43 | 44 | backbone_cfg = dict( 45 | embed_dim=embed_dim, 46 | depths=depths, 47 | num_heads=num_heads, 48 | window_size=window_size, 49 | ape=False, 50 | drop_path_rate=0.3, 51 | patch_norm=True, 52 | use_checkpoint=False, 53 | frozen_stages=frozen_stages 54 | ) 55 | 56 | embed_dim = 512 57 | decoder_cfg = dict( 58 | in_channels=in_channels, 59 | in_index=[0, 1, 2, 3], 60 | pool_scales=(1, 2, 3, 6), 61 | channels=embed_dim, 62 | dropout_ratio=0.0, 63 | num_classes=32, 64 | norm_cfg=norm_cfg, 65 | align_corners=False 66 | ) 67 | 68 | self.backbone = SwinTransformer(**backbone_cfg) 69 | v_dim = decoder_cfg['num_classes']*4 70 | win = 7 71 | crf_dims = [128, 256, 512, 1024] 72 | v_dims = [64, 128, 256, embed_dim] 73 | self.crf3 = NewCRF(input_dim=in_channels[3], embed_dim=crf_dims[3], window_size=win, v_dim=v_dims[3], num_heads=32) 74 | self.crf2 = NewCRF(input_dim=in_channels[2], embed_dim=crf_dims[2], window_size=win, v_dim=v_dims[2], num_heads=16) 75 | self.crf1 = NewCRF(input_dim=in_channels[1], embed_dim=crf_dims[1], window_size=win, v_dim=v_dims[1], num_heads=8) 76 | self.crf0 = NewCRF(input_dim=in_channels[0], embed_dim=crf_dims[0], window_size=win, v_dim=v_dims[0], num_heads=4) 77 | 78 | self.decoder = PSP(**decoder_cfg) 79 | self.disp_head1 = DispHead(input_dim=crf_dims[0]) 80 | 81 | self.up_mode = 'bilinear' 82 | if self.up_mode == 'mask': 83 | self.mask_head = nn.Sequential( 84 | nn.Conv2d(crf_dims[0], 64, 3, padding=1), 85 | nn.ReLU(inplace=True), 86 | nn.Conv2d(64, 16*9, 1, padding=0)) 87 | 88 | self.min_depth = min_depth 89 | self.max_depth = max_depth 90 | 91 | self.init_weights(pretrained=pretrained) 92 | 93 | def init_weights(self, pretrained=None): 94 | """Initialize the weights in backbone and heads. 95 | 96 | Args: 97 | pretrained (str, optional): Path to pre-trained weights. 98 | Defaults to None. 99 | """ 100 | print(f'== Load encoder backbone from: {pretrained}') 101 | self.backbone.init_weights(pretrained=pretrained) 102 | self.decoder.init_weights() 103 | if self.with_auxiliary_head: 104 | if isinstance(self.auxiliary_head, nn.ModuleList): 105 | for aux_head in self.auxiliary_head: 106 | aux_head.init_weights() 107 | else: 108 | self.auxiliary_head.init_weights() 109 | 110 | def upsample_mask(self, disp, mask): 111 | """ Upsample disp [H/4, W/4, 1] -> [H, W, 1] using convex combination """ 112 | N, _, H, W = disp.shape 113 | mask = mask.view(N, 1, 9, 4, 4, H, W) 114 | mask = torch.softmax(mask, dim=2) 115 | 116 | up_disp = F.unfold(disp, kernel_size=3, padding=1) 117 | up_disp = up_disp.view(N, 1, 9, 1, 1, H, W) 118 | 119 | up_disp = torch.sum(mask * up_disp, dim=2) 120 | up_disp = up_disp.permute(0, 1, 4, 2, 5, 3) 121 | return up_disp.reshape(N, 1, 4*H, 4*W) 122 | 123 | def forward(self, imgs): 124 | 125 | feats = self.backbone(imgs) 126 | if self.with_neck: 127 | feats = self.neck(feats) 128 | 129 | ppm_out = self.decoder(feats) 130 | 131 | e3 = self.crf3(feats[3], ppm_out) 132 | e3 = nn.PixelShuffle(2)(e3) 133 | e2 = self.crf2(feats[2], e3) 134 | e2 = nn.PixelShuffle(2)(e2) 135 | e1 = self.crf1(feats[1], e2) 136 | e1 = nn.PixelShuffle(2)(e1) 137 | e0 = self.crf0(feats[0], e1) 138 | 139 | if self.up_mode == 'mask': 140 | mask = self.mask_head(e0) 141 | d1 = self.disp_head1(e0, 1) 142 | d1 = self.upsample_mask(d1, mask) 143 | else: 144 | d1 = self.disp_head1(e0, 4) 145 | 146 | depth = d1 * self.max_depth 147 | 148 | return depth 149 | 150 | 151 | class DispHead(nn.Module): 152 | def __init__(self, input_dim=100): 153 | super(DispHead, self).__init__() 154 | # self.norm1 = nn.BatchNorm2d(input_dim) 155 | self.conv1 = nn.Conv2d(input_dim, 1, 3, padding=1) 156 | # self.relu = nn.ReLU(inplace=True) 157 | self.sigmoid = nn.Sigmoid() 158 | 159 | def forward(self, x, scale): 160 | # x = self.relu(self.norm1(x)) 161 | x = self.sigmoid(self.conv1(x)) 162 | if scale > 1: 163 | x = upsample(x, scale_factor=scale) 164 | return x 165 | 166 | 167 | class DispUnpack(nn.Module): 168 | def __init__(self, input_dim=100, hidden_dim=128): 169 | super(DispUnpack, self).__init__() 170 | self.conv1 = nn.Conv2d(input_dim, hidden_dim, 3, padding=1) 171 | self.conv2 = nn.Conv2d(hidden_dim, 16, 3, padding=1) 172 | self.relu = nn.ReLU(inplace=True) 173 | self.sigmoid = nn.Sigmoid() 174 | self.pixel_shuffle = nn.PixelShuffle(4) 175 | 176 | def forward(self, x, output_size): 177 | x = self.relu(self.conv1(x)) 178 | x = self.sigmoid(self.conv2(x)) # [b, 16, h/4, w/4] 179 | # x = torch.reshape(x, [x.shape[0], 1, x.shape[2]*4, x.shape[3]*4]) 180 | x = self.pixel_shuffle(x) 181 | 182 | return x 183 | 184 | 185 | def upsample(x, scale_factor=2, mode="bilinear", align_corners=False): 186 | """Upsample input tensor by a factor of 2 187 | """ 188 | return F.interpolate(x, scale_factor=scale_factor, mode=mode, align_corners=align_corners) -------------------------------------------------------------------------------- /newcrfs/networks/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YuiNsky/Gradient-based-depth-map-fusion/ef5414cdd876b13215df1bfcceb4c2d1e5676ec2/newcrfs/networks/__init__.py -------------------------------------------------------------------------------- /newcrfs/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.distributed as dist 4 | from torch.utils.data import Sampler 5 | from torchvision import transforms 6 | 7 | import os, sys 8 | import numpy as np 9 | import math 10 | import torch 11 | 12 | 13 | def convert_arg_line_to_args(arg_line): 14 | for arg in arg_line.split(): 15 | if not arg.strip(): 16 | continue 17 | yield arg 18 | 19 | 20 | def block_print(): 21 | sys.stdout = open(os.devnull, 'w') 22 | 23 | 24 | def enable_print(): 25 | sys.stdout = sys.__stdout__ 26 | 27 | 28 | def get_num_lines(file_path): 29 | f = open(file_path, 'r') 30 | lines = f.readlines() 31 | f.close() 32 | return len(lines) 33 | 34 | 35 | def colorize(value, vmin=None, vmax=None, cmap='Greys'): 36 | value = value.cpu().numpy()[:, :, :] 37 | value = np.log10(value) 38 | 39 | vmin = value.min() if vmin is None else vmin 40 | vmax = value.max() if vmax is None else vmax 41 | 42 | if vmin != vmax: 43 | value = (value - vmin) / (vmax - vmin) 44 | else: 45 | value = value*0. 46 | 47 | cmapper = matplotlib.cm.get_cmap(cmap) 48 | value = cmapper(value, bytes=True) 49 | 50 | img = value[:, :, :3] 51 | 52 | return img.transpose((2, 0, 1)) 53 | 54 | 55 | def normalize_result(value, vmin=None, vmax=None): 56 | value = value.cpu().numpy()[0, :, :] 57 | 58 | vmin = value.min() if vmin is None else vmin 59 | vmax = value.max() if vmax is None else vmax 60 | 61 | if vmin != vmax: 62 | value = (value - vmin) / (vmax - vmin) 63 | else: 64 | value = value * 0. 65 | 66 | return np.expand_dims(value, 0) 67 | 68 | 69 | inv_normalize = transforms.Normalize( 70 | mean=[-0.485/0.229, -0.456/0.224, -0.406/0.225], 71 | std=[1/0.229, 1/0.224, 1/0.225] 72 | ) 73 | 74 | 75 | eval_metrics = ['silog', 'abs_rel', 'log10', 'rms', 'sq_rel', 'log_rms', 'd1', 'd2', 'd3'] 76 | 77 | 78 | def compute_errors(gt, pred): 79 | thresh = np.maximum((gt / pred), (pred / gt)) 80 | d1 = (thresh < 1.25).mean() 81 | d2 = (thresh < 1.25 ** 2).mean() 82 | d3 = (thresh < 1.25 ** 3).mean() 83 | 84 | rms = (gt - pred) ** 2 85 | rms = np.sqrt(rms.mean()) 86 | 87 | log_rms = (np.log(gt) - np.log(pred)) ** 2 88 | log_rms = np.sqrt(log_rms.mean()) 89 | 90 | abs_rel = np.mean(np.abs(gt - pred) / gt) 91 | sq_rel = np.mean(((gt - pred) ** 2) / gt) 92 | 93 | err = np.log(pred) - np.log(gt) 94 | silog = np.sqrt(np.mean(err ** 2) - np.mean(err) ** 2) * 100 95 | 96 | err = np.abs(np.log10(pred) - np.log10(gt)) 97 | log10 = np.mean(err) 98 | 99 | return [silog, abs_rel, log10, rms, sq_rel, log_rms, d1, d2, d3] 100 | 101 | 102 | class silog_loss(nn.Module): 103 | def __init__(self, variance_focus): 104 | super(silog_loss, self).__init__() 105 | self.variance_focus = variance_focus 106 | 107 | def forward(self, depth_est, depth_gt, mask): 108 | d = torch.log(depth_est[mask]) - torch.log(depth_gt[mask]) 109 | return torch.sqrt((d ** 2).mean() - self.variance_focus * (d.mean() ** 2)) * 10.0 110 | 111 | 112 | def flip_lr(image): 113 | """ 114 | Flip image horizontally 115 | 116 | Parameters 117 | ---------- 118 | image : torch.Tensor [B,3,H,W] 119 | Image to be flipped 120 | 121 | Returns 122 | ------- 123 | image_flipped : torch.Tensor [B,3,H,W] 124 | Flipped image 125 | """ 126 | assert image.dim() == 4, 'You need to provide a [B,C,H,W] image to flip' 127 | return torch.flip(image, [3]) 128 | 129 | 130 | def fuse_inv_depth(inv_depth, inv_depth_hat, method='mean'): 131 | """ 132 | Fuse inverse depth and flipped inverse depth maps 133 | 134 | Parameters 135 | ---------- 136 | inv_depth : torch.Tensor [B,1,H,W] 137 | Inverse depth map 138 | inv_depth_hat : torch.Tensor [B,1,H,W] 139 | Flipped inverse depth map produced from a flipped image 140 | method : str 141 | Method that will be used to fuse the inverse depth maps 142 | 143 | Returns 144 | ------- 145 | fused_inv_depth : torch.Tensor [B,1,H,W] 146 | Fused inverse depth map 147 | """ 148 | if method == 'mean': 149 | return 0.5 * (inv_depth + inv_depth_hat) 150 | elif method == 'max': 151 | return torch.max(inv_depth, inv_depth_hat) 152 | elif method == 'min': 153 | return torch.min(inv_depth, inv_depth_hat) 154 | else: 155 | raise ValueError('Unknown post-process method {}'.format(method)) 156 | 157 | 158 | def post_process_depth(depth, depth_flipped, method='mean'): 159 | """ 160 | Post-process an inverse and flipped inverse depth map 161 | 162 | Parameters 163 | ---------- 164 | inv_depth : torch.Tensor [B,1,H,W] 165 | Inverse depth map 166 | inv_depth_flipped : torch.Tensor [B,1,H,W] 167 | Inverse depth map produced from a flipped image 168 | method : str 169 | Method that will be used to fuse the inverse depth maps 170 | 171 | Returns 172 | ------- 173 | inv_depth_pp : torch.Tensor [B,1,H,W] 174 | Post-processed inverse depth map 175 | """ 176 | B, C, H, W = depth.shape 177 | inv_depth_hat = flip_lr(depth_flipped) 178 | inv_depth_fused = fuse_inv_depth(depth, inv_depth_hat, method=method) 179 | xs = torch.linspace(0., 1., W, device=depth.device, 180 | dtype=depth.dtype).repeat(B, C, H, 1) 181 | mask = 1.0 - torch.clamp(20. * (xs - 0.05), 0., 1.) 182 | mask_hat = flip_lr(mask) 183 | return mask_hat * depth + mask * inv_depth_hat + \ 184 | (1.0 - mask - mask_hat) * inv_depth_fused 185 | 186 | 187 | class DistributedSamplerNoEvenlyDivisible(Sampler): 188 | """Sampler that restricts data loading to a subset of the dataset. 189 | 190 | It is especially useful in conjunction with 191 | :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each 192 | process can pass a DistributedSampler instance as a DataLoader sampler, 193 | and load a subset of the original dataset that is exclusive to it. 194 | 195 | .. note:: 196 | Dataset is assumed to be of constant size. 197 | 198 | Arguments: 199 | dataset: Dataset used for sampling. 200 | num_replicas (optional): Number of processes participating in 201 | distributed training. 202 | rank (optional): Rank of the current process within num_replicas. 203 | shuffle (optional): If true (default), sampler will shuffle the indices 204 | """ 205 | 206 | def __init__(self, dataset, num_replicas=None, rank=None, shuffle=True): 207 | if num_replicas is None: 208 | if not dist.is_available(): 209 | raise RuntimeError("Requires distributed package to be available") 210 | num_replicas = dist.get_world_size() 211 | if rank is None: 212 | if not dist.is_available(): 213 | raise RuntimeError("Requires distributed package to be available") 214 | rank = dist.get_rank() 215 | self.dataset = dataset 216 | self.num_replicas = num_replicas 217 | self.rank = rank 218 | self.epoch = 0 219 | num_samples = int(math.floor(len(self.dataset) * 1.0 / self.num_replicas)) 220 | rest = len(self.dataset) - num_samples * self.num_replicas 221 | if self.rank < rest: 222 | num_samples += 1 223 | self.num_samples = num_samples 224 | self.total_size = len(dataset) 225 | # self.total_size = self.num_samples * self.num_replicas 226 | self.shuffle = shuffle 227 | 228 | def __iter__(self): 229 | # deterministically shuffle based on epoch 230 | g = torch.Generator() 231 | g.manual_seed(self.epoch) 232 | if self.shuffle: 233 | indices = torch.randperm(len(self.dataset), generator=g).tolist() 234 | else: 235 | indices = list(range(len(self.dataset))) 236 | 237 | # add extra samples to make it evenly divisible 238 | # indices += indices[:(self.total_size - len(indices))] 239 | # assert len(indices) == self.total_size 240 | 241 | # subsample 242 | indices = indices[self.rank:self.total_size:self.num_replicas] 243 | self.num_samples = len(indices) 244 | # assert len(indices) == self.num_samples 245 | 246 | return iter(indices) 247 | 248 | def __len__(self): 249 | return self.num_samples 250 | 251 | def set_epoch(self, epoch): 252 | self.epoch = epoch -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import torch 5 | import argparse 6 | import numpy as np 7 | from tqdm import trange 8 | from SGR import DepthNet as SGRnet 9 | from MiDaS.midas_net import MidasNet 10 | from utils.model import Gradient_FusionModel 11 | from utils.func import scale_image, save_orig, visual_crfs 12 | from LeRes.multi_depth_model_woauxi import strip_prefix_if_present, RelDepthModel 13 | from dpt.models import DPTDepthModel 14 | from dpt.transforms import Resize, NormalizeImage, PrepareForNet 15 | from newcrfs.networks.NewCRFDepth import NewCRFDepth 16 | from torchvision.transforms import Compose 17 | import torchvision.transforms as transforms 18 | from torch.autograd import Variable 19 | 20 | 21 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 22 | torch.backends.cudnn.enabled = True 23 | torch.backends.cudnn.benchmark = True 24 | 25 | 26 | def run(args): 27 | size = 224 28 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 29 | if args.pred_model == 'LeRes50': 30 | Depth_model = RelDepthModel(backbone='resnet50') 31 | depth_dict = './LeRes/res50.pth' 32 | depth_dict = torch.load(depth_dict) 33 | Depth_model.load_state_dict(strip_prefix_if_present(depth_dict['depth_model'], "module."), strict=True) 34 | model_flag = 1 35 | 36 | elif args.pred_model == 'LeRes101': 37 | Depth_model = RelDepthModel(backbone='resnext101') 38 | depth_dict='./LeRes/res101.pth' 39 | depth_dict = torch.load(depth_dict) 40 | Depth_model.load_state_dict(strip_prefix_if_present(depth_dict['depth_model'], "module."), strict=True) 41 | model_flag = 2 42 | 43 | elif args.pred_model == 'SGR': 44 | Depth_model = SGRnet.DepthNet() 45 | if device == torch.device("cuda"): 46 | Depth_model = torch.nn.DataParallel(Depth_model, device_ids=[0]).cuda() 47 | else: 48 | print('sgr model can not run correctly without gpu') 49 | exit() 50 | depth_dict = torch.load('./SGR/model.pth.tar') 51 | Depth_model.load_state_dict(depth_dict['state_dict']) 52 | model_flag = 3 53 | 54 | elif args.pred_model == 'MiDaS': 55 | Depth_model = MidasNet('./MiDaS/model.pt', non_negative=True) 56 | model_flag = 4 57 | size = 192 58 | 59 | elif args.pred_model == 'DPT': 60 | torch.backends.cudnn.enabled = True 61 | torch.backends.cudnn.benchmark = True 62 | Depth_model = DPTDepthModel( 63 | path="./dpt/weights/dpt_hybrid-midas-501f0c75.pt", 64 | backbone="vitb_rn50_384", 65 | non_negative=True, 66 | enable_attention_hooks=False, 67 | ) 68 | normalization = NormalizeImage(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) 69 | transform_low = Compose( 70 | [Resize( 71 | 384, 72 | 384, 73 | resize_target=None, 74 | keep_aspect_ratio=True, 75 | ensure_multiple_of=32, 76 | resize_method="minimal", 77 | image_interpolation_method=cv2.INTER_CUBIC, 78 | ), 79 | normalization, 80 | PrepareForNet(),]) 81 | 82 | transform_high = Compose( 83 | [Resize( 84 | 384*3, 85 | 384*3, 86 | resize_target=None, 87 | keep_aspect_ratio=True, 88 | ensure_multiple_of=32, 89 | resize_method="minimal", 90 | image_interpolation_method=cv2.INTER_CUBIC, 91 | ), 92 | normalization, 93 | PrepareForNet(),]) 94 | if device == torch.device("cuda"): 95 | Depth_model = Depth_model.to(memory_format=torch.channels_last) 96 | Depth_model = Depth_model.half() 97 | model_flag = 5 98 | 99 | elif args.pred_model == 'NeWCRFs': 100 | checkpoint_path = './newcrfs/model_nyu.ckpt' 101 | Depth_model = NewCRFDepth(version='large07', inv_depth=True, max_depth=1000) 102 | Depth_model = torch.nn.DataParallel(Depth_model) 103 | checkpoint = torch.load(checkpoint_path) 104 | Depth_model.load_state_dict(checkpoint['model']) 105 | model_flag = 6 106 | 107 | else: 108 | print('no such model') 109 | exit() 110 | 111 | Fusion_model = Gradient_FusionModel(dict_path=args.model_weights) 112 | 113 | Depth_model.to(device) 114 | Depth_model.eval() 115 | Fusion_model.to(device) 116 | Fusion_model.eval() 117 | 118 | img_names = glob.glob(os.path.join(args.input_path, "*")) 119 | img_names.sort() 120 | for index in trange(len(img_names)): 121 | img_loc = img_names[index] 122 | img = cv2.imread(img_loc) 123 | 124 | if model_flag == 5: 125 | img = img.astype('float32')/255.0 126 | low_img = transform_low({"image": img})["image"] 127 | high_img = transform_high({"image": img})["image"] 128 | elif model_flag == 6: 129 | img = img.astype('float32')/255.0 130 | low_img = cv2.resize(img, (640, 480)) 131 | high_img = cv2.resize(img, (640*3, 480*3)) 132 | normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) 133 | low_img = np.expand_dims(low_img, axis=0) 134 | low_img = np.transpose(low_img, (0, 3, 1, 2)) 135 | low_img = Variable(normalize(torch.from_numpy(low_img)).float()).cuda() 136 | high_img = np.expand_dims(high_img, axis=0) 137 | high_img = np.transpose(high_img, (0, 3, 1, 2)) 138 | high_img = Variable(normalize(torch.from_numpy(high_img)).float()).cuda() 139 | else: 140 | low_img, high_img = scale_image(img, size, device) 141 | 142 | with torch.no_grad(): 143 | if model_flag == 1 or model_flag == 2: 144 | low_dep = Depth_model.inference(low_img) 145 | high_dep = Depth_model.inference(high_img) 146 | 147 | elif model_flag == 3: 148 | low_dep = Depth_model.forward(low_img) 149 | high_dep = Depth_model.forward(high_img) 150 | low_dep = low_dep.max() - low_dep 151 | high_dep = high_dep.max() - high_dep 152 | 153 | elif model_flag == 4: 154 | low_dep = Depth_model.forward(low_img).unsqueeze(0) 155 | high_dep = Depth_model.forward(high_img).unsqueeze(0) 156 | low_dep = low_dep.max() - low_dep 157 | high_dep = high_dep.max() - high_dep 158 | 159 | elif model_flag == 5: 160 | sample = torch.from_numpy(low_img).to(device).unsqueeze(0) 161 | sample = sample.to(memory_format=torch.channels_last) 162 | sample = sample.half() 163 | low_dep = Depth_model.forward(sample) 164 | low_dep = (torch.nn.functional.interpolate( 165 | low_dep.unsqueeze(1), 166 | size=img.shape[:2], 167 | mode="bicubic", 168 | align_corners=False,)).float() 169 | sample = torch.from_numpy(high_img).to(device).unsqueeze(0) 170 | sample = sample.to(memory_format=torch.channels_last) 171 | sample = sample.half() 172 | high_dep = Depth_model.forward(sample) 173 | high_dep = (torch.nn.functional.interpolate( 174 | high_dep.unsqueeze(1), 175 | size=img.shape[:2], 176 | mode="bicubic", 177 | align_corners=False,)).float() 178 | low_dep = low_dep.max() - low_dep 179 | high_dep = high_dep.max() - high_dep 180 | 181 | elif model_flag == 6: 182 | low_dep = Depth_model(low_img) 183 | high_dep = Depth_model(high_img) 184 | low_dep, high_dep = visual_crfs(low_dep, high_dep) 185 | 186 | low_dep, high_dep, pred = Fusion_model.inference(low_dep, high_dep) 187 | save_orig(img, f'{args.output_path}/{args.pred_model}_{index}.jpg', low_dep, pred, model_flag) 188 | 189 | 190 | 191 | if __name__ == '__main__': 192 | parser = argparse.ArgumentParser() 193 | 194 | parser.add_argument('-i', '--input_path', 195 | default='./input', 196 | help='folder with input images' 197 | ) 198 | 199 | parser.add_argument('-o', '--output_path', 200 | default='./output', 201 | help='folder for output images' 202 | ) 203 | 204 | parser.add_argument('-m', '--model_weights', 205 | default='./models/model_dict.pt', 206 | help='path to the trained weights of model' 207 | ) 208 | 209 | parser.add_argument('-p', '--pred_model', 210 | default='LeRes50', 211 | help='model type: LeRes50, LeRes101, SGR ,MiDaS, DPT or NeWCRFs' 212 | ) 213 | 214 | args = parser.parse_args() 215 | run(args) 216 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os 2 | import h5py 3 | import torch 4 | import numpy as np 5 | import torch.utils.data as dataloader 6 | from utils.model import Gradient_FusionModel 7 | from torch.optim import lr_scheduler, AdamW 8 | import torchvision.transforms as transforms 9 | 10 | 11 | os.environ['CUDA_LAUNCH_BLOCKING'] = '1' 12 | os.environ["HDF5_USE_FILE_LOCKING"] = 'FALSE' 13 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 14 | 15 | class DataFromH5File(dataloader.Dataset): 16 | def __init__(self, filepath): 17 | h5File = h5py.File(filepath, 'r', swmr=True) 18 | self.photo = h5File['photo'][:] 19 | self.depth = h5File['depth'][:] 20 | 21 | def __getitem__(self, idx): 22 | global device 23 | transform_= transforms.Compose([transforms.ToTensor(), transforms.Resize(size=(2048, 2048))]) 24 | inputs = self.depth[idx] 25 | photo = self.photo[idx] 26 | with torch.no_grad(): 27 | low = transform_(inputs[0].astype('float32')).to(device) 28 | high = transform_(inputs[1].astype('float32')).to(device) 29 | guided = transform_(inputs[2].astype('float32')).to(device) 30 | return low, high, guided, photo 31 | 32 | def __len__(self): 33 | return self.photo.shape[0] 34 | 35 | 36 | def Train_HR(Fuse_model, optimizer, scheduler): 37 | index_list = [i+1 for i in range(36)] 38 | for index in index_list: 39 | data_loc = f'./datasets/HR/hq_HR_{index}.hdf5' 40 | data_set = DataFromH5File(data_loc) 41 | train_loader = dataloader.DataLoader(dataset=data_set, batch_size=2, shuffle=True, num_workers=0, pin_memory=False) 42 | for step, (low, high, guided, photo) in enumerate(train_loader): 43 | optimizer.zero_grad() 44 | loss = Fuse_model.cret(low.clone().detach(), high.clone().detach(), guided.clone().detach()) 45 | loss.backward() 46 | optimizer.step() 47 | if Fuse_model.total_step % 100 == 0: 48 | Fuse_model.record() 49 | Fuse_model.evaluate(low, high, guided, photo) 50 | scheduler.step() 51 | return Fuse_model, optimizer, scheduler 52 | 53 | 54 | def run(): 55 | global device 56 | log_path = 'logs/test' 57 | Fuse_model = Gradient_FusionModel(log_path=log_path) 58 | Fuse_model.to(device) 59 | 60 | optimizer = AdamW(filter(lambda p: p.requires_grad, Fuse_model.Fuse.parameters()), lr=(1e-4) * 1) 61 | scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99) 62 | for i in range(2): 63 | Fuse_model, optimizer, scheduler = Train_HR(Fuse_model, optimizer, scheduler) 64 | state = {'net': Fuse_model.Fuse.state_dict()} 65 | torch.save(state, log_path +f'/model_dict_{i}.pt') 66 | print(f'finish') 67 | 68 | 69 | if __name__ == '__main__': 70 | run() 71 | -------------------------------------------------------------------------------- /utils/func.py: -------------------------------------------------------------------------------- 1 | import cv2 2 | import torch 3 | import numpy as np 4 | import torchvision.transforms as transforms 5 | import matplotlib.pyplot as plt 6 | from utils.guided_f import guided_filter 7 | 8 | 9 | def shift_scale(pred, gt, mask_=None): 10 | gt_flat = gt.flatten() 11 | pred_flat = pred.flatten() 12 | mask_valid = np.ones_like(gt_flat) 13 | mask_valid[gt_flat == 0] = 0 14 | if mask_ is not None: 15 | mask_ = mask_.flatten() 16 | mask_valid[mask_ == 0] = 0 17 | 18 | gt_valid = np.array([gt_flat[i] for i in range(len(mask_valid)) if mask_valid[i]]) 19 | pred_valid = np.array([pred_flat[i] for i in range(len(mask_valid)) if mask_valid[i]]) 20 | 21 | para_s = np.polyfit(pred_valid, gt_valid, deg=1) 22 | pred = np.polyval(para_s, pred) 23 | return pred 24 | 25 | 26 | def generate_gf(low_dep, high_dep): 27 | r = int(high_dep.shape[0] / 12) - 1 28 | enhanced = guided_filter(high_dep, low_dep, r, 1e-12) 29 | return enhanced 30 | 31 | 32 | def visual_crfs(low_dep, high_dep): 33 | low_dep[:, :, :, :5] = low_dep.min() 34 | low_dep[:, :, :, -5:] = low_dep.min() 35 | low_dep[:, :, :5, :] = low_dep.min() 36 | low_dep[:, :, -5:, :] = low_dep.min() 37 | high_dep[:, :, :, :15] = high_dep.min() 38 | high_dep[:, :, :, -15:] = high_dep.min() 39 | high_dep[:, :, :15, :] = high_dep.min() 40 | high_dep[:, :, -15:, :] = high_dep.min() 41 | return low_dep, high_dep 42 | 43 | 44 | def img2Tensor(img, scale, model_input_size=224): 45 | transformer = transforms.Compose([transforms.ToTensor(), transforms.Resize(size=(model_input_size*scale, model_input_size*scale)),\ 46 | transforms.Normalize((0.485, 0.456, 0.406) , (0.229, 0.224, 0.225))]) 47 | tens = transformer(img.copy()) 48 | return tens[None, :, :, :] 49 | 50 | 51 | def scale_image(img, size, device): 52 | scale_list = [2, 6] 53 | tensor_list = [] 54 | 55 | for scale in scale_list: 56 | tensor = img2Tensor(img, scale, size) 57 | if device == torch.device("cuda"): 58 | tensor_list.append(tensor.cuda()) 59 | else: 60 | tensor_list.append(tensor) 61 | return tensor_list 62 | 63 | 64 | def save_orig(input_rgb, img_loc, low_dep, pred, model_flag): 65 | if model_flag >= 5: 66 | input_rgb = (input_rgb * 255).astype('uint8') 67 | input_rgb = cv2.resize(input_rgb, None, fx=0.5, fy=0.5) 68 | h, w, _ = input_rgb.shape 69 | low_dep = cv2.resize(low_dep.cpu().detach().numpy().squeeze(), (w, h)) 70 | pred = cv2.resize(pred.cpu().detach().numpy().squeeze(), (w, h)) 71 | low_dep = 255 - (low_dep - low_dep.min()) / (low_dep.max() - low_dep.min()) * 255 72 | pred = 255 - (pred - pred.min()) / (pred.max() - pred.min()) * 255 73 | result = np.hstack((low_dep, low_dep, pred)) 74 | plt.imsave(img_loc, result, cmap='inferno') 75 | result = cv2.imread(img_loc) 76 | result[:h, :w, :] = input_rgb 77 | cv2.imwrite(img_loc, result) 78 | -------------------------------------------------------------------------------- /utils/guided_f.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import scipy as sp 3 | import scipy.ndimage 4 | import matplotlib.pyplot as plt 5 | 6 | def box(img, r): 7 | """ O(1) box filter 8 | img - >= 2d image 9 | r - radius of box filter 10 | """ 11 | (rows, cols) = img.shape[:2] 12 | imDst = np.zeros_like(img) 13 | 14 | 15 | tile = [1] * img.ndim 16 | tile[0] = r 17 | imCum = np.cumsum(img, 0) 18 | imDst[0:r+1, :, ...] = imCum[r:2*r+1, :, ...] 19 | imDst[r+1:rows-r, :, ...] = imCum[2*r+1:rows, :, ...] - imCum[0:rows-2*r-1, :, ...] 20 | imDst[rows-r:rows, :, ...] = np.tile(imCum[rows-1:rows, :, ...], tile) - imCum[rows-2*r-1:rows-r-1, :, ...] 21 | 22 | tile = [1] * img.ndim 23 | tile[1] = r 24 | imCum = np.cumsum(imDst, 1) 25 | imDst[:, 0:r+1, ...] = imCum[:, r:2*r+1, ...] 26 | imDst[:, r+1:cols-r, ...] = imCum[:, 2*r+1 : cols, ...] - imCum[:, 0 : cols-2*r-1, ...] 27 | imDst[:, cols-r: cols, ...] = np.tile(imCum[:, cols-1:cols, ...], tile) - imCum[:, cols-2*r-1 : cols-r-1, ...] 28 | 29 | return imDst 30 | 31 | def _gf_color(I, p, r, eps, s=None): 32 | """ Color guided filter 33 | I - guide image (rgb) 34 | p - filtering input (single channel) 35 | r - window radius 36 | eps - regularization (roughly, variance of non-edge noise) 37 | s - subsampling factor for fast guided filter 38 | """ 39 | fullI = I 40 | fullP = p 41 | if s is not None: 42 | I = sp.ndimage.zoom(fullI, [1/s, 1/s, 1], order=1) 43 | p = sp.ndimage.zoom(fullP, [1/s, 1/s], order=1) 44 | r = round(r / s) 45 | 46 | h, w = p.shape[:2] 47 | N = box(np.ones((h, w)), r) 48 | 49 | mI_r = box(I[:,:,0], r) / N 50 | mI_g = box(I[:,:,1], r) / N 51 | mI_b = box(I[:,:,2], r) / N 52 | 53 | mP = box(p, r) / N 54 | 55 | # mean of I * p 56 | mIp_r = box(I[:,:,0]*p, r) / N 57 | mIp_g = box(I[:,:,1]*p, r) / N 58 | mIp_b = box(I[:,:,2]*p, r) / N 59 | 60 | # per-patch covariance of (I, p) 61 | covIp_r = mIp_r - mI_r * mP 62 | covIp_g = mIp_g - mI_g * mP 63 | covIp_b = mIp_b - mI_b * mP 64 | 65 | # symmetric covariance matrix of I in each patch: 66 | # rr rg rb 67 | # rg gg gb 68 | # rb gb bb 69 | var_I_rr = box(I[:,:,0] * I[:,:,0], r) / N - mI_r * mI_r; 70 | var_I_rg = box(I[:,:,0] * I[:,:,1], r) / N - mI_r * mI_g; 71 | var_I_rb = box(I[:,:,0] * I[:,:,2], r) / N - mI_r * mI_b; 72 | 73 | var_I_gg = box(I[:,:,1] * I[:,:,1], r) / N - mI_g * mI_g; 74 | var_I_gb = box(I[:,:,1] * I[:,:,2], r) / N - mI_g * mI_b; 75 | 76 | var_I_bb = box(I[:,:,2] * I[:,:,2], r) / N - mI_b * mI_b; 77 | 78 | a = np.zeros((h, w, 3)) 79 | for i in range(h): 80 | for j in range(w): 81 | sig = np.array([ 82 | [var_I_rr[i,j], var_I_rg[i,j], var_I_rb[i,j]], 83 | [var_I_rg[i,j], var_I_gg[i,j], var_I_gb[i,j]], 84 | [var_I_rb[i,j], var_I_gb[i,j], var_I_bb[i,j]] 85 | ]) 86 | covIp = np.array([covIp_r[i,j], covIp_g[i,j], covIp_b[i,j]]) 87 | a[i,j,:] = np.linalg.solve(sig + eps * np.eye(3), covIp) 88 | 89 | b = mP - a[:,:,0] * mI_r - a[:,:,1] * mI_g - a[:,:,2] * mI_b 90 | 91 | meanA = box(a, r) / N[...,np.newaxis] 92 | meanB = box(b, r) / N 93 | 94 | if s is not None: 95 | meanA = sp.ndimage.zoom(meanA, [s, s, 1], order=1) 96 | meanB = sp.ndimage.zoom(meanB, [s, s], order=1) 97 | 98 | q = np.sum(meanA * fullI, axis=2) + meanB 99 | 100 | return q 101 | 102 | 103 | def _gf_gray(I, p, r, eps, s=None): 104 | """ grayscale (fast) guided filter 105 | I - guide image (1 channel) 106 | p - filter input (1 channel) 107 | r - window raidus 108 | eps - regularization (roughly, allowable variance of non-edge noise) 109 | s - subsampling factor for fast guided filter 110 | """ 111 | if s is not None: 112 | Isub = sp.ndimage.zoom(I, 1/s, order=1) 113 | Psub = sp.ndimage.zoom(p, 1/s, order=1) 114 | r = round(r / s) 115 | else: 116 | Isub = I 117 | Psub = p 118 | 119 | 120 | (rows, cols) = Isub.shape 121 | 122 | N = box(np.ones([rows, cols]), r) 123 | 124 | meanI = box(Isub, r) / N 125 | meanP = box(Psub, r) / N 126 | corrI = box(Isub * Isub, r) / N 127 | corrIp = box(Isub * Psub, r) / N 128 | varI = corrI - meanI * meanI 129 | covIp = corrIp - meanI * meanP 130 | 131 | 132 | a = covIp / (varI + eps) 133 | b = meanP - a * meanI 134 | 135 | meanA = box(a, r) / N 136 | meanB = box(b, r) / N 137 | 138 | if s is not None: 139 | meanA = sp.ndimage.zoom(meanA, s, order=1) 140 | meanB = sp.ndimage.zoom(meanB, s, order=1) 141 | 142 | q = meanA * I + meanB 143 | return q 144 | 145 | 146 | def _gf_colorgray(I, p, r, eps, s=None): 147 | """ automatically choose color or gray guided filter based on I's shape """ 148 | if I.ndim == 2 or I.shape[2] == 1: 149 | return _gf_gray(I, p, r, eps, s) 150 | elif I.ndim == 3 and I.shape[2] == 3: 151 | return _gf_color(I, p, r, eps, s) 152 | else: 153 | print("Invalid guide dimensions:", I.shape) 154 | 155 | 156 | def guided_filter(I, p, r, eps, s=None): 157 | """ run a guided filter per-channel on filtering input p 158 | I - guide image (1 or 3 channel) 159 | p - filter input (n channel) 160 | r - window raidus 161 | eps - regularization (roughly, allowable variance of non-edge noise) 162 | s - subsampling factor for fast guided filter 163 | """ 164 | if p.ndim == 2: 165 | p3 = p[:,:,np.newaxis] 166 | 167 | out = np.zeros_like(p3) 168 | for ch in range(p3.shape[2]): 169 | out[:,:,ch] = _gf_colorgray(I, p3[:,:,ch], r, eps, s) 170 | return np.squeeze(out) if p.ndim == 2 else out 171 | 172 | 173 | def test_gf(low, high): 174 | low = imageio.imread('14low.jpg').astype(np.float32) / 255 175 | high = imageio.imread('14high.jpg').astype(np.float32) / 255 176 | 177 | r = 160 178 | eps = 0.000000001 179 | dep_smoothed = guided_filter(high, low, r, eps) 180 | 181 | -------------------------------------------------------------------------------- /utils/hypersim.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import h5py 5 | from pylab import * 6 | import numpy as np 7 | from utils.func import * 8 | 9 | 10 | def convetDep(file): 11 | intWidth, intHeight, fltFocal = 1024, 768, 886.81 12 | npyDistance = file 13 | npyImageplaneX = np.linspace((-0.5 * intWidth) + 0.5, (0.5 * intWidth) - 0.5, intWidth).reshape(1, intWidth).repeat(intHeight, 0).astype(np.float32)[:, :, None] 14 | npyImageplaneY = np.linspace((-0.5 * intHeight) + 0.5, (0.5 * intHeight) - 0.5, intHeight).reshape(intHeight, 1).repeat(intWidth, 1).astype(np.float32)[:, :, None] 15 | npyImageplaneZ = np.full([intHeight, intWidth, 1], fltFocal, np.float32) 16 | npyImageplane = np.concatenate([npyImageplaneX, npyImageplaneY, npyImageplaneZ], 2) 17 | 18 | npyDepth = npyDistance / np.linalg.norm(npyImageplane, 2, 2) * fltFocal 19 | return npyDepth 20 | 21 | 22 | def compute_global_errors(gt, pred): 23 | gt=gt[gt!=0] 24 | pred=pred[pred!=0] 25 | 26 | mask2 = gt > 1e-8 27 | mask3 = pred > 1e-8 28 | mask2 = mask2 & mask3 29 | 30 | gt = gt[mask2] 31 | pred = pred[mask2] 32 | 33 | #compute global relative errors 34 | thresh = np.maximum((gt / pred), (pred / gt)) 35 | thr1 = (thresh < 1.25 ).mean() 36 | thr2 = (thresh < 1.25 ** 2).mean() 37 | rmse = (gt - pred) ** 2 38 | rmse = np.sqrt(rmse.mean()) 39 | log10 = np.mean(np.abs(np.log10(gt) - np.log10(pred))) 40 | sq_rel = np.mean(((gt - pred)**2) / gt) 41 | 42 | return sq_rel, rmse, log10, thr1, thr2 43 | 44 | 45 | class hypersim(): 46 | def __init__(self, path= './datasets/hypersim'): 47 | self.input_path = path 48 | self.num = 286 49 | self.rms = np.zeros(self.num, np.float32) 50 | self.log10 = np.zeros(self.num, np.float32) 51 | self.sq_rel = np.zeros(self.num, np.float32) 52 | self.thr1 = np.zeros(self.num, np.float32) 53 | self.thr2 = np.zeros(self.num, np.float32) 54 | self.d3r_rel = np.zeros(self.num, np.float32) 55 | self.ord_rel = np.zeros(self.num, np.float32) 56 | self.img_names = glob.glob(os.path.join(self.input_path, "*")) 57 | self.index = -1 58 | self.name_list = [] 59 | 60 | 61 | def getitem(self): 62 | if not len(self.name_list): 63 | self.sub_set = self.img_names.pop() 64 | img_loc = self.sub_set + '/images/scene_cam_00_final_hdf5/' 65 | name_list = glob.glob(os.path.join(img_loc, "*")) 66 | self.name_list = [name for name in name_list if 'color' in name] 67 | file_loc = self.name_list.pop() 68 | 69 | dep_loc = file_loc.replace('scene_cam_00_final_hdf5', 'scene_cam_00_geometry_hdf5') 70 | dep_loc = dep_loc.replace('color', 'depth_meters') 71 | entity_loc = dep_loc.replace('depth_meters', 'render_entity_id') 72 | 73 | with h5py.File(file_loc, "r") as f: rgb_color = f["dataset"][:].astype('float32') 74 | with h5py.File(entity_loc, "r") as f: render_entity_id = f["dataset"][:].astype('int32') 75 | # assert all(render_entity_id != 0) 76 | 77 | gamma = .5/2.2 # standard gamma correction exponent 78 | inv_gamma = 1.0/gamma 79 | percentile = 90 # we want this percentile brightness value in the unmodified image... 80 | brightness_nth_percentile_desired = 0.8 # ...to be this bright after scaling 81 | valid_mask = render_entity_id != -1 82 | if count_nonzero(valid_mask) == 0: 83 | scale = 1.0 # if there are no valid pixels, then set scale to 1.0 84 | else: 85 | brightness = 0.3*rgb_color[:,:,0] + 0.59*rgb_color[:,:,1] + 0.11*rgb_color[:,:,2] # "CCIR601 YIQ" method for computing brightness 86 | brightness_valid = brightness[valid_mask] 87 | eps = 0.0001 # if the kth percentile brightness value in the unmodified image is less than this, set the scale to 0.0 to avoid divide-by-zero 88 | brightness_nth_percentile_current = np.percentile(brightness_valid, percentile) 89 | if brightness_nth_percentile_current < eps: 90 | scale = 0.0 91 | else: 92 | scale = np.power(brightness_nth_percentile_desired, inv_gamma) / brightness_nth_percentile_current 93 | rgb_color_tm = np.power(np.maximum(scale*rgb_color,0), gamma) 94 | 95 | img = rgb_color_tm 96 | img = img/img.max() * 255 97 | img_bgr = img[:, :, ::-1] 98 | img_bgr = img_bgr.astype('uint8') 99 | 100 | with h5py.File(dep_loc, "r") as f: 101 | dep = f["dataset"][:] 102 | 103 | nan = float('nan') 104 | dep[np.isnan(dep)] = 0 105 | dep = dep.astype('float32') 106 | val_mask = np.ones_like(dep) 107 | val_mask[dep==0]=0 108 | self.index += 1 109 | return img_bgr, dep, val_mask 110 | 111 | def compute_error(self, target, depth, val_mask): 112 | target = target.cpu().numpy().squeeze() 113 | h, w = depth.shape 114 | val_mask = cv2.resize(val_mask, (w, h)) 115 | target = cv2.resize(target, (w, h)) 116 | target = shift_scale(target.astype('float64'), depth.astype('float64'), val_mask) 117 | 118 | pred = target.copy() 119 | pred_org = pred.copy() 120 | pred_invalid = pred.copy() 121 | pred_invalid[pred_invalid!=0]=1 122 | mask_missing = depth.copy() # Mask for further missing depth values in depth map 123 | mask_missing[mask_missing!=0]=1 124 | mask_valid = mask_missing*pred_invalid # Combine masks 125 | depth_valid = depth*mask_valid 126 | gt = depth_valid 127 | gt_vec = gt.flatten() 128 | pred = pred*mask_valid 129 | pred_vec = pred.flatten() 130 | self.sq_rel[self.index], self.rms[self.index], self.log10[self.index], self.thr1[self.index], self.thr2[self.index] = compute_global_errors(gt_vec,pred_vec) 131 | -------------------------------------------------------------------------------- /utils/middleburry2021.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import csv 4 | import cv2 5 | import glob 6 | import numpy as np 7 | from utils.func import * 8 | 9 | 10 | 11 | def read_calib(calib_file_path): 12 | with open(calib_file_path, 'r') as calib_file: 13 | calib = {} 14 | csv_reader = csv.reader(calib_file, delimiter='=') 15 | for attr, value in csv_reader: 16 | calib.setdefault(attr, value) 17 | return calib 18 | 19 | 20 | def read_pfm(pfm_file_path): 21 | with open(pfm_file_path, 'rb') as pfm_file: 22 | header = pfm_file.readline().decode().rstrip() 23 | channels = 3 if header == 'PF' else 1 24 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', pfm_file.readline().decode('utf-8')) 25 | if dim_match: 26 | width, height = map(int, dim_match.groups()) 27 | else: 28 | raise Exception("Malformed PFM header.") 29 | scale = float(pfm_file.readline().decode().rstrip()) 30 | if scale < 0: 31 | endian = '<' # littel endian 32 | scale = -scale 33 | else: 34 | endian = '>' # big endian 35 | dispariy = np.fromfile(pfm_file, endian + 'f') 36 | return dispariy, [(height, width, channels), scale] 37 | 38 | 39 | def create_depth_map(pfm_file_path, calib_file_path=None): 40 | dispariy, [shape,scale] = read_pfm(pfm_file_path) 41 | if calib_file_path is None: 42 | raise Exception("Loss calibration information.") 43 | else: 44 | calib = read_calib(calib_file_path) 45 | fx = float(calib['cam0'].split(' ')[0].lstrip('[')) 46 | base_line = float(calib['baseline']) 47 | doffs = float(calib['doffs']) 48 | depth_map = fx*base_line / (dispariy / scale + doffs) 49 | depth_map = np.reshape(depth_map, newshape=shape) 50 | depth_map = np.flipud(depth_map).astype('uint8') 51 | return depth_map 52 | 53 | 54 | def compute_global_errors(gt, pred): 55 | gt=gt[gt!=0] 56 | pred=pred[pred!=0] 57 | 58 | mask2 = gt > 1e-8 59 | mask3 = pred > 1e-8 60 | mask2 = mask2 & mask3 61 | 62 | gt = gt[mask2] 63 | pred = pred[mask2] 64 | 65 | #compute global relative errors 66 | thresh = np.maximum((gt / pred), (pred / gt)) 67 | thr1 = (thresh < 1.25 ).mean() 68 | thr2 = (thresh < 1.25 ** 2).mean() 69 | rmse = (gt - pred) ** 2 70 | rmse = np.sqrt(rmse.mean()) 71 | log10 = np.mean(np.abs(np.log10(gt) - np.log10(pred))) 72 | sq_rel = np.mean(((gt - pred)**2) / gt) 73 | 74 | return sq_rel, rmse, log10, thr1, thr2 75 | 76 | 77 | class middleburry(): 78 | def __init__(self, path= './datasets/2021mobile'): 79 | self.input_path = path 80 | self.num = 48 81 | self.rms = np.zeros(self.num, np.float32) 82 | self.log10 = np.zeros(self.num, np.float32) 83 | self.sq_rel = np.zeros(self.num, np.float32) 84 | self.thr1 = np.zeros(self.num, np.float32) 85 | self.thr2 = np.zeros(self.num, np.float32) 86 | self.d3r_rel = np.zeros(self.num, np.float32) 87 | self.ord_rel = np.zeros(self.num, np.float32) 88 | self.img_names = glob.glob(os.path.join(self.input_path, "*")) 89 | self.img_names.sort() 90 | self.img_names.reverse() 91 | self.index = -1 92 | self.dex = 0 93 | 94 | 95 | def getitem(self): 96 | if self.dex == 0: 97 | self.img_loc = self.img_names.pop() 98 | img = cv2.imread(self.img_loc+f'/im{self.dex}.png')[:, :, ::-1] 99 | dep = create_depth_map(f'{self.img_loc}/disp{self.dex}.pfm', f'{self.img_loc}/calib.txt')[:, :, 0] 100 | val_mask = np.ones_like(dep) 101 | val_mask[dep==0] = 0 102 | self.dex = (self.dex + 1) % 2 103 | self.index += 1 104 | return img, dep, val_mask 105 | 106 | def compute_error(self, target, depth, val_mask): 107 | target = target.cpu().numpy().squeeze() 108 | h, w = depth.shape 109 | val_mask = cv2.resize(val_mask, (w, h)) 110 | target = cv2.resize(target, (w, h)) 111 | target = shift_scale(target.astype('float64'), depth.astype('float64'), val_mask) 112 | 113 | pred = target.copy() 114 | pred_org = pred.copy() 115 | pred_invalid = pred.copy() 116 | pred_invalid[pred_invalid!=0]=1 117 | mask_missing = depth.copy() # Mask for further missing depth values in depth map 118 | mask_missing[mask_missing!=0]=1 119 | mask_valid = mask_missing*pred_invalid # Combine masks 120 | depth_valid = depth*mask_valid 121 | gt = depth_valid 122 | gt_vec = gt.flatten() 123 | pred = pred*mask_valid 124 | pred_vec = pred.flatten() 125 | self.sq_rel[self.index], self.rms[self.index], self.log10[self.index], self.thr1[self.index], self.thr2[self.index] = compute_global_errors(gt_vec,pred_vec) 126 | -------------------------------------------------------------------------------- /utils/multiscopic.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import glob 4 | import numpy as np 5 | from utils.func import * 6 | 7 | 8 | def compute_global_errors(gt, pred): 9 | gt=gt[gt!=0] 10 | pred=pred[pred!=0] 11 | 12 | mask2 = gt > 1e-8 13 | mask3 = pred > 1e-8 14 | mask2 = mask2 & mask3 15 | 16 | gt = gt[mask2] 17 | pred = pred[mask2] 18 | 19 | #compute global relative errors 20 | thresh = np.maximum((gt / pred), (pred / gt)) 21 | thr1 = (thresh < 1.25 ).mean() 22 | thr2 = (thresh < 1.25 ** 2).mean() 23 | rmse = (gt - pred) ** 2 24 | rmse = np.sqrt(rmse.mean()) 25 | log10 = np.mean(np.abs(np.log10(gt) - np.log10(pred))) 26 | sq_rel = np.mean(((gt - pred)**2) / gt) 27 | 28 | return sq_rel, rmse, log10, thr1, thr2 29 | 30 | 31 | class multiscopic(): 32 | def __init__(self, path= './datasets/multiscopic/test_b'): 33 | self.input_path = path 34 | self.num = 500 35 | self.rms = np.zeros(self.num, np.float32) 36 | self.log10 = np.zeros(self.num, np.float32) 37 | self.sq_rel = np.zeros(self.num, np.float32) 38 | self.thr1 = np.zeros(self.num, np.float32) 39 | self.thr2 = np.zeros(self.num, np.float32) 40 | self.d3r_rel = np.zeros(self.num, np.float32) 41 | self.ord_rel = np.zeros(self.num, np.float32) 42 | self.img_names = glob.glob(os.path.join(self.input_path, "*")) 43 | self.img_names.sort() 44 | self.index = -1 45 | self.dex = 0 46 | 47 | 48 | def getitem(self): 49 | if self.dex == 0: 50 | self.img_loc = self.img_names.pop() 51 | img = cv2.imread(self.img_loc+f'/view{self.dex}.png')[:, :, ::-1] 52 | dep_loc = self.img_loc + f'/disp{self.dex}.png' 53 | depth = cv2.imread(dep_loc)[:, :, 0] 54 | 55 | val_mask = np.ones_like(depth) 56 | val_mask[depth==0]=0 57 | depth = depth.max() - depth 58 | 59 | self.dex = (self.dex + 1) % 5 60 | self.index += 1 61 | return img, depth, val_mask 62 | 63 | def compute_error(self, target, depth, val_mask): 64 | target = target.cpu().numpy().squeeze() 65 | h, w = depth.shape 66 | val_mask = cv2.resize(val_mask, (w, h)) 67 | target = cv2.resize(target, (w, h)) 68 | target = shift_scale(target.astype('float64'), depth.astype('float64'), val_mask) 69 | 70 | pred = target.copy() 71 | pred_org = pred.copy() 72 | pred_invalid = pred.copy() 73 | pred_invalid[pred_invalid!=0]=1 74 | mask_missing = depth.copy() # Mask for further missing depth values in depth map 75 | mask_missing[mask_missing!=0]=1 76 | mask_valid = mask_missing*pred_invalid # Combine masks 77 | depth_valid = depth*mask_valid 78 | gt = depth_valid 79 | gt_vec = gt.flatten() 80 | pred = pred*mask_valid 81 | pred_vec = pred.flatten() 82 | self.sq_rel[self.index], self.rms[self.index], self.log10[self.index], self.thr1[self.index], self.thr2[self.index] = compute_global_errors(gt_vec,pred_vec) 83 | --------------------------------------------------------------------------------