├── README.txt ├── .gitattributes ├── .gitignore ├── TCOT.m ├── OPW.m ├── OPW_w.m ├── pdist2.m └── sinkhornTransport.m /README.txt: -------------------------------------------------------------------------------- https://raw.githubusercontent.com/BingSu12/OPW/HEAD/README.txt -------------------------------------------------------------------------------- /.gitattributes: -------------------------------------------------------------------------------- 1 | # Auto detect text files and perform LF normalization 2 | * text=auto 3 | 4 | # Custom for Visual Studio 5 | *.cs diff=csharp 6 | 7 | # Standard to msysgit 8 | *.doc diff=astextplain 9 | *.DOC diff=astextplain 10 | *.docx diff=astextplain 11 | *.DOCX diff=astextplain 12 | *.dot diff=astextplain 13 | *.DOT diff=astextplain 14 | *.pdf diff=astextplain 15 | *.PDF diff=astextplain 16 | *.rtf diff=astextplain 17 | *.RTF diff=astextplain 18 | -------------------------------------------------------------------------------- /.gitignore: -------------------------------------------------------------------------------- 1 | # Windows image file caches 2 | Thumbs.db 3 | ehthumbs.db 4 | 5 | # Folder config file 6 | Desktop.ini 7 | 8 | # Recycle Bin used on file shares 9 | $RECYCLE.BIN/ 10 | 11 | # Windows Installer files 12 | *.cab 13 | *.msi 14 | *.msm 15 | *.msp 16 | 17 | # Windows shortcuts 18 | *.lnk 19 | 20 | # ========================= 21 | # Operating System Files 22 | # ========================= 23 | 24 | # OSX 25 | # ========================= 26 | 27 | .DS_Store 28 | .AppleDouble 29 | .LSOverride 30 | 31 | # Thumbnails 32 | ._* 33 | 34 | # Files that might appear in the root of a volume 35 | .DocumentRevisions-V100 36 | .fseventsd 37 | .Spotlight-V100 38 | .TemporaryItems 39 | .Trashes 40 | .VolumeIcon.icns 41 | 42 | # Directories potentially created on remote AFP share 43 | .AppleDB 44 | .AppleDesktop 45 | Network Trash Folder 46 | Temporary Items 47 | .apdisk 48 | -------------------------------------------------------------------------------- /TCOT.m: -------------------------------------------------------------------------------- 1 | function [dis,T] = TCOT(X,Y,lambda,VERBOSE) 2 | % Compute the Temporally Coupled Optimal Transport (TCOT) distance for two 3 | % sequences X and Y 4 | 5 | % ------------- 6 | % DEPENDENCY: 7 | % ------------- 8 | % "sinkhornTransport.m" by Marco Cuturi; website: http://marcocuturi.net/SI.html 9 | % Please download and add the code into the current directory 10 | % Relevant paper: 11 | % M. Cuturi, 12 | % Sinkhorn Distances : Lightspeed Computation of Optimal Transport, 13 | % Advances in Neural Information Processing Systems (NIPS) 26, 2013 14 | 15 | % ------------- 16 | % INPUT: 17 | % ------------- 18 | % X: a N * d matrix, representing the input sequence consists of of N 19 | % d-dimensional vectors, where N is the number of instances (vectors) in X, 20 | % and d is the dimensionality of instances; 21 | % Y: a M * d matrix, representing the input sequence consists of of N 22 | % d-dimensional vectors, , where N is the number of instances (vectors) in 23 | % Y, and d is the dimensionality of instances; 24 | % iterations = total number of iterations 25 | % lamda: the weight of the entropy regularization, default value: 1 26 | % VERBOSE: whether display the iteration status, default value: 0 (not display) 27 | 28 | % ------------- 29 | % OUTPUT 30 | % ------------- 31 | % dis: the OPW distance between X and Y 32 | % T: the learned transport between X and Y, which is a N*M matrix 33 | 34 | 35 | % ------------- 36 | % c : barycenter according to weights 37 | % ADVICE: divide M by median(M) to have a natural scale 38 | % for lambda 39 | 40 | % ------------- 41 | % Copyright (c) 2017 Bing Su, Gang Hua 42 | % ------------- 43 | % 44 | % ------------- 45 | % License 46 | % The code can be used for research purposes only. 47 | 48 | if nargin<3 || isempty(lambda) 49 | lambda = 1; 50 | end 51 | 52 | if nargin<4 || isempty(VERBOSE) 53 | VERBOSE = 0; 54 | end 55 | 56 | tolerance=.5e-2; 57 | maxIter=100; 58 | % The maximum number of iterations; 59 | % Set it to a large value (e.g, 1000 or 10000) to obtain a more precise 60 | % transport; 61 | p_norm=inf; 62 | 63 | N = size(X,1); 64 | M = size(Y,1); 65 | dim = size(X,2); 66 | if size(Y,2)~=dim 67 | disp('The dimensions of instances in the input sequences must be the same!'); 68 | end 69 | 70 | D = zeros(N,M); 71 | for i = 1:N 72 | for j = 1:M 73 | D(i,j) = sum((X(i,:)-Y(j,:)).^2); 74 | D(i,j) = D(i,j)*(1+abs(i/N-j/M)); 75 | end 76 | end 77 | 78 | % D = pdist2(X,Y, 'sqeuclidean'); 79 | % for i = 1:N 80 | % for j = 1:M 81 | % D(i,j) = D(i,j)*(1+abs(i/N-j/M)); 82 | % end 83 | % end 84 | 85 | % In cases the instances in sequences are not normalized and/or are very 86 | % high-dimensional, the matrix D can be normalized or scaled as follows: 87 | % D = D/max(max(D)); D = D/(10^2); 88 | 89 | 90 | K=exp(-lambda*D); 91 | % With some parameters, some entries of K may exceed the maching-precision 92 | % limit; in such cases, you may need to adjust the parameters, and/or 93 | % normalize the input features in sequences or the matrix D; Please see the 94 | % paper for details. 95 | % In practical situations it might be a good idea to do the following: 96 | % K(K<1e-100)=1e-100; 97 | 98 | U=K.*D; 99 | 100 | a = ones(N,1)./N; 101 | b = ones(M,1)./M; 102 | 103 | % Call the dependency "sinkhornTransport.m" to solve the matrix scaling 104 | % problem 105 | [dis,lowerEMD,l,m]=sinkhornTransport(a,b,K,U,lambda,[],p_norm,tolerance,maxIter,VERBOSE); 106 | T=bsxfun(@times,m',(bsxfun(@times,l,K))); % this is the optimal transport. 107 | 108 | end -------------------------------------------------------------------------------- /OPW.m: -------------------------------------------------------------------------------- 1 | function [dis,T] = OPW(X,Y,lamda1,lamda2,delta,VERBOSE) 2 | % Compute the Order-Preserving Wasserstein Distance (OPW) for two sequences 3 | % X and Y 4 | 5 | % ------------- 6 | % INPUT: 7 | % ------------- 8 | % X: a N * d matrix, representing the input sequence consists of of N 9 | % d-dimensional vectors, where N is the number of instances (vectors) in X, 10 | % and d is the dimensionality of instances; 11 | % Y: a M * d matrix, representing the input sequence consists of of N 12 | % d-dimensional vectors, , where N is the number of instances (vectors) in 13 | % Y, and d is the dimensionality of instances; 14 | % iterations = total number of iterations 15 | % lamda1: the weight of the IDM regularization, default value: 50 16 | % lamda2: the weight of the KL-divergence regularization, default value: 17 | % 0.1 18 | % delta: the parameter of the prior Gaussian distribution, default value: 1 19 | % VERBOSE: whether display the iteration status, default value: 0 (not display) 20 | 21 | % ------------- 22 | % OUTPUT 23 | % ------------- 24 | % dis: the OPW distance between X and Y 25 | % T: the learned transport between X and Y, which is a N*M matrix 26 | 27 | 28 | % ------------- 29 | % c : barycenter according to weights 30 | % ADVICE: divide M by median(M) to have a natural scale 31 | % for lambda 32 | 33 | % ------------- 34 | % Copyright (c) 2017 Bing Su, Gang Hua 35 | % ------------- 36 | % 37 | % ------------- 38 | % License 39 | % The code can be used for research purposes only. 40 | 41 | if nargin<3 || isempty(lamda1) 42 | lamda1 = 50; 43 | end 44 | 45 | if nargin<4 || isempty(lamda2) 46 | lamda2 = 0.1; 47 | end 48 | 49 | if nargin<5 || isempty(delta) 50 | delta = 1; 51 | end 52 | 53 | if nargin<6 || isempty(VERBOSE) 54 | VERBOSE = 0; 55 | end 56 | 57 | tolerance=.5e-2; 58 | maxIter= 20; 59 | % The maximum number of iterations; with a default small value, the 60 | % tolerance and VERBOSE may not be used; 61 | % Set it to a large value (e.g, 1000 or 10000) to obtain a more precise 62 | % transport; 63 | p_norm=inf; 64 | 65 | N = size(X,1); 66 | M = size(Y,1); 67 | dim = size(X,2); 68 | if size(Y,2)~=dim 69 | disp('The dimensions of instances in the input sequences must be the same!'); 70 | end 71 | 72 | P = zeros(N,M); 73 | mid_para = sqrt((1/(N^2) + 1/(M^2))); 74 | for i = 1:N 75 | for j = 1:M 76 | d = abs(i/N - j/M)/mid_para; 77 | P(i,j) = exp(-d^2/(2*delta^2))/(delta*sqrt(2*pi)); 78 | end 79 | end 80 | 81 | S = zeros(N,M); 82 | for i = 1:N 83 | for j = 1:M 84 | S(i,j) = lamda1/((i/N-j/M)^2+1); 85 | end 86 | end 87 | 88 | D = pdist2(X,Y, 'sqeuclidean'); 89 | % In cases the instances in sequences are not normalized and/or are very 90 | % high-dimensional, the matrix D can be normalized or scaled as follows: 91 | % D = D/max(max(D)); D = D/(10^2); 92 | 93 | K = P.*exp((S - D)./lamda2); 94 | % With some parameters, some entries of K may exceed the maching-precision 95 | % limit; in such cases, you may need to adjust the parameters, and/or 96 | % normalize the input features in sequences or the matrix D; Please see the 97 | % paper for details. 98 | % In practical situations it might be a good idea to do the following: 99 | % K(K<1e-100)=1e-100; 100 | 101 | 102 | a = ones(N,1)./N; 103 | b = ones(M,1)./M; 104 | 105 | ainvK=bsxfun(@rdivide,K,a); 106 | 107 | compt=0; 108 | u=ones(N,1)/N; 109 | 110 | % The Sinkhorn's fixed point iteration 111 | % This part of code is adopted from the code "sinkhornTransport.m" by Marco 112 | % Cuturi; website: http://marcocuturi.net/SI.html 113 | % Relevant paper: 114 | % M. Cuturi, 115 | % Sinkhorn Distances : Lightspeed Computation of Optimal Transport, 116 | % Advances in Neural Information Processing Systems (NIPS) 26, 2013 117 | while compt0 133 | disp(['Iteration :',num2str(compt),' Criterion: ',num2str(Criterion)]); 134 | end 135 | end 136 | end 137 | 138 | U = K.*D; 139 | dis=sum(u.*(U*v)); 140 | T=bsxfun(@times,v',(bsxfun(@times,u,K))); 141 | 142 | end -------------------------------------------------------------------------------- /OPW_w.m: -------------------------------------------------------------------------------- 1 | function [dis,T] = OPW_w(X,Y,a,b,lamda1,lamda2,delta,VERBOSE) 2 | % Compute the Order-Preserving Wasserstein Distance (OPW) for two sequences 3 | % X and Y 4 | 5 | % ------------- 6 | % INPUT: 7 | % ------------- 8 | % X: a N * d matrix, representing the input sequence consists of of N 9 | % d-dimensional vectors, where N is the number of instances (vectors) in X, 10 | % and d is the dimensionality of instances; 11 | % Y: a M * d matrix, representing the input sequence consists of of N 12 | % d-dimensional vectors, , where N is the number of instances (vectors) in 13 | % Y, and d is the dimensionality of instances; 14 | % iterations = total number of iterations 15 | % a: a N * 1 weight vector for vectors in X, default uniform weights if input [] 16 | % b: a M * 1 weight vector for vectors in Y, default uniform weights if input [] 17 | % lamda1: the weight of the IDM regularization, default value: 50 18 | % lamda2: the weight of the KL-divergence regularization, default value: 19 | % 0.1 20 | % delta: the parameter of the prior Gaussian distribution, default value: 1 21 | % VERBOSE: whether display the iteration status, default value: 0 (not display) 22 | 23 | % ------------- 24 | % OUTPUT 25 | % ------------- 26 | % dis: the OPW distance between X and Y 27 | % T: the learned transport between X and Y, which is a N*M matrix 28 | 29 | 30 | % ------------- 31 | % c : barycenter according to weights 32 | % ADVICE: divide M by median(M) to have a natural scale 33 | % for lambda 34 | 35 | % ------------- 36 | % Copyright (c) 2017 Bing Su, Gang Hua 37 | % ------------- 38 | % 39 | % ------------- 40 | % License 41 | % The code can be used for research purposes only. 42 | 43 | if nargin<3 || isempty(lamda1) 44 | lamda1 = 50; 45 | end 46 | 47 | if nargin<4 || isempty(lamda2) 48 | lamda2 = 0.1; 49 | end 50 | 51 | if nargin<5 || isempty(delta) 52 | delta = 1; 53 | end 54 | 55 | if nargin<6 || isempty(VERBOSE) 56 | VERBOSE = 0; 57 | end 58 | 59 | tolerance=.5e-2; 60 | maxIter= 20; 61 | % The maximum number of iterations; with a default small value, the 62 | % tolerance and VERBOSE may not be used; 63 | % Set it to a large value (e.g, 1000 or 10000) to obtain a more precise 64 | % transport; 65 | p_norm=inf; 66 | 67 | N = size(X,1); 68 | M = size(Y,1); 69 | dim = size(X,2); 70 | if size(Y,2)~=dim 71 | disp('The dimensions of instances in the input sequences must be the same!'); 72 | end 73 | 74 | P = zeros(N,M); 75 | mid_para = sqrt((1/(N^2) + 1/(M^2))); 76 | for i = 1:N 77 | for j = 1:M 78 | d = abs(i/N - j/M)/mid_para; 79 | P(i,j) = exp(-d^2/(2*delta^2))/(delta*sqrt(2*pi)); 80 | end 81 | end 82 | 83 | %D = zeros(N,M); 84 | S = zeros(N,M); 85 | for i = 1:N 86 | for j = 1:M 87 | %D(i,j) = sum((X(i,:)-Y(j,:)).^2); 88 | S(i,j) = lamda1/((i/N-j/M)^2+1); 89 | end 90 | end 91 | 92 | D = pdist2(X,Y, 'sqeuclidean'); 93 | %D = D/(10^2); 94 | % In cases the instances in sequences are not normalized and/or are very 95 | % high-dimensional, the matrix D can be normalized or scaled as follows: 96 | % D = D/max(max(D)); D = D/(10^2); 97 | 98 | K = P.*exp((S - D)./lamda2); 99 | % With some parameters, some entries of K may exceed the maching-precision 100 | % limit; in such cases, you may need to adjust the parameters, and/or 101 | % normalize the input features in sequences or the matrix D; Please see the 102 | % paper for details. 103 | % In practical situations it might be a good idea to do the following: 104 | % K(K<1e-100)=1e-100; 105 | 106 | if isempty(a) 107 | a = ones(N,1)./N; 108 | end 109 | 110 | if isempty(b) 111 | b = ones(M,1)./M; 112 | end 113 | 114 | ainvK=bsxfun(@rdivide,K,a); 115 | 116 | compt=0; 117 | u=ones(N,1)/N; 118 | 119 | % The Sinkhorn's fixed point iteration 120 | % This part of code is adopted from the code "sinkhornTransport.m" by Marco 121 | % Cuturi; website: http://marcocuturi.net/SI.html 122 | % Relevant paper: 123 | % M. Cuturi, 124 | % Sinkhorn Distances : Lightspeed Computation of Optimal Transport, 125 | % Advances in Neural Information Processing Systems (NIPS) 26, 2013 126 | while compt0 142 | disp(['Iteration :',num2str(compt),' Criterion: ',num2str(Criterion)]); 143 | end 144 | end 145 | end 146 | 147 | U = K.*D; 148 | dis=sum(u.*(U*v)); 149 | T=bsxfun(@times,v',(bsxfun(@times,u,K))); 150 | 151 | end -------------------------------------------------------------------------------- /pdist2.m: -------------------------------------------------------------------------------- 1 | % This function belongs to Piotr Dollar's Toolbox 2 | % http://vision.ucsd.edu/~pdollar/toolbox/doc/index.html 3 | % Please refer to the above web page for definitions and clarifications 4 | % 5 | % Calculates the distance between sets of vectors. 6 | % 7 | % Let X be an m-by-p matrix representing m points in p-dimensional space 8 | % and Y be an n-by-p matrix representing another set of points in the same 9 | % space. This function computes the m-by-n distance matrix D where D(i,j) 10 | % is the distance between X(i,:) and Y(j,:). This function has been 11 | % optimized where possible, with most of the distance computations 12 | % requiring few or no loops. 13 | % 14 | % The metric can be one of the following: 15 | % 16 | % 'euclidean' / 'sqeuclidean': 17 | % Euclidean / SQUARED Euclidean distance. Note that 'sqeuclidean' 18 | % is significantly faster. 19 | % 20 | % 'chisq' 21 | % The chi-squared distance between two vectors is defined as: 22 | % d(x,y) = sum( (xi-yi)^2 / (xi+yi) ) / 2; 23 | % The chi-squared distance is useful when comparing histograms. 24 | % 25 | % 'cosine' 26 | % Distance is defined as the cosine of the angle between two vectors. 27 | % 28 | % 'emd' 29 | % Earth Mover's Distance (EMD) between positive vectors (histograms). 30 | % Note for 1D, with all histograms having equal weight, there is a simple 31 | % closed form for the calculation of the EMD. The EMD between histograms 32 | % x and y is given by the sum(abs(cdf(x)-cdf(y))), where cdf is the 33 | % cumulative distribution function (computed simply by cumsum). 34 | % 35 | % 'L1' 36 | % The L1 distance between two vectors is defined as: sum(abs(x-y)); 37 | % 38 | % 39 | % USAGE 40 | % D = pdist2( X, Y, [metric] ) 41 | % 42 | % INPUTS 43 | % X - [m x p] matrix of m p-dimensional vectors 44 | % Y - [n x p] matrix of n p-dimensional vectors 45 | % metric - ['sqeuclidean'], 'chisq', 'cosine', 'emd', 'euclidean', 'L1' 46 | % 47 | % OUTPUTS 48 | % D - [m x n] distance matrix 49 | % 50 | % EXAMPLE 51 | % [X,IDX] = demoGenData(100,0,5,4,10,2,0); 52 | % D = pdist2( X, X, 'sqeuclidean' ); 53 | % distMatrixShow( D, IDX ); 54 | % 55 | % See also PDIST, DISTMATRIXSHOW 56 | 57 | % Piotr's Image&Video Toolbox Version 2.0 58 | % Copyright (C) 2007 Piotr Dollar. [pdollar-at-caltech.edu] 59 | % Please email me if you find bugs, or have suggestions or questions! 60 | % Licensed under the Lesser GPL [see external/lgpl.txt] 61 | 62 | function D = pdist2( X, Y, metric ) 63 | 64 | if( nargin<3 || isempty(metric) ); metric=0; end; 65 | 66 | switch metric 67 | case {0,'sqeuclidean'} 68 | D = distEucSq( X, Y ); 69 | case 'euclidean' 70 | D = sqrt(distEucSq( X, Y )); 71 | case 'L1' 72 | D = distL1( X, Y ); 73 | case 'cosine' 74 | D = distCosine( X, Y ); 75 | case 'emd' 76 | D = distEmd( X, Y ); 77 | case 'chisq' 78 | D = distChiSq( X, Y ); 79 | otherwise 80 | error(['pdist2 - unknown metric: ' metric]); 81 | end 82 | 83 | 84 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 85 | function D = distL1( X, Y ) 86 | 87 | m = size(X,1); n = size(Y,1); 88 | mOnes = ones(1,m); D = zeros(m,n); 89 | for i=1:n 90 | yi = Y(i,:); yi = yi( mOnes, : ); 91 | D(:,i) = sum( abs( X-yi),2 ); 92 | end 93 | 94 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 95 | function D = distCosine( X, Y ) 96 | 97 | if( ~isa(X,'double') || ~isa(Y,'double')) 98 | error( 'Inputs must be of type double'); end; 99 | 100 | p=size(X,2); 101 | XX = sqrt(sum(X.*X,2)); X = X ./ XX(:,ones(1,p)); 102 | YY = sqrt(sum(Y.*Y,2)); Y = Y ./ YY(:,ones(1,p)); 103 | D = 1 - X*Y'; 104 | 105 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 106 | function D = distEmd( X, Y ) 107 | 108 | Xcdf = cumsum(X,2); 109 | Ycdf = cumsum(Y,2); 110 | 111 | m = size(X,1); n = size(Y,1); 112 | mOnes = ones(1,m); D = zeros(m,n); 113 | for i=1:n 114 | ycdf = Ycdf(i,:); 115 | ycdfRep = ycdf( mOnes, : ); 116 | D(:,i) = sum(abs(Xcdf - ycdfRep),2); 117 | end 118 | 119 | 120 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 121 | function D = distChiSq( X, Y ) 122 | 123 | %%% supposedly it's possible to implement this without a loop! 124 | m = size(X,1); n = size(Y,1); 125 | mOnes = ones(1,m); D = zeros(m,n); 126 | for i=1:n 127 | yi = Y(i,:); yiRep = yi( mOnes, : ); 128 | s = yiRep + X; d = yiRep - X; 129 | D(:,i) = sum( d.^2 ./ (s+eps), 2 ); 130 | end 131 | D = D/2; 132 | 133 | 134 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 135 | function D = distEucSq( X, Y ) 136 | 137 | %if( ~isa(X,'double') || ~isa(Y,'double')) 138 | % error( 'Inputs must be of type double'); end; 139 | m = size(X,1); n = size(Y,1); 140 | %Yt = Y'; 141 | XX = sum(X.*X,2); 142 | YY = sum(Y'.*Y',1); 143 | D = XX(:,ones(1,n)) + YY(ones(1,m),:) - 2*X*Y'; 144 | 145 | 146 | 147 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% 148 | % function D = distEucSq( X, Y ) 149 | %%%% code from Charles Elkan with variables renamed 150 | % m = size(X,1); n = size(Y,1); 151 | % D = sum(X.^2, 2) * ones(1,n) + ones(m,1) * sum(Y.^2, 2)' - 2.*X*Y'; 152 | 153 | 154 | %%% LOOP METHOD - SLOW 155 | % [m p] = size(X); 156 | % [n p] = size(Y); 157 | % 158 | % D = zeros(m,n); 159 | % onesM = ones(m,1); 160 | % for i=1:n 161 | % y = Y(i,:); 162 | % d = X - y(onesM,:); 163 | % D(:,i) = sum( d.*d, 2 ); 164 | % end 165 | 166 | 167 | %%% PARALLEL METHOD THAT IS SUPER SLOW (slower then loop)! 168 | % % From "MATLAB array manipulation tips and tricks" by Peter J. Acklam 169 | % Xb = permute(X, [1 3 2]); 170 | % Yb = permute(Y, [3 1 2]); 171 | % D = sum( (Xb(:,ones(1,n),:) - Yb(ones(1,m),:,:)).^2, 3); 172 | 173 | 174 | %%% USELESS FOR EVEN VERY LARGE ARRAYS X=16000x1000!! and Y=100x1000 175 | % call recursively to save memory 176 | % if( (m+n)*p > 10^5 && (m>1 || n>1)) 177 | % if( m>n ) 178 | % X1 = X(1:floor(end/2),:); 179 | % X2 = X((floor(end/2)+1):end,:); 180 | % D1 = distEucSq( X1, Y ); 181 | % D2 = distEucSq( X2, Y ); 182 | % D = cat( 1, D1, D2 ); 183 | % else 184 | % Y1 = Y(1:floor(end/2),:); 185 | % Y2 = Y((floor(end/2)+1):end,:); 186 | % D1 = distEucSq( X, Y1 ); 187 | % D2 = distEucSq( X, Y2 ); 188 | % D = cat( 2, D1, D2 ); 189 | % end 190 | % return; 191 | % end 192 | -------------------------------------------------------------------------------- /sinkhornTransport.m: -------------------------------------------------------------------------------- 1 | function [D,L,u,v]=sinkhornTransport(a,b,K,U,lambda,stoppingCriterion,p_norm,tolerance,maxIter,VERBOSE) 2 | % Compute N dual-Sinkhorn divergences (upper bound on the EMD) as well as 3 | % N lower bounds on the EMD for all the pairs 4 | % 5 | % D= [d(a_1,b_1), d(a_2,b_2), ... , d(a_N,b_N)]. 6 | % If needed, the function also outputs diagonal scalings to recover smoothed optimal 7 | % transport between each of the pairs (a_i,b_i). 8 | % 9 | %--------------------------- 10 | % Required Inputs: 11 | %--------------------------- 12 | % a is either 13 | % - a d1 x 1 column vector in the probability simplex (nonnegative, 14 | % summing to one). This is the [1-vs-N mode] 15 | % - a d_1 x N matrix, where each column vector is in the probability simplex 16 | % This is the [N x 1-vs-1 mode] 17 | % 18 | % b is a d2 x N matrix of N vectors in the probability simplex 19 | % 20 | % K is a d1 x d2 matrix, equal to exp(-lambda M), where M is the d1 x d2 21 | % matrix of pairwise distances between bins described in a and bins in the b_1,...b_N histograms. 22 | % In the most simple case d_1=d_2 and M is simply a distance matrix (zero 23 | % on the diagonal and such that m_ij < m_ik + m_kj 24 | % 25 | % 26 | % U = K.*M is a d1 x d2 matrix, pre-stored to speed up the computation of 27 | % the distances. 28 | % 29 | % 30 | %--------------------------- 31 | % Optional Inputs: 32 | %--------------------------- 33 | % stoppingCriterion in {'marginalDifference','distanceRelativeDecrease'} 34 | % - marginalDifference (Default) : checks whether the difference between 35 | % the marginals of the current optimal transport and the 36 | % theoretical marginals set by a b_1,...,b_N are satisfied. 37 | % - distanceRelativeDecrease : only focus on convergence of the vector 38 | % of distances 39 | % 40 | % p_norm: parameter in {(1,+infty]} used to compute a stoppingCriterion statistic 41 | % from N numbers (these N numbers might be the 1-norm of marginal 42 | % differences or the vector of distances. 43 | % 44 | % tolerance : >0 number to test the stoppingCriterion. 45 | % 46 | % maxIter: maximal number of Sinkhorn fixed point iterations. 47 | % 48 | % verbose: verbose level. 0 by default. 49 | %--------------------------- 50 | % Output 51 | %--------------------------- 52 | % D : vector of N dual-sinkhorn divergences, or upper bounds to the EMD. 53 | % 54 | % L : vector of N lower bounds to the original OT problem, a.k.a EMD. This is computed by using 55 | % the dual variables of the smoothed problem, which, when modified 56 | % adequately, are feasible for the original (non-smoothed) OT dual problem 57 | % 58 | % u : d1 x N matrix of left scalings 59 | % v : d2 x N matrix of right scalings 60 | % 61 | % The smoothed optimal transport between (a_i,b_i) can be recovered as 62 | % T_i = diag(u(:,i)) * K * diag(v(:,i)); 63 | % 64 | % or, equivalently and substantially faster: 65 | % T_i = bsxfun(@times,v(:,i)',(bsxfun(@times,u(:,i),K))) 66 | % 67 | % 68 | % Relevant paper: 69 | % M. Cuturi, 70 | % Sinkhorn Distances : Lightspeed Computation of Optimal Transport, 71 | % Advances in Neural Information Processing Systems (NIPS) 26, 2013 72 | 73 | % This code, (c) Marco Cuturi 2013,2014 (see license block below) 74 | % v0.2b corrected a small bug in the definition of the first scaling 75 | % variable u. 76 | % v0.2 numerous improvements, including possibility to compute 77 | % simultaneously distances between different pairs of points 24/03/14 78 | % v0.1 added lower bound 26/11/13 79 | % v0.0 first version 20/11/2013 80 | 81 | % Change log: 82 | % 28/5/14: The initialization of u was u=ones(length(a),size(b,2))/length(a); which does not 83 | % work when the number of columns of a is larger than the number 84 | % of lines (i.e. more histograms than dimensions). The correct 85 | % initialization must use size(a,1) and not its length. 86 | % 24/3/14: Now possible to compute in parallel D(a_i,b_i) instead of being 87 | % limited to D(a,b_i). More optional inputs and better error checking. 88 | % Removed an unfortunate code error where 2 variables had the same name. 89 | % 90 | % 20/1/14: Another correction at the very end of the script to output weights. 91 | % 92 | % 15/1/14: Correction when outputting l at the very end of the script. replaced size(b) by size(a). 93 | 94 | %% Processing optional inputs 95 | 96 | if nargin<6 || isempty(stoppingCriterion), 97 | stoppingCriterion='marginalDifference'; % check marginal constraints by default 98 | end 99 | 100 | if nargin<7 || isempty(p_norm), 101 | p_norm=inf; 102 | end 103 | 104 | if nargin<8 || isempty(tolerance), 105 | tolerance=.5e-2; 106 | end 107 | 108 | if nargin<9 || isempty(maxIter), 109 | maxIter=5000; 110 | end 111 | 112 | if nargin<10 || isempty(VERBOSE), 113 | VERBOSE=0; 114 | end 115 | 116 | 117 | %% Checking the type of computation: 1-vs-N points or many pairs. 118 | 119 | if size(a,2)==1, 120 | ONE_VS_N=true; % We are computing [D(a,b_1), ... , D(a,b_N)] 121 | elseif size(a,2)==size(b,2), 122 | ONE_VS_N=false; % We are computing [D(a_1,b_1), ... , D(a_N,b_N)] 123 | else 124 | error('The first parameter a is either a column vector in the probability simplex, or N column vectors in the probability simplex where N is size(b,2)'); 125 | end 126 | 127 | %% Checking dimensionality: 128 | if size(b,2)>size(b,1), 129 | BIGN=true; 130 | else 131 | BIGN=false; 132 | end 133 | 134 | 135 | %% Small changes in the 1-vs-N case to go a bit faster. 136 | if ONE_VS_N, % if computing 1-vs-N make sure all components of a are >0. Otherwise we can get rid of some lines of K to go faster. 137 | I=(a>0); 138 | someZeroValues=false; 139 | if ~all(I), % need to update some vectors and matrices if a does not have full support 140 | someZeroValues=true; 141 | K=K(I,:); 142 | U=U(I,:); 143 | a=a(I); 144 | end 145 | ainvK=bsxfun(@rdivide,K,a); % precomputation of this matrix saves a d1 x N Schur product at each iteration. 146 | end 147 | 148 | %% Fixed point counter 149 | compt=0; 150 | 151 | %% Initialization of Left scaling Factors, N column vectors. 152 | u=ones(size(a,1),size(b,2))/size(a,1); 153 | 154 | 155 | if strcmp(stoppingCriterion,'distanceRelativeDecrease') 156 | Dold=ones(1,size(b,2)); %initialization of vector of distances. 157 | end 158 | 159 | 160 | %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% Fixed Point Loop 161 | % The computation below is mostly captured by the repeated iteration of 162 | % u=a./(K*(b./(K'*u))); 163 | % 164 | % In some cases, this iteration can be sped up further when considering a few 165 | % minor tricks (when computing the distances of 1 histogram vs many, 166 | % ONE_VS_N, or when the number of histograms N is larger than the dimension 167 | % of these histograms). 168 | % We consider such cases below. 169 | 170 | 171 | while compt0, 224 | disp(['Iteration :',num2str(compt),' Criterion: ',num2str(Criterion)]); 225 | end 226 | if any(isnan(Criterion)), % stop all computation if a computation of one of the pairs goes wrong. 227 | error('NaN values have appeared during the fixed point iteration. This problem appears because of insufficient machine precision when processing computations with a regularization value of lambda that is too high. Try again with a reduced regularization parameter lambda or with a thresholded metric matrix M.'); 228 | end 229 | end 230 | end 231 | 232 | if strcmp(stoppingCriterion,'marginalDifference'), % if we have been watching marginal differences, we need to compute the vector of distances. 233 | D=sum(u.*(U*v)); 234 | end 235 | 236 | if nargout>1, % user wants lower bounds 237 | alpha = log(u); 238 | beta = log(v); 239 | beta(beta==-inf)=0; % zero values of v (corresponding to zero values in b) generate inf numbers. 240 | if ONE_VS_N 241 | L= (a'* alpha + sum(b.*beta))/lambda; 242 | else 243 | alpha(alpha==-inf)=0; % zero values of u (corresponding to zero values in a) generate inf numbers. in ONE-VS-ONE mode this never happens. 244 | L= (sum(a.*alpha) + sum(b.*beta))/lambda; 245 | end 246 | 247 | 248 | end 249 | 250 | if nargout>2 && ONE_VS_N && someZeroValues, % user wants scalings. We might have to arficially add zeros again in bins of a that were suppressed. 251 | uu=u; 252 | u=zeros(length(I),size(b,2)); 253 | u(I,:)=uu; 254 | end 255 | 256 | 257 | 258 | % ***** BEGIN LICENSE BLOCK ***** 259 | % * Version: MPL 1.1/GPL 2.0/LGPL 2.1 260 | % * 261 | % * The contents of this file are subject to the Mozilla Public License Version 262 | % * 1.1 (the "License"); you may not use this file except in compliance with 263 | % * the License. You may obtain a copy of the License at 264 | % * http://www.mozilla.org/MPL/ 265 | % * 266 | % * Software distributed under the License is distributed on an "AS IS" basis, 267 | % * WITHOUT WARRANTY OF ANY KIND, either express or implied. See the License 268 | % * for the specific language governing rights and limitations under the 269 | % * License. 270 | % * 271 | % * The Original Code is Sinkhorn Transport, (C) 2013, Marco Cuturi 272 | % * 273 | % * The Initial Developers of the Original Code is 274 | % * 275 | % * Marco Cuturi mcuturi@i.kyoto-u.ac.jp 276 | % * 277 | % * Portions created by the Initial Developers are 278 | % * Copyright (C) 2013 the Initial Developers. All Rights Reserved. 279 | % * 280 | % * 281 | % ***** END LICENSE BLOCK ***** 282 | --------------------------------------------------------------------------------