├── .circleci
└── config.yml
├── .gitignore
├── LICENSE
├── README.md
├── SECURITY.md
├── images
├── SCNN_architecture.png
├── caltech_washington1.gif
├── detections.gif
├── result.jpg
└── testImage.jpg
├── model
└── .gitkeep
├── spatialCNNLaneDetectionExample.m
├── spatialCNNLaneDetectionVideoExample.m
├── src
├── +helper
│ ├── createSCNNDetectionParameters.m
│ ├── downloadSCNNLaneDetection.m
│ ├── generateLines.m
│ ├── plotLanes.m
│ ├── plotLanesVideo.m
│ └── processPredictions.m
├── +layer
│ └── MessagePassingLayer.m
└── detectLaneMarkings.m
└── test
├── tMessagePassingLayer.m
├── tPretrainedSCNNLaneDetection.m
├── tdownloadSCNNLaneDetection.m
├── tload.m
└── tools
├── DownloadSCNNLaneDetectionFixture.m
└── getRepoRoot.m
/.circleci/config.yml:
--------------------------------------------------------------------------------
1 | version: 2.1
2 | orbs:
3 | matlab: mathworks/matlab@0.4.0
4 |
5 | jobs:
6 | build:
7 | machine:
8 | image: ubuntu-1604:201903-01
9 | steps:
10 | - checkout
11 | - matlab/install
12 | - matlab/run-tests:
13 | test-results-junit: artifacts/test_results/matlab/results.xml
14 | # Have to add test/tools to the path for certain tests.
15 | source-folder: .;test/tools;src
16 | - store_test_results:
17 | path: artifacts/test_results
18 | - store_artifacts:
19 | path: artifacts/
20 |
--------------------------------------------------------------------------------
/.gitignore:
--------------------------------------------------------------------------------
1 | model/
2 |
--------------------------------------------------------------------------------
/LICENSE:
--------------------------------------------------------------------------------
1 | Copyright (c) 2021, The MathWorks, Inc.
2 | All rights reserved.
3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
5 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
6 | 3. In all cases, the software is, and all modifications and derivatives of the software shall be, licensed to you solely for use in conjunction with MathWorks products and service offerings.
7 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
8 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Lane Detection Using Deep Learning
2 |
3 | This repository implements a pretrained Spatial-CNN (SCNN)[1] lane detection model in MATLAB®.
4 |
5 | ## Requirements
6 | - MATLAB® R2021a or later.
7 | - Deep Learning Toolbox™.
8 | - Computer Vision Toolbox™.
9 | - Automated Driving Toolbox™.
10 |
11 | ## Overview
12 | This repository implements SCNN with VGG-16 as the backbone. The pretrained network is trained to detect lanes in the image. The network is trained using [CULane](https://xingangpan.github.io/projects/CULane.html) dataset[1].
13 |
14 | Spatial-CNN (SCNN) uses slice-by-slice convolutions on the feature maps obtained by layer-by-layer convolutions since the spatial information can be reinforced via inter-layer propagation. This helps in detecting objects with strong structure prior but less appearance clues such as lanes, poles, or trucks with occlusions.
15 |
16 | ## Getting Started
17 | Download or clone this repository to your machine and open it in MATLAB®.
18 |
19 | ### Setup
20 | Add path to the source directory.
21 | ```
22 | addpath('src');
23 | ```
24 |
25 | ### Download and Load the Pretrained Network
26 | Use the below helper to download and load the pretrained network. The network will be downloaded and saved in `model` directory.
27 | ```
28 | model = helper.downloadSCNNLaneDetection;
29 | net = model.net;
30 | ```
31 |
32 | ### Detect Lanes Using SCNN
33 | This snippet includes all the steps required to run SCNN model on a single RGB image in MATLAB®. Use the script `spatialCNNLaneDetectionExample.m` to run the inference on single image.
34 |
35 | ```
36 | % Specify Detection Parameters.
37 | params = helper.createSCNNDetectionParameters;
38 |
39 | % Specify the executionEnvironment as either "cpu", "gpu", or "auto".
40 | executionEnvironment = "auto";
41 |
42 | % Read the test image.
43 | path = fullfile("images","testImage.jpg");
44 | image = imread(path);
45 |
46 | % Use detectLaneMarkings function to detect the lane markings.
47 | laneMarkings = detectLaneMarkings(net, image, params, executionEnvironment);
48 |
49 | % Visualize the detected lanes.
50 | fig = figure;
51 | helper.plotLanes(fig, image, laneMarkings);
52 |
53 | ```
54 |
55 | Alternatively, you can also run the SCNN model on sample videos. Use the script `spatialCNNLaneDetectionVideoExample.m` to run the inference on a driving scene.
56 |
57 |
58 | ### Result
59 | Left-side image is the input and right-side image shows the detected lanes. The image is taken from the [Panda Set](https://scale.com/open-datasets/pandaset) dataset[2].
60 |
61 |
62 |
63 | |
64 | |
65 |
66 |
67 |
68 | Sample video output generated by the script `spatialCNNLaneDetectionVideoExample.m`.
69 |
70 |
71 |
72 | ## Evaluation Metrics
73 | The model is evaluated using the method specified in [1].
74 |
75 | | Dataset | Error Metric | IOU | Result |
76 | | ------------- | ------------- | ------------- | ------------- |
77 | | CULane | F-measure | 0.3 | 73.45 |
78 | | CULane | F-measure | 0.5 | 43.41 |
79 |
80 | ## Spatial-CNN Algorithm Details
81 | The SCNN network architecture is illustrated in the following diagram.
82 | |
|
83 | |:--:|
84 | |**Fig.1**|
85 |
86 |
87 | The network takes RGB images as input and outputs a probability map and confidence score for each lane. The pre-trained SCNN model trained on [CULane](https://xingangpan.github.io/projects/CULane.html) can detect maximum of 4 lanes( 2 driving lanes and 2 lanes on either side of the driving lane). The probability map predicted by the network has 5 channels (4 lanes + 1 background). Lanes with confidence score less than 0.5 are ignored. To generate the detections the probability map is processed and curves are fit.
88 |
89 |
90 | SCNN network in this repository has 4 message passing layers in sequence in the directions top-to-bottom, bottom-to-top, left-to-right, and right-to-left with kernel size of 9 and are represented by up-down, down-up, left-right, and right-left respectively in Fig.1. The message passing layers are special layers that apply slice-by-slice convolutions within the feature map[1]. These layers are implemented as a custom nested deep learning layer. For more information about the custom nested deep learning layer, see [Define Nested Deep Learning Layer](https://www.mathworks.com/help/deeplearning/ug/define-nested-deep-learning-layer.html).
91 |
92 | ## References
93 |
94 | [1] Xingang Pan, Jianping Shi, Ping Luo, Xiaogang Wang, and Xiaoou Tang. "Spatial As Deep: Spatial CNN for Traffic Scene Understanding" AAAI Conference on Artificial Intelligence (AAAI) - 2018
95 |
96 | [2] [Panda Set](https://scale.com/open-datasets/pandaset) is provided by Hesai and Scale under the [CC-BY-4.0 license](https://creativecommons.org/licenses/by/4.0)
97 |
98 | ## See also
99 | [Visual Perception Using Monocular Camera](https://www.mathworks.com/help/driving/ug/visual-perception-using-monocular-camera.html)
100 |
101 | Copyright 2021 The MathWorks, Inc.
102 |
103 |
--------------------------------------------------------------------------------
/SECURITY.md:
--------------------------------------------------------------------------------
1 | # Reporting Security Vulnerabilities
2 |
3 | If you believe you have discovered a security vulnerability, please report it to
4 | [security@mathworks.com](mailto:security@mathworks.com). Please see
5 | [MathWorks Vulnerability Disclosure Policy for Security Researchers](https://www.mathworks.com/company/aboutus/policies_statements/vulnerability-disclosure-policy.html)
6 | for additional information.
7 |
--------------------------------------------------------------------------------
/images/SCNN_architecture.png:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/pretrained-spatial-CNN/bc159eaf3bc4f411904598652d94fc90d7eeafa7/images/SCNN_architecture.png
--------------------------------------------------------------------------------
/images/caltech_washington1.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/pretrained-spatial-CNN/bc159eaf3bc4f411904598652d94fc90d7eeafa7/images/caltech_washington1.gif
--------------------------------------------------------------------------------
/images/detections.gif:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/pretrained-spatial-CNN/bc159eaf3bc4f411904598652d94fc90d7eeafa7/images/detections.gif
--------------------------------------------------------------------------------
/images/result.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/pretrained-spatial-CNN/bc159eaf3bc4f411904598652d94fc90d7eeafa7/images/result.jpg
--------------------------------------------------------------------------------
/images/testImage.jpg:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/matlab-deep-learning/pretrained-spatial-CNN/bc159eaf3bc4f411904598652d94fc90d7eeafa7/images/testImage.jpg
--------------------------------------------------------------------------------
/model/.gitkeep:
--------------------------------------------------------------------------------
1 |
2 |
--------------------------------------------------------------------------------
/spatialCNNLaneDetectionExample.m:
--------------------------------------------------------------------------------
1 | %% Lane Detection Using Spatial-CNN Network
2 | % The following code demonstrates running lane detection on a pre-trained SCNN
3 | % network, trained on CULane dataset.
4 |
5 | %% Prerequisites
6 | % To run this example you need the following prerequisites -
7 | % * MATLAB (R2021a or later).
8 | % * Deep Learning Toolbox.
9 | % * Pretrained Spatial-CNN network (download instructions below).
10 |
11 | %% Add path to the source directory
12 | addpath('src');
13 |
14 | %% Download Pre-trained Network
15 | model = helper.downloadSCNNLaneDetection;
16 | net = model.net;
17 |
18 | %% Specify Detection Parameters
19 | % Use the function helper.createSCNNDetectionParameters to specify the
20 | % parameters required for lane detection.
21 | params = helper.createSCNNDetectionParameters;
22 |
23 | % Specify the executionEnvironment as either "cpu", "gpu", or "auto".
24 | executionEnvironment = "auto";
25 |
26 | %% Detect on an Image
27 | % Read the test image.
28 | path = fullfile("images","testImage.jpg");
29 | image = imread(path);
30 |
31 | % Call detectLaneMarkings to detect the lane markings.
32 | laneMarkings = detectLaneMarkings(net, image, params, executionEnvironment);
33 |
34 | % Visualize the detected lanes.
35 | fig = figure;
36 | helper.plotLanes(fig, image, laneMarkings);
37 |
38 | % Copyright 2021 The MathWorks, Inc.
39 |
--------------------------------------------------------------------------------
/spatialCNNLaneDetectionVideoExample.m:
--------------------------------------------------------------------------------
1 | %% Lane Detection Using Spatial-CNN Network in Driving Scene
2 | % The following code demonstrates running lane detection on a pre-trained
3 | % SCNN network on a driving scene.
4 |
5 | %% Prerequisites
6 | % To run this example you need the following prerequisites -
7 | % * MATLAB (R2021a or later).
8 | % * Deep Learning Toolbox.
9 | % * Automated Driving Toolbox (required to load the video).
10 | % * Pretrained Spatial-CNN network (download instructions below).
11 |
12 | %% Add path to the source directory
13 | addpath('src');
14 |
15 | %% Download Pre-trained Network
16 | model = helper.downloadSCNNLaneDetection;
17 | net = model.net;
18 |
19 | %% Specify Detection Parameters
20 | % Use the function helper.createSCNNDetectionParameters to specify the
21 | % parameters required for lane detection.
22 | params = helper.createSCNNDetectionParameters;
23 |
24 | % Specify the mini batch size as 8. Increase this value to speed up
25 | % detection time.
26 | miniBatchSize = 8;
27 |
28 | % Specify the executionEnvironment as either "cpu", "gpu", or "auto".
29 | executionEnvironment = "auto";
30 |
31 | %% Detect in Video
32 | % Read the video.
33 | v = VideoReader('caltech_washington1.avi');
34 |
35 | % Store the video start time.
36 | videoStartTime = v.CurrentTime;
37 |
38 | % Detect using the detectLaneMarkingVideo function provided as helper
39 | % function below.
40 | laneMarkings = detectLaneMarkingVideo(net, v, params, miniBatchSize, executionEnvironment);
41 |
42 | % Plot detections in video and save result.
43 | helper.plotLanesVideo(v, laneMarkings, videoStartTime);
44 |
45 |
46 | %% Helper function to Detect in Video
47 | function laneMarkings = detectLaneMarkingVideo(net, v, params, miniBatchSize, executionEnvironment)
48 | % Detect lane markings in a video by reading frames in batches.
49 | laneMarkings = {};
50 | numBatches = ceil(v.NumFrames/miniBatchSize);
51 | for batch = 1:numBatches
52 | firstFrameIdx = miniBatchSize*batch-(miniBatchSize-1);
53 | if batch ~= numBatches
54 | lastFrameIdx = miniBatchSize*batch;
55 | else
56 | lastFrameIdx = v.NumFrames;
57 | end
58 | % Read batch of frames.
59 | frames = read(v, [firstFrameIdx, lastFrameIdx]);
60 |
61 | % Detect lanes using the function detectLaneMarking.
62 | detections = detectLaneMarkings(net, frames, params, executionEnvironment);
63 |
64 | % Append the detections.
65 | laneMarkings = [laneMarkings; detections];
66 |
67 | % Print detection progress.
68 | fprintf("Detected %d frames out of %d frames.\n",lastFrameIdx,v.NumFrames);
69 | end
70 | end
71 |
72 | % Copyright 2021 The MathWorks, Inc.
73 |
--------------------------------------------------------------------------------
/src/+helper/createSCNNDetectionParameters.m:
--------------------------------------------------------------------------------
1 | function params = createSCNNDetectionParameters()
2 | % createSCNNDetectionParameters creates parameters required for detection.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | % Specify these parameters for detection -
7 | % * Set the threshold as 0.5. Detections with confidence score less than
8 | % threshold are ignored.
9 | % * Set the networkInputSize as [288,800].
10 | params.threshold = 0.5;
11 | params.networkInputSize = [288,800];
12 |
13 | end
14 |
--------------------------------------------------------------------------------
/src/+helper/downloadSCNNLaneDetection.m:
--------------------------------------------------------------------------------
1 | function model = downloadSCNNLaneDetection()
2 | % The downloadSCNNLaneDetection function loads a pretrained
3 | % SCNN network.
4 | %
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | dataPath = 'model';
8 | modelName = 'scnn-culane';
9 | netFileFullPath = fullfile(dataPath, modelName);
10 |
11 | % Add the extensions to filename.
12 | netMatFileFull = [netFileFullPath,'.mat'];
13 | netZipFileFull = [netFileFullPath,'.zip'];
14 |
15 | if ~exist(netZipFileFull,'file')
16 | fprintf(['Downloading pretrained ', modelName ,' network.\n']);
17 | fprintf('This can take several minutes to download...\n');
18 | url = 'https://ssd.mathworks.com/supportfiles/vision/deeplearning/models/SpatialCNN/SCNN.zip';
19 | websave (netFileFullPath,url);
20 | unzip(netZipFileFull, dataPath);
21 | path = fullfile(netMatFileFull);
22 | model = load(path);
23 | else
24 | if ~exist(netMatFileFull,'file')
25 | fprintf('Pretrained SCNN-CULane network already exists.\n\n');
26 | unzip(netZipFileFull, dataPath);
27 | end
28 | path = fullfile(netMatFileFull);
29 | model = load(path);
30 | end
31 | end
32 |
--------------------------------------------------------------------------------
/src/+helper/generateLines.m:
--------------------------------------------------------------------------------
1 | function coordinates = generateLines(pred, conf, imsize, params, fitPolyLine)
2 | % The generateLines function extracts the lane pixel coordinates from the
3 | % predicted probability map.
4 |
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | % Number of lanes to detect.
8 | numLanes = size(conf,1);
9 |
10 | % Filter size for gaussian blur.
11 | gaussFilterSize = [9,9];
12 |
13 | % Batch size of input.
14 | batchSize = size(pred,4);
15 |
16 | % Create cell array to store the lane coordinates.
17 | coordinates = cell(batchSize,numLanes);
18 |
19 | % Threshold value below which the detections are ignored.
20 | threshold = params.threshold;
21 |
22 | % Suppress warnings generated by fitPolynomialRANSAC.
23 | warning('off','vision:ransac:maxTrialsReached');
24 |
25 | % For each image
26 | for batch = 1:batchSize
27 | for lane = 1:numLanes
28 | probMap = pred(:,:,lane+1,batch);
29 | % Smoothen the probability maps.
30 | probMap = imgaussfilt(probMap,'FilterSize',gaussFilterSize,'Padding','replicate');
31 | if conf(lane,batch)>0
32 | % Extract the lane points from probability maps.
33 | coordXY = lanePoints(probMap, threshold, imsize);
34 | % Remove all zeros.
35 | mask = (coordXY(:,1) == 0) & (coordXY(:,2) == 0);
36 | coordXY(mask,:) = [];
37 | if ~isempty(coordXY)
38 | % Sample points after specified gap including the last
39 | % point.
40 | last = coordXY(end,:);
41 | coordXY = coordXY(1:1:end,:);
42 | if ~all(coordXY(end,:) == last)
43 | coordXY(end+1,:) = last;
44 | end
45 | if ~fitPolyLine
46 | coordinates{batch,lane} = coordXY(:,1:2);
47 | else
48 | % Fit the equation x = a.y^2 + b.y + c.
49 | y = [coordXY(1,2):-1:coordXY(end,2)]';
50 | % Second-degree polynomial.
51 | n = 2;
52 | % Maximum allowed distance for a point to be inlier.
53 | maxDistance = 3;
54 | p = fitPolynomialRANSAC([coordXY(:,2),coordXY(:,1)],n,maxDistance);
55 | x = round(polyval(p,y));
56 | coordXY = [x,y];
57 | ids = x>=1 & x<=imsize(2);
58 | coordXY = coordXY(ids,:);
59 | coordinates{batch,lane} = coordXY;
60 | end
61 | end
62 | end
63 | end
64 | end
65 | end
66 |
67 | function coordinates = lanePoints(probMap, thresh, imsize)
68 | % The function lanePoints returns the row and column indexes of the maximum
69 | % probability values.
70 | probshape = size(probMap);
71 | coordinates = zeros(probshape(1),2);
72 | for i=1:probshape(1)
73 | % Extract the indexes of the max value.
74 | rowID = probshape(1)-i+1;
75 | cols = probMap(rowID,:);
76 | [value, id] = max(cols);
77 | if value > thresh
78 | coordinates(i,1) = floor(id/probshape(2)*imsize(2));
79 | coordinates(i,2) = floor(rowID/probshape(1)*imsize(1));
80 | end
81 | end
82 | % Ignore if number of points are less than 2.
83 | if sum(coordinates>0)<2
84 | coordinates = zeros(probshape(1),2);
85 | end
86 | end
87 |
--------------------------------------------------------------------------------
/src/+helper/plotLanes.m:
--------------------------------------------------------------------------------
1 | function plotLanes(f, image, detections)
2 | % The function plotLanes plots the lane coordinates as lines on the
3 | % image. The maximum number of lanes that can be marked are 4.
4 |
5 | % Copyright 2021 The MathWorks, Inc.
6 |
7 | figure(f);
8 | clf;
9 |
10 | % Show the image.
11 | imshow(image);
12 |
13 | % Initialize lane colours.
14 | colors = ["red", "green", "blue", "yellow"];
15 | % Add lane coordinates to the image.
16 | hold on;
17 | for i=1:size(detections,2)
18 | if ~isempty(detections{1,i})
19 | plot(detections{1,i}(:,1),detections{1,i}(:,2),'LineWidth',10,'Color',colors(i));
20 | end
21 | end
22 | hold off;
23 | end
24 |
--------------------------------------------------------------------------------
/src/+helper/plotLanesVideo.m:
--------------------------------------------------------------------------------
1 | function plotLanesVideo(video, detections, startTime)
2 | % The function plotLanesVideo plots the lane coordinates on the video
3 | % frames and saves the ouput video. The maximum number of lanes that can be
4 | % marked are 4.
5 |
6 | % Copyright 2021 The MathWorks, Inc.
7 |
8 | % Reset the current time to start time.
9 | video.CurrentTime = startTime;
10 |
11 | frameIdx = 1;
12 | [~,videoName,~] = fileparts(video.Name);
13 | filename = strcat(videoName,"_detected");
14 | savepath = fullfile(pwd, filename);
15 |
16 | % Initialize lane colours.
17 | colors = ["red", "green", "blue", "yellow"];
18 |
19 | % Create output video object.
20 | outVideo = VideoWriter(savepath);
21 | outVideo.FrameRate = video.FrameRate;
22 | open(outVideo);
23 |
24 | while video.hasFrame
25 | frame = readFrame(video);
26 | for i=1:size(detections,2)
27 | if ~isempty(detections{frameIdx,i})
28 | frame = insertMarker(frame, detections{frameIdx,i}(:,1:2),'o','Color',colors(i),'Size',5);
29 | end
30 | end
31 | writeVideo(outVideo,frame);
32 | frameIdx = frameIdx+1;
33 | end
34 | close(outVideo);
35 | end
36 |
--------------------------------------------------------------------------------
/src/+helper/processPredictions.m:
--------------------------------------------------------------------------------
1 | function [probMap, confidence] = processPredictions(predictions, confidence, threshold)
2 | % The function processPredictions applies post-processing to the
3 | % predictions output by the SCNN network.
4 | % The function returns the probability map and the confidence score of the
5 | % detections.
6 | %
7 | % Copyright 2021 The MathWorks, Inc.
8 |
9 | if ~strcmp(dims(confidence),'CB')
10 | tmp = predictions;
11 | predictions = confidence;
12 | confidence = tmp;
13 | end
14 |
15 | % Apply a softmax to the predictions.
16 | probMap = softmax(predictions);
17 |
18 | % Extract and gather te predcitions.
19 | probMap = gather(extractdata(probMap));
20 | confidence = gather(extractdata(confidence));
21 |
22 | % Only consider the confidence values that are greater than threshold.
23 | confidence(confidence>threshold) = 1;
24 | confidence(confidence<=threshold) = 0;
25 | end
26 |
--------------------------------------------------------------------------------
/src/+layer/MessagePassingLayer.m:
--------------------------------------------------------------------------------
1 | classdef MessagePassingLayer < nnet.layer.Layer
2 | % MessagePassingLayer Message Passing layer that applies
3 | % slice-by-slice convolution and relu within the
4 | % feature map.
5 | %
6 | % To create a MessagePassingLayer, use:
7 | % layer = MessagePassingLayer(filterSize, numFilters, direction, ...
8 | % 'Name', 'msgpassinglayer');
9 | %
10 | % Inputs:
11 | % -------
12 | % filterSize - Specify a positive integer for the filter size of the
13 | % slice-by-slice convolution operation.
14 | %
15 | % numFilters - Specify a positive integer for the number of filters
16 | % in the convolution layer.
17 | %
18 | % direction - Specify a string or character array for the direction
19 | % of the slice-by-slice convolution operation. The valid
20 | % value are 'topDown', 'bottomUp', 'leftRight', and
21 | % 'rightLeft'.
22 | %
23 | % layer = MessagePassingLayer(__, 'PARAM', VAL) specifies optional
24 | % parameter name/value pairs for creating the layer:
25 | %
26 | % 'Name' - A string or character array that specifies the name
27 | % for the layer.
28 | %
29 | % Default : ''
30 | %
31 | % Example:
32 | % --------
33 | % Create a top down message passing layer with filtersize of 9 and
34 | % 128 number of filters.
35 | % l = MessagePassingLayer(9, 128, 'topDown', 'Name', 'top-down');
36 |
37 | % Copyright 2021 The MathWorks, Inc.
38 |
39 | properties
40 | % NumFilters in message passing.
41 | NumFilters
42 |
43 | % Filter size of message passing block.
44 | FilterSize
45 |
46 | % Direction of message passing.
47 | Direction
48 | end
49 |
50 | properties (Learnable)
51 | % Layer learnable parameters
52 |
53 | MessagePassingBlock
54 | end
55 |
56 | methods
57 | function layer = MessagePassingLayer(filterSize,numFilters,direction,NameValueArgs)
58 | % Creates a conv+relu block with the specified convolution
59 | % filters, filter size and direction.
60 |
61 | % Parse input arguments.
62 | arguments
63 | filterSize {mustBeInteger,mustBePositive}
64 | numFilters {mustBeInteger,mustBePositive}
65 | direction {mustBeTextScalar}
66 | NameValueArgs.Name = ''
67 | end
68 |
69 | name = NameValueArgs.Name;
70 |
71 | % Set layer name.
72 | layer.Name = name;
73 | layer.NumFilters = numFilters;
74 | layer.FilterSize = filterSize;
75 | layer.Direction = iCheckAndReturnValidDirection(direction);
76 |
77 |
78 | % Set layer description.
79 | description = "Message passing " + layer.Direction + " block with num filters "+ layer.NumFilters+" and filter size "+layer.FilterSize;
80 | layer.Description = description;
81 |
82 | % Set layer type.
83 | layer.Type = "Message passing " + layer.Direction + " block";
84 |
85 | % Define nested layer graph.
86 | if strcmp(layer.Direction,'topDown') || strcmp(layer.Direction,'bottomUp')
87 | layers = [
88 | convolution2dLayer([1,filterSize],numFilters,'Padding',"same",'Name','conv1',"BiasLearnRateFactor",0)
89 | reluLayer('Name','relu1')
90 | ];
91 | else
92 | layers = [
93 | convolution2dLayer([filterSize,1],numFilters,'Padding',"same",'Name','conv1',"BiasLearnRateFactor",0)
94 | reluLayer('Name','relu1')
95 | ];
96 | end
97 |
98 | lgraph = layerGraph(layers);
99 |
100 | % Convert to dlnetwork.
101 | dlnet = dlnetwork(lgraph,'Initialize',false);
102 |
103 | % Set Network property.
104 | layer.MessagePassingBlock = dlnet;
105 | end
106 |
107 | function Z = predict(layer, X)
108 | % Forward input data through the layer at prediction time and
109 | % output the result.
110 | %
111 | % Inputs:
112 | % layer - Layer to forward propagate through
113 | % X - Input data
114 | % Outputs:
115 | % Z - Output of layer forward function
116 |
117 | % Convert input data to formatted dlarray.
118 | X = dlarray(X,'SSCB');
119 | dlnet = layer.MessagePassingBlock;
120 | fh = str2func(layer.Direction);
121 |
122 | % Process message passing block.
123 | Z = fh(dlnet,X);
124 | Z = stripdims(Z);
125 | end
126 | end
127 | end
128 |
129 | function value = iCheckAndReturnValidDirection(value)
130 | validateattributes(value, {'char','string'},{},'','Direction');
131 | value = validatestring(value, {'topDown', 'bottomUp', 'leftRight', 'rightLeft'},'','Direction');
132 | end
133 |
134 | function block = topDown(net,block)
135 | % Slice along the height.
136 | sliceDim = 1;
137 | for i = 2:size(block,sliceDim)
138 | z = predict(net,block(i-1,:,:,:));
139 | block(i,:,:,:) = block(i,:,:,:) + z;
140 | end
141 | end
142 |
143 | function block = bottomUp(net,block)
144 | % Slice along the height.
145 | sliceDim = 1;
146 | for i = size(block,sliceDim)-1:-1:1
147 | z = predict(net,block(i+1,:,:,:));
148 | block(i,:,:,:) = block(i,:,:,:) + z;
149 | end
150 | end
151 |
152 | function block = leftRight(net,block)
153 | % Slice along the width.
154 | sliceDim = 2;
155 | for i = 2:size(block,sliceDim)
156 | z = predict(net,block(:,i-1,:,:));
157 | block(:,i,:,:) = block(:,i,:,:) + z;
158 | end
159 | end
160 |
161 | function block = rightLeft(net,block)
162 | % Slice along the width.
163 | sliceDim = 2;
164 | for i = size(block,sliceDim)-1:-1:1
165 | z = predict(net,block(:,i+1,:,:));
166 | block(:,i,:,:) = block(:,i,:,:) + z;
167 | end
168 | end
169 |
--------------------------------------------------------------------------------
/src/detectLaneMarkings.m:
--------------------------------------------------------------------------------
1 | function detections = detectLaneMarkings(net, data, params, executionEnvironment, NameValueArgs)
2 | % detections = detectLaneMarkings(net, image, params, executionEnvironment)
3 | % runs prediction on a pre-trained SCNN network.
4 | %
5 | % Inputs:
6 | % -------
7 | % net - Pretrained SCNN dlnetwork.
8 | % data - Input data must be a single RGB image of size
9 | % HxWx3 or an array of RGB images of size HxWx3xB,
10 | % where H is the height, W is the width and B is the
11 | % number of images.
12 | % params - Parameters required to run inference on SCNN
13 | % created using
14 | % helper.createSCNNDetectionParameters.
15 | % executionEnvironment - Environment to run predictions on. Specify cpu,
16 | % gpu, or auto.
17 | %
18 | % detections = detectLaneMarkings(..., Name, Value) specifies the optional
19 | % name-value pair argument as described below.
20 | %
21 | % 'FitPolyLine' - Specify the value true or false. If true then
22 | % second order polynomials are fit to the detections
23 | % to make them smooth else raw detections are
24 | % returned.
25 | %
26 | % Default: true
27 | %
28 | %
29 | % Output:
30 | % -------
31 | % detections - Returns a cell array of M-by-N, where M is the
32 | % number of images in a batch and N is the number of
33 | % lanes detected. Each cell contains a P-by-2 array
34 | % of pixel coordinates of the detected lane
35 | % markings, where P is the number of points and 2
36 | % colums are X and Y values of their respective
37 | % pixel coordinates.
38 |
39 | % Copyright 2021 The MathWorks, Inc.
40 |
41 | % Parse input arguments.
42 | arguments
43 | net
44 | data
45 | params
46 | executionEnvironment
47 | NameValueArgs.FitPolyLine = true;
48 | end
49 | fitPolyLine = NameValueArgs.FitPolyLine;
50 |
51 | % Get the input image size.
52 | inputImageSize = size(data);
53 |
54 | % Resize the image to the input size of the network.
55 | resizedImage = imresize(data, params.networkInputSize);
56 |
57 | % Rescale the pixels in the range [0,1].
58 | resizedImage = im2single(resizedImage);
59 |
60 | % Convert the resized image to dlarray and gpuArray if specified.
61 | if canUseGPU && ~(strcmp(executionEnvironment,"cpu"))
62 | resizedImage = gpuArray(dlarray(resizedImage,'SSCB'));
63 | else
64 | resizedImage = dlarray(resizedImage,'SSCB');
65 | end
66 |
67 | % Predict the output.
68 | [laneMask, confidence] = predict(net, resizedImage);
69 |
70 | % Process the predictions to output probability map and confidence scores.
71 | [laneMask, confidence] = helper.processPredictions(laneMask, confidence, params.threshold);
72 |
73 | % Extract lane marking coordinates from the probability map and confidence
74 | % scores.
75 | detections = helper.generateLines(laneMask, confidence, inputImageSize, params, fitPolyLine);
76 |
77 | end
78 |
--------------------------------------------------------------------------------
/test/tMessagePassingLayer.m:
--------------------------------------------------------------------------------
1 | classdef tMessagePassingLayer < matlab.unittest.TestCase
2 | % tMessagePassingLayer. Test for MessagePassingLayer
3 | properties
4 | MessagePassingLayer = @layer.MessagePassingLayer;
5 | NumFilters= 128;
6 | FilterSize = 9;
7 | RS
8 | end
9 | properties(TestParameter)
10 | Direction = {'topDown', 'bottomUp', 'leftRight', 'rightLeft'};
11 | ExpectedOut = iGetExpectedOutputs();
12 | end
13 | methods(TestMethodSetup)
14 | function setupRNGSeed(test)
15 | test.RS = rng(0);
16 | end
17 | end
18 |
19 | methods(TestMethodTeardown)
20 | function resetRNGSeed(test)
21 | rng(test.RS);
22 | end
23 | end
24 | methods(Test,ParameterCombination = 'sequential')
25 | function testConstruction(test,Direction)
26 | layer = test.MessagePassingLayer(test.FilterSize,test.NumFilters,Direction,'Name',strcat('message_passing_',Direction));
27 |
28 | % Verifying the layer properties.
29 | test.verifyEqual(layer.Name,strcat('message_passing_',Direction))
30 | test.verifyEqual(layer.NumFilters,test.NumFilters);
31 | test.verifyEqual(layer.FilterSize,test.FilterSize);
32 | test.verifyEqual(layer.Type,strcat("Message passing " , Direction , " block"));
33 | test.verifyClass(layer.MessagePassingBlock,'dlnetwork')
34 | test.verifyEqual(numel(layer.MessagePassingBlock.Layers),2);
35 | test.verifyFalse(layer.MessagePassingBlock.Initialized);
36 | % Verifying the properties of layers inside
37 | % MessagePassingBlock.
38 | test.verifyClass(layer.MessagePassingBlock.Layers(1),'nnet.cnn.layer.Convolution2DLayer');
39 | test.verifyEqual(layer.MessagePassingBlock.Layers(1).Name,'conv1');
40 |
41 | if ismember(Direction,{'topDown', 'bottomUp'})
42 | test.verifyEqual(layer.MessagePassingBlock.Layers(1).FilterSize,[1 test.FilterSize]);
43 | else
44 | test.verifyEqual(layer.MessagePassingBlock.Layers(1).FilterSize,[test.FilterSize 1]);
45 | end
46 |
47 | test.verifyEqual(layer.MessagePassingBlock.Layers(1).NumFilters, test.NumFilters);
48 | test.verifyEqual(layer.MessagePassingBlock.Layers(1).PaddingMode,'same');
49 | test.verifyClass(layer.MessagePassingBlock.Layers(2),'nnet.cnn.layer.ReLULayer');
50 | test.verifyEqual(layer.MessagePassingBlock.Layers(2).Name,'relu1');
51 | end
52 | function testPredict(test,Direction,ExpectedOut)
53 | numFilters = 2;
54 | layer = test.MessagePassingLayer(test.FilterSize,numFilters,Direction);
55 | %constructing an initialized dlnetwork the messagePassingBlock
56 | %to call predict and exercising predict.
57 | dlnet = dlnetwork([imageInputLayer([3 3 3],'normalization','none') ; layer.MessagePassingBlock.Layers]);
58 | input = dlarray(ones(2,2,3),'SSC');
59 | out = dlnet.predict(input);
60 |
61 | test.verifyEqual(size(out),[2 2 numFilters]);
62 | test.verifyEqual(ExpectedOut,ExpectedOut)
63 | end
64 | end
65 | end
66 | function out = iGetExpectedOutputs()
67 | td(:,:,1) = [0.5833 0.6376; 0.5833 0.6376];
68 | td(:,:,2) = [0 0 ; 0 0 ];
69 | bu(:,:,1) = [0.5833 0.6376; 0.5833 0.6376];
70 | bu(:,:,2) = [0 0 ; 0 0 ];
71 | lr(:,:,1) = [0.5833 0.5833; 0.6376 0.6376];
72 | lr(:,:,2) = [0 0; 0 0];
73 | rl(:,:,1) = [0.5833 0.5833; 0.6376 0.6376];
74 | rl(:,:,2) = [0 0;0 0];
75 |
76 | out = struct('topDownExp',td, 'bottomUpExp',bu, 'leftRightExp',lr, 'rightLeftExp',rl);
77 | end
78 |
--------------------------------------------------------------------------------
/test/tPretrainedSCNNLaneDetection.m:
--------------------------------------------------------------------------------
1 | classdef(SharedTestFixtures = {DownloadSCNNLaneDetectionFixture}) tPretrainedSCNNLaneDetection < matlab.unittest.TestCase
2 | % Test for tPretrainedSCNNLaneDetection
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | % The shared test fixture downloads the model. Here we check the
7 | % inference on the pretrained model.
8 | properties
9 | RepoRoot = getRepoRoot;
10 | ModelName = 'scnn-culane.mat';
11 | end
12 |
13 | methods(Test)
14 | function exerciseDetection(test)
15 | model = load(fullfile(test.RepoRoot,'model',test.ModelName));
16 | image = imread(fullfile(test.RepoRoot,"images","testImage.jpg"));
17 | executionEnvironment = "auto";
18 | params = helper.createSCNNDetectionParameters;
19 | laneMarkings = detectLaneMarkings(model.net, image, params, executionEnvironment);
20 |
21 | % verifying output class
22 | test.verifyClass(laneMarkings,'cell');
23 |
24 | % verifying size of output from detectLaneMarkings.
25 | test.verifyEqual(size(laneMarkings{1}),[322 2]);
26 | test.verifyEqual(size(laneMarkings{2}),[560 2]);
27 | test.verifyEqual(size(laneMarkings{3}),[556 2]);
28 | test.verifyEqual(size(laneMarkings{4}),[0 0]);
29 | end
30 | end
31 | end
32 |
--------------------------------------------------------------------------------
/test/tdownloadSCNNLaneDetection.m:
--------------------------------------------------------------------------------
1 | classdef(SharedTestFixtures = {DownloadSCNNLaneDetectionFixture}) tdownloadSCNNLaneDetection < matlab.unittest.TestCase
2 | % Test for downloadSCNNLaneDetection
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | % The shared test fixture DownloadSCNNLaneDetectionFixture calls
7 | % downloadSCNNLaneDetection. Here we check that the downloaded files
8 | % exists in the appropriate location.
9 |
10 | properties
11 | DataDir = fullfile(getRepoRoot(),'model');
12 | end
13 |
14 | methods(Test)
15 | function verifyDownloadedFilesExist(test)
16 | dataFileName = 'scnn-culane.mat';
17 | test.verifyTrue(isequal(exist(fullfile(test.DataDir,dataFileName),'file'),2));
18 | end
19 | end
20 | end
21 |
--------------------------------------------------------------------------------
/test/tload.m:
--------------------------------------------------------------------------------
1 | classdef(SharedTestFixtures = {DownloadSCNNLaneDetectionFixture}) tload < matlab.unittest.TestCase
2 | % Test for loading the downloaded model.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | % The shared test fixture DownloadSCNNLaneDetectionFixture calls
7 | % downloadSCNNLaneDetection. Here we check that the properties of
8 | % downloaded model.
9 |
10 | properties
11 | DataDir = fullfile(getRepoRoot(),'model');
12 | end
13 |
14 | methods(Test)
15 | function verifyModelAndFields(test)
16 | % Test point to verify the fields of the downloaded models are
17 | % as expected.
18 |
19 | loadedModel = load(fullfile(test.DataDir,'scnn-culane.mat'));
20 |
21 | test.verifyClass(loadedModel.net,'dlnetwork');
22 | test.verifyEqual(numel(loadedModel.net.Layers),62);
23 | test.verifyEqual(size(loadedModel.net.Connections),[61 2])
24 | test.verifyEqual(loadedModel.net.InputNames,{'input'});
25 | test.verifyEqual(loadedModel.net.OutputNames,{'layer2_2','fc_3'});
26 | end
27 | end
28 | end
--------------------------------------------------------------------------------
/test/tools/DownloadSCNNLaneDetectionFixture.m:
--------------------------------------------------------------------------------
1 | classdef DownloadSCNNLaneDetectionFixture < matlab.unittest.fixtures.Fixture
2 | % DownloadSCNNLaneDetectionFixture A fixture for calling
3 | % downloadDownloadSCNNLaneDetectionFixture if necessary. This is to
4 | % ensure that this function is only called once and only when tests
5 | % need it. It also provides a teardown to return the test environment
6 | % to the expected state before testing.
7 |
8 | % Copyright 2021 The MathWorks, Inc
9 |
10 | properties(Constant)
11 | SCNNDataDir = fullfile(getRepoRoot(),'model')
12 | end
13 |
14 | properties
15 | SCNNExist (1,1) logical
16 | end
17 |
18 | methods
19 | function setup(this)
20 | import matlab.unittest.fixtures.CurrentFolderFixture;
21 | this.applyFixture(CurrentFolderFixture ...
22 | (getRepoRoot()));
23 |
24 | this.SCNNExist = exist(fullfile(this.SCNNDataDir,'scnn-culane.mat'),'file')==2;
25 |
26 | % Call this in eval to capture and drop any standard output
27 | % that we don't want polluting the test logs.
28 | if ~this.SCNNExist
29 | evalc('helper.downloadSCNNLaneDetection();');
30 | end
31 | end
32 |
33 | function teardown(this)
34 | if this.SCNNExist
35 | delete(fullfile(this.SCNNDataDir,'model','scnn-culane.mat'));
36 | end
37 | end
38 | end
39 | end
40 |
--------------------------------------------------------------------------------
/test/tools/getRepoRoot.m:
--------------------------------------------------------------------------------
1 | function path = getRepoRoot()
2 | % getRepoRoot Return a path to the repository's root directory.
3 |
4 | % Copyright 2021 The MathWorks, Inc.
5 |
6 | thisFile = mfilename('fullpath');
7 | thisDir = fileparts(thisFile);
8 |
9 | % the root is up two directories (/test/tools/getRepoRoot.m)
10 | path = fullfile(thisDir,'..','..');
11 | end
12 |
--------------------------------------------------------------------------------