├── NetworkN1.py ├── README.md ├── Trained_Models └── model_full_ae.pth ├── dataset923.py ├── image ├── figure1.png └── figure2.png ├── testN.py ├── train_NetworkN1.py └── utils118.py /NetworkN1.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.parallel 5 | import torch.utils.data 6 | import torch.nn.functional as F 7 | from torch.autograd import Variable 8 | import utils 9 | import os 10 | import math 11 | 12 | 13 | def knn(x, k): 14 | inner = -2 * torch.matmul(x.transpose(2, 1), x) 15 | xx = torch.sum(x ** 2, dim=1, keepdim=True) 16 | pairwise_distance = -xx - inner - xx.transpose(2, 1) 17 | 18 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k) 19 | return idx 20 | 21 | 22 | def get_idx(x, k=20, idx=None, dim9=False): 23 | batch_size = x.size(0) 24 | num_points = x.size(2) 25 | x = x.view(batch_size, -1, num_points) 26 | if idx is None: 27 | if dim9 == False: 28 | idx = knn(x, k=k) # (batch_size, num_points, k) 29 | else: 30 | idx = knn(x[:, 6:], k=k) 31 | device = torch.device('cuda') 32 | 33 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points 34 | 35 | idx = idx + idx_base 36 | 37 | idx = idx.view(-1) 38 | 39 | return idx # (batch_size, 2*num_dims, num_points, k) 40 | 41 | 42 | def get_knn_feature(x, k=20): 43 | idx = get_idx(x, k=k) 44 | 45 | batch_size = x.size(0) 46 | num_points = x.size(2) 47 | x = x.view(batch_size, -1, num_points) 48 | _, num_dims, _ = x.size() 49 | 50 | x = x.transpose(2, 51 | 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 52 | feature = x.view(batch_size * num_points, -1)[idx, :] 53 | feature = feature.view(batch_size, num_points, k, num_dims) 54 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1) 55 | 56 | feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous() 57 | 58 | return feature 59 | 60 | 61 | class FeatureExtration(nn.Module): 62 | def __init__(self, input_dim, output_dim, rate1, rate2, rate3): 63 | super(FeatureExtration, self).__init__() 64 | self.bn1_1 = nn.BatchNorm2d(output_dim // rate1) 65 | self.bn1_2 = nn.BatchNorm2d(output_dim // rate2) 66 | self.bn1_3 = nn.BatchNorm1d(output_dim // rate3) 67 | self.bn1_4 = nn.BatchNorm1d(output_dim) 68 | self.bn1_5 = nn.BatchNorm2d(output_dim // rate3) 69 | 70 | self.conv1_1 = nn.Sequential(nn.Conv2d(input_dim * 2, output_dim // rate1, 1), self.bn1_1, 71 | nn.LeakyReLU(negative_slope=0.2)) 72 | self.conv1_2 = nn.Sequential(nn.Conv2d(input_dim * 2, output_dim // rate2, 1), self.bn1_2, 73 | nn.LeakyReLU(negative_slope=0.2)) 74 | self.conv1_3 = nn.Sequential(nn.Conv1d(output_dim // rate1 + output_dim // rate2, output_dim // rate3, 1), 75 | self.bn1_3, 76 | nn.LeakyReLU(negative_slope=0.2)) 77 | self.conv1_5 = nn.Sequential(nn.Conv2d((output_dim // rate3) * 2, output_dim // rate3, 1), self.bn1_5, 78 | nn.LeakyReLU(negative_slope=0.2)) 79 | 80 | self.conv1_4 = nn.Sequential(nn.Conv1d(output_dim // rate3, output_dim, 1), self.bn1_4, 81 | nn.LeakyReLU(negative_slope=0.2)) 82 | 83 | self.fc1 = nn.Sequential( 84 | nn.Linear(output_dim, output_dim // 2), 85 | nn.ReLU(inplace=True), 86 | nn.Linear(output_dim // 2, 3) 87 | ) 88 | 89 | def forward(self, point): 90 | ''' 91 | 92 | :param point: [B,3,N] 93 | :return: feature :[B,N,Outputdim] 94 | refinepoint:[B,N,3] 95 | ''' 96 | pointfeature = get_knn_feature(point, k=8) # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)) 97 | pointfeature = self.conv1_1(pointfeature) # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k) 98 | pointfeature1 = pointfeature.max(dim=-1, keepdim=False)[ 99 | 0] # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points) 100 | pointfeature = get_knn_feature(point, k=16) 101 | pointfeature = self.conv1_2(pointfeature) 102 | pointfeature2 = pointfeature.max(dim=-1, keepdim=False)[0] 103 | pointfeature = torch.cat([pointfeature1, pointfeature2], dim=1) 104 | pointfeature = self.conv1_3(pointfeature) 105 | 106 | pointfeature = get_knn_feature(pointfeature, k=16) 107 | pointfeature = self.conv1_5(pointfeature) 108 | pointfeature = pointfeature.max(dim=-1, keepdim=False)[0] 109 | 110 | pointfeature = self.conv1_4(pointfeature) 111 | 112 | pointfeature = pointfeature.transpose(2, 1) 113 | refinepoint = self.fc1(pointfeature) 114 | refinepoint = refinepoint + point.transpose(2, 1) 115 | 116 | return pointfeature, refinepoint 117 | 118 | 119 | class ConsistentPointSelect(nn.Module): 120 | def __init__(self, r=0.5): 121 | super(ConsistentPointSelect, self).__init__() 122 | self.r = r 123 | 124 | self.fc1 = nn.Sequential( 125 | nn.Linear(128, 64), 126 | nn.ReLU(inplace=True), 127 | ) 128 | self.fc2 = nn.Sequential( 129 | nn.Linear(128, 64), 130 | nn.ReLU(inplace=True), 131 | ) 132 | self.fc3 = nn.Sequential( 133 | nn.Linear(1, 32), 134 | nn.ReLU(inplace=True), 135 | ) 136 | self.fc4 = nn.Sequential( 137 | nn.Linear(1, 32), 138 | nn.ReLU(inplace=True), 139 | ) 140 | self.bn1 = nn.BatchNorm1d(128) 141 | self.conv1 = nn.Sequential(nn.Conv1d(192, 128, 1), self.bn1, 142 | nn.LeakyReLU(negative_slope=0.2)) 143 | # self.sig=nn.Softmax(dim=1) 144 | self.sig = nn.Sigmoid() 145 | 146 | def angle(self, v1, v2): 147 | cross_prod = torch.stack([v1[..., 1] * v2[..., 2] - v1[..., 2] * v2[..., 1], 148 | v1[..., 2] * v2[..., 0] - v1[..., 0] * v2[..., 2], 149 | v1[..., 0] * v2[..., 1] - v1[..., 1] * v2[..., 0]], dim=-1) 150 | cross_prod_norm = torch.norm(cross_prod, dim=-1) 151 | dot_prod = torch.sum(v1 * v2, dim=-1) 152 | result = torch.atan2(cross_prod_norm, dot_prod) 153 | result = result.unsqueeze(-1) 154 | return result 155 | 156 | def get_center_normal(self, normalfeature, idx): 157 | # idx=np.load('top.npy') 158 | # print(idx.shape) 159 | # idx = torch.from_numpy(idx) 160 | B, N, C = normalfeature.size() 161 | center_normal = torch.gather(normalfeature, dim=1, index=idx.unsqueeze(-1).expand(B, 1, C)) 162 | center_normal = center_normal.repeat(1, N, 1) 163 | normalfeature = normalfeature - center_normal 164 | # normalfeature=F.normalize(normalfeature,dim=2) 165 | # normalfeature=torch.exp(-torch.abs(normalfeature)) 166 | return normalfeature 167 | 168 | def forward(self, pointfea, normalfea, index, point, normal): 169 | ''' 170 | 171 | :param pointfea: point-wise feature [B,N,C] 172 | :param normal: normal-wise feature [B,N,C] 173 | :param index: refine center normal position[B,1] 174 | :param point: point coordinate [B,N,3] 175 | :param normal: normal coordinate [B,N,3] 176 | :return: 177 | topidx [B,k] 178 | keypointfeature[B,k,C] 179 | keypoint[B,k,3] 180 | keynormalfeature[B,k,C] 181 | keynormal[B,k,3] 182 | ''' 183 | B, N, C = pointfea.size() 184 | k = int(self.r * N) 185 | # ||xi-xj||,???不确定保留 186 | distance = point * point 187 | pointdist = torch.sum(distance, dim=-1, keepdim=True) 188 | 189 | pointdist = torch.exp(-pointdist) 190 | pointdist = self.fc4(pointdist) 191 | 192 | angle = self.angle(point, normal) 193 | angle = self.fc3(angle) 194 | 195 | pointfeature = self.fc1(pointfea) 196 | # normalfeature=self.get_center_normal(normalfea,index) 197 | normalfeature = normalfea 198 | normalfeature = self.fc2(normalfeature) 199 | 200 | feature = torch.cat([pointfeature, normalfeature, angle, pointdist], dim=2) 201 | feature = feature.transpose(2, 1) 202 | feature = self.conv1(feature) 203 | feature = feature.transpose(2, 1) # [B,N,C] 204 | feature = torch.max(feature, dim=-1)[0] # [B,N] 205 | weight = self.sig(feature) # [B,N] 206 | top_idx = torch.argsort(weight, dim=-1, descending=True)[:, 0:k] 207 | 208 | keypointfeature = torch.gather(pointfea, dim=1, index=top_idx.unsqueeze(-1).expand(B, k, C)) 209 | keynormalfeature = torch.gather(normalfea, dim=1, index=top_idx.unsqueeze(-1).expand(B, k, C)) 210 | keyrefinepoint = torch.gather(point, dim=1, index=top_idx.unsqueeze(-1).expand(B, k, 3)) 211 | keyrefinenormal = torch.gather(normal, dim=1, index=top_idx.unsqueeze(-1).expand(B, k, 3)) 212 | 213 | return weight, top_idx, keypointfeature, keyrefinepoint, keynormalfeature, keyrefinenormal 214 | 215 | 216 | class KeyFeatureFusion(nn.Module): 217 | def __init__(self): 218 | super(KeyFeatureFusion, self).__init__() 219 | 220 | self.fc = nn.Sequential( 221 | nn.Linear(128, 128), 222 | nn.ReLU(inplace=True), 223 | ) 224 | self.conv = nn.Sequential( 225 | nn.Conv1d(128, 128, 1), 226 | nn.BatchNorm1d(128), 227 | nn.LeakyReLU(negative_slope=0.2), 228 | ) 229 | self.t = nn.Conv1d(128, 64, 1) 230 | # linear transform to get keys 231 | self.p = nn.Conv1d(128, 64, 1) 232 | # linear transform to get query 233 | self.g = nn.Conv1d(128, 128, 1) 234 | self.z = nn.Conv1d(256, 256, 1) 235 | 236 | self.gn = nn.GroupNorm(num_groups=1, num_channels=256) 237 | 238 | self.softmax = nn.Softmax(dim=-1) 239 | 240 | def normalAttention(self, points, normals): 241 | # print(points.shape) 242 | # print(normals.shape) 243 | t = self.t(points) # [batchsize,64,500] 244 | p = self.p(points) # [batchsize,64,500] 245 | v = self.g(normals) 246 | proj_query = t # B X C/2 XN 247 | 248 | proj_key = p.transpose(2, 1) # B X M XC/2 249 | 250 | energy = torch.bmm(proj_key, proj_query) # [B,N,N] 251 | 252 | total_energy = energy 253 | attention = self.softmax(total_energy) # B X N X N 254 | # print(attention.shape) 255 | proj_value = v 256 | out = torch.bmm(proj_value, attention.permute(0, 2, 1)) 257 | # print(out.shape) 258 | return out 259 | 260 | def knnfeature(self, x, normalfvals, k): 261 | ''' 262 | 263 | :param x: x is normal/point cardinate [B,N,3] 264 | :param normalfvals: normalfvals is normal/point feature [B,C,N] 265 | :param k: K neighbors 266 | :return: k normal features [B,N,K,C] 267 | ''' 268 | x = x.transpose(2, 1).contiguous() 269 | batch_size, num_points, num_dims = normalfvals.size() 270 | idx = get_idx(x, k=k) 271 | normalfvals = normalfvals.transpose(2, 272 | 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points) 273 | feature = normalfvals.view(batch_size * num_points, -1)[idx, :] 274 | feature = feature.view(batch_size, num_points, k, num_dims) 275 | 276 | return feature 277 | 278 | def featurefuse(self, knnpointfeature, keyfeature, topidx): 279 | B, N, K, C = knnpointfeature.size() 280 | k = topidx.size(1) 281 | keyknnfeature = torch.gather(knnpointfeature, dim=1, 282 | index=topidx.unsqueeze(-1).unsqueeze(-1).expand(B, k, K, C)) 283 | # keyfeature=keyfeature.unsqueeze(-1) 284 | # keyfeature=keyfeature.view(B,k,1,C).repeat(1,1,K,1) 285 | # keypoint=keypoint.unsqueeze(-1) 286 | # keypoint=keypoint.view(B,k,1,3).repeat(1,1,K,1) 287 | # feature is included:[point coordinate,key point feature,key point's knn feature] 288 | # feature=torch.cat([keyfeature,keyknnfeature],dim=-1)#[B,k,K,C] 289 | # feature=torch.mean(feature,dim=2) 290 | keyknnfeature = torch.mean(keyknnfeature, dim=2) # [B,k,C] 291 | # keyknnfeature = torch.sum(keyknnfeature, dim=2) # [B,k,C] 292 | feature = keyfeature + keyknnfeature 293 | return feature 294 | 295 | def forward(self, weight, allfeature, keyfeature, refinepoint, keypoint, topidx, k): 296 | ''' 297 | 298 | :param allfeature: [B,N,C] 299 | :param keyfeature: [B,k,C] 300 | :param refinepoint: [B,N,3] 301 | :param keypoint: [B,k,3] 302 | :param topidx: [B,k,1] 303 | :param k: knn neighboorhood 304 | :return: keyknnfeature [B,C,N] 305 | ''' 306 | 307 | # pointfeature=pointfeature.transpose(2,1) 308 | allfeature = allfeature * weight.unsqueeze(-1) 309 | knnpointfeature = self.knnfeature(refinepoint, allfeature, k) # [B,N,K,C] 310 | feature = self.featurefuse(knnpointfeature, keyfeature, topidx) 311 | feature = feature.transpose(2, 1) 312 | feature = self.conv(feature) 313 | 314 | return feature 315 | 316 | 317 | class NormalEncorder(nn.Module): 318 | def __init__(self): 319 | super(NormalEncorder, self).__init__() 320 | 321 | self.conv1_1 = nn.Sequential(nn.Conv1d(256, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.2)) 322 | # self.conv1_2=nn.Sequential(nn.Conv1d(128,256,1),nn.BatchNorm1d(256),nn.LeakyReLU(negative_slope=0.2)) 323 | self.conv1_2 = nn.Sequential(nn.Conv2d(128 * 2, 64, 1), nn.BatchNorm2d(64), 324 | nn.LeakyReLU(negative_slope=0.2)) 325 | 326 | # self.conv2_1=nn.Sequential(nn.Conv1d(256,128,1),nn.BatchNorm1d(128),nn.LeakyReLU(negative_slope=0.2)) 327 | # self.conv2_2=nn.Sequential(nn.Conv1d(128,256,1),nn.BatchNorm1d(256),nn.LeakyReLU(negative_slope=0.2)) 328 | self.conv2_1 = nn.Sequential(nn.Conv2d(64 * 2, 128, 1), nn.BatchNorm2d(128), 329 | nn.LeakyReLU(negative_slope=0.2)) 330 | self.conv2_2 = nn.Sequential(nn.Conv1d(128, 256, 1), nn.BatchNorm1d(256), nn.LeakyReLU(negative_slope=0.2)) 331 | 332 | self.fc1 = nn.Sequential(nn.Conv2d(256, 128, 1), nn.BatchNorm2d(128), nn.LeakyReLU(negative_slope=0.2)) 333 | self.fc2 = nn.Sequential(nn.Conv2d(128, 64, 1), nn.BatchNorm2d(64), nn.LeakyReLU(negative_slope=0.2)) 334 | self.fc3 = nn.Sequential(nn.Conv2d(64, 3, 1)) 335 | # self.fc3=nn.Linear(64,3) 336 | 337 | self.fc1_1 = nn.Sequential(nn.Linear(256, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.2)) 338 | self.fc2_1 = nn.Sequential(nn.Linear(128, 64, 1), nn.BatchNorm1d(64), nn.LeakyReLU(negative_slope=0.2)) 339 | self.fc3_1 = nn.Sequential(nn.Linear(64, 3, 1)) 340 | 341 | def forward(self, x, normalfeature,pointfusefeature): 342 | # [B,256,N] 343 | # x=torch.cat([x,normalfeature],dim=1) 344 | x = x + normalfeature 345 | x=torch.cat([x,pointfusefeature],dim=1) 346 | 347 | feature = self.conv1_1(x) 348 | # feature1=self.conv1_2(feature1) 349 | # feature1=feature1+x 350 | 351 | feature1 = get_knn_feature(feature, k=8) # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k)) 352 | feature1 = self.conv1_2(feature1) # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k) 353 | feature1 = feature1.max(dim=-1, keepdim=False)[ 354 | 0] # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points) 355 | 356 | feature1 = get_knn_feature(feature1, k=8) 357 | feature1 = self.conv2_1(feature1) 358 | feature1 = feature1.max(dim=-1, keepdim=False)[0] # [B,128,N] 359 | 360 | feature = feature + feature1 361 | feature = self.conv2_2(feature) 362 | 363 | dis = feature.max(dim=-1, keepdim=False)[0] 364 | dis = self.fc1_1(dis) 365 | dis = self.fc2_1(dis) 366 | dis = self.fc3_1(dis) 367 | # feature=F.normalize(feature,p=2) 368 | 369 | return dis 370 | 371 | 372 | ''' 373 | class NormalEncorder(nn.Module): 374 | def __init__(self): 375 | super(NormalEncorder,self).__init__() 376 | 377 | self.conv1_1=nn.Sequential(nn.Conv1d(256,1024, 1),nn.BatchNorm1d(1024),nn.LeakyReLU(negative_slope=0.2)) 378 | # self.conv1_2=nn.Sequential(nn.Conv1d(128,256,1),nn.BatchNorm1d(256),nn.LeakyReLU(negative_slope=0.2)) 379 | # self.mappool1=nn.MaxPool1d(1,stride=2) 380 | 381 | # self.conv2_1=nn.Sequential(nn.Conv1d(256,128,1),nn.BatchNorm1d(128),nn.LeakyReLU(negative_slope=0.2)) 382 | # self.conv2_2=nn.Sequential(nn.Conv1d(128,256,1),nn.BatchNorm1d(256),nn.LeakyReLU(negative_slope=0.2)) 383 | 384 | 385 | # self.fc1=nn.Sequential(nn.Linear(256,128),nn.BatchNorm1d(128),nn.LeakyReLU(negative_slope=0.2)) 386 | # self.fc2 = nn.Sequential(nn.Linear(128, 64), nn.BatchNorm1d(64), nn.LeakyReLU(negative_slope=0.2)) 387 | # self.fc3=nn.Linear(64,3) 388 | 389 | self.fc1=nn.Sequential(nn.Linear(1024,512),nn.BatchNorm1d(512),nn.LeakyReLU(negative_slope=0.2)) 390 | self.fc2 = nn.Sequential(nn.Linear(512, 128), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.2)) 391 | self.fc3 = nn.Sequential(nn.Linear(128, 64), nn.BatchNorm1d(64), nn.LeakyReLU(negative_slope=0.2)) 392 | self.fc=nn.Linear(64,3) 393 | 394 | 395 | 396 | def forward(self,x,globalfeature,normalfeature): 397 | 398 | x=x+globalfeature 399 | x=torch.cat([x,normalfeature],dim=1) 400 | 401 | 402 | feature1=self.conv1_1(x) 403 | feature1=self.conv1_2(feature1) 404 | feature1=feature1+x 405 | feature1=self.mappool1(feature1) 406 | 407 | feature2=self.conv2_1(feature1) 408 | feature2=self.conv2_2(feature2) 409 | feature2=feature2+feature1 410 | 411 | feature=feature2.max(dim=-1,keepdim=False)[0] 412 | feature=self.fc1(feature) 413 | feature=self.fc2(feature) 414 | feature=self.fc3(feature) 415 | 416 | 417 | feature=self.conv1_1(x) 418 | feature=feature.max(dim=-1,keepdim=False)[0] 419 | feature=self.fc1(feature) 420 | feature=self.fc2(feature) 421 | feature=self.fc3(feature) 422 | feature=torch.tanh(self.fc(feature)) 423 | 424 | return feature 425 | ''' 426 | 427 | 428 | class MLP(nn.Module): 429 | def __init__(self): 430 | super(MLP, self).__init__() 431 | 432 | self.conv = nn.Sequential(nn.Conv1d(384, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(negative_slope=0.2)) 433 | self.conv1 = nn.Sequential(nn.Conv1d(512, 256, 1), nn.BatchNorm1d(256), nn.LeakyReLU(negative_slope=0.2)) 434 | self.conv2 = nn.Sequential(nn.Conv1d(256, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(negative_slope=0.2)) 435 | self.fc1_1 = nn.Linear(512, 256) 436 | self.fc1_2 = nn.Linear(256, 64) 437 | self.fc1_3 = nn.Linear(64, 3) 438 | self.bn1_11 = nn.BatchNorm1d(256) 439 | self.bn1_22 = nn.BatchNorm1d(64) 440 | 441 | def forward(self, x): 442 | x = self.conv(x) 443 | x = x.max(dim=-1, keepdim=False)[0] 444 | x = F.relu(self.bn1_11(self.fc1_1(x))) 445 | x = F.relu(self.bn1_22(self.fc1_2(x))) 446 | x = torch.tanh(self.fc1_3(x)) 447 | return x 448 | 449 | 450 | class PCPNet(nn.Module): 451 | def __init__(self, num_points=500, output_dim=3, k=20): 452 | super(PCPNet, self).__init__() 453 | self.num_points = num_points 454 | self.k = k 455 | 456 | self.pointfeatEX = FeatureExtration(input_dim=3, output_dim=128, rate1=8, rate2=4, rate3=2) 457 | self.normalfeatEX = FeatureExtration(input_dim=3, output_dim=128, rate1=8, rate2=4, rate3=2) 458 | self.weight = ConsistentPointSelect(r=0.5) 459 | self.pointFeaFu = KeyFeatureFusion() 460 | self.normalFeaFu = KeyFeatureFusion() 461 | self.normalDecoder = NormalEncorder() 462 | 463 | self.mlp1 = MLP() 464 | 465 | def forward(self, x, normal, index): 466 | ''' 467 | 468 | :param x: point coordinate [64,3,N] 469 | :param normal: normal coordinate [64,3,n] 470 | :param normal_center: patch center coordinate [64,1] 471 | :return: point,normal 472 | ''' 473 | # print("here") 474 | pointfeature, refinepoint = self.pointfeatEX(x) 475 | normalfeature, refinenormal = self.normalfeatEX(normal) 476 | weight, topidx, keypointfeature, keypoint, keynormalfeature, keynormal = self.weight(pointfeature, 477 | normalfeature, index, 478 | refinepoint, refinenormal) 479 | 480 | pointfusefeature = self.pointFeaFu(weight, pointfeature, keypointfeature, refinepoint, keypoint, topidx, 481 | k=10) # [B,C,N] 482 | # normalfusefeature=self.normalFeaFu(weight,normalfeature,keynormalfeature,refinenormal,keynormal,topidx,k=10)#[B,C,N] 483 | normalfusefeature = self.normalFeaFu(weight, normalfeature, keynormalfeature, refinepoint, keypoint, topidx, 484 | k=10) 485 | 486 | N = pointfusefeature.size(2) 487 | 488 | globalnormalfeature = torch.max(normalfeature, dim=1, keepdim=True)[0] 489 | globalnormalfeature = globalnormalfeature.repeat(1, N, 1) 490 | globalnormalfeature = globalnormalfeature.transpose(2, 1) 491 | 492 | globalpointfeature = torch.max(pointfeature, dim=1, keepdim=True)[0] 493 | globalpointfeature = globalpointfeature.repeat(1, N, 1) 494 | globalpointfeature = globalpointfeature.transpose(2, 1) 495 | 496 | maxpointfeature = torch.max(pointfusefeature, dim=2, keepdim=True)[0] 497 | maxpointfeature = maxpointfeature.repeat(1, 1, N) 498 | 499 | maxnormalfeature = torch.max(normalfusefeature, dim=2, keepdim=True)[0] 500 | maxnormalfeature = maxnormalfeature.repeat(1, 1, N) # [B,128,N] 501 | 502 | pfeat = torch.cat([pointfusefeature, globalpointfeature, normalfusefeature], dim=1) 503 | # nfeat=torch.cat([normalfusefeature,globalnormalfeature],dim=1) 504 | 505 | p = self.mlp1(pfeat) 506 | normal = self.normalDecoder(normalfusefeature, globalnormalfeature,pointfusefeature) 507 | # n=self.mlp2(nfeat,False) 508 | n = normal 509 | 510 | return p, n, weight, topidx 511 | 512 | 513 | if __name__ == '__main__': 514 | batchsize = 64 515 | point = torch.rand(64, 512, 3) 516 | point = point.transpose(2, 1) 517 | normal = torch.rand(64, 512, 3) 518 | normal = normal.transpose(2, 1) 519 | pfeat = torch.rand(64, 128, 512) 520 | nfeat = torch.rand(64, 128, 512) 521 | pdist = torch.rand(64, 1, 512) 522 | nrag = torch.rand(64, 1, 512) 523 | index = np.random.randint(10, 20, 64) 524 | # index=np.expand_dims(index,axis=1) 525 | # print(index) 526 | net = PCPNet() 527 | net(point, normal, index) -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # PCDNF: Revisiting Learning-based Point Cloud Denoising via Joint Normal Filtering 2 | 3 | :zap:`Status Update: [2023/07/02] This paper has been accepted by the IEEE Transactions on Visualization and Computer Graphics (TVCG).` 4 | 5 |

6 | 7 |

8 | 9 | by [Zheng Liu](https://labzhengliu.github.io/), Yaowu Zhao, Sijing Zhan, [Yuanyuan Liu](https://cvlab-liuyuanyuan.github.io/), [Renjie Chen](http://staff.ustc.edu.cn/~renjiec/) and [Ying He](https://personal.ntu.edu.sg/yhe/) 10 | 11 | ## :bulb: Introduction 12 | Recovering high quality surfaces from noisy point clouds, known as point cloud denoising, is a fundamental yet challenging 13 | problem in geometry processing. Most of the existing methods either directly denoise the noisy input or filter raw normals followed by 14 | updating point positions. Motivated by the essential interplay between point cloud denoising and normal filtering, we revisit point cloud 15 | denoising from a multitask perspective, and propose an end-to-end network, named PCDNF, to denoise point clouds via joint normal 16 | filtering. In particular, we introduce an auxiliary normal filtering task to help the overall network remove noise more effectively while 17 | preserving geometric features more accurately. In addition to the overall architecture, our network has two novel modules. On one 18 | hand, to improve noise removal performance, we design a shape-aware selector to construct the latent tangent space representation of 19 | the specific point by comprehensively considering the learned point and normal features and geometry priors. On the other hand, point 20 | features are more suitable for describing geometric details, and normal features are more conducive for representing geometric 21 | structures (e.g., sharp edges and corners). Combining point and normal features allows us to overcome their weaknesses. Thus, we 22 | design a feature refinement module to fuse point and normal features for better recovering geometric information. 23 | 24 |

25 | 26 |

27 | 28 | ## :wrench: Usage 29 | ## Environment 30 | * Python 3.6 31 | * PyTorch 1.5.0 32 | * CUDA and CuDNN (CUDA 10.1 & CuDNN 7.5) 33 | * TensorboardX (2.0) if logging training info. 34 | ## Install required python packages: 35 | ``` bash 36 | pip install numpy 37 | pip install scipy 38 | pip install plyfile 39 | pip install scikit-learn 40 | pip install tensorboardX (only for training stage) 41 | pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html 42 | ``` 43 | ### Test the trained model: 44 | Set the parameters such as file path, batchsize, iteration numbers, etc in **testN.py** and then run it. 45 | We provide our pretrained model. 46 | 47 | ### Train the model: 48 | Set the parameters such as file path, batchsize, iteration numbers, etc in **train_NetworkN1.py** and then run it. 49 | Our training set is from [PointFilter](https://github.com/dongbo-BUAA-VR/Pointfilter) and the normal information is computed by PCA. 50 | 51 | ## :link: Citation 52 | If you find this work helpful please consider citing our [paper](https://ieeexplore.ieee.org/document/10173632) : 53 | ``` 54 | @ARTICLE{10173632, 55 | author={Liu, Zheng and Zhao, Yaowu and Zhan, Sijing and Liu, Yuanyuan and Chen, Renjie and He, Ying}, 56 | journal={IEEE Transactions on Visualization and Computer Graphics}, 57 | title={PCDNF: Revisiting Learning-based Point Cloud Denoising via Joint Normal Filtering}, 58 | year={2023}, 59 | doi={10.1109/TVCG.2023.3292464} 60 | } 61 | ``` 62 | 63 | -------------------------------------------------------------------------------- /Trained_Models/model_full_ae.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LabZhengLiu/PCDNF/4cc625096868c7888f80d0cbca38d4ba52a0c9eb/Trained_Models/model_full_ae.pth -------------------------------------------------------------------------------- /dataset923.py: -------------------------------------------------------------------------------- 1 | from __future__ import print_function 2 | 3 | import torch 4 | import torch.utils.data as data 5 | from torch.utils.data.dataloader import default_collate 6 | 7 | import os 8 | import numpy as np 9 | import scipy.spatial as sp 10 | 11 | from utils118 import pca_alignment 12 | 13 | 14 | ##################################New Dataloader Class########################### 15 | 16 | def my_collate(batch): 17 | batch = list(filter(lambda x: x is not None, batch)) 18 | return default_collate(batch) 19 | 20 | 21 | class RandomPointcloudPatchSampler(data.sampler.Sampler): 22 | 23 | def __init__(self, data_source, patches_per_shape, seed=None, identical_epochs=False): 24 | self.data_source = data_source 25 | self.patches_per_shape = patches_per_shape 26 | self.seed = seed 27 | self.identical_epochs = identical_epochs 28 | self.total_patch_count = None 29 | 30 | if self.seed is None: 31 | self.seed = np.random.random_integers(0, 2 ** 32 - 1, 1)[0] 32 | self.rng = np.random.RandomState(self.seed) 33 | 34 | self.total_patch_count = 0 35 | for shape_ind, _ in enumerate(self.data_source.shape_names): 36 | self.total_patch_count = self.total_patch_count + min(self.patches_per_shape, 37 | self.data_source.shape_patch_count[shape_ind]) 38 | 39 | def __iter__(self): 40 | 41 | if self.identical_epochs: 42 | self.rng.seed(self.seed) 43 | 44 | return iter( 45 | self.rng.choice(sum(self.data_source.shape_patch_count), size=self.total_patch_count, replace=False)) 46 | 47 | def __len__(self): 48 | return self.total_patch_count 49 | 50 | 51 | class PointcloudPatchDataset(data.Dataset): 52 | 53 | def __init__(self, root=None, shapes_list_file=None, patch_radius=0.05, points_per_patch=512, 54 | seed=None, train_state='train', shape_name=None, identical_epoches=False,knn=False): 55 | 56 | self.root = root 57 | self.shapes_list_file = shapes_list_file 58 | 59 | self.patch_radius = patch_radius 60 | self.points_per_patch = points_per_patch 61 | self.seed = seed 62 | self.train_state = train_state 63 | self.identical_epochs = identical_epoches 64 | self.knn=knn 65 | 66 | # initialize rng for picking points in a patch 67 | if self.seed is None: 68 | self.seed = np.random.random_integers(0, 2 ** 10 - 1, 1)[0] 69 | self.rng = np.random.RandomState(self.seed) 70 | 71 | self.shape_patch_count = [] 72 | self.patch_radius_absolute = [] 73 | self.gt_shapes = [] 74 | self.noise_shapes = [] 75 | 76 | self.shape_names = [] 77 | if self.train_state == 'evaluation' and shape_name is not None: 78 | pts_normal = np.load(os.path.join(self.root, shape_name + '.npy')) 79 | noise_pts = pts_normal[:, 0:3] 80 | noise_normal = pts_normal[:, 3:6] 81 | noise_kdtree = sp.cKDTree(noise_pts) 82 | self.noise_shapes.append( 83 | {'noise_pts': noise_pts, 'noise_kdtree': noise_kdtree, 'noise_normal': noise_normal}) 84 | self.shape_patch_count.append(noise_pts.shape[0]) 85 | bbdiag = float(np.linalg.norm(noise_pts.max(0) - noise_pts.min(0), 2)) 86 | self.patch_radius_absolute.append(bbdiag * self.patch_radius) 87 | elif self.train_state == 'train': 88 | with open(os.path.join(self.root, self.shapes_list_file)) as f: 89 | self.shape_names = f.readlines() 90 | self.shape_names = [x.strip() for x in self.shape_names] 91 | self.shape_names = list(filter(None, self.shape_names)) 92 | for shape_ind, shape_name in enumerate(self.shape_names): 93 | print('getting information for shape %s' % shape_name) 94 | if shape_ind % 6 == 0: 95 | gt_pts_normal = np.load(os.path.join(self.root, shape_name + '.npy')) 96 | gt_pts = gt_pts_normal[:, 0:3] 97 | gt_normal = gt_pts_normal[:, 3:6] 98 | gt_kdtree = sp.cKDTree(gt_pts) 99 | self.gt_shapes.append({'gt_pts': gt_pts, 'gt_normal': gt_normal, 'gt_kdtree': gt_kdtree}) 100 | self.noise_shapes.append( 101 | {'noise_pts': gt_pts, 'noise_kdtree': gt_kdtree, 'noise_normal': gt_normal}) 102 | noise_pts = gt_pts 103 | else: 104 | 105 | pts_normal = np.load(os.path.join(self.root, shape_name + '.npy')) 106 | noise_pts = pts_normal[:, 0:3] 107 | noise_normal = pts_normal[:, 3:6] 108 | noise_kdtree = sp.cKDTree(noise_pts) 109 | self.noise_shapes.append( 110 | {'noise_pts': noise_pts, 'noise_kdtree': noise_kdtree, 'noise_normal': noise_normal}) 111 | 112 | self.shape_patch_count.append(noise_pts.shape[0]) 113 | bbdiag = float(np.linalg.norm(noise_pts.max(0) - noise_pts.min(0), 2)) 114 | self.patch_radius_absolute.append(bbdiag * self.patch_radius) 115 | 116 | def patch_sampling(self, patch_inds): 117 | 118 | if self.identical_epochs: 119 | self.rng.seed(self.seed) 120 | 121 | # if patch_pts.shape[0] > self.points_per_patch: 122 | # 123 | # sample_index = self.rng.choice(range(patch_pts.shape[0]), self.points_per_patch, replace=False) 124 | # 125 | # else: 126 | # 127 | # sample_index = self.rng.choice(range(patch_pts.shape[0]), self.points_per_patch) 128 | # point_count = min(self.points_per_patch, len(patch_inds)) 129 | if len(patch_inds)>=self.points_per_patch: 130 | patch_inds = patch_inds[self.rng.choice(len(patch_inds), self.points_per_patch, replace=False)] 131 | else: 132 | patch_inds=patch_inds[self.rng.choice(len(patch_inds),self.points_per_patch)] 133 | 134 | return patch_inds 135 | 136 | def gauss_fcn(self,x, mu=0, sigma2=0.12): 137 | tmp = -(x - mu) ** 2 / (2 * sigma2) 138 | 139 | return np.exp(tmp) 140 | 141 | 142 | def __getitem__(self, index): 143 | 144 | # find shape that contains the point with given global index 145 | shape_ind, patch_ind = self.shape_index(index) 146 | noise_shape = self.noise_shapes[shape_ind] 147 | patch_radius = self.patch_radius_absolute[shape_ind] 148 | # For noise_patch 149 | 150 | if self.knn: 151 | #索引中包含中心点 152 | dist,noise_patch_idx=np.array(noise_shape['noise_kdtree'].query(noise_shape['noise_pts'][patch_ind],self.points_per_patch)) 153 | # patch_radius=dist[-1] 154 | noise_patch_idx=noise_patch_idx.astype(np.int) 155 | # print(noise_patch_idx) 156 | else: 157 | #索引中不包含中心点 158 | noise_patch_idx = noise_shape['noise_kdtree'].query_ball_point(noise_shape['noise_pts'][patch_ind],patch_radius) 159 | #noise_patch_idx=noise_patch_idx.astype(np.int) 160 | noise_patch_idx=np.array(noise_patch_idx) 161 | 162 | if len(noise_patch_idx) < 3: 163 | return None 164 | 165 | noise_sample_idx = self.patch_sampling(noise_patch_idx) 166 | index=np.where(noise_sample_idx==patch_ind) 167 | index=index[0] 168 | 169 | noise_patch_pts = noise_shape['noise_pts'][noise_sample_idx] - noise_shape['noise_pts'][patch_ind] 170 | # 返回旋转后的patch,以及逆矩阵R^-1 171 | noise_patch_pts /= patch_radius 172 | noise_patch_pts, noise_patch_inv = pca_alignment(noise_patch_pts) 173 | 174 | support_radius = np.linalg.norm(noise_patch_pts.max(0) - noise_patch_pts.min(0), 2) / noise_patch_pts.shape[0] 175 | support_radius = np.expand_dims(support_radius, axis=0) 176 | 177 | normal=noise_shape['noise_normal'][patch_ind] 178 | normal=np.expand_dims(normal,axis=0) 179 | normal = np.array(np.linalg.inv(noise_patch_inv) * np.matrix(normal.T)).T 180 | 181 | 182 | noise_patch_normal = noise_shape['noise_normal'][noise_sample_idx] 183 | noise_patch_normal = np.array(np.linalg.inv(noise_patch_inv) * np.matrix(noise_patch_normal.T)).T 184 | 185 | if self.train_state == 'evaluation': 186 | return torch.from_numpy(noise_patch_pts), torch.from_numpy(noise_patch_inv), \ 187 | noise_shape['noise_pts'][patch_ind],torch.from_numpy(noise_patch_normal),torch.from_numpy(index),normal 188 | 189 | # For gt_patch 190 | gt_shape = self.gt_shapes[shape_ind // 6] 191 | if self.knn: 192 | # gt_patch_idx = gt_shape['gt_kdtree'].query_ball_point(noise_shape['noise_pts'][patch_ind], patch_radius) 193 | dist,gt_patch_idx=gt_shape['gt_kdtree'].query(noise_shape['noise_pts'][patch_ind],self.points_per_patch) 194 | gt_patch_idx=gt_patch_idx.astype(np.int) 195 | else: 196 | gt_patch_idx=np.array(gt_shape['gt_kdtree'].query_ball_point(noise_shape['noise_pts'][patch_ind],patch_radius)) 197 | # print(gt_patch_idx) 198 | if len(gt_patch_idx) < 3: 199 | return None 200 | 201 | gt_sample_idx=self.patch_sampling(gt_patch_idx) 202 | # Patch归一化处理 203 | gt_patch_pts=gt_shape['gt_pts'][gt_sample_idx]-noise_shape['noise_pts'][patch_ind] 204 | gt_patch_pts /= patch_radius 205 | gt_patch_pts = np.array(np.linalg.inv(noise_patch_inv) * np.matrix(gt_patch_pts.T)).T 206 | # 对patch随机选取500个点 207 | gt_normal=gt_shape['gt_normal'][patch_ind] 208 | gt_normal=np.expand_dims(gt_normal,axis=0) 209 | gt_normal = np.array(np.linalg.inv(noise_patch_inv) * np.matrix(gt_normal.T)).T 210 | 211 | gt_patch_normal=gt_shape['gt_normal'][gt_sample_idx] 212 | gt_patch_normal = np.array(np.linalg.inv(noise_patch_inv) * np.matrix(gt_patch_normal.T)).T 213 | 214 | gt_point=gt_shape['gt_pts'][patch_ind] 215 | gt_point=gt_point-noise_shape['noise_pts'][patch_ind] 216 | gt_point=np.expand_dims(gt_point,axis=0) 217 | gt_point=np.array(np.linalg.inv(noise_patch_inv)*np.matrix(gt_point.T)).T 218 | 219 | return torch.from_numpy(noise_patch_pts), torch.from_numpy(gt_patch_pts), torch.from_numpy(noise_patch_normal),torch.from_numpy(gt_patch_normal),torch.from_numpy(support_radius),torch.from_numpy(gt_normal),torch.from_numpy(index),torch.from_numpy(normal) 220 | 221 | def __len__(self): 222 | return sum(self.shape_patch_count) 223 | 224 | def shape_index(self, index): 225 | shape_patch_offset = 0 226 | shape_ind = None 227 | for shape_ind, shape_patch_count in enumerate(self.shape_patch_count): 228 | if (index >= shape_patch_offset) and (index < shape_patch_offset + shape_patch_count): 229 | shape_patch_ind = index - shape_patch_offset 230 | break 231 | shape_patch_offset = shape_patch_offset + shape_patch_count 232 | 233 | return shape_ind, shape_patch_ind 234 | 235 | 236 | if __name__ == '__main__': 237 | seed = 3627473 238 | train_dataset = PointcloudPatchDataset( 239 | root='./dataset', 240 | shapes_list_file='train.txt', 241 | seed=seed, 242 | train_state='train', 243 | identical_epoches=True, 244 | knn=True) 245 | train_dataset.__getitem__(index=100000) 246 | # train_datasampler = RandomPointcloudPatchSampler( 247 | # train_dataset, 248 | # patches_per_shape=8000, 249 | # seed=3627473, 250 | # identical_epochs=False) 251 | # train_dataloader = torch.utils.data.DataLoader( 252 | # train_dataset, 253 | # collate_fn=my_collate, 254 | # sampler=train_datasampler, 255 | # shuffle=(train_datasampler is None), 256 | # batch_size=64, 257 | # num_workers=4, 258 | # pin_memory=True) 259 | # for batch_ind, data_tuple in enumerate(train_dataloader): 260 | # 261 | # noise_patch, gt_patch, patch_normal, gt_patch_normal = data_tuple 262 | -------------------------------------------------------------------------------- /image/figure1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LabZhengLiu/PCDNF/4cc625096868c7888f80d0cbca38d4ba52a0c9eb/image/figure1.png -------------------------------------------------------------------------------- /image/figure2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/LabZhengLiu/PCDNF/4cc625096868c7888f80d0cbca38d4ba52a0c9eb/image/figure2.png -------------------------------------------------------------------------------- /testN.py: -------------------------------------------------------------------------------- 1 | import os 2 | import torch 3 | import time 4 | import numpy as np 5 | from NetworkN1 import PCPNet 6 | from dataset923 import PointcloudPatchDataset,my_collate 7 | from utils118 import parse_arguments 8 | 9 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 10 | def eval(opt): 11 | 12 | 13 | with open(os.path.join(opt.testset, 'test.txt'), 'r') as f: 14 | shape_names = f.readlines() 15 | shape_names = [x.strip() for x in shape_names] 16 | shape_names = list(filter(None, shape_names)) 17 | 18 | if not os.path.exists(parameters.save_dir): 19 | os.makedirs(parameters.save_dir) 20 | for shape_id, shape_name in enumerate(shape_names): 21 | print(shape_name) 22 | original_noise_pts = np.load(os.path.join(opt.testset, shape_name + '.npy')) 23 | np.save(os.path.join(opt.save_dir, shape_name + '_pred_iter_0.npy'), original_noise_pts.astype('float32')) 24 | for eval_index in range(opt.eval_iter_nums): 25 | print(eval_index) 26 | test_dataset = PointcloudPatchDataset( 27 | root=opt.save_dir, 28 | shape_name=shape_name + '_pred_iter_' + str(eval_index), 29 | patch_radius=opt.patch_radius, 30 | train_state='evaluation', 31 | knn=True) 32 | test_dataloader = torch.utils.data.DataLoader( 33 | test_dataset, 34 | batch_size=opt.batchSize, 35 | collate_fn=my_collate, 36 | num_workers=int(opt.workers)) 37 | 38 | pointfilter_eval = PCPNet() 39 | model_filename = os.path.join(parameters.eval_dir, 'model_full_ae.pth') 40 | checkpoint = torch.load(model_filename) 41 | pointfilter_eval.load_state_dict(checkpoint['state_dict']) 42 | 43 | pointfilter_eval.cuda() 44 | pointfilter_eval.eval() 45 | 46 | patch_radius = test_dataset.patch_radius_absolute 47 | pred_pts = np.empty((0, 6), dtype='float32') 48 | # start = time.time()/ 49 | for batch_ind, data_tuple in enumerate(test_dataloader): 50 | #normal [64,3] 51 | noise_patch, noise_inv, noise_point,patch_normal,index,normals= data_tuple 52 | 53 | noise_patch = noise_patch.float().cuda() 54 | noise_inv = noise_inv.float().cuda() 55 | patch_normal=patch_normal.float().cuda() 56 | index=index.cuda() 57 | normals=normals.float().cuda() 58 | 59 | noise_patch = noise_patch.transpose(2, 1).contiguous() 60 | patch_normal=patch_normal.transpose(2,1).contiguous() 61 | 62 | with torch.no_grad(): 63 | #dis= pointfilter_eval(noise_patch,patch_normal) # [64,3] 64 | dis,n,_,_= pointfilter_eval(noise_patch, patch_normal,index) 65 | # dis,classficaton,pointfval = pointfilter_eval(noise_patch,distance)#[64,3] 66 | dis=dis.unsqueeze(2) 67 | # n=n[:,0,:] 68 | n=n.unsqueeze(2) 69 | 70 | dis = torch.bmm(noise_inv, dis)#[64,3,1] 71 | n=torch.bmm(noise_inv,n) 72 | dis=np.squeeze(dis.data.cpu().numpy()) * patch_radius + noise_point.numpy() 73 | n=np.squeeze(n.data.cpu().numpy()) 74 | normal=n 75 | #normal=normal.data.cpu().numpy() 76 | # print(dis.shape) 77 | # print(normal.shape) 78 | if normal.shape[0] != dis.shape[0]: 79 | normal = normal.reshape(dis.shape) 80 | # exit(0) 81 | pred_normal=np.append(dis,normal,axis=1) 82 | pred_pts = np.append(pred_pts, 83 | pred_normal,axis=0) 84 | end = time.time() 85 | print("total_time:"+str(end-start)) 86 | np.save(os.path.join(opt.save_dir, shape_name + '_pred_iter_' + str(eval_index + 1) + '.npy'), 87 | pred_pts.astype('float32')) 88 | np.save(os.path.join(opt.save_dir, shape_name + '_pred_iter_' + str(eval_index + 1) + '.npy'), 89 | pred_pts.astype('float32')) 90 | # np.savetxt(os.path.join(opt.save_dir, shape_name + '.txt'), 91 | # pred_pts.astype('float32'), fmt='%.6f') 92 | 93 | 94 | 95 | if __name__ == '__main__': 96 | 97 | parameters = parse_arguments() 98 | parameters.testset = r'testdir' 99 | parameters.eval_dir = './Trained_Models/' 100 | parameters.batchSize = 64 101 | parameters.eval_iter_nums =1 102 | parameters.workers = 4 103 | parameters.save_dir = r'savedir' 104 | parameters.patch_radius = 0.05 105 | eval(parameters) -------------------------------------------------------------------------------- /train_NetworkN1.py: -------------------------------------------------------------------------------- 1 | # coding=utf-8 2 | 3 | from __future__ import print_function 4 | from tensorboardX import SummaryWriter 5 | from NetworkNN import PCPNet 6 | from dataset923 import PointcloudPatchDataset, RandomPointcloudPatchSampler, my_collate 7 | from utils118 import parse_arguments, adjust_learning_rate,compute_bilateral_loss 8 | 9 | import os 10 | import numpy as np 11 | import torch.utils.data 12 | import torch.optim as optim 13 | import torch.backends.cudnn as cudnn 14 | torch.backends.cudnn.benchmark = True 15 | os.environ["CUDA_VISIBLE_DEVICES"] = "0" 16 | def train(opt): 17 | print(opt) 18 | if not os.path.exists(opt.summary_train): 19 | os.makedirs(opt.summary_train) 20 | if not os.path.exists(opt.network_model_dir): 21 | os.makedirs(opt.network_model_dir) 22 | print("Random Seed: ", opt.manualSeed) 23 | np.random.seed(opt.manualSeed) 24 | torch.manual_seed(opt.manualSeed) 25 | train_dataset = PointcloudPatchDataset( 26 | root=opt.trainset, 27 | shapes_list_file='train.txt', 28 | patch_radius=0.05, 29 | seed=opt.manualSeed, 30 | identical_epoches=False, 31 | knn=True) 32 | train_datasampler = RandomPointcloudPatchSampler( 33 | train_dataset, 34 | patches_per_shape=8000, 35 | seed=opt.manualSeed, 36 | identical_epochs=False) 37 | train_dataloader = torch.utils.data.DataLoader( 38 | train_dataset, 39 | collate_fn=my_collate, 40 | sampler=train_datasampler, 41 | shuffle=(train_datasampler is None), 42 | batch_size=opt.batchSize, 43 | num_workers=int(opt.workers), 44 | pin_memory=True) 45 | num_batch = len(train_dataloader) 46 | print(num_batch) 47 | # optionally resume from a checkpoint 48 | denoisenet =PCPNet() 49 | denoisenet.cuda() 50 | optimizer = optim.SGD( 51 | denoisenet.parameters(), 52 | lr=opt.lr, 53 | momentum=opt.momentum) 54 | train_writer = SummaryWriter(opt.summary_train) 55 | if opt.resume: 56 | if os.path.isfile(opt.resume): 57 | print("=> loading checkpoint '{}'".format(opt.resume)) 58 | checkpoint = torch.load(opt.resume) 59 | opt.start_epoch = checkpoint['epoch'] 60 | denoisenet.load_state_dict(checkpoint['state_dict']) 61 | optimizer.load_state_dict(checkpoint['optimizer']) 62 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch'])) 63 | else: 64 | print("=> no checkpoint found at '{}'".format(opt.resume)) 65 | 66 | for epoch in range(opt.start_epoch, opt.nepoch): 67 | adjust_learning_rate(optimizer, epoch, opt) 68 | print('lr is %.10f' % (optimizer.param_groups[0]['lr'])) 69 | for batch_ind, data_tuple in enumerate(train_dataloader): 70 | denoisenet.train() 71 | optimizer.zero_grad() 72 | noise_patch, gt_patch,patch_normal,gt_patch_normal,support_radius,gt_normal,index,normal= data_tuple 73 | noise_patch = noise_patch.float().cuda() 74 | gt_patch = gt_patch.float().cuda() 75 | patch_normal=patch_normal.float().cuda() 76 | gt_patch_normal=gt_patch_normal.float().cuda() 77 | support_radius = opt.support_multiple * support_radius 78 | support_radius = support_radius.float().cuda(non_blocking=True) 79 | support_angle = (opt.support_angle / 360) * 2 * np.pi 80 | gt_normal=gt_normal.float().cuda() 81 | normal=normal.float().cuda() 82 | index=index.cuda() 83 | # print(index.shape) 84 | # exit(0) 85 | 86 | noise_patch = noise_patch.transpose(2, 1).contiguous() 87 | patch_normal=patch_normal.transpose(2,1).contiguous() 88 | 89 | x,n,w,topidx= denoisenet(noise_patch, patch_normal,index) 90 | # loss,loss1,loss2=comtrative_loss(x,n,gt_patch,gt_patch_normal,w,gt_normal,support_radius,support_angle,opt.repulsion_alpha) 91 | loss,loss1,loss2,loss3=compute_bilateral_loss(x,n,gt_patch,gt_patch_normal,w,support_radius,support_angle,opt.repulsion_alpha,topidx) 92 | loss.backward() 93 | optimizer.step() 94 | 95 | print('[%d: %d/%d] train loss: %f\n' % (epoch, batch_ind, num_batch, loss.item())) 96 | train_writer.add_scalar('loss', loss.data.item(), epoch * num_batch + batch_ind) 97 | 98 | train_writer.add_scalar('loss1', loss1.data.item(), epoch * num_batch + batch_ind) 99 | train_writer.add_scalar('loss2', loss2.data.item(), epoch * num_batch + batch_ind) 100 | train_writer.add_scalar('loss3', loss3.data.item(), epoch * num_batch + batch_ind) 101 | checpoint_state = { 102 | 'epoch': epoch + 1, 103 | 'state_dict': denoisenet.state_dict(), 104 | 'optimizer': optimizer.state_dict()} 105 | 106 | if epoch == (opt.nepoch - 1): 107 | 108 | torch.save(checpoint_state, '%s/model_full_ae.pth' % opt.network_model_dir) 109 | 110 | if epoch % opt.model_interval == 0: 111 | 112 | torch.save(checpoint_state, '%s/model_full_ae_%d.pth' % (opt.network_model_dir, epoch)) 113 | 114 | if __name__ == '__main__': 115 | parameters = parse_arguments() 116 | parameters.trainset = './trainset' 117 | parameters.summary_train = './log' 118 | parameters.network_model_dir = './Models/' 119 | parameters.batchSize = 128 120 | parameters.lr = 1e-4 121 | parameters.workers = 4 122 | parameters.nepoch =50 123 | print(parameters) 124 | train(parameters) 125 | -------------------------------------------------------------------------------- /utils118.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from sklearn.decomposition import PCA 3 | import math 4 | import torch 5 | import argparse 6 | ##########################Parameters######################## 7 | # 8 | # 9 | # 10 | # 11 | ############################################################### 12 | 13 | def str2bool(v): 14 | if v.lower() in ('yes', 'true', 't', 'y', '1'): 15 | return True 16 | elif v.lower() in ('no', 'false', 'f', 'n', '0'): 17 | return False 18 | else: 19 | raise argparse.ArgumentTypeError('Boolean value expected.') 20 | def parse_arguments(): 21 | parser = argparse.ArgumentParser() 22 | # naming / file handling 23 | parser.add_argument('--name', type=str, default='pcdenoising', help='training run name') 24 | parser.add_argument('--network_model_dir', type=str, default='./Models/all/test1', help='output folder (trained models)') 25 | parser.add_argument('--trainset', type=str, default='./dataset/Train', help='training set file name') 26 | parser.add_argument('--testset', type=str, default='./Dataset/Test', help='testing set file name') 27 | parser.add_argument('--save_dir', type=str, default='./Results/all/test1', help='') 28 | parser.add_argument('--summary_train', type=str, default='.logs/all/test', help='') 29 | parser.add_argument('--summary_test', type=str, default='./Summary/logs/model1/test', help='') 30 | 31 | # training parameters 32 | parser.add_argument('--nepoch', type=int, default=50, help='number of epochs to train for') 33 | parser.add_argument('--batchSize', type=int, default=32, help='input batch size') 34 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers') 35 | parser.add_argument('--manualSeed', type=int, default=3627473, help='manual seed') 36 | parser.add_argument('--start_epoch', type=int, default=0, help='') 37 | parser.add_argument('--patch_per_shape', type=int, default=8000, help='') 38 | parser.add_argument('--patch_radius', type=float, default=0.05, help='') 39 | parser.add_argument('--knn patch',type=bool,default=True,help='use knn neighboorhood to construct patch') 40 | 41 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate') 42 | parser.add_argument('--momentum', type=float, default=0.9, help='gradient descent momentum') 43 | parser.add_argument('--model_interval', type=int, default=5, metavar='N', help='how many batches to wait before logging training status') 44 | 45 | # others parameters 46 | parser.add_argument('--resume', type=str, default='', help='refine model at this path') 47 | parser.add_argument('--support_multiple', type=float, default=4.0, help='the multiple of support radius') 48 | parser.add_argument('--support_angle', type=int, default=15, help='') 49 | parser.add_argument('--gt_normal_mode', type=str, default='nearest', help='') 50 | parser.add_argument('--repulsion_alpha', type=float, default='0.98', help='') 51 | 52 | # evaluation parameters 53 | parser.add_argument('--eval_dir', type=str, default='./Models/all/test1', help='') 54 | parser.add_argument('--eval_iter_nums', type=int, default=3, help='') 55 | 56 | return parser.parse_args() 57 | 58 | ###################Pre-Processing Tools######################## 59 | # 60 | # 61 | # 62 | # 63 | ############################################################### 64 | 65 | 66 | def get_principle_dirs(pts): 67 | 68 | pts_pca = PCA(n_components=3) 69 | pts_pca.fit(pts) 70 | principle_dirs = pts_pca.components_ 71 | principle_dirs /= np.linalg.norm(principle_dirs, 2, axis=0) 72 | 73 | return principle_dirs 74 | 75 | 76 | def pca_alignment(pts, random_flag=False): 77 | 78 | pca_dirs = get_principle_dirs(pts) 79 | 80 | if random_flag: 81 | 82 | pca_dirs *= np.random.choice([-1, 1], 1) 83 | 84 | rotate_1 = compute_roatation_matrix(pca_dirs[2], [0, 0, 1], pca_dirs[1]) 85 | pca_dirs = np.array(rotate_1 * pca_dirs.T).T 86 | rotate_2 = compute_roatation_matrix(pca_dirs[1], [1, 0, 0], pca_dirs[2]) 87 | pts = np.array(rotate_2 * rotate_1 * np.matrix(pts.T)).T 88 | 89 | inv_rotation = np.array(np.linalg.inv(rotate_2 * rotate_1)) 90 | 91 | return pts, inv_rotation 92 | 93 | def compute_roatation_matrix(sour_vec, dest_vec, sour_vertical_vec=None): 94 | # http://immersivemath.com/forum/question/rotation-matrix-from-one-vector-to-another/ 95 | if np.linalg.norm(np.cross(sour_vec, dest_vec), 2) == 0 or np.abs(np.dot(sour_vec, dest_vec)) >= 1.0: 96 | if np.dot(sour_vec, dest_vec) < 0: 97 | return rotation_matrix(sour_vertical_vec, np.pi) 98 | return np.identity(3) 99 | alpha = np.arccos(np.dot(sour_vec, dest_vec)) 100 | a = np.cross(sour_vec, dest_vec) / np.linalg.norm(np.cross(sour_vec, dest_vec), 2) 101 | c = np.cos(alpha) 102 | s = np.sin(alpha) 103 | R1 = [a[0] * a[0] * (1.0 - c) + c, 104 | a[0] * a[1] * (1.0 - c) - s * a[2], 105 | a[0] * a[2] * (1.0 - c) + s * a[1]] 106 | 107 | R2 = [a[0] * a[1] * (1.0 - c) + s * a[2], 108 | a[1] * a[1] * (1.0 - c) + c, 109 | a[1] * a[2] * (1.0 - c) - s * a[0]] 110 | 111 | R3 = [a[0] * a[2] * (1.0 - c) - s * a[1], 112 | a[1] * a[2] * (1.0 - c) + s * a[0], 113 | a[2] * a[2] * (1.0 - c) + c] 114 | 115 | R = np.matrix([R1, R2, R3]) 116 | 117 | return R 118 | 119 | 120 | def rotation_matrix(axis, theta): 121 | 122 | # Return the rotation matrix associated with counterclockwise rotation about the given axis by theta radians. 123 | 124 | axis = np.asarray(axis) 125 | axis = axis / math.sqrt(np.dot(axis, axis)) 126 | a = math.cos(theta / 2.0) 127 | b, c, d = -axis * math.sin(theta / 2.0) 128 | aa, bb, cc, dd = a * a, b * b, c * c, d * d 129 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d 130 | return np.matrix(np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)], 131 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)], 132 | [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]])) 133 | 134 | 135 | 136 | 137 | ##########################Network Tools######################## 138 | # 139 | # 140 | # 141 | # 142 | ############################################################### 143 | 144 | def adjust_learning_rate(optimizer, epoch, opt): 145 | 146 | lr_shceduler(optimizer, epoch, opt.lr) 147 | 148 | def lr_shceduler(optimizer, epoch, init_lr): 149 | 150 | if epoch > 36: 151 | init_lr *= 0.5e-3 152 | elif epoch > 32: 153 | init_lr *= 1e-3 154 | elif epoch > 24: 155 | init_lr *= 1e-2 156 | elif epoch > 16: 157 | init_lr *= 1e-1 158 | for param_group in optimizer.param_groups: 159 | param_group['lr'] = init_lr 160 | 161 | ################################Ablation Study of Different Loss ############################### 162 | 163 | #论文中第一种的方案,La_proj 164 | def compute_original_1_loss(pts_pred, gt_patch_pts, gt_patch_normals, support_radius, alpha): 165 | 166 | pts_pred = pts_pred.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 167 | dist_square = (pts_pred - gt_patch_pts).pow(2).sum(2) 168 | 169 | # avoid divided by zero 170 | weight = torch.exp(-1 * dist_square / (support_radius ** 2)) + 1e-12 171 | weight = weight / weight.sum(1, keepdim=True) 172 | 173 | # key loss 174 | project_dist = ((pts_pred - gt_patch_pts) * gt_patch_normals).sum(2) 175 | imls_dist = torch.abs((project_dist * weight).sum(1)) 176 | 177 | # repulsion loss 178 | max_dist = torch.max(dist_square, 1)[0] 179 | 180 | # final loss 181 | dist = torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist) 182 | 183 | return dist 184 | #使用双边滤波 185 | def compute_original_2_loss(pred_point, gt_patch_pts, gt_patch_normals, support_radius, support_angle, alpha): 186 | 187 | # Compute Spatial Weighted Function 188 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 189 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2) 190 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2)) 191 | 192 | ############# Get The Nearest Normal For Predicted Point ############# 193 | nearest_idx = torch.argmin(dist_square, dim=1) 194 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)]) 195 | pred_point_normal = pred_point_normal.view(-1, 3) 196 | pred_point_normal = pred_point_normal.unsqueeze(1) 197 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1) 198 | ############# Get The Nearest Normal For Predicted Point ############# 199 | 200 | # Compute Normal Weighted Function 201 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2) 202 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2) 203 | 204 | # # avoid divided by zero 205 | weight = weight_theta * weight_phi + 1e-12 206 | weight = weight / weight.sum(1, keepdim=True) 207 | 208 | # key loss 209 | #不同于poinfilter的地方,Pointfilter用dist_square*normal 210 | project_dist = torch.sqrt(dist_square) 211 | imls_dist = (project_dist * weight).sum(1) 212 | 213 | # repulsion loss 214 | max_dist = torch.max(dist_square, 1)[0] 215 | 216 | # final loss 217 | dist = torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist) 218 | 219 | return dist 220 | #PointCleanNet 221 | def compute_original_3_loss(pts_pred, gt_patch_pts, alpha): 222 | # PointCleanNet Loss 223 | pts_pred = pts_pred.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 224 | m = (pts_pred - gt_patch_pts).pow(2).sum(2) 225 | min_dist = torch.min(m, 1)[0] 226 | max_dist = torch.max(m, 1)[0] 227 | dist = torch.mean((alpha * min_dist) + (1 - alpha) * max_dist) 228 | # print('min_dist: %f max_dist: %f' % (alpha * torch.mean(min_dist).item(), (1 - alpha) * torch.mean(max_dist).item())) 229 | return dist * 100 230 | 231 | def compute_original_4_loss(pts_pred1,pts_pred2, gt_patch_pts,alpha): 232 | # PointCleanNet Loss 233 | pts_pred1= pts_pred1.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 234 | m = (pts_pred1 - gt_patch_pts).pow(2).sum(2) 235 | min_dist = torch.min(m, 1)[0] 236 | max_dist = torch.max(m, 1)[0] 237 | 238 | dist1 =torch.mean((alpha * min_dist) + (1 - alpha) * max_dist) 239 | 240 | pts_pred2= pts_pred2.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 241 | m = (pts_pred2 - gt_patch_pts).pow(2).sum(2) 242 | min_dist = torch.min(m, 1)[0] 243 | max_dist = torch.max(m, 1)[0] 244 | dist2 = torch.mean((alpha * min_dist) + (1 - alpha) * max_dist) 245 | 246 | dist=dist1+dist2 247 | 248 | # print('min_dist: %f max_dist: %f' % (alpha * torch.mean(min_dist).item(), (1 - alpha) * torch.mean(max_dist).item())) 249 | return dist * 100 250 | 251 | def compute_original_5_loss(pts_pred1,pts_pred2,normal, gt_patch_pts,gt_normal,alpha): 252 | # PointCleanNet Loss 253 | Batchsize=gt_patch_pts.size(0) 254 | pts_pred1= pts_pred1.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 255 | m = (pts_pred1 - gt_patch_pts).pow(2).sum(2) 256 | min_dist= torch.min(m, 1)[0] 257 | max_dist = torch.max(m, 1)[0] 258 | dist1 =torch.mean((alpha * min_dist) + (1 - alpha) * max_dist) 259 | 260 | 261 | pred_ponts=pts_pred2 262 | pts_pred2= pts_pred2.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 263 | m = (pts_pred2 - gt_patch_pts).pow(2).sum(2) 264 | min_dist,idx= torch.min(m, 1) 265 | max_dist = torch.max(m, 1)[0] 266 | dist2 = torch.mean((alpha * min_dist) + (1 - alpha) * max_dist) 267 | 268 | idx=idx.unsqueeze(-1).unsqueeze(-1) 269 | nearestpoint=torch.gather(gt_patch_pts,dim=1,index=idx.expand(Batchsize,1,3)) 270 | nearestpoint=nearestpoint.squeeze(1) 271 | point=(pred_ponts-nearestpoint).unsqueeze(-1) 272 | pointnormal=normal.unsqueeze(1) 273 | oth=torch.abs(torch.bmm(pointnormal,point)) 274 | oth=oth.mean()*100 275 | 276 | normal_dist=(normal-gt_normal).pow(2).sum(1).mean() 277 | dist=(dist1+dist2)*100 278 | out=dist+normal_dist+oth 279 | 280 | 281 | # print('min_dist: %f max_dist: %f' % (alpha * torch.mean(min_dist).item(), (1 - alpha) * torch.mean(max_dist).item())) 282 | return out,dist,normal_dist,oth 283 | 284 | 285 | def compute_original_6_loss(pts_pred1,gt_patch_pts,normal,gtnormal, alpha): 286 | # PointCleanNet Loss 287 | pts_pred1= pts_pred1.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 288 | m = (pts_pred1 - gt_patch_pts).pow(2).sum(2) 289 | min_dist = torch.min(m, 1)[0] 290 | max_dist = torch.max(m, 1)[0] 291 | 292 | dist =torch.mean((alpha * min_dist) + (1 - alpha) * max_dist) 293 | dist=dist*100 294 | 295 | loss1= torch.nn.functional.nll_loss(normal, gtnormal) 296 | 297 | loss=loss1+dist 298 | # print('min_dist: %f max_dist: %f' % (alpha * torch.mean(min_dist).item(), (1 - alpha) * torch.mean(max_dist).item())) 299 | return loss,dist,loss1 300 | ################################Ablation Study of Different Loss ############################### 301 | #作者改进的双边滤波 302 | def compute_original_7_loss(pts_pred1,gt_patch_pts,normal,gtnormal,patch_center_normal,alpha): 303 | # PointCleanNet Loss 304 | Batchsize = gt_patch_pts.size(0) 305 | pred_ponts = pts_pred1 306 | pts_pred1= pts_pred1.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 307 | m = (pts_pred1 - gt_patch_pts).pow(2).sum(2) 308 | min_dist,idx= torch.min(m, 1) 309 | max_dist = torch.max(m, 1)[0] 310 | 311 | dist =torch.mean((alpha * min_dist) + (1 - alpha) * max_dist) 312 | dist=dist*100 313 | ''' 314 | idx = idx.unsqueeze(-1).unsqueeze(-1) 315 | nearestpoint = torch.gather(gt_patch_pts, dim=1, index=idx.expand(Batchsize, 1, 3)) 316 | nearestpoint = nearestpoint.squeeze(1) 317 | point = (pred_ponts - nearestpoint).unsqueeze(-1) 318 | pointnormal = patch_center_normal 319 | oth = torch.abs(torch.bmm(pointnormal, point)) 320 | oth = oth.mean() * 100 321 | ''' 322 | #点法向量相乘是否要加系数? 323 | loss1= torch.nn.functional.nll_loss(normal, gtnormal) 324 | 325 | loss=loss1+dist 326 | # print('min_dist: %f max_dist: %f' % (alpha * torch.mean(min_dist).item(), (1 - alpha) * torch.mean(max_dist).item())) 327 | return loss,dist,loss1 328 | def compute_original_8_loss(pred_point, gt_patch_pts, gt_patch_normals,deltnorma,prednormal,support_radius, support_angle, alpha): 329 | 330 | # Our Loss 331 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 332 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2) 333 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2)) 334 | 335 | nearest_idx = torch.argmin(dist_square, dim=1) 336 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)]) 337 | pred_point_normal = pred_point_normal.view(-1, 3) 338 | pred_point_normal = pred_point_normal.unsqueeze(1) 339 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1) 340 | 341 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2) 342 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2) 343 | 344 | # # avoid divided by zero 345 | weight = weight_theta * weight_phi + 1e-12 346 | weight = weight / weight.sum(1, keepdim=True) 347 | 348 | # key loss 349 | project_dist = torch.abs(((pred_point - gt_patch_pts) * gt_patch_normals).sum(2)) 350 | imls_dist = (project_dist * weight).sum(1) 351 | 352 | # repulsion loss 353 | max_dist = torch.max(dist_square, 1)[0] 354 | 355 | loss1 = torch.nn.functional.nll_loss(prednormal, deltnorma) 356 | 357 | # final loss 358 | dist = torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist) 359 | 360 | loss=dist+loss1 361 | 362 | return loss 363 | def compute_bilateral_loss(pred_point,pred_normal, gt_patch_pts, gt_patch_normals,predweight,support_radius, support_angle, alpha,top_idx): 364 | 365 | # Our Loss 366 | # Our Loss 367 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 368 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2) 369 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2)) 370 | 371 | nearest_idx = torch.argmin(dist_square, dim=1) 372 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)]) 373 | pred_point_normal = pred_point_normal.view(-1, 3) 374 | pred_point_normal = pred_point_normal.unsqueeze(1) 375 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1) 376 | 377 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2) 378 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2) 379 | 380 | # # avoid divided by zero 381 | weight = weight_theta * weight_phi + 1e-12 382 | weight = weight / weight.sum(1, keepdim=True) 383 | 384 | # key loss 385 | project_dist = torch.abs(((pred_point - gt_patch_pts) * gt_patch_normals).sum(2)) 386 | imls_dist = (project_dist * weight).sum(1) 387 | 388 | # repulsion loss 389 | max_dist = torch.max(dist_square, 1)[0] 390 | 391 | # final loss 392 | loss1 = 100*torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist) 393 | 394 | pred_normal=pred_normal.unsqueeze(1) 395 | pred_normal=pred_normal.repeat(1,gt_patch_normals.size(1),1) 396 | loss2=10*(pred_normal-pred_point_normal).pow(2).sum(2).mean(1).mean(0) 397 | 398 | oth_loss = (pred_normal*(pred_point-gt_patch_pts)).sum(2).pow(2) 399 | oth_loss = 10*(oth_loss).mean() 400 | 401 | loss=loss1+loss2+oth_loss 402 | 403 | return loss,loss1,loss2,oth_loss 404 | 405 | def compute_bilateral_loss1(pred_point,pred_normal, gt_patch_pts, gt_patch_normals,predweight,support_radius, support_angle, alpha,top_idx): 406 | 407 | # Our Loss 408 | # Our Loss 409 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 410 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2) 411 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2)) 412 | 413 | nearest_idx = torch.argmin(dist_square, dim=1) 414 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)]) 415 | pred_point_normal = pred_point_normal.view(-1, 3) 416 | pred_point_normal = pred_point_normal.unsqueeze(1) 417 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1) 418 | 419 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2) 420 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2) 421 | 422 | # # avoid divided by zero 423 | weight = weight_theta * weight_phi + 1e-12 424 | weight = weight / weight.sum(1, keepdim=True) 425 | 426 | # key loss 427 | project_dist = torch.abs(((pred_point - gt_patch_pts) * gt_patch_normals).sum(2)) 428 | imls_dist = (project_dist * weight).sum(1) 429 | 430 | # repulsion loss 431 | max_dist = torch.max(dist_square, 1)[0] 432 | 433 | # final loss 434 | loss1 = 100*torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist) 435 | 436 | pred_normal=pred_normal.unsqueeze(1) 437 | pred_normal=pred_normal.repeat(1,gt_patch_normals.size(1),1) 438 | loss2=10*(pred_normal-pred_point_normal).pow(2).sum(2).mean(1).mean(0) 439 | 440 | # oth_loss =predweight*(pred_normal*(pred_point-gt_patch_pts)).sum(2).pow(2) 441 | # oth_loss = 10*(oth_loss).mean() 442 | 443 | loss=loss1+loss2 444 | 445 | return loss,loss1,loss2 446 | 447 | 448 | def compute_bilateral_loss_with_repulsion(pred_point, gt_patch_pts, gt_patch_normals, support_radius, support_angle, alpha): 449 | 450 | # Our Loss 451 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 452 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2) 453 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2)) 454 | 455 | nearest_idx = torch.argmin(dist_square, dim=1) 456 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)]) 457 | pred_point_normal = pred_point_normal.view(-1, 3) 458 | pred_point_normal = pred_point_normal.unsqueeze(1) 459 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1) 460 | 461 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2) 462 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2) 463 | 464 | # # avoid divided by zero 465 | weight = weight_theta * weight_phi + 1e-12 466 | weight = weight / weight.sum(1, keepdim=True) 467 | 468 | # key loss 469 | project_dist = torch.abs(((pred_point - gt_patch_pts) * gt_patch_normals).sum(2)) 470 | imls_dist = (project_dist * weight).sum(1) 471 | 472 | # repulsion loss 473 | max_dist = torch.max(dist_square, 1)[0] 474 | 475 | # final loss 476 | dist = torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist) 477 | 478 | return dist 479 | #loss =compute_original_L2_loss(x, gt_patch,delta_normal,xx,opt.repulsion_alpha) 480 | def compute_original_L2_loss(pts_pred, gt_patch_pts,gt_mask,pred_mask,alpha): 481 | # PointCleanNet Loss 482 | 483 | #classficaton loss 484 | loss1=torch.nn.functional.nll_loss(pred_mask,gt_mask) 485 | #loss1=100*loss1 486 | loss2=compute_original_3_loss(pts_pred,gt_patch_pts,alpha) 487 | loss=0.5*loss1+0.5*loss2 488 | 489 | return loss,loss1,loss2 490 | 491 | def compute_orginal_Pointfilter_loss(pred_point, gt_patch_pts, gt_patch_normals, support_radius, support_angle,gt_mask,pred_mask,alpha): 492 | 493 | loss1=torch.nn.functional.nll_loss(pred_mask,gt_mask) 494 | loss2=compute_bilateral_loss_with_repulsion(pred_point,gt_patch_pts,gt_patch_normals,support_radius,support_angle,alpha) 495 | loss=0.5*loss1+0.5*loss2 496 | 497 | return loss,loss1,loss2 498 | def cos_angle(v1, v2): 499 | 500 | return torch.bmm(v1.unsqueeze(1), v2.unsqueeze(2)).view(-1) / torch.clamp(v1.norm(2, 1) * v2.norm(2, 1), min=0.000001) 501 | 502 | def Patch_Normal_loss_Compute(pred_normas,gt_normals,top_idx): 503 | 504 | B,k=top_idx.size() 505 | gt_normals=torch.gather(gt_normals, dim=1, index=top_idx.unsqueeze(-1).expand(B, k, 3))#[B,256,3] 506 | # normal_loss=torch.min((pred_normas-gt_normals).pow(2).sum(2),(pred_normas+gt_normals).pow(2).sum(2)).mean(1).mean() 507 | normal_loss = (pred_normas - gt_normals).pow(2).sum(2).mean(1).mean() 508 | 509 | return normal_loss 510 | 511 | def Normal_loss_Compute(pred_normal,gt_normal): 512 | gt_normal=gt_normal.squeeze(1) 513 | # normal_loss=torch.min((pred_normas-gt_normals).pow(2).sum(2),(pred_normas+gt_normals).pow(2).sum(2)).mean(1).mean() 514 | normal_loss = (pred_normal - gt_normal).pow(1).sum(1).mean() 515 | 516 | return normal_loss 517 | 518 | def Cos_Compute_Normal_Loss(pre_normals,gt_normals): 519 | loss=(1 - torch.abs(cos_angle(pre_normals,gt_normals))).pow(2).mean() 520 | return loss 521 | def Sin_Compute_Normal_Loss(pre_normals,gt_normals): 522 | 523 | loss= 0.5*torch.norm(torch.cross(pre_normals, gt_normals, dim=-1), p=2, dim=1).mean() 524 | return loss 525 | ''' 526 | def Otho_Loss(gt_normals,gt_normal,gt_points,pre_point,index): 527 | 528 | k=index.size(1) 529 | B=index.size(0) 530 | gt_normals=torch.gather(gt_normals,dim=1,index=index.unsqueeze(-1).expand(B,k,3)) 531 | gt_points=torch.gather(gt_points,dim=1,index=index.unsqueeze(-1).expand(B,k,3)) 532 | 533 | pre_point=pre_point.unsqueeze(-1).repeat(1,1,k).transpose(2,1) 534 | loss1=(torch.abs(gt_normals*(pre_point-gt_points))).sum(-1).sum(1).mean(0) 535 | 536 | gt_normal=gt_normal.repeat(1,k,1) 537 | loss2=(torch.abs(gt_normal*(pre_point-gt_points))).sum(-1).sum(1).mean(0) 538 | 539 | loss=loss1+loss2 540 | return loss 541 | ''' 542 | def Otho_Loss(pred_normal,gt_point,pred_point): 543 | 544 | pred_point=pred_point.unsqueeze(1) 545 | # pred_normal=pred_normal.unsqueeze(1) 546 | point_constrain=(pred_point-gt_point).pow(2).sum(2).mean(0) 547 | 548 | point_normal=(pred_normal*(pred_point-gt_point)).sum(2).unsqueeze(-1) 549 | normal_point_constrain=(point_normal*pred_normal).pow(2).sum(2).mean(0) 550 | 551 | constrain=torch.abs(point_constrain-normal_point_constrain) 552 | 553 | return constrain 554 | 555 | 556 | 557 | 558 | 559 | def compute_loss(pred_point,pred_normal, gt_patch_pts, gt_patch_normals,predweight,gt_normal,support_radius,support_angle,alpha): 560 | # Our Loss 561 | orginal_point=pred_point 562 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)#[B,3]-->[B,N,3] 563 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2) 564 | 565 | nearest_idx = torch.argmin(dist_square, dim=1) 566 | neareat_point = torch.cat([gt_patch_pts[i, index, :] for i, index in enumerate(nearest_idx)]) 567 | neareat_point = neareat_point.view(-1, 3)#[64,3] 568 | 569 | max_dist = torch.max(dist_square, 1)[0] 570 | max_dist=torch.mean(max_dist) 571 | # loss1=10*(torch.abs((orginal_point-neareat_point).pow(2).sum(1)-(pred_normal*(orginal_point-neareat_point)).sum(1).pow(2))).mean() 572 | gt_normal=gt_normal.squeeze(1) 573 | # key loss 574 | pred_normal=pred_normal.unsqueeze(1).repeat(1,gt_patch_pts.size(1),1) 575 | project_dist =(gt_patch_normals*(pred_point - gt_patch_pts)).sum(2).pow(2)#[b,n] 576 | normal_dist=(pred_normal*(gt_patch_pts-0)).sum(2).pow(2) 577 | oth_loss=100*((normal_dist+project_dist)*predweight).sum(1).mean() 578 | 579 | dist=oth_loss+max_dist 580 | 581 | return dist,oth_loss,max_dist 582 | 583 | def compute_loss1(pred_point,pred_normal, gt_patch_pts, gt_patch_normals,predweight,gt_normal,support_radius,support_angle,alpha): 584 | # Our Loss 585 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)#[B,3]-->[B,N,3] 586 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2) 587 | 588 | min_dist=torch.min(dist_square,1)[0] 589 | max_dist = torch.max(dist_square, 1)[0] 590 | # final loss 591 | loss1 = 100*torch.mean((alpha * min_dist) + (1 - alpha) * max_dist) 592 | gt_normal=gt_normal.squeeze(1) 593 | loss2=10*(pred_normal-gt_normal).pow(2).sum(1).mean() 594 | # key loss 595 | pred_normal=pred_normal.unsqueeze(1).repeat(1,gt_patch_pts.size(1),1) 596 | project_dist =(gt_patch_normals*(pred_point - gt_patch_pts)).sum(2).pow(2)#[b,n] 597 | normal_dist=(pred_normal*(pred_point-gt_patch_pts)).sum(2).pow(2) 598 | oth_loss=10*((normal_dist+project_dist)*predweight).sum(1).mean() 599 | # regularizer = - torch.mean(predweight.log()) 600 | dist=0.5*loss1+0.5*loss2+oth_loss 601 | 602 | 603 | return dist,loss1,loss2,oth_loss 604 | def compute_ditstance(gt_patch_normals,gt_patch_pts,pred_point): 605 | dist=(gt_patch_normals*(pred_point-gt_patch_pts)).sum(2) 606 | return dist 607 | def comtrative_loss(pred_point,pred_normal, gt_patch_pts, gt_patch_normals,topidx,gt_normal,support_radius,support_angle,alpha): 608 | device = torch.device('cuda') 609 | B,N,C=gt_patch_pts.size() 610 | label=torch.zeros(B,N,1,device=device) 611 | idx_base = torch.arange(0, 64,device=device).view(-1, 1) *N 612 | topidx=topidx+idx_base 613 | topidx=topidx.view(-1) 614 | label=label.view(B*N,-1) 615 | label[topidx,:]=1 616 | label=label.view(B,N) 617 | margin=0.8 618 | 619 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1) 620 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2) 621 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2)) 622 | 623 | nearest_idx = torch.argmin(dist_square, dim=1) 624 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)]) 625 | pred_point_normal = pred_point_normal.view(-1, 3) 626 | pred_point_normal = pred_point_normal.unsqueeze(1) 627 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1) 628 | 629 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2) 630 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2) 631 | 632 | # # avoid divided by zero 633 | weight = weight_theta * weight_phi + 1e-12 634 | weight = weight / weight.sum(1, keepdim=True) 635 | # r1=label*(compute_ditstance(gt_patch_normals,gt_patch_pts,pred_point).pow(2)) 636 | # r2=(1-label)*((torch.clamp(margin-compute_ditstance(gt_patch_normals,gt_patch_pts,pred_point),min=0.0)).pow(2)) 637 | loss1=label*(weight*compute_ditstance(gt_patch_normals,gt_patch_pts,pred_point).pow(2))+(1-label)*(weight*(torch.clamp(margin-compute_ditstance(gt_patch_normals,gt_patch_pts,pred_point),min=0.0)).pow(2)) 638 | loss1=1000*loss1.mean() 639 | gt_normal=gt_normal.squeeze(1) 640 | loss2 =(pred_normal - gt_normal).pow(2).sum(1).mean() 641 | loss=10*(loss1+loss2) 642 | return loss,loss1,loss2 643 | 644 | 645 | if __name__ == '__main__': 646 | 647 | pred_normal=torch.rand(64,3) 648 | pred_point=torch.rand(64,3) 649 | gt_normal=torch.rand(64,3) 650 | gt_patch_pts=torch.rand(64,512,3) 651 | gt_patch_normals=torch.rand(64,512,3) 652 | support_radius=torch.rand(64,1) 653 | support_angle=0.23898 654 | alpha=0.97 655 | predweight=torch.rand(64,512) 656 | # compute_bilateral_loss_with_repulsion(pred_point,pred_normal,gt_patch_pts,gt_patch_normal,predweight,support_radius,support_angle,alpha) 657 | compute_loss(pred_point,pred_normal,gt_patch_pts,gt_patch_normals,predweight,gt_normal,alpha) --------------------------------------------------------------------------------