├── README.md ├── assets ├── motivation.png └── pipeline.png ├── datasets ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── blendedmvs.cpython-37.pyc │ ├── data_io.cpython-37.pyc │ ├── dtu_yao.cpython-37.pyc │ └── general_eval.cpython-37.pyc ├── blendedmvs.py ├── data_io.py ├── dtu_yao.py ├── general_eval.py └── lists │ ├── blendedmvs │ ├── all_list.txt │ ├── training_list.txt │ └── validation_list.txt │ └── dtu │ ├── single.txt │ ├── test.txt │ ├── train.txt │ ├── trainval.txt │ └── val.txt ├── filter ├── __init__.py ├── __pycache__ │ ├── __init__.cpython-37.pyc │ ├── dypcd_tanks.cpython-37.pyc │ ├── pcd.cpython-37.pyc │ └── tank_test_config.cpython-37.pyc ├── dypcd_tanks.py ├── pcd.py └── tank_test_config.py ├── loss.py ├── main.py ├── model.py ├── networks ├── module.py └── mvsnet.py ├── scripts ├── blendedmvs_finetune.sh ├── dtu_test.sh ├── evaluation_dtu │ ├── BaseEval2Obj_web.m │ ├── BaseEvalMain_web.m │ ├── ComputeStat_web.m │ ├── MaxDistCP.m │ ├── PointCompareMain.m │ ├── plyread.m │ └── reducePts_haa.m ├── tank_test.sh └── train.sh └── tools.py /README.md: -------------------------------------------------------------------------------- 1 | # (ICCV 2023) Constraining Depth Map Geometry for Multi-View Stereo: A Dual-Depth Approach with Saddle-shaped Depth Cells 2 | 3 | - Xinyi Ye, Weiyue Zhao, Tianqi Liu, Zihao Huang, Zhiguo Cao, Xin Li 4 | 5 | ## [Paper](https://openaccess.thecvf.com/content/ICCV2023/html/Ye_Constraining_Depth_Map_Geometry_for_Multi-View_Stereo_A_Dual-Depth_Approach_ICCV_2023_paper.html) | Project Page |[Arxiv](https://arxiv.org/abs/2307.09160) | [Model](https://pan.baidu.com/s/1sw4lkIzoOymBJNFp622iVA) | [Points](https://pan.baidu.com/s/1bos9KatNs7WlrE3-JbNvqA ) 6 | ![image](assets/pipeline.png) 7 | 8 | # Highlights 9 | ![image](assets/motivation.png) 10 | 11 | In this work,**we propose a fresh viewpoint for considering depth geometry in multi-view stereo, a factor that has not been adequately concerned in prior works**. We demonstrated that different depth geometries suffer from significant performance gaps, even for the same depth estimation error case in the MVS reconstruction task both qualitatively and quantitatively. Based on the concept, we proposed the depth geometry with saddle-shaped cells for the first time and a dual-depth approach to constraint depth map to approach the proposed geometry. 12 | 13 | # Abstract 14 | 15 | Learning-based multi-view stereo (MVS) methods deal with predicting accurate depth maps to achieve an accurate and complete 3D representation. Despite the excellent performance, existing methods ignore the fact that a suitable depth geometry is also critical in MVS. In this paper, we demonstrate that different depth geometries have significant performance gaps, even using the same depth prediction error. Therefore, we introduce an ideal depth geometry composed of **Saddle-Shaped Cell**s, whose predicted depth map oscillates upward and downward around the ground-truth surface, rather than maintaining a continuous and smooth depth plane. To achieve it, we develop a coarse-to-fine framework called Dual-MVSNet (DMVSNet), which can produce an oscillating depth plane. Technically, we predict two depth values for each pixel (**Dual-Depth**), and propose a novel loss function and a checkerboard-shaped selecting strategy to constrain the predicted depth geometry. Compared to existing methods,DMVSNet achieves a high rank on the DTU benchmark and obtains the top performance on challenging scenes of Tanks and Temples, demonstrating its strong performance and generalization ability. Our method also points to a new research direction for considering depth geometry in MVS. 16 | 17 | # Prepare Data 18 | #### 1. DTU Dataset 19 | 20 | **Training Data**. We adopt the full resolution ground-truth depth provided in CasMVSNet or MVSNet. Download [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view) and [Depth raw](https://virutalbuy-public.oss-cn-hangzhou.aliyuncs.com/share/cascade-stereo/CasMVSNet/dtu_data/dtu_train_hr/Depths_raw.zip). 21 | Unzip them and put the `Depth_raw` to `dtu_training` folder. The structure is just like: 22 | ``` 23 | dtu_training 24 | ├── Cameras 25 | ├── Depths 26 | ├── Depths_raw 27 | └── Rectified 28 | ``` 29 | **Testing Data**. Download [DTU testing data](https://drive.google.com/file/d/135oKPefcPTsdtLRzoDAQtPpHuoIrpRI_/view) and unzip it. The structure is just like: 30 | ``` 31 | dtu_testing 32 | ├── Cameras 33 | ├── scan1 34 | ├── scan2 35 | ├── ... 36 | ``` 37 | 38 | #### 2. BlendedMVS Dataset 39 | 40 | **Training Data** and **Validation Data**. Download [BlendedMVS](https://drive.google.com/file/d/1ilxls-VJNvJnB7IaFj7P0ehMPr7ikRCb/view) and 41 | unzip it. And we only adopt 42 | BlendedMVS for finetuning and not testing on it. The structure is just like: 43 | ``` 44 | blendedmvs 45 | ├── 5a0271884e62597cdee0d0eb 46 | ├── 5a3ca9cb270f0e3f14d0eddb 47 | ├── ... 48 | ├── training_list.txt 49 | ├── ... 50 | ``` 51 | 52 | #### 3. Tanks and Temples Dataset 53 | 54 | **Testing Data**. Download [Tanks and Temples](https://drive.google.com/file/d/1YArOJaX9WVLJh4757uE8AEREYkgszrCo/view) and 55 | unzip it. Here, we adopt the camera parameters of short depth range version (Included in your download), therefore, **you should 56 | replace the `cams` folder in `intermediate` folder with the short depth range version manually.** The 57 | structure is just like: 58 | ``` 59 | tanksandtemples 60 | ├── advanced 61 | │ ├── Auditorium 62 | │ ├── ... 63 | └── intermediate 64 | ├── Family 65 | ├── ... 66 | ``` 67 | # Environment 68 | - PyTorch 1.8.1 69 | - Python 3.7 70 | - progressbar 2.5 71 | - thop 0.1 72 | 73 | # Scripts 74 | #### 1. train on DTU 75 | - modify `datapath` in `scripts/train.sh` 76 | ```bash 77 | bash scripts/train.sh 78 | ``` 79 | #### 2. evaluate on DTU 80 | - modify `datapath` and `resume` in `scripts/dtu_test.sh` 81 | ```bash 82 | bash scripts/dtu_test.sh 83 | ``` 84 | - modify `datapath`, `plyPath`, `resultsPath` in `scripts/evaluation_dtu/BaseEvalMain_web.m` 85 | - modify `datapath`, `resultsPath` in `scripts/evaluation_dtu/ComputeStat_web.m` 86 | ``` 87 | cd ./scripts/evaluation_dtu/ 88 | matlab -nodisplay 89 | 90 | BaseEvalMain_web 91 | 92 | ComputeStat_web 93 | ``` 94 | #### 3. finetune on BlendedMVS 95 | - modify `datapath` and `resume` in `scripts/blendedmvs_finetune.sh` 96 | ```bash 97 | bash scripts/blendedmvs_finetune.sh 98 | ``` 99 | 100 | #### 4. evaluate on Tanks and Temple 101 | - modify `datapath` and `resume` in `scripts/dtu_test.sh` 102 | 103 | #### 5 points and model 104 | - [Points](https://pan.baidu.com/s/1bos9KatNs7WlrE3-JbNvqA) (extraction code: 2ygz) 105 | - [Model](https://pan.baidu.com/s/1sw4lkIzoOymBJNFp622iVA)(extraction code: 8lly) 106 | # Citation 107 | ```bibtex 108 | @inproceedings{Ye2023Dmvsnet, 109 | title={Constraining Depth Map Geometry for Multi-View Stereo: A Dual-Depth Approach with Saddle-shaped Depth Cells}, 110 | author={Xinyi Ye, Weiyue Zhao, Tianqi Liu, Zihao Huang, Zhiguo Cao, Xin Li}, 111 | booktitle={Proceedings of the IEEE/CVF International Conference on Computer Vision}, 112 | year={2023} 113 | } 114 | ``` 115 | 116 | # Acknowledge 117 | We have incorporated certain release codes from [Unimvsnet](https://github.com/prstrive/UniMVSNet) and extend our gratitude for their excellent work -------------------------------------------------------------------------------- /assets/motivation.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DIVE128/DMVSNet/d37a0b2da8017ac88d48208e4f062c72248c48d5/assets/motivation.png -------------------------------------------------------------------------------- /assets/pipeline.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DIVE128/DMVSNet/d37a0b2da8017ac88d48208e4f062c72248c48d5/assets/pipeline.png -------------------------------------------------------------------------------- /datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils import data 3 | import torch.distributed as dist 4 | from torch.utils.data import RandomSampler, SequentialSampler 5 | 6 | from .dtu_yao import MVSDataset as DtuDataset 7 | from .general_eval import MVSDataset as EvalDataset 8 | from .blendedmvs import BlendedMVSDataset 9 | 10 | 11 | def get_loader(args, datapath, listfile, nviews, mode="train",force_test=False): 12 | 13 | 14 | if args.dataset_name == "dtu_yao": 15 | dataset = DtuDataset(datapath, listfile, mode, nviews, args.img_size, args.numdepth, args.interval_scale) 16 | elif args.dataset_name == "general_eval": 17 | dataset = EvalDataset(datapath, listfile, mode, nviews, args.numdepth, args.interval_scale, args.inverse_depth, 18 | max_h=args.max_h, max_w=args.max_w, fix_res=args.fix_res) 19 | elif args.dataset_name == "blendedmvs": 20 | dataset = BlendedMVSDataset(datapath, listfile, mode, nviews, args.numdepth, args.interval_scale) 21 | else: 22 | raise NotImplementedError("Don't support dataset: {}".format(args.dataset_name)) 23 | 24 | if args.distributed: 25 | sampler = torch.utils.data.DistributedSampler(dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()) 26 | else: 27 | sampler = RandomSampler(dataset) if (mode == "train") else SequentialSampler(dataset) 28 | 29 | data_loader = data.DataLoader(dataset, args.batch_size, sampler=sampler, num_workers=4, drop_last=(mode == "train"), pin_memory=True) 30 | 31 | return data_loader, sampler 32 | -------------------------------------------------------------------------------- /datasets/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DIVE128/DMVSNet/d37a0b2da8017ac88d48208e4f062c72248c48d5/datasets/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/blendedmvs.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DIVE128/DMVSNet/d37a0b2da8017ac88d48208e4f062c72248c48d5/datasets/__pycache__/blendedmvs.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/data_io.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DIVE128/DMVSNet/d37a0b2da8017ac88d48208e4f062c72248c48d5/datasets/__pycache__/data_io.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/dtu_yao.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DIVE128/DMVSNet/d37a0b2da8017ac88d48208e4f062c72248c48d5/datasets/__pycache__/dtu_yao.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/__pycache__/general_eval.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DIVE128/DMVSNet/d37a0b2da8017ac88d48208e4f062c72248c48d5/datasets/__pycache__/general_eval.cpython-37.pyc -------------------------------------------------------------------------------- /datasets/blendedmvs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | from PIL import Image 4 | import numpy as np 5 | from torch.utils.data import Dataset 6 | from torchvision import transforms 7 | 8 | from datasets.data_io import * 9 | 10 | 11 | def motion_blur(img: np.ndarray, max_kernel_size=3): 12 | # Either vertial, hozirontal or diagonal blur 13 | mode = np.random.choice(['h', 'v', 'diag_down', 'diag_up']) 14 | ksize = np.random.randint(0, (max_kernel_size + 1) / 2) * 2 + 1 # make sure is odd 15 | center = int((ksize - 1) / 2) 16 | kernel = np.zeros((ksize, ksize)) 17 | if mode == 'h': 18 | kernel[center, :] = 1. 19 | elif mode == 'v': 20 | kernel[:, center] = 1. 21 | elif mode == 'diag_down': 22 | kernel = np.eye(ksize) 23 | elif mode == 'diag_up': 24 | kernel = np.flip(np.eye(ksize), 0) 25 | var = ksize * ksize / 16. 26 | grid = np.repeat(np.arange(ksize)[:, np.newaxis], ksize, axis=-1) 27 | gaussian = np.exp(-(np.square(grid - center) + np.square(grid.T - center)) / (2. * var)) 28 | kernel *= gaussian 29 | kernel /= np.sum(kernel) 30 | img = cv2.filter2D(img, -1, kernel) 31 | return img 32 | 33 | 34 | class BlendedMVSDataset(Dataset): 35 | def __init__(self, datapath, listfile, mode, nviews, ndepths=128, interval_scale=1.06): 36 | super(BlendedMVSDataset, self).__init__() 37 | 38 | self.datapath = datapath 39 | self.listfile = listfile 40 | self.mode = mode 41 | self.nviews = nviews 42 | self.ndepths = ndepths 43 | self.interval_scale = interval_scale 44 | self.metas = self.build_list() 45 | self.transform = transforms.ColorJitter(brightness=0.25, contrast=(0.3, 1.5)) 46 | 47 | def build_list(self): 48 | metas = [] 49 | proj_list = open(self.listfile).read().splitlines() 50 | 51 | for data_name in proj_list: 52 | dataset_folder = os.path.join(self.datapath, data_name) 53 | 54 | # read cluster 55 | cluster_path = os.path.join(dataset_folder, 'cams', 'pair.txt') 56 | cluster_lines = open(cluster_path).read().splitlines() 57 | image_num = int(cluster_lines[0]) 58 | 59 | # get per-image info 60 | for idx in range(0, image_num): 61 | 62 | ref_id = int(cluster_lines[2 * idx + 1]) 63 | cluster_info = cluster_lines[2 * idx + 2].rstrip().split() 64 | total_view_num = int(cluster_info[0]) 65 | if total_view_num < self.nviews - 1: 66 | continue 67 | 68 | src_ids = [int(x) for x in cluster_info[1::2]] 69 | 70 | metas.append((data_name, ref_id, src_ids)) 71 | 72 | return metas 73 | 74 | def __len__(self): 75 | return len(self.metas) 76 | 77 | def read_img(self, filename): 78 | img = Image.open(filename) 79 | if self.mode == "train": 80 | img = self.transform(img) 81 | img = motion_blur(np.array(img, dtype=np.float32)) 82 | 83 | # scale 0~255 to 0~1 84 | np_img = np.array(img, dtype=np.float32) / 255.0 85 | return np_img 86 | 87 | def read_cam(self, filename): 88 | with open(filename) as f: 89 | lines = f.readlines() 90 | lines = [line.rstrip() for line in lines] 91 | # extrinsics: line [1,5), 4x4 matrix 92 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 93 | # intrinsics: line [7-10), 3x3 matrix 94 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 95 | # depth_min & depth_interval: line 11 96 | depth_min = float(lines[11].split()[0]) 97 | depth_interval = float(lines[11].split()[1]) * self.interval_scale 98 | # depth_sample_num = float(lines[11].split()[2]) 99 | # depth_max = float(lines[11].split()[3]) 100 | return intrinsics, extrinsics, depth_min, depth_interval 101 | 102 | def read_mask(self, filename): 103 | masked_img = np.array(Image.open(filename), dtype=np.float32) 104 | mask = np.any(masked_img > 10, axis=2).astype(np.float32) 105 | 106 | h, w = mask.shape 107 | mask_ms = { 108 | "stage1": cv2.resize(mask, (w // 4, h // 4), interpolation=cv2.INTER_NEAREST), 109 | "stage2": cv2.resize(mask, (w // 2, h // 2), interpolation=cv2.INTER_NEAREST), 110 | "stage3": mask, 111 | } 112 | return mask_ms 113 | 114 | def read_depth_and_mask(self, filename, depth_min): 115 | # read pfm depth file 116 | # (576, 768) 117 | depth = np.array(read_pfm(filename)[0], dtype=np.float32) 118 | mask = np.array(depth >= depth_min, dtype=np.float32) 119 | 120 | h, w = depth.shape 121 | mask_ms = { 122 | "stage1": cv2.resize(mask, (w // 4, h // 4), interpolation=cv2.INTER_NEAREST), 123 | "stage2": cv2.resize(mask, (w // 2, h // 2), interpolation=cv2.INTER_NEAREST), 124 | "stage3": mask, 125 | } 126 | depth_ms = { 127 | "stage1": cv2.resize(depth, (w // 4, h // 4), interpolation=cv2.INTER_NEAREST), 128 | "stage2": cv2.resize(depth, (w // 2, h // 2), interpolation=cv2.INTER_NEAREST), 129 | "stage3": depth, 130 | } 131 | return depth_ms, mask_ms 132 | 133 | def __getitem__(self, idx): 134 | data_name, ref_id, src_ids = self.metas[idx] 135 | view_ids = [ref_id] + src_ids[:self.nviews - 1] 136 | 137 | imgs = [] 138 | img_paths = [] 139 | proj_matrices = [] 140 | mask_ms, depth_ms, depth_values = None, None, None 141 | 142 | for i, vid in enumerate(view_ids): 143 | img_path = os.path.join(self.datapath, data_name, 'blended_images', '%08d.jpg' % vid) 144 | cam_path = os.path.join(self.datapath, data_name, 'cams', '%08d_cam.txt' % vid) 145 | img_paths.append(img_path) 146 | 147 | img = self.read_img(img_path) 148 | intrinsics, extrinsics, depth_min, depth_interval = self.read_cam(cam_path) 149 | 150 | proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) # 151 | proj_mat[0, :4, :4] = extrinsics 152 | proj_mat[1, :3, :3] = intrinsics 153 | 154 | imgs.append(img) 155 | proj_matrices.append(proj_mat) 156 | 157 | if i == 0: 158 | ref_depth_path = os.path.join(self.datapath, data_name, 'rendered_depth_maps', '%08d.pfm' % vid) 159 | depth_ms, mask_ms = self.read_depth_and_mask(ref_depth_path, depth_min) 160 | 161 | # ref_masked_img_path = os.path.join(self.datapath, data_name, 'blended_images', '%08d_masked.jpg' % vid) 162 | # mask_ms = self.read_mask(ref_masked_img_path) 163 | 164 | # -0.5 to prevent blendedmvs bug 165 | # get depth values 166 | depth_values = np.arange(depth_min, depth_interval * (self.ndepths - 0.5) + depth_min, depth_interval, dtype=np.float32) 167 | 168 | imgs = np.stack(imgs).transpose([0, 3, 1, 2]) 169 | # ms proj_mats 170 | proj_matrices = np.stack(proj_matrices) 171 | stage1_pjmats = proj_matrices.copy() 172 | stage1_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 0.25 173 | stage2_pjmats = proj_matrices.copy() 174 | stage2_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 0.5 175 | 176 | proj_matrices_ms = { 177 | "stage1": stage1_pjmats, 178 | "stage2": stage2_pjmats, 179 | "stage3": proj_matrices 180 | } 181 | 182 | return {"imgs": imgs, 183 | "img_paths": img_paths, 184 | "proj_matrices": proj_matrices_ms, 185 | "depth": depth_ms, 186 | "depth_values": depth_values, 187 | "mask": mask_ms} 188 | -------------------------------------------------------------------------------- /datasets/data_io.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import re 3 | import sys 4 | 5 | 6 | def read_pfm(filename): 7 | file = open(filename, 'rb') 8 | color = None 9 | width = None 10 | height = None 11 | scale = None 12 | endian = None 13 | 14 | header = file.readline().decode('utf-8').rstrip() 15 | if header == 'PF': 16 | color = True 17 | elif header == 'Pf': 18 | color = False 19 | else: 20 | raise Exception('Not a PFM file.') 21 | 22 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 23 | if dim_match: 24 | width, height = map(int, dim_match.groups()) 25 | else: 26 | raise Exception('Malformed PFM header.') 27 | 28 | scale = float(file.readline().rstrip()) 29 | if scale < 0: # little-endian 30 | endian = '<' 31 | scale = -scale 32 | else: 33 | endian = '>' # big-endian 34 | 35 | data = np.fromfile(file, endian + 'f') 36 | shape = (height, width, 3) if color else (height, width) 37 | 38 | data = np.reshape(data, shape) 39 | data = np.flipud(data) 40 | file.close() 41 | return data, scale 42 | 43 | 44 | def save_pfm(filename, image, scale=1): 45 | file = open(filename, "wb") 46 | color = None 47 | 48 | image = np.flipud(image) 49 | 50 | if image.dtype.name != 'float32': 51 | raise Exception('Image dtype must be float32.') 52 | 53 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 54 | color = True 55 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 56 | color = False 57 | else: 58 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 59 | 60 | file.write('PF\n'.encode('utf-8') if color else 'Pf\n'.encode('utf-8')) 61 | file.write('{} {}\n'.format(image.shape[1], image.shape[0]).encode('utf-8')) 62 | 63 | endian = image.dtype.byteorder 64 | 65 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 66 | scale = -scale 67 | 68 | file.write(('%f\n' % scale).encode('utf-8')) 69 | 70 | image.tofile(file) 71 | file.close() 72 | 73 | import random, cv2 74 | class RandomCrop(object): 75 | def __init__(self, CropSize=0.1): 76 | self.CropSize = CropSize 77 | 78 | def __call__(self, image, normal): 79 | h, w = normal.shape[:2] 80 | img_h, img_w = image.shape[:2] 81 | CropSize_w, CropSize_h = max(1, int(w * self.CropSize)), max(1, int(h * self.CropSize)) 82 | x1, y1 = random.randint(0, CropSize_w), random.randint(0, CropSize_h) 83 | x2, y2 = random.randint(w - CropSize_w, w), random.randint(h - CropSize_h, h) 84 | 85 | normal_crop = normal[y1:y2, x1:x2] 86 | normal_resize = cv2.resize(normal_crop, (w, h), interpolation=cv2.INTER_NEAREST) 87 | 88 | image_crop = image[4*y1:4*y2, 4*x1:4*x2] 89 | image_resize = cv2.resize(image_crop, (img_w, img_h), interpolation=cv2.INTER_LINEAR) 90 | 91 | # import matplotlib.pyplot as plt 92 | # plt.subplot(2, 3, 1) 93 | # plt.imshow(image) 94 | # plt.subplot(2, 3, 2) 95 | # plt.imshow(image_crop) 96 | # plt.subplot(2, 3, 3) 97 | # plt.imshow(image_resize) 98 | # 99 | # plt.subplot(2, 3, 4) 100 | # plt.imshow((normal + 1.0) / 2, cmap="rainbow") 101 | # plt.subplot(2, 3, 5) 102 | # plt.imshow((normal_crop + 1.0) / 2, cmap="rainbow") 103 | # plt.subplot(2, 3, 6) 104 | # plt.imshow((normal_resize + 1.0) / 2, cmap="rainbow") 105 | # plt.show() 106 | # plt.pause(1) 107 | # plt.close() 108 | 109 | return image_resize, normal_resize 110 | def cv2_imread(filename): 111 | return cv2.imread(filename,0) -------------------------------------------------------------------------------- /datasets/dtu_yao.py: -------------------------------------------------------------------------------- 1 | import random 2 | from torch.utils.data import Dataset 3 | from torchvision import transforms 4 | import numpy as np 5 | import os, cv2, time, math 6 | from PIL import Image 7 | from datasets.data_io import * 8 | 9 | 10 | # the DTU dataset preprocessed by Yao Yao (only for training) 11 | class MVSDataset(Dataset): 12 | def __init__(self, datapath, listfile, mode, nviews, img_size=None, ndepths=192, interval_scale=1.06, **kwargs): 13 | super(MVSDataset, self).__init__() 14 | self.img_size = img_size if img_size is not None else [512, 640] 15 | assert self.img_size[0] % 32 == 0 and self.img_size[1] % 32 == 0, 'img_wh must both be multiples of 32!' 16 | 17 | self.datapath = datapath 18 | self.listfile = listfile 19 | self.mode = mode 20 | self.nviews = nviews 21 | self.ndepths = ndepths 22 | self.interval_scale = interval_scale 23 | self.kwargs = kwargs 24 | print("mvsdataset kwargs", self.kwargs) 25 | 26 | assert self.mode in ["train", "val", "test"] 27 | self.metas = self.build_list() 28 | 29 | # # less data 30 | # np.random.shuffle(self.metas) 31 | # self.metas = self.metas[:int(0.5 * len(self.metas))] 32 | 33 | def build_list(self): 34 | metas = [] 35 | with open(self.listfile) as f: 36 | scans = f.readlines() 37 | scans = [line.rstrip() for line in scans] 38 | 39 | # scans 40 | for scan in scans: 41 | pair_file = "Cameras/pair.txt" 42 | # read the pair file 43 | with open(os.path.join(self.datapath, pair_file)) as f: 44 | num_viewpoint = int(f.readline()) 45 | # viewpoints (49) 46 | for view_idx in range(num_viewpoint): 47 | ref_view = int(f.readline().rstrip()) 48 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 49 | # light conditions 0-6 50 | for light_idx in range(7): 51 | metas.append((scan, light_idx, ref_view, src_views)) 52 | print("dataset", self.mode, "metas:", len(metas)) 53 | return metas 54 | 55 | def __len__(self): 56 | return len(self.metas) 57 | 58 | def read_cam_file(self, filename): 59 | with open(filename) as f: 60 | lines = f.readlines() 61 | lines = [line.rstrip() for line in lines] 62 | # extrinsics: line [1,5), 4x4 matrix 63 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 64 | # intrinsics: line [7-10), 3x3 matrix 65 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 66 | # depth_min & depth_interval: line 11 67 | depth_min = float(lines[11].split()[0]) 68 | depth_interval = float(lines[11].split()[1]) * self.interval_scale 69 | return intrinsics, extrinsics, depth_min, depth_interval 70 | 71 | def read_img(self, filename): 72 | img = Image.open(filename) 73 | # img = img.resize(self.img_size[::-1], Image.BILINEAR) 74 | # scale 0~255 to 0~1 75 | np_img = np.array(img, dtype=np.float32) / 255.0 76 | return np_img 77 | 78 | def prepare_img(self, hr_img): 79 | # w1600-h1200-> 800-600 ; crop -> 640, 512; downsample 1/4 -> 160, 128 80 | 81 | # downsample 82 | h, w = hr_img.shape 83 | hr_img_ds = cv2.resize(hr_img, (w // 2, h // 2), interpolation=cv2.INTER_NEAREST) 84 | # crop 85 | h, w = hr_img_ds.shape 86 | # target_h, target_w = 512, 640 87 | target_h, target_w = 512, 640 88 | 89 | start_h, start_w = (h - target_h) // 2, (w - target_w) // 2 90 | hr_img_crop = hr_img_ds[start_h: start_h + target_h, start_w: start_w + target_w] 91 | 92 | # #downsample 93 | # lr_img = cv2.resize(hr_img_crop, (target_w//4, target_h//4), interpolation=cv2.INTER_NEAREST) 94 | 95 | return hr_img_crop 96 | 97 | def read_mask_hr(self, filename): 98 | img = Image.open(filename) 99 | np_img = np.array(img, dtype=np.float32) 100 | np_img = (np_img > 10).astype(np.float32) 101 | np_img = self.prepare_img(np_img) 102 | 103 | h, w = np_img.shape 104 | np_img_ms = { 105 | "stage1": cv2.resize(np_img, (w // 4, h // 4), interpolation=cv2.INTER_NEAREST), 106 | "stage2": cv2.resize(np_img, (w // 2, h // 2), interpolation=cv2.INTER_NEAREST), 107 | "stage3": np_img, 108 | } 109 | return np_img_ms 110 | 111 | def read_depth(self, filename): 112 | # read pfm depth file 113 | return np.array(read_pfm(filename)[0], dtype=np.float32) 114 | 115 | def read_depth_hr(self, filename): 116 | # read pfm depth file 117 | # w1600-h1200-> 800-600 ; crop -> 640, 512; downsample 1/4 -> 160, 128 118 | depth_hr = np.array(read_pfm(filename)[0], dtype=np.float32) 119 | depth_lr = self.prepare_img(depth_hr) 120 | 121 | h, w = depth_lr.shape 122 | depth_lr_ms = { 123 | "stage1": cv2.resize(depth_lr, (w // 4, h // 4), interpolation=cv2.INTER_NEAREST), 124 | "stage2": cv2.resize(depth_lr, (w // 2, h // 2), interpolation=cv2.INTER_NEAREST), 125 | "stage3": depth_lr, 126 | } 127 | return depth_lr_ms 128 | 129 | def __getitem__(self, idx): 130 | meta = self.metas[idx] 131 | scan, light_idx, ref_view, src_views = meta 132 | # use only the reference view and first nviews-1 source views 133 | view_ids = [ref_view] + src_views[:self.nviews - 1] 134 | 135 | imgs = [] 136 | mask = None 137 | depth_ms = None 138 | depth_values = None 139 | proj_matrices = [] 140 | 141 | for i, vid in enumerate(view_ids): 142 | # NOTE that the id in image file names is from 1 to 49 (not 0~48) 143 | img_filename = os.path.join(self.datapath, 144 | 'Rectified/{}_train/rect_{:0>3}_{}_r5000.png'.format(scan, vid + 1, light_idx)) 145 | 146 | mask_filename_hr = os.path.join(self.datapath, 'Depths_raw/{}/depth_visual_{:0>4}.png'.format(scan, vid)) 147 | depth_filename_hr = os.path.join(self.datapath, 'Depths_raw/{}/depth_map_{:0>4}.pfm'.format(scan, vid)) 148 | 149 | proj_mat_filename = os.path.join(self.datapath, 'Cameras/train/{:0>8}_cam.txt').format(vid) 150 | 151 | img = self.read_img(img_filename) 152 | 153 | intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename) 154 | 155 | proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) # 156 | proj_mat[0, :4, :4] = extrinsics 157 | proj_mat[1, :3, :3] = intrinsics 158 | 159 | proj_matrices.append(proj_mat) 160 | 161 | if i == 0: # reference view 162 | mask_read_ms = self.read_mask_hr(mask_filename_hr) 163 | depth_ms = self.read_depth_hr(depth_filename_hr) 164 | 165 | # get depth values 166 | depth_max = depth_interval * self.ndepths + depth_min 167 | depth_values = np.arange(depth_min, depth_max, depth_interval, dtype=np.float32) 168 | 169 | mask = mask_read_ms 170 | 171 | imgs.append(img) 172 | 173 | imgs = np.stack(imgs).transpose([0, 3, 1, 2]) 174 | # ms proj_mats 175 | proj_matrices = np.stack(proj_matrices) 176 | stage2_pjmats = proj_matrices.copy() 177 | stage2_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2 178 | stage3_pjmats = proj_matrices.copy() 179 | stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4 180 | 181 | proj_matrices_ms = { 182 | "stage1": proj_matrices, 183 | "stage2": stage2_pjmats, 184 | "stage3": stage3_pjmats 185 | } 186 | 187 | 188 | 189 | return {"imgs": imgs, 190 | "proj_matrices": proj_matrices_ms, 191 | "depth": depth_ms, 192 | "depth_values": depth_values, 193 | "mask": mask} 194 | -------------------------------------------------------------------------------- /datasets/general_eval.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.data import Dataset 3 | import numpy as np 4 | import os, cv2, time 5 | from PIL import Image, ImageEnhance 6 | from datasets.data_io import * 7 | global_base=32 8 | s_h, s_w = 0, 0 9 | class MVSDataset(Dataset): 10 | def __init__(self, datapath, listfile, mode, nviews, ndepths=192, interval_scale=1.06, inverse_depth=False, **kwargs): 11 | super(MVSDataset, self).__init__() 12 | self.datapath = datapath 13 | self.listfile = listfile 14 | self.mode = mode 15 | self.nviews = nviews 16 | self.ndepths = ndepths 17 | self.interval_scale = interval_scale 18 | self.max_h, self.max_w = kwargs["max_h"], kwargs["max_w"] 19 | self.fix_res = kwargs.get("fix_res", False) #whether to fix the resolution of input image. 20 | self.fix_wh = False 21 | self.inverse_depth = inverse_depth 22 | 23 | assert self.mode == "test" 24 | self.metas = self.build_list() 25 | 26 | def build_list(self): 27 | metas = [] 28 | scans = self.listfile 29 | 30 | interval_scale_dict = {} 31 | # scans 32 | for scan in scans: 33 | # determine the interval scale of each scene. default is 1.06 34 | if isinstance(self.interval_scale, float): 35 | interval_scale_dict[scan] = self.interval_scale 36 | else: 37 | interval_scale_dict[scan] = self.interval_scale[scan] 38 | 39 | pair_file = "{}/pair.txt".format(scan) 40 | # read the pair file 41 | with open(os.path.join(self.datapath, pair_file)) as f: 42 | num_viewpoint = int(f.readline()) 43 | # viewpoints 44 | for view_idx in range(num_viewpoint): 45 | ref_view = int(f.readline().rstrip()) 46 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 47 | # filter by no src view and fill to nviews 48 | if len(src_views) > 0: 49 | if len(src_views) < self.nviews-1: 50 | print("{}< src num_views:{}".format(len(src_views), self.nviews)) 51 | src_views += [src_views[0]] * (self.nviews - len(src_views)) 52 | metas.append((scan, ref_view, src_views, scan)) 53 | 54 | self.interval_scale = interval_scale_dict 55 | print("dataset", self.mode, "metas:", len(metas), "interval_scale:{}".format(self.interval_scale)) 56 | return metas 57 | 58 | def __len__(self): 59 | return len(self.metas) 60 | 61 | def read_cam_file(self, filename, interval_scale): 62 | with open(filename) as f: 63 | lines = f.readlines() 64 | lines = [line.rstrip() for line in lines] 65 | # extrinsics: line [1,5), 4x4 matrix 66 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 67 | # intrinsics: line [7-10), 3x3 matrix 68 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 69 | intrinsics[:2, :] /= 4.0 70 | # depth_min & depth_interval: line 11 71 | depth_min = float(lines[11].split()[0]) 72 | depth_interval = float(lines[11].split()[1]) 73 | 74 | if len(lines[11].split()) >= 3: 75 | num_depth = lines[11].split()[2] 76 | depth_max = depth_min + int(float(num_depth)) * depth_interval 77 | depth_interval = (depth_max - depth_min) / self.ndepths 78 | 79 | depth_interval *= interval_scale 80 | 81 | return intrinsics, extrinsics, depth_min, depth_interval 82 | 83 | def read_img(self, filename): 84 | img = Image.open(filename) 85 | 86 | # colorEnhancer = ImageEnhance.Sharpness(img) 87 | # img = colorEnhancer.enhance(40.0) 88 | # scale 0~255 to 0~1 89 | np_img = np.array(img, dtype=np.float32) / 255. 90 | 91 | return np_img 92 | 93 | def read_depth(self, filename): 94 | # read pfm depth file 95 | return np.array(read_pfm(filename)[0], dtype=np.float32) 96 | 97 | def scale_mvs_input(self, img, intrinsics, max_w, max_h, base=global_base): 98 | h, w = img.shape[:2] 99 | if h > max_h or w > max_w: 100 | scale = 1.0 * max_h / h 101 | if scale * w > max_w: 102 | scale = 1.0 * max_w / w 103 | new_w, new_h = scale * w // base * base, scale * h // base * base 104 | else: 105 | new_w, new_h = 1.0 * w // base * base, 1.0 * h // base * base 106 | 107 | scale_w = 1.0 * new_w / w 108 | scale_h = 1.0 * new_h / h 109 | intrinsics[0, :] *= scale_w 110 | intrinsics[1, :] *= scale_h 111 | 112 | img = cv2.resize(img, (int(new_w), int(new_h))) 113 | 114 | return img, intrinsics 115 | def scale_depth_input(self,img,max_w, max_h, base=global_base,interp=cv2.INTER_NEAREST): 116 | h, w = img.shape[:2] 117 | if h > max_h or w > max_w: 118 | scale = 1.0 * max_h / h 119 | if scale * w > max_w: 120 | scale = 1.0 * max_w / w 121 | new_w, new_h = scale * w // base * base, scale * h // base * base 122 | else: 123 | new_w, new_h = 1.0 * w // base * base, 1.0 * h // base * base 124 | img = cv2.resize(img, (int(new_w), int(new_h)),interpolation=interp) 125 | return img 126 | def __getitem__(self, idx): 127 | global s_h, s_w 128 | meta = self.metas[idx] 129 | scan, ref_view, src_views, scene_name = meta 130 | # use only the reference view and first nviews-1 source views 131 | view_ids = [ref_view] + src_views[:self.nviews - 1] 132 | 133 | imgs = [] 134 | depth_values = None 135 | proj_matrices = [] 136 | 137 | for i, vid in enumerate(view_ids): 138 | img_filename = os.path.join(self.datapath, '{}/images_post/{:0>8}.jpg'.format(scan, vid)) 139 | if not os.path.exists(img_filename): 140 | img_filename = os.path.join(self.datapath, '{}/images/{:0>8}.jpg'.format(scan, vid)) 141 | 142 | proj_mat_filename = os.path.join(self.datapath, '{}/cams/{:0>8}_cam.txt'.format(scan, vid)) 143 | 144 | img = self.read_img(img_filename) 145 | intrinsics, extrinsics, depth_min, depth_interval = self.read_cam_file(proj_mat_filename, interval_scale= 146 | self.interval_scale[scene_name]) 147 | # scale input 148 | img, intrinsics = self.scale_mvs_input(img, intrinsics, self.max_w, self.max_h) 149 | 150 | if self.fix_res: 151 | # using the same standard height or width in entire scene. 152 | s_h, s_w = img.shape[:2] 153 | self.fix_res = False 154 | self.fix_wh = True 155 | 156 | if i == 0: 157 | if not self.fix_wh: 158 | # using the same standard height or width in each nviews. 159 | s_h, s_w = img.shape[:2] 160 | 161 | # resize to standard height or width 162 | c_h, c_w = img.shape[:2] 163 | if (c_h != s_h) or (c_w != s_w): 164 | scale_h = 1.0 * s_h / c_h 165 | scale_w = 1.0 * s_w / c_w 166 | img = cv2.resize(img, (s_w, s_h)) 167 | intrinsics[0, :] *= scale_w 168 | intrinsics[1, :] *= scale_h 169 | 170 | imgs.append(img) 171 | # extrinsics, intrinsics 172 | proj_mat = np.zeros(shape=(2, 4, 4), dtype=np.float32) # 173 | proj_mat[0, :4, :4] = extrinsics 174 | proj_mat[1, :3, :3] = intrinsics 175 | proj_matrices.append(proj_mat) 176 | 177 | if i == 0: # reference view 178 | if self.inverse_depth: 179 | depth_end = depth_interval * self.ndepths + depth_min 180 | depth_values = np.linspace(1.0 / depth_min, 1.0 / depth_end, self.ndepths, endpoint=False) 181 | depth_values = (1.0 / depth_values).astype(np.float32) 182 | else: 183 | depth_values = np.arange(depth_min, depth_interval * (self.ndepths - 0.5) + depth_min, depth_interval, 184 | dtype=np.float32) 185 | #all 186 | imgs = np.stack(imgs).transpose([0, 3, 1, 2]) 187 | # ms proj_mats 188 | proj_matrices = np.stack(proj_matrices) 189 | stage2_pjmats = proj_matrices.copy() 190 | stage2_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 2 191 | stage3_pjmats = proj_matrices.copy() 192 | stage3_pjmats[:, 1, :2, :] = proj_matrices[:, 1, :2, :] * 4 193 | 194 | proj_matrices_ms = { 195 | "stage1": proj_matrices, 196 | "stage2": stage2_pjmats, 197 | "stage3": stage3_pjmats 198 | } 199 | 200 | return {"imgs": imgs, 201 | "proj_matrices": proj_matrices_ms, 202 | "depth_values": depth_values, 203 | "filename": scan + '/{}/' + '{:0>8}'.format(view_ids[0]) + "{}"} 204 | -------------------------------------------------------------------------------- /datasets/lists/blendedmvs/all_list.txt: -------------------------------------------------------------------------------- 1 | 5c1f33f1d33e1f2e4aa6dda4 2 | 5bfe5ae0fe0ea555e6a969ca 3 | 5bff3c5cfe0ea555e6bcbf3a 4 | 58eaf1513353456af3a1682a 5 | 5bfc9d5aec61ca1dd69132a2 6 | 5bf18642c50e6f7f8bdbd492 7 | 5bf26cbbd43923194854b270 8 | 5bf17c0fd439231948355385 9 | 5be3ae47f44e235bdbbc9771 10 | 5be3a5fb8cfdd56947f6b67c 11 | 5bbb6eb2ea1cfa39f1af7e0c 12 | 5ba75d79d76ffa2c86cf2f05 13 | 5bb7a08aea1cfa39f1a947ab 14 | 5b864d850d072a699b32f4ae 15 | 5b7a3890fc8fcf6781e2593a 16 | 5b6eff8b67b396324c5b2672 17 | 5b6e716d67b396324c2d77cb 18 | 5b69cc0cb44b61786eb959bf 19 | 5b62647143840965efc0dbde 20 | 5b60fa0c764f146feef84df0 21 | 5b558a928bbfb62204e77ba2 22 | 5b271079e0878c3816dacca4 23 | 5b08286b2775267d5b0634ba 24 | 5afacb69ab00705d0cefdd5b 25 | 5af28cea59bc705737003253 26 | 5af02e904c8216544b4ab5a2 27 | 5aa515e613d42d091d29d300 28 | 5c34529873a8df509ae57b58 29 | 5c34300a73a8df509add216d 30 | 5c189f2326173c3a09ed7ef3 31 | 5c1af2e2bee9a723c963d019 32 | 5c1892f726173c3a09ea9aeb 33 | 5c0d13b795da9479e12e2ee9 34 | 5c062d84a96e33018ff6f0a6 35 | 5bfd0f32ec61ca1dd69dc77b 36 | 5bf21799d43923194842c001 37 | 5bf3a82cd439231948877aed 38 | 5bf03590d4392319481971dc 39 | 5beb6e66abd34c35e18e66b9 40 | 5be883a4f98cee15019d5b83 41 | 5be47bf9b18881428d8fbc1d 42 | 5bcf979a6d5f586b95c258cd 43 | 5bce7ac9ca24970bce4934b6 44 | 5bb8a49aea1cfa39f1aa7f75 45 | 5b950c71608de421b1e7318f 46 | 5b78e57afc8fcf6781d0c3ba 47 | 5b21e18c58e2823a67a10dd8 48 | 5b22269758e2823a67a3bd03 49 | 5b192eb2170cf166458ff886 50 | 5ae2e9c5fe405c5076abc6b2 51 | 5adc6bd52430a05ecb2ffb85 52 | 5ab8b8e029f5351f7f2ccf59 53 | 5abc2506b53b042ead637d86 54 | 5ab85f1dac4291329b17cb50 55 | 5a969eea91dfc339a9a3ad2c 56 | 5a8aa0fab18050187cbe060e 57 | 5a7d3db14989e929563eb153 58 | 5a69c47d0d5d0a7f3b2e9752 59 | 5a618c72784780334bc1972d 60 | 5a6400933d809f1d8200af15 61 | 5a6464143d809f1d8208c43c 62 | 5a588a8193ac3d233f77fbca 63 | 5a57542f333d180827dfc132 64 | 5a572fd9fc597b0478a81d14 65 | 5a563183425d0f5186314855 66 | 5a4a38dad38c8a075495b5d2 67 | 5a48d4b2c7dab83a7d7b9851 68 | 5a489fb1c7dab83a7d7b1070 69 | 5a48ba95c7dab83a7d7b44ed 70 | 5a3ca9cb270f0e3f14d0eddb 71 | 5a3cb4e4270f0e3f14d12f43 72 | 5a3f4aba5889373fbbc5d3b5 73 | 5a0271884e62597cdee0d0eb 74 | 59e864b2a9e91f2c5529325f 75 | 59d2657f82ca7774b1ec081d 76 | 599aa591d5b41f366fed0d58 77 | 59350ca084b7f26bf5ce6eb8 78 | 59338e76772c3e6384afbb15 79 | 5c20ca3a0843bc542d94e3e2 80 | 5c1dbf200843bc542d8ef8c4 81 | 5c1b1500bee9a723c96c3e78 82 | 5bea87f4abd34c35e1860ab5 83 | 5c2b3ed5e611832e8aed46bf 84 | 57f8d9bbe73f6760f10e916a 85 | 5bf7d63575c26f32dbf7413b 86 | 5be4ab93870d330ff2dce134 87 | 5bd43b4ba6b28b1ee86b92dd 88 | 5bccd6beca24970bce448134 89 | 5bc5f0e896b66a2cd8f9bd36 90 | 5ba19a8a360c7c30c1c169df 91 | 5b908d3dc6ab78485f3d24a9 92 | 5b2c67b5e0878c381608b8d8 93 | 5b4933abf2b5f44e95de482a 94 | 5b3b353d8d46a939f93524b9 95 | 5acf8ca0f3d8a750097e4b15 96 | 5ab8713ba3799a1d138bd69a 97 | 5aa235f64a17b335eeaf9609 98 | 5aa0f9d7a9efce63548c69a1 99 | 5a8315f624b8e938486e0bd8 100 | 5a48c4e9c7dab83a7d7b5cc7 101 | 59ecfd02e225f6492d20fcc9 102 | 59f87d0bfa6280566fb38c9a 103 | 59f363a8b45be22330016cad 104 | 59f70ab1e5c5d366af29bf3e 105 | 59817e4a1bd4b175e7038d19 106 | 59e75a2ca9e91f2c5526005d 107 | 5947719bf1b45630bd096665 108 | 5947b62af1b45630bd0c2a02 109 | 59056e6760bb961de55f3501 110 | 58f7f7299f5b5647873cb110 111 | 58cf4771d0f5fb221defe6da 112 | 58d36897f387231e6c929903 113 | 58c4bb4f4a69c55606122be4 114 | -------------------------------------------------------------------------------- /datasets/lists/blendedmvs/training_list.txt: -------------------------------------------------------------------------------- 1 | 5c1f33f1d33e1f2e4aa6dda4 2 | 5bfe5ae0fe0ea555e6a969ca 3 | 5bff3c5cfe0ea555e6bcbf3a 4 | 58eaf1513353456af3a1682a 5 | 5bfc9d5aec61ca1dd69132a2 6 | 5bf18642c50e6f7f8bdbd492 7 | 5bf26cbbd43923194854b270 8 | 5bf17c0fd439231948355385 9 | 5be3ae47f44e235bdbbc9771 10 | 5be3a5fb8cfdd56947f6b67c 11 | 5bbb6eb2ea1cfa39f1af7e0c 12 | 5ba75d79d76ffa2c86cf2f05 13 | 5bb7a08aea1cfa39f1a947ab 14 | 5b864d850d072a699b32f4ae 15 | 5b6eff8b67b396324c5b2672 16 | 5b6e716d67b396324c2d77cb 17 | 5b69cc0cb44b61786eb959bf 18 | 5b62647143840965efc0dbde 19 | 5b60fa0c764f146feef84df0 20 | 5b558a928bbfb62204e77ba2 21 | 5b271079e0878c3816dacca4 22 | 5b08286b2775267d5b0634ba 23 | 5afacb69ab00705d0cefdd5b 24 | 5af28cea59bc705737003253 25 | 5af02e904c8216544b4ab5a2 26 | 5aa515e613d42d091d29d300 27 | 5c34529873a8df509ae57b58 28 | 5c34300a73a8df509add216d 29 | 5c1af2e2bee9a723c963d019 30 | 5c1892f726173c3a09ea9aeb 31 | 5c0d13b795da9479e12e2ee9 32 | 5c062d84a96e33018ff6f0a6 33 | 5bfd0f32ec61ca1dd69dc77b 34 | 5bf21799d43923194842c001 35 | 5bf3a82cd439231948877aed 36 | 5bf03590d4392319481971dc 37 | 5beb6e66abd34c35e18e66b9 38 | 5be883a4f98cee15019d5b83 39 | 5be47bf9b18881428d8fbc1d 40 | 5bcf979a6d5f586b95c258cd 41 | 5bce7ac9ca24970bce4934b6 42 | 5bb8a49aea1cfa39f1aa7f75 43 | 5b78e57afc8fcf6781d0c3ba 44 | 5b21e18c58e2823a67a10dd8 45 | 5b22269758e2823a67a3bd03 46 | 5b192eb2170cf166458ff886 47 | 5ae2e9c5fe405c5076abc6b2 48 | 5adc6bd52430a05ecb2ffb85 49 | 5ab8b8e029f5351f7f2ccf59 50 | 5abc2506b53b042ead637d86 51 | 5ab85f1dac4291329b17cb50 52 | 5a969eea91dfc339a9a3ad2c 53 | 5a8aa0fab18050187cbe060e 54 | 5a7d3db14989e929563eb153 55 | 5a69c47d0d5d0a7f3b2e9752 56 | 5a618c72784780334bc1972d 57 | 5a6464143d809f1d8208c43c 58 | 5a588a8193ac3d233f77fbca 59 | 5a57542f333d180827dfc132 60 | 5a572fd9fc597b0478a81d14 61 | 5a563183425d0f5186314855 62 | 5a4a38dad38c8a075495b5d2 63 | 5a48d4b2c7dab83a7d7b9851 64 | 5a489fb1c7dab83a7d7b1070 65 | 5a48ba95c7dab83a7d7b44ed 66 | 5a3ca9cb270f0e3f14d0eddb 67 | 5a3cb4e4270f0e3f14d12f43 68 | 5a3f4aba5889373fbbc5d3b5 69 | 5a0271884e62597cdee0d0eb 70 | 59e864b2a9e91f2c5529325f 71 | 599aa591d5b41f366fed0d58 72 | 59350ca084b7f26bf5ce6eb8 73 | 59338e76772c3e6384afbb15 74 | 5c20ca3a0843bc542d94e3e2 75 | 5c1dbf200843bc542d8ef8c4 76 | 5c1b1500bee9a723c96c3e78 77 | 5bea87f4abd34c35e1860ab5 78 | 5c2b3ed5e611832e8aed46bf 79 | 57f8d9bbe73f6760f10e916a 80 | 5bf7d63575c26f32dbf7413b 81 | 5be4ab93870d330ff2dce134 82 | 5bd43b4ba6b28b1ee86b92dd 83 | 5bccd6beca24970bce448134 84 | 5bc5f0e896b66a2cd8f9bd36 85 | 5b908d3dc6ab78485f3d24a9 86 | 5b2c67b5e0878c381608b8d8 87 | 5b4933abf2b5f44e95de482a 88 | 5b3b353d8d46a939f93524b9 89 | 5acf8ca0f3d8a750097e4b15 90 | 5ab8713ba3799a1d138bd69a 91 | 5aa235f64a17b335eeaf9609 92 | 5aa0f9d7a9efce63548c69a1 93 | 5a8315f624b8e938486e0bd8 94 | 5a48c4e9c7dab83a7d7b5cc7 95 | 59ecfd02e225f6492d20fcc9 96 | 59f87d0bfa6280566fb38c9a 97 | 59f363a8b45be22330016cad 98 | 59f70ab1e5c5d366af29bf3e 99 | 59e75a2ca9e91f2c5526005d 100 | 5947719bf1b45630bd096665 101 | 5947b62af1b45630bd0c2a02 102 | 59056e6760bb961de55f3501 103 | 58f7f7299f5b5647873cb110 104 | 58cf4771d0f5fb221defe6da 105 | 58d36897f387231e6c929903 106 | 58c4bb4f4a69c55606122be4 107 | -------------------------------------------------------------------------------- /datasets/lists/blendedmvs/validation_list.txt: -------------------------------------------------------------------------------- 1 | 5b7a3890fc8fcf6781e2593a 2 | 5c189f2326173c3a09ed7ef3 3 | 5b950c71608de421b1e7318f 4 | 5a6400933d809f1d8200af15 5 | 59d2657f82ca7774b1ec081d 6 | 5ba19a8a360c7c30c1c169df 7 | 59817e4a1bd4b175e7038d19 8 | -------------------------------------------------------------------------------- /datasets/lists/dtu/single.txt: -------------------------------------------------------------------------------- 1 | scan48 2 | scan49 3 | scan62 -------------------------------------------------------------------------------- /datasets/lists/dtu/test.txt: -------------------------------------------------------------------------------- 1 | scan1 2 | scan4 3 | scan9 4 | scan10 5 | scan11 6 | scan12 7 | scan13 8 | scan15 9 | scan23 10 | scan24 11 | scan29 12 | scan32 13 | scan33 14 | scan34 15 | scan48 16 | scan49 17 | scan62 18 | scan75 19 | scan77 20 | scan110 21 | scan114 22 | scan118 -------------------------------------------------------------------------------- /datasets/lists/dtu/train.txt: -------------------------------------------------------------------------------- 1 | scan2 2 | scan6 3 | scan7 4 | scan8 5 | scan14 6 | scan16 7 | scan18 8 | scan19 9 | scan20 10 | scan22 11 | scan30 12 | scan31 13 | scan36 14 | scan39 15 | scan41 16 | scan42 17 | scan44 18 | scan45 19 | scan46 20 | scan47 21 | scan50 22 | scan51 23 | scan52 24 | scan53 25 | scan55 26 | scan57 27 | scan58 28 | scan60 29 | scan61 30 | scan63 31 | scan64 32 | scan65 33 | scan68 34 | scan69 35 | scan70 36 | scan71 37 | scan72 38 | scan74 39 | scan76 40 | scan83 41 | scan84 42 | scan85 43 | scan87 44 | scan88 45 | scan89 46 | scan90 47 | scan91 48 | scan92 49 | scan93 50 | scan94 51 | scan95 52 | scan96 53 | scan97 54 | scan98 55 | scan99 56 | scan100 57 | scan101 58 | scan102 59 | scan103 60 | scan104 61 | scan105 62 | scan107 63 | scan108 64 | scan109 65 | scan111 66 | scan112 67 | scan113 68 | scan115 69 | scan116 70 | scan119 71 | scan120 72 | scan121 73 | scan122 74 | scan123 75 | scan124 76 | scan125 77 | scan126 78 | scan127 79 | scan128 -------------------------------------------------------------------------------- /datasets/lists/dtu/trainval.txt: -------------------------------------------------------------------------------- 1 | scan2 2 | scan6 3 | scan7 4 | scan8 5 | scan14 6 | scan16 7 | scan18 8 | scan19 9 | scan20 10 | scan22 11 | scan30 12 | scan31 13 | scan36 14 | scan39 15 | scan41 16 | scan42 17 | scan44 18 | scan45 19 | scan46 20 | scan47 21 | scan50 22 | scan51 23 | scan52 24 | scan53 25 | scan55 26 | scan57 27 | scan58 28 | scan60 29 | scan61 30 | scan63 31 | scan64 32 | scan65 33 | scan68 34 | scan69 35 | scan70 36 | scan71 37 | scan72 38 | scan74 39 | scan76 40 | scan83 41 | scan84 42 | scan85 43 | scan87 44 | scan88 45 | scan89 46 | scan90 47 | scan91 48 | scan92 49 | scan93 50 | scan94 51 | scan95 52 | scan96 53 | scan97 54 | scan98 55 | scan99 56 | scan100 57 | scan101 58 | scan102 59 | scan103 60 | scan104 61 | scan105 62 | scan107 63 | scan108 64 | scan109 65 | scan111 66 | scan112 67 | scan113 68 | scan115 69 | scan116 70 | scan119 71 | scan120 72 | scan121 73 | scan122 74 | scan123 75 | scan124 76 | scan125 77 | scan126 78 | scan127 79 | scan128 80 | scan3 81 | scan5 82 | scan17 83 | scan21 84 | scan28 85 | scan35 86 | scan37 87 | scan38 88 | scan40 89 | scan43 90 | scan56 91 | scan59 92 | scan66 93 | scan67 94 | scan82 95 | scan86 96 | scan106 97 | scan117 -------------------------------------------------------------------------------- /datasets/lists/dtu/val.txt: -------------------------------------------------------------------------------- 1 | scan3 2 | scan5 3 | scan17 4 | scan21 5 | scan28 6 | scan35 7 | scan37 8 | scan38 9 | scan40 10 | scan43 11 | scan56 12 | scan59 13 | scan66 14 | scan67 15 | scan82 16 | scan86 17 | scan106 18 | scan117 -------------------------------------------------------------------------------- /filter/__init__.py: -------------------------------------------------------------------------------- 1 | # from .gipuma import gipuma_filter 2 | from .pcd import pcd_filter 3 | from .dypcd_tanks import dypcd_filter 4 | -------------------------------------------------------------------------------- /filter/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DIVE128/DMVSNet/d37a0b2da8017ac88d48208e4f062c72248c48d5/filter/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /filter/__pycache__/dypcd_tanks.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DIVE128/DMVSNet/d37a0b2da8017ac88d48208e4f062c72248c48d5/filter/__pycache__/dypcd_tanks.cpython-37.pyc -------------------------------------------------------------------------------- /filter/__pycache__/pcd.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DIVE128/DMVSNet/d37a0b2da8017ac88d48208e4f062c72248c48d5/filter/__pycache__/pcd.cpython-37.pyc -------------------------------------------------------------------------------- /filter/__pycache__/tank_test_config.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/DIVE128/DMVSNet/d37a0b2da8017ac88d48208e4f062c72248c48d5/filter/__pycache__/tank_test_config.cpython-37.pyc -------------------------------------------------------------------------------- /filter/dypcd_tanks.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import signal 4 | import numpy as np 5 | from PIL import Image 6 | from functools import partial 7 | from multiprocessing import Pool 8 | from plyfile import PlyData, PlyElement 9 | 10 | from datasets.data_io import read_pfm 11 | from filter.tank_test_config import tank_cfg 12 | 13 | from datasets.data_io import save_pfm, read_pfm 14 | 15 | import torch.nn.functional as F 16 | import torch 17 | # save a binary mask 18 | def save_mask(filename, mask): 19 | assert mask.dtype == np.bool 20 | mask = mask.astype(np.uint8) * 255 21 | Image.fromarray(mask).save(filename) 22 | 23 | 24 | # read an image 25 | def read_img(filename): 26 | img = Image.open(filename) 27 | # scale 0~255 to 0~1 28 | np_img = np.array(img, dtype=np.float32) / 255. 29 | return np_img 30 | 31 | 32 | # read intrinsics and extrinsics 33 | def read_camera_parameters(filename): 34 | with open(filename) as f: 35 | lines = f.readlines() 36 | lines = [line.rstrip() for line in lines] 37 | # extrinsics: line [1,5), 4x4 matrix 38 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 39 | # intrinsics: line [7-10), 3x3 matrix 40 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 41 | # TODO: assume the feature is 1/4 of the original image size 42 | # intrinsics[:2, :] /= 4 43 | return intrinsics, extrinsics 44 | 45 | 46 | # read a pair file, [(ref_view1, [src_view1-1, ...]), (ref_view2, [src_view2-1, ...]), ...] 47 | def read_pair_file(filename): 48 | data = [] 49 | with open(filename) as f: 50 | num_viewpoint = int(f.readline()) 51 | # 49 viewpoints 52 | for view_idx in range(num_viewpoint): 53 | ref_view = int(f.readline().rstrip()) 54 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 55 | if len(src_views) > 0: 56 | data.append((ref_view, src_views)) 57 | return data 58 | 59 | 60 | # project the reference point cloud into the source view, then project back 61 | def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): 62 | width, height = depth_ref.shape[1], depth_ref.shape[0] 63 | ## step1. project reference pixels to the source view 64 | # reference view x, y 65 | x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 66 | x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1]) 67 | # reference 3D space 68 | xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref), 69 | np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1])) 70 | # source 3D space 71 | xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)), 72 | np.vstack((xyz_ref, np.ones_like(x_ref))))[:3] 73 | # source view x, y 74 | K_xyz_src = np.matmul(intrinsics_src, xyz_src) 75 | xy_src = K_xyz_src[:2] / K_xyz_src[2:3] 76 | 77 | ## step2. reproject the source view points with source view depth estimation 78 | # find the depth estimation of the source view 79 | x_src = xy_src[0].reshape([height, width]).astype(np.float32) 80 | y_src = xy_src[1].reshape([height, width]).astype(np.float32) 81 | sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR) 82 | # mask = sampled_depth_src > 0 83 | 84 | # source 3D space 85 | # NOTE that we should use sampled source-view depth_here to project back 86 | xyz_src = np.matmul(np.linalg.inv(intrinsics_src), 87 | np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1])) 88 | # reference 3D space 89 | xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)), 90 | np.vstack((xyz_src, np.ones_like(x_ref))))[:3] 91 | # source view x, y, depth 92 | depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32) 93 | K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected) 94 | K_xyz_reprojected[2:3][K_xyz_reprojected[2:3]==0] += 0.00001 95 | xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3] 96 | x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32) 97 | y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32) 98 | 99 | return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src 100 | 101 | @torch.no_grad() 102 | def reproject_with_depth_pytorch(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): 103 | def tocuda(varlist:list): 104 | out=[] 105 | for var in varlist: 106 | if isinstance(var,np.ndarray): 107 | var=torch.from_numpy(var.copy()) 108 | out.append(var.cuda()) 109 | return out 110 | def tonumpy(varlist:list): 111 | out=[] 112 | for var in varlist: 113 | out.append(var.cpu().numpy()) 114 | return out 115 | 116 | [depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src]=tocuda([depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src]) 117 | 118 | width, height = depth_ref.shape[1], depth_ref.shape[0] 119 | ## step1. project reference pixels to the source view 120 | # reference view x, y 121 | #np.meshgrid(a,b)=torch.meshgrid(b,a) 122 | y_ref,x_ref = torch.meshgrid(torch.arange(0, height),torch.arange(0, width)) 123 | 124 | x_ref, y_ref = x_ref.reshape([-1]).cuda(), y_ref.reshape([-1]).cuda() 125 | # reference 3D space 126 | xyz_ref = torch.matmul(torch.linalg.inv(intrinsics_ref), 127 | torch.vstack((x_ref, y_ref, torch.ones_like(x_ref))) * depth_ref.reshape([-1])) 128 | # source 3D space 129 | xyz_src = torch.matmul(torch.matmul(extrinsics_src, torch.linalg.inv(extrinsics_ref)), 130 | torch.vstack((xyz_ref, torch.ones_like(x_ref))))[:3] 131 | # source view x, y 132 | K_xyz_src = torch.matmul(intrinsics_src, xyz_src) 133 | xy_src = K_xyz_src[:2] / K_xyz_src[2:3] 134 | 135 | ## step2. reproject the source view points with source view depth estimation 136 | # find the depth estimation of the source view 137 | x_src = xy_src[0]/ ((width - 1) / 2) - 1 138 | y_src = xy_src[1]/ ((height - 1) / 2) - 1 139 | proj_xy = torch.stack((x_src, y_src), dim=-1) # [H*W, 2] 140 | sampled_depth_src = F.grid_sample(depth_src.unsqueeze(0).unsqueeze(0), proj_xy.view(1, height, width, 2), mode='bilinear',padding_mode='zeros',align_corners=True).type(torch.float32).squeeze(0).squeeze(0) 141 | 142 | 143 | # mask = sampled_depth_src > 0 144 | 145 | # source 3D space 146 | # NOTE that we should use sampled source-view depth_here to project back 147 | xyz_src = torch.matmul(torch.linalg.inv(intrinsics_src), 148 | torch.vstack((xy_src, torch.ones_like(x_ref))) * sampled_depth_src.reshape([-1])) 149 | # reference 3D space 150 | xyz_reprojected = torch.matmul(torch.matmul(extrinsics_ref, torch.linalg.inv(extrinsics_src)), 151 | torch.vstack((xyz_src, torch.ones_like(x_ref))))[:3] 152 | # source view x, y, depth 153 | depth_reprojected = xyz_reprojected[2].reshape([height, width]) 154 | K_xyz_reprojected = torch.matmul(intrinsics_ref, xyz_reprojected) 155 | K_xyz_reprojected[2:3][K_xyz_reprojected[2:3]==0] += 0.00001 156 | xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3] 157 | x_reprojected = xy_reprojected[0].reshape([height, width]) 158 | y_reprojected = xy_reprojected[1].reshape([height, width]) 159 | 160 | [depth_reprojected, x_reprojected, y_reprojected, x_src, y_src]=tonumpy([depth_reprojected, x_reprojected, y_reprojected, x_src, y_src]) 161 | return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src 162 | 163 | 164 | def check_geometric_consistency(args, depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): 165 | width, height = depth_ref.shape[1], depth_ref.shape[0] 166 | x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 167 | depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, 168 | depth_src, intrinsics_src, extrinsics_src) 169 | # check |p_reproj-p_1| < 1 170 | dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2) 171 | 172 | # check |d_reproj-d_1| / d_1 < 0.01 173 | depth_diff = np.abs(depth_reprojected - depth_ref) 174 | relative_depth_diff = depth_diff / depth_ref 175 | 176 | mask = None 177 | masks = [] 178 | for i in range(2, 11): 179 | # mask = np.logical_and(dist < i / 4, relative_depth_diff < i / 1300) 180 | mask = np.logical_and(dist < i * args.dist_base, relative_depth_diff < i * args.rel_diff_base) 181 | masks.append(mask) 182 | depth_reprojected[~mask] = 0 183 | 184 | return masks, mask, depth_reprojected, x2d_src, y2d_src 185 | 186 | def filter_depth(args, pair_folder, scan_folder, out_folder, plyfilename): 187 | num_stage = len(args.ndepths) 188 | 189 | # the pair file 190 | pair_file = os.path.join(pair_folder, "pair.txt") 191 | # for the final point cloud 192 | vertexs = [] 193 | vertex_colors = [] 194 | 195 | pair_data = read_pair_file(pair_file) 196 | nviews = len(pair_data) 197 | 198 | # for each reference view and the corresponding source views 199 | for ref_view, src_views in pair_data: 200 | # src_views = src_views[:args.num_view] 201 | # load the camera parameters 202 | ref_intrinsics, ref_extrinsics = read_camera_parameters( 203 | os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(ref_view))) 204 | # load the reference image 205 | ref_img = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view))) 206 | # load the estimated depth of the reference view 207 | ref_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0] 208 | # load the photometric mask of the reference view 209 | confidence = read_pfm(os.path.join(out_folder, 'confidence/{:0>8}.pfm'.format(ref_view)))[0] 210 | 211 | if not os.path.exists(os.path.join(out_folder, 'confidence/{:0>8}_stage2.pfm'.format(ref_view))): 212 | confidence2=confidence1=confidence 213 | else: 214 | confidence2 = read_pfm(os.path.join(out_folder, 'confidence/{:0>8}_stage2.pfm'.format(ref_view)))[0] 215 | confidence1 = read_pfm(os.path.join(out_folder, 'confidence/{:0>8}_stage1.pfm'.format(ref_view)))[0] 216 | photo_mask = np.logical_and(np.logical_and(confidence > args.conf[2], confidence2 > args.conf[1]), confidence1 > args.conf[0]) 217 | 218 | # save_pfm(depth_filename, depth_est) 219 | 220 | 221 | if not (os.path.exists(os.path.join(out_folder, "mask/{:0>8}_geo.png".format(ref_view))) and os.path.exists(os.path.join(out_folder, 'depth_est/{:0>8}_averaged.pfm'.format(ref_view)))): 222 | all_srcview_depth_ests = [] 223 | # all_srcview_x = [] 224 | # all_srcview_y = [] 225 | all_srcview_geomask = [] 226 | 227 | # compute the geometric mask 228 | geo_mask_sum = 0 229 | dy_range = len(src_views) + 1 230 | geo_mask_sums = [0] * (dy_range - 2) 231 | for src_view in src_views: 232 | # camera parameters of the source view 233 | src_intrinsics, src_extrinsics = read_camera_parameters( 234 | os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(src_view))) 235 | # the estimated depth of the source view 236 | src_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))[0] 237 | 238 | masks, geo_mask, depth_reprojected, x2d_src, y2d_src = check_geometric_consistency(args, ref_depth_est, ref_intrinsics, 239 | ref_extrinsics, src_depth_est, 240 | src_intrinsics, src_extrinsics) 241 | geo_mask_sum += geo_mask.astype(np.int32) 242 | for i in range(2, dy_range): 243 | geo_mask_sums[i - 2] += masks[i - 2].astype(np.int32) 244 | 245 | all_srcview_depth_ests.append(depth_reprojected) 246 | all_srcview_geomask.append(geo_mask) 247 | 248 | depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1) 249 | 250 | save_pfm(os.path.join(out_folder, 'depth_est/{:0>8}_averaged.pfm'.format(ref_view)), depth_est_averaged.astype(np.float32)) 251 | # at least args.thres_view source views matched 252 | 253 | geo_mask = geo_mask_sum >= dy_range 254 | for i in range(2, dy_range): 255 | geo_mask = np.logical_or(geo_mask, geo_mask_sums[i - 2] >= i) 256 | else: 257 | geo_mask_path=os.path.join(out_folder, "mask/{:0>8}_geo.png".format(ref_view)) 258 | import cv2 259 | geo_mask=cv2.imread(geo_mask_path,-1)>0 260 | depth_est_averaged=read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}_averaged.pfm'.format(ref_view)))[0] 261 | print("finished") 262 | 263 | final_mask = np.logical_and(photo_mask, geo_mask) 264 | 265 | os.makedirs(os.path.join(out_folder, "mask"), exist_ok=True) 266 | save_mask(os.path.join(out_folder, "mask/{:0>8}_photo.png".format(ref_view)), photo_mask) 267 | save_mask(os.path.join(out_folder, "mask/{:0>8}_geo.png".format(ref_view)), geo_mask) 268 | save_mask(os.path.join(out_folder, "mask/{:0>8}_final.png".format(ref_view)), final_mask) 269 | 270 | print("processing {}, ref-view{:0>2}, photo/geo/final-mask:{}/{}/{}".format(scan_folder, ref_view, 271 | photo_mask.mean(), 272 | geo_mask.mean(), final_mask.mean())) 273 | 274 | if args.display: 275 | import cv2 276 | cv2.imshow('ref_img', ref_img[:, :, ::-1]) 277 | cv2.imshow('ref_depth', ref_depth_est / 800) 278 | cv2.imshow('ref_depth * photo_mask', ref_depth_est * photo_mask.astype(np.float32) / 800) 279 | cv2.imshow('ref_depth * geo_mask', ref_depth_est * geo_mask.astype(np.float32) / 800) 280 | cv2.imshow('ref_depth * mask', ref_depth_est * final_mask.astype(np.float32) / 800) 281 | cv2.waitKey(0) 282 | 283 | height, width = depth_est_averaged.shape[:2] 284 | x, y = np.meshgrid(np.arange(0, width), np.arange(0, height)) 285 | 286 | valid_points = final_mask 287 | print("valid_points", valid_points.mean()) 288 | x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[valid_points] 289 | # color = ref_img[1:-16:4, 1::4, :][valid_points] # hardcoded for DTU dataset 290 | 291 | if num_stage == 1: 292 | color = ref_img[1::4, 1::4, :][valid_points] 293 | elif num_stage == 2: 294 | color = ref_img[1::2, 1::2, :][valid_points] 295 | elif num_stage == 3: 296 | color = ref_img[valid_points] 297 | 298 | xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics), 299 | np.vstack((x, y, np.ones_like(x))) * depth) 300 | xyz_world = np.matmul(np.linalg.inv(ref_extrinsics), 301 | np.vstack((xyz_ref, np.ones_like(x))))[:3] 302 | vertexs.append(xyz_world.transpose((1, 0))) 303 | vertex_colors.append((color * 255).astype(np.uint8)) 304 | 305 | # # set used_mask[ref_view] 306 | # used_mask[ref_view][...] = True 307 | # for idx, src_view in enumerate(src_views): 308 | # src_mask = np.logical_and(final_mask, all_srcview_geomask[idx]) 309 | # src_y = all_srcview_y[idx].astype(np.int) 310 | # src_x = all_srcview_x[idx].astype(np.int) 311 | # used_mask[src_view][src_y[src_mask], src_x[src_mask]] = True 312 | 313 | vertexs = np.concatenate(vertexs, axis=0) 314 | vertex_colors = np.concatenate(vertex_colors, axis=0) 315 | vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 316 | vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 317 | 318 | vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr) 319 | for prop in vertexs.dtype.names: 320 | vertex_all[prop] = vertexs[prop] 321 | for prop in vertex_colors.dtype.names: 322 | vertex_all[prop] = vertex_colors[prop] 323 | 324 | el = PlyElement.describe(vertex_all, 'vertex') 325 | PlyData([el]).write(plyfilename) 326 | print("saving the final model to", plyfilename) 327 | 328 | def check_geometric_consistency_geomean(args, depth_ref, intrinsics_ref, extrinsics_ref, depth_src_up,depth_src,depth_src_dn, intrinsics_src, extrinsics_src,f=2): 329 | width, height = depth_ref.shape[1], depth_ref.shape[0] 330 | x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 331 | depth_reprojected_u, x2d_reprojected_u, y2d_reprojected_u, x2d_src_u, y2d_src_u = reproject_with_depth_pytorch(depth_ref, intrinsics_ref, extrinsics_ref, 332 | depth_src_up, intrinsics_src, extrinsics_src) 333 | depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth_pytorch(depth_ref, intrinsics_ref, extrinsics_ref, 334 | depth_src, intrinsics_src, extrinsics_src) 335 | depth_reprojected_d, x2d_reprojected_d, y2d_reprojected_d, x2d_src_d, y2d_src_d = reproject_with_depth_pytorch(depth_ref, intrinsics_ref, extrinsics_ref, 336 | depth_src_dn, intrinsics_src, extrinsics_src) 337 | 338 | 339 | depth_reprojected=np.stack((depth_reprojected_u,depth_reprojected,depth_reprojected_d),axis=0) 340 | x2d_reprojected=np.stack((x2d_reprojected_u,x2d_reprojected,x2d_reprojected_d),axis=0) 341 | y2d_reprojected=np.stack((y2d_reprojected_u,y2d_reprojected,y2d_reprojected_d),axis=0) 342 | 343 | dist = np.sqrt((x2d_reprojected - x_ref[np.newaxis]) ** 2 + (y2d_reprojected - y_ref[np.newaxis]) ** 2) 344 | 345 | # check |d_reproj-d_1| / d_1 < 0.01 346 | depth_diff = np.abs(depth_reprojected - depth_ref[np.newaxis]) 347 | relative_depth_diff = depth_diff / depth_ref 348 | 349 | mask = None 350 | masks = [] 351 | for i in range(f, 6): 352 | # mask = np.logical_and(dist < i / 4, relative_depth_diff < i / 1300) 353 | mask = np.logical_and(dist < i * args.dist_base, relative_depth_diff < i * args.rel_diff_base) 354 | mask_ = np.logical_and(dist < (i) * args.dist_base, relative_depth_diff < (i) * args.rel_diff_base) 355 | mask_=mask_.mean(0) 356 | 357 | masks.append(mask_) 358 | div_=np.where(mask.sum(0)>0,(mask.sum(0)),np.ones_like(depth_reprojected[0])) 359 | depth_reprojected[~mask]=0 360 | depth_reprojected=np.where(mask.sum(0)>0,(depth_reprojected.sum(0))/div_,np.zeros_like(depth_reprojected[0])) 361 | 362 | mask_for_view=mask.mean(0) 363 | mask=mask.sum(0)>0 364 | 365 | 366 | return masks, mask,mask_for_view, depth_reprojected, x2d_src, y2d_src 367 | 368 | 369 | def dypcd_filter_worker(args, scene,suffix=None): 370 | if args.testlist != "all": 371 | scan_id = int(scene[4:]) 372 | save_name = 'mvsnet{:0>3}_l3.ply'.format(scan_id) 373 | else: 374 | save_name = '{}.ply'.format(scene) 375 | pair_folder = os.path.join(args.datapath, scene) 376 | scan_folder = os.path.join(args.outdir, scene) 377 | out_folder = os.path.join(args.outdir, scene) 378 | 379 | if scene in tank_cfg.scenes: 380 | scene_cfg = getattr(tank_cfg, scene) 381 | args.conf = scene_cfg.conf 382 | 383 | filter_depth(args, pair_folder, scan_folder, out_folder, os.path.join(args.outdir,"dypcd" ,save_name)) 384 | 385 | def init_worker(): 386 | ''' 387 | Catch Ctrl+C signal to termiante workers 388 | ''' 389 | signal.signal(signal.SIGINT, signal.SIG_IGN) 390 | 391 | 392 | def dypcd_filter(args, testlist, number_worker,suffix=None): 393 | if not os.path.exists(os.path.join(args.outdir,"dypcd")): 394 | os.makedirs(os.path.join(args.outdir,"dypcd")) 395 | if number_worker>1: 396 | partial_func = partial(dypcd_filter_worker, args) 397 | 398 | p = Pool(number_worker, init_worker) 399 | try: 400 | p.map(partial_func, testlist) 401 | except KeyboardInterrupt: 402 | print("....\n Caught KeyboardInterrupt, terminating workers") 403 | p.terminate() 404 | else: 405 | p.close() 406 | p.join() 407 | else: 408 | if suffix is not None: 409 | if not os.path.exists(os.path.join(args.outdir,"dypcd_{}".format(suffix))): 410 | os.makedirs(os.path.join(args.outdir,"dypcd_{}".format(suffix))) 411 | 412 | for scene in testlist: 413 | dypcd_filter_worker(args,scene,suffix=suffix) 414 | -------------------------------------------------------------------------------- /filter/pcd.py: -------------------------------------------------------------------------------- 1 | import os,sys 2 | import cv2 3 | import signal 4 | import numpy as np 5 | from PIL import Image 6 | from functools import partial 7 | from multiprocessing import Pool 8 | from plyfile import PlyData, PlyElement 9 | from tomlkit import value 10 | from datasets.data_io import cv2_imread 11 | import torch.nn.functional as F 12 | import torch 13 | import re 14 | from filter.tank_test_config import tank_cfg 15 | 16 | def read_pfm(filename): 17 | file = open(filename, 'rb') 18 | color = None 19 | width = None 20 | height = None 21 | scale = None 22 | endian = None 23 | 24 | header = file.readline().decode('utf-8').rstrip() 25 | if header == 'PF': 26 | color = True 27 | elif header == 'Pf': 28 | color = False 29 | else: 30 | raise Exception('Not a PFM file.') 31 | 32 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode('utf-8')) 33 | if dim_match: 34 | width, height = map(int, dim_match.groups()) 35 | else: 36 | raise Exception('Malformed PFM header.') 37 | 38 | scale = float(file.readline().rstrip()) 39 | if scale < 0: # little-endian 40 | endian = '<' 41 | scale = -scale 42 | else: 43 | endian = '>' # big-endian 44 | 45 | data = np.fromfile(file, endian + 'f') 46 | shape = (height, width, 3) if color else (height, width) 47 | 48 | data = np.reshape(data, shape) 49 | data = np.flipud(data) 50 | file.close() 51 | return data, scale 52 | 53 | # save a binary mask 54 | def save_mask(filename, mask): 55 | assert mask.dtype == np.bool 56 | mask = mask.astype(np.uint8) * 255 57 | Image.fromarray(mask).save(filename) 58 | 59 | # read an image 60 | def read_img(filename): 61 | img = Image.open(filename) 62 | # scale 0~255 to 0~1 63 | np_img = np.array(img, dtype=np.float32) / 255. 64 | return np_img 65 | 66 | 67 | # read intrinsics and extrinsics 68 | def read_camera_parameters(filename): 69 | with open(filename) as f: 70 | lines = f.readlines() 71 | lines = [line.rstrip() for line in lines] 72 | # extrinsics: line [1,5), 4x4 matrix 73 | extrinsics = np.fromstring(' '.join(lines[1:5]), dtype=np.float32, sep=' ').reshape((4, 4)) 74 | # intrinsics: line [7-10), 3x3 matrix 75 | intrinsics = np.fromstring(' '.join(lines[7:10]), dtype=np.float32, sep=' ').reshape((3, 3)) 76 | # TODO: assume the feature is 1/4 of the original image size 77 | # intrinsics[:2, :] /= 4 78 | return intrinsics, extrinsics 79 | 80 | 81 | # read a pair file, [(ref_view1, [src_view1-1, ...]), (ref_view2, [src_view2-1, ...]), ...] 82 | def read_pair_file(filename): 83 | data = [] 84 | with open(filename) as f: 85 | num_viewpoint = int(f.readline()) 86 | # 49 viewpoints 87 | for view_idx in range(num_viewpoint): 88 | ref_view = int(f.readline().rstrip()) 89 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 90 | if len(src_views) > 0: 91 | data.append((ref_view, src_views)) 92 | return data 93 | 94 | 95 | # project the reference point cloud into the source view, then project back 96 | def reproject_with_depth(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): 97 | width, height = depth_ref.shape[1], depth_ref.shape[0] 98 | ## step1. project reference pixels to the source view 99 | # reference view x, y 100 | x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 101 | x_ref, y_ref = x_ref.reshape([-1]), y_ref.reshape([-1]) 102 | # reference 3D space 103 | xyz_ref = np.matmul(np.linalg.inv(intrinsics_ref), 104 | np.vstack((x_ref, y_ref, np.ones_like(x_ref))) * depth_ref.reshape([-1])) 105 | # source 3D space 106 | xyz_src = np.matmul(np.matmul(extrinsics_src, np.linalg.inv(extrinsics_ref)), 107 | np.vstack((xyz_ref, np.ones_like(x_ref))))[:3] 108 | # source view x, y 109 | K_xyz_src = np.matmul(intrinsics_src, xyz_src) 110 | xy_src = K_xyz_src[:2] / K_xyz_src[2:3] 111 | 112 | ## step2. reproject the source view points with source view depth estimation 113 | # find the depth estimation of the source view 114 | x_src = xy_src[0].reshape([height, width]).astype(np.float32) 115 | y_src = xy_src[1].reshape([height, width]).astype(np.float32) 116 | sampled_depth_src = cv2.remap(depth_src, x_src, y_src, interpolation=cv2.INTER_LINEAR) 117 | # mask = sampled_depth_src > 0 118 | 119 | # source 3D space 120 | # NOTE that we should use sampled source-view depth_here to project back 121 | xyz_src = np.matmul(np.linalg.inv(intrinsics_src), 122 | np.vstack((xy_src, np.ones_like(x_ref))) * sampled_depth_src.reshape([-1])) 123 | # reference 3D space 124 | xyz_reprojected = np.matmul(np.matmul(extrinsics_ref, np.linalg.inv(extrinsics_src)), 125 | np.vstack((xyz_src, np.ones_like(x_ref))))[:3] 126 | # source view x, y, depth 127 | depth_reprojected = xyz_reprojected[2].reshape([height, width]).astype(np.float32) 128 | K_xyz_reprojected = np.matmul(intrinsics_ref, xyz_reprojected) 129 | xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3] 130 | x_reprojected = xy_reprojected[0].reshape([height, width]).astype(np.float32) 131 | y_reprojected = xy_reprojected[1].reshape([height, width]).astype(np.float32) 132 | 133 | return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src 134 | 135 | @torch.no_grad() 136 | def tocuda(varlist:list): 137 | out=[] 138 | for var in varlist: 139 | if isinstance(var,np.ndarray): 140 | var=torch.from_numpy(var.copy()) 141 | out.append(var.cuda()) 142 | return out 143 | 144 | @torch.no_grad() 145 | def tonumpy(varlist:list): 146 | out=[] 147 | for var in varlist: 148 | out.append(var.cpu().numpy()) 149 | return out 150 | 151 | @torch.no_grad() 152 | def reproject_with_depth_pytorch(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src,tnp=True): 153 | 154 | [depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src]=tocuda([depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src]) 155 | 156 | width, height = depth_ref.shape[1], depth_ref.shape[0] 157 | ## step1. project reference pixels to the source view 158 | # reference view x, y 159 | #np.meshgrid(a,b)=torch.meshgrid(b,a) 160 | y_ref,x_ref = torch.meshgrid(torch.arange(0, height),torch.arange(0, width)) 161 | 162 | x_ref, y_ref = x_ref.reshape([-1]).cuda(), y_ref.reshape([-1]).cuda() 163 | # reference 3D space 164 | xyz_ref = torch.matmul(torch.linalg.inv(intrinsics_ref), 165 | torch.vstack((x_ref, y_ref, torch.ones_like(x_ref))) * depth_ref.reshape([-1])) 166 | # source 3D space 167 | xyz_src = torch.matmul(torch.matmul(extrinsics_src, torch.linalg.inv(extrinsics_ref)), 168 | torch.vstack((xyz_ref, torch.ones_like(x_ref))))[:3] 169 | # source view x, y 170 | K_xyz_src = torch.matmul(intrinsics_src, xyz_src) 171 | xy_src = K_xyz_src[:2] / K_xyz_src[2:3] 172 | 173 | ## step2. reproject the source view points with source view depth estimation 174 | # find the depth estimation of the source view 175 | x_src = xy_src[0]/ ((width - 1) / 2) - 1 176 | y_src = xy_src[1]/ ((height - 1) / 2) - 1 177 | proj_xy = torch.stack((x_src, y_src), dim=-1) # [H*W, 2] 178 | sampled_depth_src = F.grid_sample(depth_src.unsqueeze(0).unsqueeze(0), proj_xy.view(1, height, width, 2), mode='bilinear',padding_mode='zeros',align_corners=True).type(torch.float32).squeeze(0).squeeze(0) 179 | 180 | 181 | # mask = sampled_depth_src > 0 182 | 183 | # source 3D space 184 | # NOTE that we should use sampled source-view depth_here to project back 185 | xyz_src = torch.matmul(torch.linalg.inv(intrinsics_src), 186 | torch.vstack((xy_src, torch.ones_like(x_ref))) * sampled_depth_src.reshape([-1])) 187 | # reference 3D space 188 | xyz_reprojected = torch.matmul(torch.matmul(extrinsics_ref, torch.linalg.inv(extrinsics_src)), 189 | torch.vstack((xyz_src, torch.ones_like(x_ref))))[:3] 190 | # source view x, y, depth 191 | depth_reprojected = xyz_reprojected[2].reshape([height, width]) 192 | K_xyz_reprojected = torch.matmul(intrinsics_ref, xyz_reprojected) 193 | K_xyz_reprojected[2:3][K_xyz_reprojected[2:3]==0] += 0.00001 194 | xy_reprojected = K_xyz_reprojected[:2] / K_xyz_reprojected[2:3] 195 | x_reprojected = xy_reprojected[0].reshape([height, width]) 196 | y_reprojected = xy_reprojected[1].reshape([height, width]) 197 | if tnp: 198 | [depth_reprojected, x_reprojected, y_reprojected, x_src, y_src]=tonumpy([depth_reprojected, x_reprojected, y_reprojected, x_src, y_src]) 199 | 200 | return depth_reprojected, x_reprojected, y_reprojected, x_src, y_src 201 | 202 | @torch.no_grad() 203 | def check_geometric_consistency_pytorch(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src,alpha=1.0): 204 | width, height = depth_ref.shape[1], depth_ref.shape[0] 205 | # x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 206 | y_ref,x_ref = torch.meshgrid(torch.arange(0, height),torch.arange(0, width)) 207 | [y_ref,x_ref]=tocuda([y_ref,x_ref]) 208 | 209 | depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth_pytorch(depth_ref, intrinsics_ref, extrinsics_ref, 210 | depth_src, intrinsics_src, extrinsics_src,tnp=False) 211 | # check |p_reproj-p_1| < 1 212 | dist = torch.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2) 213 | 214 | # check |d_reproj-d_1| / d_1 < 0.01 215 | depth_ref[depth_ref==0]=1e-4 216 | depth_diff = torch.abs(depth_reprojected - depth_ref) 217 | relative_depth_diff = depth_diff / depth_ref 218 | 219 | # mask = torch.logical_and(dist < 1, relative_depth_diff < 0.01) 220 | mask = torch.logical_and(dist < 1*alpha, relative_depth_diff < 0.01*alpha) 221 | 222 | depth_reprojected[~mask] = 0 223 | 224 | return mask, depth_reprojected, x2d_src, y2d_src 225 | 226 | def check_geometric_consistency(depth_ref, intrinsics_ref, extrinsics_ref, depth_src, intrinsics_src, extrinsics_src): 227 | width, height = depth_ref.shape[1], depth_ref.shape[0] 228 | x_ref, y_ref = np.meshgrid(np.arange(0, width), np.arange(0, height)) 229 | depth_reprojected, x2d_reprojected, y2d_reprojected, x2d_src, y2d_src = reproject_with_depth_pytorch(depth_ref, intrinsics_ref, extrinsics_ref, 230 | depth_src, intrinsics_src, extrinsics_src) 231 | # check |p_reproj-p_1| < 1 232 | dist = np.sqrt((x2d_reprojected - x_ref) ** 2 + (y2d_reprojected - y_ref) ** 2) 233 | 234 | # check |d_reproj-d_1| / d_1 < 0.01 235 | depth_ref[depth_ref==0]=1e-4 236 | depth_diff = np.abs(depth_reprojected - depth_ref) 237 | relative_depth_diff = depth_diff / depth_ref 238 | 239 | mask = np.logical_and(dist < 1, relative_depth_diff < 0.01) 240 | depth_reprojected[~mask] = 0 241 | 242 | return mask, depth_reprojected, x2d_src, y2d_src 243 | 244 | def filter_depth(args, pair_folder, scan_folder, out_folder, plyfilename): 245 | num_stage = len(args.ndepths) 246 | 247 | # the pair file 248 | pair_file = os.path.join(pair_folder, "pair.txt") 249 | # for the final point cloud 250 | vertexs = [] 251 | vertex_colors = [] 252 | 253 | pair_data = read_pair_file(pair_file) 254 | nviews = len(pair_data) 255 | 256 | for ref_view, src_views in pair_data: 257 | 258 | # src_views = src_views[:args.num_view] 259 | # load the camera parameters 260 | ref_intrinsics, ref_extrinsics = read_camera_parameters( 261 | os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(ref_view))) 262 | # load the reference image 263 | ref_img = read_img(os.path.join(scan_folder, 'images/{:0>8}.jpg'.format(ref_view))) 264 | # load the estimated depth of the reference view 265 | ref_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(ref_view)))[0] 266 | 267 | # load the photometric mask of the reference view 268 | confidence = read_pfm(os.path.join(out_folder, 'confidence/{:0>8}.pfm'.format(ref_view)))[0] 269 | if os.path.exists(os.path.join(out_folder, 'confidence/{:0>8}_stage2.pfm'.format(ref_view))): 270 | confidence2 = read_pfm(os.path.join(out_folder, 'confidence/{:0>8}_stage2.pfm'.format(ref_view)))[0] 271 | confidence1 = read_pfm(os.path.join(out_folder, 'confidence/{:0>8}_stage1.pfm'.format(ref_view)))[0] 272 | else: 273 | confidence2=confidence1=confidence 274 | photo_mask = np.logical_and(np.logical_and(confidence > args.conf[2], confidence2 > args.conf[1]), confidence1 > args.conf[0]) 275 | 276 | all_srcview_depth_ests = [] 277 | all_srcview_x = [] 278 | all_srcview_y = [] 279 | all_srcview_geomask = [] 280 | 281 | # compute the geometric mask 282 | geo_mask_sum = 0 283 | for src_view in src_views: 284 | # camera parameters of the source view 285 | src_intrinsics, src_extrinsics = read_camera_parameters( 286 | os.path.join(scan_folder, 'cams/{:0>8}_cam.txt'.format(src_view))) 287 | # the estimated depth of the source view 288 | src_depth_est = read_pfm(os.path.join(out_folder, 'depth_est/{:0>8}.pfm'.format(src_view)))[0] 289 | 290 | geo_mask, depth_reprojected, x2d_src, y2d_src = check_geometric_consistency(ref_depth_est, ref_intrinsics, ref_extrinsics, 291 | src_depth_est, 292 | src_intrinsics, src_extrinsics) 293 | geo_mask_sum += geo_mask.astype(np.int32) 294 | all_srcview_depth_ests.append(depth_reprojected) 295 | all_srcview_x.append(x2d_src) 296 | all_srcview_y.append(y2d_src) 297 | all_srcview_geomask.append(geo_mask) 298 | 299 | depth_est_averaged = (sum(all_srcview_depth_ests) + ref_depth_est) / (geo_mask_sum + 1) 300 | 301 | # at least args.thres_view source views matched 302 | geo_mask = geo_mask_sum >= args.thres_view 303 | final_mask = np.logical_and(photo_mask, geo_mask) 304 | final_mask=final_mask 305 | 306 | os.makedirs(os.path.join(out_folder, "mask"), exist_ok=True) 307 | save_mask(os.path.join(out_folder, "mask/{:0>8}_photo.png".format(ref_view)), photo_mask) 308 | save_mask(os.path.join(out_folder, "mask/{:0>8}_geo.png".format(ref_view)), geo_mask) 309 | save_mask(os.path.join(out_folder, "mask/{:0>8}_final.png".format(ref_view)), final_mask) 310 | 311 | print("processing {}, ref-view{:0>2}, photo/geo/final-mask:{}/{}/{}".format(scan_folder, ref_view, 312 | photo_mask.mean(), 313 | geo_mask.mean(), final_mask.mean())) 314 | 315 | if args.display: 316 | import cv2 317 | cv2.imshow('ref_img', ref_img[:, :, ::-1]) 318 | cv2.imshow('ref_depth', ref_depth_est / 800) 319 | cv2.imshow('ref_depth * photo_mask', ref_depth_est * photo_mask.astype(np.float32) / 800) 320 | cv2.imshow('ref_depth * geo_mask', ref_depth_est * geo_mask.astype(np.float32) / 800) 321 | cv2.imshow('ref_depth * mask', ref_depth_est * final_mask.astype(np.float32) / 800) 322 | cv2.waitKey(0) 323 | 324 | height, width = depth_est_averaged.shape[:2] 325 | x, y = np.meshgrid(np.arange(0, width), np.arange(0, height)) 326 | 327 | valid_points = final_mask 328 | print("valid_points", valid_points.mean()) 329 | x, y, depth = x[valid_points], y[valid_points], depth_est_averaged[valid_points] 330 | 331 | #color = ref_img[1:-16:4, 1::4, :][valid_points] # hardcoded for DTU dataset 332 | 333 | if num_stage == 1: 334 | color = ref_img[1::4, 1::4, :][valid_points] 335 | elif num_stage == 2: 336 | color = ref_img[1::2, 1::2, :][valid_points] 337 | elif num_stage == 3: 338 | color = ref_img[valid_points] 339 | 340 | xyz_ref = np.matmul(np.linalg.inv(ref_intrinsics), 341 | np.vstack((x, y, np.ones_like(x))) * depth) 342 | xyz_world = np.matmul(np.linalg.inv(ref_extrinsics), 343 | np.vstack((xyz_ref, np.ones_like(x))))[:3] 344 | vertexs.append(xyz_world.transpose((1, 0))) 345 | vertex_colors.append((color * 255).astype(np.uint8)) 346 | 347 | 348 | vertexs = np.concatenate(vertexs, axis=0) 349 | vertex_colors = np.concatenate(vertex_colors, axis=0) 350 | vertexs = np.array([tuple(v) for v in vertexs], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) 351 | vertex_colors = np.array([tuple(v) for v in vertex_colors], dtype=[('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) 352 | 353 | vertex_all = np.empty(len(vertexs), vertexs.dtype.descr + vertex_colors.dtype.descr) 354 | for prop in vertexs.dtype.names: 355 | vertex_all[prop] = vertexs[prop] 356 | for prop in vertex_colors.dtype.names: 357 | vertex_all[prop] = vertex_colors[prop] 358 | 359 | el = PlyElement.describe(vertex_all, 'vertex') 360 | PlyData([el]).write(plyfilename) 361 | print("saving the final model to", plyfilename) 362 | 363 | 364 | 365 | def pcd_filter_worker(args, scan,suffix=None): 366 | if args.testlist != "all": 367 | scan_id = int(scan[4:]) 368 | save_name = 'mvsnet{:0>3}_l3.ply'.format(scan_id) 369 | else: 370 | save_name = '{}.ply'.format(scan) 371 | pair_folder = os.path.join(args.datapath, scan) 372 | scan_folder = os.path.join(args.outdir, scan) 373 | out_folder = os.path.join(args.outdir, scan) 374 | 375 | if scan in tank_cfg.scenes: 376 | scene_cfg = getattr(tank_cfg, scan) 377 | args.conf = scene_cfg.conf 378 | 379 | filter_depth(args, pair_folder, scan_folder, out_folder, os.path.join(args.outdir,"pcd" ,save_name)) 380 | 381 | def init_worker(): 382 | ''' 383 | Catch Ctrl+C signal to termiante workers 384 | ''' 385 | signal.signal(signal.SIGINT, signal.SIG_IGN) 386 | 387 | 388 | def pcd_filter(args, testlist, number_worker,suffix=None): 389 | if not os.path.exists(os.path.join(args.outdir,"pcd")): 390 | os.makedirs(os.path.join(args.outdir,"pcd")) 391 | 392 | if number_worker>1: 393 | partial_func = partial(pcd_filter_worker, args) 394 | 395 | p = Pool(number_worker, init_worker) 396 | try: 397 | p.map(partial_func, testlist) 398 | except KeyboardInterrupt: 399 | print("....\nCaught KeyboardInterrupt, terminating workers") 400 | p.terminate() 401 | else: 402 | p.close() 403 | p.join() 404 | else: 405 | if suffix is not None: 406 | if not os.path.exists(os.path.join(args.outdir,"pcd_{}".format(suffix))): 407 | os.makedirs(os.path.join(args.outdir,"pcd_{}".format(suffix))) 408 | 409 | for scene in testlist: 410 | pcd_filter_worker(args,scene,suffix=suffix) -------------------------------------------------------------------------------- /filter/tank_test_config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | 3 | tank_cfg = CN() 4 | 5 | tank_cfg.META_ARC = "tank_test_config" 6 | 7 | tank_cfg.scenes = ("Family", "Francis", "Horse", "Lighthouse", "M60", "Panther", "Playground", "Train", "Auditorium", "Ballroom", "Courtroom", "Museum", "Palace", "Temple") 8 | 9 | tank_cfg.Family = CN() 10 | tank_cfg.Family.max_h = 1080 11 | tank_cfg.Family.max_w = 2048 12 | 13 | tank_cfg.Family.conf = [0.6, 0.7, 0.95] 14 | 15 | tank_cfg.Francis = CN() 16 | tank_cfg.Francis.max_h = 1080 17 | tank_cfg.Francis.max_w = 2048 18 | tank_cfg.Francis.conf = [0.6, 0.7, 0.95] 19 | 20 | 21 | tank_cfg.Horse = CN() 22 | tank_cfg.Horse.max_h = 1080 23 | tank_cfg.Horse.max_w = 2048 24 | tank_cfg.Horse.conf = [0.15, 0.4, 0.8] 25 | 26 | 27 | tank_cfg.Lighthouse = CN() 28 | tank_cfg.Lighthouse.max_h = 1080 29 | tank_cfg.Lighthouse.max_w = 2048 30 | tank_cfg.Lighthouse.conf = [0.6, 0.7, 0.95] 31 | 32 | 33 | tank_cfg.M60 = CN() 34 | tank_cfg.M60.max_h = 1080 35 | tank_cfg.M60.max_w = 2048 36 | tank_cfg.M60.conf = [0.35, 0.65, 0.85] 37 | 38 | tank_cfg.Panther = CN() 39 | tank_cfg.Panther.max_h = 896 40 | tank_cfg.Panther.max_w = 1216 41 | tank_cfg.Panther.conf = [0.1, 0.15, 0.9] 42 | 43 | tank_cfg.Playground = CN() 44 | tank_cfg.Playground.max_h = 1080 45 | tank_cfg.Playground.max_w = 2048 46 | tank_cfg.Playground.conf = [0.6, 0.75, 0.95] 47 | 48 | tank_cfg.Train = CN() 49 | tank_cfg.Train.max_h = 1080 50 | tank_cfg.Train.max_w = 2048 51 | tank_cfg.Train.conf = [0.3, 0.6, 0.95] 52 | 53 | tank_cfg.Auditorium = CN() 54 | tank_cfg.Auditorium.max_h = 1080 55 | tank_cfg.Auditorium.max_w = 2048 56 | tank_cfg.Auditorium.conf = [0.0, 0.0, 0.4] 57 | 58 | tank_cfg.Ballroom = CN() 59 | tank_cfg.Ballroom.max_h = 1080 60 | tank_cfg.Ballroom.max_w = 2048 61 | tank_cfg.Ballroom.conf = [0.0, 0.0, 0.5] 62 | 63 | tank_cfg.Courtroom = CN() 64 | tank_cfg.Courtroom.max_h = 1080 65 | tank_cfg.Courtroom.max_w = 2048 66 | tank_cfg.Courtroom.conf = [0.0, 0.0, 0.4] 67 | 68 | tank_cfg.Museum = CN() 69 | tank_cfg.Museum.max_h = 1080 70 | tank_cfg.Museum.max_w = 2048 71 | tank_cfg.Museum.conf = [0.0, 0.0, 0.7] 72 | 73 | tank_cfg.Palace = CN() 74 | tank_cfg.Palace.max_h = 1080 75 | tank_cfg.Palace.max_w = 2048 76 | tank_cfg.Palace.conf = [0.0, 0.0, 0.7] 77 | 78 | tank_cfg.Temple = CN() 79 | tank_cfg.Temple.max_h = 1080 80 | tank_cfg.Temple.max_w = 2048 81 | tank_cfg.Temple.conf = [0.0, 0.0, 0.4] -------------------------------------------------------------------------------- /loss.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | 5 | def mvs_loss(inputs, depth_gt_ms, mask_ms, mode, **kwargs): 6 | depth_loss_weights = kwargs.get("dlossw", [1.0 for k in inputs.keys() if "stage" in k]) 7 | total_loss = torch.tensor(0.0, dtype=torch.float32, device=mask_ms["stage1"].device, requires_grad=False) 8 | for (stage_inputs, stage_key) in [(inputs[k], k) for k in inputs.keys() if "stage" in k]: 9 | prob_volume = stage_inputs["prob_volume"] if "global_volume" not in stage_inputs else stage_inputs["global_volume"]# (b, d, h, w) 10 | depth_values = stage_inputs["depth_values"] if "depth_values_new" not in stage_inputs else stage_inputs["depth_values_new"]# (b, d, h, w) 11 | interval = stage_inputs["interval"] # float 12 | depth_gt = depth_gt_ms[stage_key] # (b, h, w) 13 | mask = mask_ms[stage_key] 14 | 15 | mask = mask > 0.5 16 | 17 | stage_idx = int(stage_key.replace("stage", "")) - 1 18 | stage_weight = depth_loss_weights[stage_idx] 19 | 20 | 21 | if mode == "regression": 22 | 23 | depth_sub_plus=stage_inputs["depth_sub_plus"] 24 | depth_sup_plus_small,depth_sup_plus_huge=depth_sub_plus.split([2,2],dim=1) 25 | loss_depth=2*regression_loss(depth_sup_plus_small, depth_gt.unsqueeze(1).expand_as(depth_sup_plus_small), mask.unsqueeze(1).expand_as(depth_sup_plus_small),torch.ones_like(depth_sup_plus_small)*stage_weight)\ 26 | +2*regression_loss(depth_sup_plus_huge, depth_gt.unsqueeze(1).expand_as(depth_sup_plus_huge), mask.unsqueeze(1).expand_as(depth_sup_plus_huge),torch.ones_like(depth_sup_plus_huge)*stage_weight) 27 | 28 | 29 | var_gt=torch.where((depth_sub_plus[:,0]-depth_gt).abs()<(depth_sub_plus[:,1]-depth_gt).abs(),(depth_sub_plus[:,1]-depth_gt).abs(),(depth_sub_plus[:,0]-depth_gt).abs()) 30 | loss_var_small=regression_loss((depth_sub_plus[:,0]-depth_sub_plus[:,1]).abs(), var_gt, mask,torch.ones_like(var_gt)*stage_weight) 31 | 32 | var_gt=torch.where((depth_sub_plus[:,2]-depth_gt).abs()<(depth_sub_plus[:,3]-depth_gt).abs(),(depth_sub_plus[:,3]-depth_gt).abs(),(depth_sub_plus[:,2]-depth_gt).abs()) 33 | loss_var_huge=regression_loss((depth_sub_plus[:,2]-depth_sub_plus[:,3]).abs(), var_gt, mask,torch.ones_like(var_gt)*stage_weight) 34 | 35 | 36 | coors=torch.stack( 37 | [item.unsqueeze(0).expand_as(depth_sub_plus[:,0]) for item in torch.meshgrid(*[torch.arange(0, s) for s in depth_sub_plus[:,0].shape[-2:]])], 38 | axis=-1).to(depth_sub_plus[:,0].device) 39 | coor_mask=((coors[:,:,:,0]%2==0)&(coors[:,:,:,1]%2==0))|((coors[:,:,:,0]%2==1)&(coors[:,:,:,1]%2==1))# 40 | 41 | small_min,small_max=depth_sup_plus_small.min(1)[0],depth_sup_plus_small.max(1)[0] 42 | huge_min,huge_max=depth_sup_plus_huge.min(1)[0],depth_sup_plus_huge.max(1)[0] 43 | 44 | loss_m=Monte_Carlo_sampling_loss(torch.where(coor_mask,small_min,small_max),depth_gt,mask,torch.ones_like(depth_gt)*stage_weight,mode="center",regress_fn=regression_loss)+\ 45 | Monte_Carlo_sampling_loss(torch.where(~coor_mask,small_min,small_max),depth_gt,mask,torch.ones_like(depth_gt)*stage_weight,mode="center",regress_fn=regression_loss)+\ 46 | Monte_Carlo_sampling_loss(torch.where(coor_mask,huge_min,huge_max),depth_gt,mask,torch.ones_like(depth_gt)*stage_weight,mode="center",regress_fn=regression_loss)+\ 47 | Monte_Carlo_sampling_loss(torch.where(~coor_mask,huge_min,huge_max),depth_gt,mask,torch.ones_like(depth_gt)*stage_weight,mode="center",regress_fn=regression_loss) 48 | 49 | total_loss+=(loss_depth+loss_var_small+loss_var_huge+loss_m) 50 | 51 | 52 | ###refine*********************** 53 | 54 | depth_sub_plus=stage_inputs["depth_sub_plus_refine"] 55 | depth_sup_plus_small,depth_sup_plus_huge=depth_sub_plus.split([2,2],dim=1) 56 | loss_depth=2*regression_loss(depth_sup_plus_small, depth_gt.unsqueeze(1).expand_as(depth_sup_plus_small), mask.unsqueeze(1).expand_as(depth_sup_plus_small),torch.ones_like(depth_sup_plus_small)*stage_weight)\ 57 | +2*regression_loss(depth_sup_plus_huge, depth_gt.unsqueeze(1).expand_as(depth_sup_plus_huge), mask.unsqueeze(1).expand_as(depth_sup_plus_huge),torch.ones_like(depth_sup_plus_huge)*stage_weight) 58 | 59 | var_gt=torch.where((depth_sub_plus[:,0]-depth_gt).abs()<(depth_sub_plus[:,1]-depth_gt).abs(),(depth_sub_plus[:,1]-depth_gt).abs(),(depth_sub_plus[:,0]-depth_gt).abs()) 60 | loss_var_small=regression_loss((depth_sub_plus[:,0]-depth_sub_plus[:,1]).abs(), var_gt, mask,torch.ones_like(var_gt)*stage_weight) 61 | 62 | var_gt=torch.where((depth_sub_plus[:,2]-depth_gt).abs()<(depth_sub_plus[:,3]-depth_gt).abs(),(depth_sub_plus[:,3]-depth_gt).abs(),(depth_sub_plus[:,2]-depth_gt).abs()) 63 | loss_var_huge=regression_loss((depth_sub_plus[:,2]-depth_sub_plus[:,3]).abs(), var_gt, mask,torch.ones_like(var_gt)*stage_weight) 64 | 65 | 66 | coors=torch.stack( 67 | [item.unsqueeze(0).expand_as(depth_sub_plus[:,0]) for item in torch.meshgrid(*[torch.arange(0, s) for s in depth_sub_plus[:,0].shape[-2:]])], 68 | axis=-1).to(depth_sub_plus[:,0].device) 69 | coor_mask=((coors[:,:,:,0]%2==0)&(coors[:,:,:,1]%2==0))|((coors[:,:,:,0]%2==1)&(coors[:,:,:,1]%2==1))# 70 | 71 | small_min,small_max=depth_sup_plus_small.min(1)[0],depth_sup_plus_small.max(1)[0] 72 | huge_min,huge_max=depth_sup_plus_huge.min(1)[0],depth_sup_plus_huge.max(1)[0] 73 | 74 | loss_m=Monte_Carlo_sampling_loss(torch.where(coor_mask,small_min,small_max),depth_gt,mask,torch.ones_like(depth_gt)*stage_weight,mode="center",regress_fn=regression_loss)+\ 75 | Monte_Carlo_sampling_loss(torch.where(~coor_mask,small_min,small_max),depth_gt,mask,torch.ones_like(depth_gt)*stage_weight,mode="center",regress_fn=regression_loss)+\ 76 | Monte_Carlo_sampling_loss(torch.where(coor_mask,huge_min,huge_max),depth_gt,mask,torch.ones_like(depth_gt)*stage_weight,mode="center",regress_fn=regression_loss)+\ 77 | Monte_Carlo_sampling_loss(torch.where(~coor_mask,huge_min,huge_max),depth_gt,mask,torch.ones_like(depth_gt)*stage_weight,mode="center",regress_fn=regression_loss) 78 | 79 | 80 | total_loss+=(loss_depth+loss_var_small+loss_var_huge+loss_m) 81 | 82 | elif mode == "classification": 83 | # loss = classification_loss(prob_volume, depth_values, interval, depth_gt, mask, stage_weight) 84 | loss = classification_loss_1(prob_volume, depth_values, interval, depth_gt, mask, stage_weight) 85 | 86 | total_loss += loss 87 | elif mode =="gfocal": 88 | fl_gamas = [2, 1, 0] 89 | fl_alphas = [0.75, 0.5, 0.25] 90 | gamma = fl_gamas[stage_idx] 91 | alpha = fl_alphas[stage_idx] 92 | loss = gfocal_loss(prob_volume, depth_values, interval, depth_gt, mask, stage_weight, gamma, alpha) 93 | total_loss += loss 94 | elif mode == "unification": 95 | fl_gamas = [2, 1, 0] 96 | fl_alphas = [0.75, 0.5, 0.25] 97 | gamma = fl_gamas[stage_idx] 98 | alpha = fl_alphas[stage_idx] 99 | loss = unified_focal_loss(prob_volume, depth_values, interval, depth_gt, mask, stage_weight, gamma, alpha) 100 | total_loss += loss 101 | else: 102 | raise NotImplementedError("Only support regression, classification and unification!") 103 | 104 | return total_loss 105 | 106 | def Monte_Carlo_sampling_loss(depth_est, depth_gt, mask, weight,mode="center",reflect=False,regress_fn=None): 107 | 108 | batch,height, width= depth_gt.shape 109 | 110 | if mode=="center": 111 | x_offset,y_offset=0.5*torch.ones((batch,height-1, width-1)),0.5*torch.ones((batch,height-1, width-1)) 112 | else: 113 | x_offset,y_offset=torch.rand(batch,height-1, width-1),torch.rand((batch,height-1, width-1)) 114 | 115 | x_offset,y_offset=x_offset.to(depth_gt.device),y_offset.to(depth_gt.device) 116 | 117 | y, x = torch.meshgrid([torch.arange(0, height-1, dtype=torch.float32, device=depth_gt.device), 118 | torch.arange(0, width-1, dtype=torch.float32, device=depth_gt.device)]) 119 | y, x = y.contiguous().unsqueeze(0).repeat(batch,1,1)+y_offset, x.contiguous().unsqueeze(0).repeat(batch,1,1)+x_offset 120 | x=x/((width - 1) / 2) - 1 121 | y=y/((height - 1) / 2) - 1 122 | 123 | grid=torch.stack((x, y), dim=3) 124 | 125 | sampled_gt=F.grid_sample(depth_gt.unsqueeze(1), grid, mode='bilinear',padding_mode='zeros',align_corners=True).type(torch.float32) 126 | sampled_est=F.grid_sample(depth_est.unsqueeze(1), grid, mode='bilinear',padding_mode='zeros',align_corners=True).type(torch.float32) 127 | sampled_weight=F.grid_sample(weight.unsqueeze(1), grid, mode='bilinear',padding_mode='zeros',align_corners=True).type(torch.float32) 128 | sampled_mask=F.grid_sample(mask.float().unsqueeze(1), grid, mode='bilinear',padding_mode='zeros',align_corners=True).type(torch.float32) 129 | #mask!=1 mean there is zero depth\ 130 | sampled_mask=sampled_mask>=1. 131 | 132 | 133 | if reflect== False: 134 | # loss = F.smooth_l1_loss(sampled_est[sampled_mask], sampled_gt[sampled_mask], reduction='mean') 135 | loss =regress_fn(sampled_est, sampled_gt, sampled_mask,sampled_weight) 136 | 137 | else: 138 | with torch.no_grad(): 139 | err=depth_est-depth_gt 140 | kernel = torch.ones((2,2)).unsqueeze(0).unsqueeze(0).to(depth_gt.device) 141 | kernel_weight = torch.nn.Parameter(data=kernel, requires_grad=False) 142 | 143 | up_sum=F.conv2d((err.unsqueeze(1)>0).float(),kernel_weight) 144 | dn_sum=F.conv2d((err.unsqueeze(1)<0).float(),kernel_weight) 145 | 146 | reflect_weight=torch.where((up_sum==4.)|(dn_sum==4.),2*torch.ones_like(sampled_gt),torch.ones_like(sampled_gt)) 147 | # reflect_weight=reflect_weight[sampled_mask] 148 | 149 | loss = F.smooth_l1_loss((reflect_weight*sampled_est)[sampled_mask], (reflect_weight*sampled_gt)[sampled_mask], reduction='mean') 150 | 151 | # loss = loss* weight 152 | 153 | 154 | 155 | return loss 156 | def regression_loss(depth_est, depth_gt, mask, weight): 157 | loss = F.smooth_l1_loss(depth_est[mask], depth_gt[mask], reduction='none') 158 | loss = (loss* weight[mask]).mean() 159 | return loss 160 | 161 | def binary_cross_entropy_with_logits(input, target, weight=None, size_average=None, 162 | reduce=False, reduction='elementwise_mean', pos_weight=None,mask=None): 163 | 164 | if not (target.size() == input.size()): 165 | raise ValueError("Target size ({}) must be the same as input size ({})".format(target.size(), input.size())) 166 | 167 | max_val = (-input).clamp(min=0) 168 | 169 | if pos_weight is None: 170 | ce_loss = input - input * target + max_val + ((-max_val).exp() + (-input - max_val).exp()).log() 171 | else: 172 | log_weight = 1 + (pos_weight - 1) * target 173 | ce_loss = input - input * target + log_weight * (max_val + ((-max_val).exp() + (-input - max_val).exp()).log()) 174 | 175 | 176 | if weight is not None: 177 | ce_loss = ce_loss * weight 178 | if mask is not None: 179 | 180 | ce_loss = ce_loss[mask.unsqueeze(1).repeat(1,ce_loss.shape[1],1,1)] 181 | 182 | if reduction == False: 183 | return ce_loss 184 | elif reduction == 'elementwise_mean': 185 | return ce_loss.mean() 186 | else: 187 | return ce_loss.sum() 188 | def classification_loss_1(prob_volume, depth_values, interval, depth_gt, mask, weight): 189 | depth_gt_volume = depth_gt.unsqueeze(1).expand_as(depth_values) # (b, d, h, w) 190 | 191 | gt_index_volume = ( 192 | ((depth_values - interval / 2) <= depth_gt_volume).float() * ((depth_values + interval / 2) > depth_gt_volume).float()) 193 | 194 | pos_w = (depth_gt_volume.shape[1]-1)/1.0 # pos_w = neg_num / pos_num 195 | loss = binary_cross_entropy_with_logits(prob_volume, gt_index_volume, pos_weight=pos_w,mask=mask,weight=weight) 196 | return loss 197 | 198 | def classification_loss(prob_volume, depth_values, interval, depth_gt, mask, weight): 199 | depth_gt_volume = depth_gt.unsqueeze(1).expand_as(depth_values) # (b, d, h, w) 200 | 201 | gt_index_volume = ( 202 | ((depth_values - interval / 2) <= depth_gt_volume).float() * ((depth_values + interval / 2) > depth_gt_volume).float()) 203 | 204 | NEAR_0 = 1e-4 # Prevent overflow 205 | prob_volume = torch.where(prob_volume <= 0.0, torch.zeros_like(prob_volume) + NEAR_0, prob_volume) 206 | 207 | loss = -torch.sum(gt_index_volume * torch.log(prob_volume), dim=1)[mask].mean() 208 | loss = loss * weight 209 | return loss 210 | 211 | 212 | def gfocal_loss(prob_volume, depth_values, interval, depth_gt, mask, weight, gamma, alpha): 213 | depth_gt_volume = depth_gt.unsqueeze(1).expand_as(depth_values) # (b, d, h, w) 214 | 215 | gt_index_volume = ((depth_values <= depth_gt_volume) * ((depth_values + interval) > depth_gt_volume)) #gt 在哪一个value里面而已 216 | gt_index_volume=gt_index_volume.float() 217 | 218 | pos_weight = (gt_index_volume - prob_volume).abs() 219 | neg_weight = prob_volume 220 | focal_weight = (pos_weight.pow(gamma)) * (gt_index_volume > 0.0).float()\ 221 | + alpha*(neg_weight.pow(gamma)) * (gt_index_volume <= 0.0).float() 222 | 223 | NEAR_0 = 1e-4 # Prevent overflow 224 | prob_volume = torch.where(prob_volume <= 0.0, torch.zeros_like(prob_volume) + NEAR_0, prob_volume) 225 | 226 | mask = mask.unsqueeze(1).expand_as(depth_values).float() # b d h w 227 | loss = (F.binary_cross_entropy(prob_volume, gt_index_volume, reduction="none") * focal_weight * mask).sum() / mask.sum() # all 228 | loss = loss * weight 229 | return loss 230 | 231 | def unified_step_focal_loss(prob_volume, depth_values, interval, depth_gt, mask, weight, gamma, alpha): 232 | depth_gt_volume = depth_gt.unsqueeze(1).expand_as(depth_values) # (b, d, h, w) 233 | 234 | gt_index_volume = (depth_values-depth_gt_volume).abs()<=interval 235 | 236 | gt_unity_index_volume = torch.zeros_like(prob_volume, requires_grad=False) 237 | gt_unity_index_volume[gt_index_volume] = 1.0 - (depth_gt_volume[gt_index_volume] - depth_values[gt_index_volume]).abs() / interval 238 | 239 | gt_unity, _ = torch.max(gt_unity_index_volume, dim=1, keepdim=True) 240 | gt_unity = torch.where(gt_unity > 0.0, gt_unity, torch.ones_like(gt_unity)) # (b, 1, h, w) 241 | pos_weight = (sigmoid((gt_unity - prob_volume).abs() / gt_unity, base=5) - 0.5) * 4 + 1 # [1, 3] 242 | neg_weight = (sigmoid(prob_volume / gt_unity, base=5) - 0.5) * 2 # [0, 1] 243 | focal_weight = (gt_unity_index_volume > 0.0).float() + alpha * (gt_unity_index_volume <= 0.0).float() 244 | 245 | mask = mask.unsqueeze(1).expand_as(depth_values).float() 246 | # offset=prob_volume-1 247 | # torch.where 248 | prob_volume=prob_volume/(prob_volume.max()) 249 | loss = (F.binary_cross_entropy(prob_volume, gt_unity_index_volume, reduction="none") * focal_weight * mask).sum() / mask.sum() 250 | loss = loss * weight 251 | return loss 252 | def unified_focal_loss(prob_volume, depth_values, interval, depth_gt, mask, weight, gamma, alpha): 253 | depth_gt_volume = depth_gt.unsqueeze(1).expand_as(depth_values) # (b, d, h, w) 254 | 255 | gt_index_volume = ((depth_values <= depth_gt_volume) * ((depth_values + interval) > depth_gt_volume)) 256 | 257 | gt_unity_index_volume = torch.zeros_like(prob_volume, requires_grad=False) 258 | gt_unity_index_volume[gt_index_volume] = 1.0 - (depth_gt_volume[gt_index_volume] - depth_values[gt_index_volume]) / interval 259 | 260 | gt_unity, _ = torch.max(gt_unity_index_volume, dim=1, keepdim=True) 261 | gt_unity = torch.where(gt_unity > 0.0, gt_unity, torch.ones_like(gt_unity)) # (b, 1, h, w) 262 | pos_weight = (sigmoid((gt_unity - prob_volume).abs() / gt_unity, base=5) - 0.5) * 4 + 1 # [1, 3] 263 | neg_weight = (sigmoid(prob_volume / gt_unity, base=5) - 0.5) * 2 # [0, 1] 264 | focal_weight = pos_weight.pow(gamma) * (gt_unity_index_volume > 0.0).float() + alpha * neg_weight.pow(gamma) * ( 265 | gt_unity_index_volume <= 0.0).float() 266 | 267 | mask = mask.unsqueeze(1).expand_as(depth_values).float() 268 | loss = (F.binary_cross_entropy(prob_volume, gt_unity_index_volume, reduction="none") * focal_weight * mask).sum() / mask.sum() 269 | loss = loss * weight 270 | return loss 271 | def sigmoid(x, base=2.71828): 272 | return 1 / (1 + torch.pow(base, -x)) 273 | def entropy_loss(prob_volume, depth_gt, mask, depth_value, return_prob_map=False): 274 | # from AA 275 | mask_true = mask 276 | valid_pixel_num = torch.sum(mask_true, dim=[1,2]) + 1e-6 277 | 278 | shape = depth_gt.shape # B,H,W 279 | 280 | depth_num = depth_value.shape[1] 281 | if len(depth_value.shape) < 3: 282 | depth_value_mat = depth_value.repeat(shape[1], shape[2], 1, 1).permute(2,3,0,1) # B,N,H,W 283 | else: 284 | depth_value_mat = depth_value 285 | 286 | gt_index_image = torch.argmin(torch.abs(depth_value_mat-depth_gt.unsqueeze(1)), dim=1) 287 | temp=gt_index_image 288 | 289 | gt_index_image = torch.mul(mask_true, gt_index_image.type(torch.float)) 290 | gt_index_image = torch.round(gt_index_image).type(torch.long).unsqueeze(1) # B, 1, H, W 291 | 292 | # gt index map -> gt one hot volume (B x 1 x H x W ) 293 | gt_index_volume = torch.zeros(shape[0], depth_num, shape[1], shape[2]).type(mask_true.type()).scatter_(1, gt_index_image, 1) 294 | 295 | # cross entropy image (B x D X H x W) 296 | cross_entropy_image = -torch.sum(gt_index_volume * torch.log(prob_volume + 1e-6), dim=1).squeeze(1) # B, 1, H, W 297 | 298 | # masked cross entropy loss 299 | masked_cross_entropy_image = torch.mul(mask_true, cross_entropy_image) # valid pixel 300 | masked_cross_entropy = torch.sum(masked_cross_entropy_image, dim=[1, 2]) 301 | 302 | masked_cross_entropy = torch.mean(masked_cross_entropy / valid_pixel_num) # Origin use sum : aggregate with batch 303 | # winner-take-all depth map 304 | wta_index_map = torch.argmax(prob_volume, dim=1, keepdim=True).type(torch.long).squeeze(1) 305 | 306 | return masked_cross_entropy 307 | 308 | def entropy_loss_expand(prob_volume, depth_gt, mask, depth_value, return_prob_map=False): 309 | # from AA 310 | 311 | 312 | shape = depth_gt.shape # B,H,W 313 | depth_gt=depth_gt.unsqueeze(1).repeat(1,3,1,1).view(-1,shape[-2],shape[-1]) 314 | mask=mask.unsqueeze(1).repeat(1,3,1,1).view(-1,shape[-2],shape[-1]) 315 | shape = depth_gt.shape 316 | 317 | mask_true = mask 318 | valid_pixel_num = torch.sum(mask_true, dim=[1,2]) + 1e-6 319 | 320 | depth_num = depth_value.shape[1] 321 | if len(depth_value.shape) < 3: 322 | depth_value_mat = depth_value.repeat(shape[1], shape[2], 1, 1).permute(2,3,0,1) # B,N,H,W 323 | else: 324 | depth_value_mat = depth_value 325 | 326 | gt_index_image = torch.argmin(torch.abs(depth_value_mat-depth_gt.unsqueeze(1)), dim=1) 327 | temp=gt_index_image 328 | 329 | gt_index_image = torch.mul(mask_true, gt_index_image.type(torch.float)) 330 | gt_index_image = torch.round(gt_index_image).type(torch.long).unsqueeze(1) # B, 1, H, W 331 | 332 | # gt index map -> gt one hot volume (B x 1 x H x W ) 333 | gt_index_volume = torch.zeros(shape[0], depth_num, shape[1], shape[2]).type(mask_true.type()).scatter_(1, gt_index_image, 1) 334 | 335 | # cross entropy image (B x D X H x W) 336 | cross_entropy_image = -torch.sum(gt_index_volume * torch.log(prob_volume + 1e-6), dim=1).squeeze(1) # B, 1, H, W 337 | 338 | # masked cross entropy loss 339 | masked_cross_entropy_image = torch.mul(mask_true, cross_entropy_image) # valid pixel 340 | masked_cross_entropy = torch.sum(masked_cross_entropy_image, dim=[1, 2]) 341 | 342 | masked_cross_entropy = torch.mean(masked_cross_entropy / valid_pixel_num) # Origin use sum : aggregate with batch 343 | # winner-take-all depth map 344 | wta_index_map = torch.argmax(prob_volume, dim=1, keepdim=True).type(torch.long).squeeze(1) 345 | 346 | return masked_cross_entropy 347 | -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import argparse,os 2 | from model import Model 3 | from tools import setup_seed 4 | parser = argparse.ArgumentParser(description="UniMVSNet args") 5 | 6 | # network 7 | parser.add_argument("--fea_mode", type=str, default="fpn", choices=["fpn", "unet","hrnet"]) 8 | parser.add_argument("--agg_mode", type=str, default="variance", choices=["variance", "adaptive","corr","corr_adaptive","corr_dynamic","corr_diff_att_adaptive","corr_dynamic_adaptive","corr_dynamic_adaptive_diff","corr_dynamic_adaptive_diff_att"]) 9 | # parser.add_argument("--agg_mode", type=str, default="variance", choices=["variance", "adaptive"]) 10 | 11 | parser.add_argument("--depth_mode", type=str, default="regression", choices=["regression", "classification", "unification","gfocal"]) 12 | parser.add_argument("--ndepths", type=int, nargs='+', default=[48, 32, 8]) 13 | parser.add_argument("--interval_ratio", type=float, nargs='+', default=[4, 2,1]) 14 | 15 | # dataset 16 | parser.add_argument("--datapath", type=str) 17 | parser.add_argument("--trainlist", type=str) 18 | parser.add_argument("--testlist", type=str) 19 | parser.add_argument("--dataset_name", type=str, default="dtu_yao", choices=["dtu_yao", "general_eval", "blendedmvs"]) 20 | parser.add_argument('--batch_size', type=int, default=1, help='train batch size') 21 | parser.add_argument('--numdepth', type=int, default=192, help='the number of depth values') 22 | parser.add_argument('--interval_scale', type=float, default=1.06, help='the number of depth values') 23 | parser.add_argument("--nviews", type=int, default=5) 24 | # only for train and eval 25 | parser.add_argument("--img_size", type=int, nargs='+', default=[512, 640]) 26 | parser.add_argument("--inverse_depth", action="store_true") 27 | 28 | # training and val 29 | parser.add_argument('--start_epoch', type=int, default=0) 30 | parser.add_argument('--epochs', type=int, default=16, help='number of epochs to train') 31 | parser.add_argument('--lr', type=float, default=0.001, help='learning rate') 32 | parser.add_argument('--wd', type=float, default=0.0, help='weight decay') 33 | parser.add_argument('--scheduler', type=str, default="steplr", choices=["steplr", "cosinelr"]) 34 | parser.add_argument('--warmup', type=float, default=0.2, help='warmup epochs') 35 | parser.add_argument('--milestones', type=float, nargs='+', default=[10, 12, 14], help='lr schedule') 36 | parser.add_argument('--lr_decay', type=float, default=0.5, help='lr decay at every milestone') 37 | parser.add_argument('--resume', type=str, help='path to the resume model') 38 | parser.add_argument('--log_dir', type=str, help='path to the log dir') 39 | parser.add_argument('--dlossw', type=float, nargs='+', default=[0.5, 1.0, 2.0], help='depth loss weight for different stage') 40 | parser.add_argument('--eval_freq', type=int, default=1, help='eval freq') 41 | parser.add_argument('--summary_freq', type=int, default=50, help='print and summary frequency') 42 | parser.add_argument("--val", action="store_true") 43 | parser.add_argument("--sync_bn", action="store_true") 44 | parser.add_argument("--blendedmvs_finetune", action="store_true") 45 | 46 | # testing 47 | parser.add_argument("--test", action="store_true") 48 | parser.add_argument('--testpath_single_scene', help='testing data path for single scene') 49 | parser.add_argument('--outdir', default='./outputs', help='output dir') 50 | parser.add_argument('--num_view', type=int, default=5, help='num of view') 51 | parser.add_argument('--max_h', type=int, default=864, help='testing max h') 52 | parser.add_argument('--max_w', type=int, default=1152, help='testing max w') 53 | parser.add_argument('--fix_res', action='store_true', help='scene all using same res') 54 | parser.add_argument('--num_worker', type=int, default=4, help='depth_filer worker') 55 | parser.add_argument('--save_freq', type=int, default=20, help='save freq of local pcd') 56 | parser.add_argument('--filter_method', type=str, default='gipuma', choices=["gipuma","pcd_weight", "pcd", "dypcd"], help="filter method") 57 | parser.add_argument('--display', action='store_true', help='display depth images and masks') 58 | parser.add_argument("--winner_take_all_to_generate_depth", action="store_true") 59 | # pcd or dypcd 60 | parser.add_argument('--conf', type=float, nargs='+', default=[0.1, 0.15, 0.7], help='prob confidence, for pcd and dypcd') 61 | parser.add_argument('--thres_view', type=float, default=5, help='threshold of num view, only for pcd') 62 | # dypcd 63 | parser.add_argument('--dist_base', type=float, default=1 / 4) 64 | parser.add_argument('--rel_diff_base', type=float, default=1 / 1300) 65 | # gimupa 66 | parser.add_argument('--fusibile_exe_path', type=str, default='./fusibile/build/fusibile') 67 | parser.add_argument('--prob_threshold', type=float, default='0.3') 68 | parser.add_argument('--disp_threshold', type=float, default='0.25') 69 | parser.add_argument('--num_consistent', type=float, default='3') 70 | parser.add_argument('--build_stage', type=int, default='3') 71 | 72 | # visualization 73 | parser.add_argument("--vis", action="store_true") 74 | parser.add_argument('--depth_path', type=str) 75 | parser.add_argument('--depth_img_save_dir', type=str, default="./") 76 | 77 | 78 | # device and distributed 79 | parser.add_argument("--no_cuda", action="store_true") 80 | parser.add_argument("--local_rank", type=int, default=0) 81 | parser.add_argument('--dist-url', default='env://', help='url used to set up distributed training') 82 | args = parser.parse_args() 83 | 84 | if __name__ == '__main__': 85 | # os.environ['CUDA_VISIBLE_DEVICES']="2" 86 | 87 | model = Model(args) 88 | print(args) 89 | model.main() 90 | 91 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import imp 2 | import cv2 3 | import time 4 | import progressbar 5 | import torch.backends.cudnn as cudnn 6 | from tensorboardX import SummaryWriter 7 | from torch.nn.parallel import DistributedDataParallel 8 | import numpy as np 9 | from networks.mvsnet import MVSNet 10 | from datasets import get_loader 11 | from tools import * 12 | from loss import mvs_loss 13 | from datasets.data_io import save_pfm, read_pfm 14 | from filter import pcd_filter, dypcd_filter 15 | from filter.tank_test_config import tank_cfg 16 | from thop import profile,clever_format 17 | 18 | class Model: 19 | def __init__(self, args): 20 | 21 | if args.vis: 22 | self.args = args 23 | return 24 | 25 | cudnn.benchmark = True 26 | 27 | init_distributed_mode(args) 28 | 29 | self.args = args 30 | self.device = torch.device("cpu" if self.args.no_cuda or not torch.cuda.is_available() else "cuda") 31 | 32 | self.network = MVSNet(ndepths=args.ndepths, depth_interval_ratio=args.interval_ratio, fea_mode=args.fea_mode, 33 | agg_mode=args.agg_mode, depth_mode=args.depth_mode, 34 | winner_take_all_to_generate_depth=args.winner_take_all_to_generate_depth,inverse_depth=self.args.inverse_depth).to(self.device) 35 | 36 | if self.args.distributed and self.args.sync_bn: 37 | self.network = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.network) 38 | 39 | if not (self.args.val or self.args.test): 40 | 41 | self.optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, self.network.parameters()), lr=args.lr, 42 | weight_decay=args.wd) 43 | self.lr_scheduler = get_schedular(self.optimizer, self.args) 44 | self.train_loader, self.train_sampler = get_loader(args, args.datapath, args.trainlist, args.nviews, "train") 45 | 46 | if not self.args.test: 47 | self.loss_func = mvs_loss 48 | 49 | self.val_loader, self.val_sampler = get_loader(args, args.datapath, args.testlist, 5, "test",force_test=True) 50 | if is_main_process(): 51 | self.writer = SummaryWriter(log_dir=args.log_dir, comment="Record network info") 52 | 53 | self.network_without_ddp = self.network 54 | if self.args.distributed: 55 | self.network = DistributedDataParallel(self.network, device_ids=[self.args.local_rank]) 56 | # self.network = DistributedDataParallel(self.network, device_ids=[self.args.local_rank],find_unused_parameters=True) 57 | self.network_without_ddp = self.network.module 58 | 59 | if self.args.resume: 60 | checkpoint = torch.load(self.args.resume, map_location="cpu") 61 | if not (self.args.val or self.args.test or self.args.blendedmvs_finetune): 62 | self.args.start_epoch = checkpoint["epoch"] + 1 63 | self.optimizer.load_state_dict(checkpoint["optimizer"]) 64 | self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) 65 | import collections 66 | new_dic=collections.OrderedDict() 67 | for (key,values) in checkpoint["model"].items(): 68 | if "attn_mask" not in key: 69 | new_dic[key]=values 70 | self.network_without_ddp.load_state_dict(new_dic) 71 | 72 | self.blendmvs=('dataset_low_res' in args.datapath) 73 | 74 | def main(self): 75 | # print(self.args.test) 76 | if self.args.vis: 77 | self.visualization() 78 | return 79 | if self.args.val: 80 | self.validate() 81 | return 82 | if self.args.test: 83 | self.test() 84 | return 85 | self.train() 86 | 87 | def train(self): 88 | 89 | for epoch in range(self.args.start_epoch, self.args.start_epoch + self.args.epochs): 90 | if self.args.distributed: 91 | self.train_sampler.set_epoch(epoch) 92 | self.train_epoch(epoch) 93 | if is_main_process(): 94 | torch.save({ 95 | 'epoch': epoch, 96 | 'model': self.network_without_ddp.state_dict(), 97 | 'optimizer': self.optimizer.state_dict(), 98 | "lr_scheduler": self.lr_scheduler.state_dict()}, 99 | "{}/model_{:0>6}.ckpt".format(self.args.log_dir, epoch)) 100 | 101 | if (epoch % self.args.eval_freq == 0) or (epoch == self.args.epochs - 1): 102 | self.validate(epoch) 103 | torch.cuda.empty_cache() 104 | 105 | def train_epoch(self, epoch): 106 | self.network.train() 107 | 108 | if is_main_process(): 109 | pwidgets = [progressbar.Percentage(), " ", progressbar.Counter(format='%(value)02d/%(max_value)d'), " ", progressbar.Bar(), " ", 110 | progressbar.Timer(), ",", progressbar.ETA(), ",", progressbar.Variable('LR', width=1), ",", 111 | progressbar.Variable('Loss', width=1), ",", progressbar.Variable('Th2', width=1), ",", 112 | progressbar.Variable('Th4', width=1), ",", progressbar.Variable('Th8', width=1)] 113 | 114 | pbar = progressbar.ProgressBar(widgets=pwidgets, max_value=len(self.train_loader), 115 | prefix="Epoch {}/{}: ".format(epoch, self.args.epochs)).start() 116 | 117 | avg_scalars = DictAverageMeter() 118 | if not self.blendmvs: 119 | color_y=torch.zeros((3,512,640)).cuda() 120 | color_g=torch.zeros((3,512,640)).cuda() 121 | else: 122 | color_y=torch.zeros((3,576,768)).cuda() 123 | color_g=torch.zeros((3,576,768)).cuda() 124 | color_y[1]=1. 125 | color_y[0]=1. 126 | color_g[1]=1. 127 | for batch, data in enumerate(self.train_loader): 128 | data = tocuda(data) 129 | 130 | outputs = self.network(data["imgs"], data["proj_matrices"], data["depth_values"]) 131 | 132 | loss = self.loss_func(outputs, data["depth"], data["mask"], self.args.depth_mode, dlossw=self.args.dlossw) 133 | 134 | self.optimizer.zero_grad() 135 | loss.backward() 136 | self.optimizer.step() 137 | 138 | self.lr_scheduler.step(epoch + batch / len(self.train_loader)) 139 | 140 | gt_depth = data["depth"]["stage{}".format(len(self.args.ndepths))] 141 | mask = data["mask"]["stage{}".format(len(self.args.ndepths))] 142 | 143 | thres2mm = Thres_metrics(outputs["depth"], gt_depth, mask > 0.5, 2) 144 | thres4mm = Thres_metrics(outputs["depth"], gt_depth, mask > 0.5, 4) 145 | thres8mm = Thres_metrics(outputs["depth"], gt_depth, mask > 0.5, 8) 146 | abs_depth_error = AbsDepthError_metrics(outputs["depth"], gt_depth, mask > 0.5) 147 | 148 | 149 | scalar_outputs = {"loss": loss, 150 | "abs_depth_error": abs_depth_error, 151 | "thres2mm_error": thres2mm, 152 | "thres4mm_error": thres4mm, 153 | "thres8mm_error": thres8mm, 154 | } 155 | 156 | if "depth_refine" in outputs: 157 | thres2mm_r = Thres_metrics(outputs["depth_refine"], gt_depth, mask > 0.5, 2) 158 | thres4mm_r = Thres_metrics(outputs["depth_refine"], gt_depth, mask > 0.5, 4) 159 | thres8mm_r = Thres_metrics(outputs["depth_refine"], gt_depth, mask > 0.5, 8) 160 | abs_depth_error_r = AbsDepthError_metrics(outputs["depth_refine"], gt_depth, mask > 0.5) 161 | 162 | if "depth_refine" in outputs: 163 | scalar_outputs_r = { 164 | "abs_depth_error_r": abs_depth_error_r, 165 | "thres2mm_error_r": thres2mm_r, 166 | "thres4mm_error_r": thres4mm_r, 167 | "thres8mm_error_r": thres8mm_r} 168 | scalar_outputs={**scalar_outputs_r,**scalar_outputs} 169 | 170 | up_dn_mask=((mask>0)&((outputs["depth"] - gt_depth).abs()<2)).unsqueeze(1) 171 | up_dn=torch.where((outputs["depth"]>gt_depth).unsqueeze(1).repeat(1,3,1,1),color_g.unsqueeze(0).repeat(outputs["depth"].shape[0],1,1,1),color_y.unsqueeze(0).repeat(outputs["depth"].shape[0],1,1,1)) 172 | 173 | image_outputs = {"depth_est": outputs["depth"] * mask, 174 | "depth_est_nomask": outputs["depth"], 175 | "ref_img": data["imgs"][:, 0], 176 | "mask": mask, 177 | "conf":outputs["photometric_confidence"], 178 | 179 | "conf_09mask":(outputs["photometric_confidence"]>0.9).float(), 180 | "conf_05mask":(outputs["photometric_confidence"]>0.5).float(), 181 | "conf_01mask":(outputs["photometric_confidence"]>0.1).float(), 182 | "errormap": (outputs["depth"] - gt_depth).abs().clip(0,2) * mask, 183 | "up_dn":up_dn*up_dn_mask, 184 | "depth_gt": gt_depth, 185 | 186 | } 187 | if "depth_refine" in outputs: 188 | image_outputs_r={"depth_refine_est": outputs["depth_refine"] * mask, 189 | "depth_refine_est_nomask": outputs["depth_refine"], 190 | "errormap_refine": (outputs["depth_refine"] - gt_depth).abs() * mask 191 | } 192 | image_outputs={**image_outputs_r,**image_outputs} 193 | if self.args.distributed: 194 | scalar_outputs = reduce_scalar_outputs(scalar_outputs) 195 | 196 | scalar_outputs, image_outputs = tensor2float(scalar_outputs), tensor2numpy(image_outputs) 197 | 198 | if is_main_process(): 199 | avg_scalars.update(scalar_outputs) 200 | if batch >= len(self.train_loader) - 1: 201 | save_scalars(self.writer, 'train_avg', avg_scalars.avg_data, epoch) 202 | if (epoch * len(self.train_loader) + batch) % self.args.summary_freq == 0: 203 | save_scalars(self.writer, 'train', scalar_outputs, epoch * len(self.train_loader) + batch) 204 | save_images(self.writer, 'train', image_outputs, epoch * len(self.train_loader) + batch) 205 | 206 | pbar.update(batch, LR=self.optimizer.param_groups[0]['lr'], 207 | Loss="{:.3f}|{:.3f}".format(scalar_outputs["loss"], avg_scalars.avg_data["loss"]), 208 | Th2="{:.3f}|{:.3f}".format(scalar_outputs["thres2mm_error"], avg_scalars.avg_data["thres2mm_error"]), 209 | Th4="{:.3f}|{:.3f}".format(scalar_outputs["thres4mm_error"], avg_scalars.avg_data["thres4mm_error"]), 210 | Th8="{:.3f}|{:.3f}".format(scalar_outputs["thres8mm_error"], avg_scalars.avg_data["thres8mm_error"])) 211 | 212 | if is_main_process(): 213 | pbar.finish() 214 | 215 | @torch.no_grad() 216 | def validate(self, epoch=0): 217 | self.network.eval() 218 | 219 | if is_main_process(): 220 | pwidgets = [progressbar.Percentage(), " ", progressbar.Counter(format='%(value)02d/%(max_value)d'), " ", progressbar.Bar(), " ", 221 | progressbar.Timer(), ",", progressbar.ETA(), ",", progressbar.Variable('Loss', width=1), ",", 222 | progressbar.Variable('Th2', width=1), ",", progressbar.Variable('Th4', width=1), ",", 223 | progressbar.Variable('Th8', width=1)] 224 | pbar = progressbar.ProgressBar(widgets=pwidgets, max_value=len(self.val_loader), prefix="Val:").start() 225 | 226 | avg_scalars = DictAverageMeter() 227 | 228 | if not self.blendmvs: 229 | color_y=torch.zeros((3,512,640)).cuda() 230 | color_g=torch.zeros((3,512,640)).cuda() 231 | else: 232 | color_y=torch.zeros((3,576,768)).cuda() 233 | color_g=torch.zeros((3,576,768)).cuda() 234 | color_y[1]=1. 235 | color_y[0]=1. 236 | color_g[1]=1. 237 | for batch, data in enumerate(self.val_loader): 238 | data = tocuda(data) 239 | 240 | outputs = self.network(data["imgs"], data["proj_matrices"], data["depth_values"]) 241 | 242 | loss = self.loss_func(outputs, data["depth"], data["mask"], self.args.depth_mode, dlossw=self.args.dlossw) 243 | 244 | gt_depth = data["depth"]["stage{}".format(len(self.args.ndepths))] 245 | mask = data["mask"]["stage{}".format(len(self.args.ndepths))] 246 | thres2mm = Thres_metrics(outputs["depth"], gt_depth, mask > 0.5, 2) 247 | thres4mm = Thres_metrics(outputs["depth"], gt_depth, mask > 0.5, 4) 248 | thres8mm = Thres_metrics(outputs["depth"], gt_depth, mask > 0.5, 8) 249 | abs_depth_error = AbsDepthError_metrics(outputs["depth"], gt_depth, mask > 0.5) 250 | 251 | 252 | 253 | scalar_outputs = {"loss": loss, 254 | "abs_depth_error": abs_depth_error, 255 | "thres2mm_error": thres2mm, 256 | "thres4mm_error": thres4mm, 257 | "thres8mm_error": thres8mm, 258 | 259 | } 260 | 261 | up_dn_mask=((mask>0)&((outputs["depth"] - gt_depth).abs()<2)).unsqueeze(1) 262 | up_dn=torch.where((outputs["depth"]>gt_depth).unsqueeze(1).repeat(1,3,1,1),color_g.unsqueeze(0).repeat(outputs["depth"].shape[0],1,1,1),color_y.unsqueeze(0).repeat(outputs["depth"].shape[0],1,1,1)) 263 | 264 | 265 | image_outputs = {"depth_est": outputs["depth"] * mask, 266 | "depth_est_nomask": outputs["depth"], 267 | "ref_img": data["imgs"][:, 0], 268 | "mask": mask, 269 | "conf":outputs["photometric_confidence"], 270 | "conf_09mask":(outputs["photometric_confidence"]>0.9).float(), 271 | "conf_05mask":(outputs["photometric_confidence"]>0.5).float(), 272 | "conf_01mask":(outputs["photometric_confidence"]>0.1).float(), 273 | "errormap": (outputs["depth"] - gt_depth).abs().clip(0,2) * mask, 274 | "up_dn":up_dn*up_dn_mask, 275 | "depth_gt": gt_depth, 276 | 277 | } 278 | 279 | if self.args.distributed: 280 | scalar_outputs = reduce_scalar_outputs(scalar_outputs) 281 | 282 | scalar_outputs, image_outputs = tensor2float(scalar_outputs), tensor2numpy(image_outputs) 283 | 284 | if is_main_process(): 285 | avg_scalars.update(scalar_outputs) 286 | if batch >= len(self.val_loader) - 1: 287 | save_scalars(self.writer, 'test_avg', avg_scalars.avg_data, epoch) 288 | if (epoch * len(self.val_loader) + batch) % self.args.summary_freq == 0: 289 | save_scalars(self.writer, 'test', scalar_outputs, epoch * len(self.val_loader) + batch) 290 | save_images(self.writer, 'test', image_outputs, epoch * len(self.val_loader) + batch) 291 | 292 | pbar.update(batch, 293 | Loss="{:.3f}|{:.3f}".format(scalar_outputs["loss"], avg_scalars.avg_data["loss"]), 294 | Th2="{:.3f}|{:.3f}".format(scalar_outputs["thres2mm_error"], avg_scalars.avg_data["thres2mm_error"]), 295 | Th4="{:.3f}|{:.3f}".format(scalar_outputs["thres4mm_error"], avg_scalars.avg_data["thres4mm_error"]), 296 | Th8="{:.3f}|{:.3f}".format(scalar_outputs["thres8mm_error"], avg_scalars.avg_data["thres8mm_error"])) 297 | 298 | if is_main_process(): 299 | pbar.finish() 300 | 301 | @torch.no_grad() 302 | def test(self,parameters=True): 303 | self.network.eval() 304 | 305 | if self.args.testpath_single_scene: 306 | self.args.datapath = os.path.dirname(self.args.testpath_single_scene) 307 | 308 | if self.args.testlist != "all": 309 | with open(self.args.testlist) as f: 310 | content = f.readlines() 311 | testlist = [line.rstrip() for line in content] 312 | 313 | else: 314 | # for tanks & temples or eth3d or colmap 315 | testlist = [e for e in os.listdir(self.args.datapath) if os.path.isdir(os.path.join(self.args.datapath, e))] \ 316 | if not self.args.testpath_single_scene else [os.path.basename(self.args.testpath_single_scene)] 317 | 318 | print(testlist) 319 | 320 | num_stage = len(self.args.ndepths) 321 | 322 | # step1. save all the depth maps and the masks in outputs directory 323 | for scene in testlist: 324 | 325 | if scene in tank_cfg.scenes: 326 | scene_cfg = getattr(tank_cfg, scene) 327 | self.args.max_h = scene_cfg.max_h 328 | self.args.max_w = scene_cfg.max_w 329 | 330 | TestImgLoader, _ = get_loader(self.args, self.args.datapath, [scene], self.args.num_view, mode="test") 331 | 332 | for batch_idx, sample in enumerate(TestImgLoader): 333 | sample_cuda = tocuda(sample) 334 | start_time = time.time() 335 | 336 | outputs = self.network(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"]) 337 | 338 | if parameters==True: 339 | macs, params = profile(self.network, inputs=(sample_cuda["imgs"], sample_cuda["proj_matrices"], sample_cuda["depth_values"], )) 340 | 341 | print("params:{},macs:{}".format( params,macs)) 342 | parameters=False 343 | 344 | 345 | end_time = time.time() 346 | 347 | outputs = tensor2numpy(outputs) 348 | del sample_cuda 349 | filenames = sample["filename"] 350 | cams = sample["proj_matrices"]["stage{}".format(num_stage)].numpy() 351 | imgs = sample["imgs"].numpy() 352 | print('Iter {}/{}, Time:{} Res:{}'.format(batch_idx, len(TestImgLoader), end_time - start_time, imgs[0].shape)) 353 | 354 | # save depth maps and confidence maps 355 | for filename, cam, img, depth_est, photometric_confidence \ 356 | in zip(filenames, cams, imgs, outputs["depth"], 357 | outputs["photometric_confidence"] 358 | ): 359 | 360 | img = img[0] # ref view 361 | cam = cam[0] # ref cam 362 | depth_filename = os.path.join(self.args.outdir, filename.format('depth_est', '.pfm')) 363 | confidence_filename = os.path.join(self.args.outdir, filename.format('confidence', '.pfm')) 364 | cam_filename = os.path.join(self.args.outdir, filename.format('cams', '_cam.txt')) 365 | img_filename = os.path.join(self.args.outdir, filename.format('images', '.jpg')) 366 | # ply_filename = os.path.join(self.args.outdir, filename.format('ply_local', '.ply')) 367 | os.makedirs(depth_filename.rsplit('/', 1)[0], exist_ok=True) 368 | os.makedirs(confidence_filename.rsplit('/', 1)[0], exist_ok=True) 369 | os.makedirs(cam_filename.rsplit('/', 1)[0], exist_ok=True) 370 | os.makedirs(img_filename.rsplit('/', 1)[0], exist_ok=True) 371 | 372 | save_pfm(depth_filename, depth_est) 373 | 374 | 375 | save_pfm(confidence_filename, photometric_confidence) 376 | # save cams, img 377 | write_cam(cam_filename, cam) 378 | img = np.clip(np.transpose(img, (1, 2, 0)) * 255, 0, 255).astype(np.uint8) 379 | img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) 380 | cv2.imwrite(img_filename, img_bgr) 381 | 382 | 383 | 384 | torch.cuda.empty_cache() 385 | 386 | # step2. filter saved depth maps with photometric confidence maps and geometric constraints 387 | if self.args.filter_method == "pcd": 388 | pcd_filter(self.args, testlist, self.args.num_worker) 389 | elif self.args.filter_method == "dypcd": 390 | dypcd_filter(self.args, testlist, 1) 391 | 392 | @torch.no_grad() 393 | def visualization(self): 394 | 395 | import matplotlib as mpl 396 | import matplotlib.cm as cm 397 | from PIL import Image 398 | 399 | save_dir = self.args.depth_img_save_dir 400 | depth_path = self.args.depth_path 401 | 402 | depth, scale = read_pfm(depth_path) 403 | vmax = np.percentile(depth, 95) 404 | normalizer = mpl.colors.Normalize(vmin=depth.min(), vmax=vmax) 405 | mapper = cm.ScalarMappable(norm=normalizer, cmap='magma') 406 | colormapped_im = (mapper.to_rgba(depth)[:, :, :3] * 255).astype(np.uint8) 407 | im = Image.fromarray(colormapped_im) 408 | im.save(os.path.join(save_dir, "depth.png")) 409 | 410 | print("Successfully visualize!") 411 | 412 | -------------------------------------------------------------------------------- /networks/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import time 5 | import sys 6 | from torch.autograd import Variable 7 | 8 | sys.path.append("..") 9 | 10 | 11 | def init_bn(module): 12 | if module.weight is not None: 13 | nn.init.ones_(module.weight) 14 | if module.bias is not None: 15 | nn.init.zeros_(module.bias) 16 | return 17 | 18 | 19 | def init_uniform(module, init_method): 20 | if module.weight is not None: 21 | if init_method == "kaiming": 22 | nn.init.kaiming_uniform_(module.weight) 23 | elif init_method == "xavier": 24 | nn.init.xavier_uniform_(module.weight) 25 | return 26 | 27 | 28 | class Conv2d(nn.Module): 29 | """Applies a 2D convolution (optionally with batch normalization and relu activation) 30 | over an input signal composed of several input planes. 31 | 32 | Attributes: 33 | conv (nn.Module): convolution module 34 | bn (nn.Module): batch normalization module 35 | relu (bool): whether to activate by relu 36 | 37 | Notes: 38 | Default momentum for batch normalization is set to be 0.01, 39 | 40 | """ 41 | 42 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 43 | relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs): 44 | super(Conv2d, self).__init__() 45 | 46 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, 47 | bias=(not bn), **kwargs) 48 | self.kernel_size = kernel_size 49 | self.stride = stride 50 | self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None 51 | # self.bn = nn.GroupNorm(8, out_channels) if bn else None 52 | self.relu = relu 53 | 54 | # assert init_method in ["kaiming", "xavier"] 55 | # self.init_weights(init_method) 56 | 57 | def forward(self, x): 58 | x = self.conv(x) 59 | if self.bn is not None: 60 | x = self.bn(x) 61 | if self.relu: 62 | x = F.relu(x, inplace=True) 63 | return x 64 | 65 | def init_weights(self, init_method): 66 | """default initialization""" 67 | init_uniform(self.conv, init_method) 68 | if self.bn is not None: 69 | init_bn(self.bn) 70 | 71 | 72 | class Deconv2d(nn.Module): 73 | """Applies a 2D deconvolution (optionally with batch normalization and relu activation) 74 | over an input signal composed of several input planes. 75 | 76 | Attributes: 77 | conv (nn.Module): convolution module 78 | bn (nn.Module): batch normalization module 79 | relu (bool): whether to activate by relu 80 | 81 | Notes: 82 | Default momentum for batch normalization is set to be 0.01, 83 | 84 | """ 85 | 86 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 87 | relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs): 88 | super(Deconv2d, self).__init__() 89 | self.out_channels = out_channels 90 | assert stride in [1, 2] 91 | self.stride = stride 92 | 93 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, 94 | bias=(not bn), **kwargs) 95 | self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None 96 | # self.bn = nn.GroupNorm(8, out_channels) if bn else None 97 | self.relu = relu 98 | 99 | # assert init_method in ["kaiming", "xavier"] 100 | # self.init_weights(init_method) 101 | 102 | def forward(self, x): 103 | y = self.conv(x) 104 | if self.stride == 2: 105 | h, w = list(x.size())[2:] 106 | y = y[:, :, :2 * h, :2 * w].contiguous() 107 | if self.bn is not None: 108 | x = self.bn(y) 109 | if self.relu: 110 | x = F.relu(x, inplace=True) 111 | return x 112 | 113 | def init_weights(self, init_method): 114 | """default initialization""" 115 | init_uniform(self.conv, init_method) 116 | if self.bn is not None: 117 | init_bn(self.bn) 118 | 119 | 120 | class Conv3d(nn.Module): 121 | """Applies a 3D convolution (optionally with batch normalization and relu activation) 122 | over an input signal composed of several input planes. 123 | 124 | Attributes: 125 | conv (nn.Module): convolution module 126 | bn (nn.Module): batch normalization module 127 | relu (bool): whether to activate by relu 128 | 129 | Notes: 130 | Default momentum for batch normalization is set to be 0.01, 131 | 132 | """ 133 | 134 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 135 | relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs): 136 | super(Conv3d, self).__init__() 137 | self.out_channels = out_channels 138 | self.kernel_size = kernel_size 139 | assert stride in [1, 2] 140 | self.stride = stride 141 | 142 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, 143 | bias=(not bn), **kwargs) 144 | self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None 145 | # self.bn = nn.GroupNorm(8, out_channels) if bn else None 146 | self.relu = relu 147 | 148 | # assert init_method in ["kaiming", "xavier"] 149 | # self.init_weights(init_method) 150 | 151 | def forward(self, x): 152 | x = self.conv(x) 153 | if self.bn is not None: 154 | x = self.bn(x) 155 | if self.relu: 156 | x = F.relu(x, inplace=True) 157 | return x 158 | 159 | def init_weights(self, init_method): 160 | """default initialization""" 161 | init_uniform(self.conv, init_method) 162 | if self.bn is not None: 163 | init_bn(self.bn) 164 | 165 | 166 | class Deconv3d(nn.Module): 167 | """Applies a 3D deconvolution (optionally with batch normalization and relu activation) 168 | over an input signal composed of several input planes. 169 | 170 | Attributes: 171 | conv (nn.Module): convolution module 172 | bn (nn.Module): batch normalization module 173 | relu (bool): whether to activate by relu 174 | 175 | Notes: 176 | Default momentum for batch normalization is set to be 0.01, 177 | 178 | """ 179 | 180 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, 181 | relu=True, bn=True, bn_momentum=0.1, init_method="xavier", **kwargs): 182 | super(Deconv3d, self).__init__() 183 | self.out_channels = out_channels 184 | assert stride in [1, 2] 185 | self.stride = stride 186 | 187 | self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, 188 | bias=(not bn), **kwargs) 189 | self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None 190 | # self.bn = nn.GroupNorm(8, out_channels) if bn else None 191 | self.relu = relu 192 | 193 | # assert init_method in ["kaiming", "xavier"] 194 | # self.init_weights(init_method) 195 | 196 | def forward(self, x): 197 | y = self.conv(x) 198 | if self.bn is not None: 199 | x = self.bn(y) 200 | if self.relu: 201 | x = F.relu(x, inplace=True) 202 | return x 203 | 204 | def init_weights(self, init_method): 205 | """default initialization""" 206 | init_uniform(self.conv, init_method) 207 | if self.bn is not None: 208 | init_bn(self.bn) 209 | 210 | 211 | 212 | def homo_warping(src_fea, src_proj, ref_proj, depth_values): 213 | # src_fea: [B, C, H, W] 214 | # src_proj: [B, 4, 4] 215 | # ref_proj: [B, 4, 4] 216 | # depth_values: [B, Ndepth] o [B, Ndepth, H, W] 217 | # out: [B, C, Ndepth, H, W] 218 | batch, channels = src_fea.shape[0], src_fea.shape[1] 219 | num_depth = depth_values.shape[1] 220 | height, width = src_fea.shape[2], src_fea.shape[3] 221 | 222 | with torch.no_grad(): 223 | proj = torch.matmul(src_proj, torch.inverse(ref_proj)) 224 | rot = proj[:, :3, :3] # [B,3,3] 225 | trans = proj[:, :3, 3:4] # [B,3,1] 226 | 227 | y, x = torch.meshgrid([torch.arange(0, height, dtype=torch.float32, device=src_fea.device), 228 | torch.arange(0, width, dtype=torch.float32, device=src_fea.device)]) 229 | y, x = y.contiguous(), x.contiguous() 230 | y, x = y.view(height * width), x.view(height * width) 231 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] 232 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W] 233 | rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W] 234 | rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.view(batch, 1, num_depth, 235 | -1) # [B, 3, Ndepth, H*W] 236 | proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W] 237 | proj_xyz[:, 2:3][proj_xyz[:, 2:3] == 0] += 0.00001 # NAN BUG, not on dtu, but on blendedmvs 238 | 239 | proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W] 240 | proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1 241 | proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1 242 | proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2] 243 | grid = proj_xy 244 | 245 | # warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear', 246 | # padding_mode='zeros').type(torch.float32) 247 | warped_src_fea = F.grid_sample(src_fea, grid.view(batch, num_depth * height, width, 2), mode='bilinear', 248 | padding_mode='zeros',align_corners=True).type(torch.float32) 249 | warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height, width) 250 | 251 | return warped_src_fea,grid.view(batch, num_depth,height, width, 2) 252 | 253 | class DeConv2dFuse(nn.Module): 254 | def __init__(self, in_channels, out_channels, kernel_size, relu=True, bn=True, 255 | bn_momentum=0.1): 256 | super(DeConv2dFuse, self).__init__() 257 | 258 | self.deconv = Deconv2d(in_channels, out_channels, kernel_size, stride=2, padding=1, output_padding=1, 259 | bn=True, relu=relu, bn_momentum=bn_momentum) 260 | 261 | self.conv = Conv2d(2 * out_channels, out_channels, kernel_size, stride=1, padding=1, 262 | bn=bn, relu=relu, bn_momentum=bn_momentum) 263 | 264 | # assert init_method in ["kaiming", "xavier"] 265 | # self.init_weights(init_method) 266 | 267 | def forward(self, x_pre, x): 268 | x = self.deconv(x) 269 | x = torch.cat((x, x_pre), dim=1) 270 | x = self.conv(x) 271 | return x 272 | 273 | 274 | class FeatureNet(nn.Module): 275 | def __init__(self, base_channels, num_stage=3, stride=4, mode="fpn",layernorm=False): 276 | super(FeatureNet, self).__init__() 277 | assert mode in ["unet", "fpn"], print("mode must be in 'unet', 'fpn', but get:{}".format(mode)) 278 | self.mode = mode 279 | self.stride = stride 280 | self.base_channels = base_channels 281 | self.num_stage = num_stage 282 | self.layernorm=layernorm 283 | self.conv0 = nn.Sequential( 284 | Conv2d(3, base_channels, 3, 1, padding=1), 285 | Conv2d(base_channels, base_channels, 3, 1, padding=1), 286 | ) 287 | 288 | self.conv1 = nn.Sequential( 289 | Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2), 290 | Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), 291 | Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), 292 | ) 293 | 294 | self.conv2 = nn.Sequential( 295 | Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2), 296 | Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1), 297 | Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1), 298 | ) 299 | 300 | 301 | self.out1 = nn.Conv2d(base_channels * 4, base_channels * 4 *2, 1, bias=False) 302 | self.out_channels = [4 * base_channels] 303 | final_chs = base_channels * 4 304 | 305 | self.inner1 = nn.Conv2d(base_channels * 2, final_chs, 1, bias=True) 306 | self.inner2 = nn.Conv2d(base_channels * 1, final_chs, 1, bias=True) 307 | 308 | self.out2 = nn.Conv2d(final_chs, base_channels * 2 *2, 3, padding=1, bias=False) 309 | self.out3 = nn.Conv2d(final_chs, base_channels *2, 3, padding=1, bias=False) 310 | self.out_channels.append(base_channels * 2) 311 | self.out_channels.append(base_channels) 312 | 313 | 314 | 315 | 316 | def forward(self, x): 317 | conv0 = self.conv0(x) 318 | conv1 = self.conv1(conv0) 319 | conv2 = self.conv2(conv1) 320 | 321 | intra_feat = conv2 322 | outputs = {} 323 | 324 | out = self.out1(intra_feat) 325 | # outputs["stage1"] = out 326 | outputs["stage1"],outputs["stage1_c"]= out.split([out.shape[1]//2,out.shape[1]//2],1) 327 | 328 | intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner1(conv1) 329 | out = self.out2(intra_feat) 330 | # outputs["stage2"] = out 331 | outputs["stage2"],outputs["stage2_c"]= out.split([out.shape[1]//2,out.shape[1]//2],1) 332 | 333 | intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="nearest") + self.inner2(conv0) 334 | out = self.out3(intra_feat) 335 | # outputs["stage3"] = out 336 | outputs["stage3"],outputs["stage3_c"]= out.split([out.shape[1]//2,out.shape[1]//2],1) 337 | 338 | 339 | 340 | return outputs 341 | 342 | class CostRegNet(nn.Module): 343 | def __init__(self, in_channels, base_channels,stage=0): 344 | super(CostRegNet, self).__init__() 345 | self.cosR_small=CostRegNet_part(in_channels, base_channels,stage=0) 346 | self.cosR_huge=CostRegNet_part(in_channels, base_channels,stage=0) 347 | def forward(self, x): 348 | results=torch.cat((self.cosR_small(x),self.cosR_huge(x)),axis=1) 349 | return results 350 | class CostRegNet_refine(nn.Module): 351 | def __init__(self, in_channels, base_channels,stage=0): 352 | super(CostRegNet_refine, self).__init__() 353 | self.cosR_small=CostRegNet_part_refine(in_channels, base_channels,stage=0) 354 | self.cosR_huge=CostRegNet_part_refine(in_channels, base_channels,stage=0) 355 | def forward(self, x): 356 | results=torch.cat((self.cosR_small(x),self.cosR_huge(x)),axis=1) 357 | return results 358 | class CostRegNet_part(nn.Module): 359 | def __init__(self, in_channels, base_channels,stage=0): 360 | super(CostRegNet_part, self).__init__() 361 | self.conv0 = Conv3d(in_channels, base_channels, padding=1) 362 | 363 | self.conv1 = Conv3d(base_channels, base_channels * 2, stride=2, padding=1) 364 | self.conv2 = Conv3d(base_channels * 2, base_channels * 2, padding=1) 365 | 366 | self.conv3 = Conv3d(base_channels * 2, base_channels * 4, stride=2, padding=1) 367 | self.conv4 = Conv3d(base_channels * 4, base_channels * 4, padding=1) 368 | 369 | self.conv5 = Conv3d(base_channels * 4, base_channels * 8, stride=2, padding=1) 370 | self.conv6 = Conv3d(base_channels * 8, base_channels * 8, padding=1) 371 | 372 | self.conv7 = Deconv3d(base_channels * 8, base_channels * 4, stride=2, padding=1, output_padding=1) 373 | 374 | self.conv9 = Deconv3d(base_channels * 4, base_channels * 2, stride=2, padding=1, output_padding=1) 375 | 376 | self.conv11 = Deconv3d(base_channels * 2, base_channels * 1, stride=2, padding=1, output_padding=1) 377 | 378 | # self.prob = nn.Conv3d(base_channels, 1 if stage==0 else 2, 3, stride=1, padding=1, bias=False) 379 | self.prob = nn.Conv3d(base_channels, 2, 3, stride=1, padding=1, bias=False) 380 | 381 | 382 | # for m in self.modules(): 383 | # if isinstance(m, nn.Conv2d): 384 | # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') 385 | # elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): 386 | # nn.init.constant_(m.weight, 1) 387 | # nn.init.constant_(m.bias, 0) 388 | 389 | def forward(self, x): 390 | conv0 = self.conv0(x) 391 | conv2 = self.conv2(self.conv1(conv0)) 392 | conv4 = self.conv4(self.conv3(conv2)) 393 | x = self.conv6(self.conv5(conv4)) 394 | x = conv4 + self.conv7(x) 395 | x = conv2 + self.conv9(x) 396 | x = conv0 + self.conv11(x) 397 | x = self.prob(x) 398 | return x 399 | 400 | class CostRegNet_part_refine(nn.Module): 401 | def __init__(self, in_channels, base_channels,stage=0): 402 | super(CostRegNet_part_refine, self).__init__() 403 | self.conv0 = Conv3d(in_channels, base_channels, padding=1) 404 | 405 | self.conv1 = Conv3d(base_channels, base_channels * 2, stride=2, padding=1) 406 | self.conv2 = Conv3d(base_channels * 2, base_channels * 2, padding=1) 407 | 408 | self.conv3 = Conv3d(base_channels * 2, base_channels * 4, stride=2, padding=1) 409 | self.conv4 = Conv3d(base_channels * 4, base_channels * 4, padding=1) 410 | 411 | self.conv5 = Conv2d(base_channels * 4, base_channels * 8,3, stride=2, padding=1) 412 | self.conv6 = Conv2d(base_channels * 8, base_channels * 8,3, padding=1) 413 | 414 | self.conv7 = Deconv2d(base_channels * 8, base_channels * 4,3, stride=2, padding=1, output_padding=1) 415 | 416 | self.conv9 = Deconv3d(base_channels * 4, base_channels * 2, stride=2, padding=1, output_padding=1) 417 | 418 | self.conv11 = Deconv3d(base_channels * 2, base_channels * 1, stride=2, padding=1, output_padding=1) 419 | 420 | # self.prob = nn.Conv3d(base_channels, 1 if stage==0 else 2, 3, stride=1, padding=1, bias=False) 421 | self.prob = nn.Conv3d(base_channels, 2, 3, stride=1, padding=1, bias=False) 422 | 423 | 424 | 425 | 426 | def forward(self, x,stage=0): 427 | conv0 = self.conv0(x) 428 | conv2 = self.conv2(self.conv1(conv0)) 429 | conv4 = self.conv4(self.conv3(conv2)).squeeze(2) 430 | x=self.conv6(self.conv5(conv4)) 431 | x=conv4+self.conv7(x) 432 | x=x.unsqueeze(2) 433 | x = conv2 + self.conv9(x) 434 | x = conv0 + self.conv11(x) 435 | x = self.prob(x) 436 | return x 437 | class AggWeightNetVolume(nn.Module): 438 | def __init__(self, in_channels=32,hid_channels=1,out_channels=1,relu=True): 439 | super(AggWeightNetVolume, self).__init__() 440 | self.w_net = nn.Sequential( 441 | Conv3d(in_channels, hid_channels, kernel_size=1, stride=1, padding=0,relu=relu), 442 | Conv3d(hid_channels, out_channels, kernel_size=1, stride=1, padding=0,relu=relu) 443 | ) 444 | 445 | def forward(self, x): 446 | """ 447 | :param x: (b, c, d, h, w) 448 | :return: (b, 1, d, h, w) 449 | """ 450 | w = self.w_net(x) 451 | return w 452 | 453 | 454 | def depth_regression(p, depth_values,axis=1): 455 | if depth_values.dim() <= 2: 456 | # print("regression dim <= 2") 457 | depth_values = depth_values.view(*depth_values.shape, 1, 1) 458 | depth = torch.sum(p * depth_values, axis=axis) 459 | 460 | return depth 461 | 462 | 463 | def winner_take_all(prob_volume, depth_values): 464 | """ 465 | :param prob_volume: (b, d, h, w) 466 | :param depth_values: (b, d, h, w) 467 | :return: (b, h, w) 468 | """ 469 | _, idx = torch.max(prob_volume, dim=1, keepdim=True) 470 | depth = torch.gather(depth_values, 1, idx).squeeze(1) 471 | return depth 472 | 473 | 474 | 475 | 476 | def get_cur_depth_range_samples_n(last_depth, ndepth, depth_inteval_pixel): 477 | # cur_depth: (B, H, W) 478 | # return depth_range_values: (B, D, H, W) 479 | last_depth_min = (last_depth - (ndepth+2) / 2 * depth_inteval_pixel) # (B, H, W) 480 | last_depth_max = (last_depth + (ndepth-2) / 2 * depth_inteval_pixel) 481 | # cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel).clamp(min=0.0) #(B, H, W) 482 | # cur_depth_max = (cur_depth_min + (ndepth - 1) * depth_inteval_pixel).clamp(max=max_depth) 483 | 484 | new_interval = (last_depth_max - last_depth_min) / (ndepth - 1) # (B, H, W) 485 | 486 | depth_range_samples = last_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=last_depth.device, 487 | dtype=last_depth.dtype, 488 | requires_grad=False).reshape(1, -1, 1, 489 | 1) * new_interval.unsqueeze(1)) 490 | 491 | return depth_range_samples, (ndepth * depth_inteval_pixel) / (ndepth - 1) 492 | def get_cur_depth_range_samples_p(last_depth, ndepth, depth_inteval_pixel): 493 | # cur_depth: (B, H, W) 494 | # return depth_range_values: (B, D, H, W) 495 | last_depth_min = (last_depth - (ndepth-2) / 2 * depth_inteval_pixel) # (B, H, W) 496 | last_depth_max = (last_depth + (ndepth+2) / 2 * depth_inteval_pixel) 497 | # cur_depth_min = (cur_depth - ndepth / 2 * depth_inteval_pixel).clamp(min=0.0) #(B, H, W) 498 | # cur_depth_max = (cur_depth_min + (ndepth - 1) * depth_inteval_pixel).clamp(max=max_depth) 499 | 500 | new_interval = (last_depth_max - last_depth_min) / (ndepth - 1) # (B, H, W) 501 | 502 | depth_range_samples = last_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=last_depth.device, 503 | dtype=last_depth.dtype, 504 | requires_grad=False).reshape(1, -1, 1, 505 | 1) * new_interval.unsqueeze(1)) 506 | 507 | return depth_range_samples, (ndepth * depth_inteval_pixel) / (ndepth - 1) 508 | 509 | def get_cur_depth_range_samples_inverse(last_depth, ndepth, depth_inteval_pixel): 510 | # cur_depth: (B, H, W) 511 | # return depth_range_values: (B, D, H, W) 512 | last_depth_min = (last_depth - ndepth / 2 * depth_inteval_pixel) # (B, H, W) 513 | last_depth_max = (last_depth + ndepth / 2 * depth_inteval_pixel) 514 | inverse_min=1/last_depth_min 515 | inverse_max=1/last_depth_max 516 | new_interval=(inverse_max-inverse_min)/(ndepth-1) 517 | inverse_depth_range_samples=inverse_min.unsqueeze(1)+(torch.arange(0, ndepth, device=last_depth.device, 518 | dtype=last_depth.dtype, 519 | requires_grad=False).reshape(1, -1, 1, 520 | 1) * new_interval.unsqueeze(1)) 521 | 522 | depth_range_samples=1/inverse_depth_range_samples 523 | return depth_range_samples, (ndepth * depth_inteval_pixel) / (ndepth - 1) 524 | 525 | def get_cur_depth_range_samples_inverse_p(last_depth, ndepth, depth_inteval_pixel): 526 | # cur_depth: (B, H, W) 527 | # return depth_range_values: (B, D, H, W) 528 | last_depth_min = (last_depth - (ndepth-2) / 2 * depth_inteval_pixel) # (B, H, W) 529 | last_depth_max = (last_depth + (ndepth+2) / 2 * depth_inteval_pixel) 530 | inverse_min=1/last_depth_min 531 | inverse_max=1/last_depth_max 532 | new_interval=(inverse_max-inverse_min)/(ndepth-1) 533 | inverse_depth_range_samples=inverse_min.unsqueeze(1)+(torch.arange(0, ndepth, device=last_depth.device, 534 | dtype=last_depth.dtype, 535 | requires_grad=False).reshape(1, -1, 1, 536 | 1) * new_interval.unsqueeze(1)) 537 | 538 | depth_range_samples=1/inverse_depth_range_samples 539 | return depth_range_samples, (ndepth * depth_inteval_pixel) / (ndepth - 1) 540 | def get_cur_depth_range_samples_inverse_n(last_depth, ndepth, depth_inteval_pixel): 541 | # cur_depth: (B, H, W) 542 | # return depth_range_values: (B, D, H, W) 543 | last_depth_min = (last_depth - (ndepth+2) / 2 * depth_inteval_pixel) # (B, H, W) 544 | last_depth_max = (last_depth + (ndepth-2) / 2 * depth_inteval_pixel) 545 | inverse_min=1/last_depth_min 546 | inverse_max=1/last_depth_max 547 | new_interval=(inverse_max-inverse_min)/(ndepth-1) 548 | inverse_depth_range_samples=inverse_min.unsqueeze(1)+(torch.arange(0, ndepth, device=last_depth.device, 549 | dtype=last_depth.dtype, 550 | requires_grad=False).reshape(1, -1, 1, 551 | 1) * new_interval.unsqueeze(1)) 552 | 553 | depth_range_samples=1/inverse_depth_range_samples 554 | return depth_range_samples, (ndepth * depth_inteval_pixel) / (ndepth - 1) 555 | 556 | def get_depth_range_samples(last_depth, ndepth, depth_inteval_pixel, shape=None,next_depth_inteval_pixel=None,inverse=False): 557 | # cur_depth: (B, H, W) or (B, D) 558 | # return depth_range_samples: (B, D, H, W) 559 | if not inverse: 560 | if last_depth.dim() == 2: 561 | last_depth_min = last_depth[:, 0] # (B,) 562 | last_depth_max = last_depth[:, -1] 563 | new_interval = (last_depth_max - last_depth_min) / (ndepth - 1) # (B, ) 564 | stage_interval = new_interval[0] 565 | 566 | depth_range_samples = last_depth_min.unsqueeze(1) + (torch.arange(0, ndepth, device=last_depth.device, dtype=last_depth.dtype, 567 | requires_grad=False).reshape(1, -1) * new_interval.unsqueeze( 568 | 1)) # (B, D) 569 | 570 | # (B, D, H, W) 571 | depth_range_samples = depth_range_samples.unsqueeze(-1).unsqueeze(-1).repeat(1, 1, shape[0], shape[1]) 572 | 573 | coors=torch.stack( 574 | [item.expand_as(depth_range_samples) \ 575 | for item in torch.meshgrid(*[torch.arange(0, s) for s in depth_range_samples.shape[-2:]])], 576 | axis=-1).to(depth_range_samples.device) 577 | mask=((coors[:,:,:,:,0]%2==0)&(coors[:,:,:,:,1]%2==0))|((coors[:,:,:,:,0]%2==1)&(coors[:,:,:,:,1]%2==1)) 578 | 579 | depth_range_samples=torch.where(mask,depth_range_samples-stage_interval,depth_range_samples+stage_interval) 580 | # depth_range_samples=torch.ma 581 | 582 | else: 583 | 584 | depth_range_samples_n, stage_interval = get_cur_depth_range_samples_n(last_depth, ndepth, depth_inteval_pixel) 585 | depth_range_samples_p, stage_interval = get_cur_depth_range_samples_p(last_depth, ndepth, depth_inteval_pixel) 586 | coors=torch.stack( 587 | [item.expand_as(depth_range_samples_n) \ 588 | for item in torch.meshgrid(*[torch.arange(0, s) for s in last_depth.shape[-2:]])], 589 | axis=-1).to(depth_range_samples_n.device) 590 | mask=((coors[:,:,:,:,0]%2==0)&(coors[:,:,:,:,1]%2==0))|((coors[:,:,:,:,0]%2==1)&(coors[:,:,:,:,1]%2==1)) 591 | depth_range_samples=torch.where(mask,\ 592 | depth_range_samples_n, 593 | depth_range_samples_p 594 | ) 595 | 596 | return depth_range_samples, stage_interval 597 | else: 598 | if last_depth.dim() == 2: 599 | 600 | last_depth_min = last_depth[:, 0] # (B,) 601 | last_depth_max = last_depth[:, -1] 602 | new_interval = (last_depth_max - last_depth_min) / (ndepth - 1) # (B, ) 603 | stage_interval = new_interval[0] 604 | 605 | 606 | last_depth_min = last_depth[:, 0]-stage_interval 607 | last_depth_max = last_depth[:, -1]-stage_interval 608 | 609 | new_interval = (last_depth_max - last_depth_min) / (ndepth - 1) # (B, ) 610 | stage_interval = new_interval[0] 611 | depth_values=[] 612 | for bg,end in zip(last_depth_min,last_depth_max): 613 | depth_values.append(torch.linspace(1 / bg , 1 / end, ndepth,device=last_depth.device)) 614 | depth_values=torch.stack(depth_values,dim=0) 615 | depth_range_samples_n = (1 / depth_values).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, shape[0], shape[1]) 616 | 617 | last_depth_min = last_depth[:, 0]+stage_interval 618 | last_depth_max = last_depth[:, -1]+stage_interval 619 | 620 | new_interval = (last_depth_max - last_depth_min) / (ndepth - 1) # (B, ) 621 | stage_interval = new_interval[0] 622 | depth_values=[] 623 | for bg,end in zip(last_depth_min,last_depth_max): 624 | depth_values.append(torch.linspace(1 / bg , 1 / end, ndepth,device=last_depth.device)) 625 | depth_values=torch.stack(depth_values,dim=0) 626 | depth_range_samples_p = (1 / depth_values).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, shape[0], shape[1]) 627 | 628 | coors=torch.stack( 629 | [item.expand_as(depth_range_samples_p) \ 630 | for item in torch.meshgrid(*[torch.arange(0, s) for s in depth_range_samples_p.shape[-2:]])], 631 | axis=-1).to(depth_range_samples_p.device) 632 | mask=((coors[:,:,:,:,0]%2==0)&(coors[:,:,:,:,1]%2==0))|((coors[:,:,:,:,0]%2==1)&(coors[:,:,:,:,1]%2==1)) 633 | 634 | depth_range_samples=torch.where(mask,depth_range_samples_n,depth_range_samples_p) 635 | 636 | else: 637 | # depth_range_samples, stage_interval = get_cur_depth_range_samples_inverse(last_depth, ndepth, depth_inteval_pixel) 638 | depth_range_samples_n, stage_interval = get_cur_depth_range_samples_inverse_n(last_depth, ndepth, depth_inteval_pixel) 639 | depth_range_samples_p, stage_interval = get_cur_depth_range_samples_inverse_p(last_depth, ndepth, depth_inteval_pixel) 640 | coors=torch.stack( 641 | [item.expand_as(depth_range_samples_n) \ 642 | for item in torch.meshgrid(*[torch.arange(0, s) for s in last_depth.shape[-2:]])], 643 | axis=-1).to(depth_range_samples_n.device) 644 | mask=((coors[:,:,:,:,0]%2==0)&(coors[:,:,:,:,1]%2==0))|((coors[:,:,:,:,0]%2==1)&(coors[:,:,:,:,1]%2==1)) 645 | depth_range_samples=torch.where(mask,\ 646 | depth_range_samples_n, 647 | depth_range_samples_p 648 | ) 649 | return depth_range_samples.float(), stage_interval.float() 650 | -------------------------------------------------------------------------------- /networks/mvsnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | 6 | from .module import * 7 | 8 | Align_Corners_Range = False 9 | 10 | 11 | class DepthNet(nn.Module): 12 | def __init__(self, mode="regression"): 13 | super(DepthNet, self).__init__() 14 | 15 | def forward(self, cost_reg, depth_values, num_depth, interval, prob_volume_init=None,stage=0): 16 | 17 | 18 | 19 | prob_volume = F.softmax(cost_reg, dim=2) # (b,2, ndepth, h, w) 20 | depth_sub_plus = depth_regression(prob_volume, depth_values=depth_values.unsqueeze(1),axis=2) # (b, h, w) 21 | 22 | depth_sup_plus_small,depth_sup_plus_huge=depth_sub_plus.split([2,2],dim=1) 23 | 24 | 25 | small_min,small_max=depth_sup_plus_small.min(1)[0],depth_sup_plus_small.max(1)[0] 26 | huge_min,huge_max=depth_sup_plus_huge.min(1)[0],depth_sup_plus_huge.max(1)[0] 27 | huge_min_d,huge_max_d=2*huge_min-huge_max,2*huge_max-huge_min 28 | small_min_d,small_max_d=2*small_min-small_max,2*small_max-small_min 29 | 30 | coors=torch.stack( 31 | [item.unsqueeze(0).expand_as(depth_sub_plus[:,0]) for item in torch.meshgrid(*[torch.arange(0, s) for s in depth_sub_plus[:,0].shape[-2:]])], 32 | axis=-1).to(depth_sub_plus[:,0].device) 33 | mask_00=((coors[:,:,:,0]%4==0)&(coors[:,:,:,1]%2==0)) 34 | mask_01=((coors[:,:,:,0]%4==0)&(coors[:,:,:,1]%2==1)) 35 | mask_10=((coors[:,:,:,0]%4==1)&(coors[:,:,:,1]%2==0)) 36 | mask_11=((coors[:,:,:,0]%4==1)&(coors[:,:,:,1]%2==1)) 37 | mask_20=((coors[:,:,:,0]%4==2)&(coors[:,:,:,1]%2==0)) 38 | mask_21=((coors[:,:,:,0]%4==2)&(coors[:,:,:,1]%2==1)) 39 | mask_30=((coors[:,:,:,0]%4==3)&(coors[:,:,:,1]%2==0)) 40 | mask_31=((coors[:,:,:,0]%4==3)&(coors[:,:,:,1]%2==1)) 41 | 42 | small_stack=torch.stack((3*small_min-2*small_max,2*small_min-small_max,small_min,small_max,2*small_max-small_min,3*small_max-2*small_min),1) 43 | small_stack_d=torch.stack((3*small_min_d-2*small_max_d,2*small_min_d-small_max_d,small_min_d,small_max_d,2*small_max_d-small_min_d,3*small_max_d-2*small_min_d),1) 44 | huge_stack=torch.stack((3*huge_min-2*huge_max,2*huge_min-huge_max,huge_min,huge_max,2*huge_max-huge_min,3*huge_max-2*huge_min),1) 45 | huge_stack_d=torch.stack((3*huge_min_d-2*huge_max_d,2*huge_min_d-huge_max_d,huge_min_d,huge_max_d,2*huge_max_d-huge_min_d,3*huge_max_d-2*huge_min_d),1) 46 | 47 | # depth=torch.zeros_like(depth_sub_plus[:,0]) 48 | depth_values_c=torch.zeros_like(depth_sub_plus) 49 | depth_values_c=torch.where(mask_00.unsqueeze(1),small_stack[:,:-2],depth_values_c) 50 | depth_values_c=torch.where(mask_01.unsqueeze(1),small_stack[:,2:],depth_values_c) 51 | depth_values_c=torch.where(mask_10.unsqueeze(1),huge_stack[:,2:],depth_values_c) 52 | depth_values_c=torch.where(mask_11.unsqueeze(1),huge_stack[:,:-2],depth_values_c) 53 | depth_values_c=torch.where(mask_20.unsqueeze(1),small_stack_d[:,:-2],depth_values_c) 54 | depth_values_c=torch.where(mask_21.unsqueeze(1),small_stack_d[:,2:],depth_values_c) 55 | depth_values_c=torch.where(mask_30.unsqueeze(1),huge_stack_d[:,2:],depth_values_c) 56 | depth_values_c=torch.where(mask_31.unsqueeze(1),huge_stack_d[:,:-2],depth_values_c) 57 | 58 | 59 | with torch.no_grad(): 60 | # photometric confidence 61 | temp_photometric_confidence=torch.sigmoid(interval/(depth_sub_plus.var(1,unbiased=False).sqrt()+1e-5)) 62 | photometric_confidence=2*(temp_photometric_confidence-0.5) 63 | 64 | 65 | return {"photometric_confidence": photometric_confidence, "prob_volume": prob_volume,"depth_sub_plus":depth_sub_plus,"depth_values_c":depth_values_c, 66 | "depth_values": depth_values, "interval": interval} 67 | def refine(self, cost_reg, depth_values, num_depth, interval,alpha=5): 68 | prob_volume = F.softmax(cost_reg*alpha, dim=2) # (b,2, ndepth, h, w) 69 | depth_sub_plus = depth_regression(prob_volume, depth_values=depth_values.unsqueeze(1),axis=2) # (b, h, w) 70 | 71 | depth_sup_plus_small,depth_sup_plus_huge=depth_sub_plus.split([2,2],dim=1) 72 | 73 | 74 | small_min,small_max=depth_sup_plus_small.min(1)[0],depth_sup_plus_small.max(1)[0] 75 | huge_min,huge_max=depth_sup_plus_huge.min(1)[0],depth_sup_plus_huge.max(1)[0] 76 | 77 | coors=torch.stack( 78 | [item.unsqueeze(0).expand_as(depth_sub_plus[:,0]) for item in torch.meshgrid(*[torch.arange(0, s) for s in depth_sub_plus[:,0].shape[-2:]])], 79 | axis=-1).to(depth_sub_plus[:,0].device) 80 | mask_00=((coors[:,:,:,0]%2==0)&(coors[:,:,:,1]%2==0)) 81 | mask_01=((coors[:,:,:,0]%2==0)&(coors[:,:,:,1]%2==1)) 82 | mask_10=((coors[:,:,:,0]%2==1)&(coors[:,:,:,1]%2==0)) 83 | mask_11=((coors[:,:,:,0]%2==1)&(coors[:,:,:,1]%2==1)) 84 | 85 | 86 | depth=torch.zeros_like(depth_sub_plus[:,0]) 87 | 88 | depth=torch.where(mask_00,small_min,depth) 89 | depth=torch.where(mask_01,small_max,depth) 90 | depth=torch.where(mask_10,huge_max,depth) 91 | depth=torch.where(mask_11,huge_min,depth) 92 | 93 | 94 | with torch.no_grad(): 95 | # photometric confidence 96 | temp_photometric_confidence=torch.sigmoid(interval/(depth_sub_plus.var(1,unbiased=False).sqrt()+1e-5)) 97 | photometric_confidence=2*(temp_photometric_confidence-0.5) 98 | 99 | 100 | return {"depth": depth, "photometric_confidence_refine": photometric_confidence,"depth_sub_plus_refine":depth_sub_plus} 101 | 102 | class CostAgg(nn.Module): 103 | def __init__(self, mode="variance", in_channels=None): 104 | super(CostAgg, self).__init__() 105 | self.mode = mode 106 | assert mode in ("variance", "adaptive"), "Don't support {}!".format(mode) 107 | if self.mode == "adaptive": 108 | self.weight_net = nn.ModuleList([AggWeightNetVolume(in_channels[i]) for i in range(len(in_channels))]) 109 | 110 | 111 | def forward(self, features, proj_matrices, depth_values, stage_idx): 112 | """ 113 | :param stage_idx: stage 114 | :param features: [ref_fea, src_fea1, src_fea2, ...], fea shape: (b, c, h, w) 115 | :param proj_matrices: (b, nview, ...) [ref_proj, src_proj1, src_proj2, ...] 116 | :param depth_values: (b, ndepth, h, w) 117 | :return: matching cost volume (b, c, ndepth, h, w) 118 | """ 119 | ref_feature, src_features = features[0], features[1:] 120 | proj_matrices = torch.unbind(proj_matrices, 1) # to list 121 | ref_proj, src_projs = proj_matrices[0], proj_matrices[1:] 122 | 123 | num_views = len(features) 124 | num_depth = depth_values.shape[1] 125 | 126 | ref_volume = ref_feature.unsqueeze(2) 127 | 128 | similarity_sum = 0 129 | 130 | b,c,_,h,w=ref_volume.shape 131 | for src_fea, src_proj in zip(src_features, src_projs): 132 | # warpped features 133 | src_proj_new = src_proj[:, 0].clone() 134 | src_proj_new[:, :3, :4] = torch.matmul(src_proj[:, 1, :3, :3], src_proj[:, 0, :3, :4]) 135 | ref_proj_new = ref_proj[:, 0].clone() 136 | ref_proj_new[:, :3, :4] = torch.matmul(ref_proj[:, 1, :3, :3], ref_proj[:, 0, :3, :4]) 137 | warped_volume,_ = homo_warping(src_fea, src_proj_new, ref_proj_new, depth_values) 138 | 139 | similarity=(warped_volume.view(b,c//2,2,warped_volume.shape[2],h,w)*(ref_volume.view(b,c//2,2,1,h,w))).mean(1) 140 | 141 | if self.training: 142 | similarity_sum = similarity_sum + similarity # [B, 2, D, H, W] 143 | 144 | else: 145 | # TODO: this is only a temporal solution to save memory, better way? 146 | similarity_sum += similarity 147 | 148 | 149 | 150 | del warped_volume 151 | 152 | # aggregate multiple feature volumes by variance 153 | return similarity_sum 154 | 155 | 156 | class MVSNet(nn.Module): 157 | def __init__(self, ndepths, depth_interval_ratio, cr_base_chs=None, fea_mode="fpn", agg_mode="variance", depth_mode="regression",winner_take_all_to_generate_depth=True,inverse_depth=False): 158 | super(MVSNet, self).__init__() 159 | 160 | if cr_base_chs is None: 161 | cr_base_chs = [8] * len(ndepths) 162 | self.ndepths = ndepths 163 | self.depth_interval_ratio = depth_interval_ratio 164 | self.fea_mode = fea_mode 165 | self.cr_base_chs = cr_base_chs 166 | self.num_stage = len(ndepths) 167 | self.inverse_depth=inverse_depth 168 | 169 | print("netphs:", ndepths) 170 | print("depth_intervals_ratio:", depth_interval_ratio) 171 | print("cr_base_chs:", cr_base_chs) 172 | print("fea_mode:", fea_mode) 173 | print("agg_mode:", agg_mode) 174 | print("depth_mode:", depth_mode) 175 | 176 | assert len(ndepths) == len(depth_interval_ratio) 177 | 178 | self.feature = FeatureNet(base_channels=8, stride=4, num_stage=self.num_stage, mode=self.fea_mode) 179 | self.cost_aggregation = CostAgg(agg_mode, self.feature.out_channels) 180 | 181 | self.cost_regularization = nn.ModuleList( 182 | [CostRegNet(in_channels=2, base_channels=self.cr_base_chs[i],stage=i) for i in range(self.num_stage)]) 183 | self.cost_regularization_refine = nn.ModuleList( 184 | [CostRegNet_refine(in_channels=2, base_channels=self.cr_base_chs[i],stage=i) for i in range(self.num_stage)]) 185 | 186 | self.DepthNet = DepthNet(depth_mode) 187 | 188 | def forward(self, imgs, proj_matrices, depth_values): 189 | """ 190 | :param is_flip: augment only for 3D-UNet 191 | :param imgs: (b, nview, c, h, w) 192 | :param proj_matrices: 193 | :param depth_values: 194 | :return: 195 | """ 196 | depth_interval = (depth_values[0, -1] - depth_values[0, 0]) / depth_values.size(1) 197 | 198 | # step 1. feature extraction 199 | features = [] 200 | for nview_idx in range(imgs.size(1)): # imgs shape (B, N, C, H, W) 201 | img = imgs[:, nview_idx] 202 | features.append(self.feature(img)) 203 | 204 | ori_shape = imgs[:, 0].shape[2:] # (H, W) 205 | 206 | outputs = {} 207 | last_depth = None 208 | for stage_idx in range(self.num_stage): 209 | # print("*********************stage{}*********************".format(stage_idx + 1)) 210 | # stage feature, proj_mats, scales 211 | features_stage = [feat["stage{}".format(stage_idx + 1)] for feat in features] 212 | proj_matrices_stage = proj_matrices["stage{}".format(stage_idx + 1)] 213 | # stage1: 1/4, stage2: 1/2, stage3: 1 214 | stage_scale = 2 ** (3 - stage_idx - 1) 215 | 216 | stage_shape = [ori_shape[0] // int(stage_scale), ori_shape[1] // int(stage_scale)] 217 | 218 | if stage_idx == 0: 219 | last_depth = depth_values 220 | else: 221 | last_depth = last_depth.detach() 222 | 223 | # (B, D, H, W) 224 | depth_range_samples, interval = get_depth_range_samples(last_depth=last_depth, 225 | ndepth=self.ndepths[stage_idx], 226 | depth_inteval_pixel=self.depth_interval_ratio[ 227 | stage_idx] * depth_interval, 228 | shape=stage_shape, # only for first stage 229 | inverse=self.inverse_depth 230 | ) 231 | 232 | if stage_idx > 0: 233 | depth_range_samples = F.interpolate(depth_range_samples, stage_shape, mode='bilinear', align_corners=Align_Corners_Range) 234 | 235 | # (b, c, d, h, w) 236 | cost_volume = self.cost_aggregation(features_stage, proj_matrices_stage, depth_range_samples, stage_idx) 237 | # cost volume regularization 238 | # (b, 1, d, h, w) 239 | cost_reg = self.cost_regularization[stage_idx](cost_volume) 240 | 241 | # depth 242 | outputs_stage = self.DepthNet(cost_reg, depth_range_samples, num_depth=self.ndepths[stage_idx], interval=interval,stage=stage_idx) 243 | 244 | 245 | 246 | depth_values_c=outputs_stage["depth_values_c"] 247 | features_stage = [feat["stage{}_c".format(stage_idx + 1)] for feat in features] 248 | 249 | 250 | cost_volume_c = self.cost_aggregation(features_stage, proj_matrices_stage, depth_values_c, stage_idx) 251 | cost_reg_c= self.cost_regularization_refine[stage_idx](cost_volume_c) 252 | outputs_stage_refine = self.DepthNet.refine(cost_reg_c, depth_values_c, num_depth=4, interval=interval) 253 | 254 | outputs_stage={**outputs_stage_refine,**outputs_stage} 255 | last_depth = outputs_stage['depth'] 256 | 257 | outputs["stage{}".format(stage_idx + 1)] = outputs_stage 258 | outputs.update(outputs_stage) 259 | 260 | return outputs 261 | -------------------------------------------------------------------------------- /scripts/blendedmvs_finetune.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | datapath="/data2/yexinyi/datasets/MVS/blendMVS/dataset_low_res/" 3 | 4 | resume="" 5 | log_dir="./checkpoints/DMVSNet/finetune" 6 | if [ ! -d $log_dir ]; then 7 | mkdir -p $log_dir 8 | fi 9 | 10 | CUDA_VISIBLE_DEVICES=2,3 python -m torch.distributed.launch --nproc_per_node=2 --master_port=2342 main.py \ 11 | --sync_bn \ 12 | --blendedmvs_finetune \ 13 | --ndepths 48 32 8 \ 14 | --interval_ratio 4 2 1 \ 15 | --img_size 576 768 \ 16 | --dlossw 0.5 1.0 2.0 \ 17 | --log_dir $log_dir \ 18 | --datapath $datapath \ 19 | --resume $resume \ 20 | --dataset_name "blendedmvs" \ 21 | --nviews 7 \ 22 | --epochs 10 \ 23 | --batch_size 1 \ 24 | --lr 0.0001 \ 25 | --scheduler steplr \ 26 | --warmup 0.2 \ 27 | --milestones 6 8 \ 28 | --lr_decay 0.5 \ 29 | --trainlist "datasets/lists/blendedmvs/training_list.txt" \ 30 | --testlist "datasets/lists/blendedmvs/validation_list.txt" \ 31 | --fea_mode "fpn" \ 32 | --agg_mode "variance" \ 33 | --depth_mode "regression" \ 34 | --numdepth 128 \ 35 | --interval_scale 1.06 ${@:1} | tee -a $log_dir/log.txt 36 | -------------------------------------------------------------------------------- /scripts/dtu_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | datapath="/data2/yexinyi/datasets/MVS/preprocessed_inputs/dtu/" 3 | outdir="./outputs_dtu/DMVSNet/" 4 | resume="./checkpoints/DMVSNet/model.ckpt" 5 | fusibile_exe_path="./fusibile/build/fusibile" 6 | 7 | 8 | CUDA_VISIBLE_DEVICES=7 python main.py \ 9 | --test \ 10 | --ndepths 48 32 8 \ 11 | --interval_ratio 4 2 1 \ 12 | --max_h 864 \ 13 | --max_w 1152 \ 14 | --num_view 5 \ 15 | --outdir $outdir \ 16 | --datapath $datapath \ 17 | --resume $resume \ 18 | --dataset_name "general_eval" \ 19 | --batch_size 1 \ 20 | --testlist "datasets/lists/dtu/test.txt" \ 21 | --fea_mode "fpn" \ 22 | --agg_mode "variance" \ 23 | --depth_mode "regression" \ 24 | --numdepth 192 \ 25 | --interval_scale 1.06 \ 26 | --filter_method "pcd" \ 27 | --thres_view 5 \ 28 | --num_worker 1 \ 29 | --inverse_depth \ 30 | --conf 0. 0. 0.3 ${@:1} 31 | -------------------------------------------------------------------------------- /scripts/evaluation_dtu/BaseEval2Obj_web.m: -------------------------------------------------------------------------------- 1 | function BaseEval2Obj_web(BaseEval,method_string,outputPath) 2 | 3 | if(nargin<3) 4 | outputPath='./'; 5 | end 6 | 7 | % tresshold for coloring alpha channel in the range of 0-10 mm 8 | dist_tresshold=10; 9 | 10 | cSet=BaseEval.cSet; 11 | 12 | Qdata=BaseEval.Qdata; 13 | alpha=min(BaseEval.Ddata,dist_tresshold)/dist_tresshold; 14 | 15 | fid=fopen([outputPath method_string '2Stl_' num2str(cSet) ' .obj'],'w+'); 16 | 17 | for cP=1:size(Qdata,2) 18 | if(BaseEval.DataInMask(cP)) 19 | C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold) 20 | else 21 | C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points outside the mask (which are not included in the analysis) 22 | end 23 | fprintf(fid,'v %f %f %f %f %f %f\n',[Qdata(1,cP) Qdata(2,cP) Qdata(3,cP) C(1) C(2) C(3)]); 24 | end 25 | fclose(fid); 26 | 27 | disp('Data2Stl saved as obj') 28 | 29 | Qstl=BaseEval.Qstl; 30 | fid=fopen([outputPath 'Stl2' method_string '_' num2str(cSet) '.obj'],'w+'); 31 | 32 | alpha=min(BaseEval.Dstl,dist_tresshold)/dist_tresshold; 33 | 34 | for cP=1:size(Qstl,2) 35 | if(BaseEval.StlAbovePlane(cP)) 36 | C=[1 0 0]*alpha(cP)+[1 1 1]*(1-alpha(cP)); %coloring from red to white in the range of 0-10 mm (0 to dist_tresshold) 37 | else 38 | C=[0 1 0]*alpha(cP)+[0 0 1]*(1-alpha(cP)); %green to blue for points below plane (which are not included in the analysis) 39 | end 40 | fprintf(fid,'v %f %f %f %f %f %f\n',[Qstl(1,cP) Qstl(2,cP) Qstl(3,cP) C(1) C(2) C(3)]); 41 | end 42 | fclose(fid); 43 | 44 | disp('Stl2Data saved as obj') -------------------------------------------------------------------------------- /scripts/evaluation_dtu/BaseEvalMain_web.m: -------------------------------------------------------------------------------- 1 | clear all 2 | close all 3 | format compact 4 | clc 5 | 6 | % script to calculate distances have been measured for all included scans (UsedSets) 7 | 8 | % dataPath=''; %path/Points 9 | dataPath='/data2/yexinyi/datasets/MVS/SampleSet/MVSData'; 10 | % plyPath=''; 11 | plyPath='/data2/yexinyi/code/DMVSNet/outputs_dtu/DMVSNet/pcd'; 12 | % resultsPath=''; 13 | resultsPath='/data2/yexinyi/code/DMVSNet/outputs_dtu/DMVSNet/pcd'; 14 | 15 | 16 | method_string='mvsnet'; 17 | light_string='l3'; % l3 is the setting with all lights on, l7 is randomly sampled between the 7 settings (index 0-6) 18 | representation_string='Points'; %mvs representation 'Points' or 'Surfaces' 19 | 20 | switch representation_string 21 | case 'Points' 22 | eval_string='_Eval_'; %results naming 23 | settings_string=''; 24 | end 25 | 26 | % get sets used in evaluation 27 | UsedSets=[1 4 9 10 11 12 13 15 23 24 29 32 33 34 48 49 62 75 77 110 114 118]; 28 | 29 | dst=0.2; %Min dist between points when reducing 30 | 31 | for cIdx=1:length(UsedSets) 32 | %Data set number 33 | cSet = UsedSets(cIdx) 34 | %input data name 35 | DataInName=[plyPath sprintf('/%s%03d_%s%s.ply',lower(method_string),cSet,light_string,settings_string)] 36 | 37 | %results name 38 | EvalName=[resultsPath method_string eval_string num2str(cSet) '.mat'] 39 | 40 | %check if file is already computed 41 | if(~exist(EvalName,'file')) 42 | disp(DataInName); 43 | 44 | time=clock;time(4:5), drawnow 45 | 46 | tic 47 | Mesh = plyread(DataInName); 48 | Qdata=[Mesh.vertex.x Mesh.vertex.y Mesh.vertex.z]'; 49 | toc 50 | 51 | BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath); 52 | 53 | disp('Saving results'), drawnow 54 | toc 55 | save(EvalName,'BaseEval'); 56 | toc 57 | 58 | % write obj-file of evaluation 59 | % BaseEval2Obj_web(BaseEval,method_string, resultsPath) 60 | % toc 61 | time=clock;time(4:5), drawnow 62 | 63 | BaseEval.MaxDist=20; %outlier threshold of 20 mm 64 | 65 | BaseEval.FilteredDstl=BaseEval.Dstl(BaseEval.StlAbovePlane); %use only points that are above the plane 66 | BaseEval.FilteredDstl=BaseEval.FilteredDstl(BaseEval.FilteredDstl=Low(1) & Qfrom(2,:)>=Low(2) & Qfrom(3,:)>=Low(3) &... 18 | Qfrom(1,:)=Low(1) & Qto(2,:)>=Low(2) & Qto(3,:)>=Low(3) &... 25 | Qto(1,:)3)] 49 | end 50 | 51 | -------------------------------------------------------------------------------- /scripts/evaluation_dtu/PointCompareMain.m: -------------------------------------------------------------------------------- 1 | function BaseEval=PointCompareMain(cSet,Qdata,dst,dataPath) 2 | % evaluation function the calculates the distantes from the reference data (stl) to the evalution points (Qdata) and the 3 | % distances from the evaluation points to the reference 4 | 5 | tic 6 | % reduce points 0.2 mm neighbourhood density 7 | 8 | Qdata=reducePts_haa(Qdata,dst); 9 | toc 10 | 11 | StlInName=[dataPath '/Points/stl/stl' sprintf('%03d',cSet) '_total.ply']; 12 | 13 | StlMesh = plyread(StlInName); %STL points already reduced 0.2 mm neighbourhood density 14 | Qstl=[StlMesh.vertex.x StlMesh.vertex.y StlMesh.vertex.z]'; 15 | 16 | %Load Mask (ObsMask) and Bounding box (BB) and Resolution (Res) 17 | Margin=10; 18 | MaskName=[dataPath '/ObsMask/ObsMask' num2str(cSet) '_' num2str(Margin) '.mat']; 19 | load(MaskName) 20 | 21 | MaxDist=60; 22 | disp('Computing Data 2 Stl distances') 23 | Ddata = MaxDistCP(Qstl,Qdata,BB,MaxDist); 24 | toc 25 | 26 | disp('Computing Stl 2 Data distances') 27 | Dstl=MaxDistCP(Qdata,Qstl,BB,MaxDist); 28 | disp('Distances computed') 29 | toc 30 | 31 | %use mask 32 | %From Get mask - inverted & modified. 33 | One=ones(1,size(Qdata,2)); 34 | Qv=(Qdata-BB(1,:)'*One)/Res+1; 35 | Qv=round(Qv); 36 | 37 | Midx1=find(Qv(1,:)>0 & Qv(1,:)<=size(ObsMask,1) & Qv(2,:)>0 & Qv(2,:)<=size(ObsMask,2) & Qv(3,:)>0 & Qv(3,:)<=size(ObsMask,3)); 38 | MidxA=sub2ind(size(ObsMask),Qv(1,Midx1),Qv(2,Midx1),Qv(3,Midx1)); 39 | Midx2=find(ObsMask(MidxA)); 40 | 41 | BaseEval.DataInMask(1:size(Qv,2))=false; 42 | BaseEval.DataInMask(Midx1(Midx2))=true; %If Data is within the mask 43 | 44 | BaseEval.cSet=cSet; 45 | BaseEval.Margin=Margin; %Margin of masks 46 | BaseEval.dst=dst; %Min dist between points when reducing 47 | BaseEval.Qdata=Qdata; %Input data points 48 | BaseEval.Ddata=Ddata; %distance from data to stl 49 | BaseEval.Qstl=Qstl; %Input stl points 50 | BaseEval.Dstl=Dstl; %Distance from the stl to data 51 | 52 | load([dataPath '/ObsMask/Plane' num2str(cSet)],'P') 53 | BaseEval.GroundPlane=P; % Plane used to destinguise which Stl points are 'used' 54 | BaseEval.StlAbovePlane=(P'*[Qstl;ones(1,size(Qstl,2))])>0; %Is stl above 'ground plane' 55 | BaseEval.Time=clock; %Time when computation is finished 56 | 57 | 58 | 59 | 60 | -------------------------------------------------------------------------------- /scripts/evaluation_dtu/plyread.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | function [Elements,varargout] = plyread(Path,Str) 3 | %PLYREAD Read a PLY 3D data file. 4 | % [DATA,COMMENTS] = PLYREAD(FILENAME) reads a version 1.0 PLY file 5 | % FILENAME and returns a structure DATA. The fields in this structure 6 | % are defined by the PLY header; each element type is a field and each 7 | % element property is a subfield. If the file contains any comments, 8 | % they are returned in a cell string array COMMENTS. 9 | % 10 | % [TRI,PTS] = PLYREAD(FILENAME,'tri') or 11 | % [TRI,PTS,DATA,COMMENTS] = PLYREAD(FILENAME,'tri') converts vertex 12 | % and face data into triangular connectivity and vertex arrays. The 13 | % mesh can then be displayed using the TRISURF command. 14 | % 15 | % Note: This function is slow for large mesh files (+50K faces), 16 | % especially when reading data with list type properties. 17 | % 18 | % Example: 19 | % [Tri,Pts] = PLYREAD('cow.ply','tri'); 20 | % trisurf(Tri,Pts(:,1),Pts(:,2),Pts(:,3)); 21 | % colormap(gray); axis equal; 22 | % 23 | % See also: PLYWRITE 24 | 25 | % Pascal Getreuer 2004 26 | 27 | [fid,Msg] = fopen(Path,'rt'); % open file in read text mode 28 | 29 | if fid == -1, error(Msg); end 30 | 31 | Buf = fscanf(fid,'%s',1); 32 | if ~strcmp(Buf,'ply') 33 | fclose(fid); 34 | error('Not a PLY file.'); 35 | end 36 | 37 | 38 | %%% read header %%% 39 | 40 | Position = ftell(fid); 41 | Format = ''; 42 | NumComments = 0; 43 | Comments = {}; % for storing any file comments 44 | NumElements = 0; 45 | NumProperties = 0; 46 | Elements = []; % structure for holding the element data 47 | ElementCount = []; % number of each type of element in file 48 | PropertyTypes = []; % corresponding structure recording property types 49 | ElementNames = {}; % list of element names in the order they are stored in the file 50 | PropertyNames = []; % structure of lists of property names 51 | 52 | while 1 53 | Buf = fgetl(fid); % read one line from file 54 | BufRem = Buf; 55 | Token = {}; 56 | Count = 0; 57 | 58 | while ~isempty(BufRem) % split line into tokens 59 | [tmp,BufRem] = strtok(BufRem); 60 | 61 | if ~isempty(tmp) 62 | Count = Count + 1; % count tokens 63 | Token{Count} = tmp; 64 | end 65 | end 66 | 67 | if Count % parse line 68 | switch lower(Token{1}) 69 | case 'format' % read data format 70 | if Count >= 2 71 | Format = lower(Token{2}); 72 | 73 | if Count == 3 & ~strcmp(Token{3},'1.0') 74 | fclose(fid); 75 | error('Only PLY format version 1.0 supported.'); 76 | end 77 | end 78 | case 'comment' % read file comment 79 | NumComments = NumComments + 1; 80 | Comments{NumComments} = ''; 81 | for i = 2:Count 82 | Comments{NumComments} = [Comments{NumComments},Token{i},' ']; 83 | end 84 | case 'element' % element name 85 | if Count >= 3 86 | if isfield(Elements,Token{2}) 87 | fclose(fid); 88 | error(['Duplicate element name, ''',Token{2},'''.']); 89 | end 90 | 91 | NumElements = NumElements + 1; 92 | NumProperties = 0; 93 | Elements = setfield(Elements,Token{2},[]); 94 | PropertyTypes = setfield(PropertyTypes,Token{2},[]); 95 | ElementNames{NumElements} = Token{2}; 96 | PropertyNames = setfield(PropertyNames,Token{2},{}); 97 | CurElement = Token{2}; 98 | ElementCount(NumElements) = str2double(Token{3}); 99 | 100 | if isnan(ElementCount(NumElements)) 101 | fclose(fid); 102 | error(['Bad element definition: ',Buf]); 103 | end 104 | else 105 | error(['Bad element definition: ',Buf]); 106 | end 107 | case 'property' % element property 108 | if ~isempty(CurElement) & Count >= 3 109 | NumProperties = NumProperties + 1; 110 | eval(['tmp=isfield(Elements.',CurElement,',Token{Count});'],... 111 | 'fclose(fid);error([''Error reading property: '',Buf])'); 112 | 113 | if tmp 114 | error(['Duplicate property name, ''',CurElement,'.',Token{2},'''.']); 115 | end 116 | 117 | % add property subfield to Elements 118 | eval(['Elements.',CurElement,'.',Token{Count},'=[];'], ... 119 | 'fclose(fid);error([''Error reading property: '',Buf])'); 120 | % add property subfield to PropertyTypes and save type 121 | eval(['PropertyTypes.',CurElement,'.',Token{Count},'={Token{2:Count-1}};'], ... 122 | 'fclose(fid);error([''Error reading property: '',Buf])'); 123 | % record property name order 124 | eval(['PropertyNames.',CurElement,'{NumProperties}=Token{Count};'], ... 125 | 'fclose(fid);error([''Error reading property: '',Buf])'); 126 | else 127 | fclose(fid); 128 | 129 | if isempty(CurElement) 130 | error(['Property definition without element definition: ',Buf]); 131 | else 132 | error(['Bad property definition: ',Buf]); 133 | end 134 | end 135 | case 'end_header' % end of header, break from while loop 136 | break; 137 | end 138 | end 139 | end 140 | 141 | %%% set reading for specified data format %%% 142 | 143 | if isempty(Format) 144 | warning('Data format unspecified, assuming ASCII.'); 145 | Format = 'ascii'; 146 | end 147 | 148 | switch Format 149 | case 'ascii' 150 | Format = 0; 151 | case 'binary_little_endian' 152 | Format = 1; 153 | case 'binary_big_endian' 154 | Format = 2; 155 | otherwise 156 | fclose(fid); 157 | error(['Data format ''',Format,''' not supported.']); 158 | end 159 | 160 | if ~Format 161 | Buf = fscanf(fid,'%f'); % read the rest of the file as ASCII data 162 | BufOff = 1; 163 | else 164 | % reopen the file in read binary mode 165 | fclose(fid); 166 | 167 | if Format == 1 168 | fid = fopen(Path,'r','ieee-le.l64'); % little endian 169 | else 170 | fid = fopen(Path,'r','ieee-be.l64'); % big endian 171 | end 172 | 173 | % find the end of the header again (using ftell on the old handle doesn't give the correct position) 174 | BufSize = 8192; 175 | Buf = [blanks(10),char(fread(fid,BufSize,'uchar')')]; 176 | i = []; 177 | tmp = -11; 178 | 179 | while isempty(i) 180 | i = findstr(Buf,['end_header',13,10]); % look for end_header + CR/LF 181 | i = [i,findstr(Buf,['end_header',10])]; % look for end_header + LF 182 | 183 | if isempty(i) 184 | tmp = tmp + BufSize; 185 | Buf = [Buf(BufSize+1:BufSize+10),char(fread(fid,BufSize,'uchar')')]; 186 | end 187 | end 188 | 189 | % seek to just after the line feed 190 | fseek(fid,i + tmp + 11 + (Buf(i + 10) == 13),-1); 191 | end 192 | 193 | 194 | %%% read element data %%% 195 | 196 | % PLY and MATLAB data types (for fread) 197 | PlyTypeNames = {'char','uchar','short','ushort','int','uint','float','double', ... 198 | 'char8','uchar8','short16','ushort16','int32','uint32','float32','double64'}; 199 | MatlabTypeNames = {'schar','uchar','int16','uint16','int32','uint32','single','double'}; 200 | SizeOf = [1,1,2,2,4,4,4,8]; % size in bytes of each type 201 | 202 | for i = 1:NumElements 203 | % get current element property information 204 | eval(['CurPropertyNames=PropertyNames.',ElementNames{i},';']); 205 | eval(['CurPropertyTypes=PropertyTypes.',ElementNames{i},';']); 206 | NumProperties = size(CurPropertyNames,2); 207 | 208 | % fprintf('Reading %s...\n',ElementNames{i}); 209 | 210 | if ~Format %%% read ASCII data %%% 211 | for j = 1:NumProperties 212 | Token = getfield(CurPropertyTypes,CurPropertyNames{j}); 213 | 214 | if strcmpi(Token{1},'list') 215 | Type(j) = 1; 216 | else 217 | Type(j) = 0; 218 | end 219 | end 220 | 221 | % parse buffer 222 | if ~any(Type) 223 | % no list types 224 | Data = reshape(Buf(BufOff:BufOff+ElementCount(i)*NumProperties-1),NumProperties,ElementCount(i))'; 225 | BufOff = BufOff + ElementCount(i)*NumProperties; 226 | else 227 | ListData = cell(NumProperties,1); 228 | 229 | for k = 1:NumProperties 230 | ListData{k} = cell(ElementCount(i),1); 231 | end 232 | 233 | % list type 234 | for j = 1:ElementCount(i) 235 | for k = 1:NumProperties 236 | if ~Type(k) 237 | Data(j,k) = Buf(BufOff); 238 | BufOff = BufOff + 1; 239 | else 240 | tmp = Buf(BufOff); 241 | ListData{k}{j} = Buf(BufOff+(1:tmp))'; 242 | BufOff = BufOff + tmp + 1; 243 | end 244 | end 245 | end 246 | end 247 | else %%% read binary data %%% 248 | % translate PLY data type names to MATLAB data type names 249 | ListFlag = 0; % = 1 if there is a list type 250 | SameFlag = 1; % = 1 if all types are the same 251 | 252 | for j = 1:NumProperties 253 | Token = getfield(CurPropertyTypes,CurPropertyNames{j}); 254 | 255 | if ~strcmp(Token{1},'list') % non-list type 256 | tmp = rem(strmatch(Token{1},PlyTypeNames,'exact')-1,8)+1; 257 | 258 | if ~isempty(tmp) 259 | TypeSize(j) = SizeOf(tmp); 260 | Type{j} = MatlabTypeNames{tmp}; 261 | TypeSize2(j) = 0; 262 | Type2{j} = ''; 263 | 264 | SameFlag = SameFlag & strcmp(Type{1},Type{j}); 265 | else 266 | fclose(fid); 267 | error(['Unknown property data type, ''',Token{1},''', in ', ... 268 | ElementNames{i},'.',CurPropertyNames{j},'.']); 269 | end 270 | else % list type 271 | if length(Token) == 3 272 | ListFlag = 1; 273 | SameFlag = 0; 274 | tmp = rem(strmatch(Token{2},PlyTypeNames,'exact')-1,8)+1; 275 | tmp2 = rem(strmatch(Token{3},PlyTypeNames,'exact')-1,8)+1; 276 | 277 | if ~isempty(tmp) & ~isempty(tmp2) 278 | TypeSize(j) = SizeOf(tmp); 279 | Type{j} = MatlabTypeNames{tmp}; 280 | TypeSize2(j) = SizeOf(tmp2); 281 | Type2{j} = MatlabTypeNames{tmp2}; 282 | else 283 | fclose(fid); 284 | error(['Unknown property data type, ''list ',Token{2},' ',Token{3},''', in ', ... 285 | ElementNames{i},'.',CurPropertyNames{j},'.']); 286 | end 287 | else 288 | fclose(fid); 289 | error(['Invalid list syntax in ',ElementNames{i},'.',CurPropertyNames{j},'.']); 290 | end 291 | end 292 | end 293 | 294 | % read file 295 | if ~ListFlag 296 | if SameFlag 297 | % no list types, all the same type (fast) 298 | Data = fread(fid,[NumProperties,ElementCount(i)],Type{1})'; 299 | else 300 | % no list types, mixed type 301 | Data = zeros(ElementCount(i),NumProperties); 302 | 303 | for j = 1:ElementCount(i) 304 | for k = 1:NumProperties 305 | Data(j,k) = fread(fid,1,Type{k}); 306 | end 307 | end 308 | end 309 | else 310 | ListData = cell(NumProperties,1); 311 | 312 | for k = 1:NumProperties 313 | ListData{k} = cell(ElementCount(i),1); 314 | end 315 | 316 | if NumProperties == 1 317 | BufSize = 512; 318 | SkipNum = 4; 319 | j = 0; 320 | 321 | % list type, one property (fast if lists are usually the same length) 322 | while j < ElementCount(i) 323 | Position = ftell(fid); 324 | % read in BufSize count values, assuming all counts = SkipNum 325 | [Buf,BufSize] = fread(fid,BufSize,Type{1},SkipNum*TypeSize2(1)); 326 | Miss = find(Buf ~= SkipNum); % find first count that is not SkipNum 327 | fseek(fid,Position + TypeSize(1),-1); % seek back to after first count 328 | 329 | if isempty(Miss) % all counts are SkipNum 330 | Buf = fread(fid,[SkipNum,BufSize],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; 331 | fseek(fid,-TypeSize(1),0); % undo last skip 332 | 333 | for k = 1:BufSize 334 | ListData{1}{j+k} = Buf(k,:); 335 | end 336 | 337 | j = j + BufSize; 338 | BufSize = floor(1.5*BufSize); 339 | else 340 | if Miss(1) > 1 % some counts are SkipNum 341 | Buf2 = fread(fid,[SkipNum,Miss(1)-1],[int2str(SkipNum),'*',Type2{1}],TypeSize(1))'; 342 | 343 | for k = 1:Miss(1)-1 344 | ListData{1}{j+k} = Buf2(k,:); 345 | end 346 | 347 | j = j + k; 348 | end 349 | 350 | % read in the list with the missed count 351 | SkipNum = Buf(Miss(1)); 352 | j = j + 1; 353 | ListData{1}{j} = fread(fid,[1,SkipNum],Type2{1}); 354 | BufSize = ceil(0.6*BufSize); 355 | end 356 | end 357 | else 358 | % list type(s), multiple properties (slow) 359 | Data = zeros(ElementCount(i),NumProperties); 360 | 361 | for j = 1:ElementCount(i) 362 | for k = 1:NumProperties 363 | if isempty(Type2{k}) 364 | Data(j,k) = fread(fid,1,Type{k}); 365 | else 366 | tmp = fread(fid,1,Type{k}); 367 | ListData{k}{j} = fread(fid,[1,tmp],Type2{k}); 368 | end 369 | end 370 | end 371 | end 372 | end 373 | end 374 | 375 | % put data into Elements structure 376 | for k = 1:NumProperties 377 | if (~Format & ~Type(k)) | (Format & isempty(Type2{k})) 378 | eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=Data(:,k);']); 379 | else 380 | eval(['Elements.',ElementNames{i},'.',CurPropertyNames{k},'=ListData{k};']); 381 | end 382 | end 383 | end 384 | 385 | clear Data ListData; 386 | fclose(fid); 387 | 388 | if (nargin > 1 & strcmpi(Str,'Tri')) | nargout > 2 389 | % find vertex element field 390 | Name = {'vertex','Vertex','point','Point','pts','Pts'}; 391 | Names = []; 392 | 393 | for i = 1:length(Name) 394 | if any(strcmp(ElementNames,Name{i})) 395 | Names = getfield(PropertyNames,Name{i}); 396 | Name = Name{i}; 397 | break; 398 | end 399 | end 400 | 401 | if any(strcmp(Names,'x')) & any(strcmp(Names,'y')) & any(strcmp(Names,'z')) 402 | eval(['varargout{1}=[Elements.',Name,'.x,Elements.',Name,'.y,Elements.',Name,'.z];']); 403 | else 404 | varargout{1} = zeros(1,3); 405 | end 406 | 407 | varargout{2} = Elements; 408 | varargout{3} = Comments; 409 | Elements = []; 410 | 411 | % find face element field 412 | Name = {'face','Face','poly','Poly','tri','Tri'}; 413 | Names = []; 414 | 415 | for i = 1:length(Name) 416 | if any(strcmp(ElementNames,Name{i})) 417 | Names = getfield(PropertyNames,Name{i}); 418 | Name = Name{i}; 419 | break; 420 | end 421 | end 422 | 423 | if ~isempty(Names) 424 | % find vertex indices property subfield 425 | PropertyName = {'vertex_indices','vertex_indexes','vertex_index','indices','indexes'}; 426 | 427 | for i = 1:length(PropertyName) 428 | if any(strcmp(Names,PropertyName{i})) 429 | PropertyName = PropertyName{i}; 430 | break; 431 | end 432 | end 433 | 434 | if ~iscell(PropertyName) 435 | % convert face index lists to triangular connectivity 436 | eval(['FaceIndices=varargout{2}.',Name,'.',PropertyName,';']); 437 | N = length(FaceIndices); 438 | Elements = zeros(N*2,3); 439 | Extra = 0; 440 | 441 | for k = 1:N 442 | Elements(k,:) = FaceIndices{k}(1:3); 443 | 444 | for j = 4:length(FaceIndices{k}) 445 | Extra = Extra + 1; 446 | Elements(N + Extra,:) = [Elements(k,[1,j-1]),FaceIndices{k}(j)]; 447 | end 448 | end 449 | Elements = Elements(1:N+Extra,:) + 1; 450 | end 451 | end 452 | else 453 | varargout{1} = Comments; 454 | end -------------------------------------------------------------------------------- /scripts/evaluation_dtu/reducePts_haa.m: -------------------------------------------------------------------------------- 1 | function [ptsOut,indexSet] = reducePts_haa(pts, dst) 2 | 3 | %Reduces a point set, pts, in a stochastic manner, such that the minimum sdistance 4 | % between points is 'dst'. Writen by abd, edited by haa, then by raje 5 | 6 | nPoints=size(pts,2); 7 | 8 | indexSet=true(nPoints,1); 9 | RandOrd=randperm(nPoints);% 10 | 11 | %tic 12 | NS = KDTreeSearcher(pts'); 13 | %toc 14 | 15 | % search the KNTree for close neighbours in a chunk-wise fashion to save memory if point cloud is really big 16 | Chunks=1:min(4e6,nPoints-1):nPoints; 17 | Chunks(end)=nPoints; 18 | 19 | for cChunk=1:(length(Chunks)-1) 20 | Range=Chunks(cChunk):Chunks(cChunk+1); 21 | idx = rangesearch(NS,pts(:,RandOrd(Range))',dst); 22 | 23 | for i = 1:size(idx,1) 24 | id =RandOrd(i-1+Chunks(cChunk)); 25 | if (indexSet(id)) 26 | indexSet(idx{i}) = 0; 27 | indexSet(id) = 1; 28 | end 29 | end 30 | end 31 | 32 | ptsOut = pts(:,indexSet); 33 | 34 | disp(['downsample factor: ' num2str(nPoints/sum(indexSet))]); 35 | -------------------------------------------------------------------------------- /scripts/tank_test.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | datapath="/data2/yexinyi/datasets/MVS/preprocessed_inputs/tankandtemples/advanced/" 3 | # datapath="/data2/yexinyi/datasets/MVS/preprocessed_inputs/tankandtemples/intermediate/" 4 | 5 | outdir="./outputs_tank/DMVSNet/" 6 | resume="" 7 | 8 | CUDA_VISIBLE_DEVICES=1 python main.py \ 9 | --test \ 10 | --ndepths 64 32 8 \ 11 | --interval_ratio 3 2 1 \ 12 | --num_view 11 \ 13 | --outdir $outdir \ 14 | --datapath $datapath \ 15 | --resume $resume \ 16 | --dataset_name "general_eval" \ 17 | --batch_size 1 \ 18 | --testlist "all" \ 19 | --fea_mode "fpn" \ 20 | --agg_mode "variance" \ 21 | --depth_mode "regression" \ 22 | --numdepth 192 \ 23 | --interval_scale 1.06 \ 24 | --filter_method "dypcd" ${@:1} -------------------------------------------------------------------------------- /scripts/train.sh: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env bash 2 | datapath="/data2/yexinyi/datasets/MVS/training_data/dtu_training/" 3 | 4 | log_dir="checkpoints/DMVSNet" 5 | if [ ! -d $log_dir ]; then 6 | mkdir -p $log_dir 7 | fi 8 | 9 | CUDA_VISIBLE_DEVICES=6,7 python -m torch.distributed.launch --nproc_per_node=2 --master_port=1111 main.py \ 10 | --sync_bn \ 11 | --ndepths 48 32 8 \ 12 | --interval_ratio 4 2 1 \ 13 | --img_size 512 640 \ 14 | --num_view 5 \ 15 | --dlossw 0.5 1.0 2.0 \ 16 | --log_dir $log_dir \ 17 | --datapath $datapath \ 18 | --dataset_name "dtu_yao" \ 19 | --epochs 16 \ 20 | --batch_size 2 \ 21 | --lr 0.001 \ 22 | --warmup 0.2 \ 23 | --scheduler "steplr" \ 24 | --milestones 10 12 14 \ 25 | --lr_decay 0.5 \ 26 | --trainlist "datasets/lists/dtu/train.txt" \ 27 | --testlist "datasets/lists/dtu/test.txt" \ 28 | --fea_mode "fpn" \ 29 | --agg_mode "variance" \ 30 | --depth_mode "regression" \ 31 | --inverse_depth \ 32 | --numdepth 192 \ 33 | --interval_scale 1.06 ${@:1} | tee -a $log_dir/log.txt 34 | 35 | -------------------------------------------------------------------------------- /tools.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import os 3 | import math 4 | import numpy as np 5 | import torchvision.utils as vutils 6 | import torch.distributed as dist 7 | from torch.optim.lr_scheduler import LambdaLR 8 | 9 | def setup_seed(seed=3407): 10 | os.environ["PYTHONHASHSEED"]=str(seed) 11 | torch.manual_seed(seed) 12 | torch.cuda.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | 15 | np.random.seed(seed) 16 | 17 | 18 | class DictAverageMeter(object): 19 | def __init__(self): 20 | self.sum_data = {} 21 | self.avg_data = {} 22 | self.count = 0 23 | 24 | def update(self, new_input): 25 | self.count += 1 26 | if len(self.sum_data) == 0: 27 | for k, v in new_input.items(): 28 | if not isinstance(v, float): 29 | raise NotImplementedError("invalid data {}: {}".format(k, type(v))) 30 | self.sum_data[k] = v 31 | self.avg_data[k] = v 32 | else: 33 | for k, v in new_input.items(): 34 | if not isinstance(v, float): 35 | raise NotImplementedError("invalid data {}: {}".format(k, type(v))) 36 | self.sum_data[k] += v 37 | self.avg_data[k] = self.sum_data[k] / self.count 38 | 39 | 40 | def write_cam(file, cam): 41 | f = open(file, "w") 42 | f.write('extrinsic\n') 43 | for i in range(0, 4): 44 | for j in range(0, 4): 45 | f.write(str(cam[0][i][j]) + ' ') 46 | f.write('\n') 47 | f.write('\n') 48 | 49 | f.write('intrinsic\n') 50 | for i in range(0, 3): 51 | for j in range(0, 3): 52 | f.write(str(cam[1][i][j]) + ' ') 53 | f.write('\n') 54 | 55 | f.write('\n' + str(cam[1][3][0]) + ' ' + str(cam[1][3][1]) + ' ' + str(cam[1][3][2]) + ' ' + str(cam[1][3][3]) + '\n') 56 | 57 | f.close() 58 | 59 | 60 | # convert a function into recursive style to handle nested dict/list/tuple variables 61 | def make_recursive_func(func): 62 | def wrapper(vars): 63 | if isinstance(vars, list): 64 | return [wrapper(x) for x in vars] 65 | elif isinstance(vars, tuple): 66 | return tuple([wrapper(x) for x in vars]) 67 | elif isinstance(vars, dict): 68 | return {k: wrapper(v) for k, v in vars.items()} 69 | else: 70 | return func(vars) 71 | 72 | return wrapper 73 | 74 | 75 | def save_scalars(logger, mode, scalar_dict, global_step): 76 | scalar_dict = tensor2float(scalar_dict) 77 | for key, value in scalar_dict.items(): 78 | if not isinstance(value, (list, tuple)): 79 | name = '{}/{}'.format(mode, key) 80 | logger.add_scalar(name, value, global_step) 81 | else: 82 | for idx in range(len(value)): 83 | name = '{}/{}_{}'.format(mode, key, idx) 84 | logger.add_scalar(name, value[idx], global_step) 85 | 86 | 87 | def save_images(logger, mode, images_dict, global_step): 88 | images_dict = tensor2numpy(images_dict) 89 | 90 | def preprocess(name, img): 91 | if not (len(img.shape) == 3 or len(img.shape) == 4): 92 | raise NotImplementedError("invalid img shape {}:{} in save_images".format(name, img.shape)) 93 | if len(img.shape) == 3: 94 | img = img[:, np.newaxis, :, :] 95 | img = torch.from_numpy(img[:1]) 96 | return vutils.make_grid(img, padding=0, nrow=1, normalize=True, scale_each=True) 97 | 98 | for key, value in images_dict.items(): 99 | if not isinstance(value, (list, tuple)): 100 | name = '{}/{}'.format(mode, key) 101 | logger.add_image(name, preprocess(name, value), global_step) 102 | else: 103 | for idx in range(len(value)): 104 | name = '{}/{}_{}'.format(mode, key, idx) 105 | logger.add_image(name, preprocess(name, value[idx]), global_step) 106 | 107 | 108 | @make_recursive_func 109 | def tensor2numpy(vars): 110 | if isinstance(vars, np.ndarray): 111 | return vars 112 | elif isinstance(vars, torch.Tensor): 113 | return vars.detach().cpu().numpy().copy() 114 | else: 115 | raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) 116 | 117 | 118 | @make_recursive_func 119 | def tensor2float(vars): 120 | if isinstance(vars, float): 121 | return vars 122 | elif isinstance(vars, torch.Tensor): 123 | return vars.data.item() 124 | else: 125 | raise NotImplementedError("invalid input type {} for tensor2float".format(type(vars))) 126 | 127 | 128 | def reduce_scalar_outputs(scalar_outputs): 129 | world_size = get_world_size() 130 | if world_size < 2: 131 | return scalar_outputs 132 | with torch.no_grad(): 133 | names = [] 134 | scalars = [] 135 | for k in sorted(scalar_outputs.keys()): 136 | names.append(k) 137 | scalars.append(scalar_outputs[k]) 138 | scalars = torch.stack(scalars, dim=0) 139 | dist.reduce(scalars, dst=0) 140 | if dist.get_rank() == 0: 141 | # only main process gets accumulated, so only divide by 142 | # world_size in this case 143 | scalars /= world_size 144 | reduced_scalars = {k: v for k, v in zip(names, scalars)} 145 | 146 | return reduced_scalars 147 | 148 | 149 | @make_recursive_func 150 | def tocuda(vars): 151 | if isinstance(vars, torch.Tensor): 152 | return vars.to(torch.device("cuda")) 153 | elif isinstance(vars, str): 154 | return vars 155 | else: 156 | raise NotImplementedError("invalid input type {} for tensor2numpy".format(type(vars))) 157 | 158 | 159 | # a wrapper to compute metrics for each image individually 160 | def compute_metrics_for_each_image(metric_func): 161 | def wrapper(depth_est, depth_gt, mask, *args): 162 | batch_size = depth_gt.shape[0] 163 | results = [] 164 | # compute result one by one 165 | for idx in range(batch_size): 166 | ret = metric_func(depth_est[idx], depth_gt[idx], mask[idx], *args) 167 | if torch.isnan(ret): 168 | results.append(torch.zeros_like(ret)) 169 | else: 170 | results.append(ret) 171 | return torch.stack(results).mean() 172 | 173 | return wrapper 174 | 175 | 176 | @torch.no_grad() 177 | @compute_metrics_for_each_image 178 | def AbsDepthError_metrics(depth_est, depth_gt, mask, thres=None): 179 | depth_est, depth_gt = depth_est[mask], depth_gt[mask] 180 | error = (depth_est - depth_gt).abs() 181 | if thres is not None: 182 | error = error[(error >= float(thres[0])) & (error <= float(thres[1]))] 183 | if error.shape[0] == 0: 184 | return torch.tensor(0, device=error.device, dtype=error.dtype) 185 | return torch.mean(error) 186 | 187 | 188 | @torch.no_grad() 189 | @compute_metrics_for_each_image 190 | def Thres_metrics(depth_est, depth_gt, mask, thres,return_Mean=False): 191 | assert isinstance(thres, (int, float)) 192 | depth_est, depth_gt = depth_est[mask], depth_gt[mask] 193 | errors = torch.abs(depth_est - depth_gt) 194 | err_mask = errors > thres 195 | if return_Mean==True: 196 | return torch.mean(err_mask.float()),errors[~err_mask].mean() 197 | else: 198 | if torch.isnan(torch.mean(err_mask.float())): 199 | return torch.zeros_like(torch.mean(err_mask.float())) 200 | else: 201 | return torch.mean(err_mask.float()) 202 | 203 | 204 | def generate_pointcloud(rgb, depth, ply_file, intr, scale=1.0): 205 | """ 206 | Generate a colored point cloud in PLY format from a color and a depth image. 207 | 208 | Input: 209 | rgb_file -- filename of color image 210 | depth_file -- filename of depth image 211 | ply_file -- filename of ply file 212 | 213 | """ 214 | fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2] 215 | points = [] 216 | for v in range(rgb.shape[0]): 217 | for u in range(rgb.shape[1]): 218 | color = rgb[v, u] # rgb.getpixel((u, v)) 219 | Z = depth[v, u] / scale 220 | if Z == 0: continue 221 | X = (u - cx) * Z / fx 222 | Y = (v - cy) * Z / fy 223 | points.append("%f %f %f %d %d %d 0\n" % (X, Y, Z, color[0], color[1], color[2])) 224 | file = open(ply_file, "w") 225 | file.write('''ply 226 | format ascii 1.0 227 | element vertex %d 228 | property float x 229 | property float y 230 | property float z 231 | property uchar red 232 | property uchar green 233 | property uchar blue 234 | property uchar alpha 235 | end_header 236 | %s 237 | ''' % (len(points), "".join(points))) 238 | file.close() 239 | print("save ply, fx:{}, fy:{}, cx:{}, cy:{}".format(fx, fy, cx, cy)) 240 | 241 | 242 | def get_schedular(optimizer, args): 243 | warmup = args.warmup 244 | milestones = np.array(args.milestones) 245 | decay = args.lr_decay 246 | if args.scheduler == "steplr": 247 | lambda_func = lambda step: 1 / 3 * (1 - step / warmup) + step / warmup if step < warmup \ 248 | else (decay ** (milestones <= step).sum()) 249 | elif args.scheduler == "cosinelr": 250 | max_lr = args.lr 251 | min_lr = max_lr * (args.lr_decay ** 3) 252 | T_max = args.epochs 253 | lambda_func = lambda step: 1 / 3 * (1 - step / warmup) + step / warmup if step < warmup else \ 254 | (min_lr + 0.5 * (max_lr - min_lr) * (1.0 + math.cos((step - warmup) / (T_max - warmup) * math.pi))) / max_lr 255 | 256 | scheduler = LambdaLR(optimizer, lambda_func) 257 | return scheduler 258 | 259 | 260 | def is_dist_avail_and_initialized(): 261 | if not dist.is_available(): 262 | return False 263 | if not dist.is_initialized(): 264 | return False 265 | return True 266 | 267 | 268 | def get_world_size(): 269 | if not is_dist_avail_and_initialized(): 270 | return 1 271 | return dist.get_world_size() 272 | 273 | 274 | def get_rank(): 275 | if not is_dist_avail_and_initialized(): 276 | return 0 277 | return dist.get_rank() 278 | 279 | 280 | def is_main_process(): 281 | return get_rank() == 0 282 | 283 | 284 | def setup_for_distributed(is_master): 285 | """ 286 | This function disables printing when not in master process 287 | """ 288 | import builtins as __builtin__ 289 | builtin_print = __builtin__.print 290 | 291 | def print(*args, **kwargs): 292 | force = kwargs.pop('force', False) 293 | if is_master or force: 294 | builtin_print(*args, **kwargs) 295 | 296 | __builtin__.print = print 297 | 298 | 299 | def init_distributed_mode(args): 300 | if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: 301 | args.rank = int(os.environ["RANK"]) 302 | args.world_size = int(os.environ['WORLD_SIZE']) 303 | args.gpu = int(os.environ['LOCAL_RANK']) 304 | elif 'SLURM_PROCID' in os.environ: 305 | args.rank = int(os.environ['SLURM_PROCID']) 306 | args.gpu = args.rank % torch.cuda.device_count() 307 | elif hasattr(args, "rank"): 308 | pass 309 | else: 310 | print('Not using distributed mode') 311 | args.distributed = False 312 | return 313 | 314 | args.distributed = True 315 | 316 | torch.cuda.set_device(args.gpu) 317 | args.dist_backend = 'nccl' 318 | print('| distributed init (rank {}): {}'.format( 319 | args.rank, args.dist_url), flush=True) 320 | torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url, 321 | world_size=args.world_size, rank=args.rank) 322 | setup_for_distributed(args.rank == 0) 323 | --------------------------------------------------------------------------------