├── NetworkN1.py
├── README.md
├── Trained_Models
└── model_full_ae.pth
├── dataset923.py
├── image
├── figure1.png
└── figure2.png
├── testN.py
├── train_NetworkN1.py
└── utils118.py
/NetworkN1.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import torch
3 | import torch.nn as nn
4 | import torch.nn.parallel
5 | import torch.utils.data
6 | import torch.nn.functional as F
7 | from torch.autograd import Variable
8 | import utils
9 | import os
10 | import math
11 |
12 |
13 | def knn(x, k):
14 | inner = -2 * torch.matmul(x.transpose(2, 1), x)
15 | xx = torch.sum(x ** 2, dim=1, keepdim=True)
16 | pairwise_distance = -xx - inner - xx.transpose(2, 1)
17 |
18 | idx = pairwise_distance.topk(k=k, dim=-1)[1] # (batch_size, num_points, k)
19 | return idx
20 |
21 |
22 | def get_idx(x, k=20, idx=None, dim9=False):
23 | batch_size = x.size(0)
24 | num_points = x.size(2)
25 | x = x.view(batch_size, -1, num_points)
26 | if idx is None:
27 | if dim9 == False:
28 | idx = knn(x, k=k) # (batch_size, num_points, k)
29 | else:
30 | idx = knn(x[:, 6:], k=k)
31 | device = torch.device('cuda')
32 |
33 | idx_base = torch.arange(0, batch_size, device=device).view(-1, 1, 1) * num_points
34 |
35 | idx = idx + idx_base
36 |
37 | idx = idx.view(-1)
38 |
39 | return idx # (batch_size, 2*num_dims, num_points, k)
40 |
41 |
42 | def get_knn_feature(x, k=20):
43 | idx = get_idx(x, k=k)
44 |
45 | batch_size = x.size(0)
46 | num_points = x.size(2)
47 | x = x.view(batch_size, -1, num_points)
48 | _, num_dims, _ = x.size()
49 |
50 | x = x.transpose(2,
51 | 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
52 | feature = x.view(batch_size * num_points, -1)[idx, :]
53 | feature = feature.view(batch_size, num_points, k, num_dims)
54 | x = x.view(batch_size, num_points, 1, num_dims).repeat(1, 1, k, 1)
55 |
56 | feature = torch.cat((feature - x, x), dim=3).permute(0, 3, 1, 2).contiguous()
57 |
58 | return feature
59 |
60 |
61 | class FeatureExtration(nn.Module):
62 | def __init__(self, input_dim, output_dim, rate1, rate2, rate3):
63 | super(FeatureExtration, self).__init__()
64 | self.bn1_1 = nn.BatchNorm2d(output_dim // rate1)
65 | self.bn1_2 = nn.BatchNorm2d(output_dim // rate2)
66 | self.bn1_3 = nn.BatchNorm1d(output_dim // rate3)
67 | self.bn1_4 = nn.BatchNorm1d(output_dim)
68 | self.bn1_5 = nn.BatchNorm2d(output_dim // rate3)
69 |
70 | self.conv1_1 = nn.Sequential(nn.Conv2d(input_dim * 2, output_dim // rate1, 1), self.bn1_1,
71 | nn.LeakyReLU(negative_slope=0.2))
72 | self.conv1_2 = nn.Sequential(nn.Conv2d(input_dim * 2, output_dim // rate2, 1), self.bn1_2,
73 | nn.LeakyReLU(negative_slope=0.2))
74 | self.conv1_3 = nn.Sequential(nn.Conv1d(output_dim // rate1 + output_dim // rate2, output_dim // rate3, 1),
75 | self.bn1_3,
76 | nn.LeakyReLU(negative_slope=0.2))
77 | self.conv1_5 = nn.Sequential(nn.Conv2d((output_dim // rate3) * 2, output_dim // rate3, 1), self.bn1_5,
78 | nn.LeakyReLU(negative_slope=0.2))
79 |
80 | self.conv1_4 = nn.Sequential(nn.Conv1d(output_dim // rate3, output_dim, 1), self.bn1_4,
81 | nn.LeakyReLU(negative_slope=0.2))
82 |
83 | self.fc1 = nn.Sequential(
84 | nn.Linear(output_dim, output_dim // 2),
85 | nn.ReLU(inplace=True),
86 | nn.Linear(output_dim // 2, 3)
87 | )
88 |
89 | def forward(self, point):
90 | '''
91 |
92 | :param point: [B,3,N]
93 | :return: feature :[B,N,Outputdim]
94 | refinepoint:[B,N,3]
95 | '''
96 | pointfeature = get_knn_feature(point, k=8) # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k))
97 | pointfeature = self.conv1_1(pointfeature) # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
98 | pointfeature1 = pointfeature.max(dim=-1, keepdim=False)[
99 | 0] # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
100 | pointfeature = get_knn_feature(point, k=16)
101 | pointfeature = self.conv1_2(pointfeature)
102 | pointfeature2 = pointfeature.max(dim=-1, keepdim=False)[0]
103 | pointfeature = torch.cat([pointfeature1, pointfeature2], dim=1)
104 | pointfeature = self.conv1_3(pointfeature)
105 |
106 | pointfeature = get_knn_feature(pointfeature, k=16)
107 | pointfeature = self.conv1_5(pointfeature)
108 | pointfeature = pointfeature.max(dim=-1, keepdim=False)[0]
109 |
110 | pointfeature = self.conv1_4(pointfeature)
111 |
112 | pointfeature = pointfeature.transpose(2, 1)
113 | refinepoint = self.fc1(pointfeature)
114 | refinepoint = refinepoint + point.transpose(2, 1)
115 |
116 | return pointfeature, refinepoint
117 |
118 |
119 | class ConsistentPointSelect(nn.Module):
120 | def __init__(self, r=0.5):
121 | super(ConsistentPointSelect, self).__init__()
122 | self.r = r
123 |
124 | self.fc1 = nn.Sequential(
125 | nn.Linear(128, 64),
126 | nn.ReLU(inplace=True),
127 | )
128 | self.fc2 = nn.Sequential(
129 | nn.Linear(128, 64),
130 | nn.ReLU(inplace=True),
131 | )
132 | self.fc3 = nn.Sequential(
133 | nn.Linear(1, 32),
134 | nn.ReLU(inplace=True),
135 | )
136 | self.fc4 = nn.Sequential(
137 | nn.Linear(1, 32),
138 | nn.ReLU(inplace=True),
139 | )
140 | self.bn1 = nn.BatchNorm1d(128)
141 | self.conv1 = nn.Sequential(nn.Conv1d(192, 128, 1), self.bn1,
142 | nn.LeakyReLU(negative_slope=0.2))
143 | # self.sig=nn.Softmax(dim=1)
144 | self.sig = nn.Sigmoid()
145 |
146 | def angle(self, v1, v2):
147 | cross_prod = torch.stack([v1[..., 1] * v2[..., 2] - v1[..., 2] * v2[..., 1],
148 | v1[..., 2] * v2[..., 0] - v1[..., 0] * v2[..., 2],
149 | v1[..., 0] * v2[..., 1] - v1[..., 1] * v2[..., 0]], dim=-1)
150 | cross_prod_norm = torch.norm(cross_prod, dim=-1)
151 | dot_prod = torch.sum(v1 * v2, dim=-1)
152 | result = torch.atan2(cross_prod_norm, dot_prod)
153 | result = result.unsqueeze(-1)
154 | return result
155 |
156 | def get_center_normal(self, normalfeature, idx):
157 | # idx=np.load('top.npy')
158 | # print(idx.shape)
159 | # idx = torch.from_numpy(idx)
160 | B, N, C = normalfeature.size()
161 | center_normal = torch.gather(normalfeature, dim=1, index=idx.unsqueeze(-1).expand(B, 1, C))
162 | center_normal = center_normal.repeat(1, N, 1)
163 | normalfeature = normalfeature - center_normal
164 | # normalfeature=F.normalize(normalfeature,dim=2)
165 | # normalfeature=torch.exp(-torch.abs(normalfeature))
166 | return normalfeature
167 |
168 | def forward(self, pointfea, normalfea, index, point, normal):
169 | '''
170 |
171 | :param pointfea: point-wise feature [B,N,C]
172 | :param normal: normal-wise feature [B,N,C]
173 | :param index: refine center normal position[B,1]
174 | :param point: point coordinate [B,N,3]
175 | :param normal: normal coordinate [B,N,3]
176 | :return:
177 | topidx [B,k]
178 | keypointfeature[B,k,C]
179 | keypoint[B,k,3]
180 | keynormalfeature[B,k,C]
181 | keynormal[B,k,3]
182 | '''
183 | B, N, C = pointfea.size()
184 | k = int(self.r * N)
185 | # ||xi-xj||,???不确定保留
186 | distance = point * point
187 | pointdist = torch.sum(distance, dim=-1, keepdim=True)
188 |
189 | pointdist = torch.exp(-pointdist)
190 | pointdist = self.fc4(pointdist)
191 |
192 | angle = self.angle(point, normal)
193 | angle = self.fc3(angle)
194 |
195 | pointfeature = self.fc1(pointfea)
196 | # normalfeature=self.get_center_normal(normalfea,index)
197 | normalfeature = normalfea
198 | normalfeature = self.fc2(normalfeature)
199 |
200 | feature = torch.cat([pointfeature, normalfeature, angle, pointdist], dim=2)
201 | feature = feature.transpose(2, 1)
202 | feature = self.conv1(feature)
203 | feature = feature.transpose(2, 1) # [B,N,C]
204 | feature = torch.max(feature, dim=-1)[0] # [B,N]
205 | weight = self.sig(feature) # [B,N]
206 | top_idx = torch.argsort(weight, dim=-1, descending=True)[:, 0:k]
207 |
208 | keypointfeature = torch.gather(pointfea, dim=1, index=top_idx.unsqueeze(-1).expand(B, k, C))
209 | keynormalfeature = torch.gather(normalfea, dim=1, index=top_idx.unsqueeze(-1).expand(B, k, C))
210 | keyrefinepoint = torch.gather(point, dim=1, index=top_idx.unsqueeze(-1).expand(B, k, 3))
211 | keyrefinenormal = torch.gather(normal, dim=1, index=top_idx.unsqueeze(-1).expand(B, k, 3))
212 |
213 | return weight, top_idx, keypointfeature, keyrefinepoint, keynormalfeature, keyrefinenormal
214 |
215 |
216 | class KeyFeatureFusion(nn.Module):
217 | def __init__(self):
218 | super(KeyFeatureFusion, self).__init__()
219 |
220 | self.fc = nn.Sequential(
221 | nn.Linear(128, 128),
222 | nn.ReLU(inplace=True),
223 | )
224 | self.conv = nn.Sequential(
225 | nn.Conv1d(128, 128, 1),
226 | nn.BatchNorm1d(128),
227 | nn.LeakyReLU(negative_slope=0.2),
228 | )
229 | self.t = nn.Conv1d(128, 64, 1)
230 | # linear transform to get keys
231 | self.p = nn.Conv1d(128, 64, 1)
232 | # linear transform to get query
233 | self.g = nn.Conv1d(128, 128, 1)
234 | self.z = nn.Conv1d(256, 256, 1)
235 |
236 | self.gn = nn.GroupNorm(num_groups=1, num_channels=256)
237 |
238 | self.softmax = nn.Softmax(dim=-1)
239 |
240 | def normalAttention(self, points, normals):
241 | # print(points.shape)
242 | # print(normals.shape)
243 | t = self.t(points) # [batchsize,64,500]
244 | p = self.p(points) # [batchsize,64,500]
245 | v = self.g(normals)
246 | proj_query = t # B X C/2 XN
247 |
248 | proj_key = p.transpose(2, 1) # B X M XC/2
249 |
250 | energy = torch.bmm(proj_key, proj_query) # [B,N,N]
251 |
252 | total_energy = energy
253 | attention = self.softmax(total_energy) # B X N X N
254 | # print(attention.shape)
255 | proj_value = v
256 | out = torch.bmm(proj_value, attention.permute(0, 2, 1))
257 | # print(out.shape)
258 | return out
259 |
260 | def knnfeature(self, x, normalfvals, k):
261 | '''
262 |
263 | :param x: x is normal/point cardinate [B,N,3]
264 | :param normalfvals: normalfvals is normal/point feature [B,C,N]
265 | :param k: K neighbors
266 | :return: k normal features [B,N,K,C]
267 | '''
268 | x = x.transpose(2, 1).contiguous()
269 | batch_size, num_points, num_dims = normalfvals.size()
270 | idx = get_idx(x, k=k)
271 | normalfvals = normalfvals.transpose(2,
272 | 1).contiguous() # (batch_size, num_points, num_dims) -> (batch_size*num_points, num_dims) # batch_size * num_points * k + range(0, batch_size*num_points)
273 | feature = normalfvals.view(batch_size * num_points, -1)[idx, :]
274 | feature = feature.view(batch_size, num_points, k, num_dims)
275 |
276 | return feature
277 |
278 | def featurefuse(self, knnpointfeature, keyfeature, topidx):
279 | B, N, K, C = knnpointfeature.size()
280 | k = topidx.size(1)
281 | keyknnfeature = torch.gather(knnpointfeature, dim=1,
282 | index=topidx.unsqueeze(-1).unsqueeze(-1).expand(B, k, K, C))
283 | # keyfeature=keyfeature.unsqueeze(-1)
284 | # keyfeature=keyfeature.view(B,k,1,C).repeat(1,1,K,1)
285 | # keypoint=keypoint.unsqueeze(-1)
286 | # keypoint=keypoint.view(B,k,1,3).repeat(1,1,K,1)
287 | # feature is included:[point coordinate,key point feature,key point's knn feature]
288 | # feature=torch.cat([keyfeature,keyknnfeature],dim=-1)#[B,k,K,C]
289 | # feature=torch.mean(feature,dim=2)
290 | keyknnfeature = torch.mean(keyknnfeature, dim=2) # [B,k,C]
291 | # keyknnfeature = torch.sum(keyknnfeature, dim=2) # [B,k,C]
292 | feature = keyfeature + keyknnfeature
293 | return feature
294 |
295 | def forward(self, weight, allfeature, keyfeature, refinepoint, keypoint, topidx, k):
296 | '''
297 |
298 | :param allfeature: [B,N,C]
299 | :param keyfeature: [B,k,C]
300 | :param refinepoint: [B,N,3]
301 | :param keypoint: [B,k,3]
302 | :param topidx: [B,k,1]
303 | :param k: knn neighboorhood
304 | :return: keyknnfeature [B,C,N]
305 | '''
306 |
307 | # pointfeature=pointfeature.transpose(2,1)
308 | allfeature = allfeature * weight.unsqueeze(-1)
309 | knnpointfeature = self.knnfeature(refinepoint, allfeature, k) # [B,N,K,C]
310 | feature = self.featurefuse(knnpointfeature, keyfeature, topidx)
311 | feature = feature.transpose(2, 1)
312 | feature = self.conv(feature)
313 |
314 | return feature
315 |
316 |
317 | class NormalEncorder(nn.Module):
318 | def __init__(self):
319 | super(NormalEncorder, self).__init__()
320 |
321 | self.conv1_1 = nn.Sequential(nn.Conv1d(256, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.2))
322 | # self.conv1_2=nn.Sequential(nn.Conv1d(128,256,1),nn.BatchNorm1d(256),nn.LeakyReLU(negative_slope=0.2))
323 | self.conv1_2 = nn.Sequential(nn.Conv2d(128 * 2, 64, 1), nn.BatchNorm2d(64),
324 | nn.LeakyReLU(negative_slope=0.2))
325 |
326 | # self.conv2_1=nn.Sequential(nn.Conv1d(256,128,1),nn.BatchNorm1d(128),nn.LeakyReLU(negative_slope=0.2))
327 | # self.conv2_2=nn.Sequential(nn.Conv1d(128,256,1),nn.BatchNorm1d(256),nn.LeakyReLU(negative_slope=0.2))
328 | self.conv2_1 = nn.Sequential(nn.Conv2d(64 * 2, 128, 1), nn.BatchNorm2d(128),
329 | nn.LeakyReLU(negative_slope=0.2))
330 | self.conv2_2 = nn.Sequential(nn.Conv1d(128, 256, 1), nn.BatchNorm1d(256), nn.LeakyReLU(negative_slope=0.2))
331 |
332 | self.fc1 = nn.Sequential(nn.Conv2d(256, 128, 1), nn.BatchNorm2d(128), nn.LeakyReLU(negative_slope=0.2))
333 | self.fc2 = nn.Sequential(nn.Conv2d(128, 64, 1), nn.BatchNorm2d(64), nn.LeakyReLU(negative_slope=0.2))
334 | self.fc3 = nn.Sequential(nn.Conv2d(64, 3, 1))
335 | # self.fc3=nn.Linear(64,3)
336 |
337 | self.fc1_1 = nn.Sequential(nn.Linear(256, 128, 1), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.2))
338 | self.fc2_1 = nn.Sequential(nn.Linear(128, 64, 1), nn.BatchNorm1d(64), nn.LeakyReLU(negative_slope=0.2))
339 | self.fc3_1 = nn.Sequential(nn.Linear(64, 3, 1))
340 |
341 | def forward(self, x, normalfeature,pointfusefeature):
342 | # [B,256,N]
343 | # x=torch.cat([x,normalfeature],dim=1)
344 | x = x + normalfeature
345 | x=torch.cat([x,pointfusefeature],dim=1)
346 |
347 | feature = self.conv1_1(x)
348 | # feature1=self.conv1_2(feature1)
349 | # feature1=feature1+x
350 |
351 | feature1 = get_knn_feature(feature, k=8) # (batch_size, 3, num_points) -> (batch_size, 3*2, num_points, k))
352 | feature1 = self.conv1_2(feature1) # (batch_size, 3*2, num_points, k) -> (batch_size, 64, num_points, k)
353 | feature1 = feature1.max(dim=-1, keepdim=False)[
354 | 0] # (batch_size, 64, num_points, k) -> (batch_size, 64, num_points)
355 |
356 | feature1 = get_knn_feature(feature1, k=8)
357 | feature1 = self.conv2_1(feature1)
358 | feature1 = feature1.max(dim=-1, keepdim=False)[0] # [B,128,N]
359 |
360 | feature = feature + feature1
361 | feature = self.conv2_2(feature)
362 |
363 | dis = feature.max(dim=-1, keepdim=False)[0]
364 | dis = self.fc1_1(dis)
365 | dis = self.fc2_1(dis)
366 | dis = self.fc3_1(dis)
367 | # feature=F.normalize(feature,p=2)
368 |
369 | return dis
370 |
371 |
372 | '''
373 | class NormalEncorder(nn.Module):
374 | def __init__(self):
375 | super(NormalEncorder,self).__init__()
376 |
377 | self.conv1_1=nn.Sequential(nn.Conv1d(256,1024, 1),nn.BatchNorm1d(1024),nn.LeakyReLU(negative_slope=0.2))
378 | # self.conv1_2=nn.Sequential(nn.Conv1d(128,256,1),nn.BatchNorm1d(256),nn.LeakyReLU(negative_slope=0.2))
379 | # self.mappool1=nn.MaxPool1d(1,stride=2)
380 |
381 | # self.conv2_1=nn.Sequential(nn.Conv1d(256,128,1),nn.BatchNorm1d(128),nn.LeakyReLU(negative_slope=0.2))
382 | # self.conv2_2=nn.Sequential(nn.Conv1d(128,256,1),nn.BatchNorm1d(256),nn.LeakyReLU(negative_slope=0.2))
383 |
384 |
385 | # self.fc1=nn.Sequential(nn.Linear(256,128),nn.BatchNorm1d(128),nn.LeakyReLU(negative_slope=0.2))
386 | # self.fc2 = nn.Sequential(nn.Linear(128, 64), nn.BatchNorm1d(64), nn.LeakyReLU(negative_slope=0.2))
387 | # self.fc3=nn.Linear(64,3)
388 |
389 | self.fc1=nn.Sequential(nn.Linear(1024,512),nn.BatchNorm1d(512),nn.LeakyReLU(negative_slope=0.2))
390 | self.fc2 = nn.Sequential(nn.Linear(512, 128), nn.BatchNorm1d(128), nn.LeakyReLU(negative_slope=0.2))
391 | self.fc3 = nn.Sequential(nn.Linear(128, 64), nn.BatchNorm1d(64), nn.LeakyReLU(negative_slope=0.2))
392 | self.fc=nn.Linear(64,3)
393 |
394 |
395 |
396 | def forward(self,x,globalfeature,normalfeature):
397 |
398 | x=x+globalfeature
399 | x=torch.cat([x,normalfeature],dim=1)
400 |
401 |
402 | feature1=self.conv1_1(x)
403 | feature1=self.conv1_2(feature1)
404 | feature1=feature1+x
405 | feature1=self.mappool1(feature1)
406 |
407 | feature2=self.conv2_1(feature1)
408 | feature2=self.conv2_2(feature2)
409 | feature2=feature2+feature1
410 |
411 | feature=feature2.max(dim=-1,keepdim=False)[0]
412 | feature=self.fc1(feature)
413 | feature=self.fc2(feature)
414 | feature=self.fc3(feature)
415 |
416 |
417 | feature=self.conv1_1(x)
418 | feature=feature.max(dim=-1,keepdim=False)[0]
419 | feature=self.fc1(feature)
420 | feature=self.fc2(feature)
421 | feature=self.fc3(feature)
422 | feature=torch.tanh(self.fc(feature))
423 |
424 | return feature
425 | '''
426 |
427 |
428 | class MLP(nn.Module):
429 | def __init__(self):
430 | super(MLP, self).__init__()
431 |
432 | self.conv = nn.Sequential(nn.Conv1d(384, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(negative_slope=0.2))
433 | self.conv1 = nn.Sequential(nn.Conv1d(512, 256, 1), nn.BatchNorm1d(256), nn.LeakyReLU(negative_slope=0.2))
434 | self.conv2 = nn.Sequential(nn.Conv1d(256, 512, 1), nn.BatchNorm1d(512), nn.LeakyReLU(negative_slope=0.2))
435 | self.fc1_1 = nn.Linear(512, 256)
436 | self.fc1_2 = nn.Linear(256, 64)
437 | self.fc1_3 = nn.Linear(64, 3)
438 | self.bn1_11 = nn.BatchNorm1d(256)
439 | self.bn1_22 = nn.BatchNorm1d(64)
440 |
441 | def forward(self, x):
442 | x = self.conv(x)
443 | x = x.max(dim=-1, keepdim=False)[0]
444 | x = F.relu(self.bn1_11(self.fc1_1(x)))
445 | x = F.relu(self.bn1_22(self.fc1_2(x)))
446 | x = torch.tanh(self.fc1_3(x))
447 | return x
448 |
449 |
450 | class PCPNet(nn.Module):
451 | def __init__(self, num_points=500, output_dim=3, k=20):
452 | super(PCPNet, self).__init__()
453 | self.num_points = num_points
454 | self.k = k
455 |
456 | self.pointfeatEX = FeatureExtration(input_dim=3, output_dim=128, rate1=8, rate2=4, rate3=2)
457 | self.normalfeatEX = FeatureExtration(input_dim=3, output_dim=128, rate1=8, rate2=4, rate3=2)
458 | self.weight = ConsistentPointSelect(r=0.5)
459 | self.pointFeaFu = KeyFeatureFusion()
460 | self.normalFeaFu = KeyFeatureFusion()
461 | self.normalDecoder = NormalEncorder()
462 |
463 | self.mlp1 = MLP()
464 |
465 | def forward(self, x, normal, index):
466 | '''
467 |
468 | :param x: point coordinate [64,3,N]
469 | :param normal: normal coordinate [64,3,n]
470 | :param normal_center: patch center coordinate [64,1]
471 | :return: point,normal
472 | '''
473 | # print("here")
474 | pointfeature, refinepoint = self.pointfeatEX(x)
475 | normalfeature, refinenormal = self.normalfeatEX(normal)
476 | weight, topidx, keypointfeature, keypoint, keynormalfeature, keynormal = self.weight(pointfeature,
477 | normalfeature, index,
478 | refinepoint, refinenormal)
479 |
480 | pointfusefeature = self.pointFeaFu(weight, pointfeature, keypointfeature, refinepoint, keypoint, topidx,
481 | k=10) # [B,C,N]
482 | # normalfusefeature=self.normalFeaFu(weight,normalfeature,keynormalfeature,refinenormal,keynormal,topidx,k=10)#[B,C,N]
483 | normalfusefeature = self.normalFeaFu(weight, normalfeature, keynormalfeature, refinepoint, keypoint, topidx,
484 | k=10)
485 |
486 | N = pointfusefeature.size(2)
487 |
488 | globalnormalfeature = torch.max(normalfeature, dim=1, keepdim=True)[0]
489 | globalnormalfeature = globalnormalfeature.repeat(1, N, 1)
490 | globalnormalfeature = globalnormalfeature.transpose(2, 1)
491 |
492 | globalpointfeature = torch.max(pointfeature, dim=1, keepdim=True)[0]
493 | globalpointfeature = globalpointfeature.repeat(1, N, 1)
494 | globalpointfeature = globalpointfeature.transpose(2, 1)
495 |
496 | maxpointfeature = torch.max(pointfusefeature, dim=2, keepdim=True)[0]
497 | maxpointfeature = maxpointfeature.repeat(1, 1, N)
498 |
499 | maxnormalfeature = torch.max(normalfusefeature, dim=2, keepdim=True)[0]
500 | maxnormalfeature = maxnormalfeature.repeat(1, 1, N) # [B,128,N]
501 |
502 | pfeat = torch.cat([pointfusefeature, globalpointfeature, normalfusefeature], dim=1)
503 | # nfeat=torch.cat([normalfusefeature,globalnormalfeature],dim=1)
504 |
505 | p = self.mlp1(pfeat)
506 | normal = self.normalDecoder(normalfusefeature, globalnormalfeature,pointfusefeature)
507 | # n=self.mlp2(nfeat,False)
508 | n = normal
509 |
510 | return p, n, weight, topidx
511 |
512 |
513 | if __name__ == '__main__':
514 | batchsize = 64
515 | point = torch.rand(64, 512, 3)
516 | point = point.transpose(2, 1)
517 | normal = torch.rand(64, 512, 3)
518 | normal = normal.transpose(2, 1)
519 | pfeat = torch.rand(64, 128, 512)
520 | nfeat = torch.rand(64, 128, 512)
521 | pdist = torch.rand(64, 1, 512)
522 | nrag = torch.rand(64, 1, 512)
523 | index = np.random.randint(10, 20, 64)
524 | # index=np.expand_dims(index,axis=1)
525 | # print(index)
526 | net = PCPNet()
527 | net(point, normal, index)
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PCDNF: Revisiting Learning-based Point Cloud Denoising via Joint Normal Filtering
2 |
3 | :zap:`Status Update: [2023/07/02] This paper has been accepted by the IEEE Transactions on Visualization and Computer Graphics (TVCG).`
4 |
5 |
6 |
7 |
8 |
9 | by [Zheng Liu](https://labzhengliu.github.io/), Yaowu Zhao, Sijing Zhan, [Yuanyuan Liu](https://cvlab-liuyuanyuan.github.io/), [Renjie Chen](http://staff.ustc.edu.cn/~renjiec/) and [Ying He](https://personal.ntu.edu.sg/yhe/)
10 |
11 | ## :bulb: Introduction
12 | Recovering high quality surfaces from noisy point clouds, known as point cloud denoising, is a fundamental yet challenging
13 | problem in geometry processing. Most of the existing methods either directly denoise the noisy input or filter raw normals followed by
14 | updating point positions. Motivated by the essential interplay between point cloud denoising and normal filtering, we revisit point cloud
15 | denoising from a multitask perspective, and propose an end-to-end network, named PCDNF, to denoise point clouds via joint normal
16 | filtering. In particular, we introduce an auxiliary normal filtering task to help the overall network remove noise more effectively while
17 | preserving geometric features more accurately. In addition to the overall architecture, our network has two novel modules. On one
18 | hand, to improve noise removal performance, we design a shape-aware selector to construct the latent tangent space representation of
19 | the specific point by comprehensively considering the learned point and normal features and geometry priors. On the other hand, point
20 | features are more suitable for describing geometric details, and normal features are more conducive for representing geometric
21 | structures (e.g., sharp edges and corners). Combining point and normal features allows us to overcome their weaknesses. Thus, we
22 | design a feature refinement module to fuse point and normal features for better recovering geometric information.
23 |
24 |
25 |
26 |
27 |
28 | ## :wrench: Usage
29 | ## Environment
30 | * Python 3.6
31 | * PyTorch 1.5.0
32 | * CUDA and CuDNN (CUDA 10.1 & CuDNN 7.5)
33 | * TensorboardX (2.0) if logging training info.
34 | ## Install required python packages:
35 | ``` bash
36 | pip install numpy
37 | pip install scipy
38 | pip install plyfile
39 | pip install scikit-learn
40 | pip install tensorboardX (only for training stage)
41 | pip install torch==1.5.0+cu101 torchvision==0.6.0+cu101 -f https://download.pytorch.org/whl/torch_stable.html
42 | ```
43 | ### Test the trained model:
44 | Set the parameters such as file path, batchsize, iteration numbers, etc in **testN.py** and then run it.
45 | We provide our pretrained model.
46 |
47 | ### Train the model:
48 | Set the parameters such as file path, batchsize, iteration numbers, etc in **train_NetworkN1.py** and then run it.
49 | Our training set is from [PointFilter](https://github.com/dongbo-BUAA-VR/Pointfilter) and the normal information is computed by PCA.
50 |
51 | ## :link: Citation
52 | If you find this work helpful please consider citing our [paper](https://ieeexplore.ieee.org/document/10173632) :
53 | ```
54 | @ARTICLE{10173632,
55 | author={Liu, Zheng and Zhao, Yaowu and Zhan, Sijing and Liu, Yuanyuan and Chen, Renjie and He, Ying},
56 | journal={IEEE Transactions on Visualization and Computer Graphics},
57 | title={PCDNF: Revisiting Learning-based Point Cloud Denoising via Joint Normal Filtering},
58 | year={2023},
59 | doi={10.1109/TVCG.2023.3292464}
60 | }
61 | ```
62 |
63 |
--------------------------------------------------------------------------------
/Trained_Models/model_full_ae.pth:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LabZhengLiu/PCDNF/4cc625096868c7888f80d0cbca38d4ba52a0c9eb/Trained_Models/model_full_ae.pth
--------------------------------------------------------------------------------
/dataset923.py:
--------------------------------------------------------------------------------
1 | from __future__ import print_function
2 |
3 | import torch
4 | import torch.utils.data as data
5 | from torch.utils.data.dataloader import default_collate
6 |
7 | import os
8 | import numpy as np
9 | import scipy.spatial as sp
10 |
11 | from utils118 import pca_alignment
12 |
13 |
14 | ##################################New Dataloader Class###########################
15 |
16 | def my_collate(batch):
17 | batch = list(filter(lambda x: x is not None, batch))
18 | return default_collate(batch)
19 |
20 |
21 | class RandomPointcloudPatchSampler(data.sampler.Sampler):
22 |
23 | def __init__(self, data_source, patches_per_shape, seed=None, identical_epochs=False):
24 | self.data_source = data_source
25 | self.patches_per_shape = patches_per_shape
26 | self.seed = seed
27 | self.identical_epochs = identical_epochs
28 | self.total_patch_count = None
29 |
30 | if self.seed is None:
31 | self.seed = np.random.random_integers(0, 2 ** 32 - 1, 1)[0]
32 | self.rng = np.random.RandomState(self.seed)
33 |
34 | self.total_patch_count = 0
35 | for shape_ind, _ in enumerate(self.data_source.shape_names):
36 | self.total_patch_count = self.total_patch_count + min(self.patches_per_shape,
37 | self.data_source.shape_patch_count[shape_ind])
38 |
39 | def __iter__(self):
40 |
41 | if self.identical_epochs:
42 | self.rng.seed(self.seed)
43 |
44 | return iter(
45 | self.rng.choice(sum(self.data_source.shape_patch_count), size=self.total_patch_count, replace=False))
46 |
47 | def __len__(self):
48 | return self.total_patch_count
49 |
50 |
51 | class PointcloudPatchDataset(data.Dataset):
52 |
53 | def __init__(self, root=None, shapes_list_file=None, patch_radius=0.05, points_per_patch=512,
54 | seed=None, train_state='train', shape_name=None, identical_epoches=False,knn=False):
55 |
56 | self.root = root
57 | self.shapes_list_file = shapes_list_file
58 |
59 | self.patch_radius = patch_radius
60 | self.points_per_patch = points_per_patch
61 | self.seed = seed
62 | self.train_state = train_state
63 | self.identical_epochs = identical_epoches
64 | self.knn=knn
65 |
66 | # initialize rng for picking points in a patch
67 | if self.seed is None:
68 | self.seed = np.random.random_integers(0, 2 ** 10 - 1, 1)[0]
69 | self.rng = np.random.RandomState(self.seed)
70 |
71 | self.shape_patch_count = []
72 | self.patch_radius_absolute = []
73 | self.gt_shapes = []
74 | self.noise_shapes = []
75 |
76 | self.shape_names = []
77 | if self.train_state == 'evaluation' and shape_name is not None:
78 | pts_normal = np.load(os.path.join(self.root, shape_name + '.npy'))
79 | noise_pts = pts_normal[:, 0:3]
80 | noise_normal = pts_normal[:, 3:6]
81 | noise_kdtree = sp.cKDTree(noise_pts)
82 | self.noise_shapes.append(
83 | {'noise_pts': noise_pts, 'noise_kdtree': noise_kdtree, 'noise_normal': noise_normal})
84 | self.shape_patch_count.append(noise_pts.shape[0])
85 | bbdiag = float(np.linalg.norm(noise_pts.max(0) - noise_pts.min(0), 2))
86 | self.patch_radius_absolute.append(bbdiag * self.patch_radius)
87 | elif self.train_state == 'train':
88 | with open(os.path.join(self.root, self.shapes_list_file)) as f:
89 | self.shape_names = f.readlines()
90 | self.shape_names = [x.strip() for x in self.shape_names]
91 | self.shape_names = list(filter(None, self.shape_names))
92 | for shape_ind, shape_name in enumerate(self.shape_names):
93 | print('getting information for shape %s' % shape_name)
94 | if shape_ind % 6 == 0:
95 | gt_pts_normal = np.load(os.path.join(self.root, shape_name + '.npy'))
96 | gt_pts = gt_pts_normal[:, 0:3]
97 | gt_normal = gt_pts_normal[:, 3:6]
98 | gt_kdtree = sp.cKDTree(gt_pts)
99 | self.gt_shapes.append({'gt_pts': gt_pts, 'gt_normal': gt_normal, 'gt_kdtree': gt_kdtree})
100 | self.noise_shapes.append(
101 | {'noise_pts': gt_pts, 'noise_kdtree': gt_kdtree, 'noise_normal': gt_normal})
102 | noise_pts = gt_pts
103 | else:
104 |
105 | pts_normal = np.load(os.path.join(self.root, shape_name + '.npy'))
106 | noise_pts = pts_normal[:, 0:3]
107 | noise_normal = pts_normal[:, 3:6]
108 | noise_kdtree = sp.cKDTree(noise_pts)
109 | self.noise_shapes.append(
110 | {'noise_pts': noise_pts, 'noise_kdtree': noise_kdtree, 'noise_normal': noise_normal})
111 |
112 | self.shape_patch_count.append(noise_pts.shape[0])
113 | bbdiag = float(np.linalg.norm(noise_pts.max(0) - noise_pts.min(0), 2))
114 | self.patch_radius_absolute.append(bbdiag * self.patch_radius)
115 |
116 | def patch_sampling(self, patch_inds):
117 |
118 | if self.identical_epochs:
119 | self.rng.seed(self.seed)
120 |
121 | # if patch_pts.shape[0] > self.points_per_patch:
122 | #
123 | # sample_index = self.rng.choice(range(patch_pts.shape[0]), self.points_per_patch, replace=False)
124 | #
125 | # else:
126 | #
127 | # sample_index = self.rng.choice(range(patch_pts.shape[0]), self.points_per_patch)
128 | # point_count = min(self.points_per_patch, len(patch_inds))
129 | if len(patch_inds)>=self.points_per_patch:
130 | patch_inds = patch_inds[self.rng.choice(len(patch_inds), self.points_per_patch, replace=False)]
131 | else:
132 | patch_inds=patch_inds[self.rng.choice(len(patch_inds),self.points_per_patch)]
133 |
134 | return patch_inds
135 |
136 | def gauss_fcn(self,x, mu=0, sigma2=0.12):
137 | tmp = -(x - mu) ** 2 / (2 * sigma2)
138 |
139 | return np.exp(tmp)
140 |
141 |
142 | def __getitem__(self, index):
143 |
144 | # find shape that contains the point with given global index
145 | shape_ind, patch_ind = self.shape_index(index)
146 | noise_shape = self.noise_shapes[shape_ind]
147 | patch_radius = self.patch_radius_absolute[shape_ind]
148 | # For noise_patch
149 |
150 | if self.knn:
151 | #索引中包含中心点
152 | dist,noise_patch_idx=np.array(noise_shape['noise_kdtree'].query(noise_shape['noise_pts'][patch_ind],self.points_per_patch))
153 | # patch_radius=dist[-1]
154 | noise_patch_idx=noise_patch_idx.astype(np.int)
155 | # print(noise_patch_idx)
156 | else:
157 | #索引中不包含中心点
158 | noise_patch_idx = noise_shape['noise_kdtree'].query_ball_point(noise_shape['noise_pts'][patch_ind],patch_radius)
159 | #noise_patch_idx=noise_patch_idx.astype(np.int)
160 | noise_patch_idx=np.array(noise_patch_idx)
161 |
162 | if len(noise_patch_idx) < 3:
163 | return None
164 |
165 | noise_sample_idx = self.patch_sampling(noise_patch_idx)
166 | index=np.where(noise_sample_idx==patch_ind)
167 | index=index[0]
168 |
169 | noise_patch_pts = noise_shape['noise_pts'][noise_sample_idx] - noise_shape['noise_pts'][patch_ind]
170 | # 返回旋转后的patch,以及逆矩阵R^-1
171 | noise_patch_pts /= patch_radius
172 | noise_patch_pts, noise_patch_inv = pca_alignment(noise_patch_pts)
173 |
174 | support_radius = np.linalg.norm(noise_patch_pts.max(0) - noise_patch_pts.min(0), 2) / noise_patch_pts.shape[0]
175 | support_radius = np.expand_dims(support_radius, axis=0)
176 |
177 | normal=noise_shape['noise_normal'][patch_ind]
178 | normal=np.expand_dims(normal,axis=0)
179 | normal = np.array(np.linalg.inv(noise_patch_inv) * np.matrix(normal.T)).T
180 |
181 |
182 | noise_patch_normal = noise_shape['noise_normal'][noise_sample_idx]
183 | noise_patch_normal = np.array(np.linalg.inv(noise_patch_inv) * np.matrix(noise_patch_normal.T)).T
184 |
185 | if self.train_state == 'evaluation':
186 | return torch.from_numpy(noise_patch_pts), torch.from_numpy(noise_patch_inv), \
187 | noise_shape['noise_pts'][patch_ind],torch.from_numpy(noise_patch_normal),torch.from_numpy(index),normal
188 |
189 | # For gt_patch
190 | gt_shape = self.gt_shapes[shape_ind // 6]
191 | if self.knn:
192 | # gt_patch_idx = gt_shape['gt_kdtree'].query_ball_point(noise_shape['noise_pts'][patch_ind], patch_radius)
193 | dist,gt_patch_idx=gt_shape['gt_kdtree'].query(noise_shape['noise_pts'][patch_ind],self.points_per_patch)
194 | gt_patch_idx=gt_patch_idx.astype(np.int)
195 | else:
196 | gt_patch_idx=np.array(gt_shape['gt_kdtree'].query_ball_point(noise_shape['noise_pts'][patch_ind],patch_radius))
197 | # print(gt_patch_idx)
198 | if len(gt_patch_idx) < 3:
199 | return None
200 |
201 | gt_sample_idx=self.patch_sampling(gt_patch_idx)
202 | # Patch归一化处理
203 | gt_patch_pts=gt_shape['gt_pts'][gt_sample_idx]-noise_shape['noise_pts'][patch_ind]
204 | gt_patch_pts /= patch_radius
205 | gt_patch_pts = np.array(np.linalg.inv(noise_patch_inv) * np.matrix(gt_patch_pts.T)).T
206 | # 对patch随机选取500个点
207 | gt_normal=gt_shape['gt_normal'][patch_ind]
208 | gt_normal=np.expand_dims(gt_normal,axis=0)
209 | gt_normal = np.array(np.linalg.inv(noise_patch_inv) * np.matrix(gt_normal.T)).T
210 |
211 | gt_patch_normal=gt_shape['gt_normal'][gt_sample_idx]
212 | gt_patch_normal = np.array(np.linalg.inv(noise_patch_inv) * np.matrix(gt_patch_normal.T)).T
213 |
214 | gt_point=gt_shape['gt_pts'][patch_ind]
215 | gt_point=gt_point-noise_shape['noise_pts'][patch_ind]
216 | gt_point=np.expand_dims(gt_point,axis=0)
217 | gt_point=np.array(np.linalg.inv(noise_patch_inv)*np.matrix(gt_point.T)).T
218 |
219 | return torch.from_numpy(noise_patch_pts), torch.from_numpy(gt_patch_pts), torch.from_numpy(noise_patch_normal),torch.from_numpy(gt_patch_normal),torch.from_numpy(support_radius),torch.from_numpy(gt_normal),torch.from_numpy(index),torch.from_numpy(normal)
220 |
221 | def __len__(self):
222 | return sum(self.shape_patch_count)
223 |
224 | def shape_index(self, index):
225 | shape_patch_offset = 0
226 | shape_ind = None
227 | for shape_ind, shape_patch_count in enumerate(self.shape_patch_count):
228 | if (index >= shape_patch_offset) and (index < shape_patch_offset + shape_patch_count):
229 | shape_patch_ind = index - shape_patch_offset
230 | break
231 | shape_patch_offset = shape_patch_offset + shape_patch_count
232 |
233 | return shape_ind, shape_patch_ind
234 |
235 |
236 | if __name__ == '__main__':
237 | seed = 3627473
238 | train_dataset = PointcloudPatchDataset(
239 | root='./dataset',
240 | shapes_list_file='train.txt',
241 | seed=seed,
242 | train_state='train',
243 | identical_epoches=True,
244 | knn=True)
245 | train_dataset.__getitem__(index=100000)
246 | # train_datasampler = RandomPointcloudPatchSampler(
247 | # train_dataset,
248 | # patches_per_shape=8000,
249 | # seed=3627473,
250 | # identical_epochs=False)
251 | # train_dataloader = torch.utils.data.DataLoader(
252 | # train_dataset,
253 | # collate_fn=my_collate,
254 | # sampler=train_datasampler,
255 | # shuffle=(train_datasampler is None),
256 | # batch_size=64,
257 | # num_workers=4,
258 | # pin_memory=True)
259 | # for batch_ind, data_tuple in enumerate(train_dataloader):
260 | #
261 | # noise_patch, gt_patch, patch_normal, gt_patch_normal = data_tuple
262 |
--------------------------------------------------------------------------------
/image/figure1.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LabZhengLiu/PCDNF/4cc625096868c7888f80d0cbca38d4ba52a0c9eb/image/figure1.png
--------------------------------------------------------------------------------
/image/figure2.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/LabZhengLiu/PCDNF/4cc625096868c7888f80d0cbca38d4ba52a0c9eb/image/figure2.png
--------------------------------------------------------------------------------
/testN.py:
--------------------------------------------------------------------------------
1 | import os
2 | import torch
3 | import time
4 | import numpy as np
5 | from NetworkN1 import PCPNet
6 | from dataset923 import PointcloudPatchDataset,my_collate
7 | from utils118 import parse_arguments
8 |
9 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
10 | def eval(opt):
11 |
12 |
13 | with open(os.path.join(opt.testset, 'test.txt'), 'r') as f:
14 | shape_names = f.readlines()
15 | shape_names = [x.strip() for x in shape_names]
16 | shape_names = list(filter(None, shape_names))
17 |
18 | if not os.path.exists(parameters.save_dir):
19 | os.makedirs(parameters.save_dir)
20 | for shape_id, shape_name in enumerate(shape_names):
21 | print(shape_name)
22 | original_noise_pts = np.load(os.path.join(opt.testset, shape_name + '.npy'))
23 | np.save(os.path.join(opt.save_dir, shape_name + '_pred_iter_0.npy'), original_noise_pts.astype('float32'))
24 | for eval_index in range(opt.eval_iter_nums):
25 | print(eval_index)
26 | test_dataset = PointcloudPatchDataset(
27 | root=opt.save_dir,
28 | shape_name=shape_name + '_pred_iter_' + str(eval_index),
29 | patch_radius=opt.patch_radius,
30 | train_state='evaluation',
31 | knn=True)
32 | test_dataloader = torch.utils.data.DataLoader(
33 | test_dataset,
34 | batch_size=opt.batchSize,
35 | collate_fn=my_collate,
36 | num_workers=int(opt.workers))
37 |
38 | pointfilter_eval = PCPNet()
39 | model_filename = os.path.join(parameters.eval_dir, 'model_full_ae.pth')
40 | checkpoint = torch.load(model_filename)
41 | pointfilter_eval.load_state_dict(checkpoint['state_dict'])
42 |
43 | pointfilter_eval.cuda()
44 | pointfilter_eval.eval()
45 |
46 | patch_radius = test_dataset.patch_radius_absolute
47 | pred_pts = np.empty((0, 6), dtype='float32')
48 | # start = time.time()/
49 | for batch_ind, data_tuple in enumerate(test_dataloader):
50 | #normal [64,3]
51 | noise_patch, noise_inv, noise_point,patch_normal,index,normals= data_tuple
52 |
53 | noise_patch = noise_patch.float().cuda()
54 | noise_inv = noise_inv.float().cuda()
55 | patch_normal=patch_normal.float().cuda()
56 | index=index.cuda()
57 | normals=normals.float().cuda()
58 |
59 | noise_patch = noise_patch.transpose(2, 1).contiguous()
60 | patch_normal=patch_normal.transpose(2,1).contiguous()
61 |
62 | with torch.no_grad():
63 | #dis= pointfilter_eval(noise_patch,patch_normal) # [64,3]
64 | dis,n,_,_= pointfilter_eval(noise_patch, patch_normal,index)
65 | # dis,classficaton,pointfval = pointfilter_eval(noise_patch,distance)#[64,3]
66 | dis=dis.unsqueeze(2)
67 | # n=n[:,0,:]
68 | n=n.unsqueeze(2)
69 |
70 | dis = torch.bmm(noise_inv, dis)#[64,3,1]
71 | n=torch.bmm(noise_inv,n)
72 | dis=np.squeeze(dis.data.cpu().numpy()) * patch_radius + noise_point.numpy()
73 | n=np.squeeze(n.data.cpu().numpy())
74 | normal=n
75 | #normal=normal.data.cpu().numpy()
76 | # print(dis.shape)
77 | # print(normal.shape)
78 | if normal.shape[0] != dis.shape[0]:
79 | normal = normal.reshape(dis.shape)
80 | # exit(0)
81 | pred_normal=np.append(dis,normal,axis=1)
82 | pred_pts = np.append(pred_pts,
83 | pred_normal,axis=0)
84 | end = time.time()
85 | print("total_time:"+str(end-start))
86 | np.save(os.path.join(opt.save_dir, shape_name + '_pred_iter_' + str(eval_index + 1) + '.npy'),
87 | pred_pts.astype('float32'))
88 | np.save(os.path.join(opt.save_dir, shape_name + '_pred_iter_' + str(eval_index + 1) + '.npy'),
89 | pred_pts.astype('float32'))
90 | # np.savetxt(os.path.join(opt.save_dir, shape_name + '.txt'),
91 | # pred_pts.astype('float32'), fmt='%.6f')
92 |
93 |
94 |
95 | if __name__ == '__main__':
96 |
97 | parameters = parse_arguments()
98 | parameters.testset = r'testdir'
99 | parameters.eval_dir = './Trained_Models/'
100 | parameters.batchSize = 64
101 | parameters.eval_iter_nums =1
102 | parameters.workers = 4
103 | parameters.save_dir = r'savedir'
104 | parameters.patch_radius = 0.05
105 | eval(parameters)
--------------------------------------------------------------------------------
/train_NetworkN1.py:
--------------------------------------------------------------------------------
1 | # coding=utf-8
2 |
3 | from __future__ import print_function
4 | from tensorboardX import SummaryWriter
5 | from NetworkNN import PCPNet
6 | from dataset923 import PointcloudPatchDataset, RandomPointcloudPatchSampler, my_collate
7 | from utils118 import parse_arguments, adjust_learning_rate,compute_bilateral_loss
8 |
9 | import os
10 | import numpy as np
11 | import torch.utils.data
12 | import torch.optim as optim
13 | import torch.backends.cudnn as cudnn
14 | torch.backends.cudnn.benchmark = True
15 | os.environ["CUDA_VISIBLE_DEVICES"] = "0"
16 | def train(opt):
17 | print(opt)
18 | if not os.path.exists(opt.summary_train):
19 | os.makedirs(opt.summary_train)
20 | if not os.path.exists(opt.network_model_dir):
21 | os.makedirs(opt.network_model_dir)
22 | print("Random Seed: ", opt.manualSeed)
23 | np.random.seed(opt.manualSeed)
24 | torch.manual_seed(opt.manualSeed)
25 | train_dataset = PointcloudPatchDataset(
26 | root=opt.trainset,
27 | shapes_list_file='train.txt',
28 | patch_radius=0.05,
29 | seed=opt.manualSeed,
30 | identical_epoches=False,
31 | knn=True)
32 | train_datasampler = RandomPointcloudPatchSampler(
33 | train_dataset,
34 | patches_per_shape=8000,
35 | seed=opt.manualSeed,
36 | identical_epochs=False)
37 | train_dataloader = torch.utils.data.DataLoader(
38 | train_dataset,
39 | collate_fn=my_collate,
40 | sampler=train_datasampler,
41 | shuffle=(train_datasampler is None),
42 | batch_size=opt.batchSize,
43 | num_workers=int(opt.workers),
44 | pin_memory=True)
45 | num_batch = len(train_dataloader)
46 | print(num_batch)
47 | # optionally resume from a checkpoint
48 | denoisenet =PCPNet()
49 | denoisenet.cuda()
50 | optimizer = optim.SGD(
51 | denoisenet.parameters(),
52 | lr=opt.lr,
53 | momentum=opt.momentum)
54 | train_writer = SummaryWriter(opt.summary_train)
55 | if opt.resume:
56 | if os.path.isfile(opt.resume):
57 | print("=> loading checkpoint '{}'".format(opt.resume))
58 | checkpoint = torch.load(opt.resume)
59 | opt.start_epoch = checkpoint['epoch']
60 | denoisenet.load_state_dict(checkpoint['state_dict'])
61 | optimizer.load_state_dict(checkpoint['optimizer'])
62 | print("=> loaded checkpoint '{}' (epoch {})".format(opt.resume, checkpoint['epoch']))
63 | else:
64 | print("=> no checkpoint found at '{}'".format(opt.resume))
65 |
66 | for epoch in range(opt.start_epoch, opt.nepoch):
67 | adjust_learning_rate(optimizer, epoch, opt)
68 | print('lr is %.10f' % (optimizer.param_groups[0]['lr']))
69 | for batch_ind, data_tuple in enumerate(train_dataloader):
70 | denoisenet.train()
71 | optimizer.zero_grad()
72 | noise_patch, gt_patch,patch_normal,gt_patch_normal,support_radius,gt_normal,index,normal= data_tuple
73 | noise_patch = noise_patch.float().cuda()
74 | gt_patch = gt_patch.float().cuda()
75 | patch_normal=patch_normal.float().cuda()
76 | gt_patch_normal=gt_patch_normal.float().cuda()
77 | support_radius = opt.support_multiple * support_radius
78 | support_radius = support_radius.float().cuda(non_blocking=True)
79 | support_angle = (opt.support_angle / 360) * 2 * np.pi
80 | gt_normal=gt_normal.float().cuda()
81 | normal=normal.float().cuda()
82 | index=index.cuda()
83 | # print(index.shape)
84 | # exit(0)
85 |
86 | noise_patch = noise_patch.transpose(2, 1).contiguous()
87 | patch_normal=patch_normal.transpose(2,1).contiguous()
88 |
89 | x,n,w,topidx= denoisenet(noise_patch, patch_normal,index)
90 | # loss,loss1,loss2=comtrative_loss(x,n,gt_patch,gt_patch_normal,w,gt_normal,support_radius,support_angle,opt.repulsion_alpha)
91 | loss,loss1,loss2,loss3=compute_bilateral_loss(x,n,gt_patch,gt_patch_normal,w,support_radius,support_angle,opt.repulsion_alpha,topidx)
92 | loss.backward()
93 | optimizer.step()
94 |
95 | print('[%d: %d/%d] train loss: %f\n' % (epoch, batch_ind, num_batch, loss.item()))
96 | train_writer.add_scalar('loss', loss.data.item(), epoch * num_batch + batch_ind)
97 |
98 | train_writer.add_scalar('loss1', loss1.data.item(), epoch * num_batch + batch_ind)
99 | train_writer.add_scalar('loss2', loss2.data.item(), epoch * num_batch + batch_ind)
100 | train_writer.add_scalar('loss3', loss3.data.item(), epoch * num_batch + batch_ind)
101 | checpoint_state = {
102 | 'epoch': epoch + 1,
103 | 'state_dict': denoisenet.state_dict(),
104 | 'optimizer': optimizer.state_dict()}
105 |
106 | if epoch == (opt.nepoch - 1):
107 |
108 | torch.save(checpoint_state, '%s/model_full_ae.pth' % opt.network_model_dir)
109 |
110 | if epoch % opt.model_interval == 0:
111 |
112 | torch.save(checpoint_state, '%s/model_full_ae_%d.pth' % (opt.network_model_dir, epoch))
113 |
114 | if __name__ == '__main__':
115 | parameters = parse_arguments()
116 | parameters.trainset = './trainset'
117 | parameters.summary_train = './log'
118 | parameters.network_model_dir = './Models/'
119 | parameters.batchSize = 128
120 | parameters.lr = 1e-4
121 | parameters.workers = 4
122 | parameters.nepoch =50
123 | print(parameters)
124 | train(parameters)
125 |
--------------------------------------------------------------------------------
/utils118.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | from sklearn.decomposition import PCA
3 | import math
4 | import torch
5 | import argparse
6 | ##########################Parameters########################
7 | #
8 | #
9 | #
10 | #
11 | ###############################################################
12 |
13 | def str2bool(v):
14 | if v.lower() in ('yes', 'true', 't', 'y', '1'):
15 | return True
16 | elif v.lower() in ('no', 'false', 'f', 'n', '0'):
17 | return False
18 | else:
19 | raise argparse.ArgumentTypeError('Boolean value expected.')
20 | def parse_arguments():
21 | parser = argparse.ArgumentParser()
22 | # naming / file handling
23 | parser.add_argument('--name', type=str, default='pcdenoising', help='training run name')
24 | parser.add_argument('--network_model_dir', type=str, default='./Models/all/test1', help='output folder (trained models)')
25 | parser.add_argument('--trainset', type=str, default='./dataset/Train', help='training set file name')
26 | parser.add_argument('--testset', type=str, default='./Dataset/Test', help='testing set file name')
27 | parser.add_argument('--save_dir', type=str, default='./Results/all/test1', help='')
28 | parser.add_argument('--summary_train', type=str, default='.logs/all/test', help='')
29 | parser.add_argument('--summary_test', type=str, default='./Summary/logs/model1/test', help='')
30 |
31 | # training parameters
32 | parser.add_argument('--nepoch', type=int, default=50, help='number of epochs to train for')
33 | parser.add_argument('--batchSize', type=int, default=32, help='input batch size')
34 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers')
35 | parser.add_argument('--manualSeed', type=int, default=3627473, help='manual seed')
36 | parser.add_argument('--start_epoch', type=int, default=0, help='')
37 | parser.add_argument('--patch_per_shape', type=int, default=8000, help='')
38 | parser.add_argument('--patch_radius', type=float, default=0.05, help='')
39 | parser.add_argument('--knn patch',type=bool,default=True,help='use knn neighboorhood to construct patch')
40 |
41 | parser.add_argument('--lr', type=float, default=1e-4, help='learning rate')
42 | parser.add_argument('--momentum', type=float, default=0.9, help='gradient descent momentum')
43 | parser.add_argument('--model_interval', type=int, default=5, metavar='N', help='how many batches to wait before logging training status')
44 |
45 | # others parameters
46 | parser.add_argument('--resume', type=str, default='', help='refine model at this path')
47 | parser.add_argument('--support_multiple', type=float, default=4.0, help='the multiple of support radius')
48 | parser.add_argument('--support_angle', type=int, default=15, help='')
49 | parser.add_argument('--gt_normal_mode', type=str, default='nearest', help='')
50 | parser.add_argument('--repulsion_alpha', type=float, default='0.98', help='')
51 |
52 | # evaluation parameters
53 | parser.add_argument('--eval_dir', type=str, default='./Models/all/test1', help='')
54 | parser.add_argument('--eval_iter_nums', type=int, default=3, help='')
55 |
56 | return parser.parse_args()
57 |
58 | ###################Pre-Processing Tools########################
59 | #
60 | #
61 | #
62 | #
63 | ###############################################################
64 |
65 |
66 | def get_principle_dirs(pts):
67 |
68 | pts_pca = PCA(n_components=3)
69 | pts_pca.fit(pts)
70 | principle_dirs = pts_pca.components_
71 | principle_dirs /= np.linalg.norm(principle_dirs, 2, axis=0)
72 |
73 | return principle_dirs
74 |
75 |
76 | def pca_alignment(pts, random_flag=False):
77 |
78 | pca_dirs = get_principle_dirs(pts)
79 |
80 | if random_flag:
81 |
82 | pca_dirs *= np.random.choice([-1, 1], 1)
83 |
84 | rotate_1 = compute_roatation_matrix(pca_dirs[2], [0, 0, 1], pca_dirs[1])
85 | pca_dirs = np.array(rotate_1 * pca_dirs.T).T
86 | rotate_2 = compute_roatation_matrix(pca_dirs[1], [1, 0, 0], pca_dirs[2])
87 | pts = np.array(rotate_2 * rotate_1 * np.matrix(pts.T)).T
88 |
89 | inv_rotation = np.array(np.linalg.inv(rotate_2 * rotate_1))
90 |
91 | return pts, inv_rotation
92 |
93 | def compute_roatation_matrix(sour_vec, dest_vec, sour_vertical_vec=None):
94 | # http://immersivemath.com/forum/question/rotation-matrix-from-one-vector-to-another/
95 | if np.linalg.norm(np.cross(sour_vec, dest_vec), 2) == 0 or np.abs(np.dot(sour_vec, dest_vec)) >= 1.0:
96 | if np.dot(sour_vec, dest_vec) < 0:
97 | return rotation_matrix(sour_vertical_vec, np.pi)
98 | return np.identity(3)
99 | alpha = np.arccos(np.dot(sour_vec, dest_vec))
100 | a = np.cross(sour_vec, dest_vec) / np.linalg.norm(np.cross(sour_vec, dest_vec), 2)
101 | c = np.cos(alpha)
102 | s = np.sin(alpha)
103 | R1 = [a[0] * a[0] * (1.0 - c) + c,
104 | a[0] * a[1] * (1.0 - c) - s * a[2],
105 | a[0] * a[2] * (1.0 - c) + s * a[1]]
106 |
107 | R2 = [a[0] * a[1] * (1.0 - c) + s * a[2],
108 | a[1] * a[1] * (1.0 - c) + c,
109 | a[1] * a[2] * (1.0 - c) - s * a[0]]
110 |
111 | R3 = [a[0] * a[2] * (1.0 - c) - s * a[1],
112 | a[1] * a[2] * (1.0 - c) + s * a[0],
113 | a[2] * a[2] * (1.0 - c) + c]
114 |
115 | R = np.matrix([R1, R2, R3])
116 |
117 | return R
118 |
119 |
120 | def rotation_matrix(axis, theta):
121 |
122 | # Return the rotation matrix associated with counterclockwise rotation about the given axis by theta radians.
123 |
124 | axis = np.asarray(axis)
125 | axis = axis / math.sqrt(np.dot(axis, axis))
126 | a = math.cos(theta / 2.0)
127 | b, c, d = -axis * math.sin(theta / 2.0)
128 | aa, bb, cc, dd = a * a, b * b, c * c, d * d
129 | bc, ad, ac, ab, bd, cd = b * c, a * d, a * c, a * b, b * d, c * d
130 | return np.matrix(np.array([[aa + bb - cc - dd, 2 * (bc + ad), 2 * (bd - ac)],
131 | [2 * (bc - ad), aa + cc - bb - dd, 2 * (cd + ab)],
132 | [2 * (bd + ac), 2 * (cd - ab), aa + dd - bb - cc]]))
133 |
134 |
135 |
136 |
137 | ##########################Network Tools########################
138 | #
139 | #
140 | #
141 | #
142 | ###############################################################
143 |
144 | def adjust_learning_rate(optimizer, epoch, opt):
145 |
146 | lr_shceduler(optimizer, epoch, opt.lr)
147 |
148 | def lr_shceduler(optimizer, epoch, init_lr):
149 |
150 | if epoch > 36:
151 | init_lr *= 0.5e-3
152 | elif epoch > 32:
153 | init_lr *= 1e-3
154 | elif epoch > 24:
155 | init_lr *= 1e-2
156 | elif epoch > 16:
157 | init_lr *= 1e-1
158 | for param_group in optimizer.param_groups:
159 | param_group['lr'] = init_lr
160 |
161 | ################################Ablation Study of Different Loss ###############################
162 |
163 | #论文中第一种的方案,La_proj
164 | def compute_original_1_loss(pts_pred, gt_patch_pts, gt_patch_normals, support_radius, alpha):
165 |
166 | pts_pred = pts_pred.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
167 | dist_square = (pts_pred - gt_patch_pts).pow(2).sum(2)
168 |
169 | # avoid divided by zero
170 | weight = torch.exp(-1 * dist_square / (support_radius ** 2)) + 1e-12
171 | weight = weight / weight.sum(1, keepdim=True)
172 |
173 | # key loss
174 | project_dist = ((pts_pred - gt_patch_pts) * gt_patch_normals).sum(2)
175 | imls_dist = torch.abs((project_dist * weight).sum(1))
176 |
177 | # repulsion loss
178 | max_dist = torch.max(dist_square, 1)[0]
179 |
180 | # final loss
181 | dist = torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist)
182 |
183 | return dist
184 | #使用双边滤波
185 | def compute_original_2_loss(pred_point, gt_patch_pts, gt_patch_normals, support_radius, support_angle, alpha):
186 |
187 | # Compute Spatial Weighted Function
188 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
189 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2)
190 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2))
191 |
192 | ############# Get The Nearest Normal For Predicted Point #############
193 | nearest_idx = torch.argmin(dist_square, dim=1)
194 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)])
195 | pred_point_normal = pred_point_normal.view(-1, 3)
196 | pred_point_normal = pred_point_normal.unsqueeze(1)
197 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1)
198 | ############# Get The Nearest Normal For Predicted Point #############
199 |
200 | # Compute Normal Weighted Function
201 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2)
202 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2)
203 |
204 | # # avoid divided by zero
205 | weight = weight_theta * weight_phi + 1e-12
206 | weight = weight / weight.sum(1, keepdim=True)
207 |
208 | # key loss
209 | #不同于poinfilter的地方,Pointfilter用dist_square*normal
210 | project_dist = torch.sqrt(dist_square)
211 | imls_dist = (project_dist * weight).sum(1)
212 |
213 | # repulsion loss
214 | max_dist = torch.max(dist_square, 1)[0]
215 |
216 | # final loss
217 | dist = torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist)
218 |
219 | return dist
220 | #PointCleanNet
221 | def compute_original_3_loss(pts_pred, gt_patch_pts, alpha):
222 | # PointCleanNet Loss
223 | pts_pred = pts_pred.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
224 | m = (pts_pred - gt_patch_pts).pow(2).sum(2)
225 | min_dist = torch.min(m, 1)[0]
226 | max_dist = torch.max(m, 1)[0]
227 | dist = torch.mean((alpha * min_dist) + (1 - alpha) * max_dist)
228 | # print('min_dist: %f max_dist: %f' % (alpha * torch.mean(min_dist).item(), (1 - alpha) * torch.mean(max_dist).item()))
229 | return dist * 100
230 |
231 | def compute_original_4_loss(pts_pred1,pts_pred2, gt_patch_pts,alpha):
232 | # PointCleanNet Loss
233 | pts_pred1= pts_pred1.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
234 | m = (pts_pred1 - gt_patch_pts).pow(2).sum(2)
235 | min_dist = torch.min(m, 1)[0]
236 | max_dist = torch.max(m, 1)[0]
237 |
238 | dist1 =torch.mean((alpha * min_dist) + (1 - alpha) * max_dist)
239 |
240 | pts_pred2= pts_pred2.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
241 | m = (pts_pred2 - gt_patch_pts).pow(2).sum(2)
242 | min_dist = torch.min(m, 1)[0]
243 | max_dist = torch.max(m, 1)[0]
244 | dist2 = torch.mean((alpha * min_dist) + (1 - alpha) * max_dist)
245 |
246 | dist=dist1+dist2
247 |
248 | # print('min_dist: %f max_dist: %f' % (alpha * torch.mean(min_dist).item(), (1 - alpha) * torch.mean(max_dist).item()))
249 | return dist * 100
250 |
251 | def compute_original_5_loss(pts_pred1,pts_pred2,normal, gt_patch_pts,gt_normal,alpha):
252 | # PointCleanNet Loss
253 | Batchsize=gt_patch_pts.size(0)
254 | pts_pred1= pts_pred1.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
255 | m = (pts_pred1 - gt_patch_pts).pow(2).sum(2)
256 | min_dist= torch.min(m, 1)[0]
257 | max_dist = torch.max(m, 1)[0]
258 | dist1 =torch.mean((alpha * min_dist) + (1 - alpha) * max_dist)
259 |
260 |
261 | pred_ponts=pts_pred2
262 | pts_pred2= pts_pred2.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
263 | m = (pts_pred2 - gt_patch_pts).pow(2).sum(2)
264 | min_dist,idx= torch.min(m, 1)
265 | max_dist = torch.max(m, 1)[0]
266 | dist2 = torch.mean((alpha * min_dist) + (1 - alpha) * max_dist)
267 |
268 | idx=idx.unsqueeze(-1).unsqueeze(-1)
269 | nearestpoint=torch.gather(gt_patch_pts,dim=1,index=idx.expand(Batchsize,1,3))
270 | nearestpoint=nearestpoint.squeeze(1)
271 | point=(pred_ponts-nearestpoint).unsqueeze(-1)
272 | pointnormal=normal.unsqueeze(1)
273 | oth=torch.abs(torch.bmm(pointnormal,point))
274 | oth=oth.mean()*100
275 |
276 | normal_dist=(normal-gt_normal).pow(2).sum(1).mean()
277 | dist=(dist1+dist2)*100
278 | out=dist+normal_dist+oth
279 |
280 |
281 | # print('min_dist: %f max_dist: %f' % (alpha * torch.mean(min_dist).item(), (1 - alpha) * torch.mean(max_dist).item()))
282 | return out,dist,normal_dist,oth
283 |
284 |
285 | def compute_original_6_loss(pts_pred1,gt_patch_pts,normal,gtnormal, alpha):
286 | # PointCleanNet Loss
287 | pts_pred1= pts_pred1.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
288 | m = (pts_pred1 - gt_patch_pts).pow(2).sum(2)
289 | min_dist = torch.min(m, 1)[0]
290 | max_dist = torch.max(m, 1)[0]
291 |
292 | dist =torch.mean((alpha * min_dist) + (1 - alpha) * max_dist)
293 | dist=dist*100
294 |
295 | loss1= torch.nn.functional.nll_loss(normal, gtnormal)
296 |
297 | loss=loss1+dist
298 | # print('min_dist: %f max_dist: %f' % (alpha * torch.mean(min_dist).item(), (1 - alpha) * torch.mean(max_dist).item()))
299 | return loss,dist,loss1
300 | ################################Ablation Study of Different Loss ###############################
301 | #作者改进的双边滤波
302 | def compute_original_7_loss(pts_pred1,gt_patch_pts,normal,gtnormal,patch_center_normal,alpha):
303 | # PointCleanNet Loss
304 | Batchsize = gt_patch_pts.size(0)
305 | pred_ponts = pts_pred1
306 | pts_pred1= pts_pred1.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
307 | m = (pts_pred1 - gt_patch_pts).pow(2).sum(2)
308 | min_dist,idx= torch.min(m, 1)
309 | max_dist = torch.max(m, 1)[0]
310 |
311 | dist =torch.mean((alpha * min_dist) + (1 - alpha) * max_dist)
312 | dist=dist*100
313 | '''
314 | idx = idx.unsqueeze(-1).unsqueeze(-1)
315 | nearestpoint = torch.gather(gt_patch_pts, dim=1, index=idx.expand(Batchsize, 1, 3))
316 | nearestpoint = nearestpoint.squeeze(1)
317 | point = (pred_ponts - nearestpoint).unsqueeze(-1)
318 | pointnormal = patch_center_normal
319 | oth = torch.abs(torch.bmm(pointnormal, point))
320 | oth = oth.mean() * 100
321 | '''
322 | #点法向量相乘是否要加系数?
323 | loss1= torch.nn.functional.nll_loss(normal, gtnormal)
324 |
325 | loss=loss1+dist
326 | # print('min_dist: %f max_dist: %f' % (alpha * torch.mean(min_dist).item(), (1 - alpha) * torch.mean(max_dist).item()))
327 | return loss,dist,loss1
328 | def compute_original_8_loss(pred_point, gt_patch_pts, gt_patch_normals,deltnorma,prednormal,support_radius, support_angle, alpha):
329 |
330 | # Our Loss
331 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
332 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2)
333 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2))
334 |
335 | nearest_idx = torch.argmin(dist_square, dim=1)
336 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)])
337 | pred_point_normal = pred_point_normal.view(-1, 3)
338 | pred_point_normal = pred_point_normal.unsqueeze(1)
339 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1)
340 |
341 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2)
342 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2)
343 |
344 | # # avoid divided by zero
345 | weight = weight_theta * weight_phi + 1e-12
346 | weight = weight / weight.sum(1, keepdim=True)
347 |
348 | # key loss
349 | project_dist = torch.abs(((pred_point - gt_patch_pts) * gt_patch_normals).sum(2))
350 | imls_dist = (project_dist * weight).sum(1)
351 |
352 | # repulsion loss
353 | max_dist = torch.max(dist_square, 1)[0]
354 |
355 | loss1 = torch.nn.functional.nll_loss(prednormal, deltnorma)
356 |
357 | # final loss
358 | dist = torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist)
359 |
360 | loss=dist+loss1
361 |
362 | return loss
363 | def compute_bilateral_loss(pred_point,pred_normal, gt_patch_pts, gt_patch_normals,predweight,support_radius, support_angle, alpha,top_idx):
364 |
365 | # Our Loss
366 | # Our Loss
367 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
368 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2)
369 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2))
370 |
371 | nearest_idx = torch.argmin(dist_square, dim=1)
372 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)])
373 | pred_point_normal = pred_point_normal.view(-1, 3)
374 | pred_point_normal = pred_point_normal.unsqueeze(1)
375 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1)
376 |
377 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2)
378 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2)
379 |
380 | # # avoid divided by zero
381 | weight = weight_theta * weight_phi + 1e-12
382 | weight = weight / weight.sum(1, keepdim=True)
383 |
384 | # key loss
385 | project_dist = torch.abs(((pred_point - gt_patch_pts) * gt_patch_normals).sum(2))
386 | imls_dist = (project_dist * weight).sum(1)
387 |
388 | # repulsion loss
389 | max_dist = torch.max(dist_square, 1)[0]
390 |
391 | # final loss
392 | loss1 = 100*torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist)
393 |
394 | pred_normal=pred_normal.unsqueeze(1)
395 | pred_normal=pred_normal.repeat(1,gt_patch_normals.size(1),1)
396 | loss2=10*(pred_normal-pred_point_normal).pow(2).sum(2).mean(1).mean(0)
397 |
398 | oth_loss = (pred_normal*(pred_point-gt_patch_pts)).sum(2).pow(2)
399 | oth_loss = 10*(oth_loss).mean()
400 |
401 | loss=loss1+loss2+oth_loss
402 |
403 | return loss,loss1,loss2,oth_loss
404 |
405 | def compute_bilateral_loss1(pred_point,pred_normal, gt_patch_pts, gt_patch_normals,predweight,support_radius, support_angle, alpha,top_idx):
406 |
407 | # Our Loss
408 | # Our Loss
409 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
410 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2)
411 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2))
412 |
413 | nearest_idx = torch.argmin(dist_square, dim=1)
414 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)])
415 | pred_point_normal = pred_point_normal.view(-1, 3)
416 | pred_point_normal = pred_point_normal.unsqueeze(1)
417 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1)
418 |
419 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2)
420 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2)
421 |
422 | # # avoid divided by zero
423 | weight = weight_theta * weight_phi + 1e-12
424 | weight = weight / weight.sum(1, keepdim=True)
425 |
426 | # key loss
427 | project_dist = torch.abs(((pred_point - gt_patch_pts) * gt_patch_normals).sum(2))
428 | imls_dist = (project_dist * weight).sum(1)
429 |
430 | # repulsion loss
431 | max_dist = torch.max(dist_square, 1)[0]
432 |
433 | # final loss
434 | loss1 = 100*torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist)
435 |
436 | pred_normal=pred_normal.unsqueeze(1)
437 | pred_normal=pred_normal.repeat(1,gt_patch_normals.size(1),1)
438 | loss2=10*(pred_normal-pred_point_normal).pow(2).sum(2).mean(1).mean(0)
439 |
440 | # oth_loss =predweight*(pred_normal*(pred_point-gt_patch_pts)).sum(2).pow(2)
441 | # oth_loss = 10*(oth_loss).mean()
442 |
443 | loss=loss1+loss2
444 |
445 | return loss,loss1,loss2
446 |
447 |
448 | def compute_bilateral_loss_with_repulsion(pred_point, gt_patch_pts, gt_patch_normals, support_radius, support_angle, alpha):
449 |
450 | # Our Loss
451 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
452 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2)
453 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2))
454 |
455 | nearest_idx = torch.argmin(dist_square, dim=1)
456 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)])
457 | pred_point_normal = pred_point_normal.view(-1, 3)
458 | pred_point_normal = pred_point_normal.unsqueeze(1)
459 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1)
460 |
461 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2)
462 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2)
463 |
464 | # # avoid divided by zero
465 | weight = weight_theta * weight_phi + 1e-12
466 | weight = weight / weight.sum(1, keepdim=True)
467 |
468 | # key loss
469 | project_dist = torch.abs(((pred_point - gt_patch_pts) * gt_patch_normals).sum(2))
470 | imls_dist = (project_dist * weight).sum(1)
471 |
472 | # repulsion loss
473 | max_dist = torch.max(dist_square, 1)[0]
474 |
475 | # final loss
476 | dist = torch.mean((alpha * imls_dist) + (1 - alpha) * max_dist)
477 |
478 | return dist
479 | #loss =compute_original_L2_loss(x, gt_patch,delta_normal,xx,opt.repulsion_alpha)
480 | def compute_original_L2_loss(pts_pred, gt_patch_pts,gt_mask,pred_mask,alpha):
481 | # PointCleanNet Loss
482 |
483 | #classficaton loss
484 | loss1=torch.nn.functional.nll_loss(pred_mask,gt_mask)
485 | #loss1=100*loss1
486 | loss2=compute_original_3_loss(pts_pred,gt_patch_pts,alpha)
487 | loss=0.5*loss1+0.5*loss2
488 |
489 | return loss,loss1,loss2
490 |
491 | def compute_orginal_Pointfilter_loss(pred_point, gt_patch_pts, gt_patch_normals, support_radius, support_angle,gt_mask,pred_mask,alpha):
492 |
493 | loss1=torch.nn.functional.nll_loss(pred_mask,gt_mask)
494 | loss2=compute_bilateral_loss_with_repulsion(pred_point,gt_patch_pts,gt_patch_normals,support_radius,support_angle,alpha)
495 | loss=0.5*loss1+0.5*loss2
496 |
497 | return loss,loss1,loss2
498 | def cos_angle(v1, v2):
499 |
500 | return torch.bmm(v1.unsqueeze(1), v2.unsqueeze(2)).view(-1) / torch.clamp(v1.norm(2, 1) * v2.norm(2, 1), min=0.000001)
501 |
502 | def Patch_Normal_loss_Compute(pred_normas,gt_normals,top_idx):
503 |
504 | B,k=top_idx.size()
505 | gt_normals=torch.gather(gt_normals, dim=1, index=top_idx.unsqueeze(-1).expand(B, k, 3))#[B,256,3]
506 | # normal_loss=torch.min((pred_normas-gt_normals).pow(2).sum(2),(pred_normas+gt_normals).pow(2).sum(2)).mean(1).mean()
507 | normal_loss = (pred_normas - gt_normals).pow(2).sum(2).mean(1).mean()
508 |
509 | return normal_loss
510 |
511 | def Normal_loss_Compute(pred_normal,gt_normal):
512 | gt_normal=gt_normal.squeeze(1)
513 | # normal_loss=torch.min((pred_normas-gt_normals).pow(2).sum(2),(pred_normas+gt_normals).pow(2).sum(2)).mean(1).mean()
514 | normal_loss = (pred_normal - gt_normal).pow(1).sum(1).mean()
515 |
516 | return normal_loss
517 |
518 | def Cos_Compute_Normal_Loss(pre_normals,gt_normals):
519 | loss=(1 - torch.abs(cos_angle(pre_normals,gt_normals))).pow(2).mean()
520 | return loss
521 | def Sin_Compute_Normal_Loss(pre_normals,gt_normals):
522 |
523 | loss= 0.5*torch.norm(torch.cross(pre_normals, gt_normals, dim=-1), p=2, dim=1).mean()
524 | return loss
525 | '''
526 | def Otho_Loss(gt_normals,gt_normal,gt_points,pre_point,index):
527 |
528 | k=index.size(1)
529 | B=index.size(0)
530 | gt_normals=torch.gather(gt_normals,dim=1,index=index.unsqueeze(-1).expand(B,k,3))
531 | gt_points=torch.gather(gt_points,dim=1,index=index.unsqueeze(-1).expand(B,k,3))
532 |
533 | pre_point=pre_point.unsqueeze(-1).repeat(1,1,k).transpose(2,1)
534 | loss1=(torch.abs(gt_normals*(pre_point-gt_points))).sum(-1).sum(1).mean(0)
535 |
536 | gt_normal=gt_normal.repeat(1,k,1)
537 | loss2=(torch.abs(gt_normal*(pre_point-gt_points))).sum(-1).sum(1).mean(0)
538 |
539 | loss=loss1+loss2
540 | return loss
541 | '''
542 | def Otho_Loss(pred_normal,gt_point,pred_point):
543 |
544 | pred_point=pred_point.unsqueeze(1)
545 | # pred_normal=pred_normal.unsqueeze(1)
546 | point_constrain=(pred_point-gt_point).pow(2).sum(2).mean(0)
547 |
548 | point_normal=(pred_normal*(pred_point-gt_point)).sum(2).unsqueeze(-1)
549 | normal_point_constrain=(point_normal*pred_normal).pow(2).sum(2).mean(0)
550 |
551 | constrain=torch.abs(point_constrain-normal_point_constrain)
552 |
553 | return constrain
554 |
555 |
556 |
557 |
558 |
559 | def compute_loss(pred_point,pred_normal, gt_patch_pts, gt_patch_normals,predweight,gt_normal,support_radius,support_angle,alpha):
560 | # Our Loss
561 | orginal_point=pred_point
562 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)#[B,3]-->[B,N,3]
563 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2)
564 |
565 | nearest_idx = torch.argmin(dist_square, dim=1)
566 | neareat_point = torch.cat([gt_patch_pts[i, index, :] for i, index in enumerate(nearest_idx)])
567 | neareat_point = neareat_point.view(-1, 3)#[64,3]
568 |
569 | max_dist = torch.max(dist_square, 1)[0]
570 | max_dist=torch.mean(max_dist)
571 | # loss1=10*(torch.abs((orginal_point-neareat_point).pow(2).sum(1)-(pred_normal*(orginal_point-neareat_point)).sum(1).pow(2))).mean()
572 | gt_normal=gt_normal.squeeze(1)
573 | # key loss
574 | pred_normal=pred_normal.unsqueeze(1).repeat(1,gt_patch_pts.size(1),1)
575 | project_dist =(gt_patch_normals*(pred_point - gt_patch_pts)).sum(2).pow(2)#[b,n]
576 | normal_dist=(pred_normal*(gt_patch_pts-0)).sum(2).pow(2)
577 | oth_loss=100*((normal_dist+project_dist)*predweight).sum(1).mean()
578 |
579 | dist=oth_loss+max_dist
580 |
581 | return dist,oth_loss,max_dist
582 |
583 | def compute_loss1(pred_point,pred_normal, gt_patch_pts, gt_patch_normals,predweight,gt_normal,support_radius,support_angle,alpha):
584 | # Our Loss
585 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)#[B,3]-->[B,N,3]
586 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2)
587 |
588 | min_dist=torch.min(dist_square,1)[0]
589 | max_dist = torch.max(dist_square, 1)[0]
590 | # final loss
591 | loss1 = 100*torch.mean((alpha * min_dist) + (1 - alpha) * max_dist)
592 | gt_normal=gt_normal.squeeze(1)
593 | loss2=10*(pred_normal-gt_normal).pow(2).sum(1).mean()
594 | # key loss
595 | pred_normal=pred_normal.unsqueeze(1).repeat(1,gt_patch_pts.size(1),1)
596 | project_dist =(gt_patch_normals*(pred_point - gt_patch_pts)).sum(2).pow(2)#[b,n]
597 | normal_dist=(pred_normal*(pred_point-gt_patch_pts)).sum(2).pow(2)
598 | oth_loss=10*((normal_dist+project_dist)*predweight).sum(1).mean()
599 | # regularizer = - torch.mean(predweight.log())
600 | dist=0.5*loss1+0.5*loss2+oth_loss
601 |
602 |
603 | return dist,loss1,loss2,oth_loss
604 | def compute_ditstance(gt_patch_normals,gt_patch_pts,pred_point):
605 | dist=(gt_patch_normals*(pred_point-gt_patch_pts)).sum(2)
606 | return dist
607 | def comtrative_loss(pred_point,pred_normal, gt_patch_pts, gt_patch_normals,topidx,gt_normal,support_radius,support_angle,alpha):
608 | device = torch.device('cuda')
609 | B,N,C=gt_patch_pts.size()
610 | label=torch.zeros(B,N,1,device=device)
611 | idx_base = torch.arange(0, 64,device=device).view(-1, 1) *N
612 | topidx=topidx+idx_base
613 | topidx=topidx.view(-1)
614 | label=label.view(B*N,-1)
615 | label[topidx,:]=1
616 | label=label.view(B,N)
617 | margin=0.8
618 |
619 | pred_point = pred_point.unsqueeze(1).repeat(1, gt_patch_pts.size(1), 1)
620 | dist_square = ((pred_point - gt_patch_pts) ** 2).sum(2)
621 | weight_theta = torch.exp(-1 * dist_square / (support_radius ** 2))
622 |
623 | nearest_idx = torch.argmin(dist_square, dim=1)
624 | pred_point_normal = torch.cat([gt_patch_normals[i, index, :] for i, index in enumerate(nearest_idx)])
625 | pred_point_normal = pred_point_normal.view(-1, 3)
626 | pred_point_normal = pred_point_normal.unsqueeze(1)
627 | pred_point_normal = pred_point_normal.repeat(1, gt_patch_normals.size(1), 1)
628 |
629 | normal_proj_dist = (pred_point_normal * gt_patch_normals).sum(2)
630 | weight_phi = torch.exp(-1 * ((1 - normal_proj_dist) / (1 - np.cos(support_angle)))**2)
631 |
632 | # # avoid divided by zero
633 | weight = weight_theta * weight_phi + 1e-12
634 | weight = weight / weight.sum(1, keepdim=True)
635 | # r1=label*(compute_ditstance(gt_patch_normals,gt_patch_pts,pred_point).pow(2))
636 | # r2=(1-label)*((torch.clamp(margin-compute_ditstance(gt_patch_normals,gt_patch_pts,pred_point),min=0.0)).pow(2))
637 | loss1=label*(weight*compute_ditstance(gt_patch_normals,gt_patch_pts,pred_point).pow(2))+(1-label)*(weight*(torch.clamp(margin-compute_ditstance(gt_patch_normals,gt_patch_pts,pred_point),min=0.0)).pow(2))
638 | loss1=1000*loss1.mean()
639 | gt_normal=gt_normal.squeeze(1)
640 | loss2 =(pred_normal - gt_normal).pow(2).sum(1).mean()
641 | loss=10*(loss1+loss2)
642 | return loss,loss1,loss2
643 |
644 |
645 | if __name__ == '__main__':
646 |
647 | pred_normal=torch.rand(64,3)
648 | pred_point=torch.rand(64,3)
649 | gt_normal=torch.rand(64,3)
650 | gt_patch_pts=torch.rand(64,512,3)
651 | gt_patch_normals=torch.rand(64,512,3)
652 | support_radius=torch.rand(64,1)
653 | support_angle=0.23898
654 | alpha=0.97
655 | predweight=torch.rand(64,512)
656 | # compute_bilateral_loss_with_repulsion(pred_point,pred_normal,gt_patch_pts,gt_patch_normal,predweight,support_radius,support_angle,alpha)
657 | compute_loss(pred_point,pred_normal,gt_patch_pts,gt_patch_normals,predweight,gt_normal,alpha)
--------------------------------------------------------------------------------