├── Class_incremental_learning_for_device_identification_in_IoT_IoT_16942_2021.pdf ├── ContinualLearning ├── FCLayerAdapted.m ├── MNISTCustomLoop.m ├── MNISTDecoder.m ├── Readme.md ├── WorkStage │ ├── CSIL.m │ ├── CSILLockOldFPs.m │ ├── CSILLockOldFPsChessBoard.m │ ├── CSILLockOldFPsChessBoardPast5000.m │ ├── CSILLockOldFPsPast5000.m │ ├── CSILPast5000.m │ ├── Fixrep.m │ ├── FixrepExpanding.m │ ├── ILDashScript.m │ ├── ILDashScriptCSNet.m │ ├── ILDashScriptFixRep.m │ ├── Readme.md │ ├── adsb-107CSNet20.mat │ ├── noCSI-kd-ewc.m │ ├── noCSI.m │ ├── noCSIL_Finetune_EWC.m │ ├── noCS_adaptiveEwc.m │ └── noCS_restrictEwc.m ├── adsb_recognition_singleBurst_4.m ├── adsb_recognition_singleBurst_v4_2.m ├── adsb_recognition_singleBurst_v4_3.m ├── adsb_recognition_singleBurst_v4_3b.m ├── adsb_recognition_singleBurst_v4_3c.m ├── adsb_recognition_singleBurst_v4_4.m ├── adsb_recognition_singleBurst_v4_4b.m ├── adsb_recognition_singleBurst_v4_4c.m ├── adsb_recognition_singleBurst_v4_5.m ├── adsb_recognition_singleBurst_v4_5b.m ├── adsb_recognition_singleBurst_v4_5c.m ├── adsb_recognition_singleBurst_v4_5d.m ├── adsb_recognition_singleBurst_v4_5e.m └── structstruct.m ├── FCLayer.m ├── Formal Proof of Orthogonality.pdf ├── HypersphereLib ├── HyperSphere.m ├── HypersphereSurfArea.m ├── hypersphereCapArea.m ├── hypersphereCoverageTest.m ├── hypersphereCoverageTest.m~ ├── hypersphereLayer.m ├── matlabPlotForIEEEtran.m ├── matlabPlotForIEEEtran.pdf ├── plotfile3.eps ├── plotfile5.eps ├── randsphere.m └── unitHypersphereSurfArea.m ├── Readme.md ├── Storyline.pdf ├── countmember.m ├── covmatrix.m ├── extractNoise.m ├── makeDataTensor.m ├── numericalSimOfDoC ├── Readme.md ├── solveSpace.m ├── solveSpaceFunction.m └── solveSpaceFunction3.m ├── plotConfMat.m ├── preluLayer.m ├── tensorVectorLayer.m ├── weightedAdditionLayer.m ├── yxBatchNorm.m ├── zeroBiasFCLayer.m └── zeroBiasFCLayerEuc.m /Class_incremental_learning_for_device_identification_in_IoT_IoT_16942_2021.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcwhy/CSIL/8ce8637daf4dc60eeb1c56bff64c050c5b2353e9/Class_incremental_learning_for_device_identification_in_IoT_IoT_16942_2021.pdf -------------------------------------------------------------------------------- /ContinualLearning/FCLayerAdapted.m: -------------------------------------------------------------------------------- 1 | classdef FCLayerAdapted < nnet.layer.Layer 2 | 3 | properties (Learnable) 4 | % Layer learnable parameters 5 | Weights; 6 | Biases; 7 | end 8 | 9 | methods 10 | function layer = FCLayerAdapted(inputDim,outputDim,name,initialWeights) 11 | % layer = weightedAdditionLayer(numInputs,name) creates a 12 | 13 | % Set number of inputs. 14 | %layer.NumInputs = inputDim; 15 | %layer.NumOutputs = numOutputs; 16 | % Set layer name. 17 | layer.Name = name; 18 | % Set layer description. 19 | layer.Description = "FC layer without bias neurons with " + inputDim + ... 20 | " inputs"; 21 | % Initialize layer weights. 22 | % layer.Weights = dlarray(randn(outputDim,inputDim).*0.0001); 23 | % layer.Biases = dlarray(randn(outputDim,1).*0.0001); 24 | stdGlorot = sqrt(10/(inputDim + outputDim)); 25 | layer.Weights = dlarray(rand(outputDim,inputDim).*stdGlorot); 26 | layer.Weights = layer.Weights - mean(layer.Weights(:)); 27 | layer.Biases = dlarray(zeros(outputDim,1)); 28 | 29 | %layer.Biases = rand(outputDim,1); 30 | if numel(initialWeights) ~= 0 31 | layer.Weights = initialWeights; 32 | end 33 | end 34 | 35 | function Z = predict(layer, X) 36 | if ndims(X) >= 3 37 | batchSize = size(X,4); 38 | else 39 | batchSize = size(X,ndims(X)); 40 | end 41 | % Z = (layer.Weights*squeeze(X)+(layer.Biases )); 42 | % Z = layer.Weights*squeeze(X); 43 | 44 | % Z = reshape((layer.Weights./sqrt(sum((layer.Weights).^2,2)))*squeeze(X),... 45 | % [1,1,size(layer.Weights,1),batchSize]); 46 | if ndims(X) >= 3 47 | Z = reshape(... 48 | ( layer.Weights )*( squeeze(X) ) + layer.Biases,... 49 | [1,1,size(layer.Weights,1),batchSize]); 50 | % Z = reshape(... 51 | % ( layer.Weights )*( squeeze(X) ),... 52 | % [1,1,size(layer.Weights,1),batchSize]); 53 | else 54 | Z = reshape(... 55 | ( layer.Weights )*( squeeze(X) ) + layer.Biases,... 56 | [size(layer.Weights,1),batchSize]); 57 | % Z = reshape(... 58 | % ( layer.Weights )*( squeeze(X) ) ,... 59 | % [size(layer.Weights,1),batchSize]); 60 | end 61 | end 62 | end 63 | end 64 | -------------------------------------------------------------------------------- /ContinualLearning/MNISTCustomLoop.m: -------------------------------------------------------------------------------- 1 | clear; 2 | clc; 3 | close all; 4 | rng default 5 | 6 | addpath('../'); 7 | addpath('../HypersphereLib/'); 8 | set(0,'DefaultTextFontName','Times','DefaultTextFontSize',18,... 9 | 'DefaultAxesFontName','Times','DefaultAxesFontSize',18,... 10 | 'DefaultLineLineWidth',1,'DefaultLineMarkerSize',7.75) 11 | 12 | 13 | [XTrain,YTrain] = digitTrain4DArrayData; 14 | YTrain = double(YTrain); 15 | cond = YTrain >=4; 16 | revCond = ~cond; 17 | uX = XTrain(:,:,:,cond); 18 | uY = YTrain(cond); 19 | % uX = 2*randn(size(uX)); 20 | XTrain = XTrain(:,:,:,revCond); 21 | YTrain = YTrain(revCond); 22 | YTrain = categorical(YTrain); 23 | 24 | perm = randperm(numel(YTrain)); 25 | XTrain = XTrain(:,:,:,perm); 26 | YTrain = YTrain(perm); 27 | XVal = XTrain(:,:,:,end - round(0.3*size(XTrain,4)):end); 28 | YVal = YTrain(end - round(0.3*size(YTrain,1)):end); 29 | XTrain = XTrain(:,:,:,1:round(0.7*size(XTrain,4))); 30 | YTrain = YTrain(1:round(0.7*size(YTrain,1))); 31 | 32 | numClasses = numel(categories(YTrain)); 33 | numFeatureDim = 10; 34 | layers = [ 35 | imageInputLayer([28 28 1],'Name','Input','Mean',0) 36 | 37 | convolution2dLayer(3,8,'Padding','same','Name','conv2d_1') 38 | batchNormalizationLayer('Name','batchNorm_1') 39 | reluLayer('Name','relu_1') 40 | 41 | maxPooling2dLayer(2,'Stride',2,'Name','maxPooling2d_1') 42 | 43 | convolution2dLayer(3,16,'Padding','same','Name','cov2d_2') 44 | batchNormalizationLayer('Name','batchNorm_2') 45 | reluLayer('Name','relu_2') 46 | 47 | maxPooling2dLayer(2,'Stride',2,'Name','maxPooling2d_2') 48 | 49 | convolution2dLayer(3,32,'Padding','same','Name','cov2d_3') 50 | batchNormalizationLayer('Name','batchNorm_3') 51 | reluLayer('Name','relu_3') 52 | tensorVectorLayer('Flatten') 53 | fullyConnectedLayer(numFeatureDim,'Name','fc_bf_fp') 54 | % FCLayer(numFeatureDim,numel(categories(YTrain)),'fp',[]) 55 | zeroBiasFCLayer(numFeatureDim,numel(categories(YTrain)),'fp',[]) 56 | yxSoftmax('softmax')]; 57 | 58 | lgraph = layerGraph(layers); 59 | YTrain = double(YTrain); 60 | numEpochs = 15; 61 | miniBatchSize = 128; 62 | plots = "training-progress"; 63 | executionEnvironment = "auto"; 64 | if plots == "training-progress" 65 | figure(10); 66 | lineLossTrain = animatedline('Color','#0072BD','lineWidth',1.5); 67 | lineClassificationLoss = animatedline('Color','#EDB120','lineWidth',1.5); 68 | 69 | ylim([-inf inf]) 70 | xlabel("Iteration") 71 | ylabel("Loss") 72 | legend('Loss','classificationLoss'); 73 | grid on; 74 | 75 | figure(11); 76 | lineCVAccuracy = animatedline('Color','#D95319','lineWidth',1.5); 77 | ylim([0 1.1]) 78 | xlabel("Iteration") 79 | ylabel("Loss") 80 | legend('CV Acc.','Avg. Kernel dist.'); 81 | grid on; 82 | end 83 | L2RegularizationFactor = 0.01; 84 | initialLearnRate = 0.01; 85 | decay = 0.01; 86 | momentumSGD = 0.9; 87 | velocities = []; 88 | learnRates = []; 89 | momentums = []; 90 | gradientMasks = []; 91 | numObservations = numel(YTrain); 92 | numIterationsPerEpoch = floor(numObservations./miniBatchSize); 93 | iteration = 0; 94 | start = tic; 95 | classes = categorical(YTrain); 96 | lgraph2 = lgraph; % No old weights 97 | dlnet = dlnetwork(lgraph2); 98 | 99 | % Loop over epochs. 100 | totalIters = 0; 101 | for epoch = 1:numEpochs 102 | idx = randperm(numel(YTrain)); 103 | XTrain = XTrain(:,:,:,idx); 104 | YTrain = YTrain(idx); 105 | % Loop over mini-batches. 106 | for i = 1:numIterationsPerEpoch 107 | iteration = iteration + 1; 108 | totalIters = totalIters + 1; 109 | % Read mini-batch of data and convert the labels to dummy 110 | % variables. 111 | idx = (i-1)*miniBatchSize+1:i*miniBatchSize; 112 | Xb = XTrain(:,:,:,idx); 113 | Yb = zeros(numClasses, miniBatchSize, 'single'); 114 | for c = 1:numClasses 115 | Yb(c,YTrain(idx)==(c)) = 1; 116 | end 117 | % Convert mini-batch of data to dlarray. 118 | dlX = dlarray(single(Xb),'SSCB'); 119 | % If training on a GPU, then convert data to gpuArray. 120 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 121 | dlX = gpuArray(dlX); 122 | end 123 | % Evaluate the model gradients, state, and loss using dlfeval and the 124 | % modelGradients function and update the network state. 125 | [gradients,state,loss,classificationLoss] = dlfeval(@modelGradientsOnWeights,dlnet,dlX,Yb); 126 | % [gradients,state,loss] = dlfeval(@modelGradientsOnWeights,dlnet,dlX,Yb); 127 | dlnet.State = state; 128 | % Determine learning rate for time-based decay learning rate schedule. 129 | learnRate = initialLearnRate/(1 + decay*iteration); 130 | % Update the network parameters using the SGDM optimizer. 131 | %[dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, momentum); 132 | % Update the network parameters using the SGD optimizer. 133 | %dlnet = dlupdate(@sgdFunction,dlnet,gradients); 134 | if isempty(velocities) 135 | velocities = packScalar(gradients, 0); 136 | learnRates = packScalar(gradients, learnRate); 137 | momentumSGDs = packScalar(gradients, momentumSGD); 138 | momentums = packScalar(gradients, 0); 139 | L2Foctors = packScalar(gradients, 0); 140 | wd = packScalar(gradients, 0); 141 | gradientMasks = packScalar(gradients, 1); 142 | % % Let's lock some weights 143 | % for k = 1:2 144 | % gradientMasks.Value{k}=dlarray(zeros(size(gradientMasks.Value{k}))); 145 | % end 146 | end 147 | %%%%----------- Check Point 2: 148 | %%%% Here you can specify which optimizer to use, 149 | % [dlnet, velocities] = dlupdate(@sgdmFunctionL2, ... 150 | % dlnet, gradients, velocities, ... 151 | % learnRates, momentumSGDs, L2Foctors, gradientMasks); % This is 152 | % % the famous SGD with momentum 153 | totalIterInPackage = packScalar(gradients, totalIters); % We have to make this... 154 | % stupid data 155 | % structure but it 156 | % only contains 157 | % the number of 158 | % iterations 159 | [dlnet, velocities, momentums] = dlupdate(@adamFunction, ... 160 | dlnet, gradients, velocities, ... 161 | learnRates, momentums, wd, gradientMasks, ... 162 | totalIterInPackage); 163 | % [dlnet] = dlupdate(@sgdFunction, ... 164 | % dlnet, gradients); % the vanilla 165 | %%%%-----------End of Check Point 2 166 | 167 | % Display the training progress. 168 | if plots == "training-progress" 169 | D = duration(0,0,toc(start),'Format','hh:mm:ss'); 170 | XTest = XVal; 171 | YTest = categorical(YVal); 172 | if mod(iteration,5) == 0 173 | accuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment,0); 174 | addpoints(lineCVAccuracy,iteration, accuracy); 175 | end 176 | addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) 177 | addpoints(lineClassificationLoss,iteration,double(gather(extractdata(classificationLoss)))); 178 | title("Epoch: " + epoch + ", Elapsed: " + string(D)) 179 | drawnow 180 | end 181 | end 182 | end 183 | accuracy = cvAccuracy(dlnet, XVal, categorical(YVal), miniBatchSize, executionEnvironment, 1) 184 | 185 | 186 | 187 | function accuracy = cvAccuracy(dlnet, XTest, YTest, miniBatchSize, executionEnvironment, confusionChartFlg) 188 | dlXTest = dlarray(XTest,'SSCB'); 189 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 190 | dlXTest = gpuArray(dlXTest); 191 | end 192 | dlYPred = modelPredictions(dlnet,dlXTest,miniBatchSize); 193 | [~,idx] = max(extractdata(dlYPred),[],1); 194 | YPred = categorical(idx); 195 | accuracy = mean(YPred(:) == YTest(:)); 196 | if confusionChartFlg == 1 197 | figure 198 | confusionchart(YPred(:),YTest(:)); 199 | end 200 | end 201 | 202 | function dlYPred = modelPredictions(dlnet,dlX,miniBatchSize) 203 | numObservations = size(dlX,4); 204 | numIterations = ceil(numObservations / miniBatchSize); 205 | numClasses = size(dlnet.Layers(end-1).Weights,1); 206 | dlYPred = zeros(numClasses,numObservations,'like',dlX); 207 | for i = 1:numIterations 208 | idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations); 209 | dlYPred(:,idx) = predict(dlnet,dlX(:,:,:,idx)); 210 | end 211 | end 212 | 213 | 214 | function [gradients,state,loss,classificationLoss] = modelGradientsOnWeights(dlnet,dlX,Y) 215 | % %This is only used with softmax of matlab which only applies softmax 216 | % on 'C' and 'B' channels. 217 | [rawPredictions,state] = forward(dlnet,dlX,'Outputs', 'fp'); 218 | dlYPred = softmax(dlarray(squeeze(rawPredictions),'CB')); 219 | % [dlYPred,state] = forward(dlnet,dlX); 220 | penalty = 0; 221 | scalarL2Factor = 0; 222 | if scalarL2Factor ~= 0 223 | paramLst = dlnet.Learnables.Value; 224 | for i = 1:size(paramLst,1) 225 | penalty = penalty + sum((paramLst{i}(:)).^2); 226 | end 227 | end 228 | 229 | classificationLoss = crossentropy(squeeze(dlYPred),Y) + scalarL2Factor*penalty; 230 | 231 | loss = classificationLoss; 232 | % loss = classificationLoss + 0.2*(max(max(rawPredictions))-min(max(rawPredictions))); 233 | gradients = dlgradient(loss,dlnet.Learnables); 234 | %gradients = dlgradient(loss,dlnet.Learnables(4,:)); 235 | end 236 | 237 | function [params,velocityUpdates,momentumUpdate] = adamFunction(params, rawParamGradients,... 238 | velocities, learnRates, momentums, wd, gradientMasks, iters) 239 | % https://arxiv.org/pdf/2010.07468.pdf %%AdaBelief 240 | % https://arxiv.org/pdf/1711.05101.pdf %%DeCoupled Weight Decay 241 | b1 = 0.9; 242 | b2 = 0.999; 243 | e = 1e-8; 244 | curIter = iters(:); 245 | curIter = curIter(1); 246 | 247 | gt = rawParamGradients; 248 | mt = (momentums.*b1 + ((1-b1)).*gt); 249 | vt = (velocities.*b2 + ((1-b2)).*((gt-mt).^2)); 250 | 251 | momentumUpdate = mt; 252 | velocityUpdates = vt; 253 | h_mt = mt./(1-b1.^curIter); 254 | h_vt = (vt+e)./(1-b2.^curIter); 255 | %%%%----------- Check Point 3: 256 | %%%% Here you can specify whether to use bias correction, 257 | %%%% or zero-bias dense layer 258 | %%%% in this test, we can just try to eliminate the effect of varying learning 259 | %%%% rates 260 | % params = params - 0.001.*(mt./(sqrt(vt)+e)).*gradientMasks... 261 | % - wd.*params.*gradientMasks; %This works better for zero-bias dense layer 262 | % params = params - 0.001.*(h_mt./(sqrt(h_vt)+e)).*gradientMasks... 263 | % -L2Foctors.*params.*gradientMasks; 264 | params = params - learnRates.*(h_mt./(sqrt(h_vt)+e)).*gradientMasks... 265 | -2*learnRates .* wd.*params.*gradientMasks; 266 | %%%% 267 | %%%%-----------End of Check Point 3 268 | end 269 | 270 | function param = sgdFunction(param,paramGradient) 271 | learnRate = 0.01; 272 | param = param - learnRate.*paramGradient; 273 | end 274 | 275 | function [params, velocityUpdates] = sgdmFunction(params, paramGradients,... 276 | velocities, learnRates, momentums) 277 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 278 | % velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 279 | velocityUpdates = momentums.*velocities+0.001.*paramGradients; 280 | params = params - velocityUpdates; 281 | end 282 | 283 | function [params, velocityUpdates] = sgdmFunctionL2(params, rawParamGradients,... 284 | velocities, learnRates, momentums, L2Foctors, gradientMasks) 285 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 286 | % https://towardsdatascience.com/intuitions-on-l1-and-l2-regularisation-235f2db4c261 287 | paramGradients = rawParamGradients + 2*L2Foctors.*params; 288 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 289 | params = params - (velocityUpdates).*gradientMasks; 290 | end 291 | 292 | function tabVars = packScalar(target, scalar) 293 | % The matlabs' silly design results in such a strange function 294 | tabVars = target; 295 | for row = 1:size(tabVars(:,3),1) 296 | tabVars{row,3} = {... 297 | dlarray(... 298 | ones(size(tabVars.Value{row})).*scalar...%ones(size(tabVars(row,3).Value{1,1})).*scalar... 299 | )... 300 | }; 301 | end 302 | end 303 | 304 | 305 | 306 | -------------------------------------------------------------------------------- /ContinualLearning/MNISTDecoder.m: -------------------------------------------------------------------------------- 1 | clear; 2 | clc; 3 | close all; 4 | rng default 5 | 6 | addpath('../'); 7 | addpath('../HypersphereLib/'); 8 | set(0,'DefaultTextFontName','Times','DefaultTextFontSize',18,... 9 | 'DefaultAxesFontName','Times','DefaultAxesFontSize',18,... 10 | 'DefaultLineLineWidth',1,'DefaultLineMarkerSize',7.75) 11 | 12 | % load('./fashion-mnist.mat') 13 | % clear XTest 14 | % clear YTest 15 | % YTrain = categorical(YTrain); 16 | % YTrain = grp2idx(YTrain); 17 | 18 | [XTrain,YTrain] = digitTrain4DArrayData; 19 | YTrain = double(YTrain); 20 | cond = YTrain >=8; 21 | revCond = ~cond; 22 | uX = XTrain(:,:,:,cond); 23 | uY = YTrain(cond); 24 | % uX = 2*randn(size(uX)); 25 | XTrain = XTrain(:,:,:,revCond); 26 | YTrain = YTrain(revCond); 27 | YTrain = categorical(YTrain); 28 | 29 | perm = randperm(numel(YTrain)); 30 | XTrain = XTrain(:,:,:,perm); 31 | YTrain = YTrain(perm); 32 | XVal = XTrain(:,:,:,end - round(0.3*size(XTrain,4)):end); 33 | YVal = YTrain(end - round(0.3*size(YTrain,1)):end); 34 | XTrain = XTrain(:,:,:,1:round(0.7*size(XTrain,4))); 35 | YTrain = YTrain(1:round(0.7*size(YTrain,1))); 36 | 37 | numClasses = numel(categories(YTrain)) 38 | numFeatureDim = numClasses + 1 39 | layers = [ 40 | imageInputLayer([28 28 1],'Name','Input','Mean',0) 41 | 42 | convolution2dLayer(3,8,'Padding','same','Name','conv2d_1') 43 | batchNormalizationLayer('Name','batchNorm_1') 44 | reluLayer('Name','relu_1') 45 | 46 | maxPooling2dLayer(2,'Stride',2,'Name','maxPooling2d_1') 47 | 48 | convolution2dLayer(3,16,'Padding','same','Name','cov2d_2') 49 | batchNormalizationLayer('Name','batchNorm_2') 50 | reluLayer('Name','relu_2') 51 | 52 | maxPooling2dLayer(2,'Stride',2,'Name','maxPooling2d_2') 53 | 54 | convolution2dLayer(3,32,'Padding','same','Name','cov2d_3') 55 | batchNormalizationLayer('Name','batchNorm_3') 56 | reluLayer('Name','relu_3') 57 | tensorVectorLayer('Flatten') 58 | dropoutLayer(0.3,'Name','someDropouts') 59 | fullyConnectedLayer(numFeatureDim,'Name','fc_bf_fp') 60 | yxBatchNorm('yxBatchnorm_1',numFeatureDim) 61 | % FCLayer(numFeatureDim,numel(categories(YTrain)),'fp',[]) 62 | zeroBiasFCLayer(numFeatureDim,numel(categories(YTrain)),'fp',[]) 63 | yxSoftmaxReverse('softmax') 64 | 65 | ]; 66 | 67 | 68 | lgraph = layerGraph(layers); 69 | shallowClassifier = [ 70 | zeroBiasFCLayer(numFeatureDim,numel(categories(YTrain)),'Shallowfp',[]) 71 | yxSoftmax('ShallowSoftmax')]; 72 | decoder = [ 73 | % fullyConnectedLayer(28*28*1,'Name','weakDecoderIn') 74 | tensorVectorLayer('weakDecoderIn') 75 | % FCLayer(numFeatureDim,28*28*1,'weakDecoderOut',[]) 76 | % dropoutLayer(0.2,'Name','decoderDropouts') 77 | fullyConnectedLayer(28*28*1,'Name','weakDecoderOut') 78 | ]; 79 | % decoder = fullyConnectedLayer(28*28*1,'Name','weakDecoder'); 80 | lgraph = addLayers(lgraph,decoder); 81 | lgraph = connectLayers(lgraph,'relu_1','weakDecoderIn'); 82 | YTrain = double(YTrain); 83 | numEpochs = 5; 84 | miniBatchSize = 128; 85 | plots = "training-progress"; 86 | executionEnvironment = "auto"; 87 | if plots == "training-progress" 88 | figure(10); 89 | lineLossTrain = animatedline('Color','#0072BD','lineWidth',1.5); 90 | lineClassificationLoss = animatedline('Color','#EDB120','lineWidth',1.5); 91 | ylim([-inf inf]) 92 | xlabel("Iteration") 93 | ylabel("Loss") 94 | legend('Loss','classificationLoss'); 95 | grid on; 96 | 97 | figure(11); 98 | lineCVAccuracy = animatedline('Color','#D95319','lineWidth',1.5); 99 | ylim([0 1.1]) 100 | xlabel("Iteration") 101 | ylabel("Loss") 102 | legend('CV Acc.'); 103 | grid on; 104 | 105 | figure(12); 106 | lineReconstructionLoss = animatedline('Color','#77AC30','lineWidth',1.5); 107 | ylim([-inf inf]) 108 | xlabel("Iteration") 109 | ylabel("ReconstructionLoss") 110 | 111 | figure(13); 112 | lineSumDeviation = animatedline('Color','#77AC30','lineWidth',1.5); 113 | ylim([-inf inf]) 114 | xlabel("Iteration") 115 | ylabel("Sum of deviations") 116 | end 117 | 118 | L2RegularizationFactor = 0.01; 119 | initialLearnRate = 0.01; 120 | decay = 0.01; 121 | momentumSGD = 0.9; 122 | velocities = []; 123 | learnRates = []; 124 | momentums = []; 125 | gradientMasks = []; 126 | numObservations = numel(YTrain); 127 | numIterationsPerEpoch = floor(numObservations./miniBatchSize); 128 | iteration = 0; 129 | start = tic; 130 | classes = categorical(YTrain); 131 | lgraph2 = lgraph; % No old weights 132 | dlnet = dlnetwork(lgraph2); 133 | 134 | % Loop over epochs. 135 | totalIters = 0; 136 | for epoch = 1:numEpochs 137 | idx = randperm(numel(YTrain)); 138 | XTrain = XTrain(:,:,:,idx); 139 | YTrain = YTrain(idx); 140 | % Loop over mini-batches. 141 | for i = 1:numIterationsPerEpoch 142 | iteration = iteration + 1; 143 | totalIters = totalIters + 1; 144 | % Read mini-batch of data and convert the labels to dummy 145 | % variables. 146 | idx = (i-1)*miniBatchSize+1:i*miniBatchSize; 147 | Xb = XTrain(:,:,:,idx); 148 | Yb = zeros(numClasses, miniBatchSize, 'single'); 149 | for c = 1:numClasses 150 | Yb(c,YTrain(idx)==(c)) = 1; 151 | end 152 | % Convert mini-batch of data to dlarray. 153 | dlX = dlarray(single(Xb),'SSCB'); 154 | % If training on a GPU, then convert data to gpuArray. 155 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 156 | dlX = gpuArray(dlX); 157 | end 158 | % Evaluate the model gradients, state, and loss using dlfeval and the 159 | % modelGradients function and update the network state. 160 | [gradients,state,... 161 | loss,classificationLoss,... 162 | reconstructionLoss] = dlfeval(@modelGradientsOnWeights,dlnet,dlX,Yb); 163 | % [gradients,state,loss] = dlfeval(@modelGradientsOnWeights,dlnet,dlX,Yb); 164 | dlnet.State = state; 165 | % Determine learning rate for time-based decay learning rate schedule. 166 | learnRate = initialLearnRate/(1 + decay*iteration); 167 | % Update the network parameters using the SGDM optimizer. 168 | %[dlnet, velocity] = sgdmupdate(dlnet, gradients, velocity, learnRate, momentum); 169 | % Update the network parameters using the SGD optimizer. 170 | %dlnet = dlupdate(@sgdFunction,dlnet,gradients); 171 | if isempty(velocities) 172 | velocities = packScalar(gradients, 0); 173 | learnRates = packScalar(gradients, learnRate); 174 | momentumSGDs = packScalar(gradients, momentumSGD); 175 | momentums = packScalar(gradients, 0); 176 | L2Foctors = packScalar(gradients, 0); 177 | wd = packScalar(gradients, 0); 178 | gradientMasks = packScalar(gradients, 1); 179 | % % Let's lock some weights 180 | % for k = 1:2 181 | % gradientMasks.Value{k}=dlarray(zeros(size(gradientMasks.Value{k}))); 182 | % end 183 | end 184 | %%%%----------- Check Point 2: 185 | %%%% Here you can specify which optimizer to use, 186 | % [dlnet, velocities] = dlupdate(@sgdmFunctionL2, ... 187 | % dlnet, gradients, velocities, ... 188 | % learnRates, momentumSGDs, L2Foctors, gradientMasks); % This is 189 | % % the famous SGD with momentum 190 | totalIterInPackage = packScalar(gradients, totalIters); % We have to make this... 191 | % stupid data 192 | % structure but it 193 | % only contains 194 | % the number of 195 | % iterations 196 | [dlnet, velocities, momentums] = dlupdate(@adamFunction, ... 197 | dlnet, gradients, velocities, ... 198 | learnRates, momentums, wd, gradientMasks, ... 199 | totalIterInPackage); 200 | % [dlnet] = dlupdate(@sgdFunction, ... 201 | % dlnet, gradients); % the vanilla 202 | %%%%-----------End of Check Point 2 203 | 204 | % Display the training progress. 205 | if plots == "training-progress" 206 | D = duration(0,0,toc(start),'Format','hh:mm:ss'); 207 | XTest = XVal; 208 | YTest = categorical(YVal); 209 | if mod(iteration,5) == 0 210 | accuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment,0); 211 | addpoints(lineCVAccuracy, iteration, accuracy); 212 | end 213 | stepSumDeviations = calculateGrossMutualDistance(gpuArray(dlnet.Layers({dlnet.Layers.Name} == "fp").Weights)); 214 | 215 | addpoints(lineLossTrain,iteration,double(gather(extractdata(loss)))) 216 | addpoints(lineClassificationLoss,iteration,double(gather(extractdata(classificationLoss)))); 217 | addpoints(lineReconstructionLoss,iteration,double(gather(extractdata(reconstructionLoss)))); 218 | addpoints(lineSumDeviation,iteration, double(stepSumDeviations) ); 219 | 220 | title("Epoch: " + epoch + ", Elapsed: " + string(D)) 221 | drawnow 222 | end 223 | end 224 | end 225 | accuracy = cvAccuracy(dlnet, XVal, categorical(YVal), miniBatchSize, executionEnvironment, 1) 226 | % imagesc(similarMatrix) 227 | 228 | testReconstruction = sigmoid(predict(dlnet, dlarray(XVal(:,:,:,1:10),'SSCB'),'Outputs', 'weakDecoderOut')); 229 | 230 | featuresWithWeakDecoder = squeeze(dlnet.predict(dlarray(XVal,'SSCB'),'Outputs','Flatten')); 231 | [coeff, score, latent, tsquared, explained] = pca(extractdata(featuresWithWeakDecoder)'); 232 | featurePCA = sum((cumsum(latent))./sum(latent) <= 0.95) 233 | 234 | [sumDevationCosine,~] = calculateGrossMutualDistance(dlnet.Layers({dlnet.Layers.Name} == "fp").Weights) 235 | 236 | 237 | function [sumDeviationAngles,similarMatrix] = calculateGrossMutualDistance(FPs) 238 | % weights = FPs; 239 | % similarMatrix = zeros(size(weights,1), size(weights,1)); 240 | % for i = 1:size(weights,1) 241 | % curWeight = weights(i,:); 242 | % magCurWeight = sqrt(sum(curWeight.^2,2)); 243 | % for j = 1:size(weights,1) 244 | % nxtWeight = weights(j,:); 245 | % magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 246 | % similarMatrix(i,j)=sum(... 247 | % (nxtWeight./magNxtWeight)... 248 | % .*(curWeight./magCurWeight)... 249 | % ); 250 | % % similarMatrix(i,j) = sqrt(sum((nxtWeight - curWeight).^2)); 251 | % end 252 | % end 253 | similarMatrix = FPs*FPs'./sqrt(diag(FPs*FPs'))./sqrt(diag(FPs*FPs')'); 254 | sumDeviationAngles = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2; 255 | end 256 | 257 | function accuracy = cvAccuracy(dlnet, XTest, YTest, miniBatchSize, executionEnvironment, confusionChartFlg) 258 | dlXTest = dlarray(XTest,'SSCB'); 259 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 260 | dlXTest = gpuArray(dlXTest); 261 | end 262 | dlYPred = modelPredictions(dlnet,dlXTest,miniBatchSize); 263 | [~,idx] = max(extractdata(dlYPred),[],1); 264 | YPred = categorical(idx); 265 | accuracy = mean(YPred(:) == YTest(:)); 266 | if confusionChartFlg == 1 267 | figure 268 | confusionchart(YPred(:),YTest(:)); 269 | end 270 | end 271 | 272 | function dlYPred = modelPredictions(dlnet,dlX,miniBatchSize) 273 | numObservations = size(dlX,4); 274 | numIterations = ceil(numObservations / miniBatchSize); 275 | numClasses = size(dlnet.Layers({dlnet.Layers.Name} == "fp").Weights,1); 276 | dlYPred = zeros(numClasses,numObservations,'like',dlX); 277 | for i = 1:numIterations 278 | idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations); 279 | dlYPred(:,idx) = predict(dlnet,dlX(:,:,:,idx)); 280 | end 281 | end 282 | 283 | 284 | function [gradients,state,loss,... 285 | classificationLoss, reconstructionLoss] = modelGradientsOnWeights(dlnet,dlX,Y) 286 | % %This is only used with softmax of matlab which only applies softmax 287 | % on 'C' and 'B' channels. 288 | % [rawPredictions, rawReconstruction, state] = forward(dlnet,dlX,'Outputs', {'fp', 'weakDecoderOut'}); 289 | % dlYPred = softmax(dlarray(squeeze(rawPredictions),'CB')); 290 | [dlYPred, rawReconstruction, state] = forward(dlnet,dlX,'Outputs', {'softmax', 'weakDecoderOut'}); 291 | reconstruction = (rawReconstruction); 292 | flatX = squeeze(reshape(dlX,[28*28,1,128])); 293 | % [dlYPred,state] = forward(dlnet,dlX); 294 | penalty = 0; 295 | scalarL2Factor = 0; 296 | if scalarL2Factor ~= 0 297 | paramLst = dlnet.Learnables.Value; 298 | for i = 1:size(paramLst,1) 299 | penalty = penalty + sum((paramLst{i}(:)).^2); 300 | end 301 | end 302 | 303 | classificationLoss = crossentropy(squeeze(dlYPred),Y) + scalarL2Factor*penalty; 304 | reconstructionLoss = sqrt(mse(reconstruction,flatX)); 305 | loss = 0.8.*classificationLoss + 0.1*reconstructionLoss; 306 | % loss = classificationLoss + 0.2*(max(max(rawPredictions))-min(max(rawPredictions))); 307 | gradients = dlgradient(loss,dlnet.Learnables); 308 | %gradients = dlgradient(loss,dlnet.Learnables(4,:)); 309 | end 310 | 311 | function [params,velocityUpdates,momentumUpdate] = adamFunction(params, rawParamGradients,... 312 | velocities, learnRates, momentums, wd, gradientMasks, iters) 313 | % https://arxiv.org/pdf/2010.07468.pdf %%AdaBelief 314 | % https://arxiv.org/pdf/1711.05101.pdf %%DeCoupled Weight Decay 315 | b1 = 0.9; 316 | b2 = 0.999; 317 | e = 1e-8; 318 | curIter = iters(:); 319 | curIter = curIter(1); 320 | 321 | gt = rawParamGradients; 322 | mt = (momentums.*b1 + ((1-b1)).*gt); 323 | vt = (velocities.*b2 + ((1-b2)).*((gt-mt).^2)); 324 | 325 | momentumUpdate = mt; 326 | velocityUpdates = vt; 327 | h_mt = mt./(1-b1.^curIter); 328 | h_vt = (vt+e)./(1-b2.^curIter); 329 | %%%%----------- Check Point 3: 330 | %%%% Here you can specify whether to use bias correction, 331 | %%%% or zero-bias dense layer 332 | %%%% in this test, we can just try to eliminate the effect of varying learning 333 | %%%% rates 334 | % params = params - 0.001.*(mt./(sqrt(vt)+e)).*gradientMasks... 335 | % - wd.*params.*gradientMasks; %This works better for zero-bias dense layer 336 | % params = params - 0.001.*(h_mt./(sqrt(h_vt)+e)).*gradientMasks... 337 | % -L2Foctors.*params.*gradientMasks; 338 | params = params - learnRates.*(h_mt./(sqrt(h_vt)+e)).*gradientMasks... 339 | -2*learnRates .* wd.*params.*gradientMasks; 340 | %%%% 341 | %%%%-----------End of Check Point 3 342 | end 343 | 344 | function param = sgdFunction(param,paramGradient) 345 | learnRate = 0.01; 346 | param = param - learnRate.*paramGradient; 347 | end 348 | 349 | function [params, velocityUpdates] = sgdmFunction(params, paramGradients,... 350 | velocities, learnRates, momentums) 351 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 352 | % velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 353 | velocityUpdates = momentums.*velocities+0.001.*paramGradients; 354 | params = params - velocityUpdates; 355 | end 356 | 357 | function [params, velocityUpdates] = sgdmFunctionL2(params, rawParamGradients,... 358 | velocities, learnRates, momentums, L2Foctors, gradientMasks) 359 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 360 | % https://towardsdatascience.com/intuitions-on-l1-and-l2-regularisation-235f2db4c261 361 | paramGradients = rawParamGradients + 2*L2Foctors.*params; 362 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 363 | params = params - (velocityUpdates).*gradientMasks; 364 | end 365 | 366 | function tabVars = packScalar(target, scalar) 367 | % The matlabs' silly design results in such a strange function 368 | tabVars = target; 369 | for row = 1:size(tabVars(:,3),1) 370 | tabVars{row,3} = {... 371 | dlarray(... 372 | ones(size(tabVars.Value{row})).*scalar...%ones(size(tabVars(row,3).Value{1,1})).*scalar... 373 | )... 374 | }; 375 | end 376 | end 377 | 378 | 379 | 380 | -------------------------------------------------------------------------------- /ContinualLearning/Readme.md: -------------------------------------------------------------------------------- 1 | Sample files for continual learning experiments, adsb_recognition_singleBurst_* are for simple two-stage incremental learning. 2 | -------------------------------------------------------------------------------- /ContinualLearning/WorkStage/Fixrep.m: -------------------------------------------------------------------------------- 1 | 2 | %%%%% 3 | %Gather fisher information and weights from the old network 4 | prevClassNum = 0; 5 | if skipDAGNet == 0 6 | lgraph2 = layerGraph(net); 7 | lgraph2 = lgraph2.removeLayers('classify_1'); 8 | weights = net.Layers({net.Layers.Name}=="Fingerprints").Weights; 9 | numClasses = size(weights,1); 10 | prevClassNum = numClasses; 11 | else 12 | lgraph2 = layerGraph(dlnet); 13 | numClasses = size(dlnet.Layers(14).Weights,1); 14 | prevClassNum = numClasses; 15 | end 16 | prevDlnet = dlnetwork(lgraph2); 17 | prevWeights = prevDlnet.Learnables; 18 | prevCX = dlarray(single(cX),'SSCB'); 19 | prevCY = zeros(numClasses, size(prevCX,4), 'single'); 20 | for c = 1:numClasses 21 | prevCY(c,cY(:)==(c)) = 1; 22 | end 23 | executionEnvironment = "auto"; 24 | % If training on a GPU, then convert data to gpuArray. 25 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 26 | prevCX = gpuArray(prevCX); 27 | end 28 | [prevGradients0,state] = dlfeval(@logModelGradientsOnWeights,prevDlnet,prevCX); 29 | accuracy = cvAccuracy(prevDlnet, prevCX, (cY), 128, executionEnvironment) 30 | prevDlnet.State = state; 31 | prevFisherInfo0 = prevGradients0; 32 | for i = 1:size(prevFisherInfo0,1) 33 | prevFisherInfo0{i,3}={dlarray(exp(prevGradients0{i,3}{:}.^2))}; 34 | % prevFisherInfo0{i,3}={dlarray(exp(prevGradients0{i,3}{:}.^2))}; 35 | 36 | end 37 | fingerprintLayerIdx = 14; 38 | 39 | newClassesNum = floor(size(unique(uY),1)); 40 | unknownClassLabels = unique(uY); 41 | idx = randperm(size(unknownClassLabels,1)); 42 | unknownClassLabels = unknownClassLabels(idx); 43 | unknownClassLabels = unknownClassLabels(1:newClassesNum); 44 | 45 | % Generate some initial fingerprints. 46 | %figure 47 | % hold on; 48 | existingFingerprints = prevDlnet.Layers(fingerprintLayerIdx).Weights; 49 | newFingerprints = dlarray([]); 50 | continualLearnX = dlarray([]); 51 | continualLearnY = []; 52 | cursor = 1; 53 | for i = 1:newClassesNum 54 | selection = uY==unknownClassLabels(i); 55 | ux_i = uX(:,:,:,selection); 56 | uy_i = uY(selection); 57 | samplePerClass = size(uy_i, 1); 58 | continualLearnX(:,:,:,cursor:cursor+samplePerClass-1) = ux_i; 59 | continualLearnY(cursor:cursor+samplePerClass-1)=uy_i; 60 | cursor = cursor+samplePerClass; 61 | % aUx_i = squeeze(prevDlnet.predict(dlarray(ux_i,'SSCB'),'Outputs','fc_bf_fp')); 62 | aUx_i = squeeze(activations(net,ux_i,'fc_bf_fp')); 63 | unitAUx_i=(aUx_i./sqrt(sum(aUx_i.^2)))'; 64 | fp_i = mean(aUx_i,2)'; 65 | magFp_i = sqrt(sum(fp_i.^2)); 66 | unitFp_i = fp_i ./ magFp_i; 67 | newFp = (unitFp_i); 68 | % newFp = findPotentialFingerprint(dlarray(existingFingerprints),dlarray(aUx_i)); 69 | newFingerprints(end+1,:) = newFp; 70 | % histogram(sum(unitFp_i.*unitFingerprints,2),[-1:0.2:1],'Normalization','probability'); 71 | % hold on; 72 | % histogram(sum(unitAUx_i.*unitFp_i,2),[-1:0.2:1],'Normalization','probability'); 73 | % legend('Correlation with existing fingerprints','Correlation with own samples'); 74 | end 75 | 76 | xNum = size(X,4); 77 | rndSeq = randperm(xNum); 78 | rndSeq = rndSeq(1:min(numel(rndSeq),5000)); 79 | 80 | continualLearnX = cat(4,continualLearnX,X(:,:,:,rndSeq)); 81 | continualLearnY = cat(2,continualLearnY,Y(rndSeq)'); 82 | 83 | randSeries = randperm(size(continualLearnY,2)); 84 | continualLearnX = continualLearnX(:,:,:,randSeries); 85 | continualLearnY = continualLearnY(randSeries); 86 | cvContinualLearnX = continualLearnX(:,:,:,floor(0.6*size(continualLearnX,4)):end); 87 | cvContinualLearnY = continualLearnY(floor(0.6*size(continualLearnY,2)):end); 88 | continualLearnX = continualLearnX(:,:,:,1:floor(0.6*size(continualLearnX,4))-1); 89 | continualLearnY = continualLearnY(1:floor(0.6*size(continualLearnY,2))-1); 90 | 91 | concatFingerprints = [existingFingerprints; newFingerprints]; 92 | 93 | % newFingerprintLayer = FCLayer(numClasses,numClasses+newClassesNum,'Fingerprints',concatFingerprints); 94 | newFingerprintLayer = zeroBiasFCLayer(numClasses,prevClassNum+newClassesNum,'Fingerprints',concatFingerprints); 95 | 96 | % newFingerprintLayer.normMag = prevDlnet.Layers(4).normMag; 97 | % newFingerprintLayer.b1 = prevDlnet.Layers(4).b1; 98 | 99 | numClasses = prevClassNum + newClassesNum; 100 | 101 | %Build new network and start continual learning. 102 | lgraph2 = layerGraph(net); 103 | lgraph2 = lgraph2.removeLayers('classify_1'); 104 | dlnet = dlnetwork(replaceLayer(lgraph2, 'Fingerprints', newFingerprintLayer)); 105 | 106 | XTrain = continualLearnX; 107 | YTrain = continualLearnY; 108 | XTest = cuX; 109 | YTest = cuY; 110 | numEpochs = 3; 111 | miniBatchSize = 20; 112 | plots = "training-progress"; 113 | statusFigureIdx = []; 114 | statusFigureAxis = []; 115 | if plots == "training-progress" 116 | figure; 117 | statusFigureIdx = gcf; 118 | statusFigureAxis = gca; 119 | % Go into the documentation of animatedline for more color codes 120 | lineNewLossTrain = animatedline('Color', '#0072BD','LineWidth',1,'Marker','.','LineStyle','none'); 121 | lineNewCVAccuracy = animatedline('Color', '#D95319','LineWidth',1); 122 | lineOldCVAccuracy = animatedline('Color', '#EDB120','LineWidth',1); 123 | lineOldLossCV = animatedline('Color', '#7E2F8E','LineWidth',1,'Marker','.','LineStyle','none'); 124 | % ylim([0 inf]) 125 | ylim([0 2]) 126 | xlabel("Iteration") 127 | ylabel("Metrics") 128 | legend('New task loss','New task accuracy','Old task accuracy', 'Old task lost'); 129 | grid on 130 | end 131 | L2RegularizationFactor = 0.01; 132 | initialLearnRate = 0.01; 133 | decay = 0.01; 134 | momentum = 0.9; 135 | velocities = []; 136 | learnRates = []; 137 | momentums = []; 138 | gradientMasks = []; 139 | numObservations = numel(YTrain); 140 | numIterationsPerEpoch = floor(numObservations./miniBatchSize); 141 | iteration = 0; 142 | start = tic; 143 | classes = categorical(YTrain); 144 | 145 | % 146 | prevCY = [prevCY;zeros(newClassesNum,size(prevCY,2))]; 147 | 148 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 149 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 150 | disp('CV accuracy b.f. cont. learning'); 151 | [newCvAccuracy, oldCvAccuracy] 152 | 153 | % Loop over epochs. 154 | fisherLossLst = []; 155 | for epoch = 1:numEpochs 156 | idx = randperm(numel(YTrain)); 157 | XTrain = XTrain(:,:,:,idx); 158 | YTrain = YTrain(idx); 159 | % Loop over mini-batches. 160 | for i = 1:numIterationsPerEpoch 161 | iteration = iteration + 1; 162 | % Read mini-batch of data and convert the labels to dummy 163 | % variables. 164 | idx = (i-1)*miniBatchSize+1:i*miniBatchSize; 165 | Xb = XTrain(:,:,:,idx); 166 | Yb = zeros(numClasses, miniBatchSize, 'single'); 167 | for c = 1:numClasses 168 | Yb(c,YTrain(idx)==(c)) = 1; 169 | end 170 | % Convert mini-batch of data to dlarray. 171 | dlX = dlarray(single(Xb),'SSCB'); 172 | % If training on a GPU, then convert data to gpuArray. 173 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 174 | dlX = gpuArray(dlX); 175 | end 176 | % Evaluate the model gradients, state, and loss using dlfeval and the 177 | % modelGradients function and update the network state. 178 | % [gradients,state,loss] = dlfeval(@modelGradientsOnWeights,dlnet,dlX,Yb); 179 | % if epoch <= 5 180 | % [gradients,state,loss,fisherLoss] = dlfeval(@modelGradientsOnWeightsEWC,dlnet,dlX, Yb,... 181 | % prevFisherInfo0, prevDlnet, newClassesNum, 0, fingerprintLayerIdx); 182 | % else 183 | % [gradients,state,loss,fisherLoss] = dlfeval(@modelGradientsOnWeightsEWC,dlnet,dlX, Yb,... 184 | % prevFisherInfo0, prevDlnet, newClassesNum, 1, fingerprintLayerIdx); 185 | % end 186 | [gradients,state,loss,fisherLoss] = dlfeval(@modelGradientsOnWeightsEWC,dlnet,dlX, Yb,... 187 | prevFisherInfo0, prevDlnet, newClassesNum, 0, fingerprintLayerIdx,prevClassNum); 188 | 189 | fisherLossLst(end+1)=extractdata(gather(fisherLoss)); 190 | 191 | dlnet.State = state; 192 | % Determine learning rate for time-based decay learning rate schedule. 193 | learnRate = initialLearnRate/(1 + decay*iteration); 194 | if isempty(velocities) 195 | velocities = packScalar(gradients, 0); 196 | learnRates = packScalar(gradients, learnRate); 197 | momentums = packScalar(gradients, momentum); 198 | L2Foctors = packScalar(gradients, 0.05); 199 | gradientMasks = packScalar(gradients, 1); 200 | % Let's lock some weights 201 | % Zero-bias dense layer only. 202 | % if epoch <= 5 203 | % gradientMasks{fingerprintLayerIdx,3} = {dlarray([zeros(18,18); ones(newClassesNum,fingerprintLen)])}; 204 | % else 205 | % gradientMasks{fingerprintLayerIdx,3} = {dlarray([ones(18,18); ones(newClassesNum,fingerprintLen)])}; 206 | % end 207 | 208 | % gradientMasks{fingerprintLayerIdx-1,3} = {dlarray([zeros(prevClassNum,20); ones(newClassesNum,20)])}; 209 | % gradientMasks{9+1,3} = {dlarray([zeros(prevClassNum,1); ones(newClassesNum,1)])}; 210 | 211 | % for k = 1:fingerprintLayerIdx-2 212 | % gradientMasks.Value{k}=dlarray(zeros(size(gradientMasks.Value{k}))); 213 | % end 214 | end 215 | %fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet,newClassesNum) 216 | [dlnet, velocities] = dlupdate(@sgdmFunctionL2, ... 217 | dlnet, gradients, velocities, ... 218 | learnRates, momentums, L2Foctors, gradientMasks); 219 | 220 | % Display the training progress. 221 | if plots == "training-progress" 222 | D = duration(0,0,toc(start),'Format','hh:mm:ss'); 223 | % figure(statusFigureIdx); 224 | if mod(iteration,20) == 0 225 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 226 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 227 | [~,~,oldCVLoss] = dlfeval(@modelGradientsOnWeights,dlnet,prevCX,prevCY); 228 | addpoints(lineNewCVAccuracy, iteration, newCvAccuracy); 229 | addpoints(lineOldCVAccuracy, iteration, oldCvAccuracy); 230 | addpoints(lineOldLossCV, iteration, double(gather(extractdata(oldCVLoss)))); 231 | end 232 | addpoints(lineNewLossTrain,iteration,double(gather(extractdata(loss)))) 233 | title(statusFigureAxis,"Epoch: " + epoch + ", Elapsed: " + string(D)) 234 | drawnow 235 | end 236 | end 237 | end 238 | 239 | 240 | % 241 | % figure 242 | % subplot(2,1,1) 243 | % imagesc(prevDlnet.Layers(fingerprintLayerIdx).Weights) 244 | % title('Old Finger prints'); 245 | % subplot(2,1,2) 246 | % imagesc(dlnet.Layers(fingerprintLayerIdx).Weights) 247 | % title('New Finger prints'); 248 | % 249 | figure 250 | weights = dlnet.Layers(14).Weights; 251 | similarMatrix = zeros(size(weights,1), size(weights,1)); 252 | for i = 1:size(weights,1) 253 | curWeight = weights(i,:); 254 | magCurWeight = sqrt(sum(curWeight.^2,2)); 255 | for j = 1:size(weights,1) 256 | nxtWeight = weights(j,:); 257 | magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 258 | similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 259 | .*(curWeight./magCurWeight)); 260 | end 261 | end 262 | sumDeviationAnglesAll = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 263 | imagesc(similarMatrix) 264 | % 265 | % weights = dlnet.Layers(11).Weights(1:18,:); 266 | % similarMatrix = zeros(size(weights,1), size(weights,1)); 267 | % for i = 1:size(weights,1) 268 | % curWeight = weights(i,:); 269 | % magCurWeight = sqrt(sum(curWeight.^2,2)); 270 | % for j = 1:size(weights,1) 271 | % nxtWeight = weights(j,:); 272 | % magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 273 | % similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 274 | % .*(curWeight./magCurWeight)); 275 | % end 276 | % end 277 | % sumDeviationAnglesOld = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 278 | % 279 | % weights = dlnet.Layers(11).Weights(19:end,:); 280 | % similarMatrix = zeros(size(weights,1), size(weights,1)); 281 | % for i = 1:size(weights,1) 282 | % curWeight = weights(i,:); 283 | % magCurWeight = sqrt(sum(curWeight.^2,2)); 284 | % for j = 1:size(weights,1) 285 | % nxtWeight = weights(j,:); 286 | % magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 287 | % similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 288 | % .*(curWeight./magCurWeight)); 289 | % end 290 | % end 291 | % sumDeviationAnglesNew = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 292 | 293 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 294 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 295 | disp('CV accuracy a.f. cont. learning'); 296 | [newCvAccuracy, oldCvAccuracy] 297 | 298 | % figure 299 | % subplot(2,1,1) 300 | % imagesc(prevDlnet.Layers(2).Weights) 301 | % title('Old input dense'); 302 | % subplot(2,1,2) 303 | % imagesc(dlnet.Layers(2).Weights) 304 | % title('New input dense'); 305 | 306 | function [totalLoss,corrOld,corrNew,grad] = evalFp(existingFp, newFeatures, fp) 307 | magFp = sqrt(sum(fp.^2)); 308 | unitFp = fp ./ magFp; 309 | corrOld = mean(sum( existingFp./sqrt( sum((existingFp).^2,2)).*unitFp, 2 )); 310 | corrNew = mean(sum( newFeatures'./sqrt( sum((newFeatures').^2,2)).*unitFp, 2)); 311 | totalLoss = corrOld - 10*corrNew; 312 | grad = dlgradient(totalLoss, fp); 313 | end 314 | 315 | function fp = findPotentialFingerprint(existingFp, newFeatures) 316 | fp = mean(newFeatures,2)'; 317 | % magFp = sqrt(sum(fp.^2)); 318 | % unitFp = fp ./ magFp; 319 | record = []; 320 | for i = 1:20 321 | [totalLoss,corrOld,corrNew, grad] = dlfeval(@evalFp, (existingFp), (newFeatures), fp); 322 | record = [record;[totalLoss,corrOld,corrNew]]; 323 | fp = fp - 0.5.*grad; 324 | end 325 | % figure 326 | % plot(extractdata(record(:,1)),'LineWidth',1.5); 327 | % hold on; 328 | % plot(extractdata(record(:,2)),'LineWidth',1.5); 329 | % plot(extractdata(record(:,3)),'LineWidth',1.5); 330 | end 331 | 332 | function [gradients,state,loss,fisherLoss] = modelGradientsOnWeightsEWC(dlnet, dlX, Y,... 333 | prevFisherInfo0, prevDlnet, newClassesNum, ewcLambda, fingerprintLayerIdx, prevClassNum) 334 | [dlYPred,state] = forward(dlnet,dlX); 335 | penalty = 0; 336 | scalarL2Factor = 0; 337 | if scalarL2Factor ~= 0 338 | paramLst = dlnet.Learnables.Value; 339 | for i = 1:size(paramLst,1) 340 | penalty = penalty + sum((paramLst{i}(:)).^2); 341 | end 342 | end 343 | fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet, newClassesNum, fingerprintLayerIdx)*ewcLambda/2; 344 | loss = crossentropy(dlYPred,Y) + scalarL2Factor*penalty + fisherLoss; 345 | gradients = dlgradient(loss, dlnet.Learnables); 346 | end 347 | 348 | function fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet, newClassesNum,fingerprintLayerIdx) 349 | prevWeights = prevDlnet.Learnables.Value; 350 | curWeights = dlnet.Learnables.Value; 351 | % fisherLossMatrix = {}; 352 | sumLoss = 0; 353 | elementCount = 1; 354 | for i = 1:size(prevWeights,1) 355 | if i >= fingerprintLayerIdx-1 356 | loss = ((prevWeights{i}-curWeights{i}(1:size(curWeights{i},1)-newClassesNum,:)).^2) .* prevFisherInfo0.Value{i}; 357 | % fisherLossMatrix{end+1} = loss; 358 | sumLoss = sumLoss + sum(loss(:)); 359 | elementCount = elementCount + numel(prevWeights{i}); 360 | else 361 | loss = ((prevWeights{i}-curWeights{i}).^2) .* prevFisherInfo0.Value{i}; 362 | % fisherLossMatrix{end+1} = loss; 363 | sumLoss = sumLoss + sum(loss(:)); 364 | elementCount = elementCount + numel(prevWeights{i}); 365 | end 366 | end 367 | fisherLoss = sumLoss; 368 | end 369 | 370 | function accuracy = cvAccuracy(dlnet, XTest, YTest, miniBatchSize, executionEnvironment) 371 | dlXTest = dlarray(XTest,'SSCB'); 372 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 373 | dlXTest = gpuArray(dlXTest); 374 | end 375 | dlYPred = modelPredictions(dlnet,dlXTest,miniBatchSize); 376 | [~,idx] = max(extractdata(dlYPred),[],1); 377 | YPred = (idx); 378 | accuracy = mean(YPred(:) == YTest(:)); 379 | end 380 | 381 | function dlYPred = modelPredictions(dlnet,dlX,miniBatchSize) 382 | numObservations = size(dlX,4); 383 | numIterations = ceil(numObservations / miniBatchSize); 384 | numClasses = size(dlnet.Layers(end-1).Weights,1); 385 | dlYPred = zeros(numClasses,numObservations,'like',dlX); 386 | for i = 1:numIterations 387 | idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations); 388 | dlYPred(:,idx) = predict(dlnet,dlX(:,:,:,idx)); 389 | end 390 | end 391 | 392 | function [gradients,state] = logModelGradientsOnWeights(dlnet,dlX) 393 | [dlYPred,state] = forward(dlnet,dlX); 394 | loglikelyhood = log(dlYPred-min(dlYPred(:))+1e-5); 395 | gradients = dlgradient(mean(loglikelyhood(:)),dlnet.Learnables); 396 | end 397 | 398 | function [gradients,state,loss] = modelGradientsOnWeights(dlnet,dlX,Y) 399 | [dlYPred,state] = forward(dlnet,dlX); 400 | penalty = 0; 401 | scalarL2Factor = 0; 402 | if scalarL2Factor ~= 0 403 | paramLst = dlnet.Learnables.Value; 404 | for i = 1:size(paramLst,1) 405 | penalty = penalty + sum((paramLst{i}(:)).^2); 406 | end 407 | end 408 | loss = crossentropy(dlYPred,Y) + scalarL2Factor*penalty; 409 | gradients = dlgradient(loss,dlnet.Learnables); 410 | %gradients = dlgradient(loss,dlnet.Learnables(4,:)); 411 | end 412 | 413 | function [params, velocityUpdates] = sgdmFunction(params, paramGradients,... 414 | velocities, learnRates, momentums) 415 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 416 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 417 | params = params - velocityUpdates; 418 | end 419 | 420 | function [params, velocityUpdates] = sgdmFunctionL2(params, rawParamGradients,... 421 | velocities, learnRates, momentums, L2Foctors, gradientMasks) 422 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 423 | % https://towardsdatascience.com/intuitions-on-l1-and-l2-regularisation-235f2db4c261 424 | paramGradients = rawParamGradients + 2*L2Foctors.*params; 425 | %Please be noted that even if rawParamGradients = 0, L2 will still try 426 | %to reduce the magnitudes of parameters 427 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 428 | params = params - (velocityUpdates).*gradientMasks; 429 | end 430 | 431 | function tabVars = packScalar(target, scalar) 432 | % The matlabs' silly design results in such a strange function 433 | tabVars = target; 434 | for row = 1:size(tabVars(:,3),1) 435 | tabVars{row,3} = {... 436 | dlarray(... 437 | ones(size(tabVars.Value{row})).*scalar...%ones(size(tabVars(row,3).Value{1,1})).*scalar... 438 | )... 439 | }; 440 | end 441 | end -------------------------------------------------------------------------------- /ContinualLearning/WorkStage/FixrepExpanding.m: -------------------------------------------------------------------------------- 1 | 2 | %%%%% 3 | %Gather fisher information and weights from the old network 4 | prevClassNum = 0; 5 | if skipDAGNet == 0 6 | lgraph2 = layerGraph(net); 7 | lgraph2 = lgraph2.removeLayers('classify_1'); 8 | weights = net.Layers({net.Layers.Name}=="Fingerprints").Weights; 9 | numClasses = size(weights,1); 10 | prevClassNum = numClasses; 11 | fingerprintLayerIdx = 14; 12 | else 13 | lgraph2 = layerGraph(dlnet); 14 | numClasses = size(dlnet.Layers(13).Weights,1); 15 | prevClassNum = numClasses; 16 | fingerprintLayerIdx = 13; 17 | end 18 | prevDlnet = dlnetwork(lgraph2); 19 | prevWeights = prevDlnet.Learnables; 20 | dlnet = prevDlnet; 21 | prevCX = dlarray(single(cX),'SSCB'); 22 | prevCY = zeros(numClasses, size(prevCX,4), 'single'); 23 | for c = 1:numClasses 24 | prevCY(c,cY(:)==(c)) = 1; 25 | end 26 | executionEnvironment = "auto"; 27 | % If training on a GPU, then convert data to gpuArray. 28 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 29 | prevCX = gpuArray(prevCX); 30 | end 31 | [prevGradients0,state] = dlfeval(@logModelGradientsOnWeights,prevDlnet,prevCX); 32 | accuracy = cvAccuracy(prevDlnet, prevCX, (cY), 128, executionEnvironment) 33 | prevDlnet.State = state; 34 | prevFisherInfo0 = prevGradients0; 35 | for i = 1:size(prevFisherInfo0,1) 36 | prevFisherInfo0{i,3}={dlarray(exp(prevGradients0{i,3}{:}.^2))}; 37 | % prevFisherInfo0{i,3}={dlarray(exp(prevGradients0{i,3}{:}.^2))}; 38 | 39 | end 40 | 41 | newClassesNum = floor(size(unique(uY),1)); 42 | unknownClassLabels = unique(uY); 43 | idx = randperm(size(unknownClassLabels,1)); 44 | unknownClassLabels = unknownClassLabels(idx); 45 | unknownClassLabels = unknownClassLabels(1:newClassesNum); 46 | 47 | % Generate some initial fingerprints. 48 | %figure 49 | % hold on; 50 | existingFingerprints = prevDlnet.Layers(fingerprintLayerIdx).Weights; 51 | newFingerprints = dlarray([]); 52 | continualLearnX = dlarray([]); 53 | continualLearnY = []; 54 | cursor = 1; 55 | for i = 1:newClassesNum 56 | selection = uY==unknownClassLabels(i); 57 | ux_i = uX(:,:,:,selection); 58 | uy_i = uY(selection); 59 | samplePerClass = size(uy_i, 1); 60 | continualLearnX(:,:,:,cursor:cursor+samplePerClass-1) = ux_i; 61 | continualLearnY(cursor:cursor+samplePerClass-1)=uy_i; 62 | cursor = cursor+samplePerClass; 63 | % aUx_i = squeeze(prevDlnet.predict(dlarray(ux_i,'SSCB'),'Outputs','fc_bf_fp')); 64 | % aUx_i = squeeze(activations(net,ux_i,'fc_bf_fp')); 65 | aUx_i = stripdims(squeeze(predict(dlnet,dlarray(ux_i,'SSCB'),'Outputs','fc_bf_fp'))); 66 | unitAUx_i=(aUx_i./sqrt(sum(aUx_i.^2)))'; 67 | fp_i = mean(aUx_i,2)'; 68 | magFp_i = sqrt(sum(fp_i.^2)); 69 | unitFp_i = fp_i ./ magFp_i; 70 | newFp = (unitFp_i); 71 | % newFp = findPotentialFingerprint(dlarray(existingFingerprints),dlarray(aUx_i)); 72 | newFingerprints(end+1,:) = newFp; 73 | % histogram(sum(unitFp_i.*unitFingerprints,2),[-1:0.2:1],'Normalization','probability'); 74 | % hold on; 75 | % histogram(sum(unitAUx_i.*unitFp_i,2),[-1:0.2:1],'Normalization','probability'); 76 | % legend('Correlation with existing fingerprints','Correlation with own samples'); 77 | end 78 | 79 | xNum = size(X,4); 80 | rndSeq = randperm(xNum); 81 | rndSeq = rndSeq(1:min(numel(rndSeq),5000)); 82 | 83 | continualLearnX = cat(4,continualLearnX,X(:,:,:,rndSeq)); 84 | continualLearnY = cat(2,continualLearnY,Y(rndSeq)'); 85 | 86 | randSeries = randperm(size(continualLearnY,2)); 87 | continualLearnX = continualLearnX(:,:,:,randSeries); 88 | continualLearnY = continualLearnY(randSeries); 89 | cvContinualLearnX = continualLearnX(:,:,:,floor(0.6*size(continualLearnX,4)):end); 90 | cvContinualLearnY = continualLearnY(floor(0.6*size(continualLearnY,2)):end); 91 | continualLearnX = continualLearnX(:,:,:,1:floor(0.6*size(continualLearnX,4))-1); 92 | continualLearnY = continualLearnY(1:floor(0.6*size(continualLearnY,2))-1); 93 | 94 | concatFingerprints = [[existingFingerprints; newFingerprints],... 95 | 0.001.*randn(prevClassNum+newClassesNum,newClassesNum)]; 96 | newFeatureEmbeddingLayer = fullyConnectedLayer(... 97 | prevClassNum+newClassesNum,'Weights',... 98 | [prevDlnet.Layers(12).Weights;... 99 | 0.001.*randn(newClassesNum,size(prevDlnet.Layers(12).Weights,2))],... 100 | 'Name','fc_bf_fp'); 101 | 102 | % newFingerprintLayer = FCLayer(numClasses,numClasses+newClassesNum,'Fingerprints',concatFingerprints); 103 | newFingerprintLayer = zeroBiasFCLayer(numClasses+newClassesNum,... 104 | prevClassNum+newClassesNum,'Fingerprints',concatFingerprints); 105 | 106 | % newFingerprintLayer.normMag = prevDlnet.Layers(4).normMag; 107 | % newFingerprintLayer.b1 = prevDlnet.Layers(4).b1; 108 | 109 | numClasses = prevClassNum + newClassesNum; 110 | 111 | %Build new network and start continual learning. 112 | lgraph2 = layerGraph(prevDlnet); 113 | % lgraph2 = lgraph2.removeLayers('classify_1'); 114 | lgraph2 = lgraph2.replaceLayer('Fingerprints', newFingerprintLayer); 115 | lgraph2 = lgraph2.replaceLayer('fc_bf_fp', newFeatureEmbeddingLayer); 116 | if skipDAGNet == 0 117 | lgraph2 = lgraph2.removeLayers('batchNorm_2'); 118 | lgraph2 = lgraph2.connectLayers('fc_bf_fp','Fingerprints'); 119 | end 120 | dlnet = dlnetwork(lgraph2); 121 | 122 | 123 | XTrain = continualLearnX; 124 | YTrain = continualLearnY; 125 | XTest = cuX; 126 | YTest = cuY; 127 | numEpochs = 3; 128 | miniBatchSize = 20; 129 | plots = "training-progress"; 130 | statusFigureIdx = []; 131 | statusFigureAxis = []; 132 | if plots == "training-progress" 133 | figure; 134 | statusFigureIdx = gcf; 135 | statusFigureAxis = gca; 136 | % Go into the documentation of animatedline for more color codes 137 | lineNewLossTrain = animatedline('Color', '#0072BD','LineWidth',1,'Marker','.','LineStyle','none'); 138 | lineNewCVAccuracy = animatedline('Color', '#D95319','LineWidth',1); 139 | lineOldCVAccuracy = animatedline('Color', '#EDB120','LineWidth',1); 140 | lineOldLossCV = animatedline('Color', '#7E2F8E','LineWidth',1,'Marker','.','LineStyle','none'); 141 | % ylim([0 inf]) 142 | ylim([0 2]) 143 | xlabel("Iteration") 144 | ylabel("Metrics") 145 | legend('New task loss','New task accuracy','Old task accuracy', 'Old task lost'); 146 | grid on 147 | end 148 | L2RegularizationFactor = 0.01; 149 | initialLearnRate = 0.01; 150 | decay = 0.01; 151 | momentum = 0.9; 152 | velocities = []; 153 | learnRates = []; 154 | momentums = []; 155 | gradientMasks = []; 156 | numObservations = numel(YTrain); 157 | numIterationsPerEpoch = floor(numObservations./miniBatchSize); 158 | iteration = 0; 159 | start = tic; 160 | classes = categorical(YTrain); 161 | 162 | % 163 | prevCY = [prevCY;zeros(newClassesNum,size(prevCY,2))]; 164 | 165 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 166 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 167 | disp('CV accuracy b.f. cont. learning'); 168 | [newCvAccuracy, oldCvAccuracy] 169 | 170 | % Loop over epochs. 171 | fisherLossLst = []; 172 | for epoch = 1:numEpochs 173 | idx = randperm(numel(YTrain)); 174 | XTrain = XTrain(:,:,:,idx); 175 | YTrain = YTrain(idx); 176 | % Loop over mini-batches. 177 | for i = 1:numIterationsPerEpoch 178 | iteration = iteration + 1; 179 | % Read mini-batch of data and convert the labels to dummy 180 | % variables. 181 | idx = (i-1)*miniBatchSize+1:i*miniBatchSize; 182 | Xb = XTrain(:,:,:,idx); 183 | Yb = zeros(numClasses, miniBatchSize, 'single'); 184 | for c = 1:numClasses 185 | Yb(c,YTrain(idx)==(c)) = 1; 186 | end 187 | % Convert mini-batch of data to dlarray. 188 | dlX = dlarray(single(Xb),'SSCB'); 189 | % If training on a GPU, then convert data to gpuArray. 190 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 191 | dlX = gpuArray(dlX); 192 | end 193 | % Evaluate the model gradients, state, and loss using dlfeval and the 194 | % modelGradients function and update the network state. 195 | % [gradients,state,loss] = dlfeval(@modelGradientsOnWeights,dlnet,dlX,Yb); 196 | % if epoch <= 5 197 | % [gradients,state,loss,fisherLoss] = dlfeval(@modelGradientsOnWeightsEWC,dlnet,dlX, Yb,... 198 | % prevFisherInfo0, prevDlnet, newClassesNum, 0, fingerprintLayerIdx); 199 | % else 200 | % [gradients,state,loss,fisherLoss] = dlfeval(@modelGradientsOnWeightsEWC,dlnet,dlX, Yb,... 201 | % prevFisherInfo0, prevDlnet, newClassesNum, 1, fingerprintLayerIdx); 202 | % end 203 | [gradients,state,loss,fisherLoss] = dlfeval(@modelGradientsOnWeightsEWC,dlnet,dlX, Yb,... 204 | prevFisherInfo0, prevDlnet, newClassesNum, 0, fingerprintLayerIdx,prevClassNum); 205 | 206 | fisherLossLst(end+1)=(fisherLoss); 207 | 208 | dlnet.State = state; 209 | % Determine learning rate for time-based decay learning rate schedule. 210 | learnRate = initialLearnRate/(1 + decay*iteration); 211 | if isempty(velocities) 212 | velocities = packScalar(gradients, 0); 213 | learnRates = packScalar(gradients, learnRate); 214 | momentums = packScalar(gradients, momentum); 215 | L2Foctors = packScalar(gradients, 0.05); 216 | gradientMasks = packScalar(gradients, 1); 217 | % Let's lock some weights 218 | % Zero-bias dense layer only. 219 | % if epoch <= 5 220 | % gradientMasks{fingerprintLayerIdx,3} = {dlarray([zeros(18,18); ones(newClassesNum,fingerprintLen)])}; 221 | % else 222 | % gradientMasks{fingerprintLayerIdx,3} = {dlarray([ones(18,18); ones(newClassesNum,fingerprintLen)])}; 223 | % end 224 | 225 | % gradientMasks{fingerprintLayerIdx-1,3} = {dlarray([zeros(prevClassNum,20); ones(newClassesNum,20)])}; 226 | % gradientMasks{9+1,3} = {dlarray([zeros(prevClassNum,1); ones(newClassesNum,1)])}; 227 | 228 | % for k = 1:fingerprintLayerIdx-2 229 | % gradientMasks.Value{k}=dlarray(zeros(size(gradientMasks.Value{k}))); 230 | % end 231 | end 232 | %fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet,newClassesNum) 233 | [dlnet, velocities] = dlupdate(@sgdmFunctionL2, ... 234 | dlnet, gradients, velocities, ... 235 | learnRates, momentums, L2Foctors, gradientMasks); 236 | 237 | % Display the training progress. 238 | if plots == "training-progress" 239 | D = duration(0,0,toc(start),'Format','hh:mm:ss'); 240 | % figure(statusFigureIdx); 241 | if mod(iteration,20) == 0 242 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 243 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 244 | [~,~,oldCVLoss] = dlfeval(@modelGradientsOnWeights,dlnet,prevCX,prevCY); 245 | addpoints(lineNewCVAccuracy, iteration, newCvAccuracy); 246 | addpoints(lineOldCVAccuracy, iteration, oldCvAccuracy); 247 | addpoints(lineOldLossCV, iteration, double(gather(extractdata(oldCVLoss)))); 248 | end 249 | addpoints(lineNewLossTrain,iteration,double(gather(extractdata(loss)))) 250 | title(statusFigureAxis,"Epoch: " + epoch + ", Elapsed: " + string(D)) 251 | drawnow 252 | end 253 | end 254 | end 255 | 256 | 257 | % 258 | % figure 259 | % subplot(2,1,1) 260 | % imagesc(prevDlnet.Layers(fingerprintLayerIdx).Weights) 261 | % title('Old Finger prints'); 262 | % subplot(2,1,2) 263 | % imagesc(dlnet.Layers(fingerprintLayerIdx).Weights) 264 | % title('New Finger prints'); 265 | % 266 | figure 267 | weights = dlnet.Layers(13).Weights; 268 | similarMatrix = zeros(size(weights,1), size(weights,1)); 269 | for i = 1:size(weights,1) 270 | curWeight = weights(i,:); 271 | magCurWeight = sqrt(sum(curWeight.^2,2)); 272 | for j = 1:size(weights,1) 273 | nxtWeight = weights(j,:); 274 | magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 275 | similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 276 | .*(curWeight./magCurWeight)); 277 | end 278 | end 279 | sumDeviationAnglesAll = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 280 | imagesc(similarMatrix) 281 | % 282 | % weights = dlnet.Layers(11).Weights(1:18,:); 283 | % similarMatrix = zeros(size(weights,1), size(weights,1)); 284 | % for i = 1:size(weights,1) 285 | % curWeight = weights(i,:); 286 | % magCurWeight = sqrt(sum(curWeight.^2,2)); 287 | % for j = 1:size(weights,1) 288 | % nxtWeight = weights(j,:); 289 | % magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 290 | % similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 291 | % .*(curWeight./magCurWeight)); 292 | % end 293 | % end 294 | % sumDeviationAnglesOld = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 295 | % 296 | % weights = dlnet.Layers(11).Weights(19:end,:); 297 | % similarMatrix = zeros(size(weights,1), size(weights,1)); 298 | % for i = 1:size(weights,1) 299 | % curWeight = weights(i,:); 300 | % magCurWeight = sqrt(sum(curWeight.^2,2)); 301 | % for j = 1:size(weights,1) 302 | % nxtWeight = weights(j,:); 303 | % magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 304 | % similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 305 | % .*(curWeight./magCurWeight)); 306 | % end 307 | % end 308 | % sumDeviationAnglesNew = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 309 | 310 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 311 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 312 | disp('CV accuracy a.f. cont. learning'); 313 | [newCvAccuracy, oldCvAccuracy] 314 | 315 | % figure 316 | % subplot(2,1,1) 317 | % imagesc(prevDlnet.Layers(2).Weights) 318 | % title('Old input dense'); 319 | % subplot(2,1,2) 320 | % imagesc(dlnet.Layers(2).Weights) 321 | % title('New input dense'); 322 | 323 | function [totalLoss,corrOld,corrNew,grad] = evalFp(existingFp, newFeatures, fp) 324 | magFp = sqrt(sum(fp.^2)); 325 | unitFp = fp ./ magFp; 326 | corrOld = mean(sum( existingFp./sqrt( sum((existingFp).^2,2)).*unitFp, 2 )); 327 | corrNew = mean(sum( newFeatures'./sqrt( sum((newFeatures').^2,2)).*unitFp, 2)); 328 | totalLoss = corrOld - 10*corrNew; 329 | grad = dlgradient(totalLoss, fp); 330 | end 331 | 332 | function fp = findPotentialFingerprint(existingFp, newFeatures) 333 | fp = mean(newFeatures,2)'; 334 | % magFp = sqrt(sum(fp.^2)); 335 | % unitFp = fp ./ magFp; 336 | record = []; 337 | for i = 1:20 338 | [totalLoss,corrOld,corrNew, grad] = dlfeval(@evalFp, (existingFp), (newFeatures), fp); 339 | record = [record;[totalLoss,corrOld,corrNew]]; 340 | fp = fp - 0.5.*grad; 341 | end 342 | % figure 343 | % plot(extractdata(record(:,1)),'LineWidth',1.5); 344 | % hold on; 345 | % plot(extractdata(record(:,2)),'LineWidth',1.5); 346 | % plot(extractdata(record(:,3)),'LineWidth',1.5); 347 | end 348 | 349 | function [gradients,state,loss,fisherLoss] = modelGradientsOnWeightsEWC(dlnet, dlX, Y,... 350 | prevFisherInfo0, prevDlnet, newClassesNum, ewcLambda, fingerprintLayerIdx, prevClassNum) 351 | [dlYPred,state] = forward(dlnet,dlX); 352 | penalty = 0; 353 | scalarL2Factor = 0; 354 | if scalarL2Factor ~= 0 355 | paramLst = dlnet.Learnables.Value; 356 | for i = 1:size(paramLst,1) 357 | penalty = penalty + sum((paramLst{i}(:)).^2); 358 | end 359 | end 360 | fisherLoss = 0; 361 | % fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet, newClassesNum, fingerprintLayerIdx)*ewcLambda/2; 362 | loss = crossentropy(dlYPred,Y) + scalarL2Factor*penalty; 363 | gradients = dlgradient(loss, dlnet.Learnables); 364 | end 365 | 366 | function fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet, newClassesNum,fingerprintLayerIdx) 367 | prevWeights = prevDlnet.Learnables.Value; 368 | curWeights = dlnet.Learnables.Value; 369 | % fisherLossMatrix = {}; 370 | sumLoss = 0; 371 | elementCount = 1; 372 | for i = 1:size(prevWeights,1) 373 | if i >= fingerprintLayerIdx-1 374 | loss = ((prevWeights{i}-curWeights{i}(1:size(curWeights{i},1)-newClassesNum,:)).^2) .* prevFisherInfo0.Value{i}; 375 | % fisherLossMatrix{end+1} = loss; 376 | sumLoss = sumLoss + sum(loss(:)); 377 | elementCount = elementCount + numel(prevWeights{i}); 378 | else 379 | loss = ((prevWeights{i}-curWeights{i}).^2) .* prevFisherInfo0.Value{i}; 380 | % fisherLossMatrix{end+1} = loss; 381 | sumLoss = sumLoss + sum(loss(:)); 382 | elementCount = elementCount + numel(prevWeights{i}); 383 | end 384 | end 385 | fisherLoss = sumLoss; 386 | end 387 | 388 | function accuracy = cvAccuracy(dlnet, XTest, YTest, miniBatchSize, executionEnvironment) 389 | dlXTest = dlarray(XTest,'SSCB'); 390 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 391 | dlXTest = gpuArray(dlXTest); 392 | end 393 | dlYPred = modelPredictions(dlnet,dlXTest,miniBatchSize); 394 | [~,idx] = max(extractdata(dlYPred),[],1); 395 | YPred = (idx); 396 | accuracy = mean(YPred(:) == YTest(:)); 397 | end 398 | 399 | function dlYPred = modelPredictions(dlnet,dlX,miniBatchSize) 400 | numObservations = size(dlX,4); 401 | numIterations = ceil(numObservations / miniBatchSize); 402 | numClasses = size(dlnet.Layers(end-1).Weights,1); 403 | dlYPred = zeros(numClasses,numObservations,'like',dlX); 404 | for i = 1:numIterations 405 | idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations); 406 | dlYPred(:,idx) = predict(dlnet,dlX(:,:,:,idx)); 407 | end 408 | end 409 | 410 | function [gradients,state] = logModelGradientsOnWeights(dlnet,dlX) 411 | [dlYPred,state] = forward(dlnet,dlX); 412 | loglikelyhood = log(dlYPred-min(dlYPred(:))+1e-5); 413 | gradients = dlgradient(mean(loglikelyhood(:)),dlnet.Learnables); 414 | end 415 | 416 | function [gradients,state,loss] = modelGradientsOnWeights(dlnet,dlX,Y) 417 | [dlYPred,state] = forward(dlnet,dlX); 418 | penalty = 0; 419 | scalarL2Factor = 0; 420 | if scalarL2Factor ~= 0 421 | paramLst = dlnet.Learnables.Value; 422 | for i = 1:size(paramLst,1) 423 | penalty = penalty + sum((paramLst{i}(:)).^2); 424 | end 425 | end 426 | loss = crossentropy(dlYPred,Y) + scalarL2Factor*penalty; 427 | gradients = dlgradient(loss,dlnet.Learnables); 428 | %gradients = dlgradient(loss,dlnet.Learnables(4,:)); 429 | end 430 | 431 | function [params, velocityUpdates] = sgdmFunction(params, paramGradients,... 432 | velocities, learnRates, momentums) 433 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 434 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 435 | params = params - velocityUpdates; 436 | end 437 | 438 | function [params, velocityUpdates] = sgdmFunctionL2(params, rawParamGradients,... 439 | velocities, learnRates, momentums, L2Foctors, gradientMasks) 440 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 441 | % https://towardsdatascience.com/intuitions-on-l1-and-l2-regularisation-235f2db4c261 442 | paramGradients = rawParamGradients + 2*L2Foctors.*params; 443 | %Please be noted that even if rawParamGradients = 0, L2 will still try 444 | %to reduce the magnitudes of parameters 445 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 446 | params = params - (velocityUpdates).*gradientMasks; 447 | end 448 | 449 | function tabVars = packScalar(target, scalar) 450 | % The matlabs' silly design results in such a strange function 451 | tabVars = target; 452 | for row = 1:size(tabVars(:,3),1) 453 | tabVars{row,3} = {... 454 | dlarray(... 455 | ones(size(tabVars.Value{row})).*scalar...%ones(size(tabVars(row,3).Value{1,1})).*scalar... 456 | )... 457 | }; 458 | end 459 | end -------------------------------------------------------------------------------- /ContinualLearning/WorkStage/ILDashScript.m: -------------------------------------------------------------------------------- 1 | addpath('../../') 2 | addpath('../') 3 | clear; 4 | close all; 5 | tic;load('./adsb-107loaded.mat');toc; 6 | load('./adsb-107Net20.mat') 7 | % rng default; 8 | 9 | stepSz = 20; 10 | newCVaccLst = []; 11 | oldCVaccLst = []; 12 | storedDlnetModels = {}; 13 | sumDevAngles = []; 14 | lockOldFps = 0; 15 | AdaptiveEWC = 0 ; 16 | for unkBound = 20:stepSz:80 17 | X2 = X; 18 | Y2 = Y; 19 | cX2 = cX; 20 | cY2 = cY; 21 | [unkBound + 1, unkBound + stepSz] 22 | uX = X(:,:,:,logical(double(Y>unkBound) .* double(Y<=unkBound + stepSz))); 23 | uY = Y(logical(double(Y>unkBound) .* double(Y <= unkBound + stepSz))); 24 | cuX = cX(:,:,:,logical(double(cY>unkBound) .* double(cY<= unkBound + stepSz))); 25 | cuY = cY(logical(double(cY>unkBound) .* double(cY<= unkBound + stepSz))); 26 | 27 | X = X(:,:,:,Y <= unkBound); 28 | cX = cX(:,:,:,cY <= unkBound); 29 | Y = Y(Y <= unkBound); 30 | cY = cY(cY <= unkBound); 31 | if unkBound == 20 32 | skipDAGNet = 0; 33 | else 34 | skipDAGNet = 1; 35 | end 36 | % noCSIL_Finetune_EWC 37 | % noCS_adaptiveEwc 38 | noCSI 39 | 40 | X = X2; 41 | Y = Y2; 42 | cX = cX2; 43 | cY = cY2; 44 | newCVaccLst(end+1) = newCvAccuracy; 45 | oldCVaccLst(end+1) = oldCvAccuracy; 46 | 47 | if unkBound == 20 48 | storedDlnetModels{end+1} = prevDlnet; 49 | end 50 | storedDlnetModels{end+1} = dlnet; 51 | sumDevAngles(end+1) = sumDeviationAnglesAll; 52 | end 53 | % 54 | % X2 = X; 55 | % Y2 = Y; 56 | % cX2 = cX; 57 | % cY2 = cY; 58 | % 59 | % uX = X(:,:,:,logical(double(Y>20) .* double(Y<=40))); 60 | % uY = Y(logical(double(Y>20) .* double(Y<=40))); 61 | % cuX = cX(:,:,:,logical(double(cY>20) .* double(cY<=40))); 62 | % cuY = cY(logical(double(cY>20) .* double(cY<=40))); 63 | % 64 | % X = X(:,:,:,Y <=20); 65 | % cX = cX(:,:,:,cY <= 20); 66 | % Y = Y(Y <= 20); 67 | % cY = cY(cY <= 20); 68 | % 69 | % skipDAGNet = 0; 70 | % noCSIL_Finetune_EWC 71 | % 72 | % X = X2; 73 | % Y = Y2; 74 | % cX = cX2; 75 | % cY = cY2; 76 | % 77 | % %% 78 | % uX = X(:,:,:,logical(double(Y>40) .* double(Y<=60))); 79 | % uY = Y(logical(double(Y>40) .* double(Y<=60))); 80 | % cuX = cX(:,:,:,logical(double(cY>40) .* double(cY<=60))); 81 | % cuY = cY(logical(double(cY>40) .* double(cY<=60))); 82 | % 83 | % X = X(:,:,:,Y <=40); 84 | % cX = cX(:,:,:,cY <= 40); 85 | % Y = Y(Y <= 40); 86 | % cY = cY(cY <= 40); 87 | % 88 | % skipDAGNet = 1; 89 | % noCSIL_Finetune_EWC 90 | % 91 | % X = X2; 92 | % Y = Y2; 93 | % cX = cX2; 94 | % cY = cY2; 95 | % 96 | % %% 97 | % uX = X(:,:,:,logical(double(Y>60) .* double(Y<=80))); 98 | % uY = Y(logical(double(Y>60) .* double(Y<=80))); 99 | % cuX = cX(:,:,:,logical(double(cY>60) .* double(cY<=80))); 100 | % cuY = cY(logical(double(cY>60) .* double(cY<=80))); 101 | % 102 | % X = X(:,:,:,Y <=60); 103 | % cX = cX(:,:,:,cY <= 60); 104 | % Y = Y(Y <= 60); 105 | % cY = cY(cY <= 60); 106 | % 107 | % skipDAGNet = 1; 108 | % noCSIL_Finetune_EWC 109 | % 110 | % X = X2; 111 | % Y = Y2; 112 | % cX = cX2; 113 | % cY = cY2; 114 | % %% 115 | % uX = X(:,:,:,logical(double(Y>80) .* double(Y<=100))); 116 | % uY = Y(logical(double(Y>80) .* double(Y<=100))); 117 | % cuX = cX(:,:,:,logical(double(cY>80) .* double(cY<=100))); 118 | % cuY = cY(logical(double(cY>80) .* double(cY<=100))); 119 | % 120 | % X = X(:,:,:,Y <=80); 121 | % cX = cX(:,:,:,cY <= 80); 122 | % Y = Y(Y <= 80); 123 | % cY = cY(cY <= 80); 124 | % 125 | % skipDAGNet = 1; 126 | % noCSIL_Finetune_EWC 127 | % 128 | % X = X2; 129 | % Y = Y2; 130 | % cX = cX2; 131 | % cY = cY2; 132 | -------------------------------------------------------------------------------- /ContinualLearning/WorkStage/ILDashScriptCSNet.m: -------------------------------------------------------------------------------- 1 | addpath('../../') 2 | addpath('../') 3 | clear; 4 | close all; 5 | tic;load('./adsb-107loaded.mat');toc; 6 | load('./adsb-107CSNet20.mat'); 7 | 8 | X2 = X; 9 | Y2 = Y; 10 | cX2 = cX; 11 | cY2 = cY; 12 | 13 | uX = X(:,:,:,logical(double(Y>20) .* double(Y<=40))); 14 | uY = Y(logical(double(Y>20) .* double(Y<=40))); 15 | cuX = cX(:,:,:,logical(double(cY>20) .* double(cY<=40))); 16 | cuY = cY(logical(double(cY>20) .* double(cY<=40))); 17 | 18 | X = X(:,:,:,Y <=20); 19 | cX = cX(:,:,:,cY <= 20); 20 | Y = Y(Y <= 20); 21 | cY = cY(cY <= 20); 22 | 23 | skipDAGNet = 0; 24 | % CSIL 25 | % CSILPast5000 26 | % CSILLockOldFPs 27 | % CSILLockOldFPsPast5000 28 | % CSILLockOldFPsChessBoard 29 | CSILLockOldFPsChessBoardPast5000 30 | 31 | X = X2; 32 | Y = Y2; 33 | cX = cX2; 34 | cY = cY2; 35 | 36 | %% 37 | uX = X(:,:,:,logical(double(Y>40) .* double(Y<=60))); 38 | uY = Y(logical(double(Y>40) .* double(Y<=60))); 39 | cuX = cX(:,:,:,logical(double(cY>40) .* double(cY<=60))); 40 | cuY = cY(logical(double(cY>40) .* double(cY<=60))); 41 | 42 | X = X(:,:,:,Y <=40); 43 | cX = cX(:,:,:,cY <= 40); 44 | Y = Y(Y <= 40); 45 | cY = cY(cY <= 40); 46 | 47 | skipDAGNet = 1; 48 | % CSIL 49 | % CSILPast5000 50 | % CSILLockOldFPs 51 | % CSILLockOldFPsPast5000 52 | % CSILLockOldFPsChessBoard 53 | CSILLockOldFPsChessBoardPast5000 54 | 55 | X = X2; 56 | Y = Y2; 57 | cX = cX2; 58 | cY = cY2; 59 | 60 | %% 61 | uX = X(:,:,:,logical(double(Y>60) .* double(Y<=80))); 62 | uY = Y(logical(double(Y>60) .* double(Y<=80))); 63 | cuX = cX(:,:,:,logical(double(cY>60) .* double(cY<=80))); 64 | cuY = cY(logical(double(cY>60) .* double(cY<=80))); 65 | 66 | X = X(:,:,:,Y <=60); 67 | cX = cX(:,:,:,cY <= 60); 68 | Y = Y(Y <= 60); 69 | cY = cY(cY <= 60); 70 | 71 | skipDAGNet = 1; 72 | % CSIL 73 | % CSILPast5000 74 | % CSILLockOldFPs 75 | % CSILLockOldFPsPast5000 76 | % CSILLockOldFPsChessBoard 77 | CSILLockOldFPsChessBoardPast5000 78 | 79 | X = X2; 80 | Y = Y2; 81 | cX = cX2; 82 | cY = cY2; 83 | %% 84 | uX = X(:,:,:,logical(double(Y>80) .* double(Y<=100))); 85 | uY = Y(logical(double(Y>80) .* double(Y<=100))); 86 | cuX = cX(:,:,:,logical(double(cY>80) .* double(cY<=100))); 87 | cuY = cY(logical(double(cY>80) .* double(cY<=100))); 88 | 89 | X = X(:,:,:,Y <=80); 90 | cX = cX(:,:,:,cY <= 80); 91 | Y = Y(Y <= 80); 92 | cY = cY(cY <= 80); 93 | 94 | skipDAGNet = 1; 95 | % CSIL 96 | % CSILPast5000 97 | % CSILLockOldFPs 98 | % CSILLockOldFPsPast5000 99 | % CSILLockOldFPsChessBoard 100 | CSILLockOldFPsChessBoardPast5000 101 | 102 | X = X2; 103 | Y = Y2; 104 | cX = cX2; 105 | cY = cY2; 106 | -------------------------------------------------------------------------------- /ContinualLearning/WorkStage/ILDashScriptFixRep.m: -------------------------------------------------------------------------------- 1 | addpath('../../') 2 | addpath('../') 3 | clear; 4 | close all; 5 | tic;load('./adsb-107loaded.mat');toc; 6 | load('./adsb-107Net20.mat') 7 | 8 | X2 = X; 9 | Y2 = Y; 10 | cX2 = cX; 11 | cY2 = cY; 12 | 13 | uX = X(:,:,:,logical(double(Y>20) .* double(Y<=40))); 14 | uY = Y(logical(double(Y>20) .* double(Y<=40))); 15 | cuX = cX(:,:,:,logical(double(cY>20) .* double(cY<=40))); 16 | cuY = cY(logical(double(cY>20) .* double(cY<=40))); 17 | 18 | X = X(:,:,:,Y <=20); 19 | cX = cX(:,:,:,cY <= 20); 20 | Y = Y(Y <= 20); 21 | cY = cY(cY <= 20); 22 | 23 | skipDAGNet = 0; 24 | FixrepExpanding 25 | 26 | X = X2; 27 | Y = Y2; 28 | cX = cX2; 29 | cY = cY2; 30 | 31 | %% 32 | uX = X(:,:,:,logical(double(Y>40) .* double(Y<=60))); 33 | uY = Y(logical(double(Y>40) .* double(Y<=60))); 34 | cuX = cX(:,:,:,logical(double(cY>40) .* double(cY<=60))); 35 | cuY = cY(logical(double(cY>40) .* double(cY<=60))); 36 | 37 | X = X(:,:,:,Y <=40); 38 | cX = cX(:,:,:,cY <= 40); 39 | Y = Y(Y <= 40); 40 | cY = cY(cY <= 40); 41 | 42 | skipDAGNet = 1; 43 | FixrepExpanding 44 | 45 | X = X2; 46 | Y = Y2; 47 | cX = cX2; 48 | cY = cY2; 49 | 50 | %% 51 | uX = X(:,:,:,logical(double(Y>60) .* double(Y<=80))); 52 | uY = Y(logical(double(Y>60) .* double(Y<=80))); 53 | cuX = cX(:,:,:,logical(double(cY>60) .* double(cY<=80))); 54 | cuY = cY(logical(double(cY>60) .* double(cY<=80))); 55 | 56 | X = X(:,:,:,Y <=60); 57 | cX = cX(:,:,:,cY <= 60); 58 | Y = Y(Y <= 60); 59 | cY = cY(cY <= 60); 60 | 61 | skipDAGNet = 1; 62 | FixrepExpanding 63 | 64 | X = X2; 65 | Y = Y2; 66 | cX = cX2; 67 | cY = cY2; 68 | %% 69 | uX = X(:,:,:,logical(double(Y>80) .* double(Y<=100))); 70 | uY = Y(logical(double(Y>80) .* double(Y<=100))); 71 | cuX = cX(:,:,:,logical(double(cY>80) .* double(cY<=100))); 72 | cuY = cY(logical(double(cY>80) .* double(cY<=100))); 73 | 74 | X = X(:,:,:,Y <=80); 75 | cX = cX(:,:,:,cY <= 80); 76 | Y = Y(Y <= 80); 77 | cY = cY(cY <= 80); 78 | 79 | skipDAGNet = 1; 80 | FixrepExpanding 81 | 82 | X = X2; 83 | Y = Y2; 84 | cX = cX2; 85 | cY = cY2; 86 | -------------------------------------------------------------------------------- /ContinualLearning/WorkStage/Readme.md: -------------------------------------------------------------------------------- 1 | Multi-stage incremental learning, remember to download the preprocessed dataset './adsb-107loaded.mat' from [IEEE Dataport](https://ieee-dataport.org/documents/ads-b-signals-records-non-cryptographic-identification-and-incremental-learning), Control & initialization scripts are in ILDashScript*.m 2 | -------------------------------------------------------------------------------- /ContinualLearning/WorkStage/adsb-107CSNet20.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcwhy/CSIL/8ce8637daf4dc60eeb1c56bff64c050c5b2353e9/ContinualLearning/WorkStage/adsb-107CSNet20.mat -------------------------------------------------------------------------------- /ContinualLearning/WorkStage/noCSI-kd-ewc.m: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | %%%%% 8 | %Gather fisher information and weights from the old network 9 | prevClassNum = 0; 10 | if skipDAGNet == 0 11 | lgraph2 = layerGraph(net); 12 | lgraph2 = lgraph2.removeLayers('classify_1'); 13 | weights = net.Layers({net.Layers.Name}=="Fingerprints").Weights; 14 | numClasses = size(weights,1); 15 | prevClassNum = numClasses; 16 | else 17 | lgraph2 = layerGraph(dlnet); 18 | numClasses = size(dlnet.Layers(14).Weights,1); 19 | prevClassNum = numClasses; 20 | end 21 | prevDlnet = dlnetwork(lgraph2); 22 | prevWeights = prevDlnet.Learnables; 23 | prevCX = dlarray(single(cX),'SSCB'); 24 | prevCY = zeros(numClasses, size(prevCX,4), 'single'); 25 | for c = 1:numClasses 26 | prevCY(c,cY(:)==(c)) = 1; 27 | end 28 | executionEnvironment = "auto"; 29 | % If training on a GPU, then convert data to gpuArray. 30 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 31 | prevCX = gpuArray(prevCX); 32 | end 33 | [prevGradients0,state] = dlfeval(@logModelGradientsOnWeights,prevDlnet,prevCX); 34 | accuracy = cvAccuracy(prevDlnet, prevCX, (cY), 128, executionEnvironment) 35 | prevDlnet.State = state; 36 | prevFisherInfo0 = prevGradients0; 37 | for i = 1:size(prevFisherInfo0,1) 38 | prevFisherInfo0{i,3}={dlarray(exp(prevGradients0{i,3}{:}.^2))}; 39 | % prevFisherInfo0{i,3}={dlarray(exp(prevGradients0{i,3}{:}.^2))}; 40 | 41 | end 42 | fingerprintLayerIdx = 14; 43 | 44 | newClassesNum = floor(size(unique(uY),1)); 45 | unknownClassLabels = unique(uY); 46 | idx = randperm(size(unknownClassLabels,1)); 47 | unknownClassLabels = unknownClassLabels(idx); 48 | unknownClassLabels = unknownClassLabels(1:newClassesNum); 49 | 50 | % Generate some initial fingerprints. 51 | %figure 52 | % hold on; 53 | existingFingerprints = prevDlnet.Layers(fingerprintLayerIdx).Weights; 54 | newFingerprints = dlarray([]); 55 | continualLearnX = dlarray([]); 56 | continualLearnY = []; 57 | cursor = 1; 58 | for i = 1:newClassesNum 59 | selection = uY==unknownClassLabels(i); 60 | ux_i = uX(:,:,:,selection); 61 | uy_i = uY(selection); 62 | samplePerClass = size(uy_i, 1); 63 | continualLearnX(:,:,:,cursor:cursor+samplePerClass-1) = ux_i; 64 | continualLearnY(cursor:cursor+samplePerClass-1)=uy_i; 65 | cursor = cursor+samplePerClass; 66 | % aUx_i = squeeze(prevDlnet.predict(dlarray(ux_i,'SSCB'),'Outputs','fc_bf_fp')); 67 | aUx_i = squeeze(activations(net,ux_i,'fc_bf_fp')); 68 | unitAUx_i=(aUx_i./sqrt(sum(aUx_i.^2)))'; 69 | fp_i = mean(aUx_i,2)'; 70 | magFp_i = sqrt(sum(fp_i.^2)); 71 | unitFp_i = fp_i ./ magFp_i; 72 | newFp = (unitFp_i); 73 | % newFp = findPotentialFingerprint(dlarray(existingFingerprints),dlarray(aUx_i)); 74 | newFingerprints(end+1,:) = newFp; 75 | % histogram(sum(unitFp_i.*unitFingerprints,2),[-1:0.2:1],'Normalization','probability'); 76 | % hold on; 77 | % histogram(sum(unitAUx_i.*unitFp_i,2),[-1:0.2:1],'Normalization','probability'); 78 | % legend('Correlation with existing fingerprints','Correlation with own samples'); 79 | end 80 | randSeries = randperm(size(continualLearnY,2)); 81 | continualLearnX = continualLearnX(:,:,:,randSeries); 82 | continualLearnY = continualLearnY(randSeries); 83 | cvContinualLearnX = continualLearnX(:,:,:,floor(0.6*size(continualLearnX,4)):end); 84 | cvContinualLearnY = continualLearnY(floor(0.6*size(continualLearnY,2)):end); 85 | continualLearnX = continualLearnX(:,:,:,1:floor(0.6*size(continualLearnX,4))-1); 86 | continualLearnY = continualLearnY(1:floor(0.6*size(continualLearnY,2))-1); 87 | 88 | concatFingerprints = [existingFingerprints; newFingerprints]; 89 | % newFingerprintLayer = FCLayer(numClasses,numClasses+newClassesNum,'Fingerprints',concatFingerprints); 90 | newFingerprintLayer = zeroBiasFCLayer(numClasses,prevClassNum+newClassesNum,'Fingerprints',concatFingerprints); 91 | 92 | % newFingerprintLayer.normMag = prevDlnet.Layers(4).normMag; 93 | % newFingerprintLayer.b1 = prevDlnet.Layers(4).b1; 94 | 95 | numClasses = prevClassNum + newClassesNum; 96 | 97 | %Build new network and start continual learning. 98 | lgraph2 = layerGraph(net); 99 | lgraph2 = lgraph2.removeLayers('classify_1'); 100 | dlnet = dlnetwork(replaceLayer(lgraph2, 'Fingerprints', newFingerprintLayer)); 101 | 102 | XTrain = continualLearnX; 103 | YTrain = continualLearnY; 104 | XTest = cvContinualLearnX; 105 | YTest = (cvContinualLearnY); 106 | numEpochs = 3; 107 | miniBatchSize = 20; 108 | plots = "training-progress"; 109 | statusFigureIdx = []; 110 | statusFigureAxis = []; 111 | if plots == "training-progress" 112 | figure; 113 | statusFigureIdx = gcf; 114 | statusFigureAxis = gca; 115 | % Go into the documentation of animatedline for more color codes 116 | lineNewLossTrain = animatedline('Color', '#0072BD','LineWidth',1,'Marker','.','LineStyle','none'); 117 | lineNewCVAccuracy = animatedline('Color', '#D95319','LineWidth',1); 118 | lineOldCVAccuracy = animatedline('Color', '#EDB120','LineWidth',1); 119 | lineOldLossCV = animatedline('Color', '#7E2F8E','LineWidth',1,'Marker','.','LineStyle','none'); 120 | % ylim([0 inf]) 121 | ylim([0 2]) 122 | xlabel("Iteration") 123 | ylabel("Metrics") 124 | legend('New task loss','New task accuracy','Old task accuracy', 'Old task lost'); 125 | grid on 126 | end 127 | L2RegularizationFactor = 0.01; 128 | initialLearnRate = 0.01; 129 | decay = 0.01; 130 | momentum = 0.9; 131 | velocities = []; 132 | learnRates = []; 133 | momentums = []; 134 | gradientMasks = []; 135 | numObservations = numel(YTrain); 136 | numIterationsPerEpoch = floor(numObservations./miniBatchSize); 137 | iteration = 0; 138 | start = tic; 139 | classes = categorical(YTrain); 140 | 141 | % 142 | prevCY = [prevCY;zeros(newClassesNum,size(prevCY,2))]; 143 | 144 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 145 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 146 | disp('CV accuracy b.f. cont. learning'); 147 | [newCvAccuracy, oldCvAccuracy] 148 | 149 | % Loop over epochs. 150 | fisherLossLst = []; 151 | for epoch = 1:numEpochs 152 | idx = randperm(numel(YTrain)); 153 | XTrain = XTrain(:,:,:,idx); 154 | YTrain = YTrain(idx); 155 | % Loop over mini-batches. 156 | for i = 1:numIterationsPerEpoch 157 | iteration = iteration + 1; 158 | % Read mini-batch of data and convert the labels to dummy 159 | % variables. 160 | idx = (i-1)*miniBatchSize+1:i*miniBatchSize; 161 | Xb = XTrain(:,:,:,idx); 162 | Yb = zeros(numClasses, miniBatchSize, 'single'); 163 | for c = 1:numClasses 164 | Yb(c,YTrain(idx)==(c)) = 1; 165 | end 166 | % Convert mini-batch of data to dlarray. 167 | dlX = dlarray(single(Xb),'SSCB'); 168 | % If training on a GPU, then convert data to gpuArray. 169 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 170 | dlX = gpuArray(dlX); 171 | end 172 | % Evaluate the model gradients, state, and loss using dlfeval and the 173 | % modelGradients function and update the network state. 174 | % [gradients,state,loss] = dlfeval(@modelGradientsOnWeights,dlnet,dlX,Yb); 175 | % if epoch <= 5 176 | % [gradients,state,loss,fisherLoss] = dlfeval(@modelGradientsOnWeightsEWC,dlnet,dlX, Yb,... 177 | % prevFisherInfo0, prevDlnet, newClassesNum, 0, fingerprintLayerIdx); 178 | % else 179 | % [gradients,state,loss,fisherLoss] = dlfeval(@modelGradientsOnWeightsEWC,dlnet,dlX, Yb,... 180 | % prevFisherInfo0, prevDlnet, newClassesNum, 1, fingerprintLayerIdx); 181 | % end 182 | [gradients,state,loss,fisherLoss] = dlfeval(@modelGradientsOnWeightsEWC,dlnet,dlX, Yb,... 183 | prevFisherInfo0, prevDlnet, newClassesNum, 1, fingerprintLayerIdx,prevClassNum); 184 | 185 | fisherLossLst(end+1)=extractdata(gather(fisherLoss)); 186 | 187 | dlnet.State = state; 188 | % Determine learning rate for time-based decay learning rate schedule. 189 | learnRate = initialLearnRate/(1 + decay*iteration); 190 | if isempty(velocities) 191 | velocities = packScalar(gradients, 0); 192 | learnRates = packScalar(gradients, learnRate); 193 | momentums = packScalar(gradients, momentum); 194 | L2Foctors = packScalar(gradients, 0.05); 195 | gradientMasks = packScalar(gradients, 1); 196 | % Let's lock some weights 197 | % Zero-bias dense layer only. 198 | % if epoch <= 5 199 | % gradientMasks{fingerprintLayerIdx,3} = {dlarray([zeros(18,18); ones(newClassesNum,fingerprintLen)])}; 200 | % else 201 | % gradientMasks{fingerprintLayerIdx,3} = {dlarray([ones(18,18); ones(newClassesNum,fingerprintLen)])}; 202 | % end 203 | 204 | % gradientMasks{fingerprintLayerIdx-1,3} = {dlarray([zeros(prevClassNum,20); ones(newClassesNum,20)])}; 205 | % gradientMasks{9+1,3} = {dlarray([zeros(prevClassNum,1); ones(newClassesNum,1)])}; 206 | 207 | for k = 1:fingerprintLayerIdx-2 208 | gradientMasks.Value{k}=dlarray(zeros(size(gradientMasks.Value{k}))); 209 | end 210 | end 211 | %fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet,newClassesNum) 212 | [dlnet, velocities] = dlupdate(@sgdmFunctionL2, ... 213 | dlnet, gradients, velocities, ... 214 | learnRates, momentums, L2Foctors, gradientMasks); 215 | 216 | % Display the training progress. 217 | if plots == "training-progress" 218 | D = duration(0,0,toc(start),'Format','hh:mm:ss'); 219 | % figure(statusFigureIdx); 220 | if mod(iteration,20) == 0 221 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 222 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 223 | [~,~,oldCVLoss] = dlfeval(@modelGradientsOnWeights,dlnet,prevCX,prevCY); 224 | addpoints(lineNewCVAccuracy, iteration, newCvAccuracy); 225 | addpoints(lineOldCVAccuracy, iteration, oldCvAccuracy); 226 | addpoints(lineOldLossCV, iteration, double(gather(extractdata(oldCVLoss)))); 227 | end 228 | addpoints(lineNewLossTrain,iteration,double(gather(extractdata(loss)))) 229 | title(statusFigureAxis,"Epoch: " + epoch + ", Elapsed: " + string(D)) 230 | drawnow 231 | end 232 | end 233 | end 234 | 235 | 236 | % 237 | % figure 238 | % subplot(2,1,1) 239 | % imagesc(prevDlnet.Layers(fingerprintLayerIdx).Weights) 240 | % title('Old Finger prints'); 241 | % subplot(2,1,2) 242 | % imagesc(dlnet.Layers(fingerprintLayerIdx).Weights) 243 | % title('New Finger prints'); 244 | % 245 | figure 246 | weights = dlnet.Layers(11).Weights; 247 | similarMatrix = zeros(size(weights,1), size(weights,1)); 248 | for i = 1:size(weights,1) 249 | curWeight = weights(i,:); 250 | magCurWeight = sqrt(sum(curWeight.^2,2)); 251 | for j = 1:size(weights,1) 252 | nxtWeight = weights(j,:); 253 | magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 254 | similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 255 | .*(curWeight./magCurWeight)); 256 | end 257 | end 258 | sumDeviationAnglesAll = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 259 | imagesc(similarMatrix) 260 | % 261 | % weights = dlnet.Layers(11).Weights(1:18,:); 262 | % similarMatrix = zeros(size(weights,1), size(weights,1)); 263 | % for i = 1:size(weights,1) 264 | % curWeight = weights(i,:); 265 | % magCurWeight = sqrt(sum(curWeight.^2,2)); 266 | % for j = 1:size(weights,1) 267 | % nxtWeight = weights(j,:); 268 | % magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 269 | % similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 270 | % .*(curWeight./magCurWeight)); 271 | % end 272 | % end 273 | % sumDeviationAnglesOld = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 274 | % 275 | % weights = dlnet.Layers(11).Weights(19:end,:); 276 | % similarMatrix = zeros(size(weights,1), size(weights,1)); 277 | % for i = 1:size(weights,1) 278 | % curWeight = weights(i,:); 279 | % magCurWeight = sqrt(sum(curWeight.^2,2)); 280 | % for j = 1:size(weights,1) 281 | % nxtWeight = weights(j,:); 282 | % magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 283 | % similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 284 | % .*(curWeight./magCurWeight)); 285 | % end 286 | % end 287 | % sumDeviationAnglesNew = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 288 | 289 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 290 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 291 | disp('CV accuracy a.f. cont. learning'); 292 | [newCvAccuracy, oldCvAccuracy] 293 | 294 | % figure 295 | % subplot(2,1,1) 296 | % imagesc(prevDlnet.Layers(2).Weights) 297 | % title('Old input dense'); 298 | % subplot(2,1,2) 299 | % imagesc(dlnet.Layers(2).Weights) 300 | % title('New input dense'); 301 | 302 | function [totalLoss,corrOld,corrNew,grad] = evalFp(existingFp, newFeatures, fp) 303 | magFp = sqrt(sum(fp.^2)); 304 | unitFp = fp ./ magFp; 305 | corrOld = mean(sum( existingFp./sqrt( sum((existingFp).^2,2)).*unitFp, 2 )); 306 | corrNew = mean(sum( newFeatures'./sqrt( sum((newFeatures').^2,2)).*unitFp, 2)); 307 | totalLoss = corrOld - 10*corrNew; 308 | grad = dlgradient(totalLoss, fp); 309 | end 310 | 311 | function fp = findPotentialFingerprint(existingFp, newFeatures) 312 | fp = mean(newFeatures,2)'; 313 | % magFp = sqrt(sum(fp.^2)); 314 | % unitFp = fp ./ magFp; 315 | record = []; 316 | for i = 1:20 317 | [totalLoss,corrOld,corrNew, grad] = dlfeval(@evalFp, (existingFp), (newFeatures), fp); 318 | record = [record;[totalLoss,corrOld,corrNew]]; 319 | fp = fp - 0.5.*grad; 320 | end 321 | % figure 322 | % plot(extractdata(record(:,1)),'LineWidth',1.5); 323 | % hold on; 324 | % plot(extractdata(record(:,2)),'LineWidth',1.5); 325 | % plot(extractdata(record(:,3)),'LineWidth',1.5); 326 | end 327 | 328 | function [gradients,state,loss,fisherLoss] = modelGradientsOnWeightsEWC(dlnet, dlX, Y,... 329 | prevFisherInfo0, prevDlnet, newClassesNum, ewcLambda, fingerprintLayerIdx, prevClassNum) 330 | % [dlYPred,state] = forward(dlnet,dlX); 331 | oldDlYPred = prevDlnet.predict(dlX,'Outputs','Fingerprints'); 332 | [dlYPred,dlYPred2,bfFp,state] = forward(dlnet,dlX,'Outputs',{'softmax_1','Fingerprints','fc_bf_fp'}); 333 | penalty = 0; 334 | scalarL2Factor = 0; 335 | if scalarL2Factor ~= 0 336 | paramLst = dlnet.Learnables.Value; 337 | for i = 1:size(paramLst,1) 338 | penalty = penalty + sum((paramLst{i}(:)).^2); 339 | end 340 | end 341 | fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet, newClassesNum, fingerprintLayerIdx)*ewcLambda/2; 342 | distilLoss = sum((squeeze(oldDlYPred)-squeeze(dlYPred2(1:prevClassNum,:))).^2,'all')./32; 343 | 344 | loss = crossentropy(dlYPred,Y) + scalarL2Factor*penalty + fisherLoss + 0.2.*distilLoss; 345 | gradients = dlgradient(loss, dlnet.Learnables); 346 | end 347 | 348 | function fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet, newClassesNum,fingerprintLayerIdx) 349 | prevWeights = prevDlnet.Learnables.Value; 350 | curWeights = dlnet.Learnables.Value; 351 | % fisherLossMatrix = {}; 352 | sumLoss = 0; 353 | elementCount = 1; 354 | for i = 1:size(prevWeights,1) 355 | if i >= fingerprintLayerIdx-1 356 | loss = ((prevWeights{i}-curWeights{i}(1:size(curWeights{i},1)-newClassesNum,:)).^2) .* prevFisherInfo0.Value{i}; 357 | % fisherLossMatrix{end+1} = loss; 358 | sumLoss = sumLoss + sum(loss(:)); 359 | elementCount = elementCount + numel(prevWeights{i}); 360 | else 361 | loss = ((prevWeights{i}-curWeights{i}).^2) .* prevFisherInfo0.Value{i}; 362 | % fisherLossMatrix{end+1} = loss; 363 | sumLoss = sumLoss + sum(loss(:)); 364 | elementCount = elementCount + numel(prevWeights{i}); 365 | end 366 | end 367 | fisherLoss = sumLoss; 368 | end 369 | 370 | function accuracy = cvAccuracy(dlnet, XTest, YTest, miniBatchSize, executionEnvironment) 371 | dlXTest = dlarray(XTest,'SSCB'); 372 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 373 | dlXTest = gpuArray(dlXTest); 374 | end 375 | dlYPred = modelPredictions(dlnet,dlXTest,miniBatchSize); 376 | [~,idx] = max(extractdata(dlYPred),[],1); 377 | YPred = (idx); 378 | accuracy = mean(YPred(:) == YTest(:)); 379 | end 380 | 381 | function dlYPred = modelPredictions(dlnet,dlX,miniBatchSize) 382 | numObservations = size(dlX,4); 383 | numIterations = ceil(numObservations / miniBatchSize); 384 | numClasses = size(dlnet.Layers(end-1).Weights,1); 385 | dlYPred = zeros(numClasses,numObservations,'like',dlX); 386 | for i = 1:numIterations 387 | idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations); 388 | dlYPred(:,idx) = predict(dlnet,dlX(:,:,:,idx)); 389 | end 390 | end 391 | 392 | function [gradients,state] = logModelGradientsOnWeights(dlnet,dlX) 393 | [dlYPred,state] = forward(dlnet,dlX); 394 | loglikelyhood = log(dlYPred-min(dlYPred(:))+1e-5); 395 | gradients = dlgradient(mean(loglikelyhood(:)),dlnet.Learnables); 396 | end 397 | 398 | function [gradients,state,loss] = modelGradientsOnWeights(dlnet,dlX,Y) 399 | [dlYPred,state] = forward(dlnet,dlX); 400 | penalty = 0; 401 | scalarL2Factor = 0; 402 | if scalarL2Factor ~= 0 403 | paramLst = dlnet.Learnables.Value; 404 | for i = 1:size(paramLst,1) 405 | penalty = penalty + sum((paramLst{i}(:)).^2); 406 | end 407 | end 408 | loss = crossentropy(dlYPred,Y) + scalarL2Factor*penalty; 409 | gradients = dlgradient(loss,dlnet.Learnables); 410 | %gradients = dlgradient(loss,dlnet.Learnables(4,:)); 411 | end 412 | 413 | function [params, velocityUpdates] = sgdmFunction(params, paramGradients,... 414 | velocities, learnRates, momentums) 415 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 416 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 417 | params = params - velocityUpdates; 418 | end 419 | 420 | function [params, velocityUpdates] = sgdmFunctionL2(params, rawParamGradients,... 421 | velocities, learnRates, momentums, L2Foctors, gradientMasks) 422 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 423 | % https://towardsdatascience.com/intuitions-on-l1-and-l2-regularisation-235f2db4c261 424 | paramGradients = rawParamGradients + 2*L2Foctors.*params; 425 | %Please be noted that even if rawParamGradients = 0, L2 will still try 426 | %to reduce the magnitudes of parameters 427 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 428 | params = params - (velocityUpdates).*gradientMasks; 429 | end 430 | 431 | function tabVars = packScalar(target, scalar) 432 | % The matlabs' silly design results in such a strange function 433 | tabVars = target; 434 | for row = 1:size(tabVars(:,3),1) 435 | tabVars{row,3} = {... 436 | dlarray(... 437 | ones(size(tabVars.Value{row})).*scalar...%ones(size(tabVars(row,3).Value{1,1})).*scalar... 438 | )... 439 | }; 440 | end 441 | end -------------------------------------------------------------------------------- /ContinualLearning/WorkStage/noCSI.m: -------------------------------------------------------------------------------- 1 | 2 | 3 | 4 | 5 | 6 | 7 | %%%%% 8 | %Gather fisher information and weights from the old network 9 | prevClassNum = 0; 10 | if skipDAGNet == 0 11 | lgraph2 = layerGraph(net); 12 | lgraph2 = lgraph2.removeLayers('classify_1'); 13 | weights = net.Layers({net.Layers.Name}=="Fingerprints").Weights; 14 | numClasses = size(weights,1); 15 | prevClassNum = numClasses; 16 | else 17 | lgraph2 = layerGraph(dlnet); 18 | numClasses = size(dlnet.Layers(14).Weights,1); 19 | prevClassNum = numClasses; 20 | end 21 | prevDlnet = dlnetwork(lgraph2); 22 | prevWeights = prevDlnet.Learnables; 23 | prevCX = dlarray(single(cX),'SSCB'); 24 | prevCY = zeros(numClasses, size(prevCX,4), 'single'); 25 | for c = 1:numClasses 26 | prevCY(c,cY(:)==(c)) = 1; 27 | end 28 | executionEnvironment = "auto"; 29 | % If training on a GPU, then convert data to gpuArray. 30 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 31 | prevCX = gpuArray(prevCX); 32 | end 33 | [prevGradients0,state] = dlfeval(@logModelGradientsOnWeights,prevDlnet,prevCX); 34 | accuracy = cvAccuracy(prevDlnet, prevCX, (cY), 128, executionEnvironment) 35 | prevDlnet.State = state; 36 | prevFisherInfo0 = prevGradients0; 37 | for i = 1:size(prevFisherInfo0,1) 38 | prevFisherInfo0{i,3}={dlarray(exp(prevGradients0{i,3}{:}.^2))}; 39 | % prevFisherInfo0{i,3}={dlarray(exp(prevGradients0{i,3}{:}.^2))}; 40 | 41 | end 42 | fingerprintLayerIdx = 14; 43 | 44 | newClassesNum = floor(size(unique(uY),1)); 45 | unknownClassLabels = unique(uY); 46 | idx = randperm(size(unknownClassLabels,1)); 47 | unknownClassLabels = unknownClassLabels(idx); 48 | unknownClassLabels = unknownClassLabels(1:newClassesNum); 49 | 50 | % Generate some initial fingerprints. 51 | %figure 52 | % hold on; 53 | existingFingerprints = prevDlnet.Layers(fingerprintLayerIdx).Weights; 54 | newFingerprints = dlarray([]); 55 | continualLearnX = dlarray([]); 56 | continualLearnY = []; 57 | cursor = 1; 58 | for i = 1:newClassesNum 59 | selection = uY==unknownClassLabels(i); 60 | ux_i = uX(:,:,:,selection); 61 | uy_i = uY(selection); 62 | samplePerClass = size(uy_i, 1); 63 | continualLearnX(:,:,:,cursor:cursor+samplePerClass-1) = ux_i; 64 | continualLearnY(cursor:cursor+samplePerClass-1)=uy_i; 65 | cursor = cursor+samplePerClass; 66 | % aUx_i = squeeze(prevDlnet.predict(dlarray(ux_i,'SSCB'),'Outputs','fc_bf_fp')); 67 | aUx_i = squeeze(activations(net,ux_i,'fc_bf_fp')); 68 | unitAUx_i=(aUx_i./sqrt(sum(aUx_i.^2)))'; 69 | fp_i = mean(aUx_i,2)'; 70 | magFp_i = sqrt(sum(fp_i.^2)); 71 | unitFp_i = fp_i ./ magFp_i; 72 | newFp = (unitFp_i); 73 | % newFp = findPotentialFingerprint(dlarray(existingFingerprints),dlarray(aUx_i)); 74 | newFingerprints(end+1,:) = newFp; 75 | % histogram(sum(unitFp_i.*unitFingerprints,2),[-1:0.2:1],'Normalization','probability'); 76 | % hold on; 77 | % histogram(sum(unitAUx_i.*unitFp_i,2),[-1:0.2:1],'Normalization','probability'); 78 | % legend('Correlation with existing fingerprints','Correlation with own samples'); 79 | end 80 | randSeries = randperm(size(continualLearnY,2)); 81 | continualLearnX = continualLearnX(:,:,:,randSeries); 82 | continualLearnY = continualLearnY(randSeries); 83 | cvContinualLearnX = continualLearnX(:,:,:,floor(0.6*size(continualLearnX,4)):end); 84 | cvContinualLearnY = continualLearnY(floor(0.6*size(continualLearnY,2)):end); 85 | continualLearnX = continualLearnX(:,:,:,1:floor(0.6*size(continualLearnX,4))-1); 86 | continualLearnY = continualLearnY(1:floor(0.6*size(continualLearnY,2))-1); 87 | 88 | concatFingerprints = [existingFingerprints; newFingerprints]; 89 | % newFingerprintLayer = FCLayer(numClasses,numClasses+newClassesNum,'Fingerprints',concatFingerprints); 90 | newFingerprintLayer = zeroBiasFCLayer(numClasses,prevClassNum+newClassesNum,'Fingerprints',concatFingerprints); 91 | 92 | % newFingerprintLayer.normMag = prevDlnet.Layers(4).normMag; 93 | % newFingerprintLayer.b1 = prevDlnet.Layers(4).b1; 94 | 95 | numClasses = prevClassNum + newClassesNum; 96 | 97 | %Build new network and start continual learning. 98 | lgraph2 = layerGraph(net); 99 | lgraph2 = lgraph2.removeLayers('classify_1'); 100 | dlnet = dlnetwork(replaceLayer(lgraph2, 'Fingerprints', newFingerprintLayer)); 101 | 102 | XTrain = continualLearnX; 103 | YTrain = continualLearnY; 104 | XTest = cvContinualLearnX; 105 | YTest = (cvContinualLearnY); 106 | numEpochs = 3; 107 | miniBatchSize = 20; 108 | plots = "training-progress"; 109 | statusFigureIdx = []; 110 | statusFigureAxis = []; 111 | if plots == "training-progress" 112 | figure; 113 | statusFigureIdx = gcf; 114 | statusFigureAxis = gca; 115 | % Go into the documentation of animatedline for more color codes 116 | lineNewLossTrain = animatedline('Color', '#0072BD','LineWidth',1,'Marker','.','LineStyle','none'); 117 | lineNewCVAccuracy = animatedline('Color', '#D95319','LineWidth',1); 118 | lineOldCVAccuracy = animatedline('Color', '#EDB120','LineWidth',1); 119 | lineOldLossCV = animatedline('Color', '#7E2F8E','LineWidth',1,'Marker','.','LineStyle','none'); 120 | % ylim([0 inf]) 121 | ylim([0 2]) 122 | xlabel("Iteration") 123 | ylabel("Metrics") 124 | legend('New task loss','New task accuracy','Old task accuracy', 'Old task lost'); 125 | grid on 126 | end 127 | L2RegularizationFactor = 0.01; 128 | initialLearnRate = 0.01; 129 | decay = 0.01; 130 | momentum = 0.9; 131 | velocities = []; 132 | learnRates = []; 133 | momentums = []; 134 | gradientMasks = []; 135 | numObservations = numel(YTrain); 136 | numIterationsPerEpoch = floor(numObservations./miniBatchSize); 137 | iteration = 0; 138 | start = tic; 139 | classes = categorical(YTrain); 140 | 141 | % 142 | prevCY = [prevCY;zeros(newClassesNum,size(prevCY,2))]; 143 | 144 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 145 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 146 | disp('CV accuracy b.f. cont. learning'); 147 | [newCvAccuracy, oldCvAccuracy] 148 | 149 | % Loop over epochs. 150 | fisherLossLst = []; 151 | for epoch = 1:numEpochs 152 | idx = randperm(numel(YTrain)); 153 | XTrain = XTrain(:,:,:,idx); 154 | YTrain = YTrain(idx); 155 | % Loop over mini-batches. 156 | for i = 1:numIterationsPerEpoch 157 | iteration = iteration + 1; 158 | % Read mini-batch of data and convert the labels to dummy 159 | % variables. 160 | idx = (i-1)*miniBatchSize+1:i*miniBatchSize; 161 | Xb = XTrain(:,:,:,idx); 162 | Yb = zeros(numClasses, miniBatchSize, 'single'); 163 | for c = 1:numClasses 164 | Yb(c,YTrain(idx)==(c)) = 1; 165 | end 166 | % Convert mini-batch of data to dlarray. 167 | dlX = dlarray(single(Xb),'SSCB'); 168 | % If training on a GPU, then convert data to gpuArray. 169 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 170 | dlX = gpuArray(dlX); 171 | end 172 | % Evaluate the model gradients, state, and loss using dlfeval and the 173 | % modelGradients function and update the network state. 174 | % [gradients,state,loss] = dlfeval(@modelGradientsOnWeights,dlnet,dlX,Yb); 175 | % if epoch <= 5 176 | % [gradients,state,loss,fisherLoss] = dlfeval(@modelGradientsOnWeightsEWC,dlnet,dlX, Yb,... 177 | % prevFisherInfo0, prevDlnet, newClassesNum, 0, fingerprintLayerIdx); 178 | % else 179 | % [gradients,state,loss,fisherLoss] = dlfeval(@modelGradientsOnWeightsEWC,dlnet,dlX, Yb,... 180 | % prevFisherInfo0, prevDlnet, newClassesNum, 1, fingerprintLayerIdx); 181 | % end 182 | [gradients,state,loss,fisherLoss] = dlfeval(@modelGradientsOnWeightsEWC,dlnet,dlX, Yb,... 183 | prevFisherInfo0, prevDlnet, newClassesNum, 1, fingerprintLayerIdx,prevClassNum); 184 | 185 | fisherLossLst(end+1)=extractdata(gather(fisherLoss)); 186 | 187 | dlnet.State = state; 188 | % Determine learning rate for time-based decay learning rate schedule. 189 | learnRate = initialLearnRate/(1 + decay*iteration); 190 | if isempty(velocities) 191 | velocities = packScalar(gradients, 0); 192 | learnRates = packScalar(gradients, learnRate); 193 | momentums = packScalar(gradients, momentum); 194 | L2Foctors = packScalar(gradients, 0.05); 195 | gradientMasks = packScalar(gradients, 1); 196 | % Let's lock some weights 197 | % Zero-bias dense layer only. 198 | % if epoch <= 5 199 | % gradientMasks{fingerprintLayerIdx,3} = {dlarray([zeros(18,18); ones(newClassesNum,fingerprintLen)])}; 200 | % else 201 | % gradientMasks{fingerprintLayerIdx,3} = {dlarray([ones(18,18); ones(newClassesNum,fingerprintLen)])}; 202 | % end 203 | 204 | % gradientMasks{fingerprintLayerIdx-1,3} = {dlarray([zeros(prevClassNum,20); ones(newClassesNum,20)])}; 205 | % gradientMasks{9+1,3} = {dlarray([zeros(prevClassNum,1); ones(newClassesNum,1)])}; 206 | 207 | for k = 1:fingerprintLayerIdx-2 208 | gradientMasks.Value{k}=dlarray(zeros(size(gradientMasks.Value{k}))); 209 | end 210 | end 211 | %fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet,newClassesNum) 212 | [dlnet, velocities] = dlupdate(@sgdmFunctionL2, ... 213 | dlnet, gradients, velocities, ... 214 | learnRates, momentums, L2Foctors, gradientMasks); 215 | 216 | % Display the training progress. 217 | if plots == "training-progress" 218 | D = duration(0,0,toc(start),'Format','hh:mm:ss'); 219 | % figure(statusFigureIdx); 220 | if mod(iteration,20) == 0 221 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 222 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 223 | [~,~,oldCVLoss] = dlfeval(@modelGradientsOnWeights,dlnet,prevCX,prevCY); 224 | addpoints(lineNewCVAccuracy, iteration, newCvAccuracy); 225 | addpoints(lineOldCVAccuracy, iteration, oldCvAccuracy); 226 | addpoints(lineOldLossCV, iteration, double(gather(extractdata(oldCVLoss)))); 227 | end 228 | addpoints(lineNewLossTrain,iteration,double(gather(extractdata(loss)))) 229 | title(statusFigureAxis,"Epoch: " + epoch + ", Elapsed: " + string(D)) 230 | drawnow 231 | end 232 | end 233 | end 234 | 235 | 236 | % 237 | % figure 238 | % subplot(2,1,1) 239 | % imagesc(prevDlnet.Layers(fingerprintLayerIdx).Weights) 240 | % title('Old Finger prints'); 241 | % subplot(2,1,2) 242 | % imagesc(dlnet.Layers(fingerprintLayerIdx).Weights) 243 | % title('New Finger prints'); 244 | % 245 | figure 246 | weights = dlnet.Layers({dlnet.Layers.Name} == "Fingerprints").Weights; 247 | similarMatrix = zeros(size(weights,1), size(weights,1)); 248 | for i = 1:size(weights,1) 249 | curWeight = weights(i,:); 250 | magCurWeight = sqrt(sum(curWeight.^2,2)); 251 | for j = 1:size(weights,1) 252 | nxtWeight = weights(j,:); 253 | magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 254 | similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 255 | .*(curWeight./magCurWeight)); 256 | end 257 | end 258 | sumDeviationAnglesAll = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 259 | imagesc(similarMatrix) 260 | % 261 | % weights = dlnet.Layers(11).Weights(1:18,:); 262 | % similarMatrix = zeros(size(weights,1), size(weights,1)); 263 | % for i = 1:size(weights,1) 264 | % curWeight = weights(i,:); 265 | % magCurWeight = sqrt(sum(curWeight.^2,2)); 266 | % for j = 1:size(weights,1) 267 | % nxtWeight = weights(j,:); 268 | % magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 269 | % similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 270 | % .*(curWeight./magCurWeight)); 271 | % end 272 | % end 273 | % sumDeviationAnglesOld = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 274 | % 275 | % weights = dlnet.Layers(11).Weights(19:end,:); 276 | % similarMatrix = zeros(size(weights,1), size(weights,1)); 277 | % for i = 1:size(weights,1) 278 | % curWeight = weights(i,:); 279 | % magCurWeight = sqrt(sum(curWeight.^2,2)); 280 | % for j = 1:size(weights,1) 281 | % nxtWeight = weights(j,:); 282 | % magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 283 | % similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 284 | % .*(curWeight./magCurWeight)); 285 | % end 286 | % end 287 | % sumDeviationAnglesNew = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 288 | 289 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 290 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 291 | disp('CV accuracy a.f. cont. learning'); 292 | [newCvAccuracy, oldCvAccuracy] 293 | 294 | % figure 295 | % subplot(2,1,1) 296 | % imagesc(prevDlnet.Layers(2).Weights) 297 | % title('Old input dense'); 298 | % subplot(2,1,2) 299 | % imagesc(dlnet.Layers(2).Weights) 300 | % title('New input dense'); 301 | 302 | function [totalLoss,corrOld,corrNew,grad] = evalFp(existingFp, newFeatures, fp) 303 | magFp = sqrt(sum(fp.^2)); 304 | unitFp = fp ./ magFp; 305 | corrOld = mean(sum( existingFp./sqrt( sum((existingFp).^2,2)).*unitFp, 2 )); 306 | corrNew = mean(sum( newFeatures'./sqrt( sum((newFeatures').^2,2)).*unitFp, 2)); 307 | totalLoss = corrOld - 10*corrNew; 308 | grad = dlgradient(totalLoss, fp); 309 | end 310 | 311 | function fp = findPotentialFingerprint(existingFp, newFeatures) 312 | fp = mean(newFeatures,2)'; 313 | % magFp = sqrt(sum(fp.^2)); 314 | % unitFp = fp ./ magFp; 315 | record = []; 316 | for i = 1:20 317 | [totalLoss,corrOld,corrNew, grad] = dlfeval(@evalFp, (existingFp), (newFeatures), fp); 318 | record = [record;[totalLoss,corrOld,corrNew]]; 319 | fp = fp - 0.5.*grad; 320 | end 321 | % figure 322 | % plot(extractdata(record(:,1)),'LineWidth',1.5); 323 | % hold on; 324 | % plot(extractdata(record(:,2)),'LineWidth',1.5); 325 | % plot(extractdata(record(:,3)),'LineWidth',1.5); 326 | end 327 | 328 | function [gradients,state,loss,fisherLoss] = modelGradientsOnWeightsEWC(dlnet, dlX, Y,... 329 | prevFisherInfo0, prevDlnet, newClassesNum, ewcLambda, fingerprintLayerIdx, prevClassNum) 330 | % [dlYPred,state] = forward(dlnet,dlX); 331 | oldDlYPred = prevDlnet.predict(dlX,'Outputs','Fingerprints'); 332 | [dlYPred,dlYPred2,bfFp,state] = forward(dlnet,dlX,'Outputs',{'softmax_1','Fingerprints','fc_bf_fp'}); 333 | penalty = 0; 334 | scalarL2Factor = 0; 335 | if scalarL2Factor ~= 0 336 | paramLst = dlnet.Learnables.Value; 337 | for i = 1:size(paramLst,1) 338 | penalty = penalty + sum((paramLst{i}(:)).^2); 339 | end 340 | end 341 | fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet, newClassesNum, fingerprintLayerIdx)*ewcLambda/2; 342 | distilLoss = sum((squeeze(oldDlYPred)-squeeze(dlYPred2(1:prevClassNum,:))).^2,'all')./32; 343 | 344 | loss = crossentropy(dlYPred,Y) + scalarL2Factor*penalty + 0.*fisherLoss + 0.2.*distilLoss; 345 | gradients = dlgradient(loss, dlnet.Learnables); 346 | end 347 | 348 | function fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet, newClassesNum,fingerprintLayerIdx) 349 | prevWeights = prevDlnet.Learnables.Value; 350 | curWeights = dlnet.Learnables.Value; 351 | % fisherLossMatrix = {}; 352 | sumLoss = 0; 353 | elementCount = 1; 354 | for i = 1:size(prevWeights,1) 355 | if i >= fingerprintLayerIdx-1 356 | loss = ((prevWeights{i}-curWeights{i}(1:size(curWeights{i},1)-newClassesNum,:)).^2) .* prevFisherInfo0.Value{i}; 357 | % fisherLossMatrix{end+1} = loss; 358 | sumLoss = sumLoss + sum(loss(:)); 359 | elementCount = elementCount + numel(prevWeights{i}); 360 | else 361 | loss = ((prevWeights{i}-curWeights{i}).^2) .* prevFisherInfo0.Value{i}; 362 | % fisherLossMatrix{end+1} = loss; 363 | sumLoss = sumLoss + sum(loss(:)); 364 | elementCount = elementCount + numel(prevWeights{i}); 365 | end 366 | end 367 | fisherLoss = sumLoss; 368 | end 369 | 370 | function accuracy = cvAccuracy(dlnet, XTest, YTest, miniBatchSize, executionEnvironment) 371 | dlXTest = dlarray(XTest,'SSCB'); 372 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 373 | dlXTest = gpuArray(dlXTest); 374 | end 375 | dlYPred = modelPredictions(dlnet,dlXTest,miniBatchSize); 376 | [~,idx] = max(extractdata(dlYPred),[],1); 377 | YPred = (idx); 378 | accuracy = mean(YPred(:) == YTest(:)); 379 | end 380 | 381 | function dlYPred = modelPredictions(dlnet,dlX,miniBatchSize) 382 | numObservations = size(dlX,4); 383 | numIterations = ceil(numObservations / miniBatchSize); 384 | numClasses = size(dlnet.Layers(end-1).Weights,1); 385 | dlYPred = zeros(numClasses,numObservations,'like',dlX); 386 | for i = 1:numIterations 387 | idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations); 388 | dlYPred(:,idx) = predict(dlnet,dlX(:,:,:,idx)); 389 | end 390 | end 391 | 392 | function [gradients,state] = logModelGradientsOnWeights(dlnet,dlX) 393 | [dlYPred,state] = forward(dlnet,dlX); 394 | loglikelyhood = log(dlYPred-min(dlYPred(:))+1e-5); 395 | gradients = dlgradient(mean(loglikelyhood(:)),dlnet.Learnables); 396 | end 397 | 398 | function [gradients,state,loss] = modelGradientsOnWeights(dlnet,dlX,Y) 399 | [dlYPred,state] = forward(dlnet,dlX); 400 | penalty = 0; 401 | scalarL2Factor = 0; 402 | if scalarL2Factor ~= 0 403 | paramLst = dlnet.Learnables.Value; 404 | for i = 1:size(paramLst,1) 405 | penalty = penalty + sum((paramLst{i}(:)).^2); 406 | end 407 | end 408 | loss = crossentropy(dlYPred,Y) + scalarL2Factor*penalty; 409 | gradients = dlgradient(loss,dlnet.Learnables); 410 | %gradients = dlgradient(loss,dlnet.Learnables(4,:)); 411 | end 412 | 413 | function [params, velocityUpdates] = sgdmFunction(params, paramGradients,... 414 | velocities, learnRates, momentums) 415 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 416 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 417 | params = params - velocityUpdates; 418 | end 419 | 420 | function [params, velocityUpdates] = sgdmFunctionL2(params, rawParamGradients,... 421 | velocities, learnRates, momentums, L2Foctors, gradientMasks) 422 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 423 | % https://towardsdatascience.com/intuitions-on-l1-and-l2-regularisation-235f2db4c261 424 | paramGradients = rawParamGradients + 2*L2Foctors.*params; 425 | %Please be noted that even if rawParamGradients = 0, L2 will still try 426 | %to reduce the magnitudes of parameters 427 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 428 | params = params - (velocityUpdates).*gradientMasks; 429 | end 430 | 431 | function tabVars = packScalar(target, scalar) 432 | % The matlabs' silly design results in such a strange function 433 | tabVars = target; 434 | for row = 1:size(tabVars(:,3),1) 435 | tabVars{row,3} = {... 436 | dlarray(... 437 | ones(size(tabVars.Value{row})).*scalar...%ones(size(tabVars(row,3).Value{1,1})).*scalar... 438 | )... 439 | }; 440 | end 441 | end -------------------------------------------------------------------------------- /ContinualLearning/WorkStage/noCSIL_Finetune_EWC.m: -------------------------------------------------------------------------------- 1 | 2 | % %%%% 3 | % Gather fisher information and weights from the old network 4 | prevClassNum = 0; 5 | if skipDAGNet == 0 6 | lgraph2 = layerGraph(net); 7 | lgraph2 = lgraph2.removeLayers('classify_1'); 8 | weights = net.Layers({net.Layers.Name} == "Fingerprints").Weights; 9 | numClasses = size(weights, 1); 10 | prevClassNum = numClasses; 11 | else 12 | lgraph2 = layerGraph(dlnet); 13 | numClasses = size(dlnet.Layers(14).Weights, 1); 14 | prevClassNum = numClasses; 15 | end 16 | prevDlnet = dlnetwork(lgraph2); 17 | prevWeights = prevDlnet.Learnables; 18 | prevCX = dlarray(single(cX), 'SSCB'); 19 | prevCY = zeros(numClasses, size(prevCX, 4), 'single'); 20 | for c = 1:numClasses 21 | prevCY(c, cY(:) == (c)) = 1; 22 | end 23 | executionEnvironment = "auto"; 24 | % If training on a GPU, then convert data to gpuArray. 25 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 26 | prevCX = gpuArray(prevCX); 27 | end 28 | [prevGradients0, state] = dlfeval(@logModelGradientsOnWeights, prevDlnet, prevCX); 29 | accuracy = cvAccuracy(prevDlnet, prevCX, (cY), 128, executionEnvironment) 30 | prevDlnet.State = state; 31 | prevFisherInfo0 = prevGradients0; 32 | for i = 1:size(prevFisherInfo0, 1) 33 | prevFisherInfo0{i, 3} = {dlarray(exp(prevGradients0{i, 3}{:} .^ 2))}; 34 | % prevFisherInfo0{i,3}={dlarray(exp(prevGradients0{i,3}{:}.^2))}; 35 | 36 | end 37 | fingerprintLayerIdx = 14; 38 | 39 | newClassesNum = floor(size(unique(uY), 1)); 40 | unknownClassLabels = unique(uY); 41 | idx = randperm(size(unknownClassLabels, 1)); 42 | unknownClassLabels = unknownClassLabels(idx); 43 | unknownClassLabels = unknownClassLabels(1:newClassesNum); 44 | 45 | % Generate some initial fingerprints. 46 | % figure 47 | % hold on; 48 | existingFingerprints = prevDlnet.Layers(fingerprintLayerIdx).Weights; 49 | newFingerprints = dlarray([]); 50 | continualLearnX = dlarray([]); 51 | continualLearnY = []; 52 | cursor = 1; 53 | for i = 1:newClassesNum 54 | selection = uY == unknownClassLabels(i); 55 | ux_i = uX(:, :, :, selection); 56 | uy_i = uY(selection); 57 | samplePerClass = size(uy_i, 1); 58 | continualLearnX(:, :, :, cursor:cursor + samplePerClass - 1) = ux_i; 59 | continualLearnY(cursor:cursor + samplePerClass - 1) = uy_i; 60 | cursor = cursor + samplePerClass; 61 | % aUx_i = squeeze(prevDlnet.predict(dlarray(ux_i,'SSCB'),'Outputs','fc_bf_fp')); 62 | aUx_i = squeeze(activations(net, ux_i, 'fc_bf_fp')); 63 | unitAUx_i = (aUx_i ./ sqrt(sum(aUx_i .^ 2)))'; 64 | fp_i = mean(aUx_i, 2)'; 65 | magFp_i = sqrt(sum(fp_i .^ 2)); 66 | unitFp_i = fp_i ./ magFp_i; 67 | newFp = (unitFp_i); 68 | % newFp = findPotentialFingerprint(dlarray(existingFingerprints),dlarray(aUx_i)); 69 | newFingerprints(end + 1, :) = newFp; 70 | % histogram(sum(unitFp_i.*unitFingerprints,2),[-1:0.2:1],'Normalization','probability'); 71 | % hold on; 72 | % histogram(sum(unitAUx_i.*unitFp_i,2),[-1:0.2:1],'Normalization','probability'); 73 | % legend('Correlation with existing fingerprints','Correlation with own samples'); 74 | end 75 | randSeries = randperm(size(continualLearnY, 2)); 76 | continualLearnX = continualLearnX(:, :, :, randSeries); 77 | continualLearnY = continualLearnY(randSeries); 78 | cvContinualLearnX = continualLearnX(:, :, :, floor(0.6 * size(continualLearnX, 4)):end); 79 | cvContinualLearnY = continualLearnY(floor(0.6 * size(continualLearnY, 2)):end); 80 | continualLearnX = continualLearnX(:, :, :, 1:floor(0.6 * size(continualLearnX, 4)) - 1); 81 | continualLearnY = continualLearnY(1:floor(0.6 * size(continualLearnY, 2)) - 1); 82 | 83 | concatFingerprints = [existingFingerprints; newFingerprints]; 84 | % newFingerprintLayer = FCLayer(numClasses,numClasses+newClassesNum,'Fingerprints',concatFingerprints); 85 | newFingerprintLayer = zeroBiasFCLayer(numClasses, prevClassNum + newClassesNum, 'Fingerprints', concatFingerprints); 86 | 87 | % newFingerprintLayer.normMag = prevDlnet.Layers(4).normMag; 88 | % newFingerprintLayer.b1 = prevDlnet.Layers(4).b1; 89 | 90 | numClasses = prevClassNum + newClassesNum; 91 | 92 | % Build new network and start continual learning. 93 | lgraph2 = layerGraph(net); 94 | lgraph2 = lgraph2.removeLayers('classify_1'); 95 | dlnet = dlnetwork(replaceLayer(lgraph2, 'Fingerprints', newFingerprintLayer)); 96 | 97 | XTrain = continualLearnX; 98 | YTrain = continualLearnY; 99 | XTest = cvContinualLearnX; 100 | YTest = (cvContinualLearnY); 101 | numEpochs = 3; 102 | miniBatchSize = 20; 103 | plots = "training-progress"; 104 | statusFigureIdx = []; 105 | statusFigureAxis = []; 106 | if plots == "training-progress" 107 | figure; 108 | statusFigureIdx = gcf; 109 | statusFigureAxis = gca; 110 | % Go into the documentation of animatedline for more color codes 111 | lineNewLossTrain = animatedline('Color', '#0072BD', 'LineWidth', 1, 'Marker', '.', 'LineStyle', 'none'); 112 | lineNewCVAccuracy = animatedline('Color', '#D95319', 'LineWidth', 1); 113 | lineOldCVAccuracy = animatedline('Color', '#EDB120', 'LineWidth', 1); 114 | lineOldLossCV = animatedline('Color', '#7E2F8E', 'LineWidth', 1, 'Marker', '.', 'LineStyle', 'none'); 115 | % ylim([0 inf]) 116 | ylim([0 2]) 117 | xlabel("Iteration") 118 | ylabel("Metrics") 119 | legend('New task loss', 'New task accuracy', 'Old task accuracy', 'Old task lost'); 120 | grid on 121 | end 122 | L2RegularizationFactor = 0.01; 123 | initialLearnRate = 0.01; 124 | decay = 0.01; 125 | momentum = 0.9; 126 | velocities = []; 127 | learnRates = []; 128 | momentums = []; 129 | gradientMasks = []; 130 | numObservations = numel(YTrain); 131 | numIterationsPerEpoch = floor(numObservations ./ miniBatchSize); 132 | iteration = 0; 133 | start = tic; 134 | classes = categorical(YTrain); 135 | 136 | % 137 | prevCY = [prevCY; zeros(newClassesNum, size(prevCY, 2))]; 138 | 139 | newCvAccuracy = cvAccuracy(dlnet, XTest, YTest, miniBatchSize, executionEnvironment); 140 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize, executionEnvironment); 141 | disp('CV accuracy b.f. cont. learning'); 142 | [newCvAccuracy, oldCvAccuracy] 143 | % Loop over epochs. 144 | fisherLossLst = []; 145 | for epoch = 1:numEpochs 146 | idx = randperm(numel(YTrain)); 147 | XTrain = XTrain(:, :, :, idx); 148 | YTrain = YTrain(idx); 149 | % Loop over mini-batches. 150 | for i = 1:numIterationsPerEpoch 151 | iteration = iteration + 1; 152 | % Read mini-batch of data and convert the labels to dummy 153 | % variables. 154 | idx = (i - 1) * miniBatchSize + 1:i * miniBatchSize; 155 | Xb = XTrain(:, :, :, idx); 156 | Yb = zeros(numClasses, miniBatchSize, 'single'); 157 | for c = 1:numClasses 158 | Yb(c, YTrain(idx) == (c)) = 1; 159 | end 160 | % Convert mini-batch of data to dlarray. 161 | dlX = dlarray(single(Xb), 'SSCB'); 162 | % If training on a GPU, then convert data to gpuArray. 163 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 164 | dlX = gpuArray(dlX); 165 | end 166 | % Evaluate the model gradients, state, and loss using dlfeval and the 167 | % modelGradients function and update the network state. 168 | [gradients, state, loss, fisherLoss] = dlfeval(@modelGradientsOnWeightsEWC, dlnet, dlX, Yb, ... 169 | prevFisherInfo0, prevDlnet, newClassesNum, 1, fingerprintLayerIdx, prevClassNum); 170 | 171 | fisherLossLst(end + 1) = extractdata(gather(fisherLoss)); 172 | 173 | dlnet.State = state; 174 | % Determine learning rate for time-based decay learning rate schedule. 175 | learnRate = initialLearnRate / (1 + decay * iteration); 176 | if isempty(velocities) 177 | velocities = packScalar(gradients, 0); 178 | learnRates = packScalar(gradients, learnRate); 179 | momentums = packScalar(gradients, momentum); 180 | L2Foctors = packScalar(gradients, 0.05); 181 | gradientMasks = packScalar(gradients, 1); 182 | % Let's lock some weights 183 | % Zero-bias dense layer only. 184 | % if epoch <= 5 185 | % gradientMasks{fingerprintLayerIdx,3} = {dlarray([zeros(18,18); ones(newClassesNum,fingerprintLen)])}; 186 | % else 187 | % gradientMasks{fingerprintLayerIdx,3} = {dlarray([ones(18,18); ones(newClassesNum,fingerprintLen)])}; 188 | % end 189 | if lockOldFps == 1 190 | gradientMasks{fingerprintLayerIdx - 1, 3} = {dlarray([zeros(prevClassNum, 20); ones(newClassesNum, 20)])}; % specify whether lock old fps or not 191 | % gradientMasks{9+1,3} = {dlarray([zeros(prevClassNum,1); ones(newClassesNum,1)])}; 192 | end 193 | 194 | for k = 1:fingerprintLayerIdx - 2 195 | gradientMasks.Value{k} = dlarray(zeros(size(gradientMasks.Value{k}))); 196 | end 197 | end 198 | % fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet,newClassesNum) 199 | [dlnet, velocities] = dlupdate(@sgdmFunctionL2, ... 200 | dlnet, gradients, velocities, ... 201 | learnRates, momentums, L2Foctors, gradientMasks); 202 | 203 | % Display the training progress. 204 | if plots == "training-progress" 205 | D = duration(0, 0, toc(start), 'Format', 'hh:mm:ss'); 206 | % figure(statusFigureIdx); 207 | if mod(iteration, 20) == 0 208 | newCvAccuracy = cvAccuracy(dlnet, XTest, YTest, miniBatchSize, executionEnvironment); 209 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize, executionEnvironment); 210 | [~, ~, oldCVLoss] = dlfeval(@modelGradientsOnWeights, dlnet, prevCX, prevCY); 211 | addpoints(lineNewCVAccuracy, iteration, newCvAccuracy); 212 | addpoints(lineOldCVAccuracy, iteration, oldCvAccuracy); 213 | addpoints(lineOldLossCV, iteration, double(gather(extractdata(oldCVLoss)))); 214 | end 215 | addpoints(lineNewLossTrain, iteration, double(gather(extractdata(loss)))) 216 | title(statusFigureAxis, "Epoch: " + epoch + ", Elapsed: " + string(D)) 217 | drawnow 218 | end 219 | end 220 | end 221 | 222 | 223 | 224 | figure 225 | weights = dlnet.Layers(14).Weights(:,:); 226 | similarMatrix = zeros(size(weights,1), size(weights,1)); 227 | for i = 1:size(weights,1) 228 | curWeight = weights(i,:); 229 | magCurWeight = sqrt(sum(curWeight.^2,2)); 230 | for j = 1:size(weights,1) 231 | nxtWeight = weights(j,:); 232 | magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 233 | similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 234 | .*(curWeight./magCurWeight)); 235 | end 236 | end 237 | sumDeviationAnglesAll = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 238 | imagesc(similarMatrix) 239 | 240 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 241 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 242 | disp('CV accuracy a.f. cont. learning'); 243 | [newCvAccuracy, oldCvAccuracy] 244 | 245 | 246 | function [totalLoss,corrOld,corrNew,grad] = evalFp(existingFp, newFeatures, fp) 247 | magFp = sqrt(sum(fp.^2)); 248 | unitFp = fp ./ magFp; 249 | corrOld = mean(sum( existingFp./sqrt( sum((existingFp).^2,2)).*unitFp, 2 )); 250 | corrNew = mean(sum( newFeatures'./sqrt( sum((newFeatures').^2,2)).*unitFp, 2)); 251 | totalLoss = corrOld - 10*corrNew; 252 | grad = dlgradient(totalLoss, fp); 253 | end 254 | 255 | function fp = findPotentialFingerprint(existingFp, newFeatures) 256 | fp = mean(newFeatures,2)'; 257 | % magFp = sqrt(sum(fp.^2)); 258 | % unitFp = fp ./ magFp; 259 | record = []; 260 | for i = 1:20 261 | [totalLoss,corrOld,corrNew, grad] = dlfeval(@evalFp, (existingFp), (newFeatures), fp); 262 | record = [record;[totalLoss,corrOld,corrNew]]; 263 | fp = fp - 0.5.*grad; 264 | end 265 | % figure 266 | % plot(extractdata(record(:,1)),'LineWidth',1.5); 267 | % hold on; 268 | % plot(extractdata(record(:,2)),'LineWidth',1.5); 269 | % plot(extractdata(record(:,3)),'LineWidth',1.5); 270 | end 271 | 272 | function [gradients,state,loss,fisherLoss] = modelGradientsOnWeightsEWC(dlnet, dlX, Y,... 273 | prevFisherInfo0, prevDlnet, newClassesNum, ewcLambda, fingerprintLayerIdx, prevClassNum) 274 | [dlYPred,state] = forward(dlnet,dlX); 275 | penalty = 0; 276 | scalarL2Factor = 0; 277 | if scalarL2Factor ~= 0 278 | paramLst = dlnet.Learnables.Value; 279 | for i = 1:size(paramLst,1) 280 | penalty = penalty + sum((paramLst{i}(:)).^2); 281 | end 282 | end 283 | fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet, newClassesNum, fingerprintLayerIdx)*ewcLambda/2; 284 | loss = crossentropy(dlYPred,Y) + scalarL2Factor*penalty + fisherLoss; 285 | gradients = dlgradient(loss, dlnet.Learnables); 286 | end 287 | 288 | function fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet, newClassesNum,fingerprintLayerIdx) 289 | prevWeights = prevDlnet.Learnables.Value; 290 | curWeights = dlnet.Learnables.Value; 291 | % fisherLossMatrix = {}; 292 | sumLoss = 0; 293 | elementCount = 1; 294 | for i = 1:size(prevWeights,1) 295 | if i >= fingerprintLayerIdx-1 296 | loss = ((prevWeights{i}-curWeights{i}(1:size(curWeights{i},1)-newClassesNum,:)).^2) .* prevFisherInfo0.Value{i}; 297 | % fisherLossMatrix{end+1} = loss; 298 | sumLoss = sumLoss + sum(loss(:)); 299 | elementCount = elementCount + numel(prevWeights{i}); 300 | else 301 | loss = ((prevWeights{i}-curWeights{i}).^2) .* prevFisherInfo0.Value{i}; 302 | % fisherLossMatrix{end+1} = loss; 303 | sumLoss = sumLoss + sum(loss(:)); 304 | elementCount = elementCount + numel(prevWeights{i}); 305 | end 306 | end 307 | fisherLoss = sumLoss; 308 | end 309 | 310 | function accuracy = cvAccuracy(dlnet, XTest, YTest, miniBatchSize, executionEnvironment) 311 | dlXTest = dlarray(XTest,'SSCB'); 312 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 313 | dlXTest = gpuArray(dlXTest); 314 | end 315 | dlYPred = modelPredictions(dlnet,dlXTest,miniBatchSize); 316 | [~,idx] = max(extractdata(dlYPred),[],1); 317 | YPred = (idx); 318 | accuracy = mean(YPred(:) == YTest(:)); 319 | end 320 | 321 | function dlYPred = modelPredictions(dlnet,dlX,miniBatchSize) 322 | numObservations = size(dlX,4); 323 | numIterations = ceil(numObservations / miniBatchSize); 324 | numClasses = size(dlnet.Layers(end-1).Weights,1); 325 | dlYPred = zeros(numClasses,numObservations,'like',dlX); 326 | for i = 1:numIterations 327 | idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations); 328 | dlYPred(:,idx) = predict(dlnet,dlX(:,:,:,idx)); 329 | end 330 | end 331 | 332 | function [gradients,state] = logModelGradientsOnWeights(dlnet,dlX) 333 | [dlYPred,state] = forward(dlnet,dlX); 334 | loglikelyhood = log(dlYPred-min(dlYPred(:))+1e-5); 335 | gradients = dlgradient(mean(loglikelyhood(:)),dlnet.Learnables); 336 | end 337 | 338 | function [gradients,state,loss] = modelGradientsOnWeights(dlnet,dlX,Y) 339 | [dlYPred,state] = forward(dlnet,dlX); 340 | penalty = 0; 341 | scalarL2Factor = 0; 342 | if scalarL2Factor ~= 0 343 | paramLst = dlnet.Learnables.Value; 344 | for i = 1:size(paramLst,1) 345 | penalty = penalty + sum((paramLst{i}(:)).^2); 346 | end 347 | end 348 | loss = crossentropy(dlYPred,Y) + scalarL2Factor*penalty; 349 | gradients = dlgradient(loss,dlnet.Learnables); 350 | %gradients = dlgradient(loss,dlnet.Learnables(4,:)); 351 | end 352 | 353 | function [params, velocityUpdates] = sgdmFunction(params, paramGradients,... 354 | velocities, learnRates, momentums) 355 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 356 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 357 | params = params - velocityUpdates; 358 | end 359 | 360 | function [params, velocityUpdates] = sgdmFunctionL2(params, rawParamGradients,... 361 | velocities, learnRates, momentums, L2Foctors, gradientMasks) 362 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 363 | % https://towardsdatascience.com/intuitions-on-l1-and-l2-regularisation-235f2db4c261 364 | paramGradients = rawParamGradients + 2*L2Foctors.*params; 365 | %Please be noted that even if rawParamGradients = 0, L2 will still try 366 | %to reduce the magnitudes of parameters 367 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 368 | params = params - (velocityUpdates).*gradientMasks; 369 | end 370 | 371 | function tabVars = packScalar(target, scalar) 372 | % The matlabs' silly design results in such a strange function 373 | tabVars = target; 374 | for row = 1:size(tabVars(:,3),1) 375 | tabVars{row,3} = {... 376 | dlarray(... 377 | ones(size(tabVars.Value{row})).*scalar...%ones(size(tabVars(row,3).Value{1,1})).*scalar... 378 | )... 379 | }; 380 | end 381 | end 382 | 383 | 384 | 385 | -------------------------------------------------------------------------------- /ContinualLearning/WorkStage/noCS_restrictEwc.m: -------------------------------------------------------------------------------- 1 | 2 | % %%%% 3 | % Gather fisher information and weights from the old network 4 | prevClassNum = 0; 5 | if skipDAGNet == 0 6 | lgraph2 = layerGraph(net); 7 | lgraph2 = lgraph2.removeLayers('classify_1'); 8 | weights = net.Layers({net.Layers.Name} == "Fingerprints").Weights; 9 | numClasses = size(weights, 1); 10 | prevClassNum = numClasses; 11 | 12 | else 13 | lgraph2 = layerGraph(dlnet); 14 | numClasses = size(dlnet.Layers(14).Weights, 1); 15 | prevClassNum = numClasses; 16 | 17 | end 18 | prevDlnet = dlnetwork(lgraph2); 19 | prevWeights = prevDlnet.Learnables; 20 | prevCX = dlarray(single(cX), 'SSCB'); 21 | prevCY = zeros(numClasses, size(prevCX, 4), 'single'); 22 | for c = 1:numClasses 23 | prevCY(c, cY(:) == (c)) = 1; 24 | end 25 | executionEnvironment = "auto"; 26 | % If training on a GPU, then convert data to gpuArray. 27 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 28 | prevCX = gpuArray(prevCX); 29 | end 30 | [prevGradients0, state] = dlfeval(@logModelGradientsOnWeights, prevDlnet, prevCX); 31 | accuracy = cvAccuracy(prevDlnet, prevCX, (cY), 128, executionEnvironment) 32 | prevDlnet.State = state; 33 | prevFisherInfo0 = prevGradients0; 34 | for i = 1:size(prevFisherInfo0, 1) 35 | prevFisherInfo0{i, 3} = {dlarray(exp(prevGradients0{i, 3}{:} .^ 2))}; 36 | % prevFisherInfo0{i,3}={dlarray(exp(prevGradients0{i,3}{:}.^2))}; 37 | 38 | end 39 | fingerprintLayerIdx = 14; 40 | 41 | newClassesNum = floor(size(unique(uY), 1)); 42 | unknownClassLabels = unique(uY); 43 | idx = randperm(size(unknownClassLabels, 1)); 44 | unknownClassLabels = unknownClassLabels(idx); 45 | unknownClassLabels = unknownClassLabels(1:newClassesNum); 46 | 47 | % Generate some initial fingerprints. 48 | % figure 49 | % hold on; 50 | existingFingerprints = prevDlnet.Layers(fingerprintLayerIdx).Weights; 51 | newFingerprints = dlarray([]); 52 | continualLearnX = dlarray([]); 53 | continualLearnY = []; 54 | cursor = 1; 55 | for i = 1:newClassesNum 56 | selection = uY == unknownClassLabels(i); 57 | ux_i = uX(:, :, :, selection); 58 | uy_i = uY(selection); 59 | samplePerClass = size(uy_i, 1); 60 | continualLearnX(:, :, :, cursor:cursor + samplePerClass - 1) = ux_i; 61 | continualLearnY(cursor:cursor + samplePerClass - 1) = uy_i; 62 | cursor = cursor + samplePerClass; 63 | % aUx_i = squeeze(prevDlnet.predict(dlarray(ux_i,'SSCB'),'Outputs','fc_bf_fp')); 64 | aUx_i = squeeze(activations(net, ux_i, 'fc_bf_fp')); 65 | unitAUx_i = (aUx_i ./ sqrt(sum(aUx_i .^ 2)))'; 66 | fp_i = mean(aUx_i, 2)'; 67 | magFp_i = sqrt(sum(fp_i .^ 2)); 68 | unitFp_i = fp_i ./ magFp_i; 69 | newFp = (unitFp_i); 70 | % newFp = findPotentialFingerprint(dlarray(existingFingerprints),dlarray(aUx_i)); 71 | newFingerprints(end + 1, :) = newFp; 72 | % histogram(sum(unitFp_i.*unitFingerprints,2),[-1:0.2:1],'Normalization','probability'); 73 | % hold on; 74 | % histogram(sum(unitAUx_i.*unitFp_i,2),[-1:0.2:1],'Normalization','probability'); 75 | % legend('Correlation with existing fingerprints','Correlation with own samples'); 76 | end 77 | randSeries = randperm(size(continualLearnY, 2)); 78 | continualLearnX = continualLearnX(:, :, :, randSeries); 79 | continualLearnY = continualLearnY(randSeries); 80 | cvContinualLearnX = continualLearnX(:, :, :, floor(0.6 * size(continualLearnX, 4)):end); 81 | cvContinualLearnY = continualLearnY(floor(0.6 * size(continualLearnY, 2)):end); 82 | continualLearnX = continualLearnX(:, :, :, 1:floor(0.6 * size(continualLearnX, 4)) - 1); 83 | continualLearnY = continualLearnY(1:floor(0.6 * size(continualLearnY, 2)) - 1); 84 | 85 | concatFingerprints = [existingFingerprints; newFingerprints]; 86 | % newFingerprintLayer = FCLayer(numClasses,numClasses+newClassesNum,'Fingerprints',concatFingerprints); 87 | newFingerprintLayer = zeroBiasFCLayer(numClasses, prevClassNum + newClassesNum, 'Fingerprints', concatFingerprints); 88 | 89 | % newFingerprintLayer.normMag = prevDlnet.Layers(4).normMag; 90 | % newFingerprintLayer.b1 = prevDlnet.Layers(4).b1; 91 | 92 | numClasses = prevClassNum + newClassesNum; 93 | 94 | % Build new network and start continual learning. 95 | lgraph2 = layerGraph(net); 96 | lgraph2 = lgraph2.removeLayers('classify_1'); 97 | dlnet = dlnetwork(replaceLayer(lgraph2, 'Fingerprints', newFingerprintLayer)); 98 | 99 | 100 | XTrain = continualLearnX; 101 | YTrain = continualLearnY; 102 | XTest = cvContinualLearnX; 103 | YTest = (cvContinualLearnY); 104 | numEpochs = 3; 105 | miniBatchSize = 20; 106 | plots = "training-progress"; 107 | statusFigureIdx = []; 108 | statusFigureAxis = []; 109 | if plots == "training-progress" 110 | figure; 111 | statusFigureIdx = gcf; 112 | statusFigureAxis = gca; 113 | % Go into the documentation of animatedline for more color codes 114 | lineNewLossTrain = animatedline('Color', '#0072BD', 'LineWidth', 1, 'Marker', '.', 'LineStyle', 'none'); 115 | lineNewCVAccuracy = animatedline('Color', '#D95319', 'LineWidth', 1); 116 | lineOldCVAccuracy = animatedline('Color', '#EDB120', 'LineWidth', 1); 117 | lineOldLossCV = animatedline('Color', '#7E2F8E', 'LineWidth', 1, 'Marker', '.', 'LineStyle', 'none'); 118 | % ylim([0 inf]) 119 | ylim([0 2]) 120 | xlabel("Iteration") 121 | ylabel("Metrics") 122 | legend('New task loss', 'New task accuracy', 'Old task accuracy', 'Old task lost'); 123 | grid on 124 | figure; 125 | lineDepth = animatedline('LineWidth',1); 126 | ylabel('Depth') 127 | ylim([-inf inf]) 128 | xlabel("Iteration") 129 | end 130 | L2RegularizationFactor = 0.01; 131 | initialLearnRate = 0.01; 132 | decay = 0.01; 133 | momentum = 0.9; 134 | velocities = []; 135 | learnRates = []; 136 | momentums = []; 137 | gradientMasks = []; 138 | numObservations = numel(YTrain); 139 | numIterationsPerEpoch = floor(numObservations ./ miniBatchSize); 140 | iteration = 0; 141 | start = tic; 142 | classes = categorical(YTrain); 143 | 144 | % 145 | prevCY = [prevCY; zeros(newClassesNum, size(prevCY, 2))]; 146 | 147 | newCvAccuracy = cvAccuracy(dlnet, XTest, YTest, miniBatchSize, executionEnvironment); 148 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize, executionEnvironment); 149 | disp('CV accuracy b.f. cont. learning'); 150 | [newCvAccuracy, oldCvAccuracy] 151 | % Loop over epochs. 152 | fisherLossLst = []; 153 | for epoch = 1:numEpochs 154 | idx = randperm(numel(YTrain)); 155 | XTrain = XTrain(:, :, :, idx); 156 | YTrain = YTrain(idx); 157 | % Loop over mini-batches. 158 | for i = 1:numIterationsPerEpoch 159 | iteration = iteration + 1; 160 | % Read mini-batch of data and convert the labels to dummy 161 | % variables. 162 | idx = (i - 1) * miniBatchSize + 1:i * miniBatchSize; 163 | Xb = XTrain(:, :, :, idx); 164 | Yb = zeros(numClasses, miniBatchSize, 'single'); 165 | for c = 1:numClasses 166 | Yb(c, YTrain(idx) == (c)) = 1; 167 | end 168 | % Convert mini-batch of data to dlarray. 169 | dlX = dlarray(single(Xb), 'SSCB'); 170 | % If training on a GPU, then convert data to gpuArray. 171 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 172 | dlX = gpuArray(dlX); 173 | end 174 | % Evaluate the model gradients, state, and loss using dlfeval and the 175 | % modelGradients function and update the network state. 176 | [gradients, state, loss, fisherLoss, depth] = dlfeval(@modelGradientsOnWeightsEWC, dlnet, dlX, Yb, ... 177 | prevFisherInfo0, prevDlnet, newClassesNum, 1, fingerprintLayerIdx, prevClassNum); 178 | 179 | fisherLossLst(end + 1) = extractdata(gather(fisherLoss)); 180 | 181 | dlnet.State = state; 182 | % Determine learning rate for time-based decay learning rate schedule. 183 | learnRate = initialLearnRate / (1 + decay * iteration); 184 | if isempty(velocities) 185 | velocities = packScalar(gradients, 0); 186 | learnRates = packScalar(gradients, learnRate); 187 | momentums = packScalar(gradients, momentum); 188 | L2Foctors = packScalar(gradients, 0.05); 189 | gradientMasks = packScalar(gradients, 1); 190 | % Let's lock some weights 191 | % Zero-bias dense layer only. 192 | % if epoch <= 5 193 | % gradientMasks{fingerprintLayerIdx,3} = {dlarray([zeros(18,18); ones(newClassesNum,fingerprintLen)])}; 194 | % else 195 | % gradientMasks{fingerprintLayerIdx,3} = {dlarray([ones(18,18); ones(newClassesNum,fingerprintLen)])}; 196 | % end 197 | if lockOldFps == 1 198 | gradientMasks{fingerprintLayerIdx - 1, 3} = {dlarray([zeros(prevClassNum, 20); ones(newClassesNum, 20)])}; % specify whether lock old fps or not 199 | % gradientMasks{9+1,3} = {dlarray([zeros(prevClassNum,1); ones(newClassesNum,1)])}; 200 | end 201 | 202 | for k = 1:fingerprintLayerIdx - 2 203 | gradientMasks.Value{k} = dlarray(zeros(size(gradientMasks.Value{k}))); 204 | end 205 | end 206 | % fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet,newClassesNum) 207 | [dlnet, velocities] = dlupdate(@sgdmFunctionL2, ... 208 | dlnet, gradients, velocities, ... 209 | learnRates, momentums, L2Foctors, gradientMasks); 210 | 211 | % Display the training progress. 212 | if plots == "training-progress" 213 | D = duration(0, 0, toc(start), 'Format', 'hh:mm:ss'); 214 | % figure(statusFigureIdx); 215 | if mod(iteration, 20) == 0 216 | newCvAccuracy = cvAccuracy(dlnet, XTest, YTest, miniBatchSize, executionEnvironment); 217 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize, executionEnvironment); 218 | [~, ~, oldCVLoss] = dlfeval(@modelGradientsOnWeights, dlnet, prevCX, prevCY); 219 | addpoints(lineNewCVAccuracy, iteration, newCvAccuracy); 220 | addpoints(lineOldCVAccuracy, iteration, oldCvAccuracy); 221 | addpoints(lineOldLossCV, iteration, double(gather(extractdata(oldCVLoss)))); 222 | end 223 | addpoints(lineNewLossTrain, iteration, double(gather(extractdata(loss)))) 224 | addpoints(lineDepth, iteration, double(extractdata(depth))); 225 | title(statusFigureAxis, "Epoch: " + epoch + ", Elapsed: " + string(D)) 226 | drawnow 227 | end 228 | end 229 | end 230 | 231 | figure 232 | weights = dlnet.Layers(14).Weights(:,:); 233 | similarMatrix = zeros(size(weights,1), size(weights,1)); 234 | for i = 1:size(weights,1) 235 | curWeight = weights(i,:); 236 | magCurWeight = sqrt(sum(curWeight.^2,2)); 237 | for j = 1:size(weights,1) 238 | nxtWeight = weights(j,:); 239 | magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 240 | similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 241 | .*(curWeight./magCurWeight)); 242 | end 243 | end 244 | sumDeviationAnglesAll = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 245 | imagesc(similarMatrix) 246 | 247 | newCvAccuracy = cvAccuracy(dlnet, XTest,YTest,miniBatchSize,executionEnvironment); 248 | oldCvAccuracy = cvAccuracy(dlnet, cX, (cY), miniBatchSize,executionEnvironment); 249 | disp('CV accuracy a.f. cont. learning'); 250 | [newCvAccuracy, oldCvAccuracy] 251 | 252 | 253 | function [totalLoss,corrOld,corrNew,grad] = evalFp(existingFp, newFeatures, fp) 254 | magFp = sqrt(sum(fp.^2)); 255 | unitFp = fp ./ magFp; 256 | corrOld = mean(sum( existingFp./sqrt( sum((existingFp).^2,2)).*unitFp, 2 )); 257 | corrNew = mean(sum( newFeatures'./sqrt( sum((newFeatures').^2,2)).*unitFp, 2)); 258 | totalLoss = corrOld - 10*corrNew; 259 | grad = dlgradient(totalLoss, fp); 260 | end 261 | 262 | function fp = findPotentialFingerprint(existingFp, newFeatures) 263 | fp = mean(newFeatures,2)'; 264 | % magFp = sqrt(sum(fp.^2)); 265 | % unitFp = fp ./ magFp; 266 | record = []; 267 | for i = 1:20 268 | [totalLoss,corrOld,corrNew, grad] = dlfeval(@evalFp, (existingFp), (newFeatures), fp); 269 | record = [record;[totalLoss,corrOld,corrNew]]; 270 | fp = fp - 0.5.*grad; 271 | end 272 | % figure 273 | % plot(extractdata(record(:,1)),'LineWidth',1.5); 274 | % hold on; 275 | % plot(extractdata(record(:,2)),'LineWidth',1.5); 276 | % plot(extractdata(record(:,3)),'LineWidth',1.5); 277 | end 278 | 279 | function [gradients,state,loss,fisherLoss,depth] = modelGradientsOnWeightsEWC(dlnet, dlX, Y,... 280 | prevFisherInfo0, prevDlnet, newClassesNum, ewcLambda, fingerprintLayerIdx, prevClassNum) 281 | [dlYPred,state] = forward(dlnet,dlX); 282 | penalty = 0; 283 | scalarL2Factor = 0; 284 | if scalarL2Factor ~= 0 285 | paramLst = dlnet.Learnables.Value; 286 | for i = 1:size(paramLst,1) 287 | penalty = penalty + sum((paramLst{i}(:)).^2); 288 | end 289 | end 290 | 291 | numClasses = prevClassNum + newClassesNum; 292 | newAvgCosine = -1./(numClasses-1); 293 | prevSumDevLimit = prevClassNum*(prevClassNum-1)/2 * newAvgCosine; 294 | 295 | fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet, newClassesNum, fingerprintLayerIdx)*ewcLambda/2; 296 | % loss = crossentropy(dlYPred,Y) + scalarL2Factor*penalty + fisherLoss + sqrt((trainDepth-prevSumDevLimit).^2); 297 | loss = crossentropy(dlYPred,Y) + scalarL2Factor*penalty + fisherLoss; 298 | gradients = dlgradient(loss, dlnet.Learnables); 299 | end 300 | 301 | function fisherLoss = calcFisherLoss(prevFisherInfo0, dlnet, prevDlnet, newClassesNum,fingerprintLayerIdx) 302 | prevWeights = prevDlnet.Learnables.Value; 303 | curWeights = dlnet.Learnables.Value; 304 | % fisherLossMatrix = {}; 305 | sumLoss = 0; 306 | elementCount = 1; 307 | for i = 1:size(prevWeights,1) 308 | if i >= fingerprintLayerIdx-1 309 | loss = ((prevWeights{i}-curWeights{i}(1:size(curWeights{i},1)-newClassesNum,:)).^2) .* prevFisherInfo0.Value{i}; 310 | % fisherLossMatrix{end+1} = loss; 311 | sumLoss = sumLoss + sum(loss(:)); 312 | elementCount = elementCount + numel(prevWeights{i}); 313 | else 314 | loss = ((prevWeights{i}-curWeights{i}).^2) .* prevFisherInfo0.Value{i}; 315 | % fisherLossMatrix{end+1} = loss; 316 | sumLoss = sumLoss + sum(loss(:)); 317 | elementCount = elementCount + numel(prevWeights{i}); 318 | end 319 | end 320 | fisherLoss = sumLoss; 321 | end 322 | 323 | function accuracy = cvAccuracy(dlnet, XTest, YTest, miniBatchSize, executionEnvironment) 324 | dlXTest = dlarray(XTest,'SSCB'); 325 | if (executionEnvironment == "auto" && canUseGPU) || executionEnvironment == "gpu" 326 | dlXTest = gpuArray(dlXTest); 327 | end 328 | dlYPred = modelPredictions(dlnet,dlXTest,miniBatchSize); 329 | [~,idx] = max(extractdata(dlYPred),[],1); 330 | YPred = (idx); 331 | accuracy = mean(YPred(:) == YTest(:)); 332 | end 333 | 334 | function dlYPred = modelPredictions(dlnet,dlX,miniBatchSize) 335 | numObservations = size(dlX,4); 336 | numIterations = ceil(numObservations / miniBatchSize); 337 | numClasses = size(dlnet.Layers(end-1).Weights,1); 338 | dlYPred = zeros(numClasses,numObservations,'like',dlX); 339 | for i = 1:numIterations 340 | idx = (i-1)*miniBatchSize+1:min(i*miniBatchSize,numObservations); 341 | dlYPred(:,idx) = predict(dlnet,dlX(:,:,:,idx)); 342 | end 343 | end 344 | 345 | function [gradients,state] = logModelGradientsOnWeights(dlnet,dlX) 346 | [dlYPred,state] = forward(dlnet,dlX); 347 | loglikelyhood = log(dlYPred-min(dlYPred(:))+1e-5); 348 | gradients = dlgradient(mean(loglikelyhood(:)),dlnet.Learnables); 349 | end 350 | 351 | function [gradients,state,loss] = modelGradientsOnWeights(dlnet,dlX,Y) 352 | [dlYPred,state] = forward(dlnet,dlX); 353 | penalty = 0; 354 | scalarL2Factor = 0; 355 | if scalarL2Factor ~= 0 356 | paramLst = dlnet.Learnables.Value; 357 | for i = 1:size(paramLst,1) 358 | penalty = penalty + sum((paramLst{i}(:)).^2); 359 | end 360 | end 361 | loss = crossentropy(dlYPred,Y) + scalarL2Factor*penalty; 362 | gradients = dlgradient(loss,dlnet.Learnables); 363 | %gradients = dlgradient(loss,dlnet.Learnables(4,:)); 364 | end 365 | 366 | function [params, velocityUpdates] = sgdmFunction(params, paramGradients,... 367 | velocities, learnRates, momentums) 368 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 369 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 370 | params = params - velocityUpdates; 371 | end 372 | 373 | function [params, velocityUpdates] = sgdmFunctionL2(params, rawParamGradients,... 374 | velocities, learnRates, momentums, L2Foctors, gradientMasks) 375 | % https://towardsdatascience.com/stochastic-gradient-descent-momentum-explanation-8548a1cd264e 376 | % https://towardsdatascience.com/intuitions-on-l1-and-l2-regularisation-235f2db4c261 377 | paramGradients = rawParamGradients + 2*L2Foctors.*params; 378 | %Please be noted that even if rawParamGradients = 0, L2 will still try 379 | %to reduce the magnitudes of parameters 380 | velocityUpdates = momentums.*velocities+learnRates.*paramGradients; 381 | params = params - (velocityUpdates).*gradientMasks; 382 | end 383 | 384 | function tabVars = packScalar(target, scalar) 385 | % The matlabs' silly design results in such a strange function 386 | tabVars = target; 387 | for row = 1:size(tabVars(:,3),1) 388 | tabVars{row,3} = {... 389 | dlarray(... 390 | ones(size(tabVars.Value{row})).*scalar...%ones(size(tabVars(row,3).Value{1,1})).*scalar... 391 | )... 392 | }; 393 | end 394 | end 395 | 396 | -------------------------------------------------------------------------------- /ContinualLearning/adsb_recognition_singleBurst_v4_3b.m: -------------------------------------------------------------------------------- 1 | %For experiment with different training depth 2 | %Addition layer and dropout. 3 | 4 | clc; 5 | close all; 6 | clear; 7 | rng default; 8 | 9 | addpath('./matplotlib') 10 | addpath('../'); 11 | 12 | load('adsb_records_qt.mat'); 13 | % load('adsb_bladerf2_10M_qt0.mat'); 14 | payloadMatrix = reshape(payloadMatrix', ... 15 | length(payloadMatrix)/length(msgIdLst), length(msgIdLst))'; 16 | rawIMatrix = reshape(rawIMatrix', ... 17 | length(rawIMatrix)/length(msgIdLst), length(msgIdLst))'; 18 | rawQMatrix = reshape(rawQMatrix', ... 19 | length(rawQMatrix)/length(msgIdLst), length(msgIdLst))'; 20 | rawCompMatrix = rawIMatrix + rawQMatrix.*1j; 21 | if size(rawCompMatrix,2) < 1024 22 | appendingBits = (ceil(sqrt(size(rawCompMatrix,2))))^2 - size(rawCompMatrix,2); 23 | rawCompMatrix = [rawCompMatrix, zeros(size(rawCompMatrix,1), appendingBits)]; 24 | else 25 | rawCompMatrix = rawCompMatrix(:,1:1024); 26 | end 27 | uIcao = unique(icaoLst); 28 | c = countmember(uIcao,icaoLst); 29 | icaoOccurTb = [uIcao,c]; 30 | icaoOccurTb = sortrows(icaoOccurTb,2,'descend'); 31 | 32 | cond1 = latitudeLst >= -200; 33 | cond2 = longitudeLst >= -200; 34 | cond4 = altitudeLst >= 0; 35 | cond3 = cond1.*cond2.*cond4; 36 | DrawPic = [longitudeLst(logical(cond3),:),... 37 | latitudeLst(logical(cond3),:),... 38 | altitudeLst(logical(cond3),:),snrLst(logical(cond3),:)]; 39 | 40 | cond1 = icaoOccurTb(:,2)>=400; 41 | cond2 = icaoOccurTb(:,2)<=8000; 42 | cond3 = icaoOccurTb(:,2)>=350; 43 | cond4 = icaoOccurTb(:,2)<500; 44 | 45 | cond = logical(cond1.*cond2); 46 | selectedPlanes = icaoOccurTb(cond,:); 47 | unknowPlanes = icaoOccurTb(logical(cond3.*cond4),:); 48 | %Clip away ICAO IDs. 49 | rawCompMatrix(:,1:32*8) = zeros(size(rawCompMatrix,1),32*8); 50 | allTrainData = [icaoLst, abs(rawCompMatrix)]; 51 | 52 | minTrainChance = 300; 53 | maxTrainChance = 400; 54 | 55 | selectedBasebandData = zeros(size(allTrainData)); 56 | selectedRawCompData = zeros(size(rawCompMatrix)); 57 | cursor = 1; 58 | for i = 1:size(selectedPlanes,1) 59 | selection = allTrainData(:,1)==selectedPlanes(i,1); 60 | localBaseband = allTrainData(selection,:); 61 | localBaseband(:,1) = ones(size(localBaseband,1),1).*i; 62 | localComplex = rawCompMatrix(selection,:); 63 | localAngles = (angle(localComplex(:,:))); 64 | 65 | % figure 66 | % for k = 1:size(localAngles,1) 67 | % plot(localAngles(k,:),'.'); 68 | % title(strcat(num2str(k), ' / ', num2str(size(localAngles,1)))); 69 | % pause(30/1000); 70 | % end 71 | 72 | if size(localBaseband,1) < minTrainChance 73 | continue; 74 | elseif size(localBaseband,1) >= maxTrainChance 75 | rndSeq = randperm(size(localBaseband,1)); 76 | rndSeq = rndSeq(1:maxTrainChance); 77 | localBaseband = localBaseband(rndSeq,:); 78 | localComplex = localComplex(rndSeq,:); 79 | else 80 | %Nothing to do 81 | end 82 | selectedBasebandData(cursor:cursor+size(localBaseband,1)-1,:) = localBaseband; 83 | selectedRawCompData(cursor:cursor+size(localComplex,1)-1,:) = localComplex; 84 | cursor = cursor+size(localBaseband,1); 85 | % selectedBasebandData = [selectedBasebandData; localBaseband]; 86 | % selectedRawCompData = [selectedRawCompData; localComplex]; 87 | end 88 | selectedBasebandData = selectedBasebandData(1:cursor-1,:); 89 | selectedRawCompData = selectedRawCompData(1:cursor-1,:); 90 | 91 | offset = size(selectedPlanes,1); 92 | unknownBasebandData = zeros(size(allTrainData)); 93 | unknownRawCompData = zeros(size(rawCompMatrix)); 94 | cursor = 1; 95 | for i = 1:size(unknowPlanes,1) 96 | selection = allTrainData(:,1)==unknowPlanes(i,1); 97 | localBaseband = allTrainData(selection,:); 98 | localBaseband(:,1) = ones(size(localBaseband,1),1).*(i+offset); 99 | localComplex = rawCompMatrix(selection,:); 100 | localAngles = (angle(localComplex(:,:))); 101 | unknownBasebandData(cursor:cursor+size(localBaseband,1)-1,:) = localBaseband; 102 | unknownRawCompData(cursor:cursor+size(localComplex,1)-1,:) = localComplex; 103 | cursor = cursor+size(localBaseband,1); 104 | % unknownBasebandData = [unknownBasebandData; localBaseband]; 105 | % unknownRawCompData = [unknownRawCompData; localComplex]; 106 | end 107 | unknownBasebandData = unknownBasebandData(1:cursor-1,:); 108 | unknownRawCompData = unknownRawCompData(1:cursor-1,:); 109 | 110 | randSeries = randperm(size(selectedBasebandData,1)); 111 | selectedBasebandData = selectedBasebandData(randSeries,:); 112 | selectedRawCompData = selectedRawCompData(randSeries,:); 113 | 114 | randSeries = randperm(size(unknownBasebandData,1)); 115 | unknownBasebandData = unknownBasebandData(randSeries,:); 116 | unknownRawCompData = unknownRawCompData(randSeries,:); 117 | 118 | [X,cX,Y,cY] = makeDataTensor(selectedBasebandData,selectedRawCompData); 119 | [uX,cuX,uY,cuY] = makeDataTensor(unknownBasebandData,unknownRawCompData); 120 | 121 | inputSize = [size(X,1) size(X,2) size(X,3)]; 122 | % numClasses = size(unique(selectedBasebandData(:,1)),1); 123 | numClasses = max(numel(unique(Y)),numel(unique(cY))); 124 | 125 | layers = [ 126 | imageInputLayer(inputSize, 'Name', 'input') 127 | convolution2dLayer(5,10, 'Name', 'conv2d_1') 128 | batchNormalizationLayer('Name', 'batchNorm_1') 129 | reluLayer('Name', 'relu_1') 130 | convolution2dLayer(3, 10, 'Padding', 1, 'Name', 'conv2d_2') 131 | reluLayer('Name', 'relu_2') 132 | convolution2dLayer(3, 10, 'Padding', 1, 'Name', 'conv2d_3') 133 | reluLayer('Name', 'relu_3') 134 | additionLayer(2,'Name', 'add_1') 135 | % depthConcatenationLayer(2,'Name','add_1') 136 | tensorVectorLayer('Flatten') 137 | dropoutLayer(0.1,'Name','Dropout_1') 138 | fullyConnectedLayer(numClasses, 'Name', 'fc_bf_fp') % 11th 139 | batchNormalizationLayer('Name', 'batchNorm_2') 140 | 141 | 142 | zeroBiasFCLayer(numClasses,numClasses,'Fingerprints',[]) 143 | % fullyConnectedLayer(numClasses, 'Name', 'Fingerprints') 144 | 145 | 146 | %FCLayer(2*numClasses,numClasses,'Fingerprints',[]) 147 | 148 | %dropoutLayer(0.3,'Name','dropOut_1') 149 | %amplificationLayer(numClasses,'Fingerprints',[]) 150 | %nearestNeighbourLayer(numClasses,numClasses,'Fingerprints',[]) 151 | softmaxLayer('Name', 'softmax_1') 152 | classificationLayer('Name', 'classify_1') 153 | ]; 154 | 155 | 156 | lgraph = layerGraph(layers); 157 | lgraph = connectLayers(lgraph, 'relu_1', 'add_1/in2'); 158 | %lgraph = connectLayers(lgraph, 'relu_2', 'add_1/in3'); 159 | plot(lgraph); 160 | options = trainingOptions('sgdm',... 161 | 'Plots', 'training-progress',... 162 | 'ExecutionEnvironment','auto',... 163 | 'ValidationData',{cX,categorical(cY)},... 164 | 'MaxEpochs', 20, ... 165 | 'MiniBatchSize',128,... 166 | 'L2Regularization',0.01); 167 | [net,info] = trainNetwork(X, categorical(Y), lgraph, options); 168 | layerOneWeights = zeros(5*size(net.Layers(2).Weights,4)*inputSize(3),5); 169 | cursor = 1; 170 | for i = 1:size(net.Layers(2).Weights,4) 171 | for j = 1:inputSize(3) 172 | layerOneWeights(cursor:cursor+5-1,:) ... 173 | = squeeze(net.Layers(2).Weights(:,:,j,i)); 174 | cursor = cursor + 5; 175 | end 176 | end 177 | %iacc2 = find(~isnan(info.ValidationAccuracy) == 1) 178 | %vacc2 = info.ValidationAccuracy(~isnan(info.ValidationAccuracy) == 1); 179 | cursor = 1; 180 | layerOneResponse = {}; 181 | for i = 1:size(net.Layers(2).Weights,4)*inputSize(3) 182 | [H,W] = freqz2(layerOneWeights(cursor:cursor+5-1,:),1024); 183 | cursor = cursor + 5; 184 | layerOneResponse{end+1} = {H,W}; 185 | end 186 | figure 187 | cursor = 1; 188 | for i = 1:5 189 | for j = 1:6 190 | subplot(5,6,cursor); 191 | imagesc(abs(layerOneResponse{cursor}{1})); 192 | cursor = cursor + 1; 193 | end 194 | end 195 | 196 | YPred = classify(net, cX); 197 | accuracy = sum(categorical(cY) == YPred)/numel(cY) 198 | cm = confusionmat(categorical(cY),YPred); 199 | cm = cm./sum(cm,2); 200 | figure 201 | imagesc(cm); 202 | 203 | 204 | weights = net.Layers({net.Layers.Name}=="Fingerprints").Weights; 205 | similarMatrix = zeros(size(weights,1), size(weights,1)); 206 | for i = 1:size(weights,1) 207 | curWeight = weights(i,:); 208 | magCurWeight = sqrt(sum(curWeight.^2,2)); 209 | for j = 1:size(weights,1) 210 | nxtWeight = weights(j,:); 211 | magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 212 | similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 213 | .*(curWeight./magCurWeight)); 214 | end 215 | end 216 | sumDeviationAngles = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 217 | rationalSumDevAngles = -numClasses./2 218 | figure;imagesc(similarMatrix) 219 | 220 | 221 | 222 | 223 | -------------------------------------------------------------------------------- /ContinualLearning/structstruct.m: -------------------------------------------------------------------------------- 1 | % 2 | % https://www.mathworks.com/matlabcentral/fileexchange/32879-graphically-display-the-branch-structure-of-a-struct-variable 3 | % 4 | 5 | % structstruct(S) takes in a structure variable and displays its structure. 6 | % 7 | % INPUTS: 8 | % 9 | % Recursive function 'structstruct.m' accepts a single input of any class. 10 | % For non-structure input, structstruct displays the class and size of the 11 | % input and then exits. For structure input, structstruct displays the 12 | % fields and sub-fields of the input in an ASCII graphical printout in the 13 | % command window. The order of structure fields is preserved. 14 | % 15 | % OUTPUTS: 16 | % 17 | % (none yet!) 18 | function structstruct(S) 19 | % Figure the type and class of the input 20 | whosout = whos('S'); 21 | sizes = whosout.size; 22 | sizestr = [int2str(sizes(1)),'x',int2str(sizes(2))]; 23 | endstr = [': [' sizestr '] ' whosout.class]; 24 | % Print out the properties of the input variable 25 | disp(' '); 26 | disp([inputname(1) endstr]); 27 | % Check if S is a structure, then call the recursive function 28 | if isstruct(S) 29 | recursor(S,0,''); 30 | end 31 | % Print out a blank line 32 | disp(' '); 33 | end 34 | function recursor(S,level,recstr) 35 | recstr = [recstr ' |']; 36 | fnames = fieldnames(S); 37 | for i = 1:length(fnames) 38 | 39 | %% Print out the current fieldname 40 | 41 | % Take out the i'th field 42 | tmpstruct = S.(fnames{i}); 43 | 44 | % Figure the type and class of the current field 45 | whosout = whos('tmpstruct'); 46 | sizes = whosout.size; 47 | sizestr = [int2str(sizes(1)),'x',int2str(sizes(2))]; 48 | endstr = [': [' sizestr '] ' whosout.class]; 49 | 50 | % Create the strings 51 | if i == length(fnames) % Last field in the current level 52 | str = [recstr(1:(end-1)) '''--' fnames{i} endstr]; 53 | recstr(end) = ' '; 54 | else % Not the last field in the current level 55 | str = [recstr '--' fnames{i} endstr]; 56 | end 57 | 58 | % Print the output string to the command line 59 | disp(str); 60 | 61 | %% Determine if each field is a struct 62 | 63 | % Check if the i'th field of S is a struct 64 | if isstruct(tmpstruct) % If tmpstruct is a struct, recursive function call 65 | recursor(tmpstruct,level+1,recstr); % Call self 66 | end 67 | 68 | end 69 | end -------------------------------------------------------------------------------- /FCLayer.m: -------------------------------------------------------------------------------- 1 | classdef FCLayer < nnet.layer.Layer 2 | 3 | properties (Learnable) 4 | % Layer learnable parameters 5 | Weights; 6 | Biases; 7 | end 8 | 9 | methods 10 | function layer = FCLayer(inputDim,outputDim,name,initialWeights) 11 | % layer = weightedAdditionLayer(numInputs,name) creates a 12 | 13 | % Set number of inputs. 14 | %layer.NumInputs = inputDim; 15 | %layer.NumOutputs = numOutputs; 16 | % Set layer name. 17 | layer.Name = name; 18 | % Set layer description. 19 | layer.Description = "FC layer without bias neurons with " + inputDim + ... 20 | " inputs"; 21 | % Initialize layer weights. 22 | % layer.Weights = dlarray(randn(outputDim,inputDim).*0.0001); 23 | % layer.Biases = dlarray(randn(outputDim,1).*0.0001); 24 | stdGlorot = sqrt(2/(inputDim + outputDim)); 25 | layer.Weights = dlarray(rand(outputDim,inputDim).*stdGlorot); 26 | layer.Biases = dlarray(zeros(outputDim,1)); 27 | 28 | %layer.Biases = rand(outputDim,1); 29 | if numel(initialWeights) ~= 0 30 | layer.Weights = initialWeights; 31 | end 32 | end 33 | 34 | function Z = predict(layer, X) 35 | if ndims(X) >= 3 36 | batchSize = size(X,4); 37 | else 38 | batchSize = size(X,ndims(X)); 39 | end 40 | Z = (layer.Weights*squeeze(X)+(layer.Biases )); 41 | % Z = layer.Weights*squeeze(X); 42 | 43 | % % Z = reshape((layer.Weights./sqrt(sum((layer.Weights).^2,2)))*squeeze(X),... 44 | % % [1,1,size(layer.Weights,1),batchSize]); 45 | % if ndims(X) >= 3 46 | % Z = reshape(... 47 | % ( layer.Weights )*( squeeze(X) ) + layer.Biases,... 48 | % [1,1,size(layer.Weights,1),batchSize]); 49 | % % Z = reshape(... 50 | % % ( layer.Weights )*( squeeze(X) ),... 51 | % % [1,1,size(layer.Weights,1),batchSize]); 52 | % else 53 | % Z = reshape(... 54 | % ( layer.Weights )*( squeeze(X) ) + layer.Biases,... 55 | % [size(layer.Weights,1),batchSize]); 56 | % % Z = reshape(... 57 | % % ( layer.Weights )*( squeeze(X) ) ,... 58 | % % [size(layer.Weights,1),batchSize]); 59 | % end 60 | end 61 | end 62 | end 63 | -------------------------------------------------------------------------------- /Formal Proof of Orthogonality.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcwhy/CSIL/8ce8637daf4dc60eeb1c56bff64c050c5b2353e9/Formal Proof of Orthogonality.pdf -------------------------------------------------------------------------------- /HypersphereLib/HyperSphere.m: -------------------------------------------------------------------------------- 1 | function Y = HyperSphere(X,varargin) 2 | % Gianluca Dorini (2020). hypersphere (https://www.mathworks.com/matlabcentral/fileexchange/5397-hypersphere), MATLAB Central File Exchange. Retrieved November 9, 2020. 3 | % n-dimensional hypersphere 4 | % 5 | % syntax: Y = HyperSphere(X,r) 6 | % Y = HyperSphere(X) 7 | % 8 | % INPUT ARGUMENTS 9 | % 10 | % -X (n - 1) x (N) every row of this matrix is a tuple 11 | % of n-1 parametric coordinates. If the number of points is 12 | % N, the number of rows is N 13 | % 14 | % -r the radius, default value: r = 1 15 | % 16 | % OUTPUT ARGUMENTS 17 | % 18 | % Y (n) x (N) every row of this matrix is a tuple 19 | % of n catesian coordinates of the n-dimensional hypersphere 20 | % 21 | % this function calculates the cartesian coordinates of an n-dimensional hypersphere 22 | % given the n-1 dimensions vector 'X' of parametric coordinates and the 23 | % radius 'r' 24 | % 25 | % 26 | % The n-hypersphere (often simply called the n-sphere) is a generalization 27 | % of the circle (2-sphere) and usual sphere (the 3-sphere) to dimensions n > 3. 28 | % The n-sphere is therefore defined as the set of n-tuples of points 29 | % (x1,x2,x3, ...,xn ) such that 30 | % 31 | % x1^2 + x2^2 + x3^2 + ... + xn^2 = r^2 32 | % 33 | % where r is the radius of the hypersphere. 34 | % 35 | % 36 | % EXAMPLES 37 | % 38 | % 39 | % % linspace points on a 2-sphere line: 40 | % 41 | % N = 20; 42 | % X = linspace(0,2*pi, N); 43 | % r = 2; 44 | % Y = HyperSphere(X,r); 45 | % plot(Y(1,:),Y(2,:),'r.'); 46 | % 47 | % 48 | % % random points on 3-sphere surface: 49 | % 50 | % N = 1000; 51 | % X = 2*pi*rand(2,N); 52 | % r = .5; 53 | % Y = HyperSphere([X],r); 54 | % plot3(Y(1,:),Y(2,:),Y(3,:),'r.'); 55 | if nargin == 2 56 | r = varargin{1}; 57 | else 58 | r = 1; 59 | end 60 | % % % % % N = size(X,1) + 1; 61 | % % % % % n = size(X,2); 62 | % % % % % Y = ones(N,n); 63 | % % % % % for i = 1:N - 1 64 | % % % % % y = [repmat(cos(X(i,:)),i,1)]; 65 | % % % % % y = [y ; sin(X(i,:))]; 66 | % % % % % y = [y ; ones(N - i - 1,n)]; 67 | % % % % % 68 | % % % % % Y = Y.*y; 69 | % % % % % end 70 | % % % % % Y = Y*r; 71 | 72 | % 73 | m = size(X,1); 74 | % number of input vectors 75 | n = size(X,2); 76 | 77 | % reverse the order of X 78 | X = X(m:-1:1,:); 79 | 80 | % sine terms 81 | S = [sin(X); ones(1,n)]; 82 | % cosine terms 83 | C = [ones(1,n); cumprod(cos(X),1)]; 84 | % calculate output 85 | Y = r * C .* S; 86 | 87 | % reverse the order of Y 88 | Y = Y(m+1:-1:1,:); -------------------------------------------------------------------------------- /HypersphereLib/HypersphereSurfArea.m: -------------------------------------------------------------------------------- 1 | function s = HypersphereSurfArea(n,r) 2 | %https://www.phys.uconn.edu/~rozman/Courses/P2400_17S/downloads/nsphere.pdf 3 | s = unitHypersphereSurfArea(n)*r^(n-1); 4 | end -------------------------------------------------------------------------------- /HypersphereLib/hypersphereCapArea.m: -------------------------------------------------------------------------------- 1 | function s = hypersphereCapArea(n,r,h) 2 | %https://en.wikipedia.org/wiki/Spherical_cap#Hyperspherical_cap 3 | 4 | s = 0.5*unitHypersphereSurfArea(n)*(r^(n-1))... 5 | *betainc((2*r*h-h^2)/(r^2),(n-1)/2,0.5); 6 | end -------------------------------------------------------------------------------- /HypersphereLib/hypersphereCoverageTest.m: -------------------------------------------------------------------------------- 1 | clc 2 | clear 3 | close all; 4 | cosAlpha = [0.3:0.1:0.9]; 5 | results = []; 6 | results2 = []; 7 | set(0,'DefaultTextFontName','Times','DefaultTextFontSize',18,... 8 | 'DefaultAxesFontName','Times','DefaultAxesFontSize',18,... 9 | 'DefaultLineLineWidth',1,'DefaultLineMarkerSize',7.75) 10 | for i = 2:100 11 | hsa = HypersphereSurfArea(i,1); 12 | hsc = []; 13 | ratios = []; 14 | classNum = []; 15 | 16 | for j = 1:numel(cosAlpha) 17 | hsc(end+1) = hypersphereCapArea(i,1,1- cosAlpha(j) ); 18 | ratios(end+1) = hsc(end)./hsa; 19 | classNum(end+1) = hsa./hsc(end); 20 | end 21 | results(end+1,:) = [i,ratios]; 22 | results2(end+1,:) = [i,classNum]; 23 | 24 | end 25 | figure 26 | hold on; 27 | legendStrings = {}; 28 | for i = 2:size(results,2) 29 | plot(results(:,1),results(:,i),'LineWidth',1.5); 30 | legendStrings{end+1} = string('$cos(\sigma)=$') + string(cosAlpha(i-1)); 31 | end 32 | xlabel('Dimensions') 33 | ylabel('Hypersphere coverage ratio / class'); 34 | legend(legendStrings,'Interpreter','latex'); 35 | % set(gca,'FontSize',12) 36 | 37 | xlim([2,50]) 38 | grid on; 39 | box off; 40 | 41 | figure 42 | legendStrings = {}; 43 | for i = 2:size(results,2) 44 | semilogy(results2(:,1),(results2(:,i)),'LineWidth',1.5); 45 | hold on; 46 | legendStrings{end+1} = string('$cos(\sigma)=$') + string(cosAlpha(i-1)); 47 | end 48 | xlabel('Dimensions') 49 | ylabel('Maximum number of classes'); 50 | legend(legendStrings,'Interpreter','latex'); 51 | % set(gca,'FontSize',12) 52 | 53 | xlim([2,50]) 54 | grid on; 55 | box off; 56 | -------------------------------------------------------------------------------- /HypersphereLib/hypersphereCoverageTest.m~: -------------------------------------------------------------------------------- 1 | clc 2 | clear 3 | close all; 4 | cosAlpha = [0.2:0.1:0.9]; 5 | results = []; 6 | results = []; 7 | 8 | for i = 2:100 9 | hsa = HypersphereSurfArea(i,1); 10 | hsc = []; 11 | ratios = []; 12 | classNum = []; 13 | 14 | for j = 1:numel(cosAlpha) 15 | hsc(end+1) = hypersphereCapArea(i,1,1- cosAlpha(j) ); 16 | ratios(end+1) = hsc(end)./hsa; 17 | classNum(end+1) = hsc(end)./hsa; 18 | end 19 | results(end+1,:) = [i,ratios]; 20 | 21 | end 22 | figure 23 | hold on; 24 | legendStrings = {}; 25 | for i = 2:size(results,2) 26 | plot(results(:,1),results(:,i),'LineWidth',1.5); 27 | legendStrings{end+1} = string('$cos(\sigma)=$') + string(cosAlpha(i-1)); 28 | end 29 | xlabel('Dimensions') 30 | ylabel('Hypersphere coverage ratio / class'); 31 | legend(legendStrings,'Interpreter','latex'); 32 | % set(gca,'FontSize',12) 33 | set(gca,'DefaultTextFontName','Times','DefaultTextFontSize',18,... 34 | 'DefaultAxesFontName','Times','DefaultAxesFontSize',18,... 35 | 'DefaultLineLineWidth',1,'DefaultLineMarkerSize',7.75) 36 | xlim([2,50]) 37 | grid on; 38 | box off; 39 | 40 | -------------------------------------------------------------------------------- /HypersphereLib/hypersphereLayer.m: -------------------------------------------------------------------------------- 1 | classdef hypersphereLayer < nnet.layer.Layer 2 | % Example custom weighted addition layer. 3 | properties (Learnable) 4 | c; 5 | end 6 | properties 7 | c0; 8 | end 9 | 10 | methods 11 | function layer = hypersphereLayer(name,inputDim,R,lambda) 12 | % layer = weightedAdditionLayer(numInputs,name) creates a 13 | 14 | % Set number of inputs. 15 | %layer.NumInputs = inputDim; 16 | %layer.NumOutputs = numOutputs; 17 | % Set layer name. 18 | layer.Name = name; 19 | layer.c = randn(inputDim,1); 20 | layer.c0 = layer.c+2; 21 | % Set layer description. 22 | layer.Description = "calculate the hypersphere loss"; 23 | end 24 | 25 | function Z = predict(layer, X) 26 | if ndims(X) >= 3 27 | batchSize = size(X,4); 28 | else 29 | batchSize = size(X,ndims(X)); 30 | end 31 | sumDist = sum((squeeze(X)-layer.c).^2,1); 32 | 33 | Z = sumDist'; 34 | % Z = reshape(layer.Weights*squeeze(X)+(layer.Biases ),... 35 | % [1,1,size(layer.Weights,1),batchSize]); 36 | 37 | end 38 | end 39 | end 40 | -------------------------------------------------------------------------------- /HypersphereLib/matlabPlotForIEEEtran.m: -------------------------------------------------------------------------------- 1 | %% 2 | set(0,'DefaultTextFontName','Times','DefaultTextFontSize',18,... 3 | 'DefaultAxesFontName','Times','DefaultAxesFontSize',18,... 4 | 'DefaultLineLineWidth',1,'DefaultLineMarkerSize',7.75) 5 | 6 | %% 7 | 8 | t = linspace(0,1,200); % Create data to plot 9 | y = sin(2*pi*t); 10 | tau = linspace(0,1,10); 11 | x = sin(2*pi*tau); 12 | 13 | % subplot(2,2,1) 14 | plot(t,y,tau,x,'ro') 15 | 16 | grid on; % optional 17 | 18 | % Optionally add some text, a label, and a title 19 | text(0.6,0.5,'sin(2\pi\itt\rm)') 20 | xlabel('\itt') 21 | title('Plotting a Function with MatLab') 22 | 23 | print('-depsc2','-r600','plotfile3.eps') % Print to file 24 | %% 25 | 26 | % Script to plot -t log t function 27 | f = @(t)-t.*log(t); 28 | t = linspace(0,1.5,200); 29 | y = f(t); 30 | tau = linspace(0,1.5,7); 31 | x = f(tau); 32 | 33 | subplot(2,2,1) 34 | plot(t,y,'k',tau,x,'ro'); grid on 35 | 36 | xlabels = strvcat('0', ' ', '1'); 37 | ylabels = strvcat('0', ' ' ); 38 | set(gca,'XTick',[0 1/exp(1) 1],'XTickLabel',xlabels,... 39 | 'YTick',[0 1/exp(1)],'YTickLabel',ylabels) 40 | text('Interpreter','latex','String','$- \kern.8em t \log t$',... 41 | 'Position',[1 1/exp(1)-.2 ]) 42 | text('Interpreter','latex','String','$\frac{\log e}{e}$',... 43 | 'Position',[-.3 1/exp(1)]) 44 | text('Interpreter','latex','String','$1/e$',... 45 | 'Position',[1/exp(1)-.09 -1.15]) 46 | 47 | print('-depsc2','-r600','plotfile4.eps') 48 | %% Histograms 49 | n = 50000; % Number of simulations 50 | X = rand(1,n); 51 | Y = rand(1,n)*2; 52 | Z = X+Y; 53 | nbins = 40; % Number of bins for histogram 54 | hstgrm = makedenshist(Z,nbins); 55 | plothist(hstgrm) 56 | % Now plot true density of Z 57 | z = linspace(0,3,200); 58 | f = @(z).5*(z.*(0<=z & z<1)+(1<=z & z<2)+(3-z).* ... 59 | (2<=z & z<=3)); 60 | hold on 61 | plot(z,f(z),'k'); grid on 62 | hold off 63 | print('-depsc2','-r600','plotfile5.eps') 64 | %% 65 | 66 | function [hstgrm,varargout] = makedenshist(Z,nbins) 67 | % Make a density histogram with nbins bins out of the data in Z. 68 | % We return the 2-by-nbins array hstgrm, where 69 | % hstgrm(1,:) = the list of bin centers, and 70 | % hstgrm(2,:) = normalized histogram heights. 71 | % 72 | % The command 73 | % 74 | % hstgrm = makedenshist(Z,nbins) 75 | % 76 | % always prints the minimum and maximum data samples, 77 | % denoted by minZ and maxZ. Alternatively, the command 78 | % 79 | % [hstgrm,minZ,maxZ] = makedenshist(Z,nbins) 80 | % 81 | % returns these values to you without printing them. 82 | 83 | hstgrm = zeros(2,nbins); % Pre-allocate space 84 | 85 | minZ = min(Z); % Determine range of data 86 | maxZ = max(Z); 87 | if nargout==3 88 | varargout{1} = minZ; 89 | varargout{2} = maxZ; 90 | else 91 | fprintf('makedenshist: Data range = [ %g , %g ].\n',minZ,maxZ) 92 | end 93 | 94 | e = linspace(minZ,maxZ,nbins+1); % Set edges of bins 95 | 96 | a = e(1:nbins); % Compute centers of bins 97 | b = e(2:nbins+1); % and store result in 98 | hstgrm(1,:) = (a+b)/2; % hstgrm(1,:) 99 | 100 | H = histc(Z,e); % Get bin heights 101 | 102 | H(nbins) = H(nbins)+H(nbins+1); % Put any hits on right-most 103 | % edge into last bin 104 | 105 | % Compute and store the normalized bin heights 106 | bw = (maxZ-minZ)/nbins; 107 | hstgrm(2,:) = H(1:nbins)/(bw*length(Z)); 108 | end 109 | %% 110 | 111 | function plothist(hstgrm); 112 | % Plot a histogram generated by makedenshist. 113 | % Actually, as long as 114 | % hstgrm(1,:) = the list of bin centers, and 115 | % hstgrm(2,:) = normalized histogram heights, 116 | % plothist will work for you. 117 | 118 | bar(hstgrm(1,:),hstgrm(2,:),'hist') 119 | h = findobj(gca,'Type','patch'); 120 | set(h,'FaceColor','w','EdgeColor','k') 121 | end -------------------------------------------------------------------------------- /HypersphereLib/matlabPlotForIEEEtran.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcwhy/CSIL/8ce8637daf4dc60eeb1c56bff64c050c5b2353e9/HypersphereLib/matlabPlotForIEEEtran.pdf -------------------------------------------------------------------------------- /HypersphereLib/randsphere.m: -------------------------------------------------------------------------------- 1 | function X = randsphere(m,n,r) 2 | % Roger Stafford (2020). Random Points in an n-Dimensional Hypersphere (https://www.mathworks.com/matlabcentral/fileexchange/9443-random-points-in-an-n-dimensional-hypersphere), MATLAB Central File Exchange. Retrieved November 9, 2020. 3 | 4 | % This function returns an m by n array, X, in which 5 | % each of the m rows has the n Cartesian coordinates 6 | % of a random point uniformly-distributed over the 7 | % interior of an n-dimensional hypersphere with 8 | % radius r and center at the origin. The function 9 | % 'randn' is initially used to generate m sets of n 10 | % random variables with independent multivariate 11 | % normal distribution, with mean 0 and variance 1. 12 | % Then the incomplete gamma function, 'gammainc', 13 | % is used to map these points radially to fit in the 14 | % hypersphere of finite radius r with a uniform % spatial distribution. 15 | % Roger Stafford - 12/23/05 16 | 17 | X = randn(m,n); 18 | s2 = sum(X.^2,2); 19 | % X = X.*repmat(r*(gammainc(s2/2,n/2).^(1/n))./sqrt(s2),1,n); 20 | 21 | X = X.*repmat(r*(rand(m,1).^(1/n))./sqrt(s2),1,n); 22 | 23 | -------------------------------------------------------------------------------- /HypersphereLib/unitHypersphereSurfArea.m: -------------------------------------------------------------------------------- 1 | function a = unitHypersphereSurfArea(n) 2 | %https://en.wikipedia.org/wiki/Unit_sphere#:~:text=The%20surface%20area%20of%20an,dimensional%20ball%20of%20radius%20r. 3 | %https://en.wikipedia.org/wiki/N-sphere#Volume_and_surface_area 4 | a = (2*pi^(n/2))/(gamma(n/2)); 5 | end 6 | 7 | -------------------------------------------------------------------------------- /Readme.md: -------------------------------------------------------------------------------- 1 | This is the public repository of our paper: Class-Incremental Learning for Wireless Device Identification in IoT, which is available [HERE](Class_incremental_learning_for_device_identification_in_IoT_IoT_16942_2021.pdf) and [IEEE Internet of Things Journal](https://ieeexplore.ieee.org/document/9425491) 2 | 3 | More importantly, we [mathematically proved](https://github.com/pcwhy/CSIL/blob/main/Formal%20Proof%20of%20Orthogonality.pdf) and verified the effect orthogonal memory representation within artificial neural network. 4 | 5 | We are delighted to know that recent advancement in neuroscience also shows the [biological evidence](https://www.nature.com/articles/s41593-021-00821-9) of orthogonal memory representations. But we have totally different [storylines](Storyline.pdf 6 | ). 7 | 8 | Our raw dataset is available at https://ieee-dataport.org/documents/ads-b-signals-records-non-cryptographic-identification-and-incremental-learning. 9 | It's a raw ADS-B signal dataset with labels, the dataset is captured using a BladeRF2 SDR receiver @ 1090MHz with a sample rate of 10MHz. 10 | 11 | Please goto [IEEE Dataport](https://ieee-dataport.org/documents/ads-b-signals-records-non-cryptographic-identification-and-incremental-learning) for the dataset (adsb_bladerf2_10M_qt0.mat) and preprocessed data (adsb-107loaded.mat). 12 | 13 | Sample code for data preprocessing and incremental learning are in [ContinualLearning](https://github.com/pcwhy/CSIL/tree/main/ContinualLearning) 14 | 15 | The code that drives the discovery of Remark 1 is in the [numerical simulation folder](https://github.com/pcwhy/CSIL/tree/main/numericalSimOfDoC) 16 | 17 | Comparison of various incremental learning algorithms are in [ContinualLearning/WorkStage](https://github.com/pcwhy/CSIL/tree/main/ContinualLearning/WorkStage) 18 | -------------------------------------------------------------------------------- /Storyline.pdf: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/pcwhy/CSIL/8ce8637daf4dc60eeb1c56bff64c050c5b2353e9/Storyline.pdf -------------------------------------------------------------------------------- /countmember.m: -------------------------------------------------------------------------------- 1 | function COUNT = countmember(A,B) 2 | % COUNTMEMBER - count members 3 | % 4 | % COUNT = COUNTMEMBER(A,B) counts the number of times the elements of array A are 5 | % present in array B, so that C(k) equals the number of occurences of 6 | % A(k) in B. A may contain non-unique elements. C will have the same size as A. 7 | % A and B should be of the same type, and can be cell array of strings. 8 | % 9 | % Examples: 10 | % countmember([1 2 1 3],[1 2 2 2 2]) 11 | % % -> 1 4 1 0 12 | % countmember({'a','b','c'}, {'a','x','a'}) 13 | % % -> 2 0 0 14 | % 15 | % Elements = {'H' 'He' 'Li' 'C' 'N' 'O'} ; 16 | % Sample = randsample(Elements, 1e6, true,[.5 .1 .2 .1 .05 .05]) ; 17 | % bar(countmember(Elements, Sample)) ; 18 | % set(gca, 'xticklabel', Elements) 19 | % xlabel('Chemical element'), ylabel('N') ; 20 | % 21 | % See also ISMEMBER, UNIQUE, HISTC 22 | % tested in R2018b 23 | % version 2.1 (jan 2019) 24 | % (c) Jos van der Geest 25 | % email: samelinoa@gmail.com 26 | % History: 27 | % 1.0 (2005) created 28 | % 1.1 (??): removed dum variable from [AU,dum,j] = unique(A(:)) to reduce 29 | % overhead 30 | % 1.2 (dec 2008) - added comments, fixed some spelling and grammar 31 | % mistakes, after being selected as Pick of the Week (dec 2008) 32 | % 2.0 (apr 2016) - updated for R2015a 33 | % 2.1 (jan 2019) - using histcounts instead of histc; minor grammar fixes 34 | % input checks 35 | narginchk(2,2) ; 36 | if ~isequal(class(A), class(B)) 37 | error('Both inputs should be of the same class.') ; 38 | end 39 | if isempty(A) || isempty(B) 40 | % nothing to do 41 | COUNT = zeros(size(A)) ; 42 | else 43 | % which elements are unique in A, 44 | % also store the position to re-order later on 45 | [AUnique, ~, posJ] = unique(A(:)) ; 46 | % assign each element in B a number corresponding to the element of A 47 | [~, loc] = ismember(B, AUnique) ; 48 | % count these numbers 49 | N = histcounts(loc(:), 1:length(AUnique)+1) ; 50 | % re-order according to A, and reshape 51 | COUNT = reshape(N(posJ),size(A)) ; 52 | end -------------------------------------------------------------------------------- /covmatrix.m: -------------------------------------------------------------------------------- 1 | function [C, m] = covmatrix(X) 2 | %COVMATRIX Computes the covariance matrix of a vector population. 3 | % [C, M] = COVMATRIX(X) computes the covariance matrix C and the 4 | % mean vector M of a vector population organized as the rows of 5 | % matrix X. C is of size N-by-N and M is of size N-by-1, where N is 6 | % the dimension of the vectors (the number of columns of X). 7 | 8 | % Copyright 2002-2004 R. C. Gonzalez, R. E. Woods, & S. L. Eddins 9 | % Digital Image Processing Using MATLAB, Prentice-Hall, 2004 10 | % $Revision: 1.4 $ $Date: 2003/05/19 12:09:06 $ 11 | 12 | [K, n] = size(X); 13 | X = double(X); 14 | if n == 1 % Handle special case. 15 | C = 0; 16 | m = X; 17 | else 18 | % Compute an unbiased estimate of m. 19 | m = sum(X, 1)/K; 20 | % Subtract the mean from each row of X. 21 | X = X - m(ones(K, 1), :); 22 | % Compute an unbiased estimate of C. Note that the product is 23 | % X'*X because the vectors are rows of X. 24 | C = (X'*X)/(K - 1); 25 | m = m'; % Convert to a column vector. 26 | end -------------------------------------------------------------------------------- /extractNoise.m: -------------------------------------------------------------------------------- 1 | function [noiseVector,medoids] = extractNoise(inputSignal) 2 | mu = mean(inputSignal); 3 | sigma = std(inputSignal); 4 | 5 | % Use probabilistic filtering to restore the rational signal 6 | maskPostive = inputSignal <= mu + 3*sigma; 7 | maskNegative = inputSignal >= mu - 3*sigma; 8 | mask = maskPostive.*maskNegative; 9 | cleanInput = inputSignal.*mask; 10 | 11 | % Estimate the generative model 12 | lowMean = min(cleanInput); 13 | highMean = max(cleanInput); 14 | 15 | 16 | for rounds = 1:3 17 | highLst = []; 18 | lowLst = []; 19 | for i = 1:length(cleanInput) 20 | if abs(cleanInput(i) - highMean) <= abs(cleanInput(i)-lowMean) 21 | highLst(end+1) = i; 22 | else 23 | lowLst(end+1) = i; 24 | end 25 | end 26 | highMean = mean(cleanInput(highLst)); 27 | lowMean = mean(cleanInput(lowLst)); 28 | end 29 | 30 | threshold = (highMean - lowMean)/2+lowMean; 31 | rationalSignal = double(cleanInput >= threshold); 32 | rationalSignal = rationalSignal.*(highMean - lowMean)+lowMean; 33 | noiseVector = (rationalSignal - cleanInput); 34 | medoids = rationalSignal; 35 | 36 | end 37 | 38 | -------------------------------------------------------------------------------- /makeDataTensor.m: -------------------------------------------------------------------------------- 1 | function [dataTensor,cvDataTensor,label,cvLabel] = makeDataTensor(selectedBasebandData,selectedRawCompData) 2 | 3 | 4 | selectedNoiseVector = zeros(size(selectedBasebandData)); 5 | selectedNoiseVector(:,1) = selectedBasebandData(:,1); 6 | medoidVector = zeros(size(selectedBasebandData)); 7 | for i = 1:size(selectedBasebandData,1) 8 | [selectedNoiseVector(i,2:size(selectedBasebandData,2)),medoids]... 9 | = extractNoise(selectedBasebandData(i,2:size(selectedBasebandData,2))); 10 | medoidVector(i,1:1024) = medoids; 11 | % plot(selectedNoiseVector(1,2:size(selectedBasebandData,2)),'.','LineWidth',1) 12 | % hold on; 13 | % plot(selectedBasebandData(i,2:size(selectedBasebandData,2)),'--','LineWidth',1) 14 | % plot(medoids,'LineWidth',1); 15 | % xlim([1,300]) 16 | % ylim([-0.1,0.3]) 17 | % set(gca,'FontSize',12) 18 | % %set(gcf,'position',[405 330 528 279]); 19 | % set(gcf,'position',[405 417 441 192]); 20 | % legend('Pesudo noise','Baseband signals','Medoids') 21 | % box off 22 | end 23 | cvNoise = selectedNoiseVector(ceil(0.7*size(selectedNoiseVector,1)):size(selectedNoiseVector,1),:); 24 | selectedNoiseVector = selectedNoiseVector(1:ceil(0.7*size(selectedNoiseVector,1))-1,:); 25 | 26 | 27 | % noiseSample = cvNoise(:,2:end); 28 | % corrvals = []; 29 | % noiseCorrvals = []; 30 | % for i = 1:size(noiseSample,1) 31 | % r = corrcoef(noiseSample(i,:),selectedBasebandData(i,2:end)); 32 | % nr = corrcoef(randn(size(selectedBasebandData(i,2:end))),selectedBasebandData(i,2:end)); 33 | % corrvals(end+1)=r(2,1); 34 | % noiseCorrvals(end+1) = nr(2,1); 35 | % end 36 | % histogram(corrvals,'Normalization','probability') 37 | % hold on; 38 | % histogram(noiseCorrvals,'Normalization','probability') 39 | % set(gca,'FontSize',12) 40 | % set(gcf,'position',[405 417 441 147]); 41 | % box off 42 | % xlabel('Correlation coefficients') 43 | % ylabel('Probability') 44 | % legend('Pesudo noise','Gaussian noise ~ N(0,1)'); 45 | 46 | 47 | selectedFFTVector = zeros(size(selectedBasebandData)); 48 | selectedFFTVector(:,1) = selectedBasebandData(:,1); 49 | for i = 1:size(selectedBasebandData,1) 50 | % selectedFFTVector(i,2:end) = fftshift(fft(selectedBasebandData(i,2:end))); 51 | actualFFT = fftshift(fft(selectedRawCompData(i,:),1024)); 52 | rationaleFFT = fftshift(fft(medoidVector(i,:),1024)); 53 | %selectedFFTVector(i,2:end) = fftshift(fft(selectedRawCompData(i,:),1024)); 54 | selectedFFTVector(i,2:end) = actualFFT-rationaleFFT; 55 | end 56 | cvFFT = selectedFFTVector(ceil(0.7*size(selectedFFTVector,1)):size(selectedFFTVector,1),:); 57 | selectedFFTVector = selectedFFTVector(1:ceil(0.7*size(selectedFFTVector,1))-1,:); 58 | selectedFFTmag = abs(selectedFFTVector); 59 | selectedFFTang = angle(selectedFFTVector); 60 | %selectedFFTang = unwrap(atan2(real(selectedFFTVector),imag(selectedFFTVector))); 61 | 62 | cvFFTmag = abs(cvFFT); 63 | cvFFTang = angle(cvFFT); 64 | %cvFFTang = unwrap(atan2(real(cvFFT),imag(cvFFT))); 65 | 66 | 67 | featureDims = 3; 68 | 69 | trainDataTensor = zeros(size(selectedNoiseVector,1)*featureDims,... 70 | size(selectedNoiseVector,2)-1); 71 | 72 | for i = 1:size(selectedNoiseVector,1) 73 | cursor = i*featureDims-(featureDims-1); 74 | trainDataTensor(cursor,:) = selectedNoiseVector(i,2:end); 75 | trainDataTensor(cursor+1,:) = real(selectedFFTVector(i,2:end)); 76 | trainDataTensor(cursor+2,:) = imag(selectedFFTVector(i,2:end)); 77 | end 78 | 79 | cvDataTensor = zeros(size(cvNoise,1)*featureDims, size(cvNoise,2)-1); 80 | for i = 1:size(cvNoise,1) 81 | cursor = i*featureDims-(featureDims-1); 82 | cvDataTensor(cursor,:) = cvNoise(i,2:end); 83 | cvDataTensor(cursor+1,:) = real(cvFFT(i,2:end)); 84 | cvDataTensor(cursor+2,:) = imag(cvFFT(i,2:end)); 85 | end 86 | 87 | 88 | label = selectedNoiseVector(:,1); 89 | cvLabel = cvNoise(:,1); 90 | 91 | dataTensor = reshape(trainDataTensor',[sqrt(1024), sqrt(1024),... 92 | featureDims, size(selectedNoiseVector,1)]); 93 | cvDataTensor = reshape(cvDataTensor',[sqrt(1024), sqrt(1024),... 94 | featureDims, size(cvLabel,1)]); 95 | 96 | % One way to restore the origin signal is: 97 | % sig = X2(:,:,:,1); 98 | % sig2 = reshape(sig,[1,1024]); 99 | 100 | 101 | 102 | 103 | -------------------------------------------------------------------------------- /numericalSimOfDoC/Readme.md: -------------------------------------------------------------------------------- 1 | Here are the numerical simulation code that drives me to discover the most critical Remark I in my paper. 2 | -------------------------------------------------------------------------------- /numericalSimOfDoC/solveSpace.m: -------------------------------------------------------------------------------- 1 | clc 2 | close all; 3 | clear; 4 | % rng default 5 | 6 | x1 = optimvar('x1',3); 7 | x2 = optimvar('x2',3); 8 | x3 = optimvar('x3',3); 9 | 10 | prob = optimproblem; 11 | % prob.Constraints.cons1 = x1(1)^2 + x1(2)^2 + x1(3).^2 == 1; 12 | % prob.Constraints.cons2 = x2(1)^2 + x2(2)^2 + x2(3).^2 == 1; 13 | % prob.Constraints.cons3 = x3(1)^2 + x3(2)^2 + x3(3).^2 == 1; 14 | prob.Constraints.cons1 = sum(x1.^2,'all') == 1; 15 | prob.Constraints.cons2 = sum(x2.^2,'all') == 1; 16 | prob.Constraints.cons3 = sum(x3.^2,'all') == 1; 17 | 18 | prob.Objective = sum(x1.*x2 + x1.*x3 + x2.*x3,'all'); 19 | % prob.Objective = sum(x1.*x2.*x3,'all'); 20 | 21 | x01 = randn(3,1); 22 | x02 = randn(3,1); 23 | x03 = randn(3,1); 24 | x0.x1=x01./vecnorm(x01); 25 | x0.x2=x02./vecnorm(x02); 26 | x0.x3=x03./vecnorm(x03); 27 | sum(x0.x1.*x0.x2+x0.x1.*x0.x3+x0.x2.*x0.x3,'all') 28 | [sol,fval,exitflag] = solve(prob,x0); 29 | sum(sol.x1.*sol.x2+sol.x1.*sol.x3+sol.x2.*sol.x3,'all') 30 | 31 | 32 | 33 | 34 | 35 | 36 | 37 | -------------------------------------------------------------------------------- /numericalSimOfDoC/solveSpaceFunction.m: -------------------------------------------------------------------------------- 1 | clc 2 | close all; 3 | clear; 4 | % rng default 5 | 6 | numFPs = 3; 7 | numDims = 2; 8 | 9 | variableArray = {}; 10 | for i = 1:numFPs 11 | variableArray{end+1} = optimvar(strcat('x',string(i)),numDims); 12 | end 13 | 14 | prob = optimproblem; 15 | 16 | expr = optimexpr(1,1); 17 | count = 0; 18 | for i = 1:numFPs 19 | for j = 1:i-1 20 | expr = expr + variableArray{i}.*variableArray{j}; 21 | count = count + 1; 22 | end 23 | end 24 | expr = sum(expr); 25 | show(expr); 26 | % prob.Objective = sum(x1.*x2 + x1.*x3 + x2.*x3,'all'); 27 | prob.Objective = expr; 28 | 29 | consStruct = struct; 30 | consCount = 1; 31 | expr = optimexpr; 32 | for i = 1:(numFPs) 33 | expr = sum(variableArray{i}.^2,'all') == 1; 34 | consStruct = setfield(consStruct, strcat('cons',string(consCount)),expr); 35 | consCount = consCount + 1; 36 | end 37 | % for i = 1:numFPs 38 | % for j = 1:i-1 39 | % expr = variableArray{i}.*variableArray{j} <= 0.5; 40 | % consStruct = setfield(consStruct, strcat('cons',string(consCount)),expr); 41 | % consCount = consCount + 1; 42 | % end 43 | % end 44 | 45 | prob.Constraints = consStruct; 46 | 47 | 48 | x0 = struct; 49 | for i = 1:numFPs 50 | rv = randn(numDims,1); 51 | rv = rv./vecnorm(rv); 52 | x0 = setfield(x0, strcat('x',string(i)),rv); 53 | end 54 | % x01 = randn(3,1); 55 | % x02 = randn(3,1); 56 | % x03 = randn(3,1); 57 | % x0.x1=x01./vecnorm(x01); 58 | % x0.x2=x02./vecnorm(x02); 59 | % x0.x3=x03./vecnorm(x03); 60 | prob.Objective.evaluate(x0) 61 | [sol,fval,exitflag,output,lambda] = solve(prob,x0) 62 | prob.Objective.evaluate(sol) 63 | 64 | FPs = []; 65 | sol = struct2cell(sol); 66 | lambda.Constraints 67 | for i = 1:numFPs 68 | FPs(end+1,:) = sol{i}'; 69 | end 70 | % FPs 71 | if numDims == 2 72 | figure; 73 | plot(0,0,'o','LineWidth',2); 74 | hold on; 75 | plot(FPs(:,1),FPs(:,2),'*','LineWidth',2) 76 | else 77 | weights = FPs; 78 | similarMatrix = zeros(size(weights,1), size(weights,1)); 79 | for i = 1:size(weights,1) 80 | curWeight = weights(i,:); 81 | magCurWeight = sqrt(sum(curWeight.^2,2)); 82 | for j = 1:size(weights,1) 83 | nxtWeight = weights(j,:); 84 | magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 85 | similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 86 | .*(curWeight./magCurWeight)); 87 | end 88 | end 89 | imagesc(similarMatrix) 90 | end 91 | 92 | 93 | 94 | 95 | 96 | -------------------------------------------------------------------------------- /numericalSimOfDoC/solveSpaceFunction3.m: -------------------------------------------------------------------------------- 1 | clc 2 | close all; 3 | clear; 4 | %rng default 5 | 6 | numFPs = 4; 7 | numDims = 2; 8 | numLambdas = numFPs; 9 | 10 | variableArray = {}; 11 | for i = 1:numFPs 12 | variableArray{end+1} = optimvar(strcat('x',string(i)),numDims); 13 | end 14 | for i = 1:numLambdas 15 | variableArray{end+1} = optimvar(strcat('lbd',string(i)),1); 16 | end 17 | 18 | prob = eqnproblem; 19 | 20 | eqStruct = struct; 21 | eqCount = 1; 22 | expr = optimexpr; 23 | for i = 1:(numLambdas) 24 | expr = sum(variableArray{i}.^2,'all') == 1; 25 | eqStruct = setfield(eqStruct, strcat('eq',string(eqCount)),expr); 26 | eqCount = eqCount + 1; 27 | end 28 | expr = 0; 29 | for i = 1:numFPs 30 | for j = 1:numDims 31 | for k = 1:numFPs 32 | if k ~= i 33 | expr = expr + variableArray{k}(j); 34 | else 35 | expr = expr + 2.*variableArray{i+numFPs}(1).*variableArray{k}(j); 36 | end 37 | end 38 | expr = expr == 0; 39 | eqStruct = setfield(eqStruct, strcat('eq',string(eqCount)),expr); 40 | eqCount = eqCount + 1; 41 | expr = 0; 42 | end 43 | end 44 | prob.Equations = eqStruct; 45 | x0 = struct; 46 | for i = 1:numFPs 47 | rv = randn(numDims,1); 48 | rv = rv./vecnorm(rv); 49 | x0 = setfield(x0, strcat('x',string(i)),rv); 50 | end 51 | for i = 1:numFPs 52 | rv = 1; 53 | x0 = setfield(x0, strcat('lbd',string(i)),rv); 54 | end 55 | show(prob) 56 | % prob.Objective.evaluate(x0) 57 | [sol,fval,exitflag] = solve(prob,x0) 58 | % prob.Objective.evaluate(sol) 59 | 60 | FPs = []; 61 | sol = struct2cell(sol); 62 | for i = 1:numFPs 63 | FPs(end+1,:) = sol{numFPs+i}'; 64 | end 65 | FPs 66 | if numDims == 2 67 | figure; 68 | plot(0,0,'o','LineWidth',2); 69 | hold on; 70 | plot(FPs(:,1),FPs(:,2),'*','LineWidth',2) 71 | else 72 | weights = FPs; 73 | similarMatrix = zeros(size(weights,1), size(weights,1)); 74 | for i = 1:size(weights,1) 75 | curWeight = weights(i,:); 76 | magCurWeight = sqrt(sum(curWeight.^2,2)); 77 | for k = 1:size(weights,1) 78 | nxtWeight = weights(k,:); 79 | magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 80 | similarMatrix(i,k)=sum((nxtWeight./magNxtWeight)... 81 | .*(curWeight./magCurWeight)); 82 | end 83 | end 84 | imagesc(similarMatrix) 85 | end 86 | 87 | weights = FPs; 88 | similarMatrix = zeros(size(weights,1), size(weights,1)); 89 | for i = 1:size(weights,1) 90 | curWeight = weights(i,:); 91 | magCurWeight = sqrt(sum(curWeight.^2,2)); 92 | for j = 1:size(weights,1) 93 | nxtWeight = weights(j,:); 94 | magNxtWeight = sqrt(sum(nxtWeight.^2,2)); 95 | similarMatrix(i,j)=sum((nxtWeight./magNxtWeight)... 96 | .*(curWeight./magCurWeight)); 97 | end 98 | end 99 | sumDeviationAngles = (sum(similarMatrix,'all') - sum(diag(similarMatrix)))./2 100 | 101 | 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /plotConfMat.m: -------------------------------------------------------------------------------- 1 | function plotConfMat(varargin) 2 | %PLOTCONFMAT plots the confusion matrix with colorscale, absolute numbers 3 | % and precision normalized percentages 4 | % 5 | % usage: 6 | % PLOTCONFMAT(confmat) plots the confmat with integers 1 to n as class labels 7 | % PLOTCONFMAT(confmat, labels) plots the confmat with the specified labels 8 | % 9 | % Vahe Tshitoyan 10 | % 20/08/2017 11 | % 12 | % Arguments 13 | % confmat: a square confusion matrix 14 | % labels (optional): vector of class labels 15 | 16 | % number of arguments 17 | switch (nargin) 18 | case 0 19 | confmat = 1; 20 | labels = {'1'}; 21 | case 1 22 | confmat = varargin{1}; 23 | labels = 1:size(confmat, 1); 24 | otherwise 25 | confmat = varargin{1}; 26 | labels = varargin{2}; 27 | end 28 | 29 | confmat(isnan(confmat))=0; % in case there are NaN elements 30 | numlabels = size(confmat, 1); % number of labels 31 | 32 | % calculate the percentage accuracies 33 | confpercent = 100*confmat./repmat(sum(confmat, 1),numlabels,1); 34 | 35 | % plotting the colors 36 | imagesc(confpercent); 37 | title(sprintf('Accuracy: %.2f%%', 100*trace(confmat)/sum(confmat(:)))); 38 | ylabel('Output Class'); xlabel('Target Class'); 39 | 40 | % set the colormap 41 | colormap(flipud(gray)); 42 | 43 | % Create strings from the matrix values and remove spaces 44 | textStrings = num2str([confpercent(:), confmat(:)], '%.1f%%\n%d\n'); 45 | textStrings = strtrim(cellstr(textStrings)); 46 | 47 | % Create x and y coordinates for the strings and plot them 48 | [x,y] = meshgrid(1:numlabels); 49 | hStrings = text(x(:),y(:),textStrings(:), ... 50 | 'HorizontalAlignment','center'); 51 | 52 | % Get the middle value of the color range 53 | midValue = mean(get(gca,'CLim')); 54 | 55 | % Choose white or black for the text color of the strings so 56 | % they can be easily seen over the background color 57 | textColors = repmat(confpercent(:) > midValue,1,3); 58 | set(hStrings,{'Color'},num2cell(textColors,2)); 59 | 60 | % Setting the axis labels 61 | set(gca,'XTick',1:numlabels,... 62 | 'XTickLabel',labels,... 63 | 'YTick',1:numlabels,... 64 | 'YTickLabel',labels,... 65 | 'TickLength',[0 0]); 66 | end -------------------------------------------------------------------------------- /preluLayer.m: -------------------------------------------------------------------------------- 1 | classdef preluLayer < nnet.layer.Layer 2 | % Example custom PReLU layer. 3 | 4 | properties (Learnable) 5 | % Layer learnable parameters 6 | % Scaling coefficient 7 | Alpha 8 | end 9 | 10 | methods 11 | function layer = preluLayer(numChannels, name) 12 | % layer = preluLayer(numChannels, name) creates a PReLU layer 13 | % with numChannels channels and specifies the layer name. 14 | % Set layer name. 15 | layer.Name = name; 16 | % Set layer description. 17 | layer.Description = "PReLU with " + numChannels + " channels"; 18 | % Initialize scaling coefficient. 19 | layer.Alpha = rand([1 1 numChannels]); 20 | end 21 | 22 | function Z = predict(layer, X) 23 | % Z = predict(layer, X) forwards the input data X through the 24 | % layer and outputs the result Z. 25 | Z = max(0, X) + layer.Alpha .* min(0, X); 26 | end 27 | end 28 | end 29 | -------------------------------------------------------------------------------- /tensorVectorLayer.m: -------------------------------------------------------------------------------- 1 | classdef tensorVectorLayer < nnet.layer.Layer 2 | % Example custom weighted addition layer. 3 | methods 4 | function layer = tensorVectorLayer(name) 5 | % layer = weightedAdditionLayer(numInputs,name) creates a 6 | 7 | % Set number of inputs. 8 | %layer.NumInputs = inputDim; 9 | %layer.NumOutputs = numOutputs; 10 | % Set layer name. 11 | layer.Name = name; 12 | % Set layer description. 13 | layer.Description = "flatten any tensor into vector"; 14 | end 15 | 16 | function Z = predict(layer, X) 17 | if ndims(X) >= 3 18 | batchSize = size(X,4); 19 | else 20 | batchSize = size(X,ndims(X)); 21 | end 22 | sX = squeeze(X); 23 | flattenX = reshape(sX,[1,1,numel(X)./batchSize,batchSize]); 24 | % flattenX(1,1,:) = 1e3.*flattenX(1,1,:)./sqrt(sum(flattenX(1,1,:).^2,'all')); 25 | Z = stripdims(flattenX); 26 | % Z = reshape(layer.Weights*squeeze(X)+(layer.Biases ),... 27 | % [1,1,size(layer.Weights,1),batchSize]); 28 | 29 | % Z = reshape((layer.Weights./sqrt(sum((layer.Weights).^2,2)))*squeeze(X),... 30 | % [1,1,size(layer.Weights,1),batchSize]); 31 | 32 | end 33 | end 34 | end 35 | -------------------------------------------------------------------------------- /weightedAdditionLayer.m: -------------------------------------------------------------------------------- 1 | classdef weightedAdditionLayer < nnet.layer.Layer 2 | % Example custom weighted addition layer. 3 | 4 | properties (Learnable) 5 | % Layer learnable parameters 6 | 7 | % Scaling coefficients 8 | Weights 9 | end 10 | 11 | methods 12 | function layer = weightedAdditionLayer(numInputs,name) 13 | % layer = weightedAdditionLayer(numInputs,name) creates a 14 | % weighted addition layer and specifies the number of inputs 15 | % and the layer name. 16 | 17 | % Set number of inputs. 18 | layer.NumInputs = numInputs; 19 | 20 | % Set layer name. 21 | layer.Name = name; 22 | 23 | % Set layer description. 24 | layer.Description = "Weighted addition of " + numInputs + ... 25 | " inputs"; 26 | 27 | % Initialize layer weights. 28 | layer.Weights = rand(1,numInputs); 29 | end 30 | 31 | function Z = predict(layer, varargin) 32 | % Z = predict(layer, X1, ..., Xn) forwards the input data X1, 33 | % ..., Xn through the layer and outputs the result Z. 34 | 35 | X = varargin; 36 | W = layer.Weights; 37 | 38 | % Initialize output 39 | X1 = X{1}; 40 | sz = size(X1); 41 | Z = zeros(sz,'like',X1); 42 | 43 | % Weighted addition 44 | for i = 1:layer.NumInputs 45 | Z = Z + W(i)*X{i}; 46 | end 47 | end 48 | end 49 | end 50 | -------------------------------------------------------------------------------- /yxBatchNorm.m: -------------------------------------------------------------------------------- 1 | classdef yxBatchNorm < nnet.layer.Layer 2 | 3 | properties (Learnable) 4 | gama; 5 | beta; 6 | end 7 | methods 8 | function layer = yxBatchNorm(name, indim) 9 | % Set layer name. 10 | layer.Name = name; 11 | layer.gama = randn(indim,1); 12 | layer.beta = randn(indim,1); 13 | % Set layer description. 14 | layer.Description = "Yongxin's BatchNorm for compatibility"; 15 | end 16 | function Z = predict(layer, X) 17 | eps = 1e-9; 18 | batchMu = mean(X,2); 19 | batchVar = var(X,0,2); 20 | Z = layer.gama.*(X-batchMu)./(sqrt(batchVar+eps)) + layer.beta; 21 | end 22 | end 23 | end 24 | -------------------------------------------------------------------------------- /zeroBiasFCLayer.m: -------------------------------------------------------------------------------- 1 | classdef zeroBiasFCLayer < nnet.layer.Layer 2 | % Example custom weighted addition layer. 3 | 4 | properties (Learnable) 5 | % Layer learnable parameters 6 | Weights; 7 | end 8 | properties 9 | normMag; 10 | b1;%, Biases; 11 | end 12 | methods 13 | function layer = zeroBiasFCLayer(inputDim,outputDim,name,initialWeights) 14 | % layer = weightedAdditionLayer(numInputs,name) creates a 15 | 16 | % Set number of inputs. 17 | %layer.NumInputs = inputDim; 18 | %layer.NumOutputs = numOutputs; 19 | % Set layer name. 20 | layer.Name = name; 21 | % Set layer description. 22 | layer.Description = "zero-bias FC layer with " + inputDim + ... 23 | " inputs"; 24 | % Initialize layer weights. 25 | stdGlorot = sqrt(1/(inputDim + outputDim)); 26 | layer.Weights = dlarray(rand(outputDim,inputDim).*stdGlorot); 27 | %layer.Weights = rand(outputDim,inputDim); 28 | layer.normMag = 5; 29 | layer.b1 = 1e-9; % a real small number to maintain numerical stability. 30 | %layer.Biases = rand(outputDim,1); 31 | if numel(initialWeights) ~= 0 32 | layer.Weights = initialWeights; 33 | end 34 | end 35 | 36 | function Z = predict(layer, X) 37 | if ndims(X) >= 3 38 | batchSize = size(X,4); 39 | else 40 | batchSize = size(X,ndims(X)); 41 | end 42 | % size(X) 43 | % Z = reshape(layer.Weights*squeeze(X)+(layer.Biases ),... 44 | % [1,1,size(layer.Weights,1),batchSize]); 45 | 46 | % Z = reshape((layer.Weights./sqrt(sum((layer.Weights).^2,2)))*squeeze(X),... 47 | % [1,1,size(layer.Weights,1),batchSize]); 48 | %size(X) 49 | if ndims(X) >= 3 50 | Z = reshape(... 51 | ( layer.Weights./sqrt((layer.b1)^2 + sum((layer.Weights).^2,2)) )... 52 | *( layer.normMag*squeeze(X)./sqrt((layer.b1)^2 + sum(squeeze(X).^2,1))),... 53 | [1,1,size(layer.Weights,1),batchSize])+layer.normMag; 54 | % Z = reshape(... 55 | % ( layer.Weights./sqrt((layer.b1)^2 + sum((layer.Weights).^2,2)) )... 56 | % *( layer.normMag*squeeze(X)),... 57 | % [1,1,size(layer.Weights,1),batchSize])+layer.normMag; 58 | else 59 | Z = reshape(... 60 | ( layer.Weights./sqrt((layer.b1)^2 + sum((layer.Weights).^2,2)) )... 61 | *( layer.normMag*squeeze(X)./sqrt((layer.b1)^2 + sum(squeeze(X).^2,1))),... 62 | [size(layer.Weights,1),batchSize])+layer.normMag; 63 | % Z = reshape(... 64 | % ( layer.Weights./sqrt((layer.b1)^2 + sum((layer.Weights).^2,2)) )... 65 | % *( layer.normMag*squeeze(X)),... 66 | % [size(layer.Weights,1),batchSize])+layer.normMag; 67 | end 68 | end 69 | end 70 | end 71 | -------------------------------------------------------------------------------- /zeroBiasFCLayerEuc.m: -------------------------------------------------------------------------------- 1 | classdef zeroBiasFCLayerEuc < nnet.layer.Layer 2 | % Example custom weighted addition layer. 3 | 4 | properties (Learnable) 5 | % Layer learnable parameters 6 | Weights; 7 | % normMag = 1; 8 | end 9 | properties 10 | normMag; 11 | b1;%, Biases; 12 | end 13 | methods 14 | function layer = zeroBiasFCLayerEuc(inputDim,outputDim,name,initialWeights) 15 | % layer = weightedAdditionLayer(numInputs,name) creates a 16 | 17 | % Set number of inputs. 18 | %layer.NumInputs = inputDim; 19 | %layer.NumOutputs = numOutputs; 20 | % Set layer name. 21 | layer.Name = name; 22 | % Set layer description. 23 | layer.Description = "zero-bias FC layer with " + inputDim + ... 24 | " inputs"; 25 | % Initialize layer weights. 26 | stdGlorot = sqrt(2/(inputDim + outputDim)); 27 | layer.Weights = dlarray(rand(outputDim,inputDim).*stdGlorot); 28 | %layer.Weights = rand(outputDim,inputDim); 29 | layer.normMag = 5; 30 | layer.b1 = 1e-9; % a real small number to maintain numerical stability. 31 | %layer.Biases = rand(outputDim,1); 32 | if numel(initialWeights) ~= 0 33 | layer.Weights = initialWeights; 34 | end 35 | end 36 | 37 | function Z = predict(layer, X) 38 | if ndims(X) >= 3 39 | batchSize = size(X,4); 40 | else 41 | batchSize = size(X,ndims(X)); 42 | end 43 | 44 | if ndims(X) >= 3 45 | % res = ( layer.Weights./sqrt((layer.b1)^2 + sum((layer.Weights).^2,2)) )... 46 | % *( layer.normMag*squeeze(X)./sqrt((layer.b1)^2 + sum(squeeze(X).^2,1))); 47 | % res = sqrt(-2*layer.Weights*squeeze(X) + sum((layer.Weights).^2,2) ... 48 | % + sum(squeeze(X).^2,1)); 49 | 50 | % sx = squeeze(X)./sqrt((layer.b1)^2 + sum(squeeze(X).^2,1)); 51 | sx = squeeze(X); 52 | res = sqrt(-2*layer.Weights*sx + sum((layer.Weights).^2,2) ... 53 | + sum(sx.^2,1)); 54 | Z = reshape(res, [1,1,size(layer.Weights,1),batchSize]); 55 | 56 | else 57 | % res = sqrt(-2*layer.Weights*squeeze(X) + sum((layer.Weights).^2,2) ... 58 | % + sum(squeeze(X).^2,1)); 59 | 60 | % sx = squeeze(X)./sqrt((layer.b1)^2 + sum(squeeze(X).^2,1)); 61 | sx = squeeze(X); 62 | res = sqrt(-2*layer.Weights*sx + sum((layer.Weights).^2,2) ... 63 | + sum(sx.^2,1)); 64 | Z = reshape(res, [size(layer.Weights,1),batchSize]); 65 | 66 | end 67 | end 68 | end 69 | end 70 | --------------------------------------------------------------------------------