├── .gitignore
├── LICENSE.txt
├── README.md
├── confs
├── ABC.conf
├── DTU.conf
└── Replica.conf
├── main.py
├── media
├── overview.jpg
└── replica.gif
├── requirements.txt
├── scripts
├── download_data.py
├── get_gt_points_DTU.py
├── run_ABC.bash
├── run_DTU.bash
└── run_Replica.bash
└── src
├── dataset
└── dataset.py
├── edge_extraction
├── edge_fitting
│ ├── bezier_fit.py
│ ├── line_fit.py
│ └── main.py
├── extract_parametric_edge.py
├── extract_pointcloud.py
├── extract_util.py
└── merging
│ └── main.py
├── eval
├── ABC_scans.txt
├── DTU_scans.txt
├── eval_ABC.py
├── eval_DTU.py
└── eval_util.py
├── models
├── __init__.py
├── embedder.py
├── loss.py
├── udf_model.py
└── udf_renderer_blending.py
├── runner
├── runner_base.py
└── runner_udf.py
└── utils
├── __init__.py
├── math.py
├── plots.py
├── rend_util.py
├── tensor_dataclass.py
├── visualization.py
└── warmup_scheduler.py
/.gitignore:
--------------------------------------------------------------------------------
1 |
2 | .idea/*
3 | *.pyc
4 | cmake-build-*
5 | *.egg-info/
6 | *.so
7 | */build/
8 | /exp/
9 | /data/
10 |
11 |
--------------------------------------------------------------------------------
/LICENSE.txt:
--------------------------------------------------------------------------------
1 | MIT License
2 |
3 | Copyright (c) 2024 Lei Li
4 |
5 | Permission is hereby granted, free of charge, to any person obtaining a copy
6 | of this software and associated documentation files (the "Software"), to deal
7 | in the Software without restriction, including without limitation the rights
8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 | copies of the Software, and to permit persons to whom the Software is
10 | furnished to do so, subject to the following conditions:
11 |
12 | The above copyright notice and this permission notice shall be included in all
13 | copies or substantial portions of the Software.
14 |
15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 | SOFTWARE.
22 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 |
2 |
3D Neural Edge Reconstruction
3 |
4 | Lei Li
5 | ·
6 | Songyou Peng
7 | ·
8 | Zehao Yu
9 | ·
10 | Shaohui Liu
11 | ·
12 | Rémi Pautrat
13 |
14 | Xiaochuan Yin
15 | ·
16 | Marc Pollefeys
17 |
18 | CVPR 2024
19 |
20 |
21 |
22 |
23 |
24 |
25 |
26 |
27 | EMAP enables 3D edge reconstruction from multi-view 2D edge maps.
28 |
29 |
30 |
31 | ## Installation
32 |
33 | ```
34 | git clone https://github.com/cvg/EMAP.git
35 | cd EMAP
36 |
37 | conda create -n emap python=3.8
38 | conda activate emap
39 |
40 | conda install pytorch==2.1.0 torchvision==0.16.0 torchaudio==2.1.0 pytorch-cuda=12.1 -c pytorch -c nvidia
41 | pip install -r requirements.txt
42 | ```
43 |
44 | ## Datasets
45 | Download datasets:
46 | ```
47 | python scripts/download_data.py
48 | ```
49 | The data is organized as follows:
50 |
51 | ```
52 |
53 | |-- meta_data.json # camera parameters
54 | |-- color # images for each view
55 | |-- 0_colors.png
56 | |-- 1_colors.png
57 | ...
58 | |-- edge_DexiNed # edge maps extracted from DexiNed
59 | |-- 0_colors.png
60 | |-- 1_colors.png
61 | ...
62 | |-- edge_PidiNet # edge maps extracted from PidiNet
63 | |-- 0_colors.png
64 | |-- 1_colors.png
65 | ...
66 | ```
67 |
68 | ## Training and Edge Extraction
69 | To train and extract edges on different datasets, use the following commands:
70 |
71 | #### ABC-NEF_Edge Dataset
72 | ```
73 | bash scripts/run_ABC.bash
74 | ```
75 |
76 | #### Replica_Edge Dataset
77 | ```
78 | bash scripts/run_Replica.bash
79 | ```
80 |
81 | #### DTU_Edge Dataset
82 | ```
83 | bash scripts/run_DTU.bash
84 | ```
85 |
86 | ### Checkpoints
87 | We have uploaded the model checkpoints on [Google Drive](https://drive.google.com/file/d/1kU87MqDv5IvwjCt8I8KecTlIok39fuws/view?usp=sharing).
88 |
89 | ## Evaluation
90 | To evaluate extracted edges on ABC-NEF_Edge dataset, use the following commands:
91 |
92 | #### ABC-NEF_Edge Dataset
93 | ```
94 | python src/eval/eval_ABC.py
95 | ```
96 |
97 | ## Code Release Status
98 | - [x] Training Code
99 | - [x] Inference Code
100 | - [x] Evaluation Code
101 | - [ ] Custom Dataset Support
102 |
103 | ## License
104 |
105 | Shield: [](https://opensource.org/licenses/MIT)
106 |
107 | The majority of EMAP is licensed under a [MIT License](LICENSE.txt).
108 |
109 | ## Citing EMAP
110 |
111 | If you find the code useful, please consider the following BibTeX entry.
112 |
113 | ```BibTeX
114 | @InProceedings{li2024neural,
115 | title={3D Neural Edge Reconstruction},
116 | author={Li, Lei and Peng, Songyou and Yu, Zehao and Liu, Shaohui and Pautrat, R{\'e}mi and Yin, Xiaochuan and Pollefeys, Marc},
117 | booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
118 | year={2024},
119 | }
120 | ```
121 |
122 | ## Contact
123 | If you encounter any issues, you can also contact Lei through lllei.li0386@gmail.com.
124 |
125 | ## Acknowledgement
126 |
127 | This project is built upon [NeuralUDF](https://github.com/xxlong0/NeuralUDF), [NeuS](https://github.com/Totoro97/NeuS) and [MeshUDF](https://github.com/cvlab-epfl/MeshUDF). We use pretrained [DexiNed](https://github.com/xavysp/DexiNed) and [PidiNet](https://github.com/hellozhuo/pidinet) for edge map extraction. We thank all the authors for their great work and repos.
128 |
--------------------------------------------------------------------------------
/confs/ABC.conf:
--------------------------------------------------------------------------------
1 | general {
2 | base_exp_dir = ./exp/ABC/
3 | expname = emap
4 |
5 | model_type = udf
6 | recording = [
7 | ./src/models,
8 | ./src/dataset,
9 | ./src/runner,
10 | ]
11 | }
12 |
13 | dataset {
14 | data_dir = ./data/ABC-NEF_Edge/data/
15 | scan = "00000325"
16 | dataset_name = NEF
17 | detector = DexiNed
18 | near = 0.05
19 | far = 6
20 | AABB = [-1.0, -1.0, -1.0, 1.0, 1.0, 1.0]
21 | }
22 |
23 | train {
24 | latest_model_name = ckpt_best.pth
25 | importance_sample = True
26 | learning_rate = 5e-4
27 | learning_rate_geo = 1e-4
28 | learning_rate_alpha = 0.05
29 | end_iter = 50000
30 |
31 | batch_size = 1024
32 | validate_resolution_level = 1
33 | warm_up_end = 1000
34 | anneal_end = 10000
35 | use_white_bkgd = False
36 |
37 | warmup_sample = False
38 |
39 | save_freq = 1000
40 | val_freq = 1000
41 | report_freq = 1000
42 |
43 | igr_weight = 0.1
44 | igr_ns_weight = 0.0
45 | }
46 |
47 | edge_loss {
48 | edge_weight = 1.0
49 | loss_type = mse
50 | }
51 |
52 | model {
53 | nerf {
54 | D = 8
55 | d_in = 4
56 | d_in_view = 3
57 | W = 256
58 | multires = 10
59 | multires_view = 4
60 | output_ch = 4
61 | skips = [4]
62 | use_viewdirs = True
63 | }
64 |
65 | udf_network {
66 | d_out = 1
67 | d_in = 3
68 | d_hidden = 256
69 | n_layers = 8
70 | skip_in = [4]
71 | multires = 10
72 | bias = 0.5
73 | scale = 1.0
74 | geometric_init = True
75 | weight_norm = True
76 | udf_type = abs # square or abs
77 | }
78 |
79 | variance_network {
80 | init_val = 0.3
81 | }
82 |
83 | rendering_network {
84 | d_feature = 256
85 | mode = no_normal
86 | d_in = 6
87 | d_out = 1
88 | d_hidden = 128
89 | n_layers = 4
90 | weight_norm = True
91 | multires_view = 4
92 | squeeze_out = True
93 | blending_cand_views = 10
94 | }
95 |
96 |
97 | beta_network {
98 | init_var_beta = 0.5
99 | init_var_gamma = 0.3
100 | init_var_zeta = 0.3
101 | beta_min = 0.00005
102 | requires_grad_beta = True
103 | requires_grad_gamma = True
104 | requires_grad_zeta = False
105 | }
106 |
107 | udf_renderer {
108 | n_samples = 64
109 | n_importance = 50
110 | n_outside = 0
111 | up_sample_steps = 5
112 | perturb = 1.0
113 | sdf2alpha_type = numerical
114 | upsampling_type = classical
115 | use_unbias_render = True
116 | }
117 | }
118 |
119 | edge_extraction {
120 | is_pointshift = True
121 | iters = 2
122 | is_linedirection = True
123 | udf_threshold = 0.02
124 | resolution = 128
125 | sampling_delta = 0.005
126 | sampling_N = 50
127 | visible_checking = False
128 |
129 | }
130 |
--------------------------------------------------------------------------------
/confs/DTU.conf:
--------------------------------------------------------------------------------
1 | general {
2 | base_exp_dir = ./exp/DTU/
3 | expname = emap
4 |
5 | model_type = udf
6 | recording = [
7 | ./src/models,
8 | ./src/dataset,
9 | ./src/runner,
10 | ]
11 | }
12 |
13 | dataset {
14 | data_dir = ./data/DTU_Edge/data/
15 | scan = "scan105"
16 | dataset_name = DTU
17 | detector = PidiNet
18 | near = 0.05
19 | far = 6.0
20 | AABB = [-1.0, -1.0, -1.0, 1.0, 1.0, 1.0]
21 | }
22 |
23 | train {
24 | latest_model_name = ckpt_best.pth
25 | importance_sample = True
26 | learning_rate = 5e-4
27 | learning_rate_geo = 1e-4
28 | learning_rate_alpha = 0.05
29 | end_iter = 200000
30 |
31 | batch_size = 1024
32 | validate_resolution_level = 1
33 | warm_up_end = 1000
34 | anneal_end = 10000
35 | use_white_bkgd = False
36 |
37 | warmup_sample = False
38 |
39 | save_freq = 5000
40 | val_freq = 5000
41 | report_freq = 1000
42 |
43 | igr_weight = 0.01
44 | igr_ns_weight = 0.0
45 | }
46 |
47 | edge_loss {
48 | edge_weight = 1.0
49 | loss_type = mse
50 | }
51 |
52 | model {
53 | nerf {
54 | D = 8
55 | d_in = 4
56 | d_in_view = 3
57 | W = 256
58 | multires = 10
59 | multires_view = 4
60 | output_ch = 4
61 | skips = [4]
62 | use_viewdirs = True
63 | }
64 |
65 | udf_network {
66 | d_out = 1
67 | d_in = 3
68 | d_hidden = 256
69 | n_layers = 8
70 | skip_in = [4]
71 | multires = 10
72 | bias = 0.5
73 | scale = 1.0
74 | geometric_init = True
75 | weight_norm = True
76 | udf_type = abs # square or abs
77 | }
78 |
79 | variance_network {
80 | init_val = 0.3
81 | }
82 |
83 | rendering_network {
84 | d_feature = 256
85 | mode = no_normal
86 | d_in = 6
87 | d_out = 1
88 | d_hidden = 128
89 | n_layers = 4
90 | weight_norm = True
91 | multires_view = 4
92 | squeeze_out = True
93 | blending_cand_views = 10
94 | }
95 |
96 |
97 | beta_network {
98 | init_var_beta = 0.5
99 | init_var_gamma = 0.3
100 | init_var_zeta = 0.3
101 | beta_min = 0.00005
102 | requires_grad_beta = True
103 | requires_grad_gamma = True
104 | requires_grad_zeta = False
105 | }
106 |
107 | udf_renderer {
108 | n_samples = 64
109 | n_importance = 50
110 | n_outside = 0
111 | up_sample_steps = 5
112 | perturb = 1.0
113 | sdf2alpha_type = numerical
114 | upsampling_type = classical
115 | use_unbias_render = True
116 | }
117 | }
118 |
119 | edge_extraction {
120 | is_pointshift = True
121 | iters = 1
122 | is_linedirection = True
123 | udf_threshold = 0.015
124 | resolution = 256
125 | sampling_delta = 0.005
126 | sampling_N = 50
127 | visible_checking = True
128 |
129 | }
130 |
--------------------------------------------------------------------------------
/confs/Replica.conf:
--------------------------------------------------------------------------------
1 | general {
2 | base_exp_dir = ./exp/Replica/
3 | expname = emap
4 |
5 | model_type = udf
6 | recording = [
7 | ./src/models,
8 | ./src/dataset,
9 | ./src/runner,
10 | ]
11 | }
12 |
13 | dataset {
14 | data_dir = ./data/Replica_Edge
15 | scan = "room0"
16 | dataset_name = Replica
17 | detector = PidiNet
18 | near = 0.05
19 | far = 2.5
20 | AABB = [-1.0, -1.0, -1.0, 1.0, 1.0, 1.0]
21 | }
22 |
23 | train {
24 | latest_model_name = ckpt_best.pth
25 | importance_sample = True
26 | learning_rate = 5e-4
27 | learning_rate_geo = 1e-4
28 | learning_rate_alpha = 0.05
29 | end_iter = 200000
30 |
31 | batch_size = 1024
32 | validate_resolution_level = 1
33 | warm_up_end = 1000
34 | anneal_end = 10000
35 | use_white_bkgd = False
36 |
37 | warmup_sample = False
38 |
39 | save_freq = 5000
40 | val_freq = 5000
41 | report_freq = 1000
42 |
43 | igr_weight = 0.01
44 | igr_ns_weight = 0.0
45 | }
46 |
47 | edge_loss {
48 | edge_weight = 1.0
49 | loss_type = mse
50 | }
51 |
52 | model {
53 | nerf {
54 | D = 8
55 | d_in = 4
56 | d_in_view = 3
57 | W = 256
58 | multires = 10
59 | multires_view = 4
60 | output_ch = 4
61 | skips = [4]
62 | use_viewdirs = True
63 | }
64 |
65 | udf_network {
66 | d_out = 1
67 | d_in = 3
68 | d_hidden = 256
69 | n_layers = 8
70 | skip_in = [4]
71 | multires = 6 # 6 or 10
72 | bias = 0.5
73 | scale = 1.0
74 | geometric_init = True
75 | weight_norm = True
76 | udf_type = abs # square or abs
77 | }
78 |
79 | variance_network {
80 | init_val = 0.3
81 | }
82 |
83 | rendering_network {
84 | d_feature = 256
85 | mode = no_normal
86 | d_in = 6
87 | d_out = 1
88 | d_hidden = 128
89 | n_layers = 4
90 | weight_norm = True
91 | multires_view = 4
92 | squeeze_out = True
93 | blending_cand_views = 10
94 | }
95 |
96 | beta_network {
97 | init_var_beta = 0.5
98 | init_var_gamma = 0.3
99 | init_var_zeta = 0.3
100 | beta_min = 0.00005
101 | requires_grad_beta = True
102 | requires_grad_gamma = True
103 | requires_grad_zeta = False
104 | }
105 |
106 | udf_renderer {
107 | n_samples = 64
108 | n_importance = 50
109 | n_outside = 0
110 | up_sample_steps = 5
111 | perturb = 1.0
112 | sdf2alpha_type = numerical
113 | upsampling_type = classical
114 | use_unbias_render = True
115 | }
116 | }
117 |
118 | edge_extraction {
119 | is_pointshift = True
120 | iters = 1
121 | is_linedirection = True
122 | udf_threshold = 0.01
123 | resolution = 256
124 | sampling_delta = 0.005
125 | sampling_N = 50
126 | visible_checking = True
127 |
128 | }
--------------------------------------------------------------------------------
/main.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import argparse
3 | import logging
4 | import numpy as np
5 | import random
6 | from pyhocon import ConfigFactory
7 | from src.runner.runner_udf import Runner_UDF
8 |
9 |
10 | def fix_random_seeds(seed=42):
11 | """
12 | Fix the random seeds for reproducibility.
13 | """
14 | torch.manual_seed(seed)
15 | if torch.cuda.is_available():
16 | torch.cuda.manual_seed_all(seed)
17 | np.random.seed(seed)
18 | random.seed(seed)
19 |
20 |
21 | def get_runner(mode):
22 | """
23 | Get the runner based on the provided mode.
24 | """
25 | runners = {
26 | "udf": Runner_UDF,
27 | }
28 | if mode not in runners:
29 | raise ValueError(f"Unknown mode: {mode}")
30 | return runners[mode]
31 |
32 |
33 | def main():
34 | """
35 | Main function to parse arguments and run the appropriate mode.
36 | """
37 | torch.set_default_dtype(torch.float32)
38 | FORMAT = "[%(filename)s:%(lineno)s - %(funcName)20s() ] %(message)s"
39 | logging.basicConfig(level=logging.DEBUG, format=FORMAT)
40 |
41 | parser = argparse.ArgumentParser()
42 |
43 | # Parameters for the training
44 | parser.add_argument(
45 | "--conf", type=str, default="./confs/ABC.conf", help="Path to the config file."
46 | )
47 | parser.add_argument(
48 | "--mode",
49 | type=str,
50 | default="train",
51 | choices=["train", "extract_edge"],
52 | help="Mode to run.",
53 | )
54 | parser.add_argument(
55 | "--scan", type=str, default="null", help="The name of a dataset."
56 | )
57 | parser.add_argument("--gpu", type=int, default=0, help="GPU id to use.")
58 | parser.add_argument(
59 | "--is_continue",
60 | default=False,
61 | action="store_true",
62 | help="Flag to continue training.",
63 | )
64 |
65 | args = parser.parse_args()
66 |
67 | # Fix the random seed
68 | fix_random_seeds()
69 |
70 | with open(args.conf, "r") as f:
71 | conf_text = f.read()
72 | conf = ConfigFactory.parse_string(conf_text)
73 |
74 | if args.scan != "null":
75 | conf["dataset"]["scan"] = args.scan
76 |
77 | logging.info(f"Run on scan of {conf['dataset']['scan']}")
78 |
79 | runner_class = get_runner(conf["general"]["model_type"])
80 | runner = runner_class(conf, args.mode, args.is_continue, args)
81 |
82 | if args.mode == "train":
83 | logging.info(f"Training UDF")
84 | runner.train()
85 | elif args.mode == "extract_edge":
86 | logging.info(f"Extracting edges from UDF")
87 | runner.extract_edge(
88 | resolution=conf["edge_extraction"]["resolution"],
89 | udf_threshold=conf["edge_extraction"]["udf_threshold"],
90 | sampling_N=conf["edge_extraction"]["sampling_N"],
91 | sampling_delta=conf["edge_extraction"]["sampling_delta"],
92 | is_pointshift=conf["edge_extraction"]["is_pointshift"],
93 | iters=conf["edge_extraction"]["iters"],
94 | is_linedirection=conf["edge_extraction"]["is_linedirection"],
95 | visible_checking=conf["edge_extraction"]["visible_checking"],
96 | )
97 | else:
98 | raise ValueError(f"Invalid mode: {args.mode}")
99 |
100 |
101 | if __name__ == "__main__":
102 | main()
103 |
--------------------------------------------------------------------------------
/media/overview.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvg/EMAP/652f24ecc3f3cbf538928f27cc6d55dbebb360c7/media/overview.jpg
--------------------------------------------------------------------------------
/media/replica.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvg/EMAP/652f24ecc3f3cbf538928f27cc6d55dbebb360c7/media/replica.gif
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | pyhocon==0.3.59
2 | icecream==2.1.3
3 | opencv-python==4.9.0.80
4 | scipy==1.10.1
5 | h5py==3.10.0
6 | open3d==0.18.0
7 | point_cloud_utils==0.30.4
8 | trimesh==4.2.2
9 | tensorboard==2.14.0
10 | torch_optimizer==0.3.0
11 | flow_vis==0.1
12 | scikit-image==0.21.0
13 | termcolor==2.4.0
14 | gitpython==3.1.43
15 | gdown==5.2.0
--------------------------------------------------------------------------------
/scripts/download_data.py:
--------------------------------------------------------------------------------
1 | import os
2 | import zipfile
3 | import gdown
4 |
5 |
6 | def download_and_unzip_google_drive_files(paths, download_to="./data"):
7 | if not os.path.exists(download_to):
8 | os.makedirs(download_to)
9 |
10 | for path in paths:
11 | # Convert Google Drive link to direct download link
12 | file_id = path.split("/d/")[1].split("/")[0]
13 | direct_link = f"https://drive.google.com/uc?id={file_id}"
14 |
15 | print(f"Downloading from {direct_link}...")
16 | output_path = os.path.join(download_to, f"{file_id}.zip")
17 | gdown.download(direct_link, output_path, quiet=False)
18 | print(f"Downloaded {output_path}")
19 |
20 | # Unzip the downloaded file
21 | with zipfile.ZipFile(output_path, "r") as zip_ref:
22 | zip_ref.extractall(download_to)
23 | os.remove(output_path)
24 | print(f"Unzipped and deleted: {output_path}")
25 |
26 | print(f"Finished extracting files to: {download_to}")
27 |
28 |
29 | # Google Drive file paths
30 | paths = [
31 | "https://drive.google.com/file/d/17aUcCJCP5vgARs237H0BtlRoms5-CR6e/view?usp=sharing",
32 | "https://drive.google.com/file/d/1eZZiMcTfoiYfIxtv4Wy3lQYAudZpKlE0/view?usp=sharing",
33 | "https://drive.google.com/file/d/1pum-25MEFhXQu1fZLy_f9lRMBxvF1ssm/view?usp=sharing",
34 | ]
35 |
36 | download_to = "./data"
37 | download_and_unzip_google_drive_files(paths, download_to)
38 |
--------------------------------------------------------------------------------
/scripts/get_gt_points_DTU.py:
--------------------------------------------------------------------------------
1 | import trimesh
2 | import open3d as o3d
3 | import numpy as np
4 | import math
5 | import json
6 | import os
7 | from src.eval.eval_util import (
8 | set_random_seeds,
9 | load_from_json,
10 | downsample_point_cloud_average,
11 | )
12 | from src.edge_extraction.edge_fitting.bezier_fit import bezier_curve
13 | import argparse
14 | from pathlib import Path
15 | import json
16 | import cv2
17 |
18 |
19 | def write_vis_pcd(file, points, colors):
20 | pcd = o3d.geometry.PointCloud()
21 | pcd.points = o3d.utility.Vector3dVector(points)
22 | pcd.colors = o3d.utility.Vector3dVector(colors)
23 | o3d.io.write_point_cloud(file, pcd)
24 |
25 |
26 | def convert_ply_to_obj(ply_file_path, obj_file_path):
27 | # Load the .ply file
28 | mesh = trimesh.load(ply_file_path)
29 | # Export the mesh to .obj format
30 | mesh.export(obj_file_path, file_type="obj")
31 |
32 |
33 | def save_point_cloud(file_path, points):
34 | pcd = o3d.geometry.PointCloud()
35 | pcd.points = o3d.utility.Vector3dVector(points)
36 | o3d.io.write_point_cloud(file_path, pcd)
37 |
38 |
39 | def load_from_json(filename: Path):
40 | """Load a dictionary from a JSON filename.
41 |
42 | Args:
43 | filename: The filename to load from.
44 | """
45 | assert filename.suffix == ".json"
46 | with open(filename, encoding="UTF-8") as file:
47 | return json.load(file)
48 |
49 |
50 | def sample_single_tri(input_):
51 | n1, n2, v1, v2, tri_vert = input_
52 | c = np.mgrid[: n1 + 1, : n2 + 1]
53 | c += 0.5
54 | c[0] /= max(n1, 1e-7)
55 | c[1] /= max(n2, 1e-7)
56 | c = np.transpose(c, (1, 2, 0))
57 | k = c[c.sum(axis=-1) < 1] # m2
58 | q = v1 * k[:, :1] + v2 * k[:, 1:] + tri_vert
59 | return q
60 |
61 |
62 | def convert_mesh_gt2world(mesh_path, out_mesh_path, gttoworld):
63 | mesh = trimesh.load(mesh_path)
64 | # mesh.transform(gttoworld)
65 | mesh.apply_transform(gttoworld)
66 | mesh.export(out_mesh_path, file_type="obj")
67 | return mesh
68 |
69 |
70 | def get_edge_maps(data_dir):
71 | meta = load_from_json(Path(data_dir) / "meta_data.json")
72 | h, w = meta["height"], meta["width"]
73 | edges_list, intrinsics_list, camtoworld_list = [], [], []
74 | for idx, frame in enumerate(meta["frames"]):
75 | intrinsics = np.array(frame["intrinsics"])
76 | camtoworld = np.array(frame["camtoworld"])[:4, :4]
77 | edges_list.append(
78 | os.path.join(
79 | data_dir,
80 | "edge_PidiNet",
81 | frame["rgb_path"],
82 | )
83 | )
84 | intrinsics_list.append(intrinsics)
85 | camtoworld_list.append(camtoworld)
86 |
87 | edges = [cv2.imread(im_name, 0)[..., None] for im_name in edges_list]
88 | edges = 1 - np.stack(edges) / 255.0
89 | intrinsics_list = np.stack(intrinsics_list)
90 | camtoworld_list = np.stack(camtoworld_list)
91 | return edges, intrinsics_list, camtoworld_list, h, w
92 |
93 |
94 | def compute_visibility(
95 | gt_points,
96 | edge_maps,
97 | intrinsics_list,
98 | camtoworld_list,
99 | h,
100 | w,
101 | edge_visibility_threshold,
102 | edge_visibility_frames,
103 | ):
104 | img_frames = len(edge_maps)
105 | point_visibility_matrix = np.zeros((len(gt_points), img_frames))
106 |
107 | for frame_idx, (edge_map, intrinsic, camtoworld) in enumerate(
108 | zip(edge_maps, intrinsics_list, camtoworld_list)
109 | ):
110 | K = intrinsic[:3, :3]
111 | worldtocam = np.linalg.inv(camtoworld)
112 | edge_uv = project2D(K, worldtocam, gt_points)
113 | edge_uv = np.round(edge_uv).astype(np.int64)
114 |
115 | # Boolean mask for valid u, v coordinates
116 | valid_u_mask = (edge_uv[:, 0] >= 0) & (edge_uv[:, 0] < w)
117 | valid_v_mask = (edge_uv[:, 1] >= 0) & (edge_uv[:, 1] < h)
118 | valid_mask = valid_u_mask & valid_v_mask
119 |
120 | valid_edge_uv = edge_uv[valid_mask]
121 | valid_projected_edges = edge_map[valid_edge_uv[:, 1], valid_edge_uv[:, 0]]
122 |
123 | # Calculate visibility in a vectorized manner
124 | visibility = (valid_projected_edges > edge_visibility_threshold).reshape(-1)
125 |
126 | point_visibility_matrix[valid_mask, frame_idx] = visibility.astype(float)
127 |
128 | return np.sum(point_visibility_matrix, axis=1) > edge_visibility_frames
129 |
130 |
131 | def project2D(K, worldtocam, points3d):
132 | shape = points3d.shape
133 | R = worldtocam[:3, :3]
134 | T = worldtocam[:3, 3:]
135 |
136 | projected = K @ (R @ points3d.T + T)
137 | projected = projected.T
138 | projected = projected / projected[:, -1:]
139 | uv = projected.reshape(*shape)[..., :2].reshape(-1, 2)
140 | return uv
141 |
142 |
143 | def bezier_curve_length(control_points, num_samples):
144 | def binomial_coefficient(n, i):
145 | return math.factorial(n) // (math.factorial(i) * math.factorial(n - i))
146 |
147 | def derivative_bezier(t):
148 | n = len(control_points) - 1
149 | point = np.array([0.0, 0.0, 0.0])
150 | for i, (p1, p2) in enumerate(zip(control_points[:-1], control_points[1:])):
151 | point += (
152 | n
153 | * binomial_coefficient(n - 1, i)
154 | * (1 - t) ** (n - 1 - i)
155 | * t**i
156 | * (np.array(p2) - np.array(p1))
157 | )
158 | return point
159 |
160 | def simpson_integral(a, b, num_samples):
161 | h = (b - a) / num_samples
162 | sum1 = sum(
163 | np.linalg.norm(derivative_bezier(a + i * h))
164 | for i in range(1, num_samples, 2)
165 | )
166 | sum2 = sum(
167 | np.linalg.norm(derivative_bezier(a + i * h))
168 | for i in range(2, num_samples - 1, 2)
169 | )
170 | return (
171 | (
172 | np.linalg.norm(derivative_bezier(a))
173 | + 4 * sum1
174 | + 2 * sum2
175 | + np.linalg.norm(derivative_bezier(b))
176 | )
177 | * h
178 | / 3
179 | )
180 |
181 | # Compute the length of the 3D Bezier curve using composite Simpson's rule
182 | length = 0.0
183 | for i in range(num_samples):
184 | t0 = i / num_samples
185 | t1 = (i + 1) / num_samples
186 | length += simpson_integral(t0, t1, num_samples)
187 |
188 | return length
189 |
190 |
191 | def bezier_para_to_point_length(control_points, num_samples=100):
192 | t_fit = np.linspace(0, 1, num_samples)
193 | curve_point_set = []
194 | curve_length_set = []
195 | for control_point in control_points:
196 | points = bezier_curve(
197 | t_fit,
198 | control_point[0, 0],
199 | control_point[0, 1],
200 | control_point[0, 2],
201 | control_point[1, 0],
202 | control_point[1, 1],
203 | control_point[1, 2],
204 | control_point[2, 0],
205 | control_point[2, 1],
206 | control_point[2, 2],
207 | control_point[3, 0],
208 | control_point[3, 1],
209 | control_point[3, 2],
210 | )
211 | lengths = bezier_curve_length(control_point, num_samples=num_samples)
212 | curve_point_set.append(points)
213 | curve_length_set.append(lengths)
214 | return (
215 | np.array(curve_point_set).reshape(-1, num_samples, 3, 1),
216 | np.array(curve_length_set),
217 | )
218 |
219 |
220 | def main(gt_point_cloud_dir, dataset_dir, out_dir):
221 | set_random_seeds()
222 | gt_point_cloud_dir = os.path.join(gt_point_cloud_dir, "Points", "stl")
223 | if not os.path.exists(gt_point_cloud_dir):
224 | print(
225 | f"Ground truth point cloud directory {gt_point_cloud_dir} does not exist. Please download it from http://roboimagedata2.compute.dtu.dk/data/MVS/Points.zip"
226 | )
227 | return
228 |
229 | scan_names_dict = {
230 | "scan37": [0.55, 0.3],
231 | "scan83": [0.65, 0.2],
232 | "scan105": [0.65, 0.2],
233 | "scan110": [0.5, 0.3],
234 | "scan118": [0.5, 0.3],
235 | "scan122": [0.35, 0.4],
236 | }
237 |
238 | os.makedirs(out_dir, exist_ok=True)
239 |
240 | for scan_name, (
241 | edge_visibility_threshold,
242 | edge_visibility_frames_ratio,
243 | ) in scan_names_dict.items():
244 | output_file = os.path.join(out_dir, scan_name, "edge_points.ply")
245 | if os.path.exists(output_file):
246 | print(f"{output_file} already exists. Skipping.")
247 | continue
248 | os.makedirs(os.path.join(out_dir, scan_name), exist_ok=True)
249 | meta_data_json_path = os.path.join(dataset_dir, scan_name, "meta_data.json")
250 | meta_base_dir = os.path.join(dataset_dir, scan_name)
251 | worldtogt = np.array(load_from_json(Path(meta_data_json_path))["worldtogt"])
252 | gttoworld = np.linalg.inv(worldtogt)
253 | gt_point_cloud_path = os.path.join(
254 | gt_point_cloud_dir,
255 | f"stl{int(scan_name[4:]):03d}_total.ply",
256 | )
257 | gt_point_cloud = o3d.io.read_point_cloud(gt_point_cloud_path)
258 |
259 | gt_points = np.asarray(gt_point_cloud.points)
260 | points = gt_points @ gttoworld[:3, :3].T + gttoworld[:3, 3][None, ...]
261 |
262 | edge_maps, intrinsics_list, camtoworld_list, h, w = get_edge_maps(meta_base_dir)
263 | num_frames = len(edge_maps)
264 | edge_visibility_frames = max(
265 | 1, round(edge_visibility_frames_ratio * num_frames)
266 | )
267 | points_visibility = compute_visibility(
268 | points,
269 | edge_maps,
270 | intrinsics_list,
271 | camtoworld_list,
272 | h,
273 | w,
274 | edge_visibility_threshold,
275 | edge_visibility_frames,
276 | )
277 |
278 | print(
279 | f"{scan_name}: before visibility check: {len(points)}, after visibility check: {np.sum(points_visibility)}"
280 | )
281 |
282 | edge_points = points[points_visibility]
283 | downsampled_edge_points = downsample_point_cloud_average(
284 | edge_points, num_voxels_per_axis=256
285 | )
286 | downsampled_edge_points = (
287 | downsampled_edge_points @ worldtogt[:3, :3].T + worldtogt[:3, 3][None, ...]
288 | )
289 | save_point_cloud(output_file, downsampled_edge_points)
290 | print(f"Saved downsampled edge point cloud to {output_file}")
291 |
292 |
293 | if __name__ == "__main__":
294 | parser = argparse.ArgumentParser(
295 | description="Process and evaluate point cloud data."
296 | )
297 | parser.add_argument(
298 | "--gt_point_cloud_dir",
299 | type=str,
300 | default="data/DTU_Edge/groundtruth",
301 | )
302 | parser.add_argument("--dataset_dir", type=str, default="data/DTU_Edge/data")
303 | parser.add_argument(
304 | "--out_dir",
305 | type=str,
306 | default="data/DTU_Edge/groundtruth/edge_points",
307 | )
308 | args = parser.parse_args()
309 |
310 | main(args.gt_point_cloud_dir, args.dataset_dir, args.out_dir)
311 |
--------------------------------------------------------------------------------
/scripts/run_ABC.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | # source ~/miniconda3/etc/profile.d/conda.sh
5 | # conda activate emap
6 |
7 | # Set the PYTHONPATH environment variable
8 | export PYTHONPATH=.
9 |
10 | # Set the device for CUDA to use
11 | export CUDA_VISIBLE_DEVICES=0
12 |
13 | # Train UDF field
14 | python main.py --conf ./confs/ABC.conf --mode train
15 |
16 | # Extract parametric edges
17 | python main.py --conf ./confs/ABC.conf --mode extract_edge
18 |
--------------------------------------------------------------------------------
/scripts/run_DTU.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | # source ~/miniconda3/etc/profile.d/conda.sh
5 | # conda activate emap
6 |
7 | # Set the PYTHONPATH environment variable
8 | export PYTHONPATH=.
9 |
10 | # Set the device for CUDA to use
11 | export CUDA_VISIBLE_DEVICES=0
12 |
13 | # Train UDF field
14 | python main.py --conf ./confs/DTU.conf --mode train
15 |
16 | # Extract parametric edges
17 | python main.py --conf ./confs/DTU.conf --mode extract_edge
18 |
--------------------------------------------------------------------------------
/scripts/run_Replica.bash:
--------------------------------------------------------------------------------
1 | #!/bin/bash
2 | set -e
3 |
4 | # source ~/miniconda3/etc/profile.d/conda.sh
5 | # conda activate emap
6 |
7 | # Set the PYTHONPATH environment variable
8 | export PYTHONPATH=.
9 |
10 | # Set the device for CUDA to use
11 | export CUDA_VISIBLE_DEVICES=0
12 |
13 | # Train UDF field
14 | python main.py --conf ./confs/Replica.conf --mode train
15 |
16 | # Extract parametric edges
17 | python main.py --conf ./confs/Replica.conf --mode extract_edge
18 |
--------------------------------------------------------------------------------
/src/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import cv2 as cv
4 | import numpy as np
5 | import os
6 | from scipy.spatial.transform import Rotation as Rot
7 | from scipy.spatial.transform import Slerp
8 | from pathlib import Path
9 | import json
10 | import random
11 |
12 |
13 | def load_from_json(filename: Path):
14 | """Load a dictionary from a JSON filename.
15 |
16 | Args:
17 | filename: The filename to load from.
18 | """
19 | assert filename.suffix == ".json"
20 | with open(filename, encoding="UTF-8") as file:
21 | return json.load(file)
22 |
23 |
24 | # This function is borrowed from IDR: https://github.com/lioryariv/idr
25 | def load_K_Rt_from_P(filename, P=None):
26 | if P is None:
27 | lines = open(filename).read().splitlines()
28 | if len(lines) == 4:
29 | lines = lines[1:]
30 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
31 | P = np.asarray(lines).astype(np.float32).squeeze()
32 |
33 | out = cv.decomposeProjectionMatrix(P)
34 | K = out[0]
35 | R = out[1]
36 | t = out[2]
37 |
38 | K = K / K[2, 2]
39 | intrinsics = np.eye(4)
40 | intrinsics[:3, :3] = K
41 |
42 | pose = np.eye(4, dtype=np.float32)
43 | pose[:3, :3] = R.transpose()
44 | pose[:3, 3] = (t[:3] / t[3])[:, 0]
45 |
46 | return intrinsics, pose
47 |
48 |
49 | class Dataset:
50 | def __init__(self, conf):
51 | super(Dataset, self).__init__()
52 | print("Load data: Begin")
53 | self.device = (
54 | torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
55 | )
56 | self.conf = conf
57 | self.scan = conf.get_string("scan")
58 | self.data_dir = os.path.join(conf.get_string("data_dir"), self.scan)
59 | self.dataset_name = self.conf.get_string("dataset_name", default="ABC")
60 | self.detector = conf.get_string("detector", default="DexiNed")
61 | assert self.detector in ["DexiNed", "PidiNet"]
62 | self.load_metadata(conf)
63 | self.load_image_data()
64 | print("Load data: End")
65 |
66 | def load_metadata(self, conf):
67 | meta = load_from_json(Path(self.data_dir) / "meta_data.json")
68 | self.meta = meta
69 | self.intrinsics_all = []
70 | self.pose_all = []
71 | self.edges_list = []
72 | self.colors_list = []
73 |
74 | meta_scene_box = meta["scene_box"]
75 |
76 | self.near = meta_scene_box["near"]
77 | self.far = meta_scene_box["far"]
78 | self.radius = meta_scene_box["radius"]
79 |
80 | H, W = meta["height"], meta["width"]
81 | self.H, self.W, self.image_pixels = H, W, H * W
82 |
83 | for idx, frame in enumerate(meta["frames"]):
84 | self.process_frame(frame)
85 |
86 | def process_frame(self, frame):
87 | intrinsics = torch.tensor(frame["intrinsics"])
88 | camtoworld = torch.tensor(frame["camtoworld"])[:4, :4]
89 | image_name = frame["rgb_path"]
90 | if self.detector == "PidiNet":
91 | self.edges_list.append(
92 | os.path.join(
93 | self.data_dir,
94 | "edge_PidiNet",
95 | image_name[:-4] + ".png",
96 | )
97 | )
98 | elif self.detector == "DexiNed":
99 | self.edges_list.append(
100 | os.path.join(self.data_dir, "edge_DexiNed", image_name)
101 | )
102 | self.colors_list.append(os.path.join(self.data_dir, "color", image_name))
103 | self.intrinsics_all.append(intrinsics)
104 | self.pose_all.append(camtoworld)
105 |
106 | def load_image_data(self):
107 | self.load_edges_data()
108 | self.colors_np = (
109 | np.stack([cv.imread(im_name) for im_name in self.colors_list]) / 255.0
110 | )
111 |
112 | self.n_images = len(self.edges_list)
113 | self.edges = torch.from_numpy(
114 | self.edges_np.astype(np.float32)
115 | ) # .to(self.device)
116 | self.colors = torch.from_numpy(self.colors_np.astype(np.float32))
117 |
118 | # .to(
119 | # self.device
120 | # )
121 |
122 | self.masks_np = (self.edges_np > 0.5).astype(np.float32)
123 | self.masks = torch.from_numpy(self.masks_np.astype(np.float32))
124 |
125 | self.intrinsics_all = torch.stack(self.intrinsics_all)
126 | self.intrinsics_all_inv = torch.inverse(self.intrinsics_all)
127 | self.focal = self.intrinsics_all[0][0, 0]
128 | self.pose_all = torch.stack(self.pose_all)
129 |
130 | self.object_bbox_min = np.array(self.meta["scene_box"]["aabb"][0])
131 | self.object_bbox_max = np.array(self.meta["scene_box"]["aabb"][1])
132 |
133 | def load_edges_data(self):
134 | edges = [cv.imread(im_name, 0)[..., None] for im_name in self.edges_list]
135 | self.edges_np = np.stack(edges) / 255.0
136 |
137 | def gen_rays_at(self, img_idx, resolution_level=1):
138 | """
139 | Generate rays at world space from one camera.
140 | """
141 | l = resolution_level
142 | tx = torch.linspace(0, self.W - 1, self.W // l)
143 | ty = torch.linspace(0, self.H - 1, self.H // l)
144 | pixels_x, pixels_y = torch.meshgrid(tx, ty)
145 | p = torch.stack(
146 | [pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1
147 | ) # W, H, 3
148 | p = torch.matmul(
149 | self.intrinsics_all_inv[img_idx, None, None, :3, :3], p[:, :, :, None]
150 | ).squeeze() # W, H, 3
151 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # W, H, 3
152 | depth_scale = rays_v[:, :, 2:]
153 | rays_v = torch.matmul(
154 | self.pose_all[img_idx, None, None, :3, :3], rays_v[:, :, :, None]
155 | ).squeeze() # W, H, 3
156 | rays_o = self.pose_all[img_idx, None, None, :3, 3].expand(
157 | rays_v.shape
158 | ) # W, H, 3
159 | pose = self.pose_all[img_idx] # [4, 4]
160 | intrinsics = self.intrinsics_all[img_idx] # [4, 4]
161 | return (
162 | rays_o.transpose(0, 1).to(self.device),
163 | rays_v.transpose(0, 1).to(self.device),
164 | pose.to(self.device),
165 | intrinsics.to(self.device),
166 | depth_scale.to(self.device),
167 | )
168 |
169 | def gen_one_ray_at(self, img_idx, x, y):
170 | """
171 |
172 | Parameters
173 | ----------
174 | img_idx :
175 | x : for width
176 | y : for height
177 |
178 | Returns
179 | -------
180 |
181 | """
182 | image = np.uint8(self.edges_np[img_idx] * 256)
183 | image2 = cv.circle(image, (x, y), radius=10, color=(0, 0, 255), thickness=-1)
184 |
185 | pixels_x = torch.Tensor([x]).long()
186 | pixels_y = torch.Tensor([y]).long()
187 | edge = self.edges[img_idx][(pixels_y, pixels_x)] # batch_size, 1
188 | color = self.colors[img_idx][(pixels_y, pixels_x)] # batch_size, 3
189 | mask = (self.masks[img_idx][(pixels_y, pixels_x)] > 0).to(
190 | torch.float32
191 | ) # batch_size, 1
192 |
193 | p = torch.stack(
194 | [pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1
195 | ).float() # batch_size, 3
196 | p = torch.matmul(
197 | self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]
198 | ).squeeze(
199 | -1
200 | ) # batch_size, 3
201 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3
202 | rays_v = torch.matmul(
203 | self.pose_all[img_idx, None, :3, :3], rays_v[:, :, None]
204 | ).squeeze(
205 | -1
206 | ) # batch_size, 3
207 | rays_o = self.pose_all[img_idx, None, :3, 3].expand(
208 | rays_v.shape
209 | ) # batch_size, 3
210 |
211 | return (
212 | {
213 | "rays_o": rays_o.to(self.device),
214 | "rays_v": rays_v.to(self.device),
215 | "edge": edge.to(self.device),
216 | "color": color.to(self.device),
217 | "mask": mask[:, :1].to(self.device),
218 | },
219 | image2,
220 | )
221 |
222 | def gen_random_rays_patches_at(
223 | self,
224 | img_idx,
225 | batch_size,
226 | importance_sample=False,
227 | ):
228 | """
229 | Generate random rays at world space from one camera.
230 | """
231 |
232 | if not importance_sample:
233 | pixels_x = torch.randint(low=0, high=self.W, size=[batch_size])
234 | pixels_y = torch.randint(low=0, high=self.H, size=[batch_size])
235 | elif (
236 | importance_sample and self.masks is not None
237 | ): # sample more pts in the valid mask regions
238 | img_np = self.edges[img_idx].cpu().numpy()
239 | edge_density = np.mean(img_np)
240 | probabilities = np.ones_like(img_np) * edge_density
241 | probabilities[img_np > 0.1] = 1.0 - edge_density
242 | probabilities = probabilities.reshape(-1)
243 |
244 | # randomly sample 50%
245 | pixels_x_1 = torch.randint(low=0, high=self.W, size=[batch_size // 2])
246 | pixels_y_1 = torch.randint(low=0, high=self.H, size=[batch_size // 2])
247 |
248 | ys, xs = torch.meshgrid(
249 | torch.linspace(0, self.H - 1, self.H),
250 | torch.linspace(0, self.W - 1, self.W),
251 | ) # pytorch's meshgrid has indexing='ij'
252 | p = torch.stack([xs, ys], dim=-1) # H, W, 2
253 | p_valid = p[self.masks[img_idx][:, :, 0] >= 0] # [num, 2]
254 |
255 | # randomly sample 50% mainly from edge regions
256 | number_list = np.arange(self.image_pixels)
257 | random_idx = random.choices(number_list, probabilities, k=batch_size // 2)
258 | random_idx = torch.from_numpy(np.array(random_idx)).to(
259 | torch.int64
260 | ) # .to(self.device)
261 | p_select = p_valid[random_idx] # [N_rays//2, 2]
262 | pixels_x_2 = p_select[:, 0]
263 | pixels_y_2 = p_select[:, 1]
264 |
265 | pixels_x = torch.cat([pixels_x_1, pixels_x_2], dim=0).to(torch.int64)
266 | pixels_y = torch.cat([pixels_y_1, pixels_y_2], dim=0).to(torch.int64)
267 | # normalized ndc uv coordinates, (-1, 1)
268 | ndc_u = 2 * pixels_x / (self.W - 1) - 1
269 | ndc_v = 2 * pixels_y / (self.H - 1) - 1
270 | rays_ndc_uv = torch.stack([ndc_u, ndc_v], dim=-1).view(-1, 2).float()
271 |
272 | edge = self.edges[img_idx][(pixels_y, pixels_x)] # batch_size, 1
273 | p = torch.stack(
274 | [pixels_x, pixels_y, torch.ones_like(pixels_y)], dim=-1
275 | ).float() # batch_size, 3
276 | p = torch.matmul(
277 | self.intrinsics_all_inv[img_idx, None, :3, :3], p[:, :, None]
278 | ).squeeze() # batch_size, 3
279 |
280 | rays_v = p / torch.linalg.norm(p, ord=2, dim=-1, keepdim=True) # batch_size, 3
281 | depth_scale = rays_v[:, 2:] # batch_size, 1
282 | rays_v = torch.matmul(
283 | self.pose_all[img_idx, None, :3, :3], rays_v[:, :, None]
284 | ).squeeze() # batch_size, 3
285 | rays_o = self.pose_all[img_idx, None, :3, 3].expand(
286 | rays_v.shape
287 | ) # batch_size, 3
288 |
289 | rays = {
290 | "rays_o": rays_o.to(self.device),
291 | "rays_v": rays_v.to(self.device),
292 | "edge": edge.to(self.device),
293 | }
294 |
295 | pose = self.pose_all[img_idx] # [4, 4]
296 | intrinsics = self.intrinsics_all[img_idx] # [4, 4]
297 |
298 | sample = {
299 | "rays": rays,
300 | "pose": pose,
301 | "intrinsics": intrinsics,
302 | "rays_ndc_uv": rays_ndc_uv.to(self.device),
303 | "rays_norm_XYZ_cam": p.to(self.device), # - XYZ_cam, before multiply depth,
304 | "depth_scale": depth_scale.to(self.device),
305 | }
306 |
307 | return sample
308 |
309 | def edge_at(self, idx, resolution_level):
310 | edge = cv.imread(self.edges_list[idx], 0)[..., None]
311 | edge = (
312 | cv.resize(edge, (self.W // resolution_level, self.H // resolution_level))
313 | ).clip(0, 255)
314 | return edge
315 |
316 | def color_at(self, idx, resolution_level):
317 | img = cv.imread(self.colors_list[idx])
318 | img = cv.resize(
319 | img,
320 | (self.W // resolution_level, self.H // resolution_level),
321 | interpolation=cv.INTER_NEAREST,
322 | )
323 | return img
324 |
--------------------------------------------------------------------------------
/src/edge_extraction/edge_fitting/bezier_fit.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.optimize import curve_fit
3 |
4 |
5 | def bezier_curve(tt, p0, p1, p2, p3, p4, p5, p6, p7, p8, p9, p10, p11):
6 | n = len(tt)
7 | matrix_t = np.concatenate(
8 | [(tt**3)[..., None], (tt**2)[..., None], tt[..., None], np.ones((n, 1))],
9 | axis=1,
10 | ).astype(float)
11 | matrix_w = np.array(
12 | [[-1, 3, -3, 1], [3, -6, 3, 0], [-3, 3, 0, 0], [1, 0, 0, 0]]
13 | ).astype(float)
14 | matrix_p = np.array(
15 | [[p0, p1, p2], [p3, p4, p5], [p6, p7, p8], [p9, p10, p11]]
16 | ).astype(float)
17 | return np.dot(np.dot(matrix_t, matrix_w), matrix_p).reshape(-1)
18 |
19 |
20 | def bezier_fit(xyz, error_threshold=1.0):
21 | n = len(xyz)
22 | t = np.linspace(0, 1, n)
23 | xyz = xyz.reshape(-1)
24 |
25 | popt, _ = curve_fit(bezier_curve, t, xyz)
26 |
27 | # Generate fitted curve
28 | fitted_curve = bezier_curve(t, *popt).reshape(-1, 3)
29 |
30 | # Calculate residuals
31 | residuals = xyz.reshape(-1, 3) - fitted_curve
32 |
33 | # Calculate RMSE
34 | rmse = np.sqrt(np.mean(np.sum(residuals**2, axis=1)))
35 |
36 | if rmse > error_threshold:
37 | return None
38 | else:
39 | return popt
40 |
--------------------------------------------------------------------------------
/src/edge_extraction/edge_fitting/line_fit.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 |
4 | def split_into_monotonic_sublists(numbers, max_longsublists=2, min_length=4):
5 | if not numbers:
6 | return []
7 |
8 | # Initialize list to store continuous and monotonic sublists
9 | monotonic_sublists = []
10 | current_sublist = [numbers[0]]
11 |
12 | # Create continuous and monotonic sublists from the original list
13 | for i in range(1, len(numbers)):
14 | if numbers[i] == numbers[i - 1] + 1:
15 | current_sublist.append(numbers[i])
16 | else:
17 | if len(current_sublist) > 1:
18 | monotonic_sublists.append(tuple(current_sublist))
19 | current_sublist = [numbers[i]]
20 |
21 | # Add the last continuous and monotonic sublist if it contains more than one element
22 | if len(current_sublist) > 1:
23 | monotonic_sublists.append(tuple(current_sublist))
24 |
25 | # Convert to set and back to list to remove duplicates
26 | monotonic_sublists = list(set(monotonic_sublists))
27 | # Sort the sublists by length in descending order
28 | monotonic_sublists.sort(key=len, reverse=True)
29 |
30 | # Keep the specified number of longest sublists
31 | max_sublists = min(max_longsublists, len(monotonic_sublists))
32 | long_sublists = monotonic_sublists[:max_sublists]
33 | short_sublists = monotonic_sublists[max_sublists:]
34 |
35 | # Handle sublists that are too short
36 | sublists_out_curves = []
37 | for sublist in long_sublists:
38 | if len(sublist) < min_length:
39 | short_sublists.append(sublist)
40 | else:
41 | sublists_out_curves.append(list(sublist))
42 |
43 | # Split the remaining short sublists into pairs of numbers
44 | sublists_out_lines = []
45 | for sublist in short_sublists:
46 | for j in range(len(sublist) - 1):
47 | sublists_out_lines.append([sublist[j], sublist[j + 1]])
48 |
49 | return [list(t) for t in sublists_out_curves], [list(t) for t in sublists_out_lines]
50 |
51 |
52 | def fit_line_ransac_3d(
53 | points_wld,
54 | voxel_size=256,
55 | max_iterations=100,
56 | min_inliers=4,
57 | max_lines=3,
58 | max_curves=2,
59 | keep_short_lines=False,
60 | ransac_with_direction=False,
61 | ):
62 | """
63 | Fit multiple lines to 3D points using RANSAC.
64 |
65 | Parameters:
66 | - points (numpy.ndarray): Array of 3D points.
67 | - voxel_size (float): Voxel size for inlier distance threshold.
68 | - max_iterations (int): Maximum number of RANSAC iterations.
69 | - min_inliers (int): Minimum number of inliers required to consider a line.
70 | - max_lines (int): Maximum number of lines to fit.
71 |
72 | Returns:
73 | - best_endpoints (list): List of line endpoints (start and end points).
74 | - split_points (list): List of points belonging to each fitted line.
75 | - remaining_points (numpy.ndarray): Points not assigned to any line.
76 | - remaining_point_indices (list): Indices of remaining points in the original input points.
77 | """
78 | inlier_dist_threshold = 1.0 / voxel_size # 1.0 / voxel_size
79 | best_lines = []
80 | best_endpoints = []
81 | split_points = []
82 | # remaining_point_indices = [] # List to store indices of remaining points
83 | N_points = len(points_wld)
84 | remaining_point_indices = np.arange(N_points)
85 | min_inlier_ratio = 1.0 / max_lines
86 | raw_points_wld = points_wld.copy()
87 |
88 | while max_lines and len(points_wld) >= min_inliers:
89 | max_lines -= 1
90 | best_line = None
91 | best_inliers_mask = None
92 | best_num_inliers = 0
93 |
94 | if not ransac_with_direction:
95 | for _ in range(max_iterations):
96 | # Generate all unique combinations of point pairs
97 | sample_indices = np.random.choice(len(points_wld), 2, replace=False)
98 | sample_points_wld = points_wld[sample_indices, :3]
99 |
100 | p1, p2 = sample_points_wld
101 | direction = p2 - p1
102 |
103 | if np.linalg.norm(direction) < 1e-6:
104 | continue
105 |
106 | direction /= np.linalg.norm(direction)
107 |
108 | distances = np.linalg.norm(
109 | np.cross(points_wld[:, :3] - p1, direction), axis=1
110 | )
111 |
112 | inliers_mask = distances < inlier_dist_threshold
113 | num_inliers = np.sum(inliers_mask)
114 |
115 | if num_inliers > best_num_inliers:
116 | best_line = (p1, direction)
117 | best_num_inliers = num_inliers
118 | best_inliers_mask = inliers_mask
119 |
120 | else:
121 | points = points_wld[:, :3]
122 | direction = points_wld[:, 3:]
123 | normalized_direction = direction / np.linalg.norm(
124 | direction, axis=1, keepdims=True
125 | )
126 |
127 | distance = np.linalg.norm(
128 | np.cross(points - points[:, None], normalized_direction), axis=2
129 | ) # N x N
130 | inlier_mask = distance < inlier_dist_threshold
131 | num_inliers = np.sum(inlier_mask, axis=1)
132 |
133 | best_inlier_idx = np.argmax(num_inliers)
134 | best_line = (points[best_inlier_idx], direction[best_inlier_idx])
135 | best_inliers_mask = inlier_mask[best_inlier_idx]
136 | best_num_inliers = num_inliers[best_inlier_idx]
137 |
138 | if best_num_inliers >= min_inliers:
139 | p1, direction = best_line
140 | inlier_points = points_wld[best_inliers_mask, :3]
141 | inlier_ratio_pred = best_num_inliers / N_points
142 | if inlier_ratio_pred < min_inlier_ratio:
143 | break
144 |
145 | center = np.mean(inlier_points, axis=0)
146 | endpoints_centered = inlier_points - center
147 | u, s, vh = np.linalg.svd(endpoints_centered, full_matrices=False)
148 | updated_direction = vh[0]
149 | updated_direction = updated_direction / np.linalg.norm(updated_direction)
150 |
151 | projections = np.dot(inlier_points - p1, updated_direction)
152 | line_segment = np.zeros(6)
153 | line_segment[:3] = p1 + np.min(projections) * updated_direction
154 | line_segment[3:] = p1 + np.max(projections) * updated_direction
155 |
156 | points_wld = points_wld[~best_inliers_mask]
157 | split_points.append(inlier_points.tolist())
158 | remaining_point_indices = remaining_point_indices[~best_inliers_mask]
159 |
160 | best_lines.append(best_line)
161 | best_endpoints.append(line_segment)
162 |
163 | # find potential curve points
164 | if len(remaining_point_indices) > 0:
165 | potential_curve_indices, shortline_indices = split_into_monotonic_sublists(
166 | remaining_point_indices.tolist(), max_curves
167 | )
168 | potential_curve_points = [
169 | raw_points_wld[potential_curve_indice, :3]
170 | for potential_curve_indice in potential_curve_indices
171 | ]
172 | if keep_short_lines and len(shortline_indices) > 0:
173 | shortline_points = raw_points_wld[shortline_indices, :3]
174 | shortline_points = shortline_points.reshape(-1, 6)
175 | best_endpoints.extend(shortline_points)
176 | split_points.extend(shortline_points.reshape(-1, 2, 3).tolist())
177 | else:
178 | potential_curve_points = []
179 |
180 | return best_endpoints, split_points, potential_curve_points
181 |
182 |
183 | def line_fitting(endpoints):
184 | center = np.mean(endpoints, axis=0)
185 |
186 | # compute the main direction through SVD
187 | endpoints_centered = endpoints - center
188 | u, s, vh = np.linalg.svd(endpoints_centered, full_matrices=False)
189 | lamda = s[0] / np.sum(s)
190 | main_direction = vh[0]
191 | main_direction = main_direction / np.linalg.norm(main_direction)
192 |
193 | # project endpoints onto the main direction
194 | projections = []
195 | for endpoint_centered in endpoints_centered:
196 | projections.append(np.dot(endpoint_centered, main_direction))
197 | projections = np.array(projections)
198 |
199 | # construct final line
200 | straight_line = np.zeros(6)
201 | # print(np.min(projections), np.max(projections))
202 | straight_line[:3] = center + main_direction * np.min(projections)
203 | straight_line[3:] = center + main_direction * np.max(projections)
204 |
205 | return straight_line, lamda
206 |
207 |
208 | def lines_fitting(lines, lamda_threshold):
209 | straight_lines = []
210 | curve_line_segments_candidate = []
211 | curves = []
212 | lamda_list = []
213 | for endpoints in lines:
214 | # merge line segments into a final line segment
215 | # total least squares on endpoints
216 | straight_line, lamda = line_fitting(endpoints)
217 | lamda_list.append(lamda)
218 | if lamda < lamda_threshold:
219 | curves.append(endpoints)
220 | curve_line_segments_candidate.append(
221 | [
222 | np.hstack([endpoints[i], endpoints[i + 1]])
223 | for i in range(len(endpoints) - 1)
224 | ]
225 | )
226 | continue
227 |
228 | straight_lines.append(straight_line)
229 | straight_lines = np.array(straight_lines)
230 | curve_line_segments_candidate = np.array(curve_line_segments_candidate)
231 | return (
232 | straight_lines,
233 | curve_line_segments_candidate,
234 | curves,
235 | lamda_list,
236 | )
237 |
--------------------------------------------------------------------------------
/src/edge_extraction/edge_fitting/main.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from scipy.spatial.distance import cdist
3 | import open3d as o3d
4 | import random
5 | from src.edge_extraction.edge_fitting.bezier_fit import bezier_fit, bezier_curve
6 | from src.edge_extraction.edge_fitting.line_fit import fit_line_ransac_3d
7 |
8 |
9 | class LineSegment:
10 | def __init__(self, start_point, end_point):
11 | self.start_point = start_point
12 | self.end_point = end_point
13 |
14 |
15 | def generate_segments_from_idx(connected_lines, points_wld):
16 | segments = []
17 | polylines_wld = []
18 | for connected_line in connected_lines:
19 | polyline_wld = [points_wld[connected_line[0]].tolist()]
20 | for i in range(len(connected_line) - 1):
21 | segment = [
22 | points_wld[connected_line[i]].tolist(),
23 | points_wld[connected_line[i + 1]].tolist(),
24 | ]
25 | segments.append(segment)
26 | polyline_wld.append(points_wld[connected_line[i + 1]])
27 | polyline_wld = np.array(polyline_wld).reshape(-1, 6)
28 | polylines_wld.append(polyline_wld)
29 |
30 | return np.array(segments).reshape(-1, 6), polylines_wld
31 |
32 |
33 | def create_line_segments_from_3d_array(segment_data):
34 | x1, y1, z1, x2, y2, z2 = (
35 | segment_data[:, 0],
36 | segment_data[:, 1],
37 | segment_data[:, 2],
38 | segment_data[:, 3],
39 | segment_data[:, 4],
40 | segment_data[:, 5],
41 | )
42 | segments = []
43 | for i in range(len(x1)):
44 | segments.append(
45 | LineSegment(
46 | np.array([x1[i], y1[i], z1[i]]), np.array([x2[i], y2[i], z2[i]])
47 | )
48 | )
49 | return segments
50 |
51 |
52 | def is_point_inside_ranges(point, ranges):
53 | point = np.array(point)
54 | if not np.all(point > ranges[0]) or not np.all(point < ranges[1]):
55 | return False
56 | return True
57 |
58 |
59 | def is_line_inside_ranges(line_segment, ranges):
60 | if not is_point_inside_ranges(line_segment.start_point, ranges):
61 | return False
62 | if not is_point_inside_ranges(line_segment.end_point, ranges):
63 | return False
64 | return True
65 |
66 |
67 | def create_open3d_line_set(
68 | line_segments, color=[0.5, 0.5, 0.5], width=2, ranges=None, scale=1.0
69 | ):
70 | o3d_points, o3d_lines, o3d_colors = [], [], []
71 | counter = 0
72 | for line_segment in line_segments:
73 | if ranges is not None:
74 | if not is_line_inside_ranges(line_segment, ranges):
75 | continue
76 | o3d_points.append(line_segment.start_point * scale)
77 | o3d_points.append(line_segment.end_point * scale)
78 | o3d_lines.append([2 * counter, 2 * counter + 1])
79 | counter += 1
80 | line_set = o3d.geometry.LineSet()
81 | line_set.points = o3d.utility.Vector3dVector(o3d_points)
82 | line_set.lines = o3d.utility.Vector2iVector(o3d_lines)
83 | line_set.colors = o3d.utility.Vector3dVector(o3d_colors)
84 | return line_set
85 |
86 |
87 | def save_3d_lines_to_file(line_segments, filename, width=2, ranges=None, scale=1.0):
88 | lines = create_line_segments_from_3d_array(line_segments)
89 | line_set = create_open3d_line_set(lines, width=width, ranges=ranges, scale=scale)
90 | o3d.io.write_line_set(filename, line_set)
91 |
92 |
93 | def connect_points(
94 | points, distance_threshold, angle_threshold, nms_factor, keep_short_lines
95 | ):
96 | num_points = len(points)
97 | connected_line_segments = []
98 |
99 | unvisited_points = set(range(num_points))
100 |
101 | while len(unvisited_points) > 0:
102 | selected_point = np.random.choice(list(unvisited_points))
103 | selected_point_opposite = selected_point
104 |
105 | unvisited_points.remove(selected_point)
106 | connected_line = [selected_point]
107 |
108 | while True: # forward connection
109 | dist = cdist(
110 | [points[selected_point, :3]], points[list(unvisited_points), :3]
111 | )
112 | neighboring_points = np.where(dist < distance_threshold)[1]
113 | neighboring_distance = dist[0, neighboring_points].reshape(-1)
114 |
115 | neighboring_points = (
116 | np.array(list(unvisited_points))[neighboring_points]
117 | ).tolist()
118 |
119 | if len(neighboring_points) == 0:
120 | break
121 |
122 | directions = (
123 | points[neighboring_points, :3] - points[selected_point, :3][None, ...]
124 | )
125 | directions /= np.linalg.norm(directions, axis=1)[:, np.newaxis] + 1e-6
126 |
127 | dot_products = np.dot(directions, points[selected_point, 3:])
128 |
129 | closest_point_idx = np.argmax(dot_products)
130 |
131 | if (
132 | dot_products[closest_point_idx] <= 1 - angle_threshold
133 | ): # no suitable point found
134 | break
135 |
136 | connected_line.append(
137 | neighboring_points[closest_point_idx]
138 | ) # add the point to the line
139 |
140 | invalid_points_idx = np.where(
141 | (neighboring_distance <= neighboring_distance[closest_point_idx])
142 | * (dot_products < dot_products[closest_point_idx])
143 | * (dot_products >= nms_factor * dot_products[closest_point_idx])
144 | )[0]
145 | invalid_points = np.array(neighboring_points)[invalid_points_idx].tolist()
146 |
147 | unvisited_points.difference_update(invalid_points)
148 |
149 | if (
150 | np.dot(
151 | points[neighboring_points[closest_point_idx], 3:],
152 | directions[closest_point_idx],
153 | )
154 | <= 0.5
155 | ): #
156 | break
157 |
158 | unvisited_points.remove(
159 | neighboring_points[closest_point_idx]
160 | ) # remove connected point from unvisited points set
161 | selected_point = neighboring_points[
162 | closest_point_idx
163 | ] # update anchor point
164 |
165 | while True: # backward connection
166 | dist = cdist(
167 | [points[selected_point_opposite, :3]],
168 | points[list(unvisited_points), :3],
169 | )
170 | neighboring_points = np.where(dist < distance_threshold)[1]
171 | neighboring_distance = dist[0, neighboring_points].reshape(-1)
172 |
173 | neighboring_points = (
174 | np.array(list(unvisited_points))[neighboring_points]
175 | ).tolist()
176 |
177 | if len(neighboring_points) == 0:
178 | break
179 |
180 | directions = (
181 | points[neighboring_points, :3]
182 | - points[selected_point_opposite, :3][None, ...]
183 | )
184 | directions /= np.linalg.norm(directions, axis=1)[:, np.newaxis] + 1e-6
185 |
186 | dot_products = np.dot(directions, points[selected_point_opposite, 3:])
187 |
188 | closest_point_idx = np.argmin(dot_products)
189 |
190 | if (
191 | abs(dot_products[closest_point_idx]) <= 1 - angle_threshold
192 | or dot_products[closest_point_idx] >= 0
193 | ):
194 | break
195 |
196 | connected_line.insert(
197 | 0, neighboring_points[closest_point_idx]
198 | ) # add connected point to the beginning of the line
199 |
200 | invalid_points_idx = np.where(
201 | (neighboring_distance <= neighboring_distance[closest_point_idx])
202 | * (dot_products > dot_products[closest_point_idx])
203 | * (dot_products <= nms_factor * dot_products[closest_point_idx])
204 | )[0]
205 | invalid_points = np.array(neighboring_points)[invalid_points_idx].tolist()
206 |
207 | unvisited_points.difference_update(invalid_points)
208 |
209 | if (
210 | np.dot(
211 | -points[neighboring_points[closest_point_idx], 3:],
212 | directions[closest_point_idx],
213 | )
214 | <= 0.5
215 | ):
216 | break
217 |
218 | unvisited_points.remove(neighboring_points[closest_point_idx])
219 | selected_point_opposite = neighboring_points[closest_point_idx]
220 |
221 | if not keep_short_lines:
222 | if len(connected_line) > 3:
223 | connected_line_segments.append(connected_line)
224 | else:
225 | if len(connected_line) > 1:
226 | connected_line_segments.append(connected_line)
227 |
228 | return connected_line_segments
229 |
230 |
231 | def edge_fitting(
232 | polylines_wld,
233 | voxel_size=256,
234 | max_iterations=100,
235 | min_inliers=4,
236 | max_lines=3,
237 | max_curves=2,
238 | keep_short_lines=True,
239 | ):
240 | straight_lines = []
241 | raw_points_on_straight_lines = []
242 | bezier_curve_params = []
243 | bezier_curve_points = []
244 | raw_points_on_curves = []
245 | t_fit = np.linspace(0, 1, 100)
246 | for endpoints_wld in polylines_wld:
247 | if len(endpoints_wld) < 4 and keep_short_lines: # keep short lines
248 |
249 | for i in range(len(endpoints_wld) - 1):
250 | segment = [
251 | endpoints_wld[i, :3],
252 | endpoints_wld[i + 1, :3],
253 | ]
254 | straight_lines.append(np.array(segment).reshape(-1))
255 | raw_points_on_straight_lines.extend(
256 | [np.array(segment).reshape(-1, 3).tolist()]
257 | )
258 | else:
259 | (
260 | straight_line,
261 | split_points,
262 | potential_curve_points,
263 | ) = fit_line_ransac_3d(
264 | endpoints_wld,
265 | voxel_size,
266 | max_iterations,
267 | min_inliers,
268 | max_lines,
269 | max_curves,
270 | keep_short_lines,
271 | )
272 |
273 | if len(split_points) >= 1: # fitted straight lines exist
274 | straight_lines.extend(straight_line)
275 | raw_points_on_straight_lines.extend(split_points)
276 |
277 | if len(potential_curve_points) >= 1: # fitted curves exist
278 | for curve_points in potential_curve_points:
279 | p = bezier_fit(curve_points, error_threshold=5.0 / voxel_size)
280 | if p is None:
281 | print("Fitting error too high, not fitting")
282 | continue
283 | bezier_curve_params.append(p)
284 |
285 | xyz_fit = bezier_curve(t_fit, *p).reshape(-1, 3)
286 | bezier_curve_points.append(xyz_fit)
287 | raw_points_on_curves.append(curve_points.tolist())
288 |
289 | straight_lines = np.array(straight_lines)
290 |
291 | if len(bezier_curve_points) >= 1:
292 | bezier_curve_points = np.concatenate(bezier_curve_points, axis=0)
293 | bezier_curve_params = np.array(bezier_curve_params)
294 |
295 | return (
296 | straight_lines,
297 | raw_points_on_straight_lines,
298 | bezier_curve_params,
299 | bezier_curve_points,
300 | raw_points_on_curves,
301 | )
302 |
303 |
304 | def edge_fit(
305 | edge_data=None,
306 | angle_threshold=0.03,
307 | nms_factor=0.9,
308 | fit_distance_threshold=10.0,
309 | min_inliers=4,
310 | max_lines=4,
311 | max_curves=3,
312 | keep_short_lines=True,
313 | ):
314 | res = np.array(edge_data["resolution"])
315 | raw_points = np.array(edge_data["points"])
316 | raw_ld_colors = np.array(edge_data["ld_colors"])
317 | pcd = o3d.geometry.PointCloud()
318 | pcd.points = o3d.utility.Vector3dVector(raw_points)
319 | pcd.colors = o3d.utility.Vector3dVector(raw_ld_colors)
320 | fit_distance_threshold = fit_distance_threshold / res
321 | pcd = pcd.voxel_down_sample(voxel_size=2.0 / res)
322 |
323 | points = np.asarray(pcd.points)
324 | ld_colors = np.asarray(pcd.colors)
325 | linedirection = ld_colors * 2 - 1 # recover the line direction
326 | linedirection = linedirection / (
327 | np.linalg.norm(linedirection, axis=1)[:, np.newaxis] + 1e-6
328 | )
329 | points_wld = np.concatenate((points, linedirection), axis=1)
330 |
331 | connected_line = connect_points(
332 | points_wld,
333 | fit_distance_threshold,
334 | angle_threshold,
335 | nms_factor,
336 | keep_short_lines,
337 | )
338 |
339 | _, polylines_wld = generate_segments_from_idx(connected_line, points_wld)
340 |
341 | (
342 | straight_lines,
343 | raw_points_on_straight_lines,
344 | bezier_curve_params,
345 | bezier_curve_points,
346 | raw_points_on_curves,
347 | ) = edge_fitting(
348 | polylines_wld,
349 | voxel_size=res,
350 | max_iterations=100,
351 | min_inliers=min_inliers,
352 | max_lines=max_lines,
353 | max_curves=max_curves,
354 | keep_short_lines=keep_short_lines,
355 | )
356 |
357 | fitted_edge_dict = {
358 | "resolution": int(res),
359 | "lines_end_pts": straight_lines.tolist() if len(straight_lines) > 0 else [],
360 | "raw_points_on_lines": (
361 | raw_points_on_straight_lines
362 | if len(raw_points_on_straight_lines) > 0
363 | else []
364 | ),
365 | "curves_ctl_pts": (
366 | bezier_curve_params.tolist() if len(bezier_curve_params) > 0 else []
367 | ),
368 | "raw_points_on_curves": (
369 | raw_points_on_curves if len(raw_points_on_curves) > 0 else []
370 | ),
371 | }
372 |
373 | return fitted_edge_dict
374 |
--------------------------------------------------------------------------------
/src/edge_extraction/extract_parametric_edge.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import json
4 | from pathlib import Path
5 | import json
6 | import cv2
7 | import copy
8 | import math
9 | from src.edge_extraction.edge_fitting.main import edge_fit
10 | from src.edge_extraction.merging.main import merge
11 | from src.edge_extraction.extract_util import bezier_curve_length
12 |
13 |
14 | def load_from_json(filename: Path):
15 | """Load a dictionary from a JSON filename.
16 |
17 | Args:
18 | filename: The filename to load from.
19 | """
20 | assert filename.suffix == ".json"
21 | with open(filename, encoding="UTF-8") as file:
22 | return json.load(file)
23 |
24 |
25 | def get_edge_maps(data_dir, detector):
26 | meta = load_from_json(Path(data_dir) / "meta_data.json")
27 | h, w = meta["height"], meta["width"]
28 | edges_list, intrinsics_list, camtoworld_list = [], [], []
29 | for idx, frame in enumerate(meta["frames"]):
30 | intrinsics = np.array(frame["intrinsics"])
31 | camtoworld = np.array(frame["camtoworld"])[:4, :4]
32 | if detector == "DexiNed":
33 | edges_list.append(
34 | os.path.join(
35 | data_dir,
36 | "edge_DexiNed",
37 | frame["rgb_path"],
38 | )
39 | )
40 | elif detector == "PidiNet":
41 | edges_list.append(
42 | os.path.join(
43 | data_dir,
44 | "edge_PidiNet",
45 | frame["rgb_path"][:-4] + ".png",
46 | )
47 | )
48 | else:
49 | raise ValueError(f"Unknown detector: {detector}")
50 | intrinsics_list.append(intrinsics)
51 | camtoworld_list.append(camtoworld)
52 |
53 | edges = [cv2.imread(im_name, 0)[..., None] for im_name in edges_list]
54 |
55 | if detector == "DexiNed":
56 | edges = 1 - np.stack(edges) / 255.0
57 | elif detector == "PidiNet":
58 | edges = np.stack(edges) / 255.0
59 |
60 | intrinsics_list = np.stack(intrinsics_list)
61 | camtoworld_list = np.stack(camtoworld_list)
62 | return edges, intrinsics_list, camtoworld_list, h, w
63 |
64 |
65 | def process_geometry_data(
66 | edge_dict,
67 | worldtogt=None,
68 | valid_curve=None,
69 | valid_line=None,
70 | sample_resolution=0.005,
71 | ):
72 | """
73 | Processes edge data to transform and sample points from geometric data (curves and lines).
74 | Optionally transforms points to a target coordinate system and filters specific geometries.
75 |
76 | Parameters:
77 | edge_dict (dict): Dictionary containing 'curves_ctl_pts' and 'lines_end_pts'.
78 | worldtogt (np.array, optional): Transformation matrix to convert points to another coordinate system.
79 | valid_curve (np.array, optional): Indices to filter specific curves.
80 | valid_line (np.array, optional): Indices to filter specific lines.
81 | sample_resolution (float): Sampling resolution for generating points.
82 |
83 | Returns:
84 | np.array: Array of sampled points.
85 | int: Number of curve points generated.
86 | """
87 | # Process curves
88 | return_edge_dict = {}
89 | curve_data = edge_dict["curves_ctl_pts"]
90 | curve_paras = np.array(curve_data).reshape(-1, 12)
91 | if valid_curve is not None:
92 | curve_paras = curve_paras[valid_curve]
93 | curve_paras = curve_paras.reshape(-1, 4, 3)
94 | return_edge_dict["curves_ctl_pts"] = curve_paras.tolist()
95 |
96 | if worldtogt is not None:
97 | curve_paras = curve_paras @ worldtogt[:3, :3].T + worldtogt[:3, 3]
98 |
99 | # Process lines
100 | line_data = edge_dict["lines_end_pts"]
101 | lines = np.array(line_data).reshape(-1, 6)
102 | if valid_line is not None:
103 | lines = lines[valid_line]
104 |
105 | return_edge_dict["lines_end_pts"] = lines.tolist()
106 |
107 | lines = lines.reshape(-1, 2, 3)
108 |
109 | if worldtogt is not None:
110 | lines = lines @ worldtogt[:3, :3].T + worldtogt[:3, 3]
111 |
112 | all_points = []
113 |
114 | # Sample curves
115 | for curve in curve_paras:
116 | sample_num = int(
117 | bezier_curve_length(curve, num_samples=100) // sample_resolution
118 | )
119 | t = np.linspace(0, 1, sample_num)
120 | coefficients = np.array(
121 | [[-1, 3, -3, 1], [3, -6, 3, 0], [-3, 3, 0, 0], [1, 0, 0, 0]]
122 | )
123 | matrix_u = np.array([t**3, t**2, t, np.ones_like(t)])
124 | points = matrix_u.T.dot(coefficients).dot(curve)
125 | all_points.extend(points.tolist())
126 |
127 | # Sample lines
128 | for line in lines:
129 | sample_num = int(np.linalg.norm(line[0] - line[1]) // sample_resolution)
130 | t = np.linspace(0, 1, sample_num)
131 | line_points = np.outer(t, line[1] - line[0]) + line[0]
132 | all_points.extend(line_points.tolist())
133 |
134 | return np.array(all_points, dtype=np.float32), return_edge_dict
135 |
136 |
137 | def compute_visibility(
138 | all_curve_points,
139 | all_line_points,
140 | edges,
141 | intrinsics_list,
142 | camtoworld_list,
143 | h,
144 | w,
145 | edge_visibility_threshold,
146 | edge_visibility_frames,
147 | ):
148 | img_frames = len(edges)
149 | curve_num, line_num = len(all_curve_points), len(all_line_points)
150 | edge_num = curve_num + line_num
151 | edge_visibility_matrix = np.zeros((edge_num, img_frames))
152 |
153 | for frame_idx, (edge_map, intrinsic, camtoworld) in enumerate(
154 | zip(edges, intrinsics_list, camtoworld_list)
155 | ):
156 | K = intrinsic[:3, :3]
157 | worldtocam = np.linalg.inv(camtoworld)
158 | R = worldtocam[:3, :3]
159 | T = worldtocam[:3, 3:]
160 |
161 | all_curve_uv, all_line_uv = project2D(
162 | K, R, T, copy.deepcopy(all_curve_points), copy.deepcopy(all_line_points)
163 | )
164 | all_edge_uv = all_curve_uv + all_line_uv
165 |
166 | for edge_idx, edge_uv in enumerate(all_edge_uv):
167 | edge_uv = np.array(edge_uv)
168 | # print(edge_uv.shape)
169 | if len(edge_uv) == 0:
170 | continue
171 | edge_uv = np.round(edge_uv).astype(np.int32)
172 | edge_u = edge_uv[:, 0]
173 | edge_v = edge_uv[:, 1]
174 |
175 | valid_edge_uv = edge_uv[
176 | (edge_u >= 0) & (edge_u < w) & (edge_v >= 0) & (edge_v < h)
177 | ]
178 | visibility = 0
179 |
180 | if len(valid_edge_uv) > 0:
181 | projected_edge = edge_map[valid_edge_uv[:, 1], valid_edge_uv[:, 0]]
182 | visibility = float(
183 | np.mean(projected_edge) > edge_visibility_threshold
184 | and np.max(projected_edge) > 0.5
185 | )
186 | edge_visibility_matrix[edge_idx, frame_idx] = visibility
187 |
188 | return np.sum(edge_visibility_matrix, axis=1) > edge_visibility_frames
189 |
190 |
191 | def project2D(K, R, T, all_curve_points, all_line_points):
192 | all_curve_uv, all_line_uv = [], []
193 | for curve_points in all_curve_points:
194 | curve_points = np.array(curve_points).reshape(-1, 3)
195 | curve_uv = project2D_single(K, R, T, curve_points)
196 | all_curve_uv.append(curve_uv)
197 | for line_points in all_line_points:
198 | line_points = np.array(line_points).reshape(-1, 3)
199 | line_uv = project2D_single(K, R, T, line_points)
200 | all_line_uv.append(line_uv)
201 | return all_curve_uv, all_line_uv
202 |
203 |
204 | def project2D_single(K, R, T, points3d):
205 | shape = points3d.shape
206 | assert shape[-1] == 3
207 | X = points3d.reshape(-1, 3)
208 |
209 | x = K @ (R @ X.T + T)
210 | x = x.T
211 | x = x / x[:, -1:]
212 | x = x.reshape(*shape)[..., :2].reshape(-1, 2).tolist()
213 | return x
214 |
215 |
216 | def get_parametric_edge(
217 | edge_dict,
218 | visible_checking=False,
219 | ):
220 |
221 | detector = edge_dict["detector"]
222 | scene_name = edge_dict["scene_name"]
223 | dataset_dir = edge_dict["dataset_dir"]
224 | result_dir = edge_dict["result_dir"]
225 | meta_data_dir = os.path.join(dataset_dir, scene_name)
226 |
227 | # fixed parameters, but can be fine-tuned for better edge extraction
228 | is_merge = True
229 | nms_factor = 0.95
230 | angle_threshold = 0.03
231 | fit_distance_threshold = 10.0
232 | min_inliers = 5
233 | max_lines = 4
234 | max_curves = 3
235 | merge_edge_distance_threshold = 5.0
236 | merge_endpoints_distance_threshold = 2.0
237 | merge_similarity_threshold = 0.98
238 |
239 | fitted_edge_dict = edge_fit(
240 | edge_data=edge_dict,
241 | angle_threshold=angle_threshold,
242 | nms_factor=nms_factor,
243 | fit_distance_threshold=fit_distance_threshold,
244 | min_inliers=min_inliers,
245 | max_lines=max_lines,
246 | max_curves=max_curves,
247 | )
248 | if is_merge:
249 | merged_edge_dict = merge(
250 | result_dir,
251 | fitted_edge_dict,
252 | merge_edge_distance_threshold=merge_edge_distance_threshold,
253 | merge_endpoints_distance_threshold=merge_endpoints_distance_threshold,
254 | merge_similarity_threshold=merge_similarity_threshold,
255 | )
256 |
257 | if visible_checking:
258 | _, return_edge_dict = process_geometry_data(merged_edge_dict)
259 | all_curve_points = return_edge_dict["curves_ctl_pts"]
260 | all_line_points = return_edge_dict["lines_end_pts"]
261 | edges, intrinsics_list, camtoworld_list, h, w = get_edge_maps(
262 | meta_data_dir, detector
263 | )
264 | edge_visibility_threshold = 0.5
265 | edge_visibility_frames_ratio = 0.1
266 | num_frames = len(edges)
267 | edge_visibility_frames = math.ceil(edge_visibility_frames_ratio * num_frames)
268 |
269 | edge_visibility = compute_visibility(
270 | all_curve_points,
271 | all_line_points,
272 | edges,
273 | intrinsics_list,
274 | camtoworld_list,
275 | h,
276 | w,
277 | edge_visibility_threshold,
278 | edge_visibility_frames,
279 | )
280 | curve_visibility = edge_visibility[: len(all_curve_points)]
281 | line_visibility = edge_visibility[len(all_curve_points) :]
282 |
283 | print(
284 | "before visible checking: ",
285 | len(all_curve_points) + len(all_line_points),
286 | "after visible checking: ",
287 | np.sum(edge_visibility),
288 | )
289 | worldtogt = np.eye(4)
290 | (pred_points, return_edge_dict) = process_geometry_data(
291 | merged_edge_dict, worldtogt, curve_visibility, line_visibility
292 | )
293 |
294 | else:
295 | worldtogt = np.eye(4)
296 | (pred_points, return_edge_dict) = process_geometry_data(
297 | merged_edge_dict, worldtogt, None, None
298 | )
299 |
300 | return pred_points, return_edge_dict
301 |
--------------------------------------------------------------------------------
/src/edge_extraction/extract_pointcloud.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.nn import functional as F
3 |
4 |
5 | def get_udf_normals_grid(
6 | func,
7 | func_grad,
8 | N,
9 | udf_threshold,
10 | is_linedirection=False,
11 | sampling_N=50,
12 | sampling_delta=0.005,
13 | max_batch=int(2**12),
14 | device="cuda",
15 | ):
16 | """
17 | Efficiently fills a dense N*N*N regular grid by querying the function and its gradient for distance field values
18 | and optionally computing line directions. Adjusts voxel grid based on specified origin and max values.
19 |
20 | Parameters:
21 | - func: Callable for evaluating distance field values.
22 | - func_grad: Callable for evaluating gradients.
23 | - N: Size of the grid in each dimension.
24 | - udf_threshold: Threshold below which gradients are computed.
25 | - is_linedirection: Flag indicating whether to compute line directions.
26 | - sampling_N: Number of samples for line direction computation.
27 | - sampling_delta: Offset range for sampling around points for line direction.
28 | - max_batch: Max number of points processed in a single batch.
29 |
30 | Returns:
31 | Tuple of tensors (df_values, line_directions, gradients, samples, voxel_size) representing the computed
32 | distance field values, line directions, gradients at points below threshold, raw sample points, and the size
33 | of each voxel.
34 | """
35 |
36 | overall_index = torch.arange(0, N**3, 1, out=torch.LongTensor())
37 | samples = torch.zeros(N**3, 12, device=device)
38 | # transform first 3 columns
39 | # to be the x, y, z index
40 | samples[:, 2] = overall_index % N
41 | samples[:, 1] = torch.div(overall_index, N, rounding_mode="floor") % N
42 | samples[:, 0] = (
43 | torch.div(
44 | torch.div(overall_index, N, rounding_mode="floor"), N, rounding_mode="floor"
45 | )
46 | % N
47 | )
48 | # Ensure voxel_origin and voxel_max are correctly set
49 | voxel_origin = [-1, -1, -1]
50 | voxel_size = 2.0 / (N - 1)
51 | samples[:, 0] = (samples[:, 0] * voxel_size) + voxel_origin[2]
52 | samples[:, 1] = (samples[:, 1] * voxel_size) + voxel_origin[1]
53 | samples[:, 2] = (samples[:, 2] * voxel_size) + voxel_origin[0]
54 |
55 | # Query function for distance field values
56 | num_samples = N**3
57 | samples.requires_grad = False
58 | for head in range(0, num_samples, max_batch):
59 | tail = min(head + max_batch, num_samples)
60 | sample_subset = samples[head:tail, :3].clone().to(device)
61 | df, _, _ = func(sample_subset)
62 | samples[head:tail, 3:4] = df.detach()
63 |
64 | # Compute gradients where distance field value is below threshold
65 | norm_mask = samples[:, 3] < udf_threshold
66 | norm_idx = torch.where(norm_mask)[0]
67 | for head in range(0, len(norm_idx), max_batch):
68 | tail = min(head + max_batch, len(norm_idx))
69 | idx_subset = norm_idx[head:tail]
70 | sample_subset = samples[idx_subset, :3].clone().to(device).requires_grad_(True)
71 | grad = func_grad(sample_subset).detach()
72 | samples[idx_subset, 4:7] = -F.normalize(grad, dim=1)[:, 0]
73 |
74 | # Compute line directions if requested
75 | if is_linedirection:
76 | sample_subset_ld = sample_subset.unsqueeze(
77 | 1
78 | ) + sampling_delta * torch.randn(
79 | (sample_subset.shape[0], sampling_N, 3), device=device
80 | )
81 | grad_ld = (
82 | func_grad(sample_subset_ld.reshape(-1, 3))
83 | .detach()
84 | .reshape(sample_subset.shape[0], sampling_N, 3)
85 | )
86 | _, _, vh = torch.linalg.svd(grad_ld)
87 | null_space = vh[:, -1, :].view(-1, 3)
88 | samples[idx_subset, 8:11] = F.normalize(null_space, dim=1)
89 |
90 | # Reshape output tensors
91 | df_values = samples[:, 3].reshape(N, N, N)
92 | vecs = samples[:, 4:7].reshape(N, N, N, 3)
93 | ld = samples[:, 8:11].reshape(N, N, N, 3)
94 |
95 | return df_values, ld, vecs, samples, torch.tensor(voxel_size)
96 |
97 |
98 | def get_udf_normals_slow(
99 | func,
100 | func_grad,
101 | voxel_size,
102 | xyz,
103 | is_linedirection,
104 | # N_ld=20,
105 | sampling_N=50,
106 | sampling_delta=0.005, # 0.005
107 | max_batch=int(2**12),
108 | device="cuda",
109 | ):
110 | """
111 | Computes distance field values, normals, and optionally line directions for a set of points.
112 |
113 | Parameters:
114 | - func: Function to evaluate the distance field.
115 | - func_grad: Function to evaluate the gradient of the distance field.
116 | - voxel_size: Size of the voxel (not used in this function but kept for compatibility).
117 | - xyz: (N,3) tensor representing coordinates to evaluate.
118 | - is_linedirection: Boolean indicating whether to compute line directions.
119 | - sampling_N: Number of samples for computing line direction.
120 | - sampling_delta: Delta range for sampling around points for line direction.
121 | - max_batch: Maximum number of points to process in a single batch.
122 |
123 | Returns:
124 | - df_values: (N,) tensor of distance field values at xyz locations.
125 | - normals: (N,3) tensor of gradient values at xyz locations.
126 | - ld: (N,3) tensor of line direction values at xyz locations, if computed.
127 | - samples: (N, 10) tensor of x, y, z, distance field, grad_x, grad_y, grad_z, and optionally line directions.
128 | """
129 | # network.eval()
130 | ################
131 | # transform first 3 columns
132 | # to be the x, y, z coordinate
133 |
134 | num_samples = xyz.shape[0]
135 | # xyz = torch.from_numpy(xyz).float().cuda()
136 | samples = torch.cat([xyz, torch.zeros(num_samples, 10).float().cuda()], dim=-1)
137 | samples.requires_grad = False
138 | # samples.pin_memory()
139 | ################
140 | # 2: Run forward pass to fill the grid
141 | ################
142 | head = 0
143 | ## FIRST: fill distance field grid without gradients
144 | while head < num_samples:
145 | # xyz coords
146 | sample_subset = (
147 | samples[head : min(head + max_batch, num_samples), 0:3].clone().cuda()
148 | )
149 | # Create input
150 | xyz = sample_subset
151 |
152 | input = xyz.reshape(-1, xyz.shape[-1])
153 | # Run forward pass
154 | with torch.no_grad():
155 | df, _, PE = func(input)
156 | # Store df
157 | samples[head : min(head + max_batch, num_samples), 3] = (
158 | df.squeeze(-1).detach().cpu()
159 | )
160 | grad = func_grad(input).detach()[:, 0]
161 | normals = -F.normalize(grad, dim=1)
162 | samples[head : min(head + max_batch, num_samples), 4:7] = normals.cpu()
163 |
164 | if is_linedirection:
165 | input_ld = input.unsqueeze(
166 | 1
167 | ) + sampling_delta * torch.randn( # need to be fixed grid
168 | (input.shape[0], sampling_N, 3), device=device
169 | )
170 | # input_ld = input.unsqueeze(1) + offset
171 | input_ld = input_ld.reshape(-1, input.shape[-1])
172 | grad_ld = (
173 | func_grad(input_ld.float())
174 | .detach()[:, 0]
175 | .reshape(input.shape[0], -1, 3)
176 | )
177 | _, _, vh = torch.linalg.svd(grad_ld)
178 |
179 | # Extract the null space (non-zero solutions) for each matrix in the batch
180 | # null_space = vh[:, -1, :][vh[:, -1, :] != 0].view(-1, 3)
181 | null_space = vh[:, -1, :].view(-1, 3)
182 | samples[head : min(head + max_batch, num_samples), 7:10] = F.normalize(
183 | null_space, dim=1
184 | )
185 | # Next iter
186 | head += max_batch
187 |
188 | # Separate values in DF / gradients
189 | df_values = samples[:, 3]
190 | normals = samples[:, 4:7]
191 | ld = samples[:, 7:10]
192 |
193 | return df_values, normals, ld, samples
194 |
195 |
196 | import numpy as np
197 |
198 |
199 | def project_vector_onto_plane(A, B):
200 | # Calculate the dot product of A and B
201 | dot_product = torch.sum(A * B, dim=-1)
202 |
203 | # Calculate the projection of A onto B
204 | projection = dot_product.unsqueeze(-1) * B
205 |
206 | # Calculate the projected vector onto the plane perpendicular to B
207 | projected_vector = A - projection
208 |
209 | return projected_vector
210 |
211 |
212 | def get_pointcloud_from_udf(
213 | func,
214 | func_grad,
215 | N_MC=128,
216 | udf_threshold=1.0,
217 | sampling_N=50,
218 | sampling_delta=5e-3,
219 | is_pointshift=False,
220 | iters=1,
221 | is_linedirection=False,
222 | device="cuda",
223 | ):
224 | """
225 | Computes a point cloud from a distance field network conditioned on the latent vector.
226 | Inputs:
227 | func: Function to evaluate the distance field.
228 | func_grad: Function to evaluate the gradient of the distance field.
229 | N_MC: Size of the grid.
230 | udf_threshold: Threshold to filter surfaces with large UDF values.
231 | is_pointshift: Flag indicating if points should be shifted by df * normals.
232 | iters: Number of iterations for the point shift.
233 | is_linedirection: Flag indicating if line direction computation is needed.
234 | Returns:
235 | pointcloud: (N**3, 3) tensor representing the edge point cloud.
236 | samples: (N**3, 7) tensor representing (x,y,z, distance field, grad_x, grad_y, grad_z).
237 | indices: Indices of coordinates that need updating in the next iteration.
238 | """
239 | # Compute UDF normals and grid
240 | df_values, lds, normals, samples, voxel_size = get_udf_normals_grid(
241 | func=func,
242 | func_grad=func_grad,
243 | N=N_MC,
244 | udf_threshold=udf_threshold,
245 | is_linedirection=is_linedirection,
246 | sampling_N=sampling_N,
247 | sampling_delta=sampling_delta,
248 | device=device,
249 | )
250 |
251 | # Reshape tensors for processing
252 | df_values, lds, normals, samples = (
253 | df_values.reshape(-1),
254 | lds.reshape(-1, 3),
255 | normals.reshape(-1, 3),
256 | samples.reshape(-1, 12),
257 | )
258 | xyz = samples[:, 0:3]
259 | df_values.clamp_(min=0) # Ensure distance field values are non-negative
260 |
261 | # Filter out points too far from the surface
262 | points_idx = df_values <= udf_threshold
263 | filtered_xyz, filtered_lds, normals, df_values = (
264 | xyz[points_idx],
265 | lds[points_idx],
266 | normals[points_idx],
267 | df_values[points_idx],
268 | )
269 |
270 | # Point shifting
271 | if is_pointshift and iters > 0:
272 | for iter in range(iters):
273 | shifted_xyz = filtered_xyz + df_values.unsqueeze(-1) * normals
274 | shifted_df_values, shifted_normals, filtered_lds, _ = get_udf_normals_slow(
275 | func=func,
276 | func_grad=func_grad,
277 | voxel_size=voxel_size,
278 | xyz=shifted_xyz,
279 | is_linedirection=True if iter == iters - 1 else False,
280 | device=device,
281 | )
282 | shifted_points_idx = shifted_df_values <= udf_threshold
283 | filtered_xyz, df_values, normals, filtered_lds = (
284 | shifted_xyz[shifted_points_idx],
285 | shifted_df_values[shifted_points_idx],
286 | shifted_normals[shifted_points_idx],
287 | filtered_lds[shifted_points_idx],
288 | )
289 |
290 | return (
291 | filtered_xyz.cpu().numpy(),
292 | filtered_lds.cpu().numpy() if filtered_lds is not None else None,
293 | )
294 |
--------------------------------------------------------------------------------
/src/edge_extraction/extract_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import json
4 | import point_cloud_utils as pcu
5 | import math
6 | import open3d as o3d
7 | import trimesh
8 |
9 |
10 | def downsample_point_cloud_average(
11 | points, num_voxels_per_axis=256, min_bound=None, max_bound=None
12 | ):
13 | """
14 | Downsample a point set based on the number of voxels per axis by averaging the points within each voxel.
15 |
16 | Args:
17 | points: a [#v, 3]-shaped array of 3d points.
18 | num_voxels_per_axis: a scalar or 3-tuple specifying the number of voxels along each axis.
19 |
20 | Returns:
21 | A [#v', 3]-shaped numpy array of downsampled points, where #v' is the number of occupied voxels.
22 | """
23 |
24 | # Calculate the bounding box of the point cloud
25 | if min_bound is None:
26 | min_bound = np.min(points, axis=0)
27 | else:
28 | min_bound = np.array(min_bound)
29 |
30 | if max_bound is None:
31 | max_bound = np.max(points, axis=0)
32 | else:
33 | max_bound = np.array(max_bound)
34 |
35 | # Determine the size of the voxel based on the desired number of voxels per axis
36 | if isinstance(num_voxels_per_axis, int):
37 | voxel_size = (max_bound - min_bound) / num_voxels_per_axis
38 | else:
39 | voxel_size = [
40 | (max_bound[i] - min_bound[i]) / num_voxels_per_axis[i] for i in range(3)
41 | ]
42 |
43 | # Use the existing function to downsample the point cloud based on voxel size
44 | downsampled_points = pcu.downsample_point_cloud_on_voxel_grid(
45 | voxel_size, points, min_bound=min_bound, max_bound=max_bound
46 | )
47 |
48 | return downsampled_points
49 |
50 |
51 | def get_gt_points(
52 | name,
53 | edge_type="all",
54 | interval=0.005,
55 | return_CAD=False,
56 | return_direction=False,
57 | base_dir=None,
58 | ):
59 | """
60 | Get ground truth points from a dataset.
61 |
62 | Args:
63 | name (str): Name of the dataset.
64 |
65 | Returns:
66 | numpy.ndarray: Raw and processed ground truth points.
67 | """
68 | objs_dir = os.path.join(base_dir, "obj")
69 | obj_names = os.listdir(objs_dir)
70 | obj_names.sort()
71 | index_obj_names = {}
72 | for obj_name in obj_names:
73 | index_obj_names[obj_name[:8]] = obj_name
74 |
75 | json_feats_path = os.path.join(base_dir, "chunk_0000_feats.json")
76 | with open(json_feats_path, "r") as f:
77 | json_data_feats = json.load(f)
78 | json_stats_path = os.path.join(base_dir, "chunk_0000_stats.json")
79 | with open(json_stats_path, "r") as f:
80 | json_data_stats = json.load(f)
81 |
82 | # get the normalize scale to help align the nerf points and gt points
83 | [
84 | x_min,
85 | y_min,
86 | z_min,
87 | x_max,
88 | y_max,
89 | z_max,
90 | x_range,
91 | y_range,
92 | z_range,
93 | ] = json_data_stats[name]["bbox"]
94 | scale = 1 / max(x_range, y_range, z_range)
95 | poi_center = (
96 | np.array([((x_min + x_max) / 2), ((y_min + y_max) / 2), ((z_min + z_max) / 2)])
97 | * scale
98 | )
99 | set_location = [0.5, 0.5, 0.5] - poi_center # based on the rendering settings
100 |
101 | obj_path = os.path.join(objs_dir, index_obj_names[name])
102 | if return_CAD:
103 | cad_obj = trimesh.load_mesh(obj_path)
104 | else:
105 | cad_obj = None
106 |
107 | with open(obj_path, encoding="utf-8") as file:
108 | data = file.readlines()
109 |
110 | vertices_obj = [each.split(" ") for each in data if each.split(" ")[0] == "v"]
111 | vertices_xyz = [
112 | [float(v[1]), float(v[2]), float(v[3].replace("\n", ""))] for v in vertices_obj
113 | ]
114 |
115 | edge_pts = []
116 | edge_pts_raw = []
117 | edge_pts_direction = []
118 | rename = {
119 | "BSpline": "curve",
120 | "Circle": "curve",
121 | "Ellipse": "curve",
122 | "Line": "line",
123 | }
124 | for each_curve in json_data_feats[name]:
125 | if edge_type != "all" and rename[each_curve["type"]] != edge_type:
126 | continue
127 |
128 | if each_curve["sharp"]: # each_curve["type"]: BSpline, Line, Circle
129 | each_edge_pts = [vertices_xyz[i] for i in each_curve["vert_indices"]]
130 | edge_pts_raw += each_edge_pts
131 |
132 | gt_sampling = []
133 | each_edge_pts = np.array(each_edge_pts)
134 | for index in range(len(each_edge_pts) - 1):
135 | next = each_edge_pts[index + 1]
136 | current = each_edge_pts[index]
137 | num = int(np.linalg.norm(next - current) // interval)
138 | linspace = np.linspace(0, 1, num)
139 | gt_sampling.append(
140 | linspace[:, None] * current + (1 - linspace)[:, None] * next
141 | )
142 |
143 | if return_direction:
144 | direction = (next - current) / np.linalg.norm(next - current)
145 | edge_pts_direction.extend([direction] * num)
146 | each_edge_pts = np.concatenate(gt_sampling).tolist()
147 | edge_pts += each_edge_pts
148 |
149 | if len(edge_pts_raw) == 0:
150 | return None, None, None, None
151 |
152 | edge_pts_raw = np.array(edge_pts_raw) * scale + set_location
153 | edge_pts = np.array(edge_pts) * scale + set_location
154 | edge_pts_direction = np.array(edge_pts_direction)
155 |
156 | return (
157 | edge_pts_raw.astype(np.float32),
158 | edge_pts.astype(np.float32),
159 | cad_obj,
160 | edge_pts_direction,
161 | )
162 |
163 |
164 | def chamfer_distance(x, y, return_index=False, p_norm=2, max_points_per_leaf=10):
165 | """
166 | Compute the chamfer distance between two point clouds x, and y
167 |
168 | Args:
169 | x : A m-sized minibatch of point sets in R^d. i.e. shape [m, n_a, d]
170 | y : A m-sized minibatch of point sets in R^d. i.e. shape [m, n_b, d]
171 | return_index: If set to True, will return a pair (corrs_x_to_y, corrs_y_to_x) where
172 | corrs_x_to_y[i] stores the index into y of the closest point to x[i]
173 | (i.e. y[corrs_x_to_y[i]] is the nearest neighbor to x[i] in y).
174 | corrs_y_to_x is similar to corrs_x_to_y but with x and y reversed.
175 | max_points_per_leaf : The maximum number of points per leaf node in the KD tree used by this function.
176 | Default is 10.
177 | p_norm : Which norm to use. p_norm can be any real number, inf (for the max norm) -inf (for the min norm),
178 | 0 (for sum(x != 0))
179 | Returns:
180 | The chamfer distance between x an dy.
181 | If return_index is set, then this function returns a tuple (chamfer_dist, corrs_x_to_y, corrs_y_to_x) where
182 | corrs_x_to_y and corrs_y_to_x are described above.
183 | """
184 |
185 | dists_x_to_y, corrs_x_to_y = pcu.k_nearest_neighbors(
186 | x, y, k=1, squared_distances=False, max_points_per_leaf=max_points_per_leaf
187 | )
188 | dists_y_to_x, corrs_y_to_x = pcu.k_nearest_neighbors(
189 | y, x, k=1, squared_distances=False, max_points_per_leaf=max_points_per_leaf
190 | )
191 |
192 | dists_x_to_y = np.linalg.norm(x[corrs_y_to_x] - y, axis=-1, ord=p_norm).mean()
193 | dists_y_to_x = np.linalg.norm(y[corrs_x_to_y] - x, axis=-1, ord=p_norm).mean()
194 |
195 | Comp = np.mean(dists_x_to_y)
196 | Acc = np.mean(dists_y_to_x)
197 | cham_dist = Comp + Acc
198 |
199 | if return_index:
200 | return cham_dist, corrs_x_to_y, corrs_y_to_x
201 |
202 | return cham_dist, Acc, Comp
203 |
204 |
205 | def compute_chamfer_distance(pred_sampled, gt_points):
206 | """
207 | Compute chamfer distance between predicted and ground truth points.
208 |
209 | Args:
210 | pred_sampled (numpy.ndarray): Predicted point cloud.
211 | gt_points (numpy.ndarray): Ground truth points.
212 |
213 | Returns:
214 | float: Chamfer distance.
215 | """
216 | chamfer_dist, acc, comp = chamfer_distance(pred_sampled, gt_points)
217 | return chamfer_dist, acc, comp
218 |
219 |
220 | def compute_precision_recall_IOU(
221 | pred_sampled, gt_points, metrics, thresh_list=[0.02], edge_type="all"
222 | ):
223 | """
224 | Compute precision, recall, F-score, and IOU.
225 |
226 | Args:
227 | pred_sampled (numpy.ndarray): Predicted point cloud.
228 | gt_points (numpy.ndarray): Ground truth points.
229 | metrics (dict): Dictionary to store metrics.
230 |
231 | Returns:
232 | dict: Updated metrics.
233 | """
234 | if edge_type == "all":
235 | for thresh in thresh_list:
236 | dists_a_to_b, _ = pcu.k_nearest_neighbors(
237 | pred_sampled, gt_points, k=1
238 | ) # k closest points (in pts_b) for each point in pts_a
239 | correct_pred = np.sum(dists_a_to_b < thresh)
240 | precision = correct_pred / len(dists_a_to_b)
241 | metrics[f"precision_{thresh}"].append(precision)
242 |
243 | dists_b_to_a, _ = pcu.k_nearest_neighbors(gt_points, pred_sampled, k=1)
244 | correct_gt = np.sum(dists_b_to_a < thresh)
245 | recall = correct_gt / len(dists_b_to_a)
246 | metrics[f"recall_{thresh}"].append(recall)
247 |
248 | fscore = 2 * precision * recall / (precision + recall)
249 | metrics[f"fscore_{thresh}"].append(fscore)
250 |
251 | intersection = min(correct_pred, correct_gt)
252 | union = (
253 | len(dists_a_to_b) + len(dists_b_to_a) - max(correct_pred, correct_gt)
254 | )
255 |
256 | IOU = intersection / union
257 | metrics[f"IOU_{thresh}"].append(IOU)
258 | return metrics
259 | else:
260 | correct_gt_list = []
261 | correct_pred_list = []
262 | _, acc, comp = compute_chamfer_distance(pred_sampled, gt_points)
263 | for thresh in thresh_list:
264 | dists_b_to_a, _ = pcu.k_nearest_neighbors(gt_points, pred_sampled, k=1)
265 | correct_gt = np.sum(dists_b_to_a < thresh)
266 | num_gt = len(dists_b_to_a)
267 | correct_gt_list.append(correct_gt)
268 |
269 | dists_a_to_b, _ = pcu.k_nearest_neighbors(pred_sampled, gt_points, k=1)
270 | correct_pred = np.sum(dists_a_to_b < thresh)
271 | correct_pred_list.append(correct_pred)
272 | num_pred = len(dists_a_to_b)
273 |
274 | return correct_gt_list, num_gt, correct_pred_list, num_pred, acc, comp
275 |
276 |
277 | def f_score(precision, recall):
278 | """
279 | Compute F-score.
280 |
281 | Args:
282 | precision (float): Precision.
283 | recall (float): Recall.
284 |
285 | Returns:
286 | float: F-score.
287 | """
288 | return 2 * precision * recall / (precision + recall)
289 |
290 |
291 | def bezier_curve_length(control_points, num_samples):
292 | def binomial_coefficient(n, i):
293 | return math.factorial(n) // (math.factorial(i) * math.factorial(n - i))
294 |
295 | def derivative_bezier(t):
296 | n = len(control_points) - 1
297 | point = np.array([0.0, 0.0, 0.0])
298 | for i, (p1, p2) in enumerate(zip(control_points[:-1], control_points[1:])):
299 | point += (
300 | n
301 | * binomial_coefficient(n - 1, i)
302 | * (1 - t) ** (n - 1 - i)
303 | * t**i
304 | * (np.array(p2) - np.array(p1))
305 | )
306 | return point
307 |
308 | def simpson_integral(a, b, num_samples):
309 | h = (b - a) / num_samples
310 | sum1 = sum(
311 | np.linalg.norm(derivative_bezier(a + i * h))
312 | for i in range(1, num_samples, 2)
313 | )
314 | sum2 = sum(
315 | np.linalg.norm(derivative_bezier(a + i * h))
316 | for i in range(2, num_samples - 1, 2)
317 | )
318 | return (
319 | (
320 | np.linalg.norm(derivative_bezier(a))
321 | + 4 * sum1
322 | + 2 * sum2
323 | + np.linalg.norm(derivative_bezier(b))
324 | )
325 | * h
326 | / 3
327 | )
328 |
329 | # Compute the length of the 3D Bezier curve using composite Simpson's rule
330 | length = 0.0
331 | for i in range(num_samples):
332 | t0 = i / num_samples
333 | t1 = (i + 1) / num_samples
334 | length += simpson_integral(t0, t1, num_samples)
335 |
336 | return length
337 |
--------------------------------------------------------------------------------
/src/edge_extraction/merging/main.py:
--------------------------------------------------------------------------------
1 | from src.edge_extraction.edge_fitting.main import save_3d_lines_to_file
2 | from src.edge_extraction.edge_fitting.line_fit import line_fitting
3 | from src.edge_extraction.edge_fitting.bezier_fit import (
4 | bezier_fit,
5 | bezier_curve,
6 | )
7 | import numpy as np
8 | from sklearn.metrics.pairwise import cosine_similarity
9 | from sklearn.metrics import silhouette_score
10 | from scipy.sparse.csgraph import connected_components
11 | import open3d as o3d
12 | from scipy.spatial.distance import euclidean, cdist
13 | import os
14 |
15 |
16 | def line_segment_point_distance(line_segment, query_point):
17 | """Compute the Euclidean distance between a line segment and a query point.
18 |
19 | Parameters:
20 | line_segment (np.ndarray): An array of shape (6,), representing two 3D endpoints.
21 | query_point (np.ndarray): An array of shape (3,), representing the 3D query point.
22 |
23 | Returns:
24 | float: The minimum distance from the query point to the line segment.
25 | """
26 | point1, point2 = line_segment[:3], line_segment[3:]
27 | point_delta = point2 - point1
28 | u = np.clip(
29 | np.dot(query_point - point1, point_delta) / np.dot(point_delta, point_delta),
30 | 0,
31 | 1,
32 | )
33 | closest_point = point1 + u * point_delta
34 | return np.linalg.norm(closest_point - query_point)
35 |
36 |
37 | def compute_pairwise_distances(line_segments):
38 | """Compute pairwise distances between line segments.
39 |
40 | Parameters:
41 | line_segments (np.ndarray): An array of shape (N, 6), each row represents a line segment in 3D.
42 |
43 | Returns:
44 | np.ndarray: A symmetric array of shape (N, N), containing pairwise distances.
45 | """
46 | num_lines = len(line_segments)
47 | endpoints = line_segments.reshape(-1, 3)
48 | dist_matrix = np.zeros((num_lines, num_lines))
49 |
50 | for i, line_segment in enumerate(line_segments):
51 | for j in range(i + 1, num_lines):
52 | min_distance = min(
53 | line_segment_point_distance(line_segment, endpoints[2 * j]),
54 | line_segment_point_distance(line_segment, endpoints[2 * j + 1]),
55 | )
56 | dist_matrix[i, j] = min_distance
57 |
58 | dist_matrix += dist_matrix.T # Make the matrix symmetric
59 | return dist_matrix
60 |
61 |
62 | def compute_pairwise_cosine_similarity(line_segments):
63 | direction_vectors = line_segments[:, 3:] - line_segments[:, :3]
64 | pairwise_similarity = cosine_similarity(direction_vectors)
65 | return pairwise_similarity
66 |
67 |
68 | def bezier_curve_distance(points1, points2):
69 | distances = np.linalg.norm(points1[:, np.newaxis] - points2, axis=2)
70 | min_distance = np.min(distances)
71 | return min_distance
72 |
73 |
74 | def bezier_slope_vector(P0, P1, P2, P3, t):
75 | # Calculate the derivative of the 3D Bézier curve
76 | dp_dt = (
77 | -3 * (1 - t) ** 2 * P0
78 | + 3 * (1 - 4 * t + 3 * t**2) * P1
79 | + 3 * (2 * t - 3 * t**2) * P2
80 | + 3 * t**2 * P3
81 | )
82 | return dp_dt
83 |
84 |
85 | def get_dist_similarity(control_points1, control_points2, points1, points2, t_values):
86 | control_points1 = control_points1.reshape(-1, 3)
87 | control_points2 = control_points2.reshape(-1, 3)
88 |
89 | # Calculate the pairwise distances between all points in the two sets
90 | distances = cdist(points1, points2)
91 |
92 | # Find the indices of the minimum distance pair
93 | min_indices = np.unravel_index(np.argmin(distances), distances.shape)
94 | min_distance = distances[min_indices]
95 |
96 | points_slope1 = bezier_slope_vector(
97 | control_points1[0],
98 | control_points1[1],
99 | control_points1[2],
100 | control_points1[3],
101 | t_values[min_indices[0]],
102 | )
103 |
104 | points_slope2 = bezier_slope_vector(
105 | control_points2[0],
106 | control_points2[1],
107 | control_points2[2],
108 | control_points2[3],
109 | t_values[min_indices[1]],
110 | )
111 |
112 | # Compute cosine similarity between the two slopes
113 | similarity = np.abs(np.dot(points_slope1, points_slope2)) / (
114 | np.linalg.norm(points_slope1) * np.linalg.norm(points_slope2)
115 | )
116 |
117 | return min_distance, similarity
118 |
119 |
120 | def merge_line_segments(
121 | line_segments, raw_points_on_lines, distance_threshold, similarity_threshold
122 | ):
123 | dist_matrix = compute_pairwise_distances(line_segments)
124 | similarity_matrix = compute_pairwise_cosine_similarity(line_segments)
125 |
126 | # Create adjacency matrix based on distance and similarity thresholds
127 | adjacency_matrix = (dist_matrix <= distance_threshold) & (
128 | similarity_matrix >= similarity_threshold
129 | )
130 | # Compute connected components
131 | num_components, labels = connected_components(adjacency_matrix)
132 |
133 | merged_line_segments = []
134 | for component in range(num_components):
135 | component_indices = np.where(labels == component)[0]
136 | if len(component_indices) == 1:
137 | merged_line_segments.append(line_segments[component_indices[0]])
138 | continue
139 | else:
140 | raw_points_on_lines_group = raw_points_on_lines[component_indices]
141 | raw_points_on_lines_group_array = np.array(
142 | [
143 | point
144 | for raw_points_on_lines in raw_points_on_lines_group
145 | for point in raw_points_on_lines
146 | ]
147 | ).reshape(-1, 3)
148 |
149 | try:
150 | line_segment, _ = line_fitting(raw_points_on_lines_group_array)
151 | merged_line_segments.append(line_segment)
152 | except:
153 | continue
154 |
155 | merged_line_segments = np.array(merged_line_segments)
156 | return merged_line_segments
157 |
158 |
159 | def merge_bezier_curves(
160 | control_points_list,
161 | raw_points_on_curves,
162 | distance_threshold,
163 | similarity_threshold,
164 | num_samples=100,
165 | ):
166 | """
167 | Merge Bezier curves based on distance and similarity thresholds.
168 |
169 | Parameters:
170 | control_points_list (array): list of control points for each curve.
171 | raw_points_on_curves (array): actual points on each Bezier curve.
172 | distance_threshold (float): maximum distance for merging curves.
173 | similarity_threshold (float): minimum similarity for merging curves.
174 | num_samples (int, optional): number of samples for curve points.
175 |
176 | Returns:
177 | np.array: Merged control points for the Bezier curves.
178 | """
179 |
180 | if not isinstance(control_points_list, np.ndarray) or not isinstance(
181 | raw_points_on_curves, np.ndarray
182 | ):
183 | raise ValueError("Input must be NumPy arrays.")
184 |
185 | num_curves = len(control_points_list)
186 | dist_matrix = np.zeros((num_curves, num_curves))
187 | similarity_matrix = np.zeros((num_curves, num_curves))
188 | t_values = np.linspace(0, 1, num_samples)
189 |
190 | for i, control_points1 in enumerate(control_points_list):
191 | for j in range(i + 1, num_curves):
192 | control_points2 = control_points_list[j]
193 | points1 = bezier_curve(t_values, *(control_points1.tolist())).reshape(-1, 3)
194 | points2 = bezier_curve(t_values, *(control_points2.tolist())).reshape(-1, 3)
195 | dist_matrix[i, j], similarity_matrix[i, j] = get_dist_similarity(
196 | control_points1, control_points2, points1, points2, t_values
197 | )
198 |
199 | dist_matrix += dist_matrix.T
200 | similarity_matrix += similarity_matrix.T
201 |
202 | adjacency_matrix = (dist_matrix <= distance_threshold) & (
203 | similarity_matrix >= similarity_threshold
204 | )
205 | num_components, labels = connected_components(adjacency_matrix)
206 |
207 | merged_bezier_curves = []
208 | for component in range(num_components):
209 | component_indices = np.where(labels == component)[0]
210 | if len(component_indices) == 1:
211 | p = control_points_list[component_indices[0]]
212 | merged_bezier_curves.append(p)
213 | else:
214 | raw_points_group = [raw_points_on_curves[i] for i in component_indices]
215 | raw_points_group_array = np.concatenate(raw_points_group, axis=0)
216 | p = bezier_fit(raw_points_group_array)
217 | merged_bezier_curves.append(p)
218 |
219 | return np.array(merged_bezier_curves)
220 |
221 |
222 | def merge_endpoints(merged_line_segments, merged_bezier_curves, distance_threshold):
223 | N_lines = len(merged_line_segments)
224 | N_curves = len(merged_bezier_curves)
225 |
226 | if N_lines == 0 and N_curves == 0:
227 | return [], []
228 |
229 | if N_lines > 0:
230 | line_endpoints = merged_line_segments.reshape(-1, 3)
231 | else:
232 | line_endpoints = np.array([]).reshape(-1, 3)
233 |
234 | if N_curves > 0:
235 | curve_endpoints = merged_bezier_curves[:, [0, 1, 2, -3, -2, -1]].reshape(-1, 3)
236 | else:
237 | curve_endpoints = np.array([]).reshape(-1, 3)
238 |
239 | concat_endpoints = np.concatenate([line_endpoints, curve_endpoints], axis=0)
240 |
241 | dist_matrix = cdist(concat_endpoints, concat_endpoints)
242 | adjacency_matrix = dist_matrix <= distance_threshold
243 | num_components, labels = connected_components(adjacency_matrix)
244 | for component in range(num_components):
245 | component_indices = np.where(labels == component)[0]
246 | if len(component_indices) > 1:
247 | endpoints = concat_endpoints[component_indices]
248 | mean_endpoint = np.mean(endpoints, axis=0)
249 | concat_endpoints[component_indices] = mean_endpoint
250 |
251 | if N_lines > 0:
252 | merged_line_segments_merged_endpoints = concat_endpoints[: N_lines * 2].reshape(
253 | -1, 6
254 | )
255 | else:
256 | merged_line_segments_merged_endpoints = []
257 |
258 | if N_curves > 0:
259 | merged_curve_segments_merged_endpoints = np.zeros_like(merged_bezier_curves)
260 | curve_merged_endpoints = concat_endpoints[N_lines * 2 :].reshape(-1, 6)
261 | merged_curve_segments_merged_endpoints[:, :3] = curve_merged_endpoints[:, :3]
262 | merged_curve_segments_merged_endpoints[:, 3:9] = merged_bezier_curves[:, 3:9]
263 | merged_curve_segments_merged_endpoints[:, 9:] = curve_merged_endpoints[:, 3:]
264 |
265 | else:
266 | merged_curve_segments_merged_endpoints = []
267 |
268 | return merged_line_segments_merged_endpoints, merged_curve_segments_merged_endpoints
269 |
270 |
271 | def approximate_curve_length(P0, P1, P2, P3):
272 | return np.linalg.norm(P1 - P0) + np.linalg.norm(P2 - P1) + np.linalg.norm(P3 - P2)
273 |
274 |
275 | def generate_points_curve(curves):
276 | all_points = []
277 | for curve in curves:
278 | P0 = np.array(curve[:3])
279 | P1 = np.array(curve[3:6])
280 | P2 = np.array(curve[6:9])
281 | P3 = np.array(curve[9:])
282 |
283 | # Approximate number of points based on curve length
284 | length = approximate_curve_length(P0, P1, P2, P3)
285 | num_points = int(np.ceil(length * 1000)) # Adjust the factor as needed
286 |
287 | t_values = np.linspace(0, 1, num_points)
288 | curve_points = np.array(bezier_curve(t_values, *(curve.tolist())))
289 |
290 | all_points.extend(curve_points)
291 |
292 | return np.array(all_points).reshape(-1, 3)
293 |
294 |
295 | def merge(
296 | out_dir,
297 | fitted_edge_dict,
298 | merge_edge_distance_threshold=5.0,
299 | merge_endpoints_distance_threshold=1.0,
300 | merge_similarity_threshold=0.98,
301 | merge_endpoints_flag=True,
302 | merge_edge_flag=True,
303 | merge_curve_flag=False,
304 | save_ply=False,
305 | ):
306 |
307 | resolution = int(fitted_edge_dict["resolution"])
308 | lines = np.array(fitted_edge_dict["lines_end_pts"]).reshape(-1, 6)
309 | raw_points_on_lines = np.array(
310 | fitted_edge_dict["raw_points_on_lines"], dtype=object
311 | )
312 | bezier_curves = np.array(fitted_edge_dict["curves_ctl_pts"]).reshape(-1, 12)
313 | raw_points_on_curves = np.array(
314 | fitted_edge_dict["raw_points_on_curves"], dtype=object
315 | )
316 |
317 | # Normalize thresholds
318 | merge_edge_distance_threshold /= resolution
319 | merge_endpoints_distance_threshold /= resolution
320 |
321 | # Merge lines
322 | if merge_edge_flag and len(lines) > 0:
323 | merged_line_segments = merge_line_segments(
324 | lines,
325 | raw_points_on_lines,
326 | merge_edge_distance_threshold / 2.0,
327 | merge_similarity_threshold,
328 | )
329 | else:
330 | merged_line_segments = lines
331 |
332 | if merge_curve_flag and merge_edge_flag:
333 | if len(bezier_curves) > 0:
334 | merged_bezier_curves = merge_bezier_curves(
335 | bezier_curves,
336 | raw_points_on_curves,
337 | merge_edge_distance_threshold,
338 | merge_similarity_threshold,
339 | )
340 | else:
341 | merged_bezier_curves = []
342 | else:
343 | merged_bezier_curves = bezier_curves
344 |
345 | if merge_endpoints_flag:
346 | (
347 | merged_line_segments,
348 | merged_bezier_curves,
349 | ) = merge_endpoints(
350 | merged_line_segments,
351 | merged_bezier_curves,
352 | merge_endpoints_distance_threshold,
353 | )
354 |
355 | if save_ply:
356 | if len(merged_line_segments) > 0:
357 | save_3d_lines_to_file(
358 | merged_line_segments,
359 | os.path.join(out_dir, "merged_line_segments.ply"),
360 | width=2,
361 | scale=1.0,
362 | )
363 | print(f"Saved merged line segments to {out_dir}.")
364 |
365 | if len(merged_bezier_curves) > 0:
366 | pcd = o3d.geometry.PointCloud()
367 | meregd_bezier_curves_points = generate_points_curve(merged_bezier_curves)
368 | pcd.points = o3d.utility.Vector3dVector(meregd_bezier_curves_points)
369 | o3d.io.write_point_cloud(
370 | os.path.join(out_dir, "merged_bezier_curve_points.ply"),
371 | pcd,
372 | write_ascii=True,
373 | )
374 | print(f"Saved merged bezier curves to {out_dir}.")
375 |
376 | merged_edge_dict = {
377 | "lines_end_pts": (
378 | merged_line_segments.tolist() if len(merged_line_segments) > 0 else []
379 | ),
380 | "curves_ctl_pts": (
381 | merged_bezier_curves.tolist() if len(merged_bezier_curves) > 0 else []
382 | ),
383 | }
384 |
385 | return merged_edge_dict
386 |
--------------------------------------------------------------------------------
/src/eval/ABC_scans.txt:
--------------------------------------------------------------------------------
1 | 00000325
--------------------------------------------------------------------------------
/src/eval/DTU_scans.txt:
--------------------------------------------------------------------------------
1 | scan105
--------------------------------------------------------------------------------
/src/eval/eval_ABC.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import argparse
4 | from src.eval.eval_util import (
5 | set_random_seeds,
6 | load_from_json,
7 | compute_chamfer_distance,
8 | f_score,
9 | compute_precision_recall_IOU,
10 | downsample_point_cloud_average,
11 | get_gt_points,
12 | get_pred_points_and_directions,
13 | )
14 |
15 |
16 | def update_totals_and_metrics(metrics, totals, results, edge_type):
17 | correct_gt, num_gt, correct_pred, num_pred, acc, comp = results
18 | metrics[f"comp_{edge_type}"].append(comp)
19 | metrics[f"acc_{edge_type}"].append(acc)
20 | for i, threshold in enumerate(["5", "10", "20"]):
21 | totals[f"thre{threshold}_correct_gt_total"] += correct_gt[i]
22 | totals[f"thre{threshold}_correct_pred_total"] += correct_pred[i]
23 | totals["num_gt_total"] += num_gt
24 | totals["num_pred_total"] += num_pred
25 |
26 |
27 | def finalize_metrics(metrics):
28 | for key, value in metrics.items():
29 | value = np.array(value)
30 | value[np.isnan(value)] = 0
31 | metrics[key] = round(np.mean(value), 4)
32 | return metrics
33 |
34 |
35 | def print_metrics(metrics, totals, edge_type):
36 | print(f"{edge_type.capitalize()}:")
37 | print(f" Completeness: {metrics[f'comp_{edge_type}']}")
38 | print(f" Accuracy: {metrics[f'acc_{edge_type}']}")
39 |
40 |
41 | def process_scan(scan_name, base_dir, exp_name, dataset_dir, metrics, totals):
42 | print(f"Processing: {scan_name}")
43 | json_path = os.path.join(
44 | base_dir, scan_name, exp_name, "results", "parametric_edges.json"
45 | )
46 | if not os.path.exists(json_path):
47 | print(f"Invalid prediction at {scan_name}")
48 | return
49 |
50 | all_curve_points, all_line_points, all_curve_directions, all_line_directions = (
51 | get_pred_points_and_directions(json_path)
52 | )
53 | pred_points = (
54 | np.concatenate([all_curve_points, all_line_points], axis=0)
55 | .reshape(-1, 3)
56 | .astype(np.float32)
57 | )
58 |
59 | if len(pred_points) == 0:
60 | print(f"Invalid prediction at {scan_name}")
61 | return
62 |
63 | pred_sampled = downsample_point_cloud_average(
64 | pred_points,
65 | num_voxels_per_axis=256,
66 | min_bound=[-1, -1, -1],
67 | max_bound=[1, 1, 1],
68 | )
69 |
70 | gt_points_raw, gt_points, _ = get_gt_points(
71 | scan_name, "all", data_base_dir=os.path.join(dataset_dir, "groundtruth")
72 | )
73 | if gt_points_raw is None:
74 | return
75 |
76 | chamfer_dist, acc, comp = compute_chamfer_distance(pred_sampled, gt_points)
77 | print(
78 | f" Chamfer Distance: {chamfer_dist:.4f}, Accuracy: {acc:.4f}, Completeness: {comp:.4f}"
79 | )
80 | metrics["chamfer"].append(chamfer_dist)
81 | metrics["acc"].append(acc)
82 | metrics["comp"].append(comp)
83 | metrics = compute_precision_recall_IOU(
84 | pred_sampled,
85 | gt_points,
86 | metrics,
87 | thresh_list=[0.005, 0.01, 0.02],
88 | edge_type="all",
89 | )
90 |
91 | for edge_type in ["curve", "line"]:
92 | gt_points_raw_edge, gt_points_edge, _ = get_gt_points(
93 | scan_name,
94 | edge_type,
95 | return_direction=True,
96 | data_base_dir=os.path.join(dataset_dir, "groundtruth"),
97 | )
98 | if gt_points_raw_edge is not None:
99 | results = compute_precision_recall_IOU(
100 | pred_sampled,
101 | gt_points_edge,
102 | None,
103 | thresh_list=[0.005, 0.01, 0.02],
104 | edge_type=edge_type,
105 | )
106 | update_totals_and_metrics(metrics, totals[edge_type], results, edge_type)
107 |
108 |
109 | def main(base_dir, dataset_dir, exp_name):
110 | set_random_seeds()
111 | metrics = {
112 | "chamfer": [],
113 | "acc": [],
114 | "comp": [],
115 | "comp_curve": [],
116 | "comp_line": [],
117 | "acc_curve": [],
118 | "acc_line": [],
119 | "precision_0.01": [],
120 | "recall_0.01": [],
121 | "fscore_0.01": [],
122 | "IOU_0.01": [],
123 | "precision_0.02": [],
124 | "recall_0.02": [],
125 | "fscore_0.02": [],
126 | "IOU_0.02": [],
127 | "precision_0.005": [],
128 | "recall_0.005": [],
129 | "fscore_0.005": [],
130 | "IOU_0.005": [],
131 | }
132 |
133 | totals = {
134 | "curve": {
135 | "thre5_correct_gt_total": 0,
136 | "thre10_correct_gt_total": 0,
137 | "thre20_correct_gt_total": 0,
138 | "thre5_correct_pred_total": 0,
139 | "thre10_correct_pred_total": 0,
140 | "thre20_correct_pred_total": 0,
141 | "num_gt_total": 0,
142 | "num_pred_total": 0,
143 | },
144 | "line": {
145 | "thre5_correct_gt_total": 0,
146 | "thre10_correct_gt_total": 0,
147 | "thre20_correct_gt_total": 0,
148 | "thre5_correct_pred_total": 0,
149 | "thre10_correct_pred_total": 0,
150 | "thre20_correct_pred_total": 0,
151 | "num_gt_total": 0,
152 | "num_pred_total": 0,
153 | },
154 | }
155 |
156 | with open("src/eval/ABC_scans.txt", "r") as f:
157 | scan_names = [line.strip() for line in f]
158 |
159 | for scan_name in scan_names:
160 | process_scan(scan_name, base_dir, exp_name, dataset_dir, metrics, totals)
161 |
162 | metrics = finalize_metrics(metrics)
163 |
164 | print("Summary:")
165 | print(f" Accuracy: {metrics['acc']:.4f}")
166 | print(f" Completeness: {metrics['comp']:.4f}")
167 | print(f" Recall @ 5 mm: {metrics['recall_0.005']:.4f}")
168 | print(f" Recall @ 10 mm: {metrics['recall_0.01']:.4f}")
169 | print(f" Recall @ 20 mm: {metrics['recall_0.02']:.4f}")
170 | print(f" Precision @ 5 mm: {metrics['precision_0.005']:.4f}")
171 | print(f" Precision @ 10 mm: {metrics['precision_0.01']:.4f}")
172 | print(f" Precision @ 20 mm: {metrics['precision_0.02']:.4f}")
173 | print(f" F-Score @ 5 mm: {metrics['fscore_0.005']:.4f}")
174 | print(f" F-Score @ 10 mm: {metrics['fscore_0.01']:.4f}")
175 | print(f" F-Score @ 20 mm: {metrics['fscore_0.02']:.4f}")
176 |
177 | if totals["curve"]["num_gt_total"] > 0:
178 | print_metrics(metrics, totals["curve"], "curve")
179 | else:
180 | print("Curve: No ground truth edges found.")
181 |
182 | if totals["line"]["num_gt_total"] > 0:
183 | print_metrics(metrics, totals["line"], "line")
184 | else:
185 | print("Line: No ground truth edges found.")
186 |
187 |
188 | if __name__ == "__main__":
189 | parser = argparse.ArgumentParser(
190 | description="Process CAD data and compute metrics."
191 | )
192 | parser.add_argument(
193 | "--base_dir",
194 | type=str,
195 | default="./exp/ABC",
196 | help="Base directory for experiments",
197 | )
198 | parser.add_argument(
199 | "--dataset_dir",
200 | type=str,
201 | default="./data/ABC-NEF_Edge",
202 | help="Directory for the dataset",
203 | )
204 | parser.add_argument("--exp_name", type=str, default="emap", help="Experiment name")
205 |
206 | args = parser.parse_args()
207 | main(args.base_dir, args.dataset_dir, args.exp_name)
208 |
--------------------------------------------------------------------------------
/src/eval/eval_DTU.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import open3d as o3d
3 | import sklearn.neighbors as skln
4 | from tqdm import tqdm
5 | from scipy.io import loadmat
6 | import multiprocessing as mp
7 | import argparse
8 | import os
9 | from pathlib import Path
10 | from src.eval.eval_util import (
11 | set_random_seeds,
12 | load_from_json,
13 | downsample_point_cloud_average,
14 | get_pred_points_and_directions,
15 | )
16 |
17 |
18 | def process_scan(
19 | scan_name,
20 | base_dir,
21 | exp_name,
22 | dataset_dir,
23 | threshold,
24 | downsample_density,
25 | precision_list,
26 | recall_list,
27 | ):
28 | print(f"Processing: {scan_name}")
29 | json_path = os.path.join(
30 | base_dir, scan_name, exp_name, "results", "parametric_edges.json"
31 | )
32 | if not os.path.exists(json_path):
33 | print(f"Invalid prediction at {scan_name}")
34 | return
35 |
36 | meta_data_json_path = os.path.join(dataset_dir, "data", scan_name, "meta_data.json")
37 | worldtogt = np.array(load_from_json(Path(meta_data_json_path))["worldtogt"])
38 |
39 | all_curve_points, all_line_points, _, _ = get_pred_points_and_directions(json_path)
40 | all_points = np.concatenate([all_curve_points, all_line_points], axis=0).reshape(
41 | -1, 3
42 | )
43 | all_points = np.dot(all_points, worldtogt[:3, :3].T) + worldtogt[:3, 3]
44 |
45 | points_down = downsample_point_cloud_average(all_points, num_voxels_per_axis=256)
46 |
47 | nn_engine = skln.NearestNeighbors(
48 | n_neighbors=1, radius=downsample_density, algorithm="kd_tree", n_jobs=-1
49 | )
50 |
51 | gt_edge_points_path = os.path.join(
52 | dataset_dir, "groundtruth", "edge_points", scan_name, "edge_points.ply"
53 | )
54 | gt_edge_pcd = o3d.io.read_point_cloud(gt_edge_points_path)
55 | gt_edge_points = np.asarray(gt_edge_pcd.points)
56 |
57 | nn_engine.fit(gt_edge_points)
58 | dist_d2s, idx_d2s = nn_engine.kneighbors(
59 | points_down, n_neighbors=1, return_distance=True
60 | )
61 | precision = np.sum(dist_d2s <= threshold) / dist_d2s.shape[0]
62 | precision_list.append(precision)
63 |
64 | nn_engine.fit(points_down)
65 | dist_s2d, idx_s2d = nn_engine.kneighbors(
66 | gt_edge_points, n_neighbors=1, return_distance=True
67 | )
68 | recall = np.sum(dist_s2d <= threshold) / len(dist_s2d)
69 | recall_list.append(recall)
70 |
71 | print(f" Recall: {recall:.4f}, Precision: {precision:.4f}")
72 |
73 |
74 | def main(args):
75 | set_random_seeds()
76 | with open("src/eval/DTU_scans.txt", "r") as f:
77 | scan_names = [line.strip() for line in f]
78 |
79 | precision_list = []
80 | recall_list = []
81 |
82 | for scan_name in scan_names:
83 | process_scan(
84 | scan_name,
85 | args.base_dir,
86 | args.exp_name,
87 | args.dataset_dir,
88 | args.threshold,
89 | args.downsample_density,
90 | precision_list,
91 | recall_list,
92 | )
93 |
94 | print("\nSummary:")
95 | print(f" Mean Recall: {np.mean(recall_list):.4f}")
96 | print(f" Mean Precision: {np.mean(precision_list):.4f}")
97 |
98 |
99 | if __name__ == "__main__":
100 | parser = argparse.ArgumentParser(
101 | description="Process DTU data and compute metrics."
102 | )
103 | parser.add_argument(
104 | "--base_dir",
105 | type=str,
106 | default="./exp/DTU",
107 | help="Base directory for experiments",
108 | )
109 | parser.add_argument(
110 | "--dataset_dir",
111 | type=str,
112 | default="./data/DTU_Edge",
113 | help="Directory for the dataset",
114 | )
115 | parser.add_argument("--exp_name", type=str, default="emap", help="Experiment name")
116 | parser.add_argument("--downsample_density", type=float, default=0.5)
117 | parser.add_argument("--threshold", type=float, default=5)
118 | args = parser.parse_args()
119 | main(args)
120 |
--------------------------------------------------------------------------------
/src/eval/eval_util.py:
--------------------------------------------------------------------------------
1 | import os
2 | import numpy as np
3 | import json
4 | import point_cloud_utils as pcu
5 | import math
6 | import random
7 | from pathlib import Path
8 |
9 |
10 | def set_random_seeds(seed=42):
11 | np.random.seed(seed)
12 | random.seed(seed)
13 |
14 | def load_from_json(filename: Path):
15 | """Load a dictionary from a JSON filename."""
16 | assert filename.suffix == ".json"
17 | with open(filename, encoding="UTF-8") as file:
18 | return json.load(file)
19 |
20 | def chamfer_distance(x, y, return_index=False, p_norm=2, max_points_per_leaf=10):
21 | """
22 | Compute the chamfer distance between two point clouds x, and y
23 |
24 | Args:
25 | x : A m-sized minibatch of point sets in R^d. i.e. shape [m, n_a, d]
26 | y : A m-sized minibatch of point sets in R^d. i.e. shape [m, n_b, d]
27 | return_index: If set to True, will return a pair (corrs_x_to_y, corrs_y_to_x) where
28 | corrs_x_to_y[i] stores the index into y of the closest point to x[i]
29 | (i.e. y[corrs_x_to_y[i]] is the nearest neighbor to x[i] in y).
30 | corrs_y_to_x is similar to corrs_x_to_y but with x and y reversed.
31 | max_points_per_leaf : The maximum number of points per leaf node in the KD tree used by this function.
32 | Default is 10.
33 | p_norm : Which norm to use. p_norm can be any real number, inf (for the max norm) -inf (for the min norm),
34 | 0 (for sum(x != 0))
35 | Returns:
36 | The chamfer distance between x an dy.
37 | If return_index is set, then this function returns a tuple (chamfer_dist, corrs_x_to_y, corrs_y_to_x) where
38 | corrs_x_to_y and corrs_y_to_x are described above.
39 | """
40 |
41 | dists_x_to_y, corrs_x_to_y = pcu.k_nearest_neighbors(
42 | x, y, k=1, squared_distances=False, max_points_per_leaf=max_points_per_leaf
43 | )
44 | dists_y_to_x, corrs_y_to_x = pcu.k_nearest_neighbors(
45 | y, x, k=1, squared_distances=False, max_points_per_leaf=max_points_per_leaf
46 | )
47 |
48 | dists_x_to_y = np.linalg.norm(x[corrs_y_to_x] - y, axis=-1, ord=p_norm).mean()
49 | dists_y_to_x = np.linalg.norm(y[corrs_x_to_y] - x, axis=-1, ord=p_norm).mean()
50 |
51 | Comp = np.mean(dists_x_to_y)
52 | Acc = np.mean(dists_y_to_x)
53 | cham_dist = Comp + Acc
54 |
55 | if return_index:
56 | return cham_dist, corrs_x_to_y, corrs_y_to_x
57 |
58 | return cham_dist, Acc, Comp
59 |
60 |
61 | def compute_chamfer_distance(pred_sampled, gt_points):
62 | """
63 | Compute chamfer distance between predicted and ground truth points.
64 |
65 | Args:
66 | pred_sampled (numpy.ndarray): Predicted point cloud.
67 | gt_points (numpy.ndarray): Ground truth points.
68 |
69 | Returns:
70 | float: Chamfer distance.
71 | """
72 | chamfer_dist, acc, comp = chamfer_distance(pred_sampled, gt_points)
73 | return chamfer_dist, acc, comp
74 |
75 |
76 | def f_score(precision, recall):
77 | """
78 | Compute F-score.
79 |
80 | Args:
81 | precision (float): Precision.
82 | recall (float): Recall.
83 |
84 | Returns:
85 | float: F-score.
86 | """
87 | return 2 * precision * recall / (precision + recall)
88 |
89 |
90 | def bezier_curve_length(control_points, num_samples):
91 | def binomial_coefficient(n, i):
92 | return math.factorial(n) // (math.factorial(i) * math.factorial(n - i))
93 |
94 | def derivative_bezier(t):
95 | n = len(control_points) - 1
96 | point = np.array([0.0, 0.0, 0.0])
97 | for i, (p1, p2) in enumerate(zip(control_points[:-1], control_points[1:])):
98 | point += (
99 | n
100 | * binomial_coefficient(n - 1, i)
101 | * (1 - t) ** (n - 1 - i)
102 | * t**i
103 | * (np.array(p2) - np.array(p1))
104 | )
105 | return point
106 |
107 | def simpson_integral(a, b, num_samples):
108 | h = (b - a) / num_samples
109 | sum1 = sum(
110 | np.linalg.norm(derivative_bezier(a + i * h))
111 | for i in range(1, num_samples, 2)
112 | )
113 | sum2 = sum(
114 | np.linalg.norm(derivative_bezier(a + i * h))
115 | for i in range(2, num_samples - 1, 2)
116 | )
117 | return (
118 | (
119 | np.linalg.norm(derivative_bezier(a))
120 | + 4 * sum1
121 | + 2 * sum2
122 | + np.linalg.norm(derivative_bezier(b))
123 | )
124 | * h
125 | / 3
126 | )
127 |
128 | # Compute the length of the 3D Bezier curve using composite Simpson's rule
129 | length = 0.0
130 | for i in range(num_samples):
131 | t0 = i / num_samples
132 | t1 = (i + 1) / num_samples
133 | length += simpson_integral(t0, t1, num_samples)
134 |
135 | return length
136 |
137 |
138 | def compute_precision_recall_IOU(
139 | pred_sampled, gt_points, metrics, thresh_list=[0.02], edge_type="all"
140 | ):
141 | """
142 | Compute precision, recall, F-score, and IOU.
143 |
144 | Args:
145 | pred_sampled (numpy.ndarray): Predicted point cloud.
146 | gt_points (numpy.ndarray): Ground truth points.
147 | metrics (dict): Dictionary to store metrics.
148 |
149 | Returns:
150 | dict: Updated metrics.
151 | """
152 | if edge_type == "all":
153 | for thresh in thresh_list:
154 | dists_a_to_b, _ = pcu.k_nearest_neighbors(
155 | pred_sampled, gt_points, k=1
156 | ) # k closest points (in pts_b) for each point in pts_a
157 | correct_pred = np.sum(dists_a_to_b < thresh)
158 | precision = correct_pred / len(dists_a_to_b)
159 | metrics[f"precision_{thresh}"].append(precision)
160 |
161 | dists_b_to_a, _ = pcu.k_nearest_neighbors(gt_points, pred_sampled, k=1)
162 | correct_gt = np.sum(dists_b_to_a < thresh)
163 | recall = correct_gt / len(dists_b_to_a)
164 | metrics[f"recall_{thresh}"].append(recall)
165 |
166 | fscore = 2 * precision * recall / (precision + recall)
167 | metrics[f"fscore_{thresh}"].append(fscore)
168 |
169 | intersection = min(correct_pred, correct_gt)
170 | union = (
171 | len(dists_a_to_b) + len(dists_b_to_a) - max(correct_pred, correct_gt)
172 | )
173 |
174 | IOU = intersection / union
175 | metrics[f"IOU_{thresh}"].append(IOU)
176 | return metrics
177 | else:
178 | correct_gt_list = []
179 | correct_pred_list = []
180 | _, acc, comp = compute_chamfer_distance(pred_sampled, gt_points)
181 | for thresh in thresh_list:
182 | dists_b_to_a, _ = pcu.k_nearest_neighbors(gt_points, pred_sampled, k=1)
183 | correct_gt = np.sum(dists_b_to_a < thresh)
184 | num_gt = len(dists_b_to_a)
185 | correct_gt_list.append(correct_gt)
186 | dists_a_to_b, _ = pcu.k_nearest_neighbors(pred_sampled, gt_points, k=1)
187 | correct_pred = np.sum(dists_a_to_b < thresh)
188 | correct_pred_list.append(correct_pred)
189 | num_pred = len(dists_a_to_b)
190 |
191 | return correct_gt_list, num_gt, correct_pred_list, num_pred, acc, comp
192 |
193 |
194 | def get_gt_points(
195 | scan_name,
196 | edge_type="all",
197 | interval=0.005,
198 | return_direction=False,
199 | data_base_dir=None,
200 | ):
201 | """
202 | Get ground truth points from a dataset.
203 |
204 | Args:
205 | name (str): Name of the dataset.
206 |
207 | Returns:
208 | numpy.ndarray: Raw and processed ground truth points.
209 | """
210 | objs_dir = os.path.join(data_base_dir, "obj")
211 | obj_names = os.listdir(objs_dir)
212 | obj_names.sort()
213 | index_obj_names = {}
214 | for obj_name in obj_names:
215 | index_obj_names[obj_name[:8]] = obj_name
216 |
217 | json_feats_path = os.path.join(data_base_dir, "chunk_0000_feats.json")
218 | with open(json_feats_path, "r") as f:
219 | json_data_feats = json.load(f)
220 | json_stats_path = os.path.join(data_base_dir, "chunk_0000_stats.json")
221 | with open(json_stats_path, "r") as f:
222 | json_data_stats = json.load(f)
223 |
224 | # get the normalize scale to help align the nerf points and gt points
225 | [
226 | x_min,
227 | y_min,
228 | z_min,
229 | x_max,
230 | y_max,
231 | z_max,
232 | x_range,
233 | y_range,
234 | z_range,
235 | ] = json_data_stats[scan_name]["bbox"]
236 | scale = 1 / max(x_range, y_range, z_range)
237 | poi_center = (
238 | np.array([((x_min + x_max) / 2), ((y_min + y_max) / 2), ((z_min + z_max) / 2)])
239 | * scale
240 | )
241 | set_location = [0.5, 0.5, 0.5] - poi_center # based on the rendering settings
242 | obj_path = os.path.join(objs_dir, index_obj_names[scan_name])
243 |
244 | with open(obj_path, encoding="utf-8") as file:
245 | data = file.readlines()
246 |
247 | vertices_obj = [each.split(" ") for each in data if each.split(" ")[0] == "v"]
248 | vertices_xyz = [
249 | [float(v[1]), float(v[2]), float(v[3].replace("\n", ""))] for v in vertices_obj
250 | ]
251 |
252 | edge_pts = []
253 | edge_pts_raw = []
254 | edge_pts_direction = []
255 | rename = {
256 | "BSpline": "curve",
257 | "Circle": "curve",
258 | "Ellipse": "curve",
259 | "Line": "line",
260 | }
261 | for each_curve in json_data_feats[scan_name]:
262 | if edge_type != "all" and rename[each_curve["type"]] != edge_type:
263 | continue
264 |
265 | if each_curve["sharp"]: # each_curve["type"]: BSpline, Line, Circle
266 | each_edge_pts = [vertices_xyz[i] for i in each_curve["vert_indices"]]
267 | edge_pts_raw += each_edge_pts
268 |
269 | gt_sampling = []
270 | each_edge_pts = np.array(each_edge_pts)
271 | for index in range(len(each_edge_pts) - 1):
272 | next = each_edge_pts[index + 1]
273 | current = each_edge_pts[index]
274 | num = int(np.linalg.norm(next - current) // interval)
275 | linspace = np.linspace(0, 1, num)
276 | gt_sampling.append(
277 | linspace[:, None] * current + (1 - linspace)[:, None] * next
278 | )
279 |
280 | if return_direction:
281 | direction = (next - current) / np.linalg.norm(next - current)
282 | edge_pts_direction.extend([direction] * num)
283 | each_edge_pts = np.concatenate(gt_sampling).tolist()
284 | edge_pts += each_edge_pts
285 |
286 | if len(edge_pts_raw) == 0:
287 | return None, None, None
288 |
289 | edge_pts_raw = np.array(edge_pts_raw) * scale + set_location
290 | edge_pts = np.array(edge_pts) * scale + set_location
291 | edge_pts_direction = np.array(edge_pts_direction)
292 |
293 | return (
294 | edge_pts_raw.astype(np.float32),
295 | edge_pts.astype(np.float32),
296 | edge_pts_direction,
297 | )
298 |
299 |
300 | def get_pred_points_and_directions(
301 | json_path,
302 | sample_resolution=0.005,
303 | ):
304 | with open(json_path, "r") as f:
305 | json_data = json.load(f)
306 |
307 | curve_paras = np.array(json_data["curves_ctl_pts"]).reshape(-1, 3)
308 | curves_ctl_pts = curve_paras.reshape(-1, 4, 3)
309 | lines_end_pts = np.array(json_data["lines_end_pts"]).reshape(-1, 2, 3)
310 |
311 | num_curves = len(curves_ctl_pts)
312 | num_lines = len(lines_end_pts)
313 |
314 | all_curve_points = []
315 | all_curve_directions = []
316 |
317 | # # -----------------------------------for Cubic Bezier-----------------------------------
318 | if num_curves > 0:
319 | for i, each_curve in enumerate(curves_ctl_pts):
320 | each_curve = np.array(each_curve).reshape(4, 3) # shape: (4, 3)
321 | sample_num = int(
322 | bezier_curve_length(each_curve, num_samples=100) // sample_resolution
323 | )
324 | t = np.linspace(0, 1, sample_num)
325 | matrix_u = np.array([t**3, t**2, t, [1] * sample_num]).reshape(
326 | 4, sample_num
327 | )
328 |
329 | matrix_middle = np.array(
330 | [[-1, 3, -3, 1], [3, -6, 3, 0], [-3, 3, 0, 0], [1, 0, 0, 0]]
331 | )
332 |
333 | matrix = np.matmul(
334 | np.matmul(matrix_u.T, matrix_middle), each_curve
335 | ).reshape(sample_num, 3)
336 |
337 | all_curve_points += matrix.tolist()
338 |
339 | # Calculate the curve directions (derivatives)
340 | derivative_u = 3 * t**2
341 | derivative_v = 2 * t
342 |
343 | # Derivative matrices for x, y, z
344 | dx = (
345 | (
346 | -3 * each_curve[0][0]
347 | + 9 * each_curve[1][0]
348 | - 9 * each_curve[2][0]
349 | + 3 * each_curve[3][0]
350 | )
351 | * derivative_u
352 | + (6 * each_curve[0][0] - 12 * each_curve[1][0] + 6 * each_curve[2][0])
353 | * derivative_v
354 | + (-3 * each_curve[0][0] + 3 * each_curve[1][0])
355 | )
356 |
357 | dy = (
358 | (
359 | -3 * each_curve[0][1]
360 | + 9 * each_curve[1][1]
361 | - 9 * each_curve[2][1]
362 | + 3 * each_curve[3][1]
363 | )
364 | * derivative_u
365 | + (6 * each_curve[0][1] - 12 * each_curve[1][1] + 6 * each_curve[2][1])
366 | * derivative_v
367 | + (-3 * each_curve[0][1] + 3 * each_curve[1][1])
368 | )
369 |
370 | dz = (
371 | (
372 | -3 * each_curve[0][2]
373 | + 9 * each_curve[1][2]
374 | - 9 * each_curve[2][2]
375 | + 3 * each_curve[3][2]
376 | )
377 | * derivative_u
378 | + (6 * each_curve[0][2] - 12 * each_curve[1][2] + 6 * each_curve[2][2])
379 | * derivative_v
380 | + (-3 * each_curve[0][2] + 3 * each_curve[1][2])
381 | )
382 | for i in range(sample_num):
383 | direction = np.array([dx[i], dy[i], dz[i]])
384 | norm_direction = direction / np.linalg.norm(direction)
385 | all_curve_directions.append(norm_direction)
386 |
387 | all_line_points = []
388 | all_line_directions = []
389 | # # -------------------------------------for Line-----------------------------------------
390 | if num_lines > 0:
391 | for i, each_line in enumerate(lines_end_pts):
392 | each_line = np.array(each_line).reshape(2, 3) # shape: (2, 3)
393 | sample_num = int(
394 | np.linalg.norm(each_line[0] - each_line[-1]) // sample_resolution
395 | )
396 | t = np.linspace(0, 1, sample_num)
397 |
398 | matrix_u_l = np.array([t, [1] * sample_num])
399 | matrix_middle_l = np.array([[-1, 1], [1, 0]])
400 |
401 | matrix_l = np.matmul(
402 | np.matmul(matrix_u_l.T, matrix_middle_l), each_line
403 | ).reshape(sample_num, 3)
404 | all_line_points += matrix_l.tolist()
405 |
406 | # Calculate the direction vector for the line segment
407 | direction = each_line[1] - each_line[0]
408 | norm_direction = direction / (np.linalg.norm(direction) + 1e-6)
409 |
410 | for point in matrix_l:
411 | all_line_directions.append(norm_direction)
412 |
413 | all_curve_points = np.array(all_curve_points).reshape(-1, 3)
414 | all_line_points = np.array(all_line_points).reshape(-1, 3)
415 | return all_curve_points, all_line_points, all_curve_directions, all_line_directions
416 |
417 |
418 | def downsample_point_cloud_average(
419 | points, num_voxels_per_axis=256, min_bound=None, max_bound=None
420 | ):
421 | """
422 | Downsample a point set based on the number of voxels per axis by averaging the points within each voxel.
423 |
424 | Args:
425 | points: a [#v, 3]-shaped array of 3d points.
426 | num_voxels_per_axis: a scalar or 3-tuple specifying the number of voxels along each axis.
427 |
428 | Returns:
429 | A [#v', 3]-shaped numpy array of downsampled points, where #v' is the number of occupied voxels.
430 | """
431 |
432 | # Calculate the bounding box of the point cloud
433 | if min_bound is None:
434 | min_bound = np.min(points, axis=0)
435 | else:
436 | min_bound = np.array(min_bound)
437 |
438 | if max_bound is None:
439 | max_bound = np.max(points, axis=0)
440 | else:
441 | max_bound = np.array(max_bound)
442 |
443 | # Determine the size of the voxel based on the desired number of voxels per axis
444 | if isinstance(num_voxels_per_axis, int):
445 | voxel_size = (max_bound - min_bound) / num_voxels_per_axis
446 | else:
447 | voxel_size = [
448 | (max_bound[i] - min_bound[i]) / num_voxels_per_axis[i] for i in range(3)
449 | ]
450 |
451 | # Use the existing function to downsample the point cloud based on voxel size
452 | downsampled_points = pcu.downsample_point_cloud_on_voxel_grid(
453 | voxel_size, points, min_bound=min_bound, max_bound=max_bound
454 | )
455 |
456 | return downsampled_points
457 |
--------------------------------------------------------------------------------
/src/models/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/cvg/EMAP/652f24ecc3f3cbf538928f27cc6d55dbebb360c7/src/models/__init__.py
--------------------------------------------------------------------------------
/src/models/embedder.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 |
4 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf.
5 | class Embedder:
6 | def __init__(self, **kwargs):
7 | self.kwargs = kwargs
8 | self.create_embedding_fn()
9 |
10 | def create_embedding_fn(self):
11 | embed_fns = []
12 | d = self.kwargs["input_dims"]
13 | out_dim = 0
14 | if self.kwargs["include_input"]:
15 | embed_fns.append(lambda x: x)
16 | out_dim += d
17 |
18 | max_freq = self.kwargs["max_freq_log2"]
19 | N_freqs = self.kwargs["num_freqs"]
20 |
21 | if self.kwargs["log_sampling"]:
22 | freq_bands = 2.0 ** torch.linspace(0.0, max_freq, N_freqs)
23 | else:
24 | freq_bands = torch.linspace(2.0**0.0, 2.0**max_freq, N_freqs)
25 |
26 | for freq in freq_bands:
27 | for p_fn in self.kwargs["periodic_fns"]:
28 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq))
29 | out_dim += d
30 |
31 | self.embed_fns = embed_fns
32 | self.out_dim = out_dim
33 |
34 | def embed(self, inputs):
35 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
36 |
37 |
38 | def get_embedder(multires, input_dims=3):
39 | embed_kwargs = {
40 | "include_input": True,
41 | "input_dims": input_dims,
42 | "max_freq_log2": multires - 1,
43 | "num_freqs": multires,
44 | "log_sampling": True,
45 | "periodic_fns": [torch.sin, torch.cos],
46 | }
47 |
48 | embedder_obj = Embedder(**embed_kwargs)
49 |
50 | def embed(x, eo=embedder_obj):
51 | return eo.embed(x)
52 |
53 | return embed, embedder_obj.out_dim
54 |
--------------------------------------------------------------------------------
/src/models/loss.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch.nn.functional as F
3 |
4 |
5 | class EdgeLoss(nn.Module):
6 |
7 | def __init__(self, loss_type="mse"):
8 | super(EdgeLoss, self).__init__()
9 | if loss_type == "mse":
10 | self.loss_func = F.mse_loss
11 | elif loss_type == "l1":
12 | self.loss_func = F.l1_loss
13 |
14 | def forward(self, pred_edge, gt_edge):
15 |
16 | edge_loss = self.loss_func(pred_edge, gt_edge, reduction="mean")
17 | return edge_loss
18 |
--------------------------------------------------------------------------------
/src/models/udf_model.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn as nn
3 | import numpy as np
4 | from src.models.embedder import get_embedder
5 |
6 |
7 | class UDFNetwork(nn.Module):
8 | def __init__(
9 | self,
10 | d_in,
11 | d_out,
12 | d_hidden,
13 | n_layers,
14 | skip_in=(4,),
15 | multires=0,
16 | scale=1,
17 | bias=0.5,
18 | geometric_init=True,
19 | weight_norm=True,
20 | udf_type="abs",
21 | ):
22 | super(UDFNetwork, self).__init__()
23 |
24 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out]
25 |
26 | self.embed_fn_fine = None
27 |
28 | if multires > 0:
29 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in)
30 | self.embed_fn_fine = embed_fn
31 | dims[0] = input_ch
32 |
33 | self.num_layers = len(dims)
34 | self.skip_in = skip_in
35 | self.scale = scale
36 |
37 | self.geometric_init = geometric_init
38 |
39 | for l in range(0, self.num_layers - 1):
40 | if l + 1 in self.skip_in:
41 | out_dim = dims[l + 1] - dims[0]
42 | else:
43 | out_dim = dims[l + 1]
44 |
45 | lin = nn.Linear(dims[l], out_dim)
46 |
47 | if geometric_init:
48 | if l == self.num_layers - 2:
49 | torch.nn.init.normal_(
50 | lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001
51 | )
52 | # torch.nn.init.constant_(lin.bias, bias) # for indoor sdf setting
53 | torch.nn.init.constant_(lin.bias, -bias)
54 |
55 | elif multires > 0 and l == 0:
56 | torch.nn.init.constant_(lin.bias, 0.0)
57 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
58 | torch.nn.init.normal_(
59 | lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)
60 | )
61 | elif multires > 0 and l in self.skip_in:
62 | torch.nn.init.constant_(lin.bias, 0.0)
63 | torch.nn.init.normal_(
64 | lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)
65 | )
66 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3) :], 0.0)
67 | else:
68 | torch.nn.init.constant_(lin.bias, 0.0)
69 | torch.nn.init.normal_(
70 | lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)
71 | )
72 |
73 | if weight_norm:
74 | lin = nn.utils.parametrizations.weight_norm(lin)
75 |
76 | setattr(self, "lin" + str(l), lin)
77 |
78 | self.activation = nn.Softplus(beta=100)
79 | self.relu = nn.ReLU()
80 | self.udf_type = udf_type
81 |
82 | def udf_out(self, x):
83 | if self.udf_type == "abs":
84 | return torch.abs(x)
85 | elif self.udf_type == "square":
86 | return x**2
87 | elif self.udf_type == "sdf":
88 | return x
89 |
90 | def forward(self, inputs):
91 | inputs = inputs * self.scale
92 | if self.embed_fn_fine is not None:
93 | inputs = self.embed_fn_fine(inputs)
94 |
95 | x = inputs
96 | for l in range(0, self.num_layers - 1):
97 | lin = getattr(self, "lin" + str(l))
98 |
99 | if l in self.skip_in:
100 | x = torch.cat([x, inputs], 1) / np.sqrt(2)
101 |
102 | x = lin(x)
103 |
104 | if l < self.num_layers - 2:
105 | x = self.activation(x)
106 |
107 | return (
108 | torch.cat([self.udf_out(x[:, :1]) / self.scale, x[:, 1:]], dim=-1),
109 | inputs,
110 | )
111 |
112 | def udf(self, x):
113 | feature_out, PE = self.forward(x)
114 | udf = feature_out[:, :1]
115 | feature = feature_out[:, 1:]
116 | return udf, feature, PE
117 |
118 | def udf_hidden_appearance(self, x):
119 | return self.forward(x)
120 |
121 | def gradient(self, x):
122 | x.requires_grad_(True)
123 | with torch.set_grad_enabled(True):
124 | y = self.udf(x)[0]
125 |
126 | d_output = torch.ones_like(y, requires_grad=False, device=y.device)
127 | gradients = torch.autograd.grad(
128 | outputs=y,
129 | inputs=x,
130 | grad_outputs=d_output,
131 | create_graph=True,
132 | retain_graph=True,
133 | only_inputs=True,
134 | )[0]
135 | return gradients.unsqueeze(1)
136 |
137 |
138 | class RenderingNetwork(nn.Module):
139 | def __init__(
140 | self,
141 | d_feature,
142 | mode,
143 | d_in,
144 | d_out,
145 | d_hidden,
146 | n_layers,
147 | weight_norm=True,
148 | multires_view=0,
149 | squeeze_out=True,
150 | ):
151 | super().__init__()
152 |
153 | self.mode = mode
154 | self.squeeze_out = squeeze_out
155 | self.d_out = d_out
156 | dims = [d_in + d_feature] + [d_hidden for _ in range(n_layers)] + [d_out]
157 |
158 | self.embedview_fn = None
159 | if multires_view > 0 and self.mode != "no_view_dir":
160 | embedview_fn, input_ch = get_embedder(multires_view)
161 | self.embedview_fn = embedview_fn
162 | dims[0] += input_ch - 3
163 |
164 | self.num_layers = len(dims)
165 |
166 | for l in range(0, self.num_layers - 1):
167 | out_dim = dims[l + 1]
168 | lin = nn.Linear(dims[l], out_dim)
169 |
170 | if weight_norm:
171 | lin = nn.utils.parametrizations.weight_norm(lin)
172 |
173 | setattr(self, "lin" + str(l), lin)
174 |
175 | self.relu = nn.ReLU()
176 |
177 | def forward(self, points, normals, view_dirs, feature_vectors):
178 | if self.embedview_fn is not None:
179 | view_dirs = self.embedview_fn(view_dirs)
180 |
181 | rendering_input = None
182 | normals = normals.detach()
183 | if self.mode == "idr":
184 | rendering_input = torch.cat(
185 | [points, view_dirs, normals, -1 * normals, feature_vectors], dim=-1
186 | )
187 | elif self.mode == "no_view_dir":
188 | rendering_input = torch.cat(
189 | [points, normals, -1 * normals, feature_vectors], dim=-1
190 | )
191 | elif self.mode == "no_normal":
192 | rendering_input = torch.cat([points, view_dirs, feature_vectors], dim=-1)
193 |
194 | x = rendering_input
195 |
196 | for l in range(0, self.num_layers - 1):
197 | lin = getattr(self, "lin" + str(l))
198 |
199 | x = lin(x)
200 |
201 | if l < self.num_layers - 2:
202 | x = self.relu(x)
203 |
204 | if self.squeeze_out:
205 | color = torch.sigmoid(x[:, : self.d_out])
206 | else:
207 | color = x[:, : self.d_out]
208 |
209 | return color
210 |
211 |
212 | class SingleVarianceNetwork(nn.Module):
213 | def __init__(self, init_val, requires_grad=True):
214 | super(SingleVarianceNetwork, self).__init__()
215 | self.variance = nn.Parameter(
216 | torch.Tensor([init_val]), requires_grad=requires_grad
217 | )
218 | self.second_variance = nn.Parameter(
219 | torch.Tensor([init_val]), requires_grad=requires_grad
220 | )
221 |
222 | def set_trainable(self):
223 | self.variance.requires_grad = True
224 | self.second_variance.requires_grad = True
225 |
226 | def forward(self, x):
227 | return torch.ones([len(x), 1]).to(x.device) * torch.exp(self.variance * 10.0)
228 |
229 | def get_secondvariance(self, x):
230 | return torch.ones([len(x), 1]).to(x.device) * torch.exp(
231 | self.second_variance * 10.0
232 | )
233 |
234 |
235 | class BetaNetwork(nn.Module):
236 | def __init__(
237 | self,
238 | init_var_beta=0.1,
239 | init_var_gamma=0.1,
240 | init_var_zeta=0.05,
241 | beta_min=0.00005,
242 | requires_grad_beta=True,
243 | requires_grad_gamma=True,
244 | requires_grad_zeta=True,
245 | ):
246 | super().__init__()
247 |
248 | self.beta = nn.Parameter(
249 | torch.Tensor([init_var_beta]), requires_grad=requires_grad_beta
250 | )
251 | self.gamma = nn.Parameter(
252 | torch.Tensor([init_var_gamma]), requires_grad=requires_grad_gamma
253 | )
254 | self.zeta = nn.Parameter(
255 | torch.Tensor([init_var_zeta]), requires_grad=requires_grad_zeta
256 | )
257 | self.beta_min = beta_min
258 |
259 | def get_beta(self):
260 | return torch.exp(self.beta * 10).clip(0, 1.0 / self.beta_min)
261 |
262 | def get_gamma(self):
263 | return torch.exp(self.gamma * 10)
264 |
265 | def get_zeta(self):
266 | """
267 | used for udf2prob mapping zeta*x/(1+zeta*x)
268 | :return:
269 | :rtype:
270 | """
271 | return self.zeta.abs()
272 |
273 | def set_beta_trainable(self):
274 | self.beta.requires_grad = True
275 |
276 | @torch.no_grad()
277 | def set_gamma(self, x):
278 | self.gamma = nn.Parameter(
279 | torch.Tensor([x]), requires_grad=self.gamma.requires_grad
280 | ).to(self.gamma.device)
281 |
282 | def forward(self):
283 | beta = self.get_beta()
284 | gamma = self.get_gamma()
285 | zeta = self.get_zeta()
286 | return beta, gamma, zeta
287 |
--------------------------------------------------------------------------------
/src/runner/runner_base.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import torch.nn.functional as F
3 | import os
4 | import numpy as np
5 | from shutil import copyfile
6 | from icecream import ic
7 | from pyhocon import ConfigFactory, HOCONConverter
8 | from src.dataset.dataset import Dataset
9 | from src.models.udf_model import (
10 | UDFNetwork,
11 | BetaNetwork,
12 | SingleVarianceNetwork,
13 | )
14 | from src.models.udf_renderer_blending import UDFRendererBlending
15 | from src.models.loss import EdgeLoss
16 |
17 |
18 | class Runner:
19 | def __init__(
20 | self,
21 | conf,
22 | mode="train",
23 | is_continue=False,
24 | args=None,
25 | ):
26 | # Initial setting
27 | self.device = torch.device("cuda")
28 | self.conf = conf
29 |
30 | self.base_exp_dir = os.path.join(
31 | self.conf["general.base_exp_dir"],
32 | str(self.conf["dataset"]["scan"]),
33 | self.conf["general.expname"],
34 | )
35 | os.makedirs(self.base_exp_dir, exist_ok=True)
36 |
37 | self.dataset = Dataset(self.conf["dataset"])
38 | self.near, self.far = self.dataset.near, self.dataset.far
39 |
40 | self.iter_step = 0
41 |
42 | # trainning parameters
43 | self.end_iter = self.conf.get_int("train.end_iter")
44 | self.save_freq = self.conf.get_int("train.save_freq")
45 | self.report_freq = self.conf.get_int("train.report_freq")
46 | self.val_freq = self.conf.get_int("train.val_freq")
47 | self.batch_size = self.conf.get_int("train.batch_size")
48 | self.validate_resolution_level = self.conf.get_int(
49 | "train.validate_resolution_level"
50 | )
51 | self.use_white_bkgd = self.conf.get_bool("train.use_white_bkgd")
52 | self.importance_sample = self.conf.get_bool("train.importance_sample")
53 |
54 | # setting about learning rate schedule
55 | self.learning_rate = self.conf.get_float("train.learning_rate")
56 | self.learning_rate_geo = self.conf.get_float("train.learning_rate_geo")
57 | self.learning_rate_alpha = self.conf.get_float("train.learning_rate_alpha")
58 | self.warm_up_end = self.conf.get_float("train.warm_up_end", default=0.0)
59 | self.anneal_end = self.conf.get_float("train.anneal_end", default=0.0)
60 | # don't train the udf network in the early steps
61 | self.fix_geo_end = self.conf.get_float("train.fix_geo_end", default=200)
62 | self.warmup_sample = self.conf.get_bool(
63 | "train.warmup_sample", default=False
64 | ) # * training schedule
65 | # whether the udf network and appearance network share the same learning rate
66 | self.same_lr = self.conf.get_bool("train.same_lr", default=False)
67 |
68 | # weights
69 | self.igr_weight = self.conf.get_float("train.igr_weight")
70 | self.igr_ns_weight = self.conf.get_float("train.igr_ns_weight", default=0.0)
71 |
72 | # loss functions
73 | self.edge_loss_func = EdgeLoss(self.conf["edge_loss"]["loss_type"])
74 | self.edge_weight = self.conf.get_float("edge_loss.edge_weight", 0.0)
75 | self.is_continue = is_continue
76 | # self.is_finetune = args.is_finetune
77 |
78 | self.mode = mode
79 | self.model_type = self.conf["general.model_type"]
80 | self.model_list = []
81 | self.writer = None
82 |
83 | # Networks
84 | params_to_train = []
85 | params_to_train_nerf = []
86 | params_to_train_geo = []
87 |
88 | self.nerf_outside = None
89 | self.nerf_coarse = None
90 | self.nerf_fine = None
91 | self.sdf_network_fine = None
92 | self.udf_network_fine = None
93 | self.variance_network_fine = None
94 |
95 | # self.nerf_outside = NeRF(**self.conf["model.nerf"]).to(self.device)
96 | self.udf_network_fine = UDFNetwork(**self.conf["model.udf_network"]).to(
97 | self.device
98 | )
99 | self.variance_network_fine = SingleVarianceNetwork(
100 | **self.conf["model.variance_network"]
101 | ).to(self.device)
102 | self.beta_network = BetaNetwork(**self.conf["model.beta_network"]).to(
103 | self.device
104 | )
105 | # params_to_train_nerf += list(self.nerf_outside.parameters())
106 | params_to_train_geo += list(self.udf_network_fine.parameters())
107 | params_to_train += list(self.variance_network_fine.parameters())
108 | params_to_train += list(self.beta_network.parameters())
109 |
110 | self.optimizer = torch.optim.Adam(
111 | [
112 | {"params": params_to_train_geo, "lr": self.learning_rate_geo},
113 | {"params": params_to_train},
114 | {"params": params_to_train_nerf},
115 | ],
116 | lr=self.learning_rate,
117 | )
118 |
119 | self.renderer = UDFRendererBlending(
120 | self.nerf_outside,
121 | self.udf_network_fine,
122 | self.variance_network_fine,
123 | self.beta_network,
124 | device=self.device,
125 | **self.conf["model.udf_renderer"],
126 | )
127 |
128 | def update_learning_rate(self, start_g_id=0):
129 | if self.iter_step < self.warm_up_end:
130 | learning_factor = self.iter_step / self.warm_up_end
131 | else:
132 | alpha = self.learning_rate_alpha
133 | progress = (self.iter_step - self.warm_up_end) / (
134 | self.end_iter - self.warm_up_end
135 | )
136 | learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (
137 | 1 - alpha
138 | ) + alpha
139 |
140 | for g in self.optimizer.param_groups[start_g_id:]:
141 | g["lr"] = self.learning_rate * learning_factor
142 |
143 | def update_learning_rate_geo(self):
144 | if self.iter_step < self.fix_geo_end: # * make bg nerf learn first
145 | learning_factor = 0.0
146 | elif self.iter_step < self.warm_up_end * 2:
147 | learning_factor = self.iter_step / (self.warm_up_end * 2)
148 | elif self.iter_step < self.end_iter * 0.5:
149 | learning_factor = 1.0
150 | else:
151 | alpha = self.learning_rate_alpha
152 | progress = (self.iter_step - self.end_iter * 0.5) / (
153 | self.end_iter - self.end_iter * 0.5
154 | )
155 | learning_factor = (np.cos(np.pi * progress) + 1.0) * 0.5 * (
156 | 1 - alpha
157 | ) + alpha
158 |
159 | for g in self.optimizer.param_groups[:1]:
160 | g["lr"] = self.learning_rate_geo * learning_factor
161 |
162 | def get_cos_anneal_ratio(self):
163 | if self.anneal_end == 0.0:
164 | return 1.0
165 | else:
166 | return np.min([1.0, self.iter_step / self.anneal_end])
167 |
168 | def train(self):
169 | self.train_udf()
170 |
171 | def get_flip_saturation(self, flip_saturation_max=0.9):
172 | start = 10000
173 | if self.iter_step < start:
174 | flip_saturation = 0.0
175 | elif self.iter_step < self.end_iter * 0.5:
176 | flip_saturation = flip_saturation_max
177 | else:
178 | flip_saturation = 1.0
179 |
180 | return flip_saturation
181 |
182 | def file_backup(self):
183 | # copy python file
184 | dir_lis = self.conf["general.recording"]
185 | cur_dir = os.path.join(self.base_exp_dir, "recording")
186 | os.makedirs(cur_dir, exist_ok=True)
187 | files = os.listdir("./")
188 | for f_name in files:
189 | if f_name[-3:] == ".py":
190 | copyfile(os.path.join("./", f_name), os.path.join(cur_dir, f_name))
191 |
192 | for dir_name in dir_lis:
193 | os.system(f"cp -r {dir_name} {cur_dir}")
194 |
195 | # copy configs
196 | # copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf'))
197 | with open(
198 | os.path.join(self.base_exp_dir, "recording", "config.conf"), "w"
199 | ) as fd:
200 | res = HOCONConverter.to_hocon(self.conf)
201 | fd.write(res)
202 |
203 | def train_udf(self):
204 | return NotImplementedError
205 |
206 | def load_checkpoint(self, checkpoint_name):
207 | return NotImplementedError
208 |
209 | def save_checkpoint(self):
210 | return NotImplementedError
211 |
212 | def validate(self, idx=-1, resolution_level=-1):
213 | return NotImplementedError
214 |
--------------------------------------------------------------------------------
/src/utils/__init__.py:
--------------------------------------------------------------------------------
1 | import torch
2 |
3 | # optimizer
4 | from torch.optim import SGD, Adam
5 | import torch_optimizer as optim
6 |
7 | # scheduler
8 | from torch.optim.lr_scheduler import CosineAnnealingLR, MultiStepLR
9 | from .warmup_scheduler import GradualWarmupScheduler
10 |
11 | from .visualization import *
12 | from .plots import *
13 |
14 |
15 | def get_parameters(models):
16 | """Get all model parameters recursively."""
17 | parameters = []
18 | if isinstance(models, list):
19 | for model in models:
20 | parameters += get_parameters(model)
21 | elif isinstance(models, dict):
22 | for model in models.values():
23 | parameters += get_parameters(model)
24 | else: # models is actually a single pytorch model
25 | parameters += list(models.parameters())
26 | return parameters
27 |
28 |
29 | def get_optimizer(hparams, models):
30 | eps = 1e-8
31 | parameters = get_parameters(models)
32 | if hparams.optimizer == "sgd":
33 | optimizer = SGD(
34 | parameters,
35 | lr=hparams.lr,
36 | momentum=hparams.momentum,
37 | weight_decay=hparams.weight_decay,
38 | )
39 | elif hparams.optimizer == "adam":
40 | optimizer = Adam(
41 | parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay
42 | )
43 | elif hparams.optimizer == "radam":
44 | optimizer = optim.RAdam(
45 | parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay
46 | )
47 | elif hparams.optimizer == "ranger":
48 | optimizer = optim.Ranger(
49 | parameters, lr=hparams.lr, eps=eps, weight_decay=hparams.weight_decay
50 | )
51 | else:
52 | raise ValueError("optimizer not recognized!")
53 |
54 | return optimizer
55 |
56 |
57 | def get_scheduler(hparams, optimizer):
58 | eps = 1e-8
59 | if hparams.lr_scheduler == "steplr":
60 | scheduler = MultiStepLR(
61 | optimizer, milestones=hparams.decay_step, gamma=hparams.decay_gamma
62 | )
63 | elif hparams.lr_scheduler == "cosine":
64 | scheduler = CosineAnnealingLR(optimizer, T_max=hparams.num_epochs, eta_min=eps)
65 | elif hparams.lr_scheduler == "poly":
66 | scheduler = LambdaLR(
67 | optimizer,
68 | lambda epoch: (1 - epoch / hparams.num_epochs) ** hparams.poly_exp,
69 | )
70 | else:
71 | raise ValueError("scheduler not recognized!")
72 |
73 | if hparams.warmup_epochs > 0 and hparams.optimizer not in ["radam", "ranger"]:
74 | scheduler = GradualWarmupScheduler(
75 | optimizer,
76 | multiplier=hparams.warmup_multiplier,
77 | total_epoch=hparams.warmup_epochs,
78 | after_scheduler=scheduler,
79 | )
80 |
81 | return scheduler
82 |
83 |
84 | def get_learning_rate(optimizer):
85 | for param_group in optimizer.param_groups:
86 | return param_group["lr"]
87 |
88 |
89 | def extract_model_state_dict(ckpt_path, model_name="model", prefixes_to_ignore=[]):
90 | checkpoint = torch.load(ckpt_path, map_location=torch.device("cpu"))
91 | checkpoint_ = {}
92 | if "state_dict" in checkpoint: # if it's a pytorch-lightning checkpoint
93 | checkpoint = checkpoint["state_dict"]
94 | for k, v in checkpoint.items():
95 | if not k.startswith(model_name):
96 | continue
97 | k = k[len(model_name) + 1 :]
98 | for prefix in prefixes_to_ignore:
99 | if k.startswith(prefix):
100 | print("ignore", k)
101 | break
102 | else:
103 | checkpoint_[k] = v
104 | return checkpoint_
105 |
106 |
107 | def load_ckpt(model, ckpt_path, model_name="model", prefixes_to_ignore=[]):
108 | if not ckpt_path:
109 | return
110 | model_dict = model.state_dict()
111 | checkpoint_ = extract_model_state_dict(ckpt_path, model_name, prefixes_to_ignore)
112 | model_dict.update(checkpoint_)
113 | model.load_state_dict(model_dict)
114 |
--------------------------------------------------------------------------------
/src/utils/math.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The Nerfstudio Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """ Math Helper Functions """
16 |
17 | from dataclasses import dataclass
18 |
19 | import torch
20 | from torchtyping import TensorType
21 |
22 |
23 | def components_from_spherical_harmonics(
24 | levels: int, directions: TensorType[..., 3]
25 | ) -> TensorType[..., "components"]:
26 | """
27 | Returns value for each component of spherical harmonics.
28 |
29 | Args:
30 | levels: Number of spherical harmonic levels to compute.
31 | directions: Spherical hamonic coefficients
32 | """
33 | num_components = levels**2
34 | components = torch.zeros(
35 | (*directions.shape[:-1], num_components), device=directions.device
36 | )
37 |
38 | assert 1 <= levels <= 5, f"SH levels must be in [1,4], got {levels}"
39 | assert (
40 | directions.shape[-1] == 3
41 | ), f"Direction input should have three dimensions. Got {directions.shape[-1]}"
42 |
43 | x = directions[..., 0]
44 | y = directions[..., 1]
45 | z = directions[..., 2]
46 |
47 | xx = x**2
48 | yy = y**2
49 | zz = z**2
50 |
51 | # l0
52 | components[..., 0] = 0.28209479177387814
53 |
54 | # l1
55 | if levels > 1:
56 | components[..., 1] = 0.4886025119029199 * y
57 | components[..., 2] = 0.4886025119029199 * z
58 | components[..., 3] = 0.4886025119029199 * x
59 |
60 | # l2
61 | if levels > 2:
62 | components[..., 4] = 1.0925484305920792 * x * y
63 | components[..., 5] = 1.0925484305920792 * y * z
64 | components[..., 6] = 0.9461746957575601 * zz - 0.31539156525251999
65 | components[..., 7] = 1.0925484305920792 * x * z
66 | components[..., 8] = 0.5462742152960396 * (xx - yy)
67 |
68 | # l3
69 | if levels > 3:
70 | components[..., 9] = 0.5900435899266435 * y * (3 * xx - yy)
71 | components[..., 10] = 2.890611442640554 * x * y * z
72 | components[..., 11] = 0.4570457994644658 * y * (5 * zz - 1)
73 | components[..., 12] = 0.3731763325901154 * z * (5 * zz - 3)
74 | components[..., 13] = 0.4570457994644658 * x * (5 * zz - 1)
75 | components[..., 14] = 1.445305721320277 * z * (xx - yy)
76 | components[..., 15] = 0.5900435899266435 * x * (xx - 3 * yy)
77 |
78 | # l4
79 | if levels > 4:
80 | components[..., 16] = 2.5033429417967046 * x * y * (xx - yy)
81 | components[..., 17] = 1.7701307697799304 * y * z * (3 * xx - yy)
82 | components[..., 18] = 0.9461746957575601 * x * y * (7 * zz - 1)
83 | components[..., 19] = 0.6690465435572892 * y * (7 * zz - 3)
84 | components[..., 20] = 0.10578554691520431 * (35 * zz * zz - 30 * zz + 3)
85 | components[..., 21] = 0.6690465435572892 * x * z * (7 * zz - 3)
86 | components[..., 22] = 0.47308734787878004 * (xx - yy) * (7 * zz - 1)
87 | components[..., 23] = 1.7701307697799304 * x * z * (xx - 3 * yy)
88 | components[..., 24] = 0.4425326924449826 * (
89 | xx * (xx - 3 * yy) - yy * (3 * xx - yy)
90 | )
91 |
92 | return components
93 |
94 |
95 | @dataclass
96 | class Gaussians:
97 | """Stores Gaussians
98 |
99 | Args:
100 | mean: Mean of multivariate Gaussian
101 | cov: Covariance of multivariate Gaussian.
102 | """
103 |
104 | mean: TensorType[..., "dim"]
105 | cov: TensorType[..., "dim", "dim"]
106 |
107 |
108 | def compute_3d_gaussian(
109 | directions: TensorType[..., 3],
110 | means: TensorType[..., 3],
111 | dir_variance: TensorType[..., 1],
112 | radius_variance: TensorType[..., 1],
113 | ) -> Gaussians:
114 | """Compute guassian along ray.
115 |
116 | Args:
117 | directions: Axis of Gaussian.
118 | means: Mean of Gaussian.
119 | dir_variance: Variance along direction axis.
120 | radius_variance: Variance tangent to direction axis.
121 |
122 | Returns:
123 | Gaussians: Oriented 3D gaussian.
124 | """
125 |
126 | dir_outer_product = directions[..., :, None] * directions[..., None, :]
127 | eye = torch.eye(directions.shape[-1], device=directions.device)
128 | dir_mag_sq = torch.clamp(
129 | torch.sum(directions**2, dim=-1, keepdim=True), min=1e-10
130 | )
131 | null_outer_product = (
132 | eye - directions[..., :, None] * (directions / dir_mag_sq)[..., None, :]
133 | )
134 | dir_cov_diag = dir_variance[..., None] * dir_outer_product[..., :, :]
135 | radius_cov_diag = radius_variance[..., None] * null_outer_product[..., :, :]
136 | cov = dir_cov_diag + radius_cov_diag
137 | return Gaussians(mean=means, cov=cov)
138 |
139 |
140 | def cylinder_to_gaussian(
141 | origins: TensorType[..., 3],
142 | directions: TensorType[..., 3],
143 | starts: TensorType[..., 1],
144 | ends: TensorType[..., 1],
145 | radius: TensorType[..., 1],
146 | ) -> Gaussians:
147 | """Approximates cylinders with a Gaussian distributions.
148 |
149 | Args:
150 | origins: Origins of cylinders.
151 | directions: Direction (axis) of cylinders.
152 | starts: Start of cylinders.
153 | ends: End of cylinders.
154 | radius: Radii of cylinders.
155 |
156 | Returns:
157 | Gaussians: Approximation of cylinders
158 | """
159 | means = origins + directions * ((starts + ends) / 2.0)
160 | dir_variance = (ends - starts) ** 2 / 12
161 | radius_variance = radius**2 / 4.0
162 | return compute_3d_gaussian(directions, means, dir_variance, radius_variance)
163 |
164 |
165 | def conical_frustum_to_gaussian(
166 | origins: TensorType[..., 3],
167 | directions: TensorType[..., 3],
168 | starts: TensorType[..., 1],
169 | ends: TensorType[..., 1],
170 | radius: TensorType[..., 1],
171 | ) -> Gaussians:
172 | """Approximates conical frustums with a Gaussian distributions.
173 |
174 | Uses stable parameterization described in mip-NeRF publication.
175 |
176 | Args:
177 | origins: Origins of cones.
178 | directions: Direction (axis) of frustums.
179 | starts: Start of conical frustums.
180 | ends: End of conical frustums.
181 | radius: Radii of cone a distance of 1 from the origin.
182 |
183 | Returns:
184 | Gaussians: Approximation of conical frustums
185 | """
186 | mu = (starts + ends) / 2.0
187 | hw = (ends - starts) / 2.0
188 | means = origins + directions * (
189 | mu + (2.0 * mu * hw**2.0) / (3.0 * mu**2.0 + hw**2.0)
190 | )
191 | dir_variance = (hw**2) / 3 - (4 / 15) * (
192 | (hw**4 * (12 * mu**2 - hw**2)) / (3 * mu**2 + hw**2) ** 2
193 | )
194 | radius_variance = radius**2 * (
195 | (mu**2) / 4
196 | + (5 / 12) * hw**2
197 | - 4 / 15 * (hw**4) / (3 * mu**2 + hw**2)
198 | )
199 | return compute_3d_gaussian(directions, means, dir_variance, radius_variance)
200 |
201 |
202 | def expected_sin(x_means: torch.Tensor, x_vars: torch.Tensor) -> torch.Tensor:
203 | """Computes the expected value of sin(y) where y ~ N(x_means, x_vars)
204 |
205 | Args:
206 | x_means: Mean values.
207 | x_vars: Variance of values.
208 |
209 | Returns:
210 | torch.Tensor: The expected value of sin.
211 | """
212 |
213 | return torch.exp(-0.5 * x_vars) * torch.sin(x_means)
214 |
--------------------------------------------------------------------------------
/src/utils/rend_util.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import imageio
3 | import skimage
4 | import cv2
5 | import torch
6 | from torch.nn import functional as F
7 |
8 |
9 | def get_psnr(img1, img2, normalize_rgb=False):
10 | if normalize_rgb: # [-1,1] --> [0,1]
11 | img1 = (img1 + 1.0) / 2.0
12 | img2 = (img2 + 1.0) / 2.0
13 |
14 | mse = torch.mean((img1 - img2) ** 2)
15 | psnr = -10.0 * torch.log(mse) / torch.log(torch.Tensor([10.0]).cuda())
16 |
17 | return psnr
18 |
19 |
20 | def load_rgb(path, normalize_rgb=False):
21 | img = imageio.imread(path)
22 | img = skimage.img_as_float32(img)
23 |
24 | if normalize_rgb: # [-1,1] --> [0,1]
25 | img -= 0.5
26 | img *= 2.0
27 | img = img.transpose(2, 0, 1)
28 | return img
29 |
30 |
31 | def load_K_Rt_from_P(filename, P=None):
32 | if P is None:
33 | lines = open(filename).read().splitlines()
34 | if len(lines) == 4:
35 | lines = lines[1:]
36 | lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
37 | P = np.asarray(lines).astype(np.float32).squeeze()
38 |
39 | out = cv2.decomposeProjectionMatrix(P)
40 | K = out[0]
41 | R = out[1]
42 | t = out[2]
43 |
44 | K = K / K[2, 2]
45 | intrinsics = np.eye(4)
46 | intrinsics[:3, :3] = K
47 |
48 | pose = np.eye(4, dtype=np.float32)
49 | pose[:3, :3] = R.transpose()
50 | pose[:3, 3] = (t[:3] / t[3])[:, 0]
51 |
52 | return intrinsics, pose
53 |
54 |
55 | def get_camera_params(uv, pose, intrinsics):
56 | if pose.shape[1] == 7: # In case of quaternion vector representation
57 | cam_loc = pose[:, 4:]
58 | R = quat_to_rot(pose[:, :4])
59 | p = torch.eye(4).repeat(pose.shape[0], 1, 1).cuda().float()
60 | p[:, :3, :3] = R
61 | p[:, :3, 3] = cam_loc
62 | else: # In case of pose matrix representation
63 | cam_loc = pose[:, :3, 3]
64 | p = pose
65 |
66 | batch_size, num_samples, _ = uv.shape
67 |
68 | depth = torch.ones((batch_size, num_samples)).cuda()
69 | x_cam = uv[:, :, 0].view(batch_size, -1)
70 | y_cam = uv[:, :, 1].view(batch_size, -1)
71 | z_cam = depth.view(batch_size, -1)
72 |
73 | pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics)
74 |
75 | # permute for batch matrix product
76 | pixel_points_cam = pixel_points_cam.permute(0, 2, 1)
77 |
78 | world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3]
79 | ray_dirs = world_coords - cam_loc[:, None, :]
80 | ray_dirs = F.normalize(ray_dirs, dim=2)
81 |
82 | return ray_dirs, cam_loc
83 |
84 |
85 | def get_camera_for_plot(pose):
86 | if pose.shape[1] == 7: # In case of quaternion vector representation
87 | cam_loc = pose[:, 4:].detach()
88 | R = quat_to_rot(pose[:, :4].detach())
89 | else: # In case of pose matrix representation
90 | cam_loc = pose[:, :3, 3]
91 | R = pose[:, :3, :3]
92 | cam_dir = R[:, :3, 2]
93 | return cam_loc, cam_dir
94 |
95 |
96 | def lift(x, y, z, intrinsics):
97 | # parse intrinsics
98 | intrinsics = intrinsics.cuda()
99 | fx = intrinsics[:, 0, 0]
100 | fy = intrinsics[:, 1, 1]
101 | cx = intrinsics[:, 0, 2]
102 | cy = intrinsics[:, 1, 2]
103 | sk = intrinsics[:, 0, 1]
104 |
105 | x_lift = (
106 | (
107 | x
108 | - cx.unsqueeze(-1)
109 | + cy.unsqueeze(-1) * sk.unsqueeze(-1) / fy.unsqueeze(-1)
110 | - sk.unsqueeze(-1) * y / fy.unsqueeze(-1)
111 | )
112 | / fx.unsqueeze(-1)
113 | * z
114 | )
115 | y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
116 |
117 | # homogeneous
118 | return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1)
119 |
120 |
121 | def quat_to_rot(q):
122 | batch_size, _ = q.shape
123 | q = F.normalize(q, dim=1)
124 | R = torch.ones((batch_size, 3, 3)).cuda()
125 | qr = q[:, 0]
126 | qi = q[:, 1]
127 | qj = q[:, 2]
128 | qk = q[:, 3]
129 | R[:, 0, 0] = 1 - 2 * (qj**2 + qk**2)
130 | R[:, 0, 1] = 2 * (qj * qi - qk * qr)
131 | R[:, 0, 2] = 2 * (qi * qk + qr * qj)
132 | R[:, 1, 0] = 2 * (qj * qi + qk * qr)
133 | R[:, 1, 1] = 1 - 2 * (qi**2 + qk**2)
134 | R[:, 1, 2] = 2 * (qj * qk - qi * qr)
135 | R[:, 2, 0] = 2 * (qk * qi - qj * qr)
136 | R[:, 2, 1] = 2 * (qj * qk + qi * qr)
137 | R[:, 2, 2] = 1 - 2 * (qi**2 + qj**2)
138 | return R
139 |
140 |
141 | def rot_to_quat(R):
142 | batch_size, _, _ = R.shape
143 | q = torch.ones((batch_size, 4)).cuda()
144 |
145 | R00 = R[:, 0, 0]
146 | R01 = R[:, 0, 1]
147 | R02 = R[:, 0, 2]
148 | R10 = R[:, 1, 0]
149 | R11 = R[:, 1, 1]
150 | R12 = R[:, 1, 2]
151 | R20 = R[:, 2, 0]
152 | R21 = R[:, 2, 1]
153 | R22 = R[:, 2, 2]
154 |
155 | q[:, 0] = torch.sqrt(1.0 + R00 + R11 + R22) / 2
156 | q[:, 1] = (R21 - R12) / (4 * q[:, 0])
157 | q[:, 2] = (R02 - R20) / (4 * q[:, 0])
158 | q[:, 3] = (R10 - R01) / (4 * q[:, 0])
159 | return q
160 |
161 |
162 | def get_sphere_intersections(cam_loc, ray_directions, r=1.0):
163 | # Input: n_rays x 3 ; n_rays x 3
164 | # Output: n_rays x 1, n_rays x 1 (close and far)
165 |
166 | ray_cam_dot = torch.bmm(
167 | ray_directions.view(-1, 1, 3), cam_loc.view(-1, 3, 1)
168 | ).squeeze(-1)
169 | under_sqrt = ray_cam_dot**2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r**2)
170 |
171 | # sanity check
172 | if (under_sqrt <= 0).sum() > 0:
173 | print("BOUNDING SPHERE PROBLEM!")
174 | exit()
175 |
176 | sphere_intersections = (
177 | torch.sqrt(under_sqrt) * torch.Tensor([-1, 1]).cuda().float() - ray_cam_dot
178 | )
179 | sphere_intersections = sphere_intersections.clamp_min(0.0)
180 |
181 | return sphere_intersections
182 |
--------------------------------------------------------------------------------
/src/utils/tensor_dataclass.py:
--------------------------------------------------------------------------------
1 | # Copyright 2022 The Nerfstudio Team. All rights reserved.
2 | #
3 | # Licensed under the Apache License, Version 2.0 (the "License");
4 | # you may not use this file except in compliance with the License.
5 | # You may obtain a copy of the License at
6 | #
7 | # http://www.apache.org/licenses/LICENSE-2.0
8 | #
9 | # Unless required by applicable law or agreed to in writing, software
10 | # distributed under the License is distributed on an "AS IS" BASIS,
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 | # See the License for the specific language governing permissions and
13 | # limitations under the License.
14 |
15 | """Tensor dataclass"""
16 |
17 | import dataclasses
18 | from copy import deepcopy
19 | from typing import Callable, Dict, List, NoReturn, Optional, Tuple, TypeVar, Union
20 |
21 | import numpy as np
22 | import torch
23 |
24 | TensorDataclassT = TypeVar("TensorDataclassT", bound="TensorDataclass")
25 |
26 |
27 | class TensorDataclass:
28 | """@dataclass of tensors with the same size batch. Allows indexing and standard tensor ops.
29 | Fields that are not Tensors will not be batched unless they are also a TensorDataclass.
30 | Any fields that are dictionaries will have their Tensors or TensorDataclasses batched, and
31 | dictionaries will have their tensors or TensorDataclasses considered in the initial broadcast.
32 | Tensor fields must have at least 1 dimension, meaning that you must convert a field like torch.Tensor(1)
33 | to torch.Tensor([1])
34 |
35 | Example:
36 |
37 | .. code-block:: python
38 |
39 | @dataclass
40 | class TestTensorDataclass(TensorDataclass):
41 | a: torch.Tensor
42 | b: torch.Tensor
43 | c: torch.Tensor = None
44 |
45 | # Create a new tensor dataclass with batch size of [2,3,4]
46 | test = TestTensorDataclass(a=torch.ones((2, 3, 4, 2)), b=torch.ones((4, 3)))
47 |
48 | test.shape # [2, 3, 4]
49 | test.a.shape # [2, 3, 4, 2]
50 | test.b.shape # [2, 3, 4, 3]
51 |
52 | test.reshape((6,4)).shape # [6, 4]
53 | test.flatten().shape # [24,]
54 |
55 | test[..., 0].shape # [2, 3]
56 | test[:, 0, :].shape # [2, 4]
57 | """
58 |
59 | _shape: tuple
60 |
61 | # A mapping from field-name (str): n (int)
62 | # Any field OR any key in a dictionary field with this name (field-name) and a corresponding
63 | # torch.Tensor will be assumed to have n dimensions after the batch dims. These n final dimensions
64 | # will remain the same shape when doing reshapes, broadcasting, etc on the tensordataclass
65 | _field_custom_dimensions: Dict[str, int] = {}
66 |
67 | def __post_init__(self) -> None:
68 | """Finishes setting up the TensorDataclass
69 |
70 | This will 1) find the broadcasted shape and 2) broadcast all fields to this shape 3)
71 | set _shape to be the broadcasted shape.
72 | """
73 | if self._field_custom_dimensions is not None:
74 | for k, v in self._field_custom_dimensions.items():
75 | assert (
76 | isinstance(v, int) and v > 1
77 | ), f"Custom dimensions must be an integer greater than 1, since 1 is the default, received {k}: {v}"
78 |
79 | if not dataclasses.is_dataclass(self):
80 | raise TypeError("TensorDataclass must be a dataclass")
81 |
82 | batch_shapes = self._get_dict_batch_shapes(
83 | {f.name: self.__getattribute__(f.name) for f in dataclasses.fields(self)}
84 | )
85 | if len(batch_shapes) == 0:
86 | raise ValueError("TensorDataclass must have at least one tensor")
87 | batch_shape = torch.broadcast_shapes(*batch_shapes)
88 |
89 | broadcasted_fields = self._broadcast_dict_fields(
90 | {f.name: self.__getattribute__(f.name) for f in dataclasses.fields(self)},
91 | batch_shape,
92 | )
93 | for f, v in broadcasted_fields.items():
94 | self.__setattr__(f, v)
95 |
96 | self.__setattr__("_shape", batch_shape)
97 |
98 | def _get_dict_batch_shapes(self, dict_: Dict) -> List:
99 | """Returns batch shapes of all tensors in a dictionary
100 |
101 | Args:
102 | dict_: The dictionary to get the batch shapes of.
103 |
104 | Returns:
105 | The batch shapes of all tensors in the dictionary.
106 | """
107 | batch_shapes = []
108 | for k, v in dict_.items():
109 | if isinstance(v, torch.Tensor):
110 | if (
111 | isinstance(self._field_custom_dimensions, dict)
112 | and k in self._field_custom_dimensions
113 | ):
114 | # pylint: disable=unsubscriptable-object
115 | batch_shapes.append(v.shape[: -self._field_custom_dimensions[k]])
116 | else:
117 | batch_shapes.append(v.shape[:-1])
118 | elif isinstance(v, TensorDataclass):
119 | batch_shapes.append(v.shape)
120 | elif isinstance(v, Dict):
121 | batch_shapes.extend(self._get_dict_batch_shapes(v))
122 | return batch_shapes
123 |
124 | def _broadcast_dict_fields(self, dict_: Dict, batch_shape) -> Dict:
125 | """Broadcasts all tensors in a dictionary according to batch_shape
126 |
127 | Args:
128 | dict_: The dictionary to broadcast.
129 |
130 | Returns:
131 | The broadcasted dictionary.
132 | """
133 | new_dict = {}
134 | for k, v in dict_.items():
135 | if isinstance(v, torch.Tensor):
136 | # If custom dimension key, then we need to
137 | if (
138 | isinstance(self._field_custom_dimensions, dict)
139 | and k in self._field_custom_dimensions
140 | ):
141 | # pylint: disable=unsubscriptable-object
142 | new_dict[k] = v.broadcast_to(
143 | (
144 | *batch_shape,
145 | *v.shape[-self._field_custom_dimensions[k] :],
146 | )
147 | )
148 | else:
149 | new_dict[k] = v.broadcast_to((*batch_shape, v.shape[-1]))
150 | elif isinstance(v, TensorDataclass):
151 | new_dict[k] = v.broadcast_to(batch_shape)
152 | elif isinstance(v, Dict):
153 | new_dict[k] = self._broadcast_dict_fields(v, batch_shape)
154 | return new_dict
155 |
156 | def __getitem__(self: TensorDataclassT, indices) -> TensorDataclassT:
157 | if isinstance(indices, (torch.Tensor)):
158 | return self._apply_fn_to_fields(lambda x: x[indices])
159 | if isinstance(indices, (int, slice, type(Ellipsis))):
160 | indices = (indices,)
161 | assert isinstance(indices, tuple)
162 | tensor_fn = lambda x: x[indices + (slice(None),)]
163 | dataclass_fn = lambda x: x[indices]
164 |
165 | def custom_tensor_dims_fn(k, v):
166 | custom_dims = self._field_custom_dimensions[
167 | k
168 | ] # pylint: disable=unsubscriptable-object
169 | return v[indices + ((slice(None),) * custom_dims)]
170 |
171 | return self._apply_fn_to_fields(
172 | tensor_fn, dataclass_fn, custom_tensor_dims_fn=custom_tensor_dims_fn
173 | )
174 |
175 | def __setitem__(self, indices, value) -> NoReturn:
176 | raise RuntimeError("Index assignment is not supported for TensorDataclass")
177 |
178 | def __len__(self) -> int:
179 | if len(self._shape) == 0:
180 | raise TypeError("len() of a 0-d tensor")
181 | return self.shape[0]
182 |
183 | def __bool__(self) -> bool:
184 | if len(self) == 0:
185 | raise ValueError(
186 | f"The truth value of {self.__class__.__name__} when `len(x) == 0` "
187 | "is ambiguous. Use `len(x)` or `x is not None`."
188 | )
189 | return True
190 |
191 | @property
192 | def shape(self) -> Tuple[int, ...]:
193 | """Returns the batch shape of the tensor dataclass."""
194 | return self._shape
195 |
196 | @property
197 | def size(self) -> int:
198 | """Returns the number of elements in the tensor dataclass batch dimension."""
199 | if len(self._shape) == 0:
200 | return 1
201 | return int(np.prod(self._shape))
202 |
203 | @property
204 | def ndim(self) -> int:
205 | """Returns the number of dimensions of the tensor dataclass."""
206 | return len(self._shape)
207 |
208 | def reshape(self: TensorDataclassT, shape: Tuple[int, ...]) -> TensorDataclassT:
209 | """Returns a new TensorDataclass with the same data but with a new shape.
210 |
211 | This should deepcopy as well.
212 |
213 | Args:
214 | shape: The new shape of the tensor dataclass.
215 |
216 | Returns:
217 | A new TensorDataclass with the same data but with a new shape.
218 | """
219 | if isinstance(shape, int):
220 | shape = (shape,)
221 | tensor_fn = lambda x: x.reshape((*shape, x.shape[-1]))
222 | dataclass_fn = lambda x: x.reshape(shape)
223 |
224 | def custom_tensor_dims_fn(k, v):
225 | custom_dims = self._field_custom_dimensions[
226 | k
227 | ] # pylint: disable=unsubscriptable-object
228 | return v.reshape((*shape, *v.shape[-custom_dims:]))
229 |
230 | return self._apply_fn_to_fields(
231 | tensor_fn, dataclass_fn, custom_tensor_dims_fn=custom_tensor_dims_fn
232 | )
233 |
234 | def flatten(self: TensorDataclassT) -> TensorDataclassT:
235 | """Returns a new TensorDataclass with flattened batch dimensions
236 |
237 | Returns:
238 | TensorDataclass: A new TensorDataclass with the same data but with a new shape.
239 | """
240 | return self.reshape((-1,))
241 |
242 | def broadcast_to(
243 | self: TensorDataclassT, shape: Union[torch.Size, Tuple[int, ...]]
244 | ) -> TensorDataclassT:
245 | """Returns a new TensorDataclass broadcast to new shape.
246 |
247 | Changes to the original tensor dataclass should effect the returned tensor dataclass,
248 | meaning it is NOT a deepcopy, and they are still linked.
249 |
250 | Args:
251 | shape: The new shape of the tensor dataclass.
252 |
253 | Returns:
254 | A new TensorDataclass with the same data but with a new shape.
255 | """
256 |
257 | def custom_tensor_dims_fn(k, v):
258 | custom_dims = self._field_custom_dimensions[
259 | k
260 | ] # pylint: disable=unsubscriptable-object
261 | return v.broadcast_to((*shape, *v.shape[-custom_dims:]))
262 |
263 | return self._apply_fn_to_fields(
264 | lambda x: x.broadcast_to((*shape, x.shape[-1])),
265 | custom_tensor_dims_fn=custom_tensor_dims_fn,
266 | )
267 |
268 | def to(self: TensorDataclassT, device) -> TensorDataclassT:
269 | """Returns a new TensorDataclass with the same data but on the specified device.
270 |
271 | Args:
272 | device: The device to place the tensor dataclass.
273 |
274 | Returns:
275 | A new TensorDataclass with the same data but on the specified device.
276 | """
277 | return self._apply_fn_to_fields(lambda x: x.to(device))
278 |
279 | def _apply_fn_to_fields(
280 | self: TensorDataclassT,
281 | fn: Callable,
282 | dataclass_fn: Optional[Callable] = None,
283 | custom_tensor_dims_fn: Optional[Callable] = None,
284 | ) -> TensorDataclassT:
285 | """Applies a function to all fields of the tensor dataclass.
286 |
287 | TODO: Someone needs to make a high level design choice for whether not not we want this
288 | to apply the function to any fields in arbitray superclasses. This is an edge case until we
289 | upgrade to python 3.10 and dataclasses can actually be subclassed with vanilla python and no
290 | janking, but if people try to jank some subclasses that are grandchildren of TensorDataclass
291 | (imagine if someone tries to subclass the RayBundle) this will matter even before upgrading
292 | to 3.10 . Currently we aren't going to be able to work properly for grandchildren, but you
293 | want to use self.__dict__ if you want to apply this to grandchildren instead of our dictionary
294 | from dataclasses.fields(self) as we do below and in other places.
295 |
296 | Args:
297 | fn: The function to apply to tensor fields.
298 | dataclass_fn: The function to apply to TensorDataclass fields.
299 |
300 | Returns:
301 | A new TensorDataclass with the same data but with a new shape.
302 | """
303 |
304 | new_fields = self._apply_fn_to_dict(
305 | {f.name: self.__getattribute__(f.name) for f in dataclasses.fields(self)},
306 | fn,
307 | dataclass_fn,
308 | custom_tensor_dims_fn,
309 | )
310 |
311 | return dataclasses.replace(self, **new_fields)
312 |
313 | def _apply_fn_to_dict(
314 | self,
315 | dict_: Dict,
316 | fn: Callable,
317 | dataclass_fn: Optional[Callable] = None,
318 | custom_tensor_dims_fn: Optional[Callable] = None,
319 | ) -> Dict:
320 | """A helper function for _apply_fn_to_fields, applying a function to all fields of dict_
321 |
322 | Args:
323 | dict_: The dictionary to apply the function to.
324 | fn: The function to apply to tensor fields.
325 | dataclass_fn: The function to apply to TensorDataclass fields.
326 |
327 | Returns:
328 | A new dictionary with the same data but with a new shape. Will deep copy"""
329 |
330 | field_names = dict_.keys()
331 | new_dict = {}
332 | for f in field_names:
333 | v = dict_[f]
334 | if v is not None:
335 | if isinstance(v, TensorDataclass) and dataclass_fn is not None:
336 | new_dict[f] = dataclass_fn(v)
337 | # This is the case when we have a custom dimensions tensor
338 | elif (
339 | isinstance(v, torch.Tensor)
340 | and isinstance(self._field_custom_dimensions, dict)
341 | and f in self._field_custom_dimensions
342 | and custom_tensor_dims_fn is not None
343 | ):
344 | new_dict[f] = custom_tensor_dims_fn(f, v)
345 | elif isinstance(v, (torch.Tensor, TensorDataclass)):
346 | new_dict[f] = fn(v)
347 | elif isinstance(v, Dict):
348 | new_dict[f] = self._apply_fn_to_dict(v, fn, dataclass_fn)
349 | else:
350 | new_dict[f] = deepcopy(v)
351 |
352 | return new_dict
353 |
--------------------------------------------------------------------------------
/src/utils/visualization.py:
--------------------------------------------------------------------------------
1 | import torchvision.transforms as T
2 | import numpy as np
3 | import cv2
4 | from PIL import Image
5 | import flow_vis
6 | import torch
7 |
8 |
9 | def visualize_depth(x, cmap=cv2.COLORMAP_JET):
10 | """
11 | depth: (H, W)
12 | """
13 | # x = depth.cpu().numpy()
14 | x = np.nan_to_num(x) # change nan to 0
15 | mi = np.min(x) # get minimum depth
16 | ma = np.max(x)
17 | x = (x - mi) / max(ma - mi, 1e-8) # normalize to 0~1
18 | x = (255 * x).astype(np.uint8)
19 | # x_ = Image.fromarray(cv2.applyColorMap(x, cmap))
20 | # x_ = T.ToTensor()(x_) # (3, H, W)
21 | x_ = np.array(cv2.applyColorMap(x, cmap))
22 | return x_
23 |
24 |
25 | def get_flow_vis(ang):
26 | # norm = line_neighborhood + 1 - np.clip(df, 0, line_neighborhood)
27 | flow_uv = 5 * np.stack([np.cos(ang), np.sin(ang)], axis=-1)
28 | flow_img = flow_vis.flow_to_color(flow_uv, convert_to_bgr=False)
29 | return flow_img
30 |
31 |
32 | COLOR_MAP_ = np.array(
33 | [
34 | [0.9047944201469568, 0.3241718265806123, 0.33443746665210006],
35 | [0.4590171386127151, 0.9095038146383864, 0.3143840671974788],
36 | [0.4769356899795538, 0.5044406738441948, 0.5354530846360839],
37 | [0.00820945625670777, 0.24099210193126785, 0.15471834055332978],
38 | [0.6195684374237388, 0.4020380013509799, 0.26100266066404676],
39 | [0.08281237756545068, 0.05900744492710419, 0.06106221202154216],
40 | [0.2264886829978755, 0.04925271007292076, 0.10214429345996079],
41 | [0.1888247470009874, 0.11275000298612425, 0.46112894830685514],
42 | [0.37415767691880975, 0.844284596118331, 0.950471611180866],
43 | [0.3817344218157631, 0.3483259270707101, 0.6572989333690541],
44 | [0.2403115731054466, 0.03078280287279167, 0.5385975692534737],
45 | [0.7035076951650824, 0.12352084932325424, 0.12873080308790197],
46 | [0.12607434914489934, 0.111244793010015, 0.09333334699716023],
47 | [0.6551607300342269, 0.7003064103554443, 0.4131794512286162],
48 | [0.13592107365596595, 0.5390702818232149, 0.004540643174930525],
49 | [0.38286244894454347, 0.709142545393449, 0.529074791609835],
50 | [0.4279376583651734, 0.5634708596431771, 0.8505569717104301],
51 | [0.3460488523902999, 0.464769595519293, 0.6676839675477276],
52 | [0.8544063246675081, 0.5041190233407755, 0.9081217697141578],
53 | [0.9207009090747208, 0.2403865944739051, 0.05375410999863772],
54 | [0.6515786136947107, 0.6299918449948327, 0.45292029442034387],
55 | [0.986174217295693, 0.2424849846977214, 0.3981993323108266],
56 | [0.22101915872994693, 0.3408589198278038, 0.006381420347677524],
57 | [0.3159785813515982, 0.1145748921741011, 0.595754317197274],
58 | [0.10263421488052715, 0.5864139253490858, 0.23908000741142432],
59 | [0.8272999391532938, 0.6123527260897751, 0.3365197327803193],
60 | [0.5269583712937912, 0.25668929554516506, 0.7888411215078127],
61 | [0.2433880265410031, 0.7240751234287827, 0.8483215810528648],
62 | [0.7254601709704898, 0.8316525547295984, 0.9325253855921963],
63 | [0.5574483824856672, 0.2935331727879944, 0.6594839453793155],
64 | [0.6209642371433579, 0.054030693198821256, 0.5080873988178534],
65 | [0.9055507077365624, 0.12865888619203514, 0.9309191861440005],
66 | [0.9914469722960537, 0.3074114506206205, 0.8762107657323488],
67 | [0.4812682518247371, 0.15055826298548158, 0.9656340505308308],
68 | [0.6459219454316445, 0.9144794010251625, 0.751338812155106],
69 | [0.860840174209798, 0.8844626353077639, 0.3604624506769899],
70 | [0.8194991672032272, 0.926399617787601, 0.8059222327343247],
71 | [0.6540413175393658, 0.04579445254618297, 0.26891917826531275],
72 | [0.37778835833987046, 0.36247927666109536, 0.7989799305827889],
73 | [0.22738304978177726, 0.9038018263773739, 0.6970838854138303],
74 | [0.6362015495896184, 0.527680794236961, 0.5570915425178721],
75 | [0.6436401915860954, 0.6316925317144524, 0.9137151236993912],
76 | [0.04161828388587163, 0.3832413349082706, 0.6880829921949752],
77 | [0.7768167825719299, 0.8933821497682587, 0.7221278391266809],
78 | [0.8632760876301346, 0.3278628094906323, 0.8421587587114462],
79 | [0.8556499133262127, 0.6497385872901932, 0.5436895688477963],
80 | [0.9861940318610894, 0.03562313777386272, 0.9183454677106616],
81 | [0.8042586091176366, 0.6167222703170994, 0.24181981557207644],
82 | [0.9504247117633057, 0.3454233714011461, 0.6883727005547743],
83 | [0.9611909135491202, 0.46384154263898114, 0.32700443315058914],
84 | [0.523542176970206, 0.446222414615845, 0.9067402987747814],
85 | [0.7536954008682911, 0.6675512338797588, 0.22538238957839196],
86 | [0.1554052265688285, 0.05746097492966129, 0.8580358872587424],
87 | [0.8540838640971405, 0.9165504335482566, 0.6806982829158964],
88 | [0.7065090319405029, 0.8683059983962002, 0.05167128320624026],
89 | [0.39134812961899124, 0.8910075505622979, 0.7639815712623922],
90 | [0.1578117311479783, 0.20047326898284668, 0.9220177338840568],
91 | [0.2017488993096358, 0.6949259970936679, 0.8729196864798128],
92 | [0.5591089340651949, 0.15576770423813258, 0.1469857469387812],
93 | [0.14510398622626974, 0.24451497734532168, 0.46574271993578786],
94 | [0.13286397822351492, 0.4178244533944635, 0.03728728952131943],
95 | [0.556463206310225, 0.14027595183361663, 0.2731537988657907],
96 | [0.4093837966398032, 0.8015225687789814, 0.8033567296903834],
97 | [0.527442563956637, 0.902232617214431, 0.7066626674362227],
98 | [0.9058355503297827, 0.34983989180213004, 0.8353262183839384],
99 | [0.7108382186953104, 0.08591307895133471, 0.21434688012521974],
100 | [0.22757345065207668, 0.7943075496583976, 0.2992305547627421],
101 | [0.20454109788173636, 0.8251670332103687, 0.012981987094547232],
102 | [0.7672562637297392, 0.005429019973062554, 0.022163616037108702],
103 | [0.37487345910117564, 0.5086240194440863, 0.9061216063654387],
104 | [0.9878004014101087, 0.006345852772772331, 0.17499753379350858],
105 | [0.030061528704491303, 0.1409704315546606, 0.3337131835834506],
106 | [0.5022506782611504, 0.5448435505388706, 0.40584238936140726],
107 | [0.39560774627423445, 0.8905943695833262, 0.5850815030921116],
108 | [0.058615671926786406, 0.5365713844300387, 0.1620457551256279],
109 | [0.41843842882069693, 0.1536005983609976, 0.3127878501592438],
110 | [0.05947621790155899, 0.5412421167331932, 0.2611322146455659],
111 | [0.5196159938235607, 0.7066461551682705, 0.970261497412556],
112 | [0.30443031606149007, 0.45158581060034975, 0.4331841153149706],
113 | [0.8848298403933996, 0.7241791700943656, 0.8917110054596072],
114 | [0.5720260591898779, 0.3072801598203052, 0.8891066705989902],
115 | [0.13964015336177327, 0.2531778096760302, 0.5703756837403124],
116 | [0.2156307542329836, 0.4139947500641685, 0.87051676884144],
117 | [0.10800455881891169, 0.05554646035458266, 0.2947027428551443],
118 | [0.35198009410633857, 0.365849666213808, 0.06525787683513773],
119 | [0.5223264108118847, 0.9032195574351178, 0.28579084943315025],
120 | [0.7607724246546966, 0.3087194381828555, 0.6253235528354899],
121 | [0.5060485442077824, 0.19173600467625274, 0.9931175692203702],
122 | [0.5131805830323746, 0.07719515392040577, 0.923212006754969],
123 | [0.3629762141280106, 0.02429179642710888, 0.6963754952399983],
124 | [0.7542592485456767, 0.6478893299494212, 0.3424965345400731],
125 | [0.49944574453364454, 0.6775665366832825, 0.33758796076989583],
126 | [0.010621818120767679, 0.8221571611173205, 0.5186257457566332],
127 | [0.5857910304290109, 0.7178133992025467, 0.9729243483606071],
128 | [0.16987399482717613, 0.9942570210657463, 0.18120758122552927],
129 | [0.016362572521240848, 0.17582788603087263, 0.7255176922640298],
130 | [0.10981764283706419, 0.9078582203470377, 0.7638063718334003],
131 | [0.9252097840441119, 0.3330197086990039, 0.27888705301420136],
132 | [0.12769972651171546, 0.11121470804891687, 0.12710743734391716],
133 | [0.5753520518360334, 0.2763862879599456, 0.6115636613363361],
134 | ]
135 | )
136 |
137 |
138 | def prepare_semseg(img):
139 | if img.ndim == 3:
140 | img = img[..., 0]
141 | assert (
142 | img.ndim == 2
143 | ), f"Expecting 2D numpy array with semseg classes, got {img.shape}"
144 | colors = COLOR_MAP_
145 | img_color_ids = sorted(np.unique(img))
146 | map = np.zeros(np.max(img) + 1, dtype=np.int32)
147 | map[img_color_ids] = np.arange(len(img_color_ids))
148 | img = map[img]
149 | # assert all(0 <= c_id < len(colors) for c_id in img_color_ids)
150 | img = colors[img.astype(np.int32)] * 255
151 | return img
152 |
--------------------------------------------------------------------------------
/src/utils/warmup_scheduler.py:
--------------------------------------------------------------------------------
1 | from torch.optim.lr_scheduler import _LRScheduler
2 | from torch.optim.lr_scheduler import ReduceLROnPlateau
3 |
4 |
5 | class GradualWarmupScheduler(_LRScheduler):
6 | """Gradually warm-up(increasing) learning rate in optimizer.
7 | Proposed in 'Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour'.
8 | Args:
9 | optimizer (Optimizer): Wrapped optimizer.
10 | multiplier: target learning rate = base lr * multiplier
11 | total_epoch: target learning rate is reached at total_epoch, gradually
12 | after_scheduler: after target_epoch, use this scheduler(eg. ReduceLROnPlateau)
13 | """
14 |
15 | def __init__(self, optimizer, multiplier, total_epoch, after_scheduler=None):
16 | self.multiplier = multiplier
17 | if self.multiplier < 1.0:
18 | raise ValueError("multiplier should be greater thant or equal to 1.")
19 | self.total_epoch = total_epoch
20 | self.after_scheduler = after_scheduler
21 | self.finished = False
22 | super().__init__(optimizer)
23 |
24 | def get_lr(self):
25 | if self.last_epoch > self.total_epoch:
26 | if self.after_scheduler:
27 | if not self.finished:
28 | self.after_scheduler.base_lrs = [
29 | base_lr * self.multiplier for base_lr in self.base_lrs
30 | ]
31 | self.finished = True
32 | return self.after_scheduler.get_lr()
33 | return [base_lr * self.multiplier for base_lr in self.base_lrs]
34 |
35 | return [
36 | base_lr
37 | * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
38 | for base_lr in self.base_lrs
39 | ]
40 |
41 | def step_ReduceLROnPlateau(self, metrics, epoch=None):
42 | if epoch is None:
43 | epoch = self.last_epoch + 1
44 | self.last_epoch = (
45 | epoch if epoch != 0 else 1
46 | ) # ReduceLROnPlateau is called at the end of epoch, whereas others are called at beginning
47 | if self.last_epoch <= self.total_epoch:
48 | warmup_lr = [
49 | base_lr
50 | * ((self.multiplier - 1.0) * self.last_epoch / self.total_epoch + 1.0)
51 | for base_lr in self.base_lrs
52 | ]
53 | for param_group, lr in zip(self.optimizer.param_groups, warmup_lr):
54 | param_group["lr"] = lr
55 | else:
56 | if epoch is None:
57 | self.after_scheduler.step(metrics, None)
58 | else:
59 | self.after_scheduler.step(metrics, epoch - self.total_epoch)
60 |
61 | def step(self, epoch=None, metrics=None):
62 | if type(self.after_scheduler) != ReduceLROnPlateau:
63 | if self.finished and self.after_scheduler:
64 | if epoch is None:
65 | self.after_scheduler.step(None)
66 | else:
67 | self.after_scheduler.step(epoch - self.total_epoch)
68 | else:
69 | return super(GradualWarmupScheduler, self).step(epoch)
70 | else:
71 | self.step_ReduceLROnPlateau(metrics, epoch)
72 |
--------------------------------------------------------------------------------