├── ProximalDictLearnConst.m └── README.md /ProximalDictLearnConst.m: -------------------------------------------------------------------------------- 1 | function [ObjFunc, ObjFuncVal,A,X] = ProximalDictLearnConst(Y, lambda, beta, k, Ainit,MaxIter,Yv) 2 | % Note: Yv is the validation data, size(Yv) = n x Nv 3 | % Solve min 0.5*|Y-AX|_F^2 + lambda * |X|_1 s.t. |A|_F^2 <= beta 4 | % Optimization variables are A,X 5 | % X dimension k x N 6 | % Y dimension n x N 7 | % (c) Meisam Razaviyayn 8 | %% initialization 9 | [n N] = size(Y); 10 | if nargin <= 4 11 | Ainit = randn(n,k); 12 | MaxIter = 300; 13 | Yv = zeros(n,1); 14 | elseif nargin == 5 15 | MaxIter = 300; 16 | Yv = zeros(n,1); 17 | elseif nargin == 6 18 | Yv = zeros(n,1); 19 | else 20 | end 21 | Nv = size(Yv,2); 22 | Xvinit = zeros(k,Nv); 23 | NewtonEps = 0.01; 24 | A = Ainit; 25 | A = A / norm(A,'fro') * sqrt(beta); 26 | ObjFunc = zeros(MaxIter,1); 27 | ObjFuncVal = zeros(MaxIter,1); 28 | relErrThe = 1e-6; %stopping criterion 29 | %% X initialization 30 | [X,obj_init,relerr_init] = L1L2MatNest(A,Y,lambda,(A' * A + lambda * eye(k))\(A' * Y),0,50); 31 | XX = zeros(k,N); 32 | %% begin iteration 33 | for IterNum = 1: MaxIter 34 | % adaptive tau in every iteration 35 | opts.issym = 1; 36 | tau = 1 / eigs(A'*A,1,'lm',opts); 37 | ts = tic; 38 | XX = 0 * XX; 39 | temp = X - tau*A'*(A * X - Y); 40 | tempIndexPositive = find(temp>=lambda * tau); 41 | XX(tempIndexPositive) = temp(tempIndexPositive) - lambda *tau; 42 | tempIndexNegative = find(temp<= - lambda * tau); 43 | XX(tempIndexNegative) = temp(tempIndexNegative) + lambda *tau; 44 | X = XX; 45 | TempXXt = X * X'; 46 | TempYXt = Y * X'; 47 | theta = 10^-11; 48 | [U,Diag] = eig(TempXXt); 49 | eigVal = diag(Diag); 50 | TempYXtU = TempYXt*U; 51 | TempYXtUinvDiag = TempYXtU * diag( 1 ./ (eigVal + theta)); 52 | if (norm(TempYXtUinvDiag,'fro')^2) > beta 53 | counter = 0; 54 | TempYXtUinvDiagNorm = norm(TempYXtUinvDiag,'fro'); 55 | while (counter <= 20)&¬((TempYXtUinvDiagNorm^2 < beta*(1+NewtonEps))&&... 56 | (TempYXtUinvDiagNorm^2 > beta*(1-NewtonEps))) 57 | counter = counter + 1; 58 | constval = TempYXtUinvDiagNorm^2 - beta; 59 | consDerivative = - 2 * norm(TempYXtUinvDiag * diag( 1 ./ (sqrt(eigVal + theta))),'fro')^2; 60 | theta = theta - constval/consDerivative; 61 | TempYXtUinvDiag = TempYXtU * diag( 1 ./ (eigVal + theta)); 62 | TempYXtUinvDiagNorm = norm(TempYXtUinvDiag,'fro'); 63 | end 64 | end 65 | A = TempYXtUinvDiag * U'; 66 | 67 | ObjFunc(IterNum) = (0.5 * norm(Y - A*X,'fro')^2 + lambda * sum(sum(abs(X))))/N; 68 | if (mod(IterNum,10)==1)&&(nargin>=5) %calculate the validation set value every 10 iteration 69 | [Xvinit,obj_val,relerr_val] = L1L2MatNest(A,Yv,lambda,Xvinit); 70 | ObjFuncVal(IterNum) =(obj_val(end)) / Nv; 71 | end 72 | tElap = toc(ts); 73 | 74 | if IterNum > 1 75 | relerr = abs(ObjFunc(IterNum) - ObjFunc(IterNum-1)) / abs(ObjFunc(IterNum)); 76 | else 77 | relerr = 1; 78 | end 79 | disp(strcat('Iter:', num2str(IterNum),', relerr:',num2str(relerr), ',time:', num2str(tElap),', train obj:', num2str(ObjFunc(IterNum)),', val obj:', num2str(ObjFuncVal(IterNum)) ) ); 80 | if relerr < relErrThe 81 | break; 82 | end 83 | end 84 | ObjFunc = ObjFunc(1:IterNum); 85 | ObjFuncVal = ObjFuncVal(1:IterNum); 86 | 87 | end 88 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Sparse-Dictionary-Learning 2 | Codes for dictionary learning for sparse representation 3 | 4 | The codes are based on the paper here http://arxiv.org/abs/1511.01776 5 | --------------------------------------------------------------------------------