├── .gitignore ├── LICENSE ├── MANIFEST.in ├── README.md ├── rmvsnet ├── RMVSNet-pretrained.pth ├── __init__.py ├── gru.py ├── rmvsnet.py ├── unet_ds2gn.py └── warping.py └── setup.py /.gitignore: -------------------------------------------------------------------------------- 1 | *~ 2 | .* 3 | *.pyc 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2019 Jae Yong Lee 2 | 3 | Permission is hereby granted, free of charge, to any person obtaining a copy 4 | of this software and associated documentation files (the "Software"), to deal 5 | in the Software without restriction, including without limitation the rights 6 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 7 | copies of the Software, and to permit persons to whom the Software is 8 | furnished to do so, subject to the following conditions: 9 | 10 | The above copyright notice and this permission notice shall be included in all 11 | copies or substantial portions of the Software. 12 | 13 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 14 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 15 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 16 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 17 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 18 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 19 | SOFTWARE. 20 | -------------------------------------------------------------------------------- /MANIFEST.in: -------------------------------------------------------------------------------- 1 | include rmvsnet/RMVSNET-pretrained.pth 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # R-MVSNet Pytorch 2 | Pytorch implementation of R-MVSNet. 3 | 4 | This repo uses pytorch ported weights of original author's tensorflow implementation. 5 | 6 | ## Installation 7 | 8 | ``` 9 | pip install git+git://github.com/leejaeyong7/rmvsnet-pytorch.git 10 | ``` 11 | 12 | ## Usage 13 | 14 | ```python 15 | from rmvsnet import RMVSNet 16 | 17 | ''' 18 | 19 | Args: 20 | images: Nx3xHxW tensor. H, W should be multiple of 16 21 | intrinsics: Nx3x3 tensor 22 | extrinsics: Nx4x4 tensor 23 | depth_start: float 24 | depth_interval: float 25 | depth_num: float 26 | 27 | depth ranges are computed by: depth_start + range(depth_num) * depth_interval 28 | 29 | Return: 30 | probs: tensor of shape (H/4)x(W/4) 31 | depths: tensor of shape (H/4)x(W/4) 32 | ''' 33 | 34 | model = RMVSNet() 35 | 36 | # optional: put model into gpu 37 | model.to(torch.device('cuda:0')) 38 | 39 | depths, probs = model(images, intrinsics, extrinsics, depth_start, depth_interval, depth_num) 40 | ``` 41 | 42 | ## Reference: 43 | 44 | ## About 45 | This is a custom port of [Original MVSNet using Tensorflow](https://github.com/YoYo000/MVSNet) in Pytorch. 46 | We use same weight that the authors provided (GRU + DTU). 47 | ``` 48 | @article{yao2019recurrent, 49 | title={Recurrent MVSNet for High-resolution Multi-view Stereo Depth Inference}, 50 | author={Yao, Yao and Luo, Zixin and Li, Shiwei and Shen, Tianwei and Fang, Tian and Quan, Long}, 51 | journal={Computer Vision and Pattern Recognition (CVPR)}, 52 | year={2019} 53 | } 54 | ``` 55 | 56 | -------------------------------------------------------------------------------- /rmvsnet/RMVSNet-pretrained.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/leejaeyong7/RMVSNet-Pytorch/c87b97b5c4338250159d03cc30796d2e93515821/rmvsnet/RMVSNet-pretrained.pth -------------------------------------------------------------------------------- /rmvsnet/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | """ PowerPoint package export. """ 3 | from .rmvsnet import RMVSNet 4 | name = "rmvsnet" 5 | 6 | __all__ = ['RMVSNet'] 7 | -------------------------------------------------------------------------------- /rmvsnet/gru.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as NF 4 | import numpy as np 5 | 6 | class GRU(nn.Module): 7 | def __init__(self, input_channel, output_channel, kernel_size): 8 | super(GRU, self).__init__() 9 | 10 | # filters used for gates 11 | gru_input_channel = input_channel + output_channel 12 | self.output_channel = output_channel 13 | 14 | self.gate_conv = nn.Conv2d(gru_input_channel, output_channel * 2, kernel_size, padding=1) 15 | self.reset_gate_norm = nn.GroupNorm(1, output_channel, 1e-5, True) 16 | self.update_gate_norm = nn.GroupNorm(1, output_channel, 1e-5, True) 17 | 18 | # filters used for outputs 19 | self.output_conv = nn.Conv2d(gru_input_channel, output_channel, kernel_size, padding=1) 20 | self.output_norm = nn.GroupNorm(1, output_channel, 1e-5, True) 21 | 22 | self.activation = nn.Tanh() 23 | 24 | def gates(self, x, h): 25 | # x = N x C x H x W 26 | # h = N x C x H x W 27 | 28 | # c = N x C*2 x H x W 29 | c = torch.cat((x, h), dim=1) 30 | f = self.gate_conv(c) 31 | 32 | # r = reset gate, u = update gate 33 | # both are N x O x H x W 34 | C = f.shape[1] 35 | r, u = torch.split(f, C // 2, 1) 36 | 37 | rn = self.reset_gate_norm(r) 38 | un = self.update_gate_norm(u) 39 | rns = NF.sigmoid(rn) 40 | uns = NF.sigmoid(un) 41 | return rns, uns 42 | 43 | def output(self, x, h, r, u): 44 | f = torch.cat((x, r * h), dim=1) 45 | o = self.output_conv(f) 46 | on = self.output_norm(o) 47 | return on 48 | 49 | def forward(self, x, h = None): 50 | N, C, H, W = x.shape 51 | HC = self.output_channel 52 | if(h is None): 53 | h = torch.zeros((N, HC, H, W), dtype=torch.float, device=x.device) 54 | r, u = self.gates(x, h) 55 | o = self.output(x, h, r, u) 56 | y = self.activation(o) 57 | return u * h + (1 - u) * y 58 | -------------------------------------------------------------------------------- /rmvsnet/rmvsnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from os import path 5 | 6 | from .gru import GRU 7 | from .unet_ds2gn import UNetDS2GN 8 | 9 | from .warping import get_homographies, warp_homographies 10 | import pdb 11 | 12 | class RMVSNet(nn.Module): 13 | def __init__(self, train=False): 14 | super(RMVSNet, self).__init__() 15 | # setup network modules 16 | 17 | self.feature_extractor = UNetDS2GN() 18 | 19 | gru_input_size = self.feature_extractor.output_size 20 | gru1_output_size = 16 21 | gru2_output_size = 4 22 | gru3_output_size = 2 23 | self.gru1 = GRU(gru_input_size, gru1_output_size, 3) 24 | self.gru2 = GRU(gru1_output_size, gru2_output_size, 3) 25 | self.gru3 = GRU(gru2_output_size, gru3_output_size, 3) 26 | 27 | self.prob_conv = nn.Conv2d(2, 1, 3, 1, 1) 28 | 29 | file_path = path.dirname(path.abspath(__file__)) 30 | pretrained_weights_file = path.join(file_path, 31 | 'RMVSNet-pretrained.pth') 32 | pretrained_weights = torch.load(pretrained_weights_file) 33 | self.load_state_dict(pretrained_weights) 34 | 35 | def compute_cost_volume(self, warped): 36 | ''' 37 | Warped: N x C x M x H x W 38 | 39 | returns: 1 x C x M x H x W 40 | ''' 41 | warped_sq = warped ** 2 42 | av_warped = warped.mean(0) 43 | av_warped_sq = warped_sq.mean(0) 44 | cost = av_warped_sq - (av_warped ** 2) 45 | 46 | return cost.unsqueeze(0) 47 | 48 | def compute_depth(self, prob_volume, depth_start, depth_interval, depth_num): 49 | ''' 50 | prob_volume: 1 x D x H x W 51 | ''' 52 | _, M, H, W = prob_volume.shape 53 | # prob_indices = HW shaped vector 54 | probs, indices = prob_volume.max(1) 55 | depth_range = depth_start + torch.arange(depth_num).float() * depth_interval 56 | depth_range = depth_range.to(prob_volume.device) 57 | depths = torch.index_select(depth_range, 0, indices.flatten()) 58 | depth_image = depths.view(H, W) 59 | prob_image = probs.view(H, W) 60 | 61 | return depth_image, prob_image 62 | 63 | 64 | 65 | def forward(self, images, intrinsics, extrinsics, depth_start, depth_interval, depth_num): 66 | ''' 67 | Takes all entry and outputs probability volume 68 | 69 | N x D x H x W probability map 70 | ''' 71 | N, C, IH, IW = images.shape 72 | f = self.feature_extractor(images) 73 | 74 | Hs = get_homographies(f, intrinsics, extrinsics, depth_start, depth_interval, depth_num) 75 | 76 | # N, C, D, H, W = warped.shape 77 | cost_1 = None 78 | cost_2 = None 79 | cost_3 = None 80 | depth_costs = [] 81 | 82 | # ref_f = f[0] 83 | # ref_f2 = ref_f ** 2 84 | 85 | for d in range(depth_num): 86 | # mean_f = ref_f 87 | # mean_f2 = ref_f2 88 | # warped = N x C x H x W 89 | ref_f = f[:1] 90 | warped = warp_homographies(f[1:], Hs[1:, d]) 91 | all_f = torch.cat((ref_f, warped), 0) 92 | 93 | # cost_d = 1 x C x H x W 94 | cost_d = self.compute_cost_volume(all_f) 95 | cost_1 = self.gru1(-cost_d, cost_1) 96 | cost_2 = self.gru2(cost_1, cost_2) 97 | cost_3 = self.gru3(cost_2, cost_3) 98 | 99 | reg_cost = self.prob_conv(cost_3) 100 | depth_costs.append(reg_cost) 101 | 102 | prob_volume = torch.cat(depth_costs, 1) 103 | softmax_probs = torch.softmax(prob_volume, 1) 104 | 105 | 106 | # compute depth map from prob / depth values 107 | return self.compute_depth(softmax_probs, depth_start, depth_interval, depth_num) 108 | 109 | -------------------------------------------------------------------------------- /rmvsnet/unet_ds2gn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as NF 4 | import numpy as np 5 | 6 | 7 | def CGR(kernel_size, input_channel, output_channel, strides): 8 | pad = (kernel_size - 1) // 2 9 | G = max(1, output_channel // 8) 10 | return nn.Sequential( 11 | nn.Conv2d(input_channel, output_channel, kernel_size, strides, pad, bias=False), 12 | nn.GroupNorm(G, output_channel), 13 | nn.ReLU(inplace=True) 14 | ) 15 | def DGR(kernel_size, input_channel, output_channel, strides): 16 | pad = (kernel_size - 1) // 2 17 | G = max(1, output_channel // 8) 18 | return nn.Sequential( 19 | nn.ConvTranspose2d(input_channel, output_channel, kernel_size, strides, pad, bias=False), 20 | nn.GroupNorm(G, output_channel), 21 | nn.ReLU(inplace=True) 22 | ) 23 | 24 | class UNetDS2GN(nn.Module): 25 | def __init__(self): 26 | super(UNetDS2GN, self).__init__() 27 | # input images pass through 28 | ###################### 29 | # feature extraction 30 | # input : N x 3 x W x H channel images 31 | # output: N x 32 x W/4 x H/4 channel 32 | # features 33 | input_channel = 3 34 | base_channel = 8 35 | 36 | # aggregation 37 | # input x=> 38 | self.conv1_0 = CGR(3, input_channel, base_channel * 2, 2) 39 | self.conv2_0 = CGR(3, base_channel * 2, base_channel * 4, 2) 40 | self.conv3_0 = CGR(3, base_channel * 4, base_channel * 8, 2) 41 | self.conv4_0 = CGR(3, base_channel * 8, base_channel * 16, 2) 42 | 43 | self.conv0_1 = CGR(3, input_channel, base_channel, 1) 44 | self.conv0_2 = CGR(3, base_channel, base_channel, 1) 45 | 46 | self.conv1_1 = CGR(3, base_channel * 2, base_channel * 2, 1) 47 | self.conv1_2 = CGR(3, base_channel * 2, base_channel * 2, 1) 48 | self.conv2_1 = CGR(3, base_channel * 4, base_channel * 4, 1) 49 | self.conv2_2 = CGR(3, base_channel * 4, base_channel * 4, 1) 50 | self.conv3_1 = CGR(3, base_channel * 8, base_channel * 8, 1) 51 | self.conv3_2 = CGR(3, base_channel * 8, base_channel * 8, 1) 52 | self.conv4_1 = CGR(3, base_channel * 16, base_channel * 16, 1) 53 | self.conv4_2 = CGR(3, base_channel * 16, base_channel * 16, 1) 54 | self.conv5_0 = DGR(3, base_channel * 16, base_channel * 8, 2) 55 | 56 | # conv5_0 + conv3_2 57 | self.conv5_1 = CGR(3, base_channel * 16, base_channel * 8, 1) 58 | self.conv5_2 = CGR(3, base_channel * 8, base_channel * 8, 1) 59 | self.conv6_0 = DGR(3, base_channel * 8, base_channel * 4, 2) 60 | 61 | # conv6_0 + conv2_2 62 | self.conv6_1 = CGR(3, base_channel * 8, base_channel * 4, 1) 63 | self.conv6_2 = CGR(3, base_channel * 4, base_channel * 4, 1) 64 | self.conv7_0 = DGR(3, base_channel * 4, base_channel * 2, 2) 65 | 66 | # conv7_0 + conv1_2 67 | self.conv7_1 = CGR(3, base_channel * 4, base_channel * 2, 1) 68 | self.conv7_2 = CGR(3, base_channel * 2, base_channel * 2, 1) 69 | self.conv8_0 = DGR(3, base_channel * 2, base_channel, 2) 70 | 71 | # conv8_0 + conv0_2 72 | self.conv8_1 = CGR(3, base_channel * 2, base_channel, 1) 73 | self.conv8_2 = CGR(3, base_channel * 1, base_channel, 1) 74 | 75 | # 76 | self.conv9_0 = CGR(5, base_channel, base_channel * 2, 2) 77 | self.conv9_1 = CGR(3, base_channel * 2, base_channel * 2, 1) 78 | self.conv9_2 = CGR(3, base_channel * 2, base_channel * 2, 1) 79 | self.conv10_0 = CGR(5, base_channel * 2, base_channel * 4, 2) 80 | self.conv10_1 = CGR(3, base_channel * 4, base_channel * 4, 1) 81 | self.conv10_2 = nn.Conv2d(base_channel * 4, base_channel * 4, 3, 1, 1, bias=False) 82 | 83 | self.output_size = base_channel * 4 84 | 85 | def forward(self, x): 86 | f0_1 = self.conv0_1(x) 87 | f0_2 = self.conv0_2(f0_1) 88 | 89 | f1_0 = self.conv1_0(x) 90 | f2_0 = self.conv2_0(f1_0) 91 | f3_0 = self.conv3_0(f2_0) 92 | f4_0 = self.conv4_0(f3_0) 93 | 94 | f1_1 = self.conv1_1(f1_0) 95 | f1_2 = self.conv1_2(f1_1) 96 | 97 | f2_1 = self.conv2_1(f2_0) 98 | f2_2 = self.conv2_2(f2_1) 99 | 100 | f3_1 = self.conv3_1(f3_0) 101 | f3_2 = self.conv3_2(f3_1) 102 | 103 | f4_1 = self.conv4_1(f4_0) 104 | f4_2 = self.conv4_2(f4_1) 105 | 106 | f5_0 = self.conv5_0(f4_2) 107 | f5_0 = NF.pad(f5_0, (0, 1, 0, 1)) 108 | 109 | cat5_0 = torch.cat((f5_0, f3_2), dim=1) 110 | f5_1 = self.conv5_1(cat5_0) 111 | f5_2 = self.conv5_2(f5_1) 112 | f6_0 = self.conv6_0(f5_2) 113 | f6_0 = NF.pad(f6_0, (0, 1, 0, 1)) 114 | 115 | cat6_0 = torch.cat((f6_0, f2_2), dim=1) 116 | f6_1 = self.conv6_1(cat6_0) 117 | f6_2 = self.conv6_2(f6_1) 118 | f7_0 = self.conv7_0(f6_2) 119 | f7_0 = NF.pad(f7_0, (0, 1, 0, 1)) 120 | 121 | 122 | cat7_0 = torch.cat((f7_0, f1_2), dim=1) 123 | f7_1 = self.conv7_1(cat7_0) 124 | f7_2 = self.conv7_2(f7_1) 125 | f8_0 = self.conv8_0(f7_2) 126 | f8_0 = NF.pad(f8_0, (0, 1, 0, 1)) 127 | 128 | cat8_0 = torch.cat((f8_0, f0_2), dim=1) 129 | f8_1 = self.conv8_1(cat8_0) 130 | f8_2 = self.conv8_2(f8_1) 131 | f9_0 = self.conv9_0(f8_2) 132 | f9_1 = self.conv9_1(f9_0) 133 | f9_2 = self.conv9_2(f9_1) 134 | f10_0 = self.conv10_0(f9_2) 135 | f10_1 = self.conv10_1(f10_0) 136 | f10_2 = self.conv10_2(f10_1) 137 | 138 | return f10_2 139 | 140 | -------------------------------------------------------------------------------- /rmvsnet/warping.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as NF 3 | 4 | def get_pixel_grids(width, height): 5 | '''returns W x H grid pixels 6 | 7 | Given width and height, creates a mesh grid, and returns homogeneous 8 | coordinates 9 | of image in a 3 x W*H Tensor 10 | 11 | Arguments: 12 | width {Number} -- Number representing width of pixel grid image 13 | height {Number} -- Number representing height of pixel grid image 14 | 15 | Returns: 16 | torch.Tensor -- 3 x width*height tensor, oriented in H, W order 17 | individual coords are oriented [x, y, 1] 18 | ''' 19 | # from 0.5 to w-0.5 20 | x_coords = torch.linspace(0.5, width - 0.5, width) 21 | # from 0.5 to h-0.5 22 | y_coords = torch.linspace(0.5, height - 0.5, height) 23 | y_grid_coords, x_grid_coords = torch.meshgrid([y_coords, x_coords]) 24 | x_grid_coords = x_grid_coords.contiguous().view(-1) 25 | y_grid_coords = y_grid_coords.contiguous().view(-1) 26 | ones = torch.ones(x_grid_coords.shape) 27 | return torch.stack([ 28 | x_grid_coords, 29 | y_grid_coords, 30 | ones 31 | ], 1) 32 | 33 | def get_homographies(features, intrinsics, extrinsics, min_dist, interval, num_planes): 34 | M = num_planes 35 | N, C, IH, IW = features.shape 36 | # define source K, R, t 37 | # K = Nx1x3x3 38 | # R = Nx1x3x3 39 | # t = Nx1x3x1 40 | depths = torch.arange(num_planes).float() * interval + min_dist 41 | depths = depths.to(features.device) 42 | src_Ks = intrinsics / 4 43 | src_Ks[:, 2, 2] = 1 44 | src_Rs = extrinsics[:, :3, :3] 45 | src_ts = extrinsics[:, :3, 3:] 46 | src_Ks = src_Ks.unsqueeze(1) 47 | src_Rs = src_Rs.unsqueeze(1) 48 | src_ts = src_ts.unsqueeze(1) 49 | src_Rts = src_Rs.transpose(2, 3) 50 | src_Cs = -src_Rts.matmul(src_ts) 51 | src_KIs = torch.inverse(src_Ks) 52 | 53 | # define ref K, R, t 54 | ref_K = src_Ks[:1] 55 | ref_R = src_Rs[:1] 56 | ref_t = src_ts[:1] 57 | ref_Rt = src_Rts[:1] 58 | ref_KI = src_KIs[:1] 59 | ref_C = src_Cs[:1] 60 | 61 | fronto_direction = ref_R[:, :, 2:3, :3] # N x 1 x 1 x 3 62 | rel_C = src_Cs - ref_C # N x 1 x 3 x 1 63 | 64 | # compute h 65 | # N x 1 x 3 x 1 . N x 1 x 1 x 3 => N x 1 x 3 x 3 66 | depth_mat = depths.view(1, M, 1, 1) 67 | trans_mat = torch.eye(3, device=features.device).view(1, 1, 3, 3) - rel_C.matmul(fronto_direction) / depth_mat 68 | return src_Ks.matmul(src_Rs).matmul(trans_mat).matmul(ref_Rt).matmul(ref_KI) 69 | 70 | def warp_pixel_grid(homographies, pixel_grid): 71 | ''' 72 | Given homography and a pixel grids 73 | Argument: 74 | - pixel_grids: 3 x W x H tensor representing 75 | homogeneous pixel coordinates 76 | [(0.5, 0.5, 1) , (0.5, 1.5, 1) ... ] 77 | - homographies: N x 3 x 3 tensor representing 78 | homography transformation of N images for M planes 79 | Returns: 80 | - N x 2 x W x H tensor of warped non-homogeneous coordinates 81 | ''' 82 | # reshape, batch matmul and reshape back 83 | # homographies = 3 x 3 @ 3 x (HW) 84 | homo_trans_grids = torch.matmul(homographies, pixel_grid.t()) 85 | 86 | # make homogeneous => non homogeneous 87 | homo_trans_coords = homo_trans_grids[:2] 88 | homo_trans_scale = homo_trans_grids[2:] 89 | return homo_trans_coords / homo_trans_scale 90 | 91 | def warp_pixel_grids(homographies, pixel_grid): 92 | ''' 93 | Given homography and a pixel grids 94 | Argument: 95 | - pixel_grids: 3 x W x H tensor representing 96 | homogeneous pixel coordinates 97 | [(0.5, 0.5, 1) , (0.5, 1.5, 1) ... ] 98 | - homographies: B x N x M x 3 x 3 tensor representing 99 | homography transformation of N images for M planes 100 | Returns: 101 | - B x N x M x 2 x W x H tensor of warped non-homogeneous coordinates 102 | ''' 103 | # reshape, batch matmul and reshape back 104 | # homographies = 3 x 3 @ 3 x (HW) 105 | homo_trans_grids = torch.matmul(homographies, pixel_grid.t()) 106 | 107 | # make homogeneous => non homogeneous 108 | homo_trans_coords = homo_trans_grids[:, :2] 109 | homo_trans_scale = homo_trans_grids[:, 2:] 110 | return homo_trans_coords / homo_trans_scale 111 | 112 | def warp_pixel_grids_all(homographies, pixel_grid): 113 | ''' 114 | Given homography and a pixel grids 115 | Argument: 116 | - pixel_grids: 3 x W x H tensor representing 117 | homogeneous pixel coordinates 118 | [(0.5, 0.5, 1) , (0.5, 1.5, 1) ... ] 119 | - homographies: N x M x 3 x 3 tensor representing 120 | homography transformation of N images for M planes 121 | Returns: 122 | - N x M x 2 x W x H tensor of warped non-homogeneous coordinates 123 | ''' 124 | # reshape, batch matmul and reshape back 125 | # homographies = 3 x 3 @ 3 x (HW) 126 | homo_trans_grids = torch.matmul(homographies, pixel_grid.t()) 127 | 128 | # make homogeneous => non homogeneous 129 | homo_trans_coords = homo_trans_grids[:, :, :2] 130 | homo_trans_scale = homo_trans_grids[:, :, :2] 131 | return homo_trans_coords / homo_trans_scale 132 | 133 | 134 | def warp_homography(features, homographies): 135 | ''' 136 | Warp features using homography, and return cost volume 137 | 138 | 1. Create pixel grid with N x M x 3 x H/4 x W/4 (homogeneous img coord) 139 | 2. Warp pixel grid by homography 140 | - this will result in N x M x 3 x H/4 x W/4 tensor 141 | 3. Obtain features from warped pixel coordinates 142 | - Use linear interpolation for feature values 143 | - this will result in N x M x 32 x H/4 x W/4 tensor 144 | ''' 145 | C, H, W = features.shape 146 | 147 | # obtain pixel grids 148 | # pixel_grid = (HW)x 3, in x, y, 1 format 149 | pixel_grid = get_pixel_grids(W, H) 150 | pixel_grid = pixel_grid.to(features.device) 151 | # warp pixel grid with homography 152 | # (HW x 3) . (N x M x 3 x 3) => N x M x HW x 3 => N x M x H x W x 3 153 | # each 3x1 warped pixel grid represents pixel coord in feature 154 | 155 | warped_pixel_grids = warp_pixel_grids(homographies, pixel_grid) 156 | 157 | # warp / interpolate features 158 | warped_features = warp_feature(features, warped_pixel_grids) 159 | return warped_features 160 | 161 | def warp_homographies(features, homographies): 162 | ''' 163 | Warp features using homography, and return cost volume 164 | 165 | 1. Create pixel grid with N x M x 3 x H/4 x W/4 (homogeneous img coord) 166 | 2. Warp pixel grid by homography 167 | - this will result in N x M x 3 x H/4 x W/4 tensor 168 | 3. Obtain features from warped pixel coordinates 169 | - Use linear interpolation for feature values 170 | - this will result in N x M x 32 x H/4 x W/4 tensor 171 | ''' 172 | N, C, H, W = features.shape 173 | N, _, _ = homographies.shape 174 | 175 | # obtain pixel grids 176 | # pixel_grid = (HW)x 3, in x, y, 1 format 177 | pixel_grid = get_pixel_grids(W, H) 178 | pixel_grid = pixel_grid.to(features.device) 179 | # warp pixel grid with homography 180 | # (1 x 1 x HW x 3) . (N x M x 3 x 3) => N x M x HW x 3 => N x M x H x W x 3 181 | # each 3x1 warped pixel grid represents pixel coord in feature 182 | warped_pixel_grids = warp_pixel_grids(homographies, pixel_grid) 183 | 184 | # warp / interpolate features 185 | warped_features = warp_features(features, warped_pixel_grids) 186 | return warped_features 187 | 188 | def warp_homographies_all(features, homographies): 189 | ''' 190 | Warp features using homography, and return cost volume 191 | 192 | 1. Create pixel grid with N x M x 3 x H/4 x W/4 (homogeneous img coord) 193 | 2. Warp pixel grid by homography 194 | - this will result in N x M x 3 x H/4 x W/4 tensor 195 | 3. Obtain features from warped pixel coordinates 196 | - Use linear interpolation for feature values 197 | - this will result in N x M x 32 x H/4 x W/4 tensor 198 | ''' 199 | N, C, H, W = features.shape 200 | N, M, _, _ = homographies.shape 201 | 202 | # obtain pixel grids 203 | # pixel_grid = (HW)x 3, in x, y, 1 format 204 | pixel_grid = get_pixel_grids(W, H) 205 | pixel_grid = pixel_grid.to(features.device) 206 | # warp pixel grid with homography 207 | # (1 x 1 x HW x 3) . (N x M x 3 x 3) => N x M x HW x 3 => N x M x H x W x 3 208 | # each 3x1 warped pixel grid represents pixel coord in feature 209 | warped_pixel_grids = warp_pixel_grids(homographies, pixel_grid) 210 | 211 | # warp / interpolate features 212 | warped_features = warp_features_all(features, warped_pixel_grids) 213 | return warped_features 214 | 215 | def warp_features_old(features, warped_pixel_grids): 216 | ''' 217 | Given features, and pixel coordinates, create a warped image 218 | Argument: 219 | - features: N x 32 x W x H 220 | - pixel_grids: N x M x 2 x HW , representing x, y coords in 221 | N images for M planes, in new warped plane 222 | Returns: 223 | - N x M x 32 x W x H warped features 224 | ''' 225 | N, C, H, W = features.shape 226 | N, M, _, HW = warped_pixel_grids.shape 227 | feats = features.view(N, C, HW) 228 | 229 | # N x 1 x MWH 230 | x = warped_pixel_grids.narrow(2, 0, 1).contiguous().view(N, 1, -1) 231 | y = warped_pixel_grids.narrow(2, 1, 1).contiguous().view(N, 1, -1) 232 | 233 | 234 | xm = ((x >= 0) & (x < W)).float() # N x 1 x MHW mask 235 | ym = ((y >= 0) & (y < H)).float() 236 | 237 | x0 = (x - 0.499).long() 238 | y0 = (y - 0.499).long() 239 | x1 = x0 + 1 240 | y1 = y0 + 1 241 | 242 | x0.clamp_(0, W - 1) 243 | y0.clamp_(0, H - 1) 244 | x1.clamp_(0, W - 1) 245 | y1.clamp_(0, H - 1) 246 | 247 | # # N x 1 x MWH 248 | # coord_a = (y1 * W + x1) 249 | # coord_b = (y1 * W + x0) 250 | # coord_c = (y0 * W + x1) 251 | # coord_d = (y0 * W + x0) 252 | 253 | # # N x C x MWH 254 | # pixel_values_a_list = [] 255 | # pixel_values_b_list = [] 256 | # pixel_values_c_list = [] 257 | # pixel_values_d_list = [] 258 | # for n in range(N): 259 | # pixel_values_a_list.append(torch.index_select(feats[n], 1, coord_a[n, 0])) 260 | # pixel_values_b_list.append(torch.index_select(feats[n], 1, coord_b[n, 0])) 261 | # pixel_values_c_list.append(torch.index_select(feats[n], 1, coord_c[n, 0])) 262 | # pixel_values_d_list.append(torch.index_select(feats[n], 1, coord_d[n, 0])) 263 | # pixel_values_a = torch.stack(pixel_values_a_list) 264 | # pixel_values_b = torch.stack(pixel_values_b_list) 265 | # pixel_values_c = torch.stack(pixel_values_c_list) 266 | # pixel_values_d = torch.stack(pixel_values_d_list) 267 | coord_a = (y1 * W + x1).repeat(1, C, 1) 268 | coord_b = (y1 * W + x0).repeat(1, C, 1) 269 | coord_c = (y0 * W + x1).repeat(1, C, 1) 270 | coord_d = (y0 * W + x0).repeat(1, C, 1) 271 | pixel_values_a = feats.gather(2, coord_a) 272 | pixel_values_b = feats.gather(2, coord_b) 273 | pixel_values_c = feats.gather(2, coord_c) 274 | pixel_values_d = feats.gather(2, coord_d) 275 | 276 | # N x 1 x MWH 277 | x0 = x0.float() 278 | y0 = y0.float() 279 | x1 = x1.float() 280 | y1 = y1.float() 281 | 282 | dy1 = (y1 - y).clamp(0, 1) 283 | dx1 = (x1 - x).clamp(0, 1) 284 | dy0 = (y - y0).clamp(0, 1) 285 | dx0 = (x - x0).clamp(0, 1) 286 | 287 | area_a = (dx1 * dy1) * xm * ym 288 | area_b = (dx0 * dy1) * xm * ym 289 | area_c = (dx1 * dy0) * xm * ym 290 | area_d = (dx0 * dy0) * xm * ym 291 | 292 | # N x C x MWH 293 | print(area_a.shape, pixel_values_a.shape) 294 | va = area_a * pixel_values_a 295 | vb = area_b * pixel_values_b 296 | vc = area_c * pixel_values_c 297 | vd = area_d * pixel_values_d 298 | 299 | # N x M x C x H x W 300 | return (va + vb + vc + vd).view(N, C, M, H, W) 301 | 302 | def warp_feature(features, warped_pixel_grids): 303 | ''' 304 | Given features, and pixel coordinates, create a warped image 305 | Argument: 306 | - features: N x 32 x W x H 307 | - pixel_grids: N x M x 2 x HW , representing x, y coords in 308 | N images for M planes, in new warped plane 309 | Returns: 310 | - N x M x 32 x W x H warped features 311 | ''' 312 | C, H, W = features.shape 313 | _, HW = warped_pixel_grids.shape 314 | 315 | # HW x 2 316 | warped_sample_coord = warped_pixel_grids.t().view(1, 1, -1, 2) 317 | 318 | grid_sample_coord = torch.zeros_like(warped_sample_coord) 319 | grid_sample_coord[:, :, :, 0] = (warped_sample_coord[:, :, :, 0]) / (W / 2) - 1 320 | grid_sample_coord[:, :, :, 1] = (warped_sample_coord[:, :, :, 1]) / (H / 2) - 1 321 | grid_sample_coord.clamp_(-2, 2) 322 | 323 | # grid_sample_coord = NxMxHWx2 324 | sampled = NF.grid_sample(features.unsqueeze(0), grid_sample_coord) 325 | # sampled = N x C x M x HW 326 | return sampled.view(C, H, W) 327 | 328 | def warp_features(features, warped_pixel_grids): 329 | ''' 330 | Given features, and pixel coordinates, create a warped image 331 | Argument: 332 | - features: N x 32 x W x H 333 | - pixel_grids: N x M x 2 x HW , representing x, y coords in 334 | N images for M planes, in new warped plane 335 | Returns: 336 | - N x M x 32 x W x H warped features 337 | ''' 338 | N, C, H, W = features.shape 339 | N, _, HW = warped_pixel_grids.shape 340 | 341 | # HW x 2 342 | warped_sample_coord = warped_pixel_grids.permute(0, 2, 1).unsqueeze(1) 343 | 344 | grid_sample_coord = torch.zeros_like(warped_sample_coord) 345 | grid_sample_coord[:, :, :, 0] = (warped_sample_coord[:, :, :, 0]) / (W / 2) - 1 346 | grid_sample_coord[:, :, :, 1] = (warped_sample_coord[:, :, :, 1]) / (H / 2) - 1 347 | grid_sample_coord.clamp_(-2, 2) 348 | 349 | # grid_sample_coord = NxMxHWx2 350 | sampled = NF.grid_sample(features, grid_sample_coord) 351 | # sampled = N x C x M x HW 352 | return sampled.view(N, C, H, W) 353 | 354 | def warp_features_all(features, warped_pixel_grids): 355 | ''' 356 | Given features, and pixel coordinates, create a warped image 357 | Argument: 358 | - features: N x 32 x W x H 359 | - pixel_grids: N x M x 2 x HW , representing x, y coords in 360 | N images for M planes, in new warped plane 361 | Returns: 362 | - N x M x 32 x W x H warped features 363 | ''' 364 | N, C, H, W = features.shape 365 | N, M, _, HW = warped_pixel_grids.shape 366 | 367 | # HW x 2 368 | warped_sample_coord = warped_pixel_grids.permute(0, 1, 3, 2) 369 | 370 | grid_sample_coord = torch.zeros_like(warped_sample_coord) 371 | grid_sample_coord[:, :, :, 0] = (warped_sample_coord[:, :, :, 0]) / (W / 2) - 1 372 | grid_sample_coord[:, :, :, 1] = (warped_sample_coord[:, :, :, 1]) / (H / 2) - 1 373 | grid_sample_coord.clamp_(-2, 2) 374 | 375 | # grid_sample_coord = NxMxHWx2 376 | sampled = NF.grid_sample(features, grid_sample_coord) 377 | # sampled = N x C x M x HW 378 | return sampled.view(N, C, H, W) 379 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | import setuptools 2 | 3 | with open("README.md", "r") as fh: 4 | long_description = fh.read() 5 | 6 | setuptools.setup( 7 | name="rmvsnet", 8 | version="0.0.1", 9 | author="Jae Yong Lee", 10 | author_email="lee896@illinois.edu", 11 | description="R-MVSNet: Recurrent MVSNet for High-resolution Multi-view Stereo Depth Inference", 12 | long_description=long_description, 13 | long_description_content_type="text/markdown", 14 | url="https://github.com/leejaeyong7/rmvsnet-pytorch", 15 | packages=setuptools.find_packages(), 16 | include_package_data=True, 17 | classifiers=[ 18 | "Programming Language :: Python :: 3", 19 | "License :: OSI Approved :: MIT License", 20 | "Operating System :: OS Independent", 21 | ], 22 | ) 23 | --------------------------------------------------------------------------------