├── .gitignore ├── README.md ├── checkpoints ├── pointnet-inview-0.52077-0018.pth └── pointnet2-inview-0.55884-0001.pth ├── clf.py ├── config ├── calib_cam_to_cam.txt ├── calib_imu_to_velo.txt ├── calib_velo_to_cam.txt ├── ego_view.json ├── render_option.json └── semantic-kitti.yaml ├── data_utils ├── ColorGenerator_Loader.py ├── ModelNetDataLoader.py ├── S3DISDataLoader.py ├── SemKITTI_Loader.py ├── ShapeNetDataLoader.py ├── augmentation.py ├── download_data.sh ├── kitti_utils.py └── redis_utils.py ├── model ├── chamfer.py ├── pointnet.py ├── pointnet2.py ├── pointnet_util.py └── utils.py ├── my_log.py ├── partseg.py ├── pcd_utils.py ├── pcdseg.py ├── pcdvis.py └── semseg.py /.gitignore: -------------------------------------------------------------------------------- 1 | # Byte-compiled / optimized / DLL files 2 | __pycache__/ 3 | *.py[cod] 4 | *$py.class 5 | .vscode/ 6 | data/ 7 | experiment/ 8 | pretrain/ 9 | *.h5 10 | # C extensions 11 | *.so 12 | checkpoints/backup/ 13 | 14 | # Distribution / packaging 15 | .Python 16 | build/ 17 | develop-eggs/ 18 | dist/ 19 | downloads/ 20 | eggs/ 21 | .eggs/ 22 | lib/ 23 | lib64/ 24 | parts/ 25 | sdist/ 26 | var/ 27 | wheels/ 28 | *.egg-info/ 29 | .installed.cfg 30 | *.egg 31 | MANIFEST 32 | 33 | # PyInstaller 34 | # Usually these files are written by a python script from a template 35 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 36 | *.manifest 37 | *.spec 38 | 39 | # Installer logs 40 | pip-log.txt 41 | pip-delete-this-directory.txt 42 | 43 | # Unit test / coverage reports 44 | htmlcov/ 45 | .tox/ 46 | .nox/ 47 | .coverage 48 | .coverage.* 49 | .cache 50 | nosetests.xml 51 | coverage.xml 52 | *.cover 53 | .hypothesis/ 54 | .pytest_cache/ 55 | 56 | # Translations 57 | *.mo 58 | *.pot 59 | 60 | # Django stuff: 61 | *.log 62 | local_settings.py 63 | db.sqlite3 64 | 65 | # Flask stuff: 66 | instance/ 67 | .webassets-cache 68 | 69 | # Scrapy stuff: 70 | .scrapy 71 | 72 | # Sphinx documentation 73 | docs/_build/ 74 | 75 | # PyBuilder 76 | target/ 77 | 78 | # Jupyter Notebook 79 | .ipynb_checkpoints 80 | 81 | # IPython 82 | profile_default/ 83 | ipython_config.py 84 | 85 | # pyenv 86 | .python-version 87 | 88 | # celery beat schedule file 89 | celerybeat-schedule 90 | 91 | # SageMath parsed files 92 | *.sage.py 93 | 94 | # Environments 95 | .env 96 | .venv 97 | env/ 98 | venv/ 99 | ENV/ 100 | env.bak/ 101 | venv.bak/ 102 | 103 | # Spyder project settings 104 | .spyderproject 105 | .spyproject 106 | 107 | # Rope project settings 108 | .ropeproject 109 | 110 | # mkdocs documentation 111 | /site 112 | 113 | # mypy 114 | .mypy_cache/ 115 | .dmypy.json 116 | dmypy.json 117 | 118 | # Pyre type checker 119 | .pyre/ 120 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | ## Pytorch Implementation of PointNet and PointNet++ Trained on KITTI Point Cloud Semantic Segmentation dataset 2 | 3 | This repo is implementation for [PointNet](http://openaccess.thecvf.com/content_cvpr_2017/papers/Qi_PointNet_Deep_Learning_CVPR_2017_paper.pdf) and [PointNet++](http://papers.nips.cc/paper/7095-pointnet-deep-hierarchical-feature-learning-on-point-sets-in-a-metric-space.pdf) in pytorch. 4 | Links for Official Code: 5 | [Official PointNet](https://github.com/charlesq34/pointnet) and [Official PointNet++](https://github.com/charlesq34/pointnet2) 6 | 7 | # Install 8 | ```bash 9 | conda create --name tf python=3.6 10 | conda activate tf 11 | conda install -c pytorch pytorch torchvision 12 | conda install -c open3d-admin open3d=0.9.0.0 13 | pip install pyyaml matplotlib tqdm h5py redis numpy pandas opencv-python==4.2.0.32 14 | ``` 15 | 16 | # Run demo 17 | ```bash 18 | export KITTI_ROOT=PATH/odometry/dataset 19 | # windows: $env:KITTI_ROOT = 'PATH/odometry/dataset' 20 | python pcdvis.py 21 | ``` -------------------------------------------------------------------------------- /checkpoints/pointnet-inview-0.52077-0018.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-Muyun/PointNet12/698b82e8c687a0dd33187f5642d084f46ae25355/checkpoints/pointnet-inview-0.52077-0018.pth -------------------------------------------------------------------------------- /checkpoints/pointnet2-inview-0.55884-0001.pth: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/Jiang-Muyun/PointNet12/698b82e8c687a0dd33187f5642d084f46ae25355/checkpoints/pointnet2-inview-0.55884-0001.pth -------------------------------------------------------------------------------- /clf.py: -------------------------------------------------------------------------------- 1 | import open3d 2 | import argparse 3 | import os 4 | import time 5 | import h5py 6 | import datetime 7 | import numpy as np 8 | from matplotlib import pyplot as plt 9 | import torch 10 | from torch.utils.data import DataLoader 11 | import torch.nn.functional as F 12 | 13 | from tqdm import tqdm 14 | import my_log as log 15 | 16 | from data_utils.ModelNetDataLoader import ModelNetDataLoader, load_data, class_names 17 | from utils import test_clf, save_checkpoint, select_avaliable, mkdir 18 | from model.pointnet2 import PointNet2ClsMsg 19 | from model.pointnet import PointNetCls, feature_transform_reguliarzer 20 | 21 | def parse_args(notebook = False): 22 | parser = argparse.ArgumentParser('PointNet') 23 | parser.add_argument('--model_name', default='pointnet', help='pointnet or pointnet2') 24 | parser.add_argument('--mode', default='train', help='train or eval') 25 | parser.add_argument('--batch_size', type=int, default=16, help='batch size in training') 26 | parser.add_argument('--epoch', default=100, type=int, help='number of epoch in training') 27 | parser.add_argument('--learning_rate', default=0.001, type=float, help='learning rate in training') 28 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') 29 | parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer for training') 30 | parser.add_argument('--pretrain', type=str, default=None, help='whether use pretrain model') 31 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='decay rate of learning rate') 32 | parser.add_argument('--feature_transform', default=False, help="use feature transform in pointnet") 33 | parser.add_argument('--augment', default=False, action='store_true', help="Enable data augmentation") 34 | if notebook: 35 | return parser.parse_args([]) 36 | else: 37 | return parser.parse_args() 38 | 39 | def train(args): 40 | experiment_dir = mkdir('./experiment/') 41 | checkpoints_dir = mkdir('./experiment/clf/%s/'%(args.model_name)) 42 | train_data, train_label, test_data, test_label = load_data('experiment/data/modelnet40_ply_hdf5_2048/') 43 | 44 | trainDataset = ModelNetDataLoader(train_data, train_label, data_augmentation = args.augment) 45 | trainDataLoader = DataLoader(trainDataset, batch_size=args.batch_size, shuffle=True) 46 | 47 | testDataset = ModelNetDataLoader(test_data, test_label) 48 | testDataLoader = torch.utils.data.DataLoader(testDataset, batch_size=args.batch_size, shuffle=False) 49 | 50 | log.info('Building Model',args.model_name) 51 | if args.model_name == 'pointnet': 52 | num_class = 40 53 | model = PointNetCls(num_class,args.feature_transform).cuda() 54 | else: 55 | model = PointNet2ClsMsg().cuda() 56 | 57 | torch.backends.cudnn.benchmark = True 58 | model = torch.nn.DataParallel(model).cuda() 59 | log.debug('Using gpu:',args.gpu) 60 | 61 | if args.pretrain is not None: 62 | log.info('Use pretrain model...') 63 | state_dict = torch.load(args.pretrain) 64 | model.load_state_dict(state_dict) 65 | init_epoch = int(args.pretrain[:-4].split('-')[-1]) 66 | log.info('start epoch from', init_epoch) 67 | else: 68 | log.info('Training from scratch') 69 | init_epoch = 0 70 | 71 | if args.optimizer == 'SGD': 72 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) 73 | elif args.optimizer == 'Adam': 74 | optimizer = torch.optim.Adam( 75 | model.parameters(), 76 | lr=args.learning_rate, 77 | betas=(0.9, 0.999), 78 | eps=1e-08, 79 | weight_decay=args.decay_rate 80 | ) 81 | 82 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 83 | LEARNING_RATE_CLIP = 1e-5 84 | 85 | global_epoch = 0 86 | global_step = 0 87 | best_tst_accuracy = 0.0 88 | 89 | log.info('Start training...') 90 | for epoch in range(init_epoch,args.epoch): 91 | scheduler.step() 92 | lr = max(optimizer.param_groups[0]['lr'],LEARNING_RATE_CLIP) 93 | 94 | log.debug(job='clf',model=args.model_name,gpu=args.gpu,epoch='%d/%s' % (epoch, args.epoch),lr=lr) 95 | 96 | for param_group in optimizer.param_groups: 97 | param_group['lr'] = lr 98 | 99 | for batch_id, data in tqdm(enumerate(trainDataLoader, 0), total=len(trainDataLoader), smoothing=0.9): 100 | points, target = data 101 | target = target[:, 0] 102 | points = points.transpose(2, 1) 103 | points, target = points.cuda(), target.cuda() 104 | optimizer.zero_grad() 105 | model = model.train() 106 | pred, trans_feat = model(points) 107 | loss = F.nll_loss(pred, target.long()) 108 | if args.feature_transform and args.model_name == 'pointnet': 109 | loss += feature_transform_reguliarzer(trans_feat) * 0.001 110 | loss.backward() 111 | optimizer.step() 112 | global_step += 1 113 | 114 | log.debug('clear cuda cache') 115 | torch.cuda.empty_cache() 116 | 117 | acc = test_clf(model, testDataLoader) 118 | log.info(loss='%.5f' % (loss.data)) 119 | log.info(Test_Accuracy='%.5f' % acc) 120 | 121 | if acc >= best_tst_accuracy: 122 | best_tst_accuracy = acc 123 | fn_pth = 'clf-%s-%.5f-%04d.pth'%(args.model_name, acc, epoch) 124 | log.debug('Saving model....', fn_pth) 125 | torch.save(model.state_dict(), os.path.join(checkpoints_dir,fn_pth)) 126 | global_epoch += 1 127 | 128 | log.info(Best_Accuracy = best_tst_accuracy) 129 | log.info('End of training...') 130 | 131 | def evaluate(args): 132 | test_data, test_label = load_data('experiment/data/modelnet40_ply_hdf5_2048/', train = False) 133 | testDataset = ModelNetDataLoader(test_data, test_label) 134 | testDataLoader = torch.utils.data.DataLoader(testDataset, batch_size=args.batch_size, shuffle=False) 135 | 136 | log.debug('Building Model',args.model_name) 137 | if args.model_name == 'pointnet': 138 | num_class = 40 139 | model = PointNetCls(num_class,args.feature_transform) 140 | else: 141 | model = PointNet2ClsMsg() 142 | 143 | torch.backends.cudnn.benchmark = True 144 | model = torch.nn.DataParallel(model).cuda() 145 | log.debug('Using gpu:',args.gpu) 146 | 147 | if args.pretrain is None: 148 | log.err('No pretrain model') 149 | return 150 | 151 | log.debug('Loading pretrain model...') 152 | state_dict = torch.load(args.pretrain) 153 | model.load_state_dict(state_dict) 154 | 155 | acc = test_clf(model.eval(), testDataLoader) 156 | log.msg(Test_Accuracy='%.5f' % (acc)) 157 | 158 | def vis(args): 159 | test_data, test_label = load_data(root, train = False) 160 | log.info(test_data=test_data.shape,test_label=test_label.shape) 161 | 162 | log.debug('Building Model',args.model_name) 163 | if args.model_name == 'pointnet': 164 | num_class = 40 165 | model = PointNetCls(num_class,args.feature_transform).cuda() 166 | else: 167 | model = PointNet2ClsMsg().cuda() 168 | 169 | torch.backends.cudnn.benchmark = True 170 | model = torch.nn.DataParallel(model) 171 | model.cuda() 172 | log.info('Using multi GPU:',args.gpu) 173 | 174 | if args.pretrain is None: 175 | log.err('No pretrain model') 176 | return 177 | 178 | log.debug('Loading pretrain model...') 179 | checkpoint = torch.load(args.pretrain) 180 | model.load_state_dict(checkpoint) 181 | model.eval() 182 | 183 | log.info('Press space to exit, press Q for next frame') 184 | 185 | for idx in range(test_data.shape[0]): 186 | point_np = test_data[idx:idx+1] 187 | gt = test_label[idx][0] 188 | 189 | points = torch.from_numpy(point_np) 190 | points = points.transpose(2, 1).cuda() 191 | 192 | pred, trans_feat = model(points) 193 | pred_choice = pred.data.max(1)[1] 194 | log.info(gt=class_names[gt], pred_choice=class_names[pred_choice.cpu().numpy().item()]) 195 | 196 | point_cloud = open3d.geometry.PointCloud() 197 | point_cloud.points = open3d.utility.Vector3dVector(point_np[0]) 198 | 199 | vis = open3d.visualization.VisualizerWithKeyCallback() 200 | vis.create_window() 201 | vis.get_render_option().background_color = np.asarray([0, 0, 0]) 202 | vis.add_geometry(point_cloud) 203 | vis.register_key_callback(32, lambda vis: exit()) 204 | vis.run() 205 | vis.destroy_window() 206 | 207 | if __name__ == '__main__': 208 | args = parse_args() 209 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 210 | if args.mode == "train": 211 | train(args) 212 | if args.mode == "eval": 213 | evaluate(args) 214 | if args.mode == "vis": 215 | vis(args) 216 | -------------------------------------------------------------------------------- /config/calib_cam_to_cam.txt: -------------------------------------------------------------------------------- 1 | calib_time: 09-Jan-2012 13:57:47 2 | corner_dist: 9.950000e-02 3 | S_00: 1.392000e+03 5.120000e+02 4 | K_00: 9.842439e+02 0.000000e+00 6.900000e+02 0.000000e+00 9.808141e+02 2.331966e+02 0.000000e+00 0.000000e+00 1.000000e+00 5 | D_00: -3.728755e-01 2.037299e-01 2.219027e-03 1.383707e-03 -7.233722e-02 6 | R_00: 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 7 | T_00: 2.573699e-16 -1.059758e-16 1.614870e-16 8 | S_rect_00: 1.242000e+03 3.750000e+02 9 | R_rect_00: 9.999239e-01 9.837760e-03 -7.445048e-03 -9.869795e-03 9.999421e-01 -4.278459e-03 7.402527e-03 4.351614e-03 9.999631e-01 10 | P_rect_00: 7.215377e+02 0.000000e+00 6.095593e+02 0.000000e+00 0.000000e+00 7.215377e+02 1.728540e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 11 | S_01: 1.392000e+03 5.120000e+02 12 | K_01: 9.895267e+02 0.000000e+00 7.020000e+02 0.000000e+00 9.878386e+02 2.455590e+02 0.000000e+00 0.000000e+00 1.000000e+00 13 | D_01: -3.644661e-01 1.790019e-01 1.148107e-03 -6.298563e-04 -5.314062e-02 14 | R_01: 9.993513e-01 1.860866e-02 -3.083487e-02 -1.887662e-02 9.997863e-01 -8.421873e-03 3.067156e-02 8.998467e-03 9.994890e-01 15 | T_01: -5.370000e-01 4.822061e-03 -1.252488e-02 16 | S_rect_01: 1.242000e+03 3.750000e+02 17 | R_rect_01: 9.996878e-01 -8.976826e-03 2.331651e-02 8.876121e-03 9.999508e-01 4.418952e-03 -2.335503e-02 -4.210612e-03 9.997184e-01 18 | P_rect_01: 7.215377e+02 0.000000e+00 6.095593e+02 -3.875744e+02 0.000000e+00 7.215377e+02 1.728540e+02 0.000000e+00 0.000000e+00 0.000000e+00 1.000000e+00 0.000000e+00 19 | S_02: 1.392000e+03 5.120000e+02 20 | K_02: 9.597910e+02 0.000000e+00 6.960217e+02 0.000000e+00 9.569251e+02 2.241806e+02 0.000000e+00 0.000000e+00 1.000000e+00 21 | D_02: -3.691481e-01 1.968681e-01 1.353473e-03 5.677587e-04 -6.770705e-02 22 | R_02: 9.999758e-01 -5.267463e-03 -4.552439e-03 5.251945e-03 9.999804e-01 -3.413835e-03 4.570332e-03 3.389843e-03 9.999838e-01 23 | T_02: 5.956621e-02 2.900141e-04 2.577209e-03 24 | S_rect_02: 1.242000e+03 3.750000e+02 25 | R_rect_02: 9.998817e-01 1.511453e-02 -2.841595e-03 -1.511724e-02 9.998853e-01 -9.338510e-04 2.827154e-03 9.766976e-04 9.999955e-01 26 | P_rect_02: 7.215377e+02 0.000000e+00 6.095593e+02 4.485728e+01 0.000000e+00 7.215377e+02 1.728540e+02 2.163791e-01 0.000000e+00 0.000000e+00 1.000000e+00 2.745884e-03 27 | S_03: 1.392000e+03 5.120000e+02 28 | K_03: 9.037596e+02 0.000000e+00 6.957519e+02 0.000000e+00 9.019653e+02 2.242509e+02 0.000000e+00 0.000000e+00 1.000000e+00 29 | D_03: -3.639558e-01 1.788651e-01 6.029694e-04 -3.922424e-04 -5.382460e-02 30 | R_03: 9.995599e-01 1.699522e-02 -2.431313e-02 -1.704422e-02 9.998531e-01 -1.809756e-03 2.427880e-02 2.223358e-03 9.997028e-01 31 | T_03: -4.731050e-01 5.551470e-03 -5.250882e-03 32 | S_rect_03: 1.242000e+03 3.750000e+02 33 | R_rect_03: 9.998321e-01 -7.193136e-03 1.685599e-02 7.232804e-03 9.999712e-01 -2.293585e-03 -1.683901e-02 2.415116e-03 9.998553e-01 34 | P_rect_03: 7.215377e+02 0.000000e+00 6.095593e+02 -3.395242e+02 0.000000e+00 7.215377e+02 1.728540e+02 2.199936e+00 0.000000e+00 0.000000e+00 1.000000e+00 2.729905e-03 35 | -------------------------------------------------------------------------------- /config/calib_imu_to_velo.txt: -------------------------------------------------------------------------------- 1 | calib_time: 25-May-2012 16:47:16 2 | R: 9.999976e-01 7.553071e-04 -2.035826e-03 -7.854027e-04 9.998898e-01 -1.482298e-02 2.024406e-03 1.482454e-02 9.998881e-01 3 | T: -8.086759e-01 3.195559e-01 -7.997231e-01 4 | -------------------------------------------------------------------------------- /config/calib_velo_to_cam.txt: -------------------------------------------------------------------------------- 1 | calib_time: 15-Mar-2012 11:37:16 2 | R: 7.533745e-03 -9.999714e-01 -6.166020e-04 1.480249e-02 7.280733e-04 -9.998902e-01 9.998621e-01 7.523790e-03 1.480755e-02 3 | T: -4.069766e-03 -7.631618e-02 -2.717806e-01 4 | delta_f: 0.000000e+00 0.000000e+00 5 | delta_c: 0.000000e+00 0.000000e+00 6 | -------------------------------------------------------------------------------- /config/ego_view.json: -------------------------------------------------------------------------------- 1 | { 2 | "class_name": "PinholeCameraParameters", 3 | "extrinsic": [ 4 | 0.05666746175333122, 5 | -0.39173740030181986, 6 | 0.9183303370700582, 7 | 0.0, 8 | -0.9975986525936923, 9 | 0.014470015169489602, 10 | 0.06773143291149537, 11 | 0.0, 12 | -0.039821189355472554, 13 | -0.9199632752810118, 14 | -0.38997672368046415, 15 | 0.0, 16 | -1.2434222622584756, 17 | 7.123571596128595, 18 | 18.14031098830259, 19 | 1.0 20 | ], 21 | "intrinsic": { 22 | "height": 800, 23 | "intrinsic_matrix": [ 24 | 692.820323027551, 25 | 0.0, 26 | 0.0, 27 | 0.0, 28 | 692.820323027551, 29 | 0.0, 30 | 399.5, 31 | 399.5, 32 | 1.0 33 | ], 34 | "width": 800 35 | }, 36 | "version_major": 1, 37 | "version_minor": 0, 38 | "h_fov": [ 39 | -40, 40 | 40 41 | ], 42 | "v_fov": [ 43 | -20, 44 | 20 45 | ], 46 | "x_range": null, 47 | "y_range": null, 48 | "z_range": null, 49 | "d_range": [ 50 | 0, 51 | 80 52 | ] 53 | } -------------------------------------------------------------------------------- /config/render_option.json: -------------------------------------------------------------------------------- 1 | { 2 | "background_color" : [ 0.0, 0.0, 0.0 ], 3 | "class_name" : "RenderOption", 4 | "default_mesh_color" : [ 0.69999999999999996, 0.69999999999999996, 0.69999999999999996 ], 5 | "image_max_depth" : 3000, 6 | "image_stretch_option" : 1, 7 | "interpolation_option" : 0, 8 | "light0_color" : [ 1.0, 1.0, 1.0 ], 9 | "light0_diffuse_power" : 0.66000000000000003, 10 | "light0_position" : [ 0.0, 0.0, 2.0 ], 11 | "light0_specular_power" : 0.20000000000000001, 12 | "light0_specular_shininess" : 100.0, 13 | "light1_color" : [ 1.0, 1.0, 1.0 ], 14 | "light1_diffuse_power" : 0.66000000000000003, 15 | "light1_position" : [ 0.0, 0.0, 2.0 ], 16 | "light1_specular_power" : 0.20000000000000001, 17 | "light1_specular_shininess" : 100.0, 18 | "light2_color" : [ 1.0, 1.0, 1.0 ], 19 | "light2_diffuse_power" : 0.66000000000000003, 20 | "light2_position" : [ 0.0, 0.0, -2.0 ], 21 | "light2_specular_power" : 0.20000000000000001, 22 | "light2_specular_shininess" : 100.0, 23 | "light3_color" : [ 1.0, 1.0, 1.0 ], 24 | "light3_diffuse_power" : 0.66000000000000003, 25 | "light3_position" : [ 0.0, 0.0, -2.0 ], 26 | "light3_specular_power" : 0.20000000000000001, 27 | "light3_specular_shininess" : 100.0, 28 | "light_ambient_color" : [ 0.0, 0.0, 0.0 ], 29 | "light_on" : true, 30 | "line_width" : 1.0, 31 | "mesh_color_option" : 1, 32 | "mesh_shade_option" : 0, 33 | "mesh_show_back_face" : false, 34 | "mesh_show_wireframe" : false, 35 | "point_color_option" : 0, 36 | "point_show_normal" : false, 37 | "point_size" : 2.0, 38 | "show_coordinate_frame" : false, 39 | "version_major" : 1, 40 | "version_minor" : 0 41 | } -------------------------------------------------------------------------------- /config/semantic-kitti.yaml: -------------------------------------------------------------------------------- 1 | # This file is covered by the LICENSE file in the root of this project. 2 | labels: 3 | 0 : "unlabeled" 4 | 1 : "outlier" 5 | 10: "car" 6 | 11: "bicycle" 7 | 13: "bus" 8 | 15: "motorcycle" 9 | 16: "on-rails" 10 | 18: "truck" 11 | 20: "other-vehicle" 12 | 30: "person" 13 | 31: "bicyclist" 14 | 32: "motorcyclist" 15 | 40: "road" 16 | 44: "parking" 17 | 48: "sidewalk" 18 | 49: "other-ground" 19 | 50: "building" 20 | 51: "fence" 21 | 52: "other-structure" 22 | 60: "lane-marking" 23 | 70: "vegetation" 24 | 71: "trunk" 25 | 72: "terrain" 26 | 80: "pole" 27 | 81: "traffic-sign" 28 | 99: "other-object" 29 | 252: "moving-car" 30 | 253: "moving-bicyclist" 31 | 254: "moving-person" 32 | 255: "moving-motorcyclist" 33 | 256: "moving-on-rails" 34 | 257: "moving-bus" 35 | 258: "moving-truck" 36 | 259: "moving-other-vehicle" 37 | color_map: # bgr 38 | 0 : [0, 0, 0] 39 | 1 : [0, 0, 255] 40 | 10: [245, 150, 100] 41 | 11: [245, 230, 100] 42 | 13: [250, 80, 100] 43 | 15: [150, 60, 30] 44 | 16: [255, 0, 0] 45 | 18: [180, 30, 80] 46 | 20: [255, 0, 0] 47 | 30: [30, 30, 255] 48 | 31: [200, 40, 255] 49 | 32: [90, 30, 150] 50 | 40: [255, 0, 255] 51 | 44: [255, 150, 255] 52 | 48: [75, 0, 75] 53 | 49: [75, 0, 175] 54 | 50: [0, 200, 255] 55 | 51: [50, 120, 255] 56 | 52: [0, 150, 255] 57 | 60: [170, 255, 150] 58 | 70: [0, 175, 0] 59 | 71: [0, 60, 135] 60 | 72: [80, 240, 150] 61 | 80: [150, 240, 255] 62 | 81: [0, 0, 255] 63 | 99: [255, 255, 50] 64 | 252: [245, 150, 100] 65 | 256: [255, 0, 0] 66 | 253: [200, 40, 255] 67 | 254: [30, 30, 255] 68 | 255: [90, 30, 150] 69 | 257: [250, 80, 100] 70 | 258: [180, 30, 80] 71 | 259: [255, 0, 0] 72 | content: # as a ratio with the total number of points 73 | 0: 0.018889854628292943 74 | 1: 0.0002937197336781505 75 | 10: 0.040818519255974316 76 | 11: 0.00016609538710764618 77 | 13: 2.7879693665067774e-05 78 | 15: 0.00039838616015114444 79 | 16: 0.0 80 | 18: 0.0020633612104619787 81 | 20: 0.0016218197275284021 82 | 30: 0.00017698551338515307 83 | 31: 1.1065903904919655e-08 84 | 32: 5.532951952459828e-09 85 | 40: 0.1987493871255525 86 | 44: 0.014717169549888214 87 | 48: 0.14392298360372 88 | 49: 0.0039048553037472045 89 | 50: 0.1326861944777486 90 | 51: 0.0723592229456223 91 | 52: 0.002395131480328884 92 | 60: 4.7084144280367186e-05 93 | 70: 0.26681502148037506 94 | 71: 0.006035012012626033 95 | 72: 0.07814222006271769 96 | 80: 0.002855498193863172 97 | 81: 0.0006155958086189918 98 | 99: 0.009923127583046915 99 | 252: 0.001789309418528068 100 | 253: 0.00012709999297008662 101 | 254: 0.00016059776092534436 102 | 255: 3.745553104802113e-05 103 | 256: 0.0 104 | 257: 0.00011351574470342043 105 | 258: 0.00010157861367183268 106 | 259: 4.3840131989471124e-05 107 | # classes that are indistinguishable from single scan or inconsistent in 108 | # ground truth are mapped to their closest equivalent 109 | learning_map: 110 | 0 : 0 # "unlabeled" 111 | 1 : 0 # "outlier" mapped to "unlabeled" --------------------------mapped 112 | 10: 1 # "car" 113 | 11: 2 # "bicycle" 114 | 13: 5 # "bus" mapped to "other-vehicle" --------------------------mapped 115 | 15: 3 # "motorcycle" 116 | 16: 5 # "on-rails" mapped to "other-vehicle" ---------------------mapped 117 | 18: 4 # "truck" 118 | 20: 5 # "other-vehicle" 119 | 30: 6 # "person" 120 | 31: 7 # "bicyclist" 121 | 32: 8 # "motorcyclist" 122 | 40: 9 # "road" 123 | 44: 10 # "parking" 124 | 48: 11 # "sidewalk" 125 | 49: 12 # "other-ground" 126 | 50: 13 # "building" 127 | 51: 14 # "fence" 128 | 52: 0 # "other-structure" mapped to "unlabeled" ------------------mapped 129 | 60: 9 # "lane-marking" to "road" ---------------------------------mapped 130 | 70: 15 # "vegetation" 131 | 71: 16 # "trunk" 132 | 72: 17 # "terrain" 133 | 80: 18 # "pole" 134 | 81: 19 # "traffic-sign" 135 | 99: 0 # "other-object" to "unlabeled" ----------------------------mapped 136 | 252: 1 # "moving-car" to "car" ------------------------------------mapped 137 | 253: 7 # "moving-bicyclist" to "bicyclist" ------------------------mapped 138 | 254: 6 # "moving-person" to "person" ------------------------------mapped 139 | 255: 8 # "moving-motorcyclist" to "motorcyclist" ------------------mapped 140 | 256: 5 # "moving-on-rails" mapped to "other-vehicle" --------------mapped 141 | 257: 5 # "moving-bus" mapped to "other-vehicle" -------------------mapped 142 | 258: 4 # "moving-truck" to "truck" --------------------------------mapped 143 | 259: 5 # "moving-other"-vehicle to "other-vehicle" ----------------mapped 144 | learning_map_inv: # inverse of previous map 145 | 0: 0 # "unlabeled", and others ignored 146 | 1: 10 # "car" 147 | 2: 11 # "bicycle" 148 | 3: 15 # "motorcycle" 149 | 4: 18 # "truck" 150 | 5: 20 # "other-vehicle" 151 | 6: 30 # "person" 152 | 7: 31 # "bicyclist" 153 | 8: 32 # "motorcyclist" 154 | 9: 40 # "road" 155 | 10: 44 # "parking" 156 | 11: 48 # "sidewalk" 157 | 12: 49 # "other-ground" 158 | 13: 50 # "building" 159 | 14: 51 # "fence" 160 | 15: 70 # "vegetation" 161 | 16: 71 # "trunk" 162 | 17: 72 # "terrain" 163 | 18: 80 # "pole" 164 | 19: 81 # "traffic-sign" 165 | learning_ignore: # Ignore classes 166 | 0: True # "unlabeled", and others ignored 167 | 1: False # "car" 168 | 2: False # "bicycle" 169 | 3: False # "motorcycle" 170 | 4: False # "truck" 171 | 5: False # "other-vehicle" 172 | 6: False # "person" 173 | 7: False # "bicyclist" 174 | 8: False # "motorcyclist" 175 | 9: False # "road" 176 | 10: False # "parking" 177 | 11: False # "sidewalk" 178 | 12: False # "other-ground" 179 | 13: False # "building" 180 | 14: False # "fence" 181 | 15: False # "vegetation" 182 | 16: False # "trunk" 183 | 17: False # "terrain" 184 | 18: False # "pole" 185 | 19: False # "traffic-sign" 186 | split: # sequence numbers 187 | train: 188 | - 0 189 | - 1 190 | - 2 191 | - 3 192 | - 4 193 | - 5 194 | - 6 195 | - 7 196 | - 9 197 | - 10 198 | valid: 199 | - 8 200 | test: 201 | - 11 202 | - 12 203 | - 13 204 | - 14 205 | - 15 206 | - 16 207 | - 17 208 | - 18 209 | - 19 210 | - 20 211 | - 21 212 | -------------------------------------------------------------------------------- /data_utils/ColorGenerator_Loader.py: -------------------------------------------------------------------------------- 1 | class PointNetColorGen(nn.Module): 2 | def __init__(self,num_class, input_dims=4, feature_transform=False): 3 | super(PointNetColorGen, self).__init__() 4 | self.k = num_class 5 | self.feat = PointNetEncoder(global_feat=False,input_dims = input_dims, feature_transform=feature_transform) 6 | self.conv1 = torch.nn.Conv1d(1088, 512, 1) 7 | self.conv2 = torch.nn.Conv1d(512, 256, 1) 8 | self.conv3 = torch.nn.Conv1d(256, 128, 1) 9 | self.conv4 = torch.nn.Conv1d(128, self.k, 1) 10 | self.bn1 = nn.BatchNorm1d(512) 11 | self.bn2 = nn.BatchNorm1d(256) 12 | self.bn3 = nn.BatchNorm1d(128) 13 | 14 | def forward(self, x): 15 | batchsize = x.size()[0] 16 | n_pts = x.size()[2] 17 | x, trans, trans_feat = self.feat(x) 18 | x = F.relu(self.bn1(self.conv1(x))) 19 | x = F.relu(self.bn2(self.conv2(x))) 20 | x = F.relu(self.bn3(self.conv3(x))) 21 | x = self.conv4(x) 22 | x = x.transpose(2,1).contiguous() 23 | x = x.view(batchsize, n_pts, self.k) 24 | return x, trans_feat 25 | 26 | class ColorGeneratorLoader(Dataset): 27 | def __init__(self, root, npoints, train = True): 28 | self.root = root 29 | self.train = train 30 | self.npoints = npoints 31 | self.np_redis = Mat_Redis_Utils() 32 | self.utils = Semantic_KITTI_Utils(root,'inview','learning') 33 | 34 | part_length = {'00': 4540,'01':1100,'02':4660,'03':800,'04':270,'05':2760, 35 | '06':1100,'07':1100,'08':4070,'09':1590,'10':1200} 36 | 37 | self.keys = [] 38 | alias = 'gen' 39 | 40 | if self.train: 41 | for part in ['00','01','02','03','04','05','06','07','09','10']: 42 | length = part_length[part] 43 | for index in range(0,length,2): 44 | self.keys.append('%s/%s/%06d'%(alias, part, index)) 45 | else: 46 | for part in ['08']: 47 | length = part_length[part] 48 | for index in range(0,length,2): 49 | self.keys.append('%s/%s/%06d'%(alias, part, index)) 50 | 51 | def __len__(self): 52 | return len(self.keys) 53 | 54 | def get_data(self, key): 55 | if not self.np_redis.exists(key): 56 | alias, part, index = key.split('/') 57 | 58 | point_cloud, _ = self.utils.get(part, int(index), load_image = True) 59 | pts_2d = self.utils.project_3d_to_2d(point_cloud[:,:3]).astype(np.int32) 60 | pts_color = np.zeros((point_cloud.shape[0],3), dtype=np.float32) 61 | 62 | frame_shape = self.utils.frame.shape 63 | for i,(y,x) in enumerate(pts_2d): 64 | if x >= 0 and x < frame_shape[0] and y >= 0 and y < frame_shape[1]: 65 | pts_color[i] = self.utils.frame_HSV[x,y] 66 | # img = self.utils.draw_2d_points(pts_2d, pts_color, on_black=True) 67 | # plt.imshow(cv2.cvtColor(img, cv2.COLOR_HSV2RGB)) 68 | pts_color[:,0] /= 180 69 | pts_color[:,1:] /= 255 70 | 71 | to_store = np.concatenate((point_cloud, pts_color),axis=1) 72 | self.np_redis.set(key, to_store) 73 | print('add', key, to_store.shape, to_store.dtype) 74 | else: 75 | data = self.np_redis.get(key) 76 | point_cloud = data[:,:4] 77 | pts_color = data[:,4:] 78 | 79 | # Unnormalized Point cloud 80 | return point_cloud, pts_color 81 | 82 | def __getitem__(self, index): 83 | point_cloud, label = self.get_data(self.keys[index]) 84 | #pcd = point_cloud 85 | pcd = pcd_normalize(point_cloud) 86 | if self.train: 87 | pcd = pcd_jitter(pcd) 88 | 89 | length = pcd.shape[0] 90 | if length == self.npoints: 91 | pass 92 | elif length > self.npoints: 93 | start_idx = np.random.randint(0, length - self.npoints) 94 | end_idx = start_idx + self.npoints 95 | pcd = pcd[start_idx:end_idx] 96 | label = label[start_idx:end_idx] 97 | else: 98 | rows_short = self.npoints - length 99 | pcd = np.concatenate((pcd,pcd[0:rows_short]),axis=0) 100 | label = np.concatenate((label,label[0:rows_short]),axis=0) 101 | return pcd, label -------------------------------------------------------------------------------- /data_utils/ModelNetDataLoader.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import warnings 3 | import h5py 4 | from torch.utils.data import Dataset 5 | import sys 6 | from .augmentation import rotate_point_cloud, jitter_point_cloud, point_cloud_normalize 7 | 8 | class_names = ['airplane','bathtub','bed','bench','bookshelf','bottle', 9 | 'bowl','car','chair','cone','cup','curtain','desk','door', 10 | 'dresser','flower_pot','glass_box','guitar','keyboard','lamp', 11 | 'laptop','mantel','monitor','night_stand','person','piano', 12 | 'plant','radio','range_hood','sink','sofa','stairs','stool', 13 | 'table','tent','toilet','tv_stand','vase','wardrobe','xbox'] 14 | 15 | def load_h5(h5_filename): 16 | f = h5py.File(h5_filename,'r') 17 | data = f['data'][:] 18 | label = f['label'][:] 19 | seg = [] 20 | return (data, label, seg) 21 | 22 | def load_data(path, train=True, classification = True): 23 | if train: 24 | data_train0, label_train0, Seglabel_train0 = load_h5(path + 'ply_data_train0.h5') 25 | data_train1, label_train1, Seglabel_train1 = load_h5(path + 'ply_data_train1.h5') 26 | data_train2, label_train2, Seglabel_train2 = load_h5(path + 'ply_data_train2.h5') 27 | data_train3, label_train3, Seglabel_train3 = load_h5(path + 'ply_data_train3.h5') 28 | data_train4, label_train4, Seglabel_train4 = load_h5(path + 'ply_data_train4.h5') 29 | train_data = np.concatenate([data_train0,data_train1,data_train2,data_train3,data_train4]) 30 | train_label = np.concatenate([label_train0,label_train1,label_train2,label_train3,label_train4]) 31 | train_Seglabel = np.concatenate([Seglabel_train0,Seglabel_train1,Seglabel_train2,Seglabel_train3,Seglabel_train4]) 32 | 33 | data_test0, label_test0, Seglabel_test0 = load_h5(path + 'ply_data_test0.h5') 34 | data_test1, label_test1, Seglabel_test1 = load_h5(path + 'ply_data_test1.h5') 35 | test_data = np.concatenate([data_test0,data_test1]) 36 | test_label = np.concatenate([label_test0,label_test1]) 37 | 38 | test_Seglabel = np.concatenate([Seglabel_test0,Seglabel_test1]) 39 | if train: 40 | if classification: 41 | return train_data, train_label, test_data, test_label 42 | else: 43 | return train_data, train_Seglabel, test_data, test_Seglabel 44 | else: 45 | if classification: 46 | return test_data, test_label 47 | else: 48 | return test_data, test_Seglabel 49 | 50 | 51 | class ModelNetDataLoader(Dataset): 52 | def __init__(self, data, labels, data_augmentation = False): 53 | self.data = data 54 | self.labels = labels 55 | self.data_augmentation = data_augmentation 56 | 57 | def __len__(self): 58 | return len(self.data) 59 | 60 | def __getitem__(self, index): 61 | pointcloud = self.data[index] 62 | label = self.labels[index] 63 | 64 | if self.data_augmentation: 65 | pcd = np.expand_dims(pcd,axis=0) 66 | pcd = rotate_point_cloud(pcd) 67 | pcd = jitter_point_cloud(pcd).astype(np.float32) 68 | pcd = np.squeeze(pcd, axis=0) 69 | 70 | return pointcloud, label -------------------------------------------------------------------------------- /data_utils/S3DISDataLoader.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | import os 3 | from torch.utils.data import Dataset 4 | import numpy as np 5 | import h5py 6 | import sys 7 | from .augmentation import rotate_point_cloud, jitter_point_cloud 8 | 9 | classes = ['ceiling','floor','wall','beam','column','window','door','table','chair','sofa','bookcase','board','clutter'] 10 | class2label = {cls: i for i,cls in enumerate(classes)} 11 | 12 | label_id_to_name = {} 13 | for i,cat in enumerate(class2label.keys()): 14 | label_id_to_name[i] = cat 15 | 16 | def getDataFiles(list_filename): 17 | return [line.rstrip() for line in open(list_filename)] 18 | 19 | def load_h5(h5_filename): 20 | f = h5py.File(h5_filename) 21 | data = f['data'][:] 22 | label = f['label'][:] 23 | return (data, label) 24 | 25 | def loadDataFile(filename): 26 | print(filename) 27 | return load_h5(filename) 28 | 29 | def recognize_all_data(root, test_area = 5): 30 | ALL_FILES = getDataFiles(os.path.join(root, 'all_files.txt')) 31 | room_filelist = [line.rstrip() for line in open(os.path.join(root, 'room_filelist.txt'))] 32 | data_batch_list = [] 33 | label_batch_list = [] 34 | 35 | for h5_filename in ALL_FILES: 36 | h5_filename = h5_filename.split('/')[-1] 37 | data_batch, label_batch = loadDataFile(os.path.join(root,h5_filename)) 38 | data_batch_list.append(data_batch) 39 | label_batch_list.append(label_batch) 40 | 41 | data_batches = np.concatenate(data_batch_list, 0) 42 | label_batches = np.concatenate(label_batch_list, 0) 43 | 44 | test_area = 'Area_' + str(test_area) 45 | train_idxs = [] 46 | test_idxs = [] 47 | for i, room_name in enumerate(room_filelist): 48 | if test_area in room_name: 49 | test_idxs.append(i) 50 | else: 51 | train_idxs.append(i) 52 | 53 | train_data = data_batches[train_idxs] 54 | train_label = label_batches[train_idxs] 55 | test_data = data_batches[test_idxs] 56 | test_label = label_batches[test_idxs] 57 | return train_data,train_label,test_data,test_label 58 | 59 | class S3DISDataLoader(Dataset): 60 | def __init__(self, data, labels, data_augmentation=False): 61 | self.data = data 62 | self.labels = labels 63 | self.data_augmentation = data_augmentation 64 | 65 | def __len__(self): 66 | return len(self.data) 67 | 68 | def __getitem__(self, index): 69 | pointcloud = self.data[index] 70 | label = self.labels[index] 71 | 72 | if self.data_augmentation: 73 | pcd = np.expand_dims(pcd,axis=0) 74 | # pcd = rotate_point_cloud(pcd) 75 | pcd = jitter_point_cloud(pcd).astype(np.float32) 76 | pcd = np.squeeze(pcd, axis=0) 77 | 78 | return pointcloud, label -------------------------------------------------------------------------------- /data_utils/SemKITTI_Loader.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import yaml 5 | import numpy as np 6 | import random 7 | from tqdm import tqdm 8 | import torch 9 | from torch.utils.data import Dataset 10 | import threading 11 | import multiprocessing 12 | from PIL import Image 13 | 14 | from .kitti_utils import Semantic_KITTI_Utils 15 | from .redis_utils import Mat_Redis_Utils 16 | 17 | def pcd_jitter(pcd, sigma=0.01, clip=0.05): 18 | N, C = pcd.shape 19 | jittered_data = np.clip(sigma * np.random.randn(N, C), -1*clip, clip).astype(pcd.dtype) 20 | jittered_data += pcd 21 | return jittered_data 22 | 23 | def pcd_normalize(pcd): 24 | pcd = pcd.copy() 25 | pcd[:,0] = pcd[:,0] / 70 26 | pcd[:,1] = pcd[:,1] / 70 27 | pcd[:,2] = pcd[:,2] / 3 28 | pcd[:,3] = (pcd[:,3] - 0.5)*2 29 | pcd = np.clip(pcd,-1,1) 30 | return pcd 31 | 32 | def pcd_unnormalize(pcd): 33 | pcd = pcd.copy() 34 | pcd[:,0] = pcd[:,0] * 70 35 | pcd[:,1] = pcd[:,1] * 70 36 | pcd[:,2] = pcd[:,2] * 3 37 | pcd[:,3] = pcd[:,3] / 2 + 0.5 38 | return pcd 39 | 40 | def pcd_tensor_unnorm(pcd): 41 | pcd_unnorm = pcd.clone() 42 | pcd_unnorm[:,0] = pcd[:,0] * 70 43 | pcd_unnorm[:,1] = pcd[:,1] * 70 44 | pcd_unnorm[:,2] = pcd[:,2] * 3 45 | pcd_unnorm[:,3] = pcd[:,3] / 2 + 0.5 46 | return pcd_unnorm 47 | 48 | class SemKITTI_Loader(Dataset): 49 | def __init__(self, root, npoints, train = True, subset = 'all'): 50 | self.root = root 51 | self.train = train 52 | self.npoints = npoints 53 | self.np_redis = Mat_Redis_Utils() 54 | self.utils = Semantic_KITTI_Utils(root,subset) 55 | 56 | part_length = {'00': 4540,'01':1100,'02':4660,'03':800,'04':270,'05':2760, 57 | '06':1100,'07':1100,'08':4070,'09':1590,'10':1200} 58 | 59 | self.keys = [] 60 | alias = subset[0] 61 | 62 | if self.train: 63 | for part in ['00','01','02','03','04','05','06','07','09','10']: 64 | length = part_length[part] 65 | for index in range(0,length,2): 66 | self.keys.append('%s/%s/%06d'%(alias, part, index)) 67 | else: 68 | for part in ['08']: 69 | length = part_length[part] 70 | for index in range(0,length): 71 | self.keys.append('%s/%s/%06d'%(alias, part, index)) 72 | 73 | def __len__(self): 74 | return len(self.keys) 75 | 76 | def get_data(self, key): 77 | if not self.np_redis.exists(key): 78 | alias, part, index = key.split('/') 79 | point_cloud, label = self.utils.get(part, int(index)) 80 | 81 | to_store = np.concatenate((point_cloud, label.reshape((-1,1)).astype(np.float32)),axis=1) 82 | self.np_redis.set(key, to_store) 83 | else: 84 | data = self.np_redis.get(key) 85 | point_cloud = data[:,:4] 86 | label = data[:,4].astype(np.int32) 87 | 88 | # Unnormalized Point cloud 89 | return point_cloud, label 90 | 91 | def __getitem__(self, index): 92 | point_cloud, label = self.get_data(self.keys[index]) 93 | pcd = pcd_normalize(point_cloud) 94 | if self.train: 95 | pcd = pcd_jitter(pcd) 96 | 97 | # length = pcd.shape[0] 98 | # if length == self.npoints: 99 | # pass 100 | # elif length > self.npoints: 101 | # start_idx = np.random.randint(0, length - self.npoints) 102 | # end_idx = start_idx + self.npoints 103 | # pcd = pcd[start_idx:end_idx] 104 | # label = label[start_idx:end_idx] 105 | # else: 106 | # rows_short = self.npoints - length 107 | # pcd = np.concatenate((pcd,pcd[0:rows_short]),axis=0) 108 | # label = np.concatenate((label,label[0:rows_short]),axis=0) 109 | 110 | length = pcd.shape[0] 111 | choice = np.random.choice(length, self.npoints, replace=True) 112 | pcd = pcd[choice] 113 | label = label[choice] 114 | 115 | return pcd, label 116 | 117 | 118 | -------------------------------------------------------------------------------- /data_utils/ShapeNetDataLoader.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | import os 3 | import json 4 | import warnings 5 | import numpy as np 6 | import gc 7 | from tqdm import tqdm 8 | import h5py 9 | from torch.utils.data import Dataset 10 | warnings.filterwarnings('ignore') 11 | import sys 12 | from .augmentation import rotate_point_cloud, jitter_point_cloud, point_cloud_normalize 13 | 14 | seg_classes = { 15 | 'Earphone': [16, 17, 18], 16 | 'Motorbike': [30, 31, 32, 33, 34, 35], 17 | 'Rocket': [41, 42, 43], 18 | 'Car': [8, 9, 10, 11], 19 | 'Laptop': [28, 29], 20 | 'Cap': [6, 7], 21 | 'Skateboard': [44, 45, 46], 22 | 'Mug': [36, 37], 23 | 'Guitar': [19, 20, 21], 24 | 'Bag': [4, 5], 25 | 'Lamp': [24, 25, 26, 27], 26 | 'Table': [47, 48, 49], 27 | 'Airplane': [0, 1, 2, 3], 28 | 'Pistol': [38, 39, 40], 29 | 'Chair': [12, 13, 14, 15], 30 | 'Knife': [22, 23] 31 | } 32 | label_id_to_name = {} 33 | for cat in seg_classes.keys(): 34 | for label in seg_classes[cat]: 35 | label_id_to_name[label] = cat 36 | 37 | class PartNormalDataset(Dataset): 38 | def __init__(self, root, cache = {}, npoints=2500, split='train', normalize=True, data_augmentation=False): 39 | self.npoints = npoints 40 | self.root = root 41 | self.category = {} 42 | self.normalize = normalize 43 | self.cache = cache 44 | self.data_augmentation = data_augmentation 45 | 46 | self.wordnet_id_to_category = {} 47 | with open(os.path.join(self.root, 'synsetoffset2category.txt'), 'r') as f: 48 | for line in f: 49 | line = line.strip().split() 50 | self.category[line[0]] = line[1] 51 | self.wordnet_id_to_category[line[1]] = line[0] 52 | 53 | fn_split = os.path.join(self.root, 'train_test_split') 54 | with open(os.path.join(fn_split,'shuffled_train_file_list.json'), 'r') as f: 55 | train_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 56 | with open(os.path.join(fn_split,'shuffled_val_file_list.json'), 'r') as f: 57 | val_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 58 | with open(os.path.join(fn_split,'shuffled_test_file_list.json'), 'r') as f: 59 | test_ids = set([str(d.split('/')[2]) for d in json.load(f)]) 60 | 61 | self.meta = {} 62 | for item in self.category: 63 | self.meta[item] = [] 64 | dir_point = os.path.join(self.root, self.category[item]) 65 | fns = sorted(os.listdir(dir_point)) 66 | 67 | if split == 'trainval': 68 | fns = [fn for fn in fns if ((fn[0:-4] in train_ids) or (fn[0:-4] in val_ids))] 69 | elif split == 'train': 70 | fns = [fn for fn in fns if fn[0:-4] in train_ids] 71 | elif split == 'val': 72 | fns = [fn for fn in fns if fn[0:-4] in val_ids] 73 | elif split == 'test': 74 | fns = [fn for fn in fns if fn[0:-4] in test_ids] 75 | else: 76 | raise ValueError('Unknown split: %s. Exiting..' % (split)) 77 | 78 | for fn in fns: 79 | self.meta[item].append(os.path.join(dir_point, fn)) 80 | 81 | self.datapath = [] 82 | for item in self.category: 83 | for fn in self.meta[item]: 84 | self.datapath.append(fn) 85 | 86 | self.classes = dict(zip(self.category, range(len(self.category)))) 87 | # print('classes',self.classes.keys()) 88 | 89 | self.seg_classes = {'Earphone': [16, 17, 18], 'Motorbike': [30, 31, 32, 33, 34, 35], 'Rocket': [41, 42, 43], 90 | 'Car': [8, 9, 10, 11], 'Laptop': [28, 29], 'Cap': [6, 7], 'Skateboard': [44, 45, 46], 91 | 'Mug': [36, 37], 'Guitar': [19, 20, 21], 'Bag': [4, 5], 'Lamp': [24, 25, 26, 27], 92 | 'Table': [47, 48, 49], 'Airplane': [0, 1, 2, 3], 'Pistol': [38, 39, 40], 93 | 'Chair': [12, 13, 14, 15], 'Knife': [22, 23]} 94 | 95 | def __getitem__(self, index): 96 | fn_full = self.datapath[index] 97 | parts = fn_full.split('/') 98 | wordnet_id = parts[-2] 99 | category = self.wordnet_id_to_category[wordnet_id] 100 | cls_id = np.array([self.classes[category]]).astype(np.int32) 101 | token = parts[-1].split('.')[0] 102 | h5_index = '%s_%s'%(wordnet_id,token) 103 | 104 | if h5_index in self.cache.keys(): 105 | data = self.cache[h5_index] 106 | pointcloud = data[:, 0:3] 107 | normal = data[:, 3:6] 108 | seg = data[:, -1].astype(np.int32) 109 | else: 110 | print('Error: cache miss',h5_index) 111 | data = np.loadtxt(fn_full).astype(np.float32) 112 | 113 | if self.normalize: 114 | pointcloud = point_cloud_normalize(pointcloud) 115 | 116 | if self.data_augmentation: 117 | pointcloud = np.expand_dims(pointcloud,axis=0) 118 | pointcloud = rotate_point_cloud(pointcloud) 119 | pointcloud = jitter_point_cloud(pointcloud).astype(np.float32) 120 | pointcloud = np.squeeze(pointcloud, axis=0) 121 | 122 | # resample 123 | choice = np.random.choice(len(seg), self.npoints, replace=True) 124 | pointcloud = pointcloud[choice, :] 125 | seg = seg[choice] 126 | normal = normal[choice, :] 127 | return pointcloud, cls_id, seg, normal 128 | 129 | def __len__(self): 130 | return len(self.datapath) -------------------------------------------------------------------------------- /data_utils/augmentation.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | 3 | 4 | def point_cloud_normalize(pc): 5 | centroid = np.mean(pc, axis=0) 6 | pc = pc - centroid 7 | m = np.max(np.sqrt(np.sum(pc ** 2, axis=1))) 8 | pc = pc / m 9 | return pc 10 | 11 | 12 | def shuffle_data(data, labels): 13 | """ Shuffle data and labels. 14 | Input: 15 | data: B,N,... numpy array 16 | label: B,... numpy array 17 | Return: 18 | shuffled data, label and shuffle indices 19 | """ 20 | idx = np.arange(len(labels)) 21 | np.random.shuffle(idx) 22 | return data[idx, ...], labels[idx], idx 23 | 24 | 25 | def rotate_point_cloud(batch_data): 26 | """ Randomly rotate the point clouds to augument the dataset 27 | rotation is per shape based along up direction 28 | Input: 29 | BxNx3 array, original batch of point clouds 30 | Return: 31 | BxNx3 array, rotated batch of point clouds 32 | """ 33 | assert len(batch_data.shape) == 3, batch_data.shape 34 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 35 | for k in range(batch_data.shape[0]): 36 | rotation_angle = np.random.uniform() * 2 * np.pi 37 | cosval = np.cos(rotation_angle) 38 | sinval = np.sin(rotation_angle) 39 | rotation_matrix = np.array([[cosval, 0, sinval], 40 | [0, 1, 0], 41 | [-sinval, 0, cosval]]) 42 | shape_pc = batch_data[k, ...] 43 | rotated_data[k, ...] = np.dot( 44 | shape_pc.reshape((-1, 3)), rotation_matrix) 45 | return rotated_data 46 | 47 | 48 | def rotate_point_cloud_by_angle(batch_data, rotation_angle): 49 | """ Rotate the point cloud along up direction with certain angle. 50 | Input: 51 | BxNx3 array, original batch of point clouds 52 | Return: 53 | BxNx3 array, rotated batch of point clouds 54 | """ 55 | assert len(batch_data.shape) == 3, batch_data.shape 56 | rotated_data = np.zeros(batch_data.shape, dtype=np.float32) 57 | for k in range(batch_data.shape[0]): 58 | # rotation_angle = np.random.uniform() * 2 * np.pi 59 | cosval = np.cos(rotation_angle) 60 | sinval = np.sin(rotation_angle) 61 | rotation_matrix = np.array([[cosval, 0, sinval], 62 | [0, 1, 0], 63 | [-sinval, 0, cosval]]) 64 | shape_pc = batch_data[k, ...] 65 | rotated_data[k, ...] = np.dot( 66 | shape_pc.reshape((-1, 3)), rotation_matrix) 67 | return rotated_data 68 | 69 | 70 | def jitter_point_cloud(batch_data, sigma=0.01, clip=0.05): 71 | """ Randomly jitter points. jittering is per point. 72 | Input: 73 | BxNx3 array, original batch of point clouds 74 | Return: 75 | BxNx3 array, jittered batch of point clouds 76 | """ 77 | assert len(batch_data.shape) == 3, batch_data.shape 78 | B, N, C = batch_data.shape 79 | assert(clip > 0) 80 | jittered_data = np.clip(sigma * np.random.randn(B, N, C), -1*clip, clip) 81 | jittered_data += batch_data 82 | return jittered_data -------------------------------------------------------------------------------- /data_utils/download_data.sh: -------------------------------------------------------------------------------- 1 | #!/bin/bash 2 | 3 | # Download HDF5 for indoor 3d semantic segmentation (around 1.6GB) 4 | wget https://shapenet.cs.stanford.edu/media/indoor3d_sem_seg_hdf5_data.zip 5 | unzip indoor3d_sem_seg_hdf5_data.zip 6 | rm indoor3d_sem_seg_hdf5_data.zip 7 | 8 | -------------------------------------------------------------------------------- /data_utils/kitti_utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import cv2 3 | import json 4 | import yaml 5 | import numpy as np 6 | import random 7 | from tqdm import tqdm 8 | import torch 9 | from PIL import Image 10 | 11 | class KITTI_2_Common(): 12 | def __init__(self, model): 13 | self.model = model 14 | self.common = None 15 | self.kitti_names = ['road', 'sidewalk', 'building', 'wall', 'fence', 16 | 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain', 17 | 'sky', 'person', 'rider', 'car', 'truck', 'bus', 'train', 18 | 'motorcycle', 'bicycle' 19 | ] 20 | self.kitti_2_common = [ 21 | 'road','sidewalk','building+wall','fence','pole', 22 | 'traffic_light+traffic_sign','vegetation', 23 | 'terrain','person','rider','car','truck', 24 | 'bus+train','motorcycle','bicycle','sky' 25 | ] 26 | self.kitti_colors = [ 27 | [128, 64, 128], [244, 35, 232], [70, 70, 70], [102, 102, 156], 28 | [190, 153, 153], [153, 153, 153], [250, 170, 30],[220, 220, 0], [107, 142, 35], 29 | [152, 251, 152], [0, 130, 180], [220, 20, 60], [255, 0, 0], [0, 0, 142], 30 | [0, 0, 70], [0, 60, 100], [0, 80, 100], [0, 0, 230], [119, 11, 32] 31 | ] 32 | self.colors = [] 33 | for index, c_class in enumerate(self.kitti_2_common): 34 | name_0 = c_class.split('+')[0] 35 | index_0 = self.kitti_names.index(name_0) 36 | self.colors.append(self.kitti_colors[index_0]) 37 | 38 | self.kitti_colors = np.array(self.kitti_colors) 39 | self.colors = np.array(self.colors) 40 | 41 | def __call__(self, x): 42 | logits = self.model(x) 43 | new_size = list(logits.size()) 44 | new_size[1] = len(self.kitti_2_common) 45 | self.common = torch.zeros(new_size, dtype=logits.dtype).cuda() 46 | 47 | for index, c_class in enumerate(self.kitti_2_common): 48 | merge = c_class.split('+') 49 | if len(merge) == 1: 50 | index_0 = self.kitti_names.index(merge[0]) 51 | self.common[:, index] = logits[:,index_0] 52 | elif len(merge) == 2: 53 | index_0 = self.kitti_names.index(merge[0]) 54 | index_1 = self.kitti_names.index(merge[1]) 55 | self.common[:, index] = logits[:,[index_0,index_1]].max(1)[0] 56 | else: 57 | raise NotImplementedError("not implemented!") 58 | return self.common 59 | 60 | 61 | class SemKITTI_2_Common(): 62 | def __init__(self, model, model_name): 63 | self.model_name = model_name 64 | self.model = model 65 | self.common = None 66 | self.semkitti_names = [ 67 | 'car', 'bicycle', 'motorcycle', 'truck', 'other-vehicle', 68 | 'person', 'bicyclist', 'motorcyclist', 'road', 'parking', 'sidewalk', 'other-ground', 69 | 'building', 'fence', 'vegetation', 'trunk', 'terrain','pole', 'traffic-sign' 70 | ] 71 | self.semkitti_2_common = [ 72 | 'road', 'parking+sidewalk', 'building', 'fence', 'trunk+pole', 73 | 'traffic-sign', 'vegetation', 'terrain', 'person', 'bicyclist+motorcyclist', 74 | 'car', 'truck', 'other-vehicle', 'motorcycle', 'bicycle','other-ground' 75 | ] 76 | self.semkitti_colors = [ 77 | [245, 150, 100],[245, 230, 100],[150, 60, 30],[180, 30, 80], 78 | [255, 0, 0],[30, 30, 255],[200, 40, 255],[90, 30, 150],[255, 0, 255], 79 | [255, 150, 255], [75, 0, 75],[75, 0, 175],[0, 200, 255],[50, 120, 255], 80 | [0, 175, 0],[0, 60, 135],[80, 240, 150],[150, 240, 255],[0, 0, 255] 81 | ] 82 | 83 | self.colors = [] 84 | for index, c_class in enumerate(self.semkitti_2_common): 85 | name_0 = c_class.split('+')[0] 86 | index_0 = self.semkitti_names.index(name_0) 87 | self.colors.append(self.semkitti_colors[index_0]) 88 | 89 | self.semkitti_colors = np.array(self.semkitti_colors) 90 | self.colors = np.array(self.colors) 91 | 92 | def __call__(self, x): 93 | if self.model_name == 'pointnet': 94 | logits, feature_transform = self.model(x) 95 | else: 96 | logits = self.model(x) 97 | 98 | new_size = list(logits.size()) 99 | new_size[2] = len(self.semkitti_2_common) 100 | self.common = torch.zeros(new_size, dtype=logits.dtype).cuda() 101 | 102 | for index, c_class in enumerate(self.semkitti_2_common): 103 | merge = c_class.split('+') 104 | if len(merge) == 1: 105 | index_0 = self.semkitti_names.index(merge[0]) 106 | self.common[:, :, index] = logits[:, :, index_0] 107 | elif len(merge) == 2: 108 | index_0 = self.semkitti_names.index(merge[0]) 109 | index_1 = self.semkitti_names.index(merge[1]) 110 | self.common[:, :, index] = logits[:, :, [index_0,index_1]].max(2)[0] 111 | else: 112 | raise NotImplementedError("not implemented!") 113 | 114 | if self.model_name == 'pointnet': 115 | return self.common, feature_transform 116 | else: 117 | return self.common 118 | 119 | 120 | sem_kitti_class_names = [ 121 | 'car','bicycle','motorcycle','truck','other-vehicle', 122 | 'person','bicyclist','motorcyclist','road','parking','sidewalk','other-ground', 123 | 'building','fence','vegetation','trunk','terrain','pole','traffic-sign'] 124 | 125 | sem_kitti_colors = [[245, 150, 100],[245, 230, 100],[150, 60, 30],[180, 30, 80], 126 | [255, 0, 0],[30, 30, 255],[200, 40, 255],[90, 30, 150],[255, 0, 255], 127 | [255, 150, 255], [75, 0, 75],[75, 0, 175],[0, 200, 255],[50, 120, 255], 128 | [0, 175, 0],[0, 60, 135],[80, 240, 150],[150, 240, 255],[0, 0, 255] 129 | ] 130 | 131 | kitti_class_names = [ 132 | 'road', 'sidewalk', 'building', 'wall', 'fence', 133 | 'pole', 'traffic_light', 'traffic_sign', 'vegetation', 'terrain','sky', 'person', 134 | 'rider', 'car', 'truck', 'bus', 'train','motorcycle', 'bicycle'] 135 | 136 | kitti_colors = [[128, 64, 128],[244, 35, 232],[70, 70, 70],[102, 102, 156], 137 | [190, 153, 153],[153, 153, 153],[250, 170, 30],[220, 220, 0],[107, 142, 35], 138 | [152, 251, 152],[0, 130, 180],[220, 20, 60],[255, 0, 0],[0, 0, 142], 139 | [0, 0, 70],[0, 60, 100],[0, 80, 100],[0, 0, 230],[119, 11, 32] 140 | ] 141 | 142 | class Semantic_KITTI_Utils(): 143 | def __init__(self, root, subset = 'all'): 144 | self.root = root 145 | 146 | base_path = os.path.dirname(os.path.realpath(__file__)) + '/../' 147 | 148 | self.R, self.T = self.calib_velo2cam(base_path+'config/calib_velo_to_cam.txt') 149 | self.P = self.calib_cam2cam(base_path+'config/calib_cam_to_cam.txt' ,mode="02") 150 | self.RT = np.concatenate((self.R, self.T), axis=1) 151 | 152 | self.sem_cfg = yaml.load(open(base_path+'config/semantic-kitti.yaml','r'), Loader=yaml.SafeLoader) 153 | 154 | self.class_names = self.sem_cfg['labels'] 155 | self.learning_map = self.sem_cfg['learning_map'] 156 | self.learning_map_inv = self.sem_cfg['learning_map_inv'] 157 | self.learning_ignore = self.sem_cfg['learning_ignore'] 158 | self.sem_color_map = self.sem_cfg['color_map'] 159 | 160 | self.length = { 161 | '00': 4540,'01':1100,'02':4660,'03':800,'04':270,'05':2760, 162 | '06':1100,'07':1100,'08':4070,'09':1590,'10':1200 163 | } 164 | assert subset in ['all', 'inview'], subset 165 | 166 | self.subset = subset 167 | 168 | self.num_classes = 19 169 | self.index_to_name = {i:name for i,name in enumerate(sem_kitti_class_names)} 170 | self.name_to_index = {name:i for i,name in enumerate(sem_kitti_class_names)} 171 | 172 | self.kitti_index_to_name = {i:name for i,name in enumerate(kitti_class_names)} 173 | self.kitti_name_to_index = {name:i for i,name in enumerate(kitti_class_names)} 174 | 175 | self.class_names = sem_kitti_class_names 176 | 177 | self.kitti_colors = np.array(kitti_colors,np.uint8) 178 | self.kitti_colors_bgr = np.array([list(reversed(c)) for c in kitti_colors],np.uint8) 179 | 180 | self.colors = np.array(sem_kitti_colors,np.uint8) 181 | self.colors_bgr = np.array([list(reversed(c)) for c in sem_kitti_colors],np.uint8) 182 | 183 | def get(self, part, index, load_image = False): 184 | 185 | sequence_root = os.path.join(self.root, 'sequences/%s/'%(part)) 186 | assert index <= self.length[part], index 187 | 188 | if load_image: 189 | fn_frame = os.path.join(sequence_root, 'image_2/%06d.png' % (index)) 190 | assert os.path.exists(fn_frame), 'Broken dataset %s' % (fn_frame) 191 | self.frame_BGR = cv2.imread(fn_frame) 192 | self.frame = cv2.cvtColor(self.frame_BGR, cv2.COLOR_BGR2RGB) 193 | #self.frame_HSV = cv2.cvtColor(self.frame_BGR, cv2.COLOR_BGR2HSV) 194 | 195 | fn_velo = os.path.join(sequence_root, 'velodyne/%06d.bin' %(index)) 196 | fn_label = os.path.join(sequence_root, 'labels/%06d.label' %(index)) 197 | assert os.path.exists(fn_velo), 'Broken dataset %s' % (fn_velo) 198 | assert os.path.exists(fn_label), 'Broken dataset %s' % (fn_label) 199 | 200 | points = np.fromfile(fn_velo, dtype=np.float32).reshape(-1, 4) 201 | raw_label = np.fromfile(fn_label, dtype=np.uint32).reshape((-1)) 202 | 203 | if raw_label.shape[0] == points.shape[0]: 204 | label = raw_label & 0xFFFF # semantic label in lower half 205 | inst_label = raw_label >> 16 # instance id in upper half 206 | assert((label + (inst_label << 16) == raw_label).all()) # sanity check 207 | else: 208 | print("Points shape: ", points.shape) 209 | print("Label shape: ", label.shape) 210 | raise ValueError("Scan and Label don't contain same number of points") 211 | 212 | # Map to learning 20 classes 213 | label = np.array([self.learning_map[x] for x in label], dtype=np.int32) 214 | 215 | # Drop class -> 0 216 | drop_class_0 = np.where(label != 0) 217 | points = points[drop_class_0] 218 | label = label[drop_class_0] - 1 219 | assert (label >=0).all and (label (-fov[1] * np.pi / 180), \ 244 | np.arctan2(n, m) < (-fov[0] * np.pi / 180)) 245 | elif fov_type == 'v': 246 | return np.logical_and(np.arctan2(n, m) < (fov[1] * np.pi / 180), \ 247 | np.arctan2(n, m) > (fov[0] * np.pi / 180)) 248 | else: 249 | raise NameError("fov type must be set between 'h' and 'v' ") 250 | 251 | def box_in_range(self,x,y,z,d, x_range, y_range, z_range, d_range): 252 | """ extract filtered in-range velodyne coordinates based on x,y,z limit """ 253 | return np.logical_and.reduce(( 254 | x > x_range[0], x < x_range[1], 255 | y > y_range[0], y < y_range[1], 256 | z > z_range[0], z < z_range[1], 257 | d > d_range[0], d < d_range[1])) 258 | 259 | def points_basic_filter(self, points): 260 | """ 261 | filter points based on h,v FOV and x,y,z distance range. 262 | x,y,z direction is based on velodyne coordinates 263 | 1. azimuth & elevation angle limit check 264 | 2. x,y,z distance limit 265 | return a bool array 266 | """ 267 | assert points.shape[1] == 4, points.shape # [N,3] 268 | x, y, z = points[:, 0], points[:, 1], points[:, 2] 269 | d = np.sqrt(x ** 2 + y ** 2 + z ** 2) # this is much faster than d = np.sqrt(np.power(points,2).sum(1)) 270 | 271 | # extract in-range fov points 272 | h_points = self.hv_in_range(x, y, self.h_fov, fov_type='h') 273 | v_points = self.hv_in_range(d, z, self.v_fov, fov_type='v') 274 | combined = np.logical_and(h_points, v_points) 275 | 276 | # extract in-range x,y,z points 277 | in_range = self.box_in_range(x,y,z,d, self.x_range, self.y_range, self.z_range, self.d_range) 278 | combined = np.logical_and(combined, in_range) 279 | 280 | return combined 281 | 282 | def calib_velo2cam(self, fn_v2c): 283 | """ 284 | get Rotation(R : 3x3), Translation(T : 3x1) matrix info 285 | using R,T matrix, we can convert velodyne coordinates to camera coordinates 286 | """ 287 | for line in open(fn_v2c, "r"): 288 | (key, val) = line.split(':', 1) 289 | if key == 'R': 290 | R = np.fromstring(val, sep=' ') 291 | R = R.reshape(3, 3) 292 | if key == 'T': 293 | T = np.fromstring(val, sep=' ') 294 | T = T.reshape(3, 1) 295 | return R, T 296 | 297 | def calib_cam2cam(self, fn_c2c, mode = '02'): 298 | """ 299 | If your image is 'rectified image' :get only Projection(P : 3x4) matrix is enough 300 | but if your image is 'distorted image'(not rectified image) : 301 | you need undistortion step using distortion coefficients(5 : D) 302 | In this code, only P matrix info is used for rectified image 303 | """ 304 | # with open(fn_c2c, "r") as f: c2c_file = f.readlines() 305 | for line in open(fn_c2c, "r"): 306 | (key, val) = line.split(':', 1) 307 | if key == ('P_rect_' + mode): 308 | P = np.fromstring(val, sep=' ') 309 | P = P.reshape(3, 4) 310 | P = P[:3, :3] # erase 4th column ([0,0,0]) 311 | return P 312 | 313 | def project_3d_to_2d(self, pts_3d): 314 | assert pts_3d.shape[1] == 3, pts_3d.shape 315 | pts_3d = pts_3d.copy() 316 | 317 | # Concat and change shape from [N,3] to [N,4] to [4,N] 318 | one_mat = np.ones((pts_3d.shape[0], 1),dtype=np.float32) 319 | xyz_v = np.concatenate((pts_3d, one_mat), axis=1).T 320 | 321 | # convert velodyne coordinates(X_v, Y_v, Z_v) to camera coordinates(X_c, Y_c, Z_c) 322 | for i in range(xyz_v.shape[1]): 323 | xyz_v[:3, i] = np.matmul(self.RT, xyz_v[:, i]) 324 | 325 | xyz_c = xyz_v[:3] 326 | 327 | # convert camera coordinates(X_c, Y_c, Z_c) image(pixel) coordinates(x,y) 328 | for i in range(xyz_c.shape[1]): 329 | xyz_c[:, i] = np.matmul(self.P, xyz_c[:, i]) 330 | 331 | # normalize image(pixel) coordinates(x,y) 332 | xy_i = xyz_c / xyz_c[2] 333 | 334 | # get pixels location 335 | pts_2d = xy_i[:2].T 336 | return pts_2d 337 | 338 | def torch_project_3d_to_2d(self, pts_3d): 339 | assert pts_3d.shape[1] == 3, pts_3d.shape 340 | pts_3d = pts_3d.copy() 341 | 342 | # Create a [N,1] array 343 | one_mat = np.ones((pts_3d.shape[0], 1),dtype=np.float32) 344 | xyz_v = np.concatenate((pts_3d, one_mat), axis=1) 345 | 346 | RT = torch.from_numpy(self.RT).float().cuda() 347 | P = torch.from_numpy(self.P).float().cuda() 348 | xyz_v = torch.from_numpy(xyz_v).float().cuda() 349 | 350 | assert xyz_v.size(1) == 4, xyz_v.size() 351 | 352 | xyz_v = xyz_v.unsqueeze(2) 353 | RT_rep = RT.expand(xyz_v.size(0),3,4) 354 | P_rep = P.expand(xyz_v.size(0),3,3) 355 | 356 | xyz_c = torch.bmm(RT_rep, xyz_v) 357 | #log.info(xyz_c.shape, RT_rep.shape, xyz_v.shape) 358 | 359 | xy_v = torch.bmm(P_rep, xyz_c) 360 | #log.msg(xy_v.shape, P_rep.shape, xyz_c.shape) 361 | 362 | xy_i = xy_v.squeeze(2).transpose(1,0) 363 | xy_n = xy_i / xy_i[2] 364 | pts_2d = (xy_n[:2]).transpose(1,0) 365 | 366 | return pts_2d.detach().cpu().numpy() 367 | 368 | def draw_2d_points(self, pts_2d, colors, image = None): 369 | """ draw 2d points in camera image """ 370 | assert pts_2d.shape[1] == 2, pts_2d.shape 371 | 372 | if image is None: 373 | image = self.frame.copy() 374 | pts = pts_2d.astype(np.int32).tolist() 375 | 376 | for (x,y),c in zip(pts, colors.tolist()): 377 | cv2.circle(image, (x, y), 2, c, -1) 378 | 379 | return image 380 | 381 | def draw_2d_top_view(self, pcd_3d, colors): 382 | """ draw 2d points in camera image """ 383 | assert pcd_3d.shape[1] == 3, pcd_3d.shape 384 | 385 | image = np.zeros((600,800,3),dtype=np.uint8) 386 | 387 | for (x,y,z),c in zip(pcd_3d.tolist(), colors.tolist()): 388 | X = int(-x*800+600) 389 | Y = int(-y*800+400) 390 | cv2.circle(image, (Y,X), 3, c, -1) 391 | 392 | return image 393 | 394 | def get_max_index(self,part): 395 | return self.length[part] -------------------------------------------------------------------------------- /data_utils/redis_utils.py: -------------------------------------------------------------------------------- 1 | import redis 2 | import numpy as np 3 | import struct 4 | from PIL import Image 5 | import io 6 | 7 | class Mat_Redis_Utils(): 8 | def __init__(self, host='127.0.0.1', port=6379, db=0): 9 | self.handle = redis.Redis(host, port, db) 10 | self.dtype_table = [ 11 | np.int8,np.int16,np.int32,np.int64, 12 | np.uint8,np.uint16,np.uint32,np.uint64, 13 | np.float16,np.float32,np.float64 14 | ] 15 | 16 | def mat_to_bytes(self, arr): 17 | dtype_id = self.dtype_table.index(arr.dtype) 18 | header = struct.pack('>'+'I' * (2+arr.ndim), dtype_id, arr.ndim, *arr.shape) 19 | data = header + arr.tobytes() 20 | return data 21 | 22 | def bytes_to_mat(self, data): 23 | dtype_id, ndim = struct.unpack('>II',data[:8]) 24 | dtype = self.dtype_table[dtype_id] 25 | shape = struct.unpack('>'+'I'*ndim, data[8:4*(2+ndim)]) 26 | arr = np.frombuffer(data[4*(2+ndim):], dtype=dtype, offset=0) 27 | arr = arr.reshape((shape)) 28 | return arr 29 | 30 | def set(self, key, arr): 31 | return self.handle.set(key, self.mat_to_bytes(arr)) 32 | 33 | def get(self, key, dtype = np.float32): 34 | data = self.handle.get(key) 35 | if data is None: 36 | raise ValueError('%s not exist in Redis'%(key)) 37 | return self.bytes_to_mat(data) 38 | 39 | def set_PIL(self, key, fn): 40 | return self.handle.set(key, open(fn, "rb").read()) 41 | 42 | def get_PIL(self, key): 43 | data = self.handle.get(key) 44 | if data is None: 45 | raise ValueError('%s not exist in Redis'%(key)) 46 | return Image.open(io.BytesIO(data)) 47 | 48 | def exists(self, key): 49 | return bool(self.handle.execute_command('EXISTS ' + key)) 50 | 51 | def ls_keys(self): 52 | return self.handle.execute_command('KEYS *') 53 | 54 | def flush_all(self): 55 | print('Del all keys in Redis') 56 | return self.handle.execute_command('flushall') -------------------------------------------------------------------------------- /model/chamfer.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def num(x): 5 | return x.detach().cpu().numpy() 6 | 7 | def chamfer_non_batch(p1, p2): 8 | ''' 9 | Calculate Chamfer Distance between two point sets 10 | :param p1: size[1, N, D] 11 | :param p2: size[1, M, D] 12 | :return: sum of Chamfer Distance of two point sets 13 | ''' 14 | assert p1.size(0) == 1 and p2.size(0) == 1 15 | assert p1.size(2) == p2.size(2) 16 | 17 | # print(p1.cpu().numpy().shape, p2.cpu().numpy().shape) 18 | p1 = p1.repeat(p2.size(1), 1, 1) 19 | # print('repeat p1', p1.cpu().numpy().shape) 20 | p1 = p1.transpose(0, 1) 21 | # print('transpose', p1.cpu().numpy().shape) 22 | p2 = p2.repeat(p1.size(0), 1, 1) 23 | # print('repeat p2',p2.cpu().numpy().shape) 24 | 25 | dist = torch.add(p1, torch.neg(p2)) 26 | dist = torch.norm(dist, 2, dim=2) 27 | dist = torch.min(dist, dim=1)[0] 28 | dist = torch.sum(dist) 29 | 30 | return dist 31 | 32 | def chamfer_batch(p1, p2): 33 | ''' 34 | Calculate Chamfer Distance between two point sets 35 | :param p1: size[B, N, D] 36 | :param p2: size[B, M, D] 37 | :return: sum of all batches of Chamfer Distance of two point sets 38 | ''' 39 | assert p1.size(0) == p2.size(0) and p1.size(2) == p2.size(2) 40 | 41 | p1,p2 = p1.unsqueeze(1), p2.unsqueeze(1) 42 | 43 | p1 = p1.repeat(1, p2.size(2), 1, 1) 44 | p1 = p1.transpose(1, 2) 45 | 46 | p2 = p2.repeat(1, p1.size(1), 1, 1) 47 | 48 | dist = torch.add(p1, torch.neg(p2)) 49 | dist = torch.norm(dist, 2, dim=3) 50 | dist = torch.min(dist, dim=2)[0] 51 | dist = torch.sum(dist)/p1.size(0) 52 | 53 | return dist 54 | 55 | if __name__ == '__main__': 56 | p1 = torch.from_numpy(np.array([[[1., 2, 3], [4, 5, 6], [3, 5, 6], [5, 6, 7]],[[2., 2, 3], [3, 5, 6], [4, 5, 6], [8, 6, 7]]])) 57 | p2 = torch.from_numpy(np.array([[[3., 7, 8], [1, 4, 5]],[[3., 8, 8], [2, 4, 5]]])) 58 | 59 | p1_1 = torch.from_numpy(np.array([[[1., 2, 3], [4, 5, 6], [3, 5, 6], [5, 6, 7]]])) 60 | p1_2 = torch.from_numpy(np.array([[[2., 2, 3], [3, 5, 6], [4, 5, 6], [8, 6, 7]]])) 61 | 62 | p2_1 = torch.from_numpy(np.array([[[3., 7, 8], [1, 4, 5]]])) 63 | p2_2 = torch.from_numpy(np.array([[[3., 8, 8], [2, 4, 5]]])) 64 | 65 | print('p1 size is {}, p2 size is {}'.format(p1.size(), p2.size())) 66 | print(chamfer_batch(p1, p2)) 67 | print((chamfer_batch(p1_1, p2_1)+chamfer_batch(p1_2, p2_2))/2) 68 | -------------------------------------------------------------------------------- /model/pointnet.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.parallel 4 | import torch.utils.data 5 | from torch.autograd import Variable 6 | import numpy as np 7 | import torch.nn.functional as F 8 | 9 | 10 | class STN3d(nn.Module): 11 | def __init__(self): 12 | super(STN3d, self).__init__() 13 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 14 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 15 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 16 | self.fc1 = nn.Linear(1024, 512) 17 | self.fc2 = nn.Linear(512, 256) 18 | self.fc3 = nn.Linear(256, 9) 19 | self.relu = nn.ReLU() 20 | 21 | self.bn1 = nn.BatchNorm1d(64) 22 | self.bn2 = nn.BatchNorm1d(128) 23 | self.bn3 = nn.BatchNorm1d(1024) 24 | self.bn4 = nn.BatchNorm1d(512) 25 | self.bn5 = nn.BatchNorm1d(256) 26 | 27 | def forward(self, x): 28 | batchsize = x.size()[0] 29 | x = F.relu(self.bn1(self.conv1(x))) 30 | x = F.relu(self.bn2(self.conv2(x))) 31 | x = F.relu(self.bn3(self.conv3(x))) 32 | x = torch.max(x, 2, keepdim=True)[0] 33 | x = x.view(-1, 1024) 34 | 35 | x = F.relu(self.bn4(self.fc1(x))) 36 | x = F.relu(self.bn5(self.fc2(x))) 37 | x = self.fc3(x) 38 | 39 | iden = Variable(torch.from_numpy(np.array([1, 0, 0, 0, 1, 0, 0, 0, 1]).astype(np.float32))).view(1, 9).repeat( 40 | batchsize, 1) 41 | if x.is_cuda: 42 | iden = iden.cuda() 43 | x = x + iden 44 | x = x.view(-1, 3, 3) 45 | return x 46 | 47 | class STNkd(nn.Module): 48 | def __init__(self, k=64): 49 | super(STNkd, self).__init__() 50 | self.conv1 = torch.nn.Conv1d(k, 64, 1) 51 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 52 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 53 | self.fc1 = nn.Linear(1024, 512) 54 | self.fc2 = nn.Linear(512, 256) 55 | self.fc3 = nn.Linear(256, k * k) 56 | self.relu = nn.ReLU() 57 | 58 | self.bn1 = nn.BatchNorm1d(64) 59 | self.bn2 = nn.BatchNorm1d(128) 60 | self.bn3 = nn.BatchNorm1d(1024) 61 | self.bn4 = nn.BatchNorm1d(512) 62 | self.bn5 = nn.BatchNorm1d(256) 63 | 64 | self.k = k 65 | 66 | def forward(self, x): 67 | batchsize = x.size()[0] 68 | x = F.relu(self.bn1(self.conv1(x))) 69 | x = F.relu(self.bn2(self.conv2(x))) 70 | x = F.relu(self.bn3(self.conv3(x))) 71 | x = torch.max(x, 2, keepdim=True)[0] 72 | x = x.view(-1, 1024) 73 | 74 | x = F.relu(self.bn4(self.fc1(x))) 75 | x = F.relu(self.bn5(self.fc2(x))) 76 | x = self.fc3(x) 77 | 78 | eye_mat = np.eye(self.k).flatten().astype(np.float32) 79 | iden = Variable(torch.from_numpy(eye_mat)).view(1, self.k * self.k).repeat(batchsize, 1) 80 | if x.is_cuda: 81 | iden = iden.cuda() 82 | x = x + iden 83 | x = x.view(-1, self.k, self.k) 84 | return x 85 | 86 | class PointNetEncoder(nn.Module): 87 | def __init__(self, global_feat=True, input_dims = 4, feature_transform=False): 88 | super(PointNetEncoder, self).__init__() 89 | self.stn = STNkd(k = input_dims) 90 | self.conv1 = torch.nn.Conv1d(input_dims, 64, 1) 91 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 92 | self.conv3 = torch.nn.Conv1d(128, 1024, 1) 93 | self.bn1 = nn.BatchNorm1d(64) 94 | self.bn2 = nn.BatchNorm1d(128) 95 | self.bn3 = nn.BatchNorm1d(1024) 96 | self.global_feat = global_feat 97 | self.feature_transform = feature_transform 98 | if self.feature_transform: 99 | self.fstn = STNkd(k=64) 100 | 101 | def forward(self, x): 102 | # x -> [batch, channels, n_pts] 103 | n_pts = x.size()[2] 104 | trans = self.stn(x) 105 | x = x.transpose(2, 1) 106 | x = torch.bmm(x, trans) 107 | x = x.transpose(2, 1) 108 | x = F.relu(self.bn1(self.conv1(x))) 109 | 110 | if self.feature_transform: 111 | trans_feat = self.fstn(x) 112 | x = x.transpose(2, 1) 113 | x = torch.bmm(x, trans_feat) 114 | x = x.transpose(2, 1) 115 | else: 116 | trans_feat = None 117 | 118 | pointfeat = x 119 | x = F.relu(self.bn2(self.conv2(x))) 120 | x = self.bn3(self.conv3(x)) 121 | #print('before max', x.shape) 122 | x = torch.max(x, 2, keepdim=True)[0] 123 | x = x.view(-1, 1024) 124 | #print('after max', x.shape) 125 | if self.global_feat: 126 | return x, trans, trans_feat 127 | else: 128 | x = x.view(-1, 1024, 1).repeat(1, 1, n_pts) 129 | #print('before return', x.shape, pointfeat.shape) 130 | #print('return', torch.cat([x, pointfeat], 1).shape) 131 | return torch.cat([x, pointfeat], 1), trans, trans_feat 132 | 133 | class PointNetCls(nn.Module): 134 | def __init__(self, k=2, feature_transform=False): 135 | super(PointNetCls, self).__init__() 136 | self.feature_transform = feature_transform 137 | self.feat = PointNetEncoder(global_feat=True, feature_transform=feature_transform, input_dims=3) 138 | self.fc1 = nn.Linear(1024, 512) 139 | self.fc2 = nn.Linear(512, 256) 140 | self.fc3 = nn.Linear(256, k) 141 | self.dropout = nn.Dropout(p=0.3) 142 | self.bn1 = nn.BatchNorm1d(512) 143 | self.bn2 = nn.BatchNorm1d(256) 144 | self.relu = nn.ReLU() 145 | 146 | def forward(self, x): 147 | x, trans, trans_feat = self.feat(x) 148 | x = F.relu(self.bn1(self.fc1(x))) 149 | x = F.relu(self.bn2(self.dropout(self.fc2(x)))) 150 | x = self.fc3(x) 151 | return F.log_softmax(x, dim=1), trans_feat 152 | 153 | class PointNetDenseCls(nn.Module): 154 | def __init__(self, cat_num=16,part_num=50): 155 | super(PointNetDenseCls, self).__init__() 156 | self.cat_num = cat_num 157 | self.part_num = part_num 158 | self.stn = STN3d() 159 | self.conv1 = torch.nn.Conv1d(3, 64, 1) 160 | self.conv2 = torch.nn.Conv1d(64, 128, 1) 161 | self.conv3 = torch.nn.Conv1d(128, 128, 1) 162 | self.conv4 = torch.nn.Conv1d(128, 512, 1) 163 | self.conv5 = torch.nn.Conv1d(512, 2048, 1) 164 | self.bn1 = nn.BatchNorm1d(64) 165 | self.bn2 = nn.BatchNorm1d(128) 166 | self.bn3 = nn.BatchNorm1d(128) 167 | self.bn4 = nn.BatchNorm1d(512) 168 | self.bn5 = nn.BatchNorm1d(2048) 169 | self.fstn = STNkd(k=128) 170 | # classification network 171 | self.fc1 = nn.Linear(2048, 256) 172 | self.fc2 = nn.Linear(256, 256) 173 | self.fc3 = nn.Linear(256, cat_num) 174 | self.dropout = nn.Dropout(p=0.3) 175 | self.bnc1 = nn.BatchNorm1d(256) 176 | self.bnc2 = nn.BatchNorm1d(256) 177 | # segmentation network 178 | self.convs1 = torch.nn.Conv1d(4944, 256, 1) 179 | self.convs2 = torch.nn.Conv1d(256, 256, 1) 180 | self.convs3 = torch.nn.Conv1d(256, 128, 1) 181 | self.convs4 = torch.nn.Conv1d(128, part_num, 1) 182 | self.bns1 = nn.BatchNorm1d(256) 183 | self.bns2 = nn.BatchNorm1d(256) 184 | self.bns3 = nn.BatchNorm1d(128) 185 | 186 | def forward(self, point_cloud, label): 187 | batchsize,_ , n_pts = point_cloud.size() 188 | # point_cloud_transformed 189 | trans = self.stn(point_cloud) 190 | point_cloud = point_cloud.transpose(2, 1) 191 | point_cloud_transformed = torch.bmm(point_cloud, trans) 192 | point_cloud_transformed = point_cloud_transformed.transpose(2, 1) 193 | 194 | # MLP 195 | out1 = F.relu(self.bn1(self.conv1(point_cloud_transformed))) 196 | out2 = F.relu(self.bn2(self.conv2(out1))) 197 | out3 = F.relu(self.bn3(self.conv3(out2))) 198 | 199 | # net_transformed 200 | trans_feat = self.fstn(out3) 201 | x = out3.transpose(2, 1) 202 | net_transformed = torch.bmm(x, trans_feat) 203 | net_transformed = net_transformed.transpose(2, 1) 204 | 205 | # MLP 206 | out4 = F.relu(self.bn4(self.conv4(net_transformed))) 207 | out5 = self.bn5(self.conv5(out4)) 208 | out_max = torch.max(out5, 2, keepdim=True)[0] 209 | out_max = out_max.view(-1, 2048) 210 | 211 | # classification network 212 | net = F.relu(self.bnc1(self.fc1(out_max))) 213 | net = F.relu(self.bnc2(self.dropout(self.fc2(net)))) 214 | net = self.fc3(net) # [B,16] 215 | 216 | # segmentation network 217 | out_max = torch.cat([out_max, label],1) 218 | expand = out_max.view(-1, 2048+16, 1).repeat(1, 1, n_pts) 219 | concat = torch.cat([expand, out1, out2, out3, out4, out5], 1) 220 | net2 = F.relu(self.bns1(self.convs1(concat))) 221 | net2 = F.relu(self.bns2(self.convs2(net2))) 222 | net2 = F.relu(self.bns3(self.convs3(net2))) 223 | net2 = self.convs4(net2) 224 | net2 = net2.transpose(2, 1).contiguous() 225 | net2 = F.log_softmax(net2.view(-1, self.part_num), dim=-1) 226 | net2 = net2.view(batchsize, n_pts, self.part_num) # [B, N 50] 227 | 228 | return net, net2, trans_feat 229 | 230 | class PointNetSeg(nn.Module): 231 | def __init__(self,num_class, input_dims=4, feature_transform=False): 232 | super(PointNetSeg, self).__init__() 233 | self.k = num_class 234 | self.feat = PointNetEncoder(global_feat=False,input_dims = input_dims, feature_transform=feature_transform) 235 | self.conv1 = torch.nn.Conv1d(1088, 512, 1) 236 | self.conv2 = torch.nn.Conv1d(512, 256, 1) 237 | self.conv3 = torch.nn.Conv1d(256, 128, 1) 238 | self.conv4 = torch.nn.Conv1d(128, self.k, 1) 239 | self.bn1 = nn.BatchNorm1d(512) 240 | self.bn2 = nn.BatchNorm1d(256) 241 | self.bn3 = nn.BatchNorm1d(128) 242 | 243 | def forward(self, x): 244 | batchsize = x.size()[0] 245 | n_pts = x.size()[2] 246 | x, trans, trans_feat = self.feat(x) 247 | x = F.relu(self.bn1(self.conv1(x))) 248 | x = F.relu(self.bn2(self.conv2(x))) 249 | x = F.relu(self.bn3(self.conv3(x))) 250 | x = self.conv4(x) 251 | x = x.transpose(2,1).contiguous() 252 | x = F.log_softmax(x.view(-1,self.k), dim=-1) 253 | x = x.view(batchsize, n_pts, self.k) 254 | return x, trans_feat 255 | 256 | 257 | def feature_transform_reguliarzer(trans): 258 | d = trans.size()[1] 259 | I = torch.eye(d)[None, :, :] 260 | if trans.is_cuda: 261 | I = I.cuda() 262 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1) - I), dim=(1, 2))) 263 | return loss 264 | 265 | 266 | class PointNetLoss(torch.nn.Module): 267 | def __init__(self, weight=1,mat_diff_loss_scale=0.001): 268 | super(PointNetLoss, self).__init__() 269 | self.mat_diff_loss_scale = mat_diff_loss_scale 270 | self.weight = weight 271 | 272 | def forward(self, labels_pred, label, seg_pred, seg, trans_feat): 273 | seg_loss = F.nll_loss(seg_pred, seg) 274 | mat_diff_loss = feature_transform_reguliarzer(trans_feat) 275 | label_loss = F.nll_loss(labels_pred, label) 276 | 277 | loss = self.weight * seg_loss + (1-self.weight) * label_loss + mat_diff_loss * self.mat_diff_loss_scale 278 | return loss, seg_loss, label_loss 279 | 280 | 281 | if __name__ == '__main__': 282 | import os 283 | os.environ["CUDA_VISIBLE_DEVICES"] = '2' 284 | points = torch.randn(2,4,7000) 285 | model = PointNetColorGen(3, input_dims = 4, feature_transform=True) 286 | pred, trans_feat = model(points) 287 | print(pred.shape,trans_feat.shape) 288 | -------------------------------------------------------------------------------- /model/pointnet2.py: -------------------------------------------------------------------------------- 1 | import torch.nn as nn 2 | import torch 3 | import numpy as np 4 | import torch.nn.functional as F 5 | from .pointnet_util import PointNetSetAbstractionMsg,PointNetSetAbstraction,PointNetFeaturePropagation 6 | 7 | class PointNet2ClsMsg(nn.Module): 8 | def __init__(self): 9 | super(PointNet2ClsMsg, self).__init__() 10 | self.sa1 = PointNetSetAbstractionMsg( 11 | 512, [0.1, 0.2, 0.4], [16, 32, 128], 0, 12 | [ 13 | [32, 32, 64], 14 | [64, 64, 128], 15 | [64, 96, 128] 16 | ] 17 | ) 18 | self.sa2 = PointNetSetAbstractionMsg( 19 | 128, [0.2, 0.4, 0.8], [32, 64, 128], 320, 20 | [ 21 | [64, 64, 128], 22 | [128, 128, 256], 23 | [128, 128, 256] 24 | ] 25 | ) 26 | self.sa3 = PointNetSetAbstraction( 27 | None, None, None, 640 + 3, [256, 512, 1024], True 28 | ) 29 | self.fc1 = nn.Linear(1024, 512) 30 | self.bn1 = nn.BatchNorm1d(512) 31 | self.drop1 = nn.Dropout(0.4) 32 | self.fc2 = nn.Linear(512, 256) 33 | self.bn2 = nn.BatchNorm1d(256) 34 | self.drop2 = nn.Dropout(0.4) 35 | self.fc3 = nn.Linear(256, 40) 36 | 37 | def forward(self, xyz): 38 | B, _, _ = xyz.shape 39 | l1_xyz, l1_points = self.sa1(xyz, None) 40 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 41 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 42 | x = l3_points.view(B, 1024) 43 | x = self.drop1(F.relu(self.bn1(self.fc1(x)))) 44 | x = self.drop2(F.relu(self.bn2(self.fc2(x)))) 45 | x = self.fc3(x) 46 | x = F.log_softmax(x, -1) 47 | return x,l3_points 48 | 49 | class PointNet2ClsSsg(nn.Module): 50 | def __init__(self): 51 | super(PointNet2ClsSsg, self).__init__() 52 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=32, in_channel=3, mlp=[64, 64, 128], group_all=False) 53 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False) 54 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True) 55 | self.fc1 = nn.Linear(1024, 512) 56 | self.bn1 = nn.BatchNorm1d(512) 57 | self.drop1 = nn.Dropout(0.4) 58 | self.fc2 = nn.Linear(512, 256) 59 | self.bn2 = nn.BatchNorm1d(256) 60 | self.drop2 = nn.Dropout(0.4) 61 | self.fc3 = nn.Linear(256, 40) 62 | 63 | def forward(self, xyz): 64 | B, _, _ = xyz.shape 65 | l1_xyz, l1_points = self.sa1(xyz, None) 66 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 67 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 68 | x = l3_points.view(B, 1024) 69 | x = self.drop1(F.relu(self.bn1(self.fc1(x)))) 70 | x = self.drop2(F.relu(self.bn2(self.fc2(x)))) 71 | x = self.fc3(x) 72 | x = F.log_softmax(x, -1) 73 | return x 74 | 75 | class PointNet2PartSegSsg(nn.Module): #TODO part segmentation tasks 76 | def __init__(self, num_classes): 77 | super(PointNet2PartSegSsg, self).__init__() 78 | self.sa1 = PointNetSetAbstraction(npoint=512, radius=0.2, nsample=64, in_channel=3, mlp=[64, 64, 128], group_all=False) 79 | self.sa2 = PointNetSetAbstraction(npoint=128, radius=0.4, nsample=64, in_channel=128 + 3, mlp=[128, 128, 256], group_all=False) 80 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=256 + 3, mlp=[256, 512, 1024], group_all=True) 81 | self.fp3 = PointNetFeaturePropagation(in_channel=1280, mlp=[256, 256]) 82 | self.fp2 = PointNetFeaturePropagation(in_channel=384, mlp=[256, 128]) 83 | self.fp1 = PointNetFeaturePropagation(in_channel=128, mlp=[128, 128, 128]) 84 | self.conv1 = nn.Conv1d(128, 128, 1) 85 | self.bn1 = nn.BatchNorm1d(128) 86 | self.drop1 = nn.Dropout(0.5) 87 | self.conv2 = nn.Conv1d(128, num_classes, 1) 88 | 89 | def forward(self, xyz): 90 | # Set Abstraction layers 91 | l1_xyz, l1_points = self.sa1(xyz, None) 92 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 93 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 94 | # Feature Propagation layers 95 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 96 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 97 | l0_points = self.fp1(xyz, l1_xyz, None, l1_points) 98 | # FC layers 99 | feat = F.relu(self.bn1(self.conv1(l0_points))) 100 | x = self.drop1(feat) 101 | x = self.conv2(x) 102 | x = F.log_softmax(x, dim=1) 103 | x = x.permute(0, 2, 1) 104 | return x, feat 105 | 106 | class PointNet2PartSegMsg_one_hot(nn.Module): 107 | def __init__(self, num_classes): 108 | super(PointNet2PartSegMsg_one_hot, self).__init__() 109 | self.sa1 = PointNetSetAbstractionMsg(512, [0.1, 0.2, 0.4], [32, 64, 128], 0+3, [[32, 32, 64], [64, 64, 128], [64, 96, 128]]) 110 | self.sa2 = PointNetSetAbstractionMsg(128, [0.4,0.8], [64, 128], 128+128+64, [[128, 128, 256], [128, 196, 256]]) 111 | self.sa3 = PointNetSetAbstraction(npoint=None, radius=None, nsample=None, in_channel=512 + 3, mlp=[256, 512, 1024], group_all=True) 112 | self.fp3 = PointNetFeaturePropagation(in_channel=1536, mlp=[256, 256]) 113 | self.fp2 = PointNetFeaturePropagation(in_channel=576, mlp=[256, 128]) 114 | self.fp1 = PointNetFeaturePropagation(in_channel=150, mlp=[128, 128]) 115 | self.conv1 = nn.Conv1d(128, 128, 1) 116 | self.bn1 = nn.BatchNorm1d(128) 117 | self.drop1 = nn.Dropout(0.5) 118 | self.conv2 = nn.Conv1d(128, num_classes, 1) 119 | 120 | def forward(self, xyz, norm_plt, cls_label): 121 | # Set Abstraction layers 122 | B,C,N = xyz.size() 123 | l0_xyz = xyz 124 | l0_points = norm_plt 125 | l1_xyz, l1_points = self.sa1(l0_xyz, norm_plt) 126 | l2_xyz, l2_points = self.sa2(l1_xyz, l1_points) 127 | l3_xyz, l3_points = self.sa3(l2_xyz, l2_points) 128 | # Feature Propagation layers 129 | l2_points = self.fp3(l2_xyz, l3_xyz, l2_points, l3_points) 130 | l1_points = self.fp2(l1_xyz, l2_xyz, l1_points, l2_points) 131 | cls_label_one_hot = cls_label.view(B,16,1).repeat(1,1,N) 132 | l0_points = self.fp1(l0_xyz, l1_xyz, torch.cat([cls_label_one_hot,l0_xyz,l0_points],1), l1_points) 133 | # FC layers 134 | feat = F.relu(self.bn1(self.conv1(l0_points))) 135 | x = self.drop1(feat) 136 | x = self.conv2(x) 137 | x = F.log_softmax(x, dim=1) 138 | x = x.permute(0, 2, 1) 139 | return x 140 | 141 | class PointNet2SemSeg(nn.Module): 142 | def __init__(self, num_classes, feature_dims = 3): 143 | super(PointNet2SemSeg, self).__init__() 144 | self.feature_dims = feature_dims 145 | self.sa1 = PointNetSetAbstraction(1024, 0.1, 32, feature_dims + 3, [32, 32, 64], False) 146 | self.sa2 = PointNetSetAbstraction(256, 0.2, 32, 64 + 3, [64, 64, 128], False) 147 | self.sa3 = PointNetSetAbstraction(64, 0.4, 32, 128 + 3, [128, 128, 256], False) 148 | self.sa4 = PointNetSetAbstraction(16, 0.8, 32, 256 + 3, [256, 256, 512], False) 149 | 150 | self.fp4 = PointNetFeaturePropagation(768, [256, 256]) 151 | self.fp3 = PointNetFeaturePropagation(384, [256, 256]) 152 | self.fp2 = PointNetFeaturePropagation(320, [256, 128]) 153 | self.fp1 = PointNetFeaturePropagation(128, [128, 128, 128]) 154 | self.conv1 = nn.Conv1d(128, 128, 1) 155 | self.bn1 = nn.BatchNorm1d(128) 156 | self.drop1 = nn.Dropout(0.5) 157 | self.conv2 = nn.Conv1d(128, num_classes, 1) 158 | 159 | def forward(self, points): 160 | xyz,feature = points[:,:3,:],points[:,3:,:] 161 | 162 | l1_xyz, l1_feature = self.sa1(xyz, feature) 163 | l2_xyz, l2_feature = self.sa2(l1_xyz, l1_feature) 164 | l3_xyz, l3_feature = self.sa3(l2_xyz, l2_feature) 165 | l4_xyz, l4_feature = self.sa4(l3_xyz, l3_feature) 166 | 167 | l3_feature = self.fp4(l3_xyz, l4_xyz, l3_feature, l4_feature) 168 | l2_feature = self.fp3(l2_xyz, l3_xyz, l2_feature, l3_feature) 169 | l1_feature = self.fp2(l1_xyz, l2_xyz, l1_feature, l2_feature) 170 | l0_feature = self.fp1(xyz, l1_xyz, None, l1_feature) 171 | 172 | x = self.drop1(F.relu(self.bn1(self.conv1(l0_feature)))) 173 | x = self.conv2(x) 174 | x = F.log_softmax(x, dim=1) 175 | x = x.permute(0, 2, 1) 176 | return x 177 | 178 | 179 | if __name__ == '__main__': 180 | import os 181 | import torch 182 | os.environ["CUDA_VISIBLE_DEVICES"] = '0' 183 | input = torch.randn((8,3,2048)) 184 | label = torch.randn(8,16) 185 | model = PointNet2PartSegMsg_one_hot(num_classes=50) 186 | output= model(input,input,label) 187 | print(output.size()) 188 | 189 | -------------------------------------------------------------------------------- /model/pointnet_util.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import torch.nn as nn 3 | import torch.nn.functional as F 4 | from time import time 5 | import numpy as np 6 | 7 | def timeit(tag, t): 8 | print("{}: {}s".format(tag, time() - t)) 9 | return time() 10 | 11 | def pc_normalize(pc): 12 | l = pc.shape[0] 13 | centroid = np.mean(pc, axis=0) 14 | pc = pc - centroid 15 | m = np.max(np.sqrt(np.sum(pc**2, axis=1))) 16 | pc = pc / m 17 | return pc 18 | 19 | def square_distance(src, dst): 20 | """ 21 | Calculate Euclid distance between each two points. 22 | 23 | src^T * dst = xn * xm + yn * ym + zn * zm; 24 | sum(src^2, dim=-1) = xn*xn + yn*yn + zn*zn; 25 | sum(dst^2, dim=-1) = xm*xm + ym*ym + zm*zm; 26 | dist = (xn - xm)^2 + (yn - ym)^2 + (zn - zm)^2 27 | = sum(src**2,dim=-1)+sum(dst**2,dim=-1)-2*src^T*dst 28 | 29 | Input: 30 | src: source points, [B, N, C] 31 | dst: target points, [B, M, C] 32 | Output: 33 | dist: per-point square distance, [B, N, M] 34 | """ 35 | B, N, _ = src.shape 36 | _, M, _ = dst.shape 37 | dist = -2 * torch.matmul(src, dst.permute(0, 2, 1)) 38 | dist += torch.sum(src ** 2, -1).view(B, N, 1) 39 | dist += torch.sum(dst ** 2, -1).view(B, 1, M) 40 | return dist 41 | 42 | 43 | def index_points(points, idx): 44 | """ 45 | 46 | Input: 47 | points: input points data, [B, N, C] 48 | idx: sample index data, [B, S] 49 | Return: 50 | new_points:, indexed points data, [B, S, C] 51 | """ 52 | device = points.device 53 | B = points.shape[0] 54 | view_shape = list(idx.shape) 55 | view_shape[1:] = [1] * (len(view_shape) - 1) 56 | repeat_shape = list(idx.shape) 57 | repeat_shape[0] = 1 58 | batch_indices = torch.arange(B, dtype=torch.long).to(device).view(view_shape).repeat(repeat_shape) 59 | new_points = points[batch_indices, idx, :] 60 | return new_points 61 | 62 | 63 | def farthest_point_sample(xyz, npoint): 64 | """ 65 | Input: 66 | xyz: pointcloud data, [B, N, C] 67 | npoint: number of samples 68 | Return: 69 | centroids: sampled pointcloud index, [B, npoint] 70 | """ 71 | device = xyz.device 72 | B, N, C = xyz.shape 73 | centroids = torch.zeros(B, npoint, dtype=torch.long).to(device) 74 | distance = torch.ones(B, N).to(device) * 1e10 75 | farthest = torch.randint(0, N, (B,), dtype=torch.long).to(device) 76 | batch_indices = torch.arange(B, dtype=torch.long).to(device) 77 | for i in range(npoint): 78 | centroids[:, i] = farthest 79 | centroid = xyz[batch_indices, farthest, :].view(B, 1, 3) 80 | dist = torch.sum((xyz - centroid) ** 2, -1) 81 | mask = dist < distance 82 | distance[mask] = dist[mask] 83 | farthest = torch.max(distance, -1)[1] 84 | return centroids 85 | 86 | 87 | def query_ball_point(radius, nsample, xyz, new_xyz): 88 | """ 89 | Input: 90 | radius: local region radius 91 | nsample: max sample number in local region 92 | xyz: all points, [B, N, C] 93 | new_xyz: query points, [B, S, C] 94 | Return: 95 | group_idx: grouped points index, [B, S, nsample] 96 | """ 97 | device = xyz.device 98 | B, N, C = xyz.shape 99 | _, S, _ = new_xyz.shape 100 | group_idx = torch.arange(N, dtype=torch.long).to(device).view(1, 1, N).repeat([B, S, 1]) 101 | sqrdists = square_distance(new_xyz, xyz) 102 | group_idx[sqrdists > radius ** 2] = N 103 | group_idx = group_idx.sort(dim=-1)[0][:, :, :nsample] 104 | group_first = group_idx[:, :, 0].view(B, S, 1).repeat([1, 1, nsample]) 105 | mask = group_idx == N 106 | group_idx[mask] = group_first[mask] 107 | return group_idx 108 | 109 | 110 | def sample_and_group(npoint, radius, nsample, xyz, points, returnfps=False): 111 | """ 112 | Input: 113 | npoint: 114 | radius: 115 | nsample: 116 | xyz: input points position data, [B, N, C] 117 | points: input points data, [B, N, D] 118 | Return: 119 | new_xyz: sampled points position data, [B, 1, C] 120 | new_points: sampled points data, [B, 1, N, C+D] 121 | """ 122 | B, N, C = xyz.shape 123 | S = npoint 124 | fps_idx = farthest_point_sample(xyz, npoint) # [B, npoint, C] 125 | new_xyz = index_points(xyz, fps_idx) 126 | idx = query_ball_point(radius, nsample, xyz, new_xyz) 127 | grouped_xyz = index_points(xyz, idx) # [B, npoint, nsample, C] 128 | grouped_xyz_norm = grouped_xyz - new_xyz.view(B, S, 1, C) 129 | if points is not None: 130 | grouped_points = index_points(points, idx) 131 | new_points = torch.cat([grouped_xyz_norm, grouped_points], dim=-1) # [B, npoint, nsample, C+D] 132 | else: 133 | new_points = grouped_xyz_norm 134 | if returnfps: 135 | return new_xyz, new_points, grouped_xyz, fps_idx 136 | else: 137 | return new_xyz, new_points 138 | 139 | 140 | def sample_and_group_all(xyz, points): 141 | """ 142 | Input: 143 | xyz: input points position data, [B, N, C] 144 | points: input points data, [B, N, D] 145 | Return: 146 | new_xyz: sampled points position data, [B, 1, C] 147 | new_points: sampled points data, [B, 1, N, C+D] 148 | """ 149 | device = xyz.device 150 | B, N, C = xyz.shape 151 | new_xyz = torch.zeros(B, 1, C).to(device) 152 | grouped_xyz = xyz.view(B, 1, N, C) 153 | if points is not None: 154 | new_points = torch.cat([grouped_xyz, points.view(B, 1, N, -1)], dim=-1) 155 | else: 156 | new_points = grouped_xyz 157 | return new_xyz, new_points 158 | 159 | 160 | class PointNetSetAbstraction(nn.Module): 161 | def __init__(self, npoint, radius, nsample, in_channel, mlp, group_all): 162 | super(PointNetSetAbstraction, self).__init__() 163 | self.npoint = npoint 164 | self.radius = radius 165 | self.nsample = nsample 166 | self.mlp_convs = nn.ModuleList() 167 | self.mlp_bns = nn.ModuleList() 168 | last_channel = in_channel 169 | for out_channel in mlp: 170 | self.mlp_convs.append(nn.Conv2d(last_channel, out_channel, 1)) 171 | self.mlp_bns.append(nn.BatchNorm2d(out_channel)) 172 | last_channel = out_channel 173 | self.group_all = group_all 174 | 175 | def forward(self, xyz, points): 176 | """ 177 | Input: 178 | xyz: input points position data, [B, C, N] 179 | points: input points data, [B, D, N] 180 | Return: 181 | new_xyz: sampled points position data, [B, C, S] 182 | new_points_concat: sample points feature data, [B, D', S] 183 | """ 184 | xyz = xyz.permute(0, 2, 1) 185 | if points is not None: 186 | points = points.permute(0, 2, 1) 187 | 188 | if self.group_all: 189 | new_xyz, new_points = sample_and_group_all(xyz, points) 190 | else: 191 | new_xyz, new_points = sample_and_group(self.npoint, self.radius, self.nsample, xyz, points) 192 | # new_xyz: sampled points position data, [B, npoint, C] 193 | # new_points: sampled points data, [B, npoint, nsample, C+D] 194 | new_points = new_points.permute(0, 3, 2, 1) # [B, C+D, nsample,npoint] 195 | for i, conv in enumerate(self.mlp_convs): 196 | bn = self.mlp_bns[i] 197 | new_points = F.relu(bn(conv(new_points))) 198 | 199 | new_points = torch.max(new_points, 2)[0] 200 | new_xyz = new_xyz.permute(0, 2, 1) 201 | return new_xyz, new_points 202 | 203 | 204 | class PointNetSetAbstractionMsg(nn.Module): 205 | def __init__(self, npoint, radius_list, nsample_list, in_channel, mlp_list): 206 | super(PointNetSetAbstractionMsg, self).__init__() 207 | self.npoint = npoint 208 | self.radius_list = radius_list 209 | self.nsample_list = nsample_list 210 | self.conv_blocks = nn.ModuleList() 211 | self.bn_blocks = nn.ModuleList() 212 | for i in range(len(mlp_list)): 213 | convs = nn.ModuleList() 214 | bns = nn.ModuleList() 215 | last_channel = in_channel + 3 216 | for out_channel in mlp_list[i]: 217 | convs.append(nn.Conv2d(last_channel, out_channel, 1)) 218 | bns.append(nn.BatchNorm2d(out_channel)) 219 | last_channel = out_channel 220 | self.conv_blocks.append(convs) 221 | self.bn_blocks.append(bns) 222 | 223 | def forward(self, xyz, points): 224 | """ 225 | Input: 226 | xyz: input points position data, [B, C, N] 227 | points: input points data, [B, D, N] 228 | Return: 229 | new_xyz: sampled points position data, [B, C, S] 230 | new_points_concat: sample points feature data, [B, D', S] 231 | """ 232 | xyz = xyz.permute(0, 2, 1) 233 | if points is not None: 234 | points = points.permute(0, 2, 1) 235 | 236 | B, N, C = xyz.shape 237 | S = self.npoint 238 | new_xyz = index_points(xyz, farthest_point_sample(xyz, S)) 239 | new_points_list = [] 240 | for i, radius in enumerate(self.radius_list): 241 | K = self.nsample_list[i] 242 | group_idx = query_ball_point(radius, K, xyz, new_xyz) 243 | grouped_xyz = index_points(xyz, group_idx) 244 | grouped_xyz -= new_xyz.view(B, S, 1, C) 245 | if points is not None: 246 | grouped_points = index_points(points, group_idx) 247 | grouped_points = torch.cat([grouped_points, grouped_xyz], dim=-1) 248 | else: 249 | grouped_points = grouped_xyz 250 | 251 | grouped_points = grouped_points.permute(0, 3, 2, 1) # [B, D, K, S] 252 | for j in range(len(self.conv_blocks[i])): 253 | conv = self.conv_blocks[i][j] 254 | bn = self.bn_blocks[i][j] 255 | grouped_points = F.relu(bn(conv(grouped_points))) 256 | new_points = torch.max(grouped_points, 2)[0] # [B, D', S] 257 | new_points_list.append(new_points) 258 | 259 | new_xyz = new_xyz.permute(0, 2, 1) 260 | new_points_concat = torch.cat(new_points_list, dim=1) 261 | return new_xyz, new_points_concat 262 | 263 | 264 | class PointNetFeaturePropagation(nn.Module): 265 | def __init__(self, in_channel, mlp): 266 | super(PointNetFeaturePropagation, self).__init__() 267 | self.mlp_convs = nn.ModuleList() 268 | self.mlp_bns = nn.ModuleList() 269 | last_channel = in_channel 270 | for out_channel in mlp: 271 | self.mlp_convs.append(nn.Conv1d(last_channel, out_channel, 1)) 272 | self.mlp_bns.append(nn.BatchNorm1d(out_channel)) 273 | last_channel = out_channel 274 | 275 | def forward(self, xyz1, xyz2, points1, points2): 276 | """ 277 | Input: 278 | xyz1: input points position data, [B, C, N] 279 | xyz2: sampled input points position data, [B, C, S] 280 | points1: input points data, [B, D, N] 281 | points2: input points data, [B, D, S] 282 | Return: 283 | new_points: upsampled points data, [B, D', N] 284 | """ 285 | xyz1 = xyz1.permute(0, 2, 1) 286 | xyz2 = xyz2.permute(0, 2, 1) 287 | 288 | points2 = points2.permute(0, 2, 1) 289 | B, N, C = xyz1.shape 290 | _, S, _ = xyz2.shape 291 | 292 | if S == 1: 293 | interpolated_points = points2.repeat(1, N, 1) 294 | else: 295 | dists = square_distance(xyz1, xyz2) 296 | dists, idx = dists.sort(dim=-1) 297 | dists, idx = dists[:, :, :3], idx[:, :, :3] # [B, N, 3] 298 | dists[dists < 1e-10] = 1e-10 299 | weight = 1.0 / dists # [B, N, 3] 300 | weight = weight / torch.sum(weight, dim=-1).view(B, N, 1) # [B, N, 3] 301 | interpolated_points = torch.sum(index_points(points2, idx) * weight.view(B, N, 3, 1), dim=2) 302 | 303 | if points1 is not None: 304 | points1 = points1.permute(0, 2, 1) 305 | new_points = torch.cat([points1, interpolated_points], dim=-1) 306 | else: 307 | new_points = interpolated_points 308 | 309 | new_points = new_points.permute(0, 2, 1) 310 | for i, conv in enumerate(self.mlp_convs): 311 | bn = self.mlp_bns[i] 312 | new_points = F.relu(bn(conv(new_points))) 313 | return new_points 314 | 315 | -------------------------------------------------------------------------------- /model/utils.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import json 4 | import numpy as np 5 | import torch 6 | import torch.nn as nn 7 | import torch.optim as optim 8 | # import torchvision.transforms as T 9 | import torch.nn.functional as F 10 | from torch.utils.data import DataLoader 11 | 12 | from .pointnet import PointNetSeg, feature_transform_reguliarzer 13 | from .pointnet2 import PointNet2SemSeg 14 | 15 | def load_pointnet(model_name, num_classes, fn_pth): 16 | if model_name == 'pointnet': 17 | model = PointNetSeg(num_classes, input_dims = 4, feature_transform=True) 18 | else: 19 | model = PointNet2SemSeg(num_classes, feature_dims = 1) 20 | 21 | torch.backends.cudnn.benchmark = True 22 | model = torch.nn.DataParallel(model) 23 | 24 | assert fn_pth is not None,'No pretrain model' 25 | if not torch.cuda.is_available(): 26 | print('=> cuda not available') 27 | checkpoint = torch.load(fn_pth, map_location=torch.device('cpu')) 28 | else: 29 | checkpoint = torch.load(fn_pth) 30 | model.cuda() 31 | 32 | model.load_state_dict(checkpoint) 33 | model.eval() 34 | return model -------------------------------------------------------------------------------- /my_log.py: -------------------------------------------------------------------------------- 1 | import os 2 | import sys 3 | import time 4 | import numpy as np 5 | import torch 6 | import torch.nn.parallel 7 | import torch.nn as nn 8 | import torch.utils.data 9 | import torch.optim as optim 10 | # import torchvision.transforms as T 11 | import torch.nn.functional as F 12 | import matplotlib.pyplot as plt 13 | 14 | 15 | # Logging 16 | 17 | def gray(x): return '\033[90m' + str(x) + '\033[0m' 18 | def red(x): return '\033[91m' + str(x) + '\033[0m' 19 | def green(x): return '\033[92m' + str(x) + '\033[0m' 20 | def yellow(x): return '\033[93m' + str(x) + '\033[0m' 21 | def blue(x): return '\033[94m' + str(x) + '\033[0m' 22 | def magenta(x): return '\033[95m' + str(x) + '\033[0m' 23 | def cyan(x): return '\033[96m' + str(x) + '\033[0m' 24 | def white(x): return '\033[97m' + str(x) + '\033[0m' 25 | 26 | def fmt(fn_color, *args, **kwargs): 27 | tmp = '' 28 | end = '\n' 29 | for msg in args: 30 | if isinstance(msg,float): 31 | msg = '%.5f' % msg 32 | tmp += '%s ' % (fn_color(msg)) 33 | for k in kwargs.keys(): 34 | if k == 'end': 35 | end = kwargs['end'] 36 | else: 37 | msg = kwargs[k] 38 | if isinstance(msg,float): 39 | msg = '%.5f' % msg 40 | tmp += '%s: %s ' % (k, fn_color(msg)) 41 | tmp += end 42 | return tmp 43 | 44 | def print_base(fn_color, *args, **kwargs): 45 | print(fmt(fn_color, *args, **kwargs),end='') 46 | 47 | def debug(*args, **kwargs): 48 | print_base(gray, *args, **kwargs) 49 | 50 | def info(*args, **kwargs): 51 | print_base(green, *args, **kwargs) 52 | 53 | def msg(*args, **kwargs): 54 | print_base(yellow, *args, **kwargs) 55 | 56 | def warn(*args, **kwargs): 57 | print_base(magenta, *args, **kwargs) 58 | 59 | def err(*args, **kwargs): 60 | print_base(red, *args, **kwargs) 61 | 62 | 63 | # utils 64 | 65 | def mkdir(fn): 66 | os.makedirs(fn, exist_ok=True) 67 | return fn 68 | 69 | def select_avaliable(fn_list): 70 | selected = None 71 | for fn in fn_list: 72 | if os.path.exists(fn): 73 | selected = fn 74 | break 75 | if selected is None: 76 | log.err(log.yellow("Could not find dataset from"), fn_list) 77 | else: 78 | return selected 79 | 80 | # Numpy functions 81 | 82 | def num(x): 83 | return x.detach().cpu().numpy() 84 | 85 | def norm_01(x): 86 | return (x - x.min())/(x.max() - x.min() + + 1e-6) 87 | 88 | def relu(x): 89 | return np.maximum(0, x) 90 | 91 | def np_l2_sum(x): 92 | return np.sqrt(np.square(x.copy()).sum()) 93 | 94 | def np_l2_mean(x): 95 | return np.sqrt(np.square(x.copy()).mean()) 96 | 97 | def np_inf_norm(x): 98 | return np.linalg.norm(x, ord=np.inf_norm) 99 | 100 | def np_clip_by_l2norm(x, clip_norm): 101 | return x * clip_norm / np.linalg.norm(x, ord=2) 102 | 103 | def np_clip_by_infnorm(x, clip_norm): 104 | return x * clip_norm / np.linalg.norm(x, ord=np.inf) 105 | 106 | # Display functions 107 | 108 | def print_mat(x): 109 | info(x.shape, x.dtype, min=x.min(), max=x.max()) 110 | 111 | def print_l2(x): 112 | info(x.shape,min=x.min(),max=x.max(),sum_l2 =np_l2_sum(x),mean_l2=np_l2_mean(x)) 113 | 114 | def get_fig(figsize=(8,4)): 115 | fig = plt.figure(figsize=figsize, dpi=100, facecolor='w', edgecolor='k') 116 | return fig 117 | 118 | def sub_plot(fig, rows, cols, index, title, image): 119 | axis = fig.add_subplot(rows, cols, index) 120 | if title != None: 121 | axis.title.set_text(title) 122 | axis.axis('off') 123 | plt.imshow(image) 124 | 125 | 126 | # Timing 127 | 128 | class Tick(): 129 | def __init__(self, name='', silent=False): 130 | self.name = name 131 | self.silent = silent 132 | 133 | def __enter__(self): 134 | self.t_start = time.time() 135 | if not self.silent: 136 | print(cyan('> %s ... ' % (self.name)), end='') 137 | sys.stdout.flush() 138 | return self 139 | 140 | def __exit__(self, exc_type, exc_val, exc_tb): 141 | self.t_end = time.time() 142 | self.delta = self.t_end-self.t_start 143 | self.fps = 1/self.delta 144 | 145 | if not self.silent: 146 | print(cyan('[%.0f ms]' % (self.delta * 1000))) 147 | sys.stdout.flush() 148 | 149 | 150 | class Tock(): 151 | def __init__(self, name=None, report_time=True): 152 | self.name = '' if name == None else name+': ' 153 | self.report_time = report_time 154 | 155 | def __enter__(self): 156 | self.t_start = time.time() 157 | return self 158 | 159 | def __exit__(self, exc_type, exc_val, exc_tb): 160 | self.t_end = time.time() 161 | self.delta = self.t_end-self.t_start 162 | self.fps = 1/self.delta 163 | if self.report_time: 164 | print(yellow('(%s%.0fms) ' % (self.name, self.delta * 1000)), end='') 165 | else: 166 | print(yellow('.'), end='') 167 | sys.stdout.flush() 168 | -------------------------------------------------------------------------------- /partseg.py: -------------------------------------------------------------------------------- 1 | import open3d 2 | import argparse 3 | import os 4 | import time 5 | import h5py 6 | import datetime 7 | import numpy as np 8 | from matplotlib import pyplot as plt 9 | import torch 10 | from torch.utils.data import DataLoader 11 | import torch.nn.functional as F 12 | 13 | import my_log as log 14 | from tqdm import tqdm 15 | 16 | from utils import test_partseg, select_avaliable, mkdir, to_categorical 17 | from data_utils.ShapeNetDataLoader import PartNormalDataset, label_id_to_name 18 | from model.pointnet2 import PointNet2PartSegMsg_one_hot 19 | from model.pointnet import PointNetDenseCls,PointNetLoss 20 | 21 | def parse_args(notebook = False): 22 | parser = argparse.ArgumentParser('PointNet2') 23 | parser.add_argument('--model_name', type=str, default='pointnet', help='pointnet or pointnet2') 24 | parser.add_argument('--mode', default='train', help='train or eval') 25 | parser.add_argument('--batch_size', type=int, default=16, help='input batch size') 26 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers') 27 | parser.add_argument('--epoch', type=int, default=100, help='number of epochs for training') 28 | parser.add_argument('--pretrain', type=str, default=None, help='whether use pretrain model') 29 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') 30 | parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate for training') 31 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay') 32 | parser.add_argument('--optimizer', type=str, default='Adam', help='type of optimizer') 33 | parser.add_argument('--augment', default=False, action='store_true', help="Enable data augmentation") 34 | if notebook: 35 | return parser.parse_args([]) 36 | else: 37 | return parser.parse_args() 38 | 39 | root = select_avaliable([ 40 | '/media/james/HDD/James_Least/Large_Dataset/ShapeNet/shapenetcore_partanno_segmentation_benchmark_v0_normal/', 41 | '/media/james/Ubuntu_Data/dataset/ShapeNet/shapenetcore_partanno_segmentation_benchmark_v0_normal/', 42 | '/media/james/MyPassport/James/dataset/ShapeNet/shapenetcore_partanno_segmentation_benchmark_v0_normal/', 43 | '/home/james/dataset/ShapeNet/shapenetcore_partanno_segmentation_benchmark_v0_normal/', 44 | '/media/james/HDD/James_Least/Datasets/ShapeNet/shapenetcore_partanno_segmentation_benchmark_v0_normal/' 45 | ]) 46 | 47 | def _load(root): 48 | fn_cache = 'experiment/data/shapenetcore_partanno_segmentation_benchmark_v0_normal.h5' 49 | if not os.path.exists(fn_cache): 50 | log.debug('Indexing Files...') 51 | fns_full = [] 52 | fp_h5 = h5py.File(fn_cache,"w") 53 | 54 | for line in open(os.path.join(root, 'synsetoffset2category.txt'), 'r'): 55 | name,wordnet_id = line.strip().split() 56 | pt_folder = os.path.join(root, wordnet_id) 57 | log.info('Building',name, wordnet_id) 58 | for fn in tqdm(os.listdir(pt_folder)): 59 | token = fn.split('.')[0] 60 | fn_full = os.path.join(pt_folder, fn) 61 | data = np.loadtxt(fn_full).astype(np.float32) 62 | 63 | h5_index = '%s_%s'%(wordnet_id,token) 64 | fp_h5.create_dataset(h5_index, data = data) 65 | 66 | log.debug('Building cache...') 67 | fp_h5.close() 68 | 69 | log.debug('Loading from cache...') 70 | fp_h5 = h5py.File(fn_cache, 'r') 71 | cache = {} 72 | for token in fp_h5.keys(): 73 | cache[token] = fp_h5.get(token)[()] 74 | return cache 75 | 76 | def train(args): 77 | experiment_dir = mkdir('./experiment/') 78 | checkpoints_dir = mkdir('./experiment/partseg/%s/'%(args.model_name)) 79 | cache = _load(root) 80 | 81 | norm = True if args.model_name == 'pointnet' else False 82 | npoints = 2048 83 | train_ds = PartNormalDataset(root,cache,npoints=npoints, split='trainval', data_augmentation = args.augment) 84 | dataloader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers)) 85 | 86 | test_ds = PartNormalDataset(root,cache,npoints=npoints, split='test') 87 | testdataloader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers)) 88 | 89 | num_classes = 16 90 | num_part = 50 91 | log.info(len_training=len(train_ds), len_testing=len(test_ds)) 92 | log.info(num_classes=num_classes, num_part=num_part) 93 | 94 | if args.model_name == 'pointnet': 95 | model = PointNetDenseCls(cat_num=num_classes,part_num=num_part) 96 | else: 97 | model = PointNet2PartSegMsg_one_hot(num_part) 98 | 99 | torch.backends.cudnn.benchmark = True 100 | model = torch.nn.DataParallel(model).cuda() 101 | log.debug('Using gpu:',args.gpu) 102 | 103 | if args.pretrain is not None and args.pretrain != 'None': 104 | log.debug('Use pretrain model...') 105 | model.load_state_dict(torch.load(args.pretrain)) 106 | init_epoch = int(args.pretrain[:-4].split('-')[-1]) 107 | log.debug('start epoch from', init_epoch) 108 | else: 109 | log.debug('Training from scratch') 110 | init_epoch = 0 111 | 112 | if args.optimizer == 'SGD': 113 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) 114 | elif args.optimizer == 'Adam': 115 | optimizer = torch.optim.Adam( 116 | model.parameters(), 117 | lr=args.learning_rate, 118 | betas=(0.9, 0.999), 119 | eps=1e-08, 120 | weight_decay=args.decay_rate) 121 | 122 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 123 | 124 | history = {'loss':[]} 125 | best_acc = 0 126 | best_class_avg_iou = 0 127 | best_inctance_avg_iou = 0 128 | LEARNING_RATE_CLIP = 1e-5 129 | 130 | # criterion = PointNetLoss() 131 | def feature_transform_reguliarzer(trans): 132 | d = trans.size()[1] 133 | I = torch.eye(d)[None, :, :] 134 | if trans.is_cuda: 135 | I = I.cuda() 136 | loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2, 1) - I), dim=(1, 2))) 137 | return loss 138 | 139 | def PointNet_Loss(labels_pred, label, seg_pred, seg, trans_feat): 140 | mat_diff_loss_scale = 0.001 141 | weight = 1 142 | seg_loss = F.nll_loss(seg_pred, seg) 143 | mat_diff_loss = feature_transform_reguliarzer(trans_feat) 144 | label_loss = F.nll_loss(labels_pred, label) 145 | loss = weight * seg_loss + (1-weight) * label_loss + mat_diff_loss * mat_diff_loss_scale 146 | return loss, seg_loss, label_loss 147 | 148 | for epoch in range(init_epoch,args.epoch): 149 | scheduler.step() 150 | lr = max(optimizer.param_groups[0]['lr'],LEARNING_RATE_CLIP) 151 | log.info(job='partseg',model=args.model_name,gpu=args.gpu,epoch='%d/%s' % (epoch, args.epoch),lr=lr) 152 | 153 | for param_group in optimizer.param_groups: 154 | param_group['lr'] = lr 155 | 156 | for i, data in tqdm(enumerate(dataloader, 0),total=len(dataloader),smoothing=0.9): 157 | points, label, target, norm_plt = data 158 | points, label, target = points.float(), label.long(), target.long() 159 | points = points.transpose(2, 1) 160 | norm_plt = norm_plt.transpose(2, 1) 161 | points, label, target,norm_plt = points.cuda(),label.squeeze().cuda(), target.cuda(), norm_plt.cuda() 162 | optimizer.zero_grad() 163 | model = model.train() 164 | 165 | if args.model_name == 'pointnet': 166 | labels_pred, seg_pred, trans_feat = model(points, to_categorical(label, 16)) 167 | seg_pred = seg_pred.contiguous().view(-1, num_part) 168 | target = target.view(-1, 1)[:, 0] 169 | # loss, seg_loss, label_loss = criterion(labels_pred, label, seg_pred, target, trans_feat) 170 | loss, seg_loss, label_loss = PointNet_Loss(labels_pred, label, seg_pred, target, trans_feat) 171 | else: 172 | seg_pred = model(points, norm_plt, to_categorical(label, 16)) 173 | seg_pred = seg_pred.contiguous().view(-1, num_part) 174 | target = target.view(-1, 1)[:, 0] 175 | loss = F.nll_loss(seg_pred, target) 176 | 177 | history['loss'].append(loss.cpu().data.numpy()) 178 | loss.backward() 179 | optimizer.step() 180 | 181 | log.debug('clear cuda cache') 182 | torch.cuda.empty_cache() 183 | 184 | test_metrics, test_hist_acc, cat_mean_iou = test_partseg( 185 | model.eval(), 186 | testdataloader, 187 | label_id_to_name, 188 | args.model_name, 189 | num_part, 190 | ) 191 | 192 | save_model = False 193 | if test_metrics['accuracy'] > best_acc: 194 | best_acc = test_metrics['accuracy'] 195 | 196 | if test_metrics['class_avg_iou'] > best_class_avg_iou: 197 | best_class_avg_iou = test_metrics['class_avg_iou'] 198 | 199 | if test_metrics['inctance_avg_iou'] > best_inctance_avg_iou: 200 | best_inctance_avg_iou = test_metrics['inctance_avg_iou'] 201 | save_model = True 202 | 203 | if save_model: 204 | fn_pth = 'partseg-%s-%.5f-%04d.pth' % (args.model_name, best_inctance_avg_iou, epoch) 205 | log.info('Save model...',fn = fn_pth) 206 | torch.save(model.state_dict(), os.path.join(checkpoints_dir, fn_pth)) 207 | log.info(cat_mean_iou) 208 | else: 209 | log.info('No need to save model') 210 | 211 | log.warn('Curr', accuracy=test_metrics['accuracy'],class_avg_mIOU = test_metrics['class_avg_iou'], 212 | inctance_avg_mIOU = test_metrics['inctance_avg_iou']) 213 | 214 | log.warn('Best', accuracy=best_acc,class_avg_mIOU = best_class_avg_iou, 215 | inctance_avg_mIOU = best_inctance_avg_iou) 216 | 217 | def evaluate(args): 218 | cache = _load(root) 219 | norm = True if args.model_name == 'pointnet' else False 220 | test_ds = PartNormalDataset(root, cache, npoints=2048, split='test') 221 | testdataloader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False, num_workers=int(args.workers)) 222 | log.info("The number of test data is:", len(test_ds)) 223 | 224 | log.info('Building Model', args.model_name) 225 | num_classes = 16 226 | num_part = 50 227 | if args.model_name == 'pointnet2': 228 | model = PointNet2PartSegMsg_one_hot(num_part) 229 | else: 230 | model = PointNetDenseCls(cat_num=num_classes,part_num=num_part) 231 | 232 | torch.backends.cudnn.benchmark = True 233 | model = torch.nn.DataParallel(model).cuda() 234 | log.debug('Using gpu:',args.gpu) 235 | 236 | if args.pretrain is None: 237 | log.err('No pretrain model') 238 | return 239 | 240 | log.debug('Loading pretrain model...') 241 | state_dict = torch.load(args.pretrain) 242 | model.load_state_dict(state_dict) 243 | 244 | log.info('Testing pretrain model...') 245 | 246 | test_metrics, test_hist_acc, cat_mean_iou = test_partseg( 247 | model.eval(), 248 | testdataloader, 249 | label_id_to_name, 250 | args.model_name, 251 | num_part, 252 | ) 253 | 254 | log.info('test_hist_acc',len(test_hist_acc)) 255 | log.info(cat_mean_iou) 256 | log.info('Test Accuracy','%.5f' % test_metrics['accuracy']) 257 | log.info('Class avg mIOU:','%.5f' % test_metrics['class_avg_iou']) 258 | log.info('Inctance avg mIOU:','%.5f' % test_metrics['inctance_avg_iou']) 259 | 260 | def vis(args): 261 | cache = _load(root) 262 | norm = True if args.model_name == 'pointnet' else False 263 | test_ds = PartNormalDataset(root, cache, npoints=2048, split='test') 264 | testdataloader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=True, num_workers=int(args.workers)) 265 | log.info("The number of test data is:", len(test_ds)) 266 | 267 | log.info('Building Model', args.model_name) 268 | num_classes = 16 269 | num_part = 50 270 | if args.model_name == 'pointnet': 271 | model = PointNetDenseCls(cat_num=num_classes,part_num=num_part) 272 | else: 273 | model = PointNet2PartSegMsg_one_hot(num_part) 274 | 275 | torch.backends.cudnn.benchmark = True 276 | model = torch.nn.DataParallel(model) 277 | model.cuda() 278 | log.debug('Using multi GPU:',args.gpu) 279 | 280 | if args.pretrain is None: 281 | log.err('No pretrain model') 282 | return 283 | 284 | log.info('Loading pretrain model...') 285 | checkpoint = torch.load(args.pretrain) 286 | model.load_state_dict(checkpoint) 287 | 288 | log.info('Press space to exit, press Q for next frame') 289 | for batch_id, (points, label, target, norm_plt) in enumerate(testdataloader): 290 | batchsize, num_point, _= points.size() 291 | points, label, target, norm_plt = points.float(),label.long(), target.long(),norm_plt.float() 292 | points = points.transpose(2, 1) 293 | norm_plt = norm_plt.transpose(2, 1) 294 | points, label, target, norm_plt = points.cuda(), label.squeeze().cuda(), target.cuda(), norm_plt.cuda() 295 | if args.model_name == 'pointnet': 296 | labels_pred, seg_pred, _ = model(points,to_categorical(label,16)) 297 | else: 298 | seg_pred = model(points, norm_plt, to_categorical(label, 16)) 299 | pred_choice = seg_pred.max(-1)[1] 300 | log.info(seg_pred=seg_pred.shape, pred_choice=pred_choice.shape) 301 | log.info(seg_pred=seg_pred.shape, pred_choice=pred_choice.shape) 302 | 303 | cmap_plt = plt.cm.get_cmap("hsv", num_part) 304 | cmap_list = [cmap_plt(i)[:3] for i in range(num_part)] 305 | np.random.shuffle(cmap_list) 306 | cmap = np.array(cmap_list) 307 | 308 | #log.info('points',points.shape,'label',label.shape,'target',target.shape,'norm_plt',norm_plt.shape) 309 | for idx in range(batchsize): 310 | pt, gt, pred = points[idx].transpose(1, 0), target[idx], pred_choice[idx].transpose(-1, 0) 311 | # log.info('pt',pt.size(),'gt',gt.size(),'pred',pred.shape) 312 | 313 | gt_color = cmap[gt.cpu().numpy() - 1, :] 314 | pred_color = cmap[pred.cpu().numpy() - 1, :] 315 | 316 | point_cloud = open3d.geometry.PointCloud() 317 | point_cloud.points = open3d.utility.Vector3dVector(pt.cpu().numpy()) 318 | point_cloud.colors = open3d.utility.Vector3dVector(pred_color) 319 | 320 | vis = open3d.visualization.VisualizerWithKeyCallback() 321 | vis.create_window() 322 | vis.get_render_option().background_color = np.asarray([0, 0, 0]) 323 | vis.add_geometry(point_cloud) 324 | vis.register_key_callback(32, lambda vis: exit()) 325 | vis.run() 326 | vis.destroy_window() 327 | 328 | 329 | if __name__ == '__main__': 330 | args = parse_args() 331 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 332 | if args.mode == "train": 333 | train(args) 334 | if args.mode == "eval": 335 | evaluate(args) 336 | if args.mode == "vis": 337 | vis(args) -------------------------------------------------------------------------------- /pcd_utils.py: -------------------------------------------------------------------------------- 1 | # *_*coding:utf-8 *_* 2 | import os 3 | import numpy as np 4 | import torch 5 | import matplotlib.pyplot as plt 6 | from torch.autograd import Variable 7 | from tqdm import tqdm 8 | from collections import defaultdict 9 | import datetime 10 | import multiprocessing 11 | import pandas as pd 12 | import torch.nn.functional as F 13 | import sys 14 | import my_log as log 15 | import time 16 | 17 | def mkdir(fn): 18 | os.makedirs(fn, exist_ok=True) 19 | return fn 20 | 21 | def select_avaliable(fn_list): 22 | selected = None 23 | for fn in fn_list: 24 | if os.path.exists(fn): 25 | selected = fn 26 | break 27 | if selected is None: 28 | log.err(log.yellow("Could not find dataset from"), fn_list) 29 | else: 30 | return selected 31 | 32 | def to_categorical(y, num_classes): 33 | """ 1-hot encodes a tensor """ 34 | new_y = torch.eye(num_classes)[y.cpu().data.numpy(),] 35 | if (y.is_cuda): 36 | return new_y.cuda() 37 | return new_y 38 | 39 | def show_example(x, y, x_reconstruction, y_pred,save_dir, figname): 40 | x = x.squeeze().cpu().data.numpy() 41 | x = x.permute(0,2,1) 42 | y = y.cpu().data.numpy() 43 | x_reconstruction = x_reconstruction.squeeze().cpu().data.numpy() 44 | _, y_pred = torch.max(y_pred, -1) 45 | y_pred = y_pred.cpu().data.numpy() 46 | 47 | fig, ax = plt.subplots(1, 2) 48 | ax[0].imshow(x, cmap='Greys') 49 | ax[0].set_title('Input: %d' % y) 50 | ax[1].imshow(x_reconstruction, cmap='Greys') 51 | ax[1].set_title('Output: %d' % y_pred) 52 | plt.savefig(save_dir + figname + '.png') 53 | 54 | def save_checkpoint(epoch, train_accuracy, test_accuracy, model, optimizer, path,modelnet='checkpoint'): 55 | savepath = path + '/%s-%.5f-%04d.pth' % (modelnet,test_accuracy, epoch) 56 | state = { 57 | 'epoch': epoch, 58 | 'train_accuracy': train_accuracy, 59 | 'test_accuracy': test_accuracy, 60 | 'model_state_dict': model.state_dict(), 61 | 'optimizer_state_dict': optimizer.state_dict(), 62 | } 63 | torch.save(state, savepath) 64 | 65 | def test_clf(model, loader): 66 | mean_correct = [] 67 | for j, data in tqdm(enumerate(loader, 0), total=len(loader), smoothing=0.9): 68 | points, target = data 69 | target = target[:, 0] 70 | points = points.transpose(2, 1) 71 | points, target = points.cuda(), target.cuda() 72 | classifier = model.eval() 73 | pred, _ = classifier(points) 74 | pred_choice = pred.data.max(1)[1] 75 | correct = pred_choice.eq(target.long().data).cpu().sum() 76 | mean_correct.append(correct.item()/float(points.size()[0])) 77 | return np.mean(mean_correct) 78 | 79 | def compute_cat_iou(pred, target, num_classes ,iou_tabel): 80 | iou_list = [] 81 | target = target.cpu().data.numpy() 82 | for j in range(pred.size(0)): 83 | batch_pred = pred[j] 84 | batch_target = target[j] 85 | batch_choice = batch_pred.data.max(1)[1].cpu().data.numpy() 86 | for cat in range(num_classes): 87 | # intersection = np.sum((batch_target == cat) & (batch_choice == cat)) 88 | # union = float(np.sum((batch_target == cat) | (batch_choice == cat))) 89 | # iou = intersection/union if not union ==0 else 1 90 | I = np.sum(np.logical_and(batch_choice == cat, batch_target == cat)) 91 | U = np.sum(np.logical_or(batch_choice == cat, batch_target == cat)) 92 | if U == 0: 93 | iou = 1 # If the union of groundtruth and prediction points is empty, then count part IoU as 1 94 | else: 95 | iou = I / float(U) 96 | iou_tabel[cat,0] += iou 97 | iou_tabel[cat,1] += 1 98 | iou_list.append(iou) 99 | return iou_tabel,iou_list 100 | 101 | def calc_categorical_iou(pred, target, num_classes ,iou_tabel): 102 | choice = pred.max(2)[1] 103 | target.squeeze_(-1) 104 | for cat in range(num_classes): 105 | I = torch.sum((choice == cat) & (target == cat)).float() 106 | U = torch.sum((choice == cat) | (target == cat)).float() 107 | if U == 0: 108 | iou = 1 109 | else: 110 | iou = (I / U).cpu().numpy() 111 | iou_tabel[cat,0] += iou 112 | iou_tabel[cat,1] += 1 113 | return iou_tabel 114 | 115 | def compute_overall_iou(pred, target, num_classes): 116 | shape_ious = [] 117 | pred_np = pred.cpu().data.numpy() 118 | target_np = target.cpu().data.numpy() 119 | for shape_idx in range(pred.size(0)): 120 | part_ious = [] 121 | for part in range(num_classes): 122 | I = np.sum(np.logical_and(pred_np[shape_idx].max(1) == part, target_np[shape_idx] == part)) 123 | U = np.sum(np.logical_or(pred_np[shape_idx].max(1) == part, target_np[shape_idx] == part)) 124 | if U == 0: 125 | iou = 1 #If the union of groundtruth and prediction points is empty, then count part IoU as 1 126 | else: 127 | iou = I / float(U) 128 | part_ious.append(iou) 129 | shape_ious.append(np.mean(part_ious)) 130 | return shape_ious 131 | 132 | def test_partseg(model, loader, catdict, model_name, num_classes = 50): 133 | ''' catdict = {0:Airplane, 1:Airplane, ...49:Table} ''' 134 | iou_tabel = np.zeros((len(catdict),3)) 135 | iou_list = [] 136 | metrics = defaultdict(lambda:list()) 137 | hist_acc = [] 138 | # mean_correct = [] 139 | 140 | for points, label, target, norm_plt in tqdm(loader, total=len(loader), smoothing=0.9, dynamic_ncols=True): 141 | batchsize, num_point,_= points.size() 142 | points, label, target, norm_plt = Variable(points.float()),Variable(label.long()), Variable(target.long()),Variable(norm_plt.float()) 143 | points = points.transpose(2, 1) 144 | norm_plt = norm_plt.transpose(2, 1) 145 | points, label, target, norm_plt = points.cuda(), label.squeeze().cuda(), target.cuda(), norm_plt.cuda() 146 | if model_name == 'pointnet': 147 | labels_pred, seg_pred, _ = model(points,to_categorical(label,16)) 148 | else: 149 | seg_pred = model(points, norm_plt, to_categorical(label, 16)) 150 | # labels_pred_choice = labels_pred.data.max(1)[1] 151 | # labels_correct = labels_pred_choice.eq(label.long().data).cpu().sum() 152 | # mean_correct.append(labels_correct.item() / float(points.size()[0])) 153 | 154 | # print(pred.size()) 155 | iou_tabel, iou = compute_cat_iou(seg_pred,target,num_classes,iou_tabel) 156 | iou_list+=iou 157 | # shape_ious += compute_overall_iou(pred, target, num_classes) 158 | seg_pred = seg_pred.contiguous().view(-1, num_classes) 159 | target = target.view(-1, 1)[:, 0] 160 | pred_choice = seg_pred.data.max(1)[1] 161 | correct = pred_choice.eq(target.data).cpu().sum() 162 | metrics['accuracy'].append(correct.item()/ (batchsize * num_point)) 163 | 164 | iou_tabel[:,2] = iou_tabel[:,0] /iou_tabel[:,1] 165 | hist_acc += metrics['accuracy'] 166 | metrics['accuracy'] = np.mean(hist_acc) 167 | metrics['inctance_avg_iou'] = np.mean(iou_list) 168 | # metrics['label_accuracy'] = np.mean(mean_correct) 169 | 170 | iou_tabel = pd.DataFrame(iou_tabel,columns=['iou','count','mean_iou']) 171 | iou_tabel['Category_IOU'] = [catdict[i] for i in range(len(catdict)) ] 172 | cat_iou = iou_tabel.groupby('Category_IOU')['mean_iou'].mean() 173 | metrics['class_avg_iou'] = np.mean(cat_iou) 174 | 175 | return metrics, hist_acc, cat_iou 176 | 177 | def test_semseg(model, loader, catdict, model_name, num_classes): 178 | iou_tabel = np.zeros((len(catdict),3)) 179 | metrics = defaultdict(lambda:list()) 180 | 181 | with torch.no_grad(): 182 | for points, target in tqdm(loader, total=len(loader), smoothing=0.9, dynamic_ncols=True): 183 | batchsize, num_point, _ = points.size() 184 | points, target = Variable(points.float()), Variable(target.long()) 185 | points = points.transpose(2, 1) 186 | points, target = points.cuda(), target.cuda() 187 | if model_name == 'pointnet': 188 | pred, _ = model(points) 189 | else: 190 | pred = model(points) 191 | 192 | # iou_tabel, iou_list = compute_cat_iou(pred,target,num_classes,iou_tabel) 193 | iou_tabel = calc_categorical_iou(pred,target,num_classes,iou_tabel) 194 | 195 | # shape_ious += compute_overall_iou(pred, target, num_classes) 196 | pred = pred.contiguous().view(-1, num_classes) 197 | target = target.view(-1, 1)[:, 0] 198 | pred_choice = pred.data.max(1)[1] 199 | correct = pred_choice.eq(target.data).cpu().sum() 200 | metrics['accuracy'].append(correct.item()/ (batchsize * num_point)) 201 | 202 | iou_tabel[:,2] = iou_tabel[:,0] /iou_tabel[:,1] 203 | metrics['accuracy'] = np.mean(metrics['accuracy']) 204 | metrics['iou'] = np.mean(iou_tabel[:, 2]) 205 | 206 | iou_tabel = pd.DataFrame(iou_tabel,columns=['iou','count','mean_iou']) 207 | iou_tabel['Category_IOU'] = [catdict[i] for i in range(len(catdict)) ] 208 | cat_iou = iou_tabel.groupby('Category_IOU')['mean_iou'].mean() 209 | 210 | return metrics, cat_iou 211 | 212 | def compute_avg_curve(y, n_points_avg): 213 | avg_kernel = np.ones((n_points_avg,)) / n_points_avg 214 | rolling_mean = np.convolve(y, avg_kernel, mode='valid') 215 | return rolling_mean 216 | 217 | def plot_loss_curve(history,n_points_avg,n_points_plot,save_dir): 218 | curve = np.asarray(history['loss'])[-n_points_plot:] 219 | avg_curve = compute_avg_curve(curve, n_points_avg) 220 | plt.plot(avg_curve, '-g') 221 | 222 | curve = np.asarray(history['margin_loss'])[-n_points_plot:] 223 | avg_curve = compute_avg_curve(curve, n_points_avg) 224 | plt.plot(avg_curve, '-b') 225 | 226 | curve = np.asarray(history['reconstruction_loss'])[-n_points_plot:] 227 | avg_curve = compute_avg_curve(curve, n_points_avg) 228 | plt.plot(avg_curve, '-r') 229 | 230 | plt.legend(['Total Loss', 'Margin Loss', 'Reconstruction Loss']) 231 | plt.savefig(save_dir + '/'+ str(datetime.datetime.now().strftime('%Y-%m-%d %H-%M')) + '_total_result.png') 232 | plt.close() 233 | 234 | def plot_acc_curve(total_train_acc,total_test_acc,save_dir): 235 | plt.plot(total_train_acc, '-b',label = 'train_acc') 236 | plt.plot(total_test_acc, '-r',label = 'test_acc') 237 | plt.legend() 238 | plt.ylabel('acc') 239 | plt.xlabel('epoch') 240 | plt.title('Accuracy of training and test') 241 | plt.savefig(save_dir +'/'+ str(datetime.datetime.now().strftime('%Y-%m-%d %H-%M'))+'_total_acc.png') 242 | plt.close() 243 | 244 | def show_point_cloud(tuple,seg_label=[],title=None): 245 | import matplotlib.pyplot as plt 246 | if seg_label == []: 247 | x = [x[0] for x in tuple] 248 | y = [y[1] for y in tuple] 249 | z = [z[2] for z in tuple] 250 | ax = plt.subplot(111, projection='3d') 251 | ax.scatter(x, y, z, c='b', cmap='spectral') 252 | ax.set_zlabel('Z') 253 | ax.set_ylabel('Y') 254 | ax.set_xlabel('X') 255 | else: 256 | category = list(np.unique(seg_label)) 257 | color = ['b','r','g','y','w','b','p'] 258 | ax = plt.subplot(111, projection='3d') 259 | for categ_index in range(len(category)): 260 | tuple_seg = tuple[seg_label == category[categ_index]] 261 | x = [x[0] for x in tuple_seg] 262 | y = [y[1] for y in tuple_seg] 263 | z = [z[2] for z in tuple_seg] 264 | ax.scatter(x, y, z, c=color[categ_index], cmap='spectral') 265 | ax.set_zlabel('Z') 266 | ax.set_ylabel('Y') 267 | ax.set_xlabel('X') 268 | plt.title(title) 269 | plt.show() -------------------------------------------------------------------------------- /pcdseg.py: -------------------------------------------------------------------------------- 1 | import open3d 2 | import argparse 3 | import os 4 | import time 5 | import json 6 | import h5py 7 | import datetime 8 | import cv2 9 | import numpy as np 10 | import pandas as pd 11 | import torch 12 | from torch.utils.data import DataLoader 13 | import torch.nn.functional as F 14 | import torch.nn as nn 15 | 16 | from tqdm import tqdm 17 | from matplotlib import pyplot as plt 18 | import my_log as log 19 | 20 | from model.pointnet import PointNetSeg, feature_transform_reguliarzer 21 | from model.pointnet2 import PointNet2SemSeg 22 | from model.utils import load_pointnet 23 | 24 | from pcd_utils import mkdir, select_avaliable 25 | from data_utils.SemKITTI_Loader import SemKITTI_Loader 26 | from data_utils.kitti_utils import Semantic_KITTI_Utils 27 | 28 | KITTI_ROOT = os.environ['KITTI_ROOT'] 29 | 30 | def parse_args(notebook = False): 31 | parser = argparse.ArgumentParser('PointNet') 32 | parser.add_argument('--mode', default='train', choices=('train', 'eval')) 33 | parser.add_argument('--model_name', type=str, default='pointnet', choices=('pointnet', 'pointnet2')) 34 | parser.add_argument('--pn2', default=False, action='store_true') 35 | parser.add_argument('--batch_size', type=int, default=8, help='input batch size') 36 | parser.add_argument('--subset', type=str, default='inview', choices=('inview', 'all')) 37 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers') 38 | parser.add_argument('--epoch', type=int, default=100, help='number of epochs for training') 39 | parser.add_argument('--pretrain', type=str, default=None, help='whether use pretrain model') 40 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') 41 | parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate for training') 42 | parser.add_argument('--optimizer', type=str, default='Adam', help='type of optimizer') 43 | parser.add_argument('--augment', default=False, action='store_true', help="Enable data augmentation") 44 | if notebook: 45 | args = parser.parse_args([]) 46 | else: 47 | args = parser.parse_args() 48 | 49 | if args.pn2 == False: 50 | args.model_name = 'pointnet' 51 | else: 52 | args.model_name = 'pointnet2' 53 | return args 54 | 55 | def calc_decay(init_lr, epoch): 56 | return init_lr * 1/(1 + 0.03*epoch) 57 | 58 | def test_kitti_semseg(model, loader, model_name, num_classes, class_names): 59 | ious = np.zeros((num_classes,), dtype = np.float32) 60 | count = np.zeros((num_classes,), dtype = np.uint32) 61 | count[0] = 1 62 | accuracy = [] 63 | 64 | for points, target in tqdm(loader, total=len(loader), smoothing=0.9, dynamic_ncols=True): 65 | batch_size, num_point, _ = points.size() 66 | points = points.float().transpose(2, 1).cuda() 67 | target = target.long().cuda() 68 | 69 | with torch.no_grad(): 70 | if model_name == 'pointnet': 71 | pred, _ = model(points) 72 | else: 73 | pred = model(points) 74 | 75 | pred_choice = pred.argmax(-1) 76 | target = target.squeeze(-1) 77 | 78 | for class_id in range(num_classes): 79 | I = torch.sum((pred_choice == class_id) & (target == class_id)).cpu().item() 80 | U = torch.sum((pred_choice == class_id) | (target == class_id)).cpu().item() 81 | iou = 1 if U == 0 else I/U 82 | ious[class_id] += iou 83 | count[class_id] += 1 84 | 85 | correct = (pred_choice == target).sum().cpu().item() 86 | accuracy.append(correct/ (batch_size * num_point)) 87 | 88 | categorical_iou = ious / count 89 | df = pd.DataFrame(categorical_iou, columns=['mIOU'], index=class_names) 90 | df = df.sort_values(by='mIOU', ascending=False) 91 | 92 | log.info('categorical mIOU') 93 | log.msg(df) 94 | 95 | acc = np.mean(accuracy) 96 | miou = np.mean(categorical_iou[1:]) 97 | return acc, miou 98 | 99 | def train(args): 100 | experiment_dir = mkdir('experiment/') 101 | checkpoints_dir = mkdir('experiment/%s/'%(args.model_name)) 102 | 103 | kitti_utils = Semantic_KITTI_Utils(KITTI_ROOT, subset=args.subset) 104 | class_names = kitti_utils.class_names 105 | num_classes = kitti_utils.num_classes 106 | 107 | if args.subset == 'inview': 108 | train_npts = 8000 109 | test_npts = 24000 110 | 111 | if args.subset == 'all': 112 | train_npts = 50000 113 | test_npts = 100000 114 | 115 | log.info(subset=args.subset, train_npts=train_npts, test_npts=test_npts) 116 | 117 | dataset = SemKITTI_Loader(KITTI_ROOT, train_npts, train=True, subset=args.subset) 118 | dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, 119 | num_workers=args.workers, pin_memory=True) 120 | 121 | test_dataset = SemKITTI_Loader(KITTI_ROOT, test_npts, train=False, subset=args.subset) 122 | testdataloader = DataLoader(test_dataset, batch_size=int(args.batch_size/2), shuffle=False, 123 | num_workers=args.workers, pin_memory=True) 124 | 125 | if args.model_name == 'pointnet': 126 | model = PointNetSeg(num_classes, input_dims = 4, feature_transform=True) 127 | else: 128 | model = PointNet2SemSeg(num_classes, feature_dims = 1) 129 | 130 | if args.optimizer == 'SGD': 131 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) 132 | elif args.optimizer == 'Adam': 133 | optimizer = torch.optim.Adam( 134 | model.parameters(), 135 | lr=args.learning_rate, 136 | betas=(0.9, 0.999), 137 | eps=1e-08, 138 | weight_decay=1e-4) 139 | 140 | torch.backends.cudnn.benchmark = True 141 | model = torch.nn.DataParallel(model) 142 | model.cuda() 143 | log.info('Using gpu:',args.gpu) 144 | 145 | if args.pretrain is not None: 146 | log.info('Use pretrain model...') 147 | model.load_state_dict(torch.load(args.pretrain)) 148 | init_epoch = int(args.pretrain[:-4].split('-')[-1]) 149 | log.info('Restart training', epoch=init_epoch) 150 | else: 151 | log.msg('Training from scratch') 152 | init_epoch = 0 153 | 154 | best_acc = 0 155 | best_miou = 0 156 | 157 | for epoch in range(init_epoch,args.epoch): 158 | model.train() 159 | lr = calc_decay(args.learning_rate, epoch) 160 | log.info(subset=args.subset, model=args.model_name, gpu=args.gpu, epoch=epoch, lr=lr) 161 | 162 | for param_group in optimizer.param_groups: 163 | param_group['lr'] = lr 164 | 165 | for points, target in tqdm(dataloader, total=len(dataloader), smoothing=0.9, dynamic_ncols=True): 166 | points = points.float().transpose(2, 1).cuda() 167 | target = target.long().cuda() 168 | 169 | if args.model_name == 'pointnet': 170 | logits, trans_feat = model(points) 171 | else: 172 | logits = model(points) 173 | 174 | #logits = logits.contiguous().view(-1, num_classes) 175 | #target = target.view(-1, 1)[:, 0] 176 | #loss = F.nll_loss(logits, target) 177 | 178 | logits = logits.transpose(2, 1) 179 | loss = nn.CrossEntropyLoss()(logits, target) 180 | 181 | if args.model_name == 'pointnet': 182 | loss += feature_transform_reguliarzer(trans_feat) * 0.001 183 | 184 | optimizer.zero_grad() 185 | loss.backward() 186 | optimizer.step() 187 | 188 | torch.cuda.empty_cache() 189 | 190 | acc, miou = test_kitti_semseg(model.eval(), testdataloader, 191 | args.model_name,num_classes,class_names) 192 | 193 | save_model = False 194 | if acc > best_acc: 195 | best_acc = acc 196 | 197 | if miou > best_miou: 198 | best_miou = miou 199 | save_model = True 200 | 201 | if save_model: 202 | fn_pth = '%s-%s-%.5f-%04d.pth' % (args.model_name, args.subset, best_miou, epoch) 203 | log.info('Save model...',fn = fn_pth) 204 | torch.save(model.state_dict(), os.path.join(checkpoints_dir, fn_pth)) 205 | else: 206 | log.info('No need to save model') 207 | 208 | log.warn('Curr',accuracy=acc, mIOU=miou) 209 | log.warn('Best',accuracy=best_acc, mIOU=best_miou) 210 | 211 | def evaluate(args): 212 | kitti_utils = Semantic_KITTI_Utils(KITTI_ROOT, subset=args.subset) 213 | class_names = kitti_utils.class_names 214 | num_classes = kitti_utils.num_classes 215 | 216 | if args.subset == 'inview': 217 | test_npts = 24000 218 | 219 | if args.subset == 'all': 220 | test_npts = 100000 221 | 222 | log.info(subset=args.subset, test_npts=test_npts) 223 | 224 | test_dataset = SemKITTI_Loader(KITTI_ROOT, test_npts, train=False, subset=args.subset) 225 | testdataloader = DataLoader(test_dataset, batch_size=int(args.batch_size/2), shuffle=False, num_workers=args.workers) 226 | 227 | model = load_pointnet(args.model_name, kitti_utils.num_classes, args.pretrain) 228 | 229 | acc, miou = test_kitti_semseg(model.eval(), testdataloader,args.model_name,num_classes,class_names) 230 | 231 | log.info('Curr', accuracy=acc, mIOU=miou) 232 | 233 | if __name__ == '__main__': 234 | args = parse_args() 235 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 236 | if args.mode == "train": 237 | train(args) 238 | if args.mode == "eval": 239 | evaluate(args) 240 | -------------------------------------------------------------------------------- /pcdvis.py: -------------------------------------------------------------------------------- 1 | import open3d 2 | import argparse 3 | import os 4 | import time 5 | import json 6 | import h5py 7 | import datetime 8 | import cv2 9 | import yaml 10 | import colorsys 11 | import numpy as np 12 | import pandas as pd 13 | import torch 14 | from torch.utils.data import DataLoader 15 | import torch.nn.functional as F 16 | 17 | from tqdm import tqdm 18 | from matplotlib import pyplot as plt 19 | import my_log as log 20 | 21 | from model.pointnet import PointNetSeg, feature_transform_reguliarzer 22 | from model.pointnet2 import PointNet2SemSeg 23 | from model.utils import load_pointnet 24 | 25 | from pcdseg import parse_args 26 | from data_utils.SemKITTI_Loader import pcd_normalize 27 | from data_utils.kitti_utils import Semantic_KITTI_Utils 28 | 29 | KITTI_ROOT = os.environ['KITTI_ROOT'] 30 | 31 | class Window_Manager(): 32 | def __init__(self): 33 | self.param = open3d.io.read_pinhole_camera_parameters('config/ego_view.json') 34 | self.vis = open3d.visualization.VisualizerWithKeyCallback() 35 | self.vis.create_window(width=800, height=800, left=100) 36 | self.vis.register_key_callback(32, lambda vis: exit()) 37 | self.vis.get_render_option().load_from_json('config/render_option.json') 38 | self.pcd = open3d.geometry.PointCloud() 39 | 40 | def update(self, pts_3d, colors): 41 | self.pcd.points = open3d.utility.Vector3dVector(pts_3d) 42 | self.pcd.colors = open3d.utility.Vector3dVector(colors/255) 43 | self.vis.remove_geometry(self.pcd) 44 | self.vis.add_geometry(self.pcd) 45 | self.vis.get_view_control().convert_from_pinhole_camera_parameters(self.param) 46 | self.vis.update_geometry(self.pcd) 47 | self.vis.poll_events() 48 | self.vis.update_renderer() 49 | 50 | def capture_screen(self,fn): 51 | self.vis.capture_screen_image(fn, False) 52 | 53 | def export_video(): 54 | fourcc = cv2.VideoWriter_fourcc(*'XVID') 55 | font = cv2.FONT_HERSHEY_SIMPLEX 56 | out = cv2.VideoWriter('experiment/pn_compare.avi',fourcc, 15.0, (int(1600*0.8),int(740*0.8))) 57 | 58 | # mkdir('experiment/imgs/%s/'%(args.model_name)) 59 | # vis_handle.capture_screen('experiment/imgs/%s/%d_3d.png'%(args.model_name,i)) 60 | # cv2.imwrite('experiment/imgs/%s/%d_sem.png'%(args.model_name, i), img_semantic) 61 | 62 | for index in range(100, 320): 63 | pn_3d = cv2.imread('experiment/imgs/pointnet/%d_3d.png' % (index)) 64 | pn_sem = cv2.imread('experiment/imgs/pointnet/%d_sem.png' % (index)) 65 | pn2_3d = cv2.imread('experiment/imgs/pointnet2/%d_3d.png' % (index)) 66 | pn2_sem = cv2.imread('experiment/imgs/pointnet2/%d_sem.png' % (index)) 67 | 68 | pn_3d = pn_3d[160:650] 69 | pn2_3d = pn2_3d[160:650] 70 | 71 | pn_sem = cv2.resize(pn_sem, (800, 250)) 72 | pn2_sem = cv2.resize(pn2_sem, (800, 250)) 73 | 74 | pn = np.vstack((pn_3d, pn_sem)) 75 | pn2 = np.vstack((pn2_3d, pn2_sem)) 76 | 77 | cv2.putText(pn, 'PointNet', (20, 100), font,1, (255, 255, 255), 2, cv2.LINE_AA) 78 | cv2.putText(pn, 'PointNet', (20, 520), font,1, (255, 255, 255), 2, cv2.LINE_AA) 79 | cv2.putText(pn2, 'PointNet2', (20, 100), font,1, (255, 255, 255), 2, cv2.LINE_AA) 80 | cv2.putText(pn2, 'PointNet2', (20, 520), font,1, (255, 255, 255), 2, cv2.LINE_AA) 81 | 82 | merge = np.hstack((pn, pn2)) 83 | class_names = ['unlabelled', 'vehicle', 'human', 'ground', 'structure', 'nature'] 84 | colors = [[255, 255, 255],[245, 150, 100],[30, 30, 255],[255, 0, 255],[0, 200, 255],[0, 175, 0]] 85 | for i,(name,c) in enumerate(zip(class_names, colors)): 86 | cv2.putText(merge, name, (200 + i * 200, 50), font,1, [c[2],c[1],c[0]], 2, cv2.LINE_AA) 87 | 88 | cv2.line(merge,(0,70),(1600,70),(255,255,255),2) 89 | cv2.line(merge,(800,70),(800,1300),(255,255,255),2) 90 | 91 | merge = cv2.resize(merge,(0,0),fx=0.8,fy=0.8) 92 | # cv2.imshow('merge', merge) 93 | # if 32 == waitKey(1): 94 | # break 95 | out.write(merge) 96 | 97 | print(index) 98 | out.release() 99 | 100 | def vis(args): 101 | part = '01' 102 | args.subset ='inview' 103 | args.model_name = 'pointnet' 104 | 105 | kitti_utils = Semantic_KITTI_Utils(KITTI_ROOT, subset=args.subset) 106 | 107 | vis_handle = Window_Manager() 108 | if args.model_name == 'pointnet': 109 | args.pretrain = 'checkpoints/pointnet-inview-0.52077-0018.pth' 110 | else: 111 | args.pretrain = 'checkpoints/pointnet2-inview-0.55884-0001.pth' 112 | 113 | model = load_pointnet(args.model_name, kitti_utils.num_classes, args.pretrain) 114 | 115 | for index in range(0, kitti_utils.get_max_index(part)): 116 | point_cloud, label = kitti_utils.get(part, index, load_image=True) 117 | 118 | # resample point cloud 119 | length = point_cloud.shape[0] 120 | npoints = 25000 121 | choice = np.random.choice(length, npoints, replace=True) 122 | point_cloud = point_cloud[choice] 123 | label = label[choice] 124 | 125 | pts_3d = point_cloud[:,:3] 126 | pcd = pcd_normalize(point_cloud) 127 | 128 | with log.Tick(): 129 | points = torch.from_numpy(pcd).unsqueeze(0).transpose(2, 1).cuda() 130 | 131 | with torch.no_grad(): 132 | if args.model_name == 'pointnet': 133 | logits, _ = model(points) 134 | else: 135 | logits = model(points) 136 | pred = logits[0].argmax(-1).cpu().numpy() 137 | 138 | print(index, pred.shape, end='') 139 | 140 | # pts_2d = kitti_utils.project_3d_to_2d(pts_3d) 141 | pts_2d = kitti_utils.torch_project_3d_to_2d(pts_3d) 142 | 143 | vis_handle.update(pts_3d, kitti_utils.colors[pred]) 144 | sem_img = kitti_utils.draw_2d_points(pts_2d, kitti_utils.colors_bgr[pred]) 145 | 146 | cv2.imshow('semantic', sem_img) 147 | cv2.imshow('frame', cv2.cvtColor(kitti_utils.frame,cv2.COLOR_BGR2RGB)) 148 | if 32 == cv2.waitKey(1): 149 | break 150 | 151 | if __name__ == '__main__': 152 | args = parse_args() 153 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 154 | vis(args) -------------------------------------------------------------------------------- /semseg.py: -------------------------------------------------------------------------------- 1 | import open3d 2 | import argparse 3 | import os 4 | import time 5 | import h5py 6 | import datetime 7 | import numpy as np 8 | from matplotlib import pyplot as plt 9 | import torch 10 | from torch.utils.data import DataLoader 11 | import torch.nn.functional as F 12 | 13 | import my_log as log 14 | from tqdm import tqdm 15 | 16 | from utils import test_semseg, select_avaliable, mkdir 17 | from data_utils.S3DISDataLoader import S3DISDataLoader, recognize_all_data, label_id_to_name 18 | from model.pointnet2 import PointNet2SemSeg 19 | from model.pointnet import PointNetSeg, feature_transform_reguliarzer 20 | 21 | def parse_args(notebook = False): 22 | parser = argparse.ArgumentParser('PointNet') 23 | parser.add_argument('--model_name', type=str, default='pointnet', help='pointnet or pointnet2') 24 | parser.add_argument('--mode', default='train', help='train or eval') 25 | parser.add_argument('--batch_size', type=int, default=16, help='input batch size') 26 | parser.add_argument('--workers', type=int, default=4, help='number of data loading workers') 27 | parser.add_argument('--epoch', type=int, default=100, help='number of epochs for training') 28 | parser.add_argument('--pretrain', type=str, default=None, help='whether use pretrain model') 29 | parser.add_argument('--gpu', type=str, default='0', help='specify gpu device') 30 | parser.add_argument('--learning_rate', type=float, default=0.001, help='learning rate for training') 31 | parser.add_argument('--decay_rate', type=float, default=1e-4, help='weight decay') 32 | parser.add_argument('--optimizer', type=str, default='Adam', help='type of optimizer') 33 | parser.add_argument('--augment', default=False, action='store_true', help="Enable data augmentation") 34 | if notebook: 35 | return parser.parse_args([]) 36 | else: 37 | return parser.parse_args() 38 | 39 | root = select_avaliable([ 40 | '/media/james/HDD/James_Least/Large_Dataset/ShapeNet/indoor3d_sem_seg_hdf5_data/', 41 | '/media/james/Ubuntu_Data/dataset/ShapeNet/indoor3d_sem_seg_hdf5_data/', 42 | '/media/james/MyPassport/James/dataset/ShapeNet/indoor3d_sem_seg_hdf5_data/', 43 | '/home/james/dataset/ShapeNet/indoor3d_sem_seg_hdf5_data/' 44 | ]) 45 | 46 | def _load(load_train = True): 47 | dataset_tmp = 'experiment/data/indoor3d_sem_seg_hdf5_data.h5' 48 | if not os.path.exists(dataset_tmp): 49 | log.info('Loading data...') 50 | train_data, train_label, test_data, test_label = recognize_all_data(root, test_area = 5) 51 | fp_h5 = h5py.File(dataset_tmp,"w") 52 | fp_h5.create_dataset('train_data', data = train_data) 53 | fp_h5.create_dataset('train_label', data = train_label) 54 | fp_h5.create_dataset('test_data', data = test_data) 55 | fp_h5.create_dataset('test_label', data = test_label) 56 | else: 57 | log.info('Loading from h5...') 58 | fp_h5 = h5py.File(dataset_tmp, 'r') 59 | if load_train: 60 | train_data = fp_h5.get('train_data')[()] 61 | train_label = fp_h5.get('train_label')[()] 62 | test_data = fp_h5.get('test_data')[()] 63 | test_label = fp_h5.get('test_label')[()] 64 | 65 | if load_train: 66 | log.info(train_data=train_data.shape, train_label=train_label.shape) 67 | log.info(test_data=test_data.shape, test_label=test_label.shape) 68 | return train_data, train_label, test_data, test_label 69 | else: 70 | log.info(test_data=test_data.shape, test_label=test_label.shape) 71 | return test_data, test_label 72 | 73 | def train(args): 74 | experiment_dir = mkdir('./experiment/') 75 | checkpoints_dir = mkdir('./experiment/semseg/%s/'%(args.model_name)) 76 | train_data, train_label, test_data, test_label = _load() 77 | 78 | dataset = S3DISDataLoader(train_data, train_label, data_augmentation = args.augment) 79 | dataloader = DataLoader(dataset, batch_size=args.batch_size,shuffle=True, num_workers=args.workers) 80 | 81 | test_dataset = S3DISDataLoader(test_data, test_label) 82 | testdataloader = DataLoader(test_dataset, batch_size=args.batch_size,shuffle=True, num_workers=args.workers) 83 | 84 | num_classes = 13 85 | if args.model_name == 'pointnet': 86 | model = PointNetSeg(num_classes, feature_transform=True, input_dims = 9) 87 | else: 88 | model = PointNet2SemSeg(num_classes, feature_dims = 6) 89 | 90 | torch.backends.cudnn.benchmark = True 91 | model = torch.nn.DataParallel(model).cuda() 92 | log.debug('Using gpu:',args.gpu) 93 | 94 | if args.pretrain is not None: 95 | log.debug('Use pretrain model...') 96 | model.load_state_dict(torch.load(args.pretrain)) 97 | init_epoch = int(args.pretrain[:-4].split('-')[-1]) 98 | log.debug('start epoch from', init_epoch) 99 | else: 100 | log.debug('Training from scratch') 101 | init_epoch = 0 102 | 103 | if args.optimizer == 'SGD': 104 | optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9) 105 | elif args.optimizer == 'Adam': 106 | optimizer = torch.optim.Adam( 107 | model.parameters(), 108 | lr=args.learning_rate, 109 | betas=(0.9, 0.999), 110 | eps=1e-08, 111 | weight_decay=args.decay_rate) 112 | 113 | scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5) 114 | LEARNING_RATE_CLIP = 1e-5 115 | 116 | history = {'loss':[]} 117 | best_acc = 0 118 | best_meaniou = 0 119 | 120 | for epoch in range(init_epoch,args.epoch): 121 | scheduler.step() 122 | lr = max(optimizer.param_groups[0]['lr'],LEARNING_RATE_CLIP) 123 | 124 | log.info(job='semseg',model=args.model_name,gpu=args.gpu,epoch='%d/%s' % (epoch, args.epoch),lr=lr) 125 | 126 | for param_group in optimizer.param_groups: 127 | param_group['lr'] = lr 128 | 129 | for points, target in tqdm(dataloader, total=len(dataloader), smoothing=0.9, dynamic_ncols=True): 130 | points, target = points.float(), target.long() 131 | points = points.transpose(2, 1) 132 | points, target = points.cuda(), target.cuda() 133 | optimizer.zero_grad() 134 | model = model.train() 135 | 136 | if args.model_name == 'pointnet': 137 | pred, trans_feat = model(points) 138 | else: 139 | pred = model(points) 140 | 141 | pred = pred.contiguous().view(-1, num_classes) 142 | target = target.view(-1, 1)[:, 0] 143 | loss = F.nll_loss(pred, target) 144 | 145 | if args.model_name == 'pointnet': 146 | loss += feature_transform_reguliarzer(trans_feat) * 0.001 147 | 148 | history['loss'].append(loss.cpu().data.numpy()) 149 | loss.backward() 150 | optimizer.step() 151 | 152 | log.debug('clear cuda cache') 153 | torch.cuda.empty_cache() 154 | 155 | test_metrics, cat_mean_iou = test_semseg( 156 | model.eval(), 157 | testdataloader, 158 | label_id_to_name, 159 | args.model_name, 160 | num_classes, 161 | ) 162 | mean_iou = np.mean(cat_mean_iou) 163 | 164 | save_model = False 165 | if test_metrics['accuracy'] > best_acc: 166 | best_acc = test_metrics['accuracy'] 167 | 168 | if mean_iou > best_meaniou: 169 | best_meaniou = mean_iou 170 | save_model = True 171 | 172 | if save_model: 173 | fn_pth = 'semseg-%s-%.5f-%04d.pth' % (args.model_name, best_meaniou, epoch) 174 | log.info('Save model...',fn = fn_pth) 175 | torch.save(model.state_dict(), os.path.join(checkpoints_dir, fn_pth)) 176 | log.warn(cat_mean_iou) 177 | else: 178 | log.info('No need to save model') 179 | log.warn(cat_mean_iou) 180 | 181 | log.warn('Curr',accuracy=test_metrics['accuracy'], meanIOU=mean_iou) 182 | log.warn('Best',accuracy=best_acc, meanIOU=best_meaniou) 183 | 184 | 185 | def evaluate(args): 186 | test_data, test_label = _load(load_train = False) 187 | test_dataset = S3DISDataLoader(test_data,test_label) 188 | testdataloader = DataLoader(test_dataset, batch_size=args.batch_size,shuffle=True, num_workers=args.workers) 189 | 190 | log.debug('Building Model', args.model_name) 191 | num_classes = 13 192 | if args.model_name == 'pointnet': 193 | model = PointNetSeg(num_classes, feature_transform=True, input_dims = 9) 194 | else: 195 | model = PointNet2SemSeg(num_classes, feature_dims = 6) 196 | 197 | torch.backends.cudnn.benchmark = True 198 | model = torch.nn.DataParallel(model).cuda() 199 | log.debug('Using gpu:',args.gpu) 200 | 201 | if args.pretrain is None: 202 | log.err('No pretrain model') 203 | return 204 | 205 | log.debug('Loading pretrain model...') 206 | state_dict = torch.load(args.pretrain) 207 | model.load_state_dict(state_dict) 208 | 209 | test_metrics, cat_mean_iou = test_semseg( 210 | model.eval(), 211 | testdataloader, 212 | label_id_to_name, 213 | args.model_name, 214 | num_classes, 215 | ) 216 | mean_iou = np.mean(cat_mean_iou) 217 | log.info(Test_accuracy=test_metrics['accuracy'], Test_meanIOU=mean_iou) 218 | 219 | def vis(args): 220 | test_data, test_label = _load(load_train = False) 221 | test_dataset = S3DISDataLoader(test_data,test_label) 222 | testdataloader = DataLoader(test_dataset, batch_size=args.batch_size,shuffle=False, num_workers=args.workers) 223 | 224 | log.debug('Building Model', args.model_name) 225 | num_classes = 13 226 | if args.model_name == 'pointnet2': 227 | model = PointNet2SemSeg(num_classes) 228 | else: 229 | model = PointNetSeg(num_classes,feature_transform=True,semseg = True) 230 | 231 | torch.backends.cudnn.benchmark = True 232 | model = torch.nn.DataParallel(model) 233 | model.cuda() 234 | log.debug('Using gpu:',args.gpu) 235 | 236 | if args.pretrain is None: 237 | log.err('No pretrain model') 238 | return 239 | 240 | log.debug('Loading pretrain model...') 241 | checkpoint = torch.load(args.pretrain) 242 | model.load_state_dict(checkpoint) 243 | model.eval() 244 | 245 | cmap = plt.cm.get_cmap("hsv", 13) 246 | cmap = np.array([cmap(i) for i in range(13)])[:, :3] 247 | 248 | for batch_id, (points, target) in enumerate(testdataloader): 249 | log.info('Press space to exit','press Q for next frame') 250 | batchsize, num_point, _ = points.size() 251 | points, target = Variable(points.float()), Variable(target.long()) 252 | points = points.transpose(2, 1) 253 | points, target = points.cuda(), target.cuda() 254 | if args.model_name == 'pointnet2': 255 | pred = model(points) 256 | else: 257 | pred, _ = model(points) 258 | 259 | points = points[:, :3, :].transpose(-1, 1) 260 | pred_choice = pred.data.max(-1)[1] 261 | 262 | for idx in range(batchsize): 263 | pt, gt, pred = points[idx], target[idx], pred_choice[idx] 264 | gt_color = cmap[gt.cpu().numpy() - 1, :] 265 | pred_color = cmap[pred.cpu().numpy() - 1, :] 266 | 267 | point_cloud = open3d.geometry.PointCloud() 268 | point_cloud.points = open3d.utility.Vector3dVector(pt.cpu().numpy()) 269 | point_cloud.colors = open3d.Vector3dVector(gt_color) 270 | 271 | vis = open3d.visualization.VisualizerWithKeyCallback() 272 | vis.create_window() 273 | vis.get_render_option().background_color = np.asarray([0, 0, 0]) 274 | vis.add_geometry(point_cloud) 275 | 276 | vis.register_key_callback(32, lambda vis: exit()) 277 | vis.run() 278 | vis.destroy_window() 279 | 280 | if __name__ == '__main__': 281 | args = parse_args() 282 | os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu 283 | if args.mode == "train": 284 | train(args) 285 | if args.mode == "eval": 286 | evaluate(args) 287 | if args.mode == "vis": 288 | vis(args) 289 | --------------------------------------------------------------------------------