├── .gitignore ├── LICENSE.txt ├── README.md ├── confs ├── ABC.conf ├── DTU.conf └── Replica.conf ├── main.py ├── media ├── overview.jpg └── replica.gif ├── requirements.txt ├── scripts ├── download_data.py ├── get_gt_points_DTU.py ├── run_ABC.bash ├── run_DTU.bash └── run_Replica.bash └── src ├── dataset └── dataset.py ├── edge_extraction ├── edge_fitting │ ├── bezier_fit.py │ ├── line_fit.py │ └── main.py ├── extract_parametric_edge.py ├── extract_pointcloud.py ├── extract_util.py └── merging │ └── main.py ├── eval ├── ABC_scans.txt ├── DTU_scans.txt ├── eval_ABC.py ├── eval_DTU.py └── eval_util.py ├── models ├── __init__.py ├── embedder.py ├── loss.py ├── udf_model.py └── udf_renderer_blending.py ├── runner ├── runner_base.py └── runner_udf.py └── utils ├── __init__.py ├── math.py ├── plots.py ├── rend_util.py ├── tensor_dataclass.py ├── visualization.py └── warmup_scheduler.py /.gitignore: -------------------------------------------------------------------------------- 1 | 2 | .idea/* 3 | *.pyc 4 | cmake-build-* 5 | *.egg-info/ 6 | *.so 7 | */build/ 8 | /exp/ 9 | /data/ 10 | 11 | -------------------------------------------------------------------------------- /LICENSE.txt: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2024 Lei Li 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 |

3D Neural Edge Reconstruction

3 |

4 | Lei Li 5 | · 6 | Songyou Peng 7 | · 8 | Zehao Yu 9 | · 10 | Shaohui Liu 11 | · 12 | Rémi Pautrat 13 |
14 | Xiaochuan Yin 15 | · 16 | Marc Pollefeys 17 |

18 |

CVPR 2024

19 |

Paper | Video | Project Page

20 |

21 | 22 |

23 | 24 |

25 | 26 |

27 | EMAP enables 3D edge reconstruction from multi-view 2D edge maps. 28 |

29 |
30 | 31 | ## Installation 32 | 33 | ``` 34 | git clone https://github.com/cvg/EMAP.git 35 | cd EMAP 36 | 37 | conda create -n emap python=3.8 38 | conda activate emap 39 | 40 | conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia 41 | pip install -r requirements.txt 42 | ``` 43 | 44 | ## Datasets 45 | Download datasets: 46 | ``` 47 | python scripts/download_data.py 48 | ``` 49 | The data is organized as follows: 50 | 51 | ``` 52 | 53 | |-- meta_data.json # camera parameters 54 | |-- color # images for each view 55 | |-- 0_colors.png 56 | |-- 1_colors.png 57 | ... 58 | |-- edge_DexiNed # edge maps extracted from DexiNed 59 | |-- 0_colors.png 60 | |-- 1_colors.png 61 | ... 62 | |-- edge_PidiNet # edge maps extracted from PidiNet 63 | |-- 0_colors.png 64 | |-- 1_colors.png 65 | ... 66 | ``` 67 | 68 | ## Training and Edge Extraction 69 | To train and extract edges on different datasets, use the following commands: 70 | 71 | #### ABC-NEF_Edge Dataset 72 | ``` 73 | bash scripts/run_ABC.bash 74 | ``` 75 | 76 | #### Replica_Edge Dataset 77 | ``` 78 | bash scripts/run_Replica.bash 79 | ``` 80 | 81 | #### DTU_Edge Dataset 82 | ``` 83 | bash scripts/run_DTU.bash 84 | ``` 85 | 86 | ### Checkpoints 87 | We have uploaded the model checkpoints on [Google Drive](https://drive.google.com/file/d/1kU87MqDv5IvwjCt8I8KecTlIok39fuws/view?usp=sharing). 88 | 89 | ## Evaluation 90 | To evaluate extracted edges on ABC-NEF_Edge dataset, use the following commands: 91 | 92 | #### ABC-NEF_Edge Dataset 93 | ``` 94 | python src/eval/eval_ABC.py 95 | ``` 96 | 97 | ## Code Release Status 98 | - [x] Training Code 99 | - [x] Inference Code 100 | - [x] Evaluation Code 101 | - [ ] Custom Dataset Support 102 | 103 | ## License 104 | 105 | Shield: [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) 106 | 107 | The majority of EMAP is licensed under a [MIT License](LICENSE.txt). 108 | 109 | ## Citing EMAP 110 | 111 | If you find the code useful, please consider the following BibTeX entry. 112 | 113 | ```BibTeX 114 | @InProceedings{li2024neural, 115 | title={3D Neural Edge Reconstruction}, 116 | author={Li, Lei and Peng, Songyou and Yu, Zehao and Liu, Shaohui and Pautrat, R{\'e}mi and Yin, Xiaochuan and Pollefeys, Marc}, 117 | booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)}, 118 | year={2024}, 119 | } 120 | ``` 121 | 122 | ## Contact 123 | If you encounter any issues, you can also contact Lei through lllei.li0386@gmail.com. 124 | 125 | ## Acknowledgement 126 | 127 | This project is built upon [NeuralUDF](https://github.com/xxlong0/NeuralUDF), [NeuS](https://github.com/Totoro97/NeuS) and [MeshUDF](https://github.com/cvlab-epfl/MeshUDF). We use pretrained [DexiNed](https://github.com/xavysp/DexiNed) and [PidiNet](https://github.com/hellozhuo/pidinet) for edge map extraction. We thank all the authors for their great work and repos. 128 | -------------------------------------------------------------------------------- /confs/ABC.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = ./exp/ABC/ 3 | expname = emap 4 | 5 | model_type = udf 6 | recording = [ 7 | ./src/models, 8 | ./src/dataset, 9 | ./src/runner, 10 | ] 11 | } 12 | 13 | dataset { 14 | data_dir = ./data/ABC-NEF_Edge/data/ 15 | scan = "00000325" 16 | dataset_name = NEF 17 | detector = DexiNed 18 | near = 0.05 19 | far = 6 20 | AABB = [-1.0, -1.0, -1.0, 1.0, 1.0, 1.0] 21 | } 22 | 23 | train { 24 | latest_model_name = ckpt_best.pth 25 | importance_sample = True 26 | learning_rate = 5e-4 27 | learning_rate_geo = 1e-4 28 | learning_rate_alpha = 0.05 29 | end_iter = 50000 30 | 31 | batch_size = 1024 32 | validate_resolution_level = 1 33 | warm_up_end = 1000 34 | anneal_end = 10000 35 | use_white_bkgd = False 36 | 37 | warmup_sample = False 38 | 39 | save_freq = 1000 40 | val_freq = 1000 41 | report_freq = 1000 42 | 43 | igr_weight = 0.1 44 | igr_ns_weight = 0.0 45 | } 46 | 47 | edge_loss { 48 | edge_weight = 1.0 49 | loss_type = mse 50 | } 51 | 52 | model { 53 | nerf { 54 | D = 8 55 | d_in = 4 56 | d_in_view = 3 57 | W = 256 58 | multires = 10 59 | multires_view = 4 60 | output_ch = 4 61 | skips = [4] 62 | use_viewdirs = True 63 | } 64 | 65 | udf_network { 66 | d_out = 1 67 | d_in = 3 68 | d_hidden = 256 69 | n_layers = 8 70 | skip_in = [4] 71 | multires = 10 72 | bias = 0.5 73 | scale = 1.0 74 | geometric_init = True 75 | weight_norm = True 76 | udf_type = abs # square or abs 77 | } 78 | 79 | variance_network { 80 | init_val = 0.3 81 | } 82 | 83 | rendering_network { 84 | d_feature = 256 85 | mode = no_normal 86 | d_in = 6 87 | d_out = 1 88 | d_hidden = 128 89 | n_layers = 4 90 | weight_norm = True 91 | multires_view = 4 92 | squeeze_out = True 93 | blending_cand_views = 10 94 | } 95 | 96 | 97 | beta_network { 98 | init_var_beta = 0.5 99 | init_var_gamma = 0.3 100 | init_var_zeta = 0.3 101 | beta_min = 0.00005 102 | requires_grad_beta = True 103 | requires_grad_gamma = True 104 | requires_grad_zeta = False 105 | } 106 | 107 | udf_renderer { 108 | n_samples = 64 109 | n_importance = 50 110 | n_outside = 0 111 | up_sample_steps = 5 112 | perturb = 1.0 113 | sdf2alpha_type = numerical 114 | upsampling_type = classical 115 | use_unbias_render = True 116 | } 117 | } 118 | 119 | edge_extraction { 120 | is_pointshift = True 121 | iters = 2 122 | is_linedirection = True 123 | udf_threshold = 0.02 124 | resolution = 128 125 | sampling_delta = 0.005 126 | sampling_N = 50 127 | visible_checking = False 128 | 129 | } 130 | -------------------------------------------------------------------------------- /confs/DTU.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = ./exp/DTU/ 3 | expname = emap 4 | 5 | model_type = udf 6 | recording = [ 7 | ./src/models, 8 | ./src/dataset, 9 | ./src/runner, 10 | ] 11 | } 12 | 13 | dataset { 14 | data_dir = ./data/DTU_Edge/data/ 15 | scan = "scan105" 16 | dataset_name = DTU 17 | detector = PidiNet 18 | near = 0.05 19 | far = 6.0 20 | AABB = [-1.0, -1.0, -1.0, 1.0, 1.0, 1.0] 21 | } 22 | 23 | train { 24 | latest_model_name = ckpt_best.pth 25 | importance_sample = True 26 | learning_rate = 5e-4 27 | learning_rate_geo = 1e-4 28 | learning_rate_alpha = 0.05 29 | end_iter = 200000 30 | 31 | batch_size = 1024 32 | validate_resolution_level = 1 33 | warm_up_end = 1000 34 | anneal_end = 10000 35 | use_white_bkgd = False 36 | 37 | warmup_sample = False 38 | 39 | save_freq = 5000 40 | val_freq = 5000 41 | report_freq = 1000 42 | 43 | igr_weight = 0.01 44 | igr_ns_weight = 0.0 45 | } 46 | 47 | edge_loss { 48 | edge_weight = 1.0 49 | loss_type = mse 50 | } 51 | 52 | model { 53 | nerf { 54 | D = 8 55 | d_in = 4 56 | d_in_view = 3 57 | W = 256 58 | multires = 10 59 | multires_view = 4 60 | output_ch = 4 61 | skips = [4] 62 | use_viewdirs = True 63 | } 64 | 65 | udf_network { 66 | d_out = 1 67 | d_in = 3 68 | d_hidden = 256 69 | n_layers = 8 70 | skip_in = [4] 71 | multires = 10 72 | bias = 0.5 73 | scale = 1.0 74 | geometric_init = True 75 | weight_norm = True 76 | udf_type = abs # square or abs 77 | } 78 | 79 | variance_network { 80 | init_val = 0.3 81 | } 82 | 83 | rendering_network { 84 | d_feature = 256 85 | mode = no_normal 86 | d_in = 6 87 | d_out = 1 88 | d_hidden = 128 89 | n_layers = 4 90 | weight_norm = True 91 | multires_view = 4 92 | squeeze_out = True 93 | blending_cand_views = 10 94 | } 95 | 96 | 97 | beta_network { 98 | init_var_beta = 0.5 99 | init_var_gamma = 0.3 100 | init_var_zeta = 0.3 101 | beta_min = 0.00005 102 | requires_grad_beta = True 103 | requires_grad_gamma = True 104 | requires_grad_zeta = False 105 | } 106 | 107 | udf_renderer { 108 | n_samples = 64 109 | n_importance = 50 110 | n_outside = 0 111 | up_sample_steps = 5 112 | perturb = 1.0 113 | sdf2alpha_type = numerical 114 | upsampling_type = classical 115 | use_unbias_render = True 116 | } 117 | } 118 | 119 | edge_extraction { 120 | is_pointshift = True 121 | iters = 1 122 | is_linedirection = True 123 | udf_threshold = 0.015 124 | resolution = 256 125 | sampling_delta = 0.005 126 | sampling_N = 50 127 | visible_checking = True 128 | 129 | } 130 | -------------------------------------------------------------------------------- /confs/Replica.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = ./exp/Replica/ 3 | expname = emap 4 | 5 | model_type = udf 6 | recording = [ 7 | ./src/models, 8 | ./src/dataset, 9 | ./src/runner, 10 | ] 11 | } 12 | 13 | dataset { 14 | data_dir = ./data/Replica_Edge 15 | scan = "room0" 16 | dataset_name = Replica 17 | detector = PidiNet 18 | near = 0.05 19 | far = 2.5 20 | AABB = [-1.0, -1.0, -1.0, 1.0, 1.0, 1.0] 21 | } 22 | 23 | train { 24 | latest_model_name = ckpt_best.pth 25 | importance_sample = True 26 | learning_rate = 5e-4 27 | learning_rate_geo = 1e-4 28 | learning_rate_alpha = 0.05 29 | end_iter = 200000 30 | 31 | batch_size = 1024 32 | validate_resolution_level = 1 33 | warm_up_end = 1000 34 | anneal_end = 10000 35 | use_white_bkgd = False 36 | 37 | warmup_sample = False 38 | 39 | save_freq = 5000 40 | val_freq = 5000 41 | report_freq = 1000 42 | 43 | igr_weight = 0.01 44 | igr_ns_weight = 0.0 45 | } 46 | 47 | edge_loss { 48 | edge_weight = 1.0 49 | loss_type = mse 50 | } 51 | 52 | model { 53 | nerf { 54 | D = 8 55 | d_in = 4 56 | d_in_view = 3 57 | W = 256 58 | multires = 10 59 | multires_view = 4 60 | output_ch = 4 61 | skips = [4] 62 | use_viewdirs = True 63 | } 64 | 65 | udf_network { 66 | d_out = 1 67 | d_in = 3 68 | d_hidden = 256 69 | n_layers = 8 70 | skip_in = [4] 71 | multires = 6 # 6 or 10 72 | bias = 0.5 73 | scale = 1.0 74 | geometric_init = True 75 | weight_norm = True 76 | udf_type = abs # square or abs 77 | } 78 | 79 | variance_network { 80 | init_val = 0.3 81 | } 82 | 83 | rendering_network { 84 | d_feature = 256 85 | mode = no_normal 86 | d_in = 6 87 | d_out = 1 88 | d_hidden = 128 89 | n_layers = 4 90 | weight_norm = True 91 | multires_view = 4 92 | squeeze_out = True 93 | blending_cand_views = 10 94 | } 95 | 96 | beta_network { 97 | init_var_beta = 0.5 98 | init_var_gamma = 0.3 99 | init_var_zeta = 0.3 100 | beta_min = 0.00005 101 | requires_grad_beta = True 102 | requires_grad_gamma = True 103 | requires_grad_zeta = False 104 | } 105 | 106 | udf_renderer { 107 | n_samples = 64 108 | n_importance = 50 109 | n_outside = 0 110 | up_sample_steps = 5 111 | perturb = 1.0 112 | sdf2alpha_type = numerical 113 | upsampling_type = classical 114 | use_unbias_render = True 115 | } 116 | } 117 | 118 | edge_extraction { 119 | is_pointshift = True 120 | iters = 1 121 | is_linedirection = True 122 | udf_threshold = 0.01 123 | resolution = 256 124 | sampling_delta = 0.005 125 | sampling_N = 50 126 | visible_checking = True 127 | 128 | } -------------------------------------------------------------------------------- /main.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import argparse 3 | import logging 4 | import numpy as np 5 | import random 6 | from pyhocon import ConfigFactory 7 | from src.runner.runner_udf import Runner_UDF 8 | 9 | 10 | def fix_random_seeds(seed=42): 11 | """ 12 | Fix the random seeds for reproducibility. 13 | """ 14 | torch.manual_seed(seed) 15 | if torch.cuda.is_available(): 16 | torch.cuda.manual_seed_all(seed) 17 | np.random.seed(seed) 18 | random.seed(seed) 19 | 20 | 21 | def get_runner(mode): 22 | """ 23 | Get the runner based on the provided mode. 24 | """ 25 | runners = { 26 | "udf": Runner_UDF, 27 | } 28 | if mode not in runners: 29 | raise ValueError(f"Unknown mode: {mode}") 30 | return runners[mode] 31 | 32 | 33 | def main(): 34 | """ 35 | Main function to parse arguments and run the appropriate mode. 36 | """ 37 | torch.set_default_dtype(torch.float32) 38 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s" 39 | logging.basicConfig(level=logging.DEBUG, format=FORMAT) 40 | 41 | parser = argparse.ArgumentParser() 42 | 43 | # Parameters for the training 44 | parser.add_argument( 45 | "--conf", type=str, default="./confs/ABC.conf", help="Path to the config file." 46 | ) 47 | parser.add_argument( 48 | "--mode", 49 | type=str, 50 | default="train", 51 | choices=["train", "extract_edge"], 52 | help="Mode to run.", 53 | ) 54 | parser.add_argument( 55 | "--scan", type=str, default="null", help="The name of a dataset." 56 | ) 57 | parser.add_argument("--gpu", type=int, default=0, help="GPU id to use.") 58 | parser.add_argument( 59 | "--is_continue", 60 | default=False, 61 | action="store_true", 62 | help="Flag to continue training.", 63 | ) 64 | 65 | args = parser.parse_args() 66 | 67 | # Fix the random seed 68 | fix_random_seeds() 69 | 70 | with open(args.conf, "r") as f: 71 | conf_text = f.read() 72 | conf = ConfigFactory.parse_string(conf_text) 73 | 74 | if args.scan != "null": 75 | conf["dataset"]["scan"] = args.scan 76 | 77 | logging.info(f"Run on scan of {conf['dataset']['scan']}") 78 | 79 | runner_class = get_runner(conf["general"]["model_type"]) 80 | runner = runner_class(conf, args.mode, args.is_continue, args) 81 | 82 | if args.mode == "train": 83 | logging.info(f"Training UDF") 84 | runner.train() 85 | elif args.mode == "extract_edge": 86 | logging.info(f"Extracting edges from UDF") 87 | runner.extract_edge( 88 | resolution=conf["edge_extraction"]["resolution"], 89 | udf_threshold=conf["edge_extraction"]["udf_threshold"], 90 | sampling_N=conf["edge_extraction"]["sampling_N"], 91 | sampling_delta=conf["edge_extraction"]["sampling_delta"], 92 | is_pointshift=conf["edge_extraction"]["is_pointshift"], 93 | iters=conf["edge_extraction"]["iters"], 94 | is_linedirection=conf["edge_extraction"]["is_linedirection"], 95 | visible_checking=conf["edge_extraction"]["visible_checking"], 96 | ) 97 | else: 98 | raise ValueError(f"Invalid mode: {args.mode}") 99 | 100 | 101 | if __name__ == "__main__": 102 | main() 103 | -------------------------------------------------------------------------------- /media/overview.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/EMAP/652f24ecc3f3cbf538928f27cc6d55dbebb360c7/media/overview.jpg -------------------------------------------------------------------------------- /media/replica.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/EMAP/652f24ecc3f3cbf538928f27cc6d55dbebb360c7/media/replica.gif -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | pyhocon==0.3.59 2 | icecream==2.1.3 3 | opencv-python==4.9.0.80 4 | scipy==1.10.1 5 | h5py==3.10.0 6 | open3d==0.18.0 7 | point_cloud_utils==0.30.4 8 | trimesh==4.2.2 9 | tensorboard==2.14.0 10 | torch_optimizer==0.3.0 11 | flow_vis==0.1 12 | scikit-image==0.21.0 13 | termcolor==2.4.0 14 | gitpython==3.1.43 15 | gdown==5.2.0 -------------------------------------------------------------------------------- /scripts/download_data.py: -------------------------------------------------------------------------------- 1 | import os 2 | import zipfile 3 | import gdown 4 | 5 | 6 | def download_and_unzip_google_drive_files(paths, download_to="./data"): 7 | if not os.path.exists(download_to): 8 | os.makedirs(download_to) 9 | 10 | for path in paths: 11 | # Convert Google Drive link to direct download link 12 | file_id = path.split("/d/")[1].split("/")[0] 13 | direct_link = f"https://drive.google.com/uc?id={file_id}" 14 | 15 | print(f"Downloading from {direct_link}...") 16 | output_path = os.path.join(download_to, f"{file_id}.zip") 17 | gdown.download(direct_link, output_path, quiet=False) 18 | print(f"Downloaded {output_path}") 19 | 20 | # Unzip the downloaded file 21 | with zipfile.ZipFile(output_path, "r") as zip_ref: 22 | zip_ref.extractall(download_to) 23 | os.remove(output_path) 24 | print(f"Unzipped and deleted: {output_path}") 25 | 26 | print(f"Finished extracting files to: {download_to}") 27 | 28 | 29 | # Google Drive file paths 30 | paths = [ 31 | "https://drive.google.com/file/d/17aUcCJCP5vgARs237H0BtlRoms5-CR6e/view?usp=sharing", 32 | "https://drive.google.com/file/d/1eZZiMcTfoiYfIxtv4Wy3lQYAudZpKlE0/view?usp=sharing", 33 | "https://drive.google.com/file/d/1pum-25MEFhXQu1fZLy_f9lRMBxvF1ssm/view?usp=sharing", 34 | ] 35 | 36 | download_to = "./data" 37 | download_and_unzip_google_drive_files(paths, download_to) 38 | -------------------------------------------------------------------------------- /scripts/get_gt_points_DTU.py: -------------------------------------------------------------------------------- 1 | import trimesh 2 | import open3d as o3d 3 | import numpy as np 4 | import math 5 | import json 6 | import os 7 | from src.eval.eval_util import ( 8 | set_random_seeds, 9 | load_from_json, 10 | downsample_point_cloud_average, 11 | ) 12 | from src.edge_extraction.edge_fitting.bezier_fit import bezier_curve 13 | import argparse 14 | from pathlib import Path 15 | import json 16 | import cv2 17 | 18 | 19 | def write_vis_pcd(file, points, colors): 20 | pcd = o3d.geometry.PointCloud() 21 | pcd.points = o3d.utility.Vector3dVector(points) 22 | pcd.colors = o3d.utility.Vector3dVector(colors) 23 | o3d.io.write_point_cloud(file, pcd) 24 | 25 | 26 | def convert_ply_to_obj(ply_file_path, obj_file_path): 27 | # Load the .ply file 28 | mesh = trimesh.load(ply_file_path) 29 | # Export the mesh to .obj format 30 | mesh.export(obj_file_path, file_type="obj") 31 | 32 | 33 | def save_point_cloud(file_path, points): 34 | pcd = o3d.geometry.PointCloud() 35 | pcd.points = o3d.utility.Vector3dVector(points) 36 | o3d.io.write_point_cloud(file_path, pcd) 37 | 38 | 39 | def load_from_json(filename: Path): 40 | """Load a dictionary from a JSON filename. 41 | 42 | Args: 43 | filename: The filename to load from. 44 | """ 45 | assert filename.suffix == ".json" 46 | with open(filename, encoding="UTF-8") as file: 47 | return json.load(file) 48 | 49 | 50 | def sample_single_tri(input_): 51 | n1, n2, v1, v2, tri_vert = input_ 52 | c = np.mgrid[: n1 + 1, : n2 + 1] 53 | c += 0.5 54 | c[0] /= max(n1, 1e-7) 55 | c[1] /= max(n2, 1e-7) 56 | c = np.transpose(c, (1, 2, 0)) 57 | k = c[c.sum(axis=-1) < 1] # m2 58 | q = v1 * k[:, :1] + v2 * k[:, 1:] + tri_vert 59 | return q 60 | 61 | 62 | def convert_mesh_gt2world(mesh_path, out_mesh_path, gttoworld): 63 | mesh = trimesh.load(mesh_path) 64 | # mesh.transform(gttoworld) 65 | mesh.apply_transform(gttoworld) 66 | mesh.export(out_mesh_path, file_type="obj") 67 | return mesh 68 | 69 | 70 | def get_edge_maps(data_dir): 71 | meta = load_from_json(Path(data_dir) / "meta_data.json") 72 | h, w = meta["height"], meta["width"] 73 | edges_list, intrinsics_list, camtoworld_list = [], [], [] 74 | for idx, frame in enumerate(meta["frames"]): 75 | intrinsics = np.array(frame["intrinsics"]) 76 | camtoworld = np.array(frame["camtoworld"])[:4, :4] 77 | edges_list.append( 78 | os.path.join( 79 | data_dir, 80 | "edge_PidiNet", 81 | frame["rgb_path"], 82 | ) 83 | ) 84 | intrinsics_list.append(intrinsics) 85 | camtoworld_list.append(camtoworld) 86 | 87 | edges = [cv2.imread(im_name, 0)[..., None] for im_name in edges_list] 88 | edges = 1 - np.stack(edges) / 255.0 89 | intrinsics_list = np.stack(intrinsics_list) 90 | camtoworld_list = np.stack(camtoworld_list) 91 | return edges, intrinsics_list, camtoworld_list, h, w 92 | 93 | 94 | def compute_visibility( 95 | gt_points, 96 | edge_maps, 97 | intrinsics_list, 98 | camtoworld_list, 99 | h, 100 | w, 101 | edge_visibility_threshold, 102 | edge_visibility_frames, 103 | ): 104 | img_frames = len(edge_maps) 105 | point_visibility_matrix = np.zeros((len(gt_points), img_frames)) 106 | 107 | for frame_idx, (edge_map, intrinsic, camtoworld) in enumerate( 108 | zip(edge_maps, intrinsics_list, camtoworld_list) 109 | ): 110 | K = intrinsic[:3, :3] 111 | worldtocam = np.linalg.inv(camtoworld) 112 | edge_uv = project2D(K, worldtocam, gt_points) 113 | edge_uv = np.round(edge_uv).astype(np.int64) 114 | 115 | # Boolean mask for valid u, v coordinates 116 | valid_u_mask = (edge_uv[:, 0] >= 0) & (edge_uv[:, 0] < w) 117 | valid_v_mask = (edge_uv[:, 1] >= 0) & (edge_uv[:, 1] < h) 118 | valid_mask = valid_u_mask & valid_v_mask 119 | 120 | valid_edge_uv = edge_uv[valid_mask] 121 | valid_projected_edges = edge_map[valid_edge_uv[:, 1], valid_edge_uv[:, 0]] 122 | 123 | # Calculate visibility in a vectorized manner 124 | visibility = (valid_projected_edges > edge_visibility_threshold).reshape(-1) 125 | 126 | point_visibility_matrix[valid_mask, frame_idx] = visibility.astype(float) 127 | 128 | return np.sum(point_visibility_matrix, axis=1) > edge_visibility_frames 129 | 130 | 131 | def project2D(K, worldtocam, points3d): 132 | shape = points3d.shape 133 | R = worldtocam[:3, :3] 134 | T = worldtocam[:3, 3:] 135 | 136 | projected = K @ (R @ points3d.T + T) 137 | projected = projected.T 138 | projected = projected / projected[:, -1:] 139 | uv = projected.reshape(*shape)[..., :2].reshape(-1, 2) 140 | return uv 141 | 142 | 143 | def bezier_curve_length(control_points, num_samples): 144 | def binomial_coefficient(n, i): 145 | return math.factorial(n) // (math.factorial(i) * math.factorial(n - i)) 146 | 147 | def derivative_bezier(t): 148 | n = len(control_points) - 1 149 | point = np.array([0.0, 0.0, 0.0]) 150 | for i, (p1, p2) in enumerate(zip(control_points[:-1], control_points[1:])): 151 | point += ( 152 | n 153 | * binomial_coefficient(n - 1, i) 154 | * (1 - t) ** (n - 1 - i) 155 | * t**i 156 | * (np.array(p2) - np.array(p1)) 157 | ) 158 | return point 159 | 160 | def simpson_integral(a, b, num_samples): 161 | h = (b - a) / num_samples 162 | sum1 = sum( 163 | np.linalg.norm(derivative_bezier(a + i * h)) 164 | for i in range(1, num_samples, 2) 165 | ) 166 | sum2 = sum( 167 | np.linalg.norm(derivative_bezier(a + i * h)) 168 | for i in range(2, num_samples - 1, 2) 169 | ) 170 | return ( 171 | ( 172 | np.linalg.norm(derivative_bezier(a)) 173 | + 4 * sum1 174 | + 2 * sum2 175 | + np.linalg.norm(derivative_bezier(b)) 176 | ) 177 | * h 178 | / 3 179 | ) 180 | 181 | # Compute the length of the 3D Bezier curve using composite Simpson's rule 182 | length = 0.0 183 | for i in range(num_samples): 184 | t0 = i / num_samples 185 | t1 = (i + 1) / num_samples 186 | length += simpson_integral(t0, t1, num_samples) 187 | 188 | return length 189 | 190 | 191 | def bezier_para_to_point_length(control_points, num_samples=100): 192 | t_fit = np.linspace(0, 1, num_samples) 193 | curve_point_set = [] 194 | curve_length_set = [] 195 | for control_point in control_points: 196 | points = bezier_curve( 197 | t_fit, 198 | control_point[0, 0], 199 | control_point[0, 1], 200 | control_point[0, 2], 201 | control_point[1, 0], 202 | control_point[1, 1], 203 | control_point[1, 2], 204 | control_point[2, 0], 205 | control_point[2, 1], 206 | control_point[2, 2], 207 | control_point[3, 0], 208 | control_point[3, 1], 209 | control_point[3, 2], 210 | ) 211 | lengths = bezier_curve_length(control_point, num_samples=num_samples) 212 | curve_point_set.append(points) 213 | curve_length_set.append(lengths) 214 | return ( 215 | np.array(curve_point_set).reshape(-1, num_samples, 3, 1), 216 | np.array(curve_length_set), 217 | ) 218 | 219 | 220 | def main(gt_point_cloud_dir, dataset_dir, out_dir): 221 | set_random_seeds() 222 | gt_point_cloud_dir = os.path.join(gt_point_cloud_dir, "Points", "stl") 223 | if not os.path.exists(gt_point_cloud_dir): 224 | print( 225 | f"Ground truth point cloud directory {gt_point_cloud_dir} does not exist. Please download it from http://roboimagedata2.compute.dtu.dk/data/MVS/Points.zip" 226 | ) 227 | return 228 | 229 | scan_names_dict = { 230 | "scan37": [0.55, 0.3], 231 | "scan83": [0.65, 0.2], 232 | "scan105": [0.65, 0.2], 233 | "scan110": [0.5, 0.3], 234 | "scan118": [0.5, 0.3], 235 | "scan122": [0.35, 0.4], 236 | } 237 | 238 | os.makedirs(out_dir, exist_ok=True) 239 | 240 | for scan_name, ( 241 | edge_visibility_threshold, 242 | edge_visibility_frames_ratio, 243 | ) in scan_names_dict.items(): 244 | output_file = os.path.join(out_dir, scan_name, "edge_points.ply") 245 | if os.path.exists(output_file): 246 | print(f"{output_file} already exists. Skipping.") 247 | continue 248 | os.makedirs(os.path.join(out_dir, scan_name), exist_ok=True) 249 | meta_data_json_path = os.path.join(dataset_dir, scan_name, "meta_data.json") 250 | meta_base_dir = os.path.join(dataset_dir, scan_name) 251 | worldtogt = np.array(load_from_json(Path(meta_data_json_path))["worldtogt"]) 252 | gttoworld = np.linalg.inv(worldtogt) 253 | gt_point_cloud_path = os.path.join( 254 | gt_point_cloud_dir, 255 | f"stl{int(scan_name[4:]):03d}_total.ply", 256 | ) 257 | gt_point_cloud = o3d.io.read_point_cloud(gt_point_cloud_path) 258 | 259 | gt_points = np.asarray(gt_point_cloud.points) 260 | points = gt_points @ gttoworld[:3, :3].T + gttoworld[:3, 3][None, ...] 261 | 262 | edge_maps, intrinsics_list, camtoworld_list, h, w = get_edge_maps(meta_base_dir) 263 | num_frames = len(edge_maps) 264 | edge_visibility_frames = max( 265 | 1, round(edge_visibility_frames_ratio * num_frames) 266 | ) 267 | points_visibility = compute_visibility( 268 | points, 269 | edge_maps, 270 | intrinsics_list, 271 | camtoworld_list, 272 | h, 273 | w, 274 | edge_visibility_threshold, 275 | edge_visibility_frames, 276 | ) 277 | 278 | print( 279 | f"{scan_name}: before visibility check: {len(points)}, after visibility check: {np.sum(points_visibility)}" 280 | ) 281 | 282 | edge_points = points[points_visibility] 283 | downsampled_edge_points = downsample_point_cloud_average( 284 | edge_points, num_voxels_per_axis=256 285 | ) 286 | downsampled_edge_points = ( 287 | downsampled_edge_points @ worldtogt[:3, :3].T + worldtogt[:3, 3][None, ...] 288 | ) 289 | save_point_cloud(output_file, downsampled_edge_points) 290 | print(f"Saved downsampled edge point cloud to {output_file}") 291 | 292 | 293 | if __name__ == "__main__": 294 | parser = argparse.ArgumentParser( 295 | description="Process and evaluate point cloud data." 296 | ) 297 | parser.add_argument( 298 | "--gt_point_cloud_dir", 299 | type=str, 300 | default="data/DTU_Edge/groundtruth", 301 | ) 302 | parser.add_argument("--dataset_dir", type=str, default="data/DTU_Edge/data") 303 | parser.add_argument( 304 | "--out_dir", 305 | type=str, 306 | default="data/DTU_Edge/groundtruth/edge_points", 307 | ) 308 | args = parser.parse_args() 309 | 310 | main(args.gt_point_cloud_dir, args.dataset_dir, args.out_dir) 311 | -------------------------------------------------------------------------------- /scripts/run_ABC.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # source ~/miniconda3/etc/profile.d/conda.sh 5 | # conda activate emap 6 | 7 | # Set the PYTHONPATH environment variable 8 | export PYTHONPATH=. 9 | 10 | # Set the device for CUDA to use 11 | export CUDA_VISIBLE_DEVICES=0 12 | 13 | # Train UDF field 14 | python main.py --conf ./confs/ABC.conf --mode train 15 | 16 | # Extract parametric edges 17 | python main.py --conf ./confs/ABC.conf --mode extract_edge 18 | -------------------------------------------------------------------------------- /scripts/run_DTU.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # source ~/miniconda3/etc/profile.d/conda.sh 5 | # conda activate emap 6 | 7 | # Set the PYTHONPATH environment variable 8 | export PYTHONPATH=. 9 | 10 | # Set the device for CUDA to use 11 | export CUDA_VISIBLE_DEVICES=0 12 | 13 | # Train UDF field 14 | python main.py --conf ./confs/DTU.conf --mode train 15 | 16 | # Extract parametric edges 17 | python main.py --conf ./confs/DTU.conf --mode extract_edge 18 | -------------------------------------------------------------------------------- /scripts/run_Replica.bash: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | set -e 3 | 4 | # source ~/miniconda3/etc/profile.d/conda.sh 5 | # conda activate emap 6 | 7 | # Set the PYTHONPATH environment variable 8 | export PYTHONPATH=. 9 | 10 | # Set the device for CUDA to use 11 | export CUDA_VISIBLE_DEVICES=0 12 | 13 | # Train UDF field 14 | python main.py --conf ./confs/Replica.conf --mode train 15 | 16 | # Extract parametric edges 17 | python main.py --conf ./confs/Replica.conf --mode extract_edge 18 | -------------------------------------------------------------------------------- /src/dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import cv2 as cv 4 | import numpy as np 5 | import os 6 | from scipy.spatial.transform import Rotation as Rot 7 | from scipy.spatial.transform import Slerp 8 | from pathlib import Path 9 | import json 10 | import random 11 | 12 | 13 | def load_from_json(filename: Path): 14 | """Load a dictionary from a JSON filename. 15 | 16 | Args: 17 | filename: The filename to load from. 18 | """ 19 | assert filename.suffix == ".json" 20 | with open(filename, encoding="UTF-8") as file: 21 | return json.load(file) 22 | 23 | 24 | # This function is borrowed from IDR: https://github.com/lioryariv/idr 25 | def load_K_Rt_from_P(filename, P=None): 26 | if P is None: 27 | lines = open(filename).read().splitlines() 28 | if len(lines) == 4: 29 | lines = lines[1:] 30 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 31 | P = np.asarray(lines).astype(np.float32).squeeze() 32 | 33 | out = cv.decomposeProjectionMatrix(P) 34 | K = out[0] 35 | R = out[1] 36 | t = out[2] 37 | 38 | K = K / K[2, 2] 39 | intrinsics = np.eye(4) 40 | intrinsics[:3, :3] = K 41 | 42 | pose = np.eye(4, dtype=np.float32) 43 | pose[:3, :3] = R.transpose() 44 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 45 | 46 | return intrinsics, pose 47 | 48 | 49 | class Dataset: 50 | def __init__(self, conf): 51 | super(Dataset, self).__init__() 52 | print("Load data: Begin") 53 | self.device = ( 54 | torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu") 55 | ) 56 | self.conf = conf 57 | self.scan = conf.get_string("scan") 58 | self.data_dir = os.path.join(conf.get_string("data_dir"), self.scan) 59 | self.dataset_name = self.conf.get_string("dataset_name", default="ABC") 60 | self.detector = conf.get_string("detector", default="DexiNed") 61 | assert self.detector in ["DexiNed", "PidiNet"] 62 | self.load_metadata(conf) 63 | self.load_image_data() 64 | print("Load data: End") 65 | 66 | def load_metadata(self, conf): 67 | meta = load_from_json(Path(self.data_dir) / "meta_data.json") 68 | self.meta = meta 69 | self.intrinsics_all = [] 70 | self.pose_all = [] 71 | self.edges_list = [] 72 | self.colors_list = [] 73 | 74 | meta_scene_box = meta["scene_box"] 75 | 76 | self.near = meta_scene_box["near"] 77 | self.far = meta_scene_box["far"] 78 | self.radius = meta_scene_box["radius"] 79 | 80 | H, W = meta["height"], meta["width"] 81 | self.H, self.W, self.image_pixels = H, W, H * W 82 | 83 | for idx, frame in enumerate(meta["frames"]): 84 | self.process_frame(frame) 85 | 86 | def process_frame(self, frame): 87 | intrinsics = torch.tensor(frame["intrinsics"]) 88 | camtoworld = torch.tensor(frame["camtoworld"])[:4, :4] 89 | image_name = frame["rgb_path"] 90 | if self.detector == "PidiNet": 91 | self.edges_list.append( 92 | os.path.join( 93 | self.data_dir, 94 | "edge_PidiNet", 95 | image_name[:-4] + ".png", 96 | ) 97 | ) 98 | elif self.detector == "DexiNed": 99 | self.edges_list.append( 100 | os.path.join(self.data_dir, "edge_DexiNed", image_name) 101 | ) 102 | self.colors_list.append(os.path.join(self.data_dir, "color", image_name)) 103 | self.intrinsics_all.append(intrinsics) 104 | self.pose_all.append(camtoworld) 105 | 106 | def load_image_data(self): 107 | self.load_edges_data() 108 | self.colors_np = ( 109 | np.stack([cv.imread(im_name) for im_name in self.colors_list]) / 255.0 110 | ) 111 | 112 | self.n_images = len(self.edges_list) 113 | self.edges = torch.from_numpy( 114 | self.edges_np.astype(np.float32) 115 | ) # .to(self.device) 116 | self.colors = torch.from_numpy(self.colors_np.astype(np.float32)) 117 | 118 | # .to( 119 | # self.device 120 | # ) 121 | 122 | self.masks_np = (self.edges_np > 0.5).astype(np.float32) 123 | self.masks = torch.from_numpy(self.masks_np.astype(np.float32)) 124 | 125 | self.intrinsics_all = torch.stack(self.intrinsics_all) 126 | self.intrinsics_all_inv = torch.inverse(self.intrinsics_all) 127 | self.focal = self.intrinsics_all[0][0, 0] 128 | self.pose_all = torch.stack(self.pose_all) 129 | 130 | self.object_bbox_min = np.array(self.meta["scene_box"]["aabb"][0]) 131 | self.object_bbox_max = np.array(self.meta["scene_box"]["aabb"][1]) 132 | 133 | def load_edges_data(self): 134 | edges = [cv.imread(im_name, 0)[..., None] for im_name in self.edges_list] 135 | self.edges_np = np.stack(edges) / 255.0 136 | 137 | def gen_rays_at(self, img_idx, resolution_level=1): 138 | """ 139 | Generate rays at world space from one camera. 140 | """ 141 | l = resolution_level 142 | tx = torch.linspace(0, self.W - 1, self.W // l) 143 | ty = torch.linspace(0, self.H - 1, self.H // l) 144 | pixels_x, pixels_y = torch.meshgrid(tx, ty) 145 | p = torch.stack( 146 | [pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1 147 | ) # W, H, 3 148 | p = torch.matmul( 149 | self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None] 150 | ).squeeze() # W, H, 3 151 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3 152 | depth_scale = rays_v[:, :, 2:] 153 | rays_v = torch.matmul( 154 | self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None] 155 | ).squeeze() # W, H, 3 156 | rays_o = self.pose_all[img_idx, None, None, :3, 3].expand( 157 | rays_v.shape 158 | ) # W, H, 3 159 | pose = self.pose_all[img_idx] # [4, 4] 160 | intrinsics = self.intrinsics_all[img_idx] # [4, 4] 161 | return ( 162 | rays_o.transpose(0, 1).to(self.device), 163 | rays_v.transpose(0, 1).to(self.device), 164 | pose.to(self.device), 165 | intrinsics.to(self.device), 166 | depth_scale.to(self.device), 167 | ) 168 | 169 | def gen_one_ray_at(self, img_idx, x, y): 170 | """ 171 | 172 | Parameters 173 | ---------- 174 | img_idx : 175 | x : for width 176 | y : for height 177 | 178 | Returns 179 | ------- 180 | 181 | """ 182 | image = np.uint8(self.edges_np[img_idx] * 256) 183 | image2 = cv.circle(image, (x, y), radius=10, color=(0, 0, 255), thickness=-1) 184 | 185 | pixels_x = torch.Tensor([x]).long() 186 | pixels_y = torch.Tensor([y]).long() 187 | edge = self.edges[img_idx][(pixels_y, pixels_x)] # batch_size, 1 188 | color = self.colors[img_idx][(pixels_y, pixels_x)] # batch_size, 3 189 | mask = (self.masks[img_idx][(pixels_y, pixels_x)] > 0).to( 190 | torch.float32 191 | ) # batch_size, 1 192 | 193 | p = torch.stack( 194 | [pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1 195 | ).float() # batch_size, 3 196 | p = torch.matmul( 197 | self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None] 198 | ).squeeze( 199 | -1 200 | ) # batch_size, 3 201 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3 202 | rays_v = torch.matmul( 203 | self.pose_all[img_idx, None, :3, :3], rays_v[:, :, None] 204 | ).squeeze( 205 | -1 206 | ) # batch_size, 3 207 | rays_o = self.pose_all[img_idx, None, :3, 3].expand( 208 | rays_v.shape 209 | ) # batch_size, 3 210 | 211 | return ( 212 | { 213 | "rays_o": rays_o.to(self.device), 214 | "rays_v": rays_v.to(self.device), 215 | "edge": edge.to(self.device), 216 | "color": color.to(self.device), 217 | "mask": mask[:, :1].to(self.device), 218 | }, 219 | image2, 220 | ) 221 | 222 | def gen_random_rays_patches_at( 223 | self, 224 | img_idx, 225 | batch_size, 226 | importance_sample=False, 227 | ): 228 | """ 229 | Generate random rays at world space from one camera. 230 | """ 231 | 232 | if not importance_sample: 233 | pixels_x = torch.randint(low=0, high=self.W, size=[batch_size]) 234 | pixels_y = torch.randint(low=0, high=self.H, size=[batch_size]) 235 | elif ( 236 | importance_sample and self.masks is not None 237 | ): # sample more pts in the valid mask regions 238 | img_np = self.edges[img_idx].cpu().numpy() 239 | edge_density = np.mean(img_np) 240 | probabilities = np.ones_like(img_np) * edge_density 241 | probabilities[img_np > 0.1] = 1.0 - edge_density 242 | probabilities = probabilities.reshape(-1) 243 | 244 | # randomly sample 50% 245 | pixels_x_1 = torch.randint(low=0, high=self.W, size=[batch_size // 2]) 246 | pixels_y_1 = torch.randint(low=0, high=self.H, size=[batch_size // 2]) 247 | 248 | ys, xs = torch.meshgrid( 249 | torch.linspace(0, self.H - 1, self.H), 250 | torch.linspace(0, self.W - 1, self.W), 251 | ) # pytorch's meshgrid has indexing='ij' 252 | p = torch.stack([xs, ys], dim=-1) # H, W, 2 253 | p_valid = p[self.masks[img_idx][:, :, 0] >= 0] # [num, 2] 254 | 255 | # randomly sample 50% mainly from edge regions 256 | number_list = np.arange(self.image_pixels) 257 | random_idx = random.choices(number_list, probabilities, k=batch_size // 2) 258 | random_idx = torch.from_numpy(np.array(random_idx)).to( 259 | torch.int64 260 | ) # .to(self.device) 261 | p_select = p_valid[random_idx] # [N_rays//2, 2] 262 | pixels_x_2 = p_select[:, 0] 263 | pixels_y_2 = p_select[:, 1] 264 | 265 | pixels_x = torch.cat([pixels_x_1, pixels_x_2], dim=0).to(torch.int64) 266 | pixels_y = torch.cat([pixels_y_1, pixels_y_2], dim=0).to(torch.int64) 267 | # normalized ndc uv coordinates, (-1, 1) 268 | ndc_u = 2 * pixels_x / (self.W - 1) - 1 269 | ndc_v = 2 * pixels_y / (self.H - 1) - 1 270 | rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float() 271 | 272 | edge = self.edges[img_idx][(pixels_y, pixels_x)] # batch_size, 1 273 | p = torch.stack( 274 | [pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1 275 | ).float() # batch_size, 3 276 | p = torch.matmul( 277 | self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None] 278 | ).squeeze() # batch_size, 3 279 | 280 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3 281 | depth_scale = rays_v[:, 2:] # batch_size, 1 282 | rays_v = torch.matmul( 283 | self.pose_all[img_idx, None, :3, :3], rays_v[:, :, None] 284 | ).squeeze() # batch_size, 3 285 | rays_o = self.pose_all[img_idx, None, :3, 3].expand( 286 | rays_v.shape 287 | ) # batch_size, 3 288 | 289 | rays = { 290 | "rays_o": rays_o.to(self.device), 291 | "rays_v": rays_v.to(self.device), 292 | "edge": edge.to(self.device), 293 | } 294 | 295 | pose = self.pose_all[img_idx] # [4, 4] 296 | intrinsics = self.intrinsics_all[img_idx] # [4, 4] 297 | 298 | sample = { 299 | "rays": rays, 300 | "pose": pose, 301 | "intrinsics": intrinsics, 302 | "rays_ndc_uv": rays_ndc_uv.to(self.device), 303 | "rays_norm_XYZ_cam": p.to(self.device), # - XYZ_cam, before multiply depth, 304 | "depth_scale": depth_scale.to(self.device), 305 | } 306 | 307 | return sample 308 | 309 | def edge_at(self, idx, resolution_level): 310 | edge = cv.imread(self.edges_list[idx], 0)[..., None] 311 | edge = ( 312 | cv.resize(edge, (self.W // resolution_level, self.H // resolution_level)) 313 | ).clip(0, 255) 314 | return edge 315 | 316 | def color_at(self, idx, resolution_level): 317 | img = cv.imread(self.colors_list[idx]) 318 | img = cv.resize( 319 | img, 320 | (self.W // resolution_level, self.H // resolution_level), 321 | interpolation=cv.INTER_NEAREST, 322 | ) 323 | return img 324 | -------------------------------------------------------------------------------- /src/edge_extraction/edge_fitting/bezier_fit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.optimize import curve_fit 3 | 4 | 5 | def bezier_curve(tt, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11): 6 | n = len(tt) 7 | matrix_t = np.concatenate( 8 | [(tt**3)[..., None], (tt**2)[..., None], tt[..., None], np.ones((n, 1))], 9 | axis=1, 10 | ).astype(float) 11 | matrix_w = np.array( 12 | [[-1, 3, -3, 1], [3, -6, 3, 0], [-3, 3, 0, 0], [1, 0, 0, 0]] 13 | ).astype(float) 14 | matrix_p = np.array( 15 | [[p0, p1, p2], [p3, p4, p5], [p6, p7, p8], [p9, p10, p11]] 16 | ).astype(float) 17 | return np.dot(np.dot(matrix_t, matrix_w), matrix_p).reshape(-1) 18 | 19 | 20 | def bezier_fit(xyz, error_threshold=1.0): 21 | n = len(xyz) 22 | t = np.linspace(0, 1, n) 23 | xyz = xyz.reshape(-1) 24 | 25 | popt, _ = curve_fit(bezier_curve, t, xyz) 26 | 27 | # Generate fitted curve 28 | fitted_curve = bezier_curve(t, *popt).reshape(-1, 3) 29 | 30 | # Calculate residuals 31 | residuals = xyz.reshape(-1, 3) - fitted_curve 32 | 33 | # Calculate RMSE 34 | rmse = np.sqrt(np.mean(np.sum(residuals**2, axis=1))) 35 | 36 | if rmse > error_threshold: 37 | return None 38 | else: 39 | return popt 40 | -------------------------------------------------------------------------------- /src/edge_extraction/edge_fitting/line_fit.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def split_into_monotonic_sublists(numbers, max_longsublists=2, min_length=4): 5 | if not numbers: 6 | return [] 7 | 8 | # Initialize list to store continuous and monotonic sublists 9 | monotonic_sublists = [] 10 | current_sublist = [numbers[0]] 11 | 12 | # Create continuous and monotonic sublists from the original list 13 | for i in range(1, len(numbers)): 14 | if numbers[i] == numbers[i - 1] + 1: 15 | current_sublist.append(numbers[i]) 16 | else: 17 | if len(current_sublist) > 1: 18 | monotonic_sublists.append(tuple(current_sublist)) 19 | current_sublist = [numbers[i]] 20 | 21 | # Add the last continuous and monotonic sublist if it contains more than one element 22 | if len(current_sublist) > 1: 23 | monotonic_sublists.append(tuple(current_sublist)) 24 | 25 | # Convert to set and back to list to remove duplicates 26 | monotonic_sublists = list(set(monotonic_sublists)) 27 | # Sort the sublists by length in descending order 28 | monotonic_sublists.sort(key=len, reverse=True) 29 | 30 | # Keep the specified number of longest sublists 31 | max_sublists = min(max_longsublists, len(monotonic_sublists)) 32 | long_sublists = monotonic_sublists[:max_sublists] 33 | short_sublists = monotonic_sublists[max_sublists:] 34 | 35 | # Handle sublists that are too short 36 | sublists_out_curves = [] 37 | for sublist in long_sublists: 38 | if len(sublist) < min_length: 39 | short_sublists.append(sublist) 40 | else: 41 | sublists_out_curves.append(list(sublist)) 42 | 43 | # Split the remaining short sublists into pairs of numbers 44 | sublists_out_lines = [] 45 | for sublist in short_sublists: 46 | for j in range(len(sublist) - 1): 47 | sublists_out_lines.append([sublist[j], sublist[j + 1]]) 48 | 49 | return [list(t) for t in sublists_out_curves], [list(t) for t in sublists_out_lines] 50 | 51 | 52 | def fit_line_ransac_3d( 53 | points_wld, 54 | voxel_size=256, 55 | max_iterations=100, 56 | min_inliers=4, 57 | max_lines=3, 58 | max_curves=2, 59 | keep_short_lines=False, 60 | ransac_with_direction=False, 61 | ): 62 | """ 63 | Fit multiple lines to 3D points using RANSAC. 64 | 65 | Parameters: 66 | - points (numpy.ndarray): Array of 3D points. 67 | - voxel_size (float): Voxel size for inlier distance threshold. 68 | - max_iterations (int): Maximum number of RANSAC iterations. 69 | - min_inliers (int): Minimum number of inliers required to consider a line. 70 | - max_lines (int): Maximum number of lines to fit. 71 | 72 | Returns: 73 | - best_endpoints (list): List of line endpoints (start and end points). 74 | - split_points (list): List of points belonging to each fitted line. 75 | - remaining_points (numpy.ndarray): Points not assigned to any line. 76 | - remaining_point_indices (list): Indices of remaining points in the original input points. 77 | """ 78 | inlier_dist_threshold = 1.0 / voxel_size # 1.0 / voxel_size 79 | best_lines = [] 80 | best_endpoints = [] 81 | split_points = [] 82 | # remaining_point_indices = [] # List to store indices of remaining points 83 | N_points = len(points_wld) 84 | remaining_point_indices = np.arange(N_points) 85 | min_inlier_ratio = 1.0 / max_lines 86 | raw_points_wld = points_wld.copy() 87 | 88 | while max_lines and len(points_wld) >= min_inliers: 89 | max_lines -= 1 90 | best_line = None 91 | best_inliers_mask = None 92 | best_num_inliers = 0 93 | 94 | if not ransac_with_direction: 95 | for _ in range(max_iterations): 96 | # Generate all unique combinations of point pairs 97 | sample_indices = np.random.choice(len(points_wld), 2, replace=False) 98 | sample_points_wld = points_wld[sample_indices, :3] 99 | 100 | p1, p2 = sample_points_wld 101 | direction = p2 - p1 102 | 103 | if np.linalg.norm(direction) < 1e-6: 104 | continue 105 | 106 | direction /= np.linalg.norm(direction) 107 | 108 | distances = np.linalg.norm( 109 | np.cross(points_wld[:, :3] - p1, direction), axis=1 110 | ) 111 | 112 | inliers_mask = distances < inlier_dist_threshold 113 | num_inliers = np.sum(inliers_mask) 114 | 115 | if num_inliers > best_num_inliers: 116 | best_line = (p1, direction) 117 | best_num_inliers = num_inliers 118 | best_inliers_mask = inliers_mask 119 | 120 | else: 121 | points = points_wld[:, :3] 122 | direction = points_wld[:, 3:] 123 | normalized_direction = direction / np.linalg.norm( 124 | direction, axis=1, keepdims=True 125 | ) 126 | 127 | distance = np.linalg.norm( 128 | np.cross(points - points[:, None], normalized_direction), axis=2 129 | ) # N x N 130 | inlier_mask = distance < inlier_dist_threshold 131 | num_inliers = np.sum(inlier_mask, axis=1) 132 | 133 | best_inlier_idx = np.argmax(num_inliers) 134 | best_line = (points[best_inlier_idx], direction[best_inlier_idx]) 135 | best_inliers_mask = inlier_mask[best_inlier_idx] 136 | best_num_inliers = num_inliers[best_inlier_idx] 137 | 138 | if best_num_inliers >= min_inliers: 139 | p1, direction = best_line 140 | inlier_points = points_wld[best_inliers_mask, :3] 141 | inlier_ratio_pred = best_num_inliers / N_points 142 | if inlier_ratio_pred < min_inlier_ratio: 143 | break 144 | 145 | center = np.mean(inlier_points, axis=0) 146 | endpoints_centered = inlier_points - center 147 | u, s, vh = np.linalg.svd(endpoints_centered, full_matrices=False) 148 | updated_direction = vh[0] 149 | updated_direction = updated_direction / np.linalg.norm(updated_direction) 150 | 151 | projections = np.dot(inlier_points - p1, updated_direction) 152 | line_segment = np.zeros(6) 153 | line_segment[:3] = p1 + np.min(projections) * updated_direction 154 | line_segment[3:] = p1 + np.max(projections) * updated_direction 155 | 156 | points_wld = points_wld[~best_inliers_mask] 157 | split_points.append(inlier_points.tolist()) 158 | remaining_point_indices = remaining_point_indices[~best_inliers_mask] 159 | 160 | best_lines.append(best_line) 161 | best_endpoints.append(line_segment) 162 | 163 | # find potential curve points 164 | if len(remaining_point_indices) > 0: 165 | potential_curve_indices, shortline_indices = split_into_monotonic_sublists( 166 | remaining_point_indices.tolist(), max_curves 167 | ) 168 | potential_curve_points = [ 169 | raw_points_wld[potential_curve_indice, :3] 170 | for potential_curve_indice in potential_curve_indices 171 | ] 172 | if keep_short_lines and len(shortline_indices) > 0: 173 | shortline_points = raw_points_wld[shortline_indices, :3] 174 | shortline_points = shortline_points.reshape(-1, 6) 175 | best_endpoints.extend(shortline_points) 176 | split_points.extend(shortline_points.reshape(-1, 2, 3).tolist()) 177 | else: 178 | potential_curve_points = [] 179 | 180 | return best_endpoints, split_points, potential_curve_points 181 | 182 | 183 | def line_fitting(endpoints): 184 | center = np.mean(endpoints, axis=0) 185 | 186 | # compute the main direction through SVD 187 | endpoints_centered = endpoints - center 188 | u, s, vh = np.linalg.svd(endpoints_centered, full_matrices=False) 189 | lamda = s[0] / np.sum(s) 190 | main_direction = vh[0] 191 | main_direction = main_direction / np.linalg.norm(main_direction) 192 | 193 | # project endpoints onto the main direction 194 | projections = [] 195 | for endpoint_centered in endpoints_centered: 196 | projections.append(np.dot(endpoint_centered, main_direction)) 197 | projections = np.array(projections) 198 | 199 | # construct final line 200 | straight_line = np.zeros(6) 201 | # print(np.min(projections), np.max(projections)) 202 | straight_line[:3] = center + main_direction * np.min(projections) 203 | straight_line[3:] = center + main_direction * np.max(projections) 204 | 205 | return straight_line, lamda 206 | 207 | 208 | def lines_fitting(lines, lamda_threshold): 209 | straight_lines = [] 210 | curve_line_segments_candidate = [] 211 | curves = [] 212 | lamda_list = [] 213 | for endpoints in lines: 214 | # merge line segments into a final line segment 215 | # total least squares on endpoints 216 | straight_line, lamda = line_fitting(endpoints) 217 | lamda_list.append(lamda) 218 | if lamda < lamda_threshold: 219 | curves.append(endpoints) 220 | curve_line_segments_candidate.append( 221 | [ 222 | np.hstack([endpoints[i], endpoints[i + 1]]) 223 | for i in range(len(endpoints) - 1) 224 | ] 225 | ) 226 | continue 227 | 228 | straight_lines.append(straight_line) 229 | straight_lines = np.array(straight_lines) 230 | curve_line_segments_candidate = np.array(curve_line_segments_candidate) 231 | return ( 232 | straight_lines, 233 | curve_line_segments_candidate, 234 | curves, 235 | lamda_list, 236 | ) 237 | -------------------------------------------------------------------------------- /src/edge_extraction/edge_fitting/main.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy.spatial.distance import cdist 3 | import open3d as o3d 4 | import random 5 | from src.edge_extraction.edge_fitting.bezier_fit import bezier_fit, bezier_curve 6 | from src.edge_extraction.edge_fitting.line_fit import fit_line_ransac_3d 7 | 8 | 9 | class LineSegment: 10 | def __init__(self, start_point, end_point): 11 | self.start_point = start_point 12 | self.end_point = end_point 13 | 14 | 15 | def generate_segments_from_idx(connected_lines, points_wld): 16 | segments = [] 17 | polylines_wld = [] 18 | for connected_line in connected_lines: 19 | polyline_wld = [points_wld[connected_line[0]].tolist()] 20 | for i in range(len(connected_line) - 1): 21 | segment = [ 22 | points_wld[connected_line[i]].tolist(), 23 | points_wld[connected_line[i + 1]].tolist(), 24 | ] 25 | segments.append(segment) 26 | polyline_wld.append(points_wld[connected_line[i + 1]]) 27 | polyline_wld = np.array(polyline_wld).reshape(-1, 6) 28 | polylines_wld.append(polyline_wld) 29 | 30 | return np.array(segments).reshape(-1, 6), polylines_wld 31 | 32 | 33 | def create_line_segments_from_3d_array(segment_data): 34 | x1, y1, z1, x2, y2, z2 = ( 35 | segment_data[:, 0], 36 | segment_data[:, 1], 37 | segment_data[:, 2], 38 | segment_data[:, 3], 39 | segment_data[:, 4], 40 | segment_data[:, 5], 41 | ) 42 | segments = [] 43 | for i in range(len(x1)): 44 | segments.append( 45 | LineSegment( 46 | np.array([x1[i], y1[i], z1[i]]), np.array([x2[i], y2[i], z2[i]]) 47 | ) 48 | ) 49 | return segments 50 | 51 | 52 | def is_point_inside_ranges(point, ranges): 53 | point = np.array(point) 54 | if not np.all(point > ranges[0]) or not np.all(point < ranges[1]): 55 | return False 56 | return True 57 | 58 | 59 | def is_line_inside_ranges(line_segment, ranges): 60 | if not is_point_inside_ranges(line_segment.start_point, ranges): 61 | return False 62 | if not is_point_inside_ranges(line_segment.end_point, ranges): 63 | return False 64 | return True 65 | 66 | 67 | def create_open3d_line_set( 68 | line_segments, color=[0.5, 0.5, 0.5], width=2, ranges=None, scale=1.0 69 | ): 70 | o3d_points, o3d_lines, o3d_colors = [], [], [] 71 | counter = 0 72 | for line_segment in line_segments: 73 | if ranges is not None: 74 | if not is_line_inside_ranges(line_segment, ranges): 75 | continue 76 | o3d_points.append(line_segment.start_point * scale) 77 | o3d_points.append(line_segment.end_point * scale) 78 | o3d_lines.append([2 * counter, 2 * counter + 1]) 79 | counter += 1 80 | line_set = o3d.geometry.LineSet() 81 | line_set.points = o3d.utility.Vector3dVector(o3d_points) 82 | line_set.lines = o3d.utility.Vector2iVector(o3d_lines) 83 | line_set.colors = o3d.utility.Vector3dVector(o3d_colors) 84 | return line_set 85 | 86 | 87 | def save_3d_lines_to_file(line_segments, filename, width=2, ranges=None, scale=1.0): 88 | lines = create_line_segments_from_3d_array(line_segments) 89 | line_set = create_open3d_line_set(lines, width=width, ranges=ranges, scale=scale) 90 | o3d.io.write_line_set(filename, line_set) 91 | 92 | 93 | def connect_points( 94 | points, distance_threshold, angle_threshold, nms_factor, keep_short_lines 95 | ): 96 | num_points = len(points) 97 | connected_line_segments = [] 98 | 99 | unvisited_points = set(range(num_points)) 100 | 101 | while len(unvisited_points) > 0: 102 | selected_point = np.random.choice(list(unvisited_points)) 103 | selected_point_opposite = selected_point 104 | 105 | unvisited_points.remove(selected_point) 106 | connected_line = [selected_point] 107 | 108 | while True: # forward connection 109 | dist = cdist( 110 | [points[selected_point, :3]], points[list(unvisited_points), :3] 111 | ) 112 | neighboring_points = np.where(dist < distance_threshold)[1] 113 | neighboring_distance = dist[0, neighboring_points].reshape(-1) 114 | 115 | neighboring_points = ( 116 | np.array(list(unvisited_points))[neighboring_points] 117 | ).tolist() 118 | 119 | if len(neighboring_points) == 0: 120 | break 121 | 122 | directions = ( 123 | points[neighboring_points, :3] - points[selected_point, :3][None, ...] 124 | ) 125 | directions /= np.linalg.norm(directions, axis=1)[:, np.newaxis] + 1e-6 126 | 127 | dot_products = np.dot(directions, points[selected_point, 3:]) 128 | 129 | closest_point_idx = np.argmax(dot_products) 130 | 131 | if ( 132 | dot_products[closest_point_idx] <= 1 - angle_threshold 133 | ): # no suitable point found 134 | break 135 | 136 | connected_line.append( 137 | neighboring_points[closest_point_idx] 138 | ) # add the point to the line 139 | 140 | invalid_points_idx = np.where( 141 | (neighboring_distance <= neighboring_distance[closest_point_idx]) 142 | * (dot_products < dot_products[closest_point_idx]) 143 | * (dot_products >= nms_factor * dot_products[closest_point_idx]) 144 | )[0] 145 | invalid_points = np.array(neighboring_points)[invalid_points_idx].tolist() 146 | 147 | unvisited_points.difference_update(invalid_points) 148 | 149 | if ( 150 | np.dot( 151 | points[neighboring_points[closest_point_idx], 3:], 152 | directions[closest_point_idx], 153 | ) 154 | <= 0.5 155 | ): # 156 | break 157 | 158 | unvisited_points.remove( 159 | neighboring_points[closest_point_idx] 160 | ) # remove connected point from unvisited points set 161 | selected_point = neighboring_points[ 162 | closest_point_idx 163 | ] # update anchor point 164 | 165 | while True: # backward connection 166 | dist = cdist( 167 | [points[selected_point_opposite, :3]], 168 | points[list(unvisited_points), :3], 169 | ) 170 | neighboring_points = np.where(dist < distance_threshold)[1] 171 | neighboring_distance = dist[0, neighboring_points].reshape(-1) 172 | 173 | neighboring_points = ( 174 | np.array(list(unvisited_points))[neighboring_points] 175 | ).tolist() 176 | 177 | if len(neighboring_points) == 0: 178 | break 179 | 180 | directions = ( 181 | points[neighboring_points, :3] 182 | - points[selected_point_opposite, :3][None, ...] 183 | ) 184 | directions /= np.linalg.norm(directions, axis=1)[:, np.newaxis] + 1e-6 185 | 186 | dot_products = np.dot(directions, points[selected_point_opposite, 3:]) 187 | 188 | closest_point_idx = np.argmin(dot_products) 189 | 190 | if ( 191 | abs(dot_products[closest_point_idx]) <= 1 - angle_threshold 192 | or dot_products[closest_point_idx] >= 0 193 | ): 194 | break 195 | 196 | connected_line.insert( 197 | 0, neighboring_points[closest_point_idx] 198 | ) # add connected point to the beginning of the line 199 | 200 | invalid_points_idx = np.where( 201 | (neighboring_distance <= neighboring_distance[closest_point_idx]) 202 | * (dot_products > dot_products[closest_point_idx]) 203 | * (dot_products <= nms_factor * dot_products[closest_point_idx]) 204 | )[0] 205 | invalid_points = np.array(neighboring_points)[invalid_points_idx].tolist() 206 | 207 | unvisited_points.difference_update(invalid_points) 208 | 209 | if ( 210 | np.dot( 211 | -points[neighboring_points[closest_point_idx], 3:], 212 | directions[closest_point_idx], 213 | ) 214 | <= 0.5 215 | ): 216 | break 217 | 218 | unvisited_points.remove(neighboring_points[closest_point_idx]) 219 | selected_point_opposite = neighboring_points[closest_point_idx] 220 | 221 | if not keep_short_lines: 222 | if len(connected_line) > 3: 223 | connected_line_segments.append(connected_line) 224 | else: 225 | if len(connected_line) > 1: 226 | connected_line_segments.append(connected_line) 227 | 228 | return connected_line_segments 229 | 230 | 231 | def edge_fitting( 232 | polylines_wld, 233 | voxel_size=256, 234 | max_iterations=100, 235 | min_inliers=4, 236 | max_lines=3, 237 | max_curves=2, 238 | keep_short_lines=True, 239 | ): 240 | straight_lines = [] 241 | raw_points_on_straight_lines = [] 242 | bezier_curve_params = [] 243 | bezier_curve_points = [] 244 | raw_points_on_curves = [] 245 | t_fit = np.linspace(0, 1, 100) 246 | for endpoints_wld in polylines_wld: 247 | if len(endpoints_wld) < 4 and keep_short_lines: # keep short lines 248 | 249 | for i in range(len(endpoints_wld) - 1): 250 | segment = [ 251 | endpoints_wld[i, :3], 252 | endpoints_wld[i + 1, :3], 253 | ] 254 | straight_lines.append(np.array(segment).reshape(-1)) 255 | raw_points_on_straight_lines.extend( 256 | [np.array(segment).reshape(-1, 3).tolist()] 257 | ) 258 | else: 259 | ( 260 | straight_line, 261 | split_points, 262 | potential_curve_points, 263 | ) = fit_line_ransac_3d( 264 | endpoints_wld, 265 | voxel_size, 266 | max_iterations, 267 | min_inliers, 268 | max_lines, 269 | max_curves, 270 | keep_short_lines, 271 | ) 272 | 273 | if len(split_points) >= 1: # fitted straight lines exist 274 | straight_lines.extend(straight_line) 275 | raw_points_on_straight_lines.extend(split_points) 276 | 277 | if len(potential_curve_points) >= 1: # fitted curves exist 278 | for curve_points in potential_curve_points: 279 | p = bezier_fit(curve_points, error_threshold=5.0 / voxel_size) 280 | if p is None: 281 | print("Fitting error too high, not fitting") 282 | continue 283 | bezier_curve_params.append(p) 284 | 285 | xyz_fit = bezier_curve(t_fit, *p).reshape(-1, 3) 286 | bezier_curve_points.append(xyz_fit) 287 | raw_points_on_curves.append(curve_points.tolist()) 288 | 289 | straight_lines = np.array(straight_lines) 290 | 291 | if len(bezier_curve_points) >= 1: 292 | bezier_curve_points = np.concatenate(bezier_curve_points, axis=0) 293 | bezier_curve_params = np.array(bezier_curve_params) 294 | 295 | return ( 296 | straight_lines, 297 | raw_points_on_straight_lines, 298 | bezier_curve_params, 299 | bezier_curve_points, 300 | raw_points_on_curves, 301 | ) 302 | 303 | 304 | def edge_fit( 305 | edge_data=None, 306 | angle_threshold=0.03, 307 | nms_factor=0.9, 308 | fit_distance_threshold=10.0, 309 | min_inliers=4, 310 | max_lines=4, 311 | max_curves=3, 312 | keep_short_lines=True, 313 | ): 314 | res = np.array(edge_data["resolution"]) 315 | raw_points = np.array(edge_data["points"]) 316 | raw_ld_colors = np.array(edge_data["ld_colors"]) 317 | pcd = o3d.geometry.PointCloud() 318 | pcd.points = o3d.utility.Vector3dVector(raw_points) 319 | pcd.colors = o3d.utility.Vector3dVector(raw_ld_colors) 320 | fit_distance_threshold = fit_distance_threshold / res 321 | pcd = pcd.voxel_down_sample(voxel_size=2.0 / res) 322 | 323 | points = np.asarray(pcd.points) 324 | ld_colors = np.asarray(pcd.colors) 325 | linedirection = ld_colors * 2 - 1 # recover the line direction 326 | linedirection = linedirection / ( 327 | np.linalg.norm(linedirection, axis=1)[:, np.newaxis] + 1e-6 328 | ) 329 | points_wld = np.concatenate((points, linedirection), axis=1) 330 | 331 | connected_line = connect_points( 332 | points_wld, 333 | fit_distance_threshold, 334 | angle_threshold, 335 | nms_factor, 336 | keep_short_lines, 337 | ) 338 | 339 | _, polylines_wld = generate_segments_from_idx(connected_line, points_wld) 340 | 341 | ( 342 | straight_lines, 343 | raw_points_on_straight_lines, 344 | bezier_curve_params, 345 | bezier_curve_points, 346 | raw_points_on_curves, 347 | ) = edge_fitting( 348 | polylines_wld, 349 | voxel_size=res, 350 | max_iterations=100, 351 | min_inliers=min_inliers, 352 | max_lines=max_lines, 353 | max_curves=max_curves, 354 | keep_short_lines=keep_short_lines, 355 | ) 356 | 357 | fitted_edge_dict = { 358 | "resolution": int(res), 359 | "lines_end_pts": straight_lines.tolist() if len(straight_lines) > 0 else [], 360 | "raw_points_on_lines": ( 361 | raw_points_on_straight_lines 362 | if len(raw_points_on_straight_lines) > 0 363 | else [] 364 | ), 365 | "curves_ctl_pts": ( 366 | bezier_curve_params.tolist() if len(bezier_curve_params) > 0 else [] 367 | ), 368 | "raw_points_on_curves": ( 369 | raw_points_on_curves if len(raw_points_on_curves) > 0 else [] 370 | ), 371 | } 372 | 373 | return fitted_edge_dict 374 | -------------------------------------------------------------------------------- /src/edge_extraction/extract_parametric_edge.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | from pathlib import Path 5 | import json 6 | import cv2 7 | import copy 8 | import math 9 | from src.edge_extraction.edge_fitting.main import edge_fit 10 | from src.edge_extraction.merging.main import merge 11 | from src.edge_extraction.extract_util import bezier_curve_length 12 | 13 | 14 | def load_from_json(filename: Path): 15 | """Load a dictionary from a JSON filename. 16 | 17 | Args: 18 | filename: The filename to load from. 19 | """ 20 | assert filename.suffix == ".json" 21 | with open(filename, encoding="UTF-8") as file: 22 | return json.load(file) 23 | 24 | 25 | def get_edge_maps(data_dir, detector): 26 | meta = load_from_json(Path(data_dir) / "meta_data.json") 27 | h, w = meta["height"], meta["width"] 28 | edges_list, intrinsics_list, camtoworld_list = [], [], [] 29 | for idx, frame in enumerate(meta["frames"]): 30 | intrinsics = np.array(frame["intrinsics"]) 31 | camtoworld = np.array(frame["camtoworld"])[:4, :4] 32 | if detector == "DexiNed": 33 | edges_list.append( 34 | os.path.join( 35 | data_dir, 36 | "edge_DexiNed", 37 | frame["rgb_path"], 38 | ) 39 | ) 40 | elif detector == "PidiNet": 41 | edges_list.append( 42 | os.path.join( 43 | data_dir, 44 | "edge_PidiNet", 45 | frame["rgb_path"][:-4] + ".png", 46 | ) 47 | ) 48 | else: 49 | raise ValueError(f"Unknown detector: {detector}") 50 | intrinsics_list.append(intrinsics) 51 | camtoworld_list.append(camtoworld) 52 | 53 | edges = [cv2.imread(im_name, 0)[..., None] for im_name in edges_list] 54 | 55 | if detector == "DexiNed": 56 | edges = 1 - np.stack(edges) / 255.0 57 | elif detector == "PidiNet": 58 | edges = np.stack(edges) / 255.0 59 | 60 | intrinsics_list = np.stack(intrinsics_list) 61 | camtoworld_list = np.stack(camtoworld_list) 62 | return edges, intrinsics_list, camtoworld_list, h, w 63 | 64 | 65 | def process_geometry_data( 66 | edge_dict, 67 | worldtogt=None, 68 | valid_curve=None, 69 | valid_line=None, 70 | sample_resolution=0.005, 71 | ): 72 | """ 73 | Processes edge data to transform and sample points from geometric data (curves and lines). 74 | Optionally transforms points to a target coordinate system and filters specific geometries. 75 | 76 | Parameters: 77 | edge_dict (dict): Dictionary containing 'curves_ctl_pts' and 'lines_end_pts'. 78 | worldtogt (np.array, optional): Transformation matrix to convert points to another coordinate system. 79 | valid_curve (np.array, optional): Indices to filter specific curves. 80 | valid_line (np.array, optional): Indices to filter specific lines. 81 | sample_resolution (float): Sampling resolution for generating points. 82 | 83 | Returns: 84 | np.array: Array of sampled points. 85 | int: Number of curve points generated. 86 | """ 87 | # Process curves 88 | return_edge_dict = {} 89 | curve_data = edge_dict["curves_ctl_pts"] 90 | curve_paras = np.array(curve_data).reshape(-1, 12) 91 | if valid_curve is not None: 92 | curve_paras = curve_paras[valid_curve] 93 | curve_paras = curve_paras.reshape(-1, 4, 3) 94 | return_edge_dict["curves_ctl_pts"] = curve_paras.tolist() 95 | 96 | if worldtogt is not None: 97 | curve_paras = curve_paras @ worldtogt[:3, :3].T + worldtogt[:3, 3] 98 | 99 | # Process lines 100 | line_data = edge_dict["lines_end_pts"] 101 | lines = np.array(line_data).reshape(-1, 6) 102 | if valid_line is not None: 103 | lines = lines[valid_line] 104 | 105 | return_edge_dict["lines_end_pts"] = lines.tolist() 106 | 107 | lines = lines.reshape(-1, 2, 3) 108 | 109 | if worldtogt is not None: 110 | lines = lines @ worldtogt[:3, :3].T + worldtogt[:3, 3] 111 | 112 | all_points = [] 113 | 114 | # Sample curves 115 | for curve in curve_paras: 116 | sample_num = int( 117 | bezier_curve_length(curve, num_samples=100) // sample_resolution 118 | ) 119 | t = np.linspace(0, 1, sample_num) 120 | coefficients = np.array( 121 | [[-1, 3, -3, 1], [3, -6, 3, 0], [-3, 3, 0, 0], [1, 0, 0, 0]] 122 | ) 123 | matrix_u = np.array([t**3, t**2, t, np.ones_like(t)]) 124 | points = matrix_u.T.dot(coefficients).dot(curve) 125 | all_points.extend(points.tolist()) 126 | 127 | # Sample lines 128 | for line in lines: 129 | sample_num = int(np.linalg.norm(line[0] - line[1]) // sample_resolution) 130 | t = np.linspace(0, 1, sample_num) 131 | line_points = np.outer(t, line[1] - line[0]) + line[0] 132 | all_points.extend(line_points.tolist()) 133 | 134 | return np.array(all_points, dtype=np.float32), return_edge_dict 135 | 136 | 137 | def compute_visibility( 138 | all_curve_points, 139 | all_line_points, 140 | edges, 141 | intrinsics_list, 142 | camtoworld_list, 143 | h, 144 | w, 145 | edge_visibility_threshold, 146 | edge_visibility_frames, 147 | ): 148 | img_frames = len(edges) 149 | curve_num, line_num = len(all_curve_points), len(all_line_points) 150 | edge_num = curve_num + line_num 151 | edge_visibility_matrix = np.zeros((edge_num, img_frames)) 152 | 153 | for frame_idx, (edge_map, intrinsic, camtoworld) in enumerate( 154 | zip(edges, intrinsics_list, camtoworld_list) 155 | ): 156 | K = intrinsic[:3, :3] 157 | worldtocam = np.linalg.inv(camtoworld) 158 | R = worldtocam[:3, :3] 159 | T = worldtocam[:3, 3:] 160 | 161 | all_curve_uv, all_line_uv = project2D( 162 | K, R, T, copy.deepcopy(all_curve_points), copy.deepcopy(all_line_points) 163 | ) 164 | all_edge_uv = all_curve_uv + all_line_uv 165 | 166 | for edge_idx, edge_uv in enumerate(all_edge_uv): 167 | edge_uv = np.array(edge_uv) 168 | # print(edge_uv.shape) 169 | if len(edge_uv) == 0: 170 | continue 171 | edge_uv = np.round(edge_uv).astype(np.int32) 172 | edge_u = edge_uv[:, 0] 173 | edge_v = edge_uv[:, 1] 174 | 175 | valid_edge_uv = edge_uv[ 176 | (edge_u >= 0) & (edge_u < w) & (edge_v >= 0) & (edge_v < h) 177 | ] 178 | visibility = 0 179 | 180 | if len(valid_edge_uv) > 0: 181 | projected_edge = edge_map[valid_edge_uv[:, 1], valid_edge_uv[:, 0]] 182 | visibility = float( 183 | np.mean(projected_edge) > edge_visibility_threshold 184 | and np.max(projected_edge) > 0.5 185 | ) 186 | edge_visibility_matrix[edge_idx, frame_idx] = visibility 187 | 188 | return np.sum(edge_visibility_matrix, axis=1) > edge_visibility_frames 189 | 190 | 191 | def project2D(K, R, T, all_curve_points, all_line_points): 192 | all_curve_uv, all_line_uv = [], [] 193 | for curve_points in all_curve_points: 194 | curve_points = np.array(curve_points).reshape(-1, 3) 195 | curve_uv = project2D_single(K, R, T, curve_points) 196 | all_curve_uv.append(curve_uv) 197 | for line_points in all_line_points: 198 | line_points = np.array(line_points).reshape(-1, 3) 199 | line_uv = project2D_single(K, R, T, line_points) 200 | all_line_uv.append(line_uv) 201 | return all_curve_uv, all_line_uv 202 | 203 | 204 | def project2D_single(K, R, T, points3d): 205 | shape = points3d.shape 206 | assert shape[-1] == 3 207 | X = points3d.reshape(-1, 3) 208 | 209 | x = K @ (R @ X.T + T) 210 | x = x.T 211 | x = x / x[:, -1:] 212 | x = x.reshape(*shape)[..., :2].reshape(-1, 2).tolist() 213 | return x 214 | 215 | 216 | def get_parametric_edge( 217 | edge_dict, 218 | visible_checking=False, 219 | ): 220 | 221 | detector = edge_dict["detector"] 222 | scene_name = edge_dict["scene_name"] 223 | dataset_dir = edge_dict["dataset_dir"] 224 | result_dir = edge_dict["result_dir"] 225 | meta_data_dir = os.path.join(dataset_dir, scene_name) 226 | 227 | # fixed parameters, but can be fine-tuned for better edge extraction 228 | is_merge = True 229 | nms_factor = 0.95 230 | angle_threshold = 0.03 231 | fit_distance_threshold = 10.0 232 | min_inliers = 5 233 | max_lines = 4 234 | max_curves = 3 235 | merge_edge_distance_threshold = 5.0 236 | merge_endpoints_distance_threshold = 2.0 237 | merge_similarity_threshold = 0.98 238 | 239 | fitted_edge_dict = edge_fit( 240 | edge_data=edge_dict, 241 | angle_threshold=angle_threshold, 242 | nms_factor=nms_factor, 243 | fit_distance_threshold=fit_distance_threshold, 244 | min_inliers=min_inliers, 245 | max_lines=max_lines, 246 | max_curves=max_curves, 247 | ) 248 | if is_merge: 249 | merged_edge_dict = merge( 250 | result_dir, 251 | fitted_edge_dict, 252 | merge_edge_distance_threshold=merge_edge_distance_threshold, 253 | merge_endpoints_distance_threshold=merge_endpoints_distance_threshold, 254 | merge_similarity_threshold=merge_similarity_threshold, 255 | ) 256 | 257 | if visible_checking: 258 | _, return_edge_dict = process_geometry_data(merged_edge_dict) 259 | all_curve_points = return_edge_dict["curves_ctl_pts"] 260 | all_line_points = return_edge_dict["lines_end_pts"] 261 | edges, intrinsics_list, camtoworld_list, h, w = get_edge_maps( 262 | meta_data_dir, detector 263 | ) 264 | edge_visibility_threshold = 0.5 265 | edge_visibility_frames_ratio = 0.1 266 | num_frames = len(edges) 267 | edge_visibility_frames = math.ceil(edge_visibility_frames_ratio * num_frames) 268 | 269 | edge_visibility = compute_visibility( 270 | all_curve_points, 271 | all_line_points, 272 | edges, 273 | intrinsics_list, 274 | camtoworld_list, 275 | h, 276 | w, 277 | edge_visibility_threshold, 278 | edge_visibility_frames, 279 | ) 280 | curve_visibility = edge_visibility[: len(all_curve_points)] 281 | line_visibility = edge_visibility[len(all_curve_points) :] 282 | 283 | print( 284 | "before visible checking: ", 285 | len(all_curve_points) + len(all_line_points), 286 | "after visible checking: ", 287 | np.sum(edge_visibility), 288 | ) 289 | worldtogt = np.eye(4) 290 | (pred_points, return_edge_dict) = process_geometry_data( 291 | merged_edge_dict, worldtogt, curve_visibility, line_visibility 292 | ) 293 | 294 | else: 295 | worldtogt = np.eye(4) 296 | (pred_points, return_edge_dict) = process_geometry_data( 297 | merged_edge_dict, worldtogt, None, None 298 | ) 299 | 300 | return pred_points, return_edge_dict 301 | -------------------------------------------------------------------------------- /src/edge_extraction/extract_pointcloud.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.nn import functional as F 3 | 4 | 5 | def get_udf_normals_grid( 6 | func, 7 | func_grad, 8 | N, 9 | udf_threshold, 10 | is_linedirection=False, 11 | sampling_N=50, 12 | sampling_delta=0.005, 13 | max_batch=int(2**12), 14 | device="cuda", 15 | ): 16 | """ 17 | Efficiently fills a dense N*N*N regular grid by querying the function and its gradient for distance field values 18 | and optionally computing line directions. Adjusts voxel grid based on specified origin and max values. 19 | 20 | Parameters: 21 | - func: Callable for evaluating distance field values. 22 | - func_grad: Callable for evaluating gradients. 23 | - N: Size of the grid in each dimension. 24 | - udf_threshold: Threshold below which gradients are computed. 25 | - is_linedirection: Flag indicating whether to compute line directions. 26 | - sampling_N: Number of samples for line direction computation. 27 | - sampling_delta: Offset range for sampling around points for line direction. 28 | - max_batch: Max number of points processed in a single batch. 29 | 30 | Returns: 31 | Tuple of tensors (df_values, line_directions, gradients, samples, voxel_size) representing the computed 32 | distance field values, line directions, gradients at points below threshold, raw sample points, and the size 33 | of each voxel. 34 | """ 35 | 36 | overall_index = torch.arange(0, N**3, 1, out=torch.LongTensor()) 37 | samples = torch.zeros(N**3, 12, device=device) 38 | # transform first 3 columns 39 | # to be the x, y, z index 40 | samples[:, 2] = overall_index % N 41 | samples[:, 1] = torch.div(overall_index, N, rounding_mode="floor") % N 42 | samples[:, 0] = ( 43 | torch.div( 44 | torch.div(overall_index, N, rounding_mode="floor"), N, rounding_mode="floor" 45 | ) 46 | % N 47 | ) 48 | # Ensure voxel_origin and voxel_max are correctly set 49 | voxel_origin = [-1, -1, -1] 50 | voxel_size = 2.0 / (N - 1) 51 | samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2] 52 | samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1] 53 | samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0] 54 | 55 | # Query function for distance field values 56 | num_samples = N**3 57 | samples.requires_grad = False 58 | for head in range(0, num_samples, max_batch): 59 | tail = min(head + max_batch, num_samples) 60 | sample_subset = samples[head:tail, :3].clone().to(device) 61 | df, _, _ = func(sample_subset) 62 | samples[head:tail, 3:4] = df.detach() 63 | 64 | # Compute gradients where distance field value is below threshold 65 | norm_mask = samples[:, 3] < udf_threshold 66 | norm_idx = torch.where(norm_mask)[0] 67 | for head in range(0, len(norm_idx), max_batch): 68 | tail = min(head + max_batch, len(norm_idx)) 69 | idx_subset = norm_idx[head:tail] 70 | sample_subset = samples[idx_subset, :3].clone().to(device).requires_grad_(True) 71 | grad = func_grad(sample_subset).detach() 72 | samples[idx_subset, 4:7] = -F.normalize(grad, dim=1)[:, 0] 73 | 74 | # Compute line directions if requested 75 | if is_linedirection: 76 | sample_subset_ld = sample_subset.unsqueeze( 77 | 1 78 | ) + sampling_delta * torch.randn( 79 | (sample_subset.shape[0], sampling_N, 3), device=device 80 | ) 81 | grad_ld = ( 82 | func_grad(sample_subset_ld.reshape(-1, 3)) 83 | .detach() 84 | .reshape(sample_subset.shape[0], sampling_N, 3) 85 | ) 86 | _, _, vh = torch.linalg.svd(grad_ld) 87 | null_space = vh[:, -1, :].view(-1, 3) 88 | samples[idx_subset, 8:11] = F.normalize(null_space, dim=1) 89 | 90 | # Reshape output tensors 91 | df_values = samples[:, 3].reshape(N, N, N) 92 | vecs = samples[:, 4:7].reshape(N, N, N, 3) 93 | ld = samples[:, 8:11].reshape(N, N, N, 3) 94 | 95 | return df_values, ld, vecs, samples, torch.tensor(voxel_size) 96 | 97 | 98 | def get_udf_normals_slow( 99 | func, 100 | func_grad, 101 | voxel_size, 102 | xyz, 103 | is_linedirection, 104 | # N_ld=20, 105 | sampling_N=50, 106 | sampling_delta=0.005, # 0.005 107 | max_batch=int(2**12), 108 | device="cuda", 109 | ): 110 | """ 111 | Computes distance field values, normals, and optionally line directions for a set of points. 112 | 113 | Parameters: 114 | - func: Function to evaluate the distance field. 115 | - func_grad: Function to evaluate the gradient of the distance field. 116 | - voxel_size: Size of the voxel (not used in this function but kept for compatibility). 117 | - xyz: (N,3) tensor representing coordinates to evaluate. 118 | - is_linedirection: Boolean indicating whether to compute line directions. 119 | - sampling_N: Number of samples for computing line direction. 120 | - sampling_delta: Delta range for sampling around points for line direction. 121 | - max_batch: Maximum number of points to process in a single batch. 122 | 123 | Returns: 124 | - df_values: (N,) tensor of distance field values at xyz locations. 125 | - normals: (N,3) tensor of gradient values at xyz locations. 126 | - ld: (N,3) tensor of line direction values at xyz locations, if computed. 127 | - samples: (N, 10) tensor of x, y, z, distance field, grad_x, grad_y, grad_z, and optionally line directions. 128 | """ 129 | # network.eval() 130 | ################ 131 | # transform first 3 columns 132 | # to be the x, y, z coordinate 133 | 134 | num_samples = xyz.shape[0] 135 | # xyz = torch.from_numpy(xyz).float().cuda() 136 | samples = torch.cat([xyz, torch.zeros(num_samples, 10).float().cuda()], dim=-1) 137 | samples.requires_grad = False 138 | # samples.pin_memory() 139 | ################ 140 | # 2: Run forward pass to fill the grid 141 | ################ 142 | head = 0 143 | ## FIRST: fill distance field grid without gradients 144 | while head < num_samples: 145 | # xyz coords 146 | sample_subset = ( 147 | samples[head : min(head + max_batch, num_samples), 0:3].clone().cuda() 148 | ) 149 | # Create input 150 | xyz = sample_subset 151 | 152 | input = xyz.reshape(-1, xyz.shape[-1]) 153 | # Run forward pass 154 | with torch.no_grad(): 155 | df, _, PE = func(input) 156 | # Store df 157 | samples[head : min(head + max_batch, num_samples), 3] = ( 158 | df.squeeze(-1).detach().cpu() 159 | ) 160 | grad = func_grad(input).detach()[:, 0] 161 | normals = -F.normalize(grad, dim=1) 162 | samples[head : min(head + max_batch, num_samples), 4:7] = normals.cpu() 163 | 164 | if is_linedirection: 165 | input_ld = input.unsqueeze( 166 | 1 167 | ) + sampling_delta * torch.randn( # need to be fixed grid 168 | (input.shape[0], sampling_N, 3), device=device 169 | ) 170 | # input_ld = input.unsqueeze(1) + offset 171 | input_ld = input_ld.reshape(-1, input.shape[-1]) 172 | grad_ld = ( 173 | func_grad(input_ld.float()) 174 | .detach()[:, 0] 175 | .reshape(input.shape[0], -1, 3) 176 | ) 177 | _, _, vh = torch.linalg.svd(grad_ld) 178 | 179 | # Extract the null space (non-zero solutions) for each matrix in the batch 180 | # null_space = vh[:, -1, :][vh[:, -1, :] != 0].view(-1, 3) 181 | null_space = vh[:, -1, :].view(-1, 3) 182 | samples[head : min(head + max_batch, num_samples), 7:10] = F.normalize( 183 | null_space, dim=1 184 | ) 185 | # Next iter 186 | head += max_batch 187 | 188 | # Separate values in DF / gradients 189 | df_values = samples[:, 3] 190 | normals = samples[:, 4:7] 191 | ld = samples[:, 7:10] 192 | 193 | return df_values, normals, ld, samples 194 | 195 | 196 | import numpy as np 197 | 198 | 199 | def project_vector_onto_plane(A, B): 200 | # Calculate the dot product of A and B 201 | dot_product = torch.sum(A * B, dim=-1) 202 | 203 | # Calculate the projection of A onto B 204 | projection = dot_product.unsqueeze(-1) * B 205 | 206 | # Calculate the projected vector onto the plane perpendicular to B 207 | projected_vector = A - projection 208 | 209 | return projected_vector 210 | 211 | 212 | def get_pointcloud_from_udf( 213 | func, 214 | func_grad, 215 | N_MC=128, 216 | udf_threshold=1.0, 217 | sampling_N=50, 218 | sampling_delta=5e-3, 219 | is_pointshift=False, 220 | iters=1, 221 | is_linedirection=False, 222 | device="cuda", 223 | ): 224 | """ 225 | Computes a point cloud from a distance field network conditioned on the latent vector. 226 | Inputs: 227 | func: Function to evaluate the distance field. 228 | func_grad: Function to evaluate the gradient of the distance field. 229 | N_MC: Size of the grid. 230 | udf_threshold: Threshold to filter surfaces with large UDF values. 231 | is_pointshift: Flag indicating if points should be shifted by df * normals. 232 | iters: Number of iterations for the point shift. 233 | is_linedirection: Flag indicating if line direction computation is needed. 234 | Returns: 235 | pointcloud: (N**3, 3) tensor representing the edge point cloud. 236 | samples: (N**3, 7) tensor representing (x,y,z, distance field, grad_x, grad_y, grad_z). 237 | indices: Indices of coordinates that need updating in the next iteration. 238 | """ 239 | # Compute UDF normals and grid 240 | df_values, lds, normals, samples, voxel_size = get_udf_normals_grid( 241 | func=func, 242 | func_grad=func_grad, 243 | N=N_MC, 244 | udf_threshold=udf_threshold, 245 | is_linedirection=is_linedirection, 246 | sampling_N=sampling_N, 247 | sampling_delta=sampling_delta, 248 | device=device, 249 | ) 250 | 251 | # Reshape tensors for processing 252 | df_values, lds, normals, samples = ( 253 | df_values.reshape(-1), 254 | lds.reshape(-1, 3), 255 | normals.reshape(-1, 3), 256 | samples.reshape(-1, 12), 257 | ) 258 | xyz = samples[:, 0:3] 259 | df_values.clamp_(min=0) # Ensure distance field values are non-negative 260 | 261 | # Filter out points too far from the surface 262 | points_idx = df_values <= udf_threshold 263 | filtered_xyz, filtered_lds, normals, df_values = ( 264 | xyz[points_idx], 265 | lds[points_idx], 266 | normals[points_idx], 267 | df_values[points_idx], 268 | ) 269 | 270 | # Point shifting 271 | if is_pointshift and iters > 0: 272 | for iter in range(iters): 273 | shifted_xyz = filtered_xyz + df_values.unsqueeze(-1) * normals 274 | shifted_df_values, shifted_normals, filtered_lds, _ = get_udf_normals_slow( 275 | func=func, 276 | func_grad=func_grad, 277 | voxel_size=voxel_size, 278 | xyz=shifted_xyz, 279 | is_linedirection=True if iter == iters - 1 else False, 280 | device=device, 281 | ) 282 | shifted_points_idx = shifted_df_values <= udf_threshold 283 | filtered_xyz, df_values, normals, filtered_lds = ( 284 | shifted_xyz[shifted_points_idx], 285 | shifted_df_values[shifted_points_idx], 286 | shifted_normals[shifted_points_idx], 287 | filtered_lds[shifted_points_idx], 288 | ) 289 | 290 | return ( 291 | filtered_xyz.cpu().numpy(), 292 | filtered_lds.cpu().numpy() if filtered_lds is not None else None, 293 | ) 294 | -------------------------------------------------------------------------------- /src/edge_extraction/extract_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import point_cloud_utils as pcu 5 | import math 6 | import open3d as o3d 7 | import trimesh 8 | 9 | 10 | def downsample_point_cloud_average( 11 | points, num_voxels_per_axis=256, min_bound=None, max_bound=None 12 | ): 13 | """ 14 | Downsample a point set based on the number of voxels per axis by averaging the points within each voxel. 15 | 16 | Args: 17 | points: a [#v, 3]-shaped array of 3d points. 18 | num_voxels_per_axis: a scalar or 3-tuple specifying the number of voxels along each axis. 19 | 20 | Returns: 21 | A [#v', 3]-shaped numpy array of downsampled points, where #v' is the number of occupied voxels. 22 | """ 23 | 24 | # Calculate the bounding box of the point cloud 25 | if min_bound is None: 26 | min_bound = np.min(points, axis=0) 27 | else: 28 | min_bound = np.array(min_bound) 29 | 30 | if max_bound is None: 31 | max_bound = np.max(points, axis=0) 32 | else: 33 | max_bound = np.array(max_bound) 34 | 35 | # Determine the size of the voxel based on the desired number of voxels per axis 36 | if isinstance(num_voxels_per_axis, int): 37 | voxel_size = (max_bound - min_bound) / num_voxels_per_axis 38 | else: 39 | voxel_size = [ 40 | (max_bound[i] - min_bound[i]) / num_voxels_per_axis[i] for i in range(3) 41 | ] 42 | 43 | # Use the existing function to downsample the point cloud based on voxel size 44 | downsampled_points = pcu.downsample_point_cloud_on_voxel_grid( 45 | voxel_size, points, min_bound=min_bound, max_bound=max_bound 46 | ) 47 | 48 | return downsampled_points 49 | 50 | 51 | def get_gt_points( 52 | name, 53 | edge_type="all", 54 | interval=0.005, 55 | return_CAD=False, 56 | return_direction=False, 57 | base_dir=None, 58 | ): 59 | """ 60 | Get ground truth points from a dataset. 61 | 62 | Args: 63 | name (str): Name of the dataset. 64 | 65 | Returns: 66 | numpy.ndarray: Raw and processed ground truth points. 67 | """ 68 | objs_dir = os.path.join(base_dir, "obj") 69 | obj_names = os.listdir(objs_dir) 70 | obj_names.sort() 71 | index_obj_names = {} 72 | for obj_name in obj_names: 73 | index_obj_names[obj_name[:8]] = obj_name 74 | 75 | json_feats_path = os.path.join(base_dir, "chunk_0000_feats.json") 76 | with open(json_feats_path, "r") as f: 77 | json_data_feats = json.load(f) 78 | json_stats_path = os.path.join(base_dir, "chunk_0000_stats.json") 79 | with open(json_stats_path, "r") as f: 80 | json_data_stats = json.load(f) 81 | 82 | # get the normalize scale to help align the nerf points and gt points 83 | [ 84 | x_min, 85 | y_min, 86 | z_min, 87 | x_max, 88 | y_max, 89 | z_max, 90 | x_range, 91 | y_range, 92 | z_range, 93 | ] = json_data_stats[name]["bbox"] 94 | scale = 1 / max(x_range, y_range, z_range) 95 | poi_center = ( 96 | np.array([((x_min + x_max) / 2), ((y_min + y_max) / 2), ((z_min + z_max) / 2)]) 97 | * scale 98 | ) 99 | set_location = [0.5, 0.5, 0.5] - poi_center # based on the rendering settings 100 | 101 | obj_path = os.path.join(objs_dir, index_obj_names[name]) 102 | if return_CAD: 103 | cad_obj = trimesh.load_mesh(obj_path) 104 | else: 105 | cad_obj = None 106 | 107 | with open(obj_path, encoding="utf-8") as file: 108 | data = file.readlines() 109 | 110 | vertices_obj = [each.split(" ") for each in data if each.split(" ")[0] == "v"] 111 | vertices_xyz = [ 112 | [float(v[1]), float(v[2]), float(v[3].replace("\n", ""))] for v in vertices_obj 113 | ] 114 | 115 | edge_pts = [] 116 | edge_pts_raw = [] 117 | edge_pts_direction = [] 118 | rename = { 119 | "BSpline": "curve", 120 | "Circle": "curve", 121 | "Ellipse": "curve", 122 | "Line": "line", 123 | } 124 | for each_curve in json_data_feats[name]: 125 | if edge_type != "all" and rename[each_curve["type"]] != edge_type: 126 | continue 127 | 128 | if each_curve["sharp"]: # each_curve["type"]: BSpline, Line, Circle 129 | each_edge_pts = [vertices_xyz[i] for i in each_curve["vert_indices"]] 130 | edge_pts_raw += each_edge_pts 131 | 132 | gt_sampling = [] 133 | each_edge_pts = np.array(each_edge_pts) 134 | for index in range(len(each_edge_pts) - 1): 135 | next = each_edge_pts[index + 1] 136 | current = each_edge_pts[index] 137 | num = int(np.linalg.norm(next - current) // interval) 138 | linspace = np.linspace(0, 1, num) 139 | gt_sampling.append( 140 | linspace[:, None] * current + (1 - linspace)[:, None] * next 141 | ) 142 | 143 | if return_direction: 144 | direction = (next - current) / np.linalg.norm(next - current) 145 | edge_pts_direction.extend([direction] * num) 146 | each_edge_pts = np.concatenate(gt_sampling).tolist() 147 | edge_pts += each_edge_pts 148 | 149 | if len(edge_pts_raw) == 0: 150 | return None, None, None, None 151 | 152 | edge_pts_raw = np.array(edge_pts_raw) * scale + set_location 153 | edge_pts = np.array(edge_pts) * scale + set_location 154 | edge_pts_direction = np.array(edge_pts_direction) 155 | 156 | return ( 157 | edge_pts_raw.astype(np.float32), 158 | edge_pts.astype(np.float32), 159 | cad_obj, 160 | edge_pts_direction, 161 | ) 162 | 163 | 164 | def chamfer_distance(x, y, return_index=False, p_norm=2, max_points_per_leaf=10): 165 | """ 166 | Compute the chamfer distance between two point clouds x, and y 167 | 168 | Args: 169 | x : A m-sized minibatch of point sets in R^d. i.e. shape [m, n_a, d] 170 | y : A m-sized minibatch of point sets in R^d. i.e. shape [m, n_b, d] 171 | return_index: If set to True, will return a pair (corrs_x_to_y, corrs_y_to_x) where 172 | corrs_x_to_y[i] stores the index into y of the closest point to x[i] 173 | (i.e. y[corrs_x_to_y[i]] is the nearest neighbor to x[i] in y). 174 | corrs_y_to_x is similar to corrs_x_to_y but with x and y reversed. 175 | max_points_per_leaf : The maximum number of points per leaf node in the KD tree used by this function. 176 | Default is 10. 177 | p_norm : Which norm to use. p_norm can be any real number, inf (for the max norm) -inf (for the min norm), 178 | 0 (for sum(x != 0)) 179 | Returns: 180 | The chamfer distance between x an dy. 181 | If return_index is set, then this function returns a tuple (chamfer_dist, corrs_x_to_y, corrs_y_to_x) where 182 | corrs_x_to_y and corrs_y_to_x are described above. 183 | """ 184 | 185 | dists_x_to_y, corrs_x_to_y = pcu.k_nearest_neighbors( 186 | x, y, k=1, squared_distances=False, max_points_per_leaf=max_points_per_leaf 187 | ) 188 | dists_y_to_x, corrs_y_to_x = pcu.k_nearest_neighbors( 189 | y, x, k=1, squared_distances=False, max_points_per_leaf=max_points_per_leaf 190 | ) 191 | 192 | dists_x_to_y = np.linalg.norm(x[corrs_y_to_x] - y, axis=-1, ord=p_norm).mean() 193 | dists_y_to_x = np.linalg.norm(y[corrs_x_to_y] - x, axis=-1, ord=p_norm).mean() 194 | 195 | Comp = np.mean(dists_x_to_y) 196 | Acc = np.mean(dists_y_to_x) 197 | cham_dist = Comp + Acc 198 | 199 | if return_index: 200 | return cham_dist, corrs_x_to_y, corrs_y_to_x 201 | 202 | return cham_dist, Acc, Comp 203 | 204 | 205 | def compute_chamfer_distance(pred_sampled, gt_points): 206 | """ 207 | Compute chamfer distance between predicted and ground truth points. 208 | 209 | Args: 210 | pred_sampled (numpy.ndarray): Predicted point cloud. 211 | gt_points (numpy.ndarray): Ground truth points. 212 | 213 | Returns: 214 | float: Chamfer distance. 215 | """ 216 | chamfer_dist, acc, comp = chamfer_distance(pred_sampled, gt_points) 217 | return chamfer_dist, acc, comp 218 | 219 | 220 | def compute_precision_recall_IOU( 221 | pred_sampled, gt_points, metrics, thresh_list=[0.02], edge_type="all" 222 | ): 223 | """ 224 | Compute precision, recall, F-score, and IOU. 225 | 226 | Args: 227 | pred_sampled (numpy.ndarray): Predicted point cloud. 228 | gt_points (numpy.ndarray): Ground truth points. 229 | metrics (dict): Dictionary to store metrics. 230 | 231 | Returns: 232 | dict: Updated metrics. 233 | """ 234 | if edge_type == "all": 235 | for thresh in thresh_list: 236 | dists_a_to_b, _ = pcu.k_nearest_neighbors( 237 | pred_sampled, gt_points, k=1 238 | ) # k closest points (in pts_b) for each point in pts_a 239 | correct_pred = np.sum(dists_a_to_b < thresh) 240 | precision = correct_pred / len(dists_a_to_b) 241 | metrics[f"precision_{thresh}"].append(precision) 242 | 243 | dists_b_to_a, _ = pcu.k_nearest_neighbors(gt_points, pred_sampled, k=1) 244 | correct_gt = np.sum(dists_b_to_a < thresh) 245 | recall = correct_gt / len(dists_b_to_a) 246 | metrics[f"recall_{thresh}"].append(recall) 247 | 248 | fscore = 2 * precision * recall / (precision + recall) 249 | metrics[f"fscore_{thresh}"].append(fscore) 250 | 251 | intersection = min(correct_pred, correct_gt) 252 | union = ( 253 | len(dists_a_to_b) + len(dists_b_to_a) - max(correct_pred, correct_gt) 254 | ) 255 | 256 | IOU = intersection / union 257 | metrics[f"IOU_{thresh}"].append(IOU) 258 | return metrics 259 | else: 260 | correct_gt_list = [] 261 | correct_pred_list = [] 262 | _, acc, comp = compute_chamfer_distance(pred_sampled, gt_points) 263 | for thresh in thresh_list: 264 | dists_b_to_a, _ = pcu.k_nearest_neighbors(gt_points, pred_sampled, k=1) 265 | correct_gt = np.sum(dists_b_to_a < thresh) 266 | num_gt = len(dists_b_to_a) 267 | correct_gt_list.append(correct_gt) 268 | 269 | dists_a_to_b, _ = pcu.k_nearest_neighbors(pred_sampled, gt_points, k=1) 270 | correct_pred = np.sum(dists_a_to_b < thresh) 271 | correct_pred_list.append(correct_pred) 272 | num_pred = len(dists_a_to_b) 273 | 274 | return correct_gt_list, num_gt, correct_pred_list, num_pred, acc, comp 275 | 276 | 277 | def f_score(precision, recall): 278 | """ 279 | Compute F-score. 280 | 281 | Args: 282 | precision (float): Precision. 283 | recall (float): Recall. 284 | 285 | Returns: 286 | float: F-score. 287 | """ 288 | return 2 * precision * recall / (precision + recall) 289 | 290 | 291 | def bezier_curve_length(control_points, num_samples): 292 | def binomial_coefficient(n, i): 293 | return math.factorial(n) // (math.factorial(i) * math.factorial(n - i)) 294 | 295 | def derivative_bezier(t): 296 | n = len(control_points) - 1 297 | point = np.array([0.0, 0.0, 0.0]) 298 | for i, (p1, p2) in enumerate(zip(control_points[:-1], control_points[1:])): 299 | point += ( 300 | n 301 | * binomial_coefficient(n - 1, i) 302 | * (1 - t) ** (n - 1 - i) 303 | * t**i 304 | * (np.array(p2) - np.array(p1)) 305 | ) 306 | return point 307 | 308 | def simpson_integral(a, b, num_samples): 309 | h = (b - a) / num_samples 310 | sum1 = sum( 311 | np.linalg.norm(derivative_bezier(a + i * h)) 312 | for i in range(1, num_samples, 2) 313 | ) 314 | sum2 = sum( 315 | np.linalg.norm(derivative_bezier(a + i * h)) 316 | for i in range(2, num_samples - 1, 2) 317 | ) 318 | return ( 319 | ( 320 | np.linalg.norm(derivative_bezier(a)) 321 | + 4 * sum1 322 | + 2 * sum2 323 | + np.linalg.norm(derivative_bezier(b)) 324 | ) 325 | * h 326 | / 3 327 | ) 328 | 329 | # Compute the length of the 3D Bezier curve using composite Simpson's rule 330 | length = 0.0 331 | for i in range(num_samples): 332 | t0 = i / num_samples 333 | t1 = (i + 1) / num_samples 334 | length += simpson_integral(t0, t1, num_samples) 335 | 336 | return length 337 | -------------------------------------------------------------------------------- /src/edge_extraction/merging/main.py: -------------------------------------------------------------------------------- 1 | from src.edge_extraction.edge_fitting.main import save_3d_lines_to_file 2 | from src.edge_extraction.edge_fitting.line_fit import line_fitting 3 | from src.edge_extraction.edge_fitting.bezier_fit import ( 4 | bezier_fit, 5 | bezier_curve, 6 | ) 7 | import numpy as np 8 | from sklearn.metrics.pairwise import cosine_similarity 9 | from sklearn.metrics import silhouette_score 10 | from scipy.sparse.csgraph import connected_components 11 | import open3d as o3d 12 | from scipy.spatial.distance import euclidean, cdist 13 | import os 14 | 15 | 16 | def line_segment_point_distance(line_segment, query_point): 17 | """Compute the Euclidean distance between a line segment and a query point. 18 | 19 | Parameters: 20 | line_segment (np.ndarray): An array of shape (6,), representing two 3D endpoints. 21 | query_point (np.ndarray): An array of shape (3,), representing the 3D query point. 22 | 23 | Returns: 24 | float: The minimum distance from the query point to the line segment. 25 | """ 26 | point1, point2 = line_segment[:3], line_segment[3:] 27 | point_delta = point2 - point1 28 | u = np.clip( 29 | np.dot(query_point - point1, point_delta) / np.dot(point_delta, point_delta), 30 | 0, 31 | 1, 32 | ) 33 | closest_point = point1 + u * point_delta 34 | return np.linalg.norm(closest_point - query_point) 35 | 36 | 37 | def compute_pairwise_distances(line_segments): 38 | """Compute pairwise distances between line segments. 39 | 40 | Parameters: 41 | line_segments (np.ndarray): An array of shape (N, 6), each row represents a line segment in 3D. 42 | 43 | Returns: 44 | np.ndarray: A symmetric array of shape (N, N), containing pairwise distances. 45 | """ 46 | num_lines = len(line_segments) 47 | endpoints = line_segments.reshape(-1, 3) 48 | dist_matrix = np.zeros((num_lines, num_lines)) 49 | 50 | for i, line_segment in enumerate(line_segments): 51 | for j in range(i + 1, num_lines): 52 | min_distance = min( 53 | line_segment_point_distance(line_segment, endpoints[2 * j]), 54 | line_segment_point_distance(line_segment, endpoints[2 * j + 1]), 55 | ) 56 | dist_matrix[i, j] = min_distance 57 | 58 | dist_matrix += dist_matrix.T # Make the matrix symmetric 59 | return dist_matrix 60 | 61 | 62 | def compute_pairwise_cosine_similarity(line_segments): 63 | direction_vectors = line_segments[:, 3:] - line_segments[:, :3] 64 | pairwise_similarity = cosine_similarity(direction_vectors) 65 | return pairwise_similarity 66 | 67 | 68 | def bezier_curve_distance(points1, points2): 69 | distances = np.linalg.norm(points1[:, np.newaxis] - points2, axis=2) 70 | min_distance = np.min(distances) 71 | return min_distance 72 | 73 | 74 | def bezier_slope_vector(P0, P1, P2, P3, t): 75 | # Calculate the derivative of the 3D Bézier curve 76 | dp_dt = ( 77 | -3 * (1 - t) ** 2 * P0 78 | + 3 * (1 - 4 * t + 3 * t**2) * P1 79 | + 3 * (2 * t - 3 * t**2) * P2 80 | + 3 * t**2 * P3 81 | ) 82 | return dp_dt 83 | 84 | 85 | def get_dist_similarity(control_points1, control_points2, points1, points2, t_values): 86 | control_points1 = control_points1.reshape(-1, 3) 87 | control_points2 = control_points2.reshape(-1, 3) 88 | 89 | # Calculate the pairwise distances between all points in the two sets 90 | distances = cdist(points1, points2) 91 | 92 | # Find the indices of the minimum distance pair 93 | min_indices = np.unravel_index(np.argmin(distances), distances.shape) 94 | min_distance = distances[min_indices] 95 | 96 | points_slope1 = bezier_slope_vector( 97 | control_points1[0], 98 | control_points1[1], 99 | control_points1[2], 100 | control_points1[3], 101 | t_values[min_indices[0]], 102 | ) 103 | 104 | points_slope2 = bezier_slope_vector( 105 | control_points2[0], 106 | control_points2[1], 107 | control_points2[2], 108 | control_points2[3], 109 | t_values[min_indices[1]], 110 | ) 111 | 112 | # Compute cosine similarity between the two slopes 113 | similarity = np.abs(np.dot(points_slope1, points_slope2)) / ( 114 | np.linalg.norm(points_slope1) * np.linalg.norm(points_slope2) 115 | ) 116 | 117 | return min_distance, similarity 118 | 119 | 120 | def merge_line_segments( 121 | line_segments, raw_points_on_lines, distance_threshold, similarity_threshold 122 | ): 123 | dist_matrix = compute_pairwise_distances(line_segments) 124 | similarity_matrix = compute_pairwise_cosine_similarity(line_segments) 125 | 126 | # Create adjacency matrix based on distance and similarity thresholds 127 | adjacency_matrix = (dist_matrix <= distance_threshold) & ( 128 | similarity_matrix >= similarity_threshold 129 | ) 130 | # Compute connected components 131 | num_components, labels = connected_components(adjacency_matrix) 132 | 133 | merged_line_segments = [] 134 | for component in range(num_components): 135 | component_indices = np.where(labels == component)[0] 136 | if len(component_indices) == 1: 137 | merged_line_segments.append(line_segments[component_indices[0]]) 138 | continue 139 | else: 140 | raw_points_on_lines_group = raw_points_on_lines[component_indices] 141 | raw_points_on_lines_group_array = np.array( 142 | [ 143 | point 144 | for raw_points_on_lines in raw_points_on_lines_group 145 | for point in raw_points_on_lines 146 | ] 147 | ).reshape(-1, 3) 148 | 149 | try: 150 | line_segment, _ = line_fitting(raw_points_on_lines_group_array) 151 | merged_line_segments.append(line_segment) 152 | except: 153 | continue 154 | 155 | merged_line_segments = np.array(merged_line_segments) 156 | return merged_line_segments 157 | 158 | 159 | def merge_bezier_curves( 160 | control_points_list, 161 | raw_points_on_curves, 162 | distance_threshold, 163 | similarity_threshold, 164 | num_samples=100, 165 | ): 166 | """ 167 | Merge Bezier curves based on distance and similarity thresholds. 168 | 169 | Parameters: 170 | control_points_list (array): list of control points for each curve. 171 | raw_points_on_curves (array): actual points on each Bezier curve. 172 | distance_threshold (float): maximum distance for merging curves. 173 | similarity_threshold (float): minimum similarity for merging curves. 174 | num_samples (int, optional): number of samples for curve points. 175 | 176 | Returns: 177 | np.array: Merged control points for the Bezier curves. 178 | """ 179 | 180 | if not isinstance(control_points_list, np.ndarray) or not isinstance( 181 | raw_points_on_curves, np.ndarray 182 | ): 183 | raise ValueError("Input must be NumPy arrays.") 184 | 185 | num_curves = len(control_points_list) 186 | dist_matrix = np.zeros((num_curves, num_curves)) 187 | similarity_matrix = np.zeros((num_curves, num_curves)) 188 | t_values = np.linspace(0, 1, num_samples) 189 | 190 | for i, control_points1 in enumerate(control_points_list): 191 | for j in range(i + 1, num_curves): 192 | control_points2 = control_points_list[j] 193 | points1 = bezier_curve(t_values, *(control_points1.tolist())).reshape(-1, 3) 194 | points2 = bezier_curve(t_values, *(control_points2.tolist())).reshape(-1, 3) 195 | dist_matrix[i, j], similarity_matrix[i, j] = get_dist_similarity( 196 | control_points1, control_points2, points1, points2, t_values 197 | ) 198 | 199 | dist_matrix += dist_matrix.T 200 | similarity_matrix += similarity_matrix.T 201 | 202 | adjacency_matrix = (dist_matrix <= distance_threshold) & ( 203 | similarity_matrix >= similarity_threshold 204 | ) 205 | num_components, labels = connected_components(adjacency_matrix) 206 | 207 | merged_bezier_curves = [] 208 | for component in range(num_components): 209 | component_indices = np.where(labels == component)[0] 210 | if len(component_indices) == 1: 211 | p = control_points_list[component_indices[0]] 212 | merged_bezier_curves.append(p) 213 | else: 214 | raw_points_group = [raw_points_on_curves[i] for i in component_indices] 215 | raw_points_group_array = np.concatenate(raw_points_group, axis=0) 216 | p = bezier_fit(raw_points_group_array) 217 | merged_bezier_curves.append(p) 218 | 219 | return np.array(merged_bezier_curves) 220 | 221 | 222 | def merge_endpoints(merged_line_segments, merged_bezier_curves, distance_threshold): 223 | N_lines = len(merged_line_segments) 224 | N_curves = len(merged_bezier_curves) 225 | 226 | if N_lines == 0 and N_curves == 0: 227 | return [], [] 228 | 229 | if N_lines > 0: 230 | line_endpoints = merged_line_segments.reshape(-1, 3) 231 | else: 232 | line_endpoints = np.array([]).reshape(-1, 3) 233 | 234 | if N_curves > 0: 235 | curve_endpoints = merged_bezier_curves[:, [0, 1, 2, -3, -2, -1]].reshape(-1, 3) 236 | else: 237 | curve_endpoints = np.array([]).reshape(-1, 3) 238 | 239 | concat_endpoints = np.concatenate([line_endpoints, curve_endpoints], axis=0) 240 | 241 | dist_matrix = cdist(concat_endpoints, concat_endpoints) 242 | adjacency_matrix = dist_matrix <= distance_threshold 243 | num_components, labels = connected_components(adjacency_matrix) 244 | for component in range(num_components): 245 | component_indices = np.where(labels == component)[0] 246 | if len(component_indices) > 1: 247 | endpoints = concat_endpoints[component_indices] 248 | mean_endpoint = np.mean(endpoints, axis=0) 249 | concat_endpoints[component_indices] = mean_endpoint 250 | 251 | if N_lines > 0: 252 | merged_line_segments_merged_endpoints = concat_endpoints[: N_lines * 2].reshape( 253 | -1, 6 254 | ) 255 | else: 256 | merged_line_segments_merged_endpoints = [] 257 | 258 | if N_curves > 0: 259 | merged_curve_segments_merged_endpoints = np.zeros_like(merged_bezier_curves) 260 | curve_merged_endpoints = concat_endpoints[N_lines * 2 :].reshape(-1, 6) 261 | merged_curve_segments_merged_endpoints[:, :3] = curve_merged_endpoints[:, :3] 262 | merged_curve_segments_merged_endpoints[:, 3:9] = merged_bezier_curves[:, 3:9] 263 | merged_curve_segments_merged_endpoints[:, 9:] = curve_merged_endpoints[:, 3:] 264 | 265 | else: 266 | merged_curve_segments_merged_endpoints = [] 267 | 268 | return merged_line_segments_merged_endpoints, merged_curve_segments_merged_endpoints 269 | 270 | 271 | def approximate_curve_length(P0, P1, P2, P3): 272 | return np.linalg.norm(P1 - P0) + np.linalg.norm(P2 - P1) + np.linalg.norm(P3 - P2) 273 | 274 | 275 | def generate_points_curve(curves): 276 | all_points = [] 277 | for curve in curves: 278 | P0 = np.array(curve[:3]) 279 | P1 = np.array(curve[3:6]) 280 | P2 = np.array(curve[6:9]) 281 | P3 = np.array(curve[9:]) 282 | 283 | # Approximate number of points based on curve length 284 | length = approximate_curve_length(P0, P1, P2, P3) 285 | num_points = int(np.ceil(length * 1000)) # Adjust the factor as needed 286 | 287 | t_values = np.linspace(0, 1, num_points) 288 | curve_points = np.array(bezier_curve(t_values, *(curve.tolist()))) 289 | 290 | all_points.extend(curve_points) 291 | 292 | return np.array(all_points).reshape(-1, 3) 293 | 294 | 295 | def merge( 296 | out_dir, 297 | fitted_edge_dict, 298 | merge_edge_distance_threshold=5.0, 299 | merge_endpoints_distance_threshold=1.0, 300 | merge_similarity_threshold=0.98, 301 | merge_endpoints_flag=True, 302 | merge_edge_flag=True, 303 | merge_curve_flag=False, 304 | save_ply=False, 305 | ): 306 | 307 | resolution = int(fitted_edge_dict["resolution"]) 308 | lines = np.array(fitted_edge_dict["lines_end_pts"]).reshape(-1, 6) 309 | raw_points_on_lines = np.array( 310 | fitted_edge_dict["raw_points_on_lines"], dtype=object 311 | ) 312 | bezier_curves = np.array(fitted_edge_dict["curves_ctl_pts"]).reshape(-1, 12) 313 | raw_points_on_curves = np.array( 314 | fitted_edge_dict["raw_points_on_curves"], dtype=object 315 | ) 316 | 317 | # Normalize thresholds 318 | merge_edge_distance_threshold /= resolution 319 | merge_endpoints_distance_threshold /= resolution 320 | 321 | # Merge lines 322 | if merge_edge_flag and len(lines) > 0: 323 | merged_line_segments = merge_line_segments( 324 | lines, 325 | raw_points_on_lines, 326 | merge_edge_distance_threshold / 2.0, 327 | merge_similarity_threshold, 328 | ) 329 | else: 330 | merged_line_segments = lines 331 | 332 | if merge_curve_flag and merge_edge_flag: 333 | if len(bezier_curves) > 0: 334 | merged_bezier_curves = merge_bezier_curves( 335 | bezier_curves, 336 | raw_points_on_curves, 337 | merge_edge_distance_threshold, 338 | merge_similarity_threshold, 339 | ) 340 | else: 341 | merged_bezier_curves = [] 342 | else: 343 | merged_bezier_curves = bezier_curves 344 | 345 | if merge_endpoints_flag: 346 | ( 347 | merged_line_segments, 348 | merged_bezier_curves, 349 | ) = merge_endpoints( 350 | merged_line_segments, 351 | merged_bezier_curves, 352 | merge_endpoints_distance_threshold, 353 | ) 354 | 355 | if save_ply: 356 | if len(merged_line_segments) > 0: 357 | save_3d_lines_to_file( 358 | merged_line_segments, 359 | os.path.join(out_dir, "merged_line_segments.ply"), 360 | width=2, 361 | scale=1.0, 362 | ) 363 | print(f"Saved merged line segments to {out_dir}.") 364 | 365 | if len(merged_bezier_curves) > 0: 366 | pcd = o3d.geometry.PointCloud() 367 | meregd_bezier_curves_points = generate_points_curve(merged_bezier_curves) 368 | pcd.points = o3d.utility.Vector3dVector(meregd_bezier_curves_points) 369 | o3d.io.write_point_cloud( 370 | os.path.join(out_dir, "merged_bezier_curve_points.ply"), 371 | pcd, 372 | write_ascii=True, 373 | ) 374 | print(f"Saved merged bezier curves to {out_dir}.") 375 | 376 | merged_edge_dict = { 377 | "lines_end_pts": ( 378 | merged_line_segments.tolist() if len(merged_line_segments) > 0 else [] 379 | ), 380 | "curves_ctl_pts": ( 381 | merged_bezier_curves.tolist() if len(merged_bezier_curves) > 0 else [] 382 | ), 383 | } 384 | 385 | return merged_edge_dict 386 | -------------------------------------------------------------------------------- /src/eval/ABC_scans.txt: -------------------------------------------------------------------------------- 1 | 00000325 -------------------------------------------------------------------------------- /src/eval/DTU_scans.txt: -------------------------------------------------------------------------------- 1 | scan105 -------------------------------------------------------------------------------- /src/eval/eval_ABC.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import argparse 4 | from src.eval.eval_util import ( 5 | set_random_seeds, 6 | load_from_json, 7 | compute_chamfer_distance, 8 | f_score, 9 | compute_precision_recall_IOU, 10 | downsample_point_cloud_average, 11 | get_gt_points, 12 | get_pred_points_and_directions, 13 | ) 14 | 15 | 16 | def update_totals_and_metrics(metrics, totals, results, edge_type): 17 | correct_gt, num_gt, correct_pred, num_pred, acc, comp = results 18 | metrics[f"comp_{edge_type}"].append(comp) 19 | metrics[f"acc_{edge_type}"].append(acc) 20 | for i, threshold in enumerate(["5", "10", "20"]): 21 | totals[f"thre{threshold}_correct_gt_total"] += correct_gt[i] 22 | totals[f"thre{threshold}_correct_pred_total"] += correct_pred[i] 23 | totals["num_gt_total"] += num_gt 24 | totals["num_pred_total"] += num_pred 25 | 26 | 27 | def finalize_metrics(metrics): 28 | for key, value in metrics.items(): 29 | value = np.array(value) 30 | value[np.isnan(value)] = 0 31 | metrics[key] = round(np.mean(value), 4) 32 | return metrics 33 | 34 | 35 | def print_metrics(metrics, totals, edge_type): 36 | print(f"{edge_type.capitalize()}:") 37 | print(f" Completeness: {metrics[f'comp_{edge_type}']}") 38 | print(f" Accuracy: {metrics[f'acc_{edge_type}']}") 39 | 40 | 41 | def process_scan(scan_name, base_dir, exp_name, dataset_dir, metrics, totals): 42 | print(f"Processing: {scan_name}") 43 | json_path = os.path.join( 44 | base_dir, scan_name, exp_name, "results", "parametric_edges.json" 45 | ) 46 | if not os.path.exists(json_path): 47 | print(f"Invalid prediction at {scan_name}") 48 | return 49 | 50 | all_curve_points, all_line_points, all_curve_directions, all_line_directions = ( 51 | get_pred_points_and_directions(json_path) 52 | ) 53 | pred_points = ( 54 | np.concatenate([all_curve_points, all_line_points], axis=0) 55 | .reshape(-1, 3) 56 | .astype(np.float32) 57 | ) 58 | 59 | if len(pred_points) == 0: 60 | print(f"Invalid prediction at {scan_name}") 61 | return 62 | 63 | pred_sampled = downsample_point_cloud_average( 64 | pred_points, 65 | num_voxels_per_axis=256, 66 | min_bound=[-1, -1, -1], 67 | max_bound=[1, 1, 1], 68 | ) 69 | 70 | gt_points_raw, gt_points, _ = get_gt_points( 71 | scan_name, "all", data_base_dir=os.path.join(dataset_dir, "groundtruth") 72 | ) 73 | if gt_points_raw is None: 74 | return 75 | 76 | chamfer_dist, acc, comp = compute_chamfer_distance(pred_sampled, gt_points) 77 | print( 78 | f" Chamfer Distance: {chamfer_dist:.4f}, Accuracy: {acc:.4f}, Completeness: {comp:.4f}" 79 | ) 80 | metrics["chamfer"].append(chamfer_dist) 81 | metrics["acc"].append(acc) 82 | metrics["comp"].append(comp) 83 | metrics = compute_precision_recall_IOU( 84 | pred_sampled, 85 | gt_points, 86 | metrics, 87 | thresh_list=[0.005, 0.01, 0.02], 88 | edge_type="all", 89 | ) 90 | 91 | for edge_type in ["curve", "line"]: 92 | gt_points_raw_edge, gt_points_edge, _ = get_gt_points( 93 | scan_name, 94 | edge_type, 95 | return_direction=True, 96 | data_base_dir=os.path.join(dataset_dir, "groundtruth"), 97 | ) 98 | if gt_points_raw_edge is not None: 99 | results = compute_precision_recall_IOU( 100 | pred_sampled, 101 | gt_points_edge, 102 | None, 103 | thresh_list=[0.005, 0.01, 0.02], 104 | edge_type=edge_type, 105 | ) 106 | update_totals_and_metrics(metrics, totals[edge_type], results, edge_type) 107 | 108 | 109 | def main(base_dir, dataset_dir, exp_name): 110 | set_random_seeds() 111 | metrics = { 112 | "chamfer": [], 113 | "acc": [], 114 | "comp": [], 115 | "comp_curve": [], 116 | "comp_line": [], 117 | "acc_curve": [], 118 | "acc_line": [], 119 | "precision_0.01": [], 120 | "recall_0.01": [], 121 | "fscore_0.01": [], 122 | "IOU_0.01": [], 123 | "precision_0.02": [], 124 | "recall_0.02": [], 125 | "fscore_0.02": [], 126 | "IOU_0.02": [], 127 | "precision_0.005": [], 128 | "recall_0.005": [], 129 | "fscore_0.005": [], 130 | "IOU_0.005": [], 131 | } 132 | 133 | totals = { 134 | "curve": { 135 | "thre5_correct_gt_total": 0, 136 | "thre10_correct_gt_total": 0, 137 | "thre20_correct_gt_total": 0, 138 | "thre5_correct_pred_total": 0, 139 | "thre10_correct_pred_total": 0, 140 | "thre20_correct_pred_total": 0, 141 | "num_gt_total": 0, 142 | "num_pred_total": 0, 143 | }, 144 | "line": { 145 | "thre5_correct_gt_total": 0, 146 | "thre10_correct_gt_total": 0, 147 | "thre20_correct_gt_total": 0, 148 | "thre5_correct_pred_total": 0, 149 | "thre10_correct_pred_total": 0, 150 | "thre20_correct_pred_total": 0, 151 | "num_gt_total": 0, 152 | "num_pred_total": 0, 153 | }, 154 | } 155 | 156 | with open("src/eval/ABC_scans.txt", "r") as f: 157 | scan_names = [line.strip() for line in f] 158 | 159 | for scan_name in scan_names: 160 | process_scan(scan_name, base_dir, exp_name, dataset_dir, metrics, totals) 161 | 162 | metrics = finalize_metrics(metrics) 163 | 164 | print("Summary:") 165 | print(f" Accuracy: {metrics['acc']:.4f}") 166 | print(f" Completeness: {metrics['comp']:.4f}") 167 | print(f" Recall @ 5 mm: {metrics['recall_0.005']:.4f}") 168 | print(f" Recall @ 10 mm: {metrics['recall_0.01']:.4f}") 169 | print(f" Recall @ 20 mm: {metrics['recall_0.02']:.4f}") 170 | print(f" Precision @ 5 mm: {metrics['precision_0.005']:.4f}") 171 | print(f" Precision @ 10 mm: {metrics['precision_0.01']:.4f}") 172 | print(f" Precision @ 20 mm: {metrics['precision_0.02']:.4f}") 173 | print(f" F-Score @ 5 mm: {metrics['fscore_0.005']:.4f}") 174 | print(f" F-Score @ 10 mm: {metrics['fscore_0.01']:.4f}") 175 | print(f" F-Score @ 20 mm: {metrics['fscore_0.02']:.4f}") 176 | 177 | if totals["curve"]["num_gt_total"] > 0: 178 | print_metrics(metrics, totals["curve"], "curve") 179 | else: 180 | print("Curve: No ground truth edges found.") 181 | 182 | if totals["line"]["num_gt_total"] > 0: 183 | print_metrics(metrics, totals["line"], "line") 184 | else: 185 | print("Line: No ground truth edges found.") 186 | 187 | 188 | if __name__ == "__main__": 189 | parser = argparse.ArgumentParser( 190 | description="Process CAD data and compute metrics." 191 | ) 192 | parser.add_argument( 193 | "--base_dir", 194 | type=str, 195 | default="./exp/ABC", 196 | help="Base directory for experiments", 197 | ) 198 | parser.add_argument( 199 | "--dataset_dir", 200 | type=str, 201 | default="./data/ABC-NEF_Edge", 202 | help="Directory for the dataset", 203 | ) 204 | parser.add_argument("--exp_name", type=str, default="emap", help="Experiment name") 205 | 206 | args = parser.parse_args() 207 | main(args.base_dir, args.dataset_dir, args.exp_name) 208 | -------------------------------------------------------------------------------- /src/eval/eval_DTU.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d as o3d 3 | import sklearn.neighbors as skln 4 | from tqdm import tqdm 5 | from scipy.io import loadmat 6 | import multiprocessing as mp 7 | import argparse 8 | import os 9 | from pathlib import Path 10 | from src.eval.eval_util import ( 11 | set_random_seeds, 12 | load_from_json, 13 | downsample_point_cloud_average, 14 | get_pred_points_and_directions, 15 | ) 16 | 17 | 18 | def process_scan( 19 | scan_name, 20 | base_dir, 21 | exp_name, 22 | dataset_dir, 23 | threshold, 24 | downsample_density, 25 | precision_list, 26 | recall_list, 27 | ): 28 | print(f"Processing: {scan_name}") 29 | json_path = os.path.join( 30 | base_dir, scan_name, exp_name, "results", "parametric_edges.json" 31 | ) 32 | if not os.path.exists(json_path): 33 | print(f"Invalid prediction at {scan_name}") 34 | return 35 | 36 | meta_data_json_path = os.path.join(dataset_dir, "data", scan_name, "meta_data.json") 37 | worldtogt = np.array(load_from_json(Path(meta_data_json_path))["worldtogt"]) 38 | 39 | all_curve_points, all_line_points, _, _ = get_pred_points_and_directions(json_path) 40 | all_points = np.concatenate([all_curve_points, all_line_points], axis=0).reshape( 41 | -1, 3 42 | ) 43 | all_points = np.dot(all_points, worldtogt[:3, :3].T) + worldtogt[:3, 3] 44 | 45 | points_down = downsample_point_cloud_average(all_points, num_voxels_per_axis=256) 46 | 47 | nn_engine = skln.NearestNeighbors( 48 | n_neighbors=1, radius=downsample_density, algorithm="kd_tree", n_jobs=-1 49 | ) 50 | 51 | gt_edge_points_path = os.path.join( 52 | dataset_dir, "groundtruth", "edge_points", scan_name, "edge_points.ply" 53 | ) 54 | gt_edge_pcd = o3d.io.read_point_cloud(gt_edge_points_path) 55 | gt_edge_points = np.asarray(gt_edge_pcd.points) 56 | 57 | nn_engine.fit(gt_edge_points) 58 | dist_d2s, idx_d2s = nn_engine.kneighbors( 59 | points_down, n_neighbors=1, return_distance=True 60 | ) 61 | precision = np.sum(dist_d2s <= threshold) / dist_d2s.shape[0] 62 | precision_list.append(precision) 63 | 64 | nn_engine.fit(points_down) 65 | dist_s2d, idx_s2d = nn_engine.kneighbors( 66 | gt_edge_points, n_neighbors=1, return_distance=True 67 | ) 68 | recall = np.sum(dist_s2d <= threshold) / len(dist_s2d) 69 | recall_list.append(recall) 70 | 71 | print(f" Recall: {recall:.4f}, Precision: {precision:.4f}") 72 | 73 | 74 | def main(args): 75 | set_random_seeds() 76 | with open("src/eval/DTU_scans.txt", "r") as f: 77 | scan_names = [line.strip() for line in f] 78 | 79 | precision_list = [] 80 | recall_list = [] 81 | 82 | for scan_name in scan_names: 83 | process_scan( 84 | scan_name, 85 | args.base_dir, 86 | args.exp_name, 87 | args.dataset_dir, 88 | args.threshold, 89 | args.downsample_density, 90 | precision_list, 91 | recall_list, 92 | ) 93 | 94 | print("\nSummary:") 95 | print(f" Mean Recall: {np.mean(recall_list):.4f}") 96 | print(f" Mean Precision: {np.mean(precision_list):.4f}") 97 | 98 | 99 | if __name__ == "__main__": 100 | parser = argparse.ArgumentParser( 101 | description="Process DTU data and compute metrics." 102 | ) 103 | parser.add_argument( 104 | "--base_dir", 105 | type=str, 106 | default="./exp/DTU", 107 | help="Base directory for experiments", 108 | ) 109 | parser.add_argument( 110 | "--dataset_dir", 111 | type=str, 112 | default="./data/DTU_Edge", 113 | help="Directory for the dataset", 114 | ) 115 | parser.add_argument("--exp_name", type=str, default="emap", help="Experiment name") 116 | parser.add_argument("--downsample_density", type=float, default=0.5) 117 | parser.add_argument("--threshold", type=float, default=5) 118 | args = parser.parse_args() 119 | main(args) 120 | -------------------------------------------------------------------------------- /src/eval/eval_util.py: -------------------------------------------------------------------------------- 1 | import os 2 | import numpy as np 3 | import json 4 | import point_cloud_utils as pcu 5 | import math 6 | import random 7 | from pathlib import Path 8 | 9 | 10 | def set_random_seeds(seed=42): 11 | np.random.seed(seed) 12 | random.seed(seed) 13 | 14 | def load_from_json(filename: Path): 15 | """Load a dictionary from a JSON filename.""" 16 | assert filename.suffix == ".json" 17 | with open(filename, encoding="UTF-8") as file: 18 | return json.load(file) 19 | 20 | def chamfer_distance(x, y, return_index=False, p_norm=2, max_points_per_leaf=10): 21 | """ 22 | Compute the chamfer distance between two point clouds x, and y 23 | 24 | Args: 25 | x : A m-sized minibatch of point sets in R^d. i.e. shape [m, n_a, d] 26 | y : A m-sized minibatch of point sets in R^d. i.e. shape [m, n_b, d] 27 | return_index: If set to True, will return a pair (corrs_x_to_y, corrs_y_to_x) where 28 | corrs_x_to_y[i] stores the index into y of the closest point to x[i] 29 | (i.e. y[corrs_x_to_y[i]] is the nearest neighbor to x[i] in y). 30 | corrs_y_to_x is similar to corrs_x_to_y but with x and y reversed. 31 | max_points_per_leaf : The maximum number of points per leaf node in the KD tree used by this function. 32 | Default is 10. 33 | p_norm : Which norm to use. p_norm can be any real number, inf (for the max norm) -inf (for the min norm), 34 | 0 (for sum(x != 0)) 35 | Returns: 36 | The chamfer distance between x an dy. 37 | If return_index is set, then this function returns a tuple (chamfer_dist, corrs_x_to_y, corrs_y_to_x) where 38 | corrs_x_to_y and corrs_y_to_x are described above. 39 | """ 40 | 41 | dists_x_to_y, corrs_x_to_y = pcu.k_nearest_neighbors( 42 | x, y, k=1, squared_distances=False, max_points_per_leaf=max_points_per_leaf 43 | ) 44 | dists_y_to_x, corrs_y_to_x = pcu.k_nearest_neighbors( 45 | y, x, k=1, squared_distances=False, max_points_per_leaf=max_points_per_leaf 46 | ) 47 | 48 | dists_x_to_y = np.linalg.norm(x[corrs_y_to_x] - y, axis=-1, ord=p_norm).mean() 49 | dists_y_to_x = np.linalg.norm(y[corrs_x_to_y] - x, axis=-1, ord=p_norm).mean() 50 | 51 | Comp = np.mean(dists_x_to_y) 52 | Acc = np.mean(dists_y_to_x) 53 | cham_dist = Comp + Acc 54 | 55 | if return_index: 56 | return cham_dist, corrs_x_to_y, corrs_y_to_x 57 | 58 | return cham_dist, Acc, Comp 59 | 60 | 61 | def compute_chamfer_distance(pred_sampled, gt_points): 62 | """ 63 | Compute chamfer distance between predicted and ground truth points. 64 | 65 | Args: 66 | pred_sampled (numpy.ndarray): Predicted point cloud. 67 | gt_points (numpy.ndarray): Ground truth points. 68 | 69 | Returns: 70 | float: Chamfer distance. 71 | """ 72 | chamfer_dist, acc, comp = chamfer_distance(pred_sampled, gt_points) 73 | return chamfer_dist, acc, comp 74 | 75 | 76 | def f_score(precision, recall): 77 | """ 78 | Compute F-score. 79 | 80 | Args: 81 | precision (float): Precision. 82 | recall (float): Recall. 83 | 84 | Returns: 85 | float: F-score. 86 | """ 87 | return 2 * precision * recall / (precision + recall) 88 | 89 | 90 | def bezier_curve_length(control_points, num_samples): 91 | def binomial_coefficient(n, i): 92 | return math.factorial(n) // (math.factorial(i) * math.factorial(n - i)) 93 | 94 | def derivative_bezier(t): 95 | n = len(control_points) - 1 96 | point = np.array([0.0, 0.0, 0.0]) 97 | for i, (p1, p2) in enumerate(zip(control_points[:-1], control_points[1:])): 98 | point += ( 99 | n 100 | * binomial_coefficient(n - 1, i) 101 | * (1 - t) ** (n - 1 - i) 102 | * t**i 103 | * (np.array(p2) - np.array(p1)) 104 | ) 105 | return point 106 | 107 | def simpson_integral(a, b, num_samples): 108 | h = (b - a) / num_samples 109 | sum1 = sum( 110 | np.linalg.norm(derivative_bezier(a + i * h)) 111 | for i in range(1, num_samples, 2) 112 | ) 113 | sum2 = sum( 114 | np.linalg.norm(derivative_bezier(a + i * h)) 115 | for i in range(2, num_samples - 1, 2) 116 | ) 117 | return ( 118 | ( 119 | np.linalg.norm(derivative_bezier(a)) 120 | + 4 * sum1 121 | + 2 * sum2 122 | + np.linalg.norm(derivative_bezier(b)) 123 | ) 124 | * h 125 | / 3 126 | ) 127 | 128 | # Compute the length of the 3D Bezier curve using composite Simpson's rule 129 | length = 0.0 130 | for i in range(num_samples): 131 | t0 = i / num_samples 132 | t1 = (i + 1) / num_samples 133 | length += simpson_integral(t0, t1, num_samples) 134 | 135 | return length 136 | 137 | 138 | def compute_precision_recall_IOU( 139 | pred_sampled, gt_points, metrics, thresh_list=[0.02], edge_type="all" 140 | ): 141 | """ 142 | Compute precision, recall, F-score, and IOU. 143 | 144 | Args: 145 | pred_sampled (numpy.ndarray): Predicted point cloud. 146 | gt_points (numpy.ndarray): Ground truth points. 147 | metrics (dict): Dictionary to store metrics. 148 | 149 | Returns: 150 | dict: Updated metrics. 151 | """ 152 | if edge_type == "all": 153 | for thresh in thresh_list: 154 | dists_a_to_b, _ = pcu.k_nearest_neighbors( 155 | pred_sampled, gt_points, k=1 156 | ) # k closest points (in pts_b) for each point in pts_a 157 | correct_pred = np.sum(dists_a_to_b < thresh) 158 | precision = correct_pred / len(dists_a_to_b) 159 | metrics[f"precision_{thresh}"].append(precision) 160 | 161 | dists_b_to_a, _ = pcu.k_nearest_neighbors(gt_points, pred_sampled, k=1) 162 | correct_gt = np.sum(dists_b_to_a < thresh) 163 | recall = correct_gt / len(dists_b_to_a) 164 | metrics[f"recall_{thresh}"].append(recall) 165 | 166 | fscore = 2 * precision * recall / (precision + recall) 167 | metrics[f"fscore_{thresh}"].append(fscore) 168 | 169 | intersection = min(correct_pred, correct_gt) 170 | union = ( 171 | len(dists_a_to_b) + len(dists_b_to_a) - max(correct_pred, correct_gt) 172 | ) 173 | 174 | IOU = intersection / union 175 | metrics[f"IOU_{thresh}"].append(IOU) 176 | return metrics 177 | else: 178 | correct_gt_list = [] 179 | correct_pred_list = [] 180 | _, acc, comp = compute_chamfer_distance(pred_sampled, gt_points) 181 | for thresh in thresh_list: 182 | dists_b_to_a, _ = pcu.k_nearest_neighbors(gt_points, pred_sampled, k=1) 183 | correct_gt = np.sum(dists_b_to_a < thresh) 184 | num_gt = len(dists_b_to_a) 185 | correct_gt_list.append(correct_gt) 186 | dists_a_to_b, _ = pcu.k_nearest_neighbors(pred_sampled, gt_points, k=1) 187 | correct_pred = np.sum(dists_a_to_b < thresh) 188 | correct_pred_list.append(correct_pred) 189 | num_pred = len(dists_a_to_b) 190 | 191 | return correct_gt_list, num_gt, correct_pred_list, num_pred, acc, comp 192 | 193 | 194 | def get_gt_points( 195 | scan_name, 196 | edge_type="all", 197 | interval=0.005, 198 | return_direction=False, 199 | data_base_dir=None, 200 | ): 201 | """ 202 | Get ground truth points from a dataset. 203 | 204 | Args: 205 | name (str): Name of the dataset. 206 | 207 | Returns: 208 | numpy.ndarray: Raw and processed ground truth points. 209 | """ 210 | objs_dir = os.path.join(data_base_dir, "obj") 211 | obj_names = os.listdir(objs_dir) 212 | obj_names.sort() 213 | index_obj_names = {} 214 | for obj_name in obj_names: 215 | index_obj_names[obj_name[:8]] = obj_name 216 | 217 | json_feats_path = os.path.join(data_base_dir, "chunk_0000_feats.json") 218 | with open(json_feats_path, "r") as f: 219 | json_data_feats = json.load(f) 220 | json_stats_path = os.path.join(data_base_dir, "chunk_0000_stats.json") 221 | with open(json_stats_path, "r") as f: 222 | json_data_stats = json.load(f) 223 | 224 | # get the normalize scale to help align the nerf points and gt points 225 | [ 226 | x_min, 227 | y_min, 228 | z_min, 229 | x_max, 230 | y_max, 231 | z_max, 232 | x_range, 233 | y_range, 234 | z_range, 235 | ] = json_data_stats[scan_name]["bbox"] 236 | scale = 1 / max(x_range, y_range, z_range) 237 | poi_center = ( 238 | np.array([((x_min + x_max) / 2), ((y_min + y_max) / 2), ((z_min + z_max) / 2)]) 239 | * scale 240 | ) 241 | set_location = [0.5, 0.5, 0.5] - poi_center # based on the rendering settings 242 | obj_path = os.path.join(objs_dir, index_obj_names[scan_name]) 243 | 244 | with open(obj_path, encoding="utf-8") as file: 245 | data = file.readlines() 246 | 247 | vertices_obj = [each.split(" ") for each in data if each.split(" ")[0] == "v"] 248 | vertices_xyz = [ 249 | [float(v[1]), float(v[2]), float(v[3].replace("\n", ""))] for v in vertices_obj 250 | ] 251 | 252 | edge_pts = [] 253 | edge_pts_raw = [] 254 | edge_pts_direction = [] 255 | rename = { 256 | "BSpline": "curve", 257 | "Circle": "curve", 258 | "Ellipse": "curve", 259 | "Line": "line", 260 | } 261 | for each_curve in json_data_feats[scan_name]: 262 | if edge_type != "all" and rename[each_curve["type"]] != edge_type: 263 | continue 264 | 265 | if each_curve["sharp"]: # each_curve["type"]: BSpline, Line, Circle 266 | each_edge_pts = [vertices_xyz[i] for i in each_curve["vert_indices"]] 267 | edge_pts_raw += each_edge_pts 268 | 269 | gt_sampling = [] 270 | each_edge_pts = np.array(each_edge_pts) 271 | for index in range(len(each_edge_pts) - 1): 272 | next = each_edge_pts[index + 1] 273 | current = each_edge_pts[index] 274 | num = int(np.linalg.norm(next - current) // interval) 275 | linspace = np.linspace(0, 1, num) 276 | gt_sampling.append( 277 | linspace[:, None] * current + (1 - linspace)[:, None] * next 278 | ) 279 | 280 | if return_direction: 281 | direction = (next - current) / np.linalg.norm(next - current) 282 | edge_pts_direction.extend([direction] * num) 283 | each_edge_pts = np.concatenate(gt_sampling).tolist() 284 | edge_pts += each_edge_pts 285 | 286 | if len(edge_pts_raw) == 0: 287 | return None, None, None 288 | 289 | edge_pts_raw = np.array(edge_pts_raw) * scale + set_location 290 | edge_pts = np.array(edge_pts) * scale + set_location 291 | edge_pts_direction = np.array(edge_pts_direction) 292 | 293 | return ( 294 | edge_pts_raw.astype(np.float32), 295 | edge_pts.astype(np.float32), 296 | edge_pts_direction, 297 | ) 298 | 299 | 300 | def get_pred_points_and_directions( 301 | json_path, 302 | sample_resolution=0.005, 303 | ): 304 | with open(json_path, "r") as f: 305 | json_data = json.load(f) 306 | 307 | curve_paras = np.array(json_data["curves_ctl_pts"]).reshape(-1, 3) 308 | curves_ctl_pts = curve_paras.reshape(-1, 4, 3) 309 | lines_end_pts = np.array(json_data["lines_end_pts"]).reshape(-1, 2, 3) 310 | 311 | num_curves = len(curves_ctl_pts) 312 | num_lines = len(lines_end_pts) 313 | 314 | all_curve_points = [] 315 | all_curve_directions = [] 316 | 317 | # # -----------------------------------for Cubic Bezier----------------------------------- 318 | if num_curves > 0: 319 | for i, each_curve in enumerate(curves_ctl_pts): 320 | each_curve = np.array(each_curve).reshape(4, 3) # shape: (4, 3) 321 | sample_num = int( 322 | bezier_curve_length(each_curve, num_samples=100) // sample_resolution 323 | ) 324 | t = np.linspace(0, 1, sample_num) 325 | matrix_u = np.array([t**3, t**2, t, [1] * sample_num]).reshape( 326 | 4, sample_num 327 | ) 328 | 329 | matrix_middle = np.array( 330 | [[-1, 3, -3, 1], [3, -6, 3, 0], [-3, 3, 0, 0], [1, 0, 0, 0]] 331 | ) 332 | 333 | matrix = np.matmul( 334 | np.matmul(matrix_u.T, matrix_middle), each_curve 335 | ).reshape(sample_num, 3) 336 | 337 | all_curve_points += matrix.tolist() 338 | 339 | # Calculate the curve directions (derivatives) 340 | derivative_u = 3 * t**2 341 | derivative_v = 2 * t 342 | 343 | # Derivative matrices for x, y, z 344 | dx = ( 345 | ( 346 | -3 * each_curve[0][0] 347 | + 9 * each_curve[1][0] 348 | - 9 * each_curve[2][0] 349 | + 3 * each_curve[3][0] 350 | ) 351 | * derivative_u 352 | + (6 * each_curve[0][0] - 12 * each_curve[1][0] + 6 * each_curve[2][0]) 353 | * derivative_v 354 | + (-3 * each_curve[0][0] + 3 * each_curve[1][0]) 355 | ) 356 | 357 | dy = ( 358 | ( 359 | -3 * each_curve[0][1] 360 | + 9 * each_curve[1][1] 361 | - 9 * each_curve[2][1] 362 | + 3 * each_curve[3][1] 363 | ) 364 | * derivative_u 365 | + (6 * each_curve[0][1] - 12 * each_curve[1][1] + 6 * each_curve[2][1]) 366 | * derivative_v 367 | + (-3 * each_curve[0][1] + 3 * each_curve[1][1]) 368 | ) 369 | 370 | dz = ( 371 | ( 372 | -3 * each_curve[0][2] 373 | + 9 * each_curve[1][2] 374 | - 9 * each_curve[2][2] 375 | + 3 * each_curve[3][2] 376 | ) 377 | * derivative_u 378 | + (6 * each_curve[0][2] - 12 * each_curve[1][2] + 6 * each_curve[2][2]) 379 | * derivative_v 380 | + (-3 * each_curve[0][2] + 3 * each_curve[1][2]) 381 | ) 382 | for i in range(sample_num): 383 | direction = np.array([dx[i], dy[i], dz[i]]) 384 | norm_direction = direction / np.linalg.norm(direction) 385 | all_curve_directions.append(norm_direction) 386 | 387 | all_line_points = [] 388 | all_line_directions = [] 389 | # # -------------------------------------for Line----------------------------------------- 390 | if num_lines > 0: 391 | for i, each_line in enumerate(lines_end_pts): 392 | each_line = np.array(each_line).reshape(2, 3) # shape: (2, 3) 393 | sample_num = int( 394 | np.linalg.norm(each_line[0] - each_line[-1]) // sample_resolution 395 | ) 396 | t = np.linspace(0, 1, sample_num) 397 | 398 | matrix_u_l = np.array([t, [1] * sample_num]) 399 | matrix_middle_l = np.array([[-1, 1], [1, 0]]) 400 | 401 | matrix_l = np.matmul( 402 | np.matmul(matrix_u_l.T, matrix_middle_l), each_line 403 | ).reshape(sample_num, 3) 404 | all_line_points += matrix_l.tolist() 405 | 406 | # Calculate the direction vector for the line segment 407 | direction = each_line[1] - each_line[0] 408 | norm_direction = direction / (np.linalg.norm(direction) + 1e-6) 409 | 410 | for point in matrix_l: 411 | all_line_directions.append(norm_direction) 412 | 413 | all_curve_points = np.array(all_curve_points).reshape(-1, 3) 414 | all_line_points = np.array(all_line_points).reshape(-1, 3) 415 | return all_curve_points, all_line_points, all_curve_directions, all_line_directions 416 | 417 | 418 | def downsample_point_cloud_average( 419 | points, num_voxels_per_axis=256, min_bound=None, max_bound=None 420 | ): 421 | """ 422 | Downsample a point set based on the number of voxels per axis by averaging the points within each voxel. 423 | 424 | Args: 425 | points: a [#v, 3]-shaped array of 3d points. 426 | num_voxels_per_axis: a scalar or 3-tuple specifying the number of voxels along each axis. 427 | 428 | Returns: 429 | A [#v', 3]-shaped numpy array of downsampled points, where #v' is the number of occupied voxels. 430 | """ 431 | 432 | # Calculate the bounding box of the point cloud 433 | if min_bound is None: 434 | min_bound = np.min(points, axis=0) 435 | else: 436 | min_bound = np.array(min_bound) 437 | 438 | if max_bound is None: 439 | max_bound = np.max(points, axis=0) 440 | else: 441 | max_bound = np.array(max_bound) 442 | 443 | # Determine the size of the voxel based on the desired number of voxels per axis 444 | if isinstance(num_voxels_per_axis, int): 445 | voxel_size = (max_bound - min_bound) / num_voxels_per_axis 446 | else: 447 | voxel_size = [ 448 | (max_bound[i] - min_bound[i]) / num_voxels_per_axis[i] for i in range(3) 449 | ] 450 | 451 | # Use the existing function to downsample the point cloud based on voxel size 452 | downsampled_points = pcu.downsample_point_cloud_on_voxel_grid( 453 | voxel_size, points, min_bound=min_bound, max_bound=max_bound 454 | ) 455 | 456 | return downsampled_points 457 | -------------------------------------------------------------------------------- /src/models/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/cvg/EMAP/652f24ecc3f3cbf538928f27cc6d55dbebb360c7/src/models/__init__.py -------------------------------------------------------------------------------- /src/models/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | 4 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. 5 | class Embedder: 6 | def __init__(self, **kwargs): 7 | self.kwargs = kwargs 8 | self.create_embedding_fn() 9 | 10 | def create_embedding_fn(self): 11 | embed_fns = [] 12 | d = self.kwargs["input_dims"] 13 | out_dim = 0 14 | if self.kwargs["include_input"]: 15 | embed_fns.append(lambda x: x) 16 | out_dim += d 17 | 18 | max_freq = self.kwargs["max_freq_log2"] 19 | N_freqs = self.kwargs["num_freqs"] 20 | 21 | if self.kwargs["log_sampling"]: 22 | freq_bands = 2.0 ** torch.linspace(0.0, max_freq, N_freqs) 23 | else: 24 | freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, N_freqs) 25 | 26 | for freq in freq_bands: 27 | for p_fn in self.kwargs["periodic_fns"]: 28 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 29 | out_dim += d 30 | 31 | self.embed_fns = embed_fns 32 | self.out_dim = out_dim 33 | 34 | def embed(self, inputs): 35 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 36 | 37 | 38 | def get_embedder(multires, input_dims=3): 39 | embed_kwargs = { 40 | "include_input": True, 41 | "input_dims": input_dims, 42 | "max_freq_log2": multires - 1, 43 | "num_freqs": multires, 44 | "log_sampling": True, 45 | "periodic_fns": [torch.sin, torch.cos], 46 | } 47 | 48 | embedder_obj = Embedder(**embed_kwargs) 49 | 50 | def embed(x, eo=embedder_obj): 51 | return eo.embed(x) 52 | 53 | return embed, embedder_obj.out_dim 54 | -------------------------------------------------------------------------------- /src/models/loss.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.nn.functional as F 3 | 4 | 5 | class EdgeLoss(nn.Module): 6 | 7 | def __init__(self, loss_type="mse"): 8 | super(EdgeLoss, self).__init__() 9 | if loss_type == "mse": 10 | self.loss_func = F.mse_loss 11 | elif loss_type == "l1": 12 | self.loss_func = F.l1_loss 13 | 14 | def forward(self, pred_edge, gt_edge): 15 | 16 | edge_loss = self.loss_func(pred_edge, gt_edge, reduction="mean") 17 | return edge_loss 18 | -------------------------------------------------------------------------------- /src/models/udf_model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import numpy as np 4 | from src.models.embedder import get_embedder 5 | 6 | 7 | class UDFNetwork(nn.Module): 8 | def __init__( 9 | self, 10 | d_in, 11 | d_out, 12 | d_hidden, 13 | n_layers, 14 | skip_in=(4,), 15 | multires=0, 16 | scale=1, 17 | bias=0.5, 18 | geometric_init=True, 19 | weight_norm=True, 20 | udf_type="abs", 21 | ): 22 | super(UDFNetwork, self).__init__() 23 | 24 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] 25 | 26 | self.embed_fn_fine = None 27 | 28 | if multires > 0: 29 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 30 | self.embed_fn_fine = embed_fn 31 | dims[0] = input_ch 32 | 33 | self.num_layers = len(dims) 34 | self.skip_in = skip_in 35 | self.scale = scale 36 | 37 | self.geometric_init = geometric_init 38 | 39 | for l in range(0, self.num_layers - 1): 40 | if l + 1 in self.skip_in: 41 | out_dim = dims[l + 1] - dims[0] 42 | else: 43 | out_dim = dims[l + 1] 44 | 45 | lin = nn.Linear(dims[l], out_dim) 46 | 47 | if geometric_init: 48 | if l == self.num_layers - 2: 49 | torch.nn.init.normal_( 50 | lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001 51 | ) 52 | # torch.nn.init.constant_(lin.bias, bias) # for indoor sdf setting 53 | torch.nn.init.constant_(lin.bias, -bias) 54 | 55 | elif multires > 0 and l == 0: 56 | torch.nn.init.constant_(lin.bias, 0.0) 57 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 58 | torch.nn.init.normal_( 59 | lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim) 60 | ) 61 | elif multires > 0 and l in self.skip_in: 62 | torch.nn.init.constant_(lin.bias, 0.0) 63 | torch.nn.init.normal_( 64 | lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim) 65 | ) 66 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0) 67 | else: 68 | torch.nn.init.constant_(lin.bias, 0.0) 69 | torch.nn.init.normal_( 70 | lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim) 71 | ) 72 | 73 | if weight_norm: 74 | lin = nn.utils.parametrizations.weight_norm(lin) 75 | 76 | setattr(self, "lin" + str(l), lin) 77 | 78 | self.activation = nn.Softplus(beta=100) 79 | self.relu = nn.ReLU() 80 | self.udf_type = udf_type 81 | 82 | def udf_out(self, x): 83 | if self.udf_type == "abs": 84 | return torch.abs(x) 85 | elif self.udf_type == "square": 86 | return x**2 87 | elif self.udf_type == "sdf": 88 | return x 89 | 90 | def forward(self, inputs): 91 | inputs = inputs * self.scale 92 | if self.embed_fn_fine is not None: 93 | inputs = self.embed_fn_fine(inputs) 94 | 95 | x = inputs 96 | for l in range(0, self.num_layers - 1): 97 | lin = getattr(self, "lin" + str(l)) 98 | 99 | if l in self.skip_in: 100 | x = torch.cat([x, inputs], 1) / np.sqrt(2) 101 | 102 | x = lin(x) 103 | 104 | if l < self.num_layers - 2: 105 | x = self.activation(x) 106 | 107 | return ( 108 | torch.cat([self.udf_out(x[:, :1]) / self.scale, x[:, 1:]], dim=-1), 109 | inputs, 110 | ) 111 | 112 | def udf(self, x): 113 | feature_out, PE = self.forward(x) 114 | udf = feature_out[:, :1] 115 | feature = feature_out[:, 1:] 116 | return udf, feature, PE 117 | 118 | def udf_hidden_appearance(self, x): 119 | return self.forward(x) 120 | 121 | def gradient(self, x): 122 | x.requires_grad_(True) 123 | with torch.set_grad_enabled(True): 124 | y = self.udf(x)[0] 125 | 126 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 127 | gradients = torch.autograd.grad( 128 | outputs=y, 129 | inputs=x, 130 | grad_outputs=d_output, 131 | create_graph=True, 132 | retain_graph=True, 133 | only_inputs=True, 134 | )[0] 135 | return gradients.unsqueeze(1) 136 | 137 | 138 | class RenderingNetwork(nn.Module): 139 | def __init__( 140 | self, 141 | d_feature, 142 | mode, 143 | d_in, 144 | d_out, 145 | d_hidden, 146 | n_layers, 147 | weight_norm=True, 148 | multires_view=0, 149 | squeeze_out=True, 150 | ): 151 | super().__init__() 152 | 153 | self.mode = mode 154 | self.squeeze_out = squeeze_out 155 | self.d_out = d_out 156 | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out] 157 | 158 | self.embedview_fn = None 159 | if multires_view > 0 and self.mode != "no_view_dir": 160 | embedview_fn, input_ch = get_embedder(multires_view) 161 | self.embedview_fn = embedview_fn 162 | dims[0] += input_ch - 3 163 | 164 | self.num_layers = len(dims) 165 | 166 | for l in range(0, self.num_layers - 1): 167 | out_dim = dims[l + 1] 168 | lin = nn.Linear(dims[l], out_dim) 169 | 170 | if weight_norm: 171 | lin = nn.utils.parametrizations.weight_norm(lin) 172 | 173 | setattr(self, "lin" + str(l), lin) 174 | 175 | self.relu = nn.ReLU() 176 | 177 | def forward(self, points, normals, view_dirs, feature_vectors): 178 | if self.embedview_fn is not None: 179 | view_dirs = self.embedview_fn(view_dirs) 180 | 181 | rendering_input = None 182 | normals = normals.detach() 183 | if self.mode == "idr": 184 | rendering_input = torch.cat( 185 | [points, view_dirs, normals, -1 * normals, feature_vectors], dim=-1 186 | ) 187 | elif self.mode == "no_view_dir": 188 | rendering_input = torch.cat( 189 | [points, normals, -1 * normals, feature_vectors], dim=-1 190 | ) 191 | elif self.mode == "no_normal": 192 | rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1) 193 | 194 | x = rendering_input 195 | 196 | for l in range(0, self.num_layers - 1): 197 | lin = getattr(self, "lin" + str(l)) 198 | 199 | x = lin(x) 200 | 201 | if l < self.num_layers - 2: 202 | x = self.relu(x) 203 | 204 | if self.squeeze_out: 205 | color = torch.sigmoid(x[:, : self.d_out]) 206 | else: 207 | color = x[:, : self.d_out] 208 | 209 | return color 210 | 211 | 212 | class SingleVarianceNetwork(nn.Module): 213 | def __init__(self, init_val, requires_grad=True): 214 | super(SingleVarianceNetwork, self).__init__() 215 | self.variance = nn.Parameter( 216 | torch.Tensor([init_val]), requires_grad=requires_grad 217 | ) 218 | self.second_variance = nn.Parameter( 219 | torch.Tensor([init_val]), requires_grad=requires_grad 220 | ) 221 | 222 | def set_trainable(self): 223 | self.variance.requires_grad = True 224 | self.second_variance.requires_grad = True 225 | 226 | def forward(self, x): 227 | return torch.ones([len(x), 1]).to(x.device) * torch.exp(self.variance * 10.0) 228 | 229 | def get_secondvariance(self, x): 230 | return torch.ones([len(x), 1]).to(x.device) * torch.exp( 231 | self.second_variance * 10.0 232 | ) 233 | 234 | 235 | class BetaNetwork(nn.Module): 236 | def __init__( 237 | self, 238 | init_var_beta=0.1, 239 | init_var_gamma=0.1, 240 | init_var_zeta=0.05, 241 | beta_min=0.00005, 242 | requires_grad_beta=True, 243 | requires_grad_gamma=True, 244 | requires_grad_zeta=True, 245 | ): 246 | super().__init__() 247 | 248 | self.beta = nn.Parameter( 249 | torch.Tensor([init_var_beta]), requires_grad=requires_grad_beta 250 | ) 251 | self.gamma = nn.Parameter( 252 | torch.Tensor([init_var_gamma]), requires_grad=requires_grad_gamma 253 | ) 254 | self.zeta = nn.Parameter( 255 | torch.Tensor([init_var_zeta]), requires_grad=requires_grad_zeta 256 | ) 257 | self.beta_min = beta_min 258 | 259 | def get_beta(self): 260 | return torch.exp(self.beta * 10).clip(0, 1.0 / self.beta_min) 261 | 262 | def get_gamma(self): 263 | return torch.exp(self.gamma * 10) 264 | 265 | def get_zeta(self): 266 | """ 267 | used for udf2prob mapping zeta*x/(1+zeta*x) 268 | :return: 269 | :rtype: 270 | """ 271 | return self.zeta.abs() 272 | 273 | def set_beta_trainable(self): 274 | self.beta.requires_grad = True 275 | 276 | @torch.no_grad() 277 | def set_gamma(self, x): 278 | self.gamma = nn.Parameter( 279 | torch.Tensor([x]), requires_grad=self.gamma.requires_grad 280 | ).to(self.gamma.device) 281 | 282 | def forward(self): 283 | beta = self.get_beta() 284 | gamma = self.get_gamma() 285 | zeta = self.get_zeta() 286 | return beta, gamma, zeta 287 | -------------------------------------------------------------------------------- /src/runner/runner_base.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import os 4 | import numpy as np 5 | from shutil import copyfile 6 | from icecream import ic 7 | from pyhocon import ConfigFactory, HOCONConverter 8 | from src.dataset.dataset import Dataset 9 | from src.models.udf_model import ( 10 | UDFNetwork, 11 | BetaNetwork, 12 | SingleVarianceNetwork, 13 | ) 14 | from src.models.udf_renderer_blending import UDFRendererBlending 15 | from src.models.loss import EdgeLoss 16 | 17 | 18 | class Runner: 19 | def __init__( 20 | self, 21 | conf, 22 | mode="train", 23 | is_continue=False, 24 | args=None, 25 | ): 26 | # Initial setting 27 | self.device = torch.device("cuda") 28 | self.conf = conf 29 | 30 | self.base_exp_dir = os.path.join( 31 | self.conf["general.base_exp_dir"], 32 | str(self.conf["dataset"]["scan"]), 33 | self.conf["general.expname"], 34 | ) 35 | os.makedirs(self.base_exp_dir, exist_ok=True) 36 | 37 | self.dataset = Dataset(self.conf["dataset"]) 38 | self.near, self.far = self.dataset.near, self.dataset.far 39 | 40 | self.iter_step = 0 41 | 42 | # trainning parameters 43 | self.end_iter = self.conf.get_int("train.end_iter") 44 | self.save_freq = self.conf.get_int("train.save_freq") 45 | self.report_freq = self.conf.get_int("train.report_freq") 46 | self.val_freq = self.conf.get_int("train.val_freq") 47 | self.batch_size = self.conf.get_int("train.batch_size") 48 | self.validate_resolution_level = self.conf.get_int( 49 | "train.validate_resolution_level" 50 | ) 51 | self.use_white_bkgd = self.conf.get_bool("train.use_white_bkgd") 52 | self.importance_sample = self.conf.get_bool("train.importance_sample") 53 | 54 | # setting about learning rate schedule 55 | self.learning_rate = self.conf.get_float("train.learning_rate") 56 | self.learning_rate_geo = self.conf.get_float("train.learning_rate_geo") 57 | self.learning_rate_alpha = self.conf.get_float("train.learning_rate_alpha") 58 | self.warm_up_end = self.conf.get_float("train.warm_up_end", default=0.0) 59 | self.anneal_end = self.conf.get_float("train.anneal_end", default=0.0) 60 | # don't train the udf network in the early steps 61 | self.fix_geo_end = self.conf.get_float("train.fix_geo_end", default=200) 62 | self.warmup_sample = self.conf.get_bool( 63 | "train.warmup_sample", default=False 64 | ) # * training schedule 65 | # whether the udf network and appearance network share the same learning rate 66 | self.same_lr = self.conf.get_bool("train.same_lr", default=False) 67 | 68 | # weights 69 | self.igr_weight = self.conf.get_float("train.igr_weight") 70 | self.igr_ns_weight = self.conf.get_float("train.igr_ns_weight", default=0.0) 71 | 72 | # loss functions 73 | self.edge_loss_func = EdgeLoss(self.conf["edge_loss"]["loss_type"]) 74 | self.edge_weight = self.conf.get_float("edge_loss.edge_weight", 0.0) 75 | self.is_continue = is_continue 76 | # self.is_finetune = args.is_finetune 77 | 78 | self.mode = mode 79 | self.model_type = self.conf["general.model_type"] 80 | self.model_list = [] 81 | self.writer = None 82 | 83 | # Networks 84 | params_to_train = [] 85 | params_to_train_nerf = [] 86 | params_to_train_geo = [] 87 | 88 | self.nerf_outside = None 89 | self.nerf_coarse = None 90 | self.nerf_fine = None 91 | self.sdf_network_fine = None 92 | self.udf_network_fine = None 93 | self.variance_network_fine = None 94 | 95 | # self.nerf_outside = NeRF(**self.conf["model.nerf"]).to(self.device) 96 | self.udf_network_fine = UDFNetwork(**self.conf["model.udf_network"]).to( 97 | self.device 98 | ) 99 | self.variance_network_fine = SingleVarianceNetwork( 100 | **self.conf["model.variance_network"] 101 | ).to(self.device) 102 | self.beta_network = BetaNetwork(**self.conf["model.beta_network"]).to( 103 | self.device 104 | ) 105 | # params_to_train_nerf += list(self.nerf_outside.parameters()) 106 | params_to_train_geo += list(self.udf_network_fine.parameters()) 107 | params_to_train += list(self.variance_network_fine.parameters()) 108 | params_to_train += list(self.beta_network.parameters()) 109 | 110 | self.optimizer = torch.optim.Adam( 111 | [ 112 | {"params": params_to_train_geo, "lr": self.learning_rate_geo}, 113 | {"params": params_to_train}, 114 | {"params": params_to_train_nerf}, 115 | ], 116 | lr=self.learning_rate, 117 | ) 118 | 119 | self.renderer = UDFRendererBlending( 120 | self.nerf_outside, 121 | self.udf_network_fine, 122 | self.variance_network_fine, 123 | self.beta_network, 124 | device=self.device, 125 | **self.conf["model.udf_renderer"], 126 | ) 127 | 128 | def update_learning_rate(self, start_g_id=0): 129 | if self.iter_step < self.warm_up_end: 130 | learning_factor = self.iter_step / self.warm_up_end 131 | else: 132 | alpha = self.learning_rate_alpha 133 | progress = (self.iter_step - self.warm_up_end) / ( 134 | self.end_iter - self.warm_up_end 135 | ) 136 | learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * ( 137 | 1 - alpha 138 | ) + alpha 139 | 140 | for g in self.optimizer.param_groups[start_g_id:]: 141 | g["lr"] = self.learning_rate * learning_factor 142 | 143 | def update_learning_rate_geo(self): 144 | if self.iter_step < self.fix_geo_end: # * make bg nerf learn first 145 | learning_factor = 0.0 146 | elif self.iter_step < self.warm_up_end * 2: 147 | learning_factor = self.iter_step / (self.warm_up_end * 2) 148 | elif self.iter_step < self.end_iter * 0.5: 149 | learning_factor = 1.0 150 | else: 151 | alpha = self.learning_rate_alpha 152 | progress = (self.iter_step - self.end_iter * 0.5) / ( 153 | self.end_iter - self.end_iter * 0.5 154 | ) 155 | learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * ( 156 | 1 - alpha 157 | ) + alpha 158 | 159 | for g in self.optimizer.param_groups[:1]: 160 | g["lr"] = self.learning_rate_geo * learning_factor 161 | 162 | def get_cos_anneal_ratio(self): 163 | if self.anneal_end == 0.0: 164 | return 1.0 165 | else: 166 | return np.min([1.0, self.iter_step / self.anneal_end]) 167 | 168 | def train(self): 169 | self.train_udf() 170 | 171 | def get_flip_saturation(self, flip_saturation_max=0.9): 172 | start = 10000 173 | if self.iter_step < start: 174 | flip_saturation = 0.0 175 | elif self.iter_step < self.end_iter * 0.5: 176 | flip_saturation = flip_saturation_max 177 | else: 178 | flip_saturation = 1.0 179 | 180 | return flip_saturation 181 | 182 | def file_backup(self): 183 | # copy python file 184 | dir_lis = self.conf["general.recording"] 185 | cur_dir = os.path.join(self.base_exp_dir, "recording") 186 | os.makedirs(cur_dir, exist_ok=True) 187 | files = os.listdir("./") 188 | for f_name in files: 189 | if f_name[-3:] == ".py": 190 | copyfile(os.path.join("./", f_name), os.path.join(cur_dir, f_name)) 191 | 192 | for dir_name in dir_lis: 193 | os.system(f"cp -r {dir_name} {cur_dir}") 194 | 195 | # copy configs 196 | # copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf')) 197 | with open( 198 | os.path.join(self.base_exp_dir, "recording", "config.conf"), "w" 199 | ) as fd: 200 | res = HOCONConverter.to_hocon(self.conf) 201 | fd.write(res) 202 | 203 | def train_udf(self): 204 | return NotImplementedError 205 | 206 | def load_checkpoint(self, checkpoint_name): 207 | return NotImplementedError 208 | 209 | def save_checkpoint(self): 210 | return NotImplementedError 211 | 212 | def validate(self, idx=-1, resolution_level=-1): 213 | return NotImplementedError 214 | -------------------------------------------------------------------------------- /src/utils/__init__.py: -------------------------------------------------------------------------------- 1 | import torch 2 | 3 | # optimizer 4 | from torch.optim import SGD, Adam 5 | import torch_optimizer as optim 6 | 7 | # scheduler 8 | from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR 9 | from .warmup_scheduler import GradualWarmupScheduler 10 | 11 | from .visualization import * 12 | from .plots import * 13 | 14 | 15 | def get_parameters(models): 16 | """Get all model parameters recursively.""" 17 | parameters = [] 18 | if isinstance(models, list): 19 | for model in models: 20 | parameters += get_parameters(model) 21 | elif isinstance(models, dict): 22 | for model in models.values(): 23 | parameters += get_parameters(model) 24 | else: # models is actually a single pytorch model 25 | parameters += list(models.parameters()) 26 | return parameters 27 | 28 | 29 | def get_optimizer(hparams, models): 30 | eps = 1e-8 31 | parameters = get_parameters(models) 32 | if hparams.optimizer == "sgd": 33 | optimizer = SGD( 34 | parameters, 35 | lr=hparams.lr, 36 | momentum=hparams.momentum, 37 | weight_decay=hparams.weight_decay, 38 | ) 39 | elif hparams.optimizer == "adam": 40 | optimizer = Adam( 41 | parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay 42 | ) 43 | elif hparams.optimizer == "radam": 44 | optimizer = optim.RAdam( 45 | parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay 46 | ) 47 | elif hparams.optimizer == "ranger": 48 | optimizer = optim.Ranger( 49 | parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay 50 | ) 51 | else: 52 | raise ValueError("optimizer not recognized!") 53 | 54 | return optimizer 55 | 56 | 57 | def get_scheduler(hparams, optimizer): 58 | eps = 1e-8 59 | if hparams.lr_scheduler == "steplr": 60 | scheduler = MultiStepLR( 61 | optimizer, milestones=hparams.decay_step, gamma=hparams.decay_gamma 62 | ) 63 | elif hparams.lr_scheduler == "cosine": 64 | scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps) 65 | elif hparams.lr_scheduler == "poly": 66 | scheduler = LambdaLR( 67 | optimizer, 68 | lambda epoch: (1 - epoch / hparams.num_epochs) ** hparams.poly_exp, 69 | ) 70 | else: 71 | raise ValueError("scheduler not recognized!") 72 | 73 | if hparams.warmup_epochs > 0 and hparams.optimizer not in ["radam", "ranger"]: 74 | scheduler = GradualWarmupScheduler( 75 | optimizer, 76 | multiplier=hparams.warmup_multiplier, 77 | total_epoch=hparams.warmup_epochs, 78 | after_scheduler=scheduler, 79 | ) 80 | 81 | return scheduler 82 | 83 | 84 | def get_learning_rate(optimizer): 85 | for param_group in optimizer.param_groups: 86 | return param_group["lr"] 87 | 88 | 89 | def extract_model_state_dict(ckpt_path, model_name="model", prefixes_to_ignore=[]): 90 | checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu")) 91 | checkpoint_ = {} 92 | if "state_dict" in checkpoint: # if it's a pytorch-lightning checkpoint 93 | checkpoint = checkpoint["state_dict"] 94 | for k, v in checkpoint.items(): 95 | if not k.startswith(model_name): 96 | continue 97 | k = k[len(model_name) + 1 :] 98 | for prefix in prefixes_to_ignore: 99 | if k.startswith(prefix): 100 | print("ignore", k) 101 | break 102 | else: 103 | checkpoint_[k] = v 104 | return checkpoint_ 105 | 106 | 107 | def load_ckpt(model, ckpt_path, model_name="model", prefixes_to_ignore=[]): 108 | if not ckpt_path: 109 | return 110 | model_dict = model.state_dict() 111 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore) 112 | model_dict.update(checkpoint_) 113 | model.load_state_dict(model_dict) 114 | -------------------------------------------------------------------------------- /src/utils/math.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """ Math Helper Functions """ 16 | 17 | from dataclasses import dataclass 18 | 19 | import torch 20 | from torchtyping import TensorType 21 | 22 | 23 | def components_from_spherical_harmonics( 24 | levels: int, directions: TensorType[..., 3] 25 | ) -> TensorType[..., "components"]: 26 | """ 27 | Returns value for each component of spherical harmonics. 28 | 29 | Args: 30 | levels: Number of spherical harmonic levels to compute. 31 | directions: Spherical hamonic coefficients 32 | """ 33 | num_components = levels**2 34 | components = torch.zeros( 35 | (*directions.shape[:-1], num_components), device=directions.device 36 | ) 37 | 38 | assert 1 <= levels <= 5, f"SH levels must be in [1,4], got {levels}" 39 | assert ( 40 | directions.shape[-1] == 3 41 | ), f"Direction input should have three dimensions. Got {directions.shape[-1]}" 42 | 43 | x = directions[..., 0] 44 | y = directions[..., 1] 45 | z = directions[..., 2] 46 | 47 | xx = x**2 48 | yy = y**2 49 | zz = z**2 50 | 51 | # l0 52 | components[..., 0] = 0.28209479177387814 53 | 54 | # l1 55 | if levels > 1: 56 | components[..., 1] = 0.4886025119029199 * y 57 | components[..., 2] = 0.4886025119029199 * z 58 | components[..., 3] = 0.4886025119029199 * x 59 | 60 | # l2 61 | if levels > 2: 62 | components[..., 4] = 1.0925484305920792 * x * y 63 | components[..., 5] = 1.0925484305920792 * y * z 64 | components[..., 6] = 0.9461746957575601 * zz - 0.31539156525251999 65 | components[..., 7] = 1.0925484305920792 * x * z 66 | components[..., 8] = 0.5462742152960396 * (xx - yy) 67 | 68 | # l3 69 | if levels > 3: 70 | components[..., 9] = 0.5900435899266435 * y * (3 * xx - yy) 71 | components[..., 10] = 2.890611442640554 * x * y * z 72 | components[..., 11] = 0.4570457994644658 * y * (5 * zz - 1) 73 | components[..., 12] = 0.3731763325901154 * z * (5 * zz - 3) 74 | components[..., 13] = 0.4570457994644658 * x * (5 * zz - 1) 75 | components[..., 14] = 1.445305721320277 * z * (xx - yy) 76 | components[..., 15] = 0.5900435899266435 * x * (xx - 3 * yy) 77 | 78 | # l4 79 | if levels > 4: 80 | components[..., 16] = 2.5033429417967046 * x * y * (xx - yy) 81 | components[..., 17] = 1.7701307697799304 * y * z * (3 * xx - yy) 82 | components[..., 18] = 0.9461746957575601 * x * y * (7 * zz - 1) 83 | components[..., 19] = 0.6690465435572892 * y * (7 * zz - 3) 84 | components[..., 20] = 0.10578554691520431 * (35 * zz * zz - 30 * zz + 3) 85 | components[..., 21] = 0.6690465435572892 * x * z * (7 * zz - 3) 86 | components[..., 22] = 0.47308734787878004 * (xx - yy) * (7 * zz - 1) 87 | components[..., 23] = 1.7701307697799304 * x * z * (xx - 3 * yy) 88 | components[..., 24] = 0.4425326924449826 * ( 89 | xx * (xx - 3 * yy) - yy * (3 * xx - yy) 90 | ) 91 | 92 | return components 93 | 94 | 95 | @dataclass 96 | class Gaussians: 97 | """Stores Gaussians 98 | 99 | Args: 100 | mean: Mean of multivariate Gaussian 101 | cov: Covariance of multivariate Gaussian. 102 | """ 103 | 104 | mean: TensorType[..., "dim"] 105 | cov: TensorType[..., "dim", "dim"] 106 | 107 | 108 | def compute_3d_gaussian( 109 | directions: TensorType[..., 3], 110 | means: TensorType[..., 3], 111 | dir_variance: TensorType[..., 1], 112 | radius_variance: TensorType[..., 1], 113 | ) -> Gaussians: 114 | """Compute guassian along ray. 115 | 116 | Args: 117 | directions: Axis of Gaussian. 118 | means: Mean of Gaussian. 119 | dir_variance: Variance along direction axis. 120 | radius_variance: Variance tangent to direction axis. 121 | 122 | Returns: 123 | Gaussians: Oriented 3D gaussian. 124 | """ 125 | 126 | dir_outer_product = directions[..., :, None] * directions[..., None, :] 127 | eye = torch.eye(directions.shape[-1], device=directions.device) 128 | dir_mag_sq = torch.clamp( 129 | torch.sum(directions**2, dim=-1, keepdim=True), min=1e-10 130 | ) 131 | null_outer_product = ( 132 | eye - directions[..., :, None] * (directions / dir_mag_sq)[..., None, :] 133 | ) 134 | dir_cov_diag = dir_variance[..., None] * dir_outer_product[..., :, :] 135 | radius_cov_diag = radius_variance[..., None] * null_outer_product[..., :, :] 136 | cov = dir_cov_diag + radius_cov_diag 137 | return Gaussians(mean=means, cov=cov) 138 | 139 | 140 | def cylinder_to_gaussian( 141 | origins: TensorType[..., 3], 142 | directions: TensorType[..., 3], 143 | starts: TensorType[..., 1], 144 | ends: TensorType[..., 1], 145 | radius: TensorType[..., 1], 146 | ) -> Gaussians: 147 | """Approximates cylinders with a Gaussian distributions. 148 | 149 | Args: 150 | origins: Origins of cylinders. 151 | directions: Direction (axis) of cylinders. 152 | starts: Start of cylinders. 153 | ends: End of cylinders. 154 | radius: Radii of cylinders. 155 | 156 | Returns: 157 | Gaussians: Approximation of cylinders 158 | """ 159 | means = origins + directions * ((starts + ends) / 2.0) 160 | dir_variance = (ends - starts) ** 2 / 12 161 | radius_variance = radius**2 / 4.0 162 | return compute_3d_gaussian(directions, means, dir_variance, radius_variance) 163 | 164 | 165 | def conical_frustum_to_gaussian( 166 | origins: TensorType[..., 3], 167 | directions: TensorType[..., 3], 168 | starts: TensorType[..., 1], 169 | ends: TensorType[..., 1], 170 | radius: TensorType[..., 1], 171 | ) -> Gaussians: 172 | """Approximates conical frustums with a Gaussian distributions. 173 | 174 | Uses stable parameterization described in mip-NeRF publication. 175 | 176 | Args: 177 | origins: Origins of cones. 178 | directions: Direction (axis) of frustums. 179 | starts: Start of conical frustums. 180 | ends: End of conical frustums. 181 | radius: Radii of cone a distance of 1 from the origin. 182 | 183 | Returns: 184 | Gaussians: Approximation of conical frustums 185 | """ 186 | mu = (starts + ends) / 2.0 187 | hw = (ends - starts) / 2.0 188 | means = origins + directions * ( 189 | mu + (2.0 * mu * hw**2.0) / (3.0 * mu**2.0 + hw**2.0) 190 | ) 191 | dir_variance = (hw**2) / 3 - (4 / 15) * ( 192 | (hw**4 * (12 * mu**2 - hw**2)) / (3 * mu**2 + hw**2) ** 2 193 | ) 194 | radius_variance = radius**2 * ( 195 | (mu**2) / 4 196 | + (5 / 12) * hw**2 197 | - 4 / 15 * (hw**4) / (3 * mu**2 + hw**2) 198 | ) 199 | return compute_3d_gaussian(directions, means, dir_variance, radius_variance) 200 | 201 | 202 | def expected_sin(x_means: torch.Tensor, x_vars: torch.Tensor) -> torch.Tensor: 203 | """Computes the expected value of sin(y) where y ~ N(x_means, x_vars) 204 | 205 | Args: 206 | x_means: Mean values. 207 | x_vars: Variance of values. 208 | 209 | Returns: 210 | torch.Tensor: The expected value of sin. 211 | """ 212 | 213 | return torch.exp(-0.5 * x_vars) * torch.sin(x_means) 214 | -------------------------------------------------------------------------------- /src/utils/rend_util.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import imageio 3 | import skimage 4 | import cv2 5 | import torch 6 | from torch.nn import functional as F 7 | 8 | 9 | def get_psnr(img1, img2, normalize_rgb=False): 10 | if normalize_rgb: # [-1,1] --> [0,1] 11 | img1 = (img1 + 1.0) / 2.0 12 | img2 = (img2 + 1.0) / 2.0 13 | 14 | mse = torch.mean((img1 - img2) ** 2) 15 | psnr = -10.0 * torch.log(mse) / torch.log(torch.Tensor([10.0]).cuda()) 16 | 17 | return psnr 18 | 19 | 20 | def load_rgb(path, normalize_rgb=False): 21 | img = imageio.imread(path) 22 | img = skimage.img_as_float32(img) 23 | 24 | if normalize_rgb: # [-1,1] --> [0,1] 25 | img -= 0.5 26 | img *= 2.0 27 | img = img.transpose(2, 0, 1) 28 | return img 29 | 30 | 31 | def load_K_Rt_from_P(filename, P=None): 32 | if P is None: 33 | lines = open(filename).read().splitlines() 34 | if len(lines) == 4: 35 | lines = lines[1:] 36 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)] 37 | P = np.asarray(lines).astype(np.float32).squeeze() 38 | 39 | out = cv2.decomposeProjectionMatrix(P) 40 | K = out[0] 41 | R = out[1] 42 | t = out[2] 43 | 44 | K = K / K[2, 2] 45 | intrinsics = np.eye(4) 46 | intrinsics[:3, :3] = K 47 | 48 | pose = np.eye(4, dtype=np.float32) 49 | pose[:3, :3] = R.transpose() 50 | pose[:3, 3] = (t[:3] / t[3])[:, 0] 51 | 52 | return intrinsics, pose 53 | 54 | 55 | def get_camera_params(uv, pose, intrinsics): 56 | if pose.shape[1] == 7: # In case of quaternion vector representation 57 | cam_loc = pose[:, 4:] 58 | R = quat_to_rot(pose[:, :4]) 59 | p = torch.eye(4).repeat(pose.shape[0], 1, 1).cuda().float() 60 | p[:, :3, :3] = R 61 | p[:, :3, 3] = cam_loc 62 | else: # In case of pose matrix representation 63 | cam_loc = pose[:, :3, 3] 64 | p = pose 65 | 66 | batch_size, num_samples, _ = uv.shape 67 | 68 | depth = torch.ones((batch_size, num_samples)).cuda() 69 | x_cam = uv[:, :, 0].view(batch_size, -1) 70 | y_cam = uv[:, :, 1].view(batch_size, -1) 71 | z_cam = depth.view(batch_size, -1) 72 | 73 | pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics) 74 | 75 | # permute for batch matrix product 76 | pixel_points_cam = pixel_points_cam.permute(0, 2, 1) 77 | 78 | world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3] 79 | ray_dirs = world_coords - cam_loc[:, None, :] 80 | ray_dirs = F.normalize(ray_dirs, dim=2) 81 | 82 | return ray_dirs, cam_loc 83 | 84 | 85 | def get_camera_for_plot(pose): 86 | if pose.shape[1] == 7: # In case of quaternion vector representation 87 | cam_loc = pose[:, 4:].detach() 88 | R = quat_to_rot(pose[:, :4].detach()) 89 | else: # In case of pose matrix representation 90 | cam_loc = pose[:, :3, 3] 91 | R = pose[:, :3, :3] 92 | cam_dir = R[:, :3, 2] 93 | return cam_loc, cam_dir 94 | 95 | 96 | def lift(x, y, z, intrinsics): 97 | # parse intrinsics 98 | intrinsics = intrinsics.cuda() 99 | fx = intrinsics[:, 0, 0] 100 | fy = intrinsics[:, 1, 1] 101 | cx = intrinsics[:, 0, 2] 102 | cy = intrinsics[:, 1, 2] 103 | sk = intrinsics[:, 0, 1] 104 | 105 | x_lift = ( 106 | ( 107 | x 108 | - cx.unsqueeze(-1) 109 | + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1) 110 | - sk.unsqueeze(-1) * y / fy.unsqueeze(-1) 111 | ) 112 | / fx.unsqueeze(-1) 113 | * z 114 | ) 115 | y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z 116 | 117 | # homogeneous 118 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1) 119 | 120 | 121 | def quat_to_rot(q): 122 | batch_size, _ = q.shape 123 | q = F.normalize(q, dim=1) 124 | R = torch.ones((batch_size, 3, 3)).cuda() 125 | qr = q[:, 0] 126 | qi = q[:, 1] 127 | qj = q[:, 2] 128 | qk = q[:, 3] 129 | R[:, 0, 0] = 1 - 2 * (qj**2 + qk**2) 130 | R[:, 0, 1] = 2 * (qj * qi - qk * qr) 131 | R[:, 0, 2] = 2 * (qi * qk + qr * qj) 132 | R[:, 1, 0] = 2 * (qj * qi + qk * qr) 133 | R[:, 1, 1] = 1 - 2 * (qi**2 + qk**2) 134 | R[:, 1, 2] = 2 * (qj * qk - qi * qr) 135 | R[:, 2, 0] = 2 * (qk * qi - qj * qr) 136 | R[:, 2, 1] = 2 * (qj * qk + qi * qr) 137 | R[:, 2, 2] = 1 - 2 * (qi**2 + qj**2) 138 | return R 139 | 140 | 141 | def rot_to_quat(R): 142 | batch_size, _, _ = R.shape 143 | q = torch.ones((batch_size, 4)).cuda() 144 | 145 | R00 = R[:, 0, 0] 146 | R01 = R[:, 0, 1] 147 | R02 = R[:, 0, 2] 148 | R10 = R[:, 1, 0] 149 | R11 = R[:, 1, 1] 150 | R12 = R[:, 1, 2] 151 | R20 = R[:, 2, 0] 152 | R21 = R[:, 2, 1] 153 | R22 = R[:, 2, 2] 154 | 155 | q[:, 0] = torch.sqrt(1.0 + R00 + R11 + R22) / 2 156 | q[:, 1] = (R21 - R12) / (4 * q[:, 0]) 157 | q[:, 2] = (R02 - R20) / (4 * q[:, 0]) 158 | q[:, 3] = (R10 - R01) / (4 * q[:, 0]) 159 | return q 160 | 161 | 162 | def get_sphere_intersections(cam_loc, ray_directions, r=1.0): 163 | # Input: n_rays x 3 ; n_rays x 3 164 | # Output: n_rays x 1, n_rays x 1 (close and far) 165 | 166 | ray_cam_dot = torch.bmm( 167 | ray_directions.view(-1, 1, 3), cam_loc.view(-1, 3, 1) 168 | ).squeeze(-1) 169 | under_sqrt = ray_cam_dot**2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r**2) 170 | 171 | # sanity check 172 | if (under_sqrt <= 0).sum() > 0: 173 | print("BOUNDING SPHERE PROBLEM!") 174 | exit() 175 | 176 | sphere_intersections = ( 177 | torch.sqrt(under_sqrt) * torch.Tensor([-1, 1]).cuda().float() - ray_cam_dot 178 | ) 179 | sphere_intersections = sphere_intersections.clamp_min(0.0) 180 | 181 | return sphere_intersections 182 | -------------------------------------------------------------------------------- /src/utils/tensor_dataclass.py: -------------------------------------------------------------------------------- 1 | # Copyright 2022 The Nerfstudio Team. All rights reserved. 2 | # 3 | # Licensed under the Apache License, Version 2.0 (the "License"); 4 | # you may not use this file except in compliance with the License. 5 | # You may obtain a copy of the License at 6 | # 7 | # http://www.apache.org/licenses/LICENSE-2.0 8 | # 9 | # Unless required by applicable law or agreed to in writing, software 10 | # distributed under the License is distributed on an "AS IS" BASIS, 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 | # See the License for the specific language governing permissions and 13 | # limitations under the License. 14 | 15 | """Tensor dataclass""" 16 | 17 | import dataclasses 18 | from copy import deepcopy 19 | from typing import Callable, Dict, List, NoReturn, Optional, Tuple, TypeVar, Union 20 | 21 | import numpy as np 22 | import torch 23 | 24 | TensorDataclassT = TypeVar("TensorDataclassT", bound="TensorDataclass") 25 | 26 | 27 | class TensorDataclass: 28 | """@dataclass of tensors with the same size batch. Allows indexing and standard tensor ops. 29 | Fields that are not Tensors will not be batched unless they are also a TensorDataclass. 30 | Any fields that are dictionaries will have their Tensors or TensorDataclasses batched, and 31 | dictionaries will have their tensors or TensorDataclasses considered in the initial broadcast. 32 | Tensor fields must have at least 1 dimension, meaning that you must convert a field like torch.Tensor(1) 33 | to torch.Tensor([1]) 34 | 35 | Example: 36 | 37 | .. code-block:: python 38 | 39 | @dataclass 40 | class TestTensorDataclass(TensorDataclass): 41 | a: torch.Tensor 42 | b: torch.Tensor 43 | c: torch.Tensor = None 44 | 45 | # Create a new tensor dataclass with batch size of [2,3,4] 46 | test = TestTensorDataclass(a=torch.ones((2, 3, 4, 2)), b=torch.ones((4, 3))) 47 | 48 | test.shape # [2, 3, 4] 49 | test.a.shape # [2, 3, 4, 2] 50 | test.b.shape # [2, 3, 4, 3] 51 | 52 | test.reshape((6,4)).shape # [6, 4] 53 | test.flatten().shape # [24,] 54 | 55 | test[..., 0].shape # [2, 3] 56 | test[:, 0, :].shape # [2, 4] 57 | """ 58 | 59 | _shape: tuple 60 | 61 | # A mapping from field-name (str): n (int) 62 | # Any field OR any key in a dictionary field with this name (field-name) and a corresponding 63 | # torch.Tensor will be assumed to have n dimensions after the batch dims. These n final dimensions 64 | # will remain the same shape when doing reshapes, broadcasting, etc on the tensordataclass 65 | _field_custom_dimensions: Dict[str, int] = {} 66 | 67 | def __post_init__(self) -> None: 68 | """Finishes setting up the TensorDataclass 69 | 70 | This will 1) find the broadcasted shape and 2) broadcast all fields to this shape 3) 71 | set _shape to be the broadcasted shape. 72 | """ 73 | if self._field_custom_dimensions is not None: 74 | for k, v in self._field_custom_dimensions.items(): 75 | assert ( 76 | isinstance(v, int) and v > 1 77 | ), f"Custom dimensions must be an integer greater than 1, since 1 is the default, received {k}: {v}" 78 | 79 | if not dataclasses.is_dataclass(self): 80 | raise TypeError("TensorDataclass must be a dataclass") 81 | 82 | batch_shapes = self._get_dict_batch_shapes( 83 | {f.name: self.__getattribute__(f.name) for f in dataclasses.fields(self)} 84 | ) 85 | if len(batch_shapes) == 0: 86 | raise ValueError("TensorDataclass must have at least one tensor") 87 | batch_shape = torch.broadcast_shapes(*batch_shapes) 88 | 89 | broadcasted_fields = self._broadcast_dict_fields( 90 | {f.name: self.__getattribute__(f.name) for f in dataclasses.fields(self)}, 91 | batch_shape, 92 | ) 93 | for f, v in broadcasted_fields.items(): 94 | self.__setattr__(f, v) 95 | 96 | self.__setattr__("_shape", batch_shape) 97 | 98 | def _get_dict_batch_shapes(self, dict_: Dict) -> List: 99 | """Returns batch shapes of all tensors in a dictionary 100 | 101 | Args: 102 | dict_: The dictionary to get the batch shapes of. 103 | 104 | Returns: 105 | The batch shapes of all tensors in the dictionary. 106 | """ 107 | batch_shapes = [] 108 | for k, v in dict_.items(): 109 | if isinstance(v, torch.Tensor): 110 | if ( 111 | isinstance(self._field_custom_dimensions, dict) 112 | and k in self._field_custom_dimensions 113 | ): 114 | # pylint: disable=unsubscriptable-object 115 | batch_shapes.append(v.shape[: -self._field_custom_dimensions[k]]) 116 | else: 117 | batch_shapes.append(v.shape[:-1]) 118 | elif isinstance(v, TensorDataclass): 119 | batch_shapes.append(v.shape) 120 | elif isinstance(v, Dict): 121 | batch_shapes.extend(self._get_dict_batch_shapes(v)) 122 | return batch_shapes 123 | 124 | def _broadcast_dict_fields(self, dict_: Dict, batch_shape) -> Dict: 125 | """Broadcasts all tensors in a dictionary according to batch_shape 126 | 127 | Args: 128 | dict_: The dictionary to broadcast. 129 | 130 | Returns: 131 | The broadcasted dictionary. 132 | """ 133 | new_dict = {} 134 | for k, v in dict_.items(): 135 | if isinstance(v, torch.Tensor): 136 | # If custom dimension key, then we need to 137 | if ( 138 | isinstance(self._field_custom_dimensions, dict) 139 | and k in self._field_custom_dimensions 140 | ): 141 | # pylint: disable=unsubscriptable-object 142 | new_dict[k] = v.broadcast_to( 143 | ( 144 | *batch_shape, 145 | *v.shape[-self._field_custom_dimensions[k] :], 146 | ) 147 | ) 148 | else: 149 | new_dict[k] = v.broadcast_to((*batch_shape, v.shape[-1])) 150 | elif isinstance(v, TensorDataclass): 151 | new_dict[k] = v.broadcast_to(batch_shape) 152 | elif isinstance(v, Dict): 153 | new_dict[k] = self._broadcast_dict_fields(v, batch_shape) 154 | return new_dict 155 | 156 | def __getitem__(self: TensorDataclassT, indices) -> TensorDataclassT: 157 | if isinstance(indices, (torch.Tensor)): 158 | return self._apply_fn_to_fields(lambda x: x[indices]) 159 | if isinstance(indices, (int, slice, type(Ellipsis))): 160 | indices = (indices,) 161 | assert isinstance(indices, tuple) 162 | tensor_fn = lambda x: x[indices + (slice(None),)] 163 | dataclass_fn = lambda x: x[indices] 164 | 165 | def custom_tensor_dims_fn(k, v): 166 | custom_dims = self._field_custom_dimensions[ 167 | k 168 | ] # pylint: disable=unsubscriptable-object 169 | return v[indices + ((slice(None),) * custom_dims)] 170 | 171 | return self._apply_fn_to_fields( 172 | tensor_fn, dataclass_fn, custom_tensor_dims_fn=custom_tensor_dims_fn 173 | ) 174 | 175 | def __setitem__(self, indices, value) -> NoReturn: 176 | raise RuntimeError("Index assignment is not supported for TensorDataclass") 177 | 178 | def __len__(self) -> int: 179 | if len(self._shape) == 0: 180 | raise TypeError("len() of a 0-d tensor") 181 | return self.shape[0] 182 | 183 | def __bool__(self) -> bool: 184 | if len(self) == 0: 185 | raise ValueError( 186 | f"The truth value of {self.__class__.__name__} when `len(x) == 0` " 187 | "is ambiguous. Use `len(x)` or `x is not None`." 188 | ) 189 | return True 190 | 191 | @property 192 | def shape(self) -> Tuple[int, ...]: 193 | """Returns the batch shape of the tensor dataclass.""" 194 | return self._shape 195 | 196 | @property 197 | def size(self) -> int: 198 | """Returns the number of elements in the tensor dataclass batch dimension.""" 199 | if len(self._shape) == 0: 200 | return 1 201 | return int(np.prod(self._shape)) 202 | 203 | @property 204 | def ndim(self) -> int: 205 | """Returns the number of dimensions of the tensor dataclass.""" 206 | return len(self._shape) 207 | 208 | def reshape(self: TensorDataclassT, shape: Tuple[int, ...]) -> TensorDataclassT: 209 | """Returns a new TensorDataclass with the same data but with a new shape. 210 | 211 | This should deepcopy as well. 212 | 213 | Args: 214 | shape: The new shape of the tensor dataclass. 215 | 216 | Returns: 217 | A new TensorDataclass with the same data but with a new shape. 218 | """ 219 | if isinstance(shape, int): 220 | shape = (shape,) 221 | tensor_fn = lambda x: x.reshape((*shape, x.shape[-1])) 222 | dataclass_fn = lambda x: x.reshape(shape) 223 | 224 | def custom_tensor_dims_fn(k, v): 225 | custom_dims = self._field_custom_dimensions[ 226 | k 227 | ] # pylint: disable=unsubscriptable-object 228 | return v.reshape((*shape, *v.shape[-custom_dims:])) 229 | 230 | return self._apply_fn_to_fields( 231 | tensor_fn, dataclass_fn, custom_tensor_dims_fn=custom_tensor_dims_fn 232 | ) 233 | 234 | def flatten(self: TensorDataclassT) -> TensorDataclassT: 235 | """Returns a new TensorDataclass with flattened batch dimensions 236 | 237 | Returns: 238 | TensorDataclass: A new TensorDataclass with the same data but with a new shape. 239 | """ 240 | return self.reshape((-1,)) 241 | 242 | def broadcast_to( 243 | self: TensorDataclassT, shape: Union[torch.Size, Tuple[int, ...]] 244 | ) -> TensorDataclassT: 245 | """Returns a new TensorDataclass broadcast to new shape. 246 | 247 | Changes to the original tensor dataclass should effect the returned tensor dataclass, 248 | meaning it is NOT a deepcopy, and they are still linked. 249 | 250 | Args: 251 | shape: The new shape of the tensor dataclass. 252 | 253 | Returns: 254 | A new TensorDataclass with the same data but with a new shape. 255 | """ 256 | 257 | def custom_tensor_dims_fn(k, v): 258 | custom_dims = self._field_custom_dimensions[ 259 | k 260 | ] # pylint: disable=unsubscriptable-object 261 | return v.broadcast_to((*shape, *v.shape[-custom_dims:])) 262 | 263 | return self._apply_fn_to_fields( 264 | lambda x: x.broadcast_to((*shape, x.shape[-1])), 265 | custom_tensor_dims_fn=custom_tensor_dims_fn, 266 | ) 267 | 268 | def to(self: TensorDataclassT, device) -> TensorDataclassT: 269 | """Returns a new TensorDataclass with the same data but on the specified device. 270 | 271 | Args: 272 | device: The device to place the tensor dataclass. 273 | 274 | Returns: 275 | A new TensorDataclass with the same data but on the specified device. 276 | """ 277 | return self._apply_fn_to_fields(lambda x: x.to(device)) 278 | 279 | def _apply_fn_to_fields( 280 | self: TensorDataclassT, 281 | fn: Callable, 282 | dataclass_fn: Optional[Callable] = None, 283 | custom_tensor_dims_fn: Optional[Callable] = None, 284 | ) -> TensorDataclassT: 285 | """Applies a function to all fields of the tensor dataclass. 286 | 287 | TODO: Someone needs to make a high level design choice for whether not not we want this 288 | to apply the function to any fields in arbitray superclasses. This is an edge case until we 289 | upgrade to python 3.10 and dataclasses can actually be subclassed with vanilla python and no 290 | janking, but if people try to jank some subclasses that are grandchildren of TensorDataclass 291 | (imagine if someone tries to subclass the RayBundle) this will matter even before upgrading 292 | to 3.10 . Currently we aren't going to be able to work properly for grandchildren, but you 293 | want to use self.__dict__ if you want to apply this to grandchildren instead of our dictionary 294 | from dataclasses.fields(self) as we do below and in other places. 295 | 296 | Args: 297 | fn: The function to apply to tensor fields. 298 | dataclass_fn: The function to apply to TensorDataclass fields. 299 | 300 | Returns: 301 | A new TensorDataclass with the same data but with a new shape. 302 | """ 303 | 304 | new_fields = self._apply_fn_to_dict( 305 | {f.name: self.__getattribute__(f.name) for f in dataclasses.fields(self)}, 306 | fn, 307 | dataclass_fn, 308 | custom_tensor_dims_fn, 309 | ) 310 | 311 | return dataclasses.replace(self, **new_fields) 312 | 313 | def _apply_fn_to_dict( 314 | self, 315 | dict_: Dict, 316 | fn: Callable, 317 | dataclass_fn: Optional[Callable] = None, 318 | custom_tensor_dims_fn: Optional[Callable] = None, 319 | ) -> Dict: 320 | """A helper function for _apply_fn_to_fields, applying a function to all fields of dict_ 321 | 322 | Args: 323 | dict_: The dictionary to apply the function to. 324 | fn: The function to apply to tensor fields. 325 | dataclass_fn: The function to apply to TensorDataclass fields. 326 | 327 | Returns: 328 | A new dictionary with the same data but with a new shape. Will deep copy""" 329 | 330 | field_names = dict_.keys() 331 | new_dict = {} 332 | for f in field_names: 333 | v = dict_[f] 334 | if v is not None: 335 | if isinstance(v, TensorDataclass) and dataclass_fn is not None: 336 | new_dict[f] = dataclass_fn(v) 337 | # This is the case when we have a custom dimensions tensor 338 | elif ( 339 | isinstance(v, torch.Tensor) 340 | and isinstance(self._field_custom_dimensions, dict) 341 | and f in self._field_custom_dimensions 342 | and custom_tensor_dims_fn is not None 343 | ): 344 | new_dict[f] = custom_tensor_dims_fn(f, v) 345 | elif isinstance(v, (torch.Tensor, TensorDataclass)): 346 | new_dict[f] = fn(v) 347 | elif isinstance(v, Dict): 348 | new_dict[f] = self._apply_fn_to_dict(v, fn, dataclass_fn) 349 | else: 350 | new_dict[f] = deepcopy(v) 351 | 352 | return new_dict 353 | -------------------------------------------------------------------------------- /src/utils/visualization.py: -------------------------------------------------------------------------------- 1 | import torchvision.transforms as T 2 | import numpy as np 3 | import cv2 4 | from PIL import Image 5 | import flow_vis 6 | import torch 7 | 8 | 9 | def visualize_depth(x, cmap=cv2.COLORMAP_JET): 10 | """ 11 | depth: (H, W) 12 | """ 13 | # x = depth.cpu().numpy() 14 | x = np.nan_to_num(x) # change nan to 0 15 | mi = np.min(x) # get minimum depth 16 | ma = np.max(x) 17 | x = (x - mi) / max(ma - mi, 1e-8) # normalize to 0~1 18 | x = (255 * x).astype(np.uint8) 19 | # x_ = Image.fromarray(cv2.applyColorMap(x, cmap)) 20 | # x_ = T.ToTensor()(x_) # (3, H, W) 21 | x_ = np.array(cv2.applyColorMap(x, cmap)) 22 | return x_ 23 | 24 | 25 | def get_flow_vis(ang): 26 | # norm = line_neighborhood + 1 - np.clip(df, 0, line_neighborhood) 27 | flow_uv = 5 * np.stack([np.cos(ang), np.sin(ang)], axis=-1) 28 | flow_img = flow_vis.flow_to_color(flow_uv, convert_to_bgr=False) 29 | return flow_img 30 | 31 | 32 | COLOR_MAP_ = np.array( 33 | [ 34 | [0.9047944201469568, 0.3241718265806123, 0.33443746665210006], 35 | [0.4590171386127151, 0.9095038146383864, 0.3143840671974788], 36 | [0.4769356899795538, 0.5044406738441948, 0.5354530846360839], 37 | [0.00820945625670777, 0.24099210193126785, 0.15471834055332978], 38 | [0.6195684374237388, 0.4020380013509799, 0.26100266066404676], 39 | [0.08281237756545068, 0.05900744492710419, 0.06106221202154216], 40 | [0.2264886829978755, 0.04925271007292076, 0.10214429345996079], 41 | [0.1888247470009874, 0.11275000298612425, 0.46112894830685514], 42 | [0.37415767691880975, 0.844284596118331, 0.950471611180866], 43 | [0.3817344218157631, 0.3483259270707101, 0.6572989333690541], 44 | [0.2403115731054466, 0.03078280287279167, 0.5385975692534737], 45 | [0.7035076951650824, 0.12352084932325424, 0.12873080308790197], 46 | [0.12607434914489934, 0.111244793010015, 0.09333334699716023], 47 | [0.6551607300342269, 0.7003064103554443, 0.4131794512286162], 48 | [0.13592107365596595, 0.5390702818232149, 0.004540643174930525], 49 | [0.38286244894454347, 0.709142545393449, 0.529074791609835], 50 | [0.4279376583651734, 0.5634708596431771, 0.8505569717104301], 51 | [0.3460488523902999, 0.464769595519293, 0.6676839675477276], 52 | [0.8544063246675081, 0.5041190233407755, 0.9081217697141578], 53 | [0.9207009090747208, 0.2403865944739051, 0.05375410999863772], 54 | [0.6515786136947107, 0.6299918449948327, 0.45292029442034387], 55 | [0.986174217295693, 0.2424849846977214, 0.3981993323108266], 56 | [0.22101915872994693, 0.3408589198278038, 0.006381420347677524], 57 | [0.3159785813515982, 0.1145748921741011, 0.595754317197274], 58 | [0.10263421488052715, 0.5864139253490858, 0.23908000741142432], 59 | [0.8272999391532938, 0.6123527260897751, 0.3365197327803193], 60 | [0.5269583712937912, 0.25668929554516506, 0.7888411215078127], 61 | [0.2433880265410031, 0.7240751234287827, 0.8483215810528648], 62 | [0.7254601709704898, 0.8316525547295984, 0.9325253855921963], 63 | [0.5574483824856672, 0.2935331727879944, 0.6594839453793155], 64 | [0.6209642371433579, 0.054030693198821256, 0.5080873988178534], 65 | [0.9055507077365624, 0.12865888619203514, 0.9309191861440005], 66 | [0.9914469722960537, 0.3074114506206205, 0.8762107657323488], 67 | [0.4812682518247371, 0.15055826298548158, 0.9656340505308308], 68 | [0.6459219454316445, 0.9144794010251625, 0.751338812155106], 69 | [0.860840174209798, 0.8844626353077639, 0.3604624506769899], 70 | [0.8194991672032272, 0.926399617787601, 0.8059222327343247], 71 | [0.6540413175393658, 0.04579445254618297, 0.26891917826531275], 72 | [0.37778835833987046, 0.36247927666109536, 0.7989799305827889], 73 | [0.22738304978177726, 0.9038018263773739, 0.6970838854138303], 74 | [0.6362015495896184, 0.527680794236961, 0.5570915425178721], 75 | [0.6436401915860954, 0.6316925317144524, 0.9137151236993912], 76 | [0.04161828388587163, 0.3832413349082706, 0.6880829921949752], 77 | [0.7768167825719299, 0.8933821497682587, 0.7221278391266809], 78 | [0.8632760876301346, 0.3278628094906323, 0.8421587587114462], 79 | [0.8556499133262127, 0.6497385872901932, 0.5436895688477963], 80 | [0.9861940318610894, 0.03562313777386272, 0.9183454677106616], 81 | [0.8042586091176366, 0.6167222703170994, 0.24181981557207644], 82 | [0.9504247117633057, 0.3454233714011461, 0.6883727005547743], 83 | [0.9611909135491202, 0.46384154263898114, 0.32700443315058914], 84 | [0.523542176970206, 0.446222414615845, 0.9067402987747814], 85 | [0.7536954008682911, 0.6675512338797588, 0.22538238957839196], 86 | [0.1554052265688285, 0.05746097492966129, 0.8580358872587424], 87 | [0.8540838640971405, 0.9165504335482566, 0.6806982829158964], 88 | [0.7065090319405029, 0.8683059983962002, 0.05167128320624026], 89 | [0.39134812961899124, 0.8910075505622979, 0.7639815712623922], 90 | [0.1578117311479783, 0.20047326898284668, 0.9220177338840568], 91 | [0.2017488993096358, 0.6949259970936679, 0.8729196864798128], 92 | [0.5591089340651949, 0.15576770423813258, 0.1469857469387812], 93 | [0.14510398622626974, 0.24451497734532168, 0.46574271993578786], 94 | [0.13286397822351492, 0.4178244533944635, 0.03728728952131943], 95 | [0.556463206310225, 0.14027595183361663, 0.2731537988657907], 96 | [0.4093837966398032, 0.8015225687789814, 0.8033567296903834], 97 | [0.527442563956637, 0.902232617214431, 0.7066626674362227], 98 | [0.9058355503297827, 0.34983989180213004, 0.8353262183839384], 99 | [0.7108382186953104, 0.08591307895133471, 0.21434688012521974], 100 | [0.22757345065207668, 0.7943075496583976, 0.2992305547627421], 101 | [0.20454109788173636, 0.8251670332103687, 0.012981987094547232], 102 | [0.7672562637297392, 0.005429019973062554, 0.022163616037108702], 103 | [0.37487345910117564, 0.5086240194440863, 0.9061216063654387], 104 | [0.9878004014101087, 0.006345852772772331, 0.17499753379350858], 105 | [0.030061528704491303, 0.1409704315546606, 0.3337131835834506], 106 | [0.5022506782611504, 0.5448435505388706, 0.40584238936140726], 107 | [0.39560774627423445, 0.8905943695833262, 0.5850815030921116], 108 | [0.058615671926786406, 0.5365713844300387, 0.1620457551256279], 109 | [0.41843842882069693, 0.1536005983609976, 0.3127878501592438], 110 | [0.05947621790155899, 0.5412421167331932, 0.2611322146455659], 111 | [0.5196159938235607, 0.7066461551682705, 0.970261497412556], 112 | [0.30443031606149007, 0.45158581060034975, 0.4331841153149706], 113 | [0.8848298403933996, 0.7241791700943656, 0.8917110054596072], 114 | [0.5720260591898779, 0.3072801598203052, 0.8891066705989902], 115 | [0.13964015336177327, 0.2531778096760302, 0.5703756837403124], 116 | [0.2156307542329836, 0.4139947500641685, 0.87051676884144], 117 | [0.10800455881891169, 0.05554646035458266, 0.2947027428551443], 118 | [0.35198009410633857, 0.365849666213808, 0.06525787683513773], 119 | [0.5223264108118847, 0.9032195574351178, 0.28579084943315025], 120 | [0.7607724246546966, 0.3087194381828555, 0.6253235528354899], 121 | [0.5060485442077824, 0.19173600467625274, 0.9931175692203702], 122 | [0.5131805830323746, 0.07719515392040577, 0.923212006754969], 123 | [0.3629762141280106, 0.02429179642710888, 0.6963754952399983], 124 | [0.7542592485456767, 0.6478893299494212, 0.3424965345400731], 125 | [0.49944574453364454, 0.6775665366832825, 0.33758796076989583], 126 | [0.010621818120767679, 0.8221571611173205, 0.5186257457566332], 127 | [0.5857910304290109, 0.7178133992025467, 0.9729243483606071], 128 | [0.16987399482717613, 0.9942570210657463, 0.18120758122552927], 129 | [0.016362572521240848, 0.17582788603087263, 0.7255176922640298], 130 | [0.10981764283706419, 0.9078582203470377, 0.7638063718334003], 131 | [0.9252097840441119, 0.3330197086990039, 0.27888705301420136], 132 | [0.12769972651171546, 0.11121470804891687, 0.12710743734391716], 133 | [0.5753520518360334, 0.2763862879599456, 0.6115636613363361], 134 | ] 135 | ) 136 | 137 | 138 | def prepare_semseg(img): 139 | if img.ndim == 3: 140 | img = img[..., 0] 141 | assert ( 142 | img.ndim == 2 143 | ), f"Expecting 2D numpy array with semseg classes, got {img.shape}" 144 | colors = COLOR_MAP_ 145 | img_color_ids = sorted(np.unique(img)) 146 | map = np.zeros(np.max(img) + 1, dtype=np.int32) 147 | map[img_color_ids] = np.arange(len(img_color_ids)) 148 | img = map[img] 149 | # assert all(0 <= c_id < len(colors) for c_id in img_color_ids) 150 | img = colors[img.astype(np.int32)] * 255 151 | return img 152 | -------------------------------------------------------------------------------- /src/utils/warmup_scheduler.py: -------------------------------------------------------------------------------- 1 | from torch.optim.lr_scheduler import _LRScheduler 2 | from torch.optim.lr_scheduler import ReduceLROnPlateau 3 | 4 | 5 | class GradualWarmupScheduler(_LRScheduler): 6 | """Gradually warm-up(increasing) learning rate in optimizer. 7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'. 8 | Args: 9 | optimizer (Optimizer): Wrapped optimizer. 10 | multiplier: target learning rate = base lr * multiplier 11 | total_epoch: target learning rate is reached at total_epoch, gradually 12 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau) 13 | """ 14 | 15 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None): 16 | self.multiplier = multiplier 17 | if self.multiplier < 1.0: 18 | raise ValueError("multiplier should be greater thant or equal to 1.") 19 | self.total_epoch = total_epoch 20 | self.after_scheduler = after_scheduler 21 | self.finished = False 22 | super().__init__(optimizer) 23 | 24 | def get_lr(self): 25 | if self.last_epoch > self.total_epoch: 26 | if self.after_scheduler: 27 | if not self.finished: 28 | self.after_scheduler.base_lrs = [ 29 | base_lr * self.multiplier for base_lr in self.base_lrs 30 | ] 31 | self.finished = True 32 | return self.after_scheduler.get_lr() 33 | return [base_lr * self.multiplier for base_lr in self.base_lrs] 34 | 35 | return [ 36 | base_lr 37 | * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) 38 | for base_lr in self.base_lrs 39 | ] 40 | 41 | def step_ReduceLROnPlateau(self, metrics, epoch=None): 42 | if epoch is None: 43 | epoch = self.last_epoch + 1 44 | self.last_epoch = ( 45 | epoch if epoch != 0 else 1 46 | ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning 47 | if self.last_epoch <= self.total_epoch: 48 | warmup_lr = [ 49 | base_lr 50 | * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0) 51 | for base_lr in self.base_lrs 52 | ] 53 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr): 54 | param_group["lr"] = lr 55 | else: 56 | if epoch is None: 57 | self.after_scheduler.step(metrics, None) 58 | else: 59 | self.after_scheduler.step(metrics, epoch - self.total_epoch) 60 | 61 | def step(self, epoch=None, metrics=None): 62 | if type(self.after_scheduler) != ReduceLROnPlateau: 63 | if self.finished and self.after_scheduler: 64 | if epoch is None: 65 | self.after_scheduler.step(None) 66 | else: 67 | self.after_scheduler.step(epoch - self.total_epoch) 68 | else: 69 | return super(GradualWarmupScheduler, self).step(epoch) 70 | else: 71 | self.step_ReduceLROnPlateau(metrics, epoch) 72 | --------------------------------------------------------------------------------