├── .gitignore ├── .gitmodules ├── README.md ├── configs ├── default.yaml └── replica.yaml ├── dataset └── dataset.py ├── geom ├── __init__.py └── plane_utils.py ├── lib ├── alphatablets.py ├── keyframe.py ├── load_colmap.py ├── load_data.py └── load_replica.py ├── recon ├── __init__.py ├── run_depth.py ├── run_recon.py ├── run_sp.py └── utils.py ├── requirements.txt ├── run.py ├── test.py └── tools ├── __init__.py ├── chamfer3D ├── chamfer3D.cu ├── chamfer_cuda.cpp ├── dist_chamfer_3D.py └── setup.py ├── generate_gt.py ├── generate_planes.py ├── random_color.py └── simple_loader.py /.gitignore: -------------------------------------------------------------------------------- 1 | **/__pycache__ 2 | data 3 | logs* 4 | results 5 | planes_9 -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "recon/third_party/omnidata"] 2 | path = recon/third_party/omnidata 3 | url = https://github.com/hyz317/omnidata 4 | [submodule "recon/third_party/metric3d"] 5 | path = recon/third_party/metric3d 6 | url = https://github.com/hyz317/Metric3D 7 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ![title](https://github.com/user-attachments/assets/ac7675cb-3a2e-4e22-8316-da7c420ba69e) 2 | 3 | # AlphaTablets 4 | 5 |

6 | 📃 Paper • 🌐 Project Page 7 |

8 | 9 | > **AlphaTablets: A Generic Plane Representation for 3D Planar Reconstruction from Monocular Videos** 10 | > 11 | > [Yuze He](https://hyzcluster.github.io), [Wang Zhao](https://github.com/thuzhaowang), [Shaohui Liu](http://b1ueber2y.me/), [Yubin Hu](https://github.com/AlbertHuyb), [Yushi Bai](https://bys0318.github.io/), Yu-Hui Wen, Yong-Jin Liu 12 | > 13 | > NeurIPS 2024 14 | 15 | **AlphaTablets** is a novel and generic representation of 3D planes that features continuous 3D surface and precise boundary delineation. By representing 3D planes as rectangles with alpha channels, AlphaTablets combine the advantages of current 2D and 3D plane representations, enabling accurate, consistent and flexible modeling of 3D planes. 16 | 17 | We propose a novel bottom-up pipeline for 3D planar reconstruction from monocular videos. Starting with 2D superpixels and geometric cues from pre-trained models, we initialize 3D planes as AlphaTablets and optimize them via differentiable rendering. An effective merging scheme is introduced to facilitate the growth and refinement of AlphaTablets. Through iterative optimization and merging, we reconstruct complete and accurate 3D planes with solid surfaces and clear boundaries. 18 | 19 | 20 | 21 | 22 | 23 | ## Quick Start 24 | 25 | ### 1. Clone the Repository 26 | 27 | Make sure to clone the repository along with its submodules: 28 | 29 | ```bash 30 | git clone --recursive https://github.com/THU-LYJ-Lab/AlphaTablets 31 | ``` 32 | 33 | ### 2. Install Dependencies 34 | 35 | Set up a Python environment and install the required packages: 36 | 37 | ```bash 38 | conda create -n alphatablets python=3.9 39 | conda activate alphatablets 40 | 41 | # Install PyTorch based on your machine configuration 42 | pip install torch==2.1.2 torchvision==0.16.2 torchaudio==2.1.2 --index-url https://download.pytorch.org/whl/cu118 43 | 44 | # Install other dependencies 45 | # Note: mmcv package also requires CUDA. To avoid potential errors, set the CUDA_HOME environment variable and download a CUDA-compatible version of the library. 46 | # Example: python -m pip install mmcv==2.2.0 -f https://download.openmmlab.com/mmcv/dist/cu118/torch2.1/index.html 47 | pip install -r requirements.txt 48 | ``` 49 | 50 | ### 3. Download Pretrained Weights 51 | 52 | #### Monocular Normal Estimation Weights 53 | 54 | Download **Omnidata** pretrained weights: 55 | 56 | - File: `omnidata_dpt_normal_v2.ckpt` 57 | - Link: [Download Here](https://www.dropbox.com/scl/fo/348s01x0trt0yxb934cwe/h?rlkey=a96g2incso7g53evzamzo0j0y&e=2&dl=0) 58 | 59 | Place the file in the directory: 60 | 61 | ```plaintext 62 | ./recon/third_party/omnidata/omnidata_tools/torch/pretrained_models 63 | ``` 64 | 65 | #### Depth Estimation Weights 66 | 67 | Download **Metric3D** pretrained weights: 68 | 69 | - File: `metric_depth_vit_giant2_800k.pth` 70 | - Link: [Download Here](https://huggingface.co/JUGGHM/Metric3D/blob/main/metric_depth_vit_giant2_800k.pth) 71 | 72 | Place the file in the directory: 73 | 74 | ```plaintext 75 | ./recon/third_party/metric3d/weight 76 | ``` 77 | 78 | ### 4. Running Demos 79 | 80 | #### ScanNet Demo 81 | 82 | 1. Download the `scene0684_01` demo scene from [here](https://drive.google.com/drive/folders/13rYkek_CQuOk_N5erJL08R26B1BkYmwD?usp=sharing) and extract it to `./data/`. 83 | 2. Run the demo with the following command: 84 | 85 | ```bash 86 | python run.py --job scene0684_01 87 | ``` 88 | 89 | #### Replica Demo 90 | 91 | 1. Download the `office0` demo scene from [here](https://drive.google.com/drive/folders/13rYkek_CQuOk_N5erJL08R26B1BkYmwD?usp=sharing) and extract it to `./data/`. 92 | 2. Run the demo using the specified configuration: 93 | 94 | ```bash 95 | python run.py --config configs/replica.yaml --job office0 96 | ``` 97 | 98 | ### Tips 99 | 100 | - **Out-of-Memory (OOM):** Reduce `batch_size` if you encounter memory issues. 101 | - **Low Frame Rate Sequences:** Increase `weight_decay`, or set it to `-1` for an automatic decay. The default value is `0.9` (works well for ScanNet and Replica), but it can go up to larger values (no more than `1.0`). 102 | - **Scene Scaling Issues:** If the scene scale differs significantly from real-world dimensions, adjust merging parameters such as `dist_thres` (maximum allowable distance for tablet merging). 103 | 104 | 105 | 106 | ## Evaluation on the ScanNet v2 dataset 107 | 108 | 1. **Download and Extract ScanNet**: 109 | Follow the instructions provided on the [ScanNet website](http://www.scan-net.org/) to download and extract the dataset. 110 | 111 | 2. **Prepare the Data**: 112 | Use the data preparation script to parse the raw ScanNet data into a processed pickle format and generate ground truth planes using code modified from [PlaneRCNN](https://github.com/NVlabs/planercnn/blob/master/data_prep/parse.py) and [PlanarRecon](https://github.com/neu-vi/PlanarRecon/tree/main). 113 | 114 | Run the following command under the PlanarRecon environment: 115 | 116 | ```bash 117 | python tools/generate_gt.py --data_path PATH_TO_SCANNET --save_name planes_9/ --window_size 9 --n_proc 2 --n_gpu 1 118 | python tools/prepare_inst_gt_txt.py --val_list PATH_TO_SCANNET/scannetv2_val.txt --plane_mesh_path ./planes_9 119 | ``` 120 | 121 | 3. **Process Scenes in the Validation Set**: 122 | You can use the following command to process each scene in the validation set. Update `scene????_??` with the specific scene name. Train/val/test split information is available [here](https://github.com/ScanNet/ScanNet/tree/master/Tasks/Benchmark): 123 | 124 | ```bash 125 | python run.py --job scene????_?? --input_dir PATH_TO_SCANNET/scene????_?? 126 | ``` 127 | 128 | 4. **Run the Test Script**: 129 | Finally, execute the test script to evaluate the processed data: 130 | 131 | ```bash 132 | python test.py 133 | ``` 134 | 135 | 136 | 137 | ## Citation 138 | 139 | If you find our work useful, please kindly cite: 140 | 141 | ``` 142 | @article{he2024alphatablets, 143 | title={AlphaTablets: A Generic Plane Representation for 3D Planar Reconstruction from Monocular Videos}, 144 | author={Yuze He and Wang Zhao and Shaohui Liu and Yubin Hu and Yushi Bai and Yu-Hui Wen and Yong-Jin Liu}, 145 | journal={arXiv preprint arXiv:2411.19950}, 146 | year={2024} 147 | } 148 | ``` 149 | 150 | 151 | 152 | ## Acknowledgements 153 | 154 | Some of the test code and installation guide in this repo is borrowed from [NeuralRecon](https://github.com/zju3dv/NeuralRecon), [PlanarRecon](https://github.com/neu-vi/PlanarRecon/tree/main) and [ParticleSfM](https://github.com/bytedance/particle-sfm)! We sincerely thank them all. 155 | -------------------------------------------------------------------------------- /configs/default.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_type: 'scannet' 3 | input_dir: './data/scene0684_01' 4 | 5 | init: 6 | depth_model_type: 'metric3d' 7 | normal_model_type: 'omnidata' 8 | 9 | crop: 20 -------------------------------------------------------------------------------- /configs/replica.yaml: -------------------------------------------------------------------------------- 1 | data: 2 | dataset_type: 'replica' 3 | input_dir: './data/office0' 4 | 5 | init: 6 | depth_model_type: 'metric3d' 7 | normal_model_type: 'omnidata' 8 | 9 | crop: 20 -------------------------------------------------------------------------------- /dataset/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.utils.data import Dataset 3 | 4 | class CustomDataset(Dataset): 5 | def __init__(self, images, mvp_mtxs, normals, depths): 6 | self.images = images 7 | self.mvp_mtxs = mvp_mtxs 8 | self.normals = normals 9 | self.depths = depths 10 | 11 | def __len__(self): 12 | return len(self.images) 13 | 14 | def __getitem__(self, idx): 15 | img = torch.from_numpy(self.images[idx]) 16 | mvp_mtx = torch.from_numpy(self.mvp_mtxs[idx]).float() 17 | normal = torch.from_numpy(self.normals[idx]) 18 | depth = torch.from_numpy(self.depths[idx]) 19 | 20 | return img, mvp_mtx, normal, depth 21 | -------------------------------------------------------------------------------- /geom/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-LYJ-Lab/AlphaTablets/735cbfe6aa7f03f7bc37f48303045fe71ffa042a/geom/__init__.py -------------------------------------------------------------------------------- /geom/plane_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | 4 | from sklearn.neighbors import KDTree 5 | from tqdm import trange 6 | 7 | 8 | class DisjointSet: 9 | def __init__(self, size, params, mean_colors): 10 | self.parent = [i for i in range(size)] 11 | self.rank = [0] * size 12 | self.accum_normal = [_ for _ in params[:, :3].copy()] 13 | self.accum_center = [_ for _ in params[:, 3:].copy()] 14 | self.accum_color = [_ for _ in mean_colors.copy()] 15 | self.accum_num = [1] * size 16 | 17 | def find(self, x): 18 | if self.parent[x] != x: 19 | self.parent[x] = self.find(self.parent[x]) 20 | return self.parent[x] 21 | 22 | def union(self, x, y, normal_thres=0.9, dis_thres=0.1, color_thres=1.): 23 | root_x = self.find(x) 24 | root_y = self.find(y) 25 | if root_x != root_y and np.abs( 26 | np.sum(self.accum_normal[root_x] * self.accum_normal[root_y]) / 27 | np.linalg.norm(self.accum_normal[root_x]) / 28 | np.linalg.norm(self.accum_normal[root_y]) 29 | ) > normal_thres and np.linalg.norm( 30 | (self.accum_center[root_x] / self.accum_num[root_x] - 31 | self.accum_center[root_y] / self.accum_num[root_y]) 32 | * self.accum_normal[root_x] / np.linalg.norm(self.accum_normal[root_x]) 33 | ) < dis_thres and np.linalg.norm( 34 | self.accum_color[root_x] / self.accum_num[root_x] - 35 | self.accum_color[root_y] / self.accum_num[root_y] 36 | ) < color_thres: 37 | if self.rank[root_x] < self.rank[root_y]: 38 | self.parent[root_x] = root_y 39 | self.accum_normal[root_y] += self.accum_normal[root_x] 40 | self.accum_center[root_y] += self.accum_center[root_x] 41 | self.accum_color[root_y] += self.accum_color[root_x] 42 | self.accum_num[root_y] += self.accum_num[root_x] 43 | elif self.rank[root_x] > self.rank[root_y]: 44 | self.parent[root_y] = root_x 45 | self.accum_normal[root_x] += self.accum_normal[root_y] 46 | self.accum_center[root_x] += self.accum_center[root_y] 47 | self.accum_color[root_x] += self.accum_color[root_y] 48 | self.accum_num[root_x] += self.accum_num[root_y] 49 | else: 50 | self.parent[root_x] = root_y 51 | self.rank[root_y] += 1 52 | self.accum_normal[root_y] += self.accum_normal[root_x] 53 | self.accum_center[root_y] += self.accum_center[root_x] 54 | self.accum_color[root_y] += self.accum_color[root_x] 55 | self.accum_num[root_y] += self.accum_num[root_x] 56 | 57 | 58 | def calc_plane(K, pose, depth, dis, mask, x_range_uv, y_range_uv, plane_normal=None): 59 | """Calculate the plane parameters from the camera pose and the pixel plane parameters. 60 | Args: 61 | K: camera intrinsic matrix, [3, 3] 62 | pose: camera pose, world-to-cam, [4, 4] 63 | dis: pixel plane distance to camera, scalar 64 | x_range_uv: pixel x range, [2] 65 | y_range_uv: pixel y range, [2] 66 | plane_normal: plane normal vector, [3], optional 67 | Returns: 68 | plane: plane parameters, [4] 69 | plane_up_vec: plane up vector, [3] 70 | resol: pixel plane resolution, [2] 71 | new_x_range_uv: new pixel x range, [2] 72 | new_y_range_uv: new pixel y range, [2] 73 | """ 74 | 75 | # create a pixel grid in the image plane 76 | u = np.linspace(x_range_uv[0], x_range_uv[1], int(x_range_uv[1]-x_range_uv[0]+1)) 77 | v = np.linspace(y_range_uv[0], y_range_uv[1], int(y_range_uv[1]-y_range_uv[0]+1)) 78 | U, V = np.meshgrid(u, v) 79 | 80 | # back project the pixel grid to 3D space 81 | X = depth * (U + 0.5 - K[0, 2]) / K[0, 0] 82 | Y = -depth * (V + 0.5 - K[1, 2]) / K[1, 1] 83 | Z = -depth * np.ones(U.shape) 84 | 85 | # transform the points from camera coordinate to world coordinate 86 | points = np.stack((X, Y, Z, np.ones(U.shape)), axis=-1) 87 | points_world = np.matmul(points, np.linalg.inv(pose).T) 88 | points_world = points_world[:, :, :3] / points_world[:, :, 3:] 89 | 90 | # use PCA to fit a plane to these points 91 | points_world_flat = points_world[mask] 92 | mean = np.mean(points_world_flat, axis=0) 93 | points_world_zero_centroid = points_world_flat - mean 94 | 95 | if plane_normal is not None: 96 | plane_normal = np.matmul(plane_normal, np.linalg.inv(pose[:3, :3]).T) 97 | plane_normal = plane_normal / np.linalg.norm(plane_normal) 98 | 99 | else: 100 | if len(points_world_zero_centroid) < 10000: 101 | _, _, v = np.linalg.svd(points_world_zero_centroid) 102 | else: 103 | import random 104 | _, _, v = np.linalg.svd(random.sample(list(points_world_zero_centroid), 10000)) 105 | 106 | # plane parameters 107 | plane_normal = v[-1, :] / np.linalg.norm(v[-1, :]) 108 | if np.abs(plane_normal).max() != plane_normal.max(): 109 | plane_normal = -plane_normal 110 | 111 | plane = np.concatenate((plane_normal, mean)) 112 | 113 | resol = np.array([K[0,0]/dis, K[1,1]/dis]) 114 | 115 | # calculate plane_up_vector according to the plane_normal 116 | pose_up_vec = np.array([0, 0, 1]) 117 | if np.abs(np.sum(plane_normal * pose_up_vec)) > 0.8: 118 | pose_up_vec = np.array([0, 1, 0]) 119 | plane_up_vec = pose_up_vec - np.sum(plane_normal * pose_up_vec) * plane_normal 120 | plane_up_vec = plane_up_vec / np.linalg.norm(plane_up_vec) 121 | 122 | U, V = U[mask], V[mask] 123 | 124 | plane_right = np.cross(plane_normal, plane_up_vec) 125 | 126 | new_x_argmin =np.matmul(points_world_flat, plane_right).argmin() 127 | new_x_argmax =np.matmul(points_world_flat, plane_right).argmax() 128 | new_y_argmin =np.matmul(points_world_flat, plane_up_vec).argmin() 129 | new_y_argmax =np.matmul(points_world_flat, plane_up_vec).argmax() 130 | 131 | new_x_width = np.sqrt((U[new_x_argmin] - U[new_x_argmax])**2 + (V[new_x_argmin] - V[new_x_argmax])**2) 132 | new_y_width = np.sqrt((U[new_y_argmin] - U[new_y_argmax])**2 + (V[new_y_argmin] - V[new_y_argmax])**2) 133 | 134 | new_x_range_uv = [-new_x_width/2 - 2, new_x_width/2 + 2] 135 | new_y_range_uv = [-new_y_width/2 - 2, new_y_width/2 + 2] 136 | 137 | y_range_3d = np.matmul(points_world_flat, plane_up_vec).max() - np.matmul(points_world_flat, plane_up_vec).min() 138 | x_range_3d = np.matmul(points_world_flat, plane_right).max() - np.matmul(points_world_flat, plane_right).min() 139 | 140 | resol = np.array([ 141 | new_x_width / (x_range_3d + 1e-6), 142 | new_y_width / (y_range_3d + 1e-6) 143 | ]) 144 | 145 | if resol[0] < K[0,0]/dis / 10 or resol[1] < K[1,1]/dis / 10: 146 | return None 147 | 148 | xy_min = np.array([new_x_range_uv[0], new_y_range_uv[0]]) 149 | xy_max = np.array([new_x_range_uv[1], new_y_range_uv[1]]) 150 | 151 | return plane, plane_up_vec, resol, xy_min, xy_max 152 | 153 | 154 | def ray_plane_intersect(plane, ray_origin, ray_direction): 155 | """Calculate the intersection of a ray and a plane. 156 | Args: 157 | plane: plane parameters, [4] 158 | ray_origin: ray origin, [3] 159 | ray_direction: ray direction, [3] 160 | 161 | Returns: 162 | intersection: intersection point, [3] 163 | """ 164 | 165 | # calculate intersection 166 | t = -(plane[3] + np.dot(plane[:3], ray_origin)) / np.dot(plane[:3], ray_direction) 167 | intersection = ray_origin + t * ray_direction 168 | 169 | return intersection 170 | 171 | 172 | def points_xyz_to_plane_uv(points, plane, resol, plane_up): 173 | """Project 3D points to the pixel plane. 174 | Args: 175 | points: 3D points, [N, 3] 176 | plane: plane parameters, [4] 177 | resol: pixel plane resolution, [2] 178 | plane_up: plane up vector, [3] 179 | Returns: 180 | uv: pixel plane coordinates, [N, 2] 181 | """ 182 | 183 | # plane normal vector 184 | plane_normal = np.asarray(plane[:3]) 185 | # projection points of 'points' on the plane 186 | points_proj = points - np.outer( np.sum(points*plane_normal, axis=1)+plane[3] , plane_normal) / np.linalg.norm(plane_normal) 187 | mean_proj = np.mean(points_proj, axis=0) 188 | uvw_right = np.cross(plane_normal, plane_up) 189 | uvw_up = plane_up 190 | # calculate the uv coordinates 191 | uv = np.c_[np.sum((points_proj-mean_proj) * uvw_right, axis=-1), np.sum((points_proj-mean_proj) * uvw_up, axis=-1)] 192 | uv = uv * resol 193 | return uv 194 | 195 | 196 | def points_xyz_to_plane_uv_torch(points, plane, resol, plane_up): 197 | """Project 3D points to the pixel plane. 198 | Args: 199 | points: 3D points, [N, 3] 200 | plane: plane parameters, [4] 201 | resol: pixel plane resolution, [2] 202 | plane_up: plane up vector, [3] 203 | Returns: 204 | uv: pixel plane coordinates, [N, 2] 205 | """ 206 | 207 | # plane normal vector 208 | plane_normal = plane[:3] 209 | # projection points of 'points' on the plane 210 | points_proj = points - torch.outer( torch.sum(points*plane_normal, dim=1)+plane[3] , plane_normal) / torch.norm(plane_normal) 211 | mean_proj = torch.mean(points_proj, dim=0) 212 | uvw_right = torch.cross(plane_normal, plane_up) 213 | uvw_up = plane_up 214 | # calculate the uv coordinates 215 | uv = torch.cat([torch.sum((points_proj-mean_proj) * uvw_right, dim=-1, keepdim=True), torch.sum((points_proj-mean_proj) * uvw_up, dim=-1, keepdim=True)], dim=-1) 216 | uv = uv * resol 217 | return uv 218 | 219 | 220 | def simple_distribute_planes_2D(uv_ranges, gap=2, min_H=8192): 221 | """Distribute pixel planes in 2D space. 222 | Args: 223 | uv_ranges: pixel ranges of each plane, [N, 2] 224 | Returns: 225 | plane_leftup: left-up pixel of each plane, [N, 2] 226 | """ 227 | 228 | # calculate the left-up pixel of each plane 229 | plane_leftup = torch.zeros_like(uv_ranges) 230 | 231 | H = (int(torch.max(uv_ranges[:, 1]).item()) + 2) // gap * gap + gap 232 | H = max(H, min_H) 233 | 234 | # sort the planes by the height 235 | _, sort_idx = torch.sort(uv_ranges[:, 1], descending=True) 236 | 237 | # distribute the planes 238 | idx = 0 239 | now_H = 0 240 | now_W = 0 241 | prev_W = 0 242 | while idx < len(sort_idx): 243 | now_leftup = torch.tensor((prev_W+1, now_H+1)) 244 | plane_leftup[sort_idx[idx]] = now_leftup 245 | now_W = max(now_W, uv_ranges[sort_idx[idx], 0] + 1 + prev_W) 246 | now_H += uv_ranges[sort_idx[idx], 1] + 1 247 | 248 | if idx + 1 < len(sort_idx) and now_H + uv_ranges[sort_idx[idx+1], 1] + 1 > H: 249 | prev_W = now_W 250 | now_H = 0 251 | 252 | idx += 1 253 | 254 | W = (int(now_W.item()) + 2) // gap * gap + gap 255 | 256 | return plane_leftup, W, H 257 | 258 | 259 | def cluster_planes(planes, K=2, thres=0.999, 260 | color_thres_1=0.3, color_thres_2=0.2, 261 | dis_thres_1=0.5, dis_thres_2=0.1, 262 | merge_edge_planes=False, 263 | init_plane_sets=None): 264 | # planes[:, :3]: plane normal vector 265 | # planes[:, 3:6]: plane center 266 | 267 | # create disjoint set 268 | ds = DisjointSet(len(planes), planes[:, :6], planes[:, 17:20]) 269 | 270 | if init_plane_sets is not None: 271 | for plane_set in init_plane_sets: 272 | for i in range(len(plane_set)-1): 273 | ds.union(plane_set[i], plane_set[i+1], normal_thres=0.99, dis_thres=1, color_thres=114514) 274 | 275 | # construct kd-tree for plane center 276 | print('clustering tablets ...') 277 | tree = KDTree(planes[:, 3:6]) 278 | 279 | # find the nearest K neighbors for each plane 280 | _, ind = tree.query(planes[:, 3:6], k=K+1) 281 | 282 | # calculate the angle between each plane and its neighbors 283 | neighbor_normals = planes[ind[:, 1:], :3] 284 | plane_normals = planes[:, :3] 285 | cos = np.sum(neighbor_normals * plane_normals[:, None, :], axis=-1) 286 | 287 | # merge planes that have cos > thres 288 | for i in trange(len(planes)): 289 | if not merge_edge_planes and planes[i, 15] == True: 290 | continue 291 | 292 | for j in range(K): 293 | if not merge_edge_planes and planes[ind[i, j+1], 15] == True: 294 | continue 295 | if cos[i, j] > thres: 296 | ds.union(i, ind[i, j+1], normal_thres=0.99, dis_thres=dis_thres_1, color_thres=color_thres_1) 297 | 298 | # merge planes that have cos > thres 299 | for i in trange(len(planes)): 300 | if not merge_edge_planes and planes[i, 15] == True: 301 | continue 302 | 303 | for j in range(K): 304 | if not merge_edge_planes and planes[ind[i, j+1], 15] == True: 305 | continue 306 | if cos[i, j] > thres: 307 | ds.union(i, ind[i, j+1], normal_thres=0.9, dis_thres=dis_thres_2, color_thres=color_thres_2) 308 | 309 | root2idx = {} 310 | for i in range(len(planes)): 311 | root = ds.find(i) 312 | if root not in root2idx: 313 | root2idx[root] = [] 314 | root2idx[root].append(i) 315 | 316 | plane_sets = [] 317 | for root in root2idx: 318 | plane_sets.append(root2idx[root]) 319 | 320 | return plane_sets 321 | 322 | -------------------------------------------------------------------------------- /lib/keyframe.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import glob 3 | import os 4 | from datetime import datetime 5 | 6 | 7 | def get_keyframes(folder, min_angle=15, min_distance=0.2, window_size=9, 8 | min_mean=0.2, max_mean=10): 9 | txt_list = sorted(glob.glob(f'{folder}/pose/*.txt'), key=lambda x: int(x.split('/')[-1].split('.')[0])) 10 | if len(txt_list) != 0: 11 | last_pose = np.loadtxt(txt_list[0]) 12 | image_skip = np.ceil(len(txt_list) / 2000) 13 | txt_list = txt_list[::int(image_skip)] 14 | pose_num = len(txt_list) 15 | else: 16 | extrs = np.loadtxt(os.path.join(folder, 'traj.txt')).reshape(-1, 4, 4) 17 | last_pose = extrs[0] 18 | image_skip = np.ceil(len(extrs) / 2000) 19 | extrs = extrs[::int(image_skip)] 20 | pose_num = len(extrs) 21 | 22 | count = 1 23 | all_ids = [] 24 | 25 | if len(txt_list) != 0: 26 | depth_list = [ pname.replace('pose', 'aligned_dense_depths').replace('.txt', '.npy') for pname in txt_list ] 27 | else: 28 | depth_list = sorted(glob.glob(f'{folder}/aligned_dense_depths/*.npy')) 29 | for i, j in zip(txt_list, depth_list): 30 | if int(i.split('/')[-1].split('.')[0]) != int(j.split('/')[-1].split('.')[0]): 31 | print(i, j) 32 | raise ValueError('pose and depth not match') 33 | depth_list = depth_list[::int(image_skip)] 34 | 35 | depth_list = np.array([ np.load(i).mean() for i in depth_list ]) 36 | id_list = np.linspace(0, len(depth_list)-1, pose_num).astype(int)[::int(image_skip)] 37 | id_list = id_list[np.logical_and(depth_list > min_mean, depth_list < max_mean)] 38 | depth_list = depth_list[np.logical_and(depth_list > min_mean, depth_list < max_mean)] 39 | 40 | from scipy.signal import medfilt 41 | filtered_depth_list = medfilt(depth_list, kernel_size=29) 42 | 43 | # filtered out the depth_list 44 | id_list = id_list[ 45 | np.logical_and( 46 | depth_list > filtered_depth_list * 0.85, 47 | depth_list < filtered_depth_list * 1.15 48 | ) 49 | ] 50 | 51 | depth_list = depth_list[ 52 | np.logical_and( 53 | depth_list > filtered_depth_list * 0.85, 54 | depth_list < filtered_depth_list * 1.15 55 | ) 56 | ] 57 | 58 | print(str(datetime.now()) + ': \033[92mI', 'filtered out depth_list', len(id_list), '/', pose_num, '\033[0m') 59 | 60 | ids = [id_list[0]] 61 | 62 | for idx_pos, idx in enumerate(id_list[1:]): 63 | if len(txt_list) != 0: 64 | i = txt_list[idx] 65 | with open(i, 'r') as f: 66 | cam_pose = np.loadtxt(i) 67 | else: 68 | cam_pose = extrs[idx] 69 | angle = np.arccos( 70 | ((np.linalg.inv(cam_pose[:3, :3]) @ last_pose[:3, :3] @ np.array([0, 0, 1]).T) * np.array( 71 | [0, 0, 1])).sum()) 72 | dis = np.linalg.norm(cam_pose[:3, 3] - last_pose[:3, 3]) 73 | if angle > (min_angle / 180) * np.pi or dis > min_distance: 74 | ids.append(idx) 75 | last_pose = cam_pose 76 | # Compute camera view frustum and extend convex hull 77 | count += 1 78 | if count == window_size: 79 | ids = [i * int(image_skip) for i in ids] 80 | all_ids.append(ids) 81 | ids = [] 82 | count = 0 83 | 84 | if len(ids) > 2: 85 | ids = [i * int(image_skip) for i in ids] 86 | all_ids.append(ids) 87 | else: 88 | ids = [i * int(image_skip) for i in ids] 89 | all_ids[-1].extend(ids) 90 | 91 | return all_ids, int(image_skip) 92 | 93 | 94 | if __name__ == '__main__': 95 | folder = '/data0/bys/Hex/PixelPlane/data/scene0709_01' 96 | keyframes, image_skip = get_keyframes(folder) 97 | print(keyframes, image_skip) 98 | -------------------------------------------------------------------------------- /lib/load_colmap.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os 3 | import cv2 4 | import glob 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from recon.utils import load_sfm_pose, load_scannet_pose 9 | 10 | from geom.plane_utils import calc_plane 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | 16 | def dilate(img): 17 | kernel = torch.ones(3, 3).cuda().double() 18 | img = torch.tensor(img.astype('int32')).cuda().double() 19 | 20 | while True: 21 | mask = img > 0 22 | if torch.logical_not(mask).sum() == 0: 23 | break 24 | mask = mask.double() 25 | 26 | mask_dilated = F.conv2d(mask[None, None], kernel[None, None], padding=1).squeeze(0).squeeze(0) 27 | 28 | img_dilated = F.conv2d(img[None, None], kernel[None, None], padding=1).squeeze(0).squeeze(0) 29 | img_dilated = torch.where(mask_dilated == 0, img_dilated, img_dilated / mask_dilated) 30 | 31 | img = torch.where(mask == 0, img_dilated, img) 32 | 33 | img = img.cpu().numpy() 34 | 35 | return img 36 | 37 | 38 | def parse_sp(sp, intr, pose, depth, normal, img_sp, edge_thres=0.95, crop=20): 39 | """Parse the superpixel segmentation into planes 40 | """ 41 | sp = torch.from_numpy(sp).cuda() 42 | depth = torch.from_numpy(depth).cuda() 43 | normal = (torch.from_numpy(normal).cuda() - 0.5) * 2 44 | normal = normal.permute(1, 2, 0) 45 | img_sp = torch.from_numpy(img_sp).cuda() 46 | sp = sp.int() 47 | sp_ids = torch.unique(sp) 48 | planes = [] 49 | mean_colors = [] 50 | 51 | # mask crop to -1 52 | sp[:crop] = -1 53 | sp[-crop:] = -1 54 | sp[:, :crop] = -1 55 | sp[:, -crop:] = -1 56 | 57 | new_sp = torch.ones_like(sp).cuda() * -1 58 | 59 | sp_id_cnt = 0 60 | 61 | for sp_id in tqdm(sp_ids): 62 | sp_mask = sp == sp_id 63 | 64 | if sp_mask.sum() <= 15: 65 | continue 66 | 67 | new_sp[sp_mask] = sp_id_cnt 68 | mean_color = torch.mean(img_sp[sp_mask], dim=0) 69 | 70 | sp_dis = depth[sp_mask] 71 | sp_dis = torch.median(sp_dis) 72 | if sp_dis < 0: 73 | continue 74 | sp_normal = normal[sp_mask] 75 | 76 | sp_coeff = torch.einsum('ij,kj->ik', sp_normal, sp_normal) 77 | if sp_coeff.min() < edge_thres: 78 | is_edge_plane = True 79 | else: 80 | is_edge_plane = False 81 | 82 | sp_normal = torch.mean(sp_normal, dim=0) 83 | sp_normal = sp_normal / torch.norm(sp_normal) 84 | sp_normal[1] = -sp_normal[1] 85 | sp_normal[2] = -sp_normal[2] 86 | 87 | x_accum = torch.sum(sp_mask, dim=0) 88 | y_accum = torch.sum(sp_mask, dim=1) 89 | x_range_uv = [torch.min(torch.nonzero(x_accum)).item(), torch.max(torch.nonzero(x_accum)).item()] 90 | y_range_uv = [torch.min(torch.nonzero(y_accum)).item(), torch.max(torch.nonzero(y_accum)).item()] 91 | 92 | sp_dis = sp_dis.cpu().numpy() 93 | sp_depth = depth[y_range_uv[0]:y_range_uv[1]+1, x_range_uv[0]:x_range_uv[1]+1].cpu().numpy() 94 | sp_mask = sp_mask[y_range_uv[0]:y_range_uv[1]+1, x_range_uv[0]:x_range_uv[1]+1].cpu().numpy() 95 | 96 | ret = calc_plane(intr, pose, sp_depth, sp_dis, sp_mask, x_range_uv, y_range_uv, plane_normal=sp_normal.cpu().numpy()) 97 | if ret is None: 98 | continue 99 | plane, plane_up, resol, new_x_range_uv, new_y_range_uv = ret 100 | 101 | plane = np.concatenate([plane, plane_up, resol, new_x_range_uv, new_y_range_uv, [is_edge_plane]]) 102 | planes.append(plane) 103 | mean_colors.append(mean_color.cpu().numpy()) 104 | 105 | sp_id_cnt += 1 106 | 107 | return planes, new_sp.cpu().numpy(), mean_colors 108 | 109 | 110 | def get_projection_matrix(fovy: float, aspect_wh: float, near: float, far: float): 111 | proj_mtx = np.zeros((4, 4), dtype=np.float32) 112 | proj_mtx[0, 0] = 1.0 / (np.tan(fovy / 2.0) * aspect_wh) 113 | proj_mtx[1, 1] = -1.0 / np.tan( 114 | fovy / 2.0 115 | ) # add a negative sign here as the y axis is flipped in nvdiffrast output 116 | proj_mtx[2, 2] = -(far + near) / (far - near) 117 | proj_mtx[2, 3] = -2.0 * far * near / (far - near) 118 | proj_mtx[3, 2] = -1.0 119 | return proj_mtx 120 | 121 | 122 | def get_mvp_matrix(c2w, proj_mtx): 123 | # calculate w2c from c2w: R' = Rt, t' = -Rt * t 124 | # mathematically equivalent to (c2w)^-1 125 | w2c = np.zeros((c2w.shape[0], 4, 4)) 126 | w2c[:, :3, :3] = np.transpose(c2w[:, :3, :3], (0, 2, 1)) 127 | w2c[:, :3, 3:] = np.transpose(-c2w[:, :3, :3], (0, 2, 1)) @ c2w[:, :3, 3:] 128 | w2c[:, 3, 3] = 1.0 129 | 130 | # calculate mvp matrix by proj_mtx @ w2c (mv_mtx) 131 | mvp_mtx = proj_mtx @ w2c 132 | 133 | return mvp_mtx 134 | 135 | 136 | def load_colmap_data(args, kf_list=[], image_skip=1, load_plane=True, scannet_pose=True): 137 | imgs = [] 138 | image_dir = os.path.join(args.input_dir, "images") 139 | start = kf_list[0] 140 | end = kf_list[-1] + 1 if kf_list[-1] != -1 else len(glob.glob(os.path.join(image_dir, "*"))) 141 | image_names = sorted(glob.glob(os.path.join(image_dir, "*")), key=lambda x: int(x.split('/')[-1][:-4])) 142 | 143 | image_names_sp = [image_names[i] for i in kf_list] 144 | image_names = image_names[start:end:image_skip] 145 | 146 | print(str(datetime.now()) + ': \033[92mI', 'loading images ...', '\033[0m') 147 | for name in tqdm(image_names): 148 | img = cv2.imread(name) 149 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 150 | imgs.append(img.astype(np.float32) / 255.) 151 | 152 | imgs_sp = [] 153 | for name in image_names_sp: 154 | img = cv2.imread(name) 155 | # convert to ycbcr using cv2 156 | img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 157 | imgs_sp.append(img.astype(np.float32) / 255.) 158 | 159 | imgs_sp = np.stack(imgs_sp, 0) 160 | 161 | if scannet_pose: 162 | print(str(datetime.now()) + ': \033[92mI', 'loading scannet poses ...', '\033[0m') 163 | intr_scannet, sfm_poses_scannet, sfm_camprojs_scannet = load_scannet_pose(args.input_dir) 164 | intr = intr_scannet 165 | sfm_poses, sfm_camprojs = sfm_poses_scannet, sfm_camprojs_scannet 166 | else: 167 | print(str(datetime.now()) + ': \033[92mI', 'loading colmap poses ...', '\033[0m') 168 | intr, sfm_poses, sfm_camprojs = load_sfm_pose(os.path.join(args.input_dir, "sfm")) 169 | 170 | 171 | sfm_poses_sp = np.array(sfm_poses)[kf_list] 172 | sfm_poses = np.array(sfm_poses)[start:end:image_skip] 173 | sfm_camprojs = np.array(sfm_camprojs)[start:end:image_skip] 174 | 175 | normal_dir = os.path.join(args.input_dir, "omnidata_normal") 176 | normal_names = sorted(glob.glob(os.path.join(normal_dir, "*")), key=lambda x: int(x.split('/')[-1][:-4])) 177 | depth_names = [ iname.replace('omnidata_normal', 'aligned_dense_depths') for iname in normal_names ] 178 | depth_names = [depth_names[i] for i in kf_list] 179 | 180 | print(str(datetime.now()) + ': \033[92mI', 'loading depths ...', '\033[0m') 181 | depths = [] 182 | for name in tqdm(depth_names): 183 | depth = np.load(name) 184 | depths.append(depth) 185 | 186 | depths = np.stack(depths, 0) 187 | aligned_depth_dir = os.path.join(args.input_dir, "aligned_dense_depths") 188 | all_depth_names = sorted(glob.glob(os.path.join(aligned_depth_dir, "*")), key=lambda x: int(x.split('/')[-1][:-4]))[start:end:image_skip] 189 | all_depths = [] 190 | 191 | for name, pose in tqdm(zip(all_depth_names, sfm_poses)): 192 | depth = np.load(name) 193 | all_depths.append(depth) 194 | 195 | if args.init.normal_model_type == 'omnidata': 196 | normal_dir = os.path.join(args.input_dir, "omnidata_normal") 197 | else: 198 | raise NotImplementedError(f'Unknown normal model type {args.init.normal_model_type} exiting') 199 | 200 | normal_names = sorted(glob.glob(os.path.join(normal_dir, "*")), key=lambda x: int(x.split('/')[-1][:-4])) 201 | normal_names = [normal_names[i] for i in kf_list] 202 | 203 | print(str(datetime.now()) + ': \033[92mI', 'loading normals ...', '\033[0m') 204 | normals = [] 205 | for name in tqdm(normal_names): 206 | normal = np.load(name)[:3] 207 | normals.append(normal) 208 | 209 | normals = np.stack(normals, 0) 210 | 211 | all_normal_names = sorted(glob.glob(os.path.join(normal_dir, "*")), key=lambda x: int(x.split('/')[-1][:-4]))[start:end:image_skip] 212 | all_normals = [] 213 | 214 | for name, pose in tqdm(zip(all_normal_names, sfm_poses)): 215 | normal = np.load(name) 216 | normal = normal[:3] 217 | 218 | normal = torch.from_numpy(normal).cuda() 219 | normal = (normal - 0.5) * 2 220 | normal = normal / torch.norm(normal, dim=0, keepdim=True) 221 | normal = normal.permute(1, 2, 0) 222 | normal[:, :, 1] = -normal[:, :, 1] 223 | normal[:, :, 2] = -normal[:, :, 2] 224 | normal = torch.einsum('ijk,kl->ijl', normal, torch.from_numpy(np.linalg.inv(pose[:3, :3]).T).cuda().float()) 225 | normal = F.interpolate(normal.permute(2, 0, 1).unsqueeze(0), 226 | (int(imgs[0].shape[0]/64)*8, int(imgs[0].shape[1]/64)*8), 227 | mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0) 228 | all_normals.append(normal.cpu().numpy()) 229 | 230 | sp_dir = os.path.join(args.input_dir, "sp") 231 | sp_names = [os.path.join(sp_dir, os.path.split(name)[-1].replace('.jpg', '.npy').replace('.png', '.npy')) for name in image_names_sp] 232 | 233 | if load_plane: 234 | print(str(datetime.now()) + ': \033[92mI', 'loading superpixels ...', '\033[0m') 235 | planes_all = [] 236 | new_sp_all = [] 237 | mean_colors_all = [] 238 | plane_index_start = 0 239 | for plane_idx, (sp_name, depth, normal, pose, img_sp) in tqdm(enumerate(zip(sp_names, depths, normals, sfm_poses_sp, imgs_sp))): 240 | sp = np.load(sp_name) 241 | planes, new_sp, mean_colors = parse_sp(sp, intr, pose, depth, normal, img_sp) 242 | planes = np.array(planes) 243 | planes_idx = np.ones((planes.shape[0], 1)) * plane_idx 244 | planes = np.concatenate([planes, planes_idx], 1) 245 | planes_all.extend(planes) 246 | new_sp_all.append(new_sp + plane_index_start) 247 | plane_index_start += planes.shape[0] 248 | mean_colors = np.array(mean_colors) 249 | mean_colors_all.extend(mean_colors) 250 | 251 | planes_all = np.array(planes_all) 252 | mean_colors_all = np.array(mean_colors_all) 253 | 254 | else: 255 | planes_all = None 256 | new_sp_all = None 257 | mean_colors_all = None 258 | 259 | # proj matrixs 260 | frames_proj = [] 261 | frames_c2w = [] 262 | frames_center = [] 263 | for i in range(len(sfm_camprojs)): 264 | fovy = 2 * np.arctan(0.5 * imgs[0].shape[0] / intr[1, 1]) 265 | proj = get_projection_matrix( 266 | fovy, imgs[0].shape[1] / imgs[0].shape[0], 0.1, 1000.0 267 | ) 268 | proj = np.array(proj) 269 | frames_proj.append(proj) 270 | # sfm_poses is w2c 271 | c2w = np.linalg.inv(sfm_poses[i]) 272 | frames_c2w.append(c2w) 273 | frames_center.append(c2w[:3, 3]) 274 | 275 | frames_proj = np.stack(frames_proj, 0) 276 | frames_c2w = np.stack(frames_c2w, 0) 277 | frames_center = np.stack(frames_center, 0) 278 | 279 | mvp_mtxs = get_mvp_matrix( 280 | frames_c2w, frames_proj, 281 | ) 282 | 283 | index_init = ((np.array(kf_list) - kf_list[0]) / image_skip).astype(np.int32) 284 | 285 | return imgs, intr, sfm_poses, sfm_camprojs, frames_center, all_depths, all_normals, planes_all, \ 286 | mvp_mtxs, index_init, new_sp_all, mean_colors_all 287 | -------------------------------------------------------------------------------- /lib/load_data.py: -------------------------------------------------------------------------------- 1 | from .load_colmap import load_colmap_data 2 | from .load_replica import load_replica_data 3 | 4 | 5 | def load_data(args, kf_list, image_skip, load_plane): 6 | if args.dataset_type == 'scannet': 7 | images, intr, sfm_poses, sfm_camprojs, cam_centers, \ 8 | all_depths, normals, planes_all, mvp_mtxs, index_init, \ 9 | new_sps, mean_colors = load_colmap_data(args, kf_list=kf_list, 10 | image_skip=image_skip, 11 | load_plane=load_plane) 12 | print('Loaded scannet', intr, len(images), sfm_poses.shape, sfm_camprojs.shape, args.input_dir) 13 | 14 | data_dict = dict( 15 | poses=sfm_poses, images=images, 16 | intr=intr, sfm_camprojs=sfm_camprojs, 17 | cam_centers=cam_centers, 18 | all_depths=all_depths, 19 | normals=normals, planes_all=planes_all, 20 | mvp_mtxs=mvp_mtxs, 21 | index_init=index_init, 22 | new_sps=new_sps, 23 | mean_colors=mean_colors, 24 | ) 25 | return data_dict 26 | 27 | 28 | elif args.dataset_type == 'replica': 29 | images, intr, sfm_poses, sfm_camprojs, cam_centers, \ 30 | all_depths, normals, planes_all, mvp_mtxs, index_init, \ 31 | new_sps, mean_colors = load_replica_data(args, kf_list=kf_list, 32 | image_skip=image_skip, 33 | load_plane=load_plane) 34 | 35 | print('Loaded replica', intr, len(images), sfm_poses.shape, sfm_camprojs.shape, args.input_dir) 36 | 37 | data_dict = dict( 38 | poses=sfm_poses, images=images, 39 | intr=intr, sfm_camprojs=sfm_camprojs, 40 | cam_centers=cam_centers, 41 | all_depths=all_depths, 42 | normals=normals, planes_all=planes_all, 43 | mvp_mtxs=mvp_mtxs, 44 | index_init=index_init, 45 | new_sps=new_sps, 46 | mean_colors=mean_colors, 47 | ) 48 | return data_dict 49 | 50 | 51 | else: 52 | raise NotImplementedError(f'Unknown dataset type {args.dataset_type} exiting') 53 | 54 | -------------------------------------------------------------------------------- /lib/load_replica.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os 3 | import cv2 4 | import glob 5 | import numpy as np 6 | from tqdm import tqdm 7 | 8 | from recon.utils import load_sfm_pose, load_replica_pose 9 | 10 | from geom.plane_utils import calc_plane 11 | 12 | import torch 13 | import torch.nn.functional as F 14 | 15 | 16 | def dilate(img): 17 | kernel = torch.ones(3, 3).cuda().double() 18 | img = torch.tensor(img.astype('int32')).cuda().double() 19 | 20 | while True: 21 | mask = img > 0 22 | if torch.logical_not(mask).sum() == 0: 23 | break 24 | mask = mask.double() 25 | 26 | mask_dilated = F.conv2d(mask[None, None], kernel[None, None], padding=1).squeeze(0).squeeze(0) 27 | 28 | img_dilated = F.conv2d(img[None, None], kernel[None, None], padding=1).squeeze(0).squeeze(0) 29 | img_dilated = torch.where(mask_dilated == 0, img_dilated, img_dilated / mask_dilated) 30 | 31 | img = torch.where(mask == 0, img_dilated, img) 32 | 33 | img = img.cpu().numpy() 34 | 35 | return img 36 | 37 | 38 | def parse_sp(sp, intr, pose, depth, normal, img_sp, edge_thres=0.95, crop=20): 39 | """Parse the superpixel segmentation into planes 40 | """ 41 | sp = torch.from_numpy(sp).cuda() 42 | depth = torch.from_numpy(depth).cuda() 43 | normal = (torch.from_numpy(normal).cuda() - 0.5) * 2 44 | normal = normal.permute(1, 2, 0) 45 | img_sp = torch.from_numpy(img_sp).cuda() 46 | sp = sp.int() 47 | sp_ids = torch.unique(sp) 48 | planes = [] 49 | mean_colors = [] 50 | 51 | # mask crop to -1 52 | sp[:crop] = -1 53 | sp[-crop:] = -1 54 | sp[:, :crop] = -1 55 | sp[:, -crop:] = -1 56 | 57 | new_sp = torch.ones_like(sp).cuda() * -1 58 | sp_id_cnt = 0 59 | 60 | for sp_id in tqdm(sp_ids): 61 | sp_mask = sp == sp_id 62 | 63 | if sp_mask.sum() <= 15: 64 | continue 65 | 66 | new_sp[sp_mask] = sp_id_cnt 67 | mean_color = torch.mean(img_sp[sp_mask], dim=0) 68 | 69 | sp_dis = depth[sp_mask] 70 | sp_dis = torch.median(sp_dis) 71 | if sp_dis < 0: 72 | continue 73 | sp_normal = normal[sp_mask] 74 | 75 | sp_coeff = torch.einsum('ij,kj->ik', sp_normal, sp_normal) 76 | if sp_coeff.min() < edge_thres: 77 | is_edge_plane = True 78 | else: 79 | is_edge_plane = False 80 | 81 | sp_normal = torch.mean(sp_normal, dim=0) 82 | sp_normal = sp_normal / torch.norm(sp_normal) 83 | sp_normal[1] = -sp_normal[1] 84 | sp_normal[2] = -sp_normal[2] 85 | 86 | x_accum = torch.sum(sp_mask, dim=0) 87 | y_accum = torch.sum(sp_mask, dim=1) 88 | x_range_uv = [torch.min(torch.nonzero(x_accum)).item(), torch.max(torch.nonzero(x_accum)).item()] 89 | y_range_uv = [torch.min(torch.nonzero(y_accum)).item(), torch.max(torch.nonzero(y_accum)).item()] 90 | 91 | sp_dis = sp_dis.cpu().numpy() 92 | sp_depth = depth[y_range_uv[0]:y_range_uv[1]+1, x_range_uv[0]:x_range_uv[1]+1].cpu().numpy() 93 | sp_mask = sp_mask[y_range_uv[0]:y_range_uv[1]+1, x_range_uv[0]:x_range_uv[1]+1].cpu().numpy() 94 | 95 | ret = calc_plane(intr, pose, sp_depth, sp_dis, sp_mask, x_range_uv, y_range_uv, plane_normal=sp_normal.cpu().numpy()) 96 | if ret is None: 97 | continue 98 | plane, plane_up, resol, new_x_range_uv, new_y_range_uv = ret 99 | 100 | plane = np.concatenate([plane, plane_up, resol, new_x_range_uv, new_y_range_uv, [is_edge_plane]]) 101 | planes.append(plane) 102 | mean_colors.append(mean_color.cpu().numpy()) 103 | 104 | sp_id_cnt += 1 105 | 106 | return planes, new_sp.cpu().numpy(), mean_colors 107 | 108 | 109 | def get_projection_matrix(fovy: float, aspect_wh: float, near: float, far: float): 110 | proj_mtx = np.zeros((4, 4), dtype=np.float32) 111 | proj_mtx[0, 0] = 1.0 / (np.tan(fovy / 2.0) * aspect_wh) 112 | proj_mtx[1, 1] = -1.0 / np.tan( 113 | fovy / 2.0 114 | ) # add a negative sign here as the y axis is flipped in nvdiffrast output 115 | proj_mtx[2, 2] = -(far + near) / (far - near) 116 | proj_mtx[2, 3] = -2.0 * far * near / (far - near) 117 | proj_mtx[3, 2] = -1.0 118 | return proj_mtx 119 | 120 | 121 | def get_mvp_matrix(c2w, proj_mtx): 122 | # calculate w2c from c2w: R' = Rt, t' = -Rt * t 123 | # mathematically equivalent to (c2w)^-1 124 | w2c = np.zeros((c2w.shape[0], 4, 4)) 125 | w2c[:, :3, :3] = np.transpose(c2w[:, :3, :3], (0, 2, 1)) 126 | w2c[:, :3, 3:] = np.transpose(-c2w[:, :3, :3], (0, 2, 1)) @ c2w[:, :3, 3:] 127 | w2c[:, 3, 3] = 1.0 128 | 129 | # calculate mvp matrix by proj_mtx @ w2c (mv_mtx) 130 | mvp_mtx = proj_mtx @ w2c 131 | 132 | return mvp_mtx 133 | 134 | 135 | def load_replica_data(args, kf_list=[], image_skip=1, load_plane=True): 136 | imgs = [] 137 | image_dir = os.path.join(args.input_dir, "images") 138 | start = kf_list[0] 139 | end = kf_list[-1] + 1 if kf_list[-1] != -1 else len(glob.glob(os.path.join(image_dir, "*"))) 140 | image_names = sorted(glob.glob(os.path.join(image_dir, "*"))) 141 | 142 | image_names_sp = [image_names[i] for i in kf_list] 143 | image_names = image_names[start:end:image_skip] 144 | 145 | print(str(datetime.now()) + ': \033[92mI', 'loading images ...', '\033[0m') 146 | for name in tqdm(image_names): 147 | img = cv2.imread(name) 148 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 149 | imgs.append(img.astype(np.float32) / 255.) 150 | 151 | imgs_sp = [] 152 | for name in image_names_sp: 153 | img = cv2.imread(name) 154 | # convert to ycbcr using cv2 155 | img = cv2.cvtColor(img, cv2.COLOR_BGR2HSV) 156 | imgs_sp.append(img.astype(np.float32) / 255.) 157 | 158 | imgs_sp = np.stack(imgs_sp, 0) 159 | 160 | intr, sfm_poses, sfm_camprojs = load_replica_pose(args.input_dir) 161 | 162 | sfm_poses_sp = np.array(sfm_poses)[kf_list] 163 | sfm_poses = np.array(sfm_poses)[start:end:image_skip] 164 | sfm_camprojs = np.array(sfm_camprojs)[start:end:image_skip] 165 | 166 | normal_dir = os.path.join(args.input_dir, "omnidata_normal") 167 | normal_names = sorted(glob.glob(os.path.join(normal_dir, "*"))) 168 | depth_names = [ iname.replace('omnidata_normal', 'aligned_dense_depths') for iname in normal_names ] 169 | depth_names = [depth_names[i] for i in kf_list] 170 | 171 | print(str(datetime.now()) + ': \033[92mI', 'loading depths ...', '\033[0m') 172 | depths = [] 173 | for name in tqdm(depth_names): 174 | depth = np.load(name) 175 | depths.append(depth) 176 | 177 | depths = np.stack(depths, 0) 178 | aligned_depth_dir = os.path.join(args.input_dir, "aligned_dense_depths") 179 | all_depth_names = sorted(glob.glob(os.path.join(aligned_depth_dir, "*")))[start:end:image_skip] 180 | all_depths = [] 181 | 182 | for name, pose in tqdm(zip(all_depth_names, sfm_poses)): 183 | depth = np.load(name) 184 | all_depths.append(depth) 185 | 186 | if args.init.normal_model_type == 'omnidata': 187 | normal_dir = os.path.join(args.input_dir, "omnidata_normal") 188 | else: 189 | raise NotImplementedError(f'Unknown normal model type {args.init.normal_model_type} exiting') 190 | normal_names = sorted(glob.glob(os.path.join(normal_dir, "*"))) 191 | normal_names = [normal_names[i] for i in kf_list] 192 | 193 | print(str(datetime.now()) + ': \033[92mI', 'loading normals ...', '\033[0m') 194 | normals = [] 195 | for name in tqdm(normal_names): 196 | normal = np.load(name)[:3] 197 | normals.append(normal) 198 | 199 | normals = np.stack(normals, 0) 200 | 201 | all_normal_names = sorted(glob.glob(os.path.join(normal_dir, "*")))[start:end:image_skip] 202 | all_normals = [] 203 | 204 | for name, pose in tqdm(zip(all_normal_names, sfm_poses)): 205 | normal = np.load(name) 206 | normal = normal[:3] 207 | 208 | normal = torch.from_numpy(normal).cuda() 209 | normal = (normal - 0.5) * 2 210 | normal = normal / torch.norm(normal, dim=0, keepdim=True) 211 | normal = normal.permute(1, 2, 0) 212 | normal[:, :, 1] = -normal[:, :, 1] 213 | normal[:, :, 2] = -normal[:, :, 2] 214 | normal = torch.einsum('ijk,kl->ijl', normal, torch.from_numpy(np.linalg.inv(pose[:3, :3]).T).cuda().float()) 215 | normal = F.interpolate(normal.permute(2, 0, 1).unsqueeze(0), 216 | (int(imgs[0].shape[0]/64)*8, int(imgs[0].shape[1]/64)*8), 217 | mode='bilinear', align_corners=False).squeeze(0).permute(1, 2, 0) 218 | all_normals.append(normal.cpu().numpy()) 219 | 220 | sp_dir = os.path.join(args.input_dir, "sp") 221 | sp_names = [os.path.join(sp_dir, os.path.split(name)[-1].replace('.jpg', '.npy').replace('.png', '.npy')) for name in image_names_sp] 222 | 223 | if load_plane: 224 | print(str(datetime.now()) + ': \033[92mI', 'loading superpixels ...', '\033[0m') 225 | planes_all = [] 226 | new_sp_all = [] 227 | mean_colors_all = [] 228 | plane_index_start = 0 229 | for plane_idx, (sp_name, depth, normal, pose, img_sp) in tqdm(enumerate(zip(sp_names, depths, normals, sfm_poses_sp, imgs_sp))): 230 | sp = np.load(sp_name) 231 | planes, new_sp, mean_colors = parse_sp(sp, intr, pose, depth, normal, img_sp) 232 | planes = np.array(planes) 233 | planes_idx = np.ones((planes.shape[0], 1)) * plane_idx 234 | planes = np.concatenate([planes, planes_idx], 1) 235 | planes_all.extend(planes) 236 | new_sp_all.append(new_sp + plane_index_start) 237 | plane_index_start += planes.shape[0] 238 | mean_colors = np.array(mean_colors) 239 | mean_colors_all.extend(mean_colors) 240 | 241 | planes_all = np.array(planes_all) 242 | mean_colors_all = np.array(mean_colors_all) 243 | 244 | else: 245 | planes_all = None 246 | new_sp_all = None 247 | mean_colors_all = None 248 | 249 | # proj matrixs 250 | frames_proj = [] 251 | frames_c2w = [] 252 | frames_center = [] 253 | for i in range(len(sfm_camprojs)): 254 | fovy = 2 * np.arctan(0.5 * imgs[0].shape[0] / intr[1, 1]) 255 | proj = get_projection_matrix( 256 | fovy, imgs[0].shape[1] / imgs[0].shape[0], 0.1, 1000.0 257 | ) 258 | proj = np.array(proj) 259 | frames_proj.append(proj) 260 | # sfm_poses is w2c 261 | c2w = np.linalg.inv(sfm_poses[i]) 262 | frames_c2w.append(c2w) 263 | frames_center.append(c2w[:3, 3]) 264 | 265 | frames_proj = np.stack(frames_proj, 0) 266 | frames_c2w = np.stack(frames_c2w, 0) 267 | frames_center = np.stack(frames_center, 0) 268 | 269 | mvp_mtxs = get_mvp_matrix( 270 | frames_c2w, frames_proj, 271 | ) 272 | 273 | index_init = ((np.array(kf_list) - kf_list[0]) / image_skip).astype(np.int32) 274 | 275 | return imgs, intr, sfm_poses, sfm_camprojs, frames_center, all_depths, all_normals, planes_all, \ 276 | mvp_mtxs, index_init, new_sp_all, mean_colors_all 277 | -------------------------------------------------------------------------------- /recon/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-LYJ-Lab/AlphaTablets/735cbfe6aa7f03f7bc37f48303045fe71ffa042a/recon/__init__.py -------------------------------------------------------------------------------- /recon/run_depth.py: -------------------------------------------------------------------------------- 1 | import os 2 | import json 3 | import numpy as np 4 | 5 | from recon.third_party.omnidata.omnidata_tools.torch.demo_normal_custom_func import demo_normal_custom_func 6 | from recon.third_party.metric3d.depth_custom_func import depth_custom_func 7 | 8 | 9 | def metric3d_depth(img_dir, input_dir, depth_dir, dataset_type): 10 | """Initialize the dense depth of each image using single-image depth prediction 11 | """ 12 | os.makedirs(depth_dir, exist_ok=True) 13 | if dataset_type == 'scannet': 14 | intr = np.loadtxt(f'{input_dir}/intrinsic/intrinsic_color.txt') 15 | fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2] 16 | elif dataset_type == 'replica': 17 | if os.path.exists(os.path.join(input_dir, 'cam_params.json')): 18 | cam_param_path = os.path.join(input_dir, 'cam_params.json') 19 | elif os.path.exists(os.path.join(input_dir, '../cam_params.json')): 20 | cam_param_path = os.path.join(input_dir, '../cam_params.json') 21 | else: 22 | raise FileNotFoundError('cam_params.json not found') 23 | with open(cam_param_path, 'r') as f: 24 | j = json.load(f) 25 | fx, fy, cx, cy = j["camera"]["fx"], j["camera"]["fy"], j["camera"]["cx"], j["camera"]["cy"] 26 | else: 27 | raise NotImplementedError(f'Unknown dataset type {dataset_type} exiting') 28 | indir = img_dir 29 | outdir = depth_dir 30 | depth_custom_func(fx, fy, cx, cy, indir, outdir) 31 | 32 | 33 | def omnidata_normal(img_dir, input_dir, normal_dir): 34 | """Initialize the dense normal of each image using single-image normal prediction 35 | """ 36 | os.makedirs(normal_dir, exist_ok=True) 37 | demo_normal_custom_func(img_dir, normal_dir) 38 | -------------------------------------------------------------------------------- /recon/run_recon.py: -------------------------------------------------------------------------------- 1 | import os 2 | import glob 3 | from datetime import datetime 4 | 5 | 6 | class Recon: 7 | def __init__(self, config, input_dir, dataset_type): 8 | self.config = config 9 | self.input_dir = input_dir 10 | self.image_dir = os.path.join(input_dir, "images") 11 | self.output_root = input_dir 12 | self.dataset_type = dataset_type 13 | self.preprocess() 14 | 15 | def preprocess(self): 16 | if not os.path.exists(self.image_dir): 17 | if self.dataset_type == 'scannet': 18 | original_image_dir = os.path.join(self.input_dir, "color") 19 | if not os.path.exists(original_image_dir): 20 | raise ValueError(f'Image directory {self.image_dir} not found') 21 | old_dir = os.getcwd() 22 | os.chdir(self.input_dir) 23 | os.symlink("color", "images") 24 | os.chdir(old_dir) 25 | print(str(datetime.now()) + ': \033[92mI', f'Linked {original_image_dir} to {self.image_dir}', '\033[0m') 26 | elif self.dataset_type == 'replica': 27 | original_image_dir = os.path.join(self.input_dir, "results") 28 | if not os.path.exists(original_image_dir): 29 | raise ValueError(f'Image directory {self.image_dir} not found') 30 | os.makedirs(self.image_dir) 31 | old_dir = os.getcwd() 32 | os.chdir(self.image_dir) 33 | for img in glob.glob(os.path.join("../results", "frame*")): 34 | os.symlink(os.path.join("../results", os.path.split(img)[-1]), os.path.split(img)[-1]) 35 | os.chdir(old_dir) 36 | print(str(datetime.now()) + ': \033[92mI', f'Linked {original_image_dir} to {self.image_dir}', '\033[0m') 37 | else: 38 | raise NotImplementedError(f'Unknown dataset type {self.dataset_type} exiting') 39 | 40 | 41 | def recon(self): 42 | if os.path.exists(os.path.join(self.output_root, "recon.lock")): 43 | print(str(datetime.now()) + ': \033[92mI', 'Monocular estimation already done!', '\033[0m') 44 | return 45 | 46 | from .run_depth import omnidata_normal, metric3d_depth 47 | 48 | depth_dir = os.path.join(self.output_root, "aligned_dense_depths") 49 | print(str(datetime.now()) + ': \033[92mI', 'Running Depth Estimation ...', '\033[0m') 50 | metric3d_depth(self.image_dir, self.output_root, depth_dir, self.dataset_type) 51 | print(str(datetime.now()) + ': \033[92mI', 'Depth Estimation Done!', '\033[0m') 52 | 53 | normal_dir = os.path.join(self.output_root, "omnidata_normal") 54 | print(str(datetime.now()) + ': \033[92mI', 'Running Normal Estimation ...', '\033[0m') 55 | omnidata_normal(self.image_dir, self.output_root, normal_dir) 56 | print(str(datetime.now()) + ': \033[92mI', 'Normal Estimation Done!', '\033[0m') 57 | 58 | # create lock file 59 | open(os.path.join(self.output_root, "recon.lock"), "w").close() 60 | 61 | def run_sp(self, kf_list): 62 | from .run_sp import run_sp 63 | print(str(datetime.now()) + ': \033[92mI', 'Running SuperPixel Subdivision ...', '\033[0m') 64 | sp_dir = run_sp(self.image_dir, self.output_root, kf_list) 65 | print(str(datetime.now()) + ': \033[92mI', 'SuperPixel Subdivision Done!', '\033[0m') 66 | return sp_dir 67 | -------------------------------------------------------------------------------- /recon/run_sp.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import numpy as np 4 | from tqdm import tqdm 5 | 6 | 7 | def sp_division(img): 8 | slic = cv2.ximgproc.createSuperpixelSLIC(img) 9 | slic.iterate(10) 10 | labels = slic.getLabels() 11 | # mask_slic = slic.getLabelContourMask() #获取Mask,超像素边缘Mask==1 12 | # number_slic = slic.getNumberOfSuperpixels() #获取超像素数目 13 | # mask_inv_slic = cv2.bitwise_not(mask_slic) 14 | # img_slic = cv2.bitwise_and(img,img,mask = mask_inv_slic) #在原图上绘制超像素边界 15 | # cv2.imwrite('3.png', img_slic) 16 | return labels 17 | 18 | 19 | def run_sp(image_dir, output_root, kf_list=None): 20 | sp_dir = os.path.join(output_root, "sp") 21 | os.makedirs(sp_dir, exist_ok=True) 22 | try: 23 | image_names = sorted(os.listdir(image_dir), key=lambda x: int(x.split('/')[-1][:-4])) 24 | except: 25 | image_names = sorted(os.listdir(image_dir)) 26 | image_names = [image_names[k] for k in kf_list] 27 | for name in tqdm(image_names): 28 | if os.path.exists(os.path.join(sp_dir, name[:-4] + '.npy')): 29 | continue 30 | img = cv2.imread(os.path.join(image_dir, name)) 31 | labels = sp_division(img) 32 | np.save(os.path.join(sp_dir, name[:-4]), labels) 33 | 34 | 35 | if __name__ == '__main__': 36 | img = cv2.imread('/home/hyz/git-plane/PixelPlane/data/scene0000_00/images/36.jpg') 37 | sp_division(img) -------------------------------------------------------------------------------- /recon/utils.py: -------------------------------------------------------------------------------- 1 | import json 2 | import os 3 | import cv2 4 | import glob 5 | import numpy as np 6 | 7 | 8 | def load_images(img_dir): 9 | imgs = [] 10 | image_names = sorted(os.listdir(img_dir)) 11 | for name in image_names: 12 | img = cv2.imread(os.path.join(img_dir, name)) 13 | img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) 14 | imgs.append(img) 15 | return imgs, image_names 16 | 17 | 18 | def load_sfm(sfm_dir): 19 | """Load converted sfm world-to-cam depths and poses 20 | """ 21 | depth_dir = os.path.join(sfm_dir, "colmap_outputs_converted/depths") 22 | pose_dir = os.path.join(sfm_dir, "colmap_outputs_converted/poses") 23 | depth_names = sorted(glob.glob(os.path.join(depth_dir, "*.npy"))) 24 | pose_names = sorted(glob.glob(os.path.join(pose_dir, "*.txt"))) 25 | sfm_depths, sfm_poses = [], [] 26 | for dn, pn in zip(depth_names, pose_names): 27 | assert os.path.basename(dn)[:-4] == os.path.basename(pn)[:-4] 28 | depth = np.load(dn) 29 | pose = np.loadtxt(pn) 30 | sfm_depths.append(depth) 31 | sfm_poses.append(pose) 32 | return sfm_depths, sfm_poses 33 | 34 | 35 | def load_sfm_pose(sfm_dir): 36 | """Load sfm poses and then convert into proj mat 37 | """ 38 | pose_dir = os.path.join(sfm_dir, "colmap_outputs_converted/poses") 39 | intr_dir = os.path.join(sfm_dir, "colmap_outputs_converted/intrinsics") 40 | pose_names = sorted(glob.glob(os.path.join(pose_dir, "*.txt")), key=lambda x: int(x.split('/')[-1][:-4])) 41 | intr_names = sorted(glob.glob(os.path.join(intr_dir, "*.txt")), key=lambda x: int(x.split('/')[-1][:-4])) 42 | K = np.loadtxt(intr_names[0]) 43 | KH = np.eye(4) 44 | KH[:3,:3] = K 45 | sfm_poses, sfm_projmats = [], [] 46 | for pn in pose_names: 47 | # world-to-cam 48 | pose = np.loadtxt(pn) 49 | pose = np.concatenate([pose, np.array([[0,0,0,1]])], 0) 50 | c2w = np.linalg.inv(pose) 51 | c2w[0:3, 1:3] *= -1 52 | pose = np.linalg.inv(c2w) 53 | sfm_poses.append(pose) 54 | # projmat 55 | projmat = KH @ pose 56 | sfm_projmats.append(projmat) 57 | return K, sfm_poses, sfm_projmats 58 | 59 | 60 | def load_scannet_pose(scannet_dir): 61 | """Load scannet poses and then convert into proj mat 62 | """ 63 | pose_dir = os.path.join(scannet_dir, "pose") 64 | intr_dir = os.path.join(scannet_dir, "intrinsic") 65 | pose_names = sorted(glob.glob(os.path.join(pose_dir, "*.txt")), key=lambda x: int(x.split('/')[-1][:-4])) 66 | intr_name = os.path.join(intr_dir, "intrinsic_color.txt") 67 | KH = np.loadtxt(intr_name) 68 | scannet_poses, scannet_projmats = [], [] 69 | for pn in pose_names: 70 | p = np.loadtxt(pn) 71 | R = p[:3, :3] 72 | R = np.matmul(R, np.array([ 73 | [1, 0, 0], 74 | [0, -1, 0], 75 | [0, 0, -1] 76 | ])) 77 | p[:3, :3] = R 78 | p = np.linalg.inv(p) 79 | scannet_poses.append(p) 80 | # projmat 81 | projmat = KH @ p 82 | scannet_projmats.append(projmat) 83 | return KH[:3, :3], scannet_poses, scannet_projmats 84 | 85 | 86 | def load_replica_pose(replica_dir): 87 | """Load replica poses and then convert into proj mat 88 | """ 89 | if os.path.exists(os.path.join(replica_dir, 'cam_params.json')): 90 | cam_param_path = os.path.join(replica_dir, 'cam_params.json') 91 | elif os.path.exists(os.path.join(replica_dir, '../cam_params.json')): 92 | cam_param_path = os.path.join(replica_dir, '../cam_params.json') 93 | else: 94 | raise FileNotFoundError('cam_params.json not found') 95 | with open(cam_param_path, 'r') as f: 96 | j = json.load(f) 97 | intrinsics = np.array([ 98 | j["camera"]["fx"], 0, j["camera"]["cx"], 99 | 0, j["camera"]["fy"], j["camera"]["cy"], 100 | 0, 0, 1 101 | ], dtype=np.float32).reshape(3, 3) 102 | 103 | KH = np.eye(4) 104 | KH[:3, :3] = intrinsics 105 | 106 | extrinsics = np.loadtxt(os.path.join(replica_dir, 'traj.txt')).reshape(-1, 4, 4) 107 | poses = [] 108 | projmats = [] 109 | for extrinsic in extrinsics: 110 | p = extrinsic 111 | R = p[:3, :3] 112 | R = np.matmul(R, np.array([ 113 | [1, 0, 0], 114 | [0, -1, 0], 115 | [0, 0, -1] 116 | ])) 117 | p[:3, :3] = R 118 | p = np.linalg.inv(p) 119 | poses.append(p) 120 | # projmat 121 | projmat = KH @ p 122 | projmats.append(projmat) 123 | 124 | return intrinsics, poses, projmats 125 | 126 | 127 | def load_dense_depths(depth_dir): 128 | """Load initialized single-image dense depth predictions 129 | """ 130 | depth_names = sorted(os.listdir(depth_dir)) 131 | depths = [] 132 | for dn in depth_names: 133 | depth = np.load(os.path.join(depth_dir, dn)) 134 | depths.append(depth) 135 | return depths, depth_names 136 | 137 | 138 | def depth2disp(depth): 139 | """Convert depth map to disparity 140 | """ 141 | disp = 1.0 / (depth + 1e-8) 142 | return disp 143 | 144 | 145 | def disp2depth(disp): 146 | """Convert disparity map to depth 147 | """ 148 | depth = 1.0 / (disp + 1e-8) 149 | return depth 150 | -------------------------------------------------------------------------------- /requirements.txt: -------------------------------------------------------------------------------- 1 | tqdm 2 | scikit-learn 3 | opencv-python 4 | opencv-contrib-python 5 | hydra-core 6 | matplotlib 7 | timm 8 | h5py 9 | pytorch_lightning 10 | mmengine 11 | mmcv 12 | ninja 13 | trimesh 14 | open3d 15 | git+https://github.com/NVlabs/nvdiffrast.git -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | from datetime import datetime 2 | import os, time, random, argparse 3 | 4 | import numpy as np 5 | import cv2 6 | 7 | from tqdm import tqdm, trange 8 | 9 | import torch 10 | import torch.nn.functional as F 11 | from torch.utils.data import DataLoader 12 | import torch.optim.lr_scheduler as lr_scheduler 13 | 14 | from lib.load_data import load_data 15 | from lib.alphatablets import AlphaTablets 16 | from dataset.dataset import CustomDataset 17 | from lib.keyframe import get_keyframes 18 | from recon.run_recon import Recon 19 | 20 | 21 | def config_parser(): 22 | '''Define command line arguments 23 | ''' 24 | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) 25 | parser.add_argument('--config', type=str, default='configs/default.yaml', 26 | help='config file path') 27 | parser.add_argument("--seed", type=int, default=777, 28 | help='Random seed') 29 | parser.add_argument("--job", type=str, default=str(time.time()), 30 | help='Job name') 31 | 32 | # learning options 33 | parser.add_argument("--lr_tex", type=float, default=0.01, 34 | help='Learning rate of texture color') 35 | parser.add_argument("--lr_alpha", type=float, default=0.03, 36 | help='Learning rate of texture alpha') 37 | parser.add_argument("--lr_plane_n", type=float, default=0.0001, 38 | help='Learning rate of plane normal') 39 | parser.add_argument("--lr_plane_dis", type=float, default=0.0005, 40 | help='Learning rate of plane distance') 41 | parser.add_argument("--lr_plane_dis_stage2", type=float, default=0.0002, 42 | help='Learning rate of plane distance in stage 2') 43 | 44 | # loss weights 45 | parser.add_argument("--weight_alpha_inv", type=float, default=1.0, 46 | help='Weight of alpha inv loss') 47 | parser.add_argument("--weight_normal", type=float, default=4.0, 48 | help='Weight of direct normal loss') 49 | parser.add_argument("--weight_depth", type=float, default=4.0, 50 | help='Weight of direct depth loss') 51 | parser.add_argument("--weight_distortion", type=float, default=20.0, 52 | help='Weight of distortion loss') 53 | parser.add_argument("--weight_decay", type=float, default=0.9, 54 | help='Weight of tablet alpha decay after a single step. -1 denotes automatic decay') 55 | 56 | # merging options 57 | parser.add_argument("--merge_normal_thres_init", type=float, default=0.97, 58 | help='Threshold of init merging planes') 59 | parser.add_argument("--merge_normal_thres", type=float, default=0.93, 60 | help='Threshold of merging planes') 61 | parser.add_argument("--merge_dist_thres1", type=float, default=0.5, 62 | help='Threshold of init merging planes') 63 | parser.add_argument("--merge_dist_thres2", type=float, default=0.1, 64 | help='Threshold of merging planes') 65 | parser.add_argument("--merge_color_thres1", type=float, default=0.3, 66 | help='Threshold of init merging planes') 67 | parser.add_argument("--merge_color_thres2", type=float, default=0.2, 68 | help='Threshold of merging planes') 69 | parser.add_argument("--merge_Y_decay", type=float, default=0.5, 70 | help='Decay rate of Y channel') 71 | 72 | # optimization options 73 | parser.add_argument("--batch_size", type=int, default=3, 74 | help='Batch size') 75 | parser.add_argument("--max_steps", type=int, default=32, 76 | help='Max optimization steps') 77 | parser.add_argument("--merge_interval", type=int, default=13, 78 | help='Merge interval') 79 | parser.add_argument("--max_steps_union", type=int, default=9, 80 | help='Max optimization steps for union optimization') 81 | parser.add_argument("--merge_interval_union", type=int, default=3, 82 | help='Merge interval for union optimization') 83 | 84 | parser.add_argument("--alpha_init", type=float, default=0.5, 85 | help='Initial alpha value') 86 | parser.add_argument("--alpha_init_empty", type=float, default=0., 87 | help='Initial alpha value for empty pixels') 88 | parser.add_argument("--depth_inside_mask_alphathres", type=float, default=0.5, 89 | help='Threshold of alpha for inside mask') 90 | 91 | parser.add_argument("--max_rasterize_layers", type=int, default=15, 92 | help='Max rasterize layers') 93 | 94 | # logging/saving options 95 | parser.add_argument("--log_path", type=str, default='./logs', 96 | help='path to save logs') 97 | parser.add_argument("--dump_images", type=bool, default=False) 98 | parser.add_argument("--dump_interval", type=int, default=1, 99 | help='Dump interval') 100 | 101 | # update input 102 | parser.add_argument("--input_dir", type=str, default='', 103 | help='input directory') 104 | 105 | args = parser.parse_args() 106 | 107 | from hydra import compose, initialize 108 | 109 | initialize(version_base=None, config_path='./') 110 | ori_cfg = compose(config_name=args.config)['configs'] 111 | 112 | class Struct: 113 | def __init__(self, **entries): 114 | self.__dict__.update(entries) 115 | 116 | cfg = Struct(**ori_cfg) 117 | 118 | # dump args and cfg 119 | import json 120 | os.makedirs(os.path.join(args.log_path, args.job), exist_ok=True) 121 | with open(os.path.join(args.log_path, args.job, 'args.json'), 'w') as f: 122 | args_json = {k: v for k, v in vars(args).items() if k != 'config'} 123 | json.dump(args_json, f, indent=4) 124 | 125 | import shutil 126 | shutil.copy(args.config, os.path.join(args.log_path, args.job, 'config.yaml')) 127 | 128 | if args.input_dir != '': 129 | cfg.data.input_dir = args.input_dir 130 | 131 | return args, cfg 132 | 133 | 134 | def seed_everything(): 135 | '''Seed everything for better reproducibility. 136 | (some pytorch operation is non-deterministic like the backprop of grid_samples) 137 | ''' 138 | torch.manual_seed(args.seed) 139 | np.random.seed(args.seed) 140 | random.seed(args.seed) 141 | 142 | 143 | def load_everything(cfg, kf_list, image_skip, load_plane=True): 144 | '''Load images / poses / camera settings / data split. 145 | ''' 146 | data_dict = load_data(cfg.data, kf_list=kf_list, image_skip=image_skip, load_plane=load_plane) 147 | data_dict['poses'] = torch.Tensor(data_dict['poses']) 148 | return data_dict 149 | 150 | 151 | if __name__=='__main__': 152 | # load setup 153 | args, cfg = config_parser() 154 | merge_cfgs = { 155 | 'normal_thres_init': args.merge_normal_thres_init, 156 | 'normal_thres': args.merge_normal_thres, 157 | 'dist_thres1': args.merge_dist_thres1, 158 | 'dist_thres2': args.merge_dist_thres2, 159 | 'color_thres1': args.merge_color_thres1, 160 | 'color_thres2': args.merge_color_thres2, 161 | 'Y_decay': args.merge_Y_decay 162 | } 163 | 164 | # init enviroment 165 | if torch.cuda.is_available(): 166 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 167 | device = torch.device('cuda') 168 | else: 169 | torch.set_default_tensor_type('torch.FloatTensor') 170 | device = torch.device('cpu') 171 | torch.set_default_dtype(torch.float32) 172 | seed_everything() 173 | 174 | recon = Recon(cfg.data.init, cfg.data.input_dir, cfg.data.dataset_type) 175 | recon.recon() 176 | 177 | kf_lists, image_skip = get_keyframes(cfg.data.input_dir) 178 | subset_mean_len = np.mean([ k[-1]-k[0]+1 for k in kf_lists ]) 179 | kf_lists.append([0, -1]) 180 | 181 | for set_num, kf_list in enumerate(kf_lists): 182 | if set_num == len(kf_lists) - 1: 183 | print(str(datetime.now()) + f': \033[94mUnion Optimization', '\033[0m') 184 | else: 185 | print(str(datetime.now()) + f': \033[94mSubset {set_num+1}/{len(kf_lists)-1}, Keyframes', kf_list, '\033[0m') 186 | 187 | union_optimize = False 188 | if kf_list[0] == 0 and kf_list[1] == -1: 189 | union_optimize = True 190 | if os.path.exists(os.path.join(args.log_path, args.job, 'ckpt', f'./ckpt_{set_num:02d}_{args.max_steps_union-1}.pt')): 191 | print(str(datetime.now()) + f': \033[94mUnion optimization already done!', '\033[0m') 192 | continue 193 | else: 194 | if os.path.exists(os.path.join(args.log_path, args.job, 'ckpt', f'./ckpt_{set_num:02d}_{args.max_steps-1}.pt')): 195 | print(str(datetime.now()) + f': \033[94mSubset {set_num+1}/{len(kf_lists)-1} optimization already done!', '\033[0m') 196 | continue 197 | 198 | if kf_list[-1] != -1: 199 | recon.run_sp(kf_list) 200 | 201 | # load images / poses / camera settings / data split 202 | data_dict = load_everything(cfg=cfg, kf_list=kf_list, image_skip=image_skip, 203 | load_plane=not union_optimize) 204 | 205 | images, mvp_mtxs, K = data_dict['images'], data_dict['mvp_mtxs'], data_dict['intr'] 206 | index_init = data_dict['index_init'] 207 | new_sps = data_dict['new_sps'] 208 | cam_centers = data_dict['cam_centers'] 209 | normals = data_dict['normals'] 210 | mean_colors = data_dict['mean_colors'] 211 | depths = data_dict['all_depths'] 212 | 213 | sp_images = np.stack([ images[idx] for idx in index_init ]) 214 | 215 | if not union_optimize: 216 | pp = AlphaTablets(plane_params=data_dict['planes_all'], 217 | sp=[ sp_images , new_sps, mvp_mtxs[index_init]], 218 | cam_centers=cam_centers[index_init], 219 | mean_colors=mean_colors, 220 | HW=images[0].shape[0:2], 221 | alpha_init=args.alpha_init, alpha_init_empty=args.alpha_init_empty, 222 | inside_mask_alpha_thres=args.depth_inside_mask_alphathres, 223 | merge_cfgs=merge_cfgs) 224 | else: 225 | pp = AlphaTablets(ckpt_paths=[ 226 | os.path.join(args.log_path, args.job, 'ckpt', f'./ckpt_{ii:02d}_{args.max_steps-1}.pt') 227 | for ii in range(len(kf_lists) - 1) 228 | ], merge_cfgs=merge_cfgs, 229 | alpha_init=args.alpha_init, alpha_init_empty=args.alpha_init_empty, 230 | inside_mask_alpha_thres=args.depth_inside_mask_alphathres) 231 | 232 | dataset = CustomDataset(images, mvp_mtxs, normals, depths) 233 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, generator=torch.Generator(device=device)) 234 | 235 | optimizer = torch.optim.Adam([ 236 | {'params': pp.tex_color, 'lr': args.lr_tex}, 237 | {'params': pp.tex_alpha, 'lr': args.lr_alpha}, 238 | {'params': pp.plane_n, 'lr': args.lr_plane_n}, 239 | {'params': pp.plane_dis, 'lr': args.lr_plane_dis}, 240 | ]) 241 | 242 | scheduler = lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.3) 243 | crop = cfg.data.crop 244 | 245 | max_steps = args.max_steps if not union_optimize else args.max_steps_union 246 | merge_interval = args.merge_interval if not union_optimize else args.merge_interval_union 247 | 248 | psnr_lst = [] 249 | time0 = time.time() 250 | print(str(datetime.now()) + ': \033[92mI', 'Start AlphaTablets Optimization', '\033[0m') 251 | for global_step in trange(0, max_steps): 252 | avg_psnr = 0 253 | for batch_idx, (imgs, mvp_mtxs, normals, depths) in enumerate(dataloader): 254 | mvp_mtxs = mvp_mtxs.cuda() 255 | imgs = imgs.cuda() 256 | normals = normals.cuda() 257 | depths = depths.cuda() 258 | 259 | render_result = pp(mvp_mtxs, 260 | max_rasterize_layers=args.max_rasterize_layers) 261 | 262 | try: 263 | render_result, render_normal, render_depth, alpha_loss, distort_loss, inside_mask, plane_inside_mask = render_result 264 | except: 265 | continue 266 | 267 | # gradient descent step 268 | optimizer.zero_grad(set_to_none=True) 269 | crop_mask = torch.zeros_like(imgs) 270 | crop_mask[:, crop:-crop, crop:-crop] = 1 271 | mask = crop_mask * inside_mask 272 | loss = F.mse_loss(render_result * mask, imgs * mask) 273 | direct_normal_loss = F.mse_loss(render_normal, normals) 274 | direct_depth_loss = F.mse_loss(render_depth * mask[..., 0] * plane_inside_mask[..., 0], depths * mask[..., 0] * plane_inside_mask[..., 0]) 275 | psnr = -10 * torch.log10(loss) 276 | 277 | loss += alpha_loss * args.weight_alpha_inv 278 | loss += distort_loss * args.weight_distortion 279 | loss += direct_normal_loss * args.weight_normal 280 | loss += direct_depth_loss * args.weight_depth 281 | avg_psnr += psnr.item() 282 | 283 | loss.backward() 284 | optimizer.step() 285 | 286 | if global_step % args.dump_interval == 0 and batch_idx == 0 and args.dump_images: 287 | optimizer.zero_grad(set_to_none=True) 288 | mvp_mtxs = torch.from_numpy(data_dict['mvp_mtxs'][0:1]).cuda().float() 289 | render_result, alpha_acc = pp(mvp_mtxs, return_alpha=True, 290 | max_rasterize_layers=args.max_rasterize_layers) 291 | 292 | os.makedirs(os.path.join(args.log_path, args.job, 'dump_images'), exist_ok=True) 293 | 294 | cv2.imwrite(os.path.join(args.log_path, args.job, 'dump_images', f'{set_num:02d}_{global_step:05d}.png'), render_result[0].detach().cpu().numpy() * 255) 295 | cv2.imwrite(os.path.join(args.log_path, args.job, 'dump_images', f'{set_num:02d}_{global_step:05d}_mask.png'), mask[0].detach().cpu().numpy() * 255) 296 | cv2.imwrite(os.path.join(args.log_path, args.job, 'dump_images', f'{set_num:02d}_{global_step:05d}_alpha.png'), alpha_acc[0].detach().cpu().numpy() * 255) 297 | 298 | error_map = torch.abs(render_result - torch.from_numpy(data_dict['images'][0][None, ...]).cuda().float()) 299 | error_map = error_map[0].detach().cpu().numpy() 300 | # turn to red-blue, 0-255 301 | error_map = (error_map - error_map.min()) / (error_map.max() - error_map.min()) * 255 302 | error_map = cv2.applyColorMap(error_map.astype(np.uint8), cv2.COLORMAP_JET) 303 | cv2.imwrite(os.path.join(args.log_path, args.job, 'dump_images', f'{set_num:02d}_{global_step:05d}_error_map.png'), error_map) 304 | 305 | if global_step % merge_interval == 0 and batch_idx == 0 and global_step > 0: 306 | optimizer.zero_grad(set_to_none=True) 307 | 308 | pp.weightCheck(torch.from_numpy(data_dict['mvp_mtxs']).cuda().float()) 309 | 310 | pp = AlphaTablets(pp=pp, merge_cfgs=merge_cfgs) 311 | del optimizer 312 | optimizer = torch.optim.Adam([ 313 | {'params': pp.tex_color, 'lr': args.lr_tex}, 314 | {'params': pp.tex_alpha, 'lr': args.lr_alpha}, 315 | {'params': pp.plane_n, 'lr': args.lr_plane_n}, 316 | {'params': pp.plane_dis, 'lr': args.lr_plane_dis_stage2}, 317 | ]) 318 | 319 | scheduler.step() 320 | 321 | if global_step == max_steps - 1: 322 | with torch.no_grad(): 323 | os.makedirs(os.path.join(args.log_path, args.job, 'ckpt'), exist_ok=True) 324 | pp.save_ckpt(os.path.join(args.log_path, args.job, 'ckpt', f'ckpt_{set_num:02d}_{global_step}.pt')) 325 | optimizer.zero_grad(set_to_none=True) 326 | 327 | if union_optimize: 328 | with torch.no_grad(): 329 | decay = args.weight_decay if args.weight_decay > 0 else max(1-0.00067*subset_mean_len, 0.8) 330 | pp.tex_alpha.data = torch.log( decay / (1 - decay + torch.exp(-pp.tex_alpha)) ) 331 | elif global_step < args.merge_interval: 332 | with torch.no_grad(): 333 | decay = args.weight_decay if args.weight_decay > 0 else max(1-0.00067*len(images), 0.8) 334 | pp.tex_alpha.data = torch.log( decay / (1 - decay + torch.exp(-pp.tex_alpha)) ) 335 | 336 | if union_optimize: 337 | os.makedirs(os.path.join(args.log_path, args.job, 'results'), exist_ok=True) 338 | export_name_obj = os.path.join(args.log_path, args.job, 'results', f'{set_num:02d}_final.obj') 339 | pp.export_mesh_with_weight_check(torch.from_numpy(data_dict['mvp_mtxs']).cuda().float(), name=export_name_obj) 340 | 341 | os.makedirs(os.path.join(args.log_path, args.job, 'plys'), exist_ok=True) 342 | export_name = os.path.join(args.log_path, args.job, 'plys', f'{set_num:02d}_final.ply') 343 | pp.export_ply(torch.from_numpy(data_dict['mvp_mtxs'][::8]).cuda().float(), name=export_name) 344 | 345 | if set_num == len(kf_lists) - 1: 346 | print(str(datetime.now()) + f': \033[94mUnion optimization finished!', '\033[0m') 347 | else: 348 | print(str(datetime.now()) + f': \033[94mSubset {set_num+1}/{len(kf_lists)-1} optimization finished', '\033[0m') 349 | 350 | # create complete lock 351 | with open(os.path.join(args.log_path, args.job, 'complete.lock'), 'w') as f: 352 | f.write('done') 353 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os 2 | import pickle 3 | import numpy as np 4 | import open3d as o3d 5 | import trimesh 6 | from tqdm import tqdm 7 | from sklearn.metrics import rand_score 8 | from skimage.metrics import variation_of_information 9 | 10 | def compute_sc(gt_in, pred_in): 11 | # to be consistent with skimage sklearn input arrangment 12 | 13 | assert len(pred_in.shape) == 1 and len(gt_in.shape) == 1 14 | 15 | acc, pred, gt = match_seg(pred_in, gt_in) # n_gt * n_pred 16 | 17 | bestmatch_gt2pred = acc.max(axis=1) 18 | bestmatch_pred2gt = acc.max(axis=0) 19 | 20 | pred_id, pred_cnt = np.unique(pred, return_counts=True) 21 | gt_id, gt_cnt = np.unique(gt, return_counts=True) 22 | 23 | cnt_pred, sum_pred = 0, 0 24 | for i, _ in enumerate(pred_id): 25 | cnt_pred += bestmatch_pred2gt[i] * pred_cnt[i] 26 | sum_pred += pred_cnt[i] 27 | 28 | cnt_gt, sum_gt = 0, 0 29 | for i, _ in enumerate(gt_id): 30 | cnt_gt += bestmatch_gt2pred[i] * gt_cnt[i] 31 | sum_gt += gt_cnt[i] 32 | 33 | sc = (cnt_pred / sum_pred + cnt_gt / sum_gt) / 2 34 | 35 | return sc 36 | 37 | def match_seg(pred_in, gt_in): 38 | assert len(pred_in.shape) == 1 and len(gt_in.shape) == 1 39 | 40 | pred, gt = compact_segm(pred_in), compact_segm(gt_in) 41 | n_gt = gt.max() + 1 42 | n_pred = pred.max() + 1 43 | 44 | # this will offer the overlap between gt and pred 45 | # if gt == 1, we will later have conf[1, j] = gt(1) + pred(j) * n_gt 46 | # essential, we encode conf_mat[i, j] to overlap, and when we decode it we let row as gt, and col for pred 47 | # then assume we have 13 gt label, 6 pred label --> gt 1 will correspond to 14, 1+2*13 ... 1 + 6*13 48 | overlap = gt + n_gt * pred 49 | freq, bin_val = np.histogram(overlap, np.arange(0, n_gt * n_pred+1)) # hist given bins [1, 2, 3] --> return [1, 2), [2, 3) 50 | conf_mat = freq.reshape([ n_gt, n_pred], order='F') # column first reshape, like matlab 51 | 52 | acc = np.zeros([n_gt, n_pred]) 53 | for i in range(n_gt): 54 | for j in range(n_pred): 55 | gt_i = conf_mat[i].sum() 56 | pred_j = conf_mat[:, j].sum() 57 | gt_pred = conf_mat[i, j] 58 | acc[i,j] = gt_pred / (gt_i + pred_j - gt_pred) if (gt_i + pred_j - gt_pred) != 0 else 0 59 | return acc[1:, 1:], pred, gt 60 | 61 | def compact_segm(seg_in): 62 | seg = seg_in.copy() 63 | uniq_id = np.unique(seg) 64 | cnt = 1 65 | for id in sorted(uniq_id): 66 | if id == 0: 67 | continue 68 | seg[seg==id] = cnt 69 | cnt += 1 70 | 71 | # every id (include non-plane should not be 0 for the later process in match_seg 72 | seg = seg + 1 73 | return seg 74 | 75 | def project_to_mesh(from_mesh, to_mesh, attribute, attr_name, color_mesh=None, dist_thresh=None): 76 | """ Transfers attributs from from_mesh to to_mesh using nearest neighbors 77 | 78 | Each vertex in to_mesh gets assigned the attribute of the nearest 79 | vertex in from mesh. Used for semantic evaluation. 80 | 81 | Args: 82 | from_mesh: Trimesh with known attributes 83 | to_mesh: Trimesh to be labeled 84 | attribute: Which attribute to transfer 85 | dist_thresh: Do not transfer attributes beyond this distance 86 | (None transfers regardless of distacne between from and to vertices) 87 | 88 | Returns: 89 | Trimesh containing transfered attribute 90 | """ 91 | 92 | if len(from_mesh.vertices) == 0: 93 | to_mesh.vertex_attributes[attr_name] = np.zeros((0), dtype=np.uint8) 94 | to_mesh.visual.vertex_colors = np.zeros((0), dtype=np.uint8) 95 | return to_mesh 96 | 97 | pcd = o3d.geometry.PointCloud() 98 | pcd.points = o3d.utility.Vector3dVector(from_mesh.vertices) 99 | kdtree = o3d.geometry.KDTreeFlann(pcd) 100 | 101 | pred_ids = attribute.copy() 102 | pred_colors = from_mesh.visual.vertex_colors if color_mesh is None else color_mesh.visual.vertex_colors 103 | 104 | matched_ids = np.zeros((to_mesh.vertices.shape[0]), dtype=np.uint8) 105 | matched_colors = np.zeros((to_mesh.vertices.shape[0], 4), dtype=np.uint8) 106 | 107 | for i, vert in enumerate(to_mesh.vertices): 108 | _, inds, dist = kdtree.search_knn_vector_3d(vert, 1) 109 | if dist_thresh is None or dist[0]= xmin, sample_points[:, 0] <= xmax), 244 | np.logical_and(sample_points[:, 1] >= ymin, sample_points[:, 1] <= ymax), 245 | np.logical_and(sample_points[:, 2] >= zmin, sample_points[:, 2] <= zmax), 246 | ) 247 | points_mask = np.logical_and(points_mask, vertices_mask_new) 248 | 249 | sample_points = sample_points[points_mask] 250 | sample_indices = sample_indices[points_mask] 251 | 252 | vertices_eval = trimesh.Trimesh(vertices=sample_points, process=False) 253 | vertices_eval.export(mesh_file_eval_ori.replace('.obj', '.ply')) 254 | mesh_file_eval = mesh_file_eval_ori.replace('.obj', '.ply') 255 | 256 | # eval 3d geometry 257 | metrics_mesh, prec_err_pcd, recal_err_pcd = eval_mesh(mesh_file_eval, file_mesh_trgt, error_map=False) 258 | metrics = {**metrics_mesh} 259 | # o3d.io.write_triangle_mesh(os.path.join('./','%s_precErr.ply' % scene), prec_err_pcd) 260 | # o3d.io.write_triangle_mesh(os.path.join('./', '%s_recErr.ply' % scene), recal_err_pcd) 261 | 262 | dist1.append(metrics['dist1']) 263 | dist2.append(metrics['dist2']) 264 | prec.append(metrics['prec']) 265 | recall.append(metrics['recal']) 266 | fscore.append(metrics['fscore']) 267 | 268 | # prepare files for instance evaluation 269 | mesh_trgt = trimesh.load(file_mesh_trgt, process=False) 270 | 271 | new_pred_ins = np.array(instance_ids)[np.array(sample_indices).astype('int32')].astype('int32') 272 | 273 | # specify color to vertces_eval by color pool with new_pred_ins 274 | color_pool = np.random.rand(32768, 3) * 255 275 | color_pool = np.concatenate([color_pool, np.ones((32768, 1)) * 255], axis=1).astype(np.uint8) 276 | colors = color_pool[new_pred_ins] 277 | vertices_eval.visual.vertex_colors = colors 278 | 279 | mesh_planeIns_transfer = project_to_mesh(vertices_eval, mesh_trgt, new_pred_ins, 'plane_ins') 280 | 281 | planeIns = mesh_planeIns_transfer.vertex_attributes['plane_ins'] 282 | 283 | plnIns_save_pth = os.path.join(save_path, 'plane_ins') 284 | if not os.path.isdir(plnIns_save_pth): 285 | os.makedirs(plnIns_save_pth) 286 | 287 | mesh_planeIns_transfer.export(os.path.join(plnIns_save_pth, '%s_planeIns_transfer.ply' % scene)) 288 | np.savetxt(plnIns_save_pth + '/%s.txt'%scene, planeIns, fmt='%d') 289 | 290 | pred_pth = os.path.join(plnIns_save_pth, '{}.txt'.format(scene)) 291 | gt_pth = os.path.join(f'./planes_9/instance/{scene}.txt') 292 | 293 | pred_ins = np.loadtxt(pred_pth).astype(np.int32) 294 | gt_ins = np.loadtxt(gt_pth).astype(np.int32) 295 | 296 | ri = rand_score(gt_ins, pred_ins) 297 | h1, h2 = variation_of_information(gt_ins, pred_ins) 298 | voi = h1 + h2 299 | sc = compute_sc(gt_ins, pred_ins) 300 | 301 | ris.append(ri) 302 | vois.append(voi) 303 | scs.append(sc) 304 | 305 | return metrics, ri, voi, sc 306 | 307 | 308 | if __name__ == '__main__': 309 | import glob 310 | now_scenes = sorted(glob.glob('./logs/scene????_??')) 311 | flag = False 312 | stats = [] 313 | for scene in tqdm(now_scenes): 314 | scene = scene.split('/')[-1] 315 | metrics, ri, voi, sc = process(scene) 316 | stats.append([metrics['dist1'], metrics['dist2'], metrics['prec'], metrics['recal'], metrics['fscore'], ri, voi, sc]) 317 | print('scene', scene, '\t'.join([f'{k}: {v:.4f}' for k, v in metrics.items()])) 318 | 319 | stats = np.array(stats) 320 | print(f'dist1:\t{np.mean(stats[:, 0]):.3f}') 321 | print(f'dist2:\t{np.mean(stats[:, 1]):.3f}') 322 | print(f'prec:\t{np.mean(stats[:, 2]):.3f}') 323 | print(f'recall:\t{np.mean(stats[:, 3]):.3f}') 324 | print(f'fscore:\t{np.mean(stats[:, 4]):.3f}') 325 | print(f'ri:\t{np.mean(stats[:, 5]):.3f}') 326 | print(f'voi:\t{np.mean(stats[:, 6]):.3f}') 327 | print(f'sc:\t{np.mean(stats[:, 7]):.3f}') 328 | -------------------------------------------------------------------------------- /tools/__init__.py: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/THU-LYJ-Lab/AlphaTablets/735cbfe6aa7f03f7bc37f48303045fe71ffa042a/tools/__init__.py -------------------------------------------------------------------------------- /tools/chamfer3D/chamfer3D.cu: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include 4 | 5 | #include 6 | #include 7 | 8 | #include 9 | 10 | 11 | 12 | __global__ void NmDistanceKernel(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i){ 13 | const int batch=512; 14 | __shared__ float buf[batch*3]; 15 | for (int i=blockIdx.x;ibest){ 127 | result[(i*n+j)]=best; 128 | result_i[(i*n+j)]=best_i; 129 | } 130 | } 131 | __syncthreads(); 132 | } 133 | } 134 | } 135 | // int chamfer_cuda_forward(int b,int n,const float * xyz,int m,const float * xyz2,float * result,int * result_i,float * result2,int * result2_i, cudaStream_t stream){ 136 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2){ 137 | 138 | const auto batch_size = xyz1.size(0); 139 | const auto n = xyz1.size(1); //num_points point cloud A 140 | const auto m = xyz2.size(1); //num_points point cloud B 141 | 142 | NmDistanceKernel<<>>(batch_size, n, xyz1.data(), m, xyz2.data(), dist1.data(), idx1.data()); 143 | NmDistanceKernel<<>>(batch_size, m, xyz2.data(), n, xyz1.data(), dist2.data(), idx2.data()); 144 | 145 | cudaError_t err = cudaGetLastError(); 146 | if (err != cudaSuccess) { 147 | printf("error in nnd updateOutput: %s\n", cudaGetErrorString(err)); 148 | //THError("aborting"); 149 | return 0; 150 | } 151 | return 1; 152 | 153 | 154 | } 155 | __global__ void NmDistanceGradKernel(int b,int n,const float * xyz1,int m,const float * xyz2,const float * grad_dist1,const int * idx1,float * grad_xyz1,float * grad_xyz2){ 156 | for (int i=blockIdx.x;i>>(batch_size,n,xyz1.data(),m,xyz2.data(),graddist1.data(),idx1.data(),gradxyz1.data(),gradxyz2.data()); 185 | NmDistanceGradKernel<<>>(batch_size,m,xyz2.data(),n,xyz1.data(),graddist2.data(),idx2.data(),gradxyz2.data(),gradxyz1.data()); 186 | 187 | cudaError_t err = cudaGetLastError(); 188 | if (err != cudaSuccess) { 189 | printf("error in nnd get grad: %s\n", cudaGetErrorString(err)); 190 | //THError("aborting"); 191 | return 0; 192 | } 193 | return 1; 194 | 195 | } 196 | 197 | -------------------------------------------------------------------------------- /tools/chamfer3D/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | ///TMP 5 | //#include "common.h" 6 | /// NOT TMP 7 | 8 | 9 | int chamfer_cuda_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2); 10 | 11 | 12 | int chamfer_cuda_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2); 13 | 14 | 15 | 16 | 17 | int chamfer_forward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor dist1, at::Tensor dist2, at::Tensor idx1, at::Tensor idx2) { 18 | return chamfer_cuda_forward(xyz1, xyz2, dist1, dist2, idx1, idx2); 19 | } 20 | 21 | 22 | int chamfer_backward(at::Tensor xyz1, at::Tensor xyz2, at::Tensor gradxyz1, at::Tensor gradxyz2, at::Tensor graddist1, 23 | at::Tensor graddist2, at::Tensor idx1, at::Tensor idx2) { 24 | 25 | return chamfer_cuda_backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2); 26 | } 27 | 28 | 29 | 30 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 31 | m.def("forward", &chamfer_forward, "chamfer forward (CUDA)"); 32 | m.def("backward", &chamfer_backward, "chamfer backward (CUDA)"); 33 | } -------------------------------------------------------------------------------- /tools/chamfer3D/dist_chamfer_3D.py: -------------------------------------------------------------------------------- 1 | from torch import nn 2 | from torch.autograd import Function 3 | import torch 4 | import importlib 5 | import os 6 | chamfer_found = importlib.find_loader("chamfer_3D") is not None 7 | if not chamfer_found: 8 | ## Cool trick from https://github.com/chrdiller 9 | print("Jitting Chamfer 3D") 10 | cur_path = os.path.dirname(os.path.abspath(__file__)) 11 | build_path = cur_path.replace('chamfer3D', 'tmp') 12 | os.makedirs(build_path, exist_ok=True) 13 | 14 | from torch.utils.cpp_extension import load 15 | chamfer_3D = load(name="chamfer_3D", 16 | sources=[ 17 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer_cuda.cpp"]), 18 | "/".join(os.path.abspath(__file__).split('/')[:-1] + ["chamfer3D.cu"]), 19 | ], build_directory=build_path) 20 | print("Loaded JIT 3D CUDA chamfer distance") 21 | 22 | else: 23 | import chamfer_3D 24 | print("Loaded compiled 3D CUDA chamfer distance") 25 | 26 | 27 | # Chamfer's distance module @thibaultgroueix 28 | # GPU tensors only 29 | class chamfer_3DFunction(Function): 30 | @staticmethod 31 | def forward(ctx, xyz1, xyz2): 32 | batchsize, n, dim = xyz1.size() 33 | assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" 34 | _, m, dim = xyz2.size() 35 | assert dim==3, "Wrong last dimension for the chamfer distance 's input! Check with .size()" 36 | device = xyz1.device 37 | 38 | device = xyz1.device 39 | 40 | dist1 = torch.zeros(batchsize, n) 41 | dist2 = torch.zeros(batchsize, m) 42 | 43 | idx1 = torch.zeros(batchsize, n).type(torch.IntTensor) 44 | idx2 = torch.zeros(batchsize, m).type(torch.IntTensor) 45 | 46 | dist1 = dist1.to(device) 47 | dist2 = dist2.to(device) 48 | idx1 = idx1.to(device) 49 | idx2 = idx2.to(device) 50 | torch.cuda.set_device(device) 51 | 52 | chamfer_3D.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 53 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 54 | return dist1, dist2, idx1, idx2 55 | 56 | @staticmethod 57 | def backward(ctx, graddist1, graddist2, gradidx1, gradidx2): 58 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 59 | graddist1 = graddist1.contiguous() 60 | graddist2 = graddist2.contiguous() 61 | device = graddist1.device 62 | 63 | gradxyz1 = torch.zeros(xyz1.size()) 64 | gradxyz2 = torch.zeros(xyz2.size()) 65 | 66 | gradxyz1 = gradxyz1.to(device) 67 | gradxyz2 = gradxyz2.to(device) 68 | chamfer_3D.backward( 69 | xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2 70 | ) 71 | return gradxyz1, gradxyz2 72 | 73 | 74 | class chamfer_3DDist(nn.Module): 75 | def __init__(self): 76 | super(chamfer_3DDist, self).__init__() 77 | 78 | def forward(self, input1, input2): 79 | input1 = input1.contiguous() 80 | input2 = input2.contiguous() 81 | return chamfer_3DFunction.apply(input1, input2) 82 | -------------------------------------------------------------------------------- /tools/chamfer3D/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 3 | 4 | setup( 5 | name='chamfer_3D', 6 | ext_modules=[ 7 | CUDAExtension('chamfer_3D', [ 8 | "/".join(__file__.split('/')[:-1] + ['chamfer_cuda.cpp']), 9 | "/".join(__file__.split('/')[:-1] + ['chamfer3D.cu']), 10 | ]), 11 | ], 12 | cmdclass={ 13 | 'build_ext': BuildExtension 14 | }) -------------------------------------------------------------------------------- /tools/generate_gt.py: -------------------------------------------------------------------------------- 1 | # This file is derived from [NeuralRecon](https://github.com/zju3dv/NeuralRecon). 2 | # Originating Author: Yiming Xie 3 | # Modified for [PlanarRecon](https://github.com/neu-vi/PlanarRecon) by Yiming Xie. 4 | 5 | # Original header: 6 | # Copyright SenseTime. All Rights Reserved. 7 | 8 | # Licensed under the Apache License, Version 2.0 (the "License"); 9 | # you may not use this file except in compliance with the License. 10 | # You may obtain a copy of the License at 11 | 12 | # http://www.apache.org/licenses/LICENSE-2.0 13 | 14 | # Unless required by applicable law or agreed to in writing, software 15 | # distributed under the License is distributed on an "AS IS" BASIS, 16 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 17 | # See the License for the specific language governing permissions and 18 | # limitations under the License. 19 | 20 | import sys 21 | import os 22 | 23 | import numpy as np 24 | 25 | sys.path.append('.') 26 | 27 | import pickle 28 | import argparse 29 | from tqdm import tqdm 30 | import ray 31 | import torch.multiprocessing 32 | from tools.simple_loader import * 33 | from tools.generate_planes import generate_planes 34 | from tools.chamfer3D import dist_chamfer_3D 35 | 36 | torch.multiprocessing.set_sharing_strategy('file_system') 37 | 38 | 39 | def coordinates(voxel_dim, device=torch.device('cuda')): 40 | """ 3d meshgrid of given size. 41 | 42 | Args: 43 | voxel_dim: tuple of 3 ints (nx,ny,nz) specifying the size of the volume 44 | 45 | Returns: 46 | torch long tensor of size (3,nx*ny*nz) 47 | """ 48 | 49 | nx, ny, nz = voxel_dim 50 | x = torch.arange(nx, dtype=torch.long, device=device) 51 | y = torch.arange(ny, dtype=torch.long, device=device) 52 | z = torch.arange(nz, dtype=torch.long, device=device) 53 | x, y, z = torch.meshgrid(x, y, z) 54 | return torch.stack((x.flatten(), y.flatten(), z.flatten())) 55 | 56 | 57 | def parse_args(): 58 | parser = argparse.ArgumentParser(description='Generate ground truth plane') 59 | parser.add_argument("--dataset", default='scannet') 60 | parser.add_argument("--data_path", metavar="DIR", 61 | help="path to dataset", default='./scannet') 62 | parser.add_argument("--save_name", metavar="DIR", 63 | help="file name", default='planes_9/') 64 | parser.add_argument('--max_depth', default=3., type=float, 65 | help='mask out large depth values since they are noisy') 66 | parser.add_argument('--voxel_size', default=0.04, type=float) 67 | 68 | parser.add_argument('--window_size', default=9, type=int) 69 | parser.add_argument('--min_angle', default=15, type=float) 70 | parser.add_argument('--min_distance', default=0.1, type=float) 71 | 72 | # ray multi processes 73 | parser.add_argument('--n_proc', type=int, default=8, help='#processes launched to process scenes.') 74 | parser.add_argument('--n_gpu', type=int, default=1, help='#number of gpus') 75 | parser.add_argument('--num_workers', type=int, default=8) 76 | parser.add_argument('--loader_num_workers', type=int, default=8) 77 | return parser.parse_args() 78 | 79 | 80 | args = parse_args() 81 | args.save_path = args.save_name 82 | 83 | 84 | def rigid_transform(xyz, transform): 85 | """Applies a rigid transform to an (N, 3) pointcloud. 86 | """ 87 | xyz_h = np.hstack([xyz, np.ones((len(xyz), 1), dtype=np.float32)]) 88 | xyz_t_h = np.dot(transform, xyz_h.T).T 89 | return xyz_t_h[:, :3] 90 | 91 | 92 | def get_view_frustum(depth_im, cam_intr, cam_pose, max_depth=3.0): 93 | """Get corners of 3D camera view frustum of depth image 94 | """ 95 | if depth_im is not None: 96 | im_h = depth_im.shape[0] 97 | im_w = depth_im.shape[1] 98 | max_depth = np.max(depth_im) 99 | else: 100 | im_h = 480 101 | im_w = 640 102 | view_frust_pts = np.array([ 103 | (np.array([0, 0, 0, im_w, im_w]) - cam_intr[0, 2]) * np.array([0, max_depth, max_depth, max_depth, max_depth]) / 104 | cam_intr[0, 0], 105 | (np.array([0, 0, im_h, 0, im_h]) - cam_intr[1, 2]) * np.array([0, max_depth, max_depth, max_depth, max_depth]) / 106 | cam_intr[1, 1], 107 | np.array([0, max_depth, max_depth, max_depth, max_depth]) 108 | ]) 109 | view_frust_pts = rigid_transform(view_frust_pts.T, cam_pose).T 110 | return view_frust_pts 111 | 112 | 113 | def compute_global_volume(args, cam_intr, cam_pose_list): 114 | # ======================================================================================================== # 115 | # (Optional) This is an example of how to compute the 3D bounds 116 | # in world coordinates of the convex hull of all camera view 117 | # frustums in the dataset 118 | # ======================================================================================================== # 119 | vol_bnds = np.zeros((3, 2)) 120 | 121 | n_imgs = len(cam_pose_list.keys()) 122 | if n_imgs > 200: 123 | ind = np.linspace(0, n_imgs - 1, 200).astype(np.int32) 124 | image_id = np.array(list(cam_pose_list.keys()))[ind] 125 | else: 126 | image_id = cam_pose_list.keys() 127 | for id in image_id: 128 | cam_pose = cam_pose_list[id] 129 | 130 | # Compute camera view frustum and extend convex hull 131 | view_frust_pts = get_view_frustum(None, cam_intr, cam_pose, max_depth=args.max_depth) 132 | vol_bnds[:, 0] = np.minimum(vol_bnds[:, 0], np.amin(view_frust_pts, axis=1)) 133 | vol_bnds[:, 1] = np.maximum(vol_bnds[:, 1], np.amax(view_frust_pts, axis=1)) 134 | # ======================================================================================================== # 135 | 136 | # Adjust volume bounds and ensure C-order contiguous 137 | vol_dim = np.round((vol_bnds[:, 1] - vol_bnds[:, 0]) / args.voxel_size).copy( 138 | order='C').astype(int) 139 | vol_bnds[:, 1] = vol_bnds[:, 0] + vol_dim * args.voxel_size 140 | vol_origin = vol_bnds[:, 0].copy(order='C').astype(np.float32) 141 | 142 | return vol_dim, vol_origin 143 | 144 | 145 | def save_label_full(args, scene, vol_dim, vol_origin, planes, plane_points): 146 | planes = np.concatenate([planes, - np.ones_like(planes[:, :1])], axis=-1) 147 | 148 | # ========================generate indicator gt======================== 149 | planes = torch.from_numpy(planes).cuda() 150 | coords = coordinates(vol_dim, device=planes.device) 151 | coords = coords.type(torch.float) * args.voxel_size + torch.from_numpy(vol_origin).view(3, 1).cuda() 152 | coords = coords.permute(1, 0).contiguous() 153 | 154 | min_dist = None 155 | indices = None 156 | for i, points in enumerate(plane_points[:]): 157 | points = torch.from_numpy(points).cuda() 158 | chamLoss = dist_chamfer_3D.chamfer_3DDist() 159 | dist1, _, _, _ = chamLoss(coords.unsqueeze(0), points.unsqueeze(0)) 160 | if min_dist is None: 161 | min_dist = dist1 162 | indices = torch.zeros_like(dist1) 163 | else: 164 | current_id = torch.ones_like(indices) * i 165 | indices = torch.where(dist1 < min_dist, current_id, indices) 166 | min_dist = torch.where(dist1 < min_dist, dist1, min_dist) 167 | # remove too far points which may not have a plane 168 | current_id = torch.ones_like(indices) * -1 169 | indices = torch.where(0.36 ** 2 < min_dist, current_id, indices) 170 | indices = indices.view(vol_dim.tolist()).data.cpu().numpy() 171 | 172 | np.savez_compressed(os.path.join(args.save_path, scene, 'indices'), indices) 173 | # ============================================================================================== 174 | 175 | 176 | def save_fragment_pkl(args, scene, cam_pose_list, vol_dim, vol_origin): 177 | # view selection 178 | fragments = [] 179 | print('segment: process scene {}'.format(scene)) 180 | 181 | all_ids = [] 182 | ids = [] 183 | count = 0 184 | last_pose = None 185 | for id in cam_pose_list.keys(): 186 | cam_pose = cam_pose_list[id] 187 | 188 | if count == 0: 189 | ids.append(id) 190 | last_pose = cam_pose 191 | count += 1 192 | else: 193 | angle = np.arccos( 194 | ((np.linalg.inv(cam_pose[:3, :3]) @ last_pose[:3, :3] @ np.array([0, 0, 1]).T) * np.array( 195 | [0, 0, 1])).sum()) 196 | dis = np.linalg.norm(cam_pose[:3, 3] - last_pose[:3, 3]) 197 | if angle > (args.min_angle / 180) * np.pi or dis > args.min_distance: 198 | ids.append(id) 199 | last_pose = cam_pose 200 | count += 1 201 | if count == args.window_size: 202 | all_ids.append(ids) 203 | ids = [] 204 | count = 0 205 | 206 | # save fragments 207 | for i, ids in enumerate(all_ids): 208 | fragments.append({ 209 | 'scene': scene, 210 | 'fragment_id': i, 211 | 'image_ids': ids, 212 | 'vol_origin': vol_origin, 213 | 'vol_dim': vol_dim, 214 | 'voxel_size': args.voxel_size, 215 | }) 216 | 217 | with open(os.path.join(args.save_path, scene, 'fragments.pkl'), 'wb') as f: 218 | pickle.dump(fragments, f) 219 | 220 | return 221 | 222 | 223 | @ray.remote(num_cpus=args.num_workers + 1, num_gpus=(1 / args.n_proc)) 224 | def process_with_single_worker(args, scannet_files): 225 | planes_all = [] 226 | for scene in tqdm(scannet_files): 227 | if os.path.exists(os.path.join(args.save_path, scene, 'fragments.pkl')): 228 | continue 229 | print('read from disk') 230 | 231 | cam_pose_all = {} 232 | 233 | n_imgs = len(os.listdir(os.path.join(args.data_path, scene, 'color'))) 234 | intrinsic_dir = os.path.join(args.data_path, scene, 'intrinsic', 'intrinsic_depth.txt') 235 | cam_intr = np.loadtxt(intrinsic_dir, delimiter=' ')[:3, :3] 236 | dataset = ScanNetDataset(n_imgs, scene, args.data_path, args.max_depth) 237 | 238 | planes, plane_points = generate_planes(args, scene, save_mesh=True) 239 | 240 | dataloader = torch.utils.data.DataLoader(dataset, batch_size=None, collate_fn=collate_fn, 241 | batch_sampler=None, num_workers=args.loader_num_workers) 242 | 243 | for id, (cam_pose, _, _) in enumerate(dataloader): 244 | if id % 100 == 0: 245 | print("{}: read frame {}/{}".format(scene, str(id), str(n_imgs))) 246 | 247 | if cam_pose[0][0] == np.inf or cam_pose[0][0] == -np.inf or cam_pose[0][0] == np.nan: 248 | continue 249 | cam_pose_all.update({id: cam_pose}) 250 | 251 | vol_dim, vol_origin = compute_global_volume(args, cam_intr, cam_pose_all) 252 | save_label_full(args, scene, vol_dim, vol_origin, planes, plane_points) 253 | save_fragment_pkl(args, scene, cam_pose_all, vol_dim, vol_origin) 254 | 255 | planes_center = np.array([ 256 | [0, -1, 0], 257 | [0, 1, 0], 258 | [-1, 0, 0], 259 | [-1, 0, -1], 260 | [0, 0, -1], 261 | [1, 0, -1], 262 | [1, 0, 1], 263 | ]) 264 | planes_center = planes_center / np.linalg.norm(planes_center, axis=1)[..., np.newaxis] 265 | 266 | np.save(os.path.join(args.save_path, 'normal_anchors.npy'), planes_center) 267 | 268 | 269 | def split_list(_list, n): 270 | assert len(_list) >= n 271 | ret = [[] for _ in range(n)] 272 | for idx, item in enumerate(_list): 273 | ret[idx % n].append(item) 274 | return ret 275 | 276 | 277 | def generate_pkl(args): 278 | all_scenes = sorted(os.listdir(args.save_path)) 279 | # todo: fix for both train/val 280 | splits = ['train', 'val'] 281 | for split in splits: 282 | fragments = [] 283 | with open(os.path.join(args.data_path[:-6],'scannetv2_{}.txt'.format(split))) as f: 284 | split_files = f.readlines() 285 | for scene in all_scenes: 286 | if 'scene' not in scene: 287 | continue 288 | if scene + '\n' in split_files: 289 | with open(os.path.join(args.save_path, scene, 'fragments.pkl'), 'rb') as f: 290 | frag_scene = pickle.load(f) 291 | fragments.extend(frag_scene) 292 | 293 | with open(os.path.join(args.save_path, 'fragments_{}.pkl'.format(split)), 'wb') as f: 294 | pickle.dump(fragments, f) 295 | 296 | 297 | if __name__ == "__main__": 298 | all_proc = args.n_proc * args.n_gpu 299 | 300 | ray.init(num_cpus=all_proc * (args.num_workers + 1), num_gpus=args.n_gpu) 301 | 302 | if args.dataset == 'scannet': 303 | args.data_raw_path = os.path.join(args.data_path, 'scans/') 304 | args.data_path = os.path.join(args.data_path, 'scans') 305 | files = sorted(os.listdir(args.data_path)) 306 | else: 307 | raise NameError('error!') 308 | 309 | files = split_list(files, all_proc) 310 | 311 | ray_worker_ids = [] 312 | for w_idx in range(all_proc): 313 | ray_worker_ids.append(process_with_single_worker.remote(args, files[w_idx])) 314 | 315 | results = ray.get(ray_worker_ids) 316 | 317 | 318 | # process_with_single_worker(args, files) 319 | 320 | if args.dataset == 'scannet': 321 | generate_pkl(args) 322 | -------------------------------------------------------------------------------- /tools/generate_planes.py: -------------------------------------------------------------------------------- 1 | # This file is derived from [PlaneRCNN](https://github.com/NVlabs/planercnn). 2 | # Originating Author: Chen Liu 3 | # Modified for [PlanarRecon](https://github.com/neu-vi/PlanarRecon) by Yiming Xie. 4 | 5 | # Original header: 6 | # Copyright (C) 2019 NVIDIA Corporation. All rights reserved. 7 | # Licensed under the CC BY-NC-SA 4.0 license 8 | # (https://creativecommons.org/licenses/by-nc-sa/4.0/legalcode). 9 | 10 | import numpy as np 11 | import sys 12 | import os 13 | from plyfile import PlyData, PlyElement 14 | import json 15 | import glob 16 | import cv2 17 | import matplotlib.pyplot as plt 18 | import torch 19 | from tools.random_color import random_color 20 | 21 | numPlanes = 200 22 | numPlanesPerSegment = 2 23 | planeAreaThreshold = 100 24 | numIterations = 100 25 | numIterationsPair = 1000 26 | planeDiffThreshold = 0.05 27 | fittingErrorThreshold = planeDiffThreshold 28 | orthogonalThreshold = np.cos(np.deg2rad(60)) 29 | parallelThreshold = np.cos(np.deg2rad(30)) 30 | 31 | 32 | ## Fit a 3D plane from points 33 | def fitPlane(points): 34 | if points.shape[0] == points.shape[1]: 35 | return np.linalg.solve(points, np.ones(points.shape[0])) 36 | else: 37 | print(points) 38 | return np.linalg.lstsq(points, np.ones(points.shape[0]))[0] 39 | return 40 | 41 | 42 | class ColorPalette: 43 | def __init__(self, numColors): 44 | np.random.seed(2) 45 | self.colorMap = np.array([[255, 0, 0], 46 | [0, 255, 0], 47 | [0, 0, 255], 48 | [80, 128, 255], 49 | [255, 230, 180], 50 | [255, 0, 255], 51 | [0, 255, 255], 52 | [100, 0, 0], 53 | [0, 100, 0], 54 | [255, 255, 0], 55 | [50, 150, 0], 56 | [200, 255, 255], 57 | [255, 200, 255], 58 | [128, 128, 80], 59 | [0, 50, 128], 60 | [0, 100, 100], 61 | [0, 255, 128], 62 | [0, 128, 255], 63 | [255, 0, 128], 64 | [128, 0, 255], 65 | [255, 128, 0], 66 | [128, 255, 0], 67 | ]) 68 | 69 | if numColors > self.colorMap.shape[0]: 70 | self.colorMap = np.concatenate( 71 | [self.colorMap, np.random.randint(255, size=(numColors - self.colorMap.shape[0], 3))], axis=0) 72 | pass 73 | 74 | return 75 | 76 | def getColorMap(self): 77 | return self.colorMap 78 | 79 | def getColor(self, index): 80 | if index >= colorMap.shape[0]: 81 | return np.random.randint(255, size=(3)) 82 | else: 83 | return self.colorMap[index] 84 | pass 85 | 86 | 87 | def loadClassMap(args): 88 | classMap = {} 89 | classLabelMap = {} 90 | with open(args.data_raw_path[:-6] + '/scannetv2-labels.combined.tsv') as info_file: 91 | line_index = 0 92 | for line in info_file: 93 | if line_index > 0: 94 | line = line.split('\t') 95 | 96 | key = line[1].strip() 97 | classMap[key] = line[7].strip() 98 | classMap[key + 's'] = line[7].strip() 99 | classMap[key + 'es'] = line[7].strip() 100 | classMap[key[:-1] + 'ves'] = line[7].strip() 101 | 102 | if line[4].strip() != '': 103 | nyuLabel = int(line[4].strip()) 104 | else: 105 | nyuLabel = -1 106 | pass 107 | classLabelMap[key] = [nyuLabel, line_index - 1] 108 | classLabelMap[key + 's'] = [nyuLabel, line_index - 1] 109 | classLabelMap[key[:-1] + 'ves'] = [nyuLabel, line_index - 1] 110 | pass 111 | line_index += 1 112 | continue 113 | pass 114 | return classMap, classLabelMap 115 | 116 | 117 | def writePointCloudFace(filename, points, faces): 118 | with open(filename, 'w') as f: 119 | header = """ply 120 | format ascii 1.0 121 | element vertex """ 122 | header += str(len(points)) 123 | header += """ 124 | property float x 125 | property float y 126 | property float z 127 | property uchar red 128 | property uchar green 129 | property uchar blue 130 | element face """ 131 | header += str(len(faces)) 132 | header += """ 133 | property list uchar int vertex_index 134 | end_header 135 | """ 136 | f.write(header) 137 | for point in points: 138 | for value in point[:3]: 139 | f.write(str(value) + ' ') 140 | continue 141 | for value in point[3:]: 142 | f.write(str(int(value)) + ' ') 143 | continue 144 | f.write('\n') 145 | continue 146 | for face in faces: 147 | f.write('3 ' + str(face[0]) + ' ' + str(face[1]) + ' ' + str(face[2]) + '\n') 148 | continue 149 | f.close() 150 | pass 151 | return 152 | 153 | 154 | def mergePlanes(points, planes, planePointIndices, planeSegments, segmentNeighbors, numPlanes, debug=False): 155 | planeFittingErrors = [] 156 | for plane, pointIndices in zip(planes, planePointIndices): 157 | XYZ = points[pointIndices] 158 | planeNorm = np.linalg.norm(plane) 159 | if planeNorm == 0: 160 | planeFittingErrors.append(fittingErrorThreshold * 2) 161 | continue 162 | diff = np.abs(np.matmul(XYZ, plane) - np.ones(XYZ.shape[0])) / planeNorm 163 | planeFittingErrors.append(diff.mean()) 164 | continue 165 | 166 | planeList = list(zip(planes, planePointIndices, planeSegments, planeFittingErrors)) 167 | planeList = sorted(planeList, key=lambda x: x[3]) 168 | 169 | while len(planeList) > 0: 170 | hasChange = False 171 | planeIndex = 0 172 | 173 | if debug: 174 | for index, planeInfo in enumerate(sorted(planeList, key=lambda x: -len(x[1]))): 175 | print(index, planeInfo[0] / np.linalg.norm(planeInfo[0]), planeInfo[2], planeInfo[3]) 176 | continue 177 | pass 178 | 179 | while planeIndex < len(planeList): 180 | plane, pointIndices, segments, fittingError = planeList[planeIndex] 181 | if fittingError > fittingErrorThreshold: 182 | break 183 | neighborSegments = [] 184 | for segment in segments: 185 | if segment in segmentNeighbors: 186 | neighborSegments += segmentNeighbors[segment] 187 | pass 188 | continue 189 | neighborSegments += list(segments) 190 | neighborSegments = set(neighborSegments) 191 | bestNeighborPlane = (fittingErrorThreshold, -1, None) 192 | for neighborPlaneIndex, neighborPlane in enumerate(planeList): 193 | if neighborPlaneIndex <= planeIndex: 194 | continue 195 | if not bool(neighborSegments & neighborPlane[2]): 196 | continue 197 | neighborPlaneNorm = np.linalg.norm(neighborPlane[0]) 198 | if neighborPlaneNorm < 1e-4: 199 | continue 200 | dotProduct = np.abs( 201 | np.dot(neighborPlane[0], plane) / np.maximum(neighborPlaneNorm * np.linalg.norm(plane), 1e-4)) 202 | if dotProduct < orthogonalThreshold: 203 | continue 204 | newPointIndices = np.concatenate([neighborPlane[1], pointIndices], axis=0) 205 | XYZ = points[newPointIndices] 206 | if dotProduct > parallelThreshold and len(neighborPlane[1]) > len(pointIndices) * 0.5: 207 | newPlane = fitPlane(XYZ) 208 | else: 209 | newPlane = plane 210 | pass 211 | diff = np.abs(np.matmul(XYZ, newPlane) - np.ones(XYZ.shape[0])) / np.linalg.norm(newPlane) 212 | newFittingError = diff.mean() 213 | if debug: 214 | print( 215 | len(planeList), planeIndex, neighborPlaneIndex, newFittingError, plane / np.linalg.norm(plane), 216 | neighborPlane[ 217 | 0] / np.linalg.norm( 218 | neighborPlane[0]), 219 | dotProduct, orthogonalThreshold) 220 | pass 221 | if newFittingError < bestNeighborPlane[0]: 222 | newPlaneInfo = [newPlane, newPointIndices, segments.union(neighborPlane[2]), newFittingError] 223 | bestNeighborPlane = (newFittingError, neighborPlaneIndex, newPlaneInfo) 224 | pass 225 | continue 226 | if bestNeighborPlane[1] != -1: 227 | newPlaneList = planeList[:planeIndex] + planeList[planeIndex + 1:bestNeighborPlane[1]] + planeList[ 228 | bestNeighborPlane[ 229 | 1] + 1:] 230 | newFittingError, newPlaneIndex, newPlane = bestNeighborPlane 231 | for newPlaneIndex in range(len(newPlaneList)): 232 | if (newPlaneIndex == 0 and newPlaneList[newPlaneIndex][3] > newFittingError) \ 233 | or newPlaneIndex == len(newPlaneList) - 1 \ 234 | or (newPlaneList[newPlaneIndex][3] < newFittingError and newPlaneList[newPlaneIndex + 1][ 235 | 3] > newFittingError): 236 | newPlaneList.insert(newPlaneIndex, newPlane) 237 | break 238 | continue 239 | if len(newPlaneList) == 0: 240 | newPlaneList = [newPlane] 241 | pass 242 | planeList = newPlaneList 243 | hasChange = True 244 | else: 245 | planeIndex += 1 246 | pass 247 | continue 248 | if not hasChange: 249 | break 250 | continue 251 | 252 | planeList = sorted(planeList, key=lambda x: -len(x[1])) 253 | 254 | minNumPlanes, maxNumPlanes = numPlanes 255 | if minNumPlanes == 1 and len(planeList) == 0: 256 | if debug: 257 | print('at least one plane') 258 | pass 259 | elif len(planeList) > maxNumPlanes: 260 | if debug: 261 | print('too many planes', len(planeList), maxNumPlanes) 262 | pass 263 | planeList = planeList[:maxNumPlanes] + [(np.zeros(3), planeInfo[1], planeInfo[2], fittingErrorThreshold) for 264 | planeInfo in planeList[maxNumPlanes:]] 265 | pass 266 | 267 | groupedPlanes, groupedPlanePointIndices, groupedPlaneSegments, groupedPlaneFittingErrors = zip(*planeList) 268 | return groupedPlanes, groupedPlanePointIndices, groupedPlaneSegments 269 | 270 | 271 | def furthest_point_sampling(points, N=100): 272 | from sklearn.metrics import pairwise_distances 273 | D = pairwise_distances(points, metric='euclidean') 274 | # By default, takes the first point in the list to be the 275 | # first point in the permutation, but could be random 276 | perm = np.zeros(N, dtype=np.int64) 277 | lambdas = np.zeros(N) 278 | ds = D[0, :] 279 | for i in range(1, N): 280 | idx = np.argmax(ds) 281 | perm[i] = idx 282 | lambdas[i] = ds[idx] 283 | ds = np.minimum(ds, D[idx, :]) 284 | return (perm, lambdas) 285 | 286 | 287 | def writePointFacePlane(filename, points, faces): 288 | with open(filename, 'w') as f: 289 | header = """ply 290 | format ascii 1.0 291 | element vertex """ 292 | header += str(len(points)) 293 | header += """ 294 | property float x 295 | property float y 296 | property float z 297 | property uchar red 298 | property uchar green 299 | property uchar blue 300 | element face """ 301 | header += str(len(faces)) 302 | header += """ 303 | property list uchar int vertex_index 304 | end_header 305 | """ 306 | f.write(header) 307 | for point in points: 308 | for value in point[:3]: 309 | f.write(str(value) + ' ') 310 | continue 311 | for value in point[3:]: 312 | f.write(str(int(value)) + ' ') 313 | continue 314 | f.write('\n') 315 | continue 316 | for face in faces: 317 | for value in face: 318 | f.write(str(value) + ' ') 319 | continue 320 | f.write('\n') 321 | continue 322 | f.close() 323 | pass 324 | return 325 | 326 | 327 | def project2plane(plane, plane_points): 328 | A, B = plane_points[0], plane_points[1] 329 | AB = B - A 330 | N = plane / np.linalg.norm(plane, ord=2) 331 | U = AB / np.linalg.norm(AB, ord=2) 332 | V = np.cross(U, N) 333 | u = A + U 334 | v = A + V 335 | n = A + N 336 | S = [[A[0], u[0], v[0], n[0]], [A[1], u[1], v[1], n[1]], [A[2], u[2], v[2], n[2]], [1, 1, 1, 1]] 337 | D = [[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [1, 1, 1, 1]] 338 | M = np.matmul(D, np.linalg.inv(S)) 339 | return M 340 | 341 | 342 | def points2contours(points, size=0.1): 343 | points = points / size 344 | points = points.astype(np.int) 345 | min_point = points.min(axis=0) - 5 346 | max_point = points.max(axis=0) + 5 347 | image_size = max_point - min_point + 1 348 | points = points - min_point 349 | image = np.zeros(image_size).astype(np.uint8) 350 | image[points[:, 0], points[:, 1]] = 255 351 | # plt.imshow(image, cmap='gray') 352 | # plt.show() 353 | contours, hierarchy = cv2.findContours(image, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) 354 | contours_valid = [] 355 | for i, con in enumerate(contours): 356 | if con.shape[0] > 10: 357 | con = (con + min_point) * size 358 | contours_valid.append(con) 359 | return contours_valid 360 | 361 | 362 | def generate_planes(args, scene_id, high_res=False, save_mesh=False, debug=False): 363 | if not os.path.exists(args.save_path + '/' + scene_id + '/annotation'): 364 | os.system('mkdir -p ' + args.save_path + '/' + scene_id + '/annotation') 365 | 366 | filename = args.data_raw_path + scene_id + '/' + scene_id + '.aggregation.json' 367 | data = json.load(open(filename, 'r')) 368 | aggregation = np.array(data['segGroups']) 369 | 370 | if high_res: 371 | filename = args.data_raw_path + scene_id + '/' + scene_id + '_vh_clean.labels.ply' 372 | else: 373 | filename = args.data_raw_path + scene_id + '/' + scene_id + '_vh_clean_2.labels.ply' 374 | 375 | plydata = PlyData.read(filename) 376 | vertices = plydata['vertex'] 377 | points = np.stack([vertices['x'], vertices['y'], vertices['z']], axis=1) 378 | faces = np.array(plydata['face']['vertex_indices']) 379 | 380 | # semanticSegmentation = vertices['label'] 381 | 382 | if high_res: 383 | filename = args.data_raw_path + scene_id + '/' + scene_id + '_vh_clean.segs.json' 384 | else: 385 | filename = args.data_raw_path + scene_id + '/' + scene_id + '_vh_clean_2.0.010000.segs.json' 386 | 387 | data = json.load(open(filename, 'r')) 388 | segmentation = np.array(data['segIndices']) 389 | 390 | groupSegments = [] 391 | groupLabels = [] 392 | for segmentIndex in range(len(aggregation)): 393 | groupSegments.append(aggregation[segmentIndex]['segments']) 394 | groupLabels.append(aggregation[segmentIndex]['label']) 395 | 396 | segmentation = segmentation.astype(np.int32) 397 | 398 | uniqueSegments = np.unique(segmentation).tolist() 399 | numSegments = 0 400 | for segments in groupSegments: 401 | for segmentIndex in segments: 402 | if segmentIndex in uniqueSegments: 403 | uniqueSegments.remove(segmentIndex) 404 | numSegments += len(segments) 405 | 406 | for segment in uniqueSegments: 407 | groupSegments.append([segment, ]) 408 | groupLabels.append('unannotated') 409 | 410 | segmentEdges = [] 411 | for faceIndex in range(faces.shape[0]): 412 | face = faces[faceIndex] 413 | segment_1 = segmentation[face[0]] 414 | segment_2 = segmentation[face[1]] 415 | segment_3 = segmentation[face[2]] 416 | if segment_1 != segment_2 or segment_1 != segment_3: 417 | if segment_1 != segment_2 and segment_1 != -1 and segment_2 != -1: 418 | segmentEdges.append((min(segment_1, segment_2), max(segment_1, segment_2))) 419 | if segment_1 != segment_3 and segment_1 != -1 and segment_3 != -1: 420 | segmentEdges.append((min(segment_1, segment_3), max(segment_1, segment_3))) 421 | if segment_2 != segment_3 and segment_2 != -1 and segment_3 != -1: 422 | segmentEdges.append((min(segment_2, segment_3), max(segment_2, segment_3))) 423 | segmentEdges = list(set(segmentEdges)) 424 | 425 | labelNumPlanes = {'wall': [1, 3], 426 | 'floor': [1, 1], 427 | 'cabinet': [0, 5], 428 | 'bed': [0, 5], 429 | 'chair': [0, 5], 430 | 'sofa': [0, 10], 431 | 'table': [0, 5], 432 | 'door': [1, 2], 433 | 'window': [0, 2], 434 | 'bookshelf': [0, 5], 435 | 'picture': [1, 1], 436 | 'counter': [0, 10], 437 | 'blinds': [0, 0], 438 | 'desk': [0, 10], 439 | 'shelf': [0, 5], 440 | 'shelves': [0, 5], 441 | 'curtain': [0, 0], 442 | 'dresser': [0, 5], 443 | 'pillow': [0, 0], 444 | 'mirror': [0, 0], 445 | 'entrance': [1, 1], 446 | 'floor mat': [1, 1], 447 | 'clothes': [0, 0], 448 | 'ceiling': [0, 5], 449 | 'book': [0, 1], 450 | 'books': [0, 1], 451 | 'refridgerator': [0, 5], 452 | 'television': [1, 1], 453 | 'paper': [0, 1], 454 | 'towel': [0, 1], 455 | 'shower curtain': [0, 1], 456 | 'box': [0, 5], 457 | 'whiteboard': [1, 5], 458 | 'person': [0, 0], 459 | 'night stand': [1, 5], 460 | 'toilet': [0, 5], 461 | 'sink': [0, 5], 462 | 'lamp': [0, 1], 463 | 'bathtub': [0, 5], 464 | 'bag': [0, 1], 465 | 'otherprop': [0, 5], 466 | 'otherstructure': [0, 5], 467 | 'otherfurniture': [0, 5], 468 | 'unannotated': [0, 5], 469 | '': [0, 0], 470 | } 471 | nonPlanarGroupLabels = ['bicycle', 'bottle', 'water bottle'] 472 | nonPlanarGroupLabels = {label: True for label in nonPlanarGroupLabels} 473 | 474 | # verticalLabels = ['wall', 'door', 'cabinet'] 475 | classMap, classLabelMap = loadClassMap(args) 476 | classMap['unannotated'] = 'unannotated' 477 | classLabelMap['unannotated'] = [max([index for index, label in classLabelMap.values()]) + 1, 41] 478 | allXYZ = points.reshape(-1, 3) 479 | 480 | segmentNeighbors = {} 481 | for segmentEdge in segmentEdges: 482 | if segmentEdge[0] not in segmentNeighbors: 483 | segmentNeighbors[segmentEdge[0]] = [] 484 | segmentNeighbors[segmentEdge[0]].append(segmentEdge[1]) 485 | 486 | if segmentEdge[1] not in segmentNeighbors: 487 | segmentNeighbors[segmentEdge[1]] = [] 488 | segmentNeighbors[segmentEdge[1]].append(segmentEdge[0]) 489 | 490 | planeGroups = [] 491 | print('num groups', len(groupSegments)) 492 | 493 | debugIndex = -1 494 | 495 | for groupIndex, group in enumerate(groupSegments): 496 | if debugIndex != -1 and groupIndex != debugIndex: 497 | continue 498 | if groupLabels[groupIndex] in nonPlanarGroupLabels: 499 | groupLabel = groupLabels[groupIndex] 500 | minNumPlanes, maxNumPlanes = 0, 0 501 | elif groupLabels[groupIndex] in classMap: 502 | groupLabel = classMap[groupLabels[groupIndex]] 503 | minNumPlanes, maxNumPlanes = labelNumPlanes[groupLabel] 504 | else: 505 | minNumPlanes, maxNumPlanes = 0, 0 506 | groupLabel = '' 507 | 508 | if maxNumPlanes == 0: 509 | pointMasks = [] 510 | for segmentIndex in group: 511 | pointMasks.append(segmentation == segmentIndex) 512 | pointIndices = np.any(np.stack(pointMasks, 0), 0).nonzero()[0] 513 | groupPlanes = [[np.zeros(3), pointIndices, []]] 514 | planeGroups.append(groupPlanes) 515 | continue 516 | groupPlanes = [] 517 | groupPlanePointIndices = [] 518 | groupPlaneSegments = [] 519 | for segmentIndex in group: 520 | segmentMask = segmentation == segmentIndex 521 | allSegmentIndices = segmentMask.nonzero()[0] 522 | segmentIndices = allSegmentIndices.copy() 523 | 524 | XYZ = allXYZ[segmentMask.reshape(-1)] 525 | numPoints = XYZ.shape[0] 526 | 527 | for c in range(2): 528 | if c == 0: 529 | ## First try to fit one plane 530 | plane = fitPlane(XYZ) 531 | diff = np.abs(np.matmul(XYZ, plane) - np.ones(XYZ.shape[0])) / np.linalg.norm(plane) 532 | if diff.mean() < fittingErrorThreshold: 533 | groupPlanes.append(plane) 534 | groupPlanePointIndices.append(segmentIndices) 535 | groupPlaneSegments.append(set([segmentIndex])) 536 | break 537 | else: 538 | ## Run ransac 539 | segmentPlanes = [] 540 | segmentPlanePointIndices = [] 541 | 542 | for planeIndex in range(numPlanesPerSegment): 543 | if len(XYZ) < planeAreaThreshold: 544 | continue 545 | bestPlaneInfo = [None, 0, None] 546 | for iteration in range(min(XYZ.shape[0], numIterations)): 547 | sampledPoints = XYZ[np.random.choice(np.arange(XYZ.shape[0]), size=(3), replace=False)] 548 | try: 549 | plane = fitPlane(sampledPoints) 550 | except: 551 | continue 552 | diff = np.abs(np.matmul(XYZ, plane) - np.ones(XYZ.shape[0])) / np.linalg.norm(plane) 553 | inlierMask = diff < planeDiffThreshold 554 | numInliers = inlierMask.sum() 555 | if numInliers > bestPlaneInfo[1]: 556 | bestPlaneInfo = [plane, numInliers, inlierMask] 557 | 558 | if bestPlaneInfo[1] < planeAreaThreshold: 559 | break 560 | 561 | pointIndices = segmentIndices[bestPlaneInfo[2]] 562 | bestPlane = fitPlane(XYZ[bestPlaneInfo[2]]) 563 | 564 | segmentPlanes.append(bestPlane) 565 | segmentPlanePointIndices.append(pointIndices) 566 | 567 | outlierMask = np.logical_not(bestPlaneInfo[2]) 568 | segmentIndices = segmentIndices[outlierMask] 569 | XYZ = XYZ[outlierMask] 570 | 571 | if sum([len(indices) for indices in segmentPlanePointIndices]) < numPoints * 0.5: 572 | groupPlanes.append(np.zeros(3)) 573 | groupPlanePointIndices.append(allSegmentIndices) 574 | groupPlaneSegments.append(set([segmentIndex])) 575 | else: 576 | if len(segmentIndices) > 0: 577 | ## Add remaining non-planar regions 578 | segmentPlanes.append(np.zeros(3)) 579 | segmentPlanePointIndices.append(segmentIndices) 580 | groupPlanes += segmentPlanes 581 | groupPlanePointIndices += segmentPlanePointIndices 582 | 583 | for _ in range(len(segmentPlanes)): 584 | groupPlaneSegments.append(set([segmentIndex])) 585 | 586 | numRealPlanes = len([plane for plane in groupPlanes if np.linalg.norm(plane) > 1e-4]) 587 | if minNumPlanes == 1 and numRealPlanes == 0: 588 | ## Some instances always contain at least one planes (e.g, the floor) 589 | maxArea = (planeAreaThreshold, -1) 590 | for index, indices in enumerate(groupPlanePointIndices): 591 | if len(indices) > maxArea[0]: 592 | maxArea = (len(indices), index) 593 | maxArea, planeIndex = maxArea 594 | if planeIndex >= 0: 595 | groupPlanes[planeIndex] = fitPlane(allXYZ[groupPlanePointIndices[planeIndex]]) 596 | numRealPlanes = 1 597 | if minNumPlanes == 1 and maxNumPlanes == 1 and numRealPlanes > 1: 598 | ## Some instances always contain at most one planes (e.g, the floor) 599 | 600 | pointIndices = np.concatenate( 601 | [indices for plane, indices in list(zip(groupPlanes, groupPlanePointIndices))], 602 | axis=0) 603 | XYZ = allXYZ[pointIndices] 604 | plane = fitPlane(XYZ) 605 | diff = np.abs(np.matmul(XYZ, plane) - np.ones(XYZ.shape[0])) / np.linalg.norm(plane) 606 | 607 | if groupLabel == 'floor': 608 | ## Relax the constraint for the floor due to the misalignment issue in ScanNet 609 | fittingErrorScale = 3 610 | else: 611 | fittingErrorScale = 1 612 | 613 | if diff.mean() < fittingErrorThreshold * fittingErrorScale: 614 | groupPlanes = [plane] 615 | groupPlanePointIndices = [pointIndices] 616 | planeSegments = [] 617 | for segments in groupPlaneSegments: 618 | planeSegments += list(segments) 619 | groupPlaneSegments = [set(planeSegments)] 620 | numRealPlanes = 1 621 | 622 | if numRealPlanes > 1: 623 | groupPlanes, groupPlanePointIndices, groupPlaneSegments = mergePlanes(points, groupPlanes, 624 | groupPlanePointIndices, 625 | groupPlaneSegments, segmentNeighbors, 626 | numPlanes=( 627 | minNumPlanes, maxNumPlanes), 628 | debug=debugIndex != -1) 629 | 630 | groupNeighbors = [] 631 | for planeIndex, planeSegments in enumerate(groupPlaneSegments): 632 | neighborSegments = [] 633 | for segment in planeSegments: 634 | if segment in segmentNeighbors: 635 | neighborSegments += segmentNeighbors[segment] 636 | neighborSegments += list(planeSegments) 637 | neighborSegments = set(neighborSegments) 638 | neighborPlaneIndices = [] 639 | for neighborPlaneIndex, neighborPlaneSegments in enumerate(groupPlaneSegments): 640 | if neighborPlaneIndex == planeIndex: 641 | continue 642 | if bool(neighborSegments & neighborPlaneSegments): 643 | plane = groupPlanes[planeIndex] 644 | neighborPlane = groupPlanes[neighborPlaneIndex] 645 | if np.linalg.norm(plane) * np.linalg.norm(neighborPlane) < 1e-4: 646 | continue 647 | dotProduct = np.abs( 648 | np.dot(plane, neighborPlane) / np.maximum(np.linalg.norm(plane) * np.linalg.norm(neighborPlane), 649 | 1e-4)) 650 | if dotProduct < orthogonalThreshold: 651 | neighborPlaneIndices.append(neighborPlaneIndex) 652 | groupNeighbors.append(neighborPlaneIndices) 653 | groupPlanes = list(zip(groupPlanes, groupPlanePointIndices, groupNeighbors)) 654 | # groupPlanes = zip(groupPlanes, groupPlanePointIndices, groupNeighbors) 655 | planeGroups.append(groupPlanes) 656 | 657 | if debug: 658 | colorMap = ColorPalette(segmentation.max() + 2).getColorMap() 659 | colorMap[-1] = 0 660 | colorMap[-2] = 255 661 | annotationFolder = 'test/' 662 | else: 663 | numPlanes = sum([len(group) for group in planeGroups]) 664 | segmentationColor = (np.arange(numPlanes + 1) + 1) * 100 665 | colorMap = np.stack([segmentationColor / (256 * 256), segmentationColor / 256 % 256, segmentationColor % 256], 666 | axis=1) 667 | colorMap[-1] = 0 668 | annotationFolder = args.save_path + scene_id + '/annotation' 669 | 670 | if debug: 671 | colors = colorMap[segmentation] 672 | writePointCloudFace(annotationFolder + '/segments.ply', np.concatenate([points, colors], axis=-1), faces) 673 | 674 | groupedSegmentation = np.full(segmentation.shape, fill_value=-1) 675 | for segmentIndex in range(len(aggregation)): 676 | indices = aggregation[segmentIndex]['segments'] 677 | for index in indices: 678 | groupedSegmentation[segmentation == index] = segmentIndex 679 | groupedSegmentation = groupedSegmentation.astype(np.int32) 680 | colors = colorMap[groupedSegmentation] 681 | writePointCloudFace(annotationFolder + '/groups.ply', np.concatenate([points, colors], axis=-1), faces) 682 | 683 | planes = [] 684 | planePointIndices = [] 685 | planeInfo = [] 686 | structureIndex = 0 687 | for index, group in enumerate(planeGroups): 688 | groupPlanes, groupPlanePointIndices, groupNeighbors = zip(*group) 689 | 690 | diag = np.diag(np.ones(len(groupNeighbors))) 691 | adjacencyMatrix = diag.copy() 692 | for groupIndex, neighbors in enumerate(groupNeighbors): 693 | for neighbor in neighbors: 694 | adjacencyMatrix[groupIndex][neighbor] = 1 695 | if groupLabels[index] in classLabelMap: 696 | label = classLabelMap[groupLabels[index]] 697 | else: 698 | print('label not valid', groupLabels[index]) 699 | exit(1) 700 | label = -1 701 | groupInfo = [[(index, label[0], label[1])] for _ in range(len(groupPlanes))] 702 | groupPlaneIndices = (adjacencyMatrix.sum(-1) >= 2).nonzero()[0] 703 | usedMask = {} 704 | for groupPlaneIndex in groupPlaneIndices: 705 | if groupPlaneIndex in usedMask: 706 | continue 707 | groupStructure = adjacencyMatrix[groupPlaneIndex].copy() 708 | for neighbor in groupStructure.nonzero()[0]: 709 | if np.any(adjacencyMatrix[neighbor] < groupStructure): 710 | groupStructure[neighbor] = 0 711 | groupStructure = groupStructure.nonzero()[0] 712 | 713 | if len(groupStructure) < 2: 714 | print('invalid structure') 715 | print(groupPlaneIndex, groupPlaneIndices) 716 | print(groupNeighbors) 717 | print(groupPlaneIndex) 718 | print(adjacencyMatrix.sum(-1) >= 2) 719 | print((adjacencyMatrix.sum(-1) >= 2).nonzero()[0]) 720 | print(adjacencyMatrix[groupPlaneIndex]) 721 | print(adjacencyMatrix) 722 | print(groupStructure) 723 | exit(1) 724 | if len(groupStructure) >= 4: 725 | print('complex structure') 726 | print('group index', index) 727 | print(adjacencyMatrix) 728 | print(groupStructure) 729 | groupStructure = groupStructure[:3] 730 | if len(groupStructure) in [2, 3]: 731 | for planeIndex in groupStructure: 732 | groupInfo[planeIndex].append((structureIndex, len(groupStructure))) 733 | structureIndex += 1 734 | for planeIndex in groupStructure: 735 | usedMask[planeIndex] = True 736 | planes += groupPlanes 737 | planePointIndices += groupPlanePointIndices 738 | planeInfo += groupInfo 739 | 740 | planeSegmentation = np.full(segmentation.shape, fill_value=-1, dtype=np.int32) 741 | for planeIndex, planePoints in enumerate(planePointIndices): 742 | planeSegmentation[planePoints] = planeIndex 743 | 744 | # generate planar 745 | if save_mesh: 746 | import copy 747 | color_vis = random_color() 748 | points_plane = [] 749 | points_tensor = torch.Tensor(points).cuda() 750 | faces_copy = copy.deepcopy(faces) 751 | faces_copy = torch.Tensor(np.stack(faces_copy)).cuda().int() 752 | planes_tensor = torch.zeros_like(points_tensor) 753 | indices_tensor = torch.zeros_like(points_tensor[:, 0]).int() 754 | for i in range(len(planes)): 755 | planes_tensor[planePointIndices[i]] = torch.Tensor(planes[i]).cuda() 756 | indices_tensor[planePointIndices[i]] = i 757 | 758 | valid = (planes_tensor != 0).any(-1) 759 | invalid_ind = torch.nonzero(valid == 0, as_tuple=False).squeeze(1) 760 | planes_tensor_valid = planes_tensor[valid] 761 | points_tensor_valid = points_tensor[valid] 762 | t = ((points_tensor_valid.unsqueeze(1) @ planes_tensor_valid.unsqueeze(-1)).squeeze() - 1) / (planes_tensor_valid[:, 0] ** 2 + planes_tensor_valid[:, 1] ** 2 + planes_tensor_valid[:, 2] ** 2) 763 | plane_points = points_tensor_valid - planes_tensor_valid[:, :3] * t.unsqueeze(-1) 764 | points_tensor[valid] = plane_points 765 | points_tensor[invalid_ind] = plane_points[0] 766 | 767 | n = 100 768 | part_num = faces_copy.shape[0] // n 769 | match_list = [] 770 | for i in range(n): 771 | if i == n-1: 772 | faces_part = faces_copy[i * part_num:] 773 | else: 774 | faces_part = faces_copy[i * part_num: (i + 1) * part_num] 775 | match = (faces_part.unsqueeze(0) != invalid_ind.unsqueeze(-1).unsqueeze(-1)) 776 | match = match.all(-1) 777 | match = match.all(0) 778 | match_list.append(match) 779 | match = torch.cat(match_list) 780 | faces_copy = faces_copy[match] 781 | points_plane = points_tensor.data.cpu().numpy() 782 | 783 | n_ins = indices_tensor.data.cpu().numpy().max() + 1 784 | indices_tensor = indices_tensor.data.cpu().numpy() 785 | segmentationColor = (np.arange(n_ins + 1) + 1) * 100 786 | colorMap = np.stack([segmentationColor / (256 * 256), segmentationColor / 256 % 256, segmentationColor % 256], 787 | axis=1) 788 | colorMap[-1] = 0 789 | plane_colors = colorMap[indices_tensor] 790 | 791 | # for vis 792 | colorMap_vis = color_vis(n_ins) 793 | plane_colors_vis = colorMap_vis[indices_tensor] 794 | 795 | writePointCloudFace(annotationFolder + '/planes_mesh.ply', 796 | np.concatenate([points_plane, plane_colors], axis=-1), faces_copy.data.cpu().numpy()) 797 | 798 | writePointCloudFace(annotationFolder + '/planes_mesh_vis.ply', 799 | np.concatenate([points_plane, plane_colors_vis], axis=-1), faces_copy.data.cpu().numpy()) 800 | 801 | if debug: 802 | groupSegmentation = np.full(segmentation.shape, fill_value=-1, dtype=np.int32) 803 | structureSegmentation = np.full(segmentation.shape, fill_value=-1, dtype=np.int32) 804 | typeSegmentation = np.full(segmentation.shape, fill_value=-1, dtype=np.int32) 805 | for planeIndex, planePoints in enumerate(planePointIndices): 806 | if len(planeInfo[planeIndex]) > 1: 807 | structureSegmentation[planePoints] = planeInfo[planeIndex][1][0] 808 | typeSegmentation[planePoints] = np.maximum(typeSegmentation[planePoints], 809 | planeInfo[planeIndex][1][1] - 2) 810 | groupSegmentation[planePoints] = planeInfo[planeIndex][0][0] 811 | 812 | colors = colorMap[groupSegmentation] 813 | writePointCloudFace(annotationFolder + '/group.ply', np.concatenate([points, colors], axis=-1), faces) 814 | 815 | colors = colorMap[structureSegmentation] 816 | writePointCloudFace(annotationFolder + '/structure.ply', np.concatenate([points, colors], axis=-1), faces) 817 | 818 | colors = colorMap[typeSegmentation] 819 | writePointCloudFace(annotationFolder + '/type.ply', np.concatenate([points, colors], axis=-1), faces) 820 | 821 | planes = np.array(planes) 822 | print('number of planes: ', planes.shape[0]) 823 | # planesD = 1.0 / np.maximum(np.linalg.norm(planes, axis=-1, keepdims=True), 1e-4) 824 | # planes *= pow(planesD, 2) 825 | 826 | if debug: 827 | print(len(planes), len(planeInfo)) 828 | exit(1) 829 | 830 | plane_points = [] 831 | for i in range(len(planePointIndices)): 832 | plane_points.append(points[planePointIndices[i]]) 833 | 834 | planes_valid = [] 835 | planeInfo_valid = [] 836 | plane_points_valid = [] 837 | for i in range(len(plane_points)): 838 | if (planes[i] != 0).any(): 839 | planes_valid.append(planes[i]) 840 | planeInfo_valid.append(planeInfo[i]) 841 | plane_points_valid.append(plane_points[i]) 842 | 843 | planes_valid = np.stack(planes_valid) 844 | np.save(annotationFolder + '/planes.npy', planes_valid) 845 | np.save(annotationFolder + '/plane_info.npy', planeInfo_valid) 846 | np.save(annotationFolder + '/plane_points', plane_points_valid) 847 | 848 | return planes_valid, plane_points_valid 849 | 850 | 851 | if __name__ == '__main__': 852 | 853 | scene_ids = os.listdir(ROOT_FOLDER) 854 | scene_ids = sorted(scene_ids) 855 | 856 | for index, scene_id in enumerate(scene_ids): 857 | if scene_id[:5] != 'scene': 858 | continue 859 | 860 | if not os.path.exists(args.save_path + '/' + scene_id + '/annotation'): 861 | os.system('mkdir -p ' + args.save_path + '/' + scene_id + '/annotation') 862 | 863 | if not os.path.exists(args.save_path + '/' + scene_id + '/annotation/planes.ply'): 864 | print('plane fitting', scene_id) 865 | readMesh(scene_id) 866 | -------------------------------------------------------------------------------- /tools/random_color.py: -------------------------------------------------------------------------------- 1 | import random 2 | import numpy as np 3 | 4 | # ============ for viz ===================== 5 | class random_color(object): 6 | def __init__(self): 7 | num_of_colors=3000 8 | self.colors = ["#"+''.join([random.choice('0123456789ABCDEF') for i in range(6)]) 9 | for j in range(num_of_colors)] 10 | 11 | def __call__(self, ret_n = 10): 12 | assert len(self.colors) > ret_n 13 | ret_color = np.zeros([ret_n, 3]) 14 | for i in range(ret_n): 15 | hex_color = self.colors[i][1:] 16 | ret_color[i] = np.array([int(hex_color[j:j + 2], 16) for j in (0, 2, 4)]) 17 | return ret_color 18 | 19 | -------------------------------------------------------------------------------- /tools/simple_loader.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | import os 4 | import cv2 5 | 6 | 7 | def collate_fn(list_data): 8 | cam_pose, depth_im, _ = list_data 9 | # Concatenate all lists 10 | return cam_pose, depth_im, _ 11 | 12 | 13 | class ScanNetDataset(torch.utils.data.Dataset): 14 | """Pytorch Dataset for a single scene. getitem loads individual frames""" 15 | 16 | def __init__(self, n_imgs, scene, data_path, max_depth, id_list=None): 17 | """ 18 | Args: 19 | """ 20 | self.n_imgs = n_imgs 21 | self.scene = scene 22 | self.data_path = data_path 23 | self.max_depth = max_depth 24 | if id_list is None: 25 | self.id_list = [i for i in range(n_imgs)] 26 | else: 27 | self.id_list = id_list 28 | 29 | def __len__(self): 30 | return self.n_imgs 31 | 32 | def __getitem__(self, id): 33 | """ 34 | Returns: 35 | dict of meta data and images for a single frame 36 | """ 37 | id = self.id_list[id] 38 | cam_pose = np.loadtxt(os.path.join(self.data_path, self.scene, "pose", '{0}.txt'.format(id)), delimiter=' ') 39 | 40 | # Read depth image and camera pose 41 | # depth_im = cv2.imread(os.path.join(self.data_path, self.scene, "depth", 'frame-{0:06d}.depth.pgm'.format(id)), -1).astype( 42 | # np.float32) 43 | # depth_im /= 1000. # depth is saved in 16-bit PNG in millimeters 44 | # depth_im[depth_im > self.max_depth] = 0 45 | 46 | # Read RGB image 47 | # color_image = cv2.cvtColor(cv2.imread(os.path.join(self.data_path, self.scene, "color", str(id) + ".jpg")), 48 | # cv2.COLOR_BGR2RGB) 49 | # color_image = cv2.resize(color_image, (depth_im.shape[1], depth_im.shape[0]), interpolation=cv2.INTER_AREA) 50 | 51 | return cam_pose, None, None 52 | 53 | 54 | class ReplicaDataset(torch.utils.data.Dataset): 55 | """Pytorch Dataset for a single scene. getitem loads individual frames""" 56 | 57 | def __init__(self, n_imgs, scene, data_path, max_depth, id_list=None): 58 | """ 59 | Args: 60 | """ 61 | self.n_imgs = n_imgs 62 | self.scene = scene 63 | self.data_path = data_path 64 | self.max_depth = max_depth 65 | if id_list is None: 66 | self.id_list = [i for i in range(n_imgs)] 67 | else: 68 | self.id_list = id_list 69 | 70 | self.extr = np.loadtxt(os.path.join(data_path, scene, 'traj.txt')).reshape(-1, 4, 4) 71 | 72 | def __len__(self): 73 | return self.n_imgs 74 | 75 | def __getitem__(self, id): 76 | """ 77 | Returns: 78 | dict of meta data and images for a single frame 79 | """ 80 | id = self.id_list[id] 81 | # cam_pose = np.loadtxt(os.path.join(self.data_path, self.scene, "pose", '{0}.txt'.format(id)), delimiter=' ') 82 | cam_pose = self.extr[id] 83 | 84 | return cam_pose, None, None 85 | --------------------------------------------------------------------------------