├── README.md ├── imgs ├── architecture.png └── attention.png ├── misc └── ops.py └── networks ├── cls └── pct.py └── seg └── pct_partseg.py /README.md: -------------------------------------------------------------------------------- 1 | # PCT: Point Cloud Transformer 2 | 3 | This is a Jittor implementation of PCT: Point Cloud Transformer. 4 | 5 | Paper link: https://arxiv.org/pdf/2012.09688.pdf 6 | 7 | 8 | ## News : 9 | 10 | * 2021.3.31 : We try to add simple position embedding in each self-attention layer, we get a more stable training process and 93.3% (5 run best) accuracy on modelnet40 dataset. Code updates in classification network. 11 | * 2021.3.29 : PCT has been accepted by Computational Visual Media Journal (CVMJ). 12 | 13 | 14 | ## Astract 15 | 16 | 17 | The irregular domain and lack of ordering make it challenging to design deep neural networks for point cloud processing. This paper presents a novel framework named Point Cloud Transformer(PCT) for point cloud learning. PCT is based on Transformer, which achieves huge success in natural language processing and displays great potential in image processing. It is inherently permutation invariant for processing a sequence of points, making it well-suited for point cloud learning. To better capture local context within the point cloud, we enhance input embedding with the support of farthest point sampling and nearest neighbor search. Extensive experiments demonstrate that the PCT achieves the state-of-the-art performance on shape classification, part segmentation and normal estimation tasks 18 | 19 | 20 | ![image](https://github.com/MenghaoGuo/PCT/blob/main/imgs/attention.png) 21 | 22 | 23 | ## Architecture 24 | 25 | 26 | ![image](https://github.com/MenghaoGuo/PCT/blob/main/imgs/architecture.png) 27 | 28 | 29 | 30 | ## Jittor 31 | 32 | Jittor is a high-performance deep learning framework which is easy to learn and use. It provides interfaces like Pytorch. 33 | 34 | You can learn how to use Jittor in following links: 35 | 36 | Jittor homepage: https://cg.cs.tsinghua.edu.cn/jittor/ 37 | 38 | Jittor github: https://github.com/Jittor/jittor 39 | 40 | If you has any questions about Jittor, you can ask in Jittor developer QQ Group: 761222083 41 | 42 | ## Other implementation 43 | 44 | ##### Version 1 : https://github.com/Strawberry-Eat-Mango/PCT_Pytorch (Pytorch version with classification acc 93.2% on ModelNet40) 45 | ##### Version 2 : https://github.com/qq456cvb/Point-Transformers (Pytorch version with classification acc 92.6% on ModelNet40) 46 | #### About part segmentation, if you want to reproduce the part segmentation results, you can refer this : https://github.com/AnTao97/dgcnn.pytorch 47 | 48 | 49 | 50 | 51 | 53 | ## Citation 54 | 55 | If it is helpful for your work, please cite this paper: 56 | ``` 57 | @article{Guo_2021, 58 | title={PCT: Point cloud transformer}, 59 | volume={7}, 60 | ISSN={2096-0662}, 61 | url={http://dx.doi.org/10.1007/s41095-021-0229-5}, 62 | DOI={10.1007/s41095-021-0229-5}, 63 | number={2}, 64 | journal={Computational Visual Media}, 65 | publisher={Springer Science and Business Media LLC}, 66 | author={Guo, Meng-Hao and Cai, Jun-Xiong and Liu, Zheng-Ning and Mu, Tai-Jiang and Martin, Ralph R. and Hu, Shi-Min}, 67 | year={2021}, 68 | month={Apr}, 69 | pages={187–199} 70 | } 71 | ``` 72 | -------------------------------------------------------------------------------- /imgs/architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MenghaoGuo/PCT/b7152d6ccfe15c7d2097d3200a8591d73bcd442a/imgs/architecture.png -------------------------------------------------------------------------------- /imgs/attention.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/MenghaoGuo/PCT/b7152d6ccfe15c7d2097d3200a8591d73bcd442a/imgs/attention.png -------------------------------------------------------------------------------- /misc/ops.py: -------------------------------------------------------------------------------- 1 | import math 2 | import numpy as np 3 | import jittor as jt 4 | from jittor import nn 5 | from jittor.contrib import concat 6 | 7 | def index_points(points, idx): 8 | """ 9 | Input: 10 | points: input points data, [B, N, C] 11 | idx: sample index data, [B, S] 12 | Return: 13 | new_points:, indexed points data, [B, S, C] 14 | """ 15 | B = points.shape[0] 16 | view_shape = list(idx.shape) 17 | view_shape[1:] = [1] * (len(view_shape) - 1) 18 | repeat_shape = list(idx.shape) 19 | repeat_shape[0] = 1 20 | batch_indices = jt.array(np.arange(B, dtype=np.int32)).view(view_shape).repeat(repeat_shape) 21 | new_points = points[batch_indices, idx, :] 22 | return new_points 23 | 24 | 25 | def square_distance(src, dst): 26 | """ 27 | Calculate Euclid distance between each two points. 28 | src^T * dst = xn * xm + yn * ym + zn * zm; 29 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 30 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 31 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 32 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 33 | Input: 34 | src: source points, [B, N, C] 35 | dst: target points, [B, M, C] 36 | Output: 37 | dist: per-point square distance, [B, N, M] 38 | """ 39 | 40 | B, N, _ = src.shape 41 | _, M, _ = dst.shape 42 | # print ('before matmul size', src.size(), dst.size()) 43 | dist = -2 * nn.matmul(src, dst.permute(0, 2, 1)) 44 | dist += jt.sum(src ** 2, -1).view(B, N, 1) 45 | dist += jt.sum(dst ** 2, -1).view(B, 1, M) 46 | return dist 47 | 48 | 49 | class PointNetFeaturePropagation(nn.Module): 50 | def __init__(self, in_channel, mlp): 51 | super(PointNetFeaturePropagation, self).__init__() 52 | self.mlp_convs = nn.ModuleList() 53 | self.mlp_bns = nn.ModuleList() 54 | last_channel = in_channel 55 | self.relu = nn.ReLU() 56 | for out_channel in mlp: 57 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 58 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 59 | last_channel = out_channel 60 | 61 | def execute(self, xyz1, xyz2, points1, points2): 62 | """ 63 | Input: 64 | xyz1: input points position data, [B, C, N] 65 | xyz2: sampled input points position data, [B, C, S] 66 | points1: input points data, [B, D, N] 67 | points2: input points data, [B, D, S] 68 | Return: 69 | new_points: upsampled points data, [B, D', N] 70 | """ 71 | # xyz1 = xyz1.permute(0, 2, 1) 72 | # xyz2 = xyz2.permute(0, 2, 1) 73 | 74 | # points2 = points2.permute(0, 2, 1) 75 | B, N, C = xyz1.shape 76 | _, S, _ = xyz2.shape 77 | 78 | if S == 1: 79 | interpolated_points = points2.repeat(1, N, 1) 80 | else: 81 | dists = square_distance(xyz1, xyz2) 82 | idx, dists = jt.argsort(dists, dim=-1) 83 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 84 | 85 | dist_recip = 1.0 / (dists + 1e-8) 86 | norm = jt.sum(dist_recip, dim=2, keepdims=True) 87 | weight = dist_recip / norm 88 | interpolated_points = jt.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 89 | 90 | if points1 is not None: 91 | # points1 = points1.permute(0, 2, 1) 92 | new_points = concat ([points1, interpolated_points], dim=-1) 93 | else: 94 | new_points = interpolated_points 95 | 96 | new_points = new_points.permute(0, 2, 1) 97 | # l = len(self.mlp_convs) 98 | for i, conv in self.mlp_convs.layers.items(): 99 | # conv = self.mlp_convs[i] 100 | bn = self.mlp_bns[i] 101 | new_points = self.relu(bn(conv(new_points))) 102 | return new_points.permute(0, 2, 1) 103 | 104 | 105 | def optimal_block(batch_size): 106 | return 2 ** int(math.log(batch_size)) 107 | 108 | 109 | class FurthestPointSampler(nn.Module): 110 | cuda_src=''' 111 | __device__ void __update(float *__restrict__ dists, int *__restrict__ dists_i, 112 | int idx1, int idx2) { 113 | const float v1 = dists[idx1], v2 = dists[idx2]; 114 | const int i1 = dists_i[idx1], i2 = dists_i[idx2]; 115 | dists[idx1] = max(v1, v2); 116 | dists_i[idx1] = v2 > v1 ? i2 : i1; 117 | } 118 | 119 | __global__ void furthest_point_sampling_kernel ( 120 | int b, int n, int m, int block_size, 121 | const float *__restrict__ dataset, 122 | float *__restrict__ temp, 123 | int *__restrict__ idxs) { 124 | 125 | if (m <= 0) return; 126 | 127 | extern __shared__ int dists_i[]; 128 | float *dists = (float *) &dists_i[block_size]; 129 | 130 | int batch_index = blockIdx.x; 131 | dataset += batch_index * n * 3; 132 | temp += batch_index * n; 133 | idxs += batch_index * m; 134 | 135 | int tid = threadIdx.x; 136 | const int stride = block_size; 137 | 138 | int old = 0; 139 | if (threadIdx.x == 0) idxs[0] = old; 140 | 141 | // initialize temp with INF 142 | for (int k = tid; k < n; k += stride) 143 | temp[k] = 1e10; 144 | 145 | __syncthreads(); 146 | for (int j = 1; j < m; j++) { 147 | int besti = 0; 148 | float best = -1; 149 | float x1 = dataset[old * 3 + 0]; 150 | float y1 = dataset[old * 3 + 1]; 151 | float z1 = dataset[old * 3 + 2]; 152 | for (int k = tid; k < n; k += stride) { 153 | float x2, y2, z2; 154 | x2 = dataset[k * 3 + 0]; 155 | y2 = dataset[k * 3 + 1]; 156 | z2 = dataset[k * 3 + 2]; 157 | float mag = (x2 * x2) + (y2 * y2) + (z2 * z2); 158 | if (mag <= 1e-3) continue; 159 | 160 | float d = (x2 - x1) * (x2 - x1) + (y2 - y1) * (y2 - y1) + (z2 - z1) * (z2 - z1); 161 | 162 | float d2 = min(d, temp[k]); 163 | temp[k] = d2; 164 | besti = d2 > best ? k : besti; 165 | best = d2 > best ? d2 : best; 166 | } 167 | dists[tid] = best; 168 | dists_i[tid] = besti; 169 | __syncthreads(); 170 | 171 | if (block_size >= 512) { 172 | if (tid < 256) { 173 | __update(dists, dists_i, tid, tid + 256); 174 | } 175 | __syncthreads(); 176 | } 177 | if (block_size >= 256) { 178 | if (tid < 128) { 179 | __update(dists, dists_i, tid, tid + 128); 180 | } 181 | __syncthreads(); 182 | } 183 | if (block_size >= 128) { 184 | if (tid < 64) { 185 | __update(dists, dists_i, tid, tid + 64); 186 | } 187 | __syncthreads(); 188 | } 189 | if (block_size >= 64) { 190 | if (tid < 32) { 191 | __update(dists, dists_i, tid, tid + 32); 192 | } 193 | __syncthreads(); 194 | } 195 | if (block_size >= 32) { 196 | if (tid < 16) { 197 | __update(dists, dists_i, tid, tid + 16); 198 | } 199 | __syncthreads(); 200 | } 201 | if (block_size >= 16) { 202 | if (tid < 8) { 203 | __update(dists, dists_i, tid, tid + 8); 204 | } 205 | __syncthreads(); 206 | } 207 | if (block_size >= 8) { 208 | if (tid < 4) { 209 | __update(dists, dists_i, tid, tid + 4); 210 | } 211 | __syncthreads(); 212 | } 213 | if (block_size >= 4) { 214 | if (tid < 2) { 215 | __update(dists, dists_i, tid, tid + 2); 216 | } 217 | __syncthreads(); 218 | } 219 | if (block_size >= 2) { 220 | if (tid < 1) { 221 | __update(dists, dists_i, tid, tid + 1); 222 | } 223 | __syncthreads(); 224 | } 225 | 226 | old = dists_i[0]; 227 | if (tid == 0) idxs[j] = old; 228 | } 229 | } 230 | 231 | int block_size = #block_size; 232 | 233 | float *temp; 234 | cudaMallocManaged(&temp, in0_shape0 * in0_shape1 * sizeof(float)); 235 | 236 | furthest_point_sampling_kernel<<>>( 237 | in0_shape0, 238 | in0_shape1, 239 | out_shape1, 240 | block_size, 241 | in0_p, 242 | temp, 243 | out_p 244 | ); 245 | cudaDeviceSynchronize(); 246 | cudaFree(temp); 247 | ''' 248 | def __init__(self, n_samples): 249 | super().__init__() 250 | self.n_samples = n_samples 251 | 252 | def execute(self, x): 253 | ''' 254 | Parameters 255 | ---------- 256 | x: jt.Var, (B, N, 3) 257 | 258 | Returns 259 | ------- 260 | y: jt.Var, (B, n_samples, 3) 261 | ''' 262 | batch_size, n_points, n_coords = x.shape 263 | 264 | assert self.n_samples <= n_points 265 | assert n_coords == 3 266 | assert x.dtype == 'float32' 267 | 268 | block_size = optimal_block(batch_size) 269 | 270 | cuda_src = self.cuda_src.replace('#block_size', str(block_size)) 271 | 272 | idxs_shape = [batch_size, self.n_samples] 273 | idxs = jt.code(idxs_shape, 'int32', [x,], cuda_src=cuda_src) 274 | 275 | y = x.reindex([batch_size, self.n_samples, 3], [ 276 | 'i0', # Batchid 277 | '@e0(i0, i1)', # Nid 278 | 'i2' 279 | ], extras=[idxs]) 280 | 281 | return y, idxs 282 | 283 | class BallQueryGrouper(nn.Module): 284 | cuda_src = ''' 285 | __global__ void query_ball_point_kernel(int b, int n, int m, float radius, 286 | int nsample, 287 | const float *__restrict__ new_xyz, 288 | const float *__restrict__ xyz, 289 | int *__restrict__ idx, 290 | int *__restrict__ cnt) { 291 | int batch_index = blockIdx.x; 292 | xyz += batch_index * n * 3; 293 | new_xyz += batch_index * m * 3; 294 | idx += m * nsample * batch_index; 295 | cnt += batch_index * m; 296 | 297 | int index = threadIdx.x; 298 | int stride = blockDim.x; 299 | 300 | float radius2 = radius * radius; 301 | for (int j = index; j < m; j += stride) { 302 | float new_x = new_xyz[j * 3 + 0]; 303 | float new_y = new_xyz[j * 3 + 1]; 304 | float new_z = new_xyz[j * 3 + 2]; 305 | cnt[j] = 0; 306 | 307 | for (int k = 0; k < n && cnt[j] < nsample; ++k) { 308 | float x = xyz[k * 3 + 0]; 309 | float y = xyz[k * 3 + 1]; 310 | float z = xyz[k * 3 + 2]; 311 | float d2 = (new_x - x) * (new_x - x) + (new_y - y) * (new_y - y) + 312 | (new_z - z) * (new_z - z); 313 | 314 | if (d2 < radius2) { 315 | if (cnt[j] == 0) { 316 | for (int l = 0; l < nsample; ++l) 317 | idx[j * nsample + l] = k; 318 | } 319 | idx[j * nsample + cnt[j]] = k; 320 | ++cnt[j]; 321 | } 322 | } 323 | } 324 | } 325 | 326 | int block_size = #block_size; 327 | 328 | query_ball_point_kernel<<>>( 329 | in0_shape0, in1_shape1, in0_shape1, #radius, #nsample, 330 | in0_p, in1_p, out0_p, out1_p 331 | ); 332 | ''' 333 | def __init__(self, radius, n_samples, use_xyz): 334 | super().__init__() 335 | self.radius = radius 336 | self.n_samples = n_samples 337 | self.use_xyz = use_xyz 338 | 339 | def execute(self, new_xyz, pointset, feature): 340 | ''' 341 | Parameters 342 | ---------- 343 | xyz: jt.Var, (B, N, 3) 344 | features: jt.Var, (B, N, C) 345 | 346 | Returns 347 | ------- 348 | new_feature: jt.Var, (B, N, n_samples, C) 349 | ''' 350 | batch_size_x, n_input, n_coords = new_xyz.shape 351 | assert n_coords == 3 352 | 353 | batch_size_p, n_points, n_coords = pointset.shape 354 | assert n_coords == 3 355 | assert batch_size_x == batch_size_p 356 | 357 | if feature is not None: 358 | batch_size_f, n_points_f, n_feature = feature.shape 359 | assert batch_size_x == batch_size_f 360 | assert n_points == n_points_f 361 | 362 | block_size = optimal_block(batch_size_x) 363 | 364 | cuda_src = self.cuda_src.replace('#block_size', str(block_size)) \ 365 | .replace('#radius', str(self.radius)) \ 366 | .replace('#nsample', str(self.n_samples)) 367 | 368 | idxs_shape = [batch_size_x, n_input, self.n_samples] 369 | cnts_shape = [batch_size_x, n_input] 370 | idxs, cnts = jt.code( 371 | [idxs_shape, cnts_shape], 372 | ['int32', 'int32'], 373 | [new_xyz, pointset], 374 | cuda_src=cuda_src 375 | ) 376 | 377 | pc_shape = [batch_size_x, n_input, self.n_samples, 3] 378 | new_pointset = pointset.reindex(pc_shape, [ 379 | 'i0', 380 | '@e0(i0, i1, i2)', 381 | 'i3', 382 | ], extras=[idxs]) 383 | 384 | if feature is not None: 385 | feature_shape = [batch_size_x, n_input, self.n_samples, n_feature] 386 | new_feature = feature.reindex(feature_shape, [ 387 | 'i0', # Batchid 388 | '@e0(i0, i1, i2)', # Nid 389 | 'i3', # Featureid 390 | ], extras=[idxs]) 391 | else: 392 | new_feature = None 393 | 394 | if self.use_xyz: 395 | local_xyz = new_pointset - new_xyz.unsqueeze(dim=2) 396 | if new_feature is not None: 397 | new_feature = jt.contrib.concat([local_xyz, new_feature], dim=-1) 398 | else: 399 | new_feature = local_xyz 400 | 401 | return new_feature 402 | 403 | 404 | class GroupAll(nn.Module): 405 | def __init__(self, use_xyz): 406 | super().__init__() 407 | self.use_xyz = use_xyz 408 | 409 | def execute(self, new_xyz, pointset, feature): 410 | if self.use_xyz: 411 | new_feature = jt.contrib.concat([pointset, feature], dim=-1) 412 | new_feature = new_feature.unsqueeze(dim=1) # [B, 1, N, C] 413 | return new_feature 414 | 415 | 416 | class KNN(nn.Module): 417 | def __init__(self, k): 418 | self.k = k 419 | self.cuda_inc= """ 420 | #undef out 421 | #include "helper_cuda.h" 422 | 423 | __global__ void compute_distances(float * ref, 424 | int ref_width, 425 | int ref_pitch, 426 | float * query, 427 | int query_width, 428 | int query_pitch, 429 | int height, 430 | float * dist) { 431 | 432 | // Declaration of the shared memory arrays As and Bs used to store the sub-matrix of A and B 433 | const int BLOCK_DIM = 16; 434 | __shared__ float shared_A[BLOCK_DIM][BLOCK_DIM]; 435 | __shared__ float shared_B[BLOCK_DIM][BLOCK_DIM]; 436 | 437 | // Sub-matrix of A (begin, step, end) and Sub-matrix of B (begin, step) 438 | __shared__ int begin_A; 439 | __shared__ int begin_B; 440 | __shared__ int step_A; 441 | __shared__ int step_B; 442 | __shared__ int end_A; 443 | 444 | // Thread index 445 | int tx = threadIdx.x; 446 | int ty = threadIdx.y; 447 | int batch_id = blockIdx.z; 448 | 449 | // Initializarion of the SSD for the current thread 450 | float ssd = 0.f; 451 | 452 | // Loop parameters 453 | begin_A = BLOCK_DIM * blockIdx.y; 454 | begin_B = BLOCK_DIM * blockIdx.x; 455 | step_A = BLOCK_DIM * ref_pitch; 456 | step_B = BLOCK_DIM * query_pitch; 457 | end_A = begin_A + (height-1) * ref_pitch; 458 | 459 | // Conditions 460 | int cond0 = (begin_A + tx < ref_width); // used to write in shared memory 461 | int cond1 = (begin_B + tx < query_width); // used to write in shared memory & to computations and to write in output array 462 | int cond2 = (begin_A + ty < ref_width); // used to computations and to write in output matrix 463 | 464 | // Loop over all the sub-matrices of A and B required to compute the block sub-matrix 465 | for (int a = begin_A, b = begin_B; a <= end_A; a += step_A, b += step_B) { 466 | 467 | // Load the matrices from device memory to shared memory; each thread loads one element of each matrix 468 | if (a/ref_pitch + ty < height) { 469 | shared_A[ty][tx] = (cond0)? ref[a + ref_pitch * ty + tx + batch_id * height * ref_pitch] : 0; 470 | shared_B[ty][tx] = (cond1)? query[b + query_pitch * ty + tx + batch_id * height * query_pitch] : 0; 471 | } 472 | else { 473 | shared_A[ty][tx] = 0; 474 | shared_B[ty][tx] = 0; 475 | } 476 | 477 | // Synchronize to make sure the matrices are loaded 478 | __syncthreads(); 479 | 480 | // Compute the difference between the two matrixes; each thread computes one element of the block sub-matrix 481 | if (cond2 && cond1) { 482 | for (int k = 0; k < BLOCK_DIM; ++k){ 483 | float tmp = shared_A[k][ty] - shared_B[k][tx]; 484 | ssd += tmp*tmp; 485 | } 486 | } 487 | 488 | // Synchronize to make sure that the preceeding computation is done before loading two new sub-matrices of A and B in the next iteration 489 | __syncthreads(); 490 | } 491 | 492 | // Write the block sub-matrix to device memory; each thread writes one element 493 | if (cond2 && cond1) { 494 | dist[ (begin_A + ty) * query_pitch + begin_B + tx + batch_id * ref_pitch * query_pitch ] = ssd; 495 | } 496 | } 497 | 498 | __global__ void modified_insertion_sort(float * dist, 499 | int ref_pitch, 500 | int * index, 501 | int index_pitch, 502 | int width, 503 | int height, 504 | int k){ 505 | 506 | // Column position 507 | unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x; 508 | int batch_id = blockIdx.z ; 509 | 510 | 511 | // Do nothing if we are out of bounds 512 | if (xIndex < width) { 513 | 514 | // Pointer shift 515 | float * p_dist = dist + xIndex + batch_id * ref_pitch * index_pitch; 516 | int * p_index = index + xIndex + batch_id * index_pitch * k; 517 | 518 | // Initialise the first index 519 | p_index[0] = 0; 520 | 521 | // Go through all points 522 | for (int i=1; i= k and if it's higher the k-th slready sorted mallest value 529 | if (i >= k && curr_dist >= p_dist[(k-1)*index_pitch]) { 530 | continue; 531 | } 532 | 533 | // Shift values (and indexes) higher that the current distance to the right 534 | int j = min(i, k-1); 535 | while (j > 0 && p_dist[(j-1)*index_pitch] > curr_dist) { 536 | p_dist[j*index_pitch] = p_dist[(j-1)*index_pitch]; 537 | p_index[j*index_pitch] = p_index[(j-1)*index_pitch]; 538 | --j; 539 | } 540 | 541 | // Write the current distance and index at their position 542 | p_dist[j*index_pitch] = curr_dist; 543 | p_index[j*index_pitch] = curr_index; 544 | } 545 | } 546 | } 547 | 548 | __global__ void compute_sqrt(float * dist, int width, int pitch, int k){ 549 | unsigned int xIndex = blockIdx.x * blockDim.x + threadIdx.x; 550 | unsigned int yIndex = blockIdx.y * blockDim.y + threadIdx.y; 551 | int batch_id = blockIdx.z; 552 | if (xIndex>>(ref_dev, ref_nb, ref_pitch, query_dev, query_nb, query_pitch, dim, dist_dev); 600 | // checkCudaErrors(cudaDeviceSynchronize()); 601 | 602 | // printf("%d", cudaDeviceSynchronize()); 603 | // printf(" after compute_distances \\n"); 604 | 605 | // Sort the distances with their respective indexes 606 | dim3 block1(256, 1, 1); 607 | dim3 grid1(query_nb / 256, 1, batch_size); 608 | if (query_nb % 256 != 0) grid1.x += 1; 609 | // printf("%d", cudaDeviceSynchronize()); 610 | // printf(" before modified_insertion_sort \\n"); 611 | // checkCudaErrors(cudaDeviceSynchronize()); 612 | 613 | modified_insertion_sort<<>>(dist_dev, ref_pitch, index_dev, index_pitch, query_nb, ref_nb, k); 614 | 615 | // checkCudaErrors(cudaDeviceSynchronize()); 616 | // printf("%d", cudaDeviceSynchronize()); 617 | // printf(" after modified_insertion_sort \\n"); 618 | 619 | // Compute the square root of the k smallest distances 620 | //dim3 block2(16, 16, 1); 621 | //dim3 grid2(query_nb / 16, k / 16, batch_size); 622 | //if (query_nb % 16 != 0) grid2.x += 1; 623 | //if (k % 16 != 0) grid2.y += 1; 624 | //compute_sqrt<<>>(dist_dev, query_nb, query_pitch, k); 625 | 626 | 627 | // Copy k smallest distances / indexes from the device to the host 628 | // TODO: batch 2d copy dist 629 | // cudaMemcpy2DAsync(knn_dist, query_nb * size_of_float, dist_dev, dist_pitch*size_of_float, query_nb * size_of_float, k, cudaMemcpyDefault); 630 | 631 | return true; 632 | } 633 | 634 | 635 | """ 636 | self.cuda_src = ''' 637 | const int k = out0_shape1; 638 | const int query_nb = in1_shape2; 639 | const int ref_nb = in0_shape2; 640 | const int dim = in0_shape1; 641 | const int batch_size = in0_shape0; 642 | knn_cuda_global(batch_size, in0_p, ref_nb, in1_p, query_nb, dim, k, out0_p, in2_p); 643 | ''' 644 | 645 | def execute(self, x_q, x_r): # n_points, c_dim 646 | batch_size, c_dim, q_points = x_q.shape 647 | batch_size, c_dim, r_points = x_r.shape 648 | out_idx_shapes = [batch_size, self.k, q_points] 649 | tmp_dist = jt.empty((batch_size, r_points, q_points), "float32") 650 | idxs, = jt.code( 651 | [out_idx_shapes], 652 | ['int32'], 653 | [x_r, x_q, tmp_dist], # in0 r point in1 q point 654 | cuda_src=self.cuda_src, 655 | cuda_header=self.cuda_inc, 656 | ) 657 | return idxs 658 | 659 | 660 | 661 | def topk(input, k, dim=None, largest=True, sorted=True): 662 | if dim is None: 663 | dim = -1 664 | if dim<0: 665 | dim+=input.ndim 666 | 667 | transpose_dims = [i for i in range(input.ndim)] 668 | transpose_dims[0] = dim 669 | transpose_dims[dim] = 0 670 | input = input.transpose(transpose_dims) 671 | index,values = jt.argsort(input,dim=0,descending=largest) 672 | indices = index[:k] 673 | values = values[:k] 674 | indices = indices.transpose(transpose_dims) 675 | values = values.transpose(transpose_dims) 676 | return [values,indices] 677 | 678 | 679 | def square_distance(src, dst): 680 | """ 681 | Calculate Euclid distance between each two points. 682 | src^T * dst = xn * xm + yn * ym + zn * zm; 683 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 684 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 685 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 686 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 687 | Input: 688 | src: source points, [B, N, C] 689 | dst: target points, [B, M, C] 690 | Output: 691 | dist: per-point square distance, [B, N, M] 692 | """ 693 | # print (src.size(), dst.size()) 694 | B, N, _ = src.shape 695 | _, M, _ = dst.shape 696 | dist = -2 * jt.matmul(src, dst.permute(0, 2, 1)) 697 | dist += jt.sum(src ** 2, -1).view(B, N, 1) 698 | dist += jt.sum(dst ** 2, -1).view(B, 1, M) 699 | return dist 700 | 701 | def index_points(points, idx): 702 | """ 703 | Input: 704 | points: input points data, [B, N, C] 705 | idx: sample index data, [B, S] 706 | Return: 707 | new_points:, indexed points data, [B, S, C] 708 | """ 709 | #device = points.device 710 | B = points.shape[0] 711 | view_shape = list(idx.shape) 712 | view_shape[1:] = [1] * (len(view_shape) - 1) 713 | repeat_shape = list(idx.shape) 714 | repeat_shape[0] = 1 715 | batch_indices = np.arange(B, dtype='l') 716 | batch_indices = jt.array(batch_indices).view(view_shape).repeat(repeat_shape) 717 | new_points = points[batch_indices, idx, :] 718 | return new_points 719 | 720 | 721 | def knn_point(nsample, xyz, new_xyz): 722 | """ 723 | Input: 724 | nsample: max sample number in local region 725 | xyz: all points, [B, N, C] 726 | new_xyz: query points, [B, S, C] 727 | Return: 728 | group_idx: grouped points index, [B, S, nsample] 729 | """ 730 | # print ('new xyz size, xyz size =',new_xyz.size(), xyz.size()) 731 | sqrdists = square_distance(new_xyz, xyz) 732 | _, group_idx = topk(sqrdists, nsample, dim = -1, largest=False, sorted=False) 733 | return group_idx 734 | 735 | 736 | def knn (x, k): 737 | inner = -2 * jt.nn.bmm(x.transpose(0, 2, 1), x) 738 | xx = jt.sum(x ** 2, dim = 1, keepdims=True) 739 | distance = -xx - inner - xx.transpose(0, 2, 1) 740 | idx = topk(distance ,k=k, dim=-1)[1] 741 | return idx 742 | 743 | -------------------------------------------------------------------------------- /networks/cls/pct.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import nn 3 | from jittor import init 4 | from jittor.contrib import concat 5 | import numpy as np 6 | from misc.ops import FurthestPointSampler 7 | from misc.ops import knn_point, index_points 8 | 9 | 10 | def sample_and_group(npoint, nsample, xyz, points): 11 | B, N, C = xyz.shape 12 | S = npoint 13 | # xyz = xyz.contiguous() 14 | sampler = FurthestPointSampler(npoint) 15 | _, fps_idx = sampler(xyz) # [B, npoint] 16 | # print ('fps size=', fps_idx.size()) 17 | # fps_idx = sampler(xyz).long() # [B, npoint] 18 | new_xyz = index_points(xyz, fps_idx) 19 | new_points = index_points(points, fps_idx) 20 | # new_xyz = xyz[:] 21 | # new_points = points[:] 22 | 23 | idx = knn_point(nsample, xyz, new_xyz) 24 | #idx = query_ball_point(radius, nsample, xyz, new_xyz) 25 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 26 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 27 | grouped_points = index_points(points, idx) 28 | grouped_points_norm = grouped_points - new_points.view(B, S, 1, -1) 29 | new_points = concat([grouped_points_norm, new_points.view(B, S, 1, -1).repeat(1, 1, nsample, 1)], dim=-1) 30 | return new_xyz, new_points 31 | 32 | 33 | 34 | class Point_Transformer2(nn.Module): 35 | def __init__(self, output_channels=40): 36 | super(Point_Transformer2, self).__init__() 37 | self.conv1 = nn.Conv1d(3, 64, kernel_size=1, bias=False) 38 | self.conv2 = nn.Conv1d(64, 64, kernel_size=1, bias=False) 39 | self.bn1 = nn.BatchNorm1d(64) 40 | self.bn2 = nn.BatchNorm1d(64) 41 | self.gather_local_0 = Local_op(in_channels=128, out_channels=128) 42 | self.gather_local_1 = Local_op(in_channels=256, out_channels=256) 43 | self.pt_last = Point_Transformer_Last() 44 | 45 | self.relu = nn.ReLU() 46 | self.conv_fuse = nn.Sequential(nn.Conv1d(1280, 1024, kernel_size=1, bias=False), 47 | nn.BatchNorm1d(1024), 48 | nn.LeakyReLU(scale=0.2)) 49 | 50 | self.linear1 = nn.Linear(1024, 512, bias=False) 51 | self.bn6 = nn.BatchNorm1d(512) 52 | self.dp1 = nn.Dropout(p=0.5) 53 | self.linear2 = nn.Linear(512, 256) 54 | self.bn7 = nn.BatchNorm1d(256) 55 | self.dp2 = nn.Dropout(p=0.5) 56 | self.linear3 = nn.Linear(256, output_channels) 57 | 58 | def execute(self, x): 59 | xyz = x.permute(0, 2, 1) 60 | batch_size, _, _ = x.size() 61 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N 62 | x = self.relu(self.bn2(self.conv2(x))) # B, D, N 63 | x = x.permute(0, 2, 1) 64 | new_xyz, new_feature = sample_and_group(npoint=512, nsample=32, xyz=xyz, points=x) 65 | feature_0 = self.gather_local_0(new_feature) 66 | feature = feature_0.permute(0, 2, 1) 67 | new_xyz, new_feature = sample_and_group(npoint=256, nsample=32, xyz=new_xyz, points=feature) 68 | feature_1 = self.gather_local_1(new_feature) 69 | # add position embedding on each layer 70 | x = self.pt_last(feature_1, new_xyz) 71 | x = concat([x, feature_1], dim=1) 72 | x = self.conv_fuse(x) 73 | x = jt.max(x, 2) 74 | x = x.view(batch_size, -1) 75 | 76 | x = self.relu(self.bn6(self.linear1(x))) 77 | x = self.dp1(x) 78 | x = self.relu(self.bn7(self.linear2(x))) 79 | x = self.dp2(x) 80 | x = self.linear3(x) 81 | 82 | return x 83 | 84 | 85 | 86 | class Point_Transformer(nn.Module): 87 | def __init__(self, output_channels=40): 88 | super(Point_Transformer, self).__init__() 89 | 90 | self.conv1 = nn.Conv1d(3, 128, kernel_size=1, bias=False) 91 | self.conv2 = nn.Conv1d(128, 128, kernel_size=1, bias=False) 92 | 93 | self.bn1 = nn.BatchNorm1d(128) 94 | self.bn2 = nn.BatchNorm1d(128) 95 | 96 | self.sa1 = SA_Layer(128) 97 | self.sa2 = SA_Layer(128) 98 | self.sa3 = SA_Layer(128) 99 | self.sa4 = SA_Layer(128) 100 | 101 | self.conv_fuse = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False), 102 | nn.BatchNorm1d(1024), 103 | nn.LeakyReLU(scale=0.2)) 104 | 105 | self.linear1 = nn.Linear(1024, 512, bias=False) 106 | self.bn6 = nn.BatchNorm1d(512) 107 | self.dp1 = nn.Dropout(p=0.5) 108 | self.linear2 = nn.Linear(512, 256) 109 | self.bn7 = nn.BatchNorm1d(256) 110 | self.dp2 = nn.Dropout(p=0.5) 111 | self.linear3 = nn.Linear(256, output_channels) 112 | 113 | self.relu = nn.ReLU() 114 | 115 | def execute(self, x): 116 | 117 | batch_size, _, N = x.size() 118 | # print (x.size()) 119 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N 120 | x = self.relu(self.bn2(self.conv2(x))) 121 | 122 | x1 = self.sa1(x) 123 | x2 = self.sa2(x1) 124 | x3 = self.sa3(x2) 125 | x4 = self.sa4(x3) 126 | 127 | x = concat((x1, x2, x3, x4), dim=1) 128 | 129 | x = self.conv_fuse(x) 130 | # x = F.adaptive_max_pool1d(x, 1).view(batch_size, -1) 131 | x = jt.max(x, 2) 132 | x = x.view(batch_size, -1) 133 | x = self.relu(self.bn6(self.linear1(x))) 134 | x = self.dp1(x) 135 | x = self.relu(self.bn7(self.linear2(x))) 136 | x = self.dp2(x) 137 | x = self.linear3(x) 138 | return x 139 | 140 | 141 | 142 | class Point_Transformer_Last(nn.Module): 143 | def __init__(self, channels=256): 144 | super(Point_Transformer_Last, self).__init__() 145 | self.conv1 = nn.Conv1d(channels, channels, kernel_size=1, bias=False) 146 | self.conv_pos = nn.Conv1d(3, channels, kernel_size=1, bias=False) 147 | 148 | self.bn1 = nn.BatchNorm1d(channels) 149 | 150 | self.sa1 = SA_Layer(channels) 151 | self.sa2 = SA_Layer(channels) 152 | self.sa3 = SA_Layer(channels) 153 | self.sa4 = SA_Layer(channels) 154 | 155 | self.relu = nn.ReLU() 156 | def execute(self, x, xyz): 157 | # 158 | # b, 3, npoint, nsample 159 | # conv2d 3 -> 128 channels 1, 1 160 | # b * npoint, c, nsample 161 | # permute reshape 162 | batch_size, _, N = x.size() 163 | # add position embedding 164 | xyz = xyz.permute(0, 2, 1) 165 | xyz = self.pos_xyz(xyz) 166 | # end 167 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N 168 | 169 | x1 = self.sa1(x, xyz) 170 | x2 = self.sa2(x1, xyz) 171 | x3 = self.sa3(x2, xyz) 172 | x4 = self.sa4(x3, xyz) 173 | 174 | x = concat((x1, x2, x3, x4), dim=1) 175 | 176 | return x 177 | 178 | class Local_op(nn.Module): 179 | def __init__(self, in_channels, out_channels): 180 | super(Local_op, self).__init__() 181 | self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False) 182 | self.conv2 = nn.Conv1d(out_channels, out_channels, kernel_size=1, bias=False) 183 | self.bn1 = nn.BatchNorm1d(out_channels) 184 | self.bn2 = nn.BatchNorm1d(out_channels) 185 | self.relu = nn.ReLU() 186 | 187 | def execute(self, x): 188 | b, n, s, d = x.size() # torch.Size([32, 512, 32, 6]) 189 | x = x.permute(0, 1, 3, 2) 190 | x = x.reshape(-1, d, s) 191 | batch_size, _, N = x.size() 192 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N 193 | x = self.relu(self.bn2(self.conv2(x))) # B, D, N 194 | x = jt.max(x, 2) 195 | x = x.view(batch_size, -1) 196 | x = x.reshape(b, n, -1).permute(0, 2, 1) 197 | return x 198 | 199 | 200 | 201 | class SA_Layer(nn.Module): 202 | def __init__(self, channels): 203 | super(SA_Layer, self).__init__() 204 | self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 205 | self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 206 | self.q_conv.conv.weight = self.k_conv.conv.weight 207 | self.v_conv = nn.Conv1d(channels, channels, 1) 208 | self.trans_conv = nn.Conv1d(channels, channels, 1) 209 | self.after_norm = nn.BatchNorm1d(channels) 210 | self.act = nn.ReLU() 211 | self.softmax = nn.Softmax(dim=-1) 212 | 213 | def execute(self, x, xyz): 214 | x = x + xyz 215 | x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c 216 | x_k = self.k_conv(x)# b, c, n 217 | x_v = self.v_conv(x) 218 | energy = nn.bmm(x_q, x_k) # b, n, n 219 | attention = self.softmax(energy) 220 | attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True)) 221 | x_r = nn.bmm(x_v, attention) # b, c, n 222 | x_r = self.act(self.after_norm(self.trans_conv(x - x_r))) 223 | x = x + x_r 224 | return x 225 | 226 | if __name__ == '__main__': 227 | 228 | jt.flags.use_cuda=1 229 | input_points = init.gauss((16, 3, 1024), dtype='float32') # B, D, N 230 | 231 | 232 | network = Point_Transformer() 233 | out_logits = network(input_points) 234 | print (out_logits.shape) 235 | 236 | -------------------------------------------------------------------------------- /networks/seg/pct_partseg.py: -------------------------------------------------------------------------------- 1 | import jittor as jt 2 | from jittor import nn 3 | from jittor.contrib import concat 4 | import numpy as np 5 | import math 6 | 7 | 8 | class Point_Transformer_partseg(nn.Module): 9 | def __init__(self, part_num=50): 10 | super(Point_Transformer_partseg, self).__init__() 11 | self.part_num = part_num 12 | self.conv1 = nn.Conv1d(3, 128, kernel_size=1, bias=False) 13 | self.conv2 = nn.Conv1d(128, 128, kernel_size=1, bias=False) 14 | 15 | self.bn1 = nn.BatchNorm1d(128) 16 | self.bn2 = nn.BatchNorm1d(128) 17 | 18 | self.sa1 = SA_Layer(128) 19 | self.sa2 = SA_Layer(128) 20 | self.sa3 = SA_Layer(128) 21 | self.sa4 = SA_Layer(128) 22 | 23 | self.conv_fuse = nn.Sequential(nn.Conv1d(512, 1024, kernel_size=1, bias=False), 24 | nn.BatchNorm1d(1024), 25 | nn.LeakyReLU(scale=0.2)) 26 | 27 | self.label_conv = nn.Sequential(nn.Conv1d(16, 64, kernel_size=1, bias=False), 28 | nn.BatchNorm1d(64), 29 | nn.LeakyReLU(scale=0.2)) 30 | 31 | self.convs1 = nn.Conv1d(1024 * 3 + 64, 512, 1) 32 | self.dp1 = nn.Dropout(0.5) 33 | self.convs2 = nn.Conv1d(512, 256, 1) 34 | self.convs3 = nn.Conv1d(256, self.part_num, 1) 35 | self.bns1 = nn.BatchNorm1d(512) 36 | self.bns2 = nn.BatchNorm1d(256) 37 | 38 | self.relu = nn.ReLU() 39 | 40 | def execute(self, x, cls_label): 41 | batch_size, _, N = x.size() 42 | x = self.relu(self.bn1(self.conv1(x))) # B, D, N 43 | x = self.relu(self.bn2(self.conv2(x))) 44 | x1 = self.sa1(x) 45 | x2 = self.sa2(x1) 46 | x3 = self.sa3(x2) 47 | x4 = self.sa4(x3) 48 | x = concat((x1, x2, x3, x4), dim=1) 49 | x = self.conv_fuse(x) 50 | x_max = jt.max(x, 2) 51 | x_avg = jt.mean(x, 2) 52 | x_max_feature = x_max.view(batch_size, -1).unsqueeze(-1).repeat(1, 1, N) 53 | x_avg_feature = x_avg.view(batch_size, -1).unsqueeze(-1).repeat(1, 1, N) 54 | cls_label_one_hot = cls_label.view(batch_size,16,1) 55 | cls_label_feature = self.label_conv(cls_label_one_hot).repeat(1, 1, N) 56 | x_global_feature = concat((x_max_feature, x_avg_feature, cls_label_feature), 1) # 1024 + 64 57 | x = concat((x, x_global_feature), 1) # 1024 * 3 + 64 58 | x = self.relu(self.bns1(self.convs1(x))) 59 | x = self.dp1(x) 60 | x = self.relu(self.bns2(self.convs2(x))) 61 | x = self.convs3(x) 62 | return x 63 | 64 | 65 | 66 | class SA_Layer(nn.Module): 67 | def __init__(self, channels): 68 | super(SA_Layer, self).__init__() 69 | self.q_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 70 | self.k_conv = nn.Conv1d(channels, channels // 4, 1, bias=False) 71 | self.q_conv.conv.weight = self.k_conv.conv.weight 72 | self.v_conv = nn.Conv1d(channels, channels, 1) 73 | self.trans_conv = nn.Conv1d(channels, channels, 1) 74 | self.after_norm = nn.BatchNorm1d(channels) 75 | self.act = nn.ReLU() 76 | self.softmax = nn.Softmax(dim=-1) 77 | 78 | def execute(self, x): 79 | x_q = self.q_conv(x).permute(0, 2, 1) # b, n, c 80 | x_k = self.k_conv(x)# b, c, n 81 | x_v = self.v_conv(x) 82 | energy = nn.bmm(x_q, x_k) # b, n, n 83 | attention = self.softmax(energy) 84 | attention = attention / (1e-9 + attention.sum(dim=1, keepdims=True)) 85 | x_r = nn.bmm(x_v, attention) # b, c, n 86 | x_r = self.act(self.after_norm(self.trans_conv(x - x_r))) 87 | x = x + x_r 88 | return x 89 | --------------------------------------------------------------------------------