├── +helper ├── augmentData.m ├── displayLidarOverlayImage.m ├── downloadPretrainedSalsaNext.m ├── generateLidarData.m ├── imageMatReader.m ├── lidarColorMap.m ├── partitionLidarData.m └── pointCloudToImage.m ├── .circleci └── config.yml ├── LICENSE ├── Pandaset.rights ├── README.md ├── SECURITY.md ├── codegenSalsaNext.m ├── images ├── SalsaNext_architecture.png ├── cuboid_0001.PNG ├── result.gif └── result_0001.PNG ├── model └── .gitkeep ├── pointclouds ├── Input1.pcd ├── Input2.pcd └── Input3.pcd ├── salsaNextSemanticSegmentationExample.m ├── salsaNextTransferLearning.m ├── salsaNextpredict.m └── test ├── tPretrainedSalsaNext.m ├── tdownloadPretrainedSalsaNext.m ├── tload.m └── tools ├── DownloadSalsaNextFixture.m └── getRepoRoot.m /+helper/augmentData.m: -------------------------------------------------------------------------------- 1 | function out = augmentData(inp) 2 | % Apply random horizontal flipping. 3 | 4 | % Copyright 2021 The MathWorks, Inc 5 | out = cell(size(inp)); 6 | 7 | % Randomly flip the five-channel image and pixel labels horizontally. 8 | I = inp{1}; 9 | sz = size(I); 10 | tform = randomAffine2d('XReflection',true); 11 | rout = affineOutputView(sz,tform,'BoundsStyle','centerOutput'); 12 | 13 | out{1} = imwarp(I,tform,'OutputView',rout); 14 | out{2} = imwarp(inp{2},tform,'OutputView',rout); 15 | end -------------------------------------------------------------------------------- /+helper/displayLidarOverlayImage.m: -------------------------------------------------------------------------------- 1 | function displayLidarOverlayImage(lidarImage, labelMap, classNames) 2 | %displayLidarOverlayImage Overlay labels over the intensity image. 3 | % 4 | % displayLidarOverlayImage(lidarImage, labelMap, classNames) 5 | % displays the overlaid image. lidarImage is a five-channel lidar input. 6 | % labelMap contains pixel labels and classNames is an array of label 7 | % names. 8 | % 9 | % Copyright 2021 The MathWorks, Inc. 10 | 11 | % Read the intensity channel from the lidar image. 12 | intensityChannel = uint8(lidarImage(:,:,4)); 13 | 14 | % Load the lidar color map. 15 | cmap = helper.lidarColorMap(); 16 | 17 | % Overlay the labels over the intensity image. 18 | B = labeloverlay(intensityChannel,labelMap,'Colormap',cmap,'Transparency',0.4); 19 | 20 | % Resize for better visualization. 21 | B = imresize(B, 'Scale', [3 1], 'method', 'nearest'); 22 | imshow(B); 23 | helperPixelLabelColorbar(cmap, classNames); 24 | 25 | end 26 | 27 | function helperPixelLabelColorbar(cmap, classNames) 28 | 29 | colormap(gca, cmap); 30 | 31 | % Add a colorbar to the current figure. 32 | c = colorbar('peer', gca); 33 | 34 | % Use class names for tick marks. 35 | c.TickLabels = classNames; 36 | numClasses = size(classNames, 1); 37 | 38 | % Center tick labels. 39 | c.Ticks = 1/(numClasses * 2):1/numClasses:1; 40 | 41 | % Remove tick marks. 42 | c.TickLength = 0; 43 | end 44 | -------------------------------------------------------------------------------- /+helper/downloadPretrainedSalsaNext.m: -------------------------------------------------------------------------------- 1 | function model = downloadPretrainedSalsaNext() 2 | 3 | % The downloadPretrainedSalsaNext function downloads a SalsaNext network 4 | % pretrained on Pandaset dataset. 5 | % 6 | % Copyright 2021 The MathWorks, Inc. 7 | 8 | dataPath = 'model'; 9 | modelName = 'SalsNext'; 10 | netFileFullPath = fullfile(dataPath,modelName); 11 | 12 | % Add '.mat' extension to the data. 13 | netFileFull = [netFileFullPath,'.zip']; 14 | 15 | if ~exist(netFileFull,'file') 16 | fprintf(['Downloading pretrained', modelName ,'network.\n']); 17 | fprintf('This can take several minutes to download...\n'); 18 | url = 'https://ssd.mathworks.com/supportfiles/lidar/data/trainedSalsaNextPandasetNet.zip'; 19 | websave (netFileFullPath,url); 20 | unzip(netFileFullPath, dataPath); 21 | model = load([dataPath, '/trainedSalsaNext.mat']); 22 | else 23 | fprintf('Pretrained SalsaNext network already exists.\n\n'); 24 | unzip(netFileFullPath, dataPath); 25 | model = load(fullfile(dataPath, 'trainedSalsaNext.mat')); 26 | end 27 | 28 | end 29 | -------------------------------------------------------------------------------- /+helper/generateLidarData.m: -------------------------------------------------------------------------------- 1 | function generateLidarData(lidarData, imageDataLocation) 2 | 3 | %generateLidarData Function to generate images 4 | % from Lidar Point Clouds. The inputs 5 | % lidarData, imageDataLocation are described below. 6 | % 7 | % Inputs 8 | % ------ 9 | % lidarData Lidar point clouds folder path. 10 | % 11 | % imageDataLocation Folder where training images will be saved to 12 | % disk. Make sure this points to a valid location on 13 | % the filesystem. 14 | % 15 | % Copyright 2021 The MathWorks, Inc 16 | 17 | if ~exist(imageDataLocation,'dir') 18 | mkdir(imageDataLocation); 19 | end 20 | 21 | 22 | tmpStr = ''; 23 | lidarData = dir(fullfile(lidarData,'*.pcd')); 24 | numFiles = size(lidarData,1); 25 | for i=1:numFiles 26 | % Load ptcloud object. 27 | data = fullfile(lidarData(i).folder,lidarData(i).name); 28 | ptcloud = pcread(data); 29 | % Image are of 5-channels, namely x,y,z,intensity and range. 30 | im = helper.pointCloudToImage(ptcloud); 31 | 32 | % Store images and labels as .mat and .png files respectively. 33 | imfile = fullfile(imageDataLocation,sprintf('%04d.mat',i)); 34 | save(imfile,'im'); 35 | 36 | 37 | % Display progress after 300 files on screen. 38 | if ~mod(i,300) 39 | msg = sprintf('Preprocessing data %3.2f%% complete', (i/numFiles)*100.0); 40 | fprintf(1,'%s',[tmpStr, msg]); 41 | tmpStr = repmat(sprintf('\b'), 1, length(msg)); 42 | end 43 | end 44 | 45 | % Print completion message when done. 46 | msg = sprintf('Preprocessing data 100%% complete'); 47 | fprintf(1,'%s',[tmpStr, msg]); 48 | 49 | end -------------------------------------------------------------------------------- /+helper/imageMatReader.m: -------------------------------------------------------------------------------- 1 | function data = imageMatReader(filename) 2 | %imageMatReader Reads custom MAT files containing 5-channel 3 | % multispectral image data. 4 | % 5 | % DATA = imageMatReader(FILENAME) returns the first 5 channels of the 6 | % multispectral image saved in FILENAME. 7 | 8 | % Copyright 2021 The MathWorks, Inc 9 | 10 | d = load(filename); 11 | f = fields(d); 12 | data = d.(f{1})(:,:,1:5); 13 | index = isnan(data); 14 | data(index) = 0; 15 | 16 | 17 | -------------------------------------------------------------------------------- /+helper/lidarColorMap.m: -------------------------------------------------------------------------------- 1 | function cmap = lidarColorMap() 2 | % Lidar color map for the pandaset classes 3 | 4 | % Copyright 2021 The MathWorks, Inc 5 | 6 | cmap = [[30,30,30]; % UnClassified 7 | [0,255,0]; % Vegetation 8 | [255, 150, 255]; % Ground 9 | [255,0,255]; % Road 10 | [255,0,0]; % Road Markings 11 | [90, 30, 150]; % Side Walk 12 | [245,150,100]; % Car 13 | [250, 80, 100]; % Truck 14 | [150, 60, 30]; % Other Vehicle 15 | [255, 255, 0]; % Pedestrian 16 | [0, 200, 255]; % Road Barriers 17 | [170,100,150]; % Signs 18 | [30, 30, 255]]; % Building 19 | 20 | cmap = cmap./255; 21 | 22 | end -------------------------------------------------------------------------------- /+helper/partitionLidarData.m: -------------------------------------------------------------------------------- 1 | function [imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = partitionLidarData(imds,pxds) 2 | % Partition Pandaset data by randomly selecting 60% of the data for training. The 3 | % rest is used for testing. 4 | % 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | % Set initial random state for example reproducibility. 8 | rng(0); 9 | numFiles = numel(imds.Files); 10 | shuffledIndices = randperm(numFiles); 11 | 12 | % Use 60% of the images for training. 13 | numTrain = round(0.60 * numFiles); 14 | trainingIdx = shuffledIndices(1:numTrain); 15 | 16 | % Use 20% of the images for validation 17 | numVal = round(0.20 * numFiles); 18 | valIdx = shuffledIndices(numTrain+1:numTrain+numVal); 19 | 20 | % Use the rest for testing. 21 | testIdx = shuffledIndices(numTrain+numVal+1:end); 22 | 23 | % Create image datastores for training and test. 24 | trainingImages = imds.Files(trainingIdx); 25 | valImages = imds.Files(valIdx); 26 | testImages = imds.Files(testIdx); 27 | 28 | imdsTrain = imageDatastore(trainingImages,'FileExtensions', '.mat', ... 29 | 'ReadFcn', @helper.imageMatReader); 30 | imdsVal = imageDatastore(valImages,'FileExtensions', '.mat', ... 31 | 'ReadFcn', @helper.imageMatReader); 32 | imdsTest = imageDatastore(testImages,'FileExtensions', '.mat', ... 33 | 'ReadFcn', @helper.imageMatReader); 34 | 35 | % Extract class and label IDs info. 36 | classes = pxds.ClassNames; 37 | labelIDs = 1 : numel(classes); 38 | 39 | % Create pixel label datastores for training and test. 40 | trainingLabels = pxds.Files(trainingIdx); 41 | valLabels = pxds.Files(valIdx); 42 | testLabels = pxds.Files(testIdx); 43 | 44 | pxdsTrain = pixelLabelDatastore(trainingLabels, classes, labelIDs); 45 | pxdsVal = pixelLabelDatastore(valLabels, classes, labelIDs); 46 | pxdsTest = pixelLabelDatastore(testLabels, classes, labelIDs); 47 | end -------------------------------------------------------------------------------- /+helper/pointCloudToImage.m: -------------------------------------------------------------------------------- 1 | function image = pointCloudToImage(ptcloud) 2 | % pointCloudToImage Converts organized 3-D point cloud to 5-channel 3 | % 2-D image. 4 | % 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | image = ptcloud.Location; 8 | image(:,:,4) = ptcloud.Intensity; 9 | rangeData = iComputeRangeData(image(:,:,1),image(:,:,2),image(:,:,3)); 10 | image(:,:,5) = rangeData; 11 | index = isnan(image); 12 | image(index) = 0; 13 | end 14 | 15 | %-------------------------------------------------------------------------- 16 | function rangeData = iComputeRangeData(xChannel,yChannel,zChannel) 17 | rangeData = sqrt(xChannel.*xChannel+yChannel.*yChannel+zChannel.*zChannel); 18 | end 19 | -------------------------------------------------------------------------------- /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | orbs: 3 | matlab: mathworks/matlab@0.4.0 4 | 5 | jobs: 6 | build: 7 | machine: 8 | image: ubuntu-1604:201903-01 9 | steps: 10 | - checkout 11 | - matlab/install 12 | - matlab/run-tests: 13 | test-results-junit: artifacts/test_results/matlab/results.xml 14 | # Have to add test/tools to the path for certain tests. 15 | source-folder: .;test/tools 16 | - store_test_results: 17 | path: artifacts/test_results 18 | - store_artifacts: 19 | path: artifacts/ 20 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, The MathWorks, Inc. 2 | All rights reserved. 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 5 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 6 | 3. In all cases, the software is, and all modifications and derivatives of the software shall be, licensed to you solely for use in conjunction with MathWorks products and service offerings. 7 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 8 | -------------------------------------------------------------------------------- /Pandaset.rights: -------------------------------------------------------------------------------- 1 | Panda Set (https://scale.com/open-datasets/pandaset) is provided by Hesai and Scale under the CC-BY-4.0 license (https://creativecommons.org/licenses/by/4.0). 2 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SalsaNext For Lidar Segmentation 2 | 3 | This repository provides a pretrained SalsaNext[1] segmentation model in MATLAB®. 4 | 5 | [![Open in MATLAB Online](https://www.mathworks.com/images/responsive/global/open-in-matlab-online.svg)](https://matlab.mathworks.com/open/github/v1?repo=matlab-deep-learning/pretrained-salsanext) 6 | 7 | Requirements 8 | ------------ 9 | 10 | - MATLAB R2021a or later. 11 | - Deep Learning Toolbox™ 12 | - Lidar Toolbox™ 13 | 14 | Overview 15 | -------- 16 | 17 | This repository provides the SalsaNext network trained to segment different object categories including road, cars, trucks, etc. This pretrained model is trained using Pandaset dataset[2] which has 13 different object categories. This repository works on organized point clouds, if you have unorganized or raw point clouds refer the [Unorganized to Organized Conversion of Point Clouds Using Spherical Projection example](https://in.mathworks.com/help/lidar/ug/unorgaized-to-organized-pointcloud-conversion.html) to convert into organized point clouds. 18 | 19 | SalsaNext is the popular Lidar semantic segmentation network used for segmentation of 3-D point clouds. SalsaNext was build upon SalsaNet that has encoder-decoder architecture with residual dilated convolution stack with gradually increasing receptive fields for encoder and a pixel-shuffle layer was added as upsampling layer in the decoder. 20 | 21 | Getting Started 22 | --------------- 23 | Download or clone this repository to your machine and open in MATLAB. 24 | 25 | ### Download the pretrained network 26 | Use below helper to download the pretrained network. The network will be downloaded and saved in `model` directory. 27 | 28 | ``` 29 | model = helper.downloadPretrainedSalsaNext; 30 | net = model.net; 31 | ``` 32 | 33 | ### Semantic Segmentation Using Pretrained SalsaNext 34 | 35 | ``` 36 | % Read test point cloud. 37 | ptCloud = pcread('pointclouds/0001.pcd'); 38 | 39 | % Convert point cloud to 5-channel image. 40 | I = helper.pointCloudToImage(ptCloud); 41 | 42 | % Segment objects from the test point cloud. 43 | predictedResult = semanticseg(I, net); 44 | 45 | % Display the output. 46 | op = single(predictedResult); 47 | cmap = helper.lidarColorMap(); 48 | colormap = cmap(op,:); 49 | ptCloudMod = pointCloud(reshape(I(:,:,1:3),[],3),"Color",colormap); 50 | figure 51 | pcshow(ptCloudMod); 52 | ``` 53 | 54 | 55 | Video output generated on test sequence. 56 | 57 | 58 | 59 | ### Generate 3-D bounding boxes from segmentation result 60 | The segmentation result is transformed to cuboid by clustering the points from class of interest and fitting cuboid around it. For more information about how to detect and track objects using lidar data, see [Detect, Classify, and Track Vehicles Using Lidar Example](https://www.mathworks.com/help/lidar/ug/detect-classify-and-track-vehicles-using-lidar.html). 61 | 62 | ``` 63 | % Get the indices of points for the required class. 64 | carIdx = (predictedResult == 'Car'); 65 | 66 | % Select the points of required class and cluster them based on distance. 67 | ptCldMod = select(ptCloud,carIdx); 68 | [labels,numClusters] = pcsegdist(ptCldMod,0.5); 69 | 70 | % Select each cluster and fit a cuboid to each cluster. 71 | bboxes = []; 72 | for num = 1:numClusters 73 | labelIdx = (labels == num); 74 | 75 | % Ignore cluster that has points less than 200 points. 76 | if sum(labelIdx,'all') < 200 77 | continue; 78 | end 79 | pcSeg = select(ptCldMod,labelIdx); 80 | try 81 | mdl = pcfitcuboid(pcSeg); 82 | bboxes = [bboxes;mdl.Parameters]; 83 | catch 84 | continue; 85 | end 86 | end 87 | 88 | % Display the output. 89 | figure; 90 | ax = pcshow(ptCloudMod); 91 | showShape('cuboid',bboxes,'Parent',ax,'Opacity',0.1,... 92 | 'Color','green','LineWidth',0.5); 93 | ``` 94 | 95 | 96 | 97 | Training Custom SalsaNext Using Transfer Learning 98 | ------------------------------------------------- 99 | 100 | Transfer learning enables you to adapt a pretrained SalsaNext network to your dataset. Create a custom SalsaNext network for transfer learning with a new set of classes using the `salsaNextTransferLearn.m` script. 101 | 102 | 103 | Code Generation for SalsaNext 104 | ----------------------------- 105 | Code generation enables you to generate code and deploy SalsaNext on multiple embedded platforms. 106 | 107 | Run `codegenSalsaNext.m`. This script calls the `salsaNextpredict.m` entry point function and generate CUDA code for it. It will run the generated MEX and gives output. 108 | 109 | | Model | Inference Speed (FPS) | 110 | | ------ | ------ | 111 | | SalsaNext w/o codegen | 8.06 | 112 | | SalsaNet with codegen | 24.39 | 113 | 114 | - Performance (in FPS) is measured on a TITAN-RTX GPU. 115 | 116 | For more information about codegen, see [Deep Learning with GPU Coder](https://www.mathworks.com/help/gpucoder/gpucoder-deep-learning.html). 117 | 118 | ## SalsaNext Architecture Details 119 | The SalsaNext network architecture is illustrated in the following diagram. 120 | 121 | 122 | 123 | - **Context Module:** To aggregate the global context information in different regions, a residual dilated convolution stack is used. This module fuses a larger receptive field with a smaller one by adding 1x1 and 3x3 kernels right at the beginning of the network. This module captures the global context alongside with more detailed spatial information. 124 | 125 | - **Encoder:** The encoder block consists of a novel combination of a set of dilated convolutions having effective receptive fields of 3, 5 and 7. Each dilated convolution output is concatenated and then passed through a 1×1 convolution followed by a residual connection. This helps the network exploit more information from the fused features coming from various depths in the receptive field. Each of these new residual dilated convolution blocks is followed by dropout and pooling layers. 126 | 127 | - **Decoder:** In general transpose convolutions are used in decoders which are computationally expensive. These standard transpose convolutions are replaced with the pixelshuffle layers, which leverages on the learnt feature maps to produce the upsampled feature maps by shuffling the pixels from the channel dimension to the spatial dimension. 128 | 129 | 130 | 131 | References 132 | ---------- 133 | 134 | [1] Cortinhal, Tiago, George Tzelepis, and Eren Erdal Aksoy. “SalsaNext: Fast, Uncertainty-Aware Semantic Segmentation of LiDAR Point Clouds for Autonomous Driving.” ArXiv:2003.03653 [Cs], July 9, 2020. http://arxiv.org/abs/2003.03653. 135 | 136 | [2] [Panda Set](https://scale.com/open-datasets/pandaset) is provided by Hesai and Scale under the [CC-BY-4.0 license](https://creativecommons.org/licenses/by/4.0) 137 | 138 | 139 | Copyright 2021 The Mathworks, Inc. 140 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Reporting Security Vulnerabilities 2 | 3 | If you believe you have discovered a security vulnerability, please report it to 4 | [security@mathworks.com](mailto:security@mathworks.com). Please see 5 | [MathWorks Vulnerability Disclosure Policy for Security Researchers](https://www.mathworks.com/company/aboutus/policies_statements/vulnerability-disclosure-policy.html) 6 | for additional information. 7 | -------------------------------------------------------------------------------- /codegenSalsaNext.m: -------------------------------------------------------------------------------- 1 | %% Code generation For SalsaNext Network 2 | % The following script demonstrates how to perform code generation for a 3 | % pretrained SalsaNext semantic segmentation network, trained on Pandaset 4 | % dataset. 5 | 6 | %% Download the pre-trained network 7 | helper.downloadPretrainedSalsaNext; 8 | 9 | %% Read and process the input point cloud 10 | % Read test point cloud. 11 | ptCloud = pcread('pointclouds/Input1.pcd'); 12 | 13 | % Convert point cloud to 5-channel image. 14 | I = helper.pointCloudToImage(ptCloud); 15 | 16 | %% Run MEX code generation 17 | % The salsaNextpredict.m is entry-point function that takes an input image 18 | % and gives output. The function uses a persistent object salsaNextObj to 19 | % load the DAG network object and reuses the persistent object for prediction 20 | % on subsequent calls. 21 | % 22 | % To generate CUDA code for the salsaNextpredict entry-point function, 23 | % create a GPU code configuration object for a MEX target and set the 24 | % target language to C++. 25 | % 26 | % Use the coder.DeepLearningConfig (GPU Coder) function to create a CuDNN 27 | % deep learning configuration object and assign it to the DeepLearningConfig 28 | % property of the GPU code configuration object. 29 | % 30 | % Run the codegen command and specify the input size. 31 | cfg = coder.gpuConfig('mex'); 32 | cfg.TargetLang = 'C++'; 33 | cfg.DeepLearningConfig = coder.DeepLearningConfig('cudnn'); 34 | codegen -config cfg salsaNextpredict -args {ones(64,1856,5)} -report 35 | 36 | %% Perform Semantic Segmenation Using Generated Mex 37 | % Call salsaNextpredict_mex on the input range image. 38 | predict_scores = salsaNextpredict_mex(I); 39 | 40 | % The predict_scores variable is a three-dimensional matrix that has 13 41 | % channels corresponding to the pixel-wise prediction scores for every 42 | % class. Compute the channel by using the maximum prediction score to get 43 | % pixel-wise labels. 44 | [~,op] = max(predict_scores,[],3); 45 | 46 | % Visualize the result. 47 | cmap = helper.lidarColorMap(); 48 | colormap = cmap(op,:); 49 | ptCloudMod = pointCloud(reshape(I(:,:,1:3),[],3),"Color",colormap); 50 | figure 51 | ax1 = pcshow(ptCloudMod); 52 | zoom(ax1,3); 53 | 54 | 55 | % Copyright 2021 The MathWorks, Inc -------------------------------------------------------------------------------- /images/SalsaNext_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/images/SalsaNext_architecture.png -------------------------------------------------------------------------------- /images/cuboid_0001.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/images/cuboid_0001.PNG -------------------------------------------------------------------------------- /images/result.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/images/result.gif -------------------------------------------------------------------------------- /images/result_0001.PNG: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/images/result_0001.PNG -------------------------------------------------------------------------------- /model/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /pointclouds/Input1.pcd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/pointclouds/Input1.pcd -------------------------------------------------------------------------------- /pointclouds/Input2.pcd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/pointclouds/Input2.pcd -------------------------------------------------------------------------------- /pointclouds/Input3.pcd: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-salsanext/5ddca9df1ef5295cd8caad0c62474284f2c12230/pointclouds/Input3.pcd -------------------------------------------------------------------------------- /salsaNextSemanticSegmentationExample.m: -------------------------------------------------------------------------------- 1 | %% Lidar Point Cloud Semantic Segmentation Using SalsaNext Deep Learning Network 2 | % The following code demonstrates running prediction on a pre-trained 3 | % SalsaNext network, trained on Pandaset Dataset. 4 | 5 | %% Prerequisites 6 | % To run this example you need the following prerequisites - 7 | % # MATLAB (R2021a or later) with Lidar and Deep Learning Toolbox. 8 | % # Pretrained SalsaNext network(download instructions below) 9 | 10 | %% Download the pre-trained network 11 | model = helper.downloadPretrainedSalsaNext; 12 | net = model.net; 13 | 14 | % Define ClassNames 15 | classNames = ["unlabelled" 16 | "Vegetation" 17 | "Ground" 18 | "Road" 19 | "RoadMarkings" 20 | "SideWalk" 21 | "Car" 22 | "Truck" 23 | "OtherVehicle" 24 | "Pedestrian" 25 | "RoadBarriers" 26 | "Signs" 27 | "Buildings"]; 28 | 29 | %% Perform Semantic Segmentation Using SalsaNext Network 30 | % Read test point cloud. 31 | ptCloud = pcread('pointclouds/Input1.pcd'); 32 | 33 | % Convert point cloud to 5-channel image. 34 | I = helper.pointCloudToImage(ptCloud); 35 | 36 | % Segment objects from the test point cloud. 37 | predictedResult = semanticseg(I, net,"ExecutionEnvironment","auto"); 38 | 39 | %% Display Output 40 | figure; 41 | helper.displayLidarOverlayImage(I, predictedResult, classNames); 42 | title('Semantic Segmentation Result'); 43 | 44 | % Display in point cloud format. 45 | cmap = helper.lidarColorMap(); 46 | colormap = cmap(single(predictedResult),:); 47 | ptCloudMod = pointCloud(reshape(I(:,:,1:3),[],3),"Color",colormap); 48 | figure 49 | ax = pcshow(ptCloudMod); 50 | zoom(ax,3); 51 | 52 | %% Get Bounding Boxes from semgenation output. 53 | % Get the indices of points for the required class. 54 | carIdx = (predictedResult == 'Car'); 55 | 56 | % Select the points of required class and cluster them based on distance. 57 | ptCldMod = select(ptCloud,carIdx); 58 | [labels,numClusters] = pcsegdist(ptCldMod,0.5); 59 | 60 | % Select each cluster and fit a cuboid to each cluster. 61 | bboxes = []; 62 | for num = 1:numClusters 63 | labelIdx = (labels == num); 64 | 65 | % Ignore cluster that has points less than 200 points. 66 | if sum(labelIdx,'all') < 200 67 | continue; 68 | end 69 | pcSeg = select(ptCldMod,labelIdx); 70 | try 71 | mdl = pcfitcuboid(pcSeg); 72 | bboxes = [bboxes;mdl.Parameters]; 73 | catch 74 | continue; 75 | end 76 | end 77 | 78 | % Display the output. 79 | figure; 80 | ax = pcshow(ptCloudMod); 81 | showShape('cuboid',bboxes,'Parent',ax,'Opacity',0.1,... 82 | 'Color','green','LineWidth',0.5); 83 | zoom(ax,3); 84 | 85 | % Copyright 2021 The MathWorks, Inc -------------------------------------------------------------------------------- /salsaNextTransferLearning.m: -------------------------------------------------------------------------------- 1 | %% Configure Pretrained SalsaNext Network for Transfer Learning 2 | % The following code demonstrates configuring a pretrained 3 | % SalsaNext[1] network on the custom dataset. 4 | 5 | %% Download Pretrained Model 6 | 7 | model = helper.downloadPretrainedSalsaNext; 8 | net = model.net; 9 | 10 | %% Download Pandaset Data Set 11 | % This example uses a subset of PandaSet[2], that contains 2560 12 | % preprocessed organized point clouds. Each point cloud is specified as a 13 | % 64-by-1856 matrix. The corresponding ground truth contains the semantic 14 | % segmentation labels for 12 classes. The point clouds are stored in PCD 15 | % format, and the ground truth data is stored in PNG format. The size of 16 | % the data set is 5.2 GB. Execute this code to download the data set. 17 | 18 | url = 'https://ssd.mathworks.com/supportfiles/lidar/data/Pandaset_LidarData.tar.gz'; 19 | outputFolder = fullfile(tempdir,'Pandaset'); 20 | 21 | lidarDataTarFile = fullfile(outputFolder,'Pandaset_LidarData.tar.gz'); 22 | if ~exist(lidarDataTarFile, 'file') 23 | mkdir(outputFolder); 24 | disp('Downloading Pandaset Lidar driving data (5.2 GB)...'); 25 | websave(lidarDataTarFile, url); 26 | untar(lidarDataTarFile,outputFolder); 27 | end 28 | 29 | % Check if tar.gz file is downloaded, but not uncompressed. 30 | if (~exist(fullfile(outputFolder,'Lidar'), 'file'))... 31 | &&(~exist(fullfile(outputFolder,'semanticLabels'), 'file')) 32 | untar(lidarDataTarFile,outputFolder); 33 | end 34 | 35 | lidarData = fullfile(outputFolder,'Lidar'); 36 | labelsFolder = fullfile(outputFolder,'semanticLabels'); 37 | 38 | % Note: Depending on your Internet connection, the download process can 39 | % take some time. The code suspends MATLAB® execution until the download 40 | % process is complete. Alternatively, you can download the data set to your 41 | % local disk using your web browser, and then extract Pandaset_LidarData 42 | % folder. To use the file you downloaded from the web, change the 43 | % outputFolder variable in the code to the location of the downloaded file. 44 | 45 | %% Prepare Data for Training 46 | % Load Lidar Point Clouds and Class Labels Use the generateLidarData helper 47 | % function, to generate training data from the lidar point clouds. The 48 | % function uses point cloud data to create five-channel input images. Each 49 | % training image is specified as a 64-by-1856-by-5 array: 50 | % 51 | % Generate the five-channel training images. 52 | 53 | imagesFolder = fullfile(outputFolder,'Images'); 54 | helper.generateLidarData(lidarData,imagesFolder); 55 | 56 | % The five-channel images are saved as MAT files. 57 | % 58 | % Note: Processing can take some time. The code suspends MATLAB® execution 59 | % until processing is complete. 60 | 61 | %% Load Generated Images. 62 | % Create ImageDatastore and PixelLabelDatastore Use the ImageDatastore 63 | % object to extract and store the five channels of the 2-D spherical images 64 | % using the imageMatReader helper function, which is a custom MAT file 65 | % reader. 66 | 67 | imds = imageDatastore(imagesFolder, ... 68 | 'FileExtensions', '.mat', ... 69 | 'ReadFcn', @helper.imageMatReader); 70 | 71 | % Use the PixelLabelDatastore object to store pixel-wise labels from pixel 72 | % label images. The object maps each pixel label to a class name. In this 73 | % example, vegetation, ground, road, road markings, side walk, cars, 74 | % trucks, other vehicles, pedestrian, road barrier signs and buildings are 75 | % the objects of interest; all other pixels are the background. Specify 76 | % these classes and assign a unique label ID to each class. 77 | 78 | classNames = ["unlabelled" 79 | "Vegetation" 80 | "Ground" 81 | "Road" 82 | "RoadMarkings" 83 | "SideWalk" 84 | "Car" 85 | "Truck" 86 | "OtherVehicle" 87 | "Pedestrian" 88 | "RoadBarriers" 89 | "Signs" 90 | "Buildings"]; 91 | 92 | numClasses = numel(classNames); 93 | 94 | % Specify label IDs from 1 to the number of classes. 95 | labelIDs = 1 : numClasses; 96 | 97 | pxds = pixelLabelDatastore(labelsFolder, classNames, labelIDs); 98 | 99 | %% Prepare Training, Validation, and Test Sets 100 | % Use the partitionLidarData helper function to split the data into 101 | % training, images, respectively. 102 | 103 | [imdsTrain, imdsVal, imdsTest, pxdsTrain, pxdsVal, pxdsTest] = helper.partitionLidarData(imds, pxds); 104 | 105 | dsTrain = combine(imdsTrain,pxdsTrain); 106 | dsVal = combine(imdsVal,pxdsVal); 107 | 108 | %% Data Augmentation 109 | % Data augmentation is used to improve network accuracy by randomly 110 | % transforming the original data during training. By using data 111 | % augmentation, you can add more variety to the training data without 112 | % actually having to increase the number of labeled training samples. 113 | % 114 | % Augment the training data by using the transform function with custom 115 | % preprocessing operations specified by the augmentData helper function. 116 | % This function randomly flips the multichannel 2-D image and associated 117 | % labels in the horizontal direction. Apply data augmentation to only the 118 | % training data set. 119 | 120 | augmentedTrainingData = transform(dsTrain, @(x) helper.augmentData(x)); 121 | %% Configure Pretrained Network 122 | 123 | % Extract the layergraph from the pretrained network to perform custom 124 | % modification. 125 | lgraph = layerGraph(net); 126 | 127 | % Changing output size to required number of classes. 128 | lgraph = replaceLayer(lgraph, 'Conv_191', convolution2dLayer([1,1], numClasses, 'Name', 'Conv_191')); 129 | 130 | %% % Define training options. 131 | options = trainingOptions('sgdm', ... 132 | 'LearnRateSchedule','piecewise',... 133 | 'LearnRateDropPeriod',10,... 134 | 'LearnRateDropFactor',0.3,... 135 | 'Momentum',0.9, ... 136 | 'InitialLearnRate',1e-3, ... 137 | 'L2Regularization',0.005, ... 138 | 'ValidationData',dsVal,... 139 | 'MaxEpochs',6, ... 140 | 'MiniBatchSize',8, ... 141 | 'Shuffle','every-epoch', ... 142 | 'CheckpointPath', tempdir, ... 143 | 'VerboseFrequency',2,... 144 | 'Plots','training-progress',... 145 | 'ValidationPatience', 4); 146 | 147 | % The learning rate uses a piecewise schedule. The learning rate is reduced 148 | % by a factor of 0.3 every 10 epochs. This allows the network to learn quickly 149 | % with a higher initial learning rate, while being able to find a solution 150 | % close to the local optimum once the learning rate drops. 151 | % 152 | % The network is tested against the validation data every epoch by setting 153 | % the 'ValidationData' parameter. The 'ValidationPatience' is set to 4 to 154 | % stop training early when the validation accuracy converges. This prevents 155 | % the network from overfitting on the training dataset. 156 | % 157 | % A mini-batch size of 16 is used for training. You can increase or decrease 158 | % this value based on the amount of GPU memory you have on your system. 159 | % 160 | % In addition, 'CheckpointPath' is set to a temporary location. This name-value 161 | % pair enables the saving of network checkpoints at the end of every training 162 | % epoch. If training is interrupted due to a system failure or power outage, 163 | % you can resume training from the saved checkpoint. Make sure that the location 164 | % specified by 'CheckpointPath' has enough space to store the network checkpoints. 165 | 166 | % Now, you can pass the 'dsTrain', 'lgraph' and 'options' to trainNetwork 167 | % as shown in 'Train Network' section of the example 'Lidar Point Cloud 168 | % Semantic Segmentation Using SqueezeSegV2 Deep Learning Network Example' 169 | % (https://www.mathworks.com/help/lidar/ug/semantic-segmentation-using-squeezesegv2-network.html)to 170 | % obtain salsaNext model trained on the custom dataset. 171 | % 172 | % You can follow the sections 'Test Network on One Image' for inference using 173 | % the trained model and 'Evaluate Trained Network' for evaluating metrics. 174 | 175 | 176 | %% References 177 | 178 | % [1] Cortinhal, Tiago, George Tzelepis, and Eren Erdal Aksoy. "SalsaNext: Fast, 179 | % Uncertainty-Aware Semantic Segmentation of LiDAR Point Clouds for Autonomous 180 | % Driving." ArXiv:2003.03653 [Cs], July 9, 2020. http://arxiv.org/abs/2003.03653 181 | % http://arxiv.org/abs/2003.03653. 182 | % 183 | % [2] https://scale.com/open-datasets/pandaset https://scale.com/open-datasets/pandaset 184 | % 185 | % Copyright 2020 The MathWorks, Inc. -------------------------------------------------------------------------------- /salsaNextpredict.m: -------------------------------------------------------------------------------- 1 | function out = salsaNextpredict(in) 2 | %#codegen 3 | % Copyright 2021 The MathWorks, Inc. 4 | 5 | persistent salsaNextObj; 6 | 7 | if isempty(salsaNextObj) 8 | salsaNextObj = coder.loadDeepLearningNetwork('model/trainedSalsaNext.mat'); 9 | end 10 | 11 | % Pass input. 12 | out = predict(salsaNextObj,in); 13 | 14 | end -------------------------------------------------------------------------------- /test/tPretrainedSalsaNext.m: -------------------------------------------------------------------------------- 1 | classdef(SharedTestFixtures = {DownloadSalsaNextFixture}) tPretrainedSalsaNext < matlab.unittest.TestCase 2 | % Test for tPretrainedSalsaNext 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | % The shared test fixture downloads the model. Here we check the 7 | % inference on the pretrained model. 8 | properties 9 | RepoRoot = getRepoRoot; 10 | ModelName = 'trainedSalsaNext.mat'; 11 | end 12 | 13 | methods(Test) 14 | function exerciseDetection(test) 15 | model = load(fullfile(test.RepoRoot,'model',test.ModelName)); 16 | ptCloud = pcread(fullfile(test.RepoRoot,'pointclouds','Input1.pcd')); 17 | img = helper.pointCloudToImage(ptCloud); 18 | 19 | imSize = size(img); 20 | imSize = imSize(:,1:2); 21 | actualLabel1Count = 48420; 22 | actualLabel2Count = 11990; 23 | 24 | result = semanticseg(img, model.net); 25 | labelsCountTbl = countlabels(result(:)); 26 | labelCount = labelsCountTbl.Count(find(labelsCountTbl.Count)); 27 | 28 | % verifying size of output from semanticseg. 29 | test.verifyEqual(size(result),imSize); 30 | % verifying that all the pixels are labelled. 31 | test.verifyEqual(sum(labelCount),prod(imSize)); 32 | % verifying the count of each labels on the result. 33 | test.verifyEqual(labelCount(1),actualLabel1Count); 34 | test.verifyEqual(labelCount(2),actualLabel2Count); 35 | end 36 | end 37 | end -------------------------------------------------------------------------------- /test/tdownloadPretrainedSalsaNext.m: -------------------------------------------------------------------------------- 1 | classdef(SharedTestFixtures = {DownloadSalsaNextFixture}) tdownloadPretrainedSalsaNext < matlab.unittest.TestCase 2 | % Test for downloadPretrainedSalsaNext 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | % The shared test fixture DownloadSalsaNextFixture calls 7 | % downloadPretrainedSalsaNext. Here we check that the downloaded files 8 | % exists in the appropriate location. 9 | 10 | properties 11 | DataDir = fullfile(getRepoRoot(),'model'); 12 | end 13 | 14 | methods(Test) 15 | function verifyDownloadedFilesExist(test) 16 | dataFileName = 'trainedSalsaNext.mat'; 17 | test.verifyTrue(isequal(exist(fullfile(test.DataDir,dataFileName),'file'),2)); 18 | end 19 | end 20 | end -------------------------------------------------------------------------------- /test/tload.m: -------------------------------------------------------------------------------- 1 | classdef(SharedTestFixtures = {DownloadSalsaNextFixture}) tload < matlab.unittest.TestCase 2 | % Test for loading the downloaded models. 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | % The shared test fixture DownloadSalsaNextFixture calls 7 | % downloadPretrainedSalsaNext. Here we check that the properties of 8 | % downloaded models. 9 | 10 | properties 11 | DataDir = fullfile(getRepoRoot(),'model'); 12 | end 13 | 14 | methods(Test) 15 | function verifyModelAndFields(test) 16 | % Test point to verify the fields of the downloaded models are 17 | % as expected. 18 | 19 | loadedModel = load(fullfile(test.DataDir,'trainedSalsaNext.mat')); 20 | 21 | test.verifyClass(loadedModel.net,'DAGNetwork'); 22 | test.verifyEqual(numel(loadedModel.net.Layers),175); 23 | test.verifyEqual(size(loadedModel.net.Connections),[204 2]) 24 | test.verifyEqual(loadedModel.net.InputNames,{'Input_input.1'}); 25 | test.verifyEqual(loadedModel.net.OutputNames,{'focalloss_out'}); 26 | end 27 | end 28 | end -------------------------------------------------------------------------------- /test/tools/DownloadSalsaNextFixture.m: -------------------------------------------------------------------------------- 1 | classdef DownloadSalsaNextFixture < matlab.unittest.fixtures.Fixture 2 | % DownloadSalsaNextFixture A fixture for calling 3 | % downloadPretrainedSalsaNext if necessary. This is to ensure that this 4 | % function is only called once and only when tests need it. It also 5 | % provides a teardown to return the test environment to the expected 6 | % state before testing. 7 | 8 | % Copyright 2021 The MathWorks, Inc 9 | 10 | properties(Constant) 11 | SalsaNextDataDir = fullfile(getRepoRoot(),'model') 12 | end 13 | 14 | properties 15 | SalsaNextExist (1,1) logical 16 | end 17 | 18 | methods 19 | function setup(this) 20 | this.SalsaNextExist = exist(fullfile(this.SalsaNextDataDir,'trainedSalsaNext.mat'),'file')==2; 21 | 22 | % Call this in eval to capture and drop any standard output 23 | % that we don't want polluting the test logs. 24 | if ~this.SalsaNextExist 25 | evalc('helper.downloadPretrainedSalsaNext();'); 26 | end 27 | end 28 | 29 | function teardown(this) 30 | if this.SalsaNextExist 31 | delete(fullfile(this.SalsaNextDataDir,'trainedSalsaNext.mat')); 32 | end 33 | end 34 | end 35 | end -------------------------------------------------------------------------------- /test/tools/getRepoRoot.m: -------------------------------------------------------------------------------- 1 | function path = getRepoRoot() 2 | % getRepoRoot Return a path to the repository's root directory. 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | thisFile = mfilename('fullpath'); 7 | thisDir = fileparts(thisFile); 8 | 9 | % the root is up two directories (/test/tools/getRepoRoot.m) 10 | path = fullfile(thisDir,'..','..'); 11 | end --------------------------------------------------------------------------------