├── .gitignore
├── README.md
├── download
├── _utils.py
├── kitti.py
├── scannetv2.py
├── sevenscenes.py
└── tartanair.py
├── environment.yml
├── evaluate
├── config
│ ├── kitti.yaml
│ ├── scannetv2.yaml
│ ├── sevenscenes.yaml
│ └── tartanair.yaml
├── metrics.py
└── utils
│ ├── funcs.py
│ ├── metrics.py
│ └── pose_estimation.py
├── lib
├── __init__.py
├── benchmark
│ └── mesh
│ │ ├── __init__.py
│ │ ├── _evaluation.py
│ │ └── _fusion.py
├── dataset
│ ├── __init__.py
│ ├── _resources
│ │ ├── __init__.py
│ │ ├── kitti
│ │ │ └── keyframes
│ │ │ │ └── standard
│ │ │ │ └── test_split.txt
│ │ ├── scannetv2
│ │ │ ├── keyframes
│ │ │ │ └── standard
│ │ │ │ │ └── test_split.txt
│ │ │ └── test_split.txt
│ │ ├── sevenscenes
│ │ │ ├── keyframes
│ │ │ │ └── standard
│ │ │ │ │ └── test_split.txt
│ │ │ └── test_split.txt
│ │ └── tartanair
│ │ │ ├── keyframes
│ │ │ └── standard
│ │ │ │ └── test_split.txt
│ │ │ └── test_split.txt
│ ├── _utils.py
│ ├── kitti.py
│ ├── scannetv2.py
│ ├── sevenscenes.py
│ ├── tartanair.py
│ └── utils
│ │ ├── __init__.py
│ │ └── metrics.py
└── visualize
│ ├── __init__.py
│ ├── depth.py
│ ├── flow.py
│ ├── gui.py
│ ├── matplotlib.py
│ └── plotly.py
├── media
└── scannetv2_scene0720_00.mp4
├── pyrightconfig.json
└── visualize
├── callbacks
├── build_scene.py
└── run_demo.py
└── run.py
/.gitignore:
--------------------------------------------------------------------------------
1 | data
2 | .env
3 | **/__pycache__
4 | renders
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
4 |
5 |
6 |
15 |
19 |
20 |
21 | 
22 |
23 | We propose Depth on Demand (DoD), a framework addressing the three major issues related to active depth sensors in streaming dense depth maps: spatial sparsity, limited frame rate and energy consumption of the depth sensors. DoD allows streaming high-resolution depth from an RGB camera and a depth sensor without requiring the depth sensor neither to be dense nor to match the frame rate of the RGB camera. Depth on Demand aims to improve the temporal resolution of an active depth sensor by utilizing the higher frame rate of an RGB camera. It estimates depth for each RGB frame, even for those that do not have direct depth sensor measurements.
24 |
25 |
26 | ### Install
27 |
28 | Dependencies can be installed with `conda` or `mamba` as follows:
29 |
30 | ```bash
31 | $ git clone https://github.com/andreaconti/depth-on-demand.git
32 | $ cd depth-on-demand
33 | $ conda env create -f environment.yml # use mamba if conda is too slow
34 | $ conda activate depth-on-demand
35 | # then, download and install the wheel containing the pretrained models, available for linux, windows and macos
36 | $ pip install https://github.com/andreaconti/depth-on-demand/releases/download/models%2Fv0.1.1/depth_on_demand-0.1.1-cp310-cp310-linux_x86_64.whl --no-deps
37 | ```
38 |
39 | We provide scripts to prepare the datasets for evaluation, each automatically downloads and unpack the dataset in the `data` directory:
40 |
41 | ```bash
42 | # as an example to download the evaluation sequences of scannetv2 can be used the following
43 | $ python download/scannetv2.py
44 | ```
45 |
46 | ### Evaluate
47 |
48 | To evaluate the framework we provide the following script, which loads the specified dataset, its configuration and returns the metrics.
49 |
50 | ```bash
51 | $ python evaluate/metrics.py --dataset scannetv2
52 | ```
53 |
54 | Configuration for each dataset is in `evaluate/config`
55 |
56 | ### Visualize
57 |
58 | We provide a GUI (based on [viser](https://viser.studio/latest/#)) to run Depth on Demand on ScannetV2 and 7Scenes sequences interactively. To start the visualizer use the following command and open the browser on the `http://127.0.0.1:8080` as described by the script output.
59 |
60 | ```bash
61 | $ python visualize/run.py
62 | ```
63 |
64 |
65 |
66 | On the right the interface parameters can be configured:
67 |
68 | - **Dataset**: select the dataset between scannetv2 and sevenscenes in the drop down menu, the dataset root should be upgraded accordingly.
69 | - **Build Scene**: choose from the drop down menu the scene to use. The other parameters control the density of the input sparse depth both temporally and spatially.
70 | - **TSDF Parameters**: you shouldn't need to change these parameters. They control the mesh reconstruction, tricking the TSDF parameters change the quality of the mesh.
71 |
72 | Press **Build Scene** to start the reconstruction. You'll see the system start the reconstruction, you can move the view angle using dragging with the mouse.
73 |
74 | **Attention: the interface in this phase may flicker and this step may potentially trigger seizures for people with photosensitive epilepsy. Viewer discretion is advised**
75 |
76 | Once the reconstruction is done, playback can be executed multiple times at different frame rates using the **Demo** box and the **Run Demo** button. With the `keep background mesh` option you can either interactively observe the mesh integration growing in time or fix the final one while the point of view moves.
77 |
78 | Finally, the **Render Video** button allows to save a demo video and also to **save all the predictions of Depth on Demand** using the `save frames` checkbox. Results are saved in the `renders` directory
79 |
80 | ## Citation
81 |
82 | ```bibtex
83 | @InProceedings{DoD_Conti24,
84 | author = {Conti, Andrea and Poggi, Matteo and Cambareri, Valerio and Mattoccia, Stefano},
85 | title = {Depth on Demand: Streaming Dense Depth from a Low Frame-Rate Active Sensor},
86 | booktitle = {European Conference on Computer Vision (ECCV)},
87 | month = {October},
88 | year = {2024},
89 | }
90 | ```
91 |
--------------------------------------------------------------------------------
/download/_utils.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities to download data from GitHub Releases
3 | """
4 |
5 | from urllib.request import urlopen, Request
6 | from urllib.error import HTTPError
7 | from zipfile import ZipFile
8 | import tempfile
9 | from pathlib import Path
10 | import os
11 | from tqdm import tqdm
12 | import hashlib
13 | import shutil
14 |
15 | __all__ = [
16 | "github_download_release_asset",
17 | "github_download_unzip_assets",
18 | "download_url_to_file",
19 | ]
20 |
21 |
22 | def github_download_unzip_assets(
23 | owner: str,
24 | repo: str,
25 | asset_ids: list[str],
26 | dst: str | Path,
27 | ):
28 | try:
29 | Path(dst).mkdir(exist_ok=True, parents=True)
30 | tmpdir = tempfile.mkdtemp()
31 | for asset_id in asset_ids:
32 | tmpfile = os.path.join(tmpdir, asset_id)
33 | github_download_release_asset(owner, repo, asset_id, tmpfile)
34 | with ZipFile(tmpfile) as f:
35 | f.extractall(dst)
36 | finally:
37 | if os.path.exists(tmpdir):
38 | shutil.rmtree(tmpdir, ignore_errors=True)
39 |
40 |
41 | def github_download_release_asset(
42 | owner: str,
43 | repo: str,
44 | asset_id: str,
45 | dst: str | Path,
46 | ):
47 | headers = {
48 | "Accept": "application/octet-stream",
49 | }
50 | if token := os.environ.get("GITHUB_TOKEN", None):
51 | headers["Authorization"] = f"Bearer {token}"
52 |
53 | try:
54 | download_url_to_file(
55 | f"https://api.github.com/repos/{owner}/{repo}/releases/assets/{asset_id}",
56 | dst,
57 | headers=headers,
58 | )
59 | except HTTPError as e:
60 | if e.code == 404 and "Authorization" not in headers:
61 | raise RuntimeError(
62 | "File not found, maybe missing GITHUB_TOKEN env variable?"
63 | )
64 | raise e
65 |
66 |
67 | def download_url_to_file(
68 | url,
69 | dst,
70 | *,
71 | hash_prefix=None,
72 | progress=True,
73 | headers=None,
74 | ):
75 | # find file size
76 | file_size = None
77 | req = Request(url, headers={} if not headers else headers)
78 | u = urlopen(req)
79 | meta = u.info()
80 | if hasattr(meta, "getheaders"):
81 | content_length = meta.getheaders("Content-Length")
82 | else:
83 | content_length = meta.get_all("Content-Length")
84 | if content_length is not None and len(content_length) > 0:
85 | file_size = int(content_length[0])
86 |
87 | dst = os.path.expanduser(dst)
88 | dst_dir = os.path.dirname(dst)
89 | f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
90 |
91 | try:
92 | if hash_prefix is not None:
93 | sha256 = hashlib.sha256()
94 | with tqdm(
95 | total=file_size,
96 | disable=not progress,
97 | unit="B",
98 | unit_scale=True,
99 | unit_divisor=1024,
100 | ) as pbar:
101 | while True:
102 | buffer = u.read(8192)
103 | if len(buffer) == 0:
104 | break
105 | f.write(buffer)
106 | if hash_prefix is not None:
107 | sha256.update(buffer)
108 | pbar.update(len(buffer))
109 |
110 | f.close()
111 | if hash_prefix is not None:
112 | digest = sha256.hexdigest()
113 | if digest[: len(hash_prefix)] != hash_prefix:
114 | raise RuntimeError(
115 | 'invalid hash value (expected "{}", got "{}")'.format(
116 | hash_prefix, digest
117 | )
118 | )
119 | shutil.move(f.name, dst)
120 | finally:
121 | f.close()
122 | if os.path.exists(f.name):
123 | os.remove(f.name)
124 |
125 |
--------------------------------------------------------------------------------
/download/kitti.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to download the KITTI Depth Completion
3 | validation set
4 | """
5 |
6 | import os
7 | import re
8 | import shutil
9 | from tqdm import tqdm
10 | from pathlib import Path
11 | from zipfile import ZipFile
12 | import requests
13 | from argparse import ArgumentParser
14 |
15 |
16 | calib_scans = [
17 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_calib.zip",
18 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_29_calib.zip",
19 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_calib.zip",
20 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_10_03_calib.zip",
21 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_28_calib.zip",
22 | ]
23 |
24 | scenes = list(
25 | map(
26 | lambda x: "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/"
27 | + f"{x}/{x}_sync.zip",
28 | [
29 | "2011_09_26_drive_0020",
30 | "2011_09_26_drive_0036",
31 | "2011_09_26_drive_0002",
32 | "2011_09_26_drive_0013",
33 | "2011_09_26_drive_0005",
34 | "2011_09_26_drive_0113",
35 | "2011_09_26_drive_0023",
36 | "2011_09_26_drive_0079",
37 | "2011_09_29_drive_0026",
38 | "2011_09_30_drive_0016",
39 | "2011_10_03_drive_0047",
40 | "2011_09_26_drive_0095",
41 | "2011_09_28_drive_0037",
42 | ],
43 | )
44 | )
45 |
46 |
47 | def download_file(url, save_path, chunk_size=1024, verbose=True):
48 | """
49 | Downloads a zip file from an `url` into a zip file in the
50 | provided `save_path`.
51 | """
52 | r = requests.get(url, stream=True)
53 | zip_name = url.split("/")[-1]
54 |
55 | content_length = int(r.headers["Content-Length"]) / 10**6
56 |
57 | if verbose:
58 | bar = tqdm(total=content_length, unit="Mb", desc="download " + zip_name)
59 | with open(save_path, "wb") as fd:
60 | for chunk in r.iter_content(chunk_size=chunk_size):
61 | fd.write(chunk)
62 | if verbose:
63 | bar.update(chunk_size / 10**6)
64 |
65 | if verbose:
66 | bar.close()
67 |
68 |
69 | def raw_download(root_path: str):
70 |
71 | date_match = re.compile("[0-9]+_[0-9]+_[0-9]+")
72 | drive_match = re.compile("[0-9]+_[0-9]+_[0-9]+_drive_[0-9]+_sync")
73 |
74 | def download_unzip(url):
75 | date = date_match.findall(url)[0]
76 | drive = drive_match.findall(url)[0]
77 | os.makedirs(os.path.join(root_path, date), exist_ok=True)
78 | download_file(url, os.path.join(root_path, date, drive + ".zip"), verbose=False)
79 | with ZipFile(os.path.join(root_path, date, drive + ".zip"), "r") as zip_ref:
80 | zip_ref.extractall(os.path.join(root_path, date, drive + "_tmp"))
81 | os.rename(
82 | os.path.join(root_path, date, drive + "_tmp", date, drive),
83 | os.path.join(root_path, date, drive),
84 | )
85 | shutil.rmtree(os.path.join(root_path, date, drive + "_tmp"))
86 | os.remove(os.path.join(root_path, date, drive + ".zip"))
87 |
88 | for scene in tqdm(scenes, desc="download kitti (raw)"):
89 | download_unzip(scene)
90 |
91 |
92 | def dc_download(root_path: str, progress_bar=True):
93 | """
94 | Downloads and scaffold depth completion dataset in `root_path`
95 | """
96 |
97 | if not os.path.exists(root_path):
98 | os.mkdir(root_path)
99 |
100 | # urls
101 | data_depth_selection_url = (
102 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_selection.zip"
103 | )
104 | data_depth_velodyne_url = (
105 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_velodyne.zip"
106 | )
107 | data_depth_annotated_url = (
108 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_annotated.zip"
109 | )
110 |
111 | # download of zips
112 | download_file(
113 | data_depth_selection_url,
114 | os.path.join(root_path, "data_depth_selection.zip"),
115 | verbose=progress_bar,
116 | )
117 |
118 | download_file(
119 | data_depth_velodyne_url,
120 | os.path.join(root_path, "data_depth_velodyne.zip"),
121 | verbose=progress_bar,
122 | )
123 | download_file(
124 | data_depth_annotated_url,
125 | os.path.join(root_path, "data_depth_annotated.zip"),
126 | verbose=progress_bar,
127 | )
128 |
129 | # unzip and remove zips
130 | with ZipFile(os.path.join(root_path, "data_depth_selection.zip"), "r") as zip_ref:
131 | zip_ref.extractall(root_path)
132 | os.rename(
133 | os.path.join(
134 | root_path, "depth_selection", "test_depth_completion_anonymous"
135 | ),
136 | os.path.join(root_path, "test_depth_completion_anonymous"),
137 | )
138 | os.rename(
139 | os.path.join(
140 | root_path, "depth_selection", "test_depth_prediction_anonymous"
141 | ),
142 | os.path.join(root_path, "test_depth_prediction_anonymous"),
143 | )
144 | os.rename(
145 | os.path.join(root_path, "depth_selection", "val_selection_cropped"),
146 | os.path.join(root_path, "val_selection_cropped"),
147 | )
148 | os.rmdir(os.path.join(root_path, "depth_selection"))
149 | with ZipFile(os.path.join(root_path, "data_depth_velodyne.zip"), "r") as zip_ref:
150 | zip_ref.extractall(root_path)
151 | with ZipFile(os.path.join(root_path, "data_depth_annotated.zip"), "r") as zip_ref:
152 | zip_ref.extractall(root_path)
153 |
154 | # remove zip files
155 | os.remove(os.path.join(root_path, "data_depth_selection.zip"))
156 | os.remove(os.path.join(root_path, "data_depth_annotated.zip"))
157 | os.remove(os.path.join(root_path, "data_depth_velodyne.zip"))
158 |
159 |
160 | def calib_download(root_path: str):
161 | """
162 | Downloads and scaffolds calibration files
163 | """
164 |
165 | Path(root_path).mkdir(exist_ok=True)
166 |
167 | for repo in calib_scans:
168 | calib_zip_path = os.path.join(root_path, "calib.zip")
169 | download_file(repo, calib_zip_path)
170 | with open(calib_zip_path, "rb") as f:
171 | ZipFile(f).extractall(root_path)
172 | os.remove(calib_zip_path)
173 |
174 |
175 | if __name__ == "__main__":
176 | parser = ArgumentParser("download the KITTI test set for evaluation")
177 | parser.add_argument("--root", type=Path, default=Path("data/kitti"))
178 | args = parser.parse_args()
179 | raw_download(str(args.root / "raw"))
180 | calib_download(str(args.root / "raw"))
181 | dc_download(str(args.root / "depth_completion"))
182 |
--------------------------------------------------------------------------------
/download/scannetv2.py:
--------------------------------------------------------------------------------
1 | from _utils import github_download_unzip_assets
2 | from argparse import ArgumentParser
3 | from zipfile import ZipFile
4 | from pathlib import Path
5 | import os
6 | import gdown
7 |
8 |
9 | def download_and_unzip_meshes(root: str | Path):
10 | root = Path(root)
11 |
12 | # from TransformerFusion
13 | # https://github.com/AljazBozic/TransformerFusion?tab=readme-ov-file
14 | url = "https://drive.usercontent.google.com/download?id=1-nto65_JTNs1vyeHycebidYFyQvE6kt4&authuser=0&confirm=t&uuid=05aba470-c11c-48f9-8ba3-162560ab15bb&at=APZUnTUsLTVanQWS8dVSGr5pA3hJ%3A1710523361429"
15 | out_zip = str(Path(root) / ".meshes.zip")
16 | gdown.download(url, out_zip, quiet=False)
17 |
18 | # unzip data in the correct format
19 | with ZipFile(out_zip, "r") as f:
20 | for file in f.infolist():
21 | if not file.is_dir():
22 | _, scene, fname = file.filename.split("/")
23 | new_path = root / (
24 | f"{scene}_vh_clean.ply"
25 | if fname == "mesh_gt.ply"
26 | else f"{scene}_{fname}"
27 | )
28 | data = f.read(file)
29 | with open(new_path, "wb") as out:
30 | out.write(data)
31 |
32 | # remove the zip
33 | os.remove(out_zip)
34 |
35 |
36 | if __name__ == "__main__":
37 | parser = ArgumentParser("download the TartanAir test set for evaluation")
38 | parser.add_argument("--root", type=Path, default=Path("data/scannetv2"))
39 | args = parser.parse_args()
40 | github_download_unzip_assets(
41 | "andreaconti",
42 | "depth-on-demand",
43 | ["156761793", "156761795", "156761794"],
44 | args.root,
45 | )
46 | download_and_unzip_meshes(args.root)
47 |
--------------------------------------------------------------------------------
/download/sevenscenes.py:
--------------------------------------------------------------------------------
1 | from _utils import github_download_unzip_assets
2 | from argparse import ArgumentParser
3 | from pathlib import Path
4 |
5 | if __name__ == "__main__":
6 | parser = ArgumentParser("download the 7Scenes test set for evaluation")
7 | parser.add_argument("--root", type=Path, default=Path("data/sevenscenes"))
8 | args = parser.parse_args()
9 | github_download_unzip_assets(
10 | "andreaconti",
11 | "depth-on-demand",
12 | ["155673151"],
13 | args.root,
14 | )
15 |
--------------------------------------------------------------------------------
/download/tartanair.py:
--------------------------------------------------------------------------------
1 | from _utils import github_download_unzip_assets
2 | from argparse import ArgumentParser
3 | from pathlib import Path
4 |
5 | if __name__ == "__main__":
6 | parser = ArgumentParser("download the TartanAir test set for evaluation")
7 | parser.add_argument("--root", type=Path, default=Path("data/tartanair"))
8 | args = parser.parse_args()
9 | github_download_unzip_assets(
10 | "andreaconti",
11 | "depth-on-demand",
12 | ["156760244", "156760245", "156760243", "156760242"],
13 | args.root,
14 | )
15 |
--------------------------------------------------------------------------------
/environment.yml:
--------------------------------------------------------------------------------
1 | name: depth-on-demand
2 | channels:
3 | - pytorch
4 | - nvidia
5 | - conda-forge
6 | - defaults
7 | - pytorch3d
8 | dependencies:
9 | - python=3.10.11
10 | - pytorch=2.0.1
11 | - pytorch-cuda=11.8
12 | - torchvision=0.15.2
13 | - torchdata=0.6.0
14 | - lightning=2.0.3
15 | - albumentations=1.3.1
16 | - kornia=0.6.12
17 | - pytorch3d=0.7.4
18 |
19 | # dev
20 | - ipython
21 | - ipykernel
22 | - black
23 | - pytest
24 |
25 | # c++ dependencies
26 | - eigen
27 |
28 | - pip
29 | - pip:
30 | - pipe==2.0.0
31 | - omegaconf==2.3.0
32 | - open3d==0.17.0
33 | - plyfile==0.9.0
34 | - pykitti==0.3.1
35 | - timm==0.9.7
36 | - viser>=0.1.7<1.0.0
37 | - h5py>=3.10.0
38 | - git+https://github.com/PoseLib/PoseLib.git
39 | - git+https://github.com/cvg/LightGlue.git
40 | - gdown
41 |
42 |
--------------------------------------------------------------------------------
/evaluate/config/kitti.yaml:
--------------------------------------------------------------------------------
1 | args:
2 | root_raw: data/kitti/raw
3 | root_completion: data/kitti/depth_completion
4 | batch_size: 1
5 | load_prevs: 0
6 | shuffle_keyframes: false
7 | order_sources_by_pose: false
8 | load_hints_pcd: true
9 | keyframes: standard
10 | min_depth: 0.001
11 | max_depth: 80.00
12 |
13 | inference:
14 | depth_hints:
15 | density: null
16 | interval: 10
17 | hints_from_pcd: true
18 | pnp_pose: true
19 | source:
20 | interval: ${inference.depth_hints.interval}
21 | model_eval_params:
22 | n_cycles: 10
--------------------------------------------------------------------------------
/evaluate/config/scannetv2.yaml:
--------------------------------------------------------------------------------
1 | args:
2 | root: data/scannetv2
3 | batch_size: 1
4 | load_prevs: 0
5 | shuffle_keyframes: false
6 | order_sources_by_pose: false
7 | keyframes: standard
8 | min_depth: 0.0
9 | max_depth: 10.0
10 |
11 | tsdf_fusion:
12 | engine: open3d
13 | depth_scale: 1
14 | depth_trunc: 3
15 | sdf_trunc: 0.12
16 | voxel_length: 0.04
17 |
18 | inference:
19 | depth_hints:
20 | interval: 5
21 | density: 500
22 | hints_from_pcd: false
23 | pnp_pose: false
24 | source:
25 | interval: 5
26 | model_eval_params:
27 | n_cycles: 10
28 |
--------------------------------------------------------------------------------
/evaluate/config/sevenscenes.yaml:
--------------------------------------------------------------------------------
1 | args:
2 | root: data/sevenscenes
3 | batch_size: 1
4 | load_prevs: 0
5 | shuffle_keyframes: false
6 | order_sources_by_pose: false
7 | keyframes: standard
8 | min_depth: 0.0
9 | max_depth: 10.0
10 |
11 | tsdf_fusion:
12 | engine: open3d
13 | depth_scale: 1
14 | depth_trunc: 3
15 | sdf_trunc: 0.12
16 | voxel_length: 0.04
17 |
18 | inference:
19 | depth_hints:
20 | interval: 5
21 | density: 500
22 | hints_from_pcd: false
23 | pnp_pose: false
24 | source:
25 | interval: 5
26 | model_eval_params:
27 | n_cycles: 10
--------------------------------------------------------------------------------
/evaluate/config/tartanair.yaml:
--------------------------------------------------------------------------------
1 | args:
2 | root: data/tartanair
3 | batch_size: 1
4 | load_prevs: 0
5 | shuffle_keyframes: false
6 | order_sources_by_pose: false
7 | keyframes: standard
8 | min_depth: 0.001
9 | max_depth: 100.0
10 |
11 | inference:
12 | depth_hints:
13 | interval: 5
14 | density: 500
15 | hints_from_pcd: false
16 | pnp_pose: false
17 | source:
18 | interval: 5
19 | model_eval_params:
20 | n_cycles: 10
--------------------------------------------------------------------------------
/evaluate/metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Script to reproduce the DoD Paper metrics on various datasets used
3 | """
4 |
5 | import sys
6 |
7 | # In[] Imports
8 | from pathlib import Path
9 |
10 | sys.path.append(str(Path(__file__).parents[1]))
11 | import shutil
12 | import tempfile
13 | import warnings
14 | from argparse import ArgumentParser
15 | from collections import defaultdict
16 | from copy import deepcopy
17 |
18 | import pandas as pd
19 | import torch
20 | import utils.funcs as funcs
21 | from depth_on_demand import Model as DepthOnDemand
22 | from omegaconf import OmegaConf
23 | from torchmetrics import MeanMetric
24 | from tqdm import tqdm
25 | from utils.metrics import compute_metrics_depth
26 |
27 | from lib.benchmark.mesh import TSDFFusion, mesh_metrics
28 | from lib.dataset import load_datamodule
29 |
30 | warnings.filterwarnings("ignore", category=UserWarning)
31 |
32 | rootdir = Path(__file__).parents[1]
33 | thisdir = Path(__file__).parent
34 |
35 | # In[] Args
36 |
37 | parser = ArgumentParser("Test DoD Model")
38 | parser.add_argument(
39 | "--dataset",
40 | choices=["kitti", "scannetv2", "sevenscenes", "tartanair"],
41 | default="sevenscenes",
42 | )
43 | parser.add_argument("--device", default="cuda:0")
44 | parser.add_argument("--f")
45 | args = parser.parse_args()
46 |
47 | # In[] Load dataset
48 | cfg = OmegaConf.load(thisdir / "config" / (args.dataset + ".yaml"))
49 | dm = load_datamodule(
50 | args.dataset,
51 | **cfg.args,
52 | split_test_scans_loaders=True,
53 | )
54 | dm.prepare_data()
55 | dm.setup("test")
56 | scans = dm.test_dataloader()
57 |
58 | # In[] Load the model
59 | model = DepthOnDemand(
60 | pretrained={
61 | "sevenscenes": "scannetv2",
62 | "scannetv2": "scannetv2",
63 | "tartanair": "tartanair",
64 | "kitti": "kitti",
65 | }[args.dataset],
66 | device=args.device,
67 | )
68 |
69 | # In[] Testing
70 | predict_pose = None
71 | if cfg.inference.depth_hints.pnp_pose:
72 |
73 | from utils.pose_estimation import PoseEstimation
74 |
75 | predict_pose = PoseEstimation(dilate_depth=3)
76 | predict_pose.to(args.device)
77 |
78 | metrics = defaultdict(lambda: MeanMetric().to(args.device))
79 | for scan in tqdm(scans):
80 |
81 | buffer_hints = {}
82 | buffer_source = {}
83 | if args.dataset == "scannetv2":
84 | tsdf_volume = TSDFFusion(**cfg.tsdf_fusion)
85 |
86 | for idx, batch in enumerate(tqdm(scan, leave=False)):
87 |
88 | batch = {
89 | k: v.to(args.device) if isinstance(v, torch.Tensor) else v
90 | for k, v in batch.items()
91 | }
92 | if idx % cfg.inference.depth_hints.interval == 0:
93 | buffer_hints = deepcopy(batch)
94 | if idx % cfg.inference.source.interval == 0:
95 | buffer_source = deepcopy(batch)
96 |
97 | input = funcs.prepare_input(
98 | batch
99 | | {k + "_prev_0": v for k, v in buffer_hints.items()}
100 | | {k + "_prev_1": v for k, v in buffer_source.items()},
101 | hints_density=cfg.inference.depth_hints.density,
102 | hints_from_pcd=cfg.inference.depth_hints.hints_from_pcd,
103 | hints_postfix="_prev_0",
104 | source_postfix="_prev_1",
105 | predict_pose=predict_pose,
106 | )
107 | buffer_source["hints"] = input.pop("source_hints")
108 | buffer_hints["hints"] = input.pop("hints_hints")
109 |
110 | # inference
111 | gt = batch["depth"].to(args.device)
112 | mask = gt > 0
113 | with torch.no_grad():
114 | pred_depth = model(
115 | **input,
116 | n_cycles=cfg.inference.model_eval_params.n_cycles,
117 | )
118 |
119 | # metrics
120 | for k, v in compute_metrics_depth(pred_depth[mask], gt[mask], "test").items():
121 | metrics[k].update(v)
122 | if args.dataset == "scannetv2":
123 | tsdf_volume.integrate_rgbd(
124 | batch["image"][0].permute(1, 2, 0).cpu().numpy(),
125 | pred_depth[0].permute(1, 2, 0).detach().cpu().numpy(),
126 | batch["intrinsics"][0].cpu().numpy(),
127 | batch["extrinsics"][0].cpu().numpy(),
128 | )
129 |
130 | if args.dataset == "scannetv2":
131 | mesh_pred = tsdf_volume.triangle_mesh()
132 | for k, v in mesh_metrics(
133 | mesh_pred,
134 | batch["gt_mesh_path"][0],
135 | batch["gt_mesh_world2grid_path"][0],
136 | batch["gt_mesh_occl_path"][0],
137 | device=args.device,
138 | ).items():
139 | metrics["test/" + k].update(v)
140 | tmpdir = Path(tempfile.mkdtemp())
141 | scan_name = "_".join(Path(batch["gt_mesh_path"][0]).stem.split("_")[:2])
142 | mesh_path = tmpdir / (scan_name + ".ply")
143 | tsdf_volume.write_triangle_mesh(mesh_path)
144 | shutil.rmtree(tmpdir, ignore_errors=True)
145 |
146 | # %%
147 |
148 | print(f"== metrics for {args.dataset} ==")
149 | for k, v in metrics.items():
150 | print(f"{k.ljust(25)}: {v.compute():0.5f}")
151 |
--------------------------------------------------------------------------------
/evaluate/utils/funcs.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import kornia
3 | from torch import Tensor
4 | import random
5 | import torch.nn.functional as F
6 | from typing import Callable
7 | from lib.visualize.matplotlib import GridFigure
8 | from kornia.morphology import dilation
9 | import io
10 | import itertools
11 | from contextlib import redirect_stdout
12 | import kornia
13 | import numpy as np
14 | from kornia.geometry.quaternion import QuaternionCoeffOrder
15 | from kornia.geometry.conversions import (
16 | euler_from_quaternion,
17 | rotation_matrix_to_quaternion,
18 | )
19 | import re
20 | from lib.visualize.matplotlib import color_depth
21 |
22 | __all__ = ["sparsity_depth", "prepare_input"]
23 |
24 |
25 | def prepare_input(
26 | batch,
27 | hints_density: float | int = 500,
28 | hints_from_pcd: bool = False,
29 | source_postfix: str = "",
30 | hints_postfix: str = "",
31 | predict_pose: Callable | None = None,
32 | pose_noise_std_mult: float = 0.0,
33 | ):
34 | # prepare source and target views
35 | image_target = batch["image"]
36 | image_source = batch[f"image{source_postfix}"]
37 | intrinsics = torch.stack(
38 | [batch["intrinsics"], batch[f"intrinsics{source_postfix}"]], 1
39 | )
40 |
41 | # sample hints on their frame
42 | hints_hints = find_hints(batch, hints_postfix, hints_density, hints_from_pcd)
43 | if hints_postfix != source_postfix:
44 | hints_src = find_hints(batch, source_postfix, hints_density, hints_from_pcd)
45 | else:
46 | hints_src = hints_hints
47 |
48 | # prepare pose
49 | if predict_pose is None:
50 | pose_src_tgt = (
51 | inv_pose(batch["extrinsics"]) @ batch[f"extrinsics{source_postfix}"]
52 | )
53 | pose_hints_tgt = (
54 | inv_pose(batch["extrinsics"]) @ batch[f"extrinsics{hints_postfix}"]
55 | )
56 | if pose_noise_std_mult > 0.0:
57 | if source_postfix == hints_postfix:
58 | trasl_std = pose_noise_std_mult * pose_src_tgt[:, :3, -1].abs()
59 | rot_std = pose_noise_std_mult * pose_to_angles(pose_src_tgt).abs()
60 | pnoise = pose_noise(rot_std, trasl_std)
61 | pose_src_tgt = pnoise @ pose_src_tgt
62 | pose_hints_tgt = pnoise @ pose_hints_tgt
63 | else:
64 | trasl_std = pose_noise_std_mult * pose_src_tgt[:, :3, -1].abs()
65 | rot_std = pose_noise_std_mult * pose_to_angles(pose_src_tgt).abs()
66 | pnoise = pose_noise(rot_std, trasl_std)
67 | pose_src_tgt = pnoise @ pose_src_tgt
68 |
69 | trasl_std = pose_noise_std_mult * pose_hints_tgt[:, :3, -1].abs()
70 | rot_std = pose_noise_std_mult * pose_to_angles(pose_hints_tgt).abs()
71 | pnoise = pose_noise(rot_std, trasl_std)
72 | pose_hints_tgt = pnoise @ pose_hints_tgt
73 |
74 | else:
75 | pose_hints_tgt = predict_pose(
76 | image_target,
77 | batch[f"image{hints_postfix}"],
78 | hints_hints,
79 | batch["intrinsics"],
80 | batch[f"intrinsics{hints_postfix}"],
81 | )
82 |
83 | if hints_postfix != source_postfix:
84 | pose_src_tgt = predict_pose(
85 | image_target,
86 | batch[f"image{source_postfix}"],
87 | hints_src,
88 | batch["intrinsics"],
89 | batch[f"intrinsics{source_postfix}"],
90 | )
91 | else:
92 | pose_src_tgt = pose_hints_tgt
93 |
94 | # project hints
95 | if not hints_from_pcd:
96 | h, w = image_target.shape[-2:]
97 | hints, _ = project_depth(
98 | hints_hints.to(torch.float32),
99 | batch[f"intrinsics{hints_postfix}"].to(torch.float32),
100 | batch[f"intrinsics"].to(torch.float32),
101 | pose_hints_tgt.to(torch.float32),
102 | torch.zeros_like(hints_hints, dtype=torch.float32),
103 | )
104 | else:
105 | b, _, h, w = image_target.shape
106 | device, dtype = image_target.device, image_target.dtype
107 | hints = torch.zeros(b, 1, h, w, device=device, dtype=dtype)
108 | project_pcd(
109 | batch[f"hints_pcd{hints_postfix}"],
110 | batch[f"intrinsics"],
111 | hints,
112 | pose_hints_tgt,
113 | )
114 |
115 | return {
116 | "target": image_target,
117 | "source": image_source,
118 | "pose_src_tgt": pose_src_tgt,
119 | "intrinsics": intrinsics,
120 | "hints": hints,
121 | "source_hints": hints_src,
122 | "hints_hints": hints_hints,
123 | }
124 |
125 |
126 | def prepare_input_onnx(
127 | batch,
128 | hints_density: float | int = 500,
129 | hints_from_pcd: bool = False,
130 | source_postfix: str = "",
131 | hints_postfix: str = "",
132 | predict_pose: Callable | None = None,
133 | pose_noise_std_mult: float = 0.0,
134 | ):
135 |
136 | # take the useful base outputs
137 | base_input = prepare_input(
138 | batch,
139 | hints_density,
140 | hints_from_pcd,
141 | source_postfix,
142 | hints_postfix,
143 | predict_pose,
144 | pose_noise_std_mult,
145 | )
146 | out = {
147 | "target": base_input["target"],
148 | "source": base_input["source"],
149 | "pose_src_tgt": base_input["pose_src_tgt"],
150 | "intrinsics": base_input["intrinsics"],
151 | }
152 |
153 | # compute the other outputs required by the onnx model
154 | b, _, h, w = base_input["hints_hints"].shape
155 | device = base_input["hints_hints"].device
156 | hints8, subpix8 = project_depth(
157 | base_input["hints_hints"],
158 | batch["intrinsics"],
159 | adjust_intrinsics(batch["intrinsics"], 1 / 8),
160 | torch.eye(4, device=device)[None],
161 | torch.zeros(
162 | 1, 1, h // 8, w // 8, device=device, dtype=base_input["hints_hints"].dtype
163 | ),
164 | )
165 | init_depth = create_init_depth(hints8)
166 | out |= {
167 | "hints8": hints8,
168 | "subpix8": subpix8,
169 | "init_depth": init_depth,
170 | }
171 | return out, base_input
172 |
173 |
174 | def adjust_intrinsics(intrinsics: Tensor, factor: float = 0.125):
175 | intrinsics = intrinsics.clone()
176 | intrinsics[..., :2, :] = intrinsics[..., :2, :] * factor
177 | return intrinsics
178 |
179 |
180 | def create_init_depth(hints8, fallback_init_depth=2.0, min_hints_for_init=20):
181 | batch_size = hints8.shape[0]
182 | mean_hints = (
183 | torch.ones(batch_size, dtype=hints8.dtype, device=hints8.device)
184 | * fallback_init_depth
185 | )
186 | for bi in range(batch_size):
187 | hints_mask = hints8[bi, 0] > 0
188 | if hints_mask.sum() > min_hints_for_init:
189 | mean_hints[bi] = hints8[bi, 0][hints8[bi, 0] > 0].mean()
190 | depth = torch.ones_like(hints8) * mean_hints[:, None, None, None]
191 | depth = torch.where(hints8 > 0, hints8, depth)
192 | return depth
193 |
194 |
195 | def find_hints(
196 | batch,
197 | postfix,
198 | hints_density,
199 | from_pcd: bool = False,
200 | project_pose: Tensor | None = None,
201 | ):
202 | if not from_pcd:
203 | if f"hints{postfix}" in batch:
204 | sparse_hints = batch[f"hints{postfix}"]
205 | else:
206 | sparse_hints = sparsify_depth(batch[f"depth{postfix}"], hints_density)
207 | return sparse_hints
208 | else:
209 | hints_pcd = batch[f"hints_pcd{postfix}"]
210 | device, dtype = hints_pcd.device, hints_pcd.dtype
211 | b, _, h, w = batch["image"].shape
212 | sparse_hints = project_pcd(
213 | batch[f"hints_pcd{postfix}"],
214 | batch["intrinsics"],
215 | torch.zeros(b, 1, h, w, device=device, dtype=dtype),
216 | project_pose,
217 | )
218 | return sparse_hints
219 |
220 |
221 | def sparsify_depth(
222 | depth: torch.Tensor, hints_perc: float | tuple[float, float] | int = 0.03
223 | ) -> torch.Tensor:
224 | if isinstance(hints_perc, tuple | list):
225 | hints_perc = random.uniform(*hints_perc)
226 |
227 | if hints_perc < 1.0:
228 | sparse_map = torch.rand_like(depth) < hints_perc
229 | sparse_depth = torch.where(
230 | sparse_map, depth, torch.tensor(0.0, dtype=depth.dtype, device=depth.device)
231 | )
232 | return sparse_depth
233 | else:
234 | b = depth.shape[0]
235 | idxs = torch.nonzero(depth[:, 0])
236 | idxs = idxs[torch.randperm(len(idxs))]
237 | sparse_depth = torch.zeros_like(depth)
238 | for bi in range(b):
239 | bidxs = idxs[idxs[:, 0] == bi][:hints_perc]
240 | sparse_depth[bi, 0, bidxs[:, 1], bidxs[:, 2]] = depth[
241 | bi, 0, bidxs[:, 1], bidxs[:, 2]
242 | ]
243 | return sparse_depth
244 |
245 |
246 | def inv_pose(pose: torch.Tensor) -> torch.Tensor:
247 | rot_inv = pose[:, :3, :3].permute(0, 2, 1)
248 | tr_inv = -rot_inv @ pose[:, :3, -1:]
249 | pose_inv = torch.eye(4, dtype=pose.dtype, device=pose.device)[None]
250 | pose_inv = pose_inv.repeat(pose.shape[0], 1, 1)
251 | pose_inv[:, :3, :3] = rot_inv
252 | pose_inv[:, :3, -1:] = tr_inv
253 | return pose_inv
254 |
255 |
256 | def pose_distance(
257 | pose: torch.Tensor,
258 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
259 | rot = pose[:, :3, :3]
260 | trasl = pose[:, :3, 3]
261 | R_trace = rot.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1)
262 | r_measure = torch.sqrt(
263 | 2 * (1 - torch.minimum(torch.ones_like(R_trace) * 3.0, R_trace) / 3)
264 | )
265 | t_measure = torch.norm(trasl, dim=1)
266 | combined_measure = torch.sqrt(t_measure**2 + r_measure**2)
267 |
268 | return combined_measure, r_measure, t_measure
269 |
270 |
271 | def project_pcd(
272 | xyz_pcd: torch.Tensor,
273 | intrinsics_to: torch.Tensor,
274 | depth_to: torch.Tensor,
275 | extrinsics_from_to: torch.Tensor | None = None,
276 | ) -> torch.Tensor:
277 | # transform pcd
278 | batch, _, _ = xyz_pcd.shape
279 | if extrinsics_from_to is None:
280 | extrinsics_from_to = torch.eye(4, dtype=xyz_pcd.dtype, device=xyz_pcd.device)[
281 | None
282 | ].repeat(batch, 1, 1)
283 | xyz_pcd_to = (
284 | intrinsics_to
285 | @ (
286 | extrinsics_from_to
287 | @ torch.nn.functional.pad(xyz_pcd, [0, 1], "constant", 1.0).permute(0, 2, 1)
288 | )[:, :3]
289 | )
290 |
291 | # project depth to 2D
292 | h_to, w_to = depth_to.shape[-2:]
293 | u, v = torch.unbind(xyz_pcd_to[:, :2] / xyz_pcd_to[:, -1:], 1)
294 | u, v = torch.round(u).to(torch.long), torch.round(v).to(torch.long)
295 | mask = (u >= 0) & (v >= 0) & (u < w_to) & (v < h_to)
296 |
297 | for b in range(batch):
298 | used_mask = mask[b]
299 | used_u, used_v = u[b, used_mask], v[b, used_mask]
300 | prev_depths = depth_to[b, 0, used_v, used_u]
301 | new_depths = xyz_pcd_to[b, :, used_mask][-1]
302 | merged_depths = torch.where(
303 | (prev_depths == 0) & (new_depths > 0), new_depths, prev_depths
304 | )
305 | depth_to[b, 0, used_v, used_u] = merged_depths
306 |
307 | return depth_to
308 |
309 |
310 | def project_depth(
311 | depth: Tensor,
312 | intrinsics_from: Tensor,
313 | intrinsics_to: Tensor,
314 | extrinsics_from_to: Tensor,
315 | depth_to: Tensor,
316 | ) -> tuple[Tensor, Tensor]:
317 | # project the depth in 3D
318 | batch, _, h, w = depth.shape
319 | xyz_pcd_from = kornia.geometry.depth_to_3d(depth, intrinsics_from).permute(
320 | 0, 2, 3, 1
321 | )
322 | xyz_pcd_to = intrinsics_to @ (
323 | (
324 | extrinsics_from_to
325 | @ torch.nn.functional.pad(
326 | xyz_pcd_from.view(batch, -1, 3), [0, 1], "constant", 1.0
327 | ).permute(0, 2, 1)
328 | )[:, :3]
329 | )
330 | xyz_pcd_to = xyz_pcd_to.permute(0, 2, 1).view(batch, h, w, 3)
331 |
332 | # project depth to 2D
333 | h_to, w_to = depth_to.shape[-2:]
334 | u_subpix, v_subpix = torch.unbind(xyz_pcd_to[..., :2] / xyz_pcd_to[..., -1:], -1)
335 | u, v = torch.round(u_subpix).to(torch.long), torch.round(v_subpix).to(torch.long)
336 | mask = (u >= 0) & (v >= 0) & (u < w_to) & (v < h_to) & (depth[:, 0] > 0)
337 | subpix = torch.zeros(
338 | batch, 2, h_to, w_to, dtype=depth_to.dtype, device=depth_to.device
339 | )
340 | for b in range(batch):
341 | used_mask = mask[b]
342 | used_u, used_v = u[b, used_mask], v[b, used_mask]
343 | prev_depths = depth_to[b, 0, used_v, used_u]
344 | new_depths = xyz_pcd_to[b, used_mask][:, -1]
345 | merged_depths = torch.where(prev_depths == 0, new_depths, prev_depths)
346 | depth_to[b, 0, used_v, used_u] = merged_depths
347 | subpix[b, 0, used_v, used_u] = u_subpix[b, used_mask] / w_to
348 | subpix[b, 1, used_v, used_u] = v_subpix[b, used_mask] / h_to
349 |
350 | return depth_to, subpix
351 |
352 |
353 | def prepare_plot(
354 | gt_depth,
355 | pred_depth,
356 | image_target,
357 | image_source,
358 | hints,
359 | gt_dilation: int | None = None,
360 | hints_dilation: int | None = None,
361 | depth_vmin: float = 0.0,
362 | depth_vmax: float = 10.0,
363 | ) -> GridFigure:
364 | b, _, h, w = image_source.shape
365 | grid = GridFigure(b, 4, size=(h // 2, w // 2))
366 |
367 | if hints_dilation:
368 | hints = dilation(
369 | hints, torch.ones((hints_dilation, hints_dilation), device=hints.device)
370 | )
371 | if gt_dilation:
372 | gt_depth = dilation(
373 | gt_depth, torch.ones((gt_dilation, gt_dilation), device=gt_depth.device)
374 | )
375 |
376 | dkwargs = {
377 | "vmin": depth_vmin,
378 | "vmax": depth_vmax,
379 | "cmap": "magma_r",
380 | }
381 |
382 | idx = 0
383 | for bi in range(b):
384 |
385 | # source view
386 | grid.imshow(idx := idx + 1, image_source[bi])
387 |
388 | # target view + hints
389 | img_target = image_target[bi]
390 | img_target = (img_target - img_target.min()) / (
391 | img_target.max() - img_target.min()
392 | )
393 | img_target = (
394 | (img_target.cpu().numpy() * 255).astype(np.uint8).transpose([1, 2, 0])
395 | )
396 | grid.imshow(idx := idx + 1, img_target, norm=False)
397 | hints_color = color_depth(hints[bi, 0].cpu().numpy(), **dkwargs)
398 | hints_color[..., :3] = hints_color[..., :3] / hints_color[..., :3].max()
399 | hints_color[..., -1:] = hints_color[..., -1:] / hints_color[..., -1:].max()
400 | grid.imshow(idx, hints_color, norm=False)
401 |
402 | # prediction
403 | grid.imshow(idx := idx + 1, pred_depth[bi], **dkwargs, norm=False)
404 |
405 | # gt view
406 | gt_depth_imshow = torch.where(gt_depth[bi] > 0, gt_depth[bi], depth_vmax)
407 | grid.imshow(idx := idx + 1, gt_depth_imshow, **dkwargs, norm=False)
408 |
409 | return grid
410 |
411 |
412 | def pose_to_angles(pose: torch.Tensor):
413 | out_angles = []
414 | angles = euler_from_quaternion(
415 | *rotation_matrix_to_quaternion(
416 | pose[:, :3, :3].contiguous(), order=QuaternionCoeffOrder.WXYZ
417 | )[0]
418 | )
419 | angles = torch.stack(angles, 0)
420 | out_angles.append(angles)
421 | return torch.rad2deg(torch.stack(out_angles, 0))
422 |
423 |
424 | def pose_noise(rot_std: torch.Tensor, trasl_std: torch.Tensor):
425 | std_vec = torch.cat([rot_std, trasl_std], 1)
426 | mean_vec = torch.zeros_like(std_vec)
427 | normal_noise = torch.normal(mean_vec, std_vec)
428 | noises = []
429 | for bi in range(rot_std.shape[0]):
430 | rot_noise = torch.eye(4).to(normal_noise.dtype).to(normal_noise.device)
431 | quat = kornia.geometry.quaternion_from_euler(
432 | torch.deg2rad(normal_noise[bi, 0]),
433 | torch.deg2rad(normal_noise[bi, 1]),
434 | torch.deg2rad(normal_noise[bi, 2]),
435 | )
436 | rot_noise[:3, :3] = kornia.geometry.quaternion_to_rotation_matrix(
437 | torch.stack(quat, -1), QuaternionCoeffOrder.WXYZ
438 | )
439 | rot_noise[:3, -1] += normal_noise[bi, 3:]
440 | noises.append(rot_noise)
441 | return torch.stack(noises, 0)
442 |
443 |
444 | def find_parameters():
445 | ipy = get_ipython() # type: ignore
446 | out = io.StringIO()
447 | with redirect_stdout(out):
448 | ipy.magic("history")
449 | x = out.getvalue().split("\n")
450 | param_lines = list(
451 | itertools.takewhile(
452 | lambda s: s != "##% END",
453 | itertools.dropwhile(lambda s: s != "##% PARAMETERS", x),
454 | )
455 | )
456 | params = []
457 | for param in param_lines:
458 | if match := re.match("([a-zA-Z_0-9]+).*=", param):
459 | name = match.groups()[0]
460 | params.append(name)
461 | return params
462 |
--------------------------------------------------------------------------------
/evaluate/utils/metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Used metrics in this project
3 | """
4 |
5 | import torch
6 | from torch import Tensor
7 |
8 | __all__ = ["mae", "rmse", "sq_rel", "rel_thresh", "compute_metrics"]
9 |
10 |
11 | def mae(pred: Tensor, gt: Tensor) -> Tensor:
12 | return torch.mean(torch.abs(pred - gt))
13 |
14 |
15 | def rmse(pred: Tensor, gt: Tensor) -> Tensor:
16 | return torch.sqrt(torch.mean(torch.square(pred - gt)))
17 |
18 |
19 | def sq_rel(pred: Tensor, gt: Tensor, eps: float = 0.0) -> Tensor:
20 | return torch.mean(torch.square(pred - gt) / (gt + eps))
21 |
22 |
23 | def abs_rel(pred: Tensor, gt: Tensor, eps: float = 0.0) -> Tensor:
24 | return torch.mean(torch.abs(pred - gt) / (gt + eps))
25 |
26 |
27 | def rel_thresh(pred: Tensor, gt: Tensor, sigma: float) -> Tensor:
28 | rel = torch.maximum(gt / pred, pred / gt) < sigma
29 | rel = torch.mean(rel.float())
30 | return rel
31 |
32 |
33 | def compute_metrics_depth(
34 | pred: Tensor, gt: Tensor, label: str = ""
35 | ) -> dict[str, Tensor]:
36 | label = label if not label else label + "/"
37 | return {
38 | f"{label}mae": mae(pred, gt),
39 | f"{label}rmse": rmse(pred, gt),
40 | f"{label}sq_rel": sq_rel(pred, gt),
41 | f"{label}abs_rel": abs_rel(pred, gt),
42 | f"{label}rel_thresh_1.05": rel_thresh(pred, gt, 1.05),
43 | f"{label}rel_thresh_1.25": rel_thresh(pred, gt, 1.25),
44 | }
45 |
--------------------------------------------------------------------------------
/evaluate/utils/pose_estimation.py:
--------------------------------------------------------------------------------
1 | """
2 | Pose estimation by means of RGBD frame + RGB py means of
3 | LightGlue and PnP + LO-RANSAC
4 | """
5 |
6 | import torch
7 | from torch import Tensor
8 | from lightglue import DISK, LightGlue
9 | import poselib
10 | import numpy as np
11 | from kornia.morphology import dilation
12 |
13 | __all__ = ["PoseEstimation"]
14 |
15 |
16 | class PoseEstimation(torch.nn.Module):
17 | def __init__(
18 | self,
19 | max_num_keypoints: int | None = None,
20 | dilate_depth: int | None = None,
21 | compile: bool = False,
22 | **kwargs,
23 | ):
24 | super().__init__()
25 | self.extractor = DISK(max_num_keypoints=max_num_keypoints).eval()
26 | self.matcher = LightGlue(features="disk").eval()
27 | if compile:
28 | self.matcher.compile()
29 | self.dilate_depth = dilate_depth
30 |
31 | # kwargs default
32 | if "ransac_max_reproj_error" not in kwargs:
33 | kwargs["ransac_max_reproj_error"] = 1.0
34 |
35 | self.kwargs = kwargs
36 |
37 | @torch.autocast("cuda", enabled=False)
38 | def forward(
39 | self,
40 | image0: Tensor,
41 | image1: Tensor,
42 | depth1: Tensor,
43 | intrinsics0: Tensor,
44 | intrinsics1: Tensor,
45 | ):
46 | # depth dilation for easier matching
47 | if self.dilate_depth is not None:
48 | depth1 = dilation(
49 | depth1,
50 | torch.ones(
51 | [self.dilate_depth, self.dilate_depth],
52 | device=depth1.device,
53 | dtype=depth1.dtype,
54 | ),
55 | )
56 |
57 | with torch.no_grad():
58 | batch_size = image0.shape[0]
59 | poses = []
60 | for b in range(batch_size):
61 | feats0 = self.extractor.extract(image0[b])
62 | feats1 = self.extractor.extract(image1[b])
63 | matches01 = self.matcher({"image0": feats0, "image1": feats1})
64 | matches01 = matches01["matches"]
65 | matches0 = _to_numpy(feats0["keypoints"][0][matches01[0][:, 0]])
66 | matches1 = _to_numpy(feats1["keypoints"][0][matches01[0][:, 1]])
67 |
68 | depths1 = _to_numpy(depth1)[
69 | b,
70 | 0,
71 | matches1[:, 1].round().astype(int),
72 | matches1[:, 0].round().astype(int),
73 | ]
74 | valid_mask = depths1 > 0
75 | pose = np.eye(4)
76 | pose[:3] = poselib.estimate_absolute_pose(
77 | matches0[valid_mask],
78 | _depth_to_3d(
79 | depths1[valid_mask],
80 | matches1[valid_mask],
81 | _to_numpy(intrinsics1[b]),
82 | ),
83 | {
84 | "model": "SIMPLE_PINHOLE",
85 | "width": image0.shape[-1],
86 | "height": image0.shape[-2],
87 | "params": [
88 | intrinsics0[b, 0, 0].item(), # fx
89 | intrinsics0[b, 0, 2].item(), # cx
90 | intrinsics0[b, 1, 2].item(), # cy
91 | ],
92 | },
93 | {
94 | "_".join(k.split("_")[1:]): v
95 | for k, v in self.kwargs.items()
96 | if k.startswith("ransac_")
97 | },
98 | {
99 | "_".join(k.split("_")[1:]): v
100 | for k, v in self.kwargs.items()
101 | if k.startswith("ransac_")
102 | },
103 | )[0].Rt
104 | poses.append(pose)
105 | return (
106 | torch.from_numpy(np.stack(poses, 0)).to(image0.device).to(image0.dtype)
107 | )
108 |
109 |
110 | # utils
111 |
112 |
113 | def _to_numpy(x: Tensor):
114 | return x.detach().cpu().numpy()
115 |
116 |
117 | def _depth_to_3d(depths: np.ndarray, coords: np.ndarray, intrinsics: np.ndarray):
118 | fx, fy = intrinsics[0, 0], intrinsics[1, 1]
119 | cx, cy = intrinsics[0, 2], intrinsics[1, 2]
120 | x3d = depths * (coords[:, 0] - cx) / fx
121 | y3d = depths * (coords[:, 1] - cy) / fy
122 | return np.stack([x3d, y3d, depths], -1)
123 |
124 |
125 | def _norm(img: Tensor):
126 | return (img - img.min()) / (img.max() - img.min())
127 |
--------------------------------------------------------------------------------
/lib/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreaconti/depth-on-demand/7376cd0f7a67e583b3d0eff86641cfd7af83de67/lib/__init__.py
--------------------------------------------------------------------------------
/lib/benchmark/mesh/__init__.py:
--------------------------------------------------------------------------------
1 | try:
2 | from ._fusion import TSDFFusion
3 | from ._evaluation import mesh_metrics
4 |
5 | __all__ = ["TSDFFusion", "mesh_metrics"]
6 | except ImportError:
7 | import warnings
8 |
9 | warnings.warn("open3d not found, can't import this module")
10 |
--------------------------------------------------------------------------------
/lib/benchmark/mesh/_evaluation.py:
--------------------------------------------------------------------------------
1 | """
2 | Evaluation procedure to compute metrics
3 | on meshes
4 | """
5 |
6 | import open3d as o3d
7 | from pytorch3d.ops import knn_points
8 | import numpy as np
9 | from pathlib import Path
10 | import torch
11 | import warnings
12 |
13 | __all__ = ["mesh_metrics"]
14 |
15 |
16 | def mesh_metrics(
17 | mesh_prediction: o3d.geometry.TriangleMesh | str | Path,
18 | mesh_groundtruth: o3d.geometry.TriangleMesh | str | Path,
19 | world2grid: np.ndarray | str | Path | None = None,
20 | occlusion_mask: np.ndarray | str | Path | None = None,
21 | dist_threshold: float = 0.05,
22 | max_dist: float = 1.0,
23 | num_points_samples: int = 200000,
24 | device: str | torch.device = "cpu",
25 | ) -> dict[str, torch.Tensor]:
26 | """
27 | Takes in input the predicted and groundtruth meshes and computes a set
28 | of common metrics between them: accuracy and completion, precision and
29 | recall and their f1_score. To compute fair metrics the predicted mesh
30 | can be masked to remove areas not available in the groundtruth providing
31 | the occlusion mask and the world2grid parameters (both).
32 | """
33 | # read data if not already done
34 | assert (
35 | occlusion_mask is None
36 | and world2grid is None
37 | or occlusion_mask is not None
38 | and world2grid is not None
39 | ), "occlusion_mask and world2grid must be both provided"
40 | if isinstance(mesh_groundtruth, (str, Path)):
41 | mesh_groundtruth = o3d.io.read_triangle_mesh(str(mesh_groundtruth))
42 | if isinstance(mesh_prediction, (str, Path)):
43 | mesh_prediction = o3d.io.read_triangle_mesh(str(mesh_prediction))
44 | if world2grid is not None and isinstance(world2grid, (str, Path)):
45 | world2grid = np.loadtxt(world2grid, dtype=np.float32)
46 | if occlusion_mask is not None and isinstance(occlusion_mask, (str, Path)):
47 | occlusion_mask = np.load(occlusion_mask).astype(np.float32)
48 |
49 | # compute gt -> pred distance
50 | points_pred = torch.from_numpy(
51 | np.asarray(
52 | mesh_prediction.sample_points_uniformly(num_points_samples).points,
53 | dtype=np.float32,
54 | )
55 | )
56 | points_pred = points_pred.to(device)
57 | points_gt = torch.from_numpy(
58 | np.asarray(mesh_groundtruth.vertices, dtype=np.float32)
59 | ).to(device)
60 | world2grid = torch.from_numpy(world2grid).to(device)
61 | occlusion_mask = torch.from_numpy(occlusion_mask).to(device).float()
62 | gt2pred_dist = chamfer_distance(points_gt[None], points_pred[None], max_dist)
63 |
64 | # compute pred -> gt distance
65 | points_pred_filtered = filter_occluded_points(
66 | points_pred, world2grid, occlusion_mask, device
67 | )
68 | pred2gt_dist = chamfer_distance(
69 | points_pred_filtered[None], points_gt[None], max_dist
70 | )
71 |
72 | # compute metrics
73 | accuracy = torch.mean(pred2gt_dist)
74 | completion = torch.mean(gt2pred_dist)
75 | precision = (pred2gt_dist <= dist_threshold).float().mean()
76 | recall = (gt2pred_dist <= dist_threshold).float().mean()
77 | f1_score = 2 * precision * recall / (precision + recall)
78 | chamfer = 0.5 * (accuracy + completion)
79 |
80 | return {
81 | "precision": precision,
82 | "recall": recall,
83 | "f1_score": f1_score,
84 | "accuracy": accuracy,
85 | "completion": completion,
86 | "chamfer": chamfer,
87 | }
88 |
89 |
90 | def chamfer_distance(points1, points2, max_dist):
91 | l2dists = knn_points(points1, points2).dists
92 | sqdists = torch.where(l2dists > 0, torch.sqrt(l2dists), 0.0)
93 | sqdists = torch.minimum(sqdists, torch.full_like(sqdists, max_dist))
94 | return sqdists
95 |
96 |
97 | def filter_occluded_points(points_pred, world2grid, occlusion_mask, device):
98 | dim_x = occlusion_mask.shape[0]
99 | dim_y = occlusion_mask.shape[1]
100 | dim_z = occlusion_mask.shape[2]
101 | num_points_pred = points_pred.shape[0]
102 |
103 | # Transform points to bbox space.
104 | R_world2grid = world2grid[:3, :3].view(1, 3, 3).expand(num_points_pred, -1, -1)
105 | t_world2grid = world2grid[:3, 3].view(1, 3, 1).expand(num_points_pred, -1, -1)
106 |
107 | points_pred_coords = (
108 | torch.matmul(R_world2grid, points_pred.view(num_points_pred, 3, 1))
109 | + t_world2grid
110 | ).view(num_points_pred, 3)
111 |
112 | # Normalize to [-1, 1]^3 space.
113 | # The world2grid transforms world positions to voxel centers, so we need to
114 | # use "align_corners=True".
115 | points_pred_coords[:, 0] /= dim_x - 1
116 | points_pred_coords[:, 1] /= dim_y - 1
117 | points_pred_coords[:, 2] /= dim_z - 1
118 | points_pred_coords = points_pred_coords * 2 - 1
119 |
120 | # Trilinearly interpolate occlusion mask.
121 | # Occlusion mask is given as (x, y, z) storage, but the grid_sample method
122 | # expects (c, z, y, x) storage.
123 | visibility_mask = 1 - occlusion_mask.view(dim_x, dim_y, dim_z)
124 | visibility_mask = visibility_mask.permute(2, 1, 0).contiguous()
125 | visibility_mask = visibility_mask.view(1, 1, dim_z, dim_y, dim_x)
126 |
127 | points_pred_coords = points_pred_coords.view(1, 1, 1, num_points_pred, 3)
128 |
129 | points_pred_visibility = torch.nn.functional.grid_sample(
130 | visibility_mask,
131 | points_pred_coords,
132 | mode="bilinear",
133 | padding_mode="zeros",
134 | align_corners=True,
135 | ).to(device)
136 |
137 | points_pred_visibility = points_pred_visibility.view(num_points_pred)
138 |
139 | eps = 1e-5
140 | points_pred_visibility = points_pred_visibility >= 1 - eps
141 |
142 | # Filter occluded predicted points.
143 | if points_pred_visibility.sum() == 0:
144 | # If no points are visible, we keep the original points, otherwise
145 | # we would penalize the sample as if nothing is predicted.
146 | warnings.warn(
147 | "All points occluded, keeping all predicted points!", RuntimeWarning
148 | )
149 | points_pred_visible = points_pred.clone()
150 | else:
151 | points_pred_visible = points_pred[points_pred_visibility]
152 |
153 | return points_pred_visible
154 |
--------------------------------------------------------------------------------
/lib/benchmark/mesh/_fusion.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities to build meshes by RGBD frames
3 | """
4 |
5 | import open3d as o3d
6 | import numpy as np
7 | import imageio.v3 as imageio
8 | from typing import Literal
9 | from pathlib import Path
10 |
11 | __all__ = ["TSDFFusion"]
12 |
13 |
14 | class TSDFFusion:
15 | """
16 | Fusion utility to build a mesh from multiple RGBD calibrated and with pose
17 | frames through TSDF.
18 | """
19 |
20 | def __init__(
21 | self,
22 | engine: Literal["open3d", "open3d-tensor"] = "open3d",
23 | # volume info
24 | voxel_length=0.04,
25 | sdf_trunc=3 * 0.04,
26 | # defaults
27 | color_type: o3d.pipelines.integration.TSDFVolumeColorType = o3d.pipelines.integration.TSDFVolumeColorType.RGB8,
28 | depth_scale: float = 1.0,
29 | depth_trunc: float = 3.0,
30 | convert_rgb_to_intensity: bool = False,
31 | device: str = "cuda:0",
32 | ):
33 | assert engine in ["open3d", "open3d-tensor"]
34 | self.engine = engine
35 | self.device = o3d.core.Device(device.upper())
36 | self.depth_scale = depth_scale
37 | self.depth_trunc = depth_trunc
38 | self.voxel_length = voxel_length
39 | self.convert_rgb_to_intensity = convert_rgb_to_intensity
40 |
41 | if engine == "open3d":
42 | self.volume = o3d.pipelines.integration.ScalableTSDFVolume(
43 | voxel_length=voxel_length,
44 | sdf_trunc=sdf_trunc,
45 | color_type=color_type,
46 | )
47 | elif engine == "open3d-tensor":
48 | self.volume = o3d.t.geometry.VoxelBlockGrid(
49 | attr_names=("tsdf", "weight", "color"),
50 | attr_dtypes=(o3d.core.float32, o3d.core.float32, o3d.core.float32),
51 | attr_channels=((1,), (1,), (3,)),
52 | voxel_size=voxel_length,
53 | device=self.device,
54 | )
55 |
56 | def convert_intrinsics(
57 | self,
58 | intrinsics: np.ndarray,
59 | height: int | None = None,
60 | width: int | None = None,
61 | ) -> o3d.camera.PinholeCameraIntrinsic | o3d.core.Tensor:
62 | intrins = o3d.camera.PinholeCameraIntrinsic(
63 | width if width else int(intrinsics[0, 2] * 2),
64 | height if height else int(intrinsics[1, 2] * 2),
65 | intrinsics,
66 | )
67 | if self.engine == "open3d-tensor":
68 | intrins = o3d.core.Tensor(intrins.intrinsic_matrix, o3d.core.Dtype.Float64)
69 | return intrins
70 |
71 | def read_rgbd(
72 | self,
73 | image: str | Path,
74 | depth: str | Path,
75 | depth_trunc: float | None = None,
76 | convert_rgb_to_intensity: bool | None = None,
77 | ) -> o3d.geometry.RGBDImage | tuple[o3d.core.Tensor, o3d.core.Tensor]:
78 | if self.engine == "open3d":
79 | return o3d.geometry.RGBDImage.create_from_color_and_depth(
80 | imageio.imread(image),
81 | imageio.imread(depth),
82 | depth_trunc=self.depth_trunc if depth_trunc is None else depth_trunc,
83 | convert_rgb_to_intensity=self.convert_rgb_to_intensity
84 | if convert_rgb_to_intensity is None
85 | else convert_rgb_to_intensity,
86 | )
87 | else:
88 | color = o3d.t.io.read_image(image).to(self.device)
89 | depth = o3d.t.io.read_image(depth).to(self.device)
90 | return color, depth
91 |
92 | def convert_rgbd(
93 | self,
94 | image: np.ndarray,
95 | depth: np.ndarray,
96 | depth_scale: float | None = None,
97 | depth_trunc: float | None = None,
98 | convert_rgb_to_intensity: bool | None = None,
99 | ) -> o3d.geometry.RGBDImage | tuple[o3d.core.Tensor, o3d.core.Tensor]:
100 | if self.engine == "open3d":
101 | if image.dtype != np.uint8:
102 | image = (image - image.min()) / (image.max() - image.min())
103 | image = np.round(image * 255).astype(np.uint8)
104 | image = o3d.geometry.Image(np.ascontiguousarray(image))
105 | depth = o3d.geometry.Image(depth[..., 0])
106 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth(
107 | image,
108 | depth,
109 | depth_scale=self.depth_scale if depth_scale is None else depth_scale,
110 | depth_trunc=self.depth_trunc if depth_trunc is None else depth_trunc,
111 | convert_rgb_to_intensity=self.convert_rgb_to_intensity
112 | if convert_rgb_to_intensity is None
113 | else convert_rgb_to_intensity,
114 | )
115 | return rgbd
116 | else:
117 | image = o3d.t.geometry.Image(
118 | o3d.core.Tensor(
119 | np.ascontiguousarray(image).astype(np.float32), device=self.device
120 | )
121 | )
122 | depth = o3d.t.geometry.Image(
123 | o3d.core.Tensor(depth[..., 0], device=self.device)
124 | )
125 | return image, depth
126 |
127 | def integrate_rgbd(
128 | self,
129 | image: np.ndarray | str | Path,
130 | depth: np.ndarray | str | Path,
131 | intrinsics: np.ndarray,
132 | extrinsics: np.ndarray,
133 | **kwargs,
134 | ):
135 | if isinstance(image, np.ndarray) and isinstance(depth, np.ndarray):
136 | rgbd = self.convert_rgbd(image, depth, **kwargs)
137 | elif isinstance(image, (str, Path)) and isinstance(depth, (str, Path)):
138 | rgbd = self.read_rgbd(image, depth, **kwargs)
139 | else:
140 | raise ValueError(
141 | f"not supported image and depth types ({type(image), type(depth)})"
142 | )
143 |
144 | if self.engine == "open3d":
145 | w, h = rgbd.color.get_max_bound().astype(int)
146 | self.volume.integrate(
147 | rgbd,
148 | self.convert_intrinsics(intrinsics, h, w),
149 | np.linalg.inv(extrinsics),
150 | )
151 | else:
152 | color, depth = rgbd
153 | extrins = o3d.core.Tensor(np.linalg.inv(extrinsics), o3d.core.Dtype.Float64)
154 | intrins = o3d.core.Tensor(intrinsics, o3d.core.Dtype.Float64)
155 | coords = self.volume.compute_unique_block_coordinates(
156 | depth,
157 | intrins,
158 | extrins,
159 | depth_scale=self.depth_scale,
160 | depth_max=self.depth_trunc,
161 | )
162 | self.volume.integrate(
163 | coords,
164 | depth,
165 | color,
166 | intrins,
167 | extrins,
168 | depth_scale=float(self.depth_scale),
169 | depth_max=float(self.depth_trunc),
170 | )
171 |
172 | def write_triangle_mesh(self, path: str | Path):
173 | return o3d.io.write_triangle_mesh(str(path), self.triangle_mesh())
174 |
175 | def triangle_mesh(self):
176 | if self.engine == "open3d":
177 | return self.volume.extract_triangle_mesh()
178 | else:
179 | return self.volume.extract_triangle_mesh().to_legacy()
180 |
181 | def reset(self):
182 | if self.engine == "open3d":
183 | self.volume.reset()
184 | else:
185 | self.volume = o3d.t.geometry.VoxelBlockGrid(
186 | attr_names=("tsdf", "weight", "color"),
187 | attr_dtypes=(o3d.core.float32, o3d.core.float32, o3d.core.float32),
188 | attr_channels=((1,), (1,), (3,)),
189 | voxel_size=self.voxel_length,
190 | device=self.device,
191 | )
192 |
--------------------------------------------------------------------------------
/lib/dataset/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Common interface to load the datasets, simply import the ``load_datamodule`` function,
3 | provide the name of the dataset to load and optionally the specific arguments for that
4 | dataset
5 | """
6 |
7 | import torch
8 | from .scannetv2 import ScanNetV2DataModule
9 | from .sevenscenes import SevenScenesDataModule
10 | from .tartanair import TartanairDataModule
11 | from .kitti import KittiDataModule
12 | from ._utils import list_scans
13 | from lightning import LightningDataModule
14 | import torchvision.transforms.functional as TF
15 |
16 | __all__ = ["load_datamodule", "list_scans"]
17 |
18 |
19 | def load_datamodule(name: str, /, **kwargs) -> LightningDataModule:
20 | match name:
21 | case "scannetv2":
22 | DataModule = ScanNetV2DataModule
23 | case "sevenscenes":
24 | DataModule = SevenScenesDataModule
25 | case "tartanair":
26 | DataModule = TartanairDataModule
27 | case "kitti":
28 | DataModule = KittiDataModule
29 | case other:
30 | raise ValueError(f"dataset {other} not available")
31 |
32 | return DataModule(**kwargs, eval_transform=_Preprocess(name))
33 |
34 |
35 | class _Preprocess:
36 | def __init__(self, dataset: str):
37 | self.dataset = dataset
38 |
39 | def __call__(self, sample: dict) -> dict:
40 | for k in sample:
41 | if "path" not in k:
42 | if "intrinsics" not in k and "extrinsics" not in k and "pcd" not in k:
43 | sample[k] = TF.to_tensor(sample[k])
44 | else:
45 | sample[k] = torch.from_numpy(sample[k])
46 | if self.dataset == "scannetv2":
47 | if "image" in k or "depth" in k:
48 | sample[k] = TF.center_crop(sample[k], [464, 624])
49 | if "intrinsics" in k:
50 | intrins = sample[k].clone()
51 | intrins[0, -1] -= 8
52 | intrins[1, -1] -= 8
53 | sample[k] = intrins
54 | if "image" in k:
55 | sample[k] = TF.normalize(
56 | sample[k], mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
57 | )
58 | return sample
59 |
--------------------------------------------------------------------------------
/lib/dataset/_resources/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Put in this folder the assets that you want to track with git
3 | and that you load in you datasets, like json files containing
4 | the splits, then you can load them by means of
5 |
6 | >>> from pathlib import Path
7 | >>> resource_path = Path(__file__).parent / "_resources/"
8 | """
9 |
--------------------------------------------------------------------------------
/lib/dataset/_resources/scannetv2/test_split.txt:
--------------------------------------------------------------------------------
1 | scene0707_00
2 | scene0708_00
3 | scene0709_00
4 | scene0710_00
5 | scene0711_00
6 | scene0712_00
7 | scene0713_00
8 | scene0714_00
9 | scene0715_00
10 | scene0716_00
11 | scene0717_00
12 | scene0718_00
13 | scene0719_00
14 | scene0720_00
15 | scene0721_00
16 | scene0722_00
17 | scene0723_00
18 | scene0724_00
19 | scene0725_00
20 | scene0726_00
21 | scene0727_00
22 | scene0728_00
23 | scene0729_00
24 | scene0730_00
25 | scene0731_00
26 | scene0732_00
27 | scene0733_00
28 | scene0734_00
29 | scene0735_00
30 | scene0736_00
31 | scene0737_00
32 | scene0738_00
33 | scene0739_00
34 | scene0740_00
35 | scene0741_00
36 | scene0742_00
37 | scene0743_00
38 | scene0744_00
39 | scene0745_00
40 | scene0746_00
41 | scene0747_00
42 | scene0748_00
43 | scene0749_00
44 | scene0750_00
45 | scene0751_00
46 | scene0752_00
47 | scene0753_00
48 | scene0754_00
49 | scene0755_00
50 | scene0756_00
51 | scene0757_00
52 | scene0758_00
53 | scene0759_00
54 | scene0760_00
55 | scene0761_00
56 | scene0762_00
57 | scene0763_00
58 | scene0764_00
59 | scene0765_00
60 | scene0766_00
61 | scene0767_00
62 | scene0768_00
63 | scene0769_00
64 | scene0770_00
65 | scene0771_00
66 | scene0772_00
67 | scene0773_00
68 | scene0774_00
69 | scene0775_00
70 | scene0776_00
71 | scene0777_00
72 | scene0778_00
73 | scene0779_00
74 | scene0780_00
75 | scene0781_00
76 | scene0782_00
77 | scene0783_00
78 | scene0784_00
79 | scene0785_00
80 | scene0786_00
81 | scene0787_00
82 | scene0788_00
83 | scene0789_00
84 | scene0790_00
85 | scene0791_00
86 | scene0792_00
87 | scene0793_00
88 | scene0794_00
89 | scene0795_00
90 | scene0796_00
91 | scene0797_00
92 | scene0798_00
93 | scene0799_00
94 | scene0800_00
95 | scene0801_00
96 | scene0802_00
97 | scene0803_00
98 | scene0804_00
99 | scene0805_00
100 | scene0806_00
--------------------------------------------------------------------------------
/lib/dataset/_resources/sevenscenes/test_split.txt:
--------------------------------------------------------------------------------
1 | chess-seq-01
2 | chess-seq-02
3 | fire-seq-01
4 | fire-seq-02
5 | heads-seq-02
6 | office-seq-01
7 | office-seq-03
8 | pumpkin-seq-03
9 | pumpkin-seq-06
10 | redkitchen-seq-01
11 | redkitchen-seq-07
12 | stairs-seq-02
13 | stairs-seq-06
--------------------------------------------------------------------------------
/lib/dataset/_resources/tartanair/test_split.txt:
--------------------------------------------------------------------------------
1 | abandonedfactory-Easy-P011
2 | abandonedfactory-Hard-P011
3 | abandonedfactory_night-Easy-P013
4 | abandonedfactory_night-Hard-P014
5 | amusement-Easy-P008
6 | amusement-Hard-P007
7 | carwelding-Easy-P007
8 | endofworld-Easy-P009
9 | gascola-Easy-P008
10 | gascola-Hard-P009
11 | hospital-Easy-P036
12 | hospital-Hard-P049
13 | japanesealley-Easy-P007
14 | japanesealley-Hard-P005
15 | neighborhood-Easy-P021
16 | neighborhood-Hard-P017
17 | ocean-Easy-P013
18 | ocean-Hard-P009
19 | office-Hard-P007
20 | office2-Easy-P011
21 | office2-Hard-P010
22 | oldtown-Easy-P007
23 | oldtown-Hard-P008
24 | seasidetown-Easy-P009
25 | seasonsforest-Easy-P011
26 | seasonsforest-Hard-P006
27 | seasonsforest_winter-Easy-P009
28 | seasonsforest_winter-Hard-P018
29 | soulcity-Easy-P012
30 | soulcity-Hard-P009
31 | westerndesert-Easy-P013
32 | westerndesert-Hard-P007
33 |
--------------------------------------------------------------------------------
/lib/dataset/_utils.py:
--------------------------------------------------------------------------------
1 | from pathlib import Path
2 | from typing import Literal, Callable, Type
3 | import torchdata.datapipes as dp
4 | from torch import Tensor
5 | from torch.utils.data import DataLoader
6 | from lightning import LightningDataModule
7 | from collections import defaultdict
8 | import torchdata.datapipes as dp
9 | import numpy as np
10 |
11 | __all__ = ["GenericDataModule", "LoadDataset"]
12 |
13 |
14 | class GenericDataModule(LightningDataModule):
15 | def __init__(
16 | self,
17 | dataset: str,
18 | # dataset specific
19 | keyframes: str = "standard",
20 | load_prevs: int = 7,
21 | cycle: bool = True,
22 | filter_scans: list[str] | None = None,
23 | val_on_test: bool = False,
24 | split_test_scans_loaders: bool = False,
25 | # dataloader specific
26 | batch_size: int = 1,
27 | shuffle: bool = True,
28 | shuffle_keyframes: bool = False,
29 | order_sources_by_pose: bool = False,
30 | num_workers: int = 8,
31 | pin_memory: bool = True,
32 | load_dataset_cls: Type["LoadDataset"] | None = None,
33 | eval_transform: Callable[[dict], dict] | None = None,
34 | **sample_params,
35 | ):
36 | super().__init__()
37 |
38 | if load_prevs < 0:
39 | raise ValueError(f"0 <= load_prevs, not {load_prevs}")
40 |
41 | self.dataset = dataset
42 | self.keyframes = keyframes
43 | self.batch_size = batch_size
44 | self.shuffle = shuffle
45 | self.shuffle_keyframes = shuffle_keyframes
46 | self.num_workers = num_workers
47 | self.eval_transform = eval_transform
48 | self.cycle = cycle
49 | self.filter_scans = filter_scans
50 | self.load_prevs = load_prevs
51 | self.val_on_test = val_on_test
52 | self.split_test_scans_dataloaders = split_test_scans_loaders
53 | self.order_sources_by_pose = order_sources_by_pose
54 | self.pin_memory = pin_memory
55 | self._load_dataset_cls = load_dataset_cls
56 | self._sample_params = sample_params
57 | self._other_params = {}
58 |
59 | def _filter_traj(self, split) -> str | list[str]:
60 | if self.filter_scans is None:
61 | return split
62 | else:
63 | split_scans = list_scans(self.dataset, self.keyframes, split)
64 | return [traj for traj in self.filter_scans if traj in split_scans]
65 |
66 | def _dataloader(self, split, scan=None, num_workers=None):
67 | dl_builder = self._load_dataset_cls(
68 | self.dataset,
69 | keyframes=self.keyframes,
70 | load_prevs=self.load_prevs,
71 | cycle=False,
72 | batch_size=1,
73 | shuffle=False,
74 | shuffle_keyframes=False,
75 | order_sources_by_pose=self.order_sources_by_pose,
76 | transform=self.eval_transform,
77 | num_workers=self.num_workers if num_workers is None else num_workers,
78 | pin_memory=self.pin_memory,
79 | **self._sample_params,
80 | )
81 | return dl_builder.build_dataloader(
82 | self._filter_traj(split) if scan is None else scan,
83 | )
84 |
85 | def setup(self, stage: str | None = None):
86 | if stage not in ["test", None]:
87 | raise ValueError(f"stage {stage} invalid")
88 |
89 | if stage in ["test", None]:
90 | if not self.split_test_scans_dataloaders:
91 | self._test_dl = self._dataloader("test")
92 | else:
93 | keyframes = build_scan_frames_mapping(
94 | self.dataset,
95 | self.keyframes,
96 | "test",
97 | self.load_prevs,
98 | )
99 | if self.filter_scans:
100 | keyframes = {
101 | k: v for k, v in keyframes.items() if k in self.filter_scans
102 | }
103 | self._test_dl = [
104 | self._dataloader("test", {key: value}, num_workers=1)
105 | for key, value in keyframes.items()
106 | ]
107 |
108 | def test_dataloader(self):
109 | return self._test_dl
110 |
111 |
112 | class LoadDataset:
113 | """
114 | This function embodies the whole creation of a datalaoder, the unique method
115 | to be overloaded is `load_sample` which loads a single sample of a sequence
116 | """
117 |
118 | def __init__(
119 | self,
120 | dataset: str,
121 | # dataset specific
122 | keyframes: Literal["standard", "dense", "offline"] = "standard",
123 | load_prevs: int = 7,
124 | cycle: bool = True,
125 | # dataloader specific
126 | batch_size: int = 1,
127 | shuffle: bool = False,
128 | shuffle_keyframes: bool = False,
129 | order_sources_by_pose: bool = True,
130 | transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None,
131 | num_workers: int = 8,
132 | pin_memory: bool = True,
133 | **kwargs,
134 | ):
135 | self.dataset = dataset
136 | self.keyframes = keyframes
137 | self.load_prevs = load_prevs
138 | self.cycle = cycle
139 | self.batch_size = batch_size
140 | self.shuffle = shuffle
141 | self.shuffle_keyframes = shuffle_keyframes
142 | self.order_sources_by_pose = order_sources_by_pose
143 | self.transform = transform
144 | self.num_workers = num_workers
145 | self.pin_memory = pin_memory
146 | self.sample_kwargs = kwargs
147 |
148 | def build_dataloader(
149 | self,
150 | split: Literal["train", "val", "test"] | str | list[str] | dict = "train",
151 | ):
152 | # checks
153 | if self.load_prevs < 0:
154 | raise ValueError(f"0 <= load_prevs, not {self.load_prevs}")
155 |
156 | # load the keyframe file
157 | if split in ["train", "val", "test"]:
158 | keyframes_path = (
159 | Path(__file__).parent
160 | / f"_resources/{self.dataset}/keyframes/{self.keyframes}/{split}_split.txt"
161 | )
162 | if not keyframes_path.exists():
163 | raise ValueError(
164 | f"split {split} for keyframes {self.keyframes} not available"
165 | )
166 | with open(keyframes_path, "rt") as f:
167 | keyframes_dict = defaultdict(_empty_dict)
168 | for ln in f:
169 | scene, keyframe, *src_frames = ln.split()
170 | keyframes_dict[scene][keyframe] = src_frames[: self.load_prevs]
171 | elif isinstance(split, (str, list)):
172 | split = [split] if isinstance(split, str) else split
173 | all_lines = []
174 | for file_path in (
175 | Path(__file__).parent
176 | / f"_resources/{self.dataset}/keyframes/{self.keyframes}"
177 | ).glob("*_split.txt"):
178 | with open(file_path, "rt") as f:
179 | all_lines.extend(f.readlines())
180 | keyframes_dict = defaultdict(_empty_dict)
181 | for scene in split:
182 | for ln in all_lines:
183 | if ln.startswith(scene):
184 | scene, keyframe, *src_frames = ln.split()
185 | keyframes_dict[scene][keyframe] = src_frames[: self.load_prevs]
186 | elif isinstance(split, dict):
187 | keyframes_dict = split
188 | else:
189 | raise ValueError(f"split type not allowed")
190 |
191 | # loading and processing pipeline
192 | keyframes_list = []
193 | for scene in keyframes_dict:
194 | keyframes_list.extend(
195 | [
196 | {"scan": scene, "keyframe": keyframe, "sources": sources}
197 | for keyframe, sources in keyframes_dict[scene].items()
198 | ]
199 | )
200 |
201 | ds = dp.map.SequenceWrapper(keyframes_list)
202 | ds = ds.shuffle() if self.shuffle else ds
203 | if self.cycle:
204 | if self.shuffle:
205 | ds = ds.cycle()
206 | else:
207 | ds = dp.iter.IterableWrapper(ds).cycle()
208 | ds = ds.sharding_filter() if self.shuffle or self.cycle else ds
209 | ds = ds.map(self._load)
210 | if self.order_sources_by_pose:
211 | ds = ds.map(order_by_pose)
212 | ds = ds.map(self.transform) if self.transform else ds
213 | return DataLoader(
214 | ds,
215 | batch_size=self.batch_size,
216 | shuffle=self.shuffle,
217 | num_workers=self.num_workers,
218 | pin_memory=self.pin_memory,
219 | )
220 |
221 | def load_sample(self, scan: str, idx: str, suffix: str = "") -> dict:
222 | raise NotImplementedError("please implement load sample")
223 |
224 | def _load(self, sample):
225 | out = {}
226 |
227 | # shuffle keyframes
228 | if self.shuffle_keyframes:
229 | all_frames = [sample["keyframe"]] + sample["sources"]
230 | keyframe, sources = all_frames[0], all_frames[1:]
231 | sample = {"scan": sample["scan"], "keyframe": keyframe, "sources": sources}
232 |
233 | # load data
234 | out = self.load_sample(sample["scan"], sample["keyframe"], **self.sample_kwargs)
235 | for idx, src in enumerate(sample["sources"]):
236 | out |= self.load_sample(
237 | sample["scan"], src, suffix=f"_prev_{idx}", **self.sample_kwargs
238 | )
239 | return out
240 |
241 |
242 | ####
243 |
244 |
245 | def list_scans(dataset: str, keyframes: str, split: str | None = None):
246 | """
247 | List all the scans used in a specific split of a dataset for a specific
248 | keyframe setting.
249 | """
250 | if split is not None:
251 | keyframes_path = (
252 | Path(__file__).parent
253 | / f"_resources/{dataset}/keyframes/{keyframes}/{split}_split.txt"
254 | )
255 | scans = []
256 | if keyframes_path.exists():
257 | scans = [ln.split()[0] for ln in open(keyframes_path, "rt")]
258 | else:
259 | scans = []
260 | for keyframes_path in (
261 | Path(__file__).parent / f"_resources/{dataset}/keyframes/{keyframes}"
262 | ).glob("*_split.txt"):
263 | scans.extend([ln.split()[0] for ln in open(keyframes_path, "rt")])
264 | return list(np.unique(scans))
265 |
266 |
267 | def _empty_dict():
268 | return {}
269 |
270 |
271 | def build_scan_frames_mapping(
272 | dataset: str,
273 | keyframes: str,
274 | split: str,
275 | load_prevs: int | None = None,
276 | ) -> dict[Path, dict[str, list[str]]]:
277 | """
278 | Given a dataset, its root, specified keyframes and split it build a dictionary
279 | mapping each scan in such split to a dictionary indexed by the target view and
280 | containing a list of valid source views.
281 | """
282 | keyframes_path = (
283 | Path(__file__).parent
284 | / f"_resources/{dataset}/keyframes/{keyframes}/{split}_split.txt"
285 | )
286 | if not keyframes_path.exists():
287 | raise ValueError(f"split {split} for keyframes {keyframes} not available")
288 | with open(keyframes_path, "rt") as f:
289 | keyframes_dict = defaultdict(_empty_dict)
290 | for ln in f:
291 | scene, keyframe, *src_frames = ln.split()
292 | if load_prevs is not None:
293 | src_frames = src_frames[:load_prevs]
294 | keyframes_dict[scene][keyframe] = src_frames
295 | return keyframes_dict
296 |
297 |
298 | def order_by_pose(ex):
299 | """
300 | Takes an example and orders source views according with their distance
301 | from the target view
302 | """
303 | srcs = sorted({int(k.split("prev_")[-1]) for k in ex.keys() if "prev_" in k})
304 | tgt_pose = ex["extrinsics"]
305 | distances = np.array(
306 | [
307 | _pose_distance(_inv_pose(tgt_pose) @ ex[f"extrinsics_prev_{src}"])[0]
308 | for src in srcs
309 | ]
310 | )
311 | order = np.argsort(distances)
312 | out = {k: v for k, v in ex.items() if "_prev_" not in k}
313 | for new_idx, prev_idx in enumerate(order):
314 | for k, v in ex.items():
315 | if f"prev_{prev_idx}" in k:
316 | out[k.replace(f"prev_{prev_idx}", f"prev_{new_idx}")] = v
317 | return out
318 |
319 |
320 | def _inv_pose(pose: np.ndarray):
321 | inverted = np.eye(4)
322 | inverted[:3, :3] = pose[:3, :3].T
323 | inverted[:3, 3:] = -inverted[:3, :3] @ pose[:3, 3:]
324 | return inverted
325 |
326 |
327 | def _pose_distance(pose: np.ndarray):
328 | rot = pose[:3, :3]
329 | trasl = pose[:3, 3]
330 | R_trace = rot.diagonal().sum()
331 | r_measure = np.sqrt(2 * (1 - np.minimum(np.ones_like(R_trace) * 3.0, R_trace) / 3))
332 | t_measure = np.linalg.norm(trasl)
333 | combined_measure = np.sqrt(t_measure**2 + r_measure**2)
334 | return combined_measure, r_measure, t_measure
335 |
--------------------------------------------------------------------------------
/lib/dataset/kitti.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataloader for TartanAir
3 | """
4 |
5 | from pathlib import Path
6 | import imageio.v3 as imageio
7 | import numpy as np
8 | from ._utils import GenericDataModule, LoadDataset
9 | import pykitti
10 |
11 |
12 | __all__ = ["KittiDataModule"]
13 |
14 |
15 | class KittiDataModule(GenericDataModule):
16 | def __init__(
17 | self,
18 | *args,
19 | root_raw: str | Path,
20 | root_completion: str | Path,
21 | min_depth: float = 1e-3,
22 | max_depth: float = 80.0,
23 | load_hints_pcd: bool = False,
24 | **kwargs,
25 | ):
26 | # prepare base class
27 | super().__init__(
28 | "kitti",
29 | *args,
30 | root_raw=root_raw,
31 | root_completion=root_completion,
32 | min_depth=min_depth,
33 | max_depth=max_depth,
34 | load_hints_pcd=load_hints_pcd,
35 | load_dataset_cls=KittiLoadSample,
36 | **kwargs,
37 | )
38 |
39 | @staticmethod
40 | def augmenting_targets():
41 | out = {
42 | "image": "image",
43 | "depth": "depth",
44 | "hints": "depth",
45 | "extrinsics": "rt",
46 | "intrinsics": "intrinsics",
47 | "hints_pcd": "pcd",
48 | }
49 | for i in range(20):
50 | out |= {
51 | f"image_prev_{i}": "image",
52 | f"depth_prev_{i}": "depth",
53 | f"hints_prev_{i}": "depth",
54 | f"extrinsics_prev_{i}": "rt",
55 | f"intrinsics_prev_{i}": "intrinsics",
56 | f"hints_pcd_prev_{i}": "pcd",
57 | }
58 | return out
59 |
60 |
61 | class KittiLoadSample(LoadDataset):
62 | def __init__(self, *args, **kwargs):
63 | super().__init__(*args, **kwargs)
64 | self._kitti_meta = {}
65 |
66 | def _get_meta(self, root_raw: str, scan: str):
67 | if scan not in self._kitti_meta:
68 | date = "_".join(scan.split("_")[:3])
69 | scan_id = scan.split("_")[4]
70 | self._kitti_meta[scan] = pykitti.raw(root_raw, date, scan_id)
71 | return self._kitti_meta[scan]
72 |
73 | def load_sample(
74 | self,
75 | scan: str,
76 | idx: str,
77 | root_raw: str | Path,
78 | root_completion: str | Path,
79 | min_depth: float = 1e-3,
80 | max_depth: float = 100.0,
81 | load_hints_pcd: bool = False,
82 | suffix: str = "",
83 | ) -> dict:
84 | root_raw = Path(root_raw)
85 | root_compl = Path(root_completion)
86 | date = "_".join(scan.split("_")[:3])
87 | compl_path = next(root_compl.glob(f"*/{scan}/proj_depth"))
88 |
89 | # img
90 | image = imageio.imread(root_raw / date / scan / f"image_02/data/{idx}.png")
91 |
92 | # hints (lidar projected)
93 | hints = (
94 | imageio.imread(compl_path / f"velodyne_raw/image_02/{idx}.png")[
95 | ..., None
96 | ].astype(np.float32)
97 | / 256.0
98 | )
99 | hints[hints < min_depth] = 0.0
100 | hints[hints > max_depth] = 0.0
101 |
102 | # gt
103 | depth = (
104 | imageio.imread(compl_path / f"groundtruth/image_02/{idx}.png")[
105 | ..., None
106 | ].astype(np.float32)
107 | / 256.0
108 | )
109 | depth[depth < min_depth] = 0.0
110 | depth[depth > max_depth] = 0.0
111 |
112 | # intrinsics
113 | scan_meta = self._get_meta(str(root_raw), scan)
114 | intrinsics = scan_meta.calib.K_cam2.astype(np.float32).copy()
115 |
116 | # extrinsics
117 | extrinsics = (
118 | scan_meta.oxts[int(idx)].T_w_imu
119 | @ np.linalg.inv(scan_meta.calib.T_velo_imu)
120 | @ np.linalg.inv(
121 | scan_meta.calib.R_rect_20 @ scan_meta.calib.T_cam0_velo_unrect
122 | )
123 | ).astype(np.float32)
124 |
125 | # lidar pcd
126 | pcd = None
127 | if load_hints_pcd:
128 | pcd = scan_meta.get_velo(int(idx))[:, :3]
129 | pcd = (
130 | scan_meta.calib.T_cam2_velo
131 | @ np.pad(pcd, [(0, 0), (0, 1)], constant_values=1.0).T
132 | ).T[:, :3]
133 | n_points = pcd.shape[0]
134 | padding = 130000 - n_points
135 | pcd = np.pad(pcd, [(0, padding), (0, 0)]).astype(np.float32)
136 |
137 | # crop frames
138 | h, w, _ = depth.shape
139 | lc = (w - 1216) // 2
140 | rc = w - 1216 - lc
141 | image = image[-256:, lc:-rc]
142 | depth = depth[-256:, lc:-rc]
143 | hints = hints[-256:, lc:-rc]
144 | intrinsics[0, -1] -= lc
145 | intrinsics[1, -1] -= h - 256
146 |
147 | # compose output dict
148 | out = {
149 | f"image{suffix}": image,
150 | f"hints{suffix}": hints,
151 | f"depth{suffix}": depth,
152 | f"intrinsics{suffix}": intrinsics,
153 | f"extrinsics{suffix}": extrinsics,
154 | }
155 | if pcd is not None:
156 | out[f"hints_pcd{suffix}"] = pcd
157 |
158 | return out
159 |
--------------------------------------------------------------------------------
/lib/dataset/scannetv2.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataloader for ScanNetV2
3 | """
4 |
5 | from pathlib import Path
6 | import imageio.v3 as imageio
7 | import numpy as np
8 |
9 | __all__ = ["ScanNetV2DataModule"]
10 |
11 | from ._utils import GenericDataModule, LoadDataset
12 |
13 |
14 | class ScanNetV2DataModule(GenericDataModule):
15 | def __init__(
16 | self,
17 | *args,
18 | root: str | Path,
19 | min_depth: float = 1e-3,
20 | max_depth: float = 10.0,
21 | **kwargs,
22 | ):
23 | super().__init__(
24 | "scannetv2",
25 | *args,
26 | root=root,
27 | min_depth=min_depth,
28 | max_depth=max_depth,
29 | load_dataset_cls=ScanNetV2LoadSample,
30 | **kwargs,
31 | )
32 |
33 |
34 | class ScanNetV2LoadSample(LoadDataset):
35 | def load_sample(
36 | self,
37 | scan: str,
38 | idx: str,
39 | root,
40 | min_depth: float = 1e-3,
41 | max_depth: float = 10.0,
42 | suffix: str = "",
43 | ) -> dict:
44 | root = Path(root)
45 | img = imageio.imread(root / scan / f"{idx}.image.jpg")
46 | depth = (
47 | imageio.imread(root / scan / f"{idx}.depth.png")[..., None] / 1000
48 | ).astype(np.float32)
49 | depth[(depth < min_depth) | (depth > max_depth)] = 0.0
50 | intrinsics = np.load(root / scan / f"{idx}.intrinsics.npy")
51 | extrinsics = np.load(root / scan / f"{idx}.extrinsics.npy")
52 |
53 | out = {}
54 | if not suffix:
55 | ply_path = root / f"{scan}_vh_clean.ply"
56 | occl_path = root / f"{scan}_occlusion_mask.npy"
57 | world2grid_path = root / f"{scan}_world2grid.txt"
58 | if ply_path.exists():
59 | out["gt_mesh_path"] = str(ply_path)
60 | if occl_path.exists():
61 | out["gt_mesh_occl_path"] = str(occl_path)
62 | if world2grid_path.exists():
63 | out["gt_mesh_world2grid_path"] = str(world2grid_path)
64 |
65 | return out | {
66 | f"image{suffix}": img,
67 | f"depth{suffix}": depth,
68 | f"intrinsics{suffix}": intrinsics,
69 | f"extrinsics{suffix}": extrinsics,
70 | }
71 |
--------------------------------------------------------------------------------
/lib/dataset/sevenscenes.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataloader for 7Scenes
3 | """
4 |
5 | from pathlib import Path
6 | import imageio.v3 as imageio
7 | import numpy as np
8 |
9 | __all__ = ["SevenScenesDataModule"]
10 |
11 | from ._utils import GenericDataModule, LoadDataset
12 |
13 |
14 | class SevenScenesDataModule(GenericDataModule):
15 | def __init__(
16 | self,
17 | *args,
18 | root: str | Path,
19 | min_depth: float = 1e-3,
20 | max_depth: float = 10.0,
21 | **kwargs,
22 | ):
23 | super().__init__(
24 | "sevenscenes",
25 | *args,
26 | root=root,
27 | min_depth=min_depth,
28 | max_depth=max_depth,
29 | load_dataset_cls=SceneScenesLoadSample,
30 | **kwargs,
31 | )
32 |
33 |
34 | class SceneScenesLoadSample(LoadDataset):
35 | def load_sample(
36 | self,
37 | scan: str,
38 | idx: str,
39 | root,
40 | min_depth: float = 1e-3,
41 | max_depth: float = 10.0,
42 | suffix: str = "",
43 | ) -> dict:
44 | root = Path(root)
45 | img = imageio.imread(root / scan / f"{idx}.image.jpg")
46 | depth = (
47 | imageio.imread(root / scan / f"{idx}.depth.png")[..., None] / 1000
48 | ).astype(np.float32)
49 | depth[(depth < min_depth) | (depth > max_depth)] = 0.0
50 | intrinsics = np.load(root / scan / f"{idx}.intrinsics.npy")
51 | extrinsics = np.load(root / scan / f"{idx}.extrinsics.npy")
52 |
53 | return {
54 | f"image{suffix}": img,
55 | f"depth{suffix}": depth,
56 | f"intrinsics{suffix}": intrinsics,
57 | f"extrinsics{suffix}": extrinsics,
58 | }
59 |
--------------------------------------------------------------------------------
/lib/dataset/tartanair.py:
--------------------------------------------------------------------------------
1 | """
2 | Dataloader for TartanAir
3 | """
4 |
5 | from pathlib import Path
6 | import imageio.v3 as imageio
7 | import json
8 | import numpy as np
9 | from ._utils import GenericDataModule, LoadDataset
10 |
11 |
12 | __all__ = ["TartanairDataModule"]
13 |
14 |
15 | class TartanairDataModule(GenericDataModule):
16 | def __init__(
17 | self,
18 | *args,
19 | root: str | Path,
20 | min_depth: float = 1e-3,
21 | max_depth: float = 100.0,
22 | **kwargs,
23 | ):
24 | super().__init__(
25 | "tartanair",
26 | *args,
27 | root=root,
28 | min_depth=min_depth,
29 | max_depth=max_depth,
30 | load_dataset_cls=TartanairLoadSample,
31 | **kwargs,
32 | )
33 |
34 |
35 | class TartanairLoadSample(LoadDataset):
36 | def load_sample(
37 | self,
38 | scan: str,
39 | idx: str,
40 | root: str | Path,
41 | min_depth: float = 1e-3,
42 | max_depth: float = 100.0,
43 | suffix: str = "",
44 | ) -> dict:
45 | root = Path(root)
46 | img = imageio.imread(root / scan / f"{idx}.image.jpg")
47 |
48 | depth = imageio.imread(root / scan / f"{idx}.depth.png")[..., None].astype(
49 | np.float32
50 | )
51 | drange = json.load(open(root / scan / f"{idx}.depth_range.json", "rt"))
52 | vmin, vmax = drange["vmin"], drange["vmax"]
53 | depth = vmin + (depth / 65535) * (vmax - vmin)
54 | depth[depth < min_depth] = 0.0
55 | depth[depth > max_depth] = 0.0
56 |
57 | intrinsics = np.load(root / scan / f"{idx}.intrinsics.npy")
58 | extrinsics = np.load(root / scan / f"{idx}.position.npy")
59 | extrinsics = _ned_to_cam.dot(extrinsics).dot(_ned_to_cam.T)
60 |
61 | return {
62 | f"image{suffix}": img,
63 | f"depth{suffix}": depth,
64 | f"intrinsics{suffix}": intrinsics,
65 | f"extrinsics{suffix}": extrinsics,
66 | }
67 |
68 |
69 | _ned_to_cam = np.array(
70 | [
71 | [0.0, 1.0, 0.0, 0.0],
72 | [0.0, 0.0, 1.0, 0.0],
73 | [1.0, 0.0, 0.0, 0.0],
74 | [0.0, 0.0, 0.0, 1.0],
75 | ],
76 | dtype=np.float32,
77 | )
78 |
--------------------------------------------------------------------------------
/lib/dataset/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import random
3 | import kornia
4 | from torch import Tensor
5 |
6 | __all__ = [
7 | "sparsify_depth",
8 | "project_depth",
9 | "project_pcd",
10 | "inv_pose",
11 | ]
12 |
13 |
14 | def sparsify_depth(
15 | depth: torch.Tensor, hints_perc: float | tuple[float, float] | int = 0.03
16 | ) -> torch.Tensor:
17 | if isinstance(hints_perc, tuple | list):
18 | hints_perc = random.uniform(*hints_perc)
19 |
20 | if hints_perc < 1.0:
21 | sparse_map = torch.rand_like(depth) < hints_perc
22 | sparse_depth = torch.where(
23 | sparse_map, depth, torch.tensor(0.0, dtype=depth.dtype, device=depth.device)
24 | )
25 | return sparse_depth
26 | else:
27 | b = depth.shape[0]
28 | idxs = torch.nonzero(depth[:, 0])
29 | idxs = idxs[torch.randperm(len(idxs))]
30 | sparse_depth = torch.zeros_like(depth)
31 | for bi in range(b):
32 | bidxs = idxs[idxs[:, 0] == bi][:hints_perc]
33 | sparse_depth[bi, 0, bidxs[:, 1], bidxs[:, 2]] = depth[
34 | bi, 0, bidxs[:, 1], bidxs[:, 2]
35 | ]
36 | return sparse_depth
37 |
38 |
39 | def project_depth(
40 | depth: Tensor,
41 | intrinsics_from: Tensor,
42 | intrinsics_to: Tensor,
43 | extrinsics_from_to: Tensor,
44 | depth_to: Tensor,
45 | ) -> Tensor:
46 | # project the depth in 3D
47 | batch, _, h, w = depth.shape
48 | xyz_pcd_from = kornia.geometry.depth_to_3d(depth, intrinsics_from).permute(
49 | 0, 2, 3, 1
50 | )
51 | xyz_pcd_to = intrinsics_to @ (
52 | (
53 | extrinsics_from_to
54 | @ torch.nn.functional.pad(
55 | xyz_pcd_from.view(batch, -1, 3), [0, 1], "constant", 1.0
56 | ).permute(0, 2, 1)
57 | )[:, :3]
58 | )
59 | xyz_pcd_to = xyz_pcd_to.permute(0, 2, 1).view(batch, h, w, 3)
60 |
61 | # project depth to 2D
62 | h_to, w_to = depth_to.shape[-2:]
63 | u, v = torch.unbind(xyz_pcd_to[..., :2] / xyz_pcd_to[..., -1:], -1)
64 | u, v = torch.round(u).to(torch.long), torch.round(v).to(torch.long)
65 | mask = (u >= 0) & (v >= 0) & (u < w_to) & (v < h_to) & (depth[:, 0] > 0)
66 |
67 | for b in range(batch):
68 | used_mask = mask[b]
69 | used_u, used_v = u[b, used_mask], v[b, used_mask]
70 | prev_depths = depth_to[b, 0, used_v, used_u]
71 | new_depths = xyz_pcd_to[b, used_mask][:, -1]
72 | merged_depths = torch.where(prev_depths == 0, new_depths, prev_depths)
73 | depth_to[b, 0, used_v, used_u] = merged_depths
74 |
75 | return depth_to
76 |
77 |
78 | def project_pcd(
79 | xyz_pcd: torch.Tensor,
80 | intrinsics_to: torch.Tensor,
81 | depth_to: torch.Tensor,
82 | extrinsics_from_to: torch.Tensor | None = None,
83 | ) -> torch.Tensor:
84 | # transform pcd
85 | batch, _, _ = xyz_pcd.shape
86 | if extrinsics_from_to is None:
87 | extrinsics_from_to = torch.eye(4, dtype=xyz_pcd.dtype, device=xyz_pcd.device)[
88 | None
89 | ].repeat(batch, 1, 1)
90 | xyz_pcd_to = (
91 | intrinsics_to
92 | @ (
93 | extrinsics_from_to
94 | @ torch.nn.functional.pad(xyz_pcd, [0, 1], "constant", 1.0).permute(0, 2, 1)
95 | )[:, :3]
96 | )
97 |
98 | # project depth to 2D
99 | h_to, w_to = depth_to.shape[-2:]
100 | u, v = torch.unbind(xyz_pcd_to[:, :2] / xyz_pcd_to[:, -1:], 1)
101 | u, v = torch.round(u).to(torch.long), torch.round(v).to(torch.long)
102 | mask = (u >= 0) & (v >= 0) & (u < w_to) & (v < h_to)
103 |
104 | for b in range(batch):
105 | used_mask = mask[b]
106 | used_u, used_v = u[b, used_mask], v[b, used_mask]
107 | prev_depths = depth_to[b, 0, used_v, used_u]
108 | new_depths = xyz_pcd_to[b, :, used_mask][-1]
109 | merged_depths = torch.where(
110 | (prev_depths == 0) & (new_depths > 0), new_depths, prev_depths
111 | )
112 | depth_to[b, 0, used_v, used_u] = merged_depths
113 |
114 | return depth_to
115 |
116 |
117 | def inv_pose(pose: torch.Tensor) -> torch.Tensor:
118 | rot_inv = pose[:, :3, :3].permute(0, 2, 1)
119 | tr_inv = -rot_inv @ pose[:, :3, -1:]
120 | pose_inv = torch.eye(4, dtype=pose.dtype, device=pose.device)[None]
121 | pose_inv = pose_inv.repeat(pose.shape[0], 1, 1)
122 | pose_inv[:, :3, :3] = rot_inv
123 | pose_inv[:, :3, -1:] = tr_inv
124 | return pose_inv
125 |
--------------------------------------------------------------------------------
/lib/dataset/utils/metrics.py:
--------------------------------------------------------------------------------
1 | """
2 | Used metrics in this project
3 | """
4 |
5 | import torch
6 | from torch import Tensor
7 |
8 | __all__ = ["mae", "rmse", "sq_rel", "abs_rel", "rel_thresh", "compute_metrics"]
9 |
10 |
11 | def mae(pred: Tensor, gt: Tensor) -> Tensor:
12 | return torch.mean(torch.abs(pred - gt))
13 |
14 |
15 | def rmse(pred: Tensor, gt: Tensor) -> Tensor:
16 | return torch.sqrt(torch.mean(torch.square(pred - gt)))
17 |
18 |
19 | def sq_rel(pred: Tensor, gt: Tensor, eps: float = 0.0) -> Tensor:
20 | return torch.mean(torch.square(pred - gt) / (gt + eps))
21 |
22 |
23 | def abs_rel(pred: Tensor, gt: Tensor, eps: float = 0.0) -> Tensor:
24 | return torch.mean(torch.abs(pred - gt) / (gt + eps))
25 |
26 |
27 | def rel_thresh(pred: Tensor, gt: Tensor, sigma: float) -> Tensor:
28 | rel = torch.maximum(gt / pred, pred / gt) < sigma
29 | rel = torch.mean(rel.float())
30 | return rel
31 |
32 |
33 | def compute_metrics_depth(
34 | pred: Tensor, gt: Tensor, label: str = ""
35 | ) -> dict[str, Tensor]:
36 | label = label if not label else label + "/"
37 | return {
38 | f"{label}mae": mae(pred, gt),
39 | f"{label}rmse": rmse(pred, gt),
40 | f"{label}sq_rel": sq_rel(pred, gt),
41 | f"{label}abs_rel": abs_rel(pred, gt),
42 | f"{label}rel_thresh_1.05": rel_thresh(pred, gt, 1.05),
43 | f"{label}rel_thresh_1.25": rel_thresh(pred, gt, 1.25),
44 | }
45 |
--------------------------------------------------------------------------------
/lib/visualize/__init__.py:
--------------------------------------------------------------------------------
1 | """
2 | Visualization tools for different needs, each module may require extra dependencies
3 | to install accordingly
4 | """
5 |
--------------------------------------------------------------------------------
/lib/visualize/depth.py:
--------------------------------------------------------------------------------
1 | """
2 | Utilities to visualize depth and pointclouds
3 | """
4 |
5 | import numpy as np
6 | from pathlib import Path
7 | from torch import Tensor
8 |
9 | try:
10 | import plyfile # type: ignore
11 | from kornia.geometry.depth import depth_to_3d # type: ignore
12 | from kornia.geometry import transform_points # type: ignore
13 | except ImportError:
14 | raise ImportError("To use depth visualize utilities install plyfile and kornia")
15 |
16 | __all__ = [
17 | "save_depth_ply",
18 | "save_pcd_ply",
19 | "save_depth_overlap_ply",
20 | ]
21 |
22 |
23 | def _norm_color(color):
24 | color = (color - color.min()) / (1e-6 + np.abs(color.max() - color.min()))
25 | return color * 255
26 |
27 |
28 | def save_depth_ply(
29 | filename: str | Path,
30 | depth: Tensor,
31 | intrinsics: Tensor,
32 | color: Tensor | None = None,
33 | _output_plyelement: bool = False,
34 | ):
35 | """
36 | Saves a depth map with optionally colors in a ply file,
37 | depth is of shape 1 x H x W and colors 3 x H x W if provided,
38 | """
39 | mask = depth[0] > 0
40 | pcd = depth_to_3d(depth[None], intrinsics[None]).permute(0, 2, 3, 1)[0][mask]
41 | if color is not None:
42 | color = color.permute(1, 2, 0)[mask]
43 | return save_pcd_ply(filename, pcd, color, _output_plyelement)
44 |
45 |
46 | def save_pcd_ply(
47 | filename: str | Path,
48 | pcd: Tensor,
49 | color: Tensor | None = None,
50 | _output_plyelement: bool = False,
51 | ):
52 | """
53 | Saves a a point cloud with optionally colors in a ply file,
54 | pcd is of shape N x 3 and colors N x 3 if provided
55 | """
56 |
57 | pcd = pcd.cpu().numpy()
58 | if color is not None:
59 | color = _norm_color(color.cpu().numpy())
60 | else:
61 | color = np.zeros_like(pcd)
62 | color[:, 0] += 255
63 | pcd = np.array(
64 | list(
65 | zip(
66 | pcd[:, 0],
67 | pcd[:, 1],
68 | pcd[:, 2],
69 | color[:, 0],
70 | color[:, 1],
71 | color[:, 2],
72 | )
73 | ),
74 | dtype=[
75 | ("x", "f4"),
76 | ("y", "f4"),
77 | ("z", "f4"),
78 | ("red", "u1"),
79 | ("green", "u1"),
80 | ("blue", "u1"),
81 | ],
82 | )
83 |
84 | if _output_plyelement:
85 | return plyfile.PlyElement.describe(pcd, "vertex")
86 | else:
87 | plyfile.PlyData([plyfile.PlyElement.describe(pcd, "vertex")]).write(filename)
88 |
89 |
90 | def save_depth_overlap_ply(
91 | filename: str | Path,
92 | depths: list[Tensor],
93 | colors: list[Tensor],
94 | intrinsics: list[Tensor] | None = None,
95 | extrinsics: list[Tensor] | None = None,
96 | ):
97 | """
98 | Takes a list of pointclouds or depth maps and their color map and saves all of them
99 | as a whole point cloud in ply format
100 | """
101 |
102 | if intrinsics is None:
103 | intrinsics = [None] * len(depths)
104 |
105 | elems = []
106 | for idx, (d, c, i) in enumerate(zip(depths, colors, intrinsics)):
107 | if d.dim() == 2:
108 | if extrinsics is None:
109 | elems.append(save_pcd_ply(None, d, c, True).data)
110 | else:
111 | d = transform_points(extrinsics[idx][None], d[None])[0]
112 | elems.append(save_pcd_ply(None, d, c, True).data)
113 | else:
114 | if extrinsics is None:
115 | elems.append(save_depth_ply(None, d, i, c, True).data)
116 | else:
117 | mask = d[0] > 0
118 | d = depth_to_3d(d[None], i[None]).permute(0, 2, 3, 1)[0][mask]
119 | d = transform_points(extrinsics[idx][None], d[None])[0]
120 | c = c.permute(1, 2, 0)[mask]
121 | elems.append(save_pcd_ply(None, d, c, True).data)
122 |
123 | plyelem = plyfile.PlyElement.describe(np.concatenate(elems, 0), "vertex")
124 | plyfile.PlyData([plyelem]).write(filename)
125 |
--------------------------------------------------------------------------------
/lib/visualize/flow.py:
--------------------------------------------------------------------------------
1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization
2 |
3 | # MIT License
4 | #
5 | # Copyright (c) 2018 Tom Runia
6 | #
7 | # Permission is hereby granted, free of charge, to any person obtaining a copy
8 | # of this software and associated documentation files (the "Software"), to deal
9 | # in the Software without restriction, including without limitation the rights
10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
11 | # copies of the Software, and to permit persons to whom the Software is
12 | # furnished to do so, subject to conditions.
13 | #
14 | # Author: Tom Runia
15 | # Date Created: 2018-08-03
16 |
17 | import numpy as np
18 |
19 | __all__ = ["flow_to_image"]
20 |
21 |
22 | def flow_to_image(flow_uv: np.ndarray, clip_flow=None, convert_to_bgr=False):
23 | """
24 | Expects a two dimensional flow image of shape.
25 |
26 | Args:
27 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
28 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
29 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
30 |
31 | Returns:
32 | np.ndarray: Flow visualization image of shape [H,W,3]
33 | """
34 | assert flow_uv.ndim == 3, "input flow must have three dimensions"
35 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]"
36 | if clip_flow is not None:
37 | flow_uv = np.clip(flow_uv, 0, clip_flow)
38 | u = flow_uv[:, :, 0]
39 | v = flow_uv[:, :, 1]
40 | rad = np.sqrt(np.square(u) + np.square(v))
41 | rad_max = np.max(rad)
42 | epsilon = 1e-5
43 | u = u / (rad_max + epsilon)
44 | v = v / (rad_max + epsilon)
45 | return _flow_uv_to_colors(u, v, convert_to_bgr)
46 |
47 |
48 | # utils
49 |
50 |
51 | def _make_colorwheel():
52 | """
53 | Generates a color wheel for optical flow visualization as presented in:
54 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
55 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
56 |
57 | Code follows the original C++ source code of Daniel Scharstein.
58 | Code follows the the Matlab source code of Deqing Sun.
59 |
60 | Returns:
61 | np.ndarray: Color wheel
62 | """
63 |
64 | RY = 15
65 | YG = 6
66 | GC = 4
67 | CB = 11
68 | BM = 13
69 | MR = 6
70 |
71 | ncols = RY + YG + GC + CB + BM + MR
72 | colorwheel = np.zeros((ncols, 3))
73 | col = 0
74 |
75 | # RY
76 | colorwheel[0:RY, 0] = 255
77 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
78 | col = col + RY
79 | # YG
80 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
81 | colorwheel[col : col + YG, 1] = 255
82 | col = col + YG
83 | # GC
84 | colorwheel[col : col + GC, 1] = 255
85 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
86 | col = col + GC
87 | # CB
88 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
89 | colorwheel[col : col + CB, 2] = 255
90 | col = col + CB
91 | # BM
92 | colorwheel[col : col + BM, 2] = 255
93 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
94 | col = col + BM
95 | # MR
96 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
97 | colorwheel[col : col + MR, 0] = 255
98 | return colorwheel
99 |
100 |
101 | def _flow_uv_to_colors(u, v, convert_to_bgr=False):
102 | """
103 | Applies the flow color wheel to (possibly clipped) flow components u and v.
104 |
105 | According to the C++ source code of Daniel Scharstein
106 | According to the Matlab source code of Deqing Sun
107 |
108 | Args:
109 | u (np.ndarray): Input horizontal flow of shape [H,W]
110 | v (np.ndarray): Input vertical flow of shape [H,W]
111 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
112 |
113 | Returns:
114 | np.ndarray: Flow visualization image of shape [H,W,3]
115 | """
116 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
117 | colorwheel = _make_colorwheel() # shape [55x3]
118 | ncols = colorwheel.shape[0]
119 | rad = np.sqrt(np.square(u) + np.square(v))
120 | a = np.arctan2(-v, -u) / np.pi
121 | fk = (a + 1) / 2 * (ncols - 1)
122 | k0 = np.floor(fk).astype(np.int32)
123 | k1 = k0 + 1
124 | k1[k1 == ncols] = 0
125 | f = fk - k0
126 | for i in range(colorwheel.shape[1]):
127 | tmp = colorwheel[:, i]
128 | col0 = tmp[k0] / 255.0
129 | col1 = tmp[k1] / 255.0
130 | col = (1 - f) * col0 + f * col1
131 | idx = rad <= 1
132 | col[idx] = 1 - rad[idx] * (1 - col[idx])
133 | col[~idx] = col[~idx] * 0.75 # out of range
134 | # Note the 2-i => BGR instead of RGB
135 | ch_idx = 2 - i if convert_to_bgr else i
136 | flow_image[:, :, ch_idx] = np.floor(255 * col)
137 | return flow_image
138 |
--------------------------------------------------------------------------------
/lib/visualize/gui.py:
--------------------------------------------------------------------------------
1 | from collections import defaultdict
2 | from contextlib import contextmanager
3 | from time import sleep
4 | from typing import Any
5 |
6 | import numpy as np
7 | import torch
8 | import trimesh
9 | import viser
10 | import viser.transforms as ops
11 | from kornia.geometry import depth_to_3d
12 |
13 |
14 | class Visualize3D:
15 | def __init__(
16 | self,
17 | host: str = "127.0.0.1",
18 | port: int = 8080,
19 | ):
20 | self._server = viser.ViserServer(host, port)
21 | self.mesh = None
22 | self._steps = defaultdict(lambda: {})
23 | self._info = {}
24 | self._gui = defaultdict(lambda: {})
25 | self._gui_funcs = _GuiFuncs(self._server)
26 |
27 | def _convert_pose(self, pose: np.ndarray | None = None):
28 | if pose is None:
29 | return {}
30 | else:
31 | pose = ops.SE3.from_matrix(pose)
32 | return {
33 | "wxyz": pose.wxyz_xyz[:4],
34 | "position": pose.wxyz_xyz[4:],
35 | }
36 |
37 | @property
38 | def n_steps(self) -> int:
39 | return len(self._steps)
40 |
41 | @property
42 | def steps(self) -> list[int]:
43 | steps = []
44 | for step in self._steps.keys():
45 | steps.append(int(step.split("_")[-1]))
46 | return sorted(steps)
47 |
48 | def has_step(self, step: int, name: str):
49 | with self._server.atomic():
50 | step_data = self._steps.get(f"step_{step}", None)
51 | if not step_data:
52 | raise KeyError(f"step {step} not exists")
53 | return name in step_data
54 |
55 | def step_visible(
56 | self,
57 | step: int,
58 | frame: bool | None = None,
59 | camera: bool | None = None,
60 | point_cloud: bool | None = None,
61 | mesh: bool | None = None,
62 | all: bool | None = None,
63 | ):
64 | with self._server.atomic():
65 | step_data = self._steps.get(f"step_{step}", None)
66 | if not step_data:
67 | raise KeyError(f"step {step} not exists")
68 |
69 | if all is not None:
70 | if "frame" in step_data:
71 | step_data["frame"].visible = all
72 | if "camera" in step_data:
73 | step_data["camera"].visible = all
74 | if "point_cloud" in step_data:
75 | step_data["point_cloud"].visible = all
76 | if "mesh" in step_data:
77 | step_data["mesh"].visible = all
78 |
79 | if frame is not None and "frame" in step_data:
80 | step_data["frame"].visible = frame
81 | if camera is not None and "camera" in step_data:
82 | step_data["camera"].visible = camera
83 | if point_cloud is not None and "point_cloud" in step_data:
84 | step_data["point_cloud"].visible = point_cloud
85 | if mesh is not None and "mesh" in step_data:
86 | step_data["mesh"].visible = mesh
87 |
88 | def remove_step(self, step: int):
89 | step_data = self._steps[f"step_{step}"]
90 | for v in step_data.values():
91 | v.remove()
92 |
93 | def reset_scene(self):
94 | self._server.reset_scene()
95 | self._steps = defaultdict(lambda: {})
96 | self._info = {}
97 |
98 | def add_info(self, label: str, value: Any):
99 | self._info[label] = value
100 |
101 | def get_info(self, label: str) -> Any:
102 | return self._info[label]
103 |
104 | def add_frame(self, step: int, pose: np.ndarray | None = None, **kwargs):
105 | frame = self._server.add_frame(
106 | f"step_{step}_frame", **self._convert_pose(pose), **kwargs
107 | )
108 | self._steps[f"step_{step}"]["frame"] = frame
109 |
110 | def add_camera(
111 | self,
112 | step: int,
113 | fov: float,
114 | image: np.ndarray,
115 | scale: float = 0.3,
116 | color: tuple[int, int, int] = (20, 20, 20),
117 | pose: np.ndarray | None = None,
118 | visible: bool = True,
119 | ):
120 | camera = self._server.add_camera_frustum(
121 | f"step_{step}_camera",
122 | fov=np.deg2rad(fov),
123 | image=image,
124 | aspect=image.shape[1] / image.shape[0],
125 | color=color,
126 | scale=scale,
127 | visible=visible,
128 | **self._convert_pose(pose),
129 | )
130 | self._steps[f"step_{step}"]["camera"] = camera
131 |
132 | def add_depth_point_cloud(
133 | self,
134 | step: int,
135 | depth: np.ndarray,
136 | intrinsics: np.ndarray,
137 | color: np.ndarray | None = None,
138 | pose: np.ndarray | None = None,
139 | point_size: float = 0.01,
140 | visible: bool = True,
141 | ):
142 | pcd = (
143 | depth_to_3d(
144 | torch.from_numpy(depth)[None, None],
145 | torch.from_numpy(intrinsics)[None],
146 | )[0]
147 | .permute(1, 2, 0)
148 | .numpy()
149 | )
150 | pcd = pcd[depth > 0]
151 | if color is not None:
152 | color = color[depth > 0]
153 | handle = self._server.add_point_cloud(
154 | f"step_{step}_point_cloud",
155 | points=pcd,
156 | colors=color,
157 | point_size=point_size,
158 | **self._convert_pose(pose),
159 | visible=visible,
160 | )
161 | self._steps[f"step_{step}"]["point_cloud"] = handle
162 |
163 | def add_mesh(
164 | self,
165 | step: int,
166 | vertices: np.ndarray,
167 | triangles: np.ndarray,
168 | vertex_colors: np.ndarray | None = None,
169 | pose: np.ndarray | None = None,
170 | visible: bool = True,
171 | ):
172 | mesh = trimesh.Trimesh(
173 | vertices=vertices, faces=triangles, vertex_colors=vertex_colors
174 | )
175 | handle = self._server.add_mesh_trimesh(
176 | f"step_{step}_mesh", mesh, **self._convert_pose(pose), visible=visible
177 | )
178 | self._steps[f"step_{step}"]["mesh"] = handle
179 |
180 | # GUI handles
181 |
182 | @property
183 | def gui(self) -> viser.ViserServer:
184 | return self._gui_funcs
185 |
186 | @contextmanager
187 | def atomic(self):
188 | with self._server.atomic():
189 | yield
190 |
191 | # wait for gui
192 |
193 | def wait(self):
194 | try:
195 | while True:
196 | sleep(10.0)
197 | except KeyboardInterrupt:
198 | self._server.stop()
199 |
200 |
201 | class _GuiFuncs:
202 | def __init__(self, server):
203 | self._server = server
204 |
205 | def __getattr__(self, name: str) -> Any:
206 | if name.startswith("add_gui") and hasattr(self._server, name):
207 | return getattr(self._server, name)
208 |
--------------------------------------------------------------------------------
/lib/visualize/matplotlib.py:
--------------------------------------------------------------------------------
1 | """
2 | Visualization utilities related to matplotlib
3 | """
4 |
5 | from torch import Tensor
6 |
7 | import shutil
8 | import tempfile
9 | from pathlib import Path
10 | from typing import Iterable, Optional, Union
11 |
12 | from skimage.transform import resize as img_resize
13 | import imageio
14 | import matplotlib.pyplot as plt
15 | import numpy as np
16 | from matplotlib.figure import Figure
17 |
18 | try:
19 | from IPython import get_ipython as _get_ipython
20 |
21 | _is_interactive = lambda: _get_ipython() is not None
22 | except ImportError:
23 | _is_interactive = lambda: None
24 |
25 |
26 | __all__ = ["GridFigure", "make_gif", "color_depth"]
27 |
28 |
29 | class GridFigure:
30 | """
31 | Utility class to plot a grid of images, really useful
32 | in training code to log qualitatives
33 | """
34 |
35 | figure: Figure
36 | rows: int
37 | cols: int
38 |
39 | def __init__(
40 | self,
41 | rows: int,
42 | cols: int,
43 | *,
44 | size: tuple[int, int] | None = None,
45 | tight_layout: bool = True,
46 | ):
47 | figsize = None
48 | if size:
49 | h, w = size
50 | px = 1 / plt.rcParams["figure.dpi"]
51 | figsize = (cols * w * px, rows * h * px)
52 | self.figure = plt.figure(tight_layout=tight_layout, figsize=figsize)
53 | self.rows = rows
54 | self.cols = cols
55 |
56 | def imshow(
57 | self,
58 | pos: int,
59 | image: Tensor,
60 | /,
61 | *,
62 | show_axis: bool = False,
63 | norm: bool = True,
64 | **kwargs,
65 | ):
66 | ax = self.figure.add_subplot(self.rows, self.cols, pos)
67 | if not show_axis:
68 | ax.set_axis_off()
69 | ax.imshow(_to_numpy(image, norm=norm), **kwargs)
70 |
71 | def show(self):
72 | self.figure.show()
73 |
74 | def close(self):
75 | plt.close(self.figure)
76 |
77 | def __enter__(self):
78 | return self
79 |
80 | def __exit__(self, exc_type, exc_val, ext_tb):
81 | self.show()
82 | if not _is_interactive():
83 | self.close()
84 |
85 |
86 | def make_gif(
87 | imgs: Iterable[Union[np.ndarray, Figure]],
88 | path: Union[Path, str],
89 | axis: bool = False,
90 | close_figures: bool = True,
91 | bbox_inches: Optional[str] = "tight",
92 | pad_inches: float = 0,
93 | **kwargs,
94 | ):
95 | """
96 | Renders each image using matplotlib and the compose them into a GIF saved
97 | in path, kwargs are arguments for ``plt.imshow``
98 | """
99 |
100 | tmpdir = Path(tempfile.mkdtemp())
101 | for i, img in enumerate(imgs):
102 | if isinstance(img, np.ndarray):
103 | plt.axis({False: "off", True: "on"}[axis])
104 | plt.imshow(img, **kwargs)
105 | plt.savefig(
106 | tmpdir / f"{i}.png", bbox_inches=bbox_inches, pad_inches=pad_inches
107 | )
108 | plt.close()
109 | elif isinstance(img, Figure):
110 | img.savefig(
111 | tmpdir / f"{i}.png", bbox_inches=bbox_inches, pad_inches=pad_inches
112 | )
113 | if close_figures:
114 | plt.close(img)
115 |
116 | with imageio.get_writer(path, mode="I") as writer:
117 | paths = sorted(tmpdir.glob("*.png"), key=lambda x: int(x.name.split(".")[0]))
118 | for path in paths:
119 | image = imageio.imread(path)
120 | writer.append_data(image)
121 | shutil.rmtree(tmpdir, ignore_errors=True)
122 |
123 |
124 | def color_depth(depth: np.ndarray, cmap="magma_r", **kwargs):
125 | px = 1 / plt.rcParams["figure.dpi"]
126 | h, w = depth.shape[:2]
127 | fig = plt.figure(figsize=(w * px, h * px))
128 | plt.axis("off")
129 | plt.tight_layout(pad=0.0, h_pad=0.0, w_pad=0.0)
130 | plt.imshow(depth, cmap=cmap, **kwargs)
131 | fig.canvas.draw()
132 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
133 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
134 | data = img_resize(data, depth.shape)
135 | alpha = (depth[..., None] > 0).astype(np.uint8) * 255
136 | data = np.concatenate([data, alpha], -1)
137 | plt.close(fig)
138 | return data
139 |
140 | # internal utils
141 |
142 |
143 | def _to_numpy(image: Tensor | np.ndarray, norm: bool = True):
144 | if isinstance(image, Tensor):
145 | image = image.detach().cpu().numpy().astype(np.float32)
146 | if image.ndim == 3:
147 | image = image.transpose([1, 2, 0])
148 | elif image.ndim != 2:
149 | raise ValueError(f"Unsupported torch.Tensor of shape {image.shape}")
150 | if norm:
151 | nan_mask = np.isnan(image)
152 | if not np.all(nan_mask):
153 | img_min, img_max = image[~nan_mask].min(), image[~nan_mask].max()
154 | image = (image - img_min) / (img_max - img_min)
155 | return image
156 |
--------------------------------------------------------------------------------
/lib/visualize/plotly.py:
--------------------------------------------------------------------------------
1 | """
2 | Visualization tools implemented in plotly
3 | """
4 |
5 | import pandas as pd
6 |
7 | try
8 | from plotly.graph_objects import Scatter3d, Figure # type: ignore
9 | except ImportError:
10 | raise ImportError("To use plotly visualize utilities install plotly")
11 |
12 |
13 | def point_cloud(
14 | data: pd.DataFrame,
15 | x: str,
16 | y: str,
17 | z: str,
18 | color: list[str] | str | None = None,
19 | point_size: float = 1.0,
20 | ) -> Scatter3d:
21 | """
22 | Utility function to visualize a Point Cloud with colors
23 | """
24 |
25 | marker_opts = {
26 | "size": point_size,
27 | }
28 |
29 | if color is not None:
30 | if isinstance(color, str):
31 | marker_opts["color"] = [color] * len(data[x])
32 | else:
33 | marker_opts["color"] = [
34 | f"rgb({r}, {g}, {b})"
35 | for r, g, b in zip(data[color[0]], data[color[1]], data[color[2]])
36 | ]
37 |
38 | scatter = Scatter3d(
39 | x=data[x], y=data[y], z=data[z], mode="markers", marker=marker_opts
40 | )
41 | return scatter
42 |
--------------------------------------------------------------------------------
/media/scannetv2_scene0720_00.mp4:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/andreaconti/depth-on-demand/7376cd0f7a67e583b3d0eff86641cfd7af83de67/media/scannetv2_scene0720_00.mp4
--------------------------------------------------------------------------------
/pyrightconfig.json:
--------------------------------------------------------------------------------
1 | {
2 | "exclude": [
3 | "**/.git",
4 | "**/__pycache__",
5 | "**/.env",
6 | "**/.pytest_cache",
7 | "**/data",
8 | "**/.data",
9 | ]
10 | }
11 |
12 |
--------------------------------------------------------------------------------
/visualize/callbacks/build_scene.py:
--------------------------------------------------------------------------------
1 | import time
2 | from pathlib import Path
3 |
4 | import matplotlib
5 | import numpy as np
6 | import torch
7 | import trimesh
8 | from depth_on_demand import Model
9 |
10 | from lib.benchmark.mesh import TSDFFusion
11 | from lib.dataset import list_scans, load_datamodule
12 | from lib.dataset.utils import inv_pose, project_depth, sparsify_depth
13 | from lib.visualize.gui import Visualize3D
14 |
15 |
16 | class BuildScene:
17 | def __init__(
18 | self,
19 | viz: Visualize3D,
20 | device: str = "cuda:0",
21 | ):
22 | self.viz = viz
23 | self.device = device
24 | self.to_clear = False
25 |
26 | # select the model
27 | with viz.gui.add_gui_folder("Model"):
28 | self.model_pretrained = viz.gui.add_gui_dropdown(
29 | "Model Pretrained", ["scannetv2"], initial_value="scannetv2"
30 | )
31 | self.load_model(self.model_pretrained.value)
32 | self.model_pretrained.on_update(self.load_model)
33 |
34 | with viz.gui.add_gui_folder("Dataset"):
35 | self.dataset = viz.gui.add_gui_dropdown(
36 | "Name", ["sevenscenes", "scannetv2"], initial_value="scannetv2"
37 | )
38 | self.dataset_root = viz.gui.add_gui_text(
39 | "Root", initial_value="data/scannetv2"
40 | )
41 | self.dataset.on_update(self.on_update_dataset)
42 |
43 | # create build scene panel
44 | with viz.gui.add_gui_folder("Build Scene"):
45 | self.scene_select = viz.gui.add_gui_dropdown(
46 | "Scene", scenes[self.dataset.value]
47 | )
48 | self.hints_interval = viz.gui.add_gui_slider(
49 | "Hints Interval", min=1, max=10, step=1, initial_value=5
50 | )
51 | self.hints_density = viz.gui.add_gui_slider(
52 | "Hints Density", min=10, max=2000, step=10, initial_value=500
53 | )
54 | self.pcd_points_size = viz.gui.add_gui_slider(
55 | "Points Size", min=0.01, max=1.0, step=0.01, initial_value=0.05
56 | )
57 | self.frustum_color = viz.gui.add_gui_dropdown(
58 | "Frustum Color", list(matplotlib.colormaps.keys()), initial_value="jet"
59 | )
60 | self.frustum_color_range = viz.gui.add_gui_vector2(
61 | "Color Map Range", (0.25, 0.5), (0.0, 0.0), (1.0, 1.0), 0.05
62 | )
63 | self.overlap_time = viz.gui.add_gui_slider(
64 | "frames overlap", 0, 1, 0.01, 0.25
65 | )
66 | self.build_scene = viz.gui.add_gui_button("Build Scene")
67 |
68 | with (folder := viz.gui.add_gui_folder("TSDF Parameters")):
69 | self.tsdf_folder = folder
70 | self.voxel_length = viz.gui.add_gui_slider(
71 | "Voxel Length", 0.01, 0.5, 0.01, 0.04
72 | )
73 | self.sdf_trunc = viz.gui.add_gui_slider(
74 | "SDF Trunc", 0.01, 0.5, 0.01, 3 * 0.04
75 | )
76 | self.depth_trunc = viz.gui.add_gui_slider("Depth Trunc", 0.5, 10, 0.1, 3.0)
77 | self.integrate_only_slow = viz.gui.add_gui_checkbox(
78 | "integrate only when hints", False
79 | )
80 |
81 | self.build_scene.on_click(self.build_scene_callback)
82 |
83 | def on_update_dataset(self, *args, **kwargs):
84 | self.scene_select.options = scenes[self.dataset.value]
85 | self.voxel_length.visible = self.dataset.value != "kitti"
86 | self.sdf_trunc.visible = self.dataset.value != "kitti"
87 | self.depth_trunc.visible = self.dataset.value != "kitti"
88 | self.integrate_only_slow.visible = self.dataset.value != "kitti"
89 | self.dataset_root.value = f"data/{self.dataset.value}"
90 |
91 | def load_model(self, *args, **kwargs):
92 | self.model = Model(self.model_pretrained.value, device=self.device)
93 |
94 | def build_scene_callback(self, event):
95 | viz = self.viz
96 | scene_data = load_data(
97 | self.dataset.value, self.scene_select.value, root=self.dataset_root.value
98 | )
99 | interval = self.hints_interval.value
100 | density = self.hints_density.value
101 |
102 | colors = matplotlib.colormaps[self.frustum_color.value](
103 | np.linspace(*self.frustum_color_range.value, len(scene_data))
104 | )
105 | colors = (colors[:, :3] * 255).astype(np.uint8)
106 |
107 | tsdf = TSDFFusion(
108 | voxel_length=self.voxel_length.value,
109 | sdf_trunc=self.sdf_trunc.value,
110 | depth_trunc=self.depth_trunc.value,
111 | device=self.device,
112 | )
113 | viz.reset_scene()
114 | viz.add_info("dataset", self.dataset.value)
115 | viz.add_info("scan", self.scene_select.value)
116 | viz.add_info("frames", [])
117 |
118 | for step, ex in enumerate(scene_data):
119 | ex = {
120 | k: v.to(self.device)
121 | for k, v in ex.items()
122 | if isinstance(v, torch.Tensor)
123 | }
124 | if step % interval == 0:
125 | buffer = ex.copy()
126 | buffer_hints = sparsify_depth(buffer["depth"], density).to(self.device)
127 | with torch.no_grad():
128 | rel_pose = inv_pose(ex["extrinsics"]) @ buffer["extrinsics"]
129 | hints = project_depth(
130 | buffer_hints,
131 | buffer["intrinsics"],
132 | ex["intrinsics"],
133 | rel_pose,
134 | torch.zeros_like(ex["image"][:, :1]),
135 | )
136 | depth = self.model(
137 | ex["image"],
138 | buffer["image"],
139 | hints,
140 | rel_pose,
141 | torch.stack([ex["intrinsics"], buffer["intrinsics"]], 1),
142 | )
143 |
144 | img = img_to_rgb(ex["image"])
145 | depth = depth.cpu().numpy()[0, 0, ..., None]
146 | self.viz.get_info("frames").append(
147 | {
148 | "target": img,
149 | "source": img_to_rgb(buffer["image"]),
150 | "target_hints": hints.cpu().numpy()[0, 0, ..., None],
151 | "source_hints": buffer_hints.cpu().numpy()[0, 0, ..., None],
152 | "depth": depth,
153 | }
154 | )
155 | pose = ex["extrinsics"][0].cpu().numpy()
156 | intrinsics = ex["intrinsics"][0].cpu().numpy()
157 | hints_cpu = buffer_hints[0, 0].cpu().numpy()
158 | hints_cpu = np.where(hints_cpu > self.depth_trunc.value, 0.0, hints_cpu)
159 |
160 | # viz.add_step(step)
161 | # viz.step_visible(step, container=False)
162 | if step % interval == 0:
163 | camera_color = (255, 0, 0)
164 | else:
165 | camera_color = tuple(colors[step])
166 | viz.add_camera(
167 | step, 60.0, img, color=camera_color, pose=pose, visible=False
168 | )
169 | if step % interval == 0:
170 | red = np.zeros_like(img)
171 | red[..., 0] = 255
172 | viz.add_depth_point_cloud(
173 | step,
174 | hints_cpu,
175 | intrinsics,
176 | red,
177 | point_size=self.pcd_points_size.value,
178 | pose=pose,
179 | visible=False,
180 | )
181 |
182 | if self.integrate_only_slow.value and step % interval != 0:
183 | pass
184 | else:
185 | tsdf.integrate_rgbd(img, depth, intrinsics, pose)
186 | mesh = tsdf.triangle_mesh()
187 | viz.add_info(
188 | "mesh",
189 | trimesh.Trimesh(
190 | vertices=np.asarray(mesh.vertices),
191 | faces=np.asarray(mesh.triangles),
192 | vertex_colors=np.asarray(mesh.vertex_colors),
193 | ),
194 | )
195 | viz.add_mesh(
196 | step,
197 | vertices=np.asarray(mesh.vertices),
198 | triangles=np.asarray(mesh.triangles),
199 | vertex_colors=np.asarray(mesh.vertex_colors),
200 | # pose=np.linalg.inv(pose),
201 | visible=False,
202 | )
203 |
204 | if step > 0:
205 | viz.step_visible(step - 1, all=False, mesh=True)
206 | viz.step_visible(step, camera=True, point_cloud=True, mesh=True)
207 | if step > 0:
208 | time.sleep(self.overlap_time.value)
209 | viz.step_visible(step - 1, all=False)
210 |
211 |
212 | def load_data(dataset: str, scene: str, root: Path | str):
213 | root = {"root": root}
214 | if dataset == "kitti":
215 | root = {
216 | "root_raw": Path(root) / "raw",
217 | "root_completion": Path(root) / "depth_completion",
218 | }
219 |
220 | dm = load_datamodule(
221 | dataset,
222 | load_prevs=0,
223 | keyframes="standard",
224 | **root,
225 | split_test_scans_loaders=True,
226 | filter_scans=[scene],
227 | )
228 | dm.prepare_data()
229 | dm.setup("test")
230 | return next(iter(dm.test_dataloader()))
231 |
232 |
233 | scenes = {
234 | "scannetv2": list_scans("scannetv2", "standard", "test"),
235 | "sevenscenes": list_scans("sevenscenes", "standard", "test"),
236 | "kitti": list_scans("kitti", "standard", "test"),
237 | "tartanair": list_scans("tartanair", "standard", "test"),
238 | }
239 |
240 |
241 | def img_to_rgb(image: torch.Tensor):
242 | image = image[0].permute(1, 2, 0).cpu().numpy()
243 | image = (image - image.min()) / (image.max() - image.min())
244 | image = (image * 255).astype(np.uint8)
245 | return image
246 |
--------------------------------------------------------------------------------
/visualize/callbacks/run_demo.py:
--------------------------------------------------------------------------------
1 | import time
2 | from pathlib import Path
3 |
4 | import cv2
5 | import h5py
6 | import numpy as np
7 | import trimesh
8 | import viser
9 |
10 | from lib.visualize.gui import Visualize3D
11 |
12 |
13 | class RunDemo:
14 | def __init__(self, viz: Visualize3D, renders_folder: Path):
15 | self.viz = viz
16 | self.renders_folder = Path(renders_folder)
17 |
18 | # create gui
19 | with viz.gui.add_gui_folder("Demo"):
20 | self.speed = viz.gui.add_gui_slider("fps", 1, 20, 1, 5)
21 | self.overlap_time = viz.gui.add_gui_slider(
22 | "frames overlap", 0, 1, 0.01, 0.05
23 | )
24 | self.keep_mesh = viz.gui.add_gui_checkbox("Keep Background Mesh", False)
25 | self.run_button = viz.gui.add_gui_button("Run Demo")
26 | self.height = viz.gui.add_gui_number("Render Height", 960, 100, 2160, 1)
27 | self.width = viz.gui.add_gui_number("Render Width", 1280, 100, 2160, 1)
28 | self.render_button = viz.gui.add_gui_button("Render Video")
29 | self.save_frames = viz.gui.add_gui_checkbox("Save Frames", False)
30 | self.run_button.on_click(self.run_demo_clbk)
31 | self.render_button.on_click(self.render_video_clbk)
32 |
33 | def render_video_clbk(self, event: viser.GuiEvent):
34 | client = event.client
35 | assert client is not None
36 | client.camera.get_render
37 | self.run_demo_clbk(None, client)
38 |
39 | def run_demo_clbk(self, event, render_client=None):
40 | keep_mesh = self.keep_mesh.value
41 | keys = ["point_cloud", "frame", "camera", "mesh"]
42 | disable_all = {k: False for k in keys}
43 | enable_all = {k: True for k in keys}
44 | viz = self.viz
45 |
46 | # hide all
47 | for step in viz.steps:
48 | viz.step_visible(
49 | step,
50 | **(
51 | disable_all | {"mesh": True}
52 | if keep_mesh and step == viz.steps[-1]
53 | else disable_all
54 | ),
55 | )
56 |
57 | # start animation
58 | images = []
59 | for step in viz.steps:
60 | viz.step_visible(
61 | step,
62 | **(
63 | enable_all | {"mesh": False}
64 | if keep_mesh and step != viz.steps[-1]
65 | else enable_all
66 | ),
67 | )
68 | time.sleep(1 / self.speed.value)
69 | if step != viz.steps[-1]:
70 | viz.step_visible(step, **disable_all)
71 |
72 | if render_client is not None:
73 | images.append(
74 | render_client.camera.get_render(
75 | height=self.height.value, width=self.width.value
76 | )
77 | )
78 |
79 | # save the animation
80 | if render_client is not None:
81 | scan_name = viz.get_info("scan")
82 | render_folder = self.renders_folder / viz.get_info("dataset") / scan_name
83 | render_folder.mkdir(exist_ok=True, parents=True)
84 |
85 | # write the video
86 | h, w, _ = images[0].shape
87 | writer = cv2.VideoWriter(
88 | str(render_folder / "video.avi"),
89 | cv2.VideoWriter_fourcc(*"MJPG"),
90 | self.speed.value,
91 | (w, h),
92 | )
93 | for frame in images:
94 | writer.write(frame[..., ::-1].astype(np.uint8))
95 | writer.release()
96 |
97 | # save the mesh
98 | mesh: trimesh.Trimesh = viz.get_info("mesh")
99 | mesh.vertices = mesh.vertices - mesh.centroid
100 | _ = mesh.export(
101 | str(render_folder / "mesh.glb"),
102 | file_type="glb",
103 | )
104 |
105 | # save frames
106 | if self.save_frames.value:
107 | for idx, dict_data in enumerate(self.viz.get_info("frames")):
108 | with h5py.File(render_folder / f"data-{idx:0>6d}.h5", "w") as f:
109 | f["target"] = dict_data["target"]
110 | f["source"] = dict_data["source"]
111 | f["depth"] = dict_data["depth"]
112 | f["target_hints"] = dict_data["target_hints"]
113 | f["source_hints"] = dict_data["source_hints"]
114 |
--------------------------------------------------------------------------------
/visualize/run.py:
--------------------------------------------------------------------------------
1 | """
2 | Visualizer 3D of a scene built with my network
3 | """
4 |
5 | import sys
6 | from pathlib import Path
7 |
8 | sys.path.append(str(Path(__file__).parents[1]))
9 | import argparse
10 | import warnings
11 |
12 | from callbacks.build_scene import BuildScene
13 | from callbacks.run_demo import RunDemo
14 |
15 | from lib.visualize.gui import Visualize3D
16 |
17 | warnings.filterwarnings("ignore", category=UserWarning)
18 |
19 |
20 | def parse_args():
21 | parser = argparse.ArgumentParser()
22 | parser.add_argument("--device", default="cuda:0")
23 | parser.add_argument("--host", default="0.0.0.0")
24 | parser.add_argument("--port", default=8080, type=int)
25 | parser.add_argument("--renders-folder", default="./renders")
26 | args = parser.parse_args()
27 | return args
28 |
29 |
30 | def main():
31 | args = parse_args()
32 | viz = Visualize3D(host=args.host, port=args.port)
33 | BuildScene(viz, args.device)
34 | RunDemo(viz, args.renders_folder)
35 | viz.wait()
36 |
37 |
38 | if __name__ == "__main__":
39 | main()
40 |
--------------------------------------------------------------------------------