├── .gitignore ├── README.md ├── datasets ├── __init__.py ├── data_io.py ├── dtu_yao.py └── dtu_yao_eval.py ├── eval.py ├── eval.sh ├── evaluations └── dtu │ ├── BaseEval2Obj_web.m │ ├── BaseEvalMain_web.m │ ├── ComputeStat_web.m │ ├── MaxDistCP.m │ ├── PointCompareMain.m │ ├── plyread.m │ └── reducePts_haa.m ├── lists └── dtu │ ├── test.txt │ ├── train.txt │ └── val.txt ├── models ├── __init__.py ├── module.py └── mvsnet.py ├── train.py ├── train.sh └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | .idea 2 | outputs 3 | *.bin 4 | checkpoints 5 | __pycache__ 6 | 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # An Unofficial Pytorch Implementation of MVSNet 2 | 3 | [MVSNet: Depth Inference for Unstructured Multi-view Stereo](https://arxiv.org/abs/1804.02505). Yao Yao, Zixin Luo, Shiwei Li, Tian Fang, Long Quan. ECCV 2018. MVSNet is a deep learning architecture for depth map inference from unstructured multi-view images. 4 | 5 | This is an unofficial Pytorch implementation of MVSNet 6 | 7 | ## How to Use 8 | 9 | ### Environment 10 | * python 3.6 (Anaconda) 11 | * pytorch 1.0.1 12 | 13 | ### Training 14 | 15 | * Download the preprocessed [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view) (Fixed training cameras, from [Original MVSNet](https://github.com/YoYo000/MVSNet)), and upzip it as the ``MVS_TRANING`` folder 16 | * in ``train.sh``, set ``MVS_TRAINING`` as your training data path 17 | * create a logdir called ``checkpoints`` 18 | * Train MVSNet: ``./train.sh`` 19 | 20 | ### Testing 21 | 22 | * Download the preprocessed test data [DTU testing data](https://drive.google.com/open?id=135oKPefcPTsdtLRzoDAQtPpHuoIrpRI_) (from [Original MVSNet](https://github.com/YoYo000/MVSNet)) and unzip it as the ``DTU_TESTING`` folder, which should contain one ``cams`` folder, one ``images`` folder and one ``pair.txt`` file. 23 | * in ``test.sh``, set ``DTU_TESTING`` as your testing data path and ``CKPT_FILE`` as your checkpoint file. You can also download my [pretrained model](https://drive.google.com/file/d/1j2I_LNKb9JeCl6wdA7hh8z1WgVQZfLU9/view?usp=sharing). 24 | * Test MVSNet: ``./test.sh`` 25 | 26 | ### Fusion 27 | 28 | in ``eval.py``, I implemented a simple version of depth map fusion. Welcome contributions to improve the code. 29 | 30 | 31 | ## Results on DTU 32 | 33 | | | Acc. | Comp. | Overall. | 34 | |-----------------------|--------|--------|----------| 35 | | MVSNet(D=256) | 0.396 | 0.527 | 0.462 | 36 | | PyTorch-MVSNet(D=192) | 0.4492 | 0.3796 | 0.4144 | 37 | 38 | Due to the memory limit, we only train the model with ``D=192``, the fusion code is also different from the original repo. 39 | -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import importlib 2 | 3 | 4 | # find the dataset definition by name, for example dtu_yao (dtu_yao.py) 5 | def find_dataset_def(dataset_name): 6 | module_name = 'datasets.{}'.format(dataset_name) 7 | module = importlib.import_module(module_name) 8 | return getattr(module, "MVSDataset") 9 | -------------------------------------------------------------------------------- /datasets/data_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import sys 4 | 5 | 6 | def read_pfm(filename): 7 | file = open(filename, 'rb') 8 | color = None 9 | width = None 10 | height = None 11 | scale = None 12 | endian = None 13 | 14 | header = file.readline().decode('utf-8').rstrip() 15 | if header == 'PF': 16 | color = True 17 | elif header == 'Pf': 18 | color = False 19 | else: 20 | raise Exception('Not a PFM file.') 21 | 22 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 23 | if dim_match: 24 | width, height = map(int, dim_match.groups()) 25 | else: 26 | raise Exception('Malformed PFM header.') 27 | 28 | scale = float(file.readline().rstrip()) 29 | if scale < 0: # little-endian 30 | endian = '<' 31 | scale = -scale 32 | else: 33 | endian = '>' # big-endian 34 | 35 | data = np.fromfile(file, endian + 'f') 36 | shape = (height, width, 3) if color else (height, width) 37 | 38 | data = np.reshape(data, shape) 39 | data = np.flipud(data) 40 | file.close() 41 | return data, scale 42 | 43 | 44 | def save_pfm(filename, image, scale=1): 45 | file = open(filename, "wb") 46 | color = None 47 | 48 | image = np.flipud(image) 49 | 50 | if image.dtype.name != 'float32': 51 | raise Exception('Image dtype must be float32.') 52 | 53 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 54 | color = True 55 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 56 | color = False 57 | else: 58 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 59 | 60 | file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8')) 61 | file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8')) 62 | 63 | endian = image.dtype.byteorder 64 | 65 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 66 | scale = -scale 67 | 68 | file.write(('%f\n' % scale).encode('utf-8')) 69 | 70 | image.tofile(file) 71 | file.close() 72 | -------------------------------------------------------------------------------- /datasets/dtu_yao.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | from datasets.data_io import * 6 | 7 | 8 | # the DTU dataset preprocessed by Yao Yao (only for training) 9 | class MVSDataset(Dataset): 10 | def __init__(self, datapath, listfile, mode, nviews, ndepths=192, interval_scale=1.06, **kwargs): 11 | super(MVSDataset, self).__init__() 12 | self.datapath = datapath 13 | self.listfile = listfile 14 | self.mode = mode 15 | self.nviews = nviews 16 | self.ndepths = ndepths 17 | self.interval_scale = interval_scale 18 | 19 | assert self.mode in ["train", "val", "test"] 20 | self.metas = self.build_list() 21 | 22 | def build_list(self): 23 | metas = [] 24 | with open(self.listfile) as f: 25 | scans = f.readlines() 26 | scans = [line.rstrip() for line in scans] 27 | 28 | # scans 29 | for scan in scans: 30 | pair_file = "Cameras/pair.txt" 31 | # read the pair file 32 | with open(os.path.join(self.datapath, pair_file)) as f: 33 | num_viewpoint = int(f.readline()) 34 | # viewpoints (49) 35 | for view_idx in range(num_viewpoint): 36 | ref_view = int(f.readline().rstrip()) 37 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 38 | # light conditions 0-6 39 | for light_idx in range(7): 40 | metas.append((scan, light_idx, ref_view, src_views)) 41 | print("dataset", self.mode, "metas:", len(metas)) 42 | return metas 43 | 44 | def __len__(self): 45 | return len(self.metas) 46 | 47 | def read_cam_file(self, filename): 48 | with open(filename) as f: 49 | lines = f.readlines() 50 | lines = [line.rstrip() for line in lines] 51 | # extrinsics: line [1,5), 4x4 matrix 52 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 53 | # intrinsics: line [7-10), 3x3 matrix 54 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 55 | # depth_min & depth_interval: line 11 56 | depth_min = float(lines[11].split()[0]) 57 | depth_interval = float(lines[11].split()[1]) * self.interval_scale 58 | return intrinsics, extrinsics, depth_min, depth_interval 59 | 60 | def read_img(self, filename): 61 | img = Image.open(filename) 62 | # scale 0~255 to 0~1 63 | np_img = np.array(img, dtype=np.float32) / 255. 64 | return np_img 65 | 66 | def read_depth(self, filename): 67 | # read pfm depth file 68 | return np.array(read_pfm(filename)[0], dtype=np.float32) 69 | 70 | def __getitem__(self, idx): 71 | meta = self.metas[idx] 72 | scan, light_idx, ref_view, src_views = meta 73 | # use only the reference view and first nviews-1 source views 74 | view_ids = [ref_view] + src_views[:self.nviews - 1] 75 | 76 | imgs = [] 77 | mask = None 78 | depth = None 79 | depth_values = None 80 | proj_matrices = [] 81 | 82 | for i, vid in enumerate(view_ids): 83 | # NOTE that the id in image file names is from 1 to 49 (not 0~48) 84 | img_filename = os.path.join(self.datapath, 85 | 'Rectified/{}_train/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx)) 86 | mask_filename = os.path.join(self.datapath, 'Depths/{}_train/depth_visual_{:0>4}.png'.format(scan, vid)) 87 | depth_filename = os.path.join(self.datapath, 'Depths/{}_train/depth_map_{:0>4}.pfm'.format(scan, vid)) 88 | proj_mat_filename = os.path.join(self.datapath, 'Cameras/train/{:0>8}_cam.txt').format(vid) 89 | 90 | imgs.append(self.read_img(img_filename)) 91 | intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename) 92 | 93 | # multiply intrinsics and extrinsics to get projection matrix 94 | proj_mat = extrinsics.copy() 95 | proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) 96 | proj_matrices.append(proj_mat) 97 | 98 | if i == 0: # reference view 99 | depth_values = np.arange(depth_min, depth_interval * self.ndepths + depth_min, depth_interval, 100 | dtype=np.float32) 101 | mask = self.read_img(mask_filename) 102 | depth = self.read_depth(depth_filename) 103 | 104 | imgs = np.stack(imgs).transpose([0, 3, 1, 2]) 105 | proj_matrices = np.stack(proj_matrices) 106 | 107 | return {"imgs": imgs, 108 | "proj_matrices": proj_matrices, 109 | "depth": depth, 110 | "depth_values": depth_values, 111 | "mask": mask} 112 | 113 | 114 | if __name__ == "__main__": 115 | # some testing code, just IGNORE it 116 | dataset = MVSDataset("/home/xyguo/dataset/dtu_mvs/processed/mvs_training/dtu/", '../lists/dtu/train.txt', 'train', 117 | 3, 128) 118 | item = dataset[50] 119 | 120 | dataset = MVSDataset("/home/xyguo/dataset/dtu_mvs/processed/mvs_training/dtu/", '../lists/dtu/val.txt', 'val', 3, 121 | 128) 122 | item = dataset[50] 123 | 124 | dataset = MVSDataset("/home/xyguo/dataset/dtu_mvs/processed/mvs_training/dtu/", '../lists/dtu/test.txt', 'test', 5, 125 | 128) 126 | item = dataset[50] 127 | 128 | # test homography here 129 | print(item.keys()) 130 | print("imgs", item["imgs"].shape) 131 | print("depth", item["depth"].shape) 132 | print("depth_values", item["depth_values"].shape) 133 | print("mask", item["mask"].shape) 134 | 135 | ref_img = item["imgs"][0].transpose([1, 2, 0])[::4, ::4] 136 | src_imgs = [item["imgs"][i].transpose([1, 2, 0])[::4, ::4] for i in range(1, 5)] 137 | ref_proj_mat = item["proj_matrices"][0] 138 | src_proj_mats = [item["proj_matrices"][i] for i in range(1, 5)] 139 | mask = item["mask"] 140 | depth = item["depth"] 141 | 142 | height = ref_img.shape[0] 143 | width = ref_img.shape[1] 144 | xx, yy = np.meshgrid(np.arange(0, width), np.arange(0, height)) 145 | print("yy", yy.max(), yy.min()) 146 | yy = yy.reshape([-1]) 147 | xx = xx.reshape([-1]) 148 | X = np.vstack((xx, yy, np.ones_like(xx))) 149 | D = depth.reshape([-1]) 150 | print("X", "D", X.shape, D.shape) 151 | 152 | X = np.vstack((X * D, np.ones_like(xx))) 153 | X = np.matmul(np.linalg.inv(ref_proj_mat), X) 154 | X = np.matmul(src_proj_mats[0], X) 155 | X /= X[2] 156 | X = X[:2] 157 | 158 | yy = X[0].reshape([height, width]).astype(np.float32) 159 | xx = X[1].reshape([height, width]).astype(np.float32) 160 | import cv2 161 | 162 | warped = cv2.remap(src_imgs[0], yy, xx, interpolation=cv2.INTER_LINEAR) 163 | warped[mask[:, :] < 0.5] = 0 164 | 165 | cv2.imwrite('../tmp0.png', ref_img[:, :, ::-1] * 255) 166 | cv2.imwrite('../tmp1.png', warped[:, :, ::-1] * 255) 167 | cv2.imwrite('../tmp2.png', src_imgs[0][:, :, ::-1] * 255) 168 | -------------------------------------------------------------------------------- /datasets/dtu_yao_eval.py: -------------------------------------------------------------------------------- 1 | from torch.utils.data import Dataset 2 | import numpy as np 3 | import os 4 | from PIL import Image 5 | from datasets.data_io import * 6 | 7 | 8 | # the DTU dataset preprocessed by Yao Yao (only for training) 9 | class MVSDataset(Dataset): 10 | def __init__(self, datapath, listfile, mode, nviews, ndepths=192, interval_scale=1.06, **kwargs): 11 | super(MVSDataset, self).__init__() 12 | self.datapath = datapath 13 | self.listfile = listfile 14 | self.mode = mode 15 | self.nviews = nviews 16 | self.ndepths = ndepths 17 | self.interval_scale = interval_scale 18 | 19 | assert self.mode == "test" 20 | self.metas = self.build_list() 21 | 22 | def build_list(self): 23 | metas = [] 24 | with open(self.listfile) as f: 25 | scans = f.readlines() 26 | scans = [line.rstrip() for line in scans] 27 | 28 | # scans 29 | for scan in scans: 30 | pair_file = "{}/pair.txt".format(scan) 31 | # read the pair file 32 | with open(os.path.join(self.datapath, pair_file)) as f: 33 | num_viewpoint = int(f.readline()) 34 | # viewpoints (49) 35 | for view_idx in range(num_viewpoint): 36 | ref_view = int(f.readline().rstrip()) 37 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 38 | metas.append((scan, ref_view, src_views)) 39 | print("dataset", self.mode, "metas:", len(metas)) 40 | return metas 41 | 42 | def __len__(self): 43 | return len(self.metas) 44 | 45 | def read_cam_file(self, filename): 46 | with open(filename) as f: 47 | lines = f.readlines() 48 | lines = [line.rstrip() for line in lines] 49 | # extrinsics: line [1,5), 4x4 matrix 50 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 51 | # intrinsics: line [7-10), 3x3 matrix 52 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 53 | intrinsics[:2, :] /= 4 54 | # depth_min & depth_interval: line 11 55 | depth_min = float(lines[11].split()[0]) 56 | depth_interval = float(lines[11].split()[1]) * self.interval_scale 57 | return intrinsics, extrinsics, depth_min, depth_interval 58 | 59 | def read_img(self, filename): 60 | img = Image.open(filename) 61 | # scale 0~255 to 0~1 62 | np_img = np.array(img, dtype=np.float32) / 255. 63 | assert np_img.shape[:2] == (1200, 1600) 64 | # crop to (1184, 1600) 65 | np_img = np_img[:-16, :] # do not need to modify intrinsics if cropping the bottom part 66 | return np_img 67 | 68 | def read_depth(self, filename): 69 | # read pfm depth file 70 | return np.array(read_pfm(filename)[0], dtype=np.float32) 71 | 72 | def __getitem__(self, idx): 73 | meta = self.metas[idx] 74 | scan, ref_view, src_views = meta 75 | # use only the reference view and first nviews-1 source views 76 | view_ids = [ref_view] + src_views[:self.nviews - 1] 77 | 78 | imgs = [] 79 | mask = None 80 | depth = None 81 | depth_values = None 82 | proj_matrices = [] 83 | 84 | for i, vid in enumerate(view_ids): 85 | img_filename = os.path.join(self.datapath, '{}/images/{:0>8}.jpg'.format(scan, vid)) 86 | proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid)) 87 | 88 | imgs.append(self.read_img(img_filename)) 89 | intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename) 90 | 91 | # multiply intrinsics and extrinsics to get projection matrix 92 | proj_mat = extrinsics.copy() 93 | proj_mat[:3, :4] = np.matmul(intrinsics, proj_mat[:3, :4]) 94 | proj_matrices.append(proj_mat) 95 | 96 | if i == 0: # reference view 97 | depth_values = np.arange(depth_min, depth_interval * (self.ndepths - 0.5) + depth_min, depth_interval, 98 | dtype=np.float32) 99 | 100 | imgs = np.stack(imgs).transpose([0, 3, 1, 2]) 101 | proj_matrices = np.stack(proj_matrices) 102 | 103 | return {"imgs": imgs, 104 | "proj_matrices": proj_matrices, 105 | "depth_values": depth_values, 106 | "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"} 107 | 108 | 109 | if __name__ == "__main__": 110 | # some testing code, just IGNORE it 111 | dataset = MVSDataset("/home/xyguo/dataset/dtu_mvs/processed/mvs_testing/dtu/", '../lists/dtu/test.txt', 'test', 5, 112 | 128) 113 | item = dataset[50] 114 | for key, value in item.items(): 115 | print(key, type(value)) 116 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | import torch.backends.cudnn as cudnn 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | import numpy as np 12 | import time 13 | from datasets import find_dataset_def 14 | from models import * 15 | from utils import * 16 | import sys 17 | from datasets.data_io import read_pfm, save_pfm 18 | import cv2 19 | from plyfile import PlyData, PlyElement 20 | from PIL import Image 21 | 22 | cudnn.benchmark = True 23 | 24 | parser = argparse.ArgumentParser(description='Predict depth, filter, and fuse. May be different from the original implementation') 25 | parser.add_argument('--model', default='mvsnet', help='select model') 26 | 27 | parser.add_argument('--dataset', default='dtu_yao_eval', help='select dataset') 28 | parser.add_argument('--testpath', help='testing data path') 29 | parser.add_argument('--testlist', help='testing scan list') 30 | 31 | parser.add_argument('--batch_size', type=int, default=1, help='testing batch size') 32 | parser.add_argument('--numdepth', type=int, default=192, help='the number of depth values') 33 | parser.add_argument('--interval_scale', type=float, default=1.06, help='the depth interval scale') 34 | 35 | parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint') 36 | parser.add_argument('--outdir', default='./outputs', help='output dir') 37 | parser.add_argument('--display', action='store_true', help='display depth images and masks') 38 | 39 | # parse arguments and check 40 | args = parser.parse_args() 41 | print("argv:", sys.argv[1:]) 42 | print_args(args) 43 | 44 | 45 | # read intrinsics and extrinsics 46 | def read_camera_parameters(filename): 47 | with open(filename) as f: 48 | lines = f.readlines() 49 | lines = [line.rstrip() for line in lines] 50 | # extrinsics: line [1,5), 4x4 matrix 51 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 52 | # intrinsics: line [7-10), 3x3 matrix 53 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 54 | # TODO: assume the feature is 1/4 of the original image size 55 | intrinsics[:2, :] /= 4 56 | return intrinsics, extrinsics 57 | 58 | 59 | # read an image 60 | def read_img(filename): 61 | img = Image.open(filename) 62 | # scale 0~255 to 0~1 63 | np_img = np.array(img, dtype=np.float32) / 255. 64 | return np_img 65 | 66 | 67 | # read a binary mask 68 | def read_mask(filename): 69 | return read_img(filename) > 0.5 70 | 71 | 72 | # save a binary mask 73 | def save_mask(filename, mask): 74 | assert mask.dtype == np.bool 75 | mask = mask.astype(np.uint8) * 255 76 | Image.fromarray(mask).save(filename) 77 | 78 | 79 | # read a pair file, [(ref_view1, [src_view1-1, ...]), (ref_view2, [src_view2-1, ...]), ...] 80 | def read_pair_file(filename): 81 | data = [] 82 | with open(filename) as f: 83 | num_viewpoint = int(f.readline()) 84 | # 49 viewpoints 85 | for view_idx in range(num_viewpoint): 86 | ref_view = int(f.readline().rstrip()) 87 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 88 | data.append((ref_view, src_views)) 89 | return data 90 | 91 | 92 | # run MVS model to save depth maps and confidence maps 93 | def save_depth(): 94 | # dataset, dataloader 95 | MVSDataset = find_dataset_def(args.dataset) 96 | test_dataset = MVSDataset(args.testpath, args.testlist, "test", 5, args.numdepth, args.interval_scale) 97 | TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False) 98 | 99 | # model 100 | model = MVSNet(refine=False) 101 | model = nn.DataParallel(model) 102 | model.cuda() 103 | 104 | # load checkpoint file specified by args.loadckpt 105 | print("loading model {}".format(args.loadckpt)) 106 | state_dict = torch.load(args.loadckpt) 107 | model.load_state_dict(state_dict['model']) 108 | model.eval() 109 | 110 | with torch.no_grad(): 111 | for batch_idx, sample in enumerate(TestImgLoader): 112 | sample_cuda = tocuda(sample) 113 | outputs = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"]) 114 | outputs = tensor2numpy(outputs) 115 | del sample_cuda 116 | print('Iter {}/{}'.format(batch_idx, len(TestImgLoader))) 117 | filenames = sample["filename"] 118 | 119 | # save depth maps and confidence maps 120 | for filename, depth_est, photometric_confidence in zip(filenames, outputs["depth"], 121 | outputs["photometric_confidence"]): 122 | depth_filename = os.path.join(args.outdir, filename.format('depth_est', '.pfm')) 123 | confidence_filename = os.path.join(args.outdir, filename.format('confidence', '.pfm')) 124 | os.makedirs(depth_filename.rsplit('/', 1)[0], exist_ok=True) 125 | os.makedirs(confidence_filename.rsplit('/', 1)[0], exist_ok=True) 126 | # save depth maps 127 | save_pfm(depth_filename, depth_est) 128 | # save confidence maps 129 | save_pfm(confidence_filename, photometric_confidence) 130 | 131 | 132 | # project the reference point cloud into the source view, then project back 133 | def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): 134 | width, height = depth_ref.shape[1], depth_ref.shape[0] 135 | ## step1. project reference pixels to the source view 136 | # reference view x, y 137 | x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 138 | x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1]) 139 | # reference 3D space 140 | xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref), 141 | np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1])) 142 | # source 3D space 143 | xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)), 144 | np.vstack((xyz_ref, np.ones_like(x_ref))))[:3] 145 | # source view x, y 146 | K_xyz_src = np.matmul(intrinsics_src, xyz_src) 147 | xy_src = K_xyz_src[:2] / K_xyz_src[2:3] 148 | 149 | ## step2. reproject the source view points with source view depth estimation 150 | # find the depth estimation of the source view 151 | x_src = xy_src[0].reshape([height, width]).astype(np.float32) 152 | y_src = xy_src[1].reshape([height, width]).astype(np.float32) 153 | sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR) 154 | # mask = sampled_depth_src > 0 155 | 156 | # source 3D space 157 | # NOTE that we should use sampled source-view depth_here to project back 158 | xyz_src = np.matmul(np.linalg.inv(intrinsics_src), 159 | np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1])) 160 | # reference 3D space 161 | xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)), 162 | np.vstack((xyz_src, np.ones_like(x_ref))))[:3] 163 | # source view x, y, depth 164 | depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32) 165 | K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected) 166 | xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3] 167 | x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32) 168 | y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32) 169 | 170 | return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src 171 | 172 | 173 | def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): 174 | width, height = depth_ref.shape[1], depth_ref.shape[0] 175 | x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 176 | depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, 177 | depth_src, intrinsics_src, extrinsics_src) 178 | # check |p_reproj-p_1| < 1 179 | dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2) 180 | 181 | # check |d_reproj-d_1| / d_1 < 0.01 182 | depth_diff = np.abs(depth_reprojected - depth_ref) 183 | relative_depth_diff = depth_diff / depth_ref 184 | 185 | mask = np.logical_and(dist < 1, relative_depth_diff < 0.01) 186 | depth_reprojected[~mask] = 0 187 | 188 | return mask, depth_reprojected, x2d_src, y2d_src 189 | 190 | 191 | def filter_depth(scan_folder, out_folder, plyfilename): 192 | # the pair file 193 | pair_file = os.path.join(scan_folder, "pair.txt") 194 | # for the final point cloud 195 | vertexs = [] 196 | vertex_colors = [] 197 | 198 | pair_data = read_pair_file(pair_file) 199 | nviews = len(pair_data) 200 | # TODO: hardcode size 201 | # used_mask = [np.zeros([296, 400], dtype=np.bool) for _ in range(nviews)] 202 | 203 | # for each reference view and the corresponding source views 204 | for ref_view, src_views in pair_data: 205 | # load the camera parameters 206 | ref_intrinsics, ref_extrinsics = read_camera_parameters( 207 | os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(ref_view))) 208 | # load the reference image 209 | ref_img = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view))) 210 | # load the estimated depth of the reference view 211 | ref_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0] 212 | # load the photometric mask of the reference view 213 | confidence = read_pfm(os.path.join(out_folder, 'confidence/{:0>8}.pfm'.format(ref_view)))[0] 214 | photo_mask = confidence > 0.8 215 | 216 | all_srcview_depth_ests = [] 217 | all_srcview_x = [] 218 | all_srcview_y = [] 219 | all_srcview_geomask = [] 220 | 221 | # compute the geometric mask 222 | geo_mask_sum = 0 223 | for src_view in src_views: 224 | # camera parameters of the source view 225 | src_intrinsics, src_extrinsics = read_camera_parameters( 226 | os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(src_view))) 227 | # the estimated depth of the source view 228 | src_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))[0] 229 | 230 | geo_mask, depth_reprojected, x2d_src, y2d_src = check_geometric_consistency(ref_depth_est, ref_intrinsics, ref_extrinsics, 231 | src_depth_est, 232 | src_intrinsics, src_extrinsics) 233 | geo_mask_sum += geo_mask.astype(np.int32) 234 | all_srcview_depth_ests.append(depth_reprojected) 235 | all_srcview_x.append(x2d_src) 236 | all_srcview_y.append(y2d_src) 237 | all_srcview_geomask.append(geo_mask) 238 | 239 | depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1) 240 | # at least 3 source views matched 241 | geo_mask = geo_mask_sum >= 3 242 | final_mask = np.logical_and(photo_mask, geo_mask) 243 | 244 | os.makedirs(os.path.join(out_folder, "mask"), exist_ok=True) 245 | save_mask(os.path.join(out_folder, "mask/{:0>8}_photo.png".format(ref_view)), photo_mask) 246 | save_mask(os.path.join(out_folder, "mask/{:0>8}_geo.png".format(ref_view)), geo_mask) 247 | save_mask(os.path.join(out_folder, "mask/{:0>8}_final.png".format(ref_view)), final_mask) 248 | 249 | print("processing {}, ref-view{:0>2}, photo/geo/final-mask:{}/{}/{}".format(scan_folder, ref_view, 250 | photo_mask.mean(), 251 | geo_mask.mean(), final_mask.mean())) 252 | 253 | if args.display: 254 | import cv2 255 | cv2.imshow('ref_img', ref_img[:, :, ::-1]) 256 | cv2.imshow('ref_depth', ref_depth_est / 800) 257 | cv2.imshow('ref_depth * photo_mask', ref_depth_est * photo_mask.astype(np.float32) / 800) 258 | cv2.imshow('ref_depth * geo_mask', ref_depth_est * geo_mask.astype(np.float32) / 800) 259 | cv2.imshow('ref_depth * mask', ref_depth_est * final_mask.astype(np.float32) / 800) 260 | cv2.waitKey(0) 261 | 262 | height, width = depth_est_averaged.shape[:2] 263 | x, y = np.meshgrid(np.arange(0, width), np.arange(0, height)) 264 | # valid_points = np.logical_and(final_mask, ~used_mask[ref_view]) 265 | valid_points = final_mask 266 | print("valid_points", valid_points.mean()) 267 | x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[valid_points] 268 | color = ref_img[1:-16:4, 1::4, :][valid_points] # hardcoded for DTU dataset 269 | xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics), 270 | np.vstack((x, y, np.ones_like(x))) * depth) 271 | xyz_world = np.matmul(np.linalg.inv(ref_extrinsics), 272 | np.vstack((xyz_ref, np.ones_like(x))))[:3] 273 | vertexs.append(xyz_world.transpose((1, 0))) 274 | vertex_colors.append((color * 255).astype(np.uint8)) 275 | 276 | # # set used_mask[ref_view] 277 | # used_mask[ref_view][...] = True 278 | # for idx, src_view in enumerate(src_views): 279 | # src_mask = np.logical_and(final_mask, all_srcview_geomask[idx]) 280 | # src_y = all_srcview_y[idx].astype(np.int) 281 | # src_x = all_srcview_x[idx].astype(np.int) 282 | # used_mask[src_view][src_y[src_mask], src_x[src_mask]] = True 283 | 284 | vertexs = np.concatenate(vertexs, axis=0) 285 | vertex_colors = np.concatenate(vertex_colors, axis=0) 286 | vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 287 | vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 288 | 289 | vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr) 290 | for prop in vertexs.dtype.names: 291 | vertex_all[prop] = vertexs[prop] 292 | for prop in vertex_colors.dtype.names: 293 | vertex_all[prop] = vertex_colors[prop] 294 | 295 | el = PlyElement.describe(vertex_all, 'vertex') 296 | PlyData([el]).write(plyfilename) 297 | print("saving the final model to", plyfilename) 298 | 299 | 300 | if __name__ == '__main__': 301 | # step1. save all the depth maps and the masks in outputs directory 302 | # save_depth() 303 | 304 | with open(args.testlist) as f: 305 | scans = f.readlines() 306 | scans = [line.rstrip() for line in scans] 307 | 308 | for scan in scans: 309 | scan_id = int(scan[4:]) 310 | scan_folder = os.path.join(args.testpath, scan) 311 | out_folder = os.path.join(args.outdir, scan) 312 | # step2. filter saved depth maps with photometric confidence maps and geometric constraints 313 | filter_depth(scan_folder, out_folder, os.path.join(args.outdir, 'mvsnet{:0>3}_l3.ply'.format(scan_id))) 314 | -------------------------------------------------------------------------------- /eval.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | DTU_TESTING="/home/xyguo/dataset.ssd/dtu_mvs/processed/mvs_testing/dtu/" 3 | CKPT_FILE="./checkpoints/d192/model_000014.ckpt" 4 | python eval.py --dataset=dtu_yao_eval --batch_size=1 --testpath=$DTU_TESTING --testlist lists/dtu/test.txt --loadckpt $CKPT_FILE $@ 5 | -------------------------------------------------------------------------------- /evaluations/dtu/BaseEval2Obj_web.m: -------------------------------------------------------------------------------- 1 | function BaseEval2Obj_web(BaseEval,method_string,outputPath) 2 | 3 | if(nargin<3) 4 | outputPath='./'; 5 | end 6 | 7 | % tresshold for coloring alpha channel in the range of 0-10 mm 8 | dist_tresshold=10; 9 | 10 | cSet=BaseEval.cSet; 11 | 12 | Qdata=BaseEval.Qdata; 13 | alpha=min(BaseEval.Ddata,dist_tresshold)/dist_tresshold; 14 | 15 | fid=fopen([outputPath method_string '2Stl_' num2str(cSet) ' .obj'],'w+'); 16 | 17 | for cP=1:size(Qdata,2) 18 | if(BaseEval.DataInMask(cP)) 19 | C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold) 20 | else 21 | C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points outside the mask (which are not included in the analysis) 22 | end 23 | fprintf(fid,'v %f %f %f %f %f %f\n',[Qdata(1,cP) Qdata(2,cP) Qdata(3,cP) C(1) C(2) C(3)]); 24 | end 25 | fclose(fid); 26 | 27 | disp('Data2Stl saved as obj') 28 | 29 | Qstl=BaseEval.Qstl; 30 | fid=fopen([outputPath 'Stl2' method_string '_' num2str(cSet) '.obj'],'w+'); 31 | 32 | alpha=min(BaseEval.Dstl,dist_tresshold)/dist_tresshold; 33 | 34 | for cP=1:size(Qstl,2) 35 | if(BaseEval.StlAbovePlane(cP)) 36 | C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold) 37 | else 38 | C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points below plane (which are not included in the analysis) 39 | end 40 | fprintf(fid,'v %f %f %f %f %f %f\n',[Qstl(1,cP) Qstl(2,cP) Qstl(3,cP) C(1) C(2) C(3)]); 41 | end 42 | fclose(fid); 43 | 44 | disp('Stl2Data saved as obj') -------------------------------------------------------------------------------- /evaluations/dtu/BaseEvalMain_web.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | format compact 4 | clc 5 | 6 | % script to calculate distances have been measured for all included scans (UsedSets) 7 | 8 | dataPath='/home/xyguo/dataset/dtu_mvs/SampleSet/MVS Data/'; 9 | plyPath='/home/xyguo/code/mvsnet_pytorch/outputs/'; 10 | resultsPath='/home/xyguo/code/mvsnet_pytorch/outputs/'; 11 | 12 | method_string='mvsnet'; 13 | light_string='l3'; % l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6) 14 | representation_string='Points'; %mvs representation 'Points' or 'Surfaces' 15 | 16 | switch representation_string 17 | case 'Points' 18 | eval_string='_Eval_'; %results naming 19 | settings_string=''; 20 | end 21 | 22 | % get sets used in evaluation 23 | UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118]; 24 | 25 | dst=0.2; %Min dist between points when reducing 26 | 27 | for cIdx=1:length(UsedSets) 28 | %Data set number 29 | cSet = UsedSets(cIdx) 30 | %input data name 31 | DataInName=[plyPath sprintf('/%s%03d_%s%s.ply',lower(method_string),cSet,light_string,settings_string)] 32 | 33 | %results name 34 | EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat'] 35 | 36 | %check if file is already computed 37 | if(~exist(EvalName,'file')) 38 | disp(DataInName); 39 | 40 | time=clock;time(4:5), drawnow 41 | 42 | tic 43 | Mesh = plyread(DataInName); 44 | Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]'; 45 | toc 46 | 47 | BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath); 48 | 49 | disp('Saving results'), drawnow 50 | toc 51 | save(EvalName,'BaseEval'); 52 | toc 53 | 54 | % write obj-file of evaluation 55 | % BaseEval2Obj_web(BaseEval,method_string, resultsPath) 56 | % toc 57 | time=clock;time(4:5), drawnow 58 | 59 | BaseEval.MaxDist=20; %outlier threshold of 20 mm 60 | 61 | BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane 62 | BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl=Low(1) & Qfrom(2,:)>=Low(2) & Qfrom(3,:)>=Low(3) &... 18 | Qfrom(1,:)=Low(1) & Qto(2,:)>=Low(2) & Qto(3,:)>=Low(3) &... 25 | Qto(1,:)3)] 49 | end 50 | 51 | -------------------------------------------------------------------------------- /evaluations/dtu/PointCompareMain.m: -------------------------------------------------------------------------------- 1 | function BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath) 2 | % evaluation function the calculates the distantes from the reference data (stl) to the evalution points (Qdata) and the 3 | % distances from the evaluation points to the reference 4 | 5 | tic 6 | % reduce points 0.2 mm neighbourhood density 7 | Qdata=reducePts_haa(Qdata,dst); 8 | toc 9 | 10 | StlInName=[dataPath '/Points/stl/stl' sprintf('%03d',cSet) '_total.ply']; 11 | 12 | StlMesh = plyread(StlInName); %STL points already reduced 0.2 mm neighbourhood density 13 | Qstl=[StlMesh.vertex.x StlMesh.vertex.y StlMesh.vertex.z]'; 14 | 15 | %Load Mask (ObsMask) and Bounding box (BB) and Resolution (Res) 16 | Margin=10; 17 | MaskName=[dataPath '/ObsMask/ObsMask' num2str(cSet) '_' num2str(Margin) '.mat']; 18 | load(MaskName) 19 | 20 | MaxDist=60; 21 | disp('Computing Data 2 Stl distances') 22 | Ddata = MaxDistCP(Qstl,Qdata,BB,MaxDist); 23 | toc 24 | 25 | disp('Computing Stl 2 Data distances') 26 | Dstl=MaxDistCP(Qdata,Qstl,BB,MaxDist); 27 | disp('Distances computed') 28 | toc 29 | 30 | %use mask 31 | %From Get mask - inverted & modified. 32 | One=ones(1,size(Qdata,2)); 33 | Qv=(Qdata-BB(1,:)'*One)/Res+1; 34 | Qv=round(Qv); 35 | 36 | Midx1=find(Qv(1,:)>0 & Qv(1,:)<=size(ObsMask,1) & Qv(2,:)>0 & Qv(2,:)<=size(ObsMask,2) & Qv(3,:)>0 & Qv(3,:)<=size(ObsMask,3)); 37 | MidxA=sub2ind(size(ObsMask),Qv(1,Midx1),Qv(2,Midx1),Qv(3,Midx1)); 38 | Midx2=find(ObsMask(MidxA)); 39 | 40 | BaseEval.DataInMask(1:size(Qv,2))=false; 41 | BaseEval.DataInMask(Midx1(Midx2))=true; %If Data is within the mask 42 | 43 | BaseEval.cSet=cSet; 44 | BaseEval.Margin=Margin; %Margin of masks 45 | BaseEval.dst=dst; %Min dist between points when reducing 46 | BaseEval.Qdata=Qdata; %Input data points 47 | BaseEval.Ddata=Ddata; %distance from data to stl 48 | BaseEval.Qstl=Qstl; %Input stl points 49 | BaseEval.Dstl=Dstl; %Distance from the stl to data 50 | 51 | load([dataPath '/ObsMask/Plane' num2str(cSet)],'P') 52 | BaseEval.GroundPlane=P; % Plane used to destinguise which Stl points are 'used' 53 | BaseEval.StlAbovePlane=(P'*[Qstl;ones(1,size(Qstl,2))])>0; %Is stl above 'ground plane' 54 | BaseEval.Time=clock; %Time when computation is finished 55 | 56 | 57 | 58 | 59 | -------------------------------------------------------------------------------- /evaluations/dtu/plyread.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | function [Elements,varargout] = plyread(Path,Str) 3 | %PLYREAD Read a PLY 3D data file. 4 | % [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file 5 | % FILENAME and returns a structure DATA. The fields in this structure 6 | % are defined by the PLY header; each element type is a field and each 7 | % element property is a subfield. If the file contains any comments, 8 | % they are returned in a cell string array COMMENTS. 9 | % 10 | % [TRI,PTS] = PLYREAD(FILENAME,'tri') or 11 | % [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex 12 | % and face data into triangular connectivity and vertex arrays. The 13 | % mesh can then be displayed using the TRISURF command. 14 | % 15 | % Note: This function is slow for large mesh files (+50K faces), 16 | % especially when reading data with list type properties. 17 | % 18 | % Example: 19 | % [Tri,Pts] = PLYREAD('cow.ply','tri'); 20 | % trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3)); 21 | % colormap(gray); axis equal; 22 | % 23 | % See also: PLYWRITE 24 | 25 | % Pascal Getreuer 2004 26 | 27 | [fid,Msg] = fopen(Path,'rt'); % open file in read text mode 28 | 29 | if fid == -1, error(Msg); end 30 | 31 | Buf = fscanf(fid,'%s',1); 32 | if ~strcmp(Buf,'ply') 33 | fclose(fid); 34 | error('Not a PLY file.'); 35 | end 36 | 37 | 38 | %%% read header %%% 39 | 40 | Position = ftell(fid); 41 | Format = ''; 42 | NumComments = 0; 43 | Comments = {}; % for storing any file comments 44 | NumElements = 0; 45 | NumProperties = 0; 46 | Elements = []; % structure for holding the element data 47 | ElementCount = []; % number of each type of element in file 48 | PropertyTypes = []; % corresponding structure recording property types 49 | ElementNames = {}; % list of element names in the order they are stored in the file 50 | PropertyNames = []; % structure of lists of property names 51 | 52 | while 1 53 | Buf = fgetl(fid); % read one line from file 54 | BufRem = Buf; 55 | Token = {}; 56 | Count = 0; 57 | 58 | while ~isempty(BufRem) % split line into tokens 59 | [tmp,BufRem] = strtok(BufRem); 60 | 61 | if ~isempty(tmp) 62 | Count = Count + 1; % count tokens 63 | Token{Count} = tmp; 64 | end 65 | end 66 | 67 | if Count % parse line 68 | switch lower(Token{1}) 69 | case 'format' % read data format 70 | if Count >= 2 71 | Format = lower(Token{2}); 72 | 73 | if Count == 3 & ~strcmp(Token{3},'1.0') 74 | fclose(fid); 75 | error('Only PLY format version 1.0 supported.'); 76 | end 77 | end 78 | case 'comment' % read file comment 79 | NumComments = NumComments + 1; 80 | Comments{NumComments} = ''; 81 | for i = 2:Count 82 | Comments{NumComments} = [Comments{NumComments},Token{i},' ']; 83 | end 84 | case 'element' % element name 85 | if Count >= 3 86 | if isfield(Elements,Token{2}) 87 | fclose(fid); 88 | error(['Duplicate element name, ''',Token{2},'''.']); 89 | end 90 | 91 | NumElements = NumElements + 1; 92 | NumProperties = 0; 93 | Elements = setfield(Elements,Token{2},[]); 94 | PropertyTypes = setfield(PropertyTypes,Token{2},[]); 95 | ElementNames{NumElements} = Token{2}; 96 | PropertyNames = setfield(PropertyNames,Token{2},{}); 97 | CurElement = Token{2}; 98 | ElementCount(NumElements) = str2double(Token{3}); 99 | 100 | if isnan(ElementCount(NumElements)) 101 | fclose(fid); 102 | error(['Bad element definition: ',Buf]); 103 | end 104 | else 105 | error(['Bad element definition: ',Buf]); 106 | end 107 | case 'property' % element property 108 | if ~isempty(CurElement) & Count >= 3 109 | NumProperties = NumProperties + 1; 110 | eval(['tmp=isfield(Elements.',CurElement,',Token{Count});'],... 111 | 'fclose(fid);error([''Error reading property: '',Buf])'); 112 | 113 | if tmp 114 | error(['Duplicate property name, ''',CurElement,'.',Token{2},'''.']); 115 | end 116 | 117 | % add property subfield to Elements 118 | eval(['Elements.',CurElement,'.',Token{Count},'=[];'], ... 119 | 'fclose(fid);error([''Error reading property: '',Buf])'); 120 | % add property subfield to PropertyTypes and save type 121 | eval(['PropertyTypes.',CurElement,'.',Token{Count},'={Token{2:Count-1}};'], ... 122 | 'fclose(fid);error([''Error reading property: '',Buf])'); 123 | % record property name order 124 | eval(['PropertyNames.',CurElement,'{NumProperties}=Token{Count};'], ... 125 | 'fclose(fid);error([''Error reading property: '',Buf])'); 126 | else 127 | fclose(fid); 128 | 129 | if isempty(CurElement) 130 | error(['Property definition without element definition: ',Buf]); 131 | else 132 | error(['Bad property definition: ',Buf]); 133 | end 134 | end 135 | case 'end_header' % end of header, break from while loop 136 | break; 137 | end 138 | end 139 | end 140 | 141 | %%% set reading for specified data format %%% 142 | 143 | if isempty(Format) 144 | warning('Data format unspecified, assuming ASCII.'); 145 | Format = 'ascii'; 146 | end 147 | 148 | switch Format 149 | case 'ascii' 150 | Format = 0; 151 | case 'binary_little_endian' 152 | Format = 1; 153 | case 'binary_big_endian' 154 | Format = 2; 155 | otherwise 156 | fclose(fid); 157 | error(['Data format ''',Format,''' not supported.']); 158 | end 159 | 160 | if ~Format 161 | Buf = fscanf(fid,'%f'); % read the rest of the file as ASCII data 162 | BufOff = 1; 163 | else 164 | % reopen the file in read binary mode 165 | fclose(fid); 166 | 167 | if Format == 1 168 | fid = fopen(Path,'r','ieee-le.l64'); % little endian 169 | else 170 | fid = fopen(Path,'r','ieee-be.l64'); % big endian 171 | end 172 | 173 | % find the end of the header again (using ftell on the old handle doesn't give the correct position) 174 | BufSize = 8192; 175 | Buf = [blanks(10),char(fread(fid,BufSize,'uchar')')]; 176 | i = []; 177 | tmp = -11; 178 | 179 | while isempty(i) 180 | i = findstr(Buf,['end_header',13,10]); % look for end_header + CR/LF 181 | i = [i,findstr(Buf,['end_header',10])]; % look for end_header + LF 182 | 183 | if isempty(i) 184 | tmp = tmp + BufSize; 185 | Buf = [Buf(BufSize+1:BufSize+10),char(fread(fid,BufSize,'uchar')')]; 186 | end 187 | end 188 | 189 | % seek to just after the line feed 190 | fseek(fid,i + tmp + 11 + (Buf(i + 10) == 13),-1); 191 | end 192 | 193 | 194 | %%% read element data %%% 195 | 196 | % PLY and MATLAB data types (for fread) 197 | PlyTypeNames = {'char','uchar','short','ushort','int','uint','float','double', ... 198 | 'char8','uchar8','short16','ushort16','int32','uint32','float32','double64'}; 199 | MatlabTypeNames = {'schar','uchar','int16','uint16','int32','uint32','single','double'}; 200 | SizeOf = [1,1,2,2,4,4,4,8]; % size in bytes of each type 201 | 202 | for i = 1:NumElements 203 | % get current element property information 204 | eval(['CurPropertyNames=PropertyNames.',ElementNames{i},';']); 205 | eval(['CurPropertyTypes=PropertyTypes.',ElementNames{i},';']); 206 | NumProperties = size(CurPropertyNames,2); 207 | 208 | % fprintf('Reading %s...\n',ElementNames{i}); 209 | 210 | if ~Format %%% read ASCII data %%% 211 | for j = 1:NumProperties 212 | Token = getfield(CurPropertyTypes,CurPropertyNames{j}); 213 | 214 | if strcmpi(Token{1},'list') 215 | Type(j) = 1; 216 | else 217 | Type(j) = 0; 218 | end 219 | end 220 | 221 | % parse buffer 222 | if ~any(Type) 223 | % no list types 224 | Data = reshape(Buf(BufOff:BufOff+ElementCount(i)*NumProperties-1),NumProperties,ElementCount(i))'; 225 | BufOff = BufOff + ElementCount(i)*NumProperties; 226 | else 227 | ListData = cell(NumProperties,1); 228 | 229 | for k = 1:NumProperties 230 | ListData{k} = cell(ElementCount(i),1); 231 | end 232 | 233 | % list type 234 | for j = 1:ElementCount(i) 235 | for k = 1:NumProperties 236 | if ~Type(k) 237 | Data(j,k) = Buf(BufOff); 238 | BufOff = BufOff + 1; 239 | else 240 | tmp = Buf(BufOff); 241 | ListData{k}{j} = Buf(BufOff+(1:tmp))'; 242 | BufOff = BufOff + tmp + 1; 243 | end 244 | end 245 | end 246 | end 247 | else %%% read binary data %%% 248 | % translate PLY data type names to MATLAB data type names 249 | ListFlag = 0; % = 1 if there is a list type 250 | SameFlag = 1; % = 1 if all types are the same 251 | 252 | for j = 1:NumProperties 253 | Token = getfield(CurPropertyTypes,CurPropertyNames{j}); 254 | 255 | if ~strcmp(Token{1},'list') % non-list type 256 | tmp = rem(strmatch(Token{1},PlyTypeNames,'exact')-1,8)+1; 257 | 258 | if ~isempty(tmp) 259 | TypeSize(j) = SizeOf(tmp); 260 | Type{j} = MatlabTypeNames{tmp}; 261 | TypeSize2(j) = 0; 262 | Type2{j} = ''; 263 | 264 | SameFlag = SameFlag & strcmp(Type{1},Type{j}); 265 | else 266 | fclose(fid); 267 | error(['Unknown property data type, ''',Token{1},''', in ', ... 268 | ElementNames{i},'.',CurPropertyNames{j},'.']); 269 | end 270 | else % list type 271 | if length(Token) == 3 272 | ListFlag = 1; 273 | SameFlag = 0; 274 | tmp = rem(strmatch(Token{2},PlyTypeNames,'exact')-1,8)+1; 275 | tmp2 = rem(strmatch(Token{3},PlyTypeNames,'exact')-1,8)+1; 276 | 277 | if ~isempty(tmp) & ~isempty(tmp2) 278 | TypeSize(j) = SizeOf(tmp); 279 | Type{j} = MatlabTypeNames{tmp}; 280 | TypeSize2(j) = SizeOf(tmp2); 281 | Type2{j} = MatlabTypeNames{tmp2}; 282 | else 283 | fclose(fid); 284 | error(['Unknown property data type, ''list ',Token{2},' ',Token{3},''', in ', ... 285 | ElementNames{i},'.',CurPropertyNames{j},'.']); 286 | end 287 | else 288 | fclose(fid); 289 | error(['Invalid list syntax in ',ElementNames{i},'.',CurPropertyNames{j},'.']); 290 | end 291 | end 292 | end 293 | 294 | % read file 295 | if ~ListFlag 296 | if SameFlag 297 | % no list types, all the same type (fast) 298 | Data = fread(fid,[NumProperties,ElementCount(i)],Type{1})'; 299 | else 300 | % no list types, mixed type 301 | Data = zeros(ElementCount(i),NumProperties); 302 | 303 | for j = 1:ElementCount(i) 304 | for k = 1:NumProperties 305 | Data(j,k) = fread(fid,1,Type{k}); 306 | end 307 | end 308 | end 309 | else 310 | ListData = cell(NumProperties,1); 311 | 312 | for k = 1:NumProperties 313 | ListData{k} = cell(ElementCount(i),1); 314 | end 315 | 316 | if NumProperties == 1 317 | BufSize = 512; 318 | SkipNum = 4; 319 | j = 0; 320 | 321 | % list type, one property (fast if lists are usually the same length) 322 | while j < ElementCount(i) 323 | Position = ftell(fid); 324 | % read in BufSize count values, assuming all counts = SkipNum 325 | [Buf,BufSize] = fread(fid,BufSize,Type{1},SkipNum*TypeSize2(1)); 326 | Miss = find(Buf ~= SkipNum); % find first count that is not SkipNum 327 | fseek(fid,Position + TypeSize(1),-1); % seek back to after first count 328 | 329 | if isempty(Miss) % all counts are SkipNum 330 | Buf = fread(fid,[SkipNum,BufSize],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; 331 | fseek(fid,-TypeSize(1),0); % undo last skip 332 | 333 | for k = 1:BufSize 334 | ListData{1}{j+k} = Buf(k,:); 335 | end 336 | 337 | j = j + BufSize; 338 | BufSize = floor(1.5*BufSize); 339 | else 340 | if Miss(1) > 1 % some counts are SkipNum 341 | Buf2 = fread(fid,[SkipNum,Miss(1)-1],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; 342 | 343 | for k = 1:Miss(1)-1 344 | ListData{1}{j+k} = Buf2(k,:); 345 | end 346 | 347 | j = j + k; 348 | end 349 | 350 | % read in the list with the missed count 351 | SkipNum = Buf(Miss(1)); 352 | j = j + 1; 353 | ListData{1}{j} = fread(fid,[1,SkipNum],Type2{1}); 354 | BufSize = ceil(0.6*BufSize); 355 | end 356 | end 357 | else 358 | % list type(s), multiple properties (slow) 359 | Data = zeros(ElementCount(i),NumProperties); 360 | 361 | for j = 1:ElementCount(i) 362 | for k = 1:NumProperties 363 | if isempty(Type2{k}) 364 | Data(j,k) = fread(fid,1,Type{k}); 365 | else 366 | tmp = fread(fid,1,Type{k}); 367 | ListData{k}{j} = fread(fid,[1,tmp],Type2{k}); 368 | end 369 | end 370 | end 371 | end 372 | end 373 | end 374 | 375 | % put data into Elements structure 376 | for k = 1:NumProperties 377 | if (~Format & ~Type(k)) | (Format & isempty(Type2{k})) 378 | eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=Data(:,k);']); 379 | else 380 | eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=ListData{k};']); 381 | end 382 | end 383 | end 384 | 385 | clear Data ListData; 386 | fclose(fid); 387 | 388 | if (nargin > 1 & strcmpi(Str,'Tri')) | nargout > 2 389 | % find vertex element field 390 | Name = {'vertex','Vertex','point','Point','pts','Pts'}; 391 | Names = []; 392 | 393 | for i = 1:length(Name) 394 | if any(strcmp(ElementNames,Name{i})) 395 | Names = getfield(PropertyNames,Name{i}); 396 | Name = Name{i}; 397 | break; 398 | end 399 | end 400 | 401 | if any(strcmp(Names,'x')) & any(strcmp(Names,'y')) & any(strcmp(Names,'z')) 402 | eval(['varargout{1}=[Elements.',Name,'.x,Elements.',Name,'.y,Elements.',Name,'.z];']); 403 | else 404 | varargout{1} = zeros(1,3); 405 | end 406 | 407 | varargout{2} = Elements; 408 | varargout{3} = Comments; 409 | Elements = []; 410 | 411 | % find face element field 412 | Name = {'face','Face','poly','Poly','tri','Tri'}; 413 | Names = []; 414 | 415 | for i = 1:length(Name) 416 | if any(strcmp(ElementNames,Name{i})) 417 | Names = getfield(PropertyNames,Name{i}); 418 | Name = Name{i}; 419 | break; 420 | end 421 | end 422 | 423 | if ~isempty(Names) 424 | % find vertex indices property subfield 425 | PropertyName = {'vertex_indices','vertex_indexes','vertex_index','indices','indexes'}; 426 | 427 | for i = 1:length(PropertyName) 428 | if any(strcmp(Names,PropertyName{i})) 429 | PropertyName = PropertyName{i}; 430 | break; 431 | end 432 | end 433 | 434 | if ~iscell(PropertyName) 435 | % convert face index lists to triangular connectivity 436 | eval(['FaceIndices=varargout{2}.',Name,'.',PropertyName,';']); 437 | N = length(FaceIndices); 438 | Elements = zeros(N*2,3); 439 | Extra = 0; 440 | 441 | for k = 1:N 442 | Elements(k,:) = FaceIndices{k}(1:3); 443 | 444 | for j = 4:length(FaceIndices{k}) 445 | Extra = Extra + 1; 446 | Elements(N + Extra,:) = [Elements(k,[1,j-1]),FaceIndices{k}(j)]; 447 | end 448 | end 449 | Elements = Elements(1:N+Extra,:) + 1; 450 | end 451 | end 452 | else 453 | varargout{1} = Comments; 454 | end -------------------------------------------------------------------------------- /evaluations/dtu/reducePts_haa.m: -------------------------------------------------------------------------------- 1 | function [ptsOut,indexSet] = reducePts_haa(pts, dst) 2 | 3 | %Reduces a point set, pts, in a stochastic manner, such that the minimum sdistance 4 | % between points is 'dst'. Writen by abd, edited by haa, then by raje 5 | 6 | nPoints=size(pts,2); 7 | 8 | indexSet=true(nPoints,1); 9 | RandOrd=randperm(nPoints); 10 | 11 | %tic 12 | NS = KDTreeSearcher(pts'); 13 | %toc 14 | 15 | % search the KNTree for close neighbours in a chunk-wise fashion to save memory if point cloud is really big 16 | Chunks=1:min(4e6,nPoints-1):nPoints; 17 | Chunks(end)=nPoints; 18 | 19 | for cChunk=1:(length(Chunks)-1) 20 | Range=Chunks(cChunk):Chunks(cChunk+1); 21 | 22 | idx = rangesearch(NS,pts(:,RandOrd(Range))',dst); 23 | 24 | for i = 1:size(idx,1) 25 | id =RandOrd(i-1+Chunks(cChunk)); 26 | if (indexSet(id)) 27 | indexSet(idx{i}) = 0; 28 | indexSet(id) = 1; 29 | end 30 | end 31 | end 32 | 33 | ptsOut = pts(:,indexSet); 34 | 35 | disp(['downsample factor: ' num2str(nPoints/sum(indexSet))]); 36 | -------------------------------------------------------------------------------- /lists/dtu/test.txt: -------------------------------------------------------------------------------- 1 | scan1 2 | scan4 3 | scan9 4 | scan10 5 | scan11 6 | scan12 7 | scan13 8 | scan15 9 | scan23 10 | scan24 11 | scan29 12 | scan32 13 | scan33 14 | scan34 15 | scan48 16 | scan49 17 | scan62 18 | scan75 19 | scan77 20 | scan110 21 | scan114 22 | scan118 -------------------------------------------------------------------------------- /lists/dtu/train.txt: -------------------------------------------------------------------------------- 1 | scan2 2 | scan6 3 | scan7 4 | scan8 5 | scan14 6 | scan16 7 | scan18 8 | scan19 9 | scan20 10 | scan22 11 | scan30 12 | scan31 13 | scan36 14 | scan39 15 | scan41 16 | scan42 17 | scan44 18 | scan45 19 | scan46 20 | scan47 21 | scan50 22 | scan51 23 | scan52 24 | scan53 25 | scan55 26 | scan57 27 | scan58 28 | scan60 29 | scan61 30 | scan63 31 | scan64 32 | scan65 33 | scan68 34 | scan69 35 | scan70 36 | scan71 37 | scan72 38 | scan74 39 | scan76 40 | scan83 41 | scan84 42 | scan85 43 | scan87 44 | scan88 45 | scan89 46 | scan90 47 | scan91 48 | scan92 49 | scan93 50 | scan94 51 | scan95 52 | scan96 53 | scan97 54 | scan98 55 | scan99 56 | scan100 57 | scan101 58 | scan102 59 | scan103 60 | scan104 61 | scan105 62 | scan107 63 | scan108 64 | scan109 65 | scan111 66 | scan112 67 | scan113 68 | scan115 69 | scan116 70 | scan119 71 | scan120 72 | scan121 73 | scan122 74 | scan123 75 | scan124 76 | scan125 77 | scan126 78 | scan127 79 | scan128 -------------------------------------------------------------------------------- /lists/dtu/val.txt: -------------------------------------------------------------------------------- 1 | scan3 2 | scan5 3 | scan17 4 | scan21 5 | scan28 6 | scan35 7 | scan37 8 | scan38 9 | scan40 10 | scan43 11 | scan56 12 | scan59 13 | scan66 14 | scan67 15 | scan82 16 | scan86 17 | scan106 18 | scan117 -------------------------------------------------------------------------------- /models/__init__.py: -------------------------------------------------------------------------------- 1 | from models.mvsnet import MVSNet, mvsnet_loss -------------------------------------------------------------------------------- /models/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvBnReLU(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): 8 | super(ConvBnReLU, self).__init__() 9 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) 10 | self.bn = nn.BatchNorm2d(out_channels) 11 | 12 | def forward(self, x): 13 | return F.relu(self.bn(self.conv(x)), inplace=True) 14 | 15 | 16 | class ConvBn(nn.Module): 17 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): 18 | super(ConvBn, self).__init__() 19 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) 20 | self.bn = nn.BatchNorm2d(out_channels) 21 | 22 | def forward(self, x): 23 | return self.bn(self.conv(x)) 24 | 25 | 26 | class ConvBnReLU3D(nn.Module): 27 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): 28 | super(ConvBnReLU3D, self).__init__() 29 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) 30 | self.bn = nn.BatchNorm3d(out_channels) 31 | 32 | def forward(self, x): 33 | return F.relu(self.bn(self.conv(x)), inplace=True) 34 | 35 | 36 | class ConvBn3D(nn.Module): 37 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): 38 | super(ConvBn3D, self).__init__() 39 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False) 40 | self.bn = nn.BatchNorm3d(out_channels) 41 | 42 | def forward(self, x): 43 | return self.bn(self.conv(x)) 44 | 45 | 46 | class BasicBlock(nn.Module): 47 | def __init__(self, in_channels, out_channels, stride, downsample=None): 48 | super(BasicBlock, self).__init__() 49 | 50 | self.conv1 = ConvBnReLU(in_channels, out_channels, kernel_size=3, stride=stride, pad=1) 51 | self.conv2 = ConvBn(out_channels, out_channels, kernel_size=3, stride=1, pad=1) 52 | 53 | self.downsample = downsample 54 | self.stride = stride 55 | 56 | def forward(self, x): 57 | out = self.conv1(x) 58 | out = self.conv2(out) 59 | if self.downsample is not None: 60 | x = self.downsample(x) 61 | out += x 62 | return out 63 | 64 | 65 | class Hourglass3d(nn.Module): 66 | def __init__(self, channels): 67 | super(Hourglass3d, self).__init__() 68 | 69 | self.conv1a = ConvBnReLU3D(channels, channels * 2, kernel_size=3, stride=2, pad=1) 70 | self.conv1b = ConvBnReLU3D(channels * 2, channels * 2, kernel_size=3, stride=1, pad=1) 71 | 72 | self.conv2a = ConvBnReLU3D(channels * 2, channels * 4, kernel_size=3, stride=2, pad=1) 73 | self.conv2b = ConvBnReLU3D(channels * 4, channels * 4, kernel_size=3, stride=1, pad=1) 74 | 75 | self.dconv2 = nn.Sequential( 76 | nn.ConvTranspose3d(channels * 4, channels * 2, kernel_size=3, padding=1, output_padding=1, stride=2, 77 | bias=False), 78 | nn.BatchNorm3d(channels * 2)) 79 | 80 | self.dconv1 = nn.Sequential( 81 | nn.ConvTranspose3d(channels * 2, channels, kernel_size=3, padding=1, output_padding=1, stride=2, 82 | bias=False), 83 | nn.BatchNorm3d(channels)) 84 | 85 | self.redir1 = ConvBn3D(channels, channels, kernel_size=1, stride=1, pad=0) 86 | self.redir2 = ConvBn3D(channels * 2, channels * 2, kernel_size=1, stride=1, pad=0) 87 | 88 | def forward(self, x): 89 | conv1 = self.conv1b(self.conv1a(x)) 90 | conv2 = self.conv2b(self.conv2a(conv1)) 91 | dconv2 = F.relu(self.dconv2(conv2) + self.redir2(conv1), inplace=True) 92 | dconv1 = F.relu(self.dconv1(dconv2) + self.redir1(x), inplace=True) 93 | return dconv1 94 | 95 | 96 | def homo_warping(src_fea, src_proj, ref_proj, depth_values): 97 | # src_fea: [B, C, H, W] 98 | # src_proj: [B, 4, 4] 99 | # ref_proj: [B, 4, 4] 100 | # depth_values: [B, Ndepth] 101 | # out: [B, C, Ndepth, H, W] 102 | batch, channels = src_fea.shape[0], src_fea.shape[1] 103 | num_depth = depth_values.shape[1] 104 | height, width = src_fea.shape[2], src_fea.shape[3] 105 | 106 | with torch.no_grad(): 107 | proj = torch.matmul(src_proj, torch.inverse(ref_proj)) 108 | rot = proj[:, :3, :3] # [B,3,3] 109 | trans = proj[:, :3, 3:4] # [B,3,1] 110 | 111 | y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=src_fea.device), 112 | torch.arange(0, width, dtype=torch.float32, device=src_fea.device)]) 113 | y, x = y.contiguous(), x.contiguous() 114 | y, x = y.view(height * width), x.view(height * width) 115 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] 116 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W] 117 | rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W] 118 | rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.view(batch, 1, num_depth, 119 | 1) # [B, 3, Ndepth, H*W] 120 | proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W] 121 | proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W] 122 | proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1 123 | proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1 124 | proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2] 125 | grid = proj_xy 126 | 127 | warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear', 128 | padding_mode='zeros') 129 | warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height, width) 130 | 131 | return warped_src_fea 132 | 133 | 134 | # p: probability volume [B, D, H, W] 135 | # depth_values: discrete depth values [B, D] 136 | def depth_regression(p, depth_values): 137 | depth_values = depth_values.view(*depth_values.shape, 1, 1) 138 | depth = torch.sum(p * depth_values, 1) 139 | return depth 140 | 141 | 142 | if __name__ == "__main__": 143 | # some testing code, just IGNORE it 144 | from datasets import find_dataset_def 145 | from torch.utils.data import DataLoader 146 | import numpy as np 147 | import cv2 148 | 149 | MVSDataset = find_dataset_def("dtu_yao") 150 | dataset = MVSDataset("/home/xyguo/dataset/dtu_mvs/processed/mvs_training/dtu/", '../lists/dtu/train.txt', 'train', 151 | 3, 256) 152 | dataloader = DataLoader(dataset, batch_size=2) 153 | item = next(iter(dataloader)) 154 | 155 | imgs = item["imgs"][:, :, :, ::4, ::4].cuda() 156 | proj_matrices = item["proj_matrices"].cuda() 157 | mask = item["mask"].cuda() 158 | depth = item["depth"].cuda() 159 | depth_values = item["depth_values"].cuda() 160 | 161 | imgs = torch.unbind(imgs, 1) 162 | proj_matrices = torch.unbind(proj_matrices, 1) 163 | ref_img, src_imgs = imgs[0], imgs[1:] 164 | ref_proj, src_projs = proj_matrices[0], proj_matrices[1:] 165 | 166 | warped_imgs = homo_warping(src_imgs[0], src_projs[0], ref_proj, depth_values) 167 | 168 | cv2.imwrite('../tmp/ref.png', ref_img.permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255) 169 | cv2.imwrite('../tmp/src.png', src_imgs[0].permute([0, 2, 3, 1])[0].detach().cpu().numpy()[:, :, ::-1] * 255) 170 | 171 | for i in range(warped_imgs.shape[2]): 172 | warped_img = warped_imgs[:, :, i, :, :].permute([0, 2, 3, 1]).contiguous() 173 | img_np = warped_img[0].detach().cpu().numpy() 174 | cv2.imwrite('../tmp/tmp{}.png'.format(i), img_np[:, :, ::-1] * 255) 175 | 176 | 177 | # generate gt 178 | def tocpu(x): 179 | return x.detach().cpu().numpy().copy() 180 | 181 | 182 | ref_img = tocpu(ref_img)[0].transpose([1, 2, 0]) 183 | src_imgs = [tocpu(x)[0].transpose([1, 2, 0]) for x in src_imgs] 184 | ref_proj_mat = tocpu(ref_proj)[0] 185 | src_proj_mats = [tocpu(x)[0] for x in src_projs] 186 | mask = tocpu(mask)[0] 187 | depth = tocpu(depth)[0] 188 | depth_values = tocpu(depth_values)[0] 189 | 190 | for i, D in enumerate(depth_values): 191 | height = ref_img.shape[0] 192 | width = ref_img.shape[1] 193 | xx, yy = np.meshgrid(np.arange(0, width), np.arange(0, height)) 194 | print("yy", yy.max(), yy.min()) 195 | yy = yy.reshape([-1]) 196 | xx = xx.reshape([-1]) 197 | X = np.vstack((xx, yy, np.ones_like(xx))) 198 | # D = depth.reshape([-1]) 199 | # print("X", "D", X.shape, D.shape) 200 | 201 | X = np.vstack((X * D, np.ones_like(xx))) 202 | X = np.matmul(np.linalg.inv(ref_proj_mat), X) 203 | X = np.matmul(src_proj_mats[0], X) 204 | X /= X[2] 205 | X = X[:2] 206 | 207 | yy = X[0].reshape([height, width]).astype(np.float32) 208 | xx = X[1].reshape([height, width]).astype(np.float32) 209 | 210 | warped = cv2.remap(src_imgs[0], yy, xx, interpolation=cv2.INTER_LINEAR) 211 | # warped[mask[:, :] < 0.5] = 0 212 | 213 | cv2.imwrite('../tmp/tmp{}_gt.png'.format(i), warped[:, :, ::-1] * 255) 214 | -------------------------------------------------------------------------------- /models/mvsnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from .module import * 5 | 6 | 7 | class FeatureNet(nn.Module): 8 | def __init__(self): 9 | super(FeatureNet, self).__init__() 10 | self.inplanes = 32 11 | 12 | self.conv0 = ConvBnReLU(3, 8, 3, 1, 1) 13 | self.conv1 = ConvBnReLU(8, 8, 3, 1, 1) 14 | 15 | self.conv2 = ConvBnReLU(8, 16, 5, 2, 2) 16 | self.conv3 = ConvBnReLU(16, 16, 3, 1, 1) 17 | self.conv4 = ConvBnReLU(16, 16, 3, 1, 1) 18 | 19 | self.conv5 = ConvBnReLU(16, 32, 5, 2, 2) 20 | self.conv6 = ConvBnReLU(32, 32, 3, 1, 1) 21 | self.feature = nn.Conv2d(32, 32, 3, 1, 1) 22 | 23 | def forward(self, x): 24 | x = self.conv1(self.conv0(x)) 25 | x = self.conv4(self.conv3(self.conv2(x))) 26 | x = self.feature(self.conv6(self.conv5(x))) 27 | return x 28 | 29 | 30 | class CostRegNet(nn.Module): 31 | def __init__(self): 32 | super(CostRegNet, self).__init__() 33 | self.conv0 = ConvBnReLU3D(32, 8) 34 | 35 | self.conv1 = ConvBnReLU3D(8, 16, stride=2) 36 | self.conv2 = ConvBnReLU3D(16, 16) 37 | 38 | self.conv3 = ConvBnReLU3D(16, 32, stride=2) 39 | self.conv4 = ConvBnReLU3D(32, 32) 40 | 41 | self.conv5 = ConvBnReLU3D(32, 64, stride=2) 42 | self.conv6 = ConvBnReLU3D(64, 64) 43 | 44 | self.conv7 = nn.Sequential( 45 | nn.ConvTranspose3d(64, 32, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), 46 | nn.BatchNorm3d(32), 47 | nn.ReLU(inplace=True)) 48 | 49 | self.conv9 = nn.Sequential( 50 | nn.ConvTranspose3d(32, 16, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), 51 | nn.BatchNorm3d(16), 52 | nn.ReLU(inplace=True)) 53 | 54 | self.conv11 = nn.Sequential( 55 | nn.ConvTranspose3d(16, 8, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), 56 | nn.BatchNorm3d(8), 57 | nn.ReLU(inplace=True)) 58 | 59 | self.prob = nn.Conv3d(8, 1, 3, stride=1, padding=1) 60 | 61 | def forward(self, x): 62 | conv0 = self.conv0(x) 63 | conv2 = self.conv2(self.conv1(conv0)) 64 | conv4 = self.conv4(self.conv3(conv2)) 65 | x = self.conv6(self.conv5(conv4)) 66 | x = conv4 + self.conv7(x) 67 | x = conv2 + self.conv9(x) 68 | x = conv0 + self.conv11(x) 69 | x = self.prob(x) 70 | return x 71 | 72 | 73 | class RefineNet(nn.Module): 74 | def __init__(self): 75 | super(RefineNet, self).__init__() 76 | self.conv1 = ConvBnReLU(4, 32) 77 | self.conv2 = ConvBnReLU(32, 32) 78 | self.conv3 = ConvBnReLU(32, 32) 79 | self.res = ConvBnReLU(32, 1) 80 | 81 | def forward(self, img, depth_init): 82 | concat = F.cat((img, depth_init), dim=1) 83 | depth_residual = self.res(self.conv3(self.conv2(self.conv1(concat)))) 84 | depth_refined = depth_init + depth_residual 85 | return depth_refined 86 | 87 | 88 | class MVSNet(nn.Module): 89 | def __init__(self, refine=True): 90 | super(MVSNet, self).__init__() 91 | self.refine = refine 92 | 93 | self.feature = FeatureNet() 94 | self.cost_regularization = CostRegNet() 95 | if self.refine: 96 | self.refine_network = RefineNet() 97 | 98 | def forward(self, imgs, proj_matrices, depth_values): 99 | imgs = torch.unbind(imgs, 1) 100 | proj_matrices = torch.unbind(proj_matrices, 1) 101 | assert len(imgs) == len(proj_matrices), "Different number of images and projection matrices" 102 | img_height, img_width = imgs[0].shape[2], imgs[0].shape[3] 103 | num_depth = depth_values.shape[1] 104 | num_views = len(imgs) 105 | 106 | # step 1. feature extraction 107 | # in: images; out: 32-channel feature maps 108 | features = [self.feature(img) for img in imgs] 109 | ref_feature, src_features = features[0], features[1:] 110 | ref_proj, src_projs = proj_matrices[0], proj_matrices[1:] 111 | 112 | # step 2. differentiable homograph, build cost volume 113 | ref_volume = ref_feature.unsqueeze(2).repeat(1, 1, num_depth, 1, 1) 114 | volume_sum = ref_volume 115 | volume_sq_sum = ref_volume ** 2 116 | del ref_volume 117 | for src_fea, src_proj in zip(src_features, src_projs): 118 | # warpped features 119 | warped_volume = homo_warping(src_fea, src_proj, ref_proj, depth_values) 120 | if self.training: 121 | volume_sum = volume_sum + warped_volume 122 | volume_sq_sum = volume_sq_sum + warped_volume ** 2 123 | else: 124 | # TODO: this is only a temporal solution to save memory, better way? 125 | volume_sum += warped_volume 126 | volume_sq_sum += warped_volume.pow_(2) # the memory of warped_volume has been modified 127 | del warped_volume 128 | # aggregate multiple feature volumes by variance 129 | volume_variance = volume_sq_sum.div_(num_views).sub_(volume_sum.div_(num_views).pow_(2)) 130 | 131 | # step 3. cost volume regularization 132 | cost_reg = self.cost_regularization(volume_variance) 133 | # cost_reg = F.upsample(cost_reg, [num_depth * 4, img_height, img_width], mode='trilinear') 134 | cost_reg = cost_reg.squeeze(1) 135 | prob_volume = F.softmax(cost_reg, dim=1) 136 | depth = depth_regression(prob_volume, depth_values=depth_values) 137 | 138 | with torch.no_grad(): 139 | # photometric confidence 140 | prob_volume_sum4 = 4 * F.avg_pool3d(F.pad(prob_volume.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)), (4, 1, 1), stride=1, padding=0).squeeze(1) 141 | depth_index = depth_regression(prob_volume, depth_values=torch.arange(num_depth, device=prob_volume.device, dtype=torch.float)).long() 142 | photometric_confidence = torch.gather(prob_volume_sum4, 1, depth_index.unsqueeze(1)).squeeze(1) 143 | 144 | # step 4. depth map refinement 145 | if not self.refine: 146 | return {"depth": depth, "photometric_confidence": photometric_confidence} 147 | else: 148 | refined_depth = self.refine_network(torch.cat((imgs[0], depth), 1)) 149 | return {"depth": depth, "refined_depth": refined_depth, "photometric_confidence": photometric_confidence} 150 | 151 | 152 | def mvsnet_loss(depth_est, depth_gt, mask): 153 | mask = mask > 0.5 154 | return F.smooth_l1_loss(depth_est[mask], depth_gt[mask], size_average=True) 155 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.parallel 6 | import torch.backends.cudnn as cudnn 7 | import torch.optim as optim 8 | from torch.utils.data import DataLoader 9 | from torch.autograd import Variable 10 | import torch.nn.functional as F 11 | import numpy as np 12 | import time 13 | from tensorboardX import SummaryWriter 14 | from datasets import find_dataset_def 15 | from models import * 16 | from utils import * 17 | import gc 18 | import sys 19 | import datetime 20 | 21 | cudnn.benchmark = True 22 | 23 | parser = argparse.ArgumentParser(description='A PyTorch Implementation of MVSNet') 24 | parser.add_argument('--mode', default='train', help='train or test', choices=['train', 'test', 'profile']) 25 | parser.add_argument('--model', default='mvsnet', help='select model') 26 | 27 | parser.add_argument('--dataset', default='dtu_yao', help='select dataset') 28 | parser.add_argument('--trainpath', help='train datapath') 29 | parser.add_argument('--testpath', help='test datapath') 30 | parser.add_argument('--trainlist', help='train list') 31 | parser.add_argument('--testlist', help='test list') 32 | 33 | parser.add_argument('--epochs', type=int, default=16, help='number of epochs to train') 34 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 35 | parser.add_argument('--lrepochs', type=str, default="10,12,14:2", help='epoch ids to downscale lr and the downscale rate') 36 | parser.add_argument('--wd', type=float, default=0.0, help='weight decay') 37 | 38 | parser.add_argument('--batch_size', type=int, default=12, help='train batch size') 39 | parser.add_argument('--numdepth', type=int, default=192, help='the number of depth values') 40 | parser.add_argument('--interval_scale', type=float, default=1.06, help='the number of depth values') 41 | 42 | parser.add_argument('--loadckpt', default=None, help='load a specific checkpoint') 43 | parser.add_argument('--logdir', default='./checkpoints/debug', help='the directory to save checkpoints/logs') 44 | parser.add_argument('--resume', action='store_true', help='continue to train the model') 45 | 46 | parser.add_argument('--summary_freq', type=int, default=20, help='print and summary frequency') 47 | parser.add_argument('--save_freq', type=int, default=1, help='save checkpoint frequency') 48 | parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed') 49 | 50 | # parse arguments and check 51 | args = parser.parse_args() 52 | if args.resume: 53 | assert args.mode == "train" 54 | assert args.loadckpt is None 55 | if args.testpath is None: 56 | args.testpath = args.trainpath 57 | 58 | torch.manual_seed(args.seed) 59 | torch.cuda.manual_seed(args.seed) 60 | 61 | # create logger for mode "train" and "testall" 62 | if args.mode == "train": 63 | if not os.path.isdir(args.logdir): 64 | os.mkdir(args.logdir) 65 | 66 | current_time_str = str(datetime.datetime.now().strftime('%Y%m%d_%H%M%S')) 67 | print("current time", current_time_str) 68 | 69 | print("creating new summary file") 70 | logger = SummaryWriter(args.logdir) 71 | 72 | print("argv:", sys.argv[1:]) 73 | print_args(args) 74 | 75 | # dataset, dataloader 76 | MVSDataset = find_dataset_def(args.dataset) 77 | train_dataset = MVSDataset(args.trainpath, args.trainlist, "train", 3, args.numdepth, args.interval_scale) 78 | test_dataset = MVSDataset(args.testpath, args.testlist, "test", 5, args.numdepth, args.interval_scale) 79 | TrainImgLoader = DataLoader(train_dataset, args.batch_size, shuffle=True, num_workers=8, drop_last=True) 80 | TestImgLoader = DataLoader(test_dataset, args.batch_size, shuffle=False, num_workers=4, drop_last=False) 81 | 82 | # model, optimizer 83 | model = MVSNet(refine=False) 84 | if args.mode in ["train", "test"]: 85 | model = nn.DataParallel(model) 86 | model.cuda() 87 | model_loss = mvsnet_loss 88 | optimizer = optim.Adam(model.parameters(), lr=args.lr, betas=(0.9, 0.999), weight_decay=args.wd) 89 | 90 | # load parameters 91 | start_epoch = 0 92 | if (args.mode == "train" and args.resume) or (args.mode == "test" and not args.loadckpt): 93 | saved_models = [fn for fn in os.listdir(args.logdir) if fn.endswith(".ckpt")] 94 | saved_models = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0])) 95 | # use the latest checkpoint file 96 | loadckpt = os.path.join(args.logdir, saved_models[-1]) 97 | print("resuming", loadckpt) 98 | state_dict = torch.load(loadckpt) 99 | model.load_state_dict(state_dict['model']) 100 | optimizer.load_state_dict(state_dict['optimizer']) 101 | start_epoch = state_dict['epoch'] + 1 102 | elif args.loadckpt: 103 | # load checkpoint file specified by args.loadckpt 104 | print("loading model {}".format(args.loadckpt)) 105 | state_dict = torch.load(args.loadckpt) 106 | model.load_state_dict(state_dict['model']) 107 | print("start at epoch {}".format(start_epoch)) 108 | print('Number of model parameters: {}'.format(sum([p.data.nelement() for p in model.parameters()]))) 109 | 110 | 111 | # main function 112 | def train(): 113 | milestones = [int(epoch_idx) for epoch_idx in args.lrepochs.split(':')[0].split(',')] 114 | lr_gamma = 1 / float(args.lrepochs.split(':')[1]) 115 | lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=lr_gamma, 116 | last_epoch=start_epoch - 1) 117 | 118 | for epoch_idx in range(start_epoch, args.epochs): 119 | print('Epoch {}:'.format(epoch_idx)) 120 | lr_scheduler.step() 121 | global_step = len(TrainImgLoader) * epoch_idx 122 | 123 | # training 124 | for batch_idx, sample in enumerate(TrainImgLoader): 125 | start_time = time.time() 126 | global_step = len(TrainImgLoader) * epoch_idx + batch_idx 127 | do_summary = global_step % args.summary_freq == 0 128 | loss, scalar_outputs, image_outputs = train_sample(sample, detailed_summary=do_summary) 129 | if do_summary: 130 | save_scalars(logger, 'train', scalar_outputs, global_step) 131 | save_images(logger, 'train', image_outputs, global_step) 132 | del scalar_outputs, image_outputs 133 | print( 134 | 'Epoch {}/{}, Iter {}/{}, train loss = {:.3f}, time = {:.3f}'.format(epoch_idx, args.epochs, batch_idx, 135 | len(TrainImgLoader), loss, 136 | time.time() - start_time)) 137 | 138 | # checkpoint 139 | if (epoch_idx + 1) % args.save_freq == 0: 140 | torch.save({ 141 | 'epoch': epoch_idx, 142 | 'model': model.state_dict(), 143 | 'optimizer': optimizer.state_dict()}, 144 | "{}/model_{:0>6}.ckpt".format(args.logdir, epoch_idx)) 145 | 146 | # testing 147 | avg_test_scalars = DictAverageMeter() 148 | for batch_idx, sample in enumerate(TestImgLoader): 149 | start_time = time.time() 150 | global_step = len(TrainImgLoader) * epoch_idx + batch_idx 151 | do_summary = global_step % args.summary_freq == 0 152 | loss, scalar_outputs, image_outputs = test_sample(sample, detailed_summary=do_summary) 153 | if do_summary: 154 | save_scalars(logger, 'test', scalar_outputs, global_step) 155 | save_images(logger, 'test', image_outputs, global_step) 156 | avg_test_scalars.update(scalar_outputs) 157 | del scalar_outputs, image_outputs 158 | print('Epoch {}/{}, Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(epoch_idx, args.epochs, batch_idx, 159 | len(TestImgLoader), loss, 160 | time.time() - start_time)) 161 | save_scalars(logger, 'fulltest', avg_test_scalars.mean(), global_step) 162 | print("avg_test_scalars:", avg_test_scalars.mean()) 163 | # gc.collect() 164 | 165 | 166 | def test(): 167 | avg_test_scalars = DictAverageMeter() 168 | for batch_idx, sample in enumerate(TestImgLoader): 169 | start_time = time.time() 170 | loss, scalar_outputs, image_outputs = test_sample(sample, detailed_summary=True) 171 | avg_test_scalars.update(scalar_outputs) 172 | del scalar_outputs, image_outputs 173 | print('Iter {}/{}, test loss = {:.3f}, time = {:3f}'.format(batch_idx, len(TestImgLoader), loss, 174 | time.time() - start_time)) 175 | if batch_idx % 100 == 0: 176 | print("Iter {}/{}, test results = {}".format(batch_idx, len(TestImgLoader), avg_test_scalars.mean())) 177 | print("final", avg_test_scalars) 178 | 179 | 180 | def train_sample(sample, detailed_summary=False): 181 | model.train() 182 | optimizer.zero_grad() 183 | 184 | sample_cuda = tocuda(sample) 185 | depth_gt = sample_cuda["depth"] 186 | mask = sample_cuda["mask"] 187 | 188 | outputs = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"]) 189 | depth_est = outputs["depth"] 190 | 191 | loss = model_loss(depth_est, depth_gt, mask) 192 | loss.backward() 193 | optimizer.step() 194 | 195 | scalar_outputs = {"loss": loss} 196 | image_outputs = {"depth_est": depth_est * mask, "depth_gt": sample["depth"], 197 | "ref_img": sample["imgs"][:, 0], 198 | "mask": sample["mask"]} 199 | if detailed_summary: 200 | image_outputs["errormap"] = (depth_est - depth_gt).abs() * mask 201 | scalar_outputs["abs_depth_error"] = AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5) 202 | scalar_outputs["thres2mm_error"] = Thres_metrics(depth_est, depth_gt, mask > 0.5, 2) 203 | scalar_outputs["thres4mm_error"] = Thres_metrics(depth_est, depth_gt, mask > 0.5, 4) 204 | scalar_outputs["thres8mm_error"] = Thres_metrics(depth_est, depth_gt, mask > 0.5, 8) 205 | 206 | return tensor2float(loss), tensor2float(scalar_outputs), image_outputs 207 | 208 | 209 | @make_nograd_func 210 | def test_sample(sample, detailed_summary=True): 211 | model.eval() 212 | 213 | sample_cuda = tocuda(sample) 214 | depth_gt = sample_cuda["depth"] 215 | mask = sample_cuda["mask"] 216 | 217 | outputs = model(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"]) 218 | depth_est = outputs["depth"] 219 | 220 | loss = model_loss(depth_est, depth_gt, mask) 221 | 222 | scalar_outputs = {"loss": loss} 223 | image_outputs = {"depth_est": depth_est * mask, "depth_gt": sample["depth"], 224 | "ref_img": sample["imgs"][:, 0], 225 | "mask": sample["mask"]} 226 | if detailed_summary: 227 | image_outputs["errormap"] = (depth_est - depth_gt).abs() * mask 228 | 229 | scalar_outputs["abs_depth_error"] = AbsDepthError_metrics(depth_est, depth_gt, mask > 0.5) 230 | scalar_outputs["thres2mm_error"] = Thres_metrics(depth_est, depth_gt, mask > 0.5, 2) 231 | scalar_outputs["thres4mm_error"] = Thres_metrics(depth_est, depth_gt, mask > 0.5, 4) 232 | scalar_outputs["thres8mm_error"] = Thres_metrics(depth_est, depth_gt, mask > 0.5, 8) 233 | 234 | return tensor2float(loss), tensor2float(scalar_outputs), image_outputs 235 | 236 | 237 | def profile(): 238 | warmup_iter = 5 239 | iter_dataloader = iter(TestImgLoader) 240 | 241 | @make_nograd_func 242 | def do_iteration(): 243 | torch.cuda.synchronize() 244 | torch.cuda.synchronize() 245 | start_time = time.perf_counter() 246 | test_sample(next(iter_dataloader), detailed_summary=True) 247 | torch.cuda.synchronize() 248 | end_time = time.perf_counter() 249 | return end_time - start_time 250 | 251 | for i in range(warmup_iter): 252 | t = do_iteration() 253 | print('WarpUp Iter {}, time = {:.4f}'.format(i, t)) 254 | 255 | with torch.autograd.profiler.profile(enabled=True, use_cuda=True) as prof: 256 | for i in range(5): 257 | t = do_iteration() 258 | print('Profile Iter {}, time = {:.4f}'.format(i, t)) 259 | time.sleep(0.02) 260 | 261 | if prof is not None: 262 | # print(prof) 263 | trace_fn = 'chrome-trace.bin' 264 | prof.export_chrome_trace(trace_fn) 265 | print("chrome trace file is written to: ", trace_fn) 266 | 267 | 268 | if __name__ == '__main__': 269 | if args.mode == "train": 270 | train() 271 | elif args.mode == "test": 272 | test() 273 | elif args.mode == "profile": 274 | profile() 275 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | MVS_TRAINING="/home/xyguo/dataset.ssd/dtu_mvs/processed/mvs_training/dtu/" 3 | python train.py --dataset=dtu_yao --batch_size=4 --trainpath=$MVS_TRAINING --trainlist lists/dtu/train.txt --testlist lists/dtu/test.txt --numdepth=192 --logdir ./checkpoints/d192 $@ 4 | -------------------------------------------------------------------------------- /utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torchvision.utils as vutils 3 | import torch 4 | import torch.nn.functional as F 5 | 6 | 7 | # print arguments 8 | def print_args(args): 9 | print("################################ args ################################") 10 | for k, v in args.__dict__.items(): 11 | print("{0: <10}\t{1: <30}\t{2: <20}".format(k, str(v), str(type(v)))) 12 | print("########################################################################") 13 | 14 | 15 | # torch.no_grad warpper for functions 16 | def make_nograd_func(func): 17 | def wrapper(*f_args, **f_kwargs): 18 | with torch.no_grad(): 19 | ret = func(*f_args, **f_kwargs) 20 | return ret 21 | 22 | return wrapper 23 | 24 | 25 | # convert a function into recursive style to handle nested dict/list/tuple variables 26 | def make_recursive_func(func): 27 | def wrapper(vars): 28 | if isinstance(vars, list): 29 | return [wrapper(x) for x in vars] 30 | elif isinstance(vars, tuple): 31 | return tuple([wrapper(x) for x in vars]) 32 | elif isinstance(vars, dict): 33 | return {k: wrapper(v) for k, v in vars.items()} 34 | else: 35 | return func(vars) 36 | 37 | return wrapper 38 | 39 | 40 | @make_recursive_func 41 | def tensor2float(vars): 42 | if isinstance(vars, float): 43 | return vars 44 | elif isinstance(vars, torch.Tensor): 45 | return vars.data.item() 46 | else: 47 | raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars))) 48 | 49 | 50 | @make_recursive_func 51 | def tensor2numpy(vars): 52 | if isinstance(vars, np.ndarray): 53 | return vars 54 | elif isinstance(vars, torch.Tensor): 55 | return vars.detach().cpu().numpy().copy() 56 | else: 57 | raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) 58 | 59 | 60 | @make_recursive_func 61 | def tocuda(vars): 62 | if isinstance(vars, torch.Tensor): 63 | return vars.cuda() 64 | elif isinstance(vars, str): 65 | return vars 66 | else: 67 | raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) 68 | 69 | 70 | def save_scalars(logger, mode, scalar_dict, global_step): 71 | scalar_dict = tensor2float(scalar_dict) 72 | for key, value in scalar_dict.items(): 73 | if not isinstance(value, (list, tuple)): 74 | name = '{}/{}'.format(mode, key) 75 | logger.add_scalar(name, value, global_step) 76 | else: 77 | for idx in range(len(value)): 78 | name = '{}/{}_{}'.format(mode, key, idx) 79 | logger.add_scalar(name, value[idx], global_step) 80 | 81 | 82 | def save_images(logger, mode, images_dict, global_step): 83 | images_dict = tensor2numpy(images_dict) 84 | 85 | def preprocess(name, img): 86 | if not (len(img.shape) == 3 or len(img.shape) == 4): 87 | raise NotImplementedError("invalid img shape {}:{} in save_images".format(name, img.shape)) 88 | if len(img.shape) == 3: 89 | img = img[:, np.newaxis, :, :] 90 | img = torch.from_numpy(img[:1]) 91 | return vutils.make_grid(img, padding=0, nrow=1, normalize=True, scale_each=True) 92 | 93 | for key, value in images_dict.items(): 94 | if not isinstance(value, (list, tuple)): 95 | name = '{}/{}'.format(mode, key) 96 | logger.add_image(name, preprocess(name, value), global_step) 97 | else: 98 | for idx in range(len(value)): 99 | name = '{}/{}_{}'.format(mode, key, idx) 100 | logger.add_image(name, preprocess(name, value[idx]), global_step) 101 | 102 | 103 | class DictAverageMeter(object): 104 | def __init__(self): 105 | self.data = {} 106 | self.count = 0 107 | 108 | def update(self, new_input): 109 | self.count += 1 110 | if len(self.data) == 0: 111 | for k, v in new_input.items(): 112 | if not isinstance(v, float): 113 | raise NotImplementedError("invalid data {}: {}".format(k, type(v))) 114 | self.data[k] = v 115 | else: 116 | for k, v in new_input.items(): 117 | if not isinstance(v, float): 118 | raise NotImplementedError("invalid data {}: {}".format(k, type(v))) 119 | self.data[k] += v 120 | 121 | def mean(self): 122 | return {k: v / self.count for k, v in self.data.items()} 123 | 124 | 125 | # a wrapper to compute metrics for each image individually 126 | def compute_metrics_for_each_image(metric_func): 127 | def wrapper(depth_est, depth_gt, mask, *args): 128 | batch_size = depth_gt.shape[0] 129 | results = [] 130 | # compute result one by one 131 | for idx in range(batch_size): 132 | ret = metric_func(depth_est[idx], depth_gt[idx], mask[idx], *args) 133 | results.append(ret) 134 | return torch.stack(results).mean() 135 | 136 | return wrapper 137 | 138 | 139 | @make_nograd_func 140 | @compute_metrics_for_each_image 141 | def Thres_metrics(depth_est, depth_gt, mask, thres): 142 | assert isinstance(thres, (int, float)) 143 | depth_est, depth_gt = depth_est[mask], depth_gt[mask] 144 | errors = torch.abs(depth_est - depth_gt) 145 | err_mask = errors > thres 146 | return torch.mean(err_mask.float()) 147 | 148 | 149 | # NOTE: please do not use this to build up training loss 150 | @make_nograd_func 151 | @compute_metrics_for_each_image 152 | def AbsDepthError_metrics(depth_est, depth_gt, mask): 153 | depth_est, depth_gt = depth_est[mask], depth_gt[mask] 154 | return torch.mean((depth_est - depth_gt).abs()) 155 | --------------------------------------------------------------------------------