├── .gitignore ├── utils ├── .DS_Store ├── plotFMap.m ├── objFuncTemplate.m ├── eigsReal.m ├── normalizeUnitArea.m ├── vertexAreas.m ├── plotMesh.m ├── strengthOfConnections.m ├── matrixBellmanFord.m ├── graphKMediods.m ├── myColorMap.m └── algebraicCoarsening.m ├── README.md └── demo.m /.gitignore: -------------------------------------------------------------------------------- 1 | .DS_Store 2 | -------------------------------------------------------------------------------- /utils/.DS_Store: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/HTDerekLiu/SpecCoarsen_MATLAB/HEAD/utils/.DS_Store -------------------------------------------------------------------------------- /utils/plotFMap.m: -------------------------------------------------------------------------------- 1 | function plotFMap(fMap) 2 | imagesc(fMap,[-max(max(abs(fMap))) max(max(abs(fMap)))]) 3 | colormap(gca,myColorMap()) 4 | axis image; -------------------------------------------------------------------------------- /utils/objFuncTemplate.m: -------------------------------------------------------------------------------- 1 | function f = objFuncTemplate(var, A, B, L, Mc, invMc) 2 | f = full(sum(sum(Mc*((A - invMc*var'*L*var*B)).^2))) / 2; 3 | end 4 | -------------------------------------------------------------------------------- /utils/eigsReal.m: -------------------------------------------------------------------------------- 1 | function [eVal, eVec] = eigsReal(L, massMat, numEigs) 2 | [eVec, eVal] = eigs(L + 1e-8.*speye(size(L,1),size(L,1)), massMat, numEigs, 'sm'); 3 | [~ ,i] = sort(diag(abs(eVal))); % sort 4 | eVal = sum(eVal,2); 5 | eVal = eVal(i) - 1e-8; 6 | eVec = eVec(:,i); -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Spectral Coarsening 2 | This is the MATLAB implementation of "Spectral Coarsening of Geometric Operators" [Liu et al. 2019]. The only dependency is the gptoolbox ```https://github.com/alecjacobson/gptoolbox```. 3 | 4 | ### bibtex 5 | ``` 6 | @article{Liu:SpecCoarse:2019, 7 | title = {Spectral Coarsening of Geometric Operators}, 8 | author = {Hsueh-Ti Derek Liu and Alec Jacobson and Maks Ovsjanikov}, 9 | year = {2019}, 10 | journal = {ACM Transactions on Graphics}, 11 | } 12 | ``` 13 | -------------------------------------------------------------------------------- /utils/normalizeUnitArea.m: -------------------------------------------------------------------------------- 1 | function V = normalizeUnitArea(V,F) 2 | % NORMALIZEUNITAREA normalizes a given mesh to have unit total surface area 3 | % 4 | % V = normalizeUnitArea(V,F) 5 | % 6 | % Inputs: 7 | % V |V| x 3 matrix of vertex positions 8 | % F |F| x 3 matrix of indices of triangle corners 9 | % Outputs: 10 | % V a new matrix of vertex positions whose total area is 1 11 | 12 | VA = vertexAreas(V,F); 13 | totalArea = sum(VA); 14 | V = V / sqrt(totalArea); % normalize shape to have unit area -------------------------------------------------------------------------------- /utils/vertexAreas.m: -------------------------------------------------------------------------------- 1 | function VA = vertexAreas(V, F) 2 | % VERTEXAREAS computs per vertex area of a triangle mesh 3 | % 4 | % VA = vertexAreas(V,F) 5 | % 6 | % Inputs: 7 | % V |V| x 3 matrix of vertex positions 8 | % F |F| x 3 matrix of indices of triangle corners 9 | % Outputs: 10 | % VA a |V| vector of vertex areas (summation of 1/3 adjacent face areas) 11 | 12 | FN = cross(V(F(:,1),:)-V(F(:,2),:), V(F(:,1),:) - V(F(:,3),:)); 13 | FA = sqrt(sum(FN.^2,2)) ./ 2; % face area 14 | 15 | rIdx = [F(:,1);F(:,2);F(:,3)]; 16 | cIdx = ones(size(rIdx)); 17 | val = [FA(:,1);FA(:,1);FA(:,1)] ./ 3; 18 | 19 | VA = full(sparse(rIdx,cIdx,val,size(V,1),1)); -------------------------------------------------------------------------------- /utils/plotMesh.m: -------------------------------------------------------------------------------- 1 | function plotFunction(V,F,f, drawEdge) 2 | 3 | if nargin < 3 4 | f = zeros(size(V,1),1); 5 | EdgeColorBool = false; 6 | elseif nargin == 3 7 | EdgeColorBool = false; 8 | elseif nargin == 4 9 | EdgeColorBool = drawEdge; 10 | end 11 | if size(f,1) ~= size(V,1) 12 | f = zeros(size(V,1),1); 13 | end 14 | 15 | CMap = myColorMap(); 16 | t = tsurf(F,V); 17 | backColor = [1,1,1]; 18 | axis equal 19 | axis off 20 | 21 | camlight; 22 | if EdgeColorBool == false 23 | t.EdgeColor = 'none'; 24 | else 25 | t.EdgeColor = 'black'; 26 | end 27 | set(t,fphong, 'FaceVertexCData', f); 28 | set(t, fsoft); 29 | set(gca, 'Visible', 'off') 30 | set(gcf, 'Color', backColor) 31 | % addToolbarExplorationButtons(gcf) 32 | colormap(CMap) 33 | -------------------------------------------------------------------------------- /utils/strengthOfConnections.m: -------------------------------------------------------------------------------- 1 | function SoC = strengthOfConnections(L, M, p) 2 | % STRENGTHOFCONNECTIONS computs the strength of connections (SOC) between 3 | % adjacent vertices given the input operator L and the mass matrix M 4 | % 5 | % SoC = strengthOfConnections(L,M) 6 | % 7 | % Inputs: 8 | % L m x m PSD matrix of a differential operator 9 | % . (diagonal positive, off-diagonal mostly negative) 10 | % M m x m diagonal matrix of variable masses (e.g. vertex areas) 11 | % . p a real number controls the unit (e.g. 0.5 for LB operator) 12 | % Outputs: 13 | % SoC a m x m matrix where off-diagonals are the SoC between variables 14 | 15 | if nargin < 3 16 | p = 1/2; 17 | end 18 | 19 | MList = diag(M); 20 | L(1:size(M,1)+1:end) = 0; % remove diagonal 21 | [i,j,val] = find(-L); 22 | 23 | % set negative off-diag entries to zeros 24 | negIdx = find(val < 0); 25 | i(negIdx) = []; 26 | j(negIdx) = []; 27 | val(negIdx) = []; 28 | 29 | strength = max((MList(i) + MList(j)).^(p) ./ val, 0); 30 | SoC = sparse(i,j,strength, size(L,1), size(L,1), length(val)); 31 | SoC = (SoC + SoC') / 2; -------------------------------------------------------------------------------- /utils/matrixBellmanFord.m: -------------------------------------------------------------------------------- 1 | function [dist, nearestCenter] = matrixBellmanFord(SoC, centersIdx) 2 | % MATRIXBELLMANFORD performs matrix bellman ford algorithm to compute the 3 | % shortest graph distance for every node to a given set of center points 4 | % 5 | % Reference: 6 | % Bell, Algebraic Multigrid for Discrete Differential Forms, 2008 7 | % 8 | % [dist, nearestCenter] = matrixBellmanFord(A, centerIdx) 9 | % 10 | % Inputs: 11 | % SoC |V| x |V| matrix of strength of connections 12 | % centersIdx #centers vector of indices of the centers 13 | % Outputs: 14 | % dist a |V| vector of the graph distance to the closest center point 15 | % nearestCenter a |V| vector of indices of the nearest center point 16 | 17 | 18 | % initialize dist and nearestCenter 19 | nV = size(SoC,1); 20 | dist = Inf(nV,1); 21 | nearestCenter = zeros(nV,1); 22 | 23 | % set center indices 24 | dist(centersIdx) = 0; 25 | nearestCenter(centersIdx) = centersIdx; 26 | 27 | % iterate through all the non-zeros in the system matrix A 28 | [i,j,dij] = find(SoC); 29 | 30 | while true 31 | idx = find( (dist(i)+dij) < dist(j)); 32 | if size(idx,1) ~= 0 33 | dist(j(idx)) = dist(i(idx)) + dij(idx); 34 | nearestCenter(j(idx)) = nearestCenter(i(idx)); 35 | else 36 | break; 37 | end 38 | end 39 | end -------------------------------------------------------------------------------- /demo.m: -------------------------------------------------------------------------------- 1 | clc; clear all; close all; 2 | addpath('./utils/') 3 | addpath(genpath('/usr/local/gptoolbox')) % path to gptoolbox 4 | 5 | % parameters 6 | fMapSize = 100; % size of functional map 7 | numNc = 500; % number of coarse points 8 | lr = 2e-2; % learning rate 9 | decayIter = 1; % learning rate decay iterations (it is optional, just for fine tune the result) 10 | stallIter = 5; % stalling iteration 11 | 12 | % read mesh 13 | [V,F] = readOBJ('./bunny.obj'); 14 | V = normalizeUnitArea(V,F); 15 | 16 | % construct an initial operator and mass matrix 17 | L = -cotmatrix(V,F); 18 | M = massmatrix(V,F); 19 | 20 | % algebraic coarsening 21 | % note: this matlab implementation does not implement the sparse gradient in the 22 | % appendix A of "Spectral Coarsening of Geometric Operators" [Liu et al. 2019]. 23 | % Thus it would be much slower than C++ implementation. 24 | [Lc, Mc, G, P, Cpt] = algebraicCoarsening(L, M, numNc, ... 25 | 'lr', lr, 'decayIter', decayIter, 'stallIter', stallIter); 26 | 27 | % visualize functional map 28 | [~, eVecc] = eigsReal(Lc, Mc, fMapSize); 29 | [~, eVec] = eigsReal(L, M, fMapSize); 30 | fMap = eVecc' * Mc * P * eVec; 31 | figure(1) 32 | plotFMap(fMap) 33 | title('functional map image') 34 | 35 | % visualize one eigenfunctions (this may have sign flip) 36 | figure(2) 37 | subplot(1,2,1) 38 | plotMesh(V,F,eVec(:,10)) 39 | subplot(1,2,2) 40 | scatter3(V(Cpt,1),V(Cpt,2),V(Cpt,3), 30, eVecc(:,10), 'filled') 41 | axis equal off 42 | title('visualize one eigen functioon') 43 | 44 | % visualize root nodes 45 | figure(3) 46 | plotMesh(V,F) 47 | hold on 48 | scatter3(V(Cpt,1),V(Cpt,2),V(Cpt,3),20,'filled') 49 | title('visualize root nodes') 50 | -------------------------------------------------------------------------------- /utils/graphKMediods.m: -------------------------------------------------------------------------------- 1 | function [clusterAssignIdx] = graphKMediods(SoC, seedsIdx, landmarks, maxIter) 2 | % GRAPHKMEDIODS performs k-mediods clustering given a strength of 3 | % connection matrix (see strengthOfConnections.m) 4 | % 5 | % clusterAssignIdx = graphKMediods(SoC, seedsIdx) 6 | 7 | % Inputs: 8 | % SoC |V| x |V| matrix of strength of connections 9 | % seedsIdx #seeds vector of indices of initial seeds for the clustering 10 | % maxIter maximum iterations to perform coarsening 11 | % Outputs: 12 | % clusterAssignIdx a |V| vector of cluster assignment result 13 | 14 | tStart = tic; 15 | 16 | if (nargin < 3) 17 | landmarks = []; 18 | end 19 | if (nargin < 4) 20 | maxIter = 20; 21 | end 22 | 23 | clusterAssignIdx = zeros(size(SoC,1),1); 24 | [row, col, dij] = find(SoC); 25 | ijList = unique(sort([row, col],2),'rows'); 26 | for iter = 1:maxIter 27 | [~, nearestCenter] = matrixBellmanFord_fast(SoC, seedsIdx, row, col, dij); 28 | isBorder = find(nearestCenter(ijList(:,1)) ~= nearestCenter(ijList(:,2))); 29 | borderIdx = ijList(isBorder,:); 30 | borderIdx = unique(borderIdx(:)); 31 | 32 | [distFromBorder, ~] = matrixBellmanFord_fast(SoC, borderIdx, row, col, dij); 33 | for ii = (length(landmarks)+1):length(seedsIdx) 34 | cIdx = seedsIdx(ii); 35 | aggNodes = find(nearestCenter == cIdx); 36 | [~,newCIdx] = max(distFromBorder(aggNodes)); 37 | seedsIdx(ii) = aggNodes(newCIdx); 38 | end 39 | 40 | if clusterAssignIdx == nearestCenter 41 | break; 42 | else 43 | clusterAssignIdx = nearestCenter; 44 | end 45 | end 46 | 47 | fprintf('graphKMediods iteration %d\n', iter) 48 | tEnd = toc(tStart); 49 | fprintf('graphKMediods %.4f sec\n', tEnd) 50 | end 51 | 52 | function [dist, nearestCenter] = matrixBellmanFord_fast(A, centerIdx, i, j, dij) 53 | % MATRIXBELLMANFORD_FAST performs matrix bellman ford algorithm to compute 54 | % graph distance for every node to a given set of seed points 55 | % 56 | % Note: 57 | % this is a simplified version designed specifically for the graphKMediods. 58 | % For a complete version, please see matrixBellmanFord.m 59 | % 60 | % Reference: 61 | % Bell, Algebraic Multigrid for Discrete Differential Forms, 2008 62 | 63 | % initialize dist and nearestCenter 64 | nV = size(A,1); 65 | dist = Inf(nV,1); 66 | nearestCenter = zeros(nV,1); 67 | 68 | % set center indices 69 | dist(centerIdx) = 0; 70 | nearestCenter(centerIdx) = centerIdx; 71 | 72 | while true 73 | idx = find( (dist(i)+dij) < dist(j)); 74 | if size(idx,1) ~= 0 75 | dist(j(idx)) = dist(i(idx)) + dij(idx); 76 | nearestCenter(j(idx)) = nearestCenter(i(idx)); 77 | else 78 | break; 79 | end 80 | end 81 | end -------------------------------------------------------------------------------- /utils/myColorMap.m: -------------------------------------------------------------------------------- 1 | function CMap = myColorMap(type) 2 | 3 | if nargin < 1 4 | type = 'default'; 5 | end 6 | 7 | if strcmp(type,'default') 8 | colors = ... 9 | [103, 0, 31; ... 10 | 178, 24, 43; ... 11 | 214, 96, 77; ... 12 | 244, 165, 130; ... 13 | 253, 219, 199; ... 14 | 209, 229, 240; ... 15 | 146, 197, 222; ... 16 | 67, 147, 195; ... 17 | 33, 102, 172; ... 18 | 5, 48, 97] / 255; 19 | 20 | stepSize = 500; 21 | CMap = zeros((size(colors,1)-1)*stepSize, 3); 22 | for i = 1:size(colors,1)-1 23 | c1 = colors(i,:); 24 | c2 = colors(i+1,:); 25 | CMap((i-1)*stepSize+1:i*stepSize, 1) = linspace(c1(1),c2(1),stepSize); 26 | CMap((i-1)*stepSize+1:i*stepSize, 2) = linspace(c1(2),c2(2),stepSize); 27 | CMap((i-1)*stepSize+1:i*stepSize, 3) = linspace(c1(3),c2(3),stepSize); 28 | end 29 | elseif strcmp(type,'red') 30 | colors = ... 31 | [255,247,236;... 32 | 254,232,200;... 33 | 253,212,158;... 34 | 253,187,132;... 35 | 252,141,89;... 36 | 239,101,72;... 37 | 215,48,31;... 38 | 179,0,0;... 39 | 127,0,0] / 255; 40 | colors = flipud(colors); 41 | 42 | stepSize = 500; 43 | CMap = zeros((size(colors,1)-1)*stepSize, 3); 44 | for i = 1:size(colors,1)-1 45 | c1 = colors(i,:); 46 | c2 = colors(i+1,:); 47 | CMap((i-1)*stepSize+1:i*stepSize, 1) = linspace(c1(1),c2(1),stepSize); 48 | CMap((i-1)*stepSize+1:i*stepSize, 2) = linspace(c1(2),c2(2),stepSize); 49 | CMap((i-1)*stepSize+1:i*stepSize, 3) = linspace(c1(3),c2(3),stepSize); 50 | end 51 | elseif strcmp(type,'blue') 52 | colors = ... 53 | [255,247,251; ... 54 | 236,231,242; ... 55 | 208,209,230; ... 56 | 166,189,219; ... 57 | 116,169,207; ... 58 | 54,144,192; ... 59 | 5,112,176; ... 60 | 4,90,141; ... 61 | 2,56,88] / 255; 62 | 63 | stepSize = 500; 64 | CMap = zeros((size(colors,1)-1)*stepSize, 3); 65 | for i = 1:size(colors,1)-1 66 | c1 = colors(i,:); 67 | c2 = colors(i+1,:); 68 | CMap((i-1)*stepSize+1:i*stepSize, 1) = linspace(c1(1),c2(1),stepSize); 69 | CMap((i-1)*stepSize+1:i*stepSize, 2) = linspace(c1(2),c2(2),stepSize); 70 | CMap((i-1)*stepSize+1:i*stepSize, 3) = linspace(c1(3),c2(3),stepSize); 71 | end 72 | % elseif strcmp(type,'test') 73 | % colors = ... 74 | % [239,243,255 75 | % 198,219,239 76 | % 158,202,225 77 | % 107,174,214 78 | % 66,146,198] / 255; 79 | % 80 | % stepSize = 500; 81 | % CMap = zeros((size(colors,1)-1)*stepSize, 3); 82 | % for i = 1:size(colors,1)-1 83 | % c1 = colors(i,:); 84 | % c2 = colors(i+1,:); 85 | % CMap((i-1)*stepSize+1:i*stepSize, 1) = linspace(c1(1),c2(1),stepSize); 86 | % CMap((i-1)*stepSize+1:i*stepSize, 2) = linspace(c1(2),c2(2),stepSize); 87 | % CMap((i-1)*stepSize+1:i*stepSize, 3) = linspace(c1(3),c2(3),stepSize); 88 | % end 89 | elseif strcmp(type,'heat') 90 | colors = ... 91 | [255,247,236;... 92 | 255,247,236;... 93 | 255,247,236;... 94 | 255,242,224;... 95 | 254,237,212;... 96 | 254,232,200;... 97 | 253,222,179;... 98 | 253,212,158;... 99 | 253,200,145;... 100 | 253,187,132;... 101 | 252,141,89;... 102 | 239,101,72;... 103 | 215,48,31;... 104 | 179,0,0;... 105 | 127,0,0] / 255; 106 | colors = flipud(colors); 107 | 108 | stepSize = 500; 109 | CMap = zeros((size(colors,1)-1)*stepSize, 3); 110 | for i = 1:size(colors,1)-1 111 | c1 = colors(i,:); 112 | c2 = colors(i+1,:); 113 | CMap((i-1)*stepSize+1:i*stepSize, 1) = linspace(c1(1),c2(1),stepSize); 114 | CMap((i-1)*stepSize+1:i*stepSize, 2) = linspace(c1(2),c2(2),stepSize); 115 | CMap((i-1)*stepSize+1:i*stepSize, 3) = linspace(c1(3),c2(3),stepSize); 116 | end 117 | elseif strcmp(type,'gray') 118 | temp = linspace(255,0,11); 119 | colors = [temp', temp', temp'] / 255; 120 | 121 | stepSize = 500; 122 | CMap = zeros((size(colors,1)-1)*stepSize, 3); 123 | for i = 1:size(colors,1)-1 124 | c1 = colors(i,:); 125 | c2 = colors(i+1,:); 126 | CMap((i-1)*stepSize+1:i*stepSize, 1) = linspace(c1(1),c2(1),stepSize); 127 | CMap((i-1)*stepSize+1:i*stepSize, 2) = linspace(c1(2),c2(2),stepSize); 128 | CMap((i-1)*stepSize+1:i*stepSize, 3) = linspace(c1(3),c2(3),stepSize); 129 | end 130 | end 131 | -------------------------------------------------------------------------------- /utils/algebraicCoarsening.m: -------------------------------------------------------------------------------- 1 | function [Lc, Mc, G, P, Cpt, errorHis] = algebraicCoarsening(L, M, numNc, varargin) 2 | %% 3 | % default parameter values 4 | lr = 2e-2; % learning rate (step size) for the gradient descent 5 | lrDecayIter = 2; % learning rate decay iterations (optional) 6 | maxIter = 1000; % the maximum iteration (usually won't reach) 7 | stallIter = 5; % stall iteration 8 | landmarks = []; % vertices you want to keep in the coarsening 9 | numEig = round(numNc / 3); % number of eigenvectors in use 10 | 11 | % process user input parameters 12 | params_to_variables = containers.Map(... 13 | {'lr', 'decayIter', 'stallIter', 'maxIter', 'landmarks','numEig'}, ... 14 | {'lr', 'lrDecayIter', 'stallIter', 'maxIter', 'landmarks','numEig'}); 15 | v = 1; 16 | while v <= numel(varargin) 17 | param_name = varargin{v}; 18 | if isKey(params_to_variables,param_name) 19 | assert(v+1<=numel(varargin)); 20 | v = v+1; 21 | % Trick: use feval on anonymous function to use assignin to this workspace 22 | feval(@()assignin('caller',params_to_variables(param_name),varargin{v})); 23 | else 24 | error('Unsupported parameter: %s',varargin{v}); 25 | end 26 | v=v+1; 27 | end 28 | 29 | %% Combinatorial Coarsening 30 | invM = diag(diag(M).^-1); 31 | SoC = strengthOfConnections(L, M); 32 | 33 | % clustering 34 | shuffle = @(v)v(randperm(numel(v))); 35 | seedsIdx = [1:size(L,1)]; 36 | seedsIdx(landmarks) = []; 37 | seedsIdx = shuffle(seedsIdx); 38 | seedsIdx = [landmarks, seedsIdx]; 39 | seedsIdx = seedsIdx(1:numNc); 40 | clusterAssignment = graphKMediods(SoC, seedsIdx, landmarks); 41 | 42 | % construct P (projection) and K (assignment) 43 | [Cpt, ~, idxToCpt] = unique(clusterAssignment); 44 | P = sparse([1:length(Cpt)], Cpt, ones(length(Cpt),1), ... 45 | length(Cpt), size(L,1), length(Cpt)); 46 | K = sparse(idxToCpt, [1:size(L,1)], ones(length(idxToCpt),1), ... 47 | length(Cpt), size(L,1), length(idxToCpt)); 48 | 49 | %% Operator Optimization 50 | Mc = K * M * K'; % coarse mass matrix 51 | invMc = diag(diag(Mc).^-1); 52 | 53 | S = double(L~=0); 54 | J = double(K * S * K' ~= 0); 55 | H = double(K' * J ~= 0); 56 | A = double(H' * S * H ~= 0); 57 | 58 | % eigen vector as test vectors 59 | [~, U] = eigsReal(L, M, numEig); 60 | 61 | G = K'; % initial G 62 | 63 | % energy function/gradient precomputation 64 | A = P * invM * L * U; 65 | B = P * U; 66 | BB = B*B'; 67 | AB = A*B'; 68 | 69 | % projection precomputation (G=G+(U0-G*PU)*invUPPU*PU') 70 | % U0 = U(:,1); 71 | % PU = P*U0; 72 | % invUPPU = inv(U0'*P'*P*U0); 73 | U0 = U(:,1); 74 | ZIdx = find(H); 75 | [rIdx, cIdx] = find(H); 76 | Z = sparse(ZIdx, [1:length(ZIdx)], ones(length(ZIdx),1), size(G,1)*size(G,2), length(ZIdx)); 77 | g = G(ZIdx); 78 | Aproj = kron((P*U0)',speye(size(G,1))) * Z; 79 | AAA = Aproj'*inv(Aproj*Aproj'); 80 | 81 | % stopping criteria (stall) 82 | objValOld = 1e10; 83 | stop = 0; 84 | lrDecayCount = 0; 85 | decayRatio = 0.5; 86 | 87 | % NADAM 88 | timeStep = 0; 89 | beta1 = 0.9; 90 | beta2 = 0.9; 91 | eps = 1e-8; 92 | mt = sparse(size(G,1), size(G,2)); 93 | nt = sparse(size(G,1), size(G,2)); 94 | 95 | % function for computing objective function 96 | objFunc = @(params) objFuncTemplate(params,A, B, L, Mc, invMc); 97 | errorHis = []; 98 | 99 | for iter = 1:maxIter 100 | % precomputation 101 | objVal = objFunc(G); 102 | errorHis = [errorHis objVal]; 103 | 104 | % compute gradient 105 | % Note that MATLAB implementation DOES NOT exploit the sparse gradient 106 | % because looping over the indices is slow in MATLAB. Therefore, the 107 | % MATLAB gradient computation is much slower 108 | LG = L*G; 109 | BBGLGinvMc = BB*(G'*(LG*invMc)); 110 | grad = LG * (- AB' - AB + BBGLGinvMc + BBGLGinvMc') ; 111 | grad = grad .* H; 112 | 113 | % NADAM 114 | timeStep = timeStep + 1; 115 | mt = beta1*mt + (1-beta1)*grad; 116 | nt = beta2*nt + (1-beta2)*(grad.^2); 117 | mt_hat = mt / (1-beta1^timeStep); 118 | nt_hat = nt / (1-beta2^timeStep); 119 | Ngrad = 1./(sqrt(nt_hat) + eps) .* (beta1*mt_hat + (1-beta1)*grad/(1-beta1^timeStep)); 120 | 121 | % update G 122 | G = G - lr * Ngrad; 123 | % G = projFunc(G); 124 | g1 = G(ZIdx); 125 | g = g1 - AAA * (Aproj*g1-U0(:)); 126 | G = sparse(rIdx, cIdx, g, size(H,1), size(H,2)); 127 | 128 | % fprintf('%.1e\n', max(abs(G * P * U0 - U0))) 129 | 130 | % print progress 131 | if mod(iter,10) == 0 132 | fprintf('iter %i, cost %f\n', iter, objVal); 133 | end 134 | 135 | % stopping criteria (if the energy stop decreasing for a few iterations) 136 | if objValOld < objVal 137 | stop = stop + 1; 138 | if stop > stallIter 139 | if lrDecayCount < lrDecayIter % not decrease lr yet 140 | lr = lr * decayRatio; 141 | stop = 0; 142 | lrDecayCount = lrDecayCount + 1; 143 | fprintf('derease learning rate to: %f\n', lr); 144 | else 145 | fprintf('iter %i, cost %f\n', iter, objValOld); 146 | break; 147 | end 148 | end 149 | else 150 | stop = 0; 151 | objValOld = objVal; 152 | Gbest = G; 153 | end 154 | end 155 | G = Gbest; 156 | Lc = G' * L * G; 157 | Lc = (Lc + Lc')/2; % sometimes a slightly non-symmetric matrix fails 'eigs' due to numerical issues --------------------------------------------------------------------------------