├── README.md
├── imgs
├── architecture.png
└── attention.png
├── misc
└── ops.py
└── networks
├── cls
└── pct.py
└── seg
└── pct_partseg.py
/README.md:
--------------------------------------------------------------------------------
1 | # PCT: Point Cloud Transformer
2 |
3 | This is a Jittor implementation of PCT: Point Cloud Transformer.
4 |
5 | Paper link: https://arxiv.org/pdf/2012.09688.pdf
6 |
7 |
8 | ## News :
9 |
10 | * 2021.3.31 : We try to add simple position embedding in each self-attention layer, we get a more stable training process and 93.3% (5 run best) accuracy on modelnet40 dataset. Code updates in classification network.
11 | * 2021.3.29 : PCT has been accepted by Computational Visual Media Journal (CVMJ).
12 |
13 |
14 | ## Astract
15 |
16 |
17 | The irregular domain and lack of ordering make it challenging to design deep neural networks for point cloud processing. This paper presents a novel framework named Point Cloud Transformer(PCT) for point cloud learning. PCT is based on Transformer, which achieves huge success in natural language processing and displays great potential in image processing. It is inherently permutation invariant for processing a sequence of points, making it well-suited for point cloud learning. To better capture local context within the point cloud, we enhance input embedding with the support of farthest point sampling and nearest neighbor search. Extensive experiments demonstrate that the PCT achieves the state-of-the-art performance on shape classification, part segmentation and normal estimation tasks
18 |
19 |
20 | 
21 |
22 |
23 | ## Architecture
24 |
25 |
26 | 
27 |
28 |
29 |
30 | ## Jittor
31 |
32 | Jittor is a high-performance deep learning framework which is easy to learn and use. It provides interfaces like Pytorch.
33 |
34 | You can learn how to use Jittor in following links:
35 |
36 | Jittor homepage: https://cg.cs.tsinghua.edu.cn/jittor/
37 |
38 | Jittor github: https://github.com/Jittor/jittor
39 |
40 | If you has any questions about Jittor, you can ask in Jittor developer QQ Group: 761222083
41 |
42 | ## Other implementation
43 |
44 | ##### Version 1 : https://github.com/Strawberry-Eat-Mango/PCT_Pytorch (Pytorch version with classification acc 93.2% on ModelNet40)
45 | ##### Version 2 : https://github.com/qq456cvb/Point-Transformers (Pytorch version with classification acc 92.6% on ModelNet40)
46 | #### About part segmentation, if you want to reproduce the part segmentation results, you can refer this : https://github.com/AnTao97/dgcnn.pytorch
47 |
48 |
49 |
50 |
51 |
53 | ## Citation
54 |
55 | If it is helpful for your work, please cite this paper:
56 | ```
57 | @article{Guo_2021,
58 | title={PCT: Point cloud transformer},
59 | volume={7},
60 | ISSN={2096-0662},
61 | url={http://dx.doi.org/10.1007/s41095-021-0229-5},
62 | DOI={10.1007/s41095-021-0229-5},
63 | number={2},
64 | journal={Computational Visual Media},
65 | publisher={Springer Science and Business Media LLC},
66 | author={Guo, Meng-Hao and Cai, Jun-Xiong and Liu, Zheng-Ning and Mu, Tai-Jiang and Martin, Ralph R. and Hu, Shi-Min},
67 | year={2021},
68 | month={Apr},
69 | pages={187–199}
70 | }
71 | ```
72 |
--------------------------------------------------------------------------------
/imgs/architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MenghaoGuo/PCT/b7152d6ccfe15c7d2097d3200a8591d73bcd442a/imgs/architecture.png
--------------------------------------------------------------------------------
/imgs/attention.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/MenghaoGuo/PCT/b7152d6ccfe15c7d2097d3200a8591d73bcd442a/imgs/attention.png
--------------------------------------------------------------------------------
/misc/ops.py:
--------------------------------------------------------------------------------
1 | import math
2 | import numpy as np
3 | import jittor as jt
4 | from jittor import nn
5 | from jittor.contrib import concat
6 |
7 | def index_points(points, idx):
8 | """
9 | Input:
10 | points: input points data, [B, N, C]
11 | idx: sample index data, [B, S]
12 | Return:
13 | new_points:, indexed points data, [B, S, C]
14 | """
15 | B = points.shape[0]
16 | view_shape = list(idx.shape)
17 | view_shape[1:] = [1] * (len(view_shape) - 1)
18 | repeat_shape = list(idx.shape)
19 | repeat_shape[0] = 1
20 | batch_indices = jt.array(np.arange(B, dtype=np.int32)).view(view_shape).repeat(repeat_shape)
21 | new_points = points[batch_indices, idx, :]
22 | return new_points
23 |
24 |
25 | def square_distance(src, dst):
26 | """
27 | Calculate Euclid distance between each two points.
28 | src^T * dst = xn * xm + yn * ym + zn * zm;
29 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
30 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
31 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
32 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
33 | Input:
34 | src: source points, [B, N, C]
35 | dst: target points, [B, M, C]
36 | Output:
37 | dist: per-point square distance, [B, N, M]
38 | """
39 |
40 | B, N, _ = src.shape
41 | _, M, _ = dst.shape
42 | # print ('before matmul size', src.size(), dst.size())
43 | dist = -2 * nn.matmul(src, dst.permute(0, 2, 1))
44 | dist += jt.sum(src ** 2, -1).view(B, N, 1)
45 | dist += jt.sum(dst ** 2, -1).view(B, 1, M)
46 | return dist
47 |
48 |
49 | class PointNetFeaturePropagation(nn.Module):
50 | def __init__(self, in_channel, mlp):
51 | super(PointNetFeaturePropagation, self).__init__()
52 | self.mlp_convs = nn.ModuleList()
53 | self.mlp_bns = nn.ModuleList()
54 | last_channel = in_channel
55 | self.relu = nn.ReLU()
56 | for out_channel in mlp:
57 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1))
58 | self.mlp_bns.append(nn.BatchNorm1d(out_channel))
59 | last_channel = out_channel
60 |
61 | def execute(self, xyz1, xyz2, points1, points2):
62 | """
63 | Input:
64 | xyz1: input points position data, [B, C, N]
65 | xyz2: sampled input points position data, [B, C, S]
66 | points1: input points data, [B, D, N]
67 | points2: input points data, [B, D, S]
68 | Return:
69 | new_points: upsampled points data, [B, D', N]
70 | """
71 | # xyz1 = xyz1.permute(0, 2, 1)
72 | # xyz2 = xyz2.permute(0, 2, 1)
73 |
74 | # points2 = points2.permute(0, 2, 1)
75 | B, N, C = xyz1.shape
76 | _, S, _ = xyz2.shape
77 |
78 | if S == 1:
79 | interpolated_points = points2.repeat(1, N, 1)
80 | else:
81 | dists = square_distance(xyz1, xyz2)
82 | idx, dists = jt.argsort(dists, dim=-1)
83 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3]
84 |
85 | dist_recip = 1.0 / (dists + 1e-8)
86 | norm = jt.sum(dist_recip, dim=2, keepdims=True)
87 | weight = dist_recip / norm
88 | interpolated_points = jt.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2)
89 |
90 | if points1 is not None:
91 | # points1 = points1.permute(0, 2, 1)
92 | new_points = concat ([points1, interpolated_points], dim=-1)
93 | else:
94 | new_points = interpolated_points
95 |
96 | new_points = new_points.permute(0, 2, 1)
97 | # l = len(self.mlp_convs)
98 | for i, conv in self.mlp_convs.layers.items():
99 | # conv = self.mlp_convs[i]
100 | bn = self.mlp_bns[i]
101 | new_points = self.relu(bn(conv(new_points)))
102 | return new_points.permute(0, 2, 1)
103 |
104 |
105 | def optimal_block(batch_size):
106 | return 2 ** int(math.log(batch_size))
107 |
108 |
109 | class FurthestPointSampler(nn.Module):
110 | cuda_src='''
111 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i,
112 | int idx1, int idx2) {
113 | const float v1 = dists[idx1], v2 = dists[idx2];
114 | const int i1 = dists_i[idx1], i2 = dists_i[idx2];
115 | dists[idx1] = max(v1, v2);
116 | dists_i[idx1] = v2 > v1 ? i2 : i1;
117 | }
118 |
119 | __global__ void furthest_point_sampling_kernel (
120 | int b, int n, int m, int block_size,
121 | const float *__restrict__ dataset,
122 | float *__restrict__ temp,
123 | int *__restrict__ idxs) {
124 |
125 | if (m <= 0) return;
126 |
127 | extern __shared__ int dists_i[];
128 | float *dists = (float *) &dists_i[block_size];
129 |
130 | int batch_index = blockIdx.x;
131 | dataset += batch_index * n * 3;
132 | temp += batch_index * n;
133 | idxs += batch_index * m;
134 |
135 | int tid = threadIdx.x;
136 | const int stride = block_size;
137 |
138 | int old = 0;
139 | if (threadIdx.x == 0) idxs[0] = old;
140 |
141 | // initialize temp with INF
142 | for (int k = tid; k < n; k += stride)
143 | temp[k] = 1e10;
144 |
145 | __syncthreads();
146 | for (int j = 1; j < m; j++) {
147 | int besti = 0;
148 | float best = -1;
149 | float x1 = dataset[old * 3 + 0];
150 | float y1 = dataset[old * 3 + 1];
151 | float z1 = dataset[old * 3 + 2];
152 | for (int k = tid; k < n; k += stride) {
153 | float x2, y2, z2;
154 | x2 = dataset[k * 3 + 0];
155 | y2 = dataset[k * 3 + 1];
156 | z2 = dataset[k * 3 + 2];
157 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2);
158 | if (mag <= 1e-3) continue;
159 |
160 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1);
161 |
162 | float d2 = min(d, temp[k]);
163 | temp[k] = d2;
164 | besti = d2 > best ? k : besti;
165 | best = d2 > best ? d2 : best;
166 | }
167 | dists[tid] = best;
168 | dists_i[tid] = besti;
169 | __syncthreads();
170 |
171 | if (block_size >= 512) {
172 | if (tid < 256) {
173 | __update(dists, dists_i, tid, tid + 256);
174 | }
175 | __syncthreads();
176 | }
177 | if (block_size >= 256) {
178 | if (tid < 128) {
179 | __update(dists, dists_i, tid, tid + 128);
180 | }
181 | __syncthreads();
182 | }
183 | if (block_size >= 128) {
184 | if (tid < 64) {
185 | __update(dists, dists_i, tid, tid + 64);
186 | }
187 | __syncthreads();
188 | }
189 | if (block_size >= 64) {
190 | if (tid < 32) {
191 | __update(dists, dists_i, tid, tid + 32);
192 | }
193 | __syncthreads();
194 | }
195 | if (block_size >= 32) {
196 | if (tid < 16) {
197 | __update(dists, dists_i, tid, tid + 16);
198 | }
199 | __syncthreads();
200 | }
201 | if (block_size >= 16) {
202 | if (tid < 8) {
203 | __update(dists, dists_i, tid, tid + 8);
204 | }
205 | __syncthreads();
206 | }
207 | if (block_size >= 8) {
208 | if (tid < 4) {
209 | __update(dists, dists_i, tid, tid + 4);
210 | }
211 | __syncthreads();
212 | }
213 | if (block_size >= 4) {
214 | if (tid < 2) {
215 | __update(dists, dists_i, tid, tid + 2);
216 | }
217 | __syncthreads();
218 | }
219 | if (block_size >= 2) {
220 | if (tid < 1) {
221 | __update(dists, dists_i, tid, tid + 1);
222 | }
223 | __syncthreads();
224 | }
225 |
226 | old = dists_i[0];
227 | if (tid == 0) idxs[j] = old;
228 | }
229 | }
230 |
231 | int block_size = #block_size;
232 |
233 | float *temp;
234 | cudaMallocManaged(&temp, in0_shape0 * in0_shape1 * sizeof(float));
235 |
236 | furthest_point_sampling_kernel<<>>(
237 | in0_shape0,
238 | in0_shape1,
239 | out_shape1,
240 | block_size,
241 | in0_p,
242 | temp,
243 | out_p
244 | );
245 | cudaDeviceSynchronize();
246 | cudaFree(temp);
247 | '''
248 | def __init__(self, n_samples):
249 | super().__init__()
250 | self.n_samples = n_samples
251 |
252 | def execute(self, x):
253 | '''
254 | Parameters
255 | ----------
256 | x: jt.Var, (B, N, 3)
257 |
258 | Returns
259 | -------
260 | y: jt.Var, (B, n_samples, 3)
261 | '''
262 | batch_size, n_points, n_coords = x.shape
263 |
264 | assert self.n_samples <= n_points
265 | assert n_coords == 3
266 | assert x.dtype == 'float32'
267 |
268 | block_size = optimal_block(batch_size)
269 |
270 | cuda_src = self.cuda_src.replace('#block_size', str(block_size))
271 |
272 | idxs_shape = [batch_size, self.n_samples]
273 | idxs = jt.code(idxs_shape, 'int32', [x,], cuda_src=cuda_src)
274 |
275 | y = x.reindex([batch_size, self.n_samples, 3], [
276 | 'i0', # Batchid
277 | '@e0(i0, i1)', # Nid
278 | 'i2'
279 | ], extras=[idxs])
280 |
281 | return y, idxs
282 |
283 | class BallQueryGrouper(nn.Module):
284 | cuda_src = '''
285 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius,
286 | int nsample,
287 | const float *__restrict__ new_xyz,
288 | const float *__restrict__ xyz,
289 | int *__restrict__ idx,
290 | int *__restrict__ cnt) {
291 | int batch_index = blockIdx.x;
292 | xyz += batch_index * n * 3;
293 | new_xyz += batch_index * m * 3;
294 | idx += m * nsample * batch_index;
295 | cnt += batch_index * m;
296 |
297 | int index = threadIdx.x;
298 | int stride = blockDim.x;
299 |
300 | float radius2 = radius * radius;
301 | for (int j = index; j < m; j += stride) {
302 | float new_x = new_xyz[j * 3 + 0];
303 | float new_y = new_xyz[j * 3 + 1];
304 | float new_z = new_xyz[j * 3 + 2];
305 | cnt[j] = 0;
306 |
307 | for (int k = 0; k < n && cnt[j] < nsample; ++k) {
308 | float x = xyz[k * 3 + 0];
309 | float y = xyz[k * 3 + 1];
310 | float z = xyz[k * 3 + 2];
311 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) +
312 | (new_z - z) * (new_z - z);
313 |
314 | if (d2 < radius2) {
315 | if (cnt[j] == 0) {
316 | for (int l = 0; l < nsample; ++l)
317 | idx[j * nsample + l] = k;
318 | }
319 | idx[j * nsample + cnt[j]] = k;
320 | ++cnt[j];
321 | }
322 | }
323 | }
324 | }
325 |
326 | int block_size = #block_size;
327 |
328 | query_ball_point_kernel<<>>(
329 | in0_shape0, in1_shape1, in0_shape1, #radius, #nsample,
330 | in0_p, in1_p, out0_p, out1_p
331 | );
332 | '''
333 | def __init__(self, radius, n_samples, use_xyz):
334 | super().__init__()
335 | self.radius = radius
336 | self.n_samples = n_samples
337 | self.use_xyz = use_xyz
338 |
339 | def execute(self, new_xyz, pointset, feature):
340 | '''
341 | Parameters
342 | ----------
343 | xyz: jt.Var, (B, N, 3)
344 | features: jt.Var, (B, N, C)
345 |
346 | Returns
347 | -------
348 | new_feature: jt.Var, (B, N, n_samples, C)
349 | '''
350 | batch_size_x, n_input, n_coords = new_xyz.shape
351 | assert n_coords == 3
352 |
353 | batch_size_p, n_points, n_coords = pointset.shape
354 | assert n_coords == 3
355 | assert batch_size_x == batch_size_p
356 |
357 | if feature is not None:
358 | batch_size_f, n_points_f, n_feature = feature.shape
359 | assert batch_size_x == batch_size_f
360 | assert n_points == n_points_f
361 |
362 | block_size = optimal_block(batch_size_x)
363 |
364 | cuda_src = self.cuda_src.replace('#block_size', str(block_size)) \
365 | .replace('#radius', str(self.radius)) \
366 | .replace('#nsample', str(self.n_samples))
367 |
368 | idxs_shape = [batch_size_x, n_input, self.n_samples]
369 | cnts_shape = [batch_size_x, n_input]
370 | idxs, cnts = jt.code(
371 | [idxs_shape, cnts_shape],
372 | ['int32', 'int32'],
373 | [new_xyz, pointset],
374 | cuda_src=cuda_src
375 | )
376 |
377 | pc_shape = [batch_size_x, n_input, self.n_samples, 3]
378 | new_pointset = pointset.reindex(pc_shape, [
379 | 'i0',
380 | '@e0(i0, i1, i2)',
381 | 'i3',
382 | ], extras=[idxs])
383 |
384 | if feature is not None:
385 | feature_shape = [batch_size_x, n_input, self.n_samples, n_feature]
386 | new_feature = feature.reindex(feature_shape, [
387 | 'i0', # Batchid
388 | '@e0(i0, i1, i2)', # Nid
389 | 'i3', # Featureid
390 | ], extras=[idxs])
391 | else:
392 | new_feature = None
393 |
394 | if self.use_xyz:
395 | local_xyz = new_pointset - new_xyz.unsqueeze(dim=2)
396 | if new_feature is not None:
397 | new_feature = jt.contrib.concat([local_xyz, new_feature], dim=-1)
398 | else:
399 | new_feature = local_xyz
400 |
401 | return new_feature
402 |
403 |
404 | class GroupAll(nn.Module):
405 | def __init__(self, use_xyz):
406 | super().__init__()
407 | self.use_xyz = use_xyz
408 |
409 | def execute(self, new_xyz, pointset, feature):
410 | if self.use_xyz:
411 | new_feature = jt.contrib.concat([pointset, feature], dim=-1)
412 | new_feature = new_feature.unsqueeze(dim=1) # [B, 1, N, C]
413 | return new_feature
414 |
415 |
416 | class KNN(nn.Module):
417 | def __init__(self, k):
418 | self.k = k
419 | self.cuda_inc= """
420 | #undef out
421 | #include "helper_cuda.h"
422 |
423 | __global__ void compute_distances(float * ref,
424 | int ref_width,
425 | int ref_pitch,
426 | float * query,
427 | int query_width,
428 | int query_pitch,
429 | int height,
430 | float * dist) {
431 |
432 | // Declaration of the shared memory arrays As and Bs used to store the sub-matrix of A and B
433 | const int BLOCK_DIM = 16;
434 | __shared__ float shared_A[BLOCK_DIM][BLOCK_DIM];
435 | __shared__ float shared_B[BLOCK_DIM][BLOCK_DIM];
436 |
437 | // Sub-matrix of A (begin, step, end) and Sub-matrix of B (begin, step)
438 | __shared__ int begin_A;
439 | __shared__ int begin_B;
440 | __shared__ int step_A;
441 | __shared__ int step_B;
442 | __shared__ int end_A;
443 |
444 | // Thread index
445 | int tx = threadIdx.x;
446 | int ty = threadIdx.y;
447 | int batch_id = blockIdx.z;
448 |
449 | // Initializarion of the SSD for the current thread
450 | float ssd = 0.f;
451 |
452 | // Loop parameters
453 | begin_A = BLOCK_DIM * blockIdx.y;
454 | begin_B = BLOCK_DIM * blockIdx.x;
455 | step_A = BLOCK_DIM * ref_pitch;
456 | step_B = BLOCK_DIM * query_pitch;
457 | end_A = begin_A + (height-1) * ref_pitch;
458 |
459 | // Conditions
460 | int cond0 = (begin_A + tx < ref_width); // used to write in shared memory
461 | int cond1 = (begin_B + tx < query_width); // used to write in shared memory & to computations and to write in output array
462 | int cond2 = (begin_A + ty < ref_width); // used to computations and to write in output matrix
463 |
464 | // Loop over all the sub-matrices of A and B required to compute the block sub-matrix
465 | for (int a = begin_A, b = begin_B; a <= end_A; a += step_A, b += step_B) {
466 |
467 | // Load the matrices from device memory to shared memory; each thread loads one element of each matrix
468 | if (a/ref_pitch + ty < height) {
469 | shared_A[ty][tx] = (cond0)? ref[a + ref_pitch * ty + tx + batch_id * height * ref_pitch] : 0;
470 | shared_B[ty][tx] = (cond1)? query[b + query_pitch * ty + tx + batch_id * height * query_pitch] : 0;
471 | }
472 | else {
473 | shared_A[ty][tx] = 0;
474 | shared_B[ty][tx] = 0;
475 | }
476 |
477 | // Synchronize to make sure the matrices are loaded
478 | __syncthreads();
479 |
480 | // Compute the difference between the two matrixes; each thread computes one element of the block sub-matrix
481 | if (cond2 && cond1) {
482 | for (int k = 0; k < BLOCK_DIM; ++k){
483 | float tmp = shared_A[k][ty] - shared_B[k][tx];
484 | ssd += tmp*tmp;
485 | }
486 | }
487 |
488 | // Synchronize to make sure that the preceeding computation is done before loading two new sub-matrices of A and B in the next iteration
489 | __syncthreads();
490 | }
491 |
492 | // Write the block sub-matrix to device memory; each thread writes one element
493 | if (cond2 && cond1) {
494 | dist[ (begin_A + ty) * query_pitch + begin_B + tx + batch_id * ref_pitch * query_pitch ] = ssd;
495 | }
496 | }
497 |
498 | __global__ void modified_insertion_sort(float * dist,
499 | int ref_pitch,
500 | int * index,
501 | int index_pitch,
502 | int width,
503 | int height,
504 | int k){
505 |
506 | // Column position
507 | unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
508 | int batch_id = blockIdx.z ;
509 |
510 |
511 | // Do nothing if we are out of bounds
512 | if (xIndex < width) {
513 |
514 | // Pointer shift
515 | float * p_dist = dist + xIndex + batch_id * ref_pitch * index_pitch;
516 | int * p_index = index + xIndex + batch_id * index_pitch * k;
517 |
518 | // Initialise the first index
519 | p_index[0] = 0;
520 |
521 | // Go through all points
522 | for (int i=1; i= k and if it's higher the k-th slready sorted mallest value
529 | if (i >= k && curr_dist >= p_dist[(k-1)*index_pitch]) {
530 | continue;
531 | }
532 |
533 | // Shift values (and indexes) higher that the current distance to the right
534 | int j = min(i, k-1);
535 | while (j > 0 && p_dist[(j-1)*index_pitch] > curr_dist) {
536 | p_dist[j*index_pitch] = p_dist[(j-1)*index_pitch];
537 | p_index[j*index_pitch] = p_index[(j-1)*index_pitch];
538 | --j;
539 | }
540 |
541 | // Write the current distance and index at their position
542 | p_dist[j*index_pitch] = curr_dist;
543 | p_index[j*index_pitch] = curr_index;
544 | }
545 | }
546 | }
547 |
548 | __global__ void compute_sqrt(float * dist, int width, int pitch, int k){
549 | unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x;
550 | unsigned int yIndex = blockIdx.y * blockDim.y + threadIdx.y;
551 | int batch_id = blockIdx.z;
552 | if (xIndex>>(ref_dev, ref_nb, ref_pitch, query_dev, query_nb, query_pitch, dim, dist_dev);
600 | // checkCudaErrors(cudaDeviceSynchronize());
601 |
602 | // printf("%d", cudaDeviceSynchronize());
603 | // printf(" after compute_distances \\n");
604 |
605 | // Sort the distances with their respective indexes
606 | dim3 block1(256, 1, 1);
607 | dim3 grid1(query_nb / 256, 1, batch_size);
608 | if (query_nb % 256 != 0) grid1.x += 1;
609 | // printf("%d", cudaDeviceSynchronize());
610 | // printf(" before modified_insertion_sort \\n");
611 | // checkCudaErrors(cudaDeviceSynchronize());
612 |
613 | modified_insertion_sort<<>>(dist_dev, ref_pitch, index_dev, index_pitch, query_nb, ref_nb, k);
614 |
615 | // checkCudaErrors(cudaDeviceSynchronize());
616 | // printf("%d", cudaDeviceSynchronize());
617 | // printf(" after modified_insertion_sort \\n");
618 |
619 | // Compute the square root of the k smallest distances
620 | //dim3 block2(16, 16, 1);
621 | //dim3 grid2(query_nb / 16, k / 16, batch_size);
622 | //if (query_nb % 16 != 0) grid2.x += 1;
623 | //if (k % 16 != 0) grid2.y += 1;
624 | //compute_sqrt<<>>(dist_dev, query_nb, query_pitch, k);
625 |
626 |
627 | // Copy k smallest distances / indexes from the device to the host
628 | // TODO: batch 2d copy dist
629 | // cudaMemcpy2DAsync(knn_dist, query_nb * size_of_float, dist_dev, dist_pitch*size_of_float, query_nb * size_of_float, k, cudaMemcpyDefault);
630 |
631 | return true;
632 | }
633 |
634 |
635 | """
636 | self.cuda_src = '''
637 | const int k = out0_shape1;
638 | const int query_nb = in1_shape2;
639 | const int ref_nb = in0_shape2;
640 | const int dim = in0_shape1;
641 | const int batch_size = in0_shape0;
642 | knn_cuda_global(batch_size, in0_p, ref_nb, in1_p, query_nb, dim, k, out0_p, in2_p);
643 | '''
644 |
645 | def execute(self, x_q, x_r): # n_points, c_dim
646 | batch_size, c_dim, q_points = x_q.shape
647 | batch_size, c_dim, r_points = x_r.shape
648 | out_idx_shapes = [batch_size, self.k, q_points]
649 | tmp_dist = jt.empty((batch_size, r_points, q_points), "float32")
650 | idxs, = jt.code(
651 | [out_idx_shapes],
652 | ['int32'],
653 | [x_r, x_q, tmp_dist], # in0 r point in1 q point
654 | cuda_src=self.cuda_src,
655 | cuda_header=self.cuda_inc,
656 | )
657 | return idxs
658 |
659 |
660 |
661 | def topk(input, k, dim=None, largest=True, sorted=True):
662 | if dim is None:
663 | dim = -1
664 | if dim<0:
665 | dim+=input.ndim
666 |
667 | transpose_dims = [i for i in range(input.ndim)]
668 | transpose_dims[0] = dim
669 | transpose_dims[dim] = 0
670 | input = input.transpose(transpose_dims)
671 | index,values = jt.argsort(input,dim=0,descending=largest)
672 | indices = index[:k]
673 | values = values[:k]
674 | indices = indices.transpose(transpose_dims)
675 | values = values.transpose(transpose_dims)
676 | return [values,indices]
677 |
678 |
679 | def square_distance(src, dst):
680 | """
681 | Calculate Euclid distance between each two points.
682 | src^T * dst = xn * xm + yn * ym + zn * zm;
683 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn;
684 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm;
685 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2
686 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst
687 | Input:
688 | src: source points, [B, N, C]
689 | dst: target points, [B, M, C]
690 | Output:
691 | dist: per-point square distance, [B, N, M]
692 | """
693 | # print (src.size(), dst.size())
694 | B, N, _ = src.shape
695 | _, M, _ = dst.shape
696 | dist = -2 * jt.matmul(src, dst.permute(0, 2, 1))
697 | dist += jt.sum(src ** 2, -1).view(B, N, 1)
698 | dist += jt.sum(dst ** 2, -1).view(B, 1, M)
699 | return dist
700 |
701 | def index_points(points, idx):
702 | """
703 | Input:
704 | points: input points data, [B, N, C]
705 | idx: sample index data, [B, S]
706 | Return:
707 | new_points:, indexed points data, [B, S, C]
708 | """
709 | #device = points.device
710 | B = points.shape[0]
711 | view_shape = list(idx.shape)
712 | view_shape[1:] = [1] * (len(view_shape) - 1)
713 | repeat_shape = list(idx.shape)
714 | repeat_shape[0] = 1
715 | batch_indices = np.arange(B, dtype='l')
716 | batch_indices = jt.array(batch_indices).view(view_shape).repeat(repeat_shape)
717 | new_points = points[batch_indices, idx, :]
718 | return new_points
719 |
720 |
721 | def knn_point(nsample, xyz, new_xyz):
722 | """
723 | Input:
724 | nsample: max sample number in local region
725 | xyz: all points, [B, N, C]
726 | new_xyz: query points, [B, S, C]
727 | Return:
728 | group_idx: grouped points index, [B, S, nsample]
729 | """
730 | # print ('new xyz size, xyz size =',new_xyz.size(), xyz.size())
731 | sqrdists = square_distance(new_xyz, xyz)
732 | _, group_idx = topk(sqrdists, nsample, dim = -1, largest=False, sorted=False)
733 | return group_idx
734 |
735 |
736 | def knn (x, k):
737 | inner = -2 * jt.nn.bmm(x.transpose(0, 2, 1), x)
738 | xx = jt.sum(x ** 2, dim = 1, keepdims=True)
739 | distance = -xx - inner - xx.transpose(0, 2, 1)
740 | idx = topk(distance ,k=k, dim=-1)[1]
741 | return idx
742 |
743 |
--------------------------------------------------------------------------------
/networks/cls/pct.py:
--------------------------------------------------------------------------------
1 | import jittor as jt
2 | from jittor import nn
3 | from jittor import init
4 | from jittor.contrib import concat
5 | import numpy as np
6 | from misc.ops import FurthestPointSampler
7 | from misc.ops import knn_point, index_points
8 |
9 |
10 | def sample_and_group(npoint, nsample, xyz, points):
11 | B, N, C = xyz.shape
12 | S = npoint
13 | # xyz = xyz.contiguous()
14 | sampler = FurthestPointSampler(npoint)
15 | _, fps_idx = sampler(xyz) # [B, npoint]
16 | # print ('fps size=', fps_idx.size())
17 | # fps_idx = sampler(xyz).long() # [B, npoint]
18 | new_xyz = index_points(xyz, fps_idx)
19 | new_points = index_points(points, fps_idx)
20 | # new_xyz = xyz[:]
21 | # new_points = points[:]
22 |
23 | idx = knn_point(nsample, xyz, new_xyz)
24 | #idx = query_ball_point(radius, nsample, xyz, new_xyz)
25 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C]
26 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C)
27 | grouped_points = index_points(points, idx)
28 | grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1)
29 | new_points = concat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1)
30 | return new_xyz, new_points
31 |
32 |
33 |
34 | class Point_Transformer2(nn.Module):
35 | def __init__(self, output_channels=40):
36 | super(Point_Transformer2, self).__init__()
37 | self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False)
38 | self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False)
39 | self.bn1 = nn.BatchNorm1d(64)
40 | self.bn2 = nn.BatchNorm1d(64)
41 | self.gather_local_0 = Local_op(in_channels=128, out_channels=128)
42 | self.gather_local_1 = Local_op(in_channels=256, out_channels=256)
43 | self.pt_last = Point_Transformer_Last()
44 |
45 | self.relu = nn.ReLU()
46 | self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False),
47 | nn.BatchNorm1d(1024),
48 | nn.LeakyReLU(scale=0.2))
49 |
50 | self.linear1 = nn.Linear(1024, 512, bias=False)
51 | self.bn6 = nn.BatchNorm1d(512)
52 | self.dp1 = nn.Dropout(p=0.5)
53 | self.linear2 = nn.Linear(512, 256)
54 | self.bn7 = nn.BatchNorm1d(256)
55 | self.dp2 = nn.Dropout(p=0.5)
56 | self.linear3 = nn.Linear(256, output_channels)
57 |
58 | def execute(self, x):
59 | xyz = x.permute(0, 2, 1)
60 | batch_size, _, _ = x.size()
61 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N
62 | x = self.relu(self.bn2(self.conv2(x))) # B, D, N
63 | x = x.permute(0, 2, 1)
64 | new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x)
65 | feature_0 = self.gather_local_0(new_feature)
66 | feature = feature_0.permute(0, 2, 1)
67 | new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature)
68 | feature_1 = self.gather_local_1(new_feature)
69 | # add position embedding on each layer
70 | x = self.pt_last(feature_1, new_xyz)
71 | x = concat([x, feature_1], dim=1)
72 | x = self.conv_fuse(x)
73 | x = jt.max(x, 2)
74 | x = x.view(batch_size, -1)
75 |
76 | x = self.relu(self.bn6(self.linear1(x)))
77 | x = self.dp1(x)
78 | x = self.relu(self.bn7(self.linear2(x)))
79 | x = self.dp2(x)
80 | x = self.linear3(x)
81 |
82 | return x
83 |
84 |
85 |
86 | class Point_Transformer(nn.Module):
87 | def __init__(self, output_channels=40):
88 | super(Point_Transformer, self).__init__()
89 |
90 | self.conv1 = nn.Conv1d(3, 128, kernel_size=1, bias=False)
91 | self.conv2 = nn.Conv1d(128, 128, kernel_size=1, bias=False)
92 |
93 | self.bn1 = nn.BatchNorm1d(128)
94 | self.bn2 = nn.BatchNorm1d(128)
95 |
96 | self.sa1 = SA_Layer(128)
97 | self.sa2 = SA_Layer(128)
98 | self.sa3 = SA_Layer(128)
99 | self.sa4 = SA_Layer(128)
100 |
101 | self.conv_fuse = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False),
102 | nn.BatchNorm1d(1024),
103 | nn.LeakyReLU(scale=0.2))
104 |
105 | self.linear1 = nn.Linear(1024, 512, bias=False)
106 | self.bn6 = nn.BatchNorm1d(512)
107 | self.dp1 = nn.Dropout(p=0.5)
108 | self.linear2 = nn.Linear(512, 256)
109 | self.bn7 = nn.BatchNorm1d(256)
110 | self.dp2 = nn.Dropout(p=0.5)
111 | self.linear3 = nn.Linear(256, output_channels)
112 |
113 | self.relu = nn.ReLU()
114 |
115 | def execute(self, x):
116 |
117 | batch_size, _, N = x.size()
118 | # print (x.size())
119 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N
120 | x = self.relu(self.bn2(self.conv2(x)))
121 |
122 | x1 = self.sa1(x)
123 | x2 = self.sa2(x1)
124 | x3 = self.sa3(x2)
125 | x4 = self.sa4(x3)
126 |
127 | x = concat((x1, x2, x3, x4), dim=1)
128 |
129 | x = self.conv_fuse(x)
130 | # x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1)
131 | x = jt.max(x, 2)
132 | x = x.view(batch_size, -1)
133 | x = self.relu(self.bn6(self.linear1(x)))
134 | x = self.dp1(x)
135 | x = self.relu(self.bn7(self.linear2(x)))
136 | x = self.dp2(x)
137 | x = self.linear3(x)
138 | return x
139 |
140 |
141 |
142 | class Point_Transformer_Last(nn.Module):
143 | def __init__(self, channels=256):
144 | super(Point_Transformer_Last, self).__init__()
145 | self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False)
146 | self.conv_pos = nn.Conv1d(3, channels, kernel_size=1, bias=False)
147 |
148 | self.bn1 = nn.BatchNorm1d(channels)
149 |
150 | self.sa1 = SA_Layer(channels)
151 | self.sa2 = SA_Layer(channels)
152 | self.sa3 = SA_Layer(channels)
153 | self.sa4 = SA_Layer(channels)
154 |
155 | self.relu = nn.ReLU()
156 | def execute(self, x, xyz):
157 | #
158 | # b, 3, npoint, nsample
159 | # conv2d 3 -> 128 channels 1, 1
160 | # b * npoint, c, nsample
161 | # permute reshape
162 | batch_size, _, N = x.size()
163 | # add position embedding
164 | xyz = xyz.permute(0, 2, 1)
165 | xyz = self.pos_xyz(xyz)
166 | # end
167 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N
168 |
169 | x1 = self.sa1(x, xyz)
170 | x2 = self.sa2(x1, xyz)
171 | x3 = self.sa3(x2, xyz)
172 | x4 = self.sa4(x3, xyz)
173 |
174 | x = concat((x1, x2, x3, x4), dim=1)
175 |
176 | return x
177 |
178 | class Local_op(nn.Module):
179 | def __init__(self, in_channels, out_channels):
180 | super(Local_op, self).__init__()
181 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
182 | self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False)
183 | self.bn1 = nn.BatchNorm1d(out_channels)
184 | self.bn2 = nn.BatchNorm1d(out_channels)
185 | self.relu = nn.ReLU()
186 |
187 | def execute(self, x):
188 | b, n, s, d = x.size() # torch.Size([32, 512, 32, 6])
189 | x = x.permute(0, 1, 3, 2)
190 | x = x.reshape(-1, d, s)
191 | batch_size, _, N = x.size()
192 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N
193 | x = self.relu(self.bn2(self.conv2(x))) # B, D, N
194 | x = jt.max(x, 2)
195 | x = x.view(batch_size, -1)
196 | x = x.reshape(b, n, -1).permute(0, 2, 1)
197 | return x
198 |
199 |
200 |
201 | class SA_Layer(nn.Module):
202 | def __init__(self, channels):
203 | super(SA_Layer, self).__init__()
204 | self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
205 | self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
206 | self.q_conv.conv.weight = self.k_conv.conv.weight
207 | self.v_conv = nn.Conv1d(channels, channels, 1)
208 | self.trans_conv = nn.Conv1d(channels, channels, 1)
209 | self.after_norm = nn.BatchNorm1d(channels)
210 | self.act = nn.ReLU()
211 | self.softmax = nn.Softmax(dim=-1)
212 |
213 | def execute(self, x, xyz):
214 | x = x + xyz
215 | x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c
216 | x_k = self.k_conv(x)# b, c, n
217 | x_v = self.v_conv(x)
218 | energy = nn.bmm(x_q, x_k) # b, n, n
219 | attention = self.softmax(energy)
220 | attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
221 | x_r = nn.bmm(x_v, attention) # b, c, n
222 | x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
223 | x = x + x_r
224 | return x
225 |
226 | if __name__ == '__main__':
227 |
228 | jt.flags.use_cuda=1
229 | input_points = init.gauss((16, 3, 1024), dtype='float32') # B, D, N
230 |
231 |
232 | network = Point_Transformer()
233 | out_logits = network(input_points)
234 | print (out_logits.shape)
235 |
236 |
--------------------------------------------------------------------------------
/networks/seg/pct_partseg.py:
--------------------------------------------------------------------------------
1 | import jittor as jt
2 | from jittor import nn
3 | from jittor.contrib import concat
4 | import numpy as np
5 | import math
6 |
7 |
8 | class Point_Transformer_partseg(nn.Module):
9 | def __init__(self, part_num=50):
10 | super(Point_Transformer_partseg, self).__init__()
11 | self.part_num = part_num
12 | self.conv1 = nn.Conv1d(3, 128, kernel_size=1, bias=False)
13 | self.conv2 = nn.Conv1d(128, 128, kernel_size=1, bias=False)
14 |
15 | self.bn1 = nn.BatchNorm1d(128)
16 | self.bn2 = nn.BatchNorm1d(128)
17 |
18 | self.sa1 = SA_Layer(128)
19 | self.sa2 = SA_Layer(128)
20 | self.sa3 = SA_Layer(128)
21 | self.sa4 = SA_Layer(128)
22 |
23 | self.conv_fuse = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False),
24 | nn.BatchNorm1d(1024),
25 | nn.LeakyReLU(scale=0.2))
26 |
27 | self.label_conv = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False),
28 | nn.BatchNorm1d(64),
29 | nn.LeakyReLU(scale=0.2))
30 |
31 | self.convs1 = nn.Conv1d(1024 * 3 + 64, 512, 1)
32 | self.dp1 = nn.Dropout(0.5)
33 | self.convs2 = nn.Conv1d(512, 256, 1)
34 | self.convs3 = nn.Conv1d(256, self.part_num, 1)
35 | self.bns1 = nn.BatchNorm1d(512)
36 | self.bns2 = nn.BatchNorm1d(256)
37 |
38 | self.relu = nn.ReLU()
39 |
40 | def execute(self, x, cls_label):
41 | batch_size, _, N = x.size()
42 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N
43 | x = self.relu(self.bn2(self.conv2(x)))
44 | x1 = self.sa1(x)
45 | x2 = self.sa2(x1)
46 | x3 = self.sa3(x2)
47 | x4 = self.sa4(x3)
48 | x = concat((x1, x2, x3, x4), dim=1)
49 | x = self.conv_fuse(x)
50 | x_max = jt.max(x, 2)
51 | x_avg = jt.mean(x, 2)
52 | x_max_feature = x_max.view(batch_size, -1).unsqueeze(-1).repeat(1, 1, N)
53 | x_avg_feature = x_avg.view(batch_size, -1).unsqueeze(-1).repeat(1, 1, N)
54 | cls_label_one_hot = cls_label.view(batch_size,16,1)
55 | cls_label_feature = self.label_conv(cls_label_one_hot).repeat(1, 1, N)
56 | x_global_feature = concat((x_max_feature, x_avg_feature, cls_label_feature), 1) # 1024 + 64
57 | x = concat((x, x_global_feature), 1) # 1024 * 3 + 64
58 | x = self.relu(self.bns1(self.convs1(x)))
59 | x = self.dp1(x)
60 | x = self.relu(self.bns2(self.convs2(x)))
61 | x = self.convs3(x)
62 | return x
63 |
64 |
65 |
66 | class SA_Layer(nn.Module):
67 | def __init__(self, channels):
68 | super(SA_Layer, self).__init__()
69 | self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
70 | self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False)
71 | self.q_conv.conv.weight = self.k_conv.conv.weight
72 | self.v_conv = nn.Conv1d(channels, channels, 1)
73 | self.trans_conv = nn.Conv1d(channels, channels, 1)
74 | self.after_norm = nn.BatchNorm1d(channels)
75 | self.act = nn.ReLU()
76 | self.softmax = nn.Softmax(dim=-1)
77 |
78 | def execute(self, x):
79 | x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c
80 | x_k = self.k_conv(x)# b, c, n
81 | x_v = self.v_conv(x)
82 | energy = nn.bmm(x_q, x_k) # b, n, n
83 | attention = self.softmax(energy)
84 | attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True))
85 | x_r = nn.bmm(x_v, attention) # b, c, n
86 | x_r = self.act(self.after_norm(self.trans_conv(x - x_r)))
87 | x = x + x_r
88 | return x
89 |
--------------------------------------------------------------------------------