├── .gitattributes ├── .gitignore ├── LICENSE ├── README.md ├── assets ├── pipeline_ags_mesh.png └── pipeline_crop.jpg ├── dn_splatter ├── __init__.py ├── data │ ├── coolermap_dataparser.py │ ├── dn_dataset.py │ ├── download_scripts │ │ ├── download_omnidata.py │ │ ├── dtu_download.py │ │ ├── mushroom_download.py │ │ ├── nrgbd_download.py │ │ └── replica_download.py │ ├── g_sdfstudio_dataparser.py │ ├── mushroom_dataparser.py │ ├── mushroom_utils │ │ ├── eval_faro.py │ │ ├── mushroom_batch_run.sh │ │ ├── pointcloud_utils.py │ │ ├── reference_depth_download.py │ │ ├── render_faro_nm.py │ │ └── render_gt_depth.py │ ├── normal_nerfstudio.py │ ├── nrgbd_dataparser.py │ ├── replica_dataparser.py │ ├── replica_utils │ │ └── render_normals.py │ ├── scannetpp_dataparser.py │ └── scannetpp_utils │ │ └── pointcloud_utils.py ├── dn_config.py ├── dn_datamanager.py ├── dn_model.py ├── dn_pipeline.py ├── eval │ ├── __init__.py │ ├── baseline_models │ │ ├── __init__.py │ │ ├── eval_configs.py │ │ ├── g_depthnerfacto.py │ │ ├── g_nerfacto.py │ │ ├── g_neusfacto.py │ │ ├── mushroom_to_sdfstudio.py │ │ └── nerfstudio_to_sdfstudio.py │ ├── batch_run.py │ ├── eval.py │ ├── eval_instructions.md │ ├── eval_mesh_mushroom_vis_cull.py │ ├── eval_mesh_vis_cull.py │ ├── eval_normals.py │ └── eval_pd.py ├── export_mesh.py ├── losses.py ├── metrics.py ├── regularization_strategy.py ├── scripts │ ├── align_depth.py │ ├── compare_normals.py │ ├── comparison_video.sh │ ├── convert_colmap.py │ ├── depth_from_pretrain.py │ ├── depth_normal_consistency.py │ ├── depth_to_normal.py │ ├── dsine │ │ ├── __init__.py │ │ ├── dsine.py │ │ ├── dsine_predictor.py │ │ ├── rotations.py │ │ └── submodules.py │ ├── isooctree_dn.py │ ├── normals_from_pretrain.py │ ├── poses_to_colmap_sfm.py │ ├── process_sai.py │ ├── render_model.py │ └── vis_errors.py └── utils │ ├── camera_utils.py │ ├── knn.py │ ├── normal_utils.py │ └── utils.py ├── pixi.lock ├── pixi.toml └── pyproject.toml /.gitattributes: -------------------------------------------------------------------------------- 1 | # GitHub syntax highlighting 2 | pixi.lock linguist-language=YAML 3 | 4 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_store 2 | .idea 3 | datasets/ 4 | outputs/ 5 | dn_splatter.egg-info/ 6 | omnidata_ckpt/ 7 | pretrained_models/ 8 | mesh_exports/ 9 | 10 | # Byte-compiled / optimized / DLL files 11 | __pycache__/ 12 | *.py[cod] 13 | *$py.class 14 | 15 | # C extensions 16 | *.so 17 | 18 | # Distribution / packaging 19 | .Python 20 | build/ 21 | develop-eggs/ 22 | dist/ 23 | downloads/ 24 | eggs/ 25 | .eggs/ 26 | lib/ 27 | lib64/ 28 | parts/ 29 | sdist/ 30 | var/ 31 | wheels/ 32 | share/python-wheels/ 33 | *.egg-info/ 34 | .installed.cfg 35 | *.egg 36 | MANIFEST 37 | 38 | # PyInstaller 39 | # Usually these files are written by a python script from a template 40 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 41 | *.manifest 42 | *.spec 43 | 44 | # Installer logs 45 | pip-log.txt 46 | pip-delete-this-directory.txt 47 | 48 | # Unit test / coverage reports 49 | htmlcov/ 50 | .tox/ 51 | .nox/ 52 | .coverage 53 | .coverage.* 54 | .cache 55 | nosetests.xml 56 | coverage.xml 57 | *.cover 58 | *.py,cover 59 | .hypothesis/ 60 | .pytest_cache/ 61 | cover/ 62 | 63 | # Translations 64 | *.mo 65 | *.pot 66 | 67 | # Django stuff: 68 | *.log 69 | local_settings.py 70 | db.sqlite3 71 | db.sqlite3-journal 72 | 73 | # Flask stuff: 74 | instance/ 75 | .webassets-cache 76 | 77 | # Scrapy stuff: 78 | .scrapy 79 | 80 | # Sphinx documentation 81 | docs/_build/ 82 | 83 | # PyBuilder 84 | .pybuilder/ 85 | target/ 86 | 87 | # Jupyter Notebook 88 | .ipynb_checkpoints 89 | 90 | # IPython 91 | profile_default/ 92 | ipython_config.py 93 | 94 | # pyenv 95 | # For a library or package, you might want to ignore these files since the code is 96 | # intended to run in multiple environments; otherwise, check them in: 97 | # .python-version 98 | 99 | # pipenv 100 | # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. 101 | # However, in case of collaboration, if having platform-specific dependencies or dependencies 102 | # having no cross-platform support, pipenv may install dependencies that don't work, or not 103 | # install all needed dependencies. 104 | #Pipfile.lock 105 | 106 | # poetry 107 | # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. 108 | # This is especially recommended for binary packages to ensure reproducibility, and is more 109 | # commonly ignored for libraries. 110 | # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control 111 | #poetry.lock 112 | 113 | # pdm 114 | # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. 115 | #pdm.lock 116 | # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it 117 | # in version control. 118 | # https://pdm.fming.dev/#use-with-ide 119 | .pdm.toml 120 | 121 | # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm 122 | __pypackages__/ 123 | 124 | # Celery stuff 125 | celerybeat-schedule 126 | celerybeat.pid 127 | 128 | # SageMath parsed files 129 | *.sage.py 130 | 131 | # Environments 132 | .env 133 | .venv 134 | env/ 135 | venv/ 136 | ENV/ 137 | env.bak/ 138 | venv.bak/ 139 | 140 | # Spyder project settings 141 | .spyderproject 142 | .spyproject 143 | 144 | # Rope project settings 145 | .ropeproject 146 | 147 | # mkdocs documentation 148 | /site 149 | 150 | # mypy 151 | .mypy_cache/ 152 | .dmypy.json 153 | dmypy.json 154 | 155 | # Pyre type checker 156 | .pyre/ 157 | 158 | # pytype static type analyzer 159 | .pytype/ 160 | 161 | # Cython debug symbols 162 | cython_debug/ 163 | 164 | # PyCharm 165 | # JetBrains specific template is maintained in a separate JetBrains.gitignore that can 166 | # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore 167 | # and can be added to the global gitignore or merged into this file. For a more nuclear 168 | # option (not recommended) you can uncomment the following to ignore the entire idea folder. 169 | #.idea/ 170 | nerfstudio 171 | .vscode 172 | dn_splatter/__pycache__ 173 | # pixi environments 174 | .pixi -------------------------------------------------------------------------------- /assets/pipeline_ags_mesh.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maturk/dn-splatter/249d52c4bb14b7bf6dd18d7d66099a36eac2ee78/assets/pipeline_ags_mesh.png -------------------------------------------------------------------------------- /assets/pipeline_crop.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maturk/dn-splatter/249d52c4bb14b7bf6dd18d7d66099a36eac2ee78/assets/pipeline_crop.jpg -------------------------------------------------------------------------------- /dn_splatter/__init__.py: -------------------------------------------------------------------------------- 1 | from .data.coolermap_dataparser import CoolerMapDataParserSpecification 2 | from .data.g_sdfstudio_dataparser import GSDFStudioDataParserSpecification 3 | from .data.mushroom_dataparser import MushroomDataParserSpecification 4 | from .data.normal_nerfstudio import NormalNerfstudioSpecification 5 | from .data.nrgbd_dataparser import NRGBDDataParserSpecification 6 | from .data.replica_dataparser import ReplicaDataParserSpecification 7 | from .data.scannetpp_dataparser import ScanNetppDataParserSpecification 8 | 9 | __all__ = [ 10 | "__version__", 11 | MushroomDataParserSpecification, 12 | ReplicaDataParserSpecification, 13 | GSDFStudioDataParserSpecification, 14 | NRGBDDataParserSpecification, 15 | ScanNetppDataParserSpecification, 16 | CoolerMapDataParserSpecification, 17 | NormalNerfstudioSpecification, 18 | ] 19 | -------------------------------------------------------------------------------- /dn_splatter/data/download_scripts/download_omnidata.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from pathlib import Path 4 | 5 | import tyro 6 | 7 | 8 | def download_omnidata( 9 | save_dir: Path = Path(os.getcwd() + "/omnidata_ckpt"), 10 | ): 11 | save_dir.mkdir(parents=True, exist_ok=True) 12 | # Download the pretrained model weights using wget 13 | try: 14 | subprocess.run( 15 | [ 16 | "wget", 17 | "-P", 18 | save_dir, 19 | "https://zenodo.org/records/10447888/files/omnidata_dpt_normal_v2.ckpt", 20 | ] 21 | ) 22 | print("Pretrained model weights downloaded successfully!") 23 | except subprocess.CalledProcessError as e: 24 | print(f"Error downloading file: {e}") 25 | 26 | 27 | if __name__ == "__main__": 28 | tyro.cli(download_omnidata) 29 | -------------------------------------------------------------------------------- /dn_splatter/data/download_scripts/dtu_download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from pathlib import Path 4 | 5 | import tyro 6 | 7 | 8 | def download_dtu( 9 | save_dir: Path = Path(os.getcwd() + "/datasets"), 10 | ): 11 | save_zip_dir = save_dir 12 | save_dir.mkdir(parents=True, exist_ok=True) 13 | 14 | url = "https://s3.eu-central-1.amazonaws.com/avg-projects/monosdf/data/DTU.tar" 15 | wget_command = ["wget", "-P", str(save_zip_dir), url] 16 | file_name = "DTU.tar" 17 | extract_command = ["tar", "-xvf", save_dir / file_name, "-C", save_dir] 18 | 19 | try: 20 | subprocess.run(wget_command, check=True) 21 | print("File file downloaded succesfully.") 22 | except subprocess.CalledProcessError as e: 23 | print(f"Error downloading file: {e}") 24 | try: 25 | subprocess.run(extract_command, check=True) 26 | print("Extraction complete.") 27 | except subprocess.CalledProcessError as e: 28 | print(f"Extraction failed: {e}") 29 | 30 | 31 | if __name__ == "__main__": 32 | tyro.cli(download_dtu) 33 | -------------------------------------------------------------------------------- /dn_splatter/data/download_scripts/mushroom_download.py: -------------------------------------------------------------------------------- 1 | """Script to download example mushroom dataset to /datasets folder""" 2 | 3 | import os 4 | import subprocess 5 | from pathlib import Path 6 | from typing import Literal 7 | 8 | import tyro 9 | 10 | room_names = Literal[ 11 | "coffee_room", 12 | "computer", 13 | "classroom", 14 | "honka", 15 | "koivu", 16 | "vr_room", 17 | "kokko", 18 | "sauna", 19 | "activity", 20 | "olohuone", 21 | ] 22 | 23 | 24 | def download_mushroom( 25 | room_name: room_names = "activity", # type: ignore 26 | save_dir: Path = Path(os.getcwd() + "/datasets"), 27 | sequence: Literal["iphone", "kinect", "faro", "all"] = "all", 28 | ): 29 | save_dir.mkdir(parents=True, exist_ok=True) 30 | 31 | iphone_url = ( 32 | "https://zenodo.org/records/10230733/files/" + room_name + "_iphone.tar.gz" 33 | ) 34 | kinect_url = ( 35 | "https://zenodo.org/records/10209072/files/" + room_name + "_kinect.tar.gz" 36 | ) 37 | mesh_pd_url = ( 38 | "https://zenodo.org/records/10222321/files/" + room_name + "_mesh_pd.tar.gz" 39 | ) 40 | 41 | commands = [] 42 | extract_commands = [] 43 | 44 | if sequence in ["iphone", "all"]: 45 | wget_command = ["wget", "-P", str(save_dir), iphone_url] 46 | commands.append(wget_command) 47 | file_name = room_name + "_iphone.tar.gz" 48 | extract_commands.append(["tar", "-xvzf", save_dir / file_name, "-C", save_dir]) 49 | 50 | if sequence in ["kinect", "all"]: 51 | wget_command = ["wget", "-P", str(save_dir), kinect_url] 52 | commands.append(wget_command) 53 | file_name = room_name + "_kinect.tar.gz" 54 | extract_commands.append(["tar", "-xvzf", save_dir / file_name, "-C", save_dir]) 55 | 56 | if sequence in ["faro", "all"]: 57 | wget_command = ["wget", "-P", str(save_dir), mesh_pd_url] 58 | commands.append(wget_command) 59 | file_name = room_name + "_mesh_pd.tar.gz" 60 | extract_commands.append(["tar", "-xvzf", save_dir / file_name, "-C", save_dir]) 61 | 62 | for command in commands: 63 | try: 64 | subprocess.run(command, check=True) 65 | print("File file downloaded succesfully.") 66 | except subprocess.CalledProcessError as e: 67 | print(f"Error downloading file: {e}") 68 | for extract_command in extract_commands: 69 | try: 70 | subprocess.run(extract_command, check=True) 71 | print("Extraction complete.") 72 | except subprocess.CalledProcessError as e: 73 | print(f"Extraction failed: {e}") 74 | 75 | 76 | if __name__ == "__main__": 77 | tyro.cli(download_mushroom) 78 | -------------------------------------------------------------------------------- /dn_splatter/data/download_scripts/nrgbd_download.py: -------------------------------------------------------------------------------- 1 | """Script to download pre-processed Replica dataset. Total size of dataset is 12.4 gb.""" 2 | import os 3 | import subprocess 4 | from pathlib import Path 5 | 6 | import tyro 7 | 8 | 9 | def download_nrgbd( 10 | save_dir: Path = Path(os.getcwd() + "/datasets"), 11 | all: bool = True, 12 | test: bool = False, 13 | ): 14 | save_zip_dir = save_dir / "NRGBD" 15 | save_dir.mkdir(parents=True, exist_ok=True) 16 | 17 | url = "http://kaldir.vc.in.tum.de/neural_rgbd/neural_rgbd_data.zip" 18 | mesh_url = "http://kaldir.vc.in.tum.de/neural_rgbd/meshes.zip" 19 | commands = [] 20 | commands.append(["wget", "-P", str(save_zip_dir), url]) 21 | commands.append(["wget", "-P", str(save_zip_dir), mesh_url]) 22 | file_name = "neural_rgbd_data.zip" 23 | mesh_name = "meshes.zip" 24 | extract_commands = [] 25 | extract_commands.append( 26 | ["unzip", save_zip_dir / file_name, "-d", save_dir / "NRGBD"] 27 | ) 28 | extract_commands.append( 29 | ["unzip", save_zip_dir / mesh_name, "-d", save_dir / "NRGBD"] 30 | ) 31 | 32 | try: 33 | for command in commands: 34 | subprocess.run(command, check=True) 35 | print("File file downloaded succesfully.") 36 | except subprocess.CalledProcessError as e: 37 | print(f"Error downloading file: {e}") 38 | try: 39 | for e_command in extract_commands: 40 | subprocess.run(e_command, check=True) 41 | print("Extraction complete.") 42 | except subprocess.CalledProcessError as e: 43 | print(f"Extraction failed: {e}") 44 | 45 | 46 | if __name__ == "__main__": 47 | tyro.cli(download_nrgbd) 48 | -------------------------------------------------------------------------------- /dn_splatter/data/download_scripts/replica_download.py: -------------------------------------------------------------------------------- 1 | """Script to download pre-processed Replica dataset. Total size of dataset is 12.4 gb.""" 2 | import os 3 | import subprocess 4 | from pathlib import Path 5 | 6 | import tyro 7 | 8 | 9 | def download_replica( 10 | save_dir: Path = Path(os.getcwd() + "/datasets"), 11 | ): 12 | save_zip_dir = save_dir / "Replica" 13 | save_dir.mkdir(parents=True, exist_ok=True) 14 | 15 | url = "https://cvg-data.inf.ethz.ch/nice-slam/data/Replica.zip" 16 | wget_command = ["wget", "-P", str(save_zip_dir), url] 17 | file_name = "Replica.zip" 18 | extract_command = ["unzip", save_zip_dir / file_name, "-d", save_dir] 19 | 20 | try: 21 | subprocess.run(wget_command, check=True) 22 | print("File file downloaded succesfully.") 23 | except subprocess.CalledProcessError as e: 24 | print(f"Error downloading file: {e}") 25 | try: 26 | subprocess.run(extract_command, check=True) 27 | print("Extraction complete.") 28 | except subprocess.CalledProcessError as e: 29 | print(f"Extraction failed: {e}") 30 | 31 | 32 | if __name__ == "__main__": 33 | tyro.cli(download_replica) 34 | -------------------------------------------------------------------------------- /dn_splatter/data/mushroom_utils/eval_faro.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """eval.py 3 | 4 | run with: python python dn_splatter/eval.py --data [PATH_TO_DATA] 5 | 6 | options : 7 | --eval-faro / --no-eval-faro 8 | 9 | eval-faro option is used for reference faro scanner projected depth maps 10 | """ 11 | import json 12 | import os 13 | from pathlib import Path 14 | from typing import Optional 15 | 16 | import torch 17 | import torchvision.transforms.functional as F 18 | import tyro 19 | from dn_splatter.metrics import DepthMetrics 20 | from dn_splatter.utils.utils import depth_path_to_tensor 21 | from rich.console import Console 22 | from rich.progress import track 23 | from torchmetrics.functional import mean_squared_error 24 | 25 | CONSOLE = Console(width=120) 26 | BATCH_SIZE = 40 27 | 28 | 29 | def depth_eval_faro(data: Path, path_to_faro: Path): 30 | transform_meta = data / "dataparser_transforms.json" 31 | meta = json.load(open(transform_meta, "r")) 32 | scale = meta["scale"] 33 | depth_metrics = DepthMetrics() 34 | 35 | render_path = data / Path("final_renders/pred/depth/raw/") 36 | gt_path = path_to_faro 37 | 38 | long_depth_list = [ 39 | f for f in os.listdir(render_path) if f.endswith(".npy") and "long_capture" in f 40 | ] 41 | short_depth_list = [ 42 | f 43 | for f in os.listdir(render_path) 44 | if f.endswith(".npy") and "short_capture" in f 45 | ] 46 | long_depth_list = sorted( 47 | long_depth_list, key=lambda x: int(x.split(".")[0].split("_")[-1]) 48 | ) 49 | short_depth_list = sorted( 50 | short_depth_list, key=lambda x: int(x.split(".")[0].split("_")[-1]) 51 | ) 52 | 53 | test_id_within = path_to_faro / Path("long_capture/test.txt") 54 | with open(test_id_within) as f: 55 | lines = f.readlines() 56 | i_eval_within = [num.split("\n")[0] for num in lines] 57 | 58 | mse = mean_squared_error 59 | 60 | long_num_frames = len(long_depth_list) 61 | short_num_frames = len(short_depth_list) 62 | 63 | mse_score_batch = [] 64 | abs_rel_score_batch = [] 65 | sq_rel_score_batch = [] 66 | rmse_score_batch = [] 67 | rmse_log_score_batch = [] 68 | a1_score_batch = [] 69 | a2_score_batch = [] 70 | a3_score_batch = [] 71 | 72 | CONSOLE.print( 73 | f"[bold green]Batchifying and evaluating a total of {long_num_frames + short_num_frames} depth frames" 74 | ) 75 | 76 | def calculate_metrics(num_frames, depth_list, capture_mode): 77 | for batch_index in track(range(0, num_frames, BATCH_SIZE)): 78 | CONSOLE.print( 79 | f"[bold yellow]Evaluating batch {batch_index // BATCH_SIZE} / {num_frames//BATCH_SIZE}" 80 | ) 81 | batch_frames = depth_list[batch_index : batch_index + BATCH_SIZE] 82 | predicted_depth = [] 83 | gt_depth = [] 84 | 85 | for i in batch_frames: 86 | render_img = depth_path_to_tensor( 87 | Path(os.path.join(render_path, i)) 88 | ).permute(2, 0, 1) 89 | 90 | gt_folder = gt_path / Path(capture_mode) / Path("reference_depth") 91 | 92 | if "iphone" in str(gt_folder.parent): 93 | name_list = i.split(".")[0].split("_") 94 | depth_name = name_list[-2] + "_" + name_list[-1] 95 | gt_name = gt_folder / Path(depth_name + ".png") 96 | elif "kinect" in str(gt_folder.parent): 97 | depth_name = i.split(".")[0].split("_")[-1] 98 | gt_name = gt_folder / Path(depth_name + ".png") 99 | if not gt_name.exists(): 100 | print("could not find frame ", gt_name, " skipping it...") 101 | continue 102 | if capture_mode == "long_capture": 103 | if depth_name not in i_eval_within: 104 | continue 105 | 106 | origin_img = depth_path_to_tensor(gt_name).permute(2, 0, 1) 107 | 108 | if origin_img.shape[-2:] != render_img.shape[-2:]: 109 | render_img = F.resize( 110 | render_img, size=origin_img.shape[-2:], antialias=None 111 | ) 112 | 113 | render_img = render_img / scale 114 | predicted_depth.append(render_img) 115 | gt_depth.append(origin_img) 116 | 117 | predicted_depth = torch.stack(predicted_depth, 0) 118 | gt_depth = torch.stack(gt_depth, 0) 119 | 120 | mse_score = mse(predicted_depth, gt_depth) 121 | (abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3) = depth_metrics( 122 | predicted_depth, gt_depth 123 | ) 124 | 125 | mse_score_batch.append(mse_score) 126 | abs_rel_score_batch.append(abs_rel) 127 | sq_rel_score_batch.append(sq_rel) 128 | rmse_score_batch.append(rmse) 129 | rmse_log_score_batch.append(rmse_log) 130 | a1_score_batch.append(a1) 131 | a2_score_batch.append(a2) 132 | a3_score_batch.append(a3) 133 | 134 | mean_scores = { 135 | "mse": float(torch.stack(mse_score_batch).mean().item()), 136 | "abs_rel": float(torch.stack(abs_rel_score_batch).mean().item()), 137 | "sq_rel": float(torch.stack(sq_rel_score_batch).mean().item()), 138 | "rmse": float(torch.stack(rmse_score_batch).mean().item()), 139 | "rmse_log": float(torch.stack(rmse_log_score_batch).mean().item()), 140 | "a1": float(torch.stack(a1_score_batch).mean().item()), 141 | "a2": float(torch.stack(a2_score_batch).mean().item()), 142 | "a3": float(torch.stack(a3_score_batch).mean().item()), 143 | } 144 | return mean_scores 145 | 146 | long_means_scores = calculate_metrics( 147 | long_num_frames, long_depth_list, "long_capture" 148 | ) 149 | short_means_scores = calculate_metrics( 150 | short_num_frames, short_depth_list, "short_capture" 151 | ) 152 | 153 | metrics_dict = {} 154 | for key in long_means_scores.keys(): 155 | metrics_dict["within_faro_" + key] = long_means_scores[key] 156 | 157 | for key in short_means_scores.keys(): 158 | metrics_dict["with_faro_" + key] = short_means_scores[key] 159 | 160 | return metrics_dict 161 | 162 | 163 | def main(data: Path, eval_faro: bool = False, path_to_faro: Optional[Path] = None): 164 | if eval_faro: 165 | assert path_to_faro is not None, "need to specify faro path" 166 | depth_eval_faro(data, path_to_faro=path_to_faro) 167 | 168 | 169 | if __name__ == "__main__": 170 | tyro.cli(main) 171 | -------------------------------------------------------------------------------- /dn_splatter/data/mushroom_utils/mushroom_batch_run.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Runs kinect and iphone datasets for all loss types back-to-back 4 | # NOTE: make sure you 'chmod +x' to be able to execute it 5 | 6 | # Configs: Mushroom dataset 7 | dataset_name="activity" 8 | root_path_to_mushroom="" # dont add a trailing '/' here 9 | depth_lambda=0.1 10 | iters_to_save=25000 11 | 12 | # kinect commands 13 | python_command_no_depth="ns-train dn_splatter --viewer.quit-on-train-completion True --pipeline.model.eval_all_images_at_train_iter True --pipeline.model.iters_to_save ${iters_to_save} --pipeline.model.save_eval_name ${dataset_name}-kinect-no-depth --pipeline.model.use-depth-loss False --data ${root_path_to_mushroom}/${dataset_name}/kinect/long_capture mushroom --mode kinect" 14 | python_command_mse="ns-train dn_splatter --viewer.quit-on-train-completion True --pipeline.model.eval_all_images_at_train_iter True --pipeline.model.iters_to_save ${iters_to_save} --pipeline.model.save_eval_name ${dataset_name}-kinect-depth-mse --pipeline.model.use-depth-loss True --pipeline.model.depth-loss-type MSE --pipeline.model.depth-lambda ${depth_lambda} --data ${root_path_to_mushroom}/${dataset_name}/kinect/long_capture mushroom --mode kinect" 15 | python_command_logl1="ns-train dn_splatter --viewer.quit-on-train-completion True --pipeline.model.eval_all_images_at_train_iter True --pipeline.model.iters_to_save ${iters_to_save} --pipeline.model.save_eval_name ${dataset_name}-kinect-depth-logl1 --pipeline.model.use-depth-loss True --pipeline.model.depth-loss-type LogL1 --pipeline.model.depth-lambda ${depth_lambda} --data ${root_path_to_mushroom}/${dataset_name}/kinect/long_capture mushroom --mode kinect" 16 | python_command_huberl1="ns-train dn_splatter --viewer.quit-on-train-completion True --pipeline.model.eval_all_images_at_train_iter True --pipeline.model.iters_to_save ${iters_to_save} --pipeline.model.save_eval_name ${dataset_name}-kinect-depth-huberl1 --pipeline.model.use-depth-loss True --pipeline.model.depth-loss-type HuberL1 --pipeline.model.depth-lambda ${depth_lambda} --data ${root_path_to_mushroom}/${dataset_name}/kinect/long_capture mushroom --mode kinect" 17 | python_command_dssiml1="ns-train dn_splatter --viewer.quit-on-train-completion True --pipeline.model.eval_all_images_at_train_iter True --pipeline.model.iters_to_save ${iters_to_save} --pipeline.model.save_eval_name ${dataset_name}-kinect-depth-dssiml1 --pipeline.model.use-depth-loss True --pipeline.model.depth-loss-type DSSIML1 --pipeline.model.depth-lambda ${depth_lambda} --data ${root_path_to_mushroom}/${dataset_name}/kinect/long_capture mushroom --mode kinect" 18 | python_command_l1="ns-train dn_splatter --viewer.quit-on-train-completion True --pipeline.model.eval_all_images_at_train_iter True --pipeline.model.iters_to_save ${iters_to_save} --pipeline.model.save_eval_name ${dataset_name}-kinect-depth-l1 --pipeline.model.use-depth-loss True --pipeline.model.depth-loss-type L1 --pipeline.model.depth-lambda ${depth_lambda} --data ${root_path_to_mushroom}/${dataset_name}/kinect/long_capture mushroom --mode kinect" 19 | 20 | 21 | # execute kinect commands 22 | echo "Evaluating on the kinect no-depth sequence..." 23 | $python_command_no_depth 24 | echo "Evaluating on the kinect depth-mse sequence..." 25 | $python_command_mse 26 | echo "Evaluating on the kinect depth-logl1 sequence..." 27 | $python_command_logl1 28 | echo "Evaluating on the kinect depth-huberl1 sequence..." 29 | $python_command_huberl1 30 | echo "Evaluating on the kinect depth-dssiml1 sequence..." 31 | $python_command_dssiml1 32 | echo "Evaluating on the kinect depth-l1 sequence..." 33 | $python_command_l1 34 | 35 | # iphone commands 36 | python_command_no_depth="ns-train dn_splatter --viewer.quit-on-train-completion True --pipeline.model.eval_all_images_at_train_iter True --pipeline.model.iters_to_save ${iters_to_save} --pipeline.model.save_eval_name ${dataset_name}-iphone-no-depth --pipeline.model.use-depth-loss False --data ${root_path_to_mushroom}/${dataset_name}/iphone/long_capture mushroom --mode iphone" 37 | python_command_mse="ns-train dn_splatter --viewer.quit-on-train-completion True --pipeline.model.eval_all_images_at_train_iter True --pipeline.model.iters_to_save ${iters_to_save} --pipeline.model.save_eval_name ${dataset_name}-iphone-depth-mse --pipeline.model.use-depth-loss True --pipeline.model.depth-loss-type MSE --pipeline.model.depth-lambda ${depth_lambda} --data ${root_path_to_mushroom}/${dataset_name}/iphone/long_capture mushroom --mode iphone" 38 | python_command_logl1="ns-train dn_splatter --viewer.quit-on-train-completion True --pipeline.model.eval_all_images_at_train_iter True --pipeline.model.iters_to_save ${iters_to_save} --pipeline.model.save_eval_name ${dataset_name}-iphone-depth-logl1 --pipeline.model.use-depth-loss True --pipeline.model.depth-loss-type LogL1 --pipeline.model.depth-lambda ${depth_lambda} --data ${root_path_to_mushroom}/${dataset_name}/iphone/long_capture mushroom --mode iphone" 39 | python_command_huberl1="ns-train dn_splatter --viewer.quit-on-train-completion True --pipeline.model.eval_all_images_at_train_iter True --pipeline.model.iters_to_save ${iters_to_save} --pipeline.model.save_eval_name ${dataset_name}-iphone-depth-huberl1 --pipeline.model.use-depth-loss True --pipeline.model.depth-loss-type HuberL1 --pipeline.model.depth-lambda ${depth_lambda} --data ${root_path_to_mushroom}/${dataset_name}/iphone/long_capture mushroom --mode iphone" 40 | python_command_dssiml1="ns-train dn_splatter --viewer.quit-on-train-completion True --pipeline.model.eval_all_images_at_train_iter True --pipeline.model.iters_to_save ${iters_to_save} --pipeline.model.save_eval_name ${dataset_name}-iphone-depth-dssiml1 --pipeline.model.use-depth-loss True --pipeline.model.depth-loss-type DSSIML1 --pipeline.model.depth-lambda ${depth_lambda} --data ${root_path_to_mushroom}/${dataset_name}/iphone/long_capture mushroom --mode iphone" 41 | python_command_l1="ns-train dn_splatter --viewer.quit-on-train-completion True --pipeline.model.eval_all_images_at_train_iter True --pipeline.model.iters_to_save ${iters_to_save} --pipeline.model.save_eval_name ${dataset_name}-iphone-depth-l1 --pipeline.model.use-depth-loss True --pipeline.model.depth-loss-type L1 --pipeline.model.depth-lambda ${depth_lambda} --data ${root_path_to_mushroom}/${dataset_name}/iphone/long_capture mushroom --mode iphone" 42 | 43 | # execute iphone commands 44 | echo "Evaluating on the iphone no-depth sequence..." 45 | $python_command_no_depth 46 | echo "Evaluating on the depth-mse iphone..." 47 | $python_command_mse 48 | echo "Evaluating on the depth-logl1 iphone..." 49 | $python_command_logl1 50 | echo "Evaluating on the depth-huberl1 iphone..." 51 | $python_command_huberl1 52 | echo "Evaluating on the depth-dssiml1 iphone..." 53 | $python_command_dssiml1 54 | echo "Evaluating on the depth-l1 iphone..." 55 | $python_command_l1 56 | -------------------------------------------------------------------------------- /dn_splatter/data/mushroom_utils/pointcloud_utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import random 4 | from pathlib import Path 5 | 6 | import cv2 7 | import numpy as np 8 | import open3d as o3d 9 | import tyro 10 | from dn_splatter.utils.camera_utils import OPENGL_TO_OPENCV 11 | from rich.progress import track 12 | 13 | 14 | def generate_kinect_pointcloud_within_sequence( 15 | data_path: Path, num_points: int = 1_000_000 16 | ): 17 | print("Generating pointclouds from kinect data...") 18 | 19 | info_file = json.load(open(os.path.join(data_path, "transformations_colmap.json"))) 20 | 21 | frames = info_file["frames"] 22 | frame_names = [frame["file_path"].split("/")[-1].split(".")[0] for frame in frames] 23 | 24 | num_images = len(frames) 25 | i_all = np.arange(num_images) 26 | 27 | with open(os.path.join(data_path, "test.txt")) as f: 28 | lines = f.readlines() 29 | i_eval_name = [num.split("\n")[0] for num in lines] 30 | 31 | # only select images that exist in frame_names 32 | i_eval_name = [name for name in i_eval_name if name in frame_names] 33 | i_eval = [frame_names.index(name) for name in i_eval_name] 34 | i_train = np.setdiff1d(i_all, i_eval) # type: ignore 35 | 36 | index = i_train 37 | 38 | points_list = [] 39 | colors_list = [] 40 | normals_list = [] 41 | 42 | samples_per_frame = (num_points + len(index)) // (len(index)) 43 | 44 | for item in track(index, description="processing ... "): 45 | frame = frames[item] 46 | name = frame["file_path"].split("/")[-1].split(".")[0] 47 | pcd = o3d.io.read_point_cloud( 48 | os.path.join(data_path, "PointCloud", name + ".ply") 49 | ) 50 | 51 | # change the pd from spectacularAI pose world coordination to colmap pose world coordination 52 | original_pose = np.loadtxt( 53 | os.path.join(data_path, "pose", name + ".txt") 54 | ).reshape(4, 4) 55 | original_pose = np.matmul(original_pose, OPENGL_TO_OPENCV) 56 | pcd = pcd.transform(np.linalg.inv(original_pose)) 57 | 58 | colmap_pose = frame["transform_matrix"] 59 | pcd = pcd.transform(colmap_pose) 60 | 61 | samples_per_frame = min(samples_per_frame, len(pcd.points)) 62 | 63 | mask = random.sample(range(len(pcd.points)), samples_per_frame) 64 | mask = np.asarray(mask) 65 | color = np.asarray(pcd.colors)[mask] 66 | point = np.asarray(pcd.points)[mask] 67 | normal = np.asarray(pcd.normals)[mask] 68 | 69 | points_list.append(np.asarray(point)) 70 | colors_list.append(np.asarray(color)) 71 | normals_list.append(np.asarray(normal)) 72 | 73 | cloud = o3d.geometry.PointCloud() 74 | points = o3d.utility.Vector3dVector(np.vstack(points_list)) 75 | colors = o3d.utility.Vector3dVector(np.vstack(colors_list)) 76 | normals = o3d.utility.Vector3dVector(np.vstack(normals_list)) 77 | 78 | cloud.points = points 79 | cloud.colors = colors 80 | cloud.normals = normals 81 | if os.path.exists(os.path.join(data_path, "kinect_pointcloud.ply")): 82 | os.remove(os.path.join(data_path, "kinect_pointcloud.ply")) 83 | o3d.io.write_point_cloud(os.path.join(data_path, "kinect_pointcloud.ply"), cloud) 84 | 85 | 86 | def generate_iPhone_pointcloud_within_sequence( 87 | data_path: Path, num_points: int = 1_000_000 88 | ): 89 | print("Generating pointcloud from iPhone data...") 90 | info_file = json.load(open(os.path.join(data_path, "transformations_colmap.json"))) 91 | 92 | frames = info_file["frames"] 93 | frame_names = [frame["file_path"].split("/")[-1].split(".")[0] for frame in frames] 94 | 95 | volume = o3d.pipelines.integration.ScalableTSDFVolume( 96 | voxel_length=0.04, 97 | sdf_trunc=0.2, 98 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8, 99 | ) 100 | 101 | num_images = len(frames) 102 | i_all = np.arange(num_images) 103 | 104 | with open(os.path.join(data_path, "test.txt")) as f: 105 | lines = f.readlines() 106 | i_eval_name = [num.split("\n")[0] for num in lines] 107 | i_eval_name = [name for name in i_eval_name if name in frame_names] 108 | i_eval = [frame_names.index(name) for name in i_eval_name] 109 | i_train = np.setdiff1d(i_all, i_eval) 110 | 111 | index = i_train 112 | 113 | points_list = [] 114 | colors_list = [] 115 | 116 | if "fl_x" in info_file: 117 | fx, fy, cx, cy = ( 118 | float(info_file["fl_x"]), 119 | float(info_file["fl_y"]), 120 | float(info_file["cx"]), 121 | float(info_file["cy"]), 122 | ) 123 | H = int(info_file["h"]) 124 | W = int(info_file["w"]) 125 | camera_intrinsics = o3d.camera.PinholeCameraIntrinsic(W, H, fx, fy, cx, cy) 126 | 127 | samples_per_frame = (num_points + len(index)) // (len(index)) 128 | 129 | for item in track(index, description="processing ... "): 130 | frame = frames[item] 131 | if "fl_x" in frame: 132 | fx, fy, cx, cy = ( 133 | float(frame["fl_x"]), 134 | float(frame["fl_y"]), 135 | float(frame["cx"]), 136 | float(frame["cy"]), 137 | ) 138 | H = int(frame["h"]) 139 | W = int(frame["w"]) 140 | camera_intrinsics = o3d.camera.PinholeCameraIntrinsic(W, H, fx, fy, cx, cy) 141 | 142 | color = cv2.imread(os.path.join(data_path, frame["file_path"])) 143 | color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB) 144 | color = o3d.geometry.Image(color) 145 | 146 | # pose 147 | pose = frame["transform_matrix"] 148 | pose = np.matmul(np.array(pose), OPENGL_TO_OPENCV) 149 | 150 | depth = cv2.imread( 151 | os.path.join(data_path, frame["depth_file_path"]), cv2.IMREAD_ANYDEPTH 152 | ) 153 | depth = cv2.resize(depth, (W, H)) # type: ignore 154 | depth = o3d.geometry.Image(depth) 155 | 156 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 157 | color, depth, depth_trunc=4.0, convert_rgb_to_intensity=False 158 | ) 159 | 160 | volume.integrate( 161 | rgbd, 162 | camera_intrinsics, # type: ignore 163 | np.linalg.inv(pose), 164 | ) 165 | 166 | pcd = volume.extract_point_cloud() 167 | 168 | # randomly select samples_per_frame points from points 169 | samples_per_frame = min(samples_per_frame, len(pcd.points)) 170 | mask = random.sample(range(len(pcd.points)), samples_per_frame) 171 | mask = np.asarray(mask) 172 | color = np.asarray(pcd.colors)[mask] 173 | point = np.asarray(pcd.points)[mask] 174 | 175 | points_list.append(np.asarray(point)) 176 | colors_list.append(np.asarray(color)) 177 | 178 | points = np.vstack(points_list) 179 | colors = np.vstack(colors_list) 180 | 181 | if points.shape[0] > num_points: 182 | # ensure final num points is exact 183 | indices = np.random.choice(points.shape[0], size=num_points, replace=False) 184 | points = points[indices] 185 | colors = colors[indices] 186 | 187 | pcd = o3d.geometry.PointCloud() 188 | pcd.points = o3d.utility.Vector3dVector(points) 189 | pcd.colors = o3d.utility.Vector3dVector(colors) 190 | 191 | if os.path.exists(os.path.join(data_path, "iphone_pointcloud.ply")): 192 | os.remove(os.path.join(data_path, "iphone_pointcloud.ply")) 193 | o3d.io.write_point_cloud(os.path.join(data_path, "iphone_pointcloud.ply"), pcd) 194 | 195 | 196 | def generate_polycam_pointcloud(data_path: Path, num_points: int = 1_000_000): 197 | print("Generating pointcloud from iPhone data...") 198 | info_file = json.load(open(os.path.join(data_path, "transforms.json"))) 199 | 200 | frames = info_file["frames"] 201 | 202 | volume = o3d.pipelines.integration.ScalableTSDFVolume( 203 | voxel_length=0.04, 204 | sdf_trunc=0.2, 205 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8, 206 | ) 207 | 208 | num_images = len(frames) 209 | i_all = np.arange(num_images) 210 | index = i_all 211 | 212 | points_list = [] 213 | colors_list = [] 214 | 215 | if "fl_x" in info_file: 216 | fx, fy, cx, cy = ( 217 | float(info_file["fl_x"]), 218 | float(info_file["fl_y"]), 219 | float(info_file["cx"]), 220 | float(info_file["cy"]), 221 | ) 222 | H = int(info_file["h"]) 223 | W = int(info_file["w"]) 224 | camera_intrinsics = o3d.camera.PinholeCameraIntrinsic(W, H, fx, fy, cx, cy) 225 | 226 | samples_per_frame = (num_points + len(index)) // (len(index)) 227 | 228 | for item in track(index, description="processing ... "): 229 | frame = frames[item] 230 | if "fl_x" in frame: 231 | fx, fy, cx, cy = ( 232 | float(frame["fl_x"]), 233 | float(frame["fl_y"]), 234 | float(frame["cx"]), 235 | float(frame["cy"]), 236 | ) 237 | H = int(frame["h"]) 238 | W = int(frame["w"]) 239 | camera_intrinsics = o3d.camera.PinholeCameraIntrinsic(W, H, fx, fy, cx, cy) 240 | 241 | color = cv2.imread(os.path.join(data_path, frame["file_path"])) 242 | color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB) 243 | color = o3d.geometry.Image(color) 244 | 245 | # pose 246 | pose = frame["transform_matrix"] 247 | pose = np.matmul(np.array(pose), OPENGL_TO_OPENCV) 248 | 249 | depth = cv2.imread( 250 | os.path.join(data_path, frame["depth_file_path"]), cv2.IMREAD_ANYDEPTH 251 | ) 252 | depth = cv2.resize(depth, (W, H)) # type: ignore 253 | depth = o3d.geometry.Image(depth) 254 | 255 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 256 | color, depth, depth_trunc=4.0, convert_rgb_to_intensity=False 257 | ) 258 | 259 | volume.integrate( 260 | rgbd, 261 | camera_intrinsics, # type: ignore 262 | np.linalg.inv(pose), 263 | ) 264 | 265 | pcd = volume.extract_point_cloud() 266 | 267 | # randomly select samples_per_frame points from points 268 | samples_per_frame = min(samples_per_frame, len(pcd.points)) 269 | mask = random.sample(range(len(pcd.points)), samples_per_frame) 270 | mask = np.asarray(mask) 271 | color = np.asarray(pcd.colors)[mask] 272 | point = np.asarray(pcd.points)[mask] 273 | 274 | points_list.append(np.asarray(point)) 275 | colors_list.append(np.asarray(color)) 276 | 277 | points = np.vstack(points_list) 278 | colors = np.vstack(colors_list) 279 | 280 | if points.shape[0] > num_points: 281 | # ensure final num points is exact 282 | indices = np.random.choice(points.shape[0], size=num_points, replace=False) 283 | points = points[indices] 284 | colors = colors[indices] 285 | 286 | pcd = o3d.geometry.PointCloud() 287 | pcd.points = o3d.utility.Vector3dVector(points) 288 | pcd.colors = o3d.utility.Vector3dVector(colors) 289 | 290 | o3d.io.write_point_cloud(os.path.join(data_path, "iphone_pointcloud.ply"), pcd) 291 | # write ply_file_path to transforms.json file 292 | info_file["ply_file_path"] = "./iphone_pointcloud.ply" 293 | with open(os.path.join(data_path, "transforms.json"), "w") as f: 294 | json.dump(info_file, f) 295 | 296 | 297 | if __name__ == "__main__": 298 | # tyro.cli(generate_kinect_pointcloud_within_sequence) 299 | tyro.cli(generate_polycam_pointcloud) 300 | -------------------------------------------------------------------------------- /dn_splatter/data/mushroom_utils/reference_depth_download.py: -------------------------------------------------------------------------------- 1 | import os 2 | import subprocess 3 | from pathlib import Path 4 | 5 | import tyro 6 | 7 | 8 | def download_reference_depth( 9 | save_dir: Path = Path(os.getcwd() + "/datasets"), 10 | ): 11 | save_dir.mkdir(parents=True, exist_ok=True) 12 | 13 | depth_url = "https://zenodo.org/records/10438963/files/reference_depth.tar.gz" 14 | 15 | wget_command = ["wget", "-P", str(save_dir), depth_url] 16 | file_name = "reference_depth.tar.gz" 17 | extract_command = ["tar", "-xvzf", save_dir / file_name, "-C", save_dir] 18 | 19 | try: 20 | subprocess.run(wget_command, check=True) 21 | print("File file downloaded succesfully.") 22 | except subprocess.CalledProcessError as e: 23 | print(f"Error downloading file: {e}") 24 | try: 25 | subprocess.run(extract_command, check=True) 26 | print("Extraction complete.") 27 | except subprocess.CalledProcessError as e: 28 | print(f"Extraction failed: {e}") 29 | 30 | 31 | if __name__ == "__main__": 32 | tyro.cli(download_reference_depth) 33 | -------------------------------------------------------------------------------- /dn_splatter/data/mushroom_utils/render_faro_nm.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | from pathlib import Path 5 | 6 | import numpy as np 7 | import torch 8 | import torch.nn as nn 9 | import trimesh 10 | import tyro 11 | from PIL import Image 12 | from pytorch3d import transforms as py3d_transform 13 | from pytorch3d.ops.interp_face_attrs import interpolate_face_attributes 14 | from pytorch3d.renderer import ( 15 | MeshRasterizer, 16 | MeshRenderer, 17 | PerspectiveCameras, 18 | RasterizationSettings, 19 | ) 20 | from pytorch3d.renderer.blending import BlendParams, softmax_rgb_blend 21 | from pytorch3d.structures import Meshes 22 | 23 | # from dn_splatter.utils.camera_utils import OPENGL_TO_OPENCV 24 | 25 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 26 | 27 | 28 | def interpolate_vertex_normals( 29 | fragments, vertex_textures, faces_packed 30 | ) -> torch.Tensor: 31 | """ 32 | Detemine the normal color for each rasterized face. Interpolate the normal colors for 33 | vertices which form the face using the barycentric coordinates. 34 | Args: 35 | meshes: A Meshes class representing a batch of meshes. 36 | fragments: 37 | The outputs of rasterization. From this we use 38 | 39 | - pix_to_face: LongTensor of shape (N, H, W, K) specifying the indices 40 | of the faces (in the packed representation) which 41 | overlap each pixel in the image. 42 | - barycentric_coords: FloatTensor of shape (N, H, W, K, 3) specifying 43 | the barycentric coordianates of each pixel 44 | relative to the faces (in the packed 45 | representation) which overlap the pixel. 46 | 47 | Returns: 48 | texels: An normal color per pixel of shape (N, H, W, K, C). 49 | There will be one C dimensional value for each element in 50 | fragments.pix_to_face. 51 | """ 52 | # vertex_textures = meshes.textures.verts_rgb_padded().reshape(-1, 3) # (V, C) 53 | # vertex_textures = vertex_textures[meshes.verts_padded_to_packed_idx(), :] 54 | 55 | # X: -1 to + 1: Red: 0 to 255 56 | # Y: -1 to + 1: Green: 0 to 255 57 | # Z: 0 to -1: Blue: 128 to 255 58 | 59 | vertex_textures[:, :2] += 1 60 | vertex_textures[:, :2] /= 2 61 | vertex_textures[:, 2] += 3 62 | vertex_textures[:, 2] /= 4 63 | vertex_textures /= torch.norm(vertex_textures, p=2, dim=-1).view( 64 | vertex_textures.shape[0], 1 65 | ) 66 | 67 | faces_textures = vertex_textures[faces_packed] # (F, 3, C) 68 | texels = interpolate_face_attributes( 69 | fragments.pix_to_face, fragments.bary_coords, faces_textures 70 | ) 71 | return texels 72 | 73 | 74 | class NormalShader(nn.Module): 75 | """ 76 | Per pixel lighting - the lighting model is applied using the interpolated 77 | coordinates and normals for each pixel. The blending function returns the 78 | soft aggregated color using all the faces per pixel. 79 | To use the default values, simply initialize the shader with the desired 80 | device e.g. 81 | .. code-block:: 82 | shader = SoftPhongShader(device=torch.device("cuda:0")) 83 | """ 84 | 85 | def __init__( 86 | self, 87 | device="cpu", 88 | cameras=None, 89 | blend_params=None, 90 | vertex_textures=None, 91 | faces_packed=None, 92 | ): 93 | super().__init__() 94 | 95 | self.cameras = cameras 96 | self.blend_params = blend_params if blend_params is not None else BlendParams() 97 | self.vertex_textures = vertex_textures 98 | self.faces_packed = faces_packed 99 | 100 | def forward(self, fragments, mesh, **kwargs) -> torch.Tensor: 101 | cameras = kwargs.get("cameras", self.cameras) 102 | if cameras is None: 103 | msg = "Cameras must be specified either at initialization \ 104 | or in the forward pass of SoftPhongShader" 105 | raise ValueError(msg) 106 | 107 | texels = interpolate_vertex_normals( 108 | fragments, self.vertex_textures, self.faces_packed 109 | ) 110 | images = softmax_rgb_blend(texels, fragments, self.blend_params) 111 | return images 112 | 113 | 114 | def main( 115 | input_dir: Path, 116 | gt_mesh_path: Path, 117 | transformation_path: Path, 118 | ): 119 | mesh = trimesh.load(os.path.join(gt_mesh_path), force="mesh", process=False) 120 | 121 | initial_transformation = np.array( 122 | json.load(open(os.path.join(transformation_path)))["gt_transformation"] 123 | ).reshape(4, 4) 124 | initial_transformation = np.linalg.inv(initial_transformation) 125 | 126 | mesh = mesh.apply_transform(initial_transformation) 127 | 128 | vertices = torch.tensor(mesh.vertices, dtype=torch.float32) 129 | faces = torch.tensor(mesh.faces, dtype=torch.int64) 130 | mesh = Meshes(verts=[vertices], faces=[faces]).to(device) 131 | 132 | Rz_rot = py3d_transform.euler_angles_to_matrix( 133 | torch.tensor([0.0, 0.0, math.pi]), convention="XYZ" 134 | ).cuda() 135 | output_path = os.path.join(input_dir, "reference_normal") 136 | 137 | os.makedirs(output_path, exist_ok=True) 138 | 139 | transformation_info = json.load( 140 | open(os.path.join(input_dir, "transformations_colmap.json")) 141 | ) 142 | frames = transformation_info["frames"] 143 | 144 | if "fl_x" in transformation_info: 145 | intrinsic_matrix_base = ( 146 | np.array( 147 | [ 148 | transformation_info["fl_x"], 149 | 0, 150 | transformation_info["cx"], 151 | 0, 152 | 0, 153 | transformation_info["fl_y"], 154 | transformation_info["cy"], 155 | 0, 156 | 0, 157 | 0, 158 | 1, 159 | 0, 160 | 0, 161 | 0, 162 | 0, 163 | 1, 164 | ] 165 | ) 166 | .reshape(4, 4) 167 | .astype(np.float32) 168 | ) 169 | H = transformation_info["h"] 170 | W = transformation_info["w"] 171 | 172 | mesh = mesh.to(device) 173 | vertex_textures = mesh.verts_normals_packed() # .to(device) 174 | faces_packed = mesh.faces_packed() 175 | OPENGL_TO_OPENCV = np.array( 176 | [[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]] 177 | ) 178 | 179 | for frame in frames: 180 | image_path = frame["file_path"] 181 | 182 | image_name = image_path.split("/")[-1] 183 | 184 | if "fl_x" in frame: 185 | intrinsic_matrix_base = ( 186 | np.array( 187 | [ 188 | frame["fl_x"], 189 | 0, 190 | frame["cx"], 191 | 0, 192 | 0, 193 | frame["fl_y"], 194 | frame["cy"], 195 | 0, 196 | 0, 197 | 0, 198 | 1, 199 | 0, 200 | 0, 201 | 0, 202 | 0, 203 | 1, 204 | ] 205 | ) 206 | .reshape(4, 4) 207 | .astype(np.float32) 208 | ) 209 | H = frame["h"] 210 | W = frame["w"] 211 | 212 | intrinsic_matrix = torch.from_numpy(intrinsic_matrix_base).unsqueeze(0).cuda() 213 | 214 | focal_length = torch.stack( 215 | [intrinsic_matrix[:, 0, 0], intrinsic_matrix[:, 1, 1]], dim=-1 216 | ) 217 | principal_point = intrinsic_matrix[:, :2, 2] 218 | 219 | image_size = torch.tensor([[H, W]]).cuda() 220 | 221 | image_size_wh = image_size.flip(dims=(1,)) 222 | 223 | s = image_size.min(dim=1, keepdim=True)[0] / 2 224 | 225 | s.expand(-1, 2) 226 | image_size_wh / 2.0 227 | 228 | c2w = frame["transform_matrix"] 229 | 230 | c2w = np.matmul(np.array(c2w), OPENGL_TO_OPENCV).astype(np.float32) 231 | c2w = np.linalg.inv(c2w) 232 | c2w = torch.from_numpy(c2w).cuda() 233 | 234 | R = c2w[:3, :3] 235 | T = c2w[:3, 3] 236 | 237 | R2 = (Rz_rot @ R).permute(-1, -2) 238 | T2 = Rz_rot @ T 239 | 240 | cameras = PerspectiveCameras( 241 | focal_length=focal_length, 242 | principal_point=principal_point, 243 | R=R2.unsqueeze(0), 244 | T=T2.unsqueeze(0), 245 | image_size=image_size, 246 | in_ndc=False, 247 | # K = intrinsic_matrix, 248 | device=device, 249 | ) 250 | 251 | raster_settings = RasterizationSettings( 252 | image_size=(H, W), 253 | blur_radius=0.0, 254 | faces_per_pixel=1, 255 | perspective_correct=True, 256 | ) 257 | 258 | renderer = MeshRenderer( 259 | rasterizer=MeshRasterizer(cameras=cameras, raster_settings=raster_settings), 260 | shader=NormalShader( 261 | device=device, 262 | cameras=cameras, 263 | vertex_textures=vertex_textures.clone(), 264 | faces_packed=faces_packed.clone(), 265 | ), 266 | ) 267 | 268 | print("start rendering") 269 | render_img = renderer(mesh)[0, ..., :3] * 255 270 | render_img = render_img.squeeze().cpu().numpy().astype(np.uint8) 271 | 272 | if image_name.endswith("jpg"): 273 | image_name = image_name.replace("jpg", "png") 274 | 275 | Image.fromarray(render_img).save(os.path.join(output_path, image_name)) 276 | 277 | 278 | if __name__ == "__main__": 279 | tyro.cli(main) 280 | -------------------------------------------------------------------------------- /dn_splatter/data/mushroom_utils/render_gt_depth.py: -------------------------------------------------------------------------------- 1 | import json 2 | import math 3 | import os 4 | from pathlib import Path 5 | 6 | import cv2 7 | import numpy as np 8 | import torch 9 | import trimesh 10 | import tyro 11 | from pytorch3d import transforms as py3d_transform 12 | from pytorch3d.renderer import MeshRasterizer, PerspectiveCameras, RasterizationSettings 13 | from pytorch3d.structures import Meshes 14 | 15 | # from dn_splatter.utils.camera_utils import OPENGL_TO_OPENCV 16 | 17 | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 18 | OPENGL_TO_OPENCV = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 19 | 20 | 21 | def main( 22 | input_dir: Path, 23 | gt_mesh_path: Path, 24 | transformation_path: Path, 25 | ): 26 | mesh = trimesh.load(os.path.join(gt_mesh_path), force="mesh", process=False) 27 | 28 | initial_transformation = np.array( 29 | json.load(open(os.path.join(transformation_path)))["gt_transformation"] 30 | ).reshape(4, 4) 31 | initial_transformation = np.linalg.inv(initial_transformation) 32 | 33 | mesh = mesh.apply_transform(initial_transformation) 34 | 35 | vertices = torch.tensor(mesh.vertices, dtype=torch.float32) 36 | faces = torch.tensor(mesh.faces, dtype=torch.int64) 37 | mesh = Meshes(verts=[vertices], faces=[faces]).to(device) 38 | 39 | Rz_rot = py3d_transform.euler_angles_to_matrix( 40 | torch.tensor([0.0, 0.0, math.pi]), convention="XYZ" 41 | ).cuda() 42 | output_path = os.path.join(input_dir, "reference_depth") 43 | 44 | os.makedirs(output_path, exist_ok=True) 45 | 46 | transformation_info = json.load( 47 | open(os.path.join(input_dir, "transformations_colmap.json")) 48 | ) 49 | frames = transformation_info["frames"] 50 | 51 | if "fl_x" in transformation_info: 52 | intrinsic_matrix_base = ( 53 | np.array( 54 | [ 55 | transformation_info["fl_x"], 56 | 0, 57 | transformation_info["cx"], 58 | 0, 59 | 0, 60 | transformation_info["fl_y"], 61 | transformation_info["cy"], 62 | 0, 63 | 0, 64 | 0, 65 | 1, 66 | 0, 67 | 0, 68 | 0, 69 | 0, 70 | 1, 71 | ] 72 | ) 73 | .reshape(4, 4) 74 | .astype(np.float32) 75 | ) 76 | H = transformation_info["h"] 77 | W = transformation_info["w"] 78 | 79 | for frame in frames: 80 | image_path = frame["file_path"] 81 | image_name = image_path.split("/")[-1] 82 | 83 | if "fl_x" in frame: 84 | intrinsic_matrix_base = ( 85 | np.array( 86 | [ 87 | frame["fl_x"], 88 | 0, 89 | frame["cx"], 90 | 0, 91 | 0, 92 | frame["fl_y"], 93 | frame["cy"], 94 | 0, 95 | 0, 96 | 0, 97 | 1, 98 | 0, 99 | 0, 100 | 0, 101 | 0, 102 | 1, 103 | ] 104 | ) 105 | .reshape(4, 4) 106 | .astype(np.float32) 107 | ) 108 | H = frame["h"] 109 | W = frame["w"] 110 | 111 | intrinsic_matrix = torch.from_numpy(intrinsic_matrix_base).unsqueeze(0).cuda() 112 | 113 | focal_length = torch.stack( 114 | [intrinsic_matrix[:, 0, 0], intrinsic_matrix[:, 1, 1]], dim=-1 115 | ) 116 | principal_point = intrinsic_matrix[:, :2, 2] 117 | 118 | image_size = torch.tensor([[H, W]]).cuda() 119 | 120 | image_size_wh = image_size.flip(dims=(1,)) 121 | 122 | s = image_size.min(dim=1, keepdim=True)[0] / 2 123 | 124 | s.expand(-1, 2) 125 | image_size_wh / 2.0 126 | 127 | c2w = frame["transform_matrix"] 128 | 129 | c2w = np.matmul(np.array(c2w), OPENGL_TO_OPENCV).astype(np.float32) 130 | c2w = np.linalg.inv(c2w) 131 | c2w = torch.from_numpy(c2w).cuda() 132 | 133 | R = c2w[:3, :3] 134 | T = c2w[:3, 3] 135 | 136 | R2 = (Rz_rot @ R).permute(-1, -2) 137 | T2 = Rz_rot @ T 138 | 139 | cameras = PerspectiveCameras( 140 | focal_length=focal_length, 141 | principal_point=principal_point, 142 | R=R2.unsqueeze(0), 143 | T=T2.unsqueeze(0), 144 | image_size=image_size, 145 | in_ndc=False, 146 | # K = intrinsic_matrix, 147 | device=device, 148 | ) 149 | 150 | raster_settings = RasterizationSettings( 151 | image_size=(H, W), blur_radius=0.0, faces_per_pixel=1 152 | ) 153 | 154 | rasterizer = MeshRasterizer(cameras=cameras, raster_settings=raster_settings) 155 | 156 | print("start rendering") 157 | render_img = rasterizer(mesh.to(device)).zbuf[0, ..., 0] 158 | render_img = render_img.cpu().numpy() 159 | 160 | render_img[render_img < 0] = 0 161 | 162 | render_img = (render_img * 1000).astype(np.uint16) 163 | print(render_img.max()) 164 | if image_name.endswith("jpg"): 165 | image_name = image_name.replace("jpg", "png") 166 | cv2.imwrite(os.path.join(output_path, image_name), render_img) 167 | # import matplotlib.pyplot as plt 168 | # plt.imsave(os.path.join(output_path, image_name), render_img, cmap="viridis") 169 | 170 | 171 | if __name__ == "__main__": 172 | tyro.cli(main) 173 | -------------------------------------------------------------------------------- /dn_splatter/data/replica_utils/render_normals.py: -------------------------------------------------------------------------------- 1 | """Renders normal maps from mesh and camera pose trajectory 2 | 3 | Note: 4 | If you are running this in Headless mode, e.g. in a server with no monitor, 5 | you need to compile Open3D in headless mode: http://www.open3d.org/docs/release/tutorial/visualization/headless_rendering.html?highlight=headless 6 | 7 | - Tested with Open3D 0.17.0 and 0.16.1. Some versions of Open3D will not work. 8 | 9 | Important: 10 | Normal maps are rendered in OpenCV camera coordinate system (default Open3D conventions) 11 | """ 12 | 13 | from pathlib import Path 14 | 15 | import numpy as np 16 | import open3d as o3d 17 | import trimesh 18 | from dn_splatter.utils.utils import save_img 19 | from torch import Tensor 20 | from tqdm import tqdm 21 | 22 | 23 | def render_normals_gt( 24 | mesh_path, 25 | poses: Tensor, 26 | save_dir: Path, 27 | w=1200, 28 | h=680, 29 | fx=600.0, 30 | fy=600.0, 31 | cx=599.5, 32 | cy=339.6, 33 | ): 34 | """Render normal maps given a mesh ply file, a trajectory of poses, and camera intrinsics""" 35 | np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) 36 | mesh = trimesh.load_mesh(mesh_path).as_open3d 37 | mesh.compute_vertex_normals() 38 | vis = o3d.visualization.Visualizer() 39 | vis.create_window(visible=False, width=w, height=h) 40 | vis.add_geometry(mesh) 41 | vis.get_render_option().mesh_color_option = o3d.visualization.MeshColorOption.Normal 42 | vis.update_geometry(mesh) 43 | 44 | for i, c2w in tqdm( 45 | enumerate(poses), desc="Generating normals for each input pose ..." 46 | ): 47 | w2c = np.linalg.inv(c2w) 48 | camera = vis.get_view_control().convert_to_pinhole_camera_parameters() 49 | camera.extrinsic = w2c 50 | vis.get_view_control().convert_from_pinhole_camera_parameters(camera) 51 | vis.poll_events() 52 | vis.update_renderer() 53 | color_image = vis.capture_screen_float_buffer(True) 54 | image = np.asarray(color_image) * 255 55 | image = image.astype(np.uint8) 56 | save_img(image, f"{str(save_dir)}/normal_{i:05d}.png", verbose=False) 57 | vis.destroy_window() 58 | -------------------------------------------------------------------------------- /dn_splatter/data/scannetpp_utils/pointcloud_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import random 3 | 4 | import cv2 5 | import numpy as np 6 | import open3d as o3d 7 | from dn_splatter.utils.camera_utils import OPENGL_TO_OPENCV 8 | 9 | 10 | def generate_iPhone_pointcloud( 11 | input_folder, meta, i_train, num_points: int = 1_000_000 12 | ): 13 | print("Generating pointcloud from iPhone data...") 14 | frames = meta["frames"] 15 | volume = o3d.pipelines.integration.ScalableTSDFVolume( 16 | voxel_length=0.04, 17 | sdf_trunc=0.2, 18 | color_type=o3d.pipelines.integration.TSDFVolumeColorType.RGB8, 19 | ) 20 | 21 | index = i_train 22 | samples_per_frame = (num_points + len(index)) // (len(index)) 23 | 24 | points_list = [] 25 | colors_list = [] 26 | 27 | for frame in frames: 28 | H, W, fx, fy, cx, cy = ( 29 | frame["h"], 30 | frame["w"], 31 | frame["fl_x"], 32 | frame["fl_y"], 33 | frame["cx"], 34 | frame["cy"], 35 | ) 36 | name = frame["file_path"].split("/")[-1] 37 | color = cv2.imread(str(input_folder / "rgb" / name)) 38 | color = cv2.cvtColor(color, cv2.COLOR_BGR2RGB) 39 | color = o3d.geometry.Image(color) 40 | 41 | # pose 42 | pose = frame["transform_matrix"] 43 | pose = np.matmul(np.array(pose), OPENGL_TO_OPENCV) 44 | depth = cv2.imread( 45 | str(input_folder / "depth" / name.replace("jpg", "png")), 46 | cv2.IMREAD_ANYDEPTH, 47 | ) 48 | depth = cv2.resize(depth, (W, H)) 49 | depth = o3d.geometry.Image(depth) 50 | 51 | camera_intrinsics = o3d.camera.PinholeCameraIntrinsic(W, H, fx, fy, cx, cy) 52 | 53 | rgbd = o3d.geometry.RGBDImage.create_from_color_and_depth( 54 | color, depth, depth_trunc=4.0, convert_rgb_to_intensity=False 55 | ) 56 | 57 | volume.integrate( 58 | rgbd, 59 | camera_intrinsics, # type: ignore 60 | np.linalg.inv(pose), 61 | ) 62 | 63 | pcd = volume.extract_point_cloud() 64 | 65 | # randomly select samples_per_frame points from points 66 | samples_per_frame = min(samples_per_frame, len(pcd.points)) 67 | mask = random.sample(range(len(pcd.points)), samples_per_frame) 68 | mask = np.asarray(mask) 69 | color = np.asarray(pcd.colors)[mask] 70 | point = np.asarray(pcd.points)[mask] 71 | 72 | points_list.append(np.asarray(point)) 73 | colors_list.append(np.asarray(color)) 74 | 75 | points = np.vstack(points_list) 76 | colors = np.vstack(colors_list) 77 | pcd = o3d.geometry.PointCloud() 78 | pcd.points = o3d.utility.Vector3dVector(points) 79 | pcd.colors = o3d.utility.Vector3dVector(colors) 80 | 81 | o3d.io.write_point_cloud(os.path.join(input_folder / "point_cloud.ply"), pcd) 82 | 83 | mesh = volume.extract_triangle_mesh() 84 | o3d.io.write_triangle_mesh(os.path.join(input_folder / "TSDFVolume.ply"), mesh) 85 | -------------------------------------------------------------------------------- /dn_splatter/dn_config.py: -------------------------------------------------------------------------------- 1 | from __future__ import annotations 2 | 3 | from dn_splatter.data.normal_nerfstudio import NormalNerfstudioConfig 4 | from dn_splatter.dn_datamanager import DNSplatterManagerConfig 5 | from dn_splatter.dn_model import DNSplatterModelConfig 6 | from dn_splatter.dn_pipeline import DNSplatterPipelineConfig 7 | from nerfstudio.configs.base_config import ViewerConfig 8 | from nerfstudio.engine.optimizers import AdamOptimizerConfig 9 | from nerfstudio.engine.schedulers import ExponentialDecaySchedulerConfig 10 | from nerfstudio.engine.trainer import TrainerConfig 11 | from nerfstudio.plugins.types import MethodSpecification 12 | 13 | dn_splatter = MethodSpecification( 14 | config=TrainerConfig( 15 | method_name="dn-splatter", 16 | steps_per_eval_image=500, 17 | steps_per_eval_batch=500, 18 | steps_per_save=1000000, 19 | steps_per_eval_all_images=1000000, 20 | max_num_iterations=30000, 21 | mixed_precision=False, 22 | gradient_accumulation_steps={"camera_opt": 100, "color": 10, "shs": 10}, 23 | pipeline=DNSplatterPipelineConfig( 24 | datamanager=DNSplatterManagerConfig( 25 | dataparser=NormalNerfstudioConfig(load_3D_points=True) 26 | ), 27 | model=DNSplatterModelConfig(regularization_strategy="dn-splatter"), 28 | ), 29 | optimizers={ 30 | "means": { 31 | "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15), 32 | "scheduler": ExponentialDecaySchedulerConfig( 33 | lr_final=1.6e-6, max_steps=30000 34 | ), 35 | }, 36 | "features_dc": { 37 | "optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15), 38 | "scheduler": None, 39 | }, 40 | "features_rest": { 41 | "optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15), 42 | "scheduler": None, 43 | }, 44 | "opacities": { 45 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15), 46 | "scheduler": None, 47 | }, 48 | "scales": { 49 | "optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15), 50 | "scheduler": None, 51 | }, 52 | "quats": { 53 | "optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), 54 | "scheduler": None, 55 | }, 56 | "camera_opt": { 57 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), 58 | "scheduler": ExponentialDecaySchedulerConfig( 59 | lr_final=5e-5, max_steps=30000 60 | ), 61 | }, 62 | "normals": { 63 | "optimizer": AdamOptimizerConfig( 64 | lr=1e-3, eps=1e-15 65 | ), # this does nothing, its just here to make the trainer happy 66 | "scheduler": None, 67 | }, 68 | }, 69 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15), 70 | vis="viewer", 71 | ), 72 | description="DN-Splatter: depth and normal priors for 3DGS", 73 | ) 74 | 75 | ags_mesh = MethodSpecification( 76 | config=TrainerConfig( 77 | method_name="ags-mesh", 78 | steps_per_eval_image=500, 79 | steps_per_eval_batch=500, 80 | steps_per_save=1000000, 81 | steps_per_eval_all_images=1000000, 82 | max_num_iterations=30000, 83 | mixed_precision=False, 84 | gradient_accumulation_steps={"camera_opt": 100, "color": 10, "shs": 10}, 85 | pipeline=DNSplatterPipelineConfig( 86 | datamanager=DNSplatterManagerConfig( 87 | dataparser=NormalNerfstudioConfig(load_3D_points=True) 88 | ), 89 | model=DNSplatterModelConfig(regularization_strategy="ags-mesh"), 90 | ), 91 | optimizers={ 92 | "means": { 93 | "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15), 94 | "scheduler": ExponentialDecaySchedulerConfig( 95 | lr_final=1.6e-6, max_steps=30000 96 | ), 97 | }, 98 | "features_dc": { 99 | "optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15), 100 | "scheduler": None, 101 | }, 102 | "features_rest": { 103 | "optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15), 104 | "scheduler": None, 105 | }, 106 | "opacities": { 107 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15), 108 | "scheduler": None, 109 | }, 110 | "scales": { 111 | "optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15), 112 | "scheduler": None, 113 | }, 114 | "quats": { 115 | "optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), 116 | "scheduler": None, 117 | }, 118 | "camera_opt": { 119 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), 120 | "scheduler": ExponentialDecaySchedulerConfig( 121 | lr_final=5e-5, max_steps=30000 122 | ), 123 | }, 124 | "normals": { 125 | "optimizer": AdamOptimizerConfig( 126 | lr=1e-3, eps=1e-15 127 | ), # this does nothing, its just here to make the trainer happy 128 | "scheduler": None, 129 | }, 130 | }, 131 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15), 132 | vis="viewer", 133 | ), 134 | description="AGS-Mesh variant of Splatfacto model. Incorporates depth and normal filtering strategy.", 135 | ) 136 | 137 | dn_splatter_big = MethodSpecification( 138 | config=TrainerConfig( 139 | method_name="dn-splatter-big", 140 | steps_per_eval_image=500, 141 | steps_per_eval_batch=500, 142 | steps_per_save=1000000, 143 | steps_per_eval_all_images=1000000, 144 | max_num_iterations=30000, 145 | mixed_precision=False, 146 | pipeline=DNSplatterPipelineConfig( 147 | datamanager=DNSplatterManagerConfig( 148 | dataparser=NormalNerfstudioConfig(load_3D_points=True) 149 | ), 150 | model=DNSplatterModelConfig( 151 | cull_alpha_thresh=0.005, 152 | continue_cull_post_densification=False, 153 | ), 154 | ), 155 | optimizers={ 156 | "means": { 157 | "optimizer": AdamOptimizerConfig(lr=1.6e-4, eps=1e-15), 158 | "scheduler": ExponentialDecaySchedulerConfig( 159 | lr_final=1.6e-6, 160 | max_steps=30000, 161 | ), 162 | }, 163 | "features_dc": { 164 | "optimizer": AdamOptimizerConfig(lr=0.0025, eps=1e-15), 165 | "scheduler": None, 166 | }, 167 | "features_rest": { 168 | "optimizer": AdamOptimizerConfig(lr=0.0025 / 20, eps=1e-15), 169 | "scheduler": None, 170 | }, 171 | "opacities": { 172 | "optimizer": AdamOptimizerConfig(lr=0.05, eps=1e-15), 173 | "scheduler": None, 174 | }, 175 | "scales": { 176 | "optimizer": AdamOptimizerConfig(lr=0.005, eps=1e-15), 177 | "scheduler": None, 178 | }, 179 | "quats": { 180 | "optimizer": AdamOptimizerConfig(lr=0.001, eps=1e-15), 181 | "scheduler": None, 182 | }, 183 | "camera_opt": { 184 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), 185 | "scheduler": ExponentialDecaySchedulerConfig( 186 | lr_final=5e-5, max_steps=30000 187 | ), 188 | }, 189 | "normals": { 190 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), 191 | "scheduler": None, 192 | }, 193 | }, 194 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15), 195 | vis="viewer", 196 | ), 197 | description="DN-Splatter Big variant", 198 | ) 199 | -------------------------------------------------------------------------------- /dn_splatter/dn_datamanager.py: -------------------------------------------------------------------------------- 1 | """ 2 | Datamanager that processes optional depth and normal data. 3 | """ 4 | 5 | import random 6 | from copy import deepcopy 7 | from dataclasses import dataclass, field 8 | from typing import Dict, Literal, Tuple, Type, Union 9 | 10 | import torch 11 | import torchvision.transforms.functional as TF 12 | 13 | from dn_splatter.data.dn_dataset import GDataset 14 | from nerfstudio.cameras.cameras import Cameras 15 | from nerfstudio.data.datamanagers.full_images_datamanager import ( 16 | FullImageDatamanager, 17 | FullImageDatamanagerConfig, 18 | ) 19 | from nerfstudio.data.datasets.base_dataset import InputDataset 20 | 21 | 22 | @dataclass 23 | class DNSplatterManagerConfig(FullImageDatamanagerConfig): 24 | """DataManager Config""" 25 | 26 | _target: Type = field(default_factory=lambda: DNSplatterDataManager) 27 | 28 | camera_res_scale_factor: float = 1.0 29 | """Rescale cameras""" 30 | 31 | 32 | class DNSplatterDataManager(FullImageDatamanager): 33 | """DataManager 34 | 35 | Args: 36 | config: the DataManagerConfig used to instantiate class 37 | """ 38 | 39 | config: DNSplatterManagerConfig 40 | train_dataset: GDataset 41 | eval_dataset: GDataset 42 | 43 | def __init__( 44 | self, 45 | config: DNSplatterManagerConfig, 46 | device: Union[torch.device, str] = "cpu", 47 | test_mode: Literal["test", "val", "inference"] = "val", 48 | world_size: int = 1, 49 | local_rank: int = 0, 50 | **kwargs, # pylint: disable=unused-argument 51 | ): 52 | self.config = config 53 | super().__init__( 54 | config=config, 55 | device=device, 56 | test_mode=test_mode, 57 | world_size=world_size, 58 | local_rank=local_rank, 59 | **kwargs, 60 | ) 61 | metadata = self.train_dataparser_outputs.metadata 62 | self.load_depths = ( 63 | True 64 | if ("depth_filenames" in metadata) 65 | or ("sensor_depth_filenames" in metadata) 66 | or ("mono_depth_filenames") in metadata 67 | else False 68 | ) 69 | 70 | self.load_normals = True if ("normal_filenames" in metadata) else False 71 | self.load_confidence = True if ("confidence_filenames" in metadata) else False 72 | self.image_idx = 0 73 | 74 | def create_train_dataset(self) -> InputDataset: 75 | """Sets up the data loaders for training""" 76 | return GDataset( 77 | dataparser_outputs=self.train_dataparser_outputs, 78 | scale_factor=self.config.camera_res_scale_factor, 79 | ) 80 | 81 | def create_eval_dataset(self) -> InputDataset: 82 | """Sets up the data loaders for evaluation""" 83 | return GDataset( 84 | dataparser_outputs=self.dataparser.get_dataparser_outputs( 85 | split=self.test_split 86 | ), 87 | scale_factor=self.config.camera_res_scale_factor, 88 | ) 89 | 90 | def next_train(self, step: int) -> Tuple[Cameras, Dict]: 91 | """Returns the next training batch""" 92 | 93 | # Don't randomly sample train images (keep t-1, t, t+1 ordering). 94 | self.image_idx = self.train_unseen_cameras.pop(0) 95 | if len(self.train_unseen_cameras) == 0: 96 | self.train_unseen_cameras = [i for i in range(len(self.train_dataset))] 97 | data = deepcopy(self.cached_train[self.image_idx]) 98 | data["image"] = data["image"].to(self.device) 99 | 100 | if "mask" in data: 101 | data["mask"] = data["mask"].to(self.device) 102 | if data["mask"].dim() == 2: 103 | data["mask"] = data["mask"][..., None] 104 | 105 | if self.load_depths: 106 | if "sensor_depth" in data: 107 | data["sensor_depth"] = data["sensor_depth"].to(self.device) 108 | if data["sensor_depth"].shape != data["image"].shape: 109 | data["sensor_depth"] = TF.resize( 110 | data["sensor_depth"].permute(2, 0, 1), 111 | data["image"].shape[:2], 112 | antialias=None, 113 | ).permute(1, 2, 0) 114 | if "mono_depth" in data: 115 | data["mono_depth"] = data["mono_depth"].to(self.device) 116 | if data["mono_depth"].shape != data["image"].shape: 117 | data["mono_depth"] = TF.resize( 118 | data["mono_depth"].permute(2, 0, 1), 119 | data["image"].shape[:2], 120 | antialias=None, 121 | ).permute(1, 2, 0) 122 | 123 | if self.load_normals: 124 | assert "normal" in data 125 | data["normal"] = data["normal"].to(self.device) 126 | if data["normal"].shape != data["image"].shape: 127 | data["normal"] = TF.resize( 128 | data["normal"].permute(2, 0, 1), 129 | data["image"].shape[:2], 130 | antialias=None, 131 | ).permute(1, 2, 0) 132 | if self.load_confidence: 133 | assert "confidence" in data 134 | data["confidence"] = data["confidence"].to(self.device) 135 | if data["confidence"].shape != data["image"].shape: 136 | data["confidence"] = TF.resize( 137 | data["confidence"].permute(2, 0, 1), 138 | data["image"].shape[:2], 139 | antialias=None, 140 | ).permute(1, 2, 0) 141 | assert ( 142 | len(self.train_dataset.cameras.shape) == 1 143 | ), "Assumes single batch dimension" 144 | camera = self.train_dataset.cameras[self.image_idx : self.image_idx + 1].to( 145 | self.device 146 | ) 147 | if camera.metadata is None: 148 | camera.metadata = {} 149 | camera.metadata["cam_idx"] = self.image_idx 150 | return camera, data 151 | 152 | def next_eval(self, step: int) -> Tuple[Cameras, Dict]: 153 | """Returns the next evaluation batch 154 | 155 | Returns a Camera instead of raybundle""" 156 | image_idx = self.eval_unseen_cameras[ 157 | random.randint(0, len(self.eval_unseen_cameras) - 1) 158 | ] 159 | 160 | # Make sure to re-populate the unseen cameras list if we have exhausted it 161 | if len(self.eval_unseen_cameras) == 0: 162 | self.eval_unseen_cameras = [i for i in range(len(self.eval_dataset))] 163 | data = deepcopy(self.cached_eval[image_idx]) 164 | data["image"] = data["image"].to(self.device) 165 | if "mask" in data: 166 | data["mask"] = data["mask"].to(self.device) 167 | if data["mask"].dim() == 2: 168 | data["mask"] = data["mask"][..., None] 169 | if self.load_depths: 170 | if "sensor_depth" in data: 171 | data["sensor_depth"] = data["sensor_depth"].to(self.device) 172 | if data["sensor_depth"].shape != data["image"].shape: 173 | data["sensor_depth"] = TF.resize( 174 | data["sensor_depth"].permute(2, 0, 1), 175 | data["image"].shape[:2], 176 | antialias=None, 177 | ).permute(1, 2, 0) 178 | if "mono_depth" in data: 179 | data["mono_depth"] = data["mono_depth"].to(self.device) 180 | if data["mono_depth"].shape != data["image"].shape: 181 | data["mono_depth"] = TF.resize( 182 | data["mono_depth"].permute(2, 0, 1), 183 | data["image"].shape[:2], 184 | antialias=None, 185 | ).permute(1, 2, 0) 186 | if self.load_normals: 187 | assert "normal" in data 188 | data["normal"] = data["normal"].to(self.device) 189 | if data["normal"].shape != data["image"].shape: 190 | data["normal"] = TF.resize( 191 | data["normal"].permute(2, 0, 1), 192 | data["image"].shape[:2], 193 | antialias=None, 194 | ).permute(1, 2, 0) 195 | if self.load_confidence: 196 | assert "confidence" in data 197 | data["confidence"] = data["confidence"].to(self.device) 198 | if data["confidence"].shape != data["image"].shape: 199 | data["confidence"] = TF.resize( 200 | data["confidence"].permute(2, 0, 1), 201 | data["image"].shape[:2], 202 | antialias=None, 203 | ).permute(1, 2, 0) 204 | assert ( 205 | len(self.eval_dataset.cameras.shape) == 1 206 | ), "Assumes single batch dimension" 207 | camera = self.eval_dataset.cameras[image_idx : image_idx + 1].to(self.device) 208 | if camera.metadata is None: 209 | camera.metadata = {} 210 | camera.metadata["cam_idx"] = image_idx 211 | return camera, data 212 | 213 | def next_eval_image(self, step: int) -> Tuple[Cameras, Dict]: 214 | """Returns the next eval image""" 215 | 216 | image_idx = self.eval_unseen_cameras[ 217 | random.randint(0, len(self.eval_unseen_cameras) - 1) 218 | ] 219 | data = deepcopy(self.cached_eval[image_idx]) 220 | data["image"] = data["image"].to(self.device) 221 | assert ( 222 | len(self.eval_dataset.cameras.shape) == 1 223 | ), "Assumes single batch dimension" 224 | camera = self.eval_dataset.cameras[image_idx : image_idx + 1].to(self.device) 225 | return camera, data 226 | -------------------------------------------------------------------------------- /dn_splatter/eval/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maturk/dn-splatter/249d52c4bb14b7bf6dd18d7d66099a36eac2ee78/dn_splatter/eval/__init__.py -------------------------------------------------------------------------------- /dn_splatter/eval/baseline_models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maturk/dn-splatter/249d52c4bb14b7bf6dd18d7d66099a36eac2ee78/dn_splatter/eval/baseline_models/__init__.py -------------------------------------------------------------------------------- /dn_splatter/eval/baseline_models/eval_configs.py: -------------------------------------------------------------------------------- 1 | """ 2 | Eval configs 3 | """ 4 | 5 | from __future__ import annotations 6 | 7 | from dn_splatter.data.dn_dataset import GDataset 8 | from dn_splatter.dn_pipeline import DNSplatterPipelineConfig 9 | from dn_splatter.eval.baseline_models.g_depthnerfacto import GDepthNerfactoModelConfig 10 | from dn_splatter.eval.baseline_models.g_nerfacto import GNerfactoModelConfig 11 | from dn_splatter.eval.baseline_models.g_neusfacto import DNeuSFactoModelConfig 12 | from nerfstudio.cameras.camera_optimizers import CameraOptimizerConfig 13 | from nerfstudio.configs.base_config import ViewerConfig 14 | from nerfstudio.data.datamanagers.base_datamanager import ( 15 | VanillaDataManager, 16 | VanillaDataManagerConfig, 17 | ) 18 | from nerfstudio.data.pixel_samplers import PairPixelSamplerConfig 19 | from nerfstudio.engine.optimizers import AdamOptimizerConfig 20 | from nerfstudio.engine.schedulers import ( 21 | CosineDecaySchedulerConfig, 22 | ExponentialDecaySchedulerConfig, 23 | MultiStepSchedulerConfig, 24 | ) 25 | from nerfstudio.engine.trainer import TrainerConfig 26 | from nerfstudio.fields.sdf_field import SDFFieldConfig 27 | from nerfstudio.plugins.types import MethodSpecification 28 | 29 | gnerfacto = MethodSpecification( 30 | config=TrainerConfig( 31 | method_name="gnerfacto", 32 | steps_per_eval_batch=500, 33 | steps_per_save=500, 34 | max_num_iterations=30000, 35 | mixed_precision=True, 36 | pipeline=DNSplatterPipelineConfig( 37 | datamanager=VanillaDataManagerConfig( 38 | _target=VanillaDataManager[GDataset], 39 | train_num_rays_per_batch=4096, 40 | eval_num_rays_per_batch=4096, 41 | ), 42 | model=GNerfactoModelConfig( 43 | eval_num_rays_per_chunk=1 << 15, 44 | camera_optimizer=CameraOptimizerConfig(mode="off"), 45 | ), 46 | ), 47 | optimizers={ 48 | "proposal_networks": { 49 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), 50 | "scheduler": ExponentialDecaySchedulerConfig( 51 | lr_final=0.0001, max_steps=200000 52 | ), 53 | }, 54 | "fields": { 55 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), 56 | "scheduler": ExponentialDecaySchedulerConfig( 57 | lr_final=0.0001, max_steps=200000 58 | ), 59 | }, 60 | "camera_opt": { 61 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), 62 | "scheduler": ExponentialDecaySchedulerConfig( 63 | lr_final=1e-4, max_steps=5000 64 | ), 65 | }, 66 | }, 67 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15), 68 | vis="viewer", 69 | ), 70 | description="Our version of nerfacto for experimentation.", 71 | ) 72 | 73 | gdepthfacto = MethodSpecification( 74 | config=TrainerConfig( 75 | method_name="gdepthfacto", 76 | steps_per_eval_batch=500, 77 | steps_per_save=500, 78 | max_num_iterations=30000, 79 | mixed_precision=True, 80 | pipeline=DNSplatterPipelineConfig( 81 | datamanager=VanillaDataManagerConfig( 82 | _target=VanillaDataManager[GDataset], 83 | pixel_sampler=PairPixelSamplerConfig(), 84 | train_num_rays_per_batch=4096, 85 | eval_num_rays_per_batch=4096, 86 | ), 87 | model=GDepthNerfactoModelConfig( 88 | eval_num_rays_per_chunk=1 << 15, 89 | camera_optimizer=CameraOptimizerConfig(mode="off"), 90 | ), 91 | ), 92 | optimizers={ 93 | "proposal_networks": { 94 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), 95 | "scheduler": ExponentialDecaySchedulerConfig( 96 | lr_final=0.0001, max_steps=200000 97 | ), 98 | }, 99 | "fields": { 100 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), 101 | "scheduler": ExponentialDecaySchedulerConfig( 102 | lr_final=0.0001, max_steps=200000 103 | ), 104 | }, 105 | "camera_opt": { 106 | "optimizer": AdamOptimizerConfig(lr=1e-3, eps=1e-15), 107 | "scheduler": ExponentialDecaySchedulerConfig( 108 | lr_final=1e-4, max_steps=5000 109 | ), 110 | }, 111 | }, 112 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15), 113 | vis="viewer", 114 | ), 115 | description="Our version of depth-nerfacto for experimentation.", 116 | ) 117 | 118 | gneusfacto = MethodSpecification( 119 | config=TrainerConfig( 120 | method_name="gneusfacto", 121 | steps_per_eval_image=500, 122 | steps_per_eval_batch=500, 123 | steps_per_save=500, 124 | max_num_iterations=30000, 125 | mixed_precision=True, 126 | pipeline=DNSplatterPipelineConfig( 127 | datamanager=VanillaDataManagerConfig( 128 | _target=VanillaDataManager[GDataset], 129 | train_num_rays_per_batch=2048, 130 | eval_num_rays_per_batch=1024, 131 | ), 132 | model=DNeuSFactoModelConfig( 133 | # proposal network allows for significantly smaller sdf/color network 134 | sdf_field=SDFFieldConfig( 135 | use_grid_feature=True, 136 | num_layers=2, 137 | num_layers_color=2, 138 | hidden_dim=256, 139 | bias=0.5, 140 | beta_init=0.8, 141 | use_appearance_embedding=False, 142 | inside_outside=False, 143 | ), 144 | background_model="none", 145 | eval_num_rays_per_chunk=1024, 146 | ), 147 | ), 148 | optimizers={ 149 | "proposal_networks": { 150 | "optimizer": AdamOptimizerConfig(lr=1e-2, eps=1e-15), 151 | "scheduler": MultiStepSchedulerConfig( 152 | max_steps=20001, milestones=(10000, 1500, 18000) 153 | ), 154 | }, 155 | "fields": { 156 | "optimizer": AdamOptimizerConfig(lr=5e-4, eps=1e-15), 157 | "scheduler": CosineDecaySchedulerConfig( 158 | warm_up_end=500, learning_rate_alpha=0.05, max_steps=20001 159 | ), 160 | }, 161 | "field_background": { 162 | "optimizer": AdamOptimizerConfig(lr=5e-4, eps=1e-15), 163 | "scheduler": CosineDecaySchedulerConfig( 164 | warm_up_end=500, learning_rate_alpha=0.05, max_steps=20001 165 | ), 166 | }, 167 | }, 168 | viewer=ViewerConfig(num_rays_per_chunk=1 << 15), 169 | vis="viewer", 170 | ), 171 | description="Our version of neus-facto for experimentation.", 172 | ) 173 | -------------------------------------------------------------------------------- /dn_splatter/eval/baseline_models/g_depthnerfacto.py: -------------------------------------------------------------------------------- 1 | """Our version of Depth-Nerfacto for evaluation""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass, field 6 | from typing import Dict, Tuple, Type 7 | 8 | import torch 9 | import torchvision.transforms.functional as TF 10 | from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure 11 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 12 | 13 | from dn_splatter.metrics import DepthMetrics 14 | from nerfstudio.model_components.scene_colliders import NearFarCollider 15 | from nerfstudio.models.depth_nerfacto import ( 16 | DepthNerfactoModel, 17 | DepthNerfactoModelConfig, 18 | ) 19 | 20 | 21 | @dataclass 22 | class GDepthNerfactoModelConfig(DepthNerfactoModelConfig): 23 | _target: Type = field(default_factory=lambda: GDepthNerfactoModel) 24 | disable_scene_contraction: bool = False 25 | """Whether to disable scene contraction or not.""" 26 | is_euclidean_depth: bool = False 27 | """Whether input depth maps are Euclidean distances (or z-distances).""" 28 | far_plane: float = 2.0 29 | """How far along the ray to stop sampling.""" 30 | predict_normals: bool = True 31 | """Whether to predict normals or not.""" 32 | 33 | 34 | class GDepthNerfactoModel(DepthNerfactoModel): 35 | config: GDepthNerfactoModelConfig 36 | 37 | def populate_modules(self): 38 | super().populate_modules() 39 | self.psnr = PeakSignalNoiseRatio(data_range=1.0) 40 | self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0, kernel_size=11) 41 | self.lpips = LearnedPerceptualImagePatchSimilarity() 42 | self.depth_metrics = DepthMetrics() 43 | self.collider = NearFarCollider( 44 | near_plane=self.config.near_plane, far_plane=self.config.far_plane 45 | ) 46 | 47 | def get_metrics_dict(self, outputs, batch): 48 | if "sensor_depth" in batch: 49 | batch["depth_image"] = batch["sensor_depth"] 50 | elif "mono_depth" in batch: 51 | batch["depth_image"] = batch["mono_depth"] 52 | metrics_dict = super().get_metrics_dict(outputs, batch) 53 | 54 | return metrics_dict 55 | 56 | def get_image_metrics_and_images( 57 | self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor] 58 | ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]: 59 | gt_rgb = batch["image"].to(self.device) 60 | predicted_rgb = outputs[ 61 | "rgb" 62 | ] # Blended with background (black if random background) 63 | combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1) 64 | # Switch images from [H, W, C] to [1, C, H, W] for metrics computations 65 | gt_rgb = torch.moveaxis(gt_rgb, -1, 0)[None, ...] 66 | predicted_rgb = torch.moveaxis(predicted_rgb, -1, 0)[None, ...] 67 | 68 | if "mask" in batch: 69 | mask = batch["mask"].to(self.device) 70 | gt_rgb = gt_rgb * mask 71 | predicted_rgb = predicted_rgb * mask 72 | 73 | psnr = self.psnr(gt_rgb, predicted_rgb) 74 | ssim = self.ssim(gt_rgb, predicted_rgb) 75 | lpips = self.lpips(gt_rgb, predicted_rgb) 76 | 77 | # all of these metrics will be logged as scalars 78 | metrics_dict = { 79 | "rgb_psnr": float(psnr.item()), 80 | "rgb_ssim": float(ssim), 81 | } # type: ignore 82 | metrics_dict["rgb_lpips"] = float(lpips) 83 | 84 | if "sensor_depth" in batch: 85 | gt_depth = batch["sensor_depth"].to(self.device) 86 | 87 | # change from z-depth to euclidean depth 88 | gt_depth = gt_depth * outputs["directions_norm"] 89 | 90 | predicted_depth = outputs["depth"] 91 | if predicted_depth.shape[:2] != gt_depth.shape[:2]: 92 | predicted_depth = TF.resize( 93 | predicted_depth.permute(2, 0, 1), gt_depth.shape[:2], antialias=None 94 | ).permute(1, 2, 0) 95 | 96 | gt_depth = gt_depth.to(torch.float32) # it is in float64 previous 97 | if "mask" in batch: 98 | gt_depth = gt_depth * mask 99 | predicted_depth = predicted_depth * mask 100 | 101 | # add depth eval metrics 102 | (abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3) = self.depth_metrics( 103 | predicted_depth.permute(2, 0, 1), gt_depth.permute(2, 0, 1) 104 | ) 105 | depth_metrics = { 106 | "depth_abs_rel": float(abs_rel.item()), 107 | "depth_sq_rel": float(sq_rel.item()), 108 | "depth_rmse": float(rmse.item()), 109 | "depth_rmse_log": float(rmse_log.item()), 110 | "depth_a1": float(a1.item()), 111 | "depth_a2": float(a2.item()), 112 | "depth_a3": float(a3.item()), 113 | } 114 | metrics_dict.update(depth_metrics) 115 | 116 | images_dict = {"img": combined_rgb} 117 | 118 | return metrics_dict, images_dict 119 | -------------------------------------------------------------------------------- /dn_splatter/eval/baseline_models/g_nerfacto.py: -------------------------------------------------------------------------------- 1 | """Our version of Nerfacto for evaluation""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass, field 6 | from typing import Dict, Tuple, Type 7 | 8 | import torch 9 | import torchvision.transforms.functional as TF 10 | from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure 11 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 12 | 13 | from dn_splatter.metrics import DepthMetrics 14 | from nerfstudio.cameras.rays import RayBundle 15 | from nerfstudio.model_components.scene_colliders import NearFarCollider 16 | from nerfstudio.models.nerfacto import NerfactoModel, NerfactoModelConfig 17 | 18 | 19 | @dataclass 20 | class GNerfactoModelConfig(NerfactoModelConfig): 21 | _target: Type = field(default_factory=lambda: GNerfactoModel) 22 | disable_scene_contraction: bool = False 23 | """Whether to disable scene contraction or not.""" 24 | is_euclidean_depth: bool = False 25 | """Whether input depth maps are Euclidean distances (or z-distances).""" 26 | far_plane: float = 2.0 27 | """How far along the ray to stop sampling.""" 28 | predict_normals: bool = True 29 | """Whether to predict normals or not.""" 30 | 31 | 32 | class GNerfactoModel(NerfactoModel): 33 | config: GNerfactoModelConfig 34 | 35 | def populate_modules(self): 36 | super().populate_modules() 37 | self.psnr = PeakSignalNoiseRatio(data_range=1.0) 38 | self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0, kernel_size=11) 39 | self.lpips = LearnedPerceptualImagePatchSimilarity() 40 | self.depth_metrics = DepthMetrics() 41 | self.collider = NearFarCollider( 42 | near_plane=self.config.near_plane, far_plane=self.config.far_plane 43 | ) 44 | 45 | def get_outputs(self, ray_bundle: RayBundle): 46 | outputs = super().get_outputs(ray_bundle) 47 | if ray_bundle.metadata is not None and "directions_norm" in ray_bundle.metadata: 48 | outputs["directions_norm"] = ray_bundle.metadata["directions_norm"] 49 | return outputs 50 | 51 | def get_image_metrics_and_images( 52 | self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor] 53 | ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]: 54 | gt_rgb = batch["image"].to(self.device) 55 | predicted_rgb = outputs[ 56 | "rgb" 57 | ] # Blended with background (black if random background) 58 | combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1) 59 | # Switch images from [H, W, C] to [1, C, H, W] for metrics computations 60 | gt_rgb = torch.moveaxis(gt_rgb, -1, 0)[None, ...] 61 | predicted_rgb = torch.moveaxis(predicted_rgb, -1, 0)[None, ...] 62 | 63 | if "mask" in batch: 64 | mask = batch["mask"].to(self.device) 65 | gt_rgb = gt_rgb * mask 66 | predicted_rgb = predicted_rgb * mask 67 | 68 | psnr = self.psnr(gt_rgb, predicted_rgb) 69 | ssim = self.ssim(gt_rgb, predicted_rgb) 70 | lpips = self.lpips(gt_rgb, predicted_rgb) 71 | 72 | # all of these metrics will be logged as scalars 73 | metrics_dict = { 74 | "rgb_psnr": float(psnr.item()), 75 | "rgb_ssim": float(ssim), 76 | } # type: ignore 77 | metrics_dict["rgb_lpips"] = float(lpips) 78 | 79 | if "sensor_depth" in batch: 80 | gt_depth = batch["sensor_depth"].to(self.device) 81 | 82 | gt_depth = gt_depth * outputs["directions_norm"] 83 | 84 | predicted_depth = outputs["depth"] 85 | if predicted_depth.shape[:2] != gt_depth.shape[:2]: 86 | predicted_depth = TF.resize( 87 | predicted_depth.permute(2, 0, 1), gt_depth.shape[:2], antialias=None 88 | ).permute(1, 2, 0) 89 | 90 | gt_depth = gt_depth.to(torch.float32) # it is in float64 previous 91 | if "mask" in batch: 92 | gt_depth = gt_depth * mask 93 | predicted_depth = predicted_depth * mask 94 | 95 | # add depth eval metrics 96 | (abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3) = self.depth_metrics( 97 | predicted_depth.permute(2, 0, 1), gt_depth.permute(2, 0, 1) 98 | ) 99 | 100 | depth_metrics = { 101 | "depth_abs_rel": float(abs_rel.item()), 102 | "depth_sq_rel": float(sq_rel.item()), 103 | "depth_rmse": float(rmse.item()), 104 | "depth_rmse_log": float(rmse_log.item()), 105 | "depth_a1": float(a1.item()), 106 | "depth_a2": float(a2.item()), 107 | "depth_a3": float(a3.item()), 108 | } 109 | metrics_dict.update(depth_metrics) 110 | 111 | images_dict = {"img": combined_rgb} 112 | 113 | return metrics_dict, images_dict 114 | -------------------------------------------------------------------------------- /dn_splatter/eval/baseline_models/g_neusfacto.py: -------------------------------------------------------------------------------- 1 | """Our version of Neus-facto for evaluation""" 2 | 3 | from __future__ import annotations 4 | 5 | from dataclasses import dataclass, field 6 | from typing import Dict, Tuple, Type 7 | 8 | import torch 9 | import torchvision.transforms.functional as TF 10 | from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure 11 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 12 | 13 | from dn_splatter.losses import SensorDepthLoss 14 | from dn_splatter.metrics import DepthMetrics 15 | from nerfstudio.models.neus_facto import NeuSFactoModel, NeuSFactoModelConfig 16 | 17 | 18 | @dataclass 19 | class DNeuSFactoModelConfig(NeuSFactoModelConfig): 20 | _target: Type = field(default_factory=lambda: DNeuSFactoModel) 21 | is_euclidean_depth: bool = False 22 | """Whether input depth maps are Euclidean distances (or z-distances).""" 23 | sensor_depth_l1_loss_mult: float = 0.0 24 | """Sensor depth L1 loss multiplier.""" 25 | sensor_depth_freespace_loss_mult: float = 0.0 26 | """Sensor depth free space loss multiplier.""" 27 | sensor_depth_sdf_loss_mult: float = 0.0 28 | """Sensor depth sdf loss multiplier.""" 29 | mono_normal_loss_mult: float = 0.05 30 | """Monocular normal consistency loss multiplier, monosdf default 0.05""" 31 | mono_depth_loss_mult: float = 0.0 32 | """Depth loss multiplier, monosdf default 0.1""" 33 | sensor_depth_truncation: float = 0.015 34 | """Sensor depth trunction, default value is 0.015 which means 5cm with a rough scale value 0.3 (0.015 = 0.05 * 0.3)""" 35 | 36 | 37 | class DNeuSFactoModel(NeuSFactoModel): 38 | config: NeuSFactoModelConfig 39 | 40 | def populate_modules(self): 41 | super().populate_modules() 42 | self.psnr = PeakSignalNoiseRatio(data_range=1.0) 43 | self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0, kernel_size=11) 44 | self.lpips = LearnedPerceptualImagePatchSimilarity() 45 | self.depth_metrics = DepthMetrics() 46 | self.sensor_depth_loss = SensorDepthLoss( 47 | truncation=self.config.sensor_depth_truncation 48 | ) 49 | 50 | def get_loss_dict(self, outputs, batch, metrics_dict=None) -> Dict: 51 | # rename mono_depth to depth 52 | if "mono_depth" in batch: 53 | batch["depth"] = batch.pop("mono_depth") 54 | loss_dict = super().get_loss_dict(outputs, batch, metrics_dict) 55 | if self.training: 56 | # sensor depth loss 57 | if "sensor_depth" in batch and ( 58 | self.config.sensor_depth_l1_loss_mult > 0.0 59 | or self.config.sensor_depth_freespace_loss_mult > 0.0 60 | or self.config.sensor_depth_sdf_loss_mult > 0.0 61 | ): 62 | l1_loss, free_space_loss, sdf_loss = self.sensor_depth_loss( 63 | batch, outputs 64 | ) 65 | loss_dict["sensor_l1_loss"] = ( 66 | l1_loss * self.config.sensor_depth_l1_loss_mult 67 | ) 68 | loss_dict["sensor_freespace_loss"] = ( 69 | free_space_loss * self.config.sensor_depth_freespace_loss_mult 70 | ) 71 | loss_dict["sensor_sdf_loss"] = ( 72 | sdf_loss * self.config.sensor_depth_sdf_loss_mult 73 | ) 74 | 75 | return loss_dict 76 | 77 | def get_image_metrics_and_images( 78 | self, outputs: Dict[str, torch.Tensor], batch: Dict[str, torch.Tensor] 79 | ) -> Tuple[Dict[str, float], Dict[str, torch.Tensor]]: 80 | gt_rgb = batch["image"].to(self.device) 81 | predicted_rgb = outputs[ 82 | "rgb" 83 | ] # Blended with background (black if random background) 84 | combined_rgb = torch.cat([gt_rgb, predicted_rgb], dim=1) 85 | # Switch images from [H, W, C] to [1, C, H, W] for metrics computations 86 | gt_rgb = torch.moveaxis(gt_rgb, -1, 0)[None, ...] 87 | predicted_rgb = torch.moveaxis(predicted_rgb, -1, 0)[None, ...] 88 | 89 | if "mask" in batch: 90 | mask = batch["mask"].to(self.device) 91 | gt_rgb = gt_rgb * mask 92 | predicted_rgb = predicted_rgb * mask 93 | 94 | psnr = self.psnr(gt_rgb, predicted_rgb) 95 | ssim = self.ssim(gt_rgb, predicted_rgb) 96 | lpips = self.lpips(gt_rgb, predicted_rgb) 97 | 98 | # all of these metrics will be logged as scalars 99 | metrics_dict = { 100 | "rgb_psnr": float(psnr.item()), 101 | "rgb_ssim": float(ssim), 102 | } # type: ignore 103 | metrics_dict["rgb_lpips"] = float(lpips) 104 | 105 | if "sensor_depth" in batch: 106 | gt_depth = batch["sensor_depth"].to(self.device) 107 | predicted_depth = outputs["depth"] 108 | if predicted_depth.shape[:2] != gt_depth.shape[:2]: 109 | predicted_depth = TF.resize( 110 | predicted_depth.permute(2, 0, 1), gt_depth.shape[:2], antialias=None 111 | ).permute(1, 2, 0) 112 | 113 | gt_depth = gt_depth.to(torch.float32) # it is in float64 previous 114 | if "mask" in batch: 115 | gt_depth = gt_depth * mask 116 | predicted_depth = predicted_depth * mask 117 | 118 | # add depth eval metrics 119 | (abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3) = self.depth_metrics( 120 | predicted_depth.permute(2, 0, 1), gt_depth.permute(2, 0, 1) 121 | ) 122 | depth_metrics = { 123 | "depth_abs_rel": float(abs_rel.item()), 124 | "depth_sq_rel": float(sq_rel.item()), 125 | "depth_rmse": float(rmse.item()), 126 | "depth_rmse_log": float(rmse_log.item()), 127 | "depth_a1": float(a1.item()), 128 | "depth_a2": float(a2.item()), 129 | "depth_a3": float(a3.item()), 130 | } 131 | metrics_dict.update(depth_metrics) 132 | 133 | images_dict = {"img": combined_rgb} 134 | 135 | return metrics_dict, images_dict 136 | -------------------------------------------------------------------------------- /dn_splatter/eval/baseline_models/nerfstudio_to_sdfstudio.py: -------------------------------------------------------------------------------- 1 | """ 2 | nerfstudio format to sdfstudio format 3 | borrowed from sdfstudio https://github.com/autonomousvision/sdfstudio/blob/master/scripts/datasets/process_nerfstudio_to_sdfstudio.py 4 | """ 5 | import argparse 6 | import json 7 | from pathlib import Path 8 | 9 | import cv2 10 | import matplotlib.pyplot as plt 11 | import numpy as np 12 | from PIL import Image 13 | from tqdm import tqdm 14 | 15 | 16 | def main(args): 17 | """ 18 | Given data that follows the nerfstudio format such as the output from colmap or polycam, 19 | convert to a format that sdfstudio will ingest 20 | """ 21 | output_dir = Path(args.output_dir) 22 | input_dir = Path(args.input_dir) 23 | output_dir.mkdir(parents=True, exist_ok=True) 24 | cam_params = json.load(open(input_dir / "transforms.json")) 25 | 26 | # === load camera intrinsics and poses === 27 | cam_intrinsics = [] 28 | if args.data_type == "colmap": 29 | cam_intrinsics.append( 30 | np.array( 31 | [ 32 | [cam_params["fl_x"], 0, cam_params["cx"]], 33 | [0, cam_params["fl_y"], cam_params["cy"]], 34 | [0, 0, 1], 35 | ] 36 | ) 37 | ) 38 | 39 | frames = cam_params["frames"] 40 | poses = [] 41 | image_paths = [] 42 | depth_paths = [] 43 | mono_depth_paths = [] 44 | # only load images with corresponding pose info 45 | # currently in random order??, probably need to sort 46 | for frame in frames: 47 | # load intrinsics from polycam 48 | if args.data_type == "polycam": 49 | cam_intrinsics.append( 50 | np.array( 51 | [ 52 | [frame["fl_x"], 0, frame["cx"]], 53 | [0, frame["fl_y"], frame["cy"]], 54 | [0, 0, 1], 55 | ] 56 | ) 57 | ) 58 | 59 | # load poses 60 | # OpenGL/Blender convention, needs to change to COLMAP/OpenCV convention 61 | # https://docs.nerf.studio/en/latest/quickstart/data_conventions.html 62 | # IGNORED for now 63 | c2w = np.array(frame["transform_matrix"]) 64 | if c2w.shape == (3, 4): 65 | c2w = np.vstack([c2w, np.array([0, 0, 0, 1])]) 66 | c2w = c2w.reshape(4, 4) 67 | c2w[0:3, 1:3] *= -1 68 | poses.append(c2w) 69 | 70 | # load images 71 | img_path = input_dir / frame["file_path"] 72 | assert img_path.exists() 73 | image_paths.append(img_path) 74 | 75 | # load sensor depths 76 | if "depth_file_path" in frame: 77 | depth_path = input_dir / frame["depth_file_path"] 78 | assert depth_path.exists() 79 | depth_paths.append(depth_path) 80 | if "mono_depth_file_path" in frame: 81 | mono_depth_path = input_dir / frame["mono_depth_file_path"] 82 | assert mono_depth_path.exists() 83 | mono_depth_paths.append(mono_depth_paths) 84 | 85 | # Check correctness 86 | assert len(poses) == len(image_paths) 87 | assert len(poses) == len(cam_intrinsics) or len(cam_intrinsics) == 1 88 | 89 | # Filter invalid poses 90 | poses = np.array(poses) 91 | valid_poses = np.isfinite(poses).all(axis=2).all(axis=1) 92 | min_vertices = poses[:, :3, 3][valid_poses].min(axis=0) 93 | max_vertices = poses[:, :3, 3][valid_poses].max(axis=0) 94 | 95 | # === Normalize the scene === 96 | if args.scene_type in ["indoor", "object"]: 97 | # Enlarge bbox by 1.05 for object scene and by 5.0 for indoor scene 98 | # TODO: Adaptively estimate `scene_scale_mult` based on depth-map or point-cloud prior 99 | if not args.scene_scale_mult: 100 | args.scene_scale_mult = 1.05 if args.scene_type == "object" else 5.0 101 | scene_scale = 2.0 / ( 102 | np.max(max_vertices - min_vertices) * args.scene_scale_mult 103 | ) 104 | scene_center = (min_vertices + max_vertices) / 2.0 105 | # normalize pose to unit cube 106 | poses[:, :3, 3] -= scene_center 107 | poses[:, :3, 3] *= scene_scale 108 | # calculate scale matrix 109 | scale_mat = np.eye(4).astype(np.float32) 110 | scale_mat[:3, 3] -= scene_center 111 | scale_mat[:3] *= scene_scale 112 | scale_mat = np.linalg.inv(scale_mat) 113 | else: 114 | scene_scale = 1.0 115 | scale_mat = np.eye(4).astype(np.float32) 116 | 117 | # === Construct the scene box === 118 | if args.scene_type == "indoor": 119 | scene_box = { 120 | "aabb": [[-1, -1, -1], [1, 1, 1]], 121 | "near": 0.05, 122 | "far": 2.5, 123 | "radius": 1.0, 124 | "collider_type": "box", 125 | } 126 | elif args.scene_type == "object": 127 | scene_box = { 128 | "aabb": [[-1, -1, -1], [1, 1, 1]], 129 | "near": 0.05, 130 | "far": 2.0, 131 | "radius": 1.0, 132 | "collider_type": "near_far", 133 | } 134 | elif args.scene_type == "unbound": 135 | # TODO: case-by-case near far based on depth prior 136 | # such as colmap sparse points or sensor depths 137 | scene_box = { 138 | "aabb": [min_vertices.tolist(), max_vertices.tolist()], 139 | "near": 0.05, 140 | "far": 2.5 * np.max(max_vertices - min_vertices), 141 | "radius": np.min(max_vertices - min_vertices) / 2.0, 142 | "collider_type": "box", 143 | } 144 | 145 | # === Resize the images and intrinsics === 146 | # Only resize the images when we want to use mono prior 147 | if args.data_type == "colmap": 148 | h, w = cam_params["h"], cam_params["w"] 149 | else: 150 | h, w = frames[0]["h"], frames[0]["w"] 151 | 152 | # === Construct the frames in the meta_data.json === 153 | frames = [] 154 | out_index = 0 155 | mono_depth, sensor_depth = False, False 156 | if len(mono_depth_paths) > 0: 157 | mono_depth = True 158 | else: 159 | sensor_depth = True 160 | 161 | for idx, (valid, pose, image_path) in enumerate( 162 | tqdm(zip(valid_poses, poses, image_paths)) 163 | ): 164 | if not valid: 165 | continue 166 | 167 | # save rgb image 168 | out_img_path = output_dir / f"{out_index:06d}_rgb.png" 169 | if args.save_imgs: 170 | img = Image.open(image_path) 171 | img.save(out_img_path) 172 | rgb_path = str(out_img_path.relative_to(output_dir)) 173 | 174 | frame = { 175 | "rgb_path": rgb_path, 176 | "camtoworld": pose.tolist(), 177 | "intrinsics": cam_intrinsics[0].tolist() 178 | if args.data_type == "colmap" 179 | else cam_intrinsics[idx].tolist(), 180 | } 181 | 182 | if sensor_depth: 183 | # load depth 184 | if args.save_imgs: 185 | depth_path = depth_paths[idx] 186 | out_depth_path = output_dir / f"{out_index:06d}_sensor_depth.png" 187 | depth = cv2.imread(str(depth_path), -1).astype(np.float32) / 1000.0 188 | # scale depth as we normalize the scene to unit box 189 | new_depth = np.copy(depth) * scene_scale 190 | plt.imsave(out_depth_path, new_depth, cmap="viridis") 191 | np.save(str(out_depth_path).replace(".png", ".npy"), new_depth) 192 | frame["mono_depth_path"] = rgb_path.replace("_rgb.png", "_sensor_depth.npy") 193 | if mono_depth: 194 | # load mono depth 195 | if args.save_imgs: 196 | mono_depth_path = mono_depth_paths[idx] 197 | out_mono_depth_path = output_dir / f"{out_index:06d}_mono_depth.png" 198 | mono_depth = ( 199 | cv2.imread(str(mono_depth_path), -1).astype(np.float32) / 1000.0 200 | ) 201 | # scale depth as we normalize the scene to unit box 202 | new_mono_depth = np.copy(mono_depth) * scene_scale 203 | plt.imsave(out_mono_depth_path, new_mono_depth, cmap="viridis") 204 | np.save( 205 | str(out_mono_depth_path).replace(".png", ".npy"), new_mono_depth 206 | ) 207 | frame["mono_depth_path"] = rgb_path.replace("_rgb.png", "_mono_depth.npy") 208 | 209 | # load normal 210 | out_normal_path = output_dir / f"{out_index:06d}_normal.png" 211 | if args.normal_from_pretrain: 212 | # load normal 213 | normal_path = input_dir / "normals_from_pretrain" / f"{image_path.stem}.png" 214 | elif args.normal_from_depth: 215 | # load normal from sensor depth 216 | normal_path = input_dir / "normals_from_depth" / f"{image_path.stem}.png" 217 | if args.normal_from_pretrain or args.normal_from_depth: 218 | if args.save_imgs: 219 | normal = Image.open(normal_path) 220 | normal.save(out_normal_path) 221 | normal = np.array(normal) 222 | np.save(str(out_normal_path).replace(".png", ".npy"), normal / 255.0) 223 | frame["mono_normal_path"] = rgb_path.replace("_rgb.png", "_normal.npy") 224 | 225 | frames.append(frame) 226 | out_index += 1 227 | 228 | # === Construct and export the metadata === 229 | meta_data = { 230 | "camera_model": "OPENCV", 231 | "height": h, 232 | "width": w, 233 | "has_foreground_mask": False, 234 | "pairs": None, 235 | "worldtogt": scale_mat.tolist(), 236 | "has_mono_prior": True, 237 | "scene_box": scene_box, 238 | "frames": frames, 239 | } 240 | with open(output_dir / "meta_data.json", "w", encoding="utf-8") as f: 241 | json.dump(meta_data, f, indent=4) 242 | 243 | print(f"Done! The processed data has been saved in {output_dir}") 244 | 245 | 246 | if __name__ == "__main__": 247 | parser = argparse.ArgumentParser( 248 | description="preprocess nerfstudio dataset to sdfstudio dataset, " 249 | "currently support colmap and polycam" 250 | ) 251 | 252 | parser.add_argument( 253 | "--input_dir", required=True, help="path to nerfstudio data directory" 254 | ) 255 | parser.add_argument( 256 | "--output_dir", required=True, help="path to output data directory" 257 | ) 258 | parser.add_argument( 259 | "--data-type", dest="data_type", required=True, choices=["colmap", "polycam"] 260 | ) 261 | parser.add_argument( 262 | "--scene-type", 263 | dest="scene_type", 264 | required=True, 265 | choices=["indoor", "object", "unbound"], 266 | help="The scene will be normalized into a unit sphere when selecting indoor or object.", 267 | ) 268 | parser.add_argument( 269 | "--scene-scale-mult", 270 | dest="scene_scale_mult", 271 | type=float, 272 | default=None, 273 | help="The bounding box of the scene is firstly calculated by the camera positions, " 274 | "then multiply with scene_scale_mult", 275 | ) 276 | parser.add_argument( 277 | "--normal_from_pretrain", 278 | action="store_true", 279 | help="Use normal from pretrain model", 280 | ) 281 | parser.add_argument( 282 | "--normal_from_depth", action="store_true", help="Use normal from sensor depth" 283 | ) 284 | parser.add_argument( 285 | "--save-imgs", 286 | action="store_true", 287 | required=True, 288 | help="Use normal from mono depth", 289 | ) 290 | 291 | args = parser.parse_args() 292 | 293 | main(args) 294 | -------------------------------------------------------------------------------- /dn_splatter/eval/eval.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | """eval.py 3 | 4 | run with: python python dn_splatter/eval.py --data [PATH_TO_DATA] 5 | 6 | options : 7 | --no-eval-rgb 8 | --no-eval-depth 9 | --eval-faro / --no-eval-faro 10 | 11 | eval-faro option is used for reference faro scanner projected depth maps 12 | """ 13 | import json 14 | import os 15 | from pathlib import Path 16 | from typing import Optional 17 | 18 | import cv2 19 | import torch 20 | import torchvision.transforms.functional as F 21 | import tyro 22 | from rich.console import Console 23 | from rich.progress import track 24 | from torchmetrics.functional import mean_squared_error 25 | from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure 26 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 27 | 28 | from dn_splatter.metrics import DepthMetrics 29 | from dn_splatter.utils.utils import depth_path_to_tensor 30 | 31 | CONSOLE = Console(width=120) 32 | BATCH_SIZE = 20 33 | 34 | 35 | @torch.no_grad() 36 | def rgb_eval(data: Path): 37 | render_path = data / Path("rgb") # os.path.join(args.data, "/rgb") 38 | gt_path = data / Path("gt/rgb/") # os.path.join(args.data, "gt", "rgb") 39 | 40 | image_list = [f for f in os.listdir(render_path) if f.endswith(".png")] 41 | image_list = sorted(image_list, key=lambda x: int(x.split(".")[0].split("_")[-1])) 42 | 43 | mse = mean_squared_error 44 | psnr = PeakSignalNoiseRatio(data_range=1.0) 45 | ssim = StructuralSimilarityIndexMeasure(data_range=1.0, kernel_size=11) 46 | lpips = LearnedPerceptualImagePatchSimilarity() 47 | 48 | num_frames = len(image_list) 49 | 50 | psnr_score_batch = [] 51 | ssim_score_batch = [] 52 | mse_score_batch = [] 53 | lpips_score_batch = [] 54 | 55 | CONSOLE.print( 56 | f"[bold green]Batchifying and evaluating a total of {num_frames} rgb frames" 57 | ) 58 | 59 | for batch_index in track(range(0, num_frames, BATCH_SIZE)): 60 | CONSOLE.print( 61 | f"[bold yellow]Evaluating batch {batch_index // BATCH_SIZE} / {num_frames//BATCH_SIZE}" 62 | ) 63 | batch_frames = image_list[batch_index : batch_index + BATCH_SIZE] 64 | predicted_rgb = [] 65 | gt_rgb = [] 66 | 67 | for i in batch_frames: 68 | render_img = cv2.imread(os.path.join(render_path, i)) / 255 69 | origin_img = cv2.imread(os.path.join(gt_path, i)) / 255 70 | origin_img = F.to_tensor(origin_img).to(torch.float32) 71 | render_img = F.to_tensor(render_img).to(torch.float32) 72 | 73 | predicted_rgb.append(render_img) 74 | gt_rgb.append(origin_img) 75 | 76 | predicted_image = torch.stack(predicted_rgb, 0) 77 | gt_image = torch.stack(gt_rgb, 0) 78 | 79 | mse_score = mse(predicted_image, gt_image) 80 | mse_score_batch.append(mse_score) 81 | psnr_score = psnr(predicted_image, gt_image) 82 | psnr_score_batch.append(psnr_score) 83 | ssim_score = ssim(predicted_image, gt_image) 84 | ssim_score_batch.append(ssim_score) 85 | lpips_score = lpips(predicted_image, gt_image) 86 | lpips_score_batch.append(lpips_score) 87 | 88 | mean_scores = { 89 | "mse": float(torch.stack(mse_score_batch).mean().item()), 90 | "psnr": float(torch.stack(psnr_score_batch).mean().item()), 91 | "ssim": float(torch.stack(ssim_score_batch).mean().item()), 92 | "lpips": float(torch.stack(lpips_score_batch).mean().item()), 93 | } 94 | print(list(mean_scores.keys())) 95 | print(list(mean_scores.values())) 96 | 97 | with open(os.path.join(render_path, "metrics.json"), "w") as outFile: 98 | print(f"Saving results to {os.path.join(render_path, 'metrics.json')}") 99 | json.dump(mean_scores, outFile, indent=2) 100 | 101 | 102 | @torch.no_grad() 103 | def depth_eval(data: Path): 104 | depth_metrics = DepthMetrics() 105 | 106 | render_path = data / Path("depth/raw/") # os.path.join(args.data, "/rgb") 107 | gt_path = data / Path("gt/depth/raw") # os.path.join(args.data, "gt", "rgb") 108 | 109 | depth_list = [f for f in os.listdir(render_path) if f.endswith(".npy")] 110 | depth_list = sorted(depth_list, key=lambda x: int(x.split(".")[0].split("_")[-1])) 111 | 112 | mse = mean_squared_error 113 | 114 | num_frames = len(depth_list) 115 | 116 | mse_score_batch = [] 117 | abs_rel_score_batch = [] 118 | sq_rel_score_batch = [] 119 | rmse_score_batch = [] 120 | rmse_log_score_batch = [] 121 | a1_score_batch = [] 122 | a2_score_batch = [] 123 | a3_score_batch = [] 124 | CONSOLE.print( 125 | f"[bold green]Batchifying and evaluating a total of {num_frames} depth frames" 126 | ) 127 | 128 | for batch_index in track(range(0, num_frames, BATCH_SIZE)): 129 | CONSOLE.print( 130 | f"[bold yellow]Evaluating batch {batch_index // BATCH_SIZE} / {num_frames//BATCH_SIZE}" 131 | ) 132 | batch_frames = depth_list[batch_index : batch_index + BATCH_SIZE] 133 | predicted_depth = [] 134 | gt_depth = [] 135 | 136 | for i in batch_frames: 137 | render_img = depth_path_to_tensor( 138 | Path(os.path.join(render_path, i)) 139 | ).permute(2, 0, 1) 140 | origin_img = depth_path_to_tensor(Path(os.path.join(gt_path, i))).permute( 141 | 2, 0, 1 142 | ) 143 | 144 | if origin_img.shape[-2:] != render_img.shape[-2:]: 145 | render_img = F.resize( 146 | render_img, size=origin_img.shape[-2:], antialias=None 147 | ) 148 | predicted_depth.append(render_img) 149 | gt_depth.append(origin_img) 150 | 151 | predicted_depth = torch.stack(predicted_depth, 0) 152 | gt_depth = torch.stack(gt_depth, 0) 153 | 154 | mse_score = mse(predicted_depth, gt_depth) 155 | mse_score_batch.append(mse_score) 156 | (abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3) = depth_metrics( 157 | predicted_depth, gt_depth 158 | ) 159 | abs_rel_score_batch.append(abs_rel) 160 | sq_rel_score_batch.append(sq_rel) 161 | rmse_score_batch.append(rmse) 162 | rmse_log_score_batch.append(rmse_log) 163 | a1_score_batch.append(a1) 164 | a2_score_batch.append(a2) 165 | a3_score_batch.append(a3) 166 | 167 | mean_scores = { 168 | "mse": float(torch.stack(mse_score_batch).mean().item()), 169 | "abs_rel": float(torch.stack(abs_rel_score_batch).mean().item()), 170 | "sq_rel": float(torch.stack(sq_rel_score_batch).mean().item()), 171 | "rmse": float(torch.stack(rmse_score_batch).mean().item()), 172 | "rmse_log": float(torch.stack(rmse_log_score_batch).mean().item()), 173 | "a1": float(torch.stack(a1_score_batch).mean().item()), 174 | "a2": float(torch.stack(a2_score_batch).mean().item()), 175 | "a3": float(torch.stack(a3_score_batch).mean().item()), 176 | } 177 | print(list(mean_scores.keys())) 178 | print(list(mean_scores.values())) 179 | with open(os.path.join(render_path, "metrics.json"), "w") as outFile: 180 | print(f"Saving results to {os.path.join(render_path, 'metrics.json')}") 181 | json.dump(mean_scores, outFile, indent=2) 182 | 183 | 184 | def depth_eval_faro(data: Path, path_to_faro: Path): 185 | depth_metrics = DepthMetrics() 186 | 187 | render_path = data / Path("depth/raw/") 188 | gt_path = path_to_faro 189 | 190 | depth_list = [f for f in os.listdir(render_path) if f.endswith(".png")] 191 | depth_list = sorted(depth_list, key=lambda x: int(x.split(".")[0].split("_")[-1])) 192 | 193 | mse = mean_squared_error 194 | 195 | num_frames = len(depth_list) 196 | 197 | mse_score_batch = [] 198 | abs_rel_score_batch = [] 199 | sq_rel_score_batch = [] 200 | rmse_score_batch = [] 201 | rmse_log_score_batch = [] 202 | a1_score_batch = [] 203 | a2_score_batch = [] 204 | a3_score_batch = [] 205 | CONSOLE.print( 206 | f"[bold green]Batchifying and evaluating a total of {num_frames} depth frames" 207 | ) 208 | 209 | for batch_index in track(range(0, num_frames, BATCH_SIZE)): 210 | CONSOLE.print( 211 | f"[bold yellow]Evaluating batch {batch_index // BATCH_SIZE} / {num_frames//BATCH_SIZE}" 212 | ) 213 | batch_frames = depth_list[batch_index : batch_index + BATCH_SIZE] 214 | predicted_depth = [] 215 | gt_depth = [] 216 | for i in batch_frames: 217 | render_img = depth_path_to_tensor( 218 | Path(os.path.join(render_path, i)) 219 | ).permute(2, 0, 1) 220 | 221 | if not Path(os.path.join(gt_path, i)).exists(): 222 | print("could not find frame ", i, " skipping it...") 223 | continue 224 | origin_img = depth_path_to_tensor(Path(os.path.join(gt_path, i))).permute( 225 | 2, 0, 1 226 | ) 227 | if origin_img.shape[-2:] != render_img.shape[-2:]: 228 | render_img = F.resize( 229 | render_img, size=origin_img.shape[-2:], antialias=None 230 | ) 231 | predicted_depth.append(render_img) 232 | gt_depth.append(origin_img) 233 | 234 | predicted_depth = torch.stack(predicted_depth, 0) 235 | gt_depth = torch.stack(gt_depth, 0) 236 | 237 | mse_score = mse(predicted_depth, gt_depth) 238 | mse_score_batch.append(mse_score) 239 | 240 | (abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3) = depth_metrics( 241 | predicted_depth, gt_depth 242 | ) 243 | abs_rel_score_batch.append(abs_rel) 244 | sq_rel_score_batch.append(sq_rel) 245 | rmse_score_batch.append(rmse) 246 | rmse_log_score_batch.append(rmse_log) 247 | a1_score_batch.append(a1) 248 | a2_score_batch.append(a2) 249 | a3_score_batch.append(a3) 250 | 251 | mean_scores = { 252 | "mse": float(torch.stack(mse_score_batch).mean().item()), 253 | "abs_rel": float(torch.stack(abs_rel_score_batch).mean().item()), 254 | "sq_rel": float(torch.stack(sq_rel_score_batch).mean().item()), 255 | "rmse": float(torch.stack(rmse_score_batch).mean().item()), 256 | "rmse_log": float(torch.stack(rmse_log_score_batch).mean().item()), 257 | "a1": float(torch.stack(a1_score_batch).mean().item()), 258 | "a2": float(torch.stack(a2_score_batch).mean().item()), 259 | "a3": float(torch.stack(a3_score_batch).mean().item()), 260 | } 261 | print("faro scanner metrics") 262 | print(list(mean_scores.keys())) 263 | print(list(mean_scores.values())) 264 | 265 | 266 | def main( 267 | data: Path, 268 | eval_rgb: bool = True, 269 | eval_depth: bool = True, 270 | eval_faro: bool = False, 271 | path_to_faro: Optional[Path] = None, 272 | ): 273 | if eval_rgb: 274 | rgb_eval(data) 275 | if eval_depth: 276 | depth_eval(data) 277 | if eval_faro: 278 | assert path_to_faro is not None, "need to specify faro path" 279 | depth_eval_faro(data, path_to_faro=path_to_faro) 280 | 281 | 282 | if __name__ == "__main__": 283 | tyro.cli(main) 284 | -------------------------------------------------------------------------------- /dn_splatter/eval/eval_instructions.md: -------------------------------------------------------------------------------- 1 | # Evaluation instructions 2 | 3 | In this document we briefly describe the evaluation protocol and methods used in the DN-Splatter project. 4 | 5 | - Under `dn_splatter/eval/` you will find various scripts for rgb, depth, mesh, and normal evaluation. 6 | - Under `dn_splatter/eval/baseline_models` you will find the configuration and model files for baseline models: Nerfacto, Depth-Nerfacto, and Neusfacto as well as some convenient scripts for converting between Nerfstudio and SDFStudio dataset formats. 7 | 8 | ## Computing Evaluation Metrics 9 | We include models for evaluating rgb, depth and/or mesh metrics using various method. 10 | The following methods are supported. 11 | 12 | ### Mesh Metrics 13 | 14 | We report the following metrics 15 | ``` 16 | accuracy # lower better 17 | completeness # lower better 18 | chamferL1 # lower better 19 | normals_correctness # higher better 20 | F-score # higher better 21 | ``` 22 | 23 | Evaluate mesh metrics and cull the predicted mesh based on training camera view visibility. Regions of the mesh not seen in the training dataset are ignored when computing mesh metrics. 24 | 25 | Run with: 26 | ```bash 27 | python dn_splatter/eval/eval_mesh_vis_cull.py --path-to-pd-mesh [PATH_TO_PREDICTED_PLY] --path-to-gt-mesh [PATH_TO_GT_PLY] 28 | ``` 29 | For MuSHRoom dataset, use: 30 | ```bash 31 | python dn_splatter/eval/eval_mesh_mushroom_vis_cull.py --path-to-pd-mesh [PATH_TO_PREDICTED_PLY] --path-to-gt-mesh [PATH_TO_GT_PLY] 32 | ``` 33 | 34 | ### RGB/Depth metrics 35 | To evaluate depth estimation and novel-view rgb metrics, run the following command: 36 | 37 | ```bash 38 | ns-eval --load-config [PATH_TO_YAML] --output-path [PATH_TO_JSON] 39 | ``` 40 | 41 | ## Running Baseline Models 42 | ## Methods 43 | ### Nerfacto 44 | RGB only supervision 45 | ```bash 46 | ns-train gnerfacto --data [PATH] dtu/nrgbd/replica/scannet/mushroom 47 | ``` 48 | ### Depth-Nerfacto 49 | RGB and Depth supervision 50 | ```bash 51 | ns-train gdepthfacto --data [PATH] dtu/nrgbd/replica/scannet/mushroom 52 | ``` 53 | ### NeusFacto 54 | RGB, Depth, and Normal supervision 55 | ```bash 56 | ns-train gneusfacto --data [PATH] sdfstudio-data --include-mono-prior True 57 | ``` 58 | 59 | # SDFStudio Instrucitons 60 | ## Train on data downloaded with ns-download-data sdfstudio 61 | SDFStudio data can be correctly loaded with the `gsdf` dataparser. 62 | If using gneusfacto, remember to correctly set the `--load-for-sdfstudio True` flag in gsdf. It is default False. 63 | ```bash 64 | ns-train gneusfacto --data ./datasets/DTU/scan65/ gsdf --load-for-sdfstudio True 65 | ``` 66 | 67 | For dn_splatter, run: 68 | ```bash 69 | ns-train dn_splatter --pipeline.model.use-depth-loss False --data ./datasets/DTU/scan65 gsdf 70 | ``` 71 | 72 | ## Convert non-sdfstudio datasets to SDFStudio format 73 | Transfer from nerfstudio data format to sdfstudio data format 74 | ```bash 75 | python dn_splatter/eval/baseline_models/nerfstudio_to_sdfstudio.py --input_dir [DATA_PATH] \ 76 | --output_dir [OUTPUT_PATH] \ 77 | --data-type colmap --scene-type indoor --normal_from_pretrain/normal_from_depth 78 | ``` 79 | 80 | -------------------------------------------------------------------------------- /dn_splatter/eval/eval_normals.py: -------------------------------------------------------------------------------- 1 | from pathlib import Path 2 | 3 | import numpy as np 4 | import tyro 5 | from torch import Tensor 6 | from torch.utils.data import DataLoader, Dataset 7 | 8 | from dn_splatter.metrics import NormalMetrics 9 | from dn_splatter.utils.utils import image_path_to_tensor 10 | 11 | 12 | def tensor_to_normals(img_tensor: Tensor) -> Tensor: 13 | """ 14 | Convert image tensor [0, 1] to normal vectors [-1, 1]. 15 | """ 16 | normals = img_tensor * 2 - 1 # Normalize to [-1, 1] 17 | return normals 18 | 19 | 20 | class NormalsDataset(Dataset): 21 | def __init__( 22 | self, gt_folder_path: Path, prediction_folder_path: Path, transform=None 23 | ): 24 | self.gt_rgb_norm_fnames = sorted( 25 | [f for f in gt_folder_path.rglob("*") if f.is_file() and f.suffix == ".png"] 26 | ) 27 | self.pred_rgb_norm_fnames = sorted( 28 | [ 29 | f 30 | for f in prediction_folder_path.rglob("*") 31 | if f.is_file() and f.suffix == ".png" 32 | ] 33 | ) 34 | self.transform = transform 35 | 36 | def __len__(self): 37 | return len(self.gt_rgb_norm_fnames) 38 | 39 | def __getitem__(self, idx): 40 | gt_tensor = image_path_to_tensor(self.gt_rgb_norm_fnames[idx]) 41 | pred_tensor = image_path_to_tensor(self.pred_rgb_norm_fnames[idx]) 42 | if self.transform: 43 | gt_tensor = self.transform(gt_tensor) 44 | pred_tensor = self.transform(pred_tensor) 45 | 46 | return gt_tensor, pred_tensor 47 | 48 | 49 | def main(render_path: Path): 50 | gt_folder_path = render_path / Path("gt/normal") 51 | prediction_folder_path = render_path / Path("pred/normal") 52 | data_transform = tensor_to_normals 53 | 54 | # Create the dataset and DataLoader 55 | batch_size = 5 56 | dataset = NormalsDataset( 57 | gt_folder_path, prediction_folder_path, transform=data_transform 58 | ) 59 | dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=False) 60 | 61 | metrics = NormalMetrics() 62 | mae_acc, rmse_acc, mean_acc, med_acc = 0.0, 0.0, 0.0, 0.0 63 | # Access batches in a loop 64 | for batch_idx, (gt_images, prediction_images) in enumerate(dataloader): 65 | mae_run, rmse_run, mean_err_run, med_err_run = metrics( 66 | prediction_images.permute(0, 3, 1, 2), gt_images.permute(0, 3, 1, 2) 67 | ) 68 | mae_acc += mae_run 69 | rmse_acc += rmse_run 70 | mean_acc += mean_err_run 71 | med_acc += med_err_run 72 | 73 | print("Performance report (normals estimation):") 74 | print("Error (Lower is better):") 75 | print( 76 | f"MAE (rad): {mae_acc / len(dataloader)}; MAE (deg): {np.rad2deg(mae_acc / len(dataloader))}" 77 | ) 78 | print(f"RMSE: {rmse_acc / len(dataloader)}") 79 | print(f"Mean: {mean_acc / len(dataloader)}") 80 | print(f"Med: {med_acc / len(dataloader)}") 81 | print( 82 | f"{mae_acc / len(dataloader)}, {np.rad2deg(mae_acc / len(dataloader))}," 83 | f" {rmse_acc / len(dataloader)}," 84 | f" {mean_acc / len(dataloader)}" 85 | ) 86 | 87 | 88 | if __name__ == "__main__": 89 | tyro.cli(main) 90 | -------------------------------------------------------------------------------- /dn_splatter/eval/eval_pd.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | from pathlib import Path 4 | 5 | import numpy as np 6 | import open3d as o3d 7 | import tyro 8 | from scipy.spatial import cKDTree 9 | 10 | 11 | def open3d_mesh_from_trimesh(tri_mesh): 12 | vertices = np.asarray(tri_mesh.vertices) 13 | faces = np.asarray(tri_mesh.faces) 14 | 15 | # Create open3d TriangleMesh object 16 | o3d_mesh = o3d.geometry.TriangleMesh() 17 | 18 | # Assign vertices and faces to open3d mesh 19 | o3d_mesh.vertices = o3d.utility.Vector3dVector(vertices) 20 | o3d_mesh.triangles = o3d.utility.Vector3iVector(faces) 21 | return o3d_mesh 22 | 23 | 24 | def calculate_accuracy( 25 | reconstructed_points, reference_points, percentile=90 26 | ): # Calculat accuracy: How far away 90% of the reconstructed point clouds are from the reference point cloud. 27 | tree = cKDTree(reference_points) 28 | distances, _ = tree.query(reconstructed_points) 29 | return np.percentile(distances, percentile) 30 | 31 | 32 | def calculate_completeness( 33 | reconstructed_points, reference_points, threshold=0.05 34 | ): # calucate completeness: What percentage of the reference point cloud is within a specific distance of the reconstructed point cloud. 35 | tree = cKDTree(reconstructed_points) 36 | distances, _ = tree.query(reference_points) 37 | within_threshold = np.sum(distances < threshold) / len(distances) 38 | return within_threshold * 100 39 | 40 | 41 | def main( 42 | export_pd: Path, 43 | path_to_room: Path = Path("room_datasets/activity"), 44 | device_type: Path = Path("iphone"), 45 | evaluate_protocol: str = "within", 46 | ): 47 | # import predicted pd 48 | reconstructed_pd = o3d.io.read_point_cloud(str(export_pd)) 49 | 50 | # load training pose 51 | if evaluate_protocol == "within": 52 | within_pose = json.load( 53 | open( 54 | os.path.join( 55 | path_to_room, device_type, "long_capture", "transformations.json" 56 | ) 57 | ) 58 | ) 59 | ref_pose = within_pose["frames"][0]["transform_matrix"] 60 | with_diff_pose = json.load( 61 | open( 62 | os.path.join( 63 | path_to_room, 64 | device_type, 65 | "long_capture", 66 | "transformations_colmap.json", 67 | ) 68 | ) 69 | ) 70 | diff_pose = with_diff_pose["frames"][0]["transform_matrix"] 71 | align_transformation = np.matmul(np.linalg.inv(ref_pose), diff_pose) 72 | print(align_transformation) 73 | reconstructed_pd = reconstructed_pd.transform(align_transformation) 74 | 75 | # load the transformation matrix to convert from colmap pose to reference mesh 76 | initial_transformation = np.array( 77 | json.load( 78 | open(os.path.join(path_to_room, "icp_{}.json".format(str(device_type)))) 79 | )["gt_transformation"] 80 | ).reshape(4, 4) 81 | reconstructed_pd = reconstructed_pd.transform(initial_transformation) 82 | reconstructed_pd = reconstructed_pd.voxel_down_sample(voxel_size=0.01) 83 | # import reference pd 84 | reference_pd = o3d.io.read_point_cloud(os.path.join(path_to_room, "gt_pd.ply")) 85 | 86 | reconstructed_points = np.asarray(reconstructed_pd.points) 87 | reference_points = np.asarray(reference_pd.points) 88 | accuracy = calculate_accuracy(reconstructed_points, reference_points) 89 | completeness = calculate_completeness(reconstructed_points, reference_points) 90 | print(accuracy, completeness) 91 | 92 | 93 | if __name__ == "__main__": 94 | tyro.cli(main) 95 | -------------------------------------------------------------------------------- /dn_splatter/metrics.py: -------------------------------------------------------------------------------- 1 | """Metrics""" 2 | 3 | import numpy as np 4 | import torch 5 | from scipy.spatial import cKDTree 6 | from torch import nn 7 | from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure 8 | from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity 9 | 10 | 11 | class PDMetrics(nn.Module): 12 | """Computation of error metrics between predicted and ground truth point clouds 13 | 14 | Input: 15 | pred: predicted pd 16 | gt: ground truth pd 17 | 18 | Returns: 19 | acc 20 | comp 21 | """ 22 | 23 | def __init__(self, **kwargs): 24 | super().__init__() 25 | 26 | self.acc = calculate_accuracy 27 | self.cmp = calculate_completeness 28 | 29 | @torch.no_grad() 30 | def forward(self, pred, gt): 31 | pred_points = np.asarray(pred.points) 32 | gt_points = np.asarray(gt.points) 33 | acc_score = self.acc(pred_points, gt_points) 34 | cmp_score = self.cmp(pred_points, gt_points) 35 | 36 | return (acc_score, cmp_score) 37 | 38 | 39 | def calculate_accuracy(reconstructed_points, reference_points, percentile=90): 40 | """ 41 | Calculat accuracy: How far away 90% of the reconstructed point clouds are from the reference point cloud. 42 | """ 43 | tree = cKDTree(reference_points) 44 | distances, _ = tree.query(reconstructed_points) 45 | return np.percentile(distances, percentile) 46 | 47 | 48 | def calculate_completeness(reconstructed_points, reference_points, threshold=0.05): 49 | """ 50 | calucate completeness: What percentage of the reference point cloud is within 51 | a specific distance of the reconstructed point cloud. 52 | """ 53 | tree = cKDTree(reconstructed_points) 54 | distances, _ = tree.query(reference_points) 55 | within_threshold = np.sum(distances < threshold) / len(distances) 56 | return within_threshold * 100 57 | 58 | 59 | def mean_angular_error(pred: torch.Tensor, gt: torch.Tensor) -> torch.Tensor: 60 | """Compute the mean angular error between predicted and reference normals 61 | 62 | Args: 63 | predicted_normals: [B, C, H, W] tensor of predicted normals 64 | reference_normals : [B, C, H, W] tensor of gt normals 65 | 66 | Returns: 67 | mae: [B, H, W] mean angular error 68 | """ 69 | dot_products = torch.sum(gt * pred, dim=1) # over the C dimension 70 | # Clamp the dot product to ensure valid cosine values (to avoid nans) 71 | dot_products = torch.clamp(dot_products, -1.0, 1.0) 72 | # Calculate the angle between the vectors (in radians) 73 | mae = torch.acos(dot_products) 74 | return mae 75 | 76 | 77 | class RGBMetrics(nn.Module): 78 | """Computation of error metrics between predicted and ground truth images 79 | 80 | Input: 81 | pred: predicted image [B, C, H, W] 82 | gt: ground truth image [B, C, H, W] 83 | 84 | Returns: 85 | PSNR 86 | SSIM 87 | LPIPS 88 | """ 89 | 90 | def __init__(self, **kwargs): 91 | super().__init__() 92 | 93 | self.psnr = PeakSignalNoiseRatio(data_range=1.0) 94 | self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0, kernel_size=11) 95 | self.lpips = LearnedPerceptualImagePatchSimilarity() 96 | 97 | @torch.no_grad() 98 | def forward(self, pred, gt): 99 | self.device = pred.device 100 | self.psnr.to(self.device) 101 | self.ssim.to(self.device) 102 | self.lpips.to(self.device) 103 | 104 | psnr_score = self.psnr(pred, gt) 105 | ssim_score = self.ssim(pred, gt) 106 | lpips_score = self.lpips(pred, gt) 107 | 108 | return (psnr_score, ssim_score, lpips_score) 109 | 110 | 111 | class DepthMetrics(nn.Module): 112 | """Computation of error metrics between predicted and ground truth depths 113 | 114 | from: 115 | https://arxiv.org/abs/1806.01260 116 | 117 | Returns: 118 | abs_rel: normalized avg absolute realtive error 119 | sqrt_rel: normalized square-root of absolute error 120 | rmse: root mean square error 121 | rmse_log: root mean square error in log space 122 | a1, a2, a3: metrics 123 | """ 124 | 125 | def __init__(self, tolerance: float = 0.1, **kwargs): 126 | self.tolerance = tolerance 127 | super().__init__() 128 | 129 | @torch.no_grad() 130 | def forward(self, pred, gt): 131 | mask = gt > self.tolerance 132 | 133 | thresh = torch.max((gt[mask] / pred[mask]), (pred[mask] / gt[mask])) 134 | a1 = (thresh < 1.25).float().mean() 135 | a2 = (thresh < 1.25**2).float().mean() 136 | a3 = (thresh < 1.25**3).float().mean() 137 | rmse = (gt[mask] - pred[mask]) ** 2 138 | rmse = torch.sqrt(rmse.mean()) 139 | 140 | rmse_log = (torch.log(gt[mask]) - torch.log(pred[mask])) ** 2 141 | # rmse_log[rmse_log == float("inf")] = float("nan") 142 | rmse_log = torch.sqrt(rmse_log).nanmean() 143 | 144 | abs_rel = torch.abs(gt - pred)[mask] / gt[mask] 145 | abs_rel = abs_rel.mean() 146 | sq_rel = (gt - pred)[mask] ** 2 / gt[mask] 147 | sq_rel = sq_rel.mean() 148 | 149 | return (abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3) 150 | 151 | 152 | class NormalMetrics(nn.Module): 153 | """Computation of error metrics between predicted and ground truth normal maps. 154 | 155 | Args: 156 | predicted_normals: [B, C, H, W] tensor of predicted normals 157 | reference_normals : [B, C, H, W] tensor of gt normals 158 | 159 | Returns: 160 | All metrics are averaged over the batch 161 | mae: mean angular error 162 | rmse: root mean squared error 163 | mean: mean error 164 | med: median error 165 | """ 166 | 167 | def __init__(self, **kwargs): 168 | super().__init__() 169 | 170 | @torch.no_grad() 171 | def forward(self, pred, gt): 172 | b, c, _, _ = gt.shape 173 | # calculate MAE 174 | mae = mean_angular_error(pred, gt).mean() 175 | # calculate RMSE 176 | rmse = torch.sqrt(torch.mean(torch.square(gt - pred), dim=[1, 2, 3])).mean() 177 | # calculate Mean 178 | mean_err = torch.mean(torch.abs(gt - pred), dim=[1, 2, 3]).mean() 179 | # calculate Median 180 | med_err = torch.median( 181 | torch.abs(gt.view(b, c, -1) - pred.view(b, c, -1)) 182 | ).mean() 183 | return mae, rmse, mean_err, med_err 184 | -------------------------------------------------------------------------------- /dn_splatter/scripts/compare_normals.py: -------------------------------------------------------------------------------- 1 | import rerun as rr 2 | import numpy as np 3 | from PIL import Image 4 | from pathlib import Path 5 | from argparse import ArgumentParser 6 | 7 | 8 | def main(first_normal_dir: Path, second_normal_dir: Path): 9 | assert first_normal_dir.is_dir(), f"{first_normal_dir} is not a directory" 10 | assert second_normal_dir.is_dir(), f"{second_normal_dir} is not a directory" 11 | 12 | first_normal_glob = sorted(first_normal_dir.glob("*.png")) 13 | second_normal_glob = sorted(second_normal_dir.glob("*.png")) 14 | assert len(first_normal_glob) > 0, f"No normal images found in {first_normal_dir}" 15 | assert len(second_normal_glob) > 0, f"No normal images found in {second_normal_dir}" 16 | for idx, (first_path, second_path) in enumerate( 17 | zip(first_normal_glob, second_normal_glob) 18 | ): 19 | rr.set_time_sequence("idx", idx) 20 | first_normal = np.array(Image.open(first_path), dtype="uint8")[..., :3] 21 | second_normal = np.array(Image.open(second_path), dtype="uint8")[..., :3] 22 | rr.log("first_normal", rr.Image(first_normal)) 23 | rr.log("second_normal", rr.Image(second_normal)) 24 | 25 | 26 | if __name__ == "__main__": 27 | parser = ArgumentParser("Surface Normal Comparison") 28 | parser.add_argument( 29 | "first_normal_dir", 30 | type=Path, 31 | help="Path to the first directory containing surface normal predictions", 32 | ) 33 | parser.add_argument( 34 | "second_normal_dir", 35 | type=Path, 36 | help="Path to the second directory containing surface normal predictions", 37 | ) 38 | rr.script_add_args(parser) 39 | args = parser.parse_args() 40 | rr.script_setup(args, "compare_normals") 41 | main( 42 | first_normal_dir=args.first_normal_dir, second_normal_dir=args.second_normal_dir 43 | ) 44 | rr.script_teardown(args) 45 | -------------------------------------------------------------------------------- /dn_splatter/scripts/comparison_video.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -eux 3 | 4 | INPUT_BASELINE="$1" 5 | INPUT_OURS="$2" 6 | OUTPUT="$3" 7 | 8 | #: "${VIDEO_MODE:=HALF}" 9 | : "${VIDEO_MODE:=SWEEP}" 10 | 11 | : "${DRAW_TEXT:=ON}" 12 | : "${DRAW_BAR:=ON}" 13 | : "${CROP_TO_HD_ASPECT:=OFF}" 14 | 15 | if [ $CROP_TO_HD_ASPECT == "ON" ]; then 16 | BASE_FILTER=" 17 | [0:v]crop=iw:'min(ih,iw/16*9)'[base];\ 18 | [1:v]crop=iw:'min(ih,iw/16*9)'[ours]" 19 | else 20 | BASE_FILTER="[0:v]copy[base];[1:v]copy[ours]" 21 | fi 22 | if [ $DRAW_TEXT == "ON" ]; then 23 | BASE_FILTER=" 24 | $BASE_FILTER;\ 25 | [base]drawtext=text='Splatfacto':fontcolor=white:fontsize=h/50:x=w/50:y=h/50[base];\ 26 | [ours]drawtext=text='DN-Splatter':fontcolor=white:fontsize=h/50:x=w-tw-w/50:y=h/50[ours]" 27 | fi 28 | if [ $DRAW_BAR == "ON" ]; then 29 | BASE_FILTER=" 30 | $BASE_FILTER;\ 31 | color=0x80ff80,format=rgba[bar];\ 32 | [bar][base]scale2ref[bar][base];\ 33 | [bar]crop=iw:ih/200:0:0[bar];\ 34 | [ours][bar]overlay=x=0:y=0[ours]" 35 | fi 36 | 37 | case $VIDEO_MODE in 38 | HALF) 39 | VIDEO_FILTER=" 40 | $BASE_FILTER;\ 41 | [base]crop=iw/2:ih:0:0[left_crop];\ 42 | [ours]crop=iw/2:ih:iw/2:0[right_crop];\ 43 | [left_crop][right_crop]hstack" 44 | ;; 45 | 46 | SWEEP) 47 | LEN=8 48 | VIDEO_FILTER=" 49 | $BASE_FILTER;\ 50 | color=0x00000000,format=rgba,scale=[black];\ 51 | color=0xffffffff,format=rgba[white];\ 52 | [black][base]scale2ref[black][base];\ 53 | [white][base]scale2ref[white][base];\ 54 | [white][black]blend=all_expr='if(lte(X,W*abs(1-mod(T,$LEN)/$LEN*2)),B,A)'[mask];\ 55 | [ours][mask]alphamerge[overlayalpha]; \ 56 | [base][overlayalpha]overlay=shortest=1" 57 | ;; 58 | 59 | *) 60 | echo -n "unknown video mode $VIDEO_MODE" 61 | exit 1 62 | ;; 63 | esac 64 | 65 | ffmpeg -i "$INPUT_BASELINE" -i "$INPUT_OURS" -filter_complex "$VIDEO_FILTER" -hide_banner -y "$OUTPUT" -------------------------------------------------------------------------------- /dn_splatter/scripts/convert_colmap.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import os 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Literal 6 | 7 | import tyro 8 | 9 | """ 10 | 11 | |---image_path 12 | | |--- 13 | | |--- 14 | | |---... 15 | |---colmap 16 | |---sparse 17 | |---0 18 | |---cameras.bin 19 | |---images.bin 20 | |---points3D.bin 21 | """ 22 | 23 | 24 | @dataclass 25 | class ConvertColmap: 26 | """Convert images to COLMAP format""" 27 | 28 | image_path: Path 29 | """Input to images folder""" 30 | use_gpu: bool = True 31 | """Whether to use_gpu with colmap""" 32 | skip_matching: bool = False 33 | """Skip matching""" 34 | skip_undistortion: bool = True 35 | """Skip undistorting images""" 36 | camera: Literal["OPENCV"] = "OPENCV" 37 | """Camera type""" 38 | resize: bool = False 39 | """Resize images""" 40 | 41 | def main(self): 42 | image_path = str(self.image_path.resolve()) 43 | use_gpu = 1 if self.use_gpu else 0 44 | colmap_command = "colmap" 45 | 46 | base_dir = str(Path(image_path).parent) 47 | 48 | if not self.skip_matching: 49 | os.makedirs(base_dir + "/colmap/sparse", exist_ok=True) 50 | ## Feature extraction 51 | feat_extracton_cmd = ( 52 | colmap_command + " feature_extractor " 53 | "--database_path " 54 | + base_dir 55 | + "/colmap/database.db \ 56 | --image_path " 57 | + image_path 58 | + " \ 59 | --ImageReader.single_camera 1 \ 60 | --ImageReader.camera_model " 61 | + self.camera 62 | + " \ 63 | --SiftExtraction.use_gpu " 64 | + str(use_gpu) 65 | ) 66 | exit_code = os.system(feat_extracton_cmd) 67 | if exit_code != 0: 68 | logging.error( 69 | f"Feature extraction failed with code {exit_code}. Exiting." 70 | ) 71 | exit(exit_code) 72 | 73 | ## Feature matching 74 | feat_matching_cmd = ( 75 | colmap_command 76 | + " exhaustive_matcher \ 77 | --database_path " 78 | + base_dir 79 | + "/colmap/database.db \ 80 | --SiftMatching.use_gpu " 81 | + str(use_gpu) 82 | ) 83 | exit_code = os.system(feat_matching_cmd) 84 | if exit_code != 0: 85 | logging.error( 86 | f"Feature matching failed with code {exit_code}. Exiting." 87 | ) 88 | exit(exit_code) 89 | 90 | ### Bundle adjustment 91 | # The default Mapper tolerance is unnecessarily large, 92 | # decreasing it speeds up bundle adjustment steps. 93 | mapper_cmd = ( 94 | colmap_command 95 | + " mapper \ 96 | --database_path " 97 | + base_dir 98 | + "/colmap/database.db \ 99 | --image_path " 100 | + image_path 101 | + " \ 102 | --output_path " 103 | + base_dir 104 | + "/colmap/sparse \ 105 | --Mapper.ba_global_function_tolerance=0.000001" 106 | ) 107 | exit_code = os.system(mapper_cmd) 108 | if exit_code != 0: 109 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 110 | exit(exit_code) 111 | 112 | ### Image undistortion 113 | if not self.skip_undistortion: 114 | img_undist_cmd = ( 115 | colmap_command 116 | + " image_undistorter \ 117 | --image_path " 118 | + image_path 119 | + " \ 120 | --input_path " 121 | + base_dir 122 | + "/colmap/sparse/0 \ 123 | --output_path " 124 | + base_dir 125 | + "\ 126 | --output_type COLMAP" 127 | ) 128 | exit_code = os.system(img_undist_cmd) 129 | if exit_code != 0: 130 | logging.error(f"Mapper failed with code {exit_code}. Exiting.") 131 | exit(exit_code) 132 | 133 | # TODO: move files to new destination folder 134 | # files = os.listdir(base_dir + "/sparse") 135 | # os.makedirs(base_dir + "/sparse/0", exist_ok=True) 136 | ## Copy each file from the source directory to the destination directory 137 | # for file in files: 138 | # if file == "0": 139 | # continue 140 | # source_file = os.path.join(base_dir, "sparse", file) 141 | # destination_file = os.path.join(base_dir, "sparse", "0", file) 142 | # shutil.move(source_file, destination_file) 143 | 144 | if self.resize: 145 | raise NotImplementedError 146 | 147 | 148 | if __name__ == "__main__": 149 | tyro.cli(ConvertColmap).main() 150 | -------------------------------------------------------------------------------- /dn_splatter/scripts/depth_normal_consistency.py: -------------------------------------------------------------------------------- 1 | """Check the normal consistency between the normals from the pre-trained model and the normals from the depth map.""" 2 | 3 | import open3d as o3d 4 | import numpy as np 5 | from pathlib import Path 6 | from dataclasses import dataclass 7 | import os 8 | import cv2 9 | from rich.console import Console 10 | from rich.progress import track 11 | from natsort import natsorted 12 | from copy import deepcopy 13 | from nerfstudio.utils.io import load_from_json 14 | from typing import Literal 15 | import tyro 16 | 17 | from PIL import Image 18 | 19 | CONSOLE = Console(width=120) 20 | OPENGL_TO_OPENCV = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 21 | SCALE_FACTOR = 0.001 22 | 23 | 24 | def get_camera_coords(img_size: tuple, pixel_offset: float = 0.5) -> np.ndarray: 25 | image_coords = np.meshgrid( 26 | np.arange(img_size[0]), 27 | np.arange(img_size[1]), 28 | indexing="xy", # W = u by H = v 29 | ) 30 | image_coords = ( 31 | np.stack(image_coords, axis=-1) + pixel_offset 32 | ) # stored as (x, y) coordinates 33 | image_coords = image_coords.reshape(-1, 2) 34 | image_coords = image_coords.astype(np.float32) 35 | return image_coords 36 | 37 | 38 | def backproject( 39 | depths: np.ndarray, 40 | fx: float, 41 | fy: float, 42 | cx: int, 43 | cy: int, 44 | img_size: tuple, 45 | c2w: np.ndarray, 46 | ): 47 | if depths.ndim == 3: 48 | depths = depths.reshape(-1, 1) 49 | elif depths.shape[-1] != 1: 50 | depths = depths[..., np.newaxis] 51 | depths = depths.reshape(-1, 1) 52 | 53 | image_coords = get_camera_coords(img_size) 54 | 55 | means3d = np.zeros([img_size[0], img_size[1], 3], dtype=np.float32).reshape(-1, 3) 56 | means3d[:, 0] = (image_coords[:, 0] - cx) * depths[:, 0] / fx # x 57 | means3d[:, 1] = (image_coords[:, 1] - cy) * depths[:, 0] / fy # y 58 | means3d[:, 2] = depths[:, 0] # z 59 | 60 | # to world coords 61 | means3d = means3d @ np.linalg.inv(c2w[..., :3, :3]) + c2w[..., :3, 3] 62 | return means3d, image_coords 63 | 64 | 65 | def compute_angle_between_normals(normal_map1, normal_map2): 66 | norm1 = np.linalg.norm(normal_map1, axis=2, keepdims=True) 67 | norm2 = np.linalg.norm(normal_map2, axis=2, keepdims=True) 68 | normal_map1_normalized = normal_map1 / norm1 69 | normal_map2_normalized = normal_map2 / norm2 70 | 71 | dot_product = np.sum(normal_map1_normalized * normal_map2_normalized, axis=2) 72 | 73 | dot_product = np.clip(dot_product, -1.0, 1.0) 74 | 75 | angles = np.arccos(dot_product) 76 | 77 | angles_degrees = np.degrees(angles) 78 | 79 | return angles_degrees 80 | 81 | 82 | def depth_path_to_array( 83 | depth_path: Path, scale_factor: float = SCALE_FACTOR, return_color=False 84 | ) -> np.ndarray: 85 | if depth_path.suffix == ".png": 86 | depth = cv2.imread(str(depth_path.absolute()), cv2.IMREAD_ANYDEPTH) 87 | elif depth_path.suffix == ".npy": 88 | depth = np.load(depth_path, allow_pickle=True) 89 | if len(depth.shape) == 3: 90 | depth = depth[..., 0] 91 | else: 92 | raise Exception(f"Format is not supported {depth_path.suffix}") 93 | depth = depth * scale_factor 94 | depth = depth.astype(np.float32) 95 | return depth 96 | 97 | 98 | @dataclass 99 | class DepthNormalConsistency: 100 | """ 101 | Check the normal consistency between the normals from the pre-trained model and the normals from the depth map, 102 | generate depth confidence mask based on the normal consistency. 103 | """ 104 | 105 | data_dir: Path = Path("dataset/room_datasets/vr_room/iphone/long_capture") 106 | """Path to data root""" 107 | transforms_name: str = "transformations_colmap.json" 108 | """transforms file name""" 109 | normal_format: Literal["omnidata", "dsine"] = "omnidata" 110 | """Model type for the pre-trained normals. Omnidata and dsine have different coordinate frames, so need to deal with this.""" 111 | angle_treshold: float = 20.0 112 | """Angle treshold for normal consistency check. Differences bigger than this threshold will be considered as inconsistent and masked.""" 113 | 114 | def main(self): 115 | if os.path.exists(os.path.join(self.data_dir, self.transforms_name)): 116 | CONSOLE.log(f"Found path to {self.transforms_name}") 117 | else: 118 | raise Exception(f"Could not find {self.transforms_name}") 119 | 120 | output_normal_path = os.path.join(self.data_dir, "depth_normals") 121 | output_mask_path = os.path.join(self.data_dir, "depth_normals_mask") 122 | 123 | os.makedirs(output_normal_path, exist_ok=True) 124 | os.makedirs(output_mask_path, exist_ok=True) 125 | 126 | mono_normal_path = os.path.join(self.data_dir, "normals_from_pretrain") 127 | 128 | transforms = load_from_json(self.data_dir / Path(self.transforms_name)) 129 | assert "frames" in transforms 130 | sorted_frames = natsorted(transforms["frames"], key=lambda x: x["file_path"]) 131 | converted_json = deepcopy(transforms) 132 | converted_json["frames"] = deepcopy(sorted_frames) 133 | num_frames = len(sorted_frames) 134 | CONSOLE.log(f"{num_frames} frames to process ...") 135 | if "fl_x" in transforms: 136 | fx = transforms["fl_x"] 137 | fy = transforms["fl_y"] 138 | cx = transforms["cx"] 139 | cy = transforms["cy"] 140 | h = transforms["h"] 141 | w = transforms["w"] 142 | else: 143 | fx = transforms["frames"][0]["fl_x"] 144 | fy = transforms["frames"][0]["fl_y"] 145 | cx = transforms["frames"][0]["cx"] 146 | cy = transforms["frames"][0]["cy"] 147 | h = transforms["frames"][0]["h"] 148 | w = transforms["frames"][0]["w"] 149 | 150 | for i in track(range(num_frames), description="Processing frames..."): 151 | c2w_ref = np.array(sorted_frames[i]["transform_matrix"]) 152 | if c2w_ref.shape[0] != 4: 153 | c2w_ref = np.concatenate([c2w_ref, np.array([[0, 0, 0, 1]])], axis=0) 154 | c2w_ref = c2w_ref @ OPENGL_TO_OPENCV 155 | depth_i = depth_path_to_array( 156 | self.data_dir / Path(sorted_frames[i]["depth_file_path"]) 157 | ) 158 | depth_i = cv2.resize(depth_i, (w, h), interpolation=cv2.INTER_NEAREST) 159 | means3d, image_coords = backproject( 160 | depths=depth_i, 161 | fx=fx, 162 | fy=fy, 163 | cx=cx, 164 | cy=cy, 165 | img_size=(w, h), 166 | c2w=c2w_ref, 167 | ) 168 | cam_center = c2w_ref[:3, 3] 169 | pcd = o3d.geometry.PointCloud() 170 | pcd.points = o3d.utility.Vector3dVector(means3d) 171 | pcd.estimate_normals( 172 | search_param=o3d.geometry.KDTreeSearchParamKNN(knn=200) 173 | ) 174 | normals_from_depth = np.array(pcd.normals) 175 | 176 | # check normal direction: if ray dir and normal angle is smaller than 90, reverse normal 177 | ray_dir = means3d - cam_center.reshape(1, 3) 178 | normal_dir_not_correct = (ray_dir * normals_from_depth).sum(axis=-1) > 0 179 | normals_from_depth[normal_dir_not_correct] = -normals_from_depth[ 180 | normal_dir_not_correct 181 | ] 182 | 183 | normals_from_depth = normals_from_depth.reshape(h, w, 3) 184 | # save images of normals_from_depth for visualization 185 | name = sorted_frames[i]["file_path"].split("/")[-1] 186 | save_name = name.replace("png", "jpg") 187 | cv2.imwrite( 188 | os.path.join(output_normal_path, save_name), 189 | ((normals_from_depth + 1) / 2 * 255).astype(np.uint8), 190 | ) 191 | 192 | # load mono normals 193 | mono_normal = Image.open( 194 | os.path.join(mono_normal_path, name.replace("jpg", "png")) 195 | ) 196 | mono_normal = np.array(mono_normal) / 255.0 197 | h, w, _ = mono_normal.shape 198 | # mono_normals are saved in [0,1] range, but need to be converted to [-1,1] 199 | mono_normal = 2 * mono_normal - 1 200 | 201 | if self.normal_format == "dsine": 202 | # convert normal map coordinate frame 203 | mono_normal = mono_normal.reshape(-1, 3) 204 | mono_normal = mono_normal @ np.diag([1, -1, -1]) 205 | mono_normal = mono_normal.reshape(h, w, 3) 206 | 207 | # convert mono normals from camera frame to world coordinate frame, same as normals_from_depth 208 | w2c = np.linalg.inv(c2w_ref) 209 | R = np.transpose( 210 | w2c[:3, :3] 211 | ) # R is stored transposed due to 'glm' in CUDA code 212 | T = w2c[:3, 3] 213 | mono_normal = mono_normal.reshape(-1, 3).transpose(1, 0) 214 | mono_normal = (R @ mono_normal).T 215 | mono_normal = mono_normal / np.linalg.norm( 216 | mono_normal, axis=1, keepdims=True 217 | ) 218 | mono_normal = mono_normal.reshape(h, w, 3) 219 | 220 | # compute angle between normals_from_depth and mono_normal 221 | degree_map = compute_angle_between_normals(normals_from_depth, mono_normal) 222 | mask = (degree_map > self.angle_treshold).astype(np.uint8) 223 | cv2.imwrite( 224 | os.path.join(output_mask_path, save_name), 225 | mask * 255.0, 226 | ) 227 | 228 | 229 | if __name__ == "__main__": 230 | tyro.cli(DepthNormalConsistency).main() 231 | -------------------------------------------------------------------------------- /dn_splatter/scripts/depth_to_normal.py: -------------------------------------------------------------------------------- 1 | """Script to generate normal estimates from raw sensor depth readings. 2 | 3 | The method uses KNN nearest points to estimate a surface normal. 4 | 5 | """ 6 | 7 | import json 8 | import os 9 | from copy import deepcopy 10 | from dataclasses import dataclass 11 | from pathlib import Path 12 | 13 | import cv2 14 | import numpy as np 15 | import open3d as o3d 16 | import tyro 17 | from natsort import natsorted 18 | from PIL import Image 19 | from rich.console import Console 20 | from tqdm import tqdm 21 | 22 | CONSOLE = Console(width=120) 23 | OPENGL_TO_OPENCV = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 24 | SCALE_FACTOR = 0.001 25 | 26 | 27 | def get_camera_coords(img_size: tuple, pixel_offset: float = 0.5) -> np.ndarray: 28 | image_coords = np.meshgrid( 29 | np.arange(img_size[0]), 30 | np.arange(img_size[1]), 31 | indexing="xy", # W = u by H = v 32 | ) 33 | image_coords = ( 34 | np.stack(image_coords, axis=-1) + pixel_offset 35 | ) # stored as (x, y) coordinates 36 | image_coords = image_coords.reshape(-1, 2) 37 | image_coords = image_coords.astype(np.float32) 38 | return image_coords 39 | 40 | 41 | def backproject( 42 | depths: np.ndarray, 43 | fx: float, 44 | fy: float, 45 | cx: int, 46 | cy: int, 47 | img_size: tuple, 48 | c2w: np.ndarray, 49 | ): 50 | if depths.ndim == 3: 51 | depths = depths.reshape(-1, 1) 52 | elif depths.shape[-1] != 1: 53 | depths = depths[..., np.newaxis] 54 | depths = depths.reshape(-1, 1) 55 | 56 | image_coords = get_camera_coords(img_size) 57 | 58 | means3d = np.zeros([img_size[0], img_size[1], 3], dtype=np.float32).reshape(-1, 3) 59 | means3d[:, 0] = (image_coords[:, 0] - cx) * depths[:, 0] / fx # x 60 | means3d[:, 1] = (image_coords[:, 1] - cy) * depths[:, 0] / fy # y 61 | means3d[:, 2] = depths[:, 0] # z 62 | 63 | # to world coords 64 | means3d = means3d @ np.linalg.inv(c2w[..., :3, :3]) + c2w[..., :3, 3] 65 | return means3d, image_coords 66 | 67 | 68 | def compute_angle_between_normals(normal_map1, normal_map2): 69 | norm1 = np.linalg.norm(normal_map1, axis=2, keepdims=True) 70 | norm2 = np.linalg.norm(normal_map2, axis=2, keepdims=True) 71 | normal_map1_normalized = normal_map1 / norm1 72 | normal_map2_normalized = normal_map2 / norm2 73 | 74 | dot_product = np.sum(normal_map1_normalized * normal_map2_normalized, axis=2) 75 | 76 | dot_product = np.clip(dot_product, -1.0, 1.0) 77 | 78 | angles = np.arccos(dot_product) 79 | 80 | angles_degrees = np.degrees(angles) 81 | 82 | return angles_degrees 83 | 84 | 85 | def load_from_json(filename: Path): 86 | assert filename.suffix == ".json" 87 | with open(filename, encoding="UTF-8") as file: 88 | return json.load(file) 89 | 90 | 91 | def depth_path_to_array( 92 | depth_path: Path, scale_factor: float = SCALE_FACTOR, return_color=False 93 | ) -> np.ndarray: 94 | if depth_path.suffix == ".png": 95 | depth = cv2.imread(str(depth_path.absolute()), cv2.IMREAD_ANYDEPTH) 96 | elif depth_path.suffix == ".npy": 97 | depth = np.load(depth_path, allow_pickle=True) 98 | if len(depth.shape) == 3: 99 | depth = depth[..., 0] 100 | else: 101 | raise Exception(f"Format is not supported {depth_path.suffix}") 102 | depth = depth * scale_factor 103 | depth = depth.astype(np.float32) 104 | return depth 105 | 106 | 107 | @dataclass 108 | class DepthToNormal: 109 | data_dir: Path = None 110 | """Path to data root""" 111 | transforms_name: str = "transformations_colmap.json" 112 | """transforms file name""" 113 | 114 | def main(self): 115 | if os.path.exists(os.path.join(self.data_dir, self.transforms_name)): 116 | CONSOLE.log(f"Found path to {self.transforms_name}") 117 | else: 118 | raise Exception(f"Could not find {self.transforms_name}") 119 | 120 | output_normal_path = os.path.join(self.data_dir, "depth_normals") 121 | output_mask_path = os.path.join(self.data_dir, "depth_normals_mask") 122 | 123 | os.makedirs(output_normal_path, exist_ok=True) 124 | os.makedirs(output_mask_path, exist_ok=True) 125 | 126 | mono_normal_path = os.path.join(self.data_dir, "normals_from_pretrain") 127 | 128 | transforms = load_from_json(self.data_dir / Path(self.transforms_name)) 129 | assert "frames" in transforms 130 | sorted_frames = natsorted(transforms["frames"], key=lambda x: x["file_path"]) 131 | converted_json = deepcopy(transforms) 132 | converted_json["frames"] = deepcopy(sorted_frames) 133 | num_frames = len(sorted_frames) 134 | CONSOLE.log(f"{num_frames} frames to process ...") 135 | if "fl_x" in transforms: 136 | fx = transforms["fl_x"] 137 | fy = transforms["fl_y"] 138 | cx = transforms["cx"] 139 | cy = transforms["cy"] 140 | h = transforms["h"] 141 | w = transforms["w"] 142 | else: 143 | fx = transforms["frames"][0]["fl_x"] 144 | fy = transforms["frames"][0]["fl_y"] 145 | cx = transforms["frames"][0]["cx"] 146 | cy = transforms["frames"][0]["cy"] 147 | h = transforms["frames"][0]["h"] 148 | w = transforms["frames"][0]["w"] 149 | 150 | for i in tqdm(range(num_frames), desc="processing ..."): 151 | c2w_ref = np.array(sorted_frames[i]["transform_matrix"]) 152 | if c2w_ref.shape[0] != 4: 153 | c2w_ref = np.concatenate([c2w_ref, np.array([[0, 0, 0, 1]])], axis=0) 154 | c2w_ref = c2w_ref @ OPENGL_TO_OPENCV 155 | depth_i = depth_path_to_array( 156 | self.data_dir / Path(sorted_frames[i]["depth_file_path"]) 157 | ) 158 | depth_i = cv2.resize(depth_i, (w, h), interpolation=cv2.INTER_NEAREST) 159 | means3d, image_coords = backproject( 160 | depths=depth_i, 161 | fx=fx, 162 | fy=fy, 163 | cx=cx, 164 | cy=cy, 165 | img_size=(w, h), 166 | c2w=c2w_ref, 167 | ) 168 | cam_center = c2w_ref[:3, 3] 169 | pcd = o3d.geometry.PointCloud() 170 | pcd.points = o3d.utility.Vector3dVector(means3d) 171 | pcd.estimate_normals( 172 | search_param=o3d.geometry.KDTreeSearchParamKNN(knn=200) 173 | ) 174 | normals = np.array(pcd.normals) 175 | 176 | # check normal direction: if ray dir and normal angle is smaller than 90, reverse normal 177 | ray_dir = means3d - cam_center.reshape(1, 3) 178 | normal_dir_not_correct = (ray_dir * normals).sum(axis=-1) > 0 179 | normals[normal_dir_not_correct] = -normals[normal_dir_not_correct] 180 | 181 | normals = normals.reshape(h, w, 3) 182 | # color normal 183 | normals = (normals + 1) / 2 184 | saved_normals = (normals * 255).astype(np.uint8) 185 | name = sorted_frames[i]["file_path"].split("/")[-1] 186 | 187 | cv2.imwrite( 188 | os.path.join(output_normal_path, name), 189 | saved_normals, 190 | ) 191 | 192 | mono_normal = Image.open( 193 | os.path.join(mono_normal_path, name.replace("jpg", "png")) 194 | ) 195 | mono_normal = np.array(mono_normal) / 255.0 196 | h, w, _ = mono_normal.shape 197 | w2c = np.linalg.inv(c2w_ref) 198 | R = np.transpose( 199 | w2c[:3, :3] 200 | ) # R is stored transposed due to 'glm' in CUDA code 201 | T = w2c[:3, 3] 202 | normal = mono_normal.reshape(-1, 3).transpose(1, 0) 203 | normal = (normal - 0.5) * 2 204 | normal = (R @ normal).T 205 | normal = normal / np.linalg.norm(normal, axis=1, keepdims=True) 206 | mono_normal = normal.reshape(h, w, 3) * 0.5 + 0.5 207 | 208 | degree_map = compute_angle_between_normals(normals, mono_normal) 209 | mask = (degree_map > 10).astype(np.uint8) 210 | cv2.imwrite( 211 | os.path.join(output_mask_path, name), 212 | mask * 255.0, 213 | ) 214 | 215 | # some useful debug 216 | 217 | # filtered_depth = np.where(mask == 0, depth_i, 0) 218 | # np.save( 219 | # os.path.join(output_filtered_depth_path, name.replace("jpg", "npy")), 220 | # filtered_depth, 221 | # ) 222 | 223 | 224 | if __name__ == "__main__": 225 | tyro.cli(DepthToNormal).main() 226 | -------------------------------------------------------------------------------- /dn_splatter/scripts/dsine/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/maturk/dn-splatter/249d52c4bb14b7bf6dd18d7d66099a36eac2ee78/dn_splatter/scripts/dsine/__init__.py -------------------------------------------------------------------------------- /dn_splatter/scripts/dsine/dsine_predictor.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | from typing import Literal, Optional 5 | from torchvision import transforms 6 | from jaxtyping import UInt8, Float 7 | import torch.nn.functional as F 8 | 9 | from dn_splatter.scripts.dsine.dsine import DSINE 10 | 11 | 12 | def pad_input(original_height: int, original_width: int): 13 | if original_width % 32 == 0: 14 | left = 0 15 | right = 0 16 | else: 17 | new_width = 32 * ((original_width // 32) + 1) 18 | left = (new_width - original_width) // 2 19 | right = (new_width - original_width) - left 20 | 21 | if original_height % 32 == 0: 22 | top = 0 23 | bottom = 0 24 | else: 25 | new_height = 32 * ((original_height // 32) + 1) 26 | top = (new_height - original_height) // 2 27 | bottom = (new_height - original_height) - top 28 | return left, right, top, bottom 29 | 30 | 31 | def get_intrins_from_fov(new_fov, height, width, device): 32 | # NOTE: top-left pixel should be (0,0) 33 | if width >= height: 34 | new_fu = (width / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) 35 | new_fv = (width / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) 36 | else: 37 | new_fu = (height / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) 38 | new_fv = (height / 2.0) / np.tan(np.deg2rad(new_fov / 2.0)) 39 | 40 | new_cu = (width / 2.0) - 0.5 41 | new_cv = (height / 2.0) - 0.5 42 | 43 | new_intrins = torch.tensor( 44 | [[new_fu, 0, new_cu], [0, new_fv, new_cv], [0, 0, 1]], 45 | dtype=torch.float32, 46 | device=device, 47 | ) 48 | 49 | return new_intrins 50 | 51 | 52 | def _load_state_dict(local_file_path: Optional[str] = None): 53 | if local_file_path is not None and os.path.exists(local_file_path): 54 | # Load state_dict from local file 55 | state_dict = torch.load(local_file_path, map_location=torch.device("cpu")) 56 | else: 57 | # Load state_dict from the default URL 58 | file_name = "dsine.pt" 59 | url = f"https://huggingface.co/camenduru/DSINE/resolve/main/dsine.pt" 60 | state_dict = torch.hub.load_state_dict_from_url( 61 | url, file_name=file_name, map_location=torch.device("cpu") 62 | ) 63 | 64 | return state_dict["model"] 65 | 66 | 67 | class DSinePredictor: 68 | def __init__(self, device: Literal["cpu", "cuda"]): 69 | self.device = device 70 | self.transform = transforms.Normalize( 71 | mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] 72 | ) 73 | self.model = self.load_model() 74 | 75 | def load_model(self): 76 | state_dict = _load_state_dict(None) 77 | model = DSINE() 78 | model.load_state_dict(state_dict, strict=True) 79 | model.eval() 80 | model = model.to(self.device) 81 | model.pixel_coords = model.pixel_coords.to(self.device) 82 | 83 | return model 84 | 85 | def __call__( 86 | self, 87 | rgb: UInt8[np.ndarray, "h w 3"], 88 | K_33: Optional[Float[np.ndarray, "3 3"]] = None, 89 | ) -> Float[torch.Tensor, "b 3 h w"]: 90 | rgb = rgb.astype(np.float32) / 255.0 91 | rgb = torch.from_numpy(rgb).permute(2, 0, 1).unsqueeze(0).to(self.device) 92 | _, _, h, w = rgb.shape 93 | 94 | # zero-pad the input image so that both the width and height are multiples of 32 95 | left, right, top, bottom = pad_input(h, w) 96 | rgb = F.pad(rgb, (left, right, top, bottom), mode="constant", value=0.0) 97 | rgb = self.transform(rgb) 98 | 99 | if K_33 is None: 100 | K_b33: Float[torch.Tensor, "b 3 3"] = get_intrins_from_fov( 101 | new_fov=60.0, height=h, width=w, device=self.device 102 | ).unsqueeze(0) 103 | else: 104 | K_b33 = torch.from_numpy(K_33).unsqueeze(0).to(self.device) 105 | 106 | # add padding to intrinsics 107 | K_b33[:, 0, 2] += left 108 | K_b33[:, 1, 2] += top 109 | 110 | with torch.no_grad(): 111 | normal_b3hw: Float[torch.Tensor, "b 3 h-t w-l"] = self.model( 112 | rgb, intrins=K_b33 113 | )[-1] 114 | normal_b3hw: Float[torch.Tensor, "b 3 h w"] = normal_b3hw[ 115 | :, :, top : top + h, left : left + w 116 | ] 117 | 118 | return normal_b3hw 119 | -------------------------------------------------------------------------------- /dn_splatter/scripts/dsine/rotations.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # NOTE: from PyTorch3D 5 | def axis_angle_to_quaternion(axis_angle: torch.Tensor) -> torch.Tensor: 6 | """ 7 | Convert rotations given as axis/angle to quaternions. 8 | 9 | Args: 10 | axis_angle: Rotations given as a vector in axis angle form, 11 | as a tensor of shape (..., 3), where the magnitude is 12 | the angle turned anticlockwise in radians around the 13 | vector's direction. 14 | 15 | Returns: 16 | quaternions with real part first, as tensor of shape (..., 4). 17 | """ 18 | angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True) 19 | half_angles = angles * 0.5 20 | eps = 1e-6 21 | small_angles = angles.abs() < eps 22 | sin_half_angles_over_angles = torch.empty_like(angles) 23 | sin_half_angles_over_angles[~small_angles] = ( 24 | torch.sin(half_angles[~small_angles]) / angles[~small_angles] 25 | ) 26 | # for x small, sin(x/2) is about x/2 - (x/2)^3/6 27 | # so sin(x/2)/x is about 1/2 - (x*x)/48 28 | sin_half_angles_over_angles[small_angles] = ( 29 | 0.5 - (angles[small_angles] * angles[small_angles]) / 48 30 | ) 31 | quaternions = torch.cat( 32 | [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1 33 | ) 34 | return quaternions 35 | 36 | 37 | # NOTE: from PyTorch3D 38 | def quaternion_to_matrix(quaternions: torch.Tensor) -> torch.Tensor: 39 | """ 40 | Convert rotations given as quaternions to rotation matrices. 41 | 42 | Args: 43 | quaternions: quaternions with real part first, 44 | as tensor of shape (..., 4). 45 | 46 | Returns: 47 | Rotation matrices as tensor of shape (..., 3, 3). 48 | """ 49 | r, i, j, k = torch.unbind(quaternions, -1) 50 | # pyre-fixme[58]: `/` is not supported for operand types `float` and `Tensor`. 51 | two_s = 2.0 / (quaternions * quaternions).sum(-1) 52 | 53 | o = torch.stack( 54 | ( 55 | 1 - two_s * (j * j + k * k), 56 | two_s * (i * j - k * r), 57 | two_s * (i * k + j * r), 58 | two_s * (i * j + k * r), 59 | 1 - two_s * (i * i + k * k), 60 | two_s * (j * k - i * r), 61 | two_s * (i * k - j * r), 62 | two_s * (j * k + i * r), 63 | 1 - two_s * (i * i + j * j), 64 | ), 65 | -1, 66 | ) 67 | return o.reshape(quaternions.shape[:-1] + (3, 3)) 68 | 69 | 70 | # NOTE: from PyTorch3D 71 | def axis_angle_to_matrix(axis_angle: torch.Tensor) -> torch.Tensor: 72 | """ 73 | Convert rotations given as axis/angle to rotation matrices. 74 | 75 | Args: 76 | axis_angle: Rotations given as a vector in axis angle form, 77 | as a tensor of shape (..., 3), where the magnitude is 78 | the angle turned anticlockwise in radians around the 79 | vector's direction. 80 | 81 | Returns: 82 | Rotation matrices as tensor of shape (..., 3, 3). 83 | """ 84 | return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle)) 85 | -------------------------------------------------------------------------------- /dn_splatter/scripts/dsine/submodules.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import geffnet 5 | 6 | 7 | INPUT_CHANNELS_DICT = { 8 | 0: [1280, 112, 40, 24, 16], 9 | 1: [1280, 112, 40, 24, 16], 10 | 2: [1408, 120, 48, 24, 16], 11 | 3: [1536, 136, 48, 32, 24], 12 | 4: [1792, 160, 56, 32, 24], 13 | 5: [2048, 176, 64, 40, 24], 14 | 6: [2304, 200, 72, 40, 32], 15 | 7: [2560, 224, 80, 48, 32], 16 | } 17 | 18 | 19 | class Encoder(nn.Module): 20 | def __init__(self, B=5, pretrained=True): 21 | """e.g. B=5 will return EfficientNet-B5""" 22 | super(Encoder, self).__init__() 23 | basemodel = geffnet.create_model( 24 | "tf_efficientnet_b%s_ap" % B, pretrained=pretrained 25 | ) 26 | # Remove last layer 27 | basemodel.global_pool = nn.Identity() 28 | basemodel.classifier = nn.Identity() 29 | self.original_model = basemodel 30 | 31 | def forward(self, x): 32 | features = [x] 33 | for k, v in self.original_model._modules.items(): 34 | if k == "blocks": 35 | for ki, vi in v._modules.items(): 36 | features.append(vi(features[-1])) 37 | else: 38 | features.append(v(features[-1])) 39 | return features 40 | 41 | 42 | class ConvGRU(nn.Module): 43 | def __init__(self, hidden_dim, input_dim, ks=3): 44 | super(ConvGRU, self).__init__() 45 | p = (ks - 1) // 2 46 | self.convz = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p) 47 | self.convr = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p) 48 | self.convq = nn.Conv2d(hidden_dim + input_dim, hidden_dim, ks, padding=p) 49 | 50 | def forward(self, h, x): 51 | hx = torch.cat([h, x], dim=1) 52 | z = torch.sigmoid(self.convz(hx)) 53 | r = torch.sigmoid(self.convr(hx)) 54 | q = torch.tanh(self.convq(torch.cat([r * h, x], dim=1))) 55 | h = (1 - z) * h + z * q 56 | return h 57 | 58 | 59 | class RayReLU(nn.Module): 60 | def __init__(self, eps=1e-2): 61 | super(RayReLU, self).__init__() 62 | self.eps = eps 63 | 64 | def forward(self, pred_norm, ray): 65 | # angle between the predicted normal and ray direction 66 | cos = torch.cosine_similarity(pred_norm, ray, dim=1).unsqueeze( 67 | 1 68 | ) # (B, 1, H, W) 69 | 70 | # component of pred_norm along view 71 | norm_along_view = ray * cos 72 | 73 | # cos should be bigger than eps 74 | norm_along_view_relu = ray * (torch.relu(cos - self.eps) + self.eps) 75 | 76 | # difference 77 | diff = norm_along_view_relu - norm_along_view 78 | 79 | # updated pred_norm 80 | new_pred_norm = pred_norm + diff 81 | new_pred_norm = F.normalize(new_pred_norm, dim=1) 82 | 83 | return new_pred_norm 84 | 85 | 86 | class UpSampleBN(nn.Module): 87 | def __init__(self, skip_input, output_features, align_corners=True): 88 | super(UpSampleBN, self).__init__() 89 | self._net = nn.Sequential( 90 | nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1), 91 | nn.BatchNorm2d(output_features), 92 | nn.LeakyReLU(), 93 | nn.Conv2d( 94 | output_features, output_features, kernel_size=3, stride=1, padding=1 95 | ), 96 | nn.BatchNorm2d(output_features), 97 | nn.LeakyReLU(), 98 | ) 99 | self.align_corners = align_corners 100 | 101 | def forward(self, x, concat_with): 102 | up_x = F.interpolate( 103 | x, 104 | size=[concat_with.size(2), concat_with.size(3)], 105 | mode="bilinear", 106 | align_corners=self.align_corners, 107 | ) 108 | f = torch.cat([up_x, concat_with], dim=1) 109 | return self._net(f) 110 | 111 | 112 | class Conv2d_WS(nn.Conv2d): 113 | """weight standardization""" 114 | 115 | def __init__( 116 | self, 117 | in_channels, 118 | out_channels, 119 | kernel_size, 120 | stride=1, 121 | padding=0, 122 | dilation=1, 123 | groups=1, 124 | bias=True, 125 | ): 126 | super(Conv2d_WS, self).__init__( 127 | in_channels, 128 | out_channels, 129 | kernel_size, 130 | stride, 131 | padding, 132 | dilation, 133 | groups, 134 | bias, 135 | ) 136 | 137 | def forward(self, x): 138 | weight = self.weight 139 | weight_mean = ( 140 | weight.mean(dim=1, keepdim=True) 141 | .mean(dim=2, keepdim=True) 142 | .mean(dim=3, keepdim=True) 143 | ) 144 | weight = weight - weight_mean 145 | std = weight.view(weight.size(0), -1).std(dim=1).view(-1, 1, 1, 1) + 1e-5 146 | weight = weight / std.expand_as(weight) 147 | return F.conv2d( 148 | x, weight, self.bias, self.stride, self.padding, self.dilation, self.groups 149 | ) 150 | 151 | 152 | class UpSampleGN(nn.Module): 153 | """UpSample with GroupNorm""" 154 | 155 | def __init__(self, skip_input, output_features, align_corners=True): 156 | super(UpSampleGN, self).__init__() 157 | self._net = nn.Sequential( 158 | Conv2d_WS(skip_input, output_features, kernel_size=3, stride=1, padding=1), 159 | nn.GroupNorm(8, output_features), 160 | nn.LeakyReLU(), 161 | Conv2d_WS( 162 | output_features, output_features, kernel_size=3, stride=1, padding=1 163 | ), 164 | nn.GroupNorm(8, output_features), 165 | nn.LeakyReLU(), 166 | ) 167 | self.align_corners = align_corners 168 | 169 | def forward(self, x, concat_with): 170 | up_x = F.interpolate( 171 | x, 172 | size=[concat_with.size(2), concat_with.size(3)], 173 | mode="bilinear", 174 | align_corners=self.align_corners, 175 | ) 176 | f = torch.cat([up_x, concat_with], dim=1) 177 | return self._net(f) 178 | 179 | 180 | def upsample_via_bilinear(out, up_mask, downsample_ratio): 181 | """bilinear upsampling (up_mask is a dummy variable)""" 182 | return F.interpolate( 183 | out, scale_factor=downsample_ratio, mode="bilinear", align_corners=True 184 | ) 185 | 186 | 187 | def upsample_via_mask(out, up_mask, downsample_ratio): 188 | """convex upsampling""" 189 | # out: low-resolution output (B, o_dim, H, W) 190 | # up_mask: (B, 9*k*k, H, W) 191 | k = downsample_ratio 192 | 193 | N, o_dim, H, W = out.shape 194 | up_mask = up_mask.view(N, 1, 9, k, k, H, W) 195 | up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W) 196 | 197 | up_out = F.unfold(out, [3, 3], padding=1) # (B, 2, H, W) -> (B, 2 X 3*3, H*W) 198 | up_out = up_out.view(N, o_dim, 9, 1, 1, H, W) # (B, 2, 3*3, 1, 1, H, W) 199 | up_out = torch.sum(up_mask * up_out, dim=2) # (B, 2, k, k, H, W) 200 | 201 | up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, 2, H, k, W, k) 202 | return up_out.reshape(N, o_dim, k * H, k * W) # (B, 2, kH, kW) 203 | 204 | 205 | def convex_upsampling(out, up_mask, k): 206 | # out: low-resolution output (B, C, H, W) 207 | # up_mask: (B, 9*k*k, H, W) 208 | B, C, H, W = out.shape 209 | up_mask = up_mask.view(B, 1, 9, k, k, H, W) 210 | up_mask = torch.softmax(up_mask, dim=2) # (B, 1, 9, k, k, H, W) 211 | 212 | out = F.pad(out, pad=(1, 1, 1, 1), mode="replicate") 213 | up_out = F.unfold(out, [3, 3], padding=0) # (B, C, H, W) -> (B, C X 3*3, H*W) 214 | up_out = up_out.view(B, C, 9, 1, 1, H, W) # (B, C, 9, 1, 1, H, W) 215 | 216 | up_out = torch.sum(up_mask * up_out, dim=2) # (B, C, k, k, H, W) 217 | up_out = up_out.permute(0, 1, 4, 2, 5, 3) # (B, C, H, k, W, k) 218 | return up_out.reshape(B, C, k * H, k * W) # (B, C, kH, kW) 219 | 220 | 221 | def get_unfold(pred_norm, ps, pad): 222 | B, C, H, W = pred_norm.shape 223 | pred_norm = F.pad( 224 | pred_norm, pad=(pad, pad, pad, pad), mode="replicate" 225 | ) # (B, C, h, w) 226 | pred_norm_unfold = F.unfold(pred_norm, [ps, ps], padding=0) # (B, C X ps*ps, h*w) 227 | pred_norm_unfold = pred_norm_unfold.view(B, C, ps * ps, H, W) # (B, C, ps*ps, h, w) 228 | return pred_norm_unfold 229 | 230 | 231 | def get_prediction_head(input_dim, hidden_dim, output_dim): 232 | return nn.Sequential( 233 | nn.Conv2d(input_dim, hidden_dim, 3, padding=1), 234 | nn.ReLU(inplace=True), 235 | nn.Conv2d(hidden_dim, hidden_dim, 1), 236 | nn.ReLU(inplace=True), 237 | nn.Conv2d(hidden_dim, output_dim, 1), 238 | ) 239 | -------------------------------------------------------------------------------- /dn_splatter/scripts/poses_to_colmap_sfm.py: -------------------------------------------------------------------------------- 1 | """Convert poses from transforms.json to colmap sfm""" 2 | 3 | import json 4 | import os 5 | from pathlib import Path 6 | 7 | import numpy as np 8 | 9 | from nerfstudio.data.utils.colmap_parsing_utils import rotmat2qvec 10 | 11 | 12 | class PosesToColmap: 13 | def __init__( 14 | self, 15 | transforms_path: str, 16 | run_colmap: bool = True, 17 | assume_colmap_world_coordinate_convention: bool = True, 18 | ): 19 | self.transforms_path = transforms_path 20 | self.run_colmap_cmd = run_colmap 21 | self.assume_colmap_world_coordinate_convention = ( 22 | assume_colmap_world_coordinate_convention 23 | ) 24 | self.camera = "OPENCV" 25 | self.image_path = "images" 26 | self.use_gpu = "False" 27 | 28 | def run_colmap(self): 29 | """ 30 | colmap feature_extractor \ 31 | --database_path $PROJECT_PATH/database.db \ 32 | --image_path $PROJECT_PATH/images 33 | 34 | colmap exhaustive_matcher \ # or alternatively any other matcher 35 | --database_path $PROJECT_PATH/database.db 36 | 37 | colmap point_triangulator \ 38 | --database_path $PROJECT_PATH/database.db \ 39 | --image_path $PROJECT_PATH/images 40 | --input_path path/to/manually/created/sparse/model \ 41 | --output_path path/to/triangulated/sparse/model 42 | 43 | """ 44 | output_db = self.base_dir / "database.db" 45 | use_gpu = 1 if self.use_gpu else 0 46 | 47 | feature_cmd = ( 48 | "colmap", 49 | "feature_extractor", 50 | "--database_path", 51 | str(output_db), 52 | "--image_path", 53 | str(self.base_dir / self.image_path), 54 | "--ImageReader.single_camera", 55 | "0", 56 | "--ImageReader.camera_model", 57 | self.camera, 58 | "--SiftExtraction.use_gpu", 59 | str(use_gpu), 60 | ) 61 | feature_cmd = (" ").join(feature_cmd) 62 | os.system(feature_cmd) 63 | 64 | match_cmd = ( 65 | "colmap", 66 | "exhaustive_matcher", 67 | "--database_path", 68 | str(output_db), 69 | ) 70 | match_cmd = (" ").join(match_cmd) 71 | os.system(match_cmd) 72 | 73 | triangulate_cmd = ( 74 | "colmap", 75 | "point_triangulator", 76 | "--database_path", 77 | str(output_db), 78 | "--image_path", 79 | str(self.base_dir / self.image_path), 80 | "--input_path", 81 | str(self.sparse_dir), 82 | "--output_path", 83 | str(self.sparse_dir), 84 | ) 85 | triangulate_cmd = (" ").join(triangulate_cmd) 86 | os.system(triangulate_cmd) 87 | 88 | def manual_sparse(self): 89 | print("Creating sparse model manually") 90 | with open(self.transforms_path) as f: 91 | data = json.load(f) 92 | 93 | self.base_dir = Path(self.transforms_path).parent.absolute() 94 | self.sparse_dir = self.base_dir / "sparse" / "0" 95 | self.sparse_dir.mkdir(parents=True, exist_ok=True) 96 | 97 | points_txt = self.sparse_dir / "points3D.txt" 98 | with open(str(points_txt), "w") as f: 99 | f.write("") 100 | images_txt = self.sparse_dir / "images.txt" 101 | cameras_txt = self.sparse_dir / "cameras.txt" 102 | 103 | camera_model = data["camera_model"] 104 | 105 | if "fl_x" in data: 106 | fx = data["fl_x"] 107 | fy = data["fl_y"] 108 | cx = data["cx"] 109 | cy = data["cy"] 110 | height = data["h"] 111 | width = data["w"] 112 | with open(str(cameras_txt), "w") as f: 113 | f.write("# Camera list with one line of data per camera:\n") 114 | f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n") 115 | f.write("# Number of cameras: 1\n") 116 | f.write( 117 | f"1 {camera_model} {width} {height} {fx} {fy} {cx} {cy} 0 0 0 0\n" 118 | ) 119 | else: 120 | with open(str(cameras_txt), "w") as f: 121 | f.write("# Camera list with one line of data per camera:\n") 122 | f.write("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]\n") 123 | f.write(f"# Number of cameras: {len(data['frames'])}\n") 124 | for item, frame in enumerate(data["frames"]): 125 | fx = frame["fl_x"] 126 | fy = frame["fl_y"] 127 | cx = frame["cx"] 128 | cy = frame["cy"] 129 | height = frame["h"] 130 | width = frame["w"] 131 | f.write( 132 | f"{item+1} {camera_model} {width} {height} {fx} {fy} {cx} {cy} 0 0 0 0\n" 133 | ) 134 | with open(str(images_txt), "w") as imgs_txt: 135 | for id, frame in enumerate(data["frames"]): 136 | c2w = np.array(frame["transform_matrix"]) 137 | if self.assume_colmap_world_coordinate_convention: 138 | c2w[2, :] *= -1 139 | c2w = c2w[np.array([0, 2, 1, 3]), :] 140 | c2w[0:3, 1:3] *= -1 141 | w2c = np.linalg.inv(c2w) 142 | rotation = w2c[:3, :3] 143 | translation = w2c[:3, 3] 144 | qvec = rotmat2qvec(rotation) 145 | name = Path(frame["file_path"]).name 146 | camera_id = 1 147 | image_id = id + 1 148 | 149 | qvec_str = ", ".join(map(str, qvec.tolist())) 150 | translation_str = ", ".join(map(str, translation.tolist())) 151 | imgs_txt.write( 152 | f"{image_id} {qvec_str} {translation_str} {camera_id} {name}" 153 | ) 154 | imgs_txt.write("\n") 155 | imgs_txt.write("\n") 156 | 157 | def main(self): 158 | self.manual_sparse() 159 | if self.run_colmap_cmd: 160 | self.run_colmap() 161 | 162 | 163 | if __name__ == "__main__": 164 | import tyro 165 | 166 | tyro.cli(PosesToColmap).main() 167 | -------------------------------------------------------------------------------- /dn_splatter/scripts/process_sai.py: -------------------------------------------------------------------------------- 1 | """Process a single custom SAI input""" 2 | import json 3 | import os 4 | import shutil 5 | import subprocess 6 | import tempfile 7 | 8 | SAI_CLI_PROCESS_PARAMS = { 9 | "image_format": "png", 10 | "no_undistort": None, 11 | "key_frame_distance": 0.1, 12 | "internal": { 13 | "maxKeypoints": 2000, 14 | "optimizerMaxIterations": 50, 15 | }, 16 | } 17 | 18 | DEFAULT_OUT_FOLDER = "datasets/custom" 19 | 20 | 21 | def ensure_exposure_time(target, input_folder): 22 | trans_fn = os.path.join(target, "transforms.json") 23 | with open(trans_fn) as f: 24 | transforms = json.load(f) 25 | 26 | if "exposure_time" in transforms: 27 | return 28 | 29 | with open(os.path.join(input_folder, "data.jsonl")) as f: 30 | for line in f: 31 | d = json.loads(line) 32 | if "frames" in d: 33 | e = d["frames"][0].get("exposureTimeSeconds", None) 34 | if e is not None: 35 | print("got exposure time %g from data.jsonl" % e) 36 | transforms["exposure_time"] = e 37 | with open(trans_fn, "wt") as f: 38 | json.dump(transforms, f, indent=4) 39 | return 40 | 41 | raise RuntimeError("no exposure time available") 42 | 43 | 44 | def process(args): 45 | def maybe_run_cmd(cmd): 46 | print("COMMAND:", cmd) 47 | if not args.dry_run: 48 | subprocess.check_call(cmd) 49 | 50 | def maybe_unzip(fn): 51 | name = os.path.basename(fn) 52 | if name.endswith(".zip"): 53 | name = name[:-4] 54 | tempdir = tempfile.mkdtemp() 55 | input_folder = os.path.join(tempdir, "recording") 56 | extract_command = [ 57 | "unzip", 58 | fn, 59 | "-d", 60 | input_folder, 61 | ] 62 | maybe_run_cmd(extract_command) 63 | if not args.dry_run: 64 | # handle folder inside zip 65 | for f in os.listdir(input_folder): 66 | if f == name: 67 | input_folder = os.path.join(input_folder, f) 68 | break 69 | else: 70 | input_folder = fn 71 | 72 | return name, input_folder 73 | 74 | sai_params = json.loads(json.dumps(SAI_CLI_PROCESS_PARAMS)) 75 | sai_params["key_frame_distance"] = args.key_frame_distance 76 | 77 | tempdir = None 78 | name, input_folder = maybe_unzip(args.spectacular_rec_input_folder_or_zip) 79 | sai_params_list = [] 80 | for k, v in sai_params.items(): 81 | if k == "internal": 82 | for k2, v2 in v.items(): 83 | sai_params_list.append(f"--{k}={k2}:{v2}") 84 | else: 85 | if v is None: 86 | sai_params_list.append(f"--{k}") 87 | else: 88 | sai_params_list.append(f"--{k}={v}") 89 | 90 | result_name = name 91 | 92 | if args.output_folder is None: 93 | final_target = os.path.join(DEFAULT_OUT_FOLDER, result_name) 94 | else: 95 | final_target = args.output_folder 96 | 97 | target = final_target 98 | 99 | cmd = ["sai-cli", "process", input_folder, target] + sai_params_list 100 | 101 | if args.preview: 102 | cmd.extend(["--preview", "--preview3d"]) 103 | 104 | if os.path.exists(target): 105 | shutil.rmtree(target) 106 | maybe_run_cmd(cmd) 107 | if not args.dry_run: 108 | ensure_exposure_time(target, input_folder) 109 | 110 | 111 | if __name__ == "__main__": 112 | import argparse 113 | 114 | parser = argparse.ArgumentParser(description=__doc__) 115 | parser.add_argument("spectacular_rec_input_folder_or_zip", type=str) 116 | parser.add_argument("output_folder", type=str, default=None, nargs="?") 117 | parser.add_argument("--preview", action="store_true") 118 | parser.add_argument("--dry_run", action="store_true") 119 | parser.add_argument( 120 | "--key_frame_distance", 121 | type=float, 122 | default=0.1, 123 | help="Minimum key frame distance in meters, default (0.1), increase for larger scenes", 124 | ) 125 | args = parser.parse_args() 126 | 127 | process(args) 128 | -------------------------------------------------------------------------------- /dn_splatter/scripts/render_model.py: -------------------------------------------------------------------------------- 1 | """Load DN-Splatter model and render all outputs to disk""" 2 | 3 | from dataclasses import dataclass 4 | from pathlib import Path 5 | from typing import Literal 6 | 7 | import torch 8 | import tyro 9 | from dn_splatter.utils.utils import save_outputs_helper 10 | 11 | from nerfstudio.cameras.cameras import Cameras 12 | from nerfstudio.data.datasets.base_dataset import InputDataset 13 | from nerfstudio.models.splatfacto import SplatfactoModel 14 | from nerfstudio.utils import colormaps 15 | from nerfstudio.utils.eval_utils import eval_setup 16 | 17 | 18 | @dataclass 19 | class RenderModel: 20 | """Render outputs of a GS model.""" 21 | 22 | load_config: Path = Path("") 23 | """Path to the config YAML file.""" 24 | output_dir: Path = Path("./renders/") 25 | """Path to the output directory.""" 26 | render_rgb: bool = True 27 | render_depth: bool = True 28 | render_normal: bool = True 29 | split: Literal["all", "test", "eval"] = "all" 30 | 31 | def main(self): 32 | if not self.output_dir.exists(): 33 | self.output_dir.mkdir(parents=True) 34 | 35 | _, pipeline, _, _ = eval_setup(self.load_config) 36 | 37 | assert isinstance(pipeline.model, SplatfactoModel) 38 | 39 | model: SplatfactoModel = pipeline.model 40 | train_dataset: InputDataset = pipeline.datamanager.train_dataset 41 | eval_dataset: InputDataset = pipeline.datamanager.eval_dataset 42 | 43 | # train dataset 44 | if self.split in ["test", "all"]: 45 | with torch.no_grad(): 46 | cameras: Cameras = pipeline.datamanager.train_dataset.cameras # type: ignore 47 | for image_idx, data in enumerate( 48 | pipeline.datamanager.cached_train # Undistorted images 49 | ): # type: ignore 50 | # process batch gt data 51 | mask = None 52 | if "mask" in data: 53 | mask = data["mask"] 54 | 55 | gt_img = data["image"] 56 | if "sensor_depth" in data: 57 | depth_gt = data["sensor_depth"] 58 | depth_gt_color = colormaps.apply_depth_colormap( 59 | data["sensor_depth"] 60 | ) 61 | else: 62 | depth_gt = None 63 | depth_gt_color = None 64 | if "normal" in data: 65 | normal_gt = data["normal"] 66 | 67 | # process pred outputs 68 | camera = cameras[image_idx : image_idx + 1].to("cpu") 69 | outputs = model.get_outputs_for_camera(camera=camera) 70 | 71 | rgb_out, depth_out = outputs["rgb"], outputs["depth"] 72 | 73 | normal = None 74 | if "normal" in outputs: 75 | normal = outputs["normal"] 76 | 77 | seq_name = Path(train_dataset.image_filenames[image_idx]) 78 | image_name = f"{seq_name.stem}" 79 | 80 | depth_color = colormaps.apply_depth_colormap(depth_out) 81 | depth = depth_out.detach().cpu().numpy() 82 | 83 | if mask is not None: 84 | rgb_out = rgb_out * mask 85 | gt_img = gt_img * mask 86 | if depth_color is not None: 87 | depth_color = depth_color * mask 88 | if depth_gt_color is not None: 89 | depth_gt_color = depth_gt_color * mask 90 | if depth_gt is not None: 91 | depth_gt = depth_gt * mask 92 | if depth is not None: 93 | depth = depth * mask 94 | if normal_gt is not None: 95 | normal_gt = normal_gt * mask 96 | if normal is not None: 97 | normal = normal * mask 98 | 99 | # save all outputs 100 | save_outputs_helper( 101 | rgb_out if self.render_rgb else None, 102 | gt_img if self.render_rgb else None, 103 | depth_color if self.render_depth else None, 104 | depth_gt_color if self.render_depth else None, 105 | depth_gt if self.render_depth else None, 106 | depth if self.render_depth else None, 107 | normal_gt if self.render_normal else None, 108 | normal if self.render_normal else None, 109 | self.output_dir / "train", 110 | image_name, 111 | ) 112 | 113 | # eval dataset 114 | if self.split in ["eval", "all"]: 115 | with torch.no_grad(): 116 | cameras: Cameras = pipeline.datamanager.eval_dataset.cameras # type: ignore 117 | for image_idx, data in enumerate( 118 | pipeline.datamanager.cached_eval # Undistorted images 119 | ): # type: ignore 120 | 121 | # process batch gt data 122 | mask = None 123 | if "mask" in data: 124 | mask = data["mask"] 125 | 126 | gt_img = data["image"] 127 | if "sensor_depth" in data: 128 | depth_gt = data["sensor_depth"] 129 | depth_gt_color = colormaps.apply_depth_colormap( 130 | data["sensor_depth"] 131 | ) 132 | else: 133 | depth_gt = None 134 | depth_gt_color = None 135 | if "normal" in data: 136 | normal_gt = data["normal"] 137 | 138 | # process pred outputs 139 | camera = cameras[image_idx : image_idx + 1].to("cpu") 140 | outputs = model.get_outputs_for_camera(camera=camera) 141 | 142 | rgb_out, depth_out = outputs["rgb"], outputs["depth"] 143 | 144 | normal = None 145 | if "normal" in outputs: 146 | normal = outputs["normal"] 147 | 148 | seq_name = Path(eval_dataset.image_filenames[image_idx]) 149 | image_name = f"{seq_name.stem}" 150 | 151 | if "long_capture" in str(seq_name).split("/"): 152 | image_name = "long_capture_" + image_name 153 | if "short_capture" in str(seq_name).split("/"): 154 | image_name = "short_capture_" + image_name 155 | 156 | depth_color = colormaps.apply_depth_colormap(depth_out) 157 | depth = depth_out.detach().cpu().numpy() 158 | 159 | if mask is not None: 160 | rgb_out = rgb_out * mask 161 | gt_img = gt_img * mask 162 | if depth_color is not None: 163 | depth_color = depth_color * mask 164 | if depth_gt_color is not None: 165 | depth_gt_color = depth_gt_color * mask 166 | if depth_gt is not None: 167 | depth_gt = depth_gt * mask 168 | if depth is not None: 169 | depth = depth * mask 170 | if normal_gt is not None: 171 | normal_gt = normal_gt * mask 172 | if normal is not None: 173 | normal = normal * mask 174 | 175 | # save all outputs 176 | save_outputs_helper( 177 | rgb_out if self.render_rgb else None, 178 | gt_img if self.render_rgb else None, 179 | depth_color if self.render_depth else None, 180 | depth_gt_color if self.render_depth else None, 181 | depth_gt if self.render_depth else None, 182 | depth if self.render_depth else None, 183 | normal_gt if self.render_normal else None, 184 | normal if self.render_normal else None, 185 | self.output_dir / "eval", 186 | image_name, 187 | ) 188 | 189 | 190 | if __name__ == "__main__": 191 | tyro.cli(RenderModel).main() 192 | -------------------------------------------------------------------------------- /dn_splatter/utils/camera_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for projection and camera coords with different conventions""" 2 | 3 | import math 4 | from typing import List, Optional, Tuple 5 | 6 | import numpy as np 7 | import torch 8 | from torch import Tensor 9 | 10 | # opengl to opencv transformation matrix 11 | OPENGL_TO_OPENCV = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) 12 | 13 | 14 | # ndc space is x to the right y up. uv space is x to the right, y down. 15 | def pix2ndc_x(x, W): 16 | x = x.float() 17 | return (2 * x) / W - 1 18 | 19 | 20 | def pix2ndc_y(y, H): 21 | y = y.float() 22 | return 1 - (2 * y) / H 23 | 24 | 25 | # ndc is y up and x right. uv is y down and x right 26 | def ndc2pix_x(x, W): 27 | return (x + 1) * 0.5 * W 28 | 29 | 30 | def ndc2pix_y(y, H): 31 | return (1 - y) * 0.5 * H 32 | 33 | 34 | def euclidean_to_z_depth( 35 | depths: Tensor, 36 | fx: float, 37 | fy: float, 38 | cx: int, 39 | cy: int, 40 | img_size: tuple, 41 | device: torch.device, 42 | ) -> Tensor: 43 | """Convert euclidean depths to z_depths given camera intrinsics""" 44 | if depths.dim() == 3: 45 | depths = depths.view(-1, 1) 46 | elif depths.shape[-1] != 1: 47 | depths = depths.unsqueeze(-1).contiguous() 48 | depths = depths.view(-1, 1) 49 | if depths.dtype != torch.float: 50 | depths = depths.float() 51 | image_coords = get_camera_coords(img_size=img_size) 52 | image_coords = image_coords.to(device) 53 | 54 | z_depth = torch.empty( 55 | size=(img_size[0], img_size[1], 3), dtype=torch.float32, device=device 56 | ).view(-1, 3) 57 | z_depth[:, 0] = (image_coords[:, 0] - cx) / fx # x 58 | z_depth[:, 1] = (image_coords[:, 1] - cy) / fy # y 59 | z_depth[:, 2] = 1 # z 60 | 61 | z_depth = z_depth / torch.norm(z_depth, dim=-1, keepdim=True) 62 | z_depth = (z_depth * depths)[:, 2] # pick only z component 63 | 64 | z_depth = z_depth[..., None] 65 | z_depth = z_depth.view(img_size[1], img_size[0], 1) 66 | 67 | return z_depth 68 | 69 | 70 | def get_camera_coords(img_size: tuple, pixel_offset: float = 0.5) -> Tensor: 71 | """Generates camera pixel coordinates [W,H] 72 | 73 | Returns: 74 | stacked coords [H*W,2] where [:,0] corresponds to W and [:,1] corresponds to H 75 | """ 76 | 77 | # img size is (w,h) 78 | image_coords = torch.meshgrid( 79 | torch.arange(img_size[0]), 80 | torch.arange(img_size[1]), 81 | indexing="xy", # W = u by H = v 82 | ) 83 | image_coords = ( 84 | torch.stack(image_coords, dim=-1) + pixel_offset 85 | ) # stored as (x, y) coordinates 86 | image_coords = image_coords.view(-1, 2) 87 | image_coords = image_coords.float() 88 | 89 | return image_coords 90 | 91 | 92 | def get_means3d_backproj( 93 | depths: Tensor, 94 | fx: float, 95 | fy: float, 96 | cx: int, 97 | cy: int, 98 | img_size: tuple, 99 | c2w: Tensor, 100 | device: torch.device, 101 | mask: Optional[Tensor] = None, 102 | ) -> Tuple[Tensor, List]: 103 | """Backprojection using camera intrinsics and extrinsics 104 | 105 | image_coords -> (x,y,depth) -> (X, Y, depth) 106 | 107 | Returns: 108 | Tuple of (means: Tensor, image_coords: Tensor) 109 | """ 110 | 111 | if depths.dim() == 3: 112 | depths = depths.view(-1, 1) 113 | elif depths.shape[-1] != 1: 114 | depths = depths.unsqueeze(-1).contiguous() 115 | depths = depths.view(-1, 1) 116 | if depths.dtype != torch.float: 117 | depths = depths.float() 118 | c2w = c2w.float() 119 | if c2w.device != device: 120 | c2w = c2w.to(device) 121 | 122 | image_coords = get_camera_coords(img_size) 123 | image_coords = image_coords.to(device) # note image_coords is (H,W) 124 | 125 | # TODO: account for skew / radial distortion 126 | means3d = torch.empty( 127 | size=(img_size[0], img_size[1], 3), dtype=torch.float32, device=device 128 | ).view(-1, 3) 129 | means3d[:, 0] = (image_coords[:, 0] - cx) * depths[:, 0] / fx # x 130 | means3d[:, 1] = (image_coords[:, 1] - cy) * depths[:, 0] / fy # y 131 | means3d[:, 2] = depths[:, 0] # z 132 | 133 | if mask is not None: 134 | if not torch.is_tensor(mask): 135 | mask = torch.tensor(mask, device=depths.device) 136 | means3d = means3d[mask] 137 | image_coords = image_coords[mask] 138 | 139 | if c2w is None: 140 | c2w = torch.eye((means3d.shape[0], 4, 4), device=device) 141 | 142 | # to world coords 143 | means3d = means3d @ torch.linalg.inv(c2w[..., :3, :3]) + c2w[..., :3, 3] 144 | return means3d, image_coords 145 | 146 | 147 | def project_pix( 148 | p: Tensor, 149 | fx: float, 150 | fy: float, 151 | cx: int, 152 | cy: int, 153 | c2w: Tensor, 154 | device: torch.device, 155 | return_z_depths: bool = False, 156 | ) -> Tensor: 157 | """Projects a world 3D point to uv coordinates using intrinsics/extrinsics 158 | 159 | Returns: 160 | uv coords 161 | """ 162 | if c2w is None: 163 | c2w = torch.eye((p.shape[0], 4, 4), device=device) # type: ignore 164 | if c2w.device != device: 165 | c2w = c2w.to(device) 166 | 167 | points_cam = (p.to(device) - c2w[..., :3, 3]) @ c2w[..., :3, :3] 168 | u = points_cam[:, 0] * fx / points_cam[:, 2] + cx # x 169 | v = points_cam[:, 1] * fy / points_cam[:, 2] + cy # y 170 | if return_z_depths: 171 | return torch.stack([u, v, points_cam[:, 2]], dim=-1) 172 | return torch.stack([u, v], dim=-1) 173 | 174 | 175 | def get_colored_points_from_depth( 176 | depths: Tensor, 177 | rgbs: Tensor, 178 | c2w: Tensor, 179 | fx: float, 180 | fy: float, 181 | cx: int, 182 | cy: int, 183 | img_size: tuple, 184 | mask: Optional[Tensor] = None, 185 | ) -> Tuple[Tensor, Tensor]: 186 | """Return colored pointclouds from depth and rgb frame and c2w. Optional masking. 187 | 188 | Returns: 189 | Tuple of (points, colors) 190 | """ 191 | points, _ = get_means3d_backproj( 192 | depths=depths.float(), 193 | fx=fx, 194 | fy=fy, 195 | cx=cx, 196 | cy=cy, 197 | img_size=img_size, 198 | c2w=c2w.float(), 199 | device=depths.device, 200 | ) 201 | points = points.squeeze(0) 202 | if mask is not None: 203 | if not torch.is_tensor(mask): 204 | mask = torch.tensor(mask, device=depths.device) 205 | colors = rgbs.view(-1, 3)[mask] 206 | points = points[mask] 207 | else: 208 | colors = rgbs.view(-1, 3) 209 | points = points 210 | return (points, colors) 211 | 212 | 213 | def get_rays_x_y_1(H, W, focal, c2w): 214 | """Get ray origins and directions in world coordinates. 215 | 216 | Convention here is (x,y,-1) such that depth*rays_d give real z depth values in world coordinates. 217 | """ 218 | assert c2w.shape == torch.Size([3, 4]) 219 | image_coords = torch.meshgrid( 220 | torch.arange(W, dtype=torch.float32), 221 | torch.arange(H, dtype=torch.float32), 222 | indexing="ij", 223 | ) 224 | i, j = image_coords 225 | # dirs = torch.stack([(i-W*0.5)/focal, -(j-H*0.5)/focal, -torch.ones_like(i)], dim = -1) 226 | dirs = torch.stack( 227 | [(pix2ndc_x(i, W)) / focal, pix2ndc_y(j, H) / focal, -torch.ones_like(i)], 228 | dim=-1, 229 | ) 230 | dirs = dirs.view(-1, 3) 231 | rays_d = dirs[..., :] @ c2w[:3, :3] 232 | rays_o = c2w[:3, -1].expand_as(rays_d) 233 | 234 | # return world coordinate rays_o and rays_d 235 | return rays_o, rays_d 236 | 237 | 238 | def get_projection_matrix(znear=0.001, zfar=1000, fovx=None, fovy=None, **kwargs): 239 | """Opengl projection matrix 240 | 241 | Returns: 242 | projmat: Tensor 243 | """ 244 | 245 | t = znear * math.tan(0.5 * fovy) 246 | b = -t 247 | r = znear * math.tan(0.5 * fovx) 248 | l = -r 249 | n = znear 250 | f = zfar 251 | return torch.tensor( 252 | [ 253 | [2 * n / (r - l), 0.0, (r + l) / (r - l), 0.0], 254 | [0.0, 2 * n / (t - b), (t + b) / (t - b), 0.0], 255 | [0.0, 0.0, (f + n) / (f - n), -1.0 * f * n / (f - n)], 256 | [0.0, 0.0, 1.0, 0.0], 257 | ], 258 | **kwargs, 259 | ) 260 | -------------------------------------------------------------------------------- /dn_splatter/utils/knn.py: -------------------------------------------------------------------------------- 1 | """KNN implementations""" 2 | 3 | import torch 4 | from torch import Tensor 5 | 6 | device = torch.device("cuda:0") 7 | 8 | 9 | def fast_knn(x: Tensor, y: Tensor, k: int = 2): 10 | """Wrapper for torch_cluster.knn 11 | 12 | Args: 13 | x: input data 14 | y: query data 15 | k: k-nearest neighbours 16 | """ 17 | from torch_cluster import knn 18 | 19 | assert x.is_cuda 20 | assert y.is_cuda 21 | assert x.dim() == y.dim() == 2 22 | with torch.no_grad(): 23 | k = k + 1 24 | outs = knn(x.clone(), y.clone(), k, None, None)[1, :] 25 | outs = outs.reshape(y.shape[0], k)[:, 1:] 26 | return outs 27 | 28 | 29 | def knn_sk(x: torch.Tensor, y: torch.Tensor, k: int): 30 | import numpy as np 31 | 32 | x_np = x.cpu().numpy() 33 | y_np = y.cpu().numpy() 34 | 35 | from sklearn.neighbors import NearestNeighbors 36 | 37 | nn_model = NearestNeighbors( 38 | n_neighbors=k + 1, algorithm="auto", metric="euclidean" 39 | ).fit(x_np) 40 | 41 | distances, indices = nn_model.kneighbors(y_np) 42 | 43 | return torch.from_numpy(indices[:, 1:].astype(np.int64)).long().to(x.device) 44 | -------------------------------------------------------------------------------- /dn_splatter/utils/normal_utils.py: -------------------------------------------------------------------------------- 1 | """Utils for normals""" 2 | 3 | import torch 4 | from dn_splatter.utils.camera_utils import get_means3d_backproj 5 | from torch import Tensor 6 | import math 7 | 8 | 9 | def pcd_to_normal(xyz: Tensor): 10 | hd, wd, _ = xyz.shape 11 | bottom_point = xyz[..., 2:hd, 1 : wd - 1, :] 12 | top_point = xyz[..., 0 : hd - 2, 1 : wd - 1, :] 13 | right_point = xyz[..., 1 : hd - 1, 2:wd, :] 14 | left_point = xyz[..., 1 : hd - 1, 0 : wd - 2, :] 15 | left_to_right = right_point - left_point 16 | bottom_to_top = top_point - bottom_point 17 | xyz_normal = torch.cross(left_to_right, bottom_to_top, dim=-1) 18 | xyz_normal = torch.nn.functional.normalize(xyz_normal, p=2, dim=-1) 19 | xyz_normal = torch.nn.functional.pad( 20 | xyz_normal.permute(2, 0, 1), (1, 1, 1, 1), mode="constant" 21 | ).permute(1, 2, 0) 22 | return xyz_normal 23 | 24 | 25 | def normal_from_depth_image( 26 | depths: Tensor, 27 | fx: float, 28 | fy: float, 29 | cx: float, 30 | cy: float, 31 | img_size: tuple, 32 | c2w: Tensor, 33 | device: torch.device, 34 | smooth: bool = False, 35 | ): 36 | """estimate normals from depth map""" 37 | if smooth: 38 | if torch.count_nonzero(depths) > 0: 39 | print("Input depth map contains 0 elements, skipping smoothing filter") 40 | else: 41 | kernel_size = (9, 9) 42 | depths = torch.from_numpy( 43 | cv2.GaussianBlur(depths.cpu().numpy(), kernel_size, 0) 44 | ).to(device) 45 | means3d, _ = get_means3d_backproj(depths, fx, fy, cx, cy, img_size, c2w, device) 46 | means3d = means3d.view(img_size[1], img_size[0], 3) 47 | normals = pcd_to_normal(means3d) 48 | return normals 49 | -------------------------------------------------------------------------------- /pixi.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "dn-splatter" 3 | version = "0.1.0" 4 | description = "Depth and normal priors for 3D Gaussian splatting and meshing" 5 | channels = ["nvidia/label/cuda-11.8.0", "nvidia", "conda-forge", "pytorch"] 6 | platforms = ["linux-64"] 7 | conda-pypi-map = { "conda-forge" = "https://raw.githubusercontent.com/prefix-dev/parselmouth/main/files/mapping_as_grayskull.json" } 8 | 9 | [tasks] 10 | # Quick Run example 11 | download-omnidata = {cmd="ls omnidata_ckpt/omnidata_dpt_normal_v2.ckpt || python dn_splatter/data/download_scripts/download_omnidata.py", outputs=["omnidata_ckpt/omnidata_dpt_normal_v2.ckpt"]} 12 | download-data-sample = {cmd="ls datasets/koivu_iphone.tar.gz || python dn_splatter/data/download_scripts/mushroom_download.py --sequence iphone --room-name koivu", outputs=["datasets/koivu_iphone.tar.gz"]} 13 | train-mushroom-sample = """ 14 | ns-train dn-splatter \ 15 | --pipeline.model.use-depth-loss True \ 16 | --pipeline.model.depth-lambda 0.2 \ 17 | --pipeline.model.use-depth-smooth-loss True \ 18 | --pipeline.model.use-normal-loss True \ 19 | --pipeline.model.normal-supervision mono \ 20 | mushroom --data datasets/room_datasets/koivu --mode iphone 21 | """ 22 | example = {cmd = "pwd", depends_on=["download-omnidata", "download-data-sample", "train-mushroom-sample"]} 23 | 24 | # Polycam example using dsine 25 | ## first download polycam only if ls fails (checking that file isn't already downloaded), || is used to check if file exists, if not move on to following command 26 | download-polycam = {cmd="ls datasets/polycam/6g-first-scan-poly.zip || wget -P datasets/polycam https://huggingface.co/datasets/pablovela5620/sample-polycam-room/resolve/main/6g-first-scan-poly.zip"} 27 | convert-poly-to-ns = {cmd="ls datasets/polycam/6g-first-scan-poly/transforms.json || ns-process-data polycam --data datasets/polycam/6g-first-scan-poly.zip --output-dir datasets/polycam/6g-first-scan-poly --use-depth", depends_on=["download-polycam"]} 28 | generate-normals = {cmd="ls datasets/polycam/6g-first-scan-poly/normals_from_pretrain/frame_00001.png || python dn_splatter/scripts/normals_from_pretrain.py --data-dir datasets/polycam/6g-first-scan-poly --normal-format dsine", depends_on=["convert-poly-to-ns"]} 29 | generate-pointcloud = "ls datasets/polycam/6g-first-scan-poly/iphone_pointcloud.ply || python dn_splatter/data/mushroom_utils/pointcloud_utils.py --data-path datasets/polycam/6g-first-scan-poly" 30 | train-polycam = """ 31 | ns-train dn-splatter \ 32 | --max-num-iterations 5001 \ 33 | --pipeline.model.use-depth-loss True \ 34 | --pipeline.model.depth-lambda 0.2 \ 35 | --pipeline.model.use-depth-smooth-loss True \ 36 | --pipeline.model.use-normal-loss True \ 37 | --pipeline.model.normal-supervision mono \ 38 | normal-nerfstudio --data datasets/polycam/6g-first-scan-poly --normal-format opencv 39 | """ 40 | 41 | example-polycam = {cmd="pwd", depends_on=["generate-normals", "generate-pointcloud", "train-polycam"]} 42 | 43 | [dependencies] 44 | python = "3.10.*" 45 | pip = ">=24.0,<25" 46 | cuda = {version = "*", channel="nvidia/label/cuda-11.8.0"} 47 | pytorch-cuda = {version = "11.8.*", channel="pytorch"} 48 | pytorch = {version = ">=2.2.0,<2.3", channel="pytorch"} 49 | torchvision = {version = ">=0.17.0,<0.18", channel="pytorch"} 50 | gcc = "11.*" 51 | gxx = ">=11.4.0,<11.5" 52 | pyarrow = ">=15.0.2,<15.1" 53 | rerun-sdk = ">=0.15.1,<0.16" 54 | 55 | [pypi-dependencies] 56 | nerfstudio = { git = "https://github.com/nerfstudio-project/nerfstudio.git", rev = "a64026f8db23a4233327a1d0303e6082bf5b9805" } 57 | dn-splatter = { path = ".", editable = true} 58 | 59 | [system-requirements] 60 | libc = { family="glibc", version="2.30" } 61 | -------------------------------------------------------------------------------- /pyproject.toml: -------------------------------------------------------------------------------- 1 | [project] 2 | name = "dn-splatter" 3 | description = "Depth and normal priors for 3D Gaussian splatting and meshing" 4 | version = "0.0.1" 5 | 6 | dependencies = [ 7 | "nerfstudio == 1.1.3", 8 | "gsplat == 1.0.0", 9 | "black == 22.3.0", 10 | "natsort", 11 | "pymeshlab>=2022.2.post2; platform_machine != 'arm64' and platform_machine != 'aarch64'", 12 | "pytest", 13 | "vdbfusion", 14 | "PyMCubes==0.1.2", 15 | "omnidata-tools", 16 | "pytorch-lightning", 17 | "torch", 18 | # required for dsine normal network 19 | "geffnet", 20 | "rerun-sdk", 21 | "pyrender", 22 | ] 23 | 24 | [tool.setuptools.packages.find] 25 | include = ["dn_splatter*"] 26 | 27 | [project.entry-points.'nerfstudio.method_configs'] 28 | dn_splatter = 'dn_splatter.dn_config:dn_splatter' 29 | ags_mesh = 'dn_splatter.dn_config:ags_mesh' 30 | dn_splatter_big = 'dn_splatter.dn_config:dn_splatter_big' 31 | #g-nerfacto = 'dn_splatter.eval.eval_configs:gnerfacto' 32 | #g-depthfacto = 'dn_splatter.eval.eval_configs:gdepthfacto' 33 | #g-neusfacto = 'dn_splatter.eval.eval_configs:gneusfacto' 34 | 35 | [project.entry-points.'nerfstudio.dataparser_configs'] 36 | mushroom = 'dn_splatter:MushroomDataParserSpecification' 37 | replica = 'dn_splatter:ReplicaDataParserSpecification' 38 | nrgbd = 'dn_splatter:NRGBDDataParserSpecification' 39 | gsdf = 'dn_splatter:GSDFStudioDataParserSpecification' 40 | scannetpp = 'dn_splatter:ScanNetppDataParserSpecification' 41 | coolermap = 'dn_splatter:CoolerMapDataParserSpecification' 42 | normal-nerfstudio = 'dn_splatter:NormalNerfstudioSpecification' 43 | 44 | [project.scripts] 45 | # export mesh scripts 46 | gs-mesh = "dn_splatter.export_mesh:entrypoint" --------------------------------------------------------------------------------