├── .gitignore ├── LICENSE ├── README.md ├── configs └── dtu.yaml ├── data ├── fastmvsnet ├── __init__.py ├── config.py ├── dataset.py ├── functions │ ├── __init__.py │ └── functions.py ├── model.py ├── networks.py ├── nn │ ├── __init__.py │ ├── conv.py │ ├── freeze_weight.py │ ├── functional.py │ ├── init.py │ ├── linear.py │ └── mlp.py ├── solver.py ├── test.py ├── train.py └── utils │ ├── __init__.py │ ├── checkpoint.py │ ├── eval_file_logger.py │ ├── feature_fetcher.py │ ├── file_logger.py │ ├── io.py │ ├── logger.py │ ├── metric_logger.py │ ├── preprocess.py │ ├── tensorboard_logger.py │ └── torch_utils.py ├── outputs └── pretrained.pth ├── requirements.txt └── tools └── depthfusion.py /.gitignore: -------------------------------------------------------------------------------- 1 | #ide 2 | .idea/ 3 | 4 | # checkpoints 5 | data/ 6 | 7 | # Byte-compiled / optimized / DLL files 8 | __pycache__/ 9 | *.py[cod] 10 | *$py.class 11 | 12 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 svip-lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Fast-MVSNet 2 | 3 | PyTorch implementation of our CVPR 2020 paper: 4 | 5 | [Fast-MVSNet: Sparse-to-Dense Multi-View Stereo With Learned Propagation and Gauss-Newton Refinement](https://arxiv.org/pdf/2003.13017.pdf) 6 | 7 | Zehao Yu, 8 | [Shenghua Gao](http://sist.shanghaitech.edu.cn/sist_en/2018/0820/c3846a31775/page.htm) 9 | 10 | ## How to use 11 | ```bash 12 | git clone git@github.com:svip-lab/FastMVSNet.git 13 | ``` 14 | ### Installation 15 | ```bash 16 | pip install -r requirements.txt 17 | ``` 18 | 19 | ### Training 20 | * Download the preprocessed [DTU training data](https://drive.google.com/file/d/1eDjh-_bxKKnEuz5h-HXS7EDJn59clx6V/view) from [MVSNet](https://github.com/YoYo000/MVSNet) and unzip it to ```data/dtu```. 21 | * Train the network 22 | 23 | ```python fastmvsnet/train.py --cfg configs/dtu.yaml``` 24 | 25 | You could change the batch size in the configuration file according to your own pc. 26 | 27 | ### Testing 28 | * Download the [rectified images](http://roboimagedata2.compute.dtu.dk/data/MVS/Rectified.zip) from [DTU benchmark](http://roboimagedata.compute.dtu.dk/?page_id=36) and unzip it to ```data/dtu/Eval```. 29 | 30 | * Test with the pretrained model 31 | 32 | ```python fastmvsnet/test.py --cfg configs/dtu.yaml TEST.WEIGHT outputs/pretrained.pth``` 33 | 34 | ### Depth Fusion 35 | We need to apply depth fusion ```tools/depthfusion.py``` to get the complete point cloud. Please refer to [MVSNet](https://github.com/YoYo000/MVSNet) for more details. 36 | 37 | ```bash 38 | python tools/depthfusion.py -f dtu -n flow2 39 | ``` 40 | 41 | ## Acknowledgements 42 | Most of the code is borrowed from [PointMVSNet](https://github.com/callmeray/PointMVSNet). We thank Rui Chen for his great works and repos. 43 | 44 | ## Citation 45 | Please cite our paper for any purpose of usage. 46 | ``` 47 | @inproceedings{Yu_2020_fastmvsnet, 48 | author = {Zehao Yu and Shenghua Gao}, 49 | title = {Fast-MVSNet: Sparse-to-Dense Multi-View Stereo With Learned Propagation and Gauss-Newton Refinement}, 50 | booktitle = {CVPR}, 51 | year = {2020} 52 | } 53 | ``` 54 | -------------------------------------------------------------------------------- /configs/dtu.yaml: -------------------------------------------------------------------------------- 1 | DATA: 2 | TRAIN: 3 | ROOT_DIR: "data/dtu2" 4 | NUM_VIRTUAL_PLANE: 48 5 | INTER_SCALE: 4.24 6 | VAL: 7 | ROOT_DIR: "data/dtu" 8 | TEST: 9 | ROOT_DIR: "data/dtu2" 10 | NUM_VIEW: 5 11 | IMG_HEIGHT: 960 12 | IMG_WIDTH: 1280 13 | NUM_VIRTUAL_PLANE: 96 14 | INTER_SCALE: 2.13 15 | NUM_WORKERS: 16 16 | MODEL: 17 | EDGE_CHANNELS: (32, 32, 64) 18 | TRAIN: 19 | IMG_SCALES: (0.25, 0.5) 20 | INTER_SCALES: (0.75, 0.375) 21 | TEST: 22 | IMG_SCALES: (0.25, 0.5) 23 | INTER_SCALES: (0.75, 0.15) 24 | SCHEDULER: 25 | TYPE: "StepLR" 26 | INIT_EPOCH: 4 27 | MAX_EPOCH: 16 28 | StepLR: 29 | gamma: 0.9 30 | step_size: 2 31 | SOLVER: 32 | BASE_LR: 0.0005 33 | WEIGHT_DECAY: 0.001 34 | TYPE: 'RMSprop' 35 | TRAIN: 36 | BATCH_SIZE: 16 37 | CHECKPOINT_PERIOD: 1 38 | LOG_PERIOD: 10 39 | TEST: 40 | WEIGHT: "outputs/dtu_wde3/model_016.pth" 41 | 42 | 43 | -------------------------------------------------------------------------------- /data: -------------------------------------------------------------------------------- 1 | /p300/yuzh/projects/PointMVSNet/data -------------------------------------------------------------------------------- /fastmvsnet/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/FastMVSNet/ccb686dda2717613c67d8a289dfe7b2aeb60e2fd/fastmvsnet/__init__.py -------------------------------------------------------------------------------- /fastmvsnet/config.py: -------------------------------------------------------------------------------- 1 | from yacs.config import CfgNode as CN 2 | from yacs.config import load_cfg 3 | 4 | 5 | _C = CN() 6 | 7 | # if set to @, the filename of config will be used by default 8 | _C.OUTPUT_DIR = "@" 9 | # Automatically resume weights from last checkpoints 10 | _C.AUTO_RESUME = True 11 | # For reproducibility...but not really because modern fast GPU libraries use 12 | # non-deterministic op implementations 13 | # -1 means not to set explicitly. 14 | _C.RNG_SEED = 1 15 | 16 | # ----------------------------------------------------------------------------- 17 | # DATA 18 | # ----------------------------------------------------------------------------- 19 | 20 | _C.DATA = CN() 21 | 22 | _C.DATA.NUM_WORKERS = 1 23 | 24 | _C.DATA.TRAIN = CN() 25 | _C.DATA.TRAIN.ROOT_DIR = "" 26 | _C.DATA.TRAIN.NUM_VIEW = 3 27 | _C.DATA.TRAIN.NUM_VIRTUAL_PLANE = 48 28 | _C.DATA.TRAIN.INTER_SCALE = 4.24 29 | 30 | 31 | _C.DATA.VAL = CN() 32 | _C.DATA.VAL = CN() 33 | _C.DATA.VAL.ROOT_DIR = "" 34 | _C.DATA.VAL.NUM_VIEW = 3 35 | 36 | 37 | _C.DATA.TEST = CN() 38 | _C.DATA.TEST = CN() 39 | _C.DATA.TEST.ROOT_DIR = "" 40 | _C.DATA.TEST.NUM_VIEW = 3 41 | _C.DATA.TEST.IMG_HEIGHT = 512 42 | _C.DATA.TEST.IMG_WIDTH = 640 43 | _C.DATA.TEST.NUM_VIRTUAL_PLANE = 48 44 | _C.DATA.TEST.INTER_SCALE = 4.24 45 | 46 | # ----------------------------------------------------------------------------- 47 | # MODEL 48 | # ----------------------------------------------------------------------------- 49 | 50 | _C.MODEL = CN() 51 | _C.MODEL.WEIGHT = "" 52 | 53 | _C.MODEL.EDGE_CHANNELS = () 54 | _C.MODEL.FLOW_CHANNELS = (64, 64, 16, 1) 55 | _C.MODEL.NUM_VIRTUAL_PLANE = 48 56 | _C.MODEL.IMG_BASE_CHANNELS = 8 57 | _C.MODEL.VOL_BASE_CHANNELS = 8 58 | 59 | _C.MODEL.VALID_THRESHOLD = 8.0 60 | 61 | _C.MODEL.TRAIN = CN() 62 | _C.MODEL.TRAIN.IMG_SCALES = (0.125, 0.25) 63 | _C.MODEL.TRAIN.INTER_SCALES = (0.75, 0.375) 64 | 65 | _C.MODEL.VAL = CN() 66 | _C.MODEL.VAL.IMG_SCALES = (0.125, 0.25) 67 | _C.MODEL.VAL.INTER_SCALES = (0.75, 0.375) 68 | 69 | _C.MODEL.TEST = CN() 70 | _C.MODEL.TEST.IMG_SCALES = (0.125, 0.25, 0.5) 71 | _C.MODEL.TEST.INTER_SCALES = (1.0, 0.75, 0.15) 72 | 73 | # ---------------------------------------------------------------------------- # 74 | # Solver (optimizer) 75 | # ---------------------------------------------------------------------------- # 76 | 77 | _C.SOLVER = CN() 78 | 79 | # Type of optimizer 80 | _C.SOLVER.TYPE = "RMSprop" 81 | 82 | # Basic parameters of solvers 83 | # Notice to change learning rate according to batch size 84 | _C.SOLVER.BASE_LR = 0.001 85 | 86 | _C.SOLVER.WEIGHT_DECAY = 0.0 87 | 88 | # Specific parameters of solvers 89 | _C.SOLVER.RMSprop = CN() 90 | _C.SOLVER.RMSprop.alpha = 0.9 91 | 92 | _C.SOLVER.SGD = CN() 93 | _C.SOLVER.SGD.momentum = 0.9 94 | 95 | # ---------------------------------------------------------------------------- # 96 | # Scheduler (learning rate schedule) 97 | # ---------------------------------------------------------------------------- # 98 | _C.SCHEDULER = CN() 99 | _C.SCHEDULER.TYPE = "" 100 | 101 | _C.SCHEDULER.INIT_EPOCH = 2 102 | _C.SCHEDULER.MAX_EPOCH = 2 103 | 104 | _C.SCHEDULER.StepLR = CN() 105 | _C.SCHEDULER.StepLR.step_size = 0 106 | _C.SCHEDULER.StepLR.gamma = 0.1 107 | 108 | _C.SCHEDULER.MultiStepLR = CN() 109 | _C.SCHEDULER.MultiStepLR.milestones = () 110 | _C.SCHEDULER.MultiStepLR.gamma = 0.1 111 | 112 | # ---------------------------------------------------------------------------- # 113 | # Specific train options 114 | # ---------------------------------------------------------------------------- # 115 | _C.TRAIN = CN() 116 | 117 | _C.TRAIN.BATCH_SIZE = 1 118 | 119 | 120 | # The period to save a checkpoint 121 | _C.TRAIN.CHECKPOINT_PERIOD = 1000 122 | _C.TRAIN.LOG_PERIOD = 10 123 | # The period to validate 124 | _C.TRAIN.VAL_PERIOD = 0 125 | # Data augmentation. The format is "method" or ("method", *args) 126 | # For example, ("PointCloudRotate", ("PointCloudRotatePerturbation",0.1, 0.2)) 127 | _C.TRAIN.AUGMENTATION = () 128 | 129 | # Regex patterns of modules and/or parameters to freeze 130 | # For example, ("bn",) will freeze all batch normalization layers' weight and bias; 131 | # And ("module:bn",) will freeze all batch normalization layers' running mean and var. 132 | _C.TRAIN.FROZEN_PATTERNS = () 133 | 134 | _C.TRAIN.VAL_METRIC = "<1_cor" 135 | 136 | # ---------------------------------------------------------------------------- # 137 | # Specific test options 138 | # ---------------------------------------------------------------------------- # 139 | _C.TEST = CN() 140 | 141 | _C.TEST.BATCH_SIZE = 1 142 | 143 | # The path of weights to be tested. "@" has similar syntax as OUTPUT_DIR. 144 | # If not set, the last checkpoint will be used by default. 145 | _C.TEST.WEIGHT = "" 146 | 147 | # Data augmentation. 148 | _C.TEST.AUGMENTATION = () 149 | 150 | _C.TEST.LOG_PERIOD = 10 151 | 152 | 153 | def load_cfg_from_file(cfg_filename): 154 | """Load config from a file 155 | 156 | Args: 157 | cfg_filename (str): 158 | 159 | Returns: 160 | CfgNode: loaded configuration 161 | 162 | """ 163 | with open(cfg_filename, "r") as f: 164 | cfg = load_cfg(f) 165 | 166 | cfg_template = _C 167 | cfg_template.merge_from_other_cfg(cfg) 168 | return cfg_template 169 | 170 | 171 | 172 | 173 | 174 | -------------------------------------------------------------------------------- /fastmvsnet/dataset.py: -------------------------------------------------------------------------------- 1 | import os 2 | import os.path as osp 3 | from tqdm import tqdm 4 | import cv2 5 | import numpy as np 6 | import scipy.io 7 | 8 | import torch 9 | from torch.utils.data import Dataset, DataLoader 10 | 11 | from fastmvsnet.utils.preprocess import mask_depth_image, norm_image, scale_dtu_input, crop_dtu_input 12 | import fastmvsnet.utils.io as io 13 | 14 | import random 15 | 16 | class DTU_Train_Val_Set(Dataset): 17 | training_set = [2, 6, 7, 8, 14, 16, 18, 19, 20, 22, 30, 31, 36, 39, 41, 42, 44, 18 | 45, 46, 47, 50, 51, 52, 53, 55, 57, 58, 60, 61, 63, 64, 65, 68, 69, 70, 71, 72, 19 | 74, 76, 83, 84, 85, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 20 | 101, 102, 103, 104, 105, 107, 108, 109, 111, 112, 113, 115, 116, 119, 120, 21 | 121, 122, 123, 124, 125, 126, 127, 128] 22 | validation_set = [3, 5, 17, 21, 28, 35, 37, 38, 40, 43, 56, 59, 66, 67, 82, 86, 106, 117] 23 | 24 | training_lighting_set = [0, 1, 2, 3, 4, 5, 6] 25 | validation_lighting_set = [3] 26 | 27 | mean = torch.tensor([1.97145182, -1.52387525, 651.07223895]) 28 | std = torch.tensor([84.45612252, 93.22252387, 80.08551226]) 29 | 30 | cluster_file_path = "Cameras/pair.txt" 31 | 32 | def __init__(self, root_dir, dataset_name, 33 | num_view=3, 34 | num_virtual_plane=128, 35 | interval_scale=1.6, 36 | ): 37 | 38 | self.root_dir = root_dir 39 | self.num_view = num_view 40 | self.interval_scale = interval_scale 41 | self.num_virtual_plane = num_virtual_plane 42 | 43 | self.cluster_file_path = osp.join(root_dir, self.cluster_file_path) 44 | self.cluster_list = open(self.cluster_file_path).read().split() 45 | # self.cluster_list = 46 | assert (dataset_name in ["train", "valid"]), "Unknown dataset_name: {}".format(dataset_name) 47 | 48 | if dataset_name == "train": 49 | self.data_set = self.training_set 50 | self.lighting_set = self.training_lighting_set 51 | elif dataset_name == "valid": 52 | self.data_set = self.validation_set 53 | self.lighting_set = self.validation_lighting_set 54 | 55 | self.path_list = self._load_dataset(self.data_set, self.lighting_set) 56 | 57 | def _load_dataset(self, dataset, lighting_set): 58 | path_list = [] 59 | for ind in dataset: 60 | image_folder = osp.join(self.root_dir, "Rectified/scan{}_train".format(ind)) 61 | cam_folder = osp.join(self.root_dir, "Cameras/train") 62 | depth_folder = osp.join(self.root_dir, "Depths/scan{}_train".format(ind)) 63 | 64 | for lighting_ind in lighting_set: 65 | # for each reference image 66 | for p in range(0, int(self.cluster_list[0])): 67 | paths = {} 68 | pts_paths = [] 69 | view_image_paths = [] 70 | view_cam_paths = [] 71 | view_depth_paths = [] 72 | 73 | # ref image 74 | ref_index = int(self.cluster_list[22 * p + 1]) 75 | ref_image_path = osp.join( 76 | image_folder, "rect_{:03d}_{}_r5000.png".format(ref_index + 1, lighting_ind)) 77 | ref_cam_path = osp.join(cam_folder, "{:08d}_cam.txt".format(ref_index)) 78 | ref_depth_path = osp.join(depth_folder, "depth_map_{:04d}.pfm".format(ref_index)) 79 | 80 | view_image_paths.append(ref_image_path) 81 | view_cam_paths.append(ref_cam_path) 82 | view_depth_paths.append(ref_depth_path) 83 | 84 | # view images 85 | for view in range(self.num_view - 1): 86 | view_index = int(self.cluster_list[22 * p + 2 * view + 3]) 87 | view_image_path = osp.join( 88 | image_folder, "rect_{:03d}_{}_r5000.png".format(view_index + 1, lighting_ind)) 89 | view_cam_path = osp.join(cam_folder, "{:08d}_cam.txt".format(view_index)) 90 | view_depth_path = osp.join(depth_folder, "depth_map_{:04d}.pfm".format(view_index)) 91 | view_image_paths.append(view_image_path) 92 | view_cam_paths.append(view_cam_path) 93 | view_depth_paths.append(view_depth_path) 94 | paths["view_image_paths"] = view_image_paths 95 | paths["view_cam_paths"] = view_cam_paths 96 | paths["view_depth_paths"] = view_depth_paths 97 | 98 | path_list.append(paths) 99 | 100 | return path_list 101 | 102 | def __getitem__(self, index): 103 | paths = self.path_list[index] 104 | images = [] 105 | cams = [] 106 | for view in range(self.num_view): 107 | while True: 108 | try: 109 | image = cv2.imread(paths["view_image_paths"][view]) 110 | #todo 111 | # image = norm_image(image) 112 | except Exception: 113 | print(paths["view_image_paths"][view]) 114 | continue 115 | break 116 | cam = io.load_cam_dtu(open(paths["view_cam_paths"][view]), 117 | num_depth=self.num_virtual_plane, 118 | interval_scale=self.interval_scale) 119 | images.append(image) 120 | cams.append(cam) 121 | 122 | depth_images = [] 123 | for depth_path in paths["view_depth_paths"]: 124 | depth_image = io.load_pfm(depth_path)[0] 125 | depth_images.append(depth_image) 126 | 127 | # mask out-of-range depth pixels (in a relaxed range) 128 | ref_depth = depth_images[0] 129 | depth_start = cams[0][1, 3, 0] + cams[0][1, 3, 1] 130 | depth_end = cams[0][1, 3, 0] + (self.num_virtual_plane - 2) * cams[0][1, 3, 1] 131 | ref_depth = mask_depth_image(ref_depth, depth_start, depth_end) 132 | 133 | depth_list = np.stack(depth_images, axis=0) 134 | img_list = np.stack(images, axis=0) 135 | cam_params_list = np.stack(cams, axis=0) 136 | 137 | img_list = torch.tensor(img_list).permute(0, 3, 1, 2).type(torch.float) 138 | cam_params_list = torch.tensor(cam_params_list).type(torch.float) 139 | ref_depth = torch.tensor(ref_depth).permute(2, 0, 1).type(torch.float) 140 | depth_list = torch.tensor(depth_list).unsqueeze(-1).permute(0, 3, 1, 2).type(torch.float) 141 | depth_list = depth_list * (depth_list > depth_start).float() * (depth_list < depth_end).float() 142 | 143 | return { 144 | "img_list": img_list, 145 | "cam_params_list": cam_params_list, 146 | "gt_depth_img": ref_depth, 147 | "depth_list": depth_list, 148 | "ref_img_path": paths["view_image_paths"][0], 149 | "mean": self.mean, 150 | "std": self.std, 151 | } 152 | 153 | def __len__(self): 154 | return len(self.path_list) 155 | 156 | 157 | class DTU_Test_Set(Dataset): 158 | test_set = [1, 4, 9, 10, 11, 12, 13, 15, 23, 24, 29, 32, 33, 34, 48, 49, 62, 75, 77, 159 | 110, 114, 118] 160 | test_lighting_set = [3] 161 | 162 | mean = torch.tensor([1.97145182, -1.52387525, 651.07223895]) 163 | std = torch.tensor([84.45612252, 93.22252387, 80.08551226]) 164 | 165 | cluster_file_path = "Cameras/pair.txt" 166 | 167 | def __init__(self, root_dir, dataset_name, 168 | num_view=3, 169 | height=1152, width=1600, 170 | num_virtual_plane=128, 171 | interval_scale=1.6, 172 | base_image_size=64, 173 | depth_folder=""): 174 | 175 | self.root_dir = root_dir 176 | self.num_view = num_view 177 | self.interval_scale = interval_scale 178 | self.num_virtual_plane = num_virtual_plane 179 | self.base_image_size = base_image_size 180 | self.height = height 181 | self.width = width 182 | self.depth_folder = depth_folder 183 | 184 | self.cluster_file_path = osp.join(root_dir, self.cluster_file_path) 185 | self.cluster_list = open(self.cluster_file_path).read().split() 186 | # self.cluster_list = 187 | assert (dataset_name in ["test"]), "Unknown dataset_name: {}".format(dataset_name) 188 | 189 | self.data_set = self.test_set 190 | self.lighting_set = self.test_lighting_set 191 | 192 | self.path_list = self._load_dataset(self.data_set, self.lighting_set) 193 | 194 | def _load_dataset(self, dataset, lighting_set): 195 | path_list = [] 196 | for ind in dataset: 197 | image_folder = osp.join(self.root_dir, "Eval/Rectified/scan{}".format(ind)) 198 | cam_folder = osp.join(self.root_dir, "Cameras") 199 | depth_folder = osp.join(self.depth_folder, "scan{}".format(ind)) 200 | 201 | for lighting_ind in lighting_set: 202 | # for each reference image 203 | for p in range(0, int(self.cluster_list[0])): 204 | paths = {} 205 | # pts_paths = [] 206 | view_image_paths = [] 207 | view_cam_paths = [] 208 | view_depth_paths = [] 209 | 210 | # ref image 211 | ref_index = int(self.cluster_list[22 * p + 1]) 212 | ref_image_path = osp.join( 213 | image_folder, "rect_{:03d}_{}_r5000.png".format(ref_index + 1, lighting_ind)) 214 | ref_cam_path = osp.join(cam_folder, "{:08d}_cam.txt".format(ref_index)) 215 | ref_depth_path = osp.join(depth_folder, "depth_map_{:04d}.pfm".format(ref_index)) 216 | 217 | view_image_paths.append(ref_image_path) 218 | view_cam_paths.append(ref_cam_path) 219 | view_depth_paths.append(ref_depth_path) 220 | 221 | # view images 222 | for view in range(self.num_view - 1): 223 | view_index = int(self.cluster_list[22 * p + 2 * view + 3]) 224 | view_image_path = osp.join( 225 | image_folder, "rect_{:03d}_{}_r5000.png".format(view_index + 1, lighting_ind)) 226 | view_cam_path = osp.join(cam_folder, "{:08d}_cam.txt".format(view_index)) 227 | view_depth_path = osp.join(depth_folder, "depth_map_{:04d}.pfm".format(view_index)) 228 | view_image_paths.append(view_image_path) 229 | view_cam_paths.append(view_cam_path) 230 | view_depth_paths.append(view_depth_path) 231 | paths["view_image_paths"] = view_image_paths 232 | paths["view_cam_paths"] = view_cam_paths 233 | paths["view_depth_paths"] = view_depth_paths 234 | 235 | path_list.append(paths) 236 | 237 | return path_list 238 | 239 | def __getitem__(self, index): 240 | paths = self.path_list[index] 241 | depth_images = [] 242 | 243 | images = [] 244 | cams = [] 245 | for view in range(self.num_view): 246 | while True: 247 | try: 248 | image = cv2.imread(paths["view_image_paths"][view]) 249 | except Exception: 250 | print(paths["view_image_paths"][view]) 251 | continue 252 | break 253 | 254 | cam = io.load_cam_dtu(open(paths["view_cam_paths"][view]), 255 | num_depth=self.num_virtual_plane, 256 | interval_scale=self.interval_scale) 257 | 258 | images.append(image) 259 | cams.append(cam) 260 | 261 | if self.depth_folder: 262 | for depth_path in paths["view_depth_paths"]: 263 | depth_image = io.load_pfm(depth_path)[0] 264 | depth_images.append(depth_image) 265 | else: 266 | for depth_path in paths["view_depth_paths"]: 267 | depth_images.append(np.zeros((self.height, self.width), np.float)) 268 | 269 | ref_depth = depth_images[0].copy() 270 | 271 | h_scale = float(self.height) / images[0].shape[0] 272 | w_scale = float(self.width) / images[0].shape[1] 273 | if h_scale > 1 or w_scale > 1: 274 | print("max_h, max_w should < W and H!") 275 | exit() 276 | resize_scale = h_scale 277 | if w_scale > h_scale: 278 | resize_scale = w_scale 279 | scaled_input_images, scaled_input_cams, ref_depth = scale_dtu_input(images, cams, depth_image=ref_depth, 280 | scale=resize_scale) 281 | 282 | # crop to fit network 283 | croped_images, croped_cams, ref_depth = crop_dtu_input(scaled_input_images, scaled_input_cams, 284 | height=self.height, width=self.width, 285 | base_image_size=self.base_image_size, 286 | depth_image=ref_depth) 287 | ref_image = croped_images[0].copy() 288 | for i, image in enumerate(croped_images): 289 | croped_images[i] = norm_image(image) 290 | 291 | depth_list = np.stack(depth_images, axis=0) 292 | img_list = np.stack(croped_images, axis=0) 293 | cam_params_list = np.stack(croped_cams, axis=0) 294 | # cam_pos_list = np.stack(camspos, axis=0) 295 | 296 | img_list = torch.tensor(img_list).permute(0, 3, 1, 2).float() 297 | cam_params_list = torch.tensor(cam_params_list).float() 298 | depth_list = torch.tensor(depth_list).unsqueeze(-1).permute(0, 3, 1, 2).float() 299 | 300 | return { 301 | "img_list": img_list, 302 | "cam_params_list": cam_params_list, 303 | "gt_depth_img": ref_depth, 304 | "depth_list": depth_list, 305 | "ref_img_path": paths["view_image_paths"][0], 306 | "ref_img": ref_image, 307 | "mean": self.mean, 308 | "std": self.std, 309 | } 310 | 311 | def __len__(self): 312 | return len(self.path_list) 313 | 314 | 315 | def build_data_loader(cfg, mode="train"): 316 | if mode == "train": 317 | dataset = DTU_Train_Val_Set( 318 | root_dir=cfg.DATA.TRAIN.ROOT_DIR, 319 | dataset_name="train", 320 | num_view=cfg.DATA.TRAIN.NUM_VIEW, 321 | interval_scale=cfg.DATA.TRAIN.INTER_SCALE, 322 | num_virtual_plane=cfg.DATA.TRAIN.NUM_VIRTUAL_PLANE, 323 | ) 324 | elif mode == "val": 325 | dataset = DTU_Train_Val_Set( 326 | root_dir=cfg.DATA.VAL.ROOT_DIR, 327 | dataset_name="val", 328 | num_view=cfg.DATA.VAL.NUM_VIEW, 329 | interval_scale=cfg.DATA.TRAIN.INTER_SCALE, 330 | num_virtual_plane=cfg.DATA.TRAIN.NUM_VIRTUAL_PLANE, 331 | ) 332 | elif mode == "test": 333 | dataset = DTU_Test_Set( 334 | root_dir=cfg.DATA.TEST.ROOT_DIR, 335 | dataset_name="test", 336 | num_view=cfg.DATA.TEST.NUM_VIEW, 337 | height=cfg.DATA.TEST.IMG_HEIGHT, 338 | width=cfg.DATA.TEST.IMG_WIDTH, 339 | interval_scale=cfg.DATA.TEST.INTER_SCALE, 340 | num_virtual_plane=cfg.DATA.TEST.NUM_VIRTUAL_PLANE, 341 | ) 342 | else: 343 | raise ValueError("Unknown mode: {}.".format(mode)) 344 | 345 | if mode == "train": 346 | batch_size = cfg.TRAIN.BATCH_SIZE 347 | else: 348 | batch_size = cfg.TEST.BATCH_SIZE 349 | 350 | data_loader = DataLoader( 351 | dataset, 352 | batch_size, 353 | shuffle=(mode == "train"), 354 | num_workers=cfg.DATA.NUM_WORKERS, 355 | ) 356 | 357 | return data_loader 358 | -------------------------------------------------------------------------------- /fastmvsnet/functions/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/FastMVSNet/ccb686dda2717613c67d8a289dfe7b2aeb60e2fd/fastmvsnet/functions/__init__.py -------------------------------------------------------------------------------- /fastmvsnet/functions/functions.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | from fastmvsnet.nn.functional import pdist 4 | 5 | 6 | def get_pixel_grids(height, width): 7 | with torch.no_grad(): 8 | # texture coordinate 9 | x_linspace = torch.linspace(0.5, width - 0.5, width).view(1, width).expand(height, width) 10 | y_linspace = torch.linspace(0.5, height - 0.5, height).view(height, 1).expand(height, width) 11 | # y_coordinates, x_coordinates = torch.meshgrid(y_linspace, x_linspace) 12 | x_coordinates = x_linspace.contiguous().view(-1) 13 | y_coordinates = y_linspace.contiguous().view(-1) 14 | ones = torch.ones(height * width) 15 | indices_grid = torch.stack([x_coordinates, y_coordinates, ones], dim=0) 16 | return indices_grid 17 | 18 | 19 | def get_propability_map(cv, depth_map, depth_start, depth_interval): 20 | """get probability map from cost volume""" 21 | with torch.no_grad(): 22 | batch_size, channels, height, width = list(depth_map.size()) 23 | depth = cv.size(1) 24 | 25 | # byx coordinates, batched & flattened 26 | b_coordinates = torch.arange(batch_size, dtype=torch.int64) 27 | y_coordinates = torch.arange(height, dtype=torch.int64) 28 | x_coordinates = torch.arange(width, dtype=torch.int64) 29 | b_coordinates = b_coordinates.view(batch_size, 1, 1).expand(batch_size, height, width) 30 | y_coordinates = y_coordinates.view(1, height, 1).expand(batch_size, height, width) 31 | x_coordinates = x_coordinates.view(1, 1, width).expand(batch_size, height, width) 32 | 33 | b_coordinates = b_coordinates.contiguous().view(-1).type(torch.long) 34 | y_coordinates = y_coordinates.contiguous().view(-1).type(torch.long) 35 | x_coordinates = x_coordinates.contiguous().view(-1).type(torch.long) 36 | # b_coordinates = _repeat_(b_coordinates, batch_size) 37 | # y_coordinates = _repeat_(y_coordinates, batch_size) 38 | # x_coordinates = _repeat_(x_coordinates, batch_size) 39 | 40 | # d coordinates (floored and ceiled), batched & flattened 41 | d_coordinates = ((depth_map - depth_start.view(-1, 1, 1, 1)) / depth_interval.view(-1, 1, 1, 1)).view(-1) 42 | d_coordinates = torch.detach(d_coordinates) 43 | d_coordinates_left0 = torch.clamp(d_coordinates.floor(), 0, depth - 1).type(torch.long) 44 | d_coordinates_right0 = torch.clamp(d_coordinates.ceil(), 0, depth - 1).type(torch.long) 45 | 46 | # # get probability image by gathering 47 | prob_map_left0 = cv[b_coordinates, d_coordinates_left0, y_coordinates, x_coordinates] 48 | prob_map_right0 = cv[b_coordinates, d_coordinates_right0, y_coordinates, x_coordinates] 49 | 50 | prob_map = prob_map_left0 + prob_map_right0 51 | prob_map = prob_map.view(batch_size, 1, height, width) 52 | 53 | return prob_map 54 | -------------------------------------------------------------------------------- /fastmvsnet/model.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import collections 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from fastmvsnet.networks import * 9 | from fastmvsnet.functions.functions import get_pixel_grids, get_propability_map 10 | from fastmvsnet.utils.feature_fetcher import FeatureFetcher, FeatureGradFetcher, PointGrad, ProjectUVFetcher 11 | 12 | 13 | class FastMVSNet(nn.Module): 14 | def __init__(self, 15 | img_base_channels=8, 16 | vol_base_channels=8, 17 | flow_channels=(64, 64, 16, 1), 18 | k=16, 19 | ): 20 | super(FastMVSNet, self).__init__() 21 | self.k = k 22 | 23 | self.feature_fetcher = FeatureFetcher() 24 | self.feature_grad_fetcher = FeatureGradFetcher() 25 | self.point_grad_fetcher = PointGrad() 26 | 27 | self.coarse_img_conv = ImageConv(img_base_channels) 28 | self.coarse_vol_conv = VolumeConv(img_base_channels * 4, vol_base_channels) 29 | self.propagation_net = PropagationNet(img_base_channels) 30 | self.flow_img_conv = ImageConv(img_base_channels) 31 | 32 | def forward(self, data_batch, img_scales, inter_scales, isGN, isTest=False): 33 | preds = collections.OrderedDict() 34 | img_list = data_batch["img_list"] 35 | cam_params_list = data_batch["cam_params_list"] 36 | 37 | cam_extrinsic = cam_params_list[:, :, 0, :3, :4].clone() # (B, V, 3, 4) 38 | R = cam_extrinsic[:, :, :3, :3] 39 | t = cam_extrinsic[:, :, :3, 3].unsqueeze(-1) 40 | R_inv = torch.inverse(R) 41 | cam_intrinsic = cam_params_list[:, :, 1, :3, :3].clone() 42 | 43 | if isTest: 44 | cam_intrinsic[:, :, :2, :3] = cam_intrinsic[:, :, :2, :3] / 4.0 45 | 46 | depth_start = cam_params_list[:, 0, 1, 3, 0] 47 | depth_interval = cam_params_list[:, 0, 1, 3, 1] 48 | num_depth = cam_params_list[0, 0, 1, 3, 2].long() 49 | 50 | depth_end = depth_start + (num_depth - 1) * depth_interval 51 | 52 | batch_size, num_view, img_channel, img_height, img_width = list(img_list.size()) 53 | 54 | coarse_feature_maps = [] 55 | for i in range(num_view): 56 | curr_img = img_list[:, i, :, :, :] 57 | curr_feature_map = self.coarse_img_conv(curr_img)["conv2"] 58 | coarse_feature_maps.append(curr_feature_map) 59 | 60 | feature_list = torch.stack(coarse_feature_maps, dim=1) 61 | 62 | feature_channels, feature_height, feature_width = list(curr_feature_map.size())[1:] 63 | 64 | depths = [] 65 | for i in range(batch_size): 66 | depths.append(torch.linspace(depth_start[i], depth_end[i], num_depth, device=img_list.device) \ 67 | .view(1, 1, num_depth, 1)) 68 | depths = torch.stack(depths, dim=0) # (B, 1, 1, D, 1) 69 | 70 | feature_map_indices_grid = get_pixel_grids(feature_height, feature_width) 71 | # print("before:", feature_map_indices_grid.size()) 72 | feature_map_indices_grid = feature_map_indices_grid.view(1, 3, feature_height, feature_width)[:, :, ::2, ::2].contiguous() 73 | # print("after:", feature_map_indices_grid.size()) 74 | feature_map_indices_grid = feature_map_indices_grid.view(1, 1, 3, -1).expand(batch_size, 1, 3, -1).to(img_list.device) 75 | 76 | ref_cam_intrinsic = cam_intrinsic[:, 0, :, :].clone() 77 | uv = torch.matmul(torch.inverse(ref_cam_intrinsic).unsqueeze(1), feature_map_indices_grid) # (B, 1, 3, FH*FW) 78 | 79 | cam_points = (uv.unsqueeze(3) * depths).view(batch_size, 1, 3, -1) # (B, 1, 3, D*FH*FW) 80 | world_points = torch.matmul(R_inv[:, 0:1, :, :], cam_points - t[:, 0:1, :, :]).transpose(1, 2).contiguous() \ 81 | .view(batch_size, 3, -1) # (B, 3, D*FH*FW) 82 | 83 | preds["world_points"] = world_points 84 | 85 | num_world_points = world_points.size(-1) 86 | assert num_world_points == feature_height * feature_width * num_depth / 4 87 | 88 | point_features = self.feature_fetcher(feature_list, world_points, cam_intrinsic, cam_extrinsic) 89 | ref_feature = coarse_feature_maps[0] 90 | #print("before ref feature:", ref_feature.size()) 91 | ref_feature = ref_feature[:, :, ::2,::2].contiguous() 92 | #print("after ref feature:", ref_feature.size()) 93 | ref_feature = ref_feature.unsqueeze(2).expand(-1, -1, num_depth, -1, -1)\ 94 | .contiguous().view(batch_size,feature_channels,-1) 95 | point_features[:, 0, :, :] = ref_feature 96 | 97 | avg_point_features = torch.mean(point_features, dim=1) 98 | avg_point_features_2 = torch.mean(point_features ** 2, dim=1) 99 | 100 | point_features = avg_point_features_2 - (avg_point_features ** 2) 101 | 102 | cost_volume = point_features.view(batch_size, feature_channels, num_depth, feature_height // 2, feature_width // 2) 103 | 104 | filtered_cost_volume = self.coarse_vol_conv(cost_volume).squeeze(1) 105 | 106 | probability_volume = F.softmax(-filtered_cost_volume, dim=1) 107 | depth_volume = [] 108 | for i in range(batch_size): 109 | depth_array = torch.linspace(depth_start[i], depth_end[i], num_depth, device=depth_start.device) 110 | depth_volume.append(depth_array) 111 | depth_volume = torch.stack(depth_volume, dim=0) # (B, D) 112 | depth_volume = depth_volume.view(batch_size, num_depth, 1, 1).expand(probability_volume.shape) 113 | pred_depth_img = torch.sum(depth_volume * probability_volume, dim=1).unsqueeze(1) # (B, 1, FH, FW) 114 | 115 | prob_map = get_propability_map(probability_volume, pred_depth_img, depth_start, depth_interval) 116 | 117 | # image guided depth map propagation 118 | pred_depth_img = F.interpolate(pred_depth_img, (feature_height, feature_width), mode="nearest") 119 | prob_map = F.interpolate(prob_map, (feature_height, feature_width), mode="bilinear") 120 | pred_depth_img = self.propagation_net(pred_depth_img, img_list[:, 0, :, :, :]) 121 | 122 | preds["coarse_depth_map"] = pred_depth_img 123 | preds["coarse_prob_map"] = prob_map 124 | 125 | if isGN: 126 | feature_pyramids = {} 127 | chosen_conv = ["conv1", "conv2"] 128 | for conv in chosen_conv: 129 | feature_pyramids[conv] = [] 130 | for i in range(num_view): 131 | curr_img = img_list[:, i, :, :, :] 132 | curr_feature_pyramid = self.flow_img_conv(curr_img) 133 | for conv in chosen_conv: 134 | feature_pyramids[conv].append(curr_feature_pyramid[conv]) 135 | 136 | for conv in chosen_conv: 137 | feature_pyramids[conv] = torch.stack(feature_pyramids[conv], dim=1) 138 | 139 | if isTest: 140 | for conv in chosen_conv: 141 | feature_pyramids[conv] = torch.detach(feature_pyramids[conv]) 142 | 143 | 144 | def gn_update(estimated_depth_map, interval, image_scale, it): 145 | nonlocal chosen_conv 146 | # print(estimated_depth_map.size(), image_scale) 147 | flow_height, flow_width = list(estimated_depth_map.size())[2:] 148 | if flow_height != int(img_height * image_scale): 149 | flow_height = int(img_height * image_scale) 150 | flow_width = int(img_width * image_scale) 151 | estimated_depth_map = F.interpolate(estimated_depth_map, (flow_height, flow_width), mode="nearest") 152 | else: 153 | # if it is the same size return directly 154 | return estimated_depth_map 155 | # pass 156 | 157 | if isTest: 158 | estimated_depth_map = estimated_depth_map.detach() 159 | 160 | # GN step 161 | cam_intrinsic = cam_params_list[:, :, 1, :3, :3].clone() 162 | if isTest: 163 | cam_intrinsic[:, :, :2, :3] *= image_scale 164 | else: 165 | cam_intrinsic[:, :, :2, :3] *= (4 * image_scale) 166 | 167 | ref_cam_intrinsic = cam_intrinsic[:, 0, :, :].clone() 168 | feature_map_indices_grid = get_pixel_grids(flow_height, flow_width) \ 169 | .view(1, 1, 3, -1).expand(batch_size, 1, 3, -1).to(img_list.device) 170 | 171 | uv = torch.matmul(torch.inverse(ref_cam_intrinsic).unsqueeze(1), 172 | feature_map_indices_grid) # (B, 1, 3, FH*FW) 173 | 174 | interval_depth_map = estimated_depth_map 175 | cam_points = (uv * interval_depth_map.view(batch_size, 1, 1, -1)) 176 | world_points = torch.matmul(R_inv[:, 0:1, :, :], cam_points - t[:, 0:1, :, :]).transpose(1, 2) \ 177 | .contiguous().view(batch_size, 3, -1) # (B, 3, D*FH*FW) 178 | 179 | grad_pts = self.point_grad_fetcher(world_points, cam_intrinsic, cam_extrinsic) 180 | 181 | R_tar_ref = torch.bmm(R.view(batch_size * num_view, 3, 3), 182 | R_inv[:, 0:1, :, :].repeat(1, num_view, 1, 1).view(batch_size * num_view, 3, 3)) 183 | 184 | R_tar_ref = R_tar_ref.view(batch_size, num_view, 3, 3) 185 | d_pts_d_d = uv.unsqueeze(-1).permute(0, 1, 3, 2, 4).contiguous().repeat(1, num_view, 1, 1, 1) 186 | d_pts_d_d = R_tar_ref.unsqueeze(2) @ d_pts_d_d 187 | d_uv_d_d = torch.bmm(grad_pts.view(-1, 2, 3), d_pts_d_d.view(-1, 3, 1)).view(batch_size, num_view, 1, 188 | -1, 2, 1) 189 | all_features = [] 190 | for conv in chosen_conv: 191 | curr_feature = feature_pyramids[conv] 192 | c, h, w = list(curr_feature.size())[2:] 193 | curr_feature = curr_feature.contiguous().view(-1, c, h, w) 194 | curr_feature = F.interpolate(curr_feature, (flow_height, flow_width), mode="bilinear") 195 | curr_feature = curr_feature.contiguous().view(batch_size, num_view, c, flow_height, flow_width) 196 | 197 | all_features.append(curr_feature) 198 | 199 | all_features = torch.cat(all_features, dim=2) 200 | 201 | if isTest: 202 | point_features, point_features_grad = \ 203 | self.feature_grad_fetcher.test_forward(all_features, world_points, cam_intrinsic, cam_extrinsic) 204 | else: 205 | point_features, point_features_grad = \ 206 | self.feature_grad_fetcher(all_features, world_points, cam_intrinsic, cam_extrinsic) 207 | 208 | c = all_features.size(2) 209 | d_uv_d_d_tmp = d_uv_d_d.repeat(1, 1, c, 1, 1, 1) 210 | # print("d_uv_d_d tmp size:", d_uv_d_d_tmp.size()) 211 | J = point_features_grad.view(-1, 1, 2) @ d_uv_d_d_tmp.view(-1, 2, 1) 212 | J = J.view(batch_size, num_view, c, -1, 1)[:, 1:, ...].contiguous()\ 213 | .permute(0, 3, 1, 2, 4).contiguous().view(-1, c * (num_view - 1), 1) 214 | 215 | # print(J.size()) 216 | resid = point_features[:, 1:, ...] - point_features[:, 0:1, ...] 217 | first_resid = torch.sum(torch.abs(resid), dim=(1, 2)) 218 | # print(resid.size()) 219 | resid = resid.permute(0, 3, 1, 2).contiguous().view(-1, c * (num_view - 1), 1) 220 | 221 | J_t = torch.transpose(J, 1, 2) 222 | H = J_t @ J 223 | b = -J_t @ resid 224 | delta = b / (H + 1e-6) 225 | # #print(delta.size()) 226 | _, _, h, w = estimated_depth_map.size() 227 | flow_result = estimated_depth_map + delta.view(-1, 1, h, w) 228 | 229 | # check update results 230 | interval_depth_map = flow_result 231 | cam_points = (uv * interval_depth_map.view(batch_size, 1, 1, -1)) 232 | world_points = torch.matmul(R_inv[:, 0:1, :, :], cam_points - t[:, 0:1, :, :]).transpose(1, 2) \ 233 | .contiguous().view(batch_size, 3, -1) # (B, 3, D*FH*FW) 234 | 235 | point_features = \ 236 | self.feature_fetcher(all_features, world_points, cam_intrinsic, cam_extrinsic) 237 | 238 | resid = point_features[:, 1:, ...] - point_features[:, 0:1, ...] 239 | second_resid = torch.sum(torch.abs(resid), dim=(1, 2)) 240 | # print(first_resid.size(), second_resid.size()) 241 | 242 | # only accept good update 243 | flow_result = torch.where((second_resid < first_resid).view(batch_size, 1, flow_height, flow_width), 244 | flow_result, estimated_depth_map) 245 | return flow_result 246 | 247 | for i, (img_scale, inter_scale) in enumerate(zip(img_scales, inter_scales)): 248 | if isTest: 249 | pred_depth_img = torch.detach(pred_depth_img) 250 | print("update: {}".format(i)) 251 | flow = gn_update(pred_depth_img, inter_scale* depth_interval, img_scale, i) 252 | preds["flow{}".format(i+1)] = flow 253 | pred_depth_img = flow 254 | 255 | return preds 256 | 257 | 258 | class PointMVSNetLoss(nn.Module): 259 | def __init__(self, valid_threshold): 260 | super(PointMVSNetLoss, self).__init__() 261 | self.maeloss = MAELoss() 262 | self.valid_maeloss = Valid_MAELoss(valid_threshold) 263 | 264 | def forward(self, preds, labels, isFlow): 265 | gt_depth_img = labels["gt_depth_img"] 266 | depth_interval = labels["cam_params_list"][:, 0, 1, 3, 1] 267 | 268 | coarse_depth_map = preds["coarse_depth_map"] 269 | resize_gt_depth = F.interpolate(gt_depth_img, (coarse_depth_map.shape[2], coarse_depth_map.shape[3])) 270 | coarse_loss = self.maeloss(coarse_depth_map, resize_gt_depth, depth_interval) 271 | 272 | losses = {} 273 | losses["coarse_loss"] = coarse_loss 274 | 275 | if isFlow: 276 | flow1 = preds["flow1"] 277 | resize_gt_depth = F.interpolate(gt_depth_img, (flow1.shape[2], flow1.shape[3])) 278 | flow1_loss = self.maeloss(flow1, resize_gt_depth, 0.75 * depth_interval) 279 | losses["flow1_loss"] = flow1_loss 280 | 281 | flow2 = preds["flow2"] 282 | resize_gt_depth = F.interpolate(gt_depth_img, (flow2.shape[2], flow2.shape[3])) 283 | flow2_loss = self.maeloss(flow2, resize_gt_depth, 0.375 * depth_interval) 284 | losses["flow2_loss"] = flow2_loss 285 | 286 | for k in losses.keys(): 287 | losses[k] /= float(len(losses.keys())) 288 | 289 | return losses 290 | 291 | 292 | def cal_less_percentage(pred_depth, gt_depth, depth_interval, threshold): 293 | shape = list(pred_depth.size()) 294 | mask_valid = (~torch.eq(gt_depth, 0.0)).type(torch.float) 295 | denom = torch.sum(mask_valid) + 1e-7 296 | interval_image = depth_interval.view(-1, 1, 1, 1).expand(shape) 297 | abs_diff_image = torch.abs(pred_depth - gt_depth) / interval_image 298 | 299 | pct = mask_valid * (abs_diff_image <= threshold).type(torch.float) 300 | 301 | pct = torch.sum(pct) / denom 302 | 303 | return pct 304 | 305 | 306 | def cal_valid_less_percentage(pred_depth, gt_depth, before_depth, depth_interval, threshold, valid_threshold): 307 | shape = list(pred_depth.size()) 308 | mask_true = (~torch.eq(gt_depth, 0.0)).type(torch.float) 309 | interval_image = depth_interval.view(-1, 1, 1, 1).expand(shape) 310 | abs_diff_image = torch.abs(pred_depth - gt_depth) / interval_image 311 | 312 | if before_depth.size(2) != shape[2]: 313 | before_depth = F.interpolate(before_depth, (shape[2], shape[3])) 314 | 315 | diff = torch.abs(before_depth - gt_depth) / interval_image 316 | mask_valid = (diff < valid_threshold).type(torch.float) 317 | mask_valid = mask_valid * mask_true 318 | 319 | denom = torch.sum(mask_valid) + 1e-7 320 | pct = mask_valid * (abs_diff_image <= threshold).type(torch.float) 321 | 322 | pct = torch.sum(pct) / denom 323 | 324 | return pct 325 | 326 | 327 | class PointMVSNetMetric(nn.Module): 328 | def __init__(self, valid_threshold): 329 | super(PointMVSNetMetric, self).__init__() 330 | self.valid_threshold = valid_threshold 331 | 332 | def forward(self, preds, labels, isFlow): 333 | gt_depth_img = labels["gt_depth_img"] 334 | depth_interval = labels["cam_params_list"][:, 0, 1, 3, 1] 335 | 336 | coarse_depth_map = preds["coarse_depth_map"] 337 | resize_gt_depth = F.interpolate(gt_depth_img, (coarse_depth_map.shape[2], coarse_depth_map.shape[3])) 338 | 339 | less_one_pct_coarse = cal_less_percentage(coarse_depth_map, resize_gt_depth, depth_interval, 1.0) 340 | less_three_pct_coarse = cal_less_percentage(coarse_depth_map, resize_gt_depth, depth_interval, 3.0) 341 | 342 | metrics = { 343 | "<1_pct_cor": less_one_pct_coarse, 344 | "<3_pct_cor": less_three_pct_coarse, 345 | } 346 | 347 | if isFlow: 348 | flow1 = preds["flow1"] 349 | resize_gt_depth = F.interpolate(gt_depth_img, (flow1.shape[2], flow1.shape[3])) 350 | 351 | less_one_pct_flow1 = cal_valid_less_percentage(flow1, resize_gt_depth, coarse_depth_map, 352 | 0.75 * depth_interval, 1.0, self.valid_threshold) 353 | less_three_pct_flow1 = cal_valid_less_percentage(flow1, resize_gt_depth, coarse_depth_map, 354 | 0.75 * depth_interval, 3.0, self.valid_threshold) 355 | 356 | metrics["<1_pct_flow1"] = less_one_pct_flow1 357 | metrics["<3_pct_flow1"] = less_three_pct_flow1 358 | 359 | flow2 = preds["flow2"] 360 | resize_gt_depth = F.interpolate(gt_depth_img, (flow2.shape[2], flow2.shape[3])) 361 | 362 | less_one_pct_flow2 = cal_valid_less_percentage(flow2, resize_gt_depth, flow1, 363 | 0.375 * depth_interval, 1.0, self.valid_threshold) 364 | less_three_pct_flow2 = cal_valid_less_percentage(flow2, resize_gt_depth, flow1, 365 | 0.375 * depth_interval, 3.0, self.valid_threshold) 366 | 367 | metrics["<1_pct_flow2"] = less_one_pct_flow2 368 | metrics["<3_pct_flow2"] = less_three_pct_flow2 369 | 370 | return metrics 371 | 372 | 373 | def build_pointmvsnet(cfg): 374 | net = FastMVSNet( 375 | img_base_channels=cfg.MODEL.IMG_BASE_CHANNELS, 376 | vol_base_channels=cfg.MODEL.VOL_BASE_CHANNELS, 377 | flow_channels=cfg.MODEL.FLOW_CHANNELS, 378 | ) 379 | 380 | loss_fn = PointMVSNetLoss( 381 | valid_threshold=cfg.MODEL.VALID_THRESHOLD, 382 | ) 383 | 384 | metric_fn = PointMVSNetMetric( 385 | valid_threshold=cfg.MODEL.VALID_THRESHOLD, 386 | ) 387 | 388 | return net, loss_fn, metric_fn 389 | 390 | 391 | -------------------------------------------------------------------------------- /fastmvsnet/networks.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from fastmvsnet.nn.conv import * 6 | import numpy as np 7 | 8 | 9 | class ImageConv(nn.Module): 10 | def __init__(self, base_channels, in_channels=3): 11 | super(ImageConv, self).__init__() 12 | self.base_channels = base_channels 13 | self.out_channels = 8 * base_channels 14 | self.conv0 = nn.Sequential( 15 | Conv2d(in_channels, base_channels, 3, 1, padding=1), 16 | Conv2d(base_channels, base_channels, 3, 1, padding=1), 17 | ) 18 | 19 | self.conv1 = nn.Sequential( 20 | Conv2d(base_channels, base_channels * 2, 5, stride=2, padding=2), 21 | Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), 22 | Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), 23 | ) 24 | 25 | self.conv2 = nn.Sequential( 26 | Conv2d(base_channels * 2, base_channels * 4, 5, stride=2, padding=2), 27 | Conv2d(base_channels * 4, base_channels * 4, 3, 1, padding=1), 28 | nn.Conv2d(base_channels * 4, base_channels * 4, 3, padding=1, bias=False) 29 | ) 30 | 31 | 32 | def forward(self, imgs): 33 | out_dict = {} 34 | 35 | conv0 = self.conv0(imgs) 36 | out_dict["conv0"] = conv0 37 | conv1 = self.conv1(conv0) 38 | out_dict["conv1"] = conv1 39 | conv2 = self.conv2(conv1) 40 | out_dict["conv2"] = conv2 41 | 42 | return out_dict 43 | 44 | 45 | 46 | class PropagationNet(nn.Module): 47 | def __init__(self, base_channels): 48 | super(PropagationNet, self).__init__() 49 | self.base_channels = base_channels 50 | 51 | self.img_conv = ImageConv(base_channels) 52 | 53 | self.conv1 = nn.Sequential( 54 | Conv2d(base_channels*4, base_channels * 4, 3, padding=1), 55 | Conv2d(base_channels * 4, base_channels * 2, 3, 1, padding=1), 56 | ) 57 | 58 | self.conv2 = nn.Sequential( 59 | Conv2d(base_channels * 4, base_channels * 2, 3, 1, padding=1), 60 | Conv2d(base_channels * 2, base_channels * 2, 3, 1, padding=1), 61 | ) 62 | 63 | self.conv3 = nn.Sequential( 64 | Conv2d(base_channels * 4, base_channels * 2, 3, 1, padding=1), 65 | nn.Conv2d(base_channels*2, 9, 3, padding=1, bias=False) 66 | ) 67 | 68 | self.unfold = nn.Unfold(kernel_size=(3,3), stride=1, padding=0) 69 | 70 | def forward(self, depth, img): 71 | img_featues = self.img_conv(img) 72 | img_conv2 = img_featues["conv2"] 73 | 74 | x = self.conv3(img_conv2) 75 | prob = F.softmax(x, dim=1) 76 | 77 | depth_pad = F.pad(depth, (1, 1, 1, 1), mode='replicate') 78 | depth_unfold = self.unfold(depth_pad) 79 | 80 | b, c, h, w = prob.size() 81 | prob = prob.view(b, 9, h*w) 82 | 83 | result_depth = torch.sum(depth_unfold * prob, dim=1) 84 | result_depth = result_depth.view(b, 1, h, w) 85 | return result_depth 86 | 87 | 88 | class VolumeConv(nn.Module): 89 | def __init__(self, in_channels, base_channels): 90 | super(VolumeConv, self).__init__() 91 | self.in_channels = in_channels 92 | self.out_channels = base_channels * 8 93 | self.base_channels = base_channels 94 | self.conv1_0 = Conv3d(in_channels, base_channels * 2, 3, stride=2, padding=1) 95 | self.conv2_0 = Conv3d(base_channels * 2, base_channels * 4, 3, stride=2, padding=1) 96 | self.conv3_0 = Conv3d(base_channels * 4, base_channels * 8, 3, stride=2, padding=1) 97 | 98 | self.conv0_1 = Conv3d(in_channels, base_channels, 3, 1, padding=1) 99 | 100 | self.conv1_1 = Conv3d(base_channels * 2, base_channels * 2, 3, 1, padding=1) 101 | self.conv2_1 = Conv3d(base_channels * 4, base_channels * 4, 3, 1, padding=1) 102 | 103 | self.conv3_1 = Conv3d(base_channels * 8, base_channels * 8, 3, 1, padding=1) 104 | self.conv4_0 = Deconv3d(base_channels * 8, base_channels * 4, 3, 2, padding=1, output_padding=1) 105 | self.conv5_0 = Deconv3d(base_channels * 4, base_channels * 2, 3, 2, padding=1, output_padding=1) 106 | self.conv6_0 = Deconv3d(base_channels * 2, base_channels, 3, 2, padding=1, output_padding=1) 107 | 108 | self.conv6_2 = nn.Conv3d(base_channels, 1, 3, padding=1, bias=False) 109 | 110 | def forward(self, x): 111 | conv0_1 = self.conv0_1(x) 112 | 113 | conv1_0 = self.conv1_0(x) 114 | conv2_0 = self.conv2_0(conv1_0) 115 | conv3_0 = self.conv3_0(conv2_0) 116 | 117 | conv1_1 = self.conv1_1(conv1_0) 118 | conv2_1 = self.conv2_1(conv2_0) 119 | conv3_1 = self.conv3_1(conv3_0) 120 | 121 | conv4_0 = self.conv4_0(conv3_1) 122 | 123 | conv5_0 = self.conv5_0(conv4_0 + conv2_1) 124 | conv6_0 = self.conv6_0(conv5_0 + conv1_1) 125 | 126 | conv6_2 = self.conv6_2(conv6_0 + conv0_1) 127 | 128 | return conv6_2 129 | 130 | 131 | class MAELoss(nn.Module): 132 | def forward(self, pred_depth_image, gt_depth_image, depth_interval): 133 | """non zero mean absolute loss for one batch""" 134 | # shape = list(pred_depth_image) 135 | depth_interval = depth_interval.view(-1) 136 | mask_valid = (~torch.eq(gt_depth_image, 0.0)).type(torch.float) 137 | denom = torch.sum(mask_valid, dim=(1, 2, 3)) + 1e-7 138 | masked_abs_error = mask_valid * torch.abs(pred_depth_image - gt_depth_image) 139 | masked_mae = torch.sum(masked_abs_error, dim=(1, 2, 3)) 140 | masked_mae = torch.sum((masked_mae / depth_interval) / denom) 141 | 142 | return masked_mae 143 | 144 | 145 | class Valid_MAELoss(nn.Module): 146 | def __init__(self, valid_threshold=2.0): 147 | super(Valid_MAELoss, self).__init__() 148 | self.valid_threshold = valid_threshold 149 | 150 | def forward(self, pred_depth_image, gt_depth_image, depth_interval, before_depth_image): 151 | """non zero mean absolute loss for one batch""" 152 | # shape = list(pred_depth_image) 153 | pred_height = pred_depth_image.size(2) 154 | pred_width = pred_depth_image.size(3) 155 | depth_interval = depth_interval.view(-1) 156 | mask_true = (~torch.eq(gt_depth_image, 0.0)).type(torch.float) 157 | before_hight = before_depth_image.size(2) 158 | if before_hight != pred_height: 159 | before_depth_image = F.interpolate(before_depth_image, (pred_height, pred_width)) 160 | diff = torch.abs(gt_depth_image - before_depth_image) / depth_interval.view(-1, 1, 1, 1) 161 | mask_valid = (diff < self.valid_threshold).type(torch.float) 162 | mask_valid = mask_true * mask_valid 163 | denom = torch.sum(mask_valid, dim=(1, 2, 3)) + 1e-7 164 | masked_abs_error = mask_valid * torch.abs(pred_depth_image - gt_depth_image) 165 | masked_mae = torch.sum(masked_abs_error, dim=(1, 2, 3)) 166 | masked_mae = torch.sum((masked_mae / depth_interval) / denom) 167 | 168 | return masked_mae 169 | -------------------------------------------------------------------------------- /fastmvsnet/nn/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/FastMVSNet/ccb686dda2717613c67d8a289dfe7b2aeb60e2fd/fastmvsnet/nn/__init__.py -------------------------------------------------------------------------------- /fastmvsnet/nn/conv.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | from .init import init_uniform, init_bn 5 | 6 | 7 | class Conv1d(nn.Module): 8 | """Applies a 1D convolution over an input signal composed of several input planes. 9 | optionally followed by batch normalization and relu activation 10 | 11 | Attributes: 12 | conv (nn.Module): convolution module 13 | bn (nn.Module): batch normalization module 14 | relu (bool): whether to activate by relu 15 | 16 | """ 17 | 18 | def __init__(self, in_channels, out_channels, kernel_size, 19 | relu=True, bn=True, bn_momentum=0.1, **kwargs): 20 | super(Conv1d, self).__init__() 21 | 22 | self.conv = nn.Conv1d(in_channels, out_channels, kernel_size, 23 | bias=(not bn), **kwargs) 24 | self.bn = nn.BatchNorm1d(out_channels, momentum=bn_momentum) if bn else None 25 | self.relu = relu 26 | 27 | self.init_weights() 28 | 29 | def forward(self, x): 30 | x = self.conv(x) 31 | if self.bn is not None: 32 | x = self.bn(x) 33 | if self.relu: 34 | x = F.relu(x, inplace=True) 35 | return x 36 | 37 | def init_weights(self): 38 | """default initialization""" 39 | init_uniform(self.conv) 40 | if self.bn is not None: 41 | init_bn(self.bn) 42 | 43 | 44 | class Conv2d(nn.Module): 45 | """Applies a 2D convolution (optionally with batch normalization and relu activation) 46 | over an input signal composed of several input planes. 47 | 48 | Attributes: 49 | conv (nn.Module): convolution module 50 | bn (nn.Module): batch normalization module 51 | relu (bool): whether to activate by relu 52 | 53 | Notes: 54 | Default momentum for batch normalization is set to be 0.01, 55 | 56 | """ 57 | 58 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 59 | relu=True, bn=True, bn_momentum=0.1, **kwargs): 60 | super(Conv2d, self).__init__() 61 | 62 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, 63 | bias=(not bn), **kwargs) 64 | self.kernel_size = kernel_size 65 | self.stride = stride 66 | self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None 67 | self.relu = relu 68 | 69 | self.init_weights() 70 | 71 | def forward(self, x): 72 | x = self.conv(x) 73 | if self.bn is not None: 74 | x = self.bn(x) 75 | if self.relu: 76 | x = F.relu(x, inplace=True) 77 | return x 78 | 79 | def init_weights(self): 80 | """default initialization""" 81 | init_uniform(self.conv) 82 | if self.bn is not None: 83 | init_bn(self.bn) 84 | 85 | 86 | class Conv2d_gn(nn.Module): 87 | """Applies a 2D convolution (optionally with batch normalization and relu activation) 88 | over an input signal composed of several input planes. 89 | 90 | Attributes: 91 | conv (nn.Module): convolution module 92 | bn (nn.Module): batch normalization module 93 | relu (bool): whether to activate by relu 94 | 95 | Notes: 96 | Default momentum for batch normalization is set to be 0.01, 97 | 98 | """ 99 | 100 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 101 | relu=True, gn=True, group_channel=8, **kwargs): 102 | super(Conv2d_gn, self).__init__() 103 | 104 | self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride=stride, 105 | bias=(not gn), **kwargs) 106 | self.kernel_size = kernel_size 107 | self.stride = stride 108 | G = max(1, out_channels // group_channel) 109 | self.gn = nn.GroupNorm(G, out_channels) if gn else None 110 | self.relu = relu 111 | 112 | self.init_weights() 113 | 114 | def forward(self, x): 115 | x = self.conv(x) 116 | if self.gn is not None: 117 | x = self.gn(x) 118 | if self.relu: 119 | x = F.relu(x, inplace=True) 120 | return x 121 | 122 | def init_weights(self): 123 | """default initialization""" 124 | init_uniform(self.conv) 125 | 126 | 127 | class Conv3d(nn.Module): 128 | """Applies a 3D convolution (optionally with batch normalization and relu activation) 129 | over an input signal composed of several input planes. 130 | 131 | Attributes: 132 | conv (nn.Module): convolution module 133 | bn (nn.Module): batch normalization module 134 | relu (bool): whether to activate by relu 135 | 136 | Notes: 137 | Default momentum for batch normalization is set to be 0.01, 138 | 139 | """ 140 | 141 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 142 | relu=True, bn=True, bn_momentum=0.1, **kwargs): 143 | super(Conv3d, self).__init__() 144 | self.out_channels = out_channels 145 | self.kernel_size = kernel_size 146 | assert stride in [1, 2] 147 | self.stride = stride 148 | 149 | self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, 150 | bias=(not bn), **kwargs) 151 | self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None 152 | self.relu = relu 153 | 154 | self.init_weights() 155 | 156 | def forward(self, x): 157 | x = self.conv(x) 158 | if self.bn is not None: 159 | x = self.bn(x) 160 | if self.relu: 161 | x = F.relu(x, inplace=True) 162 | return x 163 | 164 | def init_weights(self): 165 | """default initialization""" 166 | init_uniform(self.conv) 167 | if self.bn is not None: 168 | init_bn(self.bn) 169 | 170 | 171 | class Deconv2d(nn.Module): 172 | """Applies a 2D deconvolution (optionally with batch normalization and relu activation) 173 | over an input signal composed of several input planes. 174 | 175 | Attributes: 176 | conv (nn.Module): convolution module 177 | bn (nn.Module): batch normalization module 178 | relu (bool): whether to activate by relu 179 | 180 | Notes: 181 | Default momentum for batch normalization is set to be 0.01, 182 | 183 | """ 184 | 185 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 186 | relu=True, bn=True, bn_momentum=0.1, **kwargs): 187 | super(Deconv2d, self).__init__() 188 | self.out_channels = out_channels 189 | assert stride in [1, 2] 190 | self.stride = stride 191 | 192 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, 193 | bias=(not bn), **kwargs) 194 | self.bn = nn.BatchNorm2d(out_channels, momentum=bn_momentum) if bn else None 195 | self.relu = relu 196 | 197 | self.init_weights() 198 | 199 | def forward(self, x): 200 | y = self.conv(x) 201 | if self.stride == 2: 202 | h, w = list(x.size())[2:] 203 | y = y[:, :, :2 * h, :2 * w].contiguous() 204 | if self.bn is not None: 205 | x = self.bn(y) 206 | if self.relu: 207 | x = F.relu(x, inplace=True) 208 | return x 209 | 210 | def init_weights(self): 211 | """default initialization""" 212 | init_uniform(self.conv) 213 | if self.bn is not None: 214 | init_bn(self.bn) 215 | 216 | 217 | class Deconv2d_gn(nn.Module): 218 | """Applies a 2D deconvolution (optionally with batch normalization and relu activation) 219 | over an input signal composed of several input planes. 220 | 221 | Attributes: 222 | conv (nn.Module): convolution module 223 | bn (nn.Module): batch normalization module 224 | relu (bool): whether to activate by relu 225 | 226 | Notes: 227 | Default momentum for batch normalization is set to be 0.01, 228 | 229 | """ 230 | 231 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 232 | relu=True, gn=True, group_channel=8, **kwargs): 233 | super(Deconv2d_gn, self).__init__() 234 | self.out_channels = out_channels 235 | assert stride in [1, 2] 236 | self.stride = stride 237 | 238 | self.conv = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, 239 | bias=(not gn), **kwargs) 240 | G = max(1, out_channels // group_channel) 241 | self.gn = nn.GroupNorm(G, out_channels) if gn else None 242 | self.relu = relu 243 | 244 | self.init_weights() 245 | 246 | def forward(self, x): 247 | y = self.conv(x) 248 | if self.stride == 2: 249 | h, w = list(x.size())[2:] 250 | y = y[:, :, :2 * h, :2 * w].contiguous() 251 | if self.gn is not None: 252 | x = self.gn(y) 253 | if self.relu: 254 | x = F.relu(x, inplace=True) 255 | return x 256 | 257 | def init_weights(self): 258 | """default initialization""" 259 | init_uniform(self.conv) 260 | 261 | class Deconv3d(nn.Module): 262 | """Applies a 3D deconvolution (optionally with batch normalization and relu activation) 263 | over an input signal composed of several input planes. 264 | 265 | Attributes: 266 | conv (nn.Module): convolution module 267 | bn (nn.Module): batch normalization module 268 | relu (bool): whether to activate by relu 269 | 270 | Notes: 271 | Default momentum for batch normalization is set to be 0.01, 272 | 273 | """ 274 | 275 | def __init__(self, in_channels, out_channels, kernel_size, stride=1, 276 | relu=True, bn=True, bn_momentum=0.1, **kwargs): 277 | super(Deconv3d, self).__init__() 278 | self.out_channels = out_channels 279 | assert stride in [1, 2] 280 | self.stride = stride 281 | 282 | self.conv = nn.ConvTranspose3d(in_channels, out_channels, kernel_size, stride=stride, 283 | bias=(not bn), **kwargs) 284 | self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) if bn else None 285 | self.relu = relu 286 | 287 | self.init_weights() 288 | 289 | def forward(self, x): 290 | y = self.conv(x) 291 | if self.bn is not None: 292 | x = self.bn(y) 293 | if self.relu: 294 | x = F.relu(x, inplace=True) 295 | return x 296 | 297 | def init_weights(self): 298 | """default initialization""" 299 | init_uniform(self.conv) 300 | if self.bn is not None: 301 | init_bn(self.bn) 302 | -------------------------------------------------------------------------------- /fastmvsnet/nn/freeze_weight.py: -------------------------------------------------------------------------------- 1 | """Helpers for operating weights/params""" 2 | import re 3 | import torch.nn as nn 4 | 5 | __all__ = ['freeze_bn', 'freeze_params', 'freeze_modules', 'freeze_by_patterns'] 6 | 7 | 8 | def freeze_bn(module, bn_eval, bn_frozen): 9 | """Freeze Batch Normalization in Module 10 | 11 | Args: 12 | module (torch.nn.Module): 13 | bn_eval (bool): flag to using global stats 14 | bn_frozen (bool): flag to freeze bn params 15 | 16 | Returns: 17 | 18 | """ 19 | for module_name, m in module.named_modules(): 20 | if isinstance(m, nn.BatchNorm2d): 21 | if bn_eval: 22 | # Notice the difference between the behaviors of 23 | # BatchNorm.eval() and BatchNorm(track_running_stats=False) 24 | m.eval() 25 | # print('BN: %s in eval mode.' % module_name) 26 | if bn_frozen: 27 | for param_name, params in m.named_parameters(): 28 | params.requires_grad = False 29 | # print('BN: %s is frozen.' % (module_name + '.' + param_name)) 30 | 31 | 32 | def freeze_params(module, frozen_params): 33 | """Freeze params and/or convert them into eval mode 34 | 35 | Args: 36 | module (torch.nn.Module): 37 | frozen_params: a list/tuple of strings, 38 | which define all the patterns of interests 39 | 40 | Returns: 41 | 42 | """ 43 | for name, params in module.named_parameters(): 44 | print(name) 45 | for pattern in frozen_params: 46 | assert isinstance(pattern, str) 47 | if re.search(pattern, name): 48 | params.requires_grad = False 49 | print('Params %s is frozen.' % name) 50 | 51 | 52 | def freeze_modules(module, frozen_modules, prefix=''): 53 | """Set module's eval mode and freeze its params 54 | 55 | Args: 56 | module (torch.nn.Module): 57 | frozen_modules (list[str]): 58 | 59 | Returns: 60 | 61 | """ 62 | for name, m in module._modules.items(): 63 | for pattern in frozen_modules: 64 | assert isinstance(pattern, str) 65 | full_name = prefix + ('.' if prefix else '') + name 66 | if re.search(pattern, full_name): 67 | m.eval() 68 | _freeze_all_params(m) 69 | print('Module %s is frozen.' % full_name) 70 | else: 71 | freeze_modules(m, frozen_modules, prefix=full_name) 72 | 73 | 74 | def freeze_by_patterns(module, patterns): 75 | """Freeze Module by matching patterns""" 76 | frozen_params = [] 77 | frozen_modules = [] 78 | for pattern in patterns: 79 | if pattern.startswith('module:'): 80 | frozen_modules.append(pattern[7:]) 81 | else: 82 | frozen_params.append(pattern) 83 | freeze_params(module, frozen_params) 84 | freeze_modules(module, frozen_modules) 85 | 86 | 87 | def _freeze_all_params(module): 88 | """Freeze all params in a module""" 89 | for name, params in module.named_parameters(): 90 | params.requires_grad = False 91 | 92 | 93 | def unfreeze_params(module, frozen_params): 94 | """Unfreeze params and/or convert them into eval mode 95 | 96 | Args: 97 | module (torch.nn.Module): 98 | frozen_params: a list/tuple of strings, 99 | which define all the patterns of interests 100 | 101 | Returns: 102 | 103 | """ 104 | for name, params in module.named_parameters(): 105 | print(name) 106 | for pattern in frozen_params: 107 | assert isinstance(pattern, str) 108 | if re.search(pattern, name): 109 | params.requires_grad = True 110 | print('Params %s is unfrozen.' % name) 111 | 112 | 113 | def unfreeze_modules(module, frozen_modules, prefix=''): 114 | """Set module's eval mode and freeze its params 115 | 116 | Args: 117 | module (torch.nn.Module): 118 | frozen_modules (list[str]): 119 | 120 | Returns: 121 | 122 | """ 123 | for name, m in module._modules.items(): 124 | for pattern in frozen_modules: 125 | assert isinstance(pattern, str) 126 | full_name = prefix + ('.' if prefix else '') + name 127 | if re.search(pattern, full_name): 128 | m.eval() 129 | _unfreeze_all_params(m) 130 | print('Module %s is unfrozen.' % full_name) 131 | else: 132 | unfreeze_modules(m, frozen_modules, prefix=full_name) 133 | 134 | 135 | def unfreeze_by_patterns(module, patterns): 136 | """Defrost Module by matching patterns""" 137 | unfreeze_params = [] 138 | unfreeze_modules = [] 139 | for pattern in patterns: 140 | if pattern.startswith('module:'): 141 | unfreeze_modules.append(pattern[7:]) 142 | else: 143 | unfreeze_params.append(pattern) 144 | 145 | 146 | def _unfreeze_all_params(module): 147 | """Freeze all params in a module""" 148 | for name, params in module.named_parameters(): 149 | params.requires_grad = True 150 | print('Params %s is unfrozen.' % name) 151 | -------------------------------------------------------------------------------- /fastmvsnet/nn/functional.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | 4 | 5 | # ----------------------------------------------------------------------------- 6 | # Distance 7 | # ----------------------------------------------------------------------------- 8 | 9 | def pdist(feature): 10 | """Compute pairwise distances of features. 11 | 12 | Args: 13 | feature (torch.Tensor): (batch_size, channels, num_features) 14 | 15 | Returns: 16 | distance (torch.Tensor): (batch_size, num_features, num_features) 17 | 18 | Notes: 19 | This method returns square distances, and is optimized for lower memory and faster speed. 20 | Sqaure sum is more efficient than gather diagonal from inner product. 21 | 22 | """ 23 | square_sum = torch.sum(feature ** 2, 1, keepdim=True) 24 | square_sum = square_sum + square_sum.transpose(1, 2) 25 | distance = torch.baddbmm(square_sum, feature.transpose(1, 2), feature, alpha=-2.0) 26 | return distance 27 | 28 | 29 | def pdist2(feature1, feature2): 30 | """Compute pairwise distances of two sets of features. 31 | 32 | Args: 33 | feature1 (torch.Tensor): (batch_size, channels, num_features1) 34 | feature2 (torch.Tensor): (batch_size, channels, num_features2) 35 | 36 | Returns: 37 | distance (torch.Tensor): (batch_size, num_features1, num_features2) 38 | 39 | Notes: 40 | This method returns square distances, and is optimized for lower memory and faster speed. 41 | Sqaure sum is more efficient than gather diagonal from inner product. 42 | 43 | """ 44 | square_sum1 = torch.sum(feature1 ** 2, 1, keepdim=True) 45 | square_sum2 = torch.sum(feature2 ** 2, 1, keepdim=True) 46 | square_sum = square_sum1.transpose(1, 2) + square_sum2 47 | distance = torch.baddbmm(square_sum, feature1.transpose(1, 2), feature2, alpha=-2.0) 48 | return distance 49 | 50 | 51 | # ----------------------------------------------------------------------------- 52 | # Losses 53 | # ----------------------------------------------------------------------------- 54 | 55 | def encode_one_hot(target, num_classes): 56 | """Encode integer labels into one-hot vectors 57 | 58 | Args: 59 | target (torch.Tensor): (N,) 60 | num_classes (int): the number of classes 61 | 62 | Returns: 63 | torch.FloatTensor: (N, C) 64 | 65 | """ 66 | one_hot = target.new_zeros(target.size(0), num_classes) 67 | one_hot = one_hot.scatter(1, target.unsqueeze(1), 1) 68 | return one_hot.float() 69 | 70 | 71 | def smooth_cross_entropy(input, target, label_smoothing): 72 | """Cross entropy loss with label smoothing 73 | 74 | Args: 75 | input (torch.Tensor): (N, C) 76 | target (torch.Tensor): (N,) 77 | label_smoothing (float): 78 | 79 | Returns: 80 | loss (torch.Tensor): scalar 81 | 82 | """ 83 | assert input.dim() == 2 and target.dim() == 1 84 | assert isinstance(label_smoothing, float) 85 | batch_size, num_classes = input.shape 86 | one_hot = torch.zeros_like(input).scatter(1, target.unsqueeze(1), 1) 87 | smooth_one_hot = one_hot * (1 - label_smoothing) + torch.ones_like(input) * (label_smoothing / num_classes) 88 | log_prob = F.log_softmax(input, dim=1) 89 | loss = (- smooth_one_hot * log_prob).sum(1).mean() 90 | return loss 91 | 92 | 93 | -------------------------------------------------------------------------------- /fastmvsnet/nn/init.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | 3 | 4 | def init_bn(module): 5 | if module.weight is not None: 6 | nn.init.ones_(module.weight) 7 | if module.bias is not None: 8 | nn.init.zeros_(module.bias) 9 | return 10 | 11 | 12 | def set_bn(model, momentum): 13 | for m in model.modules(): 14 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 15 | m.momentum = momentum 16 | 17 | 18 | def init_uniform(module): 19 | if module.weight is not None: 20 | # nn.init.kaiming_uniform_(module.weight) 21 | nn.init.xavier_uniform_(module.weight) 22 | if module.bias is not None: 23 | nn.init.zeros_(module.bias) 24 | return 25 | 26 | 27 | def set_eps(model, eps): 28 | for m in model.modules(): 29 | if isinstance(m, (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d)): 30 | m.eps = eps 31 | -------------------------------------------------------------------------------- /fastmvsnet/nn/linear.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | from .init import init_uniform, init_bn 5 | 6 | 7 | class FC(nn.Module): 8 | """Applies a linear transformation to the incoming data 9 | optionally followed by batch normalization and relu activation 10 | 11 | Attributes: 12 | conv (nn.Module): convolution module 13 | bn (nn.Module): batch normalization module 14 | relu (bool): whether to activate by relu 15 | 16 | """ 17 | 18 | def __init__(self, in_channels, out_channels, 19 | relu=True, bn=True, bn_momentum=0.1): 20 | super(FC, self).__init__() 21 | 22 | self.fc = nn.Linear(in_channels, out_channels, bias=(not bn)) 23 | self.bn = nn.BatchNorm1d(out_channels, momentum=bn_momentum) if bn else None 24 | self.relu = relu 25 | 26 | self.init_weights() 27 | 28 | def forward(self, x): 29 | x = self.fc(x) 30 | if self.bn is not None: 31 | x = self.bn(x) 32 | if self.relu: 33 | x = F.relu(x, inplace=True) 34 | return x 35 | 36 | def init_weights(self): 37 | """default initialization""" 38 | init_uniform(self.fc) 39 | if self.bn is not None: 40 | init_bn(self.bn) 41 | -------------------------------------------------------------------------------- /fastmvsnet/nn/mlp.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | import torch.nn.functional as F 3 | 4 | from .conv import Conv1d, Conv2d 5 | from .linear import FC 6 | 7 | 8 | class MLP(nn.ModuleList): 9 | """Multilayer perceptron 10 | 11 | Args: 12 | in_channels (int): the number of channels of input tensor 13 | mlp_channels (tuple): the numbers of channels of fully connected layers 14 | dropout (float or None): dropout ratio 15 | bn (bool): whether to use batch normalization 16 | 17 | """ 18 | 19 | def __init__(self, 20 | in_channels, 21 | mlp_channels, 22 | dropout=None, 23 | bn=True, 24 | bn_momentum=0.1): 25 | super(MLP, self).__init__() 26 | 27 | self.in_channels = in_channels 28 | self.dropout = dropout 29 | 30 | for ind, out_channels in enumerate(mlp_channels): 31 | self.append(FC(in_channels, out_channels, 32 | relu=True, bn=bn, bn_momentum=bn_momentum)) 33 | in_channels = out_channels 34 | 35 | self.out_channels = in_channels 36 | 37 | def forward(self, x): 38 | for module in self: 39 | x = module(x) 40 | if self.dropout: 41 | x = F.dropout(x, self.dropout, self.training, inplace=False) 42 | return x 43 | 44 | 45 | class SharedMLP(nn.ModuleList): 46 | def __init__(self, 47 | in_channels, 48 | mlp_channels, 49 | ndim=1, 50 | bn=True, 51 | bn_momentum=0.1): 52 | """Multilayer perceptron shared on resolution (1D or 2D) 53 | 54 | Args: 55 | in_channels (int): the number of channels of input tensor 56 | mlp_channels (tuple): the numbers of channels of fully connected layers 57 | ndim (int): the number of dimensions to share 58 | bn (bool): whether to use batch normalization 59 | """ 60 | super(SharedMLP, self).__init__() 61 | 62 | self.in_channels = in_channels 63 | 64 | if ndim == 1: 65 | mlp_module = Conv1d 66 | elif ndim == 2: 67 | mlp_module = Conv2d 68 | else: 69 | raise ValueError() 70 | 71 | for ind, out_channels in enumerate(mlp_channels): 72 | self.append(mlp_module(in_channels, out_channels, 1, 73 | relu=True, bn=bn, bn_momentum=bn_momentum)) 74 | in_channels = out_channels 75 | 76 | self.out_channels = in_channels 77 | 78 | def forward(self, x): 79 | for module in self: 80 | x = module(x) 81 | return x 82 | -------------------------------------------------------------------------------- /fastmvsnet/solver.py: -------------------------------------------------------------------------------- 1 | """ 2 | Build optimizers and schedulers 3 | 4 | Notes: 5 | Default optimizer will optimize all parameters. 6 | Custom optimizer should be implemented and registered in '_OPTIMIZER_BUILDERS' 7 | 8 | """ 9 | import torch 10 | import torch.nn as nn 11 | from torch.nn.modules.batchnorm import _BatchNorm 12 | from torch.nn.modules.conv import _ConvNd 13 | 14 | _OPTIMIZER_BUILDERS = {} 15 | 16 | 17 | def build_optimizer(cfg, model): 18 | name = cfg.SOLVER.TYPE 19 | if hasattr(torch.optim, name): 20 | def builder(cfg, model): 21 | return getattr(torch.optim, name)( 22 | group_weight(model, cfg.SOLVER.WEIGHT_DECAY), 23 | lr=cfg.SOLVER.BASE_LR, 24 | **cfg.SOLVER[name], 25 | ) 26 | elif name in _OPTIMIZER_BUILDERS: 27 | builder = _OPTIMIZER_BUILDERS[name] 28 | else: 29 | raise ValueError("Unsupported type of optimizer.") 30 | 31 | return builder(cfg, model) 32 | 33 | 34 | def group_weight(module, weight_decay): 35 | group_decay = [] 36 | group_no_decay = [] 37 | keywords = [".bn."] 38 | 39 | for m in list(module.named_parameters()): 40 | exclude = False 41 | for k in keywords: 42 | if k in m[0]: 43 | print("Weight decay exclude: "+m[0]) 44 | group_no_decay.append(m[1]) 45 | exclude = True 46 | break 47 | if not exclude: 48 | print("Weight decay include: " + m[0]) 49 | group_decay.append(m[1]) 50 | 51 | assert len(list(module.parameters())) == len(group_decay) + len(group_no_decay) 52 | groups = [dict(params=group_decay, weight_decay=weight_decay), dict(params=group_no_decay, weight_decay=.0)] 53 | return groups 54 | 55 | 56 | def register_optimizer_builder(name, builder): 57 | if name in _OPTIMIZER_BUILDERS: 58 | raise KeyError( 59 | "Duplicate keys for {:s} with {} and {}." 60 | "Solve key conflicts first!".format(name, _OPTIMIZER_BUILDERS[name], builder)) 61 | _OPTIMIZER_BUILDERS[name] = builder 62 | 63 | 64 | def build_scheduler(cfg, optimizer): 65 | name = cfg.SCHEDULER.TYPE 66 | if hasattr(torch.optim.lr_scheduler, name): 67 | def builder(cfg, optimizer): 68 | return getattr(torch.optim.lr_scheduler, name)( 69 | optimizer, 70 | **cfg.SCHEDULER[name], 71 | ) 72 | elif name in _OPTIMIZER_BUILDERS: 73 | builder = _OPTIMIZER_BUILDERS[name] 74 | else: 75 | raise ValueError("Unsupported type of optimizer.") 76 | 77 | return builder(cfg, optimizer) 78 | -------------------------------------------------------------------------------- /fastmvsnet/test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import os.path as osp 4 | import logging 5 | import time 6 | import sys 7 | sys.path.insert(0, osp.dirname(__file__) + '/..') 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from fastmvsnet.config import load_cfg_from_file 13 | from fastmvsnet.utils.io import mkdir 14 | from fastmvsnet.utils.logger import setup_logger 15 | from fastmvsnet.model import build_pointmvsnet as build_model 16 | from fastmvsnet.utils.checkpoint import Checkpointer 17 | from fastmvsnet.dataset import build_data_loader 18 | from fastmvsnet.utils.metric_logger import MetricLogger 19 | from fastmvsnet.utils.eval_file_logger import eval_file_logger 20 | 21 | 22 | def parse_args(): 23 | parser = argparse.ArgumentParser(description="PyTorch Fast-MVSNet Evaluation") 24 | parser.add_argument( 25 | "--cfg", 26 | dest="config_file", 27 | default="", 28 | metavar="FILE", 29 | help="path to config file", 30 | type=str, 31 | ) 32 | parser.add_argument( 33 | "--cpu", 34 | action='store_true', 35 | default=False, 36 | help="whether to only use cpu for test", 37 | ) 38 | parser.add_argument( 39 | "opts", 40 | help="Modify config options using the command-line", 41 | default=None, 42 | nargs=argparse.REMAINDER, 43 | ) 44 | 45 | args = parser.parse_args() 46 | return args 47 | 48 | 49 | def test_model(model, 50 | image_scales, 51 | inter_scales, 52 | data_loader, 53 | folder, 54 | isCPU=False, 55 | ): 56 | logger = logging.getLogger("fastmvsnet.train") 57 | meters = MetricLogger(delimiter=" ") 58 | model.train() 59 | end = time.time() 60 | total_iteration = data_loader.__len__() 61 | path_list = [] 62 | with torch.no_grad(): 63 | for iteration, data_batch in enumerate(data_loader): 64 | data_time = time.time() - end 65 | curr_ref_img_path = data_batch["ref_img_path"][0] 66 | path_list.extend(curr_ref_img_path) 67 | if not isCPU: 68 | data_batch = {k: v.cuda(non_blocking=True) for k, v in data_batch.items() if isinstance(v, torch.Tensor)} 69 | preds = model(data_batch, image_scales, inter_scales, isGN=True, isTest=True) 70 | 71 | batch_time = time.time() - end 72 | end = time.time() 73 | meters.update(time=batch_time, data=data_time) 74 | logger.info( 75 | "{} finished.".format(curr_ref_img_path) + str(meters)) 76 | eval_file_logger(data_batch, preds, curr_ref_img_path, folder) 77 | 78 | 79 | def test(cfg, output_dir, isCPU=False): 80 | logger = logging.getLogger("fastmvsnet.tester") 81 | # build model 82 | model, _, _ = build_model(cfg) 83 | if not isCPU: 84 | model = nn.DataParallel(model).cuda() 85 | 86 | # build checkpointer 87 | checkpointer = Checkpointer(model, save_dir=output_dir) 88 | 89 | if cfg.TEST.WEIGHT: 90 | weight_path = cfg.TEST.WEIGHT.replace("@", output_dir) 91 | checkpointer.load(weight_path, resume=False) 92 | else: 93 | checkpointer.load(None, resume=True) 94 | 95 | # build data loader 96 | test_data_loader = build_data_loader(cfg, mode="test") 97 | start_time = time.time() 98 | test_model(model, 99 | image_scales=cfg.MODEL.TEST.IMG_SCALES, 100 | inter_scales=cfg.MODEL.TEST.INTER_SCALES, 101 | data_loader=test_data_loader, 102 | folder=output_dir.split("/")[-1], 103 | isCPU=isCPU, 104 | ) 105 | test_time = time.time() - start_time 106 | logger.info("Test forward time: {:.2f}s".format(test_time)) 107 | 108 | 109 | def main(): 110 | args = parse_args() 111 | num_gpus = torch.cuda.device_count() 112 | 113 | cfg = load_cfg_from_file(args.config_file) 114 | cfg.merge_from_list(args.opts) 115 | cfg.freeze() 116 | assert cfg.TEST.BATCH_SIZE == 1 117 | 118 | isCPU = args.cpu 119 | 120 | output_dir = cfg.OUTPUT_DIR 121 | if output_dir: 122 | config_path = osp.splitext(args.config_file)[0] 123 | config_path = config_path.replace("configs", "outputs") 124 | output_dir = output_dir.replace('@', config_path) 125 | mkdir(output_dir) 126 | 127 | logger = setup_logger("fastmvsnet", output_dir, prefix="test") 128 | if isCPU: 129 | logger.info("Using CPU") 130 | else: 131 | logger.info("Using {} GPUs".format(num_gpus)) 132 | logger.info(args) 133 | 134 | logger.info("Loaded configuration file {}".format(args.config_file)) 135 | logger.info("Running with config:\n{}".format(cfg)) 136 | 137 | test(cfg, output_dir, isCPU=isCPU) 138 | 139 | 140 | if __name__ == "__main__": 141 | main() 142 | -------------------------------------------------------------------------------- /fastmvsnet/train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | import argparse 3 | import os.path as osp 4 | import logging 5 | import time 6 | import sys 7 | sys.path.insert(0, osp.dirname(__file__) + '/..') 8 | 9 | import torch 10 | import torch.nn as nn 11 | 12 | from fastmvsnet.config import load_cfg_from_file 13 | from fastmvsnet.utils.io import mkdir 14 | from fastmvsnet.utils.logger import setup_logger 15 | from fastmvsnet.utils.torch_utils import set_random_seed 16 | from fastmvsnet.model import build_pointmvsnet as build_model 17 | from fastmvsnet.solver import build_optimizer, build_scheduler 18 | from fastmvsnet.utils.checkpoint import Checkpointer 19 | from fastmvsnet.dataset import build_data_loader 20 | from fastmvsnet.utils.tensorboard_logger import TensorboardLogger 21 | from fastmvsnet.utils.metric_logger import MetricLogger 22 | from fastmvsnet.utils.file_logger import file_logger 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser(description="PyTorch Fast-MVSNet Training") 27 | parser.add_argument( 28 | "--cfg", 29 | dest="config_file", 30 | default="", 31 | metavar="FILE", 32 | help="path to config file", 33 | type=str, 34 | ) 35 | parser.add_argument( 36 | "opts", 37 | help="Modify config options using the command-line", 38 | default=None, 39 | nargs=argparse.REMAINDER, 40 | ) 41 | 42 | args = parser.parse_args() 43 | return args 44 | 45 | 46 | def train_model(model, 47 | loss_fn, 48 | metric_fn, 49 | image_scales, 50 | inter_scales, 51 | isFlow, 52 | data_loader, 53 | optimizer, 54 | curr_epoch, 55 | tensorboard_logger, 56 | log_period=1, 57 | output_dir="", 58 | ): 59 | logger = logging.getLogger("fastmvsnet.train") 60 | meters = MetricLogger(delimiter=" ") 61 | model.train() 62 | end = time.time() 63 | total_iteration = data_loader.__len__() 64 | path_list = [] 65 | 66 | for iteration, data_batch in enumerate(data_loader): 67 | data_time = time.time() - end 68 | curr_ref_img_path = data_batch["ref_img_path"] 69 | path_list.extend(curr_ref_img_path) 70 | data_batch = {k: v.cuda(non_blocking=True) for k, v in data_batch.items() if isinstance(v, torch.Tensor)} 71 | 72 | preds = model(data_batch, image_scales, inter_scales, isFlow) 73 | optimizer.zero_grad() 74 | 75 | loss_dict = loss_fn(preds, data_batch, isFlow) 76 | metric_dict = metric_fn(preds, data_batch, isFlow) 77 | losses = sum(loss_dict.values()) 78 | meters.update(loss=losses, **loss_dict, **metric_dict) 79 | 80 | losses.backward() 81 | 82 | optimizer.step() 83 | 84 | batch_time = time.time() - end 85 | end = time.time() 86 | meters.update(time=batch_time, data=data_time) 87 | 88 | if iteration % log_period == 0: 89 | logger.info( 90 | meters.delimiter.join( 91 | [ 92 | "EPOCH: {epoch:2d}", 93 | "iter: {iter:4d}", 94 | "{meters}", 95 | "lr: {lr:.2e}", 96 | "max mem: {memory:.0f}", 97 | ] 98 | ).format( 99 | epoch=curr_epoch, 100 | iter=iteration, 101 | meters=str(meters), 102 | lr=optimizer.param_groups[0]["lr"], 103 | memory=torch.cuda.max_memory_allocated() / (1024.0 ** 2), 104 | ) 105 | ) 106 | tensorboard_logger.add_scalars(loss_dict, curr_epoch * total_iteration + iteration, prefix="train") 107 | tensorboard_logger.add_scalars(metric_dict, curr_epoch * total_iteration + iteration, prefix="train") 108 | 109 | if iteration % (100 * log_period) == 0: 110 | file_logger(data_batch, preds, curr_epoch * total_iteration + iteration, output_dir, prefix="train") 111 | 112 | return meters 113 | 114 | 115 | def validate_model(model, 116 | loss_fn, 117 | metric_fn, 118 | image_scales, 119 | inter_scales, 120 | isFlow, 121 | data_loader, 122 | curr_epoch, 123 | tensorboard_logger, 124 | log_period=1, 125 | output_dir="", 126 | ): 127 | logger = logging.getLogger("fastmvsnet.validate") 128 | meters = MetricLogger(delimiter=" ") 129 | model.train() 130 | end = time.time() 131 | total_iteration = data_loader.__len__() 132 | with torch.no_grad(): 133 | for iteration, data_batch in enumerate(data_loader): 134 | data_time = time.time() - end 135 | curr_ref_img_path = data_batch["ref_img_path"] 136 | 137 | data_batch = {k: v.cuda(non_blocking=True) for k, v in data_batch.items() if isinstance(v, torch.Tensor)} 138 | 139 | preds = model(data_batch, image_scales, inter_scales, isFlow) 140 | loss_dict = loss_fn(preds, data_batch, isFlow) 141 | metric_dict = metric_fn(preds, data_batch, isFlow) 142 | losses = sum(loss_dict.values()) 143 | meters.update(loss=losses, **loss_dict, **metric_dict) 144 | batch_time = time.time() - end 145 | end = time.time() 146 | meters.update(time=batch_time, data=data_time) 147 | 148 | if iteration % log_period == 0: 149 | logger.info( 150 | meters.delimiter.join( 151 | [ 152 | "EPOCH: {epoch:2d}", 153 | "iter: {iter:4d}", 154 | "{meters}", 155 | ] 156 | ).format( 157 | epoch=curr_epoch, 158 | iter=iteration, 159 | meters=str(meters), 160 | ) 161 | ) 162 | tensorboard_logger.add_scalars(meters.meters, curr_epoch * total_iteration + iteration, prefix="valid") 163 | 164 | if iteration % (100 * log_period) == 0: 165 | file_logger(data_batch, preds, curr_epoch * total_iteration + iteration, output_dir, prefix="valid") 166 | 167 | return meters 168 | 169 | 170 | def train(cfg, output_dir=""): 171 | logger = logging.getLogger("fastmvsnet.trainer") 172 | 173 | # build model 174 | set_random_seed(cfg.RNG_SEED) 175 | model, loss_fn, metric_fn = build_model(cfg) 176 | logger.info("Build model:\n{}".format(str(model))) 177 | model = nn.DataParallel(model).cuda() 178 | 179 | # build optimizer 180 | optimizer = build_optimizer(cfg, model) 181 | 182 | # build lr scheduler 183 | scheduler = build_scheduler(cfg, optimizer) 184 | 185 | # build checkpointer 186 | checkpointer = Checkpointer(model, 187 | optimizer=optimizer, 188 | scheduler=scheduler, 189 | save_dir=output_dir, 190 | logger=logger) 191 | 192 | checkpoint_data = checkpointer.load(cfg.MODEL.WEIGHT, resume=cfg.AUTO_RESUME) 193 | ckpt_period = cfg.TRAIN.CHECKPOINT_PERIOD 194 | 195 | # build data loader 196 | train_data_loader = build_data_loader(cfg, mode="train") 197 | val_period = cfg.TRAIN.VAL_PERIOD 198 | val_data_loader = build_data_loader(cfg, mode="val") if val_period > 0 else None 199 | 200 | # build tensorboard logger (optionally by comment) 201 | tensorboard_logger = TensorboardLogger(output_dir) 202 | 203 | # train 204 | max_epoch = cfg.SCHEDULER.MAX_EPOCH 205 | start_epoch = checkpoint_data.get("epoch", 0) 206 | best_metric_name = "best_{}".format(cfg.TRAIN.VAL_METRIC) 207 | best_metric = checkpoint_data.get(best_metric_name, None) 208 | logger.info("Start training from epoch {}".format(start_epoch)) 209 | for epoch in range(start_epoch, max_epoch): 210 | cur_epoch = epoch + 1 211 | scheduler.step() 212 | start_time = time.time() 213 | train_meters = train_model(model, 214 | loss_fn, 215 | metric_fn, 216 | image_scales=cfg.MODEL.TRAIN.IMG_SCALES, 217 | inter_scales=cfg.MODEL.TRAIN.INTER_SCALES, 218 | isFlow=(cur_epoch > cfg.SCHEDULER.INIT_EPOCH), 219 | data_loader=train_data_loader, 220 | optimizer=optimizer, 221 | curr_epoch=epoch, 222 | tensorboard_logger=tensorboard_logger, 223 | log_period=cfg.TRAIN.LOG_PERIOD, 224 | output_dir=output_dir, 225 | ) 226 | epoch_time = time.time() - start_time 227 | logger.info("Epoch[{}]-Train {} total_time: {:.2f}s".format( 228 | cur_epoch, train_meters.summary_str, epoch_time)) 229 | 230 | # checkpoint 231 | if cur_epoch % ckpt_period == 0 or cur_epoch == max_epoch: 232 | checkpoint_data["epoch"] = cur_epoch 233 | checkpoint_data[best_metric_name] = best_metric 234 | checkpointer.save("model_{:03d}".format(cur_epoch), **checkpoint_data) 235 | 236 | # validate 237 | if val_period < 1: 238 | continue 239 | if cur_epoch % val_period == 0 or cur_epoch == max_epoch: 240 | val_meters = validate_model(model, 241 | loss_fn, 242 | metric_fn, 243 | image_scales=cfg.MODEL.VAL.IMG_SCALES, 244 | inter_scales=cfg.MODEL.VAL.INTER_SCALES, 245 | isFlow=(cur_epoch > cfg.SCHEDULER.INIT_EPOCH), 246 | data_loader=val_data_loader, 247 | curr_epoch=epoch, 248 | tensorboard_logger=tensorboard_logger, 249 | log_period=cfg.TEST.LOG_PERIOD, 250 | output_dir=output_dir, 251 | ) 252 | logger.info("Epoch[{}]-Val {}".format(cur_epoch, val_meters.summary_str)) 253 | 254 | # best validation 255 | cur_metric = val_meters.meters[cfg.TRAIN.VAL_METRIC].global_avg 256 | if best_metric is None or cur_metric > best_metric: 257 | best_metric = cur_metric 258 | checkpoint_data["epoch"] = cur_epoch 259 | checkpoint_data[best_metric_name] = best_metric 260 | checkpointer.save("model_best", **checkpoint_data) 261 | 262 | logger.info("Best val-{} = {}".format(cfg.TRAIN.VAL_METRIC, best_metric)) 263 | 264 | return model 265 | 266 | 267 | def main(): 268 | args = parse_args() 269 | num_gpus = torch.cuda.device_count() 270 | 271 | cfg = load_cfg_from_file(args.config_file) 272 | cfg.merge_from_list(args.opts) 273 | cfg.freeze() 274 | 275 | output_dir = cfg.OUTPUT_DIR 276 | if output_dir: 277 | config_path = osp.splitext(args.config_file)[0] 278 | config_path = config_path.replace("configs", "outputs") 279 | output_dir = output_dir.replace('@', config_path) 280 | mkdir(output_dir) 281 | 282 | logger = setup_logger("fastmvsnet", output_dir, prefix="train") 283 | logger.info("Using {} GPUs".format(num_gpus)) 284 | logger.info(args) 285 | 286 | logger.info("Loaded configuration file {}".format(args.config_file)) 287 | logger.info("Running with config:\n{}".format(cfg)) 288 | 289 | train(cfg, output_dir) 290 | 291 | 292 | if __name__ == "__main__": 293 | main() 294 | -------------------------------------------------------------------------------- /fastmvsnet/utils/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/FastMVSNet/ccb686dda2717613c67d8a289dfe7b2aeb60e2fd/fastmvsnet/utils/__init__.py -------------------------------------------------------------------------------- /fastmvsnet/utils/checkpoint.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | 5 | import torch 6 | 7 | 8 | class Checkpointer(object): 9 | def __init__( 10 | self, 11 | model, 12 | optimizer=None, 13 | scheduler=None, 14 | save_dir="", 15 | logger=None, 16 | ): 17 | self.model = model 18 | self.optimizer = optimizer 19 | self.scheduler = scheduler 20 | self.save_dir = save_dir 21 | if logger is None: 22 | logger = logging.getLogger(__name__) 23 | self.logger = logger 24 | 25 | def save(self, name, **kwargs): 26 | if not self.save_dir: 27 | return 28 | 29 | data = {} 30 | data["model"] = self.model.state_dict() 31 | if self.optimizer is not None: 32 | data["optimizer"] = self.optimizer.state_dict() 33 | if self.scheduler is not None: 34 | data["scheduler"] = self.scheduler.state_dict() 35 | data.update(kwargs) 36 | 37 | save_file = os.path.join(self.save_dir, "{}.pth".format(name)) 38 | self.logger.info("Saving checkpoint to {}".format(save_file)) 39 | torch.save(data, save_file) 40 | self.tag_last_checkpoint(save_file) 41 | 42 | def load(self, f=None, resume=True): 43 | if resume and self.has_checkpoint(): 44 | # override argument with existing checkpoint 45 | f = self.get_checkpoint_file() 46 | if not f: 47 | # no checkpoint could be found 48 | self.logger.info("No checkpoint found. Initializing model from scratch") 49 | return {} 50 | self.logger.info("Loading checkpoint from {}".format(f)) 51 | checkpoint = self._load_file(f) 52 | self.model.load_state_dict(checkpoint.pop("model"), False) 53 | # if "optimizer" in checkpoint and self.optimizer: 54 | # self.logger.info("Loading optimizer from {}".format(f)) 55 | # self.optimizer.load_state_dict(checkpoint.pop("optimizer")) 56 | if "scheduler" in checkpoint and self.scheduler: 57 | self.logger.info("Loading scheduler from {}".format(f)) 58 | self.scheduler.load_state_dict(checkpoint.pop("scheduler")) 59 | 60 | # return any further checkpoint data 61 | return checkpoint 62 | 63 | def has_checkpoint(self): 64 | save_file = os.path.join(self.save_dir, "last_checkpoint") 65 | return os.path.exists(save_file) 66 | 67 | def get_checkpoint_file(self): 68 | save_file = os.path.join(self.save_dir, "last_checkpoint") 69 | try: 70 | with open(save_file, "r") as f: 71 | last_saved = f.read().strip() 72 | except IOError: 73 | # if file doesn't exist, maybe because it has just been 74 | # deleted by a separate process 75 | last_saved = "" 76 | return last_saved 77 | 78 | def tag_last_checkpoint(self, last_filename): 79 | save_file = os.path.join(self.save_dir, "last_checkpoint") 80 | with open(save_file, "w") as f: 81 | f.write(last_filename) 82 | 83 | def _load_file(self, f): 84 | return torch.load(f, map_location=torch.device("cpu")) 85 | -------------------------------------------------------------------------------- /fastmvsnet/utils/eval_file_logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | import cv2 4 | import scipy 5 | 6 | import torch 7 | import torch.nn.functional as F 8 | 9 | from fastmvsnet.utils.io import mkdir, write_cam_dtu, write_pfm 10 | 11 | 12 | def eval_file_logger(data_batch, preds, ref_img_path, folder, scene_name_index=-2, out_index_minus=1, save_prob_volume=False): 13 | l = ref_img_path.split("/") 14 | eval_folder = "/".join(l[:-3]) 15 | 16 | scene = l[scene_name_index] 17 | 18 | scene_folder = osp.join(eval_folder, folder, scene) 19 | 20 | if not osp.isdir(scene_folder): 21 | mkdir(scene_folder) 22 | print("**** {} ****".format(scene)) 23 | 24 | out_index = int(l[-1][5:8]) - out_index_minus 25 | 26 | cam_params_list = data_batch["cam_params_list"].cpu().numpy() 27 | 28 | ref_cam_paras = cam_params_list[0, 0, :, :, :] 29 | 30 | init_depth_map_path = scene_folder + ('/%08d_init.pfm' % out_index) 31 | init_prob_map_path = scene_folder + ('/%08d_init_prob.pfm' % out_index) 32 | out_ref_image_path = scene_folder + ('/%08d.jpg' % out_index) 33 | 34 | init_depth_map = preds["coarse_depth_map"].cpu().numpy()[0, 0] 35 | init_prob_map = preds["coarse_prob_map"].cpu().numpy()[0, 0] 36 | ref_image = data_batch["ref_img"][0].cpu().numpy() 37 | 38 | write_pfm(init_depth_map_path, init_depth_map) 39 | write_pfm(init_prob_map_path, init_prob_map) 40 | cv2.imwrite(out_ref_image_path, ref_image) 41 | 42 | out_init_cam_path = scene_folder + ('/cam_%08d_init.txt' % out_index) 43 | init_cam_paras = ref_cam_paras.copy() 44 | init_cam_paras[1, :2, :3] *= (float(init_depth_map.shape[0]) / ref_image.shape[0]) 45 | write_cam_dtu(out_init_cam_path, init_cam_paras) 46 | 47 | interval_list = np.array([-2.0, -1.0, 0.0, 1.0, 2.0]) 48 | interval_list = np.reshape(interval_list, [1, 1, -1]) 49 | 50 | for i, k in enumerate(preds.keys()): 51 | if "flow" in k: 52 | if "prob" in k: 53 | out_flow_prob_map = preds[k][0].cpu().permute(1, 2, 0).numpy() 54 | num_interval = out_flow_prob_map.shape[-1] 55 | assert num_interval == interval_list.size 56 | pred_interval = np.sum(out_flow_prob_map * interval_list, axis=-1) + 2.0 57 | pred_floor = np.floor(pred_interval).astype(np.int)[..., np.newaxis] 58 | pred_ceil = pred_floor + 1 59 | pred_ceil = np.clip(pred_ceil, 0, num_interval - 1) 60 | pred_floor = np.clip(pred_floor, 0, num_interval - 1) 61 | prob_height, prob_width = pred_floor.shape[:2] 62 | prob_height_ind = np.tile(np.reshape(np.arange(prob_height), [-1, 1, 1]), [1, prob_width, 1]) 63 | prob_width_ind = np.tile(np.reshape(np.arange(prob_width), [1, -1, 1]), [prob_height, 1, 1]) 64 | 65 | floor_prob = np.squeeze(out_flow_prob_map[prob_height_ind, prob_width_ind, pred_floor], -1) 66 | ceil_prob = np.squeeze(out_flow_prob_map[prob_height_ind, prob_width_ind, pred_ceil], -1) 67 | flow_prob = floor_prob + ceil_prob 68 | flow_prob_map_path = scene_folder + "/{:08d}_{}.pfm".format(out_index, k) 69 | write_pfm(flow_prob_map_path, flow_prob) 70 | 71 | else: 72 | out_flow_depth_map = preds[k][0, 0].cpu().numpy() 73 | flow_depth_map_path = scene_folder + "/{:08d}_{}.pfm".format(out_index, k) 74 | write_pfm(flow_depth_map_path, out_flow_depth_map) 75 | out_flow_cam_path = scene_folder + "/cam_{:08d}_{}.txt".format(out_index, k) 76 | flow_cam_paras = ref_cam_paras.copy() 77 | flow_cam_paras[1, :2, :3] *= (float(out_flow_depth_map.shape[0]) / float(ref_image.shape[0])) 78 | write_cam_dtu(out_flow_cam_path, flow_cam_paras) 79 | 80 | world_pts = depth2pts_np(out_flow_depth_map, flow_cam_paras[1][:3, :3], flow_cam_paras[0]) 81 | save_points(osp.join(scene_folder, "{:08d}_{}pts.xyz".format(out_index, k)), world_pts) 82 | # save cost volume 83 | if save_prob_volume: 84 | probability_volume = preds["coarse_prob_volume"].cpu().numpy()[0] 85 | init_prob_volume_path = scene_folder + ('/%08d_init_prob_volume.npz' % out_index) 86 | np.savez(init_prob_volume_path, probability_volume) 87 | 88 | 89 | def depth2pts_np(depth_map, cam_intrinsic, cam_extrinsic): 90 | feature_grid = get_pixel_grids_np(depth_map.shape[0], depth_map.shape[1]) 91 | 92 | uv = np.matmul(np.linalg.inv(cam_intrinsic), feature_grid) 93 | cam_points = uv * np.reshape(depth_map, (1, -1)) 94 | 95 | R = cam_extrinsic[:3, :3] 96 | t = cam_extrinsic[:3, 3:4] 97 | R_inv = np.linalg.inv(R) 98 | 99 | world_points = np.matmul(R_inv, cam_points - t).transpose() 100 | return world_points 101 | 102 | 103 | def get_pixel_grids_np(height, width): 104 | x_linspace = np.linspace(0.5, width - 0.5, width) 105 | y_linspace = np.linspace(0.5, height - 0.5, height) 106 | x_coordinates, y_coordinates = np.meshgrid(x_linspace, y_linspace) 107 | x_coordinates = np.reshape(x_coordinates, (1, -1)) 108 | y_coordinates = np.reshape(y_coordinates, (1, -1)) 109 | ones = np.ones_like(x_coordinates).astype(np.float) 110 | grid = np.concatenate([x_coordinates, y_coordinates, ones], axis=0) 111 | 112 | return grid 113 | 114 | 115 | def save_points(path, points): 116 | np.savetxt(path, points, delimiter=' ', fmt='%.4f') 117 | -------------------------------------------------------------------------------- /fastmvsnet/utils/feature_fetcher.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import random 5 | import numpy as np 6 | 7 | 8 | class FeatureFetcher(nn.Module): 9 | def __init__(self, mode="bilinear"): 10 | super(FeatureFetcher, self).__init__() 11 | self.mode = mode 12 | 13 | def forward(self, feature_maps, pts, cam_intrinsics, cam_extrinsics): 14 | """ 15 | 16 | :param feature_maps: torch.tensor, [B, V, C, H, W] 17 | :param pts: torch.tensor, [B, 3, N] 18 | :param cam_intrinsics: torch.tensor, [B, V, 3, 3] 19 | :param cam_extrinsics: torch.tensor, [B, V, 3, 4], [R|t], p_cam = R*p_world + t 20 | :return: 21 | pts_feature: torch.tensor, [B, V, C, N] 22 | """ 23 | batch_size, num_view, channels, height, width = list(feature_maps.size()) 24 | feature_maps = feature_maps.view(batch_size * num_view, channels, height, width) 25 | 26 | curr_batch_size = batch_size * num_view 27 | cam_intrinsics = cam_intrinsics.view(curr_batch_size, 3, 3) 28 | 29 | with torch.no_grad(): 30 | num_pts = pts.size(2) 31 | pts_expand = pts.unsqueeze(1).contiguous().expand(batch_size, num_view, 3, num_pts) \ 32 | .contiguous().view(curr_batch_size, 3, num_pts) 33 | if cam_extrinsics is None: 34 | transformed_pts = pts_expand.type(torch.float).transpose(1, 2) 35 | else: 36 | cam_extrinsics = cam_extrinsics.view(curr_batch_size, 3, 4) 37 | R = torch.narrow(cam_extrinsics, 2, 0, 3) 38 | t = torch.narrow(cam_extrinsics, 2, 3, 1).expand(curr_batch_size, 3, num_pts) 39 | transformed_pts = torch.bmm(R, pts_expand) + t 40 | transformed_pts = transformed_pts.type(torch.float).transpose(1, 2) 41 | x = transformed_pts[..., 0] 42 | y = transformed_pts[..., 1] 43 | z = transformed_pts[..., 2] 44 | 45 | normal_uv = torch.cat( 46 | [torch.div(x, z).unsqueeze(-1), torch.div(y, z).unsqueeze(-1), torch.ones_like(x).unsqueeze(-1)], 47 | dim=-1) 48 | uv = torch.bmm(normal_uv, cam_intrinsics.transpose(1, 2)) 49 | uv = uv[:, :, :2] 50 | 51 | grid = (uv - 0.5).view(curr_batch_size, num_pts, 1, 2) 52 | grid[..., 0] = (grid[..., 0] / float(width - 1)) * 2 - 1.0 53 | grid[..., 1] = (grid[..., 1] / float(height - 1)) * 2 - 1.0 54 | 55 | # pts_feature = F.grid_sample(feature_maps, grid, mode=self.mode, padding_mode='border') 56 | # print("without border pad-----------------------") 57 | pts_feature = F.grid_sample(feature_maps, grid, mode=self.mode) 58 | pts_feature = pts_feature.squeeze(3) 59 | 60 | pts_feature = pts_feature.view(batch_size, num_view, channels, num_pts) 61 | 62 | return pts_feature 63 | 64 | 65 | class FeatureGradFetcher(nn.Module): 66 | def __init__(self, mode="bilinear"): 67 | super(FeatureGradFetcher, self).__init__() 68 | self.mode = mode 69 | 70 | def forward(self, feature_maps, pts, cam_intrinsics, cam_extrinsics): 71 | """ 72 | 73 | :param feature_maps: torch.tensor, [B, V, C, H, W] 74 | :param pts: torch.tensor, [B, 3, N] 75 | :param cam_intrinsics: torch.tensor, [B, V, 3, 3] 76 | :param cam_extrinsics: torch.tensor, [B, V, 3, 4], [R|t], p_cam = R*p_world + t 77 | :return: 78 | pts_feature: torch.tensor, [B, V, C, N] 79 | """ 80 | batch_size, num_view, channels, height, width = list(feature_maps.size()) 81 | feature_maps = feature_maps.view(batch_size * num_view, channels, height, width) 82 | 83 | curr_batch_size = batch_size * num_view 84 | cam_intrinsics = cam_intrinsics.view(curr_batch_size, 3, 3) 85 | 86 | with torch.no_grad(): 87 | num_pts = pts.size(2) 88 | pts_expand = pts.unsqueeze(1).contiguous().expand(batch_size, num_view, 3, num_pts) \ 89 | .contiguous().view(curr_batch_size, 3, num_pts) 90 | if cam_extrinsics is None: 91 | transformed_pts = pts_expand.type(torch.float).transpose(1, 2) 92 | else: 93 | cam_extrinsics = cam_extrinsics.view(curr_batch_size, 3, 4) 94 | R = torch.narrow(cam_extrinsics, 2, 0, 3) 95 | t = torch.narrow(cam_extrinsics, 2, 3, 1).expand(curr_batch_size, 3, num_pts) 96 | transformed_pts = torch.bmm(R, pts_expand) + t 97 | transformed_pts = transformed_pts.type(torch.float).transpose(1, 2) 98 | x = transformed_pts[..., 0] 99 | y = transformed_pts[..., 1] 100 | z = transformed_pts[..., 2] 101 | 102 | normal_uv = torch.cat( 103 | [torch.div(x, z).unsqueeze(-1), torch.div(y, z).unsqueeze(-1), torch.ones_like(x).unsqueeze(-1)], 104 | dim=-1) 105 | uv = torch.bmm(normal_uv, cam_intrinsics.transpose(1, 2)) 106 | uv = uv[:, :, :2] 107 | 108 | grid = (uv - 0.5).view(curr_batch_size, num_pts, 1, 2) 109 | grid[..., 0] = (grid[..., 0] / float(width - 1)) * 2 - 1.0 110 | grid[..., 1] = (grid[..., 1] / float(height - 1)) * 2 - 1.0 111 | 112 | #todo check bug 113 | grid_l = grid.clone() 114 | grid_l[..., 0] -= (1. / float(width - 1)) * 2 115 | 116 | grid_r = grid.clone() 117 | grid_r[..., 0] += (1. / float(width - 1)) * 2 118 | 119 | grid_t = grid.clone() 120 | grid_t[..., 1] -= (1. / float(height - 1)) * 2 121 | 122 | grid_b = grid.clone() 123 | grid_b[..., 1] += (1. / float(height - 1)) * 2 124 | 125 | 126 | def get_features(grid_uv): 127 | pts_feature = F.grid_sample(feature_maps, grid_uv, mode=self.mode) 128 | pts_feature = pts_feature.squeeze(3) 129 | 130 | pts_feature = pts_feature.view(batch_size, num_view, channels, num_pts) 131 | return pts_feature 132 | 133 | pts_feature = get_features(grid) 134 | 135 | pts_feature_l = get_features(grid_l) 136 | pts_feature_r = get_features(grid_r) 137 | pts_feature_t = get_features(grid_t) 138 | pts_feature_b = get_features(grid_b) 139 | 140 | pts_feature_grad_x = 0.5 * (pts_feature_r - pts_feature_l) 141 | pts_feature_grad_y = 0.5 * (pts_feature_b - pts_feature_t) 142 | 143 | pts_feature_grad = torch.stack((pts_feature_grad_x, pts_feature_grad_y), dim=-1) 144 | # print("================features++++++++++++") 145 | # print(feature_maps) 146 | # print ("===========grad+++++++++++++++") 147 | # print (pts_feature_grad) 148 | return pts_feature, pts_feature_grad 149 | 150 | def get_result(self, feature_maps, pts, cam_intrinsics, cam_extrinsics): 151 | batch_size, num_view, channels, height, width = list(feature_maps.size()) 152 | feature_maps = feature_maps.view(batch_size * num_view, channels, height, width) 153 | 154 | curr_batch_size = batch_size * num_view 155 | cam_intrinsics = cam_intrinsics.view(curr_batch_size, 3, 3) 156 | 157 | num_pts = pts.size(2) 158 | pts_expand = pts.unsqueeze(1).contiguous().expand(batch_size, num_view, 3, num_pts) \ 159 | .contiguous().view(curr_batch_size, 3, num_pts) 160 | if cam_extrinsics is None: 161 | transformed_pts = pts_expand.type(torch.float).transpose(1, 2) 162 | else: 163 | cam_extrinsics = cam_extrinsics.view(curr_batch_size, 3, 4) 164 | R = torch.narrow(cam_extrinsics, 2, 0, 3) 165 | t = torch.narrow(cam_extrinsics, 2, 3, 1).expand(curr_batch_size, 3, num_pts) 166 | transformed_pts = torch.bmm(R, pts_expand) + t 167 | transformed_pts = transformed_pts.type(torch.float).transpose(1, 2) 168 | x = transformed_pts[..., 0] 169 | y = transformed_pts[..., 1] 170 | z = transformed_pts[..., 2] 171 | 172 | normal_uv = torch.cat( 173 | [torch.div(x, z).unsqueeze(-1), torch.div(y, z).unsqueeze(-1), torch.ones_like(x).unsqueeze(-1)], 174 | dim=-1) 175 | uv = torch.bmm(normal_uv, cam_intrinsics.transpose(1, 2)) 176 | uv = uv[:, :, :2] 177 | 178 | grid = (uv - 0.5).view(curr_batch_size, num_pts, 1, 2) 179 | grid[..., 0] = (grid[..., 0] / float(width - 1)) * 2 - 1.0 180 | grid[..., 1] = (grid[..., 1] / float(height - 1)) * 2 - 1.0 181 | 182 | def get_features(grid_uv): 183 | pts_feature = F.grid_sample(feature_maps, grid_uv, mode=self.mode) 184 | pts_feature = pts_feature.squeeze(3) 185 | 186 | pts_feature = pts_feature.view(batch_size, num_view, channels, num_pts) 187 | return pts_feature.detach() 188 | 189 | pts_feature = get_features(grid) 190 | 191 | # todo check bug 192 | grid[..., 0] -= (1. / float(width - 1)) * 2 193 | pts_feature_l = get_features(grid) 194 | grid[..., 0] += (1. / float(width - 1)) * 2 195 | 196 | grid[..., 0] += (1. / float(width - 1)) * 2 197 | pts_feature_r = get_features(grid) 198 | grid[..., 0] -= (1. / float(width - 1)) * 2 199 | 200 | grid[..., 1] -= (1. / float(height - 1)) * 2 201 | pts_feature_t = get_features(grid) 202 | grid[..., 1] += (1. / float(height - 1)) * 2 203 | 204 | grid[..., 1] += (1. / float(height - 1)) * 2 205 | pts_feature_b = get_features(grid) 206 | grid[..., 1] -= (1. / float(height - 1)) * 2 207 | 208 | pts_feature_r -= pts_feature_l 209 | pts_feature_r *= 0.5 210 | pts_feature_b -= pts_feature_t 211 | pts_feature_b *= 0.5 212 | 213 | return pts_feature.detach(), pts_feature_r.detach(), pts_feature_b.detach() 214 | 215 | def test_forward(self, feature_maps, pts, cam_intrinsics, cam_extrinsics): 216 | """ 217 | 218 | :param feature_maps: torch.tensor, [B, V, C, H, W] 219 | :param pts: torch.tensor, [B, 3, N] 220 | :param cam_intrinsics: torch.tensor, [B, V, 3, 3] 221 | :param cam_extrinsics: torch.tensor, [B, V, 3, 4], [R|t], p_cam = R*p_world + t 222 | :return: 223 | pts_feature: torch.tensor, [B, V, C, N] 224 | """ 225 | with torch.no_grad(): 226 | pts_feature, grad_x, grad_y = \ 227 | self.get_result(feature_maps, pts, cam_intrinsics, cam_extrinsics) 228 | torch.cuda.empty_cache() 229 | pts_feature_grad = torch.stack((grad_x, grad_y), dim=-1) 230 | 231 | return pts_feature.detach(), pts_feature_grad.detach() 232 | 233 | 234 | class PointGrad(nn.Module): 235 | def __init__(self): 236 | super(PointGrad, self).__init__() 237 | 238 | def forward(self, pts, cam_intrinsics, cam_extrinsics): 239 | """ 240 | :param pts: torch.tensor, [B, 3, N] 241 | :param cam_intrinsics: torch.tensor, [B, V, 3, 3] 242 | :param cam_extrinsics: torch.tensor, [B, V, 3, 4], [R|t], p_cam = R*p_world + t 243 | :return: 244 | pts_feature: torch.tensor, [B, V, C, N] 245 | """ 246 | batch_size, num_view, _, _ = list(cam_extrinsics.size()) 247 | 248 | curr_batch_size = batch_size * num_view 249 | cam_intrinsics = cam_intrinsics.view(curr_batch_size, 3, 3) 250 | 251 | with torch.no_grad(): 252 | num_pts = pts.size(2) 253 | pts_expand = pts.unsqueeze(1).contiguous().expand(batch_size, num_view, 3, num_pts) \ 254 | .contiguous().view(curr_batch_size, 3, num_pts) 255 | if cam_extrinsics is None: 256 | transformed_pts = pts_expand.type(torch.float).transpose(1, 2) 257 | else: 258 | cam_extrinsics = cam_extrinsics.view(curr_batch_size, 3, 4) 259 | R = torch.narrow(cam_extrinsics, 2, 0, 3) 260 | t = torch.narrow(cam_extrinsics, 2, 3, 1).expand(curr_batch_size, 3, num_pts) 261 | transformed_pts = torch.bmm(R, pts_expand) + t 262 | transformed_pts = transformed_pts.type(torch.float).transpose(1, 2) 263 | x = transformed_pts[..., 0] 264 | y = transformed_pts[..., 1] 265 | z = transformed_pts[..., 2] 266 | 267 | fx = cam_intrinsics[..., 0, 0].view(curr_batch_size, 1) 268 | fy = cam_intrinsics[..., 1, 1].view(curr_batch_size, 1) 269 | 270 | # print("x", x.size()) 271 | # print("fx", fx.size(), fx, fy) 272 | 273 | zero = torch.zeros_like(x) 274 | grad_u = torch.stack([fx / z, zero, -fx * x / (z**2)], dim=-1) 275 | grad_v = torch.stack([zero, fy / z, -fy * y / (z**2)], dim=-1) 276 | grad_p = torch.stack((grad_u, grad_v), dim=-2) 277 | # print("grad_u size:", grad_u.size()) 278 | # print("grad_p size:", grad_p.size()) 279 | grad_p = grad_p.view(batch_size, num_view, num_pts, 2, 3) 280 | return grad_p 281 | 282 | 283 | 284 | class ProjectUVFetcher(nn.Module): 285 | def __init__(self, mode="bilinear"): 286 | super(ProjectUVFetcher, self).__init__() 287 | self.mode = mode 288 | 289 | def forward(self, pts, cam_intrinsics, cam_extrinsics): 290 | """ 291 | 292 | :param pts: torch.tensor, [B, 3, N] 293 | :param cam_intrinsics: torch.tensor, [B, V, 3, 3] 294 | :param cam_extrinsics: torch.tensor, [B, V, 3, 4], [R|t], p_cam = R*p_world + t 295 | :return: 296 | pts_feature: torch.tensor, [B, V, C, N] 297 | """ 298 | batch_size, num_view = cam_extrinsics.size()[:2] 299 | 300 | curr_batch_size = batch_size * num_view 301 | cam_intrinsics = cam_intrinsics.view(curr_batch_size, 3, 3) 302 | 303 | with torch.no_grad(): 304 | num_pts = pts.size(2) 305 | pts_expand = pts.unsqueeze(1).contiguous().expand(batch_size, num_view, 3, num_pts) \ 306 | .contiguous().view(curr_batch_size, 3, num_pts) 307 | if cam_extrinsics is None: 308 | transformed_pts = pts_expand.type(torch.float).transpose(1, 2) 309 | else: 310 | cam_extrinsics = cam_extrinsics.view(curr_batch_size, 3, 4) 311 | R = torch.narrow(cam_extrinsics, 2, 0, 3) 312 | t = torch.narrow(cam_extrinsics, 2, 3, 1).expand(curr_batch_size, 3, num_pts) 313 | transformed_pts = torch.bmm(R, pts_expand) + t 314 | transformed_pts = transformed_pts.type(torch.float).transpose(1, 2) 315 | x = transformed_pts[..., 0] 316 | y = transformed_pts[..., 1] 317 | z = transformed_pts[..., 2] 318 | 319 | normal_uv = torch.cat( 320 | [torch.div(x, z).unsqueeze(-1), torch.div(y, z).unsqueeze(-1), torch.ones_like(x).unsqueeze(-1)], 321 | dim=-1) 322 | uv = torch.bmm(normal_uv, cam_intrinsics.transpose(1, 2)) 323 | uv = uv[:, :, :2] 324 | 325 | grid = (uv - 0.5).view(curr_batch_size, num_pts, 1, 2) 326 | 327 | return grid.view(batch_size, num_view, num_pts, 1, 2) 328 | 329 | 330 | def test_feature_fetching(): 331 | import numpy as np 332 | batch_size = 3 333 | num_view = 2 334 | channels = 16 335 | height = 240 336 | width = 320 337 | num_pts = 32 338 | 339 | cam_intrinsic = torch.tensor([[10, 0, 1], [0, 10, 1], [0, 0, 1]]).float() \ 340 | .view(1, 1, 3, 3).expand(batch_size, num_view, 3, 3).cuda() 341 | cam_extrinsic = torch.rand(batch_size, num_view, 3, 4).cuda() 342 | 343 | feature_fetcher = FeatureFetcher().cuda() 344 | 345 | features = torch.rand(batch_size, num_view, channels, height, width).cuda() 346 | 347 | imgpt = torch.tensor([60.5, 80.5, 1.0]).view(1, 1, 3, 1).expand(batch_size, num_view, 3, num_pts).cuda() 348 | 349 | z = 200 350 | 351 | pt = torch.matmul(torch.inverse(cam_intrinsic), imgpt) * z 352 | 353 | pt = torch.matmul(torch.inverse(cam_extrinsic[:, :, :, :3]), 354 | (pt - cam_extrinsic[:, :, :, 3].unsqueeze(-1))) # Xc = [R|T] Xw 355 | 356 | gathered_feature = feature_fetcher(features, pt[:, 0, :, :], cam_intrinsic, cam_extrinsic) 357 | 358 | gathered_feature = gathered_feature[:, 0, :, 0] 359 | np.savetxt("gathered_feature.txt", gathered_feature.detach().cpu().numpy(), fmt="%.4f") 360 | 361 | groundtruth_feature = features[:, :, :, 80, 60][:, 0, :] 362 | np.savetxt("groundtruth_feature.txt", groundtruth_feature.detach().cpu().numpy(), fmt="%.4f") 363 | 364 | print(np.allclose(gathered_feature.detach().cpu().numpy(), groundtruth_feature.detach().cpu().numpy(), 1.e-2)) 365 | 366 | 367 | if __name__ == "__main__": 368 | test_feature_fetching() 369 | -------------------------------------------------------------------------------- /fastmvsnet/utils/file_logger.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import os.path as osp 3 | import cv2 4 | 5 | import torch 6 | import torch.nn.functional as F 7 | 8 | from fastmvsnet.utils.io import mkdir 9 | from fastmvsnet.functions.functions import get_pixel_grids 10 | 11 | 12 | def file_logger(data_batch, preds, step, output_dir, prefix): 13 | step_dir = osp.join(output_dir, "{}_step{:05d}".format(prefix, step)) 14 | mkdir(step_dir) 15 | print("start saving files in ", step_dir) 16 | 17 | img_list = data_batch["img_list"] 18 | batch_size, num_view, img_channel, img_height, img_width = list(img_list.size()) 19 | 20 | cam_params_list = data_batch["cam_params_list"] 21 | 22 | for i in range(num_view): 23 | np.savetxt(osp.join(step_dir, "img{}.txt".format(i)), img_list[0, i, 0].detach().cpu().numpy(), fmt="%.4f") 24 | np.savetxt(osp.join(step_dir, "cam{}_extrinsic.txt".format(i)), cam_params_list[0, i, 0].detach().cpu().numpy(), fmt="%.4f") 25 | np.savetxt(osp.join(step_dir, "cam{}_intrinsic.txt".format(i)), cam_params_list[0, i, 1].detach().cpu().numpy(), fmt="%.4f") 26 | np.savetxt(osp.join(step_dir, "gt_depth_img.txt"), data_batch["gt_depth_img"][0, 0].detach().cpu().numpy(), fmt="%.4f") 27 | np.savetxt(osp.join(step_dir, "coarse_depth_img.txt"), preds["coarse_depth_map"][0, 0].detach().cpu().numpy(), fmt="%.4f") 28 | 29 | cam_extrinsic = cam_params_list[0, 0, 0, :3, :4].clone() # (3, 4) 30 | 31 | cam_intrinsic = cam_params_list[0, 0, 1, :3, :3].clone() 32 | 33 | world_points = preds["world_points"] 34 | world_points = world_points[0].cpu().numpy().transpose() 35 | save_points(osp.join(step_dir, "world_points.xyz"), world_points) 36 | 37 | prob_map = preds["coarse_prob_map"][0][0].cpu().numpy() 38 | 39 | coarse_points = depth2pts(preds["coarse_depth_map"], prob_map, 40 | cam_intrinsic, cam_extrinsic, (img_height, img_width)) 41 | save_points(osp.join(step_dir, "coarse_point.xyz"), coarse_points) 42 | 43 | gt_points = depth2pts(data_batch["gt_depth_img"], prob_map, 44 | cam_intrinsic, cam_extrinsic, (img_height, img_width)) 45 | save_points(osp.join(step_dir, "gt_points.xyz"), gt_points) 46 | 47 | if "flow1" in preds.keys(): 48 | flow1_points = depth2pts(preds["flow1"], prob_map, 49 | cam_intrinsic, cam_extrinsic, (img_height, img_width)) 50 | save_points(osp.join(step_dir, "flow1_points.xyz"), flow1_points) 51 | 52 | if "flow2" in preds.keys(): 53 | flow2_points = depth2pts(preds["flow2"], prob_map, 54 | cam_intrinsic, cam_extrinsic, (img_height, img_width)) 55 | save_points(osp.join(step_dir, "flow2_points.xyz"), flow2_points) 56 | 57 | print("saving finished.") 58 | 59 | 60 | def depth2pts(depth_map, prob_map, cam_intrinsic, cam_extrinsic, img_size): 61 | feature_map_indices_grid = get_pixel_grids(depth_map.size(2), depth_map.size(3)).to(depth_map.device) # (3, H*W) 62 | 63 | curr_cam_intrinsic = cam_intrinsic.clone() 64 | scale = (depth_map.size(2) + 0.0) / (img_size[0] + 0.0) * 4.0 65 | curr_cam_intrinsic[:2, :3] *= scale 66 | 67 | uv = torch.matmul(torch.inverse(curr_cam_intrinsic), feature_map_indices_grid) 68 | cam_points = uv * depth_map[0].view(1, -1) 69 | 70 | R = cam_extrinsic[:3, :3] 71 | t = cam_extrinsic[:3, 3].unsqueeze(-1) 72 | R_inv = torch.inverse(R) 73 | 74 | world_points = torch.matmul(R_inv, cam_points - t).detach().cpu().numpy().transpose() 75 | 76 | curr_prob_map = prob_map.copy() 77 | if curr_prob_map.shape[0] != depth_map.size(2): 78 | curr_prob_map = cv2.resize(curr_prob_map, (depth_map.size(3), depth_map.size(2)), 79 | interpolation=cv2.INTER_LANCZOS4) 80 | curr_prob_map = np.reshape(curr_prob_map, (-1, 1)) 81 | 82 | world_points = np.concatenate([world_points, curr_prob_map], axis=1) 83 | 84 | return world_points 85 | 86 | 87 | def save_points(path, points): 88 | np.savetxt(path, points, delimiter=' ', fmt='%.4f') 89 | -------------------------------------------------------------------------------- /fastmvsnet/utils/io.py: -------------------------------------------------------------------------------- 1 | import os 2 | import re 3 | import numpy as np 4 | import sys 5 | import errno 6 | import pickle 7 | import cv2 8 | import torch 9 | 10 | 11 | def mkdir(path): 12 | os.makedirs(path, exist_ok=True) 13 | 14 | 15 | def load_cam_dtu(file, num_depth=0, interval_scale=1.0): 16 | """ read camera txt file """ 17 | cam = np.zeros((2, 4, 4)) 18 | words = file.read().split() 19 | # read extrinsic 20 | for i in range(0, 4): 21 | for j in range(0, 4): 22 | extrinsic_index = 4 * i + j + 1 23 | cam[0][i][j] = words[extrinsic_index] 24 | 25 | # read intrinsic 26 | for i in range(0, 3): 27 | for j in range(0, 3): 28 | intrinsic_index = 3 * i + j + 18 29 | cam[1][i][j] = words[intrinsic_index] 30 | 31 | if len(words) == 29: 32 | cam[1][3][0] = words[27] 33 | cam[1][3][1] = float(words[28]) * interval_scale 34 | cam[1][3][2] = num_depth 35 | cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * (num_depth - 1) 36 | elif len(words) == 30: 37 | cam[1][3][0] = words[27] 38 | cam[1][3][1] = float(words[28]) * interval_scale 39 | cam[1][3][2] = words[29] 40 | cam[1][3][3] = cam[1][3][0] + cam[1][3][1] * (num_depth - 1) 41 | elif len(words) == 31: 42 | cam[1][3][0] = words[27] 43 | cam[1][3][1] = float(words[28]) * interval_scale 44 | cam[1][3][2] = words[29] 45 | cam[1][3][3] = words[30] 46 | else: 47 | cam[1][3][0] = 0 48 | cam[1][3][1] = 0 49 | cam[1][3][2] = 0 50 | cam[1][3][3] = 0 51 | 52 | return cam 53 | 54 | 55 | def write_cam_dtu(file, cam): 56 | # f = open(file, "w") 57 | f = open(file, "w") 58 | 59 | f.write('extrinsic\n') 60 | for i in range(0, 4): 61 | for j in range(0, 4): 62 | f.write(str(cam[0][i][j]) + ' ') 63 | f.write('\n') 64 | f.write('\n') 65 | 66 | f.write('intrinsic\n') 67 | for i in range(0, 3): 68 | for j in range(0, 3): 69 | f.write(str(cam[1][i][j]) + ' ') 70 | f.write('\n') 71 | 72 | f.write( 73 | '\n' + str(cam[1][3][0]) + ' ' + str(cam[1][3][1]) + ' ' + str(cam[1][3][2]) + ' ' + str(cam[1][3][3]) + '\n') 74 | 75 | f.close() 76 | 77 | 78 | def load_pfm(file): 79 | file = open(file, 'rb') 80 | 81 | color = None 82 | width = None 83 | height = None 84 | scale = None 85 | endian = None 86 | 87 | header = file.readline().rstrip() 88 | if header.decode("ascii") == 'PF': 89 | color = True 90 | elif header.decode("ascii") == 'Pf': 91 | color = False 92 | else: 93 | raise Exception('Not a PFM file.') 94 | 95 | dim_match = re.match(r'^(\d+)\s(\d+)\s$', file.readline().decode("ascii")) 96 | if dim_match: 97 | width, height = list(map(int, dim_match.groups())) 98 | else: 99 | raise Exception('Malformed PFM header.') 100 | 101 | scale = float(file.readline().decode("ascii").rstrip()) 102 | if scale < 0: # little-endian 103 | endian = '<' 104 | scale = -scale 105 | else: 106 | endian = '>' # big-endian 107 | 108 | data = np.fromfile(file, endian + 'f') 109 | shape = (height, width, 3) if color else (height, width) 110 | 111 | data = np.reshape(data, shape) 112 | data = np.flipud(data) 113 | return data, scale 114 | 115 | 116 | def write_pfm(file, image, scale=1): 117 | file = open(file, mode='wb') 118 | color = None 119 | 120 | if image.dtype.name != 'float32': 121 | raise Exception('Image dtype must be float32.') 122 | 123 | image = np.flipud(image) 124 | 125 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 126 | color = True 127 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 128 | color = False 129 | else: 130 | raise Exception('Image must have H x W x 3, H x W x 1 or H x W dimensions.') 131 | 132 | file.write('PF\n' if color else 'Pf\n'.encode()) 133 | file.write('%d %d\n'.encode() % (image.shape[1], image.shape[0])) 134 | 135 | endian = image.dtype.byteorder 136 | 137 | if endian == '<' or endian == '=' and sys.byteorder == 'little': 138 | scale = -scale 139 | 140 | file.write('%f\n'.encode() % scale) 141 | 142 | image_string = image.tostring() 143 | file.write(image_string) 144 | 145 | file.close() 146 | -------------------------------------------------------------------------------- /fastmvsnet/utils/logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | import logging 3 | import os 4 | import sys 5 | import time 6 | import numpy as np 7 | import torch 8 | from os.path import join 9 | import cv2 10 | 11 | 12 | def setup_logger(name, save_dir, prefix="", timestamp=True): 13 | logger = logging.getLogger(name) 14 | logger.setLevel(logging.INFO) 15 | ch = logging.StreamHandler(stream=sys.stdout) 16 | ch.setLevel(logging.INFO) 17 | formatter = logging.Formatter("%(asctime)s %(name)s %(levelname)s: %(message)s") 18 | ch.setFormatter(formatter) 19 | logger.addHandler(ch) 20 | 21 | if save_dir: 22 | timestamp = time.strftime(".%m_%d_%H_%M_%S") if timestamp else "" 23 | prefix = "." + prefix if prefix else "" 24 | log_file = os.path.join(save_dir, "log{}.txt".format(prefix + timestamp)) 25 | fh = logging.FileHandler(log_file) 26 | fh.setLevel(logging.INFO) 27 | fh.setFormatter(formatter) 28 | logger.addHandler(fh) 29 | 30 | logger.propagate = False 31 | return logger 32 | 33 | 34 | def shutdown_logger(logger): 35 | logger.handlers = [] 36 | 37 | -------------------------------------------------------------------------------- /fastmvsnet/utils/metric_logger.py: -------------------------------------------------------------------------------- 1 | # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. 2 | from collections import defaultdict 3 | from collections import deque 4 | 5 | import numpy as np 6 | import torch 7 | 8 | 9 | class AverageMeter(object): 10 | """Track a series of values and provide access to smoothed values over a 11 | window or the global series average. 12 | """ 13 | 14 | def __init__(self, window_size=20): 15 | self.values = deque(maxlen=window_size) 16 | self.counts = deque(maxlen=window_size) 17 | self.sum = 0.0 18 | self.count = 0 19 | 20 | def update(self, value, count=1): 21 | self.values.append(value) 22 | self.counts.append(count) 23 | self.sum += value 24 | self.count += count 25 | 26 | @property 27 | def avg(self): 28 | if np.sum(self.counts) == 0: 29 | return 0 30 | return np.sum(self.values) / np.sum(self.counts) 31 | 32 | @property 33 | def global_avg(self): 34 | if self.count == 0: 35 | return 0 36 | return self.sum / self.count 37 | 38 | 39 | class MetricLogger(object): 40 | def __init__(self, delimiter="\t"): 41 | self.meters = defaultdict(AverageMeter) 42 | self.delimiter = delimiter 43 | 44 | def update(self, **kwargs): 45 | for k, v in kwargs.items(): 46 | count = 1 47 | if isinstance(v, torch.Tensor): 48 | if v.numel() == 1: 49 | v = v.item() 50 | else: 51 | count = v.numel() 52 | v = v.sum().item() 53 | assert isinstance(v, (float, int)) 54 | self.meters[k].update(v, count) 55 | 56 | def __getattr__(self, attr): 57 | if attr in self.meters: 58 | return self.meters[attr] 59 | return object.__getattr__(self, attr) 60 | 61 | def __str__(self): 62 | metric_str = [] 63 | for name, meter in self.meters.items(): 64 | metric_str.append( 65 | "{}: {:.4f} ({:.4f})".format(name, meter.avg, meter.global_avg) 66 | ) 67 | return self.delimiter.join(metric_str) 68 | 69 | @property 70 | def summary_str(self): 71 | metric_str = [] 72 | for name, meter in self.meters.items(): 73 | metric_str.append( 74 | "{}: {:.4f}".format(name, meter.global_avg) 75 | ) 76 | return self.delimiter.join(metric_str) 77 | -------------------------------------------------------------------------------- /fastmvsnet/utils/preprocess.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import cv2 3 | import math 4 | 5 | 6 | def norm_image(img): 7 | """ normalize image input """ 8 | img = img.astype(np.float32) 9 | var = np.var(img, axis=(0, 1), keepdims=True) 10 | mean = np.mean(img, axis=(0, 1), keepdims=True) 11 | return (img - mean) / (np.sqrt(var) + 1e-7) 12 | 13 | 14 | def mask_depth_image(depth_image, min_depth, max_depth): 15 | """ mask out-of-range pixel to zero """ 16 | # print ('mask min max', min_depth, max_depth) 17 | ret, depth_image = cv2.threshold(depth_image, min_depth, 100000, cv2.THRESH_TOZERO) 18 | ret, depth_image = cv2.threshold(depth_image, max_depth, 100000, cv2.THRESH_TOZERO_INV) 19 | depth_image = np.expand_dims(depth_image, 2) 20 | return depth_image 21 | 22 | 23 | def scale_camera(cam, scale=1): 24 | """ resize input in order to produce sampled depth map """ 25 | new_cam = np.copy(cam) 26 | # focal: 27 | new_cam[1][0][0] = cam[1][0][0] * scale 28 | new_cam[1][1][1] = cam[1][1][1] * scale 29 | # principle point: 30 | new_cam[1][0][2] = cam[1][0][2] * scale 31 | new_cam[1][1][2] = cam[1][1][2] * scale 32 | return new_cam 33 | 34 | 35 | def scale_image(image, scale=1, interpolation='linear'): 36 | """ resize image using cv2 """ 37 | if interpolation == 'linear': 38 | return cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR) 39 | if interpolation == 'nearest': 40 | return cv2.resize(image, None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST) 41 | 42 | 43 | def scale_dtu_input(images, cams, depth_image=None, scale=1): 44 | """ resize input to fit into the memory """ 45 | for view in range(len(images)): 46 | images[view] = scale_image(images[view], scale=scale) 47 | cams[view] = scale_camera(cams[view], scale=scale) 48 | 49 | if depth_image is None: 50 | return images, cams 51 | else: 52 | depth_image = scale_image(depth_image, scale=scale, interpolation='nearest') 53 | return images, cams, depth_image 54 | 55 | 56 | def crop_dtu_input(images, cams, height, width, base_image_size, depth_image=None): 57 | """ resize images and cameras to fit the network (can be divided by base image size) """ 58 | 59 | # crop images and cameras 60 | for view in range(len(images)): 61 | h, w = images[view].shape[0:2] 62 | new_h = h 63 | new_w = w 64 | if new_h > height: 65 | new_h = height 66 | else: 67 | new_h = int(math.floor(h / base_image_size) * base_image_size) 68 | if new_w > width: 69 | new_w = width 70 | else: 71 | new_w = int(math.floor(w / base_image_size) * base_image_size) 72 | start_h = int(math.floor((h - new_h) / 2)) 73 | start_w = int(math.floor((w - new_w) / 2)) 74 | finish_h = start_h + new_h 75 | finish_w = start_w + new_w 76 | images[view] = images[view][start_h:finish_h, start_w:finish_w] 77 | cams[view][1][0][2] = cams[view][1][0][2] - start_w 78 | cams[view][1][1][2] = cams[view][1][1][2] - start_h 79 | 80 | # crop depth image 81 | if not depth_image is None: 82 | depth_image = depth_image[start_h:finish_h, start_w:finish_w] 83 | return images, cams, depth_image 84 | else: 85 | return images, cams 86 | -------------------------------------------------------------------------------- /fastmvsnet/utils/tensorboard_logger.py: -------------------------------------------------------------------------------- 1 | import time 2 | import os.path as osp 3 | import torch 4 | 5 | from .metric_logger import AverageMeter 6 | from tensorboardX import SummaryWriter 7 | from .io import mkdir 8 | 9 | _KEYWORDS = ("loss", "pct") 10 | 11 | 12 | class TensorboardLogger(object): 13 | def __init__(self, log_dir, keywords=_KEYWORDS): 14 | self.log_dir = osp.join(log_dir, "events.{}".format(time.strftime("%m_%d_%H_%M_%S"))) 15 | mkdir(self.log_dir) 16 | self.keywords = keywords 17 | self.writer = SummaryWriter(log_dir=self.log_dir) 18 | 19 | def add_scalars(self, meters, step, prefix=""): 20 | for k, meter in meters.items(): 21 | for keyword in _KEYWORDS: 22 | if keyword in k: 23 | if isinstance(meter, AverageMeter): 24 | v = meter.global_avg 25 | elif isinstance(meter, (int, float)): 26 | v = meter 27 | elif isinstance(meter, torch.Tensor): 28 | v = meter.cpu().item() 29 | else: 30 | raise TypeError() 31 | 32 | self.writer.add_scalar(osp.join(prefix, k), v, global_step=step) 33 | 34 | def add_image(self, img, step, prefix=""): 35 | assert len(img.size()) == 3 36 | self.writer.add_image(osp.join(prefix, "_img"), img, global_step=step) 37 | -------------------------------------------------------------------------------- /fastmvsnet/utils/torch_utils.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | import torch 5 | import torch.nn.functional as F 6 | 7 | def set_random_seed(seed): 8 | if seed < 0: 9 | return 10 | random.seed(seed) 11 | np.random.seed(seed) 12 | torch.manual_seed(seed) 13 | torch.cuda.manual_seed_all(seed) 14 | 15 | 16 | def get_knn_3d(xyz, kernel_size=5, knn=20): 17 | """ Use 3D Conv to compute neighbour distance and find k nearest neighbour 18 | xyz: (B, 3, D, H, W) 19 | 20 | Returns: 21 | idx: (B, D*H*W, k) 22 | """ 23 | batch_size, _, depth, height, width = list(xyz.size()) 24 | assert (kernel_size % 2 == 1) 25 | hk = (kernel_size // 2) 26 | k2 = kernel_size ** 2 27 | k3 = kernel_size ** 3 28 | 29 | t = np.zeros((kernel_size, kernel_size, kernel_size, 1, kernel_size ** 3)) 30 | ind = 0 31 | for i in range(kernel_size): 32 | for j in range(kernel_size): 33 | for k in range(kernel_size): 34 | t[i, j, k, 0, ind] -= 1.0 35 | t[hk, hk, hk, 0, ind] += 1.0 36 | ind += 1 37 | weight = np.zeros((kernel_size, kernel_size, kernel_size, 3, 3 * k3)) 38 | weight[:, :, :, 0:1, :k3] = t 39 | weight[:, :, :, 1:2, k3:2 * k3] = t 40 | weight[:, :, :, 2:3, 2 * k3:3 * k3] = t 41 | weight = torch.tensor(weight).float() 42 | 43 | weights_torch = torch.Tensor(weight.permute((4, 3, 0, 1, 2))).to(xyz.device) 44 | dist = F.conv3d(xyz, weights_torch, padding=hk) 45 | 46 | dist_flat = dist.contiguous().view(batch_size, 3, k3, -1) 47 | dist2 = torch.sum(dist_flat ** 2, dim=1) 48 | 49 | _, nn_idx = torch.topk(-dist2, k=knn, dim=1) 50 | nn_idx = nn_idx.permute(0, 2, 1) 51 | d_offset = nn_idx // k2 - hk 52 | h_offset = (nn_idx % k2) // kernel_size - hk 53 | w_offset = nn_idx % kernel_size - hk 54 | 55 | idx = torch.arange(depth * height * width).to(xyz.device) 56 | idx = idx.view(1, -1, 1).expand(batch_size, -1, knn) 57 | idx = idx + (d_offset * height * width) + (h_offset * width) + w_offset 58 | 59 | idx = torch.clamp(idx, 0, depth * height * width - 1) 60 | 61 | return idx 62 | -------------------------------------------------------------------------------- /outputs/pretrained.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/svip-lab/FastMVSNet/ccb686dda2717613c67d8a289dfe7b2aeb60e2fd/outputs/pretrained.pth -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | progressbar2>=3.0 2 | numpy>=1.13 3 | opencv_python>=3.2 4 | scikit-learn>=0.18 5 | scipy>=0.18 6 | matplotlib>=1.5 7 | Pillow>=3.1.2 8 | yacs 9 | tqdm 10 | tensorboardX 11 | -------------------------------------------------------------------------------- /tools/depthfusion.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """ 3 | Copyright 2018, Yao Yao, HKUST. 4 | Edited by Rui Chen. 5 | convert Point-MVSNet output to Gipuma format for post-processing. 6 | """ 7 | 8 | from __future__ import print_function 9 | 10 | import argparse 11 | import os.path as osp 12 | from struct import * 13 | import sys 14 | import os 15 | 16 | sys.path.append("/root/projects/FastMVSNet") 17 | from fastmvsnet.utils.io import * 18 | 19 | 20 | def mkdir(path): 21 | import errno 22 | try: 23 | os.makedirs(path) 24 | except OSError as e: 25 | if e.errno != errno.EEXIST: 26 | raise 27 | 28 | 29 | def read_gipuma_dmb(path): 30 | '''read Gipuma .dmb format image''' 31 | 32 | with open(path, "rb") as fid: 33 | image_type = unpack(' 0, 1, 0)) 115 | mask_image = np.reshape(mask_image, (image_shape[0], image_shape[1], 1)) 116 | mask_image = np.tile(mask_image, [1, 1, 3]) 117 | mask_image = np.float32(mask_image) 118 | 119 | normal_image = np.multiply(normal_image, mask_image) 120 | normal_image = np.float32(normal_image) 121 | 122 | write_gipuma_dmb(out_normal_path, normal_image) 123 | return 124 | 125 | 126 | def mvsnet_to_gipuma(scene_folder, gipuma_point_folder, name, view_num): 127 | gipuma_cam_folder = os.path.join(gipuma_point_folder, 'cams') 128 | gipuma_image_folder = os.path.join(gipuma_point_folder, 'images') 129 | mkdir(gipuma_cam_folder) 130 | mkdir(gipuma_image_folder) 131 | 132 | for v in range(view_num): 133 | # convert cameras 134 | in_cam_file = os.path.join(scene_folder, 'cam_{:08d}_{}.txt'.format(v, name)) 135 | out_cam_file = os.path.join(gipuma_cam_folder, '{:08d}.jpg.P'.format(v)) 136 | mvsnet_to_gipuma_cam(in_cam_file, out_cam_file) 137 | 138 | # convert depth maps and fake normal maps 139 | gipuma_prefix = '2333__' 140 | sub_depth_folder = os.path.join(gipuma_point_folder, gipuma_prefix + "{:08d}".format(v)) 141 | mkdir(sub_depth_folder) 142 | in_depth_pfm = os.path.join(scene_folder, "{:08d}_{}_prob_filtered.pfm".format(v, name)) 143 | out_depth_dmb = os.path.join(sub_depth_folder, 'disp.dmb') 144 | fake_normal_dmb = os.path.join(sub_depth_folder, 'normals.dmb') 145 | mvsnet_to_gipuma_dmb(in_depth_pfm, out_depth_dmb) 146 | fake_colmap_normal(out_depth_dmb, fake_normal_dmb) 147 | 148 | # copy images to gipuma image folder 149 | in_image_file = os.path.join(scene_folder, '{:08d}.jpg'.format(v)) 150 | out_image_file = os.path.join(gipuma_image_folder, '{:08d}.jpg'.format(v)) 151 | in_image = cv2.imread(in_image_file) 152 | 153 | depth_image = load_pfm(in_depth_pfm)[0] 154 | if in_image.shape[:2] != depth_image.shape[:2]: 155 | in_image = cv2.resize(in_image, (depth_image.shape[1], depth_image.shape[0]), 156 | interpolation=cv2.INTER_NEAREST) 157 | cv2.imwrite(out_image_file, in_image) 158 | 159 | 160 | def probability_filter(scene_folder, init_prob_threshold, flow_prob_threshold, name, view_num, mode): 161 | name_bak = name 162 | for v in range(view_num): 163 | # name = 'init' 164 | init_prob_map_path = os.path.join(scene_folder, "{:08d}_init_prob.pfm".format(v)) 165 | prob_map_path = os.path.join(scene_folder, "{:08d}_{}_prob.pfm".format(v, name)) 166 | init_depth_map_path = os.path.join(scene_folder, "{:08d}_{}.pfm".format(v, name)) 167 | # name = name_bak 168 | out_depth_map_path = os.path.join(scene_folder, "{:08d}_{}_prob_filtered.pfm".format(v, name)) 169 | 170 | depth_map = load_pfm(init_depth_map_path)[0] 171 | prob_map = load_pfm(prob_map_path)[0] 172 | init_prob_map = load_pfm(init_prob_map_path)[0] 173 | 174 | # depth_map = cv2.resize(depth_map, (640, 480), interpolation=cv2.INTER_NEAREST) 175 | 176 | if prob_map.shape != depth_map.shape: 177 | prob_map = cv2.resize(prob_map, (depth_map.shape[1], depth_map.shape[0]), interpolation=mode) 178 | if init_prob_map.shape != depth_map.shape: 179 | init_prob_map = cv2.resize(init_prob_map, (depth_map.shape[1], depth_map.shape[0]), interpolation=mode) 180 | 181 | depth_map[prob_map < flow_prob_threshold] = 0 182 | depth_map[init_prob_map < init_prob_threshold] = 0 183 | write_pfm(out_depth_map_path, depth_map) 184 | 185 | 186 | def probability_filter2(scene_folder, init_prob_threshold, flow_prob_threshold, name, view_num, mode): 187 | name_bak = name 188 | for v in range(view_num): 189 | # name = 'init' 190 | init_prob_map_path = os.path.join(scene_folder, "{:08d}_init_prob.pfm".format(v)) 191 | init_depth_map_path = os.path.join(scene_folder, "{:08d}_{}.pfm".format(v, name)) 192 | # name = name_bak 193 | out_depth_map_path = os.path.join(scene_folder, "{:08d}_{}_prob_filtered.pfm".format(v, name)) 194 | 195 | depth_map = load_pfm(init_depth_map_path)[0] 196 | init_prob_map = load_pfm(init_prob_map_path)[0] 197 | # print(depth_map.shape) 198 | # depth_map = cv2.resize(depth_map, (640, 480), interpolation=cv2.INTER_NEAREST) 199 | # print(depth_map.shape) 200 | if init_prob_map.shape != depth_map.shape: 201 | init_prob_map = cv2.resize(init_prob_map, (depth_map.shape[1], depth_map.shape[0]), interpolation=mode) 202 | 203 | depth_map[init_prob_map < init_prob_threshold] = 0 204 | write_pfm(out_depth_map_path, depth_map) 205 | 206 | 207 | def depth_map_fusion(point_folder, fusibile_exe_path, disp_thresh, num_consistent): 208 | cam_folder = os.path.join(point_folder, 'cams') 209 | image_folder = os.path.join(point_folder, 'images') 210 | depth_min = 0.001 211 | depth_max = 100000 212 | normal_thresh = 360 213 | 214 | cmd = fusibile_exe_path 215 | cmd = cmd + ' -input_folder ' + point_folder + '/' 216 | cmd = cmd + ' -p_folder ' + cam_folder + '/' 217 | cmd = cmd + ' -images_folder ' + image_folder + '/' 218 | cmd = cmd + ' --depth_min=' + str(depth_min) 219 | cmd = cmd + ' --depth_max=' + str(depth_max) 220 | cmd = cmd + ' --normal_thresh=' + str(normal_thresh) 221 | cmd = cmd + ' --disp_thresh=' + str(disp_thresh) 222 | cmd = cmd + ' --num_consistent=' + str(num_consistent) 223 | print(cmd) 224 | os.system(cmd) 225 | 226 | return 227 | 228 | 229 | if __name__ == '__main__': 230 | 231 | parser = argparse.ArgumentParser() 232 | parser.add_argument('--eval_folder', type=str, 233 | default='data/dtu/Eval/') 234 | parser.add_argument('--fusibile_exe_path', type=str, default='/root/projects/fusibile/fusibile') 235 | parser.add_argument('--init_prob_threshold', type=float, default=0.2) 236 | parser.add_argument('--flow_prob_threshold', type=float, default=0.1) 237 | parser.add_argument('--disp_threshold', type=float, default=0.12) 238 | parser.add_argument('--num_consistent', type=int, default=3) 239 | parser.add_argument("-v", '--view_num', type=int, default=49) 240 | parser.add_argument("-n", '--name', type=str) 241 | parser.add_argument("-m", '--inter_mode', type=str, default='LANCZOS4') 242 | parser.add_argument("-f", '--depth_folder', type=str) 243 | args = parser.parse_args() 244 | 245 | eval_folder = args.eval_folder 246 | fusibile_exe_path = args.fusibile_exe_path 247 | init_prob_threshold = args.init_prob_threshold 248 | flow_prob_threshold = args.flow_prob_threshold 249 | disp_threshold = args.disp_threshold 250 | num_consistent = args.num_consistent 251 | view_num = args.view_num 252 | name = args.name 253 | 254 | if args.inter_mode == "NEAREST": 255 | mode = cv2.INTER_NEAREST 256 | elif args.inter_mode == "BILINEAR": 257 | mode = cv2.INTER_LINEAR 258 | elif args.inter_mode == "CUBIC": 259 | mode = cv2.INTER_CUBIC 260 | elif args.inter_mode == "LANCZOS4": 261 | mode = cv2.INTER_LANCZOS4 262 | else: 263 | raise ValueError("Unknown interpolation mode: {}.".format(args.inter_mode)) 264 | 265 | DEPTH_FOLDER = args.depth_folder 266 | 267 | out_point_folder = os.path.join(eval_folder, DEPTH_FOLDER, '{}_3ITER_{}_ip{}_fp{}_d{}_nc{}' 268 | .format(args.inter_mode, name, init_prob_threshold, flow_prob_threshold, 269 | disp_threshold, num_consistent)) 270 | mkdir(out_point_folder) 271 | 272 | scene_list = ["scan1", "scan4", "scan9", "scan10", "scan11", "scan12", "scan13", "scan15", "scan23", 273 | "scan24", "scan29", "scan32", "scan33", "scan34", "scan48", "scan49", "scan62", "scan75", 274 | "scan77", "scan110", "scan114", "scan118"] 275 | 276 | for scene in scene_list: 277 | scene_folder = osp.join(eval_folder, DEPTH_FOLDER, scene) 278 | if not osp.isdir(scene_folder): 279 | continue 280 | if scene[:4] != "scan": 281 | continue 282 | print("**** Fusion for {} ****".format(scene)) 283 | 284 | # probability filter 285 | print('filter depth map with probability map') 286 | probability_filter2(scene_folder, init_prob_threshold, flow_prob_threshold, name, view_num, mode) 287 | 288 | # convert to gipuma format 289 | print('Convert mvsnet output to gipuma input') 290 | point_folder = osp.join(out_point_folder, scene) 291 | mkdir(point_folder) 292 | mvsnet_to_gipuma(scene_folder, point_folder, name, view_num) 293 | 294 | # depth map fusion with gipuma 295 | print('Run depth map fusion & filter') 296 | depth_map_fusion(point_folder, fusibile_exe_path, disp_threshold, num_consistent) 297 | 298 | cur_dirs = os.listdir(point_folder) 299 | filter_dirs = list(filter(lambda x:x.startswith("consistencyCheck"), cur_dirs)) 300 | 301 | assert (len(filter_dirs) == 1) 302 | 303 | rename_cmd = "cp " + osp.join(point_folder, filter_dirs[0]) + "/final3d_model.ply {}/{}_ip{}_fp{}_d{}_nc{}_{}.ply".format( 304 | out_point_folder, scene, init_prob_threshold, flow_prob_threshold, disp_threshold, num_consistent, 305 | args.inter_mode 306 | ) 307 | print(rename_cmd) 308 | os.system(rename_cmd) 309 | 310 | # remove tmp file 311 | remove_cmd = "rm -r " + point_folder 312 | print(remove_cmd) 313 | os.system(remove_cmd) 314 | --------------------------------------------------------------------------------