├── .gitignore ├── README.md ├── download ├── _utils.py ├── kitti.py ├── scannetv2.py ├── sevenscenes.py └── tartanair.py ├── environment.yml ├── evaluate ├── config │ ├── kitti.yaml │ ├── scannetv2.yaml │ ├── sevenscenes.yaml │ └── tartanair.yaml ├── metrics.py └── utils │ ├── funcs.py │ ├── metrics.py │ └── pose_estimation.py ├── lib ├── __init__.py ├── benchmark │ └── mesh │ │ ├── __init__.py │ │ ├── _evaluation.py │ │ └── _fusion.py ├── dataset │ ├── __init__.py │ ├── _resources │ │ ├── __init__.py │ │ ├── kitti │ │ │ └── keyframes │ │ │ │ └── standard │ │ │ │ └── test_split.txt │ │ ├── scannetv2 │ │ │ ├── keyframes │ │ │ │ └── standard │ │ │ │ │ └── test_split.txt │ │ │ └── test_split.txt │ │ ├── sevenscenes │ │ │ ├── keyframes │ │ │ │ └── standard │ │ │ │ │ └── test_split.txt │ │ │ └── test_split.txt │ │ └── tartanair │ │ │ ├── keyframes │ │ │ └── standard │ │ │ │ └── test_split.txt │ │ │ └── test_split.txt │ ├── _utils.py │ ├── kitti.py │ ├── scannetv2.py │ ├── sevenscenes.py │ ├── tartanair.py │ └── utils │ │ ├── __init__.py │ │ └── metrics.py └── visualize │ ├── __init__.py │ ├── depth.py │ ├── flow.py │ ├── gui.py │ ├── matplotlib.py │ └── plotly.py ├── media └── scannetv2_scene0720_00.mp4 ├── pyrightconfig.json └── visualize ├── callbacks ├── build_scene.py └── run_demo.py └── run.py /.gitignore: -------------------------------------------------------------------------------- 1 | data 2 | .env 3 | **/__pycache__ 4 | renders -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | Depth on Demand: Streaming Dense Depth from a Low Frame-Rate Active Sensor 3 |

4 | 5 |

6 |

7 | Andrea Conti 8 | · 9 | Matteo Poggi 10 | · 11 | Valerio Cambareri 12 | · 13 | Stefano Mattoccia 14 |
15 |
16 | [Arxiv] 17 | [Project Page] 18 |
19 |

20 | 21 | ![](https://andreaconti.github.io/projects/depth_on_demand/images/setup_example.png) 22 | 23 | We propose Depth on Demand (DoD), a framework addressing the three major issues related to active depth sensors in streaming dense depth maps: spatial sparsity, limited frame rate and energy consumption of the depth sensors. DoD allows streaming high-resolution depth from an RGB camera and a depth sensor without requiring the depth sensor neither to be dense nor to match the frame rate of the RGB camera. Depth on Demand aims to improve the temporal resolution of an active depth sensor by utilizing the higher frame rate of an RGB camera. It estimates depth for each RGB frame, even for those that do not have direct depth sensor measurements. 24 | 25 | 26 | ### Install 27 | 28 | Dependencies can be installed with `conda` or `mamba` as follows: 29 | 30 | ```bash 31 | $ git clone https://github.com/andreaconti/depth-on-demand.git 32 | $ cd depth-on-demand 33 | $ conda env create -f environment.yml # use mamba if conda is too slow 34 | $ conda activate depth-on-demand 35 | # then, download and install the wheel containing the pretrained models, available for linux, windows and macos 36 | $ pip install https://github.com/andreaconti/depth-on-demand/releases/download/models%2Fv0.1.1/depth_on_demand-0.1.1-cp310-cp310-linux_x86_64.whl --no-deps 37 | ``` 38 | 39 | We provide scripts to prepare the datasets for evaluation, each automatically downloads and unpack the dataset in the `data` directory: 40 | 41 | ```bash 42 | # as an example to download the evaluation sequences of scannetv2 can be used the following 43 | $ python download/scannetv2.py 44 | ``` 45 | 46 | ### Evaluate 47 | 48 | To evaluate the framework we provide the following script, which loads the specified dataset, its configuration and returns the metrics. 49 | 50 | ```bash 51 | $ python evaluate/metrics.py --dataset scannetv2 52 | ``` 53 | 54 | Configuration for each dataset is in `evaluate/config` 55 | 56 | ### Visualize 57 | 58 | We provide a GUI (based on [viser](https://viser.studio/latest/#)) to run Depth on Demand on ScannetV2 and 7Scenes sequences interactively. To start the visualizer use the following command and open the browser on the `http://127.0.0.1:8080` as described by the script output. 59 | 60 | ```bash 61 | $ python visualize/run.py 62 | ``` 63 | 64 | visualizer_example 65 | 66 | On the right the interface parameters can be configured: 67 | 68 | - **Dataset**: select the dataset between scannetv2 and sevenscenes in the drop down menu, the dataset root should be upgraded accordingly. 69 | - **Build Scene**: choose from the drop down menu the scene to use. The other parameters control the density of the input sparse depth both temporally and spatially. 70 | - **TSDF Parameters**: you shouldn't need to change these parameters. They control the mesh reconstruction, tricking the TSDF parameters change the quality of the mesh. 71 | 72 | Press **Build Scene** to start the reconstruction. You'll see the system start the reconstruction, you can move the view angle using dragging with the mouse. 73 | 74 | **Attention: the interface in this phase may flicker and this step may potentially trigger seizures for people with photosensitive epilepsy. Viewer discretion is advised** 75 | 76 | Once the reconstruction is done, playback can be executed multiple times at different frame rates using the **Demo** box and the **Run Demo** button. With the `keep background mesh` option you can either interactively observe the mesh integration growing in time or fix the final one while the point of view moves. 77 | 78 | Finally, the **Render Video** button allows to save a demo video and also to **save all the predictions of Depth on Demand** using the `save frames` checkbox. Results are saved in the `renders` directory 79 | 80 | ## Citation 81 | 82 | ```bibtex 83 | @InProceedings{DoD_Conti24, 84 | author = {Conti, Andrea and Poggi, Matteo and Cambareri, Valerio and Mattoccia, Stefano}, 85 | title = {Depth on Demand: Streaming Dense Depth from a Low Frame-Rate Active Sensor}, 86 | booktitle = {European Conference on Computer Vision (ECCV)}, 87 | month = {October}, 88 | year = {2024}, 89 | } 90 | ``` 91 | -------------------------------------------------------------------------------- /download/_utils.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities to download data from GitHub Releases 3 | """ 4 | 5 | from urllib.request import urlopen, Request 6 | from urllib.error import HTTPError 7 | from zipfile import ZipFile 8 | import tempfile 9 | from pathlib import Path 10 | import os 11 | from tqdm import tqdm 12 | import hashlib 13 | import shutil 14 | 15 | __all__ = [ 16 | "github_download_release_asset", 17 | "github_download_unzip_assets", 18 | "download_url_to_file", 19 | ] 20 | 21 | 22 | def github_download_unzip_assets( 23 | owner: str, 24 | repo: str, 25 | asset_ids: list[str], 26 | dst: str | Path, 27 | ): 28 | try: 29 | Path(dst).mkdir(exist_ok=True, parents=True) 30 | tmpdir = tempfile.mkdtemp() 31 | for asset_id in asset_ids: 32 | tmpfile = os.path.join(tmpdir, asset_id) 33 | github_download_release_asset(owner, repo, asset_id, tmpfile) 34 | with ZipFile(tmpfile) as f: 35 | f.extractall(dst) 36 | finally: 37 | if os.path.exists(tmpdir): 38 | shutil.rmtree(tmpdir, ignore_errors=True) 39 | 40 | 41 | def github_download_release_asset( 42 | owner: str, 43 | repo: str, 44 | asset_id: str, 45 | dst: str | Path, 46 | ): 47 | headers = { 48 | "Accept": "application/octet-stream", 49 | } 50 | if token := os.environ.get("GITHUB_TOKEN", None): 51 | headers["Authorization"] = f"Bearer {token}" 52 | 53 | try: 54 | download_url_to_file( 55 | f"https://api.github.com/repos/{owner}/{repo}/releases/assets/{asset_id}", 56 | dst, 57 | headers=headers, 58 | ) 59 | except HTTPError as e: 60 | if e.code == 404 and "Authorization" not in headers: 61 | raise RuntimeError( 62 | "File not found, maybe missing GITHUB_TOKEN env variable?" 63 | ) 64 | raise e 65 | 66 | 67 | def download_url_to_file( 68 | url, 69 | dst, 70 | *, 71 | hash_prefix=None, 72 | progress=True, 73 | headers=None, 74 | ): 75 | # find file size 76 | file_size = None 77 | req = Request(url, headers={} if not headers else headers) 78 | u = urlopen(req) 79 | meta = u.info() 80 | if hasattr(meta, "getheaders"): 81 | content_length = meta.getheaders("Content-Length") 82 | else: 83 | content_length = meta.get_all("Content-Length") 84 | if content_length is not None and len(content_length) > 0: 85 | file_size = int(content_length[0]) 86 | 87 | dst = os.path.expanduser(dst) 88 | dst_dir = os.path.dirname(dst) 89 | f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) 90 | 91 | try: 92 | if hash_prefix is not None: 93 | sha256 = hashlib.sha256() 94 | with tqdm( 95 | total=file_size, 96 | disable=not progress, 97 | unit="B", 98 | unit_scale=True, 99 | unit_divisor=1024, 100 | ) as pbar: 101 | while True: 102 | buffer = u.read(8192) 103 | if len(buffer) == 0: 104 | break 105 | f.write(buffer) 106 | if hash_prefix is not None: 107 | sha256.update(buffer) 108 | pbar.update(len(buffer)) 109 | 110 | f.close() 111 | if hash_prefix is not None: 112 | digest = sha256.hexdigest() 113 | if digest[: len(hash_prefix)] != hash_prefix: 114 | raise RuntimeError( 115 | 'invalid hash value (expected "{}", got "{}")'.format( 116 | hash_prefix, digest 117 | ) 118 | ) 119 | shutil.move(f.name, dst) 120 | finally: 121 | f.close() 122 | if os.path.exists(f.name): 123 | os.remove(f.name) 124 | 125 | -------------------------------------------------------------------------------- /download/kitti.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to download the KITTI Depth Completion 3 | validation set 4 | """ 5 | 6 | import os 7 | import re 8 | import shutil 9 | from tqdm import tqdm 10 | from pathlib import Path 11 | from zipfile import ZipFile 12 | import requests 13 | from argparse import ArgumentParser 14 | 15 | 16 | calib_scans = [ 17 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_30_calib.zip", 18 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_29_calib.zip", 19 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_26_calib.zip", 20 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_10_03_calib.zip", 21 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/2011_09_28_calib.zip", 22 | ] 23 | 24 | scenes = list( 25 | map( 26 | lambda x: "https://s3.eu-central-1.amazonaws.com/avg-kitti/raw_data/" 27 | + f"{x}/{x}_sync.zip", 28 | [ 29 | "2011_09_26_drive_0020", 30 | "2011_09_26_drive_0036", 31 | "2011_09_26_drive_0002", 32 | "2011_09_26_drive_0013", 33 | "2011_09_26_drive_0005", 34 | "2011_09_26_drive_0113", 35 | "2011_09_26_drive_0023", 36 | "2011_09_26_drive_0079", 37 | "2011_09_29_drive_0026", 38 | "2011_09_30_drive_0016", 39 | "2011_10_03_drive_0047", 40 | "2011_09_26_drive_0095", 41 | "2011_09_28_drive_0037", 42 | ], 43 | ) 44 | ) 45 | 46 | 47 | def download_file(url, save_path, chunk_size=1024, verbose=True): 48 | """ 49 | Downloads a zip file from an `url` into a zip file in the 50 | provided `save_path`. 51 | """ 52 | r = requests.get(url, stream=True) 53 | zip_name = url.split("/")[-1] 54 | 55 | content_length = int(r.headers["Content-Length"]) / 10**6 56 | 57 | if verbose: 58 | bar = tqdm(total=content_length, unit="Mb", desc="download " + zip_name) 59 | with open(save_path, "wb") as fd: 60 | for chunk in r.iter_content(chunk_size=chunk_size): 61 | fd.write(chunk) 62 | if verbose: 63 | bar.update(chunk_size / 10**6) 64 | 65 | if verbose: 66 | bar.close() 67 | 68 | 69 | def raw_download(root_path: str): 70 | 71 | date_match = re.compile("[0-9]+_[0-9]+_[0-9]+") 72 | drive_match = re.compile("[0-9]+_[0-9]+_[0-9]+_drive_[0-9]+_sync") 73 | 74 | def download_unzip(url): 75 | date = date_match.findall(url)[0] 76 | drive = drive_match.findall(url)[0] 77 | os.makedirs(os.path.join(root_path, date), exist_ok=True) 78 | download_file(url, os.path.join(root_path, date, drive + ".zip"), verbose=False) 79 | with ZipFile(os.path.join(root_path, date, drive + ".zip"), "r") as zip_ref: 80 | zip_ref.extractall(os.path.join(root_path, date, drive + "_tmp")) 81 | os.rename( 82 | os.path.join(root_path, date, drive + "_tmp", date, drive), 83 | os.path.join(root_path, date, drive), 84 | ) 85 | shutil.rmtree(os.path.join(root_path, date, drive + "_tmp")) 86 | os.remove(os.path.join(root_path, date, drive + ".zip")) 87 | 88 | for scene in tqdm(scenes, desc="download kitti (raw)"): 89 | download_unzip(scene) 90 | 91 | 92 | def dc_download(root_path: str, progress_bar=True): 93 | """ 94 | Downloads and scaffold depth completion dataset in `root_path` 95 | """ 96 | 97 | if not os.path.exists(root_path): 98 | os.mkdir(root_path) 99 | 100 | # urls 101 | data_depth_selection_url = ( 102 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_selection.zip" 103 | ) 104 | data_depth_velodyne_url = ( 105 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_velodyne.zip" 106 | ) 107 | data_depth_annotated_url = ( 108 | "https://s3.eu-central-1.amazonaws.com/avg-kitti/data_depth_annotated.zip" 109 | ) 110 | 111 | # download of zips 112 | download_file( 113 | data_depth_selection_url, 114 | os.path.join(root_path, "data_depth_selection.zip"), 115 | verbose=progress_bar, 116 | ) 117 | 118 | download_file( 119 | data_depth_velodyne_url, 120 | os.path.join(root_path, "data_depth_velodyne.zip"), 121 | verbose=progress_bar, 122 | ) 123 | download_file( 124 | data_depth_annotated_url, 125 | os.path.join(root_path, "data_depth_annotated.zip"), 126 | verbose=progress_bar, 127 | ) 128 | 129 | # unzip and remove zips 130 | with ZipFile(os.path.join(root_path, "data_depth_selection.zip"), "r") as zip_ref: 131 | zip_ref.extractall(root_path) 132 | os.rename( 133 | os.path.join( 134 | root_path, "depth_selection", "test_depth_completion_anonymous" 135 | ), 136 | os.path.join(root_path, "test_depth_completion_anonymous"), 137 | ) 138 | os.rename( 139 | os.path.join( 140 | root_path, "depth_selection", "test_depth_prediction_anonymous" 141 | ), 142 | os.path.join(root_path, "test_depth_prediction_anonymous"), 143 | ) 144 | os.rename( 145 | os.path.join(root_path, "depth_selection", "val_selection_cropped"), 146 | os.path.join(root_path, "val_selection_cropped"), 147 | ) 148 | os.rmdir(os.path.join(root_path, "depth_selection")) 149 | with ZipFile(os.path.join(root_path, "data_depth_velodyne.zip"), "r") as zip_ref: 150 | zip_ref.extractall(root_path) 151 | with ZipFile(os.path.join(root_path, "data_depth_annotated.zip"), "r") as zip_ref: 152 | zip_ref.extractall(root_path) 153 | 154 | # remove zip files 155 | os.remove(os.path.join(root_path, "data_depth_selection.zip")) 156 | os.remove(os.path.join(root_path, "data_depth_annotated.zip")) 157 | os.remove(os.path.join(root_path, "data_depth_velodyne.zip")) 158 | 159 | 160 | def calib_download(root_path: str): 161 | """ 162 | Downloads and scaffolds calibration files 163 | """ 164 | 165 | Path(root_path).mkdir(exist_ok=True) 166 | 167 | for repo in calib_scans: 168 | calib_zip_path = os.path.join(root_path, "calib.zip") 169 | download_file(repo, calib_zip_path) 170 | with open(calib_zip_path, "rb") as f: 171 | ZipFile(f).extractall(root_path) 172 | os.remove(calib_zip_path) 173 | 174 | 175 | if __name__ == "__main__": 176 | parser = ArgumentParser("download the KITTI test set for evaluation") 177 | parser.add_argument("--root", type=Path, default=Path("data/kitti")) 178 | args = parser.parse_args() 179 | raw_download(str(args.root / "raw")) 180 | calib_download(str(args.root / "raw")) 181 | dc_download(str(args.root / "depth_completion")) 182 | -------------------------------------------------------------------------------- /download/scannetv2.py: -------------------------------------------------------------------------------- 1 | from _utils import github_download_unzip_assets 2 | from argparse import ArgumentParser 3 | from zipfile import ZipFile 4 | from pathlib import Path 5 | import os 6 | import gdown 7 | 8 | 9 | def download_and_unzip_meshes(root: str | Path): 10 | root = Path(root) 11 | 12 | # from TransformerFusion 13 | # https://github.com/AljazBozic/TransformerFusion?tab=readme-ov-file 14 | url = "https://drive.usercontent.google.com/download?id=1-nto65_JTNs1vyeHycebidYFyQvE6kt4&authuser=0&confirm=t&uuid=05aba470-c11c-48f9-8ba3-162560ab15bb&at=APZUnTUsLTVanQWS8dVSGr5pA3hJ%3A1710523361429" 15 | out_zip = str(Path(root) / ".meshes.zip") 16 | gdown.download(url, out_zip, quiet=False) 17 | 18 | # unzip data in the correct format 19 | with ZipFile(out_zip, "r") as f: 20 | for file in f.infolist(): 21 | if not file.is_dir(): 22 | _, scene, fname = file.filename.split("/") 23 | new_path = root / ( 24 | f"{scene}_vh_clean.ply" 25 | if fname == "mesh_gt.ply" 26 | else f"{scene}_{fname}" 27 | ) 28 | data = f.read(file) 29 | with open(new_path, "wb") as out: 30 | out.write(data) 31 | 32 | # remove the zip 33 | os.remove(out_zip) 34 | 35 | 36 | if __name__ == "__main__": 37 | parser = ArgumentParser("download the TartanAir test set for evaluation") 38 | parser.add_argument("--root", type=Path, default=Path("data/scannetv2")) 39 | args = parser.parse_args() 40 | github_download_unzip_assets( 41 | "andreaconti", 42 | "depth-on-demand", 43 | ["156761793", "156761795", "156761794"], 44 | args.root, 45 | ) 46 | download_and_unzip_meshes(args.root) 47 | -------------------------------------------------------------------------------- /download/sevenscenes.py: -------------------------------------------------------------------------------- 1 | from _utils import github_download_unzip_assets 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | 5 | if __name__ == "__main__": 6 | parser = ArgumentParser("download the 7Scenes test set for evaluation") 7 | parser.add_argument("--root", type=Path, default=Path("data/sevenscenes")) 8 | args = parser.parse_args() 9 | github_download_unzip_assets( 10 | "andreaconti", 11 | "depth-on-demand", 12 | ["155673151"], 13 | args.root, 14 | ) 15 | -------------------------------------------------------------------------------- /download/tartanair.py: -------------------------------------------------------------------------------- 1 | from _utils import github_download_unzip_assets 2 | from argparse import ArgumentParser 3 | from pathlib import Path 4 | 5 | if __name__ == "__main__": 6 | parser = ArgumentParser("download the TartanAir test set for evaluation") 7 | parser.add_argument("--root", type=Path, default=Path("data/tartanair")) 8 | args = parser.parse_args() 9 | github_download_unzip_assets( 10 | "andreaconti", 11 | "depth-on-demand", 12 | ["156760244", "156760245", "156760243", "156760242"], 13 | args.root, 14 | ) 15 | -------------------------------------------------------------------------------- /environment.yml: -------------------------------------------------------------------------------- 1 | name: depth-on-demand 2 | channels: 3 | - pytorch 4 | - nvidia 5 | - conda-forge 6 | - defaults 7 | - pytorch3d 8 | dependencies: 9 | - python=3.10.11 10 | - pytorch=2.0.1 11 | - pytorch-cuda=11.8 12 | - torchvision=0.15.2 13 | - torchdata=0.6.0 14 | - lightning=2.0.3 15 | - albumentations=1.3.1 16 | - kornia=0.6.12 17 | - pytorch3d=0.7.4 18 | 19 | # dev 20 | - ipython 21 | - ipykernel 22 | - black 23 | - pytest 24 | 25 | # c++ dependencies 26 | - eigen 27 | 28 | - pip 29 | - pip: 30 | - pipe==2.0.0 31 | - omegaconf==2.3.0 32 | - open3d==0.17.0 33 | - plyfile==0.9.0 34 | - pykitti==0.3.1 35 | - timm==0.9.7 36 | - viser>=0.1.7<1.0.0 37 | - h5py>=3.10.0 38 | - git+https://github.com/PoseLib/PoseLib.git 39 | - git+https://github.com/cvg/LightGlue.git 40 | - gdown 41 | 42 | -------------------------------------------------------------------------------- /evaluate/config/kitti.yaml: -------------------------------------------------------------------------------- 1 | args: 2 | root_raw: data/kitti/raw 3 | root_completion: data/kitti/depth_completion 4 | batch_size: 1 5 | load_prevs: 0 6 | shuffle_keyframes: false 7 | order_sources_by_pose: false 8 | load_hints_pcd: true 9 | keyframes: standard 10 | min_depth: 0.001 11 | max_depth: 80.00 12 | 13 | inference: 14 | depth_hints: 15 | density: null 16 | interval: 10 17 | hints_from_pcd: true 18 | pnp_pose: true 19 | source: 20 | interval: ${inference.depth_hints.interval} 21 | model_eval_params: 22 | n_cycles: 10 -------------------------------------------------------------------------------- /evaluate/config/scannetv2.yaml: -------------------------------------------------------------------------------- 1 | args: 2 | root: data/scannetv2 3 | batch_size: 1 4 | load_prevs: 0 5 | shuffle_keyframes: false 6 | order_sources_by_pose: false 7 | keyframes: standard 8 | min_depth: 0.0 9 | max_depth: 10.0 10 | 11 | tsdf_fusion: 12 | engine: open3d 13 | depth_scale: 1 14 | depth_trunc: 3 15 | sdf_trunc: 0.12 16 | voxel_length: 0.04 17 | 18 | inference: 19 | depth_hints: 20 | interval: 5 21 | density: 500 22 | hints_from_pcd: false 23 | pnp_pose: false 24 | source: 25 | interval: 5 26 | model_eval_params: 27 | n_cycles: 10 28 | -------------------------------------------------------------------------------- /evaluate/config/sevenscenes.yaml: -------------------------------------------------------------------------------- 1 | args: 2 | root: data/sevenscenes 3 | batch_size: 1 4 | load_prevs: 0 5 | shuffle_keyframes: false 6 | order_sources_by_pose: false 7 | keyframes: standard 8 | min_depth: 0.0 9 | max_depth: 10.0 10 | 11 | tsdf_fusion: 12 | engine: open3d 13 | depth_scale: 1 14 | depth_trunc: 3 15 | sdf_trunc: 0.12 16 | voxel_length: 0.04 17 | 18 | inference: 19 | depth_hints: 20 | interval: 5 21 | density: 500 22 | hints_from_pcd: false 23 | pnp_pose: false 24 | source: 25 | interval: 5 26 | model_eval_params: 27 | n_cycles: 10 -------------------------------------------------------------------------------- /evaluate/config/tartanair.yaml: -------------------------------------------------------------------------------- 1 | args: 2 | root: data/tartanair 3 | batch_size: 1 4 | load_prevs: 0 5 | shuffle_keyframes: false 6 | order_sources_by_pose: false 7 | keyframes: standard 8 | min_depth: 0.001 9 | max_depth: 100.0 10 | 11 | inference: 12 | depth_hints: 13 | interval: 5 14 | density: 500 15 | hints_from_pcd: false 16 | pnp_pose: false 17 | source: 18 | interval: 5 19 | model_eval_params: 20 | n_cycles: 10 -------------------------------------------------------------------------------- /evaluate/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Script to reproduce the DoD Paper metrics on various datasets used 3 | """ 4 | 5 | import sys 6 | 7 | # In[] Imports 8 | from pathlib import Path 9 | 10 | sys.path.append(str(Path(__file__).parents[1])) 11 | import shutil 12 | import tempfile 13 | import warnings 14 | from argparse import ArgumentParser 15 | from collections import defaultdict 16 | from copy import deepcopy 17 | 18 | import pandas as pd 19 | import torch 20 | import utils.funcs as funcs 21 | from depth_on_demand import Model as DepthOnDemand 22 | from omegaconf import OmegaConf 23 | from torchmetrics import MeanMetric 24 | from tqdm import tqdm 25 | from utils.metrics import compute_metrics_depth 26 | 27 | from lib.benchmark.mesh import TSDFFusion, mesh_metrics 28 | from lib.dataset import load_datamodule 29 | 30 | warnings.filterwarnings("ignore", category=UserWarning) 31 | 32 | rootdir = Path(__file__).parents[1] 33 | thisdir = Path(__file__).parent 34 | 35 | # In[] Args 36 | 37 | parser = ArgumentParser("Test DoD Model") 38 | parser.add_argument( 39 | "--dataset", 40 | choices=["kitti", "scannetv2", "sevenscenes", "tartanair"], 41 | default="sevenscenes", 42 | ) 43 | parser.add_argument("--device", default="cuda:0") 44 | parser.add_argument("--f") 45 | args = parser.parse_args() 46 | 47 | # In[] Load dataset 48 | cfg = OmegaConf.load(thisdir / "config" / (args.dataset + ".yaml")) 49 | dm = load_datamodule( 50 | args.dataset, 51 | **cfg.args, 52 | split_test_scans_loaders=True, 53 | ) 54 | dm.prepare_data() 55 | dm.setup("test") 56 | scans = dm.test_dataloader() 57 | 58 | # In[] Load the model 59 | model = DepthOnDemand( 60 | pretrained={ 61 | "sevenscenes": "scannetv2", 62 | "scannetv2": "scannetv2", 63 | "tartanair": "tartanair", 64 | "kitti": "kitti", 65 | }[args.dataset], 66 | device=args.device, 67 | ) 68 | 69 | # In[] Testing 70 | predict_pose = None 71 | if cfg.inference.depth_hints.pnp_pose: 72 | 73 | from utils.pose_estimation import PoseEstimation 74 | 75 | predict_pose = PoseEstimation(dilate_depth=3) 76 | predict_pose.to(args.device) 77 | 78 | metrics = defaultdict(lambda: MeanMetric().to(args.device)) 79 | for scan in tqdm(scans): 80 | 81 | buffer_hints = {} 82 | buffer_source = {} 83 | if args.dataset == "scannetv2": 84 | tsdf_volume = TSDFFusion(**cfg.tsdf_fusion) 85 | 86 | for idx, batch in enumerate(tqdm(scan, leave=False)): 87 | 88 | batch = { 89 | k: v.to(args.device) if isinstance(v, torch.Tensor) else v 90 | for k, v in batch.items() 91 | } 92 | if idx % cfg.inference.depth_hints.interval == 0: 93 | buffer_hints = deepcopy(batch) 94 | if idx % cfg.inference.source.interval == 0: 95 | buffer_source = deepcopy(batch) 96 | 97 | input = funcs.prepare_input( 98 | batch 99 | | {k + "_prev_0": v for k, v in buffer_hints.items()} 100 | | {k + "_prev_1": v for k, v in buffer_source.items()}, 101 | hints_density=cfg.inference.depth_hints.density, 102 | hints_from_pcd=cfg.inference.depth_hints.hints_from_pcd, 103 | hints_postfix="_prev_0", 104 | source_postfix="_prev_1", 105 | predict_pose=predict_pose, 106 | ) 107 | buffer_source["hints"] = input.pop("source_hints") 108 | buffer_hints["hints"] = input.pop("hints_hints") 109 | 110 | # inference 111 | gt = batch["depth"].to(args.device) 112 | mask = gt > 0 113 | with torch.no_grad(): 114 | pred_depth = model( 115 | **input, 116 | n_cycles=cfg.inference.model_eval_params.n_cycles, 117 | ) 118 | 119 | # metrics 120 | for k, v in compute_metrics_depth(pred_depth[mask], gt[mask], "test").items(): 121 | metrics[k].update(v) 122 | if args.dataset == "scannetv2": 123 | tsdf_volume.integrate_rgbd( 124 | batch["image"][0].permute(1, 2, 0).cpu().numpy(), 125 | pred_depth[0].permute(1, 2, 0).detach().cpu().numpy(), 126 | batch["intrinsics"][0].cpu().numpy(), 127 | batch["extrinsics"][0].cpu().numpy(), 128 | ) 129 | 130 | if args.dataset == "scannetv2": 131 | mesh_pred = tsdf_volume.triangle_mesh() 132 | for k, v in mesh_metrics( 133 | mesh_pred, 134 | batch["gt_mesh_path"][0], 135 | batch["gt_mesh_world2grid_path"][0], 136 | batch["gt_mesh_occl_path"][0], 137 | device=args.device, 138 | ).items(): 139 | metrics["test/" + k].update(v) 140 | tmpdir = Path(tempfile.mkdtemp()) 141 | scan_name = "_".join(Path(batch["gt_mesh_path"][0]).stem.split("_")[:2]) 142 | mesh_path = tmpdir / (scan_name + ".ply") 143 | tsdf_volume.write_triangle_mesh(mesh_path) 144 | shutil.rmtree(tmpdir, ignore_errors=True) 145 | 146 | # %% 147 | 148 | print(f"== metrics for {args.dataset} ==") 149 | for k, v in metrics.items(): 150 | print(f"{k.ljust(25)}: {v.compute():0.5f}") 151 | -------------------------------------------------------------------------------- /evaluate/utils/funcs.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import kornia 3 | from torch import Tensor 4 | import random 5 | import torch.nn.functional as F 6 | from typing import Callable 7 | from lib.visualize.matplotlib import GridFigure 8 | from kornia.morphology import dilation 9 | import io 10 | import itertools 11 | from contextlib import redirect_stdout 12 | import kornia 13 | import numpy as np 14 | from kornia.geometry.quaternion import QuaternionCoeffOrder 15 | from kornia.geometry.conversions import ( 16 | euler_from_quaternion, 17 | rotation_matrix_to_quaternion, 18 | ) 19 | import re 20 | from lib.visualize.matplotlib import color_depth 21 | 22 | __all__ = ["sparsity_depth", "prepare_input"] 23 | 24 | 25 | def prepare_input( 26 | batch, 27 | hints_density: float | int = 500, 28 | hints_from_pcd: bool = False, 29 | source_postfix: str = "", 30 | hints_postfix: str = "", 31 | predict_pose: Callable | None = None, 32 | pose_noise_std_mult: float = 0.0, 33 | ): 34 | # prepare source and target views 35 | image_target = batch["image"] 36 | image_source = batch[f"image{source_postfix}"] 37 | intrinsics = torch.stack( 38 | [batch["intrinsics"], batch[f"intrinsics{source_postfix}"]], 1 39 | ) 40 | 41 | # sample hints on their frame 42 | hints_hints = find_hints(batch, hints_postfix, hints_density, hints_from_pcd) 43 | if hints_postfix != source_postfix: 44 | hints_src = find_hints(batch, source_postfix, hints_density, hints_from_pcd) 45 | else: 46 | hints_src = hints_hints 47 | 48 | # prepare pose 49 | if predict_pose is None: 50 | pose_src_tgt = ( 51 | inv_pose(batch["extrinsics"]) @ batch[f"extrinsics{source_postfix}"] 52 | ) 53 | pose_hints_tgt = ( 54 | inv_pose(batch["extrinsics"]) @ batch[f"extrinsics{hints_postfix}"] 55 | ) 56 | if pose_noise_std_mult > 0.0: 57 | if source_postfix == hints_postfix: 58 | trasl_std = pose_noise_std_mult * pose_src_tgt[:, :3, -1].abs() 59 | rot_std = pose_noise_std_mult * pose_to_angles(pose_src_tgt).abs() 60 | pnoise = pose_noise(rot_std, trasl_std) 61 | pose_src_tgt = pnoise @ pose_src_tgt 62 | pose_hints_tgt = pnoise @ pose_hints_tgt 63 | else: 64 | trasl_std = pose_noise_std_mult * pose_src_tgt[:, :3, -1].abs() 65 | rot_std = pose_noise_std_mult * pose_to_angles(pose_src_tgt).abs() 66 | pnoise = pose_noise(rot_std, trasl_std) 67 | pose_src_tgt = pnoise @ pose_src_tgt 68 | 69 | trasl_std = pose_noise_std_mult * pose_hints_tgt[:, :3, -1].abs() 70 | rot_std = pose_noise_std_mult * pose_to_angles(pose_hints_tgt).abs() 71 | pnoise = pose_noise(rot_std, trasl_std) 72 | pose_hints_tgt = pnoise @ pose_hints_tgt 73 | 74 | else: 75 | pose_hints_tgt = predict_pose( 76 | image_target, 77 | batch[f"image{hints_postfix}"], 78 | hints_hints, 79 | batch["intrinsics"], 80 | batch[f"intrinsics{hints_postfix}"], 81 | ) 82 | 83 | if hints_postfix != source_postfix: 84 | pose_src_tgt = predict_pose( 85 | image_target, 86 | batch[f"image{source_postfix}"], 87 | hints_src, 88 | batch["intrinsics"], 89 | batch[f"intrinsics{source_postfix}"], 90 | ) 91 | else: 92 | pose_src_tgt = pose_hints_tgt 93 | 94 | # project hints 95 | if not hints_from_pcd: 96 | h, w = image_target.shape[-2:] 97 | hints, _ = project_depth( 98 | hints_hints.to(torch.float32), 99 | batch[f"intrinsics{hints_postfix}"].to(torch.float32), 100 | batch[f"intrinsics"].to(torch.float32), 101 | pose_hints_tgt.to(torch.float32), 102 | torch.zeros_like(hints_hints, dtype=torch.float32), 103 | ) 104 | else: 105 | b, _, h, w = image_target.shape 106 | device, dtype = image_target.device, image_target.dtype 107 | hints = torch.zeros(b, 1, h, w, device=device, dtype=dtype) 108 | project_pcd( 109 | batch[f"hints_pcd{hints_postfix}"], 110 | batch[f"intrinsics"], 111 | hints, 112 | pose_hints_tgt, 113 | ) 114 | 115 | return { 116 | "target": image_target, 117 | "source": image_source, 118 | "pose_src_tgt": pose_src_tgt, 119 | "intrinsics": intrinsics, 120 | "hints": hints, 121 | "source_hints": hints_src, 122 | "hints_hints": hints_hints, 123 | } 124 | 125 | 126 | def prepare_input_onnx( 127 | batch, 128 | hints_density: float | int = 500, 129 | hints_from_pcd: bool = False, 130 | source_postfix: str = "", 131 | hints_postfix: str = "", 132 | predict_pose: Callable | None = None, 133 | pose_noise_std_mult: float = 0.0, 134 | ): 135 | 136 | # take the useful base outputs 137 | base_input = prepare_input( 138 | batch, 139 | hints_density, 140 | hints_from_pcd, 141 | source_postfix, 142 | hints_postfix, 143 | predict_pose, 144 | pose_noise_std_mult, 145 | ) 146 | out = { 147 | "target": base_input["target"], 148 | "source": base_input["source"], 149 | "pose_src_tgt": base_input["pose_src_tgt"], 150 | "intrinsics": base_input["intrinsics"], 151 | } 152 | 153 | # compute the other outputs required by the onnx model 154 | b, _, h, w = base_input["hints_hints"].shape 155 | device = base_input["hints_hints"].device 156 | hints8, subpix8 = project_depth( 157 | base_input["hints_hints"], 158 | batch["intrinsics"], 159 | adjust_intrinsics(batch["intrinsics"], 1 / 8), 160 | torch.eye(4, device=device)[None], 161 | torch.zeros( 162 | 1, 1, h // 8, w // 8, device=device, dtype=base_input["hints_hints"].dtype 163 | ), 164 | ) 165 | init_depth = create_init_depth(hints8) 166 | out |= { 167 | "hints8": hints8, 168 | "subpix8": subpix8, 169 | "init_depth": init_depth, 170 | } 171 | return out, base_input 172 | 173 | 174 | def adjust_intrinsics(intrinsics: Tensor, factor: float = 0.125): 175 | intrinsics = intrinsics.clone() 176 | intrinsics[..., :2, :] = intrinsics[..., :2, :] * factor 177 | return intrinsics 178 | 179 | 180 | def create_init_depth(hints8, fallback_init_depth=2.0, min_hints_for_init=20): 181 | batch_size = hints8.shape[0] 182 | mean_hints = ( 183 | torch.ones(batch_size, dtype=hints8.dtype, device=hints8.device) 184 | * fallback_init_depth 185 | ) 186 | for bi in range(batch_size): 187 | hints_mask = hints8[bi, 0] > 0 188 | if hints_mask.sum() > min_hints_for_init: 189 | mean_hints[bi] = hints8[bi, 0][hints8[bi, 0] > 0].mean() 190 | depth = torch.ones_like(hints8) * mean_hints[:, None, None, None] 191 | depth = torch.where(hints8 > 0, hints8, depth) 192 | return depth 193 | 194 | 195 | def find_hints( 196 | batch, 197 | postfix, 198 | hints_density, 199 | from_pcd: bool = False, 200 | project_pose: Tensor | None = None, 201 | ): 202 | if not from_pcd: 203 | if f"hints{postfix}" in batch: 204 | sparse_hints = batch[f"hints{postfix}"] 205 | else: 206 | sparse_hints = sparsify_depth(batch[f"depth{postfix}"], hints_density) 207 | return sparse_hints 208 | else: 209 | hints_pcd = batch[f"hints_pcd{postfix}"] 210 | device, dtype = hints_pcd.device, hints_pcd.dtype 211 | b, _, h, w = batch["image"].shape 212 | sparse_hints = project_pcd( 213 | batch[f"hints_pcd{postfix}"], 214 | batch["intrinsics"], 215 | torch.zeros(b, 1, h, w, device=device, dtype=dtype), 216 | project_pose, 217 | ) 218 | return sparse_hints 219 | 220 | 221 | def sparsify_depth( 222 | depth: torch.Tensor, hints_perc: float | tuple[float, float] | int = 0.03 223 | ) -> torch.Tensor: 224 | if isinstance(hints_perc, tuple | list): 225 | hints_perc = random.uniform(*hints_perc) 226 | 227 | if hints_perc < 1.0: 228 | sparse_map = torch.rand_like(depth) < hints_perc 229 | sparse_depth = torch.where( 230 | sparse_map, depth, torch.tensor(0.0, dtype=depth.dtype, device=depth.device) 231 | ) 232 | return sparse_depth 233 | else: 234 | b = depth.shape[0] 235 | idxs = torch.nonzero(depth[:, 0]) 236 | idxs = idxs[torch.randperm(len(idxs))] 237 | sparse_depth = torch.zeros_like(depth) 238 | for bi in range(b): 239 | bidxs = idxs[idxs[:, 0] == bi][:hints_perc] 240 | sparse_depth[bi, 0, bidxs[:, 1], bidxs[:, 2]] = depth[ 241 | bi, 0, bidxs[:, 1], bidxs[:, 2] 242 | ] 243 | return sparse_depth 244 | 245 | 246 | def inv_pose(pose: torch.Tensor) -> torch.Tensor: 247 | rot_inv = pose[:, :3, :3].permute(0, 2, 1) 248 | tr_inv = -rot_inv @ pose[:, :3, -1:] 249 | pose_inv = torch.eye(4, dtype=pose.dtype, device=pose.device)[None] 250 | pose_inv = pose_inv.repeat(pose.shape[0], 1, 1) 251 | pose_inv[:, :3, :3] = rot_inv 252 | pose_inv[:, :3, -1:] = tr_inv 253 | return pose_inv 254 | 255 | 256 | def pose_distance( 257 | pose: torch.Tensor, 258 | ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: 259 | rot = pose[:, :3, :3] 260 | trasl = pose[:, :3, 3] 261 | R_trace = rot.diagonal(offset=0, dim1=-1, dim2=-2).sum(-1) 262 | r_measure = torch.sqrt( 263 | 2 * (1 - torch.minimum(torch.ones_like(R_trace) * 3.0, R_trace) / 3) 264 | ) 265 | t_measure = torch.norm(trasl, dim=1) 266 | combined_measure = torch.sqrt(t_measure**2 + r_measure**2) 267 | 268 | return combined_measure, r_measure, t_measure 269 | 270 | 271 | def project_pcd( 272 | xyz_pcd: torch.Tensor, 273 | intrinsics_to: torch.Tensor, 274 | depth_to: torch.Tensor, 275 | extrinsics_from_to: torch.Tensor | None = None, 276 | ) -> torch.Tensor: 277 | # transform pcd 278 | batch, _, _ = xyz_pcd.shape 279 | if extrinsics_from_to is None: 280 | extrinsics_from_to = torch.eye(4, dtype=xyz_pcd.dtype, device=xyz_pcd.device)[ 281 | None 282 | ].repeat(batch, 1, 1) 283 | xyz_pcd_to = ( 284 | intrinsics_to 285 | @ ( 286 | extrinsics_from_to 287 | @ torch.nn.functional.pad(xyz_pcd, [0, 1], "constant", 1.0).permute(0, 2, 1) 288 | )[:, :3] 289 | ) 290 | 291 | # project depth to 2D 292 | h_to, w_to = depth_to.shape[-2:] 293 | u, v = torch.unbind(xyz_pcd_to[:, :2] / xyz_pcd_to[:, -1:], 1) 294 | u, v = torch.round(u).to(torch.long), torch.round(v).to(torch.long) 295 | mask = (u >= 0) & (v >= 0) & (u < w_to) & (v < h_to) 296 | 297 | for b in range(batch): 298 | used_mask = mask[b] 299 | used_u, used_v = u[b, used_mask], v[b, used_mask] 300 | prev_depths = depth_to[b, 0, used_v, used_u] 301 | new_depths = xyz_pcd_to[b, :, used_mask][-1] 302 | merged_depths = torch.where( 303 | (prev_depths == 0) & (new_depths > 0), new_depths, prev_depths 304 | ) 305 | depth_to[b, 0, used_v, used_u] = merged_depths 306 | 307 | return depth_to 308 | 309 | 310 | def project_depth( 311 | depth: Tensor, 312 | intrinsics_from: Tensor, 313 | intrinsics_to: Tensor, 314 | extrinsics_from_to: Tensor, 315 | depth_to: Tensor, 316 | ) -> tuple[Tensor, Tensor]: 317 | # project the depth in 3D 318 | batch, _, h, w = depth.shape 319 | xyz_pcd_from = kornia.geometry.depth_to_3d(depth, intrinsics_from).permute( 320 | 0, 2, 3, 1 321 | ) 322 | xyz_pcd_to = intrinsics_to @ ( 323 | ( 324 | extrinsics_from_to 325 | @ torch.nn.functional.pad( 326 | xyz_pcd_from.view(batch, -1, 3), [0, 1], "constant", 1.0 327 | ).permute(0, 2, 1) 328 | )[:, :3] 329 | ) 330 | xyz_pcd_to = xyz_pcd_to.permute(0, 2, 1).view(batch, h, w, 3) 331 | 332 | # project depth to 2D 333 | h_to, w_to = depth_to.shape[-2:] 334 | u_subpix, v_subpix = torch.unbind(xyz_pcd_to[..., :2] / xyz_pcd_to[..., -1:], -1) 335 | u, v = torch.round(u_subpix).to(torch.long), torch.round(v_subpix).to(torch.long) 336 | mask = (u >= 0) & (v >= 0) & (u < w_to) & (v < h_to) & (depth[:, 0] > 0) 337 | subpix = torch.zeros( 338 | batch, 2, h_to, w_to, dtype=depth_to.dtype, device=depth_to.device 339 | ) 340 | for b in range(batch): 341 | used_mask = mask[b] 342 | used_u, used_v = u[b, used_mask], v[b, used_mask] 343 | prev_depths = depth_to[b, 0, used_v, used_u] 344 | new_depths = xyz_pcd_to[b, used_mask][:, -1] 345 | merged_depths = torch.where(prev_depths == 0, new_depths, prev_depths) 346 | depth_to[b, 0, used_v, used_u] = merged_depths 347 | subpix[b, 0, used_v, used_u] = u_subpix[b, used_mask] / w_to 348 | subpix[b, 1, used_v, used_u] = v_subpix[b, used_mask] / h_to 349 | 350 | return depth_to, subpix 351 | 352 | 353 | def prepare_plot( 354 | gt_depth, 355 | pred_depth, 356 | image_target, 357 | image_source, 358 | hints, 359 | gt_dilation: int | None = None, 360 | hints_dilation: int | None = None, 361 | depth_vmin: float = 0.0, 362 | depth_vmax: float = 10.0, 363 | ) -> GridFigure: 364 | b, _, h, w = image_source.shape 365 | grid = GridFigure(b, 4, size=(h // 2, w // 2)) 366 | 367 | if hints_dilation: 368 | hints = dilation( 369 | hints, torch.ones((hints_dilation, hints_dilation), device=hints.device) 370 | ) 371 | if gt_dilation: 372 | gt_depth = dilation( 373 | gt_depth, torch.ones((gt_dilation, gt_dilation), device=gt_depth.device) 374 | ) 375 | 376 | dkwargs = { 377 | "vmin": depth_vmin, 378 | "vmax": depth_vmax, 379 | "cmap": "magma_r", 380 | } 381 | 382 | idx = 0 383 | for bi in range(b): 384 | 385 | # source view 386 | grid.imshow(idx := idx + 1, image_source[bi]) 387 | 388 | # target view + hints 389 | img_target = image_target[bi] 390 | img_target = (img_target - img_target.min()) / ( 391 | img_target.max() - img_target.min() 392 | ) 393 | img_target = ( 394 | (img_target.cpu().numpy() * 255).astype(np.uint8).transpose([1, 2, 0]) 395 | ) 396 | grid.imshow(idx := idx + 1, img_target, norm=False) 397 | hints_color = color_depth(hints[bi, 0].cpu().numpy(), **dkwargs) 398 | hints_color[..., :3] = hints_color[..., :3] / hints_color[..., :3].max() 399 | hints_color[..., -1:] = hints_color[..., -1:] / hints_color[..., -1:].max() 400 | grid.imshow(idx, hints_color, norm=False) 401 | 402 | # prediction 403 | grid.imshow(idx := idx + 1, pred_depth[bi], **dkwargs, norm=False) 404 | 405 | # gt view 406 | gt_depth_imshow = torch.where(gt_depth[bi] > 0, gt_depth[bi], depth_vmax) 407 | grid.imshow(idx := idx + 1, gt_depth_imshow, **dkwargs, norm=False) 408 | 409 | return grid 410 | 411 | 412 | def pose_to_angles(pose: torch.Tensor): 413 | out_angles = [] 414 | angles = euler_from_quaternion( 415 | *rotation_matrix_to_quaternion( 416 | pose[:, :3, :3].contiguous(), order=QuaternionCoeffOrder.WXYZ 417 | )[0] 418 | ) 419 | angles = torch.stack(angles, 0) 420 | out_angles.append(angles) 421 | return torch.rad2deg(torch.stack(out_angles, 0)) 422 | 423 | 424 | def pose_noise(rot_std: torch.Tensor, trasl_std: torch.Tensor): 425 | std_vec = torch.cat([rot_std, trasl_std], 1) 426 | mean_vec = torch.zeros_like(std_vec) 427 | normal_noise = torch.normal(mean_vec, std_vec) 428 | noises = [] 429 | for bi in range(rot_std.shape[0]): 430 | rot_noise = torch.eye(4).to(normal_noise.dtype).to(normal_noise.device) 431 | quat = kornia.geometry.quaternion_from_euler( 432 | torch.deg2rad(normal_noise[bi, 0]), 433 | torch.deg2rad(normal_noise[bi, 1]), 434 | torch.deg2rad(normal_noise[bi, 2]), 435 | ) 436 | rot_noise[:3, :3] = kornia.geometry.quaternion_to_rotation_matrix( 437 | torch.stack(quat, -1), QuaternionCoeffOrder.WXYZ 438 | ) 439 | rot_noise[:3, -1] += normal_noise[bi, 3:] 440 | noises.append(rot_noise) 441 | return torch.stack(noises, 0) 442 | 443 | 444 | def find_parameters(): 445 | ipy = get_ipython() # type: ignore 446 | out = io.StringIO() 447 | with redirect_stdout(out): 448 | ipy.magic("history") 449 | x = out.getvalue().split("\n") 450 | param_lines = list( 451 | itertools.takewhile( 452 | lambda s: s != "##% END", 453 | itertools.dropwhile(lambda s: s != "##% PARAMETERS", x), 454 | ) 455 | ) 456 | params = [] 457 | for param in param_lines: 458 | if match := re.match("([a-zA-Z_0-9]+).*=", param): 459 | name = match.groups()[0] 460 | params.append(name) 461 | return params 462 | -------------------------------------------------------------------------------- /evaluate/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used metrics in this project 3 | """ 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | __all__ = ["mae", "rmse", "sq_rel", "rel_thresh", "compute_metrics"] 9 | 10 | 11 | def mae(pred: Tensor, gt: Tensor) -> Tensor: 12 | return torch.mean(torch.abs(pred - gt)) 13 | 14 | 15 | def rmse(pred: Tensor, gt: Tensor) -> Tensor: 16 | return torch.sqrt(torch.mean(torch.square(pred - gt))) 17 | 18 | 19 | def sq_rel(pred: Tensor, gt: Tensor, eps: float = 0.0) -> Tensor: 20 | return torch.mean(torch.square(pred - gt) / (gt + eps)) 21 | 22 | 23 | def abs_rel(pred: Tensor, gt: Tensor, eps: float = 0.0) -> Tensor: 24 | return torch.mean(torch.abs(pred - gt) / (gt + eps)) 25 | 26 | 27 | def rel_thresh(pred: Tensor, gt: Tensor, sigma: float) -> Tensor: 28 | rel = torch.maximum(gt / pred, pred / gt) < sigma 29 | rel = torch.mean(rel.float()) 30 | return rel 31 | 32 | 33 | def compute_metrics_depth( 34 | pred: Tensor, gt: Tensor, label: str = "" 35 | ) -> dict[str, Tensor]: 36 | label = label if not label else label + "/" 37 | return { 38 | f"{label}mae": mae(pred, gt), 39 | f"{label}rmse": rmse(pred, gt), 40 | f"{label}sq_rel": sq_rel(pred, gt), 41 | f"{label}abs_rel": abs_rel(pred, gt), 42 | f"{label}rel_thresh_1.05": rel_thresh(pred, gt, 1.05), 43 | f"{label}rel_thresh_1.25": rel_thresh(pred, gt, 1.25), 44 | } 45 | -------------------------------------------------------------------------------- /evaluate/utils/pose_estimation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Pose estimation by means of RGBD frame + RGB py means of 3 | LightGlue and PnP + LO-RANSAC 4 | """ 5 | 6 | import torch 7 | from torch import Tensor 8 | from lightglue import DISK, LightGlue 9 | import poselib 10 | import numpy as np 11 | from kornia.morphology import dilation 12 | 13 | __all__ = ["PoseEstimation"] 14 | 15 | 16 | class PoseEstimation(torch.nn.Module): 17 | def __init__( 18 | self, 19 | max_num_keypoints: int | None = None, 20 | dilate_depth: int | None = None, 21 | compile: bool = False, 22 | **kwargs, 23 | ): 24 | super().__init__() 25 | self.extractor = DISK(max_num_keypoints=max_num_keypoints).eval() 26 | self.matcher = LightGlue(features="disk").eval() 27 | if compile: 28 | self.matcher.compile() 29 | self.dilate_depth = dilate_depth 30 | 31 | # kwargs default 32 | if "ransac_max_reproj_error" not in kwargs: 33 | kwargs["ransac_max_reproj_error"] = 1.0 34 | 35 | self.kwargs = kwargs 36 | 37 | @torch.autocast("cuda", enabled=False) 38 | def forward( 39 | self, 40 | image0: Tensor, 41 | image1: Tensor, 42 | depth1: Tensor, 43 | intrinsics0: Tensor, 44 | intrinsics1: Tensor, 45 | ): 46 | # depth dilation for easier matching 47 | if self.dilate_depth is not None: 48 | depth1 = dilation( 49 | depth1, 50 | torch.ones( 51 | [self.dilate_depth, self.dilate_depth], 52 | device=depth1.device, 53 | dtype=depth1.dtype, 54 | ), 55 | ) 56 | 57 | with torch.no_grad(): 58 | batch_size = image0.shape[0] 59 | poses = [] 60 | for b in range(batch_size): 61 | feats0 = self.extractor.extract(image0[b]) 62 | feats1 = self.extractor.extract(image1[b]) 63 | matches01 = self.matcher({"image0": feats0, "image1": feats1}) 64 | matches01 = matches01["matches"] 65 | matches0 = _to_numpy(feats0["keypoints"][0][matches01[0][:, 0]]) 66 | matches1 = _to_numpy(feats1["keypoints"][0][matches01[0][:, 1]]) 67 | 68 | depths1 = _to_numpy(depth1)[ 69 | b, 70 | 0, 71 | matches1[:, 1].round().astype(int), 72 | matches1[:, 0].round().astype(int), 73 | ] 74 | valid_mask = depths1 > 0 75 | pose = np.eye(4) 76 | pose[:3] = poselib.estimate_absolute_pose( 77 | matches0[valid_mask], 78 | _depth_to_3d( 79 | depths1[valid_mask], 80 | matches1[valid_mask], 81 | _to_numpy(intrinsics1[b]), 82 | ), 83 | { 84 | "model": "SIMPLE_PINHOLE", 85 | "width": image0.shape[-1], 86 | "height": image0.shape[-2], 87 | "params": [ 88 | intrinsics0[b, 0, 0].item(), # fx 89 | intrinsics0[b, 0, 2].item(), # cx 90 | intrinsics0[b, 1, 2].item(), # cy 91 | ], 92 | }, 93 | { 94 | "_".join(k.split("_")[1:]): v 95 | for k, v in self.kwargs.items() 96 | if k.startswith("ransac_") 97 | }, 98 | { 99 | "_".join(k.split("_")[1:]): v 100 | for k, v in self.kwargs.items() 101 | if k.startswith("ransac_") 102 | }, 103 | )[0].Rt 104 | poses.append(pose) 105 | return ( 106 | torch.from_numpy(np.stack(poses, 0)).to(image0.device).to(image0.dtype) 107 | ) 108 | 109 | 110 | # utils 111 | 112 | 113 | def _to_numpy(x: Tensor): 114 | return x.detach().cpu().numpy() 115 | 116 | 117 | def _depth_to_3d(depths: np.ndarray, coords: np.ndarray, intrinsics: np.ndarray): 118 | fx, fy = intrinsics[0, 0], intrinsics[1, 1] 119 | cx, cy = intrinsics[0, 2], intrinsics[1, 2] 120 | x3d = depths * (coords[:, 0] - cx) / fx 121 | y3d = depths * (coords[:, 1] - cy) / fy 122 | return np.stack([x3d, y3d, depths], -1) 123 | 124 | 125 | def _norm(img: Tensor): 126 | return (img - img.min()) / (img.max() - img.min()) 127 | -------------------------------------------------------------------------------- /lib/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreaconti/depth-on-demand/7376cd0f7a67e583b3d0eff86641cfd7af83de67/lib/__init__.py -------------------------------------------------------------------------------- /lib/benchmark/mesh/__init__.py: -------------------------------------------------------------------------------- 1 | try: 2 | from ._fusion import TSDFFusion 3 | from ._evaluation import mesh_metrics 4 | 5 | __all__ = ["TSDFFusion", "mesh_metrics"] 6 | except ImportError: 7 | import warnings 8 | 9 | warnings.warn("open3d not found, can't import this module") 10 | -------------------------------------------------------------------------------- /lib/benchmark/mesh/_evaluation.py: -------------------------------------------------------------------------------- 1 | """ 2 | Evaluation procedure to compute metrics 3 | on meshes 4 | """ 5 | 6 | import open3d as o3d 7 | from pytorch3d.ops import knn_points 8 | import numpy as np 9 | from pathlib import Path 10 | import torch 11 | import warnings 12 | 13 | __all__ = ["mesh_metrics"] 14 | 15 | 16 | def mesh_metrics( 17 | mesh_prediction: o3d.geometry.TriangleMesh | str | Path, 18 | mesh_groundtruth: o3d.geometry.TriangleMesh | str | Path, 19 | world2grid: np.ndarray | str | Path | None = None, 20 | occlusion_mask: np.ndarray | str | Path | None = None, 21 | dist_threshold: float = 0.05, 22 | max_dist: float = 1.0, 23 | num_points_samples: int = 200000, 24 | device: str | torch.device = "cpu", 25 | ) -> dict[str, torch.Tensor]: 26 | """ 27 | Takes in input the predicted and groundtruth meshes and computes a set 28 | of common metrics between them: accuracy and completion, precision and 29 | recall and their f1_score. To compute fair metrics the predicted mesh 30 | can be masked to remove areas not available in the groundtruth providing 31 | the occlusion mask and the world2grid parameters (both). 32 | """ 33 | # read data if not already done 34 | assert ( 35 | occlusion_mask is None 36 | and world2grid is None 37 | or occlusion_mask is not None 38 | and world2grid is not None 39 | ), "occlusion_mask and world2grid must be both provided" 40 | if isinstance(mesh_groundtruth, (str, Path)): 41 | mesh_groundtruth = o3d.io.read_triangle_mesh(str(mesh_groundtruth)) 42 | if isinstance(mesh_prediction, (str, Path)): 43 | mesh_prediction = o3d.io.read_triangle_mesh(str(mesh_prediction)) 44 | if world2grid is not None and isinstance(world2grid, (str, Path)): 45 | world2grid = np.loadtxt(world2grid, dtype=np.float32) 46 | if occlusion_mask is not None and isinstance(occlusion_mask, (str, Path)): 47 | occlusion_mask = np.load(occlusion_mask).astype(np.float32) 48 | 49 | # compute gt -> pred distance 50 | points_pred = torch.from_numpy( 51 | np.asarray( 52 | mesh_prediction.sample_points_uniformly(num_points_samples).points, 53 | dtype=np.float32, 54 | ) 55 | ) 56 | points_pred = points_pred.to(device) 57 | points_gt = torch.from_numpy( 58 | np.asarray(mesh_groundtruth.vertices, dtype=np.float32) 59 | ).to(device) 60 | world2grid = torch.from_numpy(world2grid).to(device) 61 | occlusion_mask = torch.from_numpy(occlusion_mask).to(device).float() 62 | gt2pred_dist = chamfer_distance(points_gt[None], points_pred[None], max_dist) 63 | 64 | # compute pred -> gt distance 65 | points_pred_filtered = filter_occluded_points( 66 | points_pred, world2grid, occlusion_mask, device 67 | ) 68 | pred2gt_dist = chamfer_distance( 69 | points_pred_filtered[None], points_gt[None], max_dist 70 | ) 71 | 72 | # compute metrics 73 | accuracy = torch.mean(pred2gt_dist) 74 | completion = torch.mean(gt2pred_dist) 75 | precision = (pred2gt_dist <= dist_threshold).float().mean() 76 | recall = (gt2pred_dist <= dist_threshold).float().mean() 77 | f1_score = 2 * precision * recall / (precision + recall) 78 | chamfer = 0.5 * (accuracy + completion) 79 | 80 | return { 81 | "precision": precision, 82 | "recall": recall, 83 | "f1_score": f1_score, 84 | "accuracy": accuracy, 85 | "completion": completion, 86 | "chamfer": chamfer, 87 | } 88 | 89 | 90 | def chamfer_distance(points1, points2, max_dist): 91 | l2dists = knn_points(points1, points2).dists 92 | sqdists = torch.where(l2dists > 0, torch.sqrt(l2dists), 0.0) 93 | sqdists = torch.minimum(sqdists, torch.full_like(sqdists, max_dist)) 94 | return sqdists 95 | 96 | 97 | def filter_occluded_points(points_pred, world2grid, occlusion_mask, device): 98 | dim_x = occlusion_mask.shape[0] 99 | dim_y = occlusion_mask.shape[1] 100 | dim_z = occlusion_mask.shape[2] 101 | num_points_pred = points_pred.shape[0] 102 | 103 | # Transform points to bbox space. 104 | R_world2grid = world2grid[:3, :3].view(1, 3, 3).expand(num_points_pred, -1, -1) 105 | t_world2grid = world2grid[:3, 3].view(1, 3, 1).expand(num_points_pred, -1, -1) 106 | 107 | points_pred_coords = ( 108 | torch.matmul(R_world2grid, points_pred.view(num_points_pred, 3, 1)) 109 | + t_world2grid 110 | ).view(num_points_pred, 3) 111 | 112 | # Normalize to [-1, 1]^3 space. 113 | # The world2grid transforms world positions to voxel centers, so we need to 114 | # use "align_corners=True". 115 | points_pred_coords[:, 0] /= dim_x - 1 116 | points_pred_coords[:, 1] /= dim_y - 1 117 | points_pred_coords[:, 2] /= dim_z - 1 118 | points_pred_coords = points_pred_coords * 2 - 1 119 | 120 | # Trilinearly interpolate occlusion mask. 121 | # Occlusion mask is given as (x, y, z) storage, but the grid_sample method 122 | # expects (c, z, y, x) storage. 123 | visibility_mask = 1 - occlusion_mask.view(dim_x, dim_y, dim_z) 124 | visibility_mask = visibility_mask.permute(2, 1, 0).contiguous() 125 | visibility_mask = visibility_mask.view(1, 1, dim_z, dim_y, dim_x) 126 | 127 | points_pred_coords = points_pred_coords.view(1, 1, 1, num_points_pred, 3) 128 | 129 | points_pred_visibility = torch.nn.functional.grid_sample( 130 | visibility_mask, 131 | points_pred_coords, 132 | mode="bilinear", 133 | padding_mode="zeros", 134 | align_corners=True, 135 | ).to(device) 136 | 137 | points_pred_visibility = points_pred_visibility.view(num_points_pred) 138 | 139 | eps = 1e-5 140 | points_pred_visibility = points_pred_visibility >= 1 - eps 141 | 142 | # Filter occluded predicted points. 143 | if points_pred_visibility.sum() == 0: 144 | # If no points are visible, we keep the original points, otherwise 145 | # we would penalize the sample as if nothing is predicted. 146 | warnings.warn( 147 | "All points occluded, keeping all predicted points!", RuntimeWarning 148 | ) 149 | points_pred_visible = points_pred.clone() 150 | else: 151 | points_pred_visible = points_pred[points_pred_visibility] 152 | 153 | return points_pred_visible 154 | -------------------------------------------------------------------------------- /lib/benchmark/mesh/_fusion.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities to build meshes by RGBD frames 3 | """ 4 | 5 | import open3d as o3d 6 | import numpy as np 7 | import imageio.v3 as imageio 8 | from typing import Literal 9 | from pathlib import Path 10 | 11 | __all__ = ["TSDFFusion"] 12 | 13 | 14 | class TSDFFusion: 15 | """ 16 | Fusion utility to build a mesh from multiple RGBD calibrated and with pose 17 | frames through TSDF. 18 | """ 19 | 20 | def __init__( 21 | self, 22 | engine: Literal["open3d", "open3d-tensor"] = "open3d", 23 | # volume info 24 | voxel_length=0.04, 25 | sdf_trunc=3 * 0.04, 26 | # defaults 27 | color_type: o3d.pipelines.integration.TSDFVolumeColorType = o3d.pipelines.integration.TSDFVolumeColorType.RGB8, 28 | depth_scale: float = 1.0, 29 | depth_trunc: float = 3.0, 30 | convert_rgb_to_intensity: bool = False, 31 | device: str = "cuda:0", 32 | ): 33 | assert engine in ["open3d", "open3d-tensor"] 34 | self.engine = engine 35 | self.device = o3d.core.Device(device.upper()) 36 | self.depth_scale = depth_scale 37 | self.depth_trunc = depth_trunc 38 | self.voxel_length = voxel_length 39 | self.convert_rgb_to_intensity = convert_rgb_to_intensity 40 | 41 | if engine == "open3d": 42 | self.volume = o3d.pipelines.integration.ScalableTSDFVolume( 43 | voxel_length=voxel_length, 44 | sdf_trunc=sdf_trunc, 45 | color_type=color_type, 46 | ) 47 | elif engine == "open3d-tensor": 48 | self.volume = o3d.t.geometry.VoxelBlockGrid( 49 | attr_names=("tsdf", "weight", "color"), 50 | attr_dtypes=(o3d.core.float32, o3d.core.float32, o3d.core.float32), 51 | attr_channels=((1,), (1,), (3,)), 52 | voxel_size=voxel_length, 53 | device=self.device, 54 | ) 55 | 56 | def convert_intrinsics( 57 | self, 58 | intrinsics: np.ndarray, 59 | height: int | None = None, 60 | width: int | None = None, 61 | ) -> o3d.camera.PinholeCameraIntrinsic | o3d.core.Tensor: 62 | intrins = o3d.camera.PinholeCameraIntrinsic( 63 | width if width else int(intrinsics[0, 2] * 2), 64 | height if height else int(intrinsics[1, 2] * 2), 65 | intrinsics, 66 | ) 67 | if self.engine == "open3d-tensor": 68 | intrins = o3d.core.Tensor(intrins.intrinsic_matrix, o3d.core.Dtype.Float64) 69 | return intrins 70 | 71 | def read_rgbd( 72 | self, 73 | image: str | Path, 74 | depth: str | Path, 75 | depth_trunc: float | None = None, 76 | convert_rgb_to_intensity: bool | None = None, 77 | ) -> o3d.geometry.RGBDImage | tuple[o3d.core.Tensor, o3d.core.Tensor]: 78 | if self.engine == "open3d": 79 | return o3d.geometry.RGBDImage.create_from_color_and_depth( 80 | imageio.imread(image), 81 | imageio.imread(depth), 82 | depth_trunc=self.depth_trunc if depth_trunc is None else depth_trunc, 83 | convert_rgb_to_intensity=self.convert_rgb_to_intensity 84 | if convert_rgb_to_intensity is None 85 | else convert_rgb_to_intensity, 86 | ) 87 | else: 88 | color = o3d.t.io.read_image(image).to(self.device) 89 | depth = o3d.t.io.read_image(depth).to(self.device) 90 | return color, depth 91 | 92 | def convert_rgbd( 93 | self, 94 | image: np.ndarray, 95 | depth: np.ndarray, 96 | depth_scale: float | None = None, 97 | depth_trunc: float | None = None, 98 | convert_rgb_to_intensity: bool | None = None, 99 | ) -> o3d.geometry.RGBDImage | tuple[o3d.core.Tensor, o3d.core.Tensor]: 100 | if self.engine == "open3d": 101 | if image.dtype != np.uint8: 102 | image = (image - image.min()) / (image.max() - image.min()) 103 | image = np.round(image * 255).astype(np.uint8) 104 | image = o3d.geometry.Image(np.ascontiguousarray(image)) 105 | depth = o3d.geometry.Image(depth[..., 0]) 106 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 107 | image, 108 | depth, 109 | depth_scale=self.depth_scale if depth_scale is None else depth_scale, 110 | depth_trunc=self.depth_trunc if depth_trunc is None else depth_trunc, 111 | convert_rgb_to_intensity=self.convert_rgb_to_intensity 112 | if convert_rgb_to_intensity is None 113 | else convert_rgb_to_intensity, 114 | ) 115 | return rgbd 116 | else: 117 | image = o3d.t.geometry.Image( 118 | o3d.core.Tensor( 119 | np.ascontiguousarray(image).astype(np.float32), device=self.device 120 | ) 121 | ) 122 | depth = o3d.t.geometry.Image( 123 | o3d.core.Tensor(depth[..., 0], device=self.device) 124 | ) 125 | return image, depth 126 | 127 | def integrate_rgbd( 128 | self, 129 | image: np.ndarray | str | Path, 130 | depth: np.ndarray | str | Path, 131 | intrinsics: np.ndarray, 132 | extrinsics: np.ndarray, 133 | **kwargs, 134 | ): 135 | if isinstance(image, np.ndarray) and isinstance(depth, np.ndarray): 136 | rgbd = self.convert_rgbd(image, depth, **kwargs) 137 | elif isinstance(image, (str, Path)) and isinstance(depth, (str, Path)): 138 | rgbd = self.read_rgbd(image, depth, **kwargs) 139 | else: 140 | raise ValueError( 141 | f"not supported image and depth types ({type(image), type(depth)})" 142 | ) 143 | 144 | if self.engine == "open3d": 145 | w, h = rgbd.color.get_max_bound().astype(int) 146 | self.volume.integrate( 147 | rgbd, 148 | self.convert_intrinsics(intrinsics, h, w), 149 | np.linalg.inv(extrinsics), 150 | ) 151 | else: 152 | color, depth = rgbd 153 | extrins = o3d.core.Tensor(np.linalg.inv(extrinsics), o3d.core.Dtype.Float64) 154 | intrins = o3d.core.Tensor(intrinsics, o3d.core.Dtype.Float64) 155 | coords = self.volume.compute_unique_block_coordinates( 156 | depth, 157 | intrins, 158 | extrins, 159 | depth_scale=self.depth_scale, 160 | depth_max=self.depth_trunc, 161 | ) 162 | self.volume.integrate( 163 | coords, 164 | depth, 165 | color, 166 | intrins, 167 | extrins, 168 | depth_scale=float(self.depth_scale), 169 | depth_max=float(self.depth_trunc), 170 | ) 171 | 172 | def write_triangle_mesh(self, path: str | Path): 173 | return o3d.io.write_triangle_mesh(str(path), self.triangle_mesh()) 174 | 175 | def triangle_mesh(self): 176 | if self.engine == "open3d": 177 | return self.volume.extract_triangle_mesh() 178 | else: 179 | return self.volume.extract_triangle_mesh().to_legacy() 180 | 181 | def reset(self): 182 | if self.engine == "open3d": 183 | self.volume.reset() 184 | else: 185 | self.volume = o3d.t.geometry.VoxelBlockGrid( 186 | attr_names=("tsdf", "weight", "color"), 187 | attr_dtypes=(o3d.core.float32, o3d.core.float32, o3d.core.float32), 188 | attr_channels=((1,), (1,), (3,)), 189 | voxel_size=self.voxel_length, 190 | device=self.device, 191 | ) 192 | -------------------------------------------------------------------------------- /lib/dataset/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Common interface to load the datasets, simply import the ``load_datamodule`` function, 3 | provide the name of the dataset to load and optionally the specific arguments for that 4 | dataset 5 | """ 6 | 7 | import torch 8 | from .scannetv2 import ScanNetV2DataModule 9 | from .sevenscenes import SevenScenesDataModule 10 | from .tartanair import TartanairDataModule 11 | from .kitti import KittiDataModule 12 | from ._utils import list_scans 13 | from lightning import LightningDataModule 14 | import torchvision.transforms.functional as TF 15 | 16 | __all__ = ["load_datamodule", "list_scans"] 17 | 18 | 19 | def load_datamodule(name: str, /, **kwargs) -> LightningDataModule: 20 | match name: 21 | case "scannetv2": 22 | DataModule = ScanNetV2DataModule 23 | case "sevenscenes": 24 | DataModule = SevenScenesDataModule 25 | case "tartanair": 26 | DataModule = TartanairDataModule 27 | case "kitti": 28 | DataModule = KittiDataModule 29 | case other: 30 | raise ValueError(f"dataset {other} not available") 31 | 32 | return DataModule(**kwargs, eval_transform=_Preprocess(name)) 33 | 34 | 35 | class _Preprocess: 36 | def __init__(self, dataset: str): 37 | self.dataset = dataset 38 | 39 | def __call__(self, sample: dict) -> dict: 40 | for k in sample: 41 | if "path" not in k: 42 | if "intrinsics" not in k and "extrinsics" not in k and "pcd" not in k: 43 | sample[k] = TF.to_tensor(sample[k]) 44 | else: 45 | sample[k] = torch.from_numpy(sample[k]) 46 | if self.dataset == "scannetv2": 47 | if "image" in k or "depth" in k: 48 | sample[k] = TF.center_crop(sample[k], [464, 624]) 49 | if "intrinsics" in k: 50 | intrins = sample[k].clone() 51 | intrins[0, -1] -= 8 52 | intrins[1, -1] -= 8 53 | sample[k] = intrins 54 | if "image" in k: 55 | sample[k] = TF.normalize( 56 | sample[k], mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 57 | ) 58 | return sample 59 | -------------------------------------------------------------------------------- /lib/dataset/_resources/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Put in this folder the assets that you want to track with git 3 | and that you load in you datasets, like json files containing 4 | the splits, then you can load them by means of 5 | 6 | >>> from pathlib import Path 7 | >>> resource_path = Path(__file__).parent / "_resources/" 8 | """ 9 | -------------------------------------------------------------------------------- /lib/dataset/_resources/scannetv2/test_split.txt: -------------------------------------------------------------------------------- 1 | scene0707_00 2 | scene0708_00 3 | scene0709_00 4 | scene0710_00 5 | scene0711_00 6 | scene0712_00 7 | scene0713_00 8 | scene0714_00 9 | scene0715_00 10 | scene0716_00 11 | scene0717_00 12 | scene0718_00 13 | scene0719_00 14 | scene0720_00 15 | scene0721_00 16 | scene0722_00 17 | scene0723_00 18 | scene0724_00 19 | scene0725_00 20 | scene0726_00 21 | scene0727_00 22 | scene0728_00 23 | scene0729_00 24 | scene0730_00 25 | scene0731_00 26 | scene0732_00 27 | scene0733_00 28 | scene0734_00 29 | scene0735_00 30 | scene0736_00 31 | scene0737_00 32 | scene0738_00 33 | scene0739_00 34 | scene0740_00 35 | scene0741_00 36 | scene0742_00 37 | scene0743_00 38 | scene0744_00 39 | scene0745_00 40 | scene0746_00 41 | scene0747_00 42 | scene0748_00 43 | scene0749_00 44 | scene0750_00 45 | scene0751_00 46 | scene0752_00 47 | scene0753_00 48 | scene0754_00 49 | scene0755_00 50 | scene0756_00 51 | scene0757_00 52 | scene0758_00 53 | scene0759_00 54 | scene0760_00 55 | scene0761_00 56 | scene0762_00 57 | scene0763_00 58 | scene0764_00 59 | scene0765_00 60 | scene0766_00 61 | scene0767_00 62 | scene0768_00 63 | scene0769_00 64 | scene0770_00 65 | scene0771_00 66 | scene0772_00 67 | scene0773_00 68 | scene0774_00 69 | scene0775_00 70 | scene0776_00 71 | scene0777_00 72 | scene0778_00 73 | scene0779_00 74 | scene0780_00 75 | scene0781_00 76 | scene0782_00 77 | scene0783_00 78 | scene0784_00 79 | scene0785_00 80 | scene0786_00 81 | scene0787_00 82 | scene0788_00 83 | scene0789_00 84 | scene0790_00 85 | scene0791_00 86 | scene0792_00 87 | scene0793_00 88 | scene0794_00 89 | scene0795_00 90 | scene0796_00 91 | scene0797_00 92 | scene0798_00 93 | scene0799_00 94 | scene0800_00 95 | scene0801_00 96 | scene0802_00 97 | scene0803_00 98 | scene0804_00 99 | scene0805_00 100 | scene0806_00 -------------------------------------------------------------------------------- /lib/dataset/_resources/sevenscenes/test_split.txt: -------------------------------------------------------------------------------- 1 | chess-seq-01 2 | chess-seq-02 3 | fire-seq-01 4 | fire-seq-02 5 | heads-seq-02 6 | office-seq-01 7 | office-seq-03 8 | pumpkin-seq-03 9 | pumpkin-seq-06 10 | redkitchen-seq-01 11 | redkitchen-seq-07 12 | stairs-seq-02 13 | stairs-seq-06 -------------------------------------------------------------------------------- /lib/dataset/_resources/tartanair/test_split.txt: -------------------------------------------------------------------------------- 1 | abandonedfactory-Easy-P011 2 | abandonedfactory-Hard-P011 3 | abandonedfactory_night-Easy-P013 4 | abandonedfactory_night-Hard-P014 5 | amusement-Easy-P008 6 | amusement-Hard-P007 7 | carwelding-Easy-P007 8 | endofworld-Easy-P009 9 | gascola-Easy-P008 10 | gascola-Hard-P009 11 | hospital-Easy-P036 12 | hospital-Hard-P049 13 | japanesealley-Easy-P007 14 | japanesealley-Hard-P005 15 | neighborhood-Easy-P021 16 | neighborhood-Hard-P017 17 | ocean-Easy-P013 18 | ocean-Hard-P009 19 | office-Hard-P007 20 | office2-Easy-P011 21 | office2-Hard-P010 22 | oldtown-Easy-P007 23 | oldtown-Hard-P008 24 | seasidetown-Easy-P009 25 | seasonsforest-Easy-P011 26 | seasonsforest-Hard-P006 27 | seasonsforest_winter-Easy-P009 28 | seasonsforest_winter-Hard-P018 29 | soulcity-Easy-P012 30 | soulcity-Hard-P009 31 | westerndesert-Easy-P013 32 | westerndesert-Hard-P007 33 | -------------------------------------------------------------------------------- /lib/dataset/_utils.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | from typing import Literal, Callable, Type 3 | import torchdata.datapipes as dp 4 | from torch import Tensor 5 | from torch.utils.data import DataLoader 6 | from lightning import LightningDataModule 7 | from collections import defaultdict 8 | import torchdata.datapipes as dp 9 | import numpy as np 10 | 11 | __all__ = ["GenericDataModule", "LoadDataset"] 12 | 13 | 14 | class GenericDataModule(LightningDataModule): 15 | def __init__( 16 | self, 17 | dataset: str, 18 | # dataset specific 19 | keyframes: str = "standard", 20 | load_prevs: int = 7, 21 | cycle: bool = True, 22 | filter_scans: list[str] | None = None, 23 | val_on_test: bool = False, 24 | split_test_scans_loaders: bool = False, 25 | # dataloader specific 26 | batch_size: int = 1, 27 | shuffle: bool = True, 28 | shuffle_keyframes: bool = False, 29 | order_sources_by_pose: bool = False, 30 | num_workers: int = 8, 31 | pin_memory: bool = True, 32 | load_dataset_cls: Type["LoadDataset"] | None = None, 33 | eval_transform: Callable[[dict], dict] | None = None, 34 | **sample_params, 35 | ): 36 | super().__init__() 37 | 38 | if load_prevs < 0: 39 | raise ValueError(f"0 <= load_prevs, not {load_prevs}") 40 | 41 | self.dataset = dataset 42 | self.keyframes = keyframes 43 | self.batch_size = batch_size 44 | self.shuffle = shuffle 45 | self.shuffle_keyframes = shuffle_keyframes 46 | self.num_workers = num_workers 47 | self.eval_transform = eval_transform 48 | self.cycle = cycle 49 | self.filter_scans = filter_scans 50 | self.load_prevs = load_prevs 51 | self.val_on_test = val_on_test 52 | self.split_test_scans_dataloaders = split_test_scans_loaders 53 | self.order_sources_by_pose = order_sources_by_pose 54 | self.pin_memory = pin_memory 55 | self._load_dataset_cls = load_dataset_cls 56 | self._sample_params = sample_params 57 | self._other_params = {} 58 | 59 | def _filter_traj(self, split) -> str | list[str]: 60 | if self.filter_scans is None: 61 | return split 62 | else: 63 | split_scans = list_scans(self.dataset, self.keyframes, split) 64 | return [traj for traj in self.filter_scans if traj in split_scans] 65 | 66 | def _dataloader(self, split, scan=None, num_workers=None): 67 | dl_builder = self._load_dataset_cls( 68 | self.dataset, 69 | keyframes=self.keyframes, 70 | load_prevs=self.load_prevs, 71 | cycle=False, 72 | batch_size=1, 73 | shuffle=False, 74 | shuffle_keyframes=False, 75 | order_sources_by_pose=self.order_sources_by_pose, 76 | transform=self.eval_transform, 77 | num_workers=self.num_workers if num_workers is None else num_workers, 78 | pin_memory=self.pin_memory, 79 | **self._sample_params, 80 | ) 81 | return dl_builder.build_dataloader( 82 | self._filter_traj(split) if scan is None else scan, 83 | ) 84 | 85 | def setup(self, stage: str | None = None): 86 | if stage not in ["test", None]: 87 | raise ValueError(f"stage {stage} invalid") 88 | 89 | if stage in ["test", None]: 90 | if not self.split_test_scans_dataloaders: 91 | self._test_dl = self._dataloader("test") 92 | else: 93 | keyframes = build_scan_frames_mapping( 94 | self.dataset, 95 | self.keyframes, 96 | "test", 97 | self.load_prevs, 98 | ) 99 | if self.filter_scans: 100 | keyframes = { 101 | k: v for k, v in keyframes.items() if k in self.filter_scans 102 | } 103 | self._test_dl = [ 104 | self._dataloader("test", {key: value}, num_workers=1) 105 | for key, value in keyframes.items() 106 | ] 107 | 108 | def test_dataloader(self): 109 | return self._test_dl 110 | 111 | 112 | class LoadDataset: 113 | """ 114 | This function embodies the whole creation of a datalaoder, the unique method 115 | to be overloaded is `load_sample` which loads a single sample of a sequence 116 | """ 117 | 118 | def __init__( 119 | self, 120 | dataset: str, 121 | # dataset specific 122 | keyframes: Literal["standard", "dense", "offline"] = "standard", 123 | load_prevs: int = 7, 124 | cycle: bool = True, 125 | # dataloader specific 126 | batch_size: int = 1, 127 | shuffle: bool = False, 128 | shuffle_keyframes: bool = False, 129 | order_sources_by_pose: bool = True, 130 | transform: Callable[[dict[str, Tensor]], dict[str, Tensor]] | None = None, 131 | num_workers: int = 8, 132 | pin_memory: bool = True, 133 | **kwargs, 134 | ): 135 | self.dataset = dataset 136 | self.keyframes = keyframes 137 | self.load_prevs = load_prevs 138 | self.cycle = cycle 139 | self.batch_size = batch_size 140 | self.shuffle = shuffle 141 | self.shuffle_keyframes = shuffle_keyframes 142 | self.order_sources_by_pose = order_sources_by_pose 143 | self.transform = transform 144 | self.num_workers = num_workers 145 | self.pin_memory = pin_memory 146 | self.sample_kwargs = kwargs 147 | 148 | def build_dataloader( 149 | self, 150 | split: Literal["train", "val", "test"] | str | list[str] | dict = "train", 151 | ): 152 | # checks 153 | if self.load_prevs < 0: 154 | raise ValueError(f"0 <= load_prevs, not {self.load_prevs}") 155 | 156 | # load the keyframe file 157 | if split in ["train", "val", "test"]: 158 | keyframes_path = ( 159 | Path(__file__).parent 160 | / f"_resources/{self.dataset}/keyframes/{self.keyframes}/{split}_split.txt" 161 | ) 162 | if not keyframes_path.exists(): 163 | raise ValueError( 164 | f"split {split} for keyframes {self.keyframes} not available" 165 | ) 166 | with open(keyframes_path, "rt") as f: 167 | keyframes_dict = defaultdict(_empty_dict) 168 | for ln in f: 169 | scene, keyframe, *src_frames = ln.split() 170 | keyframes_dict[scene][keyframe] = src_frames[: self.load_prevs] 171 | elif isinstance(split, (str, list)): 172 | split = [split] if isinstance(split, str) else split 173 | all_lines = [] 174 | for file_path in ( 175 | Path(__file__).parent 176 | / f"_resources/{self.dataset}/keyframes/{self.keyframes}" 177 | ).glob("*_split.txt"): 178 | with open(file_path, "rt") as f: 179 | all_lines.extend(f.readlines()) 180 | keyframes_dict = defaultdict(_empty_dict) 181 | for scene in split: 182 | for ln in all_lines: 183 | if ln.startswith(scene): 184 | scene, keyframe, *src_frames = ln.split() 185 | keyframes_dict[scene][keyframe] = src_frames[: self.load_prevs] 186 | elif isinstance(split, dict): 187 | keyframes_dict = split 188 | else: 189 | raise ValueError(f"split type not allowed") 190 | 191 | # loading and processing pipeline 192 | keyframes_list = [] 193 | for scene in keyframes_dict: 194 | keyframes_list.extend( 195 | [ 196 | {"scan": scene, "keyframe": keyframe, "sources": sources} 197 | for keyframe, sources in keyframes_dict[scene].items() 198 | ] 199 | ) 200 | 201 | ds = dp.map.SequenceWrapper(keyframes_list) 202 | ds = ds.shuffle() if self.shuffle else ds 203 | if self.cycle: 204 | if self.shuffle: 205 | ds = ds.cycle() 206 | else: 207 | ds = dp.iter.IterableWrapper(ds).cycle() 208 | ds = ds.sharding_filter() if self.shuffle or self.cycle else ds 209 | ds = ds.map(self._load) 210 | if self.order_sources_by_pose: 211 | ds = ds.map(order_by_pose) 212 | ds = ds.map(self.transform) if self.transform else ds 213 | return DataLoader( 214 | ds, 215 | batch_size=self.batch_size, 216 | shuffle=self.shuffle, 217 | num_workers=self.num_workers, 218 | pin_memory=self.pin_memory, 219 | ) 220 | 221 | def load_sample(self, scan: str, idx: str, suffix: str = "") -> dict: 222 | raise NotImplementedError("please implement load sample") 223 | 224 | def _load(self, sample): 225 | out = {} 226 | 227 | # shuffle keyframes 228 | if self.shuffle_keyframes: 229 | all_frames = [sample["keyframe"]] + sample["sources"] 230 | keyframe, sources = all_frames[0], all_frames[1:] 231 | sample = {"scan": sample["scan"], "keyframe": keyframe, "sources": sources} 232 | 233 | # load data 234 | out = self.load_sample(sample["scan"], sample["keyframe"], **self.sample_kwargs) 235 | for idx, src in enumerate(sample["sources"]): 236 | out |= self.load_sample( 237 | sample["scan"], src, suffix=f"_prev_{idx}", **self.sample_kwargs 238 | ) 239 | return out 240 | 241 | 242 | #### 243 | 244 | 245 | def list_scans(dataset: str, keyframes: str, split: str | None = None): 246 | """ 247 | List all the scans used in a specific split of a dataset for a specific 248 | keyframe setting. 249 | """ 250 | if split is not None: 251 | keyframes_path = ( 252 | Path(__file__).parent 253 | / f"_resources/{dataset}/keyframes/{keyframes}/{split}_split.txt" 254 | ) 255 | scans = [] 256 | if keyframes_path.exists(): 257 | scans = [ln.split()[0] for ln in open(keyframes_path, "rt")] 258 | else: 259 | scans = [] 260 | for keyframes_path in ( 261 | Path(__file__).parent / f"_resources/{dataset}/keyframes/{keyframes}" 262 | ).glob("*_split.txt"): 263 | scans.extend([ln.split()[0] for ln in open(keyframes_path, "rt")]) 264 | return list(np.unique(scans)) 265 | 266 | 267 | def _empty_dict(): 268 | return {} 269 | 270 | 271 | def build_scan_frames_mapping( 272 | dataset: str, 273 | keyframes: str, 274 | split: str, 275 | load_prevs: int | None = None, 276 | ) -> dict[Path, dict[str, list[str]]]: 277 | """ 278 | Given a dataset, its root, specified keyframes and split it build a dictionary 279 | mapping each scan in such split to a dictionary indexed by the target view and 280 | containing a list of valid source views. 281 | """ 282 | keyframes_path = ( 283 | Path(__file__).parent 284 | / f"_resources/{dataset}/keyframes/{keyframes}/{split}_split.txt" 285 | ) 286 | if not keyframes_path.exists(): 287 | raise ValueError(f"split {split} for keyframes {keyframes} not available") 288 | with open(keyframes_path, "rt") as f: 289 | keyframes_dict = defaultdict(_empty_dict) 290 | for ln in f: 291 | scene, keyframe, *src_frames = ln.split() 292 | if load_prevs is not None: 293 | src_frames = src_frames[:load_prevs] 294 | keyframes_dict[scene][keyframe] = src_frames 295 | return keyframes_dict 296 | 297 | 298 | def order_by_pose(ex): 299 | """ 300 | Takes an example and orders source views according with their distance 301 | from the target view 302 | """ 303 | srcs = sorted({int(k.split("prev_")[-1]) for k in ex.keys() if "prev_" in k}) 304 | tgt_pose = ex["extrinsics"] 305 | distances = np.array( 306 | [ 307 | _pose_distance(_inv_pose(tgt_pose) @ ex[f"extrinsics_prev_{src}"])[0] 308 | for src in srcs 309 | ] 310 | ) 311 | order = np.argsort(distances) 312 | out = {k: v for k, v in ex.items() if "_prev_" not in k} 313 | for new_idx, prev_idx in enumerate(order): 314 | for k, v in ex.items(): 315 | if f"prev_{prev_idx}" in k: 316 | out[k.replace(f"prev_{prev_idx}", f"prev_{new_idx}")] = v 317 | return out 318 | 319 | 320 | def _inv_pose(pose: np.ndarray): 321 | inverted = np.eye(4) 322 | inverted[:3, :3] = pose[:3, :3].T 323 | inverted[:3, 3:] = -inverted[:3, :3] @ pose[:3, 3:] 324 | return inverted 325 | 326 | 327 | def _pose_distance(pose: np.ndarray): 328 | rot = pose[:3, :3] 329 | trasl = pose[:3, 3] 330 | R_trace = rot.diagonal().sum() 331 | r_measure = np.sqrt(2 * (1 - np.minimum(np.ones_like(R_trace) * 3.0, R_trace) / 3)) 332 | t_measure = np.linalg.norm(trasl) 333 | combined_measure = np.sqrt(t_measure**2 + r_measure**2) 334 | return combined_measure, r_measure, t_measure 335 | -------------------------------------------------------------------------------- /lib/dataset/kitti.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataloader for TartanAir 3 | """ 4 | 5 | from pathlib import Path 6 | import imageio.v3 as imageio 7 | import numpy as np 8 | from ._utils import GenericDataModule, LoadDataset 9 | import pykitti 10 | 11 | 12 | __all__ = ["KittiDataModule"] 13 | 14 | 15 | class KittiDataModule(GenericDataModule): 16 | def __init__( 17 | self, 18 | *args, 19 | root_raw: str | Path, 20 | root_completion: str | Path, 21 | min_depth: float = 1e-3, 22 | max_depth: float = 80.0, 23 | load_hints_pcd: bool = False, 24 | **kwargs, 25 | ): 26 | # prepare base class 27 | super().__init__( 28 | "kitti", 29 | *args, 30 | root_raw=root_raw, 31 | root_completion=root_completion, 32 | min_depth=min_depth, 33 | max_depth=max_depth, 34 | load_hints_pcd=load_hints_pcd, 35 | load_dataset_cls=KittiLoadSample, 36 | **kwargs, 37 | ) 38 | 39 | @staticmethod 40 | def augmenting_targets(): 41 | out = { 42 | "image": "image", 43 | "depth": "depth", 44 | "hints": "depth", 45 | "extrinsics": "rt", 46 | "intrinsics": "intrinsics", 47 | "hints_pcd": "pcd", 48 | } 49 | for i in range(20): 50 | out |= { 51 | f"image_prev_{i}": "image", 52 | f"depth_prev_{i}": "depth", 53 | f"hints_prev_{i}": "depth", 54 | f"extrinsics_prev_{i}": "rt", 55 | f"intrinsics_prev_{i}": "intrinsics", 56 | f"hints_pcd_prev_{i}": "pcd", 57 | } 58 | return out 59 | 60 | 61 | class KittiLoadSample(LoadDataset): 62 | def __init__(self, *args, **kwargs): 63 | super().__init__(*args, **kwargs) 64 | self._kitti_meta = {} 65 | 66 | def _get_meta(self, root_raw: str, scan: str): 67 | if scan not in self._kitti_meta: 68 | date = "_".join(scan.split("_")[:3]) 69 | scan_id = scan.split("_")[4] 70 | self._kitti_meta[scan] = pykitti.raw(root_raw, date, scan_id) 71 | return self._kitti_meta[scan] 72 | 73 | def load_sample( 74 | self, 75 | scan: str, 76 | idx: str, 77 | root_raw: str | Path, 78 | root_completion: str | Path, 79 | min_depth: float = 1e-3, 80 | max_depth: float = 100.0, 81 | load_hints_pcd: bool = False, 82 | suffix: str = "", 83 | ) -> dict: 84 | root_raw = Path(root_raw) 85 | root_compl = Path(root_completion) 86 | date = "_".join(scan.split("_")[:3]) 87 | compl_path = next(root_compl.glob(f"*/{scan}/proj_depth")) 88 | 89 | # img 90 | image = imageio.imread(root_raw / date / scan / f"image_02/data/{idx}.png") 91 | 92 | # hints (lidar projected) 93 | hints = ( 94 | imageio.imread(compl_path / f"velodyne_raw/image_02/{idx}.png")[ 95 | ..., None 96 | ].astype(np.float32) 97 | / 256.0 98 | ) 99 | hints[hints < min_depth] = 0.0 100 | hints[hints > max_depth] = 0.0 101 | 102 | # gt 103 | depth = ( 104 | imageio.imread(compl_path / f"groundtruth/image_02/{idx}.png")[ 105 | ..., None 106 | ].astype(np.float32) 107 | / 256.0 108 | ) 109 | depth[depth < min_depth] = 0.0 110 | depth[depth > max_depth] = 0.0 111 | 112 | # intrinsics 113 | scan_meta = self._get_meta(str(root_raw), scan) 114 | intrinsics = scan_meta.calib.K_cam2.astype(np.float32).copy() 115 | 116 | # extrinsics 117 | extrinsics = ( 118 | scan_meta.oxts[int(idx)].T_w_imu 119 | @ np.linalg.inv(scan_meta.calib.T_velo_imu) 120 | @ np.linalg.inv( 121 | scan_meta.calib.R_rect_20 @ scan_meta.calib.T_cam0_velo_unrect 122 | ) 123 | ).astype(np.float32) 124 | 125 | # lidar pcd 126 | pcd = None 127 | if load_hints_pcd: 128 | pcd = scan_meta.get_velo(int(idx))[:, :3] 129 | pcd = ( 130 | scan_meta.calib.T_cam2_velo 131 | @ np.pad(pcd, [(0, 0), (0, 1)], constant_values=1.0).T 132 | ).T[:, :3] 133 | n_points = pcd.shape[0] 134 | padding = 130000 - n_points 135 | pcd = np.pad(pcd, [(0, padding), (0, 0)]).astype(np.float32) 136 | 137 | # crop frames 138 | h, w, _ = depth.shape 139 | lc = (w - 1216) // 2 140 | rc = w - 1216 - lc 141 | image = image[-256:, lc:-rc] 142 | depth = depth[-256:, lc:-rc] 143 | hints = hints[-256:, lc:-rc] 144 | intrinsics[0, -1] -= lc 145 | intrinsics[1, -1] -= h - 256 146 | 147 | # compose output dict 148 | out = { 149 | f"image{suffix}": image, 150 | f"hints{suffix}": hints, 151 | f"depth{suffix}": depth, 152 | f"intrinsics{suffix}": intrinsics, 153 | f"extrinsics{suffix}": extrinsics, 154 | } 155 | if pcd is not None: 156 | out[f"hints_pcd{suffix}"] = pcd 157 | 158 | return out 159 | -------------------------------------------------------------------------------- /lib/dataset/scannetv2.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataloader for ScanNetV2 3 | """ 4 | 5 | from pathlib import Path 6 | import imageio.v3 as imageio 7 | import numpy as np 8 | 9 | __all__ = ["ScanNetV2DataModule"] 10 | 11 | from ._utils import GenericDataModule, LoadDataset 12 | 13 | 14 | class ScanNetV2DataModule(GenericDataModule): 15 | def __init__( 16 | self, 17 | *args, 18 | root: str | Path, 19 | min_depth: float = 1e-3, 20 | max_depth: float = 10.0, 21 | **kwargs, 22 | ): 23 | super().__init__( 24 | "scannetv2", 25 | *args, 26 | root=root, 27 | min_depth=min_depth, 28 | max_depth=max_depth, 29 | load_dataset_cls=ScanNetV2LoadSample, 30 | **kwargs, 31 | ) 32 | 33 | 34 | class ScanNetV2LoadSample(LoadDataset): 35 | def load_sample( 36 | self, 37 | scan: str, 38 | idx: str, 39 | root, 40 | min_depth: float = 1e-3, 41 | max_depth: float = 10.0, 42 | suffix: str = "", 43 | ) -> dict: 44 | root = Path(root) 45 | img = imageio.imread(root / scan / f"{idx}.image.jpg") 46 | depth = ( 47 | imageio.imread(root / scan / f"{idx}.depth.png")[..., None] / 1000 48 | ).astype(np.float32) 49 | depth[(depth < min_depth) | (depth > max_depth)] = 0.0 50 | intrinsics = np.load(root / scan / f"{idx}.intrinsics.npy") 51 | extrinsics = np.load(root / scan / f"{idx}.extrinsics.npy") 52 | 53 | out = {} 54 | if not suffix: 55 | ply_path = root / f"{scan}_vh_clean.ply" 56 | occl_path = root / f"{scan}_occlusion_mask.npy" 57 | world2grid_path = root / f"{scan}_world2grid.txt" 58 | if ply_path.exists(): 59 | out["gt_mesh_path"] = str(ply_path) 60 | if occl_path.exists(): 61 | out["gt_mesh_occl_path"] = str(occl_path) 62 | if world2grid_path.exists(): 63 | out["gt_mesh_world2grid_path"] = str(world2grid_path) 64 | 65 | return out | { 66 | f"image{suffix}": img, 67 | f"depth{suffix}": depth, 68 | f"intrinsics{suffix}": intrinsics, 69 | f"extrinsics{suffix}": extrinsics, 70 | } 71 | -------------------------------------------------------------------------------- /lib/dataset/sevenscenes.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataloader for 7Scenes 3 | """ 4 | 5 | from pathlib import Path 6 | import imageio.v3 as imageio 7 | import numpy as np 8 | 9 | __all__ = ["SevenScenesDataModule"] 10 | 11 | from ._utils import GenericDataModule, LoadDataset 12 | 13 | 14 | class SevenScenesDataModule(GenericDataModule): 15 | def __init__( 16 | self, 17 | *args, 18 | root: str | Path, 19 | min_depth: float = 1e-3, 20 | max_depth: float = 10.0, 21 | **kwargs, 22 | ): 23 | super().__init__( 24 | "sevenscenes", 25 | *args, 26 | root=root, 27 | min_depth=min_depth, 28 | max_depth=max_depth, 29 | load_dataset_cls=SceneScenesLoadSample, 30 | **kwargs, 31 | ) 32 | 33 | 34 | class SceneScenesLoadSample(LoadDataset): 35 | def load_sample( 36 | self, 37 | scan: str, 38 | idx: str, 39 | root, 40 | min_depth: float = 1e-3, 41 | max_depth: float = 10.0, 42 | suffix: str = "", 43 | ) -> dict: 44 | root = Path(root) 45 | img = imageio.imread(root / scan / f"{idx}.image.jpg") 46 | depth = ( 47 | imageio.imread(root / scan / f"{idx}.depth.png")[..., None] / 1000 48 | ).astype(np.float32) 49 | depth[(depth < min_depth) | (depth > max_depth)] = 0.0 50 | intrinsics = np.load(root / scan / f"{idx}.intrinsics.npy") 51 | extrinsics = np.load(root / scan / f"{idx}.extrinsics.npy") 52 | 53 | return { 54 | f"image{suffix}": img, 55 | f"depth{suffix}": depth, 56 | f"intrinsics{suffix}": intrinsics, 57 | f"extrinsics{suffix}": extrinsics, 58 | } 59 | -------------------------------------------------------------------------------- /lib/dataset/tartanair.py: -------------------------------------------------------------------------------- 1 | """ 2 | Dataloader for TartanAir 3 | """ 4 | 5 | from pathlib import Path 6 | import imageio.v3 as imageio 7 | import json 8 | import numpy as np 9 | from ._utils import GenericDataModule, LoadDataset 10 | 11 | 12 | __all__ = ["TartanairDataModule"] 13 | 14 | 15 | class TartanairDataModule(GenericDataModule): 16 | def __init__( 17 | self, 18 | *args, 19 | root: str | Path, 20 | min_depth: float = 1e-3, 21 | max_depth: float = 100.0, 22 | **kwargs, 23 | ): 24 | super().__init__( 25 | "tartanair", 26 | *args, 27 | root=root, 28 | min_depth=min_depth, 29 | max_depth=max_depth, 30 | load_dataset_cls=TartanairLoadSample, 31 | **kwargs, 32 | ) 33 | 34 | 35 | class TartanairLoadSample(LoadDataset): 36 | def load_sample( 37 | self, 38 | scan: str, 39 | idx: str, 40 | root: str | Path, 41 | min_depth: float = 1e-3, 42 | max_depth: float = 100.0, 43 | suffix: str = "", 44 | ) -> dict: 45 | root = Path(root) 46 | img = imageio.imread(root / scan / f"{idx}.image.jpg") 47 | 48 | depth = imageio.imread(root / scan / f"{idx}.depth.png")[..., None].astype( 49 | np.float32 50 | ) 51 | drange = json.load(open(root / scan / f"{idx}.depth_range.json", "rt")) 52 | vmin, vmax = drange["vmin"], drange["vmax"] 53 | depth = vmin + (depth / 65535) * (vmax - vmin) 54 | depth[depth < min_depth] = 0.0 55 | depth[depth > max_depth] = 0.0 56 | 57 | intrinsics = np.load(root / scan / f"{idx}.intrinsics.npy") 58 | extrinsics = np.load(root / scan / f"{idx}.position.npy") 59 | extrinsics = _ned_to_cam.dot(extrinsics).dot(_ned_to_cam.T) 60 | 61 | return { 62 | f"image{suffix}": img, 63 | f"depth{suffix}": depth, 64 | f"intrinsics{suffix}": intrinsics, 65 | f"extrinsics{suffix}": extrinsics, 66 | } 67 | 68 | 69 | _ned_to_cam = np.array( 70 | [ 71 | [0.0, 1.0, 0.0, 0.0], 72 | [0.0, 0.0, 1.0, 0.0], 73 | [1.0, 0.0, 0.0, 0.0], 74 | [0.0, 0.0, 0.0, 1.0], 75 | ], 76 | dtype=np.float32, 77 | ) 78 | -------------------------------------------------------------------------------- /lib/dataset/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import random 3 | import kornia 4 | from torch import Tensor 5 | 6 | __all__ = [ 7 | "sparsify_depth", 8 | "project_depth", 9 | "project_pcd", 10 | "inv_pose", 11 | ] 12 | 13 | 14 | def sparsify_depth( 15 | depth: torch.Tensor, hints_perc: float | tuple[float, float] | int = 0.03 16 | ) -> torch.Tensor: 17 | if isinstance(hints_perc, tuple | list): 18 | hints_perc = random.uniform(*hints_perc) 19 | 20 | if hints_perc < 1.0: 21 | sparse_map = torch.rand_like(depth) < hints_perc 22 | sparse_depth = torch.where( 23 | sparse_map, depth, torch.tensor(0.0, dtype=depth.dtype, device=depth.device) 24 | ) 25 | return sparse_depth 26 | else: 27 | b = depth.shape[0] 28 | idxs = torch.nonzero(depth[:, 0]) 29 | idxs = idxs[torch.randperm(len(idxs))] 30 | sparse_depth = torch.zeros_like(depth) 31 | for bi in range(b): 32 | bidxs = idxs[idxs[:, 0] == bi][:hints_perc] 33 | sparse_depth[bi, 0, bidxs[:, 1], bidxs[:, 2]] = depth[ 34 | bi, 0, bidxs[:, 1], bidxs[:, 2] 35 | ] 36 | return sparse_depth 37 | 38 | 39 | def project_depth( 40 | depth: Tensor, 41 | intrinsics_from: Tensor, 42 | intrinsics_to: Tensor, 43 | extrinsics_from_to: Tensor, 44 | depth_to: Tensor, 45 | ) -> Tensor: 46 | # project the depth in 3D 47 | batch, _, h, w = depth.shape 48 | xyz_pcd_from = kornia.geometry.depth_to_3d(depth, intrinsics_from).permute( 49 | 0, 2, 3, 1 50 | ) 51 | xyz_pcd_to = intrinsics_to @ ( 52 | ( 53 | extrinsics_from_to 54 | @ torch.nn.functional.pad( 55 | xyz_pcd_from.view(batch, -1, 3), [0, 1], "constant", 1.0 56 | ).permute(0, 2, 1) 57 | )[:, :3] 58 | ) 59 | xyz_pcd_to = xyz_pcd_to.permute(0, 2, 1).view(batch, h, w, 3) 60 | 61 | # project depth to 2D 62 | h_to, w_to = depth_to.shape[-2:] 63 | u, v = torch.unbind(xyz_pcd_to[..., :2] / xyz_pcd_to[..., -1:], -1) 64 | u, v = torch.round(u).to(torch.long), torch.round(v).to(torch.long) 65 | mask = (u >= 0) & (v >= 0) & (u < w_to) & (v < h_to) & (depth[:, 0] > 0) 66 | 67 | for b in range(batch): 68 | used_mask = mask[b] 69 | used_u, used_v = u[b, used_mask], v[b, used_mask] 70 | prev_depths = depth_to[b, 0, used_v, used_u] 71 | new_depths = xyz_pcd_to[b, used_mask][:, -1] 72 | merged_depths = torch.where(prev_depths == 0, new_depths, prev_depths) 73 | depth_to[b, 0, used_v, used_u] = merged_depths 74 | 75 | return depth_to 76 | 77 | 78 | def project_pcd( 79 | xyz_pcd: torch.Tensor, 80 | intrinsics_to: torch.Tensor, 81 | depth_to: torch.Tensor, 82 | extrinsics_from_to: torch.Tensor | None = None, 83 | ) -> torch.Tensor: 84 | # transform pcd 85 | batch, _, _ = xyz_pcd.shape 86 | if extrinsics_from_to is None: 87 | extrinsics_from_to = torch.eye(4, dtype=xyz_pcd.dtype, device=xyz_pcd.device)[ 88 | None 89 | ].repeat(batch, 1, 1) 90 | xyz_pcd_to = ( 91 | intrinsics_to 92 | @ ( 93 | extrinsics_from_to 94 | @ torch.nn.functional.pad(xyz_pcd, [0, 1], "constant", 1.0).permute(0, 2, 1) 95 | )[:, :3] 96 | ) 97 | 98 | # project depth to 2D 99 | h_to, w_to = depth_to.shape[-2:] 100 | u, v = torch.unbind(xyz_pcd_to[:, :2] / xyz_pcd_to[:, -1:], 1) 101 | u, v = torch.round(u).to(torch.long), torch.round(v).to(torch.long) 102 | mask = (u >= 0) & (v >= 0) & (u < w_to) & (v < h_to) 103 | 104 | for b in range(batch): 105 | used_mask = mask[b] 106 | used_u, used_v = u[b, used_mask], v[b, used_mask] 107 | prev_depths = depth_to[b, 0, used_v, used_u] 108 | new_depths = xyz_pcd_to[b, :, used_mask][-1] 109 | merged_depths = torch.where( 110 | (prev_depths == 0) & (new_depths > 0), new_depths, prev_depths 111 | ) 112 | depth_to[b, 0, used_v, used_u] = merged_depths 113 | 114 | return depth_to 115 | 116 | 117 | def inv_pose(pose: torch.Tensor) -> torch.Tensor: 118 | rot_inv = pose[:, :3, :3].permute(0, 2, 1) 119 | tr_inv = -rot_inv @ pose[:, :3, -1:] 120 | pose_inv = torch.eye(4, dtype=pose.dtype, device=pose.device)[None] 121 | pose_inv = pose_inv.repeat(pose.shape[0], 1, 1) 122 | pose_inv[:, :3, :3] = rot_inv 123 | pose_inv[:, :3, -1:] = tr_inv 124 | return pose_inv 125 | -------------------------------------------------------------------------------- /lib/dataset/utils/metrics.py: -------------------------------------------------------------------------------- 1 | """ 2 | Used metrics in this project 3 | """ 4 | 5 | import torch 6 | from torch import Tensor 7 | 8 | __all__ = ["mae", "rmse", "sq_rel", "abs_rel", "rel_thresh", "compute_metrics"] 9 | 10 | 11 | def mae(pred: Tensor, gt: Tensor) -> Tensor: 12 | return torch.mean(torch.abs(pred - gt)) 13 | 14 | 15 | def rmse(pred: Tensor, gt: Tensor) -> Tensor: 16 | return torch.sqrt(torch.mean(torch.square(pred - gt))) 17 | 18 | 19 | def sq_rel(pred: Tensor, gt: Tensor, eps: float = 0.0) -> Tensor: 20 | return torch.mean(torch.square(pred - gt) / (gt + eps)) 21 | 22 | 23 | def abs_rel(pred: Tensor, gt: Tensor, eps: float = 0.0) -> Tensor: 24 | return torch.mean(torch.abs(pred - gt) / (gt + eps)) 25 | 26 | 27 | def rel_thresh(pred: Tensor, gt: Tensor, sigma: float) -> Tensor: 28 | rel = torch.maximum(gt / pred, pred / gt) < sigma 29 | rel = torch.mean(rel.float()) 30 | return rel 31 | 32 | 33 | def compute_metrics_depth( 34 | pred: Tensor, gt: Tensor, label: str = "" 35 | ) -> dict[str, Tensor]: 36 | label = label if not label else label + "/" 37 | return { 38 | f"{label}mae": mae(pred, gt), 39 | f"{label}rmse": rmse(pred, gt), 40 | f"{label}sq_rel": sq_rel(pred, gt), 41 | f"{label}abs_rel": abs_rel(pred, gt), 42 | f"{label}rel_thresh_1.05": rel_thresh(pred, gt, 1.05), 43 | f"{label}rel_thresh_1.25": rel_thresh(pred, gt, 1.25), 44 | } 45 | -------------------------------------------------------------------------------- /lib/visualize/__init__.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualization tools for different needs, each module may require extra dependencies 3 | to install accordingly 4 | """ 5 | -------------------------------------------------------------------------------- /lib/visualize/depth.py: -------------------------------------------------------------------------------- 1 | """ 2 | Utilities to visualize depth and pointclouds 3 | """ 4 | 5 | import numpy as np 6 | from pathlib import Path 7 | from torch import Tensor 8 | 9 | try: 10 | import plyfile # type: ignore 11 | from kornia.geometry.depth import depth_to_3d # type: ignore 12 | from kornia.geometry import transform_points # type: ignore 13 | except ImportError: 14 | raise ImportError("To use depth visualize utilities install plyfile and kornia") 15 | 16 | __all__ = [ 17 | "save_depth_ply", 18 | "save_pcd_ply", 19 | "save_depth_overlap_ply", 20 | ] 21 | 22 | 23 | def _norm_color(color): 24 | color = (color - color.min()) / (1e-6 + np.abs(color.max() - color.min())) 25 | return color * 255 26 | 27 | 28 | def save_depth_ply( 29 | filename: str | Path, 30 | depth: Tensor, 31 | intrinsics: Tensor, 32 | color: Tensor | None = None, 33 | _output_plyelement: bool = False, 34 | ): 35 | """ 36 | Saves a depth map with optionally colors in a ply file, 37 | depth is of shape 1 x H x W and colors 3 x H x W if provided, 38 | """ 39 | mask = depth[0] > 0 40 | pcd = depth_to_3d(depth[None], intrinsics[None]).permute(0, 2, 3, 1)[0][mask] 41 | if color is not None: 42 | color = color.permute(1, 2, 0)[mask] 43 | return save_pcd_ply(filename, pcd, color, _output_plyelement) 44 | 45 | 46 | def save_pcd_ply( 47 | filename: str | Path, 48 | pcd: Tensor, 49 | color: Tensor | None = None, 50 | _output_plyelement: bool = False, 51 | ): 52 | """ 53 | Saves a a point cloud with optionally colors in a ply file, 54 | pcd is of shape N x 3 and colors N x 3 if provided 55 | """ 56 | 57 | pcd = pcd.cpu().numpy() 58 | if color is not None: 59 | color = _norm_color(color.cpu().numpy()) 60 | else: 61 | color = np.zeros_like(pcd) 62 | color[:, 0] += 255 63 | pcd = np.array( 64 | list( 65 | zip( 66 | pcd[:, 0], 67 | pcd[:, 1], 68 | pcd[:, 2], 69 | color[:, 0], 70 | color[:, 1], 71 | color[:, 2], 72 | ) 73 | ), 74 | dtype=[ 75 | ("x", "f4"), 76 | ("y", "f4"), 77 | ("z", "f4"), 78 | ("red", "u1"), 79 | ("green", "u1"), 80 | ("blue", "u1"), 81 | ], 82 | ) 83 | 84 | if _output_plyelement: 85 | return plyfile.PlyElement.describe(pcd, "vertex") 86 | else: 87 | plyfile.PlyData([plyfile.PlyElement.describe(pcd, "vertex")]).write(filename) 88 | 89 | 90 | def save_depth_overlap_ply( 91 | filename: str | Path, 92 | depths: list[Tensor], 93 | colors: list[Tensor], 94 | intrinsics: list[Tensor] | None = None, 95 | extrinsics: list[Tensor] | None = None, 96 | ): 97 | """ 98 | Takes a list of pointclouds or depth maps and their color map and saves all of them 99 | as a whole point cloud in ply format 100 | """ 101 | 102 | if intrinsics is None: 103 | intrinsics = [None] * len(depths) 104 | 105 | elems = [] 106 | for idx, (d, c, i) in enumerate(zip(depths, colors, intrinsics)): 107 | if d.dim() == 2: 108 | if extrinsics is None: 109 | elems.append(save_pcd_ply(None, d, c, True).data) 110 | else: 111 | d = transform_points(extrinsics[idx][None], d[None])[0] 112 | elems.append(save_pcd_ply(None, d, c, True).data) 113 | else: 114 | if extrinsics is None: 115 | elems.append(save_depth_ply(None, d, i, c, True).data) 116 | else: 117 | mask = d[0] > 0 118 | d = depth_to_3d(d[None], i[None]).permute(0, 2, 3, 1)[0][mask] 119 | d = transform_points(extrinsics[idx][None], d[None])[0] 120 | c = c.permute(1, 2, 0)[mask] 121 | elems.append(save_pcd_ply(None, d, c, True).data) 122 | 123 | plyelem = plyfile.PlyElement.describe(np.concatenate(elems, 0), "vertex") 124 | plyfile.PlyData([plyelem]).write(filename) 125 | -------------------------------------------------------------------------------- /lib/visualize/flow.py: -------------------------------------------------------------------------------- 1 | # Flow visualization code used from https://github.com/tomrunia/OpticalFlow_Visualization 2 | 3 | # MIT License 4 | # 5 | # Copyright (c) 2018 Tom Runia 6 | # 7 | # Permission is hereby granted, free of charge, to any person obtaining a copy 8 | # of this software and associated documentation files (the "Software"), to deal 9 | # in the Software without restriction, including without limitation the rights 10 | # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 11 | # copies of the Software, and to permit persons to whom the Software is 12 | # furnished to do so, subject to conditions. 13 | # 14 | # Author: Tom Runia 15 | # Date Created: 2018-08-03 16 | 17 | import numpy as np 18 | 19 | __all__ = ["flow_to_image"] 20 | 21 | 22 | def flow_to_image(flow_uv: np.ndarray, clip_flow=None, convert_to_bgr=False): 23 | """ 24 | Expects a two dimensional flow image of shape. 25 | 26 | Args: 27 | flow_uv (np.ndarray): Flow UV image of shape [H,W,2] 28 | clip_flow (float, optional): Clip maximum of flow values. Defaults to None. 29 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 30 | 31 | Returns: 32 | np.ndarray: Flow visualization image of shape [H,W,3] 33 | """ 34 | assert flow_uv.ndim == 3, "input flow must have three dimensions" 35 | assert flow_uv.shape[2] == 2, "input flow must have shape [H,W,2]" 36 | if clip_flow is not None: 37 | flow_uv = np.clip(flow_uv, 0, clip_flow) 38 | u = flow_uv[:, :, 0] 39 | v = flow_uv[:, :, 1] 40 | rad = np.sqrt(np.square(u) + np.square(v)) 41 | rad_max = np.max(rad) 42 | epsilon = 1e-5 43 | u = u / (rad_max + epsilon) 44 | v = v / (rad_max + epsilon) 45 | return _flow_uv_to_colors(u, v, convert_to_bgr) 46 | 47 | 48 | # utils 49 | 50 | 51 | def _make_colorwheel(): 52 | """ 53 | Generates a color wheel for optical flow visualization as presented in: 54 | Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) 55 | URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf 56 | 57 | Code follows the original C++ source code of Daniel Scharstein. 58 | Code follows the the Matlab source code of Deqing Sun. 59 | 60 | Returns: 61 | np.ndarray: Color wheel 62 | """ 63 | 64 | RY = 15 65 | YG = 6 66 | GC = 4 67 | CB = 11 68 | BM = 13 69 | MR = 6 70 | 71 | ncols = RY + YG + GC + CB + BM + MR 72 | colorwheel = np.zeros((ncols, 3)) 73 | col = 0 74 | 75 | # RY 76 | colorwheel[0:RY, 0] = 255 77 | colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) 78 | col = col + RY 79 | # YG 80 | colorwheel[col : col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) 81 | colorwheel[col : col + YG, 1] = 255 82 | col = col + YG 83 | # GC 84 | colorwheel[col : col + GC, 1] = 255 85 | colorwheel[col : col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) 86 | col = col + GC 87 | # CB 88 | colorwheel[col : col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) 89 | colorwheel[col : col + CB, 2] = 255 90 | col = col + CB 91 | # BM 92 | colorwheel[col : col + BM, 2] = 255 93 | colorwheel[col : col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) 94 | col = col + BM 95 | # MR 96 | colorwheel[col : col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) 97 | colorwheel[col : col + MR, 0] = 255 98 | return colorwheel 99 | 100 | 101 | def _flow_uv_to_colors(u, v, convert_to_bgr=False): 102 | """ 103 | Applies the flow color wheel to (possibly clipped) flow components u and v. 104 | 105 | According to the C++ source code of Daniel Scharstein 106 | According to the Matlab source code of Deqing Sun 107 | 108 | Args: 109 | u (np.ndarray): Input horizontal flow of shape [H,W] 110 | v (np.ndarray): Input vertical flow of shape [H,W] 111 | convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. 112 | 113 | Returns: 114 | np.ndarray: Flow visualization image of shape [H,W,3] 115 | """ 116 | flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) 117 | colorwheel = _make_colorwheel() # shape [55x3] 118 | ncols = colorwheel.shape[0] 119 | rad = np.sqrt(np.square(u) + np.square(v)) 120 | a = np.arctan2(-v, -u) / np.pi 121 | fk = (a + 1) / 2 * (ncols - 1) 122 | k0 = np.floor(fk).astype(np.int32) 123 | k1 = k0 + 1 124 | k1[k1 == ncols] = 0 125 | f = fk - k0 126 | for i in range(colorwheel.shape[1]): 127 | tmp = colorwheel[:, i] 128 | col0 = tmp[k0] / 255.0 129 | col1 = tmp[k1] / 255.0 130 | col = (1 - f) * col0 + f * col1 131 | idx = rad <= 1 132 | col[idx] = 1 - rad[idx] * (1 - col[idx]) 133 | col[~idx] = col[~idx] * 0.75 # out of range 134 | # Note the 2-i => BGR instead of RGB 135 | ch_idx = 2 - i if convert_to_bgr else i 136 | flow_image[:, :, ch_idx] = np.floor(255 * col) 137 | return flow_image 138 | -------------------------------------------------------------------------------- /lib/visualize/gui.py: -------------------------------------------------------------------------------- 1 | from collections import defaultdict 2 | from contextlib import contextmanager 3 | from time import sleep 4 | from typing import Any 5 | 6 | import numpy as np 7 | import torch 8 | import trimesh 9 | import viser 10 | import viser.transforms as ops 11 | from kornia.geometry import depth_to_3d 12 | 13 | 14 | class Visualize3D: 15 | def __init__( 16 | self, 17 | host: str = "127.0.0.1", 18 | port: int = 8080, 19 | ): 20 | self._server = viser.ViserServer(host, port) 21 | self.mesh = None 22 | self._steps = defaultdict(lambda: {}) 23 | self._info = {} 24 | self._gui = defaultdict(lambda: {}) 25 | self._gui_funcs = _GuiFuncs(self._server) 26 | 27 | def _convert_pose(self, pose: np.ndarray | None = None): 28 | if pose is None: 29 | return {} 30 | else: 31 | pose = ops.SE3.from_matrix(pose) 32 | return { 33 | "wxyz": pose.wxyz_xyz[:4], 34 | "position": pose.wxyz_xyz[4:], 35 | } 36 | 37 | @property 38 | def n_steps(self) -> int: 39 | return len(self._steps) 40 | 41 | @property 42 | def steps(self) -> list[int]: 43 | steps = [] 44 | for step in self._steps.keys(): 45 | steps.append(int(step.split("_")[-1])) 46 | return sorted(steps) 47 | 48 | def has_step(self, step: int, name: str): 49 | with self._server.atomic(): 50 | step_data = self._steps.get(f"step_{step}", None) 51 | if not step_data: 52 | raise KeyError(f"step {step} not exists") 53 | return name in step_data 54 | 55 | def step_visible( 56 | self, 57 | step: int, 58 | frame: bool | None = None, 59 | camera: bool | None = None, 60 | point_cloud: bool | None = None, 61 | mesh: bool | None = None, 62 | all: bool | None = None, 63 | ): 64 | with self._server.atomic(): 65 | step_data = self._steps.get(f"step_{step}", None) 66 | if not step_data: 67 | raise KeyError(f"step {step} not exists") 68 | 69 | if all is not None: 70 | if "frame" in step_data: 71 | step_data["frame"].visible = all 72 | if "camera" in step_data: 73 | step_data["camera"].visible = all 74 | if "point_cloud" in step_data: 75 | step_data["point_cloud"].visible = all 76 | if "mesh" in step_data: 77 | step_data["mesh"].visible = all 78 | 79 | if frame is not None and "frame" in step_data: 80 | step_data["frame"].visible = frame 81 | if camera is not None and "camera" in step_data: 82 | step_data["camera"].visible = camera 83 | if point_cloud is not None and "point_cloud" in step_data: 84 | step_data["point_cloud"].visible = point_cloud 85 | if mesh is not None and "mesh" in step_data: 86 | step_data["mesh"].visible = mesh 87 | 88 | def remove_step(self, step: int): 89 | step_data = self._steps[f"step_{step}"] 90 | for v in step_data.values(): 91 | v.remove() 92 | 93 | def reset_scene(self): 94 | self._server.reset_scene() 95 | self._steps = defaultdict(lambda: {}) 96 | self._info = {} 97 | 98 | def add_info(self, label: str, value: Any): 99 | self._info[label] = value 100 | 101 | def get_info(self, label: str) -> Any: 102 | return self._info[label] 103 | 104 | def add_frame(self, step: int, pose: np.ndarray | None = None, **kwargs): 105 | frame = self._server.add_frame( 106 | f"step_{step}_frame", **self._convert_pose(pose), **kwargs 107 | ) 108 | self._steps[f"step_{step}"]["frame"] = frame 109 | 110 | def add_camera( 111 | self, 112 | step: int, 113 | fov: float, 114 | image: np.ndarray, 115 | scale: float = 0.3, 116 | color: tuple[int, int, int] = (20, 20, 20), 117 | pose: np.ndarray | None = None, 118 | visible: bool = True, 119 | ): 120 | camera = self._server.add_camera_frustum( 121 | f"step_{step}_camera", 122 | fov=np.deg2rad(fov), 123 | image=image, 124 | aspect=image.shape[1] / image.shape[0], 125 | color=color, 126 | scale=scale, 127 | visible=visible, 128 | **self._convert_pose(pose), 129 | ) 130 | self._steps[f"step_{step}"]["camera"] = camera 131 | 132 | def add_depth_point_cloud( 133 | self, 134 | step: int, 135 | depth: np.ndarray, 136 | intrinsics: np.ndarray, 137 | color: np.ndarray | None = None, 138 | pose: np.ndarray | None = None, 139 | point_size: float = 0.01, 140 | visible: bool = True, 141 | ): 142 | pcd = ( 143 | depth_to_3d( 144 | torch.from_numpy(depth)[None, None], 145 | torch.from_numpy(intrinsics)[None], 146 | )[0] 147 | .permute(1, 2, 0) 148 | .numpy() 149 | ) 150 | pcd = pcd[depth > 0] 151 | if color is not None: 152 | color = color[depth > 0] 153 | handle = self._server.add_point_cloud( 154 | f"step_{step}_point_cloud", 155 | points=pcd, 156 | colors=color, 157 | point_size=point_size, 158 | **self._convert_pose(pose), 159 | visible=visible, 160 | ) 161 | self._steps[f"step_{step}"]["point_cloud"] = handle 162 | 163 | def add_mesh( 164 | self, 165 | step: int, 166 | vertices: np.ndarray, 167 | triangles: np.ndarray, 168 | vertex_colors: np.ndarray | None = None, 169 | pose: np.ndarray | None = None, 170 | visible: bool = True, 171 | ): 172 | mesh = trimesh.Trimesh( 173 | vertices=vertices, faces=triangles, vertex_colors=vertex_colors 174 | ) 175 | handle = self._server.add_mesh_trimesh( 176 | f"step_{step}_mesh", mesh, **self._convert_pose(pose), visible=visible 177 | ) 178 | self._steps[f"step_{step}"]["mesh"] = handle 179 | 180 | # GUI handles 181 | 182 | @property 183 | def gui(self) -> viser.ViserServer: 184 | return self._gui_funcs 185 | 186 | @contextmanager 187 | def atomic(self): 188 | with self._server.atomic(): 189 | yield 190 | 191 | # wait for gui 192 | 193 | def wait(self): 194 | try: 195 | while True: 196 | sleep(10.0) 197 | except KeyboardInterrupt: 198 | self._server.stop() 199 | 200 | 201 | class _GuiFuncs: 202 | def __init__(self, server): 203 | self._server = server 204 | 205 | def __getattr__(self, name: str) -> Any: 206 | if name.startswith("add_gui") and hasattr(self._server, name): 207 | return getattr(self._server, name) 208 | -------------------------------------------------------------------------------- /lib/visualize/matplotlib.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualization utilities related to matplotlib 3 | """ 4 | 5 | from torch import Tensor 6 | 7 | import shutil 8 | import tempfile 9 | from pathlib import Path 10 | from typing import Iterable, Optional, Union 11 | 12 | from skimage.transform import resize as img_resize 13 | import imageio 14 | import matplotlib.pyplot as plt 15 | import numpy as np 16 | from matplotlib.figure import Figure 17 | 18 | try: 19 | from IPython import get_ipython as _get_ipython 20 | 21 | _is_interactive = lambda: _get_ipython() is not None 22 | except ImportError: 23 | _is_interactive = lambda: None 24 | 25 | 26 | __all__ = ["GridFigure", "make_gif", "color_depth"] 27 | 28 | 29 | class GridFigure: 30 | """ 31 | Utility class to plot a grid of images, really useful 32 | in training code to log qualitatives 33 | """ 34 | 35 | figure: Figure 36 | rows: int 37 | cols: int 38 | 39 | def __init__( 40 | self, 41 | rows: int, 42 | cols: int, 43 | *, 44 | size: tuple[int, int] | None = None, 45 | tight_layout: bool = True, 46 | ): 47 | figsize = None 48 | if size: 49 | h, w = size 50 | px = 1 / plt.rcParams["figure.dpi"] 51 | figsize = (cols * w * px, rows * h * px) 52 | self.figure = plt.figure(tight_layout=tight_layout, figsize=figsize) 53 | self.rows = rows 54 | self.cols = cols 55 | 56 | def imshow( 57 | self, 58 | pos: int, 59 | image: Tensor, 60 | /, 61 | *, 62 | show_axis: bool = False, 63 | norm: bool = True, 64 | **kwargs, 65 | ): 66 | ax = self.figure.add_subplot(self.rows, self.cols, pos) 67 | if not show_axis: 68 | ax.set_axis_off() 69 | ax.imshow(_to_numpy(image, norm=norm), **kwargs) 70 | 71 | def show(self): 72 | self.figure.show() 73 | 74 | def close(self): 75 | plt.close(self.figure) 76 | 77 | def __enter__(self): 78 | return self 79 | 80 | def __exit__(self, exc_type, exc_val, ext_tb): 81 | self.show() 82 | if not _is_interactive(): 83 | self.close() 84 | 85 | 86 | def make_gif( 87 | imgs: Iterable[Union[np.ndarray, Figure]], 88 | path: Union[Path, str], 89 | axis: bool = False, 90 | close_figures: bool = True, 91 | bbox_inches: Optional[str] = "tight", 92 | pad_inches: float = 0, 93 | **kwargs, 94 | ): 95 | """ 96 | Renders each image using matplotlib and the compose them into a GIF saved 97 | in path, kwargs are arguments for ``plt.imshow`` 98 | """ 99 | 100 | tmpdir = Path(tempfile.mkdtemp()) 101 | for i, img in enumerate(imgs): 102 | if isinstance(img, np.ndarray): 103 | plt.axis({False: "off", True: "on"}[axis]) 104 | plt.imshow(img, **kwargs) 105 | plt.savefig( 106 | tmpdir / f"{i}.png", bbox_inches=bbox_inches, pad_inches=pad_inches 107 | ) 108 | plt.close() 109 | elif isinstance(img, Figure): 110 | img.savefig( 111 | tmpdir / f"{i}.png", bbox_inches=bbox_inches, pad_inches=pad_inches 112 | ) 113 | if close_figures: 114 | plt.close(img) 115 | 116 | with imageio.get_writer(path, mode="I") as writer: 117 | paths = sorted(tmpdir.glob("*.png"), key=lambda x: int(x.name.split(".")[0])) 118 | for path in paths: 119 | image = imageio.imread(path) 120 | writer.append_data(image) 121 | shutil.rmtree(tmpdir, ignore_errors=True) 122 | 123 | 124 | def color_depth(depth: np.ndarray, cmap="magma_r", **kwargs): 125 | px = 1 / plt.rcParams["figure.dpi"] 126 | h, w = depth.shape[:2] 127 | fig = plt.figure(figsize=(w * px, h * px)) 128 | plt.axis("off") 129 | plt.tight_layout(pad=0.0, h_pad=0.0, w_pad=0.0) 130 | plt.imshow(depth, cmap=cmap, **kwargs) 131 | fig.canvas.draw() 132 | data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) 133 | data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,)) 134 | data = img_resize(data, depth.shape) 135 | alpha = (depth[..., None] > 0).astype(np.uint8) * 255 136 | data = np.concatenate([data, alpha], -1) 137 | plt.close(fig) 138 | return data 139 | 140 | # internal utils 141 | 142 | 143 | def _to_numpy(image: Tensor | np.ndarray, norm: bool = True): 144 | if isinstance(image, Tensor): 145 | image = image.detach().cpu().numpy().astype(np.float32) 146 | if image.ndim == 3: 147 | image = image.transpose([1, 2, 0]) 148 | elif image.ndim != 2: 149 | raise ValueError(f"Unsupported torch.Tensor of shape {image.shape}") 150 | if norm: 151 | nan_mask = np.isnan(image) 152 | if not np.all(nan_mask): 153 | img_min, img_max = image[~nan_mask].min(), image[~nan_mask].max() 154 | image = (image - img_min) / (img_max - img_min) 155 | return image 156 | -------------------------------------------------------------------------------- /lib/visualize/plotly.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualization tools implemented in plotly 3 | """ 4 | 5 | import pandas as pd 6 | 7 | try 8 | from plotly.graph_objects import Scatter3d, Figure # type: ignore 9 | except ImportError: 10 | raise ImportError("To use plotly visualize utilities install plotly") 11 | 12 | 13 | def point_cloud( 14 | data: pd.DataFrame, 15 | x: str, 16 | y: str, 17 | z: str, 18 | color: list[str] | str | None = None, 19 | point_size: float = 1.0, 20 | ) -> Scatter3d: 21 | """ 22 | Utility function to visualize a Point Cloud with colors 23 | """ 24 | 25 | marker_opts = { 26 | "size": point_size, 27 | } 28 | 29 | if color is not None: 30 | if isinstance(color, str): 31 | marker_opts["color"] = [color] * len(data[x]) 32 | else: 33 | marker_opts["color"] = [ 34 | f"rgb({r}, {g}, {b})" 35 | for r, g, b in zip(data[color[0]], data[color[1]], data[color[2]]) 36 | ] 37 | 38 | scatter = Scatter3d( 39 | x=data[x], y=data[y], z=data[z], mode="markers", marker=marker_opts 40 | ) 41 | return scatter 42 | -------------------------------------------------------------------------------- /media/scannetv2_scene0720_00.mp4: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/andreaconti/depth-on-demand/7376cd0f7a67e583b3d0eff86641cfd7af83de67/media/scannetv2_scene0720_00.mp4 -------------------------------------------------------------------------------- /pyrightconfig.json: -------------------------------------------------------------------------------- 1 | { 2 | "exclude": [ 3 | "**/.git", 4 | "**/__pycache__", 5 | "**/.env", 6 | "**/.pytest_cache", 7 | "**/data", 8 | "**/.data", 9 | ] 10 | } 11 | 12 | -------------------------------------------------------------------------------- /visualize/callbacks/build_scene.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | 4 | import matplotlib 5 | import numpy as np 6 | import torch 7 | import trimesh 8 | from depth_on_demand import Model 9 | 10 | from lib.benchmark.mesh import TSDFFusion 11 | from lib.dataset import list_scans, load_datamodule 12 | from lib.dataset.utils import inv_pose, project_depth, sparsify_depth 13 | from lib.visualize.gui import Visualize3D 14 | 15 | 16 | class BuildScene: 17 | def __init__( 18 | self, 19 | viz: Visualize3D, 20 | device: str = "cuda:0", 21 | ): 22 | self.viz = viz 23 | self.device = device 24 | self.to_clear = False 25 | 26 | # select the model 27 | with viz.gui.add_gui_folder("Model"): 28 | self.model_pretrained = viz.gui.add_gui_dropdown( 29 | "Model Pretrained", ["scannetv2"], initial_value="scannetv2" 30 | ) 31 | self.load_model(self.model_pretrained.value) 32 | self.model_pretrained.on_update(self.load_model) 33 | 34 | with viz.gui.add_gui_folder("Dataset"): 35 | self.dataset = viz.gui.add_gui_dropdown( 36 | "Name", ["sevenscenes", "scannetv2"], initial_value="scannetv2" 37 | ) 38 | self.dataset_root = viz.gui.add_gui_text( 39 | "Root", initial_value="data/scannetv2" 40 | ) 41 | self.dataset.on_update(self.on_update_dataset) 42 | 43 | # create build scene panel 44 | with viz.gui.add_gui_folder("Build Scene"): 45 | self.scene_select = viz.gui.add_gui_dropdown( 46 | "Scene", scenes[self.dataset.value] 47 | ) 48 | self.hints_interval = viz.gui.add_gui_slider( 49 | "Hints Interval", min=1, max=10, step=1, initial_value=5 50 | ) 51 | self.hints_density = viz.gui.add_gui_slider( 52 | "Hints Density", min=10, max=2000, step=10, initial_value=500 53 | ) 54 | self.pcd_points_size = viz.gui.add_gui_slider( 55 | "Points Size", min=0.01, max=1.0, step=0.01, initial_value=0.05 56 | ) 57 | self.frustum_color = viz.gui.add_gui_dropdown( 58 | "Frustum Color", list(matplotlib.colormaps.keys()), initial_value="jet" 59 | ) 60 | self.frustum_color_range = viz.gui.add_gui_vector2( 61 | "Color Map Range", (0.25, 0.5), (0.0, 0.0), (1.0, 1.0), 0.05 62 | ) 63 | self.overlap_time = viz.gui.add_gui_slider( 64 | "frames overlap", 0, 1, 0.01, 0.25 65 | ) 66 | self.build_scene = viz.gui.add_gui_button("Build Scene") 67 | 68 | with (folder := viz.gui.add_gui_folder("TSDF Parameters")): 69 | self.tsdf_folder = folder 70 | self.voxel_length = viz.gui.add_gui_slider( 71 | "Voxel Length", 0.01, 0.5, 0.01, 0.04 72 | ) 73 | self.sdf_trunc = viz.gui.add_gui_slider( 74 | "SDF Trunc", 0.01, 0.5, 0.01, 3 * 0.04 75 | ) 76 | self.depth_trunc = viz.gui.add_gui_slider("Depth Trunc", 0.5, 10, 0.1, 3.0) 77 | self.integrate_only_slow = viz.gui.add_gui_checkbox( 78 | "integrate only when hints", False 79 | ) 80 | 81 | self.build_scene.on_click(self.build_scene_callback) 82 | 83 | def on_update_dataset(self, *args, **kwargs): 84 | self.scene_select.options = scenes[self.dataset.value] 85 | self.voxel_length.visible = self.dataset.value != "kitti" 86 | self.sdf_trunc.visible = self.dataset.value != "kitti" 87 | self.depth_trunc.visible = self.dataset.value != "kitti" 88 | self.integrate_only_slow.visible = self.dataset.value != "kitti" 89 | self.dataset_root.value = f"data/{self.dataset.value}" 90 | 91 | def load_model(self, *args, **kwargs): 92 | self.model = Model(self.model_pretrained.value, device=self.device) 93 | 94 | def build_scene_callback(self, event): 95 | viz = self.viz 96 | scene_data = load_data( 97 | self.dataset.value, self.scene_select.value, root=self.dataset_root.value 98 | ) 99 | interval = self.hints_interval.value 100 | density = self.hints_density.value 101 | 102 | colors = matplotlib.colormaps[self.frustum_color.value]( 103 | np.linspace(*self.frustum_color_range.value, len(scene_data)) 104 | ) 105 | colors = (colors[:, :3] * 255).astype(np.uint8) 106 | 107 | tsdf = TSDFFusion( 108 | voxel_length=self.voxel_length.value, 109 | sdf_trunc=self.sdf_trunc.value, 110 | depth_trunc=self.depth_trunc.value, 111 | device=self.device, 112 | ) 113 | viz.reset_scene() 114 | viz.add_info("dataset", self.dataset.value) 115 | viz.add_info("scan", self.scene_select.value) 116 | viz.add_info("frames", []) 117 | 118 | for step, ex in enumerate(scene_data): 119 | ex = { 120 | k: v.to(self.device) 121 | for k, v in ex.items() 122 | if isinstance(v, torch.Tensor) 123 | } 124 | if step % interval == 0: 125 | buffer = ex.copy() 126 | buffer_hints = sparsify_depth(buffer["depth"], density).to(self.device) 127 | with torch.no_grad(): 128 | rel_pose = inv_pose(ex["extrinsics"]) @ buffer["extrinsics"] 129 | hints = project_depth( 130 | buffer_hints, 131 | buffer["intrinsics"], 132 | ex["intrinsics"], 133 | rel_pose, 134 | torch.zeros_like(ex["image"][:, :1]), 135 | ) 136 | depth = self.model( 137 | ex["image"], 138 | buffer["image"], 139 | hints, 140 | rel_pose, 141 | torch.stack([ex["intrinsics"], buffer["intrinsics"]], 1), 142 | ) 143 | 144 | img = img_to_rgb(ex["image"]) 145 | depth = depth.cpu().numpy()[0, 0, ..., None] 146 | self.viz.get_info("frames").append( 147 | { 148 | "target": img, 149 | "source": img_to_rgb(buffer["image"]), 150 | "target_hints": hints.cpu().numpy()[0, 0, ..., None], 151 | "source_hints": buffer_hints.cpu().numpy()[0, 0, ..., None], 152 | "depth": depth, 153 | } 154 | ) 155 | pose = ex["extrinsics"][0].cpu().numpy() 156 | intrinsics = ex["intrinsics"][0].cpu().numpy() 157 | hints_cpu = buffer_hints[0, 0].cpu().numpy() 158 | hints_cpu = np.where(hints_cpu > self.depth_trunc.value, 0.0, hints_cpu) 159 | 160 | # viz.add_step(step) 161 | # viz.step_visible(step, container=False) 162 | if step % interval == 0: 163 | camera_color = (255, 0, 0) 164 | else: 165 | camera_color = tuple(colors[step]) 166 | viz.add_camera( 167 | step, 60.0, img, color=camera_color, pose=pose, visible=False 168 | ) 169 | if step % interval == 0: 170 | red = np.zeros_like(img) 171 | red[..., 0] = 255 172 | viz.add_depth_point_cloud( 173 | step, 174 | hints_cpu, 175 | intrinsics, 176 | red, 177 | point_size=self.pcd_points_size.value, 178 | pose=pose, 179 | visible=False, 180 | ) 181 | 182 | if self.integrate_only_slow.value and step % interval != 0: 183 | pass 184 | else: 185 | tsdf.integrate_rgbd(img, depth, intrinsics, pose) 186 | mesh = tsdf.triangle_mesh() 187 | viz.add_info( 188 | "mesh", 189 | trimesh.Trimesh( 190 | vertices=np.asarray(mesh.vertices), 191 | faces=np.asarray(mesh.triangles), 192 | vertex_colors=np.asarray(mesh.vertex_colors), 193 | ), 194 | ) 195 | viz.add_mesh( 196 | step, 197 | vertices=np.asarray(mesh.vertices), 198 | triangles=np.asarray(mesh.triangles), 199 | vertex_colors=np.asarray(mesh.vertex_colors), 200 | # pose=np.linalg.inv(pose), 201 | visible=False, 202 | ) 203 | 204 | if step > 0: 205 | viz.step_visible(step - 1, all=False, mesh=True) 206 | viz.step_visible(step, camera=True, point_cloud=True, mesh=True) 207 | if step > 0: 208 | time.sleep(self.overlap_time.value) 209 | viz.step_visible(step - 1, all=False) 210 | 211 | 212 | def load_data(dataset: str, scene: str, root: Path | str): 213 | root = {"root": root} 214 | if dataset == "kitti": 215 | root = { 216 | "root_raw": Path(root) / "raw", 217 | "root_completion": Path(root) / "depth_completion", 218 | } 219 | 220 | dm = load_datamodule( 221 | dataset, 222 | load_prevs=0, 223 | keyframes="standard", 224 | **root, 225 | split_test_scans_loaders=True, 226 | filter_scans=[scene], 227 | ) 228 | dm.prepare_data() 229 | dm.setup("test") 230 | return next(iter(dm.test_dataloader())) 231 | 232 | 233 | scenes = { 234 | "scannetv2": list_scans("scannetv2", "standard", "test"), 235 | "sevenscenes": list_scans("sevenscenes", "standard", "test"), 236 | "kitti": list_scans("kitti", "standard", "test"), 237 | "tartanair": list_scans("tartanair", "standard", "test"), 238 | } 239 | 240 | 241 | def img_to_rgb(image: torch.Tensor): 242 | image = image[0].permute(1, 2, 0).cpu().numpy() 243 | image = (image - image.min()) / (image.max() - image.min()) 244 | image = (image * 255).astype(np.uint8) 245 | return image 246 | -------------------------------------------------------------------------------- /visualize/callbacks/run_demo.py: -------------------------------------------------------------------------------- 1 | import time 2 | from pathlib import Path 3 | 4 | import cv2 5 | import h5py 6 | import numpy as np 7 | import trimesh 8 | import viser 9 | 10 | from lib.visualize.gui import Visualize3D 11 | 12 | 13 | class RunDemo: 14 | def __init__(self, viz: Visualize3D, renders_folder: Path): 15 | self.viz = viz 16 | self.renders_folder = Path(renders_folder) 17 | 18 | # create gui 19 | with viz.gui.add_gui_folder("Demo"): 20 | self.speed = viz.gui.add_gui_slider("fps", 1, 20, 1, 5) 21 | self.overlap_time = viz.gui.add_gui_slider( 22 | "frames overlap", 0, 1, 0.01, 0.05 23 | ) 24 | self.keep_mesh = viz.gui.add_gui_checkbox("Keep Background Mesh", False) 25 | self.run_button = viz.gui.add_gui_button("Run Demo") 26 | self.height = viz.gui.add_gui_number("Render Height", 960, 100, 2160, 1) 27 | self.width = viz.gui.add_gui_number("Render Width", 1280, 100, 2160, 1) 28 | self.render_button = viz.gui.add_gui_button("Render Video") 29 | self.save_frames = viz.gui.add_gui_checkbox("Save Frames", False) 30 | self.run_button.on_click(self.run_demo_clbk) 31 | self.render_button.on_click(self.render_video_clbk) 32 | 33 | def render_video_clbk(self, event: viser.GuiEvent): 34 | client = event.client 35 | assert client is not None 36 | client.camera.get_render 37 | self.run_demo_clbk(None, client) 38 | 39 | def run_demo_clbk(self, event, render_client=None): 40 | keep_mesh = self.keep_mesh.value 41 | keys = ["point_cloud", "frame", "camera", "mesh"] 42 | disable_all = {k: False for k in keys} 43 | enable_all = {k: True for k in keys} 44 | viz = self.viz 45 | 46 | # hide all 47 | for step in viz.steps: 48 | viz.step_visible( 49 | step, 50 | **( 51 | disable_all | {"mesh": True} 52 | if keep_mesh and step == viz.steps[-1] 53 | else disable_all 54 | ), 55 | ) 56 | 57 | # start animation 58 | images = [] 59 | for step in viz.steps: 60 | viz.step_visible( 61 | step, 62 | **( 63 | enable_all | {"mesh": False} 64 | if keep_mesh and step != viz.steps[-1] 65 | else enable_all 66 | ), 67 | ) 68 | time.sleep(1 / self.speed.value) 69 | if step != viz.steps[-1]: 70 | viz.step_visible(step, **disable_all) 71 | 72 | if render_client is not None: 73 | images.append( 74 | render_client.camera.get_render( 75 | height=self.height.value, width=self.width.value 76 | ) 77 | ) 78 | 79 | # save the animation 80 | if render_client is not None: 81 | scan_name = viz.get_info("scan") 82 | render_folder = self.renders_folder / viz.get_info("dataset") / scan_name 83 | render_folder.mkdir(exist_ok=True, parents=True) 84 | 85 | # write the video 86 | h, w, _ = images[0].shape 87 | writer = cv2.VideoWriter( 88 | str(render_folder / "video.avi"), 89 | cv2.VideoWriter_fourcc(*"MJPG"), 90 | self.speed.value, 91 | (w, h), 92 | ) 93 | for frame in images: 94 | writer.write(frame[..., ::-1].astype(np.uint8)) 95 | writer.release() 96 | 97 | # save the mesh 98 | mesh: trimesh.Trimesh = viz.get_info("mesh") 99 | mesh.vertices = mesh.vertices - mesh.centroid 100 | _ = mesh.export( 101 | str(render_folder / "mesh.glb"), 102 | file_type="glb", 103 | ) 104 | 105 | # save frames 106 | if self.save_frames.value: 107 | for idx, dict_data in enumerate(self.viz.get_info("frames")): 108 | with h5py.File(render_folder / f"data-{idx:0>6d}.h5", "w") as f: 109 | f["target"] = dict_data["target"] 110 | f["source"] = dict_data["source"] 111 | f["depth"] = dict_data["depth"] 112 | f["target_hints"] = dict_data["target_hints"] 113 | f["source_hints"] = dict_data["source_hints"] 114 | -------------------------------------------------------------------------------- /visualize/run.py: -------------------------------------------------------------------------------- 1 | """ 2 | Visualizer 3D of a scene built with my network 3 | """ 4 | 5 | import sys 6 | from pathlib import Path 7 | 8 | sys.path.append(str(Path(__file__).parents[1])) 9 | import argparse 10 | import warnings 11 | 12 | from callbacks.build_scene import BuildScene 13 | from callbacks.run_demo import RunDemo 14 | 15 | from lib.visualize.gui import Visualize3D 16 | 17 | warnings.filterwarnings("ignore", category=UserWarning) 18 | 19 | 20 | def parse_args(): 21 | parser = argparse.ArgumentParser() 22 | parser.add_argument("--device", default="cuda:0") 23 | parser.add_argument("--host", default="0.0.0.0") 24 | parser.add_argument("--port", default=8080, type=int) 25 | parser.add_argument("--renders-folder", default="./renders") 26 | args = parser.parse_args() 27 | return args 28 | 29 | 30 | def main(): 31 | args = parse_args() 32 | viz = Visualize3D(host=args.host, port=args.port) 33 | BuildScene(viz, args.device) 34 | RunDemo(viz, args.renders_folder) 35 | viz.wait() 36 | 37 | 38 | if __name__ == "__main__": 39 | main() 40 | --------------------------------------------------------------------------------