├── .gitignore
├── .pre-commit-config.yaml
├── .vscode
├── extensions.json
└── settings.json
├── LICENSE
├── README.md
├── data
└── README.md
├── environment.yml
├── eval.py
├── guided_mvs_lib
├── __init__.py
├── datasets
│ ├── __init__.py
│ ├── blended_mvg_utils.py
│ ├── blended_mvs_utils.py
│ ├── dtu_blended_mvs.py
│ ├── dtu_utils.py
│ ├── sample_preprocess.py
│ └── utils.py
├── models
│ ├── __init__.py
│ ├── cas_mvsnet
│ │ ├── __init__.py
│ │ ├── cas_mvsnet.py
│ │ └── module.py
│ ├── d2hc_rmvsnet
│ │ ├── __init__.py
│ │ ├── convlstm.py
│ │ ├── drmvsnet.py
│ │ ├── module.py
│ │ ├── rnnmodule.py
│ │ ├── submodule.py
│ │ ├── vamvsnet.py
│ │ └── vamvsnet_high_submodule.py
│ ├── mvsnet
│ │ ├── __init__.py
│ │ ├── module.py
│ │ └── mvsnet.py
│ ├── patchmatchnet
│ │ ├── __init__.py
│ │ ├── module.py
│ │ ├── net.py
│ │ └── patchmatch.py
│ └── ucsnet
│ │ ├── __init__.py
│ │ ├── submodules.py
│ │ └── ucsnet.py
└── utils.py
├── hubconf.py
├── params.yaml
├── pyproject.toml
├── results.ipynb
├── tests
├── test_dataset.py
├── test_models.py
├── test_train.py
└── test_version.py
└── train.py
/.gitignore:
--------------------------------------------------------------------------------
1 | .python-version
2 | **/__pycache__/
3 | data/*
4 | !data/README.md
5 | eval_output/
6 | guided-mvs.code-workspace
7 | .pytest_cache
8 | notebooks/
9 | .current_run.yaml
10 | .envrc
11 | mlruns/
12 |
--------------------------------------------------------------------------------
/.pre-commit-config.yaml:
--------------------------------------------------------------------------------
1 | # See https://pre-commit.com for more information
2 | # See https://pre-commit.com/hooks.html for more hooks
3 | repos:
4 | - repo: https://github.com/pre-commit/pre-commit-hooks
5 | rev: v3.2.0
6 | hooks:
7 | - id: check-yaml
8 | - id: check-added-large-files
9 | - repo: https://github.com/psf/black
10 | rev: 21.9b0
11 | hooks:
12 | - id: black
13 | - repo: https://github.com/PyCQA/isort
14 | rev: 5.9.3
15 | hooks:
16 | - id: isort
--------------------------------------------------------------------------------
/.vscode/extensions.json:
--------------------------------------------------------------------------------
1 | {
2 | "recommendations": [
3 | "ms-python.python",
4 | "yzhang.markdown-all-in-one",
5 | ]
6 | }
--------------------------------------------------------------------------------
/.vscode/settings.json:
--------------------------------------------------------------------------------
1 | {
2 | "editor.formatOnSave": true,
3 | "python.formatting.provider": "black",
4 | "[python]": {
5 | "editor.codeActionsOnSave": {
6 | "source.organizeImports": true
7 | }
8 | },
9 | "files.watcherExclude": {
10 | "**/.git/objects/**": true,
11 | "**/.git/subtree-cache/**": true,
12 | "**/node_modules/*/**": true,
13 | "**/.hg/store/**": true,
14 | "**/output/**": true,
15 | "**/data/**": true,
16 | },
17 | "python.testing.pytestArgs": [
18 | "tests"
19 | ],
20 | "python.testing.unittestEnabled": false,
21 | "python.testing.pytestEnabled": true
22 | }
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2020 Fangjinhua Wang
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Multi-View Guided Multi-View Stereo
2 |
3 |
4 |
12 |
17 |
18 |
19 | This is the official source code of Multi-View Guided Multi-View Stereo presented at [IEEE/RSJ International Conference on Intelligent Robots and Systems](https://iros2022.org/)
20 |
21 | ## Citation
22 |
23 | ```
24 | @inproceedings{poggi2022guided,
25 | title={Multi-View Guided Multi-View Stereo},
26 | author={Poggi, Matteo and Conti, Andrea and Mattoccia, Stefano},
27 | booktitle={IEEE/RSJ International Conference on Intelligent Robots and Systems},
28 | note={IROS},
29 | year={2022}
30 | }
31 | ```
32 |
33 | ## Load pretrained models and evaluate
34 |
35 | We release many of the mvs networks tested in the paper trained on Blended-MVG or on Blended-MVG and fine-tuned on DTU, with and without sparse depth points. To load these models can be simply used the `torch.hub` API.
36 |
37 | ```python
38 | model = torch.hub.load(
39 | "andreaconti/multi-view-guided-multi-view-stereo",
40 | "mvsnet", # mvsnet | ucsnet | d2hc_rmvsnet | patchmatchnet | cas_mvsnet
41 | pretrained=True,
42 | dataset="blended_mvg", # blended_mvg | dtu_yao_blended_mvg
43 | hints="not_guided", # mvguided_filtered | not_guided | guided | mvguided
44 | )
45 | ```
46 |
47 | Once loaded each model have the same following interface, moreover each pretrained models provides its training parameters under the attribute `train_params`.
48 |
49 | ```python
50 | depth = model(
51 | images, # B x N x 3 x H x W
52 | intrinsics, # B x N x 3 x 3
53 | extrinsics, # B x N x 4 x 4
54 | depth_values, # B x D (128 usually)
55 | hints, # B x 1 x H x W (optional)
56 | )
57 | ```
58 |
59 | Finally, we provide also an interface over the datasets used as follows. In this case is required Pytorch Lightning as dependency and the dataset must be stored locally.
60 |
61 | ```python
62 | dm = torch.hub.load(
63 | "andreaconti/multi-view-guided-multi-view-stereo",
64 | "blended-mvg", # blended_mvg | blended_mvs | dtu
65 | root="data/blended-mvg",
66 | hints="not_guided", # mvguided_filtered | not_guided | guided | mvguided
67 | hints_density=0.03,
68 | )
69 | dm.prepare_data()
70 | dm.setup()
71 | dl = dm.train_dataloader()
72 | ```
73 |
74 | In [results.ipynb](https://github.com/andreaconti/multi-view-guided-multi-view-stereo/blob/main/results.ipynb) there is an example of how to reproduce some of the results showed in the paper through the `torch.hub` API.
75 |
76 | ## Installation
77 |
78 | Install the dependencies using Conda or [Mamba](https://github.com/mamba-org/mamba):
79 |
80 | ```bash
81 | $ conda env create -f environment.yml
82 | $ conda activate guided-mvs
83 | ```
84 |
85 | ## Download the dataset(s)
86 |
87 | Download the used datasets:
88 |
89 | * [DTU](http://roboimagedata.compute.dtu.dk/?page_id=36) preprocessed by [patchmatchnet](https://github.com/FangjinhuaWang/PatchmatchNet), [train](https://polybox.ethz.ch/index.php/s/ugDdJQIuZTk4S35) and [val](https://drive.google.com/file/d/1jN8yEQX0a-S22XwUjISM8xSJD39pFLL_/view?usp=sharing)
90 | * [BlendedMVG](https://github.com/YoYo000/BlendedMVS), download the lowres version from [BlendedMVS](https://1drv.ms/u/s!Ag8Dbz2Aqc81gVDu7FHfbPZwqhIy?e=BHY07t), [BlendedMVS+](https://1drv.ms/u/s!Ag8Dbz2Aqc81gVLILxpohZLEYiIa?e=MhwYSR), [BlendedMVS++](https://1drv.ms/u/s!Ag8Dbz2Aqc81gVHCxmURGz0UBGns?e=Tnw2KY)
91 |
92 | And organize them as follows under the data folder (sym-links works fine):
93 |
94 | ```
95 | data/dtu
96 | |-- train_data
97 | |-- Cameras_1
98 | |-- Depths_raw
99 | |-- Rectified
100 | |-- test_data
101 | |-- scan1
102 | |-- scan2
103 | |-- ..
104 |
105 | data/blended-mvs
106 | |--
107 | |-- blended_images
108 | |-- cams
109 | |-- rendered_depth_maps
110 | |-- ..
111 | ```
112 |
113 | ## [Optional] Test everything is fine
114 |
115 | This project implements some tests to preliminarily check everything is fine. Tests are grouped by different tags.
116 |
117 | ```
118 | # tages:
119 | # - data: tests related to the datasets
120 | # - dtu: tests related to the DTU dataset
121 | # - blended_mvs: tests related to blended MVS
122 | # - blended_mvg: tests related to blended MVG
123 | # - train: tests to launch all networks in fast dev run mode (1 batch of train val for each network for each dataset)
124 | # - slow: tests slow to be executed
125 |
126 | # EXAMPLES
127 |
128 | # runs all tests
129 | $ pytest
130 |
131 | # runs tests excluding slow ones
132 | $ pytest -m "not slow"
133 |
134 | # runs tests only on dtu
135 | $ pytest -m dtu
136 |
137 | # runs tests on data except for dtu ones
138 | $ pytest -m "data and not dtu"
139 | ```
140 |
141 | ## Training
142 |
143 | To train a model, edit ``params.yaml`` specifying the model to be trained among the following:
144 |
145 | * [cas_mvsnet](https://arxiv.org/pdf/1912.06378.pdf)
146 | * [d2hc_rmvsnet](https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123490647.pdf)
147 | * [mvsnet](https://arxiv.org/pdf/1804.02505.pdf)
148 | * [patchmatchnet](https://arxiv.org/pdf/2012.01411.pdf)
149 | * [ucsnet](https://arxiv.org/abs/1911.12012)
150 |
151 | The dataset between ``dtu_yao``, `blended_mvs`, ``blended_mvg`` and the other training parameters, then hit:
152 |
153 | ```
154 | # python3 train.py --help to see the options
155 | $ python3 train.py
156 | ```
157 |
158 | The best model is stored in ``output`` folder as ``model.ckpt`` along with a ``meta`` folder containing useful informations about the training executed.
159 |
160 | ### Resume a training
161 |
162 | If something bad happens or if you stop the training process using a keyboard interrupt (``Ctrl-C``) the checkpoints will not be deleted and you can resume
163 | the training with the following option:
164 |
165 | ```
166 | # resume the last checkpoint saved in output/ckpts (last epoch)
167 | $ python3 train.py --resume-from-checkpoint
168 |
169 | # resume a choosen checkpoint elsewere
170 | $ python3 train.py --resume-from-checkpoint my_checkpoint.ckpt
171 | ```
172 |
173 | It takes care to properly update the correct training logs on tensorboard.
174 |
175 | ## Evaluation
176 |
177 | Once you have trained the model you can evaluate it using ``eval.py``, here you have few options, specifically:
178 |
179 | * the **dataset** on which test
180 | * if evaluate using **guided** hints, **guided_integral** hints or none
181 | * the **hints density** to be used
182 | * the number of **views**
183 |
184 | ```
185 | # see the options
186 | $ python3 eval.py --help
187 | usage: eval.py [-h] [--dataset {dtu_yao,blended_mvs,blended_mvg}]
188 | [--hints {not_guided,guided,guided_integral}]
189 | [--hints-density HINTS_DENSITY] [--views VIEWS]
190 | [--loadckpt LOADCKPT] [--limit-scans LIMIT_SCANS]
191 | [--skip-steps {1,2,3} [{1,2,3} ...]] [--save-scans-output]
192 |
193 | # EXAMPLES
194 | # without guided hints on dtu_yao, 3 views
195 | $ python3 eval.py
196 |
197 | # with guided hints and 5 views and density of 0.01
198 | $ python3 eval.py --hints guided --views 5
199 |
200 | # with integral guided hints 3 views and 0.03 density
201 | $ python3 eval.py --hints guided_integral --hints-density 0.03
202 | ```
203 |
204 | Results will be stored under ``output/eval_/[guided|not_guided|guided_integral]-[density=]-views=``, for instance guiding on DTU with a 0.01 density and using 3 views the results will be in:
205 |
206 | * ``output/eval_dtu_yao/guided-density=0.01-views=3/``
207 |
208 | Each of these folders will contain the point cloud for each testing scene and a ``metrics.json`` file containing the final metrics, they will differ depending on the dataset used for evaluation.
209 |
210 | ## Development
211 |
212 | ### Environment
213 |
214 | To develop you have to create a conda virtual environment and **also** install git hooks:
215 |
216 | ```bash
217 | $ conda env create -f environment.yml
218 | $ conda activate guided-MVS
219 | $ pre-commit install
220 | ```
221 |
222 | When you will commit [Black](https://github.com/psf/black) and [Isort](https://pypi.org/project/isort/) will be executed on the modified
223 | files.
224 |
225 | ### VSCode specific settings
226 |
227 | If you use Visual Studio Code its configuration and needed extensions are stored in ``.vscode``. Create a file in the root folder called ``.guided-mvs.code-workspace`` containing the following to load the conda environment properly:
228 |
229 | ```json
230 | {
231 | "folders": [
232 | {
233 | "path": "."
234 | }
235 | ],
236 | "settings": {
237 | "python.pythonPath": ""
238 | }
239 | }
240 | ```
241 |
--------------------------------------------------------------------------------
/data/README.md:
--------------------------------------------------------------------------------
1 | # Data Folder
2 |
3 | Here you have to put your datasets (even a soft link is enough)
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: guided-mvs
2 | channels:
3 | - pytorch-lts
4 | - conda-forge
5 | - defaults
6 | dependencies:
7 |
8 | # core
9 | - python=3.8
10 | - cudatoolkit=11.1
11 | - pytorch=1.8.*
12 | - pytorch-lightning=1.4.*
13 | - torchvision
14 | - numpy>=1.20.0
15 | - tqdm
16 | - pyyaml
17 | - scipy
18 | - tensorboard
19 | - joblib
20 | - pip
21 |
22 | # dev
23 | - black
24 | - isort
25 | - pre-commit
26 | - pytest
27 |
28 | - pip:
29 | - opencv-python
30 | - cmapy >=0.6.0, <1.0
31 | - plyfile
32 | - pillow
33 | - GitPython
34 | - mlflow
35 | - pysftp
36 |
--------------------------------------------------------------------------------
/guided_mvs_lib/__init__.py:
--------------------------------------------------------------------------------
1 | __version__ = "0.1.0"
2 |
--------------------------------------------------------------------------------
/guided_mvs_lib/datasets/__init__.py:
--------------------------------------------------------------------------------
1 | import multiprocessing
2 | from pathlib import Path
3 | from typing import Callable, List, Literal, Optional, Union
4 |
5 | from pytorch_lightning import LightningDataModule
6 | from torch.utils.data import DataLoader
7 |
8 | from . import blended_mvg_utils, blended_mvs_utils, dtu_utils
9 | from .dtu_blended_mvs import MVSDataset, MVSSample
10 |
11 | __all__ = ["MVSDataModule", "MVSSample", "find_dataset_def", "find_scans"]
12 |
13 |
14 | class MVSDataModule(LightningDataModule):
15 | def __init__(
16 | self,
17 | # selection of the dataset
18 | name: Literal["dtu_yao", "blended_mvs", "blended_mvg"],
19 | # args for the dataloader
20 | batch_size: int = 1,
21 | num_workers: int = multiprocessing.cpu_count() // 2,
22 | # args for the dataset
23 | **kwargs,
24 | ):
25 | super().__init__()
26 | self._ds_builder = find_dataset_def(name, kwargs.pop("datapath", None))
27 | self._ds_args = kwargs
28 |
29 | # dataloader args
30 | self.batch_size = batch_size
31 | self.num_workers = num_workers
32 |
33 | def setup(self, stage: Optional[str] = None):
34 | if stage in ("fit", None):
35 | self.mvs_train = self._ds_builder(mode="train", **self._ds_args)
36 | self.mvs_val = self._ds_builder(mode="val", **self._ds_args)
37 | if stage in ("test", None):
38 | self.mvs_test = self._ds_builder(mode="test", **self._ds_args)
39 |
40 | def train_dataloader(self):
41 | return DataLoader(
42 | self.mvs_train,
43 | batch_size=self.batch_size,
44 | shuffle=True,
45 | num_workers=self.num_workers,
46 | )
47 |
48 | def val_dataloader(self):
49 | return DataLoader(
50 | self.mvs_val,
51 | batch_size=1,
52 | shuffle=False,
53 | num_workers=self.num_workers,
54 | )
55 |
56 | def test_dataloader(self):
57 | return DataLoader(
58 | self.mvs_test,
59 | batch_size=1,
60 | shuffle=False,
61 | num_workers=self.num_workers,
62 | )
63 |
64 |
65 | def find_dataset_def(
66 | name: Literal["dtu_yao", "blended_mvs", "blended_mvg"],
67 | datapath: Union[Path, str, None] = None,
68 | ) -> Callable[..., MVSDataset]:
69 | assert name in ["dtu_yao", "blended_mvs", "blended_mvg"]
70 | if datapath is None:
71 | datapath = {
72 | "dtu_yao": "data/dtu",
73 | "blended_mvs": "data/blended-mvs",
74 | "blended_mvg": "data/blended-mvs",
75 | }[name]
76 | datapath = Path(datapath)
77 |
78 | def builder(*args, **kwargs):
79 | return MVSDataset(name, datapath, *args, **kwargs)
80 |
81 | return builder
82 |
83 |
84 | _SCANS = {
85 | "dtu_yao": {
86 | "train": dtu_utils.train_scans(),
87 | "val": dtu_utils.val_scans(),
88 | "test": dtu_utils.test_scans(),
89 | },
90 | "blended_mvs": {
91 | "train": blended_mvs_utils.train_scans(),
92 | "val": blended_mvs_utils.val_scans(),
93 | "test": blended_mvs_utils.test_scans(),
94 | },
95 | "blended_mvg": {
96 | "train": blended_mvg_utils.train_scans(),
97 | "val": blended_mvg_utils.val_scans(),
98 | "test": blended_mvg_utils.test_scans(),
99 | },
100 | }
101 |
102 |
103 | def find_scans(
104 | dataset_name: Literal["dtu_yao", "blended_mvs", "blended_mvg"],
105 | split: Literal["train", "val", "test"],
106 | ) -> Optional[List[str]]:
107 | try:
108 | return _SCANS[dataset_name][split]
109 | except KeyError:
110 | if dataset_name not in ["dtu_yao", "blended_mvs", "blended_mvg"]:
111 | raise ValueError(f"{dataset_name} not in dtu_utils, blended_mvs, blended_mvg")
112 | elif split not in ["train", "val", "test"]:
113 | raise ValueError(f"{split} not in train, val, test")
114 |
--------------------------------------------------------------------------------
/guided_mvs_lib/datasets/blended_mvs_utils.py:
--------------------------------------------------------------------------------
1 | import os
2 | from pathlib import Path
3 | from typing import List, Literal, Tuple, Union
4 |
5 | import numpy as np
6 |
7 | from .utils import DatasetExamplePaths, read_pfm
8 |
9 | __all__ = [
10 | "datapath_files",
11 | "read_cam_file",
12 | "read_depth_mask",
13 | "read_depth",
14 | "build_list",
15 | "train_scans",
16 | "val_scans",
17 | "test_scans",
18 | ]
19 |
20 |
21 | def datapath_files(
22 | datapath: Union[Path, str],
23 | scan: str,
24 | view_id: int,
25 | split: Literal["train", "test", "val"],
26 | light_id: int = None,
27 | ) -> DatasetExamplePaths:
28 | """
29 | Takes in input the root of the Blended MVS dataset and returns a dictionary containing
30 | the paths to the files of a single scan, with the speciied view_id and light_id
31 | (this last one is used only when ``split`` isn't test)
32 |
33 | Parameters
34 | ----------
35 | datapath: path
36 | path to the root of the dtu dataset
37 | scan: str
38 | name of the used scan, say scan1
39 | view_id: int
40 | the index of the specific image in the scan
41 | split: train, val or test
42 | which split of DTU must be used to search the scan for
43 | light_id: int
44 | the index of the specific lightning condition index
45 |
46 | Returns
47 | -------
48 | out: Dict[str, Path]
49 | returns a dictionary containing the paths taken into account
50 | """
51 | root = Path(datapath)
52 | return {
53 | "img": root / f"{scan}/blended_images/{view_id:0>8}.jpg",
54 | "proj_mat": root / f"{scan}/cams/{view_id:0>8}_cam.txt",
55 | "depth_mask": root / f"{scan}/rendered_depth_maps/{view_id:0>8}.pfm",
56 | "depth": root / f"{scan}/rendered_depth_maps/{view_id:0>8}.pfm",
57 | "pcd": None,
58 | "obs_mask": None,
59 | "ground_plane": None,
60 | }
61 |
62 |
63 | def read_cam_file(path: str) -> Tuple[np.ndarray, np.ndarray, float, float]:
64 | """
65 | Reads a file containing the Blended MVS camera intrinsics, extrinsics, max depth and
66 | min depth.
67 |
68 | Parameters
69 | ----------
70 | path: str
71 | path of the source file (something like ../00000000_cam.txt)
72 |
73 | Returns
74 | -------
75 | out: Tuple[np.ndarray, np.ndarray, float, float]
76 | respectively intrinsics, extrinsics, min depth and max depth
77 | """
78 | with open(path) as f:
79 | lines = f.readlines()
80 | lines = [line.rstrip() for line in lines]
81 | extrinsics = np.fromstring(" ".join(lines[1:5]), dtype=np.float32, sep=" ").reshape((4, 4))
82 | intrinsics = np.fromstring(" ".join(lines[7:10]), dtype=np.float32, sep=" ").reshape((3, 3))
83 | depth_min = float(lines[11].split()[0])
84 | depth_max = float(lines[11].split()[3])
85 | return intrinsics, extrinsics, depth_min, depth_max
86 |
87 |
88 | def read_depth_mask(path: Union[str, Path]) -> np.ndarray:
89 | """
90 | Loads the Blended MVS depth mask
91 | """
92 | mask = np.array(read_pfm(path)[0], dtype=np.float32) > 0
93 | return mask
94 |
95 |
96 | def read_depth(path: Union[str, Path]) -> np.ndarray:
97 | """
98 | Loads the depth DTU depth map
99 | """
100 | depth = np.array(read_pfm(str(path))[0], dtype=np.float32)
101 | return depth
102 |
103 |
104 | def build_list(datapath: str, scans: list, nviews: int):
105 | metas = []
106 | for scan in scans:
107 | pair_file = "cams/pair.txt"
108 |
109 | with open(os.path.join(datapath, scan, pair_file)) as f:
110 | num_viewpoint = int(f.readline())
111 | for view_idx in range(num_viewpoint):
112 | ref_view = int(f.readline().rstrip())
113 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
114 | if len(src_views) >= nviews:
115 | metas.append((scan, None, ref_view, src_views))
116 | return metas
117 |
118 |
119 | def train_scans() -> List[str]:
120 | return [
121 | "5c1f33f1d33e1f2e4aa6dda4",
122 | "5bfe5ae0fe0ea555e6a969ca",
123 | "5bff3c5cfe0ea555e6bcbf3a",
124 | "58eaf1513353456af3a1682a",
125 | "5bfc9d5aec61ca1dd69132a2",
126 | "5bf18642c50e6f7f8bdbd492",
127 | "5bf26cbbd43923194854b270",
128 | "5bf17c0fd439231948355385",
129 | "5be3ae47f44e235bdbbc9771",
130 | "5be3a5fb8cfdd56947f6b67c",
131 | "5bbb6eb2ea1cfa39f1af7e0c",
132 | "5ba75d79d76ffa2c86cf2f05",
133 | # "5bb7a08aea1cfa39f1a947ab",
134 | "5b864d850d072a699b32f4ae",
135 | "5b6eff8b67b396324c5b2672",
136 | "5b6e716d67b396324c2d77cb",
137 | "5b69cc0cb44b61786eb959bf",
138 | "5b62647143840965efc0dbde",
139 | "5b60fa0c764f146feef84df0",
140 | "5b558a928bbfb62204e77ba2",
141 | "5b271079e0878c3816dacca4",
142 | "5b08286b2775267d5b0634ba",
143 | "5afacb69ab00705d0cefdd5b",
144 | "5af28cea59bc705737003253",
145 | # "5af02e904c8216544b4ab5a2",
146 | "5aa515e613d42d091d29d300",
147 | "5c34529873a8df509ae57b58",
148 | "5c34300a73a8df509add216d",
149 | "5c1af2e2bee9a723c963d019",
150 | "5c1892f726173c3a09ea9aeb",
151 | "5c0d13b795da9479e12e2ee9",
152 | "5c062d84a96e33018ff6f0a6",
153 | "5bfd0f32ec61ca1dd69dc77b",
154 | "5bf21799d43923194842c001",
155 | "5bf3a82cd439231948877aed",
156 | "5bf03590d4392319481971dc",
157 | "5beb6e66abd34c35e18e66b9",
158 | # "5be883a4f98cee15019d5b83",
159 | "5be47bf9b18881428d8fbc1d",
160 | "5bcf979a6d5f586b95c258cd",
161 | "5bce7ac9ca24970bce4934b6",
162 | "5bb8a49aea1cfa39f1aa7f75",
163 | "5b78e57afc8fcf6781d0c3ba",
164 | # "5b21e18c58e2823a67a10dd8",
165 | "5b22269758e2823a67a3bd03",
166 | "5b192eb2170cf166458ff886",
167 | "5ae2e9c5fe405c5076abc6b2",
168 | "5adc6bd52430a05ecb2ffb85",
169 | "5ab8b8e029f5351f7f2ccf59",
170 | "5abc2506b53b042ead637d86",
171 | "5ab85f1dac4291329b17cb50",
172 | "5a969eea91dfc339a9a3ad2c",
173 | "5a8aa0fab18050187cbe060e",
174 | # "5a7d3db14989e929563eb153",
175 | "5a69c47d0d5d0a7f3b2e9752",
176 | "5a618c72784780334bc1972d",
177 | "5a6464143d809f1d8208c43c",
178 | "5a588a8193ac3d233f77fbca",
179 | "5a57542f333d180827dfc132",
180 | "5a572fd9fc597b0478a81d14",
181 | "5a563183425d0f5186314855",
182 | "5a4a38dad38c8a075495b5d2",
183 | "5a48d4b2c7dab83a7d7b9851",
184 | "5a489fb1c7dab83a7d7b1070",
185 | # "5a48ba95c7dab83a7d7b44ed",
186 | "5a3ca9cb270f0e3f14d0eddb",
187 | "5a3cb4e4270f0e3f14d12f43",
188 | "5a3f4aba5889373fbbc5d3b5",
189 | "5a0271884e62597cdee0d0eb",
190 | "59e864b2a9e91f2c5529325f",
191 | "599aa591d5b41f366fed0d58",
192 | "59350ca084b7f26bf5ce6eb8",
193 | "59338e76772c3e6384afbb15",
194 | "5c20ca3a0843bc542d94e3e2",
195 | "5c1dbf200843bc542d8ef8c4",
196 | "5c1b1500bee9a723c96c3e78",
197 | "5bea87f4abd34c35e1860ab5",
198 | "5c2b3ed5e611832e8aed46bf",
199 | "57f8d9bbe73f6760f10e916a",
200 | "5bf7d63575c26f32dbf7413b",
201 | # "5be4ab93870d330ff2dce134",
202 | "5bd43b4ba6b28b1ee86b92dd",
203 | "5bccd6beca24970bce448134",
204 | "5bc5f0e896b66a2cd8f9bd36",
205 | "5b908d3dc6ab78485f3d24a9",
206 | "5b2c67b5e0878c381608b8d8",
207 | "5b4933abf2b5f44e95de482a",
208 | "5b3b353d8d46a939f93524b9",
209 | "5acf8ca0f3d8a750097e4b15",
210 | "5ab8713ba3799a1d138bd69a",
211 | "5aa235f64a17b335eeaf9609",
212 | # "5aa0f9d7a9efce63548c69a1",
213 | "5a8315f624b8e938486e0bd8",
214 | "5a48c4e9c7dab83a7d7b5cc7",
215 | "59ecfd02e225f6492d20fcc9",
216 | "59f87d0bfa6280566fb38c9a",
217 | "59f363a8b45be22330016cad",
218 | "59f70ab1e5c5d366af29bf3e",
219 | "59e75a2ca9e91f2c5526005d",
220 | "5947719bf1b45630bd096665",
221 | "5947b62af1b45630bd0c2a02",
222 | "59056e6760bb961de55f3501",
223 | "58f7f7299f5b5647873cb110",
224 | "58cf4771d0f5fb221defe6da",
225 | "58d36897f387231e6c929903",
226 | "58c4bb4f4a69c55606122be4",
227 | ]
228 |
229 |
230 | def val_scans() -> List[str]:
231 | return [
232 | "5bb7a08aea1cfa39f1a947ab",
233 | "5af02e904c8216544b4ab5a2",
234 | "5be883a4f98cee15019d5b83",
235 | "5b21e18c58e2823a67a10dd8",
236 | "5a7d3db14989e929563eb153",
237 | "5a48ba95c7dab83a7d7b44ed",
238 | "5be4ab93870d330ff2dce134",
239 | "5aa0f9d7a9efce63548c69a1",
240 | ]
241 |
242 |
243 | def test_scans() -> List[str]:
244 | return [
245 | "5b7a3890fc8fcf6781e2593a",
246 | "5c189f2326173c3a09ed7ef3",
247 | "5b950c71608de421b1e7318f",
248 | "5a6400933d809f1d8200af15",
249 | "59d2657f82ca7774b1ec081d",
250 | "5ba19a8a360c7c30c1c169df",
251 | "59817e4a1bd4b175e7038d19",
252 | ]
253 |
--------------------------------------------------------------------------------
/guided_mvs_lib/datasets/dtu_blended_mvs.py:
--------------------------------------------------------------------------------
1 | import os
2 | import random
3 | from collections import defaultdict
4 | from typing import Callable, Dict, List, Literal, Optional, TypedDict
5 |
6 | import cv2
7 | import numpy as np
8 | import scipy.io
9 | import torch
10 | from PIL import Image
11 | from plyfile import PlyData
12 | from torch.utils.data import Dataset
13 |
14 | import guided_mvs_lib.datasets.blended_mvg_utils as mvg_utils
15 | import guided_mvs_lib.datasets.blended_mvs_utils as mvs_utils
16 | import guided_mvs_lib.datasets.dtu_utils as dtu_utils
17 |
18 |
19 | class MVSSample(TypedDict):
20 | imgs: List[Image.Image]
21 | intrinsics: List[np.ndarray]
22 | extrinsics: List[np.ndarray]
23 | depths: List[np.ndarray]
24 | ref_depth_min: float
25 | ref_depth_max: float
26 | ref_depth_values: np.ndarray
27 | filename: str
28 | scan_pcd: Optional[np.ndarray]
29 | scan_pcd_obs_mask: Optional[np.ndarray]
30 | scan_pcd_bounding_box: Optional[np.ndarray]
31 | scan_pcd_resolution: Optional[float]
32 | scan_pcd_ground_plane: Optional[np.ndarray]
33 |
34 |
35 | def _identity_fn(x: MVSSample) -> Dict:
36 | return x
37 |
38 |
39 | class MVSDataset(Dataset):
40 | def __init__(
41 | self,
42 | name: Literal["dtu_yao", "blended_mvs", "blended_mvg"],
43 | datapath: str,
44 | mode: Literal["train", "val", "test"],
45 | nviews: int = 5,
46 | ndepths: int = 192,
47 | robust_train: bool = False,
48 | transform: Callable[[Dict], Dict] = _identity_fn,
49 | ):
50 | super().__init__()
51 | assert mode in ["train", "val", "test"], "MVSDataset train, val or test"
52 |
53 | if name == "dtu_yao":
54 | self.ds_utils = dtu_utils
55 | elif name == "blended_mvs":
56 | self.ds_utils = mvs_utils
57 | elif name == "blended_mvg":
58 | self.ds_utils = mvg_utils
59 | else:
60 | raise ValueError("datasets supported: dtu_yao, blended_mvs, blended_mvg")
61 |
62 | self.datapath = datapath
63 | self.transform = transform
64 | self.name = name
65 | self.mode = mode
66 | self.nviews = nviews
67 | self.ndepths = ndepths
68 | self.robust_train = robust_train
69 |
70 | scans = {
71 | "train": self.ds_utils.train_scans,
72 | "val": self.ds_utils.val_scans,
73 | "test": self.ds_utils.test_scans,
74 | }[mode]()
75 | if scans is None:
76 | raise ValueError(f"{mode} not supported on dataset {self.name}")
77 |
78 | self.metas = self.ds_utils.build_list(self.datapath, scans, nviews)
79 |
80 | def __len__(self):
81 | return len(self.metas)
82 |
83 | def _resize_depth_dtu(self, depth, size):
84 | h, w, _ = depth.shape
85 | depth = cv2.resize(depth, (w // 2, h // 2), interpolation=cv2.INTER_NEAREST)[..., None]
86 | h, w, _ = depth.shape
87 | start_h, start_w = (h - size[0]) // 2, (w - size[1]) // 2
88 | depth = depth[start_h : start_h + size[0], start_w : start_w + size[1]]
89 | return depth
90 |
91 | def __getitem__(self, idx: int) -> MVSSample:
92 |
93 | # load info
94 | scan, light_idx, ref_view, src_views = self.metas[idx]
95 | if self.mode in ["train", "val"] and self.robust_train:
96 | num_src_views = len(src_views)
97 | index = random.sample(range(num_src_views), self.nviews - 1)
98 | view_ids = [ref_view] + [src_views[i] for i in index]
99 | else:
100 | view_ids = [ref_view] + src_views[: self.nviews - 1]
101 |
102 | # collect the reference depth, the images and meta for this example
103 | out = defaultdict(
104 | lambda: [],
105 | scan_pcd=None,
106 | scan_pcd_obs_mask=None,
107 | scan_pcd_bounding_box=None,
108 | scan_pcd_resolution=None,
109 | scan_pcd_ground_plane=None,
110 | )
111 | for i, vid in enumerate(view_ids):
112 | datapaths = self.ds_utils.datapath_files(
113 | self.datapath, scan, vid, self.mode, light_idx
114 | )
115 | out["imgs"].append(Image.open(datapaths["img"]))
116 |
117 | if datapaths["pcd"] is not None and i == 0:
118 | mesh = PlyData.read(datapaths["pcd"])
119 | out["scan_pcd"] = np.stack(
120 | [mesh["vertex"]["x"], mesh["vertex"]["y"], mesh["vertex"]["z"]], -1
121 | )
122 | obs_mask_data = scipy.io.loadmat(datapaths["obs_mask"])
123 | out["scan_pcd_obs_mask"] = obs_mask_data["ObsMask"]
124 | out["scan_pcd_bounding_box"] = obs_mask_data["BB"]
125 | out["scan_pcd_resolution"] = obs_mask_data["Res"]
126 | out["scan_pcd_ground_plane"] = scipy.io.loadmat(datapaths["ground_plane"])["P"]
127 |
128 | # build proj matrix
129 | intrinsics, extrinsics, depth_min, depth_max = self.ds_utils.read_cam_file(
130 | datapaths["proj_mat"]
131 | )
132 | if self.name == "dtu_yao" and self.mode != "test":
133 | # preprocessed images loaded by dtu_yao train and val have been halved in dimension
134 | # and then cropped to (512, 640)
135 | intrinsics[:2] *= 0.5
136 | intrinsics[0, 2] *= 640 / 800
137 | intrinsics[1, 2] *= 512 / 600
138 | out["intrinsics"].append(intrinsics)
139 | out["extrinsics"].append(extrinsics)
140 |
141 | mask = self.ds_utils.read_depth_mask(datapaths["depth_mask"])
142 | depth = self.ds_utils.read_depth(datapaths["depth"]) * mask
143 |
144 | if self.name == "dtu_yao" and self.mode != "test":
145 | # the original depth map of dtu must be brought to (512, 640) when in training phase
146 | mask = self._resize_depth_dtu(mask.astype(np.float32), (512, 640))
147 | depth = self._resize_depth_dtu(depth, (512, 640))
148 |
149 | out["depths"].append(depth)
150 | if i == 0:
151 | out["ref_depth_min"] = depth_min
152 | out["ref_depth_max"] = depth_max
153 | out["ref_depth_values"] = np.arange(
154 | depth_min,
155 | depth_max,
156 | (depth_max - depth_min) / self.ndepths,
157 | dtype=np.float32,
158 | )
159 |
160 | out["filename"] = os.path.join(scan, "{}", f"{view_ids[0]:0>8}" + "{}")
161 | return self.transform(dict(out))
162 |
--------------------------------------------------------------------------------
/guided_mvs_lib/datasets/dtu_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utility functions to load the DTU dataset
3 | """
4 |
5 | from pathlib import Path
6 | from typing import Dict, Literal, Tuple, Union
7 |
8 | import numpy as np
9 | from PIL import Image
10 |
11 | from .utils import DatasetExamplePaths, read_pfm, save_pfm
12 |
13 | __all__ = [
14 | "datapath_files",
15 | "read_cam_file",
16 | "read_depth_mask",
17 | "read_depth",
18 | "build_list",
19 | "train_scans",
20 | "val_scans",
21 | "test_scans",
22 | ]
23 |
24 |
25 | def datapath_files(
26 | datapath: Union[Path, str],
27 | scan: str,
28 | view_id: int,
29 | split: Literal["train", "test", "val"],
30 | light_id: int = None,
31 | ) -> DatasetExamplePaths:
32 | """
33 | Takes in input the root of the DTU dataset and returns a dictionary containing
34 | the paths to the files of a single scan, with the speciied view_id and light_id
35 | (this last one is used only when ``split`` isn't test)
36 |
37 | Parameters
38 | ----------
39 | datapath: path
40 | path to the root of the dtu dataset
41 | scan: str
42 | name of the used scan, say scan1
43 | view_id: int
44 | the index of the specific image in the scan
45 | split: train, val or test
46 | which split of DTU must be used to search the scan for
47 | light_id: int
48 | the index of the specific lightning condition index
49 |
50 | Returns
51 | -------
52 | out: Dict[str, Path]
53 | returns a dictionary containing the paths taken into account
54 | """
55 | assert split in ["train", "val", "test"]
56 | root = Path(datapath)
57 | if split in ["train", "val"]:
58 | root = root / "train_data"
59 | return {
60 | "img": root / f"Rectified/{scan}_train/rect_{view_id + 1:0>3}_{light_id}_r5000.png",
61 | "proj_mat": root / f"Cameras_1/{view_id:0>8}_cam.txt",
62 | "depth_mask": root / f"Depths_raw/{scan}/depth_visual_{view_id:0>4}.png",
63 | "depth": root / f"Depths_raw/{scan}/depth_map_{view_id:0>4}.pfm",
64 | "pcd": None,
65 | "obs_mask": None,
66 | "ground_plane": None,
67 | }
68 | else:
69 | scan_idx = int(scan[4:])
70 | root = root / "test_data"
71 | return {
72 | "img": root / f"{scan}/images/{view_id:0>8}.jpg",
73 | "proj_mat": root / f"{scan}/cams_1/{view_id:0>8}_cam.txt",
74 | "depth_mask": root.parent
75 | / f"train_data/Depths_raw/{scan}/depth_visual_{view_id:0>4}.png",
76 | "depth": root.parent / f"train_data/Depths_raw/{scan}/depth_map_{view_id:0>4}.pfm",
77 | "pcd": root.parent / f"SampleSet/MVS Data/Points/stl/stl{scan_idx:0>3}_total.ply",
78 | "obs_mask": root.parent / f"SampleSet/MVS Data/ObsMask/ObsMask{scan_idx}_10.mat",
79 | "ground_plane": root.parent / f"SampleSet/MVS Data/ObsMask/Plane{scan_idx}.mat",
80 | }
81 |
82 |
83 | def read_cam_file(path: str) -> Tuple[np.ndarray, np.ndarray, float, float]:
84 | """
85 | Reads a file containing the DTU camera intrinsics, extrinsics, max depth and
86 | min depth.
87 |
88 | Parameters
89 | ----------
90 | path: str
91 | path of the source file (something like ../00000000_cam.txt)
92 |
93 | Returns
94 | -------
95 | out: Tuple[np.ndarray, np.ndarray, float, float]
96 | respectively intrinsics, extrinsics, min depth and max depth
97 | """
98 | with open(path) as f:
99 | lines = f.readlines()
100 | lines = [line.rstrip() for line in lines]
101 | extrinsics = np.fromstring(" ".join(lines[1:5]), dtype=np.float32, sep=" ").reshape((4, 4))
102 | intrinsics = np.fromstring(" ".join(lines[7:10]), dtype=np.float32, sep=" ").reshape((3, 3))
103 | depth_min = float(lines[11].split()[0])
104 | depth_max = float(lines[11].split()[1])
105 | return intrinsics, extrinsics, depth_min, depth_max
106 |
107 |
108 | def read_depth_mask(path: Union[str, Path]) -> np.ndarray:
109 | """
110 | Loads the depth DTU depth mask
111 | """
112 | img = np.array(Image.open(path))[..., None] > 10
113 | return img
114 |
115 |
116 | def read_depth(path: Union[str, Path]) -> np.ndarray:
117 | """
118 | Loads the depth DTU depth mask
119 | """
120 | depth = np.array(read_pfm(str(path))[0], dtype=np.float32)
121 | return depth
122 |
123 |
124 | # LIST OF SCANS
125 |
126 |
127 | def build_list(datapath: Union[str, Path], scans: list, nviews: int):
128 | metas = []
129 |
130 | datapath = Path(datapath)
131 | for scan in scans:
132 |
133 | # find pair file
134 | pair_file_test = datapath / f"test_data/{scan}/pair.txt"
135 | pair_file_train = datapath / "train_data/Cameras_1/pair.txt"
136 | if not pair_file_test.exists():
137 | if not pair_file_train.exists():
138 | raise ValueError(f"scan {scan} not found")
139 | else:
140 | pair_file = pair_file_train
141 | split = "train"
142 | else:
143 | pair_file = pair_file_test
144 | split = "test"
145 |
146 | # use pair file
147 | with open(pair_file, "rt") as f:
148 | num_viewpoint = int(f.readline())
149 | for _ in range(num_viewpoint):
150 | ref_view = int(f.readline().rstrip())
151 | src_views = [int(x) for x in f.readline().rstrip().split()[1::2]]
152 |
153 | if split == "train" and len(src_views) >= nviews:
154 | for light_idx in range(7):
155 | metas.append((scan, light_idx, ref_view, src_views))
156 | else:
157 | metas.append((scan, None, ref_view, src_views))
158 |
159 | return metas
160 |
161 |
162 | def train_scans():
163 | return [
164 | "scan2",
165 | "scan6",
166 | "scan7",
167 | "scan8",
168 | "scan14",
169 | "scan16",
170 | "scan18",
171 | "scan19",
172 | "scan20",
173 | "scan22",
174 | "scan30",
175 | "scan31",
176 | "scan36",
177 | "scan39",
178 | "scan41",
179 | "scan42",
180 | "scan44",
181 | "scan45",
182 | "scan46",
183 | "scan47",
184 | "scan50",
185 | "scan51",
186 | "scan52",
187 | "scan53",
188 | "scan55",
189 | "scan57",
190 | "scan58",
191 | "scan60",
192 | "scan61",
193 | "scan63",
194 | "scan64",
195 | "scan65",
196 | "scan68",
197 | "scan69",
198 | "scan70",
199 | "scan71",
200 | "scan72",
201 | "scan74",
202 | "scan76",
203 | "scan83",
204 | "scan84",
205 | "scan85",
206 | "scan87",
207 | "scan88",
208 | "scan89",
209 | "scan90",
210 | "scan91",
211 | "scan92",
212 | "scan93",
213 | "scan94",
214 | "scan95",
215 | "scan96",
216 | "scan97",
217 | "scan98",
218 | "scan99",
219 | "scan100",
220 | "scan101",
221 | "scan102",
222 | "scan103",
223 | "scan104",
224 | "scan105",
225 | "scan107",
226 | "scan108",
227 | "scan109",
228 | "scan111",
229 | "scan112",
230 | "scan113",
231 | "scan115",
232 | "scan116",
233 | "scan119",
234 | "scan120",
235 | "scan121",
236 | "scan122",
237 | "scan123",
238 | "scan124",
239 | "scan125",
240 | "scan126",
241 | "scan127",
242 | "scan128",
243 | ]
244 |
245 |
246 | def val_scans():
247 | return [
248 | "scan3",
249 | "scan5",
250 | "scan17",
251 | "scan21",
252 | "scan28",
253 | "scan35",
254 | "scan37",
255 | "scan38",
256 | "scan40",
257 | "scan43",
258 | "scan56",
259 | "scan59",
260 | "scan66",
261 | "scan67",
262 | "scan82",
263 | "scan86",
264 | "scan106",
265 | "scan117",
266 | ]
267 |
268 |
269 | def test_scans():
270 | return [
271 | "scan1",
272 | "scan4",
273 | "scan9",
274 | "scan10",
275 | "scan11",
276 | "scan12",
277 | "scan13",
278 | "scan15",
279 | "scan23",
280 | "scan24",
281 | "scan29",
282 | "scan32",
283 | "scan33",
284 | "scan34",
285 | "scan48",
286 | "scan49",
287 | "scan62",
288 | "scan75",
289 | "scan77",
290 | "scan110",
291 | "scan114",
292 | "scan118",
293 | ]
294 |
--------------------------------------------------------------------------------
/guided_mvs_lib/datasets/sample_preprocess.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from typing import Dict, Literal, Tuple
3 |
4 | import cv2
5 | import numpy as np
6 | import torch
7 | import torchvision.transforms.functional as F
8 | from PIL import Image
9 |
10 | from guided_mvs_lib.datasets.dtu_blended_mvs import MVSSample
11 |
12 | __all__ = ["MVSSampleTransform"]
13 |
14 |
15 | class MVSSampleTransform:
16 | """
17 | Preprocess each sample computing each view at 4 different scales and converts the
18 | sample in torch Tensor and compute hints in various ways
19 | """
20 |
21 | def __init__(
22 | self,
23 | generate_hints: Literal[
24 | "not_guided", "guided", "mvguided", "mvguided_filtered"
25 | ] = "not_guided",
26 | hints_perc: float = 0.01,
27 | filtering_window: Tuple[int, int] = (9, 9),
28 | ):
29 | assert generate_hints in ["not_guided", "guided", "mvguided", "mvguided_filtered"]
30 |
31 | self.generate_hints = generate_hints
32 | self._hints_perc = hints_perc
33 | self._height_bin = (filtering_window[0] // 2) - 1
34 | self._width_bin = (filtering_window[1] // 2) - 1
35 |
36 | def _generate_hints(self, sample: Dict) -> torch.Tensor:
37 |
38 | if self.generate_hints == "guided":
39 | # use only the ref depth map
40 | depth = sample["depths"][0]
41 | valid_hints = (depth > 0) & (np.random.rand(*depth.shape) <= self._hints_perc)
42 | hints = depth * valid_hints
43 | return torch.from_numpy(hints).permute(2, 0, 1)
44 | elif self.generate_hints in ["mvguided", "mvguided_filtered"]:
45 |
46 | hints_perc = self._hints_perc
47 |
48 | # extract hints from each depth map
49 | unwarped_hints = []
50 | for depth in sample["depths"]:
51 | valid_hints = (depth > 0) & (np.random.rand(*depth.shape) <= hints_perc)
52 | unwarped_hints.append(torch.from_numpy(depth * valid_hints).permute(2, 0, 1)[None])
53 |
54 | # project hints from other depth maps in the ref one
55 | ref_hints, ref_mask = unwarped_hints[0], unwarped_hints[0] > 0
56 | for i in range(1, len(unwarped_hints)):
57 | proj_mat_ref = sample["extrinsics"][0].copy()
58 | proj_mat_ref[:3, :4] = sample["intrinsics"][0] @ proj_mat_ref[:3, :4]
59 | proj_mat_src = sample["extrinsics"][i].copy()
60 | proj_mat_src[:3, :4] = sample["intrinsics"][i] @ proj_mat_src[:3, :4]
61 |
62 | warped_hints, warped_mask = hints_homo_warping(
63 | unwarped_hints[i],
64 | (unwarped_hints[i] > 0).to(torch.float32),
65 | torch.from_numpy(proj_mat_ref[None]),
66 | torch.from_numpy(proj_mat_src[None]),
67 | )
68 | assign_mask = (ref_mask * (1 - warped_mask)) == 0
69 | ref_hints[assign_mask] = warped_hints[assign_mask]
70 | ref_mask = ref_mask | warped_mask.to(torch.bool)
71 |
72 | # filter if required
73 | ref_hints = ref_hints[0] * ref_mask[0]
74 | if self.generate_hints == "mvguided_filtered":
75 | hints_mask = outlier_removal_mask(
76 | ref_hints.numpy().transpose([1, 2, 0]),
77 | sample["intrinsics"][0],
78 | height_bin=self._height_bin,
79 | width_bin=self._width_bin,
80 | )
81 | ref_hints[~torch.from_numpy(hints_mask).to(torch.bool).permute(2, 0, 1)] = 0
82 |
83 | return ref_hints
84 |
85 | def _split_pad(self, pad):
86 | if pad % 2 == 0:
87 | return pad // 2, pad // 2
88 | else:
89 | pad_1 = pad // 2
90 | pad_2 = (pad // 2) + 1
91 | return pad_1, pad_2
92 |
93 | def _pad_to_div_by(self, x, *, div_by=8):
94 |
95 | # compute padding
96 | if isinstance(x, Image.Image):
97 | w, h = x.size
98 | elif isinstance(x, np.ndarray):
99 | h, w, _ = x.shape
100 | else:
101 | raise ValueError("Image or np.ndarray")
102 |
103 | new_h = int(np.ceil(h / div_by)) * div_by
104 | new_w = int(np.ceil(w / div_by)) * div_by
105 | pad_t, pad_b = self._split_pad(new_h - h)
106 | pad_l, pad_r = self._split_pad(new_w - w)
107 |
108 | # return PIL or np.ndarray
109 | if isinstance(x, Image.Image):
110 | return F.pad(x, (pad_l, pad_t, pad_r, pad_b))
111 | elif isinstance(x, np.ndarray):
112 | return np.pad(x, [(pad_t, pad_b), (pad_l, pad_r), (0, 0)])
113 |
114 | def __call__(self, sample: MVSSample) -> Dict:
115 |
116 | # padding
117 | imgs = []
118 | for img, intrins in zip(sample["imgs"], sample["intrinsics"]):
119 |
120 | # pad the image
121 | w, h = img.size
122 | imgs.append(self._pad_to_div_by(img, div_by=32))
123 | w_new, h_new = imgs[-1].size
124 |
125 | # adapt intrinsics
126 | pad_w = (w_new - w) / 2
127 | ratio = (w + pad_w) / w
128 | intrins[0, 2] = intrins[0, 2] * ratio
129 | pad_h = (h_new - h) / 2
130 | ratio = (h + pad_h) / h
131 | intrins[1, 2] = intrins[1, 2] * ratio
132 |
133 | sample["imgs"] = imgs
134 | sample["depths"] = [self._pad_to_div_by(x, div_by=32) for x in sample["depths"]]
135 |
136 | # compute downsampled stages
137 | imgs, proj_matrices = defaultdict(lambda: []), defaultdict(lambda: [])
138 | depths = {}
139 |
140 | w, h = sample["imgs"][0].size
141 | for i in range(4):
142 | dsize = (w // (2 ** i), h // (2 ** i))
143 |
144 | for img in sample["imgs"]:
145 | imgs[f"stage_{i}"].append(F.to_tensor(cv2.resize(np.array(img), dsize)))
146 |
147 | for extrinsics, intrinsics in zip(sample["extrinsics"], sample["intrinsics"]):
148 | proj_mat = extrinsics.copy()
149 | intrinsics_copy = intrinsics.copy()
150 | intrinsics_copy[:2, :] = intrinsics_copy[:2, :] / (2 ** i)
151 | proj_mat[:3, :4] = intrinsics_copy @ proj_mat[:3, :4]
152 | proj_matrices[f"stage_{i}"].append(torch.from_numpy(proj_mat))
153 |
154 | depths[f"stage_{i}"] = F.to_tensor(cv2.resize(sample["depths"][0], dsize))
155 |
156 | # return result
157 | out = {
158 | "imgs": {k: torch.stack(v) for k, v in imgs.items()},
159 | "depth": depths,
160 | "proj_matrices": {k: torch.stack(v) for k, v in proj_matrices.items()},
161 | "intrinsics": torch.from_numpy(np.stack(sample["intrinsics"])),
162 | "extrinsics": torch.from_numpy(np.stack(sample["extrinsics"])),
163 | "depth_min": sample["ref_depth_min"],
164 | "depth_max": sample["ref_depth_max"],
165 | "depth_values": torch.from_numpy(sample["ref_depth_values"]),
166 | "filename": sample["filename"],
167 | }
168 |
169 | # add hints if requested
170 | if self.generate_hints != "not_guided":
171 | out["hints"] = self._generate_hints(sample)
172 |
173 | # add fields pcd if available
174 | for key in [
175 | "scan_pcd",
176 | "scan_pcd_obs_mask",
177 | "scan_pcd_bounding_box",
178 | "scan_pcd_resolution",
179 | "scan_pcd_ground_plane",
180 | ]:
181 | if key in sample and sample[key] is not None:
182 | out[key] = torch.from_numpy(sample[key])
183 |
184 | return out
185 |
186 |
187 | def _get_all_points(lidar: np.ndarray, intrinsic: np.ndarray) -> Tuple:
188 |
189 | lidar_32 = np.squeeze(lidar).astype(np.float32)
190 | height, width = np.shape(lidar_32)
191 | x_axis = np.arange(width).reshape(width, 1)
192 | x_image = np.tile(x_axis, height)
193 | x_image = np.transpose(x_image)
194 | y_axis = np.arange(height).reshape(height, 1)
195 | y_image = np.tile(y_axis, width)
196 | z_image = np.ones((height, width))
197 | image_coor_tensor = (
198 | np.asarray([x_image, y_image, z_image]).astype(np.float32).transpose([1, 0, 2])
199 | )
200 |
201 | intrinsic = np.reshape(intrinsic, [3, 3]).astype(np.float32)
202 | intrinsic_inverse = np.linalg.inv(intrinsic)
203 | points_homo = np.matmul(intrinsic_inverse, image_coor_tensor)
204 |
205 | lidar_32 = np.reshape(lidar_32, [height, 1, width])
206 | points_homo = points_homo * lidar_32
207 | extra_image = np.ones((height, width)).astype(np.float32)
208 | extra_image = np.reshape(extra_image, [height, 1, width])
209 | points_homo = np.concatenate([points_homo, extra_image], axis=1)
210 |
211 | extrinsic_v_2_c = [
212 | [0.007, -1, 0, 0],
213 | [0.0148, 0, -1, -0.076],
214 | [1, 0, 0.0148, -0.271],
215 | [0, 0, 0, 1],
216 | ]
217 | extrinsic_v_2_c = np.reshape(extrinsic_v_2_c, [4, 4]).astype(np.float32)
218 | extrinsic_c_2_v = np.linalg.inv(extrinsic_v_2_c)
219 | points_lidar = np.matmul(extrinsic_c_2_v, points_homo)
220 |
221 | mask = np.squeeze(lidar) > 0.1
222 | total_points = [
223 | points_lidar[:, 0, :][mask],
224 | points_lidar[:, 1, :][mask],
225 | points_lidar[:, 2, :][mask],
226 | ]
227 | total_points = np.asarray(total_points)
228 | total_points = np.transpose(total_points)
229 |
230 | return total_points, x_image[mask], y_image[mask], x_image, y_image
231 |
232 |
233 | def _do_range_projection_try(
234 | points: np.ndarray,
235 | fov_up: float = 3.0,
236 | fov_down: float = -18.0,
237 | ) -> Tuple[np.ndarray, np.ndarray]:
238 | # for each point, where it is in the range image
239 | proj_x = np.zeros((0, 1), dtype=np.float32) # [m, 1]: x
240 | proj_y = np.zeros((0, 1), dtype=np.float32) # [m, 1]: y
241 |
242 | # laser parameters
243 | fov_up = fov_up / 180.0 * np.pi # field of view up in rad
244 | fov_down = fov_down / 180.0 * np.pi # field of view down in rad
245 | fov = abs(fov_down) + abs(fov_up) # get field of view total in rad
246 |
247 | # get depth of all points
248 | depth = np.linalg.norm(points, 2, axis=1)
249 |
250 | # get scan components
251 | scan_x = points[:, 0]
252 | scan_y = points[:, 1]
253 | scan_z = points[:, 2]
254 |
255 | # get angles of all points
256 | yaw = -np.arctan2(scan_y, scan_x)
257 | pitch = np.arcsin(scan_z / depth)
258 |
259 | # get projections in image coords
260 | proj_x = 0.5 * (yaw / np.pi + 1.0) # in [0.0, 1.0]
261 |
262 | proj_y = 1.0 - (pitch + abs(fov_down)) / fov # in [0.0, 1.0]
263 | return proj_x, proj_y
264 |
265 |
266 | def _compute_trunck(v: np.ndarray, height_bin: int = 4, width_bin: int = 4) -> np.ndarray:
267 | v = np.squeeze(v)
268 | v_trunck = np.lib.stride_tricks.sliding_window_view(
269 | np.pad(v, [(height_bin, height_bin), (width_bin, width_bin)]),
270 | (height_bin * 2 + 1, width_bin * 2 + 1),
271 | ).reshape(v.shape[0], v.shape[1], (height_bin * 2 + 1) * (height_bin * 2 + 1))
272 | return v_trunck
273 |
274 |
275 | def _compute_residual(v, height_bin, width_bin):
276 | v_trunck = _compute_trunck(v, height_bin, width_bin)
277 | residual = np.squeeze(v)[..., None] - v_trunck
278 | return residual
279 |
280 |
281 | def outlier_removal_mask(
282 | lidar: np.ndarray, intrinsic: np.ndarray, height_bin: int = 4, width_bin: int = 4
283 | ) -> np.ndarray:
284 |
285 | height, width, _ = lidar.shape
286 |
287 | total_points, x_indices, y_indices, width_image, height_image = _get_all_points(
288 | lidar, intrinsic
289 | )
290 | proj_x, proj_y = _do_range_projection_try(total_points)
291 |
292 | project_x, project_y = np.zeros((2, height, width, 1))
293 | project_x[y_indices, x_indices, 0] = proj_x
294 | project_y[y_indices, x_indices, 0] = proj_y
295 |
296 | project_x_trunck = _compute_trunck(project_x, height_bin, width_bin)
297 | project_x_residual = project_x - project_x_trunck
298 |
299 | project_y_trunck = _compute_trunck(project_y, height_bin, width_bin)
300 | project_y_residual = project_y - project_y_trunck
301 |
302 | height_image_trunck = _compute_trunck(height_image, height_bin, width_bin)
303 | height_image_residual = height_image[..., None] - height_image_trunck
304 |
305 | width_image_trunck = _compute_trunck(width_image, height_bin, width_bin)
306 | width_image_residual = width_image[..., None] - width_image_trunck
307 |
308 | lidar_trunck = _compute_trunck(lidar, height_bin, width_bin)
309 | zero_mask = np.logical_and(lidar > 0.1, lidar_trunck > 0.1)
310 |
311 | x_mask = np.logical_and(
312 | np.logical_or(
313 | np.logical_and(project_x_residual > 0.0000, width_image_residual <= 0),
314 | np.logical_and(project_x_residual < 0.0000, width_image_residual >= 0),
315 | ),
316 | zero_mask,
317 | )
318 |
319 | y_mask = np.logical_and(
320 | np.logical_or(
321 | np.logical_and(project_y_residual > 0, height_image_residual <= 0),
322 | np.logical_and(project_y_residual < 0, height_image_residual >= 0),
323 | ),
324 | zero_mask,
325 | )
326 |
327 | lidar_residual = lidar - lidar_trunck
328 | lidar_mask = np.logical_and(lidar_residual > 3.0, lidar > 0.01)
329 |
330 | final_mask = np.logical_and(lidar_mask, np.logical_or(x_mask, y_mask))
331 | final_mask = np.squeeze(final_mask)
332 | final_mask = np.sum(final_mask, axis=-1, keepdims=True) == 0
333 |
334 | return final_mask
335 |
336 |
337 | def hints_homo_warping(src_hints, src_validhints, src_proj, ref_proj):
338 | # src_fea: [B, 1, H, W]
339 | # src_proj: [B, 4, 4]
340 | # ref_proj: [B, 4, 4]
341 | # depth_values: [B, Ndepth]
342 | # out: [B, C, Ndepth, H, W]
343 | batch = src_hints.shape[0]
344 | height, width = src_hints.shape[2], src_hints.shape[3]
345 | warped_hints = torch.zeros_like(src_hints)
346 | warped_validhints = torch.zeros_like(src_validhints)
347 |
348 | with torch.no_grad():
349 | proj = torch.matmul(src_proj, torch.inverse(ref_proj))
350 | rot = proj[:, :3, :3] # [B,3,3]
351 | trans = proj[:, :3, 3:4] # [B,3,1]
352 |
353 | y, x = torch.meshgrid(
354 | [
355 | torch.arange(0, height, dtype=torch.float32, device=src_hints.device),
356 | torch.arange(0, width, dtype=torch.float32, device=src_hints.device),
357 | ]
358 | )
359 | y, x = y.contiguous(), x.contiguous()
360 | y, x = y.view(height * width), x.view(height * width)
361 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
362 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
363 | rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
364 | rot_depth_xyz = rot_xyz.unsqueeze(2) * src_hints.view(
365 | batch, 1, 1, -1
366 | ) # [B, 3, Ndepth, H*W]
367 | proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W]
368 |
369 | proj_x = proj_xyz[:, 0, :, :].view(batch, height, width)
370 | proj_y = proj_xyz[:, 1, :, :].view(batch, height, width)
371 | proj_z = proj_xyz[:, 2, :, :].view(batch, height, width)
372 |
373 | proj_x = torch.clamp(np.round(proj_x / proj_z).to(torch.int64), min=0, max=width - 1)
374 | proj_y = torch.clamp(np.round(proj_y / proj_z).to(torch.int64), min=0, max=height - 1)
375 | proj_z = proj_z.unsqueeze(-1)
376 |
377 | warped_hints = warped_hints.squeeze(1).unsqueeze(-1)
378 | warped_validhints = warped_validhints.squeeze(1).unsqueeze(-1)
379 | src_validhints = src_validhints.squeeze(1).unsqueeze(-1)
380 |
381 | # forward warping (will it work?)
382 | for i in range(warped_hints.shape[0]):
383 | warped_hints[i][proj_y, proj_x] = -proj_z[i]
384 | warped_validhints[i][proj_y, proj_x] = -src_validhints[i]
385 | warped_hints *= -1
386 | warped_validhints *= -1 * (warped_hints > 0)
387 |
388 | return warped_hints.permute(0, 3, 1, 2), warped_validhints.permute(
389 | 0, 3, 1, 2
390 | ) # , warped_validhints
391 |
--------------------------------------------------------------------------------
/guided_mvs_lib/datasets/utils.py:
--------------------------------------------------------------------------------
1 | import re
2 | import sys
3 | from pathlib import Path
4 | from typing import Optional, Tuple, TypedDict
5 |
6 | import numpy as np
7 |
8 | __all__ = ["DatasetInstancePaths", "read_pfm", "save_pfm"]
9 |
10 |
11 | class DatasetExamplePaths(TypedDict):
12 | img: Path
13 | proj_mat: Path
14 | depth_mask: Path
15 | depth: Path
16 | pcd: Optional[Path]
17 | obs_mask: Optional[Path]
18 | ground_plane: Optional[Path]
19 |
20 |
21 | def read_pfm(filename: str) -> Tuple[np.ndarray, float]:
22 | """Read a depth map from a .pfm file
23 |
24 | Args:
25 | filename: .pfm file path string
26 |
27 | Returns:
28 | data: array of shape (H, W, C) representing loaded depth map
29 | scale: float to recover actual depth map pixel values
30 | """
31 | file = open(filename, "rb") # treat as binary and read-only
32 | color = None
33 | width = None
34 | height = None
35 | scale = None
36 | endian = None
37 |
38 | header = file.readline().decode("utf-8").rstrip()
39 | if header == "PF":
40 | color = True
41 | elif header == "Pf": # depth is Pf
42 | color = False
43 | else:
44 | raise Exception("Not a PFM file.")
45 |
46 | dim_match = re.match(r"^(\d+)\s(\d+)\s$", file.readline().decode("utf-8"))
47 | if dim_match:
48 | width, height = map(int, dim_match.groups())
49 | else:
50 | raise Exception("Malformed PFM header.")
51 |
52 | scale = float(file.readline().rstrip())
53 | if scale < 0: # little-endian
54 | endian = "<"
55 | scale = -scale
56 | else:
57 | endian = ">" # big-endian
58 |
59 | data = np.fromfile(file, endian + "f")
60 | shape = (height, width, 3) if color else (height, width, 1)
61 |
62 | data = np.reshape(data, shape)
63 | data = np.flipud(data)
64 | file.close()
65 | return data, scale
66 |
67 |
68 | def save_pfm(filename: str, image: np.ndarray, scale: float = 1) -> None:
69 | """Save a depth map from a .pfm file
70 |
71 | Args:
72 | filename: output .pfm file path string,
73 | image: depth map to save, of shape (H,W) or (H,W,C)
74 | scale: scale parameter to save
75 | """
76 | file = open(filename, "wb")
77 | color = None
78 |
79 | image = np.flipud(image)
80 |
81 | if image.dtype.name != "float32":
82 | raise Exception("Image dtype must be float32.")
83 |
84 | if len(image.shape) == 3 and image.shape[2] == 3: # color image
85 | color = True
86 | elif len(image.shape) == 2 or len(image.shape) == 3 and image.shape[2] == 1: # greyscale
87 | color = False
88 | else:
89 | raise Exception("Image must have H x W x 3, H x W x 1 or H x W dimensions.")
90 |
91 | file.write("PF\n".encode("utf-8") if color else "Pf\n".encode("utf-8"))
92 | file.write("{} {}\n".format(image.shape[1], image.shape[0]).encode("utf-8"))
93 |
94 | endian = image.dtype.byteorder
95 |
96 | if endian == "<" or endian == "=" and sys.byteorder == "little":
97 | scale = -scale
98 |
99 | file.write(("%f\n" % scale).encode("utf-8"))
100 |
101 | image.tofile(file)
102 | file.close()
103 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/__init__.py:
--------------------------------------------------------------------------------
1 | from types import SimpleNamespace
2 | from typing import Dict, Optional, Union
3 |
4 | import matplotlib.pyplot as plt
5 | import torch
6 | from pytorch_lightning import LightningModule
7 | from torch import optim
8 | from torch.optim.lr_scheduler import MultiStepLR
9 |
10 | from . import cas_mvsnet, d2hc_rmvsnet, mvsnet, patchmatchnet, ucsnet
11 |
12 | __all__ = ["MVSModel", "build_network"]
13 |
14 | _NETWORKS = {
15 | "mvsnet": mvsnet.NetBuilder,
16 | "cas_mvsnet": cas_mvsnet.NetBuilder,
17 | "ucsnet": ucsnet.NetBuilder,
18 | "d2hc_rmvsnet": d2hc_rmvsnet.NetBuilder,
19 | "patchmatchnet": patchmatchnet.NetBuilder,
20 | }
21 |
22 |
23 | def build_network(name: str, args: SimpleNamespace):
24 | try:
25 | return _NETWORKS[name](args)
26 | except KeyError:
27 | raise ValueError("network name in {}".format(", ".join(_NETWORKS.keys())))
28 |
29 |
30 | class MVSModel(LightningModule):
31 | def __init__(
32 | self,
33 | *,
34 | args: SimpleNamespace,
35 | mlflow_run_id: Optional[str] = None,
36 | v_num: Optional[str] = None, # (experiment version to show in tqdm)
37 | ):
38 | super().__init__()
39 |
40 | # instance the used model
41 | self.model = build_network(args.model, args)
42 | self.mlflow_run_id = mlflow_run_id
43 | self.loss_fn = self.model.loss
44 |
45 | # save train parameters
46 | hparams = dict(
47 | model=args.model,
48 | **{name: value for name, value in args.train.__dict__.items()},
49 | )
50 | if hasattr(self.model, "hparams"):
51 | hparams = dict(**hparams, **self.model.hparams)
52 |
53 | self.save_hyperparameters(hparams)
54 |
55 | # utils
56 | self._is_val = False
57 | self.v_num = v_num
58 |
59 | def get_progress_bar_dict(self) -> Dict[str, Union[int, str]]:
60 | progress_bar_dict = super().get_progress_bar_dict()
61 | progress_bar_dict["v_num"] = self.v_num
62 | return progress_bar_dict
63 |
64 | def forward(self, batch: dict):
65 | return self.model(batch)
66 |
67 | def configure_optimizers(self):
68 | optimizer = optim.Adam(
69 | self.model.parameters(),
70 | lr=self.hparams.lr,
71 | betas=(0.9, 0.999),
72 | weight_decay=self.hparams.weight_decay,
73 | )
74 |
75 | if self.hparams.epochs_lr_decay is not None:
76 | scheduler = MultiStepLR(
77 | optimizer,
78 | self.hparams.epochs_lr_decay,
79 | gamma=self.hparams.epochs_lr_gamma,
80 | )
81 | return [optimizer], [scheduler]
82 | else:
83 | return optimizer
84 |
85 | def _compute_masks(self, depth_dict: Dict) -> Dict:
86 | return {k: (v > 0).to(torch.float32) for k, v in depth_dict.items()}
87 |
88 | def training_step(self, batch, _):
89 | self._is_val = False
90 |
91 | outputs = self.model(batch)
92 | loss = self.loss_fn(
93 | outputs["loss_data"], batch["depth"], self._compute_masks(batch["depth"])
94 | )
95 | self._log_metrics(batch, outputs, loss, "train")
96 |
97 | return loss
98 |
99 | def validation_step(self, batch, _):
100 | outputs = self.model(batch)
101 | loss = self.loss_fn(
102 | outputs["loss_data"], batch["depth"], self._compute_masks(batch["depth"])
103 | )
104 | self._log_metrics(batch, outputs, loss, "val")
105 |
106 | return loss
107 |
108 | def _log_metrics(self, batch, outputs, loss, stage):
109 | on_epoch = stage != "train"
110 |
111 | depth_gt = batch["depth"]["stage_0"]
112 | mask = (depth_gt > 0).to(torch.float32)
113 | depth_est = outputs["depth"]["stage_0"]
114 |
115 | # log scalar metrics
116 | self.log(f"{stage}/loss", loss, on_epoch=on_epoch, on_step=not on_epoch)
117 | self.log(
118 | f"{stage}/mae",
119 | torch.mean(torch.abs(depth_est[mask > 0.5] - depth_gt[mask > 0.5])),
120 | on_epoch=on_epoch,
121 | on_step=not on_epoch,
122 | )
123 | for thresh in [1, 2, 3, 4, 8]:
124 | self.log(
125 | f"{stage}/perc_l1_upper_thresh_{thresh}",
126 | torch.mean(
127 | (torch.abs(depth_est[mask > 0.5] - depth_gt[mask > 0.5]) > thresh).float()
128 | ),
129 | on_epoch=on_epoch,
130 | on_step=not on_epoch,
131 | )
132 |
133 | # logs only the first image for val, one every 5000 steps for train
134 | if (stage == "train" and self.global_step % 5000 == 0) or (
135 | stage == "val" and not self._is_val
136 | ):
137 | if stage == "val":
138 | self._is_val = True
139 |
140 | if self.mlflow_run_id != None:
141 | for name, map_ in [("depth_pred", depth_est), ("depth_gt", depth_gt)]:
142 |
143 | fig = plt.figure(tight_layout=True)
144 | ax = plt.Axes(fig, [0.0, 0.0, 1.0, 1.0])
145 | ax.set_axis_off()
146 | fig.add_axes(ax)
147 | ax.imshow(
148 | (map_[0] * mask[0]).permute(1, 2, 0).detach().cpu().numpy(), cmap="magma_r"
149 | )
150 | self.logger.experiment.log_figure(
151 | self.mlflow_run_id, fig, f"{stage}/{name}-{self.global_step}.jpg"
152 | )
153 | plt.close(fig)
154 |
155 | self.logger.experiment.log_image(
156 | self.mlflow_run_id,
157 | batch["imgs"]["stage_0"][0, 0].permute(1, 2, 0).cpu().numpy(),
158 | f"{stage}/image-{self.global_step}.jpg",
159 | )
160 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/cas_mvsnet/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from types import SimpleNamespace
3 | from typing import Any, Optional
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch import Tensor
8 |
9 | from .cas_mvsnet import CascadeMVSNet, cas_mvsnet_loss
10 |
11 | __all__ = ["cas_mvsnet_loss", "CascadeMVSNet", "NetBuilder", "SimpleIntefaceNet"]
12 |
13 |
14 | class SimpleInterfaceNet(nn.Module):
15 | """
16 | Simple common interface to call the pretrained models
17 | """
18 |
19 | def __init__(self, *args, **kwargs):
20 | super().__init__()
21 | self.model = CascadeMVSNet(*args, **kwargs)
22 | self.all_outputs: dict[str, Any] = {}
23 |
24 | def forward(
25 | self,
26 | imgs: Tensor,
27 | intrinsics: Tensor,
28 | extrinsics: Tensor,
29 | depth_values: Tensor,
30 | hints: Optional[Tensor] = None,
31 | ):
32 | with warnings.catch_warnings():
33 | warnings.simplefilter("ignore", UserWarning)
34 |
35 | # compute poses
36 | proj_matrices = {}
37 | for i in range(4):
38 | proj_mat = extrinsics.clone()
39 | intrinsics_copy = intrinsics.clone()
40 | intrinsics_copy[..., :2, :] = intrinsics_copy[..., :2, :] / (2 ** i)
41 | proj_mat[..., :3, :4] = intrinsics_copy @ proj_mat[..., :3, :4]
42 | proj_matrices[f"stage_{i}"] = proj_mat
43 |
44 | # validhints
45 | validhints = None
46 | if hints is not None:
47 | validhints = (hints > 0).to(torch.float32)
48 |
49 | # call
50 | out = self.model(imgs, proj_matrices, depth_values, hints, validhints)
51 | self.all_outputs = out
52 | return out["depth"]["stage_0"]
53 |
54 |
55 | class NetBuilder(nn.Module):
56 | def __init__(self, args: SimpleNamespace):
57 | super().__init__()
58 | self.model = CascadeMVSNet()
59 |
60 | def loss_func(loss_data, depth_gt, mask):
61 | return cas_mvsnet_loss(
62 | loss_data["depth"],
63 | depth_gt,
64 | mask,
65 | )
66 |
67 | self.loss = loss_func
68 |
69 | def forward(self, batch: dict):
70 |
71 | hints, validhints = None, None
72 | if "hints" in batch:
73 | hints = batch["hints"]
74 | validhints = (hints > 0).to(torch.float32)
75 |
76 | return self.model(
77 | batch["imgs"]["stage_0"],
78 | batch["proj_matrices"],
79 | batch["depth_values"],
80 | hints,
81 | validhints,
82 | )
83 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/cas_mvsnet/cas_mvsnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .module import *
6 |
7 | Align_Corners_Range = False
8 |
9 |
10 | class DepthNet(nn.Module):
11 | def __init__(self):
12 | super(DepthNet, self).__init__()
13 |
14 | def forward(
15 | self,
16 | features,
17 | proj_matrices,
18 | depth_values,
19 | num_depth,
20 | cost_regularization,
21 | prob_volume_init=None,
22 | hints=None,
23 | validhints=None,
24 | scale=1,
25 | ):
26 | proj_matrices = torch.unbind(proj_matrices, 1)
27 | assert len(features) == len(
28 | proj_matrices
29 | ), "Different number of images and projection matrices"
30 | assert depth_values.shape[1] == num_depth, "depth_values.shape[1]:{} num_depth:{}".format(
31 | depth_values.shapep[1], num_depth
32 | )
33 | num_views = len(features)
34 |
35 | # step 1. feature extraction
36 | # in: images; out: 32-channel feature maps
37 | ref_feature, src_features = features[0], features[1:]
38 | ref_proj, src_projs = proj_matrices[0], proj_matrices[1:]
39 |
40 | # step 2. differentiable homograph, build cost volume
41 | ref_volume = ref_feature.unsqueeze(2).repeat(1, 1, num_depth, 1, 1)
42 | volume_sum = ref_volume
43 | volume_sq_sum = ref_volume ** 2
44 | del ref_volume
45 | for src_fea, src_proj in zip(src_features, src_projs):
46 | warped_volume = homo_warping(src_fea, src_proj, ref_proj, depth_values)
47 |
48 | if self.training:
49 | volume_sum = volume_sum + warped_volume
50 | volume_sq_sum = volume_sq_sum + warped_volume ** 2
51 | else:
52 | # TODO: this is only a temporal solution to save memory, better way?
53 | volume_sum += warped_volume
54 | volume_sq_sum += warped_volume.pow_(
55 | 2
56 | ) # the memory of warped_volume has been modified
57 | del warped_volume
58 |
59 | # aggregate multiple feature volumes by variance
60 | volume_variance = volume_sq_sum.div_(num_views).sub_(volume_sum.div_(num_views).pow_(2))
61 |
62 | # TODO cost-volume modulation
63 | if hints is not None and validhints is not None:
64 | batch_size, feats, height, width = ref_feature.shape
65 | GAUSSIAN_HEIGHT = 10.0
66 | GAUSSIAN_WIDTH = 1.0
67 |
68 | # image features are one fourth the original size: subsample the hints and divide them by four
69 | hints = hints
70 | hints = F.interpolate(hints, scale_factor=1 / scale, mode="nearest").unsqueeze(1)
71 | validhints = validhints
72 | validhints = F.interpolate(
73 | validhints, scale_factor=1 / scale, mode="nearest"
74 | ).unsqueeze(1)
75 | hints = hints * validhints
76 |
77 | # add feature and disparity dimensions to hints and validhints
78 | # and repeat their values along those dimensions, to obtain the same size as cost
79 | hints = hints.expand(-1, feats, num_depth, -1, -1)
80 | validhints = validhints.expand(-1, feats, num_depth, -1, -1)
81 |
82 | # create a tensor of the same size as cost, with disparities
83 | # between 0 and num_disp-1 along the disparity dimension
84 | depth_hyps = (
85 | depth_values.unsqueeze(1).expand(batch_size, feats, -1, height, width).detach()
86 | )
87 | volume_variance = volume_variance * (
88 | (1 - validhints)
89 | + validhints
90 | * GAUSSIAN_HEIGHT
91 | * (1 - torch.exp(-((depth_hyps - hints) ** 2) / (2 * GAUSSIAN_WIDTH ** 2)))
92 | )
93 |
94 | # step 3. cost volume regularization
95 | cost_reg = cost_regularization(volume_variance)
96 | # cost_reg = F.upsample(cost_reg, [num_depth * 4, img_height, img_width], mode='trilinear')
97 | prob_volume_pre = cost_reg.squeeze(1)
98 |
99 | if prob_volume_init is not None:
100 | prob_volume_pre += prob_volume_init
101 |
102 | prob_volume = F.softmax(prob_volume_pre, dim=1)
103 | depth = depth_regression(prob_volume, depth_values=depth_values)
104 |
105 | with torch.no_grad():
106 | # photometric confidence
107 | prob_volume_sum4 = (
108 | 4
109 | * F.avg_pool3d(
110 | F.pad(prob_volume.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)),
111 | (4, 1, 1),
112 | stride=1,
113 | padding=0,
114 | ).squeeze(1)
115 | )
116 | depth_index = depth_regression(
117 | prob_volume,
118 | depth_values=torch.arange(num_depth, device=prob_volume.device, dtype=torch.float),
119 | ).long()
120 | depth_index = depth_index.clamp(min=0, max=num_depth - 1)
121 | photometric_confidence = torch.gather(
122 | prob_volume_sum4, 1, depth_index.unsqueeze(1)
123 | ).squeeze(1)
124 |
125 | return {"depth": depth, "photometric_confidence": photometric_confidence}
126 |
127 |
128 | class CascadeMVSNet(nn.Module):
129 | def __init__(
130 | self,
131 | refine=False,
132 | ndepths=[48, 32, 8],
133 | depth_interals_ratio=[4, 2, 1],
134 | share_cr=False,
135 | grad_method="detach",
136 | arch_mode="fpn",
137 | cr_base_chs=[8, 8, 8],
138 | ):
139 | super(CascadeMVSNet, self).__init__()
140 | self.refine = refine
141 | self.share_cr = share_cr
142 | self.ndepths = ndepths
143 | self.depth_interals_ratio = depth_interals_ratio
144 | self.grad_method = grad_method
145 | self.arch_mode = arch_mode
146 | self.cr_base_chs = cr_base_chs
147 | self.num_stage = len(ndepths)
148 | assert len(ndepths) == len(depth_interals_ratio)
149 |
150 | self.stage_infos = {
151 | "stage_2": { # was "stage1"
152 | "scale": 4.0,
153 | },
154 | "stage_1": { # was "stage2"
155 | "scale": 2.0,
156 | },
157 | "stage_0": { # was "stage3"
158 | "scale": 1.0,
159 | },
160 | }
161 |
162 | self.feature = FeatureNet(
163 | base_channels=8, stride=4, num_stage=self.num_stage, arch_mode=self.arch_mode
164 | )
165 | if self.share_cr:
166 | self.cost_regularization = CostRegNet(
167 | in_channels=self.feature.out_channels, base_channels=8
168 | )
169 | else:
170 | self.cost_regularization = nn.ModuleList(
171 | [
172 | CostRegNet(
173 | in_channels=self.feature.out_channels[i], base_channels=self.cr_base_chs[i]
174 | )
175 | for i in [2, 1, 0]
176 | ]
177 | ) # range(self.num_stage)])
178 | if self.refine:
179 | self.refine_network = RefineNet()
180 | self.DepthNet = DepthNet()
181 |
182 | def forward(self, imgs, proj_matrices, depth_values, hints=None, validhints=None):
183 | depth_min = float(depth_values[0, 0].cpu().numpy())
184 | depth_max = float(depth_values[0, -1].cpu().numpy())
185 | depth_interval = (depth_max - depth_min) / depth_values.size(1)
186 |
187 | # step 1. feature extraction
188 | features = []
189 | for nview_idx in range(imgs.size(1)): # imgs shape (B, N, C, H, W)
190 | img = imgs[:, nview_idx]
191 | features.append(self.feature(img))
192 |
193 | outputs = {}
194 | depth, cur_depth = None, None
195 | for stage_idx in [2, 1, 0]: # range(self.num_stage):
196 | # stage feature, proj_mats, scales
197 | features_stage = [feat["stage_{}".format(stage_idx)] for feat in features]
198 | proj_matrices_stage = proj_matrices["stage_{}".format(stage_idx)] # + 1)]
199 | stage_scale = self.stage_infos["stage_{}".format(stage_idx)]["scale"]
200 |
201 | if depth is not None:
202 | if self.grad_method == "detach":
203 | cur_depth = depth.detach()
204 | else:
205 | cur_depth = depth
206 | cur_depth = F.interpolate(
207 | cur_depth.unsqueeze(1),
208 | [img.shape[2], img.shape[3]],
209 | mode="bilinear",
210 | align_corners=Align_Corners_Range,
211 | ).squeeze(1)
212 | else:
213 | cur_depth = depth_values
214 | depth_range_samples = get_depth_range_samples(
215 | cur_depth=cur_depth,
216 | ndepth=self.ndepths[2 - stage_idx],
217 | depth_inteval_pixel=self.depth_interals_ratio[2 - stage_idx] * depth_interval,
218 | dtype=img[0].dtype,
219 | device=img[0].device,
220 | shape=[img.shape[0], img.shape[2], img.shape[3]],
221 | max_depth=depth_max,
222 | min_depth=depth_min,
223 | )
224 |
225 | outputs_stage = self.DepthNet(
226 | features_stage,
227 | proj_matrices_stage,
228 | depth_values=F.interpolate(
229 | depth_range_samples.unsqueeze(1),
230 | [
231 | self.ndepths[2 - stage_idx],
232 | img.shape[2] // int(stage_scale),
233 | img.shape[3] // int(stage_scale),
234 | ],
235 | mode="trilinear",
236 | align_corners=Align_Corners_Range,
237 | ).squeeze(1),
238 | num_depth=self.ndepths[2 - stage_idx],
239 | cost_regularization=self.cost_regularization
240 | if self.share_cr
241 | else self.cost_regularization[stage_idx],
242 | hints=hints,
243 | validhints=validhints,
244 | scale=stage_scale,
245 | )
246 |
247 | depth = outputs_stage["depth"] # .unsqueeze(1)
248 |
249 | outputs["stage_{}".format(stage_idx)] = depth.unsqueeze(1)
250 |
251 | # depth map refinement
252 | if self.refine:
253 | refined_depth = self.refine_network(torch.cat((imgs[:, 0], depth), 1))
254 | outputs["refined_depth"] = refined_depth
255 |
256 | return {
257 | "depth": outputs["refined_depth"] if self.refine else outputs,
258 | "loss_data": {
259 | "depth": outputs,
260 | },
261 | "photometric_confidence": outputs_stage["photometric_confidence"],
262 | }
263 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/d2hc_rmvsnet/__init__.py:
--------------------------------------------------------------------------------
1 | import logging
2 | import warnings
3 | from types import SimpleNamespace
4 | from typing import Any, Optional
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 | from torch import Tensor
10 |
11 | from .drmvsnet import D2HCRMVSNet
12 | from .vamvsnet import mvsnet_loss
13 |
14 | __all__ = ["mvsnet_loss", "D2HCRMVSNet", "NetBuilder", "SimpleInterfaceNet"]
15 |
16 | _logger = logging.getLogger(__name__)
17 |
18 |
19 | _DATASETS_SIZE = {
20 | "dtu_yao": (512, 640),
21 | "blended_mvs": (576, 768),
22 | "blended_mvg": (576, 768),
23 | }
24 |
25 |
26 | class SimpleInterfaceNet(nn.Module):
27 | """
28 | Simple common interface to call the pretrained models
29 | """
30 |
31 | def __init__(self, *args, **kwargs):
32 | super().__init__()
33 | self.model = D2HCRMVSNet(*args, **kwargs)
34 | self.all_outputs: dict[str, Any] = {}
35 |
36 | def forward(
37 | self,
38 | imgs: Tensor,
39 | intrinsics: Tensor,
40 | extrinsics: Tensor,
41 | depth_values: Tensor,
42 | hints: Optional[Tensor] = None,
43 | ):
44 | with warnings.catch_warnings():
45 | warnings.simplefilter("ignore", UserWarning)
46 |
47 | # compute poses
48 | proj_mat = extrinsics.clone()
49 | intrinsics_copy = intrinsics.clone()
50 | intrinsics_copy[..., :2, :] = intrinsics_copy[..., :2, :] / 4
51 | proj_mat[..., :3, :4] = intrinsics_copy @ proj_mat[..., :3, :4]
52 |
53 | # image resize
54 | imgs = torch.stack(
55 | [
56 | F.interpolate(img, scale_factor=0.25, mode="bilinear", align_corners=True)
57 | for img in torch.unbind(imgs, 2)
58 | ],
59 | 2,
60 | )
61 |
62 | # validhints
63 | validhints = None
64 | if hints is not None:
65 | validhints = (hints > 0).to(torch.float32)
66 |
67 | # call
68 | out = self.model(imgs, proj_mat, depth_values, hints, validhints)
69 | self.all_outputs = out
70 | return out["depth"]["stage_0"]
71 |
72 |
73 | class NetBuilder(nn.Module):
74 | def __init__(self, args: SimpleNamespace):
75 | super().__init__()
76 |
77 | self.model = D2HCRMVSNet()
78 | self.loss = mvsnet_loss
79 |
80 | def forward(self, batch: dict):
81 |
82 | hints, validhints = None, None
83 | if "hints" in batch:
84 | hints = batch["hints"]
85 | validhints = (hints > 0).to(torch.float32)
86 |
87 | out = self.model(
88 | batch["imgs"]["stage_2"],
89 | batch["proj_matrices"]["stage_2"],
90 | batch["depth_values"],
91 | hints,
92 | validhints,
93 | )
94 | return out
95 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/d2hc_rmvsnet/convlstm.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | from torch.autograd import Variable
4 |
5 | from .module import *
6 |
7 |
8 | class ConvLSTMCell(nn.Module):
9 | def __init__(self, input_size, scale, input_dim, hidden_dim, kernel_size, bias=True):
10 | """
11 | Initialize ConvLSTM cell.
12 |
13 | Parameters
14 | ----------
15 | input_size: (int, int)
16 | Height and width of input tensor as (height, width).
17 | input_dim: int
18 | Number of channels of input tensor.
19 | hidden_dim: int
20 | Number of channels of hidden state.
21 | kernel_size: (int, int)
22 | Size of the convolutional kernel.
23 | bias: bool
24 | Whether or not to add the bias.
25 | """
26 |
27 | super(ConvLSTMCell, self).__init__()
28 |
29 | self.height, self.width = input_size
30 | self.input_dim = input_dim
31 | self.hidden_dim = hidden_dim
32 | self.scale = scale
33 |
34 | self.kernel_size = kernel_size
35 | self.padding = kernel_size[0] // 2, kernel_size[1] // 2
36 | self.bias = bias
37 |
38 | self.conv = nn.Conv2d(
39 | in_channels=self.input_dim + self.hidden_dim,
40 | out_channels=4 * self.hidden_dim,
41 | kernel_size=self.kernel_size,
42 | padding=self.padding,
43 | bias=self.bias,
44 | )
45 |
46 | def forward(self, input_tensor, cur_state):
47 |
48 | h_cur, c_cur = cur_state
49 |
50 | combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis
51 |
52 | combined_conv = self.conv(combined)
53 | cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
54 | i = torch.sigmoid(cc_i)
55 | f = torch.sigmoid(cc_f)
56 | o = torch.sigmoid(cc_o)
57 | g = torch.tanh(cc_g)
58 |
59 | c_next = f * c_cur + i * g
60 | h_next = o * torch.tanh(c_next)
61 |
62 | return h_next, c_next
63 |
64 | def init_hidden(self, batch_size, height=None, width=None):
65 | if height is None:
66 | height = self.height
67 | if width is None:
68 | width = self.width
69 | return (
70 | Variable(
71 | torch.zeros(
72 | batch_size, self.hidden_dim, int(height / self.scale), int(width / self.scale)
73 | )
74 | ).to(self.conv.weight.device),
75 | Variable(
76 | torch.zeros(
77 | batch_size, self.hidden_dim, int(height / self.scale), int(width / self.scale)
78 | )
79 | ).to(self.conv.weight.device),
80 | )
81 |
82 |
83 | class ConvBnLSTMCell(ConvLSTMCell):
84 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias=True):
85 | super(ConvBnLSTMCell, self).__init__(input_size, input_dim, hidden_dim, kernel_size, bias)
86 |
87 | # self.conv = ConvBnReLU(in_channels=self.input_dim + self.hidden_dim,
88 | self.conv = ConvBn(
89 | in_channels=self.input_dim + self.hidden_dim,
90 | out_channels=4 * self.hidden_dim,
91 | kernel_size=self.kernel_size,
92 | stride=1,
93 | padding=self.padding,
94 | ) # bias = False in_channels, out_channels, kernel_size=3, stride=1, pad=1
95 |
96 |
97 | class ConvGnLSTMCell(ConvLSTMCell):
98 | def __init__(self, input_size, input_dim, hidden_dim, kernel_size, bias=True):
99 | super(ConvGnLSTMCell, self).__init__(input_size, input_dim, hidden_dim, kernel_size, bias)
100 |
101 | # self.conv = ConvGnReLU(in_channels=self.input_dim + self.hidden_dim,
102 | self.conv = ConvGn(
103 | in_channels=self.input_dim + self.hidden_dim,
104 | out_channels=4 * self.hidden_dim,
105 | kernel_size=self.kernel_size,
106 | stride=1,
107 | padding=self.padding,
108 | )
109 |
110 |
111 | class ConvLSTM(nn.Module):
112 | def __init__(
113 | self,
114 | input_size,
115 | input_dim,
116 | hidden_dim,
117 | kernel_size,
118 | num_layers,
119 | batch_first=False,
120 | bias=True,
121 | return_all_layers=False,
122 | ):
123 | super(ConvLSTM, self).__init__()
124 |
125 | self._check_kernel_size_consistency(kernel_size)
126 |
127 | # Make sure that both `kernel_size` and `hidden_dim` are lists having len == num_layers
128 | kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
129 | hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)
130 | if not len(kernel_size) == len(hidden_dim) == num_layers:
131 | raise ValueError("Inconsistent list length.")
132 |
133 | self.height, self.width = input_size
134 |
135 | self.input_dim = input_dim
136 | self.hidden_dim = hidden_dim
137 | self.kernel_size = kernel_size
138 | self.num_layers = num_layers
139 | self.batch_first = batch_first
140 | self.bias = bias
141 | self.return_all_layers = return_all_layers
142 |
143 | cell_list = []
144 | for i in range(0, self.num_layers):
145 | cur_input_dim = self.input_dim if i == 0 else self.hidden_dim[i - 1]
146 |
147 | cell_list.append(
148 | ConvLSTMCell(
149 | input_size=(self.height, self.width),
150 | input_dim=cur_input_dim,
151 | hidden_dim=self.hidden_dim[i],
152 | kernel_size=self.kernel_size[i],
153 | bias=self.bias,
154 | )
155 | )
156 |
157 | self.cell_list = nn.ModuleList(cell_list)
158 |
159 | def forward(self, input_tensor, hidden_state=None):
160 | """
161 |
162 | Parameters
163 | ----------
164 | input_tensor: todo
165 | 5-D Tensor either of shape (t, b, c, h, w) or (b, t, c, h, w)
166 | hidden_state: todo
167 | None. todo implement stateful
168 |
169 | Returns
170 | -------
171 | last_state_list, layer_output
172 | """
173 | if not self.batch_first:
174 | # (t, b, c, h, w) -> (b, t, c, h, w)
175 | input_tensor = input_tensor.permute(1, 0, 2, 3, 4)
176 |
177 | # Implement stateful ConvLSTM
178 | if hidden_state is not None:
179 | raise NotImplementedError()
180 | else:
181 | hidden_state = self._init_hidden(batch_size=input_tensor.size(0))
182 |
183 | layer_output_list = []
184 | last_state_list = []
185 |
186 | seq_len = input_tensor.size(1)
187 | cur_layer_input = input_tensor
188 |
189 | for layer_idx in range(self.num_layers):
190 |
191 | h, c = hidden_state[layer_idx]
192 | output_inner = []
193 | for t in range(seq_len):
194 |
195 | h, c = self.cell_list[layer_idx](
196 | input_tensor=cur_layer_input[:, t, :, :, :], cur_state=[h, c]
197 | )
198 | output_inner.append(h)
199 |
200 | layer_output = torch.stack(output_inner, dim=1)
201 | cur_layer_input = layer_output
202 |
203 | layer_output_list.append(layer_output)
204 | last_state_list.append([h, c])
205 |
206 | if not self.return_all_layers:
207 | layer_output_list = layer_output_list[-1:]
208 | last_state_list = last_state_list[-1:]
209 |
210 | return layer_output_list, last_state_list
211 |
212 | def _init_hidden(self, batch_size):
213 | init_states = []
214 | for i in range(self.num_layers):
215 | init_states.append(self.cell_list[i].init_hidden(batch_size))
216 | return init_states
217 |
218 | @staticmethod
219 | def _check_kernel_size_consistency(kernel_size):
220 | if not (
221 | isinstance(kernel_size, tuple)
222 | or (
223 | isinstance(kernel_size, list)
224 | and all([isinstance(elem, tuple) for elem in kernel_size])
225 | )
226 | ):
227 | raise ValueError("`kernel_size` must be tuple or list of tuples")
228 |
229 | @staticmethod
230 | def _extend_for_multilayer(param, num_layers):
231 | if not isinstance(param, list):
232 | param = [param] * num_layers
233 | return param
234 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/d2hc_rmvsnet/submodule.py:
--------------------------------------------------------------------------------
1 | #!/usr/bin/python
2 | # -*- coding: utf-8 -*-
3 | #
4 | # Hongwei Yi (hongweiyi@pku.edu.cn)
5 |
6 | import numpy as np
7 | import torch
8 | import torch.nn as nn
9 |
10 |
11 | def conv(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
12 | return nn.Sequential(
13 | nn.Conv2d(
14 | in_channels,
15 | out_channels,
16 | kernel_size=kernel_size,
17 | stride=stride,
18 | dilation=dilation,
19 | padding=((kernel_size - 1) // 2) * dilation,
20 | bias=bias,
21 | ),
22 | nn.ReLU(inplace=True),
23 | )
24 |
25 |
26 | def convbn(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
27 | return nn.Sequential(
28 | nn.Conv2d(
29 | in_channels,
30 | out_channels,
31 | kernel_size=kernel_size,
32 | stride=stride,
33 | dilation=dilation,
34 | padding=((kernel_size - 1) // 2) * dilation,
35 | bias=bias,
36 | ),
37 | nn.BatchNorm2d(out_channels),
38 | # nn.SyncBatchNorm(out_channels),
39 | # nn.LeakyReLU(0.0,inplace=True)
40 | nn.ReLU(inplace=True),
41 | )
42 |
43 |
44 | def convgnrelu(
45 | in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True, group_channel=8
46 | ):
47 | return nn.Sequential(
48 | nn.Conv2d(
49 | in_channels,
50 | out_channels,
51 | kernel_size=kernel_size,
52 | stride=stride,
53 | dilation=dilation,
54 | padding=((kernel_size - 1) // 2) * dilation,
55 | bias=bias,
56 | ),
57 | nn.GroupNorm(int(max(1, out_channels / group_channel)), out_channels),
58 | nn.ReLU(inplace=True),
59 | )
60 |
61 |
62 | # def conv3d(in_channels, out_channels, kernel_size=3, stride=1,dilation=1, bias=True):
63 | # return nn.Sequential(
64 | # nn.Conv3d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, dilation=dilation, padding=((kernel_size-1)//2)*dilation, bias=bias),
65 | # nn.LeakyReLU(0.0,inplace=True)
66 | # )
67 |
68 |
69 | def conv3dgn(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
70 | return nn.Sequential(
71 | nn.Conv3d(
72 | in_channels,
73 | out_channels,
74 | kernel_size=kernel_size,
75 | stride=stride,
76 | dilation=dilation,
77 | padding=((kernel_size - 1) // 2) * dilation,
78 | bias=bias,
79 | ),
80 | nn.GroupNorm(1, 1),
81 | nn.LeakyReLU(0.0, inplace=True),
82 | )
83 |
84 |
85 | def conv3d(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
86 | return nn.Sequential(
87 | nn.Conv3d(
88 | in_channels,
89 | out_channels,
90 | kernel_size=kernel_size,
91 | stride=stride,
92 | dilation=dilation,
93 | padding=((kernel_size - 1) // 2) * dilation,
94 | bias=bias,
95 | ),
96 | nn.BatchNorm3d(out_channels),
97 | nn.LeakyReLU(0.0, inplace=True),
98 | )
99 |
100 |
101 | def resnet_block(in_channels, kernel_size=3, dilation=[1, 1], bias=True):
102 | return ResnetBlock(in_channels, kernel_size, dilation, bias=bias)
103 |
104 |
105 | def resnet_block_bn(in_channels, kernel_size=3, dilation=[1, 1], bias=True):
106 | return ResnetBlockBn(in_channels, kernel_size, dilation, bias=bias)
107 |
108 |
109 | class ResnetBlock(nn.Module):
110 | def __init__(self, in_channels, kernel_size, dilation, bias):
111 | super(ResnetBlock, self).__init__()
112 | self.stem = nn.Sequential(
113 | nn.Conv2d(
114 | in_channels,
115 | in_channels,
116 | kernel_size=kernel_size,
117 | stride=1,
118 | dilation=dilation[0],
119 | padding=((kernel_size - 1) // 2) * dilation[0],
120 | bias=bias,
121 | ),
122 | nn.LeakyReLU(0.0, inplace=True),
123 | nn.Conv2d(
124 | in_channels,
125 | in_channels,
126 | kernel_size=kernel_size,
127 | stride=1,
128 | dilation=dilation[1],
129 | padding=((kernel_size - 1) // 2) * dilation[1],
130 | bias=bias,
131 | ),
132 | )
133 |
134 | def forward(self, x):
135 | out = self.stem(x) + x
136 | return out
137 |
138 |
139 | class ResnetBlockBn(nn.Module):
140 | def __init__(self, in_channels, kernel_size, dilation, bias):
141 | super(ResnetBlockBn, self).__init__()
142 | self.stem = nn.Sequential(
143 | convbn(
144 | in_channels,
145 | in_channels,
146 | kernel_size=kernel_size,
147 | stride=1,
148 | dilation=dilation[0],
149 | bias=bias,
150 | ),
151 | nn.Conv2d(
152 | in_channels,
153 | in_channels,
154 | kernel_size=kernel_size,
155 | stride=1,
156 | dilation=dilation[1],
157 | padding=((kernel_size - 1) // 2) * dilation[1],
158 | bias=bias,
159 | ),
160 | )
161 |
162 | def forward(self, x):
163 | out = self.stem(x) + x
164 | return out
165 |
166 |
167 | ##### Define weightnet-3d
168 | def volumegatelight(in_channels, kernel_size=3, dilation=[1, 1], bias=True):
169 | return nn.Sequential(
170 | # MSDilateBlock3D(in_channels, kernel_size, dilation, bias),
171 | conv3d(in_channels, 1, kernel_size=1, stride=1, bias=bias),
172 | conv3d(1, 1, kernel_size=1, stride=1),
173 | )
174 |
175 |
176 | def volumegatelightgn(in_channels, kernel_size=3, dilation=[1, 1], bias=True):
177 | return nn.Sequential(
178 | # MSDilateBlock3D(in_channels, kernel_size, dilation, bias),
179 | conv3dgn(in_channels, 1, kernel_size=1, stride=1, bias=bias),
180 | conv3dgn(1, 1, kernel_size=1, stride=1),
181 | )
182 |
183 |
184 | ##### Define gatenet
185 | def gatenetbn(bias=True):
186 | return nn.Sequential(
187 | convbn(32, 16, kernel_size=3, stride=1, pad=1, dilation=1, bias=bias),
188 | resnet_block_bn(16, kernel_size=1),
189 | nn.Conv2d(16, 1, kernel_size=1, padding=0),
190 | nn.Sigmoid(),
191 | )
192 |
193 |
194 | def pillarnetbn(bias=True):
195 | return nn.Sequential(
196 | nn.Linear(192, 32, bias=bias),
197 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE, inplace=True),
198 | nn.Linear(32, 2),
199 | )
200 |
201 |
202 | class ResnetBlockGn(nn.Module):
203 | def __init__(self, in_channels, kernel_size, dilation, bias, group_channel=8):
204 | super(ResnetBlockGn, self).__init__()
205 | self.stem = nn.Sequential(
206 | convgnrelu(
207 | in_channels,
208 | in_channels,
209 | kernel_size=kernel_size,
210 | stride=1,
211 | dilation=dilation[0],
212 | bias=bias,
213 | group_channel=group_channel,
214 | ),
215 | nn.Conv2d(
216 | in_channels,
217 | in_channels,
218 | kernel_size=kernel_size,
219 | stride=1,
220 | dilation=dilation[1],
221 | padding=((kernel_size - 1) // 2) * dilation[1],
222 | bias=bias,
223 | ),
224 | nn.GroupNorm(int(max(1, in_channels / group_channel)), in_channels),
225 | )
226 | self.relu = nn.ReLU(inplace=True)
227 |
228 | def forward(self, x):
229 | out = self.stem(x) + x
230 | out = self.relu(out)
231 | return out
232 |
233 |
234 | def resnet_block_gn(in_channels, kernel_size=3, dilation=[1, 1], bias=True, group_channel=8):
235 | return ResnetBlockGn(
236 | in_channels, kernel_size, dilation, bias=bias, group_channel=group_channel
237 | )
238 |
239 |
240 | def gatenet(gn=True, in_channels=32, bias=True): # WORK
241 | if gn:
242 | return nn.Sequential(
243 | # nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, dilation=1, padding=1, bias=bias), # in_channels=64
244 | # nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE,inplace=True),
245 | convgnrelu(
246 | in_channels, 4, kernel_size=3, stride=1, dilation=1, bias=bias
247 | ), # 4: 10G,8.6G;
248 | resnet_block_gn(4, kernel_size=1),
249 | nn.Conv2d(4, 1, kernel_size=1, padding=0),
250 | nn.Sigmoid(),
251 | )
252 | else:
253 | return nn.Sequential(
254 | nn.Conv2d(
255 | in_channels, 16, kernel_size=3, stride=1, dilation=1, padding=1, bias=bias
256 | ), # in_channels=64
257 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE, inplace=True),
258 | resnet_block(16, kernel_size=1),
259 | nn.Conv2d(16, 1, kernel_size=1, padding=0),
260 | nn.Sigmoid(),
261 | )
262 |
263 |
264 | def gatenet_m4(gn=True, in_channels=32, bias=True):
265 | if gn:
266 | return nn.Sequential(
267 | # nn.Conv2d(in_channels, 16, kernel_size=3, stride=1, dilation=1, padding=1, bias=bias), # in_channels=64
268 | # nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE,inplace=True),
269 | convgnrelu(
270 | in_channels, 8, kernel_size=3, stride=1, dilation=1, bias=bias
271 | ), # 4: 10G,8.6G;
272 | resnet_block_gn(8, kernel_size=1),
273 | nn.Conv2d(8, 1, kernel_size=1, padding=0),
274 | nn.Sigmoid(),
275 | )
276 | else:
277 | return nn.Sequential(
278 | nn.Conv2d(
279 | in_channels, 16, kernel_size=3, stride=1, dilation=1, padding=1, bias=bias
280 | ), # in_channels=64
281 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE, inplace=True),
282 | resnet_block(16, kernel_size=1),
283 | nn.Conv2d(16, 1, kernel_size=1, padding=0),
284 | nn.Sigmoid(),
285 | )
286 |
287 |
288 | def pillarnet(in_channels=192, bias=True): # origin_pillarnet: 192, 96, 48, 24
289 | return nn.Sequential(
290 | nn.Linear(in_channels, 32, bias=bias),
291 | nn.LeakyReLU(cfg.NETWORK.LEAKY_VALUE, inplace=True),
292 | nn.Linear(32, 2),
293 | )
294 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/d2hc_rmvsnet/vamvsnet_high_submodule.py:
--------------------------------------------------------------------------------
1 | import sys
2 | from copy import deepcopy
3 |
4 | import torch
5 | import torch.nn as nn
6 | import torch.nn.functional as F
7 |
8 | from .module import *
9 |
10 | # Multi-scale feature extractor && Coarse To Fine Regression Module
11 |
12 |
13 | class FeatureNetHigh(nn.Module): # Original Paper Setting
14 | def __init__(self):
15 | super(FeatureNetHigh, self).__init__()
16 | self.inplanes = 32
17 |
18 | self.conv0 = ConvBnReLU(3, 8, 3, 1, 1)
19 | self.conv1 = ConvBnReLU(8, 8, 3, 1, 1)
20 |
21 | self.conv2 = ConvBnReLU(8, 16, 5, 2, 2)
22 | self.conv3 = ConvBnReLU(16, 16, 3, 1, 1)
23 | self.conv4 = ConvBnReLU(16, 16, 3, 1, 1)
24 |
25 | self.conv5 = ConvBnReLU(16, 32, 5, 2, 2)
26 | self.conv6 = ConvBnReLU(32, 32, 3, 1, 1)
27 |
28 | self.conv7 = ConvBnReLU(32, 32, 5, 2, 2)
29 | self.conv8 = ConvBnReLU(32, 32, 3, 1, 1)
30 |
31 | self.conv9 = ConvBnReLU(32, 64, 5, 2, 2)
32 | self.conv10 = ConvBnReLU(64, 64, 3, 1, 1)
33 |
34 | self.conv11 = ConvBnReLU(64, 64, 5, 2, 2)
35 | self.conv12 = ConvBnReLU(64, 64, 3, 1, 1)
36 |
37 | self.feature1 = nn.Conv2d(32, 32, 3, 1, 1)
38 |
39 | self.feature2 = nn.Conv2d(32, 32, 3, 1, 1)
40 |
41 | self.feature3 = nn.Conv2d(64, 64, 3, 1, 1)
42 |
43 | self.feature4 = nn.Conv2d(64, 64, 3, 1, 1)
44 |
45 | def forward(self, x):
46 | x = self.conv1(self.conv0(x))
47 | x = self.conv4(self.conv3(self.conv2(x)))
48 | x = self.conv6(self.conv5(x))
49 | feature1 = self.feature1(x)
50 | x = self.conv8(self.conv7(x))
51 | feature2 = self.feature2(x)
52 | x = self.conv10(self.conv9(x))
53 | feature3 = self.feature3(x)
54 | x = self.conv12(self.conv11(x))
55 | feature4 = self.feature4(x)
56 | return [feature1, feature2, feature3, feature4]
57 |
58 |
59 | class FeatureNetHighGN(nn.Module): # Original Paper Setting
60 | def __init__(self):
61 | super(FeatureNetHighGN, self).__init__()
62 | self.inplanes = 32
63 |
64 | self.conv0 = ConvGnReLU(3, 8, 3, 1, 1)
65 | self.conv1 = ConvGnReLU(8, 8, 3, 1, 1)
66 |
67 | self.conv2 = ConvGnReLU(8, 16, 5, 2, 2)
68 | self.conv3 = ConvGnReLU(16, 16, 3, 1, 1)
69 | self.conv4 = ConvGnReLU(16, 16, 3, 1, 1)
70 |
71 | self.conv5 = ConvGnReLU(16, 32, 5, 2, 2)
72 | self.conv6 = ConvGnReLU(32, 32, 3, 1, 1)
73 |
74 | self.conv7 = ConvGnReLU(32, 32, 5, 2, 2)
75 | self.conv8 = ConvGnReLU(32, 32, 3, 1, 1)
76 |
77 | self.conv9 = ConvGnReLU(32, 64, 5, 2, 2)
78 | self.conv10 = ConvGnReLU(64, 64, 3, 1, 1)
79 |
80 | self.conv11 = ConvGnReLU(64, 64, 5, 2, 2)
81 | self.conv12 = ConvGnReLU(64, 64, 3, 1, 1)
82 |
83 | self.feature1 = nn.Conv2d(32, 32, 3, 1, 1)
84 |
85 | self.feature2 = nn.Conv2d(32, 32, 3, 1, 1)
86 |
87 | self.feature3 = nn.Conv2d(64, 64, 3, 1, 1)
88 |
89 | self.feature4 = nn.Conv2d(64, 64, 3, 1, 1)
90 |
91 | def forward(self, x):
92 | x = self.conv1(self.conv0(x))
93 | x = self.conv4(self.conv3(self.conv2(x)))
94 | x = self.conv6(self.conv5(x))
95 | feature1 = self.feature1(x)
96 | x = self.conv8(self.conv7(x))
97 | feature2 = self.feature2(x)
98 | x = self.conv10(self.conv9(x))
99 | feature3 = self.feature3(x)
100 | x = self.conv12(self.conv11(x))
101 | feature4 = self.feature4(x)
102 | return [feature1, feature2, feature3, feature4]
103 |
104 |
105 | class RegNetUS0_Coarse2Fine(nn.Module):
106 | def __init__(self, origin_size=False, dp_ratio=0.0, image_scale=0.25):
107 | super(RegNetUS0_Coarse2Fine, self).__init__()
108 | self.origin_size = origin_size
109 | self.image_scale = image_scale
110 |
111 | self.conv0 = ConvBnReLU3D(32, 8)
112 |
113 | self.conv1 = ConvBnReLU3D(32, 16, stride=2)
114 | self.conv2 = ConvBnReLU3D(16, 16)
115 |
116 | self.conv3 = ConvBnReLU3D(16, 32, stride=2)
117 | self.conv4 = ConvBnReLU3D(32, 32)
118 |
119 | self.conv5 = ConvBnReLU3D(32, 64, stride=2)
120 | self.conv6 = ConvBnReLU3D(64, 64)
121 |
122 | self.conv7 = nn.Sequential(
123 | nn.ConvTranspose3d(
124 | 128, 32, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False
125 | ),
126 | nn.BatchNorm3d(32),
127 | nn.ReLU(inplace=True),
128 | )
129 |
130 | self.conv9 = nn.Sequential(
131 | nn.ConvTranspose3d(
132 | 97, 16, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False
133 | ),
134 | nn.BatchNorm3d(16),
135 | nn.ReLU(inplace=True),
136 | )
137 |
138 | self.conv11 = nn.Sequential(
139 | nn.ConvTranspose3d(
140 | 49, 8, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False
141 | ),
142 | nn.BatchNorm3d(8),
143 | nn.ReLU(inplace=True),
144 | )
145 |
146 | self.prob1 = nn.Conv3d(41, 1, 1, bias=False)
147 | self.dropout1 = nn.Dropout3d(p=dp_ratio)
148 | self.prob2 = nn.Conv3d(49, 1, 1, bias=False)
149 | self.dropout2 = nn.Dropout3d(p=dp_ratio)
150 | self.prob3 = nn.Conv3d(97, 1, 1, bias=False)
151 | self.dropout3 = nn.Dropout3d(p=dp_ratio)
152 | self.prob4 = nn.Conv3d(128, 1, 1, bias=False)
153 | self.dropout4 = nn.Dropout3d(p=dp_ratio)
154 | # add Drop out
155 |
156 | def forward(self, x_list):
157 | x1, x2, x3, x4 = x_list # 32*192, 32*96, 64*48, 64*24
158 | input_shape = x1.shape
159 |
160 | conv0 = self.conv0(x1)
161 | conv1 = self.conv1(x1)
162 | conv3 = self.conv3(conv1)
163 | conv5 = self.conv5(conv3)
164 |
165 | x = torch.cat([self.conv6(conv5), x4], 1)
166 | prob4 = self.dropout4(self.prob4(x))
167 | # prob4 = self.prob4(x)
168 | x = self.conv7(x) + self.conv4(conv3)
169 | x = torch.cat(
170 | [x, x3, F.interpolate(prob4, scale_factor=2, mode="trilinear", align_corners=True)], 1
171 | )
172 | prob3 = self.dropout3(self.prob3(x))
173 | # prob3 = self.prob3(x)
174 | x = self.conv9(x) + self.conv2(conv1)
175 | x = torch.cat(
176 | [x, x2, F.interpolate(prob3, scale_factor=2, mode="trilinear", align_corners=True)], 1
177 | )
178 | prob2 = self.dropout2(self.prob2(x))
179 | # prob2 = self.prob2(x)
180 | x = self.conv11(x) + conv0
181 | x = torch.cat(
182 | [x, x1, F.interpolate(prob2, scale_factor=2, mode="trilinear", align_corners=True)], 1
183 | )
184 |
185 | if self.origin_size and self.image_scale == 0.50:
186 | x = F.interpolate(
187 | x,
188 | size=(input_shape[2], input_shape[3] * 2, input_shape[4] * 2),
189 | mode="trilinear",
190 | align_corners=True,
191 | )
192 | prob1 = self.dropout1(self.prob1(x))
193 | # prob1 = self.prob1(x) # without dropout
194 | # if self.origin_size:
195 | # x = F.interpolate(x, size=(input_shape[2], input_shape[3]*4, input_shape[4]*4), mode='trilinear', align_corners=True)
196 | return [prob1, prob2, prob3, prob4]
197 |
198 |
199 | class RegNetUS0_Coarse2FineGN(nn.Module):
200 | def __init__(self, origin_size=False, dp_ratio=0.0, image_scale=0.25):
201 | super(RegNetUS0_Coarse2FineGN, self).__init__()
202 | self.origin_size = origin_size
203 | self.image_scale = image_scale
204 |
205 | self.conv0 = ConvGnReLU3D(32, 8)
206 |
207 | self.conv1 = ConvGnReLU3D(32, 16, stride=2)
208 | self.conv2 = ConvGnReLU3D(16, 16)
209 |
210 | self.conv3 = ConvGnReLU3D(16, 32, stride=2)
211 | self.conv4 = ConvGnReLU3D(32, 32)
212 |
213 | self.conv5 = ConvGnReLU3D(32, 64, stride=2)
214 | self.conv6 = ConvGnReLU3D(64, 64)
215 |
216 | self.conv7 = nn.Sequential(
217 | nn.ConvTranspose3d(
218 | 128, 32, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False
219 | ),
220 | # nn.BatchNorm3d(32),
221 | nn.GroupNorm(4, 32),
222 | nn.ReLU(inplace=True),
223 | )
224 |
225 | self.conv9 = nn.Sequential(
226 | nn.ConvTranspose3d(
227 | 97, 16, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False
228 | ),
229 | nn.GroupNorm(2, 16),
230 | nn.ReLU(inplace=True),
231 | )
232 |
233 | self.conv11 = nn.Sequential(
234 | nn.ConvTranspose3d(
235 | 49, 8, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False
236 | ),
237 | nn.GroupNorm(1, 8),
238 | nn.ReLU(inplace=True),
239 | )
240 |
241 | self.prob1 = nn.Conv3d(41, 1, 1, bias=False)
242 | self.dropout1 = nn.Dropout3d(p=dp_ratio)
243 | self.prob2 = nn.Conv3d(49, 1, 1, bias=False)
244 | self.dropout2 = nn.Dropout3d(p=dp_ratio)
245 | self.prob3 = nn.Conv3d(97, 1, 1, bias=False)
246 | self.dropout3 = nn.Dropout3d(p=dp_ratio)
247 | self.prob4 = nn.Conv3d(128, 1, 1, bias=False)
248 | self.dropout4 = nn.Dropout3d(p=dp_ratio)
249 | # add Drop out
250 |
251 | def forward(self, x_list):
252 | x1, x2, x3, x4 = x_list # 32*192, 32*96, 64*48, 64*24
253 | # print(x1.shape, x2.shape, x3.shape, x4.shape)
254 | input_shape = x1.shape
255 |
256 | conv0 = self.conv0(x1)
257 | conv1 = self.conv1(x1)
258 | conv3 = self.conv3(conv1)
259 | conv5 = self.conv5(conv3)
260 |
261 | x = torch.cat([self.conv6(conv5), x4], 1)
262 | prob4 = self.dropout4(self.prob4(x))
263 | # prob4 = self.prob4(x)
264 | x = self.conv7(x) + self.conv4(conv3)
265 | x = torch.cat(
266 | [x, x3, F.interpolate(prob4, scale_factor=2, mode="trilinear", align_corners=True)], 1
267 | )
268 | prob3 = self.dropout3(self.prob3(x))
269 | # prob3 = self.prob3(x)
270 | x = self.conv9(x) + self.conv2(conv1)
271 | x = torch.cat(
272 | [x, x2, F.interpolate(prob3, scale_factor=2, mode="trilinear", align_corners=True)], 1
273 | )
274 | prob2 = self.dropout2(self.prob2(x))
275 | # prob2 = self.prob2(x)
276 | x = self.conv11(x) + conv0
277 | x = torch.cat(
278 | [x, x1, F.interpolate(prob2, scale_factor=2, mode="trilinear", align_corners=True)], 1
279 | )
280 |
281 | if self.origin_size and self.image_scale == 0.50:
282 | x = F.interpolate(
283 | x,
284 | size=(input_shape[2], input_shape[3] * 2, input_shape[4] * 2),
285 | mode="trilinear",
286 | align_corners=True,
287 | )
288 | prob1 = self.dropout1(self.prob1(x))
289 | # prob1 = self.prob1(x) # without dropout
290 | # if self.origin_size:
291 | # x = F.interpolate(x, size=(input_shape[2], input_shape[3]*4, input_shape[4]*4), mode='trilinear', align_corners=True)
292 | return [prob1, prob2, prob3, prob4]
293 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/mvsnet/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from types import SimpleNamespace
3 | from typing import Any, Optional
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch import Tensor
8 |
9 | from .mvsnet import MVSNet, mvsnet_loss
10 |
11 | __all__ = ["mvsnet_loss", "MVSNet", "NetBuilder", "SimpleInterfaceNet"]
12 |
13 |
14 | class SimpleInterfaceNet(nn.Module):
15 | """
16 | Simple common interface to call the pretrained models
17 | """
18 |
19 | def __init__(self, refine: bool = False):
20 | super().__init__()
21 | self.model = MVSNet(refine=refine)
22 | self.all_outputs: dict[str, Any] = {}
23 |
24 | def forward(
25 | self,
26 | imgs: Tensor,
27 | intrinsics: Tensor,
28 | extrinsics: Tensor,
29 | depth_values: Tensor,
30 | hints: Optional[Tensor] = None,
31 | ):
32 | with warnings.catch_warnings():
33 | warnings.simplefilter("ignore", UserWarning)
34 |
35 | # compute poses
36 | proj_mat = extrinsics.clone()
37 | intrinsics_copy = intrinsics.clone()
38 | intrinsics_copy[..., :2, :] = intrinsics_copy[..., :2, :] / 4
39 | proj_mat[..., :3, :4] = intrinsics_copy @ proj_mat[..., :3, :4]
40 |
41 | # validhints
42 | validhints = None
43 | if hints is not None:
44 | validhints = (hints > 0).to(torch.float32)
45 |
46 | # call
47 | out = self.model(imgs, proj_mat, depth_values, hints, validhints)
48 | self.all_outputs = out
49 | return out["depth"]["stage_0"]
50 |
51 |
52 | class NetBuilder(nn.Module):
53 | def __init__(self, args: SimpleNamespace):
54 | super().__init__()
55 | self.model = MVSNet(refine=False)
56 | self.loss = mvsnet_loss
57 |
58 | def forward(self, batch: dict):
59 |
60 | hints, validhints = None, None
61 | if "hints" in batch:
62 | hints = batch["hints"]
63 | validhints = (hints > 0).to(torch.float32)
64 |
65 | return self.model(
66 | batch["imgs"]["stage_0"],
67 | batch["proj_matrices"]["stage_2"],
68 | batch["depth_values"],
69 | hints,
70 | validhints,
71 | )
72 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/mvsnet/module.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 |
6 | class ConvBnReLU(nn.Module):
7 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
8 | super(ConvBnReLU, self).__init__()
9 | self.conv = nn.Conv2d(
10 | in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False
11 | )
12 | self.bn = nn.BatchNorm2d(out_channels)
13 |
14 | def forward(self, x):
15 | return F.relu(self.bn(self.conv(x)), inplace=True)
16 |
17 |
18 | class ConvBn(nn.Module):
19 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
20 | super(ConvBn, self).__init__()
21 | self.conv = nn.Conv2d(
22 | in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False
23 | )
24 | self.bn = nn.BatchNorm2d(out_channels)
25 |
26 | def forward(self, x):
27 | return self.bn(self.conv(x))
28 |
29 |
30 | class ConvBnReLU3D(nn.Module):
31 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
32 | super(ConvBnReLU3D, self).__init__()
33 | self.conv = nn.Conv3d(
34 | in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False
35 | )
36 | self.bn = nn.BatchNorm3d(out_channels)
37 |
38 | def forward(self, x):
39 | return F.relu(self.bn(self.conv(x)), inplace=True)
40 |
41 |
42 | class ConvBn3D(nn.Module):
43 | def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, pad=1):
44 | super(ConvBn3D, self).__init__()
45 | self.conv = nn.Conv3d(
46 | in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False
47 | )
48 | self.bn = nn.BatchNorm3d(out_channels)
49 |
50 | def forward(self, x):
51 | return self.bn(self.conv(x))
52 |
53 |
54 | class BasicBlock(nn.Module):
55 | def __init__(self, in_channels, out_channels, stride, downsample=None):
56 | super(BasicBlock, self).__init__()
57 |
58 | self.conv1 = ConvBnReLU(in_channels, out_channels, kernel_size=3, stride=stride, pad=1)
59 | self.conv2 = ConvBn(out_channels, out_channels, kernel_size=3, stride=1, pad=1)
60 |
61 | self.downsample = downsample
62 | self.stride = stride
63 |
64 | def forward(self, x):
65 | out = self.conv1(x)
66 | out = self.conv2(out)
67 | if self.downsample is not None:
68 | x = self.downsample(x)
69 | out += x
70 | return out
71 |
72 |
73 | class Hourglass3d(nn.Module):
74 | def __init__(self, channels):
75 | super(Hourglass3d, self).__init__()
76 |
77 | self.conv1a = ConvBnReLU3D(channels, channels * 2, kernel_size=3, stride=2, pad=1)
78 | self.conv1b = ConvBnReLU3D(channels * 2, channels * 2, kernel_size=3, stride=1, pad=1)
79 |
80 | self.conv2a = ConvBnReLU3D(channels * 2, channels * 4, kernel_size=3, stride=2, pad=1)
81 | self.conv2b = ConvBnReLU3D(channels * 4, channels * 4, kernel_size=3, stride=1, pad=1)
82 |
83 | self.dconv2 = nn.Sequential(
84 | nn.ConvTranspose3d(
85 | channels * 4,
86 | channels * 2,
87 | kernel_size=3,
88 | padding=1,
89 | output_padding=1,
90 | stride=2,
91 | bias=False,
92 | ),
93 | nn.BatchNorm3d(channels * 2),
94 | )
95 |
96 | self.dconv1 = nn.Sequential(
97 | nn.ConvTranspose3d(
98 | channels * 2,
99 | channels,
100 | kernel_size=3,
101 | padding=1,
102 | output_padding=1,
103 | stride=2,
104 | bias=False,
105 | ),
106 | nn.BatchNorm3d(channels),
107 | )
108 |
109 | self.redir1 = ConvBn3D(channels, channels, kernel_size=1, stride=1, pad=0)
110 | self.redir2 = ConvBn3D(channels * 2, channels * 2, kernel_size=1, stride=1, pad=0)
111 |
112 | def forward(self, x):
113 | conv1 = self.conv1b(self.conv1a(x))
114 | conv2 = self.conv2b(self.conv2a(conv1))
115 | dconv2 = F.relu(self.dconv2(conv2) + self.redir2(conv1), inplace=True)
116 | dconv1 = F.relu(self.dconv1(dconv2) + self.redir1(x), inplace=True)
117 | return dconv1
118 |
119 |
120 | def homo_warping(src_fea, src_proj, ref_proj, depth_values):
121 | # src_fea: [B, C, H, W]
122 | # src_proj: [B, 4, 4]
123 | # ref_proj: [B, 4, 4]
124 | # depth_values: [B, Ndepth]
125 | # out: [B, C, Ndepth, H, W]
126 | batch, channels = src_fea.shape[0], src_fea.shape[1]
127 | num_depth = depth_values.shape[1]
128 | height, width = src_fea.shape[2], src_fea.shape[3]
129 |
130 | with torch.no_grad():
131 | proj = torch.matmul(src_proj, torch.inverse(ref_proj))
132 | rot = proj[:, :3, :3] # [B,3,3]
133 | trans = proj[:, :3, 3:4] # [B,3,1]
134 |
135 | y, x = torch.meshgrid(
136 | [
137 | torch.arange(0, height, dtype=torch.float32, device=src_fea.device),
138 | torch.arange(0, width, dtype=torch.float32, device=src_fea.device),
139 | ]
140 | )
141 | y, x = y.contiguous(), x.contiguous()
142 | y, x = y.view(height * width), x.view(height * width)
143 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
144 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
145 | rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
146 | rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_values.view(
147 | batch, 1, num_depth, 1
148 | ) # [B, 3, Ndepth, H*W]
149 | proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W]
150 | proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W]
151 | proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1
152 | proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1
153 | proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2]
154 | grid = proj_xy
155 |
156 | warped_src_fea = F.grid_sample(
157 | src_fea,
158 | grid.view(batch, num_depth * height, width, 2),
159 | mode="bilinear",
160 | padding_mode="zeros",
161 | )
162 | warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height, width)
163 |
164 | return warped_src_fea
165 |
166 |
167 | # p: probability volume [B, D, H, W]
168 | # depth_values: discrete depth values [B, D]
169 | def depth_regression(p, depth_values):
170 | depth_values = depth_values.view(*depth_values.shape, 1, 1)
171 | depth = torch.sum(p * depth_values, 1)
172 | return depth
173 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/mvsnet/mvsnet.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.functional as F
5 |
6 | from .module import *
7 |
8 |
9 | class FeatureNet(nn.Module):
10 | def __init__(self):
11 | super(FeatureNet, self).__init__()
12 | self.inplanes = 32
13 |
14 | self.conv0 = ConvBnReLU(3, 8, 3, 1, 1)
15 | self.conv1 = ConvBnReLU(8, 8, 3, 1, 1)
16 |
17 | self.conv2 = ConvBnReLU(8, 16, 5, 2, 2)
18 | self.conv3 = ConvBnReLU(16, 16, 3, 1, 1)
19 | self.conv4 = ConvBnReLU(16, 16, 3, 1, 1)
20 |
21 | self.conv5 = ConvBnReLU(16, 32, 5, 2, 2)
22 | self.conv6 = ConvBnReLU(32, 32, 3, 1, 1)
23 | self.feature = nn.Conv2d(32, 32, 3, 1, 1)
24 |
25 | def forward(self, x):
26 | x = self.conv1(self.conv0(x))
27 | x = self.conv4(self.conv3(self.conv2(x)))
28 | x = self.feature(self.conv6(self.conv5(x)))
29 | return x
30 |
31 |
32 | class CostRegNet(nn.Module):
33 | def __init__(self):
34 | super(CostRegNet, self).__init__()
35 | self.conv0 = ConvBnReLU3D(32, 8)
36 |
37 | self.conv1 = ConvBnReLU3D(8, 16, stride=2)
38 | self.conv2 = ConvBnReLU3D(16, 16)
39 |
40 | self.conv3 = ConvBnReLU3D(16, 32, stride=2)
41 | self.conv4 = ConvBnReLU3D(32, 32)
42 |
43 | self.conv5 = ConvBnReLU3D(32, 64, stride=2)
44 | self.conv6 = ConvBnReLU3D(64, 64)
45 |
46 | self.conv7 = nn.Sequential(
47 | nn.ConvTranspose3d(
48 | 64, 32, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False
49 | ),
50 | nn.BatchNorm3d(32),
51 | nn.ReLU(inplace=True),
52 | )
53 |
54 | self.conv9 = nn.Sequential(
55 | nn.ConvTranspose3d(
56 | 32, 16, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False
57 | ),
58 | nn.BatchNorm3d(16),
59 | nn.ReLU(inplace=True),
60 | )
61 |
62 | self.conv11 = nn.Sequential(
63 | nn.ConvTranspose3d(
64 | 16, 8, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False
65 | ),
66 | nn.BatchNorm3d(8),
67 | nn.ReLU(inplace=True),
68 | )
69 |
70 | self.prob = nn.Conv3d(8, 1, 3, stride=1, padding=1)
71 |
72 | def _split_pad(self, pad):
73 | if pad % 2 == 0:
74 | return pad // 2, pad // 2
75 | else:
76 | pad_1 = pad // 2
77 | pad_2 = (pad // 2) + 1
78 | return pad_1, pad_2
79 |
80 | def _generate_slice(self, pad):
81 | if pad == 0:
82 | return slice(0, None)
83 | elif pad % 2 == 0:
84 | return slice(pad // 2, -pad // 2)
85 | else:
86 | pad_1 = pad // 2
87 | pad_2 = (pad // 2) + 1
88 | return slice(pad_1, -pad_2)
89 |
90 | def _pad_to_div_by(self, x, *, div_by=8):
91 | _, _, _, h, w = x.shape
92 | new_h = int(np.ceil(h / div_by)) * div_by
93 | new_w = int(np.ceil(w / div_by)) * div_by
94 | pad_h_l, pad_h_r = self._split_pad(new_h - h)
95 | pad_w_t, pad_w_b = self._split_pad(new_w - w)
96 | return F.pad(x, (pad_w_t, pad_w_b, pad_h_l, pad_h_r))
97 |
98 | def forward(self, x):
99 |
100 | # padding
101 | _, _, _, h, w = x.shape
102 | x = self._pad_to_div_by(x, div_by=8)
103 | _, _, _, new_h, new_w = x.shape
104 |
105 | # regularization
106 | conv0 = self.conv0(x)
107 | conv2 = self.conv2(self.conv1(conv0))
108 | conv4 = self.conv4(self.conv3(conv2))
109 | x = self.conv6(self.conv5(conv4))
110 | x = conv4 + self.conv7(x)
111 | x = conv2 + self.conv9(x)
112 | x = conv0 + self.conv11(x)
113 | x = self.prob(x)
114 |
115 | # unpadding
116 | slice_h = self._generate_slice(new_h - h)
117 | slice_w = self._generate_slice(new_w - w)
118 | x = x[..., slice_h, slice_w]
119 |
120 | return x
121 |
122 |
123 | class RefineNet(nn.Module):
124 | def __init__(self):
125 | super(RefineNet, self).__init__()
126 | self.conv1 = ConvBnReLU(4, 32)
127 | self.conv2 = ConvBnReLU(32, 32)
128 | self.conv3 = ConvBnReLU(32, 32)
129 | self.res = ConvBnReLU(32, 1)
130 |
131 | def forward(self, img, depth_init):
132 | concat = F.cat((img, depth_init), dim=1)
133 | depth_residual = self.res(self.conv3(self.conv2(self.conv1(concat))))
134 | depth_refined = depth_init + depth_residual
135 | return depth_refined
136 |
137 |
138 | class MVSNet(nn.Module):
139 | def __init__(self, refine=True):
140 | super(MVSNet, self).__init__()
141 | self.refine = refine
142 | # self.depth_interval = depth_interval
143 |
144 | self.feature = FeatureNet()
145 | self.cost_regularization = CostRegNet()
146 | if self.refine:
147 | self.refine_network = RefineNet()
148 |
149 | def forward(self, imgs, proj_matrices, depth_values, hints=None, validhints=None):
150 | imgs = torch.unbind(imgs, 1)
151 | proj_matrices = torch.unbind(proj_matrices, 1)
152 | assert len(imgs) == len(
153 | proj_matrices
154 | ), "Different number of images and projection matrices"
155 | num_depth = depth_values.shape[1]
156 | num_views = len(imgs)
157 |
158 | # step 1. feature extraction
159 | # in: images; out: 32-channel feature maps
160 | features = [self.feature(img) for img in imgs]
161 | ref_feature, src_features = features[0], features[1:]
162 | ref_proj, src_projs = proj_matrices[0], proj_matrices[1:]
163 |
164 | # step 2. differentiable homograph, build cost volume
165 | ref_volume = ref_feature.unsqueeze(2).repeat(1, 1, num_depth, 1, 1)
166 | volume_sum = ref_volume
167 | volume_sq_sum = ref_volume ** 2
168 | del ref_volume
169 | for src_fea, src_proj in zip(src_features, src_projs):
170 | # warpped features
171 | warped_volume = homo_warping(src_fea, src_proj, ref_proj, depth_values)
172 | if self.training:
173 | volume_sum = volume_sum + warped_volume
174 | volume_sq_sum = volume_sq_sum + warped_volume ** 2
175 | else:
176 | # TODO: this is only a temporal solution to save memory, better way?
177 | volume_sum += warped_volume
178 | volume_sq_sum += warped_volume.pow_(
179 | 2
180 | ) # the memory of warped_volume has been modified
181 | del warped_volume
182 | # aggregate multiple feature volumes by variance
183 | volume_variance = volume_sq_sum.div_(num_views).sub_(volume_sum.div_(num_views).pow_(2))
184 |
185 | # TODO cost-volume modulation
186 | if hints is not None and validhints is not None:
187 | batch_size, feats, height, width = ref_feature.shape
188 | GAUSSIAN_HEIGHT = 10.0
189 | GAUSSIAN_WIDTH = 1.0
190 |
191 | # image features are one fourth the original size: subsample the hints and divide them by four
192 | hints = hints
193 | hints = F.interpolate(hints, scale_factor=0.25, mode="nearest").unsqueeze(1)
194 | validhints = validhints
195 | validhints = F.interpolate(validhints, scale_factor=0.25, mode="nearest").unsqueeze(1)
196 | hints = hints * validhints
197 |
198 | # add feature and disparity dimensions to hints and validhints
199 | # and repeat their values along those dimensions, to obtain the same size as cost
200 | hints = hints.expand(-1, feats, num_depth, -1, -1)
201 | validhints = validhints.expand(-1, feats, num_depth, -1, -1)
202 |
203 | # create a tensor of the same size as cost, with disparities
204 | # between 0 and num_disp-1 along the disparity dimension
205 | depth_hyps = (
206 | depth_values.unsqueeze(1)
207 | .unsqueeze(3)
208 | .unsqueeze(4)
209 | .expand(batch_size, feats, -1, height, width)
210 | .detach()
211 | )
212 | volume_variance = volume_variance * (
213 | (1 - validhints)
214 | + validhints
215 | * GAUSSIAN_HEIGHT
216 | * (1 - torch.exp(-((depth_hyps - hints) ** 2) / (2 * GAUSSIAN_WIDTH ** 2)))
217 | )
218 |
219 | # step 3. cost volume regularization
220 | cost_reg = self.cost_regularization(volume_variance)
221 | # cost_reg = F.upsample(cost_reg, [num_depth * 4, img_height, img_width], mode='trilinear')
222 | cost_reg = cost_reg.squeeze(1)
223 | prob_volume = F.softmax(-cost_reg, dim=1)
224 | depth = depth_regression(prob_volume, depth_values=depth_values).unsqueeze(1)
225 | depth = F.interpolate(depth, scale_factor=4, mode="bilinear")
226 |
227 | with torch.no_grad():
228 | # photometric confidence
229 | prob_volume_sum4 = (
230 | 4
231 | * F.avg_pool3d(
232 | F.pad(prob_volume.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)),
233 | (4, 1, 1),
234 | stride=1,
235 | padding=0,
236 | ).squeeze(1)
237 | )
238 | depth_index = depth_regression(
239 | prob_volume,
240 | depth_values=torch.arange(num_depth, device=prob_volume.device, dtype=torch.float),
241 | ).long()
242 | photometric_confidence = torch.gather(
243 | prob_volume_sum4, 1, depth_index.unsqueeze(1)
244 | ).squeeze(1)
245 | photometric_confidence = F.interpolate(
246 | photometric_confidence[None], scale_factor=4, mode="bilinear"
247 | )[0]
248 |
249 | depth_dict = {}
250 | depth_dict["stage_0"] = depth
251 |
252 | # step 4. depth map refinement
253 | if not self.refine:
254 | return {
255 | "depth": depth_dict,
256 | "photometric_confidence": photometric_confidence,
257 | "loss_data": depth_dict,
258 | }
259 | else:
260 | refined_depth = self.refine_network(torch.cat((imgs[0], depth), 1))
261 | return {
262 | "depth": refined_depth,
263 | "photometric_confidence": photometric_confidence,
264 | "loss_data": refined_depth,
265 | }
266 |
267 |
268 | def mvsnet_loss(depth_est, depth_gt, mask):
269 | mask = mask["stage_0"] > 0.5
270 | return F.smooth_l1_loss(
271 | depth_est["stage_0"][mask], depth_gt["stage_0"][mask], size_average=True
272 | )
273 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/patchmatchnet/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from types import SimpleNamespace
3 | from typing import Any, Optional
4 |
5 | import torch
6 | import torch.nn as nn
7 | import torch.nn.functional as F
8 | from torch import Tensor
9 |
10 | from .net import PatchmatchNet, patchmatchnet_loss
11 |
12 | __all__ = ["PatchmatchNet", "patchmatchnet_loss", "NetBuilder", "SimpleInterfaceNet"]
13 |
14 |
15 | class SimpleInterfaceNet(nn.Module):
16 | """
17 | Simple common interface to call the pretrained models
18 | """
19 |
20 | def __init__(self, **kwargs):
21 | super().__init__()
22 |
23 | default_args = dict(
24 | patchmatch_interval_scale=[0.005, 0.0125, 0.025],
25 | propagation_range=[6, 4, 2],
26 | patchmatch_iteration=[1, 2, 2],
27 | patchmatch_num_sample=[8, 8, 16],
28 | propagate_neighbors=[0, 8, 16],
29 | evaluate_neighbors=[9, 9, 9],
30 | )
31 | default_args.update(kwargs)
32 | self.model = PatchmatchNet(**default_args)
33 | self.all_outputs: dict[str, Any] = {}
34 |
35 | def forward(
36 | self,
37 | imgs: Tensor,
38 | intrinsics: Tensor,
39 | extrinsics: Tensor,
40 | depth_values: Tensor,
41 | hints: Optional[Tensor] = None,
42 | ):
43 | with warnings.catch_warnings():
44 | warnings.simplefilter("ignore", UserWarning)
45 |
46 | # compute poses
47 | proj_matrices = {}
48 | for i in range(4):
49 | proj_mat = extrinsics.clone()
50 | intrinsics_copy = intrinsics.clone()
51 | intrinsics_copy[..., :2, :] = intrinsics_copy[..., :2, :] / (2 ** i)
52 | proj_mat[..., :3, :4] = intrinsics_copy @ proj_mat[..., :3, :4]
53 | proj_matrices[f"stage_{i}"] = proj_mat
54 |
55 | # downsample images
56 | imgs_stages = {}
57 | h, w = imgs.shape[-2:]
58 | for i in range(4):
59 | dsize = h // (2 ** i), w // (2 ** i)
60 | imgs_stages[f"stage_{i}"] = torch.stack(
61 | [
62 | F.interpolate(img, dsize, mode="bilinear", align_corners=True)
63 | for img in torch.unbind(imgs, 1)
64 | ],
65 | 1,
66 | )
67 |
68 | # validhints
69 | validhints = None
70 | if hints is not None:
71 | validhints = (hints > 0).to(torch.float32)
72 |
73 | # call
74 | out = self.model(
75 | imgs_stages,
76 | proj_matrices,
77 | depth_values.min(1).values,
78 | depth_values.max(1).values,
79 | hints,
80 | validhints,
81 | )
82 | self.all_outputs = out
83 | return out["depth"]["stage_0"]
84 |
85 |
86 | class NetBuilder(nn.Module):
87 | def __init__(self, args: SimpleNamespace):
88 | super().__init__()
89 |
90 | self.model = PatchmatchNet(
91 | patchmatch_interval_scale=[0.005, 0.0125, 0.025],
92 | propagation_range=[6, 4, 2],
93 | patchmatch_iteration=[1, 2, 2],
94 | patchmatch_num_sample=[8, 8, 16],
95 | propagate_neighbors=[0, 8, 16],
96 | evaluate_neighbors=[9, 9, 9],
97 | )
98 |
99 | def loss_func(loss_data, depth_gt, mask):
100 | return patchmatchnet_loss(
101 | loss_data["depth_patchmatch"],
102 | loss_data["refined_depth"],
103 | depth_gt,
104 | mask,
105 | )
106 |
107 | self.loss = loss_func
108 |
109 | self.hparams = {
110 | "interval_scale": [0.005, 0.0125, 0.025],
111 | "propagation_range": [6, 4, 2],
112 | "iteration": [1, 2, 2],
113 | "num_sample": [8, 8, 16],
114 | "propagate_neighbors": [0, 8, 16],
115 | "evaluate_neighbors": [9, 9, 9],
116 | }
117 |
118 | def forward(self, batch: dict):
119 |
120 | hints, validhints = None, None
121 | if "hints" in batch:
122 | hints = batch["hints"]
123 | validhints = (hints > 0).to(torch.float32)
124 |
125 | return self.model(
126 | batch["imgs"],
127 | batch["proj_matrices"],
128 | batch["depth_min"],
129 | batch["depth_max"],
130 | hints,
131 | validhints,
132 | )
133 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/patchmatchnet/module.py:
--------------------------------------------------------------------------------
1 | """
2 | Implementation of Pytorch layer primitives, such as Conv+BN+ReLU, differentiable warping layers,
3 | and depth regression based upon expectation of an input probability distribution.
4 | """
5 |
6 | import torch
7 | import torch.nn as nn
8 | import torch.nn.functional as F
9 |
10 |
11 | class ConvBnReLU(nn.Module):
12 | """Implements 2d Convolution + batch normalization + ReLU"""
13 |
14 | def __init__(
15 | self,
16 | in_channels: int,
17 | out_channels: int,
18 | kernel_size: int = 3,
19 | stride: int = 1,
20 | pad: int = 1,
21 | dilation: int = 1,
22 | ) -> None:
23 | """initialization method for convolution2D + batch normalization + relu module
24 | Args:
25 | in_channels: input channel number of convolution layer
26 | out_channels: output channel number of convolution layer
27 | kernel_size: kernel size of convolution layer
28 | stride: stride of convolution layer
29 | pad: pad of convolution layer
30 | dilation: dilation of convolution layer
31 | """
32 | super(ConvBnReLU, self).__init__()
33 | self.conv = nn.Conv2d(
34 | in_channels,
35 | out_channels,
36 | kernel_size,
37 | stride=stride,
38 | padding=pad,
39 | dilation=dilation,
40 | bias=False,
41 | )
42 | self.bn = nn.BatchNorm2d(out_channels)
43 |
44 | def forward(self, x: torch.Tensor) -> torch.Tensor:
45 | """forward method"""
46 | return F.relu(self.bn(self.conv(x)), inplace=True)
47 |
48 |
49 | class ConvBnReLU3D(nn.Module):
50 | """Implements of 3d convolution + batch normalization + ReLU."""
51 |
52 | def __init__(
53 | self,
54 | in_channels: int,
55 | out_channels: int,
56 | kernel_size: int = 3,
57 | stride: int = 1,
58 | pad: int = 1,
59 | dilation: int = 1,
60 | ) -> None:
61 | """initialization method for convolution3D + batch normalization + relu module
62 | Args:
63 | in_channels: input channel number of convolution layer
64 | out_channels: output channel number of convolution layer
65 | kernel_size: kernel size of convolution layer
66 | stride: stride of convolution layer
67 | pad: pad of convolution layer
68 | dilation: dilation of convolution layer
69 | """
70 | super(ConvBnReLU3D, self).__init__()
71 | self.conv = nn.Conv3d(
72 | in_channels,
73 | out_channels,
74 | kernel_size,
75 | stride=stride,
76 | padding=pad,
77 | dilation=dilation,
78 | bias=False,
79 | )
80 | self.bn = nn.BatchNorm3d(out_channels)
81 |
82 | def forward(self, x: torch.Tensor) -> torch.Tensor:
83 | """forward method"""
84 | return F.relu(self.bn(self.conv(x)), inplace=True)
85 |
86 |
87 | class ConvBnReLU1D(nn.Module):
88 | """Implements 1d Convolution + batch normalization + ReLU."""
89 |
90 | def __init__(
91 | self,
92 | in_channels: int,
93 | out_channels: int,
94 | kernel_size: int = 3,
95 | stride: int = 1,
96 | pad: int = 1,
97 | dilation: int = 1,
98 | ) -> None:
99 | """initialization method for convolution1D + batch normalization + relu module
100 | Args:
101 | in_channels: input channel number of convolution layer
102 | out_channels: output channel number of convolution layer
103 | kernel_size: kernel size of convolution layer
104 | stride: stride of convolution layer
105 | pad: pad of convolution layer
106 | dilation: dilation of convolution layer
107 | """
108 | super(ConvBnReLU1D, self).__init__()
109 | self.conv = nn.Conv1d(
110 | in_channels,
111 | out_channels,
112 | kernel_size,
113 | stride=stride,
114 | padding=pad,
115 | dilation=dilation,
116 | bias=False,
117 | )
118 | self.bn = nn.BatchNorm1d(out_channels)
119 |
120 | def forward(self, x: torch.Tensor) -> torch.Tensor:
121 | """forward method"""
122 | return F.relu(self.bn(self.conv(x)), inplace=True)
123 |
124 |
125 | class ConvBn(nn.Module):
126 | """Implements of 2d convolution + batch normalization."""
127 |
128 | def __init__(
129 | self,
130 | in_channels: int,
131 | out_channels: int,
132 | kernel_size: int = 3,
133 | stride: int = 1,
134 | pad: int = 1,
135 | ) -> None:
136 | """initialization method for convolution2D + batch normalization + ReLU module
137 | Args:
138 | in_channels: input channel number of convolution layer
139 | out_channels: output channel number of convolution layer
140 | kernel_size: kernel size of convolution layer
141 | stride: stride of convolution layer
142 | pad: pad of convolution layer
143 | """
144 | super(ConvBn, self).__init__()
145 | self.conv = nn.Conv2d(
146 | in_channels, out_channels, kernel_size, stride=stride, padding=pad, bias=False
147 | )
148 | self.bn = nn.BatchNorm2d(out_channels)
149 |
150 | def forward(self, x: torch.Tensor) -> torch.Tensor:
151 | """forward method"""
152 | return self.bn(self.conv(x))
153 |
154 |
155 | def differentiable_warping(
156 | src_fea: torch.Tensor,
157 | src_proj: torch.Tensor,
158 | ref_proj: torch.Tensor,
159 | depth_samples: torch.Tensor,
160 | ):
161 | """Differentiable homography-based warping, implemented in Pytorch.
162 |
163 | Args:
164 | src_fea: [B, C, H, W] source features, for each source view in batch
165 | src_proj: [B, 4, 4] source camera projection matrix, for each source view in batch
166 | ref_proj: [B, 4, 4] reference camera projection matrix, for each ref view in batch
167 | depth_samples: [B, Ndepth, H, W] virtual depth layers
168 | Returns:
169 | warped_src_fea: [B, C, Ndepth, H, W] features on depths after perspective transformation
170 | """
171 |
172 | batch, channels, height, width = src_fea.shape
173 | num_depth = depth_samples.shape[1]
174 |
175 | with torch.no_grad():
176 | proj = torch.matmul(src_proj, torch.inverse(ref_proj))
177 | rot = proj[:, :3, :3] # [B,3,3]
178 | trans = proj[:, :3, 3:4] # [B,3,1]
179 |
180 | y, x = torch.meshgrid(
181 | [
182 | torch.arange(0, height, dtype=torch.float32, device=src_fea.device),
183 | torch.arange(0, width, dtype=torch.float32, device=src_fea.device),
184 | ]
185 | )
186 | y, x = y.contiguous(), x.contiguous()
187 | y, x = y.view(height * width), x.view(height * width)
188 | xyz = torch.stack((x, y, torch.ones_like(x))) # [3, H*W]
189 | xyz = torch.unsqueeze(xyz, 0).repeat(batch, 1, 1) # [B, 3, H*W]
190 | rot_xyz = torch.matmul(rot, xyz) # [B, 3, H*W]
191 |
192 | rot_depth_xyz = rot_xyz.unsqueeze(2).repeat(1, 1, num_depth, 1) * depth_samples.view(
193 | batch, 1, num_depth, height * width
194 | ) # [B, 3, Ndepth, H*W]
195 | proj_xyz = rot_depth_xyz + trans.view(batch, 3, 1, 1) # [B, 3, Ndepth, H*W]
196 | # avoid negative depth
197 | negative_depth_mask = proj_xyz[:, 2:] <= 1e-3
198 | proj_xyz[:, 0:1][negative_depth_mask] = width
199 | proj_xyz[:, 1:2][negative_depth_mask] = height
200 | proj_xyz[:, 2:3][negative_depth_mask] = 1
201 | proj_xy = proj_xyz[:, :2, :, :] / proj_xyz[:, 2:3, :, :] # [B, 2, Ndepth, H*W]
202 | proj_x_normalized = proj_xy[:, 0, :, :] / ((width - 1) / 2) - 1 # [B, Ndepth, H*W]
203 | proj_y_normalized = proj_xy[:, 1, :, :] / ((height - 1) / 2) - 1
204 | proj_xy = torch.stack((proj_x_normalized, proj_y_normalized), dim=3) # [B, Ndepth, H*W, 2]
205 | grid = proj_xy
206 |
207 | warped_src_fea = F.grid_sample(
208 | src_fea,
209 | grid.view(batch, num_depth * height, width, 2),
210 | mode="bilinear",
211 | padding_mode="zeros",
212 | align_corners=True,
213 | )
214 |
215 | warped_src_fea = warped_src_fea.view(batch, channels, num_depth, height, width)
216 |
217 | return warped_src_fea
218 |
219 |
220 | def depth_regression(p: torch.Tensor, depth_values: torch.Tensor) -> torch.Tensor:
221 | """Implements per-pixel depth regression based upon a probability distribution per-pixel.
222 |
223 | The regressed depth value D(p) at pixel p is found as the expectation w.r.t. P of the hypotheses.
224 |
225 | Args:
226 | p: probability volume [B, D, H, W]
227 | depth_values: discrete depth values [B, D]
228 | Returns:
229 | result depth: expected value, soft argmin [B, 1, H, W]
230 | """
231 |
232 | depth_values = depth_values.view(*depth_values.shape, 1, 1)
233 | depth = torch.sum(p * depth_values, dim=1)
234 | depth = depth.unsqueeze(1)
235 | return depth
236 |
237 |
238 | def depth_regression_1(p: torch.Tensor, depth_values: torch.Tensor) -> torch.Tensor:
239 | """another version of depth regression function
240 | Args:
241 | p: probability volume [B, D, H, W]
242 | depth_values: discrete depth values [B, D]
243 | Returns:
244 | result depth: expected value, soft argmin [B, 1, H, W]
245 | """
246 |
247 | depth = torch.sum(p * depth_values, 1)
248 | depth = depth.unsqueeze(1)
249 | return depth
250 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/patchmatchnet/net.py:
--------------------------------------------------------------------------------
1 | from typing import Any, Dict, List
2 |
3 | import torch
4 | import torch.nn as nn
5 | import torch.nn.functional as F
6 |
7 | from .module import ConvBnReLU, depth_regression
8 | from .patchmatch import PatchMatch
9 |
10 |
11 | class FeatureNet(nn.Module):
12 | """Feature Extraction Network: to extract features of original images from each view"""
13 |
14 | def __init__(self):
15 | """Initialize different layers in the network"""
16 |
17 | super(FeatureNet, self).__init__()
18 |
19 | self.conv0 = ConvBnReLU(3, 8, 3, 1, 1)
20 | # [B,8,H,W]
21 | self.conv1 = ConvBnReLU(8, 8, 3, 1, 1)
22 | # [B,16,H/2,W/2]
23 | self.conv2 = ConvBnReLU(8, 16, 5, 2, 2)
24 | self.conv3 = ConvBnReLU(16, 16, 3, 1, 1)
25 | self.conv4 = ConvBnReLU(16, 16, 3, 1, 1)
26 | # [B,32,H/4,W/4]
27 | self.conv5 = ConvBnReLU(16, 32, 5, 2, 2)
28 | self.conv6 = ConvBnReLU(32, 32, 3, 1, 1)
29 | self.conv7 = ConvBnReLU(32, 32, 3, 1, 1)
30 | # [B,64,H/8,W/8]
31 | self.conv8 = ConvBnReLU(32, 64, 5, 2, 2)
32 | self.conv9 = ConvBnReLU(64, 64, 3, 1, 1)
33 | self.conv10 = ConvBnReLU(64, 64, 3, 1, 1)
34 |
35 | self.output1 = nn.Conv2d(64, 64, 1, bias=False)
36 | self.inner1 = nn.Conv2d(32, 64, 1, bias=True)
37 | self.inner2 = nn.Conv2d(16, 64, 1, bias=True)
38 | self.output2 = nn.Conv2d(64, 32, 1, bias=False)
39 | self.output3 = nn.Conv2d(64, 16, 1, bias=False)
40 |
41 | def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
42 | """Forward method
43 |
44 | Args:
45 | x: images from a single view, in the shape of [B, C, H, W]. Generally, C=3
46 |
47 | Returns:
48 | output_feature: a python dictionary contains extracted features from stage_1 to stage_3
49 | keys are "stage_1", "stage_2", and "stage_3"
50 | """
51 | output_feature = {}
52 |
53 | conv1 = self.conv1(self.conv0(x))
54 | conv4 = self.conv4(self.conv3(self.conv2(conv1)))
55 |
56 | conv7 = self.conv7(self.conv6(self.conv5(conv4)))
57 | conv10 = self.conv10(self.conv9(self.conv8(conv7)))
58 |
59 | output_feature["stage_3"] = self.output1(conv10)
60 |
61 | intra_feat = F.interpolate(conv10, scale_factor=2, mode="bilinear") + self.inner1(conv7)
62 | del conv7, conv10
63 | output_feature["stage_2"] = self.output2(intra_feat)
64 |
65 | intra_feat = F.interpolate(intra_feat, scale_factor=2, mode="bilinear") + self.inner2(
66 | conv4
67 | )
68 | del conv4
69 | output_feature["stage_1"] = self.output3(intra_feat)
70 |
71 | del intra_feat
72 | return output_feature
73 |
74 |
75 | class Refinement(nn.Module):
76 | """Depth map refinement network"""
77 |
78 | def __init__(self):
79 | """Initialize"""
80 |
81 | super(Refinement, self).__init__()
82 |
83 | # img: [B,3,H,W]
84 | self.conv0 = ConvBnReLU(in_channels=3, out_channels=8)
85 | # depth map:[B,1,H/2,W/2]
86 | self.conv1 = ConvBnReLU(in_channels=1, out_channels=8)
87 | self.conv2 = ConvBnReLU(in_channels=8, out_channels=8)
88 | self.deconv = nn.ConvTranspose2d(
89 | in_channels=8,
90 | out_channels=8,
91 | kernel_size=3,
92 | padding=1,
93 | output_padding=1,
94 | stride=2,
95 | bias=False,
96 | )
97 |
98 | self.bn = nn.BatchNorm2d(8)
99 | self.conv3 = ConvBnReLU(in_channels=16, out_channels=8)
100 | self.res = nn.Conv2d(in_channels=8, out_channels=1, kernel_size=3, padding=1, bias=False)
101 |
102 | def forward(
103 | self,
104 | img: torch.Tensor,
105 | depth_0: torch.Tensor,
106 | depth_min: torch.Tensor,
107 | depth_max: torch.Tensor,
108 | ) -> torch.Tensor:
109 | """Forward method
110 |
111 | Args:
112 | img: input reference images (B, 3, H, W)
113 | depth_0: current depth map (B, 1, H//2, W//2)
114 | depth_min: pre-defined minimum depth (B, )
115 | depth_max: pre-defined maximum depth (B, )
116 |
117 | Returns:
118 | depth: refined depth map (B, 1, H, W)
119 | """
120 |
121 | batch_size = depth_min.size()[0]
122 | # pre-scale the depth map into [0,1]
123 | depth = (depth_0 - depth_min.view(batch_size, 1, 1, 1)) / (
124 | depth_max.view(batch_size, 1, 1, 1) - depth_min.view(batch_size, 1, 1, 1)
125 | )
126 |
127 | conv0 = self.conv0(img)
128 | deconv = F.relu(self.bn(self.deconv(self.conv2(self.conv1(depth)))), inplace=True)
129 | cat = torch.cat((deconv, conv0), dim=1)
130 | del deconv, conv0
131 | # depth residual
132 | res = self.res(self.conv3(cat))
133 | del cat
134 |
135 | depth = F.interpolate(depth, scale_factor=2, mode="nearest") + res
136 | # convert the normalized depth back
137 | depth = depth * (
138 | depth_max.view(batch_size, 1, 1, 1) - depth_min.view(batch_size, 1, 1, 1)
139 | ) + depth_min.view(batch_size, 1, 1, 1)
140 |
141 | return depth
142 |
143 |
144 | class PatchmatchNet(nn.Module):
145 | """Implementation of complete structure of PatchmatchNet"""
146 |
147 | def __init__(
148 | self,
149 | patchmatch_interval_scale: List[float] = [0.005, 0.0125, 0.025],
150 | propagation_range: List[int] = [6, 4, 2],
151 | patchmatch_iteration: List[int] = [1, 2, 2],
152 | patchmatch_num_sample: List[int] = [8, 8, 16],
153 | propagate_neighbors: List[int] = [0, 8, 16],
154 | evaluate_neighbors: List[int] = [9, 9, 9],
155 | ) -> None:
156 | """Initialize modules in PatchmatchNet
157 |
158 | Args:
159 | patchmatch_interval_scale: depth interval scale in patchmatch module
160 | propagation_range: propagation range
161 | patchmatch_iteration: patchmatch interation number
162 | patchmatch_num_sample: patchmatch number of samples
163 | propagate_neighbors: number of propagation neigbors
164 | evaluate_neighbors: number of propagation neigbors for evaluation
165 | """
166 | super(PatchmatchNet, self).__init__()
167 |
168 | self.stages = 4
169 | self.feature = FeatureNet()
170 | self.patchmatch_num_sample = patchmatch_num_sample
171 |
172 | num_features = [8, 16, 32, 64]
173 |
174 | self.propagate_neighbors = propagate_neighbors
175 | self.evaluate_neighbors = evaluate_neighbors
176 | # number of groups for group-wise correlation
177 | self.G = [4, 8, 8]
178 |
179 | for i in range(self.stages - 1):
180 |
181 | if i == 2:
182 | patchmatch = PatchMatch(
183 | random_initialization=True,
184 | propagation_out_range=propagation_range[i],
185 | patchmatch_iteration=patchmatch_iteration[i],
186 | patchmatch_num_sample=patchmatch_num_sample[i],
187 | patchmatch_interval_scale=patchmatch_interval_scale[i],
188 | num_feature=num_features[i + 1],
189 | G=self.G[i],
190 | propagate_neighbors=self.propagate_neighbors[i],
191 | stage=i + 1,
192 | evaluate_neighbors=evaluate_neighbors[i],
193 | )
194 | else:
195 | patchmatch = PatchMatch(
196 | random_initialization=False,
197 | propagation_out_range=propagation_range[i],
198 | patchmatch_iteration=patchmatch_iteration[i],
199 | patchmatch_num_sample=patchmatch_num_sample[i],
200 | patchmatch_interval_scale=patchmatch_interval_scale[i],
201 | num_feature=num_features[i + 1],
202 | G=self.G[i],
203 | propagate_neighbors=self.propagate_neighbors[i],
204 | stage=i + 1,
205 | evaluate_neighbors=evaluate_neighbors[i],
206 | )
207 | setattr(self, f"patchmatch_{i+1}", patchmatch)
208 |
209 | self.upsample_net = Refinement()
210 |
211 | def forward(
212 | self,
213 | imgs: Dict[str, torch.Tensor],
214 | proj_matrices: Dict[str, torch.Tensor],
215 | depth_min: torch.Tensor,
216 | depth_max: torch.Tensor,
217 | hints: torch.Tensor = None,
218 | validhints: torch.Tensor = None,
219 | ) -> Dict[str, Any]:
220 | """Forward method for PatchMatchNet
221 |
222 | Args:
223 | imgs: different stages of images (B, 3, H, W) stored in the dictionary
224 | proj_matrics: different stages of camera projection matrices (B, 4, 4) stored in the dictionary
225 | depth_min: minimum virtual depth (B, )
226 | depth_max: maximum virtual depth (B, )
227 |
228 | Returns:
229 | output dictionary of PatchMatchNet, containing refined depthmap, depth patchmatch
230 | and photometric_confidence.
231 | """
232 | imgs_0 = torch.unbind(imgs["stage_0"], 1)
233 | imgs_1 = torch.unbind(imgs["stage_1"], 1)
234 | imgs_2 = torch.unbind(imgs["stage_2"], 1)
235 | imgs_3 = torch.unbind(imgs["stage_3"], 1)
236 | del imgs
237 |
238 | self.imgs_0_ref = imgs_0[0]
239 | self.imgs_1_ref = imgs_1[0]
240 | self.imgs_2_ref = imgs_2[0]
241 | self.imgs_3_ref = imgs_3[0]
242 | del imgs_1, imgs_2, imgs_3
243 |
244 | self.proj_matrices_0 = torch.unbind(proj_matrices["stage_0"].float(), 1)
245 | self.proj_matrices_1 = torch.unbind(proj_matrices["stage_1"].float(), 1)
246 | self.proj_matrices_2 = torch.unbind(proj_matrices["stage_2"].float(), 1)
247 | self.proj_matrices_3 = torch.unbind(proj_matrices["stage_3"].float(), 1)
248 | del proj_matrices
249 |
250 | assert len(imgs_0) == len(
251 | self.proj_matrices_0
252 | ), "Different number of images and projection matrices"
253 |
254 | # step 1. Multi-scale feature extraction
255 | features = []
256 | for img in imgs_0:
257 | output_feature = self.feature(img)
258 | features.append(output_feature)
259 | del imgs_0
260 | ref_feature, src_features = features[0], features[1:]
261 |
262 | depth_min = depth_min.float()
263 | depth_max = depth_max.float()
264 |
265 | # step 2. Learning-based patchmatch
266 | depth = None
267 | view_weights = None
268 | depth_patchmatch = {}
269 | refined_depth = {}
270 |
271 | for l in reversed(range(1, self.stages)):
272 | src_features_l = [src_fea[f"stage_{l}"] for src_fea in src_features]
273 | projs_l = getattr(self, f"proj_matrices_{l}")
274 | ref_proj, src_projs = projs_l[0], projs_l[1:]
275 |
276 | if l > 1:
277 | depth, _, view_weights = getattr(self, f"patchmatch_{l}")(
278 | ref_feature=ref_feature[f"stage_{l}"],
279 | src_features=src_features_l,
280 | ref_proj=ref_proj,
281 | src_projs=src_projs,
282 | depth_min=depth_min,
283 | depth_max=depth_max,
284 | depth=depth,
285 | img=getattr(self, f"imgs_{l}_ref"),
286 | view_weights=view_weights,
287 | hints=hints,
288 | validhints=validhints,
289 | )
290 | else:
291 | depth, score, _ = getattr(self, f"patchmatch_{l}")(
292 | ref_feature=ref_feature[f"stage_{l}"],
293 | src_features=src_features_l,
294 | ref_proj=ref_proj,
295 | src_projs=src_projs,
296 | depth_min=depth_min,
297 | depth_max=depth_max,
298 | depth=depth,
299 | img=getattr(self, f"imgs_{l}_ref"),
300 | view_weights=view_weights,
301 | hints=hints,
302 | validhints=validhints,
303 | )
304 |
305 | del src_features_l, ref_proj, src_projs, projs_l
306 |
307 | depth_patchmatch[f"stage_{l}"] = depth
308 |
309 | depth = depth[-1].detach()
310 | if l > 1:
311 | # upsampling the depth map and pixel-wise view weight for next stage
312 | depth = F.interpolate(depth, scale_factor=2, mode="nearest")
313 | view_weights = F.interpolate(view_weights, scale_factor=2, mode="nearest")
314 |
315 | # step 3. Refinement
316 | depth = self.upsample_net(self.imgs_0_ref, depth, depth_min, depth_max)
317 | refined_depth["stage_0"] = depth
318 |
319 | del depth, ref_feature, src_features
320 |
321 | num_depth = self.patchmatch_num_sample[0]
322 | score_sum4 = 4 * F.avg_pool3d(
323 | F.pad(score.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)), (4, 1, 1), stride=1, padding=0
324 | ).squeeze(1)
325 | # [B, 1, H, W]
326 | depth_index = depth_regression(
327 | score, depth_values=torch.arange(num_depth, device=score.device, dtype=torch.float)
328 | ).long()
329 | depth_index = torch.clamp(depth_index, 0, num_depth - 1)
330 | photometric_confidence = torch.gather(score_sum4, 1, depth_index)
331 | photometric_confidence = F.interpolate(
332 | photometric_confidence, scale_factor=2, mode="nearest"
333 | )
334 | photometric_confidence = photometric_confidence.squeeze(1)
335 |
336 | return {
337 | "depth": refined_depth,
338 | "loss_data": {
339 | "depth_patchmatch": depth_patchmatch,
340 | "refined_depth": refined_depth,
341 | },
342 | "photometric_confidence": photometric_confidence,
343 | }
344 |
345 |
346 | def patchmatchnet_loss(
347 | depth_patchmatch: Dict[str, torch.Tensor],
348 | refined_depth: Dict[str, torch.Tensor],
349 | depth_gt: Dict[str, torch.Tensor],
350 | mask: Dict[str, torch.Tensor],
351 | ) -> torch.Tensor:
352 | """Patchmatch Net loss function
353 |
354 | Args:
355 | depth_patchmatch: depth map predicted by patchmatch net
356 | refined_depth: refined depth map predicted by patchmatch net
357 | depth_gt: ground truth depth map
358 | mask: mask for filter valid points
359 |
360 | Returns:
361 | loss: result loss value
362 | """
363 | stage = 4
364 |
365 | loss = 0
366 | for l in range(1, stage):
367 | depth_gt_l = depth_gt[f"stage_{l}"]
368 | mask_l = mask[f"stage_{l}"] > 0.5
369 | depth2 = depth_gt_l[mask_l]
370 |
371 | depth_patchmatch_l = depth_patchmatch[f"stage_{l}"]
372 | for i in range(len(depth_patchmatch_l)):
373 | depth1 = depth_patchmatch_l[i][mask_l]
374 | loss = loss + F.smooth_l1_loss(depth1, depth2, reduction="mean")
375 |
376 | l = 0
377 | depth_refined_l = refined_depth[f"stage_{l}"]
378 | depth_gt_l = depth_gt[f"stage_{l}"]
379 | mask_l = mask[f"stage_{l}"] > 0.5
380 |
381 | depth1 = depth_refined_l[mask_l]
382 | depth2 = depth_gt_l[mask_l]
383 | loss = loss + F.smooth_l1_loss(depth1, depth2, reduction="mean")
384 |
385 | return loss
386 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/ucsnet/__init__.py:
--------------------------------------------------------------------------------
1 | import warnings
2 | from types import SimpleNamespace
3 | from typing import Any, Optional
4 |
5 | import torch
6 | import torch.nn as nn
7 | from torch import Tensor
8 |
9 | from .ucsnet import UCSNet, ucsnet_loss
10 |
11 | __all__ = ["ucsnet_loss", "UCSNet", "NetBuilder", "SimpleInterfaceNet"]
12 |
13 |
14 | class SimpleInterfaceNet(nn.Module):
15 | """
16 | Simple common interface to call the pretrained models
17 | """
18 |
19 | def __init__(self, **kwargs):
20 | super().__init__()
21 | self.model = UCSNet(*kwargs)
22 | self.all_outputs: dict[str, Any] = {}
23 |
24 | def forward(
25 | self,
26 | imgs: Tensor,
27 | intrinsics: Tensor,
28 | extrinsics: Tensor,
29 | depth_values: Tensor,
30 | hints: Optional[Tensor] = None,
31 | ):
32 | with warnings.catch_warnings():
33 | warnings.simplefilter("ignore", UserWarning)
34 |
35 | # compute poses
36 | proj_matrices = {}
37 | for i in range(4):
38 | proj_mat = extrinsics.clone()
39 | intrinsics_copy = intrinsics.clone()
40 | intrinsics_copy[..., :2, :] = intrinsics_copy[..., :2, :] / (2 ** i)
41 | proj_mat[..., :3, :4] = intrinsics_copy @ proj_mat[..., :3, :4]
42 | proj_matrices[f"stage_{i}"] = proj_mat
43 |
44 | # validhints
45 | validhints = None
46 | if hints is not None:
47 | validhints = (hints > 0).to(torch.float32)
48 |
49 | # call
50 | out = self.model(imgs, proj_matrices, depth_values, hints, validhints)
51 | self.all_outputs = out
52 | return out["depth"]["stage_0"]
53 |
54 |
55 | class NetBuilder(nn.Module):
56 | def __init__(self, args: SimpleNamespace):
57 | super().__init__()
58 | self.model = UCSNet()
59 |
60 | def loss_func(loss_data, depth_gt, mask):
61 | return ucsnet_loss(
62 | loss_data["depth"],
63 | depth_gt,
64 | mask,
65 | )
66 |
67 | self.loss = loss_func
68 |
69 | def forward(self, batch: dict):
70 |
71 | hints, validhints = None, None
72 | if "hints" in batch:
73 | hints = batch["hints"]
74 | validhints = (hints > 0).to(torch.float32)
75 |
76 | return self.model(
77 | batch["imgs"]["stage_0"],
78 | batch["proj_matrices"],
79 | batch["depth_values"],
80 | hints,
81 | validhints,
82 | )
83 |
--------------------------------------------------------------------------------
/guided_mvs_lib/models/ucsnet/ucsnet.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import torch.nn.functional as F
4 |
5 | from .submodules import *
6 |
7 |
8 | def compute_depth(
9 | feats,
10 | proj_mats,
11 | depth_samps,
12 | cost_reg,
13 | lamb,
14 | is_training=False,
15 | hints=None,
16 | validhints=None,
17 | scale=1,
18 | ):
19 | """
20 | :param feats: [(B, C, H, W), ] * num_views
21 | :param proj_mats: [()]
22 | :param depth_samps:
23 | :param cost_reg:
24 | :param lamb:
25 | :return:
26 | """
27 |
28 | # print(proj_mats.shape)
29 | proj_mats = torch.unbind(proj_mats, 1)
30 | num_views = len(feats)
31 | num_depth = depth_samps.shape[1]
32 |
33 | assert len(proj_mats) == num_views, "Different number of images and projection matrices"
34 |
35 | ref_feat, src_feats = feats[0], feats[1:]
36 | ref_proj, src_projs = proj_mats[0], proj_mats[1:]
37 |
38 | ref_volume = ref_feat.unsqueeze(2).repeat(1, 1, num_depth, 1, 1)
39 | volume_sum = ref_volume
40 | volume_sq_sum = ref_volume ** 2
41 | del ref_volume
42 |
43 | # todo optimize impl
44 | for src_fea, src_proj in zip(src_feats, src_projs):
45 | warped_volume = homo_warping(src_fea, src_proj, ref_proj, depth_samps)
46 |
47 | if is_training:
48 | volume_sum = volume_sum + warped_volume
49 | volume_sq_sum = volume_sq_sum + warped_volume ** 2
50 | else:
51 | volume_sum += warped_volume
52 | volume_sq_sum += warped_volume.pow_(2) # in_place method
53 | del warped_volume
54 | volume_variance = volume_sq_sum.div_(num_views).sub_(volume_sum.div_(num_views).pow_(2))
55 |
56 | # TODO cost-volume modulation
57 | if hints is not None and validhints is not None:
58 | batch_size, feats, height, width = ref_feat.shape
59 | GAUSSIAN_HEIGHT = 10.0
60 | GAUSSIAN_WIDTH = 1.0
61 |
62 | # image features are one fourth the original size: subsample the hints and divide them by four
63 | hints = hints
64 | hints = F.interpolate(hints, scale_factor=1 / scale, mode="nearest").unsqueeze(1)
65 | validhints = validhints
66 | validhints = F.interpolate(validhints, scale_factor=1 / scale, mode="nearest").unsqueeze(1)
67 | hints = hints * validhints
68 |
69 | # add feature and disparity dimensions to hints and validhints
70 | # and repeat their values along those dimensions, to obtain the same size as cost
71 | hints = hints.expand(-1, feats, num_depth, -1, -1)
72 | validhints = validhints.expand(-1, feats, num_depth, -1, -1)
73 |
74 | # create a tensor of the same size as cost, with disparities
75 | # between 0 and num_disp-1 along the disparity dimension
76 | depth_hyps = depth_samps.unsqueeze(1).expand(batch_size, feats, -1, height, width).detach()
77 | volume_variance = volume_variance * (
78 | (1 - validhints)
79 | + validhints
80 | * GAUSSIAN_HEIGHT
81 | * (1 - torch.exp(-((depth_hyps - hints) ** 2) / (2 * GAUSSIAN_WIDTH ** 2)))
82 | )
83 |
84 | prob_volume_pre = cost_reg(volume_variance).squeeze(1)
85 | prob_volume = F.softmax(prob_volume_pre, dim=1)
86 | depth = depth_regression(prob_volume, depth_values=depth_samps)
87 |
88 | with torch.no_grad():
89 | prob_volume_sum4 = 4 * F.avg_pool3d(
90 | F.pad(prob_volume.unsqueeze(1), pad=(0, 0, 0, 0, 1, 2)), (4, 1, 1), stride=1, padding=0
91 | ).squeeze(1)
92 | depth_index = depth_regression(
93 | prob_volume,
94 | depth_values=torch.arange(num_depth, device=prob_volume.device, dtype=torch.float),
95 | ).long()
96 | depth_index = depth_index.clamp(min=0, max=num_depth - 1)
97 | prob_conf = torch.gather(prob_volume_sum4, 1, depth_index.unsqueeze(1)).squeeze(1)
98 |
99 | samp_variance = (depth_samps - depth.unsqueeze(1)) ** 2
100 | exp_variance = lamb * torch.sum(samp_variance * prob_volume, dim=1, keepdim=False) ** 0.5
101 |
102 | return {"depth": depth, "photometric_confidence": prob_conf, "variance": exp_variance}
103 |
104 |
105 | class UCSNet(nn.Module):
106 | def __init__(
107 | self,
108 | lamb=1.5,
109 | stage_configs=[64, 32, 8],
110 | grad_method="detach",
111 | base_chs=[8, 8, 8],
112 | feat_ext_ch=8,
113 | ):
114 | super(UCSNet, self).__init__()
115 |
116 | self.stage_configs = stage_configs
117 | self.grad_method = grad_method
118 | self.base_chs = base_chs
119 | self.lamb = lamb
120 | self.num_stage = len(stage_configs)
121 | self.ds_ratio = {
122 | "stage_2": 4.0, # "stage_1": 4.0,
123 | "stage_1": 2.0, # "stage_2": 2.0,
124 | "stage_0": 1.0, # "stage_3": 1.0
125 | }
126 |
127 | self.feature_extraction = FeatExtNet(
128 | base_channels=feat_ext_ch,
129 | num_stage=self.num_stage,
130 | )
131 |
132 | self.cost_regularization = nn.ModuleList(
133 | [
134 | CostRegNet(
135 | in_channels=self.feature_extraction.out_channels[i],
136 | base_channels=self.base_chs[i],
137 | )
138 | for i in [2, 1, 0]
139 | ]
140 | ) # range(self.num_stage)])
141 |
142 | def forward(self, imgs, proj_matrices, depth_values, hints=None, validhints=None):
143 | features = []
144 | for nview_idx in range(imgs.shape[1]):
145 | img = imgs[:, nview_idx]
146 | features.append(self.feature_extraction(img))
147 |
148 | outputs = {}
149 | depth, cur_depth, exp_var = None, None, None
150 | for stage_idx in [2, 1, 0]: # range(self.num_stage):
151 | features_stage = [feat["stage_{}".format(stage_idx)] for feat in features]
152 | proj_matrices_stage = proj_matrices["stage_{}".format(stage_idx)] # + 1)]
153 | stage_scale = self.ds_ratio["stage_{}".format(stage_idx)] # + 1)]
154 | cur_h = img.shape[2] // int(stage_scale)
155 | cur_w = img.shape[3] // int(stage_scale)
156 |
157 | if depth is not None:
158 | if self.grad_method == "detach":
159 | cur_depth = depth.detach()
160 | exp_var = exp_var.detach()
161 | else:
162 | cur_depth = depth
163 |
164 | cur_depth = F.interpolate(cur_depth.unsqueeze(1), [cur_h, cur_w], mode="bilinear")
165 | exp_var = F.interpolate(exp_var.unsqueeze(1), [cur_h, cur_w], mode="bilinear")
166 |
167 | else:
168 | cur_depth = depth_values
169 |
170 | depth_range_samples = uncertainty_aware_samples(
171 | cur_depth=cur_depth,
172 | exp_var=exp_var,
173 | ndepth=self.stage_configs[2 - stage_idx],
174 | dtype=img[0].dtype,
175 | device=img[0].device,
176 | shape=[img.shape[0], cur_h, cur_w],
177 | )
178 |
179 | outputs_stage = compute_depth(
180 | features_stage,
181 | proj_matrices_stage,
182 | depth_samps=depth_range_samples,
183 | cost_reg=self.cost_regularization[stage_idx],
184 | lamb=self.lamb,
185 | is_training=self.training,
186 | hints=hints,
187 | validhints=validhints,
188 | scale=stage_scale,
189 | )
190 |
191 | depth = outputs_stage["depth"]
192 | exp_var = outputs_stage["variance"]
193 |
194 | outputs["stage_{}".format(stage_idx)] = outputs_stage["depth"].unsqueeze(1)
195 |
196 | return {
197 | "depth": outputs,
198 | "loss_data": {
199 | "depth": outputs,
200 | },
201 | "photometric_confidence": outputs_stage["photometric_confidence"],
202 | }
203 |
204 |
205 | def ucsnet_loss(outputs, labels, masks, weights=[0.5, 1.0, 2.0]):
206 | tot_loss = 0
207 | for stage_id in [0, 1, 2]: # range(3):
208 | depth_i = outputs["stage_{}".format(stage_id)] # .unsqueeze(1)
209 | label_i = labels["stage_{}".format(stage_id)]
210 | mask_i = masks["stage_{}".format(stage_id)].bool()
211 | depth_loss = F.smooth_l1_loss(depth_i[mask_i], label_i[mask_i], reduction="mean")
212 | tot_loss += depth_loss * weights[2 - stage_id]
213 | return tot_loss
214 |
--------------------------------------------------------------------------------
/hubconf.py:
--------------------------------------------------------------------------------
1 | import tarfile
2 | from multiprocessing import cpu_count
3 | from pathlib import Path
4 | from typing import Literal, Tuple
5 |
6 | import torch
7 | import yaml
8 |
9 | from guided_mvs_lib import __version__ as CURR_VERS
10 | from guided_mvs_lib import models
11 |
12 | dependencies = ["torch", "torchvision"]
13 |
14 | ## entry points for each single model
15 |
16 | CURR_DIR = Path(__file__).parent
17 |
18 |
19 | def _get_archive() -> str:
20 | if not (CURR_DIR / "trained_models.tar.gz").exists():
21 | torch.hub.download_url_to_file(
22 | f"https://github.com/andreaconti/multi-view-guided-multi-view-stereo/releases/download/v{CURR_VERS}/trained_models.tar.gz",
23 | str(CURR_DIR / "trained_models.tar.gz"),
24 | )
25 | return str(CURR_DIR / "trained_models.tar.gz")
26 |
27 |
28 | def _load_model(
29 | tarpath: str,
30 | model: Literal["ucsnet", "d2hc_rmvsnet", "mvsnet", "patchmatchnet", "cas_mvsnet"] = "mvsnet",
31 | pretrained: bool = True,
32 | dataset: Literal["blended_mvg", "dtu_yao_blended_mvg"] = "blended_mvg",
33 | hints: Literal["mvguided_filtered", "not_guided", "guided", "mvguided"] = "not_guided",
34 | hints_density: float = 0.03,
35 | ):
36 | """
37 | Utility function to load from the tarfile containing all the pretrained models the one choosen
38 | """
39 |
40 | assert model in [
41 | "ucsnet",
42 | "d2hc_rmvsnet",
43 | "mvsnet",
44 | "patchmatchnet",
45 | "cas_mvsnet",
46 | ]
47 | assert dataset in ["blended_mvg", "dtu_yao_blended_mvg"]
48 | assert hints in ["mvguided_filtered", "not_guided", "guided", "mvguided"]
49 |
50 | # model instance
51 | model_net = models.__dict__[model].SimpleInterfaceNet()
52 | model_net.train_params = None
53 |
54 | # find the correct checkpoint
55 | if pretrained:
56 | with tarfile.open(tarpath) as archive:
57 | info = yaml.safe_load(archive.extractfile("trained_models/info.yaml"))
58 | for ckpt_id, meta in info.items():
59 | found = meta["model"] == model and meta["hints"] == hints
60 | if hints != "not_guided":
61 | found = found and float(meta["hints_density"]) == hints_density
62 | if dataset == "blended_mvg":
63 | found = found and meta["dataset"] == dataset
64 | else:
65 | found = (
66 | found
67 | and meta["dataset"] == "dtu_yao"
68 | and "load_weights" in meta
69 | and info[meta["load_weights"]]["dataset"] == "blended_mvg"
70 | )
71 | if found:
72 | break
73 | if not found:
74 | raise ValueError("Model not available with the provided parameters")
75 |
76 | model_net.load_state_dict(
77 | {
78 | ".".join(n.split(".")[1:]): v
79 | for n, v in torch.load(archive.extractfile(f"trained_models/{ckpt_id}.ckpt"))[
80 | "state_dict"
81 | ].items()
82 | }
83 | )
84 | model_net.train_params = meta
85 | return model_net
86 |
87 |
88 | def mvsnet(
89 | pretrained: bool = True,
90 | dataset: Literal["blended_mvg", "dtu_yao_blended_mvg"] = "blended_mvg",
91 | hints: Literal["mvguided_filtered", "not_guided", "guided", "mvguided"] = "not_guided",
92 | hints_density: float = 0.03,
93 | ):
94 | """
95 | pretrained `MVSNet`_ network.
96 |
97 | .. _MVSNet https://arxiv.org/pdf/1804.02505.pdf
98 | """
99 | return _load_model(
100 | _get_archive(),
101 | "mvsnet",
102 | pretrained=pretrained,
103 | dataset=dataset,
104 | hints=hints,
105 | hints_density=hints_density,
106 | )
107 |
108 |
109 | def ucsnet(
110 | pretrained: bool = True,
111 | dataset: Literal["blended_mvg", "dtu_yao_blended_mvg"] = "blended_mvg",
112 | hints: Literal["mvguided_filtered", "not_guided", "guided", "mvguided"] = "not_guided",
113 | hints_density: float = 0.03,
114 | ):
115 | """
116 | pretrained `UCSNet`_ network.
117 |
118 | .. _UCSNet https://arxiv.org/pdf/1911.12012.pdf
119 | """
120 | return _load_model(
121 | _get_archive(),
122 | "ucsnet",
123 | pretrained=pretrained,
124 | dataset=dataset,
125 | hints=hints,
126 | hints_density=hints_density,
127 | )
128 |
129 |
130 | def d2hc_rmvsnet(
131 | pretrained: bool = True,
132 | dataset: Literal["blended_mvg", "dtu_yao_blended_mvg"] = "blended_mvg",
133 | hints: Literal["mvguided_filtered", "not_guided", "guided", "mvguided"] = "not_guided",
134 | hints_density: float = 0.03,
135 | ):
136 | """
137 | pretrained `D2HCRMVSNet`_ network.
138 |
139 | .. _D2HCRMVSNet https://www.ecva.net/papers/eccv_2020/papers_ECCV/papers/123490647.pdf
140 | """
141 | return _load_model(
142 | _get_archive(),
143 | "d2hc_rmvsnet",
144 | pretrained=pretrained,
145 | dataset=dataset,
146 | hints=hints,
147 | hints_density=hints_density,
148 | )
149 |
150 |
151 | def patchmatchnet(
152 | pretrained: bool = True,
153 | dataset: Literal["blended_mvg", "dtu_yao_blended_mvg"] = "blended_mvg",
154 | hints: Literal["mvguided_filtered", "not_guided", "guided", "mvguided"] = "not_guided",
155 | hints_density: float = 0.03,
156 | ):
157 | """
158 | pretrained `PatchMatchNet`_ network.
159 |
160 | .. _PatchMatchNet https://github.com/FangjinhuaWang/PatchmatchNet
161 | """
162 | return _load_model(
163 | _get_archive(),
164 | "patchmatchnet",
165 | pretrained=pretrained,
166 | dataset=dataset,
167 | hints=hints,
168 | hints_density=hints_density,
169 | )
170 |
171 |
172 | def cas_mvsnet(
173 | pretrained: bool = True,
174 | dataset: Literal["blended_mvg", "dtu_yao_blended_mvg"] = "blended_mvg",
175 | hints: Literal["mvguided_filtered", "not_guided", "guided", "mvguided"] = "not_guided",
176 | hints_density: float = 0.03,
177 | ):
178 | """
179 | pretrained `CASMVSNet`_ network.
180 |
181 | .. _CASMVSNet https://arxiv.org/pdf/1912.06378.pdf
182 | """
183 | return _load_model(
184 | _get_archive(),
185 | "cas_mvsnet",
186 | pretrained=pretrained,
187 | dataset=dataset,
188 | hints=hints,
189 | hints_density=hints_density,
190 | )
191 |
192 |
193 | ## Datasets
194 |
195 |
196 | def _load_dataset(
197 | dataset: str,
198 | root: str,
199 | batch_size: int = 1,
200 | nviews: int = 5,
201 | ndepths: int = 128,
202 | hints: str = "mvguided_filtered",
203 | hints_density: float = 0.03,
204 | filtering_window: Tuple[int, int] = (9, 9),
205 | num_workers: int = cpu_count() // 2,
206 | ):
207 | from guided_mvs_lib.datasets import MVSDataModule
208 | from guided_mvs_lib.datasets.sample_preprocess import MVSSampleTransform
209 |
210 | dm = MVSDataModule(
211 | dataset,
212 | batch_size=batch_size,
213 | num_workers=num_workers,
214 | datapath=root,
215 | nviews=nviews,
216 | ndepths=ndepths,
217 | robust_train=False,
218 | transform=MVSSampleTransform(
219 | generate_hints=hints,
220 | hints_perc=hints_density,
221 | filtering_window=filtering_window,
222 | ),
223 | )
224 | return dm
225 |
226 |
227 | def blended_mvs(
228 | root: str,
229 | batch_size: int = 1,
230 | nviews: int = 5,
231 | ndepths: int = 128,
232 | hints: str = "mvguided_filtered",
233 | hints_density: float = 0.03,
234 | filtering_window: Tuple[int, int] = (9, 9),
235 | num_workers: int = cpu_count() // 2,
236 | ):
237 | """
238 | Utility function to load a Pytorch Lightning DataModule loading
239 | the BlendedMVS dataset
240 | """
241 | return _load_dataset(
242 | "blended_mvs",
243 | root=root,
244 | batch_size=batch_size,
245 | nviews=nviews,
246 | ndepths=ndepths,
247 | hints=hints,
248 | hints_density=hints_density,
249 | filtering_window=filtering_window,
250 | num_workers=num_workers,
251 | )
252 |
253 |
254 | def blended_mvg(
255 | root: str,
256 | batch_size: int = 1,
257 | nviews: int = 5,
258 | ndepths: int = 128,
259 | hints: str = "mvguided_filtered",
260 | hints_density: float = 0.03,
261 | filtering_window: Tuple[int, int] = (9, 9),
262 | num_workers: int = cpu_count() // 2,
263 | ):
264 | """
265 | Utility function to load a Pytorch Lightning DataModule loading
266 | the BlendedMVG dataset
267 | """
268 | return _load_dataset(
269 | "blended_mvg",
270 | root=root,
271 | batch_size=batch_size,
272 | nviews=nviews,
273 | ndepths=ndepths,
274 | hints=hints,
275 | hints_density=hints_density,
276 | filtering_window=filtering_window,
277 | num_workers=num_workers,
278 | )
279 |
280 |
281 | def dtu(
282 | root: str,
283 | batch_size: int = 1,
284 | nviews: int = 5,
285 | ndepths: int = 128,
286 | hints: str = "mvguided_filtered",
287 | hints_density: float = 0.03,
288 | filtering_window: Tuple[int, int] = (9, 9),
289 | num_workers: int = 4, # (pretty memory aggressive)
290 | ):
291 | """
292 | Utility function to load a Pytorch Lightning DataModule loading
293 | the DTU dataset
294 | """
295 | return _load_dataset(
296 | "dtu_yao",
297 | root=root,
298 | batch_size=batch_size,
299 | nviews=nviews,
300 | ndepths=ndepths,
301 | hints=hints,
302 | hints_density=hints_density,
303 | filtering_window=filtering_window,
304 | num_workers=num_workers,
305 | )
306 |
--------------------------------------------------------------------------------
/params.yaml:
--------------------------------------------------------------------------------
1 | model: patchmatchnet
2 |
3 | # train parameters
4 | train:
5 | dataset: dtu_yao # dtu_yao | blended_mvs (in data/)
6 | epochs: 5 # | null
7 | steps: null # | null, if both epochs and steps it takes the minimum one
8 | batch_size: 1
9 | lr: 0.001
10 | epochs_lr_decay: null # list of increasing int like [10, 12, 14] or null to disable
11 | epochs_lr_gamma: 2
12 | weight_decay: 0.0
13 | ndepths: 128 # 128
14 | views: 5
15 | hints: mvguided_filtered # not_guided | guided | mvguided | mvguided_filtered
16 | hints_filter_window: [9, 9]
17 | hints_density: 0.03
18 |
--------------------------------------------------------------------------------
/pyproject.toml:
--------------------------------------------------------------------------------
1 | [tool.black]
2 | line-length = 99
3 | exclude = '''
4 | /(
5 | \.git
6 | | data
7 | | tensorboard_logs
8 | | \.vscode
9 | | output
10 | )/
11 | '''
12 |
13 | [tool.isort]
14 | profile = "black"
15 |
16 | [tool.pytest.ini_options]
17 | testpaths = ["tests"]
18 | markers = [
19 | "slow: marks tests as slow (testing trainings for instance)",
20 | "train: marks tests concerning the training process",
21 | "data: marks tests concerning datasets and data loading",
22 | "dtu: marks tests on the DTU dataset",
23 | "blended_mvs: marks tests on the Blended MVS dataset",
24 | "blended_mvg: marks tests on the Blended MVG dataset",
25 | "mvsnet: marks tests on mvsnet",
26 | "cas_mvsnet: marks tests on cas_mvsnet",
27 | "ucsnet: marks tests on ucsnet",
28 | "d2hc_rmvsnet: marks tests on d2hc_rmvsnet",
29 | "patchmatchnet: marks tests on patchmatchnet",
30 | ]
--------------------------------------------------------------------------------
/results.ipynb:
--------------------------------------------------------------------------------
1 | {
2 | "cells": [
3 | {
4 | "cell_type": "markdown",
5 | "metadata": {},
6 | "source": [
7 | "# Example to use the pretrained models\n",
8 | "\n",
9 | "In this notebook we show how to load the dataset and the pretrained models to reproduce some of the results reported in the paper, to do so we leverage the ``torch.hub`` API"
10 | ]
11 | },
12 | {
13 | "cell_type": "code",
14 | "execution_count": 1,
15 | "metadata": {},
16 | "outputs": [],
17 | "source": [
18 | "import torch\n",
19 | "from tqdm import tqdm\n",
20 | "from collections import defaultdict\n",
21 | "import pandas as pd\n",
22 | "\n",
23 | "HUBCONF_URI = \".\"\n",
24 | "HUBCONF_SOURCE = \"local\""
25 | ]
26 | },
27 | {
28 | "cell_type": "markdown",
29 | "metadata": {},
30 | "source": [
31 | "### Load the dataset\n",
32 | "\n",
33 | "We provide implementation for DTU and Blended-MVS / Blended-MVG, once that the dataset is locally available it can be simply loaded by means of the torch.hub function"
34 | ]
35 | },
36 | {
37 | "cell_type": "code",
38 | "execution_count": 2,
39 | "metadata": {},
40 | "outputs": [],
41 | "source": [
42 | "# load the dataset\n",
43 | "dm = torch.hub.load(\n",
44 | " HUBCONF_URI,\n",
45 | " \"blended_mvg\",\n",
46 | " source=HUBCONF_SOURCE,\n",
47 | " root=\"data/blended-mvs\",\n",
48 | " hints=\"mvguided_filtered\",\n",
49 | " hints_density=0.03,\n",
50 | ")\n",
51 | "dm.prepare_data()\n",
52 | "dm.setup()\n",
53 | "dl = dm.test_dataloader()"
54 | ]
55 | },
56 | {
57 | "cell_type": "markdown",
58 | "metadata": {},
59 | "source": [
60 | "### Load the Network(s)\n",
61 | "\n",
62 | "Here we load the pretrained network trained with and without sparse depth points and test them, to reproduce the results provided in the [paper](https://arxiv.org/pdf/2210.11467v1.pdf)"
63 | ]
64 | },
65 | {
66 | "cell_type": "code",
67 | "execution_count": 3,
68 | "metadata": {},
69 | "outputs": [
70 | {
71 | "name": "stderr",
72 | "output_type": "stream",
73 | "text": [
74 | "mvsnet: 100%|██████████| 915/915 [03:04<00:00, 4.95it/s]\n",
75 | "ucsnet: 100%|██████████| 915/915 [04:53<00:00, 3.12it/s]\n",
76 | "d2hc_rmvsnet: 100%|██████████| 915/915 [21:56<00:00, 1.44s/it]\n",
77 | "patchmatchnet: 100%|██████████| 915/915 [02:24<00:00, 6.34it/s]\n",
78 | "cas_mvsnet: 100%|██████████| 915/915 [04:30<00:00, 3.38it/s]\n"
79 | ]
80 | }
81 | ],
82 | "source": [
83 | "results = defaultdict(lambda: [])\n",
84 | "\n",
85 | "def metrics(pred: torch.Tensor, gt: torch.Tensor):\n",
86 | " mask = (gt > 0)\n",
87 | " diff = torch.abs(gt[mask] - pred[mask])\n",
88 | " return {\n",
89 | " \">1 px\": torch.mean((diff > 1).float()),\n",
90 | " \">2 px\": torch.mean((diff > 2).float()),\n",
91 | " \">3 px\": torch.mean((diff > 3).float()),\n",
92 | " \">4 px\": torch.mean((diff > 4).float()),\n",
93 | " }\n",
94 | "\n",
95 | "for model_name in [\n",
96 | " \"mvsnet\",\n",
97 | " \"ucsnet\",\n",
98 | " \"d2hc_rmvsnet\",\n",
99 | " \"patchmatchnet\",\n",
100 | " \"cas_mvsnet\",\n",
101 | "]:\n",
102 | "\n",
103 | " model_orig = torch.hub.load(\n",
104 | " HUBCONF_URI,\n",
105 | " model_name,\n",
106 | " source=HUBCONF_SOURCE,\n",
107 | " dataset=\"blended_mvg\",\n",
108 | " hints=\"not_guided\",\n",
109 | " )\n",
110 | " model_orig.eval()\n",
111 | " model_orig.cuda() # use a gpu for this\n",
112 | "\n",
113 | " model_hints = torch.hub.load(\n",
114 | " HUBCONF_URI,\n",
115 | " model_name,\n",
116 | " source=HUBCONF_SOURCE,\n",
117 | " dataset=\"blended_mvg\",\n",
118 | " hints=\"mvguided_filtered\",\n",
119 | " hints_density=0.03,\n",
120 | " )\n",
121 | " model_hints.eval()\n",
122 | " model_hints.cuda()\n",
123 | "\n",
124 | " with torch.no_grad():\n",
125 | " for ex in tqdm(dl, desc=model_name):\n",
126 | "\n",
127 | " # compute inputs\n",
128 | " inp_no_hints = {\n",
129 | " \"imgs\": ex[\"imgs\"][\"stage_0\"].cuda(),\n",
130 | " \"intrinsics\": ex[\"intrinsics\"].cuda(),\n",
131 | " \"extrinsics\": ex[\"extrinsics\"].cuda(),\n",
132 | " \"depth_values\": ex[\"depth_values\"].cuda(),\n",
133 | " }\n",
134 | "\n",
135 | " inp_hints = dict(\n",
136 | " **inp_no_hints,\n",
137 | " hints=ex[\"hints\"].cuda(),\n",
138 | " )\n",
139 | " \n",
140 | " # forward\n",
141 | " depth_orig = model_orig(**inp_no_hints)\n",
142 | " depth_hints = model_hints(**inp_hints)\n",
143 | "\n",
144 | " # metrics\n",
145 | " gt = ex[\"depth\"][\"stage_0\"].cuda()\n",
146 | " results[model_name].append(metrics(depth_orig, gt))\n",
147 | " results[model_name + \"_hints\"].append(metrics(depth_hints, gt))"
148 | ]
149 | },
150 | {
151 | "cell_type": "code",
152 | "execution_count": 4,
153 | "metadata": {},
154 | "outputs": [
155 | {
156 | "data": {
157 | "text/html": [
158 | "\n",
159 | "\n",
172 | "
\n",
173 | " \n",
174 | " \n",
175 | " | \n",
176 | " >1 px | \n",
177 | " >2 px | \n",
178 | " >3 px | \n",
179 | " >4 px | \n",
180 | "
\n",
181 | " \n",
182 | " model | \n",
183 | " | \n",
184 | " | \n",
185 | " | \n",
186 | " | \n",
187 | "
\n",
188 | " \n",
189 | " \n",
190 | " \n",
191 | " mvsnet | \n",
192 | " 0.145 | \n",
193 | " 0.076 | \n",
194 | " 0.048 | \n",
195 | " 0.033 | \n",
196 | "
\n",
197 | " \n",
198 | " mvsnet_hints | \n",
199 | " 0.077 | \n",
200 | " 0.037 | \n",
201 | " 0.023 | \n",
202 | " 0.015 | \n",
203 | "
\n",
204 | " \n",
205 | " ucsnet | \n",
206 | " 0.083 | \n",
207 | " 0.042 | \n",
208 | " 0.027 | \n",
209 | " 0.019 | \n",
210 | "
\n",
211 | " \n",
212 | " ucsnet_hints | \n",
213 | " 0.040 | \n",
214 | " 0.018 | \n",
215 | " 0.011 | \n",
216 | " 0.008 | \n",
217 | "
\n",
218 | " \n",
219 | " d2hc_rmvsnet | \n",
220 | " 0.190 | \n",
221 | " 0.102 | \n",
222 | " 0.065 | \n",
223 | " 0.044 | \n",
224 | "
\n",
225 | " \n",
226 | " d2hc_rmvsnet_hints | \n",
227 | " 0.082 | \n",
228 | " 0.041 | \n",
229 | " 0.026 | \n",
230 | " 0.018 | \n",
231 | "
\n",
232 | " \n",
233 | " patchmatchnet | \n",
234 | " 0.083 | \n",
235 | " 0.041 | \n",
236 | " 0.026 | \n",
237 | " 0.018 | \n",
238 | "
\n",
239 | " \n",
240 | " patchmatchnet_hints | \n",
241 | " 0.065 | \n",
242 | " 0.034 | \n",
243 | " 0.022 | \n",
244 | " 0.016 | \n",
245 | "
\n",
246 | " \n",
247 | " cas_mvsnet | \n",
248 | " 0.083 | \n",
249 | " 0.040 | \n",
250 | " 0.025 | \n",
251 | " 0.018 | \n",
252 | "
\n",
253 | " \n",
254 | " cas_mvsnet_hints | \n",
255 | " 0.048 | \n",
256 | " 0.018 | \n",
257 | " 0.012 | \n",
258 | " 0.009 | \n",
259 | "
\n",
260 | " \n",
261 | "
\n",
262 | "
"
263 | ],
264 | "text/plain": [
265 | " >1 px >2 px >3 px >4 px\n",
266 | "model \n",
267 | "mvsnet 0.145 0.076 0.048 0.033\n",
268 | "mvsnet_hints 0.077 0.037 0.023 0.015\n",
269 | "ucsnet 0.083 0.042 0.027 0.019\n",
270 | "ucsnet_hints 0.040 0.018 0.011 0.008\n",
271 | "d2hc_rmvsnet 0.190 0.102 0.065 0.044\n",
272 | "d2hc_rmvsnet_hints 0.082 0.041 0.026 0.018\n",
273 | "patchmatchnet 0.083 0.041 0.026 0.018\n",
274 | "patchmatchnet_hints 0.065 0.034 0.022 0.016\n",
275 | "cas_mvsnet 0.083 0.040 0.025 0.018\n",
276 | "cas_mvsnet_hints 0.048 0.018 0.012 0.009"
277 | ]
278 | },
279 | "execution_count": 4,
280 | "metadata": {},
281 | "output_type": "execute_result"
282 | }
283 | ],
284 | "source": [
285 | "def mean_metrics(lst_d):\n",
286 | " return {k: f\"{torch.stack([dic[k] for dic in lst_d]).mean().item():.3f}\" for k in lst_d[0]}\n",
287 | "\n",
288 | "outs = [dict(model=net_name, **mean_metrics(results[net_name])) for net_name in results]\n",
289 | "pd.DataFrame(outs).set_index(\"model\")"
290 | ]
291 | }
292 | ],
293 | "metadata": {
294 | "kernelspec": {
295 | "display_name": "Python 3.8.10 64-bit ('guided-MVS')",
296 | "language": "python",
297 | "name": "python3"
298 | },
299 | "language_info": {
300 | "codemirror_mode": {
301 | "name": "ipython",
302 | "version": 3
303 | },
304 | "file_extension": ".py",
305 | "mimetype": "text/x-python",
306 | "name": "python",
307 | "nbconvert_exporter": "python",
308 | "pygments_lexer": "ipython3",
309 | "version": "3.8.10"
310 | },
311 | "orig_nbformat": 4,
312 | "vscode": {
313 | "interpreter": {
314 | "hash": "c0867fb98881dfcb09a7568ecd2197f83b886116f1d37752edcf5a1b3d692e7b"
315 | }
316 | }
317 | },
318 | "nbformat": 4,
319 | "nbformat_minor": 2
320 | }
321 |
--------------------------------------------------------------------------------
/tests/test_dataset.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests concerning data loading
3 | """
4 |
5 | import itertools
6 | import sys
7 | from pathlib import Path
8 |
9 | import numpy as np
10 | import pytest
11 | import torch
12 | from PIL import Image
13 |
14 | sys.path.append(".")
15 | from guided_mvs_lib.datasets import (
16 | MVSDataset,
17 | blended_mvg_utils,
18 | blended_mvs_utils,
19 | dtu_utils,
20 | )
21 | from guided_mvs_lib.datasets.sample_preprocess import MVSSampleTransform
22 |
23 |
24 | def _test_scans(dataset, split):
25 | if dataset == "dtu":
26 | ds_utils = dtu_utils
27 | elif dataset == "blended_mvs":
28 | ds_utils = blended_mvs_utils
29 | elif dataset == "blended_mvg":
30 | ds_utils = blended_mvg_utils
31 |
32 | scans = {
33 | "train": ds_utils.train_scans(),
34 | "val": ds_utils.val_scans(),
35 | "test": ds_utils.test_scans(),
36 | }[split]
37 |
38 | if not scans:
39 | raise ValueError(f"{split} not implemented for {dataset}")
40 |
41 | path = Path("data/blended-mvs")
42 | if dataset == "dtu":
43 | path = Path("data/dtu")
44 |
45 | missing_scans = []
46 | for scan in scans:
47 | paths = ds_utils.datapath_files(path, scan, 10, split, 3)
48 |
49 | exist = True
50 | error = []
51 | for k, v in paths.items():
52 | if v is not None and not v.parent.exists():
53 | exist = False
54 | error.append(f"not found: {str(v.parent)}")
55 | if split == "test" and dataset == "dtu":
56 | assert paths["pcd"] is not None
57 | assert paths["obs_mask"] is not None
58 | assert paths["ground_plane"] is not None
59 |
60 | if not exist:
61 | missing_scans.append((scan, error))
62 |
63 | error_out = ""
64 | if missing_scans:
65 | for scan, errors in missing_scans:
66 | error_out += f"\nerror in scan {scan}: \n" + "".join(f"- {err}\n" for err in errors)
67 | assert False, error_out
68 |
69 |
70 | @pytest.mark.dtu
71 | @pytest.mark.data
72 | @pytest.mark.parametrize("split", ["train", "val", "test"])
73 | def test_dtu_scans(split):
74 | _test_scans("dtu", split)
75 |
76 |
77 | @pytest.mark.blended_mvs
78 | @pytest.mark.data
79 | @pytest.mark.parametrize("split", ["train", "val", "test"])
80 | def test_blended_mvs_scans(split):
81 | _test_scans("blended_mvs", split)
82 |
83 |
84 | @pytest.mark.blended_mvg
85 | @pytest.mark.data
86 | @pytest.mark.parametrize("split", ["train", "val", "test"])
87 | def test_blended_mvg_scans(split):
88 | _test_scans("blended_mvg", split)
89 |
90 |
91 | def _test_ds_loading(name, mode, nviews, ndepths):
92 | if name == "dtu":
93 | name = "dtu_yao"
94 |
95 | if name in ["blended_mvs", "blended_mvg"]:
96 | datapath = "data/blended-mvs"
97 | else:
98 | datapath = "data/dtu"
99 |
100 | dataset = MVSDataset(
101 | name,
102 | datapath=datapath,
103 | mode=mode,
104 | nviews=nviews,
105 | ndepths=ndepths,
106 | )
107 |
108 | batch = dataset[0]
109 |
110 | for key in {
111 | "imgs",
112 | "intrinsics",
113 | "extrinsics",
114 | "depths",
115 | "ref_depth_min",
116 | "ref_depth_max",
117 | "ref_depth_values",
118 | "filename",
119 | }:
120 | assert key in batch.keys()
121 | assert isinstance(batch["imgs"], list)
122 | assert isinstance(batch["intrinsics"], list)
123 | assert batch["intrinsics"][0].shape == (3, 3)
124 | assert isinstance(batch["extrinsics"], list)
125 | assert batch["extrinsics"][0].shape == (4, 4)
126 |
127 | if name == "dtu_yao":
128 | img_shape = (512, 640)
129 | if mode == "test":
130 | img_shape = (1200, 1600)
131 | assert batch["depths"][0].shape == (*img_shape, 1)
132 | elif name in ["blended_mvs", "blended_mvg"]:
133 | assert batch["depths"][0].shape == (576, 768, 1)
134 |
135 | assert isinstance(batch["depths"], list)
136 | assert isinstance(batch["ref_depth_min"], float)
137 | assert isinstance(batch["ref_depth_max"], float)
138 | assert batch["ref_depth_values"].shape == (ndepths,)
139 | assert isinstance(batch["filename"], str)
140 |
141 |
142 | @pytest.mark.dtu
143 | @pytest.mark.data
144 | @pytest.mark.parametrize(
145 | "mode, nviews, ndepths",
146 | itertools.product(
147 | ["train", "val", "test"],
148 | [3, 5],
149 | [192, 128],
150 | ),
151 | )
152 | def test_dtu_loading(mode, nviews, ndepths):
153 | _test_ds_loading("dtu", mode, nviews, ndepths)
154 |
155 |
156 | @pytest.mark.blended_mvs
157 | @pytest.mark.data
158 | @pytest.mark.parametrize(
159 | "mode, nviews, ndepths",
160 | itertools.product(
161 | ["train", "val", "test"],
162 | [3, 5],
163 | [192, 128],
164 | ),
165 | )
166 | def test_blended_mvs_loading(mode, nviews, ndepths):
167 | _test_ds_loading("blended_mvs", mode, nviews, ndepths)
168 |
169 |
170 | @pytest.mark.blended_mvg
171 | @pytest.mark.data
172 | @pytest.mark.parametrize(
173 | "mode, nviews, ndepths",
174 | itertools.product(
175 | ["train", "val", "test"],
176 | [3, 5],
177 | [192, 128],
178 | ),
179 | )
180 | def test_blended_mvg_loading(mode, nviews, ndepths):
181 | _test_ds_loading("blended_mvg", mode, nviews, ndepths)
182 |
183 |
184 | @pytest.mark.data
185 | @pytest.mark.dtu
186 | @pytest.mark.parametrize("split", ["train", "val", "test"])
187 | def _test_ds_build_list(ds_name, split, path):
188 |
189 | ds_utils = {
190 | "dtu": dtu_utils,
191 | "blended_mvs": blended_mvs_utils,
192 | "blended_mvg": blended_mvg_utils,
193 | }[ds_name]
194 |
195 | scans = {
196 | "train": ds_utils.train_scans(),
197 | "val": ds_utils.val_scans(),
198 | "test": ds_utils.test_scans(),
199 | }[split]
200 |
201 | metas = ds_utils.build_list(path, scans, 3)
202 | for scan, light_idx, ref_view, _ in metas:
203 | paths = ds_utils.datapath_files(path, scan, ref_view, split, light_idx)
204 | for val in paths.values():
205 | if val is not None:
206 | assert val.exists(), f"{val} not found"
207 |
208 |
209 | @pytest.mark.data
210 | @pytest.mark.dtu
211 | @pytest.mark.parametrize("split", ["train", "val", "test"])
212 | def test_dtu_build_list(split):
213 | _test_ds_build_list("dtu", split, "data/dtu")
214 |
215 |
216 | @pytest.mark.data
217 | @pytest.mark.blended_mvs
218 | @pytest.mark.parametrize("split", ["train", "val", "test"])
219 | def test_blended_mvs_build_list(split):
220 | _test_ds_build_list("blended_mvs", split, "data/blended-mvs")
221 |
222 |
223 | @pytest.mark.data
224 | @pytest.mark.blended_mvg
225 | @pytest.mark.parametrize("split", ["train", "val", "test"])
226 | def test_blended_mvg_build_list(split):
227 | _test_ds_build_list("blended_mvg", split, "data/blended-mvs")
228 |
229 |
230 | @pytest.mark.data
231 | @pytest.mark.parametrize("hints", ["not_guided", "guided", "mvguided", "mvguided_filtered"])
232 | def test_dataset_preprocess(hints):
233 |
234 | # TODO: add tests for pcd fields
235 |
236 | fake_image = np.random.randint(0, 255, size=(512, 640, 3), dtype=np.uint8)
237 | sample = {
238 | "imgs": [Image.fromarray(fake_image)] * 3,
239 | "intrinsics": [np.random.randn(3, 3).astype(np.float32)] * 3,
240 | "extrinsics": [np.random.randn(4, 4).astype(np.float32)] * 3,
241 | "depths": [np.random.rand(512, 640, 1).astype(np.float32) * 1000] * 3,
242 | "ref_depth_min": 0.1,
243 | "ref_depth_max": 100.0,
244 | "ref_depth_values": np.arange(0, 1000, 192).astype(np.float32),
245 | "filename": "scan1/{}/00000000{}",
246 | "scan_pcd": None,
247 | "scan_pcd_obs_mask": None,
248 | "scan_pcd_bounding_box": None,
249 | "scan_pcd_resolution": None,
250 | "scan_pcd_ground_plane": None,
251 | }
252 |
253 | output = MVSSampleTransform(generate_hints=hints)(sample)
254 |
255 | for i in range(4):
256 | stage = f"stage_{i}"
257 | dims = 512 // (2 ** i), 640 // (2 ** i)
258 | assert output["imgs"][stage].shape == torch.Size([3, 3, *dims])
259 | assert output["proj_matrices"][stage].shape == torch.Size([3, 4, 4])
260 | assert output["depth"][stage].shape == torch.Size([1, *dims])
261 |
262 | assert "depth_min" in output
263 | assert "depth_max" in output
264 | assert "filename" in output
265 | if hints == "not_guided":
266 | assert not hasattr(output, "hints"), output.keys()
267 | else:
268 | assert output["hints"].shape == torch.Size([1, 512, 640])
269 |
--------------------------------------------------------------------------------
/tests/test_models.py:
--------------------------------------------------------------------------------
1 | """
2 | Tests network interface and operational behaviour
3 | """
4 |
5 | import logging
6 | import sys
7 |
8 | sys.path.append(".")
9 | from types import SimpleNamespace
10 | from unittest.mock import MagicMock
11 |
12 | import pytest
13 |
14 | from guided_mvs_lib import models
15 | from guided_mvs_lib.datasets import MVSDataModule
16 | from guided_mvs_lib.datasets.sample_preprocess import MVSSampleTransform
17 |
18 | logging.basicConfig(level=logging.DEBUG)
19 |
20 | _train_args = {
21 | "epochs": 5,
22 | "steps": None,
23 | "batch_size": 1,
24 | "lr": 0.001,
25 | "epochs_lr_decay": None,
26 | "epochs_lr_gamma": 2,
27 | "weight_decay": 0.0,
28 | "ndepths": 128,
29 | "wiews": 2,
30 | "hints_density": 0.01,
31 | }
32 |
33 |
34 | @pytest.mark.slow
35 | @pytest.mark.parametrize("hints", ["not_guided", "guided", "mvguided", "mvguided_filtered"])
36 | @pytest.mark.parametrize(
37 | "dataset",
38 | [
39 | pytest.param("blended_mvg", marks=pytest.mark.blended_mvg),
40 | pytest.param("blended_mvs", marks=pytest.mark.blended_mvs),
41 | pytest.param("dtu_yao", marks=pytest.mark.dtu),
42 | ],
43 | )
44 | @pytest.mark.parametrize(
45 | "network",
46 | [
47 | pytest.param("mvsnet", marks=pytest.mark.mvsnet),
48 | pytest.param("cas_mvsnet", marks=pytest.mark.cas_mvsnet),
49 | pytest.param("ucsnet", marks=pytest.mark.ucsnet),
50 | pytest.param("patchmatchnet", marks=pytest.mark.patchmatchnet),
51 | pytest.param("d2hc_rmvsnet", marks=pytest.mark.d2hc_rmvsnet),
52 | ],
53 | )
54 | def test_network_forward(dataset, network, hints):
55 |
56 | data_module = MVSDataModule(
57 | dataset,
58 | nviews=3,
59 | ndepths=128,
60 | transform=MVSSampleTransform(generate_hints=hints),
61 | )
62 | data_module.setup(stage="test")
63 | dl = data_module.test_dataloader()
64 | batch = next(iter(dl))
65 |
66 | # load model
67 | args = SimpleNamespace(model=network)
68 | args.train = SimpleNamespace(**_train_args, dataset=dataset, hints=hints)
69 | network = models.build_network(network, args)
70 |
71 | # mocking batch dictionary for inspection
72 | batch_mock = MagicMock()
73 | batch_mock.__getitem__.side_effect = batch.__getitem__
74 | batch_mock.__contains__.side_effect = batch.__contains__
75 |
76 | # assert output interface and hints usage
77 | output = network(batch_mock)
78 | assert set(output.keys()) == {"depth", "photometric_confidence", "loss_data"}
79 | if "hints" in batch:
80 | batch_mock.__getitem__.assert_any_call("hints")
81 |
--------------------------------------------------------------------------------
/tests/test_train.py:
--------------------------------------------------------------------------------
1 | """
2 | Testing the training procedure in fast-dev-run with all the networks
3 | on DTU
4 | """
5 |
6 | import shutil
7 | import sys
8 | import tempfile
9 | from pathlib import Path
10 |
11 | import py
12 | import pytest
13 |
14 | # tests are executed in the root dir working directory, the root directory
15 | # however is not a python package, this hack solves it and allows us to load
16 | # the training script as a module even if this testing file is in `tests` folder
17 | sys.path.append(".")
18 | from train import run_training
19 |
20 |
21 | @pytest.mark.slow
22 | @pytest.mark.train
23 | @pytest.mark.parametrize(
24 | "model",
25 | [
26 | pytest.param("mvsnet", marks=pytest.mark.mvsnet),
27 | pytest.param("cas_mvsnet", marks=pytest.mark.cas_mvsnet),
28 | pytest.param("ucsnet", marks=pytest.mark.ucsnet),
29 | pytest.param("patchmatchnet", marks=pytest.mark.patchmatchnet),
30 | pytest.param("d2hc_rmvsnet", marks=pytest.mark.d2hc_rmvsnet),
31 | ],
32 | )
33 | @pytest.mark.parametrize(
34 | "dataset",
35 | [
36 | pytest.param("dtu_yao", marks=pytest.mark.dtu),
37 | pytest.param("blended_mvs", marks=[pytest.mark.blended_mvs, pytest.mark.blended_mvg]),
38 | ],
39 | )
40 | @pytest.mark.parametrize("hints", ["not_guided", "guided"])
41 | def test_network_train(model, dataset, hints):
42 |
43 | # create a folder for the output
44 | path = tempfile.mkdtemp(prefix="guided-mvs-test-")
45 |
46 | run_training(
47 | params={
48 | "model": model,
49 | "train": {
50 | "dataset": dataset,
51 | "epochs": 10,
52 | "steps": None,
53 | "batch_size": 1,
54 | "lr": 0.001,
55 | "epochs_lr_decay": None,
56 | "epochs_lr_gamma": 2,
57 | "weight_decay": 0.0,
58 | "ndepths": 192,
59 | "views": 3,
60 | "hints": hints,
61 | "hints_density": 0.01,
62 | "hints_filter_window": [9, 9],
63 | },
64 | },
65 | cmdline_args=["--gpus", "1", "--fast-dev-run"],
66 | outpath=Path(path),
67 | logspath=Path(path),
68 | )
69 |
70 | shutil.rmtree(path, ignore_errors=True)
71 |
--------------------------------------------------------------------------------
/tests/test_version.py:
--------------------------------------------------------------------------------
1 | """
2 | Simple check of library version
3 | """
4 |
5 | import guided_mvs_lib as lib
6 |
7 |
8 | def test_version():
9 | assert lib.__version__ == "0.1.0"
10 |
--------------------------------------------------------------------------------
/train.py:
--------------------------------------------------------------------------------
1 | import argparse
2 | import os
3 | import re
4 | import shutil
5 | import subprocess
6 | import tempfile
7 | import warnings
8 | from datetime import datetime, timedelta
9 | from pathlib import Path
10 | from types import SimpleNamespace
11 | from typing import List
12 |
13 | import numpy as np
14 | import pytorch_lightning as pl
15 | import urllib3
16 | import yaml
17 | from git import Repo
18 | from git.exc import InvalidGitRepositoryError
19 | from mlflow.tracking import MlflowClient
20 | from pytorch_lightning.callbacks import Callback, ModelCheckpoint
21 | from pytorch_lightning.loggers import MLFlowLogger
22 |
23 | import guided_mvs_lib.models as models
24 | from guided_mvs_lib.datasets import MVSDataModule
25 | from guided_mvs_lib.datasets.sample_preprocess import MVSSampleTransform
26 | from guided_mvs_lib.utils import *
27 |
28 |
29 | def run_training(
30 | params: Union[str, Path, dict] = "params.yaml",
31 | cmdline_args: Optional[List[str]] = None,
32 | datapath: Union[str, Path, None] = None,
33 | outpath: Union[str, Path] = "output",
34 | logspath: Union[str, Path] = ".",
35 | ):
36 | # handle args
37 | outpath = Path(outpath)
38 | logspath = Path(logspath)
39 |
40 | # remove annoying torch specific version warnings
41 | warnings.simplefilter("ignore", UserWarning)
42 | urllib3.disable_warnings()
43 |
44 | parser = argparse.ArgumentParser(description="training procedure")
45 |
46 | # training params
47 | parser.add_argument(
48 | "--gpus", type=int, default=1, help="number of gpus to select for training"
49 | )
50 | parser.add_argument(
51 | "--fast-dev-run",
52 | nargs="?",
53 | const=True,
54 | default=False,
55 | help="if execute a single step of train and val, to debug",
56 | )
57 | parser.add_argument(
58 | "--limit-train-batches",
59 | type=int,
60 | default=None,
61 | help="limits the number of batches for each epoch, to debug",
62 | )
63 | parser.add_argument(
64 | "--limit-val-batches",
65 | type=int,
66 | default=None,
67 | help="limits the number of batches for each epoch, to debug",
68 | )
69 | parser.add_argument(
70 | "--resume-from-checkpoint",
71 | nargs="?",
72 | const=True,
73 | default=False,
74 | help="if resume from the last checkpoint or from a specific checkpoint",
75 | )
76 | parser.add_argument(
77 | "--load-weights",
78 | default=None,
79 | type=str,
80 | help="load weights either from a mlflow train or from a checkpoint file",
81 | )
82 |
83 | # experiment date
84 | date = datetime.now().strftime(r"%Y-%h-%d-%H-%M")
85 |
86 | # parse arguments and merge from params.yaml
87 | cmd_line_args = parser.parse_args(cmdline_args)
88 | if isinstance(params, dict):
89 | train_args = params
90 | else:
91 | with open(params, "rt") as f:
92 | train_args = yaml.safe_load(f)
93 |
94 | args = SimpleNamespace(**vars(cmd_line_args))
95 | for k, v in train_args.items():
96 | if not isinstance(v, dict):
97 | setattr(args, k, v)
98 | else:
99 | setattr(args, k, SimpleNamespace(**v))
100 |
101 | # Train using pytorch lightning
102 | pl.seed_everything(42)
103 |
104 | # Build LightningDataModule
105 | data_module = MVSDataModule(
106 | args.train.dataset,
107 | batch_size=args.train.batch_size,
108 | datapath=datapath,
109 | nviews=args.train.views,
110 | ndepths=args.train.ndepths,
111 | robust_train=True if args.train.dataset == "dtu_yao" else False,
112 | transform=MVSSampleTransform(
113 | generate_hints=args.train.hints,
114 | hints_perc=args.train.hints_density,
115 | filtering_window=tuple(args.train.hints_filter_window),
116 | ),
117 | )
118 |
119 | # loading model or only weights ?
120 | if args.load_weights is not None and args.resume_from_checkpoint is not False:
121 | print("Use either --load-weights or --resume-from-checkpoint")
122 | return
123 |
124 | ckpt_path = None
125 | steps_re = re.compile("step=(\d+)")
126 | if args.resume_from_checkpoint is True:
127 | if (outpath / "ckpts/last.ckpt").exists():
128 | ckpt_path = outpath / "ckpts/last.ckpt"
129 | else:
130 | ckpts = list((outpath / "ckpts").glob("*.ckpt"))
131 | steps = [
132 | int(steps_re.findall(ckpt.name)[0])
133 | for ckpt in ckpts
134 | if steps_re.findall(ckpt.name) is not []
135 | ]
136 | if not steps:
137 | print("not found any valid checkpoint in", str(outpath / "ckpts"))
138 | return
139 | ckpt_path = ckpts[np.argmax(steps)]
140 | print(f"resuming from last checkpoint: {ckpt_path}")
141 | elif args.resume_from_checkpoint is not False:
142 | if Path(args.resume_from_checkpoint).exists():
143 | ckpt_path = args.resume_from_checkpoint
144 | print(f"resuming from choosen checkpoint: {ckpt_path}")
145 | else:
146 | print(f"file {ckpt_path} does not exist")
147 | return
148 |
149 | # init mlflow logger and model
150 | if ckpt_path is None:
151 | logger = MLFlowLogger(
152 | experiment_name="guided-mvs",
153 | run_name=f"{args.model}-{date}",
154 | )
155 |
156 | outpath.mkdir(exist_ok=True, parents=True)
157 | with open(outpath / "run_uuid", "wt") as f:
158 | f.write(logger.run_id)
159 |
160 | model = models.MVSModel(
161 | args=args,
162 | mlflow_run_id=logger.run_id,
163 | v_num=f"{args.model}-{'-'.join(date.split('-')[1:3])}",
164 | )
165 | else:
166 |
167 | with open(outpath / "run_uuid", "rt") as f:
168 | mlflow_run_id = f.readline().strip()
169 |
170 | model = models.MVSModel.load_from_checkpoint(
171 | ckpt_path,
172 | args=args,
173 | mlflow_run_id=mlflow_run_id,
174 | v_num=f"{args.model}-{'-'.join(date.split('-')[1:3])}",
175 | )
176 | logger = MLFlowLogger(
177 | experiment_name="guided-mvs",
178 | run_name=f"{args.model}-{date}",
179 | )
180 | logger._run_id = mlflow_run_id
181 |
182 | # if required load weights
183 | if args.load_weights is not None:
184 | mlflow_client: MlflowClient = logger.experiment
185 | if args.load_weights in [
186 | run.run_uuid for run in mlflow_client.list_run_infos(logger.experiment_id)
187 | ]:
188 | # download the model
189 | run_weights_path = mlflow_client.download_artifacts(args.load_weights, "model.ckpt")
190 | model.load_state_dict(torch.load(run_weights_path)["state_dict"])
191 |
192 | # track the model weights
193 | run_weights_path = Path(run_weights_path)
194 | shutil.move(run_weights_path, run_weights_path.parent / "init_weights.ckpt")
195 | mlflow_client.log_artifact(
196 | logger.run_id, run_weights_path.parent / "init_weights.ckpt"
197 | )
198 | mlflow_client.set_tag(logger.run_id, "load_weights", args.load_weights)
199 | shutil.rmtree(Path(run_weights_path).parent, ignore_errors=True)
200 | else:
201 | try:
202 | model.load_state_dict(torch.load(args.load_weights)["state_dict"])
203 | tmpdir = Path(tempfile.mkdtemp())
204 | shutil.copy(args.load_weights, tmpdir / "init_weights.ckpt")
205 | mlflow_client.log_artifact(logger.run_id, tmpdir / "init_weights.ckpt")
206 | shutil.rmtree(tmpdir, ignore_errors=True)
207 | except FileNotFoundError:
208 | print(f"{args.load_weights} is neither a valid run id or a path to a .ckpt")
209 | return
210 |
211 | # handle checkpoints
212 | if (
213 | args.train.epochs is None
214 | or args.train.epochs == 1
215 | and args.train.steps is not None
216 | and args.train.steps > 0
217 | ):
218 | ckpt_callback = ModelCheckpoint(
219 | outpath / "ckpts",
220 | train_time_interval=timedelta(hours=2),
221 | save_last=True,
222 | )
223 | else:
224 | ckpt_callback = ModelCheckpoint(outpath / "ckpts", save_last=True)
225 |
226 | remove_output = True
227 |
228 | class HandleOutputs(Callback):
229 | def on_train_end(self, trainer, pl_module):
230 |
231 | # save final model
232 | print("saving the final model.")
233 | torch.save(
234 | {"global_step": trainer.global_step, "state_dict": pl_module.state_dict()},
235 | outpath / "model.ckpt",
236 | )
237 |
238 | # copy the model and the params on MLFlow
239 | if not args.fast_dev_run:
240 | mlflow_client: MlflowClient = logger.experiment
241 |
242 | # store diff file if needed
243 | try:
244 | repo = Repo(Path.cwd())
245 |
246 | if repo.is_dirty():
247 | try:
248 | out = subprocess.check_output(["git", "diff"], cwd=Path.cwd())
249 | if out is not None:
250 | tmpfile = Path(tempfile.mkdtemp()) / "changes.diff"
251 | with open(tmpfile, "wb") as f:
252 | f.write(out)
253 | mlflow_client.log_artifact(logger.run_id, tmpfile)
254 | os.remove(tmpfile)
255 | except subprocess.CalledProcessError as e:
256 | print("Failed to save a diff file of the current experiment")
257 |
258 | except InvalidGitRepositoryError:
259 | pass
260 |
261 | # save the model
262 | mlflow_client.log_artifact(logger.run_id, str(outpath / "model.ckpt"))
263 |
264 | # finally, remove the temp output and log in a hidden file the current run
265 | # for the eval step
266 | with open(".current_run.yaml", "wt") as f:
267 | yaml.safe_dump(
268 | {"experiment": logger.experiment_id, "run_uuid": logger.run_id}, f
269 | )
270 |
271 | def on_keyboard_interrupt(self, trainer, pl_module):
272 | print("training interrupted")
273 |
274 | # (not removing checkpoints)
275 | nonlocal remove_output
276 | remove_output = False
277 |
278 | # init train
279 | trainer_params = {
280 | "gpus": args.gpus,
281 | "fast_dev_run": args.fast_dev_run,
282 | "logger": logger,
283 | "benchmark": True,
284 | "callbacks": [ckpt_callback, HandleOutputs()],
285 | "weights_summary": None,
286 | "resume_from_checkpoint": ckpt_path,
287 | "num_sanity_val_steps": 0,
288 | }
289 |
290 | if (
291 | args.resume_from_checkpoint is not False
292 | and args.train.epochs is not None
293 | and args.train.epochs == 1
294 | and args.train.steps is not None
295 | and args.train.steps > 0
296 | and ckpt_path is not None
297 | ):
298 | args.train.epochs = None
299 |
300 | if args.train.epochs is not None:
301 | trainer_params["max_epochs"] = args.train.epochs
302 | if args.train.steps is not None:
303 | trainer_params["max_steps"] = args.train.steps
304 | if args.limit_train_batches is not None:
305 | trainer_params["limit_train_batches"] = args.limit_train_batches
306 | if args.limit_val_batches is not None:
307 | trainer_params["limit_val_batches"] = args.limit_val_batches
308 |
309 | trainer = pl.Trainer(**trainer_params)
310 | trainer.fit(model, data_module)
311 |
312 | if remove_output:
313 | shutil.rmtree(outpath)
314 |
315 |
316 | if __name__ == "__main__":
317 | run_training()
318 |
--------------------------------------------------------------------------------