├── DictionaryLearning.m ├── EdgeDetectionUsingOvercompleteDictionary.m ├── FaceRecognitionDKSVD.m ├── ImageDenoisingUsingOvercompleteDictionary.m ├── ImageInpaintingUsingOvercompleteDictionary.m ├── LinearClassifierUsingKSVD.m ├── README.md ├── ReconstructiveDiscrimination.m ├── overcompleteDCTdictionary.m └── utilities ├── extractImagePatches.m ├── hardThreshold.m ├── initDictionaryFromPatches.m ├── learnDictionary.m ├── normalizeColumns.m ├── softThreshold.m ├── sparseCode.m ├── strictThreshold.m ├── substractMeanCols.m ├── updateDictionary.m └── visualizeDictionary.m /DictionaryLearning.m: -------------------------------------------------------------------------------- 1 | %% DICTIONARY LEARNING 2 | clearvars 3 | close all 4 | clc 5 | 6 | %% INITIALIZATION 7 | addpath('utilities') 8 | addpath('data') 9 | 10 | % size of extracted square (w*w) patches 11 | blockSize = 16; 12 | 13 | % number of image patches in set Y 14 | N = 1000; 15 | 16 | % length of signal y (vectorized image patch) 17 | n = blockSize^2; 18 | 19 | % desired sparsity (number of non-zero elements in sparse representation vector) 20 | T0 = 20; 21 | 22 | % number of atoms in dictionary D 23 | K = 324; 24 | 25 | % load image for patch extraction 26 | imagePath = '.\data\barb.png'; 27 | image = im2double(imresize(imread(imagePath), 1)); 28 | 29 | % add additive noise noise 30 | % sigma = 0.1; 31 | % image = image + sigma*randn(size(image)); 32 | 33 | [imH, imW] = size(image); 34 | 35 | %% EXTRACT IMAGE PATCHES & INITIALIZE DICTIONARY D0 & PLOT DICTIONARY 36 | 37 | [~, Y] = extractImagePatches(image, blockSize, 'rand', 'nPatches', 1000); 38 | 39 | % Y = kron(dctmtx(16),dctmtx(16))*Y; 40 | 41 | Y = Y - repmat(mean(Y, 1), [blockSize^2,1]); 42 | 43 | D0 = initDictionaryFromPatches(n, K, Y); 44 | 45 | % ALTERNATIVE: generate overcomplete DCT dictionary 46 | % D0 = overcompleteDCTdictionary(n, K); 47 | 48 | visualizeDictionary(D0); 49 | title('Initial Dictionary') 50 | 51 | %% CALCULATE COEFFICIENTS X 52 | D = D0; 53 | X = zeros(size(D, 2), size(Y, 2)); 54 | 55 | X = sparseCode(Y, D, T0, 20, 'Plot', 0); 56 | 57 | %% UPDATE DICTIONARY D 58 | D = D0; 59 | 60 | D = updateDictionary(Y, X, D, 'ksvd', 'nIter', 15, 'Plot', 0, 'Verbose', 1); 61 | 62 | 63 | %% DICTIONARY LEARNING 64 | % perform dictionary learning by iteratively repeating coefficient 65 | % calculation and dictionary update steps 66 | 67 | niter_learn = 20; 68 | niter_coeff = 10; 69 | niter_dict = 10; 70 | 71 | D = D0; 72 | X = zeros(size(D, 2), size(Y, 2)); 73 | E0 = []; 74 | 75 | sigma = 0.1; 76 | lambda = 1.5 * sigma; 77 | 78 | 79 | [D, X, E0] = learnDictionary(Y, D, T0, 'Plot', 1); 80 | 81 | 82 | %% DICTIONARY LEARNING - SPAMS 83 | 84 | param.lambda = 0.1; 85 | param.numThreads = -1; % number of threads 86 | param.iter = 50; % let us see what happens after 1000 iterations. 87 | param.mode = 5; 88 | param.D = D0; 89 | 90 | D = mexTrainDL(Y, param); 91 | 92 | %% 93 | figure 94 | visualizeDictionary(D) 95 | title('Final Dictionary') 96 | 97 | %% 98 | 99 | [x_rec, r, coeff, iopt] = wmpalg('OMP', Y(:,1), D0, 'itermax', 100, 'maxerr', {'L2', 20}); 100 | 101 | iopt 102 | 103 | 104 | 105 | -------------------------------------------------------------------------------- /EdgeDetectionUsingOvercompleteDictionary.m: -------------------------------------------------------------------------------- 1 | clearvars 2 | close all 3 | clc 4 | 5 | image = im2double(imread('lena.png')); 6 | % image = imgaussfilt(image, 1.2); 7 | 8 | % image = rgb2gray(image); 9 | 10 | % image=imresize(imbinarize(checkerboard(64)),8); 11 | 12 | bb = 7; 13 | 14 | imageBlocks = im2col(image, [bb, bb], 'sliding'); 15 | 16 | averagedImageBlocks = sum(imageBlocks, 1)./(bb.^2); 17 | 18 | imageBlocksSubtracted = imageBlocks - repmat(averagedImageBlocks, (bb.^2), 1); 19 | 20 | [V, D, W] = eig(cov(imageBlocksSubtracted')); 21 | 22 | for i=1:size(V,2) 23 | mask(:,:,i) = reshape(W(:,i), [bb, bb]); 24 | 25 | filteredImages(:,:,i) = conv2(image, mask(:,:,i), 'same'); 26 | 27 | figure(1) 28 | imshow(filteredImages(:,:,i)) 29 | drawnow 30 | % waitforbuttonpress 31 | 32 | 33 | end 34 | 35 | edgeMap = max((filteredImages(:,:, 1:end)), [], 3); 36 | % edgeMap = mean2((filteredImages(:,:, 1:end)), 3); 37 | 38 | % se = strel('disk', 2); 39 | % edgeMap = imerode(edgeMap, se); 40 | 41 | [Ix, Iy] = gradient(filteredImages(:,:,i)); 42 | 43 | I = hypot(Ix, Iy); 44 | 45 | % edgeMap=I; 46 | 47 | orient = atan2(Ix, Iy); 48 | orient(orient < 0) = orient(orient < 0) + pi; 49 | 50 | % orient = smoothorient(orient, 2); 51 | 52 | orient = int8(rad2deg(orient)); 53 | 54 | radius = 1.5; 55 | 56 | edgeMap = nonmaxsup(edgeMap, orient, radius); 57 | 58 | threshold = 5*mean2(edgeMap); 59 | 60 | edgeMap = hysthresh(edgeMap, threshold, 0.95*threshold); 61 | 62 | % edgeMap = imbinarize(edgeMap, mean(edgeMap(:))); 63 | 64 | % edgeMap= bwmorph(edgeMap, 'close', 100); 65 | 66 | edgeMapMatlab = edge(image, 'Canny'); 67 | 68 | figure, colormap gray 69 | subplot(121), imshow(edgeMap) 70 | subplot(122), imshow(edgeMapMatlab) 71 | title('Detected Edges') 72 | 73 | 74 | -------------------------------------------------------------------------------- /FaceRecognitionDKSVD.m: -------------------------------------------------------------------------------- 1 | % FACE RECOGNITION USING D-KSVD 2 | close all 3 | clearvars 4 | clc 5 | 6 | %% LOAD IMAGES 7 | % 10 persons in set and 20 images per person 8 | % 10 images used for training and 10 images used for testing 9 | 10 | addpath('utilities') 11 | addpath('data') 12 | 13 | for personID = 1:10 14 | for imageID = 1:20 15 | fileName = sprintf('Y:/Projects/MATLAB Projects/Sparse Dictionary Learning/data/face recognition/malestaff/%d/%02d.jpg', personID, imageID); 16 | image = rgb2gray(imread(fileName)); 17 | image = im2double(image); 18 | images(:, :, imageID, personID) = image; 19 | 20 | % figure(1) 21 | % imagesc(images(:, :, personID, imageID)) 22 | % drawnow 23 | end 24 | end 25 | 26 | nTraining = 10; 27 | nTesting = 10; 28 | 29 | images_training = images(:, :, 1:nTraining, :); 30 | images_testing = images(:, :, nTraining+1:nTraining+nTesting, :); 31 | 32 | images_training = reshape(images_training, size(images_training, 1)*size(images_training, 2), size(images_training, 3)* size(images_training, 4)); 33 | images_testing = reshape(images_testing, size(images_testing, 1)*size(images_testing, 2), size(images_testing, 3)* size(images_testing, 4)); 34 | 35 | images_training = substractMeanCols(images_training); 36 | images_testing = substractMeanCols(images_testing); 37 | 38 | 39 | % for personID = 1:10 40 | % for imageID = 1:10 41 | % personID, imageID 42 | % figure(1) 43 | % subplot(121) 44 | % imagesc(reshape(images_training(:,10*(personID-1)+1+imageID-1), 200,180)) 45 | % subplot(122) 46 | % imagesc(reshape(images_testing(:,10*(personID-1)+1+imageID-1), 200,180)) 47 | % waitforbuttonpress 48 | % end 49 | % end 50 | 51 | % randomfaces matrix size 52 | % random faces is linear projection generated by Gaussian random mask 53 | n = 756; 54 | 55 | R = randn(n, size(images, 1)*size(images,2)); 56 | R = normalizeColumns(R')'; 57 | 58 | Y0 = R*images_training; 59 | 60 | K = 100; 61 | 62 | D0 = initDictionaryFromPatches(n, K, Y0); 63 | 64 | % X = OMP(D0,images_testing(:,2), T0) 65 | 66 | %% 67 | 68 | niter_learn = 20; 69 | niter_coeff = 30; 70 | niter_dict = 10; 71 | T0 = 10; 72 | 73 | D_cat = []; 74 | 75 | % initialize class label matrix 76 | H = kron(diag(ones(10,1)), ones(size(Y0,2)/10, 1))'; 77 | 78 | param.K = K/10; 79 | param.numIteration=niter_learn; 80 | param.InitializationMethod='DataElements'; 81 | param.preserveDCAtom=0; 82 | param.L = T0; 83 | 84 | for i = 1:10 85 | 86 | Y_part = Y0(:,(i-1)*10+1:i*10); 87 | % D_part = D0(:,(i-1)*10+1:i*10); 88 | 89 | D_part = initDictionaryFromPatches(n, K/10, Y_part); 90 | 91 | % [D_part, out] = KSVD(Y_part, param); 92 | [D_part, X] = learnDictionary(Y_part, D_part, 5, 'nIterLearn', 5); 93 | 94 | D_cat = [D_cat, D_part]; 95 | end 96 | 97 | 98 | %% 99 | 100 | X = OMP(D_cat,Y0,T0); 101 | % X = sparseCode(Y0, D_cat, T0, 10); 102 | 103 | % initialize linear classifier W 104 | W0 = (H*X')/(X*X'+eye(size(X*X'))); 105 | 106 | % factor that controls reconstructive/discriminative dictionary properties 107 | gamma = 1000; 108 | 109 | D = [D0; sqrt(gamma)*W0]; 110 | Y = [Y0; sqrt(gamma)*H]; 111 | 112 | param.K = K; 113 | param.numIteration = 10; 114 | param.InitializationMethod = 'GivenMatrix'; 115 | param.initialDictionary = D; 116 | param.preserveDCAtom = 0; 117 | param.L = T0; 118 | 119 | [D, out] = KSVD(Y, param); 120 | 121 | D_final = (D(1:n,:))./(sqrt(sum(abs(D(1:n,:).^2),1))); 122 | W_final = (D(n+1:end,:))./(sqrt(sum(abs(D(1:n,:).^2),1))); 123 | 124 | %% 125 | Y0 = R*images_testing; 126 | % 127 | % Y0 = R*images_training; 128 | 129 | X = zeros(size(D, 2), size(Y0, 2)); 130 | 131 | X = OMP(D_final, Y0, 50); 132 | 133 | classCode = (W_final*X); 134 | 135 | [c, i] = max(abs(classCode)) 136 | 137 | figure 138 | imagesc(i) 139 | -------------------------------------------------------------------------------- /ImageDenoisingUsingOvercompleteDictionary.m: -------------------------------------------------------------------------------- 1 | %% DICTIONARY LEARNING 2 | clearvars 3 | close all 4 | clc 5 | 6 | %% INITIALIZATION 7 | addpath('utilities') 8 | addpath('data') 9 | 10 | % size of extracted square (w*w) patches 11 | blockSize = 16; 12 | 13 | % number of image patches in set Y 14 | N = 1000; 15 | 16 | % length of signal y (vectorized image patch) 17 | n = blockSize^2; 18 | 19 | % desired sparsity (number of non-zero elements in sparse representation vector) 20 | T0 = 20; 21 | 22 | % number of atoms in dictionary D 23 | K = 300; 24 | 25 | % load image for patch extraction 26 | imagePath = '.\data\barb.png'; 27 | image = im2double(imresize(imread(imagePath), 1)); 28 | 29 | % add additive noise noise 30 | sigma = 0.1; 31 | image = image + sigma*randn(size(image)); 32 | 33 | [imH, imW] = size(image); 34 | 35 | %% EXTRACT IMAGE PATCHES & INITIALIZE DICTIONARY D0 & PLOT DICTIONARY 36 | 37 | [~, Y] = extractImagePatches(image, blockSize, 'rand', 'nPatches', 5000); 38 | % [~, Y] = extractImagePatches(image, blockSize, 'seq', 'Overlap', 0); 39 | 40 | Y = Y - repmat(mean(Y, 1), [blockSize^2,1]); 41 | 42 | D0 = initDictionaryFromPatches(n, K, Y); 43 | 44 | % ALTERNATIVE: generate overcomplete DCT dictionary 45 | % D0 = overcompleteDCTdictionary(n, p); 46 | 47 | visualizeDictionary(D0); 48 | title('Initial Dictionary') 49 | %% DICTIONARY LEARNING 50 | % perform dictionary learning by iteratively repeating coefficient 51 | % calculation and dictionary update steps 52 | 53 | niter_learn = 20; 54 | niter_coeff = 10; 55 | niter_dict = 10; 56 | 57 | D = D0; 58 | X = zeros(size(D, 2), size(Y, 2)); 59 | E0 = []; 60 | 61 | sigma = 0.1; 62 | lambda = 1.5 * sigma; 63 | 64 | 65 | for iter = 1:niter_learn 66 | fprintf('Dictionary Learning Iteration No. %d\n', iter); 67 | 68 | %%%%%%%%%%%%%%%% coefficient calculation %%%%%%%%%%%%%%%%%%%%%%% 69 | X = sparseCode(Y, D, T0, niter_coeff, 'StepSize', 20000, 'Verbose', 1); 70 | 71 | E0(end+1) = norm(Y-D*X, 'fro')^2; 72 | 73 | %%%%%%%%%%%%%%%% dictionary update %%%%%%%%%%%%%%%%%%%%%%%%%%% 74 | [D, X] = updateDictionary(Y, X, D, 'aksvd', 'nIter', niter_dict, 'Verbose', 1); 75 | 76 | E0(end+1) = norm(Y-D*X, 'fro')^2; 77 | end 78 | 79 | 80 | figure, 81 | hold on 82 | plot(1:2*niter_learn, E0); 83 | plot(1:2:2*niter_learn, E0(1:2:2*niter_learn), '*'); 84 | plot(2:2:2*niter_learn, E0(2:2:2*niter_learn), 'o'); 85 | axis tight; 86 | legend('|Y-DX|^2', 'After coefficient update', 'After dictionary update'); 87 | 88 | %% DICTIONARY LEARNING - SPAMS 89 | 90 | % param.lambda = 0.1; 91 | % param.numThreads = -1; % number of threads 92 | % param.iter = 50; % let us see what happens after 1000 iterations. 93 | % param.mode = 5; 94 | % param.D = D0; 95 | % 96 | % D = mexTrainDL(Y, param); 97 | 98 | %% 99 | figure 100 | visualizeDictionary(D) 101 | title('Final Dictionary') 102 | 103 | %% IMAGE DENOISING 104 | 105 | [~, Y, Xp, Yp] = extractImagePatches(image, blockSize, 'seq', 'Overlap', blockSize-1); 106 | meanY = mean(Y, 1); 107 | 108 | Y = Y - repmat(mean(Y, 1), [blockSize^2,1]); 109 | 110 | X = zeros(size(D, 2), size(Y, 2)); 111 | X = sparseCode(Y, D, 5, 10, 'StepSize', 10000, 'Plot', 0, 'Verbose', 1); 112 | 113 | 114 | PA = reshape((D*X), [blockSize blockSize size(Y, 2)]); 115 | PA = PA - repmat( mean(mean(PA)), [blockSize blockSize] ); 116 | PA = PA + reshape(repmat( meanY, [blockSize^2 1] ), [blockSize blockSize size(Y, 2)]); 117 | 118 | W = zeros(imH, imW); 119 | denoisedImage = zeros(imH, imW); 120 | 121 | for i=1:size(Y, 2) 122 | x = Xp(:,:,i); 123 | y = Yp(:,:,i); 124 | 125 | denoisedImage(x+(y-1)*imH) = denoisedImage(x+(y-1)*imH) + PA(:,:,i); 126 | W(x+(y-1)*imH) = W(x+(y-1)*imH) + 1; 127 | end 128 | 129 | denoisedImage = denoisedImage ./ W; 130 | 131 | 132 | figure, 133 | subplot(121), imagesc(image), title('Noisy image'), axis image 134 | subplot(122), imagesc(denoisedImage), title('Denoised image'), axis image 135 | 136 | 137 | 138 | 139 | -------------------------------------------------------------------------------- /ImageInpaintingUsingOvercompleteDictionary.m: -------------------------------------------------------------------------------- 1 | %IMAGE INPAINTING USING OVERCOMPLETE DICTIONARY 2 | close all 3 | clearvars 4 | clc 5 | %% 6 | 7 | 8 | % load image for patch extraction 9 | image = im2double(imresize(imread('lena.png'), 0.5)); 10 | [imH, imW] = size(image); 11 | 12 | % im = im + 0.1*randn(size(im)); 13 | 14 | % generate random binary mask 15 | mask = abs(randn(size(image)))>0.5; 16 | 17 | % masking out image pixels 18 | image = mask .* image; 19 | 20 | % size of extracted patch 21 | w = 16; 22 | 23 | % number of image patches in set Y 24 | N = 5000; 25 | 26 | % length of signal y 27 | n = w^2; 28 | 29 | % desired sparsity 30 | T0 = 10; 31 | 32 | % number of atoms in dictionary 33 | K = 128; 34 | 35 | % overlap 36 | q = 1; 37 | 38 | [y, x] = meshgrid(1:q:imH-w/2, 1:q:imW-w/2); 39 | [dY,dX] = meshgrid(0:w-1,0:w-1); 40 | 41 | N = size(x(:),1); 42 | 43 | Xp = repmat(dX,[1 1 N]) + repmat( reshape(x(:),[1 1 N]), [w w 1]); 44 | Yp = repmat(dY,[1 1 N]) + repmat( reshape(y(:),[1 1 N]), [w w 1]); 45 | 46 | Xp(Xp>imH) = 2*imH-Xp(Xp>imH); 47 | Yp(Yp>imW) = 2*imW-Yp(Yp>imW); 48 | 49 | Y = image(Xp+(Yp-1)*imH); 50 | Y = reshape(Y, [n, N]); 51 | 52 | M = mask(Xp+(Yp-1)*imH); 53 | M = reshape(M, [n, N]); 54 | 55 | [Y, meanY] = substractMeanCols(Y); 56 | 57 | 58 | select = @(A,k)repmat(A(k,:), [size(A,1) 1]); 59 | hardThresh = @(X,k)X .* (abs(X) >= select(sort(abs(X), 'descend'),k)); 60 | softThresh = @(X,th)sign(X).*max(abs(X)-th,0); 61 | 62 | %% 63 | 64 | niter_coeff = 1; 65 | D = overcompleteDCTdictionary(n, 300); 66 | 67 | 68 | X = zeros(size(D,2),size(Y,2)); 69 | E0 = []; 70 | 71 | sigma = 0.000001; 72 | lambda = 1.5 * sigma; 73 | 74 | tau = 1.9/norm(D*D'); 75 | E = []; 76 | th=tau*lambda; 77 | 78 | step=1000; 79 | 80 | 81 | 82 | for jj = 1:step:size(Y,2) 83 | jj 84 | 85 | jumpSize=min(jj+step-1,size(Y,2)); 86 | X_tmp = zeros(size(D,2),1); 87 | 88 | for i = 1:niter_coeff 89 | 90 | i 91 | 92 | for kk = jj:jumpSize 93 | 94 | R = M(:,kk).*D*X_tmp-Y(:,kk); 95 | 96 | % X_tmp = hardThresh(X_tmp-tau*(M(:,kk).*D)'*R, T0); 97 | % X_tmp = OMP(M(:,kk).*D, Y(:,kk), T0); 98 | 99 | % param.lambda = th(1); 100 | param.L = 50; 101 | % X_tmp = mexLasso(Y(:,kk), M(:,kk).*D, param); 102 | % X_tmp = mexOMP(Y(:,kk), M(:,kk).*D, param); 103 | 104 | th = tau*lambda; 105 | X_tmp = softThresh(X_tmp-tau*(M(:,kk).*D)'*R, th'); 106 | 107 | 108 | X(:,kk)=X_tmp; 109 | end 110 | end 111 | end 112 | 113 | 114 | %% 115 | 116 | PA = reshape(D*X, [w w N]); 117 | PA = PA - repmat( mean(mean(PA)), [w w]); 118 | PA = PA + reshape(repmat( meanY, [w^2 1] ), [w w N]); 119 | 120 | W = zeros(imH,imW); 121 | M1 = zeros(imH, imW); 122 | 123 | 124 | for i=1:N 125 | x = Xp(:,:,i); 126 | y = Yp(:,:,i); 127 | 128 | M1(x+(y-1)*n) = M1(x+(y-1)*n) + PA(:,:,i); 129 | W(x+(y-1)*n) = W(x+(y-1)*n) + 1; 130 | end 131 | 132 | M1 = M1 ./ W; 133 | 134 | figure 135 | imagesc(M1) -------------------------------------------------------------------------------- /LinearClassifierUsingKSVD.m: -------------------------------------------------------------------------------- 1 | %% LEARNING A LINEAR CLASSIFIER WITH KSVD 2 | close all 3 | clearvars 4 | clc 5 | 6 | %% INITIALIZATION 7 | 8 | addpath('utilities') 9 | addpath('data') 10 | 11 | % size of extracted square (w*w) patch 12 | blockSize = 32; 13 | 14 | % length of signal y (vectorized image patch) 15 | n = blockSize^2; 16 | 17 | Y_cat = []; 18 | D_cat = []; 19 | 20 | 21 | %% 22 | % desired sparsity (number of non-zero elements in sparse representation vector) 23 | T0 = 5; 24 | 25 | % number of atoms in dictionary D 26 | K = 64; 27 | 28 | % load image for patch extraction 29 | imagePath = '.\data\textures\1.1.12.tiff'; 30 | % imagePath = '.\data\barb.png'; 31 | 32 | image = im2double(imresize(imread(imagePath), 0.5)); 33 | 34 | % add additive noise noise 35 | % sigma = 0.1; 36 | % im = im + sigma*randn(size(im)); 37 | 38 | [imH, imW] = size(image); 39 | 40 | [~, Y, Xp, Yp] = extractImagePatches(image, blockSize, 'rand', 'nPatches', 1000); 41 | 42 | Y = Y - repmat(mean(Y), [n,1]); 43 | 44 | 45 | D0 = initDictionaryFromPatches(n, K, Y); 46 | 47 | [~, Y, Xp, Yp] = extractImagePatches(image, blockSize, 'seq', 'Overlap', 0); 48 | 49 | meanY = mean(Y); 50 | Y = Y - repmat(mean(Y), [n,1]); 51 | 52 | % DICTIONARY LEARNING 53 | % perform dictionary learning by iteratively repeating coefficient 54 | % calculation and dictionary update steps 55 | niter_learn = 3; 56 | niter_coeff = 10; 57 | niter_dict = 10; 58 | 59 | D = D0; 60 | X = zeros(size(D, 2), size(Y, 2)); 61 | 62 | for iter = 1:niter_learn 63 | fprintf('Dictionary Learning Iteration No. %d\n', iter); 64 | 65 | %%%%%%%%%%%%%%%% coefficient calculation %%%%%%%%%%%%%%%%%%%%%%% 66 | X = sparseCode(Y, X, D, T0, niter_coeff, 'Verbose', 0, 'StepSize', 10000); 67 | 68 | 69 | %%%%%%%%%%%%%%%% dictionary update %%%%%%%%%%%%%%%%%%%%%%%%%%% 70 | [D, X] = updateDictionary(Y, X, D, 'ksvd', 'nIter', niter_dict, 'Verbose', 0); 71 | 72 | end 73 | 74 | D_cat = [D_cat, D]; 75 | Y_cat = [Y_cat, Y]; 76 | 77 | % figure 78 | % subplot(121), visualizeDictionary(D0), title('Initial Dictionary') 79 | % subplot(122), visualizeDictionary(D) , title('Trained Dictionary') 80 | % 81 | 82 | %% 83 | 84 | figure 85 | imagesc(X) 86 | 87 | H = kron(diag(ones(2,1)), ones(size(Y,2), 1))'; 88 | 89 | Y0 = Y_cat; 90 | D0 = D_cat; 91 | 92 | X0 = zeros(size(D0, 2), size(Y0, 2)); 93 | 94 | X0 = sparseCode(Y0, X0, D0, 20, 10); 95 | 96 | W0 = H*X0'*inv(X0*X0'+eye(size(X0*X0'))); 97 | 98 | %% 99 | gamma = 0.8; 100 | 101 | D = [D0; sqrt(gamma)*W0]; 102 | Y = [Y0; sqrt(gamma)*H]; 103 | X = zeros(size(D,2), size(Y,2)); 104 | 105 | niter_learn = 20; 106 | niter_coeff = 5; 107 | niter_dict = 5; 108 | 109 | for iter = 1:niter_learn 110 | fprintf('Dictionary Learning Iteration No. %d\n', iter); 111 | 112 | %%%%%%%%%%%%%%%% coefficient calculation %%%%%%%%%%%%%%%%%%%%%%% 113 | X = sparseCode(Y, X, D, T0, niter_coeff, 'Verbose', 0, 'StepSize', 10000); 114 | 115 | 116 | %%%%%%%%%%%%%%%% dictionary update %%%%%%%%%%%%%%%%%%%%%%%%%%% 117 | [D, X] = updateDictionary(Y, X, D, 'ksvd', 'nIter', niter_dict, 'Verbose', 0); 118 | 119 | end 120 | 121 | %% 122 | 123 | D_final = (D(1:n,:))./(sqrt(sum(abs(D(1:n,:).^2),1))); 124 | W_final = (D(n+1:end,:))./(sqrt(sum(abs(D(1:n,:).^2),1))); 125 | 126 | 127 | %% 128 | 129 | residual = (W_final*X); 130 | [c, i] = max(abs(residual)) 131 | 132 | 133 | label=zeros(size(residual)); 134 | label(i)=residual(i); 135 | 136 | 137 | figure 138 | imagesc(i) 139 | 140 | % visualizeDictionary(Y) 141 | % figure 142 | % visualizeDictionary(repmat(i, 1024,1)) 143 | 144 | % imagesc(repmat(i, 1024,1)) 145 | 146 | %% 147 | 148 | X0 = zeros(size(D_final, 2), size(Y_cat, 2)); 149 | 150 | X = sparseCode(Y_cat, X0, D_final, 1, 10, 'Verbose', 1); 151 | 152 | 153 | 154 | 155 | %% 156 | PA = reshape((D*X), [blockSize blockSize size(Y, 2)]); 157 | PA = PA - repmat( mean(mean(PA)), [blockSize blockSize] ); 158 | PA = PA + reshape(repmat( meanY, [blockSize^2 1]), [blockSize blockSize size(Y, 2)]); 159 | 160 | W = zeros(imH, imW); 161 | denoisedImage = zeros(imH, imW); 162 | 163 | for i=1:size(Y, 2) 164 | x = Xp(:,:,i); 165 | y = Yp(:,:,i); 166 | 167 | denoisedImage(x+(y-1)*imH) = denoisedImage(x+(y-1)*imH) + PA(:,:,i); 168 | W(x+(y-1)*imH) = W(x+(y-1)*imH) + 1; 169 | end 170 | 171 | denoisedImage = denoisedImage ./ W; 172 | 173 | 174 | figure, 175 | subplot(121), imagesc(image), title('Noisy image'), axis image 176 | subplot(122), imagesc(denoisedImage), title('Denoised image'), axis image 177 | 178 | 179 | %% 180 | 181 | 182 | 183 | 184 | 185 | 186 | 187 | 188 | 189 | 190 | 191 | 192 | 193 | 194 | 195 | 196 | 197 | 198 | 199 | 200 | 201 | 202 | 203 | 204 | 205 | 206 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # SparseDictionaryLearning -------------------------------------------------------------------------------- /ReconstructiveDiscrimination.m: -------------------------------------------------------------------------------- 1 | close all 2 | clear all 3 | clc 4 | %% 5 | 6 | D0 = []; 7 | Y = []; 8 | 9 | directoryPath = 'Y:\Projects\MATLAB Projects\Dictionary Learning\data\textures\'; 10 | fileExtension = '.tiff'; 11 | 12 | files = dir(fullfile(directoryPath, strcat('*', fileExtension))); 13 | 14 | 15 | 16 | for i=1:1 17 | 18 | % fileName = files(i).name; 19 | 20 | fileName = 'Y:\Projects\MATLAB Projects\Dictionary Learning\data\textures\1.1.02.tiff'; 21 | 22 | image=im2double(imread((fileName))); 23 | image = imresize(image,0.5); 24 | % image=im2double(imread(strcat(directoryPath, fileName))); 25 | 26 | if(size(image,3) > 1) 27 | image = rgb2gray(image); 28 | end 29 | 30 | % if(p.Results.Plot) 31 | % figure(1) 32 | % imagesc(image) 33 | % drawnow 34 | % end 35 | 36 | % images(:,:,i)=(image); 37 | % disp(['File No: ' num2str(i, '%02d')]); 38 | 39 | 40 | %%%%%%%%%%%%% 41 | 42 | % size of extracted square (w*w) patch 43 | w = 16; 44 | 45 | % number of image patches in set Y 46 | m = 5000; 47 | 48 | % length of signal y (vectorized image patch) 49 | n = w^2; 50 | 51 | % desired sparsity (number of non-zero elements in sparse representation vector) 52 | K = 10; 53 | 54 | % number of atoms in dictionary D 55 | p = 500; 56 | 57 | [N1, N2] = size(image); 58 | 59 | % number of randomly selected image patches 60 | q = 3*m; 61 | 62 | % select q random locations in image (upper left block corners) 63 | x = floor(rand(1,1,q)*(N1-w))+1; 64 | y = floor(rand(1,1,q)*(N2-w))+1; 65 | 66 | % create rectangular mesh wxw 67 | [dY,dX] = meshgrid(0:w-1,0:w-1); 68 | 69 | % generate matrices containting block locations 70 | Xp = repmat(dX, [1 1 q]) + repmat(x, [w w 1]); 71 | Yp = repmat(dY, [1 1 q]) + repmat(y, [w w 1]); 72 | 73 | % extract and vectorize blocks 74 | Y_part = image(Xp+(Yp-1)*N1); 75 | Y_part = reshape(Y_part, [n, q]); 76 | 77 | % substract mean value from the blocks 78 | Y_part = Y_part - repmat(mean(Y_part), [n,1]); 79 | 80 | Y_part = unique(Y_part', 'rows')'; 81 | 82 | % choose m highest energy blocks 83 | [~, idx] = sort(sum(Y_part.^2), 'descend'); 84 | Y_part = Y_part(:, idx(1:m)); 85 | 86 | % randomly select p blocks (p atoms in the dictionary D) 87 | sel = randperm(m); 88 | sel = sel(1:p); 89 | 90 | % normalize columns of Y (normalized image patches) 91 | D0_part = normc(Y_part(:,sel)); 92 | 93 | D0 = [D0, D0_part]; 94 | Y = [Y, Y_part]; 95 | 96 | end 97 | 98 | 99 | %% 100 | 101 | % param.K = 100; % learns a dictionary with 100 elements 102 | param.lambda = 0.05; 103 | param.numThreads = -1; % number of threads 104 | param.batchsize = 1024; 105 | param.verbose = true; 106 | param.iter = 20; % let us see what happens after 1000 iterations. 107 | param.mode = 5; 108 | param.D = D0; 109 | 110 | D = mexTrainDL(Y, param); 111 | 112 | %% 113 | 114 | 115 | % load image for patch extraction 116 | filePath = 'Y:\Projects\MATLAB Projects\Dictionary Learning\data\textures\1234.tiff'; 117 | 118 | image = im2double(imresize(imread(filePath), 0.5)); 119 | 120 | if(size(image,3) > 1) 121 | image = rgb2gray(image); 122 | end 123 | 124 | [N1, N2] = size(image); 125 | 126 | % size of extracted patch - 8x8 JPEG like 127 | 128 | % number of image patches in set Y 129 | m = 5000; 130 | 131 | % length of signal y 132 | n = w^2; 133 | 134 | % desired sparsity 135 | k = 256; 136 | 137 | % number of atoms in dictionary 138 | p = 128; 139 | 140 | % overlap 141 | q = 1; 142 | 143 | [y, x] = meshgrid(1:q:N2-w/2, 1:q:N1-w/2); 144 | [dY,dX] = meshgrid(0:w-1,0:w-1); 145 | 146 | m = size(x(:),1); 147 | % m=256; 148 | 149 | Xp = repmat(dX,[1 1 m]) + repmat( reshape(x(:),[1 1 m]), [w w 1]); 150 | Yp = repmat(dY,[1 1 m]) + repmat( reshape(y(:),[1 1 m]), [w w 1]); 151 | 152 | Xp(Xp>N1) = 2*N1-Xp(Xp>N1); 153 | Yp(Yp>N2) = 2*N2-Yp(Yp>N2); 154 | 155 | Y = image(Xp+(Yp-1)*N1); 156 | Y = reshape(Y, [n, m]); 157 | 158 | a = mean(Y); 159 | Y = Y - repmat(mean(Y), [n,1]); 160 | 161 | 162 | select = @(A,k)repmat(A(k,:), [size(A,1) 1]); 163 | hardThresh = @(X,k)X .* (abs(X) >= select(sort(abs(X), 'descend'),k)); 164 | softThresh = @(X,th)sign(X).*max(abs(X)-th,0); 165 | 166 | %% 167 | 168 | niter_coeff = 20; 169 | 170 | % load D.mat 171 | 172 | % D=D3; 173 | 174 | X = zeros(size(D,2),size(Y,2)); 175 | E0 = []; 176 | 177 | sigma = 0; 178 | lambda = 1.5 * sigma; 179 | 180 | tau = 1.9/norm(D*D'); 181 | E = []; 182 | th=tau*lambda; 183 | 184 | step = 2000; 185 | 186 | for jj = 1:step:size(Y,2) 187 | jj 188 | jumpSize=min(jj+step-1,size(Y,2)); 189 | 190 | X_tmp = zeros(size(D,2),size(Y(:,jj:jumpSize),2)); 191 | 192 | for i = 1:niter_coeff 193 | 194 | 195 | 196 | % Y(:,jj:jumpSize) 197 | i 198 | R(:,jj:jumpSize) = D*X_tmp-Y(:,jj:jumpSize); 199 | % E(end+1,:) = sum(R.^2); 200 | 201 | X_tmp = hardThresh(X_tmp-tau*D'*R(:,jj:jumpSize) , 5); 202 | 203 | % th = tau*lambda; 204 | % X_tmp = softThresh(X_tmp-tau*D'*R, th'); 205 | 206 | 207 | % param.eps=sigma; 208 | % param.lambda=0.000000001; 209 | % param.mode = 2; 210 | 211 | % X_tmp = mexLasso(Y(:,jj:jumpSize),D,param); 212 | 213 | % X_tmp = mexOMP(Y(:,jj:jumpSize),D,param); 214 | % X_tmp = OMPerr(D,Y(:,jj:jumpSize),sigma); 215 | 216 | X(:,jj:jumpSize)=X_tmp; 217 | 218 | 219 | % X = wthresh(X-tau*D'*R, 'h', tau*lambda); 220 | end 221 | end 222 | 223 | %% 224 | % R = D*X - Y; 225 | 226 | % sum(sum(R.^2)) 227 | % 228 | % Y_rec = D*X; 229 | 230 | Y_rec = R; 231 | % 232 | PA = reshape(Y_rec, [w w m]); 233 | 234 | % PA = PA - repmat( mean(mean(PA)), [w w] ); 235 | % PA = PA + reshape(repmat( a, [w^2 1] ), [w w m]); 236 | 237 | 238 | W = zeros(N1,N2); 239 | M1 = zeros(N1,N2); 240 | 241 | 242 | for i=1:m 243 | x = Xp(:,:,i); y = Yp(:,:,i); 244 | 245 | M1(x+(y-1)*N1) = M1(x+(y-1)*N1) + PA(:,:,i); 246 | 247 | W(x+(y-1)*N1) = W(x+(y-1)*N1) + 1; 248 | end 249 | 250 | M1 = M1 ./ W; 251 | 252 | 253 | figure 254 | imagesc(M1.^2) 255 | 256 | nanmean(M1(:).^2) 257 | 258 | -------------------------------------------------------------------------------- /overcompleteDCTdictionary.m: -------------------------------------------------------------------------------- 1 | function [ D ] = overcompleteDCTdictionary( M, K) 2 | %overcompleteDCTdictionary Create overcomplete DCT MxK dictionary 3 | % Detailed explanation goes here 4 | 5 | M=sqrt(M); 6 | K=ceil(sqrt(K)); 7 | 8 | k=0:M/K:(M-M/K); 9 | n=0:1:M-1; 10 | 11 | E=cos(pi/M*k'*(n+0.5)); 12 | 13 | D=kron(E,E); 14 | D=normc(D'); 15 | 16 | end 17 | 18 | -------------------------------------------------------------------------------- /utilities/extractImagePatches.m: -------------------------------------------------------------------------------- 1 | function [ patches, patchesVectorized, Xp, Yp ] = extractImagePatches( image, blockSize, mode, varargin ) 2 | %extractImagePatches Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | p=inputParser; 6 | 7 | p.addRequired('image', @ismatrix); 8 | p.addRequired('blockSize', @(x) isnumeric(x) && (x>0)); 9 | p.addRequired('mode', @isstr); 10 | p.addParameter('Overlap', 0, @(x) isnumeric(x) && (x>=0)); 11 | p.addParameter('nPatches', 0, @(x) isnumeric(x) && (x>0)); 12 | 13 | p.parse(image, blockSize, mode, varargin{:}); 14 | 15 | [imH, imW] = size(image); 16 | 17 | if(p.Results.Overlap > (blockSize - 1)) 18 | error('Invalid overlap parameter!') 19 | end 20 | 21 | if(strcmp(p.Results.mode, 'seq')) 22 | 23 | % define step for patch extraction 24 | % default is no overlap(q=blockSize) and maximum is q=1 in that case 25 | % block centered on every pixel is extracted 26 | q = blockSize - p.Results.Overlap; 27 | 28 | [y, x] = meshgrid(1:q:imW-blockSize/2, 1:q:imH-blockSize/2); 29 | [dY,dX] = meshgrid(0:blockSize-1,0:blockSize-1); 30 | 31 | m = size(x(:),1); 32 | 33 | % create indexing grids for block extraction 34 | Xp = repmat(dX,[1 1 m]) + repmat( reshape(x(:),[1 1 m]), [blockSize blockSize 1]); 35 | Yp = repmat(dY,[1 1 m]) + repmat( reshape(y(:),[1 1 m]), [blockSize blockSize 1]); 36 | 37 | % boundary indices condition 38 | Xp(Xp>imH) = 2*imH-Xp(Xp>imH); 39 | Yp(Yp>imW) = 2*imW-Yp(Yp>imW); 40 | 41 | patches = image(Xp+(Yp-1)*imH); 42 | 43 | % h = fspecial('log', blockSize); 44 | % patches = patches .* repmat(h, 1, 1, size(patches, 3)); 45 | 46 | patchesVectorized = reshape(patches, [blockSize^2, m]); 47 | 48 | end 49 | 50 | 51 | if(strcmp(p.Results.mode, 'rand')) 52 | 53 | % number of randomly selected image patches 54 | q = p.Results.nPatches; 55 | 56 | % select q random locations in image (upper left block corners) 57 | x = floor(rand(1,1,q)*(imH-blockSize))+1; 58 | y = floor(rand(1,1,q)*(imW-blockSize))+1; 59 | 60 | % create rectangular mesh wxw 61 | [dY,dX] = meshgrid(0:blockSize-1,0:blockSize-1); 62 | 63 | % generate matrices containting block locations 64 | Xp = repmat(dX, [1 1 q]) + repmat(x, [blockSize blockSize 1]); 65 | Yp = repmat(dY, [1 1 q]) + repmat(y, [blockSize blockSize 1]); 66 | 67 | % extract and vectorize blocks 68 | patches = image(Xp+(Yp-1)*imW); 69 | 70 | 71 | 72 | % h = fspecial('log', blockSize); 73 | % patches = patches .* repmat(h, 1, 1, size(patches, 3)); 74 | 75 | patchesVectorized = reshape(patches, [blockSize^2, q]); 76 | 77 | patchesVectorized = unique(patchesVectorized', 'rows')'; 78 | 79 | end 80 | 81 | end 82 | 83 | -------------------------------------------------------------------------------- /utilities/hardThreshold.m: -------------------------------------------------------------------------------- 1 | function [ out ] = hardThreshold( in, th ) 2 | %hardThreshold Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | out = in .* (abs(in)>th); 6 | 7 | end 8 | 9 | -------------------------------------------------------------------------------- /utilities/initDictionaryFromPatches.m: -------------------------------------------------------------------------------- 1 | function [ D0 ] = initDictionaryFromPatches( n, K, patchesVectorized ) 2 | %UNTITLED Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | % choose K highest energy blocks 6 | [~, idx] = sort(sum(patchesVectorized.^2), 'descend'); 7 | patchesVectorized = patchesVectorized(:, idx(1:size(patchesVectorized, 2))); 8 | 9 | % randomly select p blocks (p atoms in the dictionary D) 10 | % sel = randperm(size(patchesVectorized, 2)); 11 | % sel = sel(1:K); 12 | 13 | sel = 1:K; 14 | 15 | % normalize columns of Y (normalized image patches) 16 | D0 = normalizeColumns(patchesVectorized(:,sel)); 17 | 18 | end 19 | 20 | -------------------------------------------------------------------------------- /utilities/learnDictionary.m: -------------------------------------------------------------------------------- 1 | function [ D, X, E0 ] = learnDictionary( Y, D, T0, varargin ) 2 | %learnDictionary Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | p = inputParser; 6 | 7 | p.addRequired('Y', @ismatrix); 8 | p.addRequired('D', @ismatrix); 9 | p.addRequired('T0', @isnumeric); 10 | p.addParameter('nIterLearn', 10, @isnumeric); 11 | p.addParameter('nIterDict', 10, @isnumeric); 12 | p.addParameter('nIterCoeff', 10, @isnumeric); 13 | p.addParameter('modeDict', 'ksvd', @isstr); 14 | p.addParameter('modeCoeff', 'grad', @isstr); 15 | p.addParameter('Plot', 0, @isnumeric); 16 | p.addParameter('Verbose', 0, @isnumeric); 17 | 18 | p.parse(Y, D, T0, varargin{:}); 19 | 20 | X = zeros(size(D, 2), size(Y, 2)); 21 | E0 = []; 22 | 23 | for iter = 1:p.Results.nIterLearn 24 | fprintf('Dictionary Learning Iteration No. %d\n', iter); 25 | 26 | %%%%%%%%%%%%%%%% coefficient calculation %%%%%%%%%%%%%%%%%%%%%%% 27 | X = sparseCode(Y, D, T0, p.Results.nIterCoeff); 28 | 29 | E0(end+1) = norm(Y-D*X, 'fro')^2; 30 | 31 | %%%%%%%%%%%%%%%% dictionary update %%%%%%%%%%%%%%%%%%%%%%%%%%% 32 | [D, X] = updateDictionary(Y, X, D, p.Results.modeDict, 'nIter', p.Results.nIterDict); 33 | 34 | E0(end+1) = norm(Y-D*X, 'fro')^2; 35 | end 36 | 37 | if(p.Results.Plot) 38 | figure, 39 | hold on 40 | plot(1:2*p.Results.nIterLearn, E0); 41 | plot(1:2:2*p.Results.nIterLearn, E0(1:2:2*p.Results.nIterLearn), '*'); 42 | plot(2:2:2*p.Results.nIterLearn, E0(2:2:2*p.Results.nIterLearn), 'o'); 43 | axis tight; 44 | legend('|Y-DX|^2', 'After coefficient update', 'After dictionary update'); 45 | end 46 | 47 | 48 | end 49 | 50 | -------------------------------------------------------------------------------- /utilities/normalizeColumns.m: -------------------------------------------------------------------------------- 1 | function [ out ] = normalizeColumns( in ) 2 | %normalizeColumns Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | out = in./(sqrt(sum(in.^2))); 6 | 7 | end 8 | 9 | -------------------------------------------------------------------------------- /utilities/softThreshold.m: -------------------------------------------------------------------------------- 1 | function [ out ] = softThreshold( in, th ) 2 | %softThreshold Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | out = sign(in) .* max(abs(in)-th,0); 6 | 7 | end 8 | 9 | -------------------------------------------------------------------------------- /utilities/sparseCode.m: -------------------------------------------------------------------------------- 1 | function [ X ] = sparseCode( Y, D, T0, nIter, varargin ) 2 | %UNTITLED6 Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | p = inputParser; 6 | 7 | p.addRequired('Y', @ismatrix); 8 | p.addRequired('D', @ismatrix); 9 | p.addRequired('T0', @(x) isnumeric(x) && x>0); 10 | p.addRequired('nIter', @(x) isnumeric(x) && x>0); 11 | p.addParameter('StepSize', 0, @isnumeric); 12 | p.addParameter('Plot', 0, @isnumeric); 13 | p.addParameter('Verbose', 0, @isnumeric); 14 | 15 | p.parse(Y, D, T0, nIter, varargin{:}) 16 | 17 | 18 | if(p.Results.StepSize == 0) 19 | step = size(Y,2); 20 | else 21 | step = p.Results.StepSize; 22 | end 23 | 24 | % initialization 25 | X = zeros(size(D, 2), size(Y, 2)); 26 | 27 | % gradient descent step 28 | tau = 1.6/norm(D*D'); 29 | 30 | E = []; 31 | 32 | sigma = 0.1; 33 | % lambda controls sparsity of the coefficients 34 | % l1 regularization is similar to soft thresholding and then usual 35 | % 1.5*sigma value is used as lambda 36 | lambda = 1.5 * sigma; 37 | 38 | % soft threshold 39 | th = lambda*tau; 40 | 41 | % step = 10000; 42 | 43 | for jj = 1:step:size(Y,2) 44 | if(p.Results.Verbose) 45 | fprintf('Sparse Coding Col. No. %d - %d/%d\n', jj, min(size(Y,2), jj + min(step, size(Y,2))), size(Y,2)); 46 | end 47 | 48 | jumpSize=min(jj+step-1,size(Y,2)); 49 | X_tmp = zeros(size(D,2),size(Y(:,jj:jumpSize),2)); 50 | R = []; 51 | 52 | for i = 1:nIter 53 | R = D*X_tmp-Y(:,jj:jumpSize); 54 | 55 | X_tmp = strictThreshold(X_tmp-tau*D'*R, T0); 56 | 57 | % th = tau*lambda; 58 | % X_tmp = softThreshold(X_tmp-tau*D'*R, th'); 59 | 60 | X(:,jj:jumpSize)=X_tmp; 61 | 62 | end 63 | end 64 | 65 | 66 | if(p.Results.Plot) 67 | sel = 1:10; 68 | figure, 69 | plot(log10(E(1:end,sel) - repmat(min(E(:,sel),[],1),[nIter 1]))); 70 | axis tight; 71 | title('$$log_{10}(J(x_j) - J(x_j^*))$$', 'Interpreter', 'latex'); 72 | end 73 | 74 | end 75 | 76 | -------------------------------------------------------------------------------- /utilities/strictThreshold.m: -------------------------------------------------------------------------------- 1 | function [ out ] = strictThreshold( in, k ) 2 | %strictThreshold Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | select = @(A,k)repmat(A(k,:), [size(A,1) 1]); 6 | out = in .* (abs(in) >= select(sort(abs(in), 'descend'), k)); 7 | 8 | end 9 | 10 | -------------------------------------------------------------------------------- /utilities/substractMeanCols.m: -------------------------------------------------------------------------------- 1 | function [ out, meanValue] = substractMeanCols( in ) 2 | %UNTITLED Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | meanValue = mean(in); 6 | out = in - meanValue; 7 | 8 | end 9 | 10 | -------------------------------------------------------------------------------- /utilities/updateDictionary.m: -------------------------------------------------------------------------------- 1 | function [ D, X ] = updateDictionary( Y, X, D, mode, varargin ) 2 | %UNTITLED6 Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | p = inputParser; 6 | 7 | p.addRequired('Y', @ismatrix); 8 | p.addRequired('X', @ismatrix); 9 | p.addRequired('D', @ismatrix); 10 | p.addParameter('nIter', 10, @(x) isnumeric(x) && x>0); 11 | p.addRequired('mode', @isstr); 12 | p.addParameter('Plot', 0, @isnumeric); 13 | p.addParameter('Verbose', 0, @isnumeric); 14 | 15 | p.parse(Y, X, D, mode, varargin{:}) 16 | 17 | 18 | % nIter = 20; 19 | tau = 1/norm(X*X'); 20 | E = []; 21 | % D = D0; 22 | 23 | if(strcmp(mode, 'grad')) 24 | % dictionary update by using projected gradient descent 25 | for i = 1:p.Results.nIter 26 | if(p.Results.Verbose) 27 | fprintf('Dictionary Update Iteration No. %d\n', i); 28 | end 29 | R = D*X-Y; 30 | E(end+1) = sum(R(:).^2); 31 | D = normalizeColumns(D - tau*(D*X - Y)*X'); 32 | end 33 | 34 | elseif(strcmp(mode, 'mod')) 35 | 36 | % dictionary update using MOD algorithm 37 | % dictionary update is performed by minimizing Frobenius norm of R over all 38 | % posible D and the solution is given by normalized pseudo-inverse 39 | D = Y*pinv(X); 40 | D = normalizeColumns(D); 41 | 42 | 43 | elseif(strcmp(mode, 'ksvd')) 44 | 45 | % dictionary update using K-SVD 46 | T = 1e-3; 47 | R = Y - D*X; 48 | E = []; 49 | 50 | for kk=1:size(D,2) 51 | if(p.Results.Verbose) 52 | fprintf('K-SVD Current Column No. %d\n', kk); 53 | end 54 | idx = find(abs(X(kk,:)) > T); 55 | 56 | if (~isempty(idx)) 57 | Ri = R(:,idx) + D(:,kk)*X(kk,idx); 58 | 59 | [U,S,V] = svds(Ri, 1, 'L'); 60 | 61 | D(:,kk) = U; 62 | X(kk,idx) = S*V'; 63 | 64 | R(:,idx) = Ri - D(:,kk)*X(kk,idx); 65 | 66 | E(end+1) = sum(R(:).^2); 67 | end 68 | end 69 | 70 | elseif(strcmp(mode, 'aksvd')) 71 | 72 | % dictionary update using approximate K-SVD 73 | T = 1e-3; 74 | R = Y - D*X; 75 | E = []; 76 | 77 | for kk=1:size(D,2) 78 | if(p.Results.Verbose) 79 | fprintf('AK-SVD Current Column No. %d\n', kk); 80 | end 81 | 82 | idx = find(abs(X(kk,:)) > T); 83 | if (~isempty(idx)) 84 | 85 | Ri = R(:,idx) + D(:,kk)*X(kk,idx); 86 | dk = Ri * X(kk,idx)'; 87 | dk = dk/sqrt(dk'*dk); % normalize 88 | D(:,kk) = dk; 89 | X(kk,idx) = dk'*Ri; 90 | R(:,idx) = Ri - D(:,kk)*X(kk,idx); 91 | 92 | E(end+1) = sum(R(:).^2); 93 | 94 | end 95 | end 96 | 97 | end 98 | 99 | if(p.Results.Plot) 100 | figure, 101 | plot(log10(E(1:end/2)-min(E))); 102 | axis tight; 103 | end 104 | 105 | end -------------------------------------------------------------------------------- /utilities/visualizeDictionary.m: -------------------------------------------------------------------------------- 1 | function [] = visualizeDictionary( D ) 2 | %UNTITLED3 Summary of this function goes here 3 | % Detailed explanation goes here 4 | 5 | blockSize = sqrt(size(D, 1)); 6 | 7 | nVis = floor(size(D, 2)/blockSize)*blockSize; 8 | 9 | dictVisual = col2im(D(:,1:nVis), [blockSize, blockSize], size(D(:,1:nVis)), 'distinct'); 10 | 11 | imagesc(dictVisual), axis image 12 | 13 | xticks(0.5:blockSize:size(dictVisual,2)) 14 | yticks(0.5:blockSize:size(dictVisual,1)) 15 | 16 | set(gca,'xticklabel',[]) 17 | set(gca,'yticklabel',[]) 18 | 19 | grid on 20 | 21 | ax = gca; 22 | ax.GridColor = 'black'; 23 | ax.GridAlpha = 1; 24 | set(gca,'LineWidth', 2); 25 | 26 | end 27 | 28 | --------------------------------------------------------------------------------