├── .gitignore ├── .gitmodules ├── LICENSE ├── README.md ├── configs ├── car_auto_T0_train_config ├── car_auto_T0_train_eval_config ├── car_auto_T0_train_train_config ├── car_auto_T1_train_config ├── car_auto_T1_train_eval_config ├── car_auto_T1_train_train_config ├── car_auto_T2_train_config ├── car_auto_T2_train_eval_config ├── car_auto_T2_train_train_config ├── car_auto_T3_train_config ├── car_auto_T3_train_eval_config ├── car_auto_T3_train_train_config ├── car_auto_T3_trainval_config ├── car_auto_T3_trainval_eval_config ├── car_auto_T3_trainval_train_config ├── car_fixed_T3_train_config ├── car_fixed_T3_train_eval_config ├── car_fixed_T3_train_train_config ├── ped_cyl_auto_T3_trainval_config ├── ped_cyl_auto_T3_trainval_eval_config └── ped_cyl_auto_T3_trainval_train_config ├── dataset └── kitti_dataset.py ├── eval.py ├── kitty_dataset.py ├── model.py ├── models ├── box_encoding.py ├── crop_aug.py ├── gnn.py ├── graph_gen.py ├── loss.py ├── models.py ├── nms.py └── preprocess.py ├── mytrain.py ├── run.py ├── scripts └── point_cloud_downsample.py ├── splits ├── train_car.txt ├── train_pedestrian_cyclist.txt ├── trainval_car.txt └── trainval_pedestrian_cyclist.txt ├── train.py ├── train.sh ├── train_tf.sh └── util ├── config_util.py ├── metrics.py ├── summary_util.py └── tf_util.py /.gitignore: -------------------------------------------------------------------------------- 1 | checkpoints* 2 | checkpoints/ 3 | saved_models 4 | checkpoints 5 | data 6 | *__pycache__* 7 | -------------------------------------------------------------------------------- /.gitmodules: -------------------------------------------------------------------------------- 1 | [submodule "kitti_native_evaluation"] 2 | path = kitti_native_evaluation 3 | url = https://github.com/asharakeh/kitti_native_evaluation 4 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2020 WeijingShi 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Point-GNN 2 | 3 | This repository is the pytorch-version reimplementation of [Point-GNN: Graph Neural Network for 3D Object Detection in a Point Cloud](http://openaccess.thecvf.com/content_CVPR_2020/papers/Shi_Point-GNN_Graph_Neural_Network_for_3D_Object_Detection_in_a_CVPR_2020_paper.pdf), CVPR 2020. 4 | It is based on original CVPR paper and their [tensorflow-version codes](https://github.com/WeijingShi/Point-GNN/) 5 | 6 | Thanks owe to authors. If you find this code useful in your research, please consider citing their work: 7 | ``` 8 | @InProceedings{Point-GNN, 9 | author = {Shi, Weijing and Rajkumar, Ragunathan (Raj)}, 10 | title = {Point-GNN: Graph Neural Network for 3D Object Detection in a Point Cloud}, 11 | booktitle = {The IEEE Conference on Computer Vision and Pattern Recognition (CVPR)}, 12 | month = {June}, 13 | year = {2020} 14 | } 15 | ``` 16 | 17 | ## Getting Started 18 | 19 | ### Prerequisites 20 | 21 | ``` 22 | conda install pytorch torchvision 23 | 24 | ``` 25 | Install torch-scatter according to your pytorch version following instructions in this url: https://github.com/rusty1s/pytorch_scatter 26 | 27 | To install other dependencies: 28 | ``` 29 | pip3 install --user opencv-python 30 | pip3 install --user open3d-python==0.7.0.0 31 | pip3 install --user scikit-learn 32 | pip3 install --user tqdm 33 | pip3 install --user shapely 34 | ``` 35 | 36 | ### KITTI Dataset 37 | 38 | We use the KITTI 3D Object Detection dataset. Please download the dataset from the KITTI website and also download the 3DOP train/val split [here](https://xiaozhichen.github.io/files/mv3d/imagesets.tar.gz). We provide extra split files for seperated classes in [splits/](splits). We recommand the following file structure: 39 | 40 | DATASET_ROOT_DIR 41 | ├── image # Left color images 42 | │ ├── training 43 | | | └── image_2 44 | │ └── testing 45 | | └── image_2 46 | ├── velodyne # Velodyne point cloud files 47 | │ ├── training 48 | | | └── velodyne 49 | │ └── testing 50 | | └── velodyne 51 | ├── calib # Calibration files 52 | │ ├── training 53 | | | └──calib 54 | │ └── testing 55 | | └── calib 56 | ├── labels # Training labels 57 | │ └── training 58 | | └── label_2 59 | └── 3DOP_splits # split files. 60 | ├── train.txt 61 | ├── train_car.txt 62 | └── ... 63 | 64 | ### Download Point-GNN 65 | 66 | Clone the repository recursively: 67 | ``` 68 | git clone https://github.com/Shudeng/Point-GNN.pytorch --recursive 69 | ``` 70 | 71 | ### Training 72 | ``` 73 | bash train.sh 74 | ``` 75 | 76 | 77 | ## License 78 | 79 | This project is licensed under the MIT License - see the [LICENSE](LICENSE) file for details 80 | 81 | 82 | 83 | 180 | -------------------------------------------------------------------------------- /configs/car_auto_T0_train_config: -------------------------------------------------------------------------------- 1 | { 2 | "box_encoding_method": "classaware_all_class_box_encoding", 3 | "downsample_by_voxel_size": null, 4 | "eval_is_training": true, 5 | "graph_gen_kwargs": { 6 | "add_rnd3d": true, 7 | "base_voxel_size": 0.8, 8 | "downsample_method": "random", 9 | "level_configs": [ 10 | { 11 | "graph_gen_kwargs": { 12 | "num_neighbors": -1, 13 | "radius": 1.0 14 | }, 15 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 16 | "graph_level": 0, 17 | "graph_scale": 1 18 | }, 19 | { 20 | "graph_gen_kwargs": { 21 | "num_neighbors": 256, 22 | "radius": 4.0 23 | }, 24 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 25 | "graph_level": 1, 26 | "graph_scale": 1 27 | } 28 | ] 29 | }, 30 | "graph_gen_method": "multi_level_local_graph_v3", 31 | "input_features": "i", 32 | "label_method": "Car", 33 | "loss": { 34 | "cls_loss_type": "softmax", 35 | "cls_loss_weight": 0.1, 36 | "loc_loss_weight": 10.0 37 | }, 38 | "model_kwargs": { 39 | "layer_configs": [ 40 | { 41 | "graph_level": 0, 42 | "kwargs": { 43 | "output_MLP_activation_type": "ReLU", 44 | "output_MLP_depth_list": [ 45 | 300, 46 | 300 47 | ], 48 | "output_MLP_normalization_type": "NONE", 49 | "point_MLP_activation_type": "ReLU", 50 | "point_MLP_depth_list": [ 51 | 32, 52 | 64, 53 | 128, 54 | 300 55 | ], 56 | "point_MLP_normalization_type": "NONE" 57 | }, 58 | "scope": "layer1", 59 | "type": "scatter_max_point_set_pooling" 60 | }, 61 | { 62 | "graph_level": 1, 63 | "kwargs": { 64 | "activation_type": "ReLU", 65 | "normalization_type": "NONE" 66 | }, 67 | "scope": "output", 68 | "type": "classaware_predictor" 69 | } 70 | ], 71 | "regularizer_kwargs": { 72 | "scale": 5e-07 73 | }, 74 | "regularizer_type": "l1" 75 | }, 76 | "model_name": "multi_layer_fast_local_graph_model_v2", 77 | "nms_overlapped_thres": 0.01, 78 | "num_classes": 4, 79 | "runtime_graph_gen_kwargs": { 80 | "add_rnd3d": false, 81 | "base_voxel_size": 0.8, 82 | "level_configs": [ 83 | { 84 | "graph_gen_kwargs": { 85 | "num_neighbors": -1, 86 | "radius": 1.0 87 | }, 88 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 89 | "graph_level": 0, 90 | "graph_scale": 0.5 91 | }, 92 | { 93 | "graph_gen_kwargs": { 94 | "num_neighbors": -1, 95 | "radius": 4.0 96 | }, 97 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 98 | "graph_level": 1, 99 | "graph_scale": 0.5 100 | } 101 | ] 102 | } 103 | } -------------------------------------------------------------------------------- /configs/car_auto_T0_train_eval_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_TEST_SAMPLE": -1, 3 | "checkpoint_path": "model", 4 | "config_path": "config", 5 | "data_aug_configs": [], 6 | "eval_dataset": "val.txt", 7 | "eval_dir": "./checkpoints/car_auto_T0_train_eval", 8 | "eval_every_second": 60, 9 | "gpu_memusage": -1, 10 | "max_step": 1400170, 11 | "train_dir": "./checkpoints/car_auto_T0_train", 12 | "visualization": false 13 | } 14 | -------------------------------------------------------------------------------- /configs/car_auto_T0_train_train_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_GPU": 2, 3 | "NUM_TEST_SAMPLE": -1, 4 | "batch_size": 4, 5 | "capacity": 1, 6 | "checkpoint_path": "model", 7 | "config_path": "config", 8 | "data_aug_configs": [ 9 | { 10 | "method_kwargs": { 11 | "expend_factor": [ 12 | 1.0, 13 | 1.0, 14 | 1.0 15 | ], 16 | "method_name": "normal", 17 | "yaw_std": 0.39269908169872414 18 | }, 19 | "method_name": "random_rotation_all" 20 | }, 21 | { 22 | "method_kwargs": { 23 | "flip_prob": 0.5 24 | }, 25 | "method_name": "random_flip_all" 26 | }, 27 | { 28 | "method_kwargs": { 29 | "appr_factor": 10, 30 | "expend_factor": [ 31 | 1.1, 32 | 1.1, 33 | 1.1 34 | ], 35 | "max_overlap_num_allowed": 100, 36 | "max_overlap_rate": 0.01, 37 | "max_trails": 100, 38 | "method_name": "normal", 39 | "xyz_std": [ 40 | 3, 41 | 0, 42 | 3 43 | ] 44 | }, 45 | "method_name": "random_box_shift" 46 | } 47 | ], 48 | "decay_factor": 0.1, 49 | "decay_step": 400000, 50 | "gpu_memusage": -1, 51 | "initial_lr": 0.125, 52 | "load_dataset_every_N_time": 0, 53 | "load_dataset_to_mem": true, 54 | "max_epoch": 1718, 55 | "max_steps": 1400000, 56 | "num_load_dataset_workers": 16, 57 | "optimizer": "sgd", 58 | "optimizer_kwargs": {}, 59 | "save_every_epoch": 20, 60 | "train_dataset": "train_car.txt", 61 | "train_dir": "./checkpoints/car_auto_T0_train", 62 | "unify_copies": true, 63 | "visualization": false 64 | } 65 | -------------------------------------------------------------------------------- /configs/car_auto_T1_train_config: -------------------------------------------------------------------------------- 1 | { 2 | "box_encoding_method": "classaware_all_class_box_encoding", 3 | "downsample_by_voxel_size": null, 4 | "eval_is_training": true, 5 | "graph_gen_kwargs": { 6 | "add_rnd3d": true, 7 | "base_voxel_size": 0.8, 8 | "downsample_method": "random", 9 | "level_configs": [ 10 | { 11 | "graph_gen_kwargs": { 12 | "num_neighbors": -1, 13 | "radius": 1.0 14 | }, 15 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 16 | "graph_level": 0, 17 | "graph_scale": 1 18 | }, 19 | { 20 | "graph_gen_kwargs": { 21 | "num_neighbors": 256, 22 | "radius": 4.0 23 | }, 24 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 25 | "graph_level": 1, 26 | "graph_scale": 1 27 | } 28 | ] 29 | }, 30 | "graph_gen_method": "multi_level_local_graph_v3", 31 | "input_features": "i", 32 | "label_method": "Car", 33 | "loss": { 34 | "cls_loss_type": "softmax", 35 | "cls_loss_weight": 0.1, 36 | "loc_loss_weight": 10.0 37 | }, 38 | "model_kwargs": { 39 | "layer_configs": [ 40 | { 41 | "graph_level": 0, 42 | "kwargs": { 43 | "output_MLP_activation_type": "ReLU", 44 | "output_MLP_depth_list": [ 45 | 300, 46 | 300 47 | ], 48 | "output_MLP_normalization_type": "NONE", 49 | "point_MLP_activation_type": "ReLU", 50 | "point_MLP_depth_list": [ 51 | 32, 52 | 64, 53 | 128, 54 | 300 55 | ], 56 | "point_MLP_normalization_type": "NONE" 57 | }, 58 | "scope": "layer1", 59 | "type": "scatter_max_point_set_pooling" 60 | }, 61 | { 62 | "graph_level": 1, 63 | "kwargs": { 64 | "auto_offset": true, 65 | "auto_offset_MLP_depth_list": [ 66 | 64, 67 | 3 68 | ], 69 | "auto_offset_MLP_feature_activation_type": "ReLU", 70 | "auto_offset_MLP_normalization_type": "NONE", 71 | "edge_MLP_activation_type": "ReLU", 72 | "edge_MLP_depth_list": [ 73 | 300, 74 | 300 75 | ], 76 | "edge_MLP_normalization_type": "NONE", 77 | "update_MLP_activation_type": "ReLU", 78 | "update_MLP_depth_list": [ 79 | 300, 80 | 300 81 | ], 82 | "update_MLP_normalization_type": "NONE" 83 | }, 84 | "scope": "layer2", 85 | "type": "scatter_max_graph_auto_center_net" 86 | }, 87 | { 88 | "graph_level": 1, 89 | "kwargs": { 90 | "activation_type": "ReLU", 91 | "normalization_type": "NONE" 92 | }, 93 | "scope": "output", 94 | "type": "classaware_predictor" 95 | } 96 | ], 97 | "regularizer_kwargs": { 98 | "scale": 5e-07 99 | }, 100 | "regularizer_type": "l1" 101 | }, 102 | "model_name": "multi_layer_fast_local_graph_model_v2", 103 | "nms_overlapped_thres": 0.01, 104 | "num_classes": 4, 105 | "runtime_graph_gen_kwargs": { 106 | "add_rnd3d": false, 107 | "base_voxel_size": 0.8, 108 | "level_configs": [ 109 | { 110 | "graph_gen_kwargs": { 111 | "num_neighbors": -1, 112 | "radius": 1.0 113 | }, 114 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 115 | "graph_level": 0, 116 | "graph_scale": 0.5 117 | }, 118 | { 119 | "graph_gen_kwargs": { 120 | "num_neighbors": -1, 121 | "radius": 4.0 122 | }, 123 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 124 | "graph_level": 1, 125 | "graph_scale": 0.5 126 | } 127 | ] 128 | } 129 | } -------------------------------------------------------------------------------- /configs/car_auto_T1_train_eval_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_TEST_SAMPLE": -1, 3 | "checkpoint_path": "model", 4 | "config_path": "config", 5 | "data_aug_configs": [], 6 | "eval_dataset": "val.txt", 7 | "eval_dir": "./checkpoints/car_auto_T1_train_eval", 8 | "eval_every_second": 60, 9 | "gpu_memusage": -1, 10 | "max_step": 1400170, 11 | "train_dir": "./checkpoints/car_auto_T1_train", 12 | "visualization": false 13 | } 14 | -------------------------------------------------------------------------------- /configs/car_auto_T1_train_train_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_GPU": 2, 3 | "NUM_TEST_SAMPLE": -1, 4 | "batch_size": 4, 5 | "capacity": 1, 6 | "checkpoint_path": "model", 7 | "config_path": "config", 8 | "data_aug_configs": [ 9 | { 10 | "method_kwargs": { 11 | "expend_factor": [ 12 | 1.0, 13 | 1.0, 14 | 1.0 15 | ], 16 | "method_name": "normal", 17 | "yaw_std": 0.39269908169872414 18 | }, 19 | "method_name": "random_rotation_all" 20 | }, 21 | { 22 | "method_kwargs": { 23 | "flip_prob": 0.5 24 | }, 25 | "method_name": "random_flip_all" 26 | }, 27 | { 28 | "method_kwargs": { 29 | "appr_factor": 10, 30 | "expend_factor": [ 31 | 1.1, 32 | 1.1, 33 | 1.1 34 | ], 35 | "max_overlap_num_allowed": 100, 36 | "max_overlap_rate": 0.01, 37 | "max_trails": 100, 38 | "method_name": "normal", 39 | "xyz_std": [ 40 | 3, 41 | 0, 42 | 3 43 | ] 44 | }, 45 | "method_name": "random_box_shift" 46 | } 47 | ], 48 | "decay_factor": 0.1, 49 | "decay_step": 400000, 50 | "gpu_memusage": -1, 51 | "initial_lr": 0.125, 52 | "load_dataset_every_N_time": 0, 53 | "load_dataset_to_mem": true, 54 | "max_epoch": 1718, 55 | "max_steps": 1400000, 56 | "num_load_dataset_workers": 16, 57 | "optimizer": "sgd", 58 | "optimizer_kwargs": {}, 59 | "save_every_epoch": 20, 60 | "train_dataset": "train_car.txt", 61 | "train_dir": "./checkpoints/car_auto_T1_train", 62 | "unify_copies": true, 63 | "visualization": false 64 | } 65 | -------------------------------------------------------------------------------- /configs/car_auto_T2_train_config: -------------------------------------------------------------------------------- 1 | { 2 | "box_encoding_method": "classaware_all_class_box_encoding", 3 | "downsample_by_voxel_size": null, 4 | "eval_is_training": true, 5 | "graph_gen_kwargs": { 6 | "add_rnd3d": true, 7 | "base_voxel_size": 0.8, 8 | "downsample_method": "random", 9 | "level_configs": [ 10 | { 11 | "graph_gen_kwargs": { 12 | "num_neighbors": -1, 13 | "radius": 1.0 14 | }, 15 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 16 | "graph_level": 0, 17 | "graph_scale": 1 18 | }, 19 | { 20 | "graph_gen_kwargs": { 21 | "num_neighbors": 256, 22 | "radius": 4.0 23 | }, 24 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 25 | "graph_level": 1, 26 | "graph_scale": 1 27 | } 28 | ] 29 | }, 30 | "graph_gen_method": "multi_level_local_graph_v3", 31 | "input_features": "i", 32 | "label_method": "Car", 33 | "loss": { 34 | "cls_loss_type": "softmax", 35 | "cls_loss_weight": 0.1, 36 | "loc_loss_weight": 10.0 37 | }, 38 | "model_kwargs": { 39 | "layer_configs": [ 40 | { 41 | "graph_level": 0, 42 | "kwargs": { 43 | "output_MLP_activation_type": "ReLU", 44 | "output_MLP_depth_list": [ 45 | 300, 46 | 300 47 | ], 48 | "output_MLP_normalization_type": "NONE", 49 | "point_MLP_activation_type": "ReLU", 50 | "point_MLP_depth_list": [ 51 | 32, 52 | 64, 53 | 128, 54 | 300 55 | ], 56 | "point_MLP_normalization_type": "NONE" 57 | }, 58 | "scope": "layer1", 59 | "type": "scatter_max_point_set_pooling" 60 | }, 61 | { 62 | "graph_level": 1, 63 | "kwargs": { 64 | "auto_offset": true, 65 | "auto_offset_MLP_depth_list": [ 66 | 64, 67 | 3 68 | ], 69 | "auto_offset_MLP_feature_activation_type": "ReLU", 70 | "auto_offset_MLP_normalization_type": "NONE", 71 | "edge_MLP_activation_type": "ReLU", 72 | "edge_MLP_depth_list": [ 73 | 300, 74 | 300 75 | ], 76 | "edge_MLP_normalization_type": "NONE", 77 | "update_MLP_activation_type": "ReLU", 78 | "update_MLP_depth_list": [ 79 | 300, 80 | 300 81 | ], 82 | "update_MLP_normalization_type": "NONE" 83 | }, 84 | "scope": "layer2", 85 | "type": "scatter_max_graph_auto_center_net" 86 | }, 87 | { 88 | "graph_level": 1, 89 | "kwargs": { 90 | "auto_offset": true, 91 | "auto_offset_MLP_depth_list": [ 92 | 64, 93 | 3 94 | ], 95 | "auto_offset_MLP_feature_activation_type": "ReLU", 96 | "auto_offset_MLP_normalization_type": "NONE", 97 | "edge_MLP_activation_type": "ReLU", 98 | "edge_MLP_depth_list": [ 99 | 300, 100 | 300 101 | ], 102 | "edge_MLP_normalization_type": "NONE", 103 | "update_MLP_activation_type": "ReLU", 104 | "update_MLP_depth_list": [ 105 | 300, 106 | 300 107 | ], 108 | "update_MLP_normalization_type": "NONE" 109 | }, 110 | "scope": "layer3", 111 | "type": "scatter_max_graph_auto_center_net" 112 | }, 113 | { 114 | "graph_level": 1, 115 | "kwargs": { 116 | "activation_type": "ReLU", 117 | "normalization_type": "NONE" 118 | }, 119 | "scope": "output", 120 | "type": "classaware_predictor" 121 | } 122 | ], 123 | "regularizer_kwargs": { 124 | "scale": 5e-07 125 | }, 126 | "regularizer_type": "l1" 127 | }, 128 | "model_name": "multi_layer_fast_local_graph_model_v2", 129 | "nms_overlapped_thres": 0.01, 130 | "num_classes": 4, 131 | "runtime_graph_gen_kwargs": { 132 | "add_rnd3d": false, 133 | "base_voxel_size": 0.8, 134 | "level_configs": [ 135 | { 136 | "graph_gen_kwargs": { 137 | "num_neighbors": -1, 138 | "radius": 1.0 139 | }, 140 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 141 | "graph_level": 0, 142 | "graph_scale": 0.5 143 | }, 144 | { 145 | "graph_gen_kwargs": { 146 | "num_neighbors": -1, 147 | "radius": 4.0 148 | }, 149 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 150 | "graph_level": 1, 151 | "graph_scale": 0.5 152 | } 153 | ] 154 | } 155 | } -------------------------------------------------------------------------------- /configs/car_auto_T2_train_eval_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_TEST_SAMPLE": -1, 3 | "checkpoint_path": "model", 4 | "config_path": "config", 5 | "data_aug_configs": [], 6 | "eval_dataset": "val.txt", 7 | "eval_dir": "./checkpoints/car_auto_T2_train_eval", 8 | "eval_every_second": 60, 9 | "gpu_memusage": -1, 10 | "max_step": 1400170, 11 | "train_dir": "./checkpoints/car_auto_T2_train", 12 | "visualization": false 13 | } 14 | -------------------------------------------------------------------------------- /configs/car_auto_T2_train_train_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_GPU": 2, 3 | "NUM_TEST_SAMPLE": -1, 4 | "batch_size": 4, 5 | "capacity": 1, 6 | "checkpoint_path": "model", 7 | "config_path": "config", 8 | "data_aug_configs": [ 9 | { 10 | "method_kwargs": { 11 | "expend_factor": [ 12 | 1.0, 13 | 1.0, 14 | 1.0 15 | ], 16 | "method_name": "normal", 17 | "yaw_std": 0.39269908169872414 18 | }, 19 | "method_name": "random_rotation_all" 20 | }, 21 | { 22 | "method_kwargs": { 23 | "flip_prob": 0.5 24 | }, 25 | "method_name": "random_flip_all" 26 | }, 27 | { 28 | "method_kwargs": { 29 | "appr_factor": 10, 30 | "expend_factor": [ 31 | 1.1, 32 | 1.1, 33 | 1.1 34 | ], 35 | "max_overlap_num_allowed": 100, 36 | "max_overlap_rate": 0.01, 37 | "max_trails": 100, 38 | "method_name": "normal", 39 | "xyz_std": [ 40 | 3, 41 | 0, 42 | 3 43 | ] 44 | }, 45 | "method_name": "random_box_shift" 46 | } 47 | ], 48 | "decay_factor": 0.1, 49 | "decay_step": 400000, 50 | "gpu_memusage": -1, 51 | "initial_lr": 0.125, 52 | "load_dataset_every_N_time": 0, 53 | "load_dataset_to_mem": true, 54 | "max_epoch": 1718, 55 | "max_steps": 1400000, 56 | "num_load_dataset_workers": 16, 57 | "optimizer": "sgd", 58 | "optimizer_kwargs": {}, 59 | "save_every_epoch": 20, 60 | "train_dataset": "train_car.txt", 61 | "train_dir": "./checkpoints/car_auto_T2_train", 62 | "unify_copies": true, 63 | "visualization": false 64 | } 65 | -------------------------------------------------------------------------------- /configs/car_auto_T3_train_config: -------------------------------------------------------------------------------- 1 | { 2 | "box_encoding_method": "classaware_all_class_box_encoding", 3 | "downsample_by_voxel_size": null, 4 | "eval_is_training": true, 5 | "graph_gen_kwargs": { 6 | "add_rnd3d": true, 7 | "base_voxel_size": 0.8, 8 | "downsample_method": "random", 9 | "level_configs": [ 10 | { 11 | "graph_gen_kwargs": { 12 | "num_neighbors": -1, 13 | "radius": 1.0 14 | }, 15 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 16 | "graph_level": 0, 17 | "graph_scale": 1 18 | }, 19 | { 20 | "graph_gen_kwargs": { 21 | "num_neighbors": 256, 22 | "radius": 4.0 23 | }, 24 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 25 | "graph_level": 1, 26 | "graph_scale": 1 27 | } 28 | ] 29 | }, 30 | "graph_gen_method": "multi_level_local_graph_v3", 31 | "input_features": "i", 32 | "label_method": "Car", 33 | "loss": { 34 | "cls_loss_type": "softmax", 35 | "cls_loss_weight": 0.1, 36 | "loc_loss_weight": 10.0 37 | }, 38 | "model_kwargs": { 39 | "layer_configs": [ 40 | { 41 | "graph_level": 0, 42 | "kwargs": { 43 | "output_MLP_activation_type": "ReLU", 44 | "output_MLP_depth_list": [ 45 | 300, 46 | 300 47 | ], 48 | "output_MLP_normalization_type": "NONE", 49 | "point_MLP_activation_type": "ReLU", 50 | "point_MLP_depth_list": [ 51 | 32, 52 | 64, 53 | 128, 54 | 300 55 | ], 56 | "point_MLP_normalization_type": "NONE" 57 | }, 58 | "scope": "layer1", 59 | "type": "scatter_max_point_set_pooling" 60 | }, 61 | { 62 | "graph_level": 1, 63 | "kwargs": { 64 | "auto_offset": true, 65 | "auto_offset_MLP_depth_list": [ 66 | 64, 67 | 3 68 | ], 69 | "auto_offset_MLP_feature_activation_type": "ReLU", 70 | "auto_offset_MLP_normalization_type": "NONE", 71 | "edge_MLP_activation_type": "ReLU", 72 | "edge_MLP_depth_list": [ 73 | 300, 74 | 300 75 | ], 76 | "edge_MLP_normalization_type": "NONE", 77 | "update_MLP_activation_type": "ReLU", 78 | "update_MLP_depth_list": [ 79 | 300, 80 | 300 81 | ], 82 | "update_MLP_normalization_type": "NONE" 83 | }, 84 | "scope": "layer2", 85 | "type": "scatter_max_graph_auto_center_net" 86 | }, 87 | { 88 | "graph_level": 1, 89 | "kwargs": { 90 | "auto_offset": true, 91 | "auto_offset_MLP_depth_list": [ 92 | 64, 93 | 3 94 | ], 95 | "auto_offset_MLP_feature_activation_type": "ReLU", 96 | "auto_offset_MLP_normalization_type": "NONE", 97 | "edge_MLP_activation_type": "ReLU", 98 | "edge_MLP_depth_list": [ 99 | 300, 100 | 300 101 | ], 102 | "edge_MLP_normalization_type": "NONE", 103 | "update_MLP_activation_type": "ReLU", 104 | "update_MLP_depth_list": [ 105 | 300, 106 | 300 107 | ], 108 | "update_MLP_normalization_type": "NONE" 109 | }, 110 | "scope": "layer3", 111 | "type": "scatter_max_graph_auto_center_net" 112 | }, 113 | { 114 | "graph_level": 1, 115 | "kwargs": { 116 | "auto_offset": true, 117 | "auto_offset_MLP_depth_list": [ 118 | 64, 119 | 3 120 | ], 121 | "auto_offset_MLP_feature_activation_type": "ReLU", 122 | "auto_offset_MLP_normalization_type": "NONE", 123 | "edge_MLP_activation_type": "ReLU", 124 | "edge_MLP_depth_list": [ 125 | 300, 126 | 300 127 | ], 128 | "edge_MLP_normalization_type": "NONE", 129 | "update_MLP_activation_type": "ReLU", 130 | "update_MLP_depth_list": [ 131 | 300, 132 | 300 133 | ], 134 | "update_MLP_normalization_type": "NONE" 135 | }, 136 | "scope": "layer4", 137 | "type": "scatter_max_graph_auto_center_net" 138 | }, 139 | { 140 | "graph_level": 1, 141 | "kwargs": { 142 | "activation_type": "ReLU", 143 | "normalization_type": "NONE" 144 | }, 145 | "scope": "output", 146 | "type": "classaware_predictor" 147 | } 148 | ], 149 | "regularizer_kwargs": { 150 | "scale": 5e-07 151 | }, 152 | "regularizer_type": "l1" 153 | }, 154 | "model_name": "multi_layer_fast_local_graph_model_v2", 155 | "nms_overlapped_thres": 0.01, 156 | "num_classes": 4, 157 | "runtime_graph_gen_kwargs": { 158 | "add_rnd3d": false, 159 | "base_voxel_size": 0.8, 160 | "level_configs": [ 161 | { 162 | "graph_gen_kwargs": { 163 | "num_neighbors": -1, 164 | "radius": 1.0 165 | }, 166 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 167 | "graph_level": 0, 168 | "graph_scale": 0.5 169 | }, 170 | { 171 | "graph_gen_kwargs": { 172 | "num_neighbors": -1, 173 | "radius": 4.0 174 | }, 175 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 176 | "graph_level": 1, 177 | "graph_scale": 0.5 178 | } 179 | ] 180 | } 181 | } -------------------------------------------------------------------------------- /configs/car_auto_T3_train_eval_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_TEST_SAMPLE": -1, 3 | "checkpoint_path": "model", 4 | "config_path": "config", 5 | "data_aug_configs": [], 6 | "eval_dataset": "val.txt", 7 | "eval_dir": "./checkpoints/car_auto_T3_train_eval", 8 | "eval_every_second": 60, 9 | "gpu_memusage": -1, 10 | "max_step": 1400170, 11 | "train_dir": "./checkpoints/car_auto_T3_train", 12 | "visualization": false 13 | } 14 | -------------------------------------------------------------------------------- /configs/car_auto_T3_train_train_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_GPU": 2, 3 | "NUM_TEST_SAMPLE": -1, 4 | "batch_size": 4, 5 | "capacity": 1, 6 | "checkpoint_path": "model", 7 | "config_path": "config", 8 | "data_aug_configs": [ 9 | { 10 | "method_kwargs": { 11 | "expend_factor": [ 12 | 1.0, 13 | 1.0, 14 | 1.0 15 | ], 16 | "method_name": "normal", 17 | "yaw_std": 0.39269908169872414 18 | }, 19 | "method_name": "random_rotation_all" 20 | }, 21 | { 22 | "method_kwargs": { 23 | "flip_prob": 0.5 24 | }, 25 | "method_name": "random_flip_all" 26 | }, 27 | { 28 | "method_kwargs": { 29 | "appr_factor": 10, 30 | "expend_factor": [ 31 | 1.1, 32 | 1.1, 33 | 1.1 34 | ], 35 | "max_overlap_num_allowed": 100, 36 | "max_overlap_rate": 0.01, 37 | "max_trails": 100, 38 | "method_name": "normal", 39 | "xyz_std": [ 40 | 3, 41 | 0, 42 | 3 43 | ] 44 | }, 45 | "method_name": "random_box_shift" 46 | } 47 | ], 48 | "decay_factor": 0.1, 49 | "decay_step": 400000, 50 | "gpu_memusage": -1, 51 | "initial_lr": 0.125, 52 | "load_dataset_every_N_time": 0, 53 | "load_dataset_to_mem": true, 54 | "max_epoch": 1718, 55 | "max_steps": 1400000, 56 | "num_load_dataset_workers": 16, 57 | "optimizer": "sgd", 58 | "optimizer_kwargs": {}, 59 | "save_every_epoch": 20, 60 | "train_dataset": "train_car.txt", 61 | "train_dir": "./checkpoints/car_auto_T3_train", 62 | "unify_copies": true, 63 | "visualization": false 64 | } 65 | -------------------------------------------------------------------------------- /configs/car_auto_T3_trainval_config: -------------------------------------------------------------------------------- 1 | { 2 | "box_encoding_method": "classaware_all_class_box_encoding", 3 | "downsample_by_voxel_size": null, 4 | "eval_is_training": true, 5 | "graph_gen_kwargs": { 6 | "add_rnd3d": true, 7 | "base_voxel_size": 0.8, 8 | "downsample_method": "random", 9 | "level_configs": [ 10 | { 11 | "graph_gen_kwargs": { 12 | "num_neighbors": -1, 13 | "radius": 1.0 14 | }, 15 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 16 | "graph_level": 0, 17 | "graph_scale": 1 18 | }, 19 | { 20 | "graph_gen_kwargs": { 21 | "num_neighbors": 256, 22 | "radius": 4.0 23 | }, 24 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 25 | "graph_level": 1, 26 | "graph_scale": 1 27 | } 28 | ] 29 | }, 30 | "graph_gen_method": "multi_level_local_graph_v3", 31 | "input_features": "i", 32 | "label_method": "Car", 33 | "loss": { 34 | "cls_loss_type": "softmax", 35 | "cls_loss_weight": 0.1, 36 | "loc_loss_weight": 10.0 37 | }, 38 | "model_kwargs": { 39 | "layer_configs": [ 40 | { 41 | "graph_level": 0, 42 | "kwargs": { 43 | "output_MLP_activation_type": "ReLU", 44 | "output_MLP_depth_list": [ 45 | 300, 46 | 300 47 | ], 48 | "output_MLP_normalization_type": "NONE", 49 | "point_MLP_activation_type": "ReLU", 50 | "point_MLP_depth_list": [ 51 | 32, 52 | 64, 53 | 128, 54 | 300 55 | ], 56 | "point_MLP_normalization_type": "NONE" 57 | }, 58 | "scope": "layer1", 59 | "type": "scatter_max_point_set_pooling" 60 | }, 61 | { 62 | "graph_level": 1, 63 | "kwargs": { 64 | "auto_offset": true, 65 | "auto_offset_MLP_depth_list": [ 66 | 64, 67 | 3 68 | ], 69 | "auto_offset_MLP_feature_activation_type": "ReLU", 70 | "auto_offset_MLP_normalization_type": "NONE", 71 | "edge_MLP_activation_type": "ReLU", 72 | "edge_MLP_depth_list": [ 73 | 300, 74 | 300 75 | ], 76 | "edge_MLP_normalization_type": "NONE", 77 | "update_MLP_activation_type": "ReLU", 78 | "update_MLP_depth_list": [ 79 | 300, 80 | 300 81 | ], 82 | "update_MLP_normalization_type": "NONE" 83 | }, 84 | "scope": "layer2", 85 | "type": "scatter_max_graph_auto_center_net" 86 | }, 87 | { 88 | "graph_level": 1, 89 | "kwargs": { 90 | "auto_offset": true, 91 | "auto_offset_MLP_depth_list": [ 92 | 64, 93 | 3 94 | ], 95 | "auto_offset_MLP_feature_activation_type": "ReLU", 96 | "auto_offset_MLP_normalization_type": "NONE", 97 | "edge_MLP_activation_type": "ReLU", 98 | "edge_MLP_depth_list": [ 99 | 300, 100 | 300 101 | ], 102 | "edge_MLP_normalization_type": "NONE", 103 | "update_MLP_activation_type": "ReLU", 104 | "update_MLP_depth_list": [ 105 | 300, 106 | 300 107 | ], 108 | "update_MLP_normalization_type": "NONE" 109 | }, 110 | "scope": "layer3", 111 | "type": "scatter_max_graph_auto_center_net" 112 | }, 113 | { 114 | "graph_level": 1, 115 | "kwargs": { 116 | "auto_offset": true, 117 | "auto_offset_MLP_depth_list": [ 118 | 64, 119 | 3 120 | ], 121 | "auto_offset_MLP_feature_activation_type": "ReLU", 122 | "auto_offset_MLP_normalization_type": "NONE", 123 | "edge_MLP_activation_type": "ReLU", 124 | "edge_MLP_depth_list": [ 125 | 300, 126 | 300 127 | ], 128 | "edge_MLP_normalization_type": "NONE", 129 | "update_MLP_activation_type": "ReLU", 130 | "update_MLP_depth_list": [ 131 | 300, 132 | 300 133 | ], 134 | "update_MLP_normalization_type": "NONE" 135 | }, 136 | "scope": "layer4", 137 | "type": "scatter_max_graph_auto_center_net" 138 | }, 139 | { 140 | "graph_level": 1, 141 | "kwargs": { 142 | "activation_type": "ReLU", 143 | "normalization_type": "NONE" 144 | }, 145 | "scope": "output", 146 | "type": "classaware_predictor" 147 | } 148 | ], 149 | "regularizer_kwargs": { 150 | "scale": 5e-07 151 | }, 152 | "regularizer_type": "l1" 153 | }, 154 | "model_name": "multi_layer_fast_local_graph_model_v2", 155 | "nms_overlapped_thres": 0.01, 156 | "num_classes": 4, 157 | "runtime_graph_gen_kwargs": { 158 | "add_rnd3d": false, 159 | "base_voxel_size": 0.8, 160 | "level_configs": [ 161 | { 162 | "graph_gen_kwargs": { 163 | "num_neighbors": -1, 164 | "radius": 1.0 165 | }, 166 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 167 | "graph_level": 0, 168 | "graph_scale": 0.5 169 | }, 170 | { 171 | "graph_gen_kwargs": { 172 | "num_neighbors": -1, 173 | "radius": 4.0 174 | }, 175 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 176 | "graph_level": 1, 177 | "graph_scale": 0.5 178 | } 179 | ] 180 | } 181 | } -------------------------------------------------------------------------------- /configs/car_auto_T3_trainval_eval_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_TEST_SAMPLE": -1, 3 | "checkpoint_path": "model", 4 | "config_path": "config", 5 | "data_aug_configs": [], 6 | "eval_dataset": "val.txt", 7 | "eval_dir": "./checkpoints/car_auto_T3_trainval_eval", 8 | "eval_every_second": 60, 9 | "gpu_memusage": -1, 10 | "max_step": 1400298, 11 | "train_dir": "./checkpoints/car_auto_T3_trainval", 12 | "visualization": false 13 | } 14 | -------------------------------------------------------------------------------- /configs/car_auto_T3_trainval_train_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_GPU": 2, 3 | "NUM_TEST_SAMPLE": -1, 4 | "batch_size": 4, 5 | "capacity": 1, 6 | "checkpoint_path": "model", 7 | "config_path": "config", 8 | "data_aug_configs": [ 9 | { 10 | "method_kwargs": { 11 | "expend_factor": [ 12 | 1.0, 13 | 1.0, 14 | 1.0 15 | ], 16 | "method_name": "normal", 17 | "yaw_std": 0.39269908169872414 18 | }, 19 | "method_name": "random_rotation_all" 20 | }, 21 | { 22 | "method_kwargs": { 23 | "flip_prob": 0.5 24 | }, 25 | "method_name": "random_flip_all" 26 | }, 27 | { 28 | "method_kwargs": { 29 | "appr_factor": 10, 30 | "expend_factor": [ 31 | 1.1, 32 | 1.1, 33 | 1.1 34 | ], 35 | "max_overlap_num_allowed": 100, 36 | "max_overlap_rate": 0.01, 37 | "max_trails": 100, 38 | "method_name": "normal", 39 | "xyz_std": [ 40 | 3, 41 | 0, 42 | 3 43 | ] 44 | }, 45 | "method_name": "random_box_shift" 46 | } 47 | ], 48 | "decay_factor": 0.1, 49 | "decay_step": 400000, 50 | "gpu_memusage": -1, 51 | "initial_lr": 0.125, 52 | "load_dataset_every_N_time": 0, 53 | "load_dataset_to_mem": true, 54 | "max_epoch": 838, 55 | "max_steps": 1400000, 56 | "num_load_dataset_workers": 16, 57 | "optimizer": "sgd", 58 | "optimizer_kwargs": {}, 59 | "save_every_epoch": 20, 60 | "train_dataset": "trainval_car.txt", 61 | "train_dir": "./checkpoints/car_auto_T3_trainval", 62 | "unify_copies": true, 63 | "visualization": false 64 | } 65 | -------------------------------------------------------------------------------- /configs/car_fixed_T3_train_config: -------------------------------------------------------------------------------- 1 | { 2 | "box_encoding_method": "classaware_all_class_box_encoding", 3 | "downsample_by_voxel_size": null, 4 | "eval_is_training": true, 5 | "graph_gen_kwargs": { 6 | "add_rnd3d": true, 7 | "base_voxel_size": 0.8, 8 | "downsample_method": "random", 9 | "level_configs": [ 10 | { 11 | "graph_gen_kwargs": { 12 | "num_neighbors": -1, 13 | "radius": 1.0 14 | }, 15 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 16 | "graph_level": 0, 17 | "graph_scale": 1 18 | }, 19 | { 20 | "graph_gen_kwargs": { 21 | "num_neighbors": 256, 22 | "radius": 4.0 23 | }, 24 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 25 | "graph_level": 1, 26 | "graph_scale": 1 27 | } 28 | ] 29 | }, 30 | "graph_gen_method": "multi_level_local_graph_v3", 31 | "input_features": "i", 32 | "label_method": "Car", 33 | "loss": { 34 | "cls_loss_type": "softmax", 35 | "cls_loss_weight": 0.1, 36 | "loc_loss_weight": 10.0 37 | }, 38 | "model_kwargs": { 39 | "layer_configs": [ 40 | { 41 | "graph_level": 0, 42 | "kwargs": { 43 | "output_MLP_activation_type": "ReLU", 44 | "output_MLP_depth_list": [ 45 | 300, 46 | 300 47 | ], 48 | "output_MLP_normalization_type": "NONE", 49 | "point_MLP_activation_type": "ReLU", 50 | "point_MLP_depth_list": [ 51 | 32, 52 | 64, 53 | 128, 54 | 300 55 | ], 56 | "point_MLP_normalization_type": "NONE" 57 | }, 58 | "scope": "layer1", 59 | "type": "scatter_max_point_set_pooling" 60 | }, 61 | { 62 | "graph_level": 1, 63 | "kwargs": { 64 | "auto_offset": false, 65 | "auto_offset_MLP_depth_list": [ 66 | 64, 67 | 3 68 | ], 69 | "auto_offset_MLP_feature_activation_type": "ReLU", 70 | "auto_offset_MLP_normalization_type": "NONE", 71 | "edge_MLP_activation_type": "ReLU", 72 | "edge_MLP_depth_list": [ 73 | 300, 74 | 300 75 | ], 76 | "edge_MLP_normalization_type": "NONE", 77 | "update_MLP_activation_type": "ReLU", 78 | "update_MLP_depth_list": [ 79 | 300, 80 | 300 81 | ], 82 | "update_MLP_normalization_type": "NONE" 83 | }, 84 | "scope": "layer2", 85 | "type": "scatter_max_graph_auto_center_net" 86 | }, 87 | { 88 | "graph_level": 1, 89 | "kwargs": { 90 | "auto_offset": false, 91 | "auto_offset_MLP_depth_list": [ 92 | 64, 93 | 3 94 | ], 95 | "auto_offset_MLP_feature_activation_type": "ReLU", 96 | "auto_offset_MLP_normalization_type": "NONE", 97 | "edge_MLP_activation_type": "ReLU", 98 | "edge_MLP_depth_list": [ 99 | 300, 100 | 300 101 | ], 102 | "edge_MLP_normalization_type": "NONE", 103 | "update_MLP_activation_type": "ReLU", 104 | "update_MLP_depth_list": [ 105 | 300, 106 | 300 107 | ], 108 | "update_MLP_normalization_type": "NONE" 109 | }, 110 | "scope": "layer3", 111 | "type": "scatter_max_graph_auto_center_net" 112 | }, 113 | { 114 | "graph_level": 1, 115 | "kwargs": { 116 | "auto_offset": false, 117 | "auto_offset_MLP_depth_list": [ 118 | 64, 119 | 3 120 | ], 121 | "auto_offset_MLP_feature_activation_type": "ReLU", 122 | "auto_offset_MLP_normalization_type": "NONE", 123 | "edge_MLP_activation_type": "ReLU", 124 | "edge_MLP_depth_list": [ 125 | 300, 126 | 300 127 | ], 128 | "edge_MLP_normalization_type": "NONE", 129 | "update_MLP_activation_type": "ReLU", 130 | "update_MLP_depth_list": [ 131 | 300, 132 | 300 133 | ], 134 | "update_MLP_normalization_type": "NONE" 135 | }, 136 | "scope": "layer4", 137 | "type": "scatter_max_graph_auto_center_net" 138 | }, 139 | { 140 | "graph_level": 1, 141 | "kwargs": { 142 | "activation_type": "ReLU", 143 | "normalization_type": "NONE" 144 | }, 145 | "scope": "output", 146 | "type": "classaware_predictor" 147 | } 148 | ], 149 | "regularizer_kwargs": { 150 | "scale": 5e-07 151 | }, 152 | "regularizer_type": "l1" 153 | }, 154 | "model_name": "multi_layer_fast_local_graph_model_v2", 155 | "nms_overlapped_thres": 0.01, 156 | "num_classes": 4, 157 | "runtime_graph_gen_kwargs": { 158 | "add_rnd3d": false, 159 | "base_voxel_size": 0.8, 160 | "level_configs": [ 161 | { 162 | "graph_gen_kwargs": { 163 | "num_neighbors": -1, 164 | "radius": 1.0 165 | }, 166 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 167 | "graph_level": 0, 168 | "graph_scale": 0.5 169 | }, 170 | { 171 | "graph_gen_kwargs": { 172 | "num_neighbors": -1, 173 | "radius": 4.0 174 | }, 175 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 176 | "graph_level": 1, 177 | "graph_scale": 0.5 178 | } 179 | ] 180 | } 181 | } -------------------------------------------------------------------------------- /configs/car_fixed_T3_train_eval_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_TEST_SAMPLE": -1, 3 | "checkpoint_path": "model", 4 | "config_path": "config", 5 | "data_aug_configs": [], 6 | "eval_dataset": "val.txt", 7 | "eval_dir": "./checkpoints/car_fixed_T3_train_eval", 8 | "eval_every_second": 60, 9 | "gpu_memusage": -1, 10 | "max_step": 1400170, 11 | "train_dir": "./checkpoints/car_fixed_T3_train", 12 | "visualization": false 13 | } 14 | -------------------------------------------------------------------------------- /configs/car_fixed_T3_train_train_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_GPU": 2, 3 | "NUM_TEST_SAMPLE": -1, 4 | "batch_size": 4, 5 | "capacity": 1, 6 | "checkpoint_path": "model", 7 | "config_path": "config", 8 | "data_aug_configs": [ 9 | { 10 | "method_kwargs": { 11 | "expend_factor": [ 12 | 1.0, 13 | 1.0, 14 | 1.0 15 | ], 16 | "method_name": "normal", 17 | "yaw_std": 0.39269908169872414 18 | }, 19 | "method_name": "random_rotation_all" 20 | }, 21 | { 22 | "method_kwargs": { 23 | "flip_prob": 0.5 24 | }, 25 | "method_name": "random_flip_all" 26 | }, 27 | { 28 | "method_kwargs": { 29 | "appr_factor": 10, 30 | "expend_factor": [ 31 | 1.1, 32 | 1.1, 33 | 1.1 34 | ], 35 | "max_overlap_num_allowed": 100, 36 | "max_overlap_rate": 0.01, 37 | "max_trails": 100, 38 | "method_name": "normal", 39 | "xyz_std": [ 40 | 3, 41 | 0, 42 | 3 43 | ] 44 | }, 45 | "method_name": "random_box_shift" 46 | } 47 | ], 48 | "decay_factor": 0.1, 49 | "decay_step": 400000, 50 | "gpu_memusage": -1, 51 | "initial_lr": 0.125, 52 | "load_dataset_every_N_time": 0, 53 | "load_dataset_to_mem": true, 54 | "max_epoch": 1718, 55 | "max_steps": 1400000, 56 | "num_load_dataset_workers": 16, 57 | "optimizer": "sgd", 58 | "optimizer_kwargs": {}, 59 | "save_every_epoch": 20, 60 | "train_dataset": "train_car.txt", 61 | "train_dir": "./checkpoints/car_fixed_T3_train", 62 | "unify_copies": true, 63 | "visualization": false 64 | } 65 | -------------------------------------------------------------------------------- /configs/ped_cyl_auto_T3_trainval_config: -------------------------------------------------------------------------------- 1 | { 2 | "box_encoding_method": "classaware_all_class_box_encoding", 3 | "downsample_by_voxel_size": null, 4 | "eval_is_training": true, 5 | "graph_gen_kwargs": { 6 | "add_rnd3d": true, 7 | "base_voxel_size": 0.8, 8 | "downsample_method": "random", 9 | "level_configs": [ 10 | { 11 | "graph_gen_kwargs": { 12 | "num_neighbors": -1, 13 | "radius": 0.4 14 | }, 15 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 16 | "graph_level": 0, 17 | "graph_scale": 0.5 18 | }, 19 | { 20 | "graph_gen_kwargs": { 21 | "num_neighbors": 256, 22 | "radius": 1.6 23 | }, 24 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 25 | "graph_level": 1, 26 | "graph_scale": 0.5 27 | } 28 | ] 29 | }, 30 | "graph_gen_method": "multi_level_local_graph_v3", 31 | "input_features": "i", 32 | "label_method": "Pedestrian_and_Cyclist", 33 | "loss": { 34 | "cls_loss_type": "softmax", 35 | "cls_loss_weight": 0.1, 36 | "loc_loss_weight": 10.0 37 | }, 38 | "model_kwargs": { 39 | "layer_configs": [ 40 | { 41 | "graph_level": 0, 42 | "kwargs": { 43 | "output_MLP_activation_type": "ReLU", 44 | "output_MLP_depth_list": [ 45 | 256, 46 | 256 47 | ], 48 | "output_MLP_normalization_type": "NONE", 49 | "point_MLP_activation_type": "ReLU", 50 | "point_MLP_depth_list": [ 51 | 32, 52 | 64, 53 | 128, 54 | 256, 55 | 512 56 | ], 57 | "point_MLP_normalization_type": "NONE" 58 | }, 59 | "scope": "layer1", 60 | "type": "scatter_max_point_set_pooling" 61 | }, 62 | { 63 | "graph_level": 1, 64 | "kwargs": { 65 | "auto_offset": true, 66 | "auto_offset_MLP_depth_list": [ 67 | 64, 68 | 3 69 | ], 70 | "auto_offset_MLP_feature_activation_type": "ReLU", 71 | "auto_offset_MLP_normalization_type": "NONE", 72 | "edge_MLP_activation_type": "ReLU", 73 | "edge_MLP_depth_list": [ 74 | 256, 75 | 256 76 | ], 77 | "edge_MLP_normalization_type": "NONE", 78 | "update_MLP_activation_type": "ReLU", 79 | "update_MLP_depth_list": [ 80 | 256, 81 | 256 82 | ], 83 | "update_MLP_normalization_type": "NONE" 84 | }, 85 | "scope": "layer2", 86 | "type": "scatter_max_graph_auto_center_net" 87 | }, 88 | { 89 | "graph_level": 1, 90 | "kwargs": { 91 | "auto_offset": true, 92 | "auto_offset_MLP_depth_list": [ 93 | 64, 94 | 3 95 | ], 96 | "auto_offset_MLP_feature_activation_type": "ReLU", 97 | "auto_offset_MLP_normalization_type": "NONE", 98 | "edge_MLP_activation_type": "ReLU", 99 | "edge_MLP_depth_list": [ 100 | 256, 101 | 256 102 | ], 103 | "edge_MLP_normalization_type": "NONE", 104 | "update_MLP_activation_type": "ReLU", 105 | "update_MLP_depth_list": [ 106 | 256, 107 | 256 108 | ], 109 | "update_MLP_normalization_type": "NONE" 110 | }, 111 | "scope": "layer3", 112 | "type": "scatter_max_graph_auto_center_net" 113 | }, 114 | { 115 | "graph_level": 1, 116 | "kwargs": { 117 | "auto_offset": true, 118 | "auto_offset_MLP_depth_list": [ 119 | 64, 120 | 3 121 | ], 122 | "auto_offset_MLP_feature_activation_type": "ReLU", 123 | "auto_offset_MLP_normalization_type": "NONE", 124 | "edge_MLP_activation_type": "ReLU", 125 | "edge_MLP_depth_list": [ 126 | 256, 127 | 256 128 | ], 129 | "edge_MLP_normalization_type": "NONE", 130 | "update_MLP_activation_type": "ReLU", 131 | "update_MLP_depth_list": [ 132 | 256, 133 | 256 134 | ], 135 | "update_MLP_normalization_type": "NONE" 136 | }, 137 | "scope": "layer4", 138 | "type": "scatter_max_graph_auto_center_net" 139 | }, 140 | { 141 | "graph_level": 1, 142 | "kwargs": { 143 | "activation_type": "ReLU", 144 | "normalization_type": "NONE" 145 | }, 146 | "scope": "output", 147 | "type": "classaware_predictor" 148 | } 149 | ], 150 | "regularizer_kwargs": { 151 | "scale": 1e-06 152 | }, 153 | "regularizer_type": "l1" 154 | }, 155 | "model_name": "multi_layer_fast_local_graph_model_v2", 156 | "nms_overlapped_thres": 0.2, 157 | "num_classes": 6, 158 | "runtime_graph_gen_kwargs": { 159 | "add_rnd3d": false, 160 | "base_voxel_size": 0.8, 161 | "level_configs": [ 162 | { 163 | "graph_gen_kwargs": { 164 | "num_neighbors": -1, 165 | "radius": 0.4 166 | }, 167 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 168 | "graph_level": 0, 169 | "graph_scale": 0.25 170 | }, 171 | { 172 | "graph_gen_kwargs": { 173 | "num_neighbors": -1, 174 | "radius": 1.6 175 | }, 176 | "graph_gen_method": "disjointed_rnn_local_graph_v3", 177 | "graph_level": 1, 178 | "graph_scale": 0.25 179 | } 180 | ] 181 | } 182 | } -------------------------------------------------------------------------------- /configs/ped_cyl_auto_T3_trainval_eval_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_TEST_SAMPLE": -1, 3 | "checkpoint_path": "model", 4 | "config_path": "config", 5 | "data_aug_configs": [], 6 | "eval_dataset": "val.txt", 7 | "eval_dir": "./checkpoints/ped_cyl_auto_T3_trainval_eval", 8 | "eval_every_second": 60, 9 | "gpu_memusage": -1, 10 | "max_step": 1000000, 11 | "train_dir": "./checkpoints/ped_cyl_auto_T3_trainval", 12 | "visualization": false 13 | } 14 | -------------------------------------------------------------------------------- /configs/ped_cyl_auto_T3_trainval_train_config: -------------------------------------------------------------------------------- 1 | { 2 | "NUM_GPU": 2, 3 | "NUM_TEST_SAMPLE": -1, 4 | "batch_size": 4, 5 | "capacity": 1, 6 | "checkpoint_path": "model", 7 | "config_path": "config", 8 | "data_aug_configs": [ 9 | { 10 | "method_kwargs": { 11 | "expend_factor": [ 12 | 1.0, 13 | 1.0, 14 | 1.0 15 | ], 16 | "method_name": "normal", 17 | "yaw_std": 0.39269908169872414 18 | }, 19 | "method_name": "random_rotation_all" 20 | }, 21 | { 22 | "method_kwargs": { 23 | "flip_prob": 0.5 24 | }, 25 | "method_name": "random_flip_all" 26 | }, 27 | { 28 | "method_kwargs": { 29 | "appr_factor": 10, 30 | "expend_factor": [ 31 | 1.1, 32 | 1.1, 33 | 1.1 34 | ], 35 | "max_overlap_num_allowed": 100, 36 | "max_overlap_rate": 0.01, 37 | "max_trails": 100, 38 | "method_name": "normal", 39 | "xyz_std": [ 40 | 3, 41 | 0, 42 | 3 43 | ] 44 | }, 45 | "method_name": "random_box_shift" 46 | } 47 | ], 48 | "decay_factor": 0.25, 49 | "decay_step": 400000, 50 | "gpu_memusage": -1, 51 | "initial_lr": 0.32, 52 | "load_dataset_every_N_time": 0, 53 | "load_dataset_to_mem": true, 54 | "max_epoch": 1611, 55 | "max_steps": 1000000, 56 | "num_load_dataset_workers": 16, 57 | "optimizer": "sgd", 58 | "optimizer_kwargs": {}, 59 | "save_every_epoch": 20, 60 | "train_dataset": "trainval_pedestrian_cyclist.txt", 61 | "train_dir": "./checkpoints/ped_cyl_auto_T3_trainval", 62 | "unify_copies": true, 63 | "visualization": false 64 | } 65 | -------------------------------------------------------------------------------- /eval.py: -------------------------------------------------------------------------------- 1 | """This file defines the evaluation process of Point-GNN object detection.""" 2 | 3 | import os 4 | import time 5 | import argparse 6 | 7 | import numpy as np 8 | import tensorflow as tf 9 | 10 | from dataset.kitti_dataset import KittiDataset 11 | from models.graph_gen import get_graph_generate_fn 12 | from models.models import get_model 13 | from models.box_encoding import get_box_decoding_fn, get_box_encoding_fn, \ 14 | get_encoding_len 15 | from models import preprocess 16 | from util.config_util import load_config, load_train_config 17 | from util.summary_util import write_summary_scale 18 | 19 | parser = argparse.ArgumentParser(description='Repeated evaluation of PointGNN.') 20 | parser.add_argument('eval_config_path', type=str, 21 | help='Path to train_config') 22 | parser.add_argument('--dataset_root_dir', type=str, default='../dataset/kitti/', 23 | help='Path to KITTI dataset. Default="../dataset/kitti/"') 24 | parser.add_argument('--dataset_split_file', type=str, 25 | default='', 26 | help='Path to KITTI dataset split file.' 27 | 'Default="DATASET_ROOT_DIR/3DOP_splits' 28 | '/eval_config["eval_dataset"]"') 29 | args = parser.parse_args() 30 | eval_config = load_train_config(args.eval_config_path) 31 | DATASET_DIR = args.dataset_root_dir 32 | if args.dataset_split_file == '': 33 | DATASET_SPLIT_FILE = os.path.join(DATASET_DIR, 34 | './3DOP_splits/'+eval_config['eval_dataset']) 35 | else: 36 | DATASET_SPLIT_FILE = args.dataset_split_file 37 | 38 | config_path = os.path.join(eval_config['train_dir'], eval_config['config_path']) 39 | while not os.path.isfile(config_path): 40 | print('No config file found in %s, waiting' % config_path) 41 | time.sleep(eval_config['eval_every_second']) 42 | config = load_config(config_path) 43 | if 'eval' in config: 44 | config = config['eval'] 45 | dataset = KittiDataset( 46 | os.path.join(DATASET_DIR, 'image/training/image_2'), 47 | os.path.join(DATASET_DIR, 'velodyne/training/velodyne/'), 48 | os.path.join(DATASET_DIR, 'calib/training/calib/'), 49 | os.path.join(DATASET_DIR, 'labels/training/label_2'), 50 | DATASET_SPLIT_FILE, 51 | num_classes=config['num_classes']) 52 | NUM_CLASSES = dataset.num_classes 53 | 54 | if 'NUM_TEST_SAMPLE' not in eval_config: 55 | NUM_TEST_SAMPLE = dataset.num_files 56 | else: 57 | if eval_config['NUM_TEST_SAMPLE'] < 0: 58 | NUM_TEST_SAMPLE = dataset.num_files 59 | else: 60 | NUM_TEST_SAMPLE = eval_config['NUM_TEST_SAMPLE'] 61 | 62 | print(NUM_TEST_SAMPLE) 63 | BOX_ENCODING_LEN = get_encoding_len(config['box_encoding_method']) 64 | box_encoding_fn = get_box_encoding_fn(config['box_encoding_method']) 65 | box_decoding_fn = get_box_decoding_fn(config['box_encoding_method']) 66 | 67 | aug_fn = preprocess.get_data_aug(eval_config['data_aug_configs']) 68 | def fetch_data(frame_idx): 69 | cam_rgb_points = dataset.get_cam_points_in_image_with_rgb(frame_idx, 70 | config['downsample_by_voxel_size']) 71 | box_label_list = dataset.get_label(frame_idx) 72 | cam_rgb_points, box_label_list = aug_fn(cam_rgb_points, box_label_list) 73 | graph_generate_fn= get_graph_generate_fn(config['graph_gen_method']) 74 | (vertex_coord_list, keypoint_indices_list, edges_list) = graph_generate_fn( 75 | cam_rgb_points.xyz, **config['graph_gen_kwargs']) 76 | if config['input_features'] == 'irgb': 77 | input_v = cam_rgb_points.attr 78 | elif config['input_features'] == '0rgb': 79 | input_v = np.hstack([np.zeros((cam_rgb_points.attr.shape[0], 1)), 80 | cam_rgb_points.attr[:, 1:]]) 81 | elif config['input_features'] == '0000': 82 | input_v = np.zeros_like(cam_rgb_points.attr) 83 | elif config['input_features'] == 'i000': 84 | input_v = np.hstack([cam_rgb_points.attr[:, [0]], 85 | np.zeros((cam_rgb_points.attr.shape[0], 3))]) 86 | elif config['input_features'] == 'i': 87 | input_v = cam_rgb_points.attr[:, [0]] 88 | elif config['input_features'] == '0': 89 | input_v = np.zeros((cam_rgb_points.attr.shape[0], 1)) 90 | last_layer_graph_level = config['model_kwargs'][ 91 | 'layer_configs'][-1]['graph_level'] 92 | last_layer_points_xyz = vertex_coord_list[last_layer_graph_level+1] 93 | if config['label_method'] == 'yaw': 94 | (cls_labels, boxes_3d, valid_boxes, label_map) =\ 95 | dataset.assign_classaware_label_to_points(box_label_list, 96 | last_layer_points_xyz, expend_factor=(1.0, 1.0, 1.0)) 97 | if config['label_method'] == 'Car': 98 | cls_labels, boxes_3d, valid_boxes, label_map =\ 99 | dataset.assign_classaware_car_label_to_points(box_label_list, 100 | last_layer_points_xyz, expend_factor=(1.0, 1.0, 1.0)) 101 | if config['label_method'] == 'Pedestrian_and_Cyclist': 102 | cls_labels, boxes_3d, valid_boxes, label_map =\ 103 | dataset.assign_classaware_ped_and_cyc_label_to_points( 104 | box_label_list, 105 | last_layer_points_xyz, expend_factor=(1.0, 1.0, 1.0)) 106 | encoded_boxes = box_encoding_fn( 107 | cls_labels, last_layer_points_xyz, boxes_3d, label_map) 108 | # reducing memory usage by casting to 32bits 109 | input_v = input_v.astype(np.float32) 110 | vertex_coord_list = [p.astype(np.float32) for p in vertex_coord_list] 111 | keypoint_indices_list = [e.astype(np.int32) for e in keypoint_indices_list] 112 | edges_list = [e.astype(np.int32) for e in edges_list] 113 | cls_labels = cls_labels.astype(np.int32) 114 | encoded_boxes = encoded_boxes.astype(np.float32) 115 | valid_boxes = valid_boxes.astype(np.float32) 116 | return(input_v, vertex_coord_list, keypoint_indices_list, edges_list, 117 | cls_labels, encoded_boxes, valid_boxes) 118 | 119 | # model ======================================================================= 120 | if config['input_features'] == 'irgb': 121 | t_initial_vertex_features = tf.placeholder( 122 | dtype=tf.float32, shape=[None, 4]) 123 | elif config['input_features'] == 'rgb': 124 | t_initial_vertex_features = tf.placeholder( 125 | dtype=tf.float32, shape=[None, 3]) 126 | elif config['input_features'] == '0000': 127 | t_initial_vertex_features = tf.placeholder( 128 | dtype=tf.float32, shape=[None, 4]) 129 | elif config['input_features'] == 'i000': 130 | t_initial_vertex_features = tf.placeholder( 131 | dtype=tf.float32, shape=[None, 4]) 132 | elif config['input_features'] == 'i': 133 | t_initial_vertex_features = tf.placeholder( 134 | dtype=tf.float32, shape=[None, 1]) 135 | elif config['input_features'] == '0': 136 | t_initial_vertex_features = tf.placeholder( 137 | dtype=tf.float32, shape=[None, 1]) 138 | 139 | t_vertex_coord_list = [tf.placeholder(dtype=tf.float32, shape=[None, 3])] 140 | for _ in range(len(config['graph_gen_kwargs']['level_configs'])): 141 | t_vertex_coord_list.append( 142 | tf.placeholder(dtype=tf.float32, shape=[None, 3])) 143 | 144 | t_edges_list = [] 145 | for _ in range(len(config['graph_gen_kwargs']['level_configs'])): 146 | t_edges_list.append( 147 | tf.placeholder(dtype=tf.int32, shape=[None, 2])) 148 | 149 | t_keypoint_indices_list = [] 150 | for _ in range(len(config['graph_gen_kwargs']['level_configs'])): 151 | t_keypoint_indices_list.append( 152 | tf.placeholder(dtype=tf.int32, shape=[None, 1])) 153 | 154 | t_class_labels = tf.placeholder(dtype=tf.int32, shape=[None, 1]) 155 | t_encoded_gt_boxes = tf.placeholder(dtype=tf.float32, 156 | shape=[None, 1, BOX_ENCODING_LEN]) 157 | t_valid_gt_boxes = tf.placeholder(dtype=tf.float32, shape=[None, 1, 1]) 158 | 159 | t_is_training = tf.placeholder(dtype=tf.bool, shape=[]) 160 | 161 | model = get_model(config['model_name'])(num_classes=NUM_CLASSES, 162 | box_encoding_len=BOX_ENCODING_LEN, mode='eval', **config['model_kwargs']) 163 | t_logits, t_pred_box = model.predict( 164 | t_initial_vertex_features, t_vertex_coord_list, t_keypoint_indices_list, 165 | t_edges_list, 166 | t_is_training) 167 | t_probs = model.postprocess(t_logits) 168 | t_predictions = tf.argmax(t_probs, axis=1, output_type=tf.int32) 169 | t_loss_dict = model.loss(t_logits, t_class_labels, t_pred_box, 170 | t_encoded_gt_boxes, t_valid_gt_boxes, **config['loss']) 171 | t_cls_loss = t_loss_dict['cls_loss'] 172 | t_loc_loss = t_loss_dict['loc_loss'] 173 | t_reg_loss = t_loss_dict['reg_loss'] 174 | t_classwise_loc_loss = t_loss_dict['classwise_loc_loss'] 175 | t_total_loss = t_cls_loss + t_loc_loss + t_reg_loss 176 | 177 | t_classwise_loc_loss_update_ops = {} 178 | for class_idx in range(NUM_CLASSES): 179 | for bi in range(BOX_ENCODING_LEN): 180 | classwise_loc_loss_ind =t_classwise_loc_loss[class_idx][bi] 181 | t_mean_loss, t_mean_loss_op = tf.metrics.mean( 182 | classwise_loc_loss_ind, 183 | name=('loc_loss_cls_%d_box_%d'%(class_idx, bi))) 184 | t_classwise_loc_loss_update_ops[ 185 | ('loc_loss_cls_%d_box_%d'%(class_idx, bi))] = t_mean_loss_op 186 | classwise_loc_loss =t_classwise_loc_loss[class_idx] 187 | t_mean_loss, t_mean_loss_op = tf.metrics.mean( 188 | classwise_loc_loss, 189 | name=('loc_loss_cls_%d'%class_idx)) 190 | t_classwise_loc_loss_update_ops[ 191 | ('loc_loss_cls_%d'%class_idx)] = t_mean_loss_op 192 | 193 | # metrics 194 | t_recall_update_ops = {} 195 | for class_idx in range(NUM_CLASSES): 196 | t_recall, t_recall_update_op = tf.metrics.recall( 197 | tf.equal(t_class_labels, tf.constant(class_idx, tf.int32)), 198 | tf.equal(t_predictions, tf.constant(class_idx, tf.int32)), 199 | name=('recall_%d'%class_idx)) 200 | t_recall_update_ops[('recall_%d'%class_idx)] = t_recall_update_op 201 | 202 | t_precision_update_ops = {} 203 | for class_idx in range(NUM_CLASSES): 204 | t_precision, t_precision_update_op = tf.metrics.precision( 205 | tf.equal(t_class_labels, tf.constant(class_idx, tf.int32)), 206 | tf.equal(t_predictions, tf.constant(class_idx, tf.int32)), 207 | name=('precision_%d'%class_idx)) 208 | t_precision_update_ops[('precision_%d'%class_idx)] = t_precision_update_op 209 | 210 | t_mAP_update_ops = {} 211 | for class_idx in range(NUM_CLASSES): 212 | t_mAP, t_mAP_update_op = tf.metrics.auc( 213 | tf.equal(t_class_labels, tf.constant(class_idx, tf.int32)), 214 | t_probs[:, class_idx], 215 | num_thresholds=200, 216 | curve='PR', 217 | name=('mAP_%d'%class_idx), 218 | summation_method='careful_interpolation') 219 | t_mAP_update_ops[('mAP_%d'%class_idx)] = t_mAP_update_op 220 | 221 | t_mean_cls_loss, t_mean_cls_loss_op = tf.metrics.mean( 222 | t_cls_loss, 223 | name='mean_cls_loss') 224 | t_mean_loc_loss, t_mean_loc_loss_op = tf.metrics.mean( 225 | t_loc_loss, 226 | name='mean_loc_loss') 227 | t_mean_reg_loss, t_mean_reg_loss_op = tf.metrics.mean( 228 | t_reg_loss, 229 | name='mean_reg_loss') 230 | t_mean_total_loss, t_mean_total_loss_op = tf.metrics.mean( 231 | t_total_loss, 232 | name='mean_total_loss') 233 | 234 | metrics_update_ops = { 235 | 'cls_loss': t_mean_cls_loss_op, 236 | 'loc_loss': t_mean_loc_loss_op, 237 | 'reg_loss': t_mean_reg_loss_op, 238 | 'total_loss': t_mean_total_loss_op,} 239 | metrics_update_ops.update(t_recall_update_ops) 240 | metrics_update_ops.update(t_precision_update_ops) 241 | metrics_update_ops.update(t_mAP_update_ops) 242 | metrics_update_ops.update(t_classwise_loc_loss_update_ops) 243 | 244 | # optimizers ================================================================ 245 | global_step = tf.Variable(0, dtype=tf.int32, trainable=False) 246 | fetches = { 247 | 'step': global_step, 248 | 'predictions': t_predictions, 249 | 'pred_box': t_pred_box 250 | 251 | } 252 | fetches.update(metrics_update_ops) 253 | 254 | # preprocessing data ======================================================== 255 | class DataProvider(object): 256 | """This class provides input data to training. 257 | It has option to load dataset in memory so that preprocessing does not 258 | repeat every time. 259 | Note, if there is randomness inside graph creation, samples should be 260 | reloaded for the randomness to take effect. 261 | """ 262 | def __init__(self, fetch_data, load_dataset_to_mem=True, 263 | load_dataset_every_N_time=1, capacity=1): 264 | self._fetch_data = fetch_data 265 | self._loaded_data_dic = {} 266 | self._loaded_data_ctr_dic = {} 267 | self._load_dataset_to_mem = load_dataset_to_mem 268 | self._load_every_N_time = load_dataset_every_N_time 269 | self._capacity = capacity 270 | def provide(self, frame_idx): 271 | extend_frame_idx = frame_idx+np.random.choice( 272 | self._capacity)*NUM_TEST_SAMPLE 273 | if self._load_dataset_to_mem: 274 | if extend_frame_idx in self._loaded_data_ctr_dic: 275 | ctr = self._loaded_data_ctr_dic[extend_frame_idx] 276 | if ctr >= self._load_every_N_time: 277 | del self._loaded_data_ctr_dic[extend_frame_idx] 278 | del self._loaded_data_dic[extend_frame_idx] 279 | if frame_idx not in self._loaded_data_dic: 280 | self._loaded_data_dic[extend_frame_idx] = self._fetch_data( 281 | frame_idx) 282 | self._loaded_data_ctr_dic[extend_frame_idx] = 0 283 | self._loaded_data_ctr_dic[extend_frame_idx] += 1 284 | return self._loaded_data_dic[extend_frame_idx] 285 | else: 286 | return self._fetch_data(frame_idx) 287 | 288 | data_provider = DataProvider(fetch_data, load_dataset_to_mem=False) 289 | saver = tf.train.Saver() 290 | graph = tf.get_default_graph() 291 | if eval_config['gpu_memusage'] < 0: 292 | gpu_options = tf.GPUOptions(allow_growth=True) 293 | else: 294 | gpu_options = tf.GPUOptions( 295 | per_process_gpu_memory_fraction=eval_config['gpu_memusage']) 296 | 297 | def eval_once(graph, gpu_options, saver, checkpoint_path): 298 | """Evaluate the model once. """ 299 | with tf.Session(graph=graph, 300 | config=tf.ConfigProto(gpu_options=gpu_options)) as sess: 301 | sess.run(tf.variables_initializer(tf.global_variables())) 302 | sess.run(tf.variables_initializer(tf.local_variables())) 303 | print('Restore from checkpoint %s' % checkpoint_path) 304 | saver.restore(sess, checkpoint_path) 305 | previous_step = sess.run(global_step) 306 | print('Global step = %d' % previous_step) 307 | start_time = time.time() 308 | for frame_idx in range(NUM_TEST_SAMPLE): 309 | (input_v, vertex_coord_list, keypoint_indices_list, edges_list, 310 | cls_labels, encoded_boxes, valid_boxes)\ 311 | = data_provider.provide(frame_idx) 312 | feed_dict = { 313 | t_initial_vertex_features: input_v, 314 | t_class_labels: cls_labels, 315 | t_encoded_gt_boxes: encoded_boxes, 316 | t_valid_gt_boxes: valid_boxes, 317 | t_is_training: config['eval_is_training'], 318 | } 319 | feed_dict.update(dict(zip(t_edges_list, edges_list))) 320 | feed_dict.update( 321 | dict(zip(t_keypoint_indices_list, keypoint_indices_list))) 322 | feed_dict.update(dict(zip(t_vertex_coord_list, vertex_coord_list))) 323 | results = sess.run(fetches, feed_dict=feed_dict) 324 | 325 | if NUM_TEST_SAMPLE >= 10: 326 | if (frame_idx + 1) % (NUM_TEST_SAMPLE // 10) == 0: 327 | print('@frame %d' % frame_idx) 328 | print('cls:%f, loc:%f, reg:%f, loss: %f' 329 | % (results['cls_loss'], results['loc_loss'], 330 | results['reg_loss'], results['total_loss'])) 331 | for class_idx in range(NUM_CLASSES): 332 | print('Class_%d: recall=%f, prec=%f, mAP=%f, loc=%f' 333 | % (class_idx, 334 | results['recall_%d'%class_idx], 335 | results['precision_%d'%class_idx], 336 | results['mAP_%d'%class_idx], 337 | results['loc_loss_cls_%d'%class_idx])) 338 | print(' '+\ 339 | 'x=%.4f y=%.4f z=%.4f l=%.4f h=%.4f w=%.4f y=%.4f' 340 | %( 341 | results['loc_loss_cls_%d_box_%d'%(class_idx, 0)], 342 | results['loc_loss_cls_%d_box_%d'%(class_idx, 1)], 343 | results['loc_loss_cls_%d_box_%d'%(class_idx, 2)], 344 | results['loc_loss_cls_%d_box_%d'%(class_idx, 3)], 345 | results['loc_loss_cls_%d_box_%d'%(class_idx, 4)], 346 | results['loc_loss_cls_%d_box_%d'%(class_idx, 5)], 347 | results['loc_loss_cls_%d_box_%d'%(class_idx, 6)]), 348 | ) 349 | print('STEP: %d, time cost: %f' 350 | % (results['step'], time.time()-start_time)) 351 | print('cls:%f, loc:%f, reg:%f, loss: %f' 352 | % (results['cls_loss'], results['loc_loss'], results['reg_loss'], 353 | results['total_loss'])) 354 | for class_idx in range(NUM_CLASSES): 355 | print('Class_%d: recall=%f, prec=%f, mAP=%f, loc=%f' 356 | % (class_idx, 357 | results['recall_%d'%class_idx], 358 | results['precision_%d'%class_idx], 359 | results['mAP_%d'%class_idx], 360 | results['loc_loss_cls_%d'%class_idx])) 361 | print(" x=%.4f y=%.4f z=%.4f l=%.4f h=%.4f w=%.4f y=%.4f" 362 | %( 363 | results['loc_loss_cls_%d_box_%d'%(class_idx, 0)], 364 | results['loc_loss_cls_%d_box_%d'%(class_idx, 1)], 365 | results['loc_loss_cls_%d_box_%d'%(class_idx, 2)], 366 | results['loc_loss_cls_%d_box_%d'%(class_idx, 3)], 367 | results['loc_loss_cls_%d_box_%d'%(class_idx, 4)], 368 | results['loc_loss_cls_%d_box_%d'%(class_idx, 5)], 369 | results['loc_loss_cls_%d_box_%d'%(class_idx, 6)]), 370 | ) 371 | # add summaries ==================================================== 372 | for key in metrics_update_ops: 373 | write_summary_scale(key, results[key], results['step'], 374 | eval_config['eval_dir']) 375 | return results['step'] 376 | 377 | def eval_repeat(): 378 | last_evaluated_model_path = None 379 | while True: 380 | previous_time = time.time() 381 | model_path = tf.train.latest_checkpoint(eval_config['train_dir']) 382 | if not model_path: 383 | print('No checkpoint found in %s, wait for %f seconds' 384 | % (eval_config['train_dir'], eval_config['eval_every_second'])) 385 | if last_evaluated_model_path == model_path: 386 | print( 387 | 'Checkpoint %s has been evaluated already, wait for %f seconds' 388 | % (model_path, eval_config['eval_every_second'])) 389 | else: 390 | last_evaluated_model_path = model_path 391 | current_step = eval_once(graph, gpu_options, saver, model_path) 392 | if current_step >= eval_config['max_step']: 393 | break 394 | time_to_next_eval = ( 395 | previous_time + eval_config['eval_every_second'] - time.time()) 396 | if time_to_next_eval > 0: 397 | time.sleep(time_to_next_eval) 398 | 399 | if __name__ == '__main__': 400 | eval_repeat() 401 | -------------------------------------------------------------------------------- /kitty_dataset.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import torch 3 | from models import preprocess 4 | from models.box_encoding import get_box_decoding_fn, get_box_encoding_fn,\ 5 | get_encoding_len 6 | from dataset.kitti_dataset import KittiDataset 7 | from models.graph_gen import get_graph_generate_fn 8 | from multiprocessing import Pool, Queue, Process 9 | import os 10 | import argparse 11 | from util.config_util import save_config, save_train_config, \ 12 | load_train_config, load_config 13 | 14 | def fetch_data(dataset, frame_idx, train_config, config): 15 | aug_fn = preprocess.get_data_aug(train_config['data_aug_configs']) 16 | BOX_ENCODING_LEN = get_encoding_len(config['box_encoding_method']) 17 | box_encoding_fn = get_box_encoding_fn(config['box_encoding_method']) 18 | box_decoding_fn = get_box_decoding_fn(config['box_encoding_method']) 19 | graph_generate_fn= get_graph_generate_fn(config['graph_gen_method']) 20 | 21 | cam_rgb_points = dataset.get_cam_points_in_image_with_rgb(frame_idx, 22 | config['downsample_by_voxel_size']) 23 | 24 | box_label_list = dataset.get_label(frame_idx) 25 | if 'crop_aug' in train_config: 26 | cam_rgb_points, box_label_list = sampler.crop_aug(cam_rgb_points, 27 | box_label_list, 28 | sample_rate=train_config['crop_aug']['sample_rate'], 29 | parser_kwargs=train_config['crop_aug']['parser_kwargs']) 30 | 31 | cam_rgb_points, box_label_list = aug_fn(cam_rgb_points, box_label_list) 32 | 33 | (vertex_coord_list, keypoint_indices_list, edges_list) = \ 34 | graph_generate_fn(cam_rgb_points.xyz, **config['graph_gen_kwargs']) 35 | if config['input_features'] == 'irgb': 36 | input_v = cam_rgb_points.attr 37 | elif config['input_features'] == '0rgb': 38 | input_v = np.hstack([np.zeros((cam_rgb_points.attr.shape[0], 1)), 39 | cam_rgb_points.attr[:, 1:]]) 40 | elif config['input_features'] == '0000': 41 | input_v = np.zeros_like(cam_rgb_points.attr) 42 | elif config['input_features'] == 'i000': 43 | input_v = np.hstack([cam_rgb_points.attr[:, [0]], 44 | np.zeros((cam_rgb_points.attr.shape[0], 3))]) 45 | elif config['input_features'] == 'i': 46 | input_v = cam_rgb_points.attr[:, [0]] 47 | elif config['input_features'] == '0': 48 | input_v = np.zeros((cam_rgb_points.attr.shape[0], 1)) 49 | last_layer_graph_level = config['model_kwargs'][ 50 | 'layer_configs'][-1]['graph_level'] 51 | last_layer_points_xyz = vertex_coord_list[last_layer_graph_level+1] 52 | if config['label_method'] == 'yaw': 53 | cls_labels, boxes_3d, valid_boxes, label_map = \ 54 | dataset.assign_classaware_label_to_points(box_label_list, 55 | last_layer_points_xyz, 56 | expend_factor=train_config.get('expend_factor', (1.0, 1.0, 1.0))) 57 | if config['label_method'] == 'Car': 58 | cls_labels, boxes_3d, valid_boxes, label_map = \ 59 | dataset.assign_classaware_car_label_to_points(box_label_list, 60 | last_layer_points_xyz, 61 | expend_factor=train_config.get('expend_factor', (1.0, 1.0, 1.0))) 62 | if config['label_method'] == 'Pedestrian_and_Cyclist': 63 | (cls_labels, boxes_3d, valid_boxes, label_map) =\ 64 | dataset.assign_classaware_ped_and_cyc_label_to_points( 65 | box_label_list, last_layer_points_xyz, 66 | expend_factor=train_config.get('expend_factor', (1.0, 1.0, 1.0))) 67 | encoded_boxes = box_encoding_fn(cls_labels, last_layer_points_xyz, 68 | boxes_3d, label_map) 69 | input_v = input_v.astype(np.float32) 70 | vertex_coord_list = [p.astype(np.float32) for p in vertex_coord_list] 71 | keypoint_indices_list = [e.astype(np.int32) for e in keypoint_indices_list] 72 | edges_list = [e.astype(np.int32) for e in edges_list] 73 | cls_labels = cls_labels.astype(np.int32) 74 | encoded_boxes = encoded_boxes.astype(np.float32) 75 | valid_boxes = valid_boxes.astype(np.float32) 76 | return(input_v, vertex_coord_list, keypoint_indices_list, edges_list, 77 | cls_labels, encoded_boxes, valid_boxes) 78 | 79 | 80 | def batch_data(batch_list): 81 | N_input_v, N_vertex_coord_list, N_keypoint_indices_list, N_edges_list,\ 82 | N_cls_labels, N_encoded_boxes, N_valid_boxes = zip(*batch_list) 83 | batch_size = len(batch_list) 84 | level_num = len(N_vertex_coord_list[0]) 85 | batched_keypoint_indices_list = [] 86 | batched_edges_list = [] 87 | for level_idx in range(level_num-1): 88 | centers = [] 89 | vertices = [] 90 | point_counter = 0 91 | center_counter = 0 92 | for batch_idx in range(batch_size): 93 | centers.append( 94 | N_keypoint_indices_list[batch_idx][level_idx]+point_counter) 95 | vertices.append(np.hstack( 96 | [N_edges_list[batch_idx][level_idx][:,[0]]+point_counter, 97 | N_edges_list[batch_idx][level_idx][:,[1]]+center_counter])) 98 | point_counter += N_vertex_coord_list[batch_idx][level_idx].shape[0] 99 | center_counter += \ 100 | N_keypoint_indices_list[batch_idx][level_idx].shape[0] 101 | batched_keypoint_indices_list.append(np.vstack(centers)) 102 | batched_edges_list.append(np.vstack(vertices)) 103 | batched_vertex_coord_list = [] 104 | for level_idx in range(level_num): 105 | points = [] 106 | counter = 0 107 | for batch_idx in range(batch_size): 108 | points.append(N_vertex_coord_list[batch_idx][level_idx]) 109 | batched_vertex_coord_list.append(np.vstack(points)) 110 | batched_input_v = np.vstack(N_input_v) 111 | batched_cls_labels = np.vstack(N_cls_labels) 112 | batched_encoded_boxes = np.vstack(N_encoded_boxes) 113 | batched_valid_boxes = np.vstack(N_valid_boxes) 114 | 115 | batched_input_v = torch.from_numpy(batched_input_v) 116 | batched_vertex_coord_list = [torch.from_numpy(item) for item in batched_vertex_coord_list] 117 | batched_keypoint_indices_list = [torch.from_numpy(item).long() for item in batched_keypoint_indices_list] 118 | batched_edges_list = [torch.from_numpy(item).long() for item in batched_edges_list] 119 | batched_cls_labels = torch.from_numpy(batched_cls_labels) 120 | batched_encoded_boxes = torch.from_numpy(batched_encoded_boxes) 121 | batched_valid_boxes = torch.from_numpy(batched_valid_boxes) 122 | 123 | return (batched_input_v, batched_vertex_coord_list, 124 | batched_keypoint_indices_list, batched_edges_list, batched_cls_labels, 125 | batched_encoded_boxes, batched_valid_boxes) 126 | 127 | class DataProvider(object): 128 | """This class provides input data to training. 129 | It has option to load dataset in memory so that preprocessing does not 130 | repeat every time. 131 | Note, if there is randomness inside graph creation, dataset should be 132 | reloaded. 133 | """ 134 | def __init__(self, dataset, train_config, config, async_load_rate=1.0, result_pool_limit=10000): 135 | if 'NUM_TEST_SAMPLE' not in train_config: 136 | self.NUM_TEST_SAMPLE = dataset.num_files 137 | else: 138 | if train_config['NUM_TEST_SAMPLE'] < 0: 139 | self.NUM_TEST_SAMPLE = dataset.num_files 140 | else: 141 | self.NUM_TEST_SAMPLE = train_config['NUM_TEST_SAMPLE'] 142 | load_dataset_to_mem=train_config['load_dataset_to_mem'] 143 | load_dataset_every_N_time=train_config['load_dataset_every_N_time'] 144 | capacity=train_config['capacity'] 145 | num_workers=train_config['num_load_dataset_workers'] 146 | preload_list=list(range(self.NUM_TEST_SAMPLE)) 147 | 148 | self.dataset = dataset 149 | self.train_config = train_config 150 | self.config = config 151 | self._fetch_data = fetch_data 152 | self._batch_data = batch_data 153 | self._buffer = {} 154 | self._results = {} 155 | self._load_dataset_to_mem = load_dataset_to_mem 156 | self._load_every_N_time = load_dataset_every_N_time 157 | self._capacity = capacity 158 | self._worker_pool = Pool(processes=num_workers) 159 | self._preload_list = preload_list 160 | self._async_load_rate = async_load_rate 161 | self._result_pool_limit = result_pool_limit 162 | #if len(self._preload_list) > 0: 163 | # self.preload(self._preload_list) 164 | 165 | def preload(self, frame_idx_list): 166 | """async load dataset into memory.""" 167 | for frame_idx in frame_idx_list: 168 | result = self._worker_pool.apply_async( 169 | self._fetch_data, (self.dataset, frame_idx, self.train_config, self.config)) 170 | self._results[frame_idx] = result 171 | 172 | def async_load(self, frame_idx): 173 | """async load a data into memory""" 174 | if frame_idx in self._results: 175 | data = self._results[frame_idx].get() 176 | del self._results[frame_idx] 177 | else: 178 | data = self._fetch_data(self.dataset, frame_idx, self.train_config, self.config) 179 | if np.random.random() < self._async_load_rate: 180 | if len(self._results) < self._result_pool_limit: 181 | result = self._worker_pool.apply_async( 182 | self._fetch_data, (self.dataset, frame_idx, self.train_config, self.config)) 183 | self._results[frame_idx] = result 184 | return data 185 | 186 | def provide(self, frame_idx): 187 | if self._load_dataset_to_mem: 188 | if self._load_every_N_time >= 1: 189 | extend_frame_idx = frame_idx+np.random.choice( 190 | self._capacity)*self.NUM_TEST_SAMPLE 191 | if extend_frame_idx not in self._buffer: 192 | data = self.async_load(frame_idx) 193 | self._buffer[extend_frame_idx] = (data, 0) 194 | data, ctr = self._buffer[extend_frame_idx] 195 | if ctr == self._load_every_N_time: 196 | data = self.async_load(frame_idx) 197 | self._buffer[extend_frame_idx] = (data, 0) 198 | data, ctr = self._buffer[extend_frame_idx] 199 | self._buffer[extend_frame_idx] = (data, ctr+1) 200 | return data 201 | else: 202 | # do not buffer 203 | return self.async_load(frame_idx) 204 | else: 205 | return self._fetch_data(self.dataset, frame_idx, self.train_config, self.config) 206 | 207 | def provide_batch(self, frame_idx_list): 208 | batch_list = [] 209 | for frame_idx in frame_idx_list: 210 | batch_list.append(self.provide(frame_idx)) 211 | return self._batch_data(batch_list) 212 | 213 | 214 | if __name__ == "__main__": 215 | parser = argparse.ArgumentParser(description='Training of PointGNN') 216 | parser.add_argument('train_config_path', type=str, 217 | help='Path to train_config') 218 | parser.add_argument('config_path', type=str, 219 | help='Path to config') 220 | parser.add_argument('--dataset_root_dir', type=str, default='../dataset/kitti/', 221 | help='Path to KITTI dataset. Default="../dataset/kitti/"') 222 | parser.add_argument('--dataset_split_file', type=str, 223 | default='', 224 | help='Path to KITTI dataset split file.' 225 | 'Default="DATASET_ROOT_DIR/3DOP_splits' 226 | '/train_config["train_dataset"]"') 227 | 228 | args = parser.parse_args() 229 | train_config = load_train_config(args.train_config_path) 230 | DATASET_DIR = args.dataset_root_dir 231 | config_complete = load_config(args.config_path) 232 | if 'train' in config_complete: 233 | config = config_complete['train'] 234 | else: 235 | config = config_complete 236 | 237 | if args.dataset_split_file == '': 238 | DATASET_SPLIT_FILE = os.path.join(DATASET_DIR, 239 | './3DOP_splits/'+train_config['train_dataset']) 240 | else: 241 | DATASET_SPLIT_FILE = args.dataset_split_file 242 | 243 | # input function ============================================================== 244 | dataset = KittiDataset( 245 | os.path.join(DATASET_DIR, 'image/training/image_2'), 246 | os.path.join(DATASET_DIR, 'velodyne/training/velodyne/'), 247 | os.path.join(DATASET_DIR, 'calib/training/calib/'), 248 | os.path.join(DATASET_DIR, 'labels/training/label_2'), 249 | DATASET_SPLIT_FILE, 250 | num_classes=config['num_classes']) 251 | 252 | data_provider = DataProvider(dataset, train_config, config) 253 | 254 | input_v, vertex_coord_list, keypoint_indices_list, edges_list, \ 255 | cls_labels, encoded_boxes, valid_boxes = data_provider.provide_batch([1545, 1546]) 256 | 257 | 258 | #batch_list = [] 259 | #batch_list += [fetch_data(dataset, 1545, train_config, config)] 260 | #batch_list += [fetch_data(dataset, 1546, train_config, config)] 261 | #input_v, vertex_coord_list, keypoint_indices_list, edges_list, \ 262 | # cls_labels, encoded_boxes, valid_boxes = batch_data(batch_list) 263 | 264 | 265 | print(f"input_v: {input_v.shape}") 266 | for i, vertex_coord in enumerate(vertex_coord_list): 267 | print(f"vertex_coord: {i}: {vertex_coord.shape}") 268 | 269 | for i, indices in enumerate(keypoint_indices_list): 270 | print(f"indices: {i}: {indices.shape}") 271 | print(indices) 272 | for i, edge in enumerate(edges_list): 273 | print(f"edge: {i}: {edge.shape}") 274 | print(edge) 275 | #for item in edge: 276 | # if item[0]==item[1]: print(item) 277 | print(f"cls_labels:{cls_labels.shape}") 278 | print(f"encoded_boxes: {encoded_boxes.shape}") 279 | print(f"valid_boxes: {valid_boxes.shape}") 280 | print(valid_boxes) 281 | print(f"max: {valid_boxes.max()}, min:{valid_boxes.min()}, sum: {valid_boxes.sum()}") 282 | 283 | 284 | 285 | 286 | 287 | 288 | 289 | -------------------------------------------------------------------------------- /model.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from torch import nn 3 | from torch_scatter import scatter_max 4 | 5 | 6 | def multi_layer_neural_network_fn(Ks): 7 | linears = [] 8 | for i in range(1, len(Ks)): 9 | linears += [ 10 | nn.Linear(Ks[i-1], Ks[i]), 11 | nn.ReLU(), 12 | nn.BatchNorm1d(Ks[i])] 13 | return nn.Sequential(*linears) 14 | 15 | def multi_layer_fc_fn(Ks=[300, 64, 32, 64], num_classes=4, is_logits=False, num_layers=4): 16 | assert len(Ks) == num_layers 17 | linears = [] 18 | for i in range(1, len(Ks)): 19 | linears += [ 20 | nn.Linear(Ks[i-1], Ks[i]), 21 | nn.ReLU(), 22 | nn.BatchNorm1d(Ks[i]) 23 | ] 24 | 25 | if is_logits: 26 | linears += [ 27 | nn.Linear(Ks[-1], num_classes)] 28 | else: 29 | linears += [ 30 | nn.Linear(Ks[-1], num_classes), 31 | nn.ReLU(), 32 | nn.BatchNorm1d(num_classes) 33 | ] 34 | return nn.Sequential(*linears) 35 | 36 | def max_aggregation_fn(features, index, l): 37 | """ 38 | Arg: features: N x dim 39 | index: N x 1, e.g. [0,0,0,1,1,...l,l] 40 | l: lenght of keypoints 41 | 42 | """ 43 | index = index.unsqueeze(-1).expand(-1, features.shape[-1]) # N x 64 44 | set_features = torch.zeros((l, features.shape[-1]), device=features.device).permute(1,0).contiguous() # len x 64 45 | set_features, argmax = scatter_max(features.permute(1,0), index.permute(1,0), out=set_features) 46 | set_features = set_features.permute(1,0) 47 | return set_features 48 | 49 | def focal_loss_sigmoid(labels, logits, alpha=0.5, gamma=2): 50 | """ 51 | github.com/tensorflow/models/blob/master/\ 52 | research/object_detection/core/losses.py 53 | Computer focal loss for binary classification 54 | Args: 55 | labels: A int32 tensor of shape [batch_size]. N x 1 56 | logits: A float32 tensor of shape [batch_size]. N x C 57 | alpha: A scalar for focal loss alpha hyper-parameter. 58 | If positive samples number > negtive samples number, 59 | alpha < 0.5 and vice versa. 60 | gamma: A scalar for focal loss gamma hyper-parameter. 61 | Returns: 62 | A tensor of the same shape as `labels` 63 | """ 64 | 65 | prob = logits.sigmoid() 66 | labels = torch.nn.functional.one_hot(labels.squeeze().long(), num_classes=prob.shape[1]) 67 | 68 | cross_ent = torch.clamp(logits, min=0) - logits * labels + torch.log(1+torch.exp(-torch.abs(logits))) 69 | prob_t = (labels*prob) + (1-labels) * (1-prob) 70 | modulating = torch.pow(1-prob_t, gamma) 71 | alpha_weight = (labels*alpha)+(1-labels)*(1-alpha) 72 | 73 | focal_cross_entropy = modulating * alpha_weight * cross_ent 74 | return focal_cross_entropy 75 | 76 | class PointSetPooling(nn.Module): 77 | def __init__(self, point_MLP_depth_list=[4, 32, 64, 128, 300], output_MLP_depth_list=[300, 300, 300]): 78 | super(PointSetPooling, self).__init__() 79 | 80 | Ks = list(point_MLP_depth_list) 81 | self.point_linears = multi_layer_neural_network_fn(Ks) 82 | 83 | Ks = list(output_MLP_depth_list) 84 | self.out_linears = multi_layer_neural_network_fn(Ks) 85 | 86 | def forward(self, 87 | point_features, 88 | point_coordinates, 89 | keypoint_indices, 90 | set_indices): 91 | """apply a features extraction from point sets. 92 | Args: 93 | point_features: a [N, M] tensor. N is the number of points. 94 | M is the length of the features. 95 | point_coordinates: a [N, D] tensor. N is the number of points. 96 | D is the dimension of the coordinates. 97 | keypoint_indices: a [K, 1] tensor. Indices of K keypoints. 98 | set_indices: a [S, 2] tensor. S pairs of (point_index, set_index). 99 | i.e. (i, j) indicates point[i] belongs to the point set created by 100 | grouping around keypoint[j]. 101 | returns: a [K, output_depth] tensor as the set feature. 102 | Output_depth depends on the feature extraction options that 103 | are selected. 104 | """ 105 | 106 | #print(f"point_features: {point_features.shape}") 107 | #print(f"point_coordinates: {point_coordinates.shape}") 108 | #print(f"keypoint_indices: {keypoint_indices.shape}") 109 | #print(f"set_indices: {set_indices.shape}") 110 | 111 | # Gather the points in a set 112 | point_set_features = point_features[set_indices[:, 0]] 113 | point_set_coordinates = point_coordinates[set_indices[:, 0]] 114 | point_set_keypoint_indices = keypoint_indices[set_indices[:, 1]] 115 | 116 | #point_set_keypoint_coordinates_1 = point_features[point_set_keypoint_indices[:, 0]] 117 | point_set_keypoint_coordinates = point_coordinates[point_set_keypoint_indices[:, 0]] 118 | 119 | point_set_coordinates = point_set_coordinates - point_set_keypoint_coordinates 120 | point_set_features = torch.cat([point_set_features, point_set_coordinates], axis=-1) 121 | 122 | # Step 1: Extract all vertex_features 123 | extracted_features = self.point_linears(point_set_features) # N x 64 124 | 125 | # Step 2: Aggerate features using scatter max method. 126 | #index = set_indices[:, 1].unsqueeze(-1).expand(-1, extracted_features.shape[-1]) # N x 64 127 | #set_features = torch.zeros((len(keypoint_indices), extracted_features.shape[-1]), device=extracted_features.device).permute(1,0).contiguous() # len x 64 128 | #set_features, argmax = scatter_max(extracted_features.permute(1,0), index.permute(1,0), out=set_features) 129 | #set_features = set_features.permute(1,0) 130 | 131 | set_features = max_aggregation_fn(extracted_features, set_indices[:, 1], len(keypoint_indices)) 132 | 133 | # Step 3: MLP for set_features 134 | set_features = self.out_linears(set_features) 135 | return set_features 136 | 137 | class GraphNetAutoCenter(nn.Module): 138 | def __init__(self, auto_offset=True, auto_offset_MLP_depth_list=[300, 64, 3], edge_MLP_depth_list=[303, 300, 300], update_MLP_depth_list=[300, 300, 300]): 139 | super(GraphNetAutoCenter, self).__init__() 140 | self.auto_offset = auto_offset 141 | self.auto_offset_fn = multi_layer_neural_network_fn(auto_offset_MLP_depth_list) 142 | self.edge_feature_fn = multi_layer_neural_network_fn(edge_MLP_depth_list) 143 | self.update_fn = multi_layer_neural_network_fn(update_MLP_depth_list) 144 | 145 | 146 | def forward(self, input_vertex_features, 147 | input_vertex_coordinates, 148 | keypoint_indices, 149 | edges): 150 | """apply one layer graph network on a graph. . 151 | Args: 152 | input_vertex_features: a [N, M] tensor. N is the number of vertices. 153 | M is the length of the features. 154 | input_vertex_coordinates: a [N, D] tensor. N is the number of 155 | vertices. D is the dimension of the coordinates. 156 | NOT_USED: leave it here for API compatibility. 157 | edges: a [K, 2] tensor. K pairs of (src, dest) vertex indices. 158 | returns: a [N, M] tensor. Updated vertex features. 159 | """ 160 | #print(f"input_vertex_features: {input_vertex_features.shape}") 161 | #print(f"input_vertex_coordinates: {input_vertex_coordinates.shape}") 162 | #print(NOT_USED) 163 | #print(f"edges: {edges.shape}") 164 | 165 | # Gather the source vertex of the edges 166 | s_vertex_features = input_vertex_features[edges[:, 0]] 167 | s_vertex_coordinates = input_vertex_coordinates[edges[:, 0]] 168 | 169 | if self.auto_offset: 170 | offset = self.auto_offset_fn(input_vertex_features) 171 | input_vertex_coordinates = input_vertex_coordinates + offset 172 | 173 | # Gather the destination vertex of the edges 174 | d_vertex_coordinates = input_vertex_coordinates[edges[:, 1]] 175 | 176 | # Prepare initial edge features 177 | edge_features = torch.cat([s_vertex_features, s_vertex_coordinates - d_vertex_coordinates], dim=-1) 178 | 179 | # Extract edge features 180 | edge_features = self.edge_feature_fn(edge_features) 181 | 182 | # Aggregate edge features 183 | aggregated_edge_features = max_aggregation_fn(edge_features, edges[:,1], len(keypoint_indices)) 184 | 185 | # Update vertex features 186 | update_features = self.update_fn(aggregated_edge_features) 187 | output_vertex_features = update_features + input_vertex_features 188 | return output_vertex_features 189 | 190 | class ClassAwarePredictor(nn.Module): 191 | def __init__(self, num_classes, box_encoding_len): 192 | super(ClassAwarePredictor, self).__init__() 193 | self.cls_fn = multi_layer_fc_fn(Ks=[300, 64], num_layers=2, num_classes=num_classes, is_logits=True) 194 | self.loc_fns = nn.ModuleList() 195 | self.num_classes = num_classes 196 | self.box_encoding_len = box_encoding_len 197 | 198 | for i in range(num_classes): 199 | self.loc_fns += [ 200 | multi_layer_fc_fn(Ks=[300, 300, 64], num_layers=3, num_classes=box_encoding_len, is_logits=True)] 201 | 202 | def forward(self, features): 203 | logits = self.cls_fn(features) 204 | box_encodings_list = [] 205 | for loc_fn in self.loc_fns: 206 | box_encodings = loc_fn(features).unsqueeze(1) 207 | box_encodings_list += [box_encodings] 208 | 209 | box_encodings = torch.cat(box_encodings_list, dim=1) 210 | return logits, box_encodings 211 | 212 | class MultiLayerFastLocalGraphModelV2(nn.Module): 213 | def __init__(self, num_classes, box_encoding_len, regularizer_type=None, \ 214 | regularizer_kwargs=None, layer_configs=None, mode=None, graph_net_layers=3): 215 | super(MultiLayerFastLocalGraphModelV2, self).__init__() 216 | self.num_classes = num_classes 217 | self.box_encoding_len = box_encoding_len 218 | self.point_set_pooling = PointSetPooling() 219 | 220 | self.graph_nets = nn.ModuleList() 221 | for i in range(graph_net_layers): 222 | self.graph_nets.append(GraphNetAutoCenter()) 223 | 224 | self.predictor = ClassAwarePredictor(num_classes, box_encoding_len) 225 | 226 | 227 | def forward(self, batch, is_training): 228 | input_v, vertex_coord_list, keypoint_indices_list, edges_list, \ 229 | cls_labels, encoded_boxes, valid_boxes = batch 230 | 231 | point_features, point_coordinates, keypoint_indices, set_indices = input_v, vertex_coord_list[0], keypoint_indices_list[0], edges_list[0] 232 | point_features = self.point_set_pooling(point_features, point_coordinates, keypoint_indices, set_indices) 233 | 234 | 235 | point_coordinates, keypoint_indices, set_indices = vertex_coord_list[1], keypoint_indices_list[1], edges_list[1] 236 | for i, graph_net in enumerate(self.graph_nets): 237 | point_features = graph_net(point_features, point_coordinates, keypoint_indices, set_indices) 238 | logits, box_encodings = self.predictor(point_features) 239 | return logits, box_encodings 240 | 241 | def postprocess(self, logits): 242 | softmax = nn.Softmax(dim=1) 243 | prob = softmax(logits) 244 | return prob 245 | 246 | def loss(self, logits, labels, pred_box, gt_box, valid_box, 247 | cls_loss_type="focal_sigmoid", loc_loss_type='huber_loss', loc_loss_weight=1.0, cls_loss_weight=1.0): 248 | """Output loss value. 249 | Args: 250 | logits: [N, num_classes] tensor. The classification logits from 251 | predict method. 252 | labels: [N] tensor. The one hot class labels. 253 | pred_box: [N, num_classes, box_encoding_len] tensor. The encoded 254 | bounding boxes from the predict method. 255 | gt_box: [N, 1, box_encoding_len] tensor. The ground truth encoded 256 | bounding boxes. 257 | valid_box: [N] tensor. An indicator of whether the vertex is from 258 | an object of interest (whether it has a valid bounding box). 259 | cls_loss_type: string, the type of classification loss function. 260 | cls_loss_kwargs: dict, keyword args to the classifcation loss. 261 | loc_loss_type: string, the type of localization loss function. 262 | loc_loss_kwargs: dict, keyword args to the localization loss. 263 | loc_loss_weight: scalar, weight on localization loss. 264 | cls_loss_weight: scalar, weight on the classifcation loss. 265 | returns: a dict of cls_loss, loc_loss, reg_loss, num_endpoint, 266 | num_valid_endpoint. num_endpoint is the number of output vertices. 267 | num_valid_endpoint is the number of output vertices that have a valid 268 | bounding box. Those numbers are useful for weighting during batching. 269 | """ 270 | """ 271 | print(f"logits: {logits.shape}") 272 | print(f"labels: {labels.shape}") 273 | print(f"pred_box: {pred_box.shape}") 274 | print(f"gt_box: {gt_box.shape}") 275 | print(f"valid_box: {valid_box.shape}") 276 | """ 277 | 278 | point_loss = focal_loss_sigmoid(labels,logits) # same shape as logits, N x C 279 | num_endpoint = point_loss.shape[0] 280 | cls_loss = cls_loss_weight * point_loss.mean() 281 | 282 | batch_idx = torch.arange(0, pred_box.shape[0]) 283 | batch_idx = batch_idx.unsqueeze(1).to(labels.device) 284 | batch_idx = torch.cat([batch_idx, labels], dim=1) 285 | pred_box = pred_box[batch_idx[:, 0], batch_idx[:, 1]] 286 | huger_loss = nn.SmoothL1Loss(reduction="none") 287 | all_loc_loss = huger_loss(pred_box, gt_box.squeeze()) 288 | all_loc_loss = all_loc_loss * valid_box.squeeze(1) 289 | 290 | num_valid_endpoint = valid_box.sum() 291 | mean_loc_loss = all_loc_loss.mean(dim=1) 292 | 293 | if num_valid_endpoint==0: 294 | loc_loss = 0 295 | else: loc_loss = mean_loc_loss.sum() / num_valid_endpoint 296 | classwise_loc_loss = [] 297 | 298 | 299 | for class_idx in range(self.num_classes): 300 | class_mask = torch.nonzero(labels==int(class_idx), as_tuple=False) 301 | l = mean_loc_loss[class_mask] 302 | classwise_loc_loss += [l] 303 | loss_dict = {} 304 | loss_dict['classwise_loc_loss'] = classwise_loc_loss 305 | 306 | params = torch.cat([x.view(-1) for x in self.parameters()]) 307 | reg_loss = torch.mean(params.abs()) 308 | 309 | loss_dict.update({'cls_loss': cls_loss, 'loc_loss': loc_loss, 310 | 'reg_loss': reg_loss, 311 | 'num_end_point': num_endpoint, 312 | 'num_valid_endpoint': num_valid_endpoint 313 | }) 314 | return loss_dict 315 | -------------------------------------------------------------------------------- /models/crop_aug.py: -------------------------------------------------------------------------------- 1 | """This file implement augmentation by cropping and parsing ground truth boxes""" 2 | 3 | import os 4 | import json 5 | 6 | import numpy as np 7 | import open3d 8 | from copy import deepcopy 9 | from tqdm import tqdm 10 | 11 | from dataset.kitti_dataset import KittiDataset, sel_xyz_in_box3d, \ 12 | sel_xyz_in_box2d, Points 13 | from models.nms import boxes_3d_to_corners, overlapped_boxes_3d, \ 14 | overlapped_boxes_3d_fast_poly 15 | from models import preprocess 16 | 17 | def save_cropped_boxes(dataset, filename, expand_factor=[1.1, 1.1, 1.1], 18 | minimum_points=10, backlist=[]): 19 | cropped_labels = {} 20 | cropped_cam_points = {} 21 | for frame_idx in tqdm(range(dataset.num_files)): 22 | labels = dataset.get_label(frame_idx) 23 | cam_points = dataset.get_cam_points_in_image_with_rgb(frame_idx) 24 | for label in labels: 25 | if label['name'] != "DontCare": 26 | if label['name'] not in backlist: 27 | mask = sel_xyz_in_box3d(label, cam_points.xyz, 28 | expand_factor) 29 | if np.sum(mask) > minimum_points: 30 | if label['name'] in cropped_labels: 31 | cropped_labels[label['name']].append(label) 32 | cropped_cam_points[label['name']].append( 33 | [cam_points.xyz[mask].tolist(), 34 | cam_points.attr[mask].tolist()]) 35 | else: 36 | cropped_labels[label['name']] = [label] 37 | cropped_cam_points[label['name']] = [ 38 | [cam_points.xyz[mask].tolist(), 39 | cam_points.attr[mask].tolist()]] 40 | 41 | with open(filename, 'w') as outfile: 42 | json.dump((cropped_labels,cropped_cam_points), outfile) 43 | 44 | def load_cropped_boxes(filename): 45 | with open(filename, 'r') as infile: 46 | cropped_labels, cropped_cam_points = json.load(infile) 47 | for key in cropped_cam_points: 48 | print("Load %d %s" % (len(cropped_cam_points[key]), key)) 49 | for i, cam_points in enumerate(cropped_cam_points[key]): 50 | cropped_cam_points[key][i] = Points(xyz=np.array(cam_points[0]), 51 | attr=np.array(cam_points[1])) 52 | return cropped_labels, cropped_cam_points 53 | 54 | def vis_cropped_boxes(cropped_labels, cropped_cam_points, dataset): 55 | for key in cropped_cam_points: 56 | if key == 'Pedestrian': 57 | for i, cam_points in enumerate(cropped_cam_points[key]): 58 | label = cropped_labels[key][i] 59 | print(label['name']) 60 | pcd = open3d.PointCloud() 61 | pcd.points = open3d.Vector3dVector(cam_points.xyz) 62 | pcd.colors = open3d.Vector3dVector(cam_points.attr[:, 1:]) 63 | def custom_draw_geometry_load_option(geometry_list): 64 | vis = open3d.Visualizer() 65 | vis.create_window() 66 | for geometry in geometry_list: 67 | vis.add_geometry(geometry) 68 | ctr = vis.get_view_control() 69 | ctr.rotate(0.0, 3141.0, 0) 70 | vis.run() 71 | vis.destroy_window() 72 | custom_draw_geometry_load_option( 73 | [pcd] + dataset.draw_open3D_box(label)) 74 | 75 | def parser_without_collision(cam_rgb_points, labels, 76 | sample_cam_points, sample_labels, 77 | overlap_mode = 'box', 78 | auto_box_height = False, 79 | max_overlap_rate = 0.01, 80 | appr_factor = 100, 81 | max_overlap_num_allowed=1, max_trails=1, method_name='normal', 82 | yaw_std=0.3, expand_factor=(1.1, 1.1, 1.1), 83 | must_have_ground=False): 84 | xyz = cam_rgb_points.xyz 85 | attr = cam_rgb_points.attr 86 | if overlap_mode == 'box' or overlap_mode == 'box_and_point': 87 | label_boxes = np.array([ 88 | [l['x3d'], l['y3d'], l['z3d'], l['length'], 89 | l['height'], l['width'], l['yaw']] 90 | for l in labels ]) 91 | label_boxes_corners = np.int32( 92 | appr_factor*boxes_3d_to_corners(label_boxes)) 93 | for i, label in enumerate(sample_labels): 94 | trial = 0 95 | sucess = False 96 | for trial in range(max_trails): 97 | # random rotate 98 | if method_name == 'normal': 99 | delta_yaw = np.random.normal(scale=yaw_std) 100 | else: 101 | if method_name == 'uniform': 102 | delta_yaw = np.random.uniform(low=-yaw_std, high=yaw_std) 103 | new_label = deepcopy(label) 104 | R = np.array([[np.cos(delta_yaw), 0, np.sin(delta_yaw)], 105 | [0, 1, 0 ], 106 | [-np.sin(delta_yaw), 0, np.cos(delta_yaw)]]); 107 | tx = new_label['x3d'] 108 | ty = new_label['y3d'] 109 | tz = new_label['z3d'] 110 | xyz_center = np.array([[tx, ty, tz]]) 111 | xyz_center = xyz_center.dot(np.transpose(R)) 112 | new_label['x3d'], new_label['y3d'], new_label['z3d'] = xyz_center[0] 113 | new_label['yaw'] = new_label['yaw']+delta_yaw 114 | if auto_box_height: 115 | original_height = new_label['height'] 116 | mask_2d = sel_xyz_in_box2d(new_label, xyz, expand_factor) 117 | if np.sum(mask_2d) > 0: 118 | ground_height = np.amax(xyz[mask_2d][:,1]) 119 | y3d_adjust = ground_height - new_label['y3d'] 120 | else: 121 | if must_have_ground: 122 | continue; 123 | y3d_adjust = 0 124 | # if np.abs(y3d_adjust) > 1: 125 | # y3d_adjust = 0 126 | new_label['y3d'] += y3d_adjust 127 | new_label['height'] = original_height 128 | mask = sel_xyz_in_box3d(new_label, xyz, expand_factor) 129 | # check if the new box includes more points than before 130 | below_overlap = False 131 | if overlap_mode == 'box': 132 | new_boxes = np.array([ 133 | [new_label['x3d'], 134 | new_label['y3d'], 135 | new_label['z3d'], 136 | new_label['length'], 137 | new_label['height'], 138 | new_label['width'], 139 | new_label['yaw']] 140 | ]) 141 | new_boxes_corners = np.int32( 142 | appr_factor*boxes_3d_to_corners(new_boxes)) 143 | below_overlap = np.all(overlapped_boxes_3d_fast_poly( 144 | new_boxes_corners[0], 145 | label_boxes_corners) < max_overlap_rate) 146 | if overlap_mode == 'point': 147 | below_overlap = np.sum(mask) < max_overlap_num_allowed 148 | if overlap_mode == 'box_and_point': 149 | new_boxes = np.array([ 150 | [new_label['x3d'], 151 | new_label['y3d'], 152 | new_label['z3d'], 153 | new_label['length'], 154 | new_label['height'], 155 | new_label['width'], 156 | new_label['yaw']] 157 | ]) 158 | new_boxes_corners = np.int32( 159 | appr_factor*boxes_3d_to_corners(new_boxes)) 160 | below_overlap = np.all( 161 | overlapped_boxes_3d_fast_poly(new_boxes_corners[0], 162 | label_boxes_corners) < max_overlap_rate) 163 | below_overlap = np.logical_and(below_overlap, 164 | (np.sum(mask) < max_overlap_num_allowed)) 165 | if below_overlap: 166 | 167 | points_xyz = sample_cam_points[i].xyz 168 | points_attr = sample_cam_points[i].attr 169 | points_xyz = points_xyz.dot(np.transpose(R)) 170 | if auto_box_height: 171 | points_xyz[:,1] = points_xyz[:,1] + y3d_adjust 172 | xyz = xyz[np.logical_not(mask)] 173 | xyz = np.concatenate([points_xyz, xyz], axis=0) 174 | attr = attr[np.logical_not(mask)] 175 | attr = np.concatenate([points_attr, attr], axis=0) 176 | # update boxes and label 177 | labels.append(new_label) 178 | if overlap_mode == 'box' or overlap_mode == 'box_and_point': 179 | label_boxes_corners = np.append(label_boxes_corners, 180 | new_boxes_corners,axis=0) 181 | sucess = True 182 | break; 183 | # if not sucess: 184 | # if not sucess, keep the old label 185 | # print('Warning: fail to parse cropped box') 186 | return Points(xyz=xyz, attr=attr), labels 187 | 188 | class CropAugSampler(): 189 | """ A class to sample from cropped objects and parse it to a frame """ 190 | def __init__(self, crop_filename): 191 | self._cropped_labels, self._cropped_cam_points = load_cropped_boxes(\ 192 | crop_filename) 193 | def crop_aug(self, cam_rgb_points, labels, 194 | sample_rate={"Car":1, "Pedestrian":1, "Cyclist":1}, 195 | parser_kwargs={}): 196 | sample_labels = [] 197 | sample_cam_points = [] 198 | for key in sample_rate: 199 | sample_indices = np.random.choice(len(self._cropped_labels[key]), 200 | size=sample_rate[key], replace=False) 201 | sample_labels.extend( 202 | deepcopy([self._cropped_labels[key][idx] 203 | for idx in sample_indices])) 204 | sample_cam_points.extend( 205 | deepcopy([self._cropped_cam_points[key][idx] 206 | for idx in sample_indices])) 207 | return parser_without_collision(cam_rgb_points, labels, 208 | sample_cam_points, sample_labels, 209 | **parser_kwargs) 210 | 211 | def vis_crop_aug_sampler(crop_filename, dataset): 212 | sampler = CropAugSampler(crop_filename) 213 | for frame_idx in range(10): 214 | labels = dataset.get_label(frame_idx) 215 | cam_rgb_points = dataset.get_cam_points_in_image_with_rgb(frame_idx) 216 | cam_rgb_points, labels = sampler.crop_aug(cam_rgb_points, labels, 217 | sample_rate={"Car":2, "Pedestrian":10, "Cyclist":10}, 218 | parser_kwargs={ 219 | 'max_overlap_num_allowed': 50, 220 | 'max_trails':100, 221 | 'method_name':'normal', 222 | 'yaw_std':np.pi/16, 223 | 'expand_factor':(1.1, 1.1, 1.1), 224 | 'auto_box_height': True, 225 | 'overlap_mode':'box_and_point', 226 | 'max_overlap_rate': 1e-6, 227 | 'appr_factor': 100, 228 | 'must_have_ground': True, 229 | }) 230 | aug_configs = [ 231 | {'method_name': 'random_box_global_rotation', 232 | 'method_kwargs': { 'max_overlap_num_allowed':100, 233 | 'max_trails': 100, 234 | 'appr_factor':100, 235 | 'method_name':'normal', 236 | 'yaw_std':np.pi/8, 237 | 'expend_factor':(1.1, 1.1, 1.1) 238 | } 239 | } 240 | ] 241 | aug_fn = preprocess.get_data_aug(aug_configs) 242 | cam_rgb_points, labels = aug_fn(cam_rgb_points, labels) 243 | dataset.vis_points(cam_rgb_points, labels, expend_factor=(1.1, 1.1,1.1)) 244 | 245 | 246 | # # Example of usage 247 | # print('generate training split: ') 248 | # kitti_train = KittiDataset( 249 | # '../dataset/kitti/image/training/image_2', 250 | # '../dataset/kitti/velodyne/training/velodyne/', 251 | # '../dataset/kitti/calib/training/calib/', 252 | # '../dataset/kitti/labels/training/label_2/', 253 | # '../dataset/kitti/3DOP_splits/train.txt',) 254 | # save_cropped_boxes(kitti_train, "../dataset/kitti/cropped/car_person_cyclist_train.json", 255 | # expand_factor = (1.1, 1.1, 1.1), minimum_points=10, 256 | # backlist=['Van', 'Truck', 'Misc', 'Tram', 'Person_sitting']) 257 | # print("generate val split: ") 258 | # kitti_val = KittiDataset( 259 | # '../dataset/kitti/image/training/image_2', 260 | # '../dataset/kitti/velodyne/training/velodyne/', 261 | # '../dataset/kitti/calib/training/calib/', 262 | # '../dataset/kitti/labels/training/label_2/', 263 | # '../dataset/kitti/3DOP_splits/val.txt',) 264 | # save_cropped_boxes(kitti_val, "../dataset/kitti/cropped/car_person_cyclist_val.json", 265 | # expand_factor = (1.1, 1.1, 1.1), minimum_points=10, 266 | # backlist=['Van', 'Truck', 'Misc', 'Tram', 'Person_sitting']) 267 | # print("generate trainval: ") 268 | # kitti_trainval = KittiDataset( 269 | # '../dataset/kitti/image/training/image_2', 270 | # '../dataset/kitti/velodyne/training/velodyne/', 271 | # '../dataset/kitti/calib/training/calib/', 272 | # '../dataset/kitti/labels/training/label_2/', 273 | # '../dataset/kitti/3DOP_splits/trainval.txt',) 274 | # save_cropped_boxes(kitti_trainval, "../dataset/kitti/cropped/car_person_cyclist_trainval.json", 275 | # expand_factor = (1.1, 1.1, 1.1), minimum_points=10, 276 | # backlist=['Van', 'Truck', 'Misc', 'Tram', 'Person_sitting']) 277 | # cropped_labels, cropped_cam_points = load_cropped_boxes( 278 | # "../dataset/kitti/cropped/car_person_cyclist_train.json") 279 | # vis_cropped_boxes(cropped_labels, cropped_cam_points, kitti_train) 280 | # cropped_labels, cropped_cam_points = load_cropped_boxes( 281 | # "../dataset/kitti/cropped/car_person_cyclist_val.json") 282 | # vis_cropped_boxes(cropped_labels, cropped_cam_points, kitti_val) 283 | # cropped_labels, cropped_cam_points = load_cropped_boxes( 284 | # "../dataset/kitti/cropped/car_person_cyclist_trainval.json") 285 | # vis_cropped_boxes(cropped_labels, cropped_cam_points, kitti_trainval) 286 | # vis_crop_aug_sampler("../dataset/kitti/cropped/car_person_cyclist_val.json", kitti_val) 287 | -------------------------------------------------------------------------------- /models/gnn.py: -------------------------------------------------------------------------------- 1 | """This file defines classes for the graph neural network. """ 2 | 3 | from functools import partial 4 | 5 | import tensorflow as tf 6 | import numpy as np 7 | import tensorflow.contrib.slim as slim 8 | 9 | def instance_normalization(features): 10 | with tf.variable_scope(None, default_name='IN'): 11 | mean, variance = tf.nn.moments( 12 | features, [0], name='IN_stats', keep_dims=True) 13 | features = tf.nn.batch_normalization( 14 | features, mean, variance, None, None, 1e-12, name='IN_apply') 15 | return(features) 16 | 17 | normalization_fn_dict = { 18 | 'fused_BN_center': slim.batch_norm, 19 | 'BN': partial(slim.batch_norm, fused=False, center=False), 20 | 'BN_center': partial(slim.batch_norm, fused=False), 21 | 'IN': instance_normalization, 22 | 'NONE': None 23 | } 24 | activation_fn_dict = { 25 | 'ReLU': tf.nn.relu, 26 | 'ReLU6': tf.nn.relu6, 27 | 'LeakyReLU': partial(tf.nn.leaky_relu, alpha=0.01), 28 | 'ELU':tf.nn.elu, 29 | 'NONE': None, 30 | 'Sigmoid': tf.nn.sigmoid, 31 | 'Tanh': tf.nn.tanh, 32 | } 33 | 34 | def multi_layer_fc_fn(sv, mask=None, Ks=(64, 32, 64), num_classes=4, 35 | is_logits=False, num_layer=4, normalization_type="fused_BN_center", 36 | activation_type='ReLU'): 37 | """A function to create multiple layers of neural network to compute 38 | features passing through each edge. 39 | 40 | Args: 41 | sv: a [N, M] or [T, DEGREE, M] tensor. 42 | N is the total number of edges, M is the length of features. T is 43 | the number of recieving vertices, DEGREE is the in-degree of each 44 | recieving vertices. When a [T, DEGREE, M] tensor is provided, the 45 | degree of each recieving vertex is assumed to be same. 46 | N is the total number of edges, M is the length of features. T is 47 | the number of recieving vertices, DEGREE is the in-degree of each 48 | recieving vertices. When a [T, DEGREE, M] tensor is provided, the 49 | degree of each recieving vertex is assumed to be same. 50 | mask: a optional [N, 1] or [T, DEGREE, 1] tensor. A value 1 is used 51 | to indicate a valid output feature, while a value 0 indicates 52 | an invalid output feature which is set to 0. 53 | num_layer: number of layers to add. 54 | 55 | returns: a [N, K] tensor or [T, DEGREE, K]. 56 | K is the length of the new features on the edge. 57 | """ 58 | assert len(sv.shape) == 2 59 | assert len(Ks) == num_layer-1 60 | if is_logits: 61 | features = sv 62 | for i in range(num_layer-1): 63 | features = slim.fully_connected(features, Ks[i], 64 | activation_fn=activation_fn_dict[activation_type], 65 | normalizer_fn=normalization_fn_dict[normalization_type], 66 | ) 67 | features = slim.fully_connected(features, num_classes, 68 | activation_fn=None, 69 | normalizer_fn=None 70 | ) 71 | else: 72 | features = sv 73 | for i in range(num_layer-1): 74 | features = slim.fully_connected(features, Ks[i], 75 | activation_fn=activation_fn_dict[activation_type], 76 | normalizer_fn=normalization_fn_dict[normalization_type], 77 | ) 78 | features = slim.fully_connected(features, num_classes, 79 | activation_fn=activation_fn_dict[activation_type], 80 | normalizer_fn=normalization_fn_dict[normalization_type], 81 | ) 82 | if mask is not None: 83 | features = features * mask 84 | return features 85 | 86 | def multi_layer_neural_network_fn(features, Ks=(64, 32, 64), is_logits=False, 87 | normalization_type="fused_BN_center", activation_type='ReLU'): 88 | """A function to create multiple layers of neural network. 89 | """ 90 | assert len(features.shape) == 2 91 | if is_logits: 92 | for i in range(len(Ks)-1): 93 | features = slim.fully_connected(features, Ks[i], 94 | activation_fn=activation_fn_dict[activation_type], 95 | normalizer_fn=normalization_fn_dict[normalization_type]) 96 | features = slim.fully_connected(features, Ks[-1], 97 | activation_fn=None, 98 | normalizer_fn=None) 99 | else: 100 | for i in range(len(Ks)): 101 | features = slim.fully_connected(features, Ks[i], 102 | activation_fn=activation_fn_dict[activation_type], 103 | normalizer_fn=normalization_fn_dict[normalization_type]) 104 | return features 105 | 106 | def graph_scatter_max_fn(point_features, point_centers, num_centers): 107 | aggregated = tf.math.unsorted_segment_max(point_features, 108 | point_centers, num_centers, name='scatter_max') 109 | return aggregated 110 | 111 | def graph_scatter_sum_fn(point_features, point_centers, num_centers): 112 | aggregated = tf.math.unsorted_segment_sum(point_features, 113 | point_centers, num_centers, name='scatter_sum') 114 | return aggregated 115 | 116 | def graph_scatter_mean_fn(point_features, point_centers, num_centers): 117 | aggregated = tf.math.unsorted_segment_mean(point_features, 118 | point_centers, num_centers, name='scatter_mean') 119 | return aggregated 120 | 121 | class ClassAwarePredictor(object): 122 | """A class to predict 3D bounding boxes and class labels.""" 123 | 124 | def __init__(self, cls_fn, loc_fn): 125 | """ 126 | Args: 127 | cls_fn: a function to classify labels. 128 | loc_fn: a function to predict 3D bounding boxes. 129 | """ 130 | self._cls_fn = cls_fn 131 | self._loc_fn = loc_fn 132 | 133 | def apply_regular(self, features, num_classes, box_encoding_len, 134 | normalization_type='fused_BN_center', 135 | activation_type='ReLU'): 136 | """ 137 | Args: 138 | input_v: input feature vectors. [N, M]. 139 | output_v: not used. 140 | A: not used. 141 | num_classes: the number of classes to predict. 142 | 143 | returns: logits, box_encodings. 144 | """ 145 | box_encodings_list = [] 146 | with tf.variable_scope('predictor'): 147 | with tf.variable_scope('cls'): 148 | logits = self._cls_fn( 149 | features, num_classes=num_classes, is_logits=True, 150 | normalization_type=normalization_type, 151 | activation_type=activation_type) 152 | with tf.variable_scope('loc'): 153 | for class_idx in range(num_classes): 154 | with tf.variable_scope('cls_%d' % class_idx): 155 | box_encodings = self._loc_fn( 156 | features, num_classes=box_encoding_len, 157 | is_logits=True, 158 | normalization_type=normalization_type, 159 | activation_type=activation_type) 160 | box_encodings = tf.expand_dims(box_encodings, axis=1) 161 | box_encodings_list.append(box_encodings) 162 | box_encodings = tf.concat(box_encodings_list, axis=1) 163 | return logits, box_encodings 164 | 165 | class ClassAwareSeparatedPredictor(object): 166 | """A class to predict 3D bounding boxes and class labels.""" 167 | 168 | def __init__(self, cls_fn, loc_fn): 169 | """ 170 | Args: 171 | cls_fn: a function to classify labels. 172 | loc_fn: a function to predict 3D bounding boxes. 173 | """ 174 | self._cls_fn = cls_fn 175 | self._loc_fn = loc_fn 176 | 177 | def apply_regular(self, features, num_classes, box_encoding_len, 178 | normalization_type='fused_BN_center', 179 | activation_type='ReLU'): 180 | """ 181 | Args: 182 | input_v: input feature vectors. [N, M]. 183 | output_v: not used. 184 | A: not used. 185 | num_classes: the number of classes to predict. 186 | 187 | returns: logits, box_encodings. 188 | """ 189 | box_encodings_list = [] 190 | with tf.variable_scope('predictor'): 191 | with tf.variable_scope('cls'): 192 | logits = self._cls_fn( 193 | features, num_classes=num_classes, is_logits=True, 194 | normalization_type=normalization_type, 195 | activation_type=activation_type) 196 | features_splits = tf.split(features, num_classes, axis=-1) 197 | with tf.variable_scope('loc'): 198 | for class_idx in range(num_classes): 199 | with tf.variable_scope('cls_%d' % class_idx): 200 | box_encodings = self._loc_fn( 201 | features_splits[class_idx], 202 | num_classes=box_encoding_len, 203 | is_logits=True, 204 | normalization_type=normalization_type, 205 | activation_type=activation_type) 206 | box_encodings = tf.expand_dims(box_encodings, axis=1) 207 | box_encodings_list.append(box_encodings) 208 | box_encodings = tf.concat(box_encodings_list, axis=1) 209 | return logits, box_encodings 210 | 211 | class PointSetPooling(object): 212 | """A class to implement local graph netural network.""" 213 | 214 | def __init__(self, 215 | point_feature_fn=multi_layer_neural_network_fn, 216 | aggregation_fn=graph_scatter_max_fn, 217 | output_fn=multi_layer_neural_network_fn): 218 | self._point_feature_fn = point_feature_fn 219 | self._aggregation_fn = aggregation_fn 220 | self._output_fn = output_fn 221 | 222 | def apply_regular(self, 223 | point_features, 224 | point_coordinates, 225 | keypoint_indices, 226 | set_indices, 227 | point_MLP_depth_list=None, 228 | point_MLP_normalization_type='fused_BN_center', 229 | point_MLP_activation_type = 'ReLU', 230 | output_MLP_depth_list=None, 231 | output_MLP_normalization_type='fused_BN_center', 232 | output_MLP_activation_type = 'ReLU'): 233 | """apply a features extraction from point sets. 234 | 235 | Args: 236 | point_features: a [N, M] tensor. N is the number of points. 237 | M is the length of the features. 238 | point_coordinates: a [N, D] tensor. N is the number of points. 239 | D is the dimension of the coordinates. 240 | keypoint_indices: a [K, 1] tensor. Indices of K keypoints. 241 | set_indices: a [S, 2] tensor. S pairs of (point_index, set_index). 242 | i.e. (i, j) indicates point[i] belongs to the point set created by 243 | grouping around keypoint[j]. 244 | point_MLP_depth_list: a list of MLP units to extract point features. 245 | point_MLP_normalization_type: the normalization function of MLP. 246 | point_MLP_activation_type: the activation function of MLP. 247 | output_MLP_depth_list: a list of MLP units to embedd set features. 248 | output_MLP_normalization_type: the normalization function of MLP. 249 | output_MLP_activation_type: the activation function of MLP. 250 | 251 | returns: a [K, output_depth] tensor as the set feature. 252 | Output_depth depends on the feature extraction options that 253 | are selected. 254 | """ 255 | # Gather the points in a set 256 | point_set_features = tf.gather(point_features, set_indices[:,0]) 257 | point_set_coordinates = tf.gather(point_coordinates, set_indices[:,0]) 258 | # Gather the keypoints for each set 259 | point_set_keypoint_indices = tf.gather( 260 | keypoint_indices, set_indices[:, 1]) 261 | point_set_keypoint_coordinates = tf.gather(point_coordinates, 262 | point_set_keypoint_indices[:,0]) 263 | # points within a set use relative coordinates to its keypoint 264 | point_set_coordinates = \ 265 | point_set_coordinates - point_set_keypoint_coordinates 266 | point_set_features = tf.concat( 267 | [point_set_features, point_set_coordinates], axis=-1) 268 | with tf.variable_scope('extract_vertex_features'): 269 | # Step 1: Extract all vertex_features 270 | extracted_point_features = self._point_feature_fn( 271 | point_set_features, 272 | Ks=point_MLP_depth_list, is_logits=False, 273 | normalization_type=point_MLP_normalization_type, 274 | activation_type=point_MLP_activation_type) 275 | set_features = self._aggregation_fn( 276 | extracted_point_features, set_indices[:, 1], 277 | tf.shape(keypoint_indices)[0]) 278 | with tf.variable_scope('combined_features'): 279 | set_features = self._output_fn(set_features, 280 | Ks=output_MLP_depth_list, is_logits=False, 281 | normalization_type=output_MLP_normalization_type, 282 | activation_type=output_MLP_activation_type) 283 | return set_features 284 | 285 | class GraphNetAutoCenter(object): 286 | """A class to implement point graph netural network layer.""" 287 | 288 | def __init__(self, 289 | edge_feature_fn=multi_layer_neural_network_fn, 290 | aggregation_fn=graph_scatter_max_fn, 291 | update_fn=multi_layer_neural_network_fn, 292 | auto_offset_fn=multi_layer_neural_network_fn): 293 | self._edge_feature_fn = edge_feature_fn 294 | self._aggregation_fn = aggregation_fn 295 | self._update_fn = update_fn 296 | self._auto_offset_fn = auto_offset_fn 297 | 298 | def apply_regular(self, 299 | input_vertex_features, 300 | input_vertex_coordinates, 301 | NOT_USED, 302 | edges, 303 | edge_MLP_depth_list=None, 304 | edge_MLP_normalization_type='fused_BN_center', 305 | edge_MLP_activation_type = 'ReLU', 306 | update_MLP_depth_list=None, 307 | update_MLP_normalization_type='fused_BN_center', 308 | update_MLP_activation_type = 'ReLU', 309 | auto_offset=False, 310 | auto_offset_MLP_depth_list=None, 311 | auto_offset_MLP_normalization_type='fused_BN_center', 312 | auto_offset_MLP_feature_activation_type = 'ReLU', 313 | ): 314 | """apply one layer graph network on a graph. . 315 | 316 | Args: 317 | input_vertex_features: a [N, M] tensor. N is the number of vertices. 318 | M is the length of the features. 319 | input_vertex_coordinates: a [N, D] tensor. N is the number of 320 | vertices. D is the dimension of the coordinates. 321 | NOT_USED: leave it here for API compatibility. 322 | edges: a [K, 2] tensor. K pairs of (src, dest) vertex indices. 323 | edge_MLP_depth_list: a list of MLP units to extract edge features. 324 | edge_MLP_normalization_type: the normalization function of MLP. 325 | edge_MLP_activation_type: the activation function of MLP. 326 | update_MLP_depth_list: a list of MLP units to extract update 327 | features. 328 | update_MLP_normalization_type: the normalization function of MLP. 329 | update_MLP_activation_type: the activation function of MLP. 330 | auto_offset: boolean, use auto registration or not. 331 | auto_offset_MLP_depth_list: a list of MLP units to compute offset. 332 | auto_offset_MLP_normalization_type: the normalization function. 333 | auto_offset_MLP_feature_activation_type: the activation function. 334 | 335 | returns: a [N, M] tensor. Updated vertex features. 336 | """ 337 | # Gather the source vertex of the edges 338 | s_vertex_features = tf.gather(input_vertex_features, edges[:,0]) 339 | s_vertex_coordinates = tf.gather(input_vertex_coordinates, edges[:,0]) 340 | # [optional] Compute the coordinates offset 341 | if auto_offset: 342 | offset = self._auto_offset_fn(input_vertex_features, 343 | Ks=auto_offset_MLP_depth_list, is_logits=True, 344 | normalization_type=auto_offset_MLP_normalization_type, 345 | activation_type=auto_offset_MLP_feature_activation_type) 346 | input_vertex_coordinates = input_vertex_coordinates + offset 347 | # Gather the destination vertex of the edges 348 | d_vertex_coordinates = tf.gather(input_vertex_coordinates, edges[:, 1]) 349 | # Prepare initial edge features 350 | edge_features = tf.concat( 351 | [s_vertex_features, s_vertex_coordinates - d_vertex_coordinates], 352 | axis=-1) 353 | with tf.variable_scope('extract_vertex_features'): 354 | # Extract edge features 355 | edge_features = self._edge_feature_fn( 356 | edge_features, 357 | Ks=edge_MLP_depth_list, 358 | is_logits=False, 359 | normalization_type=edge_MLP_normalization_type, 360 | activation_type=edge_MLP_activation_type) 361 | # Aggregate edge features 362 | aggregated_edge_features = self._aggregation_fn( 363 | edge_features, 364 | edges[:, 1], 365 | tf.shape(input_vertex_features)[0]) 366 | # Update vertex features 367 | with tf.variable_scope('combined_features'): 368 | update_features = self._update_fn(aggregated_edge_features, 369 | Ks=update_MLP_depth_list, is_logits=True, 370 | normalization_type=update_MLP_normalization_type, 371 | activation_type=update_MLP_activation_type) 372 | output_vertex_features = update_features + input_vertex_features 373 | return output_vertex_features 374 | -------------------------------------------------------------------------------- /models/graph_gen.py: -------------------------------------------------------------------------------- 1 | """The file defines functions to generate graphs.""" 2 | 3 | import time 4 | import random 5 | 6 | import numpy as np 7 | from sklearn.neighbors import NearestNeighbors 8 | import open3d 9 | import tensorflow as tf 10 | 11 | def multi_layer_downsampling(points_xyz, base_voxel_size, levels=[1], 12 | add_rnd3d=False,): 13 | """Downsample the points using base_voxel_size at different scales""" 14 | xmax, ymax, zmax = np.amax(points_xyz, axis=0) 15 | xmin, ymin, zmin = np.amin(points_xyz, axis=0) 16 | xyz_offset = np.asarray([[xmin, ymin, zmin]]) 17 | xyz_zeros = np.asarray([0, 0, 0], dtype=np.float32) 18 | downsampled_list = [points_xyz] 19 | last_level = 0 20 | for level in levels: 21 | if np.isclose(last_level, level): 22 | downsampled_list.append(np.copy(downsampled_list[-1])) 23 | else: 24 | if add_rnd3d: 25 | xyz_idx = (points_xyz-xyz_offset+ 26 | base_voxel_size*level*np.random.random((1,3)))//\ 27 | (base_voxel_size*level) 28 | xyz_idx = xyz_idx.astype(np.int32) 29 | dim_x, dim_y, dim_z = np.amax(xyz_idx, axis=0) + 1 30 | keys = xyz_idx[:, 0]+xyz_idx[:, 1]*dim_x+\ 31 | xyz_idx[:, 2]*dim_y*dim_x 32 | sorted_order = np.argsort(keys) 33 | sorted_keys = keys[sorted_order] 34 | sorted_points_xyz = points_xyz[sorted_order] 35 | _, lens = np.unique(sorted_keys, return_counts=True) 36 | indices = np.hstack([[0], lens[:-1]]).cumsum() 37 | downsampled_xyz = np.add.reduceat( 38 | sorted_points_xyz, indices, axis=0)/lens[:,np.newaxis] 39 | downsampled_list.append(np.array(downsampled_xyz)) 40 | else: 41 | pcd = open3d.PointCloud() 42 | pcd.points = open3d.Vector3dVector(points_xyz) 43 | downsampled_xyz = np.asarray(open3d.voxel_down_sample( 44 | pcd, voxel_size = base_voxel_size*level).points) 45 | downsampled_list.append(downsampled_xyz) 46 | last_level = level 47 | return downsampled_list 48 | 49 | def multi_layer_downsampling_select(points_xyz, base_voxel_size, levels=[1], 50 | add_rnd3d=False): 51 | """Downsample the points at different scales and match the downsampled 52 | points to original points by a nearest neighbor search. 53 | 54 | Args: 55 | points_xyz: a [N, D] matrix. N is the total number of the points. D is 56 | the dimension of the coordinates. 57 | base_voxel_size: scalar, the cell size of voxel. 58 | level_configs: a dict of 'level', 'graph_gen_method', 59 | 'graph_gen_kwargs', 'graph_scale'. 60 | add_rnd3d: boolean, whether to add random offset when downsampling. 61 | 62 | returns: vertex_coord_list, keypoint_indices_list 63 | """ 64 | # Voxel downsampling 65 | vertex_coord_list = multi_layer_downsampling( 66 | points_xyz, base_voxel_size, levels=levels, add_rnd3d=add_rnd3d) 67 | num_levels = len(vertex_coord_list) 68 | assert num_levels == len(levels) + 1 69 | # Match downsampled vertices to original by a nearest neighbor search. 70 | keypoint_indices_list = [] 71 | last_level = 0 72 | for i in range(1, num_levels): 73 | current_level = levels[i-1] 74 | base_points = vertex_coord_list[i-1] 75 | current_points = vertex_coord_list[i] 76 | if np.isclose(current_level, last_level): 77 | # same downsample scale (gnn layer), 78 | # just copy it, no need to search. 79 | vertex_coord_list[i] = base_points 80 | keypoint_indices_list.append( 81 | np.expand_dims(np.arange(base_points.shape[0]),axis=1)) 82 | else: 83 | # different scale (pooling layer), search original points. 84 | nbrs = NearestNeighbors(n_neighbors=1, 85 | algorithm='kd_tree', n_jobs=1).fit(base_points) 86 | indices = nbrs.kneighbors(current_points, return_distance=False) 87 | vertex_coord_list[i] = base_points[indices[:, 0], :] 88 | keypoint_indices_list.append(indices) 89 | last_level = current_level 90 | return vertex_coord_list, keypoint_indices_list 91 | 92 | def multi_layer_downsampling_random(points_xyz, base_voxel_size, levels=[1], 93 | add_rnd3d=False): 94 | """Downsample the points at different scales by randomly select a point 95 | within a voxel cell. 96 | 97 | Args: 98 | points_xyz: a [N, D] matrix. N is the total number of the points. D is 99 | the dimension of the coordinates. 100 | base_voxel_size: scalar, the cell size of voxel. 101 | level_configs: a dict of 'level', 'graph_gen_method', 102 | 'graph_gen_kwargs', 'graph_scale'. 103 | add_rnd3d: boolean, whether to add random offset when downsampling. 104 | 105 | returns: vertex_coord_list, keypoint_indices_list 106 | """ 107 | xmax, ymax, zmax = np.amax(points_xyz, axis=0) 108 | xmin, ymin, zmin = np.amin(points_xyz, axis=0) 109 | xyz_offset = np.asarray([[xmin, ymin, zmin]]) 110 | xyz_zeros = np.asarray([0, 0, 0], dtype=np.float32) 111 | vertex_coord_list = [points_xyz] 112 | keypoint_indices_list = [] 113 | last_level = 0 114 | for level in levels: 115 | last_points_xyz = vertex_coord_list[-1] 116 | if np.isclose(last_level, level): 117 | # same downsample scale (gnn layer), just copy it 118 | vertex_coord_list.append(np.copy(last_points_xyz)) 119 | keypoint_indices_list.append( 120 | np.expand_dims(np.arange(len(last_points_xyz)), axis=1)) 121 | else: 122 | if not add_rnd3d: 123 | xyz_idx = (last_points_xyz - xyz_offset) \ 124 | // (base_voxel_size*level) 125 | else: 126 | xyz_idx = (last_points_xyz - xyz_offset + 127 | base_voxel_size*level*np.random.random((1,3))) \ 128 | // (base_voxel_size*level) 129 | xyz_idx = xyz_idx.astype(np.int32) 130 | dim_x, dim_y, dim_z = np.amax(xyz_idx, axis=0) + 1 131 | keys = xyz_idx[:, 0]+xyz_idx[:, 1]*dim_x+xyz_idx[:, 2]*dim_y*dim_x 132 | num_points = xyz_idx.shape[0] 133 | 134 | voxels_idx = {} 135 | for pidx in range(len(last_points_xyz)): 136 | key = keys[pidx] 137 | if key in voxels_idx: 138 | voxels_idx[key].append(pidx) 139 | else: 140 | voxels_idx[key] = [pidx] 141 | 142 | downsampled_xyz = [] 143 | downsampled_xyz_idx = [] 144 | for key in voxels_idx: 145 | center_idx = random.choice(voxels_idx[key]) 146 | downsampled_xyz.append(last_points_xyz[center_idx]) 147 | downsampled_xyz_idx.append(center_idx) 148 | vertex_coord_list.append(np.array(downsampled_xyz)) 149 | keypoint_indices_list.append( 150 | np.expand_dims(np.array(downsampled_xyz_idx),axis=1)) 151 | last_level = level 152 | 153 | return vertex_coord_list, keypoint_indices_list 154 | 155 | def gen_multi_level_local_graph_v3( 156 | points_xyz, base_voxel_size, level_configs, add_rnd3d=False, 157 | downsample_method='center'): 158 | """Generating graphs at multiple scale. This function enforce output 159 | vertices of a graph matches the input vertices of next graph so that 160 | gnn layers can be applied sequentially. 161 | 162 | Args: 163 | points_xyz: a [N, D] matrix. N is the total number of the points. D is 164 | the dimension of the coordinates. 165 | base_voxel_size: scalar, the cell size of voxel. 166 | level_configs: a dict of 'level', 'graph_gen_method', 167 | 'graph_gen_kwargs', 'graph_scale'. 168 | add_rnd3d: boolean, whether to add random offset when downsampling. 169 | downsample_method: string, the name of downsampling method. 170 | returns: vertex_coord_list, keypoint_indices_list, edges_list 171 | """ 172 | if isinstance(base_voxel_size, list): 173 | base_voxel_size = np.array(base_voxel_size) 174 | # Gather the downsample scale for each graph 175 | scales = [config['graph_scale'] for config in level_configs] 176 | # Generate vertex coordinates 177 | if downsample_method=='center': 178 | vertex_coord_list, keypoint_indices_list = \ 179 | multi_layer_downsampling_select( 180 | points_xyz, base_voxel_size, scales, add_rnd3d=add_rnd3d) 181 | if downsample_method=='random': 182 | vertex_coord_list, keypoint_indices_list = \ 183 | multi_layer_downsampling_random( 184 | points_xyz, base_voxel_size, scales, add_rnd3d=add_rnd3d) 185 | # Create edges 186 | edges_list = [] 187 | for config in level_configs: 188 | graph_level = config['graph_level'] 189 | gen_graph_fn = get_graph_generate_fn(config['graph_gen_method']) 190 | method_kwarg = config['graph_gen_kwargs'] 191 | points_xyz = vertex_coord_list[graph_level] 192 | center_xyz = vertex_coord_list[graph_level+1] 193 | vertices = gen_graph_fn(points_xyz, center_xyz, **method_kwarg) 194 | edges_list.append(vertices) 195 | return vertex_coord_list, keypoint_indices_list, edges_list 196 | 197 | def gen_disjointed_rnn_local_graph_v3( 198 | points_xyz, center_xyz, radius, num_neighbors, 199 | neighbors_downsample_method='random', 200 | scale=None): 201 | """Generate a local graph by radius neighbors. 202 | """ 203 | if scale is not None: 204 | scale = np.array(scale) 205 | points_xyz = points_xyz/scale 206 | center_xyz = center_xyz/scale 207 | nbrs = NearestNeighbors( 208 | radius=radius,algorithm='ball_tree', n_jobs=1, ).fit(points_xyz) 209 | indices = nbrs.radius_neighbors(center_xyz, return_distance=False) 210 | if num_neighbors > 0: 211 | if neighbors_downsample_method == 'random': 212 | indices = [neighbors if neighbors.size <= num_neighbors else 213 | np.random.choice(neighbors, num_neighbors, replace=False) 214 | for neighbors in indices] 215 | vertices_v = np.concatenate(indices) 216 | vertices_i = np.concatenate( 217 | [i*np.ones(neighbors.size, dtype=np.int32) 218 | for i, neighbors in enumerate(indices)]) 219 | vertices = np.array([vertices_v, vertices_i]).transpose() 220 | return vertices 221 | 222 | def get_graph_generate_fn(method_name): 223 | method_map = { 224 | 'disjointed_rnn_local_graph_v3':gen_disjointed_rnn_local_graph_v3, 225 | 'multi_level_local_graph_v3': gen_multi_level_local_graph_v3, 226 | } 227 | return method_map[method_name] 228 | -------------------------------------------------------------------------------- /models/loss.py: -------------------------------------------------------------------------------- 1 | """Implements popular losses. """ 2 | 3 | import tensorflow as tf 4 | 5 | def focal_loss_sigmoid(labels, logits, alpha=0.5, gamma=2): 6 | """ 7 | github.com/tensorflow/models/blob/master/\ 8 | research/object_detection/core/losses.py 9 | Computer focal loss for binary classification 10 | Args: 11 | labels: A int32 tensor of shape [batch_size]. 12 | logits: A float32 tensor of shape [batch_size]. 13 | alpha: A scalar for focal loss alpha hyper-parameter. 14 | If positive samples number > negtive samples number, 15 | alpha < 0.5 and vice versa. 16 | gamma: A scalar for focal loss gamma hyper-parameter. 17 | Returns: 18 | A tensor of the same shape as `lables` 19 | """ 20 | prob = tf.sigmoid(logits) 21 | labels = tf.one_hot(labels,depth=prob.shape[1]) 22 | labels = tf.squeeze(labels, axis=1) 23 | cross_ent = tf.nn.sigmoid_cross_entropy_with_logits( 24 | labels=labels, logits=logits) 25 | prob_t= (labels*prob) + (1-labels)*(1-prob) 26 | modulating = tf.pow(1-prob_t, gamma) 27 | alpha_weight = (labels*alpha) + (1-labels)*(1-alpha) 28 | focal_cross_entropy = (modulating * alpha_weight * cross_ent) 29 | return focal_cross_entropy 30 | 31 | def focal_loss_softmax(labels, logits, gamma=2): 32 | """ 33 | https://github.com/fudannlp16/focal-loss/blob/master/focal_loss.py 34 | Computer focal loss for multi classification 35 | Args: 36 | labels: A int32 tensor of shape [batch_size]. 37 | logits: A float32 tensor of shape [batch_size,num_classes]. 38 | gamma: A scalar for focal loss gamma hyper-parameter. 39 | Returns: 40 | A tensor of the same shape as `lables` 41 | """ 42 | y_pred=tf.nn.softmax(logits,axis=-1) # [batch_size,num_classes] 43 | cross_ent = tf.nn.sparse_softmax_cross_entropy_with_logits( 44 | labels=tf.squeeze(labels, axis=1), logits=logits) 45 | cross_ent = tf.expand_dims(cross_ent, 1) 46 | labels = tf.cast(labels, tf.int32) 47 | L=((1.0-tf.batch_gather(y_pred, labels))**gamma)*cross_ent 48 | return L 49 | 50 | def test_focal_loss(): 51 | labels = tf.constant([[1], [2]], dtype=tf.int32) 52 | logits = tf.constant([[-100, -100, -100], [-20, -20, -40.0]], 53 | dtype=tf.float32) 54 | focal_sigmoid = focal_loss_sigmoid(labels, logits) 55 | focal_softmax = focal_loss_softmax(labels, logits) 56 | with tf.Session() as sess: 57 | print(sess.run(focal_sigmoid)) 58 | print(sess.run(focal_softmax)) 59 | if __name__ == '__main__': 60 | test_focal_loss() 61 | -------------------------------------------------------------------------------- /models/models.py: -------------------------------------------------------------------------------- 1 | """This file implements models for object detection. """ 2 | 3 | from functools import partial 4 | 5 | import tensorflow as tf 6 | import tensorflow.contrib.slim as slim 7 | 8 | from models.loss import focal_loss_sigmoid, focal_loss_softmax 9 | from models import gnn 10 | 11 | regularizer_dict = { 12 | 'l2': slim.l2_regularizer, 13 | 'l1': slim.l1_regularizer, 14 | 'l1_l2': slim.l1_l2_regularizer, 15 | } 16 | keras_regularizer_dict = { 17 | 'l2': tf.keras.regularizers.l1_l2, 18 | 'l1': tf.keras.regularizers.l1, 19 | 'l1_l2': tf.keras.regularizers.l1_l2, 20 | } 21 | 22 | class MultiLayerFastLocalGraphModelV2(object): 23 | """General multiple layer GNN model. The graphs are generated outside this 24 | model and then feeded into this model. This model applies a list of layers 25 | sequentially, while each layer chooses the graph they operate on. 26 | """ 27 | 28 | def __init__(self, num_classes, box_encoding_len, regularizer_type=None, 29 | regularizer_kwargs=None, layer_configs=None, mode=None): 30 | """ 31 | Args: 32 | num_classes: int, the number of object classes. 33 | box_encoding_len: int, the length of encoded bounding box. 34 | regularizer_type: string, one of 'l2','l1', 'l1_l2'. 35 | regularizer_kwargs: dict, keyword args to the regularizer. 36 | layer_config: A list of layer configurations. 37 | mode: string, one of 'train', 'eval', 'test'. 38 | """ 39 | self.num_classes = num_classes 40 | self.box_encoding_len = box_encoding_len 41 | if regularizer_type is None: 42 | assert regularizer_kwargs is None, 'No regularizer no kwargs' 43 | self._regularizer = None 44 | else: 45 | self._regularizer = regularizer_dict[regularizer_type]( 46 | **regularizer_kwargs) 47 | self._layer_configs = layer_configs 48 | self._default_layers_type = { 49 | 'scatter_max_point_set_pooling': gnn.PointSetPooling( 50 | point_feature_fn=gnn.multi_layer_neural_network_fn, 51 | aggregation_fn=gnn.graph_scatter_max_fn, 52 | output_fn=gnn.multi_layer_neural_network_fn 53 | ), 54 | 'scatter_max_graph_auto_center_net': gnn.GraphNetAutoCenter( 55 | edge_feature_fn=gnn.multi_layer_neural_network_fn, 56 | aggregation_fn=gnn.graph_scatter_max_fn, 57 | update_fn=gnn.multi_layer_neural_network_fn, 58 | auto_offset_fn=gnn.multi_layer_neural_network_fn, 59 | ), 60 | 'classaware_predictor': gnn.ClassAwarePredictor( 61 | cls_fn=partial(gnn.multi_layer_fc_fn, Ks=(64,), num_layer=2), 62 | loc_fn=partial(gnn.multi_layer_fc_fn, 63 | Ks=(64, 64,), num_layer=3) 64 | ), 65 | 'classaware_predictor_128': gnn.ClassAwarePredictor( 66 | cls_fn=partial(gnn.multi_layer_fc_fn, Ks=(128,), num_layer=2), 67 | loc_fn=partial(gnn.multi_layer_fc_fn, 68 | Ks=(128, 128), num_layer=3) 69 | ), 70 | 'classaware_separated_predictor': gnn.ClassAwareSeparatedPredictor( 71 | cls_fn=partial(gnn.multi_layer_fc_fn, Ks=(64,), num_layer=2), 72 | loc_fn=partial(gnn.multi_layer_fc_fn, 73 | Ks=(64, 64,), num_layer=3) 74 | ), 75 | } 76 | assert mode in ['train', 'eval', 'test'], 'Unsupported mode' 77 | self._mode = mode 78 | 79 | def predict(self, 80 | t_initial_vertex_features, 81 | t_vertex_coord_list, 82 | t_keypoint_indices_list, 83 | t_edges_list, 84 | is_training, 85 | ): 86 | """ 87 | Predict the objects with initial vertex features and a list of graphs. 88 | The model applies layers sequentially while each layer choose the graph 89 | that they operates. For example, a layer can choose the i-th graph, 90 | which is composed of t_vertex_coord_list[i], t_edges_list[i], and 91 | optionally t_keypoint_indices_list[i]. It operates on the graph and 92 | output the updated vertex_features. Then the next layer takes the 93 | vertex_features and also choose a graph to further update the features. 94 | 95 | Args: 96 | t_initial_vertex_features: a [N, M] tensor, the initial features of 97 | N vertices. For example, the intensity value of lidar reflection. 98 | t_vertex_coord_list: a list of [Ni, 3] tensors, the coordinates of 99 | a list of graph vertices. 100 | t_keypoint_indices_list: a list of [Nj, 1] tensors or None. For a 101 | pooling layer, it outputs a reduced number of vertices, aka. the 102 | keypoints. t_keypoint_indices_list[i] is the indices of those 103 | keypoints. For a gnn layer, it does not reduce the vertex number, 104 | thus t_keypoint_indices_list[i] should be set to 'None'. 105 | t_edges_list: a list of [Ki, 2] tensors. t_edges_list[i] are edges 106 | for the i-th graph. it contains Ki pair of (source, destination) 107 | vertex indices. 108 | is_training: boolean, whether in training mode or not. 109 | returns: [N_output, num_classes] logits tensor for classification, 110 | [N_output, num_classes, box_encoding_len] box_encodings tensor for 111 | localization. 112 | """ 113 | with slim.arg_scope([slim.batch_norm], is_training=is_training), \ 114 | slim.arg_scope([slim.fully_connected], 115 | weights_regularizer=self._regularizer): 116 | tfeatures_list = [] 117 | tfeatures = t_initial_vertex_features 118 | tfeatures_list.append(tfeatures) 119 | for idx in range(len(self._layer_configs)-1): 120 | layer_config = self._layer_configs[idx] 121 | layer_scope = layer_config['scope'] 122 | layer_type = layer_config['type'] 123 | layer_kwargs = layer_config['kwargs'] 124 | graph_level = layer_config['graph_level'] 125 | t_vertex_coordinates = t_vertex_coord_list[graph_level] 126 | t_keypoint_indices = t_keypoint_indices_list[graph_level] 127 | t_edges = t_edges_list[graph_level] 128 | with tf.variable_scope(layer_scope, reuse=tf.AUTO_REUSE): 129 | flgn = self._default_layers_type[layer_type] 130 | print('@ level %d Graph, Add layer: %s, type: %s'% 131 | (graph_level, layer_scope, layer_type)) 132 | if 'device' in layer_config: 133 | with tf.device(layer_config['device']): 134 | tfeatures = flgn.apply_regular( 135 | tfeatures, 136 | t_vertex_coordinates, 137 | t_keypoint_indices, 138 | t_edges, 139 | **layer_kwargs) 140 | else: 141 | tfeatures = flgn.apply_regular( 142 | tfeatures, 143 | t_vertex_coordinates, 144 | t_keypoint_indices, 145 | t_edges, 146 | **layer_kwargs) 147 | 148 | tfeatures_list.append(tfeatures) 149 | print('Feature Dim:' + str(tfeatures.shape[-1])) 150 | predictor_config = self._layer_configs[-1] 151 | assert (predictor_config['type']=='classaware_predictor' or 152 | predictor_config['type']=='classaware_predictor_128' or 153 | predictor_config['type']=='classaware_separated_predictor') 154 | predictor = self._default_layers_type[predictor_config['type']] 155 | print('Final Feature Dim:'+str(tfeatures.shape[-1])) 156 | with tf.variable_scope(predictor_config['scope'], 157 | reuse=tf.AUTO_REUSE): 158 | logits, box_encodings = predictor.apply_regular(tfeatures, 159 | num_classes=self.num_classes, 160 | box_encoding_len=self.box_encoding_len, 161 | **predictor_config['kwargs']) 162 | print("Prediction %d classes" % self.num_classes) 163 | return logits, box_encodings 164 | 165 | def postprocess(self, logits): 166 | """Output predictions. """ 167 | prob = tf.nn.softmax(logits, axis=-1) 168 | return prob 169 | 170 | def loss(self, logits, labels, pred_box, gt_box, valid_box, 171 | cls_loss_type='focal_sigmoid', cls_loss_kwargs={}, 172 | loc_loss_type='huber_loss', loc_loss_kwargs={}, 173 | loc_loss_weight=1.0, 174 | cls_loss_weight=1.0): 175 | """Output loss value. 176 | 177 | Args: 178 | logits: [N, num_classes] tensor. The classification logits from 179 | predict method. 180 | labels: [N] tensor. The one hot class labels. 181 | pred_box: [N, num_classes, box_encoding_len] tensor. The encoded 182 | bounding boxes from the predict method. 183 | gt_box: [N, 1, box_encoding_len] tensor. The ground truth encoded 184 | bounding boxes. 185 | valid_box: [N] tensor. An indicator of whether the vertex is from 186 | an object of interest (whether it has a valid bounding box). 187 | cls_loss_type: string, the type of classification loss function. 188 | cls_loss_kwargs: dict, keyword args to the classifcation loss. 189 | loc_loss_type: string, the type of localization loss function. 190 | loc_loss_kwargs: dict, keyword args to the localization loss. 191 | loc_loss_weight: scalar, weight on localization loss. 192 | cls_loss_weight: scalar, weight on the classifcation loss. 193 | returns: a dict of cls_loss, loc_loss, reg_loss, num_endpoint, 194 | num_valid_endpoint. num_endpoint is the number of output vertices. 195 | num_valid_endpoint is the number of output vertices that have a valid 196 | bounding box. Those numbers are useful for weighting during batching. 197 | """ 198 | if isinstance(loc_loss_weight, dict): 199 | loc_loss_weight = loc_loss_weight[self._mode] 200 | if isinstance(cls_loss_weight, dict): 201 | cls_loss_weight = cls_loss_weight[self._mode] 202 | if isinstance(cls_loss_type, dict): 203 | cls_loss_type = cls_loss_type[self._mode] 204 | cls_loss_kwargs = cls_loss_kwargs[self._mode] 205 | if isinstance(loc_loss_type, dict): 206 | loc_loss_type = loc_loss_type[self._mode] 207 | loc_loss_kwargs = loc_loss_kwargs[self._mode] 208 | 209 | loss_dict = {} 210 | assert cls_loss_type in ['softmax', 'top_k_softmax', 211 | 'focal_sigmoid', 'focal_softmax'] 212 | if cls_loss_type == 'softmax': 213 | point_loss =tf.nn.sparse_softmax_cross_entropy_with_logits( 214 | labels=tf.squeeze(labels, axis=1), logits=logits) 215 | num_endpoint = tf.shape(point_loss)[0] 216 | if cls_loss_type == 'focal_sigmoid': 217 | point_loss = focal_loss_sigmoid(labels, logits, **cls_loss_kwargs) 218 | num_endpoint = tf.shape(point_loss)[0] 219 | if cls_loss_type == 'focal_softmax': 220 | point_loss = focal_loss_softmax(labels, logits, **cls_loss_kwargs) 221 | num_endpoint = tf.shape(point_loss)[0] 222 | if cls_loss_type == 'top_k_softmax': 223 | point_loss =tf.nn.sparse_softmax_cross_entropy_with_logits( 224 | labels=tf.squeeze(labels, axis=1), logits=logits) 225 | num_endpoint = tf.shape(point_loss)[0] 226 | k = cls_loss_kwargs['k'] 227 | top_k_cls_loss, _ = tf.math.top_k(point_loss, k=k, sorted=True) 228 | point_loss = top_k_cls_loss 229 | cls_loss = cls_loss_weight* tf.reduce_mean(point_loss) 230 | batch_idx = tf.range(tf.shape(pred_box)[0]) 231 | batch_idx = tf.expand_dims(batch_idx, axis=1) 232 | batch_idx = tf.concat([batch_idx, labels], axis=1) 233 | pred_box = tf.gather_nd(pred_box, batch_idx) 234 | pred_box = tf.expand_dims(pred_box, axis=1) 235 | #pred_box = tf.batch_gather(pred_box, labels) 236 | if loc_loss_type == 'huber_loss': 237 | all_loc_loss = loc_loss_weight*tf.losses.huber_loss( 238 | gt_box, 239 | pred_box, 240 | delta=1.0, 241 | weights=valid_box, 242 | reduction=tf.losses.Reduction.NONE) 243 | print(all_loc_loss.shape) 244 | all_loc_loss = tf.squeeze(all_loc_loss, axis=1) 245 | if 'classwise_loc_loss_weight' in loc_loss_kwargs and\ 246 | self._mode == 'train': 247 | classwise_loc_loss_weight = loc_loss_kwargs[ 248 | 'classwise_loc_loss_weight'] 249 | classwise_loc_loss_weight = tf.gather( 250 | classwise_loc_loss_weight, labels) 251 | all_loc_loss = all_loc_loss*classwise_loc_loss_weight 252 | num_valid_endpoint = tf.reduce_sum(valid_box) 253 | mean_loc_loss = tf.reduce_mean(all_loc_loss, axis=1) 254 | loc_loss = tf.div_no_nan(tf.reduce_sum(mean_loc_loss), 255 | num_valid_endpoint) 256 | classwise_loc_loss = [] 257 | for class_idx in range(self.num_classes): 258 | class_mask = tf.where(tf.equal(tf.squeeze(labels, axis=1), 259 | tf.constant(class_idx, tf.int32))) 260 | l = tf.reduce_sum(tf.gather(all_loc_loss, class_mask), axis=0) 261 | l = tf.squeeze(l, axis=0) 262 | is_nan_mask = tf.is_nan(l) 263 | l = tf.where(is_nan_mask, tf.zeros_like(l),l) 264 | classwise_loc_loss.append(l) 265 | loss_dict['classwise_loc_loss'] = classwise_loc_loss 266 | if loc_loss_type == 'top_k_huber_loss': 267 | k = loc_loss_kwargs['k'] 268 | all_loc_loss = loc_loss_weight*tf.losses.huber_loss( 269 | gt_box, 270 | pred_box, 271 | delta=1.0, 272 | weights=valid_box, 273 | reduction=tf.losses.Reduction.NONE) 274 | all_loc_loss = tf.squeeze(all_loc_loss, axis=1) 275 | if 'classwise_loc_loss_weight' in loc_loss_kwargs \ 276 | and self._mode == 'train': 277 | classwise_loc_loss_weight = loc_loss_kwargs[ 278 | 'classwise_loc_loss_weight'] 279 | classwise_loc_loss_weight = tf.gather( 280 | classwise_loc_loss_weight, labels) 281 | all_loc_loss = all_loc_loss*classwise_loc_loss_weight 282 | loc_loss = tf.reduce_mean(all_loc_loss, axis=1) 283 | top_k_loc_loss, top_k_indices = tf.math.top_k(loc_loss, 284 | k=k, sorted=True) 285 | valid_box = tf.squeeze(valid_box, axis=2) 286 | valid_box = tf.squeeze(valid_box, axis=1) 287 | top_k_valid_box = tf.gather(valid_box, top_k_indices) 288 | num_valid_endpoint = tf.reduce_sum(top_k_valid_box) 289 | loc_loss = tf.div_no_nan(tf.reduce_sum(top_k_loc_loss), 290 | num_valid_endpoint) 291 | top_k_labels = tf.gather(labels, top_k_indices) 292 | all_top_k_loc_loss = tf.gather(all_loc_loss, top_k_indices) 293 | classwise_loc_loss = [] 294 | for class_idx in range(self.num_classes): 295 | class_mask = tf.where(tf.equal(tf.squeeze(top_k_labels), 296 | tf.constant(class_idx, tf.int32))) 297 | l = tf.reduce_sum(tf.gather(all_top_k_loc_loss, class_mask), 298 | axis=0) 299 | l = tf.squeeze(l, axis=0) 300 | is_nan_mask = tf.is_nan(l) 301 | l = tf.where(is_nan_mask, tf.zeros_like(l),l) 302 | classwise_loc_loss.append(l) 303 | loss_dict['classwise_loc_loss'] = classwise_loc_loss 304 | 305 | with tf.control_dependencies([tf.assert_equal(tf.is_nan(loc_loss), 306 | False)]): 307 | reg_loss = tf.reduce_sum(tf.losses.get_regularization_losses()) 308 | loss_dict.update({'cls_loss': cls_loss, 'loc_loss': loc_loss, 309 | 'reg_loss': reg_loss, 'num_endpoint': num_endpoint, 310 | 'num_valid_endpoint':num_valid_endpoint}) 311 | return loss_dict 312 | 313 | def get_model(model_name): 314 | """Fetch a model class.""" 315 | model_map = { 316 | 'multi_layer_fast_local_graph_model_v2': 317 | MultiLayerFastLocalGraphModelV2, 318 | } 319 | return model_map[model_name] 320 | -------------------------------------------------------------------------------- /models/nms.py: -------------------------------------------------------------------------------- 1 | """This file defines nms functions to merge boxes""" 2 | 3 | import time 4 | 5 | import cv2 6 | import numpy as np 7 | from shapely.geometry import Polygon 8 | 9 | def boxes_3d_to_corners(boxes_3d): 10 | all_corners = [] 11 | for box_3d in boxes_3d: 12 | x3d, y3d, z3d, l, h, w, yaw = box_3d 13 | R = np.array([[np.cos(yaw), 0, np.sin(yaw)], 14 | [0, 1, 0 ], 15 | [-np.sin(yaw), 0, np.cos(yaw)]]); 16 | corners = np.array([[ l/2, 0.0, w/2], # front up right 17 | [ l/2, 0.0, -w/2], # front up left 18 | [-l/2, 0.0, -w/2], # back up left 19 | [-l/2, 0.0, w/2], # back up right 20 | [ l/2, -h, w/2], # front down right 21 | [ l/2, -h, -w/2], # front down left 22 | [-l/2, -h, -w/2], # back down left 23 | [-l/2, -h, w/2]]) # back down right 24 | r_corners = corners.dot(np.transpose(R)) 25 | cam_points_xyz = r_corners+np.array([x3d, y3d, z3d]) 26 | all_corners.append(cam_points_xyz) 27 | return np.array(all_corners) 28 | 29 | def overlapped_boxes_3d(single_box, box_list): 30 | x0_max, y0_max, z0_max = np.max(single_box, axis=0) 31 | x0_min, y0_min, z0_min = np.min(single_box, axis=0) 32 | overlap = np.zeros(len(box_list)) 33 | for i, box in enumerate(box_list): 34 | x_max, y_max, z_max = np.max(box, axis=0) 35 | x_min, y_min, z_min = np.min(box, axis=0) 36 | if x0_max < x_min or x0_min > x_max: 37 | overlap[i] = 0 38 | continue 39 | if y0_max < y_min or y0_min > y_max: 40 | overlap[i] = 0 41 | continue 42 | if z0_max < z_min or z0_min > z_max: 43 | overlap[i] = 0 44 | continue 45 | x_draw_min = min(x0_min, x_min) 46 | x_draw_max = max(x0_max, x_max) 47 | z_draw_min = min(z0_min, z_min) 48 | z_draw_max = max(z0_max, z_max) 49 | offset = np.array([x_draw_min, z_draw_min]) 50 | buf1 = np.zeros((z_draw_max-z_draw_min, x_draw_max-x_draw_min), 51 | dtype=np.int32) 52 | buf2 = np.zeros_like(buf1) 53 | cv2.fillPoly(buf1, [single_box[:4, [0,2]]-offset], color=1) 54 | cv2.fillPoly(buf2, [box[:4, [0,2]]-offset], color=1) 55 | shared_area = cv2.countNonZero(buf1*buf2) 56 | area1 = cv2.countNonZero(buf1) 57 | area2 = cv2.countNonZero(buf2) 58 | shared_y = min(y_max, y0_max) - max(y_min, y0_min) 59 | intersection = shared_y * shared_area 60 | union = (y_max-y_min) * area2 + (y0_max-y0_min) * area1 61 | overlap[i] = np.float32(intersection) / (union - intersection) 62 | return overlap 63 | 64 | def overlapped_boxes_3d_fast_poly(single_box, box_list): 65 | single_box_max_corner = np.max(single_box, axis=0) 66 | single_box_min_corner = np.min(single_box, axis=0) 67 | x0_max, y0_max, z0_max = single_box_max_corner 68 | x0_min, y0_min, z0_min = single_box_min_corner 69 | max_corner = np.max(box_list, axis=1) 70 | min_corner = np.min(box_list, axis=1) 71 | overlap = np.zeros(len(box_list)) 72 | non_overlap_mask = np.logical_or(single_box_max_corner < min_corner, 73 | single_box_min_corner > max_corner) 74 | non_overlap_mask = np.any(non_overlap_mask, axis=1) 75 | p1 = Polygon(single_box[:4, [0,2]]) 76 | area1 = p1.area 77 | for i in range(len(box_list)): 78 | if not non_overlap_mask[i]: 79 | x_max, y_max, z_max = max_corner[i] 80 | x_min, y_min, z_min = min_corner[i] 81 | p2 = Polygon(box_list[i][:4, [0,2]]) 82 | shared_area = p1.intersection(p2).area 83 | area2 = p2.area 84 | shared_y = min(y_max, y0_max) - max(y_min, y0_min) 85 | intersection = shared_y * shared_area 86 | union = (y_max-y_min) * area2 + (y0_max-y0_min) * area1 87 | overlap[i] = np.float32(intersection) / (union - intersection) 88 | return overlap 89 | 90 | def bboxes_sort(classes, scores, bboxes, top_k=400, attributes=None): 91 | """Sort bounding boxes by decreasing order and keep only the top_k 92 | """ 93 | idxes = np.argsort(-scores) 94 | classes = classes[idxes] 95 | scores = scores[idxes] 96 | bboxes = bboxes[idxes] 97 | if attributes is not None: 98 | attributes = attributes[idxes] 99 | if top_k > 0: 100 | if len(idxes) > top_k: 101 | classes = classes[:top_k] 102 | scores = scores[:top_k] 103 | bboxes = bboxes[:top_k] 104 | if attributes is not None: 105 | attributes = attributes[:top_k] 106 | return classes, scores, bboxes, attributes 107 | 108 | def bboxes_nms(classes, scores, bboxes, nms_threshold=0.45, 109 | overlapped_fn=overlapped_boxes_3d, appr_factor=10.0, attributes=None): 110 | """Apply non-maximum selection to bounding boxes. 111 | """ 112 | boxes_corners = boxes_3d_to_corners(bboxes) 113 | # convert to pixels 114 | boxes_corners = np.int32(boxes_corners*appr_factor) 115 | keep_bboxes = np.ones(scores.shape, dtype=np.bool) 116 | for i in range(scores.size-1): 117 | if keep_bboxes[i]: 118 | # Computer overlap with bboxes which are following. 119 | overlap = overlapped_fn(boxes_corners[i], boxes_corners[(i+1):]) 120 | # Overlap threshold for keeping + checking part of the same class 121 | keep_overlap = np.logical_or( 122 | overlap <= nms_threshold, classes[(i+1):] != classes[i]) 123 | keep_bboxes[(i+1):] = np.logical_and( 124 | keep_bboxes[(i+1):], keep_overlap)## 125 | idxes = np.where(keep_bboxes) 126 | classes = classes[idxes] 127 | scores = scores[idxes] 128 | bboxes = bboxes[idxes] 129 | if attributes is not None: 130 | attributes = attributes[idxes] 131 | return classes, scores, bboxes, attributes 132 | 133 | def bboxes_nms_uncertainty(classes, scores, bboxes, scores_threshold=0.25, 134 | nms_threshold=0.45, overlapped_fn=overlapped_boxes_3d, appr_factor=10.0, 135 | attributes=None): 136 | """Apply non-maximum selection to bounding boxes. 137 | """ 138 | boxes_corners = boxes_3d_to_corners(bboxes) 139 | # boxes_corners = bboxes 140 | # convert to pixels 141 | # boxes_corners = np.int32(boxes_corners*appr_factor) 142 | keep_bboxes = np.ones(scores.shape, dtype=np.bool) 143 | for i in range(scores.size-1): 144 | if keep_bboxes[i]: 145 | # Only compute on the rest of bboxes 146 | valid = keep_bboxes[(i+1):] 147 | # Computer overlap with bboxes which are following. 148 | overlap = overlapped_fn( 149 | boxes_corners[i], boxes_corners[(i+1):][valid]) 150 | # Overlap threshold for keeping + checking part of the same class 151 | remove_overlap = np.logical_and( 152 | overlap > nms_threshold, classes[(i+1):][valid] == classes[i]) 153 | overlaped_bboxes = np.concatenate( 154 | [bboxes[(i+1):][valid][remove_overlap], bboxes[[i]]], axis=0) 155 | boxes_mean = np.median(overlaped_bboxes, axis=0) 156 | bboxes[i][:] = boxes_mean[:] 157 | boxes_corners_mean = boxes_3d_to_corners( 158 | np.expand_dims(boxes_mean, axis=0)) 159 | boxes_mean_overlap = overlapped_fn(boxes_corners_mean[0], 160 | boxes_corners[(i+1):][valid][remove_overlap]) 161 | scores[i] += np.sum( 162 | scores[(i+1):][valid][remove_overlap]*boxes_mean_overlap) 163 | keep_bboxes[(i+1):][valid] = np.logical_not(remove_overlap)## 164 | idxes = np.where(keep_bboxes) 165 | classes = classes[idxes] 166 | scores = scores[idxes] 167 | bboxes = bboxes[idxes] 168 | if attributes is not None: 169 | attributes = attributes[idxes] 170 | return classes, scores, bboxes, attributes 171 | 172 | def bboxes_nms_merge_only(classes, scores, bboxes, scores_threshold=0.25, 173 | nms_threshold=0.45, overlapped_fn=overlapped_boxes_3d, appr_factor=10.0, 174 | attributes=None): 175 | """Apply non-maximum selection to bounding boxes. 176 | """ 177 | boxes_corners = boxes_3d_to_corners(bboxes) 178 | # convert to pixels 179 | keep_bboxes = np.ones(scores.shape, dtype=np.bool) 180 | for i in range(scores.size-1): 181 | if keep_bboxes[i]: 182 | # Only compute on the rest of bboxes 183 | valid = keep_bboxes[(i+1):] 184 | # Computer overlap with bboxes which are following. 185 | overlap = overlapped_fn(boxes_corners[i], 186 | boxes_corners[(i+1):][valid]) 187 | # Overlap threshold for keeping + checking part of the same class 188 | remove_overlap = np.logical_and(overlap > nms_threshold, 189 | classes[(i+1):][valid] == classes[i]) 190 | overlaped_bboxes = np.concatenate( 191 | [bboxes[(i+1):][valid][remove_overlap], bboxes[[i]]], axis=0) 192 | boxes_mean = np.median(overlaped_bboxes, axis=0) 193 | bboxes[i][:] = boxes_mean[:] 194 | keep_bboxes[(i+1):][valid] = np.logical_not(remove_overlap)## 195 | 196 | idxes = np.where(keep_bboxes) 197 | classes = classes[idxes] 198 | scores = scores[idxes] 199 | bboxes = bboxes[idxes] 200 | if attributes is not None: 201 | attributes = attributes[idxes] 202 | return classes, scores, bboxes, attributes 203 | 204 | def bboxes_nms_score_only(classes, scores, bboxes, scores_threshold=0.25, 205 | nms_threshold=0.45, overlapped_fn=overlapped_boxes_3d, appr_factor=10.0, 206 | attributes=None): 207 | """Apply non-maximum selection to bounding boxes. 208 | """ 209 | boxes_corners = boxes_3d_to_corners(bboxes) 210 | # convert to pixels 211 | keep_bboxes = np.ones(scores.shape, dtype=np.bool) 212 | for i in range(scores.size-1): 213 | if keep_bboxes[i]: 214 | # Only compute on the rest of bboxes 215 | valid = keep_bboxes[(i+1):] 216 | # Computer overlap with bboxes which are following. 217 | overlap = overlapped_fn(boxes_corners[i], 218 | boxes_corners[(i+1):][valid]) 219 | # Overlap threshold for keeping + checking part of the same class 220 | remove_overlap = np.logical_and(overlap > nms_threshold, 221 | classes[(i+1):][valid] == classes[i]) 222 | overlaped_bboxes = np.concatenate( 223 | [bboxes[(i+1):][valid][remove_overlap], bboxes[[i]]], axis=0) 224 | boxes_mean = bboxes[i][:] 225 | bboxes[i][:] = boxes_mean[:] 226 | boxes_corners_mean = boxes_3d_to_corners( 227 | np.expand_dims(boxes_mean, axis=0)) 228 | boxes_mean_overlap = overlapped_fn(boxes_corners_mean[0], 229 | boxes_corners[(i+1):][valid][remove_overlap]) 230 | scores[i] += np.sum( 231 | scores[(i+1):][valid][remove_overlap]*boxes_mean_overlap) 232 | keep_bboxes[(i+1):][valid] = np.logical_not(remove_overlap)## 233 | idxes = np.where(keep_bboxes) 234 | classes = classes[idxes] 235 | scores = scores[idxes] 236 | bboxes = bboxes[idxes] 237 | if attributes is not None: 238 | attributes = attributes[idxes] 239 | return classes, scores, bboxes, attributes 240 | 241 | def nms_boxes_3d(class_labels, detection_boxes_3d, detection_scores, 242 | overlapped_thres=0.5, overlapped_fn=overlapped_boxes_3d, appr_factor=10.0, 243 | top_k=-1, attributes=None): 244 | class_labels, detection_scores, detection_boxes_3d, attributes = \ 245 | bboxes_sort( 246 | class_labels, detection_scores, detection_boxes_3d, top_k=top_k, 247 | attributes=attributes) 248 | # nms 249 | class_labels, detection_scores, detection_boxes_3d, attributes = \ 250 | bboxes_nms( 251 | class_labels, detection_scores, detection_boxes_3d, 252 | nms_threshold=overlapped_thres, overlapped_fn=overlapped_fn, 253 | appr_factor=appr_factor, attributes=attributes) 254 | return class_labels, detection_boxes_3d, detection_scores, attributes 255 | 256 | def nms_boxes_3d_uncertainty(class_labels, detection_boxes_3d, detection_scores, 257 | overlapped_thres=0.5, overlapped_fn=overlapped_boxes_3d, appr_factor=10.0, 258 | top_k=-1, attributes=None): 259 | 260 | class_labels, detection_scores, detection_boxes_3d, attributes = \ 261 | bboxes_sort( 262 | class_labels, detection_scores, detection_boxes_3d, top_k=top_k, 263 | attributes=attributes) 264 | # nms 265 | class_labels, detection_scores, detection_boxes_3d, attributes = \ 266 | bboxes_nms_uncertainty( 267 | class_labels, detection_scores, detection_boxes_3d, 268 | nms_threshold=overlapped_thres, overlapped_fn=overlapped_fn, 269 | appr_factor=appr_factor, attributes=attributes) 270 | return class_labels, detection_boxes_3d, detection_scores, attributes 271 | 272 | def nms_boxes_3d_merge_only(class_labels, detection_boxes_3d, detection_scores, 273 | overlapped_thres=0.5, overlapped_fn=overlapped_boxes_3d, appr_factor=10.0, 274 | top_k=-1, attributes=None): 275 | class_labels, detection_scores, detection_boxes_3d, attributes = \ 276 | bboxes_sort( 277 | class_labels, detection_scores, detection_boxes_3d, top_k=top_k, 278 | attributes=attributes) 279 | # nms 280 | class_labels, detection_scores, detection_boxes_3d, attributes = \ 281 | bboxes_nms_merge_only( 282 | class_labels, detection_scores, detection_boxes_3d, 283 | nms_threshold=overlapped_thres, overlapped_fn=overlapped_fn, 284 | appr_factor=appr_factor, attributes=attributes) 285 | return class_labels, detection_boxes_3d, detection_scores, attributes 286 | 287 | def nms_boxes_3d_score_only(class_labels, detection_boxes_3d, detection_scores, 288 | overlapped_thres=0.5, overlapped_fn=overlapped_boxes_3d, appr_factor=10.0, 289 | top_k=-1, attributes=None): 290 | class_labels, detection_scores, detection_boxes_3d, attributes = \ 291 | bboxes_sort( 292 | class_labels, detection_scores, detection_boxes_3d, top_k=top_k, 293 | attributes=attributes) 294 | # nms 295 | class_labels, detection_scores, detection_boxes_3d, attributes = \ 296 | bboxes_nms_score_only( 297 | class_labels, detection_scores, detection_boxes_3d, 298 | nms_threshold=overlapped_thres, overlapped_fn=overlapped_fn, 299 | appr_factor=appr_factor, attributes=attributes) 300 | return class_labels, detection_boxes_3d, detection_scores, attributes 301 | -------------------------------------------------------------------------------- /mytrain.py: -------------------------------------------------------------------------------- 1 | import torch 2 | from util.config_util import save_config, save_train_config, \ 3 | load_train_config, load_config 4 | from models.box_encoding import get_box_decoding_fn, get_box_encoding_fn, get_encoding_len 5 | import os 6 | from dataset.kitti_dataset import KittiDataset 7 | from kitty_dataset import DataProvider 8 | from model import * 9 | import numpy as np 10 | import argparse 11 | from util.metrics import recall_precisions, mAP 12 | from tqdm import trange 13 | from tqdm import tqdm 14 | 15 | 16 | 17 | if __name__ == "__main__": 18 | parser = argparse.ArgumentParser(description='Training of PointGNN') 19 | parser.add_argument('train_config_path', type=str, 20 | help='Path to train_config') 21 | parser.add_argument('config_path', type=str, 22 | help='Path to config') 23 | parser.add_argument('--device', type=str, default='cuda:0', 24 | help="Device for training, cuda or cpu") 25 | parser.add_argument('--batch_size', type=int, default=1, 26 | help='Batch size') 27 | parser.add_argument('--epoches', type=int, default=100, 28 | help='Training epoches') 29 | parser.add_argument('--dataset_root_dir', type=str, default='../dataset/kitti/', 30 | help='Path to KITTI dataset. Default="../dataset/kitti/"') 31 | parser.add_argument('--dataset_split_file', type=str, 32 | default='', 33 | help='Path to KITTI dataset split file.' 34 | 'Default="DATASET_ROOT_DIR/3DOP_splits' 35 | '/train_config["train_dataset"]"') 36 | 37 | args = parser.parse_args() 38 | epoches = args.epoches 39 | batch_size = args.batch_size 40 | device = args.device 41 | train_config = load_train_config(args.train_config_path) 42 | DATASET_DIR = args.dataset_root_dir 43 | config_complete = load_config(args.config_path) 44 | if 'train' in config_complete: 45 | config = config_complete['train'] 46 | else: 47 | config = config_complete 48 | 49 | if args.dataset_split_file == '': 50 | DATASET_SPLIT_FILE = os.path.join(DATASET_DIR, 51 | './3DOP_splits/'+train_config['train_dataset']) 52 | else: 53 | DATASET_SPLIT_FILE = args.dataset_split_file 54 | 55 | # input function ============================================================== 56 | dataset = KittiDataset( 57 | os.path.join(DATASET_DIR, 'image/training/image_2'), 58 | os.path.join(DATASET_DIR, 'velodyne/training/velodyne/'), 59 | os.path.join(DATASET_DIR, 'calib/training/calib/'), 60 | os.path.join(DATASET_DIR, 'labels/training/label_2'), 61 | DATASET_SPLIT_FILE, 62 | num_classes=config['num_classes']) 63 | 64 | data_provider = DataProvider(dataset, train_config, config) 65 | #input_v, vertex_coord_list, keypoint_indices_list, edges_list, \ 66 | # cls_labels, encoded_boxes, valid_boxes = data_provider.provide_batch([1545, 1546]) 67 | 68 | batch = data_provider.provide_batch([1545, 1546]) 69 | input_v, vertex_coord_list, keypoint_indices_list, edges_list, \ 70 | cls_labels, encoded_boxes, valid_boxes = batch 71 | 72 | 73 | NUM_CLASSES = dataset.num_classes 74 | BOX_ENCODING_LEN = get_encoding_len(config['box_encoding_method']) 75 | 76 | model = MultiLayerFastLocalGraphModelV2(num_classes=NUM_CLASSES, 77 | box_encoding_len=BOX_ENCODING_LEN, mode='train', 78 | **config['model_kwargs']) 79 | model = model.to(device) 80 | 81 | optimizer = torch.optim.Adam(model.parameters(), lr=1e-4) 82 | NUM_TEST_SAMPLE = dataset.num_files 83 | 84 | os.system("mkdir saved_models") 85 | 86 | 87 | for epoch in range(1, epoches): 88 | recalls_list, precisions_list, mAP_list = {}, {}, {} 89 | for i in range(NUM_CLASSES): recalls_list[i], precisions_list[i], mAP_list[i] = [], [], [] 90 | 91 | frame_idx_list = np.random.permutation(NUM_TEST_SAMPLE) 92 | 93 | pbar = tqdm(list(range(0, NUM_TEST_SAMPLE-batch_size+1, batch_size)), desc="start training", leave=True) 94 | 95 | for batch_idx in pbar: 96 | #for batch_idx in range(0, NUM_TEST_SAMPLE-batch_size+1, batch_size): 97 | batch_frame_idx_list = frame_idx_list[batch_idx: batch_idx+batch_size] 98 | batch = data_provider.provide_batch(batch_frame_idx_list) 99 | input_v, vertex_coord_list, keypoint_indices_list, edges_list, \ 100 | cls_labels, encoded_boxes, valid_boxes = batch 101 | 102 | new_batch = [] 103 | for item in batch: 104 | if not isinstance(item, torch.Tensor): 105 | item = [x.to(device) for x in item] 106 | else: item = item.to(device) 107 | new_batch += [item] 108 | batch = new_batch 109 | input_v, vertex_coord_list, keypoint_indices_list, edges_list, \ 110 | cls_labels, encoded_boxes, valid_boxes = batch 111 | 112 | logits, box_encoding = model(batch, is_training=True) 113 | predictions = torch.argmax(logits, dim=1) 114 | 115 | loss_dict = model.loss(logits, cls_labels, box_encoding, encoded_boxes, valid_boxes) 116 | t_cls_loss, t_loc_loss, t_reg_loss = loss_dict['cls_loss'], loss_dict['loc_loss'], loss_dict['reg_loss'] 117 | pbar.set_description(f"{epoch}, t_cls_loss: {t_cls_loss}, t_loc_loss: {t_loc_loss}, t_reg_loss: {t_reg_loss}") 118 | t_total_loss = t_cls_loss + t_loc_loss + t_reg_loss 119 | optimizer.zero_grad() 120 | t_total_loss.backward() 121 | optimizer.step() 122 | 123 | # record metrics 124 | recalls, precisions = recall_precisions(cls_labels, predictions, NUM_CLASSES) 125 | #mAPs = mAP(cls_labels, logits, NUM_CLASSES) 126 | mAPs = mAP(cls_labels, logits.sigmoid(), NUM_CLASSES) 127 | for i in range(NUM_CLASSES): 128 | recalls_list[i] += [recalls[i]] 129 | precisions_list[i] += [precisions[i]] 130 | mAP_list[i] += [mAPs[i]] 131 | 132 | # print metrics 133 | for class_idx in range(NUM_CLASSES): 134 | print(f"class_idx:{class_idx}, recall: {np.mean(recalls_list[class_idx])}, precision: {np.mean(precisions_list[class_idx])}, mAP: {np.mean(mAP_list[class_idx])}") 135 | 136 | # save model 137 | torch.save(model.state_dict(), "saved_models/model_{}.pt".format(epoch)) 138 | 139 | 140 | 141 | 142 | 143 | 144 | -------------------------------------------------------------------------------- /scripts/point_cloud_downsample.py: -------------------------------------------------------------------------------- 1 | import numpy as np 2 | import open3d 3 | import os 4 | from dataset.kitti_dataset import KittiDataset 5 | from sklearn.cluster import KMeans 6 | from tqdm import tqdm 7 | 8 | 9 | dataset = KittiDataset( 10 | '../dataset/kitti/image/training/image_2', 11 | '../dataset/kitti/velodyne/training/velodyne/', 12 | '../dataset/kitti/calib/training/calib/', 13 | '', 14 | '../dataset/kitti/3DOP_splits/val.txt', 15 | is_training=False) 16 | 17 | downsample_rate = 2 18 | output_dir = '../dataset/kitti/velodyne/training_downsampled_%d/velodyne/' % downsample_rate 19 | for frame_idx in tqdm(range(0, dataset.num_files)): 20 | velo_points = dataset.get_velo_points(frame_idx) 21 | filename = dataset.get_filename(frame_idx) 22 | xyz = velo_points.xyz 23 | xyz_norm = np.sqrt(np.sum(xyz * xyz, axis=1, keepdims=True)) 24 | z_axis = np.array([[0], [0], [1]]) 25 | cos = xyz.dot(z_axis) / xyz_norm 26 | kmeans = KMeans(n_clusters=64, n_jobs=-1).fit(cos) 27 | centers = np.sort(np.squeeze(kmeans.cluster_centers_)) 28 | centers = [-1, ] + centers.tolist() + [1, ] 29 | cos = np.squeeze(cos) 30 | point_total_mask = np.zeros(len(velo_points.xyz), dtype=np.bool) 31 | for i in range(0, len(centers) - 2, downsample_rate): 32 | lower = (centers[i] + centers[i + 1]) / 2 33 | higher = (centers[i + 1] + centers[i + 2]) / 2 34 | point_mask = (cos > lower) * (cos < higher) 35 | point_total_mask += point_mask 36 | # Visualization 37 | # pcd = open3d.PointCloud() 38 | # pcd.points = open3d.Vector3dVector(velo_points.xyz[point_total_mask, :]) 39 | # def custom_draw_geometry_load_option(geometry_list): 40 | # vis = open3d.Visualizer() 41 | # vis.create_window() 42 | # for geometry in geometry_list: 43 | # vis.add_geometry(geometry) 44 | # ctr = vis.get_view_control() 45 | # ctr.rotate(0.0, 3141.0, 0) 46 | # vis.run() 47 | # vis.destroy_window() 48 | # custom_draw_geometry_load_option([pcd]) 49 | 50 | output = np.hstack([velo_points.xyz[point_total_mask, :], velo_points.attr[point_total_mask, :]]) 51 | point_file = output_dir + filename + '.bin' 52 | os.makedirs(os.path.dirname(point_file), exist_ok=True) 53 | output.tofile(point_file) 54 | -------------------------------------------------------------------------------- /splits/train_pedestrian_cyclist.txt: -------------------------------------------------------------------------------- 1 | 000000 2 | 000007 3 | 000010 4 | 000011 5 | 000043 6 | 000049 7 | 000051 8 | 000054 9 | 000060 10 | 000068 11 | 000070 12 | 000073 13 | 000074 14 | 000085 15 | 000087 16 | 000091 17 | 000100 18 | 000103 19 | 000105 20 | 000109 21 | 000114 22 | 000119 23 | 000127 24 | 000130 25 | 000142 26 | 000144 27 | 000145 28 | 000146 29 | 000149 30 | 000150 31 | 000154 32 | 000157 33 | 000162 34 | 000164 35 | 000177 36 | 000179 37 | 000189 38 | 000193 39 | 000200 40 | 000206 41 | 000208 42 | 000209 43 | 000214 44 | 000217 45 | 000221 46 | 000228 47 | 000232 48 | 000245 49 | 000254 50 | 000264 51 | 000271 52 | 000274 53 | 000275 54 | 000277 55 | 000286 56 | 000287 57 | 000288 58 | 000292 59 | 000294 60 | 000295 61 | 000298 62 | 000303 63 | 000306 64 | 000310 65 | 000316 66 | 000317 67 | 000325 68 | 000330 69 | 000331 70 | 000337 71 | 000339 72 | 000342 73 | 000348 74 | 000364 75 | 000368 76 | 000371 77 | 000380 78 | 000384 79 | 000412 80 | 000423 81 | 000424 82 | 000435 83 | 000442 84 | 000445 85 | 000449 86 | 000460 87 | 000461 88 | 000462 89 | 000464 90 | 000465 91 | 000471 92 | 000484 93 | 000487 94 | 000490 95 | 000501 96 | 000502 97 | 000505 98 | 000514 99 | 000518 100 | 000520 101 | 000522 102 | 000529 103 | 000532 104 | 000535 105 | 000553 106 | 000575 107 | 000585 108 | 000596 109 | 000598 110 | 000599 111 | 000608 112 | 000622 113 | 000632 114 | 000633 115 | 000638 116 | 000646 117 | 000663 118 | 000671 119 | 000675 120 | 000685 121 | 000695 122 | 000701 123 | 000703 124 | 000705 125 | 000712 126 | 000713 127 | 000720 128 | 000726 129 | 000742 130 | 000743 131 | 000744 132 | 000755 133 | 000760 134 | 000763 135 | 000764 136 | 000770 137 | 000780 138 | 000784 139 | 000787 140 | 000788 141 | 000789 142 | 000796 143 | 000797 144 | 000799 145 | 000808 146 | 000825 147 | 000827 148 | 000830 149 | 000846 150 | 000855 151 | 000860 152 | 000861 153 | 000868 154 | 000870 155 | 000883 156 | 000886 157 | 000892 158 | 000902 159 | 000906 160 | 000910 161 | 000913 162 | 000919 163 | 000924 164 | 000925 165 | 000936 166 | 000937 167 | 000947 168 | 000955 169 | 000957 170 | 000959 171 | 000965 172 | 000972 173 | 000982 174 | 000987 175 | 000992 176 | 000998 177 | 001003 178 | 001029 179 | 001031 180 | 001033 181 | 001034 182 | 001036 183 | 001040 184 | 001045 185 | 001049 186 | 001057 187 | 001061 188 | 001062 189 | 001091 190 | 001100 191 | 001112 192 | 001124 193 | 001160 194 | 001161 195 | 001168 196 | 001169 197 | 001171 198 | 001174 199 | 001184 200 | 001196 201 | 001204 202 | 001205 203 | 001208 204 | 001211 205 | 001220 206 | 001233 207 | 001238 208 | 001256 209 | 001262 210 | 001279 211 | 001283 212 | 001290 213 | 001300 214 | 001310 215 | 001311 216 | 001319 217 | 001322 218 | 001323 219 | 001325 220 | 001328 221 | 001340 222 | 001348 223 | 001351 224 | 001354 225 | 001357 226 | 001358 227 | 001360 228 | 001366 229 | 001368 230 | 001370 231 | 001371 232 | 001378 233 | 001394 234 | 001396 235 | 001404 236 | 001405 237 | 001406 238 | 001409 239 | 001422 240 | 001423 241 | 001425 242 | 001426 243 | 001428 244 | 001433 245 | 001444 246 | 001449 247 | 001454 248 | 001455 249 | 001464 250 | 001467 251 | 001468 252 | 001496 253 | 001504 254 | 001509 255 | 001519 256 | 001523 257 | 001529 258 | 001531 259 | 001540 260 | 001543 261 | 001544 262 | 001548 263 | 001558 264 | 001566 265 | 001568 266 | 001570 267 | 001571 268 | 001572 269 | 001580 270 | 001608 271 | 001609 272 | 001611 273 | 001618 274 | 001622 275 | 001624 276 | 001632 277 | 001636 278 | 001644 279 | 001648 280 | 001655 281 | 001661 282 | 001672 283 | 001674 284 | 001676 285 | 001677 286 | 001692 287 | 001696 288 | 001700 289 | 001716 290 | 001720 291 | 001725 292 | 001738 293 | 001748 294 | 001760 295 | 001763 296 | 001766 297 | 001770 298 | 001777 299 | 001779 300 | 001785 301 | 001788 302 | 001790 303 | 001792 304 | 001796 305 | 001811 306 | 001819 307 | 001827 308 | 001830 309 | 001837 310 | 001849 311 | 001864 312 | 001870 313 | 001873 314 | 001874 315 | 001879 316 | 001895 317 | 001902 318 | 001908 319 | 001910 320 | 001911 321 | 001912 322 | 001915 323 | 001916 324 | 001918 325 | 001921 326 | 001948 327 | 001951 328 | 001957 329 | 001958 330 | 001962 331 | 001963 332 | 001964 333 | 001970 334 | 001971 335 | 001974 336 | 001981 337 | 001987 338 | 001988 339 | 001992 340 | 001994 341 | 001998 342 | 002003 343 | 002005 344 | 002006 345 | 002007 346 | 002015 347 | 002016 348 | 002023 349 | 002024 350 | 002030 351 | 002059 352 | 002063 353 | 002065 354 | 002080 355 | 002084 356 | 002088 357 | 002092 358 | 002095 359 | 002096 360 | 002110 361 | 002114 362 | 002117 363 | 002133 364 | 002144 365 | 002147 366 | 002149 367 | 002156 368 | 002157 369 | 002164 370 | 002167 371 | 002171 372 | 002175 373 | 002176 374 | 002178 375 | 002180 376 | 002181 377 | 002189 378 | 002194 379 | 002199 380 | 002210 381 | 002211 382 | 002226 383 | 002227 384 | 002247 385 | 002249 386 | 002264 387 | 002267 388 | 002269 389 | 002271 390 | 002275 391 | 002285 392 | 002288 393 | 002289 394 | 002296 395 | 002301 396 | 002309 397 | 002317 398 | 002322 399 | 002324 400 | 002339 401 | 002342 402 | 002351 403 | 002360 404 | 002375 405 | 002379 406 | 002388 407 | 002400 408 | 002402 409 | 002403 410 | 002408 411 | 002410 412 | 002416 413 | 002435 414 | 002444 415 | 002456 416 | 002465 417 | 002471 418 | 002480 419 | 002485 420 | 002487 421 | 002493 422 | 002494 423 | 002496 424 | 002501 425 | 002507 426 | 002508 427 | 002510 428 | 002512 429 | 002514 430 | 002517 431 | 002524 432 | 002535 433 | 002536 434 | 002537 435 | 002549 436 | 002551 437 | 002554 438 | 002560 439 | 002571 440 | 002595 441 | 002605 442 | 002609 443 | 002616 444 | 002617 445 | 002618 446 | 002634 447 | 002643 448 | 002648 449 | 002649 450 | 002652 451 | 002654 452 | 002671 453 | 002689 454 | 002691 455 | 002703 456 | 002716 457 | 002718 458 | 002723 459 | 002731 460 | 002734 461 | 002738 462 | 002739 463 | 002754 464 | 002762 465 | 002769 466 | 002771 467 | 002774 468 | 002780 469 | 002781 470 | 002788 471 | 002792 472 | 002798 473 | 002799 474 | 002813 475 | 002816 476 | 002822 477 | 002829 478 | 002837 479 | 002838 480 | 002842 481 | 002843 482 | 002851 483 | 002857 484 | 002860 485 | 002864 486 | 002868 487 | 002870 488 | 002888 489 | 002897 490 | 002898 491 | 002904 492 | 002906 493 | 002915 494 | 002918 495 | 002920 496 | 002923 497 | 002927 498 | 002933 499 | 002939 500 | 002943 501 | 002952 502 | 002954 503 | 002956 504 | 002973 505 | 002986 506 | 003002 507 | 003008 508 | 003009 509 | 003013 510 | 003014 511 | 003015 512 | 003017 513 | 003018 514 | 003020 515 | 003028 516 | 003057 517 | 003059 518 | 003078 519 | 003084 520 | 003085 521 | 003095 522 | 003105 523 | 003115 524 | 003120 525 | 003121 526 | 003122 527 | 003130 528 | 003140 529 | 003147 530 | 003149 531 | 003155 532 | 003157 533 | 003158 534 | 003163 535 | 003164 536 | 003166 537 | 003171 538 | 003178 539 | 003185 540 | 003193 541 | 003195 542 | 003206 543 | 003235 544 | 003237 545 | 003238 546 | 003244 547 | 003249 548 | 003256 549 | 003258 550 | 003260 551 | 003264 552 | 003267 553 | 003277 554 | 003282 555 | 003286 556 | 003287 557 | 003289 558 | 003294 559 | 003307 560 | 003309 561 | 003311 562 | 003321 563 | 003329 564 | 003333 565 | 003336 566 | 003339 567 | 003349 568 | 003356 569 | 003359 570 | 003369 571 | 003377 572 | 003380 573 | 003387 574 | 003390 575 | 003391 576 | 003398 577 | 003437 578 | 003438 579 | 003442 580 | 003454 581 | 003455 582 | 003459 583 | 003472 584 | 003476 585 | 003494 586 | 003498 587 | 003510 588 | 003512 589 | 003514 590 | 003516 591 | 003518 592 | 003526 593 | 003533 594 | 003534 595 | 003540 596 | 003548 597 | 003555 598 | 003556 599 | 003560 600 | 003564 601 | 003576 602 | 003585 603 | 003586 604 | 003587 605 | 003589 606 | 003591 607 | 003602 608 | 003603 609 | 003615 610 | 003617 611 | 003619 612 | 003626 613 | 003628 614 | 003641 615 | 003642 616 | 003644 617 | 003646 618 | 003663 619 | 003664 620 | 003670 621 | 003675 622 | 003686 623 | 003687 624 | 003694 625 | 003696 626 | 003700 627 | 003706 628 | 003709 629 | 003713 630 | 003722 631 | 003725 632 | 003740 633 | 003744 634 | 003758 635 | 003759 636 | 003765 637 | 003766 638 | 003770 639 | 003774 640 | 003784 641 | 003786 642 | 003790 643 | 003797 644 | 003801 645 | 003819 646 | 003823 647 | 003838 648 | 003849 649 | 003858 650 | 003861 651 | 003862 652 | 003876 653 | 003887 654 | 003912 655 | 003921 656 | 003922 657 | 003927 658 | 003929 659 | 003939 660 | 003947 661 | 003954 662 | 003955 663 | 003963 664 | 003967 665 | 003971 666 | 003974 667 | 003978 668 | 003979 669 | 003989 670 | 003993 671 | 003995 672 | 003997 673 | 004012 674 | 004014 675 | 004019 676 | 004022 677 | 004024 678 | 004035 679 | 004039 680 | 004053 681 | 004057 682 | 004058 683 | 004073 684 | 004086 685 | 004088 686 | 004090 687 | 004093 688 | 004094 689 | 004099 690 | 004103 691 | 004106 692 | 004114 693 | 004127 694 | 004139 695 | 004144 696 | 004147 697 | 004165 698 | 004170 699 | 004176 700 | 004177 701 | 004179 702 | 004186 703 | 004192 704 | 004194 705 | 004197 706 | 004212 707 | 004218 708 | 004219 709 | 004229 710 | 004230 711 | 004233 712 | 004235 713 | 004240 714 | 004247 715 | 004252 716 | 004261 717 | 004265 718 | 004268 719 | 004273 720 | 004274 721 | 004283 722 | 004287 723 | 004292 724 | 004297 725 | 004304 726 | 004308 727 | 004325 728 | 004328 729 | 004332 730 | 004346 731 | 004347 732 | 004351 733 | 004358 734 | 004361 735 | 004371 736 | 004378 737 | 004379 738 | 004382 739 | 004389 740 | 004400 741 | 004410 742 | 004411 743 | 004417 744 | 004428 745 | 004431 746 | 004436 747 | 004442 748 | 004446 749 | 004449 750 | 004459 751 | 004484 752 | 004488 753 | 004499 754 | 004503 755 | 004509 756 | 004512 757 | 004514 758 | 004515 759 | 004524 760 | 004536 761 | 004537 762 | 004543 763 | 004546 764 | 004552 765 | 004558 766 | 004561 767 | 004564 768 | 004580 769 | 004583 770 | 004584 771 | 004594 772 | 004595 773 | 004597 774 | 004602 775 | 004604 776 | 004605 777 | 004613 778 | 004619 779 | 004627 780 | 004637 781 | 004639 782 | 004646 783 | 004662 784 | 004674 785 | 004690 786 | 004696 787 | 004704 788 | 004712 789 | 004723 790 | 004729 791 | 004731 792 | 004733 793 | 004736 794 | 004747 795 | 004749 796 | 004755 797 | 004757 798 | 004772 799 | 004780 800 | 004786 801 | 004793 802 | 004796 803 | 004802 804 | 004818 805 | 004827 806 | 004836 807 | 004837 808 | 004844 809 | 004854 810 | 004855 811 | 004857 812 | 004866 813 | 004869 814 | 004870 815 | 004872 816 | 004877 817 | 004882 818 | 004884 819 | 004897 820 | 004901 821 | 004906 822 | 004908 823 | 004912 824 | 004915 825 | 004923 826 | 004939 827 | 004940 828 | 004945 829 | 004950 830 | 004955 831 | 004964 832 | 004965 833 | 004967 834 | 004968 835 | 004972 836 | 004973 837 | 004977 838 | 005005 839 | 005007 840 | 005011 841 | 005023 842 | 005027 843 | 005030 844 | 005031 845 | 005035 846 | 005039 847 | 005043 848 | 005044 849 | 005059 850 | 005061 851 | 005083 852 | 005084 853 | 005085 854 | 005092 855 | 005097 856 | 005098 857 | 005102 858 | 005104 859 | 005106 860 | 005107 861 | 005114 862 | 005115 863 | 005116 864 | 005119 865 | 005131 866 | 005137 867 | 005142 868 | 005165 869 | 005171 870 | 005196 871 | 005203 872 | 005207 873 | 005208 874 | 005211 875 | 005228 876 | 005231 877 | 005235 878 | 005239 879 | 005245 880 | 005248 881 | 005258 882 | 005259 883 | 005263 884 | 005270 885 | 005278 886 | 005281 887 | 005286 888 | 005294 889 | 005300 890 | 005301 891 | 005314 892 | 005317 893 | 005324 894 | 005348 895 | 005353 896 | 005356 897 | 005358 898 | 005370 899 | 005380 900 | 005383 901 | 005395 902 | 005403 903 | 005408 904 | 005417 905 | 005418 906 | 005419 907 | 005428 908 | 005433 909 | 005436 910 | 005439 911 | 005462 912 | 005469 913 | 005485 914 | 005496 915 | 005504 916 | 005517 917 | 005518 918 | 005519 919 | 005524 920 | 005526 921 | 005533 922 | 005539 923 | 005541 924 | 005553 925 | 005554 926 | 005567 927 | 005568 928 | 005579 929 | 005598 930 | 005604 931 | 005608 932 | 005612 933 | 005621 934 | 005622 935 | 005637 936 | 005647 937 | 005654 938 | 005661 939 | 005663 940 | 005665 941 | 005666 942 | 005670 943 | 005671 944 | 005682 945 | 005686 946 | 005688 947 | 005690 948 | 005692 949 | 005693 950 | 005696 951 | 005711 952 | 005718 953 | 005719 954 | 005733 955 | 005750 956 | 005766 957 | 005769 958 | 005770 959 | 005771 960 | 005773 961 | 005778 962 | 005779 963 | 005791 964 | 005804 965 | 005808 966 | 005813 967 | 005817 968 | 005823 969 | 005824 970 | 005825 971 | 005835 972 | 005836 973 | 005837 974 | 005838 975 | 005846 976 | 005847 977 | 005849 978 | 005853 979 | 005866 980 | 005880 981 | 005888 982 | 005890 983 | 005891 984 | 005895 985 | 005898 986 | 005902 987 | 005915 988 | 005920 989 | 005924 990 | 005928 991 | 005934 992 | 005937 993 | 005945 994 | 005950 995 | 005954 996 | 005959 997 | 005964 998 | 005980 999 | 005991 1000 | 006006 1001 | 006011 1002 | 006019 1003 | 006021 1004 | 006035 1005 | 006051 1006 | 006053 1007 | 006065 1008 | 006076 1009 | 006082 1010 | 006089 1011 | 006092 1012 | 006094 1013 | 006109 1014 | 006158 1015 | 006159 1016 | 006160 1017 | 006164 1018 | 006166 1019 | 006174 1020 | 006179 1021 | 006184 1022 | 006189 1023 | 006201 1024 | 006211 1025 | 006214 1026 | 006221 1027 | 006224 1028 | 006226 1029 | 006231 1030 | 006234 1031 | 006236 1032 | 006242 1033 | 006256 1034 | 006257 1035 | 006261 1036 | 006262 1037 | 006268 1038 | 006277 1039 | 006284 1040 | 006292 1041 | 006293 1042 | 006294 1043 | 006298 1044 | 006299 1045 | 006303 1046 | 006308 1047 | 006313 1048 | 006328 1049 | 006337 1050 | 006341 1051 | 006352 1052 | 006358 1053 | 006361 1054 | 006362 1055 | 006378 1056 | 006382 1057 | 006390 1058 | 006400 1059 | 006412 1060 | 006414 1061 | 006422 1062 | 006431 1063 | 006449 1064 | 006458 1065 | 006460 1066 | 006467 1067 | 006471 1068 | 006476 1069 | 006479 1070 | 006480 1071 | 006492 1072 | 006500 1073 | 006502 1074 | 006511 1075 | 006522 1076 | 006528 1077 | 006536 1078 | 006539 1079 | 006546 1080 | 006552 1081 | 006564 1082 | 006572 1083 | 006573 1084 | 006575 1085 | 006579 1086 | 006585 1087 | 006589 1088 | 006600 1089 | 006605 1090 | 006608 1091 | 006617 1092 | 006627 1093 | 006635 1094 | 006639 1095 | 006642 1096 | 006645 1097 | 006652 1098 | 006653 1099 | 006657 1100 | 006661 1101 | 006690 1102 | 006700 1103 | 006704 1104 | 006717 1105 | 006722 1106 | 006728 1107 | 006749 1108 | 006750 1109 | 006766 1110 | 006776 1111 | 006779 1112 | 006795 1113 | 006799 1114 | 006801 1115 | 006802 1116 | 006809 1117 | 006810 1118 | 006814 1119 | 006825 1120 | 006827 1121 | 006830 1122 | 006835 1123 | 006838 1124 | 006842 1125 | 006859 1126 | 006861 1127 | 006878 1128 | 006886 1129 | 006892 1130 | 006894 1131 | 006904 1132 | 006911 1133 | 006912 1134 | 006916 1135 | 006919 1136 | 006923 1137 | 006926 1138 | 006929 1139 | 006932 1140 | 006935 1141 | 006941 1142 | 006951 1143 | 006952 1144 | 006957 1145 | 006969 1146 | 006985 1147 | 007004 1148 | 007017 1149 | 007018 1150 | 007021 1151 | 007036 1152 | 007045 1153 | 007046 1154 | 007070 1155 | 007073 1156 | 007075 1157 | 007077 1158 | 007092 1159 | 007094 1160 | 007099 1161 | 007118 1162 | 007121 1163 | 007123 1164 | 007126 1165 | 007129 1166 | 007147 1167 | 007150 1168 | 007156 1169 | 007159 1170 | 007179 1171 | 007185 1172 | 007188 1173 | 007189 1174 | 007192 1175 | 007203 1176 | 007206 1177 | 007223 1178 | 007231 1179 | 007237 1180 | 007254 1181 | 007264 1182 | 007269 1183 | 007285 1184 | 007286 1185 | 007301 1186 | 007308 1187 | 007320 1188 | 007321 1189 | 007324 1190 | 007333 1191 | 007334 1192 | 007338 1193 | 007355 1194 | 007362 1195 | 007363 1196 | 007366 1197 | 007368 1198 | 007372 1199 | 007387 1200 | 007393 1201 | 007406 1202 | 007408 1203 | 007414 1204 | 007427 1205 | 007431 1206 | 007446 1207 | 007451 1208 | 007452 1209 | 007465 1210 | 007476 1211 | -------------------------------------------------------------------------------- /train.sh: -------------------------------------------------------------------------------- 1 | #python3 kitty_dataset.py configs/car_auto_T3_train_train_config configs/car_auto_T3_train_config --dataset_root_dir data 2 | python3 mytrain.py configs/car_auto_T3_train_train_config configs/car_auto_T3_train_config --dataset_root_dir data 3 | #python3 train.py configs/car_auto_T3_train_train_config configs/car_auto_T3_train_config --dataset_root_dir data 4 | 5 | -------------------------------------------------------------------------------- /train_tf.sh: -------------------------------------------------------------------------------- 1 | python3 train.py configs/car_auto_T3_train_train_config configs/car_auto_T3_train_config --dataset_root_dir data 2 | -------------------------------------------------------------------------------- /util/config_util.py: -------------------------------------------------------------------------------- 1 | """This file implements configuration functions. """ 2 | 3 | import json 4 | 5 | def load_config(filename): 6 | """Load a configuration file.""" 7 | with open(filename, 'r') as f: 8 | config = json.load(f) 9 | return config 10 | 11 | def save_config(filename, config): 12 | """Save a configuration file. """ 13 | with open(filename, 'w') as f: 14 | json.dump(config, f, sort_keys=True, indent=4) 15 | 16 | def load_train_config(filename): 17 | """Load a configuration file.""" 18 | with open(filename, 'r') as f: 19 | config = json.load(f) 20 | return config 21 | 22 | def save_train_config(filename, train_config): 23 | """Save a configuration file. """ 24 | with open(filename, 'w') as f: 25 | json.dump(train_config, f, sort_keys=True, indent=4) 26 | -------------------------------------------------------------------------------- /util/metrics.py: -------------------------------------------------------------------------------- 1 | import torch 2 | import numpy as np 3 | 4 | def recall_precisions(labels, predictions, num_classes): 5 | recalls, precisions = {}, {} 6 | 7 | for class_idx in range(num_classes): 8 | gt = (labels==class_idx) 9 | pred = (predictions==class_idx) 10 | 11 | TP = float(torch.logical_and(gt.squeeze(), pred.squeeze()).sum()) 12 | 13 | recalls[class_idx] = TP / gt.sum().item() if gt.sum().item()!=0 else 0 14 | precisions[class_idx] = TP / pred.sum().item() if pred.sum().item()!=0 else 0 15 | 16 | return recalls, precisions 17 | 18 | def mAP(lables, logits, num_classes): 19 | mAPs = {} 20 | 21 | for class_idx in range(num_classes): 22 | pred = logits[:, class_idx] 23 | threshs = sorted(pred.tolist()) 24 | threshs = threshs[::len(threshs)//30] 25 | gt = (lables==class_idx) 26 | 27 | precisions = [] 28 | for thresh in threshs: 29 | _pred = (pred>thresh).bool() 30 | TP = float(torch.logical_and(gt.squeeze(), _pred.squeeze()).sum()) 31 | if _pred.sum().float().item()==0: 32 | precisions += [0] 33 | else: precisions += [TP / _pred.sum().float().item()] 34 | mAPs[class_idx] = np.mean(precisions) 35 | return mAPs 36 | 37 | if __name__ == "__main__": 38 | labels = torch.randint(0, 4, (100,)) 39 | predictions = torch.randint(0,4, (100,)) 40 | recalls, precisions = recall_precisions(labels, predictions, 4) 41 | 42 | logits = torch.rand(100, 4) 43 | mAPs = mAP(labels, logits, 4) 44 | print("recall: ", recalls) 45 | print("precision: ", precisions) 46 | print("mAPs: ", mAPs) 47 | 48 | 49 | 50 | 51 | 52 | 53 | -------------------------------------------------------------------------------- /util/summary_util.py: -------------------------------------------------------------------------------- 1 | """This file implements utility functions for tensorflow summary.""" 2 | 3 | import tensorflow as tf 4 | from tensorboard import summary as summary_lib 5 | 6 | def write_summary_scale(key, value, global_step, summary_dir): 7 | """Write a scale summary to summary_dir. """ 8 | writer = tf.summary.FileWriterCache.get(summary_dir) 9 | summary = tf.Summary(value=[ 10 | tf.Summary.Value(tag=key, simple_value=value), 11 | ]) 12 | writer.add_summary(summary, global_step) 13 | 14 | def write_summary(summary, global_step, summary_dir): 15 | """Write a summary to summary_dir. """ 16 | writer = tf.summary.FileWriterCache.get(summary_dir) 17 | writer.add_summary(summary, global_step) 18 | -------------------------------------------------------------------------------- /util/tf_util.py: -------------------------------------------------------------------------------- 1 | import tensorflow as tf 2 | 3 | def average_gradients(tower_grads): 4 | """Calculate the average gradient for each shared variable across all towers. 5 | Note that this function provides a synchronization point across all towers. 6 | Args: 7 | tower_grads: List of lists of (gradient, variable) tuples. The outer list 8 | is over individual gradients. The inner list is over the gradient 9 | calculation for each tower. 10 | Returns: 11 | List of pairs of (gradient, variable) where the gradient has been averaged 12 | across all towers. 13 | """ 14 | average_grads = [] 15 | for grad_and_vars in zip(*tower_grads): 16 | # Note that each grad_and_vars looks like the following: 17 | # ((grad0_gpu0, var0_gpu0), ... , (grad0_gpuN, var0_gpuN)) 18 | if grad_and_vars[0][0] is None: 19 | grad = grad_and_vars[0][0] 20 | v = grad_and_vars[0][1] 21 | grad_and_var = (grad, v) 22 | average_grads.append(grad_and_var) 23 | continue 24 | 25 | grads = [] 26 | for g, _ in grad_and_vars: 27 | # Add 0 dimension to the gradients to represent the tower. 28 | expanded_g = tf.expand_dims(g, 0) 29 | 30 | # Append on a 'tower' dimension which we will average over below. 31 | grads.append(expanded_g) 32 | 33 | # Average over the 'tower' dimension. 34 | grad = tf.concat(axis=0, values=grads) 35 | grad = tf.reduce_mean(grad, 0) 36 | 37 | # Keep in mind that the Variables are redundant because they are shared 38 | # across towers. So .. we will just return the first tower's pointer to 39 | # the Variable. 40 | v = grad_and_vars[0][1] 41 | grad_and_var = (grad, v) 42 | average_grads.append(grad_and_var) 43 | return average_grads 44 | --------------------------------------------------------------------------------