├── README.md ├── dataset.py ├── log └── 001 │ └── ckpts │ └── ckpt_900.pt ├── net ├── CMG_Net.py ├── decode.py ├── local_feature.py └── utils.py ├── test.py ├── train.py └── utils └── misc.py /README.md: -------------------------------------------------------------------------------- 1 | # CMG-Net: Robust Normal Estimation for Point Clouds via Chamfer Normal Distance and Multi-scale Geometry (AAAI 2024) 2 | 3 | ### *[ArXiv](https://arxiv.org/abs/2312.09154) 4 | 5 | This work presents an accurate and robust method for estimating normals from point clouds. In contrast to predecessor approaches that minimize the deviations between the annotated and the predicted normals directly, leading to direction inconsistency, we first propose a new metric termed Chamfer Normal Distance to address this issue. This not only mitigates the challenge but also facilitates network training and substantially enhances the network robustness against noise. Subsequently, we devise an innovative architecture that encompasses Multi-scale Local Feature Aggregation and Hierarchical Geometric Information Fusion. This design empowers the network to capture intricate geometric details more effectively and alleviate the ambiguity in scale selection. Extensive experiments demonstrate that our method achieves the state-of-the-art performance on both synthetic and real-world datasets, particularly in scenarios contaminated by noise. This project is the implementation of CMG-Net by Pytorch. 6 | 7 | ## Requirements 8 | The code is implemented in the following environment settings: 9 | - Ubuntu 20.04 10 | - CUDA 11.3 11 | - Python 3.8 12 | - Pytorch 1.12 13 | - Numpy 1.24 14 | - Scipy 1.10 15 | 16 | ## Dataset 17 | We train our network model on the [PCPNet](http://geometry.cs.ucl.ac.uk/projects/2018/pcpnet/pclouds.zip) dataset. 18 | Download the dataset to the folder `***/dataset/` and copy the list into the fold `***/dataset/PCPNet/list`. The dataset is organized as follows: 19 | ``` 20 | │dataset/ 21 | ├──PCPNet/ 22 | │ ├── list 23 | │ ├── ***.txt 24 | │ ├── ***.xyz 25 | │ ├── ***.normals 26 | │ ├── ***.pidx 27 | ``` 28 | 29 | ## Train 30 | Our trained model is provided in `./log/001/ckpts/ckpt_900.pt`. 31 | To train a new model on the PCPNet dataset, simply run: 32 | ``` 33 | python train.py 34 | ``` 35 | Your trained model will be save in `./log/***/ckpts/`. 36 | 37 | ## Test 38 | You can use the provided model for testing: 39 | ``` 40 | python test.py 41 | ``` 42 | The evaluation results will be saved in `./log/001/results_PCPNet/ckpt_900/`. 43 | 44 | To test with your trained model, you need to change the variables in `test.py`: 45 | ``` 46 | ckpt_dirs 47 | ckpt_iter 48 | ``` 49 | To save the normals of the input point cloud, you need to change the variables in `test.py`: 50 | ``` 51 | save_pn = True # to save the point normals as '.normals' file 52 | sparse_patches = False # to output sparse point normals or not 53 | ``` 54 | 55 | ## Acknowledgement 56 | The code is heavily based on [HSurf-Net](https://github.com/LeoQLi/HSurf-Net). 57 | If you find our work useful in your research, please cite the following papers: 58 | 59 | ``` 60 | @inproceedings{wu2024cmg, 61 | title={CMG-Net: Robust Normal Estimation for Point Clouds via Chamfer Normal Distance and Multi-scale Geometry}, 62 | author={Wu, Yingrui and Zhao, Mingyang and Li, Keqiang and Quan, Weize and Yu, Tianqi and Yang, Jianfeng and Jia, Xiaohong and Yan, Dong-Ming}, 63 | booktitle={Proceedings of the AAAI conference on artificial intelligence}, 64 | year={2024} 65 | } 66 | 67 | @inproceedings{ben2020deepfit, 68 | title={Deepfit: 3d surface fitting via neural network weighted least squares}, 69 | author={Ben-Shabat, Yizhak and Gould, Stephen}, 70 | booktitle={European conference on computer vision}, 71 | pages={20--34}, 72 | year={2020}, 73 | organization={Springer} 74 | } 75 | 76 | @article{li2022hsurf, 77 | title={HSurf-Net: Normal estimation for 3D point clouds by learning hyper surfaces}, 78 | author={Li, Qing and Liu, Yu-Shen and Cheng, Jin-San and Wang, Cheng and Fang, Yi and Han, Zhizhong}, 79 | journal={Advances in Neural Information Processing Systems}, 80 | volume={35}, 81 | pages={4218--4230}, 82 | year={2022} 83 | } 84 | ``` 85 | 86 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import numpy as np 3 | from tqdm.auto import tqdm 4 | import scipy.spatial as spatial 5 | import torch 6 | from torch.utils.data import Dataset 7 | 8 | # All shapes of PCPNet dataset 9 | all_train_sets = ['fandisk100k', 'bunny100k', 'armadillo100k', 'dragon_xyzrgb100k', 'boxunion_uniform100k', 10 | 'tortuga100k', 'flower100k', 'Cup33100k'] 11 | all_test_sets = ['galera100k', 'icosahedron100k', 'netsuke100k', 'Cup34100k', 'sphere100k', 12 | 'cylinder100k', 'star_smooth100k', 'star_halfsmooth100k', 'star_sharp100k', 'Liberty100k', 13 | 'boxunion2100k', 'pipe100k', 'pipe_curve100k', 'column100k', 'column_head100k', 14 | 'Boxy_smooth100k', 'sphere_analytic100k', 'cylinder_analytic100k', 'sheet_analytic100k'] 15 | all_val_sets = ['cylinder100k', 'galera100k', 'netsuke100k'] 16 | 17 | def load_data(filedir, filename, dtype=np.float32, wo=False): 18 | d = None 19 | filepath = os.path.join(filedir, 'npy', filename + '.npy') 20 | os.makedirs(os.path.join(filedir, 'npy'), exist_ok=True) 21 | if os.path.exists(filepath): 22 | if wo: 23 | return True 24 | d = np.load(filepath) 25 | else: 26 | d = np.loadtxt(os.path.join(filedir, filename), dtype=dtype) 27 | np.save(filepath, d) 28 | return d 29 | 30 | 31 | class PCATrans(object): 32 | def __init__(self): 33 | super().__init__() 34 | 35 | def __call__(self, data): 36 | # compute PCA of points in the patch, center the patch around the mean 37 | pts = data['pcl_pat'] 38 | pts_mean = pts.mean(0) 39 | pts = pts - pts_mean 40 | 41 | trans, _, _ = torch.svd(torch.t(pts)) # (3, 3) 42 | pts = torch.mm(pts, trans) 43 | 44 | # since the patch was originally centered, the original cp was at (0,0,0) 45 | cp_new = -pts_mean 46 | cp_new = torch.matmul(cp_new, trans) 47 | 48 | # re-center on original center point 49 | data['pcl_pat'] = pts - cp_new 50 | data['pca_trans'] = trans 51 | 52 | if 'center_normal' in data: 53 | data['center_normal'] = torch.matmul(data['center_normal'], trans) 54 | return data 55 | 56 | 57 | class SequentialPointcloudPatchSampler(torch.utils.data.sampler.Sampler): 58 | def __init__(self, data_source): 59 | self.data_source = data_source 60 | self.total_patch_count = sum(data_source.datasets.shape_patch_count) 61 | 62 | def __iter__(self): 63 | return iter(range(self.total_patch_count)) 64 | 65 | def __len__(self): 66 | return self.total_patch_count 67 | 68 | 69 | class RandomPointcloudPatchSampler(torch.utils.data.sampler.Sampler): 70 | # randomly get subset data from the whole dataset 71 | def __init__(self, data_source, patches_per_shape, seed=None, identical_epochs=False): 72 | self.data_source = data_source 73 | self.patches_per_shape = patches_per_shape 74 | self.seed = seed 75 | self.identical_epochs = identical_epochs 76 | 77 | if self.seed is None: 78 | self.seed = np.random.random_integers(0, 2**32-1, 1)[0] 79 | self.rng = np.random.RandomState(self.seed) 80 | 81 | self.total_patch_count = 0 82 | for shape_ind, _ in enumerate(data_source.datasets.shape_names): 83 | self.total_patch_count += min(self.patches_per_shape, data_source.datasets.shape_patch_count[shape_ind]) 84 | 85 | def __iter__(self): 86 | # optionally always pick the same permutation (mainly for debugging) 87 | if self.identical_epochs: 88 | self.rng.seed(self.seed) 89 | 90 | return iter(self.rng.choice(sum(self.data_source.datasets.shape_patch_count), size=self.total_patch_count, replace=False)) 91 | 92 | def __len__(self): 93 | return self.total_patch_count 94 | 95 | 96 | class PointCloudDataset(Dataset): 97 | def __init__(self, root, mode=None, data_set='', data_list='', sparse_patches=False): 98 | super().__init__() 99 | self.mode = mode 100 | self.data_set = data_set 101 | self.sparse_patches = sparse_patches 102 | self.data_dir = os.path.join(root, data_set) 103 | 104 | self.pointclouds = [] 105 | self.pointclouds_clean = [] 106 | self.shape_names = [] 107 | self.normals = [] 108 | self.pidxs = [] 109 | self.kdtrees = [] 110 | self.kdtrees_clean = [] 111 | self.shape_patch_count = [] # point number of each shape 112 | assert self.mode in ['train', 'val', 'test'] 113 | 114 | if len(data_list) > 0: 115 | # get all shape names 116 | cur_sets = [] 117 | with open(os.path.join(root, data_set, 'list', data_list)) as f: 118 | cur_sets = f.readlines() 119 | cur_sets = [x.strip() for x in cur_sets] 120 | cur_sets = list(filter(None, cur_sets)) 121 | 122 | print('Current %s dataset:' % self.mode) 123 | for s in cur_sets: 124 | print(' ', s) 125 | 126 | self.load_data(cur_sets) 127 | 128 | def load_data(self, cur_sets): 129 | for s in tqdm(cur_sets, desc='Loading data'): 130 | pcl = load_data(filedir=self.data_dir, filename='%s.xyz' % s, dtype=np.float32)[:, :3] 131 | 132 | if s.find('_noise_white_') == -1: 133 | s_clean = s 134 | else: 135 | s_clean = s.split('_noise_white_')[0] 136 | pcl_clean = load_data(filedir=self.data_dir, filename='%s.xyz' % s_clean, dtype=np.float32)[:, :3] 137 | 138 | nor = load_data(filedir=self.data_dir, filename=s_clean + '.normals', dtype=np.float32) 139 | 140 | self.pointclouds.append(pcl) 141 | self.pointclouds_clean.append(pcl_clean) 142 | self.normals.append(nor) 143 | self.shape_names.append(s) 144 | 145 | # KDTree construction may run out of recursions 146 | sys.setrecursionlimit(int(max(1000, round(pcl.shape[0]/10)))) 147 | kdtree = spatial.cKDTree(pcl, 10) 148 | self.kdtrees.append(kdtree) 149 | 150 | kdtree_clean = spatial.cKDTree(pcl_clean, 10) 151 | self.kdtrees_clean.append(kdtree_clean) 152 | 153 | if self.sparse_patches: 154 | pidx = load_data(filedir=self.data_dir, filename='%s.pidx' % s, dtype=np.int32) 155 | self.pidxs.append(pidx) 156 | self.shape_patch_count.append(len(pidx)) 157 | else: 158 | self.shape_patch_count.append(pcl.shape[0]) 159 | 160 | def __len__(self): 161 | return len(self.pointclouds) 162 | 163 | def __getitem__(self, idx): 164 | # KDTree uses a reference, not a copy of these points, 165 | # so modifying the points would make the kdtree give incorrect results! 166 | data = { 167 | 'pcl': self.pointclouds[idx].copy(), 168 | 'pcl_clean': self.pointclouds_clean[idx].copy(), 169 | 'kdtree': self.kdtrees[idx], 170 | 'kdtree_clean': self.kdtrees_clean[idx], 171 | 'normal': self.normals[idx], 172 | 'pidx': self.pidxs[idx] if len(self.pidxs) > 0 else None, 173 | 'name': self.shape_names[idx], 174 | } 175 | return data 176 | 177 | 178 | class PatchDataset(Dataset): 179 | def __init__(self, datasets, patch_size=1, with_trans=True): 180 | super().__init__() 181 | self.datasets = datasets 182 | self.patch_size = patch_size 183 | self.trans = None 184 | if with_trans: 185 | self.trans = PCATrans() 186 | 187 | def __len__(self): 188 | return sum(self.datasets.shape_patch_count) 189 | 190 | def shape_index(self, index): 191 | """ 192 | Translate global (dataset-wide) point index to shape index & local (shape-wide) point index 193 | """ 194 | shape_patch_offset = 0 195 | shape_ind = None 196 | for shape_ind, shape_patch_count in enumerate(self.datasets.shape_patch_count): 197 | if index >= shape_patch_offset and index < shape_patch_offset + shape_patch_count: 198 | shape_patch_ind = index - shape_patch_offset # index in shape with ID shape_ind 199 | break 200 | shape_patch_offset = shape_patch_offset + shape_patch_count 201 | return shape_ind, shape_patch_ind 202 | 203 | def make_patch(self, pcl, kdtree=None, kdtree_clean=None, seed_idx=None, patch_size=1): 204 | """ 205 | Args: 206 | pcl: (N, 3) 207 | kdtree: 208 | nor: (N, 3) 209 | seed_idx: (P,) 210 | patch_size: K 211 | Returns: 212 | pcl_pat, nor_pat: (P, K, 3) 213 | """ 214 | seed_pnts = pcl[seed_idx, :] 215 | dists, pat_idx = kdtree.query(seed_pnts, k=patch_size) # sorted by distance (nearest first) 216 | dist_max = max(dists) 217 | 218 | pcl_pat = pcl[pat_idx, :] # (K, 3) 219 | pcl_pat = pcl_pat - seed_pnts # center 220 | pcl_pat = pcl_pat / dist_max # normlize 221 | 222 | _, nor_idx = kdtree_clean.query(seed_pnts) 223 | 224 | return pcl_pat, nor_idx 225 | 226 | def __getitem__(self, idx): 227 | """ 228 | Returns a patch centered at the point with the given global index 229 | and the ground truth normal of the patch center 230 | """ 231 | # find shape that contains the point with given global index 232 | shape_idx, patch_idx = self.shape_index(idx) 233 | shape_data = self.datasets[shape_idx] 234 | 235 | # get the query point 236 | if shape_data['pidx'] is None: 237 | center_point_idx = patch_idx 238 | else: 239 | center_point_idx = shape_data['pidx'][patch_idx] 240 | 241 | pcl_pat, normal_idx = self.make_patch(pcl=shape_data['pcl'], 242 | kdtree=shape_data['kdtree'], 243 | kdtree_clean=shape_data['kdtree_clean'], 244 | seed_idx=center_point_idx, 245 | patch_size=self.patch_size, 246 | ) 247 | data = { 248 | 'pcl_pat': torch.from_numpy(pcl_pat), 249 | 'center_normal': torch.from_numpy(shape_data['normal'][normal_idx, :]), 250 | 'name': shape_data['name'], 251 | } 252 | 253 | if self.trans is not None: 254 | data = self.trans(data) 255 | return data 256 | 257 | 258 | if __name__ == '__main__': 259 | root = './dataset/' 260 | data_set = 'PCPNet' 261 | data_list = 'testset_%s.txt' % data_set 262 | 263 | test_dset = PointCloudDataset( 264 | root=root, 265 | mode='test', 266 | data_set=data_set, 267 | data_list=data_list, 268 | ) 269 | test_set = PatchDataset( 270 | datasets=test_dset, 271 | patch_size=700, 272 | transform=PCATrans(), 273 | ) 274 | test_dataloader = torch.utils.data.DataLoader( 275 | test_set, 276 | sampler=SequentialPointcloudPatchSampler(test_set), 277 | batch_size=10, 278 | num_workers=1, 279 | ) 280 | 281 | for batchind, data in enumerate(test_dataloader, 0): 282 | print(data['pcl_pat'].size()) 283 | 284 | 285 | 286 | 287 | -------------------------------------------------------------------------------- /log/001/ckpts/ckpt_900.pt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/YingruiWoo/CMG-Net_Pytorch/8ba318062b0f4a72534160455231aae98f6e2dbc/log/001/ckpts/ckpt_900.pt -------------------------------------------------------------------------------- /net/CMG_Net.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.functional as F 5 | from .utils import knn_group_0, get_knn_idx 6 | from .local_feature import LocalFeature_Extraction, AdaptiveLayer 7 | from .decode import PosionFusion 8 | from .utils import HierarchicalLayer, Conv1D 9 | from .utils import batch_quat_to_rotmat 10 | 11 | class PointEncoder(nn.Module): 12 | def __init__(self, num_out=[], knn_l1=16, knn_l2=32, knn_h1=32, knn_h2=16, code_dim=128): 13 | super(PointEncoder, self).__init__() 14 | self.num_out = num_out 15 | 16 | self.stn = QSTN(num_points=700, dim=3, sym_op='max') 17 | 18 | self.encodeNet1 = LocalFeature_Extraction(num_convs=4, 19 | conv_channels=24, 20 | knn=knn_l1) 21 | 22 | self.encodeNet2 = LocalFeature_Extraction(num_convs=4, 23 | conv_channels=24, 24 | knn=knn_l2) 25 | 26 | dim_1 = self.encodeNet1.out_channels 27 | 28 | self.att_layer = AdaptiveLayer(dim_1) 29 | self.conv_1 = Conv1D(dim_1, 128) 30 | self.conv_2 = Conv1D(128, 256) 31 | 32 | self.knn_h1 = knn_h1 33 | self.knn_h2 = knn_h2 34 | 35 | self.shift_1 = HierarchicalLayer(self.num_out[0], 256, 256, with_fc=True, neighbor_feature=1) 36 | self.shift_2 = HierarchicalLayer(self.num_out[1], 256, 256, last_dim=128, with_last=True, with_fc=True, neighbor_feature=2) 37 | self.shift_3 = HierarchicalLayer(self.num_out[2], 256, 256, last_dim=128, with_last=True, with_fc=True, neighbor_feature=1) 38 | self.shift_4 = HierarchicalLayer(self.num_out[2], 256, 256, last_dim=128, with_last=True, with_fc=True, neighbor_feature=2) 39 | 40 | self.conv_3 = Conv1D(256, 256) 41 | self.conv_4 = Conv1D(256, code_dim) 42 | 43 | def forward(self, pos, knn_idx, knn_idx_l): 44 | """ 45 | pos: (B, N, 3) 46 | knn_idx: (B, N, K) 47 | """ 48 | 49 | trans = self.stn(pos.transpose(2, 1)) 50 | pos = torch.bmm(pos, trans) 51 | 52 | ### Multi-scale Local Feature Aggregation 53 | y1 = self.encodeNet1(pos, knn_idx=knn_idx).transpose(2, 1) 54 | y2 = self.encodeNet2(pos, knn_idx=knn_idx_l).transpose(2, 1) 55 | y = self.att_layer(y1, y2) 56 | 57 | y = self.conv_1(y) 58 | y = self.conv_2(y) 59 | 60 | ### Hierarchical 61 | idx1 = get_knn_idx(pos, pos[:, :self.num_out[0], :], k=self.knn_h1, offset=1) 62 | y, global_1 = self.shift_1(y, knn_idx=idx1, pos=pos.transpose(2, 1)) 63 | idx2 = get_knn_idx(pos[:, :self.num_out[0], :], pos[:, :self.num_out[1], :], k=self.knn_h1, offset=1) 64 | y, global_2 = self.shift_2(y, knn_idx=idx2, pos=pos.transpose(2, 1), x_last=global_1) 65 | idx3 = get_knn_idx(pos[:, :self.num_out[1], :], pos[:, :self.num_out[2], :], k=self.knn_h2, offset=1) 66 | y, global_3 = self.shift_3(y, knn_idx=idx3, pos=pos.transpose(2, 1), x_last=global_2) 67 | idx4 = get_knn_idx(pos[:, :self.num_out[2], :], pos[:, :self.num_out[2], :], k=self.knn_h2, offset=1) 68 | y, global_4 = self.shift_4(y, knn_idx=idx4, pos=pos.transpose(2, 1), x_last=global_3) 69 | 70 | y = self.conv_3(y) + y 71 | y = self.conv_4(y) 72 | return y, trans, pos 73 | 74 | 75 | class Network(nn.Module): 76 | def __init__(self, num_in=1, knn_l1=16, knn_l2=32, knn_h1=16, knn_h2=32, knn_d=16): 77 | super(Network, self).__init__() 78 | self.num_in = num_in 79 | self.num_out = [num_in // 3 * 2, num_in // 3 * 2 // 3 * 2, num_in // 3 * 2 // 3 * 2 // 3 * 2] 80 | self.knn_l1 = knn_l1 81 | self.knn_l2 = knn_l2 82 | self.knn_h1 = knn_h1 83 | self.knn_h2 = knn_h2 84 | self.decode_knn = knn_d 85 | code_dim = 128 86 | 87 | self.pointEncoder = PointEncoder(num_out=self.num_out, knn_l1=self.knn_l1, knn_l2=self.knn_l2, 88 | knn_h1=self.knn_h1, knn_h2=self.knn_h2, code_dim=code_dim) 89 | 90 | pos_dim = 64 91 | self.out_dim = 128 92 | self.featDecoder = PosionFusion(in_dim=code_dim, 93 | pos_dim=pos_dim + 3, 94 | out_dim=self.out_dim, 95 | hidden_size=128, 96 | num_blocks=3) 97 | 98 | self.mlp_pos = nn.Sequential( 99 | nn.Linear(3, 64), 100 | nn.ReLU(), 101 | nn.Linear(64, pos_dim), 102 | ) 103 | 104 | self.conv_1 = Conv1D(128, 128) 105 | self.conv_2 = Conv1D(128, 128) 106 | self.conv_w = nn.Conv1d(128, 1, 1) 107 | self.mlp_n = nn.Linear(128, 3) 108 | 109 | def forward(self, pos): 110 | """ 111 | pos: (B, N, 3) 112 | """ 113 | 114 | ### Encoder 115 | knn_idx = get_knn_idx(pos, pos, k=self.knn_l1+1) # (B, N, K+1) 116 | knn_idx_large = get_knn_idx(pos, pos, k=self.knn_l2+1) 117 | y, trans, pos = self.pointEncoder(pos, knn_idx=knn_idx[:,:,1:self.knn_l1+1], knn_idx_l=knn_idx_large[:,:,1:self.knn_l2+1]) # (B, C, n) 118 | B, Cy, _ = y.size() 119 | 120 | ### Position Embedding 121 | pos_sub = pos[:, :self.num_out[2], :] 122 | knn_idx = knn_idx[:, :self.num_out[2], :self.decode_knn] 123 | 124 | nn_pc = knn_group_0(pos, knn_idx) 125 | nn_pc = nn_pc - pos_sub.unsqueeze(2) 126 | 127 | nn_feat = self.mlp_pos(nn_pc) 128 | nn_feat = torch.cat([nn_pc, nn_feat], dim=-1) 129 | 130 | ### Position Fusion 131 | Cp = nn_feat.size()[-1] 132 | feat = self.featDecoder(x=nn_feat.view(B*self.num_out[2], self.decode_knn, Cp), 133 | c=y.transpose(2, 1).reshape(B*self.num_out[2], Cy), 134 | ) 135 | feat = feat.reshape(B, self.num_out[2], self.out_dim, self.decode_knn) 136 | feat = feat.max(dim=3, keepdim=False)[0] 137 | 138 | ### Weighted Output 139 | feat = self.conv_1(feat.transpose(2, 1)) 140 | weights = 0.01 + torch.sigmoid(self.conv_w(feat)) 141 | normal = self.mlp_n(self.conv_2(feat * weights).max(dim=2, keepdim=False)[0]) 142 | 143 | normal = F.normalize(normal, p=2, dim=-1) 144 | 145 | return normal, weights, trans 146 | 147 | def get_loss(self, q_target, q_pred, pred_weights=None, normal_loss_type='sin', pcl_in=None, trans=None): 148 | """ 149 | q_target: (B, 3) 150 | q_pred: (B, 3) 151 | pred_weights: (B, 1, N) 152 | pcl_in: (B, N, 3) 153 | trans: (B, 3, 3) 154 | """ 155 | def cos_angle(v1, v2): 156 | return torch.bmm(v1.unsqueeze(1), v2.unsqueeze(2)).view(-1) / torch.clamp(v1.norm(2, 1) * v2.norm(2, 1), min=0.000001) 157 | 158 | weight_loss = torch.zeros(1, device=q_pred.device, dtype=q_pred.dtype) 159 | 160 | ### query point normal 161 | o_pred = q_pred 162 | o_target = q_target 163 | 164 | o_pred = torch.bmm(o_pred.unsqueeze(1), trans.transpose(2, 1)).squeeze(1) 165 | 166 | if normal_loss_type == 'mse_loss': 167 | normal_loss = 0.5 * F.mse_loss(o_pred, o_target) 168 | elif normal_loss_type == 'ms_euclidean': 169 | normal_loss = 0.1 * torch.min((o_pred-o_target).pow(2).sum(1), (o_pred+o_target).pow(2).sum(1)).mean() 170 | elif normal_loss_type == 'ms_oneminuscos': 171 | cos_ang = cos_angle(o_pred, o_target) 172 | normal_loss = 1.0 * (1-torch.abs(cos_ang)).pow(2).mean() 173 | elif normal_loss_type == 'sin': 174 | normal_loss = 0.1 * torch.norm(torch.cross(o_pred, o_target, dim=-1), p=2, dim=1).mean() 175 | else: 176 | raise ValueError('Unsupported loss type: %s' % (normal_loss_type)) 177 | 178 | ### compute the true weight by fitting distance 179 | pcl_in = pcl_in[:, :self.num_out[2], :] 180 | pred_weights = pred_weights.squeeze() 181 | if pred_weights is not None: 182 | thres_d = 0.05 * 0.05 183 | normal_dis = torch.bmm(o_target.unsqueeze(1), pcl_in.transpose(2, 1)).pow(2).squeeze() 184 | sigma = torch.mean(normal_dis, dim=1) * 0.3 + 1e-5 185 | threshold_matrix = torch.ones_like(sigma) * thres_d 186 | sigma = torch.where(sigma < thres_d, threshold_matrix, sigma) 187 | true_weight = torch.exp(-1 * torch.div(normal_dis, sigma.unsqueeze(-1))) 188 | 189 | weight_loss = (true_weight - pred_weights).pow(2).mean() 190 | 191 | regularizer_loss = 0.1 * torch.nn.MSELoss()(trans @ trans.permute(0, 2, 1), 192 | torch.eye(3, device=trans.device).unsqueeze(0).repeat( 193 | trans.size(0), 1, 1)) 194 | 195 | batch_size = trans.shape[0] 196 | z_vector = torch.from_numpy(np.array([0, 0, 1]).astype(np.float32)).squeeze().repeat(batch_size, 1).to(trans.device) 197 | z_vector_rot = torch.bmm(z_vector.unsqueeze(1), trans.transpose(2, 1)).squeeze(1) 198 | z_vector_rot = F.normalize(z_vector_rot, dim=1) 199 | z_trans_loss = 0.5 * torch.norm(torch.cross(z_vector_rot, o_target, dim=-1), p=2, dim=1).mean() 200 | 201 | loss = normal_loss + weight_loss + regularizer_loss + z_trans_loss 202 | 203 | return loss, (normal_loss, weight_loss, regularizer_loss, z_trans_loss) 204 | 205 | class QSTN(nn.Module): 206 | def __init__(self, num_points=700, dim=3, sym_op='max'): 207 | super(QSTN, self).__init__() 208 | 209 | self.dim = dim 210 | self.sym_op = sym_op 211 | self.num_points = num_points 212 | 213 | self.conv1 = torch.nn.Conv1d(self.dim, 64, 1) 214 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 215 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 216 | self.mp1 = torch.nn.MaxPool1d(num_points) 217 | self.fc1 = nn.Linear(1024, 512) 218 | self.fc2 = nn.Linear(512, 256) 219 | self.fc3 = nn.Linear(256, 4) 220 | 221 | self.bn1 = nn.BatchNorm1d(64) 222 | self.bn2 = nn.BatchNorm1d(128) 223 | self.bn3 = nn.BatchNorm1d(1024) 224 | self.bn4 = nn.BatchNorm1d(512) 225 | self.bn5 = nn.BatchNorm1d(256) 226 | 227 | 228 | def forward(self, x): 229 | x = F.relu(self.bn1(self.conv1(x))) 230 | x = F.relu(self.bn2(self.conv2(x))) 231 | x = F.relu(self.bn3(self.conv3(x))) 232 | 233 | x = self.mp1(x) 234 | 235 | x = x.view(-1, 1024) 236 | 237 | 238 | x = F.relu(self.bn4(self.fc1(x))) 239 | x = F.relu(self.bn5(self.fc2(x))) 240 | x = self.fc3(x) 241 | 242 | iden = x.new_tensor([1, 0, 0, 0]) 243 | x = x + iden 244 | 245 | x = batch_quat_to_rotmat(x) 246 | 247 | return x 248 | 249 | 250 | -------------------------------------------------------------------------------- /net/decode.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | 4 | 5 | class ResnetBlock(nn.Module): 6 | def __init__(self, c_dim, size_in, size_h=None, size_out=None): 7 | """ 8 | size_in (int): input dimension 9 | size_h (int): hidden dimension 10 | size_out (int): output dimension 11 | """ 12 | super().__init__() 13 | if size_h is None: 14 | size_h = size_in 15 | if size_out is None: 16 | size_out = size_in 17 | 18 | self.mlp_1 = nn.Sequential( 19 | nn.BatchNorm1d(size_in), 20 | nn.ReLU(), 21 | nn.Conv1d(size_in, size_h, 1) 22 | ) 23 | self.mlp_2 = nn.Sequential( 24 | nn.BatchNorm1d(size_h), 25 | nn.ReLU(), 26 | nn.Conv1d(size_h, size_out, 1) 27 | ) 28 | 29 | self.fc_c = nn.Conv1d(c_dim, size_out, 1) 30 | 31 | if size_in == size_out: 32 | self.shortcut = None 33 | else: 34 | self.shortcut = nn.Conv1d(size_in, size_out, 1, bias=False) 35 | 36 | def forward(self, x, c): 37 | dx = self.mlp_1(x) 38 | dx = self.mlp_2(dx) 39 | 40 | if self.shortcut is not None: 41 | x_s = self.shortcut(x) 42 | else: 43 | x_s = x 44 | 45 | out = x_s + dx + self.fc_c(c) 46 | return out 47 | 48 | 49 | class PosionFusion(nn.Module): 50 | def __init__(self, in_dim, pos_dim, out_dim, hidden_size, num_blocks): 51 | """ 52 | in_dim: Dimension of context vectors 53 | pos_dim: Point dimension 54 | out_dim: Output dimension 55 | hidden_size: Hidden state dimension 56 | """ 57 | super().__init__() 58 | 59 | c_dim = in_dim + pos_dim 60 | self.conv_p = nn.Conv1d(c_dim, hidden_size, 1) 61 | self.blocks = nn.ModuleList([ 62 | ResnetBlock(c_dim, hidden_size) for _ in range(num_blocks) 63 | ]) 64 | self.mlp_out = nn.Sequential( 65 | nn.BatchNorm1d(hidden_size), 66 | nn.ReLU(), 67 | nn.Conv1d(hidden_size, out_dim, 1) 68 | ) 69 | 70 | def forward(self, x, c): 71 | """ 72 | x: (B, N, C) 73 | c: (B, in_dim), latent code 74 | """ 75 | x = x.transpose(2, 1) 76 | num_points = x.size(-1) 77 | c = c.unsqueeze(2).expand(-1, -1, num_points) 78 | 79 | xc = torch.cat([x, c], dim=1) 80 | net = self.conv_p(xc) 81 | 82 | for block in self.blocks: 83 | net = block(net, xc) 84 | 85 | out = self.mlp_out(net) 86 | return out 87 | 88 | -------------------------------------------------------------------------------- /net/local_feature.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | from .utils import knn_group_0, get_knn_idx 4 | from .utils import LinearLayer as FCLayer 5 | BN1d = 1 6 | BN2d = 2 7 | 8 | 9 | class Aggregator(nn.Module): 10 | def __init__(self, oper): 11 | super().__init__() 12 | assert oper in ('mean', 'sum', 'max') 13 | self.oper = oper 14 | 15 | def forward(self, x, dim=2): 16 | if self.oper == 'mean': 17 | return x.mean(dim=dim, keepdim=False) 18 | elif self.oper == 'sum': 19 | return x.sum(dim=dim, keepdim=False) 20 | elif self.oper == 'max': 21 | ret, _ = x.max(dim=dim, keepdim=False) 22 | return ret 23 | 24 | class AdaptiveLayer(nn.Module): 25 | def __init__(self, C, r=4): 26 | super(AdaptiveLayer, self).__init__() 27 | self.squeeze = nn.AdaptiveAvgPool1d(1) 28 | self.excitation = nn.Sequential( 29 | nn.Linear(C, C // r, bias=False), 30 | nn.ReLU(inplace=True), 31 | nn.Linear(C // r, C, bias=False), 32 | nn.Sigmoid() 33 | ) 34 | def forward(self, x1, x2): 35 | fea = x1 + x2 36 | b, C, _ = fea.shape 37 | out = self.squeeze(fea).view(b, C) 38 | out = self.excitation(out).view(b, C, 1) 39 | attention_vectors = out.expand_as(fea) 40 | fea_v = attention_vectors * x1 + (1 - attention_vectors) * x2 41 | return fea_v 42 | 43 | 44 | class GraphConv_L(nn.Module): 45 | def __init__(self, in_channels, num_fc_layers, growth_rate, knn=16, aggr='max', with_bn=BN2d, activation='relu', relative_feat_only=False): 46 | super().__init__() 47 | self.in_channels = in_channels 48 | self.knn = knn 49 | assert num_fc_layers > 2 50 | self.num_fc_layers = num_fc_layers 51 | self.growth_rate = growth_rate 52 | self.relative_feat_only = relative_feat_only 53 | 54 | if relative_feat_only: 55 | self.layer_first = FCLayer(in_channels+3, growth_rate, with_bn=with_bn, activation=activation) 56 | else: 57 | self.layer_first = FCLayer(in_channels*3, growth_rate, with_bn=with_bn, activation=activation) 58 | 59 | self.layers_mid = nn.ModuleList() 60 | for i in range(1, num_fc_layers-1): 61 | self.layers_mid.append(FCLayer(in_channels + i * growth_rate, growth_rate, with_bn=with_bn, activation=activation)) 62 | 63 | self.layer_last = FCLayer(in_channels + (num_fc_layers - 1) * growth_rate, growth_rate, with_bn=False, activation=None) 64 | 65 | self.aggr = Aggregator(aggr) 66 | 67 | @property 68 | def out_channels(self): 69 | return self.in_channels + self.num_fc_layers * self.growth_rate 70 | 71 | def get_edge_feature(self, x, pos, knn_idx): 72 | """ 73 | :param x: (B, N, c) 74 | :param pos: (B, N, 3) 75 | :param knn_idx: (B, N, K) 76 | :return edge_feat: (B, N, K, C) 77 | """ 78 | knn_feat = knn_group_0(x, knn_idx) 79 | x_tiled = x.unsqueeze(-2).expand_as(knn_feat) 80 | if self.relative_feat_only: 81 | knn_pos = knn_group_0(pos, knn_idx) 82 | pos_tiled = pos.unsqueeze(-2) 83 | edge_feat = torch.cat([knn_pos - pos_tiled, knn_feat - x_tiled], dim=3) 84 | else: 85 | edge_feat = torch.cat([x_tiled, knn_feat, knn_feat - x_tiled], dim=3) 86 | return edge_feat 87 | 88 | def forward(self, x, pos, knn_idx=None): 89 | """ 90 | :param x: (B, N, x) 91 | pos: (B, N, y) 92 | :return y: (B, N, z) 93 | knn_idx: (B, N, K) 94 | """ 95 | if knn_idx is None: 96 | knn_idx = get_knn_idx(pos, pos, k=self.knn, offset=1) 97 | edge_feat = self.get_edge_feature(x, pos, knn_idx=knn_idx) 98 | 99 | ### First Layer 100 | y = torch.cat([ 101 | self.layer_first(edge_feat), 102 | x.unsqueeze(-2).repeat(1, 1, self.knn, 1) 103 | ], dim=-1) 104 | 105 | ### Intermediate Layers 106 | for layer in self.layers_mid: 107 | y = torch.cat([ 108 | layer(y), 109 | y 110 | ], dim=-1) 111 | 112 | ### Last Layer 113 | y = torch.cat([ 114 | self.layer_last(y), 115 | y 116 | ], dim=-1) 117 | 118 | ### Pooling Layer 119 | y = self.aggr(y, dim=-2) 120 | 121 | return y, knn_idx 122 | 123 | 124 | class LocalFeature_Extraction(nn.Module): 125 | def __init__(self, 126 | num_convs=4, 127 | in_channels=3, 128 | conv_channels=24, 129 | num_fc_layers=3, 130 | growth_rate=12, 131 | knn=16, 132 | aggr='max', 133 | activation='relu', 134 | ): 135 | super().__init__() 136 | self.num_convs = num_convs 137 | self.in_channels = in_channels 138 | 139 | self.trans = nn.ModuleList() 140 | self.convs = nn.ModuleList() 141 | for i in range(num_convs): 142 | tran = FCLayer(in_features=in_channels, out_features=conv_channels, with_bn=BN1d, activation=activation) 143 | conv = GraphConv_L( 144 | in_channels=conv_channels, 145 | num_fc_layers=num_fc_layers, 146 | growth_rate=growth_rate, 147 | knn=knn, 148 | aggr=aggr, 149 | activation=activation, 150 | relative_feat_only=(i == 0), 151 | ) 152 | self.trans.append(tran) 153 | self.convs.append(conv) 154 | in_channels = conv.out_channels 155 | 156 | @property 157 | def out_channels(self): 158 | return self.convs[-1].out_channels 159 | 160 | def forward(self, x, knn_idx=None): 161 | """ 162 | :param x: (B, N, 3+c) 163 | knn_idx: (B, N, K) 164 | :return y: (B, N, C), C = conv_channels+num_fc_layers*growth_rate 165 | """ 166 | pos = x[:,:,:3] 167 | for i in range(self.num_convs): 168 | x = self.trans[i](x) 169 | x, knn_idx = self.convs[i](x, pos=pos, knn_idx=knn_idx) 170 | return x -------------------------------------------------------------------------------- /net/utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | 5 | def square_distance(src, dst): 6 | """ 7 | Calculate Euclid distance between each two points. 8 | src^T * dst = xn * xm + yn * ym + zn * zm; 9 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 10 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 11 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 12 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 13 | Input: 14 | src: source points, [B, M, C] 15 | dst: target points, [B, N, C] 16 | Output: 17 | dist: per-point square distance, [B, M, N] 18 | """ 19 | return torch.sum((src[:, :, None] - dst[:, None]) ** 2, dim=-1) 20 | 21 | def knn_group_0(x:torch.FloatTensor, idx:torch.LongTensor): 22 | """ 23 | :param x: (B, N, F) 24 | :param idx: (B, M, k) 25 | :return (B, M, k, F) 26 | """ 27 | B, N, F = tuple(x.size()) 28 | _, M, k = tuple(idx.size()) 29 | 30 | x = x.unsqueeze(1).expand(B, M, N, F) 31 | idx = idx.unsqueeze(3).expand(B, M, k, F) 32 | 33 | return torch.gather(x, dim=2, index=idx) 34 | 35 | def knn_group_1(x:torch.FloatTensor, idx:torch.LongTensor): 36 | """ 37 | :param x: (B, F, N) 38 | :param idx: (B, M, k) 39 | :return (B, F, M, k) 40 | """ 41 | B, F, N = tuple(x.size()) 42 | _, M, k = tuple(idx.size()) 43 | 44 | x = x.unsqueeze(2).expand(B, F, M, N) 45 | idx = idx.unsqueeze(1).expand(B, F, M, k) 46 | 47 | return torch.gather(x, dim=3, index=idx) 48 | 49 | def get_knn_idx(pos, query, k, offset=0): 50 | """ 51 | :param pos: (B, N, F) 52 | :param query: (B, M, F) 53 | :return knn_idx: (B, M, k) 54 | """ 55 | dists = square_distance(query, pos) 56 | dists, idx = dists.sort(dim=-1) 57 | dists, idx = dists[:, :, offset:k+offset], idx[:, :, offset:k+offset] 58 | return idx[:, :, offset:] 59 | 60 | 61 | class LinearLayer(torch.nn.Module): 62 | def __init__(self, in_features, out_features, with_bn=1, activation='relu'): 63 | super().__init__() 64 | assert with_bn in [0, 1, 2] 65 | self.with_bn = with_bn > 0 and activation is not None 66 | 67 | self.linear = nn.Linear(in_features, out_features) 68 | 69 | if self.with_bn: 70 | if with_bn == 2: 71 | self.bn = nn.BatchNorm2d(out_features) 72 | else: 73 | self.bn = nn.BatchNorm1d(out_features) 74 | 75 | if activation is None: 76 | self.activation = nn.Identity() 77 | elif activation == 'relu': 78 | self.activation = nn.ReLU(inplace=True) 79 | elif activation == 'elu': 80 | self.activation = nn.ELU(alpha=1.0) 81 | elif activation == 'lrelu': 82 | self.activation = nn.LeakyReLU(0.1) 83 | else: 84 | raise ValueError() 85 | 86 | def forward(self, x): 87 | """ 88 | x: (*, C) 89 | y: (*, C) 90 | """ 91 | y = self.linear(x) 92 | if self.with_bn: 93 | if x.dim() == 2: # (B, C) 94 | y = self.activation(self.bn(y)) 95 | elif x.dim() == 3: # (B, N, C) 96 | y = self.activation(self.bn(y.transpose(1, 2))).transpose(1, 2) 97 | elif x.dim() == 4: # (B, H, W, C) 98 | y = self.activation(self.bn(y.permute(0, 3, 1, 2))).permute(0, 2, 3, 1) 99 | else: 100 | y = self.activation(y) 101 | return y 102 | 103 | 104 | class Conv1D(nn.Module): 105 | def __init__(self, input_dim, output_dim, with_bn=True, with_relu=True): 106 | super(Conv1D, self).__init__() 107 | self.with_bn = with_bn 108 | self.with_relu = with_relu 109 | self.conv = nn.Conv1d(input_dim, output_dim, 1) 110 | if with_bn: 111 | self.bn = nn.BatchNorm1d(output_dim) 112 | 113 | def forward(self, x): 114 | """ 115 | x: (B, C, N) 116 | """ 117 | if self.with_bn: 118 | x = self.bn(self.conv(x)) 119 | else: 120 | x = self.conv(x) 121 | 122 | if self.with_relu: 123 | x = F.relu(x) 124 | return x 125 | 126 | class Conv2D(nn.Module): 127 | def __init__(self, input_dim, output_dim, with_bn=True, with_relu=True): 128 | super(Conv2D, self).__init__() 129 | self.with_bn = with_bn 130 | self.with_relu = with_relu 131 | self.conv = nn.Conv2d(input_dim, output_dim, 1) 132 | if with_bn: 133 | self.bn = nn.BatchNorm2d(output_dim) 134 | 135 | def forward(self, x): 136 | """ 137 | x: (B, C, N) 138 | """ 139 | if self.with_bn: 140 | x = self.bn(self.conv(x)) 141 | else: 142 | x = self.conv(x) 143 | 144 | if self.with_relu: 145 | x = F.relu(x) 146 | return x 147 | 148 | class FC(nn.Module): 149 | def __init__(self, input_dim, output_dim): 150 | super(FC, self).__init__() 151 | self.fc = nn.Linear(input_dim, output_dim) 152 | self.bn = nn.BatchNorm1d(output_dim) 153 | 154 | def forward(self, x): 155 | """ 156 | x: (B, C) 157 | """ 158 | x = F.relu(self.bn(self.fc(x))) 159 | return x 160 | 161 | class GraphConv_H(nn.Module): 162 | def __init__(self, in_channels, output_scale, neighbor_feature): 163 | super().__init__() 164 | self.in_channels = in_channels 165 | self.output_scale = output_scale 166 | self.neighbor_feature = neighbor_feature 167 | 168 | if self.neighbor_feature == 1: 169 | self.conv1 = Conv2D(3, 64, with_bn=True, with_relu=True) 170 | self.conv2 = Conv2D(64, 64, with_bn=True, with_relu=True) 171 | if self.neighbor_feature == 1: 172 | self.graph_conv = Conv2D(3 * 2 + 64, 256, with_bn=True, with_relu=True) 173 | if self.neighbor_feature == 2: 174 | self.graph_conv = Conv2D(3 * 2 + 256, 256, with_bn=True, with_relu=True) 175 | 176 | def get_edge_feature(self, x, pos, knn_idx): 177 | """ 178 | :param x: (B, C, N) 179 | :param pos: (B, 3, N) 180 | :param knn_idx: (B, N, K) 181 | :return edge_feat: (B, C, N, K) 182 | """ 183 | 184 | knn_pos = knn_group_1(pos, knn_idx) # (B, C, N, K) 185 | pos_tiled = pos[:, :, :self.output_scale].unsqueeze(-1).expand_as(knn_pos) 186 | 187 | knn_pos = knn_pos - pos_tiled 188 | knn_dist = torch.sum(knn_pos ** 2, dim=1, keepdim=True) 189 | knn_r = torch.sqrt(knn_dist.max(dim=3, keepdim=True)[0]) 190 | knn_pos = knn_pos / knn_r.expand_as(knn_pos) 191 | 192 | if self.neighbor_feature == 1: 193 | knn_x = self.conv1(knn_pos) 194 | knn_x = self.conv2(knn_x ) + knn_x 195 | if self.neighbor_feature == 2: 196 | knn_x = knn_group_1(x, knn_idx) 197 | x_tiled = x[:, :, :self.output_scale].unsqueeze(-1).expand_as(knn_x) 198 | 199 | knn_x = knn_x - x_tiled 200 | knn_xdist = torch.sum(knn_x ** 2, dim=1, keepdim=True) 201 | knn_xr = torch.sqrt(knn_xdist.max(dim=3, keepdim=True)[0]) 202 | knn_x = knn_x / knn_xr.expand_as(knn_x) 203 | 204 | edge_pos = torch.cat([pos_tiled, knn_pos, knn_x], dim=1) 205 | return edge_pos 206 | 207 | def forward(self, x, pos, knn_idx): 208 | """ 209 | :param x: (B, N, x) 210 | pos: (B, N, y) 211 | :return y: (B, N, z) 212 | knn_idx: (B, N, K) 213 | """ 214 | 215 | edge_pos = self.get_edge_feature(x, pos, knn_idx=knn_idx) 216 | 217 | y = self.graph_conv(edge_pos) 218 | 219 | y_global = y.max(dim=3, keepdim=False)[0] 220 | 221 | return y_global 222 | 223 | class HierarchicalLayer(nn.Module): 224 | def __init__(self, output_scale, input_dim, output_dim, last_dim=0, with_last=False, with_fc=True, neighbor_feature=False): 225 | super(HierarchicalLayer, self).__init__() 226 | self.output_scale = output_scale 227 | self.input_dim = input_dim 228 | self.output_dim = output_dim 229 | self.with_last = with_last 230 | self.with_fc = with_fc 231 | self.neighbor_feature = neighbor_feature 232 | 233 | self.conv_in = Conv1D(input_dim, input_dim*2, with_bn=True, with_relu=with_fc) 234 | 235 | if with_fc: 236 | self.fc = FC(input_dim*2, input_dim//2) 237 | if with_last: 238 | self.conv_out = Conv1D(input_dim + input_dim//2 + last_dim, output_dim, with_bn=True) 239 | else: 240 | self.conv_out = Conv1D(input_dim + input_dim//2, output_dim, with_bn=True) 241 | else: 242 | if with_last: 243 | self.conv_out = Conv1D(input_dim + input_dim*2 + last_dim, output_dim, with_bn=True) 244 | else: 245 | self.conv_out = Conv1D(input_dim + input_dim*2, output_dim, with_bn=True) 246 | 247 | if self.neighbor_feature: 248 | self.GraphConv = GraphConv_H(256, self.output_scale, self.neighbor_feature) 249 | 250 | def forward(self, x, x_last=None, knn_idx=None, pos=None): 251 | """ 252 | x: (B, C, N) 253 | x_last: (B, C) 254 | """ 255 | BS, _, _ = x.shape 256 | 257 | ### Global information 258 | ori_x = x 259 | y = self.conv_in(x) 260 | x_global = torch.max(y, dim=2, keepdim=False)[0] 261 | if self.with_fc: 262 | x_global = self.fc(x_global) 263 | 264 | ### Neighbor information 265 | if self.neighbor_feature: 266 | x = self.GraphConv(x, pos, knn_idx) 267 | x = ori_x[:, :, :self.output_scale] + x 268 | else: 269 | x = ori_x[:, :, :self.output_scale] 270 | 271 | ### Feature fusion for shifting 272 | if self.with_last: 273 | x = torch.cat([x_global.view(BS, -1, 1).repeat(1, 1, self.output_scale), 274 | x_last.view(BS, -1, 1).repeat(1, 1, self.output_scale), x], dim=1) 275 | else: 276 | x = torch.cat([x_global.view(BS, -1, 1).repeat(1, 1, self.output_scale), x], dim=1) 277 | 278 | x = self.conv_out(x) 279 | x = x + ori_x[:, :, :self.output_scale] 280 | 281 | return x, x_global 282 | 283 | def batch_quat_to_rotmat(q, out=None): 284 | 285 | batchsize = q.size(0) 286 | 287 | if out is None: 288 | out = q.new_empty(batchsize, 3, 3) 289 | 290 | # 2 / squared quaternion 2-norm 291 | s = 2/torch.sum(q.pow(2), 1) 292 | 293 | # coefficients of the Hamilton product of the quaternion with itself 294 | h = torch.bmm(q.unsqueeze(2), q.unsqueeze(1)) 295 | 296 | out[:, 0, 0] = 1 - (h[:, 2, 2] + h[:, 3, 3]).mul(s) 297 | out[:, 0, 1] = (h[:, 1, 2] - h[:, 3, 0]).mul(s) 298 | out[:, 0, 2] = (h[:, 1, 3] + h[:, 2, 0]).mul(s) 299 | 300 | out[:, 1, 0] = (h[:, 1, 2] + h[:, 3, 0]).mul(s) 301 | out[:, 1, 1] = 1 - (h[:, 1, 1] + h[:, 3, 3]).mul(s) 302 | out[:, 1, 2] = (h[:, 2, 3] - h[:, 1, 0]).mul(s) 303 | 304 | out[:, 2, 0] = (h[:, 1, 3] - h[:, 2, 0]).mul(s) 305 | out[:, 2, 1] = (h[:, 2, 3] + h[:, 1, 0]).mul(s) 306 | out[:, 2, 2] = 1 - (h[:, 1, 1] + h[:, 2, 2]).mul(s) 307 | 308 | return out 309 | 310 | 311 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import shutil 3 | import time 4 | import argparse 5 | import torch 6 | import numpy as np 7 | 8 | from net.CMG_Net import Network 9 | from utils.misc import get_logger, seed_all 10 | from dataset import PointCloudDataset, PatchDataset, SequentialPointcloudPatchSampler, load_data 11 | import scipy.spatial as spatial 12 | 13 | 14 | def parse_arguments(): 15 | parser = argparse.ArgumentParser() 16 | parser.add_argument('--gpu', type=int, default=0) 17 | parser.add_argument('--dataset_root', type=str, default='../') 18 | parser.add_argument('--data_set', type=str, default='PCPNet') 19 | parser.add_argument('--log_root', type=str, default='./log') 20 | parser.add_argument('--ckpt_dirs', type=str, default='001', help="can be multiple directories, separated by ',' ") 21 | parser.add_argument('--ckpt_iter', type=str, default='900') 22 | parser.add_argument('--seed', type=int, default=2022) 23 | parser.add_argument('--num_workers', type=int, default=8) 24 | parser.add_argument('--batch_size', type=int, default=100) 25 | parser.add_argument('--tag', type=str, default='') 26 | parser.add_argument('--testset_list', type=str, default='testset_all.txt') 27 | parser.add_argument('--eval_list', type=str, 28 | default=['testset_no_noise.txt', 'testset_low_noise.txt', 'testset_med_noise.txt', 'testset_high_noise.txt', 29 | 'testset_vardensity_striped.txt', 'testset_vardensity_gradient.txt'], 30 | nargs='*', help='list of .txt files containing sets of point cloud names for evaluation') 31 | parser.add_argument('--patch_size', type=int, default=800) 32 | parser.add_argument('--knn_l1', type=int, default=16) 33 | parser.add_argument('--knn_l2', type=int, default=32) 34 | parser.add_argument('--knn_h1', type=int, default=32) 35 | parser.add_argument('--knn_h2', type=int, default=16) 36 | parser.add_argument('--knn_d', type=int, default=16) 37 | parser.add_argument('--sparse_patches', type=eval, default=True, choices=[True, False], 38 | help='test on a sparse set of patches, given by a .pidx file containing the patch center point indices.') 39 | parser.add_argument('--save_pn', type=eval, default=False, choices=[True, False]) 40 | parser.add_argument('--matric', type=str, default='CND', choices=['CND', 'RMSE']) 41 | args = parser.parse_args() 42 | return args 43 | 44 | 45 | def get_data_loaders(args): 46 | test_dset = PointCloudDataset( 47 | root=args.dataset_root, 48 | mode='test', 49 | data_set=args.data_set, 50 | data_list=args.testset_list, 51 | sparse_patches=args.sparse_patches, 52 | ) 53 | test_set = PatchDataset( 54 | datasets=test_dset, 55 | patch_size=args.patch_size, 56 | ) 57 | test_dataloader = torch.utils.data.DataLoader( 58 | test_set, 59 | sampler=SequentialPointcloudPatchSampler(test_set), 60 | batch_size=args.batch_size, 61 | num_workers=args.num_workers, 62 | ) 63 | return test_dset, test_dataloader 64 | 65 | 66 | ### Arguments 67 | args = parse_arguments() 68 | arg_str = '\n'.join([' {}: {}'.format(op, getattr(args, op)) for op in vars(args)]) 69 | print('Arguments:\n %s\n' % arg_str) 70 | 71 | seed_all(args.seed) 72 | PID = os.getpid() 73 | 74 | assert args.gpu >= 0, "ERROR GPU ID!" 75 | _device = torch.device('cuda:%d' % args.gpu) 76 | 77 | ### Datasets and loaders 78 | test_dset, test_dataloader = get_data_loaders(args) 79 | 80 | 81 | def normal_error(normal_gts, normal_preds, eval_file='log.txt', matric='CND'): 82 | """ 83 | Compute normal root-mean-square error (CND/RMSE) 84 | """ 85 | def l2_norm(v): 86 | norm_v = np.sqrt(np.sum(np.square(v), axis=1)) 87 | return norm_v 88 | 89 | log_file = open(eval_file, 'w') 90 | def log_string(out_str): 91 | log_file.write(out_str+'\n') 92 | log_file.flush() 93 | 94 | errors = [] 95 | errors_o = [] 96 | pgp30 = [] 97 | pgp25 = [] 98 | pgp20 = [] 99 | pgp15 = [] 100 | pgp10 = [] 101 | pgp5 = [] 102 | pgp_alpha = [] 103 | 104 | for i in range(len(normal_gts)): 105 | normal_gt = normal_gts[i] 106 | normal_pred = normal_preds[i] 107 | 108 | normal_gt_norm = l2_norm(normal_gt) 109 | normal_results_norm = l2_norm(normal_pred) 110 | normal_pred = np.divide(normal_pred, np.tile(np.expand_dims(normal_results_norm, axis=1), [1, 3])) 111 | normal_gt = np.divide(normal_gt, np.tile(np.expand_dims(normal_gt_norm, axis=1), [1, 3])) 112 | 113 | ### Unoriented cnd/rmse 114 | nn = np.sum(np.multiply(normal_gt, normal_pred), axis=1) 115 | nn[nn > 1] = 1 116 | nn[nn < -1] = -1 117 | 118 | ang = np.rad2deg(np.arccos(np.abs(nn))) 119 | 120 | ### Error metric 121 | errors.append(np.sqrt(np.mean(np.square(ang)))) 122 | ### Portion of good points 123 | pgp30_shape = sum([j < 30.0 for j in ang]) / float(len(ang)) 124 | pgp25_shape = sum([j < 25.0 for j in ang]) / float(len(ang)) 125 | pgp20_shape = sum([j < 20.0 for j in ang]) / float(len(ang)) 126 | pgp15_shape = sum([j < 15.0 for j in ang]) / float(len(ang)) 127 | pgp10_shape = sum([j < 10.0 for j in ang]) / float(len(ang)) 128 | pgp5_shape = sum([j < 5.0 for j in ang]) / float(len(ang)) 129 | pgp30.append(pgp30_shape) 130 | pgp25.append(pgp25_shape) 131 | pgp20.append(pgp20_shape) 132 | pgp15.append(pgp15_shape) 133 | pgp10.append(pgp10_shape) 134 | pgp5.append(pgp5_shape) 135 | 136 | pgp_alpha_shape = [] 137 | for alpha in range(30): 138 | pgp_alpha_shape.append(sum([j < alpha for j in ang]) / float(len(ang))) 139 | 140 | pgp_alpha.append(pgp_alpha_shape) 141 | 142 | # Oriented cnd/rmse 143 | errors_o.append(np.sqrt(np.mean(np.square(np.rad2deg(np.arccos(nn)))))) 144 | 145 | avg_errors = np.mean(errors) 146 | avg_errors_o = np.mean(errors_o) 147 | avg_pgp30 = np.mean(pgp30) 148 | avg_pgp25 = np.mean(pgp25) 149 | avg_pgp20 = np.mean(pgp20) 150 | avg_pgp15 = np.mean(pgp15) 151 | avg_pgp10 = np.mean(pgp10) 152 | avg_pgp5 = np.mean(pgp5) 153 | avg_pgp_alpha = np.mean(np.array(pgp_alpha), axis=0) 154 | 155 | log_string('%s per shape: ' % matric + str(errors)) 156 | log_string('%s not oriented (shape average): ' % matric + str(avg_errors)) 157 | log_string('%s oriented (shape average): ' % matric + str(avg_errors_o)) 158 | log_string('PGP30 per shape: ' + str(pgp30)) 159 | log_string('PGP25 per shape: ' + str(pgp25)) 160 | log_string('PGP20 per shape: ' + str(pgp20)) 161 | log_string('PGP15 per shape: ' + str(pgp15)) 162 | log_string('PGP10 per shape: ' + str(pgp10)) 163 | log_string('PGP5 per shape: ' + str(pgp5)) 164 | log_string('PGP30 average: ' + str(avg_pgp30)) 165 | log_string('PGP25 average: ' + str(avg_pgp25)) 166 | log_string('PGP20 average: ' + str(avg_pgp20)) 167 | log_string('PGP15 average: ' + str(avg_pgp15)) 168 | log_string('PGP10 average: ' + str(avg_pgp10)) 169 | log_string('PGP5 average: ' + str(avg_pgp5)) 170 | log_string('PGP alpha average: ' + str(avg_pgp_alpha)) 171 | log_file.close() 172 | 173 | return avg_errors 174 | 175 | 176 | def test(ckpt_dir, ckpt_iter): 177 | ### Input/Output 178 | ckpt_path = os.path.join(args.log_root, ckpt_dir, 'ckpts/ckpt_%s.pt' % ckpt_iter) 179 | output_dir = os.path.join(args.log_root, ckpt_dir, 'results_%s/ckpt_%s' % (args.data_set, ckpt_iter)) 180 | if args.tag is not None and len(args.tag) != 0: 181 | output_dir += '_' + args.tag 182 | if not os.path.exists(ckpt_path): 183 | print('ERROR path: %s' % ckpt_path) 184 | return False, False 185 | 186 | file_save_dir = os.path.join(output_dir, 'pred_normal') 187 | os.makedirs(output_dir, exist_ok=True) 188 | os.makedirs(file_save_dir, exist_ok=True) 189 | 190 | logger = get_logger('test(%d)(%s-%s)' % (PID, ckpt_dir, ckpt_iter), output_dir) 191 | logger.info('Command: {}'.format(' '.join(sys.argv))) 192 | 193 | ### Model 194 | logger.info('Loading model: %s' % ckpt_path) 195 | ckpt = torch.load(ckpt_path, map_location=_device) 196 | model = Network(num_in=args.patch_size, 197 | knn_l1=args.knn_l1, 198 | knn_l2=args.knn_l2, 199 | knn_h1=args.knn_h1, 200 | knn_h2=args.knn_h2, 201 | knn_d=args.knn_d, 202 | ).to(_device) 203 | 204 | model_parameters = filter(lambda p: p.requires_grad, model.parameters()) 205 | num_params = sum([np.prod(p.size()) for p in model_parameters]) 206 | trainable_num = sum(p.numel() for p in model.parameters() if p.requires_grad) 207 | logger.info('Num_params: %d' % num_params) 208 | logger.info('Num_params_trainable: %d' % trainable_num) 209 | 210 | model.load_state_dict(ckpt['state_dict']) 211 | model.eval() 212 | 213 | shape_ind = 0 214 | shape_patch_offset = 0 215 | shape_num = len(test_dset.shape_names) 216 | shape_patch_count = test_dset.shape_patch_count[shape_ind] 217 | 218 | num_batch = len(test_dataloader) 219 | normal_prop = torch.zeros([shape_patch_count, 3]) 220 | 221 | total_time = 0 222 | for batchind, data in enumerate(test_dataloader, 0): 223 | pcl_pat = data['pcl_pat'].to(_device) # (B, N, 3) 224 | data_trans = data['pca_trans'].to(_device) 225 | 226 | start_time = time.time() 227 | with torch.no_grad(): 228 | n_est, _, trans = model(pcl_pat) 229 | end_time = time.time() 230 | elapsed_time = 1000 * (end_time - start_time) # ms 231 | total_time += elapsed_time 232 | 233 | if batchind % 5 == 0: 234 | batchSize = pcl_pat.size()[0] 235 | logger.info('[%d/%d] %s: elapsed_time per point/patch: %.3f ms' % ( 236 | batchind, num_batch-1, test_dset.shape_names[shape_ind], elapsed_time / batchSize)) 237 | 238 | n_est[:, :] = torch.bmm(n_est.unsqueeze(1), trans.transpose(2, 1)).squeeze(dim=1) 239 | if data_trans is not None: 240 | ### transform predictions with inverse PCA rotation (back to world space) 241 | n_est[:, :] = torch.bmm(n_est.unsqueeze(1), data_trans.transpose(2, 1)).squeeze(dim=1) 242 | 243 | ### Save the estimated normals to file 244 | batch_offset = 0 245 | while batch_offset < n_est.shape[0] and shape_ind + 1 <= shape_num: 246 | shape_patches_remaining = shape_patch_count - shape_patch_offset 247 | batch_patches_remaining = n_est.shape[0] - batch_offset 248 | 249 | ### append estimated patch properties batch to properties for the current shape on the CPU 250 | normal_prop[shape_patch_offset:shape_patch_offset + min(shape_patches_remaining, batch_patches_remaining), :] = \ 251 | n_est[batch_offset:batch_offset + min(shape_patches_remaining, batch_patches_remaining), :] 252 | 253 | batch_offset = batch_offset + min(shape_patches_remaining, batch_patches_remaining) 254 | shape_patch_offset = shape_patch_offset + min(shape_patches_remaining, batch_patches_remaining) 255 | 256 | if shape_patches_remaining <= batch_patches_remaining: 257 | normals_to_write = normal_prop.cpu().numpy() 258 | 259 | ### for faster reading speed in the evaluation 260 | save_path = os.path.join(file_save_dir, test_dset.shape_names[shape_ind] + '_normal.npy') 261 | np.save(save_path, normals_to_write) 262 | if args.save_pn: 263 | save_path = os.path.join(file_save_dir, test_dset.shape_names[shape_ind] + '.normals') 264 | np.savetxt(save_path, normals_to_write) 265 | logger.info('saved normal: {} \n'.format(save_path)) 266 | 267 | sys.stdout.flush() 268 | shape_patch_offset = 0 269 | shape_ind += 1 270 | if shape_ind < shape_num: 271 | shape_patch_count = test_dset.shape_patch_count[shape_ind] 272 | normal_prop = torch.zeros([shape_patch_count, 3]) 273 | 274 | logger.info('Total Time: %.2f s, Shape Num: %d' % (total_time/1000, shape_num)) 275 | return output_dir, file_save_dir 276 | 277 | 278 | def eval(normal_gt_path, normal_pred_path, output_dir): 279 | print('\n Evaluation ...') 280 | eval_summary_dir = os.path.join(output_dir, 'test_summary') 281 | os.makedirs(eval_summary_dir, exist_ok=True) 282 | 283 | all_avg_errors = [] 284 | for cur_list in args.eval_list: 285 | print("\n***************** " + cur_list + " *****************") 286 | print("Result path: " + normal_pred_path) 287 | 288 | ### get all shape names in the list 289 | shape_names = [] 290 | normal_gt_filenames = os.path.join(normal_gt_path, 'list', cur_list) 291 | with open(normal_gt_filenames) as f: 292 | shape_names = f.readlines() 293 | shape_names = [x.strip() for x in shape_names] 294 | shape_names = list(filter(None, shape_names)) 295 | 296 | ### load all shapes 297 | normal_gts = [] 298 | normal_preds = [] 299 | for shape in shape_names: 300 | print(shape) 301 | shape_gt = shape.split('_noise_white_')[0] 302 | xyz_ori = load_data(filedir=normal_gt_path, filename=shape + '.xyz', dtype=np.float32) 303 | xyz_gt = load_data(filedir=normal_gt_path, filename=shape_gt + '.xyz', dtype=np.float32) 304 | normal_gt = load_data(filedir=normal_gt_path, filename=shape_gt + '.normals', dtype=np.float32) # (N, 3) 305 | normal_pred = np.load(os.path.join(normal_pred_path, shape + '_normal.npy')) # (n, 3) 306 | ### eval with sparse point sets 307 | points_idx = load_data(filedir=normal_gt_path, filename=shape + '.pidx', dtype=np.int32) # (n,) 308 | sys.setrecursionlimit(int(max(1000, round(xyz_gt.shape[0] / 10)))) 309 | kdtree = spatial.cKDTree(xyz_gt, 10) 310 | qurey_points = xyz_ori[points_idx, :] 311 | _, nor_idx = kdtree.query(qurey_points) 312 | if args.matric == 'CND': 313 | normal_gt = normal_gt[nor_idx, :] 314 | elif args.matric == 'RMSE': 315 | normal_gt = normal_gt[points_idx, :] 316 | if normal_pred.shape[0] > normal_gt.shape[0]: 317 | normal_pred = normal_pred[points_idx, :] 318 | 319 | normal_gts.append(normal_gt) 320 | normal_preds.append(normal_pred) 321 | 322 | ### compute CND per-list 323 | avg_errors = normal_error(normal_gts=normal_gts, 324 | normal_preds=normal_preds, 325 | eval_file=os.path.join(eval_summary_dir, cur_list[:-4] + '_evaluation_results.txt'), 326 | matric=args.matric) 327 | all_avg_errors.append(avg_errors) 328 | print('%s: %f' % (args.matric, avg_errors)) 329 | 330 | s = ('\n {} \n All %s not oriented (shape average): {} | Mean: {}\n' % args.matric).format( 331 | normal_pred_path, str(all_avg_errors), np.mean(all_avg_errors)) 332 | print(s) 333 | 334 | ### delete the output point normals 335 | if not args.save_pn: 336 | shutil.rmtree(normal_pred_path) 337 | return all_avg_errors 338 | 339 | 340 | 341 | if __name__ == '__main__': 342 | ckpt_dirs = args.ckpt_dirs.split(',') 343 | 344 | for ckpt_dir in ckpt_dirs: 345 | eval_dict = '' 346 | sum_file = 'eval_' + args.data_set + ('_'+args.tag if len(args.tag) != 0 else '') 347 | log_file_sum = open(os.path.join(args.log_root, ckpt_dir, sum_file+'.txt'), 'a') 348 | log_file_sum.write('\n====== %s ======\n' % args.eval_list) 349 | 350 | output_dir, file_save_dir = test(ckpt_dir=ckpt_dir, ckpt_iter=args.ckpt_iter) 351 | if not output_dir or args.data_set == 'Semantic3D': 352 | continue 353 | all_avg_errors = eval(normal_gt_path=os.path.join(args.dataset_root, args.data_set), 354 | normal_pred_path=file_save_dir, 355 | output_dir=output_dir) 356 | 357 | s = '%s: %s | Mean: %f\n' % (args.ckpt_iter, str(all_avg_errors), np.mean(all_avg_errors)) 358 | log_file_sum.write(s) 359 | log_file_sum.flush() 360 | eval_dict += s 361 | 362 | log_file_sum.close() 363 | s = ('\n All %s not oriented (shape average): \n{}\n' % args.matric).format(eval_dict) 364 | print(s) 365 | 366 | 367 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | import os, sys 2 | import argparse 3 | import math 4 | import torch 5 | import torch.utils.data 6 | import torch.utils.tensorboard 7 | import torch.optim as optim 8 | import torch.optim.lr_scheduler as lr_scheduler 9 | from torch.nn.utils import clip_grad_norm_ 10 | import torch.nn.functional as F 11 | 12 | from utils.misc import * 13 | from net.CMG_Net import Network 14 | from dataset import PointCloudDataset, PatchDataset, RandomPointcloudPatchSampler 15 | 16 | 17 | def parse_arguments(): 18 | parser = argparse.ArgumentParser() 19 | ### Training 20 | parser.add_argument('--gpu', type=int, default=0) 21 | parser.add_argument('--lr', type=float, default=0.0005) 22 | parser.add_argument('--seed', type=int, default=2022) 23 | parser.add_argument('--logging', type=eval, default=True, choices=[True, False]) 24 | parser.add_argument('--log_root', type=str, default='./log') 25 | parser.add_argument('--tag', type=str, default=None) 26 | parser.add_argument('--resume', type=str, default='') 27 | parser.add_argument('--nepoch', type=int, default=900) 28 | parser.add_argument('--interval', type=int, default=50) 29 | parser.add_argument('--max_grad_norm', type=float, default=float('inf')) 30 | ### Dataset and loader 31 | parser.add_argument('--dataset_root', type=str, default='../') 32 | parser.add_argument('--data_set', type=str, default='PCPNet', choices=['PCPNet']) 33 | parser.add_argument('--trainset_list', type=str, default='trainingset_whitenoise.txt') 34 | parser.add_argument('--batch_size', type=int, default=64) 35 | parser.add_argument('--num_workers', type=int, default=8) 36 | parser.add_argument('--patch_size', type=int, default=700) 37 | parser.add_argument('--knn_l1', type=int, default=16) 38 | parser.add_argument('--knn_l2', type=int, default=32) 39 | parser.add_argument('--knn_h1', type=int, default=32) 40 | parser.add_argument('--knn_h2', type=int, default=16) 41 | parser.add_argument('--knn_d', type=int, default=16) 42 | parser.add_argument('--patches_per_shape', type=int, default=1000, help='The number of patches sampled from each shape in an epoch') 43 | args = parser.parse_args() 44 | return args 45 | 46 | def get_data_loaders(args): 47 | def worker_init_fn(worker_id): 48 | random.seed(args.seed) 49 | np.random.seed(args.seed) 50 | 51 | 52 | train_dset = PointCloudDataset( 53 | root=args.dataset_root, 54 | mode='train', 55 | data_set=args.data_set, 56 | data_list=args.trainset_list, 57 | ) 58 | train_set = PatchDataset( 59 | datasets=train_dset, 60 | patch_size=args.patch_size, 61 | ) 62 | train_datasampler = RandomPointcloudPatchSampler(train_set, patches_per_shape=args.patches_per_shape, seed=args.seed) 63 | train_dataloader = torch.utils.data.DataLoader( 64 | train_set, 65 | sampler=train_datasampler, 66 | batch_size=args.batch_size, 67 | num_workers=int(args.num_workers), 68 | pin_memory=True, 69 | worker_init_fn=worker_init_fn, 70 | ) 71 | 72 | return train_dataloader, train_datasampler 73 | 74 | 75 | ### Arguments 76 | args = parse_arguments() 77 | seed_all(args.seed) 78 | 79 | assert args.gpu >= 0, "ERROR GPU ID!" 80 | _device = torch.device('cuda:%d' % args.gpu) 81 | PID = os.getpid() 82 | 83 | ### Datasets and loaders 84 | print('Loading datasets ...') 85 | train_dataloader, train_datasampler = get_data_loaders(args) 86 | train_num_batch = len(train_dataloader) 87 | 88 | ### Model 89 | print('Building model ...') 90 | model = Network(num_in=args.patch_size, 91 | knn_l1=args.knn_l1, 92 | knn_l2=args.knn_l2, 93 | knn_h1=args.knn_h1, 94 | knn_h2=args.knn_h2, 95 | knn_d=args.knn_d, 96 | ).to(_device) 97 | 98 | ### Optimizer and Scheduler 99 | optimizer = optim.Adam(model.parameters(), lr=args.lr) 100 | lambda1 = lambda epoch: (0.99*epoch/100 + 0.01) if epoch < 100 else 2e-3 if (0.5 * (1+math.cos(math.pi*(epoch-100)/(args.nepoch-200)))<2e-3 or epoch > 800) else 0.5 * (1+math.cos(math.pi*(epoch-100)/(args.nepoch-200))) 101 | scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) 102 | 103 | ### Logging 104 | if args.logging: 105 | log_path, log_dir_name = get_new_log_dir(args.log_root, prefix='', 106 | postfix='_' + args.tag if args.tag is not None else '') 107 | sub_log_dir = os.path.join(log_path, 'log') 108 | os.makedirs(sub_log_dir) 109 | logger = get_logger(name='train(%d)(%s)' % (PID, log_dir_name), log_dir=sub_log_dir) 110 | writer = torch.utils.tensorboard.SummaryWriter(sub_log_dir) 111 | log_hyperparams(writer, sub_log_dir, args) 112 | ckpt_mgr = CheckpointManager(os.path.join(log_path, 'ckpts')) 113 | else: 114 | logger = get_logger('train', None) 115 | writer = BlackHole() 116 | ckpt_mgr = BlackHole() 117 | 118 | refine_epoch = -1 119 | if args.resume != '': 120 | assert os.path.exists(args.resume), 'ERROR path: %s' % args.resume 121 | logger.info('Resume from: %s' % args.resume) 122 | 123 | ckpt = torch.load(args.resume) 124 | model.load_state_dict(ckpt['state_dict']) 125 | refine_epoch = ckpt['others']['epoch'] 126 | 127 | logger.info('Load pretrained mode: %s' % args.resume) 128 | 129 | if args.logging: 130 | code_dir = os.path.join(log_path, 'code') 131 | os.makedirs(code_dir, exist_ok=True) 132 | os.system('cp %s %s' % ('*.py', code_dir)) 133 | os.system('cp -r %s %s' % ('net', code_dir)) 134 | os.system('cp -r %s %s' % ('utils', code_dir)) 135 | 136 | 137 | ### Arguments 138 | logger.info('Command: {}'.format(' '.join(sys.argv))) 139 | arg_str = '\n'.join([' {}: {}'.format(op, getattr(args, op)) for op in vars(args)]) 140 | logger.info('Arguments:\n' + arg_str) 141 | logger.info(repr(model)) 142 | logger.info('training set: %d patches (in %d batches)' % 143 | (len(train_datasampler), len(train_dataloader))) 144 | 145 | 146 | def train(epoch): 147 | for train_batchind, batch in enumerate(train_dataloader, 0): 148 | pcl_pat = batch['pcl_pat'].to(_device) 149 | center_normal = batch['center_normal'].to(_device) # (B, 3) 150 | 151 | ### Reset grad and model state 152 | model.train() 153 | optimizer.zero_grad() 154 | 155 | ### Forward 156 | pred_nor, weights, trans = model(pcl_pat) 157 | loss, loss_tuple = model.get_loss(q_target=center_normal, q_pred=pred_nor, pred_weights=weights, pcl_in=pcl_pat, trans=trans) 158 | 159 | ### Backward and optimize 160 | loss.backward() 161 | orig_grad_norm = clip_grad_norm_(model.parameters(), args.max_grad_norm) 162 | optimizer.step() 163 | 164 | ### Logging 165 | s = '' 166 | for l in loss_tuple: 167 | s += '%.5f+' % l.item() 168 | logger.info('[Train] [%03d: %03d/%03d] | Loss: %.6f(%s) | Grad: %.6f' % ( 169 | epoch, train_batchind, train_num_batch-1, loss.item(), s[:-1], orig_grad_norm) 170 | ) 171 | 172 | if __name__ == '__main__': 173 | logger.info('Start training ...') 174 | try: 175 | for epoch in range(1, args.nepoch+1): 176 | logger.info('### Epoch %d ###' % epoch) 177 | if epoch <= refine_epoch: 178 | scheduler.step() 179 | continue 180 | 181 | start_time = time.time() 182 | train(epoch) 183 | end_time = time.time() 184 | logger.info('Time cost: %.1f s \n' % (end_time-start_time)) 185 | 186 | scheduler.step() 187 | 188 | if epoch % args.interval == 0 or epoch == args.nepoch-1: 189 | opt_states = { 190 | 'epoch': epoch, 191 | } 192 | 193 | if args.logging: 194 | ckpt_mgr.save(model, args, others=opt_states, step=epoch) 195 | 196 | except KeyboardInterrupt: 197 | logger.info('Terminating ...') 198 | -------------------------------------------------------------------------------- /utils/misc.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import numpy as np 4 | import random 5 | import time 6 | import logging 7 | import logging.handlers 8 | from datetime import datetime 9 | 10 | 11 | class BlackHole(object): 12 | def __setattr__(self, name, value): 13 | pass 14 | def __call__(self, *args, **kwargs): 15 | return self 16 | def __getattr__(self, name): 17 | return self 18 | 19 | 20 | class CheckpointManager(object): 21 | 22 | def __init__(self, save_dir, logger=BlackHole()): 23 | super().__init__() 24 | os.makedirs(save_dir, exist_ok=True) 25 | self.save_dir = save_dir 26 | self.ckpts = [] 27 | self.logger = logger 28 | 29 | for f in os.listdir(self.save_dir): 30 | if f[:4] != 'ckpt': 31 | continue 32 | _, it, score = f.split('_') 33 | it = it.split('.')[0] 34 | self.ckpts.append({ 35 | 'score': float(score), 36 | 'file': f, 37 | 'iteration': int(it), 38 | }) 39 | 40 | def get_worst_ckpt_idx(self): 41 | idx = -1 42 | worst = float('-inf') 43 | for i, ckpt in enumerate(self.ckpts): 44 | if ckpt['score'] >= worst: 45 | idx = i 46 | worst = ckpt['score'] 47 | return idx if idx >= 0 else None 48 | 49 | def get_best_ckpt_idx(self): 50 | idx = -1 51 | best = float('inf') 52 | for i, ckpt in enumerate(self.ckpts): 53 | if ckpt['score'] <= best: 54 | idx = i 55 | best = ckpt['score'] 56 | return idx if idx >= 0 else None 57 | 58 | def get_latest_ckpt_idx(self): 59 | idx = -1 60 | latest_it = -1 61 | for i, ckpt in enumerate(self.ckpts): 62 | if ckpt['iteration'] > latest_it: 63 | idx = i 64 | latest_it = ckpt['iteration'] 65 | return idx if idx >= 0 else None 66 | 67 | def save(self, model, args, score=None, others=None, step=None): 68 | assert step > -1, 'Please define the value of step' 69 | if score is None: 70 | fname = 'ckpt_%d.pt' % int(step) 71 | else: 72 | fname = 'ckpt_%d_%.6f.pt' % (int(step), float(score)) 73 | path = os.path.join(self.save_dir, fname) 74 | 75 | torch.save({ 76 | 'args': args, 77 | 'state_dict': model.state_dict(), 78 | 'others': others 79 | }, path) 80 | 81 | self.ckpts.append({ 82 | 'score': score, 83 | 'file': fname 84 | }) 85 | return True 86 | 87 | def load_best(self): 88 | idx = self.get_best_ckpt_idx() 89 | if idx is None: 90 | raise IOError('No checkpoints found.') 91 | ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]['file'])) 92 | return ckpt 93 | 94 | def load_latest(self): 95 | idx = self.get_latest_ckpt_idx() 96 | if idx is None: 97 | raise IOError('No checkpoints found.') 98 | ckpt = torch.load(os.path.join(self.save_dir, self.ckpts[idx]['file'])) 99 | return ckpt 100 | 101 | def load_selected(self, file): 102 | ckpt = torch.load(os.path.join(self.save_dir, file)) 103 | return ckpt 104 | 105 | 106 | def seed_all(seed): 107 | random.seed(seed) 108 | np.random.seed(seed) 109 | torch.manual_seed(seed) 110 | torch.cuda.manual_seed(seed) 111 | torch.cuda.manual_seed_all(seed) 112 | os.environ['PYTHONHASHSEED'] = str(seed) 113 | torch.backends.cudnn.enabled = True 114 | torch.backends.cudnn.benchmark = True 115 | torch.backends.cudnn.deterministic = True 116 | 117 | 118 | def git_commit(logger, log_dir=None, git_name=None): 119 | """ Logs source code configuration 120 | """ 121 | import git 122 | import subprocess 123 | 124 | try: 125 | repo = git.Repo(search_parent_directories=True) 126 | git_sha = repo.head.object.hexsha 127 | git_date = datetime.fromtimestamp(repo.head.object.committed_date).strftime('%Y-%m-%d') 128 | git_message = repo.head.object.message 129 | logger.info('Source is from Commit {} ({}): {}'.format(git_sha[:8], git_date, git_message.strip())) 130 | 131 | # create diff file in the log directory 132 | if log_dir is not None: 133 | with open(os.path.join(log_dir, 'compareHead.diff'), 'w') as fid: 134 | subprocess.run(['git', 'diff'], stdout=fid) 135 | 136 | git_name = git_name if git_name is not None else datetime.now().strftime("%y%m%d_%H%M%S") 137 | os.system("git add --all") 138 | os.system("git commit --all -m '{}'".format(git_name)) 139 | except git.exc.InvalidGitRepositoryError: 140 | pass 141 | 142 | 143 | def get_logger(name, log_dir=None): 144 | logger = logging.getLogger(name) 145 | logger.setLevel(logging.DEBUG) 146 | formatter = logging.Formatter('[%(asctime)s::%(name)s::%(levelname)s] %(message)s') 147 | 148 | stream_handler = logging.StreamHandler() 149 | stream_handler.setLevel(logging.DEBUG) 150 | stream_handler.setFormatter(formatter) 151 | logger.addHandler(stream_handler) 152 | 153 | if log_dir is not None: 154 | file_handler = logging.FileHandler(os.path.join(log_dir, 'log.txt'), mode='w') 155 | file_handler.setLevel(logging.INFO) 156 | file_handler.setFormatter(formatter) 157 | logger.addHandler(file_handler) 158 | logger.info('Output and logs will be saved to: {}'.format(log_dir)) 159 | return logger 160 | 161 | 162 | def get_new_log_dir(root='./logs', prefix='', postfix=''): 163 | name = prefix + time.strftime("%y%m%d_%H%M%S", time.localtime()) + postfix 164 | log_dir = os.path.join(root, name) 165 | os.makedirs(log_dir) 166 | return log_dir, name 167 | 168 | 169 | def int_tuple(argstr): 170 | return tuple(map(int, argstr.split(','))) 171 | 172 | 173 | def str_tuple(argstr): 174 | return tuple(argstr.split(',')) 175 | 176 | 177 | def int_list(argstr): 178 | return list(map(int, argstr.split(','))) 179 | 180 | 181 | def str_list(argstr): 182 | return list(argstr.split(',')) 183 | 184 | 185 | def log_hyperparams(writer, log_dir, args): 186 | from torch.utils.tensorboard.summary import hparams 187 | vars_args = {k:v if isinstance(v, str) else repr(v) for k, v in vars(args).items()} 188 | exp, ssi, sei = hparams(vars_args, {"hp_metric": -1}) 189 | fw = writer._get_file_writer() 190 | fw.add_summary(exp) 191 | fw.add_summary(ssi) 192 | fw.add_summary(sei) 193 | with open(os.path.join(log_dir, 'hparams.csv'), 'w') as csvf: 194 | csvf.write('key,value\n') 195 | for k, v in vars_args.items(): 196 | csvf.write('%s,%s\n' % (k, v)) 197 | 198 | --------------------------------------------------------------------------------