├── .gitignore
├── .gitmodules
├── README.md
├── configs
├── default.yaml
└── replica.yaml
├── dataset
└── dataset.py
├── geom
├── __init__.py
└── plane_utils.py
├── lib
├── alphatablets.py
├── keyframe.py
├── load_colmap.py
├── load_data.py
└── load_replica.py
├── recon
├── __init__.py
├── run_depth.py
├── run_recon.py
├── run_sp.py
└── utils.py
├── requirements.txt
├── run.py
├── test.py
└── tools
├── __init__.py
├── chamfer3D
├── chamfer3D.cu
├── chamfer_cuda.cpp
├── dist_chamfer_3D.py
└── setup.py
├── generate_gt.py
├── generate_planes.py
├── random_color.py
└── simple_loader.py
/.gitignore:
--------------------------------------------------------------------------------
1 | **/__pycache__
2 | data
3 | logs*
4 | results
5 | planes_9
--------------------------------------------------------------------------------
/.gitmodules:
--------------------------------------------------------------------------------
1 | [submodule "recon/third_party/omnidata"]
2 | path = recon/third_party/omnidata
3 | url = https://github.com/hyz317/omnidata
4 | [submodule "recon/third_party/metric3d"]
5 | path = recon/third_party/metric3d
6 | url = https://github.com/hyz317/Metric3D
7 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | 
2 |
3 | # AlphaTablets
4 |
5 |
6 | 📃 Paper • 🌐 Project Page
7 |
8 |
9 | > **AlphaTablets: A Generic Plane Representation for 3D Planar Reconstruction from Monocular Videos**
10 | >
11 | > [Yuze He](https://hyzcluster.github.io), [Wang Zhao](https://github.com/thuzhaowang), [Shaohui Liu](http://b1ueber2y.me/), [Yubin Hu](https://github.com/AlbertHuyb), [Yushi Bai](https://bys0318.github.io/), Yu-Hui Wen, Yong-Jin Liu
12 | >
13 | > NeurIPS 2024
14 |
15 | **AlphaTablets** is a novel and generic representation of 3D planes that features continuous 3D surface and precise boundary delineation. By representing 3D planes as rectangles with alpha channels, AlphaTablets combine the advantages of current 2D and 3D plane representations, enabling accurate, consistent and flexible modeling of 3D planes.
16 |
17 | We propose a novel bottom-up pipeline for 3D planar reconstruction from monocular videos. Starting with 2D superpixels and geometric cues from pre-trained models, we initialize 3D planes as AlphaTablets and optimize them via differentiable rendering. An effective merging scheme is introduced to facilitate the growth and refinement of AlphaTablets. Through iterative optimization and merging, we reconstruct complete and accurate 3D planes with solid surfaces and clear boundaries.
18 |
19 |
20 |
21 |
22 |
23 | ## Quick Start
24 |
25 | ### 1. Clone the Repository
26 |
27 | Make sure to clone the repository along with its submodules:
28 |
29 | ```bash
30 | git clone --recursive https://github.com/THU-LYJ-Lab/AlphaTablets
31 | ```
32 |
33 | ### 2. Install Dependencies
34 |
35 | Set up a Python environment and install the required packages:
36 |
37 | ```bash
38 | conda create -n alphatablets python=3.9
39 | conda activate alphatablets
40 |
41 | # Install PyTorch based on your machine configuration
42 | pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118
43 |
44 | # Install other dependencies
45 | # Note: mmcv package also requires CUDA. To avoid potential errors, set the CUDA_HOME environment variable and download a CUDA-compatible version of the library.
46 | # Example: python -m pip install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu118/torch2.1/index.html
47 | pip install -r requirements.txt
48 | ```
49 |
50 | ### 3. Download Pretrained Weights
51 |
52 | #### Monocular Normal Estimation Weights
53 |
54 | Download **Omnidata** pretrained weights:
55 |
56 | - File: `omnidata_dpt_normal_v2.ckpt`
57 | - Link: [Download Here](https://www.dropbox.com/scl/fo/348s01x0trt0yxb934cwe/h?rlkey=a96g2incso7g53evzamzo0j0y&e=2&dl=0)
58 |
59 | Place the file in the directory:
60 |
61 | ```plaintext
62 | ./recon/third_party/omnidata/omnidata_tools/torch/pretrained_models
63 | ```
64 |
65 | #### Depth Estimation Weights
66 |
67 | Download **Metric3D** pretrained weights:
68 |
69 | - File: `metric_depth_vit_giant2_800k.pth`
70 | - Link: [Download Here](https://huggingface.co/JUGGHM/Metric3D/blob/main/metric_depth_vit_giant2_800k.pth)
71 |
72 | Place the file in the directory:
73 |
74 | ```plaintext
75 | ./recon/third_party/metric3d/weight
76 | ```
77 |
78 | ### 4. Running Demos
79 |
80 | #### ScanNet Demo
81 |
82 | 1. Download the `scene0684_01` demo scene from [here](https://drive.google.com/drive/folders/13rYkek_CQuOk_N5erJL08R26B1BkYmwD?usp=sharing) and extract it to `./data/`.
83 | 2. Run the demo with the following command:
84 |
85 | ```bash
86 | python run.py --job scene0684_01
87 | ```
88 |
89 | #### Replica Demo
90 |
91 | 1. Download the `office0` demo scene from [here](https://drive.google.com/drive/folders/13rYkek_CQuOk_N5erJL08R26B1BkYmwD?usp=sharing) and extract it to `./data/`.
92 | 2. Run the demo using the specified configuration:
93 |
94 | ```bash
95 | python run.py --config configs/replica.yaml --job office0
96 | ```
97 |
98 | ### Tips
99 |
100 | - **Out-of-Memory (OOM):** Reduce `batch_size` if you encounter memory issues.
101 | - **Low Frame Rate Sequences:** Increase `weight_decay`, or set it to `-1` for an automatic decay. The default value is `0.9` (works well for ScanNet and Replica), but it can go up to larger values (no more than `1.0`).
102 | - **Scene Scaling Issues:** If the scene scale differs significantly from real-world dimensions, adjust merging parameters such as `dist_thres` (maximum allowable distance for tablet merging).
103 |
104 |
105 |
106 | ## Evaluation on the ScanNet v2 dataset
107 |
108 | 1. **Download and Extract ScanNet**:
109 | Follow the instructions provided on the [ScanNet website](http://www.scan-net.org/) to download and extract the dataset.
110 |
111 | 2. **Prepare the Data**:
112 | Use the data preparation script to parse the raw ScanNet data into a processed pickle format and generate ground truth planes using code modified from [PlaneRCNN](https://github.com/NVlabs/planercnn/blob/master/data_prep/parse.py) and [PlanarRecon](https://github.com/neu-vi/PlanarRecon/tree/main).
113 |
114 | Run the following command under the PlanarRecon environment:
115 |
116 | ```bash
117 | python tools/generate_gt.py --data_path PATH_TO_SCANNET --save_name planes_9/ --window_size 9 --n_proc 2 --n_gpu 1
118 | python tools/prepare_inst_gt_txt.py --val_list PATH_TO_SCANNET/scannetv2_val.txt --plane_mesh_path ./planes_9
119 | ```
120 |
121 | 3. **Process Scenes in the Validation Set**:
122 | You can use the following command to process each scene in the validation set. Update `scene????_??` with the specific scene name. Train/val/test split information is available [here](https://github.com/ScanNet/ScanNet/tree/master/Tasks/Benchmark):
123 |
124 | ```bash
125 | python run.py --job scene????_?? --input_dir PATH_TO_SCANNET/scene????_??
126 | ```
127 |
128 | 4. **Run the Test Script**:
129 | Finally, execute the test script to evaluate the processed data:
130 |
131 | ```bash
132 | python test.py
133 | ```
134 |
135 |
136 |
137 | ## Citation
138 |
139 | If you find our work useful, please kindly cite:
140 |
141 | ```
142 | @article{he2024alphatablets,
143 | title={AlphaTablets: A Generic Plane Representation for 3D Planar Reconstruction from Monocular Videos},
144 | author={Yuze He and Wang Zhao and Shaohui Liu and Yubin Hu and Yushi Bai and Yu-Hui Wen and Yong-Jin Liu},
145 | journal={arXiv preprint arXiv:2411.19950},
146 | year={2024}
147 | }
148 | ```
149 |
150 |
151 |
152 | ## Acknowledgements
153 |
154 | Some of the test code and installation guide in this repo is borrowed from [NeuralRecon](https://github.com/zju3dv/NeuralRecon), [PlanarRecon](https://github.com/neu-vi/PlanarRecon/tree/main) and [ParticleSfM](https://github.com/bytedance/particle-sfm)! We sincerely thank them all.
155 |
--------------------------------------------------------------------------------
/configs/default.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_type: 'scannet'
3 | input_dir: './data/scene0684_01'
4 |
5 | init:
6 | depth_model_type: 'metric3d'
7 | normal_model_type: 'omnidata'
8 |
9 | crop: 20
--------------------------------------------------------------------------------
/configs/replica.yaml:
--------------------------------------------------------------------------------
1 | data:
2 | dataset_type: 'replica'
3 | input_dir: './data/office0'
4 |
5 | init:
6 | depth_model_type: 'metric3d'
7 | normal_model_type: 'omnidata'
8 |
9 | crop: 20
--------------------------------------------------------------------------------
/dataset/dataset.py:
--------------------------------------------------------------------------------
1 | import torch
2 | from torch.utils.data import Dataset
3 |
4 | class CustomDataset(Dataset):
5 | def __init__(self, images, mvp_mtxs, normals, depths):
6 | self.images = images
7 | self.mvp_mtxs = mvp_mtxs
8 | self.normals = normals
9 | self.depths = depths
10 |
11 | def __len__(self):
12 | return len(self.images)
13 |
14 | def __getitem__(self, idx):
15 | img = torch.from_numpy(self.images[idx])
16 | mvp_mtx = torch.from_numpy(self.mvp_mtxs[idx]).float()
17 | normal = torch.from_numpy(self.normals[idx])
18 | depth = torch.from_numpy(self.depths[idx])
19 |
20 | return img, mvp_mtx, normal, depth
21 |
--------------------------------------------------------------------------------
/geom/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-LYJ-Lab/AlphaTablets/735cbfe6aa7f03f7bc37f48303045fe71ffa042a/geom/__init__.py
--------------------------------------------------------------------------------
/geom/plane_utils.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 |
4 | from sklearn.neighbors import KDTree
5 | from tqdm import trange
6 |
7 |
8 | class DisjointSet:
9 | def __init__(self, size, params, mean_colors):
10 | self.parent = [i for i in range(size)]
11 | self.rank = [0] * size
12 | self.accum_normal = [_ for _ in params[:, :3].copy()]
13 | self.accum_center = [_ for _ in params[:, 3:].copy()]
14 | self.accum_color = [_ for _ in mean_colors.copy()]
15 | self.accum_num = [1] * size
16 |
17 | def find(self, x):
18 | if self.parent[x] != x:
19 | self.parent[x] = self.find(self.parent[x])
20 | return self.parent[x]
21 |
22 | def union(self, x, y, normal_thres=0.9, dis_thres=0.1, color_thres=1.):
23 | root_x = self.find(x)
24 | root_y = self.find(y)
25 | if root_x != root_y and np.abs(
26 | np.sum(self.accum_normal[root_x] * self.accum_normal[root_y]) /
27 | np.linalg.norm(self.accum_normal[root_x]) /
28 | np.linalg.norm(self.accum_normal[root_y])
29 | ) > normal_thres and np.linalg.norm(
30 | (self.accum_center[root_x] / self.accum_num[root_x] -
31 | self.accum_center[root_y] / self.accum_num[root_y])
32 | * self.accum_normal[root_x] / np.linalg.norm(self.accum_normal[root_x])
33 | ) < dis_thres and np.linalg.norm(
34 | self.accum_color[root_x] / self.accum_num[root_x] -
35 | self.accum_color[root_y] / self.accum_num[root_y]
36 | ) < color_thres:
37 | if self.rank[root_x] < self.rank[root_y]:
38 | self.parent[root_x] = root_y
39 | self.accum_normal[root_y] += self.accum_normal[root_x]
40 | self.accum_center[root_y] += self.accum_center[root_x]
41 | self.accum_color[root_y] += self.accum_color[root_x]
42 | self.accum_num[root_y] += self.accum_num[root_x]
43 | elif self.rank[root_x] > self.rank[root_y]:
44 | self.parent[root_y] = root_x
45 | self.accum_normal[root_x] += self.accum_normal[root_y]
46 | self.accum_center[root_x] += self.accum_center[root_y]
47 | self.accum_color[root_x] += self.accum_color[root_y]
48 | self.accum_num[root_x] += self.accum_num[root_y]
49 | else:
50 | self.parent[root_x] = root_y
51 | self.rank[root_y] += 1
52 | self.accum_normal[root_y] += self.accum_normal[root_x]
53 | self.accum_center[root_y] += self.accum_center[root_x]
54 | self.accum_color[root_y] += self.accum_color[root_x]
55 | self.accum_num[root_y] += self.accum_num[root_x]
56 |
57 |
58 | def calc_plane(K, pose, depth, dis, mask, x_range_uv, y_range_uv, plane_normal=None):
59 | """Calculate the plane parameters from the camera pose and the pixel plane parameters.
60 | Args:
61 | K: camera intrinsic matrix, [3, 3]
62 | pose: camera pose, world-to-cam, [4, 4]
63 | dis: pixel plane distance to camera, scalar
64 | x_range_uv: pixel x range, [2]
65 | y_range_uv: pixel y range, [2]
66 | plane_normal: plane normal vector, [3], optional
67 | Returns:
68 | plane: plane parameters, [4]
69 | plane_up_vec: plane up vector, [3]
70 | resol: pixel plane resolution, [2]
71 | new_x_range_uv: new pixel x range, [2]
72 | new_y_range_uv: new pixel y range, [2]
73 | """
74 |
75 | # create a pixel grid in the image plane
76 | u = np.linspace(x_range_uv[0], x_range_uv[1], int(x_range_uv[1]-x_range_uv[0]+1))
77 | v = np.linspace(y_range_uv[0], y_range_uv[1], int(y_range_uv[1]-y_range_uv[0]+1))
78 | U, V = np.meshgrid(u, v)
79 |
80 | # back project the pixel grid to 3D space
81 | X = depth * (U + 0.5 - K[0, 2]) / K[0, 0]
82 | Y = -depth * (V + 0.5 - K[1, 2]) / K[1, 1]
83 | Z = -depth * np.ones(U.shape)
84 |
85 | # transform the points from camera coordinate to world coordinate
86 | points = np.stack((X, Y, Z, np.ones(U.shape)), axis=-1)
87 | points_world = np.matmul(points, np.linalg.inv(pose).T)
88 | points_world = points_world[:, :, :3] / points_world[:, :, 3:]
89 |
90 | # use PCA to fit a plane to these points
91 | points_world_flat = points_world[mask]
92 | mean = np.mean(points_world_flat, axis=0)
93 | points_world_zero_centroid = points_world_flat - mean
94 |
95 | if plane_normal is not None:
96 | plane_normal = np.matmul(plane_normal, np.linalg.inv(pose[:3, :3]).T)
97 | plane_normal = plane_normal / np.linalg.norm(plane_normal)
98 |
99 | else:
100 | if len(points_world_zero_centroid) < 10000:
101 | _, _, v = np.linalg.svd(points_world_zero_centroid)
102 | else:
103 | import random
104 | _, _, v = np.linalg.svd(random.sample(list(points_world_zero_centroid), 10000))
105 |
106 | # plane parameters
107 | plane_normal = v[-1, :] / np.linalg.norm(v[-1, :])
108 | if np.abs(plane_normal).max() != plane_normal.max():
109 | plane_normal = -plane_normal
110 |
111 | plane = np.concatenate((plane_normal, mean))
112 |
113 | resol = np.array([K[0,0]/dis, K[1,1]/dis])
114 |
115 | # calculate plane_up_vector according to the plane_normal
116 | pose_up_vec = np.array([0, 0, 1])
117 | if np.abs(np.sum(plane_normal * pose_up_vec)) > 0.8:
118 | pose_up_vec = np.array([0, 1, 0])
119 | plane_up_vec = pose_up_vec - np.sum(plane_normal * pose_up_vec) * plane_normal
120 | plane_up_vec = plane_up_vec / np.linalg.norm(plane_up_vec)
121 |
122 | U, V = U[mask], V[mask]
123 |
124 | plane_right = np.cross(plane_normal, plane_up_vec)
125 |
126 | new_x_argmin =np.matmul(points_world_flat, plane_right).argmin()
127 | new_x_argmax =np.matmul(points_world_flat, plane_right).argmax()
128 | new_y_argmin =np.matmul(points_world_flat, plane_up_vec).argmin()
129 | new_y_argmax =np.matmul(points_world_flat, plane_up_vec).argmax()
130 |
131 | new_x_width = np.sqrt((U[new_x_argmin] - U[new_x_argmax])**2 + (V[new_x_argmin] - V[new_x_argmax])**2)
132 | new_y_width = np.sqrt((U[new_y_argmin] - U[new_y_argmax])**2 + (V[new_y_argmin] - V[new_y_argmax])**2)
133 |
134 | new_x_range_uv = [-new_x_width/2 - 2, new_x_width/2 + 2]
135 | new_y_range_uv = [-new_y_width/2 - 2, new_y_width/2 + 2]
136 |
137 | y_range_3d = np.matmul(points_world_flat, plane_up_vec).max() - np.matmul(points_world_flat, plane_up_vec).min()
138 | x_range_3d = np.matmul(points_world_flat, plane_right).max() - np.matmul(points_world_flat, plane_right).min()
139 |
140 | resol = np.array([
141 | new_x_width / (x_range_3d + 1e-6),
142 | new_y_width / (y_range_3d + 1e-6)
143 | ])
144 |
145 | if resol[0] < K[0,0]/dis / 10 or resol[1] < K[1,1]/dis / 10:
146 | return None
147 |
148 | xy_min = np.array([new_x_range_uv[0], new_y_range_uv[0]])
149 | xy_max = np.array([new_x_range_uv[1], new_y_range_uv[1]])
150 |
151 | return plane, plane_up_vec, resol, xy_min, xy_max
152 |
153 |
154 | def ray_plane_intersect(plane, ray_origin, ray_direction):
155 | """Calculate the intersection of a ray and a plane.
156 | Args:
157 | plane: plane parameters, [4]
158 | ray_origin: ray origin, [3]
159 | ray_direction: ray direction, [3]
160 |
161 | Returns:
162 | intersection: intersection point, [3]
163 | """
164 |
165 | # calculate intersection
166 | t = -(plane[3] + np.dot(plane[:3], ray_origin)) / np.dot(plane[:3], ray_direction)
167 | intersection = ray_origin + t * ray_direction
168 |
169 | return intersection
170 |
171 |
172 | def points_xyz_to_plane_uv(points, plane, resol, plane_up):
173 | """Project 3D points to the pixel plane.
174 | Args:
175 | points: 3D points, [N, 3]
176 | plane: plane parameters, [4]
177 | resol: pixel plane resolution, [2]
178 | plane_up: plane up vector, [3]
179 | Returns:
180 | uv: pixel plane coordinates, [N, 2]
181 | """
182 |
183 | # plane normal vector
184 | plane_normal = np.asarray(plane[:3])
185 | # projection points of 'points' on the plane
186 | points_proj = points - np.outer( np.sum(points*plane_normal, axis=1)+plane[3] , plane_normal) / np.linalg.norm(plane_normal)
187 | mean_proj = np.mean(points_proj, axis=0)
188 | uvw_right = np.cross(plane_normal, plane_up)
189 | uvw_up = plane_up
190 | # calculate the uv coordinates
191 | uv = np.c_[np.sum((points_proj-mean_proj) * uvw_right, axis=-1), np.sum((points_proj-mean_proj) * uvw_up, axis=-1)]
192 | uv = uv * resol
193 | return uv
194 |
195 |
196 | def points_xyz_to_plane_uv_torch(points, plane, resol, plane_up):
197 | """Project 3D points to the pixel plane.
198 | Args:
199 | points: 3D points, [N, 3]
200 | plane: plane parameters, [4]
201 | resol: pixel plane resolution, [2]
202 | plane_up: plane up vector, [3]
203 | Returns:
204 | uv: pixel plane coordinates, [N, 2]
205 | """
206 |
207 | # plane normal vector
208 | plane_normal = plane[:3]
209 | # projection points of 'points' on the plane
210 | points_proj = points - torch.outer( torch.sum(points*plane_normal, dim=1)+plane[3] , plane_normal) / torch.norm(plane_normal)
211 | mean_proj = torch.mean(points_proj, dim=0)
212 | uvw_right = torch.cross(plane_normal, plane_up)
213 | uvw_up = plane_up
214 | # calculate the uv coordinates
215 | uv = torch.cat([torch.sum((points_proj-mean_proj) * uvw_right, dim=-1, keepdim=True), torch.sum((points_proj-mean_proj) * uvw_up, dim=-1, keepdim=True)], dim=-1)
216 | uv = uv * resol
217 | return uv
218 |
219 |
220 | def simple_distribute_planes_2D(uv_ranges, gap=2, min_H=8192):
221 | """Distribute pixel planes in 2D space.
222 | Args:
223 | uv_ranges: pixel ranges of each plane, [N, 2]
224 | Returns:
225 | plane_leftup: left-up pixel of each plane, [N, 2]
226 | """
227 |
228 | # calculate the left-up pixel of each plane
229 | plane_leftup = torch.zeros_like(uv_ranges)
230 |
231 | H = (int(torch.max(uv_ranges[:, 1]).item()) + 2) // gap * gap + gap
232 | H = max(H, min_H)
233 |
234 | # sort the planes by the height
235 | _, sort_idx = torch.sort(uv_ranges[:, 1], descending=True)
236 |
237 | # distribute the planes
238 | idx = 0
239 | now_H = 0
240 | now_W = 0
241 | prev_W = 0
242 | while idx < len(sort_idx):
243 | now_leftup = torch.tensor((prev_W+1, now_H+1))
244 | plane_leftup[sort_idx[idx]] = now_leftup
245 | now_W = max(now_W, uv_ranges[sort_idx[idx], 0] + 1 + prev_W)
246 | now_H += uv_ranges[sort_idx[idx], 1] + 1
247 |
248 | if idx + 1 < len(sort_idx) and now_H + uv_ranges[sort_idx[idx+1], 1] + 1 > H:
249 | prev_W = now_W
250 | now_H = 0
251 |
252 | idx += 1
253 |
254 | W = (int(now_W.item()) + 2) // gap * gap + gap
255 |
256 | return plane_leftup, W, H
257 |
258 |
259 | def cluster_planes(planes, K=2, thres=0.999,
260 | color_thres_1=0.3, color_thres_2=0.2,
261 | dis_thres_1=0.5, dis_thres_2=0.1,
262 | merge_edge_planes=False,
263 | init_plane_sets=None):
264 | # planes[:, :3]: plane normal vector
265 | # planes[:, 3:6]: plane center
266 |
267 | # create disjoint set
268 | ds = DisjointSet(len(planes), planes[:, :6], planes[:, 17:20])
269 |
270 | if init_plane_sets is not None:
271 | for plane_set in init_plane_sets:
272 | for i in range(len(plane_set)-1):
273 | ds.union(plane_set[i], plane_set[i+1], normal_thres=0.99, dis_thres=1, color_thres=114514)
274 |
275 | # construct kd-tree for plane center
276 | print('clustering tablets ...')
277 | tree = KDTree(planes[:, 3:6])
278 |
279 | # find the nearest K neighbors for each plane
280 | _, ind = tree.query(planes[:, 3:6], k=K+1)
281 |
282 | # calculate the angle between each plane and its neighbors
283 | neighbor_normals = planes[ind[:, 1:], :3]
284 | plane_normals = planes[:, :3]
285 | cos = np.sum(neighbor_normals * plane_normals[:, None, :], axis=-1)
286 |
287 | # merge planes that have cos > thres
288 | for i in trange(len(planes)):
289 | if not merge_edge_planes and planes[i, 15] == True:
290 | continue
291 |
292 | for j in range(K):
293 | if not merge_edge_planes and planes[ind[i, j+1], 15] == True:
294 | continue
295 | if cos[i, j] > thres:
296 | ds.union(i, ind[i, j+1], normal_thres=0.99, dis_thres=dis_thres_1, color_thres=color_thres_1)
297 |
298 | # merge planes that have cos > thres
299 | for i in trange(len(planes)):
300 | if not merge_edge_planes and planes[i, 15] == True:
301 | continue
302 |
303 | for j in range(K):
304 | if not merge_edge_planes and planes[ind[i, j+1], 15] == True:
305 | continue
306 | if cos[i, j] > thres:
307 | ds.union(i, ind[i, j+1], normal_thres=0.9, dis_thres=dis_thres_2, color_thres=color_thres_2)
308 |
309 | root2idx = {}
310 | for i in range(len(planes)):
311 | root = ds.find(i)
312 | if root not in root2idx:
313 | root2idx[root] = []
314 | root2idx[root].append(i)
315 |
316 | plane_sets = []
317 | for root in root2idx:
318 | plane_sets.append(root2idx[root])
319 |
320 | return plane_sets
321 |
322 |
--------------------------------------------------------------------------------
/lib/keyframe.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import glob
3 | import os
4 | from datetime import datetime
5 |
6 |
7 | def get_keyframes(folder, min_angle=15, min_distance=0.2, window_size=9,
8 | min_mean=0.2, max_mean=10):
9 | txt_list = sorted(glob.glob(f'{folder}/pose/*.txt'), key=lambda x: int(x.split('/')[-1].split('.')[0]))
10 | if len(txt_list) != 0:
11 | last_pose = np.loadtxt(txt_list[0])
12 | image_skip = np.ceil(len(txt_list) / 2000)
13 | txt_list = txt_list[::int(image_skip)]
14 | pose_num = len(txt_list)
15 | else:
16 | extrs = np.loadtxt(os.path.join(folder, 'traj.txt')).reshape(-1, 4, 4)
17 | last_pose = extrs[0]
18 | image_skip = np.ceil(len(extrs) / 2000)
19 | extrs = extrs[::int(image_skip)]
20 | pose_num = len(extrs)
21 |
22 | count = 1
23 | all_ids = []
24 |
25 | if len(txt_list) != 0:
26 | depth_list = [ pname.replace('pose', 'aligned_dense_depths').replace('.txt', '.npy') for pname in txt_list ]
27 | else:
28 | depth_list = sorted(glob.glob(f'{folder}/aligned_dense_depths/*.npy'))
29 | for i, j in zip(txt_list, depth_list):
30 | if int(i.split('/')[-1].split('.')[0]) != int(j.split('/')[-1].split('.')[0]):
31 | print(i, j)
32 | raise ValueError('pose and depth not match')
33 | depth_list = depth_list[::int(image_skip)]
34 |
35 | depth_list = np.array([ np.load(i).mean() for i in depth_list ])
36 | id_list = np.linspace(0, len(depth_list)-1, pose_num).astype(int)[::int(image_skip)]
37 | id_list = id_list[np.logical_and(depth_list > min_mean, depth_list < max_mean)]
38 | depth_list = depth_list[np.logical_and(depth_list > min_mean, depth_list < max_mean)]
39 |
40 | from scipy.signal import medfilt
41 | filtered_depth_list = medfilt(depth_list, kernel_size=29)
42 |
43 | # filtered out the depth_list
44 | id_list = id_list[
45 | np.logical_and(
46 | depth_list > filtered_depth_list * 0.85,
47 | depth_list < filtered_depth_list * 1.15
48 | )
49 | ]
50 |
51 | depth_list = depth_list[
52 | np.logical_and(
53 | depth_list > filtered_depth_list * 0.85,
54 | depth_list < filtered_depth_list * 1.15
55 | )
56 | ]
57 |
58 | print(str(datetime.now()) + ': \033[92mI', 'filtered out depth_list', len(id_list), '/', pose_num, '\033[0m')
59 |
60 | ids = [id_list[0]]
61 |
62 | for idx_pos, idx in enumerate(id_list[1:]):
63 | if len(txt_list) != 0:
64 | i = txt_list[idx]
65 | with open(i, 'r') as f:
66 | cam_pose = np.loadtxt(i)
67 | else:
68 | cam_pose = extrs[idx]
69 | angle = np.arccos(
70 | ((np.linalg.inv(cam_pose[:3, :3]) @ last_pose[:3, :3] @ np.array([0, 0, 1]).T) * np.array(
71 | [0, 0, 1])).sum())
72 | dis = np.linalg.norm(cam_pose[:3, 3] - last_pose[:3, 3])
73 | if angle > (min_angle / 180) * np.pi or dis > min_distance:
74 | ids.append(idx)
75 | last_pose = cam_pose
76 | # Compute camera view frustum and extend convex hull
77 | count += 1
78 | if count == window_size:
79 | ids = [i * int(image_skip) for i in ids]
80 | all_ids.append(ids)
81 | ids = []
82 | count = 0
83 |
84 | if len(ids) > 2:
85 | ids = [i * int(image_skip) for i in ids]
86 | all_ids.append(ids)
87 | else:
88 | ids = [i * int(image_skip) for i in ids]
89 | all_ids[-1].extend(ids)
90 |
91 | return all_ids, int(image_skip)
92 |
93 |
94 | if __name__ == '__main__':
95 | folder = '/data0/bys/Hex/PixelPlane/data/scene0709_01'
96 | keyframes, image_skip = get_keyframes(folder)
97 | print(keyframes, image_skip)
98 |
--------------------------------------------------------------------------------
/lib/load_colmap.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import os
3 | import cv2
4 | import glob
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | from recon.utils import load_sfm_pose, load_scannet_pose
9 |
10 | from geom.plane_utils import calc_plane
11 |
12 | import torch
13 | import torch.nn.functional as F
14 |
15 |
16 | def dilate(img):
17 | kernel = torch.ones(3, 3).cuda().double()
18 | img = torch.tensor(img.astype('int32')).cuda().double()
19 |
20 | while True:
21 | mask = img > 0
22 | if torch.logical_not(mask).sum() == 0:
23 | break
24 | mask = mask.double()
25 |
26 | mask_dilated = F.conv2d(mask[None, None], kernel[None, None], padding=1).squeeze(0).squeeze(0)
27 |
28 | img_dilated = F.conv2d(img[None, None], kernel[None, None], padding=1).squeeze(0).squeeze(0)
29 | img_dilated = torch.where(mask_dilated == 0, img_dilated, img_dilated / mask_dilated)
30 |
31 | img = torch.where(mask == 0, img_dilated, img)
32 |
33 | img = img.cpu().numpy()
34 |
35 | return img
36 |
37 |
38 | def parse_sp(sp, intr, pose, depth, normal, img_sp, edge_thres=0.95, crop=20):
39 | """Parse the superpixel segmentation into planes
40 | """
41 | sp = torch.from_numpy(sp).cuda()
42 | depth = torch.from_numpy(depth).cuda()
43 | normal = (torch.from_numpy(normal).cuda() - 0.5) * 2
44 | normal = normal.permute(1, 2, 0)
45 | img_sp = torch.from_numpy(img_sp).cuda()
46 | sp = sp.int()
47 | sp_ids = torch.unique(sp)
48 | planes = []
49 | mean_colors = []
50 |
51 | # mask crop to -1
52 | sp[:crop] = -1
53 | sp[-crop:] = -1
54 | sp[:, :crop] = -1
55 | sp[:, -crop:] = -1
56 |
57 | new_sp = torch.ones_like(sp).cuda() * -1
58 |
59 | sp_id_cnt = 0
60 |
61 | for sp_id in tqdm(sp_ids):
62 | sp_mask = sp == sp_id
63 |
64 | if sp_mask.sum() <= 15:
65 | continue
66 |
67 | new_sp[sp_mask] = sp_id_cnt
68 | mean_color = torch.mean(img_sp[sp_mask], dim=0)
69 |
70 | sp_dis = depth[sp_mask]
71 | sp_dis = torch.median(sp_dis)
72 | if sp_dis < 0:
73 | continue
74 | sp_normal = normal[sp_mask]
75 |
76 | sp_coeff = torch.einsum('ij,kj->ik', sp_normal, sp_normal)
77 | if sp_coeff.min() < edge_thres:
78 | is_edge_plane = True
79 | else:
80 | is_edge_plane = False
81 |
82 | sp_normal = torch.mean(sp_normal, dim=0)
83 | sp_normal = sp_normal / torch.norm(sp_normal)
84 | sp_normal[1] = -sp_normal[1]
85 | sp_normal[2] = -sp_normal[2]
86 |
87 | x_accum = torch.sum(sp_mask, dim=0)
88 | y_accum = torch.sum(sp_mask, dim=1)
89 | x_range_uv = [torch.min(torch.nonzero(x_accum)).item(), torch.max(torch.nonzero(x_accum)).item()]
90 | y_range_uv = [torch.min(torch.nonzero(y_accum)).item(), torch.max(torch.nonzero(y_accum)).item()]
91 |
92 | sp_dis = sp_dis.cpu().numpy()
93 | sp_depth = depth[y_range_uv[0]:y_range_uv[1]+1, x_range_uv[0]:x_range_uv[1]+1].cpu().numpy()
94 | sp_mask = sp_mask[y_range_uv[0]:y_range_uv[1]+1, x_range_uv[0]:x_range_uv[1]+1].cpu().numpy()
95 |
96 | ret = calc_plane(intr, pose, sp_depth, sp_dis, sp_mask, x_range_uv, y_range_uv, plane_normal=sp_normal.cpu().numpy())
97 | if ret is None:
98 | continue
99 | plane, plane_up, resol, new_x_range_uv, new_y_range_uv = ret
100 |
101 | plane = np.concatenate([plane, plane_up, resol, new_x_range_uv, new_y_range_uv, [is_edge_plane]])
102 | planes.append(plane)
103 | mean_colors.append(mean_color.cpu().numpy())
104 |
105 | sp_id_cnt += 1
106 |
107 | return planes, new_sp.cpu().numpy(), mean_colors
108 |
109 |
110 | def get_projection_matrix(fovy: float, aspect_wh: float, near: float, far: float):
111 | proj_mtx = np.zeros((4, 4), dtype=np.float32)
112 | proj_mtx[0, 0] = 1.0 / (np.tan(fovy / 2.0) * aspect_wh)
113 | proj_mtx[1, 1] = -1.0 / np.tan(
114 | fovy / 2.0
115 | ) # add a negative sign here as the y axis is flipped in nvdiffrast output
116 | proj_mtx[2, 2] = -(far + near) / (far - near)
117 | proj_mtx[2, 3] = -2.0 * far * near / (far - near)
118 | proj_mtx[3, 2] = -1.0
119 | return proj_mtx
120 |
121 |
122 | def get_mvp_matrix(c2w, proj_mtx):
123 | # calculate w2c from c2w: R' = Rt, t' = -Rt * t
124 | # mathematically equivalent to (c2w)^-1
125 | w2c = np.zeros((c2w.shape[0], 4, 4))
126 | w2c[:, :3, :3] = np.transpose(c2w[:, :3, :3], (0, 2, 1))
127 | w2c[:, :3, 3:] = np.transpose(-c2w[:, :3, :3], (0, 2, 1)) @ c2w[:, :3, 3:]
128 | w2c[:, 3, 3] = 1.0
129 |
130 | # calculate mvp matrix by proj_mtx @ w2c (mv_mtx)
131 | mvp_mtx = proj_mtx @ w2c
132 |
133 | return mvp_mtx
134 |
135 |
136 | def load_colmap_data(args, kf_list=[], image_skip=1, load_plane=True, scannet_pose=True):
137 | imgs = []
138 | image_dir = os.path.join(args.input_dir, "images")
139 | start = kf_list[0]
140 | end = kf_list[-1] + 1 if kf_list[-1] != -1 else len(glob.glob(os.path.join(image_dir, "*")))
141 | image_names = sorted(glob.glob(os.path.join(image_dir, "*")), key=lambda x: int(x.split('/')[-1][:-4]))
142 |
143 | image_names_sp = [image_names[i] for i in kf_list]
144 | image_names = image_names[start:end:image_skip]
145 |
146 | print(str(datetime.now()) + ': \033[92mI', 'loading images ...', '\033[0m')
147 | for name in tqdm(image_names):
148 | img = cv2.imread(name)
149 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
150 | imgs.append(img.astype(np.float32) / 255.)
151 |
152 | imgs_sp = []
153 | for name in image_names_sp:
154 | img = cv2.imread(name)
155 | # convert to ycbcr using cv2
156 | img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
157 | imgs_sp.append(img.astype(np.float32) / 255.)
158 |
159 | imgs_sp = np.stack(imgs_sp, 0)
160 |
161 | if scannet_pose:
162 | print(str(datetime.now()) + ': \033[92mI', 'loading scannet poses ...', '\033[0m')
163 | intr_scannet, sfm_poses_scannet, sfm_camprojs_scannet = load_scannet_pose(args.input_dir)
164 | intr = intr_scannet
165 | sfm_poses, sfm_camprojs = sfm_poses_scannet, sfm_camprojs_scannet
166 | else:
167 | print(str(datetime.now()) + ': \033[92mI', 'loading colmap poses ...', '\033[0m')
168 | intr, sfm_poses, sfm_camprojs = load_sfm_pose(os.path.join(args.input_dir, "sfm"))
169 |
170 |
171 | sfm_poses_sp = np.array(sfm_poses)[kf_list]
172 | sfm_poses = np.array(sfm_poses)[start:end:image_skip]
173 | sfm_camprojs = np.array(sfm_camprojs)[start:end:image_skip]
174 |
175 | normal_dir = os.path.join(args.input_dir, "omnidata_normal")
176 | normal_names = sorted(glob.glob(os.path.join(normal_dir, "*")), key=lambda x: int(x.split('/')[-1][:-4]))
177 | depth_names = [ iname.replace('omnidata_normal', 'aligned_dense_depths') for iname in normal_names ]
178 | depth_names = [depth_names[i] for i in kf_list]
179 |
180 | print(str(datetime.now()) + ': \033[92mI', 'loading depths ...', '\033[0m')
181 | depths = []
182 | for name in tqdm(depth_names):
183 | depth = np.load(name)
184 | depths.append(depth)
185 |
186 | depths = np.stack(depths, 0)
187 | aligned_depth_dir = os.path.join(args.input_dir, "aligned_dense_depths")
188 | all_depth_names = sorted(glob.glob(os.path.join(aligned_depth_dir, "*")), key=lambda x: int(x.split('/')[-1][:-4]))[start:end:image_skip]
189 | all_depths = []
190 |
191 | for name, pose in tqdm(zip(all_depth_names, sfm_poses)):
192 | depth = np.load(name)
193 | all_depths.append(depth)
194 |
195 | if args.init.normal_model_type == 'omnidata':
196 | normal_dir = os.path.join(args.input_dir, "omnidata_normal")
197 | else:
198 | raise NotImplementedError(f'Unknown normal model type {args.init.normal_model_type} exiting')
199 |
200 | normal_names = sorted(glob.glob(os.path.join(normal_dir, "*")), key=lambda x: int(x.split('/')[-1][:-4]))
201 | normal_names = [normal_names[i] for i in kf_list]
202 |
203 | print(str(datetime.now()) + ': \033[92mI', 'loading normals ...', '\033[0m')
204 | normals = []
205 | for name in tqdm(normal_names):
206 | normal = np.load(name)[:3]
207 | normals.append(normal)
208 |
209 | normals = np.stack(normals, 0)
210 |
211 | all_normal_names = sorted(glob.glob(os.path.join(normal_dir, "*")), key=lambda x: int(x.split('/')[-1][:-4]))[start:end:image_skip]
212 | all_normals = []
213 |
214 | for name, pose in tqdm(zip(all_normal_names, sfm_poses)):
215 | normal = np.load(name)
216 | normal = normal[:3]
217 |
218 | normal = torch.from_numpy(normal).cuda()
219 | normal = (normal - 0.5) * 2
220 | normal = normal / torch.norm(normal, dim=0, keepdim=True)
221 | normal = normal.permute(1, 2, 0)
222 | normal[:, :, 1] = -normal[:, :, 1]
223 | normal[:, :, 2] = -normal[:, :, 2]
224 | normal = torch.einsum('ijk,kl->ijl', normal, torch.from_numpy(np.linalg.inv(pose[:3, :3]).T).cuda().float())
225 | normal = F.interpolate(normal.permute(2, 0, 1).unsqueeze(0),
226 | (int(imgs[0].shape[0]/64)*8, int(imgs[0].shape[1]/64)*8),
227 | mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0)
228 | all_normals.append(normal.cpu().numpy())
229 |
230 | sp_dir = os.path.join(args.input_dir, "sp")
231 | sp_names = [os.path.join(sp_dir, os.path.split(name)[-1].replace('.jpg', '.npy').replace('.png', '.npy')) for name in image_names_sp]
232 |
233 | if load_plane:
234 | print(str(datetime.now()) + ': \033[92mI', 'loading superpixels ...', '\033[0m')
235 | planes_all = []
236 | new_sp_all = []
237 | mean_colors_all = []
238 | plane_index_start = 0
239 | for plane_idx, (sp_name, depth, normal, pose, img_sp) in tqdm(enumerate(zip(sp_names, depths, normals, sfm_poses_sp, imgs_sp))):
240 | sp = np.load(sp_name)
241 | planes, new_sp, mean_colors = parse_sp(sp, intr, pose, depth, normal, img_sp)
242 | planes = np.array(planes)
243 | planes_idx = np.ones((planes.shape[0], 1)) * plane_idx
244 | planes = np.concatenate([planes, planes_idx], 1)
245 | planes_all.extend(planes)
246 | new_sp_all.append(new_sp + plane_index_start)
247 | plane_index_start += planes.shape[0]
248 | mean_colors = np.array(mean_colors)
249 | mean_colors_all.extend(mean_colors)
250 |
251 | planes_all = np.array(planes_all)
252 | mean_colors_all = np.array(mean_colors_all)
253 |
254 | else:
255 | planes_all = None
256 | new_sp_all = None
257 | mean_colors_all = None
258 |
259 | # proj matrixs
260 | frames_proj = []
261 | frames_c2w = []
262 | frames_center = []
263 | for i in range(len(sfm_camprojs)):
264 | fovy = 2 * np.arctan(0.5 * imgs[0].shape[0] / intr[1, 1])
265 | proj = get_projection_matrix(
266 | fovy, imgs[0].shape[1] / imgs[0].shape[0], 0.1, 1000.0
267 | )
268 | proj = np.array(proj)
269 | frames_proj.append(proj)
270 | # sfm_poses is w2c
271 | c2w = np.linalg.inv(sfm_poses[i])
272 | frames_c2w.append(c2w)
273 | frames_center.append(c2w[:3, 3])
274 |
275 | frames_proj = np.stack(frames_proj, 0)
276 | frames_c2w = np.stack(frames_c2w, 0)
277 | frames_center = np.stack(frames_center, 0)
278 |
279 | mvp_mtxs = get_mvp_matrix(
280 | frames_c2w, frames_proj,
281 | )
282 |
283 | index_init = ((np.array(kf_list) - kf_list[0]) / image_skip).astype(np.int32)
284 |
285 | return imgs, intr, sfm_poses, sfm_camprojs, frames_center, all_depths, all_normals, planes_all, \
286 | mvp_mtxs, index_init, new_sp_all, mean_colors_all
287 |
--------------------------------------------------------------------------------
/lib/load_data.py:
--------------------------------------------------------------------------------
1 | from .load_colmap import load_colmap_data
2 | from .load_replica import load_replica_data
3 |
4 |
5 | def load_data(args, kf_list, image_skip, load_plane):
6 | if args.dataset_type == 'scannet':
7 | images, intr, sfm_poses, sfm_camprojs, cam_centers, \
8 | all_depths, normals, planes_all, mvp_mtxs, index_init, \
9 | new_sps, mean_colors = load_colmap_data(args, kf_list=kf_list,
10 | image_skip=image_skip,
11 | load_plane=load_plane)
12 | print('Loaded scannet', intr, len(images), sfm_poses.shape, sfm_camprojs.shape, args.input_dir)
13 |
14 | data_dict = dict(
15 | poses=sfm_poses, images=images,
16 | intr=intr, sfm_camprojs=sfm_camprojs,
17 | cam_centers=cam_centers,
18 | all_depths=all_depths,
19 | normals=normals, planes_all=planes_all,
20 | mvp_mtxs=mvp_mtxs,
21 | index_init=index_init,
22 | new_sps=new_sps,
23 | mean_colors=mean_colors,
24 | )
25 | return data_dict
26 |
27 |
28 | elif args.dataset_type == 'replica':
29 | images, intr, sfm_poses, sfm_camprojs, cam_centers, \
30 | all_depths, normals, planes_all, mvp_mtxs, index_init, \
31 | new_sps, mean_colors = load_replica_data(args, kf_list=kf_list,
32 | image_skip=image_skip,
33 | load_plane=load_plane)
34 |
35 | print('Loaded replica', intr, len(images), sfm_poses.shape, sfm_camprojs.shape, args.input_dir)
36 |
37 | data_dict = dict(
38 | poses=sfm_poses, images=images,
39 | intr=intr, sfm_camprojs=sfm_camprojs,
40 | cam_centers=cam_centers,
41 | all_depths=all_depths,
42 | normals=normals, planes_all=planes_all,
43 | mvp_mtxs=mvp_mtxs,
44 | index_init=index_init,
45 | new_sps=new_sps,
46 | mean_colors=mean_colors,
47 | )
48 | return data_dict
49 |
50 |
51 | else:
52 | raise NotImplementedError(f'Unknown dataset type {args.dataset_type} exiting')
53 |
54 |
--------------------------------------------------------------------------------
/lib/load_replica.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import os
3 | import cv2
4 | import glob
5 | import numpy as np
6 | from tqdm import tqdm
7 |
8 | from recon.utils import load_sfm_pose, load_replica_pose
9 |
10 | from geom.plane_utils import calc_plane
11 |
12 | import torch
13 | import torch.nn.functional as F
14 |
15 |
16 | def dilate(img):
17 | kernel = torch.ones(3, 3).cuda().double()
18 | img = torch.tensor(img.astype('int32')).cuda().double()
19 |
20 | while True:
21 | mask = img > 0
22 | if torch.logical_not(mask).sum() == 0:
23 | break
24 | mask = mask.double()
25 |
26 | mask_dilated = F.conv2d(mask[None, None], kernel[None, None], padding=1).squeeze(0).squeeze(0)
27 |
28 | img_dilated = F.conv2d(img[None, None], kernel[None, None], padding=1).squeeze(0).squeeze(0)
29 | img_dilated = torch.where(mask_dilated == 0, img_dilated, img_dilated / mask_dilated)
30 |
31 | img = torch.where(mask == 0, img_dilated, img)
32 |
33 | img = img.cpu().numpy()
34 |
35 | return img
36 |
37 |
38 | def parse_sp(sp, intr, pose, depth, normal, img_sp, edge_thres=0.95, crop=20):
39 | """Parse the superpixel segmentation into planes
40 | """
41 | sp = torch.from_numpy(sp).cuda()
42 | depth = torch.from_numpy(depth).cuda()
43 | normal = (torch.from_numpy(normal).cuda() - 0.5) * 2
44 | normal = normal.permute(1, 2, 0)
45 | img_sp = torch.from_numpy(img_sp).cuda()
46 | sp = sp.int()
47 | sp_ids = torch.unique(sp)
48 | planes = []
49 | mean_colors = []
50 |
51 | # mask crop to -1
52 | sp[:crop] = -1
53 | sp[-crop:] = -1
54 | sp[:, :crop] = -1
55 | sp[:, -crop:] = -1
56 |
57 | new_sp = torch.ones_like(sp).cuda() * -1
58 | sp_id_cnt = 0
59 |
60 | for sp_id in tqdm(sp_ids):
61 | sp_mask = sp == sp_id
62 |
63 | if sp_mask.sum() <= 15:
64 | continue
65 |
66 | new_sp[sp_mask] = sp_id_cnt
67 | mean_color = torch.mean(img_sp[sp_mask], dim=0)
68 |
69 | sp_dis = depth[sp_mask]
70 | sp_dis = torch.median(sp_dis)
71 | if sp_dis < 0:
72 | continue
73 | sp_normal = normal[sp_mask]
74 |
75 | sp_coeff = torch.einsum('ij,kj->ik', sp_normal, sp_normal)
76 | if sp_coeff.min() < edge_thres:
77 | is_edge_plane = True
78 | else:
79 | is_edge_plane = False
80 |
81 | sp_normal = torch.mean(sp_normal, dim=0)
82 | sp_normal = sp_normal / torch.norm(sp_normal)
83 | sp_normal[1] = -sp_normal[1]
84 | sp_normal[2] = -sp_normal[2]
85 |
86 | x_accum = torch.sum(sp_mask, dim=0)
87 | y_accum = torch.sum(sp_mask, dim=1)
88 | x_range_uv = [torch.min(torch.nonzero(x_accum)).item(), torch.max(torch.nonzero(x_accum)).item()]
89 | y_range_uv = [torch.min(torch.nonzero(y_accum)).item(), torch.max(torch.nonzero(y_accum)).item()]
90 |
91 | sp_dis = sp_dis.cpu().numpy()
92 | sp_depth = depth[y_range_uv[0]:y_range_uv[1]+1, x_range_uv[0]:x_range_uv[1]+1].cpu().numpy()
93 | sp_mask = sp_mask[y_range_uv[0]:y_range_uv[1]+1, x_range_uv[0]:x_range_uv[1]+1].cpu().numpy()
94 |
95 | ret = calc_plane(intr, pose, sp_depth, sp_dis, sp_mask, x_range_uv, y_range_uv, plane_normal=sp_normal.cpu().numpy())
96 | if ret is None:
97 | continue
98 | plane, plane_up, resol, new_x_range_uv, new_y_range_uv = ret
99 |
100 | plane = np.concatenate([plane, plane_up, resol, new_x_range_uv, new_y_range_uv, [is_edge_plane]])
101 | planes.append(plane)
102 | mean_colors.append(mean_color.cpu().numpy())
103 |
104 | sp_id_cnt += 1
105 |
106 | return planes, new_sp.cpu().numpy(), mean_colors
107 |
108 |
109 | def get_projection_matrix(fovy: float, aspect_wh: float, near: float, far: float):
110 | proj_mtx = np.zeros((4, 4), dtype=np.float32)
111 | proj_mtx[0, 0] = 1.0 / (np.tan(fovy / 2.0) * aspect_wh)
112 | proj_mtx[1, 1] = -1.0 / np.tan(
113 | fovy / 2.0
114 | ) # add a negative sign here as the y axis is flipped in nvdiffrast output
115 | proj_mtx[2, 2] = -(far + near) / (far - near)
116 | proj_mtx[2, 3] = -2.0 * far * near / (far - near)
117 | proj_mtx[3, 2] = -1.0
118 | return proj_mtx
119 |
120 |
121 | def get_mvp_matrix(c2w, proj_mtx):
122 | # calculate w2c from c2w: R' = Rt, t' = -Rt * t
123 | # mathematically equivalent to (c2w)^-1
124 | w2c = np.zeros((c2w.shape[0], 4, 4))
125 | w2c[:, :3, :3] = np.transpose(c2w[:, :3, :3], (0, 2, 1))
126 | w2c[:, :3, 3:] = np.transpose(-c2w[:, :3, :3], (0, 2, 1)) @ c2w[:, :3, 3:]
127 | w2c[:, 3, 3] = 1.0
128 |
129 | # calculate mvp matrix by proj_mtx @ w2c (mv_mtx)
130 | mvp_mtx = proj_mtx @ w2c
131 |
132 | return mvp_mtx
133 |
134 |
135 | def load_replica_data(args, kf_list=[], image_skip=1, load_plane=True):
136 | imgs = []
137 | image_dir = os.path.join(args.input_dir, "images")
138 | start = kf_list[0]
139 | end = kf_list[-1] + 1 if kf_list[-1] != -1 else len(glob.glob(os.path.join(image_dir, "*")))
140 | image_names = sorted(glob.glob(os.path.join(image_dir, "*")))
141 |
142 | image_names_sp = [image_names[i] for i in kf_list]
143 | image_names = image_names[start:end:image_skip]
144 |
145 | print(str(datetime.now()) + ': \033[92mI', 'loading images ...', '\033[0m')
146 | for name in tqdm(image_names):
147 | img = cv2.imread(name)
148 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
149 | imgs.append(img.astype(np.float32) / 255.)
150 |
151 | imgs_sp = []
152 | for name in image_names_sp:
153 | img = cv2.imread(name)
154 | # convert to ycbcr using cv2
155 | img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV)
156 | imgs_sp.append(img.astype(np.float32) / 255.)
157 |
158 | imgs_sp = np.stack(imgs_sp, 0)
159 |
160 | intr, sfm_poses, sfm_camprojs = load_replica_pose(args.input_dir)
161 |
162 | sfm_poses_sp = np.array(sfm_poses)[kf_list]
163 | sfm_poses = np.array(sfm_poses)[start:end:image_skip]
164 | sfm_camprojs = np.array(sfm_camprojs)[start:end:image_skip]
165 |
166 | normal_dir = os.path.join(args.input_dir, "omnidata_normal")
167 | normal_names = sorted(glob.glob(os.path.join(normal_dir, "*")))
168 | depth_names = [ iname.replace('omnidata_normal', 'aligned_dense_depths') for iname in normal_names ]
169 | depth_names = [depth_names[i] for i in kf_list]
170 |
171 | print(str(datetime.now()) + ': \033[92mI', 'loading depths ...', '\033[0m')
172 | depths = []
173 | for name in tqdm(depth_names):
174 | depth = np.load(name)
175 | depths.append(depth)
176 |
177 | depths = np.stack(depths, 0)
178 | aligned_depth_dir = os.path.join(args.input_dir, "aligned_dense_depths")
179 | all_depth_names = sorted(glob.glob(os.path.join(aligned_depth_dir, "*")))[start:end:image_skip]
180 | all_depths = []
181 |
182 | for name, pose in tqdm(zip(all_depth_names, sfm_poses)):
183 | depth = np.load(name)
184 | all_depths.append(depth)
185 |
186 | if args.init.normal_model_type == 'omnidata':
187 | normal_dir = os.path.join(args.input_dir, "omnidata_normal")
188 | else:
189 | raise NotImplementedError(f'Unknown normal model type {args.init.normal_model_type} exiting')
190 | normal_names = sorted(glob.glob(os.path.join(normal_dir, "*")))
191 | normal_names = [normal_names[i] for i in kf_list]
192 |
193 | print(str(datetime.now()) + ': \033[92mI', 'loading normals ...', '\033[0m')
194 | normals = []
195 | for name in tqdm(normal_names):
196 | normal = np.load(name)[:3]
197 | normals.append(normal)
198 |
199 | normals = np.stack(normals, 0)
200 |
201 | all_normal_names = sorted(glob.glob(os.path.join(normal_dir, "*")))[start:end:image_skip]
202 | all_normals = []
203 |
204 | for name, pose in tqdm(zip(all_normal_names, sfm_poses)):
205 | normal = np.load(name)
206 | normal = normal[:3]
207 |
208 | normal = torch.from_numpy(normal).cuda()
209 | normal = (normal - 0.5) * 2
210 | normal = normal / torch.norm(normal, dim=0, keepdim=True)
211 | normal = normal.permute(1, 2, 0)
212 | normal[:, :, 1] = -normal[:, :, 1]
213 | normal[:, :, 2] = -normal[:, :, 2]
214 | normal = torch.einsum('ijk,kl->ijl', normal, torch.from_numpy(np.linalg.inv(pose[:3, :3]).T).cuda().float())
215 | normal = F.interpolate(normal.permute(2, 0, 1).unsqueeze(0),
216 | (int(imgs[0].shape[0]/64)*8, int(imgs[0].shape[1]/64)*8),
217 | mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0)
218 | all_normals.append(normal.cpu().numpy())
219 |
220 | sp_dir = os.path.join(args.input_dir, "sp")
221 | sp_names = [os.path.join(sp_dir, os.path.split(name)[-1].replace('.jpg', '.npy').replace('.png', '.npy')) for name in image_names_sp]
222 |
223 | if load_plane:
224 | print(str(datetime.now()) + ': \033[92mI', 'loading superpixels ...', '\033[0m')
225 | planes_all = []
226 | new_sp_all = []
227 | mean_colors_all = []
228 | plane_index_start = 0
229 | for plane_idx, (sp_name, depth, normal, pose, img_sp) in tqdm(enumerate(zip(sp_names, depths, normals, sfm_poses_sp, imgs_sp))):
230 | sp = np.load(sp_name)
231 | planes, new_sp, mean_colors = parse_sp(sp, intr, pose, depth, normal, img_sp)
232 | planes = np.array(planes)
233 | planes_idx = np.ones((planes.shape[0], 1)) * plane_idx
234 | planes = np.concatenate([planes, planes_idx], 1)
235 | planes_all.extend(planes)
236 | new_sp_all.append(new_sp + plane_index_start)
237 | plane_index_start += planes.shape[0]
238 | mean_colors = np.array(mean_colors)
239 | mean_colors_all.extend(mean_colors)
240 |
241 | planes_all = np.array(planes_all)
242 | mean_colors_all = np.array(mean_colors_all)
243 |
244 | else:
245 | planes_all = None
246 | new_sp_all = None
247 | mean_colors_all = None
248 |
249 | # proj matrixs
250 | frames_proj = []
251 | frames_c2w = []
252 | frames_center = []
253 | for i in range(len(sfm_camprojs)):
254 | fovy = 2 * np.arctan(0.5 * imgs[0].shape[0] / intr[1, 1])
255 | proj = get_projection_matrix(
256 | fovy, imgs[0].shape[1] / imgs[0].shape[0], 0.1, 1000.0
257 | )
258 | proj = np.array(proj)
259 | frames_proj.append(proj)
260 | # sfm_poses is w2c
261 | c2w = np.linalg.inv(sfm_poses[i])
262 | frames_c2w.append(c2w)
263 | frames_center.append(c2w[:3, 3])
264 |
265 | frames_proj = np.stack(frames_proj, 0)
266 | frames_c2w = np.stack(frames_c2w, 0)
267 | frames_center = np.stack(frames_center, 0)
268 |
269 | mvp_mtxs = get_mvp_matrix(
270 | frames_c2w, frames_proj,
271 | )
272 |
273 | index_init = ((np.array(kf_list) - kf_list[0]) / image_skip).astype(np.int32)
274 |
275 | return imgs, intr, sfm_poses, sfm_camprojs, frames_center, all_depths, all_normals, planes_all, \
276 | mvp_mtxs, index_init, new_sp_all, mean_colors_all
277 |
--------------------------------------------------------------------------------
/recon/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-LYJ-Lab/AlphaTablets/735cbfe6aa7f03f7bc37f48303045fe71ffa042a/recon/__init__.py
--------------------------------------------------------------------------------
/recon/run_depth.py:
--------------------------------------------------------------------------------
1 | import os
2 | import json
3 | import numpy as np
4 |
5 | from recon.third_party.omnidata.omnidata_tools.torch.demo_normal_custom_func import demo_normal_custom_func
6 | from recon.third_party.metric3d.depth_custom_func import depth_custom_func
7 |
8 |
9 | def metric3d_depth(img_dir, input_dir, depth_dir, dataset_type):
10 | """Initialize the dense depth of each image using single-image depth prediction
11 | """
12 | os.makedirs(depth_dir, exist_ok=True)
13 | if dataset_type == 'scannet':
14 | intr = np.loadtxt(f'{input_dir}/intrinsic/intrinsic_color.txt')
15 | fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2]
16 | elif dataset_type == 'replica':
17 | if os.path.exists(os.path.join(input_dir, 'cam_params.json')):
18 | cam_param_path = os.path.join(input_dir, 'cam_params.json')
19 | elif os.path.exists(os.path.join(input_dir, '../cam_params.json')):
20 | cam_param_path = os.path.join(input_dir, '../cam_params.json')
21 | else:
22 | raise FileNotFoundError('cam_params.json not found')
23 | with open(cam_param_path, 'r') as f:
24 | j = json.load(f)
25 | fx, fy, cx, cy = j["camera"]["fx"], j["camera"]["fy"], j["camera"]["cx"], j["camera"]["cy"]
26 | else:
27 | raise NotImplementedError(f'Unknown dataset type {dataset_type} exiting')
28 | indir = img_dir
29 | outdir = depth_dir
30 | depth_custom_func(fx, fy, cx, cy, indir, outdir)
31 |
32 |
33 | def omnidata_normal(img_dir, input_dir, normal_dir):
34 | """Initialize the dense normal of each image using single-image normal prediction
35 | """
36 | os.makedirs(normal_dir, exist_ok=True)
37 | demo_normal_custom_func(img_dir, normal_dir)
38 |
--------------------------------------------------------------------------------
/recon/run_recon.py:
--------------------------------------------------------------------------------
1 | import os
2 | import glob
3 | from datetime import datetime
4 |
5 |
6 | class Recon:
7 | def __init__(self, config, input_dir, dataset_type):
8 | self.config = config
9 | self.input_dir = input_dir
10 | self.image_dir = os.path.join(input_dir, "images")
11 | self.output_root = input_dir
12 | self.dataset_type = dataset_type
13 | self.preprocess()
14 |
15 | def preprocess(self):
16 | if not os.path.exists(self.image_dir):
17 | if self.dataset_type == 'scannet':
18 | original_image_dir = os.path.join(self.input_dir, "color")
19 | if not os.path.exists(original_image_dir):
20 | raise ValueError(f'Image directory {self.image_dir} not found')
21 | old_dir = os.getcwd()
22 | os.chdir(self.input_dir)
23 | os.symlink("color", "images")
24 | os.chdir(old_dir)
25 | print(str(datetime.now()) + ': \033[92mI', f'Linked {original_image_dir} to {self.image_dir}', '\033[0m')
26 | elif self.dataset_type == 'replica':
27 | original_image_dir = os.path.join(self.input_dir, "results")
28 | if not os.path.exists(original_image_dir):
29 | raise ValueError(f'Image directory {self.image_dir} not found')
30 | os.makedirs(self.image_dir)
31 | old_dir = os.getcwd()
32 | os.chdir(self.image_dir)
33 | for img in glob.glob(os.path.join("../results", "frame*")):
34 | os.symlink(os.path.join("../results", os.path.split(img)[-1]), os.path.split(img)[-1])
35 | os.chdir(old_dir)
36 | print(str(datetime.now()) + ': \033[92mI', f'Linked {original_image_dir} to {self.image_dir}', '\033[0m')
37 | else:
38 | raise NotImplementedError(f'Unknown dataset type {self.dataset_type} exiting')
39 |
40 |
41 | def recon(self):
42 | if os.path.exists(os.path.join(self.output_root, "recon.lock")):
43 | print(str(datetime.now()) + ': \033[92mI', 'Monocular estimation already done!', '\033[0m')
44 | return
45 |
46 | from .run_depth import omnidata_normal, metric3d_depth
47 |
48 | depth_dir = os.path.join(self.output_root, "aligned_dense_depths")
49 | print(str(datetime.now()) + ': \033[92mI', 'Running Depth Estimation ...', '\033[0m')
50 | metric3d_depth(self.image_dir, self.output_root, depth_dir, self.dataset_type)
51 | print(str(datetime.now()) + ': \033[92mI', 'Depth Estimation Done!', '\033[0m')
52 |
53 | normal_dir = os.path.join(self.output_root, "omnidata_normal")
54 | print(str(datetime.now()) + ': \033[92mI', 'Running Normal Estimation ...', '\033[0m')
55 | omnidata_normal(self.image_dir, self.output_root, normal_dir)
56 | print(str(datetime.now()) + ': \033[92mI', 'Normal Estimation Done!', '\033[0m')
57 |
58 | # create lock file
59 | open(os.path.join(self.output_root, "recon.lock"), "w").close()
60 |
61 | def run_sp(self, kf_list):
62 | from .run_sp import run_sp
63 | print(str(datetime.now()) + ': \033[92mI', 'Running SuperPixel Subdivision ...', '\033[0m')
64 | sp_dir = run_sp(self.image_dir, self.output_root, kf_list)
65 | print(str(datetime.now()) + ': \033[92mI', 'SuperPixel Subdivision Done!', '\033[0m')
66 | return sp_dir
67 |
--------------------------------------------------------------------------------
/recon/run_sp.py:
--------------------------------------------------------------------------------
1 | import os
2 | import cv2
3 | import numpy as np
4 | from tqdm import tqdm
5 |
6 |
7 | def sp_division(img):
8 | slic = cv2.ximgproc.createSuperpixelSLIC(img)
9 | slic.iterate(10)
10 | labels = slic.getLabels()
11 | # mask_slic = slic.getLabelContourMask() #获取Mask,超像素边缘Mask==1
12 | # number_slic = slic.getNumberOfSuperpixels() #获取超像素数目
13 | # mask_inv_slic = cv2.bitwise_not(mask_slic)
14 | # img_slic = cv2.bitwise_and(img,img,mask = mask_inv_slic) #在原图上绘制超像素边界
15 | # cv2.imwrite('3.png', img_slic)
16 | return labels
17 |
18 |
19 | def run_sp(image_dir, output_root, kf_list=None):
20 | sp_dir = os.path.join(output_root, "sp")
21 | os.makedirs(sp_dir, exist_ok=True)
22 | try:
23 | image_names = sorted(os.listdir(image_dir), key=lambda x: int(x.split('/')[-1][:-4]))
24 | except:
25 | image_names = sorted(os.listdir(image_dir))
26 | image_names = [image_names[k] for k in kf_list]
27 | for name in tqdm(image_names):
28 | if os.path.exists(os.path.join(sp_dir, name[:-4] + '.npy')):
29 | continue
30 | img = cv2.imread(os.path.join(image_dir, name))
31 | labels = sp_division(img)
32 | np.save(os.path.join(sp_dir, name[:-4]), labels)
33 |
34 |
35 | if __name__ == '__main__':
36 | img = cv2.imread('/home/hyz/git-plane/PixelPlane/data/scene0000_00/images/36.jpg')
37 | sp_division(img)
--------------------------------------------------------------------------------
/recon/utils.py:
--------------------------------------------------------------------------------
1 | import json
2 | import os
3 | import cv2
4 | import glob
5 | import numpy as np
6 |
7 |
8 | def load_images(img_dir):
9 | imgs = []
10 | image_names = sorted(os.listdir(img_dir))
11 | for name in image_names:
12 | img = cv2.imread(os.path.join(img_dir, name))
13 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
14 | imgs.append(img)
15 | return imgs, image_names
16 |
17 |
18 | def load_sfm(sfm_dir):
19 | """Load converted sfm world-to-cam depths and poses
20 | """
21 | depth_dir = os.path.join(sfm_dir, "colmap_outputs_converted/depths")
22 | pose_dir = os.path.join(sfm_dir, "colmap_outputs_converted/poses")
23 | depth_names = sorted(glob.glob(os.path.join(depth_dir, "*.npy")))
24 | pose_names = sorted(glob.glob(os.path.join(pose_dir, "*.txt")))
25 | sfm_depths, sfm_poses = [], []
26 | for dn, pn in zip(depth_names, pose_names):
27 | assert os.path.basename(dn)[:-4] == os.path.basename(pn)[:-4]
28 | depth = np.load(dn)
29 | pose = np.loadtxt(pn)
30 | sfm_depths.append(depth)
31 | sfm_poses.append(pose)
32 | return sfm_depths, sfm_poses
33 |
34 |
35 | def load_sfm_pose(sfm_dir):
36 | """Load sfm poses and then convert into proj mat
37 | """
38 | pose_dir = os.path.join(sfm_dir, "colmap_outputs_converted/poses")
39 | intr_dir = os.path.join(sfm_dir, "colmap_outputs_converted/intrinsics")
40 | pose_names = sorted(glob.glob(os.path.join(pose_dir, "*.txt")), key=lambda x: int(x.split('/')[-1][:-4]))
41 | intr_names = sorted(glob.glob(os.path.join(intr_dir, "*.txt")), key=lambda x: int(x.split('/')[-1][:-4]))
42 | K = np.loadtxt(intr_names[0])
43 | KH = np.eye(4)
44 | KH[:3,:3] = K
45 | sfm_poses, sfm_projmats = [], []
46 | for pn in pose_names:
47 | # world-to-cam
48 | pose = np.loadtxt(pn)
49 | pose = np.concatenate([pose, np.array([[0,0,0,1]])], 0)
50 | c2w = np.linalg.inv(pose)
51 | c2w[0:3, 1:3] *= -1
52 | pose = np.linalg.inv(c2w)
53 | sfm_poses.append(pose)
54 | # projmat
55 | projmat = KH @ pose
56 | sfm_projmats.append(projmat)
57 | return K, sfm_poses, sfm_projmats
58 |
59 |
60 | def load_scannet_pose(scannet_dir):
61 | """Load scannet poses and then convert into proj mat
62 | """
63 | pose_dir = os.path.join(scannet_dir, "pose")
64 | intr_dir = os.path.join(scannet_dir, "intrinsic")
65 | pose_names = sorted(glob.glob(os.path.join(pose_dir, "*.txt")), key=lambda x: int(x.split('/')[-1][:-4]))
66 | intr_name = os.path.join(intr_dir, "intrinsic_color.txt")
67 | KH = np.loadtxt(intr_name)
68 | scannet_poses, scannet_projmats = [], []
69 | for pn in pose_names:
70 | p = np.loadtxt(pn)
71 | R = p[:3, :3]
72 | R = np.matmul(R, np.array([
73 | [1, 0, 0],
74 | [0, -1, 0],
75 | [0, 0, -1]
76 | ]))
77 | p[:3, :3] = R
78 | p = np.linalg.inv(p)
79 | scannet_poses.append(p)
80 | # projmat
81 | projmat = KH @ p
82 | scannet_projmats.append(projmat)
83 | return KH[:3, :3], scannet_poses, scannet_projmats
84 |
85 |
86 | def load_replica_pose(replica_dir):
87 | """Load replica poses and then convert into proj mat
88 | """
89 | if os.path.exists(os.path.join(replica_dir, 'cam_params.json')):
90 | cam_param_path = os.path.join(replica_dir, 'cam_params.json')
91 | elif os.path.exists(os.path.join(replica_dir, '../cam_params.json')):
92 | cam_param_path = os.path.join(replica_dir, '../cam_params.json')
93 | else:
94 | raise FileNotFoundError('cam_params.json not found')
95 | with open(cam_param_path, 'r') as f:
96 | j = json.load(f)
97 | intrinsics = np.array([
98 | j["camera"]["fx"], 0, j["camera"]["cx"],
99 | 0, j["camera"]["fy"], j["camera"]["cy"],
100 | 0, 0, 1
101 | ], dtype=np.float32).reshape(3, 3)
102 |
103 | KH = np.eye(4)
104 | KH[:3, :3] = intrinsics
105 |
106 | extrinsics = np.loadtxt(os.path.join(replica_dir, 'traj.txt')).reshape(-1, 4, 4)
107 | poses = []
108 | projmats = []
109 | for extrinsic in extrinsics:
110 | p = extrinsic
111 | R = p[:3, :3]
112 | R = np.matmul(R, np.array([
113 | [1, 0, 0],
114 | [0, -1, 0],
115 | [0, 0, -1]
116 | ]))
117 | p[:3, :3] = R
118 | p = np.linalg.inv(p)
119 | poses.append(p)
120 | # projmat
121 | projmat = KH @ p
122 | projmats.append(projmat)
123 |
124 | return intrinsics, poses, projmats
125 |
126 |
127 | def load_dense_depths(depth_dir):
128 | """Load initialized single-image dense depth predictions
129 | """
130 | depth_names = sorted(os.listdir(depth_dir))
131 | depths = []
132 | for dn in depth_names:
133 | depth = np.load(os.path.join(depth_dir, dn))
134 | depths.append(depth)
135 | return depths, depth_names
136 |
137 |
138 | def depth2disp(depth):
139 | """Convert depth map to disparity
140 | """
141 | disp = 1.0 / (depth + 1e-8)
142 | return disp
143 |
144 |
145 | def disp2depth(disp):
146 | """Convert disparity map to depth
147 | """
148 | depth = 1.0 / (disp + 1e-8)
149 | return depth
150 |
--------------------------------------------------------------------------------
/requirements.txt:
--------------------------------------------------------------------------------
1 | tqdm
2 | scikit-learn
3 | opencv-python
4 | opencv-contrib-python
5 | hydra-core
6 | matplotlib
7 | timm
8 | h5py
9 | pytorch_lightning
10 | mmengine
11 | mmcv
12 | ninja
13 | trimesh
14 | open3d
15 | git+https://github.com/NVlabs/nvdiffrast.git
--------------------------------------------------------------------------------
/run.py:
--------------------------------------------------------------------------------
1 | from datetime import datetime
2 | import os, time, random, argparse
3 |
4 | import numpy as np
5 | import cv2
6 |
7 | from tqdm import tqdm, trange
8 |
9 | import torch
10 | import torch.nn.functional as F
11 | from torch.utils.data import DataLoader
12 | import torch.optim.lr_scheduler as lr_scheduler
13 |
14 | from lib.load_data import load_data
15 | from lib.alphatablets import AlphaTablets
16 | from dataset.dataset import CustomDataset
17 | from lib.keyframe import get_keyframes
18 | from recon.run_recon import Recon
19 |
20 |
21 | def config_parser():
22 | '''Define command line arguments
23 | '''
24 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
25 | parser.add_argument('--config', type=str, default='configs/default.yaml',
26 | help='config file path')
27 | parser.add_argument("--seed", type=int, default=777,
28 | help='Random seed')
29 | parser.add_argument("--job", type=str, default=str(time.time()),
30 | help='Job name')
31 |
32 | # learning options
33 | parser.add_argument("--lr_tex", type=float, default=0.01,
34 | help='Learning rate of texture color')
35 | parser.add_argument("--lr_alpha", type=float, default=0.03,
36 | help='Learning rate of texture alpha')
37 | parser.add_argument("--lr_plane_n", type=float, default=0.0001,
38 | help='Learning rate of plane normal')
39 | parser.add_argument("--lr_plane_dis", type=float, default=0.0005,
40 | help='Learning rate of plane distance')
41 | parser.add_argument("--lr_plane_dis_stage2", type=float, default=0.0002,
42 | help='Learning rate of plane distance in stage 2')
43 |
44 | # loss weights
45 | parser.add_argument("--weight_alpha_inv", type=float, default=1.0,
46 | help='Weight of alpha inv loss')
47 | parser.add_argument("--weight_normal", type=float, default=4.0,
48 | help='Weight of direct normal loss')
49 | parser.add_argument("--weight_depth", type=float, default=4.0,
50 | help='Weight of direct depth loss')
51 | parser.add_argument("--weight_distortion", type=float, default=20.0,
52 | help='Weight of distortion loss')
53 | parser.add_argument("--weight_decay", type=float, default=0.9,
54 | help='Weight of tablet alpha decay after a single step. -1 denotes automatic decay')
55 |
56 | # merging options
57 | parser.add_argument("--merge_normal_thres_init", type=float, default=0.97,
58 | help='Threshold of init merging planes')
59 | parser.add_argument("--merge_normal_thres", type=float, default=0.93,
60 | help='Threshold of merging planes')
61 | parser.add_argument("--merge_dist_thres1", type=float, default=0.5,
62 | help='Threshold of init merging planes')
63 | parser.add_argument("--merge_dist_thres2", type=float, default=0.1,
64 | help='Threshold of merging planes')
65 | parser.add_argument("--merge_color_thres1", type=float, default=0.3,
66 | help='Threshold of init merging planes')
67 | parser.add_argument("--merge_color_thres2", type=float, default=0.2,
68 | help='Threshold of merging planes')
69 | parser.add_argument("--merge_Y_decay", type=float, default=0.5,
70 | help='Decay rate of Y channel')
71 |
72 | # optimization options
73 | parser.add_argument("--batch_size", type=int, default=3,
74 | help='Batch size')
75 | parser.add_argument("--max_steps", type=int, default=32,
76 | help='Max optimization steps')
77 | parser.add_argument("--merge_interval", type=int, default=13,
78 | help='Merge interval')
79 | parser.add_argument("--max_steps_union", type=int, default=9,
80 | help='Max optimization steps for union optimization')
81 | parser.add_argument("--merge_interval_union", type=int, default=3,
82 | help='Merge interval for union optimization')
83 |
84 | parser.add_argument("--alpha_init", type=float, default=0.5,
85 | help='Initial alpha value')
86 | parser.add_argument("--alpha_init_empty", type=float, default=0.,
87 | help='Initial alpha value for empty pixels')
88 | parser.add_argument("--depth_inside_mask_alphathres", type=float, default=0.5,
89 | help='Threshold of alpha for inside mask')
90 |
91 | parser.add_argument("--max_rasterize_layers", type=int, default=15,
92 | help='Max rasterize layers')
93 |
94 | # logging/saving options
95 | parser.add_argument("--log_path", type=str, default='./logs',
96 | help='path to save logs')
97 | parser.add_argument("--dump_images", type=bool, default=False)
98 | parser.add_argument("--dump_interval", type=int, default=1,
99 | help='Dump interval')
100 |
101 | # update input
102 | parser.add_argument("--input_dir", type=str, default='',
103 | help='input directory')
104 |
105 | args = parser.parse_args()
106 |
107 | from hydra import compose, initialize
108 |
109 | initialize(version_base=None, config_path='./')
110 | ori_cfg = compose(config_name=args.config)['configs']
111 |
112 | class Struct:
113 | def __init__(self, **entries):
114 | self.__dict__.update(entries)
115 |
116 | cfg = Struct(**ori_cfg)
117 |
118 | # dump args and cfg
119 | import json
120 | os.makedirs(os.path.join(args.log_path, args.job), exist_ok=True)
121 | with open(os.path.join(args.log_path, args.job, 'args.json'), 'w') as f:
122 | args_json = {k: v for k, v in vars(args).items() if k != 'config'}
123 | json.dump(args_json, f, indent=4)
124 |
125 | import shutil
126 | shutil.copy(args.config, os.path.join(args.log_path, args.job, 'config.yaml'))
127 |
128 | if args.input_dir != '':
129 | cfg.data.input_dir = args.input_dir
130 |
131 | return args, cfg
132 |
133 |
134 | def seed_everything():
135 | '''Seed everything for better reproducibility.
136 | (some pytorch operation is non-deterministic like the backprop of grid_samples)
137 | '''
138 | torch.manual_seed(args.seed)
139 | np.random.seed(args.seed)
140 | random.seed(args.seed)
141 |
142 |
143 | def load_everything(cfg, kf_list, image_skip, load_plane=True):
144 | '''Load images / poses / camera settings / data split.
145 | '''
146 | data_dict = load_data(cfg.data, kf_list=kf_list, image_skip=image_skip, load_plane=load_plane)
147 | data_dict['poses'] = torch.Tensor(data_dict['poses'])
148 | return data_dict
149 |
150 |
151 | if __name__=='__main__':
152 | # load setup
153 | args, cfg = config_parser()
154 | merge_cfgs = {
155 | 'normal_thres_init': args.merge_normal_thres_init,
156 | 'normal_thres': args.merge_normal_thres,
157 | 'dist_thres1': args.merge_dist_thres1,
158 | 'dist_thres2': args.merge_dist_thres2,
159 | 'color_thres1': args.merge_color_thres1,
160 | 'color_thres2': args.merge_color_thres2,
161 | 'Y_decay': args.merge_Y_decay
162 | }
163 |
164 | # init enviroment
165 | if torch.cuda.is_available():
166 | torch.set_default_tensor_type('torch.cuda.FloatTensor')
167 | device = torch.device('cuda')
168 | else:
169 | torch.set_default_tensor_type('torch.FloatTensor')
170 | device = torch.device('cpu')
171 | torch.set_default_dtype(torch.float32)
172 | seed_everything()
173 |
174 | recon = Recon(cfg.data.init, cfg.data.input_dir, cfg.data.dataset_type)
175 | recon.recon()
176 |
177 | kf_lists, image_skip = get_keyframes(cfg.data.input_dir)
178 | subset_mean_len = np.mean([ k[-1]-k[0]+1 for k in kf_lists ])
179 | kf_lists.append([0, -1])
180 |
181 | for set_num, kf_list in enumerate(kf_lists):
182 | if set_num == len(kf_lists) - 1:
183 | print(str(datetime.now()) + f': \033[94mUnion Optimization', '\033[0m')
184 | else:
185 | print(str(datetime.now()) + f': \033[94mSubset {set_num+1}/{len(kf_lists)-1}, Keyframes', kf_list, '\033[0m')
186 |
187 | union_optimize = False
188 | if kf_list[0] == 0 and kf_list[1] == -1:
189 | union_optimize = True
190 | if os.path.exists(os.path.join(args.log_path, args.job, 'ckpt', f'./ckpt_{set_num:02d}_{args.max_steps_union-1}.pt')):
191 | print(str(datetime.now()) + f': \033[94mUnion optimization already done!', '\033[0m')
192 | continue
193 | else:
194 | if os.path.exists(os.path.join(args.log_path, args.job, 'ckpt', f'./ckpt_{set_num:02d}_{args.max_steps-1}.pt')):
195 | print(str(datetime.now()) + f': \033[94mSubset {set_num+1}/{len(kf_lists)-1} optimization already done!', '\033[0m')
196 | continue
197 |
198 | if kf_list[-1] != -1:
199 | recon.run_sp(kf_list)
200 |
201 | # load images / poses / camera settings / data split
202 | data_dict = load_everything(cfg=cfg, kf_list=kf_list, image_skip=image_skip,
203 | load_plane=not union_optimize)
204 |
205 | images, mvp_mtxs, K = data_dict['images'], data_dict['mvp_mtxs'], data_dict['intr']
206 | index_init = data_dict['index_init']
207 | new_sps = data_dict['new_sps']
208 | cam_centers = data_dict['cam_centers']
209 | normals = data_dict['normals']
210 | mean_colors = data_dict['mean_colors']
211 | depths = data_dict['all_depths']
212 |
213 | sp_images = np.stack([ images[idx] for idx in index_init ])
214 |
215 | if not union_optimize:
216 | pp = AlphaTablets(plane_params=data_dict['planes_all'],
217 | sp=[ sp_images , new_sps, mvp_mtxs[index_init]],
218 | cam_centers=cam_centers[index_init],
219 | mean_colors=mean_colors,
220 | HW=images[0].shape[0:2],
221 | alpha_init=args.alpha_init, alpha_init_empty=args.alpha_init_empty,
222 | inside_mask_alpha_thres=args.depth_inside_mask_alphathres,
223 | merge_cfgs=merge_cfgs)
224 | else:
225 | pp = AlphaTablets(ckpt_paths=[
226 | os.path.join(args.log_path, args.job, 'ckpt', f'./ckpt_{ii:02d}_{args.max_steps-1}.pt')
227 | for ii in range(len(kf_lists) - 1)
228 | ], merge_cfgs=merge_cfgs,
229 | alpha_init=args.alpha_init, alpha_init_empty=args.alpha_init_empty,
230 | inside_mask_alpha_thres=args.depth_inside_mask_alphathres)
231 |
232 | dataset = CustomDataset(images, mvp_mtxs, normals, depths)
233 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, generator=torch.Generator(device=device))
234 |
235 | optimizer = torch.optim.Adam([
236 | {'params': pp.tex_color, 'lr': args.lr_tex},
237 | {'params': pp.tex_alpha, 'lr': args.lr_alpha},
238 | {'params': pp.plane_n, 'lr': args.lr_plane_n},
239 | {'params': pp.plane_dis, 'lr': args.lr_plane_dis},
240 | ])
241 |
242 | scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.3)
243 | crop = cfg.data.crop
244 |
245 | max_steps = args.max_steps if not union_optimize else args.max_steps_union
246 | merge_interval = args.merge_interval if not union_optimize else args.merge_interval_union
247 |
248 | psnr_lst = []
249 | time0 = time.time()
250 | print(str(datetime.now()) + ': \033[92mI', 'Start AlphaTablets Optimization', '\033[0m')
251 | for global_step in trange(0, max_steps):
252 | avg_psnr = 0
253 | for batch_idx, (imgs, mvp_mtxs, normals, depths) in enumerate(dataloader):
254 | mvp_mtxs = mvp_mtxs.cuda()
255 | imgs = imgs.cuda()
256 | normals = normals.cuda()
257 | depths = depths.cuda()
258 |
259 | render_result = pp(mvp_mtxs,
260 | max_rasterize_layers=args.max_rasterize_layers)
261 |
262 | try:
263 | render_result, render_normal, render_depth, alpha_loss, distort_loss, inside_mask, plane_inside_mask = render_result
264 | except:
265 | continue
266 |
267 | # gradient descent step
268 | optimizer.zero_grad(set_to_none=True)
269 | crop_mask = torch.zeros_like(imgs)
270 | crop_mask[:, crop:-crop, crop:-crop] = 1
271 | mask = crop_mask * inside_mask
272 | loss = F.mse_loss(render_result * mask, imgs * mask)
273 | direct_normal_loss = F.mse_loss(render_normal, normals)
274 | direct_depth_loss = F.mse_loss(render_depth * mask[..., 0] * plane_inside_mask[..., 0], depths * mask[..., 0] * plane_inside_mask[..., 0])
275 | psnr = -10 * torch.log10(loss)
276 |
277 | loss += alpha_loss * args.weight_alpha_inv
278 | loss += distort_loss * args.weight_distortion
279 | loss += direct_normal_loss * args.weight_normal
280 | loss += direct_depth_loss * args.weight_depth
281 | avg_psnr += psnr.item()
282 |
283 | loss.backward()
284 | optimizer.step()
285 |
286 | if global_step % args.dump_interval == 0 and batch_idx == 0 and args.dump_images:
287 | optimizer.zero_grad(set_to_none=True)
288 | mvp_mtxs = torch.from_numpy(data_dict['mvp_mtxs'][0:1]).cuda().float()
289 | render_result, alpha_acc = pp(mvp_mtxs, return_alpha=True,
290 | max_rasterize_layers=args.max_rasterize_layers)
291 |
292 | os.makedirs(os.path.join(args.log_path, args.job, 'dump_images'), exist_ok=True)
293 |
294 | cv2.imwrite(os.path.join(args.log_path, args.job, 'dump_images', f'{set_num:02d}_{global_step:05d}.png'), render_result[0].detach().cpu().numpy() * 255)
295 | cv2.imwrite(os.path.join(args.log_path, args.job, 'dump_images', f'{set_num:02d}_{global_step:05d}_mask.png'), mask[0].detach().cpu().numpy() * 255)
296 | cv2.imwrite(os.path.join(args.log_path, args.job, 'dump_images', f'{set_num:02d}_{global_step:05d}_alpha.png'), alpha_acc[0].detach().cpu().numpy() * 255)
297 |
298 | error_map = torch.abs(render_result - torch.from_numpy(data_dict['images'][0][None, ...]).cuda().float())
299 | error_map = error_map[0].detach().cpu().numpy()
300 | # turn to red-blue, 0-255
301 | error_map = (error_map - error_map.min()) / (error_map.max() - error_map.min()) * 255
302 | error_map = cv2.applyColorMap(error_map.astype(np.uint8), cv2.COLORMAP_JET)
303 | cv2.imwrite(os.path.join(args.log_path, args.job, 'dump_images', f'{set_num:02d}_{global_step:05d}_error_map.png'), error_map)
304 |
305 | if global_step % merge_interval == 0 and batch_idx == 0 and global_step > 0:
306 | optimizer.zero_grad(set_to_none=True)
307 |
308 | pp.weightCheck(torch.from_numpy(data_dict['mvp_mtxs']).cuda().float())
309 |
310 | pp = AlphaTablets(pp=pp, merge_cfgs=merge_cfgs)
311 | del optimizer
312 | optimizer = torch.optim.Adam([
313 | {'params': pp.tex_color, 'lr': args.lr_tex},
314 | {'params': pp.tex_alpha, 'lr': args.lr_alpha},
315 | {'params': pp.plane_n, 'lr': args.lr_plane_n},
316 | {'params': pp.plane_dis, 'lr': args.lr_plane_dis_stage2},
317 | ])
318 |
319 | scheduler.step()
320 |
321 | if global_step == max_steps - 1:
322 | with torch.no_grad():
323 | os.makedirs(os.path.join(args.log_path, args.job, 'ckpt'), exist_ok=True)
324 | pp.save_ckpt(os.path.join(args.log_path, args.job, 'ckpt', f'ckpt_{set_num:02d}_{global_step}.pt'))
325 | optimizer.zero_grad(set_to_none=True)
326 |
327 | if union_optimize:
328 | with torch.no_grad():
329 | decay = args.weight_decay if args.weight_decay > 0 else max(1-0.00067*subset_mean_len, 0.8)
330 | pp.tex_alpha.data = torch.log( decay / (1 - decay + torch.exp(-pp.tex_alpha)) )
331 | elif global_step < args.merge_interval:
332 | with torch.no_grad():
333 | decay = args.weight_decay if args.weight_decay > 0 else max(1-0.00067*len(images), 0.8)
334 | pp.tex_alpha.data = torch.log( decay / (1 - decay + torch.exp(-pp.tex_alpha)) )
335 |
336 | if union_optimize:
337 | os.makedirs(os.path.join(args.log_path, args.job, 'results'), exist_ok=True)
338 | export_name_obj = os.path.join(args.log_path, args.job, 'results', f'{set_num:02d}_final.obj')
339 | pp.export_mesh_with_weight_check(torch.from_numpy(data_dict['mvp_mtxs']).cuda().float(), name=export_name_obj)
340 |
341 | os.makedirs(os.path.join(args.log_path, args.job, 'plys'), exist_ok=True)
342 | export_name = os.path.join(args.log_path, args.job, 'plys', f'{set_num:02d}_final.ply')
343 | pp.export_ply(torch.from_numpy(data_dict['mvp_mtxs'][::8]).cuda().float(), name=export_name)
344 |
345 | if set_num == len(kf_lists) - 1:
346 | print(str(datetime.now()) + f': \033[94mUnion optimization finished!', '\033[0m')
347 | else:
348 | print(str(datetime.now()) + f': \033[94mSubset {set_num+1}/{len(kf_lists)-1} optimization finished', '\033[0m')
349 |
350 | # create complete lock
351 | with open(os.path.join(args.log_path, args.job, 'complete.lock'), 'w') as f:
352 | f.write('done')
353 |
--------------------------------------------------------------------------------
/test.py:
--------------------------------------------------------------------------------
1 | import os
2 | import pickle
3 | import numpy as np
4 | import open3d as o3d
5 | import trimesh
6 | from tqdm import tqdm
7 | from sklearn.metrics import rand_score
8 | from skimage.metrics import variation_of_information
9 |
10 | def compute_sc(gt_in, pred_in):
11 | # to be consistent with skimage sklearn input arrangment
12 |
13 | assert len(pred_in.shape) == 1 and len(gt_in.shape) == 1
14 |
15 | acc, pred, gt = match_seg(pred_in, gt_in) # n_gt * n_pred
16 |
17 | bestmatch_gt2pred = acc.max(axis=1)
18 | bestmatch_pred2gt = acc.max(axis=0)
19 |
20 | pred_id, pred_cnt = np.unique(pred, return_counts=True)
21 | gt_id, gt_cnt = np.unique(gt, return_counts=True)
22 |
23 | cnt_pred, sum_pred = 0, 0
24 | for i, _ in enumerate(pred_id):
25 | cnt_pred += bestmatch_pred2gt[i] * pred_cnt[i]
26 | sum_pred += pred_cnt[i]
27 |
28 | cnt_gt, sum_gt = 0, 0
29 | for i, _ in enumerate(gt_id):
30 | cnt_gt += bestmatch_gt2pred[i] * gt_cnt[i]
31 | sum_gt += gt_cnt[i]
32 |
33 | sc = (cnt_pred / sum_pred + cnt_gt / sum_gt) / 2
34 |
35 | return sc
36 |
37 | def match_seg(pred_in, gt_in):
38 | assert len(pred_in.shape) == 1 and len(gt_in.shape) == 1
39 |
40 | pred, gt = compact_segm(pred_in), compact_segm(gt_in)
41 | n_gt = gt.max() + 1
42 | n_pred = pred.max() + 1
43 |
44 | # this will offer the overlap between gt and pred
45 | # if gt == 1, we will later have conf[1, j] = gt(1) + pred(j) * n_gt
46 | # essential, we encode conf_mat[i, j] to overlap, and when we decode it we let row as gt, and col for pred
47 | # then assume we have 13 gt label, 6 pred label --> gt 1 will correspond to 14, 1+2*13 ... 1 + 6*13
48 | overlap = gt + n_gt * pred
49 | freq, bin_val = np.histogram(overlap, np.arange(0, n_gt * n_pred+1)) # hist given bins [1, 2, 3] --> return [1, 2), [2, 3)
50 | conf_mat = freq.reshape([ n_gt, n_pred], order='F') # column first reshape, like matlab
51 |
52 | acc = np.zeros([n_gt, n_pred])
53 | for i in range(n_gt):
54 | for j in range(n_pred):
55 | gt_i = conf_mat[i].sum()
56 | pred_j = conf_mat[:, j].sum()
57 | gt_pred = conf_mat[i, j]
58 | acc[i,j] = gt_pred / (gt_i + pred_j - gt_pred) if (gt_i + pred_j - gt_pred) != 0 else 0
59 | return acc[1:, 1:], pred, gt
60 |
61 | def compact_segm(seg_in):
62 | seg = seg_in.copy()
63 | uniq_id = np.unique(seg)
64 | cnt = 1
65 | for id in sorted(uniq_id):
66 | if id == 0:
67 | continue
68 | seg[seg==id] = cnt
69 | cnt += 1
70 |
71 | # every id (include non-plane should not be 0 for the later process in match_seg
72 | seg = seg + 1
73 | return seg
74 |
75 | def project_to_mesh(from_mesh, to_mesh, attribute, attr_name, color_mesh=None, dist_thresh=None):
76 | """ Transfers attributs from from_mesh to to_mesh using nearest neighbors
77 |
78 | Each vertex in to_mesh gets assigned the attribute of the nearest
79 | vertex in from mesh. Used for semantic evaluation.
80 |
81 | Args:
82 | from_mesh: Trimesh with known attributes
83 | to_mesh: Trimesh to be labeled
84 | attribute: Which attribute to transfer
85 | dist_thresh: Do not transfer attributes beyond this distance
86 | (None transfers regardless of distacne between from and to vertices)
87 |
88 | Returns:
89 | Trimesh containing transfered attribute
90 | """
91 |
92 | if len(from_mesh.vertices) == 0:
93 | to_mesh.vertex_attributes[attr_name] = np.zeros((0), dtype=np.uint8)
94 | to_mesh.visual.vertex_colors = np.zeros((0), dtype=np.uint8)
95 | return to_mesh
96 |
97 | pcd = o3d.geometry.PointCloud()
98 | pcd.points = o3d.utility.Vector3dVector(from_mesh.vertices)
99 | kdtree = o3d.geometry.KDTreeFlann(pcd)
100 |
101 | pred_ids = attribute.copy()
102 | pred_colors = from_mesh.visual.vertex_colors if color_mesh is None else color_mesh.visual.vertex_colors
103 |
104 | matched_ids = np.zeros((to_mesh.vertices.shape[0]), dtype=np.uint8)
105 | matched_colors = np.zeros((to_mesh.vertices.shape[0], 4), dtype=np.uint8)
106 |
107 | for i, vert in enumerate(to_mesh.vertices):
108 | _, inds, dist = kdtree.search_knn_vector_3d(vert, 1)
109 | if dist_thresh is None or dist[0]= xmin, sample_points[:, 0] <= xmax),
244 | np.logical_and(sample_points[:, 1] >= ymin, sample_points[:, 1] <= ymax),
245 | np.logical_and(sample_points[:, 2] >= zmin, sample_points[:, 2] <= zmax),
246 | )
247 | points_mask = np.logical_and(points_mask, vertices_mask_new)
248 |
249 | sample_points = sample_points[points_mask]
250 | sample_indices = sample_indices[points_mask]
251 |
252 | vertices_eval = trimesh.Trimesh(vertices=sample_points, process=False)
253 | vertices_eval.export(mesh_file_eval_ori.replace('.obj', '.ply'))
254 | mesh_file_eval = mesh_file_eval_ori.replace('.obj', '.ply')
255 |
256 | # eval 3d geometry
257 | metrics_mesh, prec_err_pcd, recal_err_pcd = eval_mesh(mesh_file_eval, file_mesh_trgt, error_map=False)
258 | metrics = {**metrics_mesh}
259 | # o3d.io.write_triangle_mesh(os.path.join('./','%s_precErr.ply' % scene), prec_err_pcd)
260 | # o3d.io.write_triangle_mesh(os.path.join('./', '%s_recErr.ply' % scene), recal_err_pcd)
261 |
262 | dist1.append(metrics['dist1'])
263 | dist2.append(metrics['dist2'])
264 | prec.append(metrics['prec'])
265 | recall.append(metrics['recal'])
266 | fscore.append(metrics['fscore'])
267 |
268 | # prepare files for instance evaluation
269 | mesh_trgt = trimesh.load(file_mesh_trgt, process=False)
270 |
271 | new_pred_ins = np.array(instance_ids)[np.array(sample_indices).astype('int32')].astype('int32')
272 |
273 | # specify color to vertces_eval by color pool with new_pred_ins
274 | color_pool = np.random.rand(32768, 3) * 255
275 | color_pool = np.concatenate([color_pool, np.ones((32768, 1)) * 255], axis=1).astype(np.uint8)
276 | colors = color_pool[new_pred_ins]
277 | vertices_eval.visual.vertex_colors = colors
278 |
279 | mesh_planeIns_transfer = project_to_mesh(vertices_eval, mesh_trgt, new_pred_ins, 'plane_ins')
280 |
281 | planeIns = mesh_planeIns_transfer.vertex_attributes['plane_ins']
282 |
283 | plnIns_save_pth = os.path.join(save_path, 'plane_ins')
284 | if not os.path.isdir(plnIns_save_pth):
285 | os.makedirs(plnIns_save_pth)
286 |
287 | mesh_planeIns_transfer.export(os.path.join(plnIns_save_pth, '%s_planeIns_transfer.ply' % scene))
288 | np.savetxt(plnIns_save_pth + '/%s.txt'%scene, planeIns, fmt='%d')
289 |
290 | pred_pth = os.path.join(plnIns_save_pth, '{}.txt'.format(scene))
291 | gt_pth = os.path.join(f'./planes_9/instance/{scene}.txt')
292 |
293 | pred_ins = np.loadtxt(pred_pth).astype(np.int32)
294 | gt_ins = np.loadtxt(gt_pth).astype(np.int32)
295 |
296 | ri = rand_score(gt_ins, pred_ins)
297 | h1, h2 = variation_of_information(gt_ins, pred_ins)
298 | voi = h1 + h2
299 | sc = compute_sc(gt_ins, pred_ins)
300 |
301 | ris.append(ri)
302 | vois.append(voi)
303 | scs.append(sc)
304 |
305 | return metrics, ri, voi, sc
306 |
307 |
308 | if __name__ == '__main__':
309 | import glob
310 | now_scenes = sorted(glob.glob('./logs/scene????_??'))
311 | flag = False
312 | stats = []
313 | for scene in tqdm(now_scenes):
314 | scene = scene.split('/')[-1]
315 | metrics, ri, voi, sc = process(scene)
316 | stats.append([metrics['dist1'], metrics['dist2'], metrics['prec'], metrics['recal'], metrics['fscore'], ri, voi, sc])
317 | print('scene', scene, '\t'.join([f'{k}: {v:.4f}' for k, v in metrics.items()]))
318 |
319 | stats = np.array(stats)
320 | print(f'dist1:\t{np.mean(stats[:, 0]):.3f}')
321 | print(f'dist2:\t{np.mean(stats[:, 1]):.3f}')
322 | print(f'prec:\t{np.mean(stats[:, 2]):.3f}')
323 | print(f'recall:\t{np.mean(stats[:, 3]):.3f}')
324 | print(f'fscore:\t{np.mean(stats[:, 4]):.3f}')
325 | print(f'ri:\t{np.mean(stats[:, 5]):.3f}')
326 | print(f'voi:\t{np.mean(stats[:, 6]):.3f}')
327 | print(f'sc:\t{np.mean(stats[:, 7]):.3f}')
328 |
--------------------------------------------------------------------------------
/tools/__init__.py:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/THU-LYJ-Lab/AlphaTablets/735cbfe6aa7f03f7bc37f48303045fe71ffa042a/tools/__init__.py
--------------------------------------------------------------------------------
/tools/chamfer3D/chamfer3D.cu:
--------------------------------------------------------------------------------
1 |
2 | #include
3 | #include
4 |
5 | #include
6 | #include
7 |
8 | #include
9 |
10 |
11 |
12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){
13 | const int batch=512;
14 | __shared__ float buf[batch*3];
15 | for (int i=blockIdx.x;ibest){
127 | result[(i*n+j)]=best;
128 | result_i[(i*n+j)]=best_i;
129 | }
130 | }
131 | __syncthreads();
132 | }
133 | }
134 | }
135 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){
136 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){
137 |
138 | const auto batch_size = xyz1.size(0);
139 | const auto n = xyz1.size(1); //num_points point cloud A
140 | const auto m = xyz2.size(1); //num_points point cloud B
141 |
142 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data());
143 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data());
144 |
145 | cudaError_t err = cudaGetLastError();
146 | if (err != cudaSuccess) {
147 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err));
148 | //THError("aborting");
149 | return 0;
150 | }
151 | return 1;
152 |
153 |
154 | }
155 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){
156 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data());
185 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data());
186 |
187 | cudaError_t err = cudaGetLastError();
188 | if (err != cudaSuccess) {
189 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err));
190 | //THError("aborting");
191 | return 0;
192 | }
193 | return 1;
194 |
195 | }
196 |
197 |
--------------------------------------------------------------------------------
/tools/chamfer3D/chamfer_cuda.cpp:
--------------------------------------------------------------------------------
1 | #include
2 | #include
3 |
4 | ///TMP
5 | //#include "common.h"
6 | /// NOT TMP
7 |
8 |
9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2);
10 |
11 |
12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2);
13 |
14 |
15 |
16 |
17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) {
18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2);
19 | }
20 |
21 |
22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1,
23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) {
24 |
25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2);
26 | }
27 |
28 |
29 |
30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)");
32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)");
33 | }
--------------------------------------------------------------------------------
/tools/chamfer3D/dist_chamfer_3D.py:
--------------------------------------------------------------------------------
1 | from torch import nn
2 | from torch.autograd import Function
3 | import torch
4 | import importlib
5 | import os
6 | chamfer_found = importlib.find_loader("chamfer_3D") is not None
7 | if not chamfer_found:
8 | ## Cool trick from https://github.com/chrdiller
9 | print("Jitting Chamfer 3D")
10 | cur_path = os.path.dirname(os.path.abspath(__file__))
11 | build_path = cur_path.replace('chamfer3D', 'tmp')
12 | os.makedirs(build_path, exist_ok=True)
13 |
14 | from torch.utils.cpp_extension import load
15 | chamfer_3D = load(name="chamfer_3D",
16 | sources=[
17 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]),
18 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]),
19 | ], build_directory=build_path)
20 | print("Loaded JIT 3D CUDA chamfer distance")
21 |
22 | else:
23 | import chamfer_3D
24 | print("Loaded compiled 3D CUDA chamfer distance")
25 |
26 |
27 | # Chamfer's distance module @thibaultgroueix
28 | # GPU tensors only
29 | class chamfer_3DFunction(Function):
30 | @staticmethod
31 | def forward(ctx, xyz1, xyz2):
32 | batchsize, n, dim = xyz1.size()
33 | assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
34 | _, m, dim = xyz2.size()
35 | assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()"
36 | device = xyz1.device
37 |
38 | device = xyz1.device
39 |
40 | dist1 = torch.zeros(batchsize, n)
41 | dist2 = torch.zeros(batchsize, m)
42 |
43 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor)
44 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor)
45 |
46 | dist1 = dist1.to(device)
47 | dist2 = dist2.to(device)
48 | idx1 = idx1.to(device)
49 | idx2 = idx2.to(device)
50 | torch.cuda.set_device(device)
51 |
52 | chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2)
53 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2)
54 | return dist1, dist2, idx1, idx2
55 |
56 | @staticmethod
57 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2):
58 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors
59 | graddist1 = graddist1.contiguous()
60 | graddist2 = graddist2.contiguous()
61 | device = graddist1.device
62 |
63 | gradxyz1 = torch.zeros(xyz1.size())
64 | gradxyz2 = torch.zeros(xyz2.size())
65 |
66 | gradxyz1 = gradxyz1.to(device)
67 | gradxyz2 = gradxyz2.to(device)
68 | chamfer_3D.backward(
69 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2
70 | )
71 | return gradxyz1, gradxyz2
72 |
73 |
74 | class chamfer_3DDist(nn.Module):
75 | def __init__(self):
76 | super(chamfer_3DDist, self).__init__()
77 |
78 | def forward(self, input1, input2):
79 | input1 = input1.contiguous()
80 | input2 = input2.contiguous()
81 | return chamfer_3DFunction.apply(input1, input2)
82 |
--------------------------------------------------------------------------------
/tools/chamfer3D/setup.py:
--------------------------------------------------------------------------------
1 | from setuptools import setup
2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension
3 |
4 | setup(
5 | name='chamfer_3D',
6 | ext_modules=[
7 | CUDAExtension('chamfer_3D', [
8 | "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']),
9 | "/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']),
10 | ]),
11 | ],
12 | cmdclass={
13 | 'build_ext': BuildExtension
14 | })
--------------------------------------------------------------------------------
/tools/generate_gt.py:
--------------------------------------------------------------------------------
1 | # This file is derived from [NeuralRecon](https://github.com/zju3dv/NeuralRecon).
2 | # Originating Author: Yiming Xie
3 | # Modified for [PlanarRecon](https://github.com/neu-vi/PlanarRecon) by Yiming Xie.
4 |
5 | # Original header:
6 | # Copyright SenseTime. All Rights Reserved.
7 |
8 | # Licensed under the Apache License, Version 2.0 (the "License");
9 | # you may not use this file except in compliance with the License.
10 | # You may obtain a copy of the License at
11 |
12 | # http://www.apache.org/licenses/LICENSE-2.0
13 |
14 | # Unless required by applicable law or agreed to in writing, software
15 | # distributed under the License is distributed on an "AS IS" BASIS,
16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17 | # See the License for the specific language governing permissions and
18 | # limitations under the License.
19 |
20 | import sys
21 | import os
22 |
23 | import numpy as np
24 |
25 | sys.path.append('.')
26 |
27 | import pickle
28 | import argparse
29 | from tqdm import tqdm
30 | import ray
31 | import torch.multiprocessing
32 | from tools.simple_loader import *
33 | from tools.generate_planes import generate_planes
34 | from tools.chamfer3D import dist_chamfer_3D
35 |
36 | torch.multiprocessing.set_sharing_strategy('file_system')
37 |
38 |
39 | def coordinates(voxel_dim, device=torch.device('cuda')):
40 | """ 3d meshgrid of given size.
41 |
42 | Args:
43 | voxel_dim: tuple of 3 ints (nx,ny,nz) specifying the size of the volume
44 |
45 | Returns:
46 | torch long tensor of size (3,nx*ny*nz)
47 | """
48 |
49 | nx, ny, nz = voxel_dim
50 | x = torch.arange(nx, dtype=torch.long, device=device)
51 | y = torch.arange(ny, dtype=torch.long, device=device)
52 | z = torch.arange(nz, dtype=torch.long, device=device)
53 | x, y, z = torch.meshgrid(x, y, z)
54 | return torch.stack((x.flatten(), y.flatten(), z.flatten()))
55 |
56 |
57 | def parse_args():
58 | parser = argparse.ArgumentParser(description='Generate ground truth plane')
59 | parser.add_argument("--dataset", default='scannet')
60 | parser.add_argument("--data_path", metavar="DIR",
61 | help="path to dataset", default='./scannet')
62 | parser.add_argument("--save_name", metavar="DIR",
63 | help="file name", default='planes_9/')
64 | parser.add_argument('--max_depth', default=3., type=float,
65 | help='mask out large depth values since they are noisy')
66 | parser.add_argument('--voxel_size', default=0.04, type=float)
67 |
68 | parser.add_argument('--window_size', default=9, type=int)
69 | parser.add_argument('--min_angle', default=15, type=float)
70 | parser.add_argument('--min_distance', default=0.1, type=float)
71 |
72 | # ray multi processes
73 | parser.add_argument('--n_proc', type=int, default=8, help='#processes launched to process scenes.')
74 | parser.add_argument('--n_gpu', type=int, default=1, help='#number of gpus')
75 | parser.add_argument('--num_workers', type=int, default=8)
76 | parser.add_argument('--loader_num_workers', type=int, default=8)
77 | return parser.parse_args()
78 |
79 |
80 | args = parse_args()
81 | args.save_path = args.save_name
82 |
83 |
84 | def rigid_transform(xyz, transform):
85 | """Applies a rigid transform to an (N, 3) pointcloud.
86 | """
87 | xyz_h = np.hstack([xyz, np.ones((len(xyz), 1), dtype=np.float32)])
88 | xyz_t_h = np.dot(transform, xyz_h.T).T
89 | return xyz_t_h[:, :3]
90 |
91 |
92 | def get_view_frustum(depth_im, cam_intr, cam_pose, max_depth=3.0):
93 | """Get corners of 3D camera view frustum of depth image
94 | """
95 | if depth_im is not None:
96 | im_h = depth_im.shape[0]
97 | im_w = depth_im.shape[1]
98 | max_depth = np.max(depth_im)
99 | else:
100 | im_h = 480
101 | im_w = 640
102 | view_frust_pts = np.array([
103 | (np.array([0, 0, 0, im_w, im_w]) - cam_intr[0, 2]) * np.array([0, max_depth, max_depth, max_depth, max_depth]) /
104 | cam_intr[0, 0],
105 | (np.array([0, 0, im_h, 0, im_h]) - cam_intr[1, 2]) * np.array([0, max_depth, max_depth, max_depth, max_depth]) /
106 | cam_intr[1, 1],
107 | np.array([0, max_depth, max_depth, max_depth, max_depth])
108 | ])
109 | view_frust_pts = rigid_transform(view_frust_pts.T, cam_pose).T
110 | return view_frust_pts
111 |
112 |
113 | def compute_global_volume(args, cam_intr, cam_pose_list):
114 | # ======================================================================================================== #
115 | # (Optional) This is an example of how to compute the 3D bounds
116 | # in world coordinates of the convex hull of all camera view
117 | # frustums in the dataset
118 | # ======================================================================================================== #
119 | vol_bnds = np.zeros((3, 2))
120 |
121 | n_imgs = len(cam_pose_list.keys())
122 | if n_imgs > 200:
123 | ind = np.linspace(0, n_imgs - 1, 200).astype(np.int32)
124 | image_id = np.array(list(cam_pose_list.keys()))[ind]
125 | else:
126 | image_id = cam_pose_list.keys()
127 | for id in image_id:
128 | cam_pose = cam_pose_list[id]
129 |
130 | # Compute camera view frustum and extend convex hull
131 | view_frust_pts = get_view_frustum(None, cam_intr, cam_pose, max_depth=args.max_depth)
132 | vol_bnds[:, 0] = np.minimum(vol_bnds[:, 0], np.amin(view_frust_pts, axis=1))
133 | vol_bnds[:, 1] = np.maximum(vol_bnds[:, 1], np.amax(view_frust_pts, axis=1))
134 | # ======================================================================================================== #
135 |
136 | # Adjust volume bounds and ensure C-order contiguous
137 | vol_dim = np.round((vol_bnds[:, 1] - vol_bnds[:, 0]) / args.voxel_size).copy(
138 | order='C').astype(int)
139 | vol_bnds[:, 1] = vol_bnds[:, 0] + vol_dim * args.voxel_size
140 | vol_origin = vol_bnds[:, 0].copy(order='C').astype(np.float32)
141 |
142 | return vol_dim, vol_origin
143 |
144 |
145 | def save_label_full(args, scene, vol_dim, vol_origin, planes, plane_points):
146 | planes = np.concatenate([planes, - np.ones_like(planes[:, :1])], axis=-1)
147 |
148 | # ========================generate indicator gt========================
149 | planes = torch.from_numpy(planes).cuda()
150 | coords = coordinates(vol_dim, device=planes.device)
151 | coords = coords.type(torch.float) * args.voxel_size + torch.from_numpy(vol_origin).view(3, 1).cuda()
152 | coords = coords.permute(1, 0).contiguous()
153 |
154 | min_dist = None
155 | indices = None
156 | for i, points in enumerate(plane_points[:]):
157 | points = torch.from_numpy(points).cuda()
158 | chamLoss = dist_chamfer_3D.chamfer_3DDist()
159 | dist1, _, _, _ = chamLoss(coords.unsqueeze(0), points.unsqueeze(0))
160 | if min_dist is None:
161 | min_dist = dist1
162 | indices = torch.zeros_like(dist1)
163 | else:
164 | current_id = torch.ones_like(indices) * i
165 | indices = torch.where(dist1 < min_dist, current_id, indices)
166 | min_dist = torch.where(dist1 < min_dist, dist1, min_dist)
167 | # remove too far points which may not have a plane
168 | current_id = torch.ones_like(indices) * -1
169 | indices = torch.where(0.36 ** 2 < min_dist, current_id, indices)
170 | indices = indices.view(vol_dim.tolist()).data.cpu().numpy()
171 |
172 | np.savez_compressed(os.path.join(args.save_path, scene, 'indices'), indices)
173 | # ==============================================================================================
174 |
175 |
176 | def save_fragment_pkl(args, scene, cam_pose_list, vol_dim, vol_origin):
177 | # view selection
178 | fragments = []
179 | print('segment: process scene {}'.format(scene))
180 |
181 | all_ids = []
182 | ids = []
183 | count = 0
184 | last_pose = None
185 | for id in cam_pose_list.keys():
186 | cam_pose = cam_pose_list[id]
187 |
188 | if count == 0:
189 | ids.append(id)
190 | last_pose = cam_pose
191 | count += 1
192 | else:
193 | angle = np.arccos(
194 | ((np.linalg.inv(cam_pose[:3, :3]) @ last_pose[:3, :3] @ np.array([0, 0, 1]).T) * np.array(
195 | [0, 0, 1])).sum())
196 | dis = np.linalg.norm(cam_pose[:3, 3] - last_pose[:3, 3])
197 | if angle > (args.min_angle / 180) * np.pi or dis > args.min_distance:
198 | ids.append(id)
199 | last_pose = cam_pose
200 | count += 1
201 | if count == args.window_size:
202 | all_ids.append(ids)
203 | ids = []
204 | count = 0
205 |
206 | # save fragments
207 | for i, ids in enumerate(all_ids):
208 | fragments.append({
209 | 'scene': scene,
210 | 'fragment_id': i,
211 | 'image_ids': ids,
212 | 'vol_origin': vol_origin,
213 | 'vol_dim': vol_dim,
214 | 'voxel_size': args.voxel_size,
215 | })
216 |
217 | with open(os.path.join(args.save_path, scene, 'fragments.pkl'), 'wb') as f:
218 | pickle.dump(fragments, f)
219 |
220 | return
221 |
222 |
223 | @ray.remote(num_cpus=args.num_workers + 1, num_gpus=(1 / args.n_proc))
224 | def process_with_single_worker(args, scannet_files):
225 | planes_all = []
226 | for scene in tqdm(scannet_files):
227 | if os.path.exists(os.path.join(args.save_path, scene, 'fragments.pkl')):
228 | continue
229 | print('read from disk')
230 |
231 | cam_pose_all = {}
232 |
233 | n_imgs = len(os.listdir(os.path.join(args.data_path, scene, 'color')))
234 | intrinsic_dir = os.path.join(args.data_path, scene, 'intrinsic', 'intrinsic_depth.txt')
235 | cam_intr = np.loadtxt(intrinsic_dir, delimiter=' ')[:3, :3]
236 | dataset = ScanNetDataset(n_imgs, scene, args.data_path, args.max_depth)
237 |
238 | planes, plane_points = generate_planes(args, scene, save_mesh=True)
239 |
240 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, collate_fn=collate_fn,
241 | batch_sampler=None, num_workers=args.loader_num_workers)
242 |
243 | for id, (cam_pose, _, _) in enumerate(dataloader):
244 | if id % 100 == 0:
245 | print("{}: read frame {}/{}".format(scene, str(id), str(n_imgs)))
246 |
247 | if cam_pose[0][0] == np.inf or cam_pose[0][0] == -np.inf or cam_pose[0][0] == np.nan:
248 | continue
249 | cam_pose_all.update({id: cam_pose})
250 |
251 | vol_dim, vol_origin = compute_global_volume(args, cam_intr, cam_pose_all)
252 | save_label_full(args, scene, vol_dim, vol_origin, planes, plane_points)
253 | save_fragment_pkl(args, scene, cam_pose_all, vol_dim, vol_origin)
254 |
255 | planes_center = np.array([
256 | [0, -1, 0],
257 | [0, 1, 0],
258 | [-1, 0, 0],
259 | [-1, 0, -1],
260 | [0, 0, -1],
261 | [1, 0, -1],
262 | [1, 0, 1],
263 | ])
264 | planes_center = planes_center / np.linalg.norm(planes_center, axis=1)[..., np.newaxis]
265 |
266 | np.save(os.path.join(args.save_path, 'normal_anchors.npy'), planes_center)
267 |
268 |
269 | def split_list(_list, n):
270 | assert len(_list) >= n
271 | ret = [[] for _ in range(n)]
272 | for idx, item in enumerate(_list):
273 | ret[idx % n].append(item)
274 | return ret
275 |
276 |
277 | def generate_pkl(args):
278 | all_scenes = sorted(os.listdir(args.save_path))
279 | # todo: fix for both train/val
280 | splits = ['train', 'val']
281 | for split in splits:
282 | fragments = []
283 | with open(os.path.join(args.data_path[:-6],'scannetv2_{}.txt'.format(split))) as f:
284 | split_files = f.readlines()
285 | for scene in all_scenes:
286 | if 'scene' not in scene:
287 | continue
288 | if scene + '\n' in split_files:
289 | with open(os.path.join(args.save_path, scene, 'fragments.pkl'), 'rb') as f:
290 | frag_scene = pickle.load(f)
291 | fragments.extend(frag_scene)
292 |
293 | with open(os.path.join(args.save_path, 'fragments_{}.pkl'.format(split)), 'wb') as f:
294 | pickle.dump(fragments, f)
295 |
296 |
297 | if __name__ == "__main__":
298 | all_proc = args.n_proc * args.n_gpu
299 |
300 | ray.init(num_cpus=all_proc * (args.num_workers + 1), num_gpus=args.n_gpu)
301 |
302 | if args.dataset == 'scannet':
303 | args.data_raw_path = os.path.join(args.data_path, 'scans/')
304 | args.data_path = os.path.join(args.data_path, 'scans')
305 | files = sorted(os.listdir(args.data_path))
306 | else:
307 | raise NameError('error!')
308 |
309 | files = split_list(files, all_proc)
310 |
311 | ray_worker_ids = []
312 | for w_idx in range(all_proc):
313 | ray_worker_ids.append(process_with_single_worker.remote(args, files[w_idx]))
314 |
315 | results = ray.get(ray_worker_ids)
316 |
317 |
318 | # process_with_single_worker(args, files)
319 |
320 | if args.dataset == 'scannet':
321 | generate_pkl(args)
322 |
--------------------------------------------------------------------------------
/tools/generate_planes.py:
--------------------------------------------------------------------------------
1 | # This file is derived from [PlaneRCNN](https://github.com/NVlabs/planercnn).
2 | # Originating Author: Chen Liu
3 | # Modified for [PlanarRecon](https://github.com/neu-vi/PlanarRecon) by Yiming Xie.
4 |
5 | # Original header:
6 | # Copyright (C) 2019 NVIDIA Corporation. All rights reserved.
7 | # Licensed under the CC BY-NC-SA 4.0 license
8 | # (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode).
9 |
10 | import numpy as np
11 | import sys
12 | import os
13 | from plyfile import PlyData, PlyElement
14 | import json
15 | import glob
16 | import cv2
17 | import matplotlib.pyplot as plt
18 | import torch
19 | from tools.random_color import random_color
20 |
21 | numPlanes = 200
22 | numPlanesPerSegment = 2
23 | planeAreaThreshold = 100
24 | numIterations = 100
25 | numIterationsPair = 1000
26 | planeDiffThreshold = 0.05
27 | fittingErrorThreshold = planeDiffThreshold
28 | orthogonalThreshold = np.cos(np.deg2rad(60))
29 | parallelThreshold = np.cos(np.deg2rad(30))
30 |
31 |
32 | ## Fit a 3D plane from points
33 | def fitPlane(points):
34 | if points.shape[0] == points.shape[1]:
35 | return np.linalg.solve(points, np.ones(points.shape[0]))
36 | else:
37 | print(points)
38 | return np.linalg.lstsq(points, np.ones(points.shape[0]))[0]
39 | return
40 |
41 |
42 | class ColorPalette:
43 | def __init__(self, numColors):
44 | np.random.seed(2)
45 | self.colorMap = np.array([[255, 0, 0],
46 | [0, 255, 0],
47 | [0, 0, 255],
48 | [80, 128, 255],
49 | [255, 230, 180],
50 | [255, 0, 255],
51 | [0, 255, 255],
52 | [100, 0, 0],
53 | [0, 100, 0],
54 | [255, 255, 0],
55 | [50, 150, 0],
56 | [200, 255, 255],
57 | [255, 200, 255],
58 | [128, 128, 80],
59 | [0, 50, 128],
60 | [0, 100, 100],
61 | [0, 255, 128],
62 | [0, 128, 255],
63 | [255, 0, 128],
64 | [128, 0, 255],
65 | [255, 128, 0],
66 | [128, 255, 0],
67 | ])
68 |
69 | if numColors > self.colorMap.shape[0]:
70 | self.colorMap = np.concatenate(
71 | [self.colorMap, np.random.randint(255, size=(numColors - self.colorMap.shape[0], 3))], axis=0)
72 | pass
73 |
74 | return
75 |
76 | def getColorMap(self):
77 | return self.colorMap
78 |
79 | def getColor(self, index):
80 | if index >= colorMap.shape[0]:
81 | return np.random.randint(255, size=(3))
82 | else:
83 | return self.colorMap[index]
84 | pass
85 |
86 |
87 | def loadClassMap(args):
88 | classMap = {}
89 | classLabelMap = {}
90 | with open(args.data_raw_path[:-6] + '/scannetv2-labels.combined.tsv') as info_file:
91 | line_index = 0
92 | for line in info_file:
93 | if line_index > 0:
94 | line = line.split('\t')
95 |
96 | key = line[1].strip()
97 | classMap[key] = line[7].strip()
98 | classMap[key + 's'] = line[7].strip()
99 | classMap[key + 'es'] = line[7].strip()
100 | classMap[key[:-1] + 'ves'] = line[7].strip()
101 |
102 | if line[4].strip() != '':
103 | nyuLabel = int(line[4].strip())
104 | else:
105 | nyuLabel = -1
106 | pass
107 | classLabelMap[key] = [nyuLabel, line_index - 1]
108 | classLabelMap[key + 's'] = [nyuLabel, line_index - 1]
109 | classLabelMap[key[:-1] + 'ves'] = [nyuLabel, line_index - 1]
110 | pass
111 | line_index += 1
112 | continue
113 | pass
114 | return classMap, classLabelMap
115 |
116 |
117 | def writePointCloudFace(filename, points, faces):
118 | with open(filename, 'w') as f:
119 | header = """ply
120 | format ascii 1.0
121 | element vertex """
122 | header += str(len(points))
123 | header += """
124 | property float x
125 | property float y
126 | property float z
127 | property uchar red
128 | property uchar green
129 | property uchar blue
130 | element face """
131 | header += str(len(faces))
132 | header += """
133 | property list uchar int vertex_index
134 | end_header
135 | """
136 | f.write(header)
137 | for point in points:
138 | for value in point[:3]:
139 | f.write(str(value) + ' ')
140 | continue
141 | for value in point[3:]:
142 | f.write(str(int(value)) + ' ')
143 | continue
144 | f.write('\n')
145 | continue
146 | for face in faces:
147 | f.write('3 ' + str(face[0]) + ' ' + str(face[1]) + ' ' + str(face[2]) + '\n')
148 | continue
149 | f.close()
150 | pass
151 | return
152 |
153 |
154 | def mergePlanes(points, planes, planePointIndices, planeSegments, segmentNeighbors, numPlanes, debug=False):
155 | planeFittingErrors = []
156 | for plane, pointIndices in zip(planes, planePointIndices):
157 | XYZ = points[pointIndices]
158 | planeNorm = np.linalg.norm(plane)
159 | if planeNorm == 0:
160 | planeFittingErrors.append(fittingErrorThreshold * 2)
161 | continue
162 | diff = np.abs(np.matmul(XYZ, plane) - np.ones(XYZ.shape[0])) / planeNorm
163 | planeFittingErrors.append(diff.mean())
164 | continue
165 |
166 | planeList = list(zip(planes, planePointIndices, planeSegments, planeFittingErrors))
167 | planeList = sorted(planeList, key=lambda x: x[3])
168 |
169 | while len(planeList) > 0:
170 | hasChange = False
171 | planeIndex = 0
172 |
173 | if debug:
174 | for index, planeInfo in enumerate(sorted(planeList, key=lambda x: -len(x[1]))):
175 | print(index, planeInfo[0] / np.linalg.norm(planeInfo[0]), planeInfo[2], planeInfo[3])
176 | continue
177 | pass
178 |
179 | while planeIndex < len(planeList):
180 | plane, pointIndices, segments, fittingError = planeList[planeIndex]
181 | if fittingError > fittingErrorThreshold:
182 | break
183 | neighborSegments = []
184 | for segment in segments:
185 | if segment in segmentNeighbors:
186 | neighborSegments += segmentNeighbors[segment]
187 | pass
188 | continue
189 | neighborSegments += list(segments)
190 | neighborSegments = set(neighborSegments)
191 | bestNeighborPlane = (fittingErrorThreshold, -1, None)
192 | for neighborPlaneIndex, neighborPlane in enumerate(planeList):
193 | if neighborPlaneIndex <= planeIndex:
194 | continue
195 | if not bool(neighborSegments & neighborPlane[2]):
196 | continue
197 | neighborPlaneNorm = np.linalg.norm(neighborPlane[0])
198 | if neighborPlaneNorm < 1e-4:
199 | continue
200 | dotProduct = np.abs(
201 | np.dot(neighborPlane[0], plane) / np.maximum(neighborPlaneNorm * np.linalg.norm(plane), 1e-4))
202 | if dotProduct < orthogonalThreshold:
203 | continue
204 | newPointIndices = np.concatenate([neighborPlane[1], pointIndices], axis=0)
205 | XYZ = points[newPointIndices]
206 | if dotProduct > parallelThreshold and len(neighborPlane[1]) > len(pointIndices) * 0.5:
207 | newPlane = fitPlane(XYZ)
208 | else:
209 | newPlane = plane
210 | pass
211 | diff = np.abs(np.matmul(XYZ, newPlane) - np.ones(XYZ.shape[0])) / np.linalg.norm(newPlane)
212 | newFittingError = diff.mean()
213 | if debug:
214 | print(
215 | len(planeList), planeIndex, neighborPlaneIndex, newFittingError, plane / np.linalg.norm(plane),
216 | neighborPlane[
217 | 0] / np.linalg.norm(
218 | neighborPlane[0]),
219 | dotProduct, orthogonalThreshold)
220 | pass
221 | if newFittingError < bestNeighborPlane[0]:
222 | newPlaneInfo = [newPlane, newPointIndices, segments.union(neighborPlane[2]), newFittingError]
223 | bestNeighborPlane = (newFittingError, neighborPlaneIndex, newPlaneInfo)
224 | pass
225 | continue
226 | if bestNeighborPlane[1] != -1:
227 | newPlaneList = planeList[:planeIndex] + planeList[planeIndex + 1:bestNeighborPlane[1]] + planeList[
228 | bestNeighborPlane[
229 | 1] + 1:]
230 | newFittingError, newPlaneIndex, newPlane = bestNeighborPlane
231 | for newPlaneIndex in range(len(newPlaneList)):
232 | if (newPlaneIndex == 0 and newPlaneList[newPlaneIndex][3] > newFittingError) \
233 | or newPlaneIndex == len(newPlaneList) - 1 \
234 | or (newPlaneList[newPlaneIndex][3] < newFittingError and newPlaneList[newPlaneIndex + 1][
235 | 3] > newFittingError):
236 | newPlaneList.insert(newPlaneIndex, newPlane)
237 | break
238 | continue
239 | if len(newPlaneList) == 0:
240 | newPlaneList = [newPlane]
241 | pass
242 | planeList = newPlaneList
243 | hasChange = True
244 | else:
245 | planeIndex += 1
246 | pass
247 | continue
248 | if not hasChange:
249 | break
250 | continue
251 |
252 | planeList = sorted(planeList, key=lambda x: -len(x[1]))
253 |
254 | minNumPlanes, maxNumPlanes = numPlanes
255 | if minNumPlanes == 1 and len(planeList) == 0:
256 | if debug:
257 | print('at least one plane')
258 | pass
259 | elif len(planeList) > maxNumPlanes:
260 | if debug:
261 | print('too many planes', len(planeList), maxNumPlanes)
262 | pass
263 | planeList = planeList[:maxNumPlanes] + [(np.zeros(3), planeInfo[1], planeInfo[2], fittingErrorThreshold) for
264 | planeInfo in planeList[maxNumPlanes:]]
265 | pass
266 |
267 | groupedPlanes, groupedPlanePointIndices, groupedPlaneSegments, groupedPlaneFittingErrors = zip(*planeList)
268 | return groupedPlanes, groupedPlanePointIndices, groupedPlaneSegments
269 |
270 |
271 | def furthest_point_sampling(points, N=100):
272 | from sklearn.metrics import pairwise_distances
273 | D = pairwise_distances(points, metric='euclidean')
274 | # By default, takes the first point in the list to be the
275 | # first point in the permutation, but could be random
276 | perm = np.zeros(N, dtype=np.int64)
277 | lambdas = np.zeros(N)
278 | ds = D[0, :]
279 | for i in range(1, N):
280 | idx = np.argmax(ds)
281 | perm[i] = idx
282 | lambdas[i] = ds[idx]
283 | ds = np.minimum(ds, D[idx, :])
284 | return (perm, lambdas)
285 |
286 |
287 | def writePointFacePlane(filename, points, faces):
288 | with open(filename, 'w') as f:
289 | header = """ply
290 | format ascii 1.0
291 | element vertex """
292 | header += str(len(points))
293 | header += """
294 | property float x
295 | property float y
296 | property float z
297 | property uchar red
298 | property uchar green
299 | property uchar blue
300 | element face """
301 | header += str(len(faces))
302 | header += """
303 | property list uchar int vertex_index
304 | end_header
305 | """
306 | f.write(header)
307 | for point in points:
308 | for value in point[:3]:
309 | f.write(str(value) + ' ')
310 | continue
311 | for value in point[3:]:
312 | f.write(str(int(value)) + ' ')
313 | continue
314 | f.write('\n')
315 | continue
316 | for face in faces:
317 | for value in face:
318 | f.write(str(value) + ' ')
319 | continue
320 | f.write('\n')
321 | continue
322 | f.close()
323 | pass
324 | return
325 |
326 |
327 | def project2plane(plane, plane_points):
328 | A, B = plane_points[0], plane_points[1]
329 | AB = B - A
330 | N = plane / np.linalg.norm(plane, ord=2)
331 | U = AB / np.linalg.norm(AB, ord=2)
332 | V = np.cross(U, N)
333 | u = A + U
334 | v = A + V
335 | n = A + N
336 | S = [[A[0], u[0], v[0], n[0]], [A[1], u[1], v[1], n[1]], [A[2], u[2], v[2], n[2]], [1, 1, 1, 1]]
337 | D = [[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1]]
338 | M = np.matmul(D, np.linalg.inv(S))
339 | return M
340 |
341 |
342 | def points2contours(points, size=0.1):
343 | points = points / size
344 | points = points.astype(np.int)
345 | min_point = points.min(axis=0) - 5
346 | max_point = points.max(axis=0) + 5
347 | image_size = max_point - min_point + 1
348 | points = points - min_point
349 | image = np.zeros(image_size).astype(np.uint8)
350 | image[points[:, 0], points[:, 1]] = 255
351 | # plt.imshow(image, cmap='gray')
352 | # plt.show()
353 | contours, hierarchy = cv2.findContours(image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
354 | contours_valid = []
355 | for i, con in enumerate(contours):
356 | if con.shape[0] > 10:
357 | con = (con + min_point) * size
358 | contours_valid.append(con)
359 | return contours_valid
360 |
361 |
362 | def generate_planes(args, scene_id, high_res=False, save_mesh=False, debug=False):
363 | if not os.path.exists(args.save_path + '/' + scene_id + '/annotation'):
364 | os.system('mkdir -p ' + args.save_path + '/' + scene_id + '/annotation')
365 |
366 | filename = args.data_raw_path + scene_id + '/' + scene_id + '.aggregation.json'
367 | data = json.load(open(filename, 'r'))
368 | aggregation = np.array(data['segGroups'])
369 |
370 | if high_res:
371 | filename = args.data_raw_path + scene_id + '/' + scene_id + '_vh_clean.labels.ply'
372 | else:
373 | filename = args.data_raw_path + scene_id + '/' + scene_id + '_vh_clean_2.labels.ply'
374 |
375 | plydata = PlyData.read(filename)
376 | vertices = plydata['vertex']
377 | points = np.stack([vertices['x'], vertices['y'], vertices['z']], axis=1)
378 | faces = np.array(plydata['face']['vertex_indices'])
379 |
380 | # semanticSegmentation = vertices['label']
381 |
382 | if high_res:
383 | filename = args.data_raw_path + scene_id + '/' + scene_id + '_vh_clean.segs.json'
384 | else:
385 | filename = args.data_raw_path + scene_id + '/' + scene_id + '_vh_clean_2.0.010000.segs.json'
386 |
387 | data = json.load(open(filename, 'r'))
388 | segmentation = np.array(data['segIndices'])
389 |
390 | groupSegments = []
391 | groupLabels = []
392 | for segmentIndex in range(len(aggregation)):
393 | groupSegments.append(aggregation[segmentIndex]['segments'])
394 | groupLabels.append(aggregation[segmentIndex]['label'])
395 |
396 | segmentation = segmentation.astype(np.int32)
397 |
398 | uniqueSegments = np.unique(segmentation).tolist()
399 | numSegments = 0
400 | for segments in groupSegments:
401 | for segmentIndex in segments:
402 | if segmentIndex in uniqueSegments:
403 | uniqueSegments.remove(segmentIndex)
404 | numSegments += len(segments)
405 |
406 | for segment in uniqueSegments:
407 | groupSegments.append([segment, ])
408 | groupLabels.append('unannotated')
409 |
410 | segmentEdges = []
411 | for faceIndex in range(faces.shape[0]):
412 | face = faces[faceIndex]
413 | segment_1 = segmentation[face[0]]
414 | segment_2 = segmentation[face[1]]
415 | segment_3 = segmentation[face[2]]
416 | if segment_1 != segment_2 or segment_1 != segment_3:
417 | if segment_1 != segment_2 and segment_1 != -1 and segment_2 != -1:
418 | segmentEdges.append((min(segment_1, segment_2), max(segment_1, segment_2)))
419 | if segment_1 != segment_3 and segment_1 != -1 and segment_3 != -1:
420 | segmentEdges.append((min(segment_1, segment_3), max(segment_1, segment_3)))
421 | if segment_2 != segment_3 and segment_2 != -1 and segment_3 != -1:
422 | segmentEdges.append((min(segment_2, segment_3), max(segment_2, segment_3)))
423 | segmentEdges = list(set(segmentEdges))
424 |
425 | labelNumPlanes = {'wall': [1, 3],
426 | 'floor': [1, 1],
427 | 'cabinet': [0, 5],
428 | 'bed': [0, 5],
429 | 'chair': [0, 5],
430 | 'sofa': [0, 10],
431 | 'table': [0, 5],
432 | 'door': [1, 2],
433 | 'window': [0, 2],
434 | 'bookshelf': [0, 5],
435 | 'picture': [1, 1],
436 | 'counter': [0, 10],
437 | 'blinds': [0, 0],
438 | 'desk': [0, 10],
439 | 'shelf': [0, 5],
440 | 'shelves': [0, 5],
441 | 'curtain': [0, 0],
442 | 'dresser': [0, 5],
443 | 'pillow': [0, 0],
444 | 'mirror': [0, 0],
445 | 'entrance': [1, 1],
446 | 'floor mat': [1, 1],
447 | 'clothes': [0, 0],
448 | 'ceiling': [0, 5],
449 | 'book': [0, 1],
450 | 'books': [0, 1],
451 | 'refridgerator': [0, 5],
452 | 'television': [1, 1],
453 | 'paper': [0, 1],
454 | 'towel': [0, 1],
455 | 'shower curtain': [0, 1],
456 | 'box': [0, 5],
457 | 'whiteboard': [1, 5],
458 | 'person': [0, 0],
459 | 'night stand': [1, 5],
460 | 'toilet': [0, 5],
461 | 'sink': [0, 5],
462 | 'lamp': [0, 1],
463 | 'bathtub': [0, 5],
464 | 'bag': [0, 1],
465 | 'otherprop': [0, 5],
466 | 'otherstructure': [0, 5],
467 | 'otherfurniture': [0, 5],
468 | 'unannotated': [0, 5],
469 | '': [0, 0],
470 | }
471 | nonPlanarGroupLabels = ['bicycle', 'bottle', 'water bottle']
472 | nonPlanarGroupLabels = {label: True for label in nonPlanarGroupLabels}
473 |
474 | # verticalLabels = ['wall', 'door', 'cabinet']
475 | classMap, classLabelMap = loadClassMap(args)
476 | classMap['unannotated'] = 'unannotated'
477 | classLabelMap['unannotated'] = [max([index for index, label in classLabelMap.values()]) + 1, 41]
478 | allXYZ = points.reshape(-1, 3)
479 |
480 | segmentNeighbors = {}
481 | for segmentEdge in segmentEdges:
482 | if segmentEdge[0] not in segmentNeighbors:
483 | segmentNeighbors[segmentEdge[0]] = []
484 | segmentNeighbors[segmentEdge[0]].append(segmentEdge[1])
485 |
486 | if segmentEdge[1] not in segmentNeighbors:
487 | segmentNeighbors[segmentEdge[1]] = []
488 | segmentNeighbors[segmentEdge[1]].append(segmentEdge[0])
489 |
490 | planeGroups = []
491 | print('num groups', len(groupSegments))
492 |
493 | debugIndex = -1
494 |
495 | for groupIndex, group in enumerate(groupSegments):
496 | if debugIndex != -1 and groupIndex != debugIndex:
497 | continue
498 | if groupLabels[groupIndex] in nonPlanarGroupLabels:
499 | groupLabel = groupLabels[groupIndex]
500 | minNumPlanes, maxNumPlanes = 0, 0
501 | elif groupLabels[groupIndex] in classMap:
502 | groupLabel = classMap[groupLabels[groupIndex]]
503 | minNumPlanes, maxNumPlanes = labelNumPlanes[groupLabel]
504 | else:
505 | minNumPlanes, maxNumPlanes = 0, 0
506 | groupLabel = ''
507 |
508 | if maxNumPlanes == 0:
509 | pointMasks = []
510 | for segmentIndex in group:
511 | pointMasks.append(segmentation == segmentIndex)
512 | pointIndices = np.any(np.stack(pointMasks, 0), 0).nonzero()[0]
513 | groupPlanes = [[np.zeros(3), pointIndices, []]]
514 | planeGroups.append(groupPlanes)
515 | continue
516 | groupPlanes = []
517 | groupPlanePointIndices = []
518 | groupPlaneSegments = []
519 | for segmentIndex in group:
520 | segmentMask = segmentation == segmentIndex
521 | allSegmentIndices = segmentMask.nonzero()[0]
522 | segmentIndices = allSegmentIndices.copy()
523 |
524 | XYZ = allXYZ[segmentMask.reshape(-1)]
525 | numPoints = XYZ.shape[0]
526 |
527 | for c in range(2):
528 | if c == 0:
529 | ## First try to fit one plane
530 | plane = fitPlane(XYZ)
531 | diff = np.abs(np.matmul(XYZ, plane) - np.ones(XYZ.shape[0])) / np.linalg.norm(plane)
532 | if diff.mean() < fittingErrorThreshold:
533 | groupPlanes.append(plane)
534 | groupPlanePointIndices.append(segmentIndices)
535 | groupPlaneSegments.append(set([segmentIndex]))
536 | break
537 | else:
538 | ## Run ransac
539 | segmentPlanes = []
540 | segmentPlanePointIndices = []
541 |
542 | for planeIndex in range(numPlanesPerSegment):
543 | if len(XYZ) < planeAreaThreshold:
544 | continue
545 | bestPlaneInfo = [None, 0, None]
546 | for iteration in range(min(XYZ.shape[0], numIterations)):
547 | sampledPoints = XYZ[np.random.choice(np.arange(XYZ.shape[0]), size=(3), replace=False)]
548 | try:
549 | plane = fitPlane(sampledPoints)
550 | except:
551 | continue
552 | diff = np.abs(np.matmul(XYZ, plane) - np.ones(XYZ.shape[0])) / np.linalg.norm(plane)
553 | inlierMask = diff < planeDiffThreshold
554 | numInliers = inlierMask.sum()
555 | if numInliers > bestPlaneInfo[1]:
556 | bestPlaneInfo = [plane, numInliers, inlierMask]
557 |
558 | if bestPlaneInfo[1] < planeAreaThreshold:
559 | break
560 |
561 | pointIndices = segmentIndices[bestPlaneInfo[2]]
562 | bestPlane = fitPlane(XYZ[bestPlaneInfo[2]])
563 |
564 | segmentPlanes.append(bestPlane)
565 | segmentPlanePointIndices.append(pointIndices)
566 |
567 | outlierMask = np.logical_not(bestPlaneInfo[2])
568 | segmentIndices = segmentIndices[outlierMask]
569 | XYZ = XYZ[outlierMask]
570 |
571 | if sum([len(indices) for indices in segmentPlanePointIndices]) < numPoints * 0.5:
572 | groupPlanes.append(np.zeros(3))
573 | groupPlanePointIndices.append(allSegmentIndices)
574 | groupPlaneSegments.append(set([segmentIndex]))
575 | else:
576 | if len(segmentIndices) > 0:
577 | ## Add remaining non-planar regions
578 | segmentPlanes.append(np.zeros(3))
579 | segmentPlanePointIndices.append(segmentIndices)
580 | groupPlanes += segmentPlanes
581 | groupPlanePointIndices += segmentPlanePointIndices
582 |
583 | for _ in range(len(segmentPlanes)):
584 | groupPlaneSegments.append(set([segmentIndex]))
585 |
586 | numRealPlanes = len([plane for plane in groupPlanes if np.linalg.norm(plane) > 1e-4])
587 | if minNumPlanes == 1 and numRealPlanes == 0:
588 | ## Some instances always contain at least one planes (e.g, the floor)
589 | maxArea = (planeAreaThreshold, -1)
590 | for index, indices in enumerate(groupPlanePointIndices):
591 | if len(indices) > maxArea[0]:
592 | maxArea = (len(indices), index)
593 | maxArea, planeIndex = maxArea
594 | if planeIndex >= 0:
595 | groupPlanes[planeIndex] = fitPlane(allXYZ[groupPlanePointIndices[planeIndex]])
596 | numRealPlanes = 1
597 | if minNumPlanes == 1 and maxNumPlanes == 1 and numRealPlanes > 1:
598 | ## Some instances always contain at most one planes (e.g, the floor)
599 |
600 | pointIndices = np.concatenate(
601 | [indices for plane, indices in list(zip(groupPlanes, groupPlanePointIndices))],
602 | axis=0)
603 | XYZ = allXYZ[pointIndices]
604 | plane = fitPlane(XYZ)
605 | diff = np.abs(np.matmul(XYZ, plane) - np.ones(XYZ.shape[0])) / np.linalg.norm(plane)
606 |
607 | if groupLabel == 'floor':
608 | ## Relax the constraint for the floor due to the misalignment issue in ScanNet
609 | fittingErrorScale = 3
610 | else:
611 | fittingErrorScale = 1
612 |
613 | if diff.mean() < fittingErrorThreshold * fittingErrorScale:
614 | groupPlanes = [plane]
615 | groupPlanePointIndices = [pointIndices]
616 | planeSegments = []
617 | for segments in groupPlaneSegments:
618 | planeSegments += list(segments)
619 | groupPlaneSegments = [set(planeSegments)]
620 | numRealPlanes = 1
621 |
622 | if numRealPlanes > 1:
623 | groupPlanes, groupPlanePointIndices, groupPlaneSegments = mergePlanes(points, groupPlanes,
624 | groupPlanePointIndices,
625 | groupPlaneSegments, segmentNeighbors,
626 | numPlanes=(
627 | minNumPlanes, maxNumPlanes),
628 | debug=debugIndex != -1)
629 |
630 | groupNeighbors = []
631 | for planeIndex, planeSegments in enumerate(groupPlaneSegments):
632 | neighborSegments = []
633 | for segment in planeSegments:
634 | if segment in segmentNeighbors:
635 | neighborSegments += segmentNeighbors[segment]
636 | neighborSegments += list(planeSegments)
637 | neighborSegments = set(neighborSegments)
638 | neighborPlaneIndices = []
639 | for neighborPlaneIndex, neighborPlaneSegments in enumerate(groupPlaneSegments):
640 | if neighborPlaneIndex == planeIndex:
641 | continue
642 | if bool(neighborSegments & neighborPlaneSegments):
643 | plane = groupPlanes[planeIndex]
644 | neighborPlane = groupPlanes[neighborPlaneIndex]
645 | if np.linalg.norm(plane) * np.linalg.norm(neighborPlane) < 1e-4:
646 | continue
647 | dotProduct = np.abs(
648 | np.dot(plane, neighborPlane) / np.maximum(np.linalg.norm(plane) * np.linalg.norm(neighborPlane),
649 | 1e-4))
650 | if dotProduct < orthogonalThreshold:
651 | neighborPlaneIndices.append(neighborPlaneIndex)
652 | groupNeighbors.append(neighborPlaneIndices)
653 | groupPlanes = list(zip(groupPlanes, groupPlanePointIndices, groupNeighbors))
654 | # groupPlanes = zip(groupPlanes, groupPlanePointIndices, groupNeighbors)
655 | planeGroups.append(groupPlanes)
656 |
657 | if debug:
658 | colorMap = ColorPalette(segmentation.max() + 2).getColorMap()
659 | colorMap[-1] = 0
660 | colorMap[-2] = 255
661 | annotationFolder = 'test/'
662 | else:
663 | numPlanes = sum([len(group) for group in planeGroups])
664 | segmentationColor = (np.arange(numPlanes + 1) + 1) * 100
665 | colorMap = np.stack([segmentationColor / (256 * 256), segmentationColor / 256 % 256, segmentationColor % 256],
666 | axis=1)
667 | colorMap[-1] = 0
668 | annotationFolder = args.save_path + scene_id + '/annotation'
669 |
670 | if debug:
671 | colors = colorMap[segmentation]
672 | writePointCloudFace(annotationFolder + '/segments.ply', np.concatenate([points, colors], axis=-1), faces)
673 |
674 | groupedSegmentation = np.full(segmentation.shape, fill_value=-1)
675 | for segmentIndex in range(len(aggregation)):
676 | indices = aggregation[segmentIndex]['segments']
677 | for index in indices:
678 | groupedSegmentation[segmentation == index] = segmentIndex
679 | groupedSegmentation = groupedSegmentation.astype(np.int32)
680 | colors = colorMap[groupedSegmentation]
681 | writePointCloudFace(annotationFolder + '/groups.ply', np.concatenate([points, colors], axis=-1), faces)
682 |
683 | planes = []
684 | planePointIndices = []
685 | planeInfo = []
686 | structureIndex = 0
687 | for index, group in enumerate(planeGroups):
688 | groupPlanes, groupPlanePointIndices, groupNeighbors = zip(*group)
689 |
690 | diag = np.diag(np.ones(len(groupNeighbors)))
691 | adjacencyMatrix = diag.copy()
692 | for groupIndex, neighbors in enumerate(groupNeighbors):
693 | for neighbor in neighbors:
694 | adjacencyMatrix[groupIndex][neighbor] = 1
695 | if groupLabels[index] in classLabelMap:
696 | label = classLabelMap[groupLabels[index]]
697 | else:
698 | print('label not valid', groupLabels[index])
699 | exit(1)
700 | label = -1
701 | groupInfo = [[(index, label[0], label[1])] for _ in range(len(groupPlanes))]
702 | groupPlaneIndices = (adjacencyMatrix.sum(-1) >= 2).nonzero()[0]
703 | usedMask = {}
704 | for groupPlaneIndex in groupPlaneIndices:
705 | if groupPlaneIndex in usedMask:
706 | continue
707 | groupStructure = adjacencyMatrix[groupPlaneIndex].copy()
708 | for neighbor in groupStructure.nonzero()[0]:
709 | if np.any(adjacencyMatrix[neighbor] < groupStructure):
710 | groupStructure[neighbor] = 0
711 | groupStructure = groupStructure.nonzero()[0]
712 |
713 | if len(groupStructure) < 2:
714 | print('invalid structure')
715 | print(groupPlaneIndex, groupPlaneIndices)
716 | print(groupNeighbors)
717 | print(groupPlaneIndex)
718 | print(adjacencyMatrix.sum(-1) >= 2)
719 | print((adjacencyMatrix.sum(-1) >= 2).nonzero()[0])
720 | print(adjacencyMatrix[groupPlaneIndex])
721 | print(adjacencyMatrix)
722 | print(groupStructure)
723 | exit(1)
724 | if len(groupStructure) >= 4:
725 | print('complex structure')
726 | print('group index', index)
727 | print(adjacencyMatrix)
728 | print(groupStructure)
729 | groupStructure = groupStructure[:3]
730 | if len(groupStructure) in [2, 3]:
731 | for planeIndex in groupStructure:
732 | groupInfo[planeIndex].append((structureIndex, len(groupStructure)))
733 | structureIndex += 1
734 | for planeIndex in groupStructure:
735 | usedMask[planeIndex] = True
736 | planes += groupPlanes
737 | planePointIndices += groupPlanePointIndices
738 | planeInfo += groupInfo
739 |
740 | planeSegmentation = np.full(segmentation.shape, fill_value=-1, dtype=np.int32)
741 | for planeIndex, planePoints in enumerate(planePointIndices):
742 | planeSegmentation[planePoints] = planeIndex
743 |
744 | # generate planar
745 | if save_mesh:
746 | import copy
747 | color_vis = random_color()
748 | points_plane = []
749 | points_tensor = torch.Tensor(points).cuda()
750 | faces_copy = copy.deepcopy(faces)
751 | faces_copy = torch.Tensor(np.stack(faces_copy)).cuda().int()
752 | planes_tensor = torch.zeros_like(points_tensor)
753 | indices_tensor = torch.zeros_like(points_tensor[:, 0]).int()
754 | for i in range(len(planes)):
755 | planes_tensor[planePointIndices[i]] = torch.Tensor(planes[i]).cuda()
756 | indices_tensor[planePointIndices[i]] = i
757 |
758 | valid = (planes_tensor != 0).any(-1)
759 | invalid_ind = torch.nonzero(valid == 0, as_tuple=False).squeeze(1)
760 | planes_tensor_valid = planes_tensor[valid]
761 | points_tensor_valid = points_tensor[valid]
762 | t = ((points_tensor_valid.unsqueeze(1) @ planes_tensor_valid.unsqueeze(-1)).squeeze() - 1) / (planes_tensor_valid[:, 0] ** 2 + planes_tensor_valid[:, 1] ** 2 + planes_tensor_valid[:, 2] ** 2)
763 | plane_points = points_tensor_valid - planes_tensor_valid[:, :3] * t.unsqueeze(-1)
764 | points_tensor[valid] = plane_points
765 | points_tensor[invalid_ind] = plane_points[0]
766 |
767 | n = 100
768 | part_num = faces_copy.shape[0] // n
769 | match_list = []
770 | for i in range(n):
771 | if i == n-1:
772 | faces_part = faces_copy[i * part_num:]
773 | else:
774 | faces_part = faces_copy[i * part_num: (i + 1) * part_num]
775 | match = (faces_part.unsqueeze(0) != invalid_ind.unsqueeze(-1).unsqueeze(-1))
776 | match = match.all(-1)
777 | match = match.all(0)
778 | match_list.append(match)
779 | match = torch.cat(match_list)
780 | faces_copy = faces_copy[match]
781 | points_plane = points_tensor.data.cpu().numpy()
782 |
783 | n_ins = indices_tensor.data.cpu().numpy().max() + 1
784 | indices_tensor = indices_tensor.data.cpu().numpy()
785 | segmentationColor = (np.arange(n_ins + 1) + 1) * 100
786 | colorMap = np.stack([segmentationColor / (256 * 256), segmentationColor / 256 % 256, segmentationColor % 256],
787 | axis=1)
788 | colorMap[-1] = 0
789 | plane_colors = colorMap[indices_tensor]
790 |
791 | # for vis
792 | colorMap_vis = color_vis(n_ins)
793 | plane_colors_vis = colorMap_vis[indices_tensor]
794 |
795 | writePointCloudFace(annotationFolder + '/planes_mesh.ply',
796 | np.concatenate([points_plane, plane_colors], axis=-1), faces_copy.data.cpu().numpy())
797 |
798 | writePointCloudFace(annotationFolder + '/planes_mesh_vis.ply',
799 | np.concatenate([points_plane, plane_colors_vis], axis=-1), faces_copy.data.cpu().numpy())
800 |
801 | if debug:
802 | groupSegmentation = np.full(segmentation.shape, fill_value=-1, dtype=np.int32)
803 | structureSegmentation = np.full(segmentation.shape, fill_value=-1, dtype=np.int32)
804 | typeSegmentation = np.full(segmentation.shape, fill_value=-1, dtype=np.int32)
805 | for planeIndex, planePoints in enumerate(planePointIndices):
806 | if len(planeInfo[planeIndex]) > 1:
807 | structureSegmentation[planePoints] = planeInfo[planeIndex][1][0]
808 | typeSegmentation[planePoints] = np.maximum(typeSegmentation[planePoints],
809 | planeInfo[planeIndex][1][1] - 2)
810 | groupSegmentation[planePoints] = planeInfo[planeIndex][0][0]
811 |
812 | colors = colorMap[groupSegmentation]
813 | writePointCloudFace(annotationFolder + '/group.ply', np.concatenate([points, colors], axis=-1), faces)
814 |
815 | colors = colorMap[structureSegmentation]
816 | writePointCloudFace(annotationFolder + '/structure.ply', np.concatenate([points, colors], axis=-1), faces)
817 |
818 | colors = colorMap[typeSegmentation]
819 | writePointCloudFace(annotationFolder + '/type.ply', np.concatenate([points, colors], axis=-1), faces)
820 |
821 | planes = np.array(planes)
822 | print('number of planes: ', planes.shape[0])
823 | # planesD = 1.0 / np.maximum(np.linalg.norm(planes, axis=-1, keepdims=True), 1e-4)
824 | # planes *= pow(planesD, 2)
825 |
826 | if debug:
827 | print(len(planes), len(planeInfo))
828 | exit(1)
829 |
830 | plane_points = []
831 | for i in range(len(planePointIndices)):
832 | plane_points.append(points[planePointIndices[i]])
833 |
834 | planes_valid = []
835 | planeInfo_valid = []
836 | plane_points_valid = []
837 | for i in range(len(plane_points)):
838 | if (planes[i] != 0).any():
839 | planes_valid.append(planes[i])
840 | planeInfo_valid.append(planeInfo[i])
841 | plane_points_valid.append(plane_points[i])
842 |
843 | planes_valid = np.stack(planes_valid)
844 | np.save(annotationFolder + '/planes.npy', planes_valid)
845 | np.save(annotationFolder + '/plane_info.npy', planeInfo_valid)
846 | np.save(annotationFolder + '/plane_points', plane_points_valid)
847 |
848 | return planes_valid, plane_points_valid
849 |
850 |
851 | if __name__ == '__main__':
852 |
853 | scene_ids = os.listdir(ROOT_FOLDER)
854 | scene_ids = sorted(scene_ids)
855 |
856 | for index, scene_id in enumerate(scene_ids):
857 | if scene_id[:5] != 'scene':
858 | continue
859 |
860 | if not os.path.exists(args.save_path + '/' + scene_id + '/annotation'):
861 | os.system('mkdir -p ' + args.save_path + '/' + scene_id + '/annotation')
862 |
863 | if not os.path.exists(args.save_path + '/' + scene_id + '/annotation/planes.ply'):
864 | print('plane fitting', scene_id)
865 | readMesh(scene_id)
866 |
--------------------------------------------------------------------------------
/tools/random_color.py:
--------------------------------------------------------------------------------
1 | import random
2 | import numpy as np
3 |
4 | # ============ for viz =====================
5 | class random_color(object):
6 | def __init__(self):
7 | num_of_colors=3000
8 | self.colors = ["#"+''.join([random.choice('0123456789ABCDEF') for i in range(6)])
9 | for j in range(num_of_colors)]
10 |
11 | def __call__(self, ret_n = 10):
12 | assert len(self.colors) > ret_n
13 | ret_color = np.zeros([ret_n, 3])
14 | for i in range(ret_n):
15 | hex_color = self.colors[i][1:]
16 | ret_color[i] = np.array([int(hex_color[j:j + 2], 16) for j in (0, 2, 4)])
17 | return ret_color
18 |
19 |
--------------------------------------------------------------------------------
/tools/simple_loader.py:
--------------------------------------------------------------------------------
1 | import torch
2 | import numpy as np
3 | import os
4 | import cv2
5 |
6 |
7 | def collate_fn(list_data):
8 | cam_pose, depth_im, _ = list_data
9 | # Concatenate all lists
10 | return cam_pose, depth_im, _
11 |
12 |
13 | class ScanNetDataset(torch.utils.data.Dataset):
14 | """Pytorch Dataset for a single scene. getitem loads individual frames"""
15 |
16 | def __init__(self, n_imgs, scene, data_path, max_depth, id_list=None):
17 | """
18 | Args:
19 | """
20 | self.n_imgs = n_imgs
21 | self.scene = scene
22 | self.data_path = data_path
23 | self.max_depth = max_depth
24 | if id_list is None:
25 | self.id_list = [i for i in range(n_imgs)]
26 | else:
27 | self.id_list = id_list
28 |
29 | def __len__(self):
30 | return self.n_imgs
31 |
32 | def __getitem__(self, id):
33 | """
34 | Returns:
35 | dict of meta data and images for a single frame
36 | """
37 | id = self.id_list[id]
38 | cam_pose = np.loadtxt(os.path.join(self.data_path, self.scene, "pose", '{0}.txt'.format(id)), delimiter=' ')
39 |
40 | # Read depth image and camera pose
41 | # depth_im = cv2.imread(os.path.join(self.data_path, self.scene, "depth", 'frame-{0:06d}.depth.pgm'.format(id)), -1).astype(
42 | # np.float32)
43 | # depth_im /= 1000. # depth is saved in 16-bit PNG in millimeters
44 | # depth_im[depth_im > self.max_depth] = 0
45 |
46 | # Read RGB image
47 | # color_image = cv2.cvtColor(cv2.imread(os.path.join(self.data_path, self.scene, "color", str(id) + ".jpg")),
48 | # cv2.COLOR_BGR2RGB)
49 | # color_image = cv2.resize(color_image, (depth_im.shape[1], depth_im.shape[0]), interpolation=cv2.INTER_AREA)
50 |
51 | return cam_pose, None, None
52 |
53 |
54 | class ReplicaDataset(torch.utils.data.Dataset):
55 | """Pytorch Dataset for a single scene. getitem loads individual frames"""
56 |
57 | def __init__(self, n_imgs, scene, data_path, max_depth, id_list=None):
58 | """
59 | Args:
60 | """
61 | self.n_imgs = n_imgs
62 | self.scene = scene
63 | self.data_path = data_path
64 | self.max_depth = max_depth
65 | if id_list is None:
66 | self.id_list = [i for i in range(n_imgs)]
67 | else:
68 | self.id_list = id_list
69 |
70 | self.extr = np.loadtxt(os.path.join(data_path, scene, 'traj.txt')).reshape(-1, 4, 4)
71 |
72 | def __len__(self):
73 | return self.n_imgs
74 |
75 | def __getitem__(self, id):
76 | """
77 | Returns:
78 | dict of meta data and images for a single frame
79 | """
80 | id = self.id_list[id]
81 | # cam_pose = np.loadtxt(os.path.join(self.data_path, self.scene, "pose", '{0}.txt'.format(id)), delimiter=' ')
82 | cam_pose = self.extr[id]
83 |
84 | return cam_pose, None, None
85 |
--------------------------------------------------------------------------------