├── .circleci └── config.yml ├── .gitignore ├── LICENSE ├── README.md ├── SECURITY.md ├── images ├── SCNN_architecture.png ├── caltech_washington1.gif ├── detections.gif ├── result.jpg └── testImage.jpg ├── model └── .gitkeep ├── spatialCNNLaneDetectionExample.m ├── spatialCNNLaneDetectionVideoExample.m ├── src ├── +helper │ ├── createSCNNDetectionParameters.m │ ├── downloadSCNNLaneDetection.m │ ├── generateLines.m │ ├── plotLanes.m │ ├── plotLanesVideo.m │ └── processPredictions.m ├── +layer │ └── MessagePassingLayer.m └── detectLaneMarkings.m └── test ├── tMessagePassingLayer.m ├── tPretrainedSCNNLaneDetection.m ├── tdownloadSCNNLaneDetection.m ├── tload.m └── tools ├── DownloadSCNNLaneDetectionFixture.m └── getRepoRoot.m /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | orbs: 3 | matlab: mathworks/matlab@0.4.0 4 | 5 | jobs: 6 | build: 7 | machine: 8 | image: ubuntu-1604:201903-01 9 | steps: 10 | - checkout 11 | - matlab/install 12 | - matlab/run-tests: 13 | test-results-junit: artifacts/test_results/matlab/results.xml 14 | # Have to add test/tools to the path for certain tests. 15 | source-folder: .;test/tools;src 16 | - store_test_results: 17 | path: artifacts/test_results 18 | - store_artifacts: 19 | path: artifacts/ 20 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | model/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, The MathWorks, Inc. 2 | All rights reserved. 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 5 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 6 | 3. In all cases, the software is, and all modifications and derivatives of the software shall be, licensed to you solely for use in conjunction with MathWorks products and service offerings. 7 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Lane Detection Using Deep Learning 2 | 3 | This repository implements a pretrained Spatial-CNN (SCNN)[1] lane detection model in MATLAB®. 4 | 5 | ## Requirements 6 | - MATLAB® R2021a or later. 7 | - Deep Learning Toolbox™. 8 | - Computer Vision Toolbox™. 9 | - Automated Driving Toolbox™. 10 | 11 | ## Overview 12 | This repository implements SCNN with VGG-16 as the backbone. The pretrained network is trained to detect lanes in the image. The network is trained using [CULane](https://xingangpan.github.io/projects/CULane.html) dataset[1]. 13 | 14 | Spatial-CNN (SCNN) uses slice-by-slice convolutions on the feature maps obtained by layer-by-layer convolutions since the spatial information can be reinforced via inter-layer propagation. This helps in detecting objects with strong structure prior but less appearance clues such as lanes, poles, or trucks with occlusions. 15 | 16 | ## Getting Started 17 | Download or clone this repository to your machine and open it in MATLAB®. 18 | 19 | ### Setup 20 | Add path to the source directory. 21 | ``` 22 | addpath('src'); 23 | ``` 24 | 25 | ### Download and Load the Pretrained Network 26 | Use the below helper to download and load the pretrained network. The network will be downloaded and saved in `model` directory. 27 | ``` 28 | model = helper.downloadSCNNLaneDetection; 29 | net = model.net; 30 | ``` 31 | 32 | ### Detect Lanes Using SCNN 33 | This snippet includes all the steps required to run SCNN model on a single RGB image in MATLAB®. Use the script `spatialCNNLaneDetectionExample.m` to run the inference on single image. 34 | 35 | ``` 36 | % Specify Detection Parameters. 37 | params = helper.createSCNNDetectionParameters; 38 | 39 | % Specify the executionEnvironment as either "cpu", "gpu", or "auto". 40 | executionEnvironment = "auto"; 41 | 42 | % Read the test image. 43 | path = fullfile("images","testImage.jpg"); 44 | image = imread(path); 45 | 46 | % Use detectLaneMarkings function to detect the lane markings. 47 | laneMarkings = detectLaneMarkings(net, image, params, executionEnvironment); 48 | 49 | % Visualize the detected lanes. 50 | fig = figure; 51 | helper.plotLanes(fig, image, laneMarkings); 52 | 53 | ``` 54 | 55 | Alternatively, you can also run the SCNN model on sample videos. Use the script `spatialCNNLaneDetectionVideoExample.m` to run the inference on a driving scene. 56 | 57 | 58 | ### Result 59 | Left-side image is the input and right-side image shows the detected lanes. The image is taken from the [Panda Set](https://scale.com/open-datasets/pandaset) dataset[2]. 60 | 61 | 62 | 63 | 64 | 65 | 66 |
67 | 68 | Sample video output generated by the script `spatialCNNLaneDetectionVideoExample.m`. 69 | 70 | 71 | 72 | ## Evaluation Metrics 73 | The model is evaluated using the method specified in [1]. 74 | 75 | | Dataset | Error Metric | IOU | Result | 76 | | ------------- | ------------- | ------------- | ------------- | 77 | | CULane | F-measure | 0.3 | 73.45 | 78 | | CULane | F-measure | 0.5 | 43.41 | 79 | 80 | ## Spatial-CNN Algorithm Details 81 | The SCNN network architecture is illustrated in the following diagram. 82 | || 83 | |:--:| 84 | |**Fig.1**| 85 | 86 | 87 | The network takes RGB images as input and outputs a probability map and confidence score for each lane. The pre-trained SCNN model trained on [CULane](https://xingangpan.github.io/projects/CULane.html) can detect maximum of 4 lanes( 2 driving lanes and 2 lanes on either side of the driving lane). The probability map predicted by the network has 5 channels (4 lanes + 1 background). Lanes with confidence score less than 0.5 are ignored. To generate the detections the probability map is processed and curves are fit. 88 | 89 | 90 | SCNN network in this repository has 4 message passing layers in sequence in the directions top-to-bottom, bottom-to-top, left-to-right, and right-to-left with kernel size of 9 and are represented by up-down, down-up, left-right, and right-left respectively in Fig.1. The message passing layers are special layers that apply slice-by-slice convolutions within the feature map[1]. These layers are implemented as a custom nested deep learning layer. For more information about the custom nested deep learning layer, see [Define Nested Deep Learning Layer](https://www.mathworks.com/help/deeplearning/ug/define-nested-deep-learning-layer.html). 91 | 92 | ## References 93 | 94 | [1] Xingang Pan, Jianping Shi, Ping Luo, Xiaogang Wang, and Xiaoou Tang. "Spatial As Deep: Spatial CNN for Traffic Scene Understanding" AAAI Conference on Artificial Intelligence (AAAI) - 2018 95 | 96 | [2] [Panda Set](https://scale.com/open-datasets/pandaset) is provided by Hesai and Scale under the [CC-BY-4.0 license](https://creativecommons.org/licenses/by/4.0) 97 | 98 | ## See also 99 | [Visual Perception Using Monocular Camera](https://www.mathworks.com/help/driving/ug/visual-perception-using-monocular-camera.html) 100 | 101 | Copyright 2021 The MathWorks, Inc. 102 | 103 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Reporting Security Vulnerabilities 2 | 3 | If you believe you have discovered a security vulnerability, please report it to 4 | [security@mathworks.com](mailto:security@mathworks.com). Please see 5 | [MathWorks Vulnerability Disclosure Policy for Security Researchers](https://www.mathworks.com/company/aboutus/policies_statements/vulnerability-disclosure-policy.html) 6 | for additional information. 7 | -------------------------------------------------------------------------------- /images/SCNN_architecture.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-spatial-CNN/bc159eaf3bc4f411904598652d94fc90d7eeafa7/images/SCNN_architecture.png -------------------------------------------------------------------------------- /images/caltech_washington1.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-spatial-CNN/bc159eaf3bc4f411904598652d94fc90d7eeafa7/images/caltech_washington1.gif -------------------------------------------------------------------------------- /images/detections.gif: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-spatial-CNN/bc159eaf3bc4f411904598652d94fc90d7eeafa7/images/detections.gif -------------------------------------------------------------------------------- /images/result.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-spatial-CNN/bc159eaf3bc4f411904598652d94fc90d7eeafa7/images/result.jpg -------------------------------------------------------------------------------- /images/testImage.jpg: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-spatial-CNN/bc159eaf3bc4f411904598652d94fc90d7eeafa7/images/testImage.jpg -------------------------------------------------------------------------------- /model/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /spatialCNNLaneDetectionExample.m: -------------------------------------------------------------------------------- 1 | %% Lane Detection Using Spatial-CNN Network 2 | % The following code demonstrates running lane detection on a pre-trained SCNN 3 | % network, trained on CULane dataset. 4 | 5 | %% Prerequisites 6 | % To run this example you need the following prerequisites - 7 | % * MATLAB (R2021a or later). 8 | % * Deep Learning Toolbox. 9 | % * Pretrained Spatial-CNN network (download instructions below). 10 | 11 | %% Add path to the source directory 12 | addpath('src'); 13 | 14 | %% Download Pre-trained Network 15 | model = helper.downloadSCNNLaneDetection; 16 | net = model.net; 17 | 18 | %% Specify Detection Parameters 19 | % Use the function helper.createSCNNDetectionParameters to specify the 20 | % parameters required for lane detection. 21 | params = helper.createSCNNDetectionParameters; 22 | 23 | % Specify the executionEnvironment as either "cpu", "gpu", or "auto". 24 | executionEnvironment = "auto"; 25 | 26 | %% Detect on an Image 27 | % Read the test image. 28 | path = fullfile("images","testImage.jpg"); 29 | image = imread(path); 30 | 31 | % Call detectLaneMarkings to detect the lane markings. 32 | laneMarkings = detectLaneMarkings(net, image, params, executionEnvironment); 33 | 34 | % Visualize the detected lanes. 35 | fig = figure; 36 | helper.plotLanes(fig, image, laneMarkings); 37 | 38 | % Copyright 2021 The MathWorks, Inc. 39 | -------------------------------------------------------------------------------- /spatialCNNLaneDetectionVideoExample.m: -------------------------------------------------------------------------------- 1 | %% Lane Detection Using Spatial-CNN Network in Driving Scene 2 | % The following code demonstrates running lane detection on a pre-trained 3 | % SCNN network on a driving scene. 4 | 5 | %% Prerequisites 6 | % To run this example you need the following prerequisites - 7 | % * MATLAB (R2021a or later). 8 | % * Deep Learning Toolbox. 9 | % * Automated Driving Toolbox (required to load the video). 10 | % * Pretrained Spatial-CNN network (download instructions below). 11 | 12 | %% Add path to the source directory 13 | addpath('src'); 14 | 15 | %% Download Pre-trained Network 16 | model = helper.downloadSCNNLaneDetection; 17 | net = model.net; 18 | 19 | %% Specify Detection Parameters 20 | % Use the function helper.createSCNNDetectionParameters to specify the 21 | % parameters required for lane detection. 22 | params = helper.createSCNNDetectionParameters; 23 | 24 | % Specify the mini batch size as 8. Increase this value to speed up 25 | % detection time. 26 | miniBatchSize = 8; 27 | 28 | % Specify the executionEnvironment as either "cpu", "gpu", or "auto". 29 | executionEnvironment = "auto"; 30 | 31 | %% Detect in Video 32 | % Read the video. 33 | v = VideoReader('caltech_washington1.avi'); 34 | 35 | % Store the video start time. 36 | videoStartTime = v.CurrentTime; 37 | 38 | % Detect using the detectLaneMarkingVideo function provided as helper 39 | % function below. 40 | laneMarkings = detectLaneMarkingVideo(net, v, params, miniBatchSize, executionEnvironment); 41 | 42 | % Plot detections in video and save result. 43 | helper.plotLanesVideo(v, laneMarkings, videoStartTime); 44 | 45 | 46 | %% Helper function to Detect in Video 47 | function laneMarkings = detectLaneMarkingVideo(net, v, params, miniBatchSize, executionEnvironment) 48 | % Detect lane markings in a video by reading frames in batches. 49 | laneMarkings = {}; 50 | numBatches = ceil(v.NumFrames/miniBatchSize); 51 | for batch = 1:numBatches 52 | firstFrameIdx = miniBatchSize*batch-(miniBatchSize-1); 53 | if batch ~= numBatches 54 | lastFrameIdx = miniBatchSize*batch; 55 | else 56 | lastFrameIdx = v.NumFrames; 57 | end 58 | % Read batch of frames. 59 | frames = read(v, [firstFrameIdx, lastFrameIdx]); 60 | 61 | % Detect lanes using the function detectLaneMarking. 62 | detections = detectLaneMarkings(net, frames, params, executionEnvironment); 63 | 64 | % Append the detections. 65 | laneMarkings = [laneMarkings; detections]; 66 | 67 | % Print detection progress. 68 | fprintf("Detected %d frames out of %d frames.\n",lastFrameIdx,v.NumFrames); 69 | end 70 | end 71 | 72 | % Copyright 2021 The MathWorks, Inc. 73 | -------------------------------------------------------------------------------- /src/+helper/createSCNNDetectionParameters.m: -------------------------------------------------------------------------------- 1 | function params = createSCNNDetectionParameters() 2 | % createSCNNDetectionParameters creates parameters required for detection. 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | % Specify these parameters for detection - 7 | % * Set the threshold as 0.5. Detections with confidence score less than 8 | % threshold are ignored. 9 | % * Set the networkInputSize as [288,800]. 10 | params.threshold = 0.5; 11 | params.networkInputSize = [288,800]; 12 | 13 | end 14 | -------------------------------------------------------------------------------- /src/+helper/downloadSCNNLaneDetection.m: -------------------------------------------------------------------------------- 1 | function model = downloadSCNNLaneDetection() 2 | % The downloadSCNNLaneDetection function loads a pretrained 3 | % SCNN network. 4 | % 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | dataPath = 'model'; 8 | modelName = 'scnn-culane'; 9 | netFileFullPath = fullfile(dataPath, modelName); 10 | 11 | % Add the extensions to filename. 12 | netMatFileFull = [netFileFullPath,'.mat']; 13 | netZipFileFull = [netFileFullPath,'.zip']; 14 | 15 | if ~exist(netZipFileFull,'file') 16 | fprintf(['Downloading pretrained ', modelName ,' network.\n']); 17 | fprintf('This can take several minutes to download...\n'); 18 | url = 'https://ssd.mathworks.com/supportfiles/vision/deeplearning/models/SpatialCNN/SCNN.zip'; 19 | websave (netFileFullPath,url); 20 | unzip(netZipFileFull, dataPath); 21 | path = fullfile(netMatFileFull); 22 | model = load(path); 23 | else 24 | if ~exist(netMatFileFull,'file') 25 | fprintf('Pretrained SCNN-CULane network already exists.\n\n'); 26 | unzip(netZipFileFull, dataPath); 27 | end 28 | path = fullfile(netMatFileFull); 29 | model = load(path); 30 | end 31 | end 32 | -------------------------------------------------------------------------------- /src/+helper/generateLines.m: -------------------------------------------------------------------------------- 1 | function coordinates = generateLines(pred, conf, imsize, params, fitPolyLine) 2 | % The generateLines function extracts the lane pixel coordinates from the 3 | % predicted probability map. 4 | 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | % Number of lanes to detect. 8 | numLanes = size(conf,1); 9 | 10 | % Filter size for gaussian blur. 11 | gaussFilterSize = [9,9]; 12 | 13 | % Batch size of input. 14 | batchSize = size(pred,4); 15 | 16 | % Create cell array to store the lane coordinates. 17 | coordinates = cell(batchSize,numLanes); 18 | 19 | % Threshold value below which the detections are ignored. 20 | threshold = params.threshold; 21 | 22 | % Suppress warnings generated by fitPolynomialRANSAC. 23 | warning('off','vision:ransac:maxTrialsReached'); 24 | 25 | % For each image 26 | for batch = 1:batchSize 27 | for lane = 1:numLanes 28 | probMap = pred(:,:,lane+1,batch); 29 | % Smoothen the probability maps. 30 | probMap = imgaussfilt(probMap,'FilterSize',gaussFilterSize,'Padding','replicate'); 31 | if conf(lane,batch)>0 32 | % Extract the lane points from probability maps. 33 | coordXY = lanePoints(probMap, threshold, imsize); 34 | % Remove all zeros. 35 | mask = (coordXY(:,1) == 0) & (coordXY(:,2) == 0); 36 | coordXY(mask,:) = []; 37 | if ~isempty(coordXY) 38 | % Sample points after specified gap including the last 39 | % point. 40 | last = coordXY(end,:); 41 | coordXY = coordXY(1:1:end,:); 42 | if ~all(coordXY(end,:) == last) 43 | coordXY(end+1,:) = last; 44 | end 45 | if ~fitPolyLine 46 | coordinates{batch,lane} = coordXY(:,1:2); 47 | else 48 | % Fit the equation x = a.y^2 + b.y + c. 49 | y = [coordXY(1,2):-1:coordXY(end,2)]'; 50 | % Second-degree polynomial. 51 | n = 2; 52 | % Maximum allowed distance for a point to be inlier. 53 | maxDistance = 3; 54 | p = fitPolynomialRANSAC([coordXY(:,2),coordXY(:,1)],n,maxDistance); 55 | x = round(polyval(p,y)); 56 | coordXY = [x,y]; 57 | ids = x>=1 & x<=imsize(2); 58 | coordXY = coordXY(ids,:); 59 | coordinates{batch,lane} = coordXY; 60 | end 61 | end 62 | end 63 | end 64 | end 65 | end 66 | 67 | function coordinates = lanePoints(probMap, thresh, imsize) 68 | % The function lanePoints returns the row and column indexes of the maximum 69 | % probability values. 70 | probshape = size(probMap); 71 | coordinates = zeros(probshape(1),2); 72 | for i=1:probshape(1) 73 | % Extract the indexes of the max value. 74 | rowID = probshape(1)-i+1; 75 | cols = probMap(rowID,:); 76 | [value, id] = max(cols); 77 | if value > thresh 78 | coordinates(i,1) = floor(id/probshape(2)*imsize(2)); 79 | coordinates(i,2) = floor(rowID/probshape(1)*imsize(1)); 80 | end 81 | end 82 | % Ignore if number of points are less than 2. 83 | if sum(coordinates>0)<2 84 | coordinates = zeros(probshape(1),2); 85 | end 86 | end 87 | -------------------------------------------------------------------------------- /src/+helper/plotLanes.m: -------------------------------------------------------------------------------- 1 | function plotLanes(f, image, detections) 2 | % The function plotLanes plots the lane coordinates as lines on the 3 | % image. The maximum number of lanes that can be marked are 4. 4 | 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | figure(f); 8 | clf; 9 | 10 | % Show the image. 11 | imshow(image); 12 | 13 | % Initialize lane colours. 14 | colors = ["red", "green", "blue", "yellow"]; 15 | % Add lane coordinates to the image. 16 | hold on; 17 | for i=1:size(detections,2) 18 | if ~isempty(detections{1,i}) 19 | plot(detections{1,i}(:,1),detections{1,i}(:,2),'LineWidth',10,'Color',colors(i)); 20 | end 21 | end 22 | hold off; 23 | end 24 | -------------------------------------------------------------------------------- /src/+helper/plotLanesVideo.m: -------------------------------------------------------------------------------- 1 | function plotLanesVideo(video, detections, startTime) 2 | % The function plotLanesVideo plots the lane coordinates on the video 3 | % frames and saves the ouput video. The maximum number of lanes that can be 4 | % marked are 4. 5 | 6 | % Copyright 2021 The MathWorks, Inc. 7 | 8 | % Reset the current time to start time. 9 | video.CurrentTime = startTime; 10 | 11 | frameIdx = 1; 12 | [~,videoName,~] = fileparts(video.Name); 13 | filename = strcat(videoName,"_detected"); 14 | savepath = fullfile(pwd, filename); 15 | 16 | % Initialize lane colours. 17 | colors = ["red", "green", "blue", "yellow"]; 18 | 19 | % Create output video object. 20 | outVideo = VideoWriter(savepath); 21 | outVideo.FrameRate = video.FrameRate; 22 | open(outVideo); 23 | 24 | while video.hasFrame 25 | frame = readFrame(video); 26 | for i=1:size(detections,2) 27 | if ~isempty(detections{frameIdx,i}) 28 | frame = insertMarker(frame, detections{frameIdx,i}(:,1:2),'o','Color',colors(i),'Size',5); 29 | end 30 | end 31 | writeVideo(outVideo,frame); 32 | frameIdx = frameIdx+1; 33 | end 34 | close(outVideo); 35 | end 36 | -------------------------------------------------------------------------------- /src/+helper/processPredictions.m: -------------------------------------------------------------------------------- 1 | function [probMap, confidence] = processPredictions(predictions, confidence, threshold) 2 | % The function processPredictions applies post-processing to the 3 | % predictions output by the SCNN network. 4 | % The function returns the probability map and the confidence score of the 5 | % detections. 6 | % 7 | % Copyright 2021 The MathWorks, Inc. 8 | 9 | if ~strcmp(dims(confidence),'CB') 10 | tmp = predictions; 11 | predictions = confidence; 12 | confidence = tmp; 13 | end 14 | 15 | % Apply a softmax to the predictions. 16 | probMap = softmax(predictions); 17 | 18 | % Extract and gather te predcitions. 19 | probMap = gather(extractdata(probMap)); 20 | confidence = gather(extractdata(confidence)); 21 | 22 | % Only consider the confidence values that are greater than threshold. 23 | confidence(confidence>threshold) = 1; 24 | confidence(confidence<=threshold) = 0; 25 | end 26 | -------------------------------------------------------------------------------- /src/+layer/MessagePassingLayer.m: -------------------------------------------------------------------------------- 1 | classdef MessagePassingLayer < nnet.layer.Layer 2 | % MessagePassingLayer Message Passing layer that applies 3 | % slice-by-slice convolution and relu within the 4 | % feature map. 5 | % 6 | % To create a MessagePassingLayer, use: 7 | % layer = MessagePassingLayer(filterSize, numFilters, direction, ... 8 | % 'Name', 'msgpassinglayer'); 9 | % 10 | % Inputs: 11 | % ------- 12 | % filterSize - Specify a positive integer for the filter size of the 13 | % slice-by-slice convolution operation. 14 | % 15 | % numFilters - Specify a positive integer for the number of filters 16 | % in the convolution layer. 17 | % 18 | % direction - Specify a string or character array for the direction 19 | % of the slice-by-slice convolution operation. The valid 20 | % value are 'topDown', 'bottomUp', 'leftRight', and 21 | % 'rightLeft'. 22 | % 23 | % layer = MessagePassingLayer(__, 'PARAM', VAL) specifies optional 24 | % parameter name/value pairs for creating the layer: 25 | % 26 | % 'Name' - A string or character array that specifies the name 27 | % for the layer. 28 | % 29 | % Default : '' 30 | % 31 | % Example: 32 | % -------- 33 | % Create a top down message passing layer with filtersize of 9 and 34 | % 128 number of filters. 35 | % l = MessagePassingLayer(9, 128, 'topDown', 'Name', 'top-down'); 36 | 37 | % Copyright 2021 The MathWorks, Inc. 38 | 39 | properties 40 | % NumFilters in message passing. 41 | NumFilters 42 | 43 | % Filter size of message passing block. 44 | FilterSize 45 | 46 | % Direction of message passing. 47 | Direction 48 | end 49 | 50 | properties (Learnable) 51 | % Layer learnable parameters 52 | 53 | MessagePassingBlock 54 | end 55 | 56 | methods 57 | function layer = MessagePassingLayer(filterSize,numFilters,direction,NameValueArgs) 58 | % Creates a conv+relu block with the specified convolution 59 | % filters, filter size and direction. 60 | 61 | % Parse input arguments. 62 | arguments 63 | filterSize {mustBeInteger,mustBePositive} 64 | numFilters {mustBeInteger,mustBePositive} 65 | direction {mustBeTextScalar} 66 | NameValueArgs.Name = '' 67 | end 68 | 69 | name = NameValueArgs.Name; 70 | 71 | % Set layer name. 72 | layer.Name = name; 73 | layer.NumFilters = numFilters; 74 | layer.FilterSize = filterSize; 75 | layer.Direction = iCheckAndReturnValidDirection(direction); 76 | 77 | 78 | % Set layer description. 79 | description = "Message passing " + layer.Direction + " block with num filters "+ layer.NumFilters+" and filter size "+layer.FilterSize; 80 | layer.Description = description; 81 | 82 | % Set layer type. 83 | layer.Type = "Message passing " + layer.Direction + " block"; 84 | 85 | % Define nested layer graph. 86 | if strcmp(layer.Direction,'topDown') || strcmp(layer.Direction,'bottomUp') 87 | layers = [ 88 | convolution2dLayer([1,filterSize],numFilters,'Padding',"same",'Name','conv1',"BiasLearnRateFactor",0) 89 | reluLayer('Name','relu1') 90 | ]; 91 | else 92 | layers = [ 93 | convolution2dLayer([filterSize,1],numFilters,'Padding',"same",'Name','conv1',"BiasLearnRateFactor",0) 94 | reluLayer('Name','relu1') 95 | ]; 96 | end 97 | 98 | lgraph = layerGraph(layers); 99 | 100 | % Convert to dlnetwork. 101 | dlnet = dlnetwork(lgraph,'Initialize',false); 102 | 103 | % Set Network property. 104 | layer.MessagePassingBlock = dlnet; 105 | end 106 | 107 | function Z = predict(layer, X) 108 | % Forward input data through the layer at prediction time and 109 | % output the result. 110 | % 111 | % Inputs: 112 | % layer - Layer to forward propagate through 113 | % X - Input data 114 | % Outputs: 115 | % Z - Output of layer forward function 116 | 117 | % Convert input data to formatted dlarray. 118 | X = dlarray(X,'SSCB'); 119 | dlnet = layer.MessagePassingBlock; 120 | fh = str2func(layer.Direction); 121 | 122 | % Process message passing block. 123 | Z = fh(dlnet,X); 124 | Z = stripdims(Z); 125 | end 126 | end 127 | end 128 | 129 | function value = iCheckAndReturnValidDirection(value) 130 | validateattributes(value, {'char','string'},{},'','Direction'); 131 | value = validatestring(value, {'topDown', 'bottomUp', 'leftRight', 'rightLeft'},'','Direction'); 132 | end 133 | 134 | function block = topDown(net,block) 135 | % Slice along the height. 136 | sliceDim = 1; 137 | for i = 2:size(block,sliceDim) 138 | z = predict(net,block(i-1,:,:,:)); 139 | block(i,:,:,:) = block(i,:,:,:) + z; 140 | end 141 | end 142 | 143 | function block = bottomUp(net,block) 144 | % Slice along the height. 145 | sliceDim = 1; 146 | for i = size(block,sliceDim)-1:-1:1 147 | z = predict(net,block(i+1,:,:,:)); 148 | block(i,:,:,:) = block(i,:,:,:) + z; 149 | end 150 | end 151 | 152 | function block = leftRight(net,block) 153 | % Slice along the width. 154 | sliceDim = 2; 155 | for i = 2:size(block,sliceDim) 156 | z = predict(net,block(:,i-1,:,:)); 157 | block(:,i,:,:) = block(:,i,:,:) + z; 158 | end 159 | end 160 | 161 | function block = rightLeft(net,block) 162 | % Slice along the width. 163 | sliceDim = 2; 164 | for i = size(block,sliceDim)-1:-1:1 165 | z = predict(net,block(:,i+1,:,:)); 166 | block(:,i,:,:) = block(:,i,:,:) + z; 167 | end 168 | end 169 | -------------------------------------------------------------------------------- /src/detectLaneMarkings.m: -------------------------------------------------------------------------------- 1 | function detections = detectLaneMarkings(net, data, params, executionEnvironment, NameValueArgs) 2 | % detections = detectLaneMarkings(net, image, params, executionEnvironment) 3 | % runs prediction on a pre-trained SCNN network. 4 | % 5 | % Inputs: 6 | % ------- 7 | % net - Pretrained SCNN dlnetwork. 8 | % data - Input data must be a single RGB image of size 9 | % HxWx3 or an array of RGB images of size HxWx3xB, 10 | % where H is the height, W is the width and B is the 11 | % number of images. 12 | % params - Parameters required to run inference on SCNN 13 | % created using 14 | % helper.createSCNNDetectionParameters. 15 | % executionEnvironment - Environment to run predictions on. Specify cpu, 16 | % gpu, or auto. 17 | % 18 | % detections = detectLaneMarkings(..., Name, Value) specifies the optional 19 | % name-value pair argument as described below. 20 | % 21 | % 'FitPolyLine' - Specify the value true or false. If true then 22 | % second order polynomials are fit to the detections 23 | % to make them smooth else raw detections are 24 | % returned. 25 | % 26 | % Default: true 27 | % 28 | % 29 | % Output: 30 | % ------- 31 | % detections - Returns a cell array of M-by-N, where M is the 32 | % number of images in a batch and N is the number of 33 | % lanes detected. Each cell contains a P-by-2 array 34 | % of pixel coordinates of the detected lane 35 | % markings, where P is the number of points and 2 36 | % colums are X and Y values of their respective 37 | % pixel coordinates. 38 | 39 | % Copyright 2021 The MathWorks, Inc. 40 | 41 | % Parse input arguments. 42 | arguments 43 | net 44 | data 45 | params 46 | executionEnvironment 47 | NameValueArgs.FitPolyLine = true; 48 | end 49 | fitPolyLine = NameValueArgs.FitPolyLine; 50 | 51 | % Get the input image size. 52 | inputImageSize = size(data); 53 | 54 | % Resize the image to the input size of the network. 55 | resizedImage = imresize(data, params.networkInputSize); 56 | 57 | % Rescale the pixels in the range [0,1]. 58 | resizedImage = im2single(resizedImage); 59 | 60 | % Convert the resized image to dlarray and gpuArray if specified. 61 | if canUseGPU && ~(strcmp(executionEnvironment,"cpu")) 62 | resizedImage = gpuArray(dlarray(resizedImage,'SSCB')); 63 | else 64 | resizedImage = dlarray(resizedImage,'SSCB'); 65 | end 66 | 67 | % Predict the output. 68 | [laneMask, confidence] = predict(net, resizedImage); 69 | 70 | % Process the predictions to output probability map and confidence scores. 71 | [laneMask, confidence] = helper.processPredictions(laneMask, confidence, params.threshold); 72 | 73 | % Extract lane marking coordinates from the probability map and confidence 74 | % scores. 75 | detections = helper.generateLines(laneMask, confidence, inputImageSize, params, fitPolyLine); 76 | 77 | end 78 | -------------------------------------------------------------------------------- /test/tMessagePassingLayer.m: -------------------------------------------------------------------------------- 1 | classdef tMessagePassingLayer < matlab.unittest.TestCase 2 | % tMessagePassingLayer. Test for MessagePassingLayer 3 | properties 4 | MessagePassingLayer = @layer.MessagePassingLayer; 5 | NumFilters= 128; 6 | FilterSize = 9; 7 | RS 8 | end 9 | properties(TestParameter) 10 | Direction = {'topDown', 'bottomUp', 'leftRight', 'rightLeft'}; 11 | ExpectedOut = iGetExpectedOutputs(); 12 | end 13 | methods(TestMethodSetup) 14 | function setupRNGSeed(test) 15 | test.RS = rng(0); 16 | end 17 | end 18 | 19 | methods(TestMethodTeardown) 20 | function resetRNGSeed(test) 21 | rng(test.RS); 22 | end 23 | end 24 | methods(Test,ParameterCombination = 'sequential') 25 | function testConstruction(test,Direction) 26 | layer = test.MessagePassingLayer(test.FilterSize,test.NumFilters,Direction,'Name',strcat('message_passing_',Direction)); 27 | 28 | % Verifying the layer properties. 29 | test.verifyEqual(layer.Name,strcat('message_passing_',Direction)) 30 | test.verifyEqual(layer.NumFilters,test.NumFilters); 31 | test.verifyEqual(layer.FilterSize,test.FilterSize); 32 | test.verifyEqual(layer.Type,strcat("Message passing " , Direction , " block")); 33 | test.verifyClass(layer.MessagePassingBlock,'dlnetwork') 34 | test.verifyEqual(numel(layer.MessagePassingBlock.Layers),2); 35 | test.verifyFalse(layer.MessagePassingBlock.Initialized); 36 | % Verifying the properties of layers inside 37 | % MessagePassingBlock. 38 | test.verifyClass(layer.MessagePassingBlock.Layers(1),'nnet.cnn.layer.Convolution2DLayer'); 39 | test.verifyEqual(layer.MessagePassingBlock.Layers(1).Name,'conv1'); 40 | 41 | if ismember(Direction,{'topDown', 'bottomUp'}) 42 | test.verifyEqual(layer.MessagePassingBlock.Layers(1).FilterSize,[1 test.FilterSize]); 43 | else 44 | test.verifyEqual(layer.MessagePassingBlock.Layers(1).FilterSize,[test.FilterSize 1]); 45 | end 46 | 47 | test.verifyEqual(layer.MessagePassingBlock.Layers(1).NumFilters, test.NumFilters); 48 | test.verifyEqual(layer.MessagePassingBlock.Layers(1).PaddingMode,'same'); 49 | test.verifyClass(layer.MessagePassingBlock.Layers(2),'nnet.cnn.layer.ReLULayer'); 50 | test.verifyEqual(layer.MessagePassingBlock.Layers(2).Name,'relu1'); 51 | end 52 | function testPredict(test,Direction,ExpectedOut) 53 | numFilters = 2; 54 | layer = test.MessagePassingLayer(test.FilterSize,numFilters,Direction); 55 | %constructing an initialized dlnetwork the messagePassingBlock 56 | %to call predict and exercising predict. 57 | dlnet = dlnetwork([imageInputLayer([3 3 3],'normalization','none') ; layer.MessagePassingBlock.Layers]); 58 | input = dlarray(ones(2,2,3),'SSC'); 59 | out = dlnet.predict(input); 60 | 61 | test.verifyEqual(size(out),[2 2 numFilters]); 62 | test.verifyEqual(ExpectedOut,ExpectedOut) 63 | end 64 | end 65 | end 66 | function out = iGetExpectedOutputs() 67 | td(:,:,1) = [0.5833 0.6376; 0.5833 0.6376]; 68 | td(:,:,2) = [0 0 ; 0 0 ]; 69 | bu(:,:,1) = [0.5833 0.6376; 0.5833 0.6376]; 70 | bu(:,:,2) = [0 0 ; 0 0 ]; 71 | lr(:,:,1) = [0.5833 0.5833; 0.6376 0.6376]; 72 | lr(:,:,2) = [0 0; 0 0]; 73 | rl(:,:,1) = [0.5833 0.5833; 0.6376 0.6376]; 74 | rl(:,:,2) = [0 0;0 0]; 75 | 76 | out = struct('topDownExp',td, 'bottomUpExp',bu, 'leftRightExp',lr, 'rightLeftExp',rl); 77 | end 78 | -------------------------------------------------------------------------------- /test/tPretrainedSCNNLaneDetection.m: -------------------------------------------------------------------------------- 1 | classdef(SharedTestFixtures = {DownloadSCNNLaneDetectionFixture}) tPretrainedSCNNLaneDetection < matlab.unittest.TestCase 2 | % Test for tPretrainedSCNNLaneDetection 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | % The shared test fixture downloads the model. Here we check the 7 | % inference on the pretrained model. 8 | properties 9 | RepoRoot = getRepoRoot; 10 | ModelName = 'scnn-culane.mat'; 11 | end 12 | 13 | methods(Test) 14 | function exerciseDetection(test) 15 | model = load(fullfile(test.RepoRoot,'model',test.ModelName)); 16 | image = imread(fullfile(test.RepoRoot,"images","testImage.jpg")); 17 | executionEnvironment = "auto"; 18 | params = helper.createSCNNDetectionParameters; 19 | laneMarkings = detectLaneMarkings(model.net, image, params, executionEnvironment); 20 | 21 | % verifying output class 22 | test.verifyClass(laneMarkings,'cell'); 23 | 24 | % verifying size of output from detectLaneMarkings. 25 | test.verifyEqual(size(laneMarkings{1}),[322 2]); 26 | test.verifyEqual(size(laneMarkings{2}),[560 2]); 27 | test.verifyEqual(size(laneMarkings{3}),[556 2]); 28 | test.verifyEqual(size(laneMarkings{4}),[0 0]); 29 | end 30 | end 31 | end 32 | -------------------------------------------------------------------------------- /test/tdownloadSCNNLaneDetection.m: -------------------------------------------------------------------------------- 1 | classdef(SharedTestFixtures = {DownloadSCNNLaneDetectionFixture}) tdownloadSCNNLaneDetection < matlab.unittest.TestCase 2 | % Test for downloadSCNNLaneDetection 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | % The shared test fixture DownloadSCNNLaneDetectionFixture calls 7 | % downloadSCNNLaneDetection. Here we check that the downloaded files 8 | % exists in the appropriate location. 9 | 10 | properties 11 | DataDir = fullfile(getRepoRoot(),'model'); 12 | end 13 | 14 | methods(Test) 15 | function verifyDownloadedFilesExist(test) 16 | dataFileName = 'scnn-culane.mat'; 17 | test.verifyTrue(isequal(exist(fullfile(test.DataDir,dataFileName),'file'),2)); 18 | end 19 | end 20 | end 21 | -------------------------------------------------------------------------------- /test/tload.m: -------------------------------------------------------------------------------- 1 | classdef(SharedTestFixtures = {DownloadSCNNLaneDetectionFixture}) tload < matlab.unittest.TestCase 2 | % Test for loading the downloaded model. 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | % The shared test fixture DownloadSCNNLaneDetectionFixture calls 7 | % downloadSCNNLaneDetection. Here we check that the properties of 8 | % downloaded model. 9 | 10 | properties 11 | DataDir = fullfile(getRepoRoot(),'model'); 12 | end 13 | 14 | methods(Test) 15 | function verifyModelAndFields(test) 16 | % Test point to verify the fields of the downloaded models are 17 | % as expected. 18 | 19 | loadedModel = load(fullfile(test.DataDir,'scnn-culane.mat')); 20 | 21 | test.verifyClass(loadedModel.net,'dlnetwork'); 22 | test.verifyEqual(numel(loadedModel.net.Layers),62); 23 | test.verifyEqual(size(loadedModel.net.Connections),[61 2]) 24 | test.verifyEqual(loadedModel.net.InputNames,{'input'}); 25 | test.verifyEqual(loadedModel.net.OutputNames,{'layer2_2','fc_3'}); 26 | end 27 | end 28 | end -------------------------------------------------------------------------------- /test/tools/DownloadSCNNLaneDetectionFixture.m: -------------------------------------------------------------------------------- 1 | classdef DownloadSCNNLaneDetectionFixture < matlab.unittest.fixtures.Fixture 2 | % DownloadSCNNLaneDetectionFixture A fixture for calling 3 | % downloadDownloadSCNNLaneDetectionFixture if necessary. This is to 4 | % ensure that this function is only called once and only when tests 5 | % need it. It also provides a teardown to return the test environment 6 | % to the expected state before testing. 7 | 8 | % Copyright 2021 The MathWorks, Inc 9 | 10 | properties(Constant) 11 | SCNNDataDir = fullfile(getRepoRoot(),'model') 12 | end 13 | 14 | properties 15 | SCNNExist (1,1) logical 16 | end 17 | 18 | methods 19 | function setup(this) 20 | import matlab.unittest.fixtures.CurrentFolderFixture; 21 | this.applyFixture(CurrentFolderFixture ... 22 | (getRepoRoot())); 23 | 24 | this.SCNNExist = exist(fullfile(this.SCNNDataDir,'scnn-culane.mat'),'file')==2; 25 | 26 | % Call this in eval to capture and drop any standard output 27 | % that we don't want polluting the test logs. 28 | if ~this.SCNNExist 29 | evalc('helper.downloadSCNNLaneDetection();'); 30 | end 31 | end 32 | 33 | function teardown(this) 34 | if this.SCNNExist 35 | delete(fullfile(this.SCNNDataDir,'model','scnn-culane.mat')); 36 | end 37 | end 38 | end 39 | end 40 | -------------------------------------------------------------------------------- /test/tools/getRepoRoot.m: -------------------------------------------------------------------------------- 1 | function path = getRepoRoot() 2 | % getRepoRoot Return a path to the repository's root directory. 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | thisFile = mfilename('fullpath'); 7 | thisDir = fileparts(thisFile); 8 | 9 | % the root is up two directories (/test/tools/getRepoRoot.m) 10 | path = fullfile(thisDir,'..','..'); 11 | end 12 | --------------------------------------------------------------------------------