├── LICENSE
├── Pandaset.rights
├── README.md
├── SECURITY.md
├── codegenComplexYOLOV4.m
├── complexYOLOv4Predict.m
├── complexYOLOv4PredictionExample.m
├── complexYOLOv4TransferLearn.m
├── createTrainingData.m
├── detectComplexYOLOv4.m
├── images
├── BEVOutput.png
├── LidarOutput.png
├── network.png
└── overview.png
├── mishLayer.m
├── models
└── .gitkeep
├── pointclouds
├── 0001.pcd
├── 0002.pcd
└── 0003.pcd
├── sliceLayer.m
├── src
├── +helper
│ ├── applyActivations.m
│ ├── applyAnchorBoxOffsets.m
│ ├── configureTrainingProgressPlotter.m
│ ├── createBatchData.m
│ ├── displayLossInfo.m
│ ├── downloadPretrainedYOLOv4.m
│ ├── extractPredictions.m
│ ├── generateTargets.m
│ ├── generateTiledAnchors.m
│ ├── getAnchors.m
│ ├── getClassNames.m
│ ├── getGridParameters.m
│ ├── piecewiseLearningRateWithWarmup.m
│ ├── postprocess.m
│ ├── preprocess.m
│ ├── preprocessData.m
│ ├── removeEmptyData.m
│ ├── transferbboxToPointCloud.m
│ ├── updatePlots.m
│ ├── validateInputData.m
│ └── yolov4Forward.m
├── +loss
│ ├── bboxOffsetLoss.m
│ ├── classConfidenceLoss.m
│ └── objectnessLoss.m
├── configureYOLOv4.m
└── modelGradients.m
└── test
├── tPretrainedYOLOv4.m
├── tdownloadPretrainedComplexYOLOv4.m
├── tload.m
└── tools
├── DownloadComplexYolov4Fixture.m
└── getRepoRoot.m
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021, The MathWorks, Inc.
2 | All rights reserved.
3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
5 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
6 | 3. In all cases, the software is, and all modifications and derivatives of the software shall be, licensed to you solely for use in conjunction with MathWorks products and service offerings.
7 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
8 |
--------------------------------------------------------------------------------
/Pandaset.rights:
--------------------------------------------------------------------------------
1 | Panda Set (https://scale.com/open-datasets/pandaset) is provided by Hesai and Scale under the CC-BY-4.0 license (https://creativecommons.org/licenses/by/4.0).
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Complex YOLO v4 Network For Lidar Object Detection
2 |
3 | This repository provides a pretrained Complex YOLO v4 Lidar object detection network for MATLAB®.
4 |
5 | [](https://matlab.mathworks.com/open/github/v1?repo=matlab-deep-learning/Lidar-object-detection-using-complex-yolov4)
6 |
7 | Requirements
8 | ------------
9 |
10 | - MATLAB® R2021a or later
11 | - Deep Learning Toolbox™
12 | - Lidar Toolbox™
13 |
14 | Overview
15 | --------
16 |
17 | YOLO v4[1] is a popular single stage object detector that performs detection and classification using CNNs. In this repository we use Complex-YOLO v4[2] approach, which is a efficient method for Lidar object detection that directly operates Birds-Eye-View (BEV) transformed RGB maps to estimate and localize accurate 3-D bounding boxes. The bounding boxes detected on the RGB-maps are projected back onto the point cloud to generate 3-D bounding boxes. The detection head of YOLOv4 model was modified to predict the angle regression along with bounding boxes, objectness score and classification scores.
18 |
19 |
20 |
21 | This repository implements two variants of the complex YOLO v4 object detectors:
22 | - **complex-yolov4-pandaset**: Standard complex yolov4 network for accurate object detection.
23 | - **tiny-complex-yolov4-pandaset**: Lightweight complex yolov4 network for faster inference.
24 |
25 | The pretrained networks are trained on three different object categories Car, Truck and Pedestrain. These networks are trained on the Pandaset dataset, available at https://scale.com/open-datasets/pandaset.
26 |
27 | For other variant of Lidar object detection network, refer [Lidar 3-D Object Detection Using PointPillars Deep Learning](https://www.mathworks.com/help/lidar/ug/object-detection-using-pointpillars-network.html).
28 |
29 | Getting Started
30 | ---------------
31 |
32 | Download or clone this repository to your machine and open it in MATLAB®.
33 |
34 | ### Setup
35 | Add path to the source directory.
36 |
37 | ```
38 | addpath('src');
39 | ```
40 |
41 | ### Download the pretrained network
42 | Use the below helper to download the complex YOLO v4 pretrained models. Use "complex-yolov4-pandaset" model name for selecting standard complex YOLO v4 pretrained network and "tiny-complex-yolov4-pandaset" model name for tiny complex YOLO v4 network.
43 |
44 | ```
45 | modelName = 'complex-yolov4-pandaset';
46 | model = helper.downloadPretrainedYOLOv4(modelName);
47 | net = model.net;
48 | ```
49 |
50 | ### Detect Objects Using Pretrained complex YOLO v4
51 |
52 | ```
53 | % Read point cloud.
54 | ptCld = pcread('pointclouds/0001.pcd');
55 |
56 | % Get the configuration parameters.
57 | gridParams = helper.getGridParameters;
58 |
59 | % Get classnames of Pandaset dataset.
60 | classNames = helper.getClassNames;
61 |
62 | % Get the birds's-eye-view RGB map from the point cloud.
63 | [img,ptCldOut] = helper.preprocess(ptCld, gridParams);
64 |
65 | % Get anchors used in training of the pretrained model.
66 | anchors = helper.getAnchors(modelName);
67 |
68 | % Detect objects in test image.
69 | executionEnvironment = 'auto';
70 | [bboxes, scores, labels] = detectComplexYOLOv4(net, img, anchors, classNames, executionEnvironment);
71 |
72 | figure
73 | imshow(img)
74 | showShape('rectangle',bboxes(labels=='Car',:),...
75 | 'Color','green','LineWidth',0.5);hold on;
76 | showShape('rectangle',bboxes(labels=='Truck',:),...
77 | 'Color','magenta','LineWidth',0.5);
78 | showShape('rectangle',bboxes(labels=='Pedestrain',:),...
79 | 'Color','yellow','LineWidth',0.5);
80 | hold off;
81 | ```
82 | ### Transfer Bounding Boxes to Point Cloud
83 |
84 | ```
85 | % Transfer labels to point cloud.
86 | bboxCuboid = helper.transferbboxToPointCloud(bboxes,gridParams,ptCldOut);
87 |
88 | figure
89 | pcshow(ptCldOut.Location);
90 | showShape('cuboid',bboxCuboid(labels=='Car',:),...
91 | 'Color','green','LineWidth',0.5);hold on;
92 | showShape('cuboid',bboxCuboid(labels=='Truck',:),...
93 | 'Color','magenta','LineWidth',0.5);
94 | showShape('cuboid',bboxCuboid(labels=='Pedestrain',:),...
95 | 'Color','yellow','LineWidth',0.5);
96 | hold off;
97 | ```
98 |
99 | ### Results
100 | Left-side image is the network output on the Bird's-eye-view image and right-side image shows the bounding boxes detected from image transferred on to the point cloud. The image is taken from the [Panda Set](https://scale.com/open-datasets/pandaset) dataset[2].
101 |
102 |
103 |
104 | |
105 | |
106 |
107 |
108 |
109 | Train Custom Complex YOLO v4 Using Transfer Learning
110 | ----------------------------------------------------
111 | Run the `createTrainingData.m` function to download the Pandaset dataset and create the RGB maps from the lidar data used to train the complex-YOLOv4 network.
112 |
113 | Transfer learning enables you to adapt a pretrained complex YOLO v4 network to your dataset. Create a custom complex YOLO v4 network for transfer learning with a new set of classes and train using the `complexYOLOv4TransferLearn.m` script.
114 |
115 | Code Generation for Complex YOLO v4
116 | -----------------------------------
117 | Code generation enables you to generate code and deploy complex YOLO v4 on multiple embedded platforms.
118 |
119 | Run `codegenComplexYOLOv4.m`. This script calls the `complexYOLOv4Predict.m` entry point function and generate CUDA code for complex-yolov4-pandaset or complex-yolov4-tiny-pandaset models. It will run the generated MEX and give output.
120 |
121 | | Model | Speed(FPS) with Codegen| Speed(FPS) w/o Codegen |
122 | | ------ | ------ | ------ |
123 | | complex-yolov4-pandaset | 14.663 | 1.4738 |
124 | | tiny-complex-yolov4-pandaset | 44.248 | 9.93 |
125 |
126 | - Performance (in FPS) is measured on a TITAN-RTX GPU.
127 |
128 | For more information about codegen, see [Deep Learning with GPU Coder](https://www.mathworks.com/help/gpucoder/gpucoder-deep-learning.html)
129 |
130 | Complex YOLO v4 Network Details
131 | -------------------------------
132 | Complex YOLO v4 network architecture is comprised of three sections i.e. Backbone, Neck and Detection Head.
133 |
134 | 
135 |
136 | - **Backbone:** CSP-Darknet53(Cross-Stage-Partial Darknet53) is used as the backbone for YOLO v4 networks. This is a model with a higher input resolution (608 x 608), a larger receptive field size (725 x 725), a larger number of 3 x 3 convolutional layers and a larger number of parameters. Larger receptive field helps to view the entire objects in an image and understand the contexts around those. Higher input resolution helps in detection of small sized objects. Hence, CSP-Darknet53 is a suitable backbone for detecting multiple objects of different sizes in a single image.
137 |
138 | - **Neck:** This section comprised of many bottom-up and top-down aggregation paths. It helps to increase the receptive field further in the network and separates out the most significant context features and causes almost no reduction of the network operation speed. SPP (Spatial Pyramid Pooling) blocks have been added as neck section over the CSP-Darknet53 backbone. PANet (Path Aggregation Network) is used as the method of parameter aggregation from different backbone levels for different detector levels.
139 |
140 | - **Detection Head**: This section processes the aggregated features from the Neck section and predicts the Bounding boxes, Angle regression,Objectness score and Classification scores. In addition to predicting the bounding boxes in conventional YOLOv4, we also predict the angle regression that estimates the heading of each 3-D box.
141 |
142 | References
143 | -----------
144 |
145 | [1] Bochkovskiy, Alexey, et al. “YOLOv4: Optimal Speed and Accuracy of Object Detection.” ArXiv:2004.10934 [Cs, Eess], Apr. 2020. arXiv.org, http://arxiv.org/abs/2004.10934.
146 |
147 | [2] Simon, Martin, et al. “Complex-YOLO: Real-Time 3D Object Detection on Point Clouds.” ArXiv:1803.06199 [Cs], Sept. 2018. arXiv.org, http://arxiv.org/abs/1803.06199.
148 |
149 | Copyright 2021 The MathWorks, Inc.
150 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Reporting Security Vulnerabilities
2 |
3 | If you believe you have discovered a security vulnerability, please report it to
4 | [security@mathworks.com](mailto:security@mathworks.com). Please see
5 | [MathWorks Vulnerability Disclosure Policy for Security Researchers](https://www.mathworks.com/company/aboutus/policies_statements/vulnerability-disclosure-policy.html)
6 | for additional information.
7 |
--------------------------------------------------------------------------------
/codegenComplexYOLOV4.m:
--------------------------------------------------------------------------------
1 | %% Code Generation for YOLO v4
2 | % The following code demonstrates code generation for pre-trained Complex
3 | % YOLO v4 object detection network, trained on Pandaset dataset.
4 |
5 | %% Setup
6 | % Add path to the source directory.
7 | addpath('src');
8 |
9 | %% Download the Pretrained Network
10 | % This repository uses two variants of YOLO v4 models.
11 | % *complex-yolov4-pandaset*
12 | % *tiny-complex-yolov4-pandaset*
13 | % Set the modelName from the above ones to download that pretrained model.
14 | modelName = 'complex-yolov4-pandaset';
15 | model = helper.downloadPretrainedYOLOv4(modelName);
16 | net = model.net;
17 |
18 | %% Read and Preprocess Input Point Cloud.
19 | % Read point cloud.
20 | ptCld = pcread('pointclouds/0001.pcd');
21 |
22 | % Get the configuration parameters.
23 | gridParams = helper.getGridParameters();
24 |
25 | % Preprocess the input.
26 | [I,ptCldOut] = helper.preprocess(ptCld, gridParams);
27 |
28 | imgSize = [size(I,1),size(I,2)];
29 | inputSize = net.Layers(1).InputSize;
30 | scale = imgSize./inputSize(1:2);
31 |
32 | % Provide location of the mat file of the trained network.
33 | matFile = sprintf('models/complex-yolov4-models-master/models/complex-yolov4-pandaset/%s.mat',modelName);
34 |
35 | %% Run MEX code generation
36 | % The complexyolov4Predict.m is the entry-point function that takes an
37 | % input image and give output for complex-yolov4-pandaset or
38 | % tiny-complex-yolov4-pandaset models. The functions uses a persistent
39 | % object yolov4obj to load the dlnetwork object and reuses that persistent
40 | % object for prediction on subsequent calls.
41 | %
42 | % To generate CUDA code for the entry-point functions,create a GPU code
43 | % configuration object for a MEX target and set the target language to C++.
44 | %
45 | % Use the coder.DeepLearningConfig (GPU Coder) function to create a CuDNN
46 | % deep learning configuration object and assign it to the DeepLearningConfig
47 | % property of the GPU code configuration object.
48 | %
49 | % Run the codegen command.
50 | cfg = coder.gpuConfig('mex');
51 | cfg.TargetLang = 'C++';
52 | cfg.DeepLearningConfig = coder.DeepLearningConfig('cudnn');
53 | args = {coder.Constant(matFile), single(I)};
54 | codegen -config cfg complexYOLOv4Predict -args args -report
55 |
56 | %% Run Generated MEX
57 | % Call tiny-complex-yolov4-pandaset on the input image.
58 | outFeatureMaps = complexYOLOv4Predict_mex(matFile,single(I));
59 |
60 | % Get classnames of Pandaset dataset.
61 | classNames = helper.getClassNames;
62 |
63 | % Get anchors used in training of the pretrained model.
64 | anchors = helper.getAnchors(modelName);
65 |
66 | % Visualize detection results.
67 | [bboxes,scores,labels] = helper.postprocess(outFeatureMaps, anchors, inputSize, scale, classNames);
68 |
69 | figure
70 | imshow(I)
71 | showShape('rectangle',bboxes(labels=='Car',:),...
72 | 'Color','green','LineWidth',0.5);hold on;
73 | showShape('rectangle',bboxes(labels=='Truck',:),...
74 | 'Color','magenta','LineWidth',0.5);
75 | showShape('rectangle',bboxes(labels=='Pedestrain',:),...
76 | 'Color','yellow','LineWidth',0.5);
77 | hold off;
78 |
79 | %% Project the bounding boxes to the point cloud
80 | % Transfer labels to point cloud.
81 | bboxCuboid = helper.transferbboxToPointCloud(bboxes,gridParams,ptCldOut);
82 |
83 | figure
84 | pcshow(ptCldOut.Location);
85 | showShape('cuboid',bboxCuboid(labels=='Car',:),...
86 | 'Color','green','LineWidth',0.5);hold on;
87 | showShape('cuboid',bboxCuboid(labels=='Truck',:),...
88 | 'Color','magenta','LineWidth',0.5);
89 | showShape('cuboid',bboxCuboid(labels=='Pedestrain',:),...
90 | 'Color','yellow','LineWidth',0.5);
91 | hold off;
92 |
93 | % Copyright 2021 The MathWorks, Inc.
94 |
--------------------------------------------------------------------------------
/complexYOLOv4Predict.m:
--------------------------------------------------------------------------------
1 | function out = complexYOLOv4Predict(matFile, image)
2 | %#codegen
3 | % Copyright 2021 The MathWorks, Inc.
4 |
5 | % Convert input to dlarray.
6 | dlInput = dlarray(image, 'SSCB');
7 |
8 | persistent yolov4Obj;
9 |
10 | if isempty(yolov4Obj)
11 | yolov4Obj = coder.loadDeepLearningNetwork(matFile);
12 | end
13 |
14 | % Pass input.
15 | out = cell(size(yolov4Obj.OutputNames,2),1);
16 | [out{:}] = yolov4Obj.predict(dlInput);
17 |
--------------------------------------------------------------------------------
/complexYOLOv4PredictionExample.m:
--------------------------------------------------------------------------------
1 | %% Object Detection Using Complex Pretrained YOLO v4 Network
2 | % The following code demonstrates running object detection on point clouds
3 | % using a pretrained Complex YOLO v4 network, trained on Pandaset dataset.
4 |
5 | %% Prerequisites
6 | % To run this example you need the following prerequisites -
7 | %
8 | % # MATLAB (R2021a or later) with Lidar and Deep Learning Toolbox.
9 | % # Pretrained Complex YOLOv4 network (download instructions below).
10 |
11 | %% Setup
12 | % Add path to the source directory.
13 | addpath('src');
14 |
15 | %% Download the pre-trained network
16 | % This repository uses two variants of Complex YOLO v4 models.
17 | % *complex-yolov4-pandaset*
18 | % *tiny-complex-yolov4-pandaset*
19 | % Set the modelName from the above ones to download that pretrained model.
20 | modelName = 'complex-yolov4-pandaset';
21 | model = helper.downloadPretrainedYOLOv4(modelName);
22 | net = model.net;
23 |
24 | %% Detect Objects using Complex YOLO v4 Object Detector
25 | % Read point cloud.
26 | ptCld = pcread('pointclouds/0001.pcd');
27 |
28 | % Get the configuration parameters.
29 | gridParams = helper.getGridParameters;
30 |
31 | % Get classnames of Pandaset dataset.
32 | classNames = helper.getClassNames;
33 |
34 | % Get the birds's-eye-view RGB map from the point cloud.
35 | [img,ptCldOut] = helper.preprocess(ptCld, gridParams);
36 |
37 | % Get anchors used in training of the pretrained model.
38 | anchors = helper.getAnchors(modelName);
39 |
40 | % Detect objects in test image.
41 | executionEnvironment = 'auto';
42 | [bboxes, scores, labels] = detectComplexYOLOv4(net, img, anchors, classNames, executionEnvironment);
43 |
44 | % Display the results on an image.
45 | figure
46 | imshow(img)
47 | showShape('rectangle',bboxes(labels=='Car',:),...
48 | 'Color','green','LineWidth',0.5);hold on;
49 | showShape('rectangle',bboxes(labels=='Truck',:),...
50 | 'Color','magenta','LineWidth',0.5);
51 | showShape('rectangle',bboxes(labels=='Pedestrain',:),...
52 | 'Color','yellow','LineWidth',0.5);
53 | hold off;
54 |
55 | %% Project the bounding boxes to the point cloud
56 | % Transfer labels to point cloud.
57 | bboxCuboid = helper.transferbboxToPointCloud(bboxes,gridParams,ptCldOut);
58 |
59 | % Display the results on point cloud.
60 | figure
61 | pcshow(ptCldOut.Location);
62 | showShape('cuboid',bboxCuboid(labels=='Car',:),...
63 | 'Color','green','LineWidth',0.5);hold on;
64 | showShape('cuboid',bboxCuboid(labels=='Truck',:),...
65 | 'Color','magenta','LineWidth',0.5);
66 | showShape('cuboid',bboxCuboid(labels=='Pedestrain',:),...
67 | 'Color','yellow','LineWidth',0.5);
68 | hold off;
69 |
70 |
71 | % Copyright 2021 The MathWorks, Inc.
72 |
--------------------------------------------------------------------------------
/complexYOLOv4TransferLearn.m:
--------------------------------------------------------------------------------
1 | %% Transfer Learning Using Pretrained Complex YOLO v4 Network
2 | % The following code demonstrates how to perform transfer learning using
3 | % the pretrained Complex YOLO v4 network for object detection. This script
4 | % uses the "configureYOLOv4" function to create a custom Complex YOLO v4
5 | % network using the pretrained model.
6 |
7 | %% Setup
8 | % Add path to the source directory.
9 | addpath('src');
10 |
11 | %% Download Pretrained Network
12 | % This repository uses two variants of YOLO v4 models.
13 | % *complex-yolov4-pandaset*
14 | % *tiny-complex-yolov4-pandaset*
15 | % Set the modelName from the above ones to download that pretrained model.
16 | modelName = 'tiny-complex-yolov4-pandaset';
17 | model = helper.downloadPretrainedYOLOv4(modelName);
18 | net = model.net;
19 |
20 | %% Load Data
21 | % Create a datastore for loading the BEV images.
22 | imageFileLocation = fullfile(tempdir,'Pandaset','BEVImages');
23 | imds = imageDatastore(imageFileLocation);
24 |
25 | % Create a datastore for loading the ground truth bounding boxes.
26 | boxLabelLocation = fullfile(tempdir,'Pandaset','Cuboids','BEVGroundTruthLabels.mat');
27 | load(boxLabelLocation,'processedLabels');
28 | bds = boxLabelDatastore(processedLabels);
29 |
30 | % Remove data with zero labels from the training data.
31 | [imds,bds] = helper.removeEmptyData(imds,bds);
32 |
33 | % Split the data set into a training set for training the network, and a test
34 | % set for evaluating the network. Use 60% of the data for training set and the
35 | % rest for the test set.
36 | rng(0);
37 | shuffledIndices = randperm(size(imds.Files,1));
38 | idx = floor(0.6 * length(shuffledIndices));
39 |
40 | % Split the image datastore into training and test set.
41 | imdsTrain = subset(imds,shuffledIndices(1:idx));
42 | imdsTest = subset(imds,shuffledIndices(idx+1:end));
43 |
44 | % Split the box label datastore into training and test set.
45 | bdsTrain = subset(bds,shuffledIndices(1:idx));
46 | bdsTest = subset(bds,shuffledIndices(idx+1:end));
47 |
48 | % Combine the image and box label datastore.
49 | trainingData = combine(imdsTrain,bdsTrain);
50 | testData = combine(imdsTest,bdsTest);
51 |
52 | helper.validateInputData(trainingData);
53 | helper.validateInputData(testData);
54 |
55 | %% Preprocess Training Data
56 | % Specify the network input size.
57 | networkInputSize = net.Layers(1).InputSize;
58 |
59 | % Preprocess the augmented training data to prepare for training. The
60 | % preprocessData helper function, listed at the end of the example, applies
61 | % the following preprocessing operations to the input data.
62 | %
63 | % * Resize the images to the network input size
64 | % * Scale the image pixels in the range |[0 1]|.
65 | preprocessedTrainingData = transform(trainingData, @(data)helper.preprocessData(data, networkInputSize, 1));
66 |
67 | % Read the preprocessed training data.
68 | data = read(preprocessedTrainingData);
69 |
70 | % Display the image with the bounding boxes.
71 | I = data{1,1};
72 | bbox = data{1,2};
73 | labels = data{1,3};
74 |
75 | figure
76 | imshow(I)
77 | showShape('rectangle',bbox(labels=='Car',:),...
78 | 'Color','green','LineWidth',0.5);hold on;
79 | showShape('rectangle',bbox(labels=='Truck',:),...
80 | 'Color','magenta','LineWidth',0.5);
81 | showShape('rectangle',bbox(labels=='Pedestrain',:),...
82 | 'Color','yellow','LineWidth',0.5);
83 |
84 | % Reset the datastore.
85 | reset(preprocessedTrainingData);
86 |
87 | %% Modify Pretrained Complex YOLO v4 Network
88 | % The Complex YOLO v4 network uses anchor boxes estimated using training
89 | % data to have better initial priors corresponding to the type of data set
90 | % and to help the network learn to predict the boxes accurately.
91 | %
92 | % First, use transform to preprocess the training data for computing the
93 | % anchor boxes, as the training images used in this example vary in size.
94 | %
95 | % Specify the number of anchors as follows:
96 | % 'complex-yolov4-pandaset' model - 9
97 | % 'tiny-complex-yolov4-pandaset' model - 6
98 | %
99 | % Use the estimateAnchorBoxes function to estimate the anchor boxes. Note
100 | % that the estimation process is not deterministic. To prevent the
101 | % estimated anchor boxes from changing while tuning other hyperparameters
102 | % set the random seed prior to estimation using rng.
103 | %
104 | % Then pass the estimated anchor boxes to configureYOLOv4 function to
105 | % arrange them in correct order to be used in the training.
106 | rng(0)
107 | trainingDataForEstimation = transform(trainingData, @(data)helper.preprocessData(data, networkInputSize, 0));
108 | numAnchors = 6;
109 | [anchorBoxes, meanIoU] = estimateAnchorBoxes(trainingDataForEstimation, numAnchors);
110 |
111 | % Specify the classNames to be used for training.
112 | classNames = helper.getClassNames;
113 |
114 | % Configure the pretrained model for transfer learning using
115 | % configureYOLOv4 function. This function will return the modified
116 | % layergraph, networkoutput names, reordered anchorBoxes and anchorBoxMasks
117 | % to select anchor boxes to use in the detection heads.
118 | [lgraph, networkOutputs, anchorBoxes, anchorBoxMasks] = configureYOLOv4(net, classNames, anchorBoxes, modelName);
119 | anchors.anchorBoxes = anchorBoxes;
120 | anchors.anchorBoxMasks = anchorBoxMasks;
121 | %% Specify Training Options
122 | % Specify these training options.
123 | %
124 | % * Set the number of epochs to be 90.
125 | % * Set the mini batch size as 4. Stable training can be possible with higher
126 | % learning rates when higher mini batch size is used. Although, this should
127 | % be set depending on the available memory.
128 | % * Set the learning rate to 0.001.
129 | % * Set the warmup period as 1000 iterations. It helps in stabilizing the
130 | % gradients at higher learning rates.
131 | % * Set the L2 regularization factor to 0.0005.
132 | % * Specify the penalty threshold as 0.5. Detections that overlap less than
133 | % 0.5 with the ground truth are penalized.
134 | % * Initialize the velocity of gradient as []. This is used by SGDM to store
135 | % the velocity of gradients.
136 | numEpochs = 90;
137 | miniBatchSize = 8;
138 | learningRate = 0.001;
139 | warmupPeriod = 1000;
140 | l2Regularization = 0.001;
141 | penaltyThreshold = 0.5;
142 | velocity = [];
143 |
144 | %% Train Model
145 | % Train on a GPU, if one is available. Using a GPU requires Parallel
146 | % Computing Toolbox™ and a CUDA® enabled NVIDIA® GPU.
147 | %
148 | % Use the minibatchqueue function to split the preprocessed training data
149 | % into batches with the supporting function createBatchData which returns
150 | % the batched images and bounding boxes combined with the respective class
151 | % IDs. For faster extraction of the batch data for training,
152 | % dispatchInBackground should be set to "true" which ensures the usage of
153 | % parallel pool.
154 | %
155 | % minibatchqueue automatically detects the availability of a GPU. If you do
156 | % not have a GPU, or do not want to use one for training, set the
157 | % OutputEnvironment parameter to "cpu".
158 | if canUseParallelPool
159 | dispatchInBackground = true;
160 | else
161 | dispatchInBackground = false;
162 | end
163 |
164 | mbqTrain = minibatchqueue(preprocessedTrainingData, 2,...
165 | "MiniBatchSize", miniBatchSize,...
166 | "MiniBatchFcn", @(images, boxes, labels) helper.createBatchData(images, boxes, labels, classNames), ...
167 | "MiniBatchFormat", ["SSCB", ""],...
168 | "DispatchInBackground", dispatchInBackground,...
169 | "OutputCast", ["", "double"]);
170 |
171 | % To train the network with a custom training loop and enable automatic
172 | % differentiation, convert the layer graph to a dlnetwork object. Then
173 | % create the training progress plotter using supporting function
174 | % configureTrainingProgressPlotter.
175 | %
176 | % Finally, specify the custom training loop. For each iteration:
177 | %
178 | % * Read data from the minibatchqueue. If it doesn't have any more data,
179 | % reset the minibatchqueue and shuffle.
180 | % * Evaluate the model gradients using dlfeval and the modelGradients
181 | % function. The function modelGradients, listed as a supporting function,
182 | % returns the gradients of the loss with respect to the learnable
183 | % parameters in net, the corresponding mini-batch loss, and the state of
184 | % the current batch.
185 | % * Apply a weight decay factor to the gradients to regularization for more
186 | % robust training.
187 | % * Determine the learning rate based on the iterations using the
188 | % piecewiseLearningRateWithWarmup supporting function.
189 | % * Update the network parameters using the sgdmupdate function.
190 | % * Update the state parameters of net with the moving average.
191 | % * Display the learning rate, total loss, and the individual losses (box
192 | % loss, object loss and class loss) for every iteration. These can be used
193 | % to interpret how the respective losses are changing in each iteration.
194 | % For example, a sudden spike in the box loss after few iterations implies
195 | % that there are Inf or NaNs in the predictions.
196 | % * Update the training progress plot.
197 |
198 | % The training can also be terminated if the loss has saturated for few
199 | % epochs.
200 |
201 | % Convert layer graph to dlnetwork.
202 | net = dlnetwork(lgraph);
203 |
204 | % Create subplots for the learning rate and mini-batch loss.
205 | fig = figure;
206 | [lossPlotter, learningRatePlotter] = helper.configureTrainingProgressPlotter(fig);
207 |
208 | iteration = 0;
209 | % Custom training loop.
210 | for epoch = 1:numEpochs
211 |
212 | reset(mbqTrain);
213 | shuffle(mbqTrain);
214 |
215 | while(hasdata(mbqTrain))
216 | iteration = iteration + 1;
217 |
218 | [XTrain, YTrain] = next(mbqTrain);
219 |
220 | % Evaluate the model gradients and loss using dlfeval and the
221 | % modelGradients function.
222 | [gradients, state, lossInfo] = dlfeval(@modelGradients, net, XTrain, YTrain, anchorBoxes, anchorBoxMasks, penaltyThreshold, networkOutputs);
223 |
224 | % Apply L2 regularization.
225 | gradients = dlupdate(@(g,w) g + l2Regularization*w, gradients, net.Learnables);
226 |
227 | % Determine the current learning rate value.
228 | currentLR = helper.piecewiseLearningRateWithWarmup(iteration, epoch, learningRate, warmupPeriod, numEpochs);
229 |
230 | % Update the network learnable parameters using the SGDM optimizer.
231 | [net, velocity] = sgdmupdate(net, gradients, velocity, currentLR);
232 |
233 | % Update the state parameters of dlnetwork.
234 | net.State = state;
235 |
236 | % Display progress.
237 | if mod(iteration,10)==1
238 | helper.displayLossInfo(epoch, iteration, currentLR, lossInfo);
239 | end
240 |
241 | % Update training plot with new points.
242 | helper.updatePlots(lossPlotter, learningRatePlotter, iteration, currentLR, lossInfo.totalLoss);
243 | end
244 | end
245 |
246 | %% Evaluate Model
247 | % Computer Vision System Toolbox provides object detector evaluation
248 | % functions to measure common metrics such as average orientation
249 | % similarity (evaluateDetectionAOS). The average orientation similarity
250 | % provides a single number that incorporates the ability of the detector to
251 | % make correct classifications (precision) and the ability of the detector
252 | % to find all relevant objects (recall).
253 | %
254 | % Following these steps to evaluate the trained dlnetwork object net on test data.
255 | %
256 | % * Specify the confidence threshold as 0.5 to keep only detections with confidence
257 | % scores above this value.
258 | % * Specify the overlap threshold as 0.5 to remove overlapping detections.
259 | % * Apply the same preprocessing transform to the test data as for the training
260 | % data. Note that data augmentation is not applied to the test data. Test data
261 | % must be representative of the original data and be left unmodified for unbiased
262 | % evaluation.
263 | % * Collect the detection results by running the detector on testData.
264 | % Use the supporting function detectComplexYOLOv4 to get the bounding boxes, object
265 | % confidence scores, and class labels.
266 | % * Call evaluateDetectionAOS with predicted results and preprocessedTestData
267 | % as arguments.
268 | confidenceThreshold = 0.5;
269 | overlapThreshold = 0.5;
270 |
271 | % Create a table to hold the bounding boxes, scores, and labels returned by
272 | % the detector.
273 | results = table('Size', [0 3], ...
274 | 'VariableTypes', {'cell','cell','cell'}, ...
275 | 'VariableNames', {'Boxes','Scores','Labels'});
276 |
277 | % Run detector on images in the test set and collect results.
278 | reset(testData)
279 | while hasdata(testData)
280 | % Read the datastore and get the image.
281 | data = read(testData);
282 | image = data{1,1};
283 |
284 | % Run the detector.
285 | executionEnvironment = 'auto';
286 | [bboxes, scores, labels] = detectComplexYOLOv4(net, image, anchors, classNames, executionEnvironment);
287 |
288 | % Collect the results.
289 | tbl = table({bboxes}, {scores}, {labels}, 'VariableNames', {'Boxes','Scores','Labels'});
290 | results = [results; tbl];
291 | end
292 |
293 | % Evaluate the object detector using Average Precision metric.
294 | metrics = evaluateDetectionAOS(results, testData);
295 |
296 | %% Detect Objects Using Trained Complex YOLO v4
297 | % Use the network for object detection.
298 | %
299 | % * Read an image.
300 | % * Convert the image to a dlarray and use a GPU if one is available..
301 | % * Use the supporting function detectComplexYOLOv4 to get the predicted
302 | % bounding boxes, confidence scores, and class labels.
303 | % * Display the image with bounding boxes and confidence scores.
304 |
305 | % Read the datastore.
306 | reset(testData)
307 | data = read(testData);
308 |
309 | % Get the image.
310 | I = data{1,1};
311 |
312 | % Run the detector.
313 | executionEnvironment = 'auto';
314 | [bboxes, scores, labels] = detectComplexYOLOv4(net, I, anchors, classNames, executionEnvironment);
315 |
316 | figure
317 | imshow(I)
318 | showShape('rectangle',bboxes(labels=='Car',:),...
319 | 'Color','green','LineWidth',0.5);hold on;
320 | showShape('rectangle',bboxes(labels=='Truck',:),...
321 | 'Color','magenta','LineWidth',0.5);
322 | showShape('rectangle',bboxes(labels=='Pedestrain',:),...
323 | 'Color','yellow','LineWidth',0.5);
324 | %% References
325 | % 1. Bochkovskiy, Alexey, Chien-Yao Wang, and Hong-Yuan Mark Liao. "YOLOv4:
326 | % Optimal Speed and Accuracy of Object Detection." arXiv preprint arXiv:2004.10934
327 | % (2020).
328 | %
329 | % Copyright 2021 The MathWorks, Inc.
--------------------------------------------------------------------------------
/createTrainingData.m:
--------------------------------------------------------------------------------
1 | %% Download Pandaset Data Set
2 | % This example uses a subset of PandaSet[2], that contains 2560
3 | % preprocessed organized point clouds. Each point cloud is specified as a
4 | % 64-by-1856 matrix. The corresponding ground truth contains the semantic
5 | % segmentation labels for 12 classes. The point clouds are stored in PCD
6 | % format, and the ground truth data is stored in PNG format. The size of
7 | % the data set is 5.2 GB. Execute this code to download the data set.
8 |
9 | url = 'https://ssd.mathworks.com/supportfiles/lidar/data/Pandaset_LidarData.tar.gz';
10 | outputFolder = fullfile(tempdir,'Pandaset');
11 |
12 | lidarDataTarFile = fullfile(outputFolder,'Pandaset_LidarData.tar.gz');
13 | if ~exist(lidarDataTarFile, 'file')
14 | mkdir(outputFolder);
15 | disp('Downloading Pandaset Lidar driving data (5.2 GB)...');
16 | websave(lidarDataTarFile, url);
17 | untar(lidarDataTarFile,outputFolder);
18 | end
19 |
20 | % Check if tar.gz file is downloaded, but not uncompressed.
21 | if (~exist(fullfile(outputFolder,'Lidar'), 'file'))...
22 | &&(~exist(fullfile(outputFolder,'Cuboids'), 'file'))
23 | untar(lidarDataTarFile,outputFolder);
24 | end
25 |
26 | lidarFolder = fullfile(outputFolder,'Lidar');
27 | labelsFolder = fullfile(outputFolder,'Cuboids');
28 |
29 | % Note: Depending on your Internet connection, the download process can
30 | % take some time. The code suspends MATLAB® execution until the download
31 | % process is complete. Alternatively, you can download the data set to your
32 | % local disk using your web browser, and then extract Pandaset_LidarData
33 | % folder. To use the file you downloaded from the web, change the
34 | % outputFolder variable in the code to the location of the downloaded file.
35 |
36 | %% Create the Bird's Eye View Image from the point cloud data
37 |
38 | % Read the ground truth labels.
39 | gtPath = fullfile(labelsFolder,'PandaSetLidarGroundTruth.mat');
40 | data = load(gtPath,'lidarGtLabels');
41 | Labels = timetable2table(data.lidarGtLabels);
42 | boxLabels = Labels(:,2:end);
43 |
44 | % Get the configuration parameters.
45 | gridParams = helper.getGridParameters();
46 |
47 | % Get classnames of Pandaset dataset.
48 | classNames = boxLabels.Properties.VariableNames;
49 |
50 | numFiles = size(boxLabels,1);
51 | processedLabels = cell(size(boxLabels));
52 |
53 | for i = 1:numFiles
54 |
55 | lidarPath = fullfile(lidarFolder,sprintf('%04d.pcd',i));
56 | ptCloud = pcread(lidarPath);
57 |
58 | groundTruth = boxLabels(i,:);
59 |
60 | [processedData,~] = helper.preprocess(ptCloud,gridParams);
61 |
62 | for ii = 1:numel(classNames)
63 |
64 | labels = groundTruth(1,classNames{ii}).Variables;
65 | if(iscell(labels))
66 | labels = labels{1};
67 | end
68 | if ~isempty(labels)
69 |
70 | % Get the label indices that are in the selected RoI.
71 | % Get the label indices that are in the selected RoI.
72 | labelsIndices = labels(:,1) - labels(:,4) > gridParams.xMin ...
73 | & labels(:,1) + labels(:,4) < gridParams.xMax ...
74 | & labels(:,2) - labels(:,5) > gridParams.yMin ...
75 | & labels(:,2) + labels(:,5) < gridParams.yMax ...
76 | & labels(:,4) > 0 ...
77 | & labels(:,5) > 0 ...
78 | & labels(:,6) > 0;
79 | labels = labels(labelsIndices,:);
80 |
81 | labelsBEV = labels(:,[2,1,5,4,9]);
82 | labelsBEV(:,5) = -labelsBEV(:,5);
83 |
84 | labelsBEV(:,1) = int32(floor(labelsBEV(:,1)/gridParams.gridW)) + 1;
85 | labelsBEV(:,2) = int32(floor(labelsBEV(:,2)/gridParams.gridH)+gridParams.bevHeight/2) + 1;
86 |
87 | labelsBEV(:,3) = int32(floor(labelsBEV(:,3)/gridParams.gridW)) + 1;
88 | labelsBEV(:,4) = int32(floor(labelsBEV(:,4)/gridParams.gridH)) + 1;
89 |
90 | end
91 | processedLabels{i,ii} = labelsBEV;
92 | end
93 |
94 | writePath = fullfile(outputFolder,'BEVImages');
95 | if ~isfolder(writePath)
96 | mkdir(writePath);
97 | end
98 |
99 | imgSavePath = fullfile(writePath,sprintf('%04d.jpg',i));
100 | imwrite(processedData,imgSavePath);
101 |
102 | end
103 |
104 | processedLabels = cell2table(processedLabels);
105 | numClasses = size(processedLabels,2);
106 | for j = 1:numClasses
107 | processedLabels.Properties.VariableNames{j} = classNames{j};
108 | end
109 |
110 | labelsSavePath = fullfile(outputFolder,'Cuboids/BEVGroundTruthLabels.mat');
111 | save(labelsSavePath,'processedLabels');
112 |
--------------------------------------------------------------------------------
/detectComplexYOLOv4.m:
--------------------------------------------------------------------------------
1 | function [bboxes, scores, labels] = detectComplexYOLOv4(dlnet, image, anchors, classNames, executionEnvironment)
2 | % detectComplexYOLOv4 runs prediction on a trained complex yolov4 network.
3 | %
4 | % Inputs:
5 | % dlnet - Pretrained complex yolov4 dlnetwork.
6 | % image - BEV image to run prediction on. (H x W x 3)
7 | % anchors - Anchors used in training of the pretrained model.
8 | % classNames - Classnames to be used in detection.
9 | % executionEnvironment - Environment to run predictions on. Specify cpu,
10 | % gpu, or auto.
11 | %
12 | % Outputs:
13 | % bboxes - Final bounding box detections ([x y w h rot]) formatted as
14 | % NumDetections x 5.
15 | % scores - NumDetections x 1 classification scores.
16 | % labels - NumDetections x 1 categorical class labels.
17 |
18 | % Copyright 2021 The MathWorks, Inc.
19 |
20 | % Get the input size of the network.
21 | inputSize = dlnet.Layers(1).InputSize;
22 |
23 | % Process the input image.
24 | imgSize = [size(image,1),size(image,2)];
25 | image = im2single(imresize(image,inputSize(:,1:2)));
26 | scale = imgSize./inputSize(1:2);
27 |
28 | % Convert to dlarray.
29 | dlInput = dlarray(image, 'SSCB');
30 |
31 | % If GPU is available, then convert data to gpuArray.
32 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu"
33 | dlInput = gpuArray(dlInput);
34 | end
35 |
36 | % Perform prediction on the input image.
37 | outFeatureMaps = cell(length(dlnet.OutputNames), 1);
38 | [outFeatureMaps{:}] = predict(dlnet, dlInput);
39 |
40 | % Apply postprocessing on the output feature maps.
41 | [bboxes,scores,labels] = helper.postprocess(outFeatureMaps, anchors, ...
42 | inputSize, scale, classNames);
43 | end
--------------------------------------------------------------------------------
/images/BEVOutput.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/Lidar-object-detection-using-complex-yolov4/2939c44aaa1e64c282e4fae4e429d7119be33f1a/images/BEVOutput.png
--------------------------------------------------------------------------------
/images/LidarOutput.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/Lidar-object-detection-using-complex-yolov4/2939c44aaa1e64c282e4fae4e429d7119be33f1a/images/LidarOutput.png
--------------------------------------------------------------------------------
/images/network.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/Lidar-object-detection-using-complex-yolov4/2939c44aaa1e64c282e4fae4e429d7119be33f1a/images/network.png
--------------------------------------------------------------------------------
/images/overview.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/Lidar-object-detection-using-complex-yolov4/2939c44aaa1e64c282e4fae4e429d7119be33f1a/images/overview.png
--------------------------------------------------------------------------------
/mishLayer.m:
--------------------------------------------------------------------------------
1 |
2 | classdef mishLayer < nnet.layer.Layer
3 | %#codegen
4 | % Custom layer for Mish activation function.
5 |
6 | % Copyright 2021 The MathWorks, Inc.
7 |
8 | methods
9 | function layer = mishLayer(name)
10 | % Set layer name.
11 | layer.Name = name;
12 |
13 | % Set layer description.
14 | layer.Description = "mish activation layer";
15 |
16 | % Set layer type.
17 | layer.Type = 'mishLayer';
18 | end
19 |
20 | function Z = predict(~, X)
21 | Z1 = max(X,0) + log(1 + exp(-abs(X)));
22 | Z = X.*tanh(Z1);
23 | end
24 | end
25 | end
--------------------------------------------------------------------------------
/models/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/pointclouds/0001.pcd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/Lidar-object-detection-using-complex-yolov4/2939c44aaa1e64c282e4fae4e429d7119be33f1a/pointclouds/0001.pcd
--------------------------------------------------------------------------------
/pointclouds/0002.pcd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/Lidar-object-detection-using-complex-yolov4/2939c44aaa1e64c282e4fae4e429d7119be33f1a/pointclouds/0002.pcd
--------------------------------------------------------------------------------
/pointclouds/0003.pcd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/Lidar-object-detection-using-complex-yolov4/2939c44aaa1e64c282e4fae4e429d7119be33f1a/pointclouds/0003.pcd
--------------------------------------------------------------------------------
/sliceLayer.m:
--------------------------------------------------------------------------------
1 | classdef sliceLayer < nnet.layer.Layer
2 | %#codegen
3 | % Custom layer used for channel grouping.
4 |
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | properties
8 | connectID
9 | groups
10 | group_id
11 | end
12 |
13 | methods
14 | function layer = sliceLayer(name,con,groups,group_id)
15 | % Set layer name.
16 | layer.Name = name;
17 |
18 | % Set layer description.
19 | text = [ num2str(groups), ' groups,group_id: ', num2str(group_id), ' sliceLayer '];
20 | layer.Description = text;
21 |
22 | % Set layer type.
23 | layer.Type = 'sliceLayer';
24 |
25 | % Set other properties.
26 | layer.connectID= con;
27 | layer.groups= groups;
28 | layer.group_id= group_id;
29 | assert(group_id>0,'group_id must great zero! it must start index from 1');
30 | end
31 |
32 | function Z = predict(layer, X)
33 | X = reshape(X,[size(X),1]);
34 | channels = size(X,3);
35 | deltaChannels = channels/layer.groups;
36 | selectStart = (layer.group_id-1)*deltaChannels+1;
37 | selectEnd = layer.group_id*deltaChannels;
38 | Z = X(:,:,selectStart:selectEnd,:);
39 | end
40 | end
41 | end
42 |
--------------------------------------------------------------------------------
/src/+helper/applyActivations.m:
--------------------------------------------------------------------------------
1 | function YPredCell = applyActivations(YPredCell)
2 | % Apply activation functions on YOLOv4 outputs.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | YPredCell(:,1:3) = cellfun(@ sigmoid, YPredCell(:,1:3), 'UniformOutput', false);
7 | YPredCell(:,4:5) = cellfun(@ exp, YPredCell(:,4:5), 'UniformOutput', false);
8 | YPredCell(:,6:7) = cellfun(@ tanh, YPredCell(:,6:7), 'UniformOutput', false);
9 | YPredCell(:,8) = cellfun(@ sigmoid, YPredCell(:,8), 'UniformOutput', false);
10 | end
--------------------------------------------------------------------------------
/src/+helper/applyAnchorBoxOffsets.m:
--------------------------------------------------------------------------------
1 | function tiledAnchors = applyAnchorBoxOffsets(tiledAnchors,YPredCell,inputImageSize)
2 | % Convert grid cell coordinates to box coordinates.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | for i=1:size(YPredCell,1)
7 | [h,w,~,~] = size(YPredCell{i,1});
8 | tiledAnchors{i,1} = (tiledAnchors{i,1}+YPredCell{i,1})./w;
9 | tiledAnchors{i,2} = (tiledAnchors{i,2}+YPredCell{i,2})./h;
10 | tiledAnchors{i,3} = (tiledAnchors{i,3}.*YPredCell{i,3})./inputImageSize(2);
11 | tiledAnchors{i,4} = (tiledAnchors{i,4}.*YPredCell{i,4})./inputImageSize(1);
12 | end
13 | end
14 |
--------------------------------------------------------------------------------
/src/+helper/configureTrainingProgressPlotter.m:
--------------------------------------------------------------------------------
1 | function [lossPlotter, learningRatePlotter] = configureTrainingProgressPlotter(f)
2 | % Create the subplots to display the loss and learning rate.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | figure(f);
7 | clf
8 | subplot(2,1,1);
9 | ylabel('Learning Rate');
10 | xlabel('Iteration');
11 | learningRatePlotter = animatedline;
12 | subplot(2,1,2);
13 | ylabel('Total Loss');
14 | xlabel('Iteration');
15 | lossPlotter = animatedline;
16 | end
--------------------------------------------------------------------------------
/src/+helper/createBatchData.m:
--------------------------------------------------------------------------------
1 | function [XTrain, YTrain] = createBatchData(data, groundTruthBoxes, groundTruthClasses, classNames)
2 | % Returns images combined along the batch dimension in XTrain and
3 | % normalized bounding boxes concatenated with classIDs in YTrain.
4 |
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | % Concatenate images along the batch dimension.
8 | XTrain = cat(4, data{:,1});
9 |
10 | % Get class IDs from the class names.
11 | classNames = repmat({categorical(classNames')}, size(groundTruthClasses));
12 | [~, classIndices] = cellfun(@(a,b)ismember(a,b), groundTruthClasses, classNames, 'UniformOutput', false);
13 |
14 | % Append the label indexes and training image size to scaled bounding boxes
15 | % and create a single cell array of responses.
16 | combinedResponses = cellfun(@(bbox, classid)[bbox, classid], groundTruthBoxes, classIndices, 'UniformOutput', false);
17 | len = max( cellfun(@(x)size(x,1), combinedResponses ) );
18 | paddedBBoxes = cellfun( @(v) padarray(v,[len-size(v,1),0],0,'post'), combinedResponses, 'UniformOutput',false);
19 | YTrain = cat(4, paddedBBoxes{:,1});
20 | end
--------------------------------------------------------------------------------
/src/+helper/displayLossInfo.m:
--------------------------------------------------------------------------------
1 | function displayLossInfo(epoch, iteration, currentLR, lossInfo)
2 | % Display loss information for each iteration.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | disp("Epoch : " + epoch + " | Iteration : " + iteration + " | Learning Rate : " + currentLR + ...
7 | " | Total Loss : " + double(gather(extractdata(lossInfo.totalLoss))) + ...
8 | " | Box Loss : " + double(gather(extractdata(lossInfo.boxLoss))) + ...
9 | " | Object Loss : " + double(gather(extractdata(lossInfo.objLoss))) + ...
10 | " | Class Loss : " + double(gather(extractdata(lossInfo.clsLoss))));
11 | end
--------------------------------------------------------------------------------
/src/+helper/downloadPretrainedYOLOv4.m:
--------------------------------------------------------------------------------
1 | function model = downloadPretrainedYOLOv4(modelName)
2 | % The downloadPretrainedYOLOv4 function downloads a Complex YOLO v4 network
3 | % pretrained on Pandaset dataset.
4 | %
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | supportedNetworks = ["complex-yolov4-pandaset", "tiny-complex-yolov4-pandaset"];
8 | validatestring(modelName, supportedNetworks);
9 |
10 | % Download the pretrained model.
11 | dataPath = 'models';
12 | filename = matlab.internal.examples.downloadSupportFile('lidar','data/complex-yolov4-models-master.zip');
13 | unzip(filename,dataPath);
14 |
15 | % Extract the model.
16 | netFileFullPath = fullfile(dataPath,'complex-yolov4-models-master','models',modelName,[modelName '.mat']);
17 | if ~exist(netFileFullPath,'file')
18 | netFileZipPath = fullfile(dataPath,'complex-yolov4-models-master','models',[modelName '.zip']);
19 | unzip(netFileZipPath,fullfile(dataPath,'complex-yolov4-models-master','models'));
20 | model = load(netFileFullPath);
21 | else
22 | model = load(netFileFullPath);
23 | end
24 |
25 | end
26 |
--------------------------------------------------------------------------------
/src/+helper/extractPredictions.m:
--------------------------------------------------------------------------------
1 | function predictions = extractPredictions(YPredictions, anchorBoxMask)
2 | % Function extractPrediction extracts and rearranges the prediction outputs
3 | % from YOLOv4 network.
4 |
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | predictions = cell(size(YPredictions, 1),6);
8 | for ii = 1:size(YPredictions, 1)
9 | % Get the required info on feature size.
10 | numChannelsPred = size(YPredictions{ii},3);
11 | numAnchors = size(anchorBoxMask{ii},2);
12 | numPredElemsPerAnchors = numChannelsPred/numAnchors;
13 | allIds = (1:numChannelsPred);
14 |
15 | stride = numPredElemsPerAnchors;
16 | endIdx = numChannelsPred;
17 |
18 | % X positions.
19 | startIdx = 1;
20 | predictions{ii,2} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
21 | xIds = startIdx:stride:endIdx;
22 |
23 | % Y positions.
24 | startIdx = 2;
25 | predictions{ii,3} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
26 | yIds = startIdx:stride:endIdx;
27 |
28 | % Width.
29 | startIdx = 3;
30 | predictions{ii,4} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
31 | wIds = startIdx:stride:endIdx;
32 |
33 | % Height.
34 | startIdx = 4;
35 | predictions{ii,5} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
36 | hIds = startIdx:stride:endIdx;
37 |
38 | % Cos angle.
39 | startIdx = 5;
40 | predictions{ii,6} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
41 | angleIds1 = startIdx:stride:endIdx;
42 |
43 | % Sin angle.
44 | startIdx = 6;
45 | predictions{ii,7} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
46 | angleIds2 = startIdx:stride:endIdx;
47 |
48 | % Confidence scores.
49 | startIdx = 7;
50 | predictions{ii,1} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:);
51 | confIds = startIdx:stride:endIdx;
52 |
53 | % Accummulate all the non-class indexes
54 | nonClassIds = [xIds yIds wIds hIds angleIds1 angleIds2 confIds];
55 |
56 | % Class probabilities.
57 | % Get the indexes which do not belong to the nonClassIds
58 | classIdx = setdiff(allIds,nonClassIds);
59 | predictions{ii,8} = YPredictions{ii}(:,:,classIdx,:);
60 | end
61 | end
--------------------------------------------------------------------------------
/src/+helper/generateTargets.m:
--------------------------------------------------------------------------------
1 | function [boxDeltaTarget, objectnessTarget, classTarget, maskTarget, boxErrorScaleTarget] = generateTargets(YPredCellGathered, groundTruth, inputImageSize, anchorBoxes, anchorBoxMask, penaltyThreshold)
2 | % generateTargets creates target array for every prediction element
3 | % x, y, width, height, confidence scores and class probabilities.
4 |
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | boxDeltaTarget = cell(size(YPredCellGathered,1),6);
8 | objectnessTarget = cell(size(YPredCellGathered,1),1);
9 | classTarget = cell(size(YPredCellGathered,1),1);
10 | maskTarget = cell(size(YPredCellGathered,1),3);
11 | boxErrorScaleTarget = cell(size(YPredCellGathered,1),1);
12 |
13 | % Normalize the ground truth boxes w.r.t image input size.
14 | gtScale = [inputImageSize(2) inputImageSize(1) inputImageSize(2) inputImageSize(1)];
15 | groundTruth(:,1:4,:,:) = groundTruth(:,1:4,:,:)./gtScale;
16 |
17 |
18 | for numPred = 1:size(YPredCellGathered,1)
19 |
20 | % Select anchor boxes based on anchor box mask indices.
21 | anchors = anchorBoxes(anchorBoxMask{numPred},:);
22 |
23 | bx = YPredCellGathered{numPred,2};
24 | by = YPredCellGathered{numPred,3};
25 | bw = YPredCellGathered{numPred,4};
26 | bh = YPredCellGathered{numPred,5};
27 | predClasses = YPredCellGathered{numPred,8};
28 |
29 | gridSize = size(bx);
30 | if numel(gridSize)== 3
31 | gridSize(4) = 1;
32 | end
33 | numClasses = size(predClasses,3)/size(anchors,1);
34 |
35 | % Initialize the required variables.
36 | mask = single(zeros(size(bx)));
37 | confMask = single(ones(size(bx)));
38 | classMask = single(zeros(size(predClasses)));
39 | tx = single(zeros(size(bx)));
40 | ty = single(zeros(size(by)));
41 | tw = single(zeros(size(bw)));
42 | th = single(zeros(size(bh)));
43 | tangle1 = single(zeros(size(bh)));
44 | tangle2 = single(zeros(size(bh)));
45 |
46 | tconf = single(zeros(size(bx)));
47 | tclass = single(zeros(size(predClasses)));
48 | boxErrorScale = single(ones(size(bx)));
49 |
50 | % Get the IOU of predictions with groundtruth.
51 | iou = getMaxIOUPredictedWithGroundTruth(bx,by,bw,bh,groundTruth);
52 |
53 | % Donot penalize the predictions which has iou greater than penalty
54 | % threshold.
55 | confMask(iou > penaltyThreshold) = 0;
56 |
57 | for batch = 1:gridSize(4)
58 | truthBatch = groundTruth(:,1:6,:,batch);
59 | truthBatch = truthBatch(all(truthBatch,2),:);
60 |
61 | % Get boxes with center as 0.
62 | gtPred = [0-truthBatch(:,3)/2,0-truthBatch(:,4)/2,truthBatch(:,3),truthBatch(:,4)];
63 | anchorPrior = [0-anchorBoxes(:,2)/(2*inputImageSize(2)),0-anchorBoxes(:,1)/(2*inputImageSize(1)),anchorBoxes(:,2)/inputImageSize(2),anchorBoxes(:,1)/inputImageSize(1)];
64 |
65 | % Get the iou of best matching anchor box.
66 | overLap = bboxOverlapRatio(gtPred,anchorPrior);
67 | [~,bestAnchorIdx] = max(overLap,[],2);
68 |
69 | % Select gt that are within the mask.
70 | index = ismember(bestAnchorIdx,anchorBoxMask{numPred});
71 | truthBatch = truthBatch(index,:);
72 | bestAnchorIdx = bestAnchorIdx(index,:);
73 | bestAnchorIdx = bestAnchorIdx - anchorBoxMask{numPred}(1,1) + 1;
74 |
75 | if ~isempty(truthBatch)
76 |
77 | errorScale = 2 - truthBatch(:,3).*truthBatch(:,4);
78 | truthBatch = [truthBatch(:,1)*gridSize(2),truthBatch(:,2)*gridSize(1),truthBatch(:,3)*inputImageSize(2),truthBatch(:,4)*inputImageSize(1),truthBatch(:,5),truthBatch(:,6)];
79 | for t = 1:size(truthBatch,1)
80 |
81 | % Get the position of ground-truth box in the grid.
82 | colIdx = ceil(truthBatch(t,1));
83 | colIdx(colIdx<1) = 1;
84 | colIdx(colIdx>gridSize(2)) = gridSize(2);
85 | rowIdx = ceil(truthBatch(t,2));
86 | rowIdx(rowIdx<1) = 1;
87 | rowIdx(rowIdx>gridSize(1)) = gridSize(1);
88 | pos = [rowIdx,colIdx];
89 | anchorIdx = bestAnchorIdx(t,1);
90 |
91 | mask(pos(1,1),pos(1,2),anchorIdx,batch) = 1;
92 | confMask(pos(1,1),pos(1,2),anchorIdx,batch) = 1;
93 |
94 | % Calculate the shift in ground-truth boxes.
95 | tShiftX = truthBatch(t,1)-pos(1,2)+1;
96 | tShiftY = truthBatch(t,2)-pos(1,1)+1;
97 | tShiftW = log(truthBatch(t,3)/anchors(anchorIdx,2));
98 | tShiftH = log(truthBatch(t,4)/anchors(anchorIdx,1));
99 | tShiftSinYaw = sind(truthBatch(t,5));
100 | tShiftCosYaw = cosd(truthBatch(t,5));
101 |
102 | % Update the target box.
103 | tx(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftX;
104 | ty(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftY;
105 | tw(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftW;
106 | th(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftH;
107 | tangle1(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftSinYaw;
108 | tangle2(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftCosYaw;
109 |
110 | boxErrorScale(pos(1,1),pos(1,2),anchorIdx,batch) = errorScale(t);
111 | tconf(rowIdx,colIdx,anchorIdx,batch) = 1;
112 | classIdx = (numClasses*(anchorIdx-1))+truthBatch(t,6);
113 | tclass(rowIdx,colIdx,classIdx,batch) = 1;
114 | classMask(rowIdx,colIdx,(numClasses*(anchorIdx-1))+(1:numClasses),batch) = 1;
115 | end
116 | end
117 | end
118 | boxDeltaTarget(numPred,:) = [{tx} {ty} {tw} {th} {tangle1} {tangle2}];
119 | objectnessTarget{numPred,1} = tconf;
120 | classTarget{numPred,1} = tclass;
121 | maskTarget(numPred,:) = [{mask} {confMask} {classMask}];
122 | boxErrorScaleTarget{numPred,:} = boxErrorScale;
123 | end
124 | end
125 |
126 | function iou = getMaxIOUPredictedWithGroundTruth(predx,predy,predw,predh,truth)
127 | % getMaxIOUPredictedWithGroundTruth computes the maximum intersection over
128 | % union scores for every pair of predictions and ground-truth boxes.
129 |
130 | [h,w,c,n] = size(predx);
131 | iou = zeros([h w c n],'like',predx);
132 |
133 | % For each batch prepare the predictions and ground-truth.
134 | for batchSize = 1:n
135 | truthBatch = truth(:,1:5,1,batchSize);
136 | truthBatch = truthBatch(all(truthBatch,2),:);
137 | predxb = predx(:,:,:,batchSize);
138 | predyb = predy(:,:,:,batchSize);
139 | predwb = predw(:,:,:,batchSize);
140 | predhb = predh(:,:,:,batchSize);
141 | predb = [predxb(:),predyb(:),predwb(:),predhb(:)];
142 |
143 | % Add yaw
144 | predb = [predb zeros(size(predb,1),1)];
145 |
146 | % Compute and extract the maximum IOU of predictions with ground-truth.
147 | try
148 | rots = truthBatch(:,5);
149 | rots = rots - floor((rots+0.5)/pi)*pi;
150 | idx = (rots > pi/4);
151 | truthBatch(idx,:) = truthBatch(idx,[1,2,4,3,5]);
152 | overlap = bboxOverlapRatio(predb(:,[1,2,3,4]), truthBatch(:,[1,2,3,4]));
153 | catch me
154 | if(any(isnan(predb(:))|isinf(predb(:))))
155 | error(me.message + " NaN/Inf has been detected during training. Try reducing the learning rate.");
156 | elseif(any(predb(:,3)<=0 | predb(:,4)<=0))
157 | error(me.message + " Invalid predictions during training. Try reducing the learning rate.");
158 | else
159 | error(me.message + " Invalid groundtruth. Check that your ground truth boxes are not empty and finite, are fully contained within the image boundary, and have positive width and height.");
160 | end
161 | end
162 |
163 | maxOverlap = max(overlap,[],2);
164 | iou(:,:,:,batchSize) = reshape(maxOverlap,h,w,c);
165 | end
166 | end
167 |
--------------------------------------------------------------------------------
/src/+helper/generateTiledAnchors.m:
--------------------------------------------------------------------------------
1 | function tiledAnchors = generateTiledAnchors(YPredCell,anchorBoxes,anchorBoxMask)
2 | % Generate tiled anchor offset.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | tiledAnchors = cell(size(YPredCell));
7 | for i=1:size(YPredCell,1)
8 | anchors = anchorBoxes(anchorBoxMask{i}, :);
9 | [h,w,~,n] = size(YPredCell{i,1});
10 | [tiledAnchors{i,2}, tiledAnchors{i,1}] = ndgrid(0:h-1,0:w-1,1:size(anchors,1),1:n);
11 | [~,~,tiledAnchors{i,3}] = ndgrid(0:h-1,0:w-1,anchors(:,2),1:n);
12 | [~,~,tiledAnchors{i,4}] = ndgrid(0:h-1,0:w-1,anchors(:,1),1:n);
13 | end
14 | end
--------------------------------------------------------------------------------
/src/+helper/getAnchors.m:
--------------------------------------------------------------------------------
1 | function anchors = getAnchors(modelName)
2 | % The getAnchors function returns the anchors used in training of the
3 | % specified Complex YOLO v4 model.
4 | %
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | if isequal(modelName, 'complex-yolov4-pandaset')
8 | anchors.anchorBoxes = [10 10; 11 13; 22 35; ...
9 | 23 48; 23 55; 25 56; ...
10 | 25 62; 27 71; 35 95];
11 | anchors.anchorBoxMasks = {[1,2,3]
12 | [4,5,6]
13 | [7,8,9]};
14 | elseif isequal(modelName, 'tiny-complex-yolov4-pandaset')
15 | anchors.anchorBoxes = [36 118; 24 59; 23 50; ...
16 | 11 14; 10 10; 5 4];
17 | anchors.anchorBoxMasks = {[1,2,3]
18 | [4,5,6]};
19 | end
20 | end
21 |
--------------------------------------------------------------------------------
/src/+helper/getClassNames.m:
--------------------------------------------------------------------------------
1 | function classNames = getClassNames()
2 | % The getClassNames function returns the names of Pandaset dataset classes.
3 | %
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | names={'Car'
7 | 'Truck'
8 | 'Pedestrain'};
9 |
10 | classNames = categorical(names);
11 | end
12 |
--------------------------------------------------------------------------------
/src/+helper/getGridParameters.m:
--------------------------------------------------------------------------------
1 | function gridParams = getGridParameters()
2 | % The getGridParameters function returns the grid parameters that controls
3 | % the range of point cloud.
4 | %
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | xMin = -25.0; % Minimum value along X-axis.
8 | xMax = 25.0; % Maximum value along X-axis.
9 | yMin = 0.0; % Minimum value along Y-axis.
10 | yMax = 50.0; % Maximum value along Y-axis.
11 | zMin = -7.0; % Minimum value along Z-axis.
12 | zMax = 15.0; % Maximum value along Z-axis.
13 |
14 | pcRange = [xMin xMax yMin yMax zMin zMax];
15 |
16 | % Define the dimensions for the pseudo-image.
17 | bevHeight = 608;
18 | bevWidth = 608;
19 |
20 | % Find grid resolution.
21 | gridW = (pcRange(4) - pcRange(3))/bevWidth;
22 | gridH = (pcRange(2) - pcRange(1))/bevHeight;
23 |
24 | gridParams.xMin = xMin;
25 | gridParams.xMax = xMax;
26 | gridParams.yMin = yMin;
27 | gridParams.yMax = yMax;
28 | gridParams.zMin = zMin;
29 | gridParams.zMax = zMax;
30 |
31 | gridParams.bevHeight = bevHeight;
32 | gridParams.bevWidth = bevWidth;
33 |
34 | gridParams.gridH = gridH;
35 | gridParams.gridW = gridW;
36 | end
--------------------------------------------------------------------------------
/src/+helper/piecewiseLearningRateWithWarmup.m:
--------------------------------------------------------------------------------
1 | function currentLR = piecewiseLearningRateWithWarmup(iteration, epoch, learningRate, warmupPeriod, numEpochs)
2 | % The piecewiseLearningRateWithWarmup function computes the current
3 | % learning rate based on the iteration number.
4 |
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | persistent warmUpEpoch;
8 |
9 | if iteration <= warmupPeriod
10 | % Increase the learning rate for number of iterations in warmup period.
11 | currentLR = learningRate * ((iteration/warmupPeriod)^4);
12 | warmUpEpoch = epoch;
13 | elseif iteration >= warmupPeriod && epoch < warmUpEpoch+floor(0.6*(numEpochs-warmUpEpoch))
14 | % After warm up period, keep the learning rate constant if the remaining number of epochs is less than 60 percent.
15 | currentLR = learningRate;
16 |
17 | elseif epoch >= warmUpEpoch + floor(0.6*(numEpochs-warmUpEpoch)) && epoch < warmUpEpoch+floor(0.9*(numEpochs-warmUpEpoch))
18 | % If the remaining number of epochs is more than 60 percent but less
19 | % than 90 percent multiply the learning rate by 0.1.
20 | currentLR = learningRate*0.1;
21 |
22 | else
23 | % If remaining epochs are more than 90 percent multiply the learning
24 | % rate by 0.01.
25 | currentLR = learningRate*0.01;
26 | end
27 | end
--------------------------------------------------------------------------------
/src/+helper/postprocess.m:
--------------------------------------------------------------------------------
1 | function [bboxes,scores,labels] = postprocess(outFeatureMaps, anchors, inputSize, scale, classNames)
2 | % The postprocess function applies postprocessing on the generated feature
3 | % maps and returns bounding boxes, detection scores and labels.
4 | %
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | % Get number of classes.
8 | classNames = categorical(classNames);
9 | numClasses = numel(classNames);
10 |
11 | % Get anchor boxes and anchor boxes masks.
12 | anchorBoxes = anchors.anchorBoxes;
13 | anchorBoxMasks = anchors.anchorBoxMasks;
14 |
15 | % Postprocess generated feature maps.
16 | outputFeatures = [];
17 | for i = 1:size(outFeatureMaps,1)
18 | currentFeatureMap = outFeatureMaps{i};
19 | numY = size(currentFeatureMap,1);
20 | numX = size(currentFeatureMap,2);
21 | stride = max(inputSize)./max(numX, numY);
22 | batchsize = size(currentFeatureMap,4);
23 | h = numY;
24 | w = numX;
25 | numAnchors = size(anchorBoxMasks{i},2);
26 |
27 | currentFeatureMap = reshape(currentFeatureMap,h,w,7+numClasses,numAnchors,batchsize);
28 | currentFeatureMap = permute(currentFeatureMap,[5,4,1,2,3]);
29 |
30 | [~,~,yv,xv] = ndgrid(1:batchsize,1:numAnchors,0:h-1,0:w-1);
31 | gridXY = cat(5,xv,yv);
32 | currentFeatureMap(:,:,:,:,1:2) = sigmoid(currentFeatureMap(:,:,:,:,1:2)) + gridXY;
33 | anchorBoxesCurrentLevel= anchorBoxes(anchorBoxMasks{i}, :);
34 | anchorBoxesCurrentLevel(:,[2,1]) = anchorBoxesCurrentLevel(:,[1,2]);
35 | anchor_grid = anchorBoxesCurrentLevel/stride;
36 | anchor_grid = reshape(anchor_grid,1,numAnchors,1,1,2);
37 | currentFeatureMap(:,:,:,:,3:4) = exp(currentFeatureMap(:,:,:,:,3:4)).*anchor_grid;
38 | currentFeatureMap(:,:,:,:,1:4) = currentFeatureMap(:,:,:,:,1:4)*stride;
39 | currentFeatureMap(:,:,:,:,5:6) = tanh(currentFeatureMap(:,:,:,:,5:6));
40 | currentFeatureMap(:,:,:,:,7:end) = sigmoid(currentFeatureMap(:,:,:,:,7:end));
41 |
42 | if numClasses == 1
43 | currentFeatureMap(:,:,:,:,8) = 1;
44 | end
45 | currentFeatureMap = reshape(currentFeatureMap,batchsize,[],7+numClasses);
46 |
47 | if isempty(outputFeatures)
48 | outputFeatures = currentFeatureMap;
49 | else
50 | outputFeatures = cat(2,outputFeatures,currentFeatureMap);
51 | end
52 | end
53 |
54 | % Coordinate conversion to the original image.
55 | outputFeatures = extractdata(outputFeatures); % [x_center,y_center,w,h,Pobj,p1,p2,...,pn]
56 | outputFeatures(:,:,[1,3]) = outputFeatures(:,:,[1,3])*scale(2); % x_center,width
57 | outputFeatures(:,:,[2,4]) = outputFeatures(:,:,[2,4])*scale(1); % y_center,height
58 |
59 | outputFeatures(:,:,5) = rad2deg(atan2(outputFeatures(:,:,5),outputFeatures(:,:,6)));
60 | outputFeatures = squeeze(outputFeatures); % If it is a single image detection, the output size is M*(7+numClasses), otherwise it is bs*M*(7+numClasses)
61 |
62 | if(canUseGPU())
63 | outputFeatures = gather(outputFeatures);
64 | end
65 |
66 | % Apply Confidence threshold and Non-maximum suppression.
67 | confidenceThreshold = 0.1;
68 | overlapThresold = 0.01;
69 |
70 | scores = outputFeatures(:,7);
71 | outFeatures = outputFeatures(scores>confidenceThreshold,:);
72 |
73 | allBBoxes = outFeatures(:,1:5);
74 | allScores = outFeatures(:,7);
75 | [maxScores,indxs] = max(outFeatures(:,8:end),[],2);
76 | allScores = allScores.*maxScores;
77 | allLabels = classNames(indxs);
78 |
79 | bboxes = [];
80 | scores = [];
81 | labels = [];
82 | if ~isempty(allBBoxes)
83 | [bboxes,scores,labels] = selectStrongestBboxMulticlass(allBBoxes,allScores,allLabels,...
84 | 'RatioType','Min','OverlapThreshold',overlapThresold);
85 | end
86 | end
87 |
--------------------------------------------------------------------------------
/src/+helper/preprocess.m:
--------------------------------------------------------------------------------
1 | function [imageMap,ptCldOut] = preprocess(ptCld,gridParams)
2 | % The preprocess converts the point cloud to image based on grid parameters
3 | % and returns the image and the processed point cloud.
4 | %
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | pcRange = [gridParams.xMin gridParams.xMax gridParams.yMin ...
8 | gridParams.yMax gridParams.zMin gridParams.zMax];
9 |
10 | indices = findPointsInROI(ptCld,pcRange);
11 | ptCldOut = select(ptCld,indices);
12 |
13 | bevHeight = gridParams.bevHeight;
14 | bevWidth = gridParams.bevWidth;
15 |
16 | % Find grid resolution.
17 | gridH = gridParams.gridH;
18 | gridW = gridParams.gridW;
19 |
20 | loc = ptCldOut.Location;
21 | intensity = ptCldOut.Intensity;
22 | intensity = normalize(intensity,'range');
23 |
24 | % Find the grid each point falls into.
25 | loc(:,1) = int32(floor(loc(:,1)/gridH)+bevHeight/2) + 1;
26 | loc(:,2) = int32(floor(loc(:,2)/gridW)) + 1;
27 |
28 | % Normalize the height.
29 | loc(:,3) = loc(:,3) - min(loc(:,3));
30 | loc(:,3) = loc(:,3)/(pcRange(6) - pcRange(5));
31 |
32 | % Sort the points based on height.
33 | [~,I] = sortrows(loc,[1,2,-3]);
34 | locMod = loc(I,:);
35 | intensityMod = intensity(I,:);
36 |
37 | % Initialize height and intensity map
38 | heightMap = zeros(bevHeight,bevWidth);
39 | intensityMap = zeros(bevHeight,bevWidth);
40 |
41 | locMod(:,1) = min(locMod(:,1),bevHeight);
42 | locMod(:,2) = min(locMod(:,2),bevHeight);
43 |
44 | % Find the unique indices having max height.
45 | mapIndices = sub2ind([bevHeight,bevWidth],locMod(:,1),locMod(:,2));
46 | [~,idx] = unique(mapIndices,"rows","first");
47 |
48 | binc = 1:bevWidth*bevHeight;
49 | counts = hist(mapIndices,binc);
50 |
51 | normalizedCounts = min(1.0, log(counts + 1) / log(64));
52 |
53 | for i = 1:size(idx,1)
54 | heightMap(mapIndices(idx(i))) = locMod(idx(i),3);
55 | intensityMap(mapIndices(idx(i))) = intensityMod(idx(i),1);
56 | end
57 |
58 | densityMap = reshape(normalizedCounts,[bevHeight,bevWidth]);
59 |
60 | imageMap = zeros(bevHeight,bevWidth,3);
61 | imageMap(:,:,1) = densityMap; % R channel
62 | imageMap(:,:,2) = heightMap; % G channel
63 | imageMap(:,:,3) = intensityMap; % B channel
64 | end
--------------------------------------------------------------------------------
/src/+helper/preprocessData.m:
--------------------------------------------------------------------------------
1 | function data = preprocessData(data, targetSize, isRotRect)
2 | % Resize the images and scale the pixels to between 0 and 1. Also scale the
3 | % corresponding bounding boxes.
4 | for ii = 1:size(data,1)
5 | I = data{ii,1};
6 | imgSize = size(I);
7 |
8 | % Convert an input image with single channel to 3 channels.
9 | if numel(imgSize) < 3
10 | I = repmat(I,1,1,3);
11 | end
12 | bboxes = data{ii,2};
13 |
14 | I = im2single(imresize(I,targetSize(1:2)));
15 | scale = targetSize(1:2)./imgSize(1:2);
16 | bboxes = bboxresize(bboxes,scale);
17 |
18 | if ~isRotRect
19 | bboxes = bboxes(:,1:4);
20 | end
21 |
22 | data(ii, 1:2) = {I, bboxes};
23 | end
24 | end
25 |
26 |
27 | % Copyright 2021 The MathWorks, Inc.
--------------------------------------------------------------------------------
/src/+helper/removeEmptyData.m:
--------------------------------------------------------------------------------
1 | function [imdsProcessed,bdsProcessed] = removeEmptyData(imds,bds)
2 | % Return non-empty indices from the saved data
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | % Read labels from the box label datastore.
7 | processedLabels = readall(bds);
8 |
9 | % Get the non-empty indices.
10 | indices = ~cellfun('isempty',processedLabels(:,1));
11 |
12 | imdsProcessed = subset(imds,indices);
13 | bdsProcessed = subset(bds,indices);
14 |
15 | end
--------------------------------------------------------------------------------
/src/+helper/transferbboxToPointCloud.m:
--------------------------------------------------------------------------------
1 | function bboxCuboid = transferbboxToPointCloud(bboxes,gridParams,ptCldOut)
2 | % Transfer labels from images to point cloud.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | % Calculate the height of the ground plane.
7 | groundPtsIdx = segmentGroundSMRF(ptCldOut,3,'MaxWindowRadius',5,'ElevationThreshold',0.4,'ElevationScale',0.25);
8 | loc = ptCldOut.Location;
9 | groundHeight = mean(loc(groundPtsIdx,3));
10 |
11 | % Assume height of objects to be a constant based on input data.
12 | objectHeight = 1.56;
13 |
14 | % Transfer Labels back to the point cloud.
15 | bboxCuboid = zeros(size(bboxes,1),9);
16 | bboxCuboid(:,1) = (bboxes(:,2) - 1 - gridParams.bevHeight/2)*gridParams.gridH;
17 | bboxCuboid(:,2) = (bboxes(:,1) - 1 )*gridParams.gridW;
18 | bboxCuboid(:,4) = bboxes(:,4)*gridParams.gridH;
19 | bboxCuboid(:,5) = bboxes(:,3)*gridParams.gridW;
20 | bboxCuboid(:,9) = -bboxes(:,5);
21 |
22 | bboxCuboid(:,6) = objectHeight;
23 | bboxCuboid(:,3) = groundHeight + (objectHeight/2);
24 | end
--------------------------------------------------------------------------------
/src/+helper/updatePlots.m:
--------------------------------------------------------------------------------
1 | function updatePlots(lossPlotter, learningRatePlotter, iteration, currentLR, totalLoss)
2 | % Update loss and learning rate plots.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | addpoints(lossPlotter, iteration, double(extractdata(gather(totalLoss))));
7 | addpoints(learningRatePlotter, iteration, currentLR);
8 | drawnow
9 | end
--------------------------------------------------------------------------------
/src/+helper/validateInputData.m:
--------------------------------------------------------------------------------
1 | function validateInputData(ds)
2 | % Validates the input images, bounding boxes and labels and displays the
3 | % paths of invalid samples.
4 |
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | % Path to images
8 | info = ds.UnderlyingDatastores{1}.Files;
9 |
10 | ds = transform(ds, @isValidDetectorData);
11 | data = readall(ds);
12 |
13 | validImgs = [data.validImgs];
14 | validBoxes = [data.validBoxes];
15 | validLabels = [data.validLabels];
16 |
17 | msg = "";
18 |
19 | if(any(~validImgs))
20 | imPaths = info(~validImgs);
21 | str = strjoin(imPaths, '\n');
22 | imErrMsg = sprintf("Input images must be non-empty and have 2 or 3 dimensions. The following images are invalid:\n") + str;
23 | msg = (imErrMsg + newline + newline);
24 | end
25 |
26 | if(any(~validBoxes))
27 | imPaths = info(~validBoxes);
28 | str = strjoin(imPaths, '\n');
29 | boxErrMsg = sprintf("Bounding box data must be M-by-5 matrices of positive integer values. The following images have invalid bounding box data:\n") ...
30 | + str;
31 |
32 | msg = (msg + boxErrMsg + newline + newline);
33 | end
34 |
35 | if(any(~validLabels))
36 | imPaths = info(~validLabels);
37 | str = strjoin(imPaths, '\n');
38 | labelErrMsg = sprintf("Labels must be non-empty and categorical. The following images have invalid labels:\n") + str;
39 |
40 | msg = (msg + labelErrMsg + newline);
41 | end
42 |
43 | if(~isempty(msg))
44 | error(msg);
45 | end
46 |
47 | end
48 |
49 | function out = isValidDetectorData(data)
50 | % Checks validity of images, bounding boxes and labels
51 | for i = 1:size(data,1)
52 | I = data{i,1};
53 | boxes = data{i,2};
54 | labels = data{i,3};
55 |
56 | imageSize = size(I);
57 | mSize = size(boxes, 1);
58 |
59 | out.validImgs(i) = iCheckImages(I);
60 | out.validBoxes(i) = iCheckBoxes(boxes, imageSize);
61 | out.validLabels(i) = iCheckLabels(labels, mSize);
62 | end
63 |
64 | end
65 |
66 | function valid = iCheckImages(I)
67 | % Validates the input images.
68 |
69 | valid = true;
70 | if ndims(I) == 2
71 | nDims = 2;
72 | else
73 | nDims = 3;
74 | end
75 | % Define image validation parameters.
76 | classes = {'numeric'};
77 | attrs = {'nonempty', 'nonsparse', 'nonnan', 'finite', 'ndims', nDims};
78 | try
79 | validateattributes(I, classes, attrs);
80 | catch
81 | valid = false;
82 | end
83 | end
84 |
85 | function valid = iCheckBoxes(boxes, imageSize)
86 | % Validates the ground-truth bounding boxes to be non-empty and finite.
87 |
88 | valid = true;
89 | % Define bounding box validation parameters.
90 | classes = {'numeric'};
91 | attrs = {'nonempty', 'nonnan', 'finite', 'positive', 'nonzero', 'nonsparse', '2d', 'ncols', 4};
92 | attrsYaw = {'nonempty', 'nonnan', 'finite', 'nonsparse'};
93 | try
94 | validateattributes(boxes(:,1:4), classes, attrs);
95 | validateattributes(boxes(:,5), classes, attrsYaw);
96 | % Validate if bounding box in within image boundary.
97 | validateattributes(boxes(:,1)+boxes(:,3)-1, classes, {'<=', imageSize(2)});
98 | validateattributes(boxes(:,2)+boxes(:,4)-1, classes, {'<=', imageSize(1)});
99 | catch
100 | valid = false;
101 | end
102 | end
103 |
104 | function valid = iCheckLabels(labels, mSize)
105 | % Validates the labels.
106 |
107 | valid = true;
108 | % Define label validation parameters.
109 | classes = {'categorical'};
110 | attrs = {'nonempty', 'nonsparse', '2d', 'ncols', 1, 'nrows', mSize};
111 | try
112 | validateattributes(labels, classes, attrs);
113 | catch
114 | valid = false;
115 | end
116 | end
--------------------------------------------------------------------------------
/src/+helper/yolov4Forward.m:
--------------------------------------------------------------------------------
1 | function [YPredCell, state] = yolov4Forward(net, XTrain, networkOutputs, anchorBoxMask)
2 | % Predict the output of network and extract the confidence score, x, y,
3 | % width, height, and class.
4 |
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | YPredictions = cell(size(networkOutputs));
8 | [YPredictions{:}, state] = forward(net, XTrain, 'Outputs', networkOutputs);
9 | YPredCell = helper.extractPredictions(YPredictions, anchorBoxMask);
10 |
11 | % Append predicted width and height to the end as they are required
12 | % for computing the loss.
13 | YPredCell(:,9:10) = YPredCell(:,4:5);
14 |
15 | % Apply sigmoid and exponential activation.
16 | YPredCell(:,1:8) = helper.applyActivations(YPredCell(:,1:8));
17 | end
--------------------------------------------------------------------------------
/src/+loss/bboxOffsetLoss.m:
--------------------------------------------------------------------------------
1 | function boxLoss = bboxOffsetLoss(boxPredCell, boxDeltaTarget, boxMaskTarget, boxErrorScaleTarget)
2 | % Mean squared error for bounding box position.
3 | lossX = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,1),boxDeltaTarget(:,1),boxMaskTarget(:,1),boxErrorScaleTarget));
4 | lossY = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,2),boxDeltaTarget(:,2),boxMaskTarget(:,1),boxErrorScaleTarget));
5 | lossW = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,3),boxDeltaTarget(:,3),boxMaskTarget(:,1),boxErrorScaleTarget));
6 | lossH = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,4),boxDeltaTarget(:,4),boxMaskTarget(:,1),boxErrorScaleTarget));
7 |
8 | lossYaw1 = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,5),boxDeltaTarget(:,5),boxMaskTarget(:,1),boxErrorScaleTarget));
9 | lossYaw2 = sum(cellfun(@(a,b,c,d) mse(a.*c.*d,b.*c.*d),boxPredCell(:,6),boxDeltaTarget(:,6),boxMaskTarget(:,1),boxErrorScaleTarget));
10 |
11 | boxLoss = lossX+lossY+lossW+lossH+lossYaw1+lossYaw2;
12 | end
--------------------------------------------------------------------------------
/src/+loss/classConfidenceLoss.m:
--------------------------------------------------------------------------------
1 | function clsLoss = classConfidenceLoss(classPredCell, classTarget, boxMaskTarget)
2 | % Binary cross-entropy loss for class confidence score.
3 | clsLoss = sum(cellfun(@(a,b,c) crossentropy(a.*c,b.*c,'TargetCategories','independent'),classPredCell,classTarget,boxMaskTarget(:,3)));
4 | end
--------------------------------------------------------------------------------
/src/+loss/objectnessLoss.m:
--------------------------------------------------------------------------------
1 | function objLoss = objectnessLoss(objectnessPredCell, objectnessDeltaTarget, boxMaskTarget)
2 | % Binary cross-entropy loss for objectness score.
3 | objLoss = sum(cellfun(@(a,b,c) crossentropy(a.*c,b.*c,'TargetCategories','independent'),objectnessPredCell,objectnessDeltaTarget,boxMaskTarget(:,2)));
4 | end
--------------------------------------------------------------------------------
/src/configureYOLOv4.m:
--------------------------------------------------------------------------------
1 | function [lgraph, networkOutputs, anchorBoxes, anchorBoxMasks] = configureYOLOv4(net, classNames, anchorBoxes, modelName)
2 | % Configure the pretrained network for transfer learning.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | % Specify anchorBoxMasks to select anchor boxes to use in both the detection
7 | % heads. anchorBoxMasks is a cell array of [Mx1], where M denotes the number
8 | % of detection heads. Each detection head consists of a [1xN] array of row
9 | % index of anchors in anchorBoxes, where N is the number of anchor boxes to
10 | % use.
11 | % Select anchor boxes for each detection head based on size-use larger
12 | % anchor boxes at lower scale and smaller anchor boxes at higher scale.
13 | if strcmp(modelName, 'complex-yolov4-pandaset')
14 | area = anchorBoxes(:, 1).*anchorBoxes(:, 2);
15 | [~, idx] = sort(area, 'ascend');
16 | anchorBoxes = anchorBoxes(idx, :);
17 | anchorBoxMasks = {[1,2,3]
18 | [4,5,6]
19 | [7,8,9]
20 | };
21 | elseif strcmp(modelName, 'tiny-complex-yolov4-pandaset')
22 | area = anchorBoxes(:, 1).*anchorBoxes(:, 2);
23 | [~, idx] = sort(area, 'descend');
24 | anchorBoxes = anchorBoxes(idx, :);
25 | anchorBoxMasks = {[1,2,3]
26 | [4,5,6]
27 | };
28 | end
29 |
30 | % Specify the number of object classes to detect, and number of prediction
31 | % elements per anchor box. The number of predictions per anchor box is set
32 | % to 7 plus the number of object classes. "7" denoted the 4 bounding box
33 | % attributes, two angle attributes and 1 object confidence.
34 | numClasses = size(classNames, 1);
35 | numPredictorsPerAnchor = 7 + numClasses;
36 |
37 | % Modify the layegraph to train with new set of classes.
38 | lgraph = layerGraph(net);
39 |
40 | if strcmp(modelName, 'complex-yolov4-pandaset')
41 | yoloModule1 = convolution2dLayer(1,length(anchorBoxMasks{1})*numPredictorsPerAnchor,'Name','yoloconv1');
42 | yoloModule2 = convolution2dLayer(1,length(anchorBoxMasks{2})*numPredictorsPerAnchor,'Name','yoloconv2');
43 | yoloModule3 = convolution2dLayer(1,length(anchorBoxMasks{3})*numPredictorsPerAnchor,'Name','yoloconv3');
44 |
45 | lgraph = replaceLayer(lgraph,'yoloconv1',yoloModule1);
46 | lgraph = replaceLayer(lgraph,'yoloconv2',yoloModule2);
47 | lgraph = replaceLayer(lgraph,'yoloconv3',yoloModule3);
48 |
49 | networkOutputs = ["yoloconv1"
50 | "yoloconv2"
51 | "yoloconv3"
52 | ];
53 | elseif strcmp(modelName, 'tiny-complex-yolov4-pandaset')
54 | yoloModule1 = convolution2dLayer(1,length(anchorBoxMasks{1})*numPredictorsPerAnchor,'Name','yoloconv1');
55 | yoloModule2 = convolution2dLayer(1,length(anchorBoxMasks{2})*numPredictorsPerAnchor,'Name','yoloconv2');
56 |
57 | lgraph = replaceLayer(lgraph,'yoloconv1',yoloModule1);
58 | lgraph = replaceLayer(lgraph,'yoloconv2',yoloModule2);
59 |
60 | networkOutputs = ["yoloconv1"
61 | "yoloconv2"
62 | ];
63 | end
64 | end
--------------------------------------------------------------------------------
/src/modelGradients.m:
--------------------------------------------------------------------------------
1 | function [gradients, state, info] = modelGradients(net, XTrain, YTrain, anchors, mask, penaltyThreshold, networkOutputs)
2 | % The function modelGradients takes as input the dlnetwork object net,
3 | % a mini-batch of input data XTrain with corresponding ground truth boxes YTrain,
4 | % anchor boxes, anchor box mask, the specified penalty threshold, and the network
5 | % output names as input arguments and returns the gradients of the loss with respect
6 | % to the learnable parameters in net, the corresponding mini-batch loss, and
7 | % the state of the current batch.
8 | %
9 | % The model gradients function computes the total loss and gradients by performing
10 | % these operations.
11 |
12 | % * Generate predictions from the input batch of images using the supporting
13 | % function yolov4Forward.
14 | % * Collect predictions on the CPU for postprocessing.
15 | % * Convert the predictions from the YOLO v4 grid cell coordinates to bounding
16 | % box coordinates to allow easy comparison with the ground truth data by using
17 | % the supporting functions generateTiledAnchors and applyAnchorBoxOffsets.
18 | % * Generate targets for loss computation by using the converted predictions
19 | % and ground truth data. These targets are generated for bounding box positions
20 | % (x, y, width, height, yaw), object confidence, and class probabilities. See the helper
21 | % function generateTargets.
22 | % * Calculates the mean squared error of the predicted bounding box coordinates
23 | % with target boxes. See the supporting function bboxOffsetLoss.
24 | % * Determines the binary cross-entropy of the predicted object confidence score
25 | % with target object confidence score. See the supporting function objectnessLoss.
26 | % * Determines the binary cross-entropy of the predicted class of object with
27 | % the target. See the supporting function classConfidenceLoss.
28 | % * Computes the total loss as the sum of all losses.
29 | % * Computes the gradients of learnables with respect to the total loss.
30 |
31 | % Copyright 2021 The MathWorks, Inc.
32 |
33 | inputImageSize = size(XTrain,1:2);
34 |
35 | % Gather the ground truths in the CPU for post processing
36 | YTrain = gather(extractdata(YTrain));
37 |
38 | % Extract the predictions from the network.
39 | [YPredCell, state] = helper.yolov4Forward(net,XTrain,networkOutputs,mask);
40 |
41 | % Gather the activations in the CPU for post processing and extract dlarray data.
42 | gatheredPredictions = cellfun(@ gather, YPredCell(:,1:8),'UniformOutput',false);
43 | gatheredPredictions = cellfun(@ extractdata, gatheredPredictions, 'UniformOutput', false);
44 |
45 | % Convert predictions from grid cell coordinates to box coordinates.
46 | tiledAnchors = helper.generateTiledAnchors(gatheredPredictions(:,2:5),anchors,mask);
47 | gatheredPredictions(:,2:5) = helper.applyAnchorBoxOffsets(tiledAnchors, gatheredPredictions(:,2:5), inputImageSize);
48 |
49 | % Generate target for predictions from the ground truth data.
50 | [boxTarget, objectnessTarget, classTarget, objectMaskTarget, boxErrorScale] = helper.generateTargets(gatheredPredictions, YTrain, inputImageSize, anchors, mask, penaltyThreshold);
51 |
52 | % Compute the loss.
53 | boxLoss = loss.bboxOffsetLoss(YPredCell(:,[2 3 9 10 6 7]),boxTarget,objectMaskTarget,boxErrorScale);
54 | objLoss = loss.objectnessLoss(YPredCell(:,1),objectnessTarget,objectMaskTarget);
55 | clsLoss = loss.classConfidenceLoss(YPredCell(:,8),classTarget,objectMaskTarget);
56 | totalLoss = boxLoss + objLoss + clsLoss;
57 |
58 | info.boxLoss = boxLoss;
59 | info.objLoss = objLoss;
60 | info.clsLoss = clsLoss;
61 | info.totalLoss = totalLoss;
62 |
63 | % Compute gradients of learnables with regard to loss.
64 | gradients = dlgradient(totalLoss, net.Learnables);
65 | end
--------------------------------------------------------------------------------
/test/tPretrainedYOLOv4.m:
--------------------------------------------------------------------------------
1 | classdef(SharedTestFixtures = {DownloadComplexYolov4Fixture}) tPretrainedYOLOv4 < matlab.unittest.TestCase
2 | % Test for PretrainedComplexYOLOv4
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | % The shared test fixture downloads the model. Here we check the
7 | % detections of each models.
8 | properties
9 | RepoRoot = getRepoRoot;
10 | end
11 |
12 | properties(TestParameter)
13 | Model = iGetDifferentModels();
14 | end
15 |
16 | methods(Test)
17 | function exerciseDetection(test,Model)
18 | detector = load(fullfile(test.RepoRoot,'models','complex-yolov4-models-master','models',Model.dataFileName));
19 | modelName = strsplit(Model.dataFileName,'/');
20 |
21 | ptCld = pcread(fullfile(test.RepoRoot,'..','pointclouds','0001.pcd'));
22 | gridParams = helper.getGridParameters;
23 | classNames = helper.getClassNames;
24 | [img,~] = helper.preprocess(ptCld, gridParams);
25 | anchors = helper.getAnchors(modelName{1});
26 | [bboxes, scores, labels] = detectComplexYOLOv4(detector.net, img, anchors, classNames, 'auto');
27 |
28 | test.verifyEqual(bboxes, Model.expectedBboxes,'AbsTol',single(1e-2));
29 | test.verifyEqual(scores, Model.expectedScores,'AbsTol',single(1e-2));
30 | test.verifyEqual(labels, Model.expectedLabels);
31 | end
32 | end
33 | end
34 |
35 | function Model = iGetDifferentModels()
36 | % Load YOLOv4-Pandaset
37 | dataFileName = 'complex-yolov4-pandaset/complex-yolov4-pandaset.mat';
38 |
39 | % Expected detection results.
40 | expectedBboxes = single([62.9979 562.6405 8.4515 10.0657 168.9640;...
41 | 321.4069 562.5641 8.2753 11.2603 123.9085;...
42 | 106.1246 353.8116 56.1350 22.1984 -3.0292;...
43 | 280.9620 253.8887 54.5222 22.7809 -2.8766;...
44 | 309.0541 153.2278 53.7926 22.8003 175.2755;...
45 | 425.4077 339.4636 53.9085 23.5097 2.8756;...
46 | 144.8896 117.9126 72.5260 32.5886 -176.6801;...
47 | 157.9700 196.1560 60.2388 26.2909 176.4073]);
48 | expectedScores = single([0.2301;0.1563;0.8486;0.9919;1.0000;0.8907;0.5088;0.9871]);
49 | expectedLabels = categorical({'Pedestrain';'Pedestrain';'Car';'Car';'Car';'Car';'Truck';'Car'});
50 | detectorYOLOv4Pandaset = struct('dataFileName',dataFileName,...
51 | 'expectedBboxes',expectedBboxes,'expectedScores',expectedScores,...
52 | 'expectedLabels',expectedLabels);
53 |
54 | % Load tiny-yolov4-pandaset
55 | dataFileName = 'tiny-complex-yolov4-pandaset/tiny-complex-yolov4-pandaset.mat';
56 |
57 | % Expected detection results.
58 | expectedBboxes = single([101.5230 353.9243 51.5760 22.7369 -2.4730;...
59 | 154.0279 117.7253 68.4812 26.6877 177.9337;...
60 | 157.0230 195.9053 59.6339 23.1816 -178.8141;...
61 | 174.7080 152.9906 65.6461 26.1998 174.4401;...
62 | 210.7210 462.1014 49.4194 22.4370 -23.0138;...
63 | 195.0767 561.0181 59.5237 25.0804 -163.0908;...
64 | 283.1094 254.6001 56.8760 23.1281 -0.2960;...
65 | 312.2672 152.6835 57.9609 22.5368 178.7967;...
66 | 316.1564 486.0815 59.9362 25.2384 -91.3584;...
67 | 347.8901 398.4372 58.1425 23.9504 -35.8817;...
68 | 413.8113 300.3788 59.3578 25.0234 -0.5064;...
69 | 400.1172 411.2654 67.0539 25.7621 -179.9310;...
70 | 400.5548 461.0544 66.7300 27.3921 -106.2011;...
71 | 387.3457 523.3322 63.0952 24.9289 -121.0169;...
72 | 420.1742 338.2006 51.8193 23.7939 4.2179]);
73 | expectedScores = single([1.0000;0.8652;1.0000;0.5205;0.5793;0.9999;0.9960;...
74 | 0.9973;0.1788;0.9606;0.9998;0.1885;0.9848;0.9833;0.1407]);
75 | expectedLabels = categorical({'Car';'Truck';'Car';'Car';'Car';'Car';'Car';...
76 | 'Car';'Car';'Car';'Car';'Car';'Car';'Car';'Car'});
77 | detectorTinyYOLOv4Pandaset = struct('dataFileName',dataFileName,...
78 | 'expectedBboxes',expectedBboxes,'expectedScores',expectedScores,...
79 | 'expectedLabels',expectedLabels);
80 |
81 | Model = struct(...
82 | 'detectorYOLOv4Pandaset',detectorYOLOv4Pandaset,'detectorTinyYOLOv4Pandaset',detectorTinyYOLOv4Pandaset);
83 | end
--------------------------------------------------------------------------------
/test/tdownloadPretrainedComplexYOLOv4.m:
--------------------------------------------------------------------------------
1 | classdef(SharedTestFixtures = {DownloadComplexYolov4Fixture}) tdownloadPretrainedComplexYOLOv4 < matlab.unittest.TestCase
2 | % Test for downloadPretrainedYOLOv4
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | % The shared test fixture DownloadComplexYolov4Fixture calls
7 | % downloadPretrainedYOLOv4. Here we check that the downloaded files
8 | % exists in the appropriate location.
9 |
10 | properties
11 | DataDir = fullfile(getRepoRoot(),'models','complex-yolov4-models-master','models');
12 | end
13 |
14 |
15 | properties(TestParameter)
16 | Model = {'complex-yolov4-pandaset', 'tiny-complex-yolov4-pandaset'};
17 | end
18 |
19 | methods(Test)
20 | function verifyDownloadedFilesExist(test,Model)
21 | dataFileName = [Model,'.mat'];
22 | test.verifyTrue(isequal(exist(fullfile(test.DataDir,Model,dataFileName),'file'),2));
23 | end
24 | end
25 | end
--------------------------------------------------------------------------------
/test/tload.m:
--------------------------------------------------------------------------------
1 | classdef(SharedTestFixtures = {DownloadComplexYolov4Fixture}) tload < matlab.unittest.TestCase
2 | % Test for loading the downloaded models.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | % The shared test fixture DownloadYolov4Fixture calls
7 | % downloadPretrainedYOLOv4. Here we check the properties of
8 | % downloaded models.
9 |
10 | properties
11 | DataDir = fullfile(getRepoRoot(),'models','complex-yolov4-models-master','models');
12 | end
13 |
14 |
15 | properties(TestParameter)
16 | Model = iGetDifferentModels();
17 | end
18 |
19 | methods(Test)
20 | function verifyModelAndFields(test,Model)
21 | % Test point to verify the fields of the downloaded models are
22 | % as expected.
23 | loadedModel = load(fullfile(test.DataDir,Model.dataFileName));
24 |
25 | test.verifyClass(loadedModel.net,'dlnetwork');
26 | test.verifyEqual(numel(loadedModel.net.Layers), Model.expectedNumLayers);
27 | test.verifyEqual(size(loadedModel.net.Connections), Model.expectedConnectionsSize);
28 | test.verifyEqual(loadedModel.net.InputNames, Model.expectedInputNames);
29 | test.verifyEqual(loadedModel.net.OutputNames, Model.expectedOutputNames);
30 | end
31 | end
32 | end
33 |
34 | function Model = iGetDifferentModels()
35 | % Load YOLOv4-coco
36 | dataFileName = 'complex-yolov4-pandaset/complex-yolov4-pandaset.mat';
37 |
38 | % Expected anchor boxes and classes.
39 | expectedNumLayers = 363;
40 | expectedConnectionsSize = [397 2];
41 | expectedInputNames = {{'input_1'}};
42 | expectedOutputNames = {{'yoloconv1' 'yoloconv2' 'yoloconv3'}};
43 |
44 | detectorYOLOv4Pandaset = struct('dataFileName',dataFileName,...
45 | 'expectedNumLayers',expectedNumLayers,'expectedConnectionsSize',expectedConnectionsSize,...
46 | 'expectedInputNames',expectedInputNames, 'expectedOutputNames',expectedOutputNames);
47 |
48 | % Load YOLOv4-tiny-coco
49 | dataFileName = 'tiny-complex-yolov4-pandaset/tiny-complex-yolov4-pandaset.mat';
50 |
51 | % Expected anchor boxes and classes.
52 | expectedNumLayers = 74;
53 | expectedConnectionsSize = [80 2];
54 | expectedInputNames = {{'input_1'}};
55 | expectedOutputNames = {{'yoloconv1' 'yoloconv2'}};
56 |
57 | detectorTinyYOLOvPandaset = struct('dataFileName',dataFileName,...
58 | 'expectedNumLayers',expectedNumLayers,'expectedConnectionsSize',expectedConnectionsSize,...
59 | 'expectedInputNames',expectedInputNames, 'expectedOutputNames',expectedOutputNames);
60 |
61 | Model = struct(...
62 | 'detectorYOLOv4Pandaset',detectorYOLOv4Pandaset,'detectorTinyYOLOvPandaset',detectorTinyYOLOvPandaset);
63 | end
--------------------------------------------------------------------------------
/test/tools/DownloadComplexYolov4Fixture.m:
--------------------------------------------------------------------------------
1 | classdef DownloadComplexYolov4Fixture < matlab.unittest.fixtures.Fixture
2 | % DownloadComplexYolov4Fixture A fixture for calling
3 | % downloadPretrainedYOLOv4 if necessary. This is to ensure that this
4 | % function is only called once and only when tests need it. It also
5 | % provides a teardown to return the test environment to the expected
6 | % state before testing.
7 |
8 | % Copyright 2021 The MathWorks, Inc
9 |
10 | properties(Constant)
11 | Yolov4DataDir = fullfile(getRepoRoot(),'models','complex-yolov4-models-master','models')
12 | end
13 |
14 | properties
15 | Yolov4PandasetExist (1,1) logical
16 | TinyYolov4PandasetExist (1,1) logical
17 | end
18 |
19 | methods
20 | function setup(this)
21 | addpath(fullfile(getRepoRoot(),'..'));
22 | this.Yolov4PandasetExist = exist(fullfile(this.Yolov4DataDir,'complex-yolov4-pandaset','complex-yolov4-pandaset.mat'),'file')==2;
23 | this.TinyYolov4PandasetExist = exist(fullfile(this.Yolov4DataDir,'tiny-complex-yolov4-pandaset','tiny-complex-yolov4-pandaset.mat'),'file')==2;
24 |
25 | % Call this in eval to capture and drop any standard output
26 | % that we don't want polluting the test logs.
27 | if ~this.Yolov4PandasetExist
28 | evalc('helper.downloadPretrainedYOLOv4(''complex-yolov4-pandaset'');');
29 | end
30 | if ~this.TinyYolov4PandasetExist
31 | evalc('helper.downloadPretrainedYOLOv4(''tiny-complex-yolov4-pandaset'');');
32 | end
33 | end
34 |
35 | function teardown(this)
36 | if ~this.Yolov4PandasetExist
37 | delete(fullfile(this.Yolov4DataDir,'complex-yolov4-pandaset','complex-yolov4-pandaset.mat'));
38 | end
39 | if ~this.TinyYolov4PandasetExist
40 | delete(fullfile(this.Yolov4DataDir,'tiny-complex-yolov4-pandaset','tiny-complex-yolov4-pandaset.mat'));
41 | end
42 | end
43 | end
44 | end
--------------------------------------------------------------------------------
/test/tools/getRepoRoot.m:
--------------------------------------------------------------------------------
1 | function path = getRepoRoot()
2 | % getRepoRoot Return a path to the repository's root directory.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | thisFile = mfilename('fullpath');
7 | thisDir = fileparts(thisFile);
8 |
9 | % the root is up two directories (/test/tools/getRepoRoot.m)
10 | path = fullfile(thisDir,'..');
11 | end
--------------------------------------------------------------------------------