├── DB └── MNIST_vs_USPS.mat ├── utils ├── recall_precision.m ├── normalize1.m ├── gen_color.m ├── compactbit.m ├── distMat.m ├── area_RP.m ├── gen_marker.m ├── recall_precision5.m ├── EuDist2.m ├── hammingDist.m └── L2_distance.m ├── histogram_label.m ├── MGramSchmidt.m ├── W1_obj.m ├── fea_trans.m ├── cl.m ├── README.md ├── construct_dataset.m ├── demo.m ├── main_demo.m └── OptStiefelGBB.m /DB/MNIST_vs_USPS.mat: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuxianghuang1/PWCF/HEAD/DB/MNIST_vs_USPS.mat -------------------------------------------------------------------------------- /utils/recall_precision.m: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/fuxianghuang1/PWCF/HEAD/utils/recall_precision.m -------------------------------------------------------------------------------- /utils/normalize1.m: -------------------------------------------------------------------------------- 1 | function [X] = normalize1(X) 2 | 3 | for i=1:size(X,1) 4 | if(norm(X(i,:))==0) 5 | 6 | else 7 | X(i,:) = X(i,:)./norm(X(i,:)); 8 | end 9 | end -------------------------------------------------------------------------------- /histogram_label.m: -------------------------------------------------------------------------------- 1 | function [h] = histogram_label(l,classnum,k) 2 | %UNTITLED3 Summary of this function goes here 3 | % Detailed explanation goes here 4 | h = zeros(classnum,1); 5 | for i =1:classnum 6 | num = length(find(i==l)); 7 | h(i) = num; 8 | end 9 | h = h/k; 10 | end 11 | 12 | -------------------------------------------------------------------------------- /MGramSchmidt.m: -------------------------------------------------------------------------------- 1 | function V = MGramSchmidt(V) 2 | [n,k] = size(V); 3 | 4 | for dj = 1:k 5 | for di = 1:dj-1 6 | V(:,dj) = V(:,dj) - proj(V(:,di), V(:,dj)); 7 | end 8 | V(:,dj) = V(:,dj)/norm(V(:,dj)); 9 | end 10 | end 11 | 12 | 13 | %project v onto u 14 | function v = proj(u,v) 15 | v = (dot(v,u)/dot(u,u))*u; 16 | end 17 | -------------------------------------------------------------------------------- /utils/gen_color.m: -------------------------------------------------------------------------------- 1 | function color=gen_color(curve_idx) 2 | 3 | colors=[]; 4 | colors{end+1}='b'; 5 | colors{end+1}='r'; 6 | colors{end+1}='g'; 7 | colors{end+1}='m'; 8 | colors{end+1}='c'; 9 | colors{end+1}='black'; 10 | colors{end+1}=[0.7 0 0.7 ]; 11 | colors{end+1}=[0 0.7 0.7 ]; 12 | colors{end+1}=[ 0.83 0.33 0]; 13 | 14 | sel_idx=mod(curve_idx-1, length(colors))+1; 15 | color=colors{sel_idx}; 16 | 17 | end -------------------------------------------------------------------------------- /W1_obj.m: -------------------------------------------------------------------------------- 1 | function [ Ft, Gt] = W1_obj(W,F,D,X,L,Xs,Xt,Bs,Bt,paras) 2 | %UNTITLED2 Summary of this function goes here 3 | % Detailed explanation goes here 4 | Ft = F+paras.theta1 *norm(Bs-W'*Xs','fro')^2+paras.theta2*norm(Bt-W'*Xt','fro')^2+paras.lambda1*trace(W'*X'*L*X*W); 5 | Gt = 2*D-2*paras.theta1 *Xs'*Bs'+2*paras.theta1 *(Xs'*Xs)*W-2*paras.theta2 *Xt'*Bt'+2*paras.theta2 *(Xt'*Xt)*W+2*paras.lambda1*X'*L*X*W; 6 | end -------------------------------------------------------------------------------- /utils/compactbit.m: -------------------------------------------------------------------------------- 1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 2 | function cb = compactbit(b) 3 | % 4 | % b = bits array 5 | % cb = compacted string of bits (using words of 'word' bits) 6 | 7 | [nSamples nbits] = size(b); 8 | nwords = ceil(nbits/8); 9 | cb = zeros([nSamples nwords], 'uint8'); 10 | 11 | for j = 1:nbits 12 | w = ceil(j/8); 13 | cb(:,w) = bitset(cb(:,w), mod(j-1,8)+1, b(:,j)); 14 | end -------------------------------------------------------------------------------- /utils/distMat.m: -------------------------------------------------------------------------------- 1 | function D=distMat(P1, P2) 2 | % 3 | % Euclidian distances between vectors 4 | % each vector is one row 5 | 6 | if nargin == 2 7 | P1 = double(P1); 8 | P2 = double(P2); 9 | 10 | X1=repmat(sum(P1.^2,2),[1 size(P2,1)]); 11 | X2=repmat(sum(P2.^2,2),[1 size(P1,1)]); 12 | R=P1*P2'; 13 | D=real(sqrt(X1+X2'-2*R)); 14 | else 15 | P1 = double(P1); 16 | 17 | % each vector is one row 18 | X1=repmat(sum(P1.^2,2),[1 size(P1,1)]); 19 | R=P1*P1'; 20 | D=X1+X1'-2*R; 21 | D = real(sqrt(D)); 22 | end 23 | 24 | 25 | -------------------------------------------------------------------------------- /utils/area_RP.m: -------------------------------------------------------------------------------- 1 | function [value] = area_RP(recall, precision) 2 | 3 | if(recall(1)~=0) 4 | xx = [0,recall']; 5 | yy = [precision(1), precision']; 6 | else 7 | xx = [recall']; 8 | yy = [precision']; 9 | end 10 | [xx, index] = unique(xx); 11 | yy = yy(index); 12 | for iii = 1:length(xx) 13 | ic = length(xx)-iii+1; 14 | if(yy(ic) >= 0) 15 | % nothing 16 | else 17 | yy(ic) = yy(ic+1); 18 | end 19 | 20 | end 21 | area = 0; 22 | for i=1:(length(xx)-1) 23 | subarea = 0.5*(xx(i+1)-xx(i))*(yy(i+1)+yy(i)); 24 | area = area+subarea; 25 | end 26 | value = area; -------------------------------------------------------------------------------- /utils/gen_marker.m: -------------------------------------------------------------------------------- 1 | 2 | function marker=gen_marker(curve_idx) 3 | 4 | markers=[]; 5 | 6 | % scheme 7 | % scheme 8 | % markers{end+1}='o'; 9 | % markers{end+1}='*'; 10 | % markers{end+1}='d'; 11 | % markers{end+1}='p'; 12 | % markers{end+1}='s'; 13 | % markers{end+1}='h'; 14 | % markers{end+1}='o'; 15 | % markers{end+1}='*'; 16 | % markers{end+1}='o'; 17 | % markers{end+1}='o'; 18 | % markers{end+1}='o'; 19 | % markers{end+1}='o'; 20 | % markers{end+1}='o'; 21 | 22 | markers{end+1}='s'; 23 | markers{end+1}='o'; 24 | markers{end+1}='d'; 25 | markers{end+1}='^'; 26 | markers{end+1}='*'; 27 | markers{end+1}='v'; 28 | markers{end+1}='x'; 29 | markers{end+1}='+'; 30 | markers{end+1}='>'; 31 | markers{end+1}='<'; 32 | markers{end+1}='p'; 33 | markers{end+1}='h'; 34 | markers{end+1}='x'; 35 | markers{end+1}='o'; 36 | markers{end+1}='*'; 37 | markers{end+1}='d'; 38 | markers{end+1}='p'; 39 | markers{end+1}='s'; 40 | markers{end+1}='h'; 41 | 42 | sel_idx=mod(curve_idx-1, length(markers))+1; 43 | marker=markers{sel_idx}; 44 | 45 | end 46 | -------------------------------------------------------------------------------- /utils/recall_precision5.m: -------------------------------------------------------------------------------- 1 | function [recall, presicion] = recall_precision5(Wtrue, Dhat, pos) 2 | % 3 | % Input: 4 | % Wtrue = true neighbors [Ntest * Ndataset], can be a full matrix NxN 5 | % Dhat = estimated distances 6 | % 7 | % Output: 8 | % 9 | % exp. # of good pairs inside hamming ball of radius <= (n-1) 10 | % precision(n) = -------------------------------------------------------------- 11 | % exp. # of total pairs inside hamming ball of radius <= (n-1) 12 | % 13 | % exp. # of good pairs inside hamming ball of radius <= (n-1) 14 | % recall(n) = -------------------------------------------------------------- 15 | % exp. # of total good pairs 16 | 17 | grid = pos; 18 | for i=1:size(Dhat,1) 19 | [a,b] = sort(Dhat(i,:),'ascend'); 20 | Wtrue(i,:) = Wtrue(i,b); 21 | end 22 | total_good_pairs = sum(Wtrue(:)); 23 | 24 | for i=1:length(grid) 25 | g = grid(i); 26 | retrieved_good_pairs = sum(sum(Wtrue(:,1:g))); 27 | [row, col] = size(Wtrue(:,1:g)); 28 | total_pairs = row*col; 29 | recall(i) = retrieved_good_pairs/total_good_pairs; 30 | presicion(i) = retrieved_good_pairs/total_pairs; 31 | end 32 | 33 | 34 | 35 | 36 | -------------------------------------------------------------------------------- /utils/EuDist2.m: -------------------------------------------------------------------------------- 1 | 2 | function D = EuDist2(fea_a,fea_b,bSqrt) 3 | %EUDIST2 Efficiently Compute the Euclidean Distance Matrix by Exploring the 4 | %Matlab matrix operations. 5 | % 6 | % D = EuDist(fea_a,fea_b) 7 | % fea_a: nSample_a * nFeature 8 | % fea_b: nSample_b * nFeature 9 | % D: nSample_a * nSample_a 10 | % or nSample_a * nSample_b 11 | % 12 | % Examples: 13 | % 14 | % a = rand(500,10); 15 | % b = rand(1000,10); 16 | % 17 | % A = EuDist2(a); % A: 500*500 18 | % D = EuDist2(a,b); % D: 500*1000 19 | % 20 | % version 2.1 --November/2011 21 | % version 2.0 --May/2009 22 | % version 1.0 --November/2005 23 | % 24 | % Written by Deng Cai (dengcai AT gmail.com) 25 | 26 | 27 | if ~exist('bSqrt','var') 28 | bSqrt = 1; 29 | end 30 | 31 | if (~exist('fea_b','var')) || isempty(fea_b) 32 | aa = sum(fea_a.*fea_a,2); 33 | ab = fea_a*fea_a'; 34 | 35 | if issparse(aa) 36 | aa = full(aa); 37 | end 38 | 39 | D = bsxfun(@plus,aa,aa') - 2*ab; 40 | D(D<0) = 0; 41 | if bSqrt 42 | D = sqrt(D); 43 | end 44 | D = max(D,D'); 45 | else 46 | aa = sum(fea_a.*fea_a,2); 47 | bb = sum(fea_b.*fea_b,2); 48 | ab = fea_a*fea_b'; 49 | 50 | if issparse(aa) 51 | aa = full(aa); 52 | bb = full(bb); 53 | end 54 | 55 | D = bsxfun(@plus,aa,bb') - 2*ab; 56 | D(D<0) = 0; 57 | if bSqrt 58 | D = sqrt(D); 59 | end 60 | end 61 | -------------------------------------------------------------------------------- /utils/hammingDist.m: -------------------------------------------------------------------------------- 1 | function Dh=hammingDist(B1, B2) 2 | % 3 | % Compute hamming distance between two sets of samples (B1, B2) 4 | % 5 | % Dh=hammingDist(B1, B2); 6 | % 7 | % Input 8 | % B1, B2: compact bit vectors. Each datapoint is one row. 9 | % size(B1) = [ndatapoints1, nwords] 10 | % size(B2) = [ndatapoints2, nwords] 11 | % It is faster if ndatapoints1 < ndatapoints2 12 | % 13 | % Output 14 | % Dh = hamming distance. 15 | % size(Dh) = [ndatapoints1, ndatapoints2] 16 | 17 | % example query 18 | % Dhamm = hammingDist(B2, B1); 19 | % this will give the same result than: 20 | % Dhamm = distMat(U2>0, U1>0).^2; 21 | % the size of the distance matrix is: 22 | % size(Dhamm) = [Ntest x Ntraining] 23 | 24 | % loop-up table: 25 | bit_in_char = uint16([... 26 | 0 1 1 2 1 2 2 3 1 2 2 3 2 3 3 4 1 2 2 3 2 3 ... 27 | 3 4 2 3 3 4 3 4 4 5 1 2 2 3 2 3 3 4 2 3 3 4 ... 28 | 3 4 4 5 2 3 3 4 3 4 4 5 3 4 4 5 4 5 5 6 1 2 ... 29 | 2 3 2 3 3 4 2 3 3 4 3 4 4 5 2 3 3 4 3 4 4 5 ... 30 | 3 4 4 5 4 5 5 6 2 3 3 4 3 4 4 5 3 4 4 5 4 5 ... 31 | 5 6 3 4 4 5 4 5 5 6 4 5 5 6 5 6 6 7 1 2 2 3 ... 32 | 2 3 3 4 2 3 3 4 3 4 4 5 2 3 3 4 3 4 4 5 3 4 ... 33 | 4 5 4 5 5 6 2 3 3 4 3 4 4 5 3 4 4 5 4 5 5 6 ... 34 | 3 4 4 5 4 5 5 6 4 5 5 6 5 6 6 7 2 3 3 4 3 4 ... 35 | 4 5 3 4 4 5 4 5 5 6 3 4 4 5 4 5 5 6 4 5 5 6 ... 36 | 5 6 6 7 3 4 4 5 4 5 5 6 4 5 5 6 5 6 6 7 4 5 ... 37 | 5 6 5 6 6 7 5 6 6 7 6 7 7 8]); 38 | 39 | n1 = size(B1,1); 40 | [n2, nwords] = size(B2); 41 | 42 | Dh = zeros([n1 n2], 'uint16'); 43 | for j = 1:n1 44 | for n=1:nwords 45 | y = bitxor(B1(j,n),B2(:,n)); 46 | Dh(j,:) = Dh(j,:) + bit_in_char(y+1); 47 | end 48 | end 49 | -------------------------------------------------------------------------------- /fea_trans.m: -------------------------------------------------------------------------------- 1 | function [HS,HT,H] = fea_trans(DT,DS,YT,YS,paras) 2 | %UNTITLED2 Summary of this function goes here 3 | % Detailed explanation goes here 4 | YT = YT'; % row label vector 5 | YS = YS'; % row label vector 6 | k = paras.k; 7 | ns = paras.ns; 8 | nt = paras.nt; 9 | HS = []; 10 | HT = []; 11 | SS = EuDist2(DS',DS'); 12 | SS(logical(eye(size(SS)))) = 10000; 13 | ST = EuDist2(DT',DT'); 14 | ST(logical(eye(size(ST)))) = 10000; 15 | 16 | clsnum1 = length(unique(YS)); 17 | clsnum2 = length(unique(YT)); 18 | if clsnum1==clsnum2 19 | clsnum = clsnum1; 20 | else 21 | fprintf('......%s start ......\n\n', 'inequal class number for source and target domains'); 22 | clsnum = clsnum1; 23 | end 24 | for i=1:size(SS,1) 25 | index = zeros(1,ns); 26 | [~,ind] = sort(SS(i,:),'ascend'); 27 | ind = ind(1:k); 28 | index(ind) = 1; 29 | label_appear = YS.*index; 30 | hs = histogram_label(label_appear,clsnum,k); 31 | HS = [HS,hs]; 32 | end 33 | clear SS;clear ind;clear index;clear hs;clear label_appear;clear YS; 34 | for j =1:size(ST,1) 35 | index = zeros(1,nt); 36 | [~,ind] = sort(ST(j,:),'ascend'); 37 | ind = ind(1:k); 38 | index(ind) = 1; 39 | label_appear = YT.*index; 40 | ht = histogram_label(label_appear,clsnum,k); 41 | HT = [HT,ht]; 42 | end 43 | clear ST;clear ind;clear index;clear ht;clear label_appear;clear YT; 44 | H = [HS,HT]; 45 | end 46 | 47 | -------------------------------------------------------------------------------- /cl.m: -------------------------------------------------------------------------------- 1 | function [ L] = cl(H,paras) 2 | 3 | k =500; 4 | sigma =paras.sigma; 5 | 6 | % Construct neighborhood graph 7 | disp('Constructing neighborhood graph...'); 8 | 9 | G = L2_distance(H, H); 10 | 11 | [~, ind] = sort(G); 12 | for i=1:size(G, 1) 13 | G(i, ind((2 + k):end, i)) = 0; 14 | end 15 | G = sparse(double(G)); 16 | G = max(G, G'); % Make sure distance matrix is symmetric 17 | 18 | G = G .^ 2; 19 | G = G ./ max(max(G)); 20 | 21 | % disp('Constructing neighborhood graph...'); 22 | % G = L2_distance(X', X'); 23 | % GS=G(1:ns,1:ns); 24 | % [~, ind] = sort(GS,2);% 25 | % for i=1:size(GS, 1) 26 | % GS(i, ind(i,(2 + k):end)) = 0; 27 | % end 28 | % GT=G(ns+1:ns+nt,ns+1:ns+nt); 29 | % [~, ind] = sort(GT,2); 30 | % for i=1:size(GT, 1) 31 | % GT(i, ind(i,(2 + k):end)) = 0; 32 | % end 33 | % 34 | % GH= L2_distance(HS,HT); 35 | % [~, ind] = sort(GH,2); 36 | % for i=1:size(GH, 1) 37 | % GH(i, ind(i,(2 + k):end)) = 0; 38 | % end 39 | % G(1:ns,ns+1:ns+nt)=GH; 40 | % G(ns+1:ns+nt,1:ns)=GH'; 41 | % G(1:ns,1:ns)=GS; 42 | % G(ns+1:ns+nt,ns+1:ns+nt)=GT; 43 | 44 | G = sparse(double(G)); 45 | G = max(G, G'); 46 | G = G .^ 2; 47 | G = G ./ max(max(G)); 48 | 49 | % Compute weights (W = G) 50 | disp('Computing weight matrices...'); 51 | 52 | % Compute Gaussian kernel (heat kernel-based weights) 53 | G(G ~= 0) = exp(-G(G ~= 0) / (2 * sigma ^ 2)); 54 | D = diag(sum(G, 2)); 55 | 56 | % Compute Laplacian 57 | L = D - G; 58 | L(isnan(L)) = 0; D(isnan(D)) = 0; 59 | L(isinf(L)) = 0; D(isinf(D)) = 0; 60 | 61 | end 62 | -------------------------------------------------------------------------------- /README.md: -------------------------------------------------------------------------------- 1 | # Probability Weighted Compact Feature for Domain Adaptive Retrieval
2 | Published in CVPR 2020
3 | Contact : huangfuxiang@cqu.edu.cn
4 | 5 | Usage: MATLAB R2017 6 | 7 | Running Models: Run main_demo.m 8 | 9 | More datasets are available at https://pan.baidu.com/s/1EVlYCz51AyDnh5y7PJ5W_Q?pwd=qyrv
10 | 11 | 12 | *If you want to cite the experimental results, please pay attention to the experimental details in the paper. For handwritten digit datasets (MNIST and USPS), following transfer learning, we select 2000 images from the MNIST as the source domain and 1800 images from the USPS as the target domain. Besides, for each dataset, we randomly select 500 target images as the test set (i.e., queries) and the rest images as the training set data. Note that, to investigate more samples, we select more datasets and diiferent settings in the ohter paper, which pobulished in TNNLS2021, i.e., Domain Adaptation Preconceived Hashing for Unconstrained Visual Retrieval. Specifically, we use 60000 images from the MNIST as the source domain and 10000 images from the USPS as the target domain. For each dataset, we randomly select 10% of the target images as the test set (i.e., queries) and the rest images as the training set data. 13 | 14 | Cite: If you find this code useful in your research then please cite 15 | 16 | ```bibtex 17 | @inproceedings{huang2020PWCF, 18 | title={Probability Weighted Compact Feature for Domain Adaptive Retrieval}, 19 | author={Huang, Fuxiang and Zhang, Lei and Yang, Yang and Zhou, Xichuan}, 20 | booktitle={CVPR}, 21 | pages={9579-9588}, 22 | year={2020} 23 | } 24 | 25 | 26 | @articale{huang2021domain, 27 | author={Huang, Fuxiang and Zhang, Lei and Gao, Xinbo}, 28 | journal={IEEE Transactions on Neural Networks and Learning Systems}, 29 | title={Domain Adaptation Preconceived Hashing for Unconstrained Visual Retrieval}, 30 | year={2021}, 31 | pages={1-15}, 32 | doi={10.1109/TNNLS.2021.3071127} 33 | } 34 | ``` 35 | -------------------------------------------------------------------------------- /utils/L2_distance.m: -------------------------------------------------------------------------------- 1 | function d = L2_distance(a, b) 2 | % L2_DISTANCE - computes Euclidean distance matrix 3 | % 4 | % E = L2_distance(A,B) 5 | % 6 | % A - (DxM) matrix 7 | % B - (DxN) matrix 8 | % 9 | % Returns: 10 | % E - (MxN) Euclidean distances between vectors in A and B 11 | % 12 | % 13 | % Description : 14 | % This fully vectorized (VERY FAST!) m-file computes the 15 | % Euclidean distance between two vectors by: 16 | % 17 | % ||A-B|| = sqrt ( ||A||^2 + ||B||^2 - 2*A.B ) 18 | % 19 | % Example : 20 | % A = rand(400,100); B = rand(400,200); 21 | % d = distance(A,B); 22 | 23 | % Author : Roland Bunschoten 24 | % University of Amsterdam 25 | % Intelligent Autonomous Systems (IAS) group 26 | % Kruislaan 403 1098 SJ Amsterdam 27 | % tel.(+31)20-5257524 28 | % bunschot@wins.uva.nl 29 | % Last Rev : Wed Oct 20 08:58:08 MET DST 1999 30 | % Tested : PC Matlab v5.2 and Solaris Matlab v5.3 31 | 32 | % Copyright notice: You are free to modify, extend and distribute 33 | % this code granted that the author of the original code is 34 | % mentioned as the original author of the code. 35 | 36 | % Fixed by JBT (3/18/00) to work for 1-dimensional vectors 37 | % and to warn for imaginary numbers. Also ensures that 38 | % output is all real, and allows the option of forcing diagonals to 39 | % be zero. 40 | % 41 | % 42 | 43 | % This file is part of the Matlab Toolbox for Dimensionality Reduction v0.7.2b. 44 | % The toolbox can be obtained from http://homepage.tudelft.nl/19j49 45 | % You are free to use, change, or redistribute this code in any way you 46 | % want for non-commercial purposes. However, it is appreciated if you 47 | % maintain the name of the original author. 48 | % 49 | % (C) Laurens van der Maaten, 2010 50 | % University California, San Diego / Delft University of Technology 51 | 52 | 53 | if nargin < 2 54 | error('Not enough input arguments'); 55 | end 56 | if size(a, 1) ~= size(b, 1) 57 | error('A and B should be of same dimensionality'); 58 | end 59 | if ~isreal(a) || ~isreal(b) 60 | warning('Computing distance table using imaginary inputs. Results may be off.'); 61 | end 62 | 63 | % Padd zeros if necessray 64 | if size(a, 1) == 1 65 | a = [a; zeros(1, size(a, 2))]; 66 | b = [b; zeros(1, size(b, 2))]; 67 | end 68 | 69 | % Compute distance table 70 | d = sqrt(bsxfun(@plus, sum(a .* a)', bsxfun(@minus, sum(b .* b), 2 * a' * b))); 71 | 72 | % Make sure result is real 73 | d = real(d); 74 | 75 | -------------------------------------------------------------------------------- /construct_dataset.m: -------------------------------------------------------------------------------- 1 | function [exp_data]=construct_dataset(db_name,num_test,topNum) 2 | addpath('./DB/'); 3 | %%choose and load dataset 4 | if strcmp(db_name,'VOC2007&Caltech101') 5 | load VOC2007 %source data 6 | Xs = double(data(:,1:end-1)); 7 | Xs = normalize1(Xs);%Xs/max(max(abs(Xs))); 8 | ys = data(:,end); 9 | clear data; %clear labels; 10 | 11 | load Caltech101 %target data 12 | Xt = double(data(:,1:end-1)); 13 | Xt = normalize1(Xt);%Xt/max(max(abs(Xt))); 14 | yt = data(:,end); 15 | clear data; %clear labels; 16 | 17 | elseif strcmp(db_name, 'Caltech256&ImageNet') 18 | load dense_imagenet_decaf7_subsampled; %target data 19 | Xt = normalize1(fts); 20 | Xt = double(Xt); 21 | yt = double(labels); 22 | clear fts; clear labels; 23 | 24 | load dense_caltech256_decaf7_subsampled; %source data 25 | Xs = normalize1(fts);%fea; 26 | Xs = double(Xs); 27 | ys = double(labels); 28 | clear fts; clear labels; 29 | elseif strcmp(db_name,'MNIST&USPS') 30 | load MNIST_vs_USPS.mat %source data 31 | Xs = double(X_src)'; 32 | Xs = normalize1(Xs);%Xs/max(max(abs(Xs))); 33 | ys = double(Y_src); %vector label 34 | Xt = double(X_tar)'; 35 | Xt = normalize1(Xt);%Xs/max(max(abs(Xs))); 36 | yt = double(Y_tar); %vector label 37 | clear X_src; clear Y_src; clear X_tar; clear Y_tar; 38 | 39 | end 40 | 41 | [ndatat,~] = size(Xt); 42 | R = randperm(ndatat); 43 | test = Xt(R(1:num_test),:); 44 | ytest = yt(R(1:num_test)); 45 | R(1:num_test) = []; 46 | train = Xt(R,:); 47 | yt = yt(R); 48 | train_ID = R; 49 | 50 | 51 | ytnew = knnclassify(train,Xs,ys); 52 | acc = length(find(ytnew==yt))/length(yt) 53 | num_train = size(train,1); 54 | if topNum == 0 55 | topNum = round(0.02*num_train);%set top Num as Two percent of the number of train 56 | 57 | end 58 | DtrueTestTrain = distMat(test,train); 59 | [~,idx] = sort(DtrueTestTrain,2); 60 | idx = idx(:,1:topNum); 61 | WtrueTestTrain = zeros(num_test,num_train); 62 | for i=1:num_test 63 | WtrueTestTrain(i,idx(i,:)) =1; 64 | end 65 | 66 | YS = repmat(ys,1,length(ytest)); 67 | YT = repmat(ytest,1,length(ys)); 68 | WTT = (YT==YS'); 69 | 70 | 71 | XX=[Xs;Xt]; 72 | samplemean = mean(XX,1); 73 | Xs = Xs-repmat(samplemean,size(Xs,1),1); 74 | train = train-repmat(samplemean,size(train,1),1); 75 | test = test-repmat(samplemean,size(test,1),1); 76 | 77 | 78 | exp_data.Xs = Xs ; 79 | exp_data.test = test; 80 | exp_data.train = train; 81 | exp_data.ytnew = ytnew ; 82 | exp_data.yt = yt; 83 | exp_data.ys = ys ; 84 | exp_data.WTT =WTT ; 85 | end 86 | -------------------------------------------------------------------------------- /demo.m: -------------------------------------------------------------------------------- 1 | function [recall, precision, mAP, rec, pre, retrieved_list] = demo(exp_data, param, method) 2 | 3 | WtrueTestTraining = exp_data.WTT ; 4 | pos = param.pos; 5 | r = param.r; 6 | Xs = exp_data.Xs; 7 | Xt = exp_data.train; 8 | ys = exp_data.ys; 9 | yt = exp_data.ytnew; 10 | test = exp_data.test; 11 | 12 | %% set parameters 13 | setting.record = 0; % 14 | setting.mxitr = 10; 15 | setting.xtol = 1e-5; 16 | setting.gtol = 1e-5; 17 | setting.ftol = 1e-8; 18 | paras.k = 50; 19 | paras.sigma = 0.4; 20 | paras.m = 0.3; 21 | paras.theta1 = 10; 22 | paras.theta2 = 10; 23 | paras.lambda1 = 1; 24 | paras.lambda2 = 10; 25 | paras.lambda3 = 1e4; 26 | paras.max_iter = 50; 27 | [paras.nt,paras.d] = size(Xt); 28 | [paras.ns,paras.d] = size(Xs); 29 | %% leraning 30 | fprintf('......%s start ......\n\n', 'PWCF'); 31 | X=[Xs;Xt]; 32 | X1=[Xs;Xt;Xt]; 33 | N = size(X1,1); 34 | 35 | [vec,val] = eig(X'*X); 36 | [~,Idx] = sort(diag(val),'descend'); 37 | W = vec(:,Idx(1:r)); 38 | clear Idx;clear vec; clear val; 39 | 40 | %%Construct triples 41 | YS = repmat(ys,1,length(ys)); 42 | S = (YS==YS'); 43 | %[HS,HT,H] = fea_trans(Xt',Xs',yt,ys,paras); 44 | Ds = EuDist2(Xs,Xs); 45 | Dp = S.*Ds; 46 | [~,Ip] = max(Dp,[],2); 47 | Xp =[]; 48 | for i=1:length(ys) 49 | Xp = [Xp; Xs(Ip(i),:)];%Similar sample 50 | end 51 | Dn = Ds-Dp; 52 | [~,In] = min(Dn,[],2); 53 | Xn =[]; 54 | for i= 1:length(ys) 55 | Xn = [Xn; Xs(In(i),:)]; 56 | end 57 | 58 | YT = repmat(yt,1,length(yt)); 59 | S = (YT==YT'); 60 | %[HS,HT,H] = fea_trans(Xt',Xs',yt,ys,paras); 61 | Dt = EuDist2(Xt,Xt); 62 | Dp = S.*Dt; 63 | [~,Ip] = max(Dp,[],2); 64 | for i=1:length(yt) 65 | Xp = [Xp; Xt(Ip(i),:)];%Similar sample 66 | end 67 | Dn = Dt-Dp; 68 | [~,In] = min(Dn,[],2); 69 | for i= 1:length(yt) 70 | Xn = [Xn; Xt(In(i),:)]; 71 | end 72 | 73 | YS = repmat(ys,1,length(yt)); 74 | YT = repmat(yt,1,length(ys)); 75 | S = (YT==YS'); 76 | [HS,HT,H] = fea_trans(Xt',Xs',yt,ys,paras); 77 | Dts = EuDist2(HT',HS'); 78 | Dp = S.*Dts; 79 | [~,Ip] = max(Dp,[],2); 80 | for i=1:length(yt) 81 | Xp = [Xp; Xs(Ip(i),:)];%Similar sample 82 | end 83 | Dn = Dts-Dp; 84 | [~,In] = min(Dn,[],2); 85 | for i= 1:length(yt) 86 | Xn = [Xn; Xs(In(i),:)]; 87 | end 88 | 89 | Y = sparse(1:length(ys), double(ys), 1); 90 | Y = full(Y); 91 | L = cl(H,paras); 92 | D=zeros(paras.d,r); 93 | F=0; 94 | 95 | Bs = sign(2*rand(r,paras.ns )-1); 96 | Bt = sign(2*rand(r,paras.nt )-1); 97 | 98 | for iter=1:paras.max_iter 99 | for i=1:N 100 | xi=X1(i,:); 101 | xp=Xp(i,:); 102 | xn=Xn(i,:); 103 | 104 | if norm(W'*(xi-xp)','fro')-norm(W'*(xi-xn)','fro')+paras.m>=0 105 | omega=(1-exp(norm(W'*(xi-xn)','fro')-norm(W'*(xi-xp)','fro')-paras.m))^2; 106 | F=F+omega*norm(W'*(xi-xp)','fro')-omega*norm(W'*(xi-xn)','fro'); 107 | D=D+omega*((xi-xn)'*(xi-xn)-(xi-xp)'*(xi-xp))*W; 108 | 109 | end 110 | end 111 | [W, ~] = OptStiefelGBB(W, @W1_obj,setting,F,D,X,L,Xs,Xt,Bs,Bt,paras); 112 | %updata Bt 113 | A=(paras.lambda2*(Bs*Bs')+paras.lambda3*eye(size(Bs*Bs')))\(paras.lambda2*Bs*Y); 114 | Bs = sign((paras.lambda2*(A*A')+paras.theta1*eye(size(A*A')))\(paras.lambda2*A*Y'+paras.theta1 *W'*Xs')); 115 | Bt = sign(W'*Xt'); 116 | end 117 | 118 | B_train = (Xs*W>0); 119 | B_test = (test*W>0); 120 | B_trn = compactbit(B_train); 121 | B_tst = compactbit(B_test); 122 | 123 | % compute Hamming metric and compute recall precision 124 | Dhamm = hammingDist(B_tst, B_trn); 125 | [~, rank] = sort(Dhamm, 2, 'ascend'); 126 | clear B_tst B_trn; 127 | choice = param.choice; 128 | switch(choice) 129 | case 'evaluation_PR_MAP' 130 | clear train_data test_data; 131 | [recall, precision, ~] = recall_precision(WtrueTestTraining, Dhamm); 132 | [rec, pre]= recall_precision5(WtrueTestTraining, Dhamm, pos); % recall VS. the number of retrieved sample 133 | [mAP] = area_RP(recall, precision); 134 | retrieved_list = []; 135 | case 'evaluation_PR' 136 | clear train_data test_data; 137 | eva_info = eva_ranking(rank, trueRank, pos); 138 | rec = eva_info.recall; 139 | pre = eva_info.precision; 140 | recall = []; 141 | precision = []; 142 | mAP = []; 143 | retrieved_list = []; 144 | case 'visualization' 145 | num = param.numRetrieval; 146 | retrieved_list = visualization(Dhamm, ID, num, train_data, test_data); 147 | recall = []; 148 | precision = []; 149 | rec = []; 150 | pre = []; 151 | mAP = []; 152 | end 153 | 154 | end 155 | -------------------------------------------------------------------------------- /main_demo.m: -------------------------------------------------------------------------------- 1 | 2 | close all; clear all; clc; 3 | addpath('./utils/'); 4 | db_name = 'MNIST&USPS'; 5 | param.choice = 'evaluation_PR_MAP'; 6 | loopnbits = [32 64]; 7 | num_test=500; 8 | runtimes = 1; % change several times to make the rusult more smooth 9 | param.pos = [1:10:40 50:50:1000]; % The number of retrieved samples: Recall-The number of retrieved samples curve 10 | hashmethods = {'PWCF'}; 11 | 12 | nhmethods = length(hashmethods); 13 | 14 | for k = 1:runtimes 15 | fprintf('The %d run time, start constructing data\n\n', k); 16 | exp_data = construct_dataset(db_name,num_test,0); 17 | fprintf('Constructing data finished\n\n'); 18 | for i =1:length(loopnbits) 19 | fprintf('======start %d bits encoding======\n\n', loopnbits(i)); 20 | param.r = loopnbits(i); 21 | for j = 1:nhmethods 22 | [recall{k}{i, j}, precision{k}{i, j}, mAP{k}{i,j}, rec{k}{i, j}, pre{k}{i, j}, ~] = demo(exp_data, param, hashmethods{1, j}); 23 | end 24 | end 25 | end 26 | 27 | % plot attribution 28 | line_width = 1.5; 29 | marker_size = 4; 30 | xy_font_size = 16; 31 | legend_font_size = 14; 32 | linewidth = 1.6; 33 | title_font_size = 18; 34 | 35 | 36 | % average MAP 37 | for j = 1:nhmethods 38 | for i =1: length(loopnbits) 39 | tmp = zeros(size(mAP{1, 1}{i, j})); 40 | for k = 1:runtimes 41 | tmp = tmp+mAP{1, k}{i, j}; 42 | end 43 | MAP{i, j} = tmp/runtimes; 44 | end 45 | clear tmp; 46 | end 47 | MAP 48 | choose_bits = 1; % i: choose the bits to show 49 | choose_times = 1; % k is the times of run times 50 | %% show recall vs. the number of retrieved sample. 51 | figure('Color', [1 1 1]); hold on; 52 | posEnd = 8; 53 | for j = 1: nhmethods 54 | pos = param.pos; 55 | recc = rec{choose_times}{choose_bits, j}; 56 | %p = plot(pos(1,1:posEnd), recc(1,1:posEnd)); 57 | p = plot(pos(1,1:end), recc(1,1:end)); 58 | color = gen_color(j); 59 | marker = gen_marker(j); 60 | set(p,'Color', color) 61 | set(p,'Marker', marker); 62 | set(p,'LineWidth', line_width); 63 | set(p,'MarkerSize', marker_size); 64 | end 65 | 66 | str_nbits = num2str(loopnbits(choose_bits)); 67 | set(gca, 'linewidth', linewidth); 68 | h1 = xlabel('The number of retrieved samples'); 69 | h2 = ylabel(['Recall @ ', str_nbits, ' bits']); 70 | title(db_name, 'FontSize', title_font_size); 71 | set(h1, 'FontSize', xy_font_size); 72 | set(h2, 'FontSize', xy_font_size); 73 | %axis square; 74 | hleg = legend(hashmethods); 75 | set(hleg, 'FontSize', legend_font_size); 76 | set(hleg,'Location', 'best'); 77 | box on; 78 | grid on; 79 | hold off; 80 | 81 | %% show precision vs. the number of retrieved sample. 82 | figure('Color', [1 1 1]); hold on; 83 | posEnd = 8; 84 | for j = 1: nhmethods 85 | pos = param.pos; 86 | prec = pre{choose_times}{choose_bits, j}; 87 | %p = plot(pos(1,1:posEnd), recc(1,1:posEnd)); 88 | p = plot(pos(1,1:end), prec(1,1:end)); 89 | color = gen_color(j); 90 | marker = gen_marker(j); 91 | set(p,'Color', color) 92 | set(p,'Marker', marker); 93 | set(p,'LineWidth', line_width); 94 | set(p,'MarkerSize', marker_size); 95 | end 96 | 97 | str_nbits = num2str(loopnbits(choose_bits)); 98 | set(gca, 'linewidth', linewidth); 99 | h1 = xlabel('The number of retrieved samples'); 100 | h2 = ylabel(['Precision @ ', str_nbits, ' bits']); 101 | title(db_name, 'FontSize', title_font_size); 102 | set(h1, 'FontSize', xy_font_size); 103 | set(h2, 'FontSize', xy_font_size); 104 | %axis square; 105 | hleg = legend(hashmethods); 106 | set(hleg, 'FontSize', legend_font_size); 107 | set(hleg,'Location', 'best'); 108 | box on; 109 | grid on; 110 | hold off; 111 | 112 | %% show precision vs. recall , i is the selection of which bits. 113 | figure('Color', [1 1 1]); hold on; 114 | 115 | for j = 1: nhmethods 116 | p = plot(recall{choose_times}{choose_bits, j}, precision{choose_times}{choose_bits, j}); 117 | color=gen_color(j); 118 | marker=gen_marker(j); 119 | set(p,'Color', color) 120 | set(p,'Marker', marker); 121 | set(p,'LineWidth', line_width); 122 | set(p,'MarkerSize', marker_size); 123 | end 124 | 125 | str_nbits = num2str(loopnbits(choose_bits)); 126 | h1 = xlabel(['Recall @ ', str_nbits, ' bits']); 127 | h2 = ylabel('Precision'); 128 | title(db_name, 'FontSize', title_font_size); 129 | set(h1, 'FontSize', xy_font_size); 130 | set(h2, 'FontSize', xy_font_size); 131 | %axis square; 132 | hleg = legend(hashmethods); 133 | set(hleg, 'FontSize', legend_font_size); 134 | set(hleg,'Location', 'best'); 135 | set(gca, 'linewidth', linewidth); 136 | box on; 137 | grid on; 138 | hold off; 139 | 140 | %% show mAP. This mAP function is provided by Yunchao Gong 141 | figure('Color', [1 1 1]); hold on; 142 | for j = 1: nhmethods 143 | map = []; 144 | for i = 1: length(loopnbits) 145 | map = [map, MAP{i, j}]; 146 | end 147 | p = plot(log2(loopnbits), map); 148 | color=gen_color(j); 149 | marker=gen_marker(j); 150 | set(p,'Color', color); 151 | set(p,'Marker', marker); 152 | set(p,'LineWidth', line_width); 153 | set(p,'MarkerSize', marker_size); 154 | end 155 | 156 | h1 = xlabel('Number of bits'); 157 | h2 = ylabel('mean Average Precision (mAP)'); 158 | title(db_name, 'FontSize', title_font_size); 159 | set(h1, 'FontSize', xy_font_size); 160 | set(h2, 'FontSize', xy_font_size); 161 | %axis square; 162 | set(gca, 'xtick', log2(loopnbits)); 163 | set(gca, 'XtickLabel', {'16','32','48' '64','96','128'}); 164 | set(gca, 'linewidth', linewidth); 165 | hleg = legend(hashmethods); 166 | set(hleg, 'FontSize', legend_font_size); 167 | set(hleg, 'Location', 'best'); 168 | box on; 169 | grid on; 170 | hold off; 171 | -------------------------------------------------------------------------------- /OptStiefelGBB.m: -------------------------------------------------------------------------------- 1 | function [X, out]= OptStiefelGBB(X, fun, opts, varargin) 2 | %------------------------------------------------------------------------- 3 | % curvilinear search algorithm for optimization on Stiefel manifold 4 | % 5 | % min F(X), S.t., X'*X = I_k, where X \in R^{n,k} 6 | % 7 | % H = [G, X]*[X -G]' 8 | % U = 0.5*tau*[G, X]; V = [X -G] 9 | % X(tau) = X - 2*U * inv( I + V'*U ) * V'*X 10 | % 11 | % ------------------------------------- 12 | % U = -[G,X]; V = [X -G]; VU = V'*U; 13 | % X(tau) = X - tau*U * inv( I + 0.5*tau*VU ) * V'*X 14 | % 15 | % 16 | % Input: 17 | % X --- n by k matrix such that X'*X = I 18 | % fun --- objective function and its gradient: 19 | % [F, G] = fun(X, data1, data2) 20 | % F, G are the objective function value and gradient, repectively 21 | % data1, data2 are addtional data, and can be more 22 | % Calling syntax: 23 | % [X, out]= OptStiefelGBB(X0, @fun, opts, data1, data2); 24 | % 25 | % opts --- option structure with fields: 26 | % record = 0, no print out 27 | % mxitr max number of iterations 28 | % xtol stop control for ||X_k - X_{k-1}|| 29 | % gtol stop control for the projected gradient 30 | % ftol stop control for |F_k - F_{k-1}|/(1+|F_{k-1}|) 31 | % usually, max{xtol, gtol} > ftol 32 | % 33 | % Output: 34 | % X --- solution 35 | % Out --- output information 36 | % 37 | % ------------------------------------- 38 | % For example, consider the eigenvalue problem F(X) = -0.5*Tr(X'*A*X); 39 | % 40 | % function demo 41 | % 42 | % function [F, G] = fun(X, A) 43 | % G = -(A*X); 44 | % F = 0.5*sum(dot(G,X,1)); 45 | % end 46 | % 47 | % n = 1000; k = 6; 48 | % A = randn(n); A = A'*A; 49 | % opts.record = 0; % 50 | % opts.mxitr = 1000; 51 | % opts.xtol = 1e-5; 52 | % opts.gtol = 1e-5; 53 | % opts.ftol = 1e-8; 54 | % 55 | % X0 = randn(n,k); X0 = orth(X0); 56 | % tic; [X, out]= OptStiefelGBB(X0, @fun, opts, A); tsolve = toc; 57 | % out.fval = -2*out.fval; % convert the function value to the sum of eigenvalues 58 | % fprintf('\nOptM: obj: %7.6e, itr: %d, nfe: %d, cpu: %f, norm(XT*X-I): %3.2e \n', ... 59 | % out.fval, out.itr, out.nfe, tsolve, norm(X'*X - eye(k), 'fro') ); 60 | % 61 | % end 62 | % ------------------------------------- 63 | % 64 | % Reference: 65 | % Z. Wen and W. Yin 66 | % A feasible method for optimization with orthogonality constraints 67 | % 68 | % Author: Zaiwen Wen, Wotao Yin 69 | % Version 1.0 .... 2010/10 70 | %------------------------------------------------------------------------- 71 | 72 | 73 | %% Size information 74 | if isempty(X) 75 | error('input X is an empty matrix'); 76 | else 77 | [n, k] = size(X); 78 | end 79 | 80 | if isfield(opts, 'xtol') 81 | if opts.xtol < 0 || opts.xtol > 1 82 | opts.xtol = 1e-6; 83 | end 84 | else 85 | opts.xtol = 1e-6; 86 | end 87 | 88 | if isfield(opts, 'gtol') 89 | if opts.gtol < 0 || opts.gtol > 1 90 | opts.gtol = 1e-6; 91 | end 92 | else 93 | opts.gtol = 1e-6; 94 | end 95 | 96 | if isfield(opts, 'ftol') 97 | if opts.ftol < 0 || opts.ftol > 1 98 | opts.ftol = 1e-12; 99 | end 100 | else 101 | opts.ftol = 1e-12; 102 | end 103 | 104 | % parameters for control the linear approximation in line search 105 | if isfield(opts, 'rho') 106 | if opts.rho < 0 || opts.rho > 1 107 | opts.rho = 1e-4; 108 | end 109 | else 110 | opts.rho = 1e-4; 111 | end 112 | 113 | % factor for decreasing the step size in the backtracking line search 114 | if isfield(opts, 'eta') 115 | if opts.eta < 0 || opts.eta > 1 116 | opts.eta = 0.1; 117 | end 118 | else 119 | opts.eta = 0.2; 120 | end 121 | 122 | % parameters for updating C by HongChao, Zhang 123 | if isfield(opts, 'gamma') 124 | if opts.gamma < 0 || opts.gamma > 1 125 | opts.gamma = 0.85; 126 | end 127 | else 128 | opts.gamma = 0.85; 129 | end 130 | 131 | if isfield(opts, 'tau') 132 | if opts.tau < 0 || opts.tau > 1e3 133 | opts.tau = 1e-3; 134 | end 135 | else 136 | opts.tau = 1e-3; 137 | end 138 | 139 | % parameters for the nonmontone line search by Raydan 140 | if ~isfield(opts, 'STPEPS') 141 | opts.STPEPS = 1e-10; 142 | end 143 | 144 | if isfield(opts, 'nt') 145 | if opts.nt < 0 || opts.nt > 100 146 | opts.nt = 5; 147 | end 148 | else 149 | opts.nt = 5; 150 | end 151 | 152 | if isfield(opts, 'projG') 153 | switch opts.projG 154 | case {1,2}; otherwise; opts.projG = 1; 155 | end 156 | else 157 | opts.projG = 1; 158 | end 159 | 160 | if isfield(opts, 'iscomplex') 161 | switch opts.iscomplex 162 | case {0, 1}; otherwise; opts.iscomplex = 0; 163 | end 164 | else 165 | opts.iscomplex = 0; 166 | end 167 | 168 | if isfield(opts, 'mxitr') 169 | if opts.mxitr < 0 || opts.mxitr > 2^20 170 | opts.mxitr = 1000; 171 | end 172 | else 173 | opts.mxitr = 1000; 174 | end 175 | 176 | if ~isfield(opts, 'record') 177 | opts.record = 0; 178 | end 179 | 180 | 181 | %------------------------------------------------------------------------------- 182 | % copy parameters 183 | xtol = opts.xtol; 184 | gtol = opts.gtol; 185 | ftol = opts.ftol; 186 | rho = opts.rho; 187 | STPEPS = opts.STPEPS; 188 | eta = opts.eta; 189 | gamma = opts.gamma; 190 | iscomplex = opts.iscomplex; 191 | record = opts.record; 192 | 193 | nt = opts.nt; crit = ones(nt, 3); 194 | 195 | invH = true; if k < n/2; invH = false; eye2k = eye(2*k); end 196 | 197 | %% Initial function value and gradient 198 | % prepare for iterations 199 | [F, G] = feval(fun, X , varargin{:}); out.nfe = 1; 200 | GX = G'*X; 201 | 202 | if invH 203 | GXT = G*X'; H = 0.5*(GXT - GXT'); RX = H*X; 204 | else 205 | if opts.projG == 1 206 | U = [G, X]; V = [X, -G]; VU = V'*U; 207 | elseif opts.projG == 2 208 | GB = G - 0.5*X*(X'*G); 209 | U = [GB, X]; V = [X, -GB]; VU = V'*U; 210 | end 211 | %U = [G, X]; VU = [GX', X'*X; -(G'*G), -GX]; 212 | %VX = VU(:,k+1:end); %VX = V'*X; 213 | VX = V'*X; 214 | end 215 | dtX = G - X*GX; nrmG = norm(dtX, 'fro'); 216 | 217 | Q = 1; Cval = F; tau = opts.tau; 218 | 219 | %% Print iteration header if debug == 1 220 | if (opts.record == 1) 221 | fid = 1; 222 | fprintf(fid, '----------- Gradient Method with Line search ----------- \n'); 223 | fprintf(fid, '%4s %8s %8s %10s %10s\n', 'Iter', 'tau', 'F(X)', 'nrmG', 'XDiff'); 224 | %fprintf(fid, '%4d \t %3.2e \t %3.2e \t %5d \t %5d \t %6d \n', 0, 0, F, 0, 0, 0); 225 | end 226 | 227 | %% main iteration 228 | for itr = 1 : opts.mxitr 229 | XP = X; FP = F; GP = G; dtXP = dtX; 230 | % scale step size 231 | 232 | nls = 1; deriv = rho*nrmG^2; %deriv 233 | while 1 234 | % calculate G, F, 235 | if invH 236 | [X, infX] = linsolve(eye(n) + tau*H, XP - tau*RX); 237 | else 238 | [aa, infR] = linsolve(eye2k + (0.5*tau)*VU, VX); 239 | X = XP - U*(tau*aa); 240 | end 241 | %if norm(X'*X - eye(k),'fro') > 1e-6; error('X^T*X~=I'); end 242 | if ~isreal(X) && ~iscomplex ; error('X is complex'); end 243 | 244 | [F,G] = feval(fun, X, varargin{:}); 245 | out.nfe = out.nfe + 1; 246 | 247 | if F <= Cval - tau*deriv || nls >= 5 248 | break; 249 | end 250 | tau = eta*tau; nls = nls+1; 251 | end 252 | 253 | GX = G'*X; 254 | if invH 255 | GXT = G*X'; H = 0.5*(GXT - GXT'); RX = H*X; 256 | else 257 | if opts.projG == 1 258 | U = [G, X]; V = [X, -G]; VU = V'*U; 259 | elseif opts.projG == 2 260 | GB = G - 0.5*X*(X'*G); 261 | U = [GB, X]; V = [X, -GB]; VU = V'*U; 262 | end 263 | %U = [G, X]; VU = [GX', X'*X; -(G'*G), -GX]; 264 | %VX = VU(:,k+1:end); % VX = V'*X; 265 | VX = V'*X; 266 | end 267 | dtX = G - X*GX; nrmG = norm(dtX, 'fro'); 268 | 269 | S = X - XP; XDiff = norm(S,'fro')/sqrt(n); 270 | tau = opts.tau; FDiff = abs(FP-F)/(abs(FP)+1); 271 | 272 | if iscomplex 273 | %Y = dtX - dtXP; SY = (sum(sum(real(conj(S).*Y)))); 274 | Y = dtX - dtXP; SY = abs(sum(sum(conj(S).*Y))); 275 | if mod(itr,2)==0; tau = sum(sum(conj(S).*S))/SY; 276 | else tau = SY/sum(sum(conj(Y).*Y)); end 277 | else 278 | %Y = G - GP; SY = abs(sum(sum(S.*Y))); 279 | Y = dtX - dtXP; SY = abs(sum(sum(S.*Y))); 280 | %alpha = sum(sum(S.*S))/SY; 281 | %alpha = SY/sum(sum(Y.*Y)); 282 | %alpha = max([sum(sum(S.*S))/SY, SY/sum(sum(Y.*Y))]); 283 | if mod(itr,2)==0; tau = sum(sum(S.*S))/SY; 284 | else tau = SY/sum(sum(Y.*Y)); end 285 | 286 | % %Y = G - GP; 287 | % Y = dtX - dtXP; 288 | % YX = Y'*X; SX = S'*X; 289 | % SY = abs(sum(sum(S.*Y)) - 0.5*sum(sum(YX.*SX)) ); 290 | % if mod(itr,2)==0; 291 | % tau = SY/(sum(sum(S.*S))- 0.5*sum(sum(SX.*SX))); 292 | % else 293 | % tau = (sum(sum(Y.*Y)) -0.5*sum(sum(YX.*YX)))/SY; 294 | % end 295 | 296 | end 297 | tau = max(min(tau, 1e20), 1e-20); 298 | 299 | if (record >= 1) 300 | fprintf('%4d %3.2e %4.3e %3.2e %3.2e %3.2e %2d\n', ... 301 | itr, tau, F, nrmG, XDiff, FDiff, nls); 302 | %fprintf('%4d %3.2e %4.3e %3.2e %3.2e (%3.2e, %3.2e)\n', ... 303 | % itr, tau, F, nrmG, XDiff, alpha1, alpha2); 304 | end 305 | 306 | crit(itr,:) = [nrmG, XDiff, FDiff]; 307 | mcrit = mean(crit(itr-min(nt,itr)+1:itr, :),1); 308 | %if (XDiff < xtol && nrmG < gtol ) || FDiff < ftol 309 | %if (XDiff < xtol || nrmG < gtol ) || FDiff < ftol 310 | %if ( XDiff < xtol && FDiff < ftol ) || nrmG < gtol 311 | %if ( XDiff < xtol || FDiff < ftol ) || nrmG < gtol 312 | if ( XDiff < xtol && FDiff < ftol ) || nrmG < gtol || all(mcrit(2:3) < 10*[xtol, ftol]) 313 | if itr <= 2 314 | ftol = 0.1*ftol; 315 | xtol = 0.1*xtol; 316 | gtol = 0.1*gtol; 317 | else 318 | out.msg = 'converge'; 319 | break; 320 | end 321 | end 322 | 323 | Qp = Q; Q = gamma*Qp + 1; Cval = (gamma*Qp*Cval + F)/Q; 324 | end 325 | 326 | if itr >= opts.mxitr 327 | out.msg = 'exceed max iteration'; 328 | end 329 | 330 | out.feasi = norm(X'*X-eye(k),'fro'); 331 | if out.feasi > 1e-13 332 | X = MGramSchmidt(X); 333 | [F,G] = feval(fun, X, varargin{:}); 334 | out.nfe = out.nfe + 1; 335 | out.feasi = norm(X'*X-eye(k),'fro'); 336 | end 337 | 338 | out.nrmG = nrmG; 339 | out.fval = F; 340 | out.itr = itr; 341 | 342 | 343 | 344 | 345 | --------------------------------------------------------------------------------