├── .gitignore ├── .pre-commit-config.yaml ├── .vscode ├── extensions.json └── settings.json ├── LICENSE ├── README.md ├── data └── README.md ├── environment.yml ├── eval.py ├── guided_mvs_lib ├── __init__.py ├── datasets │ ├── __init__.py │ ├── blended_mvg_utils.py │ ├── blended_mvs_utils.py │ ├── dtu_blended_mvs.py │ ├── dtu_utils.py │ ├── sample_preprocess.py │ └── utils.py ├── models │ ├── __init__.py │ ├── cas_mvsnet │ │ ├── __init__.py │ │ ├── cas_mvsnet.py │ │ └── module.py │ ├── d2hc_rmvsnet │ │ ├── __init__.py │ │ ├── convlstm.py │ │ ├── drmvsnet.py │ │ ├── module.py │ │ ├── rnnmodule.py │ │ ├── submodule.py │ │ ├── vamvsnet.py │ │ └── vamvsnet_high_submodule.py │ ├── mvsnet │ │ ├── __init__.py │ │ ├── module.py │ │ └── mvsnet.py │ ├── patchmatchnet │ │ ├── __init__.py │ │ ├── module.py │ │ ├── net.py │ │ └── patchmatch.py │ └── ucsnet │ │ ├── __init__.py │ │ ├── submodules.py │ │ └── ucsnet.py └── utils.py ├── hubconf.py ├── params.yaml ├── pyproject.toml ├── results.ipynb ├── tests ├── test_dataset.py ├── test_models.py ├── test_train.py └── test_version.py └── train.py /.gitignore: -------------------------------------------------------------------------------- 1 | .python-version 2 | **/__pycache__/ 3 | data/* 4 | !data/README.md 5 | eval_output/ 6 | guided-mvs.code-workspace 7 | .pytest_cache 8 | notebooks/ 9 | .current_run.yaml 10 | .envrc 11 | mlruns/ 12 | -------------------------------------------------------------------------------- /.pre-commit-config.yaml: -------------------------------------------------------------------------------- 1 | # See https://pre-commit.com for more information 2 | # See https://pre-commit.com/hooks.html for more hooks 3 | repos: 4 | - repo: https://github.com/pre-commit/pre-commit-hooks 5 | rev: v3.2.0 6 | hooks: 7 | - id: check-yaml 8 | - id: check-added-large-files 9 | - repo: https://github.com/psf/black 10 | rev: 21.9b0 11 | hooks: 12 | - id: black 13 | - repo: https://github.com/PyCQA/isort 14 | rev: 5.9.3 15 | hooks: 16 | - id: isort -------------------------------------------------------------------------------- /.vscode/extensions.json: -------------------------------------------------------------------------------- 1 | { 2 | "recommendations": [ 3 | "ms-python.python", 4 | "yzhang.markdown-all-in-one", 5 | ] 6 | } -------------------------------------------------------------------------------- /.vscode/settings.json: -------------------------------------------------------------------------------- 1 | { 2 | "editor.formatOnSave": true, 3 | "python.formatting.provider": "black", 4 | "[python]": { 5 | "editor.codeActionsOnSave": { 6 | "source.organizeImports": true 7 | } 8 | }, 9 | "files.watcherExclude": { 10 | "**/.git/objects/**": true, 11 | "**/.git/subtree-cache/**": true, 12 | "**/node_modules/*/**": true, 13 | "**/.hg/store/**": true, 14 | "**/output/**": true, 15 | "**/data/**": true, 16 | }, 17 | "python.testing.pytestArgs": [ 18 | "tests" 19 | ], 20 | "python.testing.unittestEnabled": false, 21 | "python.testing.pytestEnabled": true 22 | } -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 Fangjinhua Wang 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Multi-View Guided Multi-View Stereo 2 | 3 |

4 |

5 | Matteo Poggi* 6 | · 7 | Andrea Conti* 8 | · 9 | Stefano Mattoccia 10 | *joint first authorship 11 |
12 |
13 | [Arxiv] 14 | [Project Page] 15 | [Demo] 16 |
17 |

18 | 19 | This is the official source code of Multi-View Guided Multi-View Stereo presented at [IEEE/RSJ International Conference on Intelligent Robots and Systems](https://iros2022.org/) 20 | 21 | ## Citation 22 | 23 | ``` 24 | @inproceedings{poggi2022guided, 25 | title={Multi-View Guided Multi-View Stereo}, 26 | author={Poggi, Matteo and Conti, Andrea and Mattoccia, Stefano}, 27 | booktitle={IEEE/RSJ International Conference on Intelligent Robots and Systems}, 28 | note={IROS}, 29 | year={2022} 30 | } 31 | ``` 32 | 33 | ## Load pretrained models and evaluate 34 | 35 | We release many of the mvs networks tested in the paper trained on Blended-MVG or on Blended-MVG and fine-tuned on DTU, with and without sparse depth points. To load these models can be simply used the `torch.hub` API. 36 | 37 | ```python 38 | model = torch.hub.load( 39 | "andreaconti/multi-view-guided-multi-view-stereo", 40 | "mvsnet", # mvsnet | ucsnet | d2hc_rmvsnet | patchmatchnet | cas_mvsnet 41 | pretrained=True, 42 | dataset="blended_mvg", # blended_mvg | dtu_yao_blended_mvg 43 | hints="not_guided", # mvguided_filtered | not_guided | guided | mvguided 44 | ) 45 | ``` 46 | 47 | Once loaded each model have the same following interface, moreover each pretrained models provides its training parameters under the attribute `train_params`. 48 | 49 | ```python 50 | depth = model( 51 | images, # B x N x 3 x H x W 52 | intrinsics, # B x N x 3 x 3 53 | extrinsics, # B x N x 4 x 4 54 | depth_values, # B x D (128 usually) 55 | hints, # B x 1 x H x W (optional) 56 | ) 57 | ``` 58 | 59 | Finally, we provide also an interface over the datasets used as follows. In this case is required Pytorch Lightning as dependency and the dataset must be stored locally. 60 | 61 | ```python 62 | dm = torch.hub.load( 63 | "andreaconti/multi-view-guided-multi-view-stereo", 64 | "blended-mvg", # blended_mvg | blended_mvs | dtu 65 | root="data/blended-mvg", 66 | hints="not_guided", # mvguided_filtered | not_guided | guided | mvguided 67 | hints_density=0.03, 68 | ) 69 | dm.prepare_data() 70 | dm.setup() 71 | dl = dm.train_dataloader() 72 | ``` 73 | 74 | In [results.ipynb](https://github.com/andreaconti/multi-view-guided-multi-view-stereo/blob/main/results.ipynb) there is an example of how to reproduce some of the results showed in the paper through the `torch.hub` API. 75 | 76 | ## Installation 77 | 78 | Install the dependencies using Conda or [Mamba](https://github.com/mamba-org/mamba): 79 | 80 | ```bash 81 | $ conda env create -f environment.yml 82 | $ conda activate guided-mvs 83 | ``` 84 | 85 | ## Download the dataset(s) 86 | 87 | Download the used datasets: 88 | 89 | * [DTU](http://roboimagedata.compute.dtu.dk/?page_id=36) preprocessed by [patchmatchnet](https://github.com/FangjinhuaWang/PatchmatchNet), [train](https://polybox.ethz.ch/index.php/s/ugDdJQIuZTk4S35) and [val](https://drive.google.com/file/d/1jN8yEQX0a-S22XwUjISM8xSJD39pFLL_/view?usp=sharing) 90 | * [BlendedMVG](https://github.com/YoYo000/BlendedMVS), download the lowres version from [BlendedMVS](https://1drv.ms/u/s!Ag8Dbz2Aqc81gVDu7FHfbPZwqhIy?e=BHY07t), [BlendedMVS+](https://1drv.ms/u/s!Ag8Dbz2Aqc81gVLILxpohZLEYiIa?e=MhwYSR), [BlendedMVS++](https://1drv.ms/u/s!Ag8Dbz2Aqc81gVHCxmURGz0UBGns?e=Tnw2KY) 91 | 92 | And organize them as follows under the data folder (sym-links works fine): 93 | 94 | ``` 95 | data/dtu 96 | |-- train_data 97 | |-- Cameras_1 98 | |-- Depths_raw 99 | |-- Rectified 100 | |-- test_data 101 | |-- scan1 102 | |-- scan2 103 | |-- .. 104 | 105 | data/blended-mvs 106 | |-- 107 | |-- blended_images 108 | |-- cams 109 | |-- rendered_depth_maps 110 | |-- .. 111 | ``` 112 | 113 | ## [Optional] Test everything is fine 114 | 115 | This project implements some tests to preliminarily check everything is fine. Tests are grouped by different tags. 116 | 117 | ``` 118 | # tages: 119 | # - data: tests related to the datasets 120 | # - dtu: tests related to the DTU dataset 121 | # - blended_mvs: tests related to blended MVS 122 | # - blended_mvg: tests related to blended MVG 123 | # - train: tests to launch all networks in fast dev run mode (1 batch of train val for each network for each dataset) 124 | # - slow: tests slow to be executed 125 | 126 | # EXAMPLES 127 | 128 | # runs all tests 129 | $ pytest 130 | 131 | # runs tests excluding slow ones 132 | $ pytest -m "not slow" 133 | 134 | # runs tests only on dtu 135 | $ pytest -m dtu 136 | 137 | # runs tests on data except for dtu ones 138 | $ pytest -m "data and not dtu" 139 | ``` 140 | 141 | ## Training 142 | 143 | To train a model, edit ``params.yaml`` specifying the model to be trained among the following: 144 | 145 | * [cas_mvsnet](https://arxiv.org/pdf/1912.06378.pdf) 146 | * [d2hc_rmvsnet](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123490647.pdf) 147 | * [mvsnet](https://arxiv.org/pdf/1804.02505.pdf) 148 | * [patchmatchnet](https://arxiv.org/pdf/2012.01411.pdf) 149 | * [ucsnet](https://arxiv.org/abs/1911.12012) 150 | 151 | The dataset between ``dtu_yao``, `blended_mvs`, ``blended_mvg`` and the other training parameters, then hit: 152 | 153 | ``` 154 | # python3 train.py --help to see the options 155 | $ python3 train.py 156 | ``` 157 | 158 | The best model is stored in ``output`` folder as ``model.ckpt`` along with a ``meta`` folder containing useful informations about the training executed. 159 | 160 | ### Resume a training 161 | 162 | If something bad happens or if you stop the training process using a keyboard interrupt (``Ctrl-C``) the checkpoints will not be deleted and you can resume 163 | the training with the following option: 164 | 165 | ``` 166 | # resume the last checkpoint saved in output/ckpts (last epoch) 167 | $ python3 train.py --resume-from-checkpoint 168 | 169 | # resume a choosen checkpoint elsewere 170 | $ python3 train.py --resume-from-checkpoint my_checkpoint.ckpt 171 | ``` 172 | 173 | It takes care to properly update the correct training logs on tensorboard. 174 | 175 | ## Evaluation 176 | 177 | Once you have trained the model you can evaluate it using ``eval.py``, here you have few options, specifically: 178 | 179 | * the **dataset** on which test 180 | * if evaluate using **guided** hints, **guided_integral** hints or none 181 | * the **hints density** to be used 182 | * the number of **views** 183 | 184 | ``` 185 | # see the options 186 | $ python3 eval.py --help 187 | usage: eval.py [-h] [--dataset {dtu_yao,blended_mvs,blended_mvg}] 188 | [--hints {not_guided,guided,guided_integral}] 189 | [--hints-density HINTS_DENSITY] [--views VIEWS] 190 | [--loadckpt LOADCKPT] [--limit-scans LIMIT_SCANS] 191 | [--skip-steps {1,2,3} [{1,2,3} ...]] [--save-scans-output] 192 | 193 | # EXAMPLES 194 | # without guided hints on dtu_yao, 3 views 195 | $ python3 eval.py 196 | 197 | # with guided hints and 5 views and density of 0.01 198 | $ python3 eval.py --hints guided --views 5 199 | 200 | # with integral guided hints 3 views and 0.03 density 201 | $ python3 eval.py --hints guided_integral --hints-density 0.03 202 | ``` 203 | 204 | Results will be stored under ``output/eval_/[guided|not_guided|guided_integral]-[density=]-views=``, for instance guiding on DTU with a 0.01 density and using 3 views the results will be in: 205 | 206 | * ``output/eval_dtu_yao/guided-density=0.01-views=3/`` 207 | 208 | Each of these folders will contain the point cloud for each testing scene and a ``metrics.json`` file containing the final metrics, they will differ depending on the dataset used for evaluation. 209 | 210 | ## Development 211 | 212 | ### Environment 213 | 214 | To develop you have to create a conda virtual environment and **also** install git hooks: 215 | 216 | ```bash 217 | $ conda env create -f environment.yml 218 | $ conda activate guided-MVS 219 | $ pre-commit install 220 | ``` 221 | 222 | When you will commit [Black](https://github.com/psf/black) and [Isort](https://pypi.org/project/isort/) will be executed on the modified 223 | files. 224 | 225 | ### VSCode specific settings 226 | 227 | If you use Visual Studio Code its configuration and needed extensions are stored in ``.vscode``. Create a file in the root folder called ``.guided-mvs.code-workspace`` containing the following to load the conda environment properly: 228 | 229 | ```json 230 | { 231 | "folders": [ 232 | { 233 | "path": "." 234 | } 235 | ], 236 | "settings": { 237 | "python.pythonPath": "" 238 | } 239 | } 240 | ``` 241 | -------------------------------------------------------------------------------- /data/README.md: -------------------------------------------------------------------------------- 1 | # Data Folder 2 | 3 | Here you have to put your datasets (even a soft link is enough) -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: guided-mvs 2 | channels: 3 | - pytorch-lts 4 | - conda-forge 5 | - defaults 6 | dependencies: 7 | 8 | # core 9 | - python=3.8 10 | - cudatoolkit=11.1 11 | - pytorch=1.8.* 12 | - pytorch-lightning=1.4.* 13 | - torchvision 14 | - numpy>=1.20.0 15 | - tqdm 16 | - pyyaml 17 | - scipy 18 | - tensorboard 19 | - joblib 20 | - pip 21 | 22 | # dev 23 | - black 24 | - isort 25 | - pre-commit 26 | - pytest 27 | 28 | - pip: 29 | - opencv-python 30 | - cmapy >=0.6.0, <1.0 31 | - plyfile 32 | - pillow 33 | - GitPython 34 | - mlflow 35 | - pysftp 36 | -------------------------------------------------------------------------------- /guided_mvs_lib/__init__.py: -------------------------------------------------------------------------------- 1 | __version__ = "0.1.0" 2 | -------------------------------------------------------------------------------- /guided_mvs_lib/datasets/__init__.py: -------------------------------------------------------------------------------- 1 | import multiprocessing 2 | from pathlib import Path 3 | from typing import Callable, List, Literal, Optional, Union 4 | 5 | from pytorch_lightning import LightningDataModule 6 | from torch.utils.data import DataLoader 7 | 8 | from . import blended_mvg_utils, blended_mvs_utils, dtu_utils 9 | from .dtu_blended_mvs import MVSDataset, MVSSample 10 | 11 | __all__ = ["MVSDataModule", "MVSSample", "find_dataset_def", "find_scans"] 12 | 13 | 14 | class MVSDataModule(LightningDataModule): 15 | def __init__( 16 | self, 17 | # selection of the dataset 18 | name: Literal["dtu_yao", "blended_mvs", "blended_mvg"], 19 | # args for the dataloader 20 | batch_size: int = 1, 21 | num_workers: int = multiprocessing.cpu_count() // 2, 22 | # args for the dataset 23 | **kwargs, 24 | ): 25 | super().__init__() 26 | self._ds_builder = find_dataset_def(name, kwargs.pop("datapath", None)) 27 | self._ds_args = kwargs 28 | 29 | # dataloader args 30 | self.batch_size = batch_size 31 | self.num_workers = num_workers 32 | 33 | def setup(self, stage: Optional[str] = None): 34 | if stage in ("fit", None): 35 | self.mvs_train = self._ds_builder(mode="train", **self._ds_args) 36 | self.mvs_val = self._ds_builder(mode="val", **self._ds_args) 37 | if stage in ("test", None): 38 | self.mvs_test = self._ds_builder(mode="test", **self._ds_args) 39 | 40 | def train_dataloader(self): 41 | return DataLoader( 42 | self.mvs_train, 43 | batch_size=self.batch_size, 44 | shuffle=True, 45 | num_workers=self.num_workers, 46 | ) 47 | 48 | def val_dataloader(self): 49 | return DataLoader( 50 | self.mvs_val, 51 | batch_size=1, 52 | shuffle=False, 53 | num_workers=self.num_workers, 54 | ) 55 | 56 | def test_dataloader(self): 57 | return DataLoader( 58 | self.mvs_test, 59 | batch_size=1, 60 | shuffle=False, 61 | num_workers=self.num_workers, 62 | ) 63 | 64 | 65 | def find_dataset_def( 66 | name: Literal["dtu_yao", "blended_mvs", "blended_mvg"], 67 | datapath: Union[Path, str, None] = None, 68 | ) -> Callable[..., MVSDataset]: 69 | assert name in ["dtu_yao", "blended_mvs", "blended_mvg"] 70 | if datapath is None: 71 | datapath = { 72 | "dtu_yao": "data/dtu", 73 | "blended_mvs": "data/blended-mvs", 74 | "blended_mvg": "data/blended-mvs", 75 | }[name] 76 | datapath = Path(datapath) 77 | 78 | def builder(*args, **kwargs): 79 | return MVSDataset(name, datapath, *args, **kwargs) 80 | 81 | return builder 82 | 83 | 84 | _SCANS = { 85 | "dtu_yao": { 86 | "train": dtu_utils.train_scans(), 87 | "val": dtu_utils.val_scans(), 88 | "test": dtu_utils.test_scans(), 89 | }, 90 | "blended_mvs": { 91 | "train": blended_mvs_utils.train_scans(), 92 | "val": blended_mvs_utils.val_scans(), 93 | "test": blended_mvs_utils.test_scans(), 94 | }, 95 | "blended_mvg": { 96 | "train": blended_mvg_utils.train_scans(), 97 | "val": blended_mvg_utils.val_scans(), 98 | "test": blended_mvg_utils.test_scans(), 99 | }, 100 | } 101 | 102 | 103 | def find_scans( 104 | dataset_name: Literal["dtu_yao", "blended_mvs", "blended_mvg"], 105 | split: Literal["train", "val", "test"], 106 | ) -> Optional[List[str]]: 107 | try: 108 | return _SCANS[dataset_name][split] 109 | except KeyError: 110 | if dataset_name not in ["dtu_yao", "blended_mvs", "blended_mvg"]: 111 | raise ValueError(f"{dataset_name} not in dtu_utils, blended_mvs, blended_mvg") 112 | elif split not in ["train", "val", "test"]: 113 | raise ValueError(f"{split} not in train, val, test") 114 | -------------------------------------------------------------------------------- /guided_mvs_lib/datasets/blended_mvs_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | from pathlib import Path 3 | from typing import List, Literal, Tuple, Union 4 | 5 | import numpy as np 6 | 7 | from .utils import DatasetExamplePaths, read_pfm 8 | 9 | __all__ = [ 10 | "datapath_files", 11 | "read_cam_file", 12 | "read_depth_mask", 13 | "read_depth", 14 | "build_list", 15 | "train_scans", 16 | "val_scans", 17 | "test_scans", 18 | ] 19 | 20 | 21 | def datapath_files( 22 | datapath: Union[Path, str], 23 | scan: str, 24 | view_id: int, 25 | split: Literal["train", "test", "val"], 26 | light_id: int = None, 27 | ) -> DatasetExamplePaths: 28 | """ 29 | Takes in input the root of the Blended MVS dataset and returns a dictionary containing 30 | the paths to the files of a single scan, with the speciied view_id and light_id 31 | (this last one is used only when ``split`` isn't test) 32 | 33 | Parameters 34 | ---------- 35 | datapath: path 36 | path to the root of the dtu dataset 37 | scan: str 38 | name of the used scan, say scan1 39 | view_id: int 40 | the index of the specific image in the scan 41 | split: train, val or test 42 | which split of DTU must be used to search the scan for 43 | light_id: int 44 | the index of the specific lightning condition index 45 | 46 | Returns 47 | ------- 48 | out: Dict[str, Path] 49 | returns a dictionary containing the paths taken into account 50 | """ 51 | root = Path(datapath) 52 | return { 53 | "img": root / f"{scan}/blended_images/{view_id:0>8}.jpg", 54 | "proj_mat": root / f"{scan}/cams/{view_id:0>8}_cam.txt", 55 | "depth_mask": root / f"{scan}/rendered_depth_maps/{view_id:0>8}.pfm", 56 | "depth": root / f"{scan}/rendered_depth_maps/{view_id:0>8}.pfm", 57 | "pcd": None, 58 | "obs_mask": None, 59 | "ground_plane": None, 60 | } 61 | 62 | 63 | def read_cam_file(path: str) -> Tuple[np.ndarray, np.ndarray, float, float]: 64 | """ 65 | Reads a file containing the Blended MVS camera intrinsics, extrinsics, max depth and 66 | min depth. 67 | 68 | Parameters 69 | ---------- 70 | path: str 71 | path of the source file (something like ../00000000_cam.txt) 72 | 73 | Returns 74 | ------- 75 | out: Tuple[np.ndarray, np.ndarray, float, float] 76 | respectively intrinsics, extrinsics, min depth and max depth 77 | """ 78 | with open(path) as f: 79 | lines = f.readlines() 80 | lines = [line.rstrip() for line in lines] 81 | extrinsics = np.fromstring(" ".join(lines[1:5]), dtype=np.float32, sep=" ").reshape((4, 4)) 82 | intrinsics = np.fromstring(" ".join(lines[7:10]), dtype=np.float32, sep=" ").reshape((3, 3)) 83 | depth_min = float(lines[11].split()[0]) 84 | depth_max = float(lines[11].split()[3]) 85 | return intrinsics, extrinsics, depth_min, depth_max 86 | 87 | 88 | def read_depth_mask(path: Union[str, Path]) -> np.ndarray: 89 | """ 90 | Loads the Blended MVS depth mask 91 | """ 92 | mask = np.array(read_pfm(path)[0], dtype=np.float32) > 0 93 | return mask 94 | 95 | 96 | def read_depth(path: Union[str, Path]) -> np.ndarray: 97 | """ 98 | Loads the depth DTU depth map 99 | """ 100 | depth = np.array(read_pfm(str(path))[0], dtype=np.float32) 101 | return depth 102 | 103 | 104 | def build_list(datapath: str, scans: list, nviews: int): 105 | metas = [] 106 | for scan in scans: 107 | pair_file = "cams/pair.txt" 108 | 109 | with open(os.path.join(datapath, scan, pair_file)) as f: 110 | num_viewpoint = int(f.readline()) 111 | for view_idx in range(num_viewpoint): 112 | ref_view = int(f.readline().rstrip()) 113 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 114 | if len(src_views) >= nviews: 115 | metas.append((scan, None, ref_view, src_views)) 116 | return metas 117 | 118 | 119 | def train_scans() -> List[str]: 120 | return [ 121 | "5c1f33f1d33e1f2e4aa6dda4", 122 | "5bfe5ae0fe0ea555e6a969ca", 123 | "5bff3c5cfe0ea555e6bcbf3a", 124 | "58eaf1513353456af3a1682a", 125 | "5bfc9d5aec61ca1dd69132a2", 126 | "5bf18642c50e6f7f8bdbd492", 127 | "5bf26cbbd43923194854b270", 128 | "5bf17c0fd439231948355385", 129 | "5be3ae47f44e235bdbbc9771", 130 | "5be3a5fb8cfdd56947f6b67c", 131 | "5bbb6eb2ea1cfa39f1af7e0c", 132 | "5ba75d79d76ffa2c86cf2f05", 133 | # "5bb7a08aea1cfa39f1a947ab", 134 | "5b864d850d072a699b32f4ae", 135 | "5b6eff8b67b396324c5b2672", 136 | "5b6e716d67b396324c2d77cb", 137 | "5b69cc0cb44b61786eb959bf", 138 | "5b62647143840965efc0dbde", 139 | "5b60fa0c764f146feef84df0", 140 | "5b558a928bbfb62204e77ba2", 141 | "5b271079e0878c3816dacca4", 142 | "5b08286b2775267d5b0634ba", 143 | "5afacb69ab00705d0cefdd5b", 144 | "5af28cea59bc705737003253", 145 | # "5af02e904c8216544b4ab5a2", 146 | "5aa515e613d42d091d29d300", 147 | "5c34529873a8df509ae57b58", 148 | "5c34300a73a8df509add216d", 149 | "5c1af2e2bee9a723c963d019", 150 | "5c1892f726173c3a09ea9aeb", 151 | "5c0d13b795da9479e12e2ee9", 152 | "5c062d84a96e33018ff6f0a6", 153 | "5bfd0f32ec61ca1dd69dc77b", 154 | "5bf21799d43923194842c001", 155 | "5bf3a82cd439231948877aed", 156 | "5bf03590d4392319481971dc", 157 | "5beb6e66abd34c35e18e66b9", 158 | # "5be883a4f98cee15019d5b83", 159 | "5be47bf9b18881428d8fbc1d", 160 | "5bcf979a6d5f586b95c258cd", 161 | "5bce7ac9ca24970bce4934b6", 162 | "5bb8a49aea1cfa39f1aa7f75", 163 | "5b78e57afc8fcf6781d0c3ba", 164 | # "5b21e18c58e2823a67a10dd8", 165 | "5b22269758e2823a67a3bd03", 166 | "5b192eb2170cf166458ff886", 167 | "5ae2e9c5fe405c5076abc6b2", 168 | "5adc6bd52430a05ecb2ffb85", 169 | "5ab8b8e029f5351f7f2ccf59", 170 | "5abc2506b53b042ead637d86", 171 | "5ab85f1dac4291329b17cb50", 172 | "5a969eea91dfc339a9a3ad2c", 173 | "5a8aa0fab18050187cbe060e", 174 | # "5a7d3db14989e929563eb153", 175 | "5a69c47d0d5d0a7f3b2e9752", 176 | "5a618c72784780334bc1972d", 177 | "5a6464143d809f1d8208c43c", 178 | "5a588a8193ac3d233f77fbca", 179 | "5a57542f333d180827dfc132", 180 | "5a572fd9fc597b0478a81d14", 181 | "5a563183425d0f5186314855", 182 | "5a4a38dad38c8a075495b5d2", 183 | "5a48d4b2c7dab83a7d7b9851", 184 | "5a489fb1c7dab83a7d7b1070", 185 | # "5a48ba95c7dab83a7d7b44ed", 186 | "5a3ca9cb270f0e3f14d0eddb", 187 | "5a3cb4e4270f0e3f14d12f43", 188 | "5a3f4aba5889373fbbc5d3b5", 189 | "5a0271884e62597cdee0d0eb", 190 | "59e864b2a9e91f2c5529325f", 191 | "599aa591d5b41f366fed0d58", 192 | "59350ca084b7f26bf5ce6eb8", 193 | "59338e76772c3e6384afbb15", 194 | "5c20ca3a0843bc542d94e3e2", 195 | "5c1dbf200843bc542d8ef8c4", 196 | "5c1b1500bee9a723c96c3e78", 197 | "5bea87f4abd34c35e1860ab5", 198 | "5c2b3ed5e611832e8aed46bf", 199 | "57f8d9bbe73f6760f10e916a", 200 | "5bf7d63575c26f32dbf7413b", 201 | # "5be4ab93870d330ff2dce134", 202 | "5bd43b4ba6b28b1ee86b92dd", 203 | "5bccd6beca24970bce448134", 204 | "5bc5f0e896b66a2cd8f9bd36", 205 | "5b908d3dc6ab78485f3d24a9", 206 | "5b2c67b5e0878c381608b8d8", 207 | "5b4933abf2b5f44e95de482a", 208 | "5b3b353d8d46a939f93524b9", 209 | "5acf8ca0f3d8a750097e4b15", 210 | "5ab8713ba3799a1d138bd69a", 211 | "5aa235f64a17b335eeaf9609", 212 | # "5aa0f9d7a9efce63548c69a1", 213 | "5a8315f624b8e938486e0bd8", 214 | "5a48c4e9c7dab83a7d7b5cc7", 215 | "59ecfd02e225f6492d20fcc9", 216 | "59f87d0bfa6280566fb38c9a", 217 | "59f363a8b45be22330016cad", 218 | "59f70ab1e5c5d366af29bf3e", 219 | "59e75a2ca9e91f2c5526005d", 220 | "5947719bf1b45630bd096665", 221 | "5947b62af1b45630bd0c2a02", 222 | "59056e6760bb961de55f3501", 223 | "58f7f7299f5b5647873cb110", 224 | "58cf4771d0f5fb221defe6da", 225 | "58d36897f387231e6c929903", 226 | "58c4bb4f4a69c55606122be4", 227 | ] 228 | 229 | 230 | def val_scans() -> List[str]: 231 | return [ 232 | "5bb7a08aea1cfa39f1a947ab", 233 | "5af02e904c8216544b4ab5a2", 234 | "5be883a4f98cee15019d5b83", 235 | "5b21e18c58e2823a67a10dd8", 236 | "5a7d3db14989e929563eb153", 237 | "5a48ba95c7dab83a7d7b44ed", 238 | "5be4ab93870d330ff2dce134", 239 | "5aa0f9d7a9efce63548c69a1", 240 | ] 241 | 242 | 243 | def test_scans() -> List[str]: 244 | return [ 245 | "5b7a3890fc8fcf6781e2593a", 246 | "5c189f2326173c3a09ed7ef3", 247 | "5b950c71608de421b1e7318f", 248 | "5a6400933d809f1d8200af15", 249 | "59d2657f82ca7774b1ec081d", 250 | "5ba19a8a360c7c30c1c169df", 251 | "59817e4a1bd4b175e7038d19", 252 | ] 253 | -------------------------------------------------------------------------------- /guided_mvs_lib/datasets/dtu_blended_mvs.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | from collections import defaultdict 4 | from typing import Callable, Dict, List, Literal, Optional, TypedDict 5 | 6 | import cv2 7 | import numpy as np 8 | import scipy.io 9 | import torch 10 | from PIL import Image 11 | from plyfile import PlyData 12 | from torch.utils.data import Dataset 13 | 14 | import guided_mvs_lib.datasets.blended_mvg_utils as mvg_utils 15 | import guided_mvs_lib.datasets.blended_mvs_utils as mvs_utils 16 | import guided_mvs_lib.datasets.dtu_utils as dtu_utils 17 | 18 | 19 | class MVSSample(TypedDict): 20 | imgs: List[Image.Image] 21 | intrinsics: List[np.ndarray] 22 | extrinsics: List[np.ndarray] 23 | depths: List[np.ndarray] 24 | ref_depth_min: float 25 | ref_depth_max: float 26 | ref_depth_values: np.ndarray 27 | filename: str 28 | scan_pcd: Optional[np.ndarray] 29 | scan_pcd_obs_mask: Optional[np.ndarray] 30 | scan_pcd_bounding_box: Optional[np.ndarray] 31 | scan_pcd_resolution: Optional[float] 32 | scan_pcd_ground_plane: Optional[np.ndarray] 33 | 34 | 35 | def _identity_fn(x: MVSSample) -> Dict: 36 | return x 37 | 38 | 39 | class MVSDataset(Dataset): 40 | def __init__( 41 | self, 42 | name: Literal["dtu_yao", "blended_mvs", "blended_mvg"], 43 | datapath: str, 44 | mode: Literal["train", "val", "test"], 45 | nviews: int = 5, 46 | ndepths: int = 192, 47 | robust_train: bool = False, 48 | transform: Callable[[Dict], Dict] = _identity_fn, 49 | ): 50 | super().__init__() 51 | assert mode in ["train", "val", "test"], "MVSDataset train, val or test" 52 | 53 | if name == "dtu_yao": 54 | self.ds_utils = dtu_utils 55 | elif name == "blended_mvs": 56 | self.ds_utils = mvs_utils 57 | elif name == "blended_mvg": 58 | self.ds_utils = mvg_utils 59 | else: 60 | raise ValueError("datasets supported: dtu_yao, blended_mvs, blended_mvg") 61 | 62 | self.datapath = datapath 63 | self.transform = transform 64 | self.name = name 65 | self.mode = mode 66 | self.nviews = nviews 67 | self.ndepths = ndepths 68 | self.robust_train = robust_train 69 | 70 | scans = { 71 | "train": self.ds_utils.train_scans, 72 | "val": self.ds_utils.val_scans, 73 | "test": self.ds_utils.test_scans, 74 | }[mode]() 75 | if scans is None: 76 | raise ValueError(f"{mode} not supported on dataset {self.name}") 77 | 78 | self.metas = self.ds_utils.build_list(self.datapath, scans, nviews) 79 | 80 | def __len__(self): 81 | return len(self.metas) 82 | 83 | def _resize_depth_dtu(self, depth, size): 84 | h, w, _ = depth.shape 85 | depth = cv2.resize(depth, (w // 2, h // 2), interpolation=cv2.INTER_NEAREST)[..., None] 86 | h, w, _ = depth.shape 87 | start_h, start_w = (h - size[0]) // 2, (w - size[1]) // 2 88 | depth = depth[start_h : start_h + size[0], start_w : start_w + size[1]] 89 | return depth 90 | 91 | def __getitem__(self, idx: int) -> MVSSample: 92 | 93 | # load info 94 | scan, light_idx, ref_view, src_views = self.metas[idx] 95 | if self.mode in ["train", "val"] and self.robust_train: 96 | num_src_views = len(src_views) 97 | index = random.sample(range(num_src_views), self.nviews - 1) 98 | view_ids = [ref_view] + [src_views[i] for i in index] 99 | else: 100 | view_ids = [ref_view] + src_views[: self.nviews - 1] 101 | 102 | # collect the reference depth, the images and meta for this example 103 | out = defaultdict( 104 | lambda: [], 105 | scan_pcd=None, 106 | scan_pcd_obs_mask=None, 107 | scan_pcd_bounding_box=None, 108 | scan_pcd_resolution=None, 109 | scan_pcd_ground_plane=None, 110 | ) 111 | for i, vid in enumerate(view_ids): 112 | datapaths = self.ds_utils.datapath_files( 113 | self.datapath, scan, vid, self.mode, light_idx 114 | ) 115 | out["imgs"].append(Image.open(datapaths["img"])) 116 | 117 | if datapaths["pcd"] is not None and i == 0: 118 | mesh = PlyData.read(datapaths["pcd"]) 119 | out["scan_pcd"] = np.stack( 120 | [mesh["vertex"]["x"], mesh["vertex"]["y"], mesh["vertex"]["z"]], -1 121 | ) 122 | obs_mask_data = scipy.io.loadmat(datapaths["obs_mask"]) 123 | out["scan_pcd_obs_mask"] = obs_mask_data["ObsMask"] 124 | out["scan_pcd_bounding_box"] = obs_mask_data["BB"] 125 | out["scan_pcd_resolution"] = obs_mask_data["Res"] 126 | out["scan_pcd_ground_plane"] = scipy.io.loadmat(datapaths["ground_plane"])["P"] 127 | 128 | # build proj matrix 129 | intrinsics, extrinsics, depth_min, depth_max = self.ds_utils.read_cam_file( 130 | datapaths["proj_mat"] 131 | ) 132 | if self.name == "dtu_yao" and self.mode != "test": 133 | # preprocessed images loaded by dtu_yao train and val have been halved in dimension 134 | # and then cropped to (512, 640) 135 | intrinsics[:2] *= 0.5 136 | intrinsics[0, 2] *= 640 / 800 137 | intrinsics[1, 2] *= 512 / 600 138 | out["intrinsics"].append(intrinsics) 139 | out["extrinsics"].append(extrinsics) 140 | 141 | mask = self.ds_utils.read_depth_mask(datapaths["depth_mask"]) 142 | depth = self.ds_utils.read_depth(datapaths["depth"]) * mask 143 | 144 | if self.name == "dtu_yao" and self.mode != "test": 145 | # the original depth map of dtu must be brought to (512, 640) when in training phase 146 | mask = self._resize_depth_dtu(mask.astype(np.float32), (512, 640)) 147 | depth = self._resize_depth_dtu(depth, (512, 640)) 148 | 149 | out["depths"].append(depth) 150 | if i == 0: 151 | out["ref_depth_min"] = depth_min 152 | out["ref_depth_max"] = depth_max 153 | out["ref_depth_values"] = np.arange( 154 | depth_min, 155 | depth_max, 156 | (depth_max - depth_min) / self.ndepths, 157 | dtype=np.float32, 158 | ) 159 | 160 | out["filename"] = os.path.join(scan, "{}", f"{view_ids[0]:0>8}" + "{}") 161 | return self.transform(dict(out)) 162 | -------------------------------------------------------------------------------- /guided_mvs_lib/datasets/dtu_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utility functions to load the DTU dataset 3 | """ 4 | 5 | from pathlib import Path 6 | from typing import Dict, Literal, Tuple, Union 7 | 8 | import numpy as np 9 | from PIL import Image 10 | 11 | from .utils import DatasetExamplePaths, read_pfm, save_pfm 12 | 13 | __all__ = [ 14 | "datapath_files", 15 | "read_cam_file", 16 | "read_depth_mask", 17 | "read_depth", 18 | "build_list", 19 | "train_scans", 20 | "val_scans", 21 | "test_scans", 22 | ] 23 | 24 | 25 | def datapath_files( 26 | datapath: Union[Path, str], 27 | scan: str, 28 | view_id: int, 29 | split: Literal["train", "test", "val"], 30 | light_id: int = None, 31 | ) -> DatasetExamplePaths: 32 | """ 33 | Takes in input the root of the DTU dataset and returns a dictionary containing 34 | the paths to the files of a single scan, with the speciied view_id and light_id 35 | (this last one is used only when ``split`` isn't test) 36 | 37 | Parameters 38 | ---------- 39 | datapath: path 40 | path to the root of the dtu dataset 41 | scan: str 42 | name of the used scan, say scan1 43 | view_id: int 44 | the index of the specific image in the scan 45 | split: train, val or test 46 | which split of DTU must be used to search the scan for 47 | light_id: int 48 | the index of the specific lightning condition index 49 | 50 | Returns 51 | ------- 52 | out: Dict[str, Path] 53 | returns a dictionary containing the paths taken into account 54 | """ 55 | assert split in ["train", "val", "test"] 56 | root = Path(datapath) 57 | if split in ["train", "val"]: 58 | root = root / "train_data" 59 | return { 60 | "img": root / f"Rectified/{scan}_train/rect_{view_id + 1:0>3}_{light_id}_r5000.png", 61 | "proj_mat": root / f"Cameras_1/{view_id:0>8}_cam.txt", 62 | "depth_mask": root / f"Depths_raw/{scan}/depth_visual_{view_id:0>4}.png", 63 | "depth": root / f"Depths_raw/{scan}/depth_map_{view_id:0>4}.pfm", 64 | "pcd": None, 65 | "obs_mask": None, 66 | "ground_plane": None, 67 | } 68 | else: 69 | scan_idx = int(scan[4:]) 70 | root = root / "test_data" 71 | return { 72 | "img": root / f"{scan}/images/{view_id:0>8}.jpg", 73 | "proj_mat": root / f"{scan}/cams_1/{view_id:0>8}_cam.txt", 74 | "depth_mask": root.parent 75 | / f"train_data/Depths_raw/{scan}/depth_visual_{view_id:0>4}.png", 76 | "depth": root.parent / f"train_data/Depths_raw/{scan}/depth_map_{view_id:0>4}.pfm", 77 | "pcd": root.parent / f"SampleSet/MVS Data/Points/stl/stl{scan_idx:0>3}_total.ply", 78 | "obs_mask": root.parent / f"SampleSet/MVS Data/ObsMask/ObsMask{scan_idx}_10.mat", 79 | "ground_plane": root.parent / f"SampleSet/MVS Data/ObsMask/Plane{scan_idx}.mat", 80 | } 81 | 82 | 83 | def read_cam_file(path: str) -> Tuple[np.ndarray, np.ndarray, float, float]: 84 | """ 85 | Reads a file containing the DTU camera intrinsics, extrinsics, max depth and 86 | min depth. 87 | 88 | Parameters 89 | ---------- 90 | path: str 91 | path of the source file (something like ../00000000_cam.txt) 92 | 93 | Returns 94 | ------- 95 | out: Tuple[np.ndarray, np.ndarray, float, float] 96 | respectively intrinsics, extrinsics, min depth and max depth 97 | """ 98 | with open(path) as f: 99 | lines = f.readlines() 100 | lines = [line.rstrip() for line in lines] 101 | extrinsics = np.fromstring(" ".join(lines[1:5]), dtype=np.float32, sep=" ").reshape((4, 4)) 102 | intrinsics = np.fromstring(" ".join(lines[7:10]), dtype=np.float32, sep=" ").reshape((3, 3)) 103 | depth_min = float(lines[11].split()[0]) 104 | depth_max = float(lines[11].split()[1]) 105 | return intrinsics, extrinsics, depth_min, depth_max 106 | 107 | 108 | def read_depth_mask(path: Union[str, Path]) -> np.ndarray: 109 | """ 110 | Loads the depth DTU depth mask 111 | """ 112 | img = np.array(Image.open(path))[..., None] > 10 113 | return img 114 | 115 | 116 | def read_depth(path: Union[str, Path]) -> np.ndarray: 117 | """ 118 | Loads the depth DTU depth mask 119 | """ 120 | depth = np.array(read_pfm(str(path))[0], dtype=np.float32) 121 | return depth 122 | 123 | 124 | # LIST OF SCANS 125 | 126 | 127 | def build_list(datapath: Union[str, Path], scans: list, nviews: int): 128 | metas = [] 129 | 130 | datapath = Path(datapath) 131 | for scan in scans: 132 | 133 | # find pair file 134 | pair_file_test = datapath / f"test_data/{scan}/pair.txt" 135 | pair_file_train = datapath / "train_data/Cameras_1/pair.txt" 136 | if not pair_file_test.exists(): 137 | if not pair_file_train.exists(): 138 | raise ValueError(f"scan {scan} not found") 139 | else: 140 | pair_file = pair_file_train 141 | split = "train" 142 | else: 143 | pair_file = pair_file_test 144 | split = "test" 145 | 146 | # use pair file 147 | with open(pair_file, "rt") as f: 148 | num_viewpoint = int(f.readline()) 149 | for _ in range(num_viewpoint): 150 | ref_view = int(f.readline().rstrip()) 151 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]] 152 | 153 | if split == "train" and len(src_views) >= nviews: 154 | for light_idx in range(7): 155 | metas.append((scan, light_idx, ref_view, src_views)) 156 | else: 157 | metas.append((scan, None, ref_view, src_views)) 158 | 159 | return metas 160 | 161 | 162 | def train_scans(): 163 | return [ 164 | "scan2", 165 | "scan6", 166 | "scan7", 167 | "scan8", 168 | "scan14", 169 | "scan16", 170 | "scan18", 171 | "scan19", 172 | "scan20", 173 | "scan22", 174 | "scan30", 175 | "scan31", 176 | "scan36", 177 | "scan39", 178 | "scan41", 179 | "scan42", 180 | "scan44", 181 | "scan45", 182 | "scan46", 183 | "scan47", 184 | "scan50", 185 | "scan51", 186 | "scan52", 187 | "scan53", 188 | "scan55", 189 | "scan57", 190 | "scan58", 191 | "scan60", 192 | "scan61", 193 | "scan63", 194 | "scan64", 195 | "scan65", 196 | "scan68", 197 | "scan69", 198 | "scan70", 199 | "scan71", 200 | "scan72", 201 | "scan74", 202 | "scan76", 203 | "scan83", 204 | "scan84", 205 | "scan85", 206 | "scan87", 207 | "scan88", 208 | "scan89", 209 | "scan90", 210 | "scan91", 211 | "scan92", 212 | "scan93", 213 | "scan94", 214 | "scan95", 215 | "scan96", 216 | "scan97", 217 | "scan98", 218 | "scan99", 219 | "scan100", 220 | "scan101", 221 | "scan102", 222 | "scan103", 223 | "scan104", 224 | "scan105", 225 | "scan107", 226 | "scan108", 227 | "scan109", 228 | "scan111", 229 | "scan112", 230 | "scan113", 231 | "scan115", 232 | "scan116", 233 | "scan119", 234 | "scan120", 235 | "scan121", 236 | "scan122", 237 | "scan123", 238 | "scan124", 239 | "scan125", 240 | "scan126", 241 | "scan127", 242 | "scan128", 243 | ] 244 | 245 | 246 | def val_scans(): 247 | return [ 248 | "scan3", 249 | "scan5", 250 | "scan17", 251 | "scan21", 252 | "scan28", 253 | "scan35", 254 | "scan37", 255 | "scan38", 256 | "scan40", 257 | "scan43", 258 | "scan56", 259 | "scan59", 260 | "scan66", 261 | "scan67", 262 | "scan82", 263 | "scan86", 264 | "scan106", 265 | "scan117", 266 | ] 267 | 268 | 269 | def test_scans(): 270 | return [ 271 | "scan1", 272 | "scan4", 273 | "scan9", 274 | "scan10", 275 | "scan11", 276 | "scan12", 277 | "scan13", 278 | "scan15", 279 | "scan23", 280 | "scan24", 281 | "scan29", 282 | "scan32", 283 | "scan33", 284 | "scan34", 285 | "scan48", 286 | "scan49", 287 | "scan62", 288 | "scan75", 289 | "scan77", 290 | "scan110", 291 | "scan114", 292 | "scan118", 293 | ] 294 | -------------------------------------------------------------------------------- /guided_mvs_lib/datasets/sample_preprocess.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from typing import Dict, Literal, Tuple 3 | 4 | import cv2 5 | import numpy as np 6 | import torch 7 | import torchvision.transforms.functional as F 8 | from PIL import Image 9 | 10 | from guided_mvs_lib.datasets.dtu_blended_mvs import MVSSample 11 | 12 | __all__ = ["MVSSampleTransform"] 13 | 14 | 15 | class MVSSampleTransform: 16 | """ 17 | Preprocess each sample computing each view at 4 different scales and converts the 18 | sample in torch Tensor and compute hints in various ways 19 | """ 20 | 21 | def __init__( 22 | self, 23 | generate_hints: Literal[ 24 | "not_guided", "guided", "mvguided", "mvguided_filtered" 25 | ] = "not_guided", 26 | hints_perc: float = 0.01, 27 | filtering_window: Tuple[int, int] = (9, 9), 28 | ): 29 | assert generate_hints in ["not_guided", "guided", "mvguided", "mvguided_filtered"] 30 | 31 | self.generate_hints = generate_hints 32 | self._hints_perc = hints_perc 33 | self._height_bin = (filtering_window[0] // 2) - 1 34 | self._width_bin = (filtering_window[1] // 2) - 1 35 | 36 | def _generate_hints(self, sample: Dict) -> torch.Tensor: 37 | 38 | if self.generate_hints == "guided": 39 | # use only the ref depth map 40 | depth = sample["depths"][0] 41 | valid_hints = (depth > 0) & (np.random.rand(*depth.shape) <= self._hints_perc) 42 | hints = depth * valid_hints 43 | return torch.from_numpy(hints).permute(2, 0, 1) 44 | elif self.generate_hints in ["mvguided", "mvguided_filtered"]: 45 | 46 | hints_perc = self._hints_perc 47 | 48 | # extract hints from each depth map 49 | unwarped_hints = [] 50 | for depth in sample["depths"]: 51 | valid_hints = (depth > 0) & (np.random.rand(*depth.shape) <= hints_perc) 52 | unwarped_hints.append(torch.from_numpy(depth * valid_hints).permute(2, 0, 1)[None]) 53 | 54 | # project hints from other depth maps in the ref one 55 | ref_hints, ref_mask = unwarped_hints[0], unwarped_hints[0] > 0 56 | for i in range(1, len(unwarped_hints)): 57 | proj_mat_ref = sample["extrinsics"][0].copy() 58 | proj_mat_ref[:3, :4] = sample["intrinsics"][0] @ proj_mat_ref[:3, :4] 59 | proj_mat_src = sample["extrinsics"][i].copy() 60 | proj_mat_src[:3, :4] = sample["intrinsics"][i] @ proj_mat_src[:3, :4] 61 | 62 | warped_hints, warped_mask = hints_homo_warping( 63 | unwarped_hints[i], 64 | (unwarped_hints[i] > 0).to(torch.float32), 65 | torch.from_numpy(proj_mat_ref[None]), 66 | torch.from_numpy(proj_mat_src[None]), 67 | ) 68 | assign_mask = (ref_mask * (1 - warped_mask)) == 0 69 | ref_hints[assign_mask] = warped_hints[assign_mask] 70 | ref_mask = ref_mask | warped_mask.to(torch.bool) 71 | 72 | # filter if required 73 | ref_hints = ref_hints[0] * ref_mask[0] 74 | if self.generate_hints == "mvguided_filtered": 75 | hints_mask = outlier_removal_mask( 76 | ref_hints.numpy().transpose([1, 2, 0]), 77 | sample["intrinsics"][0], 78 | height_bin=self._height_bin, 79 | width_bin=self._width_bin, 80 | ) 81 | ref_hints[~torch.from_numpy(hints_mask).to(torch.bool).permute(2, 0, 1)] = 0 82 | 83 | return ref_hints 84 | 85 | def _split_pad(self, pad): 86 | if pad % 2 == 0: 87 | return pad // 2, pad // 2 88 | else: 89 | pad_1 = pad // 2 90 | pad_2 = (pad // 2) + 1 91 | return pad_1, pad_2 92 | 93 | def _pad_to_div_by(self, x, *, div_by=8): 94 | 95 | # compute padding 96 | if isinstance(x, Image.Image): 97 | w, h = x.size 98 | elif isinstance(x, np.ndarray): 99 | h, w, _ = x.shape 100 | else: 101 | raise ValueError("Image or np.ndarray") 102 | 103 | new_h = int(np.ceil(h / div_by)) * div_by 104 | new_w = int(np.ceil(w / div_by)) * div_by 105 | pad_t, pad_b = self._split_pad(new_h - h) 106 | pad_l, pad_r = self._split_pad(new_w - w) 107 | 108 | # return PIL or np.ndarray 109 | if isinstance(x, Image.Image): 110 | return F.pad(x, (pad_l, pad_t, pad_r, pad_b)) 111 | elif isinstance(x, np.ndarray): 112 | return np.pad(x, [(pad_t, pad_b), (pad_l, pad_r), (0, 0)]) 113 | 114 | def __call__(self, sample: MVSSample) -> Dict: 115 | 116 | # padding 117 | imgs = [] 118 | for img, intrins in zip(sample["imgs"], sample["intrinsics"]): 119 | 120 | # pad the image 121 | w, h = img.size 122 | imgs.append(self._pad_to_div_by(img, div_by=32)) 123 | w_new, h_new = imgs[-1].size 124 | 125 | # adapt intrinsics 126 | pad_w = (w_new - w) / 2 127 | ratio = (w + pad_w) / w 128 | intrins[0, 2] = intrins[0, 2] * ratio 129 | pad_h = (h_new - h) / 2 130 | ratio = (h + pad_h) / h 131 | intrins[1, 2] = intrins[1, 2] * ratio 132 | 133 | sample["imgs"] = imgs 134 | sample["depths"] = [self._pad_to_div_by(x, div_by=32) for x in sample["depths"]] 135 | 136 | # compute downsampled stages 137 | imgs, proj_matrices = defaultdict(lambda: []), defaultdict(lambda: []) 138 | depths = {} 139 | 140 | w, h = sample["imgs"][0].size 141 | for i in range(4): 142 | dsize = (w // (2 ** i), h // (2 ** i)) 143 | 144 | for img in sample["imgs"]: 145 | imgs[f"stage_{i}"].append(F.to_tensor(cv2.resize(np.array(img), dsize))) 146 | 147 | for extrinsics, intrinsics in zip(sample["extrinsics"], sample["intrinsics"]): 148 | proj_mat = extrinsics.copy() 149 | intrinsics_copy = intrinsics.copy() 150 | intrinsics_copy[:2, :] = intrinsics_copy[:2, :] / (2 ** i) 151 | proj_mat[:3, :4] = intrinsics_copy @ proj_mat[:3, :4] 152 | proj_matrices[f"stage_{i}"].append(torch.from_numpy(proj_mat)) 153 | 154 | depths[f"stage_{i}"] = F.to_tensor(cv2.resize(sample["depths"][0], dsize)) 155 | 156 | # return result 157 | out = { 158 | "imgs": {k: torch.stack(v) for k, v in imgs.items()}, 159 | "depth": depths, 160 | "proj_matrices": {k: torch.stack(v) for k, v in proj_matrices.items()}, 161 | "intrinsics": torch.from_numpy(np.stack(sample["intrinsics"])), 162 | "extrinsics": torch.from_numpy(np.stack(sample["extrinsics"])), 163 | "depth_min": sample["ref_depth_min"], 164 | "depth_max": sample["ref_depth_max"], 165 | "depth_values": torch.from_numpy(sample["ref_depth_values"]), 166 | "filename": sample["filename"], 167 | } 168 | 169 | # add hints if requested 170 | if self.generate_hints != "not_guided": 171 | out["hints"] = self._generate_hints(sample) 172 | 173 | # add fields pcd if available 174 | for key in [ 175 | "scan_pcd", 176 | "scan_pcd_obs_mask", 177 | "scan_pcd_bounding_box", 178 | "scan_pcd_resolution", 179 | "scan_pcd_ground_plane", 180 | ]: 181 | if key in sample and sample[key] is not None: 182 | out[key] = torch.from_numpy(sample[key]) 183 | 184 | return out 185 | 186 | 187 | def _get_all_points(lidar: np.ndarray, intrinsic: np.ndarray) -> Tuple: 188 | 189 | lidar_32 = np.squeeze(lidar).astype(np.float32) 190 | height, width = np.shape(lidar_32) 191 | x_axis = np.arange(width).reshape(width, 1) 192 | x_image = np.tile(x_axis, height) 193 | x_image = np.transpose(x_image) 194 | y_axis = np.arange(height).reshape(height, 1) 195 | y_image = np.tile(y_axis, width) 196 | z_image = np.ones((height, width)) 197 | image_coor_tensor = ( 198 | np.asarray([x_image, y_image, z_image]).astype(np.float32).transpose([1, 0, 2]) 199 | ) 200 | 201 | intrinsic = np.reshape(intrinsic, [3, 3]).astype(np.float32) 202 | intrinsic_inverse = np.linalg.inv(intrinsic) 203 | points_homo = np.matmul(intrinsic_inverse, image_coor_tensor) 204 | 205 | lidar_32 = np.reshape(lidar_32, [height, 1, width]) 206 | points_homo = points_homo * lidar_32 207 | extra_image = np.ones((height, width)).astype(np.float32) 208 | extra_image = np.reshape(extra_image, [height, 1, width]) 209 | points_homo = np.concatenate([points_homo, extra_image], axis=1) 210 | 211 | extrinsic_v_2_c = [ 212 | [0.007, -1, 0, 0], 213 | [0.0148, 0, -1, -0.076], 214 | [1, 0, 0.0148, -0.271], 215 | [0, 0, 0, 1], 216 | ] 217 | extrinsic_v_2_c = np.reshape(extrinsic_v_2_c, [4, 4]).astype(np.float32) 218 | extrinsic_c_2_v = np.linalg.inv(extrinsic_v_2_c) 219 | points_lidar = np.matmul(extrinsic_c_2_v, points_homo) 220 | 221 | mask = np.squeeze(lidar) > 0.1 222 | total_points = [ 223 | points_lidar[:, 0, :][mask], 224 | points_lidar[:, 1, :][mask], 225 | points_lidar[:, 2, :][mask], 226 | ] 227 | total_points = np.asarray(total_points) 228 | total_points = np.transpose(total_points) 229 | 230 | return total_points, x_image[mask], y_image[mask], x_image, y_image 231 | 232 | 233 | def _do_range_projection_try( 234 | points: np.ndarray, 235 | fov_up: float = 3.0, 236 | fov_down: float = -18.0, 237 | ) -> Tuple[np.ndarray, np.ndarray]: 238 | # for each point, where it is in the range image 239 | proj_x = np.zeros((0, 1), dtype=np.float32) # [m, 1]: x 240 | proj_y = np.zeros((0, 1), dtype=np.float32) # [m, 1]: y 241 | 242 | # laser parameters 243 | fov_up = fov_up / 180.0 * np.pi # field of view up in rad 244 | fov_down = fov_down / 180.0 * np.pi # field of view down in rad 245 | fov = abs(fov_down) + abs(fov_up) # get field of view total in rad 246 | 247 | # get depth of all points 248 | depth = np.linalg.norm(points, 2, axis=1) 249 | 250 | # get scan components 251 | scan_x = points[:, 0] 252 | scan_y = points[:, 1] 253 | scan_z = points[:, 2] 254 | 255 | # get angles of all points 256 | yaw = -np.arctan2(scan_y, scan_x) 257 | pitch = np.arcsin(scan_z / depth) 258 | 259 | # get projections in image coords 260 | proj_x = 0.5 * (yaw / np.pi + 1.0) # in [0.0, 1.0] 261 | 262 | proj_y = 1.0 - (pitch + abs(fov_down)) / fov # in [0.0, 1.0] 263 | return proj_x, proj_y 264 | 265 | 266 | def _compute_trunck(v: np.ndarray, height_bin: int = 4, width_bin: int = 4) -> np.ndarray: 267 | v = np.squeeze(v) 268 | v_trunck = np.lib.stride_tricks.sliding_window_view( 269 | np.pad(v, [(height_bin, height_bin), (width_bin, width_bin)]), 270 | (height_bin * 2 + 1, width_bin * 2 + 1), 271 | ).reshape(v.shape[0], v.shape[1], (height_bin * 2 + 1) * (height_bin * 2 + 1)) 272 | return v_trunck 273 | 274 | 275 | def _compute_residual(v, height_bin, width_bin): 276 | v_trunck = _compute_trunck(v, height_bin, width_bin) 277 | residual = np.squeeze(v)[..., None] - v_trunck 278 | return residual 279 | 280 | 281 | def outlier_removal_mask( 282 | lidar: np.ndarray, intrinsic: np.ndarray, height_bin: int = 4, width_bin: int = 4 283 | ) -> np.ndarray: 284 | 285 | height, width, _ = lidar.shape 286 | 287 | total_points, x_indices, y_indices, width_image, height_image = _get_all_points( 288 | lidar, intrinsic 289 | ) 290 | proj_x, proj_y = _do_range_projection_try(total_points) 291 | 292 | project_x, project_y = np.zeros((2, height, width, 1)) 293 | project_x[y_indices, x_indices, 0] = proj_x 294 | project_y[y_indices, x_indices, 0] = proj_y 295 | 296 | project_x_trunck = _compute_trunck(project_x, height_bin, width_bin) 297 | project_x_residual = project_x - project_x_trunck 298 | 299 | project_y_trunck = _compute_trunck(project_y, height_bin, width_bin) 300 | project_y_residual = project_y - project_y_trunck 301 | 302 | height_image_trunck = _compute_trunck(height_image, height_bin, width_bin) 303 | height_image_residual = height_image[..., None] - height_image_trunck 304 | 305 | width_image_trunck = _compute_trunck(width_image, height_bin, width_bin) 306 | width_image_residual = width_image[..., None] - width_image_trunck 307 | 308 | lidar_trunck = _compute_trunck(lidar, height_bin, width_bin) 309 | zero_mask = np.logical_and(lidar > 0.1, lidar_trunck > 0.1) 310 | 311 | x_mask = np.logical_and( 312 | np.logical_or( 313 | np.logical_and(project_x_residual > 0.0000, width_image_residual <= 0), 314 | np.logical_and(project_x_residual < 0.0000, width_image_residual >= 0), 315 | ), 316 | zero_mask, 317 | ) 318 | 319 | y_mask = np.logical_and( 320 | np.logical_or( 321 | np.logical_and(project_y_residual > 0, height_image_residual <= 0), 322 | np.logical_and(project_y_residual < 0, height_image_residual >= 0), 323 | ), 324 | zero_mask, 325 | ) 326 | 327 | lidar_residual = lidar - lidar_trunck 328 | lidar_mask = np.logical_and(lidar_residual > 3.0, lidar > 0.01) 329 | 330 | final_mask = np.logical_and(lidar_mask, np.logical_or(x_mask, y_mask)) 331 | final_mask = np.squeeze(final_mask) 332 | final_mask = np.sum(final_mask, axis=-1, keepdims=True) == 0 333 | 334 | return final_mask 335 | 336 | 337 | def hints_homo_warping(src_hints, src_validhints, src_proj, ref_proj): 338 | # src_fea: [B, 1, H, W] 339 | # src_proj: [B, 4, 4] 340 | # ref_proj: [B, 4, 4] 341 | # depth_values: [B, Ndepth] 342 | # out: [B, C, Ndepth, H, W] 343 | batch = src_hints.shape[0] 344 | height, width = src_hints.shape[2], src_hints.shape[3] 345 | warped_hints = torch.zeros_like(src_hints) 346 | warped_validhints = torch.zeros_like(src_validhints) 347 | 348 | with torch.no_grad(): 349 | proj = torch.matmul(src_proj, torch.inverse(ref_proj)) 350 | rot = proj[:, :3, :3] # [B,3,3] 351 | trans = proj[:, :3, 3:4] # [B,3,1] 352 | 353 | y, x = torch.meshgrid( 354 | [ 355 | torch.arange(0, height, dtype=torch.float32, device=src_hints.device), 356 | torch.arange(0, width, dtype=torch.float32, device=src_hints.device), 357 | ] 358 | ) 359 | y, x = y.contiguous(), x.contiguous() 360 | y, x = y.view(height * width), x.view(height * width) 361 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] 362 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W] 363 | rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W] 364 | rot_depth_xyz = rot_xyz.unsqueeze(2) * src_hints.view( 365 | batch, 1, 1, -1 366 | ) # [B, 3, Ndepth, H*W] 367 | proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W] 368 | 369 | proj_x = proj_xyz[:, 0, :, :].view(batch, height, width) 370 | proj_y = proj_xyz[:, 1, :, :].view(batch, height, width) 371 | proj_z = proj_xyz[:, 2, :, :].view(batch, height, width) 372 | 373 | proj_x = torch.clamp(np.round(proj_x / proj_z).to(torch.int64), min=0, max=width - 1) 374 | proj_y = torch.clamp(np.round(proj_y / proj_z).to(torch.int64), min=0, max=height - 1) 375 | proj_z = proj_z.unsqueeze(-1) 376 | 377 | warped_hints = warped_hints.squeeze(1).unsqueeze(-1) 378 | warped_validhints = warped_validhints.squeeze(1).unsqueeze(-1) 379 | src_validhints = src_validhints.squeeze(1).unsqueeze(-1) 380 | 381 | # forward warping (will it work?) 382 | for i in range(warped_hints.shape[0]): 383 | warped_hints[i][proj_y, proj_x] = -proj_z[i] 384 | warped_validhints[i][proj_y, proj_x] = -src_validhints[i] 385 | warped_hints *= -1 386 | warped_validhints *= -1 * (warped_hints > 0) 387 | 388 | return warped_hints.permute(0, 3, 1, 2), warped_validhints.permute( 389 | 0, 3, 1, 2 390 | ) # , warped_validhints 391 | -------------------------------------------------------------------------------- /guided_mvs_lib/datasets/utils.py: -------------------------------------------------------------------------------- 1 | import re 2 | import sys 3 | from pathlib import Path 4 | from typing import Optional, Tuple, TypedDict 5 | 6 | import numpy as np 7 | 8 | __all__ = ["DatasetInstancePaths", "read_pfm", "save_pfm"] 9 | 10 | 11 | class DatasetExamplePaths(TypedDict): 12 | img: Path 13 | proj_mat: Path 14 | depth_mask: Path 15 | depth: Path 16 | pcd: Optional[Path] 17 | obs_mask: Optional[Path] 18 | ground_plane: Optional[Path] 19 | 20 | 21 | def read_pfm(filename: str) -> Tuple[np.ndarray, float]: 22 | """Read a depth map from a .pfm file 23 | 24 | Args: 25 | filename: .pfm file path string 26 | 27 | Returns: 28 | data: array of shape (H, W, C) representing loaded depth map 29 | scale: float to recover actual depth map pixel values 30 | """ 31 | file = open(filename, "rb") # treat as binary and read-only 32 | color = None 33 | width = None 34 | height = None 35 | scale = None 36 | endian = None 37 | 38 | header = file.readline().decode("utf-8").rstrip() 39 | if header == "PF": 40 | color = True 41 | elif header == "Pf": # depth is Pf 42 | color = False 43 | else: 44 | raise Exception("Not a PFM file.") 45 | 46 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("utf-8")) 47 | if dim_match: 48 | width, height = map(int, dim_match.groups()) 49 | else: 50 | raise Exception("Malformed PFM header.") 51 | 52 | scale = float(file.readline().rstrip()) 53 | if scale < 0: # little-endian 54 | endian = "<" 55 | scale = -scale 56 | else: 57 | endian = ">" # big-endian 58 | 59 | data = np.fromfile(file, endian + "f") 60 | shape = (height, width, 3) if color else (height, width, 1) 61 | 62 | data = np.reshape(data, shape) 63 | data = np.flipud(data) 64 | file.close() 65 | return data, scale 66 | 67 | 68 | def save_pfm(filename: str, image: np.ndarray, scale: float = 1) -> None: 69 | """Save a depth map from a .pfm file 70 | 71 | Args: 72 | filename: output .pfm file path string, 73 | image: depth map to save, of shape (H,W) or (H,W,C) 74 | scale: scale parameter to save 75 | """ 76 | file = open(filename, "wb") 77 | color = None 78 | 79 | image = np.flipud(image) 80 | 81 | if image.dtype.name != "float32": 82 | raise Exception("Image dtype must be float32.") 83 | 84 | if len(image.shape) == 3 and image.shape[2] == 3: # color image 85 | color = True 86 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale 87 | color = False 88 | else: 89 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.") 90 | 91 | file.write("PF\n".encode("utf-8") if color else "Pf\n".encode("utf-8")) 92 | file.write("{} {}\n".format(image.shape[1], image.shape[0]).encode("utf-8")) 93 | 94 | endian = image.dtype.byteorder 95 | 96 | if endian == "<" or endian == "=" and sys.byteorder == "little": 97 | scale = -scale 98 | 99 | file.write(("%f\n" % scale).encode("utf-8")) 100 | 101 | image.tofile(file) 102 | file.close() 103 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/__init__.py: -------------------------------------------------------------------------------- 1 | from types import SimpleNamespace 2 | from typing import Dict, Optional, Union 3 | 4 | import matplotlib.pyplot as plt 5 | import torch 6 | from pytorch_lightning import LightningModule 7 | from torch import optim 8 | from torch.optim.lr_scheduler import MultiStepLR 9 | 10 | from . import cas_mvsnet, d2hc_rmvsnet, mvsnet, patchmatchnet, ucsnet 11 | 12 | __all__ = ["MVSModel", "build_network"] 13 | 14 | _NETWORKS = { 15 | "mvsnet": mvsnet.NetBuilder, 16 | "cas_mvsnet": cas_mvsnet.NetBuilder, 17 | "ucsnet": ucsnet.NetBuilder, 18 | "d2hc_rmvsnet": d2hc_rmvsnet.NetBuilder, 19 | "patchmatchnet": patchmatchnet.NetBuilder, 20 | } 21 | 22 | 23 | def build_network(name: str, args: SimpleNamespace): 24 | try: 25 | return _NETWORKS[name](args) 26 | except KeyError: 27 | raise ValueError("network name in {}".format(", ".join(_NETWORKS.keys()))) 28 | 29 | 30 | class MVSModel(LightningModule): 31 | def __init__( 32 | self, 33 | *, 34 | args: SimpleNamespace, 35 | mlflow_run_id: Optional[str] = None, 36 | v_num: Optional[str] = None, # (experiment version to show in tqdm) 37 | ): 38 | super().__init__() 39 | 40 | # instance the used model 41 | self.model = build_network(args.model, args) 42 | self.mlflow_run_id = mlflow_run_id 43 | self.loss_fn = self.model.loss 44 | 45 | # save train parameters 46 | hparams = dict( 47 | model=args.model, 48 | **{name: value for name, value in args.train.__dict__.items()}, 49 | ) 50 | if hasattr(self.model, "hparams"): 51 | hparams = dict(**hparams, **self.model.hparams) 52 | 53 | self.save_hyperparameters(hparams) 54 | 55 | # utils 56 | self._is_val = False 57 | self.v_num = v_num 58 | 59 | def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]: 60 | progress_bar_dict = super().get_progress_bar_dict() 61 | progress_bar_dict["v_num"] = self.v_num 62 | return progress_bar_dict 63 | 64 | def forward(self, batch: dict): 65 | return self.model(batch) 66 | 67 | def configure_optimizers(self): 68 | optimizer = optim.Adam( 69 | self.model.parameters(), 70 | lr=self.hparams.lr, 71 | betas=(0.9, 0.999), 72 | weight_decay=self.hparams.weight_decay, 73 | ) 74 | 75 | if self.hparams.epochs_lr_decay is not None: 76 | scheduler = MultiStepLR( 77 | optimizer, 78 | self.hparams.epochs_lr_decay, 79 | gamma=self.hparams.epochs_lr_gamma, 80 | ) 81 | return [optimizer], [scheduler] 82 | else: 83 | return optimizer 84 | 85 | def _compute_masks(self, depth_dict: Dict) -> Dict: 86 | return {k: (v > 0).to(torch.float32) for k, v in depth_dict.items()} 87 | 88 | def training_step(self, batch, _): 89 | self._is_val = False 90 | 91 | outputs = self.model(batch) 92 | loss = self.loss_fn( 93 | outputs["loss_data"], batch["depth"], self._compute_masks(batch["depth"]) 94 | ) 95 | self._log_metrics(batch, outputs, loss, "train") 96 | 97 | return loss 98 | 99 | def validation_step(self, batch, _): 100 | outputs = self.model(batch) 101 | loss = self.loss_fn( 102 | outputs["loss_data"], batch["depth"], self._compute_masks(batch["depth"]) 103 | ) 104 | self._log_metrics(batch, outputs, loss, "val") 105 | 106 | return loss 107 | 108 | def _log_metrics(self, batch, outputs, loss, stage): 109 | on_epoch = stage != "train" 110 | 111 | depth_gt = batch["depth"]["stage_0"] 112 | mask = (depth_gt > 0).to(torch.float32) 113 | depth_est = outputs["depth"]["stage_0"] 114 | 115 | # log scalar metrics 116 | self.log(f"{stage}/loss", loss, on_epoch=on_epoch, on_step=not on_epoch) 117 | self.log( 118 | f"{stage}/mae", 119 | torch.mean(torch.abs(depth_est[mask > 0.5] - depth_gt[mask > 0.5])), 120 | on_epoch=on_epoch, 121 | on_step=not on_epoch, 122 | ) 123 | for thresh in [1, 2, 3, 4, 8]: 124 | self.log( 125 | f"{stage}/perc_l1_upper_thresh_{thresh}", 126 | torch.mean( 127 | (torch.abs(depth_est[mask > 0.5] - depth_gt[mask > 0.5]) > thresh).float() 128 | ), 129 | on_epoch=on_epoch, 130 | on_step=not on_epoch, 131 | ) 132 | 133 | # logs only the first image for val, one every 5000 steps for train 134 | if (stage == "train" and self.global_step % 5000 == 0) or ( 135 | stage == "val" and not self._is_val 136 | ): 137 | if stage == "val": 138 | self._is_val = True 139 | 140 | if self.mlflow_run_id != None: 141 | for name, map_ in [("depth_pred", depth_est), ("depth_gt", depth_gt)]: 142 | 143 | fig = plt.figure(tight_layout=True) 144 | ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0]) 145 | ax.set_axis_off() 146 | fig.add_axes(ax) 147 | ax.imshow( 148 | (map_[0] * mask[0]).permute(1, 2, 0).detach().cpu().numpy(), cmap="magma_r" 149 | ) 150 | self.logger.experiment.log_figure( 151 | self.mlflow_run_id, fig, f"{stage}/{name}-{self.global_step}.jpg" 152 | ) 153 | plt.close(fig) 154 | 155 | self.logger.experiment.log_image( 156 | self.mlflow_run_id, 157 | batch["imgs"]["stage_0"][0, 0].permute(1, 2, 0).cpu().numpy(), 158 | f"{stage}/image-{self.global_step}.jpg", 159 | ) 160 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/cas_mvsnet/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from types import SimpleNamespace 3 | from typing import Any, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import Tensor 8 | 9 | from .cas_mvsnet import CascadeMVSNet, cas_mvsnet_loss 10 | 11 | __all__ = ["cas_mvsnet_loss", "CascadeMVSNet", "NetBuilder", "SimpleIntefaceNet"] 12 | 13 | 14 | class SimpleInterfaceNet(nn.Module): 15 | """ 16 | Simple common interface to call the pretrained models 17 | """ 18 | 19 | def __init__(self, *args, **kwargs): 20 | super().__init__() 21 | self.model = CascadeMVSNet(*args, **kwargs) 22 | self.all_outputs: dict[str, Any] = {} 23 | 24 | def forward( 25 | self, 26 | imgs: Tensor, 27 | intrinsics: Tensor, 28 | extrinsics: Tensor, 29 | depth_values: Tensor, 30 | hints: Optional[Tensor] = None, 31 | ): 32 | with warnings.catch_warnings(): 33 | warnings.simplefilter("ignore", UserWarning) 34 | 35 | # compute poses 36 | proj_matrices = {} 37 | for i in range(4): 38 | proj_mat = extrinsics.clone() 39 | intrinsics_copy = intrinsics.clone() 40 | intrinsics_copy[..., :2, :] = intrinsics_copy[..., :2, :] / (2 ** i) 41 | proj_mat[..., :3, :4] = intrinsics_copy @ proj_mat[..., :3, :4] 42 | proj_matrices[f"stage_{i}"] = proj_mat 43 | 44 | # validhints 45 | validhints = None 46 | if hints is not None: 47 | validhints = (hints > 0).to(torch.float32) 48 | 49 | # call 50 | out = self.model(imgs, proj_matrices, depth_values, hints, validhints) 51 | self.all_outputs = out 52 | return out["depth"]["stage_0"] 53 | 54 | 55 | class NetBuilder(nn.Module): 56 | def __init__(self, args: SimpleNamespace): 57 | super().__init__() 58 | self.model = CascadeMVSNet() 59 | 60 | def loss_func(loss_data, depth_gt, mask): 61 | return cas_mvsnet_loss( 62 | loss_data["depth"], 63 | depth_gt, 64 | mask, 65 | ) 66 | 67 | self.loss = loss_func 68 | 69 | def forward(self, batch: dict): 70 | 71 | hints, validhints = None, None 72 | if "hints" in batch: 73 | hints = batch["hints"] 74 | validhints = (hints > 0).to(torch.float32) 75 | 76 | return self.model( 77 | batch["imgs"]["stage_0"], 78 | batch["proj_matrices"], 79 | batch["depth_values"], 80 | hints, 81 | validhints, 82 | ) 83 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/cas_mvsnet/cas_mvsnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .module import * 6 | 7 | Align_Corners_Range = False 8 | 9 | 10 | class DepthNet(nn.Module): 11 | def __init__(self): 12 | super(DepthNet, self).__init__() 13 | 14 | def forward( 15 | self, 16 | features, 17 | proj_matrices, 18 | depth_values, 19 | num_depth, 20 | cost_regularization, 21 | prob_volume_init=None, 22 | hints=None, 23 | validhints=None, 24 | scale=1, 25 | ): 26 | proj_matrices = torch.unbind(proj_matrices, 1) 27 | assert len(features) == len( 28 | proj_matrices 29 | ), "Different number of images and projection matrices" 30 | assert depth_values.shape[1] == num_depth, "depth_values.shape[1]:{} num_depth:{}".format( 31 | depth_values.shapep[1], num_depth 32 | ) 33 | num_views = len(features) 34 | 35 | # step 1. feature extraction 36 | # in: images; out: 32-channel feature maps 37 | ref_feature, src_features = features[0], features[1:] 38 | ref_proj, src_projs = proj_matrices[0], proj_matrices[1:] 39 | 40 | # step 2. differentiable homograph, build cost volume 41 | ref_volume = ref_feature.unsqueeze(2).repeat(1, 1, num_depth, 1, 1) 42 | volume_sum = ref_volume 43 | volume_sq_sum = ref_volume ** 2 44 | del ref_volume 45 | for src_fea, src_proj in zip(src_features, src_projs): 46 | warped_volume = homo_warping(src_fea, src_proj, ref_proj, depth_values) 47 | 48 | if self.training: 49 | volume_sum = volume_sum + warped_volume 50 | volume_sq_sum = volume_sq_sum + warped_volume ** 2 51 | else: 52 | # TODO: this is only a temporal solution to save memory, better way? 53 | volume_sum += warped_volume 54 | volume_sq_sum += warped_volume.pow_( 55 | 2 56 | ) # the memory of warped_volume has been modified 57 | del warped_volume 58 | 59 | # aggregate multiple feature volumes by variance 60 | volume_variance = volume_sq_sum.div_(num_views).sub_(volume_sum.div_(num_views).pow_(2)) 61 | 62 | # TODO cost-volume modulation 63 | if hints is not None and validhints is not None: 64 | batch_size, feats, height, width = ref_feature.shape 65 | GAUSSIAN_HEIGHT = 10.0 66 | GAUSSIAN_WIDTH = 1.0 67 | 68 | # image features are one fourth the original size: subsample the hints and divide them by four 69 | hints = hints 70 | hints = F.interpolate(hints, scale_factor=1 / scale, mode="nearest").unsqueeze(1) 71 | validhints = validhints 72 | validhints = F.interpolate( 73 | validhints, scale_factor=1 / scale, mode="nearest" 74 | ).unsqueeze(1) 75 | hints = hints * validhints 76 | 77 | # add feature and disparity dimensions to hints and validhints 78 | # and repeat their values along those dimensions, to obtain the same size as cost 79 | hints = hints.expand(-1, feats, num_depth, -1, -1) 80 | validhints = validhints.expand(-1, feats, num_depth, -1, -1) 81 | 82 | # create a tensor of the same size as cost, with disparities 83 | # between 0 and num_disp-1 along the disparity dimension 84 | depth_hyps = ( 85 | depth_values.unsqueeze(1).expand(batch_size, feats, -1, height, width).detach() 86 | ) 87 | volume_variance = volume_variance * ( 88 | (1 - validhints) 89 | + validhints 90 | * GAUSSIAN_HEIGHT 91 | * (1 - torch.exp(-((depth_hyps - hints) ** 2) / (2 * GAUSSIAN_WIDTH ** 2))) 92 | ) 93 | 94 | # step 3. cost volume regularization 95 | cost_reg = cost_regularization(volume_variance) 96 | # cost_reg = F.upsample(cost_reg, [num_depth * 4, img_height, img_width], mode='trilinear') 97 | prob_volume_pre = cost_reg.squeeze(1) 98 | 99 | if prob_volume_init is not None: 100 | prob_volume_pre += prob_volume_init 101 | 102 | prob_volume = F.softmax(prob_volume_pre, dim=1) 103 | depth = depth_regression(prob_volume, depth_values=depth_values) 104 | 105 | with torch.no_grad(): 106 | # photometric confidence 107 | prob_volume_sum4 = ( 108 | 4 109 | * F.avg_pool3d( 110 | F.pad(prob_volume.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)), 111 | (4, 1, 1), 112 | stride=1, 113 | padding=0, 114 | ).squeeze(1) 115 | ) 116 | depth_index = depth_regression( 117 | prob_volume, 118 | depth_values=torch.arange(num_depth, device=prob_volume.device, dtype=torch.float), 119 | ).long() 120 | depth_index = depth_index.clamp(min=0, max=num_depth - 1) 121 | photometric_confidence = torch.gather( 122 | prob_volume_sum4, 1, depth_index.unsqueeze(1) 123 | ).squeeze(1) 124 | 125 | return {"depth": depth, "photometric_confidence": photometric_confidence} 126 | 127 | 128 | class CascadeMVSNet(nn.Module): 129 | def __init__( 130 | self, 131 | refine=False, 132 | ndepths=[48, 32, 8], 133 | depth_interals_ratio=[4, 2, 1], 134 | share_cr=False, 135 | grad_method="detach", 136 | arch_mode="fpn", 137 | cr_base_chs=[8, 8, 8], 138 | ): 139 | super(CascadeMVSNet, self).__init__() 140 | self.refine = refine 141 | self.share_cr = share_cr 142 | self.ndepths = ndepths 143 | self.depth_interals_ratio = depth_interals_ratio 144 | self.grad_method = grad_method 145 | self.arch_mode = arch_mode 146 | self.cr_base_chs = cr_base_chs 147 | self.num_stage = len(ndepths) 148 | assert len(ndepths) == len(depth_interals_ratio) 149 | 150 | self.stage_infos = { 151 | "stage_2": { # was "stage1" 152 | "scale": 4.0, 153 | }, 154 | "stage_1": { # was "stage2" 155 | "scale": 2.0, 156 | }, 157 | "stage_0": { # was "stage3" 158 | "scale": 1.0, 159 | }, 160 | } 161 | 162 | self.feature = FeatureNet( 163 | base_channels=8, stride=4, num_stage=self.num_stage, arch_mode=self.arch_mode 164 | ) 165 | if self.share_cr: 166 | self.cost_regularization = CostRegNet( 167 | in_channels=self.feature.out_channels, base_channels=8 168 | ) 169 | else: 170 | self.cost_regularization = nn.ModuleList( 171 | [ 172 | CostRegNet( 173 | in_channels=self.feature.out_channels[i], base_channels=self.cr_base_chs[i] 174 | ) 175 | for i in [2, 1, 0] 176 | ] 177 | ) # range(self.num_stage)]) 178 | if self.refine: 179 | self.refine_network = RefineNet() 180 | self.DepthNet = DepthNet() 181 | 182 | def forward(self, imgs, proj_matrices, depth_values, hints=None, validhints=None): 183 | depth_min = float(depth_values[0, 0].cpu().numpy()) 184 | depth_max = float(depth_values[0, -1].cpu().numpy()) 185 | depth_interval = (depth_max - depth_min) / depth_values.size(1) 186 | 187 | # step 1. feature extraction 188 | features = [] 189 | for nview_idx in range(imgs.size(1)): # imgs shape (B, N, C, H, W) 190 | img = imgs[:, nview_idx] 191 | features.append(self.feature(img)) 192 | 193 | outputs = {} 194 | depth, cur_depth = None, None 195 | for stage_idx in [2, 1, 0]: # range(self.num_stage): 196 | # stage feature, proj_mats, scales 197 | features_stage = [feat["stage_{}".format(stage_idx)] for feat in features] 198 | proj_matrices_stage = proj_matrices["stage_{}".format(stage_idx)] # + 1)] 199 | stage_scale = self.stage_infos["stage_{}".format(stage_idx)]["scale"] 200 | 201 | if depth is not None: 202 | if self.grad_method == "detach": 203 | cur_depth = depth.detach() 204 | else: 205 | cur_depth = depth 206 | cur_depth = F.interpolate( 207 | cur_depth.unsqueeze(1), 208 | [img.shape[2], img.shape[3]], 209 | mode="bilinear", 210 | align_corners=Align_Corners_Range, 211 | ).squeeze(1) 212 | else: 213 | cur_depth = depth_values 214 | depth_range_samples = get_depth_range_samples( 215 | cur_depth=cur_depth, 216 | ndepth=self.ndepths[2 - stage_idx], 217 | depth_inteval_pixel=self.depth_interals_ratio[2 - stage_idx] * depth_interval, 218 | dtype=img[0].dtype, 219 | device=img[0].device, 220 | shape=[img.shape[0], img.shape[2], img.shape[3]], 221 | max_depth=depth_max, 222 | min_depth=depth_min, 223 | ) 224 | 225 | outputs_stage = self.DepthNet( 226 | features_stage, 227 | proj_matrices_stage, 228 | depth_values=F.interpolate( 229 | depth_range_samples.unsqueeze(1), 230 | [ 231 | self.ndepths[2 - stage_idx], 232 | img.shape[2] // int(stage_scale), 233 | img.shape[3] // int(stage_scale), 234 | ], 235 | mode="trilinear", 236 | align_corners=Align_Corners_Range, 237 | ).squeeze(1), 238 | num_depth=self.ndepths[2 - stage_idx], 239 | cost_regularization=self.cost_regularization 240 | if self.share_cr 241 | else self.cost_regularization[stage_idx], 242 | hints=hints, 243 | validhints=validhints, 244 | scale=stage_scale, 245 | ) 246 | 247 | depth = outputs_stage["depth"] # .unsqueeze(1) 248 | 249 | outputs["stage_{}".format(stage_idx)] = depth.unsqueeze(1) 250 | 251 | # depth map refinement 252 | if self.refine: 253 | refined_depth = self.refine_network(torch.cat((imgs[:, 0], depth), 1)) 254 | outputs["refined_depth"] = refined_depth 255 | 256 | return { 257 | "depth": outputs["refined_depth"] if self.refine else outputs, 258 | "loss_data": { 259 | "depth": outputs, 260 | }, 261 | "photometric_confidence": outputs_stage["photometric_confidence"], 262 | } 263 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/d2hc_rmvsnet/__init__.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import warnings 3 | from types import SimpleNamespace 4 | from typing import Any, Optional 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | from torch import Tensor 10 | 11 | from .drmvsnet import D2HCRMVSNet 12 | from .vamvsnet import mvsnet_loss 13 | 14 | __all__ = ["mvsnet_loss", "D2HCRMVSNet", "NetBuilder", "SimpleInterfaceNet"] 15 | 16 | _logger = logging.getLogger(__name__) 17 | 18 | 19 | _DATASETS_SIZE = { 20 | "dtu_yao": (512, 640), 21 | "blended_mvs": (576, 768), 22 | "blended_mvg": (576, 768), 23 | } 24 | 25 | 26 | class SimpleInterfaceNet(nn.Module): 27 | """ 28 | Simple common interface to call the pretrained models 29 | """ 30 | 31 | def __init__(self, *args, **kwargs): 32 | super().__init__() 33 | self.model = D2HCRMVSNet(*args, **kwargs) 34 | self.all_outputs: dict[str, Any] = {} 35 | 36 | def forward( 37 | self, 38 | imgs: Tensor, 39 | intrinsics: Tensor, 40 | extrinsics: Tensor, 41 | depth_values: Tensor, 42 | hints: Optional[Tensor] = None, 43 | ): 44 | with warnings.catch_warnings(): 45 | warnings.simplefilter("ignore", UserWarning) 46 | 47 | # compute poses 48 | proj_mat = extrinsics.clone() 49 | intrinsics_copy = intrinsics.clone() 50 | intrinsics_copy[..., :2, :] = intrinsics_copy[..., :2, :] / 4 51 | proj_mat[..., :3, :4] = intrinsics_copy @ proj_mat[..., :3, :4] 52 | 53 | # image resize 54 | imgs = torch.stack( 55 | [ 56 | F.interpolate(img, scale_factor=0.25, mode="bilinear", align_corners=True) 57 | for img in torch.unbind(imgs, 2) 58 | ], 59 | 2, 60 | ) 61 | 62 | # validhints 63 | validhints = None 64 | if hints is not None: 65 | validhints = (hints > 0).to(torch.float32) 66 | 67 | # call 68 | out = self.model(imgs, proj_mat, depth_values, hints, validhints) 69 | self.all_outputs = out 70 | return out["depth"]["stage_0"] 71 | 72 | 73 | class NetBuilder(nn.Module): 74 | def __init__(self, args: SimpleNamespace): 75 | super().__init__() 76 | 77 | self.model = D2HCRMVSNet() 78 | self.loss = mvsnet_loss 79 | 80 | def forward(self, batch: dict): 81 | 82 | hints, validhints = None, None 83 | if "hints" in batch: 84 | hints = batch["hints"] 85 | validhints = (hints > 0).to(torch.float32) 86 | 87 | out = self.model( 88 | batch["imgs"]["stage_2"], 89 | batch["proj_matrices"]["stage_2"], 90 | batch["depth_values"], 91 | hints, 92 | validhints, 93 | ) 94 | return out 95 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/d2hc_rmvsnet/convlstm.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from torch.autograd import Variable 4 | 5 | from .module import * 6 | 7 | 8 | class ConvLSTMCell(nn.Module): 9 | def __init__(self, input_size, scale, input_dim, hidden_dim, kernel_size, bias=True): 10 | """ 11 | Initialize ConvLSTM cell. 12 | 13 | Parameters 14 | ---------- 15 | input_size: (int, int) 16 | Height and width of input tensor as (height, width). 17 | input_dim: int 18 | Number of channels of input tensor. 19 | hidden_dim: int 20 | Number of channels of hidden state. 21 | kernel_size: (int, int) 22 | Size of the convolutional kernel. 23 | bias: bool 24 | Whether or not to add the bias. 25 | """ 26 | 27 | super(ConvLSTMCell, self).__init__() 28 | 29 | self.height, self.width = input_size 30 | self.input_dim = input_dim 31 | self.hidden_dim = hidden_dim 32 | self.scale = scale 33 | 34 | self.kernel_size = kernel_size 35 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2 36 | self.bias = bias 37 | 38 | self.conv = nn.Conv2d( 39 | in_channels=self.input_dim + self.hidden_dim, 40 | out_channels=4 * self.hidden_dim, 41 | kernel_size=self.kernel_size, 42 | padding=self.padding, 43 | bias=self.bias, 44 | ) 45 | 46 | def forward(self, input_tensor, cur_state): 47 | 48 | h_cur, c_cur = cur_state 49 | 50 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis 51 | 52 | combined_conv = self.conv(combined) 53 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1) 54 | i = torch.sigmoid(cc_i) 55 | f = torch.sigmoid(cc_f) 56 | o = torch.sigmoid(cc_o) 57 | g = torch.tanh(cc_g) 58 | 59 | c_next = f * c_cur + i * g 60 | h_next = o * torch.tanh(c_next) 61 | 62 | return h_next, c_next 63 | 64 | def init_hidden(self, batch_size, height=None, width=None): 65 | if height is None: 66 | height = self.height 67 | if width is None: 68 | width = self.width 69 | return ( 70 | Variable( 71 | torch.zeros( 72 | batch_size, self.hidden_dim, int(height / self.scale), int(width / self.scale) 73 | ) 74 | ).to(self.conv.weight.device), 75 | Variable( 76 | torch.zeros( 77 | batch_size, self.hidden_dim, int(height / self.scale), int(width / self.scale) 78 | ) 79 | ).to(self.conv.weight.device), 80 | ) 81 | 82 | 83 | class ConvBnLSTMCell(ConvLSTMCell): 84 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias=True): 85 | super(ConvBnLSTMCell, self).__init__(input_size, input_dim, hidden_dim, kernel_size, bias) 86 | 87 | # self.conv = ConvBnReLU(in_channels=self.input_dim + self.hidden_dim, 88 | self.conv = ConvBn( 89 | in_channels=self.input_dim + self.hidden_dim, 90 | out_channels=4 * self.hidden_dim, 91 | kernel_size=self.kernel_size, 92 | stride=1, 93 | padding=self.padding, 94 | ) # bias = False in_channels, out_channels, kernel_size=3, stride=1, pad=1 95 | 96 | 97 | class ConvGnLSTMCell(ConvLSTMCell): 98 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias=True): 99 | super(ConvGnLSTMCell, self).__init__(input_size, input_dim, hidden_dim, kernel_size, bias) 100 | 101 | # self.conv = ConvGnReLU(in_channels=self.input_dim + self.hidden_dim, 102 | self.conv = ConvGn( 103 | in_channels=self.input_dim + self.hidden_dim, 104 | out_channels=4 * self.hidden_dim, 105 | kernel_size=self.kernel_size, 106 | stride=1, 107 | padding=self.padding, 108 | ) 109 | 110 | 111 | class ConvLSTM(nn.Module): 112 | def __init__( 113 | self, 114 | input_size, 115 | input_dim, 116 | hidden_dim, 117 | kernel_size, 118 | num_layers, 119 | batch_first=False, 120 | bias=True, 121 | return_all_layers=False, 122 | ): 123 | super(ConvLSTM, self).__init__() 124 | 125 | self._check_kernel_size_consistency(kernel_size) 126 | 127 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers 128 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers) 129 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers) 130 | if not len(kernel_size) == len(hidden_dim) == num_layers: 131 | raise ValueError("Inconsistent list length.") 132 | 133 | self.height, self.width = input_size 134 | 135 | self.input_dim = input_dim 136 | self.hidden_dim = hidden_dim 137 | self.kernel_size = kernel_size 138 | self.num_layers = num_layers 139 | self.batch_first = batch_first 140 | self.bias = bias 141 | self.return_all_layers = return_all_layers 142 | 143 | cell_list = [] 144 | for i in range(0, self.num_layers): 145 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1] 146 | 147 | cell_list.append( 148 | ConvLSTMCell( 149 | input_size=(self.height, self.width), 150 | input_dim=cur_input_dim, 151 | hidden_dim=self.hidden_dim[i], 152 | kernel_size=self.kernel_size[i], 153 | bias=self.bias, 154 | ) 155 | ) 156 | 157 | self.cell_list = nn.ModuleList(cell_list) 158 | 159 | def forward(self, input_tensor, hidden_state=None): 160 | """ 161 | 162 | Parameters 163 | ---------- 164 | input_tensor: todo 165 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w) 166 | hidden_state: todo 167 | None. todo implement stateful 168 | 169 | Returns 170 | ------- 171 | last_state_list, layer_output 172 | """ 173 | if not self.batch_first: 174 | # (t, b, c, h, w) -> (b, t, c, h, w) 175 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4) 176 | 177 | # Implement stateful ConvLSTM 178 | if hidden_state is not None: 179 | raise NotImplementedError() 180 | else: 181 | hidden_state = self._init_hidden(batch_size=input_tensor.size(0)) 182 | 183 | layer_output_list = [] 184 | last_state_list = [] 185 | 186 | seq_len = input_tensor.size(1) 187 | cur_layer_input = input_tensor 188 | 189 | for layer_idx in range(self.num_layers): 190 | 191 | h, c = hidden_state[layer_idx] 192 | output_inner = [] 193 | for t in range(seq_len): 194 | 195 | h, c = self.cell_list[layer_idx]( 196 | input_tensor=cur_layer_input[:, t, :, :, :], cur_state=[h, c] 197 | ) 198 | output_inner.append(h) 199 | 200 | layer_output = torch.stack(output_inner, dim=1) 201 | cur_layer_input = layer_output 202 | 203 | layer_output_list.append(layer_output) 204 | last_state_list.append([h, c]) 205 | 206 | if not self.return_all_layers: 207 | layer_output_list = layer_output_list[-1:] 208 | last_state_list = last_state_list[-1:] 209 | 210 | return layer_output_list, last_state_list 211 | 212 | def _init_hidden(self, batch_size): 213 | init_states = [] 214 | for i in range(self.num_layers): 215 | init_states.append(self.cell_list[i].init_hidden(batch_size)) 216 | return init_states 217 | 218 | @staticmethod 219 | def _check_kernel_size_consistency(kernel_size): 220 | if not ( 221 | isinstance(kernel_size, tuple) 222 | or ( 223 | isinstance(kernel_size, list) 224 | and all([isinstance(elem, tuple) for elem in kernel_size]) 225 | ) 226 | ): 227 | raise ValueError("`kernel_size` must be tuple or list of tuples") 228 | 229 | @staticmethod 230 | def _extend_for_multilayer(param, num_layers): 231 | if not isinstance(param, list): 232 | param = [param] * num_layers 233 | return param 234 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/d2hc_rmvsnet/submodule.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/python 2 | # -*- coding: utf-8 -*- 3 | # 4 | # Hongwei Yi (hongweiyi@pku.edu.cn) 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | 10 | 11 | def conv(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True): 12 | return nn.Sequential( 13 | nn.Conv2d( 14 | in_channels, 15 | out_channels, 16 | kernel_size=kernel_size, 17 | stride=stride, 18 | dilation=dilation, 19 | padding=((kernel_size - 1) // 2) * dilation, 20 | bias=bias, 21 | ), 22 | nn.ReLU(inplace=True), 23 | ) 24 | 25 | 26 | def convbn(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True): 27 | return nn.Sequential( 28 | nn.Conv2d( 29 | in_channels, 30 | out_channels, 31 | kernel_size=kernel_size, 32 | stride=stride, 33 | dilation=dilation, 34 | padding=((kernel_size - 1) // 2) * dilation, 35 | bias=bias, 36 | ), 37 | nn.BatchNorm2d(out_channels), 38 | # nn.SyncBatchNorm(out_channels), 39 | # nn.LeakyReLU(0.0,inplace=True) 40 | nn.ReLU(inplace=True), 41 | ) 42 | 43 | 44 | def convgnrelu( 45 | in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True, group_channel=8 46 | ): 47 | return nn.Sequential( 48 | nn.Conv2d( 49 | in_channels, 50 | out_channels, 51 | kernel_size=kernel_size, 52 | stride=stride, 53 | dilation=dilation, 54 | padding=((kernel_size - 1) // 2) * dilation, 55 | bias=bias, 56 | ), 57 | nn.GroupNorm(int(max(1, out_channels / group_channel)), out_channels), 58 | nn.ReLU(inplace=True), 59 | ) 60 | 61 | 62 | # def conv3d(in_channels, out_channels, kernel_size=3, stride=1,dilation=1, bias=True): 63 | # return nn.Sequential( 64 | # nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias), 65 | # nn.LeakyReLU(0.0,inplace=True) 66 | # ) 67 | 68 | 69 | def conv3dgn(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True): 70 | return nn.Sequential( 71 | nn.Conv3d( 72 | in_channels, 73 | out_channels, 74 | kernel_size=kernel_size, 75 | stride=stride, 76 | dilation=dilation, 77 | padding=((kernel_size - 1) // 2) * dilation, 78 | bias=bias, 79 | ), 80 | nn.GroupNorm(1, 1), 81 | nn.LeakyReLU(0.0, inplace=True), 82 | ) 83 | 84 | 85 | def conv3d(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True): 86 | return nn.Sequential( 87 | nn.Conv3d( 88 | in_channels, 89 | out_channels, 90 | kernel_size=kernel_size, 91 | stride=stride, 92 | dilation=dilation, 93 | padding=((kernel_size - 1) // 2) * dilation, 94 | bias=bias, 95 | ), 96 | nn.BatchNorm3d(out_channels), 97 | nn.LeakyReLU(0.0, inplace=True), 98 | ) 99 | 100 | 101 | def resnet_block(in_channels, kernel_size=3, dilation=[1, 1], bias=True): 102 | return ResnetBlock(in_channels, kernel_size, dilation, bias=bias) 103 | 104 | 105 | def resnet_block_bn(in_channels, kernel_size=3, dilation=[1, 1], bias=True): 106 | return ResnetBlockBn(in_channels, kernel_size, dilation, bias=bias) 107 | 108 | 109 | class ResnetBlock(nn.Module): 110 | def __init__(self, in_channels, kernel_size, dilation, bias): 111 | super(ResnetBlock, self).__init__() 112 | self.stem = nn.Sequential( 113 | nn.Conv2d( 114 | in_channels, 115 | in_channels, 116 | kernel_size=kernel_size, 117 | stride=1, 118 | dilation=dilation[0], 119 | padding=((kernel_size - 1) // 2) * dilation[0], 120 | bias=bias, 121 | ), 122 | nn.LeakyReLU(0.0, inplace=True), 123 | nn.Conv2d( 124 | in_channels, 125 | in_channels, 126 | kernel_size=kernel_size, 127 | stride=1, 128 | dilation=dilation[1], 129 | padding=((kernel_size - 1) // 2) * dilation[1], 130 | bias=bias, 131 | ), 132 | ) 133 | 134 | def forward(self, x): 135 | out = self.stem(x) + x 136 | return out 137 | 138 | 139 | class ResnetBlockBn(nn.Module): 140 | def __init__(self, in_channels, kernel_size, dilation, bias): 141 | super(ResnetBlockBn, self).__init__() 142 | self.stem = nn.Sequential( 143 | convbn( 144 | in_channels, 145 | in_channels, 146 | kernel_size=kernel_size, 147 | stride=1, 148 | dilation=dilation[0], 149 | bias=bias, 150 | ), 151 | nn.Conv2d( 152 | in_channels, 153 | in_channels, 154 | kernel_size=kernel_size, 155 | stride=1, 156 | dilation=dilation[1], 157 | padding=((kernel_size - 1) // 2) * dilation[1], 158 | bias=bias, 159 | ), 160 | ) 161 | 162 | def forward(self, x): 163 | out = self.stem(x) + x 164 | return out 165 | 166 | 167 | ##### Define weightnet-3d 168 | def volumegatelight(in_channels, kernel_size=3, dilation=[1, 1], bias=True): 169 | return nn.Sequential( 170 | # MSDilateBlock3D(in_channels, kernel_size, dilation, bias), 171 | conv3d(in_channels, 1, kernel_size=1, stride=1, bias=bias), 172 | conv3d(1, 1, kernel_size=1, stride=1), 173 | ) 174 | 175 | 176 | def volumegatelightgn(in_channels, kernel_size=3, dilation=[1, 1], bias=True): 177 | return nn.Sequential( 178 | # MSDilateBlock3D(in_channels, kernel_size, dilation, bias), 179 | conv3dgn(in_channels, 1, kernel_size=1, stride=1, bias=bias), 180 | conv3dgn(1, 1, kernel_size=1, stride=1), 181 | ) 182 | 183 | 184 | ##### Define gatenet 185 | def gatenetbn(bias=True): 186 | return nn.Sequential( 187 | convbn(32, 16, kernel_size=3, stride=1, pad=1, dilation=1, bias=bias), 188 | resnet_block_bn(16, kernel_size=1), 189 | nn.Conv2d(16, 1, kernel_size=1, padding=0), 190 | nn.Sigmoid(), 191 | ) 192 | 193 | 194 | def pillarnetbn(bias=True): 195 | return nn.Sequential( 196 | nn.Linear(192, 32, bias=bias), 197 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE, inplace=True), 198 | nn.Linear(32, 2), 199 | ) 200 | 201 | 202 | class ResnetBlockGn(nn.Module): 203 | def __init__(self, in_channels, kernel_size, dilation, bias, group_channel=8): 204 | super(ResnetBlockGn, self).__init__() 205 | self.stem = nn.Sequential( 206 | convgnrelu( 207 | in_channels, 208 | in_channels, 209 | kernel_size=kernel_size, 210 | stride=1, 211 | dilation=dilation[0], 212 | bias=bias, 213 | group_channel=group_channel, 214 | ), 215 | nn.Conv2d( 216 | in_channels, 217 | in_channels, 218 | kernel_size=kernel_size, 219 | stride=1, 220 | dilation=dilation[1], 221 | padding=((kernel_size - 1) // 2) * dilation[1], 222 | bias=bias, 223 | ), 224 | nn.GroupNorm(int(max(1, in_channels / group_channel)), in_channels), 225 | ) 226 | self.relu = nn.ReLU(inplace=True) 227 | 228 | def forward(self, x): 229 | out = self.stem(x) + x 230 | out = self.relu(out) 231 | return out 232 | 233 | 234 | def resnet_block_gn(in_channels, kernel_size=3, dilation=[1, 1], bias=True, group_channel=8): 235 | return ResnetBlockGn( 236 | in_channels, kernel_size, dilation, bias=bias, group_channel=group_channel 237 | ) 238 | 239 | 240 | def gatenet(gn=True, in_channels=32, bias=True): # WORK 241 | if gn: 242 | return nn.Sequential( 243 | # nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, dilation=1, padding=1, bias=bias), # in_channels=64 244 | # nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE,inplace=True), 245 | convgnrelu( 246 | in_channels, 4, kernel_size=3, stride=1, dilation=1, bias=bias 247 | ), # 4: 10G,8.6G; 248 | resnet_block_gn(4, kernel_size=1), 249 | nn.Conv2d(4, 1, kernel_size=1, padding=0), 250 | nn.Sigmoid(), 251 | ) 252 | else: 253 | return nn.Sequential( 254 | nn.Conv2d( 255 | in_channels, 16, kernel_size=3, stride=1, dilation=1, padding=1, bias=bias 256 | ), # in_channels=64 257 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE, inplace=True), 258 | resnet_block(16, kernel_size=1), 259 | nn.Conv2d(16, 1, kernel_size=1, padding=0), 260 | nn.Sigmoid(), 261 | ) 262 | 263 | 264 | def gatenet_m4(gn=True, in_channels=32, bias=True): 265 | if gn: 266 | return nn.Sequential( 267 | # nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, dilation=1, padding=1, bias=bias), # in_channels=64 268 | # nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE,inplace=True), 269 | convgnrelu( 270 | in_channels, 8, kernel_size=3, stride=1, dilation=1, bias=bias 271 | ), # 4: 10G,8.6G; 272 | resnet_block_gn(8, kernel_size=1), 273 | nn.Conv2d(8, 1, kernel_size=1, padding=0), 274 | nn.Sigmoid(), 275 | ) 276 | else: 277 | return nn.Sequential( 278 | nn.Conv2d( 279 | in_channels, 16, kernel_size=3, stride=1, dilation=1, padding=1, bias=bias 280 | ), # in_channels=64 281 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE, inplace=True), 282 | resnet_block(16, kernel_size=1), 283 | nn.Conv2d(16, 1, kernel_size=1, padding=0), 284 | nn.Sigmoid(), 285 | ) 286 | 287 | 288 | def pillarnet(in_channels=192, bias=True): # origin_pillarnet: 192, 96, 48, 24 289 | return nn.Sequential( 290 | nn.Linear(in_channels, 32, bias=bias), 291 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE, inplace=True), 292 | nn.Linear(32, 2), 293 | ) 294 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/d2hc_rmvsnet/vamvsnet_high_submodule.py: -------------------------------------------------------------------------------- 1 | import sys 2 | from copy import deepcopy 3 | 4 | import torch 5 | import torch.nn as nn 6 | import torch.nn.functional as F 7 | 8 | from .module import * 9 | 10 | # Multi-scale feature extractor && Coarse To Fine Regression Module 11 | 12 | 13 | class FeatureNetHigh(nn.Module): # Original Paper Setting 14 | def __init__(self): 15 | super(FeatureNetHigh, self).__init__() 16 | self.inplanes = 32 17 | 18 | self.conv0 = ConvBnReLU(3, 8, 3, 1, 1) 19 | self.conv1 = ConvBnReLU(8, 8, 3, 1, 1) 20 | 21 | self.conv2 = ConvBnReLU(8, 16, 5, 2, 2) 22 | self.conv3 = ConvBnReLU(16, 16, 3, 1, 1) 23 | self.conv4 = ConvBnReLU(16, 16, 3, 1, 1) 24 | 25 | self.conv5 = ConvBnReLU(16, 32, 5, 2, 2) 26 | self.conv6 = ConvBnReLU(32, 32, 3, 1, 1) 27 | 28 | self.conv7 = ConvBnReLU(32, 32, 5, 2, 2) 29 | self.conv8 = ConvBnReLU(32, 32, 3, 1, 1) 30 | 31 | self.conv9 = ConvBnReLU(32, 64, 5, 2, 2) 32 | self.conv10 = ConvBnReLU(64, 64, 3, 1, 1) 33 | 34 | self.conv11 = ConvBnReLU(64, 64, 5, 2, 2) 35 | self.conv12 = ConvBnReLU(64, 64, 3, 1, 1) 36 | 37 | self.feature1 = nn.Conv2d(32, 32, 3, 1, 1) 38 | 39 | self.feature2 = nn.Conv2d(32, 32, 3, 1, 1) 40 | 41 | self.feature3 = nn.Conv2d(64, 64, 3, 1, 1) 42 | 43 | self.feature4 = nn.Conv2d(64, 64, 3, 1, 1) 44 | 45 | def forward(self, x): 46 | x = self.conv1(self.conv0(x)) 47 | x = self.conv4(self.conv3(self.conv2(x))) 48 | x = self.conv6(self.conv5(x)) 49 | feature1 = self.feature1(x) 50 | x = self.conv8(self.conv7(x)) 51 | feature2 = self.feature2(x) 52 | x = self.conv10(self.conv9(x)) 53 | feature3 = self.feature3(x) 54 | x = self.conv12(self.conv11(x)) 55 | feature4 = self.feature4(x) 56 | return [feature1, feature2, feature3, feature4] 57 | 58 | 59 | class FeatureNetHighGN(nn.Module): # Original Paper Setting 60 | def __init__(self): 61 | super(FeatureNetHighGN, self).__init__() 62 | self.inplanes = 32 63 | 64 | self.conv0 = ConvGnReLU(3, 8, 3, 1, 1) 65 | self.conv1 = ConvGnReLU(8, 8, 3, 1, 1) 66 | 67 | self.conv2 = ConvGnReLU(8, 16, 5, 2, 2) 68 | self.conv3 = ConvGnReLU(16, 16, 3, 1, 1) 69 | self.conv4 = ConvGnReLU(16, 16, 3, 1, 1) 70 | 71 | self.conv5 = ConvGnReLU(16, 32, 5, 2, 2) 72 | self.conv6 = ConvGnReLU(32, 32, 3, 1, 1) 73 | 74 | self.conv7 = ConvGnReLU(32, 32, 5, 2, 2) 75 | self.conv8 = ConvGnReLU(32, 32, 3, 1, 1) 76 | 77 | self.conv9 = ConvGnReLU(32, 64, 5, 2, 2) 78 | self.conv10 = ConvGnReLU(64, 64, 3, 1, 1) 79 | 80 | self.conv11 = ConvGnReLU(64, 64, 5, 2, 2) 81 | self.conv12 = ConvGnReLU(64, 64, 3, 1, 1) 82 | 83 | self.feature1 = nn.Conv2d(32, 32, 3, 1, 1) 84 | 85 | self.feature2 = nn.Conv2d(32, 32, 3, 1, 1) 86 | 87 | self.feature3 = nn.Conv2d(64, 64, 3, 1, 1) 88 | 89 | self.feature4 = nn.Conv2d(64, 64, 3, 1, 1) 90 | 91 | def forward(self, x): 92 | x = self.conv1(self.conv0(x)) 93 | x = self.conv4(self.conv3(self.conv2(x))) 94 | x = self.conv6(self.conv5(x)) 95 | feature1 = self.feature1(x) 96 | x = self.conv8(self.conv7(x)) 97 | feature2 = self.feature2(x) 98 | x = self.conv10(self.conv9(x)) 99 | feature3 = self.feature3(x) 100 | x = self.conv12(self.conv11(x)) 101 | feature4 = self.feature4(x) 102 | return [feature1, feature2, feature3, feature4] 103 | 104 | 105 | class RegNetUS0_Coarse2Fine(nn.Module): 106 | def __init__(self, origin_size=False, dp_ratio=0.0, image_scale=0.25): 107 | super(RegNetUS0_Coarse2Fine, self).__init__() 108 | self.origin_size = origin_size 109 | self.image_scale = image_scale 110 | 111 | self.conv0 = ConvBnReLU3D(32, 8) 112 | 113 | self.conv1 = ConvBnReLU3D(32, 16, stride=2) 114 | self.conv2 = ConvBnReLU3D(16, 16) 115 | 116 | self.conv3 = ConvBnReLU3D(16, 32, stride=2) 117 | self.conv4 = ConvBnReLU3D(32, 32) 118 | 119 | self.conv5 = ConvBnReLU3D(32, 64, stride=2) 120 | self.conv6 = ConvBnReLU3D(64, 64) 121 | 122 | self.conv7 = nn.Sequential( 123 | nn.ConvTranspose3d( 124 | 128, 32, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False 125 | ), 126 | nn.BatchNorm3d(32), 127 | nn.ReLU(inplace=True), 128 | ) 129 | 130 | self.conv9 = nn.Sequential( 131 | nn.ConvTranspose3d( 132 | 97, 16, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False 133 | ), 134 | nn.BatchNorm3d(16), 135 | nn.ReLU(inplace=True), 136 | ) 137 | 138 | self.conv11 = nn.Sequential( 139 | nn.ConvTranspose3d( 140 | 49, 8, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False 141 | ), 142 | nn.BatchNorm3d(8), 143 | nn.ReLU(inplace=True), 144 | ) 145 | 146 | self.prob1 = nn.Conv3d(41, 1, 1, bias=False) 147 | self.dropout1 = nn.Dropout3d(p=dp_ratio) 148 | self.prob2 = nn.Conv3d(49, 1, 1, bias=False) 149 | self.dropout2 = nn.Dropout3d(p=dp_ratio) 150 | self.prob3 = nn.Conv3d(97, 1, 1, bias=False) 151 | self.dropout3 = nn.Dropout3d(p=dp_ratio) 152 | self.prob4 = nn.Conv3d(128, 1, 1, bias=False) 153 | self.dropout4 = nn.Dropout3d(p=dp_ratio) 154 | # add Drop out 155 | 156 | def forward(self, x_list): 157 | x1, x2, x3, x4 = x_list # 32*192, 32*96, 64*48, 64*24 158 | input_shape = x1.shape 159 | 160 | conv0 = self.conv0(x1) 161 | conv1 = self.conv1(x1) 162 | conv3 = self.conv3(conv1) 163 | conv5 = self.conv5(conv3) 164 | 165 | x = torch.cat([self.conv6(conv5), x4], 1) 166 | prob4 = self.dropout4(self.prob4(x)) 167 | # prob4 = self.prob4(x) 168 | x = self.conv7(x) + self.conv4(conv3) 169 | x = torch.cat( 170 | [x, x3, F.interpolate(prob4, scale_factor=2, mode="trilinear", align_corners=True)], 1 171 | ) 172 | prob3 = self.dropout3(self.prob3(x)) 173 | # prob3 = self.prob3(x) 174 | x = self.conv9(x) + self.conv2(conv1) 175 | x = torch.cat( 176 | [x, x2, F.interpolate(prob3, scale_factor=2, mode="trilinear", align_corners=True)], 1 177 | ) 178 | prob2 = self.dropout2(self.prob2(x)) 179 | # prob2 = self.prob2(x) 180 | x = self.conv11(x) + conv0 181 | x = torch.cat( 182 | [x, x1, F.interpolate(prob2, scale_factor=2, mode="trilinear", align_corners=True)], 1 183 | ) 184 | 185 | if self.origin_size and self.image_scale == 0.50: 186 | x = F.interpolate( 187 | x, 188 | size=(input_shape[2], input_shape[3] * 2, input_shape[4] * 2), 189 | mode="trilinear", 190 | align_corners=True, 191 | ) 192 | prob1 = self.dropout1(self.prob1(x)) 193 | # prob1 = self.prob1(x) # without dropout 194 | # if self.origin_size: 195 | # x = F.interpolate(x, size=(input_shape[2], input_shape[3]*4, input_shape[4]*4), mode='trilinear', align_corners=True) 196 | return [prob1, prob2, prob3, prob4] 197 | 198 | 199 | class RegNetUS0_Coarse2FineGN(nn.Module): 200 | def __init__(self, origin_size=False, dp_ratio=0.0, image_scale=0.25): 201 | super(RegNetUS0_Coarse2FineGN, self).__init__() 202 | self.origin_size = origin_size 203 | self.image_scale = image_scale 204 | 205 | self.conv0 = ConvGnReLU3D(32, 8) 206 | 207 | self.conv1 = ConvGnReLU3D(32, 16, stride=2) 208 | self.conv2 = ConvGnReLU3D(16, 16) 209 | 210 | self.conv3 = ConvGnReLU3D(16, 32, stride=2) 211 | self.conv4 = ConvGnReLU3D(32, 32) 212 | 213 | self.conv5 = ConvGnReLU3D(32, 64, stride=2) 214 | self.conv6 = ConvGnReLU3D(64, 64) 215 | 216 | self.conv7 = nn.Sequential( 217 | nn.ConvTranspose3d( 218 | 128, 32, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False 219 | ), 220 | # nn.BatchNorm3d(32), 221 | nn.GroupNorm(4, 32), 222 | nn.ReLU(inplace=True), 223 | ) 224 | 225 | self.conv9 = nn.Sequential( 226 | nn.ConvTranspose3d( 227 | 97, 16, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False 228 | ), 229 | nn.GroupNorm(2, 16), 230 | nn.ReLU(inplace=True), 231 | ) 232 | 233 | self.conv11 = nn.Sequential( 234 | nn.ConvTranspose3d( 235 | 49, 8, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False 236 | ), 237 | nn.GroupNorm(1, 8), 238 | nn.ReLU(inplace=True), 239 | ) 240 | 241 | self.prob1 = nn.Conv3d(41, 1, 1, bias=False) 242 | self.dropout1 = nn.Dropout3d(p=dp_ratio) 243 | self.prob2 = nn.Conv3d(49, 1, 1, bias=False) 244 | self.dropout2 = nn.Dropout3d(p=dp_ratio) 245 | self.prob3 = nn.Conv3d(97, 1, 1, bias=False) 246 | self.dropout3 = nn.Dropout3d(p=dp_ratio) 247 | self.prob4 = nn.Conv3d(128, 1, 1, bias=False) 248 | self.dropout4 = nn.Dropout3d(p=dp_ratio) 249 | # add Drop out 250 | 251 | def forward(self, x_list): 252 | x1, x2, x3, x4 = x_list # 32*192, 32*96, 64*48, 64*24 253 | # print(x1.shape, x2.shape, x3.shape, x4.shape) 254 | input_shape = x1.shape 255 | 256 | conv0 = self.conv0(x1) 257 | conv1 = self.conv1(x1) 258 | conv3 = self.conv3(conv1) 259 | conv5 = self.conv5(conv3) 260 | 261 | x = torch.cat([self.conv6(conv5), x4], 1) 262 | prob4 = self.dropout4(self.prob4(x)) 263 | # prob4 = self.prob4(x) 264 | x = self.conv7(x) + self.conv4(conv3) 265 | x = torch.cat( 266 | [x, x3, F.interpolate(prob4, scale_factor=2, mode="trilinear", align_corners=True)], 1 267 | ) 268 | prob3 = self.dropout3(self.prob3(x)) 269 | # prob3 = self.prob3(x) 270 | x = self.conv9(x) + self.conv2(conv1) 271 | x = torch.cat( 272 | [x, x2, F.interpolate(prob3, scale_factor=2, mode="trilinear", align_corners=True)], 1 273 | ) 274 | prob2 = self.dropout2(self.prob2(x)) 275 | # prob2 = self.prob2(x) 276 | x = self.conv11(x) + conv0 277 | x = torch.cat( 278 | [x, x1, F.interpolate(prob2, scale_factor=2, mode="trilinear", align_corners=True)], 1 279 | ) 280 | 281 | if self.origin_size and self.image_scale == 0.50: 282 | x = F.interpolate( 283 | x, 284 | size=(input_shape[2], input_shape[3] * 2, input_shape[4] * 2), 285 | mode="trilinear", 286 | align_corners=True, 287 | ) 288 | prob1 = self.dropout1(self.prob1(x)) 289 | # prob1 = self.prob1(x) # without dropout 290 | # if self.origin_size: 291 | # x = F.interpolate(x, size=(input_shape[2], input_shape[3]*4, input_shape[4]*4), mode='trilinear', align_corners=True) 292 | return [prob1, prob2, prob3, prob4] 293 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/mvsnet/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from types import SimpleNamespace 3 | from typing import Any, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import Tensor 8 | 9 | from .mvsnet import MVSNet, mvsnet_loss 10 | 11 | __all__ = ["mvsnet_loss", "MVSNet", "NetBuilder", "SimpleInterfaceNet"] 12 | 13 | 14 | class SimpleInterfaceNet(nn.Module): 15 | """ 16 | Simple common interface to call the pretrained models 17 | """ 18 | 19 | def __init__(self, refine: bool = False): 20 | super().__init__() 21 | self.model = MVSNet(refine=refine) 22 | self.all_outputs: dict[str, Any] = {} 23 | 24 | def forward( 25 | self, 26 | imgs: Tensor, 27 | intrinsics: Tensor, 28 | extrinsics: Tensor, 29 | depth_values: Tensor, 30 | hints: Optional[Tensor] = None, 31 | ): 32 | with warnings.catch_warnings(): 33 | warnings.simplefilter("ignore", UserWarning) 34 | 35 | # compute poses 36 | proj_mat = extrinsics.clone() 37 | intrinsics_copy = intrinsics.clone() 38 | intrinsics_copy[..., :2, :] = intrinsics_copy[..., :2, :] / 4 39 | proj_mat[..., :3, :4] = intrinsics_copy @ proj_mat[..., :3, :4] 40 | 41 | # validhints 42 | validhints = None 43 | if hints is not None: 44 | validhints = (hints > 0).to(torch.float32) 45 | 46 | # call 47 | out = self.model(imgs, proj_mat, depth_values, hints, validhints) 48 | self.all_outputs = out 49 | return out["depth"]["stage_0"] 50 | 51 | 52 | class NetBuilder(nn.Module): 53 | def __init__(self, args: SimpleNamespace): 54 | super().__init__() 55 | self.model = MVSNet(refine=False) 56 | self.loss = mvsnet_loss 57 | 58 | def forward(self, batch: dict): 59 | 60 | hints, validhints = None, None 61 | if "hints" in batch: 62 | hints = batch["hints"] 63 | validhints = (hints > 0).to(torch.float32) 64 | 65 | return self.model( 66 | batch["imgs"]["stage_0"], 67 | batch["proj_matrices"]["stage_2"], 68 | batch["depth_values"], 69 | hints, 70 | validhints, 71 | ) 72 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/mvsnet/module.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | 6 | class ConvBnReLU(nn.Module): 7 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): 8 | super(ConvBnReLU, self).__init__() 9 | self.conv = nn.Conv2d( 10 | in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False 11 | ) 12 | self.bn = nn.BatchNorm2d(out_channels) 13 | 14 | def forward(self, x): 15 | return F.relu(self.bn(self.conv(x)), inplace=True) 16 | 17 | 18 | class ConvBn(nn.Module): 19 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): 20 | super(ConvBn, self).__init__() 21 | self.conv = nn.Conv2d( 22 | in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False 23 | ) 24 | self.bn = nn.BatchNorm2d(out_channels) 25 | 26 | def forward(self, x): 27 | return self.bn(self.conv(x)) 28 | 29 | 30 | class ConvBnReLU3D(nn.Module): 31 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): 32 | super(ConvBnReLU3D, self).__init__() 33 | self.conv = nn.Conv3d( 34 | in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False 35 | ) 36 | self.bn = nn.BatchNorm3d(out_channels) 37 | 38 | def forward(self, x): 39 | return F.relu(self.bn(self.conv(x)), inplace=True) 40 | 41 | 42 | class ConvBn3D(nn.Module): 43 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1): 44 | super(ConvBn3D, self).__init__() 45 | self.conv = nn.Conv3d( 46 | in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False 47 | ) 48 | self.bn = nn.BatchNorm3d(out_channels) 49 | 50 | def forward(self, x): 51 | return self.bn(self.conv(x)) 52 | 53 | 54 | class BasicBlock(nn.Module): 55 | def __init__(self, in_channels, out_channels, stride, downsample=None): 56 | super(BasicBlock, self).__init__() 57 | 58 | self.conv1 = ConvBnReLU(in_channels, out_channels, kernel_size=3, stride=stride, pad=1) 59 | self.conv2 = ConvBn(out_channels, out_channels, kernel_size=3, stride=1, pad=1) 60 | 61 | self.downsample = downsample 62 | self.stride = stride 63 | 64 | def forward(self, x): 65 | out = self.conv1(x) 66 | out = self.conv2(out) 67 | if self.downsample is not None: 68 | x = self.downsample(x) 69 | out += x 70 | return out 71 | 72 | 73 | class Hourglass3d(nn.Module): 74 | def __init__(self, channels): 75 | super(Hourglass3d, self).__init__() 76 | 77 | self.conv1a = ConvBnReLU3D(channels, channels * 2, kernel_size=3, stride=2, pad=1) 78 | self.conv1b = ConvBnReLU3D(channels * 2, channels * 2, kernel_size=3, stride=1, pad=1) 79 | 80 | self.conv2a = ConvBnReLU3D(channels * 2, channels * 4, kernel_size=3, stride=2, pad=1) 81 | self.conv2b = ConvBnReLU3D(channels * 4, channels * 4, kernel_size=3, stride=1, pad=1) 82 | 83 | self.dconv2 = nn.Sequential( 84 | nn.ConvTranspose3d( 85 | channels * 4, 86 | channels * 2, 87 | kernel_size=3, 88 | padding=1, 89 | output_padding=1, 90 | stride=2, 91 | bias=False, 92 | ), 93 | nn.BatchNorm3d(channels * 2), 94 | ) 95 | 96 | self.dconv1 = nn.Sequential( 97 | nn.ConvTranspose3d( 98 | channels * 2, 99 | channels, 100 | kernel_size=3, 101 | padding=1, 102 | output_padding=1, 103 | stride=2, 104 | bias=False, 105 | ), 106 | nn.BatchNorm3d(channels), 107 | ) 108 | 109 | self.redir1 = ConvBn3D(channels, channels, kernel_size=1, stride=1, pad=0) 110 | self.redir2 = ConvBn3D(channels * 2, channels * 2, kernel_size=1, stride=1, pad=0) 111 | 112 | def forward(self, x): 113 | conv1 = self.conv1b(self.conv1a(x)) 114 | conv2 = self.conv2b(self.conv2a(conv1)) 115 | dconv2 = F.relu(self.dconv2(conv2) + self.redir2(conv1), inplace=True) 116 | dconv1 = F.relu(self.dconv1(dconv2) + self.redir1(x), inplace=True) 117 | return dconv1 118 | 119 | 120 | def homo_warping(src_fea, src_proj, ref_proj, depth_values): 121 | # src_fea: [B, C, H, W] 122 | # src_proj: [B, 4, 4] 123 | # ref_proj: [B, 4, 4] 124 | # depth_values: [B, Ndepth] 125 | # out: [B, C, Ndepth, H, W] 126 | batch, channels = src_fea.shape[0], src_fea.shape[1] 127 | num_depth = depth_values.shape[1] 128 | height, width = src_fea.shape[2], src_fea.shape[3] 129 | 130 | with torch.no_grad(): 131 | proj = torch.matmul(src_proj, torch.inverse(ref_proj)) 132 | rot = proj[:, :3, :3] # [B,3,3] 133 | trans = proj[:, :3, 3:4] # [B,3,1] 134 | 135 | y, x = torch.meshgrid( 136 | [ 137 | torch.arange(0, height, dtype=torch.float32, device=src_fea.device), 138 | torch.arange(0, width, dtype=torch.float32, device=src_fea.device), 139 | ] 140 | ) 141 | y, x = y.contiguous(), x.contiguous() 142 | y, x = y.view(height * width), x.view(height * width) 143 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] 144 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W] 145 | rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W] 146 | rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.view( 147 | batch, 1, num_depth, 1 148 | ) # [B, 3, Ndepth, H*W] 149 | proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W] 150 | proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W] 151 | proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1 152 | proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1 153 | proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2] 154 | grid = proj_xy 155 | 156 | warped_src_fea = F.grid_sample( 157 | src_fea, 158 | grid.view(batch, num_depth * height, width, 2), 159 | mode="bilinear", 160 | padding_mode="zeros", 161 | ) 162 | warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height, width) 163 | 164 | return warped_src_fea 165 | 166 | 167 | # p: probability volume [B, D, H, W] 168 | # depth_values: discrete depth values [B, D] 169 | def depth_regression(p, depth_values): 170 | depth_values = depth_values.view(*depth_values.shape, 1, 1) 171 | depth = torch.sum(p * depth_values, 1) 172 | return depth 173 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/mvsnet/mvsnet.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | 6 | from .module import * 7 | 8 | 9 | class FeatureNet(nn.Module): 10 | def __init__(self): 11 | super(FeatureNet, self).__init__() 12 | self.inplanes = 32 13 | 14 | self.conv0 = ConvBnReLU(3, 8, 3, 1, 1) 15 | self.conv1 = ConvBnReLU(8, 8, 3, 1, 1) 16 | 17 | self.conv2 = ConvBnReLU(8, 16, 5, 2, 2) 18 | self.conv3 = ConvBnReLU(16, 16, 3, 1, 1) 19 | self.conv4 = ConvBnReLU(16, 16, 3, 1, 1) 20 | 21 | self.conv5 = ConvBnReLU(16, 32, 5, 2, 2) 22 | self.conv6 = ConvBnReLU(32, 32, 3, 1, 1) 23 | self.feature = nn.Conv2d(32, 32, 3, 1, 1) 24 | 25 | def forward(self, x): 26 | x = self.conv1(self.conv0(x)) 27 | x = self.conv4(self.conv3(self.conv2(x))) 28 | x = self.feature(self.conv6(self.conv5(x))) 29 | return x 30 | 31 | 32 | class CostRegNet(nn.Module): 33 | def __init__(self): 34 | super(CostRegNet, self).__init__() 35 | self.conv0 = ConvBnReLU3D(32, 8) 36 | 37 | self.conv1 = ConvBnReLU3D(8, 16, stride=2) 38 | self.conv2 = ConvBnReLU3D(16, 16) 39 | 40 | self.conv3 = ConvBnReLU3D(16, 32, stride=2) 41 | self.conv4 = ConvBnReLU3D(32, 32) 42 | 43 | self.conv5 = ConvBnReLU3D(32, 64, stride=2) 44 | self.conv6 = ConvBnReLU3D(64, 64) 45 | 46 | self.conv7 = nn.Sequential( 47 | nn.ConvTranspose3d( 48 | 64, 32, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False 49 | ), 50 | nn.BatchNorm3d(32), 51 | nn.ReLU(inplace=True), 52 | ) 53 | 54 | self.conv9 = nn.Sequential( 55 | nn.ConvTranspose3d( 56 | 32, 16, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False 57 | ), 58 | nn.BatchNorm3d(16), 59 | nn.ReLU(inplace=True), 60 | ) 61 | 62 | self.conv11 = nn.Sequential( 63 | nn.ConvTranspose3d( 64 | 16, 8, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False 65 | ), 66 | nn.BatchNorm3d(8), 67 | nn.ReLU(inplace=True), 68 | ) 69 | 70 | self.prob = nn.Conv3d(8, 1, 3, stride=1, padding=1) 71 | 72 | def _split_pad(self, pad): 73 | if pad % 2 == 0: 74 | return pad // 2, pad // 2 75 | else: 76 | pad_1 = pad // 2 77 | pad_2 = (pad // 2) + 1 78 | return pad_1, pad_2 79 | 80 | def _generate_slice(self, pad): 81 | if pad == 0: 82 | return slice(0, None) 83 | elif pad % 2 == 0: 84 | return slice(pad // 2, -pad // 2) 85 | else: 86 | pad_1 = pad // 2 87 | pad_2 = (pad // 2) + 1 88 | return slice(pad_1, -pad_2) 89 | 90 | def _pad_to_div_by(self, x, *, div_by=8): 91 | _, _, _, h, w = x.shape 92 | new_h = int(np.ceil(h / div_by)) * div_by 93 | new_w = int(np.ceil(w / div_by)) * div_by 94 | pad_h_l, pad_h_r = self._split_pad(new_h - h) 95 | pad_w_t, pad_w_b = self._split_pad(new_w - w) 96 | return F.pad(x, (pad_w_t, pad_w_b, pad_h_l, pad_h_r)) 97 | 98 | def forward(self, x): 99 | 100 | # padding 101 | _, _, _, h, w = x.shape 102 | x = self._pad_to_div_by(x, div_by=8) 103 | _, _, _, new_h, new_w = x.shape 104 | 105 | # regularization 106 | conv0 = self.conv0(x) 107 | conv2 = self.conv2(self.conv1(conv0)) 108 | conv4 = self.conv4(self.conv3(conv2)) 109 | x = self.conv6(self.conv5(conv4)) 110 | x = conv4 + self.conv7(x) 111 | x = conv2 + self.conv9(x) 112 | x = conv0 + self.conv11(x) 113 | x = self.prob(x) 114 | 115 | # unpadding 116 | slice_h = self._generate_slice(new_h - h) 117 | slice_w = self._generate_slice(new_w - w) 118 | x = x[..., slice_h, slice_w] 119 | 120 | return x 121 | 122 | 123 | class RefineNet(nn.Module): 124 | def __init__(self): 125 | super(RefineNet, self).__init__() 126 | self.conv1 = ConvBnReLU(4, 32) 127 | self.conv2 = ConvBnReLU(32, 32) 128 | self.conv3 = ConvBnReLU(32, 32) 129 | self.res = ConvBnReLU(32, 1) 130 | 131 | def forward(self, img, depth_init): 132 | concat = F.cat((img, depth_init), dim=1) 133 | depth_residual = self.res(self.conv3(self.conv2(self.conv1(concat)))) 134 | depth_refined = depth_init + depth_residual 135 | return depth_refined 136 | 137 | 138 | class MVSNet(nn.Module): 139 | def __init__(self, refine=True): 140 | super(MVSNet, self).__init__() 141 | self.refine = refine 142 | # self.depth_interval = depth_interval 143 | 144 | self.feature = FeatureNet() 145 | self.cost_regularization = CostRegNet() 146 | if self.refine: 147 | self.refine_network = RefineNet() 148 | 149 | def forward(self, imgs, proj_matrices, depth_values, hints=None, validhints=None): 150 | imgs = torch.unbind(imgs, 1) 151 | proj_matrices = torch.unbind(proj_matrices, 1) 152 | assert len(imgs) == len( 153 | proj_matrices 154 | ), "Different number of images and projection matrices" 155 | num_depth = depth_values.shape[1] 156 | num_views = len(imgs) 157 | 158 | # step 1. feature extraction 159 | # in: images; out: 32-channel feature maps 160 | features = [self.feature(img) for img in imgs] 161 | ref_feature, src_features = features[0], features[1:] 162 | ref_proj, src_projs = proj_matrices[0], proj_matrices[1:] 163 | 164 | # step 2. differentiable homograph, build cost volume 165 | ref_volume = ref_feature.unsqueeze(2).repeat(1, 1, num_depth, 1, 1) 166 | volume_sum = ref_volume 167 | volume_sq_sum = ref_volume ** 2 168 | del ref_volume 169 | for src_fea, src_proj in zip(src_features, src_projs): 170 | # warpped features 171 | warped_volume = homo_warping(src_fea, src_proj, ref_proj, depth_values) 172 | if self.training: 173 | volume_sum = volume_sum + warped_volume 174 | volume_sq_sum = volume_sq_sum + warped_volume ** 2 175 | else: 176 | # TODO: this is only a temporal solution to save memory, better way? 177 | volume_sum += warped_volume 178 | volume_sq_sum += warped_volume.pow_( 179 | 2 180 | ) # the memory of warped_volume has been modified 181 | del warped_volume 182 | # aggregate multiple feature volumes by variance 183 | volume_variance = volume_sq_sum.div_(num_views).sub_(volume_sum.div_(num_views).pow_(2)) 184 | 185 | # TODO cost-volume modulation 186 | if hints is not None and validhints is not None: 187 | batch_size, feats, height, width = ref_feature.shape 188 | GAUSSIAN_HEIGHT = 10.0 189 | GAUSSIAN_WIDTH = 1.0 190 | 191 | # image features are one fourth the original size: subsample the hints and divide them by four 192 | hints = hints 193 | hints = F.interpolate(hints, scale_factor=0.25, mode="nearest").unsqueeze(1) 194 | validhints = validhints 195 | validhints = F.interpolate(validhints, scale_factor=0.25, mode="nearest").unsqueeze(1) 196 | hints = hints * validhints 197 | 198 | # add feature and disparity dimensions to hints and validhints 199 | # and repeat their values along those dimensions, to obtain the same size as cost 200 | hints = hints.expand(-1, feats, num_depth, -1, -1) 201 | validhints = validhints.expand(-1, feats, num_depth, -1, -1) 202 | 203 | # create a tensor of the same size as cost, with disparities 204 | # between 0 and num_disp-1 along the disparity dimension 205 | depth_hyps = ( 206 | depth_values.unsqueeze(1) 207 | .unsqueeze(3) 208 | .unsqueeze(4) 209 | .expand(batch_size, feats, -1, height, width) 210 | .detach() 211 | ) 212 | volume_variance = volume_variance * ( 213 | (1 - validhints) 214 | + validhints 215 | * GAUSSIAN_HEIGHT 216 | * (1 - torch.exp(-((depth_hyps - hints) ** 2) / (2 * GAUSSIAN_WIDTH ** 2))) 217 | ) 218 | 219 | # step 3. cost volume regularization 220 | cost_reg = self.cost_regularization(volume_variance) 221 | # cost_reg = F.upsample(cost_reg, [num_depth * 4, img_height, img_width], mode='trilinear') 222 | cost_reg = cost_reg.squeeze(1) 223 | prob_volume = F.softmax(-cost_reg, dim=1) 224 | depth = depth_regression(prob_volume, depth_values=depth_values).unsqueeze(1) 225 | depth = F.interpolate(depth, scale_factor=4, mode="bilinear") 226 | 227 | with torch.no_grad(): 228 | # photometric confidence 229 | prob_volume_sum4 = ( 230 | 4 231 | * F.avg_pool3d( 232 | F.pad(prob_volume.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)), 233 | (4, 1, 1), 234 | stride=1, 235 | padding=0, 236 | ).squeeze(1) 237 | ) 238 | depth_index = depth_regression( 239 | prob_volume, 240 | depth_values=torch.arange(num_depth, device=prob_volume.device, dtype=torch.float), 241 | ).long() 242 | photometric_confidence = torch.gather( 243 | prob_volume_sum4, 1, depth_index.unsqueeze(1) 244 | ).squeeze(1) 245 | photometric_confidence = F.interpolate( 246 | photometric_confidence[None], scale_factor=4, mode="bilinear" 247 | )[0] 248 | 249 | depth_dict = {} 250 | depth_dict["stage_0"] = depth 251 | 252 | # step 4. depth map refinement 253 | if not self.refine: 254 | return { 255 | "depth": depth_dict, 256 | "photometric_confidence": photometric_confidence, 257 | "loss_data": depth_dict, 258 | } 259 | else: 260 | refined_depth = self.refine_network(torch.cat((imgs[0], depth), 1)) 261 | return { 262 | "depth": refined_depth, 263 | "photometric_confidence": photometric_confidence, 264 | "loss_data": refined_depth, 265 | } 266 | 267 | 268 | def mvsnet_loss(depth_est, depth_gt, mask): 269 | mask = mask["stage_0"] > 0.5 270 | return F.smooth_l1_loss( 271 | depth_est["stage_0"][mask], depth_gt["stage_0"][mask], size_average=True 272 | ) 273 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/patchmatchnet/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from types import SimpleNamespace 3 | from typing import Any, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | from torch import Tensor 9 | 10 | from .net import PatchmatchNet, patchmatchnet_loss 11 | 12 | __all__ = ["PatchmatchNet", "patchmatchnet_loss", "NetBuilder", "SimpleInterfaceNet"] 13 | 14 | 15 | class SimpleInterfaceNet(nn.Module): 16 | """ 17 | Simple common interface to call the pretrained models 18 | """ 19 | 20 | def __init__(self, **kwargs): 21 | super().__init__() 22 | 23 | default_args = dict( 24 | patchmatch_interval_scale=[0.005, 0.0125, 0.025], 25 | propagation_range=[6, 4, 2], 26 | patchmatch_iteration=[1, 2, 2], 27 | patchmatch_num_sample=[8, 8, 16], 28 | propagate_neighbors=[0, 8, 16], 29 | evaluate_neighbors=[9, 9, 9], 30 | ) 31 | default_args.update(kwargs) 32 | self.model = PatchmatchNet(**default_args) 33 | self.all_outputs: dict[str, Any] = {} 34 | 35 | def forward( 36 | self, 37 | imgs: Tensor, 38 | intrinsics: Tensor, 39 | extrinsics: Tensor, 40 | depth_values: Tensor, 41 | hints: Optional[Tensor] = None, 42 | ): 43 | with warnings.catch_warnings(): 44 | warnings.simplefilter("ignore", UserWarning) 45 | 46 | # compute poses 47 | proj_matrices = {} 48 | for i in range(4): 49 | proj_mat = extrinsics.clone() 50 | intrinsics_copy = intrinsics.clone() 51 | intrinsics_copy[..., :2, :] = intrinsics_copy[..., :2, :] / (2 ** i) 52 | proj_mat[..., :3, :4] = intrinsics_copy @ proj_mat[..., :3, :4] 53 | proj_matrices[f"stage_{i}"] = proj_mat 54 | 55 | # downsample images 56 | imgs_stages = {} 57 | h, w = imgs.shape[-2:] 58 | for i in range(4): 59 | dsize = h // (2 ** i), w // (2 ** i) 60 | imgs_stages[f"stage_{i}"] = torch.stack( 61 | [ 62 | F.interpolate(img, dsize, mode="bilinear", align_corners=True) 63 | for img in torch.unbind(imgs, 1) 64 | ], 65 | 1, 66 | ) 67 | 68 | # validhints 69 | validhints = None 70 | if hints is not None: 71 | validhints = (hints > 0).to(torch.float32) 72 | 73 | # call 74 | out = self.model( 75 | imgs_stages, 76 | proj_matrices, 77 | depth_values.min(1).values, 78 | depth_values.max(1).values, 79 | hints, 80 | validhints, 81 | ) 82 | self.all_outputs = out 83 | return out["depth"]["stage_0"] 84 | 85 | 86 | class NetBuilder(nn.Module): 87 | def __init__(self, args: SimpleNamespace): 88 | super().__init__() 89 | 90 | self.model = PatchmatchNet( 91 | patchmatch_interval_scale=[0.005, 0.0125, 0.025], 92 | propagation_range=[6, 4, 2], 93 | patchmatch_iteration=[1, 2, 2], 94 | patchmatch_num_sample=[8, 8, 16], 95 | propagate_neighbors=[0, 8, 16], 96 | evaluate_neighbors=[9, 9, 9], 97 | ) 98 | 99 | def loss_func(loss_data, depth_gt, mask): 100 | return patchmatchnet_loss( 101 | loss_data["depth_patchmatch"], 102 | loss_data["refined_depth"], 103 | depth_gt, 104 | mask, 105 | ) 106 | 107 | self.loss = loss_func 108 | 109 | self.hparams = { 110 | "interval_scale": [0.005, 0.0125, 0.025], 111 | "propagation_range": [6, 4, 2], 112 | "iteration": [1, 2, 2], 113 | "num_sample": [8, 8, 16], 114 | "propagate_neighbors": [0, 8, 16], 115 | "evaluate_neighbors": [9, 9, 9], 116 | } 117 | 118 | def forward(self, batch: dict): 119 | 120 | hints, validhints = None, None 121 | if "hints" in batch: 122 | hints = batch["hints"] 123 | validhints = (hints > 0).to(torch.float32) 124 | 125 | return self.model( 126 | batch["imgs"], 127 | batch["proj_matrices"], 128 | batch["depth_min"], 129 | batch["depth_max"], 130 | hints, 131 | validhints, 132 | ) 133 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/patchmatchnet/module.py: -------------------------------------------------------------------------------- 1 | """ 2 | Implementation of Pytorch layer primitives, such as Conv+BN+ReLU, differentiable warping layers, 3 | and depth regression based upon expectation of an input probability distribution. 4 | """ 5 | 6 | import torch 7 | import torch.nn as nn 8 | import torch.nn.functional as F 9 | 10 | 11 | class ConvBnReLU(nn.Module): 12 | """Implements 2d Convolution + batch normalization + ReLU""" 13 | 14 | def __init__( 15 | self, 16 | in_channels: int, 17 | out_channels: int, 18 | kernel_size: int = 3, 19 | stride: int = 1, 20 | pad: int = 1, 21 | dilation: int = 1, 22 | ) -> None: 23 | """initialization method for convolution2D + batch normalization + relu module 24 | Args: 25 | in_channels: input channel number of convolution layer 26 | out_channels: output channel number of convolution layer 27 | kernel_size: kernel size of convolution layer 28 | stride: stride of convolution layer 29 | pad: pad of convolution layer 30 | dilation: dilation of convolution layer 31 | """ 32 | super(ConvBnReLU, self).__init__() 33 | self.conv = nn.Conv2d( 34 | in_channels, 35 | out_channels, 36 | kernel_size, 37 | stride=stride, 38 | padding=pad, 39 | dilation=dilation, 40 | bias=False, 41 | ) 42 | self.bn = nn.BatchNorm2d(out_channels) 43 | 44 | def forward(self, x: torch.Tensor) -> torch.Tensor: 45 | """forward method""" 46 | return F.relu(self.bn(self.conv(x)), inplace=True) 47 | 48 | 49 | class ConvBnReLU3D(nn.Module): 50 | """Implements of 3d convolution + batch normalization + ReLU.""" 51 | 52 | def __init__( 53 | self, 54 | in_channels: int, 55 | out_channels: int, 56 | kernel_size: int = 3, 57 | stride: int = 1, 58 | pad: int = 1, 59 | dilation: int = 1, 60 | ) -> None: 61 | """initialization method for convolution3D + batch normalization + relu module 62 | Args: 63 | in_channels: input channel number of convolution layer 64 | out_channels: output channel number of convolution layer 65 | kernel_size: kernel size of convolution layer 66 | stride: stride of convolution layer 67 | pad: pad of convolution layer 68 | dilation: dilation of convolution layer 69 | """ 70 | super(ConvBnReLU3D, self).__init__() 71 | self.conv = nn.Conv3d( 72 | in_channels, 73 | out_channels, 74 | kernel_size, 75 | stride=stride, 76 | padding=pad, 77 | dilation=dilation, 78 | bias=False, 79 | ) 80 | self.bn = nn.BatchNorm3d(out_channels) 81 | 82 | def forward(self, x: torch.Tensor) -> torch.Tensor: 83 | """forward method""" 84 | return F.relu(self.bn(self.conv(x)), inplace=True) 85 | 86 | 87 | class ConvBnReLU1D(nn.Module): 88 | """Implements 1d Convolution + batch normalization + ReLU.""" 89 | 90 | def __init__( 91 | self, 92 | in_channels: int, 93 | out_channels: int, 94 | kernel_size: int = 3, 95 | stride: int = 1, 96 | pad: int = 1, 97 | dilation: int = 1, 98 | ) -> None: 99 | """initialization method for convolution1D + batch normalization + relu module 100 | Args: 101 | in_channels: input channel number of convolution layer 102 | out_channels: output channel number of convolution layer 103 | kernel_size: kernel size of convolution layer 104 | stride: stride of convolution layer 105 | pad: pad of convolution layer 106 | dilation: dilation of convolution layer 107 | """ 108 | super(ConvBnReLU1D, self).__init__() 109 | self.conv = nn.Conv1d( 110 | in_channels, 111 | out_channels, 112 | kernel_size, 113 | stride=stride, 114 | padding=pad, 115 | dilation=dilation, 116 | bias=False, 117 | ) 118 | self.bn = nn.BatchNorm1d(out_channels) 119 | 120 | def forward(self, x: torch.Tensor) -> torch.Tensor: 121 | """forward method""" 122 | return F.relu(self.bn(self.conv(x)), inplace=True) 123 | 124 | 125 | class ConvBn(nn.Module): 126 | """Implements of 2d convolution + batch normalization.""" 127 | 128 | def __init__( 129 | self, 130 | in_channels: int, 131 | out_channels: int, 132 | kernel_size: int = 3, 133 | stride: int = 1, 134 | pad: int = 1, 135 | ) -> None: 136 | """initialization method for convolution2D + batch normalization + ReLU module 137 | Args: 138 | in_channels: input channel number of convolution layer 139 | out_channels: output channel number of convolution layer 140 | kernel_size: kernel size of convolution layer 141 | stride: stride of convolution layer 142 | pad: pad of convolution layer 143 | """ 144 | super(ConvBn, self).__init__() 145 | self.conv = nn.Conv2d( 146 | in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False 147 | ) 148 | self.bn = nn.BatchNorm2d(out_channels) 149 | 150 | def forward(self, x: torch.Tensor) -> torch.Tensor: 151 | """forward method""" 152 | return self.bn(self.conv(x)) 153 | 154 | 155 | def differentiable_warping( 156 | src_fea: torch.Tensor, 157 | src_proj: torch.Tensor, 158 | ref_proj: torch.Tensor, 159 | depth_samples: torch.Tensor, 160 | ): 161 | """Differentiable homography-based warping, implemented in Pytorch. 162 | 163 | Args: 164 | src_fea: [B, C, H, W] source features, for each source view in batch 165 | src_proj: [B, 4, 4] source camera projection matrix, for each source view in batch 166 | ref_proj: [B, 4, 4] reference camera projection matrix, for each ref view in batch 167 | depth_samples: [B, Ndepth, H, W] virtual depth layers 168 | Returns: 169 | warped_src_fea: [B, C, Ndepth, H, W] features on depths after perspective transformation 170 | """ 171 | 172 | batch, channels, height, width = src_fea.shape 173 | num_depth = depth_samples.shape[1] 174 | 175 | with torch.no_grad(): 176 | proj = torch.matmul(src_proj, torch.inverse(ref_proj)) 177 | rot = proj[:, :3, :3] # [B,3,3] 178 | trans = proj[:, :3, 3:4] # [B,3,1] 179 | 180 | y, x = torch.meshgrid( 181 | [ 182 | torch.arange(0, height, dtype=torch.float32, device=src_fea.device), 183 | torch.arange(0, width, dtype=torch.float32, device=src_fea.device), 184 | ] 185 | ) 186 | y, x = y.contiguous(), x.contiguous() 187 | y, x = y.view(height * width), x.view(height * width) 188 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W] 189 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W] 190 | rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W] 191 | 192 | rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_samples.view( 193 | batch, 1, num_depth, height * width 194 | ) # [B, 3, Ndepth, H*W] 195 | proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W] 196 | # avoid negative depth 197 | negative_depth_mask = proj_xyz[:, 2:] <= 1e-3 198 | proj_xyz[:, 0:1][negative_depth_mask] = width 199 | proj_xyz[:, 1:2][negative_depth_mask] = height 200 | proj_xyz[:, 2:3][negative_depth_mask] = 1 201 | proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W] 202 | proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1 # [B, Ndepth, H*W] 203 | proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1 204 | proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2] 205 | grid = proj_xy 206 | 207 | warped_src_fea = F.grid_sample( 208 | src_fea, 209 | grid.view(batch, num_depth * height, width, 2), 210 | mode="bilinear", 211 | padding_mode="zeros", 212 | align_corners=True, 213 | ) 214 | 215 | warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height, width) 216 | 217 | return warped_src_fea 218 | 219 | 220 | def depth_regression(p: torch.Tensor, depth_values: torch.Tensor) -> torch.Tensor: 221 | """Implements per-pixel depth regression based upon a probability distribution per-pixel. 222 | 223 | The regressed depth value D(p) at pixel p is found as the expectation w.r.t. P of the hypotheses. 224 | 225 | Args: 226 | p: probability volume [B, D, H, W] 227 | depth_values: discrete depth values [B, D] 228 | Returns: 229 | result depth: expected value, soft argmin [B, 1, H, W] 230 | """ 231 | 232 | depth_values = depth_values.view(*depth_values.shape, 1, 1) 233 | depth = torch.sum(p * depth_values, dim=1) 234 | depth = depth.unsqueeze(1) 235 | return depth 236 | 237 | 238 | def depth_regression_1(p: torch.Tensor, depth_values: torch.Tensor) -> torch.Tensor: 239 | """another version of depth regression function 240 | Args: 241 | p: probability volume [B, D, H, W] 242 | depth_values: discrete depth values [B, D] 243 | Returns: 244 | result depth: expected value, soft argmin [B, 1, H, W] 245 | """ 246 | 247 | depth = torch.sum(p * depth_values, 1) 248 | depth = depth.unsqueeze(1) 249 | return depth 250 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/patchmatchnet/net.py: -------------------------------------------------------------------------------- 1 | from typing import Any, Dict, List 2 | 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.functional as F 6 | 7 | from .module import ConvBnReLU, depth_regression 8 | from .patchmatch import PatchMatch 9 | 10 | 11 | class FeatureNet(nn.Module): 12 | """Feature Extraction Network: to extract features of original images from each view""" 13 | 14 | def __init__(self): 15 | """Initialize different layers in the network""" 16 | 17 | super(FeatureNet, self).__init__() 18 | 19 | self.conv0 = ConvBnReLU(3, 8, 3, 1, 1) 20 | # [B,8,H,W] 21 | self.conv1 = ConvBnReLU(8, 8, 3, 1, 1) 22 | # [B,16,H/2,W/2] 23 | self.conv2 = ConvBnReLU(8, 16, 5, 2, 2) 24 | self.conv3 = ConvBnReLU(16, 16, 3, 1, 1) 25 | self.conv4 = ConvBnReLU(16, 16, 3, 1, 1) 26 | # [B,32,H/4,W/4] 27 | self.conv5 = ConvBnReLU(16, 32, 5, 2, 2) 28 | self.conv6 = ConvBnReLU(32, 32, 3, 1, 1) 29 | self.conv7 = ConvBnReLU(32, 32, 3, 1, 1) 30 | # [B,64,H/8,W/8] 31 | self.conv8 = ConvBnReLU(32, 64, 5, 2, 2) 32 | self.conv9 = ConvBnReLU(64, 64, 3, 1, 1) 33 | self.conv10 = ConvBnReLU(64, 64, 3, 1, 1) 34 | 35 | self.output1 = nn.Conv2d(64, 64, 1, bias=False) 36 | self.inner1 = nn.Conv2d(32, 64, 1, bias=True) 37 | self.inner2 = nn.Conv2d(16, 64, 1, bias=True) 38 | self.output2 = nn.Conv2d(64, 32, 1, bias=False) 39 | self.output3 = nn.Conv2d(64, 16, 1, bias=False) 40 | 41 | def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]: 42 | """Forward method 43 | 44 | Args: 45 | x: images from a single view, in the shape of [B, C, H, W]. Generally, C=3 46 | 47 | Returns: 48 | output_feature: a python dictionary contains extracted features from stage_1 to stage_3 49 | keys are "stage_1", "stage_2", and "stage_3" 50 | """ 51 | output_feature = {} 52 | 53 | conv1 = self.conv1(self.conv0(x)) 54 | conv4 = self.conv4(self.conv3(self.conv2(conv1))) 55 | 56 | conv7 = self.conv7(self.conv6(self.conv5(conv4))) 57 | conv10 = self.conv10(self.conv9(self.conv8(conv7))) 58 | 59 | output_feature["stage_3"] = self.output1(conv10) 60 | 61 | intra_feat = F.interpolate(conv10, scale_factor=2, mode="bilinear") + self.inner1(conv7) 62 | del conv7, conv10 63 | output_feature["stage_2"] = self.output2(intra_feat) 64 | 65 | intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear") + self.inner2( 66 | conv4 67 | ) 68 | del conv4 69 | output_feature["stage_1"] = self.output3(intra_feat) 70 | 71 | del intra_feat 72 | return output_feature 73 | 74 | 75 | class Refinement(nn.Module): 76 | """Depth map refinement network""" 77 | 78 | def __init__(self): 79 | """Initialize""" 80 | 81 | super(Refinement, self).__init__() 82 | 83 | # img: [B,3,H,W] 84 | self.conv0 = ConvBnReLU(in_channels=3, out_channels=8) 85 | # depth map:[B,1,H/2,W/2] 86 | self.conv1 = ConvBnReLU(in_channels=1, out_channels=8) 87 | self.conv2 = ConvBnReLU(in_channels=8, out_channels=8) 88 | self.deconv = nn.ConvTranspose2d( 89 | in_channels=8, 90 | out_channels=8, 91 | kernel_size=3, 92 | padding=1, 93 | output_padding=1, 94 | stride=2, 95 | bias=False, 96 | ) 97 | 98 | self.bn = nn.BatchNorm2d(8) 99 | self.conv3 = ConvBnReLU(in_channels=16, out_channels=8) 100 | self.res = nn.Conv2d(in_channels=8, out_channels=1, kernel_size=3, padding=1, bias=False) 101 | 102 | def forward( 103 | self, 104 | img: torch.Tensor, 105 | depth_0: torch.Tensor, 106 | depth_min: torch.Tensor, 107 | depth_max: torch.Tensor, 108 | ) -> torch.Tensor: 109 | """Forward method 110 | 111 | Args: 112 | img: input reference images (B, 3, H, W) 113 | depth_0: current depth map (B, 1, H//2, W//2) 114 | depth_min: pre-defined minimum depth (B, ) 115 | depth_max: pre-defined maximum depth (B, ) 116 | 117 | Returns: 118 | depth: refined depth map (B, 1, H, W) 119 | """ 120 | 121 | batch_size = depth_min.size()[0] 122 | # pre-scale the depth map into [0,1] 123 | depth = (depth_0 - depth_min.view(batch_size, 1, 1, 1)) / ( 124 | depth_max.view(batch_size, 1, 1, 1) - depth_min.view(batch_size, 1, 1, 1) 125 | ) 126 | 127 | conv0 = self.conv0(img) 128 | deconv = F.relu(self.bn(self.deconv(self.conv2(self.conv1(depth)))), inplace=True) 129 | cat = torch.cat((deconv, conv0), dim=1) 130 | del deconv, conv0 131 | # depth residual 132 | res = self.res(self.conv3(cat)) 133 | del cat 134 | 135 | depth = F.interpolate(depth, scale_factor=2, mode="nearest") + res 136 | # convert the normalized depth back 137 | depth = depth * ( 138 | depth_max.view(batch_size, 1, 1, 1) - depth_min.view(batch_size, 1, 1, 1) 139 | ) + depth_min.view(batch_size, 1, 1, 1) 140 | 141 | return depth 142 | 143 | 144 | class PatchmatchNet(nn.Module): 145 | """Implementation of complete structure of PatchmatchNet""" 146 | 147 | def __init__( 148 | self, 149 | patchmatch_interval_scale: List[float] = [0.005, 0.0125, 0.025], 150 | propagation_range: List[int] = [6, 4, 2], 151 | patchmatch_iteration: List[int] = [1, 2, 2], 152 | patchmatch_num_sample: List[int] = [8, 8, 16], 153 | propagate_neighbors: List[int] = [0, 8, 16], 154 | evaluate_neighbors: List[int] = [9, 9, 9], 155 | ) -> None: 156 | """Initialize modules in PatchmatchNet 157 | 158 | Args: 159 | patchmatch_interval_scale: depth interval scale in patchmatch module 160 | propagation_range: propagation range 161 | patchmatch_iteration: patchmatch interation number 162 | patchmatch_num_sample: patchmatch number of samples 163 | propagate_neighbors: number of propagation neigbors 164 | evaluate_neighbors: number of propagation neigbors for evaluation 165 | """ 166 | super(PatchmatchNet, self).__init__() 167 | 168 | self.stages = 4 169 | self.feature = FeatureNet() 170 | self.patchmatch_num_sample = patchmatch_num_sample 171 | 172 | num_features = [8, 16, 32, 64] 173 | 174 | self.propagate_neighbors = propagate_neighbors 175 | self.evaluate_neighbors = evaluate_neighbors 176 | # number of groups for group-wise correlation 177 | self.G = [4, 8, 8] 178 | 179 | for i in range(self.stages - 1): 180 | 181 | if i == 2: 182 | patchmatch = PatchMatch( 183 | random_initialization=True, 184 | propagation_out_range=propagation_range[i], 185 | patchmatch_iteration=patchmatch_iteration[i], 186 | patchmatch_num_sample=patchmatch_num_sample[i], 187 | patchmatch_interval_scale=patchmatch_interval_scale[i], 188 | num_feature=num_features[i + 1], 189 | G=self.G[i], 190 | propagate_neighbors=self.propagate_neighbors[i], 191 | stage=i + 1, 192 | evaluate_neighbors=evaluate_neighbors[i], 193 | ) 194 | else: 195 | patchmatch = PatchMatch( 196 | random_initialization=False, 197 | propagation_out_range=propagation_range[i], 198 | patchmatch_iteration=patchmatch_iteration[i], 199 | patchmatch_num_sample=patchmatch_num_sample[i], 200 | patchmatch_interval_scale=patchmatch_interval_scale[i], 201 | num_feature=num_features[i + 1], 202 | G=self.G[i], 203 | propagate_neighbors=self.propagate_neighbors[i], 204 | stage=i + 1, 205 | evaluate_neighbors=evaluate_neighbors[i], 206 | ) 207 | setattr(self, f"patchmatch_{i+1}", patchmatch) 208 | 209 | self.upsample_net = Refinement() 210 | 211 | def forward( 212 | self, 213 | imgs: Dict[str, torch.Tensor], 214 | proj_matrices: Dict[str, torch.Tensor], 215 | depth_min: torch.Tensor, 216 | depth_max: torch.Tensor, 217 | hints: torch.Tensor = None, 218 | validhints: torch.Tensor = None, 219 | ) -> Dict[str, Any]: 220 | """Forward method for PatchMatchNet 221 | 222 | Args: 223 | imgs: different stages of images (B, 3, H, W) stored in the dictionary 224 | proj_matrics: different stages of camera projection matrices (B, 4, 4) stored in the dictionary 225 | depth_min: minimum virtual depth (B, ) 226 | depth_max: maximum virtual depth (B, ) 227 | 228 | Returns: 229 | output dictionary of PatchMatchNet, containing refined depthmap, depth patchmatch 230 | and photometric_confidence. 231 | """ 232 | imgs_0 = torch.unbind(imgs["stage_0"], 1) 233 | imgs_1 = torch.unbind(imgs["stage_1"], 1) 234 | imgs_2 = torch.unbind(imgs["stage_2"], 1) 235 | imgs_3 = torch.unbind(imgs["stage_3"], 1) 236 | del imgs 237 | 238 | self.imgs_0_ref = imgs_0[0] 239 | self.imgs_1_ref = imgs_1[0] 240 | self.imgs_2_ref = imgs_2[0] 241 | self.imgs_3_ref = imgs_3[0] 242 | del imgs_1, imgs_2, imgs_3 243 | 244 | self.proj_matrices_0 = torch.unbind(proj_matrices["stage_0"].float(), 1) 245 | self.proj_matrices_1 = torch.unbind(proj_matrices["stage_1"].float(), 1) 246 | self.proj_matrices_2 = torch.unbind(proj_matrices["stage_2"].float(), 1) 247 | self.proj_matrices_3 = torch.unbind(proj_matrices["stage_3"].float(), 1) 248 | del proj_matrices 249 | 250 | assert len(imgs_0) == len( 251 | self.proj_matrices_0 252 | ), "Different number of images and projection matrices" 253 | 254 | # step 1. Multi-scale feature extraction 255 | features = [] 256 | for img in imgs_0: 257 | output_feature = self.feature(img) 258 | features.append(output_feature) 259 | del imgs_0 260 | ref_feature, src_features = features[0], features[1:] 261 | 262 | depth_min = depth_min.float() 263 | depth_max = depth_max.float() 264 | 265 | # step 2. Learning-based patchmatch 266 | depth = None 267 | view_weights = None 268 | depth_patchmatch = {} 269 | refined_depth = {} 270 | 271 | for l in reversed(range(1, self.stages)): 272 | src_features_l = [src_fea[f"stage_{l}"] for src_fea in src_features] 273 | projs_l = getattr(self, f"proj_matrices_{l}") 274 | ref_proj, src_projs = projs_l[0], projs_l[1:] 275 | 276 | if l > 1: 277 | depth, _, view_weights = getattr(self, f"patchmatch_{l}")( 278 | ref_feature=ref_feature[f"stage_{l}"], 279 | src_features=src_features_l, 280 | ref_proj=ref_proj, 281 | src_projs=src_projs, 282 | depth_min=depth_min, 283 | depth_max=depth_max, 284 | depth=depth, 285 | img=getattr(self, f"imgs_{l}_ref"), 286 | view_weights=view_weights, 287 | hints=hints, 288 | validhints=validhints, 289 | ) 290 | else: 291 | depth, score, _ = getattr(self, f"patchmatch_{l}")( 292 | ref_feature=ref_feature[f"stage_{l}"], 293 | src_features=src_features_l, 294 | ref_proj=ref_proj, 295 | src_projs=src_projs, 296 | depth_min=depth_min, 297 | depth_max=depth_max, 298 | depth=depth, 299 | img=getattr(self, f"imgs_{l}_ref"), 300 | view_weights=view_weights, 301 | hints=hints, 302 | validhints=validhints, 303 | ) 304 | 305 | del src_features_l, ref_proj, src_projs, projs_l 306 | 307 | depth_patchmatch[f"stage_{l}"] = depth 308 | 309 | depth = depth[-1].detach() 310 | if l > 1: 311 | # upsampling the depth map and pixel-wise view weight for next stage 312 | depth = F.interpolate(depth, scale_factor=2, mode="nearest") 313 | view_weights = F.interpolate(view_weights, scale_factor=2, mode="nearest") 314 | 315 | # step 3. Refinement 316 | depth = self.upsample_net(self.imgs_0_ref, depth, depth_min, depth_max) 317 | refined_depth["stage_0"] = depth 318 | 319 | del depth, ref_feature, src_features 320 | 321 | num_depth = self.patchmatch_num_sample[0] 322 | score_sum4 = 4 * F.avg_pool3d( 323 | F.pad(score.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)), (4, 1, 1), stride=1, padding=0 324 | ).squeeze(1) 325 | # [B, 1, H, W] 326 | depth_index = depth_regression( 327 | score, depth_values=torch.arange(num_depth, device=score.device, dtype=torch.float) 328 | ).long() 329 | depth_index = torch.clamp(depth_index, 0, num_depth - 1) 330 | photometric_confidence = torch.gather(score_sum4, 1, depth_index) 331 | photometric_confidence = F.interpolate( 332 | photometric_confidence, scale_factor=2, mode="nearest" 333 | ) 334 | photometric_confidence = photometric_confidence.squeeze(1) 335 | 336 | return { 337 | "depth": refined_depth, 338 | "loss_data": { 339 | "depth_patchmatch": depth_patchmatch, 340 | "refined_depth": refined_depth, 341 | }, 342 | "photometric_confidence": photometric_confidence, 343 | } 344 | 345 | 346 | def patchmatchnet_loss( 347 | depth_patchmatch: Dict[str, torch.Tensor], 348 | refined_depth: Dict[str, torch.Tensor], 349 | depth_gt: Dict[str, torch.Tensor], 350 | mask: Dict[str, torch.Tensor], 351 | ) -> torch.Tensor: 352 | """Patchmatch Net loss function 353 | 354 | Args: 355 | depth_patchmatch: depth map predicted by patchmatch net 356 | refined_depth: refined depth map predicted by patchmatch net 357 | depth_gt: ground truth depth map 358 | mask: mask for filter valid points 359 | 360 | Returns: 361 | loss: result loss value 362 | """ 363 | stage = 4 364 | 365 | loss = 0 366 | for l in range(1, stage): 367 | depth_gt_l = depth_gt[f"stage_{l}"] 368 | mask_l = mask[f"stage_{l}"] > 0.5 369 | depth2 = depth_gt_l[mask_l] 370 | 371 | depth_patchmatch_l = depth_patchmatch[f"stage_{l}"] 372 | for i in range(len(depth_patchmatch_l)): 373 | depth1 = depth_patchmatch_l[i][mask_l] 374 | loss = loss + F.smooth_l1_loss(depth1, depth2, reduction="mean") 375 | 376 | l = 0 377 | depth_refined_l = refined_depth[f"stage_{l}"] 378 | depth_gt_l = depth_gt[f"stage_{l}"] 379 | mask_l = mask[f"stage_{l}"] > 0.5 380 | 381 | depth1 = depth_refined_l[mask_l] 382 | depth2 = depth_gt_l[mask_l] 383 | loss = loss + F.smooth_l1_loss(depth1, depth2, reduction="mean") 384 | 385 | return loss 386 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/ucsnet/__init__.py: -------------------------------------------------------------------------------- 1 | import warnings 2 | from types import SimpleNamespace 3 | from typing import Any, Optional 4 | 5 | import torch 6 | import torch.nn as nn 7 | from torch import Tensor 8 | 9 | from .ucsnet import UCSNet, ucsnet_loss 10 | 11 | __all__ = ["ucsnet_loss", "UCSNet", "NetBuilder", "SimpleInterfaceNet"] 12 | 13 | 14 | class SimpleInterfaceNet(nn.Module): 15 | """ 16 | Simple common interface to call the pretrained models 17 | """ 18 | 19 | def __init__(self, **kwargs): 20 | super().__init__() 21 | self.model = UCSNet(*kwargs) 22 | self.all_outputs: dict[str, Any] = {} 23 | 24 | def forward( 25 | self, 26 | imgs: Tensor, 27 | intrinsics: Tensor, 28 | extrinsics: Tensor, 29 | depth_values: Tensor, 30 | hints: Optional[Tensor] = None, 31 | ): 32 | with warnings.catch_warnings(): 33 | warnings.simplefilter("ignore", UserWarning) 34 | 35 | # compute poses 36 | proj_matrices = {} 37 | for i in range(4): 38 | proj_mat = extrinsics.clone() 39 | intrinsics_copy = intrinsics.clone() 40 | intrinsics_copy[..., :2, :] = intrinsics_copy[..., :2, :] / (2 ** i) 41 | proj_mat[..., :3, :4] = intrinsics_copy @ proj_mat[..., :3, :4] 42 | proj_matrices[f"stage_{i}"] = proj_mat 43 | 44 | # validhints 45 | validhints = None 46 | if hints is not None: 47 | validhints = (hints > 0).to(torch.float32) 48 | 49 | # call 50 | out = self.model(imgs, proj_matrices, depth_values, hints, validhints) 51 | self.all_outputs = out 52 | return out["depth"]["stage_0"] 53 | 54 | 55 | class NetBuilder(nn.Module): 56 | def __init__(self, args: SimpleNamespace): 57 | super().__init__() 58 | self.model = UCSNet() 59 | 60 | def loss_func(loss_data, depth_gt, mask): 61 | return ucsnet_loss( 62 | loss_data["depth"], 63 | depth_gt, 64 | mask, 65 | ) 66 | 67 | self.loss = loss_func 68 | 69 | def forward(self, batch: dict): 70 | 71 | hints, validhints = None, None 72 | if "hints" in batch: 73 | hints = batch["hints"] 74 | validhints = (hints > 0).to(torch.float32) 75 | 76 | return self.model( 77 | batch["imgs"]["stage_0"], 78 | batch["proj_matrices"], 79 | batch["depth_values"], 80 | hints, 81 | validhints, 82 | ) 83 | -------------------------------------------------------------------------------- /guided_mvs_lib/models/ucsnet/ucsnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | from .submodules import * 6 | 7 | 8 | def compute_depth( 9 | feats, 10 | proj_mats, 11 | depth_samps, 12 | cost_reg, 13 | lamb, 14 | is_training=False, 15 | hints=None, 16 | validhints=None, 17 | scale=1, 18 | ): 19 | """ 20 | :param feats: [(B, C, H, W), ] * num_views 21 | :param proj_mats: [()] 22 | :param depth_samps: 23 | :param cost_reg: 24 | :param lamb: 25 | :return: 26 | """ 27 | 28 | # print(proj_mats.shape) 29 | proj_mats = torch.unbind(proj_mats, 1) 30 | num_views = len(feats) 31 | num_depth = depth_samps.shape[1] 32 | 33 | assert len(proj_mats) == num_views, "Different number of images and projection matrices" 34 | 35 | ref_feat, src_feats = feats[0], feats[1:] 36 | ref_proj, src_projs = proj_mats[0], proj_mats[1:] 37 | 38 | ref_volume = ref_feat.unsqueeze(2).repeat(1, 1, num_depth, 1, 1) 39 | volume_sum = ref_volume 40 | volume_sq_sum = ref_volume ** 2 41 | del ref_volume 42 | 43 | # todo optimize impl 44 | for src_fea, src_proj in zip(src_feats, src_projs): 45 | warped_volume = homo_warping(src_fea, src_proj, ref_proj, depth_samps) 46 | 47 | if is_training: 48 | volume_sum = volume_sum + warped_volume 49 | volume_sq_sum = volume_sq_sum + warped_volume ** 2 50 | else: 51 | volume_sum += warped_volume 52 | volume_sq_sum += warped_volume.pow_(2) # in_place method 53 | del warped_volume 54 | volume_variance = volume_sq_sum.div_(num_views).sub_(volume_sum.div_(num_views).pow_(2)) 55 | 56 | # TODO cost-volume modulation 57 | if hints is not None and validhints is not None: 58 | batch_size, feats, height, width = ref_feat.shape 59 | GAUSSIAN_HEIGHT = 10.0 60 | GAUSSIAN_WIDTH = 1.0 61 | 62 | # image features are one fourth the original size: subsample the hints and divide them by four 63 | hints = hints 64 | hints = F.interpolate(hints, scale_factor=1 / scale, mode="nearest").unsqueeze(1) 65 | validhints = validhints 66 | validhints = F.interpolate(validhints, scale_factor=1 / scale, mode="nearest").unsqueeze(1) 67 | hints = hints * validhints 68 | 69 | # add feature and disparity dimensions to hints and validhints 70 | # and repeat their values along those dimensions, to obtain the same size as cost 71 | hints = hints.expand(-1, feats, num_depth, -1, -1) 72 | validhints = validhints.expand(-1, feats, num_depth, -1, -1) 73 | 74 | # create a tensor of the same size as cost, with disparities 75 | # between 0 and num_disp-1 along the disparity dimension 76 | depth_hyps = depth_samps.unsqueeze(1).expand(batch_size, feats, -1, height, width).detach() 77 | volume_variance = volume_variance * ( 78 | (1 - validhints) 79 | + validhints 80 | * GAUSSIAN_HEIGHT 81 | * (1 - torch.exp(-((depth_hyps - hints) ** 2) / (2 * GAUSSIAN_WIDTH ** 2))) 82 | ) 83 | 84 | prob_volume_pre = cost_reg(volume_variance).squeeze(1) 85 | prob_volume = F.softmax(prob_volume_pre, dim=1) 86 | depth = depth_regression(prob_volume, depth_values=depth_samps) 87 | 88 | with torch.no_grad(): 89 | prob_volume_sum4 = 4 * F.avg_pool3d( 90 | F.pad(prob_volume.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)), (4, 1, 1), stride=1, padding=0 91 | ).squeeze(1) 92 | depth_index = depth_regression( 93 | prob_volume, 94 | depth_values=torch.arange(num_depth, device=prob_volume.device, dtype=torch.float), 95 | ).long() 96 | depth_index = depth_index.clamp(min=0, max=num_depth - 1) 97 | prob_conf = torch.gather(prob_volume_sum4, 1, depth_index.unsqueeze(1)).squeeze(1) 98 | 99 | samp_variance = (depth_samps - depth.unsqueeze(1)) ** 2 100 | exp_variance = lamb * torch.sum(samp_variance * prob_volume, dim=1, keepdim=False) ** 0.5 101 | 102 | return {"depth": depth, "photometric_confidence": prob_conf, "variance": exp_variance} 103 | 104 | 105 | class UCSNet(nn.Module): 106 | def __init__( 107 | self, 108 | lamb=1.5, 109 | stage_configs=[64, 32, 8], 110 | grad_method="detach", 111 | base_chs=[8, 8, 8], 112 | feat_ext_ch=8, 113 | ): 114 | super(UCSNet, self).__init__() 115 | 116 | self.stage_configs = stage_configs 117 | self.grad_method = grad_method 118 | self.base_chs = base_chs 119 | self.lamb = lamb 120 | self.num_stage = len(stage_configs) 121 | self.ds_ratio = { 122 | "stage_2": 4.0, # "stage_1": 4.0, 123 | "stage_1": 2.0, # "stage_2": 2.0, 124 | "stage_0": 1.0, # "stage_3": 1.0 125 | } 126 | 127 | self.feature_extraction = FeatExtNet( 128 | base_channels=feat_ext_ch, 129 | num_stage=self.num_stage, 130 | ) 131 | 132 | self.cost_regularization = nn.ModuleList( 133 | [ 134 | CostRegNet( 135 | in_channels=self.feature_extraction.out_channels[i], 136 | base_channels=self.base_chs[i], 137 | ) 138 | for i in [2, 1, 0] 139 | ] 140 | ) # range(self.num_stage)]) 141 | 142 | def forward(self, imgs, proj_matrices, depth_values, hints=None, validhints=None): 143 | features = [] 144 | for nview_idx in range(imgs.shape[1]): 145 | img = imgs[:, nview_idx] 146 | features.append(self.feature_extraction(img)) 147 | 148 | outputs = {} 149 | depth, cur_depth, exp_var = None, None, None 150 | for stage_idx in [2, 1, 0]: # range(self.num_stage): 151 | features_stage = [feat["stage_{}".format(stage_idx)] for feat in features] 152 | proj_matrices_stage = proj_matrices["stage_{}".format(stage_idx)] # + 1)] 153 | stage_scale = self.ds_ratio["stage_{}".format(stage_idx)] # + 1)] 154 | cur_h = img.shape[2] // int(stage_scale) 155 | cur_w = img.shape[3] // int(stage_scale) 156 | 157 | if depth is not None: 158 | if self.grad_method == "detach": 159 | cur_depth = depth.detach() 160 | exp_var = exp_var.detach() 161 | else: 162 | cur_depth = depth 163 | 164 | cur_depth = F.interpolate(cur_depth.unsqueeze(1), [cur_h, cur_w], mode="bilinear") 165 | exp_var = F.interpolate(exp_var.unsqueeze(1), [cur_h, cur_w], mode="bilinear") 166 | 167 | else: 168 | cur_depth = depth_values 169 | 170 | depth_range_samples = uncertainty_aware_samples( 171 | cur_depth=cur_depth, 172 | exp_var=exp_var, 173 | ndepth=self.stage_configs[2 - stage_idx], 174 | dtype=img[0].dtype, 175 | device=img[0].device, 176 | shape=[img.shape[0], cur_h, cur_w], 177 | ) 178 | 179 | outputs_stage = compute_depth( 180 | features_stage, 181 | proj_matrices_stage, 182 | depth_samps=depth_range_samples, 183 | cost_reg=self.cost_regularization[stage_idx], 184 | lamb=self.lamb, 185 | is_training=self.training, 186 | hints=hints, 187 | validhints=validhints, 188 | scale=stage_scale, 189 | ) 190 | 191 | depth = outputs_stage["depth"] 192 | exp_var = outputs_stage["variance"] 193 | 194 | outputs["stage_{}".format(stage_idx)] = outputs_stage["depth"].unsqueeze(1) 195 | 196 | return { 197 | "depth": outputs, 198 | "loss_data": { 199 | "depth": outputs, 200 | }, 201 | "photometric_confidence": outputs_stage["photometric_confidence"], 202 | } 203 | 204 | 205 | def ucsnet_loss(outputs, labels, masks, weights=[0.5, 1.0, 2.0]): 206 | tot_loss = 0 207 | for stage_id in [0, 1, 2]: # range(3): 208 | depth_i = outputs["stage_{}".format(stage_id)] # .unsqueeze(1) 209 | label_i = labels["stage_{}".format(stage_id)] 210 | mask_i = masks["stage_{}".format(stage_id)].bool() 211 | depth_loss = F.smooth_l1_loss(depth_i[mask_i], label_i[mask_i], reduction="mean") 212 | tot_loss += depth_loss * weights[2 - stage_id] 213 | return tot_loss 214 | -------------------------------------------------------------------------------- /hubconf.py: -------------------------------------------------------------------------------- 1 | import tarfile 2 | from multiprocessing import cpu_count 3 | from pathlib import Path 4 | from typing import Literal, Tuple 5 | 6 | import torch 7 | import yaml 8 | 9 | from guided_mvs_lib import __version__ as CURR_VERS 10 | from guided_mvs_lib import models 11 | 12 | dependencies = ["torch", "torchvision"] 13 | 14 | ## entry points for each single model 15 | 16 | CURR_DIR = Path(__file__).parent 17 | 18 | 19 | def _get_archive() -> str: 20 | if not (CURR_DIR / "trained_models.tar.gz").exists(): 21 | torch.hub.download_url_to_file( 22 | f"https://github.com/andreaconti/multi-view-guided-multi-view-stereo/releases/download/v{CURR_VERS}/trained_models.tar.gz", 23 | str(CURR_DIR / "trained_models.tar.gz"), 24 | ) 25 | return str(CURR_DIR / "trained_models.tar.gz") 26 | 27 | 28 | def _load_model( 29 | tarpath: str, 30 | model: Literal["ucsnet", "d2hc_rmvsnet", "mvsnet", "patchmatchnet", "cas_mvsnet"] = "mvsnet", 31 | pretrained: bool = True, 32 | dataset: Literal["blended_mvg", "dtu_yao_blended_mvg"] = "blended_mvg", 33 | hints: Literal["mvguided_filtered", "not_guided", "guided", "mvguided"] = "not_guided", 34 | hints_density: float = 0.03, 35 | ): 36 | """ 37 | Utility function to load from the tarfile containing all the pretrained models the one choosen 38 | """ 39 | 40 | assert model in [ 41 | "ucsnet", 42 | "d2hc_rmvsnet", 43 | "mvsnet", 44 | "patchmatchnet", 45 | "cas_mvsnet", 46 | ] 47 | assert dataset in ["blended_mvg", "dtu_yao_blended_mvg"] 48 | assert hints in ["mvguided_filtered", "not_guided", "guided", "mvguided"] 49 | 50 | # model instance 51 | model_net = models.__dict__[model].SimpleInterfaceNet() 52 | model_net.train_params = None 53 | 54 | # find the correct checkpoint 55 | if pretrained: 56 | with tarfile.open(tarpath) as archive: 57 | info = yaml.safe_load(archive.extractfile("trained_models/info.yaml")) 58 | for ckpt_id, meta in info.items(): 59 | found = meta["model"] == model and meta["hints"] == hints 60 | if hints != "not_guided": 61 | found = found and float(meta["hints_density"]) == hints_density 62 | if dataset == "blended_mvg": 63 | found = found and meta["dataset"] == dataset 64 | else: 65 | found = ( 66 | found 67 | and meta["dataset"] == "dtu_yao" 68 | and "load_weights" in meta 69 | and info[meta["load_weights"]]["dataset"] == "blended_mvg" 70 | ) 71 | if found: 72 | break 73 | if not found: 74 | raise ValueError("Model not available with the provided parameters") 75 | 76 | model_net.load_state_dict( 77 | { 78 | ".".join(n.split(".")[1:]): v 79 | for n, v in torch.load(archive.extractfile(f"trained_models/{ckpt_id}.ckpt"))[ 80 | "state_dict" 81 | ].items() 82 | } 83 | ) 84 | model_net.train_params = meta 85 | return model_net 86 | 87 | 88 | def mvsnet( 89 | pretrained: bool = True, 90 | dataset: Literal["blended_mvg", "dtu_yao_blended_mvg"] = "blended_mvg", 91 | hints: Literal["mvguided_filtered", "not_guided", "guided", "mvguided"] = "not_guided", 92 | hints_density: float = 0.03, 93 | ): 94 | """ 95 | pretrained `MVSNet`_ network. 96 | 97 | .. _MVSNet https://arxiv.org/pdf/1804.02505.pdf 98 | """ 99 | return _load_model( 100 | _get_archive(), 101 | "mvsnet", 102 | pretrained=pretrained, 103 | dataset=dataset, 104 | hints=hints, 105 | hints_density=hints_density, 106 | ) 107 | 108 | 109 | def ucsnet( 110 | pretrained: bool = True, 111 | dataset: Literal["blended_mvg", "dtu_yao_blended_mvg"] = "blended_mvg", 112 | hints: Literal["mvguided_filtered", "not_guided", "guided", "mvguided"] = "not_guided", 113 | hints_density: float = 0.03, 114 | ): 115 | """ 116 | pretrained `UCSNet`_ network. 117 | 118 | .. _UCSNet https://arxiv.org/pdf/1911.12012.pdf 119 | """ 120 | return _load_model( 121 | _get_archive(), 122 | "ucsnet", 123 | pretrained=pretrained, 124 | dataset=dataset, 125 | hints=hints, 126 | hints_density=hints_density, 127 | ) 128 | 129 | 130 | def d2hc_rmvsnet( 131 | pretrained: bool = True, 132 | dataset: Literal["blended_mvg", "dtu_yao_blended_mvg"] = "blended_mvg", 133 | hints: Literal["mvguided_filtered", "not_guided", "guided", "mvguided"] = "not_guided", 134 | hints_density: float = 0.03, 135 | ): 136 | """ 137 | pretrained `D2HCRMVSNet`_ network. 138 | 139 | .. _D2HCRMVSNet https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123490647.pdf 140 | """ 141 | return _load_model( 142 | _get_archive(), 143 | "d2hc_rmvsnet", 144 | pretrained=pretrained, 145 | dataset=dataset, 146 | hints=hints, 147 | hints_density=hints_density, 148 | ) 149 | 150 | 151 | def patchmatchnet( 152 | pretrained: bool = True, 153 | dataset: Literal["blended_mvg", "dtu_yao_blended_mvg"] = "blended_mvg", 154 | hints: Literal["mvguided_filtered", "not_guided", "guided", "mvguided"] = "not_guided", 155 | hints_density: float = 0.03, 156 | ): 157 | """ 158 | pretrained `PatchMatchNet`_ network. 159 | 160 | .. _PatchMatchNet https://github.com/FangjinhuaWang/PatchmatchNet 161 | """ 162 | return _load_model( 163 | _get_archive(), 164 | "patchmatchnet", 165 | pretrained=pretrained, 166 | dataset=dataset, 167 | hints=hints, 168 | hints_density=hints_density, 169 | ) 170 | 171 | 172 | def cas_mvsnet( 173 | pretrained: bool = True, 174 | dataset: Literal["blended_mvg", "dtu_yao_blended_mvg"] = "blended_mvg", 175 | hints: Literal["mvguided_filtered", "not_guided", "guided", "mvguided"] = "not_guided", 176 | hints_density: float = 0.03, 177 | ): 178 | """ 179 | pretrained `CASMVSNet`_ network. 180 | 181 | .. _CASMVSNet https://arxiv.org/pdf/1912.06378.pdf 182 | """ 183 | return _load_model( 184 | _get_archive(), 185 | "cas_mvsnet", 186 | pretrained=pretrained, 187 | dataset=dataset, 188 | hints=hints, 189 | hints_density=hints_density, 190 | ) 191 | 192 | 193 | ## Datasets 194 | 195 | 196 | def _load_dataset( 197 | dataset: str, 198 | root: str, 199 | batch_size: int = 1, 200 | nviews: int = 5, 201 | ndepths: int = 128, 202 | hints: str = "mvguided_filtered", 203 | hints_density: float = 0.03, 204 | filtering_window: Tuple[int, int] = (9, 9), 205 | num_workers: int = cpu_count() // 2, 206 | ): 207 | from guided_mvs_lib.datasets import MVSDataModule 208 | from guided_mvs_lib.datasets.sample_preprocess import MVSSampleTransform 209 | 210 | dm = MVSDataModule( 211 | dataset, 212 | batch_size=batch_size, 213 | num_workers=num_workers, 214 | datapath=root, 215 | nviews=nviews, 216 | ndepths=ndepths, 217 | robust_train=False, 218 | transform=MVSSampleTransform( 219 | generate_hints=hints, 220 | hints_perc=hints_density, 221 | filtering_window=filtering_window, 222 | ), 223 | ) 224 | return dm 225 | 226 | 227 | def blended_mvs( 228 | root: str, 229 | batch_size: int = 1, 230 | nviews: int = 5, 231 | ndepths: int = 128, 232 | hints: str = "mvguided_filtered", 233 | hints_density: float = 0.03, 234 | filtering_window: Tuple[int, int] = (9, 9), 235 | num_workers: int = cpu_count() // 2, 236 | ): 237 | """ 238 | Utility function to load a Pytorch Lightning DataModule loading 239 | the BlendedMVS dataset 240 | """ 241 | return _load_dataset( 242 | "blended_mvs", 243 | root=root, 244 | batch_size=batch_size, 245 | nviews=nviews, 246 | ndepths=ndepths, 247 | hints=hints, 248 | hints_density=hints_density, 249 | filtering_window=filtering_window, 250 | num_workers=num_workers, 251 | ) 252 | 253 | 254 | def blended_mvg( 255 | root: str, 256 | batch_size: int = 1, 257 | nviews: int = 5, 258 | ndepths: int = 128, 259 | hints: str = "mvguided_filtered", 260 | hints_density: float = 0.03, 261 | filtering_window: Tuple[int, int] = (9, 9), 262 | num_workers: int = cpu_count() // 2, 263 | ): 264 | """ 265 | Utility function to load a Pytorch Lightning DataModule loading 266 | the BlendedMVG dataset 267 | """ 268 | return _load_dataset( 269 | "blended_mvg", 270 | root=root, 271 | batch_size=batch_size, 272 | nviews=nviews, 273 | ndepths=ndepths, 274 | hints=hints, 275 | hints_density=hints_density, 276 | filtering_window=filtering_window, 277 | num_workers=num_workers, 278 | ) 279 | 280 | 281 | def dtu( 282 | root: str, 283 | batch_size: int = 1, 284 | nviews: int = 5, 285 | ndepths: int = 128, 286 | hints: str = "mvguided_filtered", 287 | hints_density: float = 0.03, 288 | filtering_window: Tuple[int, int] = (9, 9), 289 | num_workers: int = 4, # (pretty memory aggressive) 290 | ): 291 | """ 292 | Utility function to load a Pytorch Lightning DataModule loading 293 | the DTU dataset 294 | """ 295 | return _load_dataset( 296 | "dtu_yao", 297 | root=root, 298 | batch_size=batch_size, 299 | nviews=nviews, 300 | ndepths=ndepths, 301 | hints=hints, 302 | hints_density=hints_density, 303 | filtering_window=filtering_window, 304 | num_workers=num_workers, 305 | ) 306 | -------------------------------------------------------------------------------- /params.yaml: -------------------------------------------------------------------------------- 1 | model: patchmatchnet 2 | 3 | # train parameters 4 | train: 5 | dataset: dtu_yao # dtu_yao | blended_mvs (in data/) 6 | epochs: 5 # | null 7 | steps: null # | null, if both epochs and steps it takes the minimum one 8 | batch_size: 1 9 | lr: 0.001 10 | epochs_lr_decay: null # list of increasing int like [10, 12, 14] or null to disable 11 | epochs_lr_gamma: 2 12 | weight_decay: 0.0 13 | ndepths: 128 # 128 14 | views: 5 15 | hints: mvguided_filtered # not_guided | guided | mvguided | mvguided_filtered 16 | hints_filter_window: [9, 9] 17 | hints_density: 0.03 18 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [tool.black] 2 | line-length = 99 3 | exclude = ''' 4 | /( 5 | \.git 6 | | data 7 | | tensorboard_logs 8 | | \.vscode 9 | | output 10 | )/ 11 | ''' 12 | 13 | [tool.isort] 14 | profile = "black" 15 | 16 | [tool.pytest.ini_options] 17 | testpaths = ["tests"] 18 | markers = [ 19 | "slow: marks tests as slow (testing trainings for instance)", 20 | "train: marks tests concerning the training process", 21 | "data: marks tests concerning datasets and data loading", 22 | "dtu: marks tests on the DTU dataset", 23 | "blended_mvs: marks tests on the Blended MVS dataset", 24 | "blended_mvg: marks tests on the Blended MVG dataset", 25 | "mvsnet: marks tests on mvsnet", 26 | "cas_mvsnet: marks tests on cas_mvsnet", 27 | "ucsnet: marks tests on ucsnet", 28 | "d2hc_rmvsnet: marks tests on d2hc_rmvsnet", 29 | "patchmatchnet: marks tests on patchmatchnet", 30 | ] -------------------------------------------------------------------------------- /results.ipynb: -------------------------------------------------------------------------------- 1 | { 2 | "cells": [ 3 | { 4 | "cell_type": "markdown", 5 | "metadata": {}, 6 | "source": [ 7 | "# Example to use the pretrained models\n", 8 | "\n", 9 | "In this notebook we show how to load the dataset and the pretrained models to reproduce some of the results reported in the paper, to do so we leverage the ``torch.hub`` API" 10 | ] 11 | }, 12 | { 13 | "cell_type": "code", 14 | "execution_count": 1, 15 | "metadata": {}, 16 | "outputs": [], 17 | "source": [ 18 | "import torch\n", 19 | "from tqdm import tqdm\n", 20 | "from collections import defaultdict\n", 21 | "import pandas as pd\n", 22 | "\n", 23 | "HUBCONF_URI = \".\"\n", 24 | "HUBCONF_SOURCE = \"local\"" 25 | ] 26 | }, 27 | { 28 | "cell_type": "markdown", 29 | "metadata": {}, 30 | "source": [ 31 | "### Load the dataset\n", 32 | "\n", 33 | "We provide implementation for DTU and Blended-MVS / Blended-MVG, once that the dataset is locally available it can be simply loaded by means of the torch.hub function" 34 | ] 35 | }, 36 | { 37 | "cell_type": "code", 38 | "execution_count": 2, 39 | "metadata": {}, 40 | "outputs": [], 41 | "source": [ 42 | "# load the dataset\n", 43 | "dm = torch.hub.load(\n", 44 | " HUBCONF_URI,\n", 45 | " \"blended_mvg\",\n", 46 | " source=HUBCONF_SOURCE,\n", 47 | " root=\"data/blended-mvs\",\n", 48 | " hints=\"mvguided_filtered\",\n", 49 | " hints_density=0.03,\n", 50 | ")\n", 51 | "dm.prepare_data()\n", 52 | "dm.setup()\n", 53 | "dl = dm.test_dataloader()" 54 | ] 55 | }, 56 | { 57 | "cell_type": "markdown", 58 | "metadata": {}, 59 | "source": [ 60 | "### Load the Network(s)\n", 61 | "\n", 62 | "Here we load the pretrained network trained with and without sparse depth points and test them, to reproduce the results provided in the [paper](https://arxiv.org/pdf/2210.11467v1.pdf)" 63 | ] 64 | }, 65 | { 66 | "cell_type": "code", 67 | "execution_count": 3, 68 | "metadata": {}, 69 | "outputs": [ 70 | { 71 | "name": "stderr", 72 | "output_type": "stream", 73 | "text": [ 74 | "mvsnet: 100%|██████████| 915/915 [03:04<00:00, 4.95it/s]\n", 75 | "ucsnet: 100%|██████████| 915/915 [04:53<00:00, 3.12it/s]\n", 76 | "d2hc_rmvsnet: 100%|██████████| 915/915 [21:56<00:00, 1.44s/it]\n", 77 | "patchmatchnet: 100%|██████████| 915/915 [02:24<00:00, 6.34it/s]\n", 78 | "cas_mvsnet: 100%|██████████| 915/915 [04:30<00:00, 3.38it/s]\n" 79 | ] 80 | } 81 | ], 82 | "source": [ 83 | "results = defaultdict(lambda: [])\n", 84 | "\n", 85 | "def metrics(pred: torch.Tensor, gt: torch.Tensor):\n", 86 | " mask = (gt > 0)\n", 87 | " diff = torch.abs(gt[mask] - pred[mask])\n", 88 | " return {\n", 89 | " \">1 px\": torch.mean((diff > 1).float()),\n", 90 | " \">2 px\": torch.mean((diff > 2).float()),\n", 91 | " \">3 px\": torch.mean((diff > 3).float()),\n", 92 | " \">4 px\": torch.mean((diff > 4).float()),\n", 93 | " }\n", 94 | "\n", 95 | "for model_name in [\n", 96 | " \"mvsnet\",\n", 97 | " \"ucsnet\",\n", 98 | " \"d2hc_rmvsnet\",\n", 99 | " \"patchmatchnet\",\n", 100 | " \"cas_mvsnet\",\n", 101 | "]:\n", 102 | "\n", 103 | " model_orig = torch.hub.load(\n", 104 | " HUBCONF_URI,\n", 105 | " model_name,\n", 106 | " source=HUBCONF_SOURCE,\n", 107 | " dataset=\"blended_mvg\",\n", 108 | " hints=\"not_guided\",\n", 109 | " )\n", 110 | " model_orig.eval()\n", 111 | " model_orig.cuda() # use a gpu for this\n", 112 | "\n", 113 | " model_hints = torch.hub.load(\n", 114 | " HUBCONF_URI,\n", 115 | " model_name,\n", 116 | " source=HUBCONF_SOURCE,\n", 117 | " dataset=\"blended_mvg\",\n", 118 | " hints=\"mvguided_filtered\",\n", 119 | " hints_density=0.03,\n", 120 | " )\n", 121 | " model_hints.eval()\n", 122 | " model_hints.cuda()\n", 123 | "\n", 124 | " with torch.no_grad():\n", 125 | " for ex in tqdm(dl, desc=model_name):\n", 126 | "\n", 127 | " # compute inputs\n", 128 | " inp_no_hints = {\n", 129 | " \"imgs\": ex[\"imgs\"][\"stage_0\"].cuda(),\n", 130 | " \"intrinsics\": ex[\"intrinsics\"].cuda(),\n", 131 | " \"extrinsics\": ex[\"extrinsics\"].cuda(),\n", 132 | " \"depth_values\": ex[\"depth_values\"].cuda(),\n", 133 | " }\n", 134 | "\n", 135 | " inp_hints = dict(\n", 136 | " **inp_no_hints,\n", 137 | " hints=ex[\"hints\"].cuda(),\n", 138 | " )\n", 139 | " \n", 140 | " # forward\n", 141 | " depth_orig = model_orig(**inp_no_hints)\n", 142 | " depth_hints = model_hints(**inp_hints)\n", 143 | "\n", 144 | " # metrics\n", 145 | " gt = ex[\"depth\"][\"stage_0\"].cuda()\n", 146 | " results[model_name].append(metrics(depth_orig, gt))\n", 147 | " results[model_name + \"_hints\"].append(metrics(depth_hints, gt))" 148 | ] 149 | }, 150 | { 151 | "cell_type": "code", 152 | "execution_count": 4, 153 | "metadata": {}, 154 | "outputs": [ 155 | { 156 | "data": { 157 | "text/html": [ 158 | "
\n", 159 | "\n", 172 | "\n", 173 | " \n", 174 | " \n", 175 | " \n", 176 | " \n", 177 | " \n", 178 | " \n", 179 | " \n", 180 | " \n", 181 | " \n", 182 | " \n", 183 | " \n", 184 | " \n", 185 | " \n", 186 | " \n", 187 | " \n", 188 | " \n", 189 | " \n", 190 | " \n", 191 | " \n", 192 | " \n", 193 | " \n", 194 | " \n", 195 | " \n", 196 | " \n", 197 | " \n", 198 | " \n", 199 | " \n", 200 | " \n", 201 | " \n", 202 | " \n", 203 | " \n", 204 | " \n", 205 | " \n", 206 | " \n", 207 | " \n", 208 | " \n", 209 | " \n", 210 | " \n", 211 | " \n", 212 | " \n", 213 | " \n", 214 | " \n", 215 | " \n", 216 | " \n", 217 | " \n", 218 | " \n", 219 | " \n", 220 | " \n", 221 | " \n", 222 | " \n", 223 | " \n", 224 | " \n", 225 | " \n", 226 | " \n", 227 | " \n", 228 | " \n", 229 | " \n", 230 | " \n", 231 | " \n", 232 | " \n", 233 | " \n", 234 | " \n", 235 | " \n", 236 | " \n", 237 | " \n", 238 | " \n", 239 | " \n", 240 | " \n", 241 | " \n", 242 | " \n", 243 | " \n", 244 | " \n", 245 | " \n", 246 | " \n", 247 | " \n", 248 | " \n", 249 | " \n", 250 | " \n", 251 | " \n", 252 | " \n", 253 | " \n", 254 | " \n", 255 | " \n", 256 | " \n", 257 | " \n", 258 | " \n", 259 | " \n", 260 | " \n", 261 | "
>1 px>2 px>3 px>4 px
model
mvsnet0.1450.0760.0480.033
mvsnet_hints0.0770.0370.0230.015
ucsnet0.0830.0420.0270.019
ucsnet_hints0.0400.0180.0110.008
d2hc_rmvsnet0.1900.1020.0650.044
d2hc_rmvsnet_hints0.0820.0410.0260.018
patchmatchnet0.0830.0410.0260.018
patchmatchnet_hints0.0650.0340.0220.016
cas_mvsnet0.0830.0400.0250.018
cas_mvsnet_hints0.0480.0180.0120.009
\n", 262 | "
" 263 | ], 264 | "text/plain": [ 265 | " >1 px >2 px >3 px >4 px\n", 266 | "model \n", 267 | "mvsnet 0.145 0.076 0.048 0.033\n", 268 | "mvsnet_hints 0.077 0.037 0.023 0.015\n", 269 | "ucsnet 0.083 0.042 0.027 0.019\n", 270 | "ucsnet_hints 0.040 0.018 0.011 0.008\n", 271 | "d2hc_rmvsnet 0.190 0.102 0.065 0.044\n", 272 | "d2hc_rmvsnet_hints 0.082 0.041 0.026 0.018\n", 273 | "patchmatchnet 0.083 0.041 0.026 0.018\n", 274 | "patchmatchnet_hints 0.065 0.034 0.022 0.016\n", 275 | "cas_mvsnet 0.083 0.040 0.025 0.018\n", 276 | "cas_mvsnet_hints 0.048 0.018 0.012 0.009" 277 | ] 278 | }, 279 | "execution_count": 4, 280 | "metadata": {}, 281 | "output_type": "execute_result" 282 | } 283 | ], 284 | "source": [ 285 | "def mean_metrics(lst_d):\n", 286 | " return {k: f\"{torch.stack([dic[k] for dic in lst_d]).mean().item():.3f}\" for k in lst_d[0]}\n", 287 | "\n", 288 | "outs = [dict(model=net_name, **mean_metrics(results[net_name])) for net_name in results]\n", 289 | "pd.DataFrame(outs).set_index(\"model\")" 290 | ] 291 | } 292 | ], 293 | "metadata": { 294 | "kernelspec": { 295 | "display_name": "Python 3.8.10 64-bit ('guided-MVS')", 296 | "language": "python", 297 | "name": "python3" 298 | }, 299 | "language_info": { 300 | "codemirror_mode": { 301 | "name": "ipython", 302 | "version": 3 303 | }, 304 | "file_extension": ".py", 305 | "mimetype": "text/x-python", 306 | "name": "python", 307 | "nbconvert_exporter": "python", 308 | "pygments_lexer": "ipython3", 309 | "version": "3.8.10" 310 | }, 311 | "orig_nbformat": 4, 312 | "vscode": { 313 | "interpreter": { 314 | "hash": "c0867fb98881dfcb09a7568ecd2197f83b886116f1d37752edcf5a1b3d692e7b" 315 | } 316 | } 317 | }, 318 | "nbformat": 4, 319 | "nbformat_minor": 2 320 | } 321 | -------------------------------------------------------------------------------- /tests/test_dataset.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests concerning data loading 3 | """ 4 | 5 | import itertools 6 | import sys 7 | from pathlib import Path 8 | 9 | import numpy as np 10 | import pytest 11 | import torch 12 | from PIL import Image 13 | 14 | sys.path.append(".") 15 | from guided_mvs_lib.datasets import ( 16 | MVSDataset, 17 | blended_mvg_utils, 18 | blended_mvs_utils, 19 | dtu_utils, 20 | ) 21 | from guided_mvs_lib.datasets.sample_preprocess import MVSSampleTransform 22 | 23 | 24 | def _test_scans(dataset, split): 25 | if dataset == "dtu": 26 | ds_utils = dtu_utils 27 | elif dataset == "blended_mvs": 28 | ds_utils = blended_mvs_utils 29 | elif dataset == "blended_mvg": 30 | ds_utils = blended_mvg_utils 31 | 32 | scans = { 33 | "train": ds_utils.train_scans(), 34 | "val": ds_utils.val_scans(), 35 | "test": ds_utils.test_scans(), 36 | }[split] 37 | 38 | if not scans: 39 | raise ValueError(f"{split} not implemented for {dataset}") 40 | 41 | path = Path("data/blended-mvs") 42 | if dataset == "dtu": 43 | path = Path("data/dtu") 44 | 45 | missing_scans = [] 46 | for scan in scans: 47 | paths = ds_utils.datapath_files(path, scan, 10, split, 3) 48 | 49 | exist = True 50 | error = [] 51 | for k, v in paths.items(): 52 | if v is not None and not v.parent.exists(): 53 | exist = False 54 | error.append(f"not found: {str(v.parent)}") 55 | if split == "test" and dataset == "dtu": 56 | assert paths["pcd"] is not None 57 | assert paths["obs_mask"] is not None 58 | assert paths["ground_plane"] is not None 59 | 60 | if not exist: 61 | missing_scans.append((scan, error)) 62 | 63 | error_out = "" 64 | if missing_scans: 65 | for scan, errors in missing_scans: 66 | error_out += f"\nerror in scan {scan}: \n" + "".join(f"- {err}\n" for err in errors) 67 | assert False, error_out 68 | 69 | 70 | @pytest.mark.dtu 71 | @pytest.mark.data 72 | @pytest.mark.parametrize("split", ["train", "val", "test"]) 73 | def test_dtu_scans(split): 74 | _test_scans("dtu", split) 75 | 76 | 77 | @pytest.mark.blended_mvs 78 | @pytest.mark.data 79 | @pytest.mark.parametrize("split", ["train", "val", "test"]) 80 | def test_blended_mvs_scans(split): 81 | _test_scans("blended_mvs", split) 82 | 83 | 84 | @pytest.mark.blended_mvg 85 | @pytest.mark.data 86 | @pytest.mark.parametrize("split", ["train", "val", "test"]) 87 | def test_blended_mvg_scans(split): 88 | _test_scans("blended_mvg", split) 89 | 90 | 91 | def _test_ds_loading(name, mode, nviews, ndepths): 92 | if name == "dtu": 93 | name = "dtu_yao" 94 | 95 | if name in ["blended_mvs", "blended_mvg"]: 96 | datapath = "data/blended-mvs" 97 | else: 98 | datapath = "data/dtu" 99 | 100 | dataset = MVSDataset( 101 | name, 102 | datapath=datapath, 103 | mode=mode, 104 | nviews=nviews, 105 | ndepths=ndepths, 106 | ) 107 | 108 | batch = dataset[0] 109 | 110 | for key in { 111 | "imgs", 112 | "intrinsics", 113 | "extrinsics", 114 | "depths", 115 | "ref_depth_min", 116 | "ref_depth_max", 117 | "ref_depth_values", 118 | "filename", 119 | }: 120 | assert key in batch.keys() 121 | assert isinstance(batch["imgs"], list) 122 | assert isinstance(batch["intrinsics"], list) 123 | assert batch["intrinsics"][0].shape == (3, 3) 124 | assert isinstance(batch["extrinsics"], list) 125 | assert batch["extrinsics"][0].shape == (4, 4) 126 | 127 | if name == "dtu_yao": 128 | img_shape = (512, 640) 129 | if mode == "test": 130 | img_shape = (1200, 1600) 131 | assert batch["depths"][0].shape == (*img_shape, 1) 132 | elif name in ["blended_mvs", "blended_mvg"]: 133 | assert batch["depths"][0].shape == (576, 768, 1) 134 | 135 | assert isinstance(batch["depths"], list) 136 | assert isinstance(batch["ref_depth_min"], float) 137 | assert isinstance(batch["ref_depth_max"], float) 138 | assert batch["ref_depth_values"].shape == (ndepths,) 139 | assert isinstance(batch["filename"], str) 140 | 141 | 142 | @pytest.mark.dtu 143 | @pytest.mark.data 144 | @pytest.mark.parametrize( 145 | "mode, nviews, ndepths", 146 | itertools.product( 147 | ["train", "val", "test"], 148 | [3, 5], 149 | [192, 128], 150 | ), 151 | ) 152 | def test_dtu_loading(mode, nviews, ndepths): 153 | _test_ds_loading("dtu", mode, nviews, ndepths) 154 | 155 | 156 | @pytest.mark.blended_mvs 157 | @pytest.mark.data 158 | @pytest.mark.parametrize( 159 | "mode, nviews, ndepths", 160 | itertools.product( 161 | ["train", "val", "test"], 162 | [3, 5], 163 | [192, 128], 164 | ), 165 | ) 166 | def test_blended_mvs_loading(mode, nviews, ndepths): 167 | _test_ds_loading("blended_mvs", mode, nviews, ndepths) 168 | 169 | 170 | @pytest.mark.blended_mvg 171 | @pytest.mark.data 172 | @pytest.mark.parametrize( 173 | "mode, nviews, ndepths", 174 | itertools.product( 175 | ["train", "val", "test"], 176 | [3, 5], 177 | [192, 128], 178 | ), 179 | ) 180 | def test_blended_mvg_loading(mode, nviews, ndepths): 181 | _test_ds_loading("blended_mvg", mode, nviews, ndepths) 182 | 183 | 184 | @pytest.mark.data 185 | @pytest.mark.dtu 186 | @pytest.mark.parametrize("split", ["train", "val", "test"]) 187 | def _test_ds_build_list(ds_name, split, path): 188 | 189 | ds_utils = { 190 | "dtu": dtu_utils, 191 | "blended_mvs": blended_mvs_utils, 192 | "blended_mvg": blended_mvg_utils, 193 | }[ds_name] 194 | 195 | scans = { 196 | "train": ds_utils.train_scans(), 197 | "val": ds_utils.val_scans(), 198 | "test": ds_utils.test_scans(), 199 | }[split] 200 | 201 | metas = ds_utils.build_list(path, scans, 3) 202 | for scan, light_idx, ref_view, _ in metas: 203 | paths = ds_utils.datapath_files(path, scan, ref_view, split, light_idx) 204 | for val in paths.values(): 205 | if val is not None: 206 | assert val.exists(), f"{val} not found" 207 | 208 | 209 | @pytest.mark.data 210 | @pytest.mark.dtu 211 | @pytest.mark.parametrize("split", ["train", "val", "test"]) 212 | def test_dtu_build_list(split): 213 | _test_ds_build_list("dtu", split, "data/dtu") 214 | 215 | 216 | @pytest.mark.data 217 | @pytest.mark.blended_mvs 218 | @pytest.mark.parametrize("split", ["train", "val", "test"]) 219 | def test_blended_mvs_build_list(split): 220 | _test_ds_build_list("blended_mvs", split, "data/blended-mvs") 221 | 222 | 223 | @pytest.mark.data 224 | @pytest.mark.blended_mvg 225 | @pytest.mark.parametrize("split", ["train", "val", "test"]) 226 | def test_blended_mvg_build_list(split): 227 | _test_ds_build_list("blended_mvg", split, "data/blended-mvs") 228 | 229 | 230 | @pytest.mark.data 231 | @pytest.mark.parametrize("hints", ["not_guided", "guided", "mvguided", "mvguided_filtered"]) 232 | def test_dataset_preprocess(hints): 233 | 234 | # TODO: add tests for pcd fields 235 | 236 | fake_image = np.random.randint(0, 255, size=(512, 640, 3), dtype=np.uint8) 237 | sample = { 238 | "imgs": [Image.fromarray(fake_image)] * 3, 239 | "intrinsics": [np.random.randn(3, 3).astype(np.float32)] * 3, 240 | "extrinsics": [np.random.randn(4, 4).astype(np.float32)] * 3, 241 | "depths": [np.random.rand(512, 640, 1).astype(np.float32) * 1000] * 3, 242 | "ref_depth_min": 0.1, 243 | "ref_depth_max": 100.0, 244 | "ref_depth_values": np.arange(0, 1000, 192).astype(np.float32), 245 | "filename": "scan1/{}/00000000{}", 246 | "scan_pcd": None, 247 | "scan_pcd_obs_mask": None, 248 | "scan_pcd_bounding_box": None, 249 | "scan_pcd_resolution": None, 250 | "scan_pcd_ground_plane": None, 251 | } 252 | 253 | output = MVSSampleTransform(generate_hints=hints)(sample) 254 | 255 | for i in range(4): 256 | stage = f"stage_{i}" 257 | dims = 512 // (2 ** i), 640 // (2 ** i) 258 | assert output["imgs"][stage].shape == torch.Size([3, 3, *dims]) 259 | assert output["proj_matrices"][stage].shape == torch.Size([3, 4, 4]) 260 | assert output["depth"][stage].shape == torch.Size([1, *dims]) 261 | 262 | assert "depth_min" in output 263 | assert "depth_max" in output 264 | assert "filename" in output 265 | if hints == "not_guided": 266 | assert not hasattr(output, "hints"), output.keys() 267 | else: 268 | assert output["hints"].shape == torch.Size([1, 512, 640]) 269 | -------------------------------------------------------------------------------- /tests/test_models.py: -------------------------------------------------------------------------------- 1 | """ 2 | Tests network interface and operational behaviour 3 | """ 4 | 5 | import logging 6 | import sys 7 | 8 | sys.path.append(".") 9 | from types import SimpleNamespace 10 | from unittest.mock import MagicMock 11 | 12 | import pytest 13 | 14 | from guided_mvs_lib import models 15 | from guided_mvs_lib.datasets import MVSDataModule 16 | from guided_mvs_lib.datasets.sample_preprocess import MVSSampleTransform 17 | 18 | logging.basicConfig(level=logging.DEBUG) 19 | 20 | _train_args = { 21 | "epochs": 5, 22 | "steps": None, 23 | "batch_size": 1, 24 | "lr": 0.001, 25 | "epochs_lr_decay": None, 26 | "epochs_lr_gamma": 2, 27 | "weight_decay": 0.0, 28 | "ndepths": 128, 29 | "wiews": 2, 30 | "hints_density": 0.01, 31 | } 32 | 33 | 34 | @pytest.mark.slow 35 | @pytest.mark.parametrize("hints", ["not_guided", "guided", "mvguided", "mvguided_filtered"]) 36 | @pytest.mark.parametrize( 37 | "dataset", 38 | [ 39 | pytest.param("blended_mvg", marks=pytest.mark.blended_mvg), 40 | pytest.param("blended_mvs", marks=pytest.mark.blended_mvs), 41 | pytest.param("dtu_yao", marks=pytest.mark.dtu), 42 | ], 43 | ) 44 | @pytest.mark.parametrize( 45 | "network", 46 | [ 47 | pytest.param("mvsnet", marks=pytest.mark.mvsnet), 48 | pytest.param("cas_mvsnet", marks=pytest.mark.cas_mvsnet), 49 | pytest.param("ucsnet", marks=pytest.mark.ucsnet), 50 | pytest.param("patchmatchnet", marks=pytest.mark.patchmatchnet), 51 | pytest.param("d2hc_rmvsnet", marks=pytest.mark.d2hc_rmvsnet), 52 | ], 53 | ) 54 | def test_network_forward(dataset, network, hints): 55 | 56 | data_module = MVSDataModule( 57 | dataset, 58 | nviews=3, 59 | ndepths=128, 60 | transform=MVSSampleTransform(generate_hints=hints), 61 | ) 62 | data_module.setup(stage="test") 63 | dl = data_module.test_dataloader() 64 | batch = next(iter(dl)) 65 | 66 | # load model 67 | args = SimpleNamespace(model=network) 68 | args.train = SimpleNamespace(**_train_args, dataset=dataset, hints=hints) 69 | network = models.build_network(network, args) 70 | 71 | # mocking batch dictionary for inspection 72 | batch_mock = MagicMock() 73 | batch_mock.__getitem__.side_effect = batch.__getitem__ 74 | batch_mock.__contains__.side_effect = batch.__contains__ 75 | 76 | # assert output interface and hints usage 77 | output = network(batch_mock) 78 | assert set(output.keys()) == {"depth", "photometric_confidence", "loss_data"} 79 | if "hints" in batch: 80 | batch_mock.__getitem__.assert_any_call("hints") 81 | -------------------------------------------------------------------------------- /tests/test_train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Testing the training procedure in fast-dev-run with all the networks 3 | on DTU 4 | """ 5 | 6 | import shutil 7 | import sys 8 | import tempfile 9 | from pathlib import Path 10 | 11 | import py 12 | import pytest 13 | 14 | # tests are executed in the root dir working directory, the root directory 15 | # however is not a python package, this hack solves it and allows us to load 16 | # the training script as a module even if this testing file is in `tests` folder 17 | sys.path.append(".") 18 | from train import run_training 19 | 20 | 21 | @pytest.mark.slow 22 | @pytest.mark.train 23 | @pytest.mark.parametrize( 24 | "model", 25 | [ 26 | pytest.param("mvsnet", marks=pytest.mark.mvsnet), 27 | pytest.param("cas_mvsnet", marks=pytest.mark.cas_mvsnet), 28 | pytest.param("ucsnet", marks=pytest.mark.ucsnet), 29 | pytest.param("patchmatchnet", marks=pytest.mark.patchmatchnet), 30 | pytest.param("d2hc_rmvsnet", marks=pytest.mark.d2hc_rmvsnet), 31 | ], 32 | ) 33 | @pytest.mark.parametrize( 34 | "dataset", 35 | [ 36 | pytest.param("dtu_yao", marks=pytest.mark.dtu), 37 | pytest.param("blended_mvs", marks=[pytest.mark.blended_mvs, pytest.mark.blended_mvg]), 38 | ], 39 | ) 40 | @pytest.mark.parametrize("hints", ["not_guided", "guided"]) 41 | def test_network_train(model, dataset, hints): 42 | 43 | # create a folder for the output 44 | path = tempfile.mkdtemp(prefix="guided-mvs-test-") 45 | 46 | run_training( 47 | params={ 48 | "model": model, 49 | "train": { 50 | "dataset": dataset, 51 | "epochs": 10, 52 | "steps": None, 53 | "batch_size": 1, 54 | "lr": 0.001, 55 | "epochs_lr_decay": None, 56 | "epochs_lr_gamma": 2, 57 | "weight_decay": 0.0, 58 | "ndepths": 192, 59 | "views": 3, 60 | "hints": hints, 61 | "hints_density": 0.01, 62 | "hints_filter_window": [9, 9], 63 | }, 64 | }, 65 | cmdline_args=["--gpus", "1", "--fast-dev-run"], 66 | outpath=Path(path), 67 | logspath=Path(path), 68 | ) 69 | 70 | shutil.rmtree(path, ignore_errors=True) 71 | -------------------------------------------------------------------------------- /tests/test_version.py: -------------------------------------------------------------------------------- 1 | """ 2 | Simple check of library version 3 | """ 4 | 5 | import guided_mvs_lib as lib 6 | 7 | 8 | def test_version(): 9 | assert lib.__version__ == "0.1.0" 10 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import argparse 2 | import os 3 | import re 4 | import shutil 5 | import subprocess 6 | import tempfile 7 | import warnings 8 | from datetime import datetime, timedelta 9 | from pathlib import Path 10 | from types import SimpleNamespace 11 | from typing import List 12 | 13 | import numpy as np 14 | import pytorch_lightning as pl 15 | import urllib3 16 | import yaml 17 | from git import Repo 18 | from git.exc import InvalidGitRepositoryError 19 | from mlflow.tracking import MlflowClient 20 | from pytorch_lightning.callbacks import Callback, ModelCheckpoint 21 | from pytorch_lightning.loggers import MLFlowLogger 22 | 23 | import guided_mvs_lib.models as models 24 | from guided_mvs_lib.datasets import MVSDataModule 25 | from guided_mvs_lib.datasets.sample_preprocess import MVSSampleTransform 26 | from guided_mvs_lib.utils import * 27 | 28 | 29 | def run_training( 30 | params: Union[str, Path, dict] = "params.yaml", 31 | cmdline_args: Optional[List[str]] = None, 32 | datapath: Union[str, Path, None] = None, 33 | outpath: Union[str, Path] = "output", 34 | logspath: Union[str, Path] = ".", 35 | ): 36 | # handle args 37 | outpath = Path(outpath) 38 | logspath = Path(logspath) 39 | 40 | # remove annoying torch specific version warnings 41 | warnings.simplefilter("ignore", UserWarning) 42 | urllib3.disable_warnings() 43 | 44 | parser = argparse.ArgumentParser(description="training procedure") 45 | 46 | # training params 47 | parser.add_argument( 48 | "--gpus", type=int, default=1, help="number of gpus to select for training" 49 | ) 50 | parser.add_argument( 51 | "--fast-dev-run", 52 | nargs="?", 53 | const=True, 54 | default=False, 55 | help="if execute a single step of train and val, to debug", 56 | ) 57 | parser.add_argument( 58 | "--limit-train-batches", 59 | type=int, 60 | default=None, 61 | help="limits the number of batches for each epoch, to debug", 62 | ) 63 | parser.add_argument( 64 | "--limit-val-batches", 65 | type=int, 66 | default=None, 67 | help="limits the number of batches for each epoch, to debug", 68 | ) 69 | parser.add_argument( 70 | "--resume-from-checkpoint", 71 | nargs="?", 72 | const=True, 73 | default=False, 74 | help="if resume from the last checkpoint or from a specific checkpoint", 75 | ) 76 | parser.add_argument( 77 | "--load-weights", 78 | default=None, 79 | type=str, 80 | help="load weights either from a mlflow train or from a checkpoint file", 81 | ) 82 | 83 | # experiment date 84 | date = datetime.now().strftime(r"%Y-%h-%d-%H-%M") 85 | 86 | # parse arguments and merge from params.yaml 87 | cmd_line_args = parser.parse_args(cmdline_args) 88 | if isinstance(params, dict): 89 | train_args = params 90 | else: 91 | with open(params, "rt") as f: 92 | train_args = yaml.safe_load(f) 93 | 94 | args = SimpleNamespace(**vars(cmd_line_args)) 95 | for k, v in train_args.items(): 96 | if not isinstance(v, dict): 97 | setattr(args, k, v) 98 | else: 99 | setattr(args, k, SimpleNamespace(**v)) 100 | 101 | # Train using pytorch lightning 102 | pl.seed_everything(42) 103 | 104 | # Build LightningDataModule 105 | data_module = MVSDataModule( 106 | args.train.dataset, 107 | batch_size=args.train.batch_size, 108 | datapath=datapath, 109 | nviews=args.train.views, 110 | ndepths=args.train.ndepths, 111 | robust_train=True if args.train.dataset == "dtu_yao" else False, 112 | transform=MVSSampleTransform( 113 | generate_hints=args.train.hints, 114 | hints_perc=args.train.hints_density, 115 | filtering_window=tuple(args.train.hints_filter_window), 116 | ), 117 | ) 118 | 119 | # loading model or only weights ? 120 | if args.load_weights is not None and args.resume_from_checkpoint is not False: 121 | print("Use either --load-weights or --resume-from-checkpoint") 122 | return 123 | 124 | ckpt_path = None 125 | steps_re = re.compile("step=(\d+)") 126 | if args.resume_from_checkpoint is True: 127 | if (outpath / "ckpts/last.ckpt").exists(): 128 | ckpt_path = outpath / "ckpts/last.ckpt" 129 | else: 130 | ckpts = list((outpath / "ckpts").glob("*.ckpt")) 131 | steps = [ 132 | int(steps_re.findall(ckpt.name)[0]) 133 | for ckpt in ckpts 134 | if steps_re.findall(ckpt.name) is not [] 135 | ] 136 | if not steps: 137 | print("not found any valid checkpoint in", str(outpath / "ckpts")) 138 | return 139 | ckpt_path = ckpts[np.argmax(steps)] 140 | print(f"resuming from last checkpoint: {ckpt_path}") 141 | elif args.resume_from_checkpoint is not False: 142 | if Path(args.resume_from_checkpoint).exists(): 143 | ckpt_path = args.resume_from_checkpoint 144 | print(f"resuming from choosen checkpoint: {ckpt_path}") 145 | else: 146 | print(f"file {ckpt_path} does not exist") 147 | return 148 | 149 | # init mlflow logger and model 150 | if ckpt_path is None: 151 | logger = MLFlowLogger( 152 | experiment_name="guided-mvs", 153 | run_name=f"{args.model}-{date}", 154 | ) 155 | 156 | outpath.mkdir(exist_ok=True, parents=True) 157 | with open(outpath / "run_uuid", "wt") as f: 158 | f.write(logger.run_id) 159 | 160 | model = models.MVSModel( 161 | args=args, 162 | mlflow_run_id=logger.run_id, 163 | v_num=f"{args.model}-{'-'.join(date.split('-')[1:3])}", 164 | ) 165 | else: 166 | 167 | with open(outpath / "run_uuid", "rt") as f: 168 | mlflow_run_id = f.readline().strip() 169 | 170 | model = models.MVSModel.load_from_checkpoint( 171 | ckpt_path, 172 | args=args, 173 | mlflow_run_id=mlflow_run_id, 174 | v_num=f"{args.model}-{'-'.join(date.split('-')[1:3])}", 175 | ) 176 | logger = MLFlowLogger( 177 | experiment_name="guided-mvs", 178 | run_name=f"{args.model}-{date}", 179 | ) 180 | logger._run_id = mlflow_run_id 181 | 182 | # if required load weights 183 | if args.load_weights is not None: 184 | mlflow_client: MlflowClient = logger.experiment 185 | if args.load_weights in [ 186 | run.run_uuid for run in mlflow_client.list_run_infos(logger.experiment_id) 187 | ]: 188 | # download the model 189 | run_weights_path = mlflow_client.download_artifacts(args.load_weights, "model.ckpt") 190 | model.load_state_dict(torch.load(run_weights_path)["state_dict"]) 191 | 192 | # track the model weights 193 | run_weights_path = Path(run_weights_path) 194 | shutil.move(run_weights_path, run_weights_path.parent / "init_weights.ckpt") 195 | mlflow_client.log_artifact( 196 | logger.run_id, run_weights_path.parent / "init_weights.ckpt" 197 | ) 198 | mlflow_client.set_tag(logger.run_id, "load_weights", args.load_weights) 199 | shutil.rmtree(Path(run_weights_path).parent, ignore_errors=True) 200 | else: 201 | try: 202 | model.load_state_dict(torch.load(args.load_weights)["state_dict"]) 203 | tmpdir = Path(tempfile.mkdtemp()) 204 | shutil.copy(args.load_weights, tmpdir / "init_weights.ckpt") 205 | mlflow_client.log_artifact(logger.run_id, tmpdir / "init_weights.ckpt") 206 | shutil.rmtree(tmpdir, ignore_errors=True) 207 | except FileNotFoundError: 208 | print(f"{args.load_weights} is neither a valid run id or a path to a .ckpt") 209 | return 210 | 211 | # handle checkpoints 212 | if ( 213 | args.train.epochs is None 214 | or args.train.epochs == 1 215 | and args.train.steps is not None 216 | and args.train.steps > 0 217 | ): 218 | ckpt_callback = ModelCheckpoint( 219 | outpath / "ckpts", 220 | train_time_interval=timedelta(hours=2), 221 | save_last=True, 222 | ) 223 | else: 224 | ckpt_callback = ModelCheckpoint(outpath / "ckpts", save_last=True) 225 | 226 | remove_output = True 227 | 228 | class HandleOutputs(Callback): 229 | def on_train_end(self, trainer, pl_module): 230 | 231 | # save final model 232 | print("saving the final model.") 233 | torch.save( 234 | {"global_step": trainer.global_step, "state_dict": pl_module.state_dict()}, 235 | outpath / "model.ckpt", 236 | ) 237 | 238 | # copy the model and the params on MLFlow 239 | if not args.fast_dev_run: 240 | mlflow_client: MlflowClient = logger.experiment 241 | 242 | # store diff file if needed 243 | try: 244 | repo = Repo(Path.cwd()) 245 | 246 | if repo.is_dirty(): 247 | try: 248 | out = subprocess.check_output(["git", "diff"], cwd=Path.cwd()) 249 | if out is not None: 250 | tmpfile = Path(tempfile.mkdtemp()) / "changes.diff" 251 | with open(tmpfile, "wb") as f: 252 | f.write(out) 253 | mlflow_client.log_artifact(logger.run_id, tmpfile) 254 | os.remove(tmpfile) 255 | except subprocess.CalledProcessError as e: 256 | print("Failed to save a diff file of the current experiment") 257 | 258 | except InvalidGitRepositoryError: 259 | pass 260 | 261 | # save the model 262 | mlflow_client.log_artifact(logger.run_id, str(outpath / "model.ckpt")) 263 | 264 | # finally, remove the temp output and log in a hidden file the current run 265 | # for the eval step 266 | with open(".current_run.yaml", "wt") as f: 267 | yaml.safe_dump( 268 | {"experiment": logger.experiment_id, "run_uuid": logger.run_id}, f 269 | ) 270 | 271 | def on_keyboard_interrupt(self, trainer, pl_module): 272 | print("training interrupted") 273 | 274 | # (not removing checkpoints) 275 | nonlocal remove_output 276 | remove_output = False 277 | 278 | # init train 279 | trainer_params = { 280 | "gpus": args.gpus, 281 | "fast_dev_run": args.fast_dev_run, 282 | "logger": logger, 283 | "benchmark": True, 284 | "callbacks": [ckpt_callback, HandleOutputs()], 285 | "weights_summary": None, 286 | "resume_from_checkpoint": ckpt_path, 287 | "num_sanity_val_steps": 0, 288 | } 289 | 290 | if ( 291 | args.resume_from_checkpoint is not False 292 | and args.train.epochs is not None 293 | and args.train.epochs == 1 294 | and args.train.steps is not None 295 | and args.train.steps > 0 296 | and ckpt_path is not None 297 | ): 298 | args.train.epochs = None 299 | 300 | if args.train.epochs is not None: 301 | trainer_params["max_epochs"] = args.train.epochs 302 | if args.train.steps is not None: 303 | trainer_params["max_steps"] = args.train.steps 304 | if args.limit_train_batches is not None: 305 | trainer_params["limit_train_batches"] = args.limit_train_batches 306 | if args.limit_val_batches is not None: 307 | trainer_params["limit_val_batches"] = args.limit_val_batches 308 | 309 | trainer = pl.Trainer(**trainer_params) 310 | trainer.fit(model, data_module) 311 | 312 | if remove_output: 313 | shutil.rmtree(outpath) 314 | 315 | 316 | if __name__ == "__main__": 317 | run_training() 318 | --------------------------------------------------------------------------------