├── Customer_Module.zip ├── LICENSE ├── README.md ├── data_utils ├── DataLoader.py └── Pointfilter_Utils.py ├── env.yml ├── eval_ours.py ├── models ├── model.py └── pointnet_util.py ├── provider.py ├── test.py └── train.py /Customer_Module.zip: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/ZeyongWei/PathNet/239c53cbf33627a1110efa9459170db8e0b6c8bc/Customer_Module.zip -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2023 wei zeyong 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PathNet: Path-selective Point Cloud Denoising (TPAMI'24) 2 | This is an implementation of PathNet 3 | 4 | ## Abstract 5 | 6 | Current point cloud denoising (PCD) models optimize single networks, trying to make their parameters adaptive to each point in a large pool of point clouds. Such a denoising network paradigm neglects that different points are often corrupted by different levels of noise and they may convey different geometric structures. Thus, the intricacy of both noise and geometry poses side effects including remnant noise, wrongly-smoothed edges, and distorted shape after denoising. We propose PathNet, a path-selective PCD paradigm based on reinforcement learning (RL). Unlike existing efforts, PathNet enables dynamic selection of the most appropriate denoising path for each point, best moving it onto its underlying surface. Besides the proposed framework of path-selective PCD for the first time, we have two more contributions. First, to leverage geometry expertise and benefit from training data, we propose a noise- and geometry-aware reward function to train the routing agent in RL. Second, the routing agent and the denoising network are trained jointly to avoid under- and over-smoothing. Extensive experiments show promising improvements of PathNet over its competitors, in terms of the effectiveness for removing different levels of noise and preserving multi-scale surface geometries. Furthermore, PathNet generalizes itself more smoothly to real scans than cutting-edge models. 7 | 8 | ## Environment 9 | * Python >=3.5 10 | * PyTorch >=1.3.1 11 | * CUDA >= 10.1 12 | * h5py 13 | * numpy 14 | * scipy 15 | * tqdm 16 | * TensorboardX 17 | 18 | ## Installation 19 | 20 | You can install via conda environment .yaml file 21 | 22 | ```bash 23 | conda env create -f env.yml 24 | conda activate pathnet 25 | ``` 26 | 27 | ## Datasets and model 28 | We provide pretrained models and datasets [here](https://drive.google.com/drive/folders/1qaxpcqBGVK59HBfTTS68AoaqSWLcp9si?usp=sharing) 29 | 30 | Please extract `test_data.zip`, `train _data.hdf5` to `data` folder. 31 | 32 | ## Denoise 33 | 34 | ## Train 35 | Use the script `train.py` to train a model in the our dataset (the trained model will be saved at `./log/path-denoise/model/checkpoints/best_model.pth`): 36 | ``` bash 37 | cd PathNet 38 | ### First stage 39 | python train.py --epoch 200 --use_random_path 1 40 | ### Second stage 41 | python train.py --epoch 300 --use_random_path 0 42 | 43 | ``` 44 | ## Test (The filtered results will be saved at `./data/results`) 45 | ``` bash 46 | cd PathNet 47 | python test.py 48 | ``` 49 | 50 | ## Citation 51 | If you find the code useful for your research, please consider citing 52 | ``` 53 | @article{wei2024pathnet, 54 | title={PathNet: Path-Selective Point Cloud Denoising}, 55 | author={Wei, Zeyong and Chen, Honghua and Nan, Liangliang and Wang, Jun and Qin, Jing and Wei, Mingqiang}, 56 | journal={IEEE Transactions on Pattern Analysis and Machine Intelligence}, 57 | year={2024}, 58 | publisher={IEEE} 59 | } 60 | ``` 61 | 62 | ## Acknowledgements 63 | The repository is based on: 64 | - [Path-Restore](https://github.com/yuke93/Path-Restore) 65 | - [PU-GAN](https://liruihui.github.io/publication/PU-GAN/) 66 | - [PU-Net](https://github.com/yulequan/PU-Net). 67 | - [Mesh Denoising via Cascaded Normal Regression](https://wang-ps.github.io/denoising.html) 68 | 69 | The point clouds are visualized with [Easy3D](https://github.com/LiangliangNan/Easy3D). 70 | 71 | We thank the authors for their great work! 72 | 73 | ## License 74 | 75 | This project is open sourced under MIT license. 76 | -------------------------------------------------------------------------------- /data_utils/DataLoader.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | import os 3 | import warnings 4 | import numpy as np 5 | import h5py 6 | from torch.utils.data import Dataset 7 | from data_utils.Pointfilter_Utils import pca_alignment 8 | 9 | warnings.filterwarnings('ignore') 10 | 11 | def pca_normalize(pc,pc2): 12 | normalize_data = np.zeros(pc.shape, dtype=np.float32) 13 | normalize_data2 = np.zeros(pc2.shape, dtype=np.float32) 14 | 15 | centroid = pc[:,0,:] 16 | centroid = np.expand_dims(centroid, axis=1) 17 | pc = pc - centroid 18 | pc2 = pc2 - centroid 19 | 20 | m = np.max(np.sqrt(np.sum(pc**2, axis=2)),axis = 1, keepdims=True) 21 | pc = pc / np.expand_dims(m, axis=-1) 22 | pc2 = pc2 / np.expand_dims(m, axis=-1) 23 | 24 | for B in range(pc.shape[0]): 25 | x, pca_martix_inv = pca_alignment(pc[B,:,:]) 26 | x2 = np.array(np.linalg.inv(pca_martix_inv) * np.matrix(pc2[B,:,:].T)).T 27 | normalize_data[B, ...] = x 28 | normalize_data2[B, ...] = x2 29 | 30 | return normalize_data, normalize_data2 31 | 32 | class PatchDataset(Dataset): 33 | def __init__(self,root = './data/', npoints=128, split='train', class_choice=None, normal_channel=False): 34 | self.npoints = npoints 35 | self.root = root 36 | 37 | self.catfile = os.path.join(self.root, 'train _data.hdf5') 38 | 39 | f = h5py.File(self.catfile,'r') 40 | 41 | self.inputs = f["inputs"][:] #B, N ,3 42 | self.target = f["target"][:] 43 | self.label = f["label"][:] 44 | 45 | self.inputs,self.target= pca_normalize(self.inputs,self.target) 46 | 47 | self.label = self.label //2.3 48 | 49 | idx = np.arange(0,self.inputs.shape[0]) 50 | np.random.seed(1111) 51 | np.random.shuffle(idx) 52 | self.inputs = self.inputs[idx][:,:,:3] 53 | self.target = self.target[idx][:,:,:3] 54 | self.label = self.label[idx][:] 55 | 56 | sample_size = int(self.inputs.shape[0] * 0.8) 57 | if(split == 'train'): 58 | self.inputs = self.inputs[:sample_size] 59 | self.target = self.target[:sample_size] 60 | self.label = self.label[:sample_size] 61 | elif(split == 'test'): 62 | self.inputs = self.inputs[sample_size:] 63 | self.target = self.target[sample_size:] 64 | self.label = self.label[:sample_size] 65 | 66 | print('The size of %s inputs is %d'%(split,self.inputs.shape[0])) 67 | 68 | 69 | self.seg_classes = {'circle': [0,1]} 70 | 71 | self.cache = {} 72 | self.cache_size = 1000 73 | 74 | 75 | def __getitem__(self, index): 76 | if index in self.cache: 77 | inputs, target, label = self.cache[index] 78 | else: 79 | inputs = self.inputs[index].astype(np.float32) #N,3 80 | target = self.target[index].astype(np.float32) 81 | label = self.label[index].astype(np.float32) 82 | 83 | if len(self.cache) < self.cache_size: 84 | self.cache[index] = (inputs, target, label) 85 | 86 | return inputs, target, label 87 | 88 | def __len__(self): 89 | return self.inputs.shape[0] 90 | 91 | 92 | 93 | -------------------------------------------------------------------------------- /data_utils/Pointfilter_Utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.decomposition import PCA 3 | import math 4 | import torch 5 | import argparse 6 | ##########################Parameters######################## 7 | # 8 | # 9 | # 10 | # 11 | ############################################################### 12 | 13 | def str2bool(v): 14 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 15 | return True 16 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 17 | return False 18 | else: 19 | raise argparse.ArgumentTypeError('Boolean value expected.') 20 | def parse_arguments(): 21 | parser = argparse.ArgumentParser() 22 | # naming / file handling 23 | parser.add_argument('--name', type=str, default='pcdenoising', help='training run name') 24 | parser.add_argument('--network_model_dir', type=str, default='./Summary/Models/Train', help='output folder (trained models)') 25 | parser.add_argument('--trainset', type=str, default='./Dataset/Train', help='training set file name') 26 | parser.add_argument('--testset', type=str, default='./Dataset/Test', help='testing set file name') 27 | parser.add_argument('--save_dir', type=str, default='./Dataset/Results', help='') 28 | parser.add_argument('--summary_dir', type=str, default='./Summary/Models/Train/logs', help='') 29 | 30 | # training parameters 31 | parser.add_argument('--nepoch', type=int, default=50, help='number of epochs to train for') 32 | parser.add_argument('--batchSize', type=int, default=32, help='input batch size') 33 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers') 34 | parser.add_argument('--manualSeed', type=int, default=3627473, help='manual seed') 35 | parser.add_argument('--start_epoch', type=int, default=0, help='') 36 | parser.add_argument('--patch_per_shape', type=int, default=8000, help='') 37 | parser.add_argument('--patch_radius', type=float, default=0.05, help='') 38 | 39 | parser.add_argument('--lr', type=float, default=1e-2, help='learning rate') 40 | parser.add_argument('--momentum', type=float, default=0.9, help='gradient descent momentum') 41 | parser.add_argument('--model_interval', type=int, default=5, metavar='N', help='how many batches to wait before logging training status') 42 | 43 | # others parameters 44 | parser.add_argument('--resume', type=str, default='', help='refine model at this path') 45 | parser.add_argument('--support_multiple', type=float, default=4.0, help='the multiple of support radius') 46 | parser.add_argument('--support_angle', type=int, default=15, help='') 47 | parser.add_argument('--gt_normal_mode', type=str, default='nearest', help='') 48 | parser.add_argument('--repulsion_alpha', type=float, default='0.97', help='') 49 | 50 | # evaluation parameters 51 | parser.add_argument('--eval_dir', type=str, default='./Summary/pre_train_model', help='') 52 | parser.add_argument('--eval_iter_nums', type=int, default=10, help='') 53 | 54 | return parser.parse_args() 55 | 56 | ###################Pre-Processing Tools######################## 57 | # 58 | # 59 | # 60 | # 61 | ############################################################### 62 | 63 | 64 | def get_principle_dirs(pts): 65 | 66 | pts_pca = PCA(n_components=3) 67 | pts_pca.fit(pts) 68 | principle_dirs = pts_pca.components_ 69 | principle_dirs /= np.linalg.norm(principle_dirs, 2, axis=0) 70 | 71 | return principle_dirs 72 | 73 | 74 | def pca_alignment(pts, random_flag=False): 75 | 76 | pca_dirs = get_principle_dirs(pts) 77 | 78 | if random_flag: 79 | 80 | pca_dirs *= np.random.choice([-1, 1], 1) 81 | 82 | rotate_1 = compute_roatation_matrix(pca_dirs[2], [0, 0, 1], pca_dirs[1]) 83 | pca_dirs = np.array(rotate_1 * pca_dirs.T).T 84 | rotate_2 = compute_roatation_matrix(pca_dirs[1], [1, 0, 0], pca_dirs[2]) 85 | pts = np.array(rotate_2 * rotate_1 * np.matrix(pts.T)).T 86 | 87 | inv_rotation = np.array(np.linalg.inv(rotate_2 * rotate_1)) 88 | 89 | return pts, inv_rotation 90 | 91 | def compute_roatation_matrix(sour_vec, dest_vec, sour_vertical_vec=None): 92 | # http://immersivemath.com/forum/question/rotation-matrix-from-one-vector-to-another/ 93 | if np.linalg.norm(np.cross(sour_vec, dest_vec), 2) == 0 or np.abs(np.dot(sour_vec, dest_vec)) >= 1.0: 94 | if np.dot(sour_vec, dest_vec) < 0: 95 | return rotation_matrix(sour_vertical_vec, np.pi) 96 | return np.identity(3) 97 | alpha = np.arccos(np.dot(sour_vec, dest_vec)) 98 | a = np.cross(sour_vec, dest_vec) / np.linalg.norm(np.cross(sour_vec, dest_vec), 2) 99 | c = np.cos(alpha) 100 | s = np.sin(alpha) 101 | R1 = [a[0] * a[0] * (1.0 - c) + c, 102 | a[0] * a[1] * (1.0 - c) - s * a[2], 103 | a[0] * a[2] * (1.0 - c) + s * a[1]] 104 | 105 | R2 = [a[0] * a[1] * (1.0 - c) + s * a[2], 106 | a[1] * a[1] * (1.0 - c) + c, 107 | a[1] * a[2] * (1.0 - c) - s * a[0]] 108 | 109 | R3 = [a[0] * a[2] * (1.0 - c) - s * a[1], 110 | a[1] * a[2] * (1.0 - c) + s * a[0], 111 | a[2] * a[2] * (1.0 - c) + c] 112 | 113 | R = np.matrix([R1, R2, R3]) 114 | 115 | return R 116 | 117 | 118 | def rotation_matrix(axis, theta): 119 | 120 | # Return the rotation matrix associated with counterclockwise rotation about the given axis by theta radians. 121 | 122 | axis = np.asarray(axis) 123 | axis = axis / math.sqrt(np.dot(axis, axis)) 124 | a = math.cos(theta / 2.0) 125 | b, c, d = -axis * math.sin(theta / 2.0) 126 | aa, bb, cc, dd = a * a, b * b, c * c, d * d 127 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d 128 | return np.matrix(np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], 129 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], 130 | [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])) 131 | 132 | 133 | def patch_sampling(patch_pts, sample_num): 134 | 135 | if patch_pts.shape[0] > sample_num: 136 | 137 | sample_index = np.random.choice(range(patch_pts.shape[0]), sample_num, replace=False) 138 | 139 | else: 140 | 141 | sample_index = np.random.choice(range(patch_pts.shape[0]), sample_num) 142 | 143 | return sample_index 144 | 145 | ##########################Network Tools######################## 146 | # 147 | # 148 | # 149 | # 150 | ############################################################### 151 | 152 | def adjust_learning_rate(optimizer, epoch, opt): 153 | 154 | lr_shceduler(optimizer, epoch, opt.lr) 155 | 156 | def lr_shceduler(optimizer, epoch, init_lr): 157 | 158 | if epoch > 36: 159 | init_lr *= 0.5e-3 160 | elif epoch > 32: 161 | init_lr *= 1e-3 162 | elif epoch > 24: 163 | init_lr *= 1e-2 164 | elif epoch > 16: 165 | init_lr *= 1e-1 166 | for param_group in optimizer.param_groups: 167 | param_group['lr'] = init_lr 168 | 169 | ################################Ablation Study of Different Loss ############################### 170 | 171 | def compute_original_1_loss(pts_pred, gt_patch_pts, gt_patch_normals, support_radius, alpha): 172 | 173 | pts_pred = pts_pred.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 174 | dist_square = (pts_pred - gt_patch_pts).pow(2).sum(2) 175 | 176 | # avoid divided by zero 177 | weight = torch.exp(-1 * dist_square / (support_radius ** 2)) + 1e-12 178 | weight = weight / weight.sum(1, keepdim=True) 179 | 180 | # key loss 181 | project_dist = ((pts_pred - gt_patch_pts) * gt_patch_normals).sum(2) 182 | imls_dist = torch.abs((project_dist * weight).sum(1)) 183 | 184 | # repulsion loss 185 | max_dist = torch.max(dist_square, 1)[0] 186 | 187 | # final loss 188 | dist = torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist) 189 | 190 | return dist 191 | 192 | def compute_original_2_loss(pred_point, gt_patch_pts, gt_patch_normals, support_radius, support_angle, alpha): 193 | 194 | # Compute Spatial Weighted Function 195 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 196 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2) 197 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2)) 198 | 199 | ############# Get The Nearest Normal For Predicted Point ############# 200 | nearest_idx = torch.argmin(dist_square, dim=1) 201 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)]) 202 | pred_point_normal = pred_point_normal.view(-1, 3) 203 | pred_point_normal = pred_point_normal.unsqueeze(1) 204 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1) 205 | ############# Get The Nearest Normal For Predicted Point ############# 206 | 207 | # Compute Normal Weighted Function 208 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2) 209 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2) 210 | 211 | # # avoid divided by zero 212 | weight = weight_theta * weight_phi + 1e-12 213 | weight = weight / weight.sum(1, keepdim=True) 214 | 215 | # key loss 216 | project_dist = torch.sqrt(dist_square) 217 | imls_dist = (project_dist * weight).sum(1) 218 | 219 | # repulsion loss 220 | max_dist = torch.max(dist_square, 1)[0] 221 | 222 | # final loss 223 | dist = torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist) 224 | 225 | return dist 226 | 227 | def compute_original_3_loss(pts_pred, gt_patch_pts, alpha): 228 | # PointCleanNet Loss 229 | pts_pred = pts_pred.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 230 | m = (pts_pred - gt_patch_pts).pow(2).sum(2) 231 | min_dist = torch.min(m, 1)[0] 232 | max_dist = torch.max(m, 1)[0] 233 | dist = torch.mean((alpha * min_dist) + (1 - alpha) * max_dist) 234 | # print('min_dist: %f max_dist: %f' % (alpha * torch.mean(min_dist).item(), (1 - alpha) * torch.mean(max_dist).item())) 235 | return dist * 100 236 | 237 | 238 | ################################Ablation Study of Different Loss ############################### 239 | 240 | def compute_bilateral_loss_with_repulsion(pred_point, gt_patch_pts, gt_patch_normals, support_radius, support_angle, alpha): 241 | 242 | # Our Loss 243 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 244 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2) 245 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2)) 246 | 247 | nearest_idx = torch.argmin(dist_square, dim=1) 248 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)]) 249 | pred_point_normal = pred_point_normal.view(-1, 3) 250 | pred_point_normal = pred_point_normal.unsqueeze(1) 251 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1) 252 | 253 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2) 254 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2) 255 | 256 | # # avoid divided by zero 257 | weight = weight_theta * weight_phi + 1e-12 258 | weight = weight / weight.sum(1, keepdim=True) 259 | 260 | # key loss 261 | project_dist = torch.abs(((pred_point - gt_patch_pts) * gt_patch_normals).sum(2)) 262 | imls_dist = (project_dist * weight).sum(1) 263 | 264 | # repulsion loss 265 | max_dist = torch.max(dist_square, 1)[0] 266 | 267 | # final loss 268 | dist = torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist) 269 | 270 | return dist 271 | -------------------------------------------------------------------------------- /env.yml: -------------------------------------------------------------------------------- 1 | name: pathnet 2 | channels: 3 | - pyg 4 | - iopath 5 | - fvcore 6 | - pytorch 7 | - nvidia 8 | - conda-forge 9 | - defaults 10 | dependencies: 11 | - _libgcc_mutex=0.1=main 12 | - _openmp_mutex=4.5=1_gnu 13 | - absl-py=0.14.1=pyhd8ed1ab_0 14 | - aiohttp=3.7.4.post0=py38h497a2fe_0 15 | - async-timeout=3.0.1=py_1000 16 | - attrs=21.2.0=pyhd8ed1ab_0 17 | - blas=1.0=mkl 18 | - blinker=1.4=py_1 19 | - brotlipy=0.7.0=py38h497a2fe_1001 20 | - bzip2=1.0.8=h7b6447c_0 21 | - c-ares=1.17.1=h27cfd23_0 22 | - ca-certificates=2021.10.8=ha878542_0 23 | - cachetools=4.2.4=pyhd8ed1ab_0 24 | - certifi=2021.10.8=py38h578d9bd_0 25 | - cffi=1.14.6=py38ha65f79e_0 26 | - chardet=4.0.0=py38h578d9bd_1 27 | - charset-normalizer=2.0.0=pyhd8ed1ab_0 28 | - click=8.0.2=py38h578d9bd_0 29 | - colorama=0.4.4=pyh9f0ad1d_0 30 | - cryptography=3.4.7=py38ha5dfef3_0 31 | - cudatoolkit=11.1.74=h6bb024c_0 32 | - dataclasses=0.8=pyhc8e2a94_3 33 | - easydict=1.9=py_0 34 | - ffmpeg=4.3=hf484d3e_0 35 | - freetype=2.10.4=h5ab3b9f_0 36 | - fvcore=0.1.5.post20210915=py38 37 | - gmp=6.2.1=h2531618_2 38 | - gnutls=3.6.15=he1e5248_0 39 | - google-auth=1.35.0=pyh6c4a22f_0 40 | - google-auth-oauthlib=0.4.6=pyhd8ed1ab_0 41 | - grpcio=1.38.1=py38hdd6454d_0 42 | - idna=3.1=pyhd3deb0d_0 43 | - importlib-metadata=4.8.1=py38h578d9bd_0 44 | - intel-openmp=2021.3.0=h06a4308_3350 45 | - iopath=0.1.9=py38 46 | - joblib=1.1.0=pyhd8ed1ab_0 47 | - jpeg=9b=h024ee3a_2 48 | - lame=3.100=h7b6447c_0 49 | - lcms2=2.12=h3be6417_0 50 | - ld_impl_linux-64=2.35.1=h7274673_9 51 | - libblas=3.9.0=11_linux64_mkl 52 | - libcblas=3.9.0=11_linux64_mkl 53 | - libffi=3.3=he6710b0_2 54 | - libgcc-ng=9.3.0=h5101ec6_17 55 | - libgfortran-ng=11.2.0=h69a702a_10 56 | - libgfortran5=11.2.0=h5c6108e_10 57 | - libgomp=9.3.0=h5101ec6_17 58 | - libiconv=1.15=h63c8f33_5 59 | - libidn2=2.3.2=h7f8727e_0 60 | - liblapack=3.9.0=11_linux64_mkl 61 | - libpng=1.6.37=hbc83047_0 62 | - libprotobuf=3.15.8=h780b84a_0 63 | - libstdcxx-ng=9.3.0=hd4cf53a_17 64 | - libtasn1=4.16.0=h27cfd23_0 65 | - libtiff=4.2.0=h85742a9_0 66 | - libunistring=0.9.10=h27cfd23_0 67 | - libuv=1.40.0=h7b6447c_0 68 | - libwebp-base=1.2.0=h27cfd23_0 69 | - llvm-openmp=8.0.1=hc9558a2_0 70 | - lz4-c=1.9.3=h295c915_1 71 | - markdown=3.3.4=pyhd8ed1ab_0 72 | - mkl=2021.3.0=h06a4308_520 73 | - multidict=5.1.0=py38h27cfd23_2 74 | - ncurses=6.2=he6710b0_1 75 | - nettle=3.7.3=hbbd107a_1 76 | - ninja=1.10.2=hff7bd54_1 77 | - numpy=1.20.3=py38h9894fe3_1 78 | - oauthlib=3.1.1=pyhd8ed1ab_0 79 | - olefile=0.46=pyhd3eb1b0_0 80 | - openh264=2.1.0=hd408876_0 81 | - openjpeg=2.4.0=h3ad879b_0 82 | - openmp=8.0.1=0 83 | - openssl=1.1.1l=h7f8727e_0 84 | - pandas=1.2.5=py38h1abd341_0 85 | - pillow=8.3.1=py38h2c7a002_0 86 | - pip=21.2.4=py38h06a4308_0 87 | - point_cloud_utils=0.18.0=py38hc10631b_1 88 | - portalocker=2.3.2=py38h578d9bd_0 89 | - protobuf=3.15.8=py38h709712a_0 90 | - pyasn1=0.4.8=py_0 91 | - pyasn1-modules=0.2.7=py_0 92 | - pycparser=2.20=pyh9f0ad1d_2 93 | - pyjwt=2.2.0=pyhd8ed1ab_0 94 | - pyopenssl=21.0.0=pyhd8ed1ab_0 95 | - pysocks=1.7.1=py38h578d9bd_3 96 | - python=3.8.11=h12debd9_0_cpython 97 | - python-dateutil=2.8.2=pyhd8ed1ab_0 98 | - python_abi=3.8=2_cp38 99 | - pytorch=1.9.0=py3.8_cuda11.1_cudnn8.0.5_0 100 | - pytorch-cluster=1.5.9=py38_torch_1.9.0_cu111 101 | - pytorch-scatter=2.0.8=py38_torch_1.9.0_cu111 102 | - pytorch3d=0.5.0=py38_cu111_pyt190 103 | - pytz=2021.3=pyhd8ed1ab_0 104 | - pyu2f=0.1.5=pyhd8ed1ab_0 105 | - pyyaml=5.4.1=py38h497a2fe_0 106 | - readline=8.1=h27cfd23_0 107 | - requests=2.26.0=pyhd8ed1ab_0 108 | - requests-oauthlib=1.3.0=pyh9f0ad1d_0 109 | - rsa=4.7.2=pyh44b312d_0 110 | - scikit-learn=0.24.2=py38hdc147b9_0 111 | - scipy=1.6.3=py38h7b17777_0 112 | - setuptools=58.0.4=py38h06a4308_0 113 | - six=1.16.0=pyh6c4a22f_0 114 | - sqlite=3.36.0=hc218d9a_0 115 | - tabulate=0.8.9=pyhd8ed1ab_0 116 | - tensorboard=2.6.0=pyhd8ed1ab_1 117 | - tensorboard-data-server=0.6.0=py38h2b97feb_0 118 | - tensorboard-plugin-wit=1.8.0=pyh44b312d_0 119 | - termcolor=1.1.0=py_2 120 | - threadpoolctl=3.0.0=pyh8a188c0_0 121 | - tk=8.6.11=h1ccaba5_0 122 | - torchvision=0.10.0=py38_cu111 123 | - tqdm=4.62.3=pyhd8ed1ab_0 124 | - typing-extensions=3.10.0.2=hd3eb1b0_0 125 | - typing_extensions=3.10.0.2=pyh06a4308_0 126 | - urllib3=1.26.7=pyhd8ed1ab_0 127 | - werkzeug=2.0.1=pyhd8ed1ab_0 128 | - wheel=0.37.0=pyhd3eb1b0_1 129 | - xz=5.2.5=h7b6447c_0 130 | - yacs=0.1.6=py_0 131 | - yaml=0.2.5=h516909a_0 132 | - yarl=1.6.3=py38h497a2fe_2 133 | - zipp=3.6.0=pyhd8ed1ab_0 134 | - zlib=1.2.11=h7b6447c_3 135 | - zstd=1.4.9=haebb681_0 136 | -------------------------------------------------------------------------------- /eval_ours.py: -------------------------------------------------------------------------------- 1 | import scipy.spatial as sp 2 | import numpy as np 3 | import torch 4 | 5 | import os 6 | 7 | from Customer_Module.chamfer_distance.dist_chamfer import chamferDist 8 | from plyfile import PlyData, PlyElement 9 | nnd = chamferDist() 10 | 11 | import logging 12 | 13 | def log_string(str): 14 | logger.info(str) 15 | print(str) 16 | 17 | name_dirs = ['benchmark81_10000','benchmark81_20000','benchmark81_50000','kinect_fusion','kinect_v1','kinect_v2'] 18 | 19 | name_dir = 'benchmark81_20000' 20 | 21 | model = '' 22 | 23 | iter_num = 2 24 | 25 | results_dir = '//' + model + '/' + name_dir + '/' 26 | 27 | back_logs = ['_output_end','_output_end_output_end'] 28 | 29 | back_log = back_logs[iter_num-1] 30 | 31 | if(not os.path.exists('./data/test_data/' + name_dir + '/eval/' + model +'/')): 32 | os.makedirs('./data/test_data/' + name_dir + '/eval/' + model +'/') 33 | 34 | if(os.path.exists('./data/test_data/' + name_dir + '/eval/' + model +'/eval_log_' + name_dir +'_' + str(iter_num) + '.csv')): 35 | os.remove('./data/test_data/' + name_dir + '/eval/' + model +'/eval_log_' + name_dir +'_' + str(iter_num) + '.csv') 36 | 37 | logger = logging.getLogger("Eval"+'_' + model + '_' + name_dir + '_' + str(iter_num)) 38 | logger.setLevel(logging.INFO) 39 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 40 | file_handler = logging.FileHandler('./data/test_data/' + name_dir + '/eval/' + model +'/eval_log_' + name_dir +'_' + str(iter_num) + '.csv') 41 | file_handler.setLevel(logging.INFO) 42 | file_handler.setFormatter(formatter) 43 | logger.addHandler(file_handler) 44 | 45 | 46 | def Eval_With_Charmfer_Distance(): 47 | log_string('************Errors under Chamfer Distance************') 48 | for shape_id, shape_name in enumerate(shape_names): 49 | if(name_dir.split('_')[0]=='benchmark81'): 50 | gt_pts = np.loadtxt(os.path.join('./data/test_data/gt/'+ name_dir +'/', shape_name.split('_0.')[0] + '.xyz')) 51 | else: 52 | gt_pts = np.loadtxt(os.path.join('./data/test_data/gt/'+ name_dir +'/', shape_name.split('_noisy')[0] + '.xyz')) 53 | pred_pts = np.loadtxt(os.path.join('./data/results/'+ results_dir +'/', shape_name + back_log + '.xyz'))[:,:3] 54 | with torch.no_grad(): 55 | gt_pts_cuda = torch.from_numpy(np.expand_dims(gt_pts, axis=0)).cuda().float() 56 | pred_pts_cuda = torch.from_numpy(np.expand_dims(pred_pts, axis=0)).cuda().float() 57 | dist1, dist2 = nnd(pred_pts_cuda, gt_pts_cuda) 58 | chamfer_errors = torch.mean(dist1, dim=1) + torch.mean(dist2, dim=1) 59 | 60 | log_string('%12s %.3f' % (shape_names[shape_id], round(chamfer_errors.item() * 100000, 3))) 61 | 62 | def Eval_With_Mean_Square_Error(): 63 | log_string('************Errors under Mean Square Error************') 64 | for shape_id, shape_name in enumerate(shape_names): 65 | if(name_dir.split('_')[0]=='benchmark81'): 66 | gt_pts = np.loadtxt(os.path.join('./data/test_data/gt/' + name_dir +'/', shape_name.split('_0.')[0] + '.xyz')) 67 | else: 68 | gt_pts = np.loadtxt(os.path.join('./data/test_data/gt/' + name_dir +'/', shape_name.split('_noisy')[0] + '.xyz')) 69 | gt_pts_tree = sp.cKDTree(gt_pts) 70 | pred_pts = np.loadtxt(os.path.join('./data/results/'+ results_dir +'/', shape_name + back_log +'.xyz'))[:,:3] 71 | pred_dist, _ = gt_pts_tree.query(pred_pts, 10) 72 | 73 | log_string('%12s %.3f' % (shape_names[shape_id], round(pred_dist.mean() * 1000, 3))) 74 | 75 | 76 | if __name__ == '__main__': 77 | 78 | with open(os.path.join('./data/test_data/gt/'+ name_dir +'/', 'test.txt'), 'r') as f: 79 | shape_names = f.readlines() 80 | shape_names = [x.strip() for x in shape_names] 81 | shape_names = list(filter(None, shape_names)) 82 | 83 | Eval_With_Charmfer_Distance() 84 | Eval_With_Mean_Square_Error() 85 | 86 | -------------------------------------------------------------------------------- /models/model.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch.utils.data 3 | import torch 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from einops import rearrange, repeat 7 | 8 | class Analyser_Block(nn.Module): 9 | def __init__(self, path_num = 3): 10 | super(Analyser_Block, self).__init__() 11 | self.path_num = path_num 12 | 13 | #encoder 14 | self.mlp_convs1 = nn.Conv1d(512,512,1) 15 | self.mlp_convs2 = nn.Conv1d(512,512,1) 16 | 17 | self.mlp_bns1 = nn.BatchNorm1d(512) 18 | self.mlp_bns2 = nn.BatchNorm1d(512) 19 | 20 | # decoder 21 | self.fc1 = nn.Linear(512, 256) 22 | self.fc2 = nn.Linear(256, 128) 23 | self.fc3 = nn.Linear(128, self.path_num) 24 | 25 | self.bn1 = nn.BatchNorm1d(256) 26 | self.bn2 = nn.BatchNorm1d(128) 27 | self.drop1 = nn.Dropout(0.4) 28 | self.drop2 = nn.Dropout(0.4) 29 | 30 | def forward(self, f1): 31 | #encoder 32 | f1 = F.relu(self.mlp_bns1(self.mlp_convs1(f1))) 33 | f1 = F.relu(self.mlp_bns2(self.mlp_convs2(f1))) 34 | 35 | #decoder 36 | f1_max = torch.max(f1, axis = 2)[0] 37 | 38 | path_f1 = self.drop1(F.relu(self.bn1(self.fc1(f1_max)))) 39 | path_f1 = self.drop2(F.relu(self.bn2(self.fc2(path_f1)))) 40 | path_f1 = self.fc3(path_f1) #B,3 41 | path_f1 = F.softmax(path_f1,dim=1) 42 | 43 | return path_f1 44 | 45 | class get_analyser(nn.Module): 46 | def __init__(self, block_num = 3, path_num = 3): 47 | super(get_analyser, self).__init__() 48 | self.block_num = block_num 49 | self.path_num = path_num 50 | #Analyser_Block 51 | self.analyser_blocks = nn.ModuleList() 52 | for i in range(self.block_num): 53 | self.analyser_blocks.append(Analyser_Block(self.path_num)) 54 | 55 | def forward(self, f1, ab_i = 0): 56 | 57 | path_f1 = self.analyser_blocks[ab_i](f1) 58 | 59 | return path_f1 60 | 61 | class Path_Block(nn.Module): 62 | def __init__(self, path_num = 3): 63 | super(Path_Block, self).__init__() 64 | self.path_num = path_num 65 | 66 | #path-head 67 | self.mlp_convs_ph1 = nn.Conv1d(512,512,1) 68 | self.mlp_convs_ph2 = nn.Conv1d(512,512,1) 69 | 70 | self.mlp_bns_ph1 = nn.BatchNorm1d(512) 71 | self.mlp_bns_ph2 = nn.BatchNorm1d(512) 72 | 73 | #path0 74 | #pass 75 | 76 | #path1 77 | self.mlp_convs_p11 = nn.Conv1d(512,256,1) 78 | self.mlp_convs_p12 = nn.Conv1d(512,512,1) 79 | 80 | self.mlp_bns_p11 = nn.BatchNorm1d(256) 81 | self.mlp_bns_p12 = nn.BatchNorm1d(512) 82 | 83 | #path2 84 | self.mlp_convs_p21 = nn.Conv1d(512,256,1) 85 | self.mlp_convs_p22 = nn.Conv1d(512,512,1) 86 | 87 | self.mlp_bns_p21 = nn.BatchNorm1d(256) 88 | self.mlp_bns_p22 = nn.BatchNorm1d(512) 89 | 90 | self.mlp_convs_p23 = nn.Conv1d(512,256,1) 91 | self.mlp_convs_p24 = nn.Conv1d(512,512,1) 92 | 93 | self.mlp_bns_p23 = nn.BatchNorm1d(256) 94 | self.mlp_bns_p24 = nn.BatchNorm1d(512) 95 | 96 | #decoder 97 | self.fc1 = nn.Linear(512, 256) 98 | self.fc2 = nn.Linear(256, 128) 99 | self.fc3 = nn.Linear(128, 3) 100 | 101 | self.bn1 = nn.BatchNorm1d(256) 102 | self.bn2 = nn.BatchNorm1d(128) 103 | self.drop1 = nn.Dropout(0.4) 104 | self.drop2 = nn.Dropout(0.4) 105 | 106 | def forward(self, f1, path_f1): 107 | #path_f1.shape #B 108 | 109 | #path_head 110 | f1_temp = F.relu(self.mlp_bns_ph1(self.mlp_convs_ph1(f1))) #B, d, N 111 | f1 = f1 + F.relu(self.mlp_bns_ph2(self.mlp_convs_ph2(f1_temp))) #B, d, N 112 | 113 | #path_chose 114 | 115 | if(self.path_num == 2): 116 | 117 | f1_out = f1 118 | idx = torch.where(path_f1==1) 119 | if(idx[0].shape[0]>0): 120 | f1_out[idx] = self.path1(f1_out[idx]) 121 | 122 | #decoder 123 | t1 = self.path_denoise(f1_out) 124 | 125 | return f1_out, t1 126 | 127 | def path0(self,f1): 128 | 129 | #pass 130 | return f1 131 | 132 | def path1(self,f1): 133 | 134 | f1_temp = F.relu(self.mlp_bns_p11(self.mlp_convs_p11(f1))) 135 | f1_temp_max = torch.max(f1_temp, axis = 2)[0] 136 | f1_temp_max = f1_temp_max.unsqueeze(-1).repeat(1,1,f1_temp.shape[-1]) 137 | f1_cat = torch.cat((f1_temp,f1_temp_max),1) 138 | f1 = f1 + F.relu(self.mlp_bns_p12(self.mlp_convs_p12(f1_cat))) 139 | 140 | return f1 141 | 142 | def path2(self,f1): 143 | f1_temp = F.relu(self.mlp_bns_p21(self.mlp_convs_p21(f1))) 144 | f1_temp_max = torch.max(f1_temp, axis = 2)[0] 145 | f1_temp_max = f1_temp_max.unsqueeze(-1).repeat(1,1,f1_temp.shape[-1]) 146 | f1_temp = torch.cat((f1_temp,f1_temp_max),1) 147 | f2 = f1 + F.relu(self.mlp_bns_p22(self.mlp_convs_p22(f1_temp))) 148 | 149 | f1_temp = F.relu(self.mlp_bns_p23(self.mlp_convs_p23(f2))) 150 | f1_temp_max = torch.max(f1_temp, axis = 2)[0] 151 | f1_temp_max = f1_temp_max.unsqueeze(-1).repeat(1,1,f1_temp.shape[-1]) 152 | f1_temp = torch.cat((f1_temp,f1_temp_max),1) 153 | f1 = f1 + F.relu(self.mlp_bns_p24(self.mlp_convs_p24(f1_temp))) 154 | 155 | return f1 156 | 157 | def path_denoise(self, f1): 158 | 159 | f1_max = torch.max(f1, axis = 2)[0] 160 | 161 | t1 = self.drop1(F.relu(self.bn1(self.fc1(f1_max)))) 162 | t1 = self.drop2(F.relu(self.bn2(self.fc2(t1)))) 163 | t1 = self.fc3(t1) #B,3 164 | 165 | return t1 166 | 167 | class get_model(nn.Module): 168 | def __init__(self, block_num = 3, path_num = 3): 169 | super(get_model, self).__init__() 170 | channel = 3 171 | self.block_num = block_num 172 | self.path_num = path_num 173 | 174 | #encoders 175 | self.mlp = np.array([64,128,256,512]) 176 | 177 | self.mlp_convs = nn.ModuleList() 178 | 179 | self.mlp_bns = nn.ModuleList() 180 | 181 | last_channel = channel 182 | for i in range(self.mlp.shape[0]): 183 | self.mlp_convs.append(nn.Conv1d(last_channel, self.mlp[i], 1)) 184 | self.mlp_bns.append(nn.BatchNorm1d(self.mlp[i])) 185 | 186 | last_channel = self.mlp[i] 187 | 188 | #path_blocks 189 | self.pbs = nn.ModuleList() 190 | for i in range(self.block_num): 191 | self.pbs.append(Path_Block(self.path_num)) 192 | 193 | 194 | def forward(self, x, analyser, use_random_path = 0): 195 | B, _, N = x.shape 196 | 197 | #encoder 198 | x = x - x[:,:,0:1] #B,3,N 199 | 200 | f0 = x 201 | for i in range(self.mlp.shape[0]): 202 | f0 = F.relu(self.mlp_bns[i](self.mlp_convs[i](f0))) 203 | 204 | feature_m = [] 205 | path_maxprob_m = [] 206 | path_m = [] 207 | trans_m = [] 208 | 209 | f1 = f0 210 | for pb_i in range(self.block_num): 211 | #path_analyser 212 | if(use_random_path == 1): 213 | path_prob_f1 = torch.rand(B,self.path_num) 214 | path_prob_f1 = F.softmax(path_prob_f1,-1) 215 | 216 | elif(pb_i < self.block_num and use_random_path == 0): 217 | path_prob_f1 = analyser(f1, pb_i) 218 | 219 | else: 220 | print("error") 221 | return 0,0,0 222 | 223 | path_maxprob_f1, path_f1 = torch.max(path_prob_f1,axis = -1)#B 224 | 225 | #path_block 226 | f1, t1 = self.pbs[pb_i](f1,path_f1) 227 | 228 | #recode 229 | feature_m.append(f1) 230 | path_maxprob_m.append(path_maxprob_f1) 231 | path_m.append(path_f1) 232 | trans_m.append(t1) 233 | 234 | return trans_m, path_m, path_maxprob_m 235 | 236 | class get_loss(nn.Module): 237 | def __init__(self): 238 | super(get_loss, self).__init__() 239 | 240 | def forward(self, source, target, trans_m, path_m, path_maxprob_m, all_stage = 1): 241 | loss_m = [] 242 | if(all_stage == 1): 243 | for i in range(len(trans_m)): 244 | trans_i = trans_m[i] 245 | points_denoise_i = source - trans_i 246 | loss_i = self.catculate_loss(points_denoise_i,target) 247 | loss_m.append(loss_i) 248 | elif(all_stage == 0): 249 | trans_end = trans_m[-1] 250 | points_denoise_end = source - trans_end 251 | loss_end = self.catculate_loss(points_denoise_end,target) 252 | loss_m.append(loss_end) 253 | 254 | loss = loss_m[-1] 255 | for i in range(len(loss_m)-1): 256 | loss = loss + 0.1*loss_m[i] 257 | 258 | return loss 259 | 260 | def catculate_loss(self,p1,p2): #B,3 #B,N,3 261 | dist = torch.sum((p1.unsqueeze(1) - p2)**2,axis = -1) #B,N 262 | 263 | dist_min = torch.min(dist,axis = -1)[0] #B 264 | dist_max = torch.max(dist,axis = -1)[0] #B 265 | 266 | loss_1 = torch.mean(dist_min) 267 | loss_2 = torch.mean(dist_max) 268 | 269 | loss = 0.99 * loss_1 + 0.01* loss_2 270 | 271 | return loss 272 | 273 | 274 | class get_reward(nn.Module): 275 | def __init__(self): 276 | super(get_reward, self).__init__() 277 | self.reward_add = 0 278 | self.itt = 0 279 | 280 | def forward(self, source, target, label, trans_m, path_m, path_maxprob_m): 281 | 282 | trans_end = trans_m[-1] 283 | points_denoise = source - trans_end 284 | 285 | loss_start = self.catculate_loss2(source,target) # B 286 | loss_end = self.catculate_loss2(points_denoise,target) # B 287 | 288 | loss_dt = loss_end - loss_start # B 289 | 290 | p = torch.Tensor([0.002]).cuda() # penalty p 0.0002 291 | 292 | L0 = torch.Tensor([0.4]).cuda() # threshold max(L) 0.4 293 | d = loss_end / L0 # 294 | 295 | d = torch.where(d>1.0,torch.ones_like(d),d) 296 | l = torch.exp((-1.0)*(label)) 297 | lammda = 0.05 298 | 299 | path_end = path_m[-1].cuda() 300 | #reward_end = (-1) * p * path_end + (-1) * d * loss_dt #B noise awareness 301 | reward_end = (-1) * p * path_end + (-1) * (d + lammda * l)* loss_dt #B geometric awareness 302 | 303 | rewards = [] 304 | for i in range(len(path_m)-1): 305 | path_i = path_m[i].cuda() #B 306 | reward = (-1) * p * path_i #B 307 | rewards.append(reward.unsqueeze(-1)) 308 | rewards.append(reward_end.unsqueeze(-1)) 309 | rewards = torch.cat(rewards,axis = -1) 310 | 311 | loss = self.catculate_rewards_loss(rewards, path_maxprob_m) 312 | 313 | loss = torch.mean(loss) 314 | 315 | return loss 316 | 317 | def catculate_loss2(self,p1,p2): #B,3 #B,N,3 318 | dist = torch.sum((p1.unsqueeze(1) - p2)**2,axis = -1) #B,N 319 | 320 | dist_min = torch.min(dist,axis = -1)[0] #B 321 | dist_max = torch.max(dist,axis = -1)[0] #B 322 | 323 | loss_1 = dist_min 324 | loss_2 = dist_max 325 | 326 | loss = 0.99 * loss_1 + 0.01* loss_2 327 | 328 | return loss 329 | 330 | def catculate_rewards_loss(self,rewards, path_maxprob_m): 331 | R = torch.zeros(rewards.shape[0]).cuda() 332 | loss = 0 333 | gamma = 0.99 ########### 334 | for i in reversed(range(rewards.shape[1])): 335 | R = gamma * R + torch.log(path_maxprob_m[i]) * (rewards[:,i] + 0.02) 336 | loss = loss - R 337 | 338 | #R = gamma * R + rewards[:,i] 339 | #loss = loss - torch.log(path_maxprob_m[i])*(R + 0.02) 340 | loss = loss / rewards.shape[1] 341 | 342 | return loss 343 | 344 | 345 | -------------------------------------------------------------------------------- /models/pointnet_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | def timeit(tag, t): 8 | print("{}: {}s".format(tag, time() - t)) 9 | return time() 10 | 11 | def pc_normalize(pc): 12 | l = pc.shape[0] 13 | centroid = np.mean(pc, axis=0) 14 | pc = pc - centroid 15 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 16 | pc = pc / m 17 | return pc 18 | 19 | def square_distance(src, dst): 20 | """ 21 | Calculate Euclid distance between each two points. 22 | 23 | src^T * dst = xn * xm + yn * ym + zn * zm; 24 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 25 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 26 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 27 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 28 | 29 | Input: 30 | src: source points, [B, N, C] 31 | dst: target points, [B, M, C] 32 | Output: 33 | dist: per-point square distance, [B, N, M] 34 | """ 35 | B, N, _ = src.shape 36 | _, M, _ = dst.shape 37 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 38 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 39 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 40 | return dist 41 | 42 | 43 | def index_points(points, idx): 44 | """ 45 | 46 | Input: 47 | points: input points data, [B, N, C] 48 | idx: sample index data, [B, S] 49 | Return: 50 | new_points:, indexed points data, [B, S, C] 51 | """ 52 | device = points.device 53 | B = points.shape[0] 54 | view_shape = list(idx.shape) 55 | view_shape[1:] = [1] * (len(view_shape) - 1) 56 | repeat_shape = list(idx.shape) 57 | repeat_shape[0] = 1 58 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 59 | new_points = points[batch_indices, idx, :] 60 | return new_points 61 | 62 | 63 | def farthest_point_sample(xyz, npoint): 64 | """ 65 | Input: 66 | xyz: pointcloud data, [B, N, 3] 67 | npoint: number of samples 68 | Return: 69 | centroids: sampled pointcloud index, [B, npoint] 70 | """ 71 | device = xyz.device 72 | B, N, C = xyz.shape 73 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 74 | distance = torch.ones(B, N).to(device) * 1e10 75 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 76 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 77 | for i in range(npoint): 78 | centroids[:, i] = farthest 79 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 80 | dist = torch.sum((xyz - centroid) ** 2, -1) 81 | mask = dist < distance 82 | distance[mask] = dist[mask] 83 | farthest = torch.max(distance, -1)[1] 84 | return centroids 85 | 86 | 87 | def query_ball_point(radius, nsample, xyz, new_xyz): 88 | """ 89 | Input: 90 | radius: local region radius 91 | nsample: max sample number in local region 92 | xyz: all points, [B, N, 3] 93 | new_xyz: query points, [B, S, 3] 94 | Return: 95 | group_idx: grouped points index, [B, S, nsample] 96 | """ 97 | device = xyz.device 98 | B, N, C = xyz.shape 99 | _, S, _ = new_xyz.shape 100 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 101 | sqrdists = square_distance(new_xyz, xyz) 102 | group_idx[sqrdists > radius ** 2] = N 103 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 104 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 105 | mask = group_idx == N 106 | group_idx[mask] = group_first[mask] 107 | return group_idx 108 | 109 | 110 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 111 | """ 112 | Input: 113 | npoint: 114 | radius: 115 | nsample: 116 | xyz: input points position data, [B, N, 3] 117 | points: input points data, [B, N, D] 118 | Return: 119 | new_xyz: sampled points position data, [B, npoint, nsample, 3] 120 | new_points: sampled points data, [B, npoint, nsample, 3+D] 121 | """ 122 | B, N, C = xyz.shape 123 | S = npoint 124 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 125 | torch.cuda.empty_cache() 126 | new_xyz = index_points(xyz, fps_idx) 127 | torch.cuda.empty_cache() 128 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 129 | torch.cuda.empty_cache() 130 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 131 | torch.cuda.empty_cache() 132 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 133 | torch.cuda.empty_cache() 134 | 135 | if points is not None: 136 | grouped_points = index_points(points, idx) 137 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 138 | else: 139 | new_points = grouped_xyz_norm 140 | if returnfps: 141 | return new_xyz, new_points, grouped_xyz, fps_idx 142 | else: 143 | return new_xyz, new_points 144 | 145 | 146 | def sample_and_group_all(xyz, points): 147 | """ 148 | Input: 149 | xyz: input points position data, [B, N, 3] 150 | points: input points data, [B, N, D] 151 | Return: 152 | new_xyz: sampled points position data, [B, 1, 3] 153 | new_points: sampled points data, [B, 1, N, 3+D] 154 | """ 155 | device = xyz.device 156 | B, N, C = xyz.shape 157 | new_xyz = torch.zeros(B, 1, C).to(device) 158 | grouped_xyz = xyz.view(B, 1, N, C) 159 | if points is not None: 160 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 161 | else: 162 | new_points = grouped_xyz 163 | return new_xyz, new_points 164 | 165 | 166 | class PointNetSetAbstraction(nn.Module): 167 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 168 | super(PointNetSetAbstraction, self).__init__() 169 | self.npoint = npoint 170 | self.radius = radius 171 | self.nsample = nsample 172 | self.mlp_convs = nn.ModuleList() 173 | self.mlp_bns = nn.ModuleList() 174 | last_channel = in_channel 175 | for out_channel in mlp: 176 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 177 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 178 | last_channel = out_channel 179 | self.group_all = group_all 180 | 181 | def forward(self, xyz, points): 182 | """ 183 | Input: 184 | xyz: input points position data, [B, C, N] 185 | points: input points data, [B, D, N] 186 | Return: 187 | new_xyz: sampled points position data, [B, C, S] 188 | new_points_concat: sample points feature data, [B, D', S] 189 | """ 190 | xyz = xyz.permute(0, 2, 1) 191 | if points is not None: 192 | points = points.permute(0, 2, 1) 193 | 194 | if self.group_all: 195 | new_xyz, new_points = sample_and_group_all(xyz, points) 196 | else: 197 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 198 | # new_xyz: sampled points position data, [B, npoint, C] 199 | # new_points: sampled points data, [B, npoint, nsample, C+D] 200 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 201 | for i, conv in enumerate(self.mlp_convs): 202 | bn = self.mlp_bns[i] 203 | new_points = F.relu(bn(conv(new_points))) 204 | 205 | new_points = torch.max(new_points, 2)[0] 206 | new_xyz = new_xyz.permute(0, 2, 1) 207 | return new_xyz, new_points 208 | 209 | 210 | class PointNetSetAbstractionMsg(nn.Module): 211 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 212 | super(PointNetSetAbstractionMsg, self).__init__() 213 | self.npoint = npoint 214 | self.radius_list = radius_list 215 | self.nsample_list = nsample_list 216 | self.conv_blocks = nn.ModuleList() 217 | self.bn_blocks = nn.ModuleList() 218 | for i in range(len(mlp_list)): 219 | convs = nn.ModuleList() 220 | bns = nn.ModuleList() 221 | last_channel = in_channel + 3 222 | for out_channel in mlp_list[i]: 223 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 224 | bns.append(nn.BatchNorm2d(out_channel)) 225 | last_channel = out_channel 226 | self.conv_blocks.append(convs) 227 | self.bn_blocks.append(bns) 228 | 229 | def forward(self, xyz, points, xyz2 ): 230 | """ 231 | Input: 232 | xyz: input points position data, [B, C, N] 233 | points: input points data, [B, D, N] 234 | Return: 235 | new_xyz: sampled points position data, [B, C, S] 236 | new_points_concat: sample points feature data, [B, D', S] 237 | """ 238 | xyz = xyz.permute(0, 2, 1) 239 | if points is not None: 240 | points = points.permute(0, 2, 1) 241 | if xyz2 is None: 242 | xyz2 = xyz 243 | else: 244 | xyz2 = xyz2.permute(0, 2, 1) 245 | B, N, C = xyz.shape 246 | S = self.npoint 247 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 248 | new_points_list = [] 249 | for i, radius in enumerate(self.radius_list): 250 | K = self.nsample_list[i] 251 | group_idx = query_ball_point(radius, K, xyz2, new_xyz) 252 | grouped_xyz = index_points(xyz2, group_idx) 253 | grouped_xyz -= new_xyz.view(B, S, 1, C) 254 | if points is not None: 255 | grouped_points = index_points(points, group_idx) 256 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 257 | else: 258 | grouped_points = grouped_xyz 259 | 260 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 261 | for j in range(len(self.conv_blocks[i])): 262 | conv = self.conv_blocks[i][j] 263 | bn = self.bn_blocks[i][j] 264 | grouped_points = F.relu(bn(conv(grouped_points))) 265 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 266 | new_points_list.append(new_points) 267 | 268 | new_xyz = new_xyz.permute(0, 2, 1) 269 | new_points_concat = torch.cat(new_points_list, dim=1) 270 | return new_xyz, new_points_concat 271 | 272 | 273 | class PointNetFeaturePropagation(nn.Module): 274 | def __init__(self, in_channel, mlp): 275 | super(PointNetFeaturePropagation, self).__init__() 276 | self.mlp_convs = nn.ModuleList() 277 | self.mlp_bns = nn.ModuleList() 278 | last_channel = in_channel 279 | for out_channel in mlp: 280 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 281 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 282 | last_channel = out_channel 283 | 284 | def forward(self, xyz1, xyz2, points1, points2): 285 | """ 286 | Input: 287 | xyz1: input points position data, [B, C, N] 288 | xyz2: sampled input points position data, [B, C, S] 289 | points1: input points data, [B, D, N] 290 | points2: input points data, [B, D, S] 291 | Return: 292 | new_points: upsampled points data, [B, D', N] 293 | """ 294 | xyz1 = xyz1.permute(0, 2, 1) 295 | xyz2 = xyz2.permute(0, 2, 1) 296 | 297 | points2 = points2.permute(0, 2, 1) 298 | B, N, C = xyz1.shape 299 | _, S, _ = xyz2.shape 300 | 301 | if S == 1: 302 | interpolated_points = points2.repeat(1, N, 1) 303 | else: 304 | dists = square_distance(xyz1, xyz2) 305 | dists, idx = dists.sort(dim=-1) 306 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 307 | 308 | dist_recip = 1.0 / (dists + 1e-8) 309 | norm = torch.sum(dist_recip, dim=2, keepdim=True) 310 | weight = dist_recip / norm 311 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 312 | 313 | if points1 is not None: 314 | points1 = points1.permute(0, 2, 1) 315 | new_points = torch.cat([points1, interpolated_points], dim=-1) 316 | else: 317 | new_points = interpolated_points 318 | 319 | new_points = new_points.permute(0, 2, 1) 320 | for i, conv in enumerate(self.mlp_convs): 321 | bn = self.mlp_bns[i] 322 | new_points = F.relu(bn(conv(new_points))) 323 | return new_points 324 | 325 | -------------------------------------------------------------------------------- /provider.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | def normalize_data(batch_data): 4 | """ Normalize the batch data, use coordinates of the block centered at origin, 5 | Input: 6 | BxNxC array 7 | Output: 8 | BxNxC array 9 | """ 10 | B, N, C = batch_data.shape 11 | normal_data = np.zeros((B, N, C)) 12 | for b in range(B): 13 | pc = batch_data[b] 14 | centroid = np.mean(pc, axis=0) 15 | pc = pc - centroid 16 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 17 | pc = pc / m 18 | normal_data[b] = pc 19 | return normal_data 20 | 21 | 22 | def shuffle_data(data, labels): 23 | """ Shuffle data and labels. 24 | Input: 25 | data: B,N,... numpy array 26 | label: B,... numpy array 27 | Return: 28 | shuffled data, label and shuffle indices 29 | """ 30 | idx = np.arange(len(labels)) 31 | np.random.shuffle(idx) 32 | return data[idx, ...], labels[idx], idx 33 | 34 | def shuffle_points(batch_data): 35 | """ Shuffle orders of points in each point cloud -- changes FPS behavior. 36 | Use the same shuffling idx for the entire batch. 37 | Input: 38 | BxNxC array 39 | Output: 40 | BxNxC array 41 | """ 42 | idx = np.arange(batch_data.shape[1]) 43 | np.random.shuffle(idx) 44 | return batch_data[:,idx,:] 45 | 46 | def rotate_point_cloud_y_z(batch_data,batch_data2,batch_data3,threshold = 1,threshold2 = 1): 47 | """ Randomly rotate the point clouds to augument the dataset 48 | rotation is per shape based along up direction 49 | Input: 50 | BxNx3 array, original batch of point clouds 51 | Return: 52 | BxNx3 array, rotated batch of point clouds 53 | """ 54 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 55 | rotated_data2 = np.zeros(batch_data2.shape, dtype=np.float32) 56 | if(batch_data3 != []): 57 | rotated_data3 = np.zeros(batch_data3.shape, dtype=np.float32) 58 | rotation_matrixs = np.zeros([batch_data3.shape[0],3,3], dtype=np.float32) 59 | rotation_matrixs_inv = np.zeros([batch_data3.shape[0],3,3], dtype=np.float32) 60 | for k in range(batch_data.shape[0]): 61 | rotation_angle = np.random.uniform() * 2 * np.pi * threshold 62 | cosval = np.cos(rotation_angle) 63 | sinval = np.sin(rotation_angle) 64 | rotation_matrix1 = np.array([[cosval, 0, sinval], 65 | [0, 1, 0], 66 | [-sinval, 0, cosval]]) 67 | 68 | rotation_angle2 = np.random.uniform() * 2 * np.pi * threshold2 69 | cosval2 = np.cos(rotation_angle2) 70 | sinval2 = np.sin(rotation_angle2) 71 | rotation_matrix2 = np.array([[cosval2, sinval2, 0], 72 | [-sinval2, cosval2, 0], 73 | [0, 0, 1]]) 74 | rotation_matrix = np.dot(rotation_matrix1, rotation_matrix2) 75 | rotation_matrix_inv = np.linalg.inv(rotation_matrix) 76 | 77 | shape_pc = batch_data[k, ...] 78 | shape_pc2 = batch_data2[k, ...] 79 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 80 | rotated_data2[k, ...] = np.dot(shape_pc2.reshape((-1, 3)), rotation_matrix) 81 | if(batch_data3 != []): 82 | shape_pc3 = batch_data3[k, ...] 83 | rotated_data3[k, ...] = np.dot(shape_pc3.reshape((-1, 3)), rotation_matrix) 84 | rotation_matrixs[k, ...] = rotation_matrix 85 | rotation_matrixs_inv[k, ...] = rotation_matrix_inv 86 | if(batch_data3 != []): 87 | return rotated_data, rotated_data2, rotated_data3, rotation_matrixs, rotation_matrixs_inv 88 | return rotated_data, rotated_data2, rotation_matrixs, rotation_matrixs_inv 89 | 90 | def rotate_point_cloud(batch_data,batch_data2,batch_data3 = [],threshold = 1): 91 | """ Randomly rotate the point clouds to augument the dataset 92 | rotation is per shape based along up direction 93 | Input: 94 | BxNx3 array, original batch of point clouds 95 | Return: 96 | BxNx3 array, rotated batch of point clouds 97 | """ 98 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 99 | rotated_data2 = np.zeros(batch_data2.shape, dtype=np.float32) 100 | if(batch_data3 != []): 101 | rotated_data3 = np.zeros(batch_data3.shape, dtype=np.float32) 102 | for k in range(batch_data.shape[0]): 103 | rotation_angle = np.random.uniform() * 2 * np.pi * threshold 104 | cosval = np.cos(rotation_angle) 105 | sinval = np.sin(rotation_angle) 106 | rotation_matrix = np.array([[cosval, 0, sinval], 107 | [0, 1, 0], 108 | [-sinval, 0, cosval]]) 109 | shape_pc = batch_data[k, ...] 110 | shape_pc2 = batch_data2[k, ...] 111 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 112 | rotated_data2[k, ...] = np.dot(shape_pc2.reshape((-1, 3)), rotation_matrix) 113 | if(batch_data3 != []): 114 | shape_pc3 = batch_data3[k, ...] 115 | rotated_data3[k, ...] = np.dot(shape_pc3.reshape((-1, 3)), rotation_matrix) 116 | if(batch_data3 != []): 117 | return rotated_data, rotated_data2, rotated_data3 118 | return rotated_data, rotated_data2 119 | 120 | def rotate_point_cloud_z(batch_data,batch_data2,batch_data3 = [],threshold = 1): 121 | """ Randomly rotate the point clouds to augument the dataset 122 | rotation is per shape based along up direction 123 | Input: 124 | BxNx3 array, original batch of point clouds 125 | Return: 126 | BxNx3 array, rotated batch of point clouds 127 | """ 128 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 129 | rotated_data2 = np.zeros(batch_data2.shape, dtype=np.float32) 130 | if(batch_data3 != []): 131 | rotated_data3 = np.zeros(batch_data3.shape, dtype=np.float32) 132 | for k in range(batch_data.shape[0]): 133 | rotation_angle = np.random.uniform() * 2 * np.pi * threshold 134 | cosval = np.cos(rotation_angle) 135 | sinval = np.sin(rotation_angle) 136 | rotation_matrix = np.array([[cosval, sinval, 0], 137 | [-sinval, cosval, 0], 138 | [0, 0, 1]]) 139 | shape_pc = batch_data[k, ...] 140 | shape_pc2 = batch_data2[k, ...] 141 | 142 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 143 | rotated_data2[k, ...] = np.dot(shape_pc2.reshape((-1, 3)), rotation_matrix) 144 | if(batch_data3 != []): 145 | shape_pc3 = batch_data3[k, ...] 146 | rotated_data3[k, ...] = np.dot(shape_pc3.reshape((-1, 3)), rotation_matrix) 147 | if(batch_data3 != []): 148 | return rotated_data, rotated_data2, rotated_data3 149 | return rotated_data, rotated_data2 150 | 151 | def rotate_point_cloud_with_normal(batch_xyz_normal): 152 | ''' Randomly rotate XYZ, normal point cloud. 153 | Input: 154 | batch_xyz_normal: B,N,6, first three channels are XYZ, last 3 all normal 155 | Output: 156 | B,N,6, rotated XYZ, normal point cloud 157 | ''' 158 | for k in range(batch_xyz_normal.shape[0]): 159 | rotation_angle = np.random.uniform() * 2 * np.pi 160 | cosval = np.cos(rotation_angle) 161 | sinval = np.sin(rotation_angle) 162 | rotation_matrix = np.array([[cosval, 0, sinval], 163 | [0, 1, 0], 164 | [-sinval, 0, cosval]]) 165 | shape_pc = batch_xyz_normal[k,:,0:3] 166 | shape_normal = batch_xyz_normal[k,:,3:6] 167 | batch_xyz_normal[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 168 | batch_xyz_normal[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), rotation_matrix) 169 | return batch_xyz_normal 170 | 171 | def rotate_perturbation_point_cloud_with_normal(batch_data, angle_sigma=0.06, angle_clip=0.18): 172 | """ Randomly perturb the point clouds by small rotations 173 | Input: 174 | BxNx6 array, original batch of point clouds and point normals 175 | Return: 176 | BxNx3 array, rotated batch of point clouds 177 | """ 178 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 179 | for k in range(batch_data.shape[0]): 180 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 181 | Rx = np.array([[1,0,0], 182 | [0,np.cos(angles[0]),-np.sin(angles[0])], 183 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 184 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 185 | [0,1,0], 186 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 187 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 188 | [np.sin(angles[2]),np.cos(angles[2]),0], 189 | [0,0,1]]) 190 | R = np.dot(Rz, np.dot(Ry,Rx)) 191 | shape_pc = batch_data[k,:,0:3] 192 | shape_normal = batch_data[k,:,3:6] 193 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), R) 194 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1, 3)), R) 195 | return rotated_data 196 | 197 | 198 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 199 | """ Rotate the point cloud along up direction with certain angle. 200 | Input: 201 | BxNx3 array, original batch of point clouds 202 | Return: 203 | BxNx3 array, rotated batch of point clouds 204 | """ 205 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 206 | for k in range(batch_data.shape[0]): 207 | #rotation_angle = np.random.uniform() * 2 * np.pi 208 | cosval = np.cos(rotation_angle) 209 | sinval = np.sin(rotation_angle) 210 | rotation_matrix = np.array([[cosval, 0, sinval], 211 | [0, 1, 0], 212 | [-sinval, 0, cosval]]) 213 | shape_pc = batch_data[k,:,0:3] 214 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 215 | return rotated_data 216 | 217 | def rotate_point_cloud_by_angle_with_normal(batch_data, rotation_angle): 218 | """ Rotate the point cloud along up direction with certain angle. 219 | Input: 220 | BxNx6 array, original batch of point clouds with normal 221 | scalar, angle of rotation 222 | Return: 223 | BxNx6 array, rotated batch of point clouds iwth normal 224 | """ 225 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 226 | for k in range(batch_data.shape[0]): 227 | #rotation_angle = np.random.uniform() * 2 * np.pi 228 | cosval = np.cos(rotation_angle) 229 | sinval = np.sin(rotation_angle) 230 | rotation_matrix = np.array([[cosval, 0, sinval], 231 | [0, 1, 0], 232 | [-sinval, 0, cosval]]) 233 | shape_pc = batch_data[k,:,0:3] 234 | shape_normal = batch_data[k,:,3:6] 235 | rotated_data[k,:,0:3] = np.dot(shape_pc.reshape((-1, 3)), rotation_matrix) 236 | rotated_data[k,:,3:6] = np.dot(shape_normal.reshape((-1,3)), rotation_matrix) 237 | return rotated_data 238 | 239 | 240 | 241 | def rotate_perturbation_point_cloud(batch_data, angle_sigma=0.06, angle_clip=0.18): 242 | """ Randomly perturb the point clouds by small rotations 243 | Input: 244 | BxNx3 array, original batch of point clouds 245 | Return: 246 | BxNx3 array, rotated batch of point clouds 247 | """ 248 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 249 | for k in range(batch_data.shape[0]): 250 | angles = np.clip(angle_sigma*np.random.randn(3), -angle_clip, angle_clip) 251 | Rx = np.array([[1,0,0], 252 | [0,np.cos(angles[0]),-np.sin(angles[0])], 253 | [0,np.sin(angles[0]),np.cos(angles[0])]]) 254 | Ry = np.array([[np.cos(angles[1]),0,np.sin(angles[1])], 255 | [0,1,0], 256 | [-np.sin(angles[1]),0,np.cos(angles[1])]]) 257 | Rz = np.array([[np.cos(angles[2]),-np.sin(angles[2]),0], 258 | [np.sin(angles[2]),np.cos(angles[2]),0], 259 | [0,0,1]]) 260 | R = np.dot(Rz, np.dot(Ry,Rx)) 261 | shape_pc = batch_data[k, ...] 262 | rotated_data[k, ...] = np.dot(shape_pc.reshape((-1, 3)), R) 263 | return rotated_data 264 | 265 | 266 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 267 | """ Randomly jitter points. jittering is per point. 268 | Input: 269 | BxNx3 array, original batch of point clouds 270 | Return: 271 | BxNx3 array, jittered batch of point clouds 272 | """ 273 | B, N, C = batch_data.shape 274 | assert(clip > 0) 275 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 276 | jittered_data += batch_data 277 | return jittered_data 278 | 279 | def shift_point_cloud(batch_data, batch_data2, shift_range=0.1): 280 | """ Randomly shift point cloud. Shift is per point cloud. 281 | Input: 282 | BxNx3 array, original batch of point clouds 283 | Return: 284 | BxNx3 array, shifted batch of point clouds 285 | """ 286 | B, N, C = batch_data.shape 287 | shifts = np.random.uniform(-shift_range, shift_range, (B,3)) 288 | for batch_index in range(B): 289 | batch_data[batch_index,:,:] += shifts[batch_index,:] 290 | batch_data2[batch_index,:,:] += shifts[batch_index,:] 291 | return batch_data,batch_data2 292 | 293 | 294 | def random_scale_point_cloud(batch_data, batch_data2, scale_low=0.8, scale_high=1.25): 295 | """ Randomly scale the point cloud. Scale is per point cloud. 296 | Input: 297 | BxNx3 array, original batch of point clouds 298 | Return: 299 | BxNx3 array, scaled batch of point clouds 300 | """ 301 | B, N, C = batch_data.shape 302 | scales = np.random.uniform(scale_low, scale_high, B) 303 | for batch_index in range(B): 304 | batch_data[batch_index,:,:] *= scales[batch_index] 305 | batch_data2[batch_index,:,:] *= scales[batch_index] 306 | return batch_data,batch_data2 307 | 308 | def random_point_dropout(batch_pc, max_dropout_ratio=0.5): 309 | ''' batch_pc: BxNx3 ''' 310 | for b in range(batch_pc.shape[0]): 311 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 312 | drop_idx = np.where(np.random.random((batch_pc.shape[1]))<=dropout_ratio)[0] 313 | if len(drop_idx)>0: 314 | batch_pc[b,drop_idx,:] = batch_pc[b,0,:] # set to the first point 315 | return batch_pc 316 | 317 | def random_point_dropout2(batch_pc, max_dropout_ratio=0.5): 318 | ''' batch_pc: BxNx3 ''' 319 | for b in range(batch_pc.shape[0]): 320 | dropout_ratio = np.random.random()*max_dropout_ratio # 0~0.875 321 | keep_idx = np.where(np.random.random((batch_pc.shape[1]))>=dropout_ratio)[0] 322 | if len(keep_idx)>4: 323 | batch_pc = batch_pc[b,keep_idx,:].reshape(1,-1,3) 324 | return batch_pc 325 | 326 | 327 | 328 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Wei 3 | Date: Mar 2022 4 | """ 5 | import argparse 6 | import os 7 | from data_utils.DataLoader import PatchDataset 8 | import torch 9 | import datetime 10 | import logging 11 | from pathlib import Path 12 | import sys 13 | import importlib 14 | import shutil 15 | from tqdm import tqdm 16 | import provider 17 | import numpy as np 18 | from scipy import spatial 19 | from glob import glob 20 | from data_utils.Pointfilter_Utils import pca_alignment 21 | import time 22 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 23 | ROOT_DIR = BASE_DIR 24 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 25 | 26 | 27 | def parse_args(): 28 | parser = argparse.ArgumentParser('Model') 29 | parser.add_argument('--model', type=str, default='model', help='model name [default: model]') 30 | parser.add_argument('--batch_size', type=int, default=100, help='Batch Size during training [default: 16]') 31 | parser.add_argument('--epoch', default=1000, type=int, help='Epoch to run [default: 251]') 32 | parser.add_argument('--learning_rate', default=0.000001, type=float, help='Initial learning rate [default: 0.001]') 33 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use [default: GPU 0]') 34 | parser.add_argument('--optimizer', type=str, default='Adam', help='Adam or SGD [default: Adam]') 35 | parser.add_argument('--optimizer2', type=str, default='Adam', help='Adam or SGD [default: Adam]') 36 | parser.add_argument('--log_dir', type=str, default='model', help='Log path [default: None]') 37 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay [default: 1e-4]') 38 | parser.add_argument('--npoint', type=int, default=128, help='Point Number [default: 2048]') 39 | parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]') 40 | parser.add_argument('--step_size', type=int, default=20, help='Decay step for lr decay [default: every 20 epochs]') 41 | parser.add_argument('--lr_decay', type=float, default=0.5, help='Decay rate for lr decay [default: 0.5]') 42 | parser.add_argument('--use_random_path', type=int, default=0, help='wether use random path 0 no ,1 yes, 2 all 0, 3 all 1 ') 43 | parser.add_argument('--block_num', type=int, default=6, help='num of denosier block') 44 | parser.add_argument('--path_num', type=int, default=2, help='path num of each denosier block') 45 | 46 | return parser.parse_args() 47 | 48 | def pca_normalize(pc): 49 | normalize_data = np.zeros(pc.shape, dtype=np.float32) 50 | martix_inv = np.zeros((pc.shape[0],3,3),dtype=np.float32) 51 | 52 | centroid = pc[:,0,:] 53 | centroid = np.expand_dims(centroid, axis=1) 54 | pc = pc - centroid 55 | 56 | m = np.max(np.sqrt(np.sum(pc**2, axis=2)),axis = 1, keepdims=True) 57 | scale_inv = m #B 58 | pc = pc / np.expand_dims(m, axis=-1) 59 | 60 | for B in range(pc.shape[0]): 61 | x, pca_martix_inv = pca_alignment(pc[B,:,:]) 62 | normalize_data[B, ...] = x 63 | martix_inv[B, ...] = pca_martix_inv 64 | 65 | return normalize_data, martix_inv, scale_inv 66 | 67 | def main(args,test_data_dir): 68 | def log_string(str): 69 | logger.info(str) 70 | print(str) 71 | 72 | '''HYPER PARAMETER''' 73 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 74 | 75 | '''CREATE DIR''' 76 | experiment_dir = Path('./log/') 77 | experiment_dir = experiment_dir.joinpath('path_denoise/'+ args.model) 78 | checkpoints_dir = experiment_dir.joinpath('checkpoints/') 79 | 80 | '''LOG''' 81 | args = parse_args() 82 | name_dir = test_data_dir 83 | root = 'data/test_data/' 84 | 85 | DATA_PATH = 'data/test_data/' + name_dir + '/' 86 | samples = glob(DATA_PATH+"/*.xyz") 87 | samples.sort() 88 | #print(samples) 89 | batch_size = args.batch_size 90 | 91 | block_num = args.block_num 92 | path_num = args.path_num 93 | 94 | '''MODEL LOADING''' 95 | MODEL = importlib.import_module(args.model) 96 | 97 | denoiser = MODEL.get_model(block_num, path_num).cuda() 98 | analyser = MODEL.get_analyser(block_num, path_num).cuda() 99 | 100 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') 101 | denoiser.load_state_dict(checkpoint['denoiser_model_state_dict']) 102 | analyser.load_state_dict(checkpoint['analyser_model_state_dict']) 103 | 104 | 105 | '''save dir''' 106 | save_dir = './data/test_data/results/' + name_dir 107 | if(not os.path.exists(save_dir)): 108 | os.makedirs(save_dir) 109 | 110 | for i,item in tqdm(enumerate(samples)): 111 | print(item) 112 | start_time = time.time() 113 | data_name = item.split('/')[-1][:-4] 114 | input_data = np.loadtxt(DATA_PATH + data_name + '.xyz').astype(np.float32) # 115 | 116 | with torch.no_grad(): 117 | denoiser = denoiser.eval() 118 | analyser = analyser.eval() 119 | source_data = input_data[:,:3] 120 | 121 | path_first = None 122 | 123 | for iter_time in range(2): 124 | 125 | inputs = input_data[:,:3] 126 | nbrs = spatial.cKDTree(inputs) # kd tree 127 | 128 | batch_num = int(inputs.shape[0] / batch_size) 129 | add_batch = 0 if(inputs.shape[0] % batch_size ==0)else 1 130 | 131 | normalize_trans_all = [] 132 | path_all = [] 133 | 134 | dist,idxs = nbrs.query(inputs,k = int(128)) #s 135 | 136 | input_patchs = inputs[idxs,:] # B,128,3 137 | 138 | normalize_input_patchs, martix_inv, scale_inv = pca_normalize(input_patchs) # B,K,3 #B,3,3 #B 139 | print('time1:', time.time()-start_time) 140 | for i in tqdm(range(batch_num+add_batch)): 141 | points = normalize_input_patchs[i*batch_size:(i+1)*batch_size,:,:] # B,128,3 142 | 143 | points = torch.Tensor(points) 144 | points = points.float().cuda() 145 | 146 | points = points.transpose(2, 1) 147 | 148 | trans_m, path_m, _ = denoiser(points,analyser,args.use_random_path) 149 | 150 | trans_m = torch.cat(trans_m,axis = 0).reshape(block_num,-1,3).transpose(1,0) #B,block_num,3 151 | 152 | path_m = torch.cat(path_m,axis = 0).reshape(block_num,-1).transpose(1,0) #B,block_num 153 | 154 | normalize_trans_all.append(trans_m) 155 | path_all.append(path_m) 156 | 157 | normalize_trans_all = torch.cat(normalize_trans_all, axis = 0) # B,6,3 158 | path_all = torch.cat(path_all, axis = 0) 159 | 160 | normalize_trans_all = normalize_trans_all.cpu().numpy().astype(np.float32).reshape(-1,block_num,3) 161 | path_all = path_all.cpu().numpy().astype(np.float32).reshape(-1,block_num) 162 | 163 | trans_all = np.matmul(martix_inv, normalize_trans_all.transpose(0,2,1)).transpose(0,2,1) # B,6,3 164 | 165 | trans_all = trans_all * np.expand_dims(scale_inv,axis = -1) 166 | 167 | outputs_all = inputs[:,:3].reshape(-1,1,3) - trans_all 168 | 169 | #scale_recover 170 | if(iter_time ==0): 171 | path_first = path_all 172 | 173 | #save 174 | inputs_start = np.concatenate((inputs[:,:3],path_all[:,0].reshape(-1,1)),axis = -1) 175 | np.savetxt(save_dir+"/"+ data_name +"_input_start.xyz", inputs_start.astype(np.float32), fmt = '%.6f') 176 | 177 | inputs_sum = np.concatenate((inputs[:,:3],np.sum(path_all, axis = 1).reshape(-1,1)),axis = -1) 178 | np.savetxt(save_dir+"/"+ data_name +"_input_sum.xyz",inputs_sum.astype(np.float32),fmt = '%.6f') 179 | 180 | output_end = np.concatenate((outputs_all[:,-1,:],np.sum(path_first, axis = 1).reshape(-1,1)),axis = -1) 181 | np.savetxt(save_dir+"/"+ data_name +"_output_end.xyz",output_end.astype(np.float32),fmt = '%.6f') 182 | 183 | for i in range(0,block_num-1): 184 | path_i = path_all[:,i+1].reshape(-1,1) 185 | output_i = outputs_all[:,i,:] 186 | output_i = np.concatenate((output_i,path_i),axis = -1) 187 | np.savetxt(save_dir+"/"+ data_name +"_output_" + str(i) + ".xyz",output_i.astype(np.float32),fmt = '%.6f') 188 | 189 | input_data = outputs_all[:,-1,:] #iter 190 | data_name = data_name +'_output_end' 191 | print('time2:', time.time()-start_time) 192 | if __name__ == '__main__': 193 | args = parse_args() 194 | test_data_dir = 'benchmark81_20000' # test dir ['benchmark81_10000','benchmark81_20000','benchmark81_50000','kinect_fusion','kinect_v1','kinect_v2'] 195 | main(args,test_data_dir) 196 | 197 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | """ 2 | Author: Zeyong Wei 3 | Date: Mar 2022 4 | """ 5 | import argparse 6 | import os 7 | from data_utils.DataLoader import PatchDataset 8 | import torch 9 | import datetime 10 | import logging 11 | from pathlib import Path 12 | import sys 13 | import importlib 14 | import shutil 15 | from tqdm import tqdm 16 | import provider 17 | import numpy as np 18 | from scipy import spatial 19 | 20 | BASE_DIR = os.path.dirname(os.path.abspath(__file__)) 21 | ROOT_DIR = BASE_DIR 22 | sys.path.append(os.path.join(ROOT_DIR, 'models')) 23 | 24 | 25 | def parse_args(): 26 | parser = argparse.ArgumentParser('Model') 27 | parser.add_argument('--model', type=str, default='model', help='model name [default: model]') 28 | parser.add_argument('--batch_size', type=int, default=64, help='Batch Size during training [default: 16]') 29 | parser.add_argument('--epoch', default=300, type=int, help='Epoch to run [default: 251]') 30 | parser.add_argument('--learning_rate', default=0.000001, type=float, help='Initial learning rate [default: 0.001]') 31 | parser.add_argument('--gpu', type=str, default='0', help='GPU to use [default: GPU 0]') 32 | parser.add_argument('--optimizer', type=str, default='Adam', help='Adam or SGD [default: Adam]') 33 | parser.add_argument('--optimizer2', type=str, default='Adam', help='Adam or SGD [default: Adam]') 34 | parser.add_argument('--log_dir', type=str, default='model', help='Log path [default: None]') 35 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay [default: 1e-4]') 36 | parser.add_argument('--npoint', type=int, default=128, help='Point Number [default: 2048]') 37 | parser.add_argument('--normal', action='store_true', default=False, help='Whether to use normal information [default: False]') 38 | parser.add_argument('--step_size', type=int, default=20, help='Decay step for lr decay [default: every 20 epochs]') 39 | parser.add_argument('--lr_decay', type=float, default=0.5, help='Decay rate for lr decay [default: 0.5]') 40 | parser.add_argument('--use_random_path', type=int, default=0, help='whether use random path, 0 no ,1 yes, 2 all 0, 3 all 1 ') 41 | parser.add_argument('--block_num', type=int, default=6, help='num of denosier block') 42 | parser.add_argument('--path_num', type=int, default=2, help='path num of each denosier block') 43 | 44 | return parser.parse_args() 45 | 46 | def main(args): 47 | def log_string(str): 48 | logger.info(str) 49 | print(str) 50 | 51 | '''HYPER PARAMETER''' 52 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 53 | 54 | '''CREATE DIR''' 55 | timestr = str(datetime.datetime.now().strftime('%Y-%m-%d_%H-%M')) 56 | experiment_dir = Path('./log/') 57 | experiment_dir.mkdir(exist_ok=True) 58 | experiment_dir = experiment_dir.joinpath('path_denoise') 59 | experiment_dir.mkdir(exist_ok=True) 60 | if args.log_dir is None: 61 | experiment_dir = experiment_dir.joinpath(timestr) 62 | else: 63 | experiment_dir = experiment_dir.joinpath(args.log_dir) 64 | experiment_dir.mkdir(exist_ok=True) 65 | checkpoints_dir = experiment_dir.joinpath('checkpoints/') 66 | checkpoints_dir.mkdir(exist_ok=True) 67 | log_dir = experiment_dir.joinpath('logs/') 68 | log_dir.mkdir(exist_ok=True) 69 | 70 | '''LOG''' 71 | args = parse_args() 72 | logger = logging.getLogger("Model") 73 | logger.setLevel(logging.INFO) 74 | formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') 75 | file_handler = logging.FileHandler('%s/%s.txt' % (log_dir, args.model)) 76 | file_handler.setLevel(logging.INFO) 77 | file_handler.setFormatter(formatter) 78 | logger.addHandler(file_handler) 79 | log_string('PARAMETER ...') 80 | log_string(args) 81 | 82 | root = 'data/' 83 | 84 | TRAIN_DATASET = PatchDataset(root = root, npoints=args.npoint, split='train', normal_channel=args.normal) 85 | trainDataLoader = torch.utils.data.DataLoader(TRAIN_DATASET, batch_size=args.batch_size,shuffle=True, num_workers=4) 86 | TEST_DATASET = PatchDataset(root = root, npoints=args.npoint, split='test', normal_channel=args.normal) 87 | testDataLoader = torch.utils.data.DataLoader(TEST_DATASET, batch_size=args.batch_size,shuffle=False, num_workers=4) 88 | log_string("The number of training data is: %d" % len(TRAIN_DATASET)) 89 | log_string("The number of test data is: %d" % len(TEST_DATASET)) 90 | 91 | block_num = args.block_num 92 | path_num = args.path_num 93 | 94 | '''MODEL LOADING''' 95 | MODEL = importlib.import_module(args.model) 96 | shutil.copy('models/%s.py' % args.model, str(experiment_dir)) 97 | shutil.copy('models/pointnet_util.py', str(experiment_dir)) 98 | shutil.copy('data_utils/DataLoader.py', str(experiment_dir)) 99 | shutil.copy('provider.py', str(experiment_dir)) 100 | shutil.copy('train.py', str(experiment_dir)) 101 | shutil.copy('test.py', str(experiment_dir)) 102 | 103 | denoiser = MODEL.get_model(block_num, path_num).cuda() 104 | criterion = MODEL.get_loss().cuda() 105 | 106 | analyser = MODEL.get_analyser(block_num, path_num).cuda() 107 | get_reward = MODEL.get_reward().cuda() 108 | 109 | def weights_init(m): 110 | classname = m.__class__.__name__ 111 | if classname.find('Conv2d') != -1: 112 | torch.nn.init.xavier_normal_(m.weight.data) 113 | torch.nn.init.constant_(m.bias.data, 0.0) 114 | elif classname.find('Linear') != -1: 115 | torch.nn.init.xavier_normal_(m.weight.data) 116 | torch.nn.init.constant_(m.bias.data, 0.0) 117 | 118 | try: 119 | checkpoint = torch.load(str(experiment_dir) + '/checkpoints/best_model.pth') 120 | denoiser.load_state_dict(checkpoint['denoiser_model_state_dict']) 121 | start_epoch = checkpoint['epoch'] + 1 122 | log_string('Use pretrain model') 123 | except: 124 | log_string('No existing model, starting training from scratch...') 125 | start_epoch = 0 126 | 127 | 128 | if args.optimizer == 'Adam': 129 | optimizer = torch.optim.Adam( 130 | denoiser.parameters(), 131 | lr=args.learning_rate, 132 | betas=(0.9, 0.999), 133 | eps=1e-08, 134 | weight_decay=args.decay_rate 135 | ) 136 | else: 137 | optimizer = torch.optim.SGD(denoiser.parameters(), lr=args.learning_rate, momentum=0.9) 138 | 139 | if args.optimizer2 == 'Adam': 140 | optimizer2 = torch.optim.Adam( 141 | analyser.parameters(), 142 | lr=args.learning_rate, 143 | betas=(0.9, 0.999), 144 | eps=1e-08, 145 | weight_decay=args.decay_rate 146 | ) 147 | else: 148 | optimizer2 = torch.optim.SGD(analyser.parameters(), lr=args.learning_rate, momentum=0.9) 149 | 150 | 151 | LEARNING_RATE_CLIP = 1e-6 152 | 153 | best_loss_denoise = 999999 154 | global_epoch = 0 155 | 156 | for epoch in range(start_epoch,args.epoch): 157 | log_string('Epoch %d (%d/%s):' % (global_epoch, epoch, args.epoch)) 158 | '''Adjust learning rate and BN momentum''' 159 | lr = max(args.learning_rate * (args.lr_decay ** (epoch // args.step_size)), LEARNING_RATE_CLIP) 160 | log_string('Learning rate:%f' % lr) 161 | for param_group in optimizer.param_groups: 162 | param_group['lr'] = lr 163 | 164 | if(args.use_random_path == 0): 165 | for param_group in optimizer2.param_groups: 166 | param_group['lr'] = lr 167 | 168 | num_batches = len(trainDataLoader) 169 | 170 | '''learning one epoch''' 171 | loss_sum = 0 172 | reward_sum = 0 173 | 174 | for i, data in tqdm(enumerate(trainDataLoader), total=len(trainDataLoader), smoothing=0.9): 175 | inputs, target, label = data 176 | cur_batch_size, NUM_POINT, _ = inputs.size() #B,N,3 177 | 178 | points = inputs.data.numpy() 179 | target = target.data.numpy() 180 | label = label.data.numpy() 181 | 182 | points = provider.random_point_dropout(points) 183 | 184 | source = points[:,0,:] #B,3 185 | 186 | points = torch.Tensor(points) 187 | source = torch.Tensor(source) 188 | target = torch.Tensor(target) 189 | label = torch.Tensor(label) 190 | 191 | points, source, target, label = points.float().cuda(), source.float().cuda(), target.float().cuda(), label.float().cuda() 192 | 193 | 194 | points = points.transpose(2, 1) 195 | 196 | if(args.use_random_path == 1): 197 | optimizer.zero_grad() 198 | denoiser = denoiser.train() 199 | analyser = analyser.eval() 200 | 201 | trans_m, path_m, path_maxprob_m = denoiser(points,analyser,args.use_random_path) 202 | 203 | trans = trans_m[-1].reshape(cur_batch_size, 3) 204 | points_denoise = source - trans 205 | 206 | loss = criterion(source, target, trans_m, None, None, 1) 207 | 208 | loss.backward() 209 | optimizer.step() 210 | loss_sum += loss 211 | 212 | elif(args.use_random_path == 0): 213 | 214 | optimizer.zero_grad() 215 | denoiser = denoiser.train() 216 | 217 | optimizer2.zero_grad() 218 | analyser = analyser.train() 219 | 220 | trans_m, path_m, path_maxprob_m = denoiser(points,analyser,args.use_random_path) 221 | 222 | trans = trans_m[-1].reshape(cur_batch_size, 3) 223 | points_denoise = source - trans 224 | 225 | loss = criterion(source, target, trans_m, path_m, path_maxprob_m, 0) 226 | reward = get_reward(source, target, label, trans_m, path_m, path_maxprob_m) 227 | 228 | loss.backward(retain_graph=True) 229 | loss_sum += loss 230 | 231 | reward.backward() 232 | reward_sum += reward 233 | 234 | optimizer.step() 235 | optimizer2.step() 236 | 237 | if(args.use_random_path == 1): 238 | log_string('Training mean loss: %f' % (loss_sum / num_batches)) 239 | elif(args.use_random_path == 0): 240 | log_string('Training mean loss: %f, Training mean reward: %f' % (loss_sum / num_batches,reward_sum / num_batches)) 241 | 242 | with torch.no_grad(): 243 | test_metrics = {} 244 | cur_mean_loss = [] 245 | cur_mean_reward = [] 246 | 247 | for batch_id, data in tqdm(enumerate(testDataLoader), total=len(testDataLoader), smoothing=0.9): 248 | inputs, target, label = data 249 | cur_batch_size, NUM_POINT, _ = inputs.size() #B,N,3 250 | 251 | points = inputs.data.numpy() 252 | target = target.data.numpy() 253 | label = label.data.numpy() 254 | 255 | source = points[:,0,:] #B,3 256 | 257 | points = torch.Tensor(points) 258 | source = torch.Tensor(source) 259 | target = torch.Tensor(target) 260 | label = torch.Tensor(label) 261 | 262 | points, source, target, label = points.float().cuda(), source.float().cuda(), target.float().cuda(), label.float().cuda() 263 | 264 | points = points.transpose(2, 1) 265 | 266 | denoiser = denoiser.eval() 267 | 268 | if(args.use_random_path == 0): 269 | analyser = analyser.eval() 270 | 271 | trans_m, path_m, path_maxprob_m = denoiser(points,analyser,args.use_random_path) 272 | 273 | trans = trans_m[-1].reshape(cur_batch_size,3) 274 | points_denoise = source - trans 275 | 276 | if(args.use_random_path == 1): 277 | loss = criterion(source, target, trans_m, None, None, 1) 278 | cur_mean_loss.append(loss.item()) 279 | elif(args.use_random_path == 0): 280 | loss = criterion(source, target, trans_m, path_m, path_maxprob_m, 0) 281 | reward = get_reward(source, target, label, trans_m, path_m, path_maxprob_m) 282 | 283 | cur_mean_loss.append(loss.item()) 284 | 285 | cur_mean_reward.append(reward.item()) 286 | 287 | if(args.use_random_path == 1): 288 | test_metrics['loss_denoise'] = np.mean(cur_mean_loss) 289 | log_string('Epoch %d test loss_denoise: %f' % (epoch, test_metrics['loss_denoise'])) 290 | elif(args.use_random_path == 0): 291 | test_metrics['loss_denoise'] = np.mean(cur_mean_loss) 292 | test_metrics['reward_denoise'] = np.mean(cur_mean_reward) 293 | log_string('Epoch %d test loss_denoise: %f, test reward_denoise: %f' % (epoch, test_metrics['loss_denoise'],test_metrics['reward_denoise'])) 294 | 295 | if (epoch%10 == 0): 296 | logger.info('Save model...') 297 | savepath = str(checkpoints_dir) + '/model_'+ str(epoch) +'.pth' 298 | log_string('Saving at %s'% savepath) 299 | state = { 300 | 'epoch': epoch, 301 | 'denoiser_model_state_dict': denoiser.state_dict(), 302 | 'optimizer_state_dict': optimizer.state_dict(), 303 | 'analyser_model_state_dict': analyser.state_dict(), 304 | 'optimizer2_state_dict': optimizer2.state_dict(), 305 | } 306 | torch.save(state, savepath) 307 | log_string('Saving model....') 308 | 309 | if test_metrics['loss_denoise'] < best_loss_denoise: 310 | best_loss_denoise = test_metrics['loss_denoise'] 311 | if (True): 312 | logger.info('Save model...') 313 | savepath = str(checkpoints_dir) + '/model_'+ str(epoch) +'.pth' 314 | savepath2 = str(checkpoints_dir) + '/best_model' +'.pth' 315 | log_string('Saving at %s'% savepath) 316 | state = { 317 | 'epoch': epoch, 318 | 'denoiser_model_state_dict': denoiser.state_dict(), 319 | 'optimizer_state_dict': optimizer.state_dict(), 320 | 'analyser_model_state_dict': analyser.state_dict(), 321 | 'optimizer2_state_dict': optimizer2.state_dict(), 322 | } 323 | if(epoch > 1): 324 | torch.save(state, savepath) 325 | torch.save(state, savepath2) 326 | log_string('Saving model....') 327 | if test_metrics['loss_denoise'] < best_loss_denoise: 328 | best_loss_denoise = test_metrics['loss_denoise'] 329 | if(args.use_random_path == 1): 330 | log_string('Best loss_denoise is: %.6f'%(best_loss_denoise)) 331 | elif(args.use_random_path == 0): 332 | log_string('Best loss_denoise is: %.6f, Best reward_denoise is: %.6f'%(best_loss_denoise, test_metrics['reward_denoise'])) 333 | global_epoch+=1 334 | 335 | if __name__ == '__main__': 336 | args = parse_args() 337 | main(args) 338 | 339 | --------------------------------------------------------------------------------