├── LICENSE ├── README.md ├── confs ├── object.conf └── scene.conf ├── data ├── input │ └── demo_car.xyz └── query_data │ └── demo_car.npz ├── extensions └── chamfer_dist │ ├── __init__.py │ ├── chamfer.cu │ ├── chamfer_cuda.cpp │ ├── setup.py │ └── test.py ├── figs ├── cars.png ├── kitti.png ├── pcpnet.png ├── pugan.png ├── scenes.png └── tmp.txt ├── models ├── dataset.py ├── embedder.py └── fields.py ├── run.py └── tools ├── logger.py ├── surface_extraction.py └── utils.py /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 Junsheng Zhou 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 |

2 | 3 | 4 |

Learning a More Continuous Zero Level Set in Unsigned Distance Fields through Level Set Projection

5 |

6 | Junsheng Zhou* 7 | · 8 | Baorui Ma* 9 | · 10 | Shujuan Li 11 | · 12 | Yu-Shen Liu 13 | · 14 | Zhizhong Han 15 | 16 |

17 |

(* Equal Contribution)

18 |

ICCV 2023

19 |
20 |

21 | 22 | We release the code of the paper Learning a More Continuous Zero Level Set in Unsigned Distance Fields through Level Set Projection in this repository. 23 | 24 | 25 | ## Reconstruction Results 26 | ### ShapeNetCars 27 |

28 | 29 |

30 | 31 | ### 3DScenes 32 |

33 | 34 |

35 | 36 | ### KITTI 37 |

38 | 39 |

40 | 41 | ## Point Upsampling Results 42 |

43 | 44 |

45 | 46 | ## Point Normal Estimation Results 47 |

48 | 49 |

50 | 51 | 52 | 53 | ## Installation 54 | Our code is implemented in Python 3.8, PyTorch 1.11.0 and CUDA 11.3. 55 | - Install python Dependencies 56 | ```bash 57 | conda create -n levelsetudf python=3.8 58 | conda activate levelsetudf 59 | conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch 60 | pip install tqdm pyhocon==0.3.57 trimesh PyMCubes scipy point_cloud_utils==0.29.7 61 | ``` 62 | - Compile C++ extensions 63 | ``` 64 | cd extensions/chamfer_dist 65 | python setup.py install 66 | ``` 67 | 68 | ## Quick Start 69 | 70 | For a quick start, you can train our LevelSetUDF to reconstruct surfaces from a single point cloud as: 71 | ``` 72 | python run.py --gpu 0 --conf confs/object.conf --dataname demo_car --dir demo_car 73 | ``` 74 | - We provide the data for a demo car in the `./data` folder for a quick start on LevelSetUDF. 75 | 76 | You can find the outputs in the `./outs` folder: 77 | 78 | ``` 79 | │outs/ 80 | ├──demo_car/ 81 | │ ├── mesh 82 | │ ├── densepoints 83 | │ ├── normal 84 | ``` 85 | - The reconstructed meshes are saved in the `mesh` folder 86 | - The upsampled dense point clouds are saved in the `densepoints` folder 87 | - The estimated normals for the point cloud are saved in the `normal` folder 88 | 89 | ## Use Your Own Data 90 | We also provide the instructions for training your own data in the following. 91 | 92 | ### Data 93 | First, you should put your own data to the `./data/input` folder. The datasets is organised as follows: 94 | ``` 95 | │data/ 96 | │── input 97 | │ ├── (dataname).ply/xyz/npy 98 | ``` 99 | We support the point cloud data format of `.ply`, `.xyz` and `.npy` 100 | 101 | ### Run 102 | To train your own data, simply run: 103 | ``` 104 | python run.py --gpu 0 --conf confs/object.conf --dataname (dataname) --dir (dataname) 105 | ``` 106 | 107 | ### Notice 108 | - For achieving better performances on point clouds of different complexity, the weights for the losses should be adjusted. For example, we provide two configs in the `./conf` folder, i.e., `object.conf` and `scene.conf`. If you are reconstructing large scale scenes, the `scene.conf` is recomended, otherwise, the `object.conf` should work fine for object-level reconstructions. 109 | 110 | - In different datasets or your own data, because of the variation in point cloud density, this hyperparameter [scale](https://github.com/junshengzhou/LevelSetUDF/blob/44cd4e72b895f51bd2d06689392e25b31fed017a/models/dataset.py#L52) has a very strong influence on the final result, which controls the distance between the query points and the point cloud. So if you want to get better results, you should adjust this parameter. We give `0.25 * np.sqrt(POINT_NUM_GT / 20000)` here as a reference value, and this value can be used for most object-level reconstructions. 111 | 112 | ## Related works 113 | Please also check out the following works that inspire us a lot: 114 | * [Junsheng Zhou et al. - Learning consistency-aware unsigned distance functions progressively from raw point clouds. (NeurIPS2022)](https://junshengzhou.github.io/CAP-UDF/) 115 | * [Baorui Ma et al. - Neural-Pull: Learning Signed Distance Functions from Point Clouds by Learning to Pull Space onto Surfaces (ICML2021)](https://github.com/mabaorui/NeuralPull-Pytorch) 116 | * [Baorui Ma et al. - Surface Reconstruction from Point Clouds by Learning Predictive Context Priors (CVPR2022)](https://mabaorui.github.io/PredictableContextPrior_page/) 117 | * [Baorui Ma et al. - Reconstructing Surfaces for Sparse Point Clouds with On-Surface Priors (CVPR2022)](https://mabaorui.github.io/-OnSurfacePrior_project_page/) 118 | 119 | ## Citation 120 | If you find our code or paper useful, please consider citing 121 | 122 | @inproceedings{zhou2023levelset, 123 | title={Learning a More Continuous Zero Level Set in Unsigned Distance Fields through Level Set Projection}, 124 | author={Zhou, Junsheng and Ma, Baorui and Li, Shujuan and Liu, Yu-Shen and Han, Zhizhong}, 125 | booktitle={Proceedings of the IEEE/CVF international conference on computer vision}, 126 | year={2023} 127 | } 128 | -------------------------------------------------------------------------------- /confs/object.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = ./outs/ 3 | recording = [ 4 | ./, 5 | ./models 6 | ] 7 | } 8 | 9 | dataset { 10 | data_dir = data/ 11 | } 12 | 13 | train { 14 | learning_rate = 0.001 15 | step1_maxiter = 40000 16 | step2_maxiter = 40000 17 | warm_up_end = 1000 18 | eval_num_points = 1000000 19 | df_filter = 0.01 20 | far = 0.015 21 | outlier = 0.0035 22 | extra_points_rate = 5 23 | low_range = 1.1 24 | 25 | batch_size = 5000 26 | batch_size_step2 = 20000 27 | 28 | save_freq = 5000 29 | val_freq = 10000 30 | val_mesh_freq = 10000 31 | report_freq = 200 32 | 33 | igr_weight = 0.1 34 | mask_weight = 0.0 35 | load_ckpt = none 36 | 37 | proj_weight = 0.002 38 | proj_adapt = 10 39 | 40 | surf_weight = 0.1 41 | orth_weight = 0.01 42 | } 43 | 44 | model { 45 | udf_network { 46 | d_out = 1 47 | d_in = 3 48 | d_hidden = 256 49 | n_layers = 8 50 | skip_in = [4] 51 | multires = 0 52 | bias = 0.5 53 | scale = 1.0 54 | geometric_init = True 55 | weight_norm = True 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /confs/scene.conf: -------------------------------------------------------------------------------- 1 | general { 2 | base_exp_dir = ./outs/ 3 | recording = [ 4 | ./, 5 | ./models 6 | ] 7 | } 8 | 9 | dataset { 10 | data_dir = data/ 11 | } 12 | 13 | train { 14 | learning_rate = 0.001 15 | step1_maxiter = 100000 16 | step2_maxiter = 100000 17 | warm_up_end = 1000 18 | eval_num_points = 1000000 19 | df_filter = 0.01 20 | far = 0.01 21 | outlier = 0.002 22 | extra_points_rate = 2 23 | low_range = 1.1 24 | 25 | batch_size = 5000 26 | batch_size_step2 = 20000 27 | 28 | save_freq = 5000 29 | val_freq = 10000 30 | val_mesh_freq = 10000 31 | report_freq = 200 32 | 33 | igr_weight = 0.1 34 | mask_weight = 0.0 35 | load_ckpt = none 36 | 37 | proj_weight = 0.001 38 | proj_adapt = 10 39 | 40 | surf_weight = 0.01 41 | orth_weight = 0.0001 42 | } 43 | 44 | model { 45 | udf_network { 46 | d_out = 1 47 | d_in = 3 48 | d_hidden = 256 49 | n_layers = 8 50 | skip_in = [4] 51 | multires = 0 52 | bias = 0.5 53 | scale = 1.0 54 | geometric_init = True 55 | weight_norm = True 56 | } 57 | 58 | } 59 | -------------------------------------------------------------------------------- /data/query_data/demo_car.npz: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junshengzhou/LevelSetUDF/4c6d37b5a010a0adb296f6d17431ef6c561d0204/data/query_data/demo_car.npz -------------------------------------------------------------------------------- /extensions/chamfer_dist/__init__.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Thibault GROUEIX 3 | # @Date: 2019-08-07 20:54:24 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-18 15:06:25 6 | # @Email: cshzxie@gmail.com 7 | 8 | import torch 9 | 10 | import chamfer 11 | 12 | 13 | class ChamferFunction(torch.autograd.Function): 14 | @staticmethod 15 | def forward(ctx, xyz1, xyz2): 16 | dist1, dist2, idx1, idx2 = chamfer.forward(xyz1, xyz2) 17 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 18 | 19 | return dist1, dist2 20 | 21 | @staticmethod 22 | def backward(ctx, grad_dist1, grad_dist2): 23 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 24 | grad_xyz1, grad_xyz2 = chamfer.backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2) 25 | return grad_xyz1, grad_xyz2 26 | 27 | 28 | class ChamferDistanceL2(torch.nn.Module): 29 | f''' Chamder Distance L2 30 | ''' 31 | def __init__(self, ignore_zeros=False): 32 | super().__init__() 33 | self.ignore_zeros = ignore_zeros 34 | 35 | def forward(self, xyz1, xyz2): 36 | batch_size = xyz1.size(0) 37 | if batch_size == 1 and self.ignore_zeros: 38 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0) 39 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0) 40 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0) 41 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0) 42 | 43 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2) 44 | return torch.mean(dist1) + torch.mean(dist2) 45 | 46 | class ChamferDistanceL2_split(torch.nn.Module): 47 | f''' Chamder Distance L2 48 | ''' 49 | def __init__(self, ignore_zeros=False): 50 | super().__init__() 51 | self.ignore_zeros = ignore_zeros 52 | 53 | def forward(self, xyz1, xyz2): 54 | batch_size = xyz1.size(0) 55 | if batch_size == 1 and self.ignore_zeros: 56 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0) 57 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0) 58 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0) 59 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0) 60 | 61 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2) 62 | return torch.mean(dist1), torch.mean(dist2) 63 | 64 | class ChamferDistanceL1(torch.nn.Module): 65 | f''' Chamder Distance L1 66 | ''' 67 | def __init__(self, ignore_zeros=False): 68 | super().__init__() 69 | self.ignore_zeros = ignore_zeros 70 | 71 | def forward(self, xyz1, xyz2): 72 | batch_size = xyz1.size(0) 73 | if batch_size == 1 and self.ignore_zeros: 74 | non_zeros1 = torch.sum(xyz1, dim=2).ne(0) 75 | non_zeros2 = torch.sum(xyz2, dim=2).ne(0) 76 | xyz1 = xyz1[non_zeros1].unsqueeze(dim=0) 77 | xyz2 = xyz2[non_zeros2].unsqueeze(dim=0) 78 | 79 | dist1, dist2 = ChamferFunction.apply(xyz1, xyz2) 80 | # import pdb 81 | # pdb.set_trace() 82 | dist1 = torch.sqrt(dist1) 83 | dist2 = torch.sqrt(dist2) 84 | return (torch.mean(dist1) + torch.mean(dist2))/2 85 | 86 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/chamfer.cu: -------------------------------------------------------------------------------- 1 | /* 2 | * @Author: Haozhe Xie 3 | * @Date: 2019-08-07 20:54:24 4 | * @Last Modified by: Haozhe Xie 5 | * @Last Modified time: 2020-06-17 14:58:55 6 | * @Email: cshzxie@gmail.com 7 | */ 8 | 9 | #include 10 | #include 11 | #include 12 | 13 | #include 14 | 15 | __global__ void chamfer_dist_kernel(int batch_size, 16 | int n, 17 | const float* xyz1, 18 | int m, 19 | const float* xyz2, 20 | float* dist, 21 | int* indexes) { 22 | const int batch = 512; 23 | __shared__ float buf[batch * 3]; 24 | for (int i = blockIdx.x; i < batch_size; i += gridDim.x) { 25 | for (int k2 = 0; k2 < m; k2 += batch) { 26 | int end_k = min(m, k2 + batch) - k2; 27 | for (int j = threadIdx.x; j < end_k * 3; j += blockDim.x) { 28 | buf[j] = xyz2[(i * m + k2) * 3 + j]; 29 | } 30 | __syncthreads(); 31 | for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; 32 | j += blockDim.x * gridDim.y) { 33 | float x1 = xyz1[(i * n + j) * 3 + 0]; 34 | float y1 = xyz1[(i * n + j) * 3 + 1]; 35 | float z1 = xyz1[(i * n + j) * 3 + 2]; 36 | float best_dist = 0; 37 | int best_dist_index = 0; 38 | int end_ka = end_k - (end_k & 3); 39 | if (end_ka == batch) { 40 | for (int k = 0; k < batch; k += 4) { 41 | { 42 | float x2 = buf[k * 3 + 0] - x1; 43 | float y2 = buf[k * 3 + 1] - y1; 44 | float z2 = buf[k * 3 + 2] - z1; 45 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 46 | 47 | if (k == 0 || dist < best_dist) { 48 | best_dist = dist; 49 | best_dist_index = k + k2; 50 | } 51 | } 52 | { 53 | float x2 = buf[k * 3 + 3] - x1; 54 | float y2 = buf[k * 3 + 4] - y1; 55 | float z2 = buf[k * 3 + 5] - z1; 56 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 57 | if (dist < best_dist) { 58 | best_dist = dist; 59 | best_dist_index = k + k2 + 1; 60 | } 61 | } 62 | { 63 | float x2 = buf[k * 3 + 6] - x1; 64 | float y2 = buf[k * 3 + 7] - y1; 65 | float z2 = buf[k * 3 + 8] - z1; 66 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 67 | if (dist < best_dist) { 68 | best_dist = dist; 69 | best_dist_index = k + k2 + 2; 70 | } 71 | } 72 | { 73 | float x2 = buf[k * 3 + 9] - x1; 74 | float y2 = buf[k * 3 + 10] - y1; 75 | float z2 = buf[k * 3 + 11] - z1; 76 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 77 | if (dist < best_dist) { 78 | best_dist = dist; 79 | best_dist_index = k + k2 + 3; 80 | } 81 | } 82 | } 83 | } else { 84 | for (int k = 0; k < end_ka; k += 4) { 85 | { 86 | float x2 = buf[k * 3 + 0] - x1; 87 | float y2 = buf[k * 3 + 1] - y1; 88 | float z2 = buf[k * 3 + 2] - z1; 89 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 90 | if (k == 0 || dist < best_dist) { 91 | best_dist = dist; 92 | best_dist_index = k + k2; 93 | } 94 | } 95 | { 96 | float x2 = buf[k * 3 + 3] - x1; 97 | float y2 = buf[k * 3 + 4] - y1; 98 | float z2 = buf[k * 3 + 5] - z1; 99 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 100 | if (dist < best_dist) { 101 | best_dist = dist; 102 | best_dist_index = k + k2 + 1; 103 | } 104 | } 105 | { 106 | float x2 = buf[k * 3 + 6] - x1; 107 | float y2 = buf[k * 3 + 7] - y1; 108 | float z2 = buf[k * 3 + 8] - z1; 109 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 110 | if (dist < best_dist) { 111 | best_dist = dist; 112 | best_dist_index = k + k2 + 2; 113 | } 114 | } 115 | { 116 | float x2 = buf[k * 3 + 9] - x1; 117 | float y2 = buf[k * 3 + 10] - y1; 118 | float z2 = buf[k * 3 + 11] - z1; 119 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 120 | if (dist < best_dist) { 121 | best_dist = dist; 122 | best_dist_index = k + k2 + 3; 123 | } 124 | } 125 | } 126 | } 127 | for (int k = end_ka; k < end_k; k++) { 128 | float x2 = buf[k * 3 + 0] - x1; 129 | float y2 = buf[k * 3 + 1] - y1; 130 | float z2 = buf[k * 3 + 2] - z1; 131 | float dist = x2 * x2 + y2 * y2 + z2 * z2; 132 | if (k == 0 || dist < best_dist) { 133 | best_dist = dist; 134 | best_dist_index = k + k2; 135 | } 136 | } 137 | if (k2 == 0 || dist[(i * n + j)] > best_dist) { 138 | dist[(i * n + j)] = best_dist; 139 | indexes[(i * n + j)] = best_dist_index; 140 | } 141 | } 142 | __syncthreads(); 143 | } 144 | } 145 | } 146 | 147 | std::vector chamfer_cuda_forward(torch::Tensor xyz1, 148 | torch::Tensor xyz2) { 149 | const int batch_size = xyz1.size(0); 150 | const int n = xyz1.size(1); // num_points point cloud A 151 | const int m = xyz2.size(1); // num_points point cloud B 152 | torch::Tensor dist1 = 153 | torch::zeros({batch_size, n}, torch::CUDA(torch::kFloat)); 154 | torch::Tensor dist2 = 155 | torch::zeros({batch_size, m}, torch::CUDA(torch::kFloat)); 156 | torch::Tensor idx1 = torch::zeros({batch_size, n}, torch::CUDA(torch::kInt)); 157 | torch::Tensor idx2 = torch::zeros({batch_size, m}, torch::CUDA(torch::kInt)); 158 | 159 | chamfer_dist_kernel<<>>( 160 | batch_size, n, xyz1.data_ptr(), m, xyz2.data_ptr(), 161 | dist1.data_ptr(), idx1.data_ptr()); 162 | chamfer_dist_kernel<<>>( 163 | batch_size, m, xyz2.data_ptr(), n, xyz1.data_ptr(), 164 | dist2.data_ptr(), idx2.data_ptr()); 165 | 166 | cudaError_t err = cudaGetLastError(); 167 | if (err != cudaSuccess) { 168 | printf("Error in chamfer_cuda_forward: %s\n", cudaGetErrorString(err)); 169 | } 170 | return {dist1, dist2, idx1, idx2}; 171 | } 172 | 173 | __global__ void chamfer_dist_grad_kernel(int b, 174 | int n, 175 | const float* xyz1, 176 | int m, 177 | const float* xyz2, 178 | const float* grad_dist1, 179 | const int* idx1, 180 | float* grad_xyz1, 181 | float* grad_xyz2) { 182 | for (int i = blockIdx.x; i < b; i += gridDim.x) { 183 | for (int j = threadIdx.x + blockIdx.y * blockDim.x; j < n; 184 | j += blockDim.x * gridDim.y) { 185 | float x1 = xyz1[(i * n + j) * 3 + 0]; 186 | float y1 = xyz1[(i * n + j) * 3 + 1]; 187 | float z1 = xyz1[(i * n + j) * 3 + 2]; 188 | int j2 = idx1[i * n + j]; 189 | float x2 = xyz2[(i * m + j2) * 3 + 0]; 190 | float y2 = xyz2[(i * m + j2) * 3 + 1]; 191 | float z2 = xyz2[(i * m + j2) * 3 + 2]; 192 | float g = grad_dist1[i * n + j] * 2; 193 | atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 0]), g * (x1 - x2)); 194 | atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 1]), g * (y1 - y2)); 195 | atomicAdd(&(grad_xyz1[(i * n + j) * 3 + 2]), g * (z1 - z2)); 196 | atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 0]), -(g * (x1 - x2))); 197 | atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 1]), -(g * (y1 - y2))); 198 | atomicAdd(&(grad_xyz2[(i * m + j2) * 3 + 2]), -(g * (z1 - z2))); 199 | } 200 | } 201 | } 202 | 203 | std::vector chamfer_cuda_backward(torch::Tensor xyz1, 204 | torch::Tensor xyz2, 205 | torch::Tensor idx1, 206 | torch::Tensor idx2, 207 | torch::Tensor grad_dist1, 208 | torch::Tensor grad_dist2) { 209 | const int batch_size = xyz1.size(0); 210 | const int n = xyz1.size(1); // num_points point cloud A 211 | const int m = xyz2.size(1); // num_points point cloud B 212 | torch::Tensor grad_xyz1 = torch::zeros_like(xyz1, torch::CUDA(torch::kFloat)); 213 | torch::Tensor grad_xyz2 = torch::zeros_like(xyz2, torch::CUDA(torch::kFloat)); 214 | 215 | chamfer_dist_grad_kernel<<>>( 216 | batch_size, n, xyz1.data_ptr(), m, xyz2.data_ptr(), 217 | grad_dist1.data_ptr(), idx1.data_ptr(), 218 | grad_xyz1.data_ptr(), grad_xyz2.data_ptr()); 219 | chamfer_dist_grad_kernel<<>>( 220 | batch_size, m, xyz2.data_ptr(), n, xyz1.data_ptr(), 221 | grad_dist2.data_ptr(), idx2.data_ptr(), 222 | grad_xyz2.data_ptr(), grad_xyz1.data_ptr()); 223 | 224 | cudaError_t err = cudaGetLastError(); 225 | if (err != cudaSuccess) { 226 | printf("Error in chamfer_cuda_backward: %s\n", cudaGetErrorString(err)); 227 | } 228 | return {grad_xyz1, grad_xyz2}; 229 | } 230 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/chamfer_cuda.cpp: -------------------------------------------------------------------------------- 1 | /* 2 | * @Author: Haozhe Xie 3 | * @Date: 2019-08-07 20:54:24 4 | * @Last Modified by: Haozhe Xie 5 | * @Last Modified time: 2019-12-10 10:33:50 6 | * @Email: cshzxie@gmail.com 7 | */ 8 | 9 | #include 10 | #include 11 | 12 | std::vector chamfer_cuda_forward(torch::Tensor xyz1, 13 | torch::Tensor xyz2); 14 | 15 | std::vector chamfer_cuda_backward(torch::Tensor xyz1, 16 | torch::Tensor xyz2, 17 | torch::Tensor idx1, 18 | torch::Tensor idx2, 19 | torch::Tensor grad_dist1, 20 | torch::Tensor grad_dist2); 21 | 22 | std::vector chamfer_forward(torch::Tensor xyz1, 23 | torch::Tensor xyz2) { 24 | return chamfer_cuda_forward(xyz1, xyz2); 25 | } 26 | 27 | std::vector chamfer_backward(torch::Tensor xyz1, 28 | torch::Tensor xyz2, 29 | torch::Tensor idx1, 30 | torch::Tensor idx2, 31 | torch::Tensor grad_dist1, 32 | torch::Tensor grad_dist2) { 33 | return chamfer_cuda_backward(xyz1, xyz2, idx1, idx2, grad_dist1, grad_dist2); 34 | } 35 | 36 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 37 | m.def("forward", &chamfer_forward, "Chamfer forward (CUDA)"); 38 | m.def("backward", &chamfer_backward, "Chamfer backward (CUDA)"); 39 | } 40 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/setup.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-08-07 20:54:24 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-10 10:04:25 6 | # @Email: cshzxie@gmail.com 7 | 8 | from setuptools import setup 9 | from torch.utils.cpp_extension import BuildExtension, CUDAExtension 10 | 11 | setup(name='chamfer', 12 | version='2.0.0', 13 | ext_modules=[ 14 | CUDAExtension('chamfer', [ 15 | 'chamfer_cuda.cpp', 16 | 'chamfer.cu', 17 | ]), 18 | ], 19 | cmdclass={'build_ext': BuildExtension}) 20 | -------------------------------------------------------------------------------- /extensions/chamfer_dist/test.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | # @Author: Haozhe Xie 3 | # @Date: 2019-12-10 10:38:01 4 | # @Last Modified by: Haozhe Xie 5 | # @Last Modified time: 2019-12-26 14:21:36 6 | # @Email: cshzxie@gmail.com 7 | # 8 | # Note: 9 | # - Replace float -> double, kFloat -> kDouble in chamfer.cu 10 | 11 | import os 12 | import sys 13 | import torch 14 | import unittest 15 | 16 | 17 | from torch.autograd import gradcheck 18 | 19 | sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir, os.path.pardir))) 20 | from extensions.chamfer_dist import ChamferFunction 21 | 22 | 23 | class ChamferDistanceTestCase(unittest.TestCase): 24 | def test_chamfer_dist(self): 25 | x = torch.rand(4, 64, 3).double() 26 | y = torch.rand(4, 128, 3).double() 27 | x.requires_grad = True 28 | y.requires_grad = True 29 | print(gradcheck(ChamferFunction.apply, [x.cuda(), y.cuda()])) 30 | 31 | 32 | 33 | if __name__ == '__main__': 34 | # unittest.main() 35 | import pdb 36 | x = torch.rand(32,128,3) 37 | y = torch.rand(32,128,3) 38 | pdb.set_trace() 39 | -------------------------------------------------------------------------------- /figs/cars.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junshengzhou/LevelSetUDF/4c6d37b5a010a0adb296f6d17431ef6c561d0204/figs/cars.png -------------------------------------------------------------------------------- /figs/kitti.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junshengzhou/LevelSetUDF/4c6d37b5a010a0adb296f6d17431ef6c561d0204/figs/kitti.png -------------------------------------------------------------------------------- /figs/pcpnet.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junshengzhou/LevelSetUDF/4c6d37b5a010a0adb296f6d17431ef6c561d0204/figs/pcpnet.png -------------------------------------------------------------------------------- /figs/pugan.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junshengzhou/LevelSetUDF/4c6d37b5a010a0adb296f6d17431ef6c561d0204/figs/pugan.png -------------------------------------------------------------------------------- /figs/scenes.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/junshengzhou/LevelSetUDF/4c6d37b5a010a0adb296f6d17431ef6c561d0204/figs/scenes.png -------------------------------------------------------------------------------- /figs/tmp.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /models/dataset.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn.functional as F 3 | import numpy as np 4 | import os 5 | from scipy.spatial import cKDTree 6 | import trimesh 7 | 8 | def search_nearest_point(point_batch, point_gt): 9 | num_point_batch, num_point_gt = point_batch.shape[0], point_gt.shape[0] 10 | point_batch = point_batch.unsqueeze(1).repeat(1, num_point_gt, 1) 11 | point_gt = point_gt.unsqueeze(0).repeat(num_point_batch, 1, 1) 12 | 13 | distances = torch.sqrt(torch.sum((point_batch-point_gt) ** 2, axis=-1) + 1e-12) 14 | dis_idx = torch.argmin(distances, axis=1).detach().cpu().numpy() 15 | 16 | return dis_idx 17 | 18 | def process_data(data_dir, dataname): 19 | if os.path.exists(os.path.join(data_dir, 'input', dataname) + '.ply'): 20 | pointcloud = trimesh.load(os.path.join(data_dir, 'input', dataname) + '.ply').vertices 21 | pointcloud = np.asarray(pointcloud) 22 | elif os.path.exists(os.path.join(data_dir, 'input', dataname) + '.xyz'): 23 | pointcloud = np.loadtxt(os.path.join(data_dir, 'input', dataname) + '.xyz') 24 | elif os.path.exists(os.path.join(data_dir, 'input', dataname) + '.npy'): 25 | pointcloud = np.load(os.path.join(data_dir, 'input', dataname) + '.npy') 26 | else: 27 | print('Only support .ply, .xyz or .npy data. Please adjust your data format.') 28 | exit() 29 | shape_scale = np.max([np.max(pointcloud[:,0])-np.min(pointcloud[:,0]),np.max(pointcloud[:,1])-np.min(pointcloud[:,1]),np.max(pointcloud[:,2])-np.min(pointcloud[:,2])]) 30 | shape_center = [(np.max(pointcloud[:,0])+np.min(pointcloud[:,0]))/2, (np.max(pointcloud[:,1])+np.min(pointcloud[:,1]))/2, (np.max(pointcloud[:,2])+np.min(pointcloud[:,2]))/2] 31 | pointcloud = pointcloud - shape_center 32 | pointcloud = pointcloud / shape_scale 33 | 34 | POINT_NUM = pointcloud.shape[0] // 60 35 | POINT_NUM_GT = pointcloud.shape[0] // 60 * 60 36 | QUERY_EACH = 1000000//POINT_NUM_GT 37 | 38 | point_idx = np.random.choice(pointcloud.shape[0], POINT_NUM_GT, replace = False) 39 | pointcloud = pointcloud[point_idx,:] 40 | ptree = cKDTree(pointcloud) 41 | sigmas = [] 42 | for p in np.array_split(pointcloud,100,axis=0): 43 | d = ptree.query(p,51) 44 | sigmas.append(d[0][:,-1]) 45 | 46 | sigmas = np.concatenate(sigmas) 47 | sample = [] 48 | sample_near = [] 49 | 50 | for i in range(QUERY_EACH): 51 | theta = 0.25 52 | scale = theta if theta * np.sqrt(POINT_NUM_GT / 20000) < theta else theta * np.sqrt(POINT_NUM_GT / 20000) 53 | tt = pointcloud + scale*np.expand_dims(sigmas,-1) * np.random.normal(0.0, 1.0, size=pointcloud.shape) 54 | sample.append(tt) 55 | tt = tt.reshape(-1,POINT_NUM,3) 56 | 57 | sample_near_tmp = [] 58 | for j in range(tt.shape[0]): 59 | nearest_idx = search_nearest_point(torch.tensor(tt[j]).float().cuda(), torch.tensor(pointcloud).float().cuda()) 60 | nearest_points = pointcloud[nearest_idx] 61 | nearest_points = np.asarray(nearest_points).reshape(-1,3) 62 | sample_near_tmp.append(nearest_points) 63 | sample_near_tmp = np.asarray(sample_near_tmp) 64 | sample_near_tmp = sample_near_tmp.reshape(-1,3) 65 | sample_near.append(sample_near_tmp) 66 | 67 | sample = np.asarray(sample) 68 | sample_near = np.asarray(sample_near) 69 | 70 | os.makedirs(os.path.join(data_dir, 'query_data'), exist_ok=True) 71 | np.savez(os.path.join(data_dir, 'query_data', dataname)+'.npz', sample = sample, point = pointcloud, sample_near = sample_near) 72 | 73 | class Dataset: 74 | def __init__(self, conf, dataname): 75 | super(Dataset, self).__init__() 76 | print('Load data: Begin') 77 | self.device = torch.device('cuda') 78 | self.conf = conf 79 | 80 | self.data_dir = conf.get_string('data_dir') 81 | self.data_name = dataname + '.npz' 82 | 83 | if os.path.exists(os.path.join(self.data_dir, 'query_data', self.data_name)): 84 | print('Query data existing. Loading data...') 85 | else: 86 | print('Query data not found. Processing data...') 87 | process_data(self.data_dir, dataname) 88 | 89 | load_data = np.load(os.path.join(self.data_dir, 'query_data', self.data_name)) 90 | 91 | self.point = np.asarray(load_data['sample_near']).reshape(-1,3) 92 | self.sample = np.asarray(load_data['sample']).reshape(-1,3) 93 | self.point_gt = np.asarray(load_data['point']).reshape(-1,3) 94 | self.sample_points_num = self.sample.shape[0]-1 95 | 96 | self.object_bbox_min = np.array([np.min(self.point[:,0]), np.min(self.point[:,1]), np.min(self.point[:,2])]) -0.05 97 | self.object_bbox_max = np.array([np.max(self.point[:,0]), np.max(self.point[:,1]), np.max(self.point[:,2])]) +0.05 98 | print('bd:',self.object_bbox_min,self.object_bbox_max) 99 | 100 | self.point = torch.from_numpy(self.point).to(self.device).float() 101 | self.sample = torch.from_numpy(self.sample).to(self.device).float() 102 | self.point_gt = torch.from_numpy(self.point_gt).to(self.device).float() 103 | 104 | print('NP Load data: End') 105 | 106 | def get_train_data(self, batch_size): 107 | index_coarse = np.random.choice(10, 1) 108 | index_fine = np.random.choice(self.sample_points_num//10, batch_size, replace = False) 109 | index = index_fine * 10 + index_coarse # for accelerating random choice operation 110 | points = self.point[index] 111 | sample = self.sample[index] 112 | return points, sample, self.point_gt 113 | 114 | def gen_new_data(self, tree): 115 | distance, index = tree.query(self.sample.detach().cpu().numpy(), 1) 116 | self.point_new = tree.data[index] 117 | self.point_new = torch.from_numpy(self.point_new).to(self.device).float() 118 | 119 | 120 | def get_train_data_step2(self, batch_size): 121 | index_coarse = np.random.choice(10, 1) 122 | index_fine = np.random.choice(self.sample_points_num//10, batch_size, replace = False) 123 | index = index_fine * 10 + index_coarse # for accelerating random choice operation 124 | points = self.point_new[index] 125 | sample = self.sample[index] 126 | return points, sample, self.point_gt 127 | 128 | -------------------------------------------------------------------------------- /models/embedder.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | # Positional encoding embedding. Code was taken from https://github.com/bmild/nerf. 6 | class Embedder: 7 | def __init__(self, **kwargs): 8 | self.kwargs = kwargs 9 | self.create_embedding_fn() 10 | 11 | def create_embedding_fn(self): 12 | embed_fns = [] 13 | d = self.kwargs['input_dims'] 14 | out_dim = 0 15 | if self.kwargs['include_input']: 16 | embed_fns.append(lambda x: x) 17 | out_dim += d 18 | 19 | max_freq = self.kwargs['max_freq_log2'] 20 | N_freqs = self.kwargs['num_freqs'] 21 | 22 | if self.kwargs['log_sampling']: 23 | freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs) 24 | else: 25 | freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs) 26 | 27 | for freq in freq_bands: 28 | for p_fn in self.kwargs['periodic_fns']: 29 | embed_fns.append(lambda x, p_fn=p_fn, freq=freq: p_fn(x * freq)) 30 | out_dim += d 31 | 32 | self.embed_fns = embed_fns 33 | self.out_dim = out_dim 34 | 35 | def embed(self, inputs): 36 | return torch.cat([fn(inputs) for fn in self.embed_fns], -1) 37 | 38 | 39 | def get_embedder(multires, input_dims=3): 40 | embed_kwargs = { 41 | 'include_input': True, 42 | 'input_dims': input_dims, 43 | 'max_freq_log2': multires-1, 44 | 'num_freqs': multires, 45 | 'log_sampling': True, 46 | 'periodic_fns': [torch.sin, torch.cos], 47 | } 48 | 49 | embedder_obj = Embedder(**embed_kwargs) 50 | def embed(x, eo=embedder_obj): return eo.embed(x) 51 | return embed, embedder_obj.out_dim 52 | -------------------------------------------------------------------------------- /models/fields.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | from models.embedder import get_embedder 6 | 7 | class LevelSetUDFNetwork(nn.Module): 8 | def __init__(self, 9 | d_in, 10 | d_out, 11 | d_hidden, 12 | n_layers, 13 | skip_in=(4,), 14 | multires=0, 15 | bias=0.5, 16 | scale=1, 17 | geometric_init=True, 18 | weight_norm=True, 19 | inside_outside=False): 20 | super(LevelSetUDFNetwork, self).__init__() 21 | 22 | dims = [d_in] + [d_hidden for _ in range(n_layers)] + [d_out] 23 | 24 | self.embed_fn_fine = None 25 | 26 | if multires > 0: 27 | embed_fn, input_ch = get_embedder(multires, input_dims=d_in) 28 | self.embed_fn_fine = embed_fn 29 | dims[0] = input_ch 30 | 31 | self.num_layers = len(dims) 32 | self.skip_in = skip_in 33 | self.scale = scale 34 | 35 | for l in range(0, self.num_layers - 1): 36 | if l + 1 in self.skip_in: 37 | out_dim = dims[l + 1] - dims[0] 38 | else: 39 | out_dim = dims[l + 1] 40 | 41 | lin = nn.Linear(dims[l], out_dim) 42 | 43 | if geometric_init: 44 | if l == self.num_layers - 2: 45 | if not inside_outside: 46 | torch.nn.init.normal_(lin.weight, mean=np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 47 | torch.nn.init.constant_(lin.bias, -bias) 48 | else: 49 | torch.nn.init.normal_(lin.weight, mean=-np.sqrt(np.pi) / np.sqrt(dims[l]), std=0.0001) 50 | torch.nn.init.constant_(lin.bias, bias) 51 | elif multires > 0 and l == 0: 52 | torch.nn.init.constant_(lin.bias, 0.0) 53 | torch.nn.init.constant_(lin.weight[:, 3:], 0.0) 54 | torch.nn.init.normal_(lin.weight[:, :3], 0.0, np.sqrt(2) / np.sqrt(out_dim)) 55 | elif multires > 0 and l in self.skip_in: 56 | torch.nn.init.constant_(lin.bias, 0.0) 57 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 58 | torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):], 0.0) 59 | else: 60 | torch.nn.init.constant_(lin.bias, 0.0) 61 | torch.nn.init.normal_(lin.weight, 0.0, np.sqrt(2) / np.sqrt(out_dim)) 62 | 63 | if weight_norm: 64 | lin = nn.utils.weight_norm(lin) 65 | 66 | setattr(self, "lin" + str(l), lin) 67 | 68 | self.activation = nn.ReLU() 69 | 70 | self.act_last = nn.Sigmoid() 71 | 72 | def forward(self, inputs): 73 | inputs = inputs * self.scale 74 | if self.embed_fn_fine is not None: 75 | inputs = self.embed_fn_fine(inputs) 76 | 77 | x = inputs 78 | for l in range(0, self.num_layers - 1): 79 | lin = getattr(self, "lin" + str(l)) 80 | 81 | if l in self.skip_in: 82 | x = torch.cat([x, inputs], 1) / np.sqrt(2) 83 | 84 | x = lin(x) 85 | 86 | if l < self.num_layers - 2: 87 | x = self.activation(x) 88 | 89 | res = torch.abs(x) 90 | return res / self.scale 91 | 92 | def udf(self, x): 93 | return self.forward(x) 94 | 95 | def udf_hidden_appearance(self, x): 96 | return self.forward(x) 97 | 98 | def gradient(self, x): 99 | x.requires_grad_(True) 100 | y = self.udf(x) 101 | d_output = torch.ones_like(y, requires_grad=False, device=y.device) 102 | gradients = torch.autograd.grad( 103 | outputs=y, 104 | inputs=x, 105 | grad_outputs=d_output, 106 | create_graph=True, 107 | retain_graph=True, 108 | only_inputs=True)[0] 109 | return gradients.unsqueeze(1) 110 | 111 | -------------------------------------------------------------------------------- /run.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | import time 4 | import torch 5 | import torch.nn.functional as F 6 | from tqdm import tqdm 7 | from models.dataset import Dataset 8 | from models.fields import LevelSetUDFNetwork 9 | import argparse 10 | from pyhocon import ConfigFactory 11 | import os 12 | from shutil import copyfile 13 | import numpy as np 14 | from tools.logger import get_logger, get_root_logger, print_log 15 | from tools.utils import remove_far, remove_outlier 16 | from tools.surface_extraction import as_mesh, surface_extraction 17 | from extensions.chamfer_dist import ChamferDistanceL1, ChamferDistanceL2 18 | import math 19 | import warnings 20 | warnings.filterwarnings('ignore') 21 | 22 | def extract_fields(bound_min, bound_max, resolution, query_func, grad_func): 23 | N = 32 24 | X = torch.linspace(bound_min[0], bound_max[0], resolution).split(N) 25 | Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(N) 26 | Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(N) 27 | 28 | u = np.zeros([resolution, resolution, resolution], dtype=np.float32) 29 | g = np.zeros([resolution, resolution, resolution, 3], dtype=np.float32) 30 | # with torch.no_grad(): 31 | for xi, xs in enumerate(X): 32 | for yi, ys in enumerate(Y): 33 | for zi, zs in enumerate(Z): 34 | xx, yy, zz = torch.meshgrid(xs, ys, zs) 35 | 36 | pts = torch.cat([xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)], dim=-1).cuda() 37 | 38 | grad = grad_func(pts).reshape(len(xs), len(ys), len(zs), 3).detach().cpu().numpy() 39 | val = query_func(pts).reshape(len(xs), len(ys), len(zs)).detach().cpu().numpy() 40 | u[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = val 41 | g[xi * N: xi * N + len(xs), yi * N: yi * N + len(ys), zi * N: zi * N + len(zs)] = grad 42 | 43 | return u, g 44 | 45 | def extract_geometry(bound_min, bound_max, resolution, threshold, out_dir, iter_step, dataname, logger, query_func, grad_func): 46 | 47 | print('Extracting mesh with resolution: {}'.format(resolution)) 48 | u, g = extract_fields(bound_min, bound_max, resolution, query_func, grad_func) 49 | b_max = bound_max.detach().cpu().numpy() 50 | b_min = bound_min.detach().cpu().numpy() 51 | mesh = surface_extraction(u, g, out_dir, iter_step, b_max, b_min, resolution) 52 | 53 | return mesh 54 | 55 | class Runner: 56 | def __init__(self, args, conf_path): 57 | self.device = torch.device('cuda') 58 | 59 | # Configuration 60 | self.conf_path = conf_path 61 | f = open(self.conf_path) 62 | conf_text = f.read() 63 | f.close() 64 | 65 | self.conf = ConfigFactory.parse_string(conf_text) 66 | self.base_exp_dir = self.conf['general.base_exp_dir'] + args.dir 67 | os.makedirs(self.base_exp_dir, exist_ok=True) 68 | 69 | 70 | self.dataset = Dataset(self.conf['dataset'], args.dataname) 71 | self.dataname = args.dataname 72 | self.iter_step = 0 73 | 74 | # Training parameters 75 | self.step1_maxiter = self.conf.get_int('train.step1_maxiter') 76 | self.step2_maxiter = self.conf.get_int('train.step2_maxiter') 77 | self.save_freq = self.conf.get_int('train.save_freq') 78 | self.report_freq = self.conf.get_int('train.report_freq') 79 | self.val_freq = self.conf.get_int('train.val_freq') 80 | self.val_mesh_freq = self.conf.get_int('train.val_mesh_freq') 81 | self.batch_size = self.conf.get_int('train.batch_size') 82 | self.batch_size_step2 = self.conf.get_int('train.batch_size_step2') 83 | self.learning_rate = self.conf.get_float('train.learning_rate') 84 | self.warm_up_end = self.conf.get_float('train.warm_up_end', default=0.0) 85 | self.eval_num_points = self.conf.get_int('train.eval_num_points') 86 | self.df_filter = self.conf.get_float('train.df_filter') 87 | 88 | self.proj_weight = self.conf.get_float('train.proj_weight') 89 | self.surf_weight = self.conf.get_float('train.surf_weight') 90 | self.orth_weight = self.conf.get_float('train.orth_weight') 91 | self.proj_adapt = self.conf.get_float('train.proj_adapt') 92 | 93 | self.ChamferDisL1 = ChamferDistanceL1().cuda() 94 | self.ChamferDisL2 = ChamferDistanceL2().cuda() 95 | 96 | # Weights 97 | self.igr_weight = self.conf.get_float('train.igr_weight') 98 | self.mask_weight = self.conf.get_float('train.mask_weight') 99 | self.model_list = [] 100 | self.writer = None 101 | 102 | # Networks 103 | self.udf_network = LevelSetUDFNetwork(**self.conf['model.udf_network']).to(self.device) 104 | if self.conf.get_string('train.load_ckpt') != 'none': 105 | self.udf_network.load_state_dict(torch.load(self.conf.get_string('train.load_ckpt'), map_location=self.device)["udf_network_fine"]) 106 | 107 | self.optimizer = torch.optim.Adam(self.udf_network.parameters(), lr=self.learning_rate) 108 | 109 | # Backup codes and configs for debug 110 | self.file_backup() 111 | 112 | def train(self): 113 | timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime()) 114 | log_file = os.path.join(os.path.join(self.base_exp_dir), f'{timestamp}.log') 115 | logger = get_root_logger(log_file=log_file, name='outs') 116 | self.logger = logger 117 | batch_size = self.batch_size 118 | 119 | for iter_i in tqdm(range(self.iter_step, self.step1_maxiter)): 120 | self.update_learning_rate(self.iter_step) 121 | 122 | points, samples, point_gt = self.dataset.get_train_data(batch_size) 123 | 124 | samples.requires_grad = True 125 | gradients_sample = self.udf_network.gradient(samples).squeeze() # 5000x3 126 | udf_sample = self.udf_network.udf(samples) # 5000x1 127 | grad_norm = F.normalize(gradients_sample, dim=1) # 5000x3 128 | sample_moved = samples - grad_norm * udf_sample # 5000x3 129 | 130 | # --------------------levelset projection loss ------------------------- 131 | grad_moved = self.udf_network.gradient(sample_moved).squeeze() 132 | grad_moved_norm = F.normalize(grad_moved, dim=-1) 133 | consis_constraint = 1 - torch.abs(F.cosine_similarity(grad_moved_norm, grad_norm, dim=-1)) 134 | weight_moved = torch.exp(-self.proj_adapt * torch.abs(udf_sample)).reshape(-1,consis_constraint.shape[-1]) 135 | consis_constraint = consis_constraint * weight_moved 136 | loss_proj = consis_constraint.mean() * self.proj_weight 137 | 138 | # --------------------surface regularzer loss --------------------------- 139 | udf_surface = self.udf_network.udf(point_gt) 140 | loss_surf = torch.mean(udf_surface) * self.surf_weight 141 | 142 | 143 | #---------------------gradient-surface orthogonal loss ------------------ 144 | grad_gt_norm = F.normalize(samples - points, dim=1) 145 | loss_orth = (1 - torch.abs(F.cosine_similarity(grad_norm, grad_gt_norm, dim=1))).mean() * self.orth_weight 146 | 147 | loss_cd = self.ChamferDisL1(points.unsqueeze(0), sample_moved.unsqueeze(0)) 148 | 149 | loss = loss_cd + loss_proj + loss_surf + loss_orth 150 | 151 | self.optimizer.zero_grad() 152 | loss.backward() 153 | self.optimizer.step() 154 | 155 | self.iter_step += 1 156 | if self.iter_step % self.report_freq == 0: 157 | print_log('iter:{:8>d} cd_l1 = {} consis = {} surf = {} query_grad = {} lr={}'.format(self.iter_step, loss_cd, loss_proj, loss_surf, loss_orth, self.optimizer.param_groups[0]['lr']), logger=logger) 158 | 159 | if self.iter_step % self.save_freq == 0: 160 | self.save_checkpoint() 161 | 162 | if self.iter_step == self.step1_maxiter: 163 | _ = self.gen_extra_pointcloud(self.iter_step, 1) 164 | 165 | if self.iter_step % self.val_freq == 0: 166 | grad_surf = self.udf_network.gradient(point_gt) 167 | grad_surf_norm = F.normalize(grad_surf, dim=-1) 168 | out_dir_norm = os.path.join(self.base_exp_dir, 'normal') 169 | os.makedirs(out_dir_norm, exist_ok=True) 170 | np.save(os.path.join(out_dir_norm, 'normal_%d.npy' % self.iter_step), grad_surf_norm.detach().cpu().numpy()) 171 | 172 | if self.iter_step % self.val_freq == 0 and self.iter_step != self.step1_maxiter: 173 | self.extract_mesh(resolution=128, threshold=0.0, point_gt=point_gt, iter_step=self.iter_step, logger=logger) 174 | 175 | if self.iter_step == self.step1_maxiter: 176 | self.extract_mesh(resolution=args.mcube_resolution, threshold=0.0, point_gt=point_gt, iter_step=self.iter_step, logger=logger) 177 | 178 | 179 | def extract_mesh(self, resolution=64, threshold=0.0, point_gt=None, iter_step=0, logger=None): 180 | 181 | bound_min = torch.tensor(self.dataset.object_bbox_min, dtype=torch.float32) 182 | bound_max = torch.tensor(self.dataset.object_bbox_max, dtype=torch.float32) 183 | out_dir = os.path.join(self.base_exp_dir, 'mesh') 184 | os.makedirs(out_dir, exist_ok=True) 185 | 186 | mesh = extract_geometry(bound_min, bound_max, resolution=resolution, threshold=threshold, \ 187 | out_dir=out_dir, iter_step=iter_step, dataname=self.dataname, logger=logger, \ 188 | query_func=lambda pts: self.udf_network.udf(pts), grad_func=lambda pts: self.udf_network.gradient(pts)) 189 | if self.conf.get_float('train.far') > 0: 190 | mesh = remove_far(point_gt.detach().cpu().numpy(), mesh, self.conf.get_float('train.far')) 191 | 192 | mesh.export(out_dir+'/'+str(iter_step)+'_mesh.obj') 193 | 194 | 195 | 196 | def gen_extra_pointcloud(self, iter_step, low_range): 197 | 198 | res = [] 199 | num_points = self.eval_num_points 200 | gen_nums = 0 201 | 202 | os.makedirs(os.path.join(self.base_exp_dir, 'densepoints'), exist_ok=True) 203 | 204 | while gen_nums < num_points: 205 | 206 | points, samples, point_gt = self.dataset.get_train_data(5000) 207 | offsets = samples - points 208 | std = torch.std(offsets) 209 | 210 | extra_std = std * low_range 211 | rands = torch.normal(0.0, extra_std, size=points.shape) 212 | samples = points + torch.tensor(rands).cuda().float() 213 | 214 | samples.requires_grad = True 215 | gradients_sample = self.udf_network.gradient(samples).squeeze() # 5000x3 216 | udf_sample = self.udf_network.udf(samples) # 5000x1 217 | grad_norm = F.normalize(gradients_sample, dim=1) # 5000x3 218 | sample_moved = samples - grad_norm * udf_sample # 5000x3 219 | 220 | index = udf_sample < self.df_filter 221 | index = index.squeeze(1) 222 | sample_moved = sample_moved[index] 223 | 224 | gen_nums += sample_moved.shape[0] 225 | 226 | res.append(sample_moved.detach().cpu().numpy()) 227 | 228 | res = np.concatenate(res) 229 | res = res[:num_points] 230 | np.savetxt(os.path.join(self.base_exp_dir, 'densepoints', 'densepoints_%d.xyz'%(iter_step)), res) 231 | 232 | res = remove_outlier(point_gt.detach().cpu().numpy(), res, dis_trunc=self.conf.get_float('train.outlier')) 233 | return res 234 | 235 | def update_learning_rate(self, iter_step): 236 | 237 | warn_up = self.warm_up_end 238 | max_iter = self.step1_maxiter 239 | init_lr = self.learning_rate 240 | lr = (iter_step / warn_up) if iter_step < warn_up else 0.5 * (math.cos((iter_step - warn_up)/(max_iter - warn_up) * math.pi) + 1) 241 | lr = lr * init_lr 242 | 243 | for g in self.optimizer.param_groups: 244 | g['lr'] = lr 245 | 246 | def file_backup(self): 247 | dir_lis = self.conf['general.recording'] 248 | os.makedirs(os.path.join(self.base_exp_dir, 'recording'), exist_ok=True) 249 | for dir_name in dir_lis: 250 | cur_dir = os.path.join(self.base_exp_dir, 'recording', dir_name) 251 | os.makedirs(cur_dir, exist_ok=True) 252 | files = os.listdir(dir_name) 253 | for f_name in files: 254 | if f_name[-3:] == '.py': 255 | copyfile(os.path.join(dir_name, f_name), os.path.join(cur_dir, f_name)) 256 | 257 | copyfile(self.conf_path, os.path.join(self.base_exp_dir, 'recording', 'config.conf')) 258 | 259 | def load_checkpoint(self, checkpoint_name): 260 | checkpoint = torch.load(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name), map_location=self.device) 261 | print(os.path.join(self.base_exp_dir, 'checkpoints', checkpoint_name)) 262 | self.udf_network.load_state_dict(checkpoint['udf_network_fine']) 263 | 264 | self.iter_step = checkpoint['iter_step'] 265 | 266 | def save_checkpoint(self): 267 | checkpoint = { 268 | 'udf_network_fine': self.udf_network.state_dict(), 269 | 'iter_step': self.iter_step, 270 | } 271 | os.makedirs(os.path.join(self.base_exp_dir, 'checkpoints'), exist_ok=True) 272 | torch.save(checkpoint, os.path.join(self.base_exp_dir, 'checkpoints', 'ckpt_{:0>6d}.pth'.format(self.iter_step))) 273 | 274 | 275 | if __name__ == '__main__': 276 | torch.set_default_tensor_type('torch.cuda.FloatTensor') 277 | parser = argparse.ArgumentParser() 278 | parser.add_argument('--conf', type=str, default='./confs/ndf.conf') 279 | parser.add_argument('--mcube_resolution', type=int, default=256) 280 | parser.add_argument('--gpu', type=int, default=0) 281 | parser.add_argument('--dir', type=str, default='test') 282 | parser.add_argument('--dataname', type=str, default='demo') 283 | args = parser.parse_args() 284 | 285 | torch.cuda.set_device(args.gpu) 286 | runner = Runner(args, args.conf) 287 | 288 | runner.train() 289 | -------------------------------------------------------------------------------- /tools/logger.py: -------------------------------------------------------------------------------- 1 | import logging 2 | import torch.distributed as dist 3 | 4 | logger_initialized = {} 5 | 6 | def get_root_logger(log_file=None, log_level=logging.INFO, name='main'): 7 | """Get root logger and add a keyword filter to it. 8 | The logger will be initialized if it has not been initialized. By default a 9 | StreamHandler will be added. If `log_file` is specified, a FileHandler will 10 | also be added. The name of the root logger is the top-level package name, 11 | e.g., "mmdet3d". 12 | Args: 13 | log_file (str, optional): File path of log. Defaults to None. 14 | log_level (int, optional): The level of logger. 15 | Defaults to logging.INFO. 16 | name (str, optional): The name of the root logger, also used as a 17 | filter keyword. Defaults to 'mmdet3d'. 18 | Returns: 19 | :obj:`logging.Logger`: The obtained logger 20 | """ 21 | logger = get_logger(name=name, log_file=log_file, log_level=log_level) 22 | # add a logging filter 23 | logging_filter = logging.Filter(name) 24 | logging_filter.filter = lambda record: record.find(name) != -1 25 | 26 | return logger 27 | 28 | 29 | def get_logger(name, log_file=None, log_level=logging.INFO, file_mode='w'): 30 | """Initialize and get a logger by name. 31 | If the logger has not been initialized, this method will initialize the 32 | logger by adding one or two handlers, otherwise the initialized logger will 33 | be directly returned. During initialization, a StreamHandler will always be 34 | added. If `log_file` is specified and the process rank is 0, a FileHandler 35 | will also be added. 36 | Args: 37 | name (str): Logger name. 38 | log_file (str | None): The log filename. If specified, a FileHandler 39 | will be added to the logger. 40 | log_level (int): The logger level. Note that only the process of 41 | rank 0 is affected, and other processes will set the level to 42 | "Error" thus be silent most of the time. 43 | file_mode (str): The file mode used in opening log file. 44 | Defaults to 'w'. 45 | Returns: 46 | logging.Logger: The expected logger. 47 | """ 48 | logger = logging.getLogger(name) 49 | if name in logger_initialized: 50 | return logger 51 | # handle hierarchical names 52 | # e.g., logger "a" is initialized, then logger "a.b" will skip the 53 | # initialization since it is a child of "a". 54 | for logger_name in logger_initialized: 55 | if name.startswith(logger_name): 56 | return logger 57 | 58 | # handle duplicate logs to the console 59 | # Starting in 1.8.0, PyTorch DDP attaches a StreamHandler (NOTSET) 60 | # to the root logger. As logger.propagate is True by default, this root 61 | # level handler causes logging messages from rank>0 processes to 62 | # unexpectedly show up on the console, creating much unwanted clutter. 63 | # To fix this issue, we set the root logger's StreamHandler, if any, to log 64 | # at the ERROR level. 65 | for handler in logger.root.handlers: 66 | if type(handler) is logging.StreamHandler: 67 | handler.setLevel(logging.ERROR) 68 | 69 | stream_handler = logging.StreamHandler() 70 | handlers = [stream_handler] 71 | 72 | if dist.is_available() and dist.is_initialized(): 73 | rank = dist.get_rank() 74 | else: 75 | rank = 0 76 | 77 | # only rank 0 will add a FileHandler 78 | if rank == 0 and log_file is not None: 79 | # Here, the default behaviour of the official logger is 'a'. Thus, we 80 | # provide an interface to change the file mode to the default 81 | # behaviour. 82 | file_handler = logging.FileHandler(log_file, file_mode) 83 | handlers.append(file_handler) 84 | 85 | formatter = logging.Formatter( 86 | '%(asctime)s - %(name)s - %(levelname)s - %(message)s') 87 | for handler in handlers: 88 | handler.setFormatter(formatter) 89 | handler.setLevel(log_level) 90 | logger.addHandler(handler) 91 | 92 | if rank == 0: 93 | logger.setLevel(log_level) 94 | else: 95 | logger.setLevel(logging.ERROR) 96 | 97 | logger_initialized[name] = True 98 | 99 | 100 | return logger 101 | 102 | 103 | def print_log(msg, logger=None, level=logging.INFO): 104 | """Print a log message. 105 | Args: 106 | msg (str): The message to be logged. 107 | logger (logging.Logger | str | None): The logger to be used. 108 | Some special loggers are: 109 | - "silent": no message will be printed. 110 | - other str: the logger obtained with `get_root_logger(logger)`. 111 | - None: The `print()` method will be used to print log messages. 112 | level (int): Logging level. Only available when `logger` is a Logger 113 | object or "root". 114 | """ 115 | if logger is None: 116 | print(msg) 117 | elif isinstance(logger, logging.Logger): 118 | logger.log(level, msg) 119 | elif logger == 'silent': 120 | pass 121 | elif isinstance(logger, str): 122 | _logger = get_logger(logger) 123 | _logger.log(level, msg) 124 | else: 125 | raise TypeError( 126 | 'logger should be either a logging.Logger object, str, ' 127 | f'"silent" or None, but got {type(logger)}') -------------------------------------------------------------------------------- /tools/surface_extraction.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import mcubes 3 | import trimesh 4 | import torch 5 | import sys 6 | from extensions.chamfer_dist import ChamferDistanceL2 7 | from tools.logger import print_log 8 | 9 | def as_mesh(scene_or_mesh): 10 | """ 11 | Convert a possible scene to a mesh. 12 | 13 | If conversion occurs, the returned mesh has only vertex and face data. 14 | Suggested by https://github.com/mikedh/trimesh/issues/507 15 | """ 16 | if isinstance(scene_or_mesh, trimesh.Scene): 17 | if len(scene_or_mesh.geometry) == 0: 18 | mesh = None # empty scene 19 | else: 20 | # we lose texture information here 21 | mesh = trimesh.util.concatenate( 22 | tuple(trimesh.Trimesh(vertices=g.vertices, faces=g.faces) 23 | for g in scene_or_mesh.geometry.values())) 24 | else: 25 | assert(isinstance(scene_or_mesh, trimesh.Trimesh)) 26 | mesh = scene_or_mesh 27 | return mesh 28 | 29 | def surface_extraction(ndf, grad, out_path, iter_step, b_max, b_min, resolution): 30 | v_all = [] 31 | t_all = [] 32 | threshold = 0.005 # accelerate extraction 33 | v_num = 0 34 | for i in range(resolution-1): 35 | for j in range(resolution-1): 36 | for k in range(resolution-1): 37 | ndf_loc = ndf[i:i+2] 38 | ndf_loc = ndf_loc[:,j:j+2,:] 39 | ndf_loc = ndf_loc[:,:,k:k+2] 40 | if np.min(ndf_loc) > threshold: 41 | continue 42 | grad_loc = grad[i:i+2] 43 | grad_loc = grad_loc[:,j:j+2,:] 44 | grad_loc = grad_loc[:,:,k:k+2] 45 | 46 | res = np.ones((2,2,2)) 47 | for ii in range(2): 48 | for jj in range(2): 49 | for kk in range(2): 50 | if np.dot(grad_loc[0][0][0], grad_loc[ii][jj][kk]) < 0: 51 | res[ii][jj][kk] = -ndf_loc[ii][jj][kk] 52 | else: 53 | res[ii][jj][kk] = ndf_loc[ii][jj][kk] 54 | 55 | if res.min()<0: 56 | vertices, triangles = mcubes.marching_cubes( 57 | res, 0.0) 58 | # print(vertices) 59 | # vertices -= 1.5 60 | # vertices /= 128 61 | vertices[:,0] += i #/ resolution 62 | vertices[:,1] += j #/ resolution 63 | vertices[:,2] += k #/ resolution 64 | triangles += v_num 65 | # vertices = 66 | # vertices[:,1] /= 3 # TODO 67 | v_all.append(vertices) 68 | t_all.append(triangles) 69 | 70 | v_num += vertices.shape[0] 71 | # print(v_num) 72 | 73 | v_all = np.concatenate(v_all) 74 | t_all = np.concatenate(t_all) 75 | # Create mesh 76 | v_all = v_all / (resolution - 1.0) * (b_max - b_min)[None, :] + b_min[None, :] 77 | 78 | mesh = trimesh.Trimesh(v_all, t_all, process=False) 79 | 80 | return mesh 81 | -------------------------------------------------------------------------------- /tools/utils.py: -------------------------------------------------------------------------------- 1 | # -*- coding: utf-8 -*- 2 | 3 | from random import sample 4 | import time 5 | from tkinter import Variable 6 | from shutil import copyfile 7 | import numpy as np 8 | import trimesh 9 | 10 | from scipy.spatial import cKDTree 11 | 12 | 13 | def get_aver(distances, face): 14 | return (distances[face[0]] + distances[face[1]] + distances[face[2]]) / 3.0 15 | 16 | def remove_far(gt_pts, mesh, dis_trunc=0.1, is_use_prj=False): 17 | # gt_pts: trimesh 18 | # mesh: trimesh 19 | 20 | gt_kd_tree = cKDTree(gt_pts) 21 | distances, vertex_ids = gt_kd_tree.query(mesh.vertices, p=2, distance_upper_bound=dis_trunc) 22 | faces_remaining = [] 23 | faces = mesh.faces 24 | 25 | if is_use_prj: 26 | normals = gt_pts.vertex_normals 27 | closest_points = gt_pts.vertices[vertex_ids] 28 | closest_normals = normals[vertex_ids] 29 | direction_from_surface = mesh.vertices - closest_points 30 | distances = direction_from_surface * closest_normals 31 | distances = np.sum(distances, axis=1) 32 | 33 | for i in range(faces.shape[0]): 34 | if get_aver(distances, faces[i]) < dis_trunc: 35 | faces_remaining.append(faces[i]) 36 | mesh_cleaned = mesh.copy() 37 | mesh_cleaned.faces = faces_remaining 38 | mesh_cleaned.remove_unreferenced_vertices() 39 | 40 | return mesh_cleaned 41 | 42 | def remove_outlier(gt_pts, q_pts, dis_trunc=0.003, is_use_prj=False): 43 | # gt_pts: trimesh 44 | # mesh: trimesh 45 | 46 | gt_kd_tree = cKDTree(gt_pts) 47 | distances, q_ids = gt_kd_tree.query(q_pts, p=2, distance_upper_bound=dis_trunc) 48 | 49 | q_pts = q_pts[distances