├── .circleci └── config.yml ├── .gitignore ├── LICENSE ├── README.md ├── SECURITY.md ├── images ├── network.png └── result.png ├── models └── .gitkeep └── src └── +helper ├── applyActivations.m ├── applyAnchorBoxOffsets.m ├── augmentData.m ├── coco-classes.txt ├── configureTrainingProgressPlotter.m ├── createBatchData.m ├── displayLossInfo.m ├── downloadPretrainedYOLOv4.m ├── extractPredictions.m ├── generateTargets.m ├── generateTiledAnchors.m ├── getAnchors.m ├── getCOCOClassNames.m ├── piecewiseLearningRateWithWarmup.m ├── postprocess.m ├── preprocess.m ├── preprocessData.m ├── updatePlots.m ├── validateInputData.m └── yolov4Forward.m /.circleci/config.yml: -------------------------------------------------------------------------------- 1 | version: 2.1 2 | orbs: 3 | matlab: mathworks/matlab@0.4.0 4 | 5 | jobs: 6 | build: 7 | machine: 8 | image: ubuntu-1604:201903-01 9 | steps: 10 | - checkout 11 | - matlab/install: 12 | release: R2021a 13 | - matlab/run-tests: 14 | test-results-junit: artifacts/test_results/matlab/results.xml 15 | # Have to add test/tools to the path for certain tests. 16 | source-folder: .;test/tools;src 17 | - store_test_results: 18 | path: artifacts/test_results 19 | - store_artifacts: 20 | path: artifacts/ 21 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | models/ 2 | -------------------------------------------------------------------------------- /LICENSE: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, The MathWorks, Inc. 2 | All rights reserved. 3 | Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 4 | 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 5 | 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 6 | 3. In all cases, the software is, and all modifications and derivatives of the software shall be, licensed to you solely for use in conjunction with MathWorks products and service offerings. 7 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 8 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Pretrained YOLO v4 Network For Object Detection 2 | This repository provides a pretrained YOLO v4[1] object detection network for MATLAB®. [![Open in MATLAB Online](https://www.mathworks.com/images/responsive/global/open-in-matlab-online.svg)](https://matlab.mathworks.com/open/github/v1?repo=matlab-deep-learning/pretrained-yolo-v4) 3 | 4 | **Creator**: MathWorks Development 5 | 6 | 7 | ## Requirements 8 | - MATLAB® R2022a or later 9 | - Deep Learning Toolbox™ 10 | - Computer Vision Toolbox™ 11 | - Computer Vision Toolbox™ Model for YOLO v4 Object Detection 12 | 13 | Note: Previous MATLAB® release users can use [this](https://github.com/matlab-deep-learning/pretrained-yolo-v4/tree/previous) branch to download the pretrained models. 14 | 15 | 16 | ## Getting Started 17 | [Getting Started with YOLO v4](https://in.mathworks.com/help/vision/ug/getting-started-with-yolo-v4.html) 18 | 19 | 20 | ### Detect Objects Using Pretrained YOLO v4 21 | Use to code below to perform detection on an example image using the pretrained model. 22 | 23 | Note: This functionality requires Deep Learning Toolbox™ and the Computer Vision Toolbox™ for YOLO v4 Object Detection. You can install the Computer Vision Toolbox for YOLO v4 Object Detection from Add-On Explorer. For more information about installing add-ons, see [Get and Manage Add-Ons](https://in.mathworks.com/help/matlab/matlab_env/get-add-ons.html). 24 | 25 | ``` 26 | % Load pretrained detector 27 | modelName = 'csp-darknet53-coco'; 28 | detector = yolov4ObjectDetector(name); 29 | 30 | % Read test image. 31 | image = imread('visionteam.jpg'); 32 | 33 | % Detect objects in the test image. 34 | [boxes, scores, labels] = detect(detector, img); 35 | 36 | % Visualize detection results. 37 | img = insertObjectAnnotation(img, 'rectangle', bboxes, scores); 38 | figure, imshow(img) 39 | ``` 40 | ![alt text](images/result.png?raw=true) 41 | 42 | ### Choosing a Pretrained YOLO v4 Object Detector 43 | You can choose the ideal YOLO v4 object detector for your application based on the below table: 44 | 45 | | Model | Input image resolution | mAP | Size (MB) | Classes | 46 | | ------ | ------ | ------ | ------ | ------ | 47 | | YOLOv4-coco | 608 x 608 | 44.2 | 229 | [coco class names](src/+helper/coco-classes.txt) | 48 | | YOLOv4-tiny-coco | 416 x 416 | 19.7 | 21.5 | [coco class names](src/+helper/coco-classes.txt) | 49 | 50 | - mAP for models trained on the COCO dataset is computed as average over IoU of .5:.95. 51 | 52 | ### Train Custom YOLO v4 Detector Using Transfer Learning 53 | To train a YOLO v4 object detection network on a labeled data set, use the [trainYOLOv4ObjectDetector](https://in.mathworks.com/help/vision/ref/trainyolov4objectdetector.html) function. You must specify the class names for the data set you use to train the network. Then, train an untrained or pretrained network by using the [trainYOLOv4ObjectDetector](https://in.mathworks.com/help/vision/ref/trainyolov4objectdetector.html) function. The training function returns the trained network as a [yolov4ObjectDetector](https://in.mathworks.com/help/vision/ref/yolov4objectdetector.html) object. 54 | 55 | For more information about training a YOLO v4 object detector, see [Object Detection using YOLO v4 Deep Learning Example](https://in.mathworks.com/help/vision/ug/object-detection-using-yolov4-deep-learning.html). 56 | 57 | ## Code Generation for YOLO v4 58 | Code generation enables you to generate code and deploy YOLO v4 on multiple embedded platforms. For more information about generating CUDA® code using the YOLO v4 object detector see [Code Generation for Object Detection by Using YOLO v4](https://in.mathworks.com/help/gpucoder/ug/code-generation-for-object-detection-using-YOLO-v4.html) 59 | 60 | ## YOLO v4 Network Details 61 | YOLO v4 network architecture is comprised of three sections i.e. Backbone, Neck and Detection Head. 62 | 63 | ![alt text](images/network.png?raw=true) 64 | 65 | - **Backbone:** CSP-Darknet53(Cross-Stage-Partial Darknet53) is used as the backbone for YOLO v4 networks. This is a model with a higher input resolution (608 x 608), a larger receptive field size (725 x 725), a larger number of 3 x 3 convolutional layers and a larger number of parameters. Larger receptive field helps to view the entire objects in an image and understand the contexts around those. Higher input resolution helps in detection of small sized objects. Hence, CSP-Darknet53 is a suitable backbone for detecting multiple objects of different sizes in a single image. 66 | 67 | - **Neck:** This section comprised of many bottom-up and top-down aggregation paths. It helps to increase the receptive field further in the network and separates out the most significant context features and causes almost no reduction of the network operation speed. SPP (Spatial Pyramid Pooling) blocks have been added as neck section over the CSP-Darknet53 backbone. PANet (Path Aggregation Network) is used as the method of parameter aggregation from different backbone levels for different detector levels. 68 | 69 | - **Detection Head**: This section processes the aggregated features from the Neck section and predicts the Bounding boxes, Objectness score and Classification scores. This follows the principle of one-stage anchor based object detector. 70 | 71 | ## References 72 | [1] Bochkovskiy, Alexey, Chien-Yao Wang, and Hong-Yuan Mark Liao. "YOLOv4: Optimal Speed and Accuracy of Object Detection." arXiv preprint arXiv:2004.10934 (2020). 73 | 74 | [2] Lin, T., et al. "Microsoft COCO: Common objects in context. arXiv 2014." arXiv preprint arXiv:1405.0312 (2014). 75 | 76 | Copyright 2021 - 2024 The MathWorks, Inc. 77 | -------------------------------------------------------------------------------- /SECURITY.md: -------------------------------------------------------------------------------- 1 | # Reporting Security Vulnerabilities 2 | 3 | If you believe you have discovered a security vulnerability, please report it to 4 | [security@mathworks.com](mailto:security@mathworks.com). Please see 5 | [MathWorks Vulnerability Disclosure Policy for Security Researchers](https://www.mathworks.com/company/aboutus/policies_statements/vulnerability-disclosure-policy.html) 6 | for additional information. 7 | -------------------------------------------------------------------------------- /images/network.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-yolo-v4/3e300733529ed68af360fbc7c4e931c1c2a60a50/images/network.png -------------------------------------------------------------------------------- /images/result.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/matlab-deep-learning/pretrained-yolo-v4/3e300733529ed68af360fbc7c4e931c1c2a60a50/images/result.png -------------------------------------------------------------------------------- /models/.gitkeep: -------------------------------------------------------------------------------- 1 | 2 | -------------------------------------------------------------------------------- /src/+helper/applyActivations.m: -------------------------------------------------------------------------------- 1 | function YPredCell = applyActivations(YPredCell) 2 | % Apply activation functions on YOLOv4 outputs. 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | YPredCell(:,1:3) = cellfun(@ sigmoid, YPredCell(:,1:3), 'UniformOutput', false); 7 | YPredCell(:,4:5) = cellfun(@ exp, YPredCell(:,4:5), 'UniformOutput', false); 8 | YPredCell(:,6) = cellfun(@ sigmoid, YPredCell(:,6), 'UniformOutput', false); 9 | end -------------------------------------------------------------------------------- /src/+helper/applyAnchorBoxOffsets.m: -------------------------------------------------------------------------------- 1 | function tiledAnchors = applyAnchorBoxOffsets(tiledAnchors,YPredCell,inputImageSize) 2 | % Convert grid cell coordinates to box coordinates. 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | for i=1:size(YPredCell,1) 7 | [h,w,~,~] = size(YPredCell{i,1}); 8 | tiledAnchors{i,1} = (tiledAnchors{i,1}+YPredCell{i,1})./w; 9 | tiledAnchors{i,2} = (tiledAnchors{i,2}+YPredCell{i,2})./h; 10 | tiledAnchors{i,3} = (tiledAnchors{i,3}.*YPredCell{i,3})./inputImageSize(2); 11 | tiledAnchors{i,4} = (tiledAnchors{i,4}.*YPredCell{i,4})./inputImageSize(1); 12 | end 13 | end 14 | -------------------------------------------------------------------------------- /src/+helper/augmentData.m: -------------------------------------------------------------------------------- 1 | function data = augmentData(A) 2 | % Apply random horizontal flipping, and random X/Y scaling. Boxes that get 3 | % scaled outside the bounds are clipped if the overlap is above 0.25. Also, 4 | % jitter image color. 5 | 6 | % Copyright 2021 The MathWorks, Inc. 7 | 8 | data = cell(size(A)); 9 | for ii = 1:size(A,1) 10 | I = A{ii,1}; 11 | bboxes = A{ii,2}; 12 | labels = A{ii,3}; 13 | sz = size(I); 14 | 15 | if numel(sz) == 3 && sz(3) == 3 16 | I = jitterColorHSV(I,... 17 | 'Contrast',0.0,... 18 | 'Hue',0.1,... 19 | 'Saturation',0.2,... 20 | 'Brightness',0.2); 21 | end 22 | 23 | % Randomly flip image. 24 | tform = randomAffine2d('XReflection',true,'Scale',[1 1.1]); 25 | rout = affineOutputView(sz,tform,'BoundsStyle','centerOutput'); 26 | I = imwarp(I,tform,'OutputView',rout); 27 | 28 | % Apply same transform to boxes. 29 | [bboxes,indices] = bboxwarp(bboxes,tform,rout,'OverlapThreshold',0.25); 30 | labels = labels(indices); 31 | 32 | % Return original data only when all boxes are removed by warping. 33 | if isempty(indices) 34 | data(ii,:) = A(ii,:); 35 | else 36 | data(ii,:) = {I, bboxes, labels}; 37 | end 38 | end 39 | end 40 | -------------------------------------------------------------------------------- /src/+helper/coco-classes.txt: -------------------------------------------------------------------------------- 1 | person 2 | bicycle 3 | car 4 | motorbike 5 | aeroplane 6 | bus 7 | train 8 | truck 9 | boat 10 | traffic light 11 | fire hydrant 12 | stop sign 13 | parking meter 14 | bench 15 | bird 16 | cat 17 | dog 18 | horse 19 | sheep 20 | cow 21 | elephant 22 | bear 23 | zebra 24 | giraffe 25 | backpack 26 | umbrella 27 | handbag 28 | tie 29 | suitcase 30 | frisbee 31 | skis 32 | snowboard 33 | sports ball 34 | kite 35 | baseball bat 36 | baseball glove 37 | skateboard 38 | surfboard 39 | tennis racket 40 | bottle 41 | wine glass 42 | cup 43 | fork 44 | knife 45 | spoon 46 | bowl 47 | banana 48 | apple 49 | sandwich 50 | orange 51 | broccoli 52 | carrot 53 | hot dog 54 | pizza 55 | donut 56 | cake 57 | chair 58 | sofa 59 | pottedplant 60 | bed 61 | diningtable 62 | toilet 63 | tvmonitor 64 | laptop 65 | mouse 66 | remote 67 | keyboard 68 | cell phone 69 | microwave 70 | oven 71 | toaster 72 | sink 73 | refrigerator 74 | book 75 | clock 76 | vase 77 | scissors 78 | teddy bear 79 | hair drier 80 | toothbrush -------------------------------------------------------------------------------- /src/+helper/configureTrainingProgressPlotter.m: -------------------------------------------------------------------------------- 1 | function [lossPlotter, learningRatePlotter] = configureTrainingProgressPlotter(f) 2 | % Create the subplots to display the loss and learning rate. 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | figure(f); 7 | clf 8 | subplot(2,1,1); 9 | ylabel('Learning Rate'); 10 | xlabel('Iteration'); 11 | learningRatePlotter = animatedline; 12 | subplot(2,1,2); 13 | ylabel('Total Loss'); 14 | xlabel('Iteration'); 15 | lossPlotter = animatedline; 16 | end -------------------------------------------------------------------------------- /src/+helper/createBatchData.m: -------------------------------------------------------------------------------- 1 | function [XTrain, YTrain] = createBatchData(data, groundTruthBoxes, groundTruthClasses, classNames) 2 | % Returns images combined along the batch dimension in XTrain and 3 | % normalized bounding boxes concatenated with classIDs in YTrain. 4 | 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | % Concatenate images along the batch dimension. 8 | XTrain = cat(4, data{:,1}); 9 | 10 | % Get class IDs from the class names. 11 | classNames = repmat({categorical(classNames')}, size(groundTruthClasses)); 12 | [~, classIndices] = cellfun(@(a,b)ismember(a,b), groundTruthClasses, classNames, 'UniformOutput', false); 13 | 14 | % Append the label indexes and training image size to scaled bounding boxes 15 | % and create a single cell array of responses. 16 | combinedResponses = cellfun(@(bbox, classid)[bbox, classid], groundTruthBoxes, classIndices, 'UniformOutput', false); 17 | len = max( cellfun(@(x)size(x,1), combinedResponses ) ); 18 | paddedBBoxes = cellfun( @(v) padarray(v,[len-size(v,1),0],0,'post'), combinedResponses, 'UniformOutput',false); 19 | YTrain = cat(4, paddedBBoxes{:,1}); 20 | end -------------------------------------------------------------------------------- /src/+helper/displayLossInfo.m: -------------------------------------------------------------------------------- 1 | function displayLossInfo(epoch, iteration, currentLR, lossInfo) 2 | % Display loss information for each iteration. 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | disp("Epoch : " + epoch + " | Iteration : " + iteration + " | Learning Rate : " + currentLR + ... 7 | " | Total Loss : " + double(gather(extractdata(lossInfo.totalLoss))) + ... 8 | " | Box Loss : " + double(gather(extractdata(lossInfo.boxLoss))) + ... 9 | " | Object Loss : " + double(gather(extractdata(lossInfo.objLoss))) + ... 10 | " | Class Loss : " + double(gather(extractdata(lossInfo.clsLoss)))); 11 | end -------------------------------------------------------------------------------- /src/+helper/downloadPretrainedYOLOv4.m: -------------------------------------------------------------------------------- 1 | function model = downloadPretrainedYOLOv4(modelName) 2 | % The downloadPretrainedYOLOv4 function downloads a YOLO v4 network 3 | % pretrained on COCO dataset. 4 | % 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | supportedNetworks = ["YOLOv4-coco", "YOLOv4-tiny-coco"]; 8 | validatestring(modelName, supportedNetworks); 9 | 10 | dataPath = 'models'; 11 | netMatFileFullPath = fullfile(dataPath, [modelName, '.mat']); 12 | netZipFileFullPath = fullfile(dataPath, [modelName, '.zip']); 13 | 14 | if ~exist(netMatFileFullPath,'file') 15 | if ~exist(netZipFileFullPath,'file') 16 | fprintf(['Downloading pretrained ', modelName ,' network.\n']); 17 | fprintf('This can take several minutes to download...\n'); 18 | url = ['https://ssd.mathworks.com/supportfiles/vision/deeplearning/models/yolov4/', modelName, '.zip']; 19 | websave(netZipFileFullPath, url); 20 | fprintf('Done.\n\n'); 21 | unzip(netZipFileFullPath, dataPath); 22 | else 23 | fprintf(['Pretrained ', modelName, ' network already exists.\n\n']); 24 | unzip(netZipFileFullPath, dataPath); 25 | end 26 | else 27 | fprintf(['Pretrained ', modelName, ' network already exists.\n\n']); 28 | end 29 | 30 | model = load(netMatFileFullPath); 31 | end 32 | -------------------------------------------------------------------------------- /src/+helper/extractPredictions.m: -------------------------------------------------------------------------------- 1 | function predictions = extractPredictions(YPredictions, anchorBoxMask) 2 | % Function extractPrediction extracts and rearranges the prediction outputs 3 | % from YOLOv4 network. 4 | 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | predictions = cell(size(YPredictions, 1),6); 8 | for ii = 1:size(YPredictions, 1) 9 | % Get the required info on feature size. 10 | numChannelsPred = size(YPredictions{ii},3); 11 | numAnchors = size(anchorBoxMask{ii},2); 12 | numPredElemsPerAnchors = numChannelsPred/numAnchors; 13 | allIds = (1:numChannelsPred); 14 | 15 | stride = numPredElemsPerAnchors; 16 | endIdx = numChannelsPred; 17 | 18 | % X positions. 19 | startIdx = 1; 20 | predictions{ii,2} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:); 21 | xIds = startIdx:stride:endIdx; 22 | 23 | % Y positions. 24 | startIdx = 2; 25 | predictions{ii,3} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:); 26 | yIds = startIdx:stride:endIdx; 27 | 28 | % Width. 29 | startIdx = 3; 30 | predictions{ii,4} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:); 31 | wIds = startIdx:stride:endIdx; 32 | 33 | % Height. 34 | startIdx = 4; 35 | predictions{ii,5} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:); 36 | hIds = startIdx:stride:endIdx; 37 | 38 | % Confidence scores. 39 | startIdx = 5; 40 | predictions{ii,1} = YPredictions{ii}(:,:,startIdx:stride:endIdx,:); 41 | confIds = startIdx:stride:endIdx; 42 | 43 | % Accummulate all the non-class indexes 44 | nonClassIds = [xIds yIds wIds hIds confIds]; 45 | 46 | % Class probabilities. 47 | % Get the indexes which do not belong to the nonClassIds 48 | classIdx = setdiff(allIds,nonClassIds); 49 | predictions{ii,6} = YPredictions{ii}(:,:,classIdx,:); 50 | end 51 | end -------------------------------------------------------------------------------- /src/+helper/generateTargets.m: -------------------------------------------------------------------------------- 1 | function [boxDeltaTarget, objectnessTarget, classTarget, maskTarget, boxErrorScaleTarget] = generateTargets(YPredCellGathered, groundTruth, inputImageSize, anchorBoxes, anchorBoxMask, penaltyThreshold) 2 | % generateTargets creates target array for every prediction element 3 | % x, y, width, height, confidence scores and class probabilities. 4 | 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | boxDeltaTarget = cell(size(YPredCellGathered,1),4); 8 | objectnessTarget = cell(size(YPredCellGathered,1),1); 9 | classTarget = cell(size(YPredCellGathered,1),1); 10 | maskTarget = cell(size(YPredCellGathered,1),3); 11 | boxErrorScaleTarget = cell(size(YPredCellGathered,1),1); 12 | 13 | % Normalize the ground truth boxes w.r.t image input size. 14 | gtScale = [inputImageSize(2) inputImageSize(1) inputImageSize(2) inputImageSize(1)]; 15 | groundTruth(:,1:4,:,:) = groundTruth(:,1:4,:,:)./gtScale; 16 | 17 | scale_X_Y=[1.2 1.1 1.05]; 18 | 19 | for numPred = 1:size(YPredCellGathered,1) 20 | 21 | % Select anchor boxes based on anchor box mask indices. 22 | anchors = anchorBoxes(anchorBoxMask{numPred},:); 23 | scaleXY = scale_X_Y(numPred); 24 | 25 | bx = YPredCellGathered{numPred,2}; 26 | by = YPredCellGathered{numPred,3}; 27 | bw = YPredCellGathered{numPred,4}; 28 | bh = YPredCellGathered{numPred,5}; 29 | predClasses = YPredCellGathered{numPred,6}; 30 | 31 | gridSize = size(bx); 32 | if numel(gridSize)== 3 33 | gridSize(4) = 1; 34 | end 35 | numClasses = size(predClasses,3)/size(anchors,1); 36 | 37 | % Initialize the required variables. 38 | mask = single(zeros(size(bx))); 39 | confMask = single(ones(size(bx))); 40 | classMask = single(zeros(size(predClasses))); 41 | tx = single(zeros(size(bx))); 42 | ty = single(zeros(size(by))); 43 | tw = single(zeros(size(bw))); 44 | th = single(zeros(size(bh))); 45 | tconf = single(zeros(size(bx))); 46 | tclass = single(zeros(size(predClasses))); 47 | boxErrorScale = single(ones(size(bx))); 48 | 49 | % Get the IOU of predictions with groundtruth. 50 | iou = getMaxIOUPredictedWithGroundTruth(bx,by,bw,bh,groundTruth); 51 | 52 | % Donot penalize the predictions which has iou greater than penalty 53 | % threshold. 54 | confMask(iou > penaltyThreshold) = 0; 55 | 56 | for batch = 1:gridSize(4) 57 | truthBatch = groundTruth(:,1:5,:,batch); 58 | truthBatch = truthBatch(all(truthBatch,2),:); 59 | 60 | % Get boxes with center as 0. 61 | gtPred = [0-truthBatch(:,3)/2,0-truthBatch(:,4)/2,truthBatch(:,3),truthBatch(:,4)]; 62 | anchorPrior = [0-anchorBoxes(:,2)/(2*inputImageSize(2)),0-anchorBoxes(:,1)/(2*inputImageSize(1)),anchorBoxes(:,2)/inputImageSize(2),anchorBoxes(:,1)/inputImageSize(1)]; 63 | 64 | % Get the iou of best matching anchor box. 65 | overLap = bboxOverlapRatio(gtPred,anchorPrior); 66 | [~,bestAnchorIdx] = max(overLap,[],2); 67 | 68 | % Select gt that are within the mask. 69 | index = ismember(bestAnchorIdx,anchorBoxMask{numPred}); 70 | truthBatch = truthBatch(index,:); 71 | bestAnchorIdx = bestAnchorIdx(index,:); 72 | bestAnchorIdx = bestAnchorIdx - anchorBoxMask{numPred}(1,1) + 1; 73 | 74 | if ~isempty(truthBatch) 75 | % Convert top left position of ground-truth to centre coordinates. 76 | truthBatch = [truthBatch(:,1)+truthBatch(:,3)./2,truthBatch(:,2)+truthBatch(:,4)./2,truthBatch(:,3),truthBatch(:,4),truthBatch(:,5)]; 77 | 78 | errorScale = 2 - truthBatch(:,3).*truthBatch(:,4); 79 | truthBatch = [truthBatch(:,1)*gridSize(2),truthBatch(:,2)*gridSize(1),truthBatch(:,3)*inputImageSize(2),truthBatch(:,4)*inputImageSize(1),truthBatch(:,5)]; 80 | for t = 1:size(truthBatch,1) 81 | 82 | % Get the position of ground-truth box in the grid. 83 | colIdx = ceil(truthBatch(t,1)); 84 | colIdx(colIdx<1) = 1; 85 | colIdx(colIdx>gridSize(2)) = gridSize(2); 86 | rowIdx = ceil(truthBatch(t,2)); 87 | rowIdx(rowIdx<1) = 1; 88 | rowIdx(rowIdx>gridSize(1)) = gridSize(1); 89 | pos = [rowIdx,colIdx]; 90 | anchorIdx = bestAnchorIdx(t,1); 91 | 92 | mask(pos(1,1),pos(1,2),anchorIdx,batch) = 1; 93 | confMask(pos(1,1),pos(1,2),anchorIdx,batch) = 1; 94 | 95 | % Calculate the shift in ground-truth boxes. 96 | tShiftX = truthBatch(t,1)-pos(1,2)+1; 97 | tShiftY = truthBatch(t,2)-pos(1,1)+1; 98 | tShiftW = log(truthBatch(t,3)/anchors(anchorIdx,2)); 99 | tShiftH = log(truthBatch(t,4)/anchors(anchorIdx,1)); 100 | 101 | % Update the target box. 102 | tx(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftX; 103 | ty(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftY; 104 | tw(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftW; 105 | th(pos(1,1),pos(1,2),anchorIdx,batch) = tShiftH; 106 | boxErrorScale(pos(1,1),pos(1,2),anchorIdx,batch) = errorScale(t); 107 | tconf(rowIdx,colIdx,anchorIdx,batch) = 1; 108 | classIdx = (numClasses*(anchorIdx-1))+truthBatch(t,5); 109 | tclass(rowIdx,colIdx,classIdx,batch) = 1; 110 | classMask(rowIdx,colIdx,(numClasses*(anchorIdx-1))+(1:numClasses),batch) = 1; 111 | end 112 | end 113 | end 114 | boxDeltaTarget(numPred,:) = [{tx} {ty} {tw} {th}]; 115 | objectnessTarget{numPred,1} = tconf; 116 | classTarget{numPred,1} = tclass; 117 | maskTarget(numPred,:) = [{mask} {confMask} {classMask}]; 118 | boxErrorScaleTarget{numPred,:} = boxErrorScale; 119 | end 120 | end 121 | 122 | function iou = getMaxIOUPredictedWithGroundTruth(predx,predy,predw,predh,truth) 123 | % getMaxIOUPredictedWithGroundTruth computes the maximum intersection over 124 | % union scores for every pair of predictions and ground-truth boxes. 125 | 126 | [h,w,c,n] = size(predx); 127 | iou = zeros([h w c n],'like',predx); 128 | 129 | % For each batch prepare the predictions and ground-truth. 130 | for batchSize = 1:n 131 | truthBatch = truth(:,1:4,1,batchSize); 132 | truthBatch = truthBatch(all(truthBatch,2),:); 133 | predxb = predx(:,:,:,batchSize); 134 | predyb = predy(:,:,:,batchSize); 135 | predwb = predw(:,:,:,batchSize); 136 | predhb = predh(:,:,:,batchSize); 137 | predb = [predxb(:),predyb(:),predwb(:),predhb(:)]; 138 | 139 | % Convert from center xy coordinate to topleft xy coordinate. 140 | predb = [predb(:,1)-predb(:,3)./2, predb(:,2)-predb(:,4)./2, predb(:,3), predb(:,4)]; 141 | 142 | % Compute and extract the maximum IOU of predictions with ground-truth. 143 | try 144 | overlap = bboxOverlapRatio(predb, truthBatch); 145 | catch me 146 | if(any(isnan(predb(:))|isinf(predb(:)))) 147 | error(me.message + " NaN/Inf has been detected during training. Try reducing the learning rate."); 148 | elseif(any(predb(:,3)<=0 | predb(:,4)<=0)) 149 | error(me.message + " Invalid predictions during training. Try reducing the learning rate."); 150 | else 151 | error(me.message + " Invalid groundtruth. Check that your ground truth boxes are not empty and finite, are fully contained within the image boundary, and have positive width and height."); 152 | end 153 | end 154 | 155 | maxOverlap = max(overlap,[],2); 156 | iou(:,:,:,batchSize) = reshape(maxOverlap,h,w,c); 157 | end 158 | end 159 | -------------------------------------------------------------------------------- /src/+helper/generateTiledAnchors.m: -------------------------------------------------------------------------------- 1 | function tiledAnchors = generateTiledAnchors(YPredCell,anchorBoxes,anchorBoxMask) 2 | % Generate tiled anchor offset. 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | tiledAnchors = cell(size(YPredCell)); 7 | for i=1:size(YPredCell,1) 8 | anchors = anchorBoxes(anchorBoxMask{i}, :); 9 | [h,w,~,n] = size(YPredCell{i,1}); 10 | [tiledAnchors{i,2}, tiledAnchors{i,1}] = ndgrid(0:h-1,0:w-1,1:size(anchors,1),1:n); 11 | [~,~,tiledAnchors{i,3}] = ndgrid(0:h-1,0:w-1,anchors(:,2),1:n); 12 | [~,~,tiledAnchors{i,4}] = ndgrid(0:h-1,0:w-1,anchors(:,1),1:n); 13 | end 14 | end -------------------------------------------------------------------------------- /src/+helper/getAnchors.m: -------------------------------------------------------------------------------- 1 | function anchors = getAnchors(modelName) 2 | % getAnchors function returns the anchors used in the training of the 3 | % specified pretrained YOLO v4 model. 4 | 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | if isequal(modelName, 'YOLOv4-coco') 8 | anchors.anchorBoxes = [16 12; 36 19; 28 40;... 9 | 75 36; 55 76; 146 72;... 10 | 110 142; 243 192; 401 459]; 11 | anchors.anchorBoxMasks = {[1,2,3] 12 | [4,5,6] 13 | [7,8,9]}; 14 | elseif isequal(modelName, 'YOLOv4-tiny-coco') 15 | anchors.anchorBoxes = [82 81; 169 135; 319 344;... 16 | 27 23; 58 37; 82 81]; 17 | anchors.anchorBoxMasks = {[1,2,3] 18 | [4,5,6]}; 19 | end 20 | end -------------------------------------------------------------------------------- /src/+helper/getCOCOClassNames.m: -------------------------------------------------------------------------------- 1 | function classNames = getCOCOClassNames() 2 | % getCOCOClassNames function returns the names of COCO dataset classes. 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | classNames={'person' 7 | 'bicycle' 8 | 'car' 9 | 'motorbike' 10 | 'aeroplane' 11 | 'bus' 12 | 'train' 13 | 'truck' 14 | 'boat' 15 | 'traffic light' 16 | 'fire hydrant' 17 | 'stop sign' 18 | 'parking meter' 19 | 'bench' 20 | 'bird' 21 | 'cat' 22 | 'dog' 23 | 'horse' 24 | 'sheep' 25 | 'cow' 26 | 'elephant' 27 | 'bear' 28 | 'zebra' 29 | 'giraffe' 30 | 'backpack' 31 | 'umbrella' 32 | 'handbag' 33 | 'tie' 34 | 'suitcase' 35 | 'frisbee' 36 | 'skis' 37 | 'snowboard' 38 | 'sports ball' 39 | 'kite' 40 | 'baseball bat' 41 | 'baseball glove' 42 | 'skateboard' 43 | 'surfboard' 44 | 'tennis racket' 45 | 'bottle' 46 | 'wine glass' 47 | 'cup' 48 | 'fork' 49 | 'knife' 50 | 'spoon' 51 | 'bowl' 52 | 'banana' 53 | 'apple' 54 | 'sandwich' 55 | 'orange' 56 | 'broccoli' 57 | 'carrot' 58 | 'hot dog' 59 | 'pizza' 60 | 'donut' 61 | 'cake' 62 | 'chair' 63 | 'sofa' 64 | 'pottedplant' 65 | 'bed' 66 | 'diningtable' 67 | 'toilet' 68 | 'tvmonitor' 69 | 'laptop' 70 | 'mouse' 71 | 'remote' 72 | 'keyboard' 73 | 'cell phone' 74 | 'microwave' 75 | 'oven' 76 | 'toaster' 77 | 'sink' 78 | 'refrigerator' 79 | 'book' 80 | 'clock' 81 | 'vase' 82 | 'scissors' 83 | 'teddy bear' 84 | 'hair drier' 85 | 'toothbrush'}; 86 | end -------------------------------------------------------------------------------- /src/+helper/piecewiseLearningRateWithWarmup.m: -------------------------------------------------------------------------------- 1 | function currentLR = piecewiseLearningRateWithWarmup(iteration, epoch, learningRate, warmupPeriod, numEpochs) 2 | % The piecewiseLearningRateWithWarmup function computes the current 3 | % learning rate based on the iteration number. 4 | 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | persistent warmUpEpoch; 8 | 9 | if iteration <= warmupPeriod 10 | % Increase the learning rate for number of iterations in warmup period. 11 | currentLR = learningRate * ((iteration/warmupPeriod)^4); 12 | warmUpEpoch = epoch; 13 | elseif iteration >= warmupPeriod && epoch < warmUpEpoch+floor(0.6*(numEpochs-warmUpEpoch)) 14 | % After warm up period, keep the learning rate constant if the remaining number of epochs is less than 60 percent. 15 | currentLR = learningRate; 16 | 17 | elseif epoch >= warmUpEpoch + floor(0.6*(numEpochs-warmUpEpoch)) && epoch < warmUpEpoch+floor(0.9*(numEpochs-warmUpEpoch)) 18 | % If the remaining number of epochs is more than 60 percent but less 19 | % than 90 percent multiply the learning rate by 0.1. 20 | currentLR = learningRate*0.1; 21 | 22 | else 23 | % If remaining epochs are more than 90 percent multiply the learning 24 | % rate by 0.01. 25 | currentLR = learningRate*0.01; 26 | end 27 | end -------------------------------------------------------------------------------- /src/+helper/postprocess.m: -------------------------------------------------------------------------------- 1 | function [bboxes,scores,labels] = postprocess(outFeatureMaps, anchors, netInputSize, scale, classNames) 2 | % The postprocess function applies postprocessing on the generated feature 3 | % maps and returns bounding boxes, detection scores and labels. 4 | 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | % Get number of classes. 8 | if isrow(classNames) 9 | classNames = classNames'; 10 | end 11 | classNames = categorical(classNames); 12 | numClasses = size(classNames,1); 13 | 14 | % Get anchor boxes and anchor boxes masks. 15 | anchorBoxes = anchors.anchorBoxes; 16 | anchorBoxMasks = anchors.anchorBoxMasks; 17 | 18 | % Postprocess generated feature maps. 19 | outputFeatures = []; 20 | for i = 1:size(outFeatureMaps,1) 21 | currentFeatureMap = outFeatureMaps{i}; 22 | numY = size(currentFeatureMap,1); 23 | numX = size(currentFeatureMap,2); 24 | stride = max(netInputSize)./max(numX, numY); 25 | batchsize = size(currentFeatureMap,4); 26 | h = numY; 27 | w = numX; 28 | numAnchors = size(anchorBoxMasks{i},2); 29 | 30 | currentFeatureMap = reshape(currentFeatureMap,h,w,5+numClasses,numAnchors,batchsize); 31 | currentFeatureMap = permute(currentFeatureMap,[5,4,1,2,3]); 32 | 33 | [~,~,yv,xv] = ndgrid(1:batchsize,1:numAnchors,0:h-1,0:w-1); 34 | gridXY = cat(5,xv,yv); 35 | currentFeatureMap(:,:,:,:,1:2) = sigmoid(currentFeatureMap(:,:,:,:,1:2)) + gridXY; 36 | anchorBoxesCurrentLevel= anchorBoxes(anchorBoxMasks{i}, :); 37 | anchorBoxesCurrentLevel(:,[2,1]) = anchorBoxesCurrentLevel(:,[1,2]); 38 | anchor_grid = anchorBoxesCurrentLevel/stride; 39 | anchor_grid = reshape(anchor_grid,1,numAnchors,1,1,2); 40 | currentFeatureMap(:,:,:,:,3:4) = exp(currentFeatureMap(:,:,:,:,3:4)).*anchor_grid; 41 | currentFeatureMap(:,:,:,:,1:4) = currentFeatureMap(:,:,:,:,1:4)*stride; 42 | currentFeatureMap(:,:,:,:,5:end) = sigmoid(currentFeatureMap(:,:,:,:,5:end)); 43 | 44 | if numClasses == 1 45 | currentFeatureMap(:,:,:,:,6) = 1; 46 | end 47 | currentFeatureMap = reshape(currentFeatureMap,batchsize,[],5+numClasses); 48 | 49 | if isempty(outputFeatures) 50 | outputFeatures = currentFeatureMap; 51 | else 52 | outputFeatures = cat(2,outputFeatures,currentFeatureMap); 53 | end 54 | end 55 | 56 | % Coordinate conversion to the original image. 57 | outputFeatures = extractdata(outputFeatures);% [x_center,y_center,w,h,Pobj,p1,p2,...,pn] 58 | outputFeatures(:,:,[1,3]) = outputFeatures(:,:,[1,3])*scale(2);% x_center,width 59 | outputFeatures(:,:,[2,4]) = outputFeatures(:,:,[2,4])*scale(1);% y_center,height 60 | outputFeatures(:,:,1) = outputFeatures(:,:,1) -outputFeatures(:,:,3)/2;% x 61 | outputFeatures(:,:,2) = outputFeatures(:,:,2) -outputFeatures(:,:,4)/2; % y 62 | outputFeatures = squeeze(outputFeatures); % If it is a single image detection, the output size is M*(5+numClasses), otherwise it is bs*M*(5+numClasses) 63 | 64 | if(canUseGPU()) 65 | outputFeatures = gather(outputFeatures); 66 | end 67 | 68 | % Apply Confidence threshold and Non-maximum suppression. 69 | confidenceThreshold = 0.5; 70 | overlapThresold = 0.5; 71 | 72 | scores = outputFeatures(:,5); 73 | outFeatures = outputFeatures(scores>confidenceThreshold,:); 74 | 75 | allBBoxes = outFeatures(:,1:4); 76 | allScores = outFeatures(:,5); 77 | [maxScores,indxs] = max(outFeatures(:,6:end),[],2); 78 | allScores = allScores.*maxScores; 79 | allLabels = classNames(indxs); 80 | 81 | bboxes = []; 82 | scores = []; 83 | labels = []; 84 | if ~isempty(allBBoxes) 85 | [bboxes,scores,labels] = selectStrongestBboxMulticlass(allBBoxes,allScores,allLabels,... 86 | 'RatioType','Min','OverlapThreshold',overlapThresold); 87 | end 88 | end 89 | -------------------------------------------------------------------------------- /src/+helper/preprocess.m: -------------------------------------------------------------------------------- 1 | function [output, scale] = preprocess(image, netInputSize) 2 | % The preprocess function applies preprocessing on the input image. 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | inputSize = [size(image,1),size(image,2)]; 7 | scale = inputSize./netInputSize(1:2); 8 | 9 | output = im2single(imresize(image,netInputSize(1:2))); 10 | end -------------------------------------------------------------------------------- /src/+helper/preprocessData.m: -------------------------------------------------------------------------------- 1 | function data = preprocessData(data, targetSize) 2 | % Resize the images and scale the pixels to between 0 and 1. Also scale the 3 | % corresponding bounding boxes. 4 | 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | for ii = 1:size(data,1) 8 | I = data{ii,1}; 9 | imgSize = size(I); 10 | 11 | % Convert an input image with single channel to 3 channels. 12 | if numel(imgSize) < 3 13 | I = repmat(I,1,1,3); 14 | end 15 | bboxes = data{ii,2}; 16 | 17 | I = im2single(imresize(I,targetSize(1:2))); 18 | scale = targetSize(1:2)./imgSize(1:2); 19 | bboxes = bboxresize(bboxes,scale); 20 | 21 | data(ii, 1:2) = {I, bboxes}; 22 | end 23 | end -------------------------------------------------------------------------------- /src/+helper/updatePlots.m: -------------------------------------------------------------------------------- 1 | function updatePlots(lossPlotter, learningRatePlotter, iteration, currentLR, totalLoss) 2 | % Update loss and learning rate plots. 3 | 4 | % Copyright 2021 The MathWorks, Inc. 5 | 6 | addpoints(lossPlotter, iteration, double(extractdata(gather(totalLoss)))); 7 | addpoints(learningRatePlotter, iteration, currentLR); 8 | drawnow 9 | end -------------------------------------------------------------------------------- /src/+helper/validateInputData.m: -------------------------------------------------------------------------------- 1 | function validateInputData(ds) 2 | % Validates the input images, bounding boxes and labels and displays the 3 | % paths of invalid samples. 4 | 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | % Path to images 8 | info = ds.UnderlyingDatastores{1}.Files; 9 | 10 | ds = transform(ds, @isValidDetectorData); 11 | data = readall(ds); 12 | 13 | validImgs = [data.validImgs]; 14 | validBoxes = [data.validBoxes]; 15 | validLabels = [data.validLabels]; 16 | 17 | msg = ""; 18 | 19 | if(any(~validImgs)) 20 | imPaths = info(~validImgs); 21 | str = strjoin(imPaths, '\n'); 22 | imErrMsg = sprintf("Input images must be non-empty and have 2 or 3 dimensions. The following images are invalid:\n") + str; 23 | msg = (imErrMsg + newline + newline); 24 | end 25 | 26 | if(any(~validBoxes)) 27 | imPaths = info(~validBoxes); 28 | str = strjoin(imPaths, '\n'); 29 | boxErrMsg = sprintf("Bounding box data must be M-by-4 matrices of positive integer values. The following images have invalid bounding box data:\n") ... 30 | + str; 31 | 32 | msg = (msg + boxErrMsg + newline + newline); 33 | end 34 | 35 | if(any(~validLabels)) 36 | imPaths = info(~validLabels); 37 | str = strjoin(imPaths, '\n'); 38 | labelErrMsg = sprintf("Labels must be non-empty and categorical. The following images have invalid labels:\n") + str; 39 | 40 | msg = (msg + labelErrMsg + newline); 41 | end 42 | 43 | if(~isempty(msg)) 44 | error(msg); 45 | end 46 | 47 | end 48 | 49 | function out = isValidDetectorData(data) 50 | % Checks validity of images, bounding boxes and labels 51 | for i = 1:size(data,1) 52 | I = data{i,1}; 53 | boxes = data{i,2}; 54 | labels = data{i,3}; 55 | 56 | imageSize = size(I); 57 | mSize = size(boxes, 1); 58 | 59 | out.validImgs(i) = iCheckImages(I); 60 | out.validBoxes(i) = iCheckBoxes(boxes, imageSize); 61 | out.validLabels(i) = iCheckLabels(labels, mSize); 62 | end 63 | 64 | end 65 | 66 | function valid = iCheckImages(I) 67 | % Validates the input images. 68 | 69 | valid = true; 70 | if ndims(I) == 2 71 | nDims = 2; 72 | else 73 | nDims = 3; 74 | end 75 | % Define image validation parameters. 76 | classes = {'numeric'}; 77 | attrs = {'nonempty', 'nonsparse', 'nonnan', 'finite', 'ndims', nDims}; 78 | try 79 | validateattributes(I, classes, attrs); 80 | catch 81 | valid = false; 82 | end 83 | end 84 | 85 | function valid = iCheckBoxes(boxes, imageSize) 86 | % Validates the ground-truth bounding boxes to be non-empty and finite. 87 | 88 | valid = true; 89 | % Define bounding box validation parameters. 90 | classes = {'numeric'}; 91 | attrs = {'nonempty', 'integer', 'nonnan', 'finite', 'positive', 'nonzero', 'nonsparse', '2d', 'ncols', 4}; 92 | try 93 | validateattributes(boxes, classes, attrs); 94 | % Validate if bounding box in within image boundary. 95 | validateattributes(boxes(:,1)+boxes(:,3)-1, classes, {'<=', imageSize(2)}); 96 | validateattributes(boxes(:,2)+boxes(:,4)-1, classes, {'<=', imageSize(1)}); 97 | catch 98 | valid = false; 99 | end 100 | end 101 | 102 | function valid = iCheckLabels(labels, mSize) 103 | % Validates the labels. 104 | 105 | valid = true; 106 | % Define label validation parameters. 107 | classes = {'categorical'}; 108 | attrs = {'nonempty', 'nonsparse', '2d', 'ncols', 1, 'nrows', mSize}; 109 | try 110 | validateattributes(labels, classes, attrs); 111 | catch 112 | valid = false; 113 | end 114 | end -------------------------------------------------------------------------------- /src/+helper/yolov4Forward.m: -------------------------------------------------------------------------------- 1 | function [YPredCell, state] = yolov4Forward(net, XTrain, networkOutputs, anchorBoxMask) 2 | % Predict the output of network and extract the confidence score, x, y, 3 | % width, height, and class. 4 | 5 | % Copyright 2021 The MathWorks, Inc. 6 | 7 | YPredictions = cell(size(networkOutputs)); 8 | [YPredictions{:}, state] = forward(net, XTrain, 'Outputs', networkOutputs); 9 | YPredCell = helper.extractPredictions(YPredictions, anchorBoxMask); 10 | 11 | % Append predicted width and height to the end as they are required 12 | % for computing the loss. 13 | YPredCell(:,7:8) = YPredCell(:,4:5); 14 | 15 | % Apply sigmoid and exponential activation. 16 | YPredCell(:,1:6) = helper.applyActivations(YPredCell(:,1:6)); 17 | end --------------------------------------------------------------------------------