├── DB
└── MNIST_vs_USPS.mat
├── utils
├── recall_precision.m
├── normalize1.m
├── gen_color.m
├── compactbit.m
├── distMat.m
├── area_RP.m
├── gen_marker.m
├── recall_precision5.m
├── EuDist2.m
├── hammingDist.m
└── L2_distance.m
├── histogram_label.m
├── MGramSchmidt.m
├── W1_obj.m
├── fea_trans.m
├── cl.m
├── README.md
├── construct_dataset.m
├── demo.m
├── main_demo.m
└── OptStiefelGBB.m
/DB/MNIST_vs_USPS.mat:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuxianghuang1/PWCF/HEAD/DB/MNIST_vs_USPS.mat
--------------------------------------------------------------------------------
/utils/recall_precision.m:
--------------------------------------------------------------------------------
https://raw.githubusercontent.com/fuxianghuang1/PWCF/HEAD/utils/recall_precision.m
--------------------------------------------------------------------------------
/utils/normalize1.m:
--------------------------------------------------------------------------------
1 | function [X] = normalize1(X)
2 |
3 | for i=1:size(X,1)
4 | if(norm(X(i,:))==0)
5 |
6 | else
7 | X(i,:) = X(i,:)./norm(X(i,:));
8 | end
9 | end
--------------------------------------------------------------------------------
/histogram_label.m:
--------------------------------------------------------------------------------
1 | function [h] = histogram_label(l,classnum,k)
2 | %UNTITLED3 Summary of this function goes here
3 | % Detailed explanation goes here
4 | h = zeros(classnum,1);
5 | for i =1:classnum
6 | num = length(find(i==l));
7 | h(i) = num;
8 | end
9 | h = h/k;
10 | end
11 |
12 |
--------------------------------------------------------------------------------
/MGramSchmidt.m:
--------------------------------------------------------------------------------
1 | function V = MGramSchmidt(V)
2 | [n,k] = size(V);
3 |
4 | for dj = 1:k
5 | for di = 1:dj-1
6 | V(:,dj) = V(:,dj) - proj(V(:,di), V(:,dj));
7 | end
8 | V(:,dj) = V(:,dj)/norm(V(:,dj));
9 | end
10 | end
11 |
12 |
13 | %project v onto u
14 | function v = proj(u,v)
15 | v = (dot(v,u)/dot(u,u))*u;
16 | end
17 |
--------------------------------------------------------------------------------
/utils/gen_color.m:
--------------------------------------------------------------------------------
1 | function color=gen_color(curve_idx)
2 |
3 | colors=[];
4 | colors{end+1}='b';
5 | colors{end+1}='r';
6 | colors{end+1}='g';
7 | colors{end+1}='m';
8 | colors{end+1}='c';
9 | colors{end+1}='black';
10 | colors{end+1}=[0.7 0 0.7 ];
11 | colors{end+1}=[0 0.7 0.7 ];
12 | colors{end+1}=[ 0.83 0.33 0];
13 |
14 | sel_idx=mod(curve_idx-1, length(colors))+1;
15 | color=colors{sel_idx};
16 |
17 | end
--------------------------------------------------------------------------------
/W1_obj.m:
--------------------------------------------------------------------------------
1 | function [ Ft, Gt] = W1_obj(W,F,D,X,L,Xs,Xt,Bs,Bt,paras)
2 | %UNTITLED2 Summary of this function goes here
3 | % Detailed explanation goes here
4 | Ft = F+paras.theta1 *norm(Bs-W'*Xs','fro')^2+paras.theta2*norm(Bt-W'*Xt','fro')^2+paras.lambda1*trace(W'*X'*L*X*W);
5 | Gt = 2*D-2*paras.theta1 *Xs'*Bs'+2*paras.theta1 *(Xs'*Xs)*W-2*paras.theta2 *Xt'*Bt'+2*paras.theta2 *(Xt'*Xt)*W+2*paras.lambda1*X'*L*X*W;
6 | end
--------------------------------------------------------------------------------
/utils/compactbit.m:
--------------------------------------------------------------------------------
1 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
2 | function cb = compactbit(b)
3 | %
4 | % b = bits array
5 | % cb = compacted string of bits (using words of 'word' bits)
6 |
7 | [nSamples nbits] = size(b);
8 | nwords = ceil(nbits/8);
9 | cb = zeros([nSamples nwords], 'uint8');
10 |
11 | for j = 1:nbits
12 | w = ceil(j/8);
13 | cb(:,w) = bitset(cb(:,w), mod(j-1,8)+1, b(:,j));
14 | end
--------------------------------------------------------------------------------
/utils/distMat.m:
--------------------------------------------------------------------------------
1 | function D=distMat(P1, P2)
2 | %
3 | % Euclidian distances between vectors
4 | % each vector is one row
5 |
6 | if nargin == 2
7 | P1 = double(P1);
8 | P2 = double(P2);
9 |
10 | X1=repmat(sum(P1.^2,2),[1 size(P2,1)]);
11 | X2=repmat(sum(P2.^2,2),[1 size(P1,1)]);
12 | R=P1*P2';
13 | D=real(sqrt(X1+X2'-2*R));
14 | else
15 | P1 = double(P1);
16 |
17 | % each vector is one row
18 | X1=repmat(sum(P1.^2,2),[1 size(P1,1)]);
19 | R=P1*P1';
20 | D=X1+X1'-2*R;
21 | D = real(sqrt(D));
22 | end
23 |
24 |
25 |
--------------------------------------------------------------------------------
/utils/area_RP.m:
--------------------------------------------------------------------------------
1 | function [value] = area_RP(recall, precision)
2 |
3 | if(recall(1)~=0)
4 | xx = [0,recall'];
5 | yy = [precision(1), precision'];
6 | else
7 | xx = [recall'];
8 | yy = [precision'];
9 | end
10 | [xx, index] = unique(xx);
11 | yy = yy(index);
12 | for iii = 1:length(xx)
13 | ic = length(xx)-iii+1;
14 | if(yy(ic) >= 0)
15 | % nothing
16 | else
17 | yy(ic) = yy(ic+1);
18 | end
19 |
20 | end
21 | area = 0;
22 | for i=1:(length(xx)-1)
23 | subarea = 0.5*(xx(i+1)-xx(i))*(yy(i+1)+yy(i));
24 | area = area+subarea;
25 | end
26 | value = area;
--------------------------------------------------------------------------------
/utils/gen_marker.m:
--------------------------------------------------------------------------------
1 |
2 | function marker=gen_marker(curve_idx)
3 |
4 | markers=[];
5 |
6 | % scheme
7 | % scheme
8 | % markers{end+1}='o';
9 | % markers{end+1}='*';
10 | % markers{end+1}='d';
11 | % markers{end+1}='p';
12 | % markers{end+1}='s';
13 | % markers{end+1}='h';
14 | % markers{end+1}='o';
15 | % markers{end+1}='*';
16 | % markers{end+1}='o';
17 | % markers{end+1}='o';
18 | % markers{end+1}='o';
19 | % markers{end+1}='o';
20 | % markers{end+1}='o';
21 |
22 | markers{end+1}='s';
23 | markers{end+1}='o';
24 | markers{end+1}='d';
25 | markers{end+1}='^';
26 | markers{end+1}='*';
27 | markers{end+1}='v';
28 | markers{end+1}='x';
29 | markers{end+1}='+';
30 | markers{end+1}='>';
31 | markers{end+1}='<';
32 | markers{end+1}='p';
33 | markers{end+1}='h';
34 | markers{end+1}='x';
35 | markers{end+1}='o';
36 | markers{end+1}='*';
37 | markers{end+1}='d';
38 | markers{end+1}='p';
39 | markers{end+1}='s';
40 | markers{end+1}='h';
41 |
42 | sel_idx=mod(curve_idx-1, length(markers))+1;
43 | marker=markers{sel_idx};
44 |
45 | end
46 |
--------------------------------------------------------------------------------
/utils/recall_precision5.m:
--------------------------------------------------------------------------------
1 | function [recall, presicion] = recall_precision5(Wtrue, Dhat, pos)
2 | %
3 | % Input:
4 | % Wtrue = true neighbors [Ntest * Ndataset], can be a full matrix NxN
5 | % Dhat = estimated distances
6 | %
7 | % Output:
8 | %
9 | % exp. # of good pairs inside hamming ball of radius <= (n-1)
10 | % precision(n) = --------------------------------------------------------------
11 | % exp. # of total pairs inside hamming ball of radius <= (n-1)
12 | %
13 | % exp. # of good pairs inside hamming ball of radius <= (n-1)
14 | % recall(n) = --------------------------------------------------------------
15 | % exp. # of total good pairs
16 |
17 | grid = pos;
18 | for i=1:size(Dhat,1)
19 | [a,b] = sort(Dhat(i,:),'ascend');
20 | Wtrue(i,:) = Wtrue(i,b);
21 | end
22 | total_good_pairs = sum(Wtrue(:));
23 |
24 | for i=1:length(grid)
25 | g = grid(i);
26 | retrieved_good_pairs = sum(sum(Wtrue(:,1:g)));
27 | [row, col] = size(Wtrue(:,1:g));
28 | total_pairs = row*col;
29 | recall(i) = retrieved_good_pairs/total_good_pairs;
30 | presicion(i) = retrieved_good_pairs/total_pairs;
31 | end
32 |
33 |
34 |
35 |
36 |
--------------------------------------------------------------------------------
/utils/EuDist2.m:
--------------------------------------------------------------------------------
1 |
2 | function D = EuDist2(fea_a,fea_b,bSqrt)
3 | %EUDIST2 Efficiently Compute the Euclidean Distance Matrix by Exploring the
4 | %Matlab matrix operations.
5 | %
6 | % D = EuDist(fea_a,fea_b)
7 | % fea_a: nSample_a * nFeature
8 | % fea_b: nSample_b * nFeature
9 | % D: nSample_a * nSample_a
10 | % or nSample_a * nSample_b
11 | %
12 | % Examples:
13 | %
14 | % a = rand(500,10);
15 | % b = rand(1000,10);
16 | %
17 | % A = EuDist2(a); % A: 500*500
18 | % D = EuDist2(a,b); % D: 500*1000
19 | %
20 | % version 2.1 --November/2011
21 | % version 2.0 --May/2009
22 | % version 1.0 --November/2005
23 | %
24 | % Written by Deng Cai (dengcai AT gmail.com)
25 |
26 |
27 | if ~exist('bSqrt','var')
28 | bSqrt = 1;
29 | end
30 |
31 | if (~exist('fea_b','var')) || isempty(fea_b)
32 | aa = sum(fea_a.*fea_a,2);
33 | ab = fea_a*fea_a';
34 |
35 | if issparse(aa)
36 | aa = full(aa);
37 | end
38 |
39 | D = bsxfun(@plus,aa,aa') - 2*ab;
40 | D(D<0) = 0;
41 | if bSqrt
42 | D = sqrt(D);
43 | end
44 | D = max(D,D');
45 | else
46 | aa = sum(fea_a.*fea_a,2);
47 | bb = sum(fea_b.*fea_b,2);
48 | ab = fea_a*fea_b';
49 |
50 | if issparse(aa)
51 | aa = full(aa);
52 | bb = full(bb);
53 | end
54 |
55 | D = bsxfun(@plus,aa,bb') - 2*ab;
56 | D(D<0) = 0;
57 | if bSqrt
58 | D = sqrt(D);
59 | end
60 | end
61 |
--------------------------------------------------------------------------------
/utils/hammingDist.m:
--------------------------------------------------------------------------------
1 | function Dh=hammingDist(B1, B2)
2 | %
3 | % Compute hamming distance between two sets of samples (B1, B2)
4 | %
5 | % Dh=hammingDist(B1, B2);
6 | %
7 | % Input
8 | % B1, B2: compact bit vectors. Each datapoint is one row.
9 | % size(B1) = [ndatapoints1, nwords]
10 | % size(B2) = [ndatapoints2, nwords]
11 | % It is faster if ndatapoints1 < ndatapoints2
12 | %
13 | % Output
14 | % Dh = hamming distance.
15 | % size(Dh) = [ndatapoints1, ndatapoints2]
16 |
17 | % example query
18 | % Dhamm = hammingDist(B2, B1);
19 | % this will give the same result than:
20 | % Dhamm = distMat(U2>0, U1>0).^2;
21 | % the size of the distance matrix is:
22 | % size(Dhamm) = [Ntest x Ntraining]
23 |
24 | % loop-up table:
25 | bit_in_char = uint16([...
26 | 0 1 1 2 1 2 2 3 1 2 2 3 2 3 3 4 1 2 2 3 2 3 ...
27 | 3 4 2 3 3 4 3 4 4 5 1 2 2 3 2 3 3 4 2 3 3 4 ...
28 | 3 4 4 5 2 3 3 4 3 4 4 5 3 4 4 5 4 5 5 6 1 2 ...
29 | 2 3 2 3 3 4 2 3 3 4 3 4 4 5 2 3 3 4 3 4 4 5 ...
30 | 3 4 4 5 4 5 5 6 2 3 3 4 3 4 4 5 3 4 4 5 4 5 ...
31 | 5 6 3 4 4 5 4 5 5 6 4 5 5 6 5 6 6 7 1 2 2 3 ...
32 | 2 3 3 4 2 3 3 4 3 4 4 5 2 3 3 4 3 4 4 5 3 4 ...
33 | 4 5 4 5 5 6 2 3 3 4 3 4 4 5 3 4 4 5 4 5 5 6 ...
34 | 3 4 4 5 4 5 5 6 4 5 5 6 5 6 6 7 2 3 3 4 3 4 ...
35 | 4 5 3 4 4 5 4 5 5 6 3 4 4 5 4 5 5 6 4 5 5 6 ...
36 | 5 6 6 7 3 4 4 5 4 5 5 6 4 5 5 6 5 6 6 7 4 5 ...
37 | 5 6 5 6 6 7 5 6 6 7 6 7 7 8]);
38 |
39 | n1 = size(B1,1);
40 | [n2, nwords] = size(B2);
41 |
42 | Dh = zeros([n1 n2], 'uint16');
43 | for j = 1:n1
44 | for n=1:nwords
45 | y = bitxor(B1(j,n),B2(:,n));
46 | Dh(j,:) = Dh(j,:) + bit_in_char(y+1);
47 | end
48 | end
49 |
--------------------------------------------------------------------------------
/fea_trans.m:
--------------------------------------------------------------------------------
1 | function [HS,HT,H] = fea_trans(DT,DS,YT,YS,paras)
2 | %UNTITLED2 Summary of this function goes here
3 | % Detailed explanation goes here
4 | YT = YT'; % row label vector
5 | YS = YS'; % row label vector
6 | k = paras.k;
7 | ns = paras.ns;
8 | nt = paras.nt;
9 | HS = [];
10 | HT = [];
11 | SS = EuDist2(DS',DS');
12 | SS(logical(eye(size(SS)))) = 10000;
13 | ST = EuDist2(DT',DT');
14 | ST(logical(eye(size(ST)))) = 10000;
15 |
16 | clsnum1 = length(unique(YS));
17 | clsnum2 = length(unique(YT));
18 | if clsnum1==clsnum2
19 | clsnum = clsnum1;
20 | else
21 | fprintf('......%s start ......\n\n', 'inequal class number for source and target domains');
22 | clsnum = clsnum1;
23 | end
24 | for i=1:size(SS,1)
25 | index = zeros(1,ns);
26 | [~,ind] = sort(SS(i,:),'ascend');
27 | ind = ind(1:k);
28 | index(ind) = 1;
29 | label_appear = YS.*index;
30 | hs = histogram_label(label_appear,clsnum,k);
31 | HS = [HS,hs];
32 | end
33 | clear SS;clear ind;clear index;clear hs;clear label_appear;clear YS;
34 | for j =1:size(ST,1)
35 | index = zeros(1,nt);
36 | [~,ind] = sort(ST(j,:),'ascend');
37 | ind = ind(1:k);
38 | index(ind) = 1;
39 | label_appear = YT.*index;
40 | ht = histogram_label(label_appear,clsnum,k);
41 | HT = [HT,ht];
42 | end
43 | clear ST;clear ind;clear index;clear ht;clear label_appear;clear YT;
44 | H = [HS,HT];
45 | end
46 |
47 |
--------------------------------------------------------------------------------
/cl.m:
--------------------------------------------------------------------------------
1 | function [ L] = cl(H,paras)
2 |
3 | k =500;
4 | sigma =paras.sigma;
5 |
6 | % Construct neighborhood graph
7 | disp('Constructing neighborhood graph...');
8 |
9 | G = L2_distance(H, H);
10 |
11 | [~, ind] = sort(G);
12 | for i=1:size(G, 1)
13 | G(i, ind((2 + k):end, i)) = 0;
14 | end
15 | G = sparse(double(G));
16 | G = max(G, G'); % Make sure distance matrix is symmetric
17 |
18 | G = G .^ 2;
19 | G = G ./ max(max(G));
20 |
21 | % disp('Constructing neighborhood graph...');
22 | % G = L2_distance(X', X');
23 | % GS=G(1:ns,1:ns);
24 | % [~, ind] = sort(GS,2);%
25 | % for i=1:size(GS, 1)
26 | % GS(i, ind(i,(2 + k):end)) = 0;
27 | % end
28 | % GT=G(ns+1:ns+nt,ns+1:ns+nt);
29 | % [~, ind] = sort(GT,2);
30 | % for i=1:size(GT, 1)
31 | % GT(i, ind(i,(2 + k):end)) = 0;
32 | % end
33 | %
34 | % GH= L2_distance(HS,HT);
35 | % [~, ind] = sort(GH,2);
36 | % for i=1:size(GH, 1)
37 | % GH(i, ind(i,(2 + k):end)) = 0;
38 | % end
39 | % G(1:ns,ns+1:ns+nt)=GH;
40 | % G(ns+1:ns+nt,1:ns)=GH';
41 | % G(1:ns,1:ns)=GS;
42 | % G(ns+1:ns+nt,ns+1:ns+nt)=GT;
43 |
44 | G = sparse(double(G));
45 | G = max(G, G');
46 | G = G .^ 2;
47 | G = G ./ max(max(G));
48 |
49 | % Compute weights (W = G)
50 | disp('Computing weight matrices...');
51 |
52 | % Compute Gaussian kernel (heat kernel-based weights)
53 | G(G ~= 0) = exp(-G(G ~= 0) / (2 * sigma ^ 2));
54 | D = diag(sum(G, 2));
55 |
56 | % Compute Laplacian
57 | L = D - G;
58 | L(isnan(L)) = 0; D(isnan(D)) = 0;
59 | L(isinf(L)) = 0; D(isinf(D)) = 0;
60 |
61 | end
62 |
--------------------------------------------------------------------------------
/README.md:
--------------------------------------------------------------------------------
1 | # Probability Weighted Compact Feature for Domain Adaptive Retrieval
2 | Published in CVPR 2020
3 | Contact : huangfuxiang@cqu.edu.cn
4 |
5 | Usage: MATLAB R2017
6 |
7 | Running Models: Run main_demo.m
8 |
9 | More datasets are available at https://pan.baidu.com/s/1EVlYCz51AyDnh5y7PJ5W_Q?pwd=qyrv
10 |
11 |
12 | *If you want to cite the experimental results, please pay attention to the experimental details in the paper. For handwritten digit datasets (MNIST and USPS), following transfer learning, we select 2000 images from the MNIST as the source domain and 1800 images from the USPS as the target domain. Besides, for each dataset, we randomly select 500 target images as the test set (i.e., queries) and the rest images as the training set data. Note that, to investigate more samples, we select more datasets and diiferent settings in the ohter paper, which pobulished in TNNLS2021, i.e., Domain Adaptation Preconceived Hashing for Unconstrained Visual Retrieval. Specifically, we use 60000 images from the MNIST as the source domain and 10000 images from the USPS as the target domain. For each dataset, we randomly select 10% of the target images as the test set (i.e., queries) and the rest images as the training set data.
13 |
14 | Cite: If you find this code useful in your research then please cite
15 |
16 | ```bibtex
17 | @inproceedings{huang2020PWCF,
18 | title={Probability Weighted Compact Feature for Domain Adaptive Retrieval},
19 | author={Huang, Fuxiang and Zhang, Lei and Yang, Yang and Zhou, Xichuan},
20 | booktitle={CVPR},
21 | pages={9579-9588},
22 | year={2020}
23 | }
24 |
25 |
26 | @articale{huang2021domain,
27 | author={Huang, Fuxiang and Zhang, Lei and Gao, Xinbo},
28 | journal={IEEE Transactions on Neural Networks and Learning Systems},
29 | title={Domain Adaptation Preconceived Hashing for Unconstrained Visual Retrieval},
30 | year={2021},
31 | pages={1-15},
32 | doi={10.1109/TNNLS.2021.3071127}
33 | }
34 | ```
35 |
--------------------------------------------------------------------------------
/utils/L2_distance.m:
--------------------------------------------------------------------------------
1 | function d = L2_distance(a, b)
2 | % L2_DISTANCE - computes Euclidean distance matrix
3 | %
4 | % E = L2_distance(A,B)
5 | %
6 | % A - (DxM) matrix
7 | % B - (DxN) matrix
8 | %
9 | % Returns:
10 | % E - (MxN) Euclidean distances between vectors in A and B
11 | %
12 | %
13 | % Description :
14 | % This fully vectorized (VERY FAST!) m-file computes the
15 | % Euclidean distance between two vectors by:
16 | %
17 | % ||A-B|| = sqrt ( ||A||^2 + ||B||^2 - 2*A.B )
18 | %
19 | % Example :
20 | % A = rand(400,100); B = rand(400,200);
21 | % d = distance(A,B);
22 |
23 | % Author : Roland Bunschoten
24 | % University of Amsterdam
25 | % Intelligent Autonomous Systems (IAS) group
26 | % Kruislaan 403 1098 SJ Amsterdam
27 | % tel.(+31)20-5257524
28 | % bunschot@wins.uva.nl
29 | % Last Rev : Wed Oct 20 08:58:08 MET DST 1999
30 | % Tested : PC Matlab v5.2 and Solaris Matlab v5.3
31 |
32 | % Copyright notice: You are free to modify, extend and distribute
33 | % this code granted that the author of the original code is
34 | % mentioned as the original author of the code.
35 |
36 | % Fixed by JBT (3/18/00) to work for 1-dimensional vectors
37 | % and to warn for imaginary numbers. Also ensures that
38 | % output is all real, and allows the option of forcing diagonals to
39 | % be zero.
40 | %
41 | %
42 |
43 | % This file is part of the Matlab Toolbox for Dimensionality Reduction v0.7.2b.
44 | % The toolbox can be obtained from http://homepage.tudelft.nl/19j49
45 | % You are free to use, change, or redistribute this code in any way you
46 | % want for non-commercial purposes. However, it is appreciated if you
47 | % maintain the name of the original author.
48 | %
49 | % (C) Laurens van der Maaten, 2010
50 | % University California, San Diego / Delft University of Technology
51 |
52 |
53 | if nargin < 2
54 | error('Not enough input arguments');
55 | end
56 | if size(a, 1) ~= size(b, 1)
57 | error('A and B should be of same dimensionality');
58 | end
59 | if ~isreal(a) || ~isreal(b)
60 | warning('Computing distance table using imaginary inputs. Results may be off.');
61 | end
62 |
63 | % Padd zeros if necessray
64 | if size(a, 1) == 1
65 | a = [a; zeros(1, size(a, 2))];
66 | b = [b; zeros(1, size(b, 2))];
67 | end
68 |
69 | % Compute distance table
70 | d = sqrt(bsxfun(@plus, sum(a .* a)', bsxfun(@minus, sum(b .* b), 2 * a' * b)));
71 |
72 | % Make sure result is real
73 | d = real(d);
74 |
75 |
--------------------------------------------------------------------------------
/construct_dataset.m:
--------------------------------------------------------------------------------
1 | function [exp_data]=construct_dataset(db_name,num_test,topNum)
2 | addpath('./DB/');
3 | %%choose and load dataset
4 | if strcmp(db_name,'VOC2007&Caltech101')
5 | load VOC2007 %source data
6 | Xs = double(data(:,1:end-1));
7 | Xs = normalize1(Xs);%Xs/max(max(abs(Xs)));
8 | ys = data(:,end);
9 | clear data; %clear labels;
10 |
11 | load Caltech101 %target data
12 | Xt = double(data(:,1:end-1));
13 | Xt = normalize1(Xt);%Xt/max(max(abs(Xt)));
14 | yt = data(:,end);
15 | clear data; %clear labels;
16 |
17 | elseif strcmp(db_name, 'Caltech256&ImageNet')
18 | load dense_imagenet_decaf7_subsampled; %target data
19 | Xt = normalize1(fts);
20 | Xt = double(Xt);
21 | yt = double(labels);
22 | clear fts; clear labels;
23 |
24 | load dense_caltech256_decaf7_subsampled; %source data
25 | Xs = normalize1(fts);%fea;
26 | Xs = double(Xs);
27 | ys = double(labels);
28 | clear fts; clear labels;
29 | elseif strcmp(db_name,'MNIST&USPS')
30 | load MNIST_vs_USPS.mat %source data
31 | Xs = double(X_src)';
32 | Xs = normalize1(Xs);%Xs/max(max(abs(Xs)));
33 | ys = double(Y_src); %vector label
34 | Xt = double(X_tar)';
35 | Xt = normalize1(Xt);%Xs/max(max(abs(Xs)));
36 | yt = double(Y_tar); %vector label
37 | clear X_src; clear Y_src; clear X_tar; clear Y_tar;
38 |
39 | end
40 |
41 | [ndatat,~] = size(Xt);
42 | R = randperm(ndatat);
43 | test = Xt(R(1:num_test),:);
44 | ytest = yt(R(1:num_test));
45 | R(1:num_test) = [];
46 | train = Xt(R,:);
47 | yt = yt(R);
48 | train_ID = R;
49 |
50 |
51 | ytnew = knnclassify(train,Xs,ys);
52 | acc = length(find(ytnew==yt))/length(yt)
53 | num_train = size(train,1);
54 | if topNum == 0
55 | topNum = round(0.02*num_train);%set top Num as Two percent of the number of train
56 |
57 | end
58 | DtrueTestTrain = distMat(test,train);
59 | [~,idx] = sort(DtrueTestTrain,2);
60 | idx = idx(:,1:topNum);
61 | WtrueTestTrain = zeros(num_test,num_train);
62 | for i=1:num_test
63 | WtrueTestTrain(i,idx(i,:)) =1;
64 | end
65 |
66 | YS = repmat(ys,1,length(ytest));
67 | YT = repmat(ytest,1,length(ys));
68 | WTT = (YT==YS');
69 |
70 |
71 | XX=[Xs;Xt];
72 | samplemean = mean(XX,1);
73 | Xs = Xs-repmat(samplemean,size(Xs,1),1);
74 | train = train-repmat(samplemean,size(train,1),1);
75 | test = test-repmat(samplemean,size(test,1),1);
76 |
77 |
78 | exp_data.Xs = Xs ;
79 | exp_data.test = test;
80 | exp_data.train = train;
81 | exp_data.ytnew = ytnew ;
82 | exp_data.yt = yt;
83 | exp_data.ys = ys ;
84 | exp_data.WTT =WTT ;
85 | end
86 |
--------------------------------------------------------------------------------
/demo.m:
--------------------------------------------------------------------------------
1 | function [recall, precision, mAP, rec, pre, retrieved_list] = demo(exp_data, param, method)
2 |
3 | WtrueTestTraining = exp_data.WTT ;
4 | pos = param.pos;
5 | r = param.r;
6 | Xs = exp_data.Xs;
7 | Xt = exp_data.train;
8 | ys = exp_data.ys;
9 | yt = exp_data.ytnew;
10 | test = exp_data.test;
11 |
12 | %% set parameters
13 | setting.record = 0; %
14 | setting.mxitr = 10;
15 | setting.xtol = 1e-5;
16 | setting.gtol = 1e-5;
17 | setting.ftol = 1e-8;
18 | paras.k = 50;
19 | paras.sigma = 0.4;
20 | paras.m = 0.3;
21 | paras.theta1 = 10;
22 | paras.theta2 = 10;
23 | paras.lambda1 = 1;
24 | paras.lambda2 = 10;
25 | paras.lambda3 = 1e4;
26 | paras.max_iter = 50;
27 | [paras.nt,paras.d] = size(Xt);
28 | [paras.ns,paras.d] = size(Xs);
29 | %% leraning
30 | fprintf('......%s start ......\n\n', 'PWCF');
31 | X=[Xs;Xt];
32 | X1=[Xs;Xt;Xt];
33 | N = size(X1,1);
34 |
35 | [vec,val] = eig(X'*X);
36 | [~,Idx] = sort(diag(val),'descend');
37 | W = vec(:,Idx(1:r));
38 | clear Idx;clear vec; clear val;
39 |
40 | %%Construct triples
41 | YS = repmat(ys,1,length(ys));
42 | S = (YS==YS');
43 | %[HS,HT,H] = fea_trans(Xt',Xs',yt,ys,paras);
44 | Ds = EuDist2(Xs,Xs);
45 | Dp = S.*Ds;
46 | [~,Ip] = max(Dp,[],2);
47 | Xp =[];
48 | for i=1:length(ys)
49 | Xp = [Xp; Xs(Ip(i),:)];%Similar sample
50 | end
51 | Dn = Ds-Dp;
52 | [~,In] = min(Dn,[],2);
53 | Xn =[];
54 | for i= 1:length(ys)
55 | Xn = [Xn; Xs(In(i),:)];
56 | end
57 |
58 | YT = repmat(yt,1,length(yt));
59 | S = (YT==YT');
60 | %[HS,HT,H] = fea_trans(Xt',Xs',yt,ys,paras);
61 | Dt = EuDist2(Xt,Xt);
62 | Dp = S.*Dt;
63 | [~,Ip] = max(Dp,[],2);
64 | for i=1:length(yt)
65 | Xp = [Xp; Xt(Ip(i),:)];%Similar sample
66 | end
67 | Dn = Dt-Dp;
68 | [~,In] = min(Dn,[],2);
69 | for i= 1:length(yt)
70 | Xn = [Xn; Xt(In(i),:)];
71 | end
72 |
73 | YS = repmat(ys,1,length(yt));
74 | YT = repmat(yt,1,length(ys));
75 | S = (YT==YS');
76 | [HS,HT,H] = fea_trans(Xt',Xs',yt,ys,paras);
77 | Dts = EuDist2(HT',HS');
78 | Dp = S.*Dts;
79 | [~,Ip] = max(Dp,[],2);
80 | for i=1:length(yt)
81 | Xp = [Xp; Xs(Ip(i),:)];%Similar sample
82 | end
83 | Dn = Dts-Dp;
84 | [~,In] = min(Dn,[],2);
85 | for i= 1:length(yt)
86 | Xn = [Xn; Xs(In(i),:)];
87 | end
88 |
89 | Y = sparse(1:length(ys), double(ys), 1);
90 | Y = full(Y);
91 | L = cl(H,paras);
92 | D=zeros(paras.d,r);
93 | F=0;
94 |
95 | Bs = sign(2*rand(r,paras.ns )-1);
96 | Bt = sign(2*rand(r,paras.nt )-1);
97 |
98 | for iter=1:paras.max_iter
99 | for i=1:N
100 | xi=X1(i,:);
101 | xp=Xp(i,:);
102 | xn=Xn(i,:);
103 |
104 | if norm(W'*(xi-xp)','fro')-norm(W'*(xi-xn)','fro')+paras.m>=0
105 | omega=(1-exp(norm(W'*(xi-xn)','fro')-norm(W'*(xi-xp)','fro')-paras.m))^2;
106 | F=F+omega*norm(W'*(xi-xp)','fro')-omega*norm(W'*(xi-xn)','fro');
107 | D=D+omega*((xi-xn)'*(xi-xn)-(xi-xp)'*(xi-xp))*W;
108 |
109 | end
110 | end
111 | [W, ~] = OptStiefelGBB(W, @W1_obj,setting,F,D,X,L,Xs,Xt,Bs,Bt,paras);
112 | %updata Bt
113 | A=(paras.lambda2*(Bs*Bs')+paras.lambda3*eye(size(Bs*Bs')))\(paras.lambda2*Bs*Y);
114 | Bs = sign((paras.lambda2*(A*A')+paras.theta1*eye(size(A*A')))\(paras.lambda2*A*Y'+paras.theta1 *W'*Xs'));
115 | Bt = sign(W'*Xt');
116 | end
117 |
118 | B_train = (Xs*W>0);
119 | B_test = (test*W>0);
120 | B_trn = compactbit(B_train);
121 | B_tst = compactbit(B_test);
122 |
123 | % compute Hamming metric and compute recall precision
124 | Dhamm = hammingDist(B_tst, B_trn);
125 | [~, rank] = sort(Dhamm, 2, 'ascend');
126 | clear B_tst B_trn;
127 | choice = param.choice;
128 | switch(choice)
129 | case 'evaluation_PR_MAP'
130 | clear train_data test_data;
131 | [recall, precision, ~] = recall_precision(WtrueTestTraining, Dhamm);
132 | [rec, pre]= recall_precision5(WtrueTestTraining, Dhamm, pos); % recall VS. the number of retrieved sample
133 | [mAP] = area_RP(recall, precision);
134 | retrieved_list = [];
135 | case 'evaluation_PR'
136 | clear train_data test_data;
137 | eva_info = eva_ranking(rank, trueRank, pos);
138 | rec = eva_info.recall;
139 | pre = eva_info.precision;
140 | recall = [];
141 | precision = [];
142 | mAP = [];
143 | retrieved_list = [];
144 | case 'visualization'
145 | num = param.numRetrieval;
146 | retrieved_list = visualization(Dhamm, ID, num, train_data, test_data);
147 | recall = [];
148 | precision = [];
149 | rec = [];
150 | pre = [];
151 | mAP = [];
152 | end
153 |
154 | end
155 |
--------------------------------------------------------------------------------
/main_demo.m:
--------------------------------------------------------------------------------
1 |
2 | close all; clear all; clc;
3 | addpath('./utils/');
4 | db_name = 'MNIST&USPS';
5 | param.choice = 'evaluation_PR_MAP';
6 | loopnbits = [32 64];
7 | num_test=500;
8 | runtimes = 1; % change several times to make the rusult more smooth
9 | param.pos = [1:10:40 50:50:1000]; % The number of retrieved samples: Recall-The number of retrieved samples curve
10 | hashmethods = {'PWCF'};
11 |
12 | nhmethods = length(hashmethods);
13 |
14 | for k = 1:runtimes
15 | fprintf('The %d run time, start constructing data\n\n', k);
16 | exp_data = construct_dataset(db_name,num_test,0);
17 | fprintf('Constructing data finished\n\n');
18 | for i =1:length(loopnbits)
19 | fprintf('======start %d bits encoding======\n\n', loopnbits(i));
20 | param.r = loopnbits(i);
21 | for j = 1:nhmethods
22 | [recall{k}{i, j}, precision{k}{i, j}, mAP{k}{i,j}, rec{k}{i, j}, pre{k}{i, j}, ~] = demo(exp_data, param, hashmethods{1, j});
23 | end
24 | end
25 | end
26 |
27 | % plot attribution
28 | line_width = 1.5;
29 | marker_size = 4;
30 | xy_font_size = 16;
31 | legend_font_size = 14;
32 | linewidth = 1.6;
33 | title_font_size = 18;
34 |
35 |
36 | % average MAP
37 | for j = 1:nhmethods
38 | for i =1: length(loopnbits)
39 | tmp = zeros(size(mAP{1, 1}{i, j}));
40 | for k = 1:runtimes
41 | tmp = tmp+mAP{1, k}{i, j};
42 | end
43 | MAP{i, j} = tmp/runtimes;
44 | end
45 | clear tmp;
46 | end
47 | MAP
48 | choose_bits = 1; % i: choose the bits to show
49 | choose_times = 1; % k is the times of run times
50 | %% show recall vs. the number of retrieved sample.
51 | figure('Color', [1 1 1]); hold on;
52 | posEnd = 8;
53 | for j = 1: nhmethods
54 | pos = param.pos;
55 | recc = rec{choose_times}{choose_bits, j};
56 | %p = plot(pos(1,1:posEnd), recc(1,1:posEnd));
57 | p = plot(pos(1,1:end), recc(1,1:end));
58 | color = gen_color(j);
59 | marker = gen_marker(j);
60 | set(p,'Color', color)
61 | set(p,'Marker', marker);
62 | set(p,'LineWidth', line_width);
63 | set(p,'MarkerSize', marker_size);
64 | end
65 |
66 | str_nbits = num2str(loopnbits(choose_bits));
67 | set(gca, 'linewidth', linewidth);
68 | h1 = xlabel('The number of retrieved samples');
69 | h2 = ylabel(['Recall @ ', str_nbits, ' bits']);
70 | title(db_name, 'FontSize', title_font_size);
71 | set(h1, 'FontSize', xy_font_size);
72 | set(h2, 'FontSize', xy_font_size);
73 | %axis square;
74 | hleg = legend(hashmethods);
75 | set(hleg, 'FontSize', legend_font_size);
76 | set(hleg,'Location', 'best');
77 | box on;
78 | grid on;
79 | hold off;
80 |
81 | %% show precision vs. the number of retrieved sample.
82 | figure('Color', [1 1 1]); hold on;
83 | posEnd = 8;
84 | for j = 1: nhmethods
85 | pos = param.pos;
86 | prec = pre{choose_times}{choose_bits, j};
87 | %p = plot(pos(1,1:posEnd), recc(1,1:posEnd));
88 | p = plot(pos(1,1:end), prec(1,1:end));
89 | color = gen_color(j);
90 | marker = gen_marker(j);
91 | set(p,'Color', color)
92 | set(p,'Marker', marker);
93 | set(p,'LineWidth', line_width);
94 | set(p,'MarkerSize', marker_size);
95 | end
96 |
97 | str_nbits = num2str(loopnbits(choose_bits));
98 | set(gca, 'linewidth', linewidth);
99 | h1 = xlabel('The number of retrieved samples');
100 | h2 = ylabel(['Precision @ ', str_nbits, ' bits']);
101 | title(db_name, 'FontSize', title_font_size);
102 | set(h1, 'FontSize', xy_font_size);
103 | set(h2, 'FontSize', xy_font_size);
104 | %axis square;
105 | hleg = legend(hashmethods);
106 | set(hleg, 'FontSize', legend_font_size);
107 | set(hleg,'Location', 'best');
108 | box on;
109 | grid on;
110 | hold off;
111 |
112 | %% show precision vs. recall , i is the selection of which bits.
113 | figure('Color', [1 1 1]); hold on;
114 |
115 | for j = 1: nhmethods
116 | p = plot(recall{choose_times}{choose_bits, j}, precision{choose_times}{choose_bits, j});
117 | color=gen_color(j);
118 | marker=gen_marker(j);
119 | set(p,'Color', color)
120 | set(p,'Marker', marker);
121 | set(p,'LineWidth', line_width);
122 | set(p,'MarkerSize', marker_size);
123 | end
124 |
125 | str_nbits = num2str(loopnbits(choose_bits));
126 | h1 = xlabel(['Recall @ ', str_nbits, ' bits']);
127 | h2 = ylabel('Precision');
128 | title(db_name, 'FontSize', title_font_size);
129 | set(h1, 'FontSize', xy_font_size);
130 | set(h2, 'FontSize', xy_font_size);
131 | %axis square;
132 | hleg = legend(hashmethods);
133 | set(hleg, 'FontSize', legend_font_size);
134 | set(hleg,'Location', 'best');
135 | set(gca, 'linewidth', linewidth);
136 | box on;
137 | grid on;
138 | hold off;
139 |
140 | %% show mAP. This mAP function is provided by Yunchao Gong
141 | figure('Color', [1 1 1]); hold on;
142 | for j = 1: nhmethods
143 | map = [];
144 | for i = 1: length(loopnbits)
145 | map = [map, MAP{i, j}];
146 | end
147 | p = plot(log2(loopnbits), map);
148 | color=gen_color(j);
149 | marker=gen_marker(j);
150 | set(p,'Color', color);
151 | set(p,'Marker', marker);
152 | set(p,'LineWidth', line_width);
153 | set(p,'MarkerSize', marker_size);
154 | end
155 |
156 | h1 = xlabel('Number of bits');
157 | h2 = ylabel('mean Average Precision (mAP)');
158 | title(db_name, 'FontSize', title_font_size);
159 | set(h1, 'FontSize', xy_font_size);
160 | set(h2, 'FontSize', xy_font_size);
161 | %axis square;
162 | set(gca, 'xtick', log2(loopnbits));
163 | set(gca, 'XtickLabel', {'16','32','48' '64','96','128'});
164 | set(gca, 'linewidth', linewidth);
165 | hleg = legend(hashmethods);
166 | set(hleg, 'FontSize', legend_font_size);
167 | set(hleg, 'Location', 'best');
168 | box on;
169 | grid on;
170 | hold off;
171 |
--------------------------------------------------------------------------------
/OptStiefelGBB.m:
--------------------------------------------------------------------------------
1 | function [X, out]= OptStiefelGBB(X, fun, opts, varargin)
2 | %-------------------------------------------------------------------------
3 | % curvilinear search algorithm for optimization on Stiefel manifold
4 | %
5 | % min F(X), S.t., X'*X = I_k, where X \in R^{n,k}
6 | %
7 | % H = [G, X]*[X -G]'
8 | % U = 0.5*tau*[G, X]; V = [X -G]
9 | % X(tau) = X - 2*U * inv( I + V'*U ) * V'*X
10 | %
11 | % -------------------------------------
12 | % U = -[G,X]; V = [X -G]; VU = V'*U;
13 | % X(tau) = X - tau*U * inv( I + 0.5*tau*VU ) * V'*X
14 | %
15 | %
16 | % Input:
17 | % X --- n by k matrix such that X'*X = I
18 | % fun --- objective function and its gradient:
19 | % [F, G] = fun(X, data1, data2)
20 | % F, G are the objective function value and gradient, repectively
21 | % data1, data2 are addtional data, and can be more
22 | % Calling syntax:
23 | % [X, out]= OptStiefelGBB(X0, @fun, opts, data1, data2);
24 | %
25 | % opts --- option structure with fields:
26 | % record = 0, no print out
27 | % mxitr max number of iterations
28 | % xtol stop control for ||X_k - X_{k-1}||
29 | % gtol stop control for the projected gradient
30 | % ftol stop control for |F_k - F_{k-1}|/(1+|F_{k-1}|)
31 | % usually, max{xtol, gtol} > ftol
32 | %
33 | % Output:
34 | % X --- solution
35 | % Out --- output information
36 | %
37 | % -------------------------------------
38 | % For example, consider the eigenvalue problem F(X) = -0.5*Tr(X'*A*X);
39 | %
40 | % function demo
41 | %
42 | % function [F, G] = fun(X, A)
43 | % G = -(A*X);
44 | % F = 0.5*sum(dot(G,X,1));
45 | % end
46 | %
47 | % n = 1000; k = 6;
48 | % A = randn(n); A = A'*A;
49 | % opts.record = 0; %
50 | % opts.mxitr = 1000;
51 | % opts.xtol = 1e-5;
52 | % opts.gtol = 1e-5;
53 | % opts.ftol = 1e-8;
54 | %
55 | % X0 = randn(n,k); X0 = orth(X0);
56 | % tic; [X, out]= OptStiefelGBB(X0, @fun, opts, A); tsolve = toc;
57 | % out.fval = -2*out.fval; % convert the function value to the sum of eigenvalues
58 | % fprintf('\nOptM: obj: %7.6e, itr: %d, nfe: %d, cpu: %f, norm(XT*X-I): %3.2e \n', ...
59 | % out.fval, out.itr, out.nfe, tsolve, norm(X'*X - eye(k), 'fro') );
60 | %
61 | % end
62 | % -------------------------------------
63 | %
64 | % Reference:
65 | % Z. Wen and W. Yin
66 | % A feasible method for optimization with orthogonality constraints
67 | %
68 | % Author: Zaiwen Wen, Wotao Yin
69 | % Version 1.0 .... 2010/10
70 | %-------------------------------------------------------------------------
71 |
72 |
73 | %% Size information
74 | if isempty(X)
75 | error('input X is an empty matrix');
76 | else
77 | [n, k] = size(X);
78 | end
79 |
80 | if isfield(opts, 'xtol')
81 | if opts.xtol < 0 || opts.xtol > 1
82 | opts.xtol = 1e-6;
83 | end
84 | else
85 | opts.xtol = 1e-6;
86 | end
87 |
88 | if isfield(opts, 'gtol')
89 | if opts.gtol < 0 || opts.gtol > 1
90 | opts.gtol = 1e-6;
91 | end
92 | else
93 | opts.gtol = 1e-6;
94 | end
95 |
96 | if isfield(opts, 'ftol')
97 | if opts.ftol < 0 || opts.ftol > 1
98 | opts.ftol = 1e-12;
99 | end
100 | else
101 | opts.ftol = 1e-12;
102 | end
103 |
104 | % parameters for control the linear approximation in line search
105 | if isfield(opts, 'rho')
106 | if opts.rho < 0 || opts.rho > 1
107 | opts.rho = 1e-4;
108 | end
109 | else
110 | opts.rho = 1e-4;
111 | end
112 |
113 | % factor for decreasing the step size in the backtracking line search
114 | if isfield(opts, 'eta')
115 | if opts.eta < 0 || opts.eta > 1
116 | opts.eta = 0.1;
117 | end
118 | else
119 | opts.eta = 0.2;
120 | end
121 |
122 | % parameters for updating C by HongChao, Zhang
123 | if isfield(opts, 'gamma')
124 | if opts.gamma < 0 || opts.gamma > 1
125 | opts.gamma = 0.85;
126 | end
127 | else
128 | opts.gamma = 0.85;
129 | end
130 |
131 | if isfield(opts, 'tau')
132 | if opts.tau < 0 || opts.tau > 1e3
133 | opts.tau = 1e-3;
134 | end
135 | else
136 | opts.tau = 1e-3;
137 | end
138 |
139 | % parameters for the nonmontone line search by Raydan
140 | if ~isfield(opts, 'STPEPS')
141 | opts.STPEPS = 1e-10;
142 | end
143 |
144 | if isfield(opts, 'nt')
145 | if opts.nt < 0 || opts.nt > 100
146 | opts.nt = 5;
147 | end
148 | else
149 | opts.nt = 5;
150 | end
151 |
152 | if isfield(opts, 'projG')
153 | switch opts.projG
154 | case {1,2}; otherwise; opts.projG = 1;
155 | end
156 | else
157 | opts.projG = 1;
158 | end
159 |
160 | if isfield(opts, 'iscomplex')
161 | switch opts.iscomplex
162 | case {0, 1}; otherwise; opts.iscomplex = 0;
163 | end
164 | else
165 | opts.iscomplex = 0;
166 | end
167 |
168 | if isfield(opts, 'mxitr')
169 | if opts.mxitr < 0 || opts.mxitr > 2^20
170 | opts.mxitr = 1000;
171 | end
172 | else
173 | opts.mxitr = 1000;
174 | end
175 |
176 | if ~isfield(opts, 'record')
177 | opts.record = 0;
178 | end
179 |
180 |
181 | %-------------------------------------------------------------------------------
182 | % copy parameters
183 | xtol = opts.xtol;
184 | gtol = opts.gtol;
185 | ftol = opts.ftol;
186 | rho = opts.rho;
187 | STPEPS = opts.STPEPS;
188 | eta = opts.eta;
189 | gamma = opts.gamma;
190 | iscomplex = opts.iscomplex;
191 | record = opts.record;
192 |
193 | nt = opts.nt; crit = ones(nt, 3);
194 |
195 | invH = true; if k < n/2; invH = false; eye2k = eye(2*k); end
196 |
197 | %% Initial function value and gradient
198 | % prepare for iterations
199 | [F, G] = feval(fun, X , varargin{:}); out.nfe = 1;
200 | GX = G'*X;
201 |
202 | if invH
203 | GXT = G*X'; H = 0.5*(GXT - GXT'); RX = H*X;
204 | else
205 | if opts.projG == 1
206 | U = [G, X]; V = [X, -G]; VU = V'*U;
207 | elseif opts.projG == 2
208 | GB = G - 0.5*X*(X'*G);
209 | U = [GB, X]; V = [X, -GB]; VU = V'*U;
210 | end
211 | %U = [G, X]; VU = [GX', X'*X; -(G'*G), -GX];
212 | %VX = VU(:,k+1:end); %VX = V'*X;
213 | VX = V'*X;
214 | end
215 | dtX = G - X*GX; nrmG = norm(dtX, 'fro');
216 |
217 | Q = 1; Cval = F; tau = opts.tau;
218 |
219 | %% Print iteration header if debug == 1
220 | if (opts.record == 1)
221 | fid = 1;
222 | fprintf(fid, '----------- Gradient Method with Line search ----------- \n');
223 | fprintf(fid, '%4s %8s %8s %10s %10s\n', 'Iter', 'tau', 'F(X)', 'nrmG', 'XDiff');
224 | %fprintf(fid, '%4d \t %3.2e \t %3.2e \t %5d \t %5d \t %6d \n', 0, 0, F, 0, 0, 0);
225 | end
226 |
227 | %% main iteration
228 | for itr = 1 : opts.mxitr
229 | XP = X; FP = F; GP = G; dtXP = dtX;
230 | % scale step size
231 |
232 | nls = 1; deriv = rho*nrmG^2; %deriv
233 | while 1
234 | % calculate G, F,
235 | if invH
236 | [X, infX] = linsolve(eye(n) + tau*H, XP - tau*RX);
237 | else
238 | [aa, infR] = linsolve(eye2k + (0.5*tau)*VU, VX);
239 | X = XP - U*(tau*aa);
240 | end
241 | %if norm(X'*X - eye(k),'fro') > 1e-6; error('X^T*X~=I'); end
242 | if ~isreal(X) && ~iscomplex ; error('X is complex'); end
243 |
244 | [F,G] = feval(fun, X, varargin{:});
245 | out.nfe = out.nfe + 1;
246 |
247 | if F <= Cval - tau*deriv || nls >= 5
248 | break;
249 | end
250 | tau = eta*tau; nls = nls+1;
251 | end
252 |
253 | GX = G'*X;
254 | if invH
255 | GXT = G*X'; H = 0.5*(GXT - GXT'); RX = H*X;
256 | else
257 | if opts.projG == 1
258 | U = [G, X]; V = [X, -G]; VU = V'*U;
259 | elseif opts.projG == 2
260 | GB = G - 0.5*X*(X'*G);
261 | U = [GB, X]; V = [X, -GB]; VU = V'*U;
262 | end
263 | %U = [G, X]; VU = [GX', X'*X; -(G'*G), -GX];
264 | %VX = VU(:,k+1:end); % VX = V'*X;
265 | VX = V'*X;
266 | end
267 | dtX = G - X*GX; nrmG = norm(dtX, 'fro');
268 |
269 | S = X - XP; XDiff = norm(S,'fro')/sqrt(n);
270 | tau = opts.tau; FDiff = abs(FP-F)/(abs(FP)+1);
271 |
272 | if iscomplex
273 | %Y = dtX - dtXP; SY = (sum(sum(real(conj(S).*Y))));
274 | Y = dtX - dtXP; SY = abs(sum(sum(conj(S).*Y)));
275 | if mod(itr,2)==0; tau = sum(sum(conj(S).*S))/SY;
276 | else tau = SY/sum(sum(conj(Y).*Y)); end
277 | else
278 | %Y = G - GP; SY = abs(sum(sum(S.*Y)));
279 | Y = dtX - dtXP; SY = abs(sum(sum(S.*Y)));
280 | %alpha = sum(sum(S.*S))/SY;
281 | %alpha = SY/sum(sum(Y.*Y));
282 | %alpha = max([sum(sum(S.*S))/SY, SY/sum(sum(Y.*Y))]);
283 | if mod(itr,2)==0; tau = sum(sum(S.*S))/SY;
284 | else tau = SY/sum(sum(Y.*Y)); end
285 |
286 | % %Y = G - GP;
287 | % Y = dtX - dtXP;
288 | % YX = Y'*X; SX = S'*X;
289 | % SY = abs(sum(sum(S.*Y)) - 0.5*sum(sum(YX.*SX)) );
290 | % if mod(itr,2)==0;
291 | % tau = SY/(sum(sum(S.*S))- 0.5*sum(sum(SX.*SX)));
292 | % else
293 | % tau = (sum(sum(Y.*Y)) -0.5*sum(sum(YX.*YX)))/SY;
294 | % end
295 |
296 | end
297 | tau = max(min(tau, 1e20), 1e-20);
298 |
299 | if (record >= 1)
300 | fprintf('%4d %3.2e %4.3e %3.2e %3.2e %3.2e %2d\n', ...
301 | itr, tau, F, nrmG, XDiff, FDiff, nls);
302 | %fprintf('%4d %3.2e %4.3e %3.2e %3.2e (%3.2e, %3.2e)\n', ...
303 | % itr, tau, F, nrmG, XDiff, alpha1, alpha2);
304 | end
305 |
306 | crit(itr,:) = [nrmG, XDiff, FDiff];
307 | mcrit = mean(crit(itr-min(nt,itr)+1:itr, :),1);
308 | %if (XDiff < xtol && nrmG < gtol ) || FDiff < ftol
309 | %if (XDiff < xtol || nrmG < gtol ) || FDiff < ftol
310 | %if ( XDiff < xtol && FDiff < ftol ) || nrmG < gtol
311 | %if ( XDiff < xtol || FDiff < ftol ) || nrmG < gtol
312 | if ( XDiff < xtol && FDiff < ftol ) || nrmG < gtol || all(mcrit(2:3) < 10*[xtol, ftol])
313 | if itr <= 2
314 | ftol = 0.1*ftol;
315 | xtol = 0.1*xtol;
316 | gtol = 0.1*gtol;
317 | else
318 | out.msg = 'converge';
319 | break;
320 | end
321 | end
322 |
323 | Qp = Q; Q = gamma*Qp + 1; Cval = (gamma*Qp*Cval + F)/Q;
324 | end
325 |
326 | if itr >= opts.mxitr
327 | out.msg = 'exceed max iteration';
328 | end
329 |
330 | out.feasi = norm(X'*X-eye(k),'fro');
331 | if out.feasi > 1e-13
332 | X = MGramSchmidt(X);
333 | [F,G] = feval(fun, X, varargin{:});
334 | out.nfe = out.nfe + 1;
335 | out.feasi = norm(X'*X-eye(k),'fro');
336 | end
337 |
338 | out.nrmG = nrmG;
339 | out.fval = F;
340 | out.itr = itr;
341 |
342 |
343 |
344 |
345 |
--------------------------------------------------------------------------------