├── loss.py ├── ReadMe.md ├── network.py ├── train_val.py └── dataset.py /loss.py: -------------------------------------------------------------------------------- 1 | 2 | 3 | import numpy as np 4 | import jittor as jt 5 | 6 | 7 | # def weighted_cross_entropy(output , target): 8 | 9 | def weighted_cross_entropy( y_pred, y_true ): 10 | 11 | soft_max = jt.nn.Softmax(dim=1) 12 | p = soft_max(y_true * 2.0) 13 | #logp = K.log(p) 14 | q = soft_max(y_pred , axis=1) 15 | logq = jt.log(q) 16 | #print(p) 17 | 18 | #p_logp = tf.multiply(p , logp) 19 | # p_logq = tf.multiply(p , logq) 20 | p_logq = p * logq 21 | 22 | #w_p_logp = tf.multiply(class_weights , p_logp) 23 | # w_p_logq = tf.multiply(class_weights , p_logq) 24 | w_p_logq = class_weights * p_logq 25 | 26 | #loss_cross = w_p_logp - w_p_logq 27 | loss_cross = - tf.reduce_mean(w_p_logq, axis=-1) 28 | 29 | loss_cross = - w_p_logq.sum(1).mean() 30 | 31 | # weighted_cross_entropy 32 | # loss = - \sum_i (w[i] * p[i] * log(p[i])) + \sum_i ( w[i] * p[i] * log(q[i])) 33 | #tf.summary("p_logp=") 34 | return loss_cross -------------------------------------------------------------------------------- /ReadMe.md: -------------------------------------------------------------------------------- 1 | 2 | 3 | # Supervoxel-CNN with [Jittor](https://cg.cs.tsinghua.edu.cn/jittor/) 4 | 5 | This is the Supervoxel-CNN network impremented by [Jittor](https://cg.cs.tsinghua.edu.cn/jittor/). Supervoxel-CNN is the backbone network used in the SVNet system, which is an online 3D semantic segmentation approach. SVNet is a online 3D semantic segmentation system, which contains both a online 3D reconstruction system and online 3D semantic prediction system (Supervoxel-CNN in this part). Here we only provide the Supervoxel-CNN part to show how we implement the network we used in the paper. If you are interesting with the whole system, please contact [Shi-Sheng Huang](https://shishenghuang.github.io/index/). For more details about SVNet, please check our paper 6 | 7 | ``` 8 | @article{SupervoxelConv2021, 9 | author = {Shi{-}Sheng Huang and Ze{-}Yu Ma and Tai{-}Jiang Mu and Hongbo Fu and Shi{-}Min Hu}, 10 | title = {Supervoxel Convolution for Online 3D Semantic Segmentation}, 11 | journal = {ACM Transactions on Graphics}, 12 | volume = {40}, 13 | number = {3}, 14 | article = {34} 15 | year = {2021} 16 | } 17 | ``` 18 | 19 | # network.py 20 | 21 | This file contains the Supervoxel-CNN network 22 | 23 | # dataset.py 24 | 25 | This file contains the data preparation method, the input is the training data we have preprocessed described in the paper. For details, please contact [Shi-Sheng Huang](https://shishenghuang.github.io/index/) 26 | 27 | # loss.py 28 | 29 | This file contains the weighted_cross_entropy we used to train Supervoxel-CNN 30 | 31 | # train_val.py 32 | 33 | This file contains the method to train and evaluate Supervoxel-CNN 34 | 35 | -------------------------------------------------------------------------------- /network.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import jittor as jt 4 | from jittor import nn 5 | from jittor import Module 6 | from jittor import init 7 | #from jittor.contrib import concat , argmax_pool 8 | 9 | 10 | class SVNet(Module): 11 | def __init__(self, num_points_, num_k_, num_class_): 12 | super(SVNet, self).__init__() 13 | 14 | self.num_points = num_points_ 15 | self.num_k = num_k_ 16 | self.num_class = num_class_ 17 | 18 | self.conv2d_x1 = nn.Conv2d(3,8 , 1) 19 | self.bn1 = nn.BatchNorm(8) 20 | self.conv2d_x2 = nn.Conv2d(8,16,1) 21 | self.bn2 = nn.BatchNorm(16) 22 | self.conv2d_x3 = nn.Conv2d(16,32,1) 23 | self.bn3 = nn.BatchNorm(32) 24 | self.conv2d_x4 = nn.Conv2d(32,64,1) 25 | self.bn4 = nn.BatchNorm(64) 26 | 27 | self.conv1d_x1 = nn.Conv1d(64*30,1024,1) 28 | self.bn5 = nn.BatchNorm(1024) 29 | self.conv1d_x2 = nn.Conv1d(1024,256,1) 30 | self.bn6 = nn.BatchNorm(256) 31 | self.conv1d_x3 = nn.Conv1d(256,128,1) 32 | self.bn7 = nn.BatchNorm(128) 33 | self.conv1d_x4 = nn.Conv1d(128,64,1) 34 | self.bn8= nn.BatchNorm(64) 35 | self.conv1d_x5 = nn.Conv1d(64,self.num_class,1) 36 | 37 | self.resize_l = nn.Resize([-1,64*30]) 38 | self.softmax_k = nn.Softmax(self.num_class) 39 | 40 | 41 | def excute(self, points, features): 42 | 43 | x = nn.ReLU(self.bn1(self.conv2d_x1(points))) 44 | 45 | x = nn.ReLU(self.bn2(self.conv2d_x2(x))) 46 | 47 | x = nn.ReLU(self.bn3(self.conv2d_x3(x))) 48 | 49 | x = nn.ReLU(self.bn4(self.conv2d_x4(x))) 50 | 51 | y = x * features 52 | 53 | y = y.reshape([-1, 64*30]) 54 | 55 | y = nn.ReLU(self.bn5(self.conv1d_x1(y))) 56 | 57 | y = nn.ReLU(self.bn6(self.conv1d_x2(y))) 58 | 59 | y = nn.ReLU(self.bn7(self.conv1d_x3(y))) 60 | 61 | y = nn.ReLU(self.bn8(self.conv1d_x4(y))) 62 | 63 | y = self.softmax_k(y) 64 | 65 | return y 66 | 67 | 68 | 69 | 70 | 71 | 72 | 73 | 74 | 75 | -------------------------------------------------------------------------------- /train_val.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import jittor as jt 4 | import os 5 | import tqdm 6 | from network import * 7 | from dataset import * 8 | from loss import * 9 | 10 | 11 | if __name__ = '__main__': 12 | 13 | freeze_random_seed() 14 | 15 | parser = argparse.ArgumentParser(decription='SVNet') 16 | parser.add_argument('--train_file' , type=str , default='./data/train.h5', metavar='N', help='The train file data') 17 | parser.add_argument('--val_file', type=str, default='./data/val.h5' , help='Evaluate file') 18 | parser.add_argument('--batch_size' , type=int , default=32, metavar='batch_size', help = 'Size of batch') 19 | parser.add_argument('--lr' , tpye=float, default=0.001, metavar='LR', help='learning rate') 20 | parser.add_argument('--num_points', type=int, default=1024,help='Points Number') 21 | parser.add_argument('--num_k', type=int, default=8, help='Number Neighbors') 22 | parser.add_argument('--num_class', type=int, default=21, help='Number Classes') 23 | parser.add_argument('--epoches' , type=int, default=10, help="Train Epoches") 24 | 25 | args = parser.parse_args() 26 | 27 | train_files = args.train_file 28 | val_files = args.val_file 29 | num_points = args.num_points 30 | num_k = args.num_k 31 | num_class = args.num_class 32 | 33 | lr = args.lr 34 | epoches = args.epoches 35 | 36 | net = SVNet(num_points_ = num_points, num_k_ = num_k , num_class_ = num_class) 37 | 38 | optimizer = jt.nn.Adam(net.parameters(), lr = lr ) 39 | 40 | train_dataloader = SVNetTrainDataSet(train_files, num_points , num_k , num_class) 41 | val_dataloader =SVNetTestDataSet(val_files , num_points , num_k , num_class) 42 | 43 | for epoch in range(epoches): 44 | 45 | net.train() 46 | 47 | if epoch % 10 == 0: 48 | jt.save(net.state_dict(), 'checkpoints/models/model_%d.th' % (epoch)) 49 | 50 | for idx , (pnts_cuda, labels_cuda) in enumerate(train_dataloader) 51 | 52 | points = pnts_cuda[:,:3].cuda() 53 | features = pnts_cuda[:,3:].cuda() 54 | 55 | target_label = labels_cuda.cuda() 56 | 57 | pred_label = net(points , features) 58 | 59 | loss = weighted_cross_entropy(pred_label , target_label) 60 | optimizer.zero_grad() 61 | #loss.backward() 62 | optimizer.step(loss) 63 | 64 | print("SVNet Train -- INFO -- epoch : %d, idx : %d, loss : %f\n" % (epoch , idx , loss)) 65 | 66 | 67 | 68 | 69 | 70 | -------------------------------------------------------------------------------- /dataset.py: -------------------------------------------------------------------------------- 1 | 2 | import numpy as np 3 | import os 4 | import jittor as jt 5 | from jittor.dataset.dataset import Dataset, dataset_root 6 | import h5py 7 | 8 | def rotate_point_cloud(batch_data): 9 | """ Randomly rotate the point clouds to augument the dataset 10 | rotation is per shape based along up direction 11 | Input: 12 | BxNxKX6 array, original batch of point clouds 13 | Return: 14 | [BxNxKX3, BXNXKX3] array, rotated batch of point clouds 15 | """ 16 | rotated_points = np.zeros((batch_data.shape[0],batch_data.shape[1],batch_data.shape[2],3), dtype=np.float32) 17 | rotated_pts = np.zeros((batch_data.shape[0],batch_data.shape[1],batch_data.shape[2],3), dtype=np.float32) 18 | rotated_ns = np.zeros((batch_data.shape[0],batch_data.shape[1],batch_data.shape[2],3), dtype=np.float32) 19 | 20 | for k in range(batch_data.shape[0]): 21 | rotation_angle = np.random.uniform() * 2 * np.pi 22 | cosval = np.cos(rotation_angle) 23 | sinval = np.sin(rotation_angle) 24 | rotation_matrix = np.array([[cosval, 0, sinval], 25 | [0, 1, 0], 26 | [-sinval, 0, cosval]]) 27 | shape_pc = batch_data[k, : , : , :3] 28 | rotated_points[k, :,:,:3] = np.dot(shape_pc, rotation_matrix) 29 | shape_pc_pt = batch_data[k, : , : , 3:6] 30 | rotated_pts[k, :, :, :3] = np.dot(shape_pc_pt , rotation_matrix) 31 | shape_pc_n = batch_data[k, : , : , 9:12] 32 | rotated_ns[k, :, :, :3] = np.dot(shape_pc_n , rotation_matrix) 33 | del shape_pc 34 | del shape_pc_pt 35 | del shape_pc_n 36 | del rotation_matrix 37 | return rotated_points, rotated_pts, rotated_ns 38 | 39 | 40 | class SVNetTrainDataSet(Dataset): 41 | def __init__(self , train_h5_file , num_points_ , num_k_ , num_class_): 42 | super().__init__() 43 | 44 | self.num_points = num_points_ 45 | self.num_k = num_k_ 46 | self.num_class = num_class_ 47 | cur_points , cur_labels = h5py.load_h5(train_h5_file) 48 | self.train_feature_r = cur_points.reshape(-1 , self.num_points , self.num_k , 33) 49 | self.train_label_r = cur_labels.reshape(-1, self.num_points , self.num_class) 50 | 51 | def __getitem__(self , index): 52 | 53 | pnts = self.train_feature_r[index, :, : , :] 54 | labs = self.train_label_r[index, :, : , :] 55 | 56 | train_points_rotate , train_feature_rotate, train_norm_rotate = rotate_point_cloud(pnts[:,:,:,:12]) 57 | train_feature_rotate_total = np.concate([train_points_rotate, pnts[:,:,:,6:9] , train_norm_rotate, pnts[:,:,:,12:]], axis = 3) 58 | 59 | pnts_cuda = jt.array(train_feature_rotate_total) 60 | labs_cuda = jt.array(labs) 61 | 62 | return pnts_cuda , labs_cuda 63 | 64 | def __len__(self): 65 | 66 | return self.train_feature_r.shape[0] 67 | 68 | 69 | class SVNetTestDataSet(Dataset): 70 | def __init__(self , test_h5_file , num_points_ , num_k_ , num_class_): 71 | super().__init__() 72 | 73 | self.num_points = num_points_ 74 | self.num_k = num_k_ 75 | self.num_class = num_class_ 76 | cur_points , cur_labels = h5py.load_h5(train_h5_file) 77 | self.train_feature_r = cur_points.reshape(-1 , self.num_points , self.num_k , 33) 78 | self.train_label_r = cur_labels.reshape(-1, self.num_points , self.num_class) 79 | 80 | def __getitem__(self , index): 81 | 82 | pnts = self.train_feature_r[index, :, : , :] 83 | labs = self.train_label_r[index, :, : , :] 84 | 85 | # train_points_rotate , train_feature_rotate, train_norm_rotate = rotate_point_cloud(pnts[:,:,:,:12]) 86 | # train_feature_rotate_total = np.concate([train_points_rotate, pnts[:,:,:,6:9] , train_norm_rotate, pnts[:,:,:,12:]], axis = 3) 87 | 88 | pnts_cuda = jt.array(pnts) 89 | labs_cuda = jt.array(labs) 90 | 91 | return pnts_cuda , labs_cuda 92 | 93 | def __len__(self): 94 | 95 | return self.train_feature_r.shape[0] 96 | 97 | --------------------------------------------------------------------------------