├── .gitignore ├── LICENSE ├── README.md ├── config.py ├── model ├── __init__.py ├── group_pointcloud.py ├── model.py └── rpn.py ├── setup.py ├── test.py ├── train.py ├── train_hook.py └── utils ├── __init__.py ├── box_overlaps.c ├── box_overlaps.pyx ├── colorize.py ├── data_aug.py ├── kitti_loader.py ├── preprocess.py ├── setup.py └── utils.py /.gitignore: -------------------------------------------------------------------------------- 1 | sync 2 | data 3 | log 4 | save_model 5 | build 6 | 7 | # Created by https://www.gitignore.io/api/vim 8 | 9 | ### Vim ### 10 | # swap 11 | [._]*.s[a-v][a-z] 12 | [._]*.sw[a-p] 13 | [._]s[a-v][a-z] 14 | [._]sw[a-p] 15 | # session 16 | Session.vim 17 | # temporary 18 | .netrwhist 19 | *~ 20 | # auto-generated tag files 21 | tags 22 | 23 | # End of https://www.gitignore.io/api/vim 24 | 25 | # Created by https://www.gitignore.io/api/python 26 | 27 | ### Python ### 28 | # Byte-compiled / optimized / DLL files 29 | __pycache__/ 30 | *.py[cod] 31 | *$py.class 32 | 33 | # C extensions 34 | *.so 35 | 36 | # Distribution / packaging 37 | .Python 38 | build/ 39 | develop-eggs/ 40 | dist/ 41 | downloads/ 42 | eggs/ 43 | .eggs/ 44 | lib/ 45 | lib64/ 46 | parts/ 47 | sdist/ 48 | var/ 49 | wheels/ 50 | *.egg-info/ 51 | .installed.cfg 52 | *.egg 53 | 54 | # PyInstaller 55 | # Usually these files are written by a python script from a template 56 | # before PyInstaller builds the exe, so as to inject date/other infos into it. 57 | *.manifest 58 | *.spec 59 | 60 | # Installer logs 61 | pip-log.txt 62 | pip-delete-this-directory.txt 63 | 64 | # Unit test / coverage reports 65 | htmlcov/ 66 | .tox/ 67 | .coverage 68 | .coverage.* 69 | .cache 70 | nosetests.xml 71 | coverage.xml 72 | *.cover 73 | .hypothesis/ 74 | 75 | # Translations 76 | *.mo 77 | *.pot 78 | 79 | # Django stuff: 80 | *.log 81 | local_settings.py 82 | 83 | # Flask stuff: 84 | instance/ 85 | .webassets-cache 86 | 87 | # Scrapy stuff: 88 | .scrapy 89 | 90 | # Sphinx documentation 91 | docs/_build/ 92 | 93 | # PyBuilder 94 | target/ 95 | 96 | # Jupyter Notebook 97 | .ipynb_checkpoints 98 | 99 | # pyenv 100 | .python-version 101 | 102 | # celery beat schedule file 103 | celerybeat-schedule 104 | 105 | # SageMath parsed files 106 | *.sage.py 107 | 108 | # Environments 109 | .env 110 | .venv 111 | env/ 112 | venv/ 113 | ENV/ 114 | env.bak/ 115 | venv.bak/ 116 | 117 | # Spyder project settings 118 | .spyderproject 119 | .spyproject 120 | 121 | # Rope project settings 122 | .ropeproject 123 | 124 | # mkdocs documentation 125 | /site 126 | 127 | # mypy 128 | .mypy_cache/ 129 | 130 | # End of https://www.gitignore.io/api/python 131 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | MIT License 2 | 3 | Copyright (c) 2019 Tsinghua Robot Learning Lab 4 | 5 | Permission is hereby granted, free of charge, to any person obtaining a copy 6 | of this software and associated documentation files (the "Software"), to deal 7 | in the Software without restriction, including without limitation the rights 8 | to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 | copies of the Software, and to permit persons to whom the Software is 10 | furnished to do so, subject to the following conditions: 11 | 12 | The above copyright notice and this permission notice shall be included in all 13 | copies or substantial portions of the Software. 14 | 15 | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 | IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 | OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 | SOFTWARE. 22 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # VoxelNet-tensorflow 2 | 3 | A tensorflow implementation for [VoxelNet](https://arxiv.org/abs/1711.06396). 4 | 5 | ## Requirement 6 | 7 | 1. `Python 3.5+` 8 | 2. `tensorflow 1.4+` 9 | 3. `NumPy`, etc. 10 | 11 | ## Usage 12 | 13 | 0. have a look at `config.py` for model configurations, split your data into test/train set by [this](https://xiaozhichen.github.io/files/mv3d/imagesets.tar.gz). 14 | 1. run `setup.py` to build the Cython module. 15 | ```bash 16 | $ python setup.py build_ext --inplace 17 | ``` 18 | 2. make sure your working directory looks like this (some files are omitted): 19 | ```plain 20 | ├── build <-- Cython build file 21 | ├── model <-- some src files 22 | ├── utils <-- some src files 23 | ├── setup.py 24 | ├── config.py 25 | ├── test.py 26 | ├── train.py 27 | ├── train_hook.py 28 | ├── README.md 29 | └── data <-- KITTI data directory 30 | └── object 31 |        ├── training  <-- training data 32 | | ├── image_2  33 | | ├── label_2  34 | | └── velodyne 35 | └── testing <--- testing data 36 | ├── image_2  37 | ├── label_2  38 | └── velodyne 39 | ``` 40 | 41 | 3. run `train.py`. Some cmdline parameters is needed, just check `train.py` for them. 42 | 4. launch a tensorboard and wait for the training result. 43 | 44 | ## Data augmentation 45 | Since [c928317](https://github.com/jeasinema/tf_voxelnet/commit/c928317169f1bf23e2157dab20cb402bddb8ffe0), data augmentation is done in an online manner, so there is no need for generating augmented samples. 46 | 47 | ## Result 48 | 49 | TBD 50 | 51 | ## Acknowledgement 52 | 53 | Thanks to [@ring00](https://github.com/ring00) for the implementation of VFE layer and **Jialin Zhao** for the implementation of the RPN. 54 | 55 | ## License 56 | 57 | MIT 58 | -------------------------------------------------------------------------------- /config.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : config.py 5 | # Purpose : 6 | # Creation Date : 09-12-2017 7 | # Last Modified : Fri 19 Jan 2018 01:11:28 PM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | 11 | """VoxelNet config system. 12 | """ 13 | 14 | import os 15 | import os.path as osp 16 | import numpy as np 17 | from time import strftime, localtime 18 | from easydict import EasyDict as edict 19 | import math 20 | 21 | __C = edict() 22 | # Consumers can get config by: 23 | # import config as cfg 24 | cfg = __C 25 | 26 | # for gpu allocation 27 | __C.GPU_AVAILABLE = '3,1,2,0' 28 | __C.GPU_USE_COUNT = len(__C.GPU_AVAILABLE.split(',')) 29 | __C.GPU_MEMORY_FRACTION = 1 30 | 31 | # selected object 32 | __C.DETECT_OBJ = 'Car' # Pedestrian/Cyclist 33 | if __C.DETECT_OBJ == 'Car': 34 | __C.Y_MIN = -40 35 | __C.Y_MAX = 40 36 | __C.X_MIN = 0 37 | __C.X_MAX = 70.4 38 | __C.VOXEL_X_SIZE = 0.2 39 | __C.VOXEL_Y_SIZE = 0.2 40 | __C.VOXEL_POINT_COUNT = 35 41 | __C.INPUT_WIDTH = int((__C.X_MAX - __C.X_MIN) / __C.VOXEL_X_SIZE) 42 | __C.INPUT_HEIGHT = int((__C.Y_MAX - __C.Y_MIN) / __C.VOXEL_Y_SIZE) 43 | __C.FEATURE_RATIO = 2 44 | __C.FEATURE_WIDTH = int(__C.INPUT_WIDTH / __C.FEATURE_RATIO) 45 | __C.FEATURE_HEIGHT = int(__C.INPUT_HEIGHT / __C.FEATURE_RATIO) 46 | else: 47 | __C.Y_MIN = -20 48 | __C.Y_MAX = 20 49 | __C.X_MIN = 0 50 | __C.X_MAX = 48 51 | __C.VOXEL_X_SIZE = 0.2 52 | __C.VOXEL_Y_SIZE = 0.2 53 | __C.VOXEL_POINT_COUNT = 45 54 | __C.INPUT_WIDTH = int((__C.X_MAX - __C.X_MIN) / __C.VOXEL_X_SIZE) 55 | __C.INPUT_HEIGHT = int((__C.Y_MAX - __C.Y_MIN) / __C.VOXEL_Y_SIZE) 56 | __C.FEATURE_RATIO = 2 57 | __C.FEATURE_WIDTH = int(__C.INPUT_WIDTH / __C.FEATURE_RATIO) 58 | __C.FEATURE_HEIGHT = int(__C.INPUT_HEIGHT / __C.FEATURE_RATIO) 59 | 60 | # set the log image scale factor 61 | __C.BV_LOG_FACTOR = 8 62 | 63 | # for data set type 64 | __C.DATA_SETS_TYPE = 'kitti' 65 | 66 | # Root directory of project 67 | __C.CHECKPOINT_DIR = osp.join('checkpoint') 68 | __C.LOG_DIR = osp.join('log') 69 | 70 | # for data preprocess 71 | # sensors 72 | __C.VELODYNE_ANGULAR_RESOLUTION = 0.08 / 180 * math.pi 73 | __C.VELODYNE_VERTICAL_RESOLUTION = 0.4 / 180 * math.pi 74 | __C.VELODYNE_HEIGHT = 1.73 75 | # rgb 76 | if __C.DATA_SETS_TYPE == 'kitti': 77 | __C.IMAGE_WIDTH = 1242 78 | __C.IMAGE_HEIGHT = 375 79 | __C.IMAGE_CHANNEL = 3 80 | # top 81 | if __C.DATA_SETS_TYPE == 'kitti': 82 | __C.TOP_Y_MIN = -30 83 | __C.TOP_Y_MAX = +30 84 | __C.TOP_X_MIN = 0 85 | __C.TOP_X_MAX = 80 86 | __C.TOP_Z_MIN = -4.2 87 | __C.TOP_Z_MAX = 0.8 88 | 89 | __C.TOP_X_DIVISION = 0.1 90 | __C.TOP_Y_DIVISION = 0.1 91 | __C.TOP_Z_DIVISION = 0.2 92 | 93 | __C.TOP_WIDTH = (__C.TOP_X_MAX - __C.TOP_X_MIN) // __C.TOP_X_DIVISION 94 | __C.TOP_HEIGHT = (__C.TOP_Y_MAX - __C.TOP_Y_MIN) // __C.TOP_Y_DIVISION 95 | __C.TOP_CHANNEL = (__C.TOP_Z_MAX - __C.TOP_Z_MIN) // __C.TOP_Z_DIVISION 96 | 97 | # for 2d proposal to 3d proposal 98 | __C.PROPOSAL3D_Z_MIN = -2.3 # -2.52 99 | __C.PROPOSAL3D_Z_MAX = 1.5 # -1.02 100 | 101 | # for RPN basenet choose 102 | __C.USE_VGG_AS_RPN = 0 103 | __C.USE_RESNET_AS_RPN = 0 104 | __C.USE_RESNEXT_AS_RPN = 0 105 | 106 | # for camera and lidar coordination convert 107 | if __C.DATA_SETS_TYPE == 'kitti': 108 | # cal mean from train set 109 | __C.MATRIX_P2 = ([[719.787081, 0., 608.463003, 44.9538775], 110 | [0., 719.787081, 174.545111, 0.1066855], 111 | [0., 0., 1., 3.0106472e-03], 112 | [0., 0., 0., 0]]) 113 | 114 | # cal mean from train set 115 | __C.MATRIX_T_VELO_2_CAM = ([ 116 | [7.49916597e-03, -9.99971248e-01, -8.65110297e-04, -6.71807577e-03], 117 | [1.18652889e-02, 9.54520517e-04, -9.99910318e-01, -7.33152811e-02], 118 | [9.99882833e-01, 7.49141178e-03, 1.18719929e-02, -2.78557062e-01], 119 | [0, 0, 0, 1] 120 | ]) 121 | # cal mean from train set 122 | __C.MATRIX_R_RECT_0 = ([ 123 | [0.99992475, 0.00975976, -0.00734152, 0], 124 | [-0.0097913, 0.99994262, -0.00430371, 0], 125 | [0.00729911, 0.0043753, 0.99996319, 0], 126 | [0, 0, 0, 1] 127 | ]) 128 | 129 | # Faster-RCNN/SSD Hyper params 130 | if __C.DETECT_OBJ == 'Car': 131 | # car anchor 132 | __C.ANCHOR_L = 3.9 133 | __C.ANCHOR_W = 1.6 134 | __C.ANCHOR_H = 1.56 135 | __C.ANCHOR_Z = -1.0 - cfg.ANCHOR_H/2 136 | __C.RPN_POS_IOU = 0.6 137 | __C.RPN_NEG_IOU = 0.45 138 | 139 | elif __C.DETECT_OBJ == 'Pedestrian': 140 | # pedestrian anchor 141 | __C.ANCHOR_L = 0.8 142 | __C.ANCHOR_W = 0.6 143 | __C.ANCHOR_H = 1.73 144 | __C.ANCHOR_Z = -0.6 - cfg.ANCHOR_H/2 145 | __C.RPN_POS_IOU = 0.5 146 | __C.RPN_NEG_IOU = 0.35 147 | 148 | if __C.DETECT_OBJ == 'Cyclist': 149 | # cyclist anchor 150 | __C.ANCHOR_L = 1.76 151 | __C.ANCHOR_W = 0.6 152 | __C.ANCHOR_H = 1.73 153 | __C.ANCHOR_Z = -0.6 - cfg.ANCHOR_H/2 154 | __C.RPN_POS_IOU = 0.5 155 | __C.RPN_NEG_IOU = 0.35 156 | 157 | # for rpn nms 158 | __C.RPN_NMS_POST_TOPK = 20 159 | __C.RPN_NMS_THRESH = 0.3 160 | __C.RPN_SCORE_THRESH = 0.96 161 | 162 | 163 | # utils 164 | __C.CORNER2CENTER_AVG = True # average version or max version 165 | 166 | if __name__ == '__main__': 167 | print('__C.ROOT_DIR = ' + __C.ROOT_DIR) 168 | print('__C.DATA_SETS_DIR = ' + __C.DATA_SETS_DIR) 169 | -------------------------------------------------------------------------------- /model/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : 5 | # Purpose : 6 | # Creation Date : 21-12-2017 7 | # Last Modified : Thu 21 Dec 2017 08:03:57 PM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | from model.group_pointcloud import * 11 | from model.rpn import * 12 | from model.model import * 13 | -------------------------------------------------------------------------------- /model/group_pointcloud.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding: utf-8 -*- 3 | 4 | # File Name : rpn.py 5 | # Purpose : 6 | # Creation Date : 10-12-2017 7 | # Last Modified : Thu 21 Dec 2017 07:48:05 PM CST 8 | # Created By : Wei Zhang 9 | 10 | import os 11 | import numpy as np 12 | import tensorflow as tf 13 | import time 14 | 15 | from config import cfg 16 | 17 | 18 | class VFELayer(object): 19 | 20 | def __init__(self, out_channels, name): 21 | super(VFELayer, self).__init__() 22 | self.units = int(out_channels / 2) 23 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE) as scope: 24 | self.dense = tf.layers.Dense( 25 | self.units, tf.nn.relu, name='dense', _reuse=tf.AUTO_REUSE, _scope=scope) 26 | self.batch_norm = tf.layers.BatchNormalization( 27 | name='batch_norm', fused=True, _reuse=tf.AUTO_REUSE, _scope=scope) 28 | 29 | def apply(self, inputs, mask, training): 30 | # [K, T, 7] tensordot [7, units] = [K, T, units] 31 | pointwise = self.batch_norm.apply(self.dense.apply(inputs), training) 32 | 33 | #n [K, 1, units] 34 | aggregated = tf.reduce_max(pointwise, axis=1, keep_dims=True) 35 | 36 | # [K, T, units] 37 | repeated = tf.tile(aggregated, [1, cfg.VOXEL_POINT_COUNT, 1]) 38 | 39 | # [K, T, 2 * units] 40 | concatenated = tf.concat([pointwise, repeated], axis=2) 41 | 42 | mask = tf.tile(mask, [1, 1, 2 * self.units]) 43 | 44 | concatenated = tf.multiply(concatenated, tf.cast(mask, tf.float32)) 45 | 46 | return concatenated 47 | 48 | 49 | class FeatureNet(object): 50 | 51 | def __init__(self, training, batch_size, name=''): 52 | super(FeatureNet, self).__init__() 53 | self.training = training 54 | 55 | # scalar 56 | self.batch_size = batch_size 57 | # [ΣK, 35/45, 7] 58 | self.feature = tf.placeholder( 59 | tf.float32, [None, cfg.VOXEL_POINT_COUNT, 7], name='feature') 60 | # [ΣK] 61 | self.number = tf.placeholder(tf.int64, [None], name='number') 62 | # [ΣK, 4], each row stores (batch, d, h, w) 63 | self.coordinate = tf.placeholder( 64 | tf.int64, [None, 4], name='coordinate') 65 | 66 | with tf.variable_scope(name, reuse=tf.AUTO_REUSE) as scope: 67 | self.vfe1 = VFELayer(32, 'VFE-1') 68 | self.vfe2 = VFELayer(128, 'VFE-2') 69 | self.dense = tf.layers.Dense( 70 | 128, tf.nn.relu, name='dense', _reuse=tf.AUTO_REUSE, _scope=scope) 71 | self.batch_norm = tf.layers.BatchNormalization( 72 | name='batch_norm', fused=True, _reuse=tf.AUTO_REUSE, _scope=scope) 73 | # boolean mask [K, T, 2 * units] 74 | mask = tf.not_equal(tf.reduce_max( 75 | self.feature, axis=2, keep_dims=True), 0) 76 | x = self.vfe1.apply(self.feature, mask, self.training) 77 | x = self.vfe2.apply(x, mask, self.training) 78 | x = self.dense.apply(x) 79 | x = self.batch_norm.apply(x, self.training) 80 | 81 | # [ΣK, 128] 82 | voxelwise = tf.reduce_max(x, axis=1) 83 | 84 | # car: [N * 10 * 400 * 352 * 128] 85 | # pedestrian/cyclist: [N * 10 * 200 * 240 * 128] 86 | self.outputs = tf.scatter_nd( 87 | self.coordinate, voxelwise, [self.batch_size, 10, cfg.INPUT_HEIGHT, cfg.INPUT_WIDTH, 128]) 88 | 89 | 90 | def build_input(voxel_dict_list): 91 | batch_size = len(voxel_dict_list) 92 | 93 | feature_list = [] 94 | number_list = [] 95 | coordinate_list = [] 96 | for i, voxel_dict in zip(range(batch_size), voxel_dict_list): 97 | feature_list.append(voxel_dict['feature_buffer']) 98 | number_list.append(voxel_dict['number_buffer']) 99 | coordinate = voxel_dict['coordinate_buffer'] 100 | coordinate_list.append( 101 | np.pad(coordinate, ((0, 0), (1, 0)), 102 | mode='constant', constant_values=i)) 103 | 104 | feature = np.concatenate(feature_list) 105 | number = np.concatenate(number_list) 106 | coordinate = np.concatenate(coordinate_list) 107 | return batch_size, feature, number, coordinate 108 | 109 | 110 | def run(batch_size, feature, number, coordinate): 111 | """ 112 | Input: 113 | batch_size: scalar, the batch size 114 | feature: [ΣK, T, 7], voxel input feature buffer 115 | number: [ΣK], number of points in each voxel 116 | coordinate: [ΣK, 4], voxel coordinate buffer 117 | 118 | A feature tensor feature[i] has number[i] points in it and is located in 119 | coordinate[i] (a 1-D tensor reprents [batch, d, h, w]) in the output 120 | 121 | Input format is similiar to what's described in section 2.3 of the paper 122 | 123 | Suppose the batch size is 3, the 3 point cloud is loaded as 124 | 1. feature: [K1, T, 7] (K1 is the number of non-empty voxels) 125 | number: [K1] (number of points in the corresponding voxel) 126 | coordinate: [K1, 3] (each row is a tensor reprents [d, h, w]) 127 | 2. feature: [K2, T, 7] 128 | number: [K2] 129 | coordinate: [K2, 3] 130 | 3. feature: [K3, T, 7] 131 | number: [K3] 132 | coordinate: [K3, 3] 133 | Then the corresponding input is 134 | batch_size: 3 135 | feature: [K1 + K2 + K3, T, 7] 136 | number: [K1 + K2 + K3] 137 | coordinate: [K1 + K2 + K3, 4] (need to append the batch index of the 138 | corresponding voxel in front of each row) 139 | Output: 140 | outputs: [batch_size, 10, 400, 352, 128] 141 | """ 142 | gpu_options = tf.GPUOptions(visible_device_list='0,2,3') 143 | config = tf.ConfigProto( 144 | gpu_options=gpu_options, 145 | device_count={'GPU': 3} 146 | ) 147 | 148 | with tf.Session(config=config) as sess: 149 | model = FeatureNet(training=False, batch_size=batch_size) 150 | tf.global_variables_initializer().run() 151 | for i in range(10): 152 | time_start = time.time() 153 | feed = {model.feature: feature, 154 | model.number: number, 155 | model.coordinate: coordinate} 156 | outputs = sess.run([model.outputs], feed) 157 | print(outputs[0].shape) 158 | time_end = time.time() 159 | print(time_end - time_start) 160 | 161 | 162 | def main(): 163 | data_dir = './data/object/training/voxel' 164 | batch_size = 32 165 | 166 | filelist = [f for f in os.listdir(data_dir) if f.endswith('npz')] 167 | 168 | import time 169 | voxel_dict_list = [] 170 | for id in range(0, len(filelist), batch_size): 171 | pre_time = time.time() 172 | batch_file = [f for f in filelist[id:id + batch_size]] 173 | voxel_dict_list = [] 174 | for file in batch_file: 175 | voxel_dict_list.append(np.load(os.path.join(data_dir, file))) 176 | 177 | # example input with batch size 16 178 | batch_size, feature, number, coordinate = build_input(voxel_dict_list) 179 | print(time.time() - pre_time) 180 | 181 | run(batch_size, feature, number, coordinate) 182 | 183 | 184 | if __name__ == '__main__': 185 | main() 186 | -------------------------------------------------------------------------------- /model/model.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : model.py 5 | # Purpose : 6 | # Creation Date : 09-12-2017 7 | # Last Modified : Fri 05 Jan 2018 09:34:48 PM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | import sys 11 | import os 12 | import tensorflow as tf 13 | import cv2 14 | from numba import jit 15 | 16 | from config import cfg 17 | from utils import * 18 | from model.group_pointcloud import FeatureNet 19 | from model.rpn import MiddleAndRPN 20 | 21 | 22 | class RPN3D(object): 23 | 24 | def __init__(self, 25 | cls='Car', 26 | single_batch_size=2, # batch_size_per_gpu 27 | learning_rate=0.001, 28 | max_gradient_norm=5.0, 29 | alpha=1.5, 30 | beta=1, 31 | is_train=True, 32 | avail_gpus=['0']): 33 | # hyper parameters and status 34 | self.cls = cls 35 | self.single_batch_size = single_batch_size 36 | self.learning_rate = tf.Variable( 37 | float(learning_rate), trainable=False, dtype=tf.float32) 38 | self.global_step = tf.Variable(1, trainable=False) 39 | self.epoch = tf.Variable(0, trainable=False) 40 | self.epoch_add_op = self.epoch.assign(self.epoch + 1) 41 | self.alpha = alpha 42 | self.beta = beta 43 | self.avail_gpus = avail_gpus 44 | 45 | lr = tf.train.exponential_decay( 46 | self.learning_rate, self.global_step, 10000, 0.96) 47 | 48 | # build graph 49 | # input placeholders 50 | self.vox_feature = [] 51 | self.vox_number = [] 52 | self.vox_coordinate = [] 53 | self.targets = [] 54 | self.pos_equal_one = [] 55 | self.pos_equal_one_sum = [] 56 | self.pos_equal_one_for_reg = [] 57 | self.neg_equal_one = [] 58 | self.neg_equal_one_sum = [] 59 | 60 | self.delta_output = [] 61 | self.prob_output = [] 62 | self.opt = tf.train.AdamOptimizer(lr) 63 | self.gradient_norm = [] 64 | self.tower_grads = [] 65 | with tf.variable_scope(tf.get_variable_scope()): 66 | for idx, dev in enumerate(self.avail_gpus): 67 | with tf.device('/gpu:{}'.format(dev)), tf.name_scope('gpu_{}'.format(dev)): 68 | # must use name scope here since we do not want to create new variables 69 | # graph 70 | feature = FeatureNet( 71 | training=is_train, batch_size=self.single_batch_size) 72 | rpn = MiddleAndRPN( 73 | input=feature.outputs, alpha=self.alpha, beta=self.beta, training=is_train) 74 | tf.get_variable_scope().reuse_variables() 75 | # input 76 | self.vox_feature.append(feature.feature) 77 | self.vox_number.append(feature.number) 78 | self.vox_coordinate.append(feature.coordinate) 79 | self.targets.append(rpn.targets) 80 | self.pos_equal_one.append(rpn.pos_equal_one) 81 | self.pos_equal_one_sum.append(rpn.pos_equal_one_sum) 82 | self.pos_equal_one_for_reg.append( 83 | rpn.pos_equal_one_for_reg) 84 | self.neg_equal_one.append(rpn.neg_equal_one) 85 | self.neg_equal_one_sum.append(rpn.neg_equal_one_sum) 86 | # output 87 | feature_output = feature.outputs 88 | delta_output = rpn.delta_output 89 | prob_output = rpn.prob_output 90 | # loss and grad 91 | self.loss = rpn.loss 92 | self.reg_loss = rpn.reg_loss 93 | self.cls_loss = rpn.cls_loss 94 | self.params = tf.trainable_variables() 95 | gradients = tf.gradients(self.loss, self.params) 96 | clipped_gradients, gradient_norm = tf.clip_by_global_norm( 97 | gradients, max_gradient_norm) 98 | 99 | self.delta_output.append(delta_output) 100 | self.prob_output.append(prob_output) 101 | self.tower_grads.append(clipped_gradients) 102 | self.gradient_norm.append(gradient_norm) 103 | self.rpn_output_shape = rpn.output_shape 104 | 105 | # loss and optimizer 106 | # self.xxxloss is only the loss for the lowest tower 107 | with tf.device('/gpu:{}'.format(self.avail_gpus[0])): 108 | self.grads = average_gradients(self.tower_grads) 109 | self.update = self.opt.apply_gradients( 110 | zip(self.grads, self.params), global_step=self.global_step) 111 | self.gradient_norm = tf.group(*self.gradient_norm) 112 | 113 | self.delta_output = tf.concat(self.delta_output, axis=0) 114 | self.prob_output = tf.concat(self.prob_output, axis=0) 115 | 116 | self.anchors = cal_anchors() 117 | # for predict and image summary 118 | self.rgb = tf.placeholder( 119 | tf.uint8, [None, cfg.IMAGE_HEIGHT, cfg.IMAGE_WIDTH, 3]) 120 | self.bv = tf.placeholder(tf.uint8, [ 121 | None, cfg.BV_LOG_FACTOR * cfg.INPUT_HEIGHT, cfg.BV_LOG_FACTOR * cfg.INPUT_WIDTH, 3]) 122 | self.bv_heatmap = tf.placeholder(tf.uint8, [ 123 | None, cfg.BV_LOG_FACTOR * cfg.FEATURE_HEIGHT, cfg.BV_LOG_FACTOR * cfg.FEATURE_WIDTH, 3]) 124 | self.boxes2d = tf.placeholder(tf.float32, [None, 4]) 125 | self.boxes2d_scores = tf.placeholder(tf.float32, [None]) 126 | 127 | # NMS(2D) 128 | with tf.device('/gpu:{}'.format(self.avail_gpus[0])): 129 | self.box2d_ind_after_nms = tf.image.non_max_suppression( 130 | self.boxes2d, self.boxes2d_scores, max_output_size=cfg.RPN_NMS_POST_TOPK, iou_threshold=cfg.RPN_NMS_THRESH) 131 | 132 | # summary and saver 133 | self.saver = tf.train.Saver(write_version=tf.train.SaverDef.V2, 134 | max_to_keep=10, pad_step_number=True, keep_checkpoint_every_n_hours=1.0) 135 | 136 | self.train_summary = tf.summary.merge([ 137 | tf.summary.scalar('train/loss', self.loss), 138 | tf.summary.scalar('train/reg_loss', self.reg_loss), 139 | tf.summary.scalar('train/cls_loss', self.cls_loss), 140 | *[tf.summary.histogram(each.name, each) for each in self.params] 141 | ]) 142 | 143 | self.validate_summary = tf.summary.merge([ 144 | tf.summary.scalar('validate/loss', self.loss), 145 | tf.summary.scalar('validate/reg_loss', self.reg_loss), 146 | tf.summary.scalar('validate/cls_loss', self.cls_loss) 147 | ]) 148 | 149 | # TODO: bird_view_summary and front_view_summary 150 | 151 | self.predict_summary = tf.summary.merge([ 152 | tf.summary.image('predict/bird_view_lidar', self.bv), 153 | tf.summary.image('predict/bird_view_heatmap', self.bv_heatmap), 154 | tf.summary.image('predict/front_view_rgb', self.rgb), 155 | ]) 156 | 157 | def train_step(self, session, data, train=False, summary=False): 158 | # input: 159 | # (N) tag 160 | # (N, N') label 161 | # vox_feature 162 | # vox_number 163 | # vox_coordinate 164 | tag = data[0] 165 | label = data[1] 166 | vox_feature = data[2] 167 | vox_number = data[3] 168 | vox_coordinate = data[4] 169 | print('train', tag) 170 | pos_equal_one, neg_equal_one, targets = cal_rpn_target( 171 | label, self.rpn_output_shape, self.anchors, cls=cfg.DETECT_OBJ, coordinate='lidar') 172 | pos_equal_one_for_reg = np.concatenate( 173 | [np.tile(pos_equal_one[..., [0]], 7), np.tile(pos_equal_one[..., [1]], 7)], axis=-1) 174 | pos_equal_one_sum = np.clip(np.sum(pos_equal_one, axis=( 175 | 1, 2, 3)).reshape(-1, 1, 1, 1), a_min=1, a_max=None) 176 | neg_equal_one_sum = np.clip(np.sum(neg_equal_one, axis=( 177 | 1, 2, 3)).reshape(-1, 1, 1, 1), a_min=1, a_max=None) 178 | 179 | input_feed = {} 180 | for idx in range(len(self.avail_gpus)): 181 | input_feed[self.vox_feature[idx]] = vox_feature[idx] 182 | input_feed[self.vox_number[idx]] = vox_number[idx] 183 | input_feed[self.vox_coordinate[idx]] = vox_coordinate[idx] 184 | input_feed[self.targets[idx]] = targets[idx * 185 | self.single_batch_size:(idx + 1) * self.single_batch_size] 186 | input_feed[self.pos_equal_one[idx]] = pos_equal_one[idx * 187 | self.single_batch_size:(idx + 1) * self.single_batch_size] 188 | input_feed[self.pos_equal_one_sum[idx]] = pos_equal_one_sum[idx * 189 | self.single_batch_size:(idx + 1) * self.single_batch_size] 190 | input_feed[self.pos_equal_one_for_reg[idx]] = pos_equal_one_for_reg[idx * 191 | self.single_batch_size:(idx + 1) * self.single_batch_size] 192 | input_feed[self.neg_equal_one[idx]] = neg_equal_one[idx * 193 | self.single_batch_size:(idx + 1) * self.single_batch_size] 194 | input_feed[self.neg_equal_one_sum[idx]] = neg_equal_one_sum[idx * 195 | self.single_batch_size:(idx + 1) * self.single_batch_size] 196 | if train: 197 | output_feed = [self.loss, self.reg_loss, 198 | self.cls_loss, self.gradient_norm, self.update] 199 | else: 200 | output_feed = [self.loss, self.reg_loss, self.cls_loss] 201 | if summary: 202 | output_feed.append(self.train_summary) 203 | # TODO: multi-gpu support for test and predict step 204 | return session.run(output_feed, input_feed) 205 | 206 | def validate_step(self, session, data, summary=False): 207 | # input: 208 | # (N) tag 209 | # (N, N') label 210 | # vox_feature 211 | # vox_number 212 | # vox_coordinate 213 | tag = data[0] 214 | label = data[1] 215 | vox_feature = data[2] 216 | vox_number = data[3] 217 | vox_coordinate = data[4] 218 | print('valid', tag) 219 | pos_equal_one, neg_equal_one, targets = cal_rpn_target( 220 | label, self.rpn_output_shape, self.anchors) 221 | pos_equal_one_for_reg = np.concatenate( 222 | [np.tile(pos_equal_one[..., [0]], 7), np.tile(pos_equal_one[..., [1]], 7)], axis=-1) 223 | pos_equal_one_sum = np.clip(np.sum(pos_equal_one, axis=( 224 | 1, 2, 3)).reshape(-1, 1, 1, 1), a_min=1, a_max=None) 225 | neg_equal_one_sum = np.clip(np.sum(neg_equal_one, axis=( 226 | 1, 2, 3)).reshape(-1, 1, 1, 1), a_min=1, a_max=None) 227 | 228 | input_feed = {} 229 | for idx in range(len(self.avail_gpus)): 230 | input_feed[self.vox_feature[idx]] = vox_feature[idx] 231 | input_feed[self.vox_number[idx]] = vox_number[idx] 232 | input_feed[self.vox_coordinate[idx]] = vox_coordinate[idx] 233 | input_feed[self.targets[idx]] = targets[idx * 234 | self.single_batch_size:(idx + 1) * self.single_batch_size] 235 | input_feed[self.pos_equal_one[idx]] = pos_equal_one[idx * 236 | self.single_batch_size:(idx + 1) * self.single_batch_size] 237 | input_feed[self.pos_equal_one_sum[idx]] = pos_equal_one_sum[idx * 238 | self.single_batch_size:(idx + 1) * self.single_batch_size] 239 | input_feed[self.pos_equal_one_for_reg[idx]] = pos_equal_one_for_reg[idx * 240 | self.single_batch_size:(idx + 1) * self.single_batch_size] 241 | input_feed[self.neg_equal_one[idx]] = neg_equal_one[idx * 242 | self.single_batch_size:(idx + 1) * self.single_batch_size] 243 | input_feed[self.neg_equal_one_sum[idx]] = neg_equal_one_sum[idx * 244 | self.single_batch_size:(idx + 1) * self.single_batch_size] 245 | 246 | output_feed = [self.loss, self.reg_loss, self.cls_loss] 247 | if summary: 248 | output_feed.append(self.validate_summary) 249 | return session.run(output_feed, input_feed) 250 | 251 | def predict_step(self, session, data, summary=False): 252 | # input: 253 | # (N) tag 254 | # (N, N') label(can be empty) 255 | # vox_feature 256 | # vox_number 257 | # vox_coordinate 258 | # img (N, w, l, 3) 259 | # lidar (N, N', 4) 260 | # output: A, B, C 261 | # A: (N) tag 262 | # B: (N, N') (class, x, y, z, h, w, l, rz, score) 263 | # C; summary(optional) 264 | tag = data[0] 265 | label = data[1] 266 | vox_feature = data[2] 267 | vox_number = data[3] 268 | vox_coordinate = data[4] 269 | img = data[5] 270 | lidar = data[6] 271 | 272 | if summary: 273 | batch_gt_boxes3d = label_to_gt_box3d( 274 | label, cls=self.cls, coordinate='lidar') 275 | print('predict', tag) 276 | input_feed = {} 277 | for idx in range(len(self.avail_gpus)): 278 | input_feed[self.vox_feature[idx]] = vox_feature[idx] 279 | input_feed[self.vox_number[idx]] = vox_number[idx] 280 | input_feed[self.vox_coordinate[idx]] = vox_coordinate[idx] 281 | 282 | output_feed = [self.prob_output, self.delta_output] 283 | probs, deltas = session.run(output_feed, input_feed) 284 | # BOTTLENECK 285 | batch_boxes3d = delta_to_boxes3d( 286 | deltas, self.anchors, coordinate='lidar') 287 | batch_boxes2d = batch_boxes3d[:, :, [0, 1, 4, 5, 6]] 288 | batch_probs = probs.reshape( 289 | (len(self.avail_gpus) * self.single_batch_size, -1)) 290 | # NMS 291 | ret_box3d = [] 292 | ret_score = [] 293 | for batch_id in range(len(self.avail_gpus) * self.single_batch_size): 294 | # remove box with low score 295 | ind = np.where(batch_probs[batch_id, :] >= cfg.RPN_SCORE_THRESH)[0] 296 | tmp_boxes3d = batch_boxes3d[batch_id, ind, ...] 297 | tmp_boxes2d = batch_boxes2d[batch_id, ind, ...] 298 | tmp_scores = batch_probs[batch_id, ind] 299 | 300 | # TODO: if possible, use rotate NMS 301 | boxes2d = corner_to_standup_box2d( 302 | center_to_corner_box2d(tmp_boxes2d, coordinate='lidar')) 303 | ind = session.run(self.box2d_ind_after_nms, { 304 | self.boxes2d: boxes2d, 305 | self.boxes2d_scores: tmp_scores 306 | }) 307 | tmp_boxes3d = tmp_boxes3d[ind, ...] 308 | tmp_scores = tmp_scores[ind] 309 | ret_box3d.append(tmp_boxes3d) 310 | ret_score.append(tmp_scores) 311 | 312 | ret_box3d_score = [] 313 | for boxes3d, scores in zip(ret_box3d, ret_score): 314 | ret_box3d_score.append(np.concatenate([np.tile(self.cls, len(boxes3d))[:, np.newaxis], 315 | boxes3d, scores[:, np.newaxis]], axis=-1)) 316 | 317 | if summary: 318 | # only summry 1 in a batch 319 | front_image = draw_lidar_box3d_on_image(img[0], ret_box3d[0], ret_score[0], 320 | batch_gt_boxes3d[0]) 321 | bird_view = lidar_to_bird_view_img( 322 | lidar[0], factor=cfg.BV_LOG_FACTOR) 323 | bird_view = draw_lidar_box3d_on_birdview(bird_view, ret_box3d[0], ret_score[0], 324 | batch_gt_boxes3d[0], factor=cfg.BV_LOG_FACTOR) 325 | heatmap = colorize(probs[0, ...], cfg.BV_LOG_FACTOR) 326 | ret_summary = session.run(self.predict_summary, { 327 | self.rgb: front_image[np.newaxis, ...], 328 | self.bv: bird_view[np.newaxis, ...], 329 | self.bv_heatmap: heatmap[np.newaxis, ...] 330 | }) 331 | 332 | return tag, ret_box3d_score, ret_summary 333 | 334 | return tag, ret_box3d_score 335 | 336 | 337 | def average_gradients(tower_grads): 338 | # ref: 339 | # https://github.com/tensorflow/models/blob/6db9f0282e2ab12795628de6200670892a8ad6ba/tutorials/image/cifar10/cifar10_multi_gpu_train.py#L103 340 | # but only contains grads, no vars 341 | average_grads = [] 342 | for grad_and_vars in zip(*tower_grads): 343 | grads = [] 344 | for g in grad_and_vars: 345 | # Add 0 dimension to the gradients to represent the tower. 346 | expanded_g = tf.expand_dims(g, 0) 347 | 348 | # Append on a 'tower' dimension which we will average over below. 349 | grads.append(expanded_g) 350 | 351 | # Average over the 'tower' dimension. 352 | grad = tf.concat(axis=0, values=grads) 353 | grad = tf.reduce_mean(grad, 0) 354 | 355 | # Keep in mind that the Variables are redundant because they are shared 356 | # across towers. So .. we will just return the first tower's pointer to 357 | # the Variable. 358 | grad_and_var = grad 359 | average_grads.append(grad_and_var) 360 | return average_grads 361 | 362 | 363 | if __name__ == '__main__': 364 | pass 365 | -------------------------------------------------------------------------------- /model/rpn.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : rpn.py 5 | # Purpose : 6 | # Creation Date : 10-12-2017 7 | # Last Modified : Thu 08 Mar 2018 02:20:43 PM CST 8 | # Created By : Jialin Zhao 9 | 10 | import tensorflow as tf 11 | import numpy as np 12 | 13 | from config import cfg 14 | 15 | 16 | small_addon_for_BCE = 1e-6 17 | 18 | 19 | class MiddleAndRPN: 20 | def __init__(self, input, alpha=1.5, beta=1, sigma=3, training=True, name=''): 21 | # scale = [batchsize, 10, 400/200, 352/240, 128] should be the output of feature learning network 22 | self.input = input 23 | self.training = training 24 | # groundtruth(target) - each anchor box, represent as △x, △y, △z, △l, △w, △h, rotation 25 | self.targets = tf.placeholder( 26 | tf.float32, [None, cfg.FEATURE_HEIGHT, cfg.FEATURE_WIDTH, 14]) 27 | # postive anchors equal to one and others equal to zero(2 anchors in 1 position) 28 | self.pos_equal_one = tf.placeholder( 29 | tf.float32, [None, cfg.FEATURE_HEIGHT, cfg.FEATURE_WIDTH, 2]) 30 | self.pos_equal_one_sum = tf.placeholder(tf.float32, [None, 1, 1, 1]) 31 | self.pos_equal_one_for_reg = tf.placeholder( 32 | tf.float32, [None, cfg.FEATURE_HEIGHT, cfg.FEATURE_WIDTH, 14]) 33 | # negative anchors equal to one and others equal to zero 34 | self.neg_equal_one = tf.placeholder( 35 | tf.float32, [None, cfg.FEATURE_HEIGHT, cfg.FEATURE_WIDTH, 2]) 36 | self.neg_equal_one_sum = tf.placeholder(tf.float32, [None, 1, 1, 1]) 37 | 38 | with tf.variable_scope('MiddleAndRPN_' + name): 39 | # convolutinal middle layers 40 | temp_conv = ConvMD(3, 128, 64, 3, (2, 1, 1), 41 | (1, 1, 1), self.input, name='conv1') 42 | temp_conv = ConvMD(3, 64, 64, 3, (1, 1, 1), 43 | (0, 1, 1), temp_conv, name='conv2') 44 | temp_conv = ConvMD(3, 64, 64, 3, (2, 1, 1), 45 | (1, 1, 1), temp_conv, name='conv3') 46 | temp_conv = tf.transpose(temp_conv, perm=[0, 2, 3, 4, 1]) 47 | temp_conv = tf.reshape( 48 | temp_conv, [-1, cfg.INPUT_HEIGHT, cfg.INPUT_WIDTH, 128]) 49 | 50 | # rpn 51 | # block1: 52 | temp_conv = ConvMD(2, 128, 128, 3, (2, 2), (1, 1), 53 | temp_conv, training=self.training, name='conv4') 54 | temp_conv = ConvMD(2, 128, 128, 3, (1, 1), (1, 1), 55 | temp_conv, training=self.training, name='conv5') 56 | temp_conv = ConvMD(2, 128, 128, 3, (1, 1), (1, 1), 57 | temp_conv, training=self.training, name='conv6') 58 | temp_conv = ConvMD(2, 128, 128, 3, (1, 1), (1, 1), 59 | temp_conv, training=self.training, name='conv7') 60 | deconv1 = Deconv2D(128, 256, 3, (1, 1), (0, 0), 61 | temp_conv, training=self.training, name='deconv1') 62 | 63 | # block2: 64 | temp_conv = ConvMD(2, 128, 128, 3, (2, 2), (1, 1), 65 | temp_conv, training=self.training, name='conv8') 66 | temp_conv = ConvMD(2, 128, 128, 3, (1, 1), (1, 1), 67 | temp_conv, training=self.training, name='conv9') 68 | temp_conv = ConvMD(2, 128, 128, 3, (1, 1), (1, 1), 69 | temp_conv, training=self.training, name='conv10') 70 | temp_conv = ConvMD(2, 128, 128, 3, (1, 1), (1, 1), 71 | temp_conv, training=self.training, name='conv11') 72 | temp_conv = ConvMD(2, 128, 128, 3, (1, 1), (1, 1), 73 | temp_conv, training=self.training, name='conv12') 74 | temp_conv = ConvMD(2, 128, 128, 3, (1, 1), (1, 1), 75 | temp_conv, training=self.training, name='conv13') 76 | deconv2 = Deconv2D(128, 256, 2, (2, 2), (0, 0), 77 | temp_conv, training=self.training, name='deconv2') 78 | 79 | # block3: 80 | temp_conv = ConvMD(2, 128, 256, 3, (2, 2), (1, 1), 81 | temp_conv, training=self.training, name='conv14') 82 | temp_conv = ConvMD(2, 256, 256, 3, (1, 1), (1, 1), 83 | temp_conv, training=self.training, name='conv15') 84 | temp_conv = ConvMD(2, 256, 256, 3, (1, 1), (1, 1), 85 | temp_conv, training=self.training, name='conv16') 86 | temp_conv = ConvMD(2, 256, 256, 3, (1, 1), (1, 1), 87 | temp_conv, training=self.training, name='conv17') 88 | temp_conv = ConvMD(2, 256, 256, 3, (1, 1), (1, 1), 89 | temp_conv, training=self.training, name='conv18') 90 | temp_conv = ConvMD(2, 256, 256, 3, (1, 1), (1, 1), 91 | temp_conv, training=self.training, name='conv19') 92 | deconv3 = Deconv2D(256, 256, 4, (4, 4), (0, 0), 93 | temp_conv, training=self.training, name='deconv3') 94 | 95 | # final: 96 | temp_conv = tf.concat([deconv3, deconv2, deconv1], -1) 97 | # Probability score map, scale = [None, 200/100, 176/120, 2] 98 | p_map = ConvMD(2, 768, 2, 1, (1, 1), (0, 0), temp_conv, activation=False, 99 | training=self.training, name='conv20') 100 | # Regression(residual) map, scale = [None, 200/100, 176/120, 14] 101 | r_map = ConvMD(2, 768, 14, 1, (1, 1), (0, 0), 102 | temp_conv, training=self.training, activation=False, name='conv21') 103 | # softmax output for positive anchor and negative anchor, scale = [None, 200/100, 176/120, 1] 104 | self.p_pos = tf.sigmoid(p_map) 105 | self.output_shape = [cfg.FEATURE_HEIGHT, cfg.FEATURE_WIDTH] 106 | 107 | self.cls_loss = alpha * (-self.pos_equal_one * tf.log(self.p_pos + small_addon_for_BCE)) / self.pos_equal_one_sum \ 108 | + beta * (-self.neg_equal_one * tf.log(1 - self.p_pos + 109 | small_addon_for_BCE)) / self.neg_equal_one_sum 110 | self.cls_loss = tf.reduce_sum(self.cls_loss) 111 | 112 | self.reg_loss = smooth_l1(r_map * self.pos_equal_one_for_reg, self.targets * 113 | self.pos_equal_one_for_reg, sigma) / self.pos_equal_one_sum 114 | self.reg_loss = tf.reduce_sum(self.reg_loss) 115 | 116 | self.loss = tf.reduce_sum(self.cls_loss + self.reg_loss) 117 | 118 | self.delta_output = r_map 119 | self.prob_output = self.p_pos 120 | 121 | 122 | def smooth_l1(deltas, targets, sigma=3.0): 123 | sigma2 = sigma * sigma 124 | diffs = tf.subtract(deltas, targets) 125 | smooth_l1_signs = tf.cast(tf.less(tf.abs(diffs), 1.0 / sigma2), tf.float32) 126 | 127 | smooth_l1_option1 = tf.multiply(diffs, diffs) * 0.5 * sigma2 128 | smooth_l1_option2 = tf.abs(diffs) - 0.5 / sigma2 129 | smooth_l1_add = tf.multiply(smooth_l1_option1, smooth_l1_signs) + \ 130 | tf.multiply(smooth_l1_option2, 1 - smooth_l1_signs) 131 | smooth_l1 = smooth_l1_add 132 | 133 | return smooth_l1 134 | 135 | 136 | def ConvMD(M, Cin, Cout, k, s, p, input, training=True, activation=True, name='conv'): 137 | temp_p = np.array(p) 138 | temp_p = np.lib.pad(temp_p, (1, 1), 'constant', constant_values=(0, 0)) 139 | with tf.variable_scope(name) as scope: 140 | if(M == 2): 141 | paddings = (np.array(temp_p)).repeat(2).reshape(4, 2) 142 | pad = tf.pad(input, paddings, "CONSTANT") 143 | temp_conv = tf.layers.conv2d( 144 | pad, Cout, k, strides=s, padding="valid", reuse=tf.AUTO_REUSE, name=scope) 145 | if(M == 3): 146 | paddings = (np.array(temp_p)).repeat(2).reshape(5, 2) 147 | pad = tf.pad(input, paddings, "CONSTANT") 148 | temp_conv = tf.layers.conv3d( 149 | pad, Cout, k, strides=s, padding="valid", reuse=tf.AUTO_REUSE, name=scope) 150 | temp_conv = tf.layers.batch_normalization( 151 | temp_conv, axis=-1, fused=True, training=training, reuse=tf.AUTO_REUSE, name=scope) 152 | if activation: 153 | return tf.nn.relu(temp_conv) 154 | else: 155 | return temp_conv 156 | 157 | def Deconv2D(Cin, Cout, k, s, p, input, training=True, name='deconv'): 158 | temp_p = np.array(p) 159 | temp_p = np.lib.pad(temp_p, (1, 1), 'constant', constant_values=(0, 0)) 160 | paddings = (np.array(temp_p)).repeat(2).reshape(4, 2) 161 | pad = tf.pad(input, paddings, "CONSTANT") 162 | with tf.variable_scope(name) as scope: 163 | temp_conv = tf.layers.conv2d_transpose( 164 | pad, Cout, k, strides=s, padding="SAME", reuse=tf.AUTO_REUSE, name=scope) 165 | temp_conv = tf.layers.batch_normalization( 166 | temp_conv, axis=-1, fused=True, training=training, reuse=tf.AUTO_REUSE, name=scope) 167 | return tf.nn.relu(temp_conv) 168 | 169 | 170 | if(__name__ == "__main__"): 171 | m = MiddleAndRPN(tf.placeholder( 172 | tf.float32, [None, 10, cfg.INPUT_HEIGHT, cfg.INPUT_WIDTH, 128])) 173 | -------------------------------------------------------------------------------- /setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : setup.py 5 | # Purpose : 6 | # Creation Date : 11-12-2017 7 | # Last Modified : Sat 23 Dec 2017 03:18:37 PM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | 11 | from distutils.core import setup 12 | from Cython.Build import cythonize 13 | 14 | setup( 15 | name='box overlaps', 16 | ext_modules=cythonize('./utils/box_overlaps.pyx') 17 | ) 18 | -------------------------------------------------------------------------------- /test.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : train.py 5 | # Purpose : 6 | # Creation Date : 09-12-2017 7 | # Last Modified : Fri 05 Jan 2018 09:35:00 PM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | import glob 11 | import argparse 12 | import os 13 | import time 14 | import tensorflow as tf 15 | 16 | from model import RPN3D 17 | from config import cfg 18 | from utils import * 19 | 20 | 21 | if __name__ == '__main__': 22 | parser = argparse.ArgumentParser(description='testing') 23 | 24 | parser.add_argument('-n', '--tag', type=str, nargs='?', default='default', 25 | help='set log tag') 26 | parser.add_argument('--output-path', type=str, nargs='?', 27 | default='./data/results/data', help='results output dir') 28 | parser.add_argument('-b', '--single-batch-size', type=int, nargs='?', default=1, 29 | help='set batch size for each gpu') 30 | 31 | args = parser.parse_args() 32 | 33 | dataset_dir = './data/object' 34 | save_model_dir = os.path.join('./save_model', args.tag) 35 | 36 | with tf.Graph().as_default(): 37 | with KittiLoader(object_dir=os.path.join(dataset_dir, 'testing_real'), queue_size=100, require_shuffle=False, is_testset=True, batch_size=args.single_batch_size * cfg.GPU_USE_COUNT, use_multi_process_num=8, multi_gpu_sum=cfg.GPU_USE_COUNT) as test_loader: 38 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION, 39 | visible_device_list=cfg.GPU_AVAILABLE, 40 | allow_growth=True) 41 | config = tf.ConfigProto( 42 | gpu_options=gpu_options, 43 | device_count={ 44 | "GPU": cfg.GPU_USE_COUNT, 45 | }, 46 | allow_soft_placement=True, 47 | ) 48 | 49 | with tf.Session(config=config) as sess: 50 | model = RPN3D( 51 | cls=cfg.DETECT_OBJ, 52 | single_batch_size=args.single_batch_size, 53 | is_train=True, 54 | avail_gpus=cfg.GPU_AVAILABLE.split(',') 55 | ) 56 | if tf.train.get_checkpoint_state(save_model_dir): 57 | print("Reading model parameters from %s" % save_model_dir) 58 | model.saver.restore( 59 | sess, tf.train.latest_checkpoint(save_model_dir)) 60 | while True: 61 | data = test_loader.load() 62 | if data is None: 63 | print('test done.') 64 | break 65 | ret = model.predict_step(sess, data, summary=False) 66 | # ret: A, B 67 | # A: (N) tag 68 | # B: (N, N') (class, x, y, z, h, w, l, rz, score) 69 | for tag, result in zip(*ret): 70 | of_path = os.path.join(args.output_path, tag + '.txt') 71 | with open(of_path, 'w+') as f: 72 | labels = box3d_to_label([result[:, 1:8]], [result[:, 0]], [result[:, -1]], coordinate='lidar')[0] 73 | for line in labels: 74 | f.write(line) 75 | print('write out {} objects to {}'.format(len(labels), tag)) 76 | -------------------------------------------------------------------------------- /train.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : train.py 5 | # Purpose : 6 | # Creation Date : 09-12-2017 7 | # Last Modified : Fri 19 Jan 2018 10:38:47 AM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | import glob 11 | import argparse 12 | import os 13 | import time 14 | import sys 15 | import tensorflow as tf 16 | from itertools import count 17 | 18 | from config import cfg 19 | from model import RPN3D 20 | from utils.kitti_loader import KittiLoader 21 | from train_hook import check_if_should_pause 22 | 23 | parser = argparse.ArgumentParser(description='training') 24 | parser.add_argument('-i', '--max-epoch', type=int, nargs='?', default=10, 25 | help='max epoch') 26 | parser.add_argument('-n', '--tag', type=str, nargs='?', default='default', 27 | help='set log tag') 28 | parser.add_argument('-b', '--single-batch-size', type=int, nargs='?', default=1, 29 | help='set batch size for each gpu') 30 | parser.add_argument('-l', '--lr', type=float, nargs='?', default=0.001, 31 | help='set learning rate') 32 | args = parser.parse_args() 33 | 34 | dataset_dir = './data/object' 35 | log_dir = os.path.join('./log', args.tag) 36 | save_model_dir = os.path.join('./save_model', args.tag) 37 | os.makedirs(log_dir, exist_ok=True) 38 | os.makedirs(save_model_dir, exist_ok=True) 39 | 40 | 41 | def main(_): 42 | # TODO: split file support 43 | with tf.Graph().as_default(): 44 | global save_model_dir 45 | with KittiLoader(object_dir=os.path.join(dataset_dir, 'training'), queue_size=50, require_shuffle=True, 46 | is_testset=False, batch_size=args.single_batch_size * cfg.GPU_USE_COUNT, use_multi_process_num=8, multi_gpu_sum=cfg.GPU_USE_COUNT, aug=True) as train_loader, \ 47 | KittiLoader(object_dir=os.path.join(dataset_dir, 'testing'), queue_size=50, require_shuffle=True, 48 | is_testset=False, batch_size=args.single_batch_size * cfg.GPU_USE_COUNT, use_multi_process_num=8, multi_gpu_sum=cfg.GPU_USE_COUNT, aug=False) as valid_loader: 49 | 50 | gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=cfg.GPU_MEMORY_FRACTION, 51 | visible_device_list=cfg.GPU_AVAILABLE, 52 | allow_growth=True) 53 | config = tf.ConfigProto( 54 | gpu_options=gpu_options, 55 | device_count={ 56 | "GPU": cfg.GPU_USE_COUNT, 57 | }, 58 | allow_soft_placement=True, 59 | ) 60 | with tf.Session(config=config) as sess: 61 | model = RPN3D( 62 | cls=cfg.DETECT_OBJ, 63 | single_batch_size=args.single_batch_size, 64 | learning_rate=args.lr, 65 | max_gradient_norm=5.0, 66 | is_train=True, 67 | alpha=1.5, 68 | beta=1, 69 | avail_gpus=cfg.GPU_AVAILABLE.split(',') 70 | ) 71 | # param init/restore 72 | if tf.train.get_checkpoint_state(save_model_dir): 73 | print("Reading model parameters from %s" % save_model_dir) 74 | model.saver.restore( 75 | sess, tf.train.latest_checkpoint(save_model_dir)) 76 | else: 77 | print("Created model with fresh parameters.") 78 | tf.global_variables_initializer().run() 79 | 80 | # train and validate 81 | iter_per_epoch = int( 82 | len(train_loader) / (args.single_batch_size * cfg.GPU_USE_COUNT)) 83 | is_summary, is_summary_image, is_validate = False, False, False 84 | 85 | summary_interval = 5 86 | summary_image_interval = 20 87 | save_model_interval = int(iter_per_epoch / 3) 88 | validate_interval = 60 89 | 90 | summary_writer = tf.summary.FileWriter(log_dir, sess.graph) 91 | while model.epoch.eval() < args.max_epoch: 92 | is_summary, is_summary_image, is_validate = False, False, False 93 | iter = model.global_step.eval() 94 | if not iter % summary_interval: 95 | is_summary = True 96 | if not iter % summary_image_interval: 97 | is_summary_image = True 98 | if not iter % save_model_interval: 99 | model.saver.save(sess, os.path.join( 100 | save_model_dir, 'checkpoint'), global_step=model.global_step) 101 | if not iter % validate_interval: 102 | is_validate = True 103 | if not iter % iter_per_epoch: 104 | sess.run(model.epoch_add_op) 105 | print('train {} epoch, total: {}'.format( 106 | model.epoch.eval(), args.max_epoch)) 107 | 108 | ret = model.train_step( 109 | sess, train_loader.load(), train=True, summary=is_summary) 110 | print('train: {}/{} @ epoch:{}/{} loss: {} reg_loss: {} cls_loss: {} {}'.format(iter, 111 | iter_per_epoch * args.max_epoch, model.epoch.eval(), args.max_epoch, ret[0], ret[1], ret[2], args.tag)) 112 | 113 | if is_summary: 114 | summary_writer.add_summary(ret[-1], iter) 115 | 116 | if is_summary_image: 117 | ret = model.predict_step( 118 | sess, valid_loader.load(), summary=True) 119 | summary_writer.add_summary(ret[-1], iter) 120 | 121 | if is_validate: 122 | ret = model.validate_step( 123 | sess, valid_loader.load(), summary=True) 124 | summary_writer.add_summary(ret[-1], iter) 125 | 126 | if check_if_should_pause(args.tag): 127 | model.saver.save(sess, os.path.join( 128 | save_model_dir, 'checkpoint'), global_step=model.global_step) 129 | print('pause and save model @ {} steps:{}'.format( 130 | save_model_dir, model.global_step.eval())) 131 | sys.exit(0) 132 | 133 | print('train done. total epoch:{} iter:{}'.format( 134 | model.epoch.eval(), model.global_step.eval())) 135 | 136 | # finallly save model 137 | model.saver.save(sess, os.path.join( 138 | save_model_dir, 'checkpoint'), global_step=model.global_step) 139 | 140 | 141 | if __name__ == '__main__': 142 | tf.app.run(main) 143 | -------------------------------------------------------------------------------- /train_hook.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : train_hook.py 5 | # Purpose : 6 | # Creation Date : 14-12-2017 7 | # Last Modified : Sat 23 Dec 2017 11:45:38 PM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | import os 11 | import argparse 12 | import pickle 13 | import numpy as np 14 | 15 | 16 | def check_if_should_pause(tag): 17 | fname = tag + '.pause.pkl' 18 | ret = False 19 | if os.path.exists(fname): 20 | s = pickle.load(open(tag + '.pause.pkl', 'rb')) 21 | if s == 'pause': 22 | ret = True 23 | os.remove(fname) 24 | return ret 25 | 26 | 27 | def pause_trainer(args): 28 | fname = args.tag + '.pause.pkl' 29 | if os.path.exists(fname): 30 | os.remove(fname) 31 | pickle.dump('pause', open(fname, 'wb')) 32 | 33 | 34 | if __name__ == '__main__': 35 | parser = argparse.ArgumentParser(description='training') 36 | parser.add_argument('--tag', type=str, nargs='?', default='default') 37 | args = parser.parse_args() 38 | 39 | pause_trainer(args) 40 | -------------------------------------------------------------------------------- /utils/__init__.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : __init__.py 5 | # Purpose : 6 | # Creation Date : 21-12-2017 7 | # Last Modified : Fri 19 Jan 2018 10:15:06 AM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | from utils.box_overlaps import * 11 | from utils.colorize import * 12 | from utils.kitti_loader import * 13 | from utils.utils import * 14 | from utils.preprocess import * 15 | from utils.data_aug import * 16 | -------------------------------------------------------------------------------- /utils/box_overlaps.pyx: -------------------------------------------------------------------------------- 1 | # -------------------------------------------------------- 2 | # Fast R-CNN 3 | # Copyright (c) 2015 Microsoft 4 | # Licensed under The MIT License [see LICENSE for details] 5 | # Written by Sergey Karayev 6 | # -------------------------------------------------------- 7 | 8 | import numpy as np 9 | cimport numpy as np 10 | from cython.parallel import prange, parallel 11 | 12 | 13 | DTYPE = np.float32 14 | ctypedef float DTYPE_t 15 | 16 | 17 | def bbox_overlaps( 18 | np.ndarray[DTYPE_t, ndim=2] boxes, 19 | np.ndarray[DTYPE_t, ndim=2] query_boxes): 20 | """ 21 | Parameters 22 | ---------- 23 | boxes: (N, 4) ndarray of float 24 | query_boxes: (K, 4) ndarray of float 25 | Returns 26 | ------- 27 | overlaps: (N, K) ndarray of overlap between boxes and query_boxes 28 | """ 29 | cdef unsigned int N = boxes.shape[0] 30 | cdef unsigned int K = query_boxes.shape[0] 31 | cdef np.ndarray[DTYPE_t, ndim=2] overlaps = np.zeros((N, K), dtype=DTYPE) 32 | cdef DTYPE_t iw, ih, box_area 33 | cdef DTYPE_t ua 34 | cdef unsigned int k, n 35 | for k in range(K): 36 | box_area = ( 37 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 38 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 39 | ) 40 | for n in range(N): 41 | iw = ( 42 | min(boxes[n, 2], query_boxes[k, 2]) - 43 | max(boxes[n, 0], query_boxes[k, 0]) + 1 44 | ) 45 | if iw > 0: 46 | ih = ( 47 | min(boxes[n, 3], query_boxes[k, 3]) - 48 | max(boxes[n, 1], query_boxes[k, 1]) + 1 49 | ) 50 | if ih > 0: 51 | ua = float( 52 | (boxes[n, 2] - boxes[n, 0] + 1) * 53 | (boxes[n, 3] - boxes[n, 1] + 1) + 54 | box_area - iw * ih 55 | ) 56 | overlaps[n, k] = iw * ih / ua 57 | return overlaps 58 | 59 | def bbox_intersections( 60 | np.ndarray[DTYPE_t, ndim=2] boxes, 61 | np.ndarray[DTYPE_t, ndim=2] query_boxes): 62 | """ 63 | For each query box compute the intersection ratio covered by boxes 64 | ---------- 65 | Parameters 66 | ---------- 67 | boxes: (N, 4) ndarray of float 68 | query_boxes: (K, 4) ndarray of float 69 | Returns 70 | ------- 71 | overlaps: (N, K) ndarray of intersec between boxes and query_boxes 72 | """ 73 | cdef unsigned int N = boxes.shape[0] 74 | cdef unsigned int K = query_boxes.shape[0] 75 | cdef np.ndarray[DTYPE_t, ndim=2] intersec = np.zeros((N, K), dtype=DTYPE) 76 | cdef DTYPE_t iw, ih, box_area 77 | cdef DTYPE_t ua 78 | cdef unsigned int k, n 79 | for k in range(K): 80 | box_area = ( 81 | (query_boxes[k, 2] - query_boxes[k, 0] + 1) * 82 | (query_boxes[k, 3] - query_boxes[k, 1] + 1) 83 | ) 84 | for n in range(N): 85 | iw = ( 86 | min(boxes[n, 2], query_boxes[k, 2]) - 87 | max(boxes[n, 0], query_boxes[k, 0]) + 1 88 | ) 89 | if iw > 0: 90 | ih = ( 91 | min(boxes[n, 3], query_boxes[k, 3]) - 92 | max(boxes[n, 1], query_boxes[k, 1]) + 1 93 | ) 94 | if ih > 0: 95 | intersec[n, k] = iw * ih / box_area 96 | return intersec 97 | 98 | # Compute bounding box voting 99 | def box_vote( 100 | np.ndarray[float, ndim=2] dets_NMS, 101 | np.ndarray[float, ndim=2] dets_all): 102 | cdef np.ndarray[float, ndim=2] dets_voted = np.zeros((dets_NMS.shape[0], dets_NMS.shape[1]), dtype=np.float32) 103 | cdef unsigned int N = dets_NMS.shape[0] 104 | cdef unsigned int M = dets_all.shape[0] 105 | 106 | cdef np.ndarray[float, ndim=1] det 107 | cdef np.ndarray[float, ndim=1] acc_box 108 | cdef float acc_score 109 | 110 | cdef np.ndarray[float, ndim=1] det2 111 | cdef float bi0, bi1, bit2, bi3 112 | cdef float iw, ih, ua 113 | 114 | cdef float thresh=0.5 115 | 116 | for i in range(N): 117 | det = dets_NMS[i, :] 118 | acc_box = np.zeros((4), dtype=np.float32) 119 | acc_score = 0.0 120 | 121 | for m in range(M): 122 | det2 = dets_all[m, :] 123 | 124 | bi0 = max(det[0], det2[0]) 125 | bi1 = max(det[1], det2[1]) 126 | bi2 = min(det[2], det2[2]) 127 | bi3 = min(det[3], det2[3]) 128 | 129 | iw = bi2 - bi0 + 1 130 | ih = bi3 - bi1 + 1 131 | 132 | if not (iw > 0 and ih > 0): 133 | continue 134 | 135 | ua = (det[2] - det[0] + 1) * (det[3] - det[1] + 1) + (det2[2] - det2[0] + 1) * (det2[3] - det2[1] + 1) - iw * ih 136 | ov = iw * ih / ua 137 | 138 | if (ov < thresh): 139 | continue 140 | 141 | acc_box += det2[4] * det2[0:4] 142 | acc_score += det2[4] 143 | 144 | dets_voted[i][0:4] = acc_box / acc_score 145 | dets_voted[i][4] = det[4] # Keep the original score 146 | 147 | return dets_voted 148 | -------------------------------------------------------------------------------- /utils/colorize.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : colorize.py 5 | # Purpose : 6 | # Creation Date : 21-12-2017 7 | # Last Modified : Thu 21 Dec 2017 09:02:22 PM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | # ref: https://gist.github.com/jimfleming/c1adfdb0f526465c99409cc143dea97b 11 | 12 | import matplotlib 13 | import matplotlib.cm 14 | import cv2 15 | import numpy as np 16 | 17 | import tensorflow as tf 18 | 19 | 20 | def colorize(value, factor=1, vmin=None, vmax=None): 21 | """ 22 | A utility function for TensorFlow that maps a grayscale image to a matplotlib 23 | colormap for use with TensorBoard image summaries. 24 | 25 | By default it will normalize the input value to the range 0..1 before mapping 26 | to a grayscale colormap. 27 | 28 | Arguments: 29 | - value: 2D Tensor of shape [height, width] or 3D Tensor of shape 30 | [height, width, 1]. 31 | - factor: resize factor, scalar 32 | - vmin: the minimum value of the range used for normalization. 33 | (Default: value minimum) 34 | - vmax: the maximum value of the range used for normalization. 35 | (Default: value maximum) 36 | 37 | Example usage: 38 | 39 | ``` 40 | output = tf.random_uniform(shape=[256, 256, 1]) 41 | output_color = colorize(output, vmin=0.0, vmax=1.0, cmap='viridis') 42 | tf.summary.image('output', output_color) 43 | ``` 44 | 45 | Returns a 3D tensor of shape [height, width, 3]. 46 | """ 47 | 48 | # normalize 49 | value = np.sum(value, axis=-1) 50 | vmin = np.min(value) if vmin is None else vmin 51 | vmax = np.max(value) if vmax is None else vmax 52 | value = (value - vmin) / (vmax - vmin) # vmin..vmax 53 | 54 | value = (value * 255).astype(np.uint8) 55 | value = cv2.applyColorMap(value, cv2.COLORMAP_JET) 56 | value = cv2.cvtColor(value, cv2.COLOR_BGR2RGB) 57 | x, y, _ = value.shape 58 | value = cv2.resize(value, (y * factor, x * factor)) 59 | 60 | return value 61 | 62 | 63 | def tf_colorize(value, factor=1, vmin=None, vmax=None, cmap=None): 64 | """ 65 | A utility function for TensorFlow that maps a grayscale image to a matplotlib 66 | colormap for use with TensorBoard image summaries. 67 | 68 | By default it will normalize the input value to the range 0..1 before mapping 69 | to a grayscale colormap. 70 | 71 | Arguments: 72 | - value: 2D Tensor of shape [height, width] or 3D Tensor of shape 73 | [height, width, 1]. 74 | - factor: resize factor, scalar 75 | - vmin: the minimum value of the range used for normalization. 76 | (Default: value minimum) 77 | - vmax: the maximum value of the range used for normalization. 78 | (Default: value maximum) 79 | - cmap: a valid cmap named for use with matplotlib's `get_cmap`. 80 | (Default: 'gray') 81 | 82 | Example usage: 83 | 84 | ``` 85 | output = tf.random_uniform(shape=[256, 256, 1]) 86 | output_color = colorize(output, vmin=0.0, vmax=1.0, cmap='viridis') 87 | tf.summary.image('output', output_color) 88 | ``` 89 | 90 | Returns a 3D tensor of shape [height, width, 3]. 91 | """ 92 | 93 | # normalize 94 | vmin = tf.reduce_min(value) if vmin is None else vmin 95 | vmax = tf.reduce_max(value) if vmax is None else vmax 96 | value = (value - vmin) / (vmax - vmin) # vmin..vmax 97 | 98 | # squeeze last dim if it exists 99 | value = tf.squeeze(value) 100 | 101 | # quantize 102 | indices = tf.to_int32(tf.round(value * 255)) 103 | 104 | # gather 105 | cm = matplotlib.cm.get_cmap(cmap if cmap is not None else 'gray') 106 | colors = tf.constant(cm.colors, dtype=tf.float32) 107 | value = tf.gather(colors, indices) 108 | 109 | return value 110 | -------------------------------------------------------------------------------- /utils/data_aug.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : data_aug.py 5 | # Purpose : 6 | # Creation Date : 21-12-2017 7 | # Last Modified : Fri 19 Jan 2018 01:06:35 PM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | import numpy as np 11 | import cv2 12 | import os 13 | import multiprocessing as mp 14 | import argparse 15 | import glob 16 | 17 | from utils.utils import * 18 | from utils.preprocess import * 19 | 20 | object_dir = './data/object' 21 | 22 | 23 | def aug_data(tag, object_dir): 24 | np.random.seed() 25 | rgb = cv2.resize(cv2.imread(os.path.join(object_dir, 26 | 'image_2', tag + '.png')), (cfg.IMAGE_WIDTH, cfg.IMAGE_HEIGHT)) 27 | lidar = np.fromfile(os.path.join(object_dir, 28 | 'velodyne', tag + '.bin'), dtype=np.float32).reshape(-1, 4) 29 | label = np.array([line for line in open(os.path.join( 30 | object_dir, 'label_2', tag + '.txt'), 'r').readlines()]) # (N') 31 | cls = np.array([line.split()[0] for line in label]) # (N') 32 | gt_box3d = label_to_gt_box3d(np.array(label)[np.newaxis, :], cls='', coordinate='camera')[ 33 | 0] # (N', 7) x, y, z, h, w, l, r 34 | 35 | choice = np.random.randint(1, 10) 36 | if choice >= 7: 37 | lidar_center_gt_box3d = camera_to_lidar_box(gt_box3d) 38 | lidar_corner_gt_box3d = center_to_corner_box3d( 39 | lidar_center_gt_box3d, coordinate='lidar') 40 | for idx in range(len(lidar_corner_gt_box3d)): 41 | # TODO: precisely gather the point 42 | is_collision = True 43 | _count = 0 44 | while is_collision and _count < 100: 45 | t_rz = np.random.uniform(-np.pi / 10, np.pi / 10) 46 | t_x = np.random.normal() 47 | t_y = np.random.normal() 48 | t_z = np.random.normal() 49 | # check collision 50 | tmp = box_transform( 51 | lidar_center_gt_box3d[[idx]], t_x, t_y, t_z, t_rz, 'lidar') 52 | is_collision = False 53 | for idy in range(idx): 54 | x1, y1, w1, l1, r1 = tmp[0][[0, 1, 4, 5, 6]] 55 | x2, y2, w2, l2, r2 = lidar_center_gt_box3d[idy][[ 56 | 0, 1, 4, 5, 6]] 57 | iou = cal_iou2d(np.array([x1, y1, w1, l1, r1], dtype=np.float32), 58 | np.array([x2, y2, w2, l2, r2], dtype=np.float32)) 59 | if iou > 0: 60 | is_collision = True 61 | _count += 1 62 | break 63 | if not is_collision: 64 | box_corner = lidar_corner_gt_box3d[idx] 65 | minx = np.min(box_corner[:, 0]) 66 | miny = np.min(box_corner[:, 1]) 67 | minz = np.min(box_corner[:, 2]) 68 | maxx = np.max(box_corner[:, 0]) 69 | maxy = np.max(box_corner[:, 1]) 70 | maxz = np.max(box_corner[:, 2]) 71 | bound_x = np.logical_and( 72 | lidar[:, 0] >= minx, lidar[:, 0] <= maxx) 73 | bound_y = np.logical_and( 74 | lidar[:, 1] >= miny, lidar[:, 1] <= maxy) 75 | bound_z = np.logical_and( 76 | lidar[:, 2] >= minz, lidar[:, 2] <= maxz) 77 | bound_box = np.logical_and( 78 | np.logical_and(bound_x, bound_y), bound_z) 79 | lidar[bound_box, 0:3] = point_transform( 80 | lidar[bound_box, 0:3], t_x, t_y, t_z, rz=t_rz) 81 | lidar_center_gt_box3d[idx] = box_transform( 82 | lidar_center_gt_box3d[[idx]], t_x, t_y, t_z, t_rz, 'lidar') 83 | 84 | gt_box3d = lidar_to_camera_box(lidar_center_gt_box3d) 85 | newtag = 'aug_{}_1_{}'.format( 86 | tag, np.random.randint(1, 1024)) 87 | elif choice < 7 and choice >= 4: 88 | # global rotation 89 | angle = np.random.uniform(-np.pi / 4, np.pi / 4) 90 | lidar[:, 0:3] = point_transform(lidar[:, 0:3], 0, 0, 0, rz=angle) 91 | lidar_center_gt_box3d = camera_to_lidar_box(gt_box3d) 92 | lidar_center_gt_box3d = box_transform(lidar_center_gt_box3d, 0, 0, 0, r=angle, coordinate='lidar') 93 | gt_box3d = lidar_to_camera_box(lidar_center_gt_box3d) 94 | newtag = 'aug_{}_2_{:.4f}'.format(tag, angle).replace('.', '_') 95 | else: 96 | # global scaling 97 | factor = np.random.uniform(0.95, 1.05) 98 | lidar[:, 0:3] = lidar[:, 0:3] * factor 99 | lidar_center_gt_box3d = camera_to_lidar_box(gt_box3d) 100 | lidar_center_gt_box3d[:, 0:6] = lidar_center_gt_box3d[:, 0:6] * factor 101 | gt_box3d = lidar_to_camera_box(lidar_center_gt_box3d) 102 | newtag = 'aug_{}_3_{:.4f}'.format(tag, factor).replace('.', '_') 103 | 104 | label = box3d_to_label(gt_box3d[np.newaxis, ...], cls[np.newaxis, ...], coordinate='camera')[0] # (N') 105 | voxel_dict = process_pointcloud(lidar) 106 | return newtag, rgb, lidar, voxel_dict, label 107 | 108 | 109 | def worker(tag): 110 | new_tag, rgb, lidar, voxel_dict, label = aug_data(tag) 111 | output_path = os.path.join(object_dir, 'training_aug') 112 | 113 | cv2.imwrite(os.path.join(output_path, 'image_2', newtag + '.png'), rgb) 114 | lidar.reshape(-1).tofile(os.path.join(output_path, 115 | 'velodyne', newtag + '.bin')) 116 | np.savez_compressed(os.path.join( 117 | output_path, 'voxel' if cfg.DETECT_OBJ == 'Car' else 'voxel_ped', newtag), **voxel_dict) 118 | with open(os.path.join(output_path, 'label_2', newtag + '.txt'), 'w+') as f: 119 | for line in label: 120 | f.write(line) 121 | print(newtag) 122 | 123 | 124 | def main(): 125 | fl = glob.glob(os.path.join(object_dir, 'training', 'calib', '*.txt')) 126 | candidate = [f.split('/')[-1].split('.')[0] for f in fl] 127 | tags = [] 128 | for _ in range(args.aug_amount): 129 | tags.append(candidate[np.random.randint(0, len(candidate))]) 130 | print('generate {} tags'.format(len(tags))) 131 | pool = mp.Pool(args.num_workers) 132 | pool.map(worker, tags) 133 | 134 | 135 | if __name__ == '__main__': 136 | parser = argparse.ArgumentParser(description='') 137 | parser.add_argument('-i', '--aug-amount', type=int, nargs='?', default=1000) 138 | parser.add_argument('-n', '--num-workers', type=int, nargs='?', default=10) 139 | args = parser.parse_args() 140 | 141 | main() 142 | -------------------------------------------------------------------------------- /utils/kitti_loader.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : kitti_loader.py 5 | # Purpose : 6 | # Creation Date : 09-12-2017 7 | # Last Modified : Fri 19 Jan 2018 03:11:15 PM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | import cv2 11 | import numpy as np 12 | import os 13 | import sys 14 | import glob 15 | import threading 16 | import time 17 | import math 18 | import random 19 | from sklearn.utils import shuffle 20 | from multiprocessing import Lock, Process, Queue as Queue, Value, Array, cpu_count 21 | 22 | from config import cfg 23 | from utils.data_aug import aug_data 24 | from utils.preprocess import process_pointcloud 25 | 26 | # for non-raw dataset 27 | 28 | 29 | class KittiLoader(object): 30 | 31 | # return: 32 | # tag (N) 33 | # label (N) (N') 34 | # rgb (N, H, W, C) 35 | # raw_lidar (N) (N', 4) 36 | # vox_feature 37 | # vox_number 38 | # vox_coordinate 39 | 40 | def __init__(self, object_dir='.', queue_size=20, require_shuffle=False, is_testset=True, batch_size=1, use_multi_process_num=0, split_file='', multi_gpu_sum=1, aug=False): 41 | assert(use_multi_process_num >= 0) 42 | self.object_dir = object_dir 43 | self.is_testset = is_testset 44 | self.use_multi_process_num = use_multi_process_num if not self.is_testset else 1 45 | self.require_shuffle = require_shuffle if not self.is_testset else False 46 | self.batch_size = batch_size 47 | self.split_file = split_file 48 | self.multi_gpu_sum = multi_gpu_sum 49 | self.aug = aug 50 | 51 | if self.split_file != '': 52 | # use split file 53 | _tag = [] 54 | self.f_rgb, self.f_lidar, self.f_label = [], [], [] 55 | for line in open(self.split_file, 'r').readlines(): 56 | line = line[:-1] # remove '\n' 57 | _tag.append(line) 58 | self.f_rgb.append(os.path.join( 59 | self.object_dir, 'image_2', line + '.png')) 60 | self.f_lidar.append(os.path.join( 61 | self.object_dir, 'velodyne', line + '.bin')) 62 | self.f_label.append(os.path.join( 63 | self.object_dir, 'label_2', line + '.txt')) 64 | else: 65 | self.f_rgb = glob.glob(os.path.join( 66 | self.object_dir, 'image_2', '*.png')) 67 | self.f_rgb.sort() 68 | self.f_lidar = glob.glob(os.path.join( 69 | self.object_dir, 'velodyne', '*.bin')) 70 | self.f_lidar.sort() 71 | self.f_label = glob.glob(os.path.join( 72 | self.object_dir, 'label_2', '*.txt')) 73 | self.f_label.sort() 74 | 75 | self.data_tag = [name.split('/')[-1].split('.')[-2] 76 | for name in self.f_rgb] 77 | assert(len(self.data_tag) == len(self.f_rgb) == len(self.f_lidar)) 78 | self.dataset_size = len(self.f_rgb) 79 | self.already_extract_data = 0 80 | self.cur_frame_info = '' 81 | 82 | print("Dataset total length: {}".format(self.dataset_size)) 83 | if self.require_shuffle: 84 | self.shuffle_dataset() 85 | 86 | self.queue_size = queue_size 87 | self.require_shuffle = require_shuffle 88 | # must use the queue provided by multiprocessing module(only this can be shared) 89 | self.dataset_queue = Queue() 90 | 91 | self.load_index = 0 92 | if self.use_multi_process_num == 0: 93 | self.loader_worker = [threading.Thread( 94 | target=self.loader_worker_main, args=(self.batch_size,))] 95 | else: 96 | self.loader_worker = [Process(target=self.loader_worker_main, args=( 97 | self.batch_size,)) for i in range(self.use_multi_process_num)] 98 | self.work_exit = Value('i', 0) 99 | [i.start() for i in self.loader_worker] 100 | 101 | # This operation is not thread-safe 102 | self.rgb_shape = (cfg.IMAGE_HEIGHT, cfg.IMAGE_WIDTH, 3) 103 | 104 | def __enter__(self): 105 | return self 106 | 107 | def __exit__(self, exc_type, exc_val, exc_tb): 108 | self.work_exit.value = True 109 | 110 | def __len__(self): 111 | return self.dataset_size 112 | 113 | def fill_queue(self, batch_size=0): 114 | load_index = self.load_index 115 | self.load_index += batch_size 116 | if self.load_index >= self.dataset_size: 117 | if not self.is_testset: # test set just end 118 | if self.require_shuffle: 119 | self.shuffle_dataset() 120 | load_index = 0 121 | self.load_index = load_index + batch_size 122 | else: 123 | self.work_exit.value = True 124 | 125 | labels, tag, voxel, rgb, raw_lidar = [], [], [], [], [] 126 | for _ in range(batch_size): 127 | try: 128 | if self.aug: 129 | ret = aug_data(self.data_tag[load_index], self.object_dir) 130 | tag.append(ret[0]) 131 | rgb.append(ret[1]) 132 | raw_lidar.append(ret[2]) 133 | voxel.append(ret[3]) 134 | labels.append(ret[4]) 135 | else: 136 | rgb.append(cv2.resize(cv2.imread( 137 | self.f_rgb[load_index]), (cfg.IMAGE_WIDTH, cfg.IMAGE_HEIGHT))) 138 | raw_lidar.append(np.fromfile( 139 | self.f_lidar[load_index], dtype=np.float32).reshape((-1, 4))) 140 | if not self.is_testset: 141 | labels.append([line for line in open( 142 | self.f_label[load_index], 'r').readlines()]) 143 | else: 144 | labels.append(['']) 145 | tag.append(self.data_tag[load_index]) 146 | voxel.append(process_pointcloud(raw_lidar[-1])) 147 | 148 | load_index += 1 149 | except: 150 | if not self.is_testset: # test set just end 151 | self.load_index = 0 152 | if self.require_shuffle: 153 | self.shuffle_dataset() 154 | else: 155 | self.work_exit.value = True 156 | 157 | # only for voxel -> [gpu, k_single_batch, ...] 158 | vox_feature, vox_number, vox_coordinate = [], [], [] 159 | single_batch_size = int(self.batch_size / self.multi_gpu_sum) 160 | for idx in range(self.multi_gpu_sum): 161 | _, per_vox_feature, per_vox_number, per_vox_coordinate = build_input( 162 | voxel[idx * single_batch_size:(idx + 1) * single_batch_size]) 163 | vox_feature.append(per_vox_feature) 164 | vox_number.append(per_vox_number) 165 | vox_coordinate.append(per_vox_coordinate) 166 | 167 | self.dataset_queue.put_nowait( 168 | (labels, (vox_feature, vox_number, vox_coordinate), rgb, raw_lidar, tag)) 169 | 170 | def load(self): 171 | try: 172 | if self.is_testset and self.already_extract_data >= self.dataset_size: 173 | return None 174 | 175 | buff = self.dataset_queue.get() 176 | label = buff[0] 177 | vox_feature = buff[1][0] 178 | vox_number = buff[1][1] 179 | vox_coordinate = buff[1][2] 180 | rgb = buff[2] 181 | raw_lidar = buff[3] 182 | tag = buff[4] 183 | self.cur_frame_info = buff[4] 184 | 185 | self.already_extract_data += self.batch_size 186 | 187 | ret = ( 188 | np.array(tag), 189 | np.array(label), 190 | np.array(vox_feature), 191 | np.array(vox_number), 192 | np.array(vox_coordinate), 193 | np.array(rgb), 194 | np.array(raw_lidar) 195 | 196 | ) 197 | except: 198 | print("Dataset empty!") 199 | ret = None 200 | return ret 201 | 202 | def load_specified(self, index=0): 203 | rgb = cv2.resize(cv2.imread( 204 | self.f_rgb[index]), (cfg.IMAGE_WIDTH, cfg.IMAGE_HEIGHT)) 205 | raw_lidar = np.fromfile( 206 | self.f_lidar[index], dtype=np.float32).reshape((-1, 4)) 207 | labels = [line for line in open(self.f_label[index], 'r').readlines()] 208 | tag = self.data_tag[index] 209 | 210 | if self.is_testset: 211 | ret = ( 212 | np.array([tag]), 213 | np.array([rgb]), 214 | np.array([raw_lidar]), 215 | ) 216 | else: 217 | ret = ( 218 | np.array([tag]), 219 | np.array([labels]), 220 | np.array([rgb]), 221 | np.array([raw_lidar]), 222 | ) 223 | return ret 224 | 225 | def loader_worker_main(self, batch_size): 226 | if self.require_shuffle: 227 | self.shuffle_dataset() 228 | while not self.work_exit.value: 229 | if self.dataset_queue.qsize() >= self.queue_size // 2: 230 | time.sleep(1) 231 | else: 232 | # since we use multiprocessing, 1 is ok 233 | self.fill_queue(batch_size) 234 | 235 | def get_shape(self): 236 | return self.rgb_shape 237 | 238 | def shuffle_dataset(self): 239 | # to prevent diff loader load same data 240 | index = shuffle([i for i in range(len(self.f_label))], 241 | random_state=random.randint(0, self.use_multi_process_num**5)) 242 | self.f_label = [self.f_label[i] for i in index] 243 | self.f_rgb = [self.f_rgb[i] for i in index] 244 | self.f_lidar = [self.f_lidar[i] for i in index] 245 | self.data_tag = [self.data_tag[i] for i in index] 246 | 247 | def get_frame_info(self): 248 | return self.cur_frame_info 249 | 250 | 251 | def build_input(voxel_dict_list): 252 | batch_size = len(voxel_dict_list) 253 | 254 | feature_list = [] 255 | number_list = [] 256 | coordinate_list = [] 257 | for i, voxel_dict in zip(range(batch_size), voxel_dict_list): 258 | feature_list.append(voxel_dict['feature_buffer']) 259 | number_list.append(voxel_dict['number_buffer']) 260 | coordinate = voxel_dict['coordinate_buffer'] 261 | coordinate_list.append( 262 | np.pad(coordinate, ((0, 0), (1, 0)), 263 | mode='constant', constant_values=i)) 264 | 265 | feature = np.concatenate(feature_list) 266 | number = np.concatenate(number_list) 267 | coordinate = np.concatenate(coordinate_list) 268 | return batch_size, feature, number, coordinate 269 | 270 | 271 | if __name__ == '__main__': 272 | pass 273 | -------------------------------------------------------------------------------- /utils/preprocess.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : preprocess.py 5 | # Purpose : 6 | # Creation Date : 10-12-2017 7 | # Last Modified : Thu 18 Jan 2018 05:34:42 PM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | import os 11 | import multiprocessing 12 | import numpy as np 13 | 14 | from config import cfg 15 | 16 | data_dir = 'velodyne' 17 | 18 | def process_pointcloud(point_cloud, cls=cfg.DETECT_OBJ): 19 | # Input: 20 | # (N, 4) 21 | # Output: 22 | # voxel_dict 23 | if cls == 'Car': 24 | scene_size = np.array([4, 80, 70.4], dtype=np.float32) 25 | voxel_size = np.array([0.4, 0.2, 0.2], dtype=np.float32) 26 | grid_size = np.array([10, 400, 352], dtype=np.int64) 27 | lidar_coord = np.array([0, 40, 3], dtype=np.float32) 28 | max_point_number = 35 29 | else: 30 | scene_size = np.array([4, 40, 48], dtype=np.float32) 31 | voxel_size = np.array([0.4, 0.2, 0.2], dtype=np.float32) 32 | grid_size = np.array([10, 200, 240], dtype=np.int64) 33 | lidar_coord = np.array([0, 20, 3], dtype=np.float32) 34 | max_point_number = 45 35 | 36 | np.random.shuffle(point_cloud) 37 | 38 | shifted_coord = point_cloud[:, :3] + lidar_coord 39 | # reverse the point cloud coordinate (X, Y, Z) -> (Z, Y, X) 40 | voxel_index = np.floor( 41 | shifted_coord[:, ::-1] / voxel_size).astype(np.int) 42 | 43 | bound_x = np.logical_and( 44 | voxel_index[:, 2] >= 0, voxel_index[:, 2] < grid_size[2]) 45 | bound_y = np.logical_and( 46 | voxel_index[:, 1] >= 0, voxel_index[:, 1] < grid_size[1]) 47 | bound_z = np.logical_and( 48 | voxel_index[:, 0] >= 0, voxel_index[:, 0] < grid_size[0]) 49 | 50 | bound_box = np.logical_and(np.logical_and(bound_x, bound_y), bound_z) 51 | 52 | point_cloud = point_cloud[bound_box] 53 | voxel_index = voxel_index[bound_box] 54 | 55 | # [K, 3] coordinate buffer as described in the paper 56 | coordinate_buffer = np.unique(voxel_index, axis=0) 57 | 58 | K = len(coordinate_buffer) 59 | T = max_point_number 60 | 61 | # [K, 1] store number of points in each voxel grid 62 | number_buffer = np.zeros(shape=(K), dtype=np.int64) 63 | 64 | # [K, T, 7] feature buffer as described in the paper 65 | feature_buffer = np.zeros(shape=(K, T, 7), dtype=np.float32) 66 | 67 | # build a reverse index for coordinate buffer 68 | index_buffer = {} 69 | for i in range(K): 70 | index_buffer[tuple(coordinate_buffer[i])] = i 71 | 72 | for voxel, point in zip(voxel_index, point_cloud): 73 | index = index_buffer[tuple(voxel)] 74 | number = number_buffer[index] 75 | if number < T: 76 | feature_buffer[index, number, :4] = point 77 | number_buffer[index] += 1 78 | 79 | feature_buffer[:, :, -3:] = feature_buffer[:, :, :3] - \ 80 | feature_buffer[:, :, :3].sum(axis=1, keepdims=True)/number_buffer.reshape(K, 1, 1) 81 | 82 | voxel_dict = {'feature_buffer': feature_buffer, 83 | 'coordinate_buffer': coordinate_buffer, 84 | 'number_buffer': number_buffer} 85 | return voxel_dict 86 | 87 | 88 | def worker(filelist): 89 | for file in filelist: 90 | point_cloud = np.fromfile( 91 | os.path.join(data_dir, file), dtype=np.float32).reshape(-1, 4) 92 | 93 | name, extension = os.path.splitext(file) 94 | voxel_dict = process_pointcloud(point_cloud) 95 | output_dir = 'voxel' if cfg.DETECT_OBJ == 'Car' else 'voxel_ped' 96 | np.savez_compressed(os.path.join(output_dir, name), **voxel_dict) 97 | 98 | 99 | if __name__ == '__main__': 100 | filelist = [f for f in os.listdir(data_dir) if f.endswith('bin')] 101 | num_worker = 8 102 | for sublist in np.array_split(filelist, num_worker): 103 | p = multiprocessing.Process(target=worker, args=(sublist,)) 104 | p.start() 105 | -------------------------------------------------------------------------------- /utils/setup.py: -------------------------------------------------------------------------------- 1 | #!/usr/bin/env python 2 | # -*- coding:UTF-8 -*- 3 | 4 | # File Name : setup.py 5 | # Purpose : 6 | # Creation Date : 11-12-2017 7 | # Last Modified : Sat 23 Dec 2017 03:19:46 PM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | 11 | from distutils.core import setup 12 | from Cython.Build import cythonize 13 | 14 | setup( 15 | name='box overlaps', 16 | ext_modules=cythonize('./utils/box_overlaps.pyx') 17 | ) 18 | -------------------------------------------------------------------------------- /utils/utils.py: -------------------------------------------------------------------------------- 1 | 2 | # -*- cooing:UTF-8 -*- 3 | 4 | # File Name : utils.py 5 | # Purpose : 6 | # Creation Date : 09-12-2017 7 | # Last Modified : Thu 08 Mar 2018 02:30:56 PM CST 8 | # Created By : Jeasine Ma [jeasinema[at]gmail[dot]com] 9 | 10 | import cv2 11 | import numpy as np 12 | import shapely.geometry 13 | import shapely.affinity 14 | import math 15 | from numba import jit 16 | 17 | from config import cfg 18 | from utils.box_overlaps import * 19 | 20 | 21 | def lidar_to_bird_view(x, y, factor=1): 22 | # using the cfg.INPUT_XXX 23 | a = (x - cfg.X_MIN) / cfg.VOXEL_X_SIZE * factor 24 | b = (y - cfg.Y_MIN) / cfg.VOXEL_Y_SIZE * factor 25 | a = np.clip(a, a_max=(cfg.X_MAX - cfg.X_MIN) / cfg.VOXEL_X_SIZE * factor, a_min=0) 26 | b = np.clip(b, a_max=(cfg.Y_MAX - cfg.Y_MIN) / cfg.VOXEL_Y_SIZE * factor, a_min=0) 27 | return a, b 28 | 29 | def batch_lidar_to_bird_view(points, factor=1): 30 | # Input: 31 | # points (N, 2) 32 | # Outputs: 33 | # points (N, 2) 34 | # using the cfg.INPUT_XXX 35 | a = (points[:, 0] - cfg.X_MIN) / cfg.VOXEL_X_SIZE * factor 36 | b = (points[:, 1] - cfg.Y_MIN) / cfg.VOXEL_Y_SIZE * factor 37 | a = np.clip(a, a_max=(cfg.X_MAX - cfg.X_MIN) / cfg.VOXEL_X_SIZE * factor, a_min=0) 38 | b = np.clip(b, a_max=(cfg.Y_MAX - cfg.Y_MIN) / cfg.VOXEL_Y_SIZE * factor, a_min=0) 39 | return np.concatenate([a[:, np.newaxis], b[:, np.newaxis]], axis=-1) 40 | 41 | 42 | def angle_in_limit(angle): 43 | # To limit the angle in -pi/2 - pi/2 44 | limit_degree = 5 45 | while angle >= np.pi / 2: 46 | angle -= np.pi 47 | while angle < -np.pi / 2: 48 | angle += np.pi 49 | if abs(angle + np.pi / 2) < limit_degree / 180 * np.pi: 50 | angle = np.pi / 2 51 | return angle 52 | 53 | 54 | def camera_to_lidar(x, y, z): 55 | p = np.array([x, y, z, 1]) 56 | p = np.matmul(np.linalg.inv(np.array(cfg.MATRIX_R_RECT_0)), p) 57 | p = np.matmul(np.linalg.inv(np.array(cfg.MATRIX_T_VELO_2_CAM)), p) 58 | p = p[0:3] 59 | return tuple(p) 60 | 61 | 62 | def lidar_to_camera(x, y, z): 63 | p = np.array([x, y, z, 1]) 64 | p = np.matmul(np.array(cfg.MATRIX_T_VELO_2_CAM), p) 65 | p = np.matmul(np.array(cfg.MATRIX_R_RECT_0), p) 66 | p = p[0:3] 67 | return tuple(p) 68 | 69 | 70 | def camera_to_lidar_point(points): 71 | # (N, 3) -> (N, 3) 72 | N = points.shape[0] 73 | points = np.hstack([points, np.ones((N, 1))]).T # (N,4) -> (4,N) 74 | 75 | points = np.matmul(np.linalg.inv(np.array(cfg.MATRIX_R_RECT_0)), points) 76 | points = np.matmul(np.linalg.inv( 77 | np.array(cfg.MATRIX_T_VELO_2_CAM)), points).T # (4, N) -> (N, 4) 78 | points = points[:, 0:3] 79 | return points.reshape(-1, 3) 80 | 81 | 82 | def lidar_to_camera_point(points): 83 | # (N, 3) -> (N, 3) 84 | N = points.shape[0] 85 | points = np.hstack([points, np.ones((N, 1))]).T 86 | 87 | points = np.matmul(np.array(cfg.MATRIX_T_VELO_2_CAM), points) 88 | points = np.matmul(np.array(cfg.MATRIX_R_RECT_0), points).T 89 | points = points[:, 0:3] 90 | return points.reshape(-1, 3) 91 | 92 | 93 | def camera_to_lidar_box(boxes): 94 | # (N, 7) -> (N, 7) x,y,z,h,w,l,r 95 | ret = [] 96 | for box in boxes: 97 | x, y, z, h, w, l, ry = box 98 | (x, y, z), h, w, l, rz = camera_to_lidar( 99 | x, y, z), h, w, l, -ry - np.pi / 2 100 | rz = angle_in_limit(rz) 101 | ret.append([x, y, z, h, w, l, rz]) 102 | return np.array(ret).reshape(-1, 7) 103 | 104 | 105 | def lidar_to_camera_box(boxes): 106 | # (N, 7) -> (N, 7) x,y,z,h,w,l,r 107 | ret = [] 108 | for box in boxes: 109 | x, y, z, h, w, l, rz = box 110 | (x, y, z), h, w, l, ry = lidar_to_camera( 111 | x, y, z), h, w, l, -rz - np.pi / 2 112 | ry = angle_in_limit(ry) 113 | ret.append([x, y, z, h, w, l, ry]) 114 | return np.array(ret).reshape(-1, 7) 115 | 116 | 117 | def center_to_corner_box2d(boxes_center, coordinate='lidar'): 118 | # (N, 5) -> (N, 4, 2) 119 | N = boxes_center.shape[0] 120 | boxes3d_center = np.zeros((N, 7)) 121 | boxes3d_center[:, [0, 1, 4, 5, 6]] = boxes_center 122 | boxes3d_corner = center_to_corner_box3d( 123 | boxes3d_center, coordinate=coordinate) 124 | 125 | return boxes3d_corner[:, 0:4, 0:2] 126 | 127 | 128 | def center_to_corner_box3d(boxes_center, coordinate='lidar'): 129 | # (N, 7) -> (N, 8, 3) 130 | N = boxes_center.shape[0] 131 | ret = np.zeros((N, 8, 3), dtype=np.float32) 132 | 133 | if coordinate == 'camera': 134 | boxes_center = camera_to_lidar_box(boxes_center) 135 | 136 | for i in range(N): 137 | box = boxes_center[i] 138 | translation = box[0:3] 139 | size = box[3:6] 140 | rotation = [0, 0, box[-1]] 141 | 142 | h, w, l = size[0], size[1], size[2] 143 | trackletBox = np.array([ # in velodyne coordinates around zero point and without orientation yet 144 | [-l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2], \ 145 | [w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2], \ 146 | [0, 0, 0, 0, h, h, h, h]]) 147 | 148 | # re-create 3D bounding box in velodyne coordinate system 149 | yaw = rotation[2] 150 | rotMat = np.array([ 151 | [np.cos(yaw), -np.sin(yaw), 0.0], 152 | [np.sin(yaw), np.cos(yaw), 0.0], 153 | [0.0, 0.0, 1.0]]) 154 | cornerPosInVelo = np.dot(rotMat, trackletBox) + \ 155 | np.tile(translation, (8, 1)).T 156 | box3d = cornerPosInVelo.transpose() 157 | ret[i] = box3d 158 | 159 | if coordinate == 'camera': 160 | for idx in range(len(ret)): 161 | ret[idx] = lidar_to_camera_point(ret[idx]) 162 | 163 | return ret 164 | 165 | 166 | def corner_to_center_box2d(boxes_corner, coordinate='lidar'): 167 | # (N, 4, 2) -> (N, 5) x,y,w,l,r 168 | N = boxes_corner.shape[0] 169 | boxes3d_corner = np.zeros((N, 8, 3)) 170 | boxes3d_corner[:, 0:4, 0:2] = boxes_corner 171 | boxes3d_corner[:, 4:8, 0:2] = boxes_corner 172 | boxes3d_center = corner_to_center_box3d( 173 | boxes3d_corner, coordinate=coordinate) 174 | 175 | return boxes3d_center[:, [0, 1, 4, 5, 6]] 176 | 177 | 178 | def corner_to_standup_box2d(boxes_corner): 179 | # (N, 4, 2) -> (N, 4) x1, y1, x2, y2 180 | N = boxes_corner.shape[0] 181 | standup_boxes2d = np.zeros((N, 4)) 182 | standup_boxes2d[:, 0] = np.min(boxes_corner[:, :, 0], axis=1) 183 | standup_boxes2d[:, 1] = np.min(boxes_corner[:, :, 1], axis=1) 184 | standup_boxes2d[:, 2] = np.max(boxes_corner[:, :, 0], axis=1) 185 | standup_boxes2d[:, 3] = np.max(boxes_corner[:, :, 1], axis=1) 186 | 187 | return standup_boxes2d 188 | 189 | 190 | # TODO: 0/90 may be not correct 191 | def anchor_to_standup_box2d(anchors): 192 | # (N, 4) -> (N, 4) x,y,w,l -> x1,y1,x2,y2 193 | anchor_standup = np.zeros_like(anchors) 194 | # r == 0 195 | anchor_standup[::2, 0] = anchors[::2, 0] - anchors[::2, 3] / 2 196 | anchor_standup[::2, 1] = anchors[::2, 1] - anchors[::2, 2] / 2 197 | anchor_standup[::2, 2] = anchors[::2, 0] + anchors[::2, 3] / 2 198 | anchor_standup[::2, 3] = anchors[::2, 1] + anchors[::2, 2] / 2 199 | # r == pi/2 200 | anchor_standup[1::2, 0] = anchors[1::2, 0] - anchors[1::2, 2] / 2 201 | anchor_standup[1::2, 1] = anchors[1::2, 1] - anchors[1::2, 3] / 2 202 | anchor_standup[1::2, 2] = anchors[1::2, 0] + anchors[1::2, 2] / 2 203 | anchor_standup[1::2, 3] = anchors[1::2, 1] + anchors[1::2, 3] / 2 204 | 205 | return anchor_standup 206 | 207 | 208 | def corner_to_center_box3d(boxes_corner, coordinate='camera'): 209 | # (N, 8, 3) -> (N, 7) x,y,z,h,w,l,ry/z 210 | if coordinate == 'lidar': 211 | for idx in range(len(boxes_corner)): 212 | boxes_corner[idx] = lidar_to_camera_point(boxes_corner[idx]) 213 | ret = [] 214 | for roi in boxes_corner: 215 | if cfg.CORNER2CENTER_AVG: # average version 216 | roi = np.array(roi) 217 | h = abs(np.sum(roi[:4, 1] - roi[4:, 1]) / 4) 218 | w = np.sum( 219 | np.sqrt(np.sum((roi[0, [0, 2]] - roi[3, [0, 2]])**2)) + 220 | np.sqrt(np.sum((roi[1, [0, 2]] - roi[2, [0, 2]])**2)) + 221 | np.sqrt(np.sum((roi[4, [0, 2]] - roi[7, [0, 2]])**2)) + 222 | np.sqrt(np.sum((roi[5, [0, 2]] - roi[6, [0, 2]])**2)) 223 | ) / 4 224 | l = np.sum( 225 | np.sqrt(np.sum((roi[0, [0, 2]] - roi[1, [0, 2]])**2)) + 226 | np.sqrt(np.sum((roi[2, [0, 2]] - roi[3, [0, 2]])**2)) + 227 | np.sqrt(np.sum((roi[4, [0, 2]] - roi[5, [0, 2]])**2)) + 228 | np.sqrt(np.sum((roi[6, [0, 2]] - roi[7, [0, 2]])**2)) 229 | ) / 4 230 | x = np.sum(roi[:, 0], axis=0) / 8 231 | y = np.sum(roi[0:4, 1], axis=0) / 4 232 | z = np.sum(roi[:, 2], axis=0) / 8 233 | ry = np.sum( 234 | math.atan2(roi[2, 0] - roi[1, 0], roi[2, 2] - roi[1, 2]) + 235 | math.atan2(roi[6, 0] - roi[5, 0], roi[6, 2] - roi[5, 2]) + 236 | math.atan2(roi[3, 0] - roi[0, 0], roi[3, 2] - roi[0, 2]) + 237 | math.atan2(roi[7, 0] - roi[4, 0], roi[7, 2] - roi[4, 2]) + 238 | math.atan2(roi[0, 2] - roi[1, 2], roi[1, 0] - roi[0, 0]) + 239 | math.atan2(roi[4, 2] - roi[5, 2], roi[5, 0] - roi[4, 0]) + 240 | math.atan2(roi[3, 2] - roi[2, 2], roi[2, 0] - roi[3, 0]) + 241 | math.atan2(roi[7, 2] - roi[6, 2], roi[6, 0] - roi[7, 0]) 242 | ) / 8 243 | if w > l: 244 | w, l = l, w 245 | ry = angle_in_limit(ry + np.pi / 2) 246 | else: # max version 247 | h = max(abs(roi[:4, 1] - roi[4:, 1])) 248 | w = np.max( 249 | np.sqrt(np.sum((roi[0, [0, 2]] - roi[3, [0, 2]])**2)) + 250 | np.sqrt(np.sum((roi[1, [0, 2]] - roi[2, [0, 2]])**2)) + 251 | np.sqrt(np.sum((roi[4, [0, 2]] - roi[7, [0, 2]])**2)) + 252 | np.sqrt(np.sum((roi[5, [0, 2]] - roi[6, [0, 2]])**2)) 253 | ) 254 | l = np.max( 255 | np.sqrt(np.sum((roi[0, [0, 2]] - roi[1, [0, 2]])**2)) + 256 | np.sqrt(np.sum((roi[2, [0, 2]] - roi[3, [0, 2]])**2)) + 257 | np.sqrt(np.sum((roi[4, [0, 2]] - roi[5, [0, 2]])**2)) + 258 | np.sqrt(np.sum((roi[6, [0, 2]] - roi[7, [0, 2]])**2)) 259 | ) 260 | x = np.sum(roi[:, 0], axis=0) / 8 261 | y = np.sum(roi[0:4, 1], axis=0) / 4 262 | z = np.sum(roi[:, 2], axis=0) / 8 263 | ry = np.sum( 264 | math.atan2(roi[2, 0] - roi[1, 0], roi[2, 2] - roi[1, 2]) + 265 | math.atan2(roi[6, 0] - roi[5, 0], roi[6, 2] - roi[5, 2]) + 266 | math.atan2(roi[3, 0] - roi[0, 0], roi[3, 2] - roi[0, 2]) + 267 | math.atan2(roi[7, 0] - roi[4, 0], roi[7, 2] - roi[4, 2]) + 268 | math.atan2(roi[0, 2] - roi[1, 2], roi[1, 0] - roi[0, 0]) + 269 | math.atan2(roi[4, 2] - roi[5, 2], roi[5, 0] - roi[4, 0]) + 270 | math.atan2(roi[3, 2] - roi[2, 2], roi[2, 0] - roi[3, 0]) + 271 | math.atan2(roi[7, 2] - roi[6, 2], roi[6, 0] - roi[7, 0]) 272 | ) / 8 273 | if w > l: 274 | w, l = l, w 275 | ry = angle_in_limit(ry + np.pi / 2) 276 | ret.append([x, y, z, h, w, l, ry]) 277 | if coordinate == 'lidar': 278 | ret = camera_to_lidar_box(np.array(ret)) 279 | 280 | return np.array(ret) 281 | 282 | 283 | # this just for visulize and testing 284 | def lidar_box3d_to_camera_box(boxes3d, cal_projection=False): 285 | # (N, 7) -> (N, 4)/(N, 8, 2) x,y,z,h,w,l,rz -> x1,y1,x2,y2/8*(x, y) 286 | num = len(boxes3d) 287 | boxes2d = np.zeros((num, 4), dtype=np.int32) 288 | projections = np.zeros((num, 8, 2), dtype=np.float32) 289 | 290 | lidar_boxes3d_corner = center_to_corner_box3d(boxes3d, coordinate='lidar') 291 | P2 = np.array(cfg.MATRIX_P2) 292 | 293 | for n in range(num): 294 | box3d = lidar_boxes3d_corner[n] 295 | box3d = lidar_to_camera_point(box3d) 296 | points = np.hstack((box3d, np.ones((8, 1)))).T # (8, 4) -> (4, 8) 297 | points = np.matmul(P2, points).T 298 | points[:, 0] /= points[:, 2] 299 | points[:, 1] /= points[:, 2] 300 | 301 | projections[n] = points[:, 0:2] 302 | minx = int(np.min(points[:, 0])) 303 | maxx = int(np.max(points[:, 0])) 304 | miny = int(np.min(points[:, 1])) 305 | maxy = int(np.max(points[:, 1])) 306 | 307 | boxes2d[n, :] = minx, miny, maxx, maxy 308 | 309 | return projections if cal_projection else boxes2d 310 | 311 | 312 | def lidar_to_bird_view_img(lidar, factor=1): 313 | # Input: 314 | # lidar: (N', 4) 315 | # Output: 316 | # birdview: (w, l, 3) 317 | birdview = np.zeros( 318 | (cfg.INPUT_HEIGHT * factor, cfg.INPUT_WIDTH * factor, 1)) 319 | for point in lidar: 320 | x, y = point[0:2] 321 | if cfg.X_MIN < x < cfg.X_MAX and cfg.Y_MIN < y < cfg.Y_MAX: 322 | x, y = int((x - cfg.X_MIN) / cfg.VOXEL_X_SIZE * 323 | factor), int((y - cfg.Y_MIN) / cfg.VOXEL_Y_SIZE * factor) 324 | birdview[y, x] += 1 325 | birdview = birdview - np.min(birdview) 326 | divisor = np.max(birdview) - np.min(birdview) 327 | # TODO: adjust this factor 328 | birdview = np.clip((birdview / divisor * 255) * 329 | 5 * factor, a_min=0, a_max=255) 330 | birdview = np.tile(birdview, 3).astype(np.uint8) 331 | 332 | return birdview 333 | 334 | 335 | def draw_lidar_box3d_on_image(img, boxes3d, scores, gt_boxes3d=np.array([]), 336 | color=(0, 255, 255), gt_color=(255, 0, 255), thickness=1): 337 | # Input: 338 | # img: (h, w, 3) 339 | # boxes3d (N, 7) [x, y, z, h, w, l, r] 340 | # scores 341 | # gt_boxes3d (N, 7) [x, y, z, h, w, l, r] 342 | img = img.copy() 343 | projections = lidar_box3d_to_camera_box(boxes3d, cal_projection=True) 344 | gt_projections = lidar_box3d_to_camera_box(gt_boxes3d, cal_projection=True) 345 | 346 | # draw projections 347 | for qs in projections: 348 | for k in range(0, 4): 349 | i, j = k, (k + 1) % 4 350 | cv2.line(img, (qs[i, 0], qs[i, 1]), (qs[j, 0], 351 | qs[j, 1]), color, thickness, cv2.LINE_AA) 352 | 353 | i, j = k + 4, (k + 1) % 4 + 4 354 | cv2.line(img, (qs[i, 0], qs[i, 1]), (qs[j, 0], 355 | qs[j, 1]), color, thickness, cv2.LINE_AA) 356 | 357 | i, j = k, k + 4 358 | cv2.line(img, (qs[i, 0], qs[i, 1]), (qs[j, 0], 359 | qs[j, 1]), color, thickness, cv2.LINE_AA) 360 | 361 | # draw gt projections 362 | for qs in gt_projections: 363 | for k in range(0, 4): 364 | i, j = k, (k + 1) % 4 365 | cv2.line(img, (qs[i, 0], qs[i, 1]), (qs[j, 0], 366 | qs[j, 1]), gt_color, thickness, cv2.LINE_AA) 367 | 368 | i, j = k + 4, (k + 1) % 4 + 4 369 | cv2.line(img, (qs[i, 0], qs[i, 1]), (qs[j, 0], 370 | qs[j, 1]), gt_color, thickness, cv2.LINE_AA) 371 | 372 | i, j = k, k + 4 373 | cv2.line(img, (qs[i, 0], qs[i, 1]), (qs[j, 0], 374 | qs[j, 1]), gt_color, thickness, cv2.LINE_AA) 375 | 376 | return cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB) 377 | 378 | 379 | def draw_lidar_box3d_on_birdview(birdview, boxes3d, scores, gt_boxes3d=np.array([]), 380 | color=(0, 255, 255), gt_color=(255, 0, 255), thickness=1, factor=1): 381 | # Input: 382 | # birdview: (h, w, 3) 383 | # boxes3d (N, 7) [x, y, z, h, w, l, r] 384 | # scores 385 | # gt_boxes3d (N, 7) [x, y, z, h, w, l, r] 386 | img = birdview.copy() 387 | corner_boxes3d = center_to_corner_box3d(boxes3d, coordinate='lidar') 388 | corner_gt_boxes3d = center_to_corner_box3d(gt_boxes3d, coordinate='lidar') 389 | # draw gt 390 | for box in corner_gt_boxes3d: 391 | x0, y0 = lidar_to_bird_view(*box[0, 0:2], factor=factor) 392 | x1, y1 = lidar_to_bird_view(*box[1, 0:2], factor=factor) 393 | x2, y2 = lidar_to_bird_view(*box[2, 0:2], factor=factor) 394 | x3, y3 = lidar_to_bird_view(*box[3, 0:2], factor=factor) 395 | 396 | cv2.line(img, (int(x0), int(y0)), (int(x1), int(y1)), 397 | gt_color, thickness, cv2.LINE_AA) 398 | cv2.line(img, (int(x1), int(y1)), (int(x2), int(y2)), 399 | gt_color, thickness, cv2.LINE_AA) 400 | cv2.line(img, (int(x2), int(y2)), (int(x3), int(y3)), 401 | gt_color, thickness, cv2.LINE_AA) 402 | cv2.line(img, (int(x3), int(y3)), (int(x0), int(y0)), 403 | gt_color, thickness, cv2.LINE_AA) 404 | 405 | # draw detections 406 | for box in corner_boxes3d: 407 | x0, y0 = lidar_to_bird_view(*box[0, 0:2], factor=factor) 408 | x1, y1 = lidar_to_bird_view(*box[1, 0:2], factor=factor) 409 | x2, y2 = lidar_to_bird_view(*box[2, 0:2], factor=factor) 410 | x3, y3 = lidar_to_bird_view(*box[3, 0:2], factor=factor) 411 | 412 | cv2.line(img, (int(x0), int(y0)), (int(x1), int(y1)), 413 | color, thickness, cv2.LINE_AA) 414 | cv2.line(img, (int(x1), int(y1)), (int(x2), int(y2)), 415 | color, thickness, cv2.LINE_AA) 416 | cv2.line(img, (int(x2), int(y2)), (int(x3), int(y3)), 417 | color, thickness, cv2.LINE_AA) 418 | cv2.line(img, (int(x3), int(y3)), (int(x0), int(y0)), 419 | color, thickness, cv2.LINE_AA) 420 | 421 | return cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_BGR2RGB) 422 | 423 | 424 | def label_to_gt_box3d(labels, cls='Car', coordinate='camera'): 425 | # Input: 426 | # label: (N, N') 427 | # cls: 'Car' or 'Pedestrain' or 'Cyclist' 428 | # coordinate: 'camera' or 'lidar' 429 | # Output: 430 | # (N, N', 7) 431 | boxes3d = [] 432 | if cls == 'Car': 433 | acc_cls = ['Car', 'Van'] 434 | elif cls == 'Pedestrian': 435 | acc_cls = ['Pedestrian'] 436 | elif cls == 'Cyclist': 437 | acc_cls = ['Cyclist'] 438 | else: # all 439 | acc_cls = [] 440 | 441 | for label in labels: 442 | boxes3d_a_label = [] 443 | for line in label: 444 | ret = line.split() 445 | if ret[0] in acc_cls or acc_cls == []: 446 | h, w, l, x, y, z, r = [float(i) for i in ret[-7:]] 447 | box3d = np.array([x, y, z, h, w, l, r]) 448 | boxes3d_a_label.append(box3d) 449 | if coordinate == 'lidar': 450 | boxes3d_a_label = camera_to_lidar_box(np.array(boxes3d_a_label)) 451 | 452 | boxes3d.append(np.array(boxes3d_a_label).reshape(-1, 7)) 453 | return boxes3d 454 | 455 | 456 | def box3d_to_label(batch_box3d, batch_cls, batch_score=[], coordinate='camera'): 457 | # Input: 458 | # (N, N', 7) x y z h w l r 459 | # (N, N') 460 | # cls: (N, N') 'Car' or 'Pedestrain' or 'Cyclist' 461 | # coordinate(input): 'camera' or 'lidar' 462 | # Output: 463 | # label: (N, N') N batches and N lines 464 | batch_label = [] 465 | if batch_score: 466 | template = '{} ' + ' '.join(['{:.4f}' for i in range(15)]) + '\n' 467 | for boxes, scores, clses in zip(batch_box3d, batch_score, batch_cls): 468 | label = [] 469 | for box, score, cls in zip(boxes, scores, clses): 470 | if coordinate == 'camera': 471 | box3d = box 472 | box2d = lidar_box3d_to_camera_box( 473 | camera_to_lidar_box(box[np.newaxis, :].astype(np.float32)), cal_projection=False)[0] 474 | else: 475 | box3d = lidar_to_camera_box( 476 | box[np.newaxis, :].astype(np.float32))[0] 477 | box2d = lidar_box3d_to_camera_box( 478 | box[np.newaxis, :].astype(np.float32), cal_projection=False)[0] 479 | x, y, z, h, w, l, r = box3d 480 | box3d = [h, w, l, x, y, z, r] 481 | label.append(template.format( 482 | cls, 0, 0, 0, *box2d, *box3d, float(score))) 483 | batch_label.append(label) 484 | else: 485 | template = '{} ' + ' '.join(['{:.4f}' for i in range(14)]) + '\n' 486 | for boxes, clses in zip(batch_box3d, batch_cls): 487 | label = [] 488 | for box, cls in zip(boxes, clses): 489 | if coordinate == 'camera': 490 | box3d = box 491 | box2d = lidar_box3d_to_camera_box( 492 | camera_to_lidar_box(box[np.newaxis, :].astype(np.float32)), cal_projection=False)[0] 493 | else: 494 | box3d = lidar_to_camera_box( 495 | box[np.newaxis, :].astype(np.float32))[0] 496 | box2d = lidar_box3d_to_camera_box( 497 | box[np.newaxis, :].astype(np.float32), cal_projection=False)[0] 498 | x, y, z, h, w, l, r = box3d 499 | box3d = [h, w, l, x, y, z, r] 500 | label.append(template.format(cls, 0, 0, 0, *box2d, *box3d)) 501 | batch_label.append(label) 502 | 503 | return np.array(batch_label) 504 | 505 | 506 | def cal_anchors(): 507 | # Output: 508 | # anchors: (w, l, 2, 7) x y z h w l r 509 | x = np.linspace(cfg.X_MIN, cfg.X_MAX, cfg.FEATURE_WIDTH) 510 | y = np.linspace(cfg.Y_MIN, cfg.Y_MAX, cfg.FEATURE_HEIGHT) 511 | cx, cy = np.meshgrid(x, y) 512 | # all is (w, l, 2) 513 | cx = np.tile(cx[..., np.newaxis], 2) 514 | cy = np.tile(cy[..., np.newaxis], 2) 515 | cz = np.ones_like(cx) * cfg.ANCHOR_Z 516 | w = np.ones_like(cx) * cfg.ANCHOR_W 517 | l = np.ones_like(cx) * cfg.ANCHOR_L 518 | h = np.ones_like(cx) * cfg.ANCHOR_H 519 | r = np.ones_like(cx) 520 | r[..., 0] = 0 # 0 521 | r[..., 1] = 90 / 180 * np.pi # 90 522 | 523 | # 7*(w,l,2) -> (w, l, 2, 7) 524 | anchors = np.stack([cx, cy, cz, h, w, l, r], axis=-1) 525 | 526 | return anchors 527 | 528 | 529 | def cal_rpn_target(labels, feature_map_shape, anchors, cls='Car', coordinate='lidar'): 530 | # Input: 531 | # labels: (N, N') 532 | # feature_map_shape: (w, l) 533 | # anchors: (w, l, 2, 7) 534 | # Output: 535 | # pos_equal_one (N, w, l, 2) 536 | # neg_equal_one (N, w, l, 2) 537 | # targets (N, w, l, 14) 538 | # attention: cal IoU on birdview 539 | batch_size = labels.shape[0] 540 | batch_gt_boxes3d = label_to_gt_box3d(labels, cls=cls, coordinate=coordinate) 541 | # defined in eq(1) in 2.2 542 | anchors_reshaped = anchors.reshape(-1, 7) 543 | anchors_d = np.sqrt(anchors_reshaped[:, 4]**2 + anchors_reshaped[:, 5]**2) 544 | pos_equal_one = np.zeros((batch_size, *feature_map_shape, 2)) 545 | neg_equal_one = np.zeros((batch_size, *feature_map_shape, 2)) 546 | targets = np.zeros((batch_size, *feature_map_shape, 14)) 547 | 548 | for batch_id in range(batch_size): 549 | # BOTTLENECK 550 | anchors_standup_2d = anchor_to_standup_box2d( 551 | anchors_reshaped[:, [0, 1, 4, 5]]) 552 | # BOTTLENECK 553 | gt_standup_2d = corner_to_standup_box2d(center_to_corner_box2d( 554 | batch_gt_boxes3d[batch_id][:, [0, 1, 4, 5, 6]], coordinate=coordinate)) 555 | 556 | iou = bbox_overlaps( 557 | np.ascontiguousarray(anchors_standup_2d).astype(np.float32), 558 | np.ascontiguousarray(gt_standup_2d).astype(np.float32), 559 | ) 560 | # iou = cal_box3d_iou( 561 | # anchors_reshaped, 562 | # batch_gt_boxes3d[batch_id] 563 | # ) 564 | 565 | # find anchor with highest iou(iou should also > 0) 566 | id_highest = np.argmax(iou.T, axis=1) 567 | id_highest_gt = np.arange(iou.T.shape[0]) 568 | mask = iou.T[id_highest_gt, id_highest] > 0 569 | id_highest, id_highest_gt = id_highest[mask], id_highest_gt[mask] 570 | 571 | # find anchor iou > cfg.XXX_POS_IOU 572 | id_pos, id_pos_gt = np.where(iou > cfg.RPN_POS_IOU) 573 | 574 | # find anchor iou < cfg.XXX_NEG_IOU 575 | id_neg = np.where(np.sum(iou < cfg.RPN_NEG_IOU, 576 | axis=1) == iou.shape[1])[0] 577 | 578 | id_pos = np.concatenate([id_pos, id_highest]) 579 | id_pos_gt = np.concatenate([id_pos_gt, id_highest_gt]) 580 | 581 | # TODO: uniquify the array in a more scientific way 582 | id_pos, index = np.unique(id_pos, return_index=True) 583 | id_pos_gt = id_pos_gt[index] 584 | id_neg.sort() 585 | 586 | # cal the target and set the equal one 587 | index_x, index_y, index_z = np.unravel_index( 588 | id_pos, (*feature_map_shape, 2)) 589 | pos_equal_one[batch_id, index_x, index_y, index_z] = 1 590 | 591 | # ATTENTION: index_z should be np.array 592 | targets[batch_id, index_x, index_y, np.array(index_z) * 7] = ( 593 | batch_gt_boxes3d[batch_id][id_pos_gt, 0] - anchors_reshaped[id_pos, 0]) / anchors_d[id_pos] 594 | targets[batch_id, index_x, index_y, np.array(index_z) * 7 + 1] = ( 595 | batch_gt_boxes3d[batch_id][id_pos_gt, 1] - anchors_reshaped[id_pos, 1]) / anchors_d[id_pos] 596 | targets[batch_id, index_x, index_y, np.array(index_z) * 7 + 2] = ( 597 | batch_gt_boxes3d[batch_id][id_pos_gt, 2] - anchors_reshaped[id_pos, 2]) / cfg.ANCHOR_H 598 | targets[batch_id, index_x, index_y, np.array(index_z) * 7 + 3] = np.log( 599 | batch_gt_boxes3d[batch_id][id_pos_gt, 3] / anchors_reshaped[id_pos, 3]) 600 | targets[batch_id, index_x, index_y, np.array(index_z) * 7 + 4] = np.log( 601 | batch_gt_boxes3d[batch_id][id_pos_gt, 4] / anchors_reshaped[id_pos, 4]) 602 | targets[batch_id, index_x, index_y, np.array(index_z) * 7 + 5] = np.log( 603 | batch_gt_boxes3d[batch_id][id_pos_gt, 5] / anchors_reshaped[id_pos, 5]) 604 | targets[batch_id, index_x, index_y, np.array(index_z) * 7 + 6] = ( 605 | batch_gt_boxes3d[batch_id][id_pos_gt, 6] - anchors_reshaped[id_pos, 6]) 606 | 607 | index_x, index_y, index_z = np.unravel_index( 608 | id_neg, (*feature_map_shape, 2)) 609 | neg_equal_one[batch_id, index_x, index_y, index_z] = 1 610 | # to avoid a box be pos/neg in the same time 611 | index_x, index_y, index_z = np.unravel_index( 612 | id_highest, (*feature_map_shape, 2)) 613 | neg_equal_one[batch_id, index_x, index_y, index_z] = 0 614 | 615 | return pos_equal_one, neg_equal_one, targets 616 | 617 | 618 | # BOTTLENECK 619 | def delta_to_boxes3d(deltas, anchors, coordinate='lidar'): 620 | # Input: 621 | # deltas: (N, w, l, 14) 622 | # feature_map_shape: (w, l) 623 | # anchors: (w, l, 2, 7) 624 | 625 | # Ouput: 626 | # boxes3d: (N, w*l*2, 7) 627 | anchors_reshaped = anchors.reshape(-1, 7) 628 | deltas = deltas.reshape(deltas.shape[0], -1, 7) 629 | anchors_d = np.sqrt(anchors_reshaped[:, 4]**2 + anchors_reshaped[:, 5]**2) 630 | boxes3d = np.zeros_like(deltas) 631 | boxes3d[..., [0, 1]] = deltas[..., [0, 1]] * \ 632 | anchors_d[:, np.newaxis] + anchors_reshaped[..., [0, 1]] 633 | boxes3d[..., [2]] = deltas[..., [2]] * \ 634 | cfg.ANCHOR_H + anchors_reshaped[..., [2]] 635 | boxes3d[..., [3, 4, 5]] = np.exp( 636 | deltas[..., [3, 4, 5]]) * anchors_reshaped[..., [3, 4, 5]] 637 | boxes3d[..., 6] = deltas[..., 6] + anchors_reshaped[..., 6] 638 | 639 | return boxes3d 640 | 641 | 642 | def point_transform(points, tx, ty, tz, rx=0, ry=0, rz=0): 643 | # Input: 644 | # points: (N, 3) 645 | # rx/y/z: in radians 646 | # Output: 647 | # points: (N, 3) 648 | N = points.shape[0] 649 | points = np.hstack([points, np.ones((N, 1))]) 650 | 651 | mat1 = np.eye(4) 652 | mat1[3, 0:3] = tx, ty, tz 653 | points = np.matmul(points, mat1) 654 | 655 | if rx != 0: 656 | mat = np.zeros((4, 4)) 657 | mat[0, 0] = 1 658 | mat[3, 3] = 1 659 | mat[1, 1] = np.cos(rx) 660 | mat[1, 2] = -np.sin(rx) 661 | mat[2, 1] = np.sin(rx) 662 | mat[2, 2] = np.cos(rx) 663 | points = np.matmul(points, mat) 664 | 665 | if ry != 0: 666 | mat = np.zeros((4, 4)) 667 | mat[1, 1] = 1 668 | mat[3, 3] = 1 669 | mat[0, 0] = np.cos(ry) 670 | mat[0, 2] = np.sin(ry) 671 | mat[2, 0] = -np.sin(ry) 672 | mat[2, 2] = np.cos(ry) 673 | points = np.matmul(points, mat) 674 | 675 | if rz != 0: 676 | mat = np.zeros((4, 4)) 677 | mat[2, 2] = 1 678 | mat[3, 3] = 1 679 | mat[0, 0] = np.cos(rz) 680 | mat[0, 1] = -np.sin(rz) 681 | mat[1, 0] = np.sin(rz) 682 | mat[1, 1] = np.cos(rz) 683 | points = np.matmul(points, mat) 684 | 685 | return points[:, 0:3] 686 | 687 | 688 | def box_transform(boxes, tx, ty, tz, r=0, coordinate='lidar'): 689 | # Input: 690 | # boxes: (N, 7) x y z h w l rz/y 691 | # Output: 692 | # boxes: (N, 7) x y z h w l rz/y 693 | boxes_corner = center_to_corner_box3d( 694 | boxes, coordinate=coordinate) # (N, 8, 3) 695 | for idx in range(len(boxes_corner)): 696 | if coordinate == 'lidar': 697 | boxes_corner[idx] = point_transform( 698 | boxes_corner[idx], tx, ty, tz, rz=r) 699 | else: 700 | boxes_corner[idx] = point_transform( 701 | boxes_corner[idx], tx, ty, tz, ry=r) 702 | 703 | return corner_to_center_box3d(boxes_corner, coordinate=coordinate) 704 | 705 | 706 | def cal_iou2d(box1, box2): 707 | # Input: 708 | # box1/2: x, y, w, l, r 709 | # Output : 710 | # iou 711 | buf1 = np.zeros((cfg.INPUT_HEIGHT, cfg.INPUT_WIDTH, 3)) 712 | buf2 = np.zeros((cfg.INPUT_HEIGHT, cfg.INPUT_WIDTH, 3)) 713 | tmp = center_to_corner_box2d(np.array([box1, box2]), coordinate='lidar') 714 | box1_corner = batch_lidar_to_bird_view(tmp[0]).astype(np.int32) 715 | box2_corner = batch_lidar_to_bird_view(tmp[1]).astype(np.int32) 716 | buf1 = cv2.fillConvexPoly(buf1, box1_corner, color=(1,1,1))[..., 0] 717 | buf2 = cv2.fillConvexPoly(buf2, box2_corner, color=(1,1,1))[..., 0] 718 | indiv = np.sum(np.absolute(buf1-buf2)) 719 | share = np.sum((buf1 + buf2) == 2) 720 | if indiv == 0: 721 | return 0.0 # when target is out of bound 722 | return share / (indiv + share) 723 | 724 | def cal_z_intersect(cz1, h1, cz2, h2): 725 | b1z1, b1z2 = cz1 - h1 / 2, cz1 + h1 / 2 726 | b2z1, b2z2 = cz2 - h2 / 2, cz2 + h2 / 2 727 | if b1z1 > b2z2 or b2z1 > b1z2: 728 | return 0 729 | elif b2z1 <= b1z1 <= b2z2: 730 | if b1z2 <= b2z2: 731 | return h1 / h2 732 | else: 733 | return (b2z2 - b1z1) / (b1z2 - b2z1) 734 | elif b1z1 < b2z1 < b1z2: 735 | if b2z2 <= b1z2: 736 | return h2 / h1 737 | else: 738 | return (b1z2 - b2z1) / (b2z2 - b1z1) 739 | 740 | 741 | def cal_iou3d(box1, box2): 742 | # Input: 743 | # box1/2: x, y, z, h, w, l, r 744 | # Output: 745 | # iou 746 | buf1 = np.zeros((cfg.INPUT_HEIGHT, cfg.INPUT_WIDTH, 3)) 747 | buf2 = np.zeros((cfg.INPUT_HEIGHT, cfg.INPUT_WIDTH, 3)) 748 | tmp = center_to_corner_box2d(np.array([box1[[0,1,4,5,6]], box2[[0,1,4,5,6]]]), coordinate='lidar') 749 | box1_corner = batch_lidar_to_bird_view(tmp[0]).astype(np.int32) 750 | box2_corner = batch_lidar_to_bird_view(tmp[1]).astype(np.int32) 751 | buf1 = cv2.fillConvexPoly(buf1, box1_corner, color=(1,1,1))[..., 0] 752 | buf2 = cv2.fillConvexPoly(buf2, box2_corner, color=(1,1,1))[..., 0] 753 | share = np.sum((buf1 + buf2) == 2) 754 | area1 = np.sum(buf1) 755 | area2 = np.sum(buf2) 756 | 757 | z1, h1, z2, h2 = box1[2], box1[3], box2[2], box2[3] 758 | z_intersect = cal_z_intersect(z1, h1, z2, h2) 759 | 760 | return share * z_intersect / (area1 * h1 + area2 * h2 - share * z_intersect) 761 | 762 | 763 | def cal_box3d_iou(boxes3d, gt_boxes3d, cal_3d=0): 764 | # Inputs: 765 | # boxes3d: (N1, 7) x,y,z,h,w,l,r 766 | # gt_boxed3d: (N2, 7) x,y,z,h,w,l,r 767 | # Outputs: 768 | # iou: (N1, N2) 769 | N1 = len(boxes3d) 770 | N2 = len(gt_boxes3d) 771 | output = np.zeros((N1, N2), dtype=np.float32) 772 | 773 | for idx in range(N1): 774 | for idy in range(N2): 775 | if cal_3d: 776 | output[idx, idy] = float( 777 | cal_iou3d(boxes3d[idx], gt_boxes3d[idy])) 778 | else: 779 | output[idx, idy] = float( 780 | cal_iou2d(boxes3d[idx, [0, 1, 4, 5, 6]], gt_boxes3d[idy, [0, 1, 4, 5, 6]])) 781 | 782 | return output 783 | 784 | 785 | def cal_box2d_iou(boxes2d, gt_boxes2d): 786 | # Inputs: 787 | # boxes2d: (N1, 5) x,y,w,l,r 788 | # gt_boxes2d: (N2, 5) x,y,w,l,r 789 | # Outputs: 790 | # iou: (N1, N2) 791 | N1 = len(boxes2d) 792 | N2 = len(gt_boxes2d) 793 | output = np.zeros((N1, N2), dtype=np.float32) 794 | for idx in range(N1): 795 | for idy in range(N2): 796 | output[idx, idy] = cal_iou2d(boxes2d[idx], gt_boxes2d[idy]) 797 | 798 | return output 799 | 800 | 801 | if __name__ == '__main__': 802 | pass 803 | --------------------------------------------------------------------------------