├── Images
├── pr_curve.png
└── detections.png
├── __pycache__
├── PIXOR.cpython-36.pyc
└── evaluate_model.cpython-36.pyc
├── .idea
├── vcs.xml
├── misc.xml
├── modules.xml
├── PIXOR.iml
├── dictionaries
│ └── MatsSteinweg.xml
└── workspace.xml
├── .gitignore
├── config.py
├── visualize_training.py
├── visualize_evaluation.py
├── README.md
├── detector.py
├── visualize_dataset.py
├── train_model.py
├── PIXOR.py
├── load_data.py
├── kitti_utils.py
└── evaluate_model.py
/Images/pr_curve.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matssteinweg/PIXOR/HEAD/Images/pr_curve.png
--------------------------------------------------------------------------------
/Images/detections.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matssteinweg/PIXOR/HEAD/Images/detections.png
--------------------------------------------------------------------------------
/__pycache__/PIXOR.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matssteinweg/PIXOR/HEAD/__pycache__/PIXOR.cpython-36.pyc
--------------------------------------------------------------------------------
/__pycache__/evaluate_model.cpython-36.pyc:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matssteinweg/PIXOR/HEAD/__pycache__/evaluate_model.cpython-36.pyc
--------------------------------------------------------------------------------
/.idea/vcs.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
--------------------------------------------------------------------------------
/.idea/misc.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
--------------------------------------------------------------------------------
/.idea/modules.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | .idea
2 | *.txt
3 | *.bin
4 | Data
5 | Images/Detections
6 | *.pt
7 | *.npz
8 | .DS_Store
9 | __pycache__/config.cpython-36.pyc
10 | __pycache__/early_stopping.cpython-36.pyc
11 | __pycache__/kitti_utils.cpython-36.pyc
12 | __pycache__/load_data.cpython-36.pyc
13 | __pycache__/PIXOR_Net.cpython-36.pyc
14 | __pycache__/show_gt.cpython-36.pyc
15 | Metrics/.DS_Store
16 | Models/.DS_Store
17 |
--------------------------------------------------------------------------------
/.idea/PIXOR.iml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 |
5 |
6 |
7 |
8 |
9 |
10 |
11 |
--------------------------------------------------------------------------------
/.idea/dictionaries/MatsSteinweg.xml:
--------------------------------------------------------------------------------
1 |
2 |
3 |
4 | calib
5 | cuda
6 | dataset
7 | datasets
8 | ious
9 | lidar
10 | optim
11 | overla
12 | pixor
13 | resized
14 | thresholded
15 | unsqueeze
16 | velodyne
17 | voxel
18 | voxelization
19 | voxelize
20 | voxelized
21 | voxels
22 |
23 |
24 |
--------------------------------------------------------------------------------
/config.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 |
3 | ###########################
4 | # project-level constants #
5 | ###########################
6 |
7 | # observable area in m in velodyne coordinates
8 | VOX_Y_MIN = -40
9 | VOX_Y_MAX = +40
10 | VOX_X_MIN = 0
11 | VOX_X_MAX = 70
12 | VOX_Z_MIN = -2.5
13 | VOX_Z_MAX = 1.0
14 |
15 | # transformation from m to voxels
16 | VOX_X_DIVISION = 0.1
17 | VOX_Y_DIVISION = 0.1
18 | VOX_Z_DIVISION = 0.1
19 |
20 | # dimensionality of network input (voxelized point cloud)
21 | INPUT_DIM_0 = int((VOX_X_MAX-VOX_X_MIN) // VOX_X_DIVISION) + 1
22 | INPUT_DIM_1 = int((VOX_Y_MAX-VOX_Y_MIN) // VOX_Y_DIVISION) + 1
23 | # + 1 for average reflectance value of the points in the respective voxel
24 | INPUT_DIM_2 = int((VOX_Z_MAX-VOX_Z_MIN) // VOX_Z_DIVISION) + 1 + 1
25 |
26 | # dimensionality of network output
27 | OUTPUT_DIM_0 = INPUT_DIM_0 // 4
28 | OUTPUT_DIM_1 = INPUT_DIM_1 // 4
29 | OUTPUT_DIM_REG = 6
30 | OUTPUT_DIM_CLA = 1
31 |
32 | # mean and std for normalization of the regression targets
33 | REG_MEAN = np.array([-0.01518276, -0.0626486, -0.05025632, -0.05040792, 0.49188597, 1.36500531])
34 | REG_STD = np.array([0.46370442, 0.88364181, 0.70925018, 1.0590797, 0.06251486, 0.10906765])
35 |
--------------------------------------------------------------------------------
/visualize_training.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | matplotlib.use('TkAgg')
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def plot_history(metrics):
8 | """
9 | Plot evoluction of training and validation loss over the training period.
10 | :param metrics: dictionary containing training and validation loss
11 | """
12 |
13 | fig, axs = plt.subplots(1, 2)
14 | plt.subplots_adjust(top=0.75, bottom=0.25, wspace=0.4)
15 |
16 | train_loss = metrics['train_loss']
17 | val_loss = metrics['val_loss']
18 |
19 | batch_size = 6
20 | epochs = batch_size * (len(train_loss) + len(val_loss)) // 6481
21 |
22 | for ax_id, ax in enumerate(axs):
23 | ax.set_xlabel('Epochs')
24 | ax.set_ylabel('Loss')
25 | ax.set_yscale('log')
26 | ax.grid(True)
27 |
28 | if ax_id == 0:
29 | ax.set_ylim([0.5 * min(train_loss), max(train_loss)])
30 | ax.set_xlim([0.0, len(train_loss)])
31 | ax.set_title('Training Loss')
32 | step_size = 3
33 | ticks = np.arange(0, len(train_loss), step_size * len(train_loss) // (epochs))
34 | labels = np.arange(1, epochs + 1, step_size)
35 | ax.set_xticks(ticks)
36 | ax.set_xticklabels(labels)
37 | ax.plot(train_loss)
38 | else:
39 | ax.set_ylim([0.5 * min(val_loss), max(val_loss)])
40 | ax.set_xlim([0.0, len(val_loss)])
41 | ax.set_title('Validation Loss')
42 | step_size = 3
43 | ticks = np.arange(0, len(val_loss), step_size * len(val_loss) // (epochs))
44 | labels = np.arange(1, epochs + 1, step_size)
45 | ax.set_xticks(ticks)
46 | ax.set_xticklabels(labels)
47 | ax.plot(val_loss)
48 |
49 | plt.show()
50 |
51 |
52 | if __name__ == '__main__':
53 |
54 | metrics = np.load('Metrics/metrics_17.npz', allow_pickle=True)['history'].item()
55 | plot_history(metrics)
56 |
57 |
--------------------------------------------------------------------------------
/visualize_evaluation.py:
--------------------------------------------------------------------------------
1 | import numpy as np
2 | import matplotlib
3 | matplotlib.use('TkAgg')
4 | import matplotlib.pyplot as plt
5 |
6 |
7 | def plot_precision_recall_curve(eval_dicts):
8 | """
9 | Plot precision recall curve using all available metrics stored in the evaluation dictionaries.
10 | :param eval_dicts: evaluation dictionary
11 | """
12 |
13 | # set up figure for precision-recall curve
14 | # one subplot for each distance range evaluated
15 | max_distance_ranges = np.max([len([key for key in eval_dict.keys() if not isinstance(key, str)][::-1]) for eval_dict in eval_dicts])
16 | fig, axs = plt.subplots(1, max_distance_ranges)
17 | for ax in axs:
18 | ax.set_xlabel('Recall')
19 | ax.set_ylabel('Precision')
20 | ax.set_ylim([0.0, 1.05])
21 | ax.set_xlim([0.0, 1.05])
22 | ax.grid(True)
23 | ax.set_aspect(0.8)
24 | colors = [(80/255, 127/255, 255/255), 'k', 'r', 'g']
25 | description = ['-', '--', '-.', ':']
26 | plt.subplots_adjust(wspace=0.4)
27 |
28 | # iterate over all provided evaluation dictionaries
29 | for eval_id, eval_dict in enumerate(eval_dicts):
30 |
31 | distance_ranges = [key for key in eval_dict.keys() if not isinstance(key, str)][::-1]
32 |
33 | # iterate over all evaluated distance ranges
34 | for range_id, distance_range in enumerate(distance_ranges):
35 |
36 | distance_dict = eval_dict[distance_range]
37 |
38 | # store plots for legend
39 | plots = []
40 |
41 | # get IoU thresholds used for evaluation
42 | thresholds = [key for key in distance_dict.keys() if not isinstance(key, str)]
43 |
44 | # loop over all IoU-thresholds
45 | plot = None
46 | for threshold in thresholds:
47 | recall = distance_dict[threshold]['recall']
48 | precision = distance_dict[threshold]['precision']
49 | plot = axs[range_id].plot(recall, precision, linewidth=1, linestyle=description[eval_id],
50 | color=colors[eval_id], label='mAP[0.5:0.9] = {0:0.2f}'.format(distance_dict['mAP']))[0]
51 | plots.append(plot)
52 |
53 | # set subplot title and display legend
54 | axs[range_id].set_title('Range 0-{0:d}m'.format(distance_range))
55 | labels = [plot.get_label() for plot in plots]
56 | axs[range_id].legend(plots, labels, loc='lower right')
57 |
58 | plt.show()
59 |
60 |
61 | if __name__ == '__main__':
62 |
63 | eval_dict = np.load('Eval/eval_dict_epoch_17.npz', allow_pickle=True)['eval_dict'].item()
64 |
65 | eval_dicts = [eval_dict]
66 | plot_precision_recall_curve(eval_dicts)
67 |
68 |
69 |
70 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # PIXOR: Real-time 3D Object Detection from Point Clouds
2 | ## Unofficial PyTorch Implementation
3 |
4 | In this repository you'll find an unofficial implementation of PIXOR using PyTorch. I implemented this project to gain some experience working with 3D object detection and familiarize myself with the *kitti dataset* used for training and evaluation of the model. The vast majority of the code presented in this repository is written by myself based on my interpretation of the [original paper](https://arxiv.org/pdf/1902.06326.pdf). Parts of the helper functions for loading and displaying data in ```kitti_utils.py``` are inspired by the [kitti_object_vis](https://github.com/kuixu/kitti_object_vis) repository.
5 |
6 |
7 |
8 |
9 |
10 |
11 | ### Requirements
12 |
13 | The project is built using a small set of libraries. Post-processing is predominantly performed in numpy with scipy being used for some vectorized operations. Shapely is required for the calculation of bounding box IoUs. All image modifications are performed using OpenCV while matplotlib is used for plotting training history and evaluation graphs.
14 | ```
15 | torch 1.1.0
16 | shapely
17 | scipy
18 | cv2
19 | matplotlib
20 | numpy
21 | ```
22 |
23 | ### Dataset
24 |
25 | The kitti dataset is among the most widely used datasets for the development of computer vision and sensor fusion applications. The dataset can be downloaded from the [official website](http://www.cvlibs.net/datasets/kitti/eval_object.php?obj_benchmark=bev) for free. For this project I only used the LiDAR point clouds. However, I downloaded the camera images as well to get a better idea of the car's surroundings during the dataset exploration. Furthermore, I downloaded the camera calibration matrices and the bounding box annotations.
26 | The project requires the dataset to be saved according to the following structure:
27 | ```
28 | Data
29 | training
30 | velodyne
31 | calib
32 | image_2
33 | label_2
34 | testing
35 | velodyne
36 | calib
37 | image_2
38 | label_2
39 | ```
40 | The training set consists of 7481 annotated frames. There also exists an official test set containing 7518 frames. However, the labels are only available for implementations that are associated with a published paper. Consequently, I had to rely on a fraction of the original training set for the evaluation of my model. I split the available data into a training set of 6481 frames and a testing set of 1000 samples. During training the dataset is then further split into a training and a validation set using a 90/10 split.
41 |
42 | ### Implementation Details
43 |
44 | The implementation and evaluation of the model is aimed to reproduce the approach taken in the original paper. This refers to the basic implementation details provided in the paper. Modifications discussed in the ablation study are not included.
45 | A difference worth mentioning is that the model was trained in a binary classification manner. Instead of considering all available classes of the kitti dataset (Car, Van, Truck, Pedestrian, Person (sitting), Cyclist, Tram and Misc) the model is merely trained on annotations of the class Car. This is expected to be updated in the future.
46 |
47 | ### How to Navigate the Project
48 |
49 | With the folder structure for the dataset set up according to the specifications above, you're ready to navigate the project.
50 | The project directory contains three main files of interest, one for training, evaluation and detection, respectively.
51 |
52 | #### Training
53 | First, we want to train a new PIXOR model. The training is performed in ```train_model.py```. By running the script a new model is trained using the specified training parameters. I trained the model for 30 epochs using Adam with an initial learning rate of 0.001 and a scheduler that reduced the learning rate by a factor of 0.1 after 10 and 20 epochs, respectively. Early stopping is used with a default patience of 8 epochs. The batch size is set to 6 due to memory constraints of the GPU I used for training. During training, the model is saved after each epoch in case the validation loss decreased compared to the previous epoch. In order to save the model, a folder ```Models``` has to be created in the working directory. Furthermore, a folder ```Metrics```should be created. In this folder a dictionary containing the training and validation loss is stored after every epoch to allow for visualizing the training process.
54 |
55 | #### Evaluation
56 | With a PIXOR model trained and saved to the Models folder of the working directory, the model can be evaluated on the test set. Sticking to the evaluation scheme of the original paper, the performance is measured over three different distance ranges(0-30m, 0-50m and 0-70m) and the mAP is computed as an average over IoU thresholds of 0.5, 0.6, 0.7, 0.8 and 0.9.
57 | By running ```evaluate_model.py```, a dictionary containing all relevant performance measures is created and saved. Prior to the execution a folder ```Eval```has to be created in the working directory. Having created the evaluation dictionary, the evaluation can be visualized using ```visualize_evaluation.py```. For each of the evaluated distance ranges, Precision-Recall-Curves are plotted for each of the specified IoU thresholds. Moreover, the final mAP for each distance range is displayed.
58 |
59 |
60 |
61 |
62 |
63 |
64 | #### Detection
65 | Having trained a PIXOR model, the detector can be run on unseen point clouds. For an visual inspection of the resulting detections, run ```detector.py```. In this script, the detector is run on a set of selected indices from the test set and the results are displayed. In order to get a good intuition about the quality of the results, the detections are displayed on a BEV representation of the point cloud along with the ground truth bounding boxes. Furthermore, an option exists to also display the original camera image along with projections of the predicted and annotated bounding boxes.
66 |
67 |
68 |
69 |
--------------------------------------------------------------------------------
/detector.py:
--------------------------------------------------------------------------------
1 | from evaluate_model import *
2 |
3 | ###############
4 | # show legend #
5 | ###############
6 |
7 |
8 | def show_legend(image, ground_truth_color, prediction_color, idx):
9 | """
10 | Display legend for color codes in provided image.
11 | :param image: image on which to display the legend
12 | :return: image with legend
13 | """
14 |
15 | text1 = 'Ground Truth'
16 | text2 = 'Prediction'
17 | text3 = 'Point Cloud Index: {:d}'.format(idx)
18 | font = cv2.FONT_HERSHEY_DUPLEX
19 | font_scale = 0.7
20 | thickness = 1
21 | size1 = cv2.getTextSize(text1, font, font_scale, thickness)
22 | size2 = cv2.getTextSize(text2, font, font_scale, thickness)
23 | size3 = cv2.getTextSize(text3, font, font_scale, thickness)
24 | rectangle = cv2.rectangle(image, (20, 20),
25 | (20 + max(size1[0][0], size2[0][0], size3[0][0]) + 10,
26 | 20 + size1[0][1] + size2[0][1] + size3[0][1] + 20),
27 | (200, 200, 200), 1)
28 | cv2.putText(rectangle, text3, (25, 25 + size3[0][1]), font, font_scale, (100, 100, 100), 1)
29 | cv2.putText(rectangle, text1, (25, 30 + size1[0][1] + size3[0][1]), font, font_scale, ground_truth_color, 1)
30 | cv2.putText(rectangle, text2, (25, 35 + size1[0][1] + size2[0][1] + size3[0][1]), font, font_scale, prediction_color, 1)
31 |
32 | return image
33 |
34 |
35 | ############
36 | # detector #
37 | ############
38 |
39 | if __name__ == '__main__':
40 |
41 | """
42 | Run the detector. Iterate over a set of indices and display final detections for the respective point cloud
43 | in a BEV image and the original camera image.
44 | """
45 |
46 | # root directory of the dataset
47 | root_dir = 'Data/'
48 |
49 | # device
50 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
51 |
52 | # create dataset
53 | dataset = PointCloudDataset(root_dir, split='testing', get_image=True)
54 |
55 | # select index from dataset
56 | ids = np.arange(0, dataset.__len__())
57 |
58 | for id in ids:
59 |
60 | # get image, point cloud, labels and calibration
61 | camera_image, point_cloud, labels, calib = dataset.__getitem__(id)
62 |
63 | # create model
64 | pixor = PIXOR()
65 | n_epochs_trained = 17
66 | pixor.load_state_dict(torch.load('Models/PIXOR_Epoch_' + str(n_epochs_trained) + '.pt', map_location=device))
67 |
68 | # unsqueeze first dimension for batch
69 | point_cloud = point_cloud.unsqueeze(0)
70 |
71 | # forward pass
72 | predictions = pixor(point_cloud)
73 |
74 | # convert network output to numpy for further processing
75 | predictions = np.transpose(predictions.detach().numpy(), (0, 2, 3, 1))
76 |
77 | # get final bounding box predictions
78 | final_box_predictions = process_predictions(predictions, confidence_threshold=0.5)
79 |
80 | ###################
81 | # display results #
82 | ###################
83 |
84 | # set colors
85 | ground_truth_color = (80, 127, 255)
86 | prediction_color = (255, 127, 80)
87 |
88 | # get point cloud as numpy array
89 | point_cloud = point_cloud[0].detach().numpy().transpose((1, 2, 0))
90 |
91 | # draw BEV image
92 | bev_image = kitti_utils.draw_bev_image(point_cloud)
93 |
94 | # display ground truth bounding boxes on BEV image and camera image
95 | for label in labels:
96 | # only consider annotations for class "Car"
97 | if label.type == 'Car':
98 | # compute corners of the bounding box
99 | bbox_corners_image_coord, bbox_corners_camera_coord = kitti_utils.compute_box_3d(label, calib.P)
100 | # display bounding box in BEV image
101 | bev_img = kitti_utils.draw_projected_box_bev(bev_image, bbox_corners_camera_coord, color=ground_truth_color)
102 | # display bounding box in camera image
103 | if bbox_corners_image_coord is not None:
104 | camera_image = kitti_utils.draw_projected_box_3d(camera_image, bbox_corners_image_coord, color=ground_truth_color)
105 |
106 | # display predicted bounding boxes on BEV image and camera image
107 | if final_box_predictions is not None:
108 | for prediction in final_box_predictions:
109 | bbox_corners_camera_coord = np.reshape(prediction[2:], (2, 4)).T
110 | # create 3D bounding box coordinates from BEV coordinates. Place all bounding boxes on the ground and
111 | # choose a height of 1.5m
112 | bbox_corners_camera_coord = np.tile(bbox_corners_camera_coord, (2, 1))
113 | bbox_y_camera_coord = np.array([[0., 0., 0., 0., 1.65, 1.65, 1.65, 1.65]]).T
114 | bbox_corners_camera_coord = np.hstack((bbox_corners_camera_coord, bbox_y_camera_coord))
115 | switch_indices = np.argsort([0, 2, 1])
116 | bbox_corners_camera_coord = bbox_corners_camera_coord[:, switch_indices]
117 | bbox_corners_image_coord = kitti_utils.project_to_image(bbox_corners_camera_coord, calib.P)
118 |
119 | # display bounding box with confidence score in BEV image
120 | bev_img = kitti_utils.draw_projected_box_bev(bev_image, bbox_corners_camera_coord, color=prediction_color, confidence_score=prediction[1])
121 | # display bounding box in camera image
122 | if bbox_corners_image_coord is not None:
123 | camera_image = kitti_utils.draw_projected_box_3d(camera_image, bbox_corners_image_coord, color=prediction_color)
124 |
125 | # display legend on BEV Image
126 | bev_image = show_legend(bev_image, ground_truth_color, prediction_color, id)
127 |
128 | # show images
129 | # cv2.imshow('BEV Image', bev_image)
130 | # cv2.imshow('Camera Image', camera_image)
131 | # cv2.waitKey()
132 |
133 | # save image
134 | print('Index: ', id)
135 | cv2.imwrite('Images/Detections/detection_id_{:d}.png'.format(id), bev_image)
136 |
--------------------------------------------------------------------------------
/visualize_dataset.py:
--------------------------------------------------------------------------------
1 | from load_data import *
2 | import matplotlib
3 | matplotlib.use('TkAgg')
4 | import matplotlib.pyplot as plt
5 |
6 | ##################
7 | # dataset object #
8 | ##################
9 |
10 |
11 | class KittiObject(object):
12 | """
13 | Load and parse object data into a usable format.
14 |
15 | """
16 | def __init__(self, root_dir, split='testing'):
17 | """
18 | root_dir contains training and testing folders
19 | :param root_dir:
20 | :param split:
21 | :param args:
22 | """
23 | self.root_dir = root_dir
24 | self.split = split
25 | self.split_dir = os.path.join(root_dir, split)
26 |
27 | if split == 'training':
28 | self.num_samples = 6481
29 | elif split == 'testing':
30 | self.num_samples = 1000
31 | else:
32 | print('Unknown split: %s' % (split))
33 | exit(-1)
34 |
35 | self.image_dir = os.path.join(self.split_dir, 'image_2')
36 | self.calib_dir = os.path.join(self.split_dir, 'calib')
37 | self.lidar_dir = os.path.join(self.split_dir, 'velodyne')
38 | self.label_dir = os.path.join(self.split_dir, 'label_2')
39 |
40 | def __len__(self):
41 | return self.num_samples
42 |
43 | def get_image(self, idx):
44 | assert (idx < self.num_samples)
45 | img_filename = os.path.join(self.image_dir, '%06d.png' % idx)
46 | return cv2.imread(img_filename)
47 |
48 | def get_calibration(self, idx):
49 | assert (idx < self.num_samples)
50 | calib_filename = os.path.join(self.calib_dir, '%06d.txt' % idx)
51 | return kitti_utils.Calibration(calib_filename)
52 |
53 | def get_label_objects(self, idx):
54 | assert(idx < self.num_samples)
55 | label_filename = os.path.join(self.label_dir, '%06d.txt' % idx)
56 | return kitti_utils.read_label(label_filename)
57 |
58 | def get_lidar(self, idx, dtype=np.float32, n_vec=4):
59 | assert (idx < self.num_samples)
60 | lidar_filename = os.path.join(self.lidar_dir, '%06d.bin' % (idx))
61 | return kitti_utils.load_velo_scan(lidar_filename, dtype, n_vec)
62 |
63 |
64 | ########
65 | # main #
66 | ########
67 |
68 | if __name__ == '__main__':
69 |
70 | """
71 | Explore the dataset. For a random selection of indices, the corresponding camera image will be displayed along with
72 | all 3D bounding box annotations for the class "Cars". Moreover, the BEV image of the LiDAR point cloud will be
73 | displayed with the bounding box annotations and a mask that shows the relevant pixels for the labels used for
74 | training the network.
75 | """
76 |
77 | # root directory of the dataset
78 | root_dir = 'Data/'
79 |
80 | # create dataset
81 | train_dataset = KittiObject(root_dir)
82 |
83 | # select random indices from dataset
84 | ids = np.random.randint(0, 1000, 30)
85 |
86 | # loop over random selection
87 | for id in ids:
88 |
89 | # get image, point cloud, labels and calibration
90 | image = train_dataset.get_image(idx=id)
91 | labels = train_dataset.get_label_objects(idx=id)
92 | calib = train_dataset.get_calibration(idx=id)
93 | point_cloud = train_dataset.get_lidar(idx=id)
94 |
95 | # voxelize the point cloud
96 | voxel_point_cloud = kitti_utils.voxelize(point_cloud)
97 |
98 | # get BEV image of point cloud
99 | bev_image = kitti_utils.draw_bev_image(voxel_point_cloud)
100 |
101 | # create empty labels
102 | regression_label = np.zeros((OUTPUT_DIM_0, OUTPUT_DIM_1, OUTPUT_DIM_REG))
103 | classification_label = np.zeros((OUTPUT_DIM_0, OUTPUT_DIM_1, OUTPUT_DIM_CLA))
104 |
105 | # loop over all annotations for current sample
106 | for idl, label in enumerate(labels):
107 | # only display objects labeled as Car
108 | if label.type == 'Car':
109 | # compute corners of the bounding box
110 | bbox_corners_image_coord, bbox_corners_camera_coord = kitti_utils.compute_box_3d(label, calib.P)
111 | # draw BEV bounding box on BEV image
112 | bev_image = kitti_utils.draw_projected_box_bev(bev_image, bbox_corners_camera_coord)
113 | # create labels
114 | regression_label, classification_label = compute_pixel_labels(regression_label, classification_label,
115 | label, bbox_corners_camera_coord)
116 | # draw 3D bounding box on image
117 | if bbox_corners_image_coord is not None:
118 | image = kitti_utils.draw_projected_box_3d(image, bbox_corners_image_coord)
119 |
120 | # create binary mask from relevant pixels in label
121 | label_mask = np.where(np.sum(np.abs(regression_label), axis=2) > 0, 255, 0).astype(np.uint8)
122 |
123 | # remove all points outside the specified area
124 | idx = np.where(point_cloud[:, 0] > VOX_X_MIN)
125 | point_cloud = point_cloud[idx]
126 | idx = np.where(point_cloud[:, 0] < VOX_X_MAX)
127 | point_cloud = point_cloud[idx]
128 | idx = np.where(point_cloud[:, 1] > VOX_Y_MIN)
129 | point_cloud = point_cloud[idx]
130 | idx = np.where(point_cloud[:, 1] < VOX_Y_MAX)
131 | point_cloud = point_cloud[idx]
132 | idx = np.where(point_cloud[:, 2] > VOX_Z_MIN)
133 | point_cloud = point_cloud[idx]
134 | idx = np.where(point_cloud[:, 2] < VOX_Z_MAX)
135 | point_cloud = point_cloud[idx]
136 |
137 | # get rectified point cloud for depth information
138 | point_cloud_rect = calib.project_velo_to_rect(point_cloud[:, :3])
139 |
140 | # color map to indicate depth of point
141 | cmap = plt.cm.get_cmap('hsv', 256)
142 | cmap = np.array([cmap(i) for i in range(256)])[:, :3] * 255
143 |
144 | # project point cloud to image plane
145 | point_cloud_2d = calib.project_velo_to_image(point_cloud[:, :3]).astype(np.int32)
146 |
147 | # draw points
148 | for i in range(point_cloud_2d.shape[0]):
149 | depth = point_cloud_rect[i, 2]
150 | if depth > 0.1:
151 | color = cmap[int(255 - depth / VOX_X_MAX * 255)-1, :]
152 | cv2.circle(image, (point_cloud_2d[i, 0], point_cloud_2d[i, 1]), radius=2, color=color, thickness=-1)
153 |
154 | # display images
155 | cv2.imshow('Label Mask', label_mask)
156 | cv2.imshow('Image', image)
157 | cv2.imshow('Image_BEV', bev_image)
158 | cv2.waitKey()
159 |
--------------------------------------------------------------------------------
/train_model.py:
--------------------------------------------------------------------------------
1 | import torch.optim as optim
2 | from load_data import *
3 | from PIXOR import PIXOR
4 | import torch.nn as nn
5 | import time
6 |
7 | ##################
8 | # early stopping #
9 | ##################
10 |
11 |
12 | class EarlyStopping:
13 | """
14 | Early stops the training if validation loss doesn't improve after a given patience.
15 | """
16 |
17 | def __init__(self, patience=7, verbose=False):
18 | """
19 | :param patience: How many epochs wait after the last validation loss improvement
20 | :param verbose: If True, prints a message for each validation loss improvement.
21 | """
22 |
23 | self.patience = patience
24 | self.verbose = verbose
25 | self.counter = 0
26 | self.best_score = None
27 | self.best_epoch = None
28 | self.early_stop = False
29 | self.val_loss_min = np.Inf
30 |
31 | def __call__(self, val_loss, epoch, model):
32 |
33 | score = -val_loss
34 |
35 | # first epoch
36 | if self.best_score is None:
37 | self.best_score = score
38 | self.best_epoch = epoch + 1
39 | self.save_checkpoint(val_loss, model)
40 |
41 | # validation loss increased
42 | elif score < self.best_score:
43 |
44 | # increase counter
45 | self.counter += 1
46 |
47 | print('Validation loss did not decrease ({:.6f} --> {:.6f})'.format(self.val_loss_min, val_loss))
48 | print('EarlyStopping counter: {} out of {}'.format(self.counter, self.patience))
49 | print('###########################################################')
50 |
51 | # stop training if patience is reached
52 | if self.counter >= self.patience:
53 | self.early_stop = True
54 |
55 | # validation loss decreased
56 | else:
57 | self.best_score = score
58 | self.best_epoch = epoch + 1
59 | self.save_checkpoint(val_loss, model)
60 |
61 | # reset counter
62 | self.counter = 0
63 |
64 | def save_checkpoint(self, val_loss, model):
65 | """
66 | Saves model when validation loss decreased.
67 | """
68 |
69 | if self.verbose:
70 | print('Validation loss decreased ({:.6f} --> {:.6f}). '
71 | 'Saving model ...'.format(self.val_loss_min, val_loss))
72 | print('###########################################################')
73 |
74 | # save model
75 | torch.save(model.state_dict(), 'Models/PIXOR_Epoch_' + str(self.best_epoch) + '.pt')
76 |
77 | # set current loss as new minimum loss
78 | self.val_loss_min = val_loss
79 |
80 |
81 | ##############
82 | # focal loss #
83 | ##############
84 |
85 |
86 | class FocalLoss(nn.Module):
87 | """
88 | Focal loss class. Stabilize training by reducing the weight of easily classified background sample and focussing
89 | on difficult foreground detections.
90 | """
91 |
92 | def __init__(self, gamma=0, size_average=False):
93 | super(FocalLoss, self).__init__()
94 | self.gamma = gamma
95 | self.size_average = size_average
96 |
97 | def forward(self, prediction, target):
98 |
99 | # get class probability
100 | pt = torch.where(target == 1.0, prediction, 1-prediction)
101 |
102 | # compute focal loss
103 | loss = -1 * (1-pt)**self.gamma * torch.log(pt)
104 |
105 | if self.size_average:
106 | return loss.mean()
107 | else:
108 | return loss.sum()
109 |
110 |
111 | ##################
112 | # calculate loss #
113 | ##################
114 |
115 |
116 | def calc_loss(batch_predictions, batch_labels):
117 | """
118 | Calculate the final loss function as a sum of the classification and the regression loss.
119 | :param batch_predictions: predictions for the current batch | shape: [batch_size, OUTPUT_DIM_0, OUTPUT_DIM_1, OUTPUT_DIM_CLA+OUTPUT_DIM_REG]
120 | :param batch_labels: labels for the current batch | shape: [batch_size, OUTPUT_DIM_0, OUTPUT_DIM_1, OUTPUT_DIM_CLA+OUTPUT_DIM_REG]
121 | :return: compouted loss
122 | """
123 |
124 | # classification loss
125 | classification_prediction = batch_predictions[:, :, :, -1].contiguous().flatten()
126 | classification_label = batch_labels[:, :, :, -1].contiguous().flatten()
127 | focal_loss = FocalLoss(gamma=2)
128 | classification_loss = focal_loss(classification_prediction, classification_label)
129 |
130 | # regression loss
131 | regression_prediction = batch_predictions[:, :, :, :-1]
132 | regression_prediction = regression_prediction.contiguous().view([regression_prediction.size(0)*
133 | regression_prediction.size(1)*regression_prediction.size(2), regression_prediction.size(3)])
134 | regression_label = batch_labels[:, :, :, :-1]
135 | regression_label = regression_label.contiguous().view([regression_label.size(0)*regression_label.size(1)*
136 | regression_label.size(2), regression_label.size(3)])
137 | positive_mask = torch.nonzero(torch.sum(torch.abs(regression_label), dim=1))
138 | pos_regression_label = regression_label[positive_mask.squeeze(), :]
139 | pos_regression_prediction = regression_prediction[positive_mask.squeeze(), :]
140 | smooth_l1 = nn.SmoothL1Loss(reduction='sum')
141 | regression_loss = smooth_l1(pos_regression_prediction, pos_regression_label)
142 |
143 | # add two loss components
144 | multi_task_loss = classification_loss.add(regression_loss)
145 |
146 | return multi_task_loss
147 |
148 |
149 | ###############
150 | # train model #
151 | ###############
152 |
153 |
154 | def train_model(model, optimizer, scheduler, data_loaders, n_epochs=25, show_times=False):
155 |
156 | # evaluation dict
157 | metrics = {'train_loss': [], 'val_loss': [], 'lr': []}
158 |
159 | # early stopping object
160 | early_stopping = EarlyStopping(patience=8, verbose=True)
161 |
162 | # moving loss
163 | moving_loss = {'train': metrics['train_loss'][-1], 'val': metrics['val_loss'][-1]}
164 |
165 | # epochs
166 | for epoch in range(n_epochs):
167 |
168 | # each epoch has a training and validation phase
169 | for phase in ['train', 'val']:
170 |
171 | # track average loss per batch
172 | if phase == 'train':
173 | model.train() # Set model to training mode
174 | else:
175 | model.eval() # Set model to evaluate mode
176 |
177 | for batch_id, (batch_data, batch_labels) in enumerate(data_loaders[phase]):
178 |
179 | # zero the parameter gradients
180 | optimizer.zero_grad()
181 |
182 | # track history only if in train phase
183 | with torch.set_grad_enabled(phase == 'train'):
184 |
185 | # forward pass
186 | forward_pass_start_time = time.time()
187 | batch_predictions = model(batch_data)
188 | forward_pass_end_time = time.time()
189 | batch_predictions = batch_predictions.permute([0, 2, 3, 1])
190 |
191 | # calculate loss
192 | calc_loss_start_time = time.time()
193 | loss = calc_loss(batch_predictions, batch_labels)
194 | calc_loss_end_time = time.time()
195 |
196 | # accumulate loss
197 | if moving_loss[phase] is None:
198 | moving_loss[phase] = loss.item()
199 | else:
200 | moving_loss[phase] = 0.99 * moving_loss[phase] + 0.01 * loss.item()
201 |
202 | # append loss for each phase
203 | metrics[phase + '_loss'].append(moving_loss[phase])
204 |
205 | # backward + optimize only if in training phase
206 | if phase == 'train':
207 | backprop_start_time = time.time()
208 | loss.backward()
209 | optimizer.step()
210 | backprop_end_time = time.time()
211 |
212 | if show_times:
213 | print('Forward Pass Time: {:.2f}'.format(forward_pass_end_time - forward_pass_start_time))
214 | print('Calc Loss Time: {:.2f} '.format(calc_loss_end_time - calc_loss_start_time))
215 | print('Backprop Time: {:.2f}'.format(backprop_end_time-backprop_start_time))
216 |
217 | if (batch_id+1) % 10 == 0:
218 | n_batches_per_epoch = data_loaders[phase].dataset.__len__() // data_loaders[phase].batch_size
219 | print("{:d}/{:d} iterations | training loss: {:.4f}".format(batch_id+1, n_batches_per_epoch, moving_loss[phase]))
220 |
221 | # keep track of learning rate
222 | for param_group in optimizer.param_groups:
223 | metrics['lr'].append(param_group['lr'])
224 |
225 | # scheduler step
226 | scheduler.step()
227 |
228 | # output progress
229 | print('###########################################################')
230 | print('Epoch: ' + str(epoch+1) + '/' + str(n_epochs))
231 | print('Learning Rate: ', metrics['lr'][-1])
232 | print('Training Loss: %.4f' % metrics['train_loss'][-1])
233 | print('Validation Loss: %.4f' % metrics['val_loss'][-1])
234 |
235 | # save metrics
236 | np.savez('./Metrics/metrics_' + str(epoch + 1) + '.npz', history=metrics)
237 |
238 | # check early stopping
239 | early_stopping(val_loss=metrics['val_loss'][-1], epoch=epoch, model=model)
240 | if early_stopping.early_stop:
241 | print('Early Stopping!')
242 | break
243 |
244 | print('Training Finished!')
245 | print('Final Model was trained for ' + str(early_stopping.best_epoch) + ' epochs and achieved minimum loss of '
246 | '%.4f!' % early_stopping.val_loss_min)
247 |
248 | return metrics
249 |
250 |
251 | if __name__ == '__main__':
252 |
253 | # set device
254 | device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
255 |
256 | # training parameters
257 | n_epochs = 30
258 | batch_size = 6
259 | initial_learning_rate = 1e-3
260 |
261 | # create data loader
262 | root_dir = 'Data/'
263 | data_loader = load_dataset(root=root_dir, batch_size=batch_size, device=device)
264 |
265 | # create model
266 | pixor = PIXOR().to(device)
267 |
268 | # create optimizer and scheduler objects
269 | optimizer = optim.Adam(filter(lambda p: p.requires_grad, pixor.parameters()), lr=initial_learning_rate)
270 | scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, [10, 20], gamma=0.1)
271 |
272 | # train model
273 | history = train_model(pixor, optimizer, scheduler, data_loader, n_epochs=n_epochs)
274 |
275 | # save training history
276 | np.savez('history.npz', history=history)
277 |
--------------------------------------------------------------------------------
/PIXOR.py:
--------------------------------------------------------------------------------
1 | import torch.nn as nn
2 | import torch
3 | import copy
4 | import os
5 | import kitti_utils
6 |
7 |
8 | ###############
9 | # Basis Block #
10 | ###############
11 |
12 |
13 | class BasisBlock(nn.Module):
14 | """
15 | BasisBlock for input to ResNet
16 | """
17 |
18 | def __init__(self, n_input_channels):
19 | super(BasisBlock, self).__init__()
20 | self.conv1 = nn.Conv2d(n_input_channels, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
21 | self.bn1 = nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
22 | self.relu1 = nn.ReLU(inplace=True)
23 | self.conv2 = nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
24 | self.bn2 = nn.BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
25 | self.relu2 = nn.ReLU(inplace=True)
26 |
27 | def forward(self, x):
28 | x = self.conv1(x)
29 | x = self.bn1(x)
30 | x = self.relu1(x)
31 | x = self.conv2(x)
32 | x = self.bn2(x)
33 | x = self.relu2(x)
34 |
35 | return x
36 |
37 |
38 | #################
39 | # Residual Unit #
40 | #################
41 |
42 |
43 | class ResidualUnit(nn.Module):
44 | def __init__(self, n_input, n_output, downsample=False):
45 | """
46 | Residual Unit consisting of two convolutional layers and an identity mapping
47 | :param n_input: number of input channels
48 | :param n_output: number of output channels
49 | :param downsample: downsample the output by a factor of 2
50 | """
51 | super(ResidualUnit, self).__init__()
52 | self.conv1 = nn.Conv2d(n_input, n_output, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
53 | self.bn1 = nn.BatchNorm2d(n_output, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
54 | self.relu = nn.ReLU(inplace=True)
55 | self.conv2 = nn.Conv2d(n_output, n_output, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
56 | self.bn2 = nn.BatchNorm2d(n_output, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
57 |
58 | # down-sampling: use stride two for convolutional kernel and create 1x1 kernel for down-sampling of input
59 | self.downsample = None
60 | if downsample:
61 | self.conv1 = nn.Conv2d(n_input, n_output, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
62 | self.downsample = nn.Sequential(nn.Conv2d(n_input, n_output, kernel_size=(1, 1), stride=(2, 2), bias=False),
63 | nn.BatchNorm2d(n_output, eps=1e-05, momentum=0.1, affine=True,
64 | track_running_stats=True))
65 | else:
66 | self.identity_channels = nn.Conv2d(n_input, n_output, kernel_size=(1, 1), bias=False)
67 |
68 | def forward(self, x):
69 |
70 | # store input for skip-connection
71 | identity = x
72 |
73 | x = self.conv1(x)
74 | x = self.bn1(x)
75 | x = self.relu(x)
76 | x = self.conv2(x)
77 | x = self.bn2(x)
78 |
79 | # downsample input to match output dimensions
80 | if self.downsample is not None:
81 | identity = self.downsample(identity)
82 | else:
83 | identity = self.identity_channels(identity)
84 |
85 | # skip-connection
86 | x += identity
87 |
88 | # apply ReLU activation
89 | x = self.relu(x)
90 |
91 | return x
92 |
93 |
94 | ##################
95 | # Residual Block #
96 | ##################
97 |
98 |
99 | class ResidualBlock(nn.Module):
100 | """
101 | Residual Block containing specified number of residual layers
102 | """
103 |
104 | def __init__(self, n_input, n_output, n_res_units):
105 | super(ResidualBlock, self).__init__()
106 |
107 | # use down-sampling only in the first residual layer of the block
108 | first_unit = True
109 |
110 | # specific channel numbers
111 | if n_res_units == 3:
112 | inputs = [n_input, n_output//4, n_output//4]
113 | outputs = [n_output//4, n_output//4, n_output]
114 | else:
115 | inputs = [n_input, n_output // 4, n_output // 4, n_output // 4, n_output // 4, n_output]
116 | outputs = [n_output // 4, n_output // 4, n_output // 4, n_output // 4, n_output, n_output]
117 |
118 | # create residual units
119 | units = []
120 | for unit_id in range(n_res_units):
121 | if first_unit:
122 | units.append(ResidualUnit(inputs[unit_id], outputs[unit_id], downsample=True))
123 | first_unit = False
124 | else:
125 | units.append(ResidualUnit(inputs[unit_id], outputs[unit_id]))
126 | self.res_block = nn.Sequential(*units)
127 |
128 | def forward(self, x):
129 |
130 | x = self.res_block(x)
131 |
132 | return x
133 |
134 |
135 | #############
136 | # FPN Block #
137 | #############
138 |
139 |
140 | class FPNBlock(nn.Module):
141 | """
142 | Block for Feature Pyramid Network including up-sampling and concatenation of feature maps
143 | """
144 |
145 | def __init__(self, bottom_up_channels, top_down_channels, fused_channels):
146 | super(FPNBlock, self).__init__()
147 | # reduce number of top-down channels to 196
148 | intermediate_channels = 196
149 | if top_down_channels > 196:
150 | self.channel_conv_td = nn.Conv2d(top_down_channels, intermediate_channels, kernel_size=(1, 1),
151 | stride=(1, 1), bias=False)
152 | else:
153 | self.channel_conv_td = None
154 |
155 | # change number of bottom-up channels to 128
156 | self.channel_conv_bu = nn.Conv2d(bottom_up_channels, fused_channels, kernel_size=(1, 1),
157 | stride=(1, 1), bias=False)
158 |
159 | # transposed convolution on top-down feature maps
160 | if fused_channels == 128:
161 | out_pad = (1, 1)
162 | else:
163 | out_pad = (0, 1)
164 | if self.channel_conv_td is not None:
165 | self.deconv = nn.ConvTranspose2d(intermediate_channels, fused_channels, kernel_size=(3, 3), padding=(1, 1),
166 | stride=2, output_padding=out_pad)
167 | else:
168 | self.deconv = nn.ConvTranspose2d(top_down_channels, fused_channels, kernel_size=(3, 3), padding=(1, 1),
169 | stride=2, output_padding=out_pad)
170 |
171 | def forward(self, x_td, x_bu):
172 |
173 | # apply 1x1 convolutional to obtain required number of channels if needed
174 | if self.channel_conv_td is not None:
175 | x_td = self.channel_conv_td(x_td)
176 |
177 | # up-sample top-down feature maps
178 | x_td = self.deconv(x_td)
179 |
180 | # apply 1x1 convolutional to obtain required number of channels
181 | x_bu = self.channel_conv_bu(x_bu)
182 |
183 | # perform element-wise addition
184 | x = x_td.add(x_bu)
185 |
186 | return x
187 |
188 |
189 | ####################
190 | # Detection Header #
191 | ####################
192 |
193 | class DetectionHeader(nn.Module):
194 |
195 | def __init__(self, n_input, n_output):
196 | super(DetectionHeader, self).__init__()
197 | basic_block = nn.Sequential(nn.Conv2d(n_input, n_output, kernel_size=(3, 3), padding=(1, 1), bias=False),
198 | nn.BatchNorm2d(n_output, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True),
199 | nn.ReLU(inplace=True))
200 | self.conv1 = basic_block
201 | self.conv2 = copy.deepcopy(basic_block)
202 | self.conv3 = copy.deepcopy(basic_block)
203 | self.conv4 = copy.deepcopy(basic_block)
204 | self.classification = nn.Conv2d(n_output, 1, kernel_size=(3, 3), padding=(1, 1))
205 | self.regression = nn.Conv2d(n_output, 6, kernel_size=(3, 3), padding=(1, 1))
206 | self.sigmoid = nn.Sigmoid()
207 |
208 | def forward(self, x):
209 |
210 | x = self.conv1(x)
211 | x = self.conv2(x)
212 | x = self.conv3(x)
213 | x = self.conv4(x)
214 | class_output = self.sigmoid(self.classification(x))
215 | regression_output = self.regression(x)
216 |
217 | return class_output, regression_output
218 |
219 |
220 | #########
221 | # PIXOR #
222 | #########
223 |
224 |
225 | class PIXOR(nn.Module):
226 | def __init__(self):
227 | super(PIXOR, self).__init__()
228 |
229 | # Backbone Network
230 | self.basis_block = BasisBlock(n_input_channels=36)
231 | self.res_block_1 = ResidualBlock(n_input=32, n_output=96, n_res_units=3)
232 | self.res_block_2 = ResidualBlock(n_input=96, n_output=196, n_res_units=6)
233 | self.res_block_3 = ResidualBlock(n_input=196, n_output=256, n_res_units=6)
234 | self.res_block_4 = ResidualBlock(n_input=256, n_output=384, n_res_units=3)
235 |
236 | # FPN blocks
237 | self.fpn_block_1 = FPNBlock(top_down_channels=384, bottom_up_channels=256, fused_channels=128)
238 | self.fpn_block_2 = FPNBlock(top_down_channels=128, bottom_up_channels=196, fused_channels=96)
239 |
240 | # Detection Header
241 | self.header = DetectionHeader(n_input=96, n_output=96)
242 |
243 | def forward(self, x):
244 | x_b = self.basis_block(x)
245 | # print(x_b.size())
246 | x_1 = self.res_block_1(x_b)
247 | # print(x_1.size())
248 | x_2 = self.res_block_2(x_1)
249 | # print(x_2.size())
250 | x_3 = self.res_block_3(x_2)
251 | # print(x_3.size())
252 | x_4 = self.res_block_4(x_3)
253 | # print(x_4.size())
254 | x_34 = self.fpn_block_1(x_4, x_3)
255 | # print(x_34.size())
256 | x_234 = self.fpn_block_2(x_34, x_2)
257 | # print(x_234.size())
258 | x_class, x_reg = self.header(x_234)
259 | # print(x_class.size())
260 | # print(x_reg.size())
261 | x_out = torch.cat((x_reg, x_class), dim=1)
262 |
263 | return x_out
264 |
265 |
266 | ########
267 | # Main #
268 | ########
269 |
270 |
271 | if __name__ == '__main__':
272 |
273 | # exemplary input point cloud
274 | base_dir = 'Data/training/velodyne'
275 | index = 1
276 | lidar_filename = os.path.join(base_dir, '%06d.bin' % index)
277 | lidar_data = kitti_utils.load_velo_scan(lidar_filename)
278 | # create torch tensor from numpy array
279 | voxel_point_cloud = torch.tensor(kitti_utils.voxelize(lidar_data), requires_grad=True, device='cpu').float()
280 | # channels along first dimensions according to PyTorch convention
281 | voxel_point_cloud = voxel_point_cloud.permute([2, 0, 1])
282 | voxel_point_cloud = torch.unsqueeze(voxel_point_cloud, 0) # add dimension 0 to tensor for batch
283 |
284 | # forward pass through network
285 | pixor = PIXOR()
286 | prediction = pixor(voxel_point_cloud)
287 | classification_prediction = prediction[:, :, -1]
288 | regression_prediction = prediction[:, :, :-1]
289 |
290 | print('+++++++++++++++++++++++++++++++++++++')
291 | print('BEV Backbone Network')
292 | print('+++++++++++++++++++++++++++++++++++++')
293 | print(pixor)
294 | print('+++++++++++++++++++++++++++++++++++++')
295 |
296 | for child_name, child in pixor.named_children():
297 | print('++++++++++++++++++++++')
298 | print(child_name)
299 | print('++++++++++++++++++++++')
300 | for parameter_name, parameter in child.named_parameters():
301 | print(parameter_name)
302 |
303 |
--------------------------------------------------------------------------------
/load_data.py:
--------------------------------------------------------------------------------
1 | from torch.utils.data import Dataset, DataLoader, random_split
2 | import torch
3 | import os
4 | import cv2
5 | import kitti_utils
6 | from config import *
7 | import time
8 |
9 |
10 | ############################
11 | # custom collate functions #
12 | ############################
13 |
14 |
15 | def my_collate_test(batch):
16 | """
17 | Collate function for test dataset. How to concatenate individual samples to a batch.
18 | Point Clouds will be stacked along first dimension, labels and calibration objects will be returned as a list
19 | :param batch: list containing a tuple of items for each sample
20 | :return: batch data in desired form
21 | """
22 |
23 | point_clouds = []
24 | labels = []
25 | calibs = []
26 | for tuple_id, tuple in enumerate(batch):
27 | point_clouds.append(tuple[0])
28 | labels.append(tuple[1])
29 | calibs.append(tuple[2])
30 |
31 | point_clouds = torch.stack(point_clouds)
32 | return point_clouds, labels, calibs
33 |
34 |
35 | def my_collate_train(batch):
36 | """
37 | Collate function for training dataset. How to concatenate individual samples to a batch.
38 | Point Clouds and labels will be stacked along first dimension
39 | :param batch: list containing a tuple of items for each sample
40 | :return: batch data in desired form
41 | """
42 |
43 | point_clouds = []
44 | labels = []
45 | for tuple_id, tuple in enumerate(batch):
46 | point_clouds.append(tuple[0])
47 | labels.append(tuple[1])
48 |
49 | point_clouds = torch.stack(point_clouds)
50 | labels = torch.stack(labels)
51 | return point_clouds, labels
52 |
53 |
54 | ########################
55 | # compute pixel labels #
56 | ########################
57 |
58 | def compute_pixel_labels(regression_label, classification_label, label, bbox_corners_camera_coord):
59 | """
60 | Compute the label that will be fed into the network from the bounding box annotations of the respective point cloud.
61 | :param: regression_label: emtpy numpy array | shape: [OUTPUT_DIM_0, OUTPUT_DIM_1, OUTPUT_DIM_REG]
62 | :param: classification_label: emtpy numpy array | shape: [OUTPUT_DIM_0, OUTPUT_DIM_1, OUTPUT_DIM_CLA]
63 | :param label: 3D label object containing bounding box information
64 | :param bbox_corners_camera_coord: corners of the bounding box | shape: [8, 3]
65 | :return: regression_label and classification_label filled with relevant label information
66 | """
67 |
68 | # get label information
69 | angle_rad = label.ry # rotation of bounding box
70 | center_x_m = label.t[0]
71 | center_y_m = label.t[2]
72 | length_m = label.length
73 | width_m = label.width
74 |
75 | # extract corners of BEV bounding box
76 | bbox_corners_x = bbox_corners_camera_coord[:4, 0]
77 | bbox_corners_y = bbox_corners_camera_coord[:4, 2]
78 |
79 | # convert coordinates from m to pixels
80 | corners_x_px = ((bbox_corners_x - VOX_Y_MIN) // VOX_Y_DIVISION).astype(np.int32)
81 | corners_y_px = (INPUT_DIM_0 - ((bbox_corners_y - VOX_X_MIN) // VOX_X_DIVISION)).astype(np.int32)
82 | bbox_corners = np.vstack((corners_x_px, corners_y_px)).T
83 |
84 | # create a pixel mask of the target bounding box
85 | canvas = np.zeros((INPUT_DIM_0, INPUT_DIM_1, 3))
86 | canvas = cv2.fillPoly(canvas, pts=[bbox_corners], color=(255, 255, 255))
87 |
88 | # resize label to fit output shape
89 | canvas_resized = cv2.resize(canvas, (OUTPUT_DIM_1, OUTPUT_DIM_0), interpolation=cv2.INTER_NEAREST)
90 | bbox_mask = np.where(np.sum(canvas_resized, axis=2) == 765, 1, 0).astype(np.uint8)[:, :, np.newaxis]
91 |
92 | # get location of each pixel in m
93 | x_lin = np.linspace(VOX_Y_MIN, VOX_Y_MAX-0.4, OUTPUT_DIM_1)
94 | y_lin = np.linspace(VOX_X_MAX, VOX_X_MIN+0.4, OUTPUT_DIM_0)
95 | px_x, px_y = np.meshgrid(x_lin, y_lin)
96 |
97 | # create regression target
98 | target = np.array([[np.cos(angle_rad), np.sin(angle_rad), -center_x_m, -center_y_m, np.log(width_m), np.log(length_m)]])
99 | target = np.tile(target, (OUTPUT_DIM_0, OUTPUT_DIM_1, 1))
100 |
101 | # take offset from pixel as regression target for bounding box location
102 | target[:, :, 2] += px_x
103 | target[:, :, 3] += px_y
104 |
105 | # normalize target
106 | target = (target - REG_MEAN) / REG_STD
107 |
108 | # zero-out non-relevant pixels
109 | target *= bbox_mask
110 |
111 | # add current target to label for currently inspected point cloud
112 | regression_label += target
113 | classification_label += bbox_mask
114 |
115 | return regression_label, classification_label
116 |
117 |
118 | ###################
119 | # dataset classes #
120 | ###################
121 |
122 |
123 | class PointCloudDataset(Dataset):
124 | """
125 | Characterizes a dataset for PyTorch
126 | """
127 |
128 | def __init__(self, root_dir, split='training', device=torch.device('cpu'), show_times=True, get_image=False):
129 | """
130 | Dataset for training and testing containing point cloud, calibration object and in case of training labels
131 | :param root_dir: root directory of the dataset
132 | :param split: training or testing split of the dataset
133 | :param device: device on which dataset will be used
134 | :param show_times: show times of each step of the data loading (debug)
135 | """
136 |
137 | self.show_times = show_times # debug
138 | self.get_image = get_image # load camera image
139 |
140 | self.device = device
141 | self.root_dir = root_dir
142 | self.split = split
143 | self.split_dir = os.path.join(root_dir, split)
144 |
145 | if split == 'training':
146 | self.num_samples = 6481
147 | elif split == 'testing':
148 | self.num_samples = 1000
149 | else:
150 | print('Unknown split: %s' % split)
151 | exit(-1)
152 |
153 | # paths to camera, lidar, calibration and label directories
154 | self.lidar_dir = os.path.join(self.split_dir, 'velodyne')
155 | self.calib_dir = os.path.join(self.split_dir, 'calib')
156 | self.label_dir = os.path.join(self.split_dir, 'label_2')
157 | self.image_dir = os.path.join(self.split_dir, 'image_2')
158 |
159 | def __len__(self):
160 | # Denotes the total number of samples
161 | return self.num_samples
162 |
163 | def __getitem__(self, index):
164 |
165 | # start time
166 | get_item_start_time = time.time()
167 |
168 | # get point cloud
169 | lidar_filename = os.path.join(self.lidar_dir, '%06d.bin' % index)
170 | lidar_data = kitti_utils.load_velo_scan(lidar_filename)
171 |
172 | # time for loading point cloud
173 | read_point_cloud_end_time = time.time()
174 | read_point_cloud_time = read_point_cloud_end_time - get_item_start_time
175 |
176 | # voxelize point cloud
177 | voxel_point_cloud = torch.tensor(kitti_utils.voxelize(point_cloud=lidar_data), requires_grad=True, device=self.device).float()
178 |
179 | # time for voxelization
180 | voxelization_end_time = time.time()
181 | voxelization_time = voxelization_end_time - read_point_cloud_end_time
182 |
183 | # channels along first dimensions according to PyTorch convention
184 | voxel_point_cloud = voxel_point_cloud.permute([2, 0, 1])
185 |
186 | # get image
187 | if self.get_image:
188 | image_filename = os.path.join(self.image_dir, '%06d.png' % index)
189 | image = kitti_utils.get_image(image_filename)
190 |
191 | # get current time
192 | read_labels_start_time = time.time()
193 |
194 | # get calibration
195 | calib_filename = os.path.join(self.calib_dir, '%06d.txt' % index)
196 | calib = kitti_utils.Calibration(calib_filename)
197 |
198 | # get labels
199 | label_filename = os.path.join(self.label_dir, '%06d.txt' % index)
200 | labels = kitti_utils.read_label(label_filename)
201 |
202 | read_labels_end_time = time.time()
203 | read_labels_time = read_labels_end_time - read_labels_start_time
204 |
205 | # compute network label
206 | if self.split == 'training':
207 | # get current time
208 | compute_label_start_time = time.time()
209 |
210 | # create empty pixel labels
211 | regression_label = np.zeros((OUTPUT_DIM_0, OUTPUT_DIM_1, OUTPUT_DIM_REG))
212 | classification_label = np.zeros((OUTPUT_DIM_0, OUTPUT_DIM_1, OUTPUT_DIM_CLA))
213 |
214 | # iterate over all 3D label objects in list
215 | for label in labels:
216 | if label.type == 'Car':
217 | # compute corners of 3D bounding box in camera coordinates
218 | _, bbox_corners_camera_coord = kitti_utils.compute_box_3d(label, calib.P, scale=1.0)
219 | # get pixel label for classification and BEV bounding box
220 | regression_label, classification_label = compute_pixel_labels\
221 | (regression_label, classification_label, label, bbox_corners_camera_coord)
222 |
223 | # stack classification and regression label
224 | regression_label = torch.tensor(regression_label, device=self.device).float()
225 | classification_label = torch.tensor(classification_label, device=self.device).float()
226 | training_label = torch.cat((regression_label, classification_label), dim=2)
227 |
228 | # get time for computing pixel label
229 | compute_label_end_time = time.time()
230 | compute_label_time = compute_label_end_time - compute_label_start_time
231 |
232 | # total time for data loading
233 | get_item_end_time = time.time()
234 | get_item_time = get_item_end_time - get_item_start_time
235 |
236 | if self.show_times:
237 | print('---------------------------')
238 | print('Get Item Time: {:.4f} s'.format(get_item_time))
239 | print('---------------------------')
240 | print('Read Point Cloud Time: {:.4f} s'.format(read_point_cloud_time))
241 | print('Voxelization Time: {:.4f} s'.format(voxelization_time))
242 | print('Read Labels Time: {:.4f} s'.format(read_labels_time))
243 | print('Compute Labels Time: {:.4f} s'.format(compute_label_time))
244 |
245 | return voxel_point_cloud, training_label
246 |
247 | else:
248 | if self.get_image:
249 | return image, voxel_point_cloud, labels, calib
250 | else:
251 | return voxel_point_cloud, labels, calib
252 |
253 |
254 | #################
255 | # load datasets #
256 | #################
257 |
258 | def load_dataset(root='Data/', batch_size=1, train_val_split=0.9, test_set=False,
259 | device=torch.device('cpu'), show_times=False):
260 | """
261 | Create a data loader that reads in the data from a directory of png-images
262 | :param device: device of the model
263 | :param root: root directory of the image data
264 | :param batch_size: batch-size for the data loader
265 | :param train_val_split: fraction of the available data used for training
266 | :param test_set: if True, data loader will be generated that contains only a test set
267 | :param show_times: display times for each step of the data loading
268 | :return: torch data loader object
269 | """
270 |
271 | # speed up data loading on gpu
272 | if device != torch.device('cpu'):
273 | num_workers = 0
274 | else:
275 | num_workers = 0
276 |
277 | # create training and validation set
278 | if not test_set:
279 |
280 | # create customized dataset class
281 | dataset = PointCloudDataset(root_dir=root, device=device, split='training', show_times=show_times)
282 |
283 | # number of images used for training and validation
284 | n_images = dataset.__len__()
285 | n_train = int(train_val_split * n_images)
286 | n_val = n_images - n_train
287 |
288 | # generated training and validation set
289 | train_dataset, val_dataset = random_split(dataset, [n_train, n_val])
290 |
291 | # create data_loaders
292 | data_loader = {
293 | 'train': DataLoader(train_dataset, batch_size=batch_size, shuffle=True, collate_fn=my_collate_train,
294 | num_workers=num_workers),
295 | 'val': DataLoader(val_dataset, batch_size=batch_size, shuffle=True, collate_fn=my_collate_train,
296 | num_workers=num_workers)
297 | }
298 |
299 | # create test set
300 | else:
301 |
302 | test_dataset = PointCloudDataset(root_dir=root, device=device, split='testing')
303 | data_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, collate_fn=my_collate_test,
304 | num_workers=num_workers, drop_last=True)
305 |
306 | return data_loader
307 |
308 |
309 | if __name__ == '__main__':
310 |
311 | # create data loader
312 | root_dir = 'Data/'
313 | batch_size = 1
314 | device = torch.device('cpu')
315 | data_loader = load_dataset(root=root_dir, batch_size=batch_size, device=device, show_times=True)['train']
316 |
317 | for batch_id, (batch_data, batch_labels) in enumerate(data_loader):
318 | pass
319 |
--------------------------------------------------------------------------------
/kitti_utils.py:
--------------------------------------------------------------------------------
1 | import cv2
2 | from config import *
3 | import copy
4 |
5 | ###################
6 | # 3D Label Object #
7 | ###################
8 |
9 |
10 | class Object3D(object):
11 | def __init__(self, label_file_line):
12 | data = label_file_line.split(' ')
13 | data[1:] = [float(x) for x in data[1:]]
14 |
15 | # extract label, truncation, occlusion
16 | self.type = data[0] # 'Car', 'Pedestrian', ...
17 | self.truncation = data[1] # truncated pixel ratio [0..1]
18 | self.occlusion = int(data[2]) # 0=visible, 1=partly occluded, 2=fully occluded, 3=unknown
19 | self.alpha = data[3] # object observation angle [-pi..pi]
20 |
21 | # extract 2d bounding box in 0-based coordinates
22 | self.xmin = data[4] # left
23 | self.ymin = data[5] # top
24 | self.xmax = data[6] # right
25 | self.ymax = data[7] # bottom
26 | self.box2d = np.array([self.xmin, self.ymin, self.xmax, self.ymax])
27 |
28 | # extract 3d bounding box information
29 | self.height = data[8] # box height
30 | self.width = data[9] # box width
31 | self.length = data[10] # box length (in meters)
32 | self.t = (data[11], data[12], data[13]) # location (x,y,z) in camera coord.
33 | self.ry = data[14] # yaw angle (around Y-axis in camera coordinates) [-pi..pi]
34 |
35 | def print_object(self):
36 | print('Type, truncation, occlusion, alpha: %s, %d, %d, %f' % \
37 | (self.type, self.truncation, self.occlusion, self.alpha))
38 | print('2d bbox (x0,y0,x1,y1): %f, %f, %f, %f' % \
39 | (self.xmin, self.ymin, self.xmax, self.ymax))
40 | print('3d bbox h,w,l: %f, %f, %f' % \
41 | (self.height, self.width, self.length))
42 | print('3d bbox location, ry: (%f, %f, %f), %f' % \
43 | (self.t[0], self.t[1], self.t[2], self.ry))
44 |
45 |
46 | ######################
47 | # Calibration Object #
48 | ######################
49 |
50 |
51 | class Calibration(object):
52 | """
53 | Calibration matrices and utils
54 |
55 | ------------------
56 | coordinate systems
57 | ------------------
58 | 3d XYZ in