├── CRSNET └── CRSNET.py ├── C_utils ├── build.py ├── build │ └── libsift.so ├── example.py ├── include │ ├── PointSift.h │ └── PointSift_cuda.h ├── install.sh ├── libsift │ ├── __init__.py │ ├── __pycache__ │ │ └── __init__.cpython-36.pyc │ └── _libsift.so └── src │ ├── PointSift.c │ └── PointSift_cuda.cu ├── FastFCN ├── fastfcn.py └── resnet.py ├── PCN ├── chamfer_distance │ ├── __init__.py │ ├── __pycache__ │ │ ├── __init__.cpython-37.pyc │ │ └── chamfer_distance.cpython-37.pyc │ ├── chamfer_distance.cpp │ ├── chamfer_distance.cu │ └── chamfer_distance.py ├── data_loader.py ├── example.png ├── main.py ├── models.py └── test.py ├── PointCNN ├── PointCNN.py ├── Utils.py └── xConv.py ├── PointSift └── sift.py ├── Pointnet2 ├── Pointnet2.py └── Pointnet2_msg.py ├── README.md ├── SPGN ├── SGPN.py ├── SGPN_utils.py ├── __pycache__ │ └── SGPN_utils.cpython-37.pyc ├── test_SGPN.py └── train_SGPN.py ├── Utils ├── Utilities.py ├── __pycache__ │ └── net_utils.cpython-37.pyc └── net_utils.py ├── cppattempt ├── Point.cpp ├── Point.h ├── Point_cuda.cu ├── build │ ├── lib.linux-x86_64-3.7 │ │ └── point.cpython-37m-x86_64-linux-gnu.so │ └── temp.linux-x86_64-3.7 │ │ ├── Point.o │ │ ├── Point_cuda.o │ │ └── point_api.o ├── dist │ └── point-0.0.0-py3.7-linux-x86_64.egg ├── example.py ├── point.egg-info │ ├── PKG-INFO │ ├── SOURCES.txt │ ├── dependency_links.txt │ └── top_level.txt ├── point_api.cpp └── setup.py ├── data_loaders └── create_grid.py ├── metrics.py └── unit_test ├── centroids.npy ├── centroids_cpp.npy ├── group_idx.npy ├── group_idxcpp.npy ├── test_cpp_vs_C.py └── test_pc.ply /CRSNET/CRSNET.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import numpy as np 5 | import torchvision.models as models 6 | 7 | 8 | class CRSNET(nn.Module): 9 | def __init__(self, n_channels, n_classes): 10 | super(CRSNET, self).__init__() 11 | self.base = models.vgg16().features 12 | self.encoder = [] 13 | for i in range(0,23): 14 | self.encoder.append(self.base[i]) 15 | 16 | self.encoder = nn.Sequential(*self.encoder) 17 | del self.base 18 | 19 | self.decoder = nn.Sequential( 20 | nn.Conv2d(512,512,3,dilation=2,padding=2), 21 | nn.ReLU(), 22 | nn.Conv2d(512,512,3,dilation=2,padding=2), 23 | nn.ReLU(), 24 | nn.Conv2d(512,256,3,dilation=2,padding=2), 25 | nn.ReLU(), 26 | nn.Conv2d(256,128,3,dilation=2,padding=2), 27 | nn.ReLU(), 28 | nn.Conv2d(128,64,3,dilation=2,padding=2), 29 | nn.ReLU(), 30 | nn.Conv2d(64,64,3,dilation=2,padding=2), 31 | nn.ReLU(), 32 | ) 33 | 34 | self.out_conv = nn.Conv2d(64,n_classes,1) 35 | def forward(self,x): 36 | x = self.encoder(x) 37 | print(x.shape) 38 | x = self.decoder(x) 39 | return self.out_conv(x) 40 | 41 | 42 | 43 | if __name__ == '__main__': 44 | model = CRSNET(3,1) 45 | print(model) 46 | input_tensor = torch.ones((1,3,224,224)) 47 | output = model(input_tensor) 48 | print(output.shape) -------------------------------------------------------------------------------- /C_utils/build.py: -------------------------------------------------------------------------------- 1 | import glob 2 | import torch 3 | from os import path as osp 4 | from torch.utils.ffi import create_extension 5 | 6 | abs_path = osp.dirname(osp.realpath(__file__)) 7 | extra_objects = [osp.join(abs_path, 'build/libsift.so')] 8 | extra_objects += glob.glob('/usr/local/cuda-9.0/lib64/*.a') 9 | 10 | ffi = create_extension( 11 | 'libsift', 12 | headers=['include/PointSift.h'], 13 | sources=['src/PointSift.c'], 14 | define_macros=[('WITH_CUDA', None)], 15 | relative_to=__file__, 16 | with_cuda=True, 17 | extra_objects=extra_objects, 18 | include_dirs=[osp.join(abs_path, 'include'),"/opt/cuda/include"] 19 | ) 20 | 21 | 22 | if __name__ == '__main__': 23 | print("COMPILING C ") 24 | assert torch.cuda.is_available(), 'Please install CUDA for GPU support.' 25 | ffi.build() 26 | -------------------------------------------------------------------------------- /C_utils/build/libsift.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/C_utils/build/libsift.so -------------------------------------------------------------------------------- /C_utils/example.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | from torch.autograd import Function 4 | import torch 5 | import sys 6 | import time 7 | import libsift 8 | 9 | 10 | 11 | x = torch.FloatTensor(1,800, 3).cuda() 12 | y = torch.zeros((1,800, 8), dtype=torch.int32).cuda() 13 | start_time = time.time() 14 | libsift.select_cube(x,y,8,6,radius) 15 | print(time.time() - start_time) 16 | 17 | 18 | start_time = time.time() 19 | xyz = x.cpu() 20 | radius = 0.4 21 | Dist = lambda x, y, z: x ** 2 + y ** 2 + z ** 2 22 | B, N, _ = xyz.shape 23 | idx = torch.empty(B, N, 8) 24 | judge_dist = radius ** 2 25 | temp_dist = torch.ones(B, N, 8) * 1e10 26 | for b in range(B): 27 | for n in range(N): 28 | idx[b, n, :] = n 29 | x, y, z = xyz[b, n] 30 | for p in range(N): 31 | if p == n: continue 32 | tx, ty, tz = xyz[b, p] 33 | dist = Dist(x - tx, y - ty, z - tz) 34 | if dist > judge_dist: continue 35 | _x, _y, _z = tx > x, ty > y, tz > z 36 | temp_idx = (_x * 4 + _y * 2 + _z).int() 37 | if dist < temp_dist[b, n, temp_idx]: 38 | idx[b, n, temp_idx] = p 39 | temp_dist[b, n, temp_idx] = dist 40 | 41 | print(time.time() - start_time) 42 | -------------------------------------------------------------------------------- /C_utils/include/PointSift.h: -------------------------------------------------------------------------------- 1 | void select_cube(THCudaTensor *xyz, THCudaIntTensor *idx_out, int b, int n,float radius); 2 | void group_points(int b, int n, int c , int m , int nsamples, THCudaTensor *xyz, THCudaIntTensor *idx, THCudaTensor *out); 3 | void ball_query(int b, int n, int m, float radius, int nsample, THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaIntTensor * idx, THCudaIntTensor * pts_cnt); 4 | void farthestPoint(int b,int n,int m,THCudaTensor * inp, THCudaTensor * temp,THCudaIntTensor * out); 5 | void interpolate(int b, int n, int m, THFloatTensor *xyz1p, THFloatTensor *xyz2p, THFloatTensor *distp, THIntTensor *idxp); 6 | void three_interpolate(int b, int m, int c, int n, THFloatTensor *points, THIntTensor *idx, THFloatTensor *weight, THFloatTensor *out); 7 | void IOUcalc(THFloatTensor * b1b, THFloatTensor *b2b,THFloatTensor *out , int nb1, int nb2); 8 | -------------------------------------------------------------------------------- /C_utils/include/PointSift_cuda.h: -------------------------------------------------------------------------------- 1 | #ifndef _POINTSHIFT_CUDA_H 2 | #define _POINTSHIFT_CUDA_H 3 | 4 | #ifdef __cplusplus 5 | extern "C" { 6 | #endif 7 | 8 | void cubeSelectLauncher(int b, int n, float radius, const float* xyz, int* idx_out); 9 | void group_pointsLauncher(int b, int n, int c, int m, int nsamples, const float * pointsp, const int * idxp, float * outp); 10 | void queryBallPointLauncher(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx, int *pts_cnt); 11 | void farthestpointsamplingLauncher(int b,int n,int m,const float * inp,float * temp,int * out); 12 | void threennLauncher(int b, int n, int m, const float *xyz1, const float *xyz2, float *dist, int *idx); 13 | void interpolateLauncher(int b, int m, int c, int n, const float *points, const int *idx, const float *weight, float *out); 14 | void groupPointLauncher(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out); 15 | #ifdef __cplusplus 16 | } 17 | #endif 18 | 19 | #endif 20 | -------------------------------------------------------------------------------- /C_utils/install.sh: -------------------------------------------------------------------------------- 1 | #!bin/bash 2 | /usr/local/cuda-9.0/bin/nvcc -c -o build/libsift.so src/PointSift_cuda.cu -x cu -Xcompiler -fPIC 3 | python build.py 4 | 5 | #gcc PointSift.c pointSIFT_g.cu.o -o tf_pointSIFT_so.so -shared -fPIC 6 | -------------------------------------------------------------------------------- /C_utils/libsift/__init__.py: -------------------------------------------------------------------------------- 1 | 2 | from torch.utils.ffi import _wrap_function 3 | from ._libsift import lib as _lib, ffi as _ffi 4 | 5 | __all__ = [] 6 | def _import_symbols(locals): 7 | for symbol in dir(_lib): 8 | fn = getattr(_lib, symbol) 9 | if callable(fn): 10 | locals[symbol] = _wrap_function(fn, _ffi) 11 | else: 12 | locals[symbol] = fn 13 | __all__.append(symbol) 14 | 15 | _import_symbols(locals()) 16 | -------------------------------------------------------------------------------- /C_utils/libsift/__pycache__/__init__.cpython-36.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/C_utils/libsift/__pycache__/__init__.cpython-36.pyc -------------------------------------------------------------------------------- /C_utils/libsift/_libsift.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/C_utils/libsift/_libsift.so -------------------------------------------------------------------------------- /C_utils/src/PointSift.c: -------------------------------------------------------------------------------- 1 | 2 | #include 3 | #include "PointSift_cuda.h" 4 | #include 5 | #include 6 | 7 | #define MAX(x, y) (((x) > (y)) ? (x) : (y)) 8 | #define MIN(x, y) (((x) < (y)) ? (x) : (y)) 9 | 10 | extern THCState *state; 11 | 12 | void select_cube(THCudaTensor *xyz, THCudaIntTensor *idx_out, int b, int n,float radius) 13 | { 14 | //select_cuda(int b, int n,float radius, float* xyz, float* idx_out) 15 | int * output = THCudaIntTensor_data(state, idx_out); 16 | float * input = THCudaTensor_data(state, xyz); 17 | cubeSelectLauncher(b,n,radius,input,output); 18 | } 19 | 20 | void group_points(int b, int n, int c , int m , int nsamples, THCudaTensor *xyz, THCudaIntTensor *idx, THCudaTensor *out) 21 | { 22 | int * idxp = THCudaIntTensor_data(state, idx); 23 | float * pointsp = THCudaTensor_data(state, xyz); 24 | float * outp = THCudaTensor_data(state, out); 25 | group_pointsLauncher(b,n,c,m,nsamples,pointsp,idxp,outp); 26 | } 27 | void ball_query (int b, int n, int m, float radius, int nsample, THCudaTensor *xyz1, THCudaTensor *xyz2, THCudaIntTensor * idx, THCudaIntTensor * pts_cnt){ 28 | 29 | int * idxp = THCudaIntTensor_data(state, idx); 30 | int * pts_cntp = THCudaIntTensor_data(state, pts_cnt); 31 | float * xyz1p = THCudaTensor_data(state, xyz1); 32 | float * xyz2p = THCudaTensor_data(state, xyz2); 33 | queryBallPointLauncher(b, n, m, radius, nsample, xyz1p, xyz2p, idxp, pts_cntp); 34 | } 35 | 36 | void farthestPoint(int b,int n,int m,THCudaTensor * inp, THCudaTensor * temp,THCudaIntTensor * out){ 37 | 38 | int * out2 = THCudaIntTensor_data(state, out); 39 | float * inp2 = THCudaTensor_data(state, inp); 40 | float * temp2 = THCudaTensor_data(state, temp); 41 | // clock_t t; 42 | // t = clock(); 43 | farthestpointsamplingLauncher(b, n, m, inp2, temp2, out2); 44 | // t = clock() - t; 45 | // double time_taken = ((double)t)/CLOCKS_PER_SEC; // in seconds 46 | // printf("fun() took %f seconds to execute \n", time_taken); 47 | } 48 | void interpolate(int b, int n, int m, THFloatTensor *xyz1p, THFloatTensor *xyz2p, THFloatTensor *distp, THIntTensor *idxp){ 49 | 50 | float * xyz1 = THCudaTensor_data(state, xyz1p); 51 | float * xyz2 = THCudaTensor_data(state, xyz2p); 52 | float * dist = THCudaTensor_data(state, distp); 53 | int * idx = THCudaIntTensor_data(state, idxp); 54 | 55 | for (int i=0;i 3 | 4 | __global__ void cubeselect(int n,float radius, const float* xyz, int* idx_out) 5 | { 6 | int batch_idx = blockIdx.x; 7 | xyz += batch_idx * n * 3; 8 | idx_out += batch_idx * n * 8; 9 | float temp_dist[8]; 10 | float judge_dist = radius * radius; 11 | for(int i = threadIdx.x; i < n;i += blockDim.x) { 12 | float x = xyz[i * 3]; 13 | float y = xyz[i * 3 + 1]; 14 | float z = xyz[i * 3 + 2]; 15 | for(int j = 0;j < 8;j ++) { 16 | temp_dist[j] = 1e8; 17 | idx_out[i * 8 + j] = i; // if not found, just return itself.. 18 | } 19 | for(int j = 0;j < n;j ++) { 20 | if(i != j){ 21 | float tx = xyz[j * 3]; 22 | float ty = xyz[j * 3 + 1]; 23 | float tz = xyz[j * 3 + 2]; 24 | float dist = (x - tx) * (x - tx) + (y - ty) * (y - ty) + (z - tz) * (z - tz); 25 | if(dist <= judge_dist){ 26 | int _x = (tx > x); 27 | int _y = (ty > y); 28 | int _z = (tz > z); 29 | int temp_idx = _x * 4 + _y * 2 + _z; 30 | if(dist < temp_dist[temp_idx]) { 31 | idx_out[i * 8 + temp_idx] = j; 32 | temp_dist[temp_idx] = dist; 33 | } 34 | } 35 | } 36 | } 37 | 38 | } 39 | } 40 | 41 | // input: points (b,n,c), idx (b,m,nsample) 42 | // output: out (b,m,nsample,c) 43 | __global__ void group_point_gpu(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out) 44 | { 45 | int batch_index = blockIdx.x; 46 | points += n*c*batch_index; 47 | idx += m*nsample*batch_index; 48 | out += m*nsample*c*batch_index; 49 | 50 | int index = threadIdx.x; 51 | int stride = blockDim.x; 52 | 53 | for (int j=index;jbest){ 143 | best=d2; 144 | besti=k; 145 | } 146 | } 147 | dists[threadIdx.x]=best; 148 | dists_i[threadIdx.x]=besti; 149 | for (int u=0;(1<>(u+1))){ 152 | int i1=(threadIdx.x*2)<>>(n, radius, xyz, idx_out); 245 | } 246 | void group_pointsLauncher(int b, int n, int c, int m, int nsamples, const float * pointsp, const int * idxp, float * outp){ 247 | group_point_gpu<<>>(b,n,c,m,nsamples,pointsp,idxp,outp); 248 | } 249 | void queryBallPointLauncher(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx, int *pts_cnt) { 250 | query_ball_point_gpu<<>>(b,n,m,radius,nsample,xyz1,xyz2,idx,pts_cnt); 251 | //cudaDeviceSynchronize(); 252 | } 253 | void farthestpointsamplingLauncher(int b,int n,int m,const float * inp,float * temp,int * out){ 254 | farthestpointsamplingKernel<<<32,512>>>(b,n,m,inp,temp,out); 255 | } 256 | 257 | void threennLauncher(int b, int n, int m, const float *xyz1, const float *xyz2, float *dist, int *idx){ 258 | threenn<<>>(b,n,m,xyz1,xyz2,dist,idx); 259 | } 260 | 261 | void interpolateLauncher(int b, int m, int c, int n, const float *points, const int *idx, const float *weight, float *out){ 262 | interpolategp<<>>(b,m,c,n,points,idx,weight,out); 263 | } 264 | -------------------------------------------------------------------------------- /FastFCN/fastfcn.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | import torch.nn.functional as F 4 | import resnet 5 | import numpy as np 6 | from torch.nn.functional import interpolate 7 | 8 | class SeparableConv2d(nn.Module): 9 | def __init__(self, inplanes, planes, kernel_size=3, stride=1, padding=1, dilation=1, bias=False, norm_layer=nn.BatchNorm2d): 10 | super(SeparableConv2d, self).__init__() 11 | 12 | self.conv1 = nn.Conv2d(inplanes, inplanes, kernel_size, stride, padding, dilation, groups=inplanes, bias=bias) 13 | self.bn = norm_layer(inplanes) 14 | self.pointwise = nn.Conv2d(inplanes, planes, 1, 1, 0, 1, 1, bias=bias) 15 | 16 | def forward(self, x): 17 | x = self.conv1(x) 18 | x = self.bn(x) 19 | x = self.pointwise(x) 20 | return x 21 | 22 | class FCNHead(nn.Module): 23 | def __init__(self, in_channels, out_channels, norm_layer): 24 | super(FCNHead, self).__init__() 25 | inter_channels = in_channels // 4 26 | self.conv5 = nn.Sequential(nn.Conv2d(in_channels, inter_channels, 3, padding=1, bias=False), 27 | norm_layer(inter_channels), 28 | nn.ReLU(), 29 | nn.Dropout2d(0.1, False), 30 | nn.Conv2d(inter_channels, out_channels, 1)) 31 | 32 | def forward(self, x): 33 | return self.conv5(x) 34 | 35 | class FastFCN(nn.Module): 36 | def __init__(self,input_size,nb_classes,imsize,backbone,dilated=False, norm_layer=nn.BatchNorm2d,sigmo=True): 37 | super(FastFCN, self).__init__() 38 | 39 | 40 | self.imsize = imsize 41 | self.sigmo = sigmo 42 | if self.sigmo: 43 | self.sig = nn.Sigmoid() 44 | ############ ENCODER PART ####################### 45 | if backbone == 'resnet50': 46 | self.pretrained = resnet.resnet50(pretrained=True, dilated=dilated, 47 | norm_layer=norm_layer) 48 | elif backbone == 'resnet101': 49 | self.pretrained = resnet.resnet101(pretrained=True, dilated=dilated, 50 | norm_layer=norm_layer) 51 | elif backbone == 'resnet152': 52 | self.pretrained = resnet.resnet152(pretrained=True, dilated=dilated, 53 | norm_layer=norm_layer) 54 | else: 55 | raise RuntimeError('unknown backbone: {}'.format(backbone)) 56 | 57 | #############JPU part ################ 58 | in_channels = [512, 1024, 2048] 59 | width=512 60 | self.conv5 = nn.Sequential( 61 | nn.Conv2d(in_channels[-1], width, 3, padding=1, bias=False), 62 | norm_layer(width), 63 | nn.ReLU(inplace=True)) 64 | self.conv4 = nn.Sequential( 65 | nn.Conv2d(in_channels[-2], width, 3, padding=1, bias=False), 66 | norm_layer(width), 67 | nn.ReLU(inplace=True)) 68 | self.conv3 = nn.Sequential( 69 | nn.Conv2d(in_channels[-3], width, 3, padding=1, bias=False), 70 | norm_layer(width), 71 | nn.ReLU(inplace=True)) 72 | 73 | self.dilation1 = nn.Sequential(SeparableConv2d(3*width, width, kernel_size=3, padding=1, dilation=1, bias=False), 74 | norm_layer(width), 75 | nn.ReLU(inplace=True)) 76 | self.dilation2 = nn.Sequential(SeparableConv2d(3*width, width, kernel_size=3, padding=2, dilation=2, bias=False), 77 | norm_layer(width), 78 | nn.ReLU(inplace=True)) 79 | self.dilation3 = nn.Sequential(SeparableConv2d(3*width, width, kernel_size=3, padding=4, dilation=4, bias=False), 80 | norm_layer(width), 81 | nn.ReLU(inplace=True)) 82 | self.dilation4 = nn.Sequential(SeparableConv2d(3*width, width, kernel_size=3, padding=8, dilation=8, bias=False), 83 | norm_layer(width), 84 | nn.ReLU(inplace=True)) 85 | 86 | 87 | ############# Segmentation Head (can be changed)####################### 88 | self.head = FCNHead(2048, nb_classes, norm_layer) 89 | 90 | def forward(self, x,size): 91 | x = self.pretrained.conv1(x) 92 | x = self.pretrained.bn1(x) 93 | x = self.pretrained.relu(x) 94 | x = self.pretrained.maxpool(x) 95 | c1 = self.pretrained.layer1(x) 96 | c2 = self.pretrained.layer2(c1) 97 | c3 = self.pretrained.layer3(c2) 98 | c4 = self.pretrained.layer4(c3) 99 | 100 | 101 | 102 | feats = [self.conv5(c4), self.conv4(c3), self.conv3(c2)] 103 | _, _, h, w = feats[-1].size() 104 | feats[-2] = F.interpolate(feats[-2], (h, w),mode='bilinear',align_corners=True) 105 | feats[-3] = F.interpolate(feats[-3], (h, w),mode='bilinear',align_corners=True) 106 | feat = torch.cat(feats, dim=1) 107 | feat = torch.cat([self.dilation1(feat), self.dilation2(feat), self.dilation3(feat), self.dilation4(feat)], dim=1) 108 | 109 | c4 = self.head(feat) 110 | output = interpolate(c4,size,mode='bilinear',align_corners=True) 111 | if self.sigmo: 112 | return self.sig(output) 113 | else: 114 | return output 115 | return output 116 | 117 | 118 | 119 | if __name__ == "__main__": 120 | model = FastFCN(3,1,(180,360),'resnet50',sigmo=True) 121 | input_tensor = torch.from_numpy(np.ones((1,3,180,360))) 122 | output = model(input_tensor.float(),(360,180)) 123 | print(output.shape) 124 | -------------------------------------------------------------------------------- /FastFCN/resnet.py: -------------------------------------------------------------------------------- 1 | """Dilated ResNet""" 2 | import math 3 | import torch 4 | import torch.nn as nn 5 | import torchvision.models as models 6 | 7 | __all__ = ['ResNet', 'resnet18', 'resnet34', 'resnet50', 'resnet101', 8 | 'resnet152', 'BasicBlock', 'Bottleneck'] 9 | 10 | 11 | 12 | def conv3x3(in_planes, out_planes, stride=1): 13 | "3x3 convolution with padding" 14 | return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride, 15 | padding=1, bias=False) 16 | 17 | 18 | class BasicBlock(nn.Module): 19 | """ResNet BasicBlock 20 | """ 21 | expansion = 1 22 | def __init__(self, inplanes, planes, stride=1, dilation=1, downsample=None, previous_dilation=1, 23 | norm_layer=None): 24 | super(BasicBlock, self).__init__() 25 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=3, stride=stride, 26 | padding=dilation, dilation=dilation, bias=False) 27 | self.bn1 = norm_layer(planes) 28 | self.relu = nn.ReLU(inplace=True) 29 | self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, 30 | padding=previous_dilation, dilation=previous_dilation, bias=False) 31 | self.bn2 = norm_layer(planes) 32 | self.downsample = downsample 33 | self.stride = stride 34 | 35 | def forward(self, x): 36 | residual = x 37 | 38 | out = self.conv1(x) 39 | out = self.bn1(out) 40 | out = self.relu(out) 41 | 42 | out = self.conv2(out) 43 | out = self.bn2(out) 44 | 45 | if self.downsample is not None: 46 | residual = self.downsample(x) 47 | 48 | out += residual 49 | out = self.relu(out) 50 | 51 | return out 52 | 53 | 54 | class Bottleneck(nn.Module): 55 | """ResNet Bottleneck 56 | """ 57 | # pylint: disable=unused-argument 58 | expansion = 4 59 | def __init__(self, inplanes, planes, stride=1, dilation=1, 60 | downsample=None, previous_dilation=1, norm_layer=None): 61 | super(Bottleneck, self).__init__() 62 | self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False) 63 | self.bn1 = norm_layer(planes) 64 | self.conv2 = nn.Conv2d( 65 | planes, planes, kernel_size=3, stride=stride, 66 | padding=dilation, dilation=dilation, bias=False) 67 | self.bn2 = norm_layer(planes) 68 | self.conv3 = nn.Conv2d( 69 | planes, planes * 4, kernel_size=1, bias=False) 70 | self.bn3 = norm_layer(planes * 4) 71 | self.relu = nn.ReLU(inplace=True) 72 | self.downsample = downsample 73 | self.dilation = dilation 74 | self.stride = stride 75 | 76 | def _sum_each(self, x, y): 77 | assert(len(x) == len(y)) 78 | z = [] 79 | for i in range(len(x)): 80 | z.append(x[i]+y[i]) 81 | return z 82 | 83 | def forward(self, x): 84 | residual = x 85 | 86 | out = self.conv1(x) 87 | out = self.bn1(out) 88 | out = self.relu(out) 89 | 90 | out = self.conv2(out) 91 | out = self.bn2(out) 92 | out = self.relu(out) 93 | 94 | out = self.conv3(out) 95 | out = self.bn3(out) 96 | 97 | if self.downsample is not None: 98 | residual = self.downsample(x) 99 | 100 | out += residual 101 | out = self.relu(out) 102 | 103 | return out 104 | 105 | 106 | class ResNet(nn.Module): 107 | """Dilated Pre-trained ResNet Model, which preduces the stride of 8 featuremaps at conv5. 108 | Parameters 109 | ---------- 110 | block : Block 111 | Class for the residual block. Options are BasicBlockV1, BottleneckV1. 112 | layers : list of int 113 | Numbers of layers in each block 114 | classes : int, default 1000 115 | Number of classification classes. 116 | dilated : bool, default False 117 | Applying dilation strategy to pretrained ResNet yielding a stride-8 model, 118 | typically used in Semantic Segmentation. 119 | norm_layer : object 120 | Normalization layer used in backbone network (default: :class:`mxnet.gluon.nn.BatchNorm`; 121 | for Synchronized Cross-GPU BachNormalization). 122 | Reference: 123 | - He, Kaiming, et al. "Deep residual learning for image recognition." Proceedings of the IEEE conference on computer vision and pattern recognition. 2016. 124 | - Yu, Fisher, and Vladlen Koltun. "Multi-scale context aggregation by dilated convolutions." 125 | """ 126 | # pylint: disable=unused-variable 127 | def __init__(self, block, layers, num_classes=1000, dilated=True, 128 | deep_base=False, norm_layer=nn.BatchNorm2d, output_size=8): 129 | self.inplanes = 128 if deep_base else 64 130 | super(ResNet, self).__init__() 131 | if deep_base: 132 | self.conv1 = nn.Sequential( 133 | nn.Conv2d(3, 64, kernel_size=3, stride=2, padding=1, bias=False), 134 | norm_layer(64), 135 | nn.ReLU(inplace=True), 136 | nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=False), 137 | norm_layer(64), 138 | nn.ReLU(inplace=True), 139 | nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False), 140 | ) 141 | else: 142 | self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, 143 | bias=False) 144 | self.bn1 = norm_layer(self.inplanes) 145 | self.relu = nn.ReLU(inplace=True) 146 | self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) 147 | self.layer1 = self._make_layer(block, 64, layers[0], norm_layer=norm_layer) 148 | self.layer2 = self._make_layer(block, 128, layers[1], stride=2, norm_layer=norm_layer) 149 | 150 | dilation_rate = 2 151 | if dilated and output_size <= 8: 152 | self.layer3 = self._make_layer(block, 256, layers[2], stride=1, 153 | dilation=dilation_rate, norm_layer=norm_layer) 154 | dilation_rate *= 2 155 | else: 156 | self.layer3 = self._make_layer(block, 256, layers[2], stride=2, 157 | norm_layer=norm_layer) 158 | 159 | if dilated and output_size <= 16: 160 | self.layer4 = self._make_layer(block, 512, layers[3], stride=1, 161 | dilation=dilation_rate, norm_layer=norm_layer) 162 | else: 163 | self.layer4 = self._make_layer(block, 512, layers[3], stride=2, 164 | norm_layer=norm_layer) 165 | 166 | self.avgpool = nn.AvgPool2d(7, stride=1) 167 | self.fc = nn.Linear(512 * block.expansion, num_classes) 168 | 169 | for m in self.modules(): 170 | if isinstance(m, nn.Conv2d): 171 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 172 | m.weight.data.normal_(0, math.sqrt(2. / n)) 173 | elif isinstance(m, norm_layer): 174 | m.weight.data.fill_(1) 175 | m.bias.data.zero_() 176 | 177 | def _make_layer(self, block, planes, blocks, stride=1, dilation=1, norm_layer=None): 178 | downsample = None 179 | if stride != 1 or self.inplanes != planes * block.expansion: 180 | downsample = nn.Sequential( 181 | nn.Conv2d(self.inplanes, planes * block.expansion, 182 | kernel_size=1, stride=stride, bias=False), 183 | norm_layer(planes * block.expansion), 184 | ) 185 | 186 | layers = [] 187 | if dilation == 1 or dilation == 2: 188 | layers.append(block(self.inplanes, planes, stride, dilation=1, 189 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 190 | elif dilation == 4: 191 | layers.append(block(self.inplanes, planes, stride, dilation=2, 192 | downsample=downsample, previous_dilation=dilation, norm_layer=norm_layer)) 193 | else: 194 | raise RuntimeError("=> unknown dilation size: {}".format(dilation)) 195 | 196 | self.inplanes = planes * block.expansion 197 | for i in range(1, blocks): 198 | layers.append(block(self.inplanes, planes, dilation=dilation, previous_dilation=dilation, 199 | norm_layer=norm_layer)) 200 | 201 | return nn.Sequential(*layers) 202 | 203 | def forward(self, x): 204 | x = self.conv1(x) 205 | x = self.bn1(x) 206 | x = self.relu(x) 207 | x = self.maxpool(x) 208 | 209 | x = self.layer1(x) 210 | x = self.layer2(x) 211 | x = self.layer3(x) 212 | x = self.layer4(x) 213 | 214 | x = self.avgpool(x) 215 | x = x.view(x.size(0), -1) 216 | x = self.fc(x) 217 | 218 | return x 219 | 220 | 221 | def resnet18(pretrained=False, **kwargs): 222 | """Constructs a ResNet-18 model. 223 | Args: 224 | pretrained (bool): If True, returns a model pre-trained on ImageNet 225 | """ 226 | model = ResNet(BasicBlock, [2, 2, 2, 2], **kwargs) 227 | if pretrained: 228 | model.load_state_dict(models.resnet18().state_dict()) 229 | return model 230 | 231 | 232 | def resnet34(pretrained=False, **kwargs): 233 | """Constructs a ResNet-34 model. 234 | Args: 235 | pretrained (bool): If True, returns a model pre-trained on ImageNet 236 | """ 237 | model = ResNet(BasicBlock, [3, 4, 6, 3], **kwargs) 238 | if pretrained: 239 | model.load_state_dict(models.resnet34().state_dict()) 240 | return model 241 | 242 | 243 | def resnet50(pretrained=False, **kwargs): 244 | """Constructs a ResNet-50 model. 245 | Args: 246 | pretrained (bool): If True, returns a model pre-trained on ImageNet 247 | """ 248 | model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs) 249 | if pretrained: 250 | model.load_state_dict(models.resnet50().state_dict()) 251 | return model 252 | 253 | 254 | def resnet101(pretrained=False, **kwargs): 255 | """Constructs a ResNet-101 model. 256 | Args: 257 | pretrained (bool): If True, returns a model pre-trained on ImageNet 258 | """ 259 | model = ResNet(Bottleneck, [3, 4, 23, 3], **kwargs) 260 | if pretrained: 261 | model.load_state_dict(models.resnet101().state_dict()) 262 | return model 263 | 264 | 265 | def resnet152(pretrained=False, **kwargs): 266 | """Constructs a ResNet-152 model. 267 | Args: 268 | pretrained (bool): If True, returns a model pre-trained on ImageNet 269 | """ 270 | model = ResNet(Bottleneck, [3, 8, 36, 3], **kwargs) 271 | if pretrained: 272 | model.load_state_dict(models.resnet152().state_dict()) 273 | return model 274 | 275 | 276 | if __name__ == "__main__": 277 | model = resnet152() 278 | print(model) -------------------------------------------------------------------------------- /PCN/chamfer_distance/__init__.py: -------------------------------------------------------------------------------- 1 | from .chamfer_distance import ChamferDistance 2 | -------------------------------------------------------------------------------- /PCN/chamfer_distance/__pycache__/__init__.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/PCN/chamfer_distance/__pycache__/__init__.cpython-37.pyc -------------------------------------------------------------------------------- /PCN/chamfer_distance/__pycache__/chamfer_distance.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/PCN/chamfer_distance/__pycache__/chamfer_distance.cpython-37.pyc -------------------------------------------------------------------------------- /PCN/chamfer_distance/chamfer_distance.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | // CUDA forward declarations 4 | int ChamferDistanceKernelLauncher( 5 | const int b, const int n, 6 | const float* xyz, 7 | const int m, 8 | const float* xyz2, 9 | float* result, 10 | int* result_i, 11 | float* result2, 12 | int* result2_i); 13 | 14 | int ChamferDistanceGradKernelLauncher( 15 | const int b, const int n, 16 | const float* xyz1, 17 | const int m, 18 | const float* xyz2, 19 | const float* grad_dist1, 20 | const int* idx1, 21 | const float* grad_dist2, 22 | const int* idx2, 23 | float* grad_xyz1, 24 | float* grad_xyz2); 25 | 26 | 27 | void chamfer_distance_forward_cuda( 28 | const at::Tensor xyz1, 29 | const at::Tensor xyz2, 30 | const at::Tensor dist1, 31 | const at::Tensor dist2, 32 | const at::Tensor idx1, 33 | const at::Tensor idx2) 34 | { 35 | ChamferDistanceKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 36 | xyz2.size(1), xyz2.data(), 37 | dist1.data(), idx1.data(), 38 | dist2.data(), idx2.data()); 39 | } 40 | 41 | void chamfer_distance_backward_cuda( 42 | const at::Tensor xyz1, 43 | const at::Tensor xyz2, 44 | at::Tensor gradxyz1, 45 | at::Tensor gradxyz2, 46 | at::Tensor graddist1, 47 | at::Tensor graddist2, 48 | at::Tensor idx1, 49 | at::Tensor idx2) 50 | { 51 | ChamferDistanceGradKernelLauncher(xyz1.size(0), xyz1.size(1), xyz1.data(), 52 | xyz2.size(1), xyz2.data(), 53 | graddist1.data(), idx1.data(), 54 | graddist2.data(), idx2.data(), 55 | gradxyz1.data(), gradxyz2.data()); 56 | } 57 | 58 | 59 | void nnsearch( 60 | const int b, const int n, const int m, 61 | const float* xyz1, 62 | const float* xyz2, 63 | float* dist, 64 | int* idx) 65 | { 66 | for (int i = 0; i < b; i++) { 67 | for (int j = 0; j < n; j++) { 68 | const float x1 = xyz1[(i*n+j)*3+0]; 69 | const float y1 = xyz1[(i*n+j)*3+1]; 70 | const float z1 = xyz1[(i*n+j)*3+2]; 71 | double best = 0; 72 | int besti = 0; 73 | for (int k = 0; k < m; k++) { 74 | const float x2 = xyz2[(i*m+k)*3+0] - x1; 75 | const float y2 = xyz2[(i*m+k)*3+1] - y1; 76 | const float z2 = xyz2[(i*m+k)*3+2] - z1; 77 | const double d=x2*x2+y2*y2+z2*z2; 78 | if (k==0 || d < best){ 79 | best = d; 80 | besti = k; 81 | } 82 | } 83 | dist[i*n+j] = best; 84 | idx[i*n+j] = besti; 85 | } 86 | } 87 | } 88 | 89 | 90 | void chamfer_distance_forward( 91 | const at::Tensor xyz1, 92 | const at::Tensor xyz2, 93 | const at::Tensor dist1, 94 | const at::Tensor dist2, 95 | const at::Tensor idx1, 96 | const at::Tensor idx2) 97 | { 98 | const int batchsize = xyz1.size(0); 99 | const int n = xyz1.size(1); 100 | const int m = xyz2.size(1); 101 | 102 | const float* xyz1_data = xyz1.data(); 103 | const float* xyz2_data = xyz2.data(); 104 | float* dist1_data = dist1.data(); 105 | float* dist2_data = dist2.data(); 106 | int* idx1_data = idx1.data(); 107 | int* idx2_data = idx2.data(); 108 | 109 | nnsearch(batchsize, n, m, xyz1_data, xyz2_data, dist1_data, idx1_data); 110 | nnsearch(batchsize, m, n, xyz2_data, xyz1_data, dist2_data, idx2_data); 111 | } 112 | 113 | 114 | void chamfer_distance_backward( 115 | const at::Tensor xyz1, 116 | const at::Tensor xyz2, 117 | at::Tensor gradxyz1, 118 | at::Tensor gradxyz2, 119 | at::Tensor graddist1, 120 | at::Tensor graddist2, 121 | at::Tensor idx1, 122 | at::Tensor idx2) 123 | { 124 | const int b = xyz1.size(0); 125 | const int n = xyz1.size(1); 126 | const int m = xyz2.size(1); 127 | 128 | const float* xyz1_data = xyz1.data(); 129 | const float* xyz2_data = xyz2.data(); 130 | float* gradxyz1_data = gradxyz1.data(); 131 | float* gradxyz2_data = gradxyz2.data(); 132 | float* graddist1_data = graddist1.data(); 133 | float* graddist2_data = graddist2.data(); 134 | const int* idx1_data = idx1.data(); 135 | const int* idx2_data = idx2.data(); 136 | 137 | for (int i = 0; i < b*n*3; i++) 138 | gradxyz1_data[i] = 0; 139 | for (int i = 0; i < b*m*3; i++) 140 | gradxyz2_data[i] = 0; 141 | for (int i = 0;i < b; i++) { 142 | for (int j = 0; j < n; j++) { 143 | const float x1 = xyz1_data[(i*n+j)*3+0]; 144 | const float y1 = xyz1_data[(i*n+j)*3+1]; 145 | const float z1 = xyz1_data[(i*n+j)*3+2]; 146 | const int j2 = idx1_data[i*n+j]; 147 | 148 | const float x2 = xyz2_data[(i*m+j2)*3+0]; 149 | const float y2 = xyz2_data[(i*m+j2)*3+1]; 150 | const float z2 = xyz2_data[(i*m+j2)*3+2]; 151 | const float g = graddist1_data[i*n+j]*2; 152 | 153 | gradxyz1_data[(i*n+j)*3+0] += g*(x1-x2); 154 | gradxyz1_data[(i*n+j)*3+1] += g*(y1-y2); 155 | gradxyz1_data[(i*n+j)*3+2] += g*(z1-z2); 156 | gradxyz2_data[(i*m+j2)*3+0] -= (g*(x1-x2)); 157 | gradxyz2_data[(i*m+j2)*3+1] -= (g*(y1-y2)); 158 | gradxyz2_data[(i*m+j2)*3+2] -= (g*(z1-z2)); 159 | } 160 | for (int j = 0; j < m; j++) { 161 | const float x1 = xyz2_data[(i*m+j)*3+0]; 162 | const float y1 = xyz2_data[(i*m+j)*3+1]; 163 | const float z1 = xyz2_data[(i*m+j)*3+2]; 164 | const int j2 = idx2_data[i*m+j]; 165 | const float x2 = xyz1_data[(i*n+j2)*3+0]; 166 | const float y2 = xyz1_data[(i*n+j2)*3+1]; 167 | const float z2 = xyz1_data[(i*n+j2)*3+2]; 168 | const float g = graddist2_data[i*m+j]*2; 169 | gradxyz2_data[(i*m+j)*3+0] += g*(x1-x2); 170 | gradxyz2_data[(i*m+j)*3+1] += g*(y1-y2); 171 | gradxyz2_data[(i*m+j)*3+2] += g*(z1-z2); 172 | gradxyz1_data[(i*n+j2)*3+0] -= (g*(x1-x2)); 173 | gradxyz1_data[(i*n+j2)*3+1] -= (g*(y1-y2)); 174 | gradxyz1_data[(i*n+j2)*3+2] -= (g*(z1-z2)); 175 | } 176 | } 177 | } 178 | 179 | 180 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 181 | m.def("forward", &chamfer_distance_forward, "ChamferDistance forward"); 182 | m.def("forward_cuda", &chamfer_distance_forward_cuda, "ChamferDistance forward (CUDA)"); 183 | m.def("backward", &chamfer_distance_backward, "ChamferDistance backward"); 184 | m.def("backward_cuda", &chamfer_distance_backward_cuda, "ChamferDistance backward (CUDA)"); 185 | } 186 | -------------------------------------------------------------------------------- /PCN/chamfer_distance/chamfer_distance.cu: -------------------------------------------------------------------------------- 1 | #include 2 | 3 | #include 4 | #include 5 | 6 | __global__ 7 | void ChamferDistanceKernel( 8 | int b, 9 | int n, 10 | const float* xyz, 11 | int m, 12 | const float* xyz2, 13 | float* result, 14 | int* result_i) 15 | { 16 | const int batch=512; 17 | __shared__ float buf[batch*3]; 18 | for (int i=blockIdx.x;ibest){ 130 | result[(i*n+j)]=best; 131 | result_i[(i*n+j)]=best_i; 132 | } 133 | } 134 | __syncthreads(); 135 | } 136 | } 137 | } 138 | 139 | void ChamferDistanceKernelLauncher( 140 | const int b, const int n, 141 | const float* xyz, 142 | const int m, 143 | const float* xyz2, 144 | float* result, 145 | int* result_i, 146 | float* result2, 147 | int* result2_i) 148 | { 149 | ChamferDistanceKernel<<>>(b, n, xyz, m, xyz2, result, result_i); 150 | ChamferDistanceKernel<<>>(b, m, xyz2, n, xyz, result2, result2_i); 151 | 152 | cudaError_t err = cudaGetLastError(); 153 | if (err != cudaSuccess) 154 | printf("error in chamfer distance updateOutput: %s\n", cudaGetErrorString(err)); 155 | } 156 | 157 | 158 | __global__ 159 | void ChamferDistanceGradKernel( 160 | int b, int n, 161 | const float* xyz1, 162 | int m, 163 | const float* xyz2, 164 | const float* grad_dist1, 165 | const int* idx1, 166 | float* grad_xyz1, 167 | float* grad_xyz2) 168 | { 169 | for (int i = blockIdx.x; i>>(b, n, xyz1, m, xyz2, grad_dist1, idx1, grad_xyz1, grad_xyz2); 204 | ChamferDistanceGradKernel<<>>(b, m, xyz2, n, xyz1, grad_dist2, idx2, grad_xyz2, grad_xyz1); 205 | 206 | cudaError_t err = cudaGetLastError(); 207 | if (err != cudaSuccess) 208 | printf("error in chamfer distance get grad: %s\n", cudaGetErrorString(err)); 209 | } 210 | -------------------------------------------------------------------------------- /PCN/chamfer_distance/chamfer_distance.py: -------------------------------------------------------------------------------- 1 | 2 | import torch 3 | 4 | from torch.utils.cpp_extension import load 5 | cd = load(name="cd", 6 | sources=["chamfer_distance/chamfer_distance.cpp", 7 | "chamfer_distance/chamfer_distance.cu"]) 8 | 9 | class ChamferDistanceFunction(torch.autograd.Function): 10 | @staticmethod 11 | def forward(ctx, xyz1, xyz2): 12 | batchsize, n, _ = xyz1.size() 13 | _, m, _ = xyz2.size() 14 | xyz1 = xyz1.contiguous() 15 | xyz2 = xyz2.contiguous() 16 | dist1 = torch.zeros(batchsize, n) 17 | dist2 = torch.zeros(batchsize, m) 18 | 19 | idx1 = torch.zeros(batchsize, n, dtype=torch.int) 20 | idx2 = torch.zeros(batchsize, m, dtype=torch.int) 21 | 22 | if not xyz1.is_cuda: 23 | cd.forward(xyz1, xyz2, dist1, dist2, idx1, idx2) 24 | else: 25 | dist1 = dist1.cuda() 26 | dist2 = dist2.cuda() 27 | idx1 = idx1.cuda() 28 | idx2 = idx2.cuda() 29 | cd.forward_cuda(xyz1, xyz2, dist1, dist2, idx1, idx2) 30 | 31 | ctx.save_for_backward(xyz1, xyz2, idx1, idx2) 32 | 33 | return dist1, dist2 34 | 35 | @staticmethod 36 | def backward(ctx, graddist1, graddist2): 37 | xyz1, xyz2, idx1, idx2 = ctx.saved_tensors 38 | 39 | graddist1 = graddist1.contiguous() 40 | graddist2 = graddist2.contiguous() 41 | 42 | gradxyz1 = torch.zeros(xyz1.size()) 43 | gradxyz2 = torch.zeros(xyz2.size()) 44 | 45 | if not graddist1.is_cuda: 46 | cd.backward(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 47 | else: 48 | gradxyz1 = gradxyz1.cuda() 49 | gradxyz2 = gradxyz2.cuda() 50 | cd.backward_cuda(xyz1, xyz2, gradxyz1, gradxyz2, graddist1, graddist2, idx1, idx2) 51 | 52 | return gradxyz1, gradxyz2 53 | 54 | 55 | class ChamferDistance(torch.nn.Module): 56 | def forward(self, xyz1, xyz2): 57 | return ChamferDistanceFunction.apply(xyz1, xyz2) 58 | -------------------------------------------------------------------------------- /PCN/data_loader.py: -------------------------------------------------------------------------------- 1 | import open3d 2 | import numpy as np 3 | from torch.utils.data.dataset import Dataset 4 | import torch 5 | 6 | class Shapenet_dataset(Dataset): 7 | def __init__(self,liste): 8 | self.list = liste 9 | 10 | def __getitem__(self, index): 11 | pcd_gt = torch.from_numpy(np.asarray(open3d.read_point_cloud("./shapenet/test/complete/"+self.list[index]+".pcd").points)).float() 12 | pcd_input = self.resample_pcd(torch.from_numpy(np.asarray(open3d.read_point_cloud("./shapenet/test/partial/"+self.list[index]+".pcd").points)).float(),1024) 13 | return (pcd_input,pcd_gt) 14 | 15 | def __len__(self): 16 | return len(self.list) 17 | 18 | def resample_pcd(self,pcd, n): 19 | """Drop or duplicate points so that pcd has exactly n points""" 20 | idx = np.random.permutation(pcd.shape[0]) 21 | if idx.shape[0] < n: 22 | idx = np.concatenate([idx, np.random.randint(pcd.shape[0], size=n-pcd.shape[0])]) 23 | return pcd[idx[:n]] 24 | 25 | def load_data(path): 26 | with open(path+"/test.list") as file: 27 | model_list = file.read().splitlines() 28 | return Shapenet_dataset(model_list) 29 | 30 | if __name__ == '__main__': 31 | load_data("shapenet") 32 | -------------------------------------------------------------------------------- /PCN/example.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/PCN/example.png -------------------------------------------------------------------------------- /PCN/main.py: -------------------------------------------------------------------------------- 1 | import open3d 2 | import models 3 | import data_loader 4 | import torch.optim as optim 5 | import torch 6 | import numpy as np 7 | from torch.utils.tensorboard import SummaryWriter 8 | imprt tqdm 9 | 10 | model = models.PCNEMD().cuda() 11 | # model = torch.load("./models/model_10000.ckpt").cuda() 12 | alpha = [ 0.01,0.1,0.5,1.0] 13 | lr = 1e-6 14 | optimizer = optim.Adam([{'params': model.parameters(), 'lr': lr}]) 15 | dataset = data_loader.load_data("shapenet") 16 | my_dataset_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=1,shuffle=False) 17 | epochs = 70000 18 | writer = SummaryWriter() 19 | a = 1 20 | for p in tqdm.tqdm(range(0,epochs)): 21 | lost = [] 22 | for input_tensor,gt_tensor in my_dataset_loader: 23 | optimizer.zero_grad 24 | input_tensor =input_tensor.cuda() 25 | gt_tensor = gt_tensor.cuda() 26 | coarse,fine = model(input_tensor) 27 | loss = model.create_loss(coarse,fine,gt_tensor,alpha[a]) 28 | loss.backward() 29 | lost.append(loss.data.item()) 30 | optimizer.step() 31 | 32 | if(p%10000==0 and p!=0): 33 | torch.save(model, "./models/model_"+str(p)+".ckpt") 34 | model.cuda() 35 | lr = lr/10 36 | for param_group in optimizer.param_groups: 37 | param_group['lr'] = lr 38 | 39 | if(p%10==0): 40 | writer.add_scalar('Loss',np.array(lost).mean(), p) 41 | if(p==10000): 42 | a = a + 1 43 | pcd = open3d.PointCloud() 44 | pcd.points = open3d.Vector3dVector(fine.data.cpu().numpy()[0]+np.array([1.0,0.0,0.0])) 45 | pcd.colors = open3d.Vector3dVector(np.ones((fine.shape[1],3))* [0.76,0.23,0.14]) 46 | 47 | pcd2 = open3d.PointCloud() 48 | pcd2.points = open3d.Vector3dVector(gt_tensor.data.cpu().numpy()[0]+np.array([-1.0,0.0,0.0])) 49 | pcd2.colors = open3d.Vector3dVector(np.ones((gt_tensor.shape[1],3))* [0.16,0.23,0.14]) 50 | 51 | pcd3 = open3d.PointCloud() 52 | pcd3.points = open3d.Vector3dVector(input_tensor.data.cpu().numpy()[0]) 53 | pcd3.colors = open3d.Vector3dVector(np.ones((input_tensor.shape[1],3))* [0.16,0.23,0.14]) 54 | open3d.draw_geometries([pcd,pcd2,pcd3]) 55 | if(p==20000): 56 | a = a + 1 57 | pcd = open3d.PointCloud() 58 | pcd.points = open3d.Vector3dVector(fine.data.cpu().numpy()[0]+np.array([1.0,0.0,0.0])) 59 | pcd.colors = open3d.Vector3dVector(np.ones((fine.shape[1],3))* [0.76,0.23,0.14]) 60 | 61 | pcd2 = open3d.PointCloud() 62 | pcd2.points = open3d.Vector3dVector(gt_tensor.data.cpu().numpy()[0]+np.array([-1.0,0.0,0.0])) 63 | pcd2.colors = open3d.Vector3dVector(np.ones((gt_tensor.shape[1],3))* [0.16,0.23,0.14]) 64 | 65 | pcd3 = open3d.PointCloud() 66 | pcd3.points = open3d.Vector3dVector(input_tensor.data.cpu().numpy()[0]) 67 | pcd3.colors = open3d.Vector3dVector(np.ones((input_tensor.shape[1],3))* [0.16,0.23,0.14]) 68 | open3d.draw_geometries([pcd,pcd2,pcd3]) 69 | if(p==50000): 70 | a = a + 1 71 | pcd = open3d.PointCloud() 72 | pcd.points = open3d.Vector3dVector(fine.data.cpu().numpy()[0]+np.array([1.0,0.0,0.0])) 73 | pcd.colors = open3d.Vector3dVector(np.ones((fine.shape[1],3))* [0.76,0.23,0.14]) 74 | 75 | pcd2 = open3d.PointCloud() 76 | pcd2.points = open3d.Vector3dVector(gt_tensor.data.cpu().numpy()[0]+np.array([-1.0,0.0,0.0])) 77 | pcd2.colors = open3d.Vector3dVector(np.ones((gt_tensor.shape[1],3))* [0.16,0.23,0.14]) 78 | 79 | pcd3 = open3d.PointCloud() 80 | pcd3.points = open3d.Vector3dVector(input_tensor.data.cpu().numpy()[0]) 81 | pcd3.colors = open3d.Vector3dVector(np.ones((input_tensor.shape[1],3))* [0.16,0.23,0.14]) 82 | open3d.draw_geometries([pcd,pcd2,pcd3]) 83 | 84 | -------------------------------------------------------------------------------- /PCN/models.py: -------------------------------------------------------------------------------- 1 | import open3d 2 | import torch 3 | import torch.nn as nn 4 | import torch.nn.init as init 5 | import torch.nn.functional as F 6 | import numpy as np 7 | from time import time 8 | from emd import earth_mover_distance 9 | from chamfer_distance import ChamferDistance 10 | 11 | chamfer_dist = ChamferDistance() 12 | 13 | class PCNEMD(nn.Module): 14 | def __init__(self): 15 | super(PCNEMD, self).__init__() 16 | self.num_coarse = 1024 17 | self.grid_size = 4 18 | self.grid_scale = 0.05 19 | self.num_fine = self.grid_size ** 2 * self.num_coarse 20 | self.npts = [1] 21 | #alpha = [10000, 20000, 50000],[0.01, 0.1, 0.5, 1.0] 22 | #### ENCODER 23 | 24 | ## first mlp 25 | mlps1 = [128, 256] 26 | first_mlp_list = [] 27 | in_features = 3 28 | for m in range(0,len(mlps1)-1): 29 | first_mlp_list.append(nn.Conv1d(in_features, mlps1[m], 1)) 30 | first_mlp_list.append(nn.ReLU()) 31 | in_features = mlps1[m] 32 | first_mlp_list.append(nn.Conv1d(in_features, mlps1[-1], 1)) 33 | self.first_mpl = nn.Sequential(*first_mlp_list) 34 | 35 | 36 | ## Second mlp 37 | mlps2 = [512, 1024] 38 | second_mlp_list = [] 39 | in_features = 512 40 | for m in range(0,len(mlps2)-1): 41 | second_mlp_list.append(nn.Conv1d(in_features, mlps2[m], 1)) 42 | second_mlp_list.append(nn.ReLU()) 43 | in_features = mlps2[m] 44 | second_mlp_list.append(nn.Conv1d(in_features, mlps2[-1], 1)) 45 | self.second_mpl = nn.Sequential(*second_mlp_list) 46 | 47 | 48 | #### DECODER 49 | coarse1 = [1024,1024,self.num_coarse*3] 50 | in_features = 1024 51 | decoder_list = [] 52 | for m in range(0,len(coarse1)-1): 53 | decoder_list.append(nn.Linear(in_features, coarse1[m])) 54 | in_features = coarse1[m] 55 | decoder_list.append(nn.Linear(in_features, coarse1[-1])) 56 | self.decoder = nn.Sequential(*decoder_list) 57 | 58 | ## FOLDING 59 | mlpsfold = [512, 512,3] 60 | fold_mlp_list = [] 61 | in_features = 1029 62 | for m in range(0,len(mlpsfold)-1): 63 | fold_mlp_list.append(nn.Conv1d(in_features, mlpsfold[m], 1)) 64 | fold_mlp_list.append(nn.ReLU()) 65 | in_features = mlpsfold[m] 66 | fold_mlp_list.append(nn.Conv1d(in_features, mlpsfold[-1], 1)) 67 | self.fold_mpl = nn.Sequential(*fold_mlp_list) 68 | 69 | def point_maxpool(self,features,npts,keepdims=True): 70 | # splitted = torch.split(features,npts[0],dim=1) 71 | # outputs = [torch.max(f,dim=2,keepdims=keepdims)[0] for f in splitted] 72 | # return torch.cat(outputs,dim=0) 73 | return torch.max(features,dim=2,keepdims=keepdims)[0] 74 | 75 | 76 | def point_unpool(self,features,npts): 77 | # features = torch.split(features,features.shape[0],dim=0) 78 | # outputs = [f.repeat([1,npts[i],1]) for i,f in enumerate(features)] 79 | # return torch.cat(outputs,dim=1) 80 | return features.repeat([1,1,256]) 81 | 82 | 83 | def forward(self, xyz): 84 | xyz = xyz.permute(0,2,1) 85 | #####ENCODER 86 | features = self.first_mpl(xyz) 87 | features_global = self.point_maxpool(features.permute(0,2,1),self.npts,keepdims=True) 88 | features_global = self.point_unpool(features_global,self.npts) 89 | 90 | features = torch.cat([features,features_global.permute(0,2,1)],dim=1) 91 | features = self.second_mpl(features) 92 | features = self.point_maxpool(features.permute(0,2,1),self.npts).squeeze(2) 93 | 94 | ##DECODER 95 | coarse = self.decoder(features) 96 | coarse = coarse.view(-1,self.num_coarse,3) 97 | 98 | ##FOLDING 99 | grid_row = torch.linspace(-0.05,0.05,self.grid_size).cuda() 100 | grid_column = torch.linspace(-0.05,0.05,self.grid_size).cuda() 101 | grid = torch.meshgrid(grid_row,grid_column) 102 | grid = torch.reshape(torch.stack(grid,dim=2),(-1,2)).unsqueeze(0) 103 | grid_feat = grid.repeat([features.shape[0],self.num_coarse,1]) 104 | # print("grid_Feat",grid_feat.shape) 105 | 106 | point_feat = coarse.unsqueeze(2).repeat([1,1,self.grid_size**2,1]) 107 | point_feat = torch.reshape(point_feat, [-1,self.num_fine,3]) 108 | # print("point_Feat",point_feat.shape) 109 | global_feat = features.unsqueeze(1).repeat([1,self.num_fine,1]) 110 | # print("global_Feat",global_feat.shape) 111 | feat = torch.cat([grid_feat,point_feat,global_feat],dim=2) 112 | 113 | center = coarse.unsqueeze(2).repeat([1,1,self.grid_size**2,1]) 114 | center = torch.reshape(center, [-1,self.num_fine,3]) 115 | 116 | fine = self.fold_mpl(feat.permute(0,2,1)) 117 | # print("fine shape",fine.shape," center shape",center.shape) 118 | fine = fine.permute(0,2,1) + center 119 | 120 | return coarse, fine 121 | 122 | def create_loss(self,coarse,fine,gt,alpha): 123 | gt_ds = gt[:,:coarse.shape[1],:] 124 | loss_coarse = earth_mover_distance(coarse, gt_ds, transpose=False) 125 | dist1, dist2 = chamfer_dist(fine, gt) 126 | loss_fine = (torch.mean(dist1)) + (torch.mean(dist2)) 127 | 128 | loss = loss_coarse + alpha * loss_fine 129 | 130 | return loss 131 | 132 | 133 | 134 | 135 | if __name__ == '__main__': 136 | # alpha [ 0.01,0.1,0.5,1.0] 137 | for i in range(10): 138 | xyz = torch.rand(1, 1024,3).cuda() 139 | pcd1 = open3d.PointCloud() 140 | pcd1.points = open3d.Vector3dVector(xyz.data.cpu().numpy()[0]) 141 | pcd1.colors = open3d.Vector3dVector(np.ones((1024,3))* [0.00,0.53,0.90]) 142 | colors = torch.rand(1, 2048,3).cuda() 143 | net = PCNEMD() 144 | net.cuda() 145 | coarse, fine = net(xyz) 146 | net.create_loss(coarse,fine,xyz,1.0) 147 | 148 | pcd = open3d.PointCloud() 149 | pcd.points = open3d.Vector3dVector(coarse.data.cpu().numpy()[0]+np.array([1.0,0.0,0.0])) 150 | pcd.colors = open3d.Vector3dVector(np.ones((1024,3))* [0.76,0.23,0.14]) 151 | 152 | pcd2 = open3d.PointCloud() 153 | pcd2.points = open3d.Vector3dVector(fine.data.cpu().numpy()[0]+np.array([-1.0,0.0,0.0])) 154 | pcd2.colors = open3d.Vector3dVector(np.ones((fine.shape[1],3))* [0.16,0.53,0.44]) 155 | open3d.draw_geometries([pcd,pcd1,pcd2]) 156 | exit() 157 | -------------------------------------------------------------------------------- /PCN/test.py: -------------------------------------------------------------------------------- 1 | import open3d 2 | import models 3 | import data_loader 4 | import torch.optim as optim 5 | import torch 6 | import numpy as np 7 | from torch.utils.tensorboard import SummaryWriter 8 | 9 | 10 | model = torch.load("./models/model_60000.ckpt").cuda() 11 | dataset = data_loader.load_data("shapenet") 12 | my_dataset_loader = torch.utils.data.DataLoader(dataset=dataset,batch_size=1,shuffle=False) 13 | for input_tensor,gt_tensor in my_dataset_loader: 14 | input_tensor =input_tensor.cuda() 15 | gt_tensor = gt_tensor.cuda() 16 | coarse,fine = model(input_tensor) 17 | print(coarse.shape,fine.shape,gt_tensor.shape) 18 | pcd = open3d.PointCloud() 19 | pcd.points = open3d.Vector3dVector(fine.data.cpu().numpy()[0]+np.array([1.0,0.0,0.0])) 20 | pcd.colors = open3d.Vector3dVector(np.ones((fine.shape[1],3))* [0.76,0.23,0.14]) 21 | 22 | pcd2 = open3d.PointCloud() 23 | pcd2.points = open3d.Vector3dVector(gt_tensor.data.cpu().numpy()[0]+np.array([-1.0,0.0,0.0])) 24 | pcd2.colors = open3d.Vector3dVector(np.ones((gt_tensor.shape[1],3))* [0.16,0.23,0.14]) 25 | 26 | pcd3 = open3d.PointCloud() 27 | pcd3.points = open3d.Vector3dVector(input_tensor.data.cpu().numpy()[0]) 28 | pcd3.colors = open3d.Vector3dVector(np.ones((input_tensor.shape[1],3))* [0.16,0.23,0.14]) 29 | open3d.draw_geometries([pcd,pcd2,pcd3]) 30 | -------------------------------------------------------------------------------- /PointCNN/PointCNN.py: -------------------------------------------------------------------------------- 1 | from Utils import knn_indices_func_cpu 2 | import torch 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import torch.nn as nn 6 | from torch import cuda, FloatTensor, LongTensor 7 | from xConv import XConv,Dense 8 | from collections import OrderedDict 9 | 10 | 11 | class PointCnnLayer(nn.Module): 12 | def __init__(self,pc,features,settings): 13 | """ 14 | :param points: Points cloud 15 | :param features: features as input, can be None if nothing known 16 | :param settings: Settings of the network, inside there are: 17 | :setting xcon parameters : parameters for the xconvolutions : xconv_param_name = ('K', 'D', 'P', 'C') "C = C_out" 8, 1, -1, 32 * x 18 | : setting fc parameters : parameters for the fully convolutional part of the network : fc_param_name = ('C', 'dropout_rate') 19 | """ 20 | super(PointCnnLayer,self).__init__() 21 | N = pc 22 | #print("There are " + str(N) + " points in the begining") 23 | with_X_transformation = True ## investigating that 24 | sorting_method = None ## sorting or not points along a dimension 25 | sampling = 'fps' ## investigating that 26 | self.nb_xconv = len(settings[0]) 27 | self.settings = settings 28 | 29 | #C_mid = C_out // 2 if C_in == 0 else C_out // 4 30 | #depth_multiplier = min(int(np.ceil(C_out / C_in)), 4)]) 31 | self.xvonc1 = XConv(C_in = 0, 32 | C_out = settings[0][0].get('C'), 33 | dims = 3, 34 | K = settings[0][0].get('K'), 35 | P = N, 36 | C_mid = settings[0][0].get('C') //2 , 37 | depth_multiplier = 4) ## First XConvolution 38 | 39 | self.dense1 = Dense(settings[0][0].get('C'),settings[0][1].get('C')//2,drop=0) 40 | self.xvonc2 = XConv(C_in = settings[0][1].get('C')//2, 41 | C_out = settings[0][1].get('C'), 42 | dims = 3, 43 | K = settings[0][1].get('K'), 44 | P = self.settings[0][1].get('P'), 45 | C_mid = settings[0][1].get('C') // 4 , 46 | depth_multiplier = settings[0][0].get('C')//4) ## Second XConvolution 47 | 48 | self.dense2 = Dense(settings[0][1].get('C'),settings[0][2].get('C')//2,drop=0) 49 | self.xvonc3 = XConv(C_in = settings[0][2].get('C')//2, 50 | C_out = settings[0][2].get('C'), 51 | dims = 3, 52 | K = settings[0][2].get('K'), 53 | P = self.settings[0][2].get('P'), 54 | C_mid = settings[0][2].get('C') // 4 , 55 | depth_multiplier = settings[0][1].get('C')//4) ## Third XConvolution 56 | 57 | self.dense3 = Dense(settings[0][2].get('C'),settings[0][3].get('C')//2,drop=0) 58 | self.xvonc4 = XConv(C_in = settings[0][3].get('C')//2, 59 | C_out = settings[0][3].get('C'), 60 | dims = 3, 61 | K = settings[0][3].get('K'), 62 | P = self.settings[0][3].get('P'), 63 | C_mid = settings[0][3].get('C') // 4 , 64 | depth_multiplier = settings[0][2].get('C')//4) ## Third XConvolution 65 | self.layers_conv =[self.xvonc1,self.xvonc2,self.xvonc3,self.xvonc4] 66 | ## deconvolution inputs : pts (output of previous conv/deconv/), fts (output of previous conv/deconv/), N , K, D, P (of concatenated layer output), 67 | #C (of concatenated layer output),C_prev (C of previous layer ) // 4 , depth_multiplier = 1 68 | # xdconv_param_name = ('K', 'D', 'pts_layer_idx', 'qrs_layer_idx') 69 | #print(self.layers_conv) 70 | deconvolutions = OrderedDict() 71 | dense_deconv = OrderedDict() 72 | fc = OrderedDict() 73 | for i in range(0,len(settings[1])): 74 | deconvolutions["deconv" + str(i)] = XConv(C_in = self.layers_conv[settings[1][i].get('pts_layer_idx')].C , 75 | C_out = self.layers_conv[settings[1][i].get('qrs_layer_idx')].C, 76 | dims = 3, 77 | K = settings[1][i].get('K'), 78 | P = self.layers_conv[settings[1][i].get('qrs_layer_idx')].P, 79 | C_mid = self.layers_conv[settings[1][i].get('pts_layer_idx')].C//4 , 80 | depth_multiplier = 1) ## First dXConvolution 81 | 82 | dense_deconv["dense" + str(i)] = Dense(self.layers_conv[settings[1][i].get('qrs_layer_idx')].C*2,self.layers_conv[settings[1][i].get('qrs_layer_idx')].C) 83 | self.deconvolutions = nn.Sequential(deconvolutions) 84 | self.dense_deconv = nn.Sequential(dense_deconv) 85 | 86 | 87 | for i in range(0,len(settings[2])): 88 | fc["fc" + str(i)] = Dense(self.dense_deconv[-1].out if i==0 else fc["fc" + str(i-1)].out, 89 | settings[2][i].get('C'), 90 | drop=settings[2][i].get('dropout_rate'), 91 | acti=True) 92 | 93 | self.fc = nn.Sequential(fc) 94 | 95 | 96 | self.sigmoid = nn.Sigmoid() 97 | def forward(self,x): 98 | layer_pts = [x] 99 | outs = [None] 100 | pts_regionals = [] 101 | fts_regionals = [] 102 | #print("First CONVOLUTION") 103 | if(self.settings[0][0].get('P')!=-1): 104 | idxx = np.random.choice(x.size()[1], self.settings[0][0].get('P'), replace = False).tolist() ## select representative points 105 | rep_pts = x[:,idxx,:] 106 | else: 107 | rep_pts = x 108 | pts_idx = knn_indices_func_cpu(rep_pts,x,self.settings[0][0].get('K'),self.settings[0][0].get('D') ) 109 | #pts_idx = pts_idx[:,::self.settings[0][0].get('D'),:] 110 | pts_regional = torch.stack([x[n][idx,:] for n, idx in enumerate(torch.unbind(pts_idx, dim = 0))], dim = 0) 111 | out = self.xvonc1(rep_pts,pts_regional,None) ## FTS 112 | layer_pts.append(rep_pts) 113 | outs.append(out) 114 | pts_regionals.append(pts_regional) 115 | fts_regionals.append(None) 116 | 117 | #print("SECOND CONVOLUTION") 118 | if(not (self.settings[0][1].get('P')==-1)): 119 | # print("been there" + str(self.settings[0][1].get('P'))) 120 | idxx = np.random.choice(rep_pts.size()[1], self.settings[0][1].get('P'), replace = False).tolist() ## select representative points 121 | rep_pts2 = rep_pts[:,idxx,:] 122 | else: 123 | rep_pts2 = rep_pts 124 | fts = self.dense1(out) 125 | pts_idx = knn_indices_func_cpu(rep_pts2,rep_pts,self.settings[0][1].get('K') ,self.settings[0][1].get('D')) 126 | #pts_idx = pts_idx[:,:,::self.settings[0][1].get('D')] 127 | pts_regional = torch.stack([rep_pts[n][idx,:] for n, idx in enumerate(torch.unbind(pts_idx, dim = 0))], dim = 0) 128 | fts_regional = torch.stack([ fts[n][idx,:] for n, idx in enumerate(torch.unbind(pts_idx, dim = 0))], dim = 0) 129 | out2 = self.xvonc2(rep_pts2,pts_regional,fts_regional) 130 | layer_pts.append(rep_pts2) 131 | outs.append(out2) 132 | pts_regionals.append(pts_regional) 133 | fts_regionals.append(fts_regional) 134 | 135 | #print("THIRD CONVOLUTION") 136 | if(not (self.settings[0][2].get('P')==-1)): 137 | #print("been there" + str(self.settings[0][1].get('P'))) 138 | idxx = np.random.choice(rep_pts2.size()[1], self.settings[0][2].get('P'), replace = False).tolist() ## select representative points 139 | rep_pts3 = rep_pts2[:,idxx,:] 140 | else: 141 | rep_pts3 = rep_pts2 142 | fts = self.dense2(out2) 143 | pts_idx = knn_indices_func_cpu(rep_pts3,rep_pts2,self.settings[0][2].get('K') ,self.settings[0][2].get('D')) 144 | #pts_idx = pts_idx[:,:,::self.settings[0][2].get('D')] 145 | pts_regional = torch.stack([rep_pts2[n][idx,:] for n, idx in enumerate(torch.unbind(pts_idx, dim = 0))], dim = 0) 146 | fts_regional = torch.stack([ fts[n][idx,:] for n, idx in enumerate(torch.unbind(pts_idx, dim = 0))], dim = 0) 147 | out3 = self.xvonc3(rep_pts3,pts_regional,fts_regional) 148 | layer_pts.append(rep_pts3) 149 | outs.append(out3) 150 | pts_regionals.append(pts_regional) 151 | fts_regionals.append(fts_regional) 152 | 153 | #print("FOURTH CONVOLUTION") 154 | if(not (self.settings[0][3].get('P')==-1)): 155 | # print("been there" + str(self.settings[0][1].get('P'))) 156 | idxx = np.random.choice(rep_pts3.size()[1], self.settings[0][3].get('P'), replace = False).tolist() ## select representative points 157 | rep_pts4 = rep_pts3[:,idxx,:] 158 | else: 159 | rep_pts4 = rep_pts3 160 | fts = self.dense3(out3) 161 | #print("dimensions rep pts : " +str(rep_pts4.shape) +" : " + str(rep_pts3.shape)) 162 | #print("inputs " + str(rep_pts4.shape) + " : " + str(rep_pts3.shape)) 163 | pts_idx = knn_indices_func_cpu(rep_pts4,rep_pts3,self.settings[0][3].get('K'),self.settings[0][3].get('D')) 164 | #pts_idx = pts_idx[:,:,::self.settings[0][3].get('D')] 165 | pts_regional = torch.stack([rep_pts3[n][idx,:] for n, idx in enumerate(torch.unbind(pts_idx, dim = 0))], dim = 0) 166 | fts_regional = torch.stack([ fts[n][idx,:] for n, idx in enumerate(torch.unbind(pts_idx, dim = 0))], dim = 0) 167 | out4 = self.xvonc4(rep_pts4,pts_regional,fts_regional) 168 | layer_pts.append(rep_pts4) 169 | outs.append(out4) 170 | pts_regionals.append(pts_regional) 171 | fts_regionals.append(fts_regional) 172 | 173 | ############################END CONVOLUTION, START DECONVOLUTIONS #################### 174 | 175 | for i in range(0,len(self.deconvolutions)): 176 | 177 | #print("DECONVOLUTION " + str(i)) 178 | this_out = outs[self.settings[1][i].get('pts_layer_idx')+1] if i == 0 else outs[-1] 179 | rep = layer_pts[self.settings[1][i].get('qrs_layer_idx')+1] 180 | rep2 = layer_pts[self.settings[1][i].get('pts_layer_idx')+1] 181 | #print("dimensions rep pts : " +str(rep.shape) + " : "+ str(rep2.shape) ) 182 | pts_idx = knn_indices_func_cpu(rep, 183 | rep2, 184 | self.settings[1][i].get('K'), 185 | self.settings[1][i].get('D')) 186 | #pts_idx = pts_idx[:,:,::self.settings[0][3].get('D')] 187 | pts_regional = torch.stack([rep2[n][idx,:] for n, idx in enumerate(torch.unbind(pts_idx, dim = 0))], dim = 0) 188 | this_out = torch.stack([ this_out[n][idx,:] for n, idx in enumerate(torch.unbind(pts_idx, dim = 0))], dim = 0) 189 | #print("features in : " + str( this_out.shape)) 190 | out = self.deconvolutions[i](rep, pts_regional,this_out) 191 | 192 | out = torch.cat((out,outs[self.settings[1][i].get('qrs_layer_idx')+1]),-1) 193 | #print("OUT CONTATENATED : " + str(out.shape)) 194 | densed = self.dense_deconv[i](out) 195 | #print("DENSED : "+ str(densed.shape)) 196 | outs.append( densed ) 197 | 198 | ############################END DECONVOLUTIONS, START Fully connected #################### 199 | #print("FULLY_CONNECTED") 200 | output = outs[-1] 201 | for i in range(0,len(self.fc)): 202 | output = self.fc[i](output) 203 | 204 | return output # self.sigmoid(output) 205 | 206 | 207 | if __name__ == "__main__": 208 | x = 8 209 | xyz = torch.rand(1, 30000,3).cuda() 210 | xconv_param_name = ('K', 'D', 'P', 'C') 211 | xconv_params = [dict(zip(xconv_param_name, xconv_param)) for xconv_param in 212 | [(8, 1, -1, 256), 213 | (12, 2, 768, 256), 214 | (16, 2, 384, 512), 215 | (16, 4, 128, 1024)]] 216 | 217 | xdconv_param_name = ('K', 'D', 'pts_layer_idx', 'qrs_layer_idx') 218 | xdconv_params = [dict(zip(xdconv_param_name, xdconv_param)) for xdconv_param in 219 | [(16, 4, 3, 3), 220 | (16, 2, 3, 2), 221 | (12, 2, 2, 1), 222 | (8, 2, 1, 0)]] 223 | 224 | fc_param_name = ('C', 'dropout_rate') 225 | fc_params = [dict(zip(fc_param_name, fc_param)) for fc_param in 226 | [(32 * x, 0.0), 227 | (32 * x, 0.5), 228 | (2,0.5)]] 229 | 230 | model = PointCnnLayer(xyz,["features"],[ xconv_params,xdconv_params,fc_params ]).cuda() 231 | #print(model) 232 | out = model(xyz) 233 | print(out.shape) 234 | -------------------------------------------------------------------------------- /PointCNN/Utils.py: -------------------------------------------------------------------------------- 1 | import open3d 2 | import torch 3 | from torch.autograd import Variable 4 | import numpy as np 5 | import torch.nn as nn 6 | import torch.optim as optim 7 | from torch.utils import data 8 | from torch import cuda, FloatTensor, LongTensor 9 | from xConv import XConv,Dense 10 | import sys 11 | 12 | from sklearn.neighbors import NearestNeighbors 13 | #import matplotlib.pyplot as plt 14 | params = {'batch_size': 5, 15 | 'shuffle': False, 16 | 'num_workers': 4} 17 | def knn_indices_func_cpu(rep_pts : FloatTensor, # (N, pts, dim) 18 | pts : FloatTensor, # (N, x, dim) 19 | K : int, D : int 20 | ) -> LongTensor: # (N, pts, K) 21 | """ 22 | CPU-based Indexing function based on K-Nearest Neighbors search. 23 | :param rep_pts: Representative points. 24 | :param pts: Point cloud to get indices from. 25 | :param K: Number of nearest neighbors to collect. 26 | :param D: dilatation factor 27 | :return: Array of indices, P_idx, into pts such that pts[n][P_idx[n],:] 28 | is the set k-nearest neighbors for the representative points in pts[n]. 29 | """ 30 | if rep_pts.is_cuda: 31 | rep_pts = rep_pts.cpu() 32 | if pts.is_cuda: 33 | pts = pts.cpu() 34 | rep_pts = rep_pts.data.numpy() 35 | pts = pts.data.numpy() 36 | 37 | region_idx = [] 38 | 39 | for n, p in enumerate(rep_pts): 40 | P_particular = pts[n] 41 | nbrs = NearestNeighbors(D*K + 1, algorithm = "auto").fit(P_particular) 42 | indices = nbrs.kneighbors(p)[1] 43 | region_idx.append(indices[:,1::D]) 44 | 45 | region_idx = torch.from_numpy(np.stack(region_idx, axis = 0)) 46 | 47 | return region_idx 48 | -------------------------------------------------------------------------------- /PointCNN/xConv.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch.autograd import Variable 3 | import numpy as np 4 | import torch.nn as nn 5 | from torch import cuda, FloatTensor, LongTensor 6 | from typing import Tuple, Callable, Optional 7 | from typing import Union 8 | 9 | class Dense(nn.Module): 10 | def __init__(self,inn,out,drop=0,acti=True): 11 | super(Dense,self).__init__() 12 | self.inn = inn 13 | self.out = out 14 | self.acti = acti 15 | self.drop = drop 16 | self.linear = nn.Linear(inn,out) 17 | self.elu = nn.ELU() 18 | if(self.drop>0): 19 | self.dropout = nn.Dropout(drop) 20 | 21 | def forward(self,x): 22 | out = self.linear(x.float()) 23 | if(self.acti): 24 | out = self.elu(out) 25 | if (self.drop>0): 26 | out = self.dropout(out) 27 | return out 28 | return out 29 | 30 | 31 | 32 | class XConv (nn.Module): 33 | 34 | def __init__(self, C_in : int, C_out : int, dims : int, K : int, P : int, C_mid : int, depth_multiplier : int) : 35 | """ 36 | :param C_in: Input dimension of the points' features. 37 | :param C_out: Output dimension of the representative point features. 38 | :param dims: Spatial dimensionality of points. 39 | :param K: Number of neighbors to convolve over. 40 | :param P: Number of representative points. 41 | :param C_mid: Dimensionality of lifted point features. 42 | :param depth_multiplier: Depth multiplier for internal depthwise separable convolution. 43 | """ 44 | super(XConv, self).__init__() 45 | self.dense1 = Dense(dims, C_mid) 46 | self.dense2 = Dense(C_mid, C_mid) 47 | self.C_in = C_in 48 | self.C = C_out 49 | self.P = P 50 | self.K = K 51 | ###get x ### 52 | #x = x.permute(0,3,1,2)# 53 | self.conv1 = nn.Sequential( 54 | nn.Conv2d(dims, K*K, (1, K), bias = True), 55 | nn.ELU() 56 | ) 57 | #x = x.permute(0,2,3,1)# 58 | self.x_dense1 = Dense(K*K, K*K) 59 | self.x_dense2 = Dense(K*K, K*K, acti = False) 60 | 61 | ### end Conv ### 62 | #x = x.permute(0,3,1,2)# 63 | self.conv2 = nn.Sequential( 64 | nn.Conv2d(C_mid + C_in, (C_mid + C_in) * depth_multiplier, (1, K), groups = C_mid + C_in), 65 | nn.Conv2d( (C_mid + C_in) * depth_multiplier, C_out, 1, bias = True), 66 | nn.ELU(), 67 | nn.BatchNorm2d(C_out, momentum = 0.9) 68 | ) 69 | #x = x.permute(0,2,3,1)# 70 | def forward(self, rep_pt,pts,fts): 71 | """ 72 | Applies XConv to the input data. 73 | :param x: (rep_pt, pts, fts) where 74 | - rep_pt: Representative point. 75 | - pts: Regional point cloud such that fts[:,p_idx,:] is the feature 76 | associated with pts[:,p_idx,:]. 77 | - fts: Regional features such that pts[:,p_idx,:] is the feature 78 | associated with fts[:,p_idx,:]. 79 | :return: Features aggregated into point rep_pt. 80 | _, indices_dilated = pf.knn_indices_general(qrs, pts, K * D, True) 81 | indices = indices_dilated[:, :, ::D, :] 82 | """ 83 | 84 | N = len(pts) 85 | P = rep_pt.shape[1] # (N, P, K, dims) 86 | p_center = torch.unsqueeze(rep_pt, dim = 2) # (N, P, 1, dims) 87 | ##FIRST STEP : Move pts to local coordinates of the reference point ## 88 | #print("COUCOU " + str(pts.shape) + " : " + str(p_center.shape)) 89 | pts_local = pts - p_center # (N, P, K, dims) 90 | #print("HELLO " + str(pts_local.shape)) 91 | 92 | 93 | 94 | ##SECOND STEP : We lift every point individually to C_mid space 95 | fts_lifted0 = self.dense1(pts_local) 96 | fts_lifted = self.dense2(fts_lifted0) # (N, P, K, C_mid) 97 | ## THIRD STEP : We check if there are already features as input (first layer or not) and cocnatenate Fsigma and previous F 98 | if fts is None: 99 | fts_cat = fts_lifted 100 | else: 101 | #print("concatenation : " + str(fts_lifted.shape) + " : " + str(fts.shape) ) 102 | fts_cat = torch.cat((fts_lifted, fts), -1) # (N, P, K, C_mid + C_in) 103 | 104 | ##FOURTH STEP : We need to learn the transformation matrix 105 | X_shape = (N, P, self.K, self.K) 106 | X = pts_local.permute(0,3,1,2) 107 | X = self.conv1(X) 108 | X = X.permute(0,2,3,1) 109 | X = self.x_dense1(X) 110 | X = self.x_dense2(X) 111 | #print("X SHAPE "+ str(X.shape)) 112 | X = X.view(*X_shape) 113 | #print("X SHAPE "+ str(X.shape)) 114 | 115 | ## FIFTH STEP : we weight and permute F* with X 116 | fts_X = torch.matmul(X, fts_cat) 117 | #print("FTS_X SHAPE " + str(fts_X.shape)) 118 | #SIXTH STEP : Last convolution giving us the output of the X convolution 119 | X2 = fts_X.permute(0,3,1,2) 120 | X2 = self.conv2(X2) 121 | x2 = X2.permute(0,2,3,1) 122 | fts_p = X2.squeeze(dim = 2) 123 | #print(fts_p.shape) 124 | 125 | # tranform to (N,P,K,C/K) 126 | fts_p = fts_p.permute(0,2,1,3) 127 | fts_shape = (N, len(fts_p[0]), 8, int(self.C/self.K)) 128 | fts_p = fts_p.squeeze(dim=3) 129 | #fts_p = fts_p.view(fts_shape) 130 | # print("################# END CONV FEATURES###############") 131 | # print(fts_p.shape) 132 | # print("##################################################") 133 | return fts_p 134 | 135 | 136 | ### TEST THE X CONVOLUTION ### 137 | if __name__ == "__main__": 138 | N = 4 139 | D = 3 140 | C_in = 8 141 | C_out = 32 142 | N_neighbors = 100 143 | convo = XConv(4,8,2,10,1,1000,10).cuda() 144 | print(convo) 145 | -------------------------------------------------------------------------------- /PointSift/sift.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0,'..') 3 | import math 4 | import open3d 5 | import torch 6 | import torch.nn as nn 7 | import torch.nn.functional as F 8 | import point 9 | import numpy as np 10 | import time 11 | import torch.optim as optim 12 | from Utils.net_utils import * 13 | 14 | 15 | def conv_bn(inp, oup, kernel, stride=1, activation='relu'): 16 | seq = nn.Sequential( 17 | nn.Conv2d(inp, oup, kernel, stride), 18 | nn.BatchNorm2d(oup) 19 | ) 20 | if activation == 'relu': 21 | seq.add_module('2', nn.ReLU()) 22 | return seq 23 | 24 | def conv1d_bn(inp, oup, kernel, stride=1, activation='relu'): 25 | seq = nn.Sequential( 26 | nn.Conv1d(inp, oup, kernel, stride), 27 | nn.BatchNorm1d(oup) 28 | ) 29 | if activation == 'relu': 30 | seq.add_module('2', nn.ReLU()) 31 | return seq 32 | 33 | 34 | def fc_bn(inp, oup): 35 | return nn.Sequential( 36 | nn.Linear(inp, oup), 37 | nn.BatchNorm1d(oup), 38 | nn.ReLU() 39 | ) 40 | 41 | class PointNet_SA_module_basic(nn.Module): 42 | def __init__(self): 43 | super(PointNet_SA_module_basic, self).__init__() 44 | 45 | def index_points(self,points, idx): 46 | """ 47 | Input: 48 | points: input points data, [B, N, C] 49 | idx: sample index data, [B, D1, D2, ..., Dn] 50 | Return: 51 | new_points:, indexed points data, [B, D1, D2, ..., Dn, C] 52 | """ 53 | device = points.device 54 | B = points.shape[0] 55 | view_shape = list(idx.shape) 56 | view_shape[1:] = [1] * (len(view_shape) - 1) 57 | repeat_shape = list(idx.shape) 58 | repeat_shape[0] = 1 59 | batch_indices = torch.arange(B, dtype=torch.long).view(view_shape).repeat(repeat_shape) 60 | new_points = points[batch_indices, idx, :] 61 | return new_points 62 | 63 | def square_distance(self, src, dst): 64 | """ 65 | Description: 66 | just the simple Euclidean distance fomula,(x-y)^2, 67 | Input: 68 | src: source points, [B, N, C] 69 | dst: target points, [B, M, C] 70 | Output: 71 | dist: per-point square distance, [B, N, M] 72 | """ 73 | B, N, _ = src.shape 74 | _, M, _ = dst.shape 75 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1).contiguous()) 76 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 77 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 78 | return dist 79 | 80 | def group_points(self,xyz,idx): 81 | b , n , c = xyz.shape 82 | m = idx.shape[1] 83 | nsample = idx.shape[2] 84 | out = torch.zeros((xyz.shape[0],xyz.shape[1], idx.shape[2],c)).cuda() 85 | point.group_points(b,n,c,n,nsample,xyz,idx.int(),out) 86 | return out 87 | 88 | def farthest_point_sample_gpu(self, xyz, npoint): 89 | b, n ,c = xyz.shape 90 | centroid = torch.zeros((xyz.shape[0],npoint), dtype=torch.int32).cuda() 91 | temp = torch.zeros((32,n)).cuda() 92 | point.farthestPoint(b,n, npoint, xyz , temp ,centroid) 93 | return centroid.long() 94 | 95 | def ball_query(self, radius, nsample, xyz, new_xyz): 96 | b, n ,c = xyz.shape 97 | m = new_xyz.shape[1] 98 | group_idx = torch.zeros((new_xyz.shape[0],new_xyz.shape[1], nsample), dtype=torch.int32).cuda() 99 | pts_cnt = torch.zeros((xyz.shape[0],xyz.shape[1]), dtype=torch.int32).cuda() 100 | point.ball_query (b, n, m, radius, nsample, xyz, new_xyz, group_idx ,pts_cnt) 101 | 102 | return group_idx.long() 103 | 104 | def idx_pts(self,points,idx): 105 | new_points = torch.cat([points.index_select(1,idx[b]) for b in range(0,idx.shape[0])], dim=0) 106 | return new_points 107 | def sample_and_group(self, npoint, radius, nsample, xyz, points): 108 | """ 109 | Input: 110 | npoint: the number of points that make the local region. 111 | radius: the radius of the local region 112 | nsample: the number of points in a local region 113 | xyz: input points position data, [B, N, C] 114 | points: input points data, [B, N, D] 115 | Return: 116 | new_xyz: sampled points position data, [B, 1, C] 117 | new_points: sampled points data, [B, 1, N, C+D] 118 | """ 119 | B, N, C = xyz.shape 120 | Np = npoint 121 | assert isinstance(Np, int) 122 | 123 | new_xyz = self.index_points(xyz, self.farthest_point_sample_gpu(xyz, npoint)) # [B,n,3] and [B,np] → [B,np,3] 124 | idx = self.ball_query(radius, nsample, xyz, new_xyz) 125 | grouped_xyz = self.index_points(xyz, idx)# [B,n,3] and [B,n,M] → [B,n,M,3] 126 | grouped_xyz -= new_xyz.view(B, Np, 1, C) # the points of each group will be normalized with their centroid 127 | if points is not None: 128 | grouped_points = self.index_points(points, idx)# [B,n,3] and [B,n,M] → [B,n,M,3] 129 | new_points = torch.cat([grouped_xyz, grouped_points], dim=-1) 130 | else: 131 | new_points = grouped_xyz 132 | return new_xyz, new_points 133 | 134 | def sample_and_group_all(self, xyz, points): 135 | """ 136 | Description: 137 | Equivalent to sample_and_group with npoint=1, radius=np.inf, and the centroid is (0, 0, 0) 138 | Input: 139 | xyz: input points position data, [B, N, C] 140 | points: input points data, [B, N, D] 141 | Return: 142 | new_xyz: sampled points position data, [B, 1, C] 143 | new_points: sampled points data, [B, 1, N, C+D] 144 | """ 145 | device = xyz.device 146 | B, N, C = xyz.shape 147 | new_xyz = torch.zeros(B, 1, C).to(device) 148 | grouped_xyz = xyz.view(B, 1, N, C) 149 | if points is not None: 150 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 151 | else: 152 | new_points = grouped_xyz 153 | return new_xyz, new_points 154 | 155 | 156 | class Pointnet_SA_module(PointNet_SA_module_basic): 157 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 158 | 159 | super(Pointnet_SA_module, self).__init__() 160 | self.npoint = npoint 161 | self.radius = radius 162 | self.nsample = nsample 163 | self.group_all = group_all 164 | 165 | self.conv_bns = nn.Sequential() 166 | in_channel += 3 # +3是因为points 与 xyz concat的原因 167 | for i, out_channel in enumerate(mlp): 168 | m = conv_bn(in_channel, out_channel, 1) 169 | self.conv_bns.add_module(str(i), m) 170 | in_channel = out_channel 171 | 172 | def forward(self, xyz, points): 173 | """ 174 | Input: 175 | xyz: the shape is [B, N, 3] 176 | points: thes shape is [B, N, D], the data include the feature infomation 177 | Return: 178 | new_xyz: the shape is [B, Np, 3] 179 | new_points: the shape is [B, Np, D'] 180 | """ 181 | 182 | if self.group_all: 183 | new_xyz, new_points = self.sample_and_group_all(xyz, points) 184 | else: 185 | new_xyz, new_points = self.sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 186 | new_points = new_points.permute(0, 3, 1, 2).contiguous() # change size to (B, C, Np, Ns), adaptive to conv 187 | # print("1:", new_points.shape) 188 | new_points = self.conv_bns(new_points) 189 | #print("2:", new_points.shape) 190 | new_points = torch.max(new_points, 3)[0] # 取一个local region里所有sampled point特征对应位置的最大值。 191 | 192 | new_points = new_points.permute(0, 2, 1).contiguous() 193 | #print(new_points.shape) 194 | return new_xyz, new_points 195 | 196 | class PointSIFT_module_basic(nn.Module): 197 | def __init__(self): 198 | super(PointSIFT_module_basic, self).__init__() 199 | 200 | 201 | def group_points(self,xyz,idx): 202 | b , n , c = xyz.shape 203 | m = idx.shape[1] 204 | nsample = idx.shape[2] 205 | out = torch.zeros((xyz.shape[0],xyz.shape[1], 8,c)).cuda() 206 | point.group_points(b,n,c,m,nsample,xyz,idx.int(),out) 207 | return out 208 | 209 | def pointsift_select(self, radius, xyz): 210 | y = torch.zeros((xyz.shape[0],xyz.shape[1], 8), dtype=torch.int32).cuda() 211 | point.select_cube(xyz,y,xyz.shape[0],xyz.shape[1],radius) 212 | return y.long() 213 | 214 | def pointsift_group(self, radius, xyz, points, use_xyz=True): 215 | 216 | B, N, C = xyz.shape 217 | assert C == 3 218 | # start_time = time.time() 219 | idx = self.pointsift_select(radius, xyz) # B, N, 8 220 | # print("select SIR 1 ", time.time() - start_time, xyz.shape) 221 | 222 | # start_time = time.time() 223 | grouped_xyz = self.group_points(xyz, idx) # B, N, 8, 3 224 | # print("group SIR SIR 1 ", time.time() - start_time) 225 | 226 | grouped_xyz -= xyz.view(B, N, 1, 3) 227 | if points is not None: 228 | grouped_points = self.group_points(points, idx) 229 | if use_xyz: 230 | grouped_points = torch.cat([grouped_xyz, grouped_points], dim=-1) 231 | else: 232 | grouped_points = grouped_xyz 233 | return grouped_xyz, grouped_points, idx 234 | 235 | def pointsift_group_with_idx(self, idx, xyz, points, use_xyz=True): 236 | 237 | B, N, C = xyz.shape 238 | grouped_xyz = self.group_points(xyz, idx) # B, N, 8, 3 239 | grouped_xyz -= xyz.view(B, N, 1, 3) 240 | if points is not None: 241 | grouped_points = self.group_points(points, idx) 242 | if use_xyz: 243 | grouped_points = torch.cat([grouped_xyz, grouped_points], dim=-1) 244 | else: 245 | grouped_points = grouped_xyz 246 | return grouped_xyz, grouped_points 247 | 248 | class PointSIFT_res_module(PointSIFT_module_basic): 249 | 250 | def __init__(self, radius, output_channel, extra_input_channel=0, merge='add', same_dim=False): 251 | super(PointSIFT_res_module, self).__init__() 252 | self.radius = radius 253 | self.merge = merge 254 | self.same_dim = same_dim 255 | 256 | self.conv1 = nn.Sequential( 257 | conv_bn(3 + extra_input_channel, output_channel, [1, 2], [1, 2]), 258 | conv_bn(output_channel, output_channel, [1, 2], [1, 2]), 259 | conv_bn(output_channel, output_channel, [1, 2], [1, 2]) 260 | ) 261 | 262 | self.conv2 = nn.Sequential( 263 | conv_bn(3 + output_channel, output_channel, [1, 2], [1, 2]), 264 | conv_bn(output_channel, output_channel, [1, 2], [1, 2]), 265 | conv_bn(output_channel, output_channel, [1, 2], [1, 2], activation=None) 266 | ) 267 | if same_dim: 268 | self.convt = nn.Sequential( 269 | nn.Conv1d(extra_input_channel, output_channel, 1), 270 | nn.BatchNorm1d(output_channel), 271 | nn.ReLU() 272 | ) 273 | 274 | def forward(self, xyz, points): 275 | _, grouped_points, idx = self.pointsift_group(self.radius, xyz, points) # [B, N, 8, 3], [B, N, 8, 3 + C] 276 | 277 | grouped_points = grouped_points.permute(0, 3, 1, 2).contiguous() # B, C, N, 8 278 | ##print(grouped_points.shape) 279 | new_points = self.conv1(grouped_points) 280 | ##print(new_points.shape) 281 | new_points = new_points.squeeze(-1).permute(0, 2, 1).contiguous() 282 | 283 | _, grouped_points = self.pointsift_group_with_idx(idx, xyz, new_points) 284 | grouped_points = grouped_points.permute(0, 3, 1, 2).contiguous() 285 | 286 | ##print(grouped_points.shape) 287 | new_points = self.conv2(grouped_points) 288 | 289 | new_points = new_points.squeeze(-1) 290 | 291 | if points is not None: 292 | points = points.permute(0, 2, 1).contiguous() 293 | # print(points.shape) 294 | if self.same_dim: 295 | points = self.convt(points) 296 | if self.merge == 'add': 297 | new_points = new_points + points 298 | elif self.merge == 'concat': 299 | new_points = torch.cat([new_points, points], dim=1) 300 | 301 | new_points = F.relu(new_points) 302 | new_points = new_points.permute(0, 2, 1).contiguous() 303 | 304 | return xyz, new_points 305 | 306 | class PointSIFT_module(PointSIFT_module_basic): 307 | 308 | def __init__(self, radius, output_channel, extra_input_channel=0, merge='add', same_dim=False): 309 | super(PointSIFT_module, self).__init__() 310 | self.radius = radius 311 | self.merge = merge 312 | self.same_dim = same_dim 313 | 314 | self.conv1 = nn.Sequential( 315 | conv_bn(3+extra_input_channel, output_channel, [1, 2], [1, 2]), 316 | conv_bn(output_channel, output_channel, [1, 2], [1, 2]), 317 | conv_bn(output_channel, output_channel, [1, 2], [1, 2]) 318 | ) 319 | 320 | self.conv2 = conv_bn(output_channel, output_channel, [1, 1], [1, 1]) 321 | 322 | 323 | def forward(self, xyz, points): 324 | _, grouped_points, idx = self.pointsift_group(self.radius, xyz, points) # [B, N, 8, 3], [B, N, 8, 3 + C] 325 | 326 | grouped_points = grouped_points.permute(0, 3, 1, 2).contiguous() # B, C, N, 8 327 | ##print(grouped_points.shape) 328 | new_points = self.conv1(grouped_points) 329 | new_points = self.conv2(new_points) 330 | 331 | new_points = new_points.squeeze(-1) 332 | 333 | return xyz, new_points 334 | 335 | class Pointnet_fp_module(nn.Module): 336 | def __init__(self,mlp,dimin): 337 | super(Pointnet_fp_module, self).__init__() 338 | self.tanh = nn.Hardtanh(min_val=0, max_val=1e-10) 339 | self.convs = [] 340 | for i,m in enumerate(mlp): 341 | self.convs.append(conv_bn(dimin[i],m,[1,1],[1,1]).cuda()) 342 | 343 | def forward(self,xyz1, xyz2, points1, points2): 344 | b,n, c = xyz1.shape 345 | m = xyz2.shape[1] 346 | dist = torch.zeros((xyz1.shape[0],xyz1.shape[1], 3)).cpu() 347 | idx = torch.zeros((xyz1.shape[0],xyz1.shape[1], 3), dtype=torch.int32).cpu() 348 | point.interpolate(b,n, m, xyz1.cpu(), xyz2.cpu(), dist,idx) 349 | 350 | 351 | dist = self.tanh(dist) 352 | norm = torch.sum((1.0/ (dist+1e-10 ) ),dim = 2, keepdim=True) 353 | norm = norm.repeat([1,1,3]) 354 | weight = (1.0/dist) / (norm+1e-10) 355 | 356 | interpolated_points = torch.zeros((b,n, points1.shape[2])).cpu() 357 | # print(torch.sum(points2 != points2)) 358 | # print(torch.sum(idx != idx)) 359 | # print(torch.sum(weight != weight)) 360 | # print(torch.sum(interpolated_points != interpolated_points)) 361 | # exit() 362 | 363 | interpolated_points = point.three_interpolate(b, m, c, n,points1.shape[2], points2.cpu(), idx.cpu(), weight.cpu(), interpolated_points) 364 | xyz1 = xyz1.cuda() 365 | xyz2 = xyz2.cuda() 366 | points1 = points1.cuda() 367 | points2 = points2.cuda() 368 | interpolated_points = torch.cat([interpolated_points.cuda(),points1],dim=2) 369 | interpolated_points = interpolated_points.unsqueeze(2).permute(0,3,2,1) 370 | for c in range(0,len(self.convs)): 371 | interpolated_points = self.convs[c](interpolated_points) 372 | 373 | interpolated_points = interpolated_points.squeeze(2) 374 | return interpolated_points 375 | 376 | class PointSIFT(nn.Module): 377 | def __init__(self,nb_classes): 378 | super(PointSIFT, self).__init__() 379 | 380 | self.num_classes = nb_classes 381 | 382 | self.pointsift_res_m3 = PointSIFT_res_module(radius=0.1, output_channel=64, merge='concat')#extra_input_channel=64) 383 | self.pointnet_sa_m3 = Pointnet_SA_module(npoint=1024, radius=0.1, nsample=32, in_channel=64, mlp=[64, 128],group_all=False) 384 | 385 | self.pointsift_res_m4 = PointSIFT_res_module(radius=0.2, output_channel=128, extra_input_channel=128) 386 | self.pointnet_sa_m4 = Pointnet_SA_module(npoint=256, radius=0.2, nsample=32, in_channel=128, mlp=[128, 256],group_all=False) 387 | 388 | self.pointsift_res_m5_1 = PointSIFT_res_module(radius=0.2, output_channel=256, extra_input_channel=256) 389 | self.pointsift_res_m5_2 = PointSIFT_res_module(radius=0.2, output_channel=512, extra_input_channel=256,same_dim=True) 390 | 391 | self.conv1 = conv1d_bn(768, 512, 1, stride=1, activation='none') 392 | 393 | self.pointnet_sa_m6 = Pointnet_SA_module(npoint=64, radius=0.2, nsample=32, in_channel=512, mlp=[512,512],group_all=False) 394 | self.pointnet_fp_m0 = Pointnet_fp_module([512,512],[512,512]) 395 | 396 | self.pointsift_m0 = PointSIFT_module(radius=0.5, output_channel=512,extra_input_channel=512) 397 | 398 | self.pointsift_m1 = PointSIFT_module(radius=0.5, output_channel=512,extra_input_channel=512) 399 | 400 | self.pointsift_m2 = PointSIFT_module(radius=0.5, output_channel=512,extra_input_channel=512) 401 | 402 | self.conv2 = conv1d_bn(512, 512, 1, stride=1, activation='none') 403 | 404 | self.pointnet_fp_m1 = Pointnet_fp_module([256,256],[256,256]) 405 | 406 | self.pointsift_m3 = PointSIFT_module(radius=0.25, output_channel=256,extra_input_channel=256) 407 | 408 | self.pointsift_m4 = PointSIFT_module(radius=0.25, output_channel=256,extra_input_channel=256) 409 | 410 | self.conv3 = conv1d_bn(256, 256, 1, stride=1, activation='none') 411 | 412 | self.pointnet_fp_m2 = Pointnet_fp_module([128,128,128],[128,128,128]) 413 | 414 | self.pointsift_m5 = PointSIFT_module(radius=0.1, output_channel=128,extra_input_channel=128) 415 | 416 | ### fc 417 | 418 | self.conv_fc = conv1d_bn(128, 128, 1, stride=1, activation='none') 419 | 420 | self.drop_fc = nn.Dropout(p=0.5) 421 | 422 | self.conv2_fc = conv1d_bn(128, 2, 1, stride=1, activation='none') 423 | 424 | 425 | 426 | 427 | def forward(self, xyz, points=None): 428 | """ 429 | Input: 430 | xyz: is the raw point cloud(B * N * 3) 431 | Return: 432 | """ 433 | B = xyz.size()[0] 434 | 435 | l3_xyz, l3_points = self.pointsift_res_m3(xyz, points) 436 | # print(l3_xyz.shape, l3_points.shape) 437 | c3_xyz, c3_points = self.pointnet_sa_m3(l3_xyz, l3_points) 438 | 439 | 440 | l4_xyz, l4_points = self.pointsift_res_m4(c3_xyz, c3_points) 441 | c4_xyz, c4_points = self.pointnet_sa_m4(l4_xyz, l4_points) 442 | 443 | 444 | l5_xyz, l5_points = self.pointsift_res_m5_1(c4_xyz, c4_points) 445 | l5_2_xyz, l5_2_points = self.pointsift_res_m5_2(l5_xyz, l5_points) 446 | 447 | 448 | l2_cat_points = torch.cat([l5_points, l5_2_points], dim=2) 449 | 450 | fc_l2_points = self.conv1(l2_cat_points.permute(0,2,1)).permute(0,2,1) 451 | 452 | 453 | l3b_xyz, l3b_points = self.pointnet_sa_m6(l5_2_xyz,fc_l2_points) 454 | 455 | l2_points = self.pointnet_fp_m0(c4_xyz,l3b_xyz, c4_points,l3_points ).permute(0,2,1) 456 | # print(torch.sum(l2_points != l2_points)) 457 | # exit() 458 | 459 | _, l2_points_1 = self.pointsift_m0(c4_xyz,l2_points) 460 | _, l2_points_2 = self.pointsift_m1(c4_xyz,l2_points) 461 | _, l2_points_3 = self.pointsift_m2(c4_xyz,l2_points) 462 | 463 | 464 | 465 | l2_points = torch.cat([l2_points_1,l2_points_2,l2_points_3],dim=-1) 466 | l2_points = self.conv2(l2_points) 467 | 468 | 469 | 470 | 471 | l1_points = self.pointnet_fp_m1(c3_xyz,c4_xyz, c3_points,l2_points ).permute(0,2,1) 472 | print(torch.sum(l1_points != l1_points)) 473 | 474 | _, l1_points_1 = self.pointsift_m3(c3_xyz,l1_points) 475 | _, l1_points_2 = self.pointsift_m4(c3_xyz,l1_points) 476 | 477 | 478 | l1_points = torch.cat([l1_points_1,l1_points_2], dim =-1) 479 | 480 | 481 | l0_points = self.conv3(l1_points) 482 | 483 | 484 | l0_points = self.pointnet_fp_m2(l3_xyz,c3_xyz, l3_points,l0_points ).permute(0,2,1) 485 | 486 | _, l0_points_1 = self.pointsift_m5(l3_xyz,l0_points) 487 | 488 | 489 | net = self.conv_fc(l0_points_1) 490 | net = self.drop_fc(net) 491 | net = self.conv2_fc(net) 492 | print(torch.sum(net != net)) 493 | return net 494 | 495 | 496 | 497 | 498 | @staticmethod 499 | def get_loss(input, target): 500 | classify_loss = nn.CrossEntropyLoss() 501 | loss = classify_loss(input, target) 502 | return loss 503 | 504 | def initialize_weights(self): 505 | for m in self.modules(): 506 | if isinstance(m, nn.Conv2d): 507 | n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels 508 | m.weight.data.normal_(0, math.sqrt(2. / n)) 509 | if m.bias is not None: 510 | m.bias.data.zero_() 511 | elif isinstance(m, nn.BatchNorm2d): 512 | m.weight.data.fill_(1) 513 | m.bias.data.zero_() 514 | elif isinstance(m, nn.Linear): 515 | # n = m.weight.size(1) 516 | m.weight.data.normal_(0, 0.01) 517 | m.bias.data.zero_() 518 | 519 | 520 | if __name__ == "__main__": 521 | for i in range(100): 522 | xyz = torch.rand(16, 2048,3).cuda() 523 | net = PointSIFT(1) 524 | net.cuda() 525 | x = net(xyz) 526 | print(torch.sum(x != x)) 527 | -------------------------------------------------------------------------------- /Pointnet2/Pointnet2.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0,'..') 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from time import time 9 | from Utils.net_utils import * 10 | 11 | 12 | class PointNet2SemSeg(nn.Module): 13 | def __init__(self, num_classes): 14 | super(PointNet2SemSeg, self).__init__() 15 | self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 3, [32, 32, 64], False) 16 | self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False) 17 | self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False) 18 | self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False) 19 | self.fp4 = PointNetFeaturePropagation(768, [256, 256]) 20 | self.fp3 = PointNetFeaturePropagation(384, [256, 256]) 21 | self.fp2 = PointNetFeaturePropagation(320, [256, 128]) 22 | self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128]) 23 | self.conv1 = nn.Conv1d(128, 128, 1) 24 | self.bn1 = nn.BatchNorm1d(128) 25 | self.drop1 = nn.Dropout(0.5) 26 | self.conv2 = nn.Conv1d(128, num_classes, 1) 27 | 28 | def forward(self, xyz): 29 | xyz = xyz.permute(0, 2, 1) 30 | l1_xyz, l1_points = self.sa1(xyz, None) 31 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 32 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 33 | l4_xyz, l4_points = self.sa4(l3_xyz, l3_points) 34 | 35 | l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points) 36 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 37 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 38 | l0_points = self.fp1(xyz, l1_xyz, None, l1_points) 39 | 40 | x = self.drop1(F.relu(self.bn1(self.conv1(l0_points)))) 41 | x = self.conv2(x) 42 | # x = F.log_softmax(x, dim=1) 43 | return x 44 | 45 | if __name__ == '__main__': 46 | for i in range(10): 47 | xyz = torch.rand(16, 2048,3).cuda() 48 | net = PointNet2SemSeg(1) 49 | net.cuda() 50 | x = net(xyz) 51 | -------------------------------------------------------------------------------- /Pointnet2/Pointnet2_msg.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0,'..') 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from time import time 9 | import point 10 | from Utils.net_utils import * 11 | 12 | 13 | class PointNet2SemSeg(nn.Module): 14 | def __init__(self, num_classes): 15 | super(PointNet2SemSeg, self).__init__() 16 | self.sa05 = PointNetSetAbstractionMsg(2048, [0.1, 0.2, 0.4], [32, 64, 128], 6,[[32, 32, 64], [64, 64, 64], [64, 96, 128]], False) 17 | self.sa1 = PointNetSetAbstraction(1024, 0.2, 32, 128+64+64+3, [32, 32, 64], False)# npoint, radius, nsample, in_channel, mlp, group_all 18 | self.sa2 = PointNetSetAbstraction(256, 0.4, 32, 64 + 3, [64, 64, 128], False) 19 | 20 | self.fp2 = PointNetFeaturePropagation(192, [512, 256])#in_channel, mlp 21 | self.fp1 = PointNetFeaturePropagation(512, [256, 128])#in_channel, mlp 22 | self.fp05 = PointNetFeaturePropagation(131, [128, 64]) 23 | self.conv1 = nn.Conv1d(64, 128, 1) 24 | self.bn1 = nn.BatchNorm1d(128) 25 | self.drop1 = nn.Dropout(0.5) 26 | self.conv2 = nn.Conv1d(128,num_classes, 1) 27 | 28 | 29 | 30 | 31 | def forward(self, xyz,color): 32 | xyz = xyz.permute(0, 2, 1) 33 | color = color.permute(0, 2, 1) 34 | l05_xyz, l05_points = self.sa05(xyz, color) 35 | l1_xyz, l1_points = self.sa1(l05_xyz, l05_points) 36 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 37 | 38 | 39 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 40 | l05_points = self.fp1(l05_xyz, l1_xyz, l05_points, l1_points) 41 | l0_points = self.fp05(xyz, l05_xyz, color, l05_points) 42 | 43 | x = self.drop1(F.relu(self.bn1(self.conv1(l0_points)))) 44 | x = self.conv2(x) 45 | 46 | return torch.sigmoid(x)## we are adding a sigmoid as it is a binary classification tmtc 47 | 48 | def weights_init(m): 49 | if isinstance(m, nn.Conv2d): 50 | torch.nn.init.xavier_uniform_(m.weight.data) 51 | 52 | if __name__ == '__main__': 53 | for i in range(10): 54 | xyz = torch.rand(1, 30000,3).cuda() 55 | colors = torch.rand(1, 30000,3).cuda() 56 | net = PointNet2SemSeg(2) 57 | net.cuda() 58 | x = net(xyz,colors) 59 | print(x.shape) 60 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # 3DNetworksPytorch 4 | 5 | 6 | ```diff 7 | - Looking for new papers to implement in pytorch ! Go comment in the dedicated issue, papers you would like to see implemented using pytorch !! 8 | ``` 9 | 10 | ***This repository is mostly implementation of papers using the pytorch framework, PLEASE cite the corresponding papers before referencing to this work. User discretion is advised concerning accuracy and readiness of my implementations, please create issues when you encounter problems and I will try my best to fix them.*** 11 | 12 | 13 | This repository is meant as way to learn by implementating them, different 3D deep learning architectures for pointclouds. I haven't tested them on benchmark datasets for the papers, only on some toy examples. If You spot any mistake, I am open to pull requests and any colaboration on the topic. 14 | 15 | (I haven't cleaned the code completly so it might seem a bit messy at first sight) 16 | Most of the networks are using the cuda code in cppattempt. Please go in there and install the extension (python setup.py install), so that they can import it. 17 | The only things required should be pytorch 1.0+ and the corresponding cudatoolkit, everything configured correctly obviously. See pytorch explanations for how to compile C++ extensions. 18 | 19 | # Table of Contents 20 | 1. [PointSift](#PointSift) 21 | 2. [PointCNN](#PointCNN) 22 | 3. [PointNet++](#PointNet++) 23 | 4. [Cuda_Extension](#Cuda_Extension) 24 | 5. [3D-BoNet](#3D-BoNet) 25 | 6. [SPGN](#SPGN) 26 | 7. [PCN](#PCN) 27 | 8. [3D_completion_challenge](#3D_completion_challenge) 28 | 9. [FastFCN](#FastFCN) 29 | 10.[CSRNet](#CSRNet) 30 | 31 | 32 | 33 | 34 | ## PointSift 35 | An implementation of PointSift using Pytorch (https://arxiv.org/pdf/1807.00652.pdf) lies in the PoinSift folder. 36 | The C_utils folder contains some algorithms inplemented in CUDA and C++ taken from the original implementation of PointSift (https://github.com/MVIG-SJTU/pointSIFT) but wrapped to be used with Pytorch Tensor directly. 37 | 38 | ## PointCNN 39 | An implementation of PointCNN using Pytorch (https://arxiv.org/pdf/1801.07791.pdf) lies in the PointCNN folder. 40 | 41 | ## PointNet++ 42 | An implementation of PointNet++ using Pytorch (https://arxiv.org/pdf/1706.02413.pdf) lies in the PointNet++ folder. 43 | It uses the same algorithms on GPU as PointSift as Pointsift uses Pointnet++ modules. 44 | 45 | 46 | ## Cuda_Extension 47 | There are two versions of the cuda extensions for pointnet and pointsift. The first one is in C_utils and was implemented using the old C api for torch. As it is now deprecated in newer version of pytorch and they recommend using the C++ extension api, I did an attempt in cppattempt folder. 48 | 49 | ## 3D-BoNet 50 | 51 | Quick implementation of 3D-BoNet (https://arxiv.org/pdf/1906.01140.pdf) https://gist.github.com/lelouedec/5a7ba5547df5cef71b50ab306199623f using pytorch. All in one file, need to compile C++ pointnet extension. Code not converging for bounding boxes regressions 52 | 53 | ## SPGN 54 | 55 | Implementation of SGPN (https://arxiv.org/pdf/1711.08588.pdf) based on Pointnet implementation. 56 | 57 | 58 | ## PCN 59 | Implementation of PCN (PCN: Point Completion Network) (https://arxiv.org/pdf/1808.00671.pdf) (https://github.com/wentaoyuan/pcn) using pytorch. For the chamfer distance and the EMD loss, I used inplementation from respectively https://github.com/chrdiller/pyTorchChamferDistance and https://github.com/daerduoCarey/PyTorchEMD. See these repositories for how to use them. Copy emd.py and the compiled ".so" lib to the same directory of your model and it should be fine. 60 | Tested with the PCN paper shapenet data, download it from the google drive provided in their repository. The dataloader will help loading the pointclouds from the shapenet directory. See following screenshot for example (Left is groundtruth, middle the input and right the output of the network): 61 | ![Example for pcn](./PCN/example.png) 62 | 63 | 64 | ## 3D_completion_challenge 65 | A new 3D completion challenge is available here : https://github.com/lynetcha/completion3d it includes PCN (seen above) 66 | 67 | ## FastFCN 68 | Two files implementation of the fast fcn paper based on their own implementation. Go check the paper and git for more details : https://arxiv.org/pdf/1903.11816.pdf https://github.com/wuhuikai/FastFCN 69 | 70 | 71 | ## CSRNet 72 | Implementation of the model used in the paper : https://arxiv.org/pdf/1802.10062.pdf. It is only one of the variation but the one used and advertised by the author. The output as in the paper needs to be upsampled to compare to the original image. 73 | 74 | 75 | As asked: 76 | ``` 77 | @software{lelouedec_2020_3766070, 78 | author = {lelouedec}, 79 | title = {lelouedec/3DNetworksPytorch: pre-alpha}, 80 | month = apr, 81 | year = 2020, 82 | version = {0.1}, 83 | } 84 | ``` 85 | -------------------------------------------------------------------------------- /SPGN/SGPN.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0,'..') 3 | import torch 4 | import torch.nn as nn 5 | import torch.nn.init as init 6 | import torch.nn.functional as F 7 | import numpy as np 8 | from time import time 9 | import point 10 | from Utils.net_utils import * 11 | from SGPN_utils import * 12 | 13 | 14 | class PointNet2SemSeg(nn.Module): 15 | def __init__(self, num_classes): 16 | super(PointNet2SemSeg, self).__init__() 17 | self.sa0 = PointNetSetAbstraction(4096, 0.1, 32, 6, [16, 16, 32], False) 18 | self.sa05 = PointNetSetAbstraction(2048, 0.1, 32, 32+3, [32, 32, 32], False) 19 | self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, 32+3, [32, 32, 64], False)# npoint, radius, nsample, in_channel, mlp, group_all 20 | self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False) 21 | self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False) 22 | self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False) 23 | 24 | self.fp4 = PointNetFeaturePropagation(768, [256, 256])#in_channel, mlp 25 | self.fp3 = PointNetFeaturePropagation(384, [256, 256])#in_channel, mlp 26 | self.fp2 = PointNetFeaturePropagation(320, [256, 128])#in_channel, mlp 27 | self.fp1 = PointNetFeaturePropagation(160, [128, 128, 128])#in_channel, mlp 28 | self.fp05 = PointNetFeaturePropagation(160, [128, 128, 64]) 29 | self.fp0 = PointNetFeaturePropagation(67, [128, 128, 64]) 30 | self.conv1 = nn.Conv1d(64, 128, 1) 31 | self.bn1 = nn.BatchNorm1d(128) 32 | self.drop1 = nn.Dropout(0.5) 33 | self.conv2 = nn.Conv1d(128, 1, 1) 34 | 35 | 36 | ## similarity 37 | self.conv2_1 = nn.Conv2d(64,128,kernel_size=(1,1),stride=(1,1)) 38 | 39 | ## confidence map 40 | self.conv3_1 = nn.Conv2d(64,128,kernel_size=(1,1),stride=(1,1)) 41 | self.conv3_2 = nn.Conv2d(128,1,kernel_size=(1,1),stride=(1,1)) 42 | 43 | self.criterion_semseg = nn.BCELoss().cuda() 44 | self.criterion2 = nn.MSELoss(reduction='mean') 45 | 46 | 47 | def forward(self, xyz,color,target,training,epoch,just_seg): 48 | xyz = xyz.permute(0, 2, 1) 49 | color = color.permute(0, 2, 1) 50 | l0_xyz, l0_points = self.sa0(xyz, color) 51 | l05_xyz, l05_points = self.sa05(l0_xyz, l0_points) 52 | l1_xyz, l1_points = self.sa1(l05_xyz, l05_points) 53 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 54 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 55 | l4_xyz, l4_points = self.sa4(l3_xyz, l3_points) 56 | 57 | l3_points = self.fp4(l3_xyz, l4_xyz, l3_points, l4_points) 58 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 59 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 60 | l05_points = self.fp1(l05_xyz, l1_xyz, l05_points, l1_points) 61 | l0_points = self.fp05(l0_xyz, l05_xyz, l0_points, l05_points) 62 | l0_points = self.fp0(xyz, l0_xyz, color, l0_points) 63 | 64 | x = self.drop1(F.relu(self.bn1(self.conv1(l0_points)))) 65 | semseg = self.conv2(x) 66 | semseg_logit = torch.sigmoid(semseg) 67 | 68 | 69 | ##similarity 70 | Fsim = self.conv2_1(l0_points.unsqueeze(2)).squeeze(2) 71 | r = torch.sum(Fsim*Fsim,dim=1) 72 | r = r.view((l0_points.shape[0],-1,1)).permute(0,2,1) 73 | trans = torch.transpose(Fsim ,2, 1) 74 | mul = 2 * torch.matmul(trans, Fsim) 75 | sub = r - mul 76 | D = sub + torch.transpose(r, 2, 1) 77 | D[D<=0.0] = 0.0 78 | 79 | ##Confidence 80 | conf_logit = self.conv3_2(self.conv3_1(l0_points.unsqueeze(2))).squeeze(2) 81 | conf = torch.sigmoid(conf_logit) 82 | 83 | # print(semseg.shape,semseg_logit.shape,D.shape,conf.shape,conf_logit.shape) 84 | ## COmputing loss 85 | if(just_seg): 86 | return [0.0,self.criterion_semseg(semseg_logit,target['semseg']),0.0,0.0],semseg_logit 87 | if(training): 88 | return self.compute_loss(semseg ,semseg_logit ,D ,conf ,conf_logit,target,epoch),semseg_logit 89 | else: 90 | pts_semseg_label,pts_semseg = self.convert_seg_to_one_hot(target['semseg']) 91 | pts_group_label, group_mask = self.convert_groupandcate_to_one_hot(target['ptsgroup']) 92 | group_mat_label = torch.matmul(pts_group_label,torch.transpose(pts_group_label,2,1)) 93 | pts_corr_val = D[0].squeeze() 94 | pred_confidence_val = conf[0].squeeze() 95 | ptsclassification_val = semseg_logit[0].squeeze() 96 | NUM_CATEGORY = 2 97 | ths = np.zeros(NUM_CATEGORY) 98 | ths_ = np.zeros(NUM_CATEGORY) 99 | cnt = np.zeros(NUM_CATEGORY) 100 | # ths,ths_,cnt = Get_Ths(pts_corr_val, target['semseg'].cpu().numpy()[0], target['ptsgroup'].cpu().numpy()[0], ths, ths_, cnt) 101 | # ths = [ths[i]/cnt[i] if cnt[i] != 0 else 0.2 for i in range(len(cnt))] 102 | groupids_block = Create_groups(pts_corr_val, pred_confidence_val, ptsclassification_val,xyz[0].permute(1,0)) 103 | # groupids_block, refineseg, group_seg = GroupMerging(pts_corr_val, pred_confidence_val, ptsclassification_val,[0.1,0.1]) 104 | # groupids_block[groupids_block==-1] = 0.0 105 | return 0.0,groupids_block,ptsclassification_val 106 | 107 | 108 | def compute_loss(self,semseg ,semseg_logit ,D ,conf ,conf_logit, gt_truth,epoch): 109 | #gt_truth = {"ptsgroup":mask_target,"semseg":target,"bounding_boxes":bb_target} 110 | 111 | pts_semseg_label,pts_semseg = self.convert_seg_to_one_hot(gt_truth['semseg']) 112 | pts_group_label, group_mask = self.convert_groupandcate_to_one_hot(gt_truth['ptsgroup']) 113 | 114 | alpha=2.0 115 | margin=[1.,2.] 116 | if(epoch%5==0 and epoch!=0): 117 | alpha = alpha/2 118 | 119 | 120 | ## Similarity loss 121 | B = pts_group_label.shape[0] 122 | N = pts_group_label.shape[1] 123 | 124 | group_mat_label = torch.matmul(pts_group_label,torch.transpose(pts_group_label,1,2)) 125 | diag_idx = torch.arange(0,group_mat_label.shape[1], out=torch.LongTensor()) 126 | group_mat_label[:,diag_idx,diag_idx] = 1.0 127 | 128 | sem_mat_label = torch.matmul(pts_semseg_label,torch.transpose(pts_semseg_label,1,2)) 129 | sem_mat_label[:,diag_idx,diag_idx] = 1.0 130 | 131 | samesem_mat_label = sem_mat_label 132 | diffsem_mat_label = 1.0 - sem_mat_label 133 | 134 | samegroup_mat_label = group_mat_label 135 | diffgroup_mat_label = 1.0 - group_mat_label 136 | diffgroup_samesem_mat_label = diffgroup_mat_label * samesem_mat_label 137 | diffgroup_diffsem_mat_label = diffgroup_mat_label * diffsem_mat_label 138 | 139 | num_samegroup = torch.sum(samegroup_mat_label) 140 | num_diffgroup_samesem = torch.sum(diffgroup_samesem_mat_label) 141 | num_diffgroup_diffsem = torch.sum(diffgroup_diffsem_mat_label) 142 | 143 | pos = samegroup_mat_label * D 144 | sub = margin[0] - D 145 | sub[sub<=0.0] = 0.0 146 | 147 | sub2 = margin[1] - D 148 | sub2[sub2<=0.0] = 0.0 149 | 150 | neg_samesem = alpha * (diffgroup_samesem_mat_label * sub) 151 | neg_diffsem = diffgroup_diffsem_mat_label * sub2 152 | 153 | simmat_loss = neg_samesem + neg_diffsem + pos 154 | 155 | group_mask_weight = torch.matmul(group_mask.unsqueeze(2), torch.transpose(group_mask.unsqueeze(2), 2, 1)) 156 | simmat_loss = simmat_loss * group_mask_weight 157 | simmat_loss = torch.mean(simmat_loss) 158 | 159 | ## Confidence map loss 160 | Pr_obj = torch.sum(pts_semseg_label,dim=2).float().cuda() 161 | ng_label = group_mat_label 162 | ng_label = torch.gt(ng_label,0.5) 163 | ng = torch.lt(D,margin[0]) 164 | epsilon = torch.ones(ng_label.shape[:2]).float().cuda() * 1e-6 165 | 166 | up = torch.sum((ng & ng_label).float()) 167 | down = torch.sum((ng | ng_label).float()) + epsilon 168 | pts_iou = torch.div(up,down) 169 | confidence_label = pts_iou * Pr_obj 170 | 171 | confidence_loss = self.criterion2(confidence_label.unsqueeze(1),conf_logit.squeeze(2))##MSE 172 | 173 | 174 | ##semseg loss 175 | sem_seg_loss = self.criterion_semseg(semseg_logit,gt_truth['semseg']) 176 | 177 | 178 | 179 | loss = simmat_loss + sem_seg_loss + confidence_loss 180 | 181 | grouperr = torch.abs(ng.float()-ng_label.float()) 182 | 183 | return (loss,simmat_loss,sem_seg_loss,confidence_loss)#, grouperr.mean(), torch.sum(grouperr+diffgroup_samesem_mat_label),num_diffgroup_samesem \ 184 | #torch.sum(grouperr * diffgroup_diffsem_mat_label), num_diffgroup_diffsem, \ 185 | #torch.sum(grouperr * samegroup_mat_label), num_samegroup 186 | 187 | 188 | 189 | def convert_seg_to_one_hot(self,labels): 190 | # labels:BxN 191 | labels = labels.permute(0,2,1) 192 | NUM_CATEGORY = 2 193 | label_one_hot = torch.zeros((labels.shape[0], labels.shape[1], NUM_CATEGORY)).cuda() 194 | pts_label_mask = torch.zeros((labels.shape[0], labels.shape[1])).cuda() 195 | 196 | un, cnt = torch.unique(labels, return_counts=True) 197 | label_count_dictionary = {} 198 | for v,u in enumerate(un): 199 | label_count_dictionary[int(u.item())] = cnt[v].item() 200 | 201 | totalnum = 0 202 | for k_un, v_cnt in label_count_dictionary.items(): 203 | if k_un != -1: 204 | totalnum += v_cnt 205 | 206 | for idx in range(labels.shape[0]): 207 | for jdx in range(labels.shape[1]): 208 | if labels[idx, jdx] != -1: 209 | label_one_hot[idx, jdx, int(labels[idx, jdx])] = 1 210 | pts_label_mask[idx, jdx] = float(totalnum) / float(label_count_dictionary[int(labels[idx, jdx])]) # 1. - float(label_count_dictionary[labels[idx, jdx]]) / totalnum 211 | return label_one_hot, pts_label_mask 212 | 213 | def convert_groupandcate_to_one_hot(self,grouplabels): 214 | # grouplabels: BxN 215 | NUM_GROUPS = 50 216 | group_one_hot = torch.zeros((grouplabels.shape[0], grouplabels.shape[1], NUM_GROUPS)).cuda() 217 | pts_group_mask = torch.zeros((grouplabels.shape[0], grouplabels.shape[1])).cuda() 218 | 219 | un, cnt = torch.unique(grouplabels, return_counts=True) 220 | group_count_dictionary = {} 221 | for v,u in enumerate(un): 222 | group_count_dictionary[int(u.item())] = cnt[v].item() 223 | totalnum = 0 224 | for k_un, v_cnt in group_count_dictionary.items(): 225 | if k_un != -1: 226 | totalnum += v_cnt 227 | 228 | for idx in range(grouplabels.shape[0]): 229 | for jdx in range(grouplabels.shape[1]): 230 | if grouplabels[idx, jdx] != -1: 231 | group_one_hot[idx, jdx, int(grouplabels[idx, jdx])] = 1 232 | pts_group_mask[idx, jdx] = float(totalnum) / float(group_count_dictionary[int(grouplabels[idx, jdx])]) # 1. - float(group_count_dictionary[grouplabels[idx, jdx]]) / totalnum 233 | 234 | return group_one_hot.float(), grouplabels 235 | 236 | def freeze_pn(self): 237 | for params in self.sa0.parameters(): 238 | params.requires_grad = False 239 | for params in self.sa1.parameters(): 240 | params.requires_grad = False 241 | for params in self.sa2.parameters(): 242 | params.requires_grad = False 243 | for params in self.sa3.parameters(): 244 | params.requires_grad = False 245 | for params in self.sa4.parameters(): 246 | params.requires_grad = False 247 | 248 | for params in self.fp4.parameters(): 249 | params.requires_grad = False 250 | for params in self.fp4.parameters(): 251 | params.requires_grad = False 252 | for params in self.fp2.parameters(): 253 | params.requires_grad = False 254 | for params in self.fp1.parameters(): 255 | params.requires_grad = False 256 | for params in self.fp05.parameters(): 257 | params.requires_grad = False 258 | for params in self.fp0.parameters(): 259 | params.requires_grad = False 260 | 261 | self.conv1.requires_grad = False 262 | self.conv2.requires_grad = False 263 | 264 | def weights_init(m): 265 | if isinstance(m, nn.Conv2d): 266 | torch.nn.init.xavier_uniform_(m.weight.data) 267 | 268 | if __name__ == '__main__': 269 | for i in range(10): 270 | xyz = torch.rand(1, 30000,3).cuda() 271 | colors = torch.rand(1, 30000,3).cuda() 272 | net = PointNet2SemSeg(2) 273 | net.cuda() 274 | x = net(xyz,colors) 275 | print(x) 276 | -------------------------------------------------------------------------------- /SPGN/SGPN_utils.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | from scipy import stats 3 | #ths,ths_,cnt = Get_Ths(pts_corr_val, target['semseg'].cpu().numpy()[0], target['ptsgroup'].cpu().numpy()[0], ths, ths_, cnt) 4 | def Get_Ths(pts_corr, seg, ins, ths, ths_, cnt): 5 | seg = np.transpose(seg,(1,0)) 6 | pts_corr = pts_corr.detach().cpu().numpy() 7 | pts_in_ins = {} 8 | for ip, pt in enumerate(pts_corr): 9 | if ins[ip] in pts_in_ins.keys(): 10 | pts_in_curins_ind = pts_in_ins[ins[ip]] 11 | pts_notin_curins_ind = (~(pts_in_ins[ins[ip]])) & (seg==seg[ip]).squeeze() 12 | hist, bin = np.histogram(pt[pts_in_curins_ind], bins=20) 13 | 14 | if seg[ip]==8: 15 | print ("pouet",bin) 16 | 17 | numpt_in_curins = np.sum(pts_in_curins_ind) 18 | numpt_notin_curins = np.sum(pts_notin_curins_ind) 19 | 20 | if numpt_notin_curins > 0: 21 | 22 | tp_over_fp = 0 23 | ib_opt = -2 24 | for ib, b in enumerate(bin): 25 | if b == 0: 26 | break 27 | tp = float(np.sum(pt[pts_in_curins_ind] < bin[ib])) / float(numpt_in_curins) 28 | fp = float(np.sum(pt[pts_notin_curins_ind] < bin[ib])) / float(numpt_notin_curins) 29 | 30 | if tp <= 0.5: 31 | continue 32 | 33 | if fp == 0. and tp > 0.5: 34 | ib_opt = ib 35 | break 36 | 37 | if tp/fp > tp_over_fp: 38 | tp_over_fp = tp / fp 39 | ib_opt = ib 40 | 41 | if tp_over_fp > 4.: 42 | ths[int(seg[ip])] += bin[ib_opt] 43 | ths_[int(seg[ip])] += bin[ib_opt] 44 | cnt[int(seg[ip])] += 1 45 | 46 | else: 47 | pts_in_curins_ind = (ins == ins[ip]) 48 | pts_in_ins[ins[ip]] = pts_in_curins_ind 49 | pts_notin_curins_ind = (~(pts_in_ins[ins[ip]])) & (seg==seg[ip]).squeeze() 50 | hist, bin = np.histogram(pt[pts_in_curins_ind], bins=20) 51 | if seg[ip]==8: 52 | print ("pouet",bin) 53 | numpt_in_curins = np.sum(pts_in_curins_ind) 54 | numpt_notin_curins = np.sum(pts_notin_curins_ind) 55 | if numpt_notin_curins > 0: 56 | tp_over_fp = 0 57 | ib_opt = -2 58 | for ib, b in enumerate(bin): 59 | 60 | if b == 0: 61 | break 62 | 63 | tp = float(np.sum(pt[pts_in_curins_ind] 0.5: 70 | ib_opt = ib 71 | break 72 | 73 | if tp / fp > tp_over_fp: 74 | tp_over_fp = tp / fp 75 | ib_opt = ib 76 | 77 | if tp_over_fp > 4.: 78 | ths[int(seg[ip])] += bin[ib_opt] 79 | ths_[int(seg[ip])] += bin[ib_opt] 80 | cnt[int(seg[ip])] += 1 81 | 82 | return ths, ths_, cnt 83 | 84 | def square_distance(src, dst): 85 | N = src.shape[0] 86 | M = dst.shape[0] 87 | dist = -2 * np.matmul(src, np.transpose(dst,[1,0])) 88 | dist += np.sum(src ** 2, axis=-1) 89 | dist += np.sum(dst ** 2, axis=-1) 90 | return dist 91 | 92 | def Create_groups(pts_corr,confidence,seg,pts): 93 | seg = seg.detach().cpu().numpy() 94 | pts = pts.detach().cpu().numpy() 95 | 96 | pts_corr = pts_corr.detach().cpu().numpy() 97 | confidence = confidence.detach().cpu().numpy() 98 | seg[seg>0.5] = 1 99 | seg[seg<0.5] = 0 100 | confvalidpts = (confidence>0.4) 101 | groupid = np.zeros(seg.shape[0]) 102 | groups = {} 103 | grp_id = 1 104 | pts_in_seg = (seg==1)## points in segmentation mask with this class 105 | valid_seg_group = np.where(pts_in_seg & confvalidpts) ## points with this class and a confidence > 0.5 106 | for p in valid_seg_group[0]: 107 | if(groupid[p]==0):## if the point doesnt have a group already 108 | groupid[p] = grp_id 109 | valid_grp = np.where( (pts_corr[p]<0.1) & pts_in_seg )[0] 110 | 111 | # valid_grp = np.where(valid_grp & (distances<10.0))[0] 112 | groupid[valid_grp] = grp_id 113 | grp_id = grp_id+1 114 | 115 | print(grp_id) 116 | return groupid 117 | 118 | 119 | def GroupMerging(pts_corr, confidence, seg,label_bin): 120 | seg = seg.detach().cpu().numpy() 121 | pts_corr = pts_corr.detach().cpu().numpy() 122 | confidence = confidence.detach().cpu().numpy() 123 | seg[seg>0.5] = 1 124 | seg[seg<0.5] = 0 125 | confvalidpts = (confidence>0.4) 126 | un_seg = np.unique(seg) 127 | refineseg = np.zeros(pts_corr.shape[0]) 128 | groupid = -1* np.ones(pts_corr.shape[0]) 129 | # print(groupid,refineseg) 130 | numgroups = 0 131 | groupseg = {} 132 | for i_seg in un_seg: 133 | if i_seg==-1 : 134 | continue 135 | pts_in_seg = (seg==i_seg)## points in segmentation mask with this class 136 | valid_seg_group = np.where(pts_in_seg & confvalidpts) ## poitns with this class and a confidence > 0.5 137 | proposals = [] 138 | if valid_seg_group[0].shape[0]==0:## if there are no points in this segmentation group (no points of this class with enough confidence) 139 | proposals += [pts_in_seg] 140 | else: 141 | for ip in valid_seg_group[0]:## for all the points in this class and with enough confidence 142 | validpt = (pts_corr[ip] < label_bin[int(i_seg)]) & pts_in_seg ## take points in correlation matrix with a distance lower than a threshold and that in same class as pts 143 | if np.sum(validpt)>5:##if there are more than 5 points 144 | flag = False 145 | for gp in range(len(proposals)): 146 | iou = float(np.sum(validpt & proposals[gp])) / np.sum(validpt|proposals[gp])#uniou 147 | validpt_in_gp = float(np.sum(validpt & proposals[gp])) / np.sum(validpt)#uniou 148 | if iou > 0.8 or validpt_in_gp > 0.8: 149 | flag = True 150 | if np.sum(validpt)>np.sum(proposals[gp]): 151 | proposals[gp] = validpt 152 | continue 153 | 154 | if not flag: 155 | proposals += [validpt] 156 | 157 | if len(proposals) == 0: 158 | proposals += [pts_in_seg] 159 | for gp in range(len(proposals)): 160 | if np.sum(proposals[gp])>50: 161 | groupid[proposals[gp]] = numgroups 162 | groupseg[numgroups] = i_seg 163 | numgroups += 1 164 | refineseg[proposals[gp]] = stats.mode(seg[proposals[gp]])[0] 165 | 166 | 167 | un, cnt = np.unique(groupid, return_counts=True) 168 | for ig, g in enumerate(un): 169 | if cnt[ig] < 60: 170 | groupid[groupid==g] = -1 171 | 172 | un, cnt = np.unique(groupid, return_counts=True) 173 | groupidnew = groupid.copy() 174 | for ig, g in enumerate(un): 175 | if g == -1: 176 | continue 177 | groupidnew[groupid==g] = (ig-1) 178 | groupseg[(ig-1)] = groupseg.pop(g) 179 | groupid = groupidnew 180 | 181 | 182 | for ip, gid in enumerate(groupid): 183 | if gid == -1: 184 | pts_in_gp_ind = (pts_corr[ip] < label_bin[int(seg[ip])]) 185 | pts_in_gp = groupid[pts_in_gp_ind] 186 | pts_in_gp_valid = pts_in_gp[pts_in_gp!=-1] 187 | if len(pts_in_gp_valid) != 0: 188 | groupid[ip] = stats.mode(pts_in_gp_valid)[0][0] 189 | 190 | print(np.unique(groupid).shape) 191 | return groupid, refineseg, groupseg 192 | 193 | def BlockMerging(volume, volume_seg, pts, grouplabel, groupseg, gap=1e-3): 194 | 195 | overlapgroupcounts = np.zeros([100,300]) 196 | groupcounts = np.ones(100) 197 | x=(pts[:,0]/gap).astype(np.int32) 198 | y=(pts[:,1]/gap).astype(np.int32) 199 | z=(pts[:,2]/gap).astype(np.int32) 200 | for i in range(pts.shape[0]): 201 | xx=x[i] 202 | yy=y[i] 203 | zz=z[i] 204 | if grouplabel[i] != -1: 205 | if volume[xx,yy,zz]!=-1 and volume_seg[xx,yy,zz]==groupseg[grouplabel[i]]: 206 | overlapgroupcounts[grouplabel[i],volume[xx,yy,zz]] += 1 207 | groupcounts[grouplabel[i]] += 1 208 | 209 | groupcate = np.argmax(overlapgroupcounts,axis=1) 210 | maxoverlapgroupcounts = np.max(overlapgroupcounts,axis=1) 211 | 212 | curr_max = np.max(volume) 213 | for i in range(groupcate.shape[0]): 214 | if maxoverlapgroupcounts[i]<7 and groupcounts[i]>30: 215 | curr_max += 1 216 | groupcate[i] = curr_max 217 | 218 | 219 | finalgrouplabel = -1 * np.ones(pts.shape[0]) 220 | 221 | for i in range(pts.shape[0]): 222 | if grouplabel[i] != -1 and volume[x[i],y[i],z[i]]==-1: 223 | volume[x[i],y[i],z[i]] = groupcate[grouplabel[i]] 224 | volume_seg[x[i],y[i],z[i]] = groupseg[grouplabel[i]] 225 | finalgrouplabel[i] = groupcate[grouplabel[i]] 226 | return finalgrouplabel 227 | -------------------------------------------------------------------------------- /SPGN/__pycache__/SGPN_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/SPGN/__pycache__/SGPN_utils.cpython-37.pyc -------------------------------------------------------------------------------- /SPGN/test_SGPN.py: -------------------------------------------------------------------------------- 1 | import open3d as opend 2 | import torch 3 | import numpy as np 4 | import SGPN 5 | # import pptk 6 | # import load_data2 7 | import torch.optim as optim 8 | import torch.nn as nn 9 | import tqdm 10 | import torch.nn as nn 11 | import torch.nn.init as init 12 | import load_data 13 | import time 14 | 15 | 16 | def get_boxe_from_p(boxes): 17 | vertices = [] 18 | lines = np.array( [[0,0]]) 19 | i = 0 20 | for b in boxes: 21 | first_p = b[1] 22 | second_p = b[0] 23 | width = first_p[0] - second_p[0] 24 | height = first_p[1] - second_p[1] 25 | depth = first_p[2] - second_p[2] 26 | 27 | vertices.append(first_p) # top front right 28 | vertices.append(first_p-[width,0,0]) # top front left 29 | vertices.append(first_p-[width,height,0]) # bottom front left 30 | vertices.append(first_p-[0,height,0]) # botton front right 31 | 32 | vertices.append(second_p) # bottom back left 33 | vertices.append(first_p-[width,0,depth]) # top back left 34 | vertices.append(first_p-[0,height,depth]) # bottom back right 35 | vertices.append(first_p-[0,0,depth]) # top back right 36 | 37 | edges = [[0+(i*8),1+(i*8)],[1+(i*8),2+(i*8)],[2+(i*8),3+(i*8)],[3+(i*8),0+(i*8)] 38 | ,[4+(i*8),5+(i*8)],[4+(i*8),6+(i*8)],[6+(i*8),7+(i*8)],[7+(i*8),5+(i*8)] 39 | ,[0+(i*8),7+(i*8)],[1+(i*8),5+(i*8)],[4+(i*8),2+(i*8)],[3+(i*8),6+(i*8)]] 40 | lines = np.concatenate([lines,edges],axis = 0) 41 | i = i+1 42 | 43 | 44 | line_set = opend.geometry.LineSet() 45 | line_set.points = opend.utility.Vector3dVector(vertices) 46 | line_set.lines = opend.utility.Vector2iVector(lines[1:]) 47 | line_set.colors = opend.utility.Vector3dVector([[1, 0, 0] for i in range(lines[1:].shape[0])]) 48 | # i = i + 1 49 | 50 | return line_set 51 | 52 | def generate_bb(mask,points,seg): 53 | seg = seg.squeeze() 54 | u,u_idx,u_cnt = np.unique(mask_numpy,return_index=True,return_counts=True) 55 | boxes = [] 56 | for v in u: 57 | vals_idx = np.where(mask==v) 58 | vals_points = points[vals_idx] 59 | min = vals_points.min(axis=0,keepdims=True)[0] 60 | max = vals_points.max(axis=0,keepdims=True)[0] 61 | # if(compute_volume([min,max])*100<0.05): 62 | # boxes.append([min,max]) 63 | return get_boxe_from_p(np.asarray(boxes)) 64 | 65 | 66 | def compute_volume(boxe): 67 | min = boxe[0] 68 | max = boxe[1] 69 | width = max[0]-min[0] 70 | height = max[1]-min[1] 71 | length = max[2]-min[2] 72 | volume = width*height*length 73 | return volume 74 | 75 | print("loading data...") 76 | points,colors,annotations,centers,bb_list,mask_list,label_list = load_data.load_data("../testing") 77 | model = torch.load("./models/test_model_SGPN_0_008.ckpt").cuda() 78 | color_map = [[0.57, 0.15, 0.43], [0.85, 0.57, 0.29], [0.7100000000000001, 0.7100000000000001, 0.01], [0.01, 0.7100000000000001, 0.15], [0.01, 0.57, 0.85], [0.43, 0.01, 0.7100000000000001], [0.43, 0.7100000000000001, 0.57], [0.29, 0.29, 0.29], [0.29, 0.7100000000000001, 0.7100000000000001], [0.57, 0.57, 0.29], [0.57, 0.7100000000000001, 0.7100000000000001], [0.01, 0.01, 0.29], [0.29, 0.57, 0.7100000000000001], [0.29, 0.15, 0.57], [0.29, 0.7100000000000001, 0.43], [0.57, 0.29, 0.01], [0.29, 0.15, 0.15], [0.7100000000000001, 0.29, 0.01], [0.01, 0.85, 0.15], [0.85, 0.01, 0.01], [0.29, 0.15, 0.7100000000000001], [0.7100000000000001, 0.15, 0.43], [0.29, 0.43, 0.7100000000000001], [0.43, 0.43, 0.7100000000000001], [0.29, 0.57, 0.01], [0.57, 0.29, 0.29], [0.57, 0.85, 0.15], [0.15, 0.29, 0.29], [0.15, 0.7100000000000001, 0.15], [0.85, 0.01, 0.29], [0.43, 0.85, 0.29], [0.43, 0.29, 0.85], [0.57, 0.85, 0.85], [0.15, 0.57, 0.01], [0.57, 0.29, 0.15], [0.7100000000000001, 0.85, 0.57], [0.57, 0.01, 0.57], [0.01, 0.85, 0.43], [0.01, 0.01, 0.01], [0.85, 0.01, 0.43], [0.57, 0.43, 0.57], [0.85, 0.01, 0.57], [0.01, 0.43, 0.43], [0.01, 0.29, 0.85], [0.57, 0.57, 0.7100000000000001], [0.7100000000000001, 0.29, 0.57], [0.57, 0.7100000000000001, 0.43], [0.29, 0.15, 0.01], [0.57, 0.15, 0.15], [0.85, 0.57, 0.85], [0.85, 0.29, 0.85], [0.85, 0.15, 0.01], [0.85, 0.7100000000000001, 0.01], [0.01, 0.57, 0.15], [0.43, 0.01, 0.43], [0.57, 0.15, 0.85], [0.01, 0.29, 0.57], [0.29, 0.85, 0.43], [0.57, 0.29, 0.43], [0.43, 0.01, 0.29], [0.15, 0.85, 0.7100000000000001], [0.85, 0.57, 0.43], [0.01, 0.15, 0.57], [0.7100000000000001, 0.7100000000000001, 0.29], [0.7100000000000001, 0.15, 0.57], [0.43, 0.43, 0.29], [0.7100000000000001, 0.43, 0.43], [0.01, 0.43, 0.57], [0.57, 0.01, 0.15], [0.57, 0.57, 0.01], [0.29, 0.01, 0.29], [0.7100000000000001, 0.01, 0.29], [0.85, 0.85, 0.7100000000000001], [0.85, 0.15, 0.29], [0.43, 0.29, 0.57], [0.43, 0.43, 0.85], [0.85, 0.15, 0.85], [0.57, 0.85, 0.29], [0.57, 0.7100000000000001, 0.01], [0.7100000000000001, 0.85, 0.15], [0.85, 0.7100000000000001, 0.43], [0.01, 0.15, 0.01], [0.85, 0.29, 0.43], [0.43, 0.85, 0.15], [0.15, 0.01, 0.15], [0.7100000000000001, 0.7100000000000001, 0.85], [0.43, 0.29, 0.01], [0.15, 0.43, 0.29], [0.7100000000000001, 0.57, 0.15], [0.29, 0.85, 0.29], [0.29, 0.7100000000000001, 0.57], [0.57, 0.85, 0.7100000000000001], [0.15, 0.01, 0.85], [0.43, 0.15, 0.57], [0.57, 0.57, 0.15], [0.01, 0.57, 0.01], [0.15, 0.29, 0.57], [0.29, 0.57, 0.43], [0.15, 0.7100000000000001, 0.01], [0.15, 0.15, 0.15], [0.43, 0.29, 0.15], [0.7100000000000001, 0.29, 0.7100000000000001], [0.7100000000000001, 0.85, 0.43], [0.15, 0.29, 0.7100000000000001], [0.15, 0.43, 0.57], [0.01, 0.7100000000000001, 0.01], [0.85, 0.29, 0.01], [0.15, 0.01, 0.57], [0.29, 0.29, 0.7100000000000001], [0.15, 0.7100000000000001, 0.29], [0.01, 0.15, 0.43], [0.7100000000000001, 0.01, 0.15], [0.57, 0.43, 0.01], [0.85, 0.43, 0.01], [0.43, 0.85, 0.7100000000000001], [0.85, 0.43, 0.43], [0.85, 0.01, 0.15], [0.01, 0.43, 0.85], [0.15, 0.15, 0.7100000000000001], [0.29, 0.57, 0.85], [0.43, 0.15, 0.15], [0.29, 0.85, 0.85], [0.15, 0.57, 0.29], [0.85, 0.85, 0.85], [0.29, 0.43, 0.43], [0.01, 0.43, 0.29], [0.43, 0.15, 0.7100000000000001], [0.7100000000000001, 0.01, 0.57], [0.7100000000000001, 0.43, 0.15], [0.01, 0.85, 0.01], [0.85, 0.01, 0.7100000000000001], [0.57, 0.43, 0.43], [0.57, 0.85, 0.01], [0.01, 0.57, 0.43], [0.15, 0.15, 0.01], [0.85, 0.43, 0.85], [0.57, 0.15, 0.29], [0.7100000000000001, 0.7100000000000001, 0.57], [0.57, 0.01, 0.85], [0.29, 0.43, 0.15], [0.7100000000000001, 0.57, 0.7100000000000001], [0.43, 0.7100000000000001, 0.85], [0.01, 0.15, 0.15], [0.85, 0.85, 0.57], [0.43, 0.85, 0.01], [0.15, 0.15, 0.85], [0.29, 0.29, 0.43], [0.29, 0.43, 0.57], [0.7100000000000001, 0.29, 0.85], [0.15, 0.15, 0.43], [0.85, 0.7100000000000001, 0.85], [0.85, 0.15, 0.43], [0.43, 0.43, 0.15], [0.57, 0.7100000000000001, 0.15], [0.7100000000000001, 0.43, 0.57], [0.7100000000000001, 0.43, 0.01], [0.85, 0.29, 0.29], [0.85, 0.15, 0.15], [0.43, 0.57, 0.85], [0.01, 0.85, 0.29], [0.29, 0.7100000000000001, 0.15], [0.57, 0.85, 0.57], [0.43, 0.43, 0.57], [0.01, 0.7100000000000001, 0.7100000000000001], [0.57, 0.15, 0.57], [0.57, 0.57, 0.43], [0.85, 0.57, 0.57], [0.85, 0.7100000000000001, 0.7100000000000001], [0.57, 0.7100000000000001, 0.85], [0.15, 0.85, 0.85], [0.29, 0.57, 0.57], [0.15, 0.7100000000000001, 0.85], [0.57, 0.01, 0.29], [0.29, 0.7100000000000001, 0.01], [0.7100000000000001, 0.29, 0.15], [0.85, 0.01, 0.85], [0.29, 0.01, 0.57], [0.29, 0.01, 0.7100000000000001], [0.7100000000000001, 0.85, 0.29], [0.85, 0.29, 0.7100000000000001], [0.43, 0.15, 0.43], [0.01, 0.01, 0.15], [0.01, 0.7100000000000001, 0.57], [0.7100000000000001, 0.15, 0.01], [0.7100000000000001, 0.15, 0.15], [0.29, 0.29, 0.85], [0.01, 0.43, 0.7100000000000001], [0.57, 0.15, 0.01], [0.85, 0.15, 0.7100000000000001], [0.43, 0.43, 0.01], [0.29, 0.85, 0.57], [0.01, 0.29, 0.43], [0.57, 0.43, 0.29], [0.43, 0.29, 0.43], [0.85, 0.7100000000000001, 0.29], [0.7100000000000001, 0.85, 0.85], [0.43, 0.43, 0.43], [0.43, 0.29, 0.7100000000000001], [0.7100000000000001, 0.57, 0.57], [0.57, 0.29, 0.85], [0.01, 0.15, 0.7100000000000001], [0.7100000000000001, 0.7100000000000001, 0.15], [0.15, 0.85, 0.01], [0.43, 0.7100000000000001, 0.7100000000000001], [0.43, 0.57, 0.43], [0.7100000000000001, 0.01, 0.85], [0.29, 0.43, 0.29], [0.57, 0.15, 0.7100000000000001], [0.29, 0.57, 0.15], [0.15, 0.29, 0.43], [0.43, 0.85, 0.57], [0.43, 0.57, 0.57], [0.7100000000000001, 0.01, 0.01], [0.85, 0.85, 0.29], [0.15, 0.01, 0.29], [0.85, 0.29, 0.15], [0.15, 0.01, 0.01], [0.01, 0.01, 0.57], [0.15, 0.57, 0.15], [0.15, 0.43, 0.85], [0.01, 0.01, 0.7100000000000001], [0.85, 0.15, 0.57], [0.29, 0.15, 0.29], [0.15, 0.29, 0.15], [0.43, 0.85, 0.43], [0.01, 0.15, 0.29], [0.85, 0.7100000000000001, 0.15], [0.01, 0.01, 0.85], [0.7100000000000001, 0.57, 0.85], [0.01, 0.85, 0.57], [0.29, 0.85, 0.7100000000000001], [0.15, 0.01, 0.7100000000000001], [0.85, 0.57, 0.15], [0.15, 0.57, 0.85], [0.15, 0.43, 0.15], [0.15, 0.85, 0.57], [0.7100000000000001, 0.85, 0.01], [0.7100000000000001, 0.01, 0.7100000000000001], [0.15, 0.43, 0.43], [0.01, 0.57, 0.7100000000000001], [0.43, 0.01, 0.57], [0.01, 0.57, 0.57], [0.57, 0.43, 0.7100000000000001], [0.7100000000000001, 0.29, 0.29], [0.15, 0.43, 0.7100000000000001], [0.57, 0.43, 0.15], [0.01, 0.7100000000000001, 0.29], [0.7100000000000001, 0.57, 0.01], [0.01, 0.29, 0.15], [0.29, 0.85, 0.01], [0.57, 0.29, 0.57], [0.85, 0.43, 0.29], [0.7100000000000001, 0.43, 0.29], [0.29, 0.7100000000000001, 0.85], [0.85, 0.57, 0.01], [0.01, 0.29, 0.7100000000000001], [0.15, 0.57, 0.7100000000000001], [0.85, 0.57, 0.7100000000000001], [0.85, 0.43, 0.7100000000000001], [0.15, 0.57, 0.43], [0.15, 0.85, 0.29], [0.15, 0.29, 0.85], [0.7100000000000001, 0.29, 0.43], [0.15, 0.15, 0.29], [0.43, 0.7100000000000001, 0.01], [0.29, 0.7100000000000001, 0.29], [0.7100000000000001, 0.7100000000000001, 0.43], [0.29, 0.85, 0.15], [0.43, 0.15, 0.29], [0.29, 0.01, 0.15], [0.57, 0.01, 0.01], [0.85, 0.85, 0.15], [0.43, 0.7100000000000001, 0.15], [0.43, 0.7100000000000001, 0.29], [0.7100000000000001, 0.57, 0.43], [0.29, 0.29, 0.57], [0.29, 0.01, 0.43], [0.85, 0.7100000000000001, 0.57], [0.57, 0.29, 0.7100000000000001], [0.43, 0.29, 0.29], [0.7100000000000001, 0.43, 0.7100000000000001], [0.43, 0.57, 0.29], [0.43, 0.01, 0.15], [0.15, 0.7100000000000001, 0.57], [0.01, 0.29, 0.29], [0.29, 0.01, 0.01], [0.7100000000000001, 0.15, 0.85], [0.57, 0.7100000000000001, 0.29], [0.57, 0.7100000000000001, 0.57], [0.15, 0.29, 0.01], [0.29, 0.43, 0.01], [0.7100000000000001, 0.15, 0.7100000000000001], [0.85, 0.85, 0.43], [0.7100000000000001, 0.15, 0.29], [0.7100000000000001, 0.7100000000000001, 0.7100000000000001], [0.29, 0.57, 0.29], [0.85, 0.43, 0.57], [0.01, 0.57, 0.29], [0.57, 0.01, 0.43], [0.01, 0.29, 0.01], [0.29, 0.01, 0.85], [0.43, 0.57, 0.15], [0.85, 0.85, 0.01], [0.29, 0.15, 0.43], [0.29, 0.29, 0.15], [0.01, 0.43, 0.01], [0.43, 0.85, 0.85], [0.15, 0.7100000000000001, 0.7100000000000001], [0.15, 0.85, 0.43], [0.7100000000000001, 0.85, 0.7100000000000001], [0.7100000000000001, 0.01, 0.43], [0.15, 0.43, 0.01], [0.43, 0.57, 0.01], [0.7100000000000001, 0.57, 0.29], [0.01, 0.85, 0.85], [0.01, 0.85, 0.7100000000000001], [0.01, 0.7100000000000001, 0.85], [0.43, 0.7100000000000001, 0.43], [0.43, 0.15, 0.01], [0.57, 0.85, 0.43], [0.85, 0.43, 0.15], [0.01, 0.15, 0.85], [0.85, 0.29, 0.57], [0.01, 0.01, 0.43], [0.57, 0.01, 0.7100000000000001], [0.29, 0.29, 0.01], [0.57, 0.43, 0.85], [0.43, 0.15, 0.85], [0.01, 0.7100000000000001, 0.43], [0.15, 0.57, 0.57], [0.15, 0.15, 0.57], [0.29, 0.43, 0.85], [0.43, 0.01, 0.01], [0.57, 0.57, 0.85], [0.15, 0.7100000000000001, 0.43], [0.43, 0.01, 0.85], [0.15, 0.01, 0.43], [0.43, 0.57, 0.7100000000000001], [0.29, 0.15, 0.85], [0.15, 0.85, 0.15], [0.7100000000000001, 0.43, 0.85], [0.01, 0.43, 0.15], [0.57, 0.57, 0.57]] 79 | for i in tqdm.tqdm(range(0,len(points))): 80 | input_tensor = torch.from_numpy(points[i]).float().unsqueeze(0).cuda() 81 | # input_tensor3 = torch.from_numpy(centers[i]).float().unsqueeze(0).cuda() 82 | if(input_tensor.shape[1]>1000):# and input_tensor3.shape[1]>0): 83 | input_tensor2 = torch.from_numpy(colors[i]).float().unsqueeze(0).cuda() 84 | target = torch.from_numpy(annotations[i]).unsqueeze(0).unsqueeze(1).cuda().float() 85 | bb_target = torch.from_numpy(bb_list[i]).unsqueeze(0).cuda().float() 86 | mask_target = torch.from_numpy(mask_list[i]).unsqueeze(0).cuda().float() 87 | 88 | gt = {"ptsgroup":mask_target,"semseg":target,"bounding_boxes":bb_target} 89 | loss,mask_numpy ,seg =model(input_tensor,input_tensor2,gt,False,0,False) 90 | # seg = seg.detach().cpu().numpy() 91 | print(seg.shape) 92 | seg[seg>0.5] = 1 93 | seg[seg<0.5] = 0 94 | colors3 = np.zeros((mask_numpy.shape[0],3)) 95 | for a,v in enumerate(seg): 96 | if(v==1): 97 | colors3[a] = [1.0,0.0,0.0] 98 | else: 99 | colors3[a] = [0.0,1.0,0.0] 100 | # colors3 = 101 | boxes = generate_bb(mask_numpy,points[i].copy(),seg) 102 | colors2 = np.zeros((mask_numpy.shape[0],3)) 103 | for a,v in enumerate(mask_numpy): 104 | colors2[a] = color_map[int(v)] 105 | 106 | pcd2 = opend.PointCloud() 107 | pcd2.points = opend.Vector3dVector(points[i]) 108 | pcd2.colors = opend.Vector3dVector(colors2) 109 | 110 | pcd3 = opend.PointCloud() 111 | pcd3.points = opend.Vector3dVector(points[i]+np.array([0.5,0.0,0.0])) 112 | pcd3.colors = opend.Vector3dVector(colors[i]) 113 | 114 | pcd4 = opend.PointCloud() 115 | pcd4.points = opend.Vector3dVector(points[i]+np.array([0.4,0.4,0.0])) 116 | pcd4.colors = opend.Vector3dVector(colors3) 117 | 118 | opend.draw_geometries([pcd2,pcd3,pcd4]) 119 | -------------------------------------------------------------------------------- /SPGN/train_SGPN.py: -------------------------------------------------------------------------------- 1 | import open3d as opend 2 | import torch 3 | import numpy as np 4 | import SGPN 5 | import SGPN2 6 | # import pptk 7 | # import load_data2 8 | import torch.optim as optim 9 | import torch.nn as nn 10 | import tqdm 11 | import torch.nn as nn 12 | import torch.nn.init as init 13 | import load_data 14 | from torch.utils.tensorboard import SummaryWriter 15 | import time 16 | 17 | 18 | 19 | 20 | def get_boxe_from_p(boxes): 21 | vertices = [] 22 | lines = np.array( [[0,0]]) 23 | i = 0 24 | for b in boxes: 25 | first_p = b[1] 26 | second_p = b[0] 27 | width = first_p[0] - second_p[0] 28 | height = first_p[1] - second_p[1] 29 | depth = first_p[2] - second_p[2] 30 | 31 | vertices.append(first_p) # top front right 32 | vertices.append(first_p-[width,0,0]) # top front left 33 | vertices.append(first_p-[width,height,0]) # bottom front left 34 | vertices.append(first_p-[0,height,0]) # botton front right 35 | 36 | vertices.append(second_p) # bottom back left 37 | vertices.append(first_p-[width,0,depth]) # top back left 38 | vertices.append(first_p-[0,height,depth]) # bottom back right 39 | vertices.append(first_p-[0,0,depth]) # top back right 40 | 41 | edges = [[0+(i*8),1+(i*8)],[1+(i*8),2+(i*8)],[2+(i*8),3+(i*8)],[3+(i*8),0+(i*8)] 42 | ,[4+(i*8),5+(i*8)],[4+(i*8),6+(i*8)],[6+(i*8),7+(i*8)],[7+(i*8),5+(i*8)] 43 | ,[0+(i*8),7+(i*8)],[1+(i*8),5+(i*8)],[4+(i*8),2+(i*8)],[3+(i*8),6+(i*8)]] 44 | lines = np.concatenate([lines,edges],axis = 0) 45 | i = i+1 46 | 47 | 48 | line_set = opend.geometry.LineSet() 49 | line_set.points = opend.utility.Vector3dVector(vertices) 50 | line_set.lines = opend.utility.Vector2iVector(lines[1:]) 51 | line_set.colors = opend.utility.Vector3dVector([[1, 0, 0] for i in range(lines[1:].shape[0])]) 52 | # i = i + 1 53 | 54 | return line_set 55 | 56 | 57 | seed = list(map(ord, 'toto')) 58 | seed = map(str, seed) 59 | seed = ''.join(seed) 60 | seed = int(seed) 61 | torch.manual_seed(seed) 62 | np.random.seed(1) 63 | 64 | torch.set_num_threads(1) 65 | OMP_NUM_THREADS=1 66 | torch.backends.cudnn.deterministic = False 67 | torch.backends.cudnn.benchmark = False 68 | 69 | print("loading data...") 70 | points,colors,annotations,centers,bb_list,mask_list,label_list = load_data.load_data("../test") 71 | model = SGPN.PointNet2SemSeg(2).cuda() 72 | # model = torch.load("./models/test_model_SGPN_only_pn.ckpt").cuda() 73 | optimizer = optim.Adam([{'params': model.parameters(), 'lr': 0.00005}]) 74 | epochs = 100 75 | 76 | 77 | writer = SummaryWriter() 78 | print("training...") 79 | for p in range(0,epochs): 80 | lost = [] 81 | lost_seg = [] 82 | lost_simmat_loss = [] 83 | lost_confidence_loss = [] 84 | for i in tqdm.tqdm(range(0,len(points))): 85 | input_tensor = torch.from_numpy(points[i]).float().unsqueeze(0).cuda() 86 | # input_tensor3 = torch.from_numpy(centers[i]).float().unsqueeze(0).cuda() 87 | if(input_tensor.shape[1]>1000):# and input_tensor3.shape[1]>0): 88 | input_tensor2 = torch.from_numpy(colors[i]).float().unsqueeze(0).cuda() 89 | target = torch.from_numpy(annotations[i]).unsqueeze(0).unsqueeze(1).cuda().float() 90 | bb_target = torch.from_numpy(bb_list[i]).unsqueeze(0).cuda().float() 91 | mask_target = torch.from_numpy(mask_list[i]).unsqueeze(0).cuda().float() 92 | optimizer.zero_grad 93 | gt = {"ptsgroup":mask_target,"semseg":target,"bounding_boxes":bb_target} 94 | loss,mask = model(input_tensor,input_tensor2,gt,True,p,False) 95 | lost.append(loss[0].data.item()) 96 | lost_seg.append(loss[1].data.item()) 97 | lost_simmat_loss.append(loss[2].data.item()) 98 | lost_confidence_loss.append(loss[3].data.item()) 99 | loss = loss[0] 100 | loss.backward() 101 | optimizer.step() 102 | 103 | if(p%100 ==0 and p!=0 ): 104 | torch.save(model.cpu(), "./models/test_model"+str(p)+".ckpt") 105 | model.cuda() 106 | writer.add_scalar('Loss', np.array(lost).mean(), p) 107 | writer.add_scalar('Loss seg', np.array(lost_seg).mean(), p) 108 | writer.add_scalar('Loss simmat_loss ', np.array(lost_simmat_loss).mean(), p) 109 | writer.add_scalar('Loss confidence_loss', np.array(lost_confidence_loss).mean(), p) 110 | torch.save(model.cpu(), "./models/test_model_SGPN_0_008.ckpt") 111 | -------------------------------------------------------------------------------- /Utils/Utilities.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import numba 3 | 4 | 5 | 6 | @jit(nopython=True) 7 | def aligne_depth(depth_raw,rowCount,columnCount,depth_aligned,focal_length, principal_point, color_focal_length, color_principal_point,extrinsic): 8 | for depth_y in range(0,rowCount): 9 | for depth_x in range(0, columnCount ): 10 | ## depth to meters 11 | z = depth_raw[depth_y][depth_x] * 0.001 12 | if(z!=0 and z <= 3): 13 | ##### top left corner to rgb image 14 | ## map depth map to 3D point 15 | x_a = (depth_x - 0.5 - principal_point[0] ) / focal_length[0] 16 | y_a = (depth_y - 0.5 - principal_point[1] ) / focal_length[1] 17 | 18 | ## Rotate and translate 3D point into RGB point of view 19 | point1 = np.array([x_a*z,y_a*z,z, 1.0]) 20 | point1 = np.dot(extrinsic,point1) 21 | point1 = point1[:3] / point1[3] 22 | 23 | #### mapping 3D depth points to RGB image ### 24 | rgb_x1 = point1[0] / point1[2] 25 | rgb_y1 = point1[1] / point1[2] 26 | rgb_x_1 = color_focal_length[0] * rgb_x1 + color_principal_point[0] 27 | rgb_y_1 = color_focal_length[1] * rgb_y1 + color_principal_point[1] 28 | 29 | rgb_x_1 = int(rgb_x_1 + 0.5) 30 | rgb_y_1 = int(rgb_y_1 + 0.5 ) 31 | 32 | 33 | ## Bottom right corner to rgb image 34 | ## map depth map to 3D point 35 | x_b = (depth_x + 0.5 - principal_point[0]) / focal_length[0] 36 | y_b = (depth_y + 0.5 - principal_point[1] ) / focal_length[1] 37 | 38 | ## Rotate and translate 3D point into RGB point of view 39 | point2 = np.array([x_b *z ,y_b *z ,z,1.0]) 40 | point2 = np.dot(extrinsic,point2) 41 | point2 = point2[:3] / point2[3] 42 | 43 | #### mapping 3D depth points to RGB image ### 44 | rgb_x2 = point2[0] /point2[2] 45 | rgb_y2 = point2[1] /point2[2] 46 | rgb_x_2 = color_focal_length[0] * rgb_x2 + color_principal_point[0] 47 | rgb_y_2 = color_focal_length[1] * rgb_y2 + color_principal_point[1] 48 | 49 | rgb_x_2 = int(rgb_x_2 + 0.5) 50 | rgb_y_2 = int(rgb_y_2 + 0.5) 51 | 52 | if(rgb_x_1 > 0 and rgb_y_1 > 0 and rgb_y_2 < len(depth_aligned) and rgb_x_2 < len(depth_aligned[0]) ): 53 | for a in range(rgb_y_1,rgb_y_2+1): 54 | for b in range(rgb_x_1,rgb_x_2+1): 55 | if(depth_aligned[a][b] != 0): 56 | depth_aligned[a][b] = min(depth_aligned[a][b], depth_raw[depth_y][depth_x] ) 57 | else: 58 | depth_aligned[a][b] = depth_raw[depth_y][depth_x] 59 | return depth_aligned 60 | -------------------------------------------------------------------------------- /Utils/__pycache__/net_utils.cpython-37.pyc: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/Utils/__pycache__/net_utils.cpython-37.pyc -------------------------------------------------------------------------------- /Utils/net_utils.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.init as init 4 | import torch.nn.functional as F 5 | import numpy as np 6 | from time import time 7 | import point 8 | 9 | 10 | def index_points(points, idx): 11 | """ 12 | Input: 13 | points: input points data, [B, N, C] 14 | idx: sample index data, [B, D1, D2, ..., Dn] 15 | Return: 16 | new_points:, indexed points data, [B, D1, D2, ..., Dn, C] 17 | """ 18 | device = points.device 19 | B = points.shape[0] 20 | view_shape = list(idx.shape) 21 | view_shape[1:] = [1] * (len(view_shape) - 1) 22 | repeat_shape = list(idx.shape) 23 | repeat_shape[0] = 1 24 | batch_indices = torch.arange(B, dtype=torch.long).view(view_shape).repeat(repeat_shape) 25 | new_points = points[batch_indices, idx, :] 26 | return new_points 27 | 28 | def square_distance(src, dst): 29 | """ 30 | Description: 31 | just the simple Euclidean distance fomula,(x-y)^2, 32 | Input: 33 | src: source points, [B, N, C] 34 | dst: target points, [B, M, C] 35 | Output: 36 | dist: per-point square distance, [B, N, M] 37 | """ 38 | B, N, _ = src.shape 39 | _, M, _ = dst.shape 40 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1).contiguous()) 41 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 42 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 43 | return dist 44 | 45 | def group_points(xyz,idx): 46 | b , n , c = xyz.shape 47 | m = idx.shape[1] 48 | nsample = idx.shape[2] 49 | out = torch.zeros((xyz.shape[0],xyz.shape[1], idx.shape[2],c)).cuda() 50 | point.group_points(b,n,c,n,nsample,xyz,idx.int(),out) 51 | return out 52 | 53 | def farthest_point_sample_gpu(xyz, npoint): 54 | b, n ,c = xyz.shape 55 | centroid = torch.zeros((xyz.shape[0],npoint)).int().cuda() 56 | temp = torch.zeros((32,n)).cuda() 57 | point.farthestPoint(b,n, npoint, xyz , temp ,centroid) 58 | return centroid.long() 59 | 60 | def ball_query(radius, nsample, xyz, new_xyz): 61 | b, n ,c = xyz.shape 62 | m = new_xyz.shape[1] 63 | group_idx = torch.zeros((new_xyz.shape[0],new_xyz.shape[1], nsample), dtype=torch.int32).cuda() 64 | pts_cnt = torch.zeros((xyz.shape[0],xyz.shape[1]), dtype=torch.int32).cuda() 65 | point.ball_query (b, n, m, radius, nsample, xyz, new_xyz, group_idx ,pts_cnt) 66 | return group_idx.long() 67 | 68 | def idx_pts(points,idx): 69 | new_points = torch.cat([points.index_select(1,idx[b]) for b in range(0,idx.shape[0])], dim=0) 70 | return new_points 71 | 72 | def sample_and_group(npoint, radius, nsample, xyz, points): 73 | """ 74 | Input: 75 | npoint: the number of points that make the local region. 76 | radius: the radius of the local region 77 | nsample: the number of points in a local region 78 | xyz: input points position data, [B, N, C] 79 | points: input points data, [B, N, D] 80 | Return: 81 | new_xyz: sampled points position data, [B, 1, C] 82 | new_points: sampled points data, [B, 1, N, C+D] 83 | """ 84 | B, N, C = xyz.shape 85 | Np = npoint 86 | assert isinstance(Np, int) 87 | 88 | new_xyz = index_points(xyz, farthest_point_sample_gpu(xyz, npoint)) # [B,n,3] and [B,np] → [B,np,3] 89 | idx = ball_query(radius, nsample, xyz, new_xyz) 90 | grouped_xyz = index_points(xyz, idx)# [B,n,3] and [B,n,M] → [B,n,M,3] 91 | grouped_xyz -= new_xyz.view(B, Np, 1, C) # the points of each group will be normalized with their centroid 92 | if points is not None: 93 | grouped_points = index_points(points, idx)# [B,n,3] and [B,n,M] → [B,n,M,3] 94 | new_points = torch.cat([grouped_xyz, grouped_points], dim=-1) 95 | else: 96 | new_points = grouped_xyz 97 | return new_xyz, new_points 98 | 99 | def sample_and_group_all(xyz, points): 100 | """ 101 | Description: 102 | Equivalent to sample_and_group with npoint=1, radius=np.inf, and the centroid is (0, 0, 0) 103 | Input: 104 | xyz: input points position data, [B, N, C] 105 | points: input points data, [B, N, D] 106 | Return: 107 | new_xyz: sampled points position data, [B, 1, C] 108 | new_points: sampled points data, [B, 1, N, C+D] 109 | """ 110 | device = xyz.device 111 | B, N, C = xyz.shape 112 | new_xyz = torch.zeros(B, 1, C).to(device) 113 | grouped_xyz = xyz.view(B, 1, N, C) 114 | if points is not None: 115 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 116 | else: 117 | new_points = grouped_xyz 118 | return new_xyz, new_points 119 | 120 | 121 | class PointNetSetAbstraction(nn.Module): 122 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 123 | super(PointNetSetAbstraction, self).__init__() 124 | self.npoint = npoint 125 | self.radius = radius 126 | self.nsample = nsample 127 | self.mlp_convs = nn.ModuleList() 128 | self.mlp_bns = nn.ModuleList() 129 | last_channel = in_channel 130 | for out_channel in mlp: 131 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1,1)) 132 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 133 | last_channel = out_channel 134 | self.group_all = group_all 135 | 136 | 137 | def forward(self, xyz, points): 138 | """ 139 | Input: 140 | xyz: input points position data, [B, C, N] 141 | points: input points data, [B, D, N] 142 | Return: 143 | new_xyz: sampled points position data, [B, C, S] 144 | new_points_concat: sample points feature data, [B, D', S] 145 | """ 146 | xyz = xyz.permute(0, 2, 1).contiguous() 147 | if points is not None: 148 | points = points.permute(0, 2, 1) 149 | 150 | if self.group_all: 151 | new_xyz, new_points = sample_and_group_all(xyz, points) 152 | else: 153 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 154 | 155 | new_points = new_points.permute(0, 3, 2, 1).contiguous() 156 | for i, conv in enumerate(self.mlp_convs): 157 | bn = self.mlp_bns[i] 158 | new_points = F.relu(bn(conv(new_points))) 159 | 160 | new_points = torch.max(new_points, 2)[0] 161 | new_xyz = new_xyz.permute(0, 2, 1) 162 | return new_xyz, new_points 163 | class PointNetSetAbstractionMsg(nn.Module): 164 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp, group_all): 165 | super(PointNetSetAbstractionMsg, self).__init__() 166 | self.npoint = npoint 167 | self.radius_list = radius_list 168 | self.nsample_list = nsample_list 169 | self.mlp_convs = nn.ModuleList() 170 | self.mlp_bns = nn.ModuleList() 171 | last_channel = in_channel 172 | for i in range(len(mlp)): 173 | convs = nn.ModuleList() 174 | bns = nn.ModuleList() 175 | last_channel = in_channel 176 | for out_channel in mlp[i]: 177 | convs.append(nn.Conv2d(last_channel, out_channel, 1,1)) 178 | bns.append(nn.BatchNorm2d(out_channel)) 179 | last_channel = out_channel 180 | self.mlp_convs.append(convs) 181 | self.mlp_bns.append(bns) 182 | self.group_all = group_all 183 | 184 | 185 | def forward(self, xyz, points): 186 | """ 187 | Input: 188 | xyz: input points position data, [B, C, N] 189 | points: input points data, [B, D, N] 190 | Return: 191 | new_xyz: sampled points position data, [B, C, S] 192 | new_points_concat: sample points feature data, [B, D', S] 193 | """ 194 | xyz = xyz.permute(0, 2, 1).contiguous() 195 | if points is not None: 196 | points = points.permute(0, 2, 1) 197 | 198 | B, N, C = xyz.shape 199 | S = self.npoint 200 | new_xyz = index_points(xyz, farthest_point_sample_gpu(xyz, S)) 201 | new_points_list = [] 202 | for i, radius in enumerate(self.radius_list): 203 | K = self.nsample_list[i] 204 | group_idx = ball_query(radius, K, xyz, new_xyz) 205 | grouped_xyz = index_points(xyz, group_idx) 206 | grouped_xyz -= new_xyz.view(B, S, 1, C) 207 | if points is not None: 208 | grouped_points = index_points(points, group_idx) 209 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 210 | else: 211 | grouped_points = grouped_xyz 212 | 213 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 214 | for j in range(len(self.mlp_convs[i])): 215 | conv = self.mlp_convs[i][j] 216 | bn = self.mlp_bns[i][j] 217 | grouped_points = F.relu(bn(conv(grouped_points))) 218 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 219 | new_points_list.append(new_points) 220 | 221 | new_xyz = new_xyz.permute(0, 2, 1) 222 | new_points_concat = torch.cat(new_points_list, dim=1) 223 | return new_xyz, new_points_concat 224 | class PointNetFeaturePropagation(nn.Module): 225 | def __init__(self, in_channel, mlp): 226 | super(PointNetFeaturePropagation, self).__init__() 227 | self.mlp_convs = nn.ModuleList() 228 | self.mlp_bns = nn.ModuleList() 229 | last_channel = in_channel 230 | for out_channel in mlp: 231 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 232 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 233 | last_channel = out_channel 234 | 235 | def forward(self, xyz1, xyz2, points1, points2): 236 | """ 237 | Input: 238 | xyz1: input points position data, [B, C, N] 239 | xyz2: sampled input points position data, [B, C, S] 240 | points1: input points data, [B, D, N] 241 | points2: input points data, [B, D, S] 242 | Return: 243 | new_points: upsampled points data, [B, D', N] 244 | """ 245 | xyz1 = xyz1.permute(0, 2, 1) 246 | xyz2 = xyz2.permute(0, 2, 1) 247 | 248 | points2 = points2.permute(0, 2, 1) 249 | B, N, C = xyz1.shape 250 | _, S, _ = xyz2.shape 251 | 252 | if S == 1: 253 | interpolated_points = points2.repeat(1, N, 1) 254 | else: 255 | dists = square_distance(xyz1, xyz2) 256 | dists, idx = dists.sort(dim=-1) 257 | dists, idx = dists[:,:,:3], idx[:,:,:3] #[B, N, 3] 258 | dists[dists < 1e-10] = 1e-10 259 | weight = 1.0 / dists #[B, N, 3] 260 | weight = weight / torch.sum(weight, dim=-1).view(B, N, 1) #[B, N, 3] 261 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim = 2) 262 | 263 | if points1 is not None: 264 | points1 = points1.permute(0, 2, 1) 265 | new_points = torch.cat([points1, interpolated_points], dim=-1) 266 | else: 267 | new_points = interpolated_points 268 | 269 | new_points = new_points.permute(0, 2, 1) 270 | for i, conv in enumerate(self.mlp_convs): 271 | bn = self.mlp_bns[i] 272 | new_points = F.relu(bn(conv(new_points))) 273 | return new_points 274 | -------------------------------------------------------------------------------- /cppattempt/Point.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "Point.h" 5 | 6 | void select_cube(at::Tensor xyz, at::Tensor idx_out, int b, int n,float radius) 7 | { 8 | cubeSelectLauncher(b,n,radius,xyz.contiguous().data(), idx_out.contiguous().data()); 9 | } 10 | 11 | void group_points(int b, int n, int c , int m , int nsamples, at::Tensor xyz, at::Tensor idx, at::Tensor out) 12 | { 13 | group_pointsLauncher(b,n,c,m,nsamples,xyz.contiguous().data(),idx.contiguous().data(),out.contiguous().data()); 14 | } 15 | void ball_query (int b, int n, int m, float radius, int nsample, at::Tensor xyz1, at::Tensor xyz2, at::Tensor idx, at::Tensor pts_cnt) 16 | { 17 | queryBallPointLauncher(b, n, m, radius, nsample, xyz1.contiguous().data(), xyz2.contiguous().data(), idx.contiguous().data(), pts_cnt.contiguous().data()); 18 | } 19 | 20 | void farthestPoint(int b,int n,int m, at::Tensor inp, at::Tensor temp,at::Tensor out) 21 | { 22 | farthestpointsamplingLauncher(b, n, m, inp.contiguous().data(), temp.contiguous().data(),out.contiguous().data()); 23 | } 24 | 25 | void interpolate(int b, int n, int m, at::Tensor xyz1p, at::Tensor xyz2p, at::Tensor distp, at::Tensor idxp){ 26 | 27 | auto xyz1 = xyz1p.contiguous().data(); 28 | auto xyz2 = xyz2p.contiguous().data(); 29 | auto dist = distp.contiguous().data(); 30 | auto idx = idxp.contiguous().data(); 31 | 32 | for (int i=0;i()[j*3]=best1; 63 | idxp.contiguous().data()[j*3]=besti1; 64 | distp.contiguous().data()[j*3+1]=best2; 65 | idxp.contiguous().data()[j*3+1]=besti2; 66 | distp.contiguous().data()[j*3+2]=best3; 67 | idxp.contiguous().data()[j*3+2]=besti3; 68 | } 69 | xyz1+=n*3; 70 | xyz2+=m*3; 71 | dist+=n*3; 72 | idx+=n*3; 73 | } 74 | } 75 | 76 | at::Tensor three_interpolate(int b, int m, int c, int n,int d, at::Tensor points, at::Tensor idx, at::Tensor weight, at::Tensor out){ 77 | 78 | float * pointsp = points.contiguous().data(); 79 | float * weightp = weight.contiguous().data(); 80 | float * outp = out.contiguous().data(); 81 | int * idxp = idx.contiguous().data(); 82 | float w1,w2,w3; 83 | int i1,i2,i3; 84 | for (int i=0;i 6 | #include 7 | #include 8 | #include 9 | 10 | void select_cube(at::Tensor xyz, at::Tensor idx_out, int b, int n,float radius); 11 | void group_points(int b, int n, int c , int m , int nsamples, at::Tensor xyz, at::Tensor idx, at::Tensor out); 12 | void ball_query (int b, int n, int m, float radius, int nsample, at::Tensor xyz1, at::Tensor xyz2, at::Tensor idx, at::Tensor pts_cnt); 13 | void farthestPoint(int b,int n,int m, at::Tensor inp, at::Tensor temp,at::Tensor out); 14 | void interpolate(int b, int n, int m, at::Tensor xyz1p, at::Tensor xyz2p, at::Tensor distp, at::Tensor idxp); 15 | at::Tensor three_interpolate(int b, int m, int c, int n,int d, at::Tensor points, at::Tensor idx, at::Tensor weight, at::Tensor out); 16 | 17 | 18 | 19 | 20 | 21 | void cubeSelectLauncher(int b, int n, float radius, float * xyz, int * idx_out); 22 | void queryBallPointLauncher(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx, int *pts_cnt); 23 | void farthestpointsamplingLauncher(int b,int n,int m,const float * inp,float * temp,int * out); 24 | void threennLauncher(int b, int n, int m, const float *xyz1, const float *xyz2, float *dist, int *idx); 25 | void interpolateLauncher(int b, int m, int c, int n, const float *points, const int *idx, const float *weight, float *out); 26 | void group_pointsLauncher(int b, int n, int c, int m, int nsamples, const float * pointsp, const int * idxp, float * outp); 27 | 28 | #endif 29 | -------------------------------------------------------------------------------- /cppattempt/Point_cuda.cu: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | #include 4 | #include "Point.h" 5 | 6 | 7 | __global__ void cubeselect(int n,float radius, float * xyz, int * idx_out) 8 | { 9 | int batch_idx = blockIdx.x; 10 | xyz += batch_idx * n * 3; 11 | idx_out += batch_idx * n * 8; 12 | float temp_dist[8]; 13 | float judge_dist = radius * radius; 14 | for(int i = threadIdx.x; i < n;i += blockDim.x) { 15 | float x = xyz[i * 3]; 16 | float y = xyz[i * 3 + 1]; 17 | float z = xyz[i * 3 + 2]; 18 | for(int j = 0;j < 8;j ++) { 19 | temp_dist[j] = 1e8; 20 | idx_out[i * 8 + j] = i; // if not found, just return itself.. 21 | } 22 | for(int j = 0;j < n;j ++) { 23 | if(i != j){ 24 | float tx = xyz[j * 3]; 25 | float ty = xyz[j * 3 + 1]; 26 | float tz = xyz[j * 3 + 2]; 27 | float dist = (x - tx) * (x - tx) + (y - ty) * (y - ty) + (z - tz) * (z - tz); 28 | if(dist <= judge_dist){ 29 | int _x = (tx > x); 30 | int _y = (ty > y); 31 | int _z = (tz > z); 32 | int temp_idx = _x * 4 + _y * 2 + _z; 33 | if(dist < temp_dist[temp_idx]) { 34 | idx_out[i * 8 + temp_idx] = j; 35 | temp_dist[temp_idx] = dist; 36 | } 37 | } 38 | } 39 | } 40 | 41 | } 42 | } 43 | 44 | // input: points (b,n,c), idx (b,m,nsample) 45 | // output: out (b,m,nsample,c) 46 | __global__ void group_point_gpu(int b, int n, int c, int m, int nsample, const float *points, const int *idx, float *out) 47 | { 48 | int batch_index = blockIdx.x; 49 | points += n*c*batch_index; 50 | idx += m*nsample*batch_index; 51 | out += m*nsample*c*batch_index; 52 | 53 | int index = threadIdx.x; 54 | int stride = blockDim.x; 55 | 56 | for (int j=index;jbest){ 146 | best=d2; 147 | besti=k; 148 | } 149 | } 150 | dists[threadIdx.x]=best; 151 | dists_i[threadIdx.x]=besti; 152 | for (int u=0;(1<>(u+1))){ 155 | int i1=(threadIdx.x*2)<>>(n, radius, xyz, idx_out); 247 | } 248 | 249 | void group_pointsLauncher(int b, int n, int c, int m, int nsamples, const float * pointsp, const int * idxp, float * outp){ 250 | group_point_gpu<<>>(b,n,c,m,nsamples,pointsp,idxp,outp); 251 | } 252 | void queryBallPointLauncher(int b, int n, int m, float radius, int nsample, const float *xyz1, const float *xyz2, int *idx, int *pts_cnt) { 253 | query_ball_point_gpu<<>>(b,n,m,radius,nsample,xyz1,xyz2,idx,pts_cnt); 254 | //cudaDeviceSynchronize(); 255 | } 256 | void farthestpointsamplingLauncher(int b,int n,int m,const float * inp,float * temp,int * out){ 257 | farthestpointsamplingKernel<<<32,512>>>(b,n,m,inp,temp,out); 258 | } 259 | 260 | void threennLauncher(int b, int n, int m, const float *xyz1, const float *xyz2, float *dist, int *idx){ 261 | threenn<<>>(b,n,m,xyz1,xyz2,dist,idx); 262 | } 263 | 264 | void interpolateLauncher(int b, int m, int c, int n, const float *points, const int *idx, const float *weight, float *out){ 265 | interpolategp<<>>(b,m,c,n,points,idx,weight,out); 266 | } 267 | -------------------------------------------------------------------------------- /cppattempt/build/lib.linux-x86_64-3.7/point.cpython-37m-x86_64-linux-gnu.so: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/cppattempt/build/lib.linux-x86_64-3.7/point.cpython-37m-x86_64-linux-gnu.so -------------------------------------------------------------------------------- /cppattempt/build/temp.linux-x86_64-3.7/Point.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/cppattempt/build/temp.linux-x86_64-3.7/Point.o -------------------------------------------------------------------------------- /cppattempt/build/temp.linux-x86_64-3.7/Point_cuda.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/cppattempt/build/temp.linux-x86_64-3.7/Point_cuda.o -------------------------------------------------------------------------------- /cppattempt/build/temp.linux-x86_64-3.7/point_api.o: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/cppattempt/build/temp.linux-x86_64-3.7/point_api.o -------------------------------------------------------------------------------- /cppattempt/dist/point-0.0.0-py3.7-linux-x86_64.egg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/cppattempt/dist/point-0.0.0-py3.7-linux-x86_64.egg -------------------------------------------------------------------------------- /cppattempt/example.py: -------------------------------------------------------------------------------- 1 | import math 2 | from torch import nn 3 | from torch.autograd import Function 4 | import torch 5 | import sys 6 | import time 7 | import point 8 | 9 | x = torch.FloatTensor(1,200, 3).cuda() 10 | ya = torch.zeros((1,200, 8), dtype=torch.int32).cuda() 11 | radius = 1.0 12 | start_time = time.time() 13 | point.select_cube(x,ya,8,6,radius) 14 | print(time.time() - start_time) 15 | ya = ya.cpu() 16 | 17 | start_time = time.time() 18 | xyz = x.cpu() 19 | radius = 0.4 20 | Dist = lambda x, y, z: x ** 2 + y ** 2 + z ** 2 21 | B, N, _ = xyz.shape 22 | idx = torch.empty(B, N, 8) 23 | judge_dist = radius ** 2 24 | temp_dist = torch.ones(B, N, 8) * 1e10 25 | for b in range(B): 26 | for n in range(N): 27 | idx[b, n, :] = n 28 | x, y, z = xyz[b, n] 29 | for p in range(N): 30 | if p == n: continue 31 | tx, ty, tz = xyz[b, p] 32 | dist = Dist(x - tx, y - ty, z - tz) 33 | if dist > judge_dist: continue 34 | _x, _y, _z = tx > x, ty > y, tz > z 35 | temp_idx = (_x * 4 + _y * 2 + _z).int() 36 | if dist < temp_dist[b, n, temp_idx]: 37 | idx[b, n, temp_idx] = p 38 | temp_dist[b, n, temp_idx] = dist 39 | print(idx) 40 | print(ya) 41 | if(torch.all(torch.eq(idx.int(), ya.cpu().int()))): 42 | print("success") 43 | 44 | print(time.time() - start_time) 45 | -------------------------------------------------------------------------------- /cppattempt/point.egg-info/PKG-INFO: -------------------------------------------------------------------------------- 1 | Metadata-Version: 1.0 2 | Name: point 3 | Version: 0.0.0 4 | Summary: UNKNOWN 5 | Home-page: UNKNOWN 6 | Author: UNKNOWN 7 | Author-email: UNKNOWN 8 | License: UNKNOWN 9 | Description: UNKNOWN 10 | Platform: UNKNOWN 11 | -------------------------------------------------------------------------------- /cppattempt/point.egg-info/SOURCES.txt: -------------------------------------------------------------------------------- 1 | Point.cpp 2 | Point_cuda.cu 3 | point_api.cpp 4 | setup.py 5 | point.egg-info/PKG-INFO 6 | point.egg-info/SOURCES.txt 7 | point.egg-info/dependency_links.txt 8 | point.egg-info/top_level.txt -------------------------------------------------------------------------------- /cppattempt/point.egg-info/dependency_links.txt: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /cppattempt/point.egg-info/top_level.txt: -------------------------------------------------------------------------------- 1 | point 2 | -------------------------------------------------------------------------------- /cppattempt/point_api.cpp: -------------------------------------------------------------------------------- 1 | #include 2 | #include 3 | 4 | #include "Point.h" 5 | 6 | 7 | PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { 8 | m.def("select_cube", &select_cube); 9 | m.def("group_points", &group_points); 10 | m.def("ball_query", &ball_query); 11 | m.def("farthestPoint", &farthestPoint); 12 | m.def("interpolate", &interpolate); 13 | m.def("three_interpolate", &three_interpolate); 14 | } 15 | -------------------------------------------------------------------------------- /cppattempt/setup.py: -------------------------------------------------------------------------------- 1 | from setuptools import setup 2 | from torch.utils.cpp_extension import CppExtension, BuildExtension,CUDAExtension 3 | 4 | setup(name='point', 5 | ext_modules=[ 6 | CUDAExtension('point', [ 7 | 'point_api.cpp', 8 | 'Point.cpp', 9 | 'Point_cuda.cu', 10 | ]), 11 | ], 12 | cmdclass={'build_ext': BuildExtension}) 13 | -------------------------------------------------------------------------------- /data_loaders/create_grid.py: -------------------------------------------------------------------------------- 1 | import open3d as opend 2 | import sys 3 | import numpy as np 4 | import time 5 | 6 | 7 | class Grid (): 8 | def __init__(self,points,colors,cell_size,radius,annotations): 9 | self.corners = [] 10 | self.centers = [] 11 | self.cells = [] 12 | self.assigned = [] 13 | self.colors = [] 14 | self.grid_lines = [] 15 | self.annotations = [] 16 | self.create_grid(points,colors,cell_size,radius,annotations) 17 | # self.attribute_points( points) 18 | 19 | def create_grid(self,points,colors,cell_size,radius,annotations): 20 | 21 | max_width = np.amax(points[:, 0]) 22 | min_width = np.amin(points[:, 0]) 23 | 24 | min_height = np.amin(points[:, 1]) 25 | max_height = np.amax(points[:, 1]) 26 | 27 | max_depth = np.amax(points[:, 2]) 28 | min_depth = np.amin(points[:, 2]) 29 | print(max_width,min_width) 30 | 31 | counter = 0 32 | 33 | for i in range(int(min_width * cell_size), int(max_width * cell_size) + 1): 34 | for j in range(int(min_height * cell_size) - 1, int(max_height * cell_size) + 1): 35 | for k in range(int(min_depth * cell_size), int(max_depth * cell_size) + 2 ): 36 | if (i < int(max_width * cell_size) + 1 and j < int(max_height * cell_size) + 1 and k + 1 < int(max_depth * cell_size)): 37 | self.grid_lines.append([counter, counter + 1]) 38 | if (i < int(max_width * cell_size) +1 and j +1 < int(max_height * cell_size) and k + 1 < int(max_depth * cell_size)): 39 | self.grid_lines.append([counter, counter + (int(max_height * cell_size) ) + int(max_depth * cell_size) ]) 40 | if (i < int(max_width * cell_size) + 1 and j < int(max_height * cell_size) + 1 and k < int(max_depth * cell_size) + 2): 41 | self.cells.append([i / cell_size, j / cell_size, k / cell_size, (i + 1) / cell_size, (j + 1) / cell_size, 42 | (k + int((max_depth - min_depth) * 2)) / cell_size]) 43 | self.corners.append(([i / cell_size, j / cell_size, k / cell_size])) 44 | counter = counter + 1 45 | 46 | print(self.corners) 47 | print(self.cells) 48 | 49 | ######################################################################################################################################## 50 | ####################Create group for each center of cell ############################################################################# 51 | ###################################################################################################################################### 52 | pcd2 = opend.PointCloud() 53 | pcd2.points = opend.Vector3dVector(points) 54 | pcd2.colors = opend.Vector3dVector(colors) 55 | # opend.draw_geometries([pcd2]) 56 | pcd_tree = opend.KDTreeFlann(pcd2) 57 | counter = 0 58 | nb_pt = 0 59 | for i,c in enumerate(self.cells): 60 | center = np.array([c[0] + (c[3]-c[0])/2, c[1] + (c[4]-c[1])/2 , c[2] + (c[5]-c[2])/2]) 61 | [k, idx, _] = pcd_tree.search_radius_vector_3d(center,radius) 62 | if(np.asarray(idx).shape[0]>256): 63 | counter = counter + 1 64 | nb_pt = nb_pt + np.asarray(idx).shape[0] 65 | self.assigned.append( points[np.array(idx)] ) 66 | self.colors.append(colors[np.array(idx)]) 67 | self.annotations.append(annotations[np.array([idx])].squeeze(0)) 68 | self.centers.append(center) 69 | 70 | 71 | ######################################################################################################################################## 72 | ####################Fuse neigboring cells if their number of point is too low ######################################################### 73 | ######################################################################################################################################## 74 | pcd3 = opend.PointCloud() 75 | pcd3.points = opend.Vector3dVector(self.centers) 76 | pcd_tree = opend.KDTreeFlann(pcd3) 77 | deleted = [] 78 | for j,pg in enumerate(self.assigned): 79 | center = self.centers[j] 80 | [k, idx, _] = pcd_tree.search_knn_vector_3d(center, 3) 81 | if(j not in deleted): 82 | if (pg.shape[0] <= 1024): 83 | cell_values = self.assigned[idx[0]] 84 | next_cell_values = self.assigned[idx[1]] 85 | next_next_cell_values = self.assigned[idx[2]] 86 | if (cell_values.shape[0] + next_cell_values.shape[0] <= 8000): 87 | self.assigned[j] = np.concatenate([self.assigned[j], next_cell_values], 0) 88 | self.colors[j] = np.concatenate([self.colors[j], self.colors[idx[1]]], 0) 89 | self.annotations[j] = np.concatenate([self.annotations[j], self.annotations[idx[1]]],0) 90 | deleted.append(idx[1]) 91 | if (cell_values.shape[0] + next_cell_values.shape[0] <= 1024): 92 | if (cell_values.shape[0] + next_cell_values.shape[0] + next_next_cell_values.shape[0] <= 8000): 93 | self.assigned[j] = np.concatenate([self.assigned[j], next_next_cell_values], 0) 94 | self.colors[j] = np.concatenate([self.colors[j], self.colors[idx[2]]], 0) 95 | self.annotations[j] = np.concatenate([self.annotations[j], self.annotations[idx[2]]], 0) 96 | deleted.append(idx[2]) 97 | res_assigned = [] 98 | res_annotations = [] 99 | for j,p in enumerate(self.assigned): 100 | if(j not in deleted): 101 | res_assigned.append(self.assigned[j]) 102 | res_annotations.append(self.annotations[j]) 103 | 104 | self.assigned = res_assigned 105 | self.annotations = res_annotations 106 | 107 | def display(self): 108 | 109 | self.colors = [] 110 | distribcr = np.random.randint(low=0, size=len(self.assigned), high=len(self.assigned))/len(self.assigned) 111 | distribcg = np.random.randint(low=0, size=len(self.assigned), high=len(self.assigned)) / len(self.assigned) 112 | distribcb = np.random.randint(low=0, size=len(self.assigned), high=len(self.assigned)) / len(self.assigned) 113 | for i,jk in enumerate(self.assigned): 114 | self.colors.append(np.array([[distribcr[i], distribcg[i], distribcb[i]] for k in range(jk.shape[0])])) 115 | opend.draw_geometries([pcd2]) 116 | 117 | def attribute_points(self,points): 118 | for c in self.cells: 119 | cellule = [] 120 | deleted = [] 121 | for i, p in enumerate(points): 122 | # print(i,p) 123 | if(p[0]>=c[0] and p[0]<=c[3] and p[1]>=c[1] and p[1]<=c[4] and p[2]>=c[2] and p[0]<=c[5]): 124 | cellule.append(p) 125 | deleted.append(i) 126 | points = np.delete(points,deleted,0) 127 | self.assigned.append(cellule) 128 | 129 | 130 | 131 | 132 | 133 | 134 | if __name__ == "__main__": 135 | 136 | # path = sys.argv[1] 137 | # 138 | # 139 | # pcd = opend.read_point_cloud(path) 140 | points = np.random.rand(20048,3) 141 | colors = np.random.rand(20048,3) 142 | # pcd2 = opend.PointCloud() 143 | # pcd2.points = opend.Vector3dVector(points) 144 | # pcd2.colors = opend.Vector3dVector(colors) 145 | # opend.draw_geometries([pcd2]) 146 | 147 | # points = np.asarray(pcd.points) 148 | start_time = time.time() 149 | annotation = np.ones((points.shape[0],1)) 150 | grid = Grid(points,colors, 0.05, 0.025,annotation)#points,colors,cell_size,radius,annotations 151 | 152 | print(time.time() - start_time) 153 | print(grid.assigned) 154 | pcd2 = opend.PointCloud() 155 | pcd2.points = opend.Vector3dVector(np.concatenate(grid.assigned)+np.array([1.0,0.0,0.0])) 156 | pcd2.colors = opend.Vector3dVector(np.concatenate(grid.colors)) 157 | 158 | pcd3 = opend.PointCloud() 159 | pcd3.points = opend.Vector3dVector(np.array(grid.corners)) 160 | 161 | line_set = opend.LineSet() 162 | line_set.points = opend.Vector3dVector(np.array(grid.corners)) 163 | line_set.lines = opend.Vector2iVector(np.array(grid.grid_lines)) 164 | opend.draw_geometries([pcd,pcd2,pcd3,line_set]) 165 | -------------------------------------------------------------------------------- /metrics.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | 5 | def Kappa_cohen(predictions,groundtruth): 6 | TP = 0 7 | TN = 0 8 | FP = 0 9 | FN = 0 10 | gt_r = 0 11 | gt_p = 0 12 | for j,an in enumerate(predictions): 13 | if(an == groundtruth[j] and an == 1 ): 14 | TP = TP + 1 15 | elif(an == groundtruth[j] and an == 0): 16 | TN = TN + 1 17 | elif(an != groundtruth[j] and an == 0): 18 | FN = FN + 1 19 | elif(an != groundtruth[j] and an == 1): 20 | FP = FP + 1 21 | if(groundtruth[j]== 0): 22 | gt_p = gt_p + 1 23 | else: 24 | gt_r = gt_r + 1 25 | 26 | observed_accuracy = (TP+TN)/groundtruth.shape[0] 27 | expected_accuracy = ((gt_r*TP)/groundtruth.shape[0] + (gt_p*TN)/groundtruth.shape[0])/groundtruth.shape[0] 28 | 29 | return (observed_accuracy - expected_accuracy)/ (1- expected_accuracy) 30 | 31 | 32 | def IoU(predictions,groundtruth): 33 | TP = 0 34 | TN = 0 35 | FP = 0 36 | FN = 0 37 | for j,an in enumerate(predictions): 38 | if(an == groundtruth[j] and an == 1 ): 39 | TP = TP + 1 40 | elif(an == groundtruth[j] and an == 0): 41 | TN = TN + 1 42 | elif(an != groundtruth[j] and an == 0): 43 | FN = FN + 1 44 | elif(an != groundtruth[j] and an == 1): 45 | FP = FP + 1 46 | 47 | return TP/(TP+FP+FN) 48 | 49 | 50 | def Accuracy(predictions,groundtruth): 51 | TP = 0 52 | TN = 0 53 | FP = 0 54 | FN = 0 55 | for j,an in enumerate(predictions): 56 | if(an == groundtruth[j] and an == 1 ): 57 | TP = TP + 1 58 | elif(an == groundtruth[j] and an == 0): 59 | TN = TN + 1 60 | elif(an != groundtruth[j] and an == 0): 61 | FN = FN + 1 62 | elif(an != groundtruth[j] and an == 1): 63 | FP = FP + 1 64 | 65 | return (TP+TN)/groundtruth.shape[0] 66 | -------------------------------------------------------------------------------- /unit_test/centroids.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/unit_test/centroids.npy -------------------------------------------------------------------------------- /unit_test/centroids_cpp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/unit_test/centroids_cpp.npy -------------------------------------------------------------------------------- /unit_test/group_idx.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/unit_test/group_idx.npy -------------------------------------------------------------------------------- /unit_test/group_idxcpp.npy: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/unit_test/group_idxcpp.npy -------------------------------------------------------------------------------- /unit_test/test_cpp_vs_C.py: -------------------------------------------------------------------------------- 1 | import sys 2 | sys.path.insert(0,'..') 3 | import open3d as opend 4 | import torch 5 | import numpy as np 6 | 7 | 8 | def index_points(points, idx): 9 | """ 10 | Input: 11 | points: input points data, [B, N, C] 12 | idx: sample index data, [B, D1, D2, ..., Dn] 13 | Return: 14 | new_points:, indexed points data, [B, D1, D2, ..., Dn, C] 15 | """ 16 | device = points.device 17 | B = points.shape[0] 18 | view_shape = list(idx.shape) 19 | view_shape[1:] = [1] * (len(view_shape) - 1) 20 | repeat_shape = list(idx.shape) 21 | repeat_shape[0] = 1 22 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 23 | new_points = points[batch_indices, idx, :] 24 | return new_points 25 | 26 | if(len(sys.argv)>1): 27 | if(sys.argv[1]=="C"): 28 | from C_utils import libsift ###C version, use venv using pytorch 0.4 29 | def group_points(xyz,idx): 30 | b , n , c = xyz.shape 31 | m = idx.shape[1] 32 | nsample = idx.shape[2] 33 | out = torch.zeros((xyz.shape[0],xyz.shape[1], idx.shape[2],c)).cuda() 34 | libsift.group_points(b,n,c,n,nsample,xyz,idx.int(),out) 35 | np.save("grouped_points.npy", out) 36 | 37 | def farthest_point_sample_gpu(xyz, npoint): 38 | b, n ,c = xyz.shape 39 | centroid = torch.zeros((xyz.shape[0],npoint), dtype=torch.int32).cuda() 40 | temp = torch.zeros((32,n)).cuda() 41 | libsift.farthestPoint(b,n, npoint, xyz , temp ,centroid) 42 | np.save("centroids.npy", centroid.long().cpu().numpy() ) 43 | 44 | def ball_query(radius, nsample, xyz, new_xyz): 45 | b, n ,c = xyz.shape 46 | m = new_xyz.shape[1] 47 | group_idx = torch.zeros((new_xyz.shape[0],new_xyz.shape[1], nsample), dtype=torch.int32).cuda() 48 | pts_cnt = torch.zeros((xyz.shape[0],xyz.shape[1]), dtype=torch.int32).cuda() 49 | libsift.ball_query (b, n, m, radius, nsample, xyz, new_xyz, group_idx ,pts_cnt) 50 | np.save("group_idx.npy",group_idx.long().cpu().numpy()) 51 | 52 | elif(sys.argv[1]=="Cpp"): 53 | import point ###Cpp version, use venv using pytorch 1.0+ 54 | def group_points(xyz,idx): 55 | b , n , c = xyz.shape 56 | m = idx.shape[1] 57 | nsample = idx.shape[2] 58 | out = torch.zeros((xyz.shape[0],xyz.shape[1], idx.shape[2],c)).cuda() 59 | point.group_points(b,n,c,n,nsample,xyz,idx.int(),out) 60 | np.save("grouped_points_cpp.npy", out) 61 | 62 | def farthest_point_sample_gpu(xyz, npoint): 63 | b, n ,c = xyz.shape 64 | centroid = torch.zeros((xyz.shape[0],npoint), dtype=torch.int32).cuda() 65 | temp = torch.zeros((32,n)).cuda() 66 | point.farthestPoint(b,n, npoint, xyz , temp ,centroid) 67 | np.save("centroids_cpp.npy", centroid.long().cpu().numpy() ) 68 | 69 | def ball_query(radius, nsample, xyz, new_xyz): 70 | b, n ,c = xyz.shape 71 | m = new_xyz.shape[1] 72 | group_idx = torch.zeros((new_xyz.shape[0],new_xyz.shape[1], nsample), dtype=torch.int32).cuda() 73 | pts_cnt = torch.zeros((xyz.shape[0],xyz.shape[1]), dtype=torch.int32).cuda() 74 | point.ball_query (b, n, m, radius, nsample, xyz, new_xyz, group_idx ,pts_cnt) 75 | np.save("group_idxcpp.npy",group_idx.long().cpu().numpy()) 76 | else: 77 | print("Please select the extension you want to use !! ") 78 | 79 | def test_farthest(): 80 | pc = torch.from_numpy(np.asarray(opend.read_point_cloud("test_pc.ply").points)).unsqueeze(0).float().cuda() 81 | farthest_point_sample_gpu(pc,500) 82 | def test_ball_query(): 83 | pc = torch.from_numpy(np.asarray(opend.read_point_cloud("test_pc.ply").points)).unsqueeze(0).float().cuda() 84 | centroids = torch.from_numpy(np.load('centroids.npy')).cuda() 85 | new_xyz = index_points(pc, centroids) 86 | ball_query(0.1, 32, pc, new_xyz) 87 | 88 | def test_group_points(): 89 | pc = torch.from_numpy(np.asarray(opend.read_point_cloud("test_pc.ply").points)).unsqueeze(0).float().cuda() 90 | centroids = torch.from_numpy(np.load('centroids.npy')).cuda() 91 | # test_farthest() 92 | # test_ball_query() 93 | def test_arrays(): 94 | cpp_f = np.load('centroids.npy') 95 | c_f = np.load('centroids_cpp.npy') 96 | if(cpp_f.all()==c_f.all()): 97 | print("gotcha centroids") 98 | cpp_g = np.load('group_idx.npy') 99 | c_g = np.load('group_idxcpp.npy') 100 | if(cpp_g.all()== c_g.all()): 101 | print("gotcha groups") 102 | test_arrays() 103 | -------------------------------------------------------------------------------- /unit_test/test_pc.ply: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/lelouedec/3DNetworksPytorch/331900efe405f2b5ed8eb094cca2c43c546156bb/unit_test/test_pc.ply --------------------------------------------------------------------------------