├── +helper
├── augmentData.m
├── displayLidarOverlayImage.m
├── downloadPretrainedSalsaNext.m
├── generateLidarData.m
├── imageMatReader.m
├── lidarColorMap.m
├── partitionLidarData.m
└── pointCloudToImage.m
├── .circleci
└── config.yml
├── LICENSE
├── Pandaset.rights
├── README.md
├── SECURITY.md
├── codegenSalsaNext.m
├── images
├── SalsaNext_architecture.png
├── cuboid_0001.PNG
├── result.gif
└── result_0001.PNG
├── model
└── .gitkeep
├── pointclouds
├── Input1.pcd
├── Input2.pcd
└── Input3.pcd
├── salsaNextSemanticSegmentationExample.m
├── salsaNextTransferLearning.m
├── salsaNextpredict.m
└── test
├── tPretrainedSalsaNext.m
├── tdownloadPretrainedSalsaNext.m
├── tload.m
└── tools
├── DownloadSalsaNextFixture.m
└── getRepoRoot.m
/+helper/augmentData.m:
--------------------------------------------------------------------------------
1 | function out = augmentData(inp)
2 | % Apply random horizontal flipping.
3 |
4 | % Copyright 2021 The MathWorks, Inc
5 | out = cell(size(inp));
6 |
7 | % Randomly flip the five-channel image and pixel labels horizontally.
8 | I = inp{1};
9 | sz = size(I);
10 | tform = randomAffine2d('XReflection',true);
11 | rout = affineOutputView(sz,tform,'BoundsStyle','centerOutput');
12 |
13 | out{1} = imwarp(I,tform,'OutputView',rout);
14 | out{2} = imwarp(inp{2},tform,'OutputView',rout);
15 | end
--------------------------------------------------------------------------------
/+helper/displayLidarOverlayImage.m:
--------------------------------------------------------------------------------
1 | function displayLidarOverlayImage(lidarImage, labelMap, classNames)
2 | %displayLidarOverlayImage Overlay labels over the intensity image.
3 | %
4 | % displayLidarOverlayImage(lidarImage, labelMap, classNames)
5 | % displays the overlaid image. lidarImage is a five-channel lidar input.
6 | % labelMap contains pixel labels and classNames is an array of label
7 | % names.
8 | %
9 | % Copyright 2021 The MathWorks, Inc.
10 |
11 | % Read the intensity channel from the lidar image.
12 | intensityChannel = uint8(lidarImage(:,:,4));
13 |
14 | % Load the lidar color map.
15 | cmap = helper.lidarColorMap();
16 |
17 | % Overlay the labels over the intensity image.
18 | B = labeloverlay(intensityChannel,labelMap,'Colormap',cmap,'Transparency',0.4);
19 |
20 | % Resize for better visualization.
21 | B = imresize(B, 'Scale', [3 1], 'method', 'nearest');
22 | imshow(B);
23 | helperPixelLabelColorbar(cmap, classNames);
24 |
25 | end
26 |
27 | function helperPixelLabelColorbar(cmap, classNames)
28 |
29 | colormap(gca, cmap);
30 |
31 | % Add a colorbar to the current figure.
32 | c = colorbar('peer', gca);
33 |
34 | % Use class names for tick marks.
35 | c.TickLabels = classNames;
36 | numClasses = size(classNames, 1);
37 |
38 | % Center tick labels.
39 | c.Ticks = 1/(numClasses * 2):1/numClasses:1;
40 |
41 | % Remove tick marks.
42 | c.TickLength = 0;
43 | end
44 |
--------------------------------------------------------------------------------
/+helper/downloadPretrainedSalsaNext.m:
--------------------------------------------------------------------------------
1 | function model = downloadPretrainedSalsaNext()
2 |
3 | % The downloadPretrainedSalsaNext function downloads a SalsaNext network
4 | % pretrained on Pandaset dataset.
5 | %
6 | % Copyright 2021 The MathWorks, Inc.
7 |
8 | dataPath = 'model';
9 | modelName = 'SalsNext';
10 | netFileFullPath = fullfile(dataPath,modelName);
11 |
12 | % Add '.mat' extension to the data.
13 | netFileFull = [netFileFullPath,'.zip'];
14 |
15 | if ~exist(netFileFull,'file')
16 | fprintf(['Downloading pretrained', modelName ,'network.\n']);
17 | fprintf('This can take several minutes to download...\n');
18 | url = 'https://ssd.mathworks.com/supportfiles/lidar/data/trainedSalsaNextPandasetNet.zip';
19 | websave (netFileFullPath,url);
20 | unzip(netFileFullPath, dataPath);
21 | model = load([dataPath, '/trainedSalsaNext.mat']);
22 | else
23 | fprintf('Pretrained SalsaNext network already exists.\n\n');
24 | unzip(netFileFullPath, dataPath);
25 | model = load(fullfile(dataPath, 'trainedSalsaNext.mat'));
26 | end
27 |
28 | end
29 |
--------------------------------------------------------------------------------
/+helper/generateLidarData.m:
--------------------------------------------------------------------------------
1 | function generateLidarData(lidarData, imageDataLocation)
2 |
3 | %generateLidarData Function to generate images
4 | % from Lidar Point Clouds. The inputs
5 | % lidarData, imageDataLocation are described below.
6 | %
7 | % Inputs
8 | % ------
9 | % lidarData Lidar point clouds folder path.
10 | %
11 | % imageDataLocation Folder where training images will be saved to
12 | % disk. Make sure this points to a valid location on
13 | % the filesystem.
14 | %
15 | % Copyright 2021 The MathWorks, Inc
16 |
17 | if ~exist(imageDataLocation,'dir')
18 | mkdir(imageDataLocation);
19 | end
20 |
21 |
22 | tmpStr = '';
23 | lidarData = dir(fullfile(lidarData,'*.pcd'));
24 | numFiles = size(lidarData,1);
25 | for i=1:numFiles
26 | % Load ptcloud object.
27 | data = fullfile(lidarData(i).folder,lidarData(i).name);
28 | ptcloud = pcread(data);
29 | % Image are of 5-channels, namely x,y,z,intensity and range.
30 | im = helper.pointCloudToImage(ptcloud);
31 |
32 | % Store images and labels as .mat and .png files respectively.
33 | imfile = fullfile(imageDataLocation,sprintf('%04d.mat',i));
34 | save(imfile,'im');
35 |
36 |
37 | % Display progress after 300 files on screen.
38 | if ~mod(i,300)
39 | msg = sprintf('Preprocessing data %3.2f%% complete', (i/numFiles)*100.0);
40 | fprintf(1,'%s',[tmpStr, msg]);
41 | tmpStr = repmat(sprintf('\b'), 1, length(msg));
42 | end
43 | end
44 |
45 | % Print completion message when done.
46 | msg = sprintf('Preprocessing data 100%% complete');
47 | fprintf(1,'%s',[tmpStr, msg]);
48 |
49 | end
--------------------------------------------------------------------------------
/+helper/imageMatReader.m:
--------------------------------------------------------------------------------
1 | function data = imageMatReader(filename)
2 | %imageMatReader Reads custom MAT files containing 5-channel
3 | % multispectral image data.
4 | %
5 | % DATA = imageMatReader(FILENAME) returns the first 5 channels of the
6 | % multispectral image saved in FILENAME.
7 |
8 | % Copyright 2021 The MathWorks, Inc
9 |
10 | d = load(filename);
11 | f = fields(d);
12 | data = d.(f{1})(:,:,1:5);
13 | index = isnan(data);
14 | data(index) = 0;
15 |
16 |
17 |
--------------------------------------------------------------------------------
/+helper/lidarColorMap.m:
--------------------------------------------------------------------------------
1 | function cmap = lidarColorMap()
2 | % Lidar color map for the pandaset classes
3 |
4 | % Copyright 2021 The MathWorks, Inc
5 |
6 | cmap = [[30,30,30]; % UnClassified
7 | [0,255,0]; % Vegetation
8 | [255, 150, 255]; % Ground
9 | [255,0,255]; % Road
10 | [255,0,0]; % Road Markings
11 | [90, 30, 150]; % Side Walk
12 | [245,150,100]; % Car
13 | [250, 80, 100]; % Truck
14 | [150, 60, 30]; % Other Vehicle
15 | [255, 255, 0]; % Pedestrian
16 | [0, 200, 255]; % Road Barriers
17 | [170,100,150]; % Signs
18 | [30, 30, 255]]; % Building
19 |
20 | cmap = cmap./255;
21 |
22 | end
--------------------------------------------------------------------------------
/+helper/partitionLidarData.m:
--------------------------------------------------------------------------------
1 | function [imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionLidarData(imds,pxds)
2 | % Partition Pandaset data by randomly selecting 60% of the data for training. The
3 | % rest is used for testing.
4 | %
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | % Set initial random state for example reproducibility.
8 | rng(0);
9 | numFiles = numel(imds.Files);
10 | shuffledIndices = randperm(numFiles);
11 |
12 | % Use 60% of the images for training.
13 | numTrain = round(0.60 * numFiles);
14 | trainingIdx = shuffledIndices(1:numTrain);
15 |
16 | % Use 20% of the images for validation
17 | numVal = round(0.20 * numFiles);
18 | valIdx = shuffledIndices(numTrain+1:numTrain+numVal);
19 |
20 | % Use the rest for testing.
21 | testIdx = shuffledIndices(numTrain+numVal+1:end);
22 |
23 | % Create image datastores for training and test.
24 | trainingImages = imds.Files(trainingIdx);
25 | valImages = imds.Files(valIdx);
26 | testImages = imds.Files(testIdx);
27 |
28 | imdsTrain = imageDatastore(trainingImages,'FileExtensions', '.mat', ...
29 | 'ReadFcn', @helper.imageMatReader);
30 | imdsVal = imageDatastore(valImages,'FileExtensions', '.mat', ...
31 | 'ReadFcn', @helper.imageMatReader);
32 | imdsTest = imageDatastore(testImages,'FileExtensions', '.mat', ...
33 | 'ReadFcn', @helper.imageMatReader);
34 |
35 | % Extract class and label IDs info.
36 | classes = pxds.ClassNames;
37 | labelIDs = 1 : numel(classes);
38 |
39 | % Create pixel label datastores for training and test.
40 | trainingLabels = pxds.Files(trainingIdx);
41 | valLabels = pxds.Files(valIdx);
42 | testLabels = pxds.Files(testIdx);
43 |
44 | pxdsTrain = pixelLabelDatastore(trainingLabels, classes, labelIDs);
45 | pxdsVal = pixelLabelDatastore(valLabels, classes, labelIDs);
46 | pxdsTest = pixelLabelDatastore(testLabels, classes, labelIDs);
47 | end
--------------------------------------------------------------------------------
/+helper/pointCloudToImage.m:
--------------------------------------------------------------------------------
1 | function image = pointCloudToImage(ptcloud)
2 | % pointCloudToImage Converts organized 3-D point cloud to 5-channel
3 | % 2-D image.
4 | %
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | image = ptcloud.Location;
8 | image(:,:,4) = ptcloud.Intensity;
9 | rangeData = iComputeRangeData(image(:,:,1),image(:,:,2),image(:,:,3));
10 | image(:,:,5) = rangeData;
11 | index = isnan(image);
12 | image(index) = 0;
13 | end
14 |
15 | %--------------------------------------------------------------------------
16 | function rangeData = iComputeRangeData(xChannel,yChannel,zChannel)
17 | rangeData = sqrt(xChannel.*xChannel+yChannel.*yChannel+zChannel.*zChannel);
18 | end
19 |
--------------------------------------------------------------------------------
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2.1
2 | orbs:
3 | matlab: mathworks/matlab@0.4.0
4 |
5 | jobs:
6 | build:
7 | machine:
8 | image: ubuntu-1604:201903-01
9 | steps:
10 | - checkout
11 | - matlab/install
12 | - matlab/run-tests:
13 | test-results-junit: artifacts/test_results/matlab/results.xml
14 | # Have to add test/tools to the path for certain tests.
15 | source-folder: .;test/tools
16 | - store_test_results:
17 | path: artifacts/test_results
18 | - store_artifacts:
19 | path: artifacts/
20 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021, The MathWorks, Inc.
2 | All rights reserved.
3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
5 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
6 | 3. In all cases, the software is, and all modifications and derivatives of the software shall be, licensed to you solely for use in conjunction with MathWorks products and service offerings.
7 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
8 |
--------------------------------------------------------------------------------
/Pandaset.rights:
--------------------------------------------------------------------------------
1 | Panda Set (https://scale.com/open-datasets/pandaset) is provided by Hesai and Scale under the CC-BY-4.0 license (https://creativecommons.org/licenses/by/4.0).
2 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # SalsaNext For Lidar Segmentation
2 |
3 | This repository provides a pretrained SalsaNext[1] segmentation model in MATLAB®.
4 |
5 | [](https://matlab.mathworks.com/open/github/v1?repo=matlab-deep-learning/pretrained-salsanext)
6 |
7 | Requirements
8 | ------------
9 |
10 | - MATLAB R2021a or later.
11 | - Deep Learning Toolbox™
12 | - Lidar Toolbox™
13 |
14 | Overview
15 | --------
16 |
17 | This repository provides the SalsaNext network trained to segment different object categories including road, cars, trucks, etc. This pretrained model is trained using Pandaset dataset[2] which has 13 different object categories. This repository works on organized point clouds, if you have unorganized or raw point clouds refer the [Unorganized to Organized Conversion of Point Clouds Using Spherical Projection example](https://in.mathworks.com/help/lidar/ug/unorgaized-to-organized-pointcloud-conversion.html) to convert into organized point clouds.
18 |
19 | SalsaNext is the popular Lidar semantic segmentation network used for segmentation of 3-D point clouds. SalsaNext was build upon SalsaNet that has encoder-decoder architecture with residual dilated convolution stack with gradually increasing receptive fields for encoder and a pixel-shuffle layer was added as upsampling layer in the decoder.
20 |
21 | Getting Started
22 | ---------------
23 | Download or clone this repository to your machine and open in MATLAB.
24 |
25 | ### Download the pretrained network
26 | Use below helper to download the pretrained network. The network will be downloaded and saved in `model` directory.
27 |
28 | ```
29 | model = helper.downloadPretrainedSalsaNext;
30 | net = model.net;
31 | ```
32 |
33 | ### Semantic Segmentation Using Pretrained SalsaNext
34 |
35 | ```
36 | % Read test point cloud.
37 | ptCloud = pcread('pointclouds/0001.pcd');
38 |
39 | % Convert point cloud to 5-channel image.
40 | I = helper.pointCloudToImage(ptCloud);
41 |
42 | % Segment objects from the test point cloud.
43 | predictedResult = semanticseg(I, net);
44 |
45 | % Display the output.
46 | op = single(predictedResult);
47 | cmap = helper.lidarColorMap();
48 | colormap = cmap(op,:);
49 | ptCloudMod = pointCloud(reshape(I(:,:,1:3),[],3),"Color",colormap);
50 | figure
51 | pcshow(ptCloudMod);
52 | ```
53 |
54 |
55 | Video output generated on test sequence.
56 |
57 |
58 |
59 | ### Generate 3-D bounding boxes from segmentation result
60 | The segmentation result is transformed to cuboid by clustering the points from class of interest and fitting cuboid around it. For more information about how to detect and track objects using lidar data, see [Detect, Classify, and Track Vehicles Using Lidar Example](https://www.mathworks.com/help/lidar/ug/detect-classify-and-track-vehicles-using-lidar.html).
61 |
62 | ```
63 | % Get the indices of points for the required class.
64 | carIdx = (predictedResult == 'Car');
65 |
66 | % Select the points of required class and cluster them based on distance.
67 | ptCldMod = select(ptCloud,carIdx);
68 | [labels,numClusters] = pcsegdist(ptCldMod,0.5);
69 |
70 | % Select each cluster and fit a cuboid to each cluster.
71 | bboxes = [];
72 | for num = 1:numClusters
73 | labelIdx = (labels == num);
74 |
75 | % Ignore cluster that has points less than 200 points.
76 | if sum(labelIdx,'all') < 200
77 | continue;
78 | end
79 | pcSeg = select(ptCldMod,labelIdx);
80 | try
81 | mdl = pcfitcuboid(pcSeg);
82 | bboxes = [bboxes;mdl.Parameters];
83 | catch
84 | continue;
85 | end
86 | end
87 |
88 | % Display the output.
89 | figure;
90 | ax = pcshow(ptCloudMod);
91 | showShape('cuboid',bboxes,'Parent',ax,'Opacity',0.1,...
92 | 'Color','green','LineWidth',0.5);
93 | ```
94 |
95 |
96 |
97 | Training Custom SalsaNext Using Transfer Learning
98 | -------------------------------------------------
99 |
100 | Transfer learning enables you to adapt a pretrained SalsaNext network to your dataset. Create a custom SalsaNext network for transfer learning with a new set of classes using the `salsaNextTransferLearn.m` script.
101 |
102 |
103 | Code Generation for SalsaNext
104 | -----------------------------
105 | Code generation enables you to generate code and deploy SalsaNext on multiple embedded platforms.
106 |
107 | Run `codegenSalsaNext.m`. This script calls the `salsaNextpredict.m` entry point function and generate CUDA code for it. It will run the generated MEX and gives output.
108 |
109 | | Model | Inference Speed (FPS) |
110 | | ------ | ------ |
111 | | SalsaNext w/o codegen | 8.06 |
112 | | SalsaNet with codegen | 24.39 |
113 |
114 | - Performance (in FPS) is measured on a TITAN-RTX GPU.
115 |
116 | For more information about codegen, see [Deep Learning with GPU Coder](https://www.mathworks.com/help/gpucoder/gpucoder-deep-learning.html).
117 |
118 | ## SalsaNext Architecture Details
119 | The SalsaNext network architecture is illustrated in the following diagram.
120 |
121 |
122 |
123 | - **Context Module:** To aggregate the global context information in different regions, a residual dilated convolution stack is used. This module fuses a larger receptive field with a smaller one by adding 1x1 and 3x3 kernels right at the beginning of the network. This module captures the global context alongside with more detailed spatial information.
124 |
125 | - **Encoder:** The encoder block consists of a novel combination of a set of dilated convolutions having effective receptive fields of 3, 5 and 7. Each dilated convolution output is concatenated and then passed through a 1×1 convolution followed by a residual connection. This helps the network exploit more information from the fused features coming from various depths in the receptive field. Each of these new residual dilated convolution blocks is followed by dropout and pooling layers.
126 |
127 | - **Decoder:** In general transpose convolutions are used in decoders which are computationally expensive. These standard transpose convolutions are replaced with the pixelshuffle layers, which leverages on the learnt feature maps to produce the upsampled feature maps by shuffling the pixels from the channel dimension to the spatial dimension.
128 |
129 |
130 |
131 | References
132 | ----------
133 |
134 | [1] Cortinhal, Tiago, George Tzelepis, and Eren Erdal Aksoy. “SalsaNext: Fast, Uncertainty-Aware Semantic Segmentation of LiDAR Point Clouds for Autonomous Driving.” ArXiv:2003.03653 [Cs], July 9, 2020. http://arxiv.org/abs/2003.03653.
135 |
136 | [2] [Panda Set](https://scale.com/open-datasets/pandaset) is provided by Hesai and Scale under the [CC-BY-4.0 license](https://creativecommons.org/licenses/by/4.0)
137 |
138 |
139 | Copyright 2021 The Mathworks, Inc.
140 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Reporting Security Vulnerabilities
2 |
3 | If you believe you have discovered a security vulnerability, please report it to
4 | [security@mathworks.com](mailto:security@mathworks.com). Please see
5 | [MathWorks Vulnerability Disclosure Policy for Security Researchers](https://www.mathworks.com/company/aboutus/policies_statements/vulnerability-disclosure-policy.html)
6 | for additional information.
7 |
--------------------------------------------------------------------------------
/codegenSalsaNext.m:
--------------------------------------------------------------------------------
1 | %% Code generation For SalsaNext Network
2 | % The following script demonstrates how to perform code generation for a
3 | % pretrained SalsaNext semantic segmentation network, trained on Pandaset
4 | % dataset.
5 |
6 | %% Download the pre-trained network
7 | helper.downloadPretrainedSalsaNext;
8 |
9 | %% Read and process the input point cloud
10 | % Read test point cloud.
11 | ptCloud = pcread('pointclouds/Input1.pcd');
12 |
13 | % Convert point cloud to 5-channel image.
14 | I = helper.pointCloudToImage(ptCloud);
15 |
16 | %% Run MEX code generation
17 | % The salsaNextpredict.m is entry-point function that takes an input image
18 | % and gives output. The function uses a persistent object salsaNextObj to
19 | % load the DAG network object and reuses the persistent object for prediction
20 | % on subsequent calls.
21 | %
22 | % To generate CUDA code for the salsaNextpredict entry-point function,
23 | % create a GPU code configuration object for a MEX target and set the
24 | % target language to C++.
25 | %
26 | % Use the coder.DeepLearningConfig (GPU Coder) function to create a CuDNN
27 | % deep learning configuration object and assign it to the DeepLearningConfig
28 | % property of the GPU code configuration object.
29 | %
30 | % Run the codegen command and specify the input size.
31 | cfg = coder.gpuConfig('mex');
32 | cfg.TargetLang = 'C++';
33 | cfg.DeepLearningConfig = coder.DeepLearningConfig('cudnn');
34 | codegen -config cfg salsaNextpredict -args {ones(64,1856,5)} -report
35 |
36 | %% Perform Semantic Segmenation Using Generated Mex
37 | % Call salsaNextpredict_mex on the input range image.
38 | predict_scores = salsaNextpredict_mex(I);
39 |
40 | % The predict_scores variable is a three-dimensional matrix that has 13
41 | % channels corresponding to the pixel-wise prediction scores for every
42 | % class. Compute the channel by using the maximum prediction score to get
43 | % pixel-wise labels.
44 | [~,op] = max(predict_scores,[],3);
45 |
46 | % Visualize the result.
47 | cmap = helper.lidarColorMap();
48 | colormap = cmap(op,:);
49 | ptCloudMod = pointCloud(reshape(I(:,:,1:3),[],3),"Color",colormap);
50 | figure
51 | ax1 = pcshow(ptCloudMod);
52 | zoom(ax1,3);
53 |
54 |
55 | % Copyright 2021 The MathWorks, Inc
--------------------------------------------------------------------------------
/images/SalsaNext_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/images/SalsaNext_architecture.png
--------------------------------------------------------------------------------
/images/cuboid_0001.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/images/cuboid_0001.PNG
--------------------------------------------------------------------------------
/images/result.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/images/result.gif
--------------------------------------------------------------------------------
/images/result_0001.PNG:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/images/result_0001.PNG
--------------------------------------------------------------------------------
/model/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/pointclouds/Input1.pcd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/pointclouds/Input1.pcd
--------------------------------------------------------------------------------
/pointclouds/Input2.pcd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/pointclouds/Input2.pcd
--------------------------------------------------------------------------------
/pointclouds/Input3.pcd:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/pointclouds/Input3.pcd
--------------------------------------------------------------------------------
/salsaNextSemanticSegmentationExample.m:
--------------------------------------------------------------------------------
1 | %% Lidar Point Cloud Semantic Segmentation Using SalsaNext Deep Learning Network
2 | % The following code demonstrates running prediction on a pre-trained
3 | % SalsaNext network, trained on Pandaset Dataset.
4 |
5 | %% Prerequisites
6 | % To run this example you need the following prerequisites -
7 | % # MATLAB (R2021a or later) with Lidar and Deep Learning Toolbox.
8 | % # Pretrained SalsaNext network(download instructions below)
9 |
10 | %% Download the pre-trained network
11 | model = helper.downloadPretrainedSalsaNext;
12 | net = model.net;
13 |
14 | % Define ClassNames
15 | classNames = ["unlabelled"
16 | "Vegetation"
17 | "Ground"
18 | "Road"
19 | "RoadMarkings"
20 | "SideWalk"
21 | "Car"
22 | "Truck"
23 | "OtherVehicle"
24 | "Pedestrian"
25 | "RoadBarriers"
26 | "Signs"
27 | "Buildings"];
28 |
29 | %% Perform Semantic Segmentation Using SalsaNext Network
30 | % Read test point cloud.
31 | ptCloud = pcread('pointclouds/Input1.pcd');
32 |
33 | % Convert point cloud to 5-channel image.
34 | I = helper.pointCloudToImage(ptCloud);
35 |
36 | % Segment objects from the test point cloud.
37 | predictedResult = semanticseg(I, net,"ExecutionEnvironment","auto");
38 |
39 | %% Display Output
40 | figure;
41 | helper.displayLidarOverlayImage(I, predictedResult, classNames);
42 | title('Semantic Segmentation Result');
43 |
44 | % Display in point cloud format.
45 | cmap = helper.lidarColorMap();
46 | colormap = cmap(single(predictedResult),:);
47 | ptCloudMod = pointCloud(reshape(I(:,:,1:3),[],3),"Color",colormap);
48 | figure
49 | ax = pcshow(ptCloudMod);
50 | zoom(ax,3);
51 |
52 | %% Get Bounding Boxes from semgenation output.
53 | % Get the indices of points for the required class.
54 | carIdx = (predictedResult == 'Car');
55 |
56 | % Select the points of required class and cluster them based on distance.
57 | ptCldMod = select(ptCloud,carIdx);
58 | [labels,numClusters] = pcsegdist(ptCldMod,0.5);
59 |
60 | % Select each cluster and fit a cuboid to each cluster.
61 | bboxes = [];
62 | for num = 1:numClusters
63 | labelIdx = (labels == num);
64 |
65 | % Ignore cluster that has points less than 200 points.
66 | if sum(labelIdx,'all') < 200
67 | continue;
68 | end
69 | pcSeg = select(ptCldMod,labelIdx);
70 | try
71 | mdl = pcfitcuboid(pcSeg);
72 | bboxes = [bboxes;mdl.Parameters];
73 | catch
74 | continue;
75 | end
76 | end
77 |
78 | % Display the output.
79 | figure;
80 | ax = pcshow(ptCloudMod);
81 | showShape('cuboid',bboxes,'Parent',ax,'Opacity',0.1,...
82 | 'Color','green','LineWidth',0.5);
83 | zoom(ax,3);
84 |
85 | % Copyright 2021 The MathWorks, Inc
--------------------------------------------------------------------------------
/salsaNextTransferLearning.m:
--------------------------------------------------------------------------------
1 | %% Configure Pretrained SalsaNext Network for Transfer Learning
2 | % The following code demonstrates configuring a pretrained
3 | % SalsaNext[1] network on the custom dataset.
4 |
5 | %% Download Pretrained Model
6 |
7 | model = helper.downloadPretrainedSalsaNext;
8 | net = model.net;
9 |
10 | %% Download Pandaset Data Set
11 | % This example uses a subset of PandaSet[2], that contains 2560
12 | % preprocessed organized point clouds. Each point cloud is specified as a
13 | % 64-by-1856 matrix. The corresponding ground truth contains the semantic
14 | % segmentation labels for 12 classes. The point clouds are stored in PCD
15 | % format, and the ground truth data is stored in PNG format. The size of
16 | % the data set is 5.2 GB. Execute this code to download the data set.
17 |
18 | url = 'https://ssd.mathworks.com/supportfiles/lidar/data/Pandaset_LidarData.tar.gz';
19 | outputFolder = fullfile(tempdir,'Pandaset');
20 |
21 | lidarDataTarFile = fullfile(outputFolder,'Pandaset_LidarData.tar.gz');
22 | if ~exist(lidarDataTarFile, 'file')
23 | mkdir(outputFolder);
24 | disp('Downloading Pandaset Lidar driving data (5.2 GB)...');
25 | websave(lidarDataTarFile, url);
26 | untar(lidarDataTarFile,outputFolder);
27 | end
28 |
29 | % Check if tar.gz file is downloaded, but not uncompressed.
30 | if (~exist(fullfile(outputFolder,'Lidar'), 'file'))...
31 | &&(~exist(fullfile(outputFolder,'semanticLabels'), 'file'))
32 | untar(lidarDataTarFile,outputFolder);
33 | end
34 |
35 | lidarData = fullfile(outputFolder,'Lidar');
36 | labelsFolder = fullfile(outputFolder,'semanticLabels');
37 |
38 | % Note: Depending on your Internet connection, the download process can
39 | % take some time. The code suspends MATLAB® execution until the download
40 | % process is complete. Alternatively, you can download the data set to your
41 | % local disk using your web browser, and then extract Pandaset_LidarData
42 | % folder. To use the file you downloaded from the web, change the
43 | % outputFolder variable in the code to the location of the downloaded file.
44 |
45 | %% Prepare Data for Training
46 | % Load Lidar Point Clouds and Class Labels Use the generateLidarData helper
47 | % function, to generate training data from the lidar point clouds. The
48 | % function uses point cloud data to create five-channel input images. Each
49 | % training image is specified as a 64-by-1856-by-5 array:
50 | %
51 | % Generate the five-channel training images.
52 |
53 | imagesFolder = fullfile(outputFolder,'Images');
54 | helper.generateLidarData(lidarData,imagesFolder);
55 |
56 | % The five-channel images are saved as MAT files.
57 | %
58 | % Note: Processing can take some time. The code suspends MATLAB® execution
59 | % until processing is complete.
60 |
61 | %% Load Generated Images.
62 | % Create ImageDatastore and PixelLabelDatastore Use the ImageDatastore
63 | % object to extract and store the five channels of the 2-D spherical images
64 | % using the imageMatReader helper function, which is a custom MAT file
65 | % reader.
66 |
67 | imds = imageDatastore(imagesFolder, ...
68 | 'FileExtensions', '.mat', ...
69 | 'ReadFcn', @helper.imageMatReader);
70 |
71 | % Use the PixelLabelDatastore object to store pixel-wise labels from pixel
72 | % label images. The object maps each pixel label to a class name. In this
73 | % example, vegetation, ground, road, road markings, side walk, cars,
74 | % trucks, other vehicles, pedestrian, road barrier signs and buildings are
75 | % the objects of interest; all other pixels are the background. Specify
76 | % these classes and assign a unique label ID to each class.
77 |
78 | classNames = ["unlabelled"
79 | "Vegetation"
80 | "Ground"
81 | "Road"
82 | "RoadMarkings"
83 | "SideWalk"
84 | "Car"
85 | "Truck"
86 | "OtherVehicle"
87 | "Pedestrian"
88 | "RoadBarriers"
89 | "Signs"
90 | "Buildings"];
91 |
92 | numClasses = numel(classNames);
93 |
94 | % Specify label IDs from 1 to the number of classes.
95 | labelIDs = 1 : numClasses;
96 |
97 | pxds = pixelLabelDatastore(labelsFolder, classNames, labelIDs);
98 |
99 | %% Prepare Training, Validation, and Test Sets
100 | % Use the partitionLidarData helper function to split the data into
101 | % training, images, respectively.
102 |
103 | [imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = helper.partitionLidarData(imds, pxds);
104 |
105 | dsTrain = combine(imdsTrain,pxdsTrain);
106 | dsVal = combine(imdsVal,pxdsVal);
107 |
108 | %% Data Augmentation
109 | % Data augmentation is used to improve network accuracy by randomly
110 | % transforming the original data during training. By using data
111 | % augmentation, you can add more variety to the training data without
112 | % actually having to increase the number of labeled training samples.
113 | %
114 | % Augment the training data by using the transform function with custom
115 | % preprocessing operations specified by the augmentData helper function.
116 | % This function randomly flips the multichannel 2-D image and associated
117 | % labels in the horizontal direction. Apply data augmentation to only the
118 | % training data set.
119 |
120 | augmentedTrainingData = transform(dsTrain, @(x) helper.augmentData(x));
121 | %% Configure Pretrained Network
122 |
123 | % Extract the layergraph from the pretrained network to perform custom
124 | % modification.
125 | lgraph = layerGraph(net);
126 |
127 | % Changing output size to required number of classes.
128 | lgraph = replaceLayer(lgraph, 'Conv_191', convolution2dLayer([1,1], numClasses, 'Name', 'Conv_191'));
129 |
130 | %% % Define training options.
131 | options = trainingOptions('sgdm', ...
132 | 'LearnRateSchedule','piecewise',...
133 | 'LearnRateDropPeriod',10,...
134 | 'LearnRateDropFactor',0.3,...
135 | 'Momentum',0.9, ...
136 | 'InitialLearnRate',1e-3, ...
137 | 'L2Regularization',0.005, ...
138 | 'ValidationData',dsVal,...
139 | 'MaxEpochs',6, ...
140 | 'MiniBatchSize',8, ...
141 | 'Shuffle','every-epoch', ...
142 | 'CheckpointPath', tempdir, ...
143 | 'VerboseFrequency',2,...
144 | 'Plots','training-progress',...
145 | 'ValidationPatience', 4);
146 |
147 | % The learning rate uses a piecewise schedule. The learning rate is reduced
148 | % by a factor of 0.3 every 10 epochs. This allows the network to learn quickly
149 | % with a higher initial learning rate, while being able to find a solution
150 | % close to the local optimum once the learning rate drops.
151 | %
152 | % The network is tested against the validation data every epoch by setting
153 | % the 'ValidationData' parameter. The 'ValidationPatience' is set to 4 to
154 | % stop training early when the validation accuracy converges. This prevents
155 | % the network from overfitting on the training dataset.
156 | %
157 | % A mini-batch size of 16 is used for training. You can increase or decrease
158 | % this value based on the amount of GPU memory you have on your system.
159 | %
160 | % In addition, 'CheckpointPath' is set to a temporary location. This name-value
161 | % pair enables the saving of network checkpoints at the end of every training
162 | % epoch. If training is interrupted due to a system failure or power outage,
163 | % you can resume training from the saved checkpoint. Make sure that the location
164 | % specified by 'CheckpointPath' has enough space to store the network checkpoints.
165 |
166 | % Now, you can pass the 'dsTrain', 'lgraph' and 'options' to trainNetwork
167 | % as shown in 'Train Network' section of the example 'Lidar Point Cloud
168 | % Semantic Segmentation Using SqueezeSegV2 Deep Learning Network Example'
169 | % (https://www.mathworks.com/help/lidar/ug/semantic-segmentation-using-squeezesegv2-network.html)to
170 | % obtain salsaNext model trained on the custom dataset.
171 | %
172 | % You can follow the sections 'Test Network on One Image' for inference using
173 | % the trained model and 'Evaluate Trained Network' for evaluating metrics.
174 |
175 |
176 | %% References
177 |
178 | % [1] Cortinhal, Tiago, George Tzelepis, and Eren Erdal Aksoy. "SalsaNext: Fast,
179 | % Uncertainty-Aware Semantic Segmentation of LiDAR Point Clouds for Autonomous
180 | % Driving." ArXiv:2003.03653 [Cs], July 9, 2020. http://arxiv.org/abs/2003.03653
181 | % http://arxiv.org/abs/2003.03653.
182 | %
183 | % [2] https://scale.com/open-datasets/pandaset https://scale.com/open-datasets/pandaset
184 | %
185 | % Copyright 2020 The MathWorks, Inc.
--------------------------------------------------------------------------------
/salsaNextpredict.m:
--------------------------------------------------------------------------------
1 | function out = salsaNextpredict(in)
2 | %#codegen
3 | % Copyright 2021 The MathWorks, Inc.
4 |
5 | persistent salsaNextObj;
6 |
7 | if isempty(salsaNextObj)
8 | salsaNextObj = coder.loadDeepLearningNetwork('model/trainedSalsaNext.mat');
9 | end
10 |
11 | % Pass input.
12 | out = predict(salsaNextObj,in);
13 |
14 | end
--------------------------------------------------------------------------------
/test/tPretrainedSalsaNext.m:
--------------------------------------------------------------------------------
1 | classdef(SharedTestFixtures = {DownloadSalsaNextFixture}) tPretrainedSalsaNext < matlab.unittest.TestCase
2 | % Test for tPretrainedSalsaNext
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | % The shared test fixture downloads the model. Here we check the
7 | % inference on the pretrained model.
8 | properties
9 | RepoRoot = getRepoRoot;
10 | ModelName = 'trainedSalsaNext.mat';
11 | end
12 |
13 | methods(Test)
14 | function exerciseDetection(test)
15 | model = load(fullfile(test.RepoRoot,'model',test.ModelName));
16 | ptCloud = pcread(fullfile(test.RepoRoot,'pointclouds','Input1.pcd'));
17 | img = helper.pointCloudToImage(ptCloud);
18 |
19 | imSize = size(img);
20 | imSize = imSize(:,1:2);
21 | actualLabel1Count = 48420;
22 | actualLabel2Count = 11990;
23 |
24 | result = semanticseg(img, model.net);
25 | labelsCountTbl = countlabels(result(:));
26 | labelCount = labelsCountTbl.Count(find(labelsCountTbl.Count));
27 |
28 | % verifying size of output from semanticseg.
29 | test.verifyEqual(size(result),imSize);
30 | % verifying that all the pixels are labelled.
31 | test.verifyEqual(sum(labelCount),prod(imSize));
32 | % verifying the count of each labels on the result.
33 | test.verifyEqual(labelCount(1),actualLabel1Count);
34 | test.verifyEqual(labelCount(2),actualLabel2Count);
35 | end
36 | end
37 | end
--------------------------------------------------------------------------------
/test/tdownloadPretrainedSalsaNext.m:
--------------------------------------------------------------------------------
1 | classdef(SharedTestFixtures = {DownloadSalsaNextFixture}) tdownloadPretrainedSalsaNext < matlab.unittest.TestCase
2 | % Test for downloadPretrainedSalsaNext
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | % The shared test fixture DownloadSalsaNextFixture calls
7 | % downloadPretrainedSalsaNext. Here we check that the downloaded files
8 | % exists in the appropriate location.
9 |
10 | properties
11 | DataDir = fullfile(getRepoRoot(),'model');
12 | end
13 |
14 | methods(Test)
15 | function verifyDownloadedFilesExist(test)
16 | dataFileName = 'trainedSalsaNext.mat';
17 | test.verifyTrue(isequal(exist(fullfile(test.DataDir,dataFileName),'file'),2));
18 | end
19 | end
20 | end
--------------------------------------------------------------------------------
/test/tload.m:
--------------------------------------------------------------------------------
1 | classdef(SharedTestFixtures = {DownloadSalsaNextFixture}) tload < matlab.unittest.TestCase
2 | % Test for loading the downloaded models.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | % The shared test fixture DownloadSalsaNextFixture calls
7 | % downloadPretrainedSalsaNext. Here we check that the properties of
8 | % downloaded models.
9 |
10 | properties
11 | DataDir = fullfile(getRepoRoot(),'model');
12 | end
13 |
14 | methods(Test)
15 | function verifyModelAndFields(test)
16 | % Test point to verify the fields of the downloaded models are
17 | % as expected.
18 |
19 | loadedModel = load(fullfile(test.DataDir,'trainedSalsaNext.mat'));
20 |
21 | test.verifyClass(loadedModel.net,'DAGNetwork');
22 | test.verifyEqual(numel(loadedModel.net.Layers),175);
23 | test.verifyEqual(size(loadedModel.net.Connections),[204 2])
24 | test.verifyEqual(loadedModel.net.InputNames,{'Input_input.1'});
25 | test.verifyEqual(loadedModel.net.OutputNames,{'focalloss_out'});
26 | end
27 | end
28 | end
--------------------------------------------------------------------------------
/test/tools/DownloadSalsaNextFixture.m:
--------------------------------------------------------------------------------
1 | classdef DownloadSalsaNextFixture < matlab.unittest.fixtures.Fixture
2 | % DownloadSalsaNextFixture A fixture for calling
3 | % downloadPretrainedSalsaNext if necessary. This is to ensure that this
4 | % function is only called once and only when tests need it. It also
5 | % provides a teardown to return the test environment to the expected
6 | % state before testing.
7 |
8 | % Copyright 2021 The MathWorks, Inc
9 |
10 | properties(Constant)
11 | SalsaNextDataDir = fullfile(getRepoRoot(),'model')
12 | end
13 |
14 | properties
15 | SalsaNextExist (1,1) logical
16 | end
17 |
18 | methods
19 | function setup(this)
20 | this.SalsaNextExist = exist(fullfile(this.SalsaNextDataDir,'trainedSalsaNext.mat'),'file')==2;
21 |
22 | % Call this in eval to capture and drop any standard output
23 | % that we don't want polluting the test logs.
24 | if ~this.SalsaNextExist
25 | evalc('helper.downloadPretrainedSalsaNext();');
26 | end
27 | end
28 |
29 | function teardown(this)
30 | if this.SalsaNextExist
31 | delete(fullfile(this.SalsaNextDataDir,'trainedSalsaNext.mat'));
32 | end
33 | end
34 | end
35 | end
--------------------------------------------------------------------------------
/test/tools/getRepoRoot.m:
--------------------------------------------------------------------------------
1 | function path = getRepoRoot()
2 | % getRepoRoot Return a path to the repository's root directory.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | thisFile = mfilename('fullpath');
7 | thisDir = fileparts(thisFile);
8 |
9 | % the root is up two directories (/test/tools/getRepoRoot.m)
10 | path = fullfile(thisDir,'..','..');
11 | end
--------------------------------------------------------------------------------