├── README.md ├── README_images ├── figure_1.png ├── figure_2.png ├── image_0.png ├── image_1.png ├── image_2.png ├── image_3.png └── image_4.png ├── changeLog.txt ├── license.txt ├── multiInputCNN.mlx ├── partitionData.m └── prepareDigitDataset.m /README.md: -------------------------------------------------------------------------------- 1 | [![View Image Classification using CNN with Multi Input 複数の入力層を持つCNN on File Exchange](https://www.mathworks.com/matlabcentral/images/matlab-file-exchange.svg)](https://jp.mathworks.com/matlabcentral/fileexchange/74760-image-classification-using-cnn-with-multi-input-cnn) 2 | # Image Classification using Convolutional Neural Network with Multi-Input 3 | 4 | **[English]** 5 | This demo shows how to implement convolutional neural network (CNN) for image classification with multi-input using `custom loop` method. As an example, a dataset of hand-written digits called MNIST was divided into the upper half and down half as shown below and the upper and down part were fed into the multi input CNN. 6 | **[Japanese]** 7 | 2種類の画像を入力できる分類のための畳み込みニューラルネットワークのデモです。2つの入力層があって、例えば、入力層Aには、動物の顔の画像を入力し、入力層Bには、その動物の足の画像を入力する、などです。2019bバージョンからカスタムループと呼ばれる方法が可能になり、深層学習のより詳細なカスタマイズが可能となりました。簡単に試せるように、手書き数字の上半分と下半分をそれぞれ別の入力層からインプットし、畳み込みなどを行った後に得られた特徴量を結合させ、さらに全結合層などで計算を進めています。よりこの例に適切なデータや課題などがございましたら教えていただけると幸いです。まだまだ作りこみ不足なところもあり、今後も更新していければよいと考えています。 8 | 9 | ![image_0.png](README_images/image_0.png) 10 | 11 | ![image_1.png](README_images/image_1.png) 12 | 13 | ![image_2.png](README_images/image_2.png) 14 | 15 | The figure above shows the classification accuracy with the multi-input CNN. The top and down part of the digits were fed into the multi-input CNN, the accuracy was over 96 %. If only the top or down part were used for the CNN, the accuracy was significantly lower than that with multi-input. 16 | 17 | 18 | 19 | ![image_3.png](README_images/image_3.png) 20 | 21 | Note that this figure is cited from ref [1]. The paper was talking about video classification, not still image classification. However, the fusion model described above is very infomative. In my understanding, this demo is similar to early late fusion, (but please do confirm). Other types of fusion may be implemented in my future work. In ref [2], they proposed a deep learning model called TM-CNN for multi-lane traffic speed prediction, which would be related to this demo. 22 | 23 | [1] Karpathy, A., Toderici, G., Shetty, S., Leung, T., Sukthankar, R., \& Fei-Fei, L. (2014). Large-scale video classification with convolutional neural networks. In *Proceedings of the IEEE conference on Computer Vision and Pattern Recognition* (pp. 1725-1732). 24 | 25 | [2] Ke, R., Li, W., Cui, Z., \& Wang, Y. (2019). Two-stream multi-channel convolutional neural network (TM-CNN) for multi-lane traffic speed prediction considering traffic volume impact. *arXiv preprint arXiv:1903.01678*. 26 | 27 | # Data preparation 28 | 29 | This script saves the hand-digit dataset into sub-folders. Use prepareDigitDataset function to create `upperHalf` and `bottomHalf` folders. 30 | 31 | ```matlab:Code 32 | clear;clc;close all 33 | if exist('bottomHalf')~=7 % the data is already prepared. This section is skipped. 34 | disp('Preparing demo dataset for this script') 35 | prepareDigitDataset 36 | end 37 | ``` 38 | 39 | # Store the images into `imagedatastore` 40 | 41 | ```matlab:Code 42 | inputSize=[14 28]; 43 | firstFolderName='upperHalf'; 44 | secondFolderName='bottomHalf'; 45 | imdsUpper = imageDatastore(strcat(firstFolderName,filesep), 'IncludeSubfolders',true, 'LabelSource','foldernames'); 46 | imdsBottom = imageDatastore(strcat(secondFolderName,filesep), 'IncludeSubfolders',true, 'LabelSource','foldernames'); 47 | augmenter = imageDataAugmenter('RandXReflection',false); 48 | augimdsUpper = augmentedImageDatastore(inputSize,imdsUpper,'DataAugmentation',augmenter); 49 | augimdsBottom = augmentedImageDatastore(inputSize,imdsBottom,'DataAugmentation',augmenter); 50 | numAll=numel(imdsBottom.Files); 51 | ``` 52 | 53 | # Dividing into training, validataion and test dataset 54 | 55 | ```matlab:Code 56 | % The ratio is specified here 57 | TrainRatio=0.8; 58 | ValidRatio=0.1; 59 | TestRatio=1-TrainRatio-ValidRatio; 60 | ``` 61 | 62 | Use the helper function `partitionData`. It separate the dataset with the ratio as defined. 63 | 64 | ```matlab:Code 65 | [XTrainUpper,XTrainBottom,XValidUpper,XValidBottom,XTestUpper,XTestBottom,YTrain,YValid,YTest]=partitionData(augimdsUpper,augimdsBottom,TrainRatio,ValidRatio,numAll,imdsUpper.Labels); 66 | classes = categories(YTrain); % retrieve the class names 67 | numClasses = numel(classes); % the number of classes 68 | ``` 69 | 70 | # Define convolutional neural network model 71 | 72 | ![image_4.png](README_images/image_4.png) 73 | 74 | ```matlab:Code 75 | numHiddenDimension=20; % speficy the dimension of the hidden layer 76 | layers = createSimpleLayer(XTrainUpper,numHiddenDimension); 77 | layers2 = createSimpleLayer(XTrainBottom,numHiddenDimension); 78 | ``` 79 | 80 | When the two layers are merged, the same name of the layers cannot be used. Use renameLayer function to rename the layer name in `layers2` 81 | 82 | ```matlab:Code 83 | layers2=renameLayer(layers2,'_2'); 84 | layersAdd=[fullyConnectedLayer(20,'Name','fcAdd1') 85 | fullyConnectedLayer(numClasses,'Name','fcAdd2')]; 86 | layersRemoved=[layers(1:end);concatenationLayer(1,2,'Name','cat');layersAdd]; 87 | lgraphAggregated = addLayers(layerGraph(layersRemoved),layers2(1:end)); 88 | lgraphAggregated = connectLayers(lgraphAggregated,'fc_2','cat/in2'); 89 | ``` 90 | 91 | Covert into deep learning network for custom training loops using `dlnetwork` 92 | 93 | ```matlab:Code 94 | dlnet = dlnetwork(lgraphAggregated); % A dlnetwork object enables support for custom training loops using automatic differentiation 95 | ``` 96 | 97 | # Specify training options 98 | 99 | ```matlab:Code 100 | miniBatchSize = 16; % mini batch size. When you run out of memory, decrease this value like 4 101 | numEpochs = 30; % max epoch 102 | numObservations = numel(YTrain); % the number of training data 103 | numIterationsPerEpoch = floor(numObservations./miniBatchSize); % number of iterations per epoch 104 | executionEnvironment = "gpu"; % Set "gpu" when you use gpu 105 | ``` 106 | 107 | Initial setting for `Adam` optimizer 108 | 109 | ```matlab:Code 110 | averageGrad = []; 111 | averageSqGrad = []; 112 | iteration = 1; % initialize iteration 113 | ``` 114 | 115 | # Create `animated line` 116 | 117 | `animatedline` creates an animated line that has no data and adds it to the current axes. Create an animation by adding points to the line in a loop using the `addpoints` function. 118 | 119 | ```matlab:Code 120 | plots = "training-progress"; 121 | if plots == "training-progress" 122 | f1=figure; 123 | lineLossTrain = animatedline('Color','r'); 124 | xlabel("Total Iterations") 125 | ylabel("Loss");lineLossValid = animatedline('Color','b'); 126 | xlabel("Total Iterations");ylabel("LossValid") 127 | end 128 | ``` 129 | 130 | # Prepare the validation data 131 | 132 | The validation data is called during training to check the CNN performance. 133 | 134 | ```matlab:Code 135 | YValidPlot=zeros(numClasses,numel(YValid),'single'); 136 | for c = 1:numClasses 137 | YValidPlot(c,YValid==classes(c)) = 1; 138 | end 139 | % Convert mini-batch of data to a dlarray. 140 | dlXValidUpper=dlarray(single(XValidUpper),'SSCB'); 141 | dlXValidBottom=dlarray(single(XValidBottom),'SSCB'); 142 | 143 | % If training on a GPU, then convert data to a gpuArray. 144 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 145 | dlXValidUpper = gpuArray(dlXValidUpper); 146 | dlXValidBottom = gpuArray(dlXValidBottom); 147 | end 148 | ``` 149 | 150 | # Train network in custom training loop 151 | 152 | ```matlab:Code 153 | for epoch = 1:numEpochs 154 | % Shuffle data. 155 | idx = randperm(numel(YTrain)); 156 | XTrainUpper = XTrainUpper(:,:,:,idx); 157 | XTrainBottom = XTrainBottom(:,:,:,idx); 158 | YTrain=YTrain(idx); 159 | 160 | for i = 1:numIterationsPerEpoch 161 | 162 | % Read mini-batch of data and convert the labels to dummy 163 | % variables. 164 | idx = (i-1)*miniBatchSize+1:i*miniBatchSize; 165 | XUpper = XTrainUpper(:,:,:,idx); 166 | XBottom = XTrainBottom(:,:,:,idx); 167 | 168 | Y = zeros(numClasses, miniBatchSize, 'single'); 169 | for c = 1:numClasses 170 | Y(c,YTrain(idx)==classes(c)) = 1; 171 | end 172 | 173 | % Convert mini-batch of data to a dlarray. 174 | dlXUpper = dlarray(single(XUpper),'SSCB'); 175 | dlXBottom = dlarray(single(XBottom),'SSCB'); 176 | 177 | % If training on a GPU, then convert data to a gpuArray. 178 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 179 | dlXUpper = gpuArray(dlXUpper); 180 | dlXBottom = gpuArray(dlXBottom); 181 | end 182 | 183 | % Evaluate the model gradients and loss using dlfeval and the 184 | % modelGradients helper function. 185 | [grad,loss] = dlfeval(@modelGradientsMulti,dlnet,dlXUpper,dlXBottom,Y); 186 | lossValid = modelLossMulti(dlnet,dlXValidUpper,dlXValidBottom,YValidPlot); 187 | % Update the network parameters using the Adam optimizer. 188 | [dlnet,averageGrad,averageSqGrad] = adamupdate(dlnet,grad,averageGrad,averageSqGrad,iteration,0.0005); 189 | 190 | % Display the training progress. 191 | if plots == "training-progress" 192 | addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) 193 | title("Loss During Training: Epoch - " + epoch + "; Iteration - " + i) 194 | addpoints(lineLossValid,iteration,double(gather(extractdata(lossValid)))) 195 | title("Loss During Validation: Epoch - " + epoch + "; Iteration - " + i) 196 | drawnow 197 | end 198 | 199 | % Increment the iteration counter. 200 | iteration = iteration + 1; 201 | end 202 | end 203 | ``` 204 | 205 | ![figure_1.png](README_images/figure_1.png) 206 | 207 | # Compute classification accuracy 208 | 209 | ```matlab:Code 210 | dlXTestUpper = dlarray(single(XTestUpper),'SSCB'); 211 | dlXTestBottom = dlarray(single(XTestBottom),'SSCB'); 212 | ``` 213 | 214 | Convert the test data into `gpuArray` to accelarate with GPU 215 | 216 | ```matlab:Code 217 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 218 | dlXTestUpper = gpuArray(dlXTestUpper); 219 | dlXTestBottom = gpuArray(dlXTestBottom); 220 | end 221 | ``` 222 | 223 | Two similar function are available in MATLAB for calculating the output of deep learning network. 224 | 225 | `predict`: Compute deep learning network output for inference 226 | 227 | `forward`: Compute deep learning network output for training 228 | 229 | The difference is either for training or testing. In the training phase, some techniques like batch normalization and dropout are employed while they are not used in testing. 230 | 231 | ```matlab:Code 232 | dlYPred = predict(dlnet,dlXTestUpper,dlXTestBottom); % use predict for testing 233 | [~,idx] = max(extractdata(dlYPred),[],1); % extract the class with highest score 234 | YPred = classes(idx); 235 | ``` 236 | 237 | Calculate the overall accuracy 238 | 239 | ```matlab:Code 240 | accuracy = mean(YPred==YTest) 241 | ``` 242 | 243 | ```text:Output 244 | accuracy = 0.9720 245 | ``` 246 | 247 | Display confusion matrix 248 | 249 | ```matlab:Code 250 | confusionchart(YTest,categorical(cellstr(YPred))) 251 | ``` 252 | 253 | ![figure_2.png](README_images/figure_2.png) 254 | 255 | # Helper functions 256 | 257 | ```matlab:Code 258 | function layers=createSimpleLayer(XTrainData_4D,numHiddenDimension) 259 | layers = [ 260 | imageInputLayer([14 28 3],"Name","imageinput","Mean",mean(XTrainData_4D,4)) 261 | convolution2dLayer([3 3],8,"Name","conv_1","Padding","same") 262 | reluLayer("Name","relu_1") 263 | maxPooling2dLayer([2 2],"Name","maxpool_1","Stride",[2 2]) 264 | convolution2dLayer([3 3],16,"Name","conv_2","Padding","same") 265 | reluLayer("Name","relu_2") 266 | maxPooling2dLayer([2 2],"Name","maxpool_2","Stride",[2 2]) 267 | convolution2dLayer([3 3],32,"Name","conv_3","Padding","same") 268 | reluLayer("Name","relu_3") 269 | fullyConnectedLayer(numHiddenDimension,"Name","fc")]; 270 | end 271 | 272 | function [gradients,loss] = modelGradientsMulti(dlnet,dlXupper,dlXBottom,Y) 273 | 274 | dlYPred = forward(dlnet,dlXupper,dlXBottom); 275 | dlYPred = softmax(dlYPred); 276 | 277 | loss = crossentropy(dlYPred,Y); 278 | gradients = dlgradient(loss,dlnet.Learnables); 279 | 280 | end 281 | 282 | function layers=renameLayer(layers,char) 283 | for i=1:numel(layers) 284 | layers(i).Name=[layers(i).Name,char]; 285 | end 286 | end 287 | 288 | function loss = modelLossMulti(dlnet,dlXUpper,dlXBottom,Y) 289 | dlYPred = forward(dlnet,dlXUpper,dlXBottom); 290 | dlYPred = softmax(dlYPred); 291 | loss = crossentropy(dlYPred,Y); 292 | end 293 | ``` 294 | -------------------------------------------------------------------------------- /README_images/figure_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KentaItakura/Image-Classification-using-CNN-with-Multi-Input-using-MATLAB/ce087863ae15670bd7f7fb89bc030d4b6066e666/README_images/figure_1.png -------------------------------------------------------------------------------- /README_images/figure_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KentaItakura/Image-Classification-using-CNN-with-Multi-Input-using-MATLAB/ce087863ae15670bd7f7fb89bc030d4b6066e666/README_images/figure_2.png -------------------------------------------------------------------------------- /README_images/image_0.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KentaItakura/Image-Classification-using-CNN-with-Multi-Input-using-MATLAB/ce087863ae15670bd7f7fb89bc030d4b6066e666/README_images/image_0.png -------------------------------------------------------------------------------- /README_images/image_1.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KentaItakura/Image-Classification-using-CNN-with-Multi-Input-using-MATLAB/ce087863ae15670bd7f7fb89bc030d4b6066e666/README_images/image_1.png -------------------------------------------------------------------------------- /README_images/image_2.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KentaItakura/Image-Classification-using-CNN-with-Multi-Input-using-MATLAB/ce087863ae15670bd7f7fb89bc030d4b6066e666/README_images/image_2.png -------------------------------------------------------------------------------- /README_images/image_3.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KentaItakura/Image-Classification-using-CNN-with-Multi-Input-using-MATLAB/ce087863ae15670bd7f7fb89bc030d4b6066e666/README_images/image_3.png -------------------------------------------------------------------------------- /README_images/image_4.png: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KentaItakura/Image-Classification-using-CNN-with-Multi-Input-using-MATLAB/ce087863ae15670bd7f7fb89bc030d4b6066e666/README_images/image_4.png -------------------------------------------------------------------------------- /changeLog.txt: -------------------------------------------------------------------------------- 1 | v2: The dlnet1 to 3 were merged into a single dlnet. Firstly uploaded into github as well as MATLAB file exchange 2 | v1: Three dlnet1 to 3 were prepared then trained using forward functions -------------------------------------------------------------------------------- /license.txt: -------------------------------------------------------------------------------- 1 | Copyright (c) 2021, Kenta Itakura 2 | All rights reserved. 3 | 4 | Redistribution and use in source and binary forms, with or without 5 | modification, are permitted provided that the following conditions are met: 6 | 7 | * Redistributions of source code must retain the above copyright notice, this 8 | list of conditions and the following disclaimer. 9 | 10 | * Redistributions in binary form must reproduce the above copyright notice, 11 | this list of conditions and the following disclaimer in the documentation 12 | and/or other materials provided with the distribution 13 | * Neither the name of nor the names of its 14 | contributors may be used to endorse or promote products derived from this 15 | software without specific prior written permission. 16 | THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" 17 | AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE 18 | IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 19 | DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE 20 | FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL 21 | DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR 22 | SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER 23 | CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, 24 | OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE 25 | OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 26 | -------------------------------------------------------------------------------- /multiInputCNN.mlx: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/KentaItakura/Image-Classification-using-CNN-with-Multi-Input-using-MATLAB/ce087863ae15670bd7f7fb89bc030d4b6066e666/multiInputCNN.mlx -------------------------------------------------------------------------------- /partitionData.m: -------------------------------------------------------------------------------- 1 | function [XTrainUpper,XTrainBottom,XValidUpper,XValidBottom,XTestUpper,XTestBottom,YTrain,YValid,YTest]=partitionData(augimdsUpper,augimdsBottom,TrainRatio,ValidRatio,numAll,labels) 2 | indices=randperm(numAll,numAll); 3 | TrainIdx=indices(1:round(numel(indices)*TrainRatio)); 4 | ValidIdx=indices(round(numel(indices)*TrainRatio)+1:round(numel(indices)*(TrainRatio+ValidRatio))); 5 | TestIdx=indices(round(numel(indices)*(TrainRatio+ValidRatio))+1:end); 6 | auimdsTrainaugimdsUpper = partitionByIndex(augimdsUpper,TrainIdx); 7 | auimdsTrainBottom = partitionByIndex(augimdsBottom,TrainIdx); 8 | auimdsValidUpper = partitionByIndex(augimdsUpper,ValidIdx); 9 | auimdsValidBottom = partitionByIndex(augimdsBottom,ValidIdx); 10 | auimdsTestUpper = partitionByIndex(augimdsUpper,TestIdx); 11 | auimdsTestBottom = partitionByIndex(augimdsBottom,TestIdx); 12 | XTrainUpper=readall(auimdsTrainaugimdsUpper); 13 | XTrainUpper=cat(4,XTrainUpper.input{:}); 14 | XTrainBottom=readall(auimdsTrainBottom); 15 | XTrainBottom=cat(4,XTrainBottom.input{:}); 16 | XValidUpper=readall(auimdsValidUpper); 17 | XValidUpper=cat(4,XValidUpper.input{:}); 18 | XValidBottom=readall(auimdsValidBottom); 19 | XValidBottom=cat(4,XValidBottom.input{:}); 20 | XTestUpper=readall(auimdsTestUpper); 21 | XTestUpper=cat(4,XTestUpper.input{:}); 22 | XTestBottom=readall(auimdsTestBottom); 23 | XTestBottom=cat(4,XTestBottom.input{:}); 24 | YTrain=labels(TrainIdx); 25 | YValid=labels(ValidIdx); 26 | YTest=labels(TestIdx); 27 | end -------------------------------------------------------------------------------- /prepareDigitDataset.m: -------------------------------------------------------------------------------- 1 | function prepareDigitDataset 2 | [XTrain,YTrain] = digitTrain4DArrayData; 3 | XTrainUpper=XTrain(1:14,:,:,:); % extract the upper part 4 | XTrainBottom=XTrain(15:28,:,:,:);% extract the down part 5 | mkdir upperHalf 6 | mkdir bottomHalf 7 | for n=1:10 8 | mkdir(sprintf('upperHalf/%d',n-1)) 9 | mkdir(sprintf('bottomHalf/%d',n-1)) 10 | end 11 | 12 | for i=1:size(XTrainUpper,4) 13 | classNum=double(YTrain(i))-1; % Note that each value is added by one 14 | imwrite(uint8(repmat(XTrainUpper(:,:,:,i)*255,[1 1 3])),sprintf('upperHalf/%d/TrainImg%d.jpg',classNum,i)) 15 | imwrite(uint8(repmat(XTrainBottom(:,:,:,i)*255,[1 1 3])),sprintf('bottomHalf/%d/TrainImg%d.jpg',classNum,i)) 16 | end 17 | end --------------------------------------------------------------------------------